diff --git a/cli/src/tests/query_test.rs b/cli/src/tests/query_test.rs index 132f7076..9a002ae8 100644 --- a/cli/src/tests/query_test.rs +++ b/cli/src/tests/query_test.rs @@ -519,6 +519,51 @@ fn test_query_errors_on_impossible_patterns() { .join("\n") }) ); + assert_eq!( + Query::new(&js_lang, "(identifier/identifier)").unwrap_err(), + QueryError { + row: 0, + offset: 0, + column: 0, + kind: QueryErrorKind::Structure, + message: [ + "(identifier/identifier)", // + "^" + ] + .join("\n") + } + ); + + if js_lang.version() >= 15 { + assert_eq!( + Query::new(&js_lang, "(statement/identifier)").unwrap_err(), + QueryError { + row: 0, + offset: 0, + column: 0, + kind: QueryErrorKind::Structure, + message: [ + "(statement/identifier)", // + "^" + ] + .join("\n") + } + ); + assert_eq!( + Query::new(&js_lang, "(statement/pattern)").unwrap_err(), + QueryError { + row: 0, + offset: 0, + column: 0, + kind: QueryErrorKind::Structure, + message: [ + "(statement/pattern)", // + "^" + ] + .join("\n") + } + ); + } }); } diff --git a/lib/src/query.c b/lib/src/query.c index 9114630a..20317d24 100644 --- a/lib/src/query.c +++ b/lib/src/query.c @@ -2316,10 +2316,10 @@ static TSQueryError ts_query__parse_pattern( else { TSSymbol symbol; bool is_missing = false; + const char *node_name = stream->input; // Parse a normal node name if (stream_is_ident_start(stream)) { - const char *node_name = stream->input; stream_scan_identifier(stream); uint32_t length = (uint32_t)(stream->input - node_name); @@ -2406,26 +2406,56 @@ static TSQueryError ts_query__parse_pattern( stream_skip_whitespace(stream); if (stream->next == '/') { + if (!step->supertype_symbol) { + stream_reset(stream, node_name - 1); // reset to the start of the node + return TSQueryErrorStructure; + } + stream_advance(stream); if (!stream_is_ident_start(stream)) { return TSQueryErrorSyntax; } - const char *node_name = stream->input; + const char *subtype_node_name = stream->input; stream_scan_identifier(stream); - uint32_t length = (uint32_t)(stream->input - node_name); + uint32_t length = (uint32_t)(stream->input - subtype_node_name); step->symbol = ts_language_symbol_for_name( self->language, - node_name, + subtype_node_name, length, true ); if (!step->symbol) { - stream_reset(stream, node_name); + stream_reset(stream, subtype_node_name); return TSQueryErrorNodeType; } + // Get all the possible subtypes for the given supertype, + // and check if the given subtype is valid. + if (self->language->version >= LANGUAGE_VERSION_WITH_RESERVED_WORDS) { + uint32_t subtype_length; + const TSSymbol *subtypes = ts_language_subtypes( + self->language, + step->supertype_symbol, + &subtype_length + ); + + bool subtype_is_valid = false; + for (uint32_t i = 0; i < subtype_length; i++) { + if (subtypes[i] == step->symbol) { + subtype_is_valid = true; + break; + } + } + + // This subtype is not valid for the given supertype. + if (!subtype_is_valid) { + stream_reset(stream, node_name - 1); // reset to the start of the node + return TSQueryErrorStructure; + } + } + stream_skip_whitespace(stream); }