diff --git a/cli/src/error.rs b/cli/src/error.rs index 824bd92f..d583d1b9 100644 --- a/cli/src/error.rs +++ b/cli/src/error.rs @@ -83,7 +83,7 @@ impl<'a> From for Error { impl<'a> From for Error { fn from(error: tree_sitter_tags::Error) -> Self { - Error::new(format!("{:?}", error)) + Error::new(format!("{}", error)) } } diff --git a/cli/src/main.rs b/cli/src/main.rs index 757c70eb..713bf28f 100644 --- a/cli/src/main.rs +++ b/cli/src/main.rs @@ -90,13 +90,8 @@ fn run() -> error::Result<()> { ) .subcommand( SubCommand::with_name("tags") - .arg( - Arg::with_name("format") - .short("f") - .long("format") - .value_name("json|protobuf") - .help("Determine output format (default: json)"), - ) + .arg(Arg::with_name("quiet").long("quiet").short("q")) + .arg(Arg::with_name("time").long("quiet").short("t")) .arg(Arg::with_name("scope").long("scope").takes_value(true)) .arg( Arg::with_name("inputs") @@ -104,12 +99,6 @@ fn run() -> error::Result<()> { .index(1) .required(true) .multiple(true), - ) - .arg( - Arg::with_name("v") - .short("v") - .multiple(true) - .help("Sets the level of verbosity"), ), ) .subcommand( @@ -149,8 +138,14 @@ fn run() -> error::Result<()> { .arg(Arg::with_name("path").index(1).multiple(true)), ) .subcommand( - SubCommand::with_name("web-ui").about("Test a parser interactively in the browser") - .arg(Arg::with_name("quiet").long("quiet").short("q").help("open in default browser")), + SubCommand::with_name("web-ui") + .about("Test a parser interactively in the browser") + .arg( + Arg::with_name("quiet") + .long("quiet") + .short("q") + .help("open in default browser"), + ), ) .subcommand( SubCommand::with_name("dump-languages") @@ -268,7 +263,13 @@ fn run() -> error::Result<()> { } else if let Some(matches) = matches.subcommand_matches("tags") { loader.find_all_languages(&config.parser_directories)?; let paths = collect_paths(matches.values_of("inputs").unwrap())?; - tags::generate_tags(&loader, matches.value_of("scope"), &paths)?; + tags::generate_tags( + &loader, + matches.value_of("scope"), + &paths, + matches.is_present("quiet"), + matches.is_present("time"), + )?; } else if let Some(matches) = matches.subcommand_matches("highlight") { loader.configure_highlights(&config.theme.highlight_names); loader.find_all_languages(&config.parser_directories)?; diff --git a/cli/src/tags.rs b/cli/src/tags.rs index d6704ec5..5ea00f39 100644 --- a/cli/src/tags.rs +++ b/cli/src/tags.rs @@ -3,10 +3,17 @@ use super::util; use crate::error::{Error, Result}; use std::io::{self, Write}; use std::path::Path; +use std::time::Instant; use std::{fs, str}; use tree_sitter_tags::TagsContext; -pub fn generate_tags(loader: &Loader, scope: Option<&str>, paths: &[String]) -> Result<()> { +pub fn generate_tags( + loader: &Loader, + scope: Option<&str>, + paths: &[String], + quiet: bool, + time: bool, +) -> Result<()> { let mut lang = None; if let Some(scope) = scope { lang = loader.language_configuration_for_scope(scope)?; @@ -34,28 +41,50 @@ pub fn generate_tags(loader: &Loader, scope: Option<&str>, paths: &[String]) -> }; if let Some(tags_config) = language_config.tags_config(language)? { - let path_str = format!("{:?}", path); - writeln!(&mut stdout, "{}", &path_str[1..path_str.len() - 1])?; + let indent; + if paths.len() > 1 { + if !quiet { + writeln!(&mut stdout, "{}", path.to_string_lossy())?; + } + indent = "\t" + } else { + indent = ""; + }; let source = fs::read(path)?; + let t0 = Instant::now(); for tag in context.generate_tags(tags_config, &source, Some(&cancellation_flag))? { let tag = tag?; - write!( - &mut stdout, - " {:<8} {:<40}\t{:>9}-{:<9}", - tag.kind, - str::from_utf8(&source[tag.name_range]).unwrap_or(""), - tag.span.start, - tag.span.end, - )?; - if let Some(docs) = tag.docs { - if docs.len() > 120 { - write!(&mut stdout, "\t{:?}...", &docs[0..120])?; - } else { - write!(&mut stdout, "\t{:?}", &docs)?; + if !quiet { + write!( + &mut stdout, + "{}{:<10}\t | {:<8}\t{} {} - {} `{}`", + indent, + str::from_utf8(&source[tag.name_range]).unwrap_or(""), + &tags_config.syntax_type_name(tag.syntax_type_id), + if tag.is_definition { "def" } else { "ref" }, + tag.span.start, + tag.span.end, + str::from_utf8(&source[tag.line_range]).unwrap_or(""), + )?; + if let Some(docs) = tag.docs { + if docs.len() > 120 { + write!(&mut stdout, "\t{:?}...", &docs[0..120])?; + } else { + write!(&mut stdout, "\t{:?}", &docs)?; + } } + writeln!(&mut stdout, "")?; } - writeln!(&mut stdout, "")?; + } + + if time { + writeln!( + &mut stdout, + "{}time: {}ms", + indent, + t0.elapsed().as_millis(), + )?; } } else { eprintln!("No tags config found for path {:?}", path); diff --git a/cli/src/tests/tags_test.rs b/cli/src/tests/tags_test.rs index fad8ebd8..f00e83ac 100644 --- a/cli/src/tests/tags_test.rs +++ b/cli/src/tests/tags_test.rs @@ -1,73 +1,79 @@ use super::helpers::allocations; use super::helpers::fixtures::{get_language, get_language_queries_path}; +use std::ffi::CStr; use std::ffi::CString; use std::{fs, ptr, slice, str}; +use tree_sitter::Point; use tree_sitter_tags::c_lib as c; -use tree_sitter_tags::{Error, TagKind, TagsConfiguration, TagsContext}; +use tree_sitter_tags::{Error, TagsConfiguration, TagsContext}; const PYTHON_TAG_QUERY: &'static str = r#" ( - (function_definition - name: (identifier) @name - body: (block . (expression_statement (string) @doc))) @function - (#strip! @doc "(^['\"\\s]*)|(['\"\\s]*$)") + (function_definition + name: (identifier) @name + body: (block . (expression_statement (string) @doc))) @definition.function + (#strip! @doc "(^['\"\\s]*)|(['\"\\s]*$)") ) (function_definition - name: (identifier) @name) @function + name: (identifier) @name) @definition.function ( - (class_definition - name: (identifier) @name - body: (block - . (expression_statement (string) @doc))) @class - (#strip! @doc "(^['\"\\s]*)|(['\"\\s]*$)") + (class_definition + name: (identifier) @name + body: (block + . (expression_statement (string) @doc))) @definition.class + (#strip! @doc "(^['\"\\s]*)|(['\"\\s]*$)") ) (class_definition - name: (identifier) @name) @class + name: (identifier) @name) @definition.class (call - function: (identifier) @name) @call + function: (identifier) @name) @reference.call + +(call + function: (attribute + attribute: (identifier) @name)) @reference.call "#; const JS_TAG_QUERY: &'static str = r#" ( (comment)* @doc . (class_declaration - name: (identifier) @name) @class - (#select-adjacent! @doc @class) + name: (identifier) @name) @definition.class + (#select-adjacent! @doc @definition.class) (#strip! @doc "(^[/\\*\\s]*)|([/\\*\\s]*$)") ) ( (comment)* @doc . (method_definition - name: (property_identifier) @name) @method - (#select-adjacent! @doc @method) + name: (property_identifier) @name) @definition.method + (#select-adjacent! @doc @definition.method) (#strip! @doc "(^[/\\*\\s]*)|([/\\*\\s]*$)") ) ( (comment)* @doc . (function_declaration - name: (identifier) @name) @function - (#select-adjacent! @doc @function) + name: (identifier) @name) @definition.function + (#select-adjacent! @doc @definition.function) (#strip! @doc "(^[/\\*\\s]*)|([/\\*\\s]*$)") ) (call_expression - function: (identifier) @name) @call + function: (identifier) @name) @reference.call "#; const RUBY_TAG_QUERY: &'static str = r#" (method - name: (identifier) @name) @method + name: (identifier) @name) @definition.method (method_call - method: (identifier) @name) @call + method: (identifier) @name) @reference.call -((identifier) @name @call +((identifier) @name @reference.call (#is-not? local)) "#; @@ -99,20 +105,20 @@ fn test_tags_python() { assert_eq!( tags.iter() - .map(|t| (substr(source, &t.name_range), t.kind)) + .map(|t| ( + substr(source, &t.name_range), + tags_config.syntax_type_name(t.syntax_type_id) + )) .collect::>(), &[ - ("Customer", TagKind::Class), - ("age", TagKind::Function), - ("compute_age", TagKind::Call), + ("Customer", "class"), + ("age", "function"), + ("compute_age", "call"), ] ); - assert_eq!(substr(source, &tags[0].line_range), " class Customer:"); - assert_eq!( - substr(source, &tags[1].line_range), - " def age(self):" - ); + assert_eq!(substr(source, &tags[0].line_range), "class Customer:"); + assert_eq!(substr(source, &tags[1].line_range), "def age(self):"); assert_eq!(tags[0].docs.as_ref().unwrap(), "Data about a customer"); assert_eq!(tags[1].docs.as_ref().unwrap(), "Get the customer's age"); } @@ -150,12 +156,16 @@ fn test_tags_javascript() { assert_eq!( tags.iter() - .map(|t| (substr(source, &t.name_range), t.kind)) + .map(|t| ( + substr(source, &t.name_range), + t.span.clone(), + tags_config.syntax_type_name(t.syntax_type_id) + )) .collect::>(), &[ - ("Customer", TagKind::Class), - ("getAge", TagKind::Method), - ("Agent", TagKind::Class) + ("Customer", Point::new(5, 10)..Point::new(5, 18), "class",), + ("getAge", Point::new(9, 8)..Point::new(9, 14), "method",), + ("Agent", Point::new(15, 10)..Point::new(15, 15), "class",) ] ); assert_eq!( @@ -166,6 +176,26 @@ fn test_tags_javascript() { assert_eq!(tags[2].docs, None); } +#[test] +fn test_tags_columns_measured_in_utf16_code_units() { + let language = get_language("python"); + let tags_config = TagsConfiguration::new(language, PYTHON_TAG_QUERY, "").unwrap(); + let mut tag_context = TagsContext::new(); + + let source = r#""❤️❤️❤️".hello_α_ω()"#.as_bytes(); + + let tag = tag_context + .generate_tags(&tags_config, source, None) + .unwrap() + .next() + .unwrap() + .unwrap(); + + assert_eq!(substr(source, &tag.name_range), "hello_α_ω"); + assert_eq!(tag.span, Point::new(0, 21)..Point::new(0, 32)); + assert_eq!(tag.utf16_column_range, 9..18); +} + #[test] fn test_tags_ruby() { let language = get_language("ruby"); @@ -204,18 +234,18 @@ fn test_tags_ruby() { tags.iter() .map(|t| ( substr(source.as_bytes(), &t.name_range), - t.kind, + tags_config.syntax_type_name(t.syntax_type_id), (t.span.start.row, t.span.start.column), )) .collect::>(), &[ - ("foo", TagKind::Method, (2, 0)), - ("bar", TagKind::Call, (7, 4)), - ("a", TagKind::Call, (7, 8)), - ("b", TagKind::Call, (7, 11)), - ("each", TagKind::Call, (9, 14)), - ("baz", TagKind::Call, (13, 8)), - ("b", TagKind::Call, (13, 15),), + ("foo", "method", (2, 4)), + ("bar", "call", (7, 4)), + ("a", "call", (7, 8)), + ("b", "call", (7, 11)), + ("each", "call", (9, 14)), + ("baz", "call", (13, 8)), + ("b", "call", (13, 15),), ] ); } @@ -253,6 +283,14 @@ fn test_tags_cancellation() { }); } +#[test] +fn test_invalid_capture() { + let language = get_language("python"); + let e = TagsConfiguration::new(language, "(identifier) @method", "") + .expect_err("expected InvalidCapture error"); + assert_eq!(e, Error::InvalidCapture("method".to_string())); +} + #[test] fn test_tags_via_c_api() { allocations::record(|| { @@ -316,29 +354,29 @@ fn test_tags_via_c_api() { }) .unwrap(); + let syntax_types: Vec<&str> = unsafe { + let mut len: u32 = 0; + let ptr = + c::ts_tagger_syntax_kinds_for_scope_name(tagger, c_scope_name.as_ptr(), &mut len); + slice::from_raw_parts(ptr, len as usize) + .iter() + .map(|i| CStr::from_ptr(*i).to_str().unwrap()) + .collect() + }; + assert_eq!( tags.iter() .map(|tag| ( - tag.kind, + syntax_types[tag.syntax_type_id as usize], &source_code[tag.name_start_byte as usize..tag.name_end_byte as usize], &source_code[tag.line_start_byte as usize..tag.line_end_byte as usize], &docs[tag.docs_start_byte as usize..tag.docs_end_byte as usize], )) .collect::>(), &[ - ( - c::TSTagKind::Function, - "b", - "function b() {", - "one\ntwo\nthree" - ), - ( - c::TSTagKind::Class, - "C", - "class C extends D {", - "four\nfive" - ), - (c::TSTagKind::Call, "b", "b(a);", "") + ("function", "b", "function b() {", "one\ntwo\nthree"), + ("class", "C", "class C extends D {", "four\nfive"), + ("call", "b", "b(a);", "") ] ); diff --git a/tags/include/tree_sitter/tags.h b/tags/include/tree_sitter/tags.h index 946dc6f1..f2b17075 100644 --- a/tags/include/tree_sitter/tags.h +++ b/tags/include/tree_sitter/tags.h @@ -16,18 +16,10 @@ typedef enum { TSTagsInvalidUtf8, TSTagsInvalidRegex, TSTagsInvalidQuery, + TSTagsInvalidCapture, } TSTagsError; -typedef enum { - TSTagKindFunction, - TSTagKindMethod, - TSTagKindClass, - TSTagKindModule, - TSTagKindCall, -} TSTagKind; - typedef struct { - TSTagKind kind; uint32_t start_byte; uint32_t end_byte; uint32_t name_start_byte; @@ -36,8 +28,12 @@ typedef struct { uint32_t line_end_byte; TSPoint start_point; TSPoint end_point; + uint32_t utf16_start_column; + uint32_t utf16_end_column; uint32_t docs_start_byte; uint32_t docs_end_byte; + uint32_t syntax_type_id; + bool is_definition; } TSTag; typedef struct TSTagger TSTagger; @@ -89,6 +85,9 @@ uint32_t ts_tags_buffer_tags_len(const TSTagsBuffer *); const char *ts_tags_buffer_docs(const TSTagsBuffer *); uint32_t ts_tags_buffer_docs_len(const TSTagsBuffer *); +// Get the syntax kinds for a scope. +const char **ts_tagger_syntax_kinds_for_scope_name(const TSTagger *, const char *scope_name, uint32_t *len); + #ifdef __cplusplus } #endif diff --git a/tags/src/c_lib.rs b/tags/src/c_lib.rs index 0c367977..07e1e19a 100644 --- a/tags/src/c_lib.rs +++ b/tags/src/c_lib.rs @@ -1,4 +1,4 @@ -use super::{Error, TagKind, TagsConfiguration, TagsContext}; +use super::{Error, TagsConfiguration, TagsContext}; use std::collections::HashMap; use std::ffi::CStr; use std::process::abort; @@ -16,19 +16,10 @@ pub enum TSTagsError { InvalidUtf8, InvalidRegex, InvalidQuery, + InvalidCapture, Unknown, } -#[repr(C)] -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub enum TSTagKind { - Function, - Method, - Class, - Module, - Call, -} - #[repr(C)] pub struct TSPoint { row: u32, @@ -37,7 +28,6 @@ pub struct TSPoint { #[repr(C)] pub struct TSTag { - pub kind: TSTagKind, pub start_byte: u32, pub end_byte: u32, pub name_start_byte: u32, @@ -46,8 +36,12 @@ pub struct TSTag { pub line_end_byte: u32, pub start_point: TSPoint, pub end_point: TSPoint, + pub utf16_start_colum: u32, + pub utf16_end_colum: u32, pub docs_start_byte: u32, pub docs_end_byte: u32, + pub syntax_type_id: u32, + pub is_definition: bool, } pub struct TSTagger { @@ -102,7 +96,9 @@ pub extern "C" fn ts_tagger_add_language( } Err(Error::Query(_)) => TSTagsError::InvalidQuery, Err(Error::Regex(_)) => TSTagsError::InvalidRegex, - Err(_) => TSTagsError::Unknown, + Err(Error::Cancelled) => TSTagsError::Timeout, + Err(Error::InvalidLanguage) => TSTagsError::InvalidLanguage, + Err(Error::InvalidCapture(_)) => TSTagsError::InvalidCapture, } } @@ -153,13 +149,6 @@ pub extern "C" fn ts_tagger_tag( buffer.docs.extend_from_slice(docs.as_bytes()); } buffer.tags.push(TSTag { - kind: match tag.kind { - TagKind::Function => TSTagKind::Function, - TagKind::Method => TSTagKind::Method, - TagKind::Class => TSTagKind::Class, - TagKind::Module => TSTagKind::Module, - TagKind::Call => TSTagKind::Call, - }, start_byte: tag.range.start as u32, end_byte: tag.range.end as u32, name_start_byte: tag.name_range.start as u32, @@ -174,8 +163,12 @@ pub extern "C" fn ts_tagger_tag( row: tag.span.end.row as u32, column: tag.span.end.column as u32, }, + utf16_start_colum: tag.utf16_column_range.start as u32, + utf16_end_colum: tag.utf16_column_range.end as u32, docs_start_byte: prev_docs_len as u32, docs_end_byte: buffer.docs.len() as u32, + syntax_type_id: tag.syntax_type_id, + is_definition: tag.is_definition, }); } @@ -223,6 +216,24 @@ pub extern "C" fn ts_tags_buffer_docs_len(this: *const TSTagsBuffer) -> u32 { buffer.docs.len() as u32 } +#[no_mangle] +pub extern "C" fn ts_tagger_syntax_kinds_for_scope_name( + this: *mut TSTagger, + scope_name: *const i8, + len: *mut u32, +) -> *const *const i8 { + let tagger = unwrap_mut_ptr(this); + let scope_name = unsafe { unwrap(CStr::from_ptr(scope_name).to_str()) }; + let len = unwrap_mut_ptr(len); + + *len = 0; + if let Some(config) = tagger.languages.get(scope_name) { + *len = config.c_syntax_type_names.len() as u32; + return config.c_syntax_type_names.as_ptr() as *const *const i8; + } + std::ptr::null() +} + fn unwrap_ptr<'a, T>(result: *const T) -> &'a T { unsafe { result.as_ref() }.unwrap_or_else(|| { eprintln!("{}:{} - pointer must not be null", file!(), line!()); diff --git a/tags/src/lib.rs b/tags/src/lib.rs index 8d1853bb..dcbb9984 100644 --- a/tags/src/lib.rs +++ b/tags/src/lib.rs @@ -1,10 +1,12 @@ pub mod c_lib; -use memchr::{memchr, memrchr}; +use memchr::memchr; use regex::Regex; +use std::collections::HashMap; +use std::ffi::{CStr, CString}; use std::ops::Range; use std::sync::atomic::{AtomicUsize, Ordering}; -use std::{fmt, mem, str}; +use std::{char, fmt, mem, str}; use tree_sitter::{ Language, Parser, Point, Query, QueryCursor, QueryError, QueryPredicateArg, Tree, }; @@ -18,12 +20,10 @@ const CANCELLATION_CHECK_INTERVAL: usize = 100; pub struct TagsConfiguration { pub language: Language, pub query: Query, - call_capture_index: Option, - class_capture_index: Option, + syntax_type_names: Vec>, + c_syntax_type_names: Vec<*const u8>, + capture_map: HashMap, doc_capture_index: Option, - function_capture_index: Option, - method_capture_index: Option, - module_capture_index: Option, name_capture_index: Option, local_scope_capture_index: Option, local_definition_capture_index: Option, @@ -31,6 +31,12 @@ pub struct TagsConfiguration { pattern_info: Vec, } +#[derive(Debug)] +pub struct NamedCapture { + pub syntax_type_id: u32, + pub is_definition: bool, +} + pub struct TagsContext { parser: Parser, cursor: QueryCursor, @@ -38,21 +44,14 @@ pub struct TagsContext { #[derive(Debug, Clone)] pub struct Tag { - pub kind: TagKind, pub range: Range, pub name_range: Range, pub line_range: Range, pub span: Range, + pub utf16_column_range: Range, pub docs: Option, -} - -#[derive(Copy, Clone, Debug, PartialEq, Eq)] -pub enum TagKind { - Function, - Method, - Class, - Module, - Call, + pub is_definition: bool, + pub syntax_type_id: u32, } #[derive(Debug, PartialEq)] @@ -61,6 +60,7 @@ pub enum Error { Regex(regex::Error), Cancelled, InvalidLanguage, + InvalidCapture(String), } #[derive(Debug, Default)] @@ -91,6 +91,7 @@ where matches: I, _tree: Tree, source: &'a [u8], + prev_line_info: Option, config: &'a TagsConfiguration, cancellation_flag: Option<&'a AtomicUsize>, iter_count: usize, @@ -98,6 +99,18 @@ where scopes: Vec>, } +struct LineInfo { + utf8_position: Point, + utf8_byte: usize, + utf16_column: usize, + line_range: Range, +} + +struct LossyUtf8<'a> { + bytes: &'a [u8], + in_replacement: bool, +} + impl TagsConfiguration { pub fn new(language: Language, tags_query: &str, locals_query: &str) -> Result { let query = Query::new(language, &format!("{}{}", locals_query, tags_query))?; @@ -111,31 +124,55 @@ impl TagsConfiguration { } } - let mut call_capture_index = None; - let mut class_capture_index = None; + let mut capture_map = HashMap::new(); + let mut syntax_type_names = Vec::new(); let mut doc_capture_index = None; - let mut function_capture_index = None; - let mut method_capture_index = None; - let mut module_capture_index = None; let mut name_capture_index = None; let mut local_scope_capture_index = None; let mut local_definition_capture_index = None; for (i, name) in query.capture_names().iter().enumerate() { - let index = match name.as_str() { - "call" => &mut call_capture_index, - "class" => &mut class_capture_index, - "doc" => &mut doc_capture_index, - "function" => &mut function_capture_index, - "method" => &mut method_capture_index, - "module" => &mut module_capture_index, - "name" => &mut name_capture_index, - "local.scope" => &mut local_scope_capture_index, - "local.definition" => &mut local_definition_capture_index, - _ => continue, - }; - *index = Some(i as u32); + match name.as_str() { + "" => continue, + "name" => name_capture_index = Some(i as u32), + "doc" => doc_capture_index = Some(i as u32), + "local.scope" => local_scope_capture_index = Some(i as u32), + "local.definition" => local_definition_capture_index = Some(i as u32), + "local.reference" => continue, + _ => { + let mut is_definition = false; + + let kind = if name.starts_with("definition.") { + is_definition = true; + name.trim_start_matches("definition.") + } else if name.starts_with("reference.") { + name.trim_start_matches("reference.") + } else { + return Err(Error::InvalidCapture(name.to_string())); + }; + + if let Ok(cstr) = CString::new(kind) { + let c_kind = cstr.to_bytes_with_nul().to_vec().into_boxed_slice(); + let syntax_type_id = syntax_type_names + .iter() + .position(|n| n == &c_kind) + .unwrap_or_else(|| { + syntax_type_names.push(c_kind); + syntax_type_names.len() - 1 + }) as u32; + capture_map.insert( + i as u32, + NamedCapture { + syntax_type_id, + is_definition, + }, + ); + } + } + } } + let c_syntax_type_names = syntax_type_names.iter().map(|s| s.as_ptr()).collect(); + let pattern_info = (0..query.pattern_count()) .map(|pattern_index| { let mut info = PatternInfo::default(); @@ -180,12 +217,10 @@ impl TagsConfiguration { Ok(TagsConfiguration { language, query, - function_capture_index, - class_capture_index, - method_capture_index, - module_capture_index, + syntax_type_names, + c_syntax_type_names, + capture_map, doc_capture_index, - call_capture_index, name_capture_index, tags_pattern_index, local_scope_capture_index, @@ -193,6 +228,14 @@ impl TagsConfiguration { pattern_info, }) } + + pub fn syntax_type_name(&self, id: u32) -> &str { + unsafe { + let cstr = CStr::from_ptr(self.syntax_type_names[id as usize].as_ptr() as *const i8) + .to_bytes(); + str::from_utf8(cstr).expect("syntax type name was not valid utf-8") + } + } } impl TagsContext { @@ -230,6 +273,7 @@ impl TagsContext { source, config, cancellation_flag, + prev_line_info: None, tag_queue: Vec::new(), iter_count: 0, scopes: vec![LocalScope { @@ -300,10 +344,11 @@ where continue; } - let mut name_range = None; + let mut name_node = None; let mut doc_nodes = Vec::new(); let mut tag_node = None; - let mut kind = TagKind::Call; + let mut syntax_type_id = 0; + let mut is_definition = false; let mut docs_adjacent_node = None; for capture in mat.captures { @@ -314,28 +359,21 @@ where } if index == self.config.name_capture_index { - name_range = Some(capture.node.byte_range()); + name_node = Some(capture.node); } else if index == self.config.doc_capture_index { doc_nodes.push(capture.node); - } else if index == self.config.call_capture_index { + } + + if let Some(named_capture) = self.config.capture_map.get(&capture.index) { tag_node = Some(capture.node); - kind = TagKind::Call; - } else if index == self.config.class_capture_index { - tag_node = Some(capture.node); - kind = TagKind::Class; - } else if index == self.config.function_capture_index { - tag_node = Some(capture.node); - kind = TagKind::Function; - } else if index == self.config.method_capture_index { - tag_node = Some(capture.node); - kind = TagKind::Method; - } else if index == self.config.module_capture_index { - tag_node = Some(capture.node); - kind = TagKind::Module; + syntax_type_id = named_capture.syntax_type_id; + is_definition = named_capture.is_definition; } } - if let (Some(tag_node), Some(name_range)) = (tag_node, name_range) { + if let (Some(tag_node), Some(name_node)) = (tag_node, name_node) { + let name_range = name_node.byte_range(); + if pattern_info.name_must_be_non_local { let mut is_local = false; for scope in self.scopes.iter().rev() { @@ -399,42 +437,73 @@ where } } + let range = tag_node.byte_range(); + let span = name_node.start_position()..name_node.end_position(); + + // Compute tag properties that depend on the text of the containing line. If the + // previous tag occurred on the same line, then reuse results from the previous tag. + let line_range; + let mut prev_utf16_column = 0; + let mut prev_utf8_byte = name_range.start - span.start.column; + let line_info = self.prev_line_info.as_ref().and_then(|info| { + if info.utf8_position.row == span.start.row { + Some(info) + } else { + None + } + }); + if let Some(line_info) = line_info { + line_range = line_info.line_range.clone(); + if line_info.utf8_position.column <= span.start.column { + prev_utf8_byte = line_info.utf8_byte; + prev_utf16_column = line_info.utf16_column; + } + } else { + line_range = self::line_range( + self.source, + name_range.start, + span.start, + MAX_LINE_LEN, + ); + } + + let utf16_start_column = prev_utf16_column + + utf16_len(&self.source[prev_utf8_byte..name_range.start]); + let utf16_end_column = + utf16_start_column + utf16_len(&self.source[name_range.clone()]); + let utf16_column_range = utf16_start_column..utf16_end_column; + + self.prev_line_info = Some(LineInfo { + utf8_position: span.end, + utf8_byte: name_range.end, + utf16_column: utf16_end_column, + line_range: line_range.clone(), + }); + let tag = Tag { + line_range, + span, + utf16_column_range, + range, + name_range, + docs, + is_definition, + syntax_type_id, + }; + // Only create one tag per node. The tag queue is sorted by node position // to allow for fast lookup. - let range = tag_node.byte_range(); - match self - .tag_queue - .binary_search_by_key(&(name_range.end, name_range.start), |(tag, _)| { - (tag.name_range.end, tag.name_range.start) - }) { + match self.tag_queue.binary_search_by_key( + &(tag.name_range.end, tag.name_range.start), + |(tag, _)| (tag.name_range.end, tag.name_range.start), + ) { Ok(i) => { - let (tag, pattern_index) = &mut self.tag_queue[i]; + let (existing_tag, pattern_index) = &mut self.tag_queue[i]; if *pattern_index > mat.pattern_index { *pattern_index = mat.pattern_index; - *tag = Tag { - line_range: line_range(self.source, range.start, MAX_LINE_LEN), - span: tag_node.start_position()..tag_node.end_position(), - kind, - range, - name_range, - docs, - }; + *existing_tag = tag; } } - Err(i) => self.tag_queue.insert( - i, - ( - Tag { - line_range: line_range(self.source, range.start, MAX_LINE_LEN), - span: tag_node.start_position()..tag_node.end_position(), - kind, - range, - name_range, - docs, - }, - mat.pattern_index, - ), - ), + Err(i) => self.tag_queue.insert(i, (tag, mat.pattern_index)), } } } @@ -448,16 +517,12 @@ where } } -impl fmt::Display for TagKind { +impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { - TagKind::Call => "Call", - TagKind::Module => "Module", - TagKind::Class => "Class", - TagKind::Method => "Method", - TagKind::Function => "Function", + Error::InvalidCapture(name) => write!(f, "Invalid capture @{}. Expected one of: @definition.*, @reference.*, @doc, @name, @local.(scope|definition|reference).", name), + _ => write!(f, "{:?}", self) } - .fmt(f) } } @@ -473,11 +538,90 @@ impl From for Error { } } -fn line_range(text: &[u8], index: usize, max_line_len: usize) -> Range { - let start = memrchr(b'\n', &text[0..index]).map_or(0, |i| i + 1); - let max_line_len = max_line_len.min(text.len() - start); - let end = start + memchr(b'\n', &text[start..(start + max_line_len)]).unwrap_or(max_line_len); - start..end +// TODO: Remove this struct at at some point. If `core::str::lossy::Utf8Lossy` +// is ever stabilized, we should use that. Otherwise, this struct could be moved +// into some module that's shared between `tree-sitter-tags` and `tree-sitter-highlight`. +impl<'a> LossyUtf8<'a> { + fn new(bytes: &'a [u8]) -> Self { + LossyUtf8 { + bytes, + in_replacement: false, + } + } +} + +impl<'a> Iterator for LossyUtf8<'a> { + type Item = &'a str; + + fn next(&mut self) -> Option<&'a str> { + if self.bytes.is_empty() { + return None; + } + if self.in_replacement { + self.in_replacement = false; + return Some("\u{fffd}"); + } + match str::from_utf8(self.bytes) { + Ok(valid) => { + self.bytes = &[]; + Some(valid) + } + Err(error) => { + if let Some(error_len) = error.error_len() { + let error_start = error.valid_up_to(); + if error_start > 0 { + let result = + unsafe { str::from_utf8_unchecked(&self.bytes[..error_start]) }; + self.bytes = &self.bytes[(error_start + error_len)..]; + self.in_replacement = true; + Some(result) + } else { + self.bytes = &self.bytes[error_len..]; + Some("\u{fffd}") + } + } else { + None + } + } + } + } +} + +fn line_range( + text: &[u8], + start_byte: usize, + start_point: Point, + max_line_len: usize, +) -> Range { + // Trim leading whitespace + let mut line_start_byte = start_byte - start_point.column; + while line_start_byte < text.len() && text[line_start_byte].is_ascii_whitespace() { + line_start_byte += 1; + } + + let max_line_len = max_line_len.min(text.len() - line_start_byte); + let text_after_line_start = &text[line_start_byte..(line_start_byte + max_line_len)]; + let line_len = if let Some(len) = memchr(b'\n', text_after_line_start) { + len + } else if let Err(e) = str::from_utf8(text_after_line_start) { + e.valid_up_to() + } else { + max_line_len + }; + + // Trim trailing whitespace + let mut line_end_byte = line_start_byte + line_len; + while line_end_byte > line_start_byte && text[line_end_byte - 1].is_ascii_whitespace() { + line_end_byte -= 1; + } + + line_start_byte..line_end_byte +} + +fn utf16_len(bytes: &[u8]) -> usize { + LossyUtf8::new(bytes) + .flat_map(|chunk| chunk.chars().map(char::len_utf16)) + .sum() } #[cfg(test)] @@ -486,14 +630,27 @@ mod tests { #[test] fn test_get_line() { - let text = b"abc\ndefg\nhijkl"; - assert_eq!(line_range(text, 0, 10), 0..3); - assert_eq!(line_range(text, 1, 10), 0..3); - assert_eq!(line_range(text, 2, 10), 0..3); - assert_eq!(line_range(text, 3, 10), 0..3); - assert_eq!(line_range(text, 1, 2), 0..2); - assert_eq!(line_range(text, 4, 10), 4..8); - assert_eq!(line_range(text, 5, 10), 4..8); - assert_eq!(line_range(text, 11, 10), 9..14); + let text = "abc\ndefg❤hij\nklmno".as_bytes(); + assert_eq!(line_range(text, 5, Point::new(1, 1), 30), 4..14); + assert_eq!(line_range(text, 5, Point::new(1, 1), 6), 4..8); + assert_eq!(line_range(text, 17, Point::new(2, 2), 30), 15..20); + assert_eq!(line_range(text, 17, Point::new(2, 2), 4), 15..19); + } + + #[test] + fn test_get_line_trims() { + let text = b" foo\nbar\n"; + assert_eq!(line_range(text, 0, Point::new(0, 0), 10), 3..6); + + let text = b"\t func foo \nbar\n"; + assert_eq!(line_range(text, 0, Point::new(0, 0), 10), 2..10); + + let r = line_range(text, 0, Point::new(0, 0), 14); + assert_eq!(r, 2..10); + assert_eq!(str::from_utf8(&text[r]).unwrap_or(""), "func foo"); + + let r = line_range(text, 12, Point::new(1, 0), 14); + assert_eq!(r, 12..15); + assert_eq!(str::from_utf8(&text[r]).unwrap_or(""), "bar"); } }