diff --git a/lib/src/node.c b/lib/src/node.c index de59504e..36bfb24c 100644 --- a/lib/src/node.c +++ b/lib/src/node.c @@ -569,24 +569,58 @@ recur: return ts_node__null(); } -const char *ts_node_field_name_for_child(TSNode self, uint32_t child_index) { - const TSFieldMapEntry *field_map_start = NULL, *field_map_end = NULL; - if (!ts_node_child_count(self)) { +static inline const char *ts_node__field_name_from_language(TSNode self, uint32_t structural_child_index) { + const TSFieldMapEntry *field_map, *field_map_end; + ts_language_field_map( + self.tree->language, + ts_node__subtree(self).ptr->production_id, + &field_map, + &field_map_end + ); + for (; field_map != field_map_end; field_map++) { + if (!field_map->inherited && field_map->child_index == structural_child_index) { + return self.tree->language->field_names[field_map->field_id]; + } + } return NULL; - } +} - ts_language_field_map( - self.tree->language, - ts_node__subtree(self).ptr->production_id, - &field_map_start, - &field_map_end - ); +const char *ts_node_field_name_for_child(TSNode self, uint32_t child_index) { + TSNode result = self; + bool did_descend = true; + const char *inherited_field_name = NULL; - for (const TSFieldMapEntry *i = field_map_start; i < field_map_end; i++) { - if (i->child_index == child_index) { - return self.tree->language->field_names[i->field_id]; + while (did_descend) { + did_descend = false; + + TSNode child; + uint32_t index = 0; + NodeChildIterator iterator = ts_node_iterate_children(&result); + while (ts_node_child_iterator_next(&iterator, &child)) { + if (ts_node__is_relevant(child, true)) { + if (index == child_index) { + const char *field_name = ts_node__field_name_from_language(result, iterator.structural_child_index - 1); + if (field_name) return field_name; + return inherited_field_name; + } + index++; + } else { + uint32_t grandchild_index = child_index - index; + uint32_t grandchild_count = ts_node__relevant_child_count(child, true); + if (grandchild_index < grandchild_count) { + const char *field_name = ts_node__field_name_from_language(result, iterator.structural_child_index - 1); + if (field_name) inherited_field_name = field_name; + + did_descend = true; + result = child; + child_index = grandchild_index; + break; + } + index += grandchild_count; + } } } + return NULL; }