From 4fa0b02d67253c36b04f5ef81ac9f00ff2ed086c Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Wed, 11 Sep 2019 12:06:38 -0700 Subject: [PATCH] Fix management of capture lists in query execution --- cli/src/tests/helpers/allocations.rs | 6 + cli/src/tests/query_test.rs | 447 +++++++++++++++------------ lib/src/query.c | 79 +++-- 3 files changed, 322 insertions(+), 210 deletions(-) diff --git a/cli/src/tests/helpers/allocations.rs b/cli/src/tests/helpers/allocations.rs index c64762bd..2f89c173 100644 --- a/cli/src/tests/helpers/allocations.rs +++ b/cli/src/tests/helpers/allocations.rs @@ -51,6 +51,12 @@ pub fn stop_recording() { } } +pub fn record(f: impl FnOnce()) { + start_recording(); + f(); + stop_recording(); +} + fn record_alloc(ptr: *mut c_void) { let mut recorder = RECORDER.lock(); if recorder.enabled { diff --git a/cli/src/tests/query_test.rs b/cli/src/tests/query_test.rs index 28becda2..cfa6a2ba 100644 --- a/cli/src/tests/query_test.rs +++ b/cli/src/tests/query_test.rs @@ -4,235 +4,302 @@ use tree_sitter::{Parser, Query, QueryError, QueryMatch}; #[test] fn test_query_errors_on_invalid_syntax() { - allocations::start_recording(); + allocations::record(|| { + let language = get_language("javascript"); - let language = get_language("javascript"); + assert!(Query::new(language, "(if_statement)").is_ok()); + assert!(Query::new(language, "(if_statement condition:(identifier))").is_ok()); - assert!(Query::new(language, "(if_statement)").is_ok()); - assert!(Query::new(language, "(if_statement condition:(identifier))").is_ok()); + // Mismatched parens + assert_eq!( + Query::new(language, "(if_statement"), + Err(QueryError::Syntax(13)) + ); + assert_eq!( + Query::new(language, "(if_statement))"), + Err(QueryError::Syntax(14)) + ); - // Mismatched parens - assert_eq!( - Query::new(language, "(if_statement"), - Err(QueryError::Syntax(13)) - ); - assert_eq!( - Query::new(language, "(if_statement))"), - Err(QueryError::Syntax(14)) - ); + // Return an error at the *beginning* of a bare identifier not followed a colon. + // If there's a colon but no pattern, return an error at the end of the colon. + assert_eq!( + Query::new(language, "(if_statement identifier)"), + Err(QueryError::Syntax(14)) + ); + assert_eq!( + Query::new(language, "(if_statement condition:)"), + Err(QueryError::Syntax(24)) + ); - // Return an error at the *beginning* of a bare identifier not followed a colon. - // If there's a colon but no pattern, return an error at the end of the colon. - assert_eq!( - Query::new(language, "(if_statement identifier)"), - Err(QueryError::Syntax(14)) - ); - assert_eq!( - Query::new(language, "(if_statement condition:)"), - Err(QueryError::Syntax(24)) - ); - - assert_eq!( - Query::new(language, "(if_statement condition:)"), - Err(QueryError::Syntax(24)) - ); - - allocations::stop_recording(); + assert_eq!( + Query::new(language, "(if_statement condition:)"), + Err(QueryError::Syntax(24)) + ); + }); } #[test] fn test_query_errors_on_invalid_symbols() { - allocations::start_recording(); + allocations::record(|| { + let language = get_language("javascript"); - let language = get_language("javascript"); - - assert_eq!( - Query::new(language, "(non_existent1)"), - Err(QueryError::NodeType("non_existent1")) - ); - assert_eq!( - Query::new(language, "(if_statement (non_existent2))"), - Err(QueryError::NodeType("non_existent2")) - ); - assert_eq!( - Query::new(language, "(if_statement condition: (non_existent3))"), - Err(QueryError::NodeType("non_existent3")) - ); - assert_eq!( - Query::new(language, "(if_statement not_a_field: (identifier))"), - Err(QueryError::Field("not_a_field")) - ); - - allocations::stop_recording(); + assert_eq!( + Query::new(language, "(non_existent1)"), + Err(QueryError::NodeType("non_existent1")) + ); + assert_eq!( + Query::new(language, "(if_statement (non_existent2))"), + Err(QueryError::NodeType("non_existent2")) + ); + assert_eq!( + Query::new(language, "(if_statement condition: (non_existent3))"), + Err(QueryError::NodeType("non_existent3")) + ); + assert_eq!( + Query::new(language, "(if_statement not_a_field: (identifier))"), + Err(QueryError::Field("not_a_field")) + ); + }); } #[test] fn test_query_capture_names() { - allocations::start_recording(); + allocations::record(|| { + let language = get_language("javascript"); + let query = Query::new( + language, + r#" + (if_statement + condition: (binary_expression + left: * @left-operand + operator: "||" + right: * @right-operand) + consequence: (statement_block) @body) - let language = get_language("javascript"); - let query = Query::new( - language, - r#" - (if_statement - condition: (binary_expression - left: * @left-operand - operator: "||" - right: * @right-operand) - consequence: (statement_block) @body) + (while_statement + condition:* @loop-condition) + "#, + ) + .unwrap(); - (while_statement - condition:* @loop-condition) - "#, - ) - .unwrap(); - - assert_eq!( - query.capture_names(), - &[ - "left-operand".to_string(), - "right-operand".to_string(), - "body".to_string(), - "loop-condition".to_string(), - ] - ); - - drop(query); - allocations::stop_recording(); + assert_eq!( + query.capture_names(), + &[ + "left-operand".to_string(), + "right-operand".to_string(), + "body".to_string(), + "loop-condition".to_string(), + ] + ); + }); } #[test] fn test_query_exec_with_simple_pattern() { - allocations::start_recording(); + allocations::record(|| { + let language = get_language("javascript"); + let query = Query::new( + language, + "(function_declaration name: (identifier) @fn-name)", + ) + .unwrap(); - let language = get_language("javascript"); - let query = Query::new( - language, - "(function_declaration name: (identifier) @fn-name)", - ) - .unwrap(); + let source = "function one() { two(); function three() {} }"; + let mut parser = Parser::new(); + parser.set_language(language).unwrap(); + let tree = parser.parse(source, None).unwrap(); - let source = "function one() { two(); function three() {} }"; - let mut parser = Parser::new(); - parser.set_language(language).unwrap(); - let tree = parser.parse(source, None).unwrap(); + let context = query.context(); + let matches = context.exec(tree.root_node()); - let context = query.context(); - let matches = context.exec(tree.root_node()); - - assert_eq!( - collect_matches(matches, &query, source), - &[ - (0, vec![("fn-name", "one")]), - (0, vec![("fn-name", "three")]) - ], - ); - - drop(context); - drop(parser); - drop(query); - drop(tree); - allocations::stop_recording(); + assert_eq!( + collect_matches(matches, &query, source), + &[ + (0, vec![("fn-name", "one")]), + (0, vec![("fn-name", "three")]) + ], + ); + }); } #[test] fn test_query_exec_with_multiple_matches_same_root() { - allocations::start_recording(); + allocations::record(|| { + let language = get_language("javascript"); + let query = Query::new( + language, + "(class_declaration + name: (identifier) @the-class-name + (class_body + (method_definition + name: (property_identifier) @the-method-name)))", + ) + .unwrap(); - let language = get_language("javascript"); - let query = Query::new( - language, - "(class_declaration - name: (identifier) @the-class-name - (class_body - (method_definition - name: (property_identifier) @the-method-name)))", - ) - .unwrap(); + let source = " + class Person { + // the constructor + constructor(name) { this.name = name; } - let source = " - class Person { - // the constructor - constructor(name) { this.name = name; } + // the getter + getFullName() { return this.name; } + } + "; - // the getter - getFullName() { return this.name; } - } - "; + let mut parser = Parser::new(); + parser.set_language(language).unwrap(); + let tree = parser.parse(source, None).unwrap(); + let context = query.context(); + let matches = context.exec(tree.root_node()); - let mut parser = Parser::new(); - parser.set_language(language).unwrap(); - let tree = parser.parse(source, None).unwrap(); - let context = query.context(); - let matches = context.exec(tree.root_node()); - - assert_eq!( - collect_matches(matches, &query, source), - &[ - ( - 0, - vec![ - ("the-class-name", "Person"), - ("the-method-name", "constructor") - ] - ), - ( - 0, - vec![ - ("the-class-name", "Person"), - ("the-method-name", "getFullName") - ] - ), - ], - ); - - drop(context); - drop(parser); - drop(query); - drop(tree); - allocations::stop_recording(); + assert_eq!( + collect_matches(matches, &query, source), + &[ + ( + 0, + vec![ + ("the-class-name", "Person"), + ("the-method-name", "constructor") + ] + ), + ( + 0, + vec![ + ("the-class-name", "Person"), + ("the-method-name", "getFullName") + ] + ), + ], + ); + }); } #[test] fn test_query_exec_multiple_patterns() { - allocations::start_recording(); + allocations::record(|| { + let language = get_language("javascript"); + let query = Query::new( + language, + " + (function_declaration name:(identifier) @fn-def) + (call_expression function:(identifier) @fn-ref) + ", + ) + .unwrap(); - let language = get_language("javascript"); - let query = Query::new( - language, - " - (function_declaration name:(identifier) @fn-def) - (call_expression function:(identifier) @fn-ref) + let source = " + function f1() { + f2(f3()); + } + "; + + let mut parser = Parser::new(); + parser.set_language(language).unwrap(); + let tree = parser.parse(source, None).unwrap(); + let context = query.context(); + let matches = context.exec(tree.root_node()); + + assert_eq!( + collect_matches(matches, &query, source), + &[ + (0, vec![("fn-def", "f1")]), + (1, vec![("fn-ref", "f2")]), + (1, vec![("fn-ref", "f3")]), + ], + ); + }); +} + +#[test] +fn test_query_exec_nested_matches_without_fields() { + allocations::record(|| { + let language = get_language("javascript"); + let query = Query::new( + language, + " + (array + (array + (identifier) @element-1 + (identifier) @element-2)) + ", + ) + .unwrap(); + + let source = " + [[a]]; + [[c, d], [e, f, g]]; + [[h], [i]]; + "; + + let mut parser = Parser::new(); + parser.set_language(language).unwrap(); + let tree = parser.parse(source, None).unwrap(); + let context = query.context(); + let matches = context.exec(tree.root_node()); + + assert_eq!( + collect_matches(matches, &query, source), + &[ + (0, vec![("element-1", "c"), ("element-2", "d")]), + (0, vec![("element-1", "e"), ("element-2", "f")]), + (0, vec![("element-1", "f"), ("element-2", "g")]), + (0, vec![("element-1", "e"), ("element-2", "g")]), + ], + ); + }); +} + +#[test] +fn test_query_exec_many_matches() { + allocations::record(|| { + let language = get_language("javascript"); + let query = Query::new(language, "(array (identifier) @element)").unwrap(); + + let source = "[hello];\n".repeat(50); + + let mut parser = Parser::new(); + parser.set_language(language).unwrap(); + let tree = parser.parse(&source, None).unwrap(); + let context = query.context(); + let matches = context.exec(tree.root_node()); + + assert_eq!( + collect_matches(matches, &query, source.as_str()), + vec![(0, vec![("element", "hello")]); 50], + ); + }); +} + +#[test] +fn test_query_exec_too_many_match_permutations_to_track() { + allocations::record(|| { + let language = get_language("javascript"); + let query = Query::new( + language, + " + (array (identifier) @pre (identifier) @post) ", - ) - .unwrap(); + ) + .unwrap(); - let source = " - function f1() { - f2(f3()); - } - "; + let mut source = "hello, ".repeat(50); + source.insert(0, '['); + source.push_str("];"); - let mut parser = Parser::new(); - parser.set_language(language).unwrap(); - let tree = parser.parse(source, None).unwrap(); - let context = query.context(); - let matches = context.exec(tree.root_node()); + let mut parser = Parser::new(); + parser.set_language(language).unwrap(); + let tree = parser.parse(&source, None).unwrap(); + let context = query.context(); + let matches = context.exec(tree.root_node()); - assert_eq!( - collect_matches(matches, &query, source), - &[ - (0, vec![("fn-def", "f1")]), - (1, vec![("fn-ref", "f2")]), - (1, vec![("fn-ref", "f3")]), - ], - ); - - drop(context); - drop(parser); - drop(query); - drop(tree); - allocations::stop_recording(); + // For this pathological query, some match permutations will be dropped. + // Just check that a subset of the results are returned, and crash or + // leak occurs. + assert_eq!( + collect_matches(matches, &query, source.as_str())[0], + (0, vec![("pre", "hello"), ("post", "hello")]), + ); + }); } fn collect_matches<'a>( diff --git a/lib/src/query.c b/lib/src/query.c index 167de1d7..9325424b 100644 --- a/lib/src/query.c +++ b/lib/src/query.c @@ -107,6 +107,12 @@ static const uint16_t NONE = UINT16_MAX; static const TSSymbol WILDCARD_SYMBOL = 0; static const uint16_t MAX_STATE_COUNT = 32; +#ifdef DEBUG_QUERY +#define LOG printf +#else +#define LOG(...) +#endif + /********** * Stream **********/ @@ -183,15 +189,23 @@ static TSQueryCapture *capture_list_pool_get(CaptureListPool *self, uint16_t id) return &self->contents[id * self->list_size]; } +static inline uint32_t capture_list_bitmask_for_id(uint16_t id) { + // An id of zero corresponds to the highest-order bit in the bitmask. + return (1u << (31 - id)); +} + static uint16_t capture_list_pool_acquire(CaptureListPool *self) { + // In the usage_map bitmask, ones represent free lists, and zeros represent + // lists that are in use. A free list can quickly be found by counting + // the leading zeros in the usage map. uint16_t id = count_leading_zeros(self->usage_map); if (id == 32) return NONE; - self->usage_map &= ~(1 << id); + self->usage_map &= ~capture_list_bitmask_for_id(id); return id; } static void capture_list_pool_release(CaptureListPool *self, uint16_t id) { - self->usage_map |= (1 << id); + self->usage_map |= capture_list_bitmask_for_id(id); } /********* @@ -586,9 +600,31 @@ void ts_query_context_exec(TSQueryContext *self, TSNode node) { self->ascending = false; } +static QueryState *ts_query_context_copy_state( + TSQueryContext *self, + QueryState *state +) { + uint32_t capture_list_id = capture_list_pool_acquire(&self->capture_list_pool); + if (capture_list_id == NONE) return NULL; + array_push(&self->states, *state); + QueryState *new_state = array_back(&self->states); + new_state->capture_list_id = capture_list_id; + TSQueryCapture *old_captures = capture_list_pool_get( + &self->capture_list_pool, + state->capture_list_id + ); + TSQueryCapture *new_captures = capture_list_pool_get( + &self->capture_list_pool, + capture_list_id + ); + memcpy(new_captures, old_captures, state->capture_count * sizeof(TSQueryCapture)); + return new_state; +} + bool ts_query_context_next(TSQueryContext *self) { if (self->finished_states.size > 0) { - array_pop(&self->finished_states); + QueryState state = array_pop(&self->finished_states); + capture_list_pool_release(&self->capture_list_pool, state.capture_list_id); } while (self->finished_states.size == 0) { @@ -598,9 +634,14 @@ bool ts_query_context_next(TSQueryContext *self) { uint32_t deleted_count = 0; for (unsigned i = 0, n = self->states.size; i < n; i++) { QueryState *state = &self->states.contents[i]; - if (state->start_depth == self->depth) { + QueryStep *step = &self->query->steps.contents[state->step_index]; - // printf("FAIL STATE pattern: %u, step: %u\n", state->pattern_index, state->step_index); + if (state->start_depth + step->depth > self->depth) { + LOG( + "fail state with pattern: %u, step: %u\n", + state->pattern_index, + state->step_index + ); capture_list_pool_release( &self->capture_list_pool, @@ -612,9 +653,9 @@ bool ts_query_context_next(TSQueryContext *self) { } } - // if (deleted_count) { - // printf("FAILED %u of %u states\n", deleted_count, self->states.size); - // } + if (deleted_count) { + LOG("failed %u of %u states\n", deleted_count, self->states.size); + } self->states.size -= deleted_count; @@ -631,7 +672,7 @@ bool ts_query_context_next(TSQueryContext *self) { TSNode node = ts_tree_cursor_current_node(&self->cursor); TSSymbol symbol = ts_node_symbol(node); - // printf("DESCEND INTO NODE: %s\n", ts_node_type(node)); + LOG("enter node %s\n", ts_node_type(node)); // Add new states for any patterns whose root node is a wildcard. for (unsigned i = 0; i < self->query->wildcard_root_pattern_count; i++) { @@ -678,7 +719,7 @@ bool ts_query_context_next(TSQueryContext *self) { if (field_id != step->field) continue; } - // printf("START NEW STATE: %u\n", slice->pattern_index); + LOG("start pattern %u\n", slice->pattern_index); // If the node matches the first step of the pattern, then add // a new in-progress state. First, acquire a list to hold the @@ -733,19 +774,15 @@ bool ts_query_context_next(TSQueryContext *self) { // siblings. QueryState *next_state = state; if (step->depth > 0 && (!step->field || field_occurs_in_later_sibling)) { - uint32_t capture_list_id = capture_list_pool_acquire( - &self->capture_list_pool - ); - if (capture_list_id != NONE) { - array_push(&self->states, *state); - next_state = array_back(&self->states); - next_state->capture_list_id = capture_list_id; - } + QueryState *copy = ts_query_context_copy_state(self, state); + if (copy) next_state = copy; } + LOG("advance state for pattern %u\n", next_state->pattern_index); + // Record captures if (step->capture_id != NONE) { - // printf("CAPTURE id: %u\n", step->capture_id); + LOG("capture id %u\n", step->capture_id); TSQueryCapture *capture_list = capture_list_pool_get( &self->capture_list_pool, @@ -762,7 +799,7 @@ bool ts_query_context_next(TSQueryContext *self) { next_state->step_index++; QueryStep *next_step = step + 1; if (next_step->depth == PATTERN_DONE_MARKER) { - // printf("FINISHED MATCH pattern: %u\n", next_state->pattern_index); + LOG("finish pattern %u\n", next_state->pattern_index); array_push(&self->finished_states, *next_state); if (next_state == state) { @@ -808,3 +845,5 @@ const TSQueryCapture *ts_query_context_matched_captures( } return NULL; } + +#undef LOG