diff --git a/crates/cli/src/init.rs b/crates/cli/src/init.rs index 2f3792a6..8e9cad06 100644 --- a/crates/cli/src/init.rs +++ b/crates/cli/src/init.rs @@ -646,9 +646,24 @@ pub fn generate_grammar_files( generate_file(path, INIT_PY_TEMPLATE, language_name, &generate_opts) })?; - missing_path(lang_path.join("__init__.pyi"), |path| { - generate_file(path, INIT_PYI_TEMPLATE, language_name, &generate_opts) - })?; + missing_path_else( + lang_path.join("__init__.pyi"), + allow_update, + |path| generate_file(path, INIT_PYI_TEMPLATE, language_name, &generate_opts), + |path| { + let mut contents = fs::read_to_string(path)?; + if !contents.contains("CapsuleType") { + contents = contents + .replace( + "from typing import Final", + "from typing import Final\nfrom typing_extensions import CapsuleType" + ) + .replace("-> object:", "-> CapsuleType:"); + write_file(path, contents)?; + } + Ok(()) + }, + )?; missing_path(lang_path.join("py.typed"), |path| { generate_file(path, "", language_name, &generate_opts) // py.typed is empty diff --git a/crates/cli/src/templates/__init__.pyi b/crates/cli/src/templates/__init__.pyi index abf6633f..5c63215d 100644 --- a/crates/cli/src/templates/__init__.pyi +++ b/crates/cli/src/templates/__init__.pyi @@ -1,4 +1,5 @@ from typing import Final +from typing_extensions import CapsuleType # NOTE: uncomment these to include any queries that this grammar contains: @@ -7,4 +8,4 @@ from typing import Final # LOCALS_QUERY: Final[str] # TAGS_QUERY: Final[str] -def language() -> object: ... +def language() -> CapsuleType: ...