diff --git a/cli/src/tests/tags_test.rs b/cli/src/tests/tags_test.rs index bca35f71..d4cbc687 100644 --- a/cli/src/tests/tags_test.rs +++ b/cli/src/tests/tags_test.rs @@ -1,30 +1,56 @@ +use super::helpers::allocations; use super::helpers::fixtures::get_language; +use std::ffi::CString; +use std::{ptr, slice, str}; +use tree_sitter_tags::c_lib as c; use tree_sitter_tags::{TagKind, TagsConfiguration, TagsContext}; +const PYTHON_TAG_QUERY: &'static str = r#" +((function_definition + name: (identifier) @name + body: (block . (expression_statement (string) @doc))) @function + (strip! @doc "(^['\"\\s]*)|(['\"\\s]*$)")) +(function_definition + name: (identifier) @name) @function +((class_definition + name: (identifier) @name + body: (block . (expression_statement (string) @doc))) @class + (strip! @doc "(^['\"\\s]*)|(['\"\\s]*$)")) +(class_definition + name: (identifier) @name) @class +(call + function: (identifier) @name) @call +"#; + +const JS_TAG_QUERY: &'static str = r#" +((* + (comment)+ @doc . + (class_declaration + name: (identifier) @name) @class) + (select-adjacent! @doc @class) + (strip! @doc "(^[/\\*\\s]*)|([/\\*\\s]*$)")) + +((* + (comment)+ @doc . + (method_definition + name: (property_identifier) @name) @method) + (select-adjacent! @doc @method) + (strip! @doc "(^[/\\*\\s]*)|([/\\*\\s]*$)")) + +((* + (comment)+ @doc . + (function_declaration + name: (identifier) @name) @function) + (select-adjacent! @doc @function) + (strip! @doc "(^[/\\*\\s]*)|([/\\*\\s]*$)")) + +(call_expression function: (identifier) @name) @call + "#; + #[test] fn test_tags_python() { let language = get_language("python"); - let tags_config = TagsConfiguration::new( - language, - r#" - ((function_definition - name: (identifier) @name - body: (block . (expression_statement (string) @doc))) @function - (strip! @doc "(^['\"\\s]*)|(['\"\\s]*$)")) - (function_definition - name: (identifier) @name) @function - ((class_definition - name: (identifier) @name - body: (block . (expression_statement (string) @doc))) @class - (strip! @doc "(^['\"\\s]*)|(['\"\\s]*$)")) - (class_definition - name: (identifier) @name) @class - (call - function: (identifier) @name) @call - "#, - "", - ) - .unwrap(); + let tags_config = TagsConfiguration::new(language, PYTHON_TAG_QUERY, "").unwrap(); let source = br#" class Customer: @@ -68,27 +94,7 @@ fn test_tags_python() { #[test] fn test_tags_javascript() { let language = get_language("javascript"); - let tags_config = TagsConfiguration::new( - language, - r#" - ((* - (comment)+ @doc - . - (class_declaration - name: (identifier) @name) @class) - (select-adjacent! @doc @class) - (strip! @doc "(^[/\\*\\s]*)|([/\\*\\s]*$)")) - ((* - (comment)+ @doc - . - (method_definition - name: (property_identifier) @name) @method) -; (select-adjacent! @doc @method) - (strip! @doc "(^[/\\*\\s]*)|([/\\*\\s]*$)")) - "#, - "", - ) - .unwrap(); + let tags_config = TagsConfiguration::new(language, JS_TAG_QUERY, "").unwrap(); let mut tag_context = TagsContext::new(); let source = br#" @@ -132,6 +138,103 @@ fn test_tags_javascript() { assert_eq!(tags[2].docs, None); } +#[test] +fn test_tags_via_c_api() { + allocations::record(|| { + let tagger = c::ts_tagger_new(); + let buffer = c::ts_tags_buffer_new(); + let scope_name = "source.js"; + let language = get_language("javascript"); + + let source_code = " + var a = 1; + + // one + // two + // three + function b() { + } + + // four + // five + class C extends D { + + } + + b(a);" + .lines() + .skip(1) + // remove extra indentation + .map(|line| &line[line.len().min(12)..]) + .collect::>() + .join("\n"); + + let c_scope_name = CString::new(scope_name).unwrap(); + let result = c::ts_tagger_add_language( + tagger, + c_scope_name.as_ptr(), + language, + JS_TAG_QUERY.as_ptr(), + ptr::null(), + JS_TAG_QUERY.len() as u32, + 0, + ); + assert_eq!(result, c::TSTagsError::Ok); + + let result = c::ts_tagger_tag( + tagger, + c_scope_name.as_ptr(), + source_code.as_ptr(), + source_code.len() as u32, + buffer, + ptr::null(), + ); + assert_eq!(result, c::TSTagsError::Ok); + let tags = unsafe { + slice::from_raw_parts( + c::ts_tags_buffer_tags(buffer), + c::ts_tags_buffer_tags_len(buffer) as usize, + ) + }; + let docs = str::from_utf8(unsafe { + slice::from_raw_parts( + c::ts_tags_buffer_docs(buffer) as *const u8, + c::ts_tags_buffer_docs_len(buffer) as usize, + ) + }) + .unwrap(); + + assert_eq!( + tags.iter() + .map(|tag| ( + tag.kind, + &source_code[tag.name_start_byte as usize..tag.name_end_byte as usize], + &source_code[tag.line_start_byte as usize..tag.line_end_byte as usize], + &docs[tag.docs_start_byte as usize..tag.docs_end_byte as usize], + )) + .collect::>(), + &[ + ( + c::TSTagKind::Function, + "b", + "function b() {", + "one\ntwo\nthree" + ), + ( + c::TSTagKind::Class, + "C", + "class C extends D {", + "four\nfive" + ), + (c::TSTagKind::Call, "b", "b(a);", "") + ] + ); + + c::ts_tags_buffer_delete(buffer); + c::ts_tagger_delete(tagger); + }); +} + fn substr<'a>(source: &'a [u8], range: &std::ops::Range) -> &'a str { std::str::from_utf8(&source[range.clone()]).unwrap() } diff --git a/tags/include/tree_sitter/tags.h b/tags/include/tree_sitter/tags.h index d492ad31..6054edc4 100644 --- a/tags/include/tree_sitter/tags.h +++ b/tags/include/tree_sitter/tags.h @@ -82,8 +82,12 @@ TSTagsBuffer *ts_tags_buffer_new(); void ts_tags_buffer_delete(TSTagsBuffer *); // Access the tags within a tag buffer. -const TSTag *ts_tags_buffer_line_offsets(const TSTagsBuffer *); -uint32_t ts_tags_buffer_len(const TSTagsBuffer *); +const TSTag *ts_tags_buffer_tags(const TSTagsBuffer *); +uint32_t ts_tags_buffer_tags_len(const TSTagsBuffer *); + +// Access the string containing all of the docs +const char *ts_tags_buffer_docs(const TSTagsBuffer *); +uint32_t ts_tags_buffer_docs_len(const TSTagsBuffer *); #ifdef __cplusplus } diff --git a/tags/src/c_lib.rs b/tags/src/c_lib.rs index 714d956e..83ef9c5f 100644 --- a/tags/src/c_lib.rs +++ b/tags/src/c_lib.rs @@ -6,7 +6,8 @@ use std::{fmt, slice, str}; use tree_sitter::Language; #[repr(C)] -enum TSTagsError { +#[derive(Debug, PartialEq, Eq)] +pub enum TSTagsError { Ok, UnknownScope, Timeout, @@ -17,7 +18,8 @@ enum TSTagsError { } #[repr(C)] -enum TSTagKind { +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum TSTagKind { Function, Method, Class, @@ -26,38 +28,49 @@ enum TSTagKind { } #[repr(C)] -struct TSPoint { +pub struct TSPoint { row: u32, column: u32, } #[repr(C)] -struct TSTag { - kind: TSTagKind, - start_byte: u32, - end_byte: u32, - name_start_byte: u32, - name_end_byte: u32, - line_start_byte: u32, - line_end_byte: u32, - start_point: TSPoint, - end_point: TSPoint, - docs: *const u8, - docs_length: u32, +pub struct TSTag { + pub kind: TSTagKind, + pub start_byte: u32, + pub end_byte: u32, + pub name_start_byte: u32, + pub name_end_byte: u32, + pub line_start_byte: u32, + pub line_end_byte: u32, + pub start_point: TSPoint, + pub end_point: TSPoint, + pub docs_start_byte: u32, + pub docs_end_byte: u32, } -struct TSTagger { +pub struct TSTagger { languages: HashMap, } -struct TSTagsBuffer { +pub struct TSTagsBuffer { context: TagsContext, tags: Vec, docs: Vec, } #[no_mangle] -unsafe extern "C" fn ts_tagger_add_language( +pub extern "C" fn ts_tagger_new() -> *mut TSTagger { + Box::into_raw(Box::new(TSTagger { + languages: HashMap::new(), + })) +} + +pub extern "C" fn ts_tagger_delete(this: *mut TSTagger) { + drop(unsafe { Box::from_raw(this) }) +} + +#[no_mangle] +pub extern "C" fn ts_tagger_add_language( this: *mut TSTagger, scope_name: *const i8, language: Language, @@ -67,9 +80,9 @@ unsafe extern "C" fn ts_tagger_add_language( locals_query_len: u32, ) -> TSTagsError { let tagger = unwrap_mut_ptr(this); - let scope_name = unwrap(CStr::from_ptr(scope_name).to_str()); - let tags_query = slice::from_raw_parts(tags_query, tags_query_len as usize); - let locals_query = slice::from_raw_parts(locals_query, locals_query_len as usize); + let scope_name = unsafe { unwrap(CStr::from_ptr(scope_name).to_str()) }; + let tags_query = unsafe { slice::from_raw_parts(tags_query, tags_query_len as usize) }; + let locals_query = unsafe { slice::from_raw_parts(locals_query, locals_query_len as usize) }; let tags_query = match str::from_utf8(tags_query) { Ok(e) => e, Err(_) => return TSTagsError::InvalidUtf8, @@ -89,7 +102,7 @@ unsafe extern "C" fn ts_tagger_add_language( } #[no_mangle] -unsafe extern "C" fn ts_tagger_tag( +pub extern "C" fn ts_tagger_tag( this: *mut TSTagger, scope_name: *const i8, source_code: *const u8, @@ -99,15 +112,17 @@ unsafe extern "C" fn ts_tagger_tag( ) -> TSTagsError { let tagger = unwrap_mut_ptr(this); let buffer = unwrap_mut_ptr(output); - let scope_name = unwrap(CStr::from_ptr(scope_name).to_str()); + let scope_name = unsafe { unwrap(CStr::from_ptr(scope_name).to_str()) }; if let Some(config) = tagger.languages.get(scope_name) { - let source_code = slice::from_raw_parts(source_code, source_code_len as usize); + buffer.tags.clear(); + buffer.docs.clear(); + let source_code = unsafe { slice::from_raw_parts(source_code, source_code_len as usize) }; + for tag in buffer.context.generate_tags(config, source_code) { let prev_docs_len = buffer.docs.len(); if let Some(docs) = tag.docs { buffer.docs.extend_from_slice(docs.as_bytes()); } - let docs = &buffer.docs[prev_docs_len..]; buffer.tags.push(TSTag { kind: match tag.kind { TagKind::Function => TSTagKind::Function, @@ -130,10 +145,11 @@ unsafe extern "C" fn ts_tagger_tag( row: tag.span.end.row as u32, column: tag.span.end.column as u32, }, - docs: docs.as_ptr(), - docs_length: docs.len() as u32, + docs_start_byte: prev_docs_len as u32, + docs_end_byte: buffer.docs.len() as u32, }); } + TSTagsError::Ok } else { TSTagsError::UnknownScope @@ -141,7 +157,7 @@ unsafe extern "C" fn ts_tagger_tag( } #[no_mangle] -extern "C" fn ts_tags_buffer_new() -> *mut TSTagsBuffer { +pub extern "C" fn ts_tags_buffer_new() -> *mut TSTagsBuffer { Box::into_raw(Box::new(TSTagsBuffer { context: TagsContext::new(), tags: Vec::new(), @@ -150,22 +166,34 @@ extern "C" fn ts_tags_buffer_new() -> *mut TSTagsBuffer { } #[no_mangle] -extern "C" fn ts_tags_buffer_delete(this: *mut TSTagsBuffer) { +pub extern "C" fn ts_tags_buffer_delete(this: *mut TSTagsBuffer) { drop(unsafe { Box::from_raw(this) }) } #[no_mangle] -extern "C" fn ts_tags_buffer_line_offsets(this: *const TSTagsBuffer) -> *const TSTag { +pub extern "C" fn ts_tags_buffer_tags(this: *const TSTagsBuffer) -> *const TSTag { let buffer = unwrap_ptr(this); buffer.tags.as_ptr() } #[no_mangle] -extern "C" fn ts_tags_buffer_len(this: *const TSTagsBuffer) -> u32 { +pub extern "C" fn ts_tags_buffer_tags_len(this: *const TSTagsBuffer) -> u32 { let buffer = unwrap_ptr(this); buffer.tags.len() as u32 } +#[no_mangle] +pub extern "C" fn ts_tags_buffer_docs(this: *const TSTagsBuffer) -> *const i8 { + let buffer = unwrap_ptr(this); + buffer.docs.as_ptr() as *const i8 +} + +#[no_mangle] +pub extern "C" fn ts_tags_buffer_docs_len(this: *const TSTagsBuffer) -> u32 { + let buffer = unwrap_ptr(this); + buffer.docs.len() as u32 +} + fn unwrap_ptr<'a, T>(result: *const T) -> &'a T { unsafe { result.as_ref() }.unwrap_or_else(|| { eprintln!("{}:{} - pointer must not be null", file!(), line!()); diff --git a/tags/src/lib.rs b/tags/src/lib.rs index 5f579d1d..e5695845 100644 --- a/tags/src/lib.rs +++ b/tags/src/lib.rs @@ -1,4 +1,4 @@ -mod c_lib; +pub mod c_lib; use memchr::{memchr, memrchr}; use regex::Regex;