diff --git a/include/tree_sitter/parser.h b/include/tree_sitter/parser.h index 81f9a92e..7b59e14f 100644 --- a/include/tree_sitter/parser.h +++ b/include/tree_sitter/parser.h @@ -9,8 +9,7 @@ extern "C" { #include #include "tree_sitter/runtime.h" -#define ts_lex_state_error 0 -#define ts_parse_state_error ((TSStateId)-1) +#define TS_STATE_ERROR 0 #define TS_DEBUG_BUFFER_SIZE 512 typedef unsigned short TSStateId; @@ -95,7 +94,6 @@ struct TSLanguage { const unsigned short *parse_table; const TSParseActionEntry *parse_actions; const TSStateId *lex_states; - const TSParseAction *recovery_actions; bool (*lex_fn)(TSLexer *, TSStateId, bool); }; @@ -134,9 +132,9 @@ struct TSLanguage { #define LEX_ERROR() \ if (error_mode) { \ - if (state == ts_lex_state_error) \ + if (state == TS_STATE_ERROR) \ lexer->advance(lexer, state, TSTransitionTypeError); \ - GO_TO_STATE(ts_lex_state_error); \ + GO_TO_STATE(TS_STATE_ERROR); \ } else { \ return false; \ } @@ -150,11 +148,10 @@ struct TSLanguage { { .type = TSParseActionTypeShift, .to_state = to_state_value } \ } -#define RECOVER(to_state_value) \ - { .type = TSParseActionTypeRecover, .to_state = to_state_value } - -#define RECOVER_EXTRA() \ - { .type = TSParseActionTypeShift, .extra = true, } +#define RECOVER(to_state_value) \ + { \ + { .type = TSParseActionTypeRecover, .to_state = to_state_value } \ + } #define SHIFT_EXTRA() \ { \ @@ -196,7 +193,6 @@ struct TSLanguage { .symbol_metadata = ts_symbol_metadata, \ .parse_table = (const unsigned short *)ts_parse_table, \ .parse_actions = ts_parse_actions, \ - .recovery_actions = ts_recovery_actions, \ .lex_states = ts_lex_states, \ .symbol_names = ts_symbol_names, \ .lex_fn = ts_lex, \ diff --git a/spec/runtime/stack_spec.cc b/spec/runtime/stack_spec.cc index b418cc16..7c3d48a2 100644 --- a/spec/runtime/stack_spec.cc +++ b/spec/runtime/stack_spec.cc @@ -8,7 +8,7 @@ #include "runtime/alloc.h" enum { - stateA = 1, + stateA = 2, stateB, stateC, stateD, stateE, stateF, stateG, stateH, stateI, stateJ }; @@ -94,7 +94,7 @@ describe("Stack", [&]() { describe("push(version, tree, is_pending, state)", [&]() { it("adds entries to the given version of the stack", [&]() { AssertThat(ts_stack_version_count(stack), Equals(1)); - AssertThat(ts_stack_top_state(stack, 0), Equals(0)); + AssertThat(ts_stack_top_state(stack, 0), Equals(1)); AssertThat(ts_stack_top_position(stack, 0), Equals(ts_length_zero())); // . <──0── A* @@ -116,7 +116,7 @@ describe("Stack", [&]() { {stateC, 0}, {stateB, 1}, {stateA, 2}, - {0, 3}, + {1, 3}, }))); }); }); @@ -149,7 +149,7 @@ describe("Stack", [&]() { {stateB, 1}, {stateC, 1}, {stateA, 2}, - {0, 3}, + {1, 3}, }))); }); @@ -191,7 +191,7 @@ describe("Stack", [&]() { {stateB, 2}, {stateC, 2}, {stateA, 3}, - {0, 4}, + {1, 4}, }))); }); }); @@ -234,14 +234,14 @@ describe("Stack", [&]() { StackSlice slice = pop.slices.contents[0]; AssertThat(slice.trees, Equals(vector({ trees[0], trees[1], trees[2] }))); - AssertThat(ts_stack_top_state(stack, 1), Equals(0)); + AssertThat(ts_stack_top_state(stack, 1), Equals(1)); free_slice_array(&pop.slices); }); it("stops popping entries early if it reaches an error tree", [&]() { // . <──0── A <──1── B <──2── C <──3── ERROR <──4── D* - ts_stack_push(stack, 0, trees[3], false, ts_parse_state_error); + ts_stack_push(stack, 0, trees[3], false, TS_STATE_ERROR); ts_stack_push(stack, 0, trees[4], false, stateD); // . <──0── A <──1── B <──2── C <──3── ERROR <──4── D* @@ -251,7 +251,7 @@ describe("Stack", [&]() { AssertThat(pop.status, Equals(StackPopResult::StackPopStoppedAtError)); AssertThat(ts_stack_version_count(stack), Equals(2)); - AssertThat(ts_stack_top_state(stack, 1), Equals(ts_parse_state_error)); + AssertThat(ts_stack_top_state(stack, 1), Equals(TS_STATE_ERROR)); AssertThat(pop.slices.size, Equals(1)); StackSlice slice = pop.slices.contents[0]; @@ -284,7 +284,7 @@ describe("Stack", [&]() { {stateB, 3}, {stateE, 3}, {stateA, 4}, - {0, 5}, + {1, 5}, }))); }); @@ -315,17 +315,17 @@ describe("Stack", [&]() { {stateB, 3}, {stateE, 3}, {stateA, 4}, - {0, 5}, + {1, 5}, }))); AssertThat(get_stack_entries(stack, 1), Equals(vector({ {stateB, 0}, {stateA, 1}, - {0, 2}, + {1, 2}, }))); AssertThat(get_stack_entries(stack, 2), Equals(vector({ {stateE, 0}, {stateA, 1}, - {0, 2}, + {1, 2}, }))); free_slice_array(&pop.slices); @@ -406,7 +406,7 @@ describe("Stack", [&]() { {stateE, 3}, {stateG, 3}, {stateA, 4}, - {0, 5}, + {1, 5}, }))); // . <──0── A <──1── B <──2── C <──3── D <──10── I* @@ -457,7 +457,7 @@ describe("Stack", [&]() { AssertThat(get_stack_entries(stack, 0), Equals(vector({ {stateA, 0}, - {0, 1}, + {1, 1}, }))); free_slice_array(&pop.slices); @@ -480,7 +480,7 @@ describe("Stack", [&]() { AssertThat(get_stack_entries(stack, 0), Equals(vector({ {stateA, 0}, - {0, 1}, + {1, 1}, }))); free_slice_array(&pop.slices); @@ -496,7 +496,7 @@ describe("Stack", [&]() { AssertThat(get_stack_entries(stack, 0), Equals(vector({ {stateB, 0}, {stateA, 1}, - {0, 2}, + {1, 2}, }))); free_slice_array(&pop.slices); diff --git a/src/compiler/build_tables/build_lex_table.cc b/src/compiler/build_tables/build_lex_table.cc index be5a3b4d..93de59f0 100644 --- a/src/compiler/build_tables/build_lex_table.cc +++ b/src/compiler/build_tables/build_lex_table.cc @@ -55,8 +55,6 @@ class LexTableBuilder { } LexTable build() { - add_lex_state_for_parse_state(&parse_table->error_state); - for (ParseState &parse_state : parse_table->states) add_lex_state_for_parse_state(&parse_state); diff --git a/src/compiler/build_tables/build_parse_table.cc b/src/compiler/build_tables/build_parse_table.cc index 48ac26b9..3ef8e6ea 100644 --- a/src/compiler/build_tables/build_parse_table.cc +++ b/src/compiler/build_tables/build_parse_table.cc @@ -56,6 +56,9 @@ class ParseTableBuilder { ProductionStep(start_symbol, 0, rules::AssociativityNone), }); + // Placeholder for error state + add_parse_state(ParseItemSet()); + add_parse_state(ParseItemSet({ { ParseItem(rules::START(), start_production, 0), @@ -67,7 +70,7 @@ class ParseTableBuilder { if (error.type != TSCompileErrorTypeNone) return { parse_table, error }; - add_out_of_context_parse_states(); + build_error_parse_state(); allow_any_conflict = true; process_part_state_queue(); @@ -104,31 +107,35 @@ class ParseTableBuilder { return CompileError::none(); } - void add_out_of_context_parse_states() { + void build_error_parse_state() { + ParseState error_state; + for (const Symbol &symbol : recovery_tokens(lexical_grammar)) { - add_out_of_context_parse_state(symbol); + add_out_of_context_parse_state(&error_state, symbol); } for (const Symbol &symbol : grammar.extra_tokens) { - parse_table.error_state.entries[symbol].actions.push_back( - ParseAction::ShiftExtra()); + error_state.entries[symbol].actions.push_back(ParseAction::ShiftExtra()); } for (size_t i = 0; i < grammar.variables.size(); i++) { Symbol symbol(i, false); - add_out_of_context_parse_state(symbol); + add_out_of_context_parse_state(&error_state, symbol); } - parse_table.error_state.entries[rules::END_OF_INPUT()].actions.push_back( - ParseAction::Shift(0, PrecedenceRange())); + error_state.entries[rules::END_OF_INPUT()].actions.push_back( + ParseAction::Recover(0)); + + parse_table.states[0] = error_state; } - void add_out_of_context_parse_state(const rules::Symbol &symbol) { + void add_out_of_context_parse_state(ParseState *error_state, + const rules::Symbol &symbol) { const ParseItemSet &item_set = recovery_states[symbol]; if (!item_set.entries.empty()) { ParseStateId state = add_parse_state(item_set); - parse_table.error_state.entries[symbol].actions.push_back( - ParseAction::Shift(state, PrecedenceRange())); + error_state->entries[symbol].actions.push_back( + ParseAction::Recover(state)); } } @@ -171,11 +178,7 @@ class ParseTableBuilder { if (status.is_done) { ParseAction action; if (item.lhs() == rules::START()) { - if (state_id == 1) { - action = ParseAction::Accept(); - } else { - continue; - } + action = ParseAction::Accept(); } else { action = ParseAction::Reduce(Symbol(item.variable_index), item.step_index, status.precedence, @@ -265,15 +268,7 @@ class ParseTableBuilder { } void remove_duplicate_parse_states() { - auto replacements = - remove_duplicate_states(&parse_table.states); - - parse_table.error_state.each_advance_action( - [&replacements](ParseAction *action) { - auto replacement = replacements.find(action->state_index); - if (replacement != replacements.end()) - action->state_index = replacement->second; - }); + remove_duplicate_states(&parse_table.states); } ParseAction *add_action(ParseStateId state_id, Symbol lookahead, diff --git a/src/compiler/generate_code/c_code.cc b/src/compiler/generate_code/c_code.cc index 101947ca..af664fa4 100644 --- a/src/compiler/generate_code/c_code.cc +++ b/src/compiler/generate_code/c_code.cc @@ -100,7 +100,6 @@ class CCodeGenerator { add_symbol_node_types_list(); add_lex_function(); add_lex_states_list(); - add_recovery_parse_states_list(); add_parse_table(); add_parser_export(); @@ -211,27 +210,6 @@ class CCodeGenerator { line(); } - void add_recovery_parse_states_list() { - line("static TSParseAction ts_recovery_actions[SYMBOL_COUNT] = {"); - indent([&]() { - for (const auto &pair : parse_table.symbols) { - const rules::Symbol &symbol = pair.first; - line("[" + symbol_id(pair.first) + "] = "); - const auto &entry = parse_table.error_state.entries.find(symbol); - if (entry != parse_table.error_state.entries.end()) { - ParseAction action = entry->second.actions[0]; - if (!action.extra) { - add("RECOVER(" + to_string(action.state_index) + "),"); - continue; - } - } - add("RECOVER_EXTRA(),"); - } - }); - line("};"); - line(); - } - void add_parse_table() { add_parse_action_list_id(ParseTableEntry{ {}, true, false }); @@ -372,6 +350,9 @@ class CCodeGenerator { to_string(action.consumed_symbol_count) + ")"); } break; + case ParseActionTypeRecover: + add("RECOVER(" + to_string(action.state_index) + ")"); + break; default: {} } add(","); diff --git a/src/compiler/parse_table.cc b/src/compiler/parse_table.cc index af917030..cd2a549a 100644 --- a/src/compiler/parse_table.cc +++ b/src/compiler/parse_table.cc @@ -53,6 +53,11 @@ ParseAction ParseAction::Shift(ParseStateId state_index, precedence_range, rules::AssociativityNone, nullptr); } +ParseAction ParseAction::Recover(ParseStateId state_index) { + return ParseAction(ParseActionTypeRecover, state_index, Symbol(-1), 0, + PrecedenceRange(), rules::AssociativityNone, nullptr); +} + ParseAction ParseAction::ShiftExtra() { ParseAction action; action.type = ParseActionTypeShift; @@ -138,7 +143,7 @@ set ParseState::expected_inputs() const { void ParseState::each_advance_action(function fn) { for (auto &entry : entries) for (ParseAction &action : entry.second.actions) - if (action.type == ParseActionTypeShift) + if (action.type == ParseActionTypeShift || ParseActionTypeRecover) fn(&action); } diff --git a/src/compiler/parse_table.h b/src/compiler/parse_table.h index 6b1a5d9b..4ce13bd5 100644 --- a/src/compiler/parse_table.h +++ b/src/compiler/parse_table.h @@ -20,6 +20,7 @@ enum ParseActionType { ParseActionTypeShift, ParseActionTypeReduce, ParseActionTypeAccept, + ParseActionTypeRecover, }; class ParseAction { @@ -32,6 +33,7 @@ class ParseAction { static ParseAction Accept(); static ParseAction Error(); static ParseAction Shift(ParseStateId state_index, PrecedenceRange precedence); + static ParseAction Recover(ParseStateId state_index); static ParseAction Reduce(rules::Symbol symbol, size_t consumed_symbol_count, int precedence, rules::Associativity, const Production &); @@ -87,7 +89,6 @@ class ParseTable { ParseAction action); std::vector states; - ParseState error_state; std::map symbols; }; diff --git a/src/runtime/language.c b/src/runtime/language.c index b117945b..8ab122ae 100644 --- a/src/runtime/language.c +++ b/src/runtime/language.c @@ -8,21 +8,20 @@ static const TSParseAction ERROR_SHIFT_EXTRA = { void ts_language_table_entry(const TSLanguage *self, TSStateId state, TSSymbol symbol, TableEntry *result) { - if (state == ts_parse_state_error) { - result->action_count = 1; - result->is_reusable = false; - result->depends_on_lookahead = false; - result->actions = (symbol == ts_builtin_sym_error) - ? &ERROR_SHIFT_EXTRA - : &self->recovery_actions[symbol]; - return; + size_t action_index; + if (symbol == ts_builtin_sym_error) { + if (state == TS_STATE_ERROR) { + result->action_count = 1; + result->is_reusable = false; + result->depends_on_lookahead = false; + result->actions = &ERROR_SHIFT_EXTRA; + return; + } + action_index = 0; + } else { + action_index = self->parse_table[state * self->symbol_count + symbol]; } - size_t action_index = - (symbol != ts_builtin_sym_error) - ? self->parse_table[state * self->symbol_count + symbol] - : 0; - const TSParseActionEntry *entry = &self->parse_actions[action_index]; result->action_count = entry->count; result->is_reusable = entry->reusable; diff --git a/src/runtime/language.h b/src/runtime/language.h index 8485143d..3941d875 100644 --- a/src/runtime/language.h +++ b/src/runtime/language.h @@ -49,11 +49,6 @@ static inline bool ts_language_is_reusable(const TSLanguage *self, TSSymbolMetadata ts_language_symbol_metadata(const TSLanguage *, TSSymbol); -static inline TSStateId ts_language_lex_state(const TSLanguage *self, - TSStateId state) { - return state == ts_parse_state_error ? 0 : self->lex_states[state]; -} - #ifdef __cplusplus } #endif diff --git a/src/runtime/parser.c b/src/runtime/parser.c index fd8b2507..4107ed65 100644 --- a/src/runtime/parser.c +++ b/src/runtime/parser.c @@ -119,7 +119,7 @@ static BreakdownResult ts_parser__breakdown_top_of_stack(TSParser *self, pending = child->child_count > 0; if (child->symbol == ts_builtin_sym_error) { - state = ts_parse_state_error; + state = TS_STATE_ERROR; } else if (!child->extra) { const TSParseAction *action = ts_language_last_action(self->language, state, child->symbol); @@ -235,7 +235,7 @@ static bool ts_parser__can_reuse(TSParser *self, StackVersion version, return false; } - TSStateId lex_state = ts_language_lex_state(self->language, state); + TSStateId lex_state = self->language->lex_states[state]; if (tree->first_leaf.lex_state != lex_state) { if (tree->child_count > 0) { TableEntry leaf_entry; @@ -266,7 +266,7 @@ static bool ts_parser__can_reuse(TSParser *self, StackVersion version, static TSTree *ts_parser__lex(TSParser *self, TSStateId parse_state, bool error_mode) { - TSStateId state = error_mode ? 0 : self->language->lex_states[parse_state]; + TSStateId state = self->language->lex_states[parse_state]; LOG("lex state:%d", state); TSLength position = self->lexer.current_position; @@ -275,7 +275,7 @@ static TSTree *ts_parser__lex(TSParser *self, TSStateId parse_state, if (!self->language->lex_fn(&self->lexer, state, error_mode)) { ts_lexer_reset(&self->lexer, position); ts_lexer_start(&self->lexer, state); - assert(self->language->lex_fn(&self->lexer, 0, true)); + assert(self->language->lex_fn(&self->lexer, TS_STATE_ERROR, true)); } TSLexerResult lex_result; @@ -332,7 +332,7 @@ static TSTree *ts_parser__get_lookahead(TSParser *self, StackVersion version, ts_lexer_reset(&self->lexer, position); TSStateId parse_state = ts_stack_top_state(self->stack, version); - bool error_mode = parse_state == ts_parse_state_error; + bool error_mode = parse_state == TS_STATE_ERROR; return ts_parser__lex(self, parse_state, error_mode); error: @@ -518,18 +518,20 @@ static Reduction ts_parser__reduce(TSParser *self, StackVersion version, if (action->type == TSParseActionTypeRecover && allow_skipping) { unsigned error_depth = ts_stack_error_depth(self->stack, slice.version); - unsigned error_cost = ts_stack_error_cost(self->stack, slice.version) + 1; - if (!ts_parser__better_version_exists(self, slice.version, error_depth, error_cost)) { + unsigned error_cost = + ts_stack_error_cost(self->stack, slice.version) + 1; + if (!ts_parser__better_version_exists(self, slice.version, error_depth, + error_cost)) { StackVersion other_version = ts_stack_duplicate_version(self->stack, slice.version); CHECK(other_version != STACK_VERSION_NONE); CHECK(ts_stack_push(self->stack, other_version, parent, false, - ts_parse_state_error)); + TS_STATE_ERROR)); for (size_t j = parent->child_count; j < slice.trees.size; j++) { TSTree *tree = slice.trees.contents[j]; CHECK(ts_stack_push(self->stack, other_version, tree, false, - ts_parse_state_error)); + TS_STATE_ERROR)); } for (StackVersion v = version + 1; v < initial_version_count; v++) @@ -896,10 +898,10 @@ static bool ts_parser__handle_error(TSParser *self, StackVersion version, if (did_reduce && !has_shift_action) ts_stack_renumber_version(self->stack, previous_version_count, version); - CHECK(ts_stack_push(self->stack, version, NULL, false, ts_parse_state_error)); + CHECK(ts_stack_push(self->stack, version, NULL, false, TS_STATE_ERROR)); while (ts_stack_version_count(self->stack) > previous_version_count) { CHECK(ts_stack_push(self->stack, previous_version_count, NULL, false, - ts_parse_state_error)); + TS_STATE_ERROR)); assert(ts_stack_merge(self->stack, version, previous_version_count)); } @@ -932,7 +934,7 @@ static bool ts_parser__recover(TSParser *self, StackVersion version, StackVersion new_version = ts_stack_duplicate_version(self->stack, version); CHECK(new_version != STACK_VERSION_NONE); CHECK(ts_parser__shift( - self, new_version, ts_parse_state_error, lookahead, + self, new_version, TS_STATE_ERROR, lookahead, ts_language_symbol_metadata(self->language, lookahead->symbol).extra)); CHECK(ts_parser__shift(self, version, state, lookahead, false)); diff --git a/src/runtime/stack.c b/src/runtime/stack.c index 07eb0676..078699bc 100644 --- a/src/runtime/stack.c +++ b/src/runtime/stack.c @@ -112,7 +112,7 @@ static StackNode *stack_node_new(StackNode *next, TSTree *tree, bool is_pending, if (tree) { ts_tree_retain(tree); - if (state == ts_parse_state_error) { + if (state == TS_STATE_ERROR) { if (!tree->extra) { node->error_cost++; } @@ -285,7 +285,7 @@ Stack *ts_stack_new() { goto error; self->base_node = - stack_node_new(NULL, NULL, false, 0, ts_length_zero(), &self->node_pool); + stack_node_new(NULL, NULL, false, 1, ts_length_zero(), &self->node_pool); stack_node_retain(self->base_node); if (!self->base_node) goto error; @@ -377,7 +377,7 @@ INLINE StackIterateAction pop_count_callback(void *payload, TSStateId state, return StackIteratePop | StackIterateStop; } - if (state == ts_parse_state_error) { + if (state == TS_STATE_ERROR) { if (pop_session->found_valid_path || pop_session->found_error) { return StackIterateStop; } else { @@ -544,7 +544,7 @@ bool ts_stack_print_dot_graph(Stack *self, const char **symbol_names, FILE *f) { all_paths_done = false; fprintf(f, "node_%p [", node); - if (node->state == ts_parse_state_error) + if (node->state == TS_STATE_ERROR) fprintf(f, "label=\"?\""); else if (node->link_count == 1 && node->links[0].tree && node->links[0].tree->extra)