diff --git a/include/tree_sitter/parser.h b/include/tree_sitter/parser.h index 88322e03..ea655c1d 100644 --- a/include/tree_sitter/parser.h +++ b/include/tree_sitter/parser.h @@ -41,6 +41,7 @@ typedef struct { struct { TSStateId state; bool extra : 1; + bool repetition : 1; }; struct { TSSymbol symbol; @@ -138,6 +139,17 @@ typedef struct TSLanguage { } \ } +#define SHIFT_REPEAT(state_value) \ + { \ + { \ + .type = TSParseActionTypeShift, \ + .params = { \ + .state = state_value, \ + .repetition = true \ + }, \ + } \ + } + #define RECOVER() \ { \ { .type = TSParseActionTypeRecover } \ diff --git a/include/tree_sitter/runtime.h b/include/tree_sitter/runtime.h index c4dde4fe..b5f86270 100644 --- a/include/tree_sitter/runtime.h +++ b/include/tree_sitter/runtime.h @@ -9,7 +9,7 @@ extern "C" { #include #include -#define TREE_SITTER_LANGUAGE_VERSION 4 +#define TREE_SITTER_LANGUAGE_VERSION 5 typedef unsigned short TSSymbol; typedef struct TSLanguage TSLanguage; diff --git a/src/compiler/build_tables/parse_table_builder.cc b/src/compiler/build_tables/parse_table_builder.cc index 9e5e1bf1..c705c2c1 100644 --- a/src/compiler/build_tables/parse_table_builder.cc +++ b/src/compiler/build_tables/parse_table_builder.cc @@ -504,6 +504,20 @@ class ParseTableBuilderImpl : public ParseTableBuilder { } if (entry.actions.back().type == ParseActionTypeShift) { + Symbol symbol = conflicting_items.begin()->lhs(); + if (symbol.is_non_terminal() && grammar.variables[symbol.index].type == VariableTypeAuxiliary) { + bool all_symbols_match = true; + for (const ParseItem &conflicting_item : conflicting_items) { + if (conflicting_item.lhs() != symbol) { + all_symbols_match = false; + break; + } + } + if (all_symbols_match) { + entry.actions.back().repetition = true; + return ""; + } + } // If the shift action has higher precedence, prefer it over any of the // reduce actions. diff --git a/src/compiler/generate_code/c_code.cc b/src/compiler/generate_code/c_code.cc index 092f1966..efd1c7dc 100644 --- a/src/compiler/generate_code/c_code.cc +++ b/src/compiler/generate_code/c_code.cc @@ -623,6 +623,8 @@ class CCodeGenerator { case ParseActionTypeShift: if (action.extra) { add("SHIFT_EXTRA()"); + } else if (action.repetition) { + add("SHIFT_REPEAT(" + to_string(action.state_index) + ")"); } else { add("SHIFT(" + to_string(action.state_index) + ")"); } diff --git a/src/compiler/parse_table.cc b/src/compiler/parse_table.cc index 6eaeaabd..c913ce34 100644 --- a/src/compiler/parse_table.cc +++ b/src/compiler/parse_table.cc @@ -21,7 +21,8 @@ ParseAction::ParseAction() : associativity(rules::AssociativityNone), alias_sequence_id(0), fragile(false), - extra(false) {} + extra(false), + repetition(false) {} ParseAction ParseAction::Error() { return ParseAction(); @@ -78,7 +79,8 @@ bool ParseAction::operator==(const ParseAction &other) const { associativity == other.associativity && alias_sequence_id == other.alias_sequence_id && extra == other.extra && - fragile == other.fragile; + fragile == other.fragile && + repetition == other.repetition; } bool ParseAction::operator<(const ParseAction &other) const { @@ -100,6 +102,8 @@ bool ParseAction::operator<(const ParseAction &other) const { if (other.extra && !extra) return false; if (fragile && !other.fragile) return true; if (other.fragile && !fragile) return false; + if (repetition && !other.repetition) return true; + if (other.repetition && !repetition) return false; return alias_sequence_id < other.alias_sequence_id; } diff --git a/src/compiler/parse_table.h b/src/compiler/parse_table.h index d652c144..ce28c176 100644 --- a/src/compiler/parse_table.h +++ b/src/compiler/parse_table.h @@ -45,6 +45,7 @@ struct ParseAction { unsigned alias_sequence_id; bool fragile; bool extra; + bool repetition; }; struct ParseTableEntry { diff --git a/src/compiler/prepare_grammar/expand_repeats.cc b/src/compiler/prepare_grammar/expand_repeats.cc index 9f663d67..46230867 100644 --- a/src/compiler/prepare_grammar/expand_repeats.cc +++ b/src/compiler/prepare_grammar/expand_repeats.cc @@ -57,7 +57,7 @@ class ExpandRepeats { helper_rule_name, VariableTypeAuxiliary, rules::Choice{{ - rules::Seq{repeat_symbol, inner_rule}, + rules::Seq{repeat_symbol, repeat_symbol}, inner_rule, }} }); diff --git a/src/runtime/parser.c b/src/runtime/parser.c index 97098c98..1bf543b3 100644 --- a/src/runtime/parser.c +++ b/src/runtime/parser.c @@ -1097,6 +1097,7 @@ static void parser__advance(Parser *self, StackVersion version, ReusableNode *re switch (action.type) { case TSParseActionTypeShift: { + if (action.params.repetition) break; TSStateId next_state; if (action.params.extra) { next_state = state; diff --git a/test/compiler/prepare_grammar/expand_repeats_test.cc b/test/compiler/prepare_grammar/expand_repeats_test.cc index 87c8b879..250bd59b 100644 --- a/test/compiler/prepare_grammar/expand_repeats_test.cc +++ b/test/compiler/prepare_grammar/expand_repeats_test.cc @@ -23,7 +23,7 @@ describe("expand_repeats", []() { AssertThat(result.variables, Equals(vector{ Variable{"rule0", VariableTypeNamed, Symbol::non_terminal(1)}, Variable{"rule0_repeat1", VariableTypeAuxiliary, Rule::choice({ - Rule::seq({ Symbol::non_terminal(1), Symbol::terminal(0) }), + Rule::seq({ Symbol::non_terminal(1), Symbol::non_terminal(1) }), Symbol::terminal(0), })}, })); @@ -48,7 +48,7 @@ describe("expand_repeats", []() { Symbol::non_terminal(1), })}, Variable{"rule0_repeat1", VariableTypeAuxiliary, Rule::choice({ - Rule::seq({ Symbol::non_terminal(1), Symbol::terminal(11) }), + Rule::seq({ Symbol::non_terminal(1), Symbol::non_terminal(1) }), Symbol::terminal(11) })}, })); @@ -73,7 +73,7 @@ describe("expand_repeats", []() { Symbol::non_terminal(1), })}, Variable{"rule0_repeat1", VariableTypeAuxiliary, Rule::choice({ - Rule::seq({ Symbol::non_terminal(1), Symbol::terminal(11) }), + Rule::seq({ Symbol::non_terminal(1), Symbol::non_terminal(1) }), Symbol::terminal(11), })}, })); @@ -106,7 +106,7 @@ describe("expand_repeats", []() { Symbol::non_terminal(2), })}, Variable{"rule0_repeat1", VariableTypeAuxiliary, Rule::choice({ - Rule::seq({ Symbol::non_terminal(2), Symbol::terminal(4) }), + Rule::seq({ Symbol::non_terminal(2), Symbol::non_terminal(2) }), Symbol::terminal(4), })}, })); @@ -131,11 +131,11 @@ describe("expand_repeats", []() { Symbol::non_terminal(2), })}, Variable{"rule0_repeat1", VariableTypeAuxiliary, Rule::choice({ - Rule::seq({ Symbol::non_terminal(1), Symbol::terminal(10) }), + Rule::seq({ Symbol::non_terminal(1), Symbol::non_terminal(1) }), Symbol::terminal(10), })}, Variable{"rule0_repeat2", VariableTypeAuxiliary, Rule::choice({ - Rule::seq({ Symbol::non_terminal(2), Symbol::terminal(11) }), + Rule::seq({ Symbol::non_terminal(2), Symbol::non_terminal(2) }), Symbol::terminal(11), })}, })); @@ -156,11 +156,11 @@ describe("expand_repeats", []() { Variable{"rule0", VariableTypeNamed, Symbol::non_terminal(2)}, Variable{"rule1", VariableTypeNamed, Symbol::non_terminal(3)}, Variable{"rule0_repeat1", VariableTypeAuxiliary, Rule::choice({ - Rule::seq({ Symbol::non_terminal(2), Symbol::terminal(10) }), + Rule::seq({ Symbol::non_terminal(2), Symbol::non_terminal(2) }), Symbol::terminal(10), })}, Variable{"rule1_repeat1", VariableTypeAuxiliary, Rule::choice({ - Rule::seq({ Symbol::non_terminal(3), Symbol::terminal(11) }), + Rule::seq({ Symbol::non_terminal(3), Symbol::non_terminal(3) }), Symbol::terminal(11), })}, }));