diff --git a/cli/loader/src/lib.rs b/cli/loader/src/lib.rs index f5dffa02..d556aa27 100644 --- a/cli/loader/src/lib.rs +++ b/cli/loader/src/lib.rs @@ -87,7 +87,7 @@ const BUILD_TARGET: &str = env!("BUILD_TARGET"); pub struct LanguageConfiguration<'a> { pub scope: Option, pub content_regex: Option, - pub _first_line_regex: Option, + pub first_line_regex: Option, pub injection_regex: Option, pub file_types: Vec, pub root_path: PathBuf, @@ -109,6 +109,7 @@ pub struct Loader { language_configurations: Vec>, language_configuration_ids_by_file_type: HashMap>, language_configuration_in_current_path: Option, + language_configuration_ids_by_first_line_regex: HashMap>, highlight_names: Box>>, use_all_highlight_names: bool, debug_build: bool, @@ -140,6 +141,7 @@ impl Loader { language_configurations: Vec::new(), language_configuration_ids_by_file_type: HashMap::new(), language_configuration_in_current_path: None, + language_configuration_ids_by_first_line_regex: HashMap::new(), highlight_names: Box::new(Mutex::new(Vec::new())), use_all_highlight_names: true, debug_build: false, @@ -241,6 +243,26 @@ impl Loader { .and_then(|extension| { self.language_configuration_ids_by_file_type.get(extension) }) + }) + .or_else(|| { + let Ok(file) = fs::File::open(path) else { + return None; + }; + let reader = BufReader::new(file); + let Some(Ok(first_line)) = std::io::BufRead::lines(reader).next() else { + return None; + }; + + self.language_configuration_ids_by_first_line_regex + .iter() + .find(|(regex, _)| { + if let Some(regex) = Self::regex(Some(regex)) { + regex.is_match(&first_line) + } else { + false + } + }) + .map(|(_, ids)| ids) }); if let Some(configuration_ids) = configuration_ids { @@ -871,9 +893,9 @@ impl Loader { scope: config_json.scope, language_id, file_types: config_json.file_types.unwrap_or(Vec::new()), - content_regex: Self::regex(config_json.content_regex), - _first_line_regex: Self::regex(config_json.first_line_regex), - injection_regex: Self::regex(config_json.injection_regex), + content_regex: Self::regex(config_json.content_regex.as_deref()), + first_line_regex: Self::regex(config_json.first_line_regex.as_deref()), + injection_regex: Self::regex(config_json.injection_regex.as_deref()), injections_filenames: config_json.injections.into_vec(), locals_filenames: config_json.locals.into_vec(), tags_filenames: config_json.tags.into_vec(), @@ -890,6 +912,12 @@ impl Loader { .or_default() .push(self.language_configurations.len()); } + if let Some(first_line_regex) = &configuration.first_line_regex { + self.language_configuration_ids_by_first_line_regex + .entry(first_line_regex.to_string()) + .or_default() + .push(self.language_configurations.len()); + } self.language_configurations .push(unsafe { mem::transmute(configuration) }); @@ -920,7 +948,7 @@ impl Loader { file_types: Vec::new(), scope: None, content_regex: None, - _first_line_regex: None, + first_line_regex: None, injection_regex: None, injections_filenames: None, locals_filenames: None, @@ -940,8 +968,8 @@ impl Loader { Ok(&self.language_configurations[initial_language_configuration_count..]) } - fn regex(pattern: Option) -> Option { - pattern.and_then(|r| RegexBuilder::new(&r).multi_line(true).build().ok()) + fn regex(pattern: Option<&str>) -> Option { + pattern.and_then(|r| RegexBuilder::new(r).multi_line(true).build().ok()) } pub fn select_language( diff --git a/cli/src/tests/detect_language.rs b/cli/src/tests/detect_language.rs new file mode 100644 index 00000000..d28522a0 --- /dev/null +++ b/cli/src/tests/detect_language.rs @@ -0,0 +1,122 @@ +use crate::tests::helpers::fixtures::scratch_dir; + +use std::path::Path; +use tree_sitter_loader::Loader; + +#[test] +fn detect_language_by_first_line_regex() { + let strace_dir = tree_sitter_dir( + r#"{ + "name": "tree-sitter-strace", + "version": "0.0.1", + "tree-sitter": [ + { + "scope": "source.strace", + "file-types": [ + "strace" + ], + "first-line-regex": "[0-9:.]* *execve" + } + ] +} +"#, + "strace", + ); + + let mut loader = Loader::with_parser_lib_path(scratch_dir().to_path_buf()); + let config = loader + .find_language_configurations_at_path(strace_dir.path(), false) + .unwrap(); + + // this is just to validate that we can read the package.json correctly + assert_eq!(config[0].scope.as_ref().unwrap(), "source.strace"); + + let file_name = strace_dir.path().join("strace.log"); + std::fs::write(&file_name, "execve\nworld").unwrap(); + assert_eq!( + get_lang_scope(&mut loader, &file_name), + Some("source.strace".into()) + ); + + let file_name = strace_dir.path().join("strace.log"); + std::fs::write(&file_name, "447845 execve\nworld").unwrap(); + assert_eq!( + get_lang_scope(&mut loader, &file_name), + Some("source.strace".into()) + ); + + let file_name = strace_dir.path().join("strace.log"); + std::fs::write(&file_name, "hello\nexecve").unwrap(); + assert!(get_lang_scope(&mut loader, &file_name).is_none()); + + let file_name = strace_dir.path().join("strace.log"); + std::fs::write(&file_name, "").unwrap(); + assert!(get_lang_scope(&mut loader, &file_name).is_none()); + + let dummy_dir = tree_sitter_dir( + r#"{ + "name": "tree-sitter-dummy", + "version": "0.0.1", + "tree-sitter": [ + { + "scope": "source.dummy", + "file-types": [ + "dummy" + ] + } + ] +} +"#, + "dummy", + ); + + // file-type takes precedence over first-line-regex + loader + .find_language_configurations_at_path(dummy_dir.path(), false) + .unwrap(); + let file_name = dummy_dir.path().join("strace.dummy"); + std::fs::write(&file_name, "execve").unwrap(); + assert_eq!( + get_lang_scope(&mut loader, &file_name), + Some("source.dummy".into()) + ); +} + +fn tree_sitter_dir(package_json: &str, name: &str) -> tempfile::TempDir { + let temp_dir = tempfile::tempdir().unwrap(); + std::fs::write(temp_dir.path().join("package.json"), package_json).unwrap(); + std::fs::create_dir(temp_dir.path().join("src")).unwrap(); + std::fs::create_dir(temp_dir.path().join("src/tree_sitter")).unwrap(); + std::fs::write( + temp_dir.path().join("src/grammar.json"), + format!(r#"{{"name":"{name}"}}"#), + ) + .unwrap(); + std::fs::write( + temp_dir.path().join("src/parser.c"), + format!( + r##" + #include "tree_sitter/parser.h" + #ifdef _WIN32 + #define extern __declspec(dllexport) + #endif + extern const TSLanguage *tree_sitter_{name}(void) {{}} + "## + ), + ) + .unwrap(); + std::fs::write( + temp_dir.path().join("src/tree_sitter/parser.h"), + include_str!("../../../lib/src/parser.h"), + ) + .unwrap(); + temp_dir +} + +// if we manage to get the language scope, it means we correctly detected the file-type +fn get_lang_scope(loader: &mut Loader, file_name: &Path) -> Option { + loader + .language_configuration_for_file_name(file_name) + .unwrap() + .and_then(|r| r.1.scope.clone()) +} diff --git a/cli/src/tests/helpers/fixtures.rs b/cli/src/tests/helpers/fixtures.rs index bf186d5f..6a04d4c7 100644 --- a/cli/src/tests/helpers/fixtures.rs +++ b/cli/src/tests/helpers/fixtures.rs @@ -27,6 +27,10 @@ pub fn fixtures_dir() -> &'static Path { &FIXTURES_DIR } +pub fn scratch_dir() -> &'static Path { + &SCRATCH_DIR +} + pub fn get_language(name: &str) -> Language { TEST_LOADER .load_language_at_path( diff --git a/cli/src/tests/mod.rs b/cli/src/tests/mod.rs index e09dc838..8630c950 100644 --- a/cli/src/tests/mod.rs +++ b/cli/src/tests/mod.rs @@ -1,5 +1,6 @@ mod async_context_test; mod corpus_test; +mod detect_language; mod github_issue_test; mod helpers; mod highlight_test;