diff --git a/lib/binding_rust/bindings.rs b/lib/binding_rust/bindings.rs index e3d2a300..f6515608 100644 --- a/lib/binding_rust/bindings.rs +++ b/lib/binding_rust/bindings.rs @@ -742,8 +742,10 @@ extern "C" { extern "C" { #[doc = " Get the numerical id for the given node type string."] pub fn ts_language_symbol_for_name( - arg1: *const TSLanguage, - arg2: *const ::std::os::raw::c_char, + self_: *const TSLanguage, + string: *const ::std::os::raw::c_char, + length: u32, + is_named: bool, ) -> TSSymbol; } extern "C" { diff --git a/lib/include/tree_sitter/api.h b/lib/include/tree_sitter/api.h index 2a23c201..40187e3d 100644 --- a/lib/include/tree_sitter/api.h +++ b/lib/include/tree_sitter/api.h @@ -817,7 +817,12 @@ const char *ts_language_symbol_name(const TSLanguage *, TSSymbol); /** * Get the numerical id for the given node type string. */ -TSSymbol ts_language_symbol_for_name(const TSLanguage *, const char *); +TSSymbol ts_language_symbol_for_name( + const TSLanguage *self, + const char *string, + uint32_t length, + bool is_named +); /** * Get the number of distinct field names in the language. diff --git a/lib/src/language.c b/lib/src/language.c index e96a3cbf..3c08bba8 100644 --- a/lib/src/language.c +++ b/lib/src/language.c @@ -47,14 +47,19 @@ const char *ts_language_symbol_name(const TSLanguage *language, TSSymbol symbol) } } -TSSymbol ts_language_symbol_for_name(const TSLanguage *self, const char *name) { - if (!strcmp(name, "ERROR")) return ts_builtin_sym_error; - +TSSymbol ts_language_symbol_for_name( + const TSLanguage *self, + const char *string, + uint32_t length, + bool is_named +) { + if (!strncmp(string, "ERROR", length)) return ts_builtin_sym_error; uint32_t count = ts_language_symbol_count(self); for (TSSymbol i = 0; i < count; i++) { - if (!strcmp(self->symbol_names[i], name)) { - return i; - } + TSSymbolMetadata metadata = ts_language_symbol_metadata(self, i); + if (!metadata.visible || metadata.named != is_named) continue; + const char *symbol_name = self->symbol_names[i]; + if (!strncmp(symbol_name, string, length) && !symbol_name[length]) return i; } return 0; } diff --git a/lib/src/query.c b/lib/src/query.c index 0a613167..6bbf37c2 100644 --- a/lib/src/query.c +++ b/lib/src/query.c @@ -322,22 +322,6 @@ static uint16_t symbol_table_insert_name( * Query *********/ -static TSSymbol ts_query_intern_node_name( - const TSQuery *self, - const char *name, - uint32_t length, - TSSymbolType symbol_type -) { - if (!strncmp(name, "ERROR", length)) return ts_builtin_sym_error; - uint32_t symbol_count = ts_language_symbol_count(self->language); - for (TSSymbol i = 0; i < symbol_count; i++) { - if (ts_language_symbol_type(self->language, i) != symbol_type) continue; - const char *symbol_name = ts_language_symbol_name(self->language, i); - if (!strncmp(symbol_name, name, length) && !symbol_name[length]) return i; - } - return 0; -} - // The `pattern_map` contains a mapping from TSSymbol values to indices in the // `steps` array. For a given syntax node, the `pattern_map` makes it possible // to quickly find the starting steps of all of the patterns whose root matches @@ -592,11 +576,11 @@ static TSQueryError ts_query_parse_pattern( const char *node_name = stream->input; stream_scan_identifier(stream); uint32_t length = stream->input - node_name; - symbol = ts_query_intern_node_name( - self, + symbol = ts_language_symbol_for_name( + self->language, node_name, length, - TSSymbolTypeRegular + true ); if (!symbol) { stream_reset(stream, node_name); @@ -643,11 +627,11 @@ static TSQueryError ts_query_parse_pattern( uint32_t length = stream->input - string_content; // Add a step for the node - TSSymbol symbol = ts_query_intern_node_name( - self, + TSSymbol symbol = ts_language_symbol_for_name( + self->language, string_content, length, - TSSymbolTypeAnonymous + false ); if (!symbol) { stream_reset(stream, string_content);