diff --git a/Cargo.lock b/Cargo.lock index 52d02cc4..7b3e299c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -162,6 +162,22 @@ dependencies = [ "lazy_static", ] +[[package]] +name = "ctor" +version = "0.1.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e98e2ad1a782e33928b96fc3948e7c355e5af34ba4de7670fe8bac2a3b2006d" +dependencies = [ + "quote", + "syn", +] + +[[package]] +name = "diff" +version = "0.1.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e25ea47919b1560c4e3b7fe0aaab9becf5b84a10325ddf7db0f0ba5e1026499" + [[package]] name = "difference" version = "2.0.0" @@ -360,6 +376,15 @@ version = "1.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "10acf907b94fc1b1a152d08ef97e7759650268cf986bf127f387e602b02c7e5a" +[[package]] +name = "output_vt100" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53cdc5b785b7a58c5aad8216b3dfa114df64b0b06ae6e1501cef91df2fbdf8f9" +dependencies = [ + "winapi", +] + [[package]] name = "percent-encoding" version = "2.1.0" @@ -372,6 +397,18 @@ version = "0.2.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ac74c624d6b2d21f425f752262f42188365d7b8ff1aff74c82e45136510a4857" +[[package]] +name = "pretty_assertions" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1cab0e7c02cf376875e9335e0ba1da535775beb5450d21e1dffca068818ed98b" +dependencies = [ + "ansi_term 0.12.1", + "ctor", + "diff", + "output_vt100", +] + [[package]] name = "proc-macro2" version = "1.0.24" @@ -568,9 +605,9 @@ checksum = "8ea5119cdb4c55b55d432abb513a0429384878c15dde60cc77b1c99de1a95a6a" [[package]] name = "syn" -version = "1.0.60" +version = "1.0.67" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c700597eca8a5a762beb35753ef6b94df201c81cca676604f547495a0d7f0081" +checksum = "6498a9efc342871f91cc2d0d694c674368b4ceb40f62b65a7a08c3792935e702" dependencies = [ "proc-macro2", "quote", @@ -701,6 +738,7 @@ dependencies = [ "indexmap", "lazy_static", "log", + "pretty_assertions", "rand", "regex", "regex-syntax", diff --git a/cli/Cargo.toml b/cli/Cargo.toml index e559842f..2a0d4927 100644 --- a/cli/Cargo.toml +++ b/cli/Cargo.toml @@ -76,6 +76,7 @@ features = ["std"] [dev-dependencies] rand = "0.8" tempfile = "3" +pretty_assertions = "0.7.2" [build-dependencies] toml = "0.5" diff --git a/cli/src/tests/corpus_test.rs b/cli/src/tests/corpus_test.rs index 5699f063..d2e586de 100644 --- a/cli/src/tests/corpus_test.rs +++ b/cli/src/tests/corpus_test.rs @@ -1,40 +1,22 @@ -use super::helpers::edits::{get_random_edit, invert_edit}; -use super::helpers::fixtures::{fixtures_dir, get_language, get_test_language}; -use super::helpers::random::Rand; -use super::helpers::scope_sequence::ScopeSequence; -use crate::generate; -use crate::parse::perform_edit; -use crate::test::{parse_tests, print_diff, print_diff_key, strip_sexp_fields, TestEntry}; -use crate::util; -use lazy_static::lazy_static; -use std::{env, fs, time, usize}; +use super::helpers::{ + edits::{get_random_edit, invert_edit}, + fixtures::{fixtures_dir, get_language, get_test_language}, + random::Rand, + scope_sequence::ScopeSequence, + EXAMPLE_FILTER, LANGUAGE_FILTER, LOG_ENABLED, LOG_GRAPH_ENABLED, SEED, TRIAL_FILTER, +}; +use crate::{ + generate, + parse::perform_edit, + test::{parse_tests, print_diff, print_diff_key, strip_sexp_fields, TestEntry}, + util, +}; +use std::{fs, usize}; use tree_sitter::{allocations, LogType, Node, Parser, Tree}; const EDIT_COUNT: usize = 3; const TRIAL_COUNT: usize = 10; -lazy_static! { - static ref LOG_ENABLED: bool = env::var("TREE_SITTER_TEST_ENABLE_LOG").is_ok(); - static ref LOG_GRAPH_ENABLED: bool = env::var("TREE_SITTER_TEST_ENABLE_LOG_GRAPHS").is_ok(); - static ref LANGUAGE_FILTER: Option = env::var("TREE_SITTER_TEST_LANGUAGE_FILTER").ok(); - static ref EXAMPLE_FILTER: Option = env::var("TREE_SITTER_TEST_EXAMPLE_FILTER").ok(); - static ref TRIAL_FILTER: Option = env::var("TREE_SITTER_TEST_TRIAL_FILTER") - .map(|s| usize::from_str_radix(&s, 10).unwrap()) - .ok(); - pub static ref SEED: usize = { - let seed = env::var("TREE_SITTER_TEST_SEED") - .map(|s| usize::from_str_radix(&s, 10).unwrap()) - .unwrap_or( - time::SystemTime::now() - .duration_since(time::UNIX_EPOCH) - .unwrap() - .as_secs() as usize, - ); - eprintln!("\n\nRandom seed: {}\n", seed); - seed - }; -} - #[test] fn test_bash_corpus() { test_language_corpus("bash"); diff --git a/cli/src/tests/helpers/mod.rs b/cli/src/tests/helpers/mod.rs index 3a75dad3..e492a42e 100644 --- a/cli/src/tests/helpers/mod.rs +++ b/cli/src/tests/helpers/mod.rs @@ -1,4 +1,32 @@ pub(super) mod edits; pub(super) mod fixtures; +pub(super) mod query_helpers; pub(super) mod random; pub(super) mod scope_sequence; + +use lazy_static::lazy_static; +use std::{env, time, usize}; + +lazy_static! { + pub static ref SEED: usize = { + let seed = env::var("TREE_SITTER_TEST_SEED") + .map(|s| usize::from_str_radix(&s, 10).unwrap()) + .unwrap_or( + time::SystemTime::now() + .duration_since(time::UNIX_EPOCH) + .unwrap() + .as_secs() as usize, + ); + eprintln!("\n\nRandom seed: {}\n", seed); + seed + }; + pub static ref LOG_ENABLED: bool = env::var("TREE_SITTER_TEST_ENABLE_LOG").is_ok(); + pub static ref LOG_GRAPH_ENABLED: bool = env::var("TREE_SITTER_TEST_ENABLE_LOG_GRAPHS").is_ok(); + pub static ref LANGUAGE_FILTER: Option = + env::var("TREE_SITTER_TEST_LANGUAGE_FILTER").ok(); + pub static ref EXAMPLE_FILTER: Option = + env::var("TREE_SITTER_TEST_EXAMPLE_FILTER").ok(); + pub static ref TRIAL_FILTER: Option = env::var("TREE_SITTER_TEST_TRIAL_FILTER") + .map(|s| usize::from_str_radix(&s, 10).unwrap()) + .ok(); +} diff --git a/cli/src/tests/helpers/query_helpers.rs b/cli/src/tests/helpers/query_helpers.rs new file mode 100644 index 00000000..78ae559c --- /dev/null +++ b/cli/src/tests/helpers/query_helpers.rs @@ -0,0 +1,306 @@ +use rand::prelude::Rng; +use std::{cmp::Ordering, fmt::Write, ops::Range}; +use tree_sitter::{Node, Point, Tree, TreeCursor}; + +#[derive(Debug)] +pub struct Pattern { + kind: Option<&'static str>, + named: bool, + field: Option<&'static str>, + capture: Option, + children: Vec, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Match<'a, 'tree> { + pub captures: Vec<(&'a str, Node<'tree>)>, + pub last_node: Option>, +} + +const CAPTURE_NAMES: &'static [&'static str] = &[ + "one", "two", "three", "four", "five", "six", "seven", "eight", +]; + +impl Pattern { + pub fn random_pattern_in_tree(tree: &Tree, rng: &mut impl Rng) -> (Self, Range) { + let mut cursor = tree.walk(); + + // Descend to the node at a random byte offset and depth. + let mut max_depth = 0; + let byte_offset = rng.gen_range(0..cursor.node().end_byte()); + while cursor.goto_first_child_for_byte(byte_offset).is_some() { + max_depth += 1; + } + let depth = rng.gen_range(0..=max_depth); + for _ in 0..depth { + cursor.goto_parent(); + } + + // Build a pattern that matches that node. + // Sometimes include subsequent siblings of the node. + let pattern_start = cursor.node().start_position(); + let mut roots = vec![Self::random_pattern_for_node(&mut cursor, rng)]; + while roots.len() < 5 && cursor.goto_next_sibling() { + if rng.gen_bool(0.2) { + roots.push(Self::random_pattern_for_node(&mut cursor, rng)); + } + } + let pattern_end = cursor.node().end_position(); + + let mut pattern = Self { + kind: None, + named: true, + field: None, + capture: None, + children: roots, + }; + + if pattern.children.len() == 1 { + pattern = pattern.children.pop().unwrap(); + } + // In a parenthesized list of sibling patterns, the first + // sibling can't be an anonymous `_` wildcard. + else if pattern.children[0].kind == Some("_") && !pattern.children[0].named { + pattern = pattern.children.pop().unwrap(); + } + // In a parenthesized list of sibling patterns, the first + // sibling can't have a field name. + else { + pattern.children[0].field = None; + } + + (pattern, pattern_start..pattern_end) + } + + fn random_pattern_for_node(cursor: &mut TreeCursor, rng: &mut impl Rng) -> Self { + let node = cursor.node(); + + // Sometimes specify the node's type, sometimes use a wildcard. + let (kind, named) = if rng.gen_bool(0.9) { + (Some(node.kind()), node.is_named()) + } else { + (Some("_"), node.is_named() && rng.gen_bool(0.8)) + }; + + // Sometimes specify the node's field. + let field = if rng.gen_bool(0.75) { + cursor.field_name() + } else { + None + }; + + // Sometimes capture the node. + let capture = if rng.gen_bool(0.7) { + Some(CAPTURE_NAMES[rng.gen_range(0..CAPTURE_NAMES.len())].to_string()) + } else { + None + }; + + // Walk the children and include child patterns for some of them. + let mut children = Vec::new(); + if named && cursor.goto_first_child() { + let max_children = rng.gen_range(0..4); + while cursor.goto_next_sibling() { + if rng.gen_bool(0.6) { + let child_ast = Self::random_pattern_for_node(cursor, rng); + children.push(child_ast); + if children.len() >= max_children { + break; + } + } + } + cursor.goto_parent(); + } + + Self { + kind, + named, + field, + capture, + children, + } + } + + pub fn to_string(&self) -> String { + let mut result = String::new(); + self.write_to_string(&mut result, 0); + result + } + + fn write_to_string(&self, string: &mut String, indent: usize) { + if let Some(field) = self.field { + write!(string, "{}: ", field).unwrap(); + } + + if self.named { + string.push('('); + let mut has_contents = false; + if let Some(kind) = &self.kind { + write!(string, "{}", kind).unwrap(); + has_contents = true; + } + for child in &self.children { + let indent = indent + 2; + if has_contents { + string.push('\n'); + string.push_str(&" ".repeat(indent)); + } + child.write_to_string(string, indent); + has_contents = true; + } + string.push(')'); + } else if self.kind == Some("_") { + string.push('_'); + } else { + write!(string, "\"{}\"", self.kind.unwrap().replace("\"", "\\\"")).unwrap(); + } + + if let Some(capture) = &self.capture { + write!(string, " @{}", capture).unwrap(); + } + } + + pub fn matches_in_tree<'tree>(&self, tree: &'tree Tree) -> Vec> { + let mut matches = Vec::new(); + + // Compute the matches naively: walk the tree and + // retry the entire pattern for each node. + let mut cursor = tree.walk(); + let mut ascending = false; + loop { + if ascending { + if cursor.goto_next_sibling() { + ascending = false; + } else if !cursor.goto_parent() { + break; + } + } else { + let matches_here = self.match_node(&mut cursor); + matches.extend_from_slice(&matches_here); + if !cursor.goto_first_child() { + ascending = true; + } + } + } + + matches.sort_unstable(); + matches.iter_mut().for_each(|m| m.last_node = None); + matches.dedup(); + matches + } + + pub fn match_node<'tree>(&self, cursor: &mut TreeCursor<'tree>) -> Vec> { + let node = cursor.node(); + + // If a kind is specified, check that it matches the node. + if let Some(kind) = self.kind { + if kind == "_" { + if self.named && !node.is_named() { + return Vec::new(); + } + } else if kind != node.kind() || self.named != node.is_named() { + return Vec::new(); + } + } + + // If a field is specified, check that it matches the node. + if let Some(field) = self.field { + if cursor.field_name() != Some(field) { + return Vec::new(); + } + } + + // Create a match for the current node. + let mat = Match { + captures: if let Some(name) = &self.capture { + vec![(name.as_str(), node)] + } else { + Vec::new() + }, + last_node: Some(node), + }; + + // If there are no child patterns to match, then return this single match. + if self.children.is_empty() { + return vec![mat]; + } + + // Find every matching combination of child patterns and child nodes. + let mut finished_matches = Vec::::new(); + if cursor.goto_first_child() { + let mut match_states = vec![(0, mat)]; + loop { + let mut new_match_states = Vec::new(); + for (pattern_index, mat) in &match_states { + let child_pattern = &self.children[*pattern_index]; + let child_matches = child_pattern.match_node(cursor); + for child_match in child_matches { + let mut combined_match = mat.clone(); + combined_match.last_node = child_match.last_node; + combined_match + .captures + .extend_from_slice(&child_match.captures); + if pattern_index + 1 < self.children.len() { + new_match_states.push((*pattern_index + 1, combined_match)); + } else { + let mut existing = false; + for existing_match in finished_matches.iter_mut() { + if existing_match.captures == combined_match.captures { + if child_pattern.capture.is_some() { + existing_match.last_node = combined_match.last_node; + } + existing = true; + } + } + if !existing { + finished_matches.push(combined_match); + } + } + } + } + match_states.extend_from_slice(&new_match_states); + if !cursor.goto_next_sibling() { + break; + } + } + cursor.goto_parent(); + } + finished_matches + } +} + +impl<'a, 'tree> PartialOrd for Match<'a, 'tree> { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl<'a, 'tree> Ord for Match<'a, 'tree> { + // Tree-sitter returns matches in the order that they terminate + // during a depth-first walk of the tree. If multiple matches + // terminate on the same node, those matches are produced in the + // order that their captures were discovered. + fn cmp(&self, other: &Self) -> Ordering { + if let Some((last_node_a, last_node_b)) = self.last_node.zip(other.last_node) { + let cmp = compare_depth_first(last_node_a, last_node_b); + if cmp.is_ne() { + return cmp; + } + } + + for (a, b) in self.captures.iter().zip(other.captures.iter()) { + let cmp = compare_depth_first(a.1, b.1); + if !cmp.is_eq() { + return cmp; + } + } + + self.captures.len().cmp(&other.captures.len()) + } +} + +fn compare_depth_first(a: Node, b: Node) -> Ordering { + let a = a.byte_range(); + let b = b.byte_range(); + a.start.cmp(&b.start).then_with(|| b.end.cmp(&a.end)) +} diff --git a/cli/src/tests/query_test.rs b/cli/src/tests/query_test.rs index 6b28cdd5..d552d422 100644 --- a/cli/src/tests/query_test.rs +++ b/cli/src/tests/query_test.rs @@ -1,7 +1,10 @@ -use super::helpers::fixtures::get_language; +use super::helpers::{ + fixtures::get_language, + query_helpers::{Match, Pattern}, +}; use lazy_static::lazy_static; -use std::env; -use std::fmt::Write; +use rand::{prelude::StdRng, SeedableRng}; +use std::{env, fmt::Write}; use tree_sitter::{ allocations, Language, Node, Parser, Point, Query, QueryCapture, QueryCursor, QueryError, QueryErrorKind, QueryMatch, QueryPredicate, QueryPredicateArg, QueryProperty, @@ -3444,7 +3447,74 @@ fn test_query_alternative_predicate_prefix() { } #[test] -fn test_query_step_is_definite() { +fn test_query_random() { + use pretty_assertions::assert_eq; + + allocations::record(|| { + let language = get_language("rust"); + let mut parser = Parser::new(); + parser.set_language(language).unwrap(); + let mut cursor = QueryCursor::new(); + cursor.set_match_limit(64); + + let pattern_tree = parser + .parse(include_str!("helpers/query_helpers.rs"), None) + .unwrap(); + let test_tree = parser + .parse(include_str!("helpers/query_helpers.rs"), None) + .unwrap(); + + // let start_seed = *SEED; + let start_seed = 0; + + for i in 0..100 { + let seed = (start_seed + i) as u64; + let mut rand = StdRng::seed_from_u64(seed); + let (pattern_ast, range) = Pattern::random_pattern_in_tree(&pattern_tree, &mut rand); + let pattern = pattern_ast.to_string(); + let expected_matches = pattern_ast.matches_in_tree(&test_tree); + + eprintln!( + "seed: {}\nsource_range: {:?}\npattern:\n{}\nexpected match count: {}\n", + seed, + range, + pattern, + expected_matches.len(), + ); + + let query = Query::new(language, &pattern).unwrap(); + let mut actual_matches = cursor + .matches( + &query, + test_tree.root_node(), + (include_str!("parser_test.rs")).as_bytes(), + ) + .map(|mat| Match { + last_node: None, + captures: mat + .captures + .iter() + .map(|c| (query.capture_names()[c.index as usize].as_str(), c.node)) + .collect::>(), + }) + .collect::>(); + + // actual_matches.sort_unstable(); + actual_matches.dedup(); + + if !cursor.did_exceed_match_limit() { + assert_eq!( + actual_matches, expected_matches, + "seed: {}, pattern:\n{}", + seed, pattern + ); + } + } + }); +} + +#[test] +fn test_query_is_pattern_guaranteed_at_step() { struct Row { language: Language, description: &'static str, @@ -3454,19 +3524,19 @@ fn test_query_step_is_definite() { let rows = &[ Row { - description: "no definite steps", + description: "no guaranteed steps", language: get_language("python"), pattern: r#"(expression_statement (string))"#, results_by_substring: &[("expression_statement", false), ("string", false)], }, Row { - description: "all definite steps", + description: "all guaranteed steps", language: get_language("javascript"), pattern: r#"(object "{" "}")"#, results_by_substring: &[("object", false), ("{", true), ("}", true)], }, Row { - description: "an indefinite step that is optional", + description: "a fallible step that is optional", language: get_language("javascript"), pattern: r#"(object "{" (identifier)? @foo "}")"#, results_by_substring: &[ @@ -3477,7 +3547,7 @@ fn test_query_step_is_definite() { ], }, Row { - description: "multiple indefinite steps that are optional", + description: "multiple fallible steps that are optional", language: get_language("javascript"), pattern: r#"(object "{" (identifier)? @id1 ("," (identifier) @id2)? "}")"#, results_by_substring: &[ @@ -3489,13 +3559,13 @@ fn test_query_step_is_definite() { ], }, Row { - description: "definite step after indefinite step", + description: "guaranteed step after fallibe step", language: get_language("javascript"), pattern: r#"(pair (property_identifier) ":")"#, results_by_substring: &[("pair", false), ("property_identifier", false), (":", true)], }, Row { - description: "indefinite step in between two definite steps", + description: "fallible step in between two guaranteed steps", language: get_language("javascript"), pattern: r#"(ternary_expression condition: (_) @@ -3512,13 +3582,13 @@ fn test_query_step_is_definite() { ], }, Row { - description: "one definite step after a repetition", + description: "one guaranteed step after a repetition", language: get_language("javascript"), pattern: r#"(object "{" (_) "}")"#, results_by_substring: &[("object", false), ("{", false), ("(_)", false), ("}", true)], }, Row { - description: "definite steps after multiple repetitions", + description: "guaranteed steps after multiple repetitions", language: get_language("json"), pattern: r#"(object "{" (pair) "," (pair) "," (_) "}")"#, results_by_substring: &[ @@ -3532,7 +3602,7 @@ fn test_query_step_is_definite() { ], }, Row { - description: "a definite with a field", + description: "a guaranteed step with a field", language: get_language("javascript"), pattern: r#"(binary_expression left: (identifier) right: (_))"#, results_by_substring: &[ @@ -3542,7 +3612,7 @@ fn test_query_step_is_definite() { ], }, Row { - description: "multiple definite steps with fields", + description: "multiple guaranteed steps with fields", language: get_language("javascript"), pattern: r#"(function_declaration name: (identifier) body: (statement_block))"#, results_by_substring: &[ @@ -3552,7 +3622,7 @@ fn test_query_step_is_definite() { ], }, Row { - description: "nesting, one definite step", + description: "nesting, one guaranteed step", language: get_language("javascript"), pattern: r#" (function_declaration @@ -3568,7 +3638,7 @@ fn test_query_step_is_definite() { ], }, Row { - description: "definite step after some deeply nested hidden nodes", + description: "a guaranteed step after some deeply nested hidden nodes", language: get_language("ruby"), pattern: r#" (singleton_class @@ -3582,7 +3652,7 @@ fn test_query_step_is_definite() { ], }, Row { - description: "nesting, no definite steps", + description: "nesting, no guaranteed steps", language: get_language("javascript"), pattern: r#" (call_expression @@ -3593,7 +3663,7 @@ fn test_query_step_is_definite() { results_by_substring: &[("property_identifier", false), ("template_string", false)], }, Row { - description: "a definite step after a nested node", + description: "a guaranteed step after a nested node", language: get_language("javascript"), pattern: r#" (subscript_expression @@ -3609,7 +3679,7 @@ fn test_query_step_is_definite() { ], }, Row { - description: "a step that is indefinite due to a predicate", + description: "a step that is fallible due to a predicate", language: get_language("javascript"), pattern: r#" (subscript_expression @@ -3626,7 +3696,7 @@ fn test_query_step_is_definite() { ], }, Row { - description: "alternation where one branch has definite steps", + description: "alternation where one branch has guaranteed steps", language: get_language("javascript"), pattern: r#" [ @@ -3645,7 +3715,7 @@ fn test_query_step_is_definite() { ], }, Row { - description: "aliased parent node", + description: "guaranteed step at the end of an aliased parent node", language: get_language("ruby"), pattern: r#" (method_parameters "(" (identifier) @id")")