From db360b73fb33d5c03a226b42b1bfa60398645873 Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Sat, 13 Oct 2018 14:09:36 -0700 Subject: [PATCH] Add Tree.walk_with_properties --- Cargo.toml | 5 + src/lib.rs | 294 +++++++++++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 292 insertions(+), 7 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 9adbcfd1..485d369e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,5 +20,10 @@ include = [ "/vendor/tree-sitter/src/runtime/*", ] +[dependencies] +serde = "1.0" +serde_json = "1.0" +serde_derive = "1.0" + [build-dependencies] cc = "1.0" diff --git a/src/lib.rs b/src/lib.rs index 4a132a3f..19b9a670 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,11 +1,17 @@ mod ffi; +#[macro_use] +extern crate serde_derive; +extern crate serde_json; + +use std::collections::HashMap; use std::ffi::CStr; use std::fmt; use std::io::{self, Read, Seek}; use std::marker::PhantomData; use std::os::raw::{c_char, c_void}; use std::ptr; +use std::str; #[derive(Clone, Copy)] #[repr(transparent)] @@ -19,7 +25,7 @@ pub enum LogType { type Logger<'a> = Box; -#[derive(Clone, Copy, Debug, PartialEq, Eq)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)] pub struct Point { pub row: u32, pub column: u32, @@ -35,6 +41,22 @@ pub struct InputEdit { pub new_end_position: Point, } +struct PropertyTransition { + state_id: u32, + child_index: Option, +} + +struct PropertyState { + transitions: HashMap>, + property_set_id: u32, + default_next_state_id: u32, +} + +pub struct PropertySheet { + states: Vec, + property_sets: Vec>, +} + pub struct Node<'a>(ffi::TSNode, PhantomData<&'a ()>); pub struct Parser(*mut ffi::TSParser); @@ -43,6 +65,13 @@ pub struct Tree(*mut ffi::TSTree); pub struct TreeCursor<'a>(ffi::TSTreeCursor, PhantomData<&'a ()>); +pub struct TreePropertyCursor<'a> { + cursor: TreeCursor<'a>, + state_stack: Vec, + child_index_stack: Vec, + property_sheet: &'a PropertySheet, +} + impl Language { pub fn node_kind_count(&self) -> usize { unsafe { ffi::ts_language_symbol_count(self.0) as usize } @@ -310,6 +339,13 @@ impl Tree { pub fn walk(&self) -> TreeCursor { self.root_node().walk() } + + pub fn walk_with_properties<'a>( + &'a self, + property_sheet: &'a PropertySheet, + ) -> TreePropertyCursor<'a> { + TreePropertyCursor::new(self, property_sheet) + } } unsafe impl Send for Tree {} @@ -437,6 +473,14 @@ impl<'tree> Node<'tree> { result } + pub fn utf8_text<'a>(&self, source: &'a str) -> Result<&'a str, str::Utf8Error> { + str::from_utf8(&source.as_bytes()[self.start_byte() as usize..self.end_byte() as usize]) + } + + pub fn utf16_text<'a>(&self, source: &'a [u16]) -> &'a [u16] { + &source[self.start_byte() as usize..self.end_byte() as usize] + } + pub fn walk(&self) -> TreeCursor<'tree> { TreeCursor(unsafe { ffi::ts_tree_cursor_new(self.0) }, PhantomData) } @@ -461,7 +505,7 @@ impl<'a> fmt::Debug for Node<'a> { } impl<'a> TreeCursor<'a> { - pub fn node(&'a self) -> Node<'a> { + pub fn node(&self) -> Node<'a> { Node( unsafe { ffi::ts_tree_cursor_current_node(&self.0) }, PhantomData, @@ -496,6 +540,87 @@ impl<'a> Drop for TreeCursor<'a> { } } +impl<'a> TreePropertyCursor<'a> { + fn new(tree: &'a Tree, property_sheet: &'a PropertySheet) -> Self { + Self { + cursor: tree.root_node().walk(), + child_index_stack: vec![0], + state_stack: vec![0], + property_sheet, + } + } + + pub fn node(&self) -> Node<'a> { + self.cursor.node() + } + + pub fn node_properties(&self) -> &'a HashMap { + &self.property_sheet.property_sets[self.current_state().property_set_id as usize] + } + + pub fn goto_first_child(&mut self) -> bool { + if self.cursor.goto_first_child() { + let child_index = 0; + let next_state_id = { + let state = &self.current_state(); + let kind_id = self.cursor.node().kind_id(); + self.next_state(state, kind_id, child_index) + }; + self.state_stack.push(next_state_id); + self.child_index_stack.push(child_index); + true + } else { + false + } + } + + pub fn goto_next_sibling(&mut self) -> bool { + if self.cursor.goto_next_sibling() { + let child_index = self.child_index_stack.pop().unwrap() + 1; + self.state_stack.pop(); + let next_state_id = { + let state = &self.current_state(); + let kind_id = self.cursor.node().kind_id(); + self.next_state(state, kind_id, child_index) + }; + self.state_stack.push(next_state_id); + self.child_index_stack.push(child_index); + true + } else { + false + } + } + + pub fn goto_parent(&mut self) -> bool { + if self.cursor.goto_parent() { + self.state_stack.pop(); + self.child_index_stack.pop(); + true + } else { + false + } + } + + fn next_state(&self, state: &PropertyState, node_kind_id: u16, node_child_index: u32) -> u32 { + state + .transitions + .get(&node_kind_id) + .and_then(|transitions| { + for transition in transitions.iter() { + if transition.child_index == Some(node_child_index) || transition.child_index == None { + return Some(transition.state_id); + } + } + None + }) + .unwrap_or(state.default_next_state_id) + } + + fn current_state(&self) -> &PropertyState { + &self.property_sheet.states[*self.state_stack.last().unwrap() as usize] + } +} + impl Point { pub fn new(row: u32, column: u32) -> Self { Point { row, column } @@ -526,6 +651,64 @@ impl From for Point { } } +impl PropertySheet { + pub fn new(language: Language, json: &str) -> Result { + #[derive(Deserialize, Debug)] + struct PropertyTransitionJSON { + #[serde(rename = "type")] + kind: String, + named: bool, + index: Option, + state_id: u32, + } + + #[derive(Deserialize, Debug)] + struct PropertyStateJSON { + transitions: Vec, + property_set_id: u32, + default_next_state_id: u32, + } + + #[derive(Deserialize, Debug)] + struct PropertySheetJSON { + states: Vec, + property_sets: Vec>, + } + + let input: PropertySheetJSON = serde_json::from_str(json)?; + Ok(PropertySheet { + property_sets: input.property_sets, + states: input + .states + .iter() + .map(|state| { + let mut transitions = HashMap::new(); + let node_kind_count = language.node_kind_count(); + for transition in state.transitions.iter() { + for i in 0..node_kind_count { + let i = i as u16; + if language.node_kind_is_named(i) == transition.named + && transition.kind == language.node_kind_for_id(i) + { + let entry = transitions.entry(i).or_insert(Vec::new()); + entry.push(PropertyTransition { + child_index: transition.index, + state_id: transition.state_id, + }); + } + } + } + PropertyState { + transitions, + default_next_state_id: state.default_next_state_id, + property_set_id: state.property_set_id, + } + }) + .collect(), + }) + } +} + #[cfg(test)] mod tests { use super::*; @@ -600,11 +783,11 @@ mod tests { let tree = parser .parse_str( " - struct Stuff { - a: A; - b: Option, - } - ", + struct Stuff { + a: A; + b: Option, + } + ", None, ) .unwrap(); @@ -628,6 +811,103 @@ mod tests { assert_eq!(cursor.node().is_named(), true); } + #[test] + fn test_tree_property_matching() { + let mut parser = Parser::new(); + parser.set_language(rust()).unwrap(); + let tree = parser.parse_str("fn f1() { f2(); }", None).unwrap(); + + let property_sheet = PropertySheet::new( + rust(), + r##" + { + "states": [ + { + "transitions": [ + {"type": "call_expression", "named": true, "state_id": 1}, + {"type": "function_item", "named": true, "state_id": 2} + ], + "default_next_state_id": 0, + "property_set_id": 0 + }, + { + "transitions": [ + {"type": "identifier", "named": true, "state_id": 3} + ], + "default_next_state_id": 0, + "property_set_id": 0 + }, + { + "transitions": [ + {"type": "identifier", "named": true, "state_id": 4} + ], + "default_next_state_id": 0, + "property_set_id": 0 + }, + { + "transitions": [], + "default_next_state_id": 0, + "property_set_id": 1 + }, + { + "transitions": [], + "default_next_state_id": 0, + "property_set_id": 2 + } + ], + "property_sets": [ + {}, + {"reference": "function"}, + {"define": "function"} + ] + } + "##, + ) + .unwrap(); + + let mut cursor = tree.walk_with_properties(&property_sheet); + assert_eq!(cursor.node().kind(), "source_file"); + assert_eq!(*cursor.node_properties(), HashMap::new()); + + assert!(cursor.goto_first_child()); + assert_eq!(cursor.node().kind(), "function_item"); + assert_eq!(*cursor.node_properties(), HashMap::new()); + + assert!(cursor.goto_first_child()); + assert_eq!(cursor.node().kind(), "fn"); + assert_eq!(*cursor.node_properties(), HashMap::new()); + assert!(!cursor.goto_first_child()); + + assert!(cursor.goto_next_sibling()); + assert_eq!(cursor.node().kind(), "identifier"); + assert_eq!(cursor.node_properties()["define"], "function"); + assert!(!cursor.goto_first_child()); + + assert!(cursor.goto_next_sibling()); + assert_eq!(cursor.node().kind(), "parameters"); + assert_eq!(*cursor.node_properties(), HashMap::new()); + + assert!(cursor.goto_first_child()); + assert_eq!(cursor.node().kind(), "("); + assert!(cursor.goto_next_sibling()); + assert_eq!(cursor.node().kind(), ")"); + assert_eq!(*cursor.node_properties(), HashMap::new()); + + assert!(cursor.goto_parent()); + assert!(cursor.goto_next_sibling()); + assert_eq!(cursor.node().kind(), "block"); + assert_eq!(*cursor.node_properties(), HashMap::new()); + + assert!(cursor.goto_first_child()); + assert!(cursor.goto_next_sibling()); + assert_eq!(cursor.node().kind(), "call_expression"); + assert_eq!(*cursor.node_properties(), HashMap::new()); + + assert!(cursor.goto_first_child()); + assert_eq!(cursor.node().kind(), "identifier"); + assert_eq!(cursor.node_properties()["reference"], "function"); + } + #[test] fn test_custom_utf8_input() { let mut parser = Parser::new();