Add a randomized test for query matching

This commit is contained in:
Max Brunsfeld 2021-05-31 23:14:36 -07:00
parent 3ac53cb645
commit f69c4861c3
6 changed files with 480 additions and 55 deletions

42
Cargo.lock generated
View file

@ -162,6 +162,22 @@ dependencies = [
"lazy_static",
]
[[package]]
name = "ctor"
version = "0.1.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5e98e2ad1a782e33928b96fc3948e7c355e5af34ba4de7670fe8bac2a3b2006d"
dependencies = [
"quote",
"syn",
]
[[package]]
name = "diff"
version = "0.1.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0e25ea47919b1560c4e3b7fe0aaab9becf5b84a10325ddf7db0f0ba5e1026499"
[[package]]
name = "difference"
version = "2.0.0"
@ -360,6 +376,15 @@ version = "1.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "10acf907b94fc1b1a152d08ef97e7759650268cf986bf127f387e602b02c7e5a"
[[package]]
name = "output_vt100"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "53cdc5b785b7a58c5aad8216b3dfa114df64b0b06ae6e1501cef91df2fbdf8f9"
dependencies = [
"winapi",
]
[[package]]
name = "percent-encoding"
version = "2.1.0"
@ -372,6 +397,18 @@ version = "0.2.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ac74c624d6b2d21f425f752262f42188365d7b8ff1aff74c82e45136510a4857"
[[package]]
name = "pretty_assertions"
version = "0.7.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1cab0e7c02cf376875e9335e0ba1da535775beb5450d21e1dffca068818ed98b"
dependencies = [
"ansi_term 0.12.1",
"ctor",
"diff",
"output_vt100",
]
[[package]]
name = "proc-macro2"
version = "1.0.24"
@ -568,9 +605,9 @@ checksum = "8ea5119cdb4c55b55d432abb513a0429384878c15dde60cc77b1c99de1a95a6a"
[[package]]
name = "syn"
version = "1.0.60"
version = "1.0.67"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c700597eca8a5a762beb35753ef6b94df201c81cca676604f547495a0d7f0081"
checksum = "6498a9efc342871f91cc2d0d694c674368b4ceb40f62b65a7a08c3792935e702"
dependencies = [
"proc-macro2",
"quote",
@ -701,6 +738,7 @@ dependencies = [
"indexmap",
"lazy_static",
"log",
"pretty_assertions",
"rand",
"regex",
"regex-syntax",

View file

@ -76,6 +76,7 @@ features = ["std"]
[dev-dependencies]
rand = "0.8"
tempfile = "3"
pretty_assertions = "0.7.2"
[build-dependencies]
toml = "0.5"

View file

@ -1,40 +1,22 @@
use super::helpers::edits::{get_random_edit, invert_edit};
use super::helpers::fixtures::{fixtures_dir, get_language, get_test_language};
use super::helpers::random::Rand;
use super::helpers::scope_sequence::ScopeSequence;
use crate::generate;
use crate::parse::perform_edit;
use crate::test::{parse_tests, print_diff, print_diff_key, strip_sexp_fields, TestEntry};
use crate::util;
use lazy_static::lazy_static;
use std::{env, fs, time, usize};
use super::helpers::{
edits::{get_random_edit, invert_edit},
fixtures::{fixtures_dir, get_language, get_test_language},
random::Rand,
scope_sequence::ScopeSequence,
EXAMPLE_FILTER, LANGUAGE_FILTER, LOG_ENABLED, LOG_GRAPH_ENABLED, SEED, TRIAL_FILTER,
};
use crate::{
generate,
parse::perform_edit,
test::{parse_tests, print_diff, print_diff_key, strip_sexp_fields, TestEntry},
util,
};
use std::{fs, usize};
use tree_sitter::{allocations, LogType, Node, Parser, Tree};
const EDIT_COUNT: usize = 3;
const TRIAL_COUNT: usize = 10;
lazy_static! {
static ref LOG_ENABLED: bool = env::var("TREE_SITTER_TEST_ENABLE_LOG").is_ok();
static ref LOG_GRAPH_ENABLED: bool = env::var("TREE_SITTER_TEST_ENABLE_LOG_GRAPHS").is_ok();
static ref LANGUAGE_FILTER: Option<String> = env::var("TREE_SITTER_TEST_LANGUAGE_FILTER").ok();
static ref EXAMPLE_FILTER: Option<String> = env::var("TREE_SITTER_TEST_EXAMPLE_FILTER").ok();
static ref TRIAL_FILTER: Option<usize> = env::var("TREE_SITTER_TEST_TRIAL_FILTER")
.map(|s| usize::from_str_radix(&s, 10).unwrap())
.ok();
pub static ref SEED: usize = {
let seed = env::var("TREE_SITTER_TEST_SEED")
.map(|s| usize::from_str_radix(&s, 10).unwrap())
.unwrap_or(
time::SystemTime::now()
.duration_since(time::UNIX_EPOCH)
.unwrap()
.as_secs() as usize,
);
eprintln!("\n\nRandom seed: {}\n", seed);
seed
};
}
#[test]
fn test_bash_corpus() {
test_language_corpus("bash");

View file

@ -1,4 +1,32 @@
pub(super) mod edits;
pub(super) mod fixtures;
pub(super) mod query_helpers;
pub(super) mod random;
pub(super) mod scope_sequence;
use lazy_static::lazy_static;
use std::{env, time, usize};
lazy_static! {
pub static ref SEED: usize = {
let seed = env::var("TREE_SITTER_TEST_SEED")
.map(|s| usize::from_str_radix(&s, 10).unwrap())
.unwrap_or(
time::SystemTime::now()
.duration_since(time::UNIX_EPOCH)
.unwrap()
.as_secs() as usize,
);
eprintln!("\n\nRandom seed: {}\n", seed);
seed
};
pub static ref LOG_ENABLED: bool = env::var("TREE_SITTER_TEST_ENABLE_LOG").is_ok();
pub static ref LOG_GRAPH_ENABLED: bool = env::var("TREE_SITTER_TEST_ENABLE_LOG_GRAPHS").is_ok();
pub static ref LANGUAGE_FILTER: Option<String> =
env::var("TREE_SITTER_TEST_LANGUAGE_FILTER").ok();
pub static ref EXAMPLE_FILTER: Option<String> =
env::var("TREE_SITTER_TEST_EXAMPLE_FILTER").ok();
pub static ref TRIAL_FILTER: Option<usize> = env::var("TREE_SITTER_TEST_TRIAL_FILTER")
.map(|s| usize::from_str_radix(&s, 10).unwrap())
.ok();
}

View file

@ -0,0 +1,306 @@
use rand::prelude::Rng;
use std::{cmp::Ordering, fmt::Write, ops::Range};
use tree_sitter::{Node, Point, Tree, TreeCursor};
#[derive(Debug)]
pub struct Pattern {
kind: Option<&'static str>,
named: bool,
field: Option<&'static str>,
capture: Option<String>,
children: Vec<Pattern>,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct Match<'a, 'tree> {
pub captures: Vec<(&'a str, Node<'tree>)>,
pub last_node: Option<Node<'tree>>,
}
const CAPTURE_NAMES: &'static [&'static str] = &[
"one", "two", "three", "four", "five", "six", "seven", "eight",
];
impl Pattern {
pub fn random_pattern_in_tree(tree: &Tree, rng: &mut impl Rng) -> (Self, Range<Point>) {
let mut cursor = tree.walk();
// Descend to the node at a random byte offset and depth.
let mut max_depth = 0;
let byte_offset = rng.gen_range(0..cursor.node().end_byte());
while cursor.goto_first_child_for_byte(byte_offset).is_some() {
max_depth += 1;
}
let depth = rng.gen_range(0..=max_depth);
for _ in 0..depth {
cursor.goto_parent();
}
// Build a pattern that matches that node.
// Sometimes include subsequent siblings of the node.
let pattern_start = cursor.node().start_position();
let mut roots = vec![Self::random_pattern_for_node(&mut cursor, rng)];
while roots.len() < 5 && cursor.goto_next_sibling() {
if rng.gen_bool(0.2) {
roots.push(Self::random_pattern_for_node(&mut cursor, rng));
}
}
let pattern_end = cursor.node().end_position();
let mut pattern = Self {
kind: None,
named: true,
field: None,
capture: None,
children: roots,
};
if pattern.children.len() == 1 {
pattern = pattern.children.pop().unwrap();
}
// In a parenthesized list of sibling patterns, the first
// sibling can't be an anonymous `_` wildcard.
else if pattern.children[0].kind == Some("_") && !pattern.children[0].named {
pattern = pattern.children.pop().unwrap();
}
// In a parenthesized list of sibling patterns, the first
// sibling can't have a field name.
else {
pattern.children[0].field = None;
}
(pattern, pattern_start..pattern_end)
}
fn random_pattern_for_node(cursor: &mut TreeCursor, rng: &mut impl Rng) -> Self {
let node = cursor.node();
// Sometimes specify the node's type, sometimes use a wildcard.
let (kind, named) = if rng.gen_bool(0.9) {
(Some(node.kind()), node.is_named())
} else {
(Some("_"), node.is_named() && rng.gen_bool(0.8))
};
// Sometimes specify the node's field.
let field = if rng.gen_bool(0.75) {
cursor.field_name()
} else {
None
};
// Sometimes capture the node.
let capture = if rng.gen_bool(0.7) {
Some(CAPTURE_NAMES[rng.gen_range(0..CAPTURE_NAMES.len())].to_string())
} else {
None
};
// Walk the children and include child patterns for some of them.
let mut children = Vec::new();
if named && cursor.goto_first_child() {
let max_children = rng.gen_range(0..4);
while cursor.goto_next_sibling() {
if rng.gen_bool(0.6) {
let child_ast = Self::random_pattern_for_node(cursor, rng);
children.push(child_ast);
if children.len() >= max_children {
break;
}
}
}
cursor.goto_parent();
}
Self {
kind,
named,
field,
capture,
children,
}
}
pub fn to_string(&self) -> String {
let mut result = String::new();
self.write_to_string(&mut result, 0);
result
}
fn write_to_string(&self, string: &mut String, indent: usize) {
if let Some(field) = self.field {
write!(string, "{}: ", field).unwrap();
}
if self.named {
string.push('(');
let mut has_contents = false;
if let Some(kind) = &self.kind {
write!(string, "{}", kind).unwrap();
has_contents = true;
}
for child in &self.children {
let indent = indent + 2;
if has_contents {
string.push('\n');
string.push_str(&" ".repeat(indent));
}
child.write_to_string(string, indent);
has_contents = true;
}
string.push(')');
} else if self.kind == Some("_") {
string.push('_');
} else {
write!(string, "\"{}\"", self.kind.unwrap().replace("\"", "\\\"")).unwrap();
}
if let Some(capture) = &self.capture {
write!(string, " @{}", capture).unwrap();
}
}
pub fn matches_in_tree<'tree>(&self, tree: &'tree Tree) -> Vec<Match<'_, 'tree>> {
let mut matches = Vec::new();
// Compute the matches naively: walk the tree and
// retry the entire pattern for each node.
let mut cursor = tree.walk();
let mut ascending = false;
loop {
if ascending {
if cursor.goto_next_sibling() {
ascending = false;
} else if !cursor.goto_parent() {
break;
}
} else {
let matches_here = self.match_node(&mut cursor);
matches.extend_from_slice(&matches_here);
if !cursor.goto_first_child() {
ascending = true;
}
}
}
matches.sort_unstable();
matches.iter_mut().for_each(|m| m.last_node = None);
matches.dedup();
matches
}
pub fn match_node<'tree>(&self, cursor: &mut TreeCursor<'tree>) -> Vec<Match<'_, 'tree>> {
let node = cursor.node();
// If a kind is specified, check that it matches the node.
if let Some(kind) = self.kind {
if kind == "_" {
if self.named && !node.is_named() {
return Vec::new();
}
} else if kind != node.kind() || self.named != node.is_named() {
return Vec::new();
}
}
// If a field is specified, check that it matches the node.
if let Some(field) = self.field {
if cursor.field_name() != Some(field) {
return Vec::new();
}
}
// Create a match for the current node.
let mat = Match {
captures: if let Some(name) = &self.capture {
vec![(name.as_str(), node)]
} else {
Vec::new()
},
last_node: Some(node),
};
// If there are no child patterns to match, then return this single match.
if self.children.is_empty() {
return vec![mat];
}
// Find every matching combination of child patterns and child nodes.
let mut finished_matches = Vec::<Match>::new();
if cursor.goto_first_child() {
let mut match_states = vec![(0, mat)];
loop {
let mut new_match_states = Vec::new();
for (pattern_index, mat) in &match_states {
let child_pattern = &self.children[*pattern_index];
let child_matches = child_pattern.match_node(cursor);
for child_match in child_matches {
let mut combined_match = mat.clone();
combined_match.last_node = child_match.last_node;
combined_match
.captures
.extend_from_slice(&child_match.captures);
if pattern_index + 1 < self.children.len() {
new_match_states.push((*pattern_index + 1, combined_match));
} else {
let mut existing = false;
for existing_match in finished_matches.iter_mut() {
if existing_match.captures == combined_match.captures {
if child_pattern.capture.is_some() {
existing_match.last_node = combined_match.last_node;
}
existing = true;
}
}
if !existing {
finished_matches.push(combined_match);
}
}
}
}
match_states.extend_from_slice(&new_match_states);
if !cursor.goto_next_sibling() {
break;
}
}
cursor.goto_parent();
}
finished_matches
}
}
impl<'a, 'tree> PartialOrd for Match<'a, 'tree> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl<'a, 'tree> Ord for Match<'a, 'tree> {
// Tree-sitter returns matches in the order that they terminate
// during a depth-first walk of the tree. If multiple matches
// terminate on the same node, those matches are produced in the
// order that their captures were discovered.
fn cmp(&self, other: &Self) -> Ordering {
if let Some((last_node_a, last_node_b)) = self.last_node.zip(other.last_node) {
let cmp = compare_depth_first(last_node_a, last_node_b);
if cmp.is_ne() {
return cmp;
}
}
for (a, b) in self.captures.iter().zip(other.captures.iter()) {
let cmp = compare_depth_first(a.1, b.1);
if !cmp.is_eq() {
return cmp;
}
}
self.captures.len().cmp(&other.captures.len())
}
}
fn compare_depth_first(a: Node, b: Node) -> Ordering {
let a = a.byte_range();
let b = b.byte_range();
a.start.cmp(&b.start).then_with(|| b.end.cmp(&a.end))
}

View file

@ -1,7 +1,10 @@
use super::helpers::fixtures::get_language;
use super::helpers::{
fixtures::get_language,
query_helpers::{Match, Pattern},
};
use lazy_static::lazy_static;
use std::env;
use std::fmt::Write;
use rand::{prelude::StdRng, SeedableRng};
use std::{env, fmt::Write};
use tree_sitter::{
allocations, Language, Node, Parser, Point, Query, QueryCapture, QueryCursor, QueryError,
QueryErrorKind, QueryMatch, QueryPredicate, QueryPredicateArg, QueryProperty,
@ -3444,7 +3447,74 @@ fn test_query_alternative_predicate_prefix() {
}
#[test]
fn test_query_step_is_definite() {
fn test_query_random() {
use pretty_assertions::assert_eq;
allocations::record(|| {
let language = get_language("rust");
let mut parser = Parser::new();
parser.set_language(language).unwrap();
let mut cursor = QueryCursor::new();
cursor.set_match_limit(64);
let pattern_tree = parser
.parse(include_str!("helpers/query_helpers.rs"), None)
.unwrap();
let test_tree = parser
.parse(include_str!("helpers/query_helpers.rs"), None)
.unwrap();
// let start_seed = *SEED;
let start_seed = 0;
for i in 0..100 {
let seed = (start_seed + i) as u64;
let mut rand = StdRng::seed_from_u64(seed);
let (pattern_ast, range) = Pattern::random_pattern_in_tree(&pattern_tree, &mut rand);
let pattern = pattern_ast.to_string();
let expected_matches = pattern_ast.matches_in_tree(&test_tree);
eprintln!(
"seed: {}\nsource_range: {:?}\npattern:\n{}\nexpected match count: {}\n",
seed,
range,
pattern,
expected_matches.len(),
);
let query = Query::new(language, &pattern).unwrap();
let mut actual_matches = cursor
.matches(
&query,
test_tree.root_node(),
(include_str!("parser_test.rs")).as_bytes(),
)
.map(|mat| Match {
last_node: None,
captures: mat
.captures
.iter()
.map(|c| (query.capture_names()[c.index as usize].as_str(), c.node))
.collect::<Vec<_>>(),
})
.collect::<Vec<_>>();
// actual_matches.sort_unstable();
actual_matches.dedup();
if !cursor.did_exceed_match_limit() {
assert_eq!(
actual_matches, expected_matches,
"seed: {}, pattern:\n{}",
seed, pattern
);
}
}
});
}
#[test]
fn test_query_is_pattern_guaranteed_at_step() {
struct Row {
language: Language,
description: &'static str,
@ -3454,19 +3524,19 @@ fn test_query_step_is_definite() {
let rows = &[
Row {
description: "no definite steps",
description: "no guaranteed steps",
language: get_language("python"),
pattern: r#"(expression_statement (string))"#,
results_by_substring: &[("expression_statement", false), ("string", false)],
},
Row {
description: "all definite steps",
description: "all guaranteed steps",
language: get_language("javascript"),
pattern: r#"(object "{" "}")"#,
results_by_substring: &[("object", false), ("{", true), ("}", true)],
},
Row {
description: "an indefinite step that is optional",
description: "a fallible step that is optional",
language: get_language("javascript"),
pattern: r#"(object "{" (identifier)? @foo "}")"#,
results_by_substring: &[
@ -3477,7 +3547,7 @@ fn test_query_step_is_definite() {
],
},
Row {
description: "multiple indefinite steps that are optional",
description: "multiple fallible steps that are optional",
language: get_language("javascript"),
pattern: r#"(object "{" (identifier)? @id1 ("," (identifier) @id2)? "}")"#,
results_by_substring: &[
@ -3489,13 +3559,13 @@ fn test_query_step_is_definite() {
],
},
Row {
description: "definite step after indefinite step",
description: "guaranteed step after fallibe step",
language: get_language("javascript"),
pattern: r#"(pair (property_identifier) ":")"#,
results_by_substring: &[("pair", false), ("property_identifier", false), (":", true)],
},
Row {
description: "indefinite step in between two definite steps",
description: "fallible step in between two guaranteed steps",
language: get_language("javascript"),
pattern: r#"(ternary_expression
condition: (_)
@ -3512,13 +3582,13 @@ fn test_query_step_is_definite() {
],
},
Row {
description: "one definite step after a repetition",
description: "one guaranteed step after a repetition",
language: get_language("javascript"),
pattern: r#"(object "{" (_) "}")"#,
results_by_substring: &[("object", false), ("{", false), ("(_)", false), ("}", true)],
},
Row {
description: "definite steps after multiple repetitions",
description: "guaranteed steps after multiple repetitions",
language: get_language("json"),
pattern: r#"(object "{" (pair) "," (pair) "," (_) "}")"#,
results_by_substring: &[
@ -3532,7 +3602,7 @@ fn test_query_step_is_definite() {
],
},
Row {
description: "a definite with a field",
description: "a guaranteed step with a field",
language: get_language("javascript"),
pattern: r#"(binary_expression left: (identifier) right: (_))"#,
results_by_substring: &[
@ -3542,7 +3612,7 @@ fn test_query_step_is_definite() {
],
},
Row {
description: "multiple definite steps with fields",
description: "multiple guaranteed steps with fields",
language: get_language("javascript"),
pattern: r#"(function_declaration name: (identifier) body: (statement_block))"#,
results_by_substring: &[
@ -3552,7 +3622,7 @@ fn test_query_step_is_definite() {
],
},
Row {
description: "nesting, one definite step",
description: "nesting, one guaranteed step",
language: get_language("javascript"),
pattern: r#"
(function_declaration
@ -3568,7 +3638,7 @@ fn test_query_step_is_definite() {
],
},
Row {
description: "definite step after some deeply nested hidden nodes",
description: "a guaranteed step after some deeply nested hidden nodes",
language: get_language("ruby"),
pattern: r#"
(singleton_class
@ -3582,7 +3652,7 @@ fn test_query_step_is_definite() {
],
},
Row {
description: "nesting, no definite steps",
description: "nesting, no guaranteed steps",
language: get_language("javascript"),
pattern: r#"
(call_expression
@ -3593,7 +3663,7 @@ fn test_query_step_is_definite() {
results_by_substring: &[("property_identifier", false), ("template_string", false)],
},
Row {
description: "a definite step after a nested node",
description: "a guaranteed step after a nested node",
language: get_language("javascript"),
pattern: r#"
(subscript_expression
@ -3609,7 +3679,7 @@ fn test_query_step_is_definite() {
],
},
Row {
description: "a step that is indefinite due to a predicate",
description: "a step that is fallible due to a predicate",
language: get_language("javascript"),
pattern: r#"
(subscript_expression
@ -3626,7 +3696,7 @@ fn test_query_step_is_definite() {
],
},
Row {
description: "alternation where one branch has definite steps",
description: "alternation where one branch has guaranteed steps",
language: get_language("javascript"),
pattern: r#"
[
@ -3645,7 +3715,7 @@ fn test_query_step_is_definite() {
],
},
Row {
description: "aliased parent node",
description: "guaranteed step at the end of an aliased parent node",
language: get_language("ruby"),
pattern: r#"
(method_parameters "(" (identifier) @id")")