diff --git a/cli/src/tests/tags_test.rs b/cli/src/tests/tags_test.rs index d4cbc687..756f63e7 100644 --- a/cli/src/tests/tags_test.rs +++ b/cli/src/tests/tags_test.rs @@ -1,7 +1,7 @@ use super::helpers::allocations; -use super::helpers::fixtures::get_language; +use super::helpers::fixtures::{get_language, get_language_queries_path}; use std::ffi::CString; -use std::{ptr, slice, str}; +use std::{fs, ptr, slice, str}; use tree_sitter_tags::c_lib as c; use tree_sitter_tags::{TagKind, TagsConfiguration, TagsContext}; @@ -47,10 +47,22 @@ const JS_TAG_QUERY: &'static str = r#" (call_expression function: (identifier) @name) @call "#; +const RUBY_TAG_QUERY: &'static str = r#" +(method + name: (identifier) @name) @method + +(method_call + method: (identifier) @name) @call + +((identifier) @name @call + (is-not? local)) +"#; + #[test] fn test_tags_python() { let language = get_language("python"); let tags_config = TagsConfiguration::new(language, PYTHON_TAG_QUERY, "").unwrap(); + let mut tag_context = TagsContext::new(); let source = br#" class Customer: @@ -66,7 +78,6 @@ fn test_tags_python() { } "#; - let mut tag_context = TagsContext::new(); let tags = tag_context .generate_tags(&tags_config, source) .collect::>(); @@ -95,8 +106,6 @@ fn test_tags_python() { fn test_tags_javascript() { let language = get_language("javascript"); let tags_config = TagsConfiguration::new(language, JS_TAG_QUERY, "").unwrap(); - - let mut tag_context = TagsContext::new(); let source = br#" // hi @@ -116,6 +125,8 @@ fn test_tags_javascript() { } "#; + + let mut tag_context = TagsContext::new(); let tags = tag_context .generate_tags(&tags_config, source) .collect::>(); @@ -138,6 +149,58 @@ fn test_tags_javascript() { assert_eq!(tags[2].docs, None); } +#[test] +fn test_tags_ruby() { + let language = get_language("ruby"); + let locals_query = + fs::read_to_string(get_language_queries_path("ruby").join("locals.scm")).unwrap(); + let tags_config = TagsConfiguration::new(language, RUBY_TAG_QUERY, &locals_query).unwrap(); + let source = strip_whitespace( + 8, + " + b = 1 + + def foo() + c = 1 + + # a is a method because it is not in scope + # b is a method because `b` doesn't capture variables from its containing scope + bar a, b, c + + [1, 2, 3].each do |a| + # a is a parameter + # b is a method + # c is a variable, because the block captures variables from its containing scope. + baz a, b, c + end + end", + ); + + let mut tag_context = TagsContext::new(); + let tags = tag_context + .generate_tags(&tags_config, source.as_bytes()) + .collect::>(); + + assert_eq!( + tags.iter() + .map(|t| ( + substr(source.as_bytes(), &t.name_range), + t.kind, + (t.span.start.row, t.span.start.column), + )) + .collect::>(), + &[ + ("foo", TagKind::Method, (2, 0)), + ("bar", TagKind::Call, (7, 4)), + ("a", TagKind::Call, (7, 8)), + ("b", TagKind::Call, (7, 11)), + ("each", TagKind::Call, (9, 14)), + ("baz", TagKind::Call, (13, 8)), + ("b", TagKind::Call, (13, 15),), + ] + ); +} + #[test] fn test_tags_via_c_api() { allocations::record(|| { @@ -146,7 +209,9 @@ fn test_tags_via_c_api() { let scope_name = "source.js"; let language = get_language("javascript"); - let source_code = " + let source_code = strip_whitespace( + 12, + " var a = 1; // one @@ -161,13 +226,8 @@ fn test_tags_via_c_api() { } - b(a);" - .lines() - .skip(1) - // remove extra indentation - .map(|line| &line[line.len().min(12)..]) - .collect::>() - .join("\n"); + b(a);", + ); let c_scope_name = CString::new(scope_name).unwrap(); let result = c::ts_tagger_add_language( @@ -238,3 +298,11 @@ fn test_tags_via_c_api() { fn substr<'a>(source: &'a [u8], range: &std::ops::Range) -> &'a str { std::str::from_utf8(&source[range.clone()]).unwrap() } + +fn strip_whitespace(indent: usize, s: &str) -> String { + s.lines() + .skip(1) + .map(|line| &line[line.len().min(indent)..]) + .collect::>() + .join("\n") +} diff --git a/tags/src/lib.rs b/tags/src/lib.rs index e5695845..c3a52303 100644 --- a/tags/src/lib.rs +++ b/tags/src/lib.rs @@ -5,7 +5,7 @@ use regex::Regex; use std::ops::Range; use std::{fmt, mem, str}; use tree_sitter::{ - Language, Node, Parser, Point, Query, QueryCursor, QueryError, QueryPredicateArg, Tree, + Language, Parser, Point, Query, QueryCursor, QueryError, QueryPredicateArg, Tree, }; const MAX_LINE_LEN: usize = 180; @@ -23,8 +23,10 @@ pub struct TagsConfiguration { method_capture_index: Option, module_capture_index: Option, name_capture_index: Option, + local_scope_capture_index: Option, + local_definition_capture_index: Option, + tags_pattern_index: usize, pattern_info: Vec, - _locals_pattern_index: usize, } pub struct TagsContext { @@ -60,9 +62,24 @@ pub enum Error { #[derive(Debug, Default)] struct PatternInfo { docs_adjacent_capture: Option, + local_scope_inherits: bool, + name_must_be_non_local: bool, doc_strip_regex: Option, } +#[derive(Debug)] +struct LocalDef<'a> { + name: &'a [u8], + value_range: Range, +} + +#[derive(Debug)] +struct LocalScope<'a> { + inherits: bool, + range: Range, + local_defs: Vec>, +} + struct TagsIter<'a, I> where I: Iterator>, @@ -71,19 +88,20 @@ where _tree: Tree, source: &'a [u8], config: &'a TagsConfiguration, - tag_queue: Vec<(Node<'a>, usize, Tag)>, + tag_queue: Vec<(Tag, usize)>, + scopes: Vec>, } impl TagsConfiguration { pub fn new(language: Language, tags_query: &str, locals_query: &str) -> Result { - let query = Query::new(language, &format!("{}{}", tags_query, locals_query))?; + let query = Query::new(language, &format!("{}{}", locals_query, tags_query))?; - let locals_query_offset = tags_query.len(); - let mut locals_pattern_index = 0; + let tags_query_offset = locals_query.len(); + let mut tags_pattern_index = 0; for i in 0..(query.pattern_count()) { let pattern_offset = query.start_byte_for_pattern(i); - if pattern_offset < locals_query_offset { - locals_pattern_index += 1; + if pattern_offset < tags_query_offset { + tags_pattern_index += 1; } } @@ -94,6 +112,8 @@ impl TagsConfiguration { let mut method_capture_index = None; let mut module_capture_index = None; let mut name_capture_index = None; + let mut local_scope_capture_index = None; + let mut local_definition_capture_index = None; for (i, name) in query.capture_names().iter().enumerate() { let index = match name.as_str() { "call" => &mut call_capture_index, @@ -103,6 +123,8 @@ impl TagsConfiguration { "method" => &mut method_capture_index, "module" => &mut module_capture_index, "name" => &mut name_capture_index, + "local.scope" => &mut local_scope_capture_index, + "local.definition" => &mut local_definition_capture_index, _ => continue, }; *index = Some(i as u32); @@ -111,6 +133,22 @@ impl TagsConfiguration { let pattern_info = (0..query.pattern_count()) .map(|pattern_index| { let mut info = PatternInfo::default(); + for (property, is_positive) in query.property_predicates(pattern_index) { + if !is_positive && property.key.as_ref() == "local" { + info.name_must_be_non_local = true; + } + } + info.local_scope_inherits = true; + for property in query.property_settings(pattern_index) { + if property.key.as_ref() == "local.scope-inherits" + && property + .value + .as_ref() + .map_or(false, |v| v.as_ref() == "false") + { + info.local_scope_inherits = false; + } + } if let Some(doc_capture_index) = doc_capture_index { for predicate in query.general_predicates(pattern_index) { if predicate.args.get(0) @@ -143,8 +181,10 @@ impl TagsConfiguration { doc_capture_index, call_capture_index, name_capture_index, + tags_pattern_index, + local_scope_capture_index, + local_definition_capture_index, pattern_info, - _locals_pattern_index: locals_pattern_index, }) } } @@ -179,11 +219,16 @@ impl TagsContext { &source[node.byte_range()] }); TagsIter { + _tree: tree, matches, source, config, tag_queue: Vec::new(), - _tree: tree, + scopes: vec![LocalScope { + range: 0..source.len(), + inherits: false, + local_defs: Vec::new(), + }], } } } @@ -200,15 +245,41 @@ where // it off of the queue and return it. if let Some(last_entry) = self.tag_queue.last() { if self.tag_queue.len() > 1 - && self.tag_queue[0].0.end_byte() < last_entry.0.start_byte() + && self.tag_queue[0].0.name_range.end < last_entry.0.name_range.start { - return Some(self.tag_queue.remove(0).2); + return Some(self.tag_queue.remove(0).0); } } // If there is another match, then compute its tag and add it to the // tag queue. if let Some(mat) = self.matches.next() { + let pattern_info = &self.config.pattern_info[mat.pattern_index]; + + if mat.pattern_index < self.config.tags_pattern_index { + for capture in mat.captures { + let index = Some(capture.index); + let range = capture.node.byte_range(); + if index == self.config.local_scope_capture_index { + self.scopes.push(LocalScope { + range, + inherits: pattern_info.local_scope_inherits, + local_defs: Vec::new(), + }); + } else if index == self.config.local_definition_capture_index { + if let Some(scope) = self.scopes.iter_mut().rev().find(|scope| { + scope.range.start <= range.start && scope.range.end >= range.end + }) { + scope.local_defs.push(LocalDef { + name: &self.source[range.clone()], + value_range: range, + }); + } + } + } + continue; + } + let mut name_range = None; let mut doc_nodes = Vec::new(); let mut tag_node = None; @@ -245,6 +316,30 @@ where } if let (Some(tag_node), Some(name_range)) = (tag_node, name_range) { + if pattern_info.name_must_be_non_local { + let mut is_local = false; + for scope in self.scopes.iter().rev() { + if scope.range.start <= name_range.start + && scope.range.end >= name_range.end + { + if scope + .local_defs + .iter() + .any(|d| d.name == &self.source[name_range.clone()]) + { + is_local = true; + break; + } + if !scope.inherits { + break; + } + } + } + if is_local { + continue; + } + } + // If needed, filter the doc nodes based on their ranges, selecting // only the slice that are adjacent to some specified node. let mut docs_start_index = 0; @@ -269,9 +364,7 @@ where let mut docs = None; for doc_node in &doc_nodes[docs_start_index..] { if let Ok(content) = str::from_utf8(&self.source[doc_node.byte_range()]) { - let content = if let Some(regex) = - &self.config.pattern_info[mat.pattern_index].doc_strip_regex - { + let content = if let Some(regex) = &pattern_info.doc_strip_regex { regex.replace_all(content, "").to_string() } else { content.to_string() @@ -289,12 +382,13 @@ where // Only create one tag per node. The tag queue is sorted by node position // to allow for fast lookup. let range = tag_node.byte_range(); - match self.tag_queue.binary_search_by_key( - &(range.end, range.start, tag_node.id()), - |(node, _, _)| (node.end_byte(), node.start_byte(), node.id()), - ) { + match self + .tag_queue + .binary_search_by_key(&(name_range.end, name_range.start), |(tag, _)| { + (tag.name_range.end, tag.name_range.start) + }) { Ok(i) => { - let (_, pattern_index, tag) = &mut self.tag_queue[i]; + let (tag, pattern_index) = &mut self.tag_queue[i]; if *pattern_index > mat.pattern_index { *pattern_index = mat.pattern_index; *tag = Tag { @@ -310,8 +404,6 @@ where Err(i) => self.tag_queue.insert( i, ( - tag_node, - mat.pattern_index, Tag { line_range: line_range(self.source, range.start, MAX_LINE_LEN), span: tag_node.start_position()..tag_node.start_position(), @@ -320,6 +412,7 @@ where name_range, docs, }, + mat.pattern_index, ), ), } @@ -327,7 +420,7 @@ where } // If there are no more matches, then drain the queue. else if !self.tag_queue.is_empty() { - return Some(self.tag_queue.remove(0).2); + return Some(self.tag_queue.remove(0).0); } else { return None; }