diff --git a/cli/src/init.rs b/cli/src/init.rs index 10ac73a6..709612af 100644 --- a/cli/src/init.rs +++ b/cli/src/init.rs @@ -666,15 +666,10 @@ pub fn generate_grammar_files( allow_update, |path| generate_file(path, SETUP_PY_TEMPLATE, language_name, &generate_opts), |path| { - let mut contents = fs::read_to_string(path)?; - if !contents.contains("egg_info") || !contents.contains("Py_GIL_DISABLED") { + let contents = fs::read_to_string(path)?; + if !contents.contains("build_ext") { eprintln!("Replacing setup.py"); generate_file(path, SETUP_PY_TEMPLATE, language_name, &generate_opts)?; - } else { - contents = contents - .replace("path\nfrom platform import system", "name as os_name, path") - .replace("system() != \"Windows\"", "os_name != \"nt\""); - write_file(path, contents)?; } Ok(()) }, diff --git a/cli/src/templates/setup.py b/cli/src/templates/setup.py index 9fdc3098..7f92eaee 100644 --- a/cli/src/templates/setup.py +++ b/cli/src/templates/setup.py @@ -1,30 +1,12 @@ -from os import name as os_name, path +from os import path from sysconfig import get_config_var from setuptools import Extension, find_packages, setup from setuptools.command.build import build +from setuptools.command.build_ext import build_ext from setuptools.command.egg_info import egg_info from wheel.bdist_wheel import bdist_wheel -sources = [ - "bindings/python/tree_sitter_LOWER_PARSER_NAME/binding.c", - "src/parser.c", -] -if path.exists("src/scanner.c"): - sources.append("src/scanner.c") - -macros: list[tuple[str, str | None]] = [ - ("PY_SSIZE_T_CLEAN", None), - ("TREE_SITTER_HIDE_SYMBOLS", None), -] -if limited_api := not get_config_var("Py_GIL_DISABLED"): - macros.append(("Py_LIMITED_API", "0x030A0000")) - -if os_name != "nt": - cflags = ["-std=c11", "-fvisibility=hidden"] -else: - cflags = ["/std:c11", "/utf-8"] - class Build(build): def run(self): @@ -34,6 +16,19 @@ class Build(build): super().run() +class BuildExt(build_ext): + def build_extension(self, ext: Extension): + if self.compiler.compiler_type != "msvc": + ext.extra_compile_args = ["-std=c11", "-fvisibility=hidden"] + else: + ext.extra_compile_args = ["/std:c11", "/utf-8"] + if path.exists("src/scanner.c"): + ext.sources.append("src/scanner.c") + if ext.py_limited_api: + ext.define_macros.append(("Py_LIMITED_API", "0x030A0000")) + super().build_extension(ext) + + class BdistWheel(bdist_wheel): def get_tag(self): python, abi, platform = super().get_tag() @@ -60,15 +55,21 @@ setup( ext_modules=[ Extension( name="_binding", - sources=sources, - extra_compile_args=cflags, - define_macros=macros, + sources=[ + "bindings/python/tree_sitter_LOWER_PARSER_NAME/binding.c", + "src/parser.c", + ], + define_macros=[ + ("PY_SSIZE_T_CLEAN", None), + ("TREE_SITTER_HIDE_SYMBOLS", None), + ], include_dirs=["src"], - py_limited_api=limited_api, + py_limited_api=not get_config_var("Py_GIL_DISABLED"), ) ], cmdclass={ "build": Build, + "build_ext": BuildExt, "bdist_wheel": BdistWheel, "egg_info": EggInfo, },