use rand::prelude::Rng; use std::{cmp::Ordering, fmt::Write, ops::Range}; use tree_sitter::{Node, Point, Tree, TreeCursor}; #[derive(Debug)] pub struct Pattern { kind: Option<&'static str>, named: bool, field: Option<&'static str>, capture: Option, children: Vec, } #[derive(Clone, Debug, PartialEq, Eq)] pub struct Match<'a, 'tree> { pub captures: Vec<(&'a str, Node<'tree>)>, pub last_node: Option>, } const CAPTURE_NAMES: &'static [&'static str] = &[ "one", "two", "three", "four", "five", "six", "seven", "eight", ]; impl Pattern { pub fn random_pattern_in_tree(tree: &Tree, rng: &mut impl Rng) -> (Self, Range) { let mut cursor = tree.walk(); // Descend to the node at a random byte offset and depth. let mut max_depth = 0; let byte_offset = rng.gen_range(0..cursor.node().end_byte()); while cursor.goto_first_child_for_byte(byte_offset).is_some() { max_depth += 1; } let depth = rng.gen_range(0..=max_depth); for _ in 0..depth { cursor.goto_parent(); } // Build a pattern that matches that node. // Sometimes include subsequent siblings of the node. let pattern_start = cursor.node().start_position(); let mut roots = vec![Self::random_pattern_for_node(&mut cursor, rng)]; while roots.len() < 5 && cursor.goto_next_sibling() { if rng.gen_bool(0.2) { roots.push(Self::random_pattern_for_node(&mut cursor, rng)); } } let pattern_end = cursor.node().end_position(); let mut pattern = Self { kind: None, named: true, field: None, capture: None, children: roots, }; if pattern.children.len() == 1 { pattern = pattern.children.pop().unwrap(); } // In a parenthesized list of sibling patterns, the first // sibling can't be an anonymous `_` wildcard. else if pattern.children[0].kind == Some("_") && !pattern.children[0].named { pattern = pattern.children.pop().unwrap(); } // In a parenthesized list of sibling patterns, the first // sibling can't have a field name. else { pattern.children[0].field = None; } (pattern, pattern_start..pattern_end) } fn random_pattern_for_node(cursor: &mut TreeCursor, rng: &mut impl Rng) -> Self { let node = cursor.node(); // Sometimes specify the node's type, sometimes use a wildcard. let (kind, named) = if rng.gen_bool(0.9) { (Some(node.kind()), node.is_named()) } else { (Some("_"), node.is_named() && rng.gen_bool(0.8)) }; // Sometimes specify the node's field. let field = if rng.gen_bool(0.75) { cursor.field_name() } else { None }; // Sometimes capture the node. let capture = if rng.gen_bool(0.7) { Some(CAPTURE_NAMES[rng.gen_range(0..CAPTURE_NAMES.len())].to_string()) } else { None }; // Walk the children and include child patterns for some of them. let mut children = Vec::new(); if named && cursor.goto_first_child() { let max_children = rng.gen_range(0..4); while cursor.goto_next_sibling() { if rng.gen_bool(0.6) { let child_ast = Self::random_pattern_for_node(cursor, rng); children.push(child_ast); if children.len() >= max_children { break; } } } cursor.goto_parent(); } Self { kind, named, field, capture, children, } } pub fn to_string(&self) -> String { let mut result = String::new(); self.write_to_string(&mut result, 0); result } fn write_to_string(&self, string: &mut String, indent: usize) { if let Some(field) = self.field { write!(string, "{}: ", field).unwrap(); } if self.named { string.push('('); let mut has_contents = false; if let Some(kind) = &self.kind { write!(string, "{}", kind).unwrap(); has_contents = true; } for child in &self.children { let indent = indent + 2; if has_contents { string.push('\n'); string.push_str(&" ".repeat(indent)); } child.write_to_string(string, indent); has_contents = true; } string.push(')'); } else if self.kind == Some("_") { string.push('_'); } else { write!(string, "\"{}\"", self.kind.unwrap().replace("\"", "\\\"")).unwrap(); } if let Some(capture) = &self.capture { write!(string, " @{}", capture).unwrap(); } } pub fn matches_in_tree<'tree>(&self, tree: &'tree Tree) -> Vec> { let mut matches = Vec::new(); // Compute the matches naively: walk the tree and // retry the entire pattern for each node. let mut cursor = tree.walk(); let mut ascending = false; loop { if ascending { if cursor.goto_next_sibling() { ascending = false; } else if !cursor.goto_parent() { break; } } else { let matches_here = self.match_node(&mut cursor); matches.extend_from_slice(&matches_here); if !cursor.goto_first_child() { ascending = true; } } } matches.sort_unstable(); matches.iter_mut().for_each(|m| m.last_node = None); matches.dedup(); matches } pub fn match_node<'tree>(&self, cursor: &mut TreeCursor<'tree>) -> Vec> { let node = cursor.node(); // If a kind is specified, check that it matches the node. if let Some(kind) = self.kind { if kind == "_" { if self.named && !node.is_named() { return Vec::new(); } } else if kind != node.kind() || self.named != node.is_named() { return Vec::new(); } } // If a field is specified, check that it matches the node. if let Some(field) = self.field { if cursor.field_name() != Some(field) { return Vec::new(); } } // Create a match for the current node. let mat = Match { captures: if let Some(name) = &self.capture { vec![(name.as_str(), node)] } else { Vec::new() }, last_node: Some(node), }; // If there are no child patterns to match, then return this single match. if self.children.is_empty() { return vec![mat]; } // Find every matching combination of child patterns and child nodes. let mut finished_matches = Vec::::new(); if cursor.goto_first_child() { let mut match_states = vec![(0, mat)]; loop { let mut new_match_states = Vec::new(); for (pattern_index, mat) in &match_states { let child_pattern = &self.children[*pattern_index]; let child_matches = child_pattern.match_node(cursor); for child_match in child_matches { let mut combined_match = mat.clone(); combined_match.last_node = child_match.last_node; combined_match .captures .extend_from_slice(&child_match.captures); if pattern_index + 1 < self.children.len() { new_match_states.push((*pattern_index + 1, combined_match)); } else { let mut existing = false; for existing_match in finished_matches.iter_mut() { if existing_match.captures == combined_match.captures { if child_pattern.capture.is_some() { existing_match.last_node = combined_match.last_node; } existing = true; } } if !existing { finished_matches.push(combined_match); } } } } match_states.extend_from_slice(&new_match_states); if !cursor.goto_next_sibling() { break; } } cursor.goto_parent(); } finished_matches } } impl<'a, 'tree> PartialOrd for Match<'a, 'tree> { fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } } impl<'a, 'tree> Ord for Match<'a, 'tree> { // Tree-sitter returns matches in the order that they terminate // during a depth-first walk of the tree. If multiple matches // terminate on the same node, those matches are produced in the // order that their captures were discovered. fn cmp(&self, other: &Self) -> Ordering { if let Some((last_node_a, last_node_b)) = self.last_node.zip(other.last_node) { let cmp = compare_depth_first(last_node_a, last_node_b); if cmp.is_ne() { return cmp; } } for (a, b) in self.captures.iter().zip(other.captures.iter()) { let cmp = compare_depth_first(a.1, b.1); if !cmp.is_eq() { return cmp; } } self.captures.len().cmp(&other.captures.len()) } } fn compare_depth_first(a: Node, b: Node) -> Ordering { let a = a.byte_range(); let b = b.byte_range(); a.start.cmp(&b.start).then_with(|| b.end.cmp(&a.end)) }