From 3f424c01216bf4a1ef4cafdf2d60763d6f77fa3f Mon Sep 17 00:00:00 2001 From: Amaan Qureshi Date: Thu, 29 Aug 2024 17:21:52 -0400 Subject: [PATCH] feat: add an API to time out query executions Currently, if a predicate is hard to match on the Rust side, a sizable query against a very large file can take forever, and ends up hanging. This commit adds an API function `ts_query_cursor_set_timeout_micros` to limit how long query execution is allowed to take, thereby negating the chance of a hang to occur. --- cli/src/tests/query_test.rs | 25 ++++++++++++++++++ lib/binding_rust/bindings.rs | 12 +++++++-- lib/binding_rust/lib.rs | 20 +++++++++++++++ lib/binding_web/binding.c | 8 ++++-- lib/binding_web/binding.js | 4 +++ lib/binding_web/test/query-test.js | 11 ++++++++ lib/binding_web/tree-sitter-web.d.ts | 1 + lib/include/tree_sitter/api.h | 16 ++++++++++++ lib/src/parser.c | 4 +-- lib/src/query.c | 38 ++++++++++++++++++++++++++-- script/generate-bindings | 1 + 11 files changed, 132 insertions(+), 8 deletions(-) diff --git a/cli/src/tests/query_test.rs b/cli/src/tests/query_test.rs index f37821c1..d404d19a 100644 --- a/cli/src/tests/query_test.rs +++ b/cli/src/tests/query_test.rs @@ -5146,3 +5146,28 @@ fn test_query_on_empty_source_code() { &[(0, vec![("program", "")])], ); } + +#[test] +fn test_query_execution_with_timeout() { + let language = get_language("javascript"); + let mut parser = Parser::new(); + parser.set_language(&language).unwrap(); + + let source_code = "function foo() { while (true) { } }\n".repeat(1000); + let tree = parser.parse(&source_code, None).unwrap(); + + let query = Query::new(&language, "(function_declaration) @function").unwrap(); + let mut cursor = QueryCursor::new(); + + cursor.set_timeout_micros(1000); + let matches = cursor + .matches(&query, tree.root_node(), source_code.as_bytes()) + .count(); + assert!(matches < 1000); + + cursor.set_timeout_micros(0); + let matches = cursor + .matches(&query, tree.root_node(), source_code.as_bytes()) + .count(); + assert_eq!(matches, 1000); +} diff --git a/lib/binding_rust/bindings.rs b/lib/binding_rust/bindings.rs index 3f831266..feaa8ca8 100644 --- a/lib/binding_rust/bindings.rs +++ b/lib/binding_rust/bindings.rs @@ -1,4 +1,4 @@ -/* automatically generated by rust-bindgen 0.69.4 */ +/* automatically generated by rust-bindgen 0.70.0 */ pub const TREE_SITTER_LANGUAGE_VERSION: u32 = 14; pub const TREE_SITTER_MIN_COMPATIBLE_LANGUAGE_VERSION: u32 = 13; @@ -462,7 +462,7 @@ extern "C" { pub fn ts_tree_cursor_delete(self_: *mut TSTreeCursor); } extern "C" { - #[doc = " Re-initialize a tree cursor to start at a different node."] + #[doc = " Re-initialize a tree cursor to start at the original node that the cursor was\n constructed with."] pub fn ts_tree_cursor_reset(self_: *mut TSTreeCursor, node: TSNode); } extern "C" { @@ -637,6 +637,14 @@ extern "C" { extern "C" { pub fn ts_query_cursor_set_match_limit(self_: *mut TSQueryCursor, limit: u32); } +extern "C" { + #[doc = " Set the maximum duration in microseconds that query execution should be allowed to\n take before halting.\n\n If query execution takes longer than this, it will halt early, returning NULL.\n See [`ts_query_cursor_next_match`] or [`ts_query_cursor_next_capture`] for more information."] + pub fn ts_query_cursor_set_timeout_micros(self_: *mut TSQueryCursor, timeout_micros: u64); +} +extern "C" { + #[doc = " Get the duration in microseconds that query execution is allowed to take."] + pub fn ts_query_cursor_timeout_micros(self_: *const TSQueryCursor) -> u64; +} extern "C" { #[doc = " Set the range of bytes or (row, column) positions in which the query\n will be executed."] pub fn ts_query_cursor_set_byte_range( diff --git a/lib/binding_rust/lib.rs b/lib/binding_rust/lib.rs index c97fd5ca..971afc75 100644 --- a/lib/binding_rust/lib.rs +++ b/lib/binding_rust/lib.rs @@ -2362,6 +2362,26 @@ impl QueryCursor { } } + /// Set the maximum duration in microseconds that query execution should be allowed to + /// take before halting. + /// + /// If query execution takes longer than this, it will halt early, returning None. + #[doc(alias = "ts_query_cursor_set_timeout_micros")] + pub fn set_timeout_micros(&mut self, timeout: u64) { + unsafe { + ffi::ts_query_cursor_set_timeout_micros(self.ptr.as_ptr(), timeout); + } + } + + /// Get the duration in microseconds that query execution is allowed to take. + /// + /// This is set via [`set_timeout_micros`](QueryCursor::set_timeout_micros). + #[doc(alias = "ts_query_cursor_timeout_micros")] + #[must_use] + pub fn timeout_micros(&self) -> u64 { + unsafe { ffi::ts_query_cursor_timeout_micros(self.ptr.as_ptr()) } + } + /// Check if, on its last execution, this cursor exceeded its maximum number /// of in-progress matches. #[doc(alias = "ts_query_cursor_did_exceed_match_limit")] diff --git a/lib/binding_web/binding.c b/lib/binding_web/binding.c index fba62eba..36efb042 100644 --- a/lib/binding_web/binding.c +++ b/lib/binding_web/binding.c @@ -792,7 +792,8 @@ void ts_query_matches_wasm( uint32_t start_index, uint32_t end_index, uint32_t match_limit, - uint32_t max_start_depth + uint32_t max_start_depth, + uint32_t timeout_micros ) { if (!scratch_query_cursor) { scratch_query_cursor = ts_query_cursor_new(); @@ -810,6 +811,7 @@ void ts_query_matches_wasm( ts_query_cursor_set_byte_range(scratch_query_cursor, start_index, end_index); ts_query_cursor_set_match_limit(scratch_query_cursor, match_limit); ts_query_cursor_set_max_start_depth(scratch_query_cursor, max_start_depth); + ts_query_cursor_set_timeout_micros(scratch_query_cursor, timeout_micros); ts_query_cursor_exec(scratch_query_cursor, self, node); uint32_t index = 0; @@ -847,7 +849,8 @@ void ts_query_captures_wasm( uint32_t start_index, uint32_t end_index, uint32_t match_limit, - uint32_t max_start_depth + uint32_t max_start_depth, + uint32_t timeout_micros ) { if (!scratch_query_cursor) { scratch_query_cursor = ts_query_cursor_new(); @@ -862,6 +865,7 @@ void ts_query_captures_wasm( ts_query_cursor_set_byte_range(scratch_query_cursor, start_index, end_index); ts_query_cursor_set_match_limit(scratch_query_cursor, match_limit); ts_query_cursor_set_max_start_depth(scratch_query_cursor, max_start_depth); + ts_query_cursor_set_timeout_micros(scratch_query_cursor, timeout_micros); ts_query_cursor_exec(scratch_query_cursor, self, node); unsigned index = 0; diff --git a/lib/binding_web/binding.js b/lib/binding_web/binding.js index 2b4696c3..a626aa01 100644 --- a/lib/binding_web/binding.js +++ b/lib/binding_web/binding.js @@ -1279,6 +1279,7 @@ class Query { endIndex = 0, matchLimit = 0xFFFFFFFF, maxStartDepth = 0xFFFFFFFF, + timeoutMicros = 0, } = {}, ) { if (typeof matchLimit !== 'number') { @@ -1298,6 +1299,7 @@ class Query { endIndex, matchLimit, maxStartDepth, + timeoutMicros, ); const rawCount = getValue(TRANSFER_BUFFER, 'i32'); @@ -1342,6 +1344,7 @@ class Query { endIndex = 0, matchLimit = 0xFFFFFFFF, maxStartDepth = 0xFFFFFFFF, + timeoutMicros = 0, } = {}, ) { if (typeof matchLimit !== 'number') { @@ -1361,6 +1364,7 @@ class Query { endIndex, matchLimit, maxStartDepth, + timeoutMicros, ); const count = getValue(TRANSFER_BUFFER, 'i32'); diff --git a/lib/binding_web/test/query-test.js b/lib/binding_web/test/query-test.js index fad6b3cf..db4c10f8 100644 --- a/lib/binding_web/test/query-test.js +++ b/lib/binding_web/test/query-test.js @@ -451,6 +451,17 @@ describe('Query', () => { ]); }); }); + + describe('Set a timeout', () => + it('returns less than the expected matches', () => { + tree = parser.parse('function foo() while (true) { } }\n'.repeat(1000)); + query = JavaScript.query('(function_declaration name: (identifier) @function)'); + const matches = query.matches(tree.rootNode, { timeoutMicros: 1000 }); + assert.isBelow(matches.length, 1000); + const matches2 = query.matches(tree.rootNode, { timeoutMicros: 0 }); + assert.equal(matches2.length, 1000); + }) + ); }); function formatMatches(matches) { diff --git a/lib/binding_web/tree-sitter-web.d.ts b/lib/binding_web/tree-sitter-web.d.ts index 97a48077..8a1fa071 100644 --- a/lib/binding_web/tree-sitter-web.d.ts +++ b/lib/binding_web/tree-sitter-web.d.ts @@ -179,6 +179,7 @@ declare module 'web-tree-sitter' { endIndex?: number; matchLimit?: number; maxStartDepth?: number; + timeoutMicros?: number; }; export interface PredicateResult { diff --git a/lib/include/tree_sitter/api.h b/lib/include/tree_sitter/api.h index c1fbad25..5ea845f5 100644 --- a/lib/include/tree_sitter/api.h +++ b/lib/include/tree_sitter/api.h @@ -983,6 +983,22 @@ bool ts_query_cursor_did_exceed_match_limit(const TSQueryCursor *self); uint32_t ts_query_cursor_match_limit(const TSQueryCursor *self); void ts_query_cursor_set_match_limit(TSQueryCursor *self, uint32_t limit); +/** + * Set the maximum duration in microseconds that query execution should be allowed to + * take before halting. + * + * If query execution takes longer than this, it will halt early, returning NULL. + * See [`ts_query_cursor_next_match`] or [`ts_query_cursor_next_capture`] for more information. + */ +void ts_query_cursor_set_timeout_micros(TSQueryCursor *self, uint64_t timeout_micros); + +/** + * Get the duration in microseconds that query execution is allowed to take. + * + * This is set via [`ts_query_cursor_set_timeout_micros`]. + */ +uint64_t ts_query_cursor_timeout_micros(const TSQueryCursor *self); + /** * Set the range of bytes or (row, column) positions in which the query * will be executed. diff --git a/lib/src/parser.c b/lib/src/parser.c index 2927d820..5db2cf50 100644 --- a/lib/src/parser.c +++ b/lib/src/parser.c @@ -83,7 +83,7 @@ static const unsigned MAX_VERSION_COUNT = 6; static const unsigned MAX_VERSION_COUNT_OVERFLOW = 4; static const unsigned MAX_SUMMARY_DEPTH = 16; static const unsigned MAX_COST_DIFFERENCE = 16 * ERROR_COST_PER_SKIPPED_TREE; -static const unsigned OP_COUNT_PER_TIMEOUT_CHECK = 100; +static const unsigned OP_COUNT_PER_PARSER_TIMEOUT_CHECK = 100; typedef struct { Subtree token; @@ -1565,7 +1565,7 @@ static bool ts_parser__advance( // If a cancellation flag or a timeout was provided, then check every // time a fixed number of parse actions has been processed. - if (++self->operation_count == OP_COUNT_PER_TIMEOUT_CHECK) { + if (++self->operation_count == OP_COUNT_PER_PARSER_TIMEOUT_CHECK) { self->operation_count = 0; } if ( diff --git a/lib/src/query.c b/lib/src/query.c index c9e8fbd0..4941f507 100644 --- a/lib/src/query.c +++ b/lib/src/query.c @@ -1,6 +1,7 @@ #include "tree_sitter/api.h" #include "./alloc.h" #include "./array.h" +#include "./clock.h" #include "./language.h" #include "./point.h" #include "./tree_cursor.h" @@ -312,6 +313,9 @@ struct TSQueryCursor { TSPoint start_point; TSPoint end_point; uint32_t next_state_id; + TSClock end_clock; + TSDuration timeout_duration; + unsigned operation_count; bool on_visible_node; bool ascending; bool halted; @@ -322,6 +326,7 @@ static const TSQueryError PARENT_DONE = -1; static const uint16_t PATTERN_DONE_MARKER = UINT16_MAX; static const uint16_t NONE = UINT16_MAX; static const TSSymbol WILDCARD_SYMBOL = 0; +static const unsigned OP_COUNT_PER_QUERY_TIMEOUT_CHECK = 100; /********** * Stream @@ -2986,6 +2991,9 @@ TSQueryCursor *ts_query_cursor_new(void) { .start_point = {0, 0}, .end_point = POINT_MAX, .max_start_depth = UINT32_MAX, + .timeout_duration = 0, + .end_clock = clock_null(), + .operation_count = 0, }; array_reserve(&self->states, 8); array_reserve(&self->finished_states, 8); @@ -3012,6 +3020,14 @@ void ts_query_cursor_set_match_limit(TSQueryCursor *self, uint32_t limit) { self->capture_list_pool.max_capture_list_count = limit; } +uint64_t ts_query_cursor_timeout_micros(const TSQueryCursor *self) { + return duration_to_micros(self->timeout_duration); +} + +void ts_query_cursor_set_timeout_micros(TSQueryCursor *self, uint64_t timeout_micros) { + self->timeout_duration = duration_from_micros(timeout_micros); +} + #ifdef DEBUG_EXECUTE_QUERY #define LOG(...) fprintf(stderr, __VA_ARGS__) #else @@ -3023,7 +3039,7 @@ void ts_query_cursor_exec( const TSQuery *query, TSNode node ) { - if (query) { + if (query) { LOG("query steps:\n"); for (unsigned i = 0; i < query->steps.size; i++) { QueryStep *step = &query->steps.contents[i]; @@ -3060,6 +3076,12 @@ void ts_query_cursor_exec( self->halted = false; self->query = query; self->did_exceed_match_limit = false; + self->operation_count = 0; + if (self->timeout_duration) { + self->end_clock = clock_after(clock_now(), self->timeout_duration); + } else { + self->end_clock = clock_null(); + } } void ts_query_cursor_set_byte_range( @@ -3456,7 +3478,19 @@ static inline bool ts_query_cursor__advance( } } - if (did_match || self->halted) return did_match; + if (++self->operation_count == OP_COUNT_PER_QUERY_TIMEOUT_CHECK) { + self->operation_count = 0; + } + if ( + did_match || + self->halted || + ( + self->operation_count == 0 && + !clock_is_null(self->end_clock) && clock_is_gt(clock_now(), self->end_clock) + ) + ) { + return did_match; + } // Exit the current node. if (self->ascending) { diff --git a/script/generate-bindings b/script/generate-bindings index fe83352b..a0022d8f 100755 --- a/script/generate-bindings +++ b/script/generate-bindings @@ -37,6 +37,7 @@ bindgen \ --blocklist-type '^__.*' \ --no-prepend-enum-name \ --no-copy "$no_copy" \ + --use-core \ "$header_path" \ -- \ -D TREE_SITTER_FEATURE_WASM \