diff --git a/Cargo.lock b/Cargo.lock index e57779cb..f539e181 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5,7 +5,7 @@ name = "aho-corasick" version = "0.6.9" source = "registry+https://github.com/rust-lang/crates.io-index" dependencies = [ - "memchr 2.1.1 (registry+https://github.com/rust-lang/crates.io-index)", + "memchr 2.3.3 (registry+https://github.com/rust-lang/crates.io-index)", ] [[package]] @@ -303,13 +303,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" [[package]] name = "memchr" -version = "2.1.1" +version = "2.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -dependencies = [ - "cfg-if 0.1.6 (registry+https://github.com/rust-lang/crates.io-index)", - "libc 0.2.61 (registry+https://github.com/rust-lang/crates.io-index)", - "version_check 0.1.5 (registry+https://github.com/rust-lang/crates.io-index)", -] [[package]] name = "nodrop" @@ -536,7 +531,7 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" dependencies = [ "aho-corasick 0.6.9 (registry+https://github.com/rust-lang/crates.io-index)", - "memchr 2.1.1 (registry+https://github.com/rust-lang/crates.io-index)", + "memchr 2.3.3 (registry+https://github.com/rust-lang/crates.io-index)", "regex-syntax 0.6.4 (registry+https://github.com/rust-lang/crates.io-index)", "thread_local 0.3.6 (registry+https://github.com/rust-lang/crates.io-index)", "utf8-ranges 1.0.2 (registry+https://github.com/rust-lang/crates.io-index)", @@ -769,6 +764,7 @@ dependencies = [ "tiny_http 0.6.2 (registry+https://github.com/rust-lang/crates.io-index)", "tree-sitter 0.6.3", "tree-sitter-highlight 0.1.6", + "tree-sitter-tags 0.1.6", "webbrowser 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)", ] @@ -777,9 +773,15 @@ name = "tree-sitter-highlight" version = "0.1.6" dependencies = [ "regex 1.1.0 (registry+https://github.com/rust-lang/crates.io-index)", - "serde 1.0.80 (registry+https://github.com/rust-lang/crates.io-index)", - "serde_derive 1.0.80 (registry+https://github.com/rust-lang/crates.io-index)", - "serde_json 1.0.33 (registry+https://github.com/rust-lang/crates.io-index)", + "tree-sitter 0.6.3", +] + +[[package]] +name = "tree-sitter-tags" +version = "0.1.6" +dependencies = [ + "memchr 2.3.3 (registry+https://github.com/rust-lang/crates.io-index)", + "regex 1.1.0 (registry+https://github.com/rust-lang/crates.io-index)", "tree-sitter 0.6.3", ] @@ -842,11 +844,6 @@ name = "vec_map" version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -[[package]] -name = "version_check" -version = "0.1.5" -source = "registry+https://github.com/rust-lang/crates.io-index" - [[package]] name = "void" version = "1.0.2" @@ -926,7 +923,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" "checksum lock_api 0.1.5 (registry+https://github.com/rust-lang/crates.io-index)" = "62ebf1391f6acad60e5c8b43706dde4582df75c06698ab44511d15016bc2442c" "checksum log 0.4.6 (registry+https://github.com/rust-lang/crates.io-index)" = "c84ec4b527950aa83a329754b01dbe3f58361d1c5efacd1f6d68c494d08a17c6" "checksum matches 0.1.8 (registry+https://github.com/rust-lang/crates.io-index)" = "7ffc5c5338469d4d3ea17d269fa8ea3512ad247247c30bd2df69e68309ed0a08" -"checksum memchr 2.1.1 (registry+https://github.com/rust-lang/crates.io-index)" = "0a3eb002f0535929f1199681417029ebea04aadc0c7a4224b46be99c7f5d6a16" +"checksum memchr 2.3.3 (registry+https://github.com/rust-lang/crates.io-index)" = "3728d817d99e5ac407411fa471ff9800a778d88a24685968b36824eaf4bee400" "checksum nodrop 0.1.13 (registry+https://github.com/rust-lang/crates.io-index)" = "2f9667ddcc6cc8a43afc9b7917599d7216aa09c463919ea32c59ed6cac8bc945" "checksum num-integer 0.1.39 (registry+https://github.com/rust-lang/crates.io-index)" = "e83d528d2677f0518c570baf2b7abdcf0cd2d248860b68507bdcb3e91d4c0cea" "checksum num-traits 0.2.6 (registry+https://github.com/rust-lang/crates.io-index)" = "0b3a5d7cc97d6d30d8b9bc8fa19bf45349ffe46241e8816f50f62f6d6aaabee1" @@ -987,7 +984,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" "checksum url 1.7.2 (registry+https://github.com/rust-lang/crates.io-index)" = "dd4e7c0d531266369519a4aa4f399d748bd37043b00bde1e4ff1f60a120b355a" "checksum utf8-ranges 1.0.2 (registry+https://github.com/rust-lang/crates.io-index)" = "796f7e48bef87609f7ade7e06495a87d5cd06c7866e6a5cbfceffc558a243737" "checksum vec_map 0.8.1 (registry+https://github.com/rust-lang/crates.io-index)" = "05c78687fb1a80548ae3250346c3db86a80a7cdd77bda190189f2d0a0987c81a" -"checksum version_check 0.1.5 (registry+https://github.com/rust-lang/crates.io-index)" = "914b1a6776c4c929a602fafd8bc742e06365d4bcbe48c30f9cca5824f70dc9dd" "checksum void 1.0.2 (registry+https://github.com/rust-lang/crates.io-index)" = "6a02e4885ed3bc0f2de90ea6dd45ebcbb66dacffe03547fadbb0eeae2770887d" "checksum webbrowser 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)" = "c01efd7cb6939b7f34983f1edff0550e5b21b49e2db4495656295922df8939ac" "checksum widestring 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "effc0e4ff8085673ea7b9b2e3c73f6bd4d118810c9009ed8f1e16bd96c331db6" diff --git a/cli/Cargo.toml b/cli/Cargo.toml index 4304d1b6..27706945 100644 --- a/cli/Cargo.toml +++ b/cli/Cargo.toml @@ -44,6 +44,10 @@ path = "../lib" version = ">= 0.1.0" path = "../highlight" +[dependencies.tree-sitter-tags] +version = ">= 0.1.0" +path = "../tags" + [dependencies.serde_json] version = "1.0" features = ["preserve_order"] diff --git a/cli/src/error.rs b/cli/src/error.rs index 73dcb732..824bd92f 100644 --- a/cli/src/error.rs +++ b/cli/src/error.rs @@ -81,6 +81,12 @@ impl<'a> From for Error { } } +impl<'a> From for Error { + fn from(error: tree_sitter_tags::Error) -> Self { + Error::new(format!("{:?}", error)) + } +} + impl From for Error { fn from(error: serde_json::Error) -> Self { Error::new(error.to_string()) diff --git a/cli/src/generate/render.rs b/cli/src/generate/render.rs index 824c3bcf..e8c59d07 100644 --- a/cli/src/generate/render.rs +++ b/cli/src/generate/render.rs @@ -325,12 +325,13 @@ impl Generator { add_line!(self, "static TSSymbol ts_symbol_map[] = {{"); indent!(self); for symbol in &self.parse_table.symbols { + let mut mapping = symbol; + // There can be multiple symbols in the grammar that have the same name and kind, // due to simple aliases. When that happens, ensure that they map to the same // public-facing symbol. If one of the symbols is not aliased, choose that one // to be the public-facing symbol. Otherwise, pick the symbol with the lowest // numeric value. - let mut mapping = symbol; if let Some(alias) = self.simple_aliases.get(symbol) { let kind = alias.kind(); for other_symbol in &self.parse_table.symbols { @@ -344,6 +345,20 @@ impl Generator { } } } + // Two anonymous tokens with different flags but the same string value + // should be represented with the same symbol in the public API. Examples: + // * "<" and token(prec(1, "<")) + // * "(" and token.immediate("(") + else if symbol.is_terminal() { + let metadata = self.metadata_for_symbol(*symbol); + for other_symbol in &self.parse_table.symbols { + let other_metadata = self.metadata_for_symbol(*other_symbol); + if other_metadata == metadata { + mapping = other_symbol; + break; + } + } + } add_line!( self, diff --git a/cli/src/highlight.rs b/cli/src/highlight.rs index c80e6083..c6b1193d 100644 --- a/cli/src/highlight.rs +++ b/cli/src/highlight.rs @@ -1,3 +1,4 @@ +use super::util; use crate::error::Result; use crate::loader::Loader; use ansi_term::Color; @@ -6,10 +7,8 @@ use serde::ser::SerializeMap; use serde::{Deserialize, Deserializer, Serialize, Serializer}; use serde_json::{json, Value}; use std::collections::HashMap; -use std::sync::atomic::{AtomicUsize, Ordering}; -use std::sync::Arc; use std::time::Instant; -use std::{fs, io, path, str, thread, usize}; +use std::{fs, io, path, str, usize}; use tree_sitter_highlight::{HighlightConfiguration, HighlightEvent, Highlighter, HtmlRenderer}; pub const HTML_HEADER: &'static str = " @@ -273,19 +272,6 @@ fn color_to_css(color: Color) -> &'static str { } } -fn cancel_on_stdin() -> Arc { - let result = Arc::new(AtomicUsize::new(0)); - thread::spawn({ - let flag = result.clone(); - move || { - let mut line = String::new(); - io::stdin().read_line(&mut line).unwrap(); - flag.store(1, Ordering::Relaxed); - } - }); - result -} - pub fn ansi( loader: &Loader, theme: &Theme, @@ -296,7 +282,7 @@ pub fn ansi( let stdout = io::stdout(); let mut stdout = stdout.lock(); let time = Instant::now(); - let cancellation_flag = cancel_on_stdin(); + let cancellation_flag = util::cancel_on_stdin(); let mut highlighter = Highlighter::new(); let events = highlighter.highlight(config, source, Some(&cancellation_flag), |string| { @@ -341,7 +327,7 @@ pub fn html( let stdout = io::stdout(); let mut stdout = stdout.lock(); let time = Instant::now(); - let cancellation_flag = cancel_on_stdin(); + let cancellation_flag = util::cancel_on_stdin(); let mut highlighter = Highlighter::new(); let events = highlighter.highlight(config, source, Some(&cancellation_flag), |string| { diff --git a/cli/src/lib.rs b/cli/src/lib.rs index 945fe339..97c288a1 100644 --- a/cli/src/lib.rs +++ b/cli/src/lib.rs @@ -6,6 +6,7 @@ pub mod loader; pub mod logger; pub mod parse; pub mod query; +pub mod tags; pub mod test; pub mod test_highlight; pub mod util; diff --git a/cli/src/loader.rs b/cli/src/loader.rs index 1f9a1978..cf2eb143 100644 --- a/cli/src/loader.rs +++ b/cli/src/loader.rs @@ -12,6 +12,7 @@ use std::time::SystemTime; use std::{fs, mem}; use tree_sitter::Language; use tree_sitter_highlight::HighlightConfiguration; +use tree_sitter_tags::TagsConfiguration; #[cfg(unix)] const DYLIB_EXTENSION: &'static str = "so"; @@ -31,8 +32,10 @@ pub struct LanguageConfiguration<'a> { pub highlights_filenames: Option>, pub injections_filenames: Option>, pub locals_filenames: Option>, + pub tags_filenames: Option>, language_id: usize, highlight_config: OnceCell>, + tags_config: OnceCell>, highlight_names: &'a Mutex>, use_all_highlight_names: bool, } @@ -432,6 +435,8 @@ impl Loader { injections: PathsJSON, #[serde(default)] locals: PathsJSON, + #[serde(default)] + tags: PathsJSON, } #[derive(Deserialize)] @@ -479,8 +484,10 @@ impl Loader { injection_regex: Self::regex(config_json.injection_regex), injections_filenames: config_json.injections.into_vec(), locals_filenames: config_json.locals.into_vec(), + tags_filenames: config_json.tags.into_vec(), highlights_filenames: config_json.highlights.into_vec(), highlight_config: OnceCell::new(), + tags_config: OnceCell::new(), highlight_names: &*self.highlight_names, use_all_highlight_names: self.use_all_highlight_names, }; @@ -512,7 +519,9 @@ impl Loader { injections_filenames: None, locals_filenames: None, highlights_filenames: None, + tags_filenames: None, highlight_config: OnceCell::new(), + tags_config: OnceCell::new(), highlight_names: &*self.highlight_names, use_all_highlight_names: self.use_all_highlight_names, }; @@ -534,32 +543,11 @@ impl<'a> LanguageConfiguration<'a> { pub fn highlight_config(&self, language: Language) -> Result> { self.highlight_config .get_or_try_init(|| { - let queries_path = self.root_path.join("queries"); - let read_queries = |paths: &Option>, default_path: &str| { - if let Some(paths) = paths.as_ref() { - let mut query = String::new(); - for path in paths { - let path = self.root_path.join(path); - query += &fs::read_to_string(&path).map_err(Error::wrap(|| { - format!("Failed to read query file {:?}", path) - }))?; - } - Ok(query) - } else { - let path = queries_path.join(default_path); - if path.exists() { - fs::read_to_string(&path).map_err(Error::wrap(|| { - format!("Failed to read query file {:?}", path) - })) - } else { - Ok(String::new()) - } - } - }; - - let highlights_query = read_queries(&self.highlights_filenames, "highlights.scm")?; - let injections_query = read_queries(&self.injections_filenames, "injections.scm")?; - let locals_query = read_queries(&self.locals_filenames, "locals.scm")?; + let highlights_query = + self.read_queries(&self.highlights_filenames, "highlights.scm")?; + let injections_query = + self.read_queries(&self.injections_filenames, "injections.scm")?; + let locals_query = self.read_queries(&self.locals_filenames, "locals.scm")?; if highlights_query.is_empty() { Ok(None) @@ -587,6 +575,47 @@ impl<'a> LanguageConfiguration<'a> { }) .map(Option::as_ref) } + + pub fn tags_config(&self, language: Language) -> Result> { + self.tags_config + .get_or_try_init(|| { + let tags_query = self.read_queries(&self.tags_filenames, "tags.scm")?; + let locals_query = self.read_queries(&self.locals_filenames, "locals.scm")?; + if tags_query.is_empty() { + Ok(None) + } else { + TagsConfiguration::new(language, &tags_query, &locals_query) + .map_err(Error::wrap(|| { + format!("Failed to load queries in {:?}", self.root_path) + })) + .map(|config| Some(config)) + } + }) + .map(Option::as_ref) + } + + fn read_queries(&self, paths: &Option>, default_path: &str) -> Result { + if let Some(paths) = paths.as_ref() { + let mut query = String::new(); + for path in paths { + let path = self.root_path.join(path); + query += &fs::read_to_string(&path).map_err(Error::wrap(|| { + format!("Failed to read query file {:?}", path) + }))?; + } + Ok(query) + } else { + let queries_path = self.root_path.join("queries"); + let path = queries_path.join(default_path); + if path.exists() { + fs::read_to_string(&path).map_err(Error::wrap(|| { + format!("Failed to read query file {:?}", path) + })) + } else { + Ok(String::new()) + } + } + } } fn needs_recompile( diff --git a/cli/src/main.rs b/cli/src/main.rs index 79d310fe..c5c0e0e0 100644 --- a/cli/src/main.rs +++ b/cli/src/main.rs @@ -6,8 +6,8 @@ use std::process::exit; use std::{env, fs, u64}; use tree_sitter::Language; use tree_sitter_cli::{ - config, error, generate, highlight, loader, logger, parse, query, test, test_highlight, wasm, - web_ui, + config, error, generate, highlight, loader, logger, parse, query, tags, test, test_highlight, + wasm, web_ui, }; const BUILD_VERSION: &'static str = env!("CARGO_PKG_VERSION"); @@ -88,6 +88,30 @@ fn run() -> error::Result<()> { .arg(Arg::with_name("scope").long("scope").takes_value(true)) .arg(Arg::with_name("captures").long("captures").short("c")), ) + .subcommand( + SubCommand::with_name("tags") + .arg( + Arg::with_name("format") + .short("f") + .long("format") + .value_name("json|protobuf") + .help("Determine output format (default: json)"), + ) + .arg(Arg::with_name("scope").long("scope").takes_value(true)) + .arg( + Arg::with_name("inputs") + .help("The source file to use") + .index(1) + .required(true) + .multiple(true), + ) + .arg( + Arg::with_name("v") + .short("v") + .multiple(true) + .help("Sets the level of verbosity"), + ), + ) .subcommand( SubCommand::with_name("test") .about("Run a parser's tests") @@ -240,6 +264,10 @@ fn run() -> error::Result<()> { )?; let query_path = Path::new(matches.value_of("query-path").unwrap()); query::query_files_at_paths(language, paths, query_path, ordered_captures)?; + } else if let Some(matches) = matches.subcommand_matches("tags") { + loader.find_all_languages(&config.parser_directories)?; + let paths = collect_paths(matches.values_of("inputs").unwrap())?; + tags::generate_tags(&loader, matches.value_of("scope"), &paths)?; } else if let Some(matches) = matches.subcommand_matches("highlight") { loader.configure_highlights(&config.theme.highlight_names); loader.find_all_languages(&config.parser_directories)?; @@ -251,19 +279,17 @@ fn run() -> error::Result<()> { println!("{}", highlight::HTML_HEADER); } - let language_config; + let mut lang = None; if let Some(scope) = matches.value_of("scope") { - language_config = loader.language_configuration_for_scope(scope)?; - if language_config.is_none() { + lang = loader.language_configuration_for_scope(scope)?; + if lang.is_none() { return Error::err(format!("Unknown scope '{}'", scope)); } - } else { - language_config = None; } for path in paths { let path = Path::new(&path); - let (language, language_config) = match language_config { + let (language, language_config) = match lang { Some(v) => v, None => match loader.language_configuration_for_file_name(path)? { Some(v) => v, @@ -274,23 +300,21 @@ fn run() -> error::Result<()> { }, }; - let source = fs::read(path)?; - if let Some(highlight_config) = language_config.highlight_config(language)? { + let source = fs::read(path)?; if html_mode { highlight::html(&loader, &config.theme, &source, highlight_config, time)?; } else { highlight::ansi(&loader, &config.theme, &source, highlight_config, time)?; } } else { - return Error::err(format!("No syntax highlighting query found")); + eprintln!("No syntax highlighting config found for path {:?}", path); } } if html_mode { println!("{}", highlight::HTML_FOOTER); } - } else if let Some(matches) = matches.subcommand_matches("build-wasm") { let grammar_path = current_dir.join(matches.value_of("path").unwrap_or("")); wasm::compile_language_to_wasm(&grammar_path, matches.is_present("docker"))?; diff --git a/cli/src/tags.rs b/cli/src/tags.rs new file mode 100644 index 00000000..d6704ec5 --- /dev/null +++ b/cli/src/tags.rs @@ -0,0 +1,66 @@ +use super::loader::Loader; +use super::util; +use crate::error::{Error, Result}; +use std::io::{self, Write}; +use std::path::Path; +use std::{fs, str}; +use tree_sitter_tags::TagsContext; + +pub fn generate_tags(loader: &Loader, scope: Option<&str>, paths: &[String]) -> Result<()> { + let mut lang = None; + if let Some(scope) = scope { + lang = loader.language_configuration_for_scope(scope)?; + if lang.is_none() { + return Error::err(format!("Unknown scope '{}'", scope)); + } + } + + let mut context = TagsContext::new(); + let cancellation_flag = util::cancel_on_stdin(); + let stdout = io::stdout(); + let mut stdout = stdout.lock(); + + for path in paths { + let path = Path::new(&path); + let (language, language_config) = match lang { + Some(v) => v, + None => match loader.language_configuration_for_file_name(path)? { + Some(v) => v, + None => { + eprintln!("No language found for path {:?}", path); + continue; + } + }, + }; + + if let Some(tags_config) = language_config.tags_config(language)? { + let path_str = format!("{:?}", path); + writeln!(&mut stdout, "{}", &path_str[1..path_str.len() - 1])?; + + let source = fs::read(path)?; + for tag in context.generate_tags(tags_config, &source, Some(&cancellation_flag))? { + let tag = tag?; + write!( + &mut stdout, + " {:<8} {:<40}\t{:>9}-{:<9}", + tag.kind, + str::from_utf8(&source[tag.name_range]).unwrap_or(""), + tag.span.start, + tag.span.end, + )?; + if let Some(docs) = tag.docs { + if docs.len() > 120 { + write!(&mut stdout, "\t{:?}...", &docs[0..120])?; + } else { + write!(&mut stdout, "\t{:?}", &docs)?; + } + } + writeln!(&mut stdout, "")?; + } + } else { + eprintln!("No tags config found for path {:?}", path); + } + } + + Ok(()) +} diff --git a/cli/src/tests/mod.rs b/cli/src/tests/mod.rs index 0ccb0ae0..ac54db00 100644 --- a/cli/src/tests/mod.rs +++ b/cli/src/tests/mod.rs @@ -4,5 +4,6 @@ mod highlight_test; mod node_test; mod parser_test; mod query_test; +mod tags_test; mod test_highlight_test; mod tree_test; diff --git a/cli/src/tests/query_test.rs b/cli/src/tests/query_test.rs index 87420501..f69074a8 100644 --- a/cli/src/tests/query_test.rs +++ b/cli/src/tests/query_test.rs @@ -2,7 +2,8 @@ use super::helpers::allocations; use super::helpers::fixtures::get_language; use std::fmt::Write; use tree_sitter::{ - Node, Parser, Query, QueryCapture, QueryCursor, QueryError, QueryMatch, QueryProperty, + Node, Parser, Query, QueryCapture, QueryCursor, QueryError, QueryMatch, QueryPredicate, + QueryPredicateArg, QueryProperty, }; #[test] @@ -438,6 +439,10 @@ fn test_query_matches_with_named_wildcard() { fn test_query_matches_with_wildcard_at_the_root() { allocations::record(|| { let language = get_language("javascript"); + let mut cursor = QueryCursor::new(); + let mut parser = Parser::new(); + parser.set_language(language).unwrap(); + let query = Query::new( language, " @@ -452,20 +457,41 @@ fn test_query_matches_with_wildcard_at_the_root() { let source = "/* one */ var x; /* two */ function y() {} /* three */ class Z {}"; - let mut parser = Parser::new(); - parser.set_language(language).unwrap(); let tree = parser.parse(source, None).unwrap(); - let mut cursor = QueryCursor::new(); let matches = cursor.matches(&query, tree.root_node(), to_callback(source)); - assert_eq!( collect_matches(matches, &query, source), &[(0, vec![("doc", "/* two */"), ("name", "y")]),] ); + + let query = Query::new( + language, + " + (* (string) @a) + (* (number) @b) + (* (true) @c) + (* (false) @d) + ", + ) + .unwrap(); + + let source = "['hi', x(true), {y: false}]"; + + let tree = parser.parse(source, None).unwrap(); + let matches = cursor.matches(&query, tree.root_node(), to_callback(source)); + assert_eq!( + collect_matches(matches, &query, source), + &[ + (0, vec![("a", "'hi'")]), + (2, vec![("c", "true")]), + (3, vec![("d", "false")]), + ] + ); }); } + #[test] -fn test_query_with_immediate_siblings() { +fn test_query_matches_with_immediate_siblings() { allocations::record(|| { let language = get_language("python"); @@ -515,6 +541,107 @@ fn test_query_with_immediate_siblings() { }); } +#[test] +fn test_query_matches_with_repeated_leaf_nodes() { + allocations::record(|| { + let language = get_language("javascript"); + + let query = Query::new( + language, + " + (* + (comment)+ @doc + . + (class_declaration + name: (identifier) @name)) + + (* + (comment)+ @doc + . + (function_declaration + name: (identifier) @name)) + ", + ) + .unwrap(); + + let source = " + // one + // two + a(); + + // three + { + // four + // five + // six + class B {} + + // seven + c(); + + // eight + function d() {} + } + "; + + let mut parser = Parser::new(); + parser.set_language(language).unwrap(); + let tree = parser.parse(source, None).unwrap(); + let mut cursor = QueryCursor::new(); + let matches = cursor.matches(&query, tree.root_node(), to_callback(source)); + + assert_eq!( + collect_matches(matches, &query, source), + &[ + ( + 0, + vec![ + ("doc", "// four"), + ("doc", "// five"), + ("doc", "// six"), + ("name", "B") + ] + ), + (1, vec![("doc", "// eight"), ("name", "d")]), + ] + ); + }); +} + +#[test] +fn test_query_matches_with_repeated_internal_nodes() { + allocations::record(|| { + let language = get_language("javascript"); + let mut parser = Parser::new(); + parser.set_language(language).unwrap(); + let mut cursor = QueryCursor::new(); + + let query = Query::new( + language, + " + (* + (method_definition + (decorator (identifier) @deco)+ + name: (property_identifier) @name)) + ", + ) + .unwrap(); + let source = " + class A { + @c + @d + e() {} + } + "; + let tree = parser.parse(source, None).unwrap(); + let matches = cursor.matches(&query, tree.root_node(), to_callback(source)); + assert_eq!( + collect_matches(matches, &query, source), + &[(0, vec![("deco", "c"), ("deco", "d"), ("name", "e")]),] + ); + }) +} + #[test] fn test_query_matches_in_language_with_simple_aliases() { allocations::record(|| { @@ -550,6 +677,41 @@ fn test_query_matches_in_language_with_simple_aliases() { }); } +#[test] +fn test_query_matches_with_different_tokens_with_the_same_string_value() { + allocations::record(|| { + let language = get_language("rust"); + let query = Query::new( + language, + r#" + "<" @less + ">" @greater + "#, + ) + .unwrap(); + + // In Rust, there are two '<' tokens: one for the binary operator, + // and one with higher precedence for generics. + let source = "const A: B = d < e || f > g;"; + + let mut parser = Parser::new(); + parser.set_language(language).unwrap(); + let tree = parser.parse(&source, None).unwrap(); + let mut cursor = QueryCursor::new(); + let matches = cursor.matches(&query, tree.root_node(), to_callback(source)); + + assert_eq!( + collect_matches(matches, &query, source), + &[ + (0, vec![("less", "<")]), + (1, vec![("greater", ">")]), + (0, vec![("less", "<")]), + (1, vec![("greater", ">")]), + ] + ); + }); +} + #[test] fn test_query_matches_with_too_many_permutations_to_track() { allocations::record(|| { @@ -880,7 +1042,7 @@ fn test_query_captures_with_text_conditions() { } #[test] -fn test_query_captures_with_set_properties() { +fn test_query_captures_with_predicates() { allocations::record(|| { let language = get_language("javascript"); @@ -889,7 +1051,8 @@ fn test_query_captures_with_set_properties() { r#" ((call_expression (identifier) @foo) (set! name something) - (set! cool)) + (set! cool) + (something! @foo omg)) ((property_identifier) @bar (is? cool) @@ -904,6 +1067,16 @@ fn test_query_captures_with_set_properties() { QueryProperty::new("cool", None, None), ] ); + assert_eq!( + query.general_predicates(0), + &[QueryPredicate { + operator: "something!".to_string().into_boxed_str(), + args: vec![ + QueryPredicateArg::Capture(0), + QueryPredicateArg::String("omg".to_string().into_boxed_str()), + ], + },] + ); assert_eq!(query.property_settings(1), &[]); assert_eq!(query.property_predicates(0), &[]); assert_eq!( @@ -917,7 +1090,7 @@ fn test_query_captures_with_set_properties() { } #[test] -fn test_query_captures_with_set_quoted_properties() { +fn test_query_captures_with_quoted_predicate_args() { allocations::record(|| { let language = get_language("javascript"); diff --git a/cli/src/tests/tags_test.rs b/cli/src/tests/tags_test.rs new file mode 100644 index 00000000..41907a3c --- /dev/null +++ b/cli/src/tests/tags_test.rs @@ -0,0 +1,347 @@ +use super::helpers::allocations; +use super::helpers::fixtures::{get_language, get_language_queries_path}; +use std::ffi::CString; +use std::{fs, ptr, slice, str}; +use tree_sitter_tags::c_lib as c; +use tree_sitter_tags::{Error, TagKind, TagsConfiguration, TagsContext}; + +const PYTHON_TAG_QUERY: &'static str = r#" +((function_definition + name: (identifier) @name + body: (block . (expression_statement (string) @doc))) @function + (strip! @doc "(^['\"\\s]*)|(['\"\\s]*$)")) +(function_definition + name: (identifier) @name) @function +((class_definition + name: (identifier) @name + body: (block . (expression_statement (string) @doc))) @class + (strip! @doc "(^['\"\\s]*)|(['\"\\s]*$)")) +(class_definition + name: (identifier) @name) @class +(call + function: (identifier) @name) @call +"#; + +const JS_TAG_QUERY: &'static str = r#" +((* + (comment)+ @doc . + (class_declaration + name: (identifier) @name) @class) + (select-adjacent! @doc @class) + (strip! @doc "(^[/\\*\\s]*)|([/\\*\\s]*$)")) + +((* + (comment)+ @doc . + (method_definition + name: (property_identifier) @name) @method) + (select-adjacent! @doc @method) + (strip! @doc "(^[/\\*\\s]*)|([/\\*\\s]*$)")) + +((* + (comment)+ @doc . + (function_declaration + name: (identifier) @name) @function) + (select-adjacent! @doc @function) + (strip! @doc "(^[/\\*\\s]*)|([/\\*\\s]*$)")) + +(call_expression function: (identifier) @name) @call + "#; + +const RUBY_TAG_QUERY: &'static str = r#" +(method + name: (identifier) @name) @method + +(method_call + method: (identifier) @name) @call + +((identifier) @name @call + (is-not? local)) +"#; + +#[test] +fn test_tags_python() { + let language = get_language("python"); + let tags_config = TagsConfiguration::new(language, PYTHON_TAG_QUERY, "").unwrap(); + let mut tag_context = TagsContext::new(); + + let source = br#" + class Customer: + """ + Data about a customer + """ + + def age(self): + ''' + Get the customer's age + ''' + compute_age(self.id) + } + "#; + + let tags = tag_context + .generate_tags(&tags_config, source, None) + .unwrap() + .collect::, _>>() + .unwrap(); + + assert_eq!( + tags.iter() + .map(|t| (substr(source, &t.name_range), t.kind)) + .collect::>(), + &[ + ("Customer", TagKind::Class), + ("age", TagKind::Function), + ("compute_age", TagKind::Call), + ] + ); + + assert_eq!(substr(source, &tags[0].line_range), " class Customer:"); + assert_eq!( + substr(source, &tags[1].line_range), + " def age(self):" + ); + assert_eq!(tags[0].docs.as_ref().unwrap(), "Data about a customer"); + assert_eq!(tags[1].docs.as_ref().unwrap(), "Get the customer's age"); +} + +#[test] +fn test_tags_javascript() { + let language = get_language("javascript"); + let tags_config = TagsConfiguration::new(language, JS_TAG_QUERY, "").unwrap(); + let source = br#" + // hi + + // Data about a customer. + // bla bla bla + class Customer { + /* + * Get the customer's age + */ + getAge() { + } + } + + // ok + + class Agent { + + } + "#; + + let mut tag_context = TagsContext::new(); + let tags = tag_context + .generate_tags(&tags_config, source, None) + .unwrap() + .collect::, _>>() + .unwrap(); + + assert_eq!( + tags.iter() + .map(|t| (substr(source, &t.name_range), t.kind)) + .collect::>(), + &[ + ("Customer", TagKind::Class), + ("getAge", TagKind::Method), + ("Agent", TagKind::Class) + ] + ); + assert_eq!( + tags[0].docs.as_ref().unwrap(), + "Data about a customer.\nbla bla bla" + ); + assert_eq!(tags[1].docs.as_ref().unwrap(), "Get the customer's age"); + assert_eq!(tags[2].docs, None); +} + +#[test] +fn test_tags_ruby() { + let language = get_language("ruby"); + let locals_query = + fs::read_to_string(get_language_queries_path("ruby").join("locals.scm")).unwrap(); + let tags_config = TagsConfiguration::new(language, RUBY_TAG_QUERY, &locals_query).unwrap(); + let source = strip_whitespace( + 8, + " + b = 1 + + def foo() + c = 1 + + # a is a method because it is not in scope + # b is a method because `b` doesn't capture variables from its containing scope + bar a, b, c + + [1, 2, 3].each do |a| + # a is a parameter + # b is a method + # c is a variable, because the block captures variables from its containing scope. + baz a, b, c + end + end", + ); + + let mut tag_context = TagsContext::new(); + let tags = tag_context + .generate_tags(&tags_config, source.as_bytes(), None) + .unwrap() + .collect::, _>>() + .unwrap(); + + assert_eq!( + tags.iter() + .map(|t| ( + substr(source.as_bytes(), &t.name_range), + t.kind, + (t.span.start.row, t.span.start.column), + )) + .collect::>(), + &[ + ("foo", TagKind::Method, (2, 0)), + ("bar", TagKind::Call, (7, 4)), + ("a", TagKind::Call, (7, 8)), + ("b", TagKind::Call, (7, 11)), + ("each", TagKind::Call, (9, 14)), + ("baz", TagKind::Call, (13, 8)), + ("b", TagKind::Call, (13, 15),), + ] + ); +} + +#[test] +fn test_tags_cancellation() { + use std::sync::atomic::{AtomicUsize, Ordering}; + + allocations::record(|| { + // Large javascript document + let source = (0..500) + .map(|_| "/* hi */ class A { /* ok */ b() {} }\n") + .collect::(); + + let cancellation_flag = AtomicUsize::new(0); + let language = get_language("javascript"); + let tags_config = TagsConfiguration::new(language, JS_TAG_QUERY, "").unwrap(); + + let mut tag_context = TagsContext::new(); + let tags = tag_context + .generate_tags(&tags_config, source.as_bytes(), Some(&cancellation_flag)) + .unwrap(); + + for (i, tag) in tags.enumerate() { + if i == 150 { + cancellation_flag.store(1, Ordering::SeqCst); + } + if let Err(e) = tag { + assert_eq!(e, Error::Cancelled); + return; + } + } + + panic!("Expected to halt tagging with an error"); + }); +} + +#[test] +fn test_tags_via_c_api() { + allocations::record(|| { + let tagger = c::ts_tagger_new(); + let buffer = c::ts_tags_buffer_new(); + let scope_name = "source.js"; + let language = get_language("javascript"); + + let source_code = strip_whitespace( + 12, + " + var a = 1; + + // one + // two + // three + function b() { + } + + // four + // five + class C extends D { + + } + + b(a);", + ); + + let c_scope_name = CString::new(scope_name).unwrap(); + let result = c::ts_tagger_add_language( + tagger, + c_scope_name.as_ptr(), + language, + JS_TAG_QUERY.as_ptr(), + ptr::null(), + JS_TAG_QUERY.len() as u32, + 0, + ); + assert_eq!(result, c::TSTagsError::Ok); + + let result = c::ts_tagger_tag( + tagger, + c_scope_name.as_ptr(), + source_code.as_ptr(), + source_code.len() as u32, + buffer, + ptr::null(), + ); + assert_eq!(result, c::TSTagsError::Ok); + let tags = unsafe { + slice::from_raw_parts( + c::ts_tags_buffer_tags(buffer), + c::ts_tags_buffer_tags_len(buffer) as usize, + ) + }; + let docs = str::from_utf8(unsafe { + slice::from_raw_parts( + c::ts_tags_buffer_docs(buffer) as *const u8, + c::ts_tags_buffer_docs_len(buffer) as usize, + ) + }) + .unwrap(); + + assert_eq!( + tags.iter() + .map(|tag| ( + tag.kind, + &source_code[tag.name_start_byte as usize..tag.name_end_byte as usize], + &source_code[tag.line_start_byte as usize..tag.line_end_byte as usize], + &docs[tag.docs_start_byte as usize..tag.docs_end_byte as usize], + )) + .collect::>(), + &[ + ( + c::TSTagKind::Function, + "b", + "function b() {", + "one\ntwo\nthree" + ), + ( + c::TSTagKind::Class, + "C", + "class C extends D {", + "four\nfive" + ), + (c::TSTagKind::Call, "b", "b(a);", "") + ] + ); + + c::ts_tags_buffer_delete(buffer); + c::ts_tagger_delete(tagger); + }); +} + +fn substr<'a>(source: &'a [u8], range: &std::ops::Range) -> &'a str { + std::str::from_utf8(&source[range.clone()]).unwrap() +} + +fn strip_whitespace(indent: usize, s: &str) -> String { + s.lines() + .skip(1) + .map(|line| &line[line.len().min(indent)..]) + .collect::>() + .join("\n") +} diff --git a/cli/src/util.rs b/cli/src/util.rs index e880bea1..8978ecc1 100644 --- a/cli/src/util.rs +++ b/cli/src/util.rs @@ -1,12 +1,29 @@ +use std::io; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; +use std::thread; +use tree_sitter::Parser; + #[cfg(unix)] use std::path::PathBuf; #[cfg(unix)] use std::process::{Child, ChildStdin, Command, Stdio}; -use tree_sitter::Parser; #[cfg(unix)] const HTML_HEADER: &[u8] = b"\n\n\n"; +pub fn cancel_on_stdin() -> Arc { + let result = Arc::new(AtomicUsize::new(0)); + thread::spawn({ + let flag = result.clone(); + move || { + let mut line = String::new(); + io::stdin().read_line(&mut line).unwrap(); + flag.store(1, Ordering::Relaxed); + } + }); + result +} #[cfg(windows)] pub struct LogSession(); diff --git a/highlight/Cargo.toml b/highlight/Cargo.toml index b0d32c02..94a4e032 100644 --- a/highlight/Cargo.toml +++ b/highlight/Cargo.toml @@ -18,9 +18,6 @@ crate-type = ["lib", "staticlib"] [dependencies] regex = "1" -serde = "1.0" -serde_json = "1.0" -serde_derive = "1.0" [dependencies.tree-sitter] version = ">= 0.3.7" diff --git a/lib/binding_rust/lib.rs b/lib/binding_rust/lib.rs index b40d97e5..a13d9168 100644 --- a/lib/binding_rust/lib.rs +++ b/lib/binding_rust/lib.rs @@ -95,6 +95,7 @@ pub struct Query { text_predicates: Vec>, property_settings: Vec>, property_predicates: Vec>, + general_predicates: Vec>, } /// A stateful object for executing a `Query` on a syntax `Tree`. @@ -108,6 +109,19 @@ pub struct QueryProperty { pub capture_id: Option, } +#[derive(Debug, PartialEq, Eq)] +pub enum QueryPredicateArg { + Capture(u32), + String(Box), +} + +/// A key-value pair associated with a particular pattern in a `Query`. +#[derive(Debug, PartialEq, Eq)] +pub struct QueryPredicate { + pub operator: Box, + pub args: Vec, +} + /// A match of a `Query` to a particular set of `Node`s. pub struct QueryMatch<'a> { pub pattern_index: usize, @@ -1194,6 +1208,7 @@ impl Query { text_predicates: Vec::with_capacity(pattern_count), property_predicates: Vec::with_capacity(pattern_count), property_settings: Vec::with_capacity(pattern_count), + general_predicates: Vec::with_capacity(pattern_count), }; // Build a vector of strings to store the capture names. @@ -1237,6 +1252,7 @@ impl Query { let mut text_predicates = Vec::new(); let mut property_predicates = Vec::new(); let mut property_settings = Vec::new(); + let mut general_predicates = Vec::new(); for p in predicate_steps.split(|s| s.type_ == type_done) { if p.is_empty() { continue; @@ -1328,12 +1344,21 @@ impl Query { operator_name == "is?", )), - _ => { - return Err(QueryError::Predicate(format!( - "Unknown query predicate function {}", - operator_name, - ))) - } + _ => general_predicates.push(QueryPredicate { + operator: operator_name.clone().into_boxed_str(), + args: p[1..] + .iter() + .map(|a| { + if a.type_ == type_capture { + QueryPredicateArg::Capture(a.value_id) + } else { + QueryPredicateArg::String( + string_values[a.value_id as usize].clone().into_boxed_str(), + ) + } + }) + .collect(), + }), } } @@ -1346,6 +1371,9 @@ impl Query { result .property_settings .push(property_settings.into_boxed_slice()); + result + .general_predicates + .push(general_predicates.into_boxed_slice()); } Ok(result) } @@ -1375,15 +1403,30 @@ impl Query { } /// Get the properties that are checked for the given pattern index. + /// + /// This includes predicates with the operators `is?` and `is-not?`. pub fn property_predicates(&self, index: usize) -> &[(QueryProperty, bool)] { &self.property_predicates[index] } /// Get the properties that are set for the given pattern index. + /// + /// This includes predicates with the operator `set!`. pub fn property_settings(&self, index: usize) -> &[QueryProperty] { &self.property_settings[index] } + /// Get the other user-defined predicates associated with the given index. + /// + /// This includes predicate with operators other than: + /// * `match?` + /// * `eq?` and `not-eq? + /// * `is?` and `is-not?` + /// * `set!` + pub fn general_predicates(&self, index: usize) -> &[QueryPredicate] { + &self.general_predicates[index] + } + /// Disable a certain capture within a query. /// /// This prevents the capture from being returned in matches, and also avoids any @@ -1420,46 +1463,39 @@ impl Query { ))); } - let mut i = 0; let mut capture_id = None; - if args[i].type_ == ffi::TSQueryPredicateStepType_TSQueryPredicateStepTypeCapture { - capture_id = Some(args[i].value_id as usize); - i += 1; - - if i == args.len() { - return Err(QueryError::Predicate(format!( - "No key specified for {} predicate.", - function_name, - ))); - } - if args[i].type_ == ffi::TSQueryPredicateStepType_TSQueryPredicateStepTypeCapture { - return Err(QueryError::Predicate(format!( - "Invalid arguments to {} predicate. Expected string, got @{}", - function_name, capture_names[args[i].value_id as usize] - ))); - } - } - - let key = &string_values[args[i].value_id as usize]; - i += 1; - + let mut key = None; let mut value = None; - if i < args.len() { - if args[i].type_ == ffi::TSQueryPredicateStepType_TSQueryPredicateStepTypeCapture { + + for arg in args { + if arg.type_ == ffi::TSQueryPredicateStepType_TSQueryPredicateStepTypeCapture { if capture_id.is_some() { return Err(QueryError::Predicate(format!( "Invalid arguments to {} predicate. Unexpected second capture name @{}", - function_name, capture_names[args[i].value_id as usize] + function_name, capture_names[arg.value_id as usize] ))); - } else { - capture_id = Some(args[i].value_id as usize); } + capture_id = Some(arg.value_id as usize); + } else if key.is_none() { + key = Some(&string_values[arg.value_id as usize]); + } else if value.is_none() { + value = Some(string_values[arg.value_id as usize].as_str()); } else { - value = Some(string_values[args[i].value_id as usize].as_str()); + return Err(QueryError::Predicate(format!( + "Invalid arguments to {} predicate. Unexpected third argument @{}", + function_name, string_values[arg.value_id as usize] + ))); } } - Ok(QueryProperty::new(key, value, capture_id)) + if let Some(key) = key { + Ok(QueryProperty::new(key, value, capture_id)) + } else { + return Err(QueryError::Predicate(format!( + "Invalid arguments to {} predicate. Missing key argument", + function_name, + ))); + } } } diff --git a/lib/src/query.c b/lib/src/query.c index 65144395..87ab05b5 100644 --- a/lib/src/query.c +++ b/lib/src/query.c @@ -35,15 +35,21 @@ typedef struct { * captured in this pattern. * - `depth` - The depth where this node occurs in the pattern. The root node * of the pattern has depth zero. + * - `repeat_step_index` - If this step is part of a repetition, the index of + * the beginning of the repetition. A `NONE` value means this step is not + * part of a repetition. */ typedef struct { TSSymbol symbol; TSFieldId field; uint16_t capture_ids[MAX_STEP_CAPTURE_COUNT]; - uint16_t depth: 13; + uint16_t repeat_step_index; + uint16_t depth: 11; bool contains_captures: 1; + bool is_pattern_start: 1; bool is_immediate: 1; bool is_last: 1; + bool is_repeated: 1; } QueryStep; /* @@ -85,23 +91,27 @@ typedef struct { * represented as one of these states. */ typedef struct { + uint32_t id; uint16_t start_depth; uint16_t pattern_index; uint16_t step_index; - uint16_t capture_count; - uint16_t capture_list_id; uint16_t consumed_capture_count; - uint32_t id; + uint16_t repeat_match_count; + uint16_t step_index_on_failure; + uint8_t capture_list_id; + bool seeking_non_match; } QueryState; +typedef Array(TSQueryCapture) CaptureList; + /* * CaptureListPool - A collection of *lists* of captures. Each QueryState - * needs to maintain its own list of captures. They are all represented as - * slices of one shared array. The CaptureListPool keeps track of which - * parts of the shared array are currently in use by a QueryState. + * needs to maintain its own list of captures. To avoid repeated allocations, + * the reuses a fixed set of capture lists, and keeps track of which ones + * are currently in use. */ typedef struct { - Array(TSQueryCapture) list; + CaptureList list[32]; uint32_t usage_map; } CaptureListPool; @@ -119,7 +129,6 @@ struct TSQuery { Array(Slice) predicates_by_pattern; Array(uint32_t) start_bytes_by_pattern; const TSLanguage *language; - uint16_t max_capture_count; uint16_t wildcard_root_pattern_count; TSSymbol *symbol_map; }; @@ -233,24 +242,25 @@ static void stream_scan_identifier(Stream *stream) { static CaptureListPool capture_list_pool_new() { return (CaptureListPool) { - .list = array_new(), .usage_map = UINT32_MAX, }; } -static void capture_list_pool_reset(CaptureListPool *self, uint16_t list_size) { +static void capture_list_pool_reset(CaptureListPool *self) { self->usage_map = UINT32_MAX; - uint32_t total_size = MAX_STATE_COUNT * list_size; - array_reserve(&self->list, total_size); - self->list.size = total_size; + for (unsigned i = 0; i < 32; i++) { + array_clear(&self->list[i]); + } } static void capture_list_pool_delete(CaptureListPool *self) { - array_delete(&self->list); + for (unsigned i = 0; i < 32; i++) { + array_delete(&self->list[i]); + } } -static TSQueryCapture *capture_list_pool_get(CaptureListPool *self, uint16_t id) { - return &self->list.contents[id * (self->list.size / MAX_STATE_COUNT)]; +static CaptureList *capture_list_pool_get(CaptureListPool *self, uint16_t id) { + return &self->list[id]; } static bool capture_list_pool_is_empty(const CaptureListPool *self) { @@ -269,6 +279,7 @@ static uint16_t capture_list_pool_acquire(CaptureListPool *self) { } static void capture_list_pool_release(CaptureListPool *self, uint16_t id) { + array_clear(&self->list[id]); self->usage_map |= bitmask_for_index(id); } @@ -407,7 +418,11 @@ static QueryStep query_step__new( .field = 0, .capture_ids = {NONE, NONE, NONE, NONE}, .contains_captures = false, + .is_repeated = false, + .is_last = false, + .is_pattern_start = false, .is_immediate = is_immediate, + .repeat_step_index = NONE, }; } @@ -842,27 +857,43 @@ static TSQueryError ts_query__parse_pattern( stream_skip_whitespace(stream); - // Parse an '@'-prefixed capture pattern - while (stream->next == '@') { - stream_advance(stream); - - // Parse the capture name - if (!stream_is_ident_start(stream)) return TSQueryErrorSyntax; - const char *capture_name = stream->input; - stream_scan_identifier(stream); - uint32_t length = stream->input - capture_name; - - // Add the capture id to the first step of the pattern - uint16_t capture_id = symbol_table_insert_name( - &self->captures, - capture_name, - length - ); + // Parse suffixes modifiers for this pattern + for (;;) { QueryStep *step = &self->steps.contents[starting_step_index]; - query_step__add_capture(step, capture_id); - (*capture_count)++; - stream_skip_whitespace(stream); + if (stream->next == '+') { + stream_advance(stream); + step->is_repeated = true; + array_back(&self->steps)->repeat_step_index = starting_step_index; + stream_skip_whitespace(stream); + } + + // Parse an '@'-prefixed capture pattern + else if (stream->next == '@') { + stream_advance(stream); + + // Parse the capture name + if (!stream_is_ident_start(stream)) return TSQueryErrorSyntax; + const char *capture_name = stream->input; + stream_scan_identifier(stream); + uint32_t length = stream->input - capture_name; + + // Add the capture id to the first step of the pattern + uint16_t capture_id = symbol_table_insert_name( + &self->captures, + capture_name, + length + ); + query_step__add_capture(step, capture_id); + (*capture_count)++; + + stream_skip_whitespace(stream); + } + + // No more suffix modifiers + else { + break; + } } return 0; @@ -912,16 +943,14 @@ TSQuery *ts_query_new( .predicates_by_pattern = array_new(), .symbol_map = symbol_map, .wildcard_root_pattern_count = 0, - .max_capture_count = 0, .language = language, }; // Parse all of the S-expressions in the given string. Stream stream = stream_new(source, source_len); stream_skip_whitespace(&stream); - uint32_t start_step_index; while (stream.input < stream.end) { - start_step_index = self->steps.size; + uint32_t start_step_index = self->steps.size; uint32_t capture_count = 0; array_push(&self->start_bytes_by_pattern, stream.input - source); array_push(&self->predicates_by_pattern, ((Slice) { @@ -939,7 +968,19 @@ TSQuery *ts_query_new( return NULL; } + // If a pattern has a wildcard at its root, optimize the matching process + // by skipping matching the wildcard. + if ( + self->steps.contents[start_step_index].symbol == WILDCARD_SYMBOL + ) { + QueryStep *second_step = &self->steps.contents[start_step_index + 1]; + if (second_step->symbol != WILDCARD_SYMBOL && second_step->depth != PATTERN_DONE_MARKER) { + start_step_index += 1; + } + } + // Maintain a map that can look up patterns for a given root symbol. + self->steps.contents[start_step_index].is_pattern_start = true; ts_query__pattern_map_insert( self, self->steps.contents[start_step_index].symbol, @@ -948,13 +989,6 @@ TSQuery *ts_query_new( if (self->steps.contents[start_step_index].symbol == WILDCARD_SYMBOL) { self->wildcard_root_pattern_count++; } - - // Keep track of the maximum number of captures in pattern, because - // that numer determines how much space is needed to store each capture - // list. - if (capture_count > self->max_capture_count) { - self->max_capture_count = capture_count; - } } ts_query__finalize_steps(self); @@ -1089,7 +1123,7 @@ void ts_query_cursor_exec( array_clear(&self->states); array_clear(&self->finished_states); ts_tree_cursor_reset(&self->cursor, node); - capture_list_pool_reset(&self->capture_list_pool, query->max_capture_count); + capture_list_pool_reset(&self->capture_list_pool); self->next_state_id = 0; self->depth = 0; self->ascending = false; @@ -1133,12 +1167,12 @@ static bool ts_query_cursor__first_in_progress_capture( bool result = false; for (unsigned i = 0; i < self->states.size; i++) { const QueryState *state = &self->states.contents[i]; - if (state->capture_count > 0) { - const TSQueryCapture *captures = capture_list_pool_get( - &self->capture_list_pool, - state->capture_list_id - ); - uint32_t capture_byte = ts_node_start_byte(captures[0].node); + const CaptureList *captures = capture_list_pool_get( + &self->capture_list_pool, + state->capture_list_id + ); + if (captures->size > 0) { + uint32_t capture_byte = ts_node_start_byte(captures->contents[0].node); if ( !result || capture_byte < *byte_offset || @@ -1161,6 +1195,19 @@ static bool ts_query__cursor_add_state( TSQueryCursor *self, const PatternEntry *pattern ) { + QueryStep *step = &self->query->steps.contents[pattern->step_index]; + + // If this pattern begins with a repetition, then avoid creating + // new states after already matching the repetition one or more times. + // The query should only one match for the repetition - the one that + // started the earliest. + if (step->is_repeated) { + for (unsigned i = 0; i < self->states.size; i++) { + QueryState *state = &self->states.contents[i]; + if (state->step_index == pattern->step_index) return true; + } + } + uint32_t list_id = capture_list_pool_acquire(&self->capture_list_pool); // If there are no capture lists left in the pool, then terminate whichever @@ -1186,14 +1233,20 @@ static bool ts_query__cursor_add_state( } } - LOG(" start state. pattern:%u\n", pattern->pattern_index); + LOG( + " start state. pattern:%u, step:%u\n", + pattern->pattern_index, + pattern->step_index + ); array_push(&self->states, ((QueryState) { .capture_list_id = list_id, .step_index = pattern->step_index, .pattern_index = pattern->pattern_index, - .start_depth = self->depth, - .capture_count = 0, + .start_depth = self->depth - step->depth, .consumed_capture_count = 0, + .repeat_match_count = 0, + .step_index_on_failure = NONE, + .seeking_non_match = false, })); return true; } @@ -1207,15 +1260,15 @@ static QueryState *ts_query__cursor_copy_state( array_push(&self->states, *state); QueryState *new_state = array_back(&self->states); new_state->capture_list_id = new_list_id; - TSQueryCapture *old_captures = capture_list_pool_get( + CaptureList *old_captures = capture_list_pool_get( &self->capture_list_pool, state->capture_list_id ); - TSQueryCapture *new_captures = capture_list_pool_get( + CaptureList *new_captures = capture_list_pool_get( &self->capture_list_pool, new_list_id ); - memcpy(new_captures, old_captures, state->capture_count * sizeof(TSQueryCapture)); + array_push_all(new_captures, old_captures); return new_state; } @@ -1372,6 +1425,24 @@ static inline bool ts_query_cursor__advance(TSQueryCursor *self) { } if (!node_does_match) { + // If this QueryState has processed a repeating sequence, and that repeating + // sequence has ended, move on to the *next* step of this state's pattern. + if ( + state->step_index_on_failure != NONE && + (!later_sibling_can_match || step->is_repeated) + ) { + LOG( + " finish repetition state. pattern:%u, step:%u\n", + state->pattern_index, + state->step_index + ); + state->step_index = state->step_index_on_failure; + state->step_index_on_failure = NONE; + state->repeat_match_count = 0; + i--; + continue; + } + if (!later_sibling_can_match) { LOG( " discard state. pattern:%u, step:%u\n", @@ -1386,9 +1457,17 @@ static inline bool ts_query_cursor__advance(TSQueryCursor *self) { i--; n--; } + + state->seeking_non_match = false; continue; } + // The `seeking_non_match` flag indicates that a previous QueryState + // has already begun processing this repeating sequence, so that *this* + // QueryState should not begin matching until a separate repeating sequence + // is found. + if (state->seeking_non_match) continue; + // Some patterns can match their root node in multiple ways, // capturing different children. If this pattern step could match // later children within the same parent, then this query state @@ -1398,11 +1477,20 @@ static inline bool ts_query_cursor__advance(TSQueryCursor *self) { // siblings. QueryState *next_state = state; if ( - step->depth > 0 && + !step->is_pattern_start && step->contains_captures && - later_sibling_can_match + later_sibling_can_match && + state->repeat_match_count == 0 ) { QueryState *copy = ts_query__cursor_copy_state(self, state); + + // The QueryState that matched this node has begun matching a repeating + // sequence. The QueryState that *skipped* this node should not start + // matching later elements of the same repeating sequence. + if (step->is_repeated) { + state->seeking_non_match = true; + } + if (copy) { LOG( " split state. pattern:%u, step:%u\n", @@ -1411,55 +1499,71 @@ static inline bool ts_query_cursor__advance(TSQueryCursor *self) { ); next_state = copy; } else { - LOG(" canot split state.\n"); + LOG(" cannot split state.\n"); } } - LOG( - " advance state. pattern:%u, step:%u\n", - next_state->pattern_index, - next_state->step_index - ); - // If the current node is captured in this pattern, add it to the // capture list. for (unsigned j = 0; j < MAX_STEP_CAPTURE_COUNT; j++) { uint16_t capture_id = step->capture_ids[j]; if (step->capture_ids[j] == NONE) break; - LOG( - " capture node. pattern:%u, capture_id:%u\n", - next_state->pattern_index, - capture_id - ); - TSQueryCapture *capture_list = capture_list_pool_get( + CaptureList *capture_list = capture_list_pool_get( &self->capture_list_pool, next_state->capture_list_id ); - capture_list[next_state->capture_count++] = (TSQueryCapture) { + array_push(capture_list, ((TSQueryCapture) { node, capture_id - }; + })); + LOG( + " capture node. pattern:%u, capture_id:%u, capture_count:%u\n", + next_state->pattern_index, + capture_id, + capture_list->size + ); } - // If the pattern is now done, then remove it from the list of - // in-progress states, and add it to the list of finished states. - next_state->step_index++; - QueryStep *next_step = step + 1; - if (next_step->depth == PATTERN_DONE_MARKER) { - LOG(" finish pattern %u\n", next_state->pattern_index); + // If this is the end of a repetition, then jump back to the beginning + // of that repetition. + if (step->repeat_step_index != NONE) { + next_state->step_index_on_failure = next_state->step_index + 1; + next_state->step_index = step->repeat_step_index; + next_state->repeat_match_count++; + LOG( + " continue repeat. pattern:%u, match_count:%u\n", + next_state->pattern_index, + next_state->repeat_match_count + ); + } else { + next_state->step_index++; + LOG( + " advance state. pattern:%u, step:%u\n", + next_state->pattern_index, + next_state->step_index + ); - next_state->id = self->next_state_id++; - array_push(&self->finished_states, *next_state); - if (next_state == state) { - array_erase(&self->states, i); - i--; - n--; - } else { - self->states.size--; + QueryStep *next_step = step + 1; + + // If the pattern is now done, then remove it from the list of + // in-progress states, and add it to the list of finished states. + if (next_step->depth == PATTERN_DONE_MARKER) { + LOG(" finish pattern %u\n", next_state->pattern_index); + + next_state->id = self->next_state_id++; + array_push(&self->finished_states, *next_state); + if (next_state == state) { + array_erase(&self->states, i); + i--; + n--; + } else { + self->states.size--; + } } } } + // Continue descending if possible. if (ts_tree_cursor_goto_first_child(&self->cursor)) { self->depth++; @@ -1485,11 +1589,12 @@ bool ts_query_cursor_next_match( QueryState *state = &self->finished_states.contents[0]; match->id = state->id; match->pattern_index = state->pattern_index; - match->capture_count = state->capture_count; - match->captures = capture_list_pool_get( + CaptureList *captures = capture_list_pool_get( &self->capture_list_pool, state->capture_list_id ); + match->captures = captures->contents; + match->capture_count = captures->size; capture_list_pool_release(&self->capture_list_pool, state->capture_list_id); array_erase(&self->finished_states, 0); return true; @@ -1542,13 +1647,13 @@ bool ts_query_cursor_next_capture( uint32_t first_finished_pattern_index = first_unfinished_pattern_index; for (unsigned i = 0; i < self->finished_states.size; i++) { const QueryState *state = &self->finished_states.contents[i]; - if (state->capture_count > state->consumed_capture_count) { - const TSQueryCapture *captures = capture_list_pool_get( - &self->capture_list_pool, - state->capture_list_id - ); + CaptureList *captures = capture_list_pool_get( + &self->capture_list_pool, + state->capture_list_id + ); + if (captures->size > state->consumed_capture_count) { uint32_t capture_byte = ts_node_start_byte( - captures[state->consumed_capture_count].node + captures->contents[state->consumed_capture_count].node ); if ( capture_byte < first_finished_capture_byte || @@ -1580,11 +1685,12 @@ bool ts_query_cursor_next_capture( ]; match->id = state->id; match->pattern_index = state->pattern_index; - match->capture_count = state->capture_count; - match->captures = capture_list_pool_get( + CaptureList *captures = capture_list_pool_get( &self->capture_list_pool, state->capture_list_id ); + match->captures = captures->contents; + match->capture_count = captures->size; *capture_index = state->consumed_capture_count; state->consumed_capture_count++; return true; diff --git a/tags/Cargo.toml b/tags/Cargo.toml new file mode 100644 index 00000000..43557bb2 --- /dev/null +++ b/tags/Cargo.toml @@ -0,0 +1,25 @@ +[package] +name = "tree-sitter-tags" +description = "Library for extracting tag information" +version = "0.1.6" +authors = [ + "Max Brunsfeld ", + "Patrick Thomson " +] +license = "MIT" +readme = "README.md" +edition = "2018" +keywords = ["incremental", "parsing", "syntax", "tagging"] +categories = ["parsing", "text-editors"] +repository = "https://github.com/tree-sitter/tree-sitter" + +[lib] +crate-type = ["lib", "staticlib"] + +[dependencies] +regex = "1" +memchr = "2.3" + +[dependencies.tree-sitter] +version = ">= 0.3.7" +path = "../lib" diff --git a/tags/README.md b/tags/README.md new file mode 100644 index 00000000..7a55c254 --- /dev/null +++ b/tags/README.md @@ -0,0 +1,60 @@ +Tree-sitter Tags +========================= + +### Usage + +Compile some languages into your app, and declare them: + +```rust +extern "C" tree_sitter_python(); +extern "C" tree_sitter_javascript(); +``` + +Create a tag context. You need one of these for each thread that you're using for tag computation: + +```rust +use tree_sitter_tags::TagsContext; + +let context = TagsContext::new(); +``` + +Load some tagging queries from the `queries` directory of some language repositories: + +```rust +use tree_sitter_highlight::TagsConfiguration; + +let python_language = unsafe { tree_sitter_python() }; +let javascript_language = unsafe { tree_sitter_javascript() }; + +let python_config = HighlightConfiguration::new( + python_language, + &fs::read_to_string("./tree-sitter-python/queries/tags.scm").unwrap(), + &fs::read_to_string("./tree-sitter-python/queries/locals.scm").unwrap(), +).unwrap(); + +let javascript_config = HighlightConfiguration::new( + javascript_language, + &fs::read_to_string("./tree-sitter-javascript/queries/tags.scm").unwrap(), + &fs::read_to_string("./tree-sitter-javascript/queries/locals.scm").unwrap(), +).unwrap(); +``` + +Compute code navigation tags for some source code: + +```rust +use tree_sitter_highlight::HighlightEvent; + +let tags = context.generate_tags( + &javascript_config, + b"class A { getB() { return c(); } }", + None, + |_| None +); + +for tag in tags { + println!("kind: {:?}", tag.kind); + println!("range: {:?}", tag.range); + println!("name_range: {:?}", tag.name_range); + println!("docs: {:?}", tag.docs); +} +``` diff --git a/tags/include/tree_sitter/tags.h b/tags/include/tree_sitter/tags.h new file mode 100644 index 00000000..946dc6f1 --- /dev/null +++ b/tags/include/tree_sitter/tags.h @@ -0,0 +1,96 @@ +#ifndef TREE_SITTER_TAGS_H_ +#define TREE_SITTER_TAGS_H_ + +#ifdef __cplusplus +extern "C" { +#endif + +#include +#include "tree_sitter/api.h" + +typedef enum { + TSTagsOk, + TSTagsUnknownScope, + TSTagsTimeout, + TSTagsInvalidLanguage, + TSTagsInvalidUtf8, + TSTagsInvalidRegex, + TSTagsInvalidQuery, +} TSTagsError; + +typedef enum { + TSTagKindFunction, + TSTagKindMethod, + TSTagKindClass, + TSTagKindModule, + TSTagKindCall, +} TSTagKind; + +typedef struct { + TSTagKind kind; + uint32_t start_byte; + uint32_t end_byte; + uint32_t name_start_byte; + uint32_t name_end_byte; + uint32_t line_start_byte; + uint32_t line_end_byte; + TSPoint start_point; + TSPoint end_point; + uint32_t docs_start_byte; + uint32_t docs_end_byte; +} TSTag; + +typedef struct TSTagger TSTagger; +typedef struct TSTagsBuffer TSTagsBuffer; + +// Construct a tagger. +TSTagger *ts_tagger_new(); + +// Delete a tagger. +void ts_tagger_delete(TSTagger *); + +// Add a `TSLanguage` to a tagger. The language is associated with a scope name, +// which can be used later to select a language for tagging. Along with the language, +// you must provide two tree query strings, one for matching tags themselves, and one +// specifying local variable definitions. +TSTagsError ts_tagger_add_language( + TSTagger *self, + const char *scope_name, + const TSLanguage *language, + const char *tags_query, + const char *locals_query, + uint32_t tags_query_len, + uint32_t locals_query_len +); + +// Compute syntax highlighting for a given document. You must first +// create a `TSTagsBuffer` to hold the output. +TSTagsError ts_tagger_tag( + const TSTagger *self, + const char *scope_name, + const char *source_code, + uint32_t source_code_len, + TSTagsBuffer *output, + const size_t *cancellation_flag +); + +// A tags buffer stores the results produced by a tagging call. It can be reused +// for multiple calls. +TSTagsBuffer *ts_tags_buffer_new(); + +// Delete a tags buffer. +void ts_tags_buffer_delete(TSTagsBuffer *); + +// Access the tags within a tag buffer. +const TSTag *ts_tags_buffer_tags(const TSTagsBuffer *); +uint32_t ts_tags_buffer_tags_len(const TSTagsBuffer *); + +// Access the string containing all of the docs +const char *ts_tags_buffer_docs(const TSTagsBuffer *); +uint32_t ts_tags_buffer_docs_len(const TSTagsBuffer *); + +#ifdef __cplusplus +} +#endif + +#endif // TREE_SITTER_TAGS_H_ diff --git a/tags/src/c_lib.rs b/tags/src/c_lib.rs new file mode 100644 index 00000000..0c367977 --- /dev/null +++ b/tags/src/c_lib.rs @@ -0,0 +1,245 @@ +use super::{Error, TagKind, TagsConfiguration, TagsContext}; +use std::collections::HashMap; +use std::ffi::CStr; +use std::process::abort; +use std::sync::atomic::AtomicUsize; +use std::{fmt, slice, str}; +use tree_sitter::Language; + +#[repr(C)] +#[derive(Debug, PartialEq, Eq)] +pub enum TSTagsError { + Ok, + UnknownScope, + Timeout, + InvalidLanguage, + InvalidUtf8, + InvalidRegex, + InvalidQuery, + Unknown, +} + +#[repr(C)] +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum TSTagKind { + Function, + Method, + Class, + Module, + Call, +} + +#[repr(C)] +pub struct TSPoint { + row: u32, + column: u32, +} + +#[repr(C)] +pub struct TSTag { + pub kind: TSTagKind, + pub start_byte: u32, + pub end_byte: u32, + pub name_start_byte: u32, + pub name_end_byte: u32, + pub line_start_byte: u32, + pub line_end_byte: u32, + pub start_point: TSPoint, + pub end_point: TSPoint, + pub docs_start_byte: u32, + pub docs_end_byte: u32, +} + +pub struct TSTagger { + languages: HashMap, +} + +pub struct TSTagsBuffer { + context: TagsContext, + tags: Vec, + docs: Vec, +} + +#[no_mangle] +pub extern "C" fn ts_tagger_new() -> *mut TSTagger { + Box::into_raw(Box::new(TSTagger { + languages: HashMap::new(), + })) +} + +#[no_mangle] +pub extern "C" fn ts_tagger_delete(this: *mut TSTagger) { + drop(unsafe { Box::from_raw(this) }) +} + +#[no_mangle] +pub extern "C" fn ts_tagger_add_language( + this: *mut TSTagger, + scope_name: *const i8, + language: Language, + tags_query: *const u8, + locals_query: *const u8, + tags_query_len: u32, + locals_query_len: u32, +) -> TSTagsError { + let tagger = unwrap_mut_ptr(this); + let scope_name = unsafe { unwrap(CStr::from_ptr(scope_name).to_str()) }; + let tags_query = unsafe { slice::from_raw_parts(tags_query, tags_query_len as usize) }; + let locals_query = unsafe { slice::from_raw_parts(locals_query, locals_query_len as usize) }; + let tags_query = match str::from_utf8(tags_query) { + Ok(e) => e, + Err(_) => return TSTagsError::InvalidUtf8, + }; + let locals_query = match str::from_utf8(locals_query) { + Ok(e) => e, + Err(_) => return TSTagsError::InvalidUtf8, + }; + + match TagsConfiguration::new(language, tags_query, locals_query) { + Ok(c) => { + tagger.languages.insert(scope_name.to_string(), c); + TSTagsError::Ok + } + Err(Error::Query(_)) => TSTagsError::InvalidQuery, + Err(Error::Regex(_)) => TSTagsError::InvalidRegex, + Err(_) => TSTagsError::Unknown, + } +} + +#[no_mangle] +pub extern "C" fn ts_tagger_tag( + this: *mut TSTagger, + scope_name: *const i8, + source_code: *const u8, + source_code_len: u32, + output: *mut TSTagsBuffer, + cancellation_flag: *const AtomicUsize, +) -> TSTagsError { + let tagger = unwrap_mut_ptr(this); + let buffer = unwrap_mut_ptr(output); + let scope_name = unsafe { unwrap(CStr::from_ptr(scope_name).to_str()) }; + + if let Some(config) = tagger.languages.get(scope_name) { + buffer.tags.clear(); + buffer.docs.clear(); + let source_code = unsafe { slice::from_raw_parts(source_code, source_code_len as usize) }; + let cancellation_flag = unsafe { cancellation_flag.as_ref() }; + + let tags = match buffer + .context + .generate_tags(config, source_code, cancellation_flag) + { + Ok(tags) => tags, + Err(e) => { + return match e { + Error::InvalidLanguage => TSTagsError::InvalidLanguage, + Error::Cancelled => TSTagsError::Timeout, + _ => TSTagsError::Timeout, + } + } + }; + + for tag in tags { + let tag = if let Ok(tag) = tag { + tag + } else { + buffer.tags.clear(); + buffer.docs.clear(); + return TSTagsError::Timeout; + }; + + let prev_docs_len = buffer.docs.len(); + if let Some(docs) = tag.docs { + buffer.docs.extend_from_slice(docs.as_bytes()); + } + buffer.tags.push(TSTag { + kind: match tag.kind { + TagKind::Function => TSTagKind::Function, + TagKind::Method => TSTagKind::Method, + TagKind::Class => TSTagKind::Class, + TagKind::Module => TSTagKind::Module, + TagKind::Call => TSTagKind::Call, + }, + start_byte: tag.range.start as u32, + end_byte: tag.range.end as u32, + name_start_byte: tag.name_range.start as u32, + name_end_byte: tag.name_range.end as u32, + line_start_byte: tag.line_range.start as u32, + line_end_byte: tag.line_range.end as u32, + start_point: TSPoint { + row: tag.span.start.row as u32, + column: tag.span.start.column as u32, + }, + end_point: TSPoint { + row: tag.span.end.row as u32, + column: tag.span.end.column as u32, + }, + docs_start_byte: prev_docs_len as u32, + docs_end_byte: buffer.docs.len() as u32, + }); + } + + TSTagsError::Ok + } else { + TSTagsError::UnknownScope + } +} + +#[no_mangle] +pub extern "C" fn ts_tags_buffer_new() -> *mut TSTagsBuffer { + Box::into_raw(Box::new(TSTagsBuffer { + context: TagsContext::new(), + tags: Vec::with_capacity(64), + docs: Vec::with_capacity(64), + })) +} + +#[no_mangle] +pub extern "C" fn ts_tags_buffer_delete(this: *mut TSTagsBuffer) { + drop(unsafe { Box::from_raw(this) }) +} + +#[no_mangle] +pub extern "C" fn ts_tags_buffer_tags(this: *const TSTagsBuffer) -> *const TSTag { + let buffer = unwrap_ptr(this); + buffer.tags.as_ptr() +} + +#[no_mangle] +pub extern "C" fn ts_tags_buffer_tags_len(this: *const TSTagsBuffer) -> u32 { + let buffer = unwrap_ptr(this); + buffer.tags.len() as u32 +} + +#[no_mangle] +pub extern "C" fn ts_tags_buffer_docs(this: *const TSTagsBuffer) -> *const i8 { + let buffer = unwrap_ptr(this); + buffer.docs.as_ptr() as *const i8 +} + +#[no_mangle] +pub extern "C" fn ts_tags_buffer_docs_len(this: *const TSTagsBuffer) -> u32 { + let buffer = unwrap_ptr(this); + buffer.docs.len() as u32 +} + +fn unwrap_ptr<'a, T>(result: *const T) -> &'a T { + unsafe { result.as_ref() }.unwrap_or_else(|| { + eprintln!("{}:{} - pointer must not be null", file!(), line!()); + abort(); + }) +} + +fn unwrap_mut_ptr<'a, T>(result: *mut T) -> &'a mut T { + unsafe { result.as_mut() }.unwrap_or_else(|| { + eprintln!("{}:{} - pointer must not be null", file!(), line!()); + abort(); + }) +} + +fn unwrap(result: Result) -> T { + result.unwrap_or_else(|error| { + eprintln!("tree-sitter tag error: {}", error); + abort(); + }) +} diff --git a/tags/src/lib.rs b/tags/src/lib.rs new file mode 100644 index 00000000..8d1853bb --- /dev/null +++ b/tags/src/lib.rs @@ -0,0 +1,499 @@ +pub mod c_lib; + +use memchr::{memchr, memrchr}; +use regex::Regex; +use std::ops::Range; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::{fmt, mem, str}; +use tree_sitter::{ + Language, Parser, Point, Query, QueryCursor, QueryError, QueryPredicateArg, Tree, +}; + +const MAX_LINE_LEN: usize = 180; +const CANCELLATION_CHECK_INTERVAL: usize = 100; + +/// Contains the data neeeded to compute tags for code written in a +/// particular language. +#[derive(Debug)] +pub struct TagsConfiguration { + pub language: Language, + pub query: Query, + call_capture_index: Option, + class_capture_index: Option, + doc_capture_index: Option, + function_capture_index: Option, + method_capture_index: Option, + module_capture_index: Option, + name_capture_index: Option, + local_scope_capture_index: Option, + local_definition_capture_index: Option, + tags_pattern_index: usize, + pattern_info: Vec, +} + +pub struct TagsContext { + parser: Parser, + cursor: QueryCursor, +} + +#[derive(Debug, Clone)] +pub struct Tag { + pub kind: TagKind, + pub range: Range, + pub name_range: Range, + pub line_range: Range, + pub span: Range, + pub docs: Option, +} + +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub enum TagKind { + Function, + Method, + Class, + Module, + Call, +} + +#[derive(Debug, PartialEq)] +pub enum Error { + Query(QueryError), + Regex(regex::Error), + Cancelled, + InvalidLanguage, +} + +#[derive(Debug, Default)] +struct PatternInfo { + docs_adjacent_capture: Option, + local_scope_inherits: bool, + name_must_be_non_local: bool, + doc_strip_regex: Option, +} + +#[derive(Debug)] +struct LocalDef<'a> { + name: &'a [u8], + value_range: Range, +} + +#[derive(Debug)] +struct LocalScope<'a> { + inherits: bool, + range: Range, + local_defs: Vec>, +} + +struct TagsIter<'a, I> +where + I: Iterator>, +{ + matches: I, + _tree: Tree, + source: &'a [u8], + config: &'a TagsConfiguration, + cancellation_flag: Option<&'a AtomicUsize>, + iter_count: usize, + tag_queue: Vec<(Tag, usize)>, + scopes: Vec>, +} + +impl TagsConfiguration { + pub fn new(language: Language, tags_query: &str, locals_query: &str) -> Result { + let query = Query::new(language, &format!("{}{}", locals_query, tags_query))?; + + let tags_query_offset = locals_query.len(); + let mut tags_pattern_index = 0; + for i in 0..(query.pattern_count()) { + let pattern_offset = query.start_byte_for_pattern(i); + if pattern_offset < tags_query_offset { + tags_pattern_index += 1; + } + } + + let mut call_capture_index = None; + let mut class_capture_index = None; + let mut doc_capture_index = None; + let mut function_capture_index = None; + let mut method_capture_index = None; + let mut module_capture_index = None; + let mut name_capture_index = None; + let mut local_scope_capture_index = None; + let mut local_definition_capture_index = None; + for (i, name) in query.capture_names().iter().enumerate() { + let index = match name.as_str() { + "call" => &mut call_capture_index, + "class" => &mut class_capture_index, + "doc" => &mut doc_capture_index, + "function" => &mut function_capture_index, + "method" => &mut method_capture_index, + "module" => &mut module_capture_index, + "name" => &mut name_capture_index, + "local.scope" => &mut local_scope_capture_index, + "local.definition" => &mut local_definition_capture_index, + _ => continue, + }; + *index = Some(i as u32); + } + + let pattern_info = (0..query.pattern_count()) + .map(|pattern_index| { + let mut info = PatternInfo::default(); + for (property, is_positive) in query.property_predicates(pattern_index) { + if !is_positive && property.key.as_ref() == "local" { + info.name_must_be_non_local = true; + } + } + info.local_scope_inherits = true; + for property in query.property_settings(pattern_index) { + if property.key.as_ref() == "local.scope-inherits" + && property + .value + .as_ref() + .map_or(false, |v| v.as_ref() == "false") + { + info.local_scope_inherits = false; + } + } + if let Some(doc_capture_index) = doc_capture_index { + for predicate in query.general_predicates(pattern_index) { + if predicate.args.get(0) + == Some(&QueryPredicateArg::Capture(doc_capture_index)) + { + match (predicate.operator.as_ref(), predicate.args.get(1)) { + ("select-adjacent!", Some(QueryPredicateArg::Capture(index))) => { + info.docs_adjacent_capture = Some(*index); + } + ("strip!", Some(QueryPredicateArg::String(pattern))) => { + let regex = Regex::new(pattern.as_ref())?; + info.doc_strip_regex = Some(regex); + } + _ => {} + } + } + } + } + return Ok(info); + }) + .collect::, Error>>()?; + + Ok(TagsConfiguration { + language, + query, + function_capture_index, + class_capture_index, + method_capture_index, + module_capture_index, + doc_capture_index, + call_capture_index, + name_capture_index, + tags_pattern_index, + local_scope_capture_index, + local_definition_capture_index, + pattern_info, + }) + } +} + +impl TagsContext { + pub fn new() -> Self { + TagsContext { + parser: Parser::new(), + cursor: QueryCursor::new(), + } + } + + pub fn generate_tags<'a>( + &'a mut self, + config: &'a TagsConfiguration, + source: &'a [u8], + cancellation_flag: Option<&'a AtomicUsize>, + ) -> Result> + 'a, Error> { + self.parser + .set_language(config.language) + .map_err(|_| Error::InvalidLanguage)?; + self.parser.reset(); + unsafe { self.parser.set_cancellation_flag(cancellation_flag) }; + let tree = self.parser.parse(source, None).ok_or(Error::Cancelled)?; + + // The `matches` iterator borrows the `Tree`, which prevents it from being moved. + // But the tree is really just a pointer, so it's actually ok to move it. + let tree_ref = unsafe { mem::transmute::<_, &'static Tree>(&tree) }; + let matches = self + .cursor + .matches(&config.query, tree_ref.root_node(), move |node| { + &source[node.byte_range()] + }); + Ok(TagsIter { + _tree: tree, + matches, + source, + config, + cancellation_flag, + tag_queue: Vec::new(), + iter_count: 0, + scopes: vec![LocalScope { + range: 0..source.len(), + inherits: false, + local_defs: Vec::new(), + }], + }) + } +} + +impl<'a, I> Iterator for TagsIter<'a, I> +where + I: Iterator>, +{ + type Item = Result; + + fn next(&mut self) -> Option { + loop { + // Periodically check for cancellation, returning `Cancelled` error if the + // cancellation flag was flipped. + if let Some(cancellation_flag) = self.cancellation_flag { + self.iter_count += 1; + if self.iter_count >= CANCELLATION_CHECK_INTERVAL { + self.iter_count = 0; + if cancellation_flag.load(Ordering::Relaxed) != 0 { + return Some(Err(Error::Cancelled)); + } + } + } + + // If there is a queued tag for an earlier node in the syntax tree, then pop + // it off of the queue and return it. + if let Some(last_entry) = self.tag_queue.last() { + if self.tag_queue.len() > 1 + && self.tag_queue[0].0.name_range.end < last_entry.0.name_range.start + { + return Some(Ok(self.tag_queue.remove(0).0)); + } + } + + // If there is another match, then compute its tag and add it to the + // tag queue. + if let Some(mat) = self.matches.next() { + let pattern_info = &self.config.pattern_info[mat.pattern_index]; + + if mat.pattern_index < self.config.tags_pattern_index { + for capture in mat.captures { + let index = Some(capture.index); + let range = capture.node.byte_range(); + if index == self.config.local_scope_capture_index { + self.scopes.push(LocalScope { + range, + inherits: pattern_info.local_scope_inherits, + local_defs: Vec::new(), + }); + } else if index == self.config.local_definition_capture_index { + if let Some(scope) = self.scopes.iter_mut().rev().find(|scope| { + scope.range.start <= range.start && scope.range.end >= range.end + }) { + scope.local_defs.push(LocalDef { + name: &self.source[range.clone()], + value_range: range, + }); + } + } + } + continue; + } + + let mut name_range = None; + let mut doc_nodes = Vec::new(); + let mut tag_node = None; + let mut kind = TagKind::Call; + let mut docs_adjacent_node = None; + + for capture in mat.captures { + let index = Some(capture.index); + + if index == self.config.pattern_info[mat.pattern_index].docs_adjacent_capture { + docs_adjacent_node = Some(capture.node); + } + + if index == self.config.name_capture_index { + name_range = Some(capture.node.byte_range()); + } else if index == self.config.doc_capture_index { + doc_nodes.push(capture.node); + } else if index == self.config.call_capture_index { + tag_node = Some(capture.node); + kind = TagKind::Call; + } else if index == self.config.class_capture_index { + tag_node = Some(capture.node); + kind = TagKind::Class; + } else if index == self.config.function_capture_index { + tag_node = Some(capture.node); + kind = TagKind::Function; + } else if index == self.config.method_capture_index { + tag_node = Some(capture.node); + kind = TagKind::Method; + } else if index == self.config.module_capture_index { + tag_node = Some(capture.node); + kind = TagKind::Module; + } + } + + if let (Some(tag_node), Some(name_range)) = (tag_node, name_range) { + if pattern_info.name_must_be_non_local { + let mut is_local = false; + for scope in self.scopes.iter().rev() { + if scope.range.start <= name_range.start + && scope.range.end >= name_range.end + { + if scope + .local_defs + .iter() + .any(|d| d.name == &self.source[name_range.clone()]) + { + is_local = true; + break; + } + if !scope.inherits { + break; + } + } + } + if is_local { + continue; + } + } + + // If needed, filter the doc nodes based on their ranges, selecting + // only the slice that are adjacent to some specified node. + let mut docs_start_index = 0; + if let (Some(docs_adjacent_node), false) = + (docs_adjacent_node, doc_nodes.is_empty()) + { + docs_start_index = doc_nodes.len(); + let mut start_row = docs_adjacent_node.start_position().row; + while docs_start_index > 0 { + let doc_node = &doc_nodes[docs_start_index - 1]; + let prev_doc_end_row = doc_node.end_position().row; + if prev_doc_end_row + 1 >= start_row { + docs_start_index -= 1; + start_row = doc_node.start_position().row; + } else { + break; + } + } + } + + // Generate a doc string from all of the doc nodes, applying any strip regexes. + let mut docs = None; + for doc_node in &doc_nodes[docs_start_index..] { + if let Ok(content) = str::from_utf8(&self.source[doc_node.byte_range()]) { + let content = if let Some(regex) = &pattern_info.doc_strip_regex { + regex.replace_all(content, "").to_string() + } else { + content.to_string() + }; + match &mut docs { + None => docs = Some(content), + Some(d) => { + d.push('\n'); + d.push_str(&content); + } + } + } + } + + // Only create one tag per node. The tag queue is sorted by node position + // to allow for fast lookup. + let range = tag_node.byte_range(); + match self + .tag_queue + .binary_search_by_key(&(name_range.end, name_range.start), |(tag, _)| { + (tag.name_range.end, tag.name_range.start) + }) { + Ok(i) => { + let (tag, pattern_index) = &mut self.tag_queue[i]; + if *pattern_index > mat.pattern_index { + *pattern_index = mat.pattern_index; + *tag = Tag { + line_range: line_range(self.source, range.start, MAX_LINE_LEN), + span: tag_node.start_position()..tag_node.end_position(), + kind, + range, + name_range, + docs, + }; + } + } + Err(i) => self.tag_queue.insert( + i, + ( + Tag { + line_range: line_range(self.source, range.start, MAX_LINE_LEN), + span: tag_node.start_position()..tag_node.end_position(), + kind, + range, + name_range, + docs, + }, + mat.pattern_index, + ), + ), + } + } + } + // If there are no more matches, then drain the queue. + else if !self.tag_queue.is_empty() { + return Some(Ok(self.tag_queue.remove(0).0)); + } else { + return None; + } + } + } +} + +impl fmt::Display for TagKind { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + TagKind::Call => "Call", + TagKind::Module => "Module", + TagKind::Class => "Class", + TagKind::Method => "Method", + TagKind::Function => "Function", + } + .fmt(f) + } +} + +impl From for Error { + fn from(error: regex::Error) -> Self { + Error::Regex(error) + } +} + +impl From for Error { + fn from(error: QueryError) -> Self { + Error::Query(error) + } +} + +fn line_range(text: &[u8], index: usize, max_line_len: usize) -> Range { + let start = memrchr(b'\n', &text[0..index]).map_or(0, |i| i + 1); + let max_line_len = max_line_len.min(text.len() - start); + let end = start + memchr(b'\n', &text[start..(start + max_line_len)]).unwrap_or(max_line_len); + start..end +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_get_line() { + let text = b"abc\ndefg\nhijkl"; + assert_eq!(line_range(text, 0, 10), 0..3); + assert_eq!(line_range(text, 1, 10), 0..3); + assert_eq!(line_range(text, 2, 10), 0..3); + assert_eq!(line_range(text, 3, 10), 0..3); + assert_eq!(line_range(text, 1, 2), 0..2); + assert_eq!(line_range(text, 4, 10), 4..8); + assert_eq!(line_range(text, 5, 10), 4..8); + assert_eq!(line_range(text, 11, 10), 9..14); + } +}