diff --git a/cli/src/tests/query_test.rs b/cli/src/tests/query_test.rs index 0d078411..945b3b1f 100644 --- a/cli/src/tests/query_test.rs +++ b/cli/src/tests/query_test.rs @@ -82,18 +82,29 @@ fn test_query_errors_on_invalid_syntax() { 1, [ "((identifier) ()", // - " ^", + " ^", ] .join("\n") )) ); assert_eq!( - Query::new(language, r#"((identifier) @x (eq? @x a"#), + Query::new(language, r#"((identifier) (#a)"#), Err(QueryError::Syntax( 1, [ - r#"((identifier) @x (eq? @x a"#, - r#" ^"#, + "((identifier) (#a)", // + " ^", + ] + .join("\n") + )) + ); + assert_eq!( + Query::new(language, r#"((identifier) @x (#eq? @x a"#), + Err(QueryError::Syntax( + 1, + [ + r#"((identifier) @x (#eq? @x a"#, + r#" ^"#, ] .join("\n") )) @@ -136,18 +147,23 @@ fn test_query_errors_on_invalid_conditions() { assert_eq!( Query::new(language, "((identifier) @id (@id))"), - Err(QueryError::Predicate( - "Expected predicate to start with a function name. Got @id.".to_string() + Err(QueryError::Syntax( + 1, + [ + "((identifier) @id (@id))", // + " ^" + ] + .join("\n") )) ); assert_eq!( - Query::new(language, "((identifier) @id (eq? @id))"), + Query::new(language, "((identifier) @id (#eq? @id))"), Err(QueryError::Predicate( - "Wrong number of arguments to eq? predicate. Expected 2, got 1.".to_string() + "Wrong number of arguments to #eq? predicate. Expected 2, got 1.".to_string() )) ); assert_eq!( - Query::new(language, "((identifier) @id (eq? @id @ok))"), + Query::new(language, "((identifier) @id (#eq? @id @ok))"), Err(QueryError::Capture(1, "ok".to_string())) ); }); @@ -1158,13 +1174,13 @@ fn test_query_captures_with_text_conditions() { language, r#" ((identifier) @constant - (match? @constant "^[A-Z]{2,}$")) + (#match? @constant "^[A-Z]{2,}$")) ((identifier) @constructor - (match? @constructor "^[A-Z]")) + (#match? @constructor "^[A-Z]")) ((identifier) @function.builtin - (eq? @function.builtin "require")) + (#eq? @function.builtin "require")) (identifier) @variable "#, @@ -1207,13 +1223,13 @@ fn test_query_captures_with_predicates() { language, r#" ((call_expression (identifier) @foo) - (set! name something) - (set! cool) - (something! @foo omg)) + (#set! name something) + (#set! cool) + (#something! @foo omg)) ((property_identifier) @bar - (is? cool) - (is-not? name something))"#, + (#is? cool) + (#is-not? name something))"#, ) .unwrap(); @@ -1259,13 +1275,13 @@ fn test_query_captures_with_quoted_predicate_args() { language, r#" ((call_expression (identifier) @foo) - (set! one "\"something\ngreat\"")) + (#set! one "\"something\ngreat\"")) ((identifier) - (set! two "\\s(\r?\n)*$")) + (#set! two "\\s(\r?\n)*$")) ((function_declaration) - (set! three "\"something\ngreat\"")) + (#set! three "\"something\ngreat\"")) "#, ) .unwrap(); @@ -1403,7 +1419,7 @@ fn test_query_captures_with_many_nested_results_with_fields() { consequence: (member_expression object: (identifier) @right) alternative: (null)) - (eq? @left @right)) + (#eq? @left @right)) "#, ) .unwrap(); @@ -1689,7 +1705,7 @@ fn test_query_start_byte_for_pattern() { .trim_start(); let patterns_3 = " - ((identifier) @b (match? @b i)) + ((identifier) @b (#match? @b i)) (function_declaration name: (identifier) @c) (method_definition name: (identifier) @d) " diff --git a/cli/src/tests/tags_test.rs b/cli/src/tests/tags_test.rs index 41907a3c..9bfd1f56 100644 --- a/cli/src/tests/tags_test.rs +++ b/cli/src/tests/tags_test.rs @@ -6,46 +6,58 @@ use tree_sitter_tags::c_lib as c; use tree_sitter_tags::{Error, TagKind, TagsConfiguration, TagsContext}; const PYTHON_TAG_QUERY: &'static str = r#" -((function_definition - name: (identifier) @name - body: (block . (expression_statement (string) @doc))) @function - (strip! @doc "(^['\"\\s]*)|(['\"\\s]*$)")) +( + (function_definition + name: (identifier) @name + body: (block . (expression_statement (string) @doc))) @function + (#strip! @doc "(^['\"\\s]*)|(['\"\\s]*$)") +) + (function_definition name: (identifier) @name) @function -((class_definition - name: (identifier) @name - body: (block . (expression_statement (string) @doc))) @class - (strip! @doc "(^['\"\\s]*)|(['\"\\s]*$)")) + +( + (class_definition + name: (identifier) @name + body: (block + . (expression_statement (string) @doc))) @class + (#strip! @doc "(^['\"\\s]*)|(['\"\\s]*$)") +) + (class_definition name: (identifier) @name) @class + (call function: (identifier) @name) @call "#; const JS_TAG_QUERY: &'static str = r#" -((* +( (comment)+ @doc . (class_declaration - name: (identifier) @name) @class) - (select-adjacent! @doc @class) - (strip! @doc "(^[/\\*\\s]*)|([/\\*\\s]*$)")) + name: (identifier) @name) @class + (#select-adjacent! @doc @class) + (#strip! @doc "(^[/\\*\\s]*)|([/\\*\\s]*$)") +) -((* +( (comment)+ @doc . (method_definition - name: (property_identifier) @name) @method) - (select-adjacent! @doc @method) - (strip! @doc "(^[/\\*\\s]*)|([/\\*\\s]*$)")) + name: (property_identifier) @name) @method + (#select-adjacent! @doc @method) + (#strip! @doc "(^[/\\*\\s]*)|([/\\*\\s]*$)") +) -((* +( (comment)+ @doc . (function_declaration - name: (identifier) @name) @function) - (select-adjacent! @doc @function) - (strip! @doc "(^[/\\*\\s]*)|([/\\*\\s]*$)")) + name: (identifier) @name) @function + (#select-adjacent! @doc @function) + (#strip! @doc "(^[/\\*\\s]*)|([/\\*\\s]*$)") +) (call_expression function: (identifier) @name) @call - "#; +"#; const RUBY_TAG_QUERY: &'static str = r#" (method @@ -55,7 +67,7 @@ const RUBY_TAG_QUERY: &'static str = r#" method: (identifier) @name) @call ((identifier) @name @call - (is-not? local)) + (#is-not? local)) "#; #[test] diff --git a/lib/binding_rust/lib.rs b/lib/binding_rust/lib.rs index a13d9168..c0aba32f 100644 --- a/lib/binding_rust/lib.rs +++ b/lib/binding_rust/lib.rs @@ -1271,13 +1271,13 @@ impl Query { "eq?" | "not-eq?" => { if p.len() != 3 { return Err(QueryError::Predicate(format!( - "Wrong number of arguments to eq? predicate. Expected 2, got {}.", + "Wrong number of arguments to #eq? predicate. Expected 2, got {}.", p.len() - 1 ))); } if p[1].type_ != type_capture { return Err(QueryError::Predicate(format!( - "First argument to eq? predicate must be a capture name. Got literal \"{}\".", + "First argument to #eq? predicate must be a capture name. Got literal \"{}\".", string_values[p[1].value_id as usize], ))); } @@ -1301,19 +1301,19 @@ impl Query { "match?" => { if p.len() != 3 { return Err(QueryError::Predicate(format!( - "Wrong number of arguments to match? predicate. Expected 2, got {}.", + "Wrong number of arguments to #match? predicate. Expected 2, got {}.", p.len() - 1 ))); } if p[1].type_ != type_capture { return Err(QueryError::Predicate(format!( - "First argument to match? predicate must be a capture name. Got literal \"{}\".", + "First argument to #match? predicate must be a capture name. Got literal \"{}\".", string_values[p[1].value_id as usize], ))); } if p[2].type_ == type_capture { return Err(QueryError::Predicate(format!( - "Second argument to match? predicate must be a literal. Got capture @{}.", + "Second argument to #match? predicate must be a literal. Got capture @{}.", result.capture_names[p[2].value_id as usize], ))); } diff --git a/lib/src/query.c b/lib/src/query.c index 801b98e2..49cbb92f 100644 --- a/lib/src/query.c +++ b/lib/src/query.c @@ -567,8 +567,20 @@ static TSQueryError ts_query__parse_predicate( TSQuery *self, Stream *stream ) { - if (stream->next == ')') return PARENT_DONE; - if (stream->next != '(') return TSQueryErrorSyntax; + if (!stream_is_ident_start(stream)) return TSQueryErrorSyntax; + const char *predicate_name = stream->input; + stream_scan_identifier(stream); + uint32_t length = stream->input - predicate_name; + uint16_t id = symbol_table_insert_name( + &self->predicate_values, + predicate_name, + length + ); + array_back(&self->predicates_by_pattern)->length++; + array_push(&self->predicate_steps, ((TSQueryPredicateStep) { + .type = TSQueryPredicateStepTypeString, + .value_id = id, + })); stream_advance(stream); stream_skip_whitespace(stream); @@ -703,35 +715,16 @@ static TSQueryError ts_query__parse_pattern( return PARENT_DONE; } - // Parse a parenthesized node expression + // Parse either: + // * A parenthesized sequence of nodes + // * A predicate + // * A named node else if (stream->next == '(') { stream_advance(stream); stream_skip_whitespace(stream); - // At the top-level, a nested list represents one root pattern followed by - // zero-or-more predicates. - if (stream->next == '(' && depth == 0) { - TSQueryError e = ts_query__parse_pattern(self, stream, 0, capture_count, is_immediate); - if (e) return e; - - // Parse the predicates. - stream_skip_whitespace(stream); - for (;;) { - TSQueryError e = ts_query__parse_predicate(self, stream); - if (e == PARENT_DONE) { - stream_advance(stream); - stream_skip_whitespace(stream); - return 0; - } else if (e) { - return e; - } - } - } - - // When nested inside of a larger pattern, a nested list just represents - // multiple sibling nodes which are grouped, possibly so that a postfix - // operator can be applied to the group. - else if (depth > 0 && (stream->next == '(' || stream->next == '"' )) { + // If this parenthesis is followed by a node, then it represents grouping. + if (stream->next == '(' || stream->next == '"') { bool child_is_immediate = false; for (;;) { if (stream->next == '.') { @@ -755,7 +748,16 @@ static TSQueryError ts_query__parse_pattern( child_is_immediate = false; } - } else { + } + + // This parenthesis is the start of a predicate + else if (stream->next == '#') { + stream_advance(stream); + return ts_query__parse_predicate(self, stream); + } + + // Otherwise, this parenthesis is the start of a named node. + else { TSSymbol symbol; // Parse the wildcard symbol