Merge pull request #648 from tree-sitter/tagging-improvements

Tagging improvements
This commit is contained in:
Max Brunsfeld 2020-07-10 13:48:23 -07:00 committed by GitHub
commit 0c2dc4c1e9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 464 additions and 229 deletions

View file

@ -83,7 +83,7 @@ impl<'a> From<tree_sitter_highlight::Error> for Error {
impl<'a> From<tree_sitter_tags::Error> for Error {
fn from(error: tree_sitter_tags::Error) -> Self {
Error::new(format!("{:?}", error))
Error::new(format!("{}", error))
}
}

View file

@ -90,13 +90,8 @@ fn run() -> error::Result<()> {
)
.subcommand(
SubCommand::with_name("tags")
.arg(
Arg::with_name("format")
.short("f")
.long("format")
.value_name("json|protobuf")
.help("Determine output format (default: json)"),
)
.arg(Arg::with_name("quiet").long("quiet").short("q"))
.arg(Arg::with_name("time").long("quiet").short("t"))
.arg(Arg::with_name("scope").long("scope").takes_value(true))
.arg(
Arg::with_name("inputs")
@ -104,12 +99,6 @@ fn run() -> error::Result<()> {
.index(1)
.required(true)
.multiple(true),
)
.arg(
Arg::with_name("v")
.short("v")
.multiple(true)
.help("Sets the level of verbosity"),
),
)
.subcommand(
@ -149,8 +138,14 @@ fn run() -> error::Result<()> {
.arg(Arg::with_name("path").index(1).multiple(true)),
)
.subcommand(
SubCommand::with_name("web-ui").about("Test a parser interactively in the browser")
.arg(Arg::with_name("quiet").long("quiet").short("q").help("open in default browser")),
SubCommand::with_name("web-ui")
.about("Test a parser interactively in the browser")
.arg(
Arg::with_name("quiet")
.long("quiet")
.short("q")
.help("open in default browser"),
),
)
.subcommand(
SubCommand::with_name("dump-languages")
@ -268,7 +263,13 @@ fn run() -> error::Result<()> {
} else if let Some(matches) = matches.subcommand_matches("tags") {
loader.find_all_languages(&config.parser_directories)?;
let paths = collect_paths(matches.values_of("inputs").unwrap())?;
tags::generate_tags(&loader, matches.value_of("scope"), &paths)?;
tags::generate_tags(
&loader,
matches.value_of("scope"),
&paths,
matches.is_present("quiet"),
matches.is_present("time"),
)?;
} else if let Some(matches) = matches.subcommand_matches("highlight") {
loader.configure_highlights(&config.theme.highlight_names);
loader.find_all_languages(&config.parser_directories)?;

View file

@ -3,10 +3,17 @@ use super::util;
use crate::error::{Error, Result};
use std::io::{self, Write};
use std::path::Path;
use std::time::Instant;
use std::{fs, str};
use tree_sitter_tags::TagsContext;
pub fn generate_tags(loader: &Loader, scope: Option<&str>, paths: &[String]) -> Result<()> {
pub fn generate_tags(
loader: &Loader,
scope: Option<&str>,
paths: &[String],
quiet: bool,
time: bool,
) -> Result<()> {
let mut lang = None;
if let Some(scope) = scope {
lang = loader.language_configuration_for_scope(scope)?;
@ -34,28 +41,50 @@ pub fn generate_tags(loader: &Loader, scope: Option<&str>, paths: &[String]) ->
};
if let Some(tags_config) = language_config.tags_config(language)? {
let path_str = format!("{:?}", path);
writeln!(&mut stdout, "{}", &path_str[1..path_str.len() - 1])?;
let indent;
if paths.len() > 1 {
if !quiet {
writeln!(&mut stdout, "{}", path.to_string_lossy())?;
}
indent = "\t"
} else {
indent = "";
};
let source = fs::read(path)?;
let t0 = Instant::now();
for tag in context.generate_tags(tags_config, &source, Some(&cancellation_flag))? {
let tag = tag?;
write!(
&mut stdout,
" {:<8} {:<40}\t{:>9}-{:<9}",
tag.kind,
str::from_utf8(&source[tag.name_range]).unwrap_or(""),
tag.span.start,
tag.span.end,
)?;
if let Some(docs) = tag.docs {
if docs.len() > 120 {
write!(&mut stdout, "\t{:?}...", &docs[0..120])?;
} else {
write!(&mut stdout, "\t{:?}", &docs)?;
if !quiet {
write!(
&mut stdout,
"{}{:<10}\t | {:<8}\t{} {} - {} `{}`",
indent,
str::from_utf8(&source[tag.name_range]).unwrap_or(""),
&tags_config.syntax_type_name(tag.syntax_type_id),
if tag.is_definition { "def" } else { "ref" },
tag.span.start,
tag.span.end,
str::from_utf8(&source[tag.line_range]).unwrap_or(""),
)?;
if let Some(docs) = tag.docs {
if docs.len() > 120 {
write!(&mut stdout, "\t{:?}...", &docs[0..120])?;
} else {
write!(&mut stdout, "\t{:?}", &docs)?;
}
}
writeln!(&mut stdout, "")?;
}
writeln!(&mut stdout, "")?;
}
if time {
writeln!(
&mut stdout,
"{}time: {}ms",
indent,
t0.elapsed().as_millis(),
)?;
}
} else {
eprintln!("No tags config found for path {:?}", path);

View file

@ -1,73 +1,79 @@
use super::helpers::allocations;
use super::helpers::fixtures::{get_language, get_language_queries_path};
use std::ffi::CStr;
use std::ffi::CString;
use std::{fs, ptr, slice, str};
use tree_sitter::Point;
use tree_sitter_tags::c_lib as c;
use tree_sitter_tags::{Error, TagKind, TagsConfiguration, TagsContext};
use tree_sitter_tags::{Error, TagsConfiguration, TagsContext};
const PYTHON_TAG_QUERY: &'static str = r#"
(
(function_definition
name: (identifier) @name
body: (block . (expression_statement (string) @doc))) @function
(#strip! @doc "(^['\"\\s]*)|(['\"\\s]*$)")
(function_definition
name: (identifier) @name
body: (block . (expression_statement (string) @doc))) @definition.function
(#strip! @doc "(^['\"\\s]*)|(['\"\\s]*$)")
)
(function_definition
name: (identifier) @name) @function
name: (identifier) @name) @definition.function
(
(class_definition
name: (identifier) @name
body: (block
. (expression_statement (string) @doc))) @class
(#strip! @doc "(^['\"\\s]*)|(['\"\\s]*$)")
(class_definition
name: (identifier) @name
body: (block
. (expression_statement (string) @doc))) @definition.class
(#strip! @doc "(^['\"\\s]*)|(['\"\\s]*$)")
)
(class_definition
name: (identifier) @name) @class
name: (identifier) @name) @definition.class
(call
function: (identifier) @name) @call
function: (identifier) @name) @reference.call
(call
function: (attribute
attribute: (identifier) @name)) @reference.call
"#;
const JS_TAG_QUERY: &'static str = r#"
(
(comment)* @doc .
(class_declaration
name: (identifier) @name) @class
(#select-adjacent! @doc @class)
name: (identifier) @name) @definition.class
(#select-adjacent! @doc @definition.class)
(#strip! @doc "(^[/\\*\\s]*)|([/\\*\\s]*$)")
)
(
(comment)* @doc .
(method_definition
name: (property_identifier) @name) @method
(#select-adjacent! @doc @method)
name: (property_identifier) @name) @definition.method
(#select-adjacent! @doc @definition.method)
(#strip! @doc "(^[/\\*\\s]*)|([/\\*\\s]*$)")
)
(
(comment)* @doc .
(function_declaration
name: (identifier) @name) @function
(#select-adjacent! @doc @function)
name: (identifier) @name) @definition.function
(#select-adjacent! @doc @definition.function)
(#strip! @doc "(^[/\\*\\s]*)|([/\\*\\s]*$)")
)
(call_expression
function: (identifier) @name) @call
function: (identifier) @name) @reference.call
"#;
const RUBY_TAG_QUERY: &'static str = r#"
(method
name: (identifier) @name) @method
name: (identifier) @name) @definition.method
(method_call
method: (identifier) @name) @call
method: (identifier) @name) @reference.call
((identifier) @name @call
((identifier) @name @reference.call
(#is-not? local))
"#;
@ -99,20 +105,20 @@ fn test_tags_python() {
assert_eq!(
tags.iter()
.map(|t| (substr(source, &t.name_range), t.kind))
.map(|t| (
substr(source, &t.name_range),
tags_config.syntax_type_name(t.syntax_type_id)
))
.collect::<Vec<_>>(),
&[
("Customer", TagKind::Class),
("age", TagKind::Function),
("compute_age", TagKind::Call),
("Customer", "class"),
("age", "function"),
("compute_age", "call"),
]
);
assert_eq!(substr(source, &tags[0].line_range), " class Customer:");
assert_eq!(
substr(source, &tags[1].line_range),
" def age(self):"
);
assert_eq!(substr(source, &tags[0].line_range), "class Customer:");
assert_eq!(substr(source, &tags[1].line_range), "def age(self):");
assert_eq!(tags[0].docs.as_ref().unwrap(), "Data about a customer");
assert_eq!(tags[1].docs.as_ref().unwrap(), "Get the customer's age");
}
@ -150,12 +156,16 @@ fn test_tags_javascript() {
assert_eq!(
tags.iter()
.map(|t| (substr(source, &t.name_range), t.kind))
.map(|t| (
substr(source, &t.name_range),
t.span.clone(),
tags_config.syntax_type_name(t.syntax_type_id)
))
.collect::<Vec<_>>(),
&[
("Customer", TagKind::Class),
("getAge", TagKind::Method),
("Agent", TagKind::Class)
("Customer", Point::new(5, 10)..Point::new(5, 18), "class",),
("getAge", Point::new(9, 8)..Point::new(9, 14), "method",),
("Agent", Point::new(15, 10)..Point::new(15, 15), "class",)
]
);
assert_eq!(
@ -166,6 +176,26 @@ fn test_tags_javascript() {
assert_eq!(tags[2].docs, None);
}
#[test]
fn test_tags_columns_measured_in_utf16_code_units() {
let language = get_language("python");
let tags_config = TagsConfiguration::new(language, PYTHON_TAG_QUERY, "").unwrap();
let mut tag_context = TagsContext::new();
let source = r#""❤️❤️❤️".hello_α_ω()"#.as_bytes();
let tag = tag_context
.generate_tags(&tags_config, source, None)
.unwrap()
.next()
.unwrap()
.unwrap();
assert_eq!(substr(source, &tag.name_range), "hello_α");
assert_eq!(tag.span, Point::new(0, 21)..Point::new(0, 32));
assert_eq!(tag.utf16_column_range, 9..18);
}
#[test]
fn test_tags_ruby() {
let language = get_language("ruby");
@ -204,18 +234,18 @@ fn test_tags_ruby() {
tags.iter()
.map(|t| (
substr(source.as_bytes(), &t.name_range),
t.kind,
tags_config.syntax_type_name(t.syntax_type_id),
(t.span.start.row, t.span.start.column),
))
.collect::<Vec<_>>(),
&[
("foo", TagKind::Method, (2, 0)),
("bar", TagKind::Call, (7, 4)),
("a", TagKind::Call, (7, 8)),
("b", TagKind::Call, (7, 11)),
("each", TagKind::Call, (9, 14)),
("baz", TagKind::Call, (13, 8)),
("b", TagKind::Call, (13, 15),),
("foo", "method", (2, 4)),
("bar", "call", (7, 4)),
("a", "call", (7, 8)),
("b", "call", (7, 11)),
("each", "call", (9, 14)),
("baz", "call", (13, 8)),
("b", "call", (13, 15),),
]
);
}
@ -253,6 +283,14 @@ fn test_tags_cancellation() {
});
}
#[test]
fn test_invalid_capture() {
let language = get_language("python");
let e = TagsConfiguration::new(language, "(identifier) @method", "")
.expect_err("expected InvalidCapture error");
assert_eq!(e, Error::InvalidCapture("method".to_string()));
}
#[test]
fn test_tags_via_c_api() {
allocations::record(|| {
@ -316,29 +354,29 @@ fn test_tags_via_c_api() {
})
.unwrap();
let syntax_types: Vec<&str> = unsafe {
let mut len: u32 = 0;
let ptr =
c::ts_tagger_syntax_kinds_for_scope_name(tagger, c_scope_name.as_ptr(), &mut len);
slice::from_raw_parts(ptr, len as usize)
.iter()
.map(|i| CStr::from_ptr(*i).to_str().unwrap())
.collect()
};
assert_eq!(
tags.iter()
.map(|tag| (
tag.kind,
syntax_types[tag.syntax_type_id as usize],
&source_code[tag.name_start_byte as usize..tag.name_end_byte as usize],
&source_code[tag.line_start_byte as usize..tag.line_end_byte as usize],
&docs[tag.docs_start_byte as usize..tag.docs_end_byte as usize],
))
.collect::<Vec<_>>(),
&[
(
c::TSTagKind::Function,
"b",
"function b() {",
"one\ntwo\nthree"
),
(
c::TSTagKind::Class,
"C",
"class C extends D {",
"four\nfive"
),
(c::TSTagKind::Call, "b", "b(a);", "")
("function", "b", "function b() {", "one\ntwo\nthree"),
("class", "C", "class C extends D {", "four\nfive"),
("call", "b", "b(a);", "")
]
);

View file

@ -16,18 +16,10 @@ typedef enum {
TSTagsInvalidUtf8,
TSTagsInvalidRegex,
TSTagsInvalidQuery,
TSTagsInvalidCapture,
} TSTagsError;
typedef enum {
TSTagKindFunction,
TSTagKindMethod,
TSTagKindClass,
TSTagKindModule,
TSTagKindCall,
} TSTagKind;
typedef struct {
TSTagKind kind;
uint32_t start_byte;
uint32_t end_byte;
uint32_t name_start_byte;
@ -36,8 +28,12 @@ typedef struct {
uint32_t line_end_byte;
TSPoint start_point;
TSPoint end_point;
uint32_t utf16_start_column;
uint32_t utf16_end_column;
uint32_t docs_start_byte;
uint32_t docs_end_byte;
uint32_t syntax_type_id;
bool is_definition;
} TSTag;
typedef struct TSTagger TSTagger;
@ -89,6 +85,9 @@ uint32_t ts_tags_buffer_tags_len(const TSTagsBuffer *);
const char *ts_tags_buffer_docs(const TSTagsBuffer *);
uint32_t ts_tags_buffer_docs_len(const TSTagsBuffer *);
// Get the syntax kinds for a scope.
const char **ts_tagger_syntax_kinds_for_scope_name(const TSTagger *, const char *scope_name, uint32_t *len);
#ifdef __cplusplus
}
#endif

View file

@ -1,4 +1,4 @@
use super::{Error, TagKind, TagsConfiguration, TagsContext};
use super::{Error, TagsConfiguration, TagsContext};
use std::collections::HashMap;
use std::ffi::CStr;
use std::process::abort;
@ -16,19 +16,10 @@ pub enum TSTagsError {
InvalidUtf8,
InvalidRegex,
InvalidQuery,
InvalidCapture,
Unknown,
}
#[repr(C)]
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum TSTagKind {
Function,
Method,
Class,
Module,
Call,
}
#[repr(C)]
pub struct TSPoint {
row: u32,
@ -37,7 +28,6 @@ pub struct TSPoint {
#[repr(C)]
pub struct TSTag {
pub kind: TSTagKind,
pub start_byte: u32,
pub end_byte: u32,
pub name_start_byte: u32,
@ -46,8 +36,12 @@ pub struct TSTag {
pub line_end_byte: u32,
pub start_point: TSPoint,
pub end_point: TSPoint,
pub utf16_start_colum: u32,
pub utf16_end_colum: u32,
pub docs_start_byte: u32,
pub docs_end_byte: u32,
pub syntax_type_id: u32,
pub is_definition: bool,
}
pub struct TSTagger {
@ -102,7 +96,9 @@ pub extern "C" fn ts_tagger_add_language(
}
Err(Error::Query(_)) => TSTagsError::InvalidQuery,
Err(Error::Regex(_)) => TSTagsError::InvalidRegex,
Err(_) => TSTagsError::Unknown,
Err(Error::Cancelled) => TSTagsError::Timeout,
Err(Error::InvalidLanguage) => TSTagsError::InvalidLanguage,
Err(Error::InvalidCapture(_)) => TSTagsError::InvalidCapture,
}
}
@ -153,13 +149,6 @@ pub extern "C" fn ts_tagger_tag(
buffer.docs.extend_from_slice(docs.as_bytes());
}
buffer.tags.push(TSTag {
kind: match tag.kind {
TagKind::Function => TSTagKind::Function,
TagKind::Method => TSTagKind::Method,
TagKind::Class => TSTagKind::Class,
TagKind::Module => TSTagKind::Module,
TagKind::Call => TSTagKind::Call,
},
start_byte: tag.range.start as u32,
end_byte: tag.range.end as u32,
name_start_byte: tag.name_range.start as u32,
@ -174,8 +163,12 @@ pub extern "C" fn ts_tagger_tag(
row: tag.span.end.row as u32,
column: tag.span.end.column as u32,
},
utf16_start_colum: tag.utf16_column_range.start as u32,
utf16_end_colum: tag.utf16_column_range.end as u32,
docs_start_byte: prev_docs_len as u32,
docs_end_byte: buffer.docs.len() as u32,
syntax_type_id: tag.syntax_type_id,
is_definition: tag.is_definition,
});
}
@ -223,6 +216,24 @@ pub extern "C" fn ts_tags_buffer_docs_len(this: *const TSTagsBuffer) -> u32 {
buffer.docs.len() as u32
}
#[no_mangle]
pub extern "C" fn ts_tagger_syntax_kinds_for_scope_name(
this: *mut TSTagger,
scope_name: *const i8,
len: *mut u32,
) -> *const *const i8 {
let tagger = unwrap_mut_ptr(this);
let scope_name = unsafe { unwrap(CStr::from_ptr(scope_name).to_str()) };
let len = unwrap_mut_ptr(len);
*len = 0;
if let Some(config) = tagger.languages.get(scope_name) {
*len = config.c_syntax_type_names.len() as u32;
return config.c_syntax_type_names.as_ptr() as *const *const i8;
}
std::ptr::null()
}
fn unwrap_ptr<'a, T>(result: *const T) -> &'a T {
unsafe { result.as_ref() }.unwrap_or_else(|| {
eprintln!("{}:{} - pointer must not be null", file!(), line!());

View file

@ -1,10 +1,12 @@
pub mod c_lib;
use memchr::{memchr, memrchr};
use memchr::memchr;
use regex::Regex;
use std::collections::HashMap;
use std::ffi::{CStr, CString};
use std::ops::Range;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::{fmt, mem, str};
use std::{char, fmt, mem, str};
use tree_sitter::{
Language, Parser, Point, Query, QueryCursor, QueryError, QueryPredicateArg, Tree,
};
@ -18,12 +20,10 @@ const CANCELLATION_CHECK_INTERVAL: usize = 100;
pub struct TagsConfiguration {
pub language: Language,
pub query: Query,
call_capture_index: Option<u32>,
class_capture_index: Option<u32>,
syntax_type_names: Vec<Box<[u8]>>,
c_syntax_type_names: Vec<*const u8>,
capture_map: HashMap<u32, NamedCapture>,
doc_capture_index: Option<u32>,
function_capture_index: Option<u32>,
method_capture_index: Option<u32>,
module_capture_index: Option<u32>,
name_capture_index: Option<u32>,
local_scope_capture_index: Option<u32>,
local_definition_capture_index: Option<u32>,
@ -31,6 +31,12 @@ pub struct TagsConfiguration {
pattern_info: Vec<PatternInfo>,
}
#[derive(Debug)]
pub struct NamedCapture {
pub syntax_type_id: u32,
pub is_definition: bool,
}
pub struct TagsContext {
parser: Parser,
cursor: QueryCursor,
@ -38,21 +44,14 @@ pub struct TagsContext {
#[derive(Debug, Clone)]
pub struct Tag {
pub kind: TagKind,
pub range: Range<usize>,
pub name_range: Range<usize>,
pub line_range: Range<usize>,
pub span: Range<Point>,
pub utf16_column_range: Range<usize>,
pub docs: Option<String>,
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum TagKind {
Function,
Method,
Class,
Module,
Call,
pub is_definition: bool,
pub syntax_type_id: u32,
}
#[derive(Debug, PartialEq)]
@ -61,6 +60,7 @@ pub enum Error {
Regex(regex::Error),
Cancelled,
InvalidLanguage,
InvalidCapture(String),
}
#[derive(Debug, Default)]
@ -91,6 +91,7 @@ where
matches: I,
_tree: Tree,
source: &'a [u8],
prev_line_info: Option<LineInfo>,
config: &'a TagsConfiguration,
cancellation_flag: Option<&'a AtomicUsize>,
iter_count: usize,
@ -98,6 +99,18 @@ where
scopes: Vec<LocalScope<'a>>,
}
struct LineInfo {
utf8_position: Point,
utf8_byte: usize,
utf16_column: usize,
line_range: Range<usize>,
}
struct LossyUtf8<'a> {
bytes: &'a [u8],
in_replacement: bool,
}
impl TagsConfiguration {
pub fn new(language: Language, tags_query: &str, locals_query: &str) -> Result<Self, Error> {
let query = Query::new(language, &format!("{}{}", locals_query, tags_query))?;
@ -111,31 +124,55 @@ impl TagsConfiguration {
}
}
let mut call_capture_index = None;
let mut class_capture_index = None;
let mut capture_map = HashMap::new();
let mut syntax_type_names = Vec::new();
let mut doc_capture_index = None;
let mut function_capture_index = None;
let mut method_capture_index = None;
let mut module_capture_index = None;
let mut name_capture_index = None;
let mut local_scope_capture_index = None;
let mut local_definition_capture_index = None;
for (i, name) in query.capture_names().iter().enumerate() {
let index = match name.as_str() {
"call" => &mut call_capture_index,
"class" => &mut class_capture_index,
"doc" => &mut doc_capture_index,
"function" => &mut function_capture_index,
"method" => &mut method_capture_index,
"module" => &mut module_capture_index,
"name" => &mut name_capture_index,
"local.scope" => &mut local_scope_capture_index,
"local.definition" => &mut local_definition_capture_index,
_ => continue,
};
*index = Some(i as u32);
match name.as_str() {
"" => continue,
"name" => name_capture_index = Some(i as u32),
"doc" => doc_capture_index = Some(i as u32),
"local.scope" => local_scope_capture_index = Some(i as u32),
"local.definition" => local_definition_capture_index = Some(i as u32),
"local.reference" => continue,
_ => {
let mut is_definition = false;
let kind = if name.starts_with("definition.") {
is_definition = true;
name.trim_start_matches("definition.")
} else if name.starts_with("reference.") {
name.trim_start_matches("reference.")
} else {
return Err(Error::InvalidCapture(name.to_string()));
};
if let Ok(cstr) = CString::new(kind) {
let c_kind = cstr.to_bytes_with_nul().to_vec().into_boxed_slice();
let syntax_type_id = syntax_type_names
.iter()
.position(|n| n == &c_kind)
.unwrap_or_else(|| {
syntax_type_names.push(c_kind);
syntax_type_names.len() - 1
}) as u32;
capture_map.insert(
i as u32,
NamedCapture {
syntax_type_id,
is_definition,
},
);
}
}
}
}
let c_syntax_type_names = syntax_type_names.iter().map(|s| s.as_ptr()).collect();
let pattern_info = (0..query.pattern_count())
.map(|pattern_index| {
let mut info = PatternInfo::default();
@ -180,12 +217,10 @@ impl TagsConfiguration {
Ok(TagsConfiguration {
language,
query,
function_capture_index,
class_capture_index,
method_capture_index,
module_capture_index,
syntax_type_names,
c_syntax_type_names,
capture_map,
doc_capture_index,
call_capture_index,
name_capture_index,
tags_pattern_index,
local_scope_capture_index,
@ -193,6 +228,14 @@ impl TagsConfiguration {
pattern_info,
})
}
pub fn syntax_type_name(&self, id: u32) -> &str {
unsafe {
let cstr = CStr::from_ptr(self.syntax_type_names[id as usize].as_ptr() as *const i8)
.to_bytes();
str::from_utf8(cstr).expect("syntax type name was not valid utf-8")
}
}
}
impl TagsContext {
@ -230,6 +273,7 @@ impl TagsContext {
source,
config,
cancellation_flag,
prev_line_info: None,
tag_queue: Vec::new(),
iter_count: 0,
scopes: vec![LocalScope {
@ -300,10 +344,11 @@ where
continue;
}
let mut name_range = None;
let mut name_node = None;
let mut doc_nodes = Vec::new();
let mut tag_node = None;
let mut kind = TagKind::Call;
let mut syntax_type_id = 0;
let mut is_definition = false;
let mut docs_adjacent_node = None;
for capture in mat.captures {
@ -314,28 +359,21 @@ where
}
if index == self.config.name_capture_index {
name_range = Some(capture.node.byte_range());
name_node = Some(capture.node);
} else if index == self.config.doc_capture_index {
doc_nodes.push(capture.node);
} else if index == self.config.call_capture_index {
}
if let Some(named_capture) = self.config.capture_map.get(&capture.index) {
tag_node = Some(capture.node);
kind = TagKind::Call;
} else if index == self.config.class_capture_index {
tag_node = Some(capture.node);
kind = TagKind::Class;
} else if index == self.config.function_capture_index {
tag_node = Some(capture.node);
kind = TagKind::Function;
} else if index == self.config.method_capture_index {
tag_node = Some(capture.node);
kind = TagKind::Method;
} else if index == self.config.module_capture_index {
tag_node = Some(capture.node);
kind = TagKind::Module;
syntax_type_id = named_capture.syntax_type_id;
is_definition = named_capture.is_definition;
}
}
if let (Some(tag_node), Some(name_range)) = (tag_node, name_range) {
if let (Some(tag_node), Some(name_node)) = (tag_node, name_node) {
let name_range = name_node.byte_range();
if pattern_info.name_must_be_non_local {
let mut is_local = false;
for scope in self.scopes.iter().rev() {
@ -399,42 +437,73 @@ where
}
}
let range = tag_node.byte_range();
let span = name_node.start_position()..name_node.end_position();
// Compute tag properties that depend on the text of the containing line. If the
// previous tag occurred on the same line, then reuse results from the previous tag.
let line_range;
let mut prev_utf16_column = 0;
let mut prev_utf8_byte = name_range.start - span.start.column;
let line_info = self.prev_line_info.as_ref().and_then(|info| {
if info.utf8_position.row == span.start.row {
Some(info)
} else {
None
}
});
if let Some(line_info) = line_info {
line_range = line_info.line_range.clone();
if line_info.utf8_position.column <= span.start.column {
prev_utf8_byte = line_info.utf8_byte;
prev_utf16_column = line_info.utf16_column;
}
} else {
line_range = self::line_range(
self.source,
name_range.start,
span.start,
MAX_LINE_LEN,
);
}
let utf16_start_column = prev_utf16_column
+ utf16_len(&self.source[prev_utf8_byte..name_range.start]);
let utf16_end_column =
utf16_start_column + utf16_len(&self.source[name_range.clone()]);
let utf16_column_range = utf16_start_column..utf16_end_column;
self.prev_line_info = Some(LineInfo {
utf8_position: span.end,
utf8_byte: name_range.end,
utf16_column: utf16_end_column,
line_range: line_range.clone(),
});
let tag = Tag {
line_range,
span,
utf16_column_range,
range,
name_range,
docs,
is_definition,
syntax_type_id,
};
// Only create one tag per node. The tag queue is sorted by node position
// to allow for fast lookup.
let range = tag_node.byte_range();
match self
.tag_queue
.binary_search_by_key(&(name_range.end, name_range.start), |(tag, _)| {
(tag.name_range.end, tag.name_range.start)
}) {
match self.tag_queue.binary_search_by_key(
&(tag.name_range.end, tag.name_range.start),
|(tag, _)| (tag.name_range.end, tag.name_range.start),
) {
Ok(i) => {
let (tag, pattern_index) = &mut self.tag_queue[i];
let (existing_tag, pattern_index) = &mut self.tag_queue[i];
if *pattern_index > mat.pattern_index {
*pattern_index = mat.pattern_index;
*tag = Tag {
line_range: line_range(self.source, range.start, MAX_LINE_LEN),
span: tag_node.start_position()..tag_node.end_position(),
kind,
range,
name_range,
docs,
};
*existing_tag = tag;
}
}
Err(i) => self.tag_queue.insert(
i,
(
Tag {
line_range: line_range(self.source, range.start, MAX_LINE_LEN),
span: tag_node.start_position()..tag_node.end_position(),
kind,
range,
name_range,
docs,
},
mat.pattern_index,
),
),
Err(i) => self.tag_queue.insert(i, (tag, mat.pattern_index)),
}
}
}
@ -448,16 +517,12 @@ where
}
}
impl fmt::Display for TagKind {
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
TagKind::Call => "Call",
TagKind::Module => "Module",
TagKind::Class => "Class",
TagKind::Method => "Method",
TagKind::Function => "Function",
Error::InvalidCapture(name) => write!(f, "Invalid capture @{}. Expected one of: @definition.*, @reference.*, @doc, @name, @local.(scope|definition|reference).", name),
_ => write!(f, "{:?}", self)
}
.fmt(f)
}
}
@ -473,11 +538,90 @@ impl From<QueryError> for Error {
}
}
fn line_range(text: &[u8], index: usize, max_line_len: usize) -> Range<usize> {
let start = memrchr(b'\n', &text[0..index]).map_or(0, |i| i + 1);
let max_line_len = max_line_len.min(text.len() - start);
let end = start + memchr(b'\n', &text[start..(start + max_line_len)]).unwrap_or(max_line_len);
start..end
// TODO: Remove this struct at at some point. If `core::str::lossy::Utf8Lossy`
// is ever stabilized, we should use that. Otherwise, this struct could be moved
// into some module that's shared between `tree-sitter-tags` and `tree-sitter-highlight`.
impl<'a> LossyUtf8<'a> {
fn new(bytes: &'a [u8]) -> Self {
LossyUtf8 {
bytes,
in_replacement: false,
}
}
}
impl<'a> Iterator for LossyUtf8<'a> {
type Item = &'a str;
fn next(&mut self) -> Option<&'a str> {
if self.bytes.is_empty() {
return None;
}
if self.in_replacement {
self.in_replacement = false;
return Some("\u{fffd}");
}
match str::from_utf8(self.bytes) {
Ok(valid) => {
self.bytes = &[];
Some(valid)
}
Err(error) => {
if let Some(error_len) = error.error_len() {
let error_start = error.valid_up_to();
if error_start > 0 {
let result =
unsafe { str::from_utf8_unchecked(&self.bytes[..error_start]) };
self.bytes = &self.bytes[(error_start + error_len)..];
self.in_replacement = true;
Some(result)
} else {
self.bytes = &self.bytes[error_len..];
Some("\u{fffd}")
}
} else {
None
}
}
}
}
}
fn line_range(
text: &[u8],
start_byte: usize,
start_point: Point,
max_line_len: usize,
) -> Range<usize> {
// Trim leading whitespace
let mut line_start_byte = start_byte - start_point.column;
while line_start_byte < text.len() && text[line_start_byte].is_ascii_whitespace() {
line_start_byte += 1;
}
let max_line_len = max_line_len.min(text.len() - line_start_byte);
let text_after_line_start = &text[line_start_byte..(line_start_byte + max_line_len)];
let line_len = if let Some(len) = memchr(b'\n', text_after_line_start) {
len
} else if let Err(e) = str::from_utf8(text_after_line_start) {
e.valid_up_to()
} else {
max_line_len
};
// Trim trailing whitespace
let mut line_end_byte = line_start_byte + line_len;
while line_end_byte > line_start_byte && text[line_end_byte - 1].is_ascii_whitespace() {
line_end_byte -= 1;
}
line_start_byte..line_end_byte
}
fn utf16_len(bytes: &[u8]) -> usize {
LossyUtf8::new(bytes)
.flat_map(|chunk| chunk.chars().map(char::len_utf16))
.sum()
}
#[cfg(test)]
@ -486,14 +630,27 @@ mod tests {
#[test]
fn test_get_line() {
let text = b"abc\ndefg\nhijkl";
assert_eq!(line_range(text, 0, 10), 0..3);
assert_eq!(line_range(text, 1, 10), 0..3);
assert_eq!(line_range(text, 2, 10), 0..3);
assert_eq!(line_range(text, 3, 10), 0..3);
assert_eq!(line_range(text, 1, 2), 0..2);
assert_eq!(line_range(text, 4, 10), 4..8);
assert_eq!(line_range(text, 5, 10), 4..8);
assert_eq!(line_range(text, 11, 10), 9..14);
let text = "abc\ndefg❤hij\nklmno".as_bytes();
assert_eq!(line_range(text, 5, Point::new(1, 1), 30), 4..14);
assert_eq!(line_range(text, 5, Point::new(1, 1), 6), 4..8);
assert_eq!(line_range(text, 17, Point::new(2, 2), 30), 15..20);
assert_eq!(line_range(text, 17, Point::new(2, 2), 4), 15..19);
}
#[test]
fn test_get_line_trims() {
let text = b" foo\nbar\n";
assert_eq!(line_range(text, 0, Point::new(0, 0), 10), 3..6);
let text = b"\t func foo \nbar\n";
assert_eq!(line_range(text, 0, Point::new(0, 0), 10), 2..10);
let r = line_range(text, 0, Point::new(0, 0), 14);
assert_eq!(r, 2..10);
assert_eq!(str::from_utf8(&text[r]).unwrap_or(""), "func foo");
let r = line_range(text, 12, Point::new(1, 0), 14);
assert_eq!(r, 12..15);
assert_eq!(str::from_utf8(&text[r]).unwrap_or(""), "bar");
}
}