diff --git a/src/compiler/build_tables/build_lex_table.cc b/src/compiler/build_tables/build_lex_table.cc index cea7fc44..c4abe367 100644 --- a/src/compiler/build_tables/build_lex_table.cc +++ b/src/compiler/build_tables/build_lex_table.cc @@ -9,7 +9,6 @@ #include "compiler/build_tables/lex_conflict_manager.h" #include "compiler/build_tables/remove_duplicate_states.h" #include "compiler/build_tables/lex_item.h" -#include "compiler/build_tables/does_match_any_line.h" #include "compiler/parse_table.h" #include "compiler/lexical_grammar.h" #include "compiler/rules/built_in_symbols.h" @@ -48,11 +47,10 @@ class LexTableBuilder { } LexTable build() { - add_lex_state(build_lex_item_set(parse_table->all_symbols(), true)); + add_lex_state_for_parse_state(&parse_table->error_state); for (ParseState &parse_state : parse_table->states) - parse_state.lex_state_id = - add_lex_state(build_lex_item_set(parse_state.expected_inputs(), false)); + add_lex_state_for_parse_state(&parse_state); mark_fragile_tokens(); remove_duplicate_lex_states(); @@ -61,7 +59,7 @@ class LexTableBuilder { } private: - LexItemSet build_lex_item_set(const set &symbols, bool error) { + LexItemSet build_lex_item_set(const set &symbols) { LexItemSet result; for (const Symbol &symbol : symbols) { vector rules; @@ -69,8 +67,6 @@ class LexTableBuilder { rules.push_back(CharacterSet().include(0).copy()); } else if (symbol.is_token) { rule_ptr rule = lex_grammar.variables[symbol.index].rule; - if (error && does_match_any_line(rule)) - continue; auto choice = rule->as(); if (choice) @@ -98,6 +94,11 @@ class LexTableBuilder { return result; } + void add_lex_state_for_parse_state(ParseState *parse_state) { + parse_state->lex_state_id = + add_lex_state(build_lex_item_set(parse_state->expected_inputs())); + } + LexStateId add_lex_state(const LexItemSet &item_set) { const auto &pair = lex_state_ids.find(item_set); if (pair == lex_state_ids.end()) { diff --git a/src/compiler/build_tables/build_parse_table.cc b/src/compiler/build_tables/build_parse_table.cc index 9bd0b8fb..ff3d945c 100644 --- a/src/compiler/build_tables/build_parse_table.cc +++ b/src/compiler/build_tables/build_parse_table.cc @@ -15,6 +15,7 @@ #include "compiler/syntax_grammar.h" #include "compiler/rules/symbol.h" #include "compiler/rules/built_in_symbols.h" +#include "compiler/build_tables/does_match_any_line.h" namespace tree_sitter { namespace build_tables { @@ -105,17 +106,20 @@ class ParseTableBuilder { } void add_out_of_context_parse_states() { - map> symbols_by_first = symbols_by_first_symbol(grammar); + auto symbols_by_first = symbols_by_first_symbol(grammar); + for (size_t i = 0; i < lexical_grammar.variables.size(); i++) { Symbol symbol(i, true); - if (!grammar.extra_tokens.count(symbol)) + if (!does_match_any_line(lexical_grammar.variables[i].rule)) add_out_of_context_parse_state(symbol, symbols_by_first[symbol]); } for (size_t i = 0; i < grammar.variables.size(); i++) { Symbol symbol(i, false); - add_out_of_context_parse_state(Symbol(i, false), symbols_by_first[symbol]); + add_out_of_context_parse_state(symbol, symbols_by_first[symbol]); } + + parse_table.error_state.actions[rules::END_OF_INPUT()].clear(); } void add_out_of_context_parse_state(const rules::Symbol &symbol, @@ -133,8 +137,11 @@ class ParseTableBuilder { } } - ParseStateId state = add_parse_state(item_set); - parse_table.out_of_context_state_indices[symbol] = state; + if (!item_set.entries.empty()) { + ParseStateId state = add_parse_state(item_set); + parse_table.error_state.actions[symbol].push_back( + ParseAction::Shift(state, PrecedenceRange())); + } } ParseStateId add_parse_state(const ParseItemSet &item_set) { @@ -265,11 +272,12 @@ class ParseTableBuilder { auto replacements = remove_duplicate_states(&parse_table.states); - for (auto &pair : parse_table.out_of_context_state_indices) { - auto replacement = replacements.find(pair.second); - if (replacement != replacements.end()) - pair.second = replacement->second; - } + parse_table.error_state.each_advance_action( + [&replacements](ParseAction *action) { + auto replacement = replacements.find(action->state_index); + if (replacement != replacements.end()) + action->state_index = replacement->second; + }); } ParseAction *add_action(ParseStateId state_id, Symbol lookahead, diff --git a/src/compiler/build_tables/remove_duplicate_states.h b/src/compiler/build_tables/remove_duplicate_states.h index 2a6a9cdb..b70bb351 100644 --- a/src/compiler/build_tables/remove_duplicate_states.h +++ b/src/compiler/build_tables/remove_duplicate_states.h @@ -46,12 +46,11 @@ std::map remove_duplicate_states(std::vector *states) } for (StateType &state : *states) - state.each_advance_action( - [&duplicates, &new_replacements](ActionType *action) { - auto new_replacement = new_replacements.find(action->state_index); - if (new_replacement != new_replacements.end()) - action->state_index = new_replacement->second; - }); + state.each_advance_action([&new_replacements](ActionType *action) { + auto new_replacement = new_replacements.find(action->state_index); + if (new_replacement != new_replacements.end()) + action->state_index = new_replacement->second; + }); for (auto i = duplicates.rbegin(); i != duplicates.rend(); ++i) states->erase(states->begin() + i->first); diff --git a/src/compiler/generate_code/c_code.cc b/src/compiler/generate_code/c_code.cc index f407f7ff..ed6bd5ff 100644 --- a/src/compiler/generate_code/c_code.cc +++ b/src/compiler/generate_code/c_code.cc @@ -223,15 +223,12 @@ class CCodeGenerator { void add_out_of_context_parse_states_list() { line("static TSStateId ts_out_of_context_states[SYMBOL_COUNT] = {"); indent([&]() { - for (const auto &entry : parse_table.symbols) { + for (const auto &entry : parse_table.error_state.actions) { const rules::Symbol &symbol = entry.first; - if (symbol.is_built_in()) - continue; - auto iter = parse_table.out_of_context_state_indices.find(symbol); - string state = (iter != parse_table.out_of_context_state_indices.end()) - ? to_string(iter->second) - : "ts_parse_state_error"; - line("[" + symbol_id(symbol) + "] = " + state + ","); + if (!entry.second.empty()) { + ParseStateId state = entry.second[0].state_index; + line("[" + symbol_id(symbol) + "] = " + to_string(state) + ","); + } } }); line("};"); diff --git a/src/compiler/parse_table.h b/src/compiler/parse_table.h index 99232707..a31ca916 100644 --- a/src/compiler/parse_table.h +++ b/src/compiler/parse_table.h @@ -102,8 +102,8 @@ class ParseTable { ParseAction action); std::vector states; + ParseState error_state; std::map symbols; - std::map out_of_context_state_indices; }; } // namespace tree_sitter diff --git a/src/runtime/parser.c b/src/runtime/parser.c index b7c0848d..304241db 100644 --- a/src/runtime/parser.c +++ b/src/runtime/parser.c @@ -264,11 +264,13 @@ static bool ts_parser__select_tree(TSParser *self, TSTree *left, TSTree *right) if (!right) return false; if (right->error_size < left->error_size) { - LOG_ACTION("select_smaller_error symbol:%s, over_symbol:%s", SYM_NAME(right->symbol), SYM_NAME(left->symbol)); + LOG_ACTION("select_smaller_error symbol:%s, over_symbol:%s", + SYM_NAME(right->symbol), SYM_NAME(left->symbol)); return true; } if (left->error_size < right->error_size) { - LOG_ACTION("select_smaller_error symbol:%s, over_symbol:%s", SYM_NAME(left->symbol), SYM_NAME(right->symbol)); + LOG_ACTION("select_smaller_error symbol:%s, over_symbol:%s", + SYM_NAME(left->symbol), SYM_NAME(right->symbol)); return false; } return ts_tree_compare(right, left) < 0; @@ -314,7 +316,7 @@ error: } static bool ts_parser__switch_children(TSParser *self, TSTree *tree, - TSTree **children, size_t count) { + TSTree **children, size_t count) { self->scratch_tree.symbol = tree->symbol; self->scratch_tree.child_count = 0; ts_tree_set_children(&self->scratch_tree, count, children); @@ -534,8 +536,8 @@ static RepairResult ts_parser__repair_error(TSParser *self, StackSlice slice, LOG_ACTION( "repair_found sym:%s, child_count:%lu, match_count:%lu, skipped:%lu", - SYM_NAME(symbol), repair.count_below_error + count_above_error, repair.in_progress_state_count, - skip_count); + SYM_NAME(symbol), repair.count_below_error + count_above_error, + repair.in_progress_state_count, skip_count); if (skip_count > 0) { TreeArray skipped_children = array_new(); diff --git a/src/runtime/stack.c b/src/runtime/stack.c index 56af6de3..ef306f82 100644 --- a/src/runtime/stack.c +++ b/src/runtime/stack.c @@ -134,9 +134,9 @@ static StackVersion ts_stack__add_version(Stack *self, StackNode *node) { static bool ts_stack__add_slice(Stack *self, StackNode *node, TreeArray *trees) { for (size_t i = self->slices.size - 1; i + 1 > 0; i--) { - StackVersion version = self->slices.contents[i].version; + StackVersion version = self->slices.contents[i].version; if (self->heads.contents[version] == node) { - StackSlice slice = {*trees, version}; + StackSlice slice = { *trees, version }; return array_insert(&self->slices, i + 1, slice); } } @@ -144,7 +144,7 @@ static bool ts_stack__add_slice(Stack *self, StackNode *node, TreeArray *trees) StackVersion version = ts_stack__add_version(self, node); if (version == STACK_VERSION_NONE) return false; - StackSlice slice = {*trees, version}; + StackSlice slice = { *trees, version }; return array_push(&self->slices, slice); } @@ -442,7 +442,10 @@ int ts_stack_print_dot_graph(Stack *self, const char **symbol_names, FILE *f) { for (size_t i = 0; i < self->heads.size; i++) { StackNode *node = self->heads.contents[i]; fprintf(f, "node_head_%lu [shape=none, label=\"\"]\n", i); - fprintf(f, "node_head_%lu -> node_%p [label=%lu, arrowhead=none, fontcolor=blue, weight=10000]\n", i, node, i); + fprintf(f, + "node_head_%lu -> node_%p [label=%lu, arrowhead=none, " + "fontcolor=blue, weight=10000]\n", + i, node, i); array_push(&self->pop_paths, ((PopPath){.node = node })); }