diff --git a/src/compiler/build_tables/build_parse_table.cc b/src/compiler/build_tables/build_parse_table.cc index e6a93099..bcb602d5 100644 --- a/src/compiler/build_tables/build_parse_table.cc +++ b/src/compiler/build_tables/build_parse_table.cc @@ -50,6 +50,7 @@ class ParseTableBuilder { ParseItemSetBuilder item_set_builder; set fragile_productions; vector> incompatible_tokens_by_index; + vector> following_terminals_by_terminal_index; bool processing_recovery_states; public: @@ -57,6 +58,8 @@ class ParseTableBuilder { : grammar(grammar), lexical_grammar(lex_grammar), item_set_builder(grammar, lex_grammar), + incompatible_tokens_by_index(lexical_grammar.variables.size()), + following_terminals_by_terminal_index(lexical_grammar.variables.size()), processing_recovery_states(false) {} pair build() { @@ -314,8 +317,6 @@ class ParseTableBuilder { } void compute_unmergable_token_pairs() { - incompatible_tokens_by_index.resize(lexical_grammar.variables.size()); - auto lex_table_builder = LexTableBuilder::create(lexical_grammar); for (unsigned i = 0, n = lexical_grammar.variables.size(); i < n; i++) { Symbol token = Symbol::terminal(i); @@ -323,7 +324,7 @@ class ParseTableBuilder { for (unsigned j = 0; j < n; j++) { if (i == j) continue; - if (lex_table_builder->detect_conflict(i, j)) { + if (lex_table_builder->detect_conflict(i, j, following_terminals_by_terminal_index)) { incompatible_indices.insert(Symbol::terminal(j)); } } @@ -690,6 +691,23 @@ class ParseTableBuilder { } SymbolSequence append_symbol(const SymbolSequence &sequence, const Symbol &symbol) { + if (!sequence.empty()) { + const LookaheadSet &left_tokens = item_set_builder.get_last_set(sequence.back()); + const LookaheadSet &right_tokens = item_set_builder.get_first_set(symbol); + + if (!left_tokens.empty() && !right_tokens.empty()) { + for (const Symbol &left_symbol : *left_tokens.entries) { + if (left_symbol.is_terminal() && !left_symbol.is_built_in()) { + for (const Symbol &right_symbol : *right_tokens.entries) { + if (right_symbol.is_terminal() && !right_symbol.is_built_in()) { + following_terminals_by_terminal_index[left_symbol.index].insert(right_symbol.index); + } + } + } + } + } + } + SymbolSequence result(sequence.size() + 1); result.assign(sequence.begin(), sequence.end()); result.push_back(symbol); diff --git a/src/compiler/build_tables/lex_table_builder.cc b/src/compiler/build_tables/lex_table_builder.cc index 4101f854..4f59a374 100644 --- a/src/compiler/build_tables/lex_table_builder.cc +++ b/src/compiler/build_tables/lex_table_builder.cc @@ -5,6 +5,7 @@ #include #include #include +#include #include #include "compiler/build_tables/lex_conflict_manager.h" #include "compiler/build_tables/lex_item.h" @@ -70,14 +71,16 @@ class LexTableBuilderImpl : public LexTableBuilder { LexTable lex_table; const LexicalGrammar grammar; vector separator_rules; - CharacterSet separator_start_characters; - CharacterSet token_start_characters; LexConflictManager conflict_manager; unordered_map lex_state_ids; - public: - vector shadowed_token_indices; + map following_characters_by_token_index; + CharacterSet separator_start_characters; + CharacterSet current_conflict_detection_following_characters; + Symbol::Index current_conflict_detection_token_index; + bool current_conflict_value; + public: LexTableBuilderImpl(const LexicalGrammar &grammar) : grammar(grammar) { StartingCharacterAggregator separator_character_aggregator; for (const auto &rule : grammar.separators) { @@ -86,20 +89,6 @@ class LexTableBuilderImpl : public LexTableBuilder { } separator_rules.push_back(Blank{}); separator_start_characters = separator_character_aggregator.result; - - StartingCharacterAggregator token_start_character_aggregator; - for (const auto &variable : grammar.variables) { - token_start_character_aggregator.apply(variable.rule); - } - token_start_characters = token_start_character_aggregator.result; - token_start_characters - .exclude('a', 'z') - .exclude('A', 'Z') - .exclude('0', '9') - .exclude('_') - .exclude('$'); - - shadowed_token_indices.resize(grammar.variables.size()); } LexTable build(ParseTable *parse_table) { @@ -113,7 +102,10 @@ class LexTableBuilderImpl : public LexTableBuilder { return lex_table; } - bool detect_conflict(Symbol::Index left, Symbol::Index right) { + bool detect_conflict(Symbol::Index left, Symbol::Index right, + const vector> &following_terminals_by_terminal_index) { + clear(); + StartingCharacterAggregator left_starting_characters; StartingCharacterAggregator right_starting_characters; left_starting_characters.apply(grammar.variables[left].rule); @@ -124,12 +116,47 @@ class LexTableBuilderImpl : public LexTableBuilder { return false; } - clear(); - map terminals; - terminals[Symbol::terminal(left)]; - terminals[Symbol::terminal(right)]; - add_lex_state(item_set_for_terminals(terminals)); - return shadowed_token_indices[right]; + auto following_characters_entry = following_characters_by_token_index.find(right); + if (following_characters_entry == following_characters_by_token_index.end()) { + StartingCharacterAggregator aggregator; + for (auto following_token_index : following_terminals_by_terminal_index[right]) { + aggregator.apply(grammar.variables[following_token_index].rule); + } + following_characters_entry = + following_characters_by_token_index.insert({right, aggregator.result}).first; + + // TODO - Refactor this. In general, a keyword token cannot be followed immediately by + // another alphanumeric character. But this requirement is currently not expressed anywhere in + // the grammar. So without this hack, we would be overly conservative about merging parse + // states because we would often consider `identifier` tokens to *conflict* with keyword + // tokens. + if (is_keyword(grammar.variables[right])) { + following_characters_entry->second + .exclude('a', 'z') + .exclude('A', 'Z') + .exclude('0', '9') + .exclude('_') + .exclude('$'); + } + } + + current_conflict_detection_token_index = right; + current_conflict_detection_following_characters = following_characters_entry->second; + add_lex_state(item_set_for_terminals({{Symbol::terminal(left), {}}, {Symbol::terminal(right), {}}})); + return current_conflict_value; + } + + bool is_keyword(const LexicalVariable &variable) { + return variable.is_string && iswalpha(get_last_character(variable.rule)); + } + + static uint32_t get_last_character(const Rule &rule) { + return rule.match( + [](const Seq &sequence) { return get_last_character(*sequence.right); }, + [](const rules::CharacterSet &rule) { return *rule.included_chars.begin(); }, + [](const rules::Metadata &rule) { return get_last_character(*rule.rule); }, + [](auto) { return 0; } + ); } LexStateId add_lex_state(const LexItemSet &item_set) { @@ -149,7 +176,8 @@ class LexTableBuilderImpl : public LexTableBuilder { void clear() { lex_table.states.clear(); lex_state_ids.clear(); - shadowed_token_indices.assign(grammar.variables.size(), false); + current_conflict_detection_following_characters = CharacterSet(); + current_conflict_value = false; } private: @@ -166,17 +194,18 @@ class LexTableBuilderImpl : public LexTableBuilder { for (const LexItem &item : transition.destination.entries) { if (item.lhs == accept_action.symbol) { can_advance_for_accepted_token = true; - } else if (!prefer_advancing && !transition.in_main_token && !item.lhs.is_built_in()) { - shadowed_token_indices[item.lhs.index] = true; + } else if (item.lhs.index == current_conflict_detection_token_index && + !prefer_advancing && !transition.in_main_token) { + current_conflict_value = true; } } - if (!can_advance_for_accepted_token) { - if (characters.intersects(separator_start_characters) || - (grammar.variables[accept_action.symbol.index].is_string && - characters.intersects(token_start_characters))) { - shadowed_token_indices[accept_action.symbol.index] = true; - } + if (accept_action.symbol.index == current_conflict_detection_token_index && + !can_advance_for_accepted_token && + (characters.intersects(separator_start_characters) || + (characters.intersects(current_conflict_detection_following_characters) && + grammar.variables[accept_action.symbol.index].is_string))) { + current_conflict_value = true; } if (!prefer_advancing) continue; @@ -346,8 +375,9 @@ LexTable LexTableBuilder::build(ParseTable *parse_table) { return static_cast(this)->build(parse_table); } -bool LexTableBuilder::detect_conflict(Symbol::Index left, Symbol::Index right) { - return static_cast(this)->detect_conflict(left, right); +bool LexTableBuilder::detect_conflict(Symbol::Index left, Symbol::Index right, + const vector> &following_terminals) { + return static_cast(this)->detect_conflict(left, right, following_terminals); } } // namespace build_tables diff --git a/src/compiler/build_tables/lex_table_builder.h b/src/compiler/build_tables/lex_table_builder.h index 91f24f70..3b896bb7 100644 --- a/src/compiler/build_tables/lex_table_builder.h +++ b/src/compiler/build_tables/lex_table_builder.h @@ -2,6 +2,8 @@ #define COMPILER_BUILD_TABLES_LEX_TABLE_BUILDER_H_ #include +#include +#include #include "compiler/lex_table.h" namespace tree_sitter { @@ -15,7 +17,11 @@ class LexTableBuilder { public: static std::unique_ptr create(const LexicalGrammar &); LexTable build(ParseTable *); - bool detect_conflict(rules::Symbol::Index, rules::Symbol::Index); + bool detect_conflict( + rules::Symbol::Index, + rules::Symbol::Index, + const std::vector> &following_terminals_by_terminal_index + ); protected: LexTableBuilder() = default; }; diff --git a/src/compiler/build_tables/parse_item_set_builder.cc b/src/compiler/build_tables/parse_item_set_builder.cc index 36c3942f..77fde864 100644 --- a/src/compiler/build_tables/parse_item_set_builder.cc +++ b/src/compiler/build_tables/parse_item_set_builder.cc @@ -1,4 +1,5 @@ #include "compiler/build_tables/parse_item_set_builder.h" +#include #include #include #include @@ -26,18 +27,20 @@ ParseItemSetBuilder::ParseItemSetBuilder(const SyntaxGrammar &grammar, for (size_t i = 0, n = lexical_grammar.variables.size(); i < n; i++) { Symbol symbol = Symbol::terminal(i); - first_sets.insert({symbol, LookaheadSet({ symbol })}); + first_sets.insert({symbol, LookaheadSet({symbol})}); + last_sets.insert({symbol, LookaheadSet({symbol})}); } for (size_t i = 0, n = grammar.external_tokens.size(); i < n; i++) { Symbol symbol = Symbol::external(i); - first_sets.insert({symbol, LookaheadSet({ symbol })}); + first_sets.insert({symbol, LookaheadSet({symbol})}); + last_sets.insert({symbol, LookaheadSet({symbol})}); } for (size_t i = 0, n = grammar.variables.size(); i < n; i++) { Symbol symbol = Symbol::non_terminal(i); - LookaheadSet first_set; + LookaheadSet first_set; processed_non_terminals.clear(); symbols_to_process.clear(); symbols_to_process.push_back(symbol); @@ -57,6 +60,26 @@ ParseItemSetBuilder::ParseItemSetBuilder(const SyntaxGrammar &grammar, } first_sets.insert({symbol, first_set}); + + LookaheadSet last_set; + processed_non_terminals.clear(); + symbols_to_process.clear(); + symbols_to_process.push_back(symbol); + while (!symbols_to_process.empty()) { + Symbol current_symbol = symbols_to_process.back(); + symbols_to_process.pop_back(); + + if (!current_symbol.is_non_terminal()) { + last_set.insert(current_symbol); + } else if (processed_non_terminals.insert(current_symbol.index).second) { + for (const Production &production : grammar.variables[current_symbol.index].productions) { + if (!production.empty()) { + symbols_to_process.push_back(production.back().symbol); + } + } + } + } + last_sets.insert({symbol, last_set}); } vector components_to_process; @@ -161,5 +184,9 @@ LookaheadSet ParseItemSetBuilder::get_first_set(const rules::Symbol &symbol) con return first_sets.find(symbol)->second; } +LookaheadSet ParseItemSetBuilder::get_last_set(const rules::Symbol &symbol) const { + return last_sets.find(symbol)->second; +} + } // namespace build_tables } // namespace tree_sitter diff --git a/src/compiler/build_tables/parse_item_set_builder.h b/src/compiler/build_tables/parse_item_set_builder.h index a319d698..b0334e68 100644 --- a/src/compiler/build_tables/parse_item_set_builder.h +++ b/src/compiler/build_tables/parse_item_set_builder.h @@ -20,6 +20,7 @@ class ParseItemSetBuilder { }; std::map first_sets; + std::map last_sets; std::map> component_cache; std::vector> item_set_buffer; @@ -27,6 +28,7 @@ class ParseItemSetBuilder { ParseItemSetBuilder(const SyntaxGrammar &, const LexicalGrammar &); void apply_transitive_closure(ParseItemSet *); LookaheadSet get_first_set(const rules::Symbol &) const; + LookaheadSet get_last_set(const rules::Symbol &) const; }; } // namespace build_tables diff --git a/test/compiler/build_tables/lex_table_builder_test.cc b/test/compiler/build_tables/lex_table_builder_test.cc index 7376bd2c..e9f70aee 100644 --- a/test/compiler/build_tables/lex_table_builder_test.cc +++ b/test/compiler/build_tables/lex_table_builder_test.cc @@ -16,7 +16,7 @@ describe("LexTableBuilder::detect_conflict", []() { auto builder = LexTableBuilder::create(LexicalGrammar{ { LexicalVariable{ - "token_1", + "token_0", VariableTypeNamed, Rule::seq({ CharacterSet({ 'a' }), @@ -26,7 +26,7 @@ describe("LexTableBuilder::detect_conflict", []() { false }, LexicalVariable{ - "token_2", + "token_1", VariableTypeNamed, Rule::seq({ CharacterSet({ 'b' }), @@ -39,22 +39,22 @@ describe("LexTableBuilder::detect_conflict", []() { separators }); - AssertThat(builder->detect_conflict(0, 1), IsFalse()); - AssertThat(builder->detect_conflict(1, 0), IsFalse()); + AssertThat(builder->detect_conflict(0, 1, {{}, {}}), IsFalse()); + AssertThat(builder->detect_conflict(1, 0, {{}, {}}), IsFalse()); }); - it("returns true when one token matches a string that the other matches, " - "plus some addition content that begins with a separator character", [&]() { + it("returns true when the left token can match a string that the right token matches, " + "plus a separator character", [&]() { LexicalGrammar grammar{ { LexicalVariable{ - "token_1", + "token_0", VariableTypeNamed, Rule::repeat(CharacterSet().include_all().exclude('\n')), // regex: /.+/ false }, LexicalVariable{ - "token_2", + "token_1", VariableTypeNamed, Rule::seq({ CharacterSet({ 'a' }), CharacterSet({ 'b' }), CharacterSet({ 'c' }) }), // string: 'abc' true @@ -64,24 +64,32 @@ describe("LexTableBuilder::detect_conflict", []() { }; auto builder = LexTableBuilder::create(grammar); - AssertThat(builder->detect_conflict(0, 1), IsTrue()); - AssertThat(builder->detect_conflict(1, 0), IsFalse()); + AssertThat(builder->detect_conflict(0, 1, {{}, {}}), IsTrue()); + AssertThat(builder->detect_conflict(1, 0, {{}, {}}), IsFalse()); grammar.variables[1].is_string = false; - AssertThat(builder->detect_conflict(0, 1), IsTrue()); - AssertThat(builder->detect_conflict(1, 0), IsFalse()); + AssertThat(builder->detect_conflict(0, 1, {{}, {}}), IsTrue()); + AssertThat(builder->detect_conflict(1, 0, {{}, {}}), IsFalse()); }); - it("returns true when one token matches a string that the other matches, " - "plus some addition content that matches another one-character token", [&]() { + it("returns true when the left token matches a string that the right token matches, " + "plus the first character of some token that can follow the right token", [&]() { LexicalGrammar grammar{ { + LexicalVariable{ + "token_0", + VariableTypeNamed, + Rule::seq({ + CharacterSet({ '>' }), + CharacterSet({ '=' }), + }), + true + }, LexicalVariable{ "token_1", VariableTypeNamed, Rule::seq({ CharacterSet({ '>' }), - CharacterSet({ '>' }), }), true }, @@ -89,7 +97,7 @@ describe("LexTableBuilder::detect_conflict", []() { "token_2", VariableTypeNamed, Rule::seq({ - CharacterSet({ '>' }), + CharacterSet({ '=' }), }), true }, @@ -97,9 +105,17 @@ describe("LexTableBuilder::detect_conflict", []() { separators }; + // If no tokens can follow token_1, then there's no conflict auto builder = LexTableBuilder::create(grammar); - AssertThat(builder->detect_conflict(0, 1), IsTrue()); - AssertThat(builder->detect_conflict(1, 0), IsFalse()); + vector> following_tokens_by_token_index(3); + AssertThat(builder->detect_conflict(0, 1, following_tokens_by_token_index), IsFalse()); + AssertThat(builder->detect_conflict(1, 0, following_tokens_by_token_index), IsFalse()); + + // If token_2 can follow token_1, then token_0 conflicts with token_1 + builder = LexTableBuilder::create(grammar); + following_tokens_by_token_index[1].insert(2); + AssertThat(builder->detect_conflict(0, 1, following_tokens_by_token_index), IsTrue()); + AssertThat(builder->detect_conflict(1, 0, following_tokens_by_token_index), IsFalse()); }); });