diff --git a/cli/src/init.rs b/cli/src/init.rs index dffc0294..a4959b53 100644 --- a/cli/src/init.rs +++ b/cli/src/init.rs @@ -608,14 +608,32 @@ pub fn generate_grammar_files( })?; missing_path(path.join("tests"), create_dir)?.apply(|path| { - missing_path(path.join("test_binding.py"), |path| { - generate_file( - path, - TEST_BINDING_PY_TEMPLATE, - language_name, - &generate_opts, - ) - })?; + missing_path_else( + path.join("test_binding.py"), + allow_update, + |path| { + generate_file( + path, + TEST_BINDING_PY_TEMPLATE, + language_name, + &generate_opts, + ) + }, + |path| { + let mut contents = fs::read_to_string(path)?; + if !contents.contains("Parser(Language(") { + contents = contents + .replace("tree_sitter.Language(", "Parser(Language(") + .replace(".language())\n", ".language()))\n") + .replace( + "import tree_sitter\n", + "from tree_sitter import Language, Parser\n", + ); + write_file(path, contents)?; + } + Ok(()) + }, + )?; Ok(()) })?; diff --git a/cli/src/templates/test_binding.py b/cli/src/templates/test_binding.py index 31aef9ac..a832c368 100644 --- a/cli/src/templates/test_binding.py +++ b/cli/src/templates/test_binding.py @@ -1,12 +1,12 @@ from unittest import TestCase -import tree_sitter +from tree_sitter import Language, Parser import tree_sitter_LOWER_PARSER_NAME class TestLanguage(TestCase): def test_can_load_grammar(self): try: - tree_sitter.Language(tree_sitter_LOWER_PARSER_NAME.language()) + Parser(Language(tree_sitter_LOWER_PARSER_NAME.language())) except Exception: self.fail("Error loading TITLE_PARSER_NAME grammar")