tree-sitter/src/lib.rs

660 lines
18 KiB
Rust
Raw Normal View History

2016-07-10 14:03:00 -07:00
mod ffi;
2018-05-18 11:15:37 -07:00
use std::fmt;
2016-07-10 14:03:00 -07:00
use std::ffi::CStr;
use std::marker::PhantomData;
use std::os::raw::{c_char, c_int, c_void};
use std::ptr;
pub type Language = *const ffi::TSLanguage;
2016-07-10 14:03:00 -07:00
pub trait Utf16Input {
fn read(&mut self) -> &[u16];
fn seek(&mut self, u32, Point);
2016-07-10 14:03:00 -07:00
}
pub trait Utf8Input {
fn read(&mut self) -> &[u8];
fn seek(&mut self, u32, Point);
2016-07-10 14:03:00 -07:00
}
#[derive(Debug, PartialEq, Eq)]
2016-07-10 14:03:00 -07:00
pub enum LogType {
Parse,
Lex,
}
2018-05-18 14:06:49 -07:00
type Logger<'a> = Box<FnMut(LogType, &str) + 'a>;
2018-05-18 11:15:37 -07:00
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
2016-07-10 14:03:00 -07:00
pub struct Point {
pub row: u32,
pub column: u32,
}
2018-05-18 11:15:37 -07:00
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
2016-07-10 14:03:00 -07:00
pub struct InputEdit {
pub start_byte: u32,
pub old_end_byte: u32,
pub new_end_byte: u32,
pub start_position: Point,
pub old_end_position: Point,
pub new_end_position: Point,
}
pub struct Node<'a>(ffi::TSNode, PhantomData<&'a ()>);
pub struct Parser(*mut ffi::TSParser);
2018-05-18 14:06:49 -07:00
pub struct Tree(*mut ffi::TSTree);
2016-07-10 14:03:00 -07:00
pub struct TreeCursor<'a>(ffi::TSTreeCursor, PhantomData<&'a ()>);
struct FlatInput<'a> {
bytes: &'a [u8],
offset: usize,
}
2016-07-10 14:03:00 -07:00
impl Parser {
pub fn new() -> Parser {
unsafe {
let parser = ffi::ts_parser_new();
Parser(parser)
}
}
2018-05-18 11:15:37 -07:00
pub fn set_language(&mut self, language: Language) -> Result<(), String> {
2016-07-10 14:03:00 -07:00
unsafe {
2018-05-18 11:15:37 -07:00
let version = ffi::ts_language_version(language) as usize;
if version == ffi::TREE_SITTER_LANGUAGE_VERSION {
ffi::ts_parser_set_language(self.0, language);
Ok(())
} else {
Err(format!(
"Incompatible language version {}. Expected {}.",
version,
ffi::TREE_SITTER_LANGUAGE_VERSION
))
}
2016-07-10 14:03:00 -07:00
}
}
2018-05-18 14:06:49 -07:00
pub fn logger(&self) -> Option<&Logger> {
let logger = unsafe { ffi::ts_parser_logger(self.0) };
unsafe { (logger.payload as *mut Logger).as_ref() }
}
pub fn set_logger(&mut self, logger: Option<Logger>) {
let prev_logger = unsafe { ffi::ts_parser_logger(self.0) };
if !prev_logger.payload.is_null() {
unsafe { Box::from_raw(prev_logger.payload as *mut Logger) };
}
2016-07-10 14:03:00 -07:00
let c_logger;
if let Some(logger) = logger {
2018-05-18 14:06:49 -07:00
let container = Box::new(logger);
unsafe extern "C" fn log(
payload: *mut c_void,
c_log_type: ffi::TSLogType,
c_message: *const c_char,
) {
let callback = (payload as *mut Logger).as_mut().unwrap();
if let Ok(message) = CStr::from_ptr(c_message).to_str() {
let log_type = if c_log_type == ffi::TSLogType_TSLogTypeParse {
LogType::Parse
} else {
LogType::Lex
};
callback(log_type, message);
}
};
let raw_container = Box::into_raw(container);
c_logger = ffi::TSLogger {
2018-05-18 14:06:49 -07:00
payload: raw_container as *mut c_void,
log: Some(log),
};
} else {
c_logger = ffi::TSLogger { payload: ptr::null_mut(), log: None };
}
2016-07-10 14:03:00 -07:00
unsafe { ffi::ts_parser_set_logger(self.0, c_logger) };
}
2018-05-18 14:27:08 -07:00
pub fn parse_str(&mut self, input: &str, old_tree: Option<&Tree>) -> Option<Tree> {
let mut input = FlatInput { bytes: input.as_bytes(), offset: 0};
self.parse_utf8(&mut input, old_tree)
}
2016-07-10 14:03:00 -07:00
pub fn parse_utf8<T: Utf8Input>(
&mut self,
input: &mut T,
2018-05-18 14:27:08 -07:00
old_tree: Option<&Tree>,
2016-07-10 14:03:00 -07:00
) -> Option<Tree> {
unsafe extern "C" fn read<T: Utf8Input>(
payload: *mut c_void,
bytes_read: *mut u32,
) -> *const c_char {
let input = (payload as *mut T).as_mut().unwrap();
let result = input.read();
*bytes_read = result.len() as u32;
return result.as_ptr() as *const c_char;
};
unsafe extern "C" fn seek<T: Utf8Input>(
payload: *mut c_void,
byte: u32,
position: ffi::TSPoint,
) -> c_int {
let input = (payload as *mut T).as_mut().unwrap();
input.seek(
byte,
Point {
row: position.row,
column: position.column,
},
);
return 1;
};
let c_input = ffi::TSInput {
payload: input as *mut T as *mut c_void,
read: Some(read::<T>),
seek: Some(seek::<T>),
encoding: ffi::TSInputEncoding_TSInputEncodingUTF8,
};
let old_tree_ptr = old_tree.map_or(ptr::null_mut(), |t| t.0);
let new_tree_ptr = unsafe { ffi::ts_parser_parse(self.0, old_tree_ptr, c_input) };
if new_tree_ptr.is_null() {
None
} else {
2018-05-18 14:06:49 -07:00
Some(Tree(new_tree_ptr))
2016-07-10 14:03:00 -07:00
}
}
pub fn parse_utf16<T: Utf16Input>(
&mut self,
input: &mut T,
2018-05-18 14:27:08 -07:00
old_tree: Option<&Tree>,
2016-07-10 14:03:00 -07:00
) -> Option<Tree> {
unsafe extern "C" fn read<T: Utf16Input>(
payload: *mut c_void,
bytes_read: *mut u32,
) -> *const c_char {
let input = (payload as *mut T).as_mut().unwrap();
let result = input.read();
*bytes_read = result.len() as u32 * 2;
return result.as_ptr() as *const c_char;
};
unsafe extern "C" fn seek<T: Utf16Input>(
payload: *mut c_void,
byte: u32,
position: ffi::TSPoint,
) -> c_int {
let input = (payload as *mut T).as_mut().unwrap();
input.seek(
byte / 2,
Point {
row: position.row,
column: position.column / 2,
},
);
return 1;
};
let c_input = ffi::TSInput {
payload: input as *mut T as *mut c_void,
read: Some(read::<T>),
seek: Some(seek::<T>),
encoding: ffi::TSInputEncoding_TSInputEncodingUTF8,
};
let old_tree_ptr = old_tree.map_or(ptr::null_mut(), |t| t.0);
let new_tree_ptr = unsafe { ffi::ts_parser_parse(self.0, old_tree_ptr, c_input) };
if new_tree_ptr.is_null() {
None
} else {
2018-05-18 14:06:49 -07:00
Some(Tree(new_tree_ptr))
2016-07-10 14:03:00 -07:00
}
}
}
impl Drop for Parser {
fn drop(&mut self) {
2018-05-18 14:06:49 -07:00
self.set_logger(None);
2016-07-10 14:03:00 -07:00
unsafe { ffi::ts_parser_delete(self.0) }
}
}
impl Tree {
pub fn root_node(&self) -> Node {
Node::new(unsafe { ffi::ts_tree_root_node(self.0) }).unwrap()
}
pub fn edit(&mut self, edit: &InputEdit) {
let edit = ffi::TSInputEdit {
start_byte: edit.start_byte,
old_end_byte: edit.old_end_byte,
new_end_byte: edit.new_end_byte,
start_point: edit.start_position.into(),
old_end_point: edit.old_end_position.into(),
new_end_point: edit.new_end_position.into(),
};
unsafe { ffi::ts_tree_edit(self.0, &edit) };
}
pub fn walk(&self) -> TreeCursor {
TreeCursor(unsafe { ffi::ts_tree_cursor_new(self.0) }, PhantomData)
}
}
impl Drop for Tree {
fn drop(&mut self) {
unsafe { ffi::ts_tree_delete(self.0) }
}
}
impl Clone for Tree {
fn clone(&self) -> Tree {
2018-05-18 14:06:49 -07:00
unsafe { Tree(ffi::ts_tree_copy(self.0)) }
2016-07-10 14:03:00 -07:00
}
}
2018-05-18 14:27:08 -07:00
impl<'tree> Node<'tree> {
2016-07-10 14:03:00 -07:00
fn new(node: ffi::TSNode) -> Option<Self> {
if node.id.is_null() {
None
} else {
Some(Node(node, PhantomData))
}
}
2018-05-18 11:15:37 -07:00
pub fn kind_id(&self) -> u16 {
unsafe { ffi::ts_node_symbol(self.0) }
}
2018-05-18 10:44:14 -07:00
pub fn kind(&self) -> &'static str {
unsafe { CStr::from_ptr(ffi::ts_node_type(self.0)) }.to_str().unwrap()
2016-07-10 14:03:00 -07:00
}
2018-05-18 10:44:14 -07:00
pub fn is_named(&self) -> bool {
unsafe { ffi::ts_node_is_named(self.0) }
}
pub fn has_changes(&self) -> bool {
unsafe { ffi::ts_node_has_changes(self.0) }
}
pub fn has_error(&self) -> bool {
unsafe { ffi::ts_node_has_error(self.0) }
}
pub fn start_byte(&self) -> u32 {
2016-07-10 14:03:00 -07:00
unsafe { ffi::ts_node_start_byte(self.0) }
}
2018-05-18 10:44:14 -07:00
pub fn end_byte(&self) -> u32 {
2016-07-10 14:03:00 -07:00
unsafe { ffi::ts_node_end_byte(self.0) }
}
pub fn start_position(&self) -> Point {
let result = unsafe { ffi::ts_node_start_point(self.0) };
Point {
row: result.row,
column: result.column,
}
}
pub fn end_position(&self) -> Point {
let result = unsafe { ffi::ts_node_end_point(self.0) };
Point {
row: result.row,
column: result.column,
}
}
2018-05-18 14:27:08 -07:00
pub fn child(&self, i: u32) -> Option<Self> {
2016-07-10 14:03:00 -07:00
Self::new(unsafe { ffi::ts_node_child(self.0, i) })
}
pub fn child_count(&self) -> u32 {
unsafe { ffi::ts_node_child_count(self.0) }
}
2018-05-18 14:27:08 -07:00
pub fn named_child<'a>(&'a self, i: u32) -> Option<Self> {
2018-05-18 10:44:14 -07:00
Self::new(unsafe { ffi::ts_node_named_child(self.0, i) })
}
pub fn named_child_count(&self) -> u32 {
unsafe { ffi::ts_node_named_child_count(self.0) }
}
2018-05-18 14:27:08 -07:00
pub fn parent(&self) -> Option<Self> {
2016-07-10 14:03:00 -07:00
Self::new(unsafe { ffi::ts_node_parent(self.0) })
}
2018-05-18 14:27:08 -07:00
pub fn next_sibling(&self) -> Option<Self> {
2018-05-18 10:44:14 -07:00
Self::new(unsafe { ffi::ts_node_next_sibling(self.0) })
}
2018-05-18 14:27:08 -07:00
pub fn prev_sibling(&self) -> Option<Self> {
2018-05-18 10:44:14 -07:00
Self::new(unsafe { ffi::ts_node_prev_sibling(self.0) })
}
2018-05-18 14:27:08 -07:00
pub fn next_named_sibling(&self) -> Option<Self> {
2018-05-18 10:44:14 -07:00
Self::new(unsafe { ffi::ts_node_next_named_sibling(self.0) })
}
2018-05-18 14:27:08 -07:00
pub fn prev_named_sibling(&self) -> Option<Self> {
2018-05-18 10:44:14 -07:00
Self::new(unsafe { ffi::ts_node_prev_named_sibling(self.0) })
}
pub fn to_sexp(&self) -> String {
2018-05-18 11:15:37 -07:00
extern "C" { fn free(pointer: *mut c_void); }
let c_string = unsafe { ffi::ts_node_string(self.0) };
let result = unsafe { CStr::from_ptr(c_string) }.to_str().unwrap().to_string();
unsafe { free(c_string as *mut c_void) };
result
}
2016-07-10 14:03:00 -07:00
}
2018-05-18 11:15:37 -07:00
impl<'a> PartialEq for Node<'a> {
fn eq(&self, other: &Self) -> bool {
self.0.id == other.0.id
}
}
impl<'a> fmt::Debug for Node<'a> {
fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
write!(f, "{{Node {} {} - {}}}", self.kind(), self.start_position(), self.end_position())
}
}
2016-07-10 14:03:00 -07:00
impl<'a> TreeCursor<'a> {
2018-05-18 10:44:14 -07:00
pub fn node(&'a self) -> Node<'a> {
2016-07-10 14:03:00 -07:00
Node(
unsafe { ffi::ts_tree_cursor_current_node(&self.0) },
PhantomData,
)
}
2018-05-18 10:44:14 -07:00
pub fn goto_first_child(&mut self) -> bool {
2016-07-10 14:03:00 -07:00
return unsafe { ffi::ts_tree_cursor_goto_first_child(&mut self.0) };
}
2018-05-18 10:44:14 -07:00
pub fn goto_parent(&mut self) -> bool {
2016-07-10 14:03:00 -07:00
return unsafe { ffi::ts_tree_cursor_goto_parent(&mut self.0) };
}
2018-05-18 10:44:14 -07:00
pub fn goto_next_sibling(&mut self) -> bool {
2016-07-10 14:03:00 -07:00
return unsafe { ffi::ts_tree_cursor_goto_next_sibling(&mut self.0) };
}
2018-05-18 10:44:14 -07:00
pub fn goto_first_child_for_index(&mut self, index: u32) -> Option<u32> {
2016-07-10 14:03:00 -07:00
let result = unsafe { ffi::ts_tree_cursor_goto_first_child_for_byte(&mut self.0, index) };
if result < 0 {
None
} else {
Some(result as u32)
}
}
}
impl<'a> Drop for TreeCursor<'a> {
fn drop(&mut self) {
unsafe { ffi::ts_tree_cursor_delete(&mut self.0) }
}
}
2018-05-18 14:27:08 -07:00
impl Point {
pub fn new(row: u32, column: u32) -> Self {
Point { row, column }
}
}
2018-05-18 11:15:37 -07:00
impl fmt::Display for Point {
fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
write!(f, "({}, {})", self.row, self.column)
}
}
2016-07-10 14:03:00 -07:00
impl Into<ffi::TSPoint> for Point {
fn into(self) -> ffi::TSPoint {
ffi::TSPoint {
row: self.row,
column: self.column,
}
}
}
impl<'a> Utf8Input for FlatInput<'a> {
fn read(&mut self) -> &[u8] {
let result = &self.bytes[self.offset..];
self.offset = self.bytes.len();
result
}
fn seek(&mut self, offset: u32, _position: Point) {
self.offset = offset as usize;
}
}
2016-07-10 14:03:00 -07:00
#[cfg(test)]
mod tests {
use super::*;
fn rust() -> Language { unsafe { tree_sitter_rust() } }
extern "C" { fn tree_sitter_rust() -> Language; }
#[test]
fn test_basic_parsing() {
let mut parser = Parser::new();
2018-05-18 11:15:37 -07:00
parser.set_language(rust()).unwrap();
let tree = parser.parse_str("
struct Stuff {}
fn main() {}
", None).unwrap();
let root_node = tree.root_node();
2018-05-18 10:44:14 -07:00
assert_eq!(root_node.kind(), "source_file");
assert_eq!(
root_node.to_sexp(),
"(source_file (struct_item (type_identifier) (field_declaration_list)) (function_item (identifier) (parameters) (block)))"
);
let struct_node = root_node.child(0).unwrap();
2018-05-18 10:44:14 -07:00
assert_eq!(struct_node.kind(), "struct_item");
}
2016-07-10 14:03:00 -07:00
#[test]
fn test_logging() {
let mut parser = Parser::new();
2018-05-18 11:15:37 -07:00
parser.set_language(rust()).unwrap();
let mut messages = Vec::new();
2018-05-18 14:06:49 -07:00
parser.set_logger(Some(Box::new(|log_type, message| {
messages.push((log_type, message.to_string()));
2018-05-18 14:06:49 -07:00
})));
parser.parse_str("
struct Stuff {}
fn main() {}
", None).unwrap();
assert!(messages.contains(&(LogType::Parse, "reduce sym:struct_item, child_count:3".to_string())));
assert!(messages.contains(&(LogType::Lex, "skip character:' '".to_string())));
}
2018-05-18 10:44:14 -07:00
#[test]
fn test_tree_cursor() {
let mut parser = Parser::new();
2018-05-18 11:15:37 -07:00
parser.set_language(rust()).unwrap();
2018-05-18 10:44:14 -07:00
let tree = parser.parse_str("
struct Stuff {
a: A;
b: Option<B>,
}
", None).unwrap();
let mut cursor = tree.walk();
assert_eq!(cursor.node().kind(), "source_file");
assert!(cursor.goto_first_child());
assert_eq!(cursor.node().kind(), "struct_item");
assert!(cursor.goto_first_child());
assert_eq!(cursor.node().kind(), "struct");
assert_eq!(cursor.node().is_named(), false);
assert!(cursor.goto_next_sibling());
assert_eq!(cursor.node().kind(), "type_identifier");
assert_eq!(cursor.node().is_named(), true);
assert!(cursor.goto_next_sibling());
assert_eq!(cursor.node().kind(), "field_declaration_list");
assert_eq!(cursor.node().is_named(), true);
}
#[test]
fn test_custom_utf8_input() {
struct LineBasedInput {
lines: &'static [&'static str],
row: usize,
column: usize,
}
impl Utf8Input for LineBasedInput {
fn read(&mut self) -> &[u8] {
if self.row < self.lines.len() {
let result = &self.lines[self.row].as_bytes()[self.column..];
self.row += 1;
self.column = 0;
result
} else {
&[]
}
}
fn seek(&mut self, _byte: u32, position: Point) {
self.row = position.row as usize;
self.column = position.column as usize;
}
}
let mut parser = Parser::new();
2018-05-18 11:15:37 -07:00
parser.set_language(rust()).unwrap();
2018-05-18 10:44:14 -07:00
let mut input = LineBasedInput {
lines: &[
"pub fn main() {",
"}",
],
row: 0,
column: 0
};
let tree = parser.parse_utf8(&mut input, None).unwrap();
let root = tree.root_node();
assert_eq!(root.kind(), "source_file");
assert_eq!(root.has_error(), false);
let child = root.child(0).unwrap();
assert_eq!(child.kind(), "function_item");
}
2018-05-18 11:15:37 -07:00
#[test]
fn test_node_equality() {
let mut parser = Parser::new();
parser.set_language(rust()).unwrap();
let tree = parser.parse_str("struct A {}", None).unwrap();
let node1 = tree.root_node();
let node2 = tree.root_node();
assert_eq!(node1, node2);
assert_eq!(node1.child(0).unwrap(), node2.child(0).unwrap());
assert_ne!(node1.child(0).unwrap(), node2);
}
2018-05-18 14:27:08 -07:00
#[test]
fn test_editing() {
struct SpyInput {
bytes: &'static [u8],
offset: usize,
bytes_read: Vec<u8>,
}
impl Utf8Input for SpyInput {
fn read(&mut self) -> &[u8] {
if self.offset < self.bytes.len() {
let result = &self.bytes[self.offset..self.offset + 1];
self.bytes_read.extend(result.iter());
self.offset += 1;
result
} else {
&[]
}
}
fn seek(&mut self, byte: u32, _position: Point) {
self.offset = byte as usize;
}
}
let mut input = SpyInput {
bytes: "fn test(a: A, c: C) {}".as_bytes(),
offset: 0,
bytes_read: Vec::new(),
};
let mut parser = Parser::new();
parser.set_language(rust()).unwrap();
let mut tree = parser.parse_utf8(&mut input, None).unwrap();
let parameters_sexp = tree.root_node()
.named_child(0).unwrap()
.named_child(1).unwrap()
.to_sexp();
assert_eq!(
parameters_sexp,
"(parameters (parameter (identifier) (type_identifier)) (parameter (identifier) (type_identifier)))"
);
input.offset = 0;
input.bytes_read.clear();
input.bytes = "fn test(a: A, b: B, c: C) {}".as_bytes();
tree.edit(&InputEdit{
start_byte: 14,
old_end_byte: 14,
new_end_byte: 20,
start_position: Point::new(0, 14),
old_end_position: Point::new(0, 14),
new_end_position: Point::new(0, 20),
});
let tree = parser.parse_utf8(&mut input, Some(&tree)).unwrap();
let parameters_sexp = tree.root_node()
.named_child(0).unwrap()
.named_child(1).unwrap()
.to_sexp();
assert_eq!(
parameters_sexp,
"(parameters (parameter (identifier) (type_identifier)) (parameter (identifier) (type_identifier)) (parameter (identifier) (type_identifier)))"
);
let retokenized_content = String::from_utf8(input.bytes_read).unwrap();
assert!(retokenized_content.contains("b: B"));
assert!(!retokenized_content.contains("a: A"));
assert!(!retokenized_content.contains("c: C"));
assert!(!retokenized_content.contains("{}"));
}
2016-07-10 14:03:00 -07:00
}