Merge pull request #583 from tree-sitter/tags

Add a 'tags' crate, for computing ctags-style code navigation tags
This commit is contained in:
Max Brunsfeld 2020-04-03 11:20:51 -07:00 committed by GitHub
commit 21175142af
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
21 changed files with 1952 additions and 223 deletions

32
Cargo.lock generated
View file

@ -5,7 +5,7 @@ name = "aho-corasick"
version = "0.6.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"memchr 2.1.1 (registry+https://github.com/rust-lang/crates.io-index)",
"memchr 2.3.3 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
@ -303,13 +303,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "memchr"
version = "2.1.1"
version = "2.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"cfg-if 0.1.6 (registry+https://github.com/rust-lang/crates.io-index)",
"libc 0.2.61 (registry+https://github.com/rust-lang/crates.io-index)",
"version_check 0.1.5 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "nodrop"
@ -536,7 +531,7 @@ version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"aho-corasick 0.6.9 (registry+https://github.com/rust-lang/crates.io-index)",
"memchr 2.1.1 (registry+https://github.com/rust-lang/crates.io-index)",
"memchr 2.3.3 (registry+https://github.com/rust-lang/crates.io-index)",
"regex-syntax 0.6.4 (registry+https://github.com/rust-lang/crates.io-index)",
"thread_local 0.3.6 (registry+https://github.com/rust-lang/crates.io-index)",
"utf8-ranges 1.0.2 (registry+https://github.com/rust-lang/crates.io-index)",
@ -769,6 +764,7 @@ dependencies = [
"tiny_http 0.6.2 (registry+https://github.com/rust-lang/crates.io-index)",
"tree-sitter 0.6.3",
"tree-sitter-highlight 0.1.6",
"tree-sitter-tags 0.1.6",
"webbrowser 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)",
]
@ -777,9 +773,15 @@ name = "tree-sitter-highlight"
version = "0.1.6"
dependencies = [
"regex 1.1.0 (registry+https://github.com/rust-lang/crates.io-index)",
"serde 1.0.80 (registry+https://github.com/rust-lang/crates.io-index)",
"serde_derive 1.0.80 (registry+https://github.com/rust-lang/crates.io-index)",
"serde_json 1.0.33 (registry+https://github.com/rust-lang/crates.io-index)",
"tree-sitter 0.6.3",
]
[[package]]
name = "tree-sitter-tags"
version = "0.1.6"
dependencies = [
"memchr 2.3.3 (registry+https://github.com/rust-lang/crates.io-index)",
"regex 1.1.0 (registry+https://github.com/rust-lang/crates.io-index)",
"tree-sitter 0.6.3",
]
@ -842,11 +844,6 @@ name = "vec_map"
version = "0.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "version_check"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "void"
version = "1.0.2"
@ -926,7 +923,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
"checksum lock_api 0.1.5 (registry+https://github.com/rust-lang/crates.io-index)" = "62ebf1391f6acad60e5c8b43706dde4582df75c06698ab44511d15016bc2442c"
"checksum log 0.4.6 (registry+https://github.com/rust-lang/crates.io-index)" = "c84ec4b527950aa83a329754b01dbe3f58361d1c5efacd1f6d68c494d08a17c6"
"checksum matches 0.1.8 (registry+https://github.com/rust-lang/crates.io-index)" = "7ffc5c5338469d4d3ea17d269fa8ea3512ad247247c30bd2df69e68309ed0a08"
"checksum memchr 2.1.1 (registry+https://github.com/rust-lang/crates.io-index)" = "0a3eb002f0535929f1199681417029ebea04aadc0c7a4224b46be99c7f5d6a16"
"checksum memchr 2.3.3 (registry+https://github.com/rust-lang/crates.io-index)" = "3728d817d99e5ac407411fa471ff9800a778d88a24685968b36824eaf4bee400"
"checksum nodrop 0.1.13 (registry+https://github.com/rust-lang/crates.io-index)" = "2f9667ddcc6cc8a43afc9b7917599d7216aa09c463919ea32c59ed6cac8bc945"
"checksum num-integer 0.1.39 (registry+https://github.com/rust-lang/crates.io-index)" = "e83d528d2677f0518c570baf2b7abdcf0cd2d248860b68507bdcb3e91d4c0cea"
"checksum num-traits 0.2.6 (registry+https://github.com/rust-lang/crates.io-index)" = "0b3a5d7cc97d6d30d8b9bc8fa19bf45349ffe46241e8816f50f62f6d6aaabee1"
@ -987,7 +984,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
"checksum url 1.7.2 (registry+https://github.com/rust-lang/crates.io-index)" = "dd4e7c0d531266369519a4aa4f399d748bd37043b00bde1e4ff1f60a120b355a"
"checksum utf8-ranges 1.0.2 (registry+https://github.com/rust-lang/crates.io-index)" = "796f7e48bef87609f7ade7e06495a87d5cd06c7866e6a5cbfceffc558a243737"
"checksum vec_map 0.8.1 (registry+https://github.com/rust-lang/crates.io-index)" = "05c78687fb1a80548ae3250346c3db86a80a7cdd77bda190189f2d0a0987c81a"
"checksum version_check 0.1.5 (registry+https://github.com/rust-lang/crates.io-index)" = "914b1a6776c4c929a602fafd8bc742e06365d4bcbe48c30f9cca5824f70dc9dd"
"checksum void 1.0.2 (registry+https://github.com/rust-lang/crates.io-index)" = "6a02e4885ed3bc0f2de90ea6dd45ebcbb66dacffe03547fadbb0eeae2770887d"
"checksum webbrowser 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)" = "c01efd7cb6939b7f34983f1edff0550e5b21b49e2db4495656295922df8939ac"
"checksum widestring 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "effc0e4ff8085673ea7b9b2e3c73f6bd4d118810c9009ed8f1e16bd96c331db6"

View file

@ -44,6 +44,10 @@ path = "../lib"
version = ">= 0.1.0"
path = "../highlight"
[dependencies.tree-sitter-tags]
version = ">= 0.1.0"
path = "../tags"
[dependencies.serde_json]
version = "1.0"
features = ["preserve_order"]

View file

@ -81,6 +81,12 @@ impl<'a> From<tree_sitter_highlight::Error> for Error {
}
}
impl<'a> From<tree_sitter_tags::Error> for Error {
fn from(error: tree_sitter_tags::Error) -> Self {
Error::new(format!("{:?}", error))
}
}
impl From<serde_json::Error> for Error {
fn from(error: serde_json::Error) -> Self {
Error::new(error.to_string())

View file

@ -325,12 +325,13 @@ impl Generator {
add_line!(self, "static TSSymbol ts_symbol_map[] = {{");
indent!(self);
for symbol in &self.parse_table.symbols {
let mut mapping = symbol;
// There can be multiple symbols in the grammar that have the same name and kind,
// due to simple aliases. When that happens, ensure that they map to the same
// public-facing symbol. If one of the symbols is not aliased, choose that one
// to be the public-facing symbol. Otherwise, pick the symbol with the lowest
// numeric value.
let mut mapping = symbol;
if let Some(alias) = self.simple_aliases.get(symbol) {
let kind = alias.kind();
for other_symbol in &self.parse_table.symbols {
@ -344,6 +345,20 @@ impl Generator {
}
}
}
// Two anonymous tokens with different flags but the same string value
// should be represented with the same symbol in the public API. Examples:
// * "<" and token(prec(1, "<"))
// * "(" and token.immediate("(")
else if symbol.is_terminal() {
let metadata = self.metadata_for_symbol(*symbol);
for other_symbol in &self.parse_table.symbols {
let other_metadata = self.metadata_for_symbol(*other_symbol);
if other_metadata == metadata {
mapping = other_symbol;
break;
}
}
}
add_line!(
self,

View file

@ -1,3 +1,4 @@
use super::util;
use crate::error::Result;
use crate::loader::Loader;
use ansi_term::Color;
@ -6,10 +7,8 @@ use serde::ser::SerializeMap;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use serde_json::{json, Value};
use std::collections::HashMap;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::Instant;
use std::{fs, io, path, str, thread, usize};
use std::{fs, io, path, str, usize};
use tree_sitter_highlight::{HighlightConfiguration, HighlightEvent, Highlighter, HtmlRenderer};
pub const HTML_HEADER: &'static str = "
@ -273,19 +272,6 @@ fn color_to_css(color: Color) -> &'static str {
}
}
fn cancel_on_stdin() -> Arc<AtomicUsize> {
let result = Arc::new(AtomicUsize::new(0));
thread::spawn({
let flag = result.clone();
move || {
let mut line = String::new();
io::stdin().read_line(&mut line).unwrap();
flag.store(1, Ordering::Relaxed);
}
});
result
}
pub fn ansi(
loader: &Loader,
theme: &Theme,
@ -296,7 +282,7 @@ pub fn ansi(
let stdout = io::stdout();
let mut stdout = stdout.lock();
let time = Instant::now();
let cancellation_flag = cancel_on_stdin();
let cancellation_flag = util::cancel_on_stdin();
let mut highlighter = Highlighter::new();
let events = highlighter.highlight(config, source, Some(&cancellation_flag), |string| {
@ -341,7 +327,7 @@ pub fn html(
let stdout = io::stdout();
let mut stdout = stdout.lock();
let time = Instant::now();
let cancellation_flag = cancel_on_stdin();
let cancellation_flag = util::cancel_on_stdin();
let mut highlighter = Highlighter::new();
let events = highlighter.highlight(config, source, Some(&cancellation_flag), |string| {

View file

@ -6,6 +6,7 @@ pub mod loader;
pub mod logger;
pub mod parse;
pub mod query;
pub mod tags;
pub mod test;
pub mod test_highlight;
pub mod util;

View file

@ -12,6 +12,7 @@ use std::time::SystemTime;
use std::{fs, mem};
use tree_sitter::Language;
use tree_sitter_highlight::HighlightConfiguration;
use tree_sitter_tags::TagsConfiguration;
#[cfg(unix)]
const DYLIB_EXTENSION: &'static str = "so";
@ -31,8 +32,10 @@ pub struct LanguageConfiguration<'a> {
pub highlights_filenames: Option<Vec<String>>,
pub injections_filenames: Option<Vec<String>>,
pub locals_filenames: Option<Vec<String>>,
pub tags_filenames: Option<Vec<String>>,
language_id: usize,
highlight_config: OnceCell<Option<HighlightConfiguration>>,
tags_config: OnceCell<Option<TagsConfiguration>>,
highlight_names: &'a Mutex<Vec<String>>,
use_all_highlight_names: bool,
}
@ -432,6 +435,8 @@ impl Loader {
injections: PathsJSON,
#[serde(default)]
locals: PathsJSON,
#[serde(default)]
tags: PathsJSON,
}
#[derive(Deserialize)]
@ -479,8 +484,10 @@ impl Loader {
injection_regex: Self::regex(config_json.injection_regex),
injections_filenames: config_json.injections.into_vec(),
locals_filenames: config_json.locals.into_vec(),
tags_filenames: config_json.tags.into_vec(),
highlights_filenames: config_json.highlights.into_vec(),
highlight_config: OnceCell::new(),
tags_config: OnceCell::new(),
highlight_names: &*self.highlight_names,
use_all_highlight_names: self.use_all_highlight_names,
};
@ -512,7 +519,9 @@ impl Loader {
injections_filenames: None,
locals_filenames: None,
highlights_filenames: None,
tags_filenames: None,
highlight_config: OnceCell::new(),
tags_config: OnceCell::new(),
highlight_names: &*self.highlight_names,
use_all_highlight_names: self.use_all_highlight_names,
};
@ -534,32 +543,11 @@ impl<'a> LanguageConfiguration<'a> {
pub fn highlight_config(&self, language: Language) -> Result<Option<&HighlightConfiguration>> {
self.highlight_config
.get_or_try_init(|| {
let queries_path = self.root_path.join("queries");
let read_queries = |paths: &Option<Vec<String>>, default_path: &str| {
if let Some(paths) = paths.as_ref() {
let mut query = String::new();
for path in paths {
let path = self.root_path.join(path);
query += &fs::read_to_string(&path).map_err(Error::wrap(|| {
format!("Failed to read query file {:?}", path)
}))?;
}
Ok(query)
} else {
let path = queries_path.join(default_path);
if path.exists() {
fs::read_to_string(&path).map_err(Error::wrap(|| {
format!("Failed to read query file {:?}", path)
}))
} else {
Ok(String::new())
}
}
};
let highlights_query = read_queries(&self.highlights_filenames, "highlights.scm")?;
let injections_query = read_queries(&self.injections_filenames, "injections.scm")?;
let locals_query = read_queries(&self.locals_filenames, "locals.scm")?;
let highlights_query =
self.read_queries(&self.highlights_filenames, "highlights.scm")?;
let injections_query =
self.read_queries(&self.injections_filenames, "injections.scm")?;
let locals_query = self.read_queries(&self.locals_filenames, "locals.scm")?;
if highlights_query.is_empty() {
Ok(None)
@ -587,6 +575,47 @@ impl<'a> LanguageConfiguration<'a> {
})
.map(Option::as_ref)
}
pub fn tags_config(&self, language: Language) -> Result<Option<&TagsConfiguration>> {
self.tags_config
.get_or_try_init(|| {
let tags_query = self.read_queries(&self.tags_filenames, "tags.scm")?;
let locals_query = self.read_queries(&self.locals_filenames, "locals.scm")?;
if tags_query.is_empty() {
Ok(None)
} else {
TagsConfiguration::new(language, &tags_query, &locals_query)
.map_err(Error::wrap(|| {
format!("Failed to load queries in {:?}", self.root_path)
}))
.map(|config| Some(config))
}
})
.map(Option::as_ref)
}
fn read_queries(&self, paths: &Option<Vec<String>>, default_path: &str) -> Result<String> {
if let Some(paths) = paths.as_ref() {
let mut query = String::new();
for path in paths {
let path = self.root_path.join(path);
query += &fs::read_to_string(&path).map_err(Error::wrap(|| {
format!("Failed to read query file {:?}", path)
}))?;
}
Ok(query)
} else {
let queries_path = self.root_path.join("queries");
let path = queries_path.join(default_path);
if path.exists() {
fs::read_to_string(&path).map_err(Error::wrap(|| {
format!("Failed to read query file {:?}", path)
}))
} else {
Ok(String::new())
}
}
}
}
fn needs_recompile(

View file

@ -6,8 +6,8 @@ use std::process::exit;
use std::{env, fs, u64};
use tree_sitter::Language;
use tree_sitter_cli::{
config, error, generate, highlight, loader, logger, parse, query, test, test_highlight, wasm,
web_ui,
config, error, generate, highlight, loader, logger, parse, query, tags, test, test_highlight,
wasm, web_ui,
};
const BUILD_VERSION: &'static str = env!("CARGO_PKG_VERSION");
@ -88,6 +88,30 @@ fn run() -> error::Result<()> {
.arg(Arg::with_name("scope").long("scope").takes_value(true))
.arg(Arg::with_name("captures").long("captures").short("c")),
)
.subcommand(
SubCommand::with_name("tags")
.arg(
Arg::with_name("format")
.short("f")
.long("format")
.value_name("json|protobuf")
.help("Determine output format (default: json)"),
)
.arg(Arg::with_name("scope").long("scope").takes_value(true))
.arg(
Arg::with_name("inputs")
.help("The source file to use")
.index(1)
.required(true)
.multiple(true),
)
.arg(
Arg::with_name("v")
.short("v")
.multiple(true)
.help("Sets the level of verbosity"),
),
)
.subcommand(
SubCommand::with_name("test")
.about("Run a parser's tests")
@ -240,6 +264,10 @@ fn run() -> error::Result<()> {
)?;
let query_path = Path::new(matches.value_of("query-path").unwrap());
query::query_files_at_paths(language, paths, query_path, ordered_captures)?;
} else if let Some(matches) = matches.subcommand_matches("tags") {
loader.find_all_languages(&config.parser_directories)?;
let paths = collect_paths(matches.values_of("inputs").unwrap())?;
tags::generate_tags(&loader, matches.value_of("scope"), &paths)?;
} else if let Some(matches) = matches.subcommand_matches("highlight") {
loader.configure_highlights(&config.theme.highlight_names);
loader.find_all_languages(&config.parser_directories)?;
@ -251,19 +279,17 @@ fn run() -> error::Result<()> {
println!("{}", highlight::HTML_HEADER);
}
let language_config;
let mut lang = None;
if let Some(scope) = matches.value_of("scope") {
language_config = loader.language_configuration_for_scope(scope)?;
if language_config.is_none() {
lang = loader.language_configuration_for_scope(scope)?;
if lang.is_none() {
return Error::err(format!("Unknown scope '{}'", scope));
}
} else {
language_config = None;
}
for path in paths {
let path = Path::new(&path);
let (language, language_config) = match language_config {
let (language, language_config) = match lang {
Some(v) => v,
None => match loader.language_configuration_for_file_name(path)? {
Some(v) => v,
@ -274,23 +300,21 @@ fn run() -> error::Result<()> {
},
};
let source = fs::read(path)?;
if let Some(highlight_config) = language_config.highlight_config(language)? {
let source = fs::read(path)?;
if html_mode {
highlight::html(&loader, &config.theme, &source, highlight_config, time)?;
} else {
highlight::ansi(&loader, &config.theme, &source, highlight_config, time)?;
}
} else {
return Error::err(format!("No syntax highlighting query found"));
eprintln!("No syntax highlighting config found for path {:?}", path);
}
}
if html_mode {
println!("{}", highlight::HTML_FOOTER);
}
} else if let Some(matches) = matches.subcommand_matches("build-wasm") {
let grammar_path = current_dir.join(matches.value_of("path").unwrap_or(""));
wasm::compile_language_to_wasm(&grammar_path, matches.is_present("docker"))?;

66
cli/src/tags.rs Normal file
View file

@ -0,0 +1,66 @@
use super::loader::Loader;
use super::util;
use crate::error::{Error, Result};
use std::io::{self, Write};
use std::path::Path;
use std::{fs, str};
use tree_sitter_tags::TagsContext;
pub fn generate_tags(loader: &Loader, scope: Option<&str>, paths: &[String]) -> Result<()> {
let mut lang = None;
if let Some(scope) = scope {
lang = loader.language_configuration_for_scope(scope)?;
if lang.is_none() {
return Error::err(format!("Unknown scope '{}'", scope));
}
}
let mut context = TagsContext::new();
let cancellation_flag = util::cancel_on_stdin();
let stdout = io::stdout();
let mut stdout = stdout.lock();
for path in paths {
let path = Path::new(&path);
let (language, language_config) = match lang {
Some(v) => v,
None => match loader.language_configuration_for_file_name(path)? {
Some(v) => v,
None => {
eprintln!("No language found for path {:?}", path);
continue;
}
},
};
if let Some(tags_config) = language_config.tags_config(language)? {
let path_str = format!("{:?}", path);
writeln!(&mut stdout, "{}", &path_str[1..path_str.len() - 1])?;
let source = fs::read(path)?;
for tag in context.generate_tags(tags_config, &source, Some(&cancellation_flag))? {
let tag = tag?;
write!(
&mut stdout,
" {:<8} {:<40}\t{:>9}-{:<9}",
tag.kind,
str::from_utf8(&source[tag.name_range]).unwrap_or(""),
tag.span.start,
tag.span.end,
)?;
if let Some(docs) = tag.docs {
if docs.len() > 120 {
write!(&mut stdout, "\t{:?}...", &docs[0..120])?;
} else {
write!(&mut stdout, "\t{:?}", &docs)?;
}
}
writeln!(&mut stdout, "")?;
}
} else {
eprintln!("No tags config found for path {:?}", path);
}
}
Ok(())
}

View file

@ -4,5 +4,6 @@ mod highlight_test;
mod node_test;
mod parser_test;
mod query_test;
mod tags_test;
mod test_highlight_test;
mod tree_test;

View file

@ -2,7 +2,8 @@ use super::helpers::allocations;
use super::helpers::fixtures::get_language;
use std::fmt::Write;
use tree_sitter::{
Node, Parser, Query, QueryCapture, QueryCursor, QueryError, QueryMatch, QueryProperty,
Node, Parser, Query, QueryCapture, QueryCursor, QueryError, QueryMatch, QueryPredicate,
QueryPredicateArg, QueryProperty,
};
#[test]
@ -438,6 +439,10 @@ fn test_query_matches_with_named_wildcard() {
fn test_query_matches_with_wildcard_at_the_root() {
allocations::record(|| {
let language = get_language("javascript");
let mut cursor = QueryCursor::new();
let mut parser = Parser::new();
parser.set_language(language).unwrap();
let query = Query::new(
language,
"
@ -452,20 +457,41 @@ fn test_query_matches_with_wildcard_at_the_root() {
let source = "/* one */ var x; /* two */ function y() {} /* three */ class Z {}";
let mut parser = Parser::new();
parser.set_language(language).unwrap();
let tree = parser.parse(source, None).unwrap();
let mut cursor = QueryCursor::new();
let matches = cursor.matches(&query, tree.root_node(), to_callback(source));
assert_eq!(
collect_matches(matches, &query, source),
&[(0, vec![("doc", "/* two */"), ("name", "y")]),]
);
let query = Query::new(
language,
"
(* (string) @a)
(* (number) @b)
(* (true) @c)
(* (false) @d)
",
)
.unwrap();
let source = "['hi', x(true), {y: false}]";
let tree = parser.parse(source, None).unwrap();
let matches = cursor.matches(&query, tree.root_node(), to_callback(source));
assert_eq!(
collect_matches(matches, &query, source),
&[
(0, vec![("a", "'hi'")]),
(2, vec![("c", "true")]),
(3, vec![("d", "false")]),
]
);
});
}
#[test]
fn test_query_with_immediate_siblings() {
fn test_query_matches_with_immediate_siblings() {
allocations::record(|| {
let language = get_language("python");
@ -515,6 +541,107 @@ fn test_query_with_immediate_siblings() {
});
}
#[test]
fn test_query_matches_with_repeated_leaf_nodes() {
allocations::record(|| {
let language = get_language("javascript");
let query = Query::new(
language,
"
(*
(comment)+ @doc
.
(class_declaration
name: (identifier) @name))
(*
(comment)+ @doc
.
(function_declaration
name: (identifier) @name))
",
)
.unwrap();
let source = "
// one
// two
a();
// three
{
// four
// five
// six
class B {}
// seven
c();
// eight
function d() {}
}
";
let mut parser = Parser::new();
parser.set_language(language).unwrap();
let tree = parser.parse(source, None).unwrap();
let mut cursor = QueryCursor::new();
let matches = cursor.matches(&query, tree.root_node(), to_callback(source));
assert_eq!(
collect_matches(matches, &query, source),
&[
(
0,
vec![
("doc", "// four"),
("doc", "// five"),
("doc", "// six"),
("name", "B")
]
),
(1, vec![("doc", "// eight"), ("name", "d")]),
]
);
});
}
#[test]
fn test_query_matches_with_repeated_internal_nodes() {
allocations::record(|| {
let language = get_language("javascript");
let mut parser = Parser::new();
parser.set_language(language).unwrap();
let mut cursor = QueryCursor::new();
let query = Query::new(
language,
"
(*
(method_definition
(decorator (identifier) @deco)+
name: (property_identifier) @name))
",
)
.unwrap();
let source = "
class A {
@c
@d
e() {}
}
";
let tree = parser.parse(source, None).unwrap();
let matches = cursor.matches(&query, tree.root_node(), to_callback(source));
assert_eq!(
collect_matches(matches, &query, source),
&[(0, vec![("deco", "c"), ("deco", "d"), ("name", "e")]),]
);
})
}
#[test]
fn test_query_matches_in_language_with_simple_aliases() {
allocations::record(|| {
@ -550,6 +677,41 @@ fn test_query_matches_in_language_with_simple_aliases() {
});
}
#[test]
fn test_query_matches_with_different_tokens_with_the_same_string_value() {
allocations::record(|| {
let language = get_language("rust");
let query = Query::new(
language,
r#"
"<" @less
">" @greater
"#,
)
.unwrap();
// In Rust, there are two '<' tokens: one for the binary operator,
// and one with higher precedence for generics.
let source = "const A: B<C> = d < e || f > g;";
let mut parser = Parser::new();
parser.set_language(language).unwrap();
let tree = parser.parse(&source, None).unwrap();
let mut cursor = QueryCursor::new();
let matches = cursor.matches(&query, tree.root_node(), to_callback(source));
assert_eq!(
collect_matches(matches, &query, source),
&[
(0, vec![("less", "<")]),
(1, vec![("greater", ">")]),
(0, vec![("less", "<")]),
(1, vec![("greater", ">")]),
]
);
});
}
#[test]
fn test_query_matches_with_too_many_permutations_to_track() {
allocations::record(|| {
@ -880,7 +1042,7 @@ fn test_query_captures_with_text_conditions() {
}
#[test]
fn test_query_captures_with_set_properties() {
fn test_query_captures_with_predicates() {
allocations::record(|| {
let language = get_language("javascript");
@ -889,7 +1051,8 @@ fn test_query_captures_with_set_properties() {
r#"
((call_expression (identifier) @foo)
(set! name something)
(set! cool))
(set! cool)
(something! @foo omg))
((property_identifier) @bar
(is? cool)
@ -904,6 +1067,16 @@ fn test_query_captures_with_set_properties() {
QueryProperty::new("cool", None, None),
]
);
assert_eq!(
query.general_predicates(0),
&[QueryPredicate {
operator: "something!".to_string().into_boxed_str(),
args: vec![
QueryPredicateArg::Capture(0),
QueryPredicateArg::String("omg".to_string().into_boxed_str()),
],
},]
);
assert_eq!(query.property_settings(1), &[]);
assert_eq!(query.property_predicates(0), &[]);
assert_eq!(
@ -917,7 +1090,7 @@ fn test_query_captures_with_set_properties() {
}
#[test]
fn test_query_captures_with_set_quoted_properties() {
fn test_query_captures_with_quoted_predicate_args() {
allocations::record(|| {
let language = get_language("javascript");

347
cli/src/tests/tags_test.rs Normal file
View file

@ -0,0 +1,347 @@
use super::helpers::allocations;
use super::helpers::fixtures::{get_language, get_language_queries_path};
use std::ffi::CString;
use std::{fs, ptr, slice, str};
use tree_sitter_tags::c_lib as c;
use tree_sitter_tags::{Error, TagKind, TagsConfiguration, TagsContext};
const PYTHON_TAG_QUERY: &'static str = r#"
((function_definition
name: (identifier) @name
body: (block . (expression_statement (string) @doc))) @function
(strip! @doc "(^['\"\\s]*)|(['\"\\s]*$)"))
(function_definition
name: (identifier) @name) @function
((class_definition
name: (identifier) @name
body: (block . (expression_statement (string) @doc))) @class
(strip! @doc "(^['\"\\s]*)|(['\"\\s]*$)"))
(class_definition
name: (identifier) @name) @class
(call
function: (identifier) @name) @call
"#;
const JS_TAG_QUERY: &'static str = r#"
((*
(comment)+ @doc .
(class_declaration
name: (identifier) @name) @class)
(select-adjacent! @doc @class)
(strip! @doc "(^[/\\*\\s]*)|([/\\*\\s]*$)"))
((*
(comment)+ @doc .
(method_definition
name: (property_identifier) @name) @method)
(select-adjacent! @doc @method)
(strip! @doc "(^[/\\*\\s]*)|([/\\*\\s]*$)"))
((*
(comment)+ @doc .
(function_declaration
name: (identifier) @name) @function)
(select-adjacent! @doc @function)
(strip! @doc "(^[/\\*\\s]*)|([/\\*\\s]*$)"))
(call_expression function: (identifier) @name) @call
"#;
const RUBY_TAG_QUERY: &'static str = r#"
(method
name: (identifier) @name) @method
(method_call
method: (identifier) @name) @call
((identifier) @name @call
(is-not? local))
"#;
#[test]
fn test_tags_python() {
let language = get_language("python");
let tags_config = TagsConfiguration::new(language, PYTHON_TAG_QUERY, "").unwrap();
let mut tag_context = TagsContext::new();
let source = br#"
class Customer:
"""
Data about a customer
"""
def age(self):
'''
Get the customer's age
'''
compute_age(self.id)
}
"#;
let tags = tag_context
.generate_tags(&tags_config, source, None)
.unwrap()
.collect::<Result<Vec<_>, _>>()
.unwrap();
assert_eq!(
tags.iter()
.map(|t| (substr(source, &t.name_range), t.kind))
.collect::<Vec<_>>(),
&[
("Customer", TagKind::Class),
("age", TagKind::Function),
("compute_age", TagKind::Call),
]
);
assert_eq!(substr(source, &tags[0].line_range), " class Customer:");
assert_eq!(
substr(source, &tags[1].line_range),
" def age(self):"
);
assert_eq!(tags[0].docs.as_ref().unwrap(), "Data about a customer");
assert_eq!(tags[1].docs.as_ref().unwrap(), "Get the customer's age");
}
#[test]
fn test_tags_javascript() {
let language = get_language("javascript");
let tags_config = TagsConfiguration::new(language, JS_TAG_QUERY, "").unwrap();
let source = br#"
// hi
// Data about a customer.
// bla bla bla
class Customer {
/*
* Get the customer's age
*/
getAge() {
}
}
// ok
class Agent {
}
"#;
let mut tag_context = TagsContext::new();
let tags = tag_context
.generate_tags(&tags_config, source, None)
.unwrap()
.collect::<Result<Vec<_>, _>>()
.unwrap();
assert_eq!(
tags.iter()
.map(|t| (substr(source, &t.name_range), t.kind))
.collect::<Vec<_>>(),
&[
("Customer", TagKind::Class),
("getAge", TagKind::Method),
("Agent", TagKind::Class)
]
);
assert_eq!(
tags[0].docs.as_ref().unwrap(),
"Data about a customer.\nbla bla bla"
);
assert_eq!(tags[1].docs.as_ref().unwrap(), "Get the customer's age");
assert_eq!(tags[2].docs, None);
}
#[test]
fn test_tags_ruby() {
let language = get_language("ruby");
let locals_query =
fs::read_to_string(get_language_queries_path("ruby").join("locals.scm")).unwrap();
let tags_config = TagsConfiguration::new(language, RUBY_TAG_QUERY, &locals_query).unwrap();
let source = strip_whitespace(
8,
"
b = 1
def foo()
c = 1
# a is a method because it is not in scope
# b is a method because `b` doesn't capture variables from its containing scope
bar a, b, c
[1, 2, 3].each do |a|
# a is a parameter
# b is a method
# c is a variable, because the block captures variables from its containing scope.
baz a, b, c
end
end",
);
let mut tag_context = TagsContext::new();
let tags = tag_context
.generate_tags(&tags_config, source.as_bytes(), None)
.unwrap()
.collect::<Result<Vec<_>, _>>()
.unwrap();
assert_eq!(
tags.iter()
.map(|t| (
substr(source.as_bytes(), &t.name_range),
t.kind,
(t.span.start.row, t.span.start.column),
))
.collect::<Vec<_>>(),
&[
("foo", TagKind::Method, (2, 0)),
("bar", TagKind::Call, (7, 4)),
("a", TagKind::Call, (7, 8)),
("b", TagKind::Call, (7, 11)),
("each", TagKind::Call, (9, 14)),
("baz", TagKind::Call, (13, 8)),
("b", TagKind::Call, (13, 15),),
]
);
}
#[test]
fn test_tags_cancellation() {
use std::sync::atomic::{AtomicUsize, Ordering};
allocations::record(|| {
// Large javascript document
let source = (0..500)
.map(|_| "/* hi */ class A { /* ok */ b() {} }\n")
.collect::<String>();
let cancellation_flag = AtomicUsize::new(0);
let language = get_language("javascript");
let tags_config = TagsConfiguration::new(language, JS_TAG_QUERY, "").unwrap();
let mut tag_context = TagsContext::new();
let tags = tag_context
.generate_tags(&tags_config, source.as_bytes(), Some(&cancellation_flag))
.unwrap();
for (i, tag) in tags.enumerate() {
if i == 150 {
cancellation_flag.store(1, Ordering::SeqCst);
}
if let Err(e) = tag {
assert_eq!(e, Error::Cancelled);
return;
}
}
panic!("Expected to halt tagging with an error");
});
}
#[test]
fn test_tags_via_c_api() {
allocations::record(|| {
let tagger = c::ts_tagger_new();
let buffer = c::ts_tags_buffer_new();
let scope_name = "source.js";
let language = get_language("javascript");
let source_code = strip_whitespace(
12,
"
var a = 1;
// one
// two
// three
function b() {
}
// four
// five
class C extends D {
}
b(a);",
);
let c_scope_name = CString::new(scope_name).unwrap();
let result = c::ts_tagger_add_language(
tagger,
c_scope_name.as_ptr(),
language,
JS_TAG_QUERY.as_ptr(),
ptr::null(),
JS_TAG_QUERY.len() as u32,
0,
);
assert_eq!(result, c::TSTagsError::Ok);
let result = c::ts_tagger_tag(
tagger,
c_scope_name.as_ptr(),
source_code.as_ptr(),
source_code.len() as u32,
buffer,
ptr::null(),
);
assert_eq!(result, c::TSTagsError::Ok);
let tags = unsafe {
slice::from_raw_parts(
c::ts_tags_buffer_tags(buffer),
c::ts_tags_buffer_tags_len(buffer) as usize,
)
};
let docs = str::from_utf8(unsafe {
slice::from_raw_parts(
c::ts_tags_buffer_docs(buffer) as *const u8,
c::ts_tags_buffer_docs_len(buffer) as usize,
)
})
.unwrap();
assert_eq!(
tags.iter()
.map(|tag| (
tag.kind,
&source_code[tag.name_start_byte as usize..tag.name_end_byte as usize],
&source_code[tag.line_start_byte as usize..tag.line_end_byte as usize],
&docs[tag.docs_start_byte as usize..tag.docs_end_byte as usize],
))
.collect::<Vec<_>>(),
&[
(
c::TSTagKind::Function,
"b",
"function b() {",
"one\ntwo\nthree"
),
(
c::TSTagKind::Class,
"C",
"class C extends D {",
"four\nfive"
),
(c::TSTagKind::Call, "b", "b(a);", "")
]
);
c::ts_tags_buffer_delete(buffer);
c::ts_tagger_delete(tagger);
});
}
fn substr<'a>(source: &'a [u8], range: &std::ops::Range<usize>) -> &'a str {
std::str::from_utf8(&source[range.clone()]).unwrap()
}
fn strip_whitespace(indent: usize, s: &str) -> String {
s.lines()
.skip(1)
.map(|line| &line[line.len().min(indent)..])
.collect::<Vec<_>>()
.join("\n")
}

View file

@ -1,12 +1,29 @@
use std::io;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::thread;
use tree_sitter::Parser;
#[cfg(unix)]
use std::path::PathBuf;
#[cfg(unix)]
use std::process::{Child, ChildStdin, Command, Stdio};
use tree_sitter::Parser;
#[cfg(unix)]
const HTML_HEADER: &[u8] = b"<!DOCTYPE html>\n<style>svg { width: 100%; }</style>\n\n";
pub fn cancel_on_stdin() -> Arc<AtomicUsize> {
let result = Arc::new(AtomicUsize::new(0));
thread::spawn({
let flag = result.clone();
move || {
let mut line = String::new();
io::stdin().read_line(&mut line).unwrap();
flag.store(1, Ordering::Relaxed);
}
});
result
}
#[cfg(windows)]
pub struct LogSession();

View file

@ -18,9 +18,6 @@ crate-type = ["lib", "staticlib"]
[dependencies]
regex = "1"
serde = "1.0"
serde_json = "1.0"
serde_derive = "1.0"
[dependencies.tree-sitter]
version = ">= 0.3.7"

View file

@ -95,6 +95,7 @@ pub struct Query {
text_predicates: Vec<Box<[TextPredicate]>>,
property_settings: Vec<Box<[QueryProperty]>>,
property_predicates: Vec<Box<[(QueryProperty, bool)]>>,
general_predicates: Vec<Box<[QueryPredicate]>>,
}
/// A stateful object for executing a `Query` on a syntax `Tree`.
@ -108,6 +109,19 @@ pub struct QueryProperty {
pub capture_id: Option<usize>,
}
#[derive(Debug, PartialEq, Eq)]
pub enum QueryPredicateArg {
Capture(u32),
String(Box<str>),
}
/// A key-value pair associated with a particular pattern in a `Query`.
#[derive(Debug, PartialEq, Eq)]
pub struct QueryPredicate {
pub operator: Box<str>,
pub args: Vec<QueryPredicateArg>,
}
/// A match of a `Query` to a particular set of `Node`s.
pub struct QueryMatch<'a> {
pub pattern_index: usize,
@ -1194,6 +1208,7 @@ impl Query {
text_predicates: Vec::with_capacity(pattern_count),
property_predicates: Vec::with_capacity(pattern_count),
property_settings: Vec::with_capacity(pattern_count),
general_predicates: Vec::with_capacity(pattern_count),
};
// Build a vector of strings to store the capture names.
@ -1237,6 +1252,7 @@ impl Query {
let mut text_predicates = Vec::new();
let mut property_predicates = Vec::new();
let mut property_settings = Vec::new();
let mut general_predicates = Vec::new();
for p in predicate_steps.split(|s| s.type_ == type_done) {
if p.is_empty() {
continue;
@ -1328,12 +1344,21 @@ impl Query {
operator_name == "is?",
)),
_ => {
return Err(QueryError::Predicate(format!(
"Unknown query predicate function {}",
operator_name,
)))
}
_ => general_predicates.push(QueryPredicate {
operator: operator_name.clone().into_boxed_str(),
args: p[1..]
.iter()
.map(|a| {
if a.type_ == type_capture {
QueryPredicateArg::Capture(a.value_id)
} else {
QueryPredicateArg::String(
string_values[a.value_id as usize].clone().into_boxed_str(),
)
}
})
.collect(),
}),
}
}
@ -1346,6 +1371,9 @@ impl Query {
result
.property_settings
.push(property_settings.into_boxed_slice());
result
.general_predicates
.push(general_predicates.into_boxed_slice());
}
Ok(result)
}
@ -1375,15 +1403,30 @@ impl Query {
}
/// Get the properties that are checked for the given pattern index.
///
/// This includes predicates with the operators `is?` and `is-not?`.
pub fn property_predicates(&self, index: usize) -> &[(QueryProperty, bool)] {
&self.property_predicates[index]
}
/// Get the properties that are set for the given pattern index.
///
/// This includes predicates with the operator `set!`.
pub fn property_settings(&self, index: usize) -> &[QueryProperty] {
&self.property_settings[index]
}
/// Get the other user-defined predicates associated with the given index.
///
/// This includes predicate with operators other than:
/// * `match?`
/// * `eq?` and `not-eq?
/// * `is?` and `is-not?`
/// * `set!`
pub fn general_predicates(&self, index: usize) -> &[QueryPredicate] {
&self.general_predicates[index]
}
/// Disable a certain capture within a query.
///
/// This prevents the capture from being returned in matches, and also avoids any
@ -1420,46 +1463,39 @@ impl Query {
)));
}
let mut i = 0;
let mut capture_id = None;
if args[i].type_ == ffi::TSQueryPredicateStepType_TSQueryPredicateStepTypeCapture {
capture_id = Some(args[i].value_id as usize);
i += 1;
if i == args.len() {
return Err(QueryError::Predicate(format!(
"No key specified for {} predicate.",
function_name,
)));
}
if args[i].type_ == ffi::TSQueryPredicateStepType_TSQueryPredicateStepTypeCapture {
return Err(QueryError::Predicate(format!(
"Invalid arguments to {} predicate. Expected string, got @{}",
function_name, capture_names[args[i].value_id as usize]
)));
}
}
let key = &string_values[args[i].value_id as usize];
i += 1;
let mut key = None;
let mut value = None;
if i < args.len() {
if args[i].type_ == ffi::TSQueryPredicateStepType_TSQueryPredicateStepTypeCapture {
for arg in args {
if arg.type_ == ffi::TSQueryPredicateStepType_TSQueryPredicateStepTypeCapture {
if capture_id.is_some() {
return Err(QueryError::Predicate(format!(
"Invalid arguments to {} predicate. Unexpected second capture name @{}",
function_name, capture_names[args[i].value_id as usize]
function_name, capture_names[arg.value_id as usize]
)));
} else {
capture_id = Some(args[i].value_id as usize);
}
capture_id = Some(arg.value_id as usize);
} else if key.is_none() {
key = Some(&string_values[arg.value_id as usize]);
} else if value.is_none() {
value = Some(string_values[arg.value_id as usize].as_str());
} else {
value = Some(string_values[args[i].value_id as usize].as_str());
return Err(QueryError::Predicate(format!(
"Invalid arguments to {} predicate. Unexpected third argument @{}",
function_name, string_values[arg.value_id as usize]
)));
}
}
Ok(QueryProperty::new(key, value, capture_id))
if let Some(key) = key {
Ok(QueryProperty::new(key, value, capture_id))
} else {
return Err(QueryError::Predicate(format!(
"Invalid arguments to {} predicate. Missing key argument",
function_name,
)));
}
}
}

View file

@ -35,15 +35,21 @@ typedef struct {
* captured in this pattern.
* - `depth` - The depth where this node occurs in the pattern. The root node
* of the pattern has depth zero.
* - `repeat_step_index` - If this step is part of a repetition, the index of
* the beginning of the repetition. A `NONE` value means this step is not
* part of a repetition.
*/
typedef struct {
TSSymbol symbol;
TSFieldId field;
uint16_t capture_ids[MAX_STEP_CAPTURE_COUNT];
uint16_t depth: 13;
uint16_t repeat_step_index;
uint16_t depth: 11;
bool contains_captures: 1;
bool is_pattern_start: 1;
bool is_immediate: 1;
bool is_last: 1;
bool is_repeated: 1;
} QueryStep;
/*
@ -85,23 +91,27 @@ typedef struct {
* represented as one of these states.
*/
typedef struct {
uint32_t id;
uint16_t start_depth;
uint16_t pattern_index;
uint16_t step_index;
uint16_t capture_count;
uint16_t capture_list_id;
uint16_t consumed_capture_count;
uint32_t id;
uint16_t repeat_match_count;
uint16_t step_index_on_failure;
uint8_t capture_list_id;
bool seeking_non_match;
} QueryState;
typedef Array(TSQueryCapture) CaptureList;
/*
* CaptureListPool - A collection of *lists* of captures. Each QueryState
* needs to maintain its own list of captures. They are all represented as
* slices of one shared array. The CaptureListPool keeps track of which
* parts of the shared array are currently in use by a QueryState.
* needs to maintain its own list of captures. To avoid repeated allocations,
* the reuses a fixed set of capture lists, and keeps track of which ones
* are currently in use.
*/
typedef struct {
Array(TSQueryCapture) list;
CaptureList list[32];
uint32_t usage_map;
} CaptureListPool;
@ -119,7 +129,6 @@ struct TSQuery {
Array(Slice) predicates_by_pattern;
Array(uint32_t) start_bytes_by_pattern;
const TSLanguage *language;
uint16_t max_capture_count;
uint16_t wildcard_root_pattern_count;
TSSymbol *symbol_map;
};
@ -233,24 +242,25 @@ static void stream_scan_identifier(Stream *stream) {
static CaptureListPool capture_list_pool_new() {
return (CaptureListPool) {
.list = array_new(),
.usage_map = UINT32_MAX,
};
}
static void capture_list_pool_reset(CaptureListPool *self, uint16_t list_size) {
static void capture_list_pool_reset(CaptureListPool *self) {
self->usage_map = UINT32_MAX;
uint32_t total_size = MAX_STATE_COUNT * list_size;
array_reserve(&self->list, total_size);
self->list.size = total_size;
for (unsigned i = 0; i < 32; i++) {
array_clear(&self->list[i]);
}
}
static void capture_list_pool_delete(CaptureListPool *self) {
array_delete(&self->list);
for (unsigned i = 0; i < 32; i++) {
array_delete(&self->list[i]);
}
}
static TSQueryCapture *capture_list_pool_get(CaptureListPool *self, uint16_t id) {
return &self->list.contents[id * (self->list.size / MAX_STATE_COUNT)];
static CaptureList *capture_list_pool_get(CaptureListPool *self, uint16_t id) {
return &self->list[id];
}
static bool capture_list_pool_is_empty(const CaptureListPool *self) {
@ -269,6 +279,7 @@ static uint16_t capture_list_pool_acquire(CaptureListPool *self) {
}
static void capture_list_pool_release(CaptureListPool *self, uint16_t id) {
array_clear(&self->list[id]);
self->usage_map |= bitmask_for_index(id);
}
@ -407,7 +418,11 @@ static QueryStep query_step__new(
.field = 0,
.capture_ids = {NONE, NONE, NONE, NONE},
.contains_captures = false,
.is_repeated = false,
.is_last = false,
.is_pattern_start = false,
.is_immediate = is_immediate,
.repeat_step_index = NONE,
};
}
@ -842,27 +857,43 @@ static TSQueryError ts_query__parse_pattern(
stream_skip_whitespace(stream);
// Parse an '@'-prefixed capture pattern
while (stream->next == '@') {
stream_advance(stream);
// Parse the capture name
if (!stream_is_ident_start(stream)) return TSQueryErrorSyntax;
const char *capture_name = stream->input;
stream_scan_identifier(stream);
uint32_t length = stream->input - capture_name;
// Add the capture id to the first step of the pattern
uint16_t capture_id = symbol_table_insert_name(
&self->captures,
capture_name,
length
);
// Parse suffixes modifiers for this pattern
for (;;) {
QueryStep *step = &self->steps.contents[starting_step_index];
query_step__add_capture(step, capture_id);
(*capture_count)++;
stream_skip_whitespace(stream);
if (stream->next == '+') {
stream_advance(stream);
step->is_repeated = true;
array_back(&self->steps)->repeat_step_index = starting_step_index;
stream_skip_whitespace(stream);
}
// Parse an '@'-prefixed capture pattern
else if (stream->next == '@') {
stream_advance(stream);
// Parse the capture name
if (!stream_is_ident_start(stream)) return TSQueryErrorSyntax;
const char *capture_name = stream->input;
stream_scan_identifier(stream);
uint32_t length = stream->input - capture_name;
// Add the capture id to the first step of the pattern
uint16_t capture_id = symbol_table_insert_name(
&self->captures,
capture_name,
length
);
query_step__add_capture(step, capture_id);
(*capture_count)++;
stream_skip_whitespace(stream);
}
// No more suffix modifiers
else {
break;
}
}
return 0;
@ -912,16 +943,14 @@ TSQuery *ts_query_new(
.predicates_by_pattern = array_new(),
.symbol_map = symbol_map,
.wildcard_root_pattern_count = 0,
.max_capture_count = 0,
.language = language,
};
// Parse all of the S-expressions in the given string.
Stream stream = stream_new(source, source_len);
stream_skip_whitespace(&stream);
uint32_t start_step_index;
while (stream.input < stream.end) {
start_step_index = self->steps.size;
uint32_t start_step_index = self->steps.size;
uint32_t capture_count = 0;
array_push(&self->start_bytes_by_pattern, stream.input - source);
array_push(&self->predicates_by_pattern, ((Slice) {
@ -939,7 +968,19 @@ TSQuery *ts_query_new(
return NULL;
}
// If a pattern has a wildcard at its root, optimize the matching process
// by skipping matching the wildcard.
if (
self->steps.contents[start_step_index].symbol == WILDCARD_SYMBOL
) {
QueryStep *second_step = &self->steps.contents[start_step_index + 1];
if (second_step->symbol != WILDCARD_SYMBOL && second_step->depth != PATTERN_DONE_MARKER) {
start_step_index += 1;
}
}
// Maintain a map that can look up patterns for a given root symbol.
self->steps.contents[start_step_index].is_pattern_start = true;
ts_query__pattern_map_insert(
self,
self->steps.contents[start_step_index].symbol,
@ -948,13 +989,6 @@ TSQuery *ts_query_new(
if (self->steps.contents[start_step_index].symbol == WILDCARD_SYMBOL) {
self->wildcard_root_pattern_count++;
}
// Keep track of the maximum number of captures in pattern, because
// that numer determines how much space is needed to store each capture
// list.
if (capture_count > self->max_capture_count) {
self->max_capture_count = capture_count;
}
}
ts_query__finalize_steps(self);
@ -1089,7 +1123,7 @@ void ts_query_cursor_exec(
array_clear(&self->states);
array_clear(&self->finished_states);
ts_tree_cursor_reset(&self->cursor, node);
capture_list_pool_reset(&self->capture_list_pool, query->max_capture_count);
capture_list_pool_reset(&self->capture_list_pool);
self->next_state_id = 0;
self->depth = 0;
self->ascending = false;
@ -1133,12 +1167,12 @@ static bool ts_query_cursor__first_in_progress_capture(
bool result = false;
for (unsigned i = 0; i < self->states.size; i++) {
const QueryState *state = &self->states.contents[i];
if (state->capture_count > 0) {
const TSQueryCapture *captures = capture_list_pool_get(
&self->capture_list_pool,
state->capture_list_id
);
uint32_t capture_byte = ts_node_start_byte(captures[0].node);
const CaptureList *captures = capture_list_pool_get(
&self->capture_list_pool,
state->capture_list_id
);
if (captures->size > 0) {
uint32_t capture_byte = ts_node_start_byte(captures->contents[0].node);
if (
!result ||
capture_byte < *byte_offset ||
@ -1161,6 +1195,19 @@ static bool ts_query__cursor_add_state(
TSQueryCursor *self,
const PatternEntry *pattern
) {
QueryStep *step = &self->query->steps.contents[pattern->step_index];
// If this pattern begins with a repetition, then avoid creating
// new states after already matching the repetition one or more times.
// The query should only one match for the repetition - the one that
// started the earliest.
if (step->is_repeated) {
for (unsigned i = 0; i < self->states.size; i++) {
QueryState *state = &self->states.contents[i];
if (state->step_index == pattern->step_index) return true;
}
}
uint32_t list_id = capture_list_pool_acquire(&self->capture_list_pool);
// If there are no capture lists left in the pool, then terminate whichever
@ -1186,14 +1233,20 @@ static bool ts_query__cursor_add_state(
}
}
LOG(" start state. pattern:%u\n", pattern->pattern_index);
LOG(
" start state. pattern:%u, step:%u\n",
pattern->pattern_index,
pattern->step_index
);
array_push(&self->states, ((QueryState) {
.capture_list_id = list_id,
.step_index = pattern->step_index,
.pattern_index = pattern->pattern_index,
.start_depth = self->depth,
.capture_count = 0,
.start_depth = self->depth - step->depth,
.consumed_capture_count = 0,
.repeat_match_count = 0,
.step_index_on_failure = NONE,
.seeking_non_match = false,
}));
return true;
}
@ -1207,15 +1260,15 @@ static QueryState *ts_query__cursor_copy_state(
array_push(&self->states, *state);
QueryState *new_state = array_back(&self->states);
new_state->capture_list_id = new_list_id;
TSQueryCapture *old_captures = capture_list_pool_get(
CaptureList *old_captures = capture_list_pool_get(
&self->capture_list_pool,
state->capture_list_id
);
TSQueryCapture *new_captures = capture_list_pool_get(
CaptureList *new_captures = capture_list_pool_get(
&self->capture_list_pool,
new_list_id
);
memcpy(new_captures, old_captures, state->capture_count * sizeof(TSQueryCapture));
array_push_all(new_captures, old_captures);
return new_state;
}
@ -1372,6 +1425,24 @@ static inline bool ts_query_cursor__advance(TSQueryCursor *self) {
}
if (!node_does_match) {
// If this QueryState has processed a repeating sequence, and that repeating
// sequence has ended, move on to the *next* step of this state's pattern.
if (
state->step_index_on_failure != NONE &&
(!later_sibling_can_match || step->is_repeated)
) {
LOG(
" finish repetition state. pattern:%u, step:%u\n",
state->pattern_index,
state->step_index
);
state->step_index = state->step_index_on_failure;
state->step_index_on_failure = NONE;
state->repeat_match_count = 0;
i--;
continue;
}
if (!later_sibling_can_match) {
LOG(
" discard state. pattern:%u, step:%u\n",
@ -1386,9 +1457,17 @@ static inline bool ts_query_cursor__advance(TSQueryCursor *self) {
i--;
n--;
}
state->seeking_non_match = false;
continue;
}
// The `seeking_non_match` flag indicates that a previous QueryState
// has already begun processing this repeating sequence, so that *this*
// QueryState should not begin matching until a separate repeating sequence
// is found.
if (state->seeking_non_match) continue;
// Some patterns can match their root node in multiple ways,
// capturing different children. If this pattern step could match
// later children within the same parent, then this query state
@ -1398,11 +1477,20 @@ static inline bool ts_query_cursor__advance(TSQueryCursor *self) {
// siblings.
QueryState *next_state = state;
if (
step->depth > 0 &&
!step->is_pattern_start &&
step->contains_captures &&
later_sibling_can_match
later_sibling_can_match &&
state->repeat_match_count == 0
) {
QueryState *copy = ts_query__cursor_copy_state(self, state);
// The QueryState that matched this node has begun matching a repeating
// sequence. The QueryState that *skipped* this node should not start
// matching later elements of the same repeating sequence.
if (step->is_repeated) {
state->seeking_non_match = true;
}
if (copy) {
LOG(
" split state. pattern:%u, step:%u\n",
@ -1411,55 +1499,71 @@ static inline bool ts_query_cursor__advance(TSQueryCursor *self) {
);
next_state = copy;
} else {
LOG(" canot split state.\n");
LOG(" cannot split state.\n");
}
}
LOG(
" advance state. pattern:%u, step:%u\n",
next_state->pattern_index,
next_state->step_index
);
// If the current node is captured in this pattern, add it to the
// capture list.
for (unsigned j = 0; j < MAX_STEP_CAPTURE_COUNT; j++) {
uint16_t capture_id = step->capture_ids[j];
if (step->capture_ids[j] == NONE) break;
LOG(
" capture node. pattern:%u, capture_id:%u\n",
next_state->pattern_index,
capture_id
);
TSQueryCapture *capture_list = capture_list_pool_get(
CaptureList *capture_list = capture_list_pool_get(
&self->capture_list_pool,
next_state->capture_list_id
);
capture_list[next_state->capture_count++] = (TSQueryCapture) {
array_push(capture_list, ((TSQueryCapture) {
node,
capture_id
};
}));
LOG(
" capture node. pattern:%u, capture_id:%u, capture_count:%u\n",
next_state->pattern_index,
capture_id,
capture_list->size
);
}
// If the pattern is now done, then remove it from the list of
// in-progress states, and add it to the list of finished states.
next_state->step_index++;
QueryStep *next_step = step + 1;
if (next_step->depth == PATTERN_DONE_MARKER) {
LOG(" finish pattern %u\n", next_state->pattern_index);
// If this is the end of a repetition, then jump back to the beginning
// of that repetition.
if (step->repeat_step_index != NONE) {
next_state->step_index_on_failure = next_state->step_index + 1;
next_state->step_index = step->repeat_step_index;
next_state->repeat_match_count++;
LOG(
" continue repeat. pattern:%u, match_count:%u\n",
next_state->pattern_index,
next_state->repeat_match_count
);
} else {
next_state->step_index++;
LOG(
" advance state. pattern:%u, step:%u\n",
next_state->pattern_index,
next_state->step_index
);
next_state->id = self->next_state_id++;
array_push(&self->finished_states, *next_state);
if (next_state == state) {
array_erase(&self->states, i);
i--;
n--;
} else {
self->states.size--;
QueryStep *next_step = step + 1;
// If the pattern is now done, then remove it from the list of
// in-progress states, and add it to the list of finished states.
if (next_step->depth == PATTERN_DONE_MARKER) {
LOG(" finish pattern %u\n", next_state->pattern_index);
next_state->id = self->next_state_id++;
array_push(&self->finished_states, *next_state);
if (next_state == state) {
array_erase(&self->states, i);
i--;
n--;
} else {
self->states.size--;
}
}
}
}
// Continue descending if possible.
if (ts_tree_cursor_goto_first_child(&self->cursor)) {
self->depth++;
@ -1485,11 +1589,12 @@ bool ts_query_cursor_next_match(
QueryState *state = &self->finished_states.contents[0];
match->id = state->id;
match->pattern_index = state->pattern_index;
match->capture_count = state->capture_count;
match->captures = capture_list_pool_get(
CaptureList *captures = capture_list_pool_get(
&self->capture_list_pool,
state->capture_list_id
);
match->captures = captures->contents;
match->capture_count = captures->size;
capture_list_pool_release(&self->capture_list_pool, state->capture_list_id);
array_erase(&self->finished_states, 0);
return true;
@ -1542,13 +1647,13 @@ bool ts_query_cursor_next_capture(
uint32_t first_finished_pattern_index = first_unfinished_pattern_index;
for (unsigned i = 0; i < self->finished_states.size; i++) {
const QueryState *state = &self->finished_states.contents[i];
if (state->capture_count > state->consumed_capture_count) {
const TSQueryCapture *captures = capture_list_pool_get(
&self->capture_list_pool,
state->capture_list_id
);
CaptureList *captures = capture_list_pool_get(
&self->capture_list_pool,
state->capture_list_id
);
if (captures->size > state->consumed_capture_count) {
uint32_t capture_byte = ts_node_start_byte(
captures[state->consumed_capture_count].node
captures->contents[state->consumed_capture_count].node
);
if (
capture_byte < first_finished_capture_byte ||
@ -1580,11 +1685,12 @@ bool ts_query_cursor_next_capture(
];
match->id = state->id;
match->pattern_index = state->pattern_index;
match->capture_count = state->capture_count;
match->captures = capture_list_pool_get(
CaptureList *captures = capture_list_pool_get(
&self->capture_list_pool,
state->capture_list_id
);
match->captures = captures->contents;
match->capture_count = captures->size;
*capture_index = state->consumed_capture_count;
state->consumed_capture_count++;
return true;

25
tags/Cargo.toml Normal file
View file

@ -0,0 +1,25 @@
[package]
name = "tree-sitter-tags"
description = "Library for extracting tag information"
version = "0.1.6"
authors = [
"Max Brunsfeld <maxbrunsfeld@gmail.com>",
"Patrick Thomson <patrickt@github.com>"
]
license = "MIT"
readme = "README.md"
edition = "2018"
keywords = ["incremental", "parsing", "syntax", "tagging"]
categories = ["parsing", "text-editors"]
repository = "https://github.com/tree-sitter/tree-sitter"
[lib]
crate-type = ["lib", "staticlib"]
[dependencies]
regex = "1"
memchr = "2.3"
[dependencies.tree-sitter]
version = ">= 0.3.7"
path = "../lib"

60
tags/README.md Normal file
View file

@ -0,0 +1,60 @@
Tree-sitter Tags
=========================
### Usage
Compile some languages into your app, and declare them:
```rust
extern "C" tree_sitter_python();
extern "C" tree_sitter_javascript();
```
Create a tag context. You need one of these for each thread that you're using for tag computation:
```rust
use tree_sitter_tags::TagsContext;
let context = TagsContext::new();
```
Load some tagging queries from the `queries` directory of some language repositories:
```rust
use tree_sitter_highlight::TagsConfiguration;
let python_language = unsafe { tree_sitter_python() };
let javascript_language = unsafe { tree_sitter_javascript() };
let python_config = HighlightConfiguration::new(
python_language,
&fs::read_to_string("./tree-sitter-python/queries/tags.scm").unwrap(),
&fs::read_to_string("./tree-sitter-python/queries/locals.scm").unwrap(),
).unwrap();
let javascript_config = HighlightConfiguration::new(
javascript_language,
&fs::read_to_string("./tree-sitter-javascript/queries/tags.scm").unwrap(),
&fs::read_to_string("./tree-sitter-javascript/queries/locals.scm").unwrap(),
).unwrap();
```
Compute code navigation tags for some source code:
```rust
use tree_sitter_highlight::HighlightEvent;
let tags = context.generate_tags(
&javascript_config,
b"class A { getB() { return c(); } }",
None,
|_| None
);
for tag in tags {
println!("kind: {:?}", tag.kind);
println!("range: {:?}", tag.range);
println!("name_range: {:?}", tag.name_range);
println!("docs: {:?}", tag.docs);
}
```

View file

@ -0,0 +1,96 @@
#ifndef TREE_SITTER_TAGS_H_
#define TREE_SITTER_TAGS_H_
#ifdef __cplusplus
extern "C" {
#endif
#include <stdint.h>
#include "tree_sitter/api.h"
typedef enum {
TSTagsOk,
TSTagsUnknownScope,
TSTagsTimeout,
TSTagsInvalidLanguage,
TSTagsInvalidUtf8,
TSTagsInvalidRegex,
TSTagsInvalidQuery,
} TSTagsError;
typedef enum {
TSTagKindFunction,
TSTagKindMethod,
TSTagKindClass,
TSTagKindModule,
TSTagKindCall,
} TSTagKind;
typedef struct {
TSTagKind kind;
uint32_t start_byte;
uint32_t end_byte;
uint32_t name_start_byte;
uint32_t name_end_byte;
uint32_t line_start_byte;
uint32_t line_end_byte;
TSPoint start_point;
TSPoint end_point;
uint32_t docs_start_byte;
uint32_t docs_end_byte;
} TSTag;
typedef struct TSTagger TSTagger;
typedef struct TSTagsBuffer TSTagsBuffer;
// Construct a tagger.
TSTagger *ts_tagger_new();
// Delete a tagger.
void ts_tagger_delete(TSTagger *);
// Add a `TSLanguage` to a tagger. The language is associated with a scope name,
// which can be used later to select a language for tagging. Along with the language,
// you must provide two tree query strings, one for matching tags themselves, and one
// specifying local variable definitions.
TSTagsError ts_tagger_add_language(
TSTagger *self,
const char *scope_name,
const TSLanguage *language,
const char *tags_query,
const char *locals_query,
uint32_t tags_query_len,
uint32_t locals_query_len
);
// Compute syntax highlighting for a given document. You must first
// create a `TSTagsBuffer` to hold the output.
TSTagsError ts_tagger_tag(
const TSTagger *self,
const char *scope_name,
const char *source_code,
uint32_t source_code_len,
TSTagsBuffer *output,
const size_t *cancellation_flag
);
// A tags buffer stores the results produced by a tagging call. It can be reused
// for multiple calls.
TSTagsBuffer *ts_tags_buffer_new();
// Delete a tags buffer.
void ts_tags_buffer_delete(TSTagsBuffer *);
// Access the tags within a tag buffer.
const TSTag *ts_tags_buffer_tags(const TSTagsBuffer *);
uint32_t ts_tags_buffer_tags_len(const TSTagsBuffer *);
// Access the string containing all of the docs
const char *ts_tags_buffer_docs(const TSTagsBuffer *);
uint32_t ts_tags_buffer_docs_len(const TSTagsBuffer *);
#ifdef __cplusplus
}
#endif
#endif // TREE_SITTER_TAGS_H_

245
tags/src/c_lib.rs Normal file
View file

@ -0,0 +1,245 @@
use super::{Error, TagKind, TagsConfiguration, TagsContext};
use std::collections::HashMap;
use std::ffi::CStr;
use std::process::abort;
use std::sync::atomic::AtomicUsize;
use std::{fmt, slice, str};
use tree_sitter::Language;
#[repr(C)]
#[derive(Debug, PartialEq, Eq)]
pub enum TSTagsError {
Ok,
UnknownScope,
Timeout,
InvalidLanguage,
InvalidUtf8,
InvalidRegex,
InvalidQuery,
Unknown,
}
#[repr(C)]
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum TSTagKind {
Function,
Method,
Class,
Module,
Call,
}
#[repr(C)]
pub struct TSPoint {
row: u32,
column: u32,
}
#[repr(C)]
pub struct TSTag {
pub kind: TSTagKind,
pub start_byte: u32,
pub end_byte: u32,
pub name_start_byte: u32,
pub name_end_byte: u32,
pub line_start_byte: u32,
pub line_end_byte: u32,
pub start_point: TSPoint,
pub end_point: TSPoint,
pub docs_start_byte: u32,
pub docs_end_byte: u32,
}
pub struct TSTagger {
languages: HashMap<String, TagsConfiguration>,
}
pub struct TSTagsBuffer {
context: TagsContext,
tags: Vec<TSTag>,
docs: Vec<u8>,
}
#[no_mangle]
pub extern "C" fn ts_tagger_new() -> *mut TSTagger {
Box::into_raw(Box::new(TSTagger {
languages: HashMap::new(),
}))
}
#[no_mangle]
pub extern "C" fn ts_tagger_delete(this: *mut TSTagger) {
drop(unsafe { Box::from_raw(this) })
}
#[no_mangle]
pub extern "C" fn ts_tagger_add_language(
this: *mut TSTagger,
scope_name: *const i8,
language: Language,
tags_query: *const u8,
locals_query: *const u8,
tags_query_len: u32,
locals_query_len: u32,
) -> TSTagsError {
let tagger = unwrap_mut_ptr(this);
let scope_name = unsafe { unwrap(CStr::from_ptr(scope_name).to_str()) };
let tags_query = unsafe { slice::from_raw_parts(tags_query, tags_query_len as usize) };
let locals_query = unsafe { slice::from_raw_parts(locals_query, locals_query_len as usize) };
let tags_query = match str::from_utf8(tags_query) {
Ok(e) => e,
Err(_) => return TSTagsError::InvalidUtf8,
};
let locals_query = match str::from_utf8(locals_query) {
Ok(e) => e,
Err(_) => return TSTagsError::InvalidUtf8,
};
match TagsConfiguration::new(language, tags_query, locals_query) {
Ok(c) => {
tagger.languages.insert(scope_name.to_string(), c);
TSTagsError::Ok
}
Err(Error::Query(_)) => TSTagsError::InvalidQuery,
Err(Error::Regex(_)) => TSTagsError::InvalidRegex,
Err(_) => TSTagsError::Unknown,
}
}
#[no_mangle]
pub extern "C" fn ts_tagger_tag(
this: *mut TSTagger,
scope_name: *const i8,
source_code: *const u8,
source_code_len: u32,
output: *mut TSTagsBuffer,
cancellation_flag: *const AtomicUsize,
) -> TSTagsError {
let tagger = unwrap_mut_ptr(this);
let buffer = unwrap_mut_ptr(output);
let scope_name = unsafe { unwrap(CStr::from_ptr(scope_name).to_str()) };
if let Some(config) = tagger.languages.get(scope_name) {
buffer.tags.clear();
buffer.docs.clear();
let source_code = unsafe { slice::from_raw_parts(source_code, source_code_len as usize) };
let cancellation_flag = unsafe { cancellation_flag.as_ref() };
let tags = match buffer
.context
.generate_tags(config, source_code, cancellation_flag)
{
Ok(tags) => tags,
Err(e) => {
return match e {
Error::InvalidLanguage => TSTagsError::InvalidLanguage,
Error::Cancelled => TSTagsError::Timeout,
_ => TSTagsError::Timeout,
}
}
};
for tag in tags {
let tag = if let Ok(tag) = tag {
tag
} else {
buffer.tags.clear();
buffer.docs.clear();
return TSTagsError::Timeout;
};
let prev_docs_len = buffer.docs.len();
if let Some(docs) = tag.docs {
buffer.docs.extend_from_slice(docs.as_bytes());
}
buffer.tags.push(TSTag {
kind: match tag.kind {
TagKind::Function => TSTagKind::Function,
TagKind::Method => TSTagKind::Method,
TagKind::Class => TSTagKind::Class,
TagKind::Module => TSTagKind::Module,
TagKind::Call => TSTagKind::Call,
},
start_byte: tag.range.start as u32,
end_byte: tag.range.end as u32,
name_start_byte: tag.name_range.start as u32,
name_end_byte: tag.name_range.end as u32,
line_start_byte: tag.line_range.start as u32,
line_end_byte: tag.line_range.end as u32,
start_point: TSPoint {
row: tag.span.start.row as u32,
column: tag.span.start.column as u32,
},
end_point: TSPoint {
row: tag.span.end.row as u32,
column: tag.span.end.column as u32,
},
docs_start_byte: prev_docs_len as u32,
docs_end_byte: buffer.docs.len() as u32,
});
}
TSTagsError::Ok
} else {
TSTagsError::UnknownScope
}
}
#[no_mangle]
pub extern "C" fn ts_tags_buffer_new() -> *mut TSTagsBuffer {
Box::into_raw(Box::new(TSTagsBuffer {
context: TagsContext::new(),
tags: Vec::with_capacity(64),
docs: Vec::with_capacity(64),
}))
}
#[no_mangle]
pub extern "C" fn ts_tags_buffer_delete(this: *mut TSTagsBuffer) {
drop(unsafe { Box::from_raw(this) })
}
#[no_mangle]
pub extern "C" fn ts_tags_buffer_tags(this: *const TSTagsBuffer) -> *const TSTag {
let buffer = unwrap_ptr(this);
buffer.tags.as_ptr()
}
#[no_mangle]
pub extern "C" fn ts_tags_buffer_tags_len(this: *const TSTagsBuffer) -> u32 {
let buffer = unwrap_ptr(this);
buffer.tags.len() as u32
}
#[no_mangle]
pub extern "C" fn ts_tags_buffer_docs(this: *const TSTagsBuffer) -> *const i8 {
let buffer = unwrap_ptr(this);
buffer.docs.as_ptr() as *const i8
}
#[no_mangle]
pub extern "C" fn ts_tags_buffer_docs_len(this: *const TSTagsBuffer) -> u32 {
let buffer = unwrap_ptr(this);
buffer.docs.len() as u32
}
fn unwrap_ptr<'a, T>(result: *const T) -> &'a T {
unsafe { result.as_ref() }.unwrap_or_else(|| {
eprintln!("{}:{} - pointer must not be null", file!(), line!());
abort();
})
}
fn unwrap_mut_ptr<'a, T>(result: *mut T) -> &'a mut T {
unsafe { result.as_mut() }.unwrap_or_else(|| {
eprintln!("{}:{} - pointer must not be null", file!(), line!());
abort();
})
}
fn unwrap<T, E: fmt::Display>(result: Result<T, E>) -> T {
result.unwrap_or_else(|error| {
eprintln!("tree-sitter tag error: {}", error);
abort();
})
}

499
tags/src/lib.rs Normal file
View file

@ -0,0 +1,499 @@
pub mod c_lib;
use memchr::{memchr, memrchr};
use regex::Regex;
use std::ops::Range;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::{fmt, mem, str};
use tree_sitter::{
Language, Parser, Point, Query, QueryCursor, QueryError, QueryPredicateArg, Tree,
};
const MAX_LINE_LEN: usize = 180;
const CANCELLATION_CHECK_INTERVAL: usize = 100;
/// Contains the data neeeded to compute tags for code written in a
/// particular language.
#[derive(Debug)]
pub struct TagsConfiguration {
pub language: Language,
pub query: Query,
call_capture_index: Option<u32>,
class_capture_index: Option<u32>,
doc_capture_index: Option<u32>,
function_capture_index: Option<u32>,
method_capture_index: Option<u32>,
module_capture_index: Option<u32>,
name_capture_index: Option<u32>,
local_scope_capture_index: Option<u32>,
local_definition_capture_index: Option<u32>,
tags_pattern_index: usize,
pattern_info: Vec<PatternInfo>,
}
pub struct TagsContext {
parser: Parser,
cursor: QueryCursor,
}
#[derive(Debug, Clone)]
pub struct Tag {
pub kind: TagKind,
pub range: Range<usize>,
pub name_range: Range<usize>,
pub line_range: Range<usize>,
pub span: Range<Point>,
pub docs: Option<String>,
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum TagKind {
Function,
Method,
Class,
Module,
Call,
}
#[derive(Debug, PartialEq)]
pub enum Error {
Query(QueryError),
Regex(regex::Error),
Cancelled,
InvalidLanguage,
}
#[derive(Debug, Default)]
struct PatternInfo {
docs_adjacent_capture: Option<u32>,
local_scope_inherits: bool,
name_must_be_non_local: bool,
doc_strip_regex: Option<Regex>,
}
#[derive(Debug)]
struct LocalDef<'a> {
name: &'a [u8],
value_range: Range<usize>,
}
#[derive(Debug)]
struct LocalScope<'a> {
inherits: bool,
range: Range<usize>,
local_defs: Vec<LocalDef<'a>>,
}
struct TagsIter<'a, I>
where
I: Iterator<Item = tree_sitter::QueryMatch<'a>>,
{
matches: I,
_tree: Tree,
source: &'a [u8],
config: &'a TagsConfiguration,
cancellation_flag: Option<&'a AtomicUsize>,
iter_count: usize,
tag_queue: Vec<(Tag, usize)>,
scopes: Vec<LocalScope<'a>>,
}
impl TagsConfiguration {
pub fn new(language: Language, tags_query: &str, locals_query: &str) -> Result<Self, Error> {
let query = Query::new(language, &format!("{}{}", locals_query, tags_query))?;
let tags_query_offset = locals_query.len();
let mut tags_pattern_index = 0;
for i in 0..(query.pattern_count()) {
let pattern_offset = query.start_byte_for_pattern(i);
if pattern_offset < tags_query_offset {
tags_pattern_index += 1;
}
}
let mut call_capture_index = None;
let mut class_capture_index = None;
let mut doc_capture_index = None;
let mut function_capture_index = None;
let mut method_capture_index = None;
let mut module_capture_index = None;
let mut name_capture_index = None;
let mut local_scope_capture_index = None;
let mut local_definition_capture_index = None;
for (i, name) in query.capture_names().iter().enumerate() {
let index = match name.as_str() {
"call" => &mut call_capture_index,
"class" => &mut class_capture_index,
"doc" => &mut doc_capture_index,
"function" => &mut function_capture_index,
"method" => &mut method_capture_index,
"module" => &mut module_capture_index,
"name" => &mut name_capture_index,
"local.scope" => &mut local_scope_capture_index,
"local.definition" => &mut local_definition_capture_index,
_ => continue,
};
*index = Some(i as u32);
}
let pattern_info = (0..query.pattern_count())
.map(|pattern_index| {
let mut info = PatternInfo::default();
for (property, is_positive) in query.property_predicates(pattern_index) {
if !is_positive && property.key.as_ref() == "local" {
info.name_must_be_non_local = true;
}
}
info.local_scope_inherits = true;
for property in query.property_settings(pattern_index) {
if property.key.as_ref() == "local.scope-inherits"
&& property
.value
.as_ref()
.map_or(false, |v| v.as_ref() == "false")
{
info.local_scope_inherits = false;
}
}
if let Some(doc_capture_index) = doc_capture_index {
for predicate in query.general_predicates(pattern_index) {
if predicate.args.get(0)
== Some(&QueryPredicateArg::Capture(doc_capture_index))
{
match (predicate.operator.as_ref(), predicate.args.get(1)) {
("select-adjacent!", Some(QueryPredicateArg::Capture(index))) => {
info.docs_adjacent_capture = Some(*index);
}
("strip!", Some(QueryPredicateArg::String(pattern))) => {
let regex = Regex::new(pattern.as_ref())?;
info.doc_strip_regex = Some(regex);
}
_ => {}
}
}
}
}
return Ok(info);
})
.collect::<Result<Vec<_>, Error>>()?;
Ok(TagsConfiguration {
language,
query,
function_capture_index,
class_capture_index,
method_capture_index,
module_capture_index,
doc_capture_index,
call_capture_index,
name_capture_index,
tags_pattern_index,
local_scope_capture_index,
local_definition_capture_index,
pattern_info,
})
}
}
impl TagsContext {
pub fn new() -> Self {
TagsContext {
parser: Parser::new(),
cursor: QueryCursor::new(),
}
}
pub fn generate_tags<'a>(
&'a mut self,
config: &'a TagsConfiguration,
source: &'a [u8],
cancellation_flag: Option<&'a AtomicUsize>,
) -> Result<impl Iterator<Item = Result<Tag, Error>> + 'a, Error> {
self.parser
.set_language(config.language)
.map_err(|_| Error::InvalidLanguage)?;
self.parser.reset();
unsafe { self.parser.set_cancellation_flag(cancellation_flag) };
let tree = self.parser.parse(source, None).ok_or(Error::Cancelled)?;
// The `matches` iterator borrows the `Tree`, which prevents it from being moved.
// But the tree is really just a pointer, so it's actually ok to move it.
let tree_ref = unsafe { mem::transmute::<_, &'static Tree>(&tree) };
let matches = self
.cursor
.matches(&config.query, tree_ref.root_node(), move |node| {
&source[node.byte_range()]
});
Ok(TagsIter {
_tree: tree,
matches,
source,
config,
cancellation_flag,
tag_queue: Vec::new(),
iter_count: 0,
scopes: vec![LocalScope {
range: 0..source.len(),
inherits: false,
local_defs: Vec::new(),
}],
})
}
}
impl<'a, I> Iterator for TagsIter<'a, I>
where
I: Iterator<Item = tree_sitter::QueryMatch<'a>>,
{
type Item = Result<Tag, Error>;
fn next(&mut self) -> Option<Self::Item> {
loop {
// Periodically check for cancellation, returning `Cancelled` error if the
// cancellation flag was flipped.
if let Some(cancellation_flag) = self.cancellation_flag {
self.iter_count += 1;
if self.iter_count >= CANCELLATION_CHECK_INTERVAL {
self.iter_count = 0;
if cancellation_flag.load(Ordering::Relaxed) != 0 {
return Some(Err(Error::Cancelled));
}
}
}
// If there is a queued tag for an earlier node in the syntax tree, then pop
// it off of the queue and return it.
if let Some(last_entry) = self.tag_queue.last() {
if self.tag_queue.len() > 1
&& self.tag_queue[0].0.name_range.end < last_entry.0.name_range.start
{
return Some(Ok(self.tag_queue.remove(0).0));
}
}
// If there is another match, then compute its tag and add it to the
// tag queue.
if let Some(mat) = self.matches.next() {
let pattern_info = &self.config.pattern_info[mat.pattern_index];
if mat.pattern_index < self.config.tags_pattern_index {
for capture in mat.captures {
let index = Some(capture.index);
let range = capture.node.byte_range();
if index == self.config.local_scope_capture_index {
self.scopes.push(LocalScope {
range,
inherits: pattern_info.local_scope_inherits,
local_defs: Vec::new(),
});
} else if index == self.config.local_definition_capture_index {
if let Some(scope) = self.scopes.iter_mut().rev().find(|scope| {
scope.range.start <= range.start && scope.range.end >= range.end
}) {
scope.local_defs.push(LocalDef {
name: &self.source[range.clone()],
value_range: range,
});
}
}
}
continue;
}
let mut name_range = None;
let mut doc_nodes = Vec::new();
let mut tag_node = None;
let mut kind = TagKind::Call;
let mut docs_adjacent_node = None;
for capture in mat.captures {
let index = Some(capture.index);
if index == self.config.pattern_info[mat.pattern_index].docs_adjacent_capture {
docs_adjacent_node = Some(capture.node);
}
if index == self.config.name_capture_index {
name_range = Some(capture.node.byte_range());
} else if index == self.config.doc_capture_index {
doc_nodes.push(capture.node);
} else if index == self.config.call_capture_index {
tag_node = Some(capture.node);
kind = TagKind::Call;
} else if index == self.config.class_capture_index {
tag_node = Some(capture.node);
kind = TagKind::Class;
} else if index == self.config.function_capture_index {
tag_node = Some(capture.node);
kind = TagKind::Function;
} else if index == self.config.method_capture_index {
tag_node = Some(capture.node);
kind = TagKind::Method;
} else if index == self.config.module_capture_index {
tag_node = Some(capture.node);
kind = TagKind::Module;
}
}
if let (Some(tag_node), Some(name_range)) = (tag_node, name_range) {
if pattern_info.name_must_be_non_local {
let mut is_local = false;
for scope in self.scopes.iter().rev() {
if scope.range.start <= name_range.start
&& scope.range.end >= name_range.end
{
if scope
.local_defs
.iter()
.any(|d| d.name == &self.source[name_range.clone()])
{
is_local = true;
break;
}
if !scope.inherits {
break;
}
}
}
if is_local {
continue;
}
}
// If needed, filter the doc nodes based on their ranges, selecting
// only the slice that are adjacent to some specified node.
let mut docs_start_index = 0;
if let (Some(docs_adjacent_node), false) =
(docs_adjacent_node, doc_nodes.is_empty())
{
docs_start_index = doc_nodes.len();
let mut start_row = docs_adjacent_node.start_position().row;
while docs_start_index > 0 {
let doc_node = &doc_nodes[docs_start_index - 1];
let prev_doc_end_row = doc_node.end_position().row;
if prev_doc_end_row + 1 >= start_row {
docs_start_index -= 1;
start_row = doc_node.start_position().row;
} else {
break;
}
}
}
// Generate a doc string from all of the doc nodes, applying any strip regexes.
let mut docs = None;
for doc_node in &doc_nodes[docs_start_index..] {
if let Ok(content) = str::from_utf8(&self.source[doc_node.byte_range()]) {
let content = if let Some(regex) = &pattern_info.doc_strip_regex {
regex.replace_all(content, "").to_string()
} else {
content.to_string()
};
match &mut docs {
None => docs = Some(content),
Some(d) => {
d.push('\n');
d.push_str(&content);
}
}
}
}
// Only create one tag per node. The tag queue is sorted by node position
// to allow for fast lookup.
let range = tag_node.byte_range();
match self
.tag_queue
.binary_search_by_key(&(name_range.end, name_range.start), |(tag, _)| {
(tag.name_range.end, tag.name_range.start)
}) {
Ok(i) => {
let (tag, pattern_index) = &mut self.tag_queue[i];
if *pattern_index > mat.pattern_index {
*pattern_index = mat.pattern_index;
*tag = Tag {
line_range: line_range(self.source, range.start, MAX_LINE_LEN),
span: tag_node.start_position()..tag_node.end_position(),
kind,
range,
name_range,
docs,
};
}
}
Err(i) => self.tag_queue.insert(
i,
(
Tag {
line_range: line_range(self.source, range.start, MAX_LINE_LEN),
span: tag_node.start_position()..tag_node.end_position(),
kind,
range,
name_range,
docs,
},
mat.pattern_index,
),
),
}
}
}
// If there are no more matches, then drain the queue.
else if !self.tag_queue.is_empty() {
return Some(Ok(self.tag_queue.remove(0).0));
} else {
return None;
}
}
}
}
impl fmt::Display for TagKind {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
TagKind::Call => "Call",
TagKind::Module => "Module",
TagKind::Class => "Class",
TagKind::Method => "Method",
TagKind::Function => "Function",
}
.fmt(f)
}
}
impl From<regex::Error> for Error {
fn from(error: regex::Error) -> Self {
Error::Regex(error)
}
}
impl From<QueryError> for Error {
fn from(error: QueryError) -> Self {
Error::Query(error)
}
}
fn line_range(text: &[u8], index: usize, max_line_len: usize) -> Range<usize> {
let start = memrchr(b'\n', &text[0..index]).map_or(0, |i| i + 1);
let max_line_len = max_line_len.min(text.len() - start);
let end = start + memchr(b'\n', &text[start..(start + max_line_len)]).unwrap_or(max_line_len);
start..end
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_get_line() {
let text = b"abc\ndefg\nhijkl";
assert_eq!(line_range(text, 0, 10), 0..3);
assert_eq!(line_range(text, 1, 10), 0..3);
assert_eq!(line_range(text, 2, 10), 0..3);
assert_eq!(line_range(text, 3, 10), 0..3);
assert_eq!(line_range(text, 1, 2), 0..2);
assert_eq!(line_range(text, 4, 10), 4..8);
assert_eq!(line_range(text, 5, 10), 4..8);
assert_eq!(line_range(text, 11, 10), 9..14);
}
}