tree-sitter/crates/cli/src/tests/helpers/query_helpers.rs

366 lines
12 KiB
Rust

use std::{cmp::Ordering, fmt::Write, ops::Range};
use rand::prelude::Rng;
use streaming_iterator::{IntoStreamingIterator, StreamingIterator};
use tree_sitter::{
Language, Node, Parser, Point, Query, QueryCapture, QueryCursor, QueryMatch, Tree, TreeCursor,
};
#[derive(Debug)]
pub struct Pattern {
kind: Option<&'static str>,
named: bool,
field: Option<&'static str>,
capture: Option<String>,
children: Vec<Self>,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct Match<'a, 'tree> {
pub captures: Vec<(&'a str, Node<'tree>)>,
pub last_node: Option<Node<'tree>>,
}
const CAPTURE_NAMES: &[&str] = &[
"one", "two", "three", "four", "five", "six", "seven", "eight",
];
impl Pattern {
pub fn random_pattern_in_tree(tree: &Tree, rng: &mut impl Rng) -> (Self, Range<Point>) {
let mut cursor = tree.walk();
// Descend to the node at a random byte offset and depth.
let mut max_depth = 0;
let byte_offset = rng.gen_range(0..cursor.node().end_byte());
while cursor.goto_first_child_for_byte(byte_offset).is_some() {
max_depth += 1;
}
let depth = rng.gen_range(0..=max_depth);
for _ in 0..depth {
cursor.goto_parent();
}
// Build a pattern that matches that node.
// Sometimes include subsequent siblings of the node.
let pattern_start = cursor.node().start_position();
let mut roots = vec![Self::random_pattern_for_node(&mut cursor, rng)];
while roots.len() < 5 && cursor.goto_next_sibling() {
if rng.gen_bool(0.2) {
roots.push(Self::random_pattern_for_node(&mut cursor, rng));
}
}
let pattern_end = cursor.node().end_position();
let mut pattern = Self {
kind: None,
named: true,
field: None,
capture: None,
children: roots,
};
if pattern.children.len() == 1 ||
// In a parenthesized list of sibling patterns, the first
// sibling can't be an anonymous `_` wildcard.
(pattern.children[0].kind == Some("_") && !pattern.children[0].named)
{
pattern = pattern.children.pop().unwrap();
}
// In a parenthesized list of sibling patterns, the first
// sibling can't have a field name.
else {
pattern.children[0].field = None;
}
(pattern, pattern_start..pattern_end)
}
fn random_pattern_for_node(cursor: &mut TreeCursor, rng: &mut impl Rng) -> Self {
let node = cursor.node();
// Sometimes specify the node's type, sometimes use a wildcard.
let (kind, named) = if rng.gen_bool(0.9) {
(Some(node.kind()), node.is_named())
} else {
(Some("_"), node.is_named() && rng.gen_bool(0.8))
};
// Sometimes specify the node's field.
let field = if rng.gen_bool(0.75) {
cursor.field_name()
} else {
None
};
// Sometimes capture the node.
let capture = if rng.gen_bool(0.7) {
Some(CAPTURE_NAMES[rng.gen_range(0..CAPTURE_NAMES.len())].to_string())
} else {
None
};
// Walk the children and include child patterns for some of them.
let mut children = Vec::new();
if named && cursor.goto_first_child() {
let max_children = rng.gen_range(0..4);
while cursor.goto_next_sibling() {
if rng.gen_bool(0.6) {
let child_ast = Self::random_pattern_for_node(cursor, rng);
children.push(child_ast);
if children.len() >= max_children {
break;
}
}
}
cursor.goto_parent();
}
Self {
kind,
named,
field,
capture,
children,
}
}
fn write_to_string(&self, string: &mut String, indent: usize) {
if let Some(field) = self.field {
write!(string, "{field}: ").unwrap();
}
if self.named {
string.push('(');
let mut has_contents = if let Some(kind) = &self.kind {
write!(string, "{kind}").unwrap();
true
} else {
false
};
for child in &self.children {
let indent = indent + 2;
if has_contents {
string.push('\n');
string.push_str(&" ".repeat(indent));
}
child.write_to_string(string, indent);
has_contents = true;
}
string.push(')');
} else if self.kind == Some("_") {
string.push('_');
} else {
write!(string, "\"{}\"", self.kind.unwrap().replace('\"', "\\\"")).unwrap();
}
if let Some(capture) = &self.capture {
write!(string, " @{capture}").unwrap();
}
}
pub fn matches_in_tree<'tree>(&self, tree: &'tree Tree) -> Vec<Match<'_, 'tree>> {
let mut matches = Vec::new();
// Compute the matches naively: walk the tree and
// retry the entire pattern for each node.
let mut cursor = tree.walk();
let mut ascending = false;
loop {
if ascending {
if cursor.goto_next_sibling() {
ascending = false;
} else if !cursor.goto_parent() {
break;
}
} else {
let matches_here = self.match_node(&mut cursor);
matches.extend_from_slice(&matches_here);
if !cursor.goto_first_child() {
ascending = true;
}
}
}
matches.sort_unstable();
for m in &mut matches {
m.last_node = None;
}
matches.dedup();
matches
}
pub fn match_node<'tree>(&self, cursor: &mut TreeCursor<'tree>) -> Vec<Match<'_, 'tree>> {
let node = cursor.node();
// If a kind is specified, check that it matches the node.
if let Some(kind) = self.kind {
if kind == "_" {
if self.named && !node.is_named() {
return Vec::new();
}
} else if kind != node.kind() || self.named != node.is_named() {
return Vec::new();
}
}
// If a field is specified, check that it matches the node.
if let Some(field) = self.field {
if cursor.field_name() != Some(field) {
return Vec::new();
}
}
// Create a match for the current node.
let mat = Match {
captures: self
.capture
.as_ref()
.map_or_else(Vec::new, |name| vec![(name.as_str(), node)]),
last_node: Some(node),
};
// If there are no child patterns to match, then return this single match.
if self.children.is_empty() {
return vec![mat];
}
// Find every matching combination of child patterns and child nodes.
let mut finished_matches = Vec::<Match>::new();
if cursor.goto_first_child() {
let mut match_states = vec![(0, mat)];
loop {
let mut new_match_states = Vec::new();
for (pattern_index, mat) in &match_states {
let child_pattern = &self.children[*pattern_index];
let child_matches = child_pattern.match_node(cursor);
for child_match in child_matches {
let mut combined_match = mat.clone();
combined_match.last_node = child_match.last_node;
combined_match
.captures
.extend_from_slice(&child_match.captures);
if pattern_index + 1 < self.children.len() {
new_match_states.push((*pattern_index + 1, combined_match));
} else {
let mut existing = false;
for existing_match in &mut finished_matches {
if existing_match.captures == combined_match.captures {
if child_pattern.capture.is_some() {
existing_match.last_node = combined_match.last_node;
}
existing = true;
}
}
if !existing {
finished_matches.push(combined_match);
}
}
}
}
match_states.extend_from_slice(&new_match_states);
if !cursor.goto_next_sibling() {
break;
}
}
cursor.goto_parent();
}
finished_matches
}
}
impl std::fmt::Display for Pattern {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut result = String::new();
self.write_to_string(&mut result, 0);
write!(f, "{result}")
}
}
impl PartialOrd for Match<'_, '_> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Match<'_, '_> {
// Tree-sitter returns matches in the order that they terminate
// during a depth-first walk of the tree. If multiple matches
// terminate on the same node, those matches are produced in the
// order that their captures were discovered.
fn cmp(&self, other: &Self) -> Ordering {
if let Some((last_node_a, last_node_b)) = self.last_node.zip(other.last_node) {
let cmp = compare_depth_first(last_node_a, last_node_b);
if cmp.is_ne() {
return cmp;
}
}
for (a, b) in self.captures.iter().zip(other.captures.iter()) {
let cmp = compare_depth_first(a.1, b.1);
if !cmp.is_eq() {
return cmp;
}
}
self.captures.len().cmp(&other.captures.len())
}
}
fn compare_depth_first(a: Node, b: Node) -> Ordering {
let a = a.byte_range();
let b = b.byte_range();
a.start.cmp(&b.start).then_with(|| b.end.cmp(&a.end))
}
pub fn assert_query_matches(
language: &Language,
query: &Query,
source: &str,
expected: &[(usize, Vec<(&str, &str)>)],
) {
let mut parser = Parser::new();
parser.set_language(language).unwrap();
let tree = parser.parse(source, None).unwrap();
let mut cursor = QueryCursor::new();
let matches = cursor.matches(query, tree.root_node(), source.as_bytes());
pretty_assertions::assert_eq!(expected, collect_matches(matches, query, source));
pretty_assertions::assert_eq!(false, cursor.did_exceed_match_limit());
}
pub fn collect_matches<'a>(
mut matches: impl StreamingIterator<Item = QueryMatch<'a, 'a>>,
query: &'a Query,
source: &'a str,
) -> Vec<(usize, Vec<(&'a str, &'a str)>)> {
let mut result = Vec::new();
while let Some(m) = matches.next() {
result.push((
m.pattern_index,
format_captures(m.captures.iter().into_streaming_iter_ref(), query, source),
));
}
result
}
pub fn collect_captures<'a>(
captures: impl StreamingIterator<Item = (QueryMatch<'a, 'a>, usize)>,
query: &'a Query,
source: &'a str,
) -> Vec<(&'a str, &'a str)> {
format_captures(captures.map(|(m, i)| m.captures[*i]), query, source)
}
fn format_captures<'a>(
mut captures: impl StreamingIterator<Item = QueryCapture<'a>>,
query: &'a Query,
source: &'a str,
) -> Vec<(&'a str, &'a str)> {
let mut result = Vec::new();
while let Some(capture) = captures.next() {
result.push((
query.capture_names()[capture.index as usize],
capture.node.utf8_text(source.as_bytes()).unwrap(),
));
}
result
}