Add unit test for tagging via C API. Fix docs handling

This commit is contained in:
Max Brunsfeld 2020-03-18 10:38:20 -07:00
parent e3e1bdba75
commit 651fa38c93
4 changed files with 211 additions and 76 deletions

View file

@ -1,30 +1,56 @@
use super::helpers::allocations;
use super::helpers::fixtures::get_language;
use std::ffi::CString;
use std::{ptr, slice, str};
use tree_sitter_tags::c_lib as c;
use tree_sitter_tags::{TagKind, TagsConfiguration, TagsContext};
const PYTHON_TAG_QUERY: &'static str = r#"
((function_definition
name: (identifier) @name
body: (block . (expression_statement (string) @doc))) @function
(strip! @doc "(^['\"\\s]*)|(['\"\\s]*$)"))
(function_definition
name: (identifier) @name) @function
((class_definition
name: (identifier) @name
body: (block . (expression_statement (string) @doc))) @class
(strip! @doc "(^['\"\\s]*)|(['\"\\s]*$)"))
(class_definition
name: (identifier) @name) @class
(call
function: (identifier) @name) @call
"#;
const JS_TAG_QUERY: &'static str = r#"
((*
(comment)+ @doc .
(class_declaration
name: (identifier) @name) @class)
(select-adjacent! @doc @class)
(strip! @doc "(^[/\\*\\s]*)|([/\\*\\s]*$)"))
((*
(comment)+ @doc .
(method_definition
name: (property_identifier) @name) @method)
(select-adjacent! @doc @method)
(strip! @doc "(^[/\\*\\s]*)|([/\\*\\s]*$)"))
((*
(comment)+ @doc .
(function_declaration
name: (identifier) @name) @function)
(select-adjacent! @doc @function)
(strip! @doc "(^[/\\*\\s]*)|([/\\*\\s]*$)"))
(call_expression function: (identifier) @name) @call
"#;
#[test]
fn test_tags_python() {
let language = get_language("python");
let tags_config = TagsConfiguration::new(
language,
r#"
((function_definition
name: (identifier) @name
body: (block . (expression_statement (string) @doc))) @function
(strip! @doc "(^['\"\\s]*)|(['\"\\s]*$)"))
(function_definition
name: (identifier) @name) @function
((class_definition
name: (identifier) @name
body: (block . (expression_statement (string) @doc))) @class
(strip! @doc "(^['\"\\s]*)|(['\"\\s]*$)"))
(class_definition
name: (identifier) @name) @class
(call
function: (identifier) @name) @call
"#,
"",
)
.unwrap();
let tags_config = TagsConfiguration::new(language, PYTHON_TAG_QUERY, "").unwrap();
let source = br#"
class Customer:
@ -68,27 +94,7 @@ fn test_tags_python() {
#[test]
fn test_tags_javascript() {
let language = get_language("javascript");
let tags_config = TagsConfiguration::new(
language,
r#"
((*
(comment)+ @doc
.
(class_declaration
name: (identifier) @name) @class)
(select-adjacent! @doc @class)
(strip! @doc "(^[/\\*\\s]*)|([/\\*\\s]*$)"))
((*
(comment)+ @doc
.
(method_definition
name: (property_identifier) @name) @method)
; (select-adjacent! @doc @method)
(strip! @doc "(^[/\\*\\s]*)|([/\\*\\s]*$)"))
"#,
"",
)
.unwrap();
let tags_config = TagsConfiguration::new(language, JS_TAG_QUERY, "").unwrap();
let mut tag_context = TagsContext::new();
let source = br#"
@ -132,6 +138,103 @@ fn test_tags_javascript() {
assert_eq!(tags[2].docs, None);
}
#[test]
fn test_tags_via_c_api() {
allocations::record(|| {
let tagger = c::ts_tagger_new();
let buffer = c::ts_tags_buffer_new();
let scope_name = "source.js";
let language = get_language("javascript");
let source_code = "
var a = 1;
// one
// two
// three
function b() {
}
// four
// five
class C extends D {
}
b(a);"
.lines()
.skip(1)
// remove extra indentation
.map(|line| &line[line.len().min(12)..])
.collect::<Vec<_>>()
.join("\n");
let c_scope_name = CString::new(scope_name).unwrap();
let result = c::ts_tagger_add_language(
tagger,
c_scope_name.as_ptr(),
language,
JS_TAG_QUERY.as_ptr(),
ptr::null(),
JS_TAG_QUERY.len() as u32,
0,
);
assert_eq!(result, c::TSTagsError::Ok);
let result = c::ts_tagger_tag(
tagger,
c_scope_name.as_ptr(),
source_code.as_ptr(),
source_code.len() as u32,
buffer,
ptr::null(),
);
assert_eq!(result, c::TSTagsError::Ok);
let tags = unsafe {
slice::from_raw_parts(
c::ts_tags_buffer_tags(buffer),
c::ts_tags_buffer_tags_len(buffer) as usize,
)
};
let docs = str::from_utf8(unsafe {
slice::from_raw_parts(
c::ts_tags_buffer_docs(buffer) as *const u8,
c::ts_tags_buffer_docs_len(buffer) as usize,
)
})
.unwrap();
assert_eq!(
tags.iter()
.map(|tag| (
tag.kind,
&source_code[tag.name_start_byte as usize..tag.name_end_byte as usize],
&source_code[tag.line_start_byte as usize..tag.line_end_byte as usize],
&docs[tag.docs_start_byte as usize..tag.docs_end_byte as usize],
))
.collect::<Vec<_>>(),
&[
(
c::TSTagKind::Function,
"b",
"function b() {",
"one\ntwo\nthree"
),
(
c::TSTagKind::Class,
"C",
"class C extends D {",
"four\nfive"
),
(c::TSTagKind::Call, "b", "b(a);", "")
]
);
c::ts_tags_buffer_delete(buffer);
c::ts_tagger_delete(tagger);
});
}
fn substr<'a>(source: &'a [u8], range: &std::ops::Range<usize>) -> &'a str {
std::str::from_utf8(&source[range.clone()]).unwrap()
}

View file

@ -82,8 +82,12 @@ TSTagsBuffer *ts_tags_buffer_new();
void ts_tags_buffer_delete(TSTagsBuffer *);
// Access the tags within a tag buffer.
const TSTag *ts_tags_buffer_line_offsets(const TSTagsBuffer *);
uint32_t ts_tags_buffer_len(const TSTagsBuffer *);
const TSTag *ts_tags_buffer_tags(const TSTagsBuffer *);
uint32_t ts_tags_buffer_tags_len(const TSTagsBuffer *);
// Access the string containing all of the docs
const char *ts_tags_buffer_docs(const TSTagsBuffer *);
uint32_t ts_tags_buffer_docs_len(const TSTagsBuffer *);
#ifdef __cplusplus
}

View file

@ -6,7 +6,8 @@ use std::{fmt, slice, str};
use tree_sitter::Language;
#[repr(C)]
enum TSTagsError {
#[derive(Debug, PartialEq, Eq)]
pub enum TSTagsError {
Ok,
UnknownScope,
Timeout,
@ -17,7 +18,8 @@ enum TSTagsError {
}
#[repr(C)]
enum TSTagKind {
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum TSTagKind {
Function,
Method,
Class,
@ -26,38 +28,49 @@ enum TSTagKind {
}
#[repr(C)]
struct TSPoint {
pub struct TSPoint {
row: u32,
column: u32,
}
#[repr(C)]
struct TSTag {
kind: TSTagKind,
start_byte: u32,
end_byte: u32,
name_start_byte: u32,
name_end_byte: u32,
line_start_byte: u32,
line_end_byte: u32,
start_point: TSPoint,
end_point: TSPoint,
docs: *const u8,
docs_length: u32,
pub struct TSTag {
pub kind: TSTagKind,
pub start_byte: u32,
pub end_byte: u32,
pub name_start_byte: u32,
pub name_end_byte: u32,
pub line_start_byte: u32,
pub line_end_byte: u32,
pub start_point: TSPoint,
pub end_point: TSPoint,
pub docs_start_byte: u32,
pub docs_end_byte: u32,
}
struct TSTagger {
pub struct TSTagger {
languages: HashMap<String, TagsConfiguration>,
}
struct TSTagsBuffer {
pub struct TSTagsBuffer {
context: TagsContext,
tags: Vec<TSTag>,
docs: Vec<u8>,
}
#[no_mangle]
unsafe extern "C" fn ts_tagger_add_language(
pub extern "C" fn ts_tagger_new() -> *mut TSTagger {
Box::into_raw(Box::new(TSTagger {
languages: HashMap::new(),
}))
}
pub extern "C" fn ts_tagger_delete(this: *mut TSTagger) {
drop(unsafe { Box::from_raw(this) })
}
#[no_mangle]
pub extern "C" fn ts_tagger_add_language(
this: *mut TSTagger,
scope_name: *const i8,
language: Language,
@ -67,9 +80,9 @@ unsafe extern "C" fn ts_tagger_add_language(
locals_query_len: u32,
) -> TSTagsError {
let tagger = unwrap_mut_ptr(this);
let scope_name = unwrap(CStr::from_ptr(scope_name).to_str());
let tags_query = slice::from_raw_parts(tags_query, tags_query_len as usize);
let locals_query = slice::from_raw_parts(locals_query, locals_query_len as usize);
let scope_name = unsafe { unwrap(CStr::from_ptr(scope_name).to_str()) };
let tags_query = unsafe { slice::from_raw_parts(tags_query, tags_query_len as usize) };
let locals_query = unsafe { slice::from_raw_parts(locals_query, locals_query_len as usize) };
let tags_query = match str::from_utf8(tags_query) {
Ok(e) => e,
Err(_) => return TSTagsError::InvalidUtf8,
@ -89,7 +102,7 @@ unsafe extern "C" fn ts_tagger_add_language(
}
#[no_mangle]
unsafe extern "C" fn ts_tagger_tag(
pub extern "C" fn ts_tagger_tag(
this: *mut TSTagger,
scope_name: *const i8,
source_code: *const u8,
@ -99,15 +112,17 @@ unsafe extern "C" fn ts_tagger_tag(
) -> TSTagsError {
let tagger = unwrap_mut_ptr(this);
let buffer = unwrap_mut_ptr(output);
let scope_name = unwrap(CStr::from_ptr(scope_name).to_str());
let scope_name = unsafe { unwrap(CStr::from_ptr(scope_name).to_str()) };
if let Some(config) = tagger.languages.get(scope_name) {
let source_code = slice::from_raw_parts(source_code, source_code_len as usize);
buffer.tags.clear();
buffer.docs.clear();
let source_code = unsafe { slice::from_raw_parts(source_code, source_code_len as usize) };
for tag in buffer.context.generate_tags(config, source_code) {
let prev_docs_len = buffer.docs.len();
if let Some(docs) = tag.docs {
buffer.docs.extend_from_slice(docs.as_bytes());
}
let docs = &buffer.docs[prev_docs_len..];
buffer.tags.push(TSTag {
kind: match tag.kind {
TagKind::Function => TSTagKind::Function,
@ -130,10 +145,11 @@ unsafe extern "C" fn ts_tagger_tag(
row: tag.span.end.row as u32,
column: tag.span.end.column as u32,
},
docs: docs.as_ptr(),
docs_length: docs.len() as u32,
docs_start_byte: prev_docs_len as u32,
docs_end_byte: buffer.docs.len() as u32,
});
}
TSTagsError::Ok
} else {
TSTagsError::UnknownScope
@ -141,7 +157,7 @@ unsafe extern "C" fn ts_tagger_tag(
}
#[no_mangle]
extern "C" fn ts_tags_buffer_new() -> *mut TSTagsBuffer {
pub extern "C" fn ts_tags_buffer_new() -> *mut TSTagsBuffer {
Box::into_raw(Box::new(TSTagsBuffer {
context: TagsContext::new(),
tags: Vec::new(),
@ -150,22 +166,34 @@ extern "C" fn ts_tags_buffer_new() -> *mut TSTagsBuffer {
}
#[no_mangle]
extern "C" fn ts_tags_buffer_delete(this: *mut TSTagsBuffer) {
pub extern "C" fn ts_tags_buffer_delete(this: *mut TSTagsBuffer) {
drop(unsafe { Box::from_raw(this) })
}
#[no_mangle]
extern "C" fn ts_tags_buffer_line_offsets(this: *const TSTagsBuffer) -> *const TSTag {
pub extern "C" fn ts_tags_buffer_tags(this: *const TSTagsBuffer) -> *const TSTag {
let buffer = unwrap_ptr(this);
buffer.tags.as_ptr()
}
#[no_mangle]
extern "C" fn ts_tags_buffer_len(this: *const TSTagsBuffer) -> u32 {
pub extern "C" fn ts_tags_buffer_tags_len(this: *const TSTagsBuffer) -> u32 {
let buffer = unwrap_ptr(this);
buffer.tags.len() as u32
}
#[no_mangle]
pub extern "C" fn ts_tags_buffer_docs(this: *const TSTagsBuffer) -> *const i8 {
let buffer = unwrap_ptr(this);
buffer.docs.as_ptr() as *const i8
}
#[no_mangle]
pub extern "C" fn ts_tags_buffer_docs_len(this: *const TSTagsBuffer) -> u32 {
let buffer = unwrap_ptr(this);
buffer.docs.len() as u32
}
fn unwrap_ptr<'a, T>(result: *const T) -> &'a T {
unsafe { result.as_ref() }.unwrap_or_else(|| {
eprintln!("{}:{} - pointer must not be null", file!(), line!());

View file

@ -1,4 +1,4 @@
mod c_lib;
pub mod c_lib;
use memchr::{memchr, memrchr};
use regex::Regex;