diff --git a/spec/compiler/build_tables/distinctive_tokens_spec.cc b/spec/compiler/build_tables/distinctive_tokens_spec.cc index 1c9d8794..c5d197b3 100644 --- a/spec/compiler/build_tables/distinctive_tokens_spec.cc +++ b/spec/compiler/build_tables/distinctive_tokens_spec.cc @@ -27,7 +27,7 @@ describe("recovery_tokens(rule)", []() { })), }; - AssertThat(recovery_tokens(grammar), Equals>({ + AssertThat(recovery_tokens(grammar), Equals>({ Symbol(1, true), })); }); diff --git a/src/compiler/build_tables/build_parse_table.cc b/src/compiler/build_tables/build_parse_table.cc index 93820424..e0af2c4e 100644 --- a/src/compiler/build_tables/build_parse_table.cc +++ b/src/compiler/build_tables/build_parse_table.cc @@ -75,6 +75,8 @@ class ParseTableBuilder { for (const auto &pair2 : state.entries) parse_table.symbols[pair1.first].compatible_symbols.insert(pair2.first); + parse_table.mergeable_symbols = recovery_tokens(lexical_grammar); + build_error_parse_state(); allow_any_conflict = true; @@ -112,7 +114,7 @@ class ParseTableBuilder { void build_error_parse_state() { ParseState error_state; - for (const Symbol &symbol : recovery_tokens(lexical_grammar)) + for (auto &symbol : parse_table.mergeable_symbols) add_out_of_context_parse_state(&error_state, symbol); for (const Symbol &symbol : grammar.extra_tokens) diff --git a/src/compiler/build_tables/recovery_tokens.cc b/src/compiler/build_tables/recovery_tokens.cc index 0aacb7c3..e8d96aad 100644 --- a/src/compiler/build_tables/recovery_tokens.cc +++ b/src/compiler/build_tables/recovery_tokens.cc @@ -11,7 +11,7 @@ namespace tree_sitter { namespace build_tables { using rules::Symbol; -using std::vector; +using std::set; template class CharacterAggregator : public rules::RuleFn { @@ -47,8 +47,8 @@ class FirstCharacters : public CharacterAggregator {}; class LastCharacters : public CharacterAggregator {}; class AllCharacters : public CharacterAggregator {}; -vector recovery_tokens(const LexicalGrammar &grammar) { - vector result; +set recovery_tokens(const LexicalGrammar &grammar) { + set result; AllCharacters all_separator_characters; for (const rule_ptr &separator : grammar.separators) @@ -79,7 +79,7 @@ vector recovery_tokens(const LexicalGrammar &grammar) { !all_characters.result.intersects(all_separator_characters.result); if ((has_distinct_start && has_distinct_end) || has_no_separators) - result.push_back(Symbol(i, true)); + result.insert(Symbol(i, true)); } return result; diff --git a/src/compiler/build_tables/recovery_tokens.h b/src/compiler/build_tables/recovery_tokens.h index db477d76..c97a8cfd 100644 --- a/src/compiler/build_tables/recovery_tokens.h +++ b/src/compiler/build_tables/recovery_tokens.h @@ -3,7 +3,7 @@ #include "compiler/rule.h" #include "compiler/rules/symbol.h" -#include +#include namespace tree_sitter { @@ -11,7 +11,7 @@ struct LexicalGrammar; namespace build_tables { -std::vector recovery_tokens(const LexicalGrammar &); +std::set recovery_tokens(const LexicalGrammar &); } // namespace build_tables } // namespace tree_sitter diff --git a/src/compiler/parse_table.cc b/src/compiler/parse_table.cc index d345f0e4..b63d472b 100644 --- a/src/compiler/parse_table.cc +++ b/src/compiler/parse_table.cc @@ -206,6 +206,8 @@ bool ParseTable::merge_state(size_t i, size_t j) { const auto &other_entry = other.entries.find(symbol); if (other_entry == other.entries.end()) { + if (mergeable_symbols.count(symbol) == 0) + return false; if (actions.back().type != ParseActionTypeReduce) return false; if (!has_entry(other, entry.second)) @@ -222,6 +224,8 @@ bool ParseTable::merge_state(size_t i, size_t j) { const vector &actions = entry.second.actions; if (!state.entries.count(symbol)) { + if (mergeable_symbols.count(symbol) == 0) + return false; if (actions.back().type != ParseActionTypeReduce) return false; if (!has_entry(state, entry.second)) diff --git a/src/compiler/parse_table.h b/src/compiler/parse_table.h index 1a00b273..cd24d32a 100644 --- a/src/compiler/parse_table.h +++ b/src/compiler/parse_table.h @@ -97,6 +97,8 @@ class ParseTable { std::vector states; std::map symbols; + + std::set mergeable_symbols; }; } // namespace tree_sitter