diff --git a/include/tree_sitter/parser.h b/include/tree_sitter/parser.h index 88322e03..ea655c1d 100644 --- a/include/tree_sitter/parser.h +++ b/include/tree_sitter/parser.h @@ -41,6 +41,7 @@ typedef struct { struct { TSStateId state; bool extra : 1; + bool repetition : 1; }; struct { TSSymbol symbol; @@ -138,6 +139,17 @@ typedef struct TSLanguage { } \ } +#define SHIFT_REPEAT(state_value) \ + { \ + { \ + .type = TSParseActionTypeShift, \ + .params = { \ + .state = state_value, \ + .repetition = true \ + }, \ + } \ + } + #define RECOVER() \ { \ { .type = TSParseActionTypeRecover } \ diff --git a/include/tree_sitter/runtime.h b/include/tree_sitter/runtime.h index c4dde4fe..b5f86270 100644 --- a/include/tree_sitter/runtime.h +++ b/include/tree_sitter/runtime.h @@ -9,7 +9,7 @@ extern "C" { #include #include -#define TREE_SITTER_LANGUAGE_VERSION 4 +#define TREE_SITTER_LANGUAGE_VERSION 5 typedef unsigned short TSSymbol; typedef struct TSLanguage TSLanguage; diff --git a/src/compiler/build_tables/parse_table_builder.cc b/src/compiler/build_tables/parse_table_builder.cc index 9e5e1bf1..c705c2c1 100644 --- a/src/compiler/build_tables/parse_table_builder.cc +++ b/src/compiler/build_tables/parse_table_builder.cc @@ -504,6 +504,20 @@ class ParseTableBuilderImpl : public ParseTableBuilder { } if (entry.actions.back().type == ParseActionTypeShift) { + Symbol symbol = conflicting_items.begin()->lhs(); + if (symbol.is_non_terminal() && grammar.variables[symbol.index].type == VariableTypeAuxiliary) { + bool all_symbols_match = true; + for (const ParseItem &conflicting_item : conflicting_items) { + if (conflicting_item.lhs() != symbol) { + all_symbols_match = false; + break; + } + } + if (all_symbols_match) { + entry.actions.back().repetition = true; + return ""; + } + } // If the shift action has higher precedence, prefer it over any of the // reduce actions. diff --git a/src/compiler/generate_code/c_code.cc b/src/compiler/generate_code/c_code.cc index 092f1966..efd1c7dc 100644 --- a/src/compiler/generate_code/c_code.cc +++ b/src/compiler/generate_code/c_code.cc @@ -623,6 +623,8 @@ class CCodeGenerator { case ParseActionTypeShift: if (action.extra) { add("SHIFT_EXTRA()"); + } else if (action.repetition) { + add("SHIFT_REPEAT(" + to_string(action.state_index) + ")"); } else { add("SHIFT(" + to_string(action.state_index) + ")"); } diff --git a/src/compiler/parse_table.cc b/src/compiler/parse_table.cc index 6eaeaabd..c913ce34 100644 --- a/src/compiler/parse_table.cc +++ b/src/compiler/parse_table.cc @@ -21,7 +21,8 @@ ParseAction::ParseAction() : associativity(rules::AssociativityNone), alias_sequence_id(0), fragile(false), - extra(false) {} + extra(false), + repetition(false) {} ParseAction ParseAction::Error() { return ParseAction(); @@ -78,7 +79,8 @@ bool ParseAction::operator==(const ParseAction &other) const { associativity == other.associativity && alias_sequence_id == other.alias_sequence_id && extra == other.extra && - fragile == other.fragile; + fragile == other.fragile && + repetition == other.repetition; } bool ParseAction::operator<(const ParseAction &other) const { @@ -100,6 +102,8 @@ bool ParseAction::operator<(const ParseAction &other) const { if (other.extra && !extra) return false; if (fragile && !other.fragile) return true; if (other.fragile && !fragile) return false; + if (repetition && !other.repetition) return true; + if (other.repetition && !repetition) return false; return alias_sequence_id < other.alias_sequence_id; } diff --git a/src/compiler/parse_table.h b/src/compiler/parse_table.h index d652c144..ce28c176 100644 --- a/src/compiler/parse_table.h +++ b/src/compiler/parse_table.h @@ -45,6 +45,7 @@ struct ParseAction { unsigned alias_sequence_id; bool fragile; bool extra; + bool repetition; }; struct ParseTableEntry { diff --git a/src/compiler/prepare_grammar/expand_repeats.cc b/src/compiler/prepare_grammar/expand_repeats.cc index 9f663d67..46230867 100644 --- a/src/compiler/prepare_grammar/expand_repeats.cc +++ b/src/compiler/prepare_grammar/expand_repeats.cc @@ -57,7 +57,7 @@ class ExpandRepeats { helper_rule_name, VariableTypeAuxiliary, rules::Choice{{ - rules::Seq{repeat_symbol, inner_rule}, + rules::Seq{repeat_symbol, repeat_symbol}, inner_rule, }} }); diff --git a/src/runtime/parser.c b/src/runtime/parser.c index 97098c98..a21de531 100644 --- a/src/runtime/parser.c +++ b/src/runtime/parser.c @@ -1097,6 +1097,7 @@ static void parser__advance(Parser *self, StackVersion version, ReusableNode *re switch (action.type) { case TSParseActionTypeShift: { + if (action.params.repetition) break; TSStateId next_state; if (action.params.extra) { next_state = state; @@ -1248,10 +1249,11 @@ Tree *parser_parse(Parser *self, TSInput input, Tree *old_tree, bool halt_on_err self->in_ambiguity = version > 1; } while (version != 0); - LOG("done"); - LOG_TREE(); ts_stack_clear(self->stack); parser__set_cached_token(self, 0, NULL, NULL); ts_tree_assign_parents(self->finished_tree, &self->tree_pool, self->language); + + LOG("done"); + LOG_TREE(); return self->finished_tree; } diff --git a/src/runtime/tree.c b/src/runtime/tree.c index 49d9b3a1..d1b551c8 100644 --- a/src/runtime/tree.c +++ b/src/runtime/tree.c @@ -203,18 +203,71 @@ Tree *ts_tree_make_copy(TreePool *pool, Tree *self) { return result; } +static void ts_tree__compress(Tree *self, unsigned count, const TSLanguage *language) { + Tree *tree = self; + for (unsigned i = 0; i < count; i++) { + Tree *child = tree->children[0]; + if (child->symbol != tree->symbol) break; + + Tree *grandchild = child->children[0]; + if (grandchild->symbol != tree->symbol) break; + if (grandchild->children[0]->symbol != tree->symbol) break; + if (child->ref_count > 1 || grandchild->ref_count > 1) break; + + tree->children[0] = grandchild; + grandchild->context.parent = tree; + grandchild->context.index = -1; + + child->children[0] = grandchild->children[1]; + child->children[0]->context.parent = child; + child->children[0]->context.index = -1; + + grandchild->children[1] = child; + grandchild->children[1]->context.parent = grandchild; + grandchild->children[1]->context.index = -1; + + tree = grandchild; + } + + while (tree != self) { + tree = tree->context.parent; + Tree *child = tree->children[0]; + Tree *grandchild = child->children[1]; + ts_tree_set_children(grandchild, 2, grandchild->children, language); + ts_tree_set_children(child, 2, child->children, language); + ts_tree_set_children(tree, 2, tree->children, language); + } +} + +void ts_tree__balance(Tree *self, const TSLanguage *language) { + if (self->children[0]->repeat_depth > self->children[1]->repeat_depth) { + unsigned n = self->children[0]->repeat_depth - self->children[1]->repeat_depth; + for (unsigned i = n / 2; i > 0; i /= 2) { + ts_tree__compress(self, i, language); + n -= i; + } + } +} + void ts_tree_assign_parents(Tree *self, TreePool *pool, const TSLanguage *language) { self->context.parent = NULL; array_clear(&pool->tree_stack); array_push(&pool->tree_stack, self); while (pool->tree_stack.size > 0) { Tree *tree = array_pop(&pool->tree_stack); + + if (tree->repeat_depth > 0) { + ts_tree__balance(tree, language); + } + Length offset = length_zero(); const TSSymbol *alias_sequence = ts_language_alias_sequence(language, tree->alias_sequence_id); uint32_t non_extra_index = 0; + bool earlier_child_was_changed = false; for (uint32_t i = 0; i < tree->child_count; i++) { Tree *child = tree->children[i]; - if (child->context.parent != tree || child->context.index != i) { + if (earlier_child_was_changed || child->context.parent != tree || child->context.index != i) { + earlier_child_was_changed = true; child->context.parent = tree; child->context.index = i; child->context.offset = offset; @@ -236,13 +289,14 @@ void ts_tree_assign_parents(Tree *self, TreePool *pool, const TSLanguage *langua void ts_tree_set_children(Tree *self, uint32_t child_count, Tree **children, const TSLanguage *language) { - if (self->child_count > 0) ts_free(self->children); + if (self->child_count > 0 && children != self->children) ts_free(self->children); self->children = children; self->child_count = child_count; self->named_child_count = 0; self->visible_child_count = 0; self->error_cost = 0; + self->repeat_depth = 0; self->has_external_tokens = false; self->dynamic_precedence = 0; @@ -298,10 +352,24 @@ void ts_tree_set_children(Tree *self, uint32_t child_count, Tree **children, if (child_count > 0) { self->first_leaf = children[0]->first_leaf; - if (children[0]->fragile_left) + if (children[0]->fragile_left) { self->fragile_left = true; - if (children[child_count - 1]->fragile_right) + } + if (children[child_count - 1]->fragile_right) { self->fragile_right = true; + } + if ( + self->child_count == 2 && + !self->visible && !self->named && + self->children[0]->symbol == self->symbol && + self->children[1]->symbol == self->symbol + ) { + if (self->children[0]->repeat_depth > self->children[1]->repeat_depth) { + self->repeat_depth = self->children[0]->repeat_depth + 1; + } else { + self->repeat_depth = self->children[1]->repeat_depth + 1; + } + } } } @@ -342,6 +410,7 @@ Tree *ts_tree_make_missing_leaf(TreePool *pool, TSSymbol symbol, const TSLanguag result->error_cost = ERROR_COST_PER_MISSING_TREE; return result; } + void ts_tree_retain(Tree *self) { assert(self->ref_count > 0); self->ref_count++; @@ -633,9 +702,9 @@ void ts_tree__print_dot_graph(const Tree *self, uint32_t byte_offset, if (self->extra) fprintf(f, ", fontcolor=gray"); - fprintf(f, ", tooltip=\"range:%u - %u\nstate:%d\nerror-cost:%u\"]\n", - byte_offset, byte_offset + ts_tree_total_bytes(self), self->parse_state, - self->error_cost); + fprintf(f, ", tooltip=\"address:%p\nrange:%u - %u\nstate:%d\nerror-cost:%u\nrepeat-depth:%u\"]\n", + self, byte_offset, byte_offset + ts_tree_total_bytes(self), self->parse_state, + self->error_cost, self->repeat_depth); for (uint32_t i = 0; i < self->child_count; i++) { const Tree *child = self->children[i]; ts_tree__print_dot_graph(child, byte_offset, language, f); diff --git a/src/runtime/tree.h b/src/runtime/tree.h index 9d33561d..0e3c2880 100644 --- a/src/runtime/tree.h +++ b/src/runtime/tree.h @@ -50,6 +50,7 @@ typedef struct Tree { TSSymbol symbol; TSStateId parse_state; unsigned error_cost; + unsigned repeat_depth; struct { TSSymbol symbol; diff --git a/test/compiler/prepare_grammar/expand_repeats_test.cc b/test/compiler/prepare_grammar/expand_repeats_test.cc index 87c8b879..250bd59b 100644 --- a/test/compiler/prepare_grammar/expand_repeats_test.cc +++ b/test/compiler/prepare_grammar/expand_repeats_test.cc @@ -23,7 +23,7 @@ describe("expand_repeats", []() { AssertThat(result.variables, Equals(vector{ Variable{"rule0", VariableTypeNamed, Symbol::non_terminal(1)}, Variable{"rule0_repeat1", VariableTypeAuxiliary, Rule::choice({ - Rule::seq({ Symbol::non_terminal(1), Symbol::terminal(0) }), + Rule::seq({ Symbol::non_terminal(1), Symbol::non_terminal(1) }), Symbol::terminal(0), })}, })); @@ -48,7 +48,7 @@ describe("expand_repeats", []() { Symbol::non_terminal(1), })}, Variable{"rule0_repeat1", VariableTypeAuxiliary, Rule::choice({ - Rule::seq({ Symbol::non_terminal(1), Symbol::terminal(11) }), + Rule::seq({ Symbol::non_terminal(1), Symbol::non_terminal(1) }), Symbol::terminal(11) })}, })); @@ -73,7 +73,7 @@ describe("expand_repeats", []() { Symbol::non_terminal(1), })}, Variable{"rule0_repeat1", VariableTypeAuxiliary, Rule::choice({ - Rule::seq({ Symbol::non_terminal(1), Symbol::terminal(11) }), + Rule::seq({ Symbol::non_terminal(1), Symbol::non_terminal(1) }), Symbol::terminal(11), })}, })); @@ -106,7 +106,7 @@ describe("expand_repeats", []() { Symbol::non_terminal(2), })}, Variable{"rule0_repeat1", VariableTypeAuxiliary, Rule::choice({ - Rule::seq({ Symbol::non_terminal(2), Symbol::terminal(4) }), + Rule::seq({ Symbol::non_terminal(2), Symbol::non_terminal(2) }), Symbol::terminal(4), })}, })); @@ -131,11 +131,11 @@ describe("expand_repeats", []() { Symbol::non_terminal(2), })}, Variable{"rule0_repeat1", VariableTypeAuxiliary, Rule::choice({ - Rule::seq({ Symbol::non_terminal(1), Symbol::terminal(10) }), + Rule::seq({ Symbol::non_terminal(1), Symbol::non_terminal(1) }), Symbol::terminal(10), })}, Variable{"rule0_repeat2", VariableTypeAuxiliary, Rule::choice({ - Rule::seq({ Symbol::non_terminal(2), Symbol::terminal(11) }), + Rule::seq({ Symbol::non_terminal(2), Symbol::non_terminal(2) }), Symbol::terminal(11), })}, })); @@ -156,11 +156,11 @@ describe("expand_repeats", []() { Variable{"rule0", VariableTypeNamed, Symbol::non_terminal(2)}, Variable{"rule1", VariableTypeNamed, Symbol::non_terminal(3)}, Variable{"rule0_repeat1", VariableTypeAuxiliary, Rule::choice({ - Rule::seq({ Symbol::non_terminal(2), Symbol::terminal(10) }), + Rule::seq({ Symbol::non_terminal(2), Symbol::non_terminal(2) }), Symbol::terminal(10), })}, Variable{"rule1_repeat1", VariableTypeAuxiliary, Rule::choice({ - Rule::seq({ Symbol::non_terminal(3), Symbol::terminal(11) }), + Rule::seq({ Symbol::non_terminal(3), Symbol::non_terminal(3) }), Symbol::terminal(11), })}, })); diff --git a/test/integration/real_grammars.cc b/test/integration/real_grammars.cc index 13256364..6d5ac27e 100644 --- a/test/integration/real_grammars.cc +++ b/test/integration/real_grammars.cc @@ -96,6 +96,7 @@ for (auto &language_name : test_languages) { uint32_t range_count; ScopeSequence old_scope_sequence = build_scope_sequence(document, input->content); ts_document_parse_and_get_changed_ranges(document, &ranges, &range_count); + assert_correct_tree_size(document, input->content); ScopeSequence new_scope_sequence = build_scope_sequence(document, input->content); verify_changed_ranges(old_scope_sequence, new_scope_sequence, @@ -119,6 +120,7 @@ for (auto &language_name : test_languages) { uint32_t range_count; ScopeSequence old_scope_sequence = build_scope_sequence(document, input->content); ts_document_parse_and_get_changed_ranges(document, &ranges, &range_count); + assert_correct_tree_size(document, input->content); ScopeSequence new_scope_sequence = build_scope_sequence(document, input->content); verify_changed_ranges(old_scope_sequence, new_scope_sequence,