Merge pull request #2140 from lewis6991/master

feat: add `ts_query_cursor_set_max_depth()`
This commit is contained in:
Andrew Hlynskyi 2023-04-17 11:56:18 +03:00 committed by GitHub
commit af92bfc022
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 268 additions and 43 deletions

139
Cargo.lock generated
View file

@ -191,13 +191,13 @@ checksum = "7fcaabb2fef8c910e7f4c7ce9f67a1283a1715879a7c230ca9d6d1ae31f16d91"
[[package]]
name = "errno"
version = "0.3.0"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "50d6a0976c999d473fe89ad888d5a284e55366d9dc9038b1ba2aa15128c4afa0"
checksum = "4bcfec3a70f97c962c307b2d2c56e358cf1d00b558d74262b5f929ee8cc7e73a"
dependencies = [
"errno-dragonfly",
"libc",
"windows-sys",
"windows-sys 0.48.0",
]
[[package]]
@ -230,9 +230,9 @@ dependencies = [
[[package]]
name = "getrandom"
version = "0.2.8"
version = "0.2.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c05aeb6a22b8f62540c194aac980f2115af067bfe15a0734d7277a768d396b31"
checksum = "c85e1d9ab2eadba7e5040d4e09cbd6d072b76a557ad64e797c2cb9d4da21d7e4"
dependencies = [
"cfg-if",
"libc",
@ -301,6 +301,12 @@ dependencies = [
"hashbrown",
]
[[package]]
name = "indoc"
version = "2.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9f2cb48b81b1dc9f39676bf99f5499babfec7cd8fe14307f7b3d747208fb5690"
[[package]]
name = "instant"
version = "0.1.12"
@ -312,13 +318,13 @@ dependencies = [
[[package]]
name = "io-lifetimes"
version = "1.0.9"
version = "1.0.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09270fd4fa1111bc614ed2246c7ef56239a3063d5be0d1ec3b589c505d400aeb"
checksum = "9c66c74d2ae7e79a5a8f7ac924adbe38ee42a859c6539ad869eb51f0b52dc220"
dependencies = [
"hermit-abi 0.3.1",
"libc",
"windows-sys",
"windows-sys 0.48.0",
]
[[package]]
@ -340,7 +346,7 @@ dependencies = [
"log",
"thiserror",
"walkdir",
"windows-sys",
"windows-sys 0.45.0",
]
[[package]]
@ -578,16 +584,16 @@ checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2"
[[package]]
name = "rustix"
version = "0.37.7"
version = "0.37.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2aae838e49b3d63e9274e1c01833cc8139d3fec468c3b84688c628f44b1ae11d"
checksum = "85597d61f83914ddeba6a47b3b8ffe7365107221c2e557ed94426489fefb5f77"
dependencies = [
"bitflags",
"errno",
"io-lifetimes",
"libc",
"linux-raw-sys",
"windows-sys",
"windows-sys 0.48.0",
]
[[package]]
@ -613,29 +619,29 @@ checksum = "bebd363326d05ec3e2f532ab7660680f3b02130d780c299bca73469d521bc0ed"
[[package]]
name = "serde"
version = "1.0.159"
version = "1.0.160"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3c04e8343c3daeec41f58990b9d77068df31209f2af111e059e9fe9646693065"
checksum = "bb2f3770c8bce3bcda7e149193a069a0f4365bda1fa5cd88e03bca26afc1216c"
dependencies = [
"serde_derive",
]
[[package]]
name = "serde_derive"
version = "1.0.159"
version = "1.0.160"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4c614d17805b093df4b147b51339e7e44bf05ef59fba1e45d83500bcfb4d8585"
checksum = "291a097c63d8497e00160b166a967a4a79c64f3facdd01cbd7502231688d77df"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.13",
"syn 2.0.15",
]
[[package]]
name = "serde_json"
version = "1.0.95"
version = "1.0.96"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d721eca97ac802aa7777b701877c8004d950fc142651367300d21c1cc0194744"
checksum = "057d394a50403bcac12672b2b18fb387ab6d289d957dab67dd201875391e52f1"
dependencies = [
"indexmap",
"itoa",
@ -668,9 +674,9 @@ dependencies = [
[[package]]
name = "syn"
version = "2.0.13"
version = "2.0.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4c9da457c5285ac1f936ebd076af6dac17a61cfe7826f2076b4d015cf47bc8ec"
checksum = "a34fcf3e8b60f57e6a14301a2e916d323af98b0ea63c599441eec8558660c822"
dependencies = [
"proc-macro2",
"quote",
@ -687,7 +693,7 @@ dependencies = [
"fastrand",
"redox_syscall 0.3.5",
"rustix",
"windows-sys",
"windows-sys 0.45.0",
]
[[package]]
@ -716,7 +722,7 @@ checksum = "f9456a42c5b0d803c8cd86e73dd7cc9edd429499f37a3550d286d5e86720569f"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.13",
"syn 2.0.15",
]
[[package]]
@ -778,6 +784,7 @@ dependencies = [
"glob",
"html-escape",
"indexmap",
"indoc",
"lazy_static",
"log",
"path-slash",
@ -999,9 +1006,9 @@ dependencies = [
[[package]]
name = "webbrowser"
version = "0.8.8"
version = "0.8.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "579cc485bd5ce5bfa0d738e4921dd0b956eca9800be1fd2e5257ebe95bc4617e"
checksum = "b692165700260bbd40fbc5ff23766c03e339fbaca907aeea5cb77bf0a553ca83"
dependencies = [
"core-foundation",
"dirs 4.0.0",
@ -1062,7 +1069,16 @@ version = "0.45.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "75283be5efb2831d37ea142365f009c02ec203cd29a3ebecbc093d52315b66d0"
dependencies = [
"windows-targets",
"windows-targets 0.42.2",
]
[[package]]
name = "windows-sys"
version = "0.48.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9"
dependencies = [
"windows-targets 0.48.0",
]
[[package]]
@ -1071,13 +1087,28 @@ version = "0.42.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8e5180c00cd44c9b1c88adb3693291f1cd93605ded80c250a75d472756b4d071"
dependencies = [
"windows_aarch64_gnullvm",
"windows_aarch64_msvc",
"windows_i686_gnu",
"windows_i686_msvc",
"windows_x86_64_gnu",
"windows_x86_64_gnullvm",
"windows_x86_64_msvc",
"windows_aarch64_gnullvm 0.42.2",
"windows_aarch64_msvc 0.42.2",
"windows_i686_gnu 0.42.2",
"windows_i686_msvc 0.42.2",
"windows_x86_64_gnu 0.42.2",
"windows_x86_64_gnullvm 0.42.2",
"windows_x86_64_msvc 0.42.2",
]
[[package]]
name = "windows-targets"
version = "0.48.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7b1eb6f0cd7c80c79759c929114ef071b87354ce476d9d94271031c0497adfd5"
dependencies = [
"windows_aarch64_gnullvm 0.48.0",
"windows_aarch64_msvc 0.48.0",
"windows_i686_gnu 0.48.0",
"windows_i686_msvc 0.48.0",
"windows_x86_64_gnu 0.48.0",
"windows_x86_64_gnullvm 0.48.0",
"windows_x86_64_msvc 0.48.0",
]
[[package]]
@ -1086,38 +1117,80 @@ version = "0.42.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "597a5118570b68bc08d8d59125332c54f1ba9d9adeedeef5b99b02ba2b0698f8"
[[package]]
name = "windows_aarch64_gnullvm"
version = "0.48.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "91ae572e1b79dba883e0d315474df7305d12f569b400fcf90581b06062f7e1bc"
[[package]]
name = "windows_aarch64_msvc"
version = "0.42.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e08e8864a60f06ef0d0ff4ba04124db8b0fb3be5776a5cd47641e942e58c4d43"
[[package]]
name = "windows_aarch64_msvc"
version = "0.48.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b2ef27e0d7bdfcfc7b868b317c1d32c641a6fe4629c171b8928c7b08d98d7cf3"
[[package]]
name = "windows_i686_gnu"
version = "0.42.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c61d927d8da41da96a81f029489353e68739737d3beca43145c8afec9a31a84f"
[[package]]
name = "windows_i686_gnu"
version = "0.48.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "622a1962a7db830d6fd0a69683c80a18fda201879f0f447f065a3b7467daa241"
[[package]]
name = "windows_i686_msvc"
version = "0.42.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "44d840b6ec649f480a41c8d80f9c65108b92d89345dd94027bfe06ac444d1060"
[[package]]
name = "windows_i686_msvc"
version = "0.48.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4542c6e364ce21bf45d69fdd2a8e455fa38d316158cfd43b3ac1c5b1b19f8e00"
[[package]]
name = "windows_x86_64_gnu"
version = "0.42.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8de912b8b8feb55c064867cf047dda097f92d51efad5b491dfb98f6bbb70cb36"
[[package]]
name = "windows_x86_64_gnu"
version = "0.48.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ca2b8a661f7628cbd23440e50b05d705db3686f894fc9580820623656af974b1"
[[package]]
name = "windows_x86_64_gnullvm"
version = "0.42.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "26d41b46a36d453748aedef1486d5c7a85db22e56aff34643984ea85514e94a3"
[[package]]
name = "windows_x86_64_gnullvm"
version = "0.48.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7896dbc1f41e08872e9d5e8f8baa8fdd2677f29468c4e156210174edc7f7b953"
[[package]]
name = "windows_x86_64_msvc"
version = "0.42.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9aec5da331524158c6d1a4ac0ab1541149c0b9505fde06423b02f5ef0106b9f0"
[[package]]
name = "windows_x86_64_msvc"
version = "0.48.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1a515f5799fe4961cb532f983ce2b23082366b898e52ffbce459c86f67c8378a"

View file

@ -78,6 +78,7 @@ tempfile = "3"
pretty_assertions = "0.7.2"
ctor = "0.1"
unindent = "0.2"
indoc = "2.0.1"
[build-dependencies]
toml = "0.5"

View file

@ -4,6 +4,7 @@ use super::helpers::{
query_helpers::{Match, Pattern},
ITERATION_COUNT,
};
use indoc::indoc;
use lazy_static::lazy_static;
use rand::{prelude::StdRng, SeedableRng};
use std::{env, fmt::Write};
@ -4469,6 +4470,111 @@ fn test_capture_quantifiers() {
});
}
#[test]
fn test_query_max_start_depth() {
struct Row {
description: &'static str,
pattern: &'static str,
depth: u32,
matches: &'static [(usize, &'static [(&'static str, &'static str)])],
}
let source = indoc! {"
if (a1 && a2) {
if (b1 && b2) { }
if (c) { }
}
if (d) {
if (e1 && e2) { }
if (f) { }
}
"};
#[rustfmt::skip]
let rows = &[
Row {
description: "depth 0: match all",
depth: 0,
pattern: r#"
(if_statement) @capture
"#,
matches: &[
(0, &[("capture", "if (a1 && a2) {\n if (b1 && b2) { }\n if (c) { }\n}")]),
(0, &[("capture", "if (b1 && b2) { }")]),
(0, &[("capture", "if (c) { }")]),
(0, &[("capture", "if (d) {\n if (e1 && e2) { }\n if (f) { }\n}")]),
(0, &[("capture", "if (e1 && e2) { }")]),
(0, &[("capture", "if (f) { }")]),
]
},
Row {
description: "depth 1: match 2 if statements at the top level",
depth: 1,
pattern: r#"
(if_statement) @capture
"#,
matches : &[
(0, &[("capture", "if (a1 && a2) {\n if (b1 && b2) { }\n if (c) { }\n}")]),
(0, &[("capture", "if (d) {\n if (e1 && e2) { }\n if (f) { }\n}")]),
]
},
Row {
description: "depth 1 with deep pattern: match the only the first if statement",
depth: 1,
pattern: r#"
(if_statement
condition: (parenthesized_expression
(binary_expression)
)
) @capture
"#,
matches: &[
(0, &[("capture", "if (a1 && a2) {\n if (b1 && b2) { }\n if (c) { }\n}")]),
]
},
Row {
description: "depth 3 with deep pattern: match all if statements with a binexpr condition",
depth: 3,
pattern: r#"
(if_statement
condition: (parenthesized_expression
(binary_expression)
)
) @capture
"#,
matches: &[
(0, &[("capture", "if (a1 && a2) {\n if (b1 && b2) { }\n if (c) { }\n}")]),
(0, &[("capture", "if (b1 && b2) { }")]),
(0, &[("capture", "if (e1 && e2) { }")]),
]
},
];
allocations::record(|| {
let language = get_language("c");
let mut parser = Parser::new();
parser.set_language(language).unwrap();
let tree = parser.parse(source, None).unwrap();
let mut cursor = QueryCursor::new();
for row in rows.iter() {
eprintln!(" query example: {:?}", row.description);
let query = Query::new(language, row.pattern).unwrap();
cursor.set_max_start_depth(row.depth);
let matches = cursor.matches(&query, tree.root_node(), source.as_bytes());
let expected = row
.matches
.iter()
.map(|x| (x.0, x.1.to_vec()))
.collect::<Vec<_>>();
assert_eq!(collect_matches(matches, &query, source), expected);
}
});
}
fn assert_query_matches(
language: Language,
query: &Query,

View file

@ -1,4 +1,4 @@
/* automatically generated by rust-bindgen 0.64.0 */
/* automatically generated by rust-bindgen 0.65.1 */
pub type TSSymbol = u16;
pub type TSFieldId = u16;
@ -580,6 +580,10 @@ extern "C" {
capture_index: *mut u32,
) -> bool;
}
extern "C" {
#[doc = " Set the maximum start depth for a cursor.\n\n This prevents cursors from exploring children nodes at a certain depth.\n Note if a pattern includes many children, then they will still be checked.\n\n Set to `0` to remove the maximum start depth."]
pub fn ts_query_cursor_set_max_start_depth(arg1: *mut TSQueryCursor, arg2: u32);
}
extern "C" {
#[doc = " Get the number of distinct node types in the language."]
pub fn ts_language_symbol_count(arg1: *const TSLanguage) -> u32;

View file

@ -1932,6 +1932,14 @@ impl QueryCursor {
}
self
}
#[doc(alias = "ts_query_cursor_set_max_start_depth")]
pub fn set_max_start_depth(&mut self, max_start_depth: u32) -> &mut Self {
unsafe {
ffi::ts_query_cursor_set_max_start_depth(self.ptr.as_ptr(), max_start_depth);
}
self
}
}
impl<'a, 'tree> QueryMatch<'a, 'tree> {

View file

@ -892,6 +892,16 @@ bool ts_query_cursor_next_capture(
uint32_t *capture_index
);
/**
* Set the maximum start depth for a cursor.
*
* This prevents cursors from exploring children nodes at a certain depth.
* Note if a pattern includes many children, then they will still be checked.
*
* Set to `0` to remove the maximum start depth.
*/
void ts_query_cursor_set_max_start_depth(TSQueryCursor *, uint32_t);
/**********************/
/* Section - Language */
/**********************/

View file

@ -305,6 +305,7 @@ struct TSQueryCursor {
Array(QueryState) finished_states;
CaptureListPool capture_list_pool;
uint32_t depth;
uint32_t max_start_depth;
uint32_t start_byte;
uint32_t end_byte;
TSPoint start_point;
@ -2976,6 +2977,7 @@ TSQueryCursor *ts_query_cursor_new(void) {
.end_byte = UINT32_MAX,
.start_point = {0, 0},
.end_point = POINT_MAX,
.max_start_depth = UINT32_MAX,
};
array_reserve(&self->states, 8);
array_reserve(&self->finished_states, 8);
@ -3346,9 +3348,15 @@ static QueryState *ts_query_cursor__copy_state(
return &self->states.contents[state_index + 1];
}
static inline bool ts_query_cursor__should_descend_outside_of_range(
TSQueryCursor *self
static inline bool ts_query_cursor__should_descend(
TSQueryCursor *self,
bool node_intersects_range
) {
if (node_intersects_range && self->depth < self->max_start_depth) {
return true;
}
// If there are in-progress matches whose remaining steps occur
// deeper in the tree, then descend.
for (unsigned i = 0; i < self->states.size; i++) {
@ -3362,6 +3370,10 @@ static inline bool ts_query_cursor__should_descend_outside_of_range(
}
}
if (self->depth >= self->max_start_depth) {
return false;
}
// If the current node is hidden, then a non-rooted pattern might match
// one if its roots inside of this node, and match another of its roots
// as part of a sibling node, so we may need to descend.
@ -3555,12 +3567,14 @@ static inline bool ts_query_cursor__advance(
// If this node matches the first step of the pattern, then add a new
// state at the start of this pattern.
QueryStep *step = &self->query->steps.contents[pattern->step_index];
uint32_t start_depth = self->depth - step->depth;
if (
(pattern->is_rooted ?
node_intersects_range :
(parent_intersects_range && !parent_is_error)) &&
(!step->field || field_id == step->field) &&
(!step->supertype_symbol || supertype_count > 0)
(!step->supertype_symbol || supertype_count > 0) &&
(start_depth <= self->max_start_depth)
) {
ts_query_cursor__add_state(self, pattern);
}
@ -3573,6 +3587,7 @@ static inline bool ts_query_cursor__advance(
PatternEntry *pattern = &self->query->pattern_map.contents[i];
QueryStep *step = &self->query->steps.contents[pattern->step_index];
uint32_t start_depth = self->depth - step->depth;
do {
// If this node matches the first step of the pattern, then add a new
// state at the start of this pattern.
@ -3580,7 +3595,8 @@ static inline bool ts_query_cursor__advance(
(pattern->is_rooted ?
node_intersects_range :
(parent_intersects_range && !parent_is_error)) &&
(!step->field || field_id == step->field)
(!step->field || field_id == step->field) &&
(start_depth <= self->max_start_depth)
) {
ts_query_cursor__add_state(self, pattern);
}
@ -3881,10 +3897,7 @@ static inline bool ts_query_cursor__advance(
}
}
bool should_descend =
node_intersects_range ||
ts_query_cursor__should_descend_outside_of_range(self);
if (should_descend) {
if (ts_query_cursor__should_descend(self, node_intersects_range)) {
switch (ts_tree_cursor_goto_first_child_internal(&self->cursor)) {
case TreeCursorStepVisible:
self->depth++;
@ -4075,4 +4088,15 @@ bool ts_query_cursor_next_capture(
}
}
void ts_query_cursor_set_max_start_depth(
TSQueryCursor *self,
uint32_t max_start_depth
) {
if (max_start_depth == 0) {
self->max_start_depth = UINT32_MAX;
} else {
self->max_start_depth = max_start_depth;
}
}
#undef LOG

View file

@ -8,7 +8,6 @@ bindgen \
--allowlist-type '^TS.*' \
--allowlist-function '^ts_.*' \
--blocklist-type '^__.*' \
--size_t-is-usize \
$header_path > $output_path
echo "" >> $output_path