diff --git a/include/tree_sitter/parser.h b/include/tree_sitter/parser.h index 9219032b..10d17582 100644 --- a/include/tree_sitter/parser.h +++ b/include/tree_sitter/parser.h @@ -14,7 +14,6 @@ typedef unsigned short TSStateId; #define ts_builtin_sym_error ((TSSymbol)-1) #define ts_builtin_sym_end 0 -#define ts_builtin_sym_start 1 typedef struct { bool visible : 1; @@ -60,6 +59,7 @@ typedef union { typedef struct TSLanguage { size_t symbol_count; + size_t token_count; const char **symbol_names; const TSSymbolMetadata *symbol_metadata; const unsigned short *parse_table; @@ -103,6 +103,9 @@ typedef struct TSLanguage { * Parse Table Macros */ +#define STATE(id) id +#define ACTIONS(id) id + #define SHIFT(to_state_value) \ { \ { \ @@ -146,6 +149,7 @@ typedef struct TSLanguage { #define EXPORT_LANGUAGE(language_name) \ static TSLanguage language = { \ .symbol_count = SYMBOL_COUNT, \ + .token_count = TOKEN_COUNT, \ .symbol_metadata = ts_symbol_metadata, \ .parse_table = (const unsigned short *)ts_parse_table, \ .parse_actions = ts_parse_actions, \ diff --git a/spec/compiler/build_tables/lex_conflict_manager_spec.cc b/spec/compiler/build_tables/lex_conflict_manager_spec.cc index b62b9137..7f43e175 100644 --- a/spec/compiler/build_tables/lex_conflict_manager_spec.cc +++ b/spec/compiler/build_tables/lex_conflict_manager_spec.cc @@ -1,5 +1,6 @@ #include "spec_helper.h" #include "helpers/rule_helpers.h" +#include "helpers/stream_methods.h" #include "compiler/rules/built_in_symbols.h" #include "compiler/parse_table.h" #include "compiler/build_tables/lex_conflict_manager.h" @@ -36,7 +37,7 @@ describe("LexConflictManager::resolve(new_action, old_action)", []() { it("adds the preferred token as a possible homonym for the discarded one", [&]() { conflict_manager.resolve(AcceptTokenAction(sym2, 1, false), AcceptTokenAction(sym1, 2, false)); - AssertThat(conflict_manager.possible_homonyms[sym2], Contains(sym1)); + AssertThat(conflict_manager.possible_homonyms[sym2.index], Contains(sym1.index)); }); }); @@ -78,7 +79,7 @@ describe("LexConflictManager::resolve(new_action, old_action)", []() { it("adds the in-progress tokens as possible extensions of the discarded token", [&]() { conflict_manager.resolve(item_set, AdvanceAction(1, { 1, 2 }, true), AcceptTokenAction(sym3, 3, true)); - AssertThat(conflict_manager.possible_extensions[sym3], Contains(sym4)); + AssertThat(conflict_manager.possible_extensions[sym3.index], Contains(sym4.index)); }); }); }); diff --git a/spec/fixtures/error_corpus/c_errors.txt b/spec/fixtures/error_corpus/c_errors.txt index 4dfb0894..05bc7ec2 100644 --- a/spec/fixtures/error_corpus/c_errors.txt +++ b/spec/fixtures/error_corpus/c_errors.txt @@ -127,6 +127,5 @@ int b() { (ERROR (identifier) (identifier)) (identifier) (number_literal))) (declaration - (ERROR (identifier) (identifier)) (identifier) - (init_declarator (identifier) (number_literal)))))) + (init_declarator (ERROR (identifier) (identifier)) (identifier) (number_literal)))))) diff --git a/spec/fixtures/error_corpus/javascript_errors.txt b/spec/fixtures/error_corpus/javascript_errors.txt index 0c5e976e..c308e443 100644 --- a/spec/fixtures/error_corpus/javascript_errors.txt +++ b/spec/fixtures/error_corpus/javascript_errors.txt @@ -14,10 +14,8 @@ e f; (ERROR (identifier)) (identifier) (statement_block - (ERROR (identifier)) - (expression_statement (identifier)))) - (ERROR (identifier)) - (expression_statement (identifier))) + (expression_statement (ERROR (identifier)) (identifier)))) + (expression_statement (ERROR (identifier)) (identifier))) ======================================================= multiple invalid tokens right after the viable prefix @@ -35,8 +33,7 @@ h i j k; (ERROR (identifier) (identifier)) (identifier) (statement_block - (ERROR (identifier) (identifier) (identifier)) - (expression_statement (identifier)))) + (expression_statement (ERROR (identifier) (identifier) (identifier)) (identifier)))) (expression_statement (ERROR (identifier) (identifier) (identifier)) (identifier))) diff --git a/spec/helpers/stream_methods.cc b/spec/helpers/stream_methods.cc index 514d6181..69483ed3 100644 --- a/spec/helpers/stream_methods.cc +++ b/spec/helpers/stream_methods.cc @@ -75,21 +75,15 @@ ostream &operator<<(ostream &stream, const ParseAction &action) { } } +ostream &operator<<(ostream &stream, const ParseTableEntry &entry) { + return stream << entry.actions; +} + ostream &operator<<(ostream &stream, const ParseState &state) { - stream << string("# {"); - for (auto &action : entry.second.actions) { - stream << string(" ") << action; - } - stream << string("}"); - started = true; - } - stream << string(">"); - return stream; + stream << string("#"); } ostream &operator<<(ostream &stream, const ProductionStep &step) { diff --git a/spec/runtime/tree_spec.cc b/spec/runtime/tree_spec.cc index 83bb67a5..9f451829 100644 --- a/spec/runtime/tree_spec.cc +++ b/spec/runtime/tree_spec.cc @@ -23,7 +23,7 @@ void assert_consistent(const Tree *tree) { START_TEST enum { - cat = ts_builtin_sym_start, + cat = 1, dog, eel, fox, diff --git a/src/compiler/build_tables/build_lex_table.cc b/src/compiler/build_tables/build_lex_table.cc index 56de23cf..0b75c368 100644 --- a/src/compiler/build_tables/build_lex_table.cc +++ b/src/compiler/build_tables/build_lex_table.cc @@ -114,14 +114,11 @@ class LexTableBuilder { void mark_fragile_tokens() { for (ParseState &state : parse_table->states) { - for (auto &entry : state.entries) { - if (!entry.first.is_token) - continue; - + for (auto &entry : state.terminal_entries) { auto homonyms = conflict_manager.possible_homonyms.find(entry.first); if (homonyms != conflict_manager.possible_homonyms.end()) - for (const Symbol &homonym : homonyms->second) - if (state.entries.count(homonym)) { + for (int homonym : homonyms->second) + if (state.terminal_entries.count(homonym)) { entry.second.reusable = false; break; } @@ -131,8 +128,8 @@ class LexTableBuilder { auto extensions = conflict_manager.possible_extensions.find(entry.first); if (extensions != conflict_manager.possible_extensions.end()) - for (const Symbol &extension : extensions->second) - if (state.entries.count(extension)) { + for (int extension : extensions->second) + if (state.terminal_entries.count(extension)) { entry.second.depends_on_lookahead = true; break; } @@ -147,7 +144,7 @@ class LexTableBuilder { } auto replacements = - remove_duplicate_states(&lex_table); + remove_duplicate_states(&lex_table); for (ParseState &parse_state : parse_table->states) { auto replacement = replacements.find(parse_state.lex_state_id); diff --git a/src/compiler/build_tables/build_parse_table.cc b/src/compiler/build_tables/build_parse_table.cc index 829d8e93..8e69e228 100644 --- a/src/compiler/build_tables/build_parse_table.cc +++ b/src/compiler/build_tables/build_parse_table.cc @@ -116,17 +116,16 @@ class ParseTableBuilder { } for (const Symbol &symbol : grammar.extra_tokens) { - if (!error_state.entries.count(symbol)) { - error_state.entries[symbol].actions.push_back(ParseAction::ShiftExtra()); + if (!error_state.terminal_entries.count(symbol.index)) { + error_state.terminal_entries[symbol.index].actions.push_back(ParseAction::ShiftExtra()); } } for (size_t i = 0; i < grammar.variables.size(); i++) { - Symbol symbol(i, false); - add_out_of_context_parse_state(&error_state, symbol); + add_out_of_context_parse_state(&error_state, Symbol(i, false)); } - error_state.entries[END_OF_INPUT()].actions.push_back(ParseAction::Recover(0)); + error_state.terminal_entries[END_OF_INPUT().index].actions.push_back(ParseAction::Recover(0)); parse_table.states[0] = error_state; } @@ -135,7 +134,11 @@ class ParseTableBuilder { const ParseItemSet &item_set = recovery_states[symbol]; if (!item_set.entries.empty()) { ParseStateId state = add_parse_state(item_set); - error_state->entries[symbol].actions.push_back(ParseAction::Recover(state)); + if (symbol.is_token) { + error_state->terminal_entries[symbol.index].actions.assign({ ParseAction::Recover(state) }); + } else { + error_state->nonterminal_entries[symbol.index] = state; + } } } @@ -158,14 +161,19 @@ class ParseTableBuilder { const ParseItemSet &next_item_set = transition.second.first; const PrecedenceRange &precedence = transition.second.second; - ParseAction *new_action = add_action( - state_id, symbol, ParseAction::Shift(0, precedence), item_set); - - if (!allow_any_conflict) + if (!allow_any_conflict) { recovery_states[symbol].add(next_item_set); + } - if (new_action) - new_action->state_index = add_parse_state(next_item_set); + if (symbol.is_token) { + ParseAction *new_action = add_terminal_action( + state_id, symbol, ParseAction::Shift(0, precedence), item_set); + if (new_action) { + new_action->state_index = add_parse_state(next_item_set); + } + } else { + parse_table.set_nonterminal_action(state_id, symbol.index, add_parse_state(next_item_set)); + } } } @@ -185,8 +193,9 @@ class ParseTableBuilder { status.associativity, *item.production); } - for (const auto &lookahead_sym : *lookahead_symbols.entries) - add_action(state_id, lookahead_sym, action, item_set); + for (const Symbol &lookahead : *lookahead_symbols.entries) { + add_terminal_action(state_id, lookahead, action, item_set); + } } } } @@ -195,24 +204,25 @@ class ParseTableBuilder { ParseAction action = ParseAction::ShiftExtra(); ParseState &state = parse_table.states[state_id]; for (const Symbol &extra_symbol : grammar.extra_tokens) - if (!state.entries.count(extra_symbol) || state.has_shift_action() || - allow_any_conflict) - parse_table.add_action(state_id, extra_symbol, action); + if (!state.terminal_entries.count(extra_symbol.index) || + state.has_shift_action() || allow_any_conflict) + parse_table.add_terminal_action(state_id, extra_symbol.index, action); } void mark_fragile_actions() { for (ParseState &state : parse_table.states) { set symbols_with_multiple_actions; - for (auto &entry : state.entries) { - const Symbol &symbol = entry.first; + for (auto &entry : state.terminal_entries) { + const Symbol symbol(entry.first, true); auto &actions = entry.second.actions; - if (actions.size() > 1) + if (actions.size() > 1) { symbols_with_multiple_actions.insert(symbol); + } for (ParseAction &action : actions) { - if (action.type == ParseActionTypeReduce && !action.extra) { + if (action.type == ParseActionTypeReduce) { if (has_fragile_production(action.production)) action.fragile = true; @@ -231,21 +241,8 @@ class ParseTableBuilder { break; } } - if (!erased) + if (!erased) { ++i; - } - } - - if (!symbols_with_multiple_actions.empty()) { - for (auto &entry : state.entries) { - if (!entry.first.is_token) { - set first_set = get_first_set(entry.first); - for (const Symbol &symbol : symbols_with_multiple_actions) { - if (first_set.count(symbol)) { - entry.second.reusable = false; - break; - } - } } } } @@ -253,33 +250,34 @@ class ParseTableBuilder { } void remove_duplicate_parse_states() { - remove_duplicate_states(&parse_table); + remove_duplicate_states(&parse_table); } - ParseAction *add_action(ParseStateId state_id, Symbol lookahead, - const ParseAction &new_action, - const ParseItemSet &item_set) { + ParseAction *add_terminal_action(ParseStateId state_id, Symbol lookahead, + const ParseAction &new_action, + const ParseItemSet &item_set) { const ParseState &state = parse_table.states[state_id]; - const auto ¤t_entry = state.entries.find(lookahead); - if (current_entry == state.entries.end()) - return &parse_table.set_action(state_id, lookahead, new_action); + const auto ¤t_entry = state.terminal_entries.find(lookahead.index); + if (current_entry == state.terminal_entries.end()) + return &parse_table.set_terminal_action(state_id, lookahead.index, new_action); if (allow_any_conflict) - return &parse_table.add_action(state_id, lookahead, new_action); + return &parse_table.add_terminal_action(state_id, lookahead.index, new_action); const ParseAction old_action = current_entry->second.actions[0]; auto resolution = conflict_manager.resolve(new_action, old_action); switch (resolution.second) { case ConflictTypeNone: - if (resolution.first) - return &parse_table.set_action(state_id, lookahead, new_action); + if (resolution.first) { + return &parse_table.set_terminal_action(state_id, lookahead.index, new_action); + } break; case ConflictTypeResolved: { if (resolution.first) { if (old_action.type == ParseActionTypeReduce) fragile_productions.insert(old_action.production); - return &parse_table.set_action(state_id, lookahead, new_action); + return &parse_table.set_terminal_action(state_id, lookahead.index, new_action); } else { if (new_action.type == ParseActionTypeReduce) fragile_productions.insert(new_action.production); @@ -293,7 +291,7 @@ class ParseTableBuilder { fragile_productions.insert(old_action.production); if (new_action.type == ParseActionTypeReduce) fragile_productions.insert(new_action.production); - return &parse_table.add_action(state_id, lookahead, new_action); + return &parse_table.add_terminal_action(state_id, lookahead.index, new_action); } break; } diff --git a/src/compiler/build_tables/lex_conflict_manager.cc b/src/compiler/build_tables/lex_conflict_manager.cc index b89228d4..3fc22ed2 100644 --- a/src/compiler/build_tables/lex_conflict_manager.cc +++ b/src/compiler/build_tables/lex_conflict_manager.cc @@ -14,7 +14,7 @@ bool LexConflictManager::resolve(const LexItemSet &item_set, return true; if (new_action.precedence_range.max >= old_action.precedence) { for (const LexItem &item : item_set.entries) - possible_extensions[old_action.symbol].insert(item.lhs); + possible_extensions[old_action.symbol.index].insert(item.lhs.index); return true; } else { return false; @@ -44,9 +44,9 @@ bool LexConflictManager::resolve(const AcceptTokenAction &new_action, result = false; if (result) - possible_homonyms[old_action.symbol].insert(new_action.symbol); + possible_homonyms[old_action.symbol.index].insert(new_action.symbol.index); else - possible_homonyms[new_action.symbol].insert(old_action.symbol); + possible_homonyms[new_action.symbol.index].insert(old_action.symbol.index); return result; } diff --git a/src/compiler/build_tables/lex_conflict_manager.h b/src/compiler/build_tables/lex_conflict_manager.h index 8fb0f075..9777dc36 100644 --- a/src/compiler/build_tables/lex_conflict_manager.h +++ b/src/compiler/build_tables/lex_conflict_manager.h @@ -21,8 +21,8 @@ class LexConflictManager { const AcceptTokenAction &); bool resolve(const AcceptTokenAction &, const AcceptTokenAction &); - std::map> possible_homonyms; - std::map> possible_extensions; + std::map> possible_homonyms; + std::map> possible_extensions; }; } // namespace build_tables diff --git a/src/compiler/build_tables/remove_duplicate_states.h b/src/compiler/build_tables/remove_duplicate_states.h index 601737a5..a154c05a 100644 --- a/src/compiler/build_tables/remove_duplicate_states.h +++ b/src/compiler/build_tables/remove_duplicate_states.h @@ -7,7 +7,7 @@ namespace tree_sitter { namespace build_tables { -template +template std::map remove_duplicate_states(TableType *table) { std::map replacements; @@ -46,10 +46,10 @@ std::map remove_duplicate_states(TableType *table) { } for (auto &state : table->states) - state.each_advance_action([&new_replacements](ActionType *action) { - auto new_replacement = new_replacements.find(action->state_index); + state.each_referenced_state([&new_replacements](int64_t *state_index) { + auto new_replacement = new_replacements.find(*state_index); if (new_replacement != new_replacements.end()) - action->state_index = new_replacement->second; + *state_index = new_replacement->second; }); for (auto i = duplicates.rbegin(); i != duplicates.rend(); ++i) diff --git a/src/compiler/generate_code/c_code.cc b/src/compiler/generate_code/c_code.cc index 65244fdf..78a8c707 100644 --- a/src/compiler/generate_code/c_code.cc +++ b/src/compiler/generate_code/c_code.cc @@ -115,6 +115,7 @@ class CCodeGenerator { void add_state_and_symbol_counts() { line("#define STATE_COUNT " + to_string(parse_table.states.size())); line("#define SYMBOL_COUNT " + to_string(parse_table.symbols.size())); + line("#define TOKEN_COUNT " + to_string(lexical_grammar.variables.size() + 1)); line(); } @@ -222,10 +223,15 @@ class CCodeGenerator { for (const auto &state : parse_table.states) { line("[" + to_string(state_id++) + "] = {"); indent([&]() { - for (const auto &entry : state.entries) { - line("[" + symbol_id(entry.first) + "] = "); + for (const auto &entry : state.nonterminal_entries) { + line("[" + symbol_id(rules::Symbol(entry.first)) + "] = STATE("); + add(to_string(entry.second)); + add("),"); + } + for (const auto &entry : state.terminal_entries) { + line("[" + symbol_id(rules::Symbol(entry.first, true)) + "] = ACTIONS("); add(to_string(add_parse_action_list_id(entry.second))); - add(","); + add("),"); } }); line("},"); diff --git a/src/compiler/lex_table.cc b/src/compiler/lex_table.cc index 852586e5..8f8d2ded 100644 --- a/src/compiler/lex_table.cc +++ b/src/compiler/lex_table.cc @@ -57,9 +57,9 @@ bool LexState::operator==(const LexState &other) const { is_token_start == other.is_token_start; } -void LexState::each_advance_action(function fn) { +void LexState::each_referenced_state(function fn) { for (auto &entry : advance_actions) - fn(&entry.second); + fn(&entry.second.state_index); } LexStateId LexTable::add_state() { diff --git a/src/compiler/lex_table.h b/src/compiler/lex_table.h index d508e9da..ac7357a1 100644 --- a/src/compiler/lex_table.h +++ b/src/compiler/lex_table.h @@ -11,6 +11,8 @@ namespace tree_sitter { +typedef int64_t LexStateId; + typedef enum { LexActionTypeError, LexActionTypeAccept, @@ -24,7 +26,7 @@ struct AdvanceAction { bool operator==(const AdvanceAction &other) const; - size_t state_index; + LexStateId state_index; PrecedenceRange precedence_range; bool in_main_token; }; @@ -52,15 +54,13 @@ class LexState { LexState(); std::set expected_inputs() const; bool operator==(const LexState &) const; - void each_advance_action(std::function); + void each_referenced_state(std::function); std::map advance_actions; AcceptTokenAction accept_action; bool is_token_start; }; -typedef int64_t LexStateId; - class LexTable { public: LexStateId add_state(); diff --git a/src/compiler/parse_table.cc b/src/compiler/parse_table.cc index ef0e235d..47218d36 100644 --- a/src/compiler/parse_table.cc +++ b/src/compiler/parse_table.cc @@ -125,29 +125,34 @@ bool ParseTableEntry::operator==(const ParseTableEntry &other) const { ParseState::ParseState() : lex_state_id(-1) {} bool ParseState::has_shift_action() const { - for (const auto &pair : entries) + for (const auto &pair : terminal_entries) if (pair.second.actions.size() > 0 && pair.second.actions.back().type == ParseActionTypeShift) return true; - return false; + return (!nonterminal_entries.empty()); } set ParseState::expected_inputs() const { set result; - for (auto &entry : entries) - result.insert(entry.first); + for (auto &entry : terminal_entries) + result.insert(Symbol(entry.first, true)); + for (auto &entry : nonterminal_entries) + result.insert(Symbol(entry.first, false)); return result; } -void ParseState::each_advance_action(function fn) { - for (auto &entry : entries) +void ParseState::each_referenced_state(function fn) { + for (auto &entry : terminal_entries) for (ParseAction &action : entry.second.actions) if (action.type == ParseActionTypeShift || ParseActionTypeRecover) - fn(&action); + fn(&action.state_index); + for (auto &entry : nonterminal_entries) + fn(&entry.second); } bool ParseState::operator==(const ParseState &other) const { - return entries == other.entries; + return terminal_entries == other.terminal_entries && + nonterminal_entries == other.nonterminal_entries; } set ParseTable::all_symbols() const { @@ -162,35 +167,34 @@ ParseStateId ParseTable::add_state() { return states.size() - 1; } -ParseAction &ParseTable::set_action(ParseStateId id, Symbol symbol, - ParseAction action) { - if (action.type == ParseActionTypeShift && action.extra) - symbols[symbol].extra = true; - else - symbols[symbol].structural = true; - - states[id].entries[symbol].actions = { action }; - return *states[id].entries[symbol].actions.begin(); +ParseAction &ParseTable::set_terminal_action(ParseStateId state_id, int index, + ParseAction action) { + states[state_id].terminal_entries[index].actions.clear(); + return add_terminal_action(state_id, index, action); } -ParseAction &ParseTable::add_action(ParseStateId id, Symbol symbol, - ParseAction action) { +ParseAction &ParseTable::add_terminal_action(ParseStateId state_id, int index, + ParseAction action) { + Symbol symbol(index, true); if (action.type == ParseActionTypeShift && action.extra) symbols[symbol].extra = true; else symbols[symbol].structural = true; - ParseState &state = states[id]; - for (ParseAction &existing_action : state.entries[symbol].actions) - if (existing_action == action) - return existing_action; + ParseTableEntry &entry = states[state_id].terminal_entries[index]; + entry.actions.push_back(action); + return *entry.actions.rbegin(); +} - state.entries[symbol].actions.push_back(action); - return *state.entries[symbol].actions.rbegin(); +void ParseTable::set_nonterminal_action(ParseStateId state_id, int index, + ParseStateId next_state_id) { + Symbol symbol(index, false); + symbols[symbol].structural = true; + states[state_id].nonterminal_entries[index] = next_state_id; } static bool has_entry(const ParseState &state, const ParseTableEntry &entry) { - for (const auto &pair : state.entries) + for (const auto &pair : state.terminal_entries) if (pair.second == entry) return true; return false; @@ -200,13 +204,16 @@ bool ParseTable::merge_state(size_t i, size_t j) { ParseState &state = states[i]; ParseState &other = states[j]; - for (auto &entry : state.entries) { - const Symbol &symbol = entry.first; + if (state.nonterminal_entries != other.nonterminal_entries) + return false; + + for (auto &entry : state.terminal_entries) { + Symbol symbol(entry.first, true); const vector &actions = entry.second.actions; - const auto &other_entry = other.entries.find(symbol); - if (other_entry == other.entries.end()) { - if (mergeable_symbols.count(symbol) == 0 && !symbol.is_built_in() && symbol.is_token) + const auto &other_entry = other.terminal_entries.find(symbol.index); + if (other_entry == other.terminal_entries.end()) { + if (mergeable_symbols.count(symbol) == 0 && !symbol.is_built_in()) return false; if (actions.back().type != ParseActionTypeReduce) return false; @@ -219,12 +226,12 @@ bool ParseTable::merge_state(size_t i, size_t j) { set symbols_to_merge; - for (auto &entry : other.entries) { - const Symbol &symbol = entry.first; + for (auto &entry : other.terminal_entries) { + Symbol symbol(entry.first, true); const vector &actions = entry.second.actions; - if (!state.entries.count(symbol)) { - if (mergeable_symbols.count(symbol) == 0 && !symbol.is_built_in() && symbol.is_token) + if (!state.terminal_entries.count(symbol.index)) { + if (mergeable_symbols.count(symbol) == 0 && !symbol.is_built_in()) return false; if (actions.back().type != ParseActionTypeReduce) return false; @@ -235,7 +242,7 @@ bool ParseTable::merge_state(size_t i, size_t j) { } for (const Symbol &symbol : symbols_to_merge) - state.entries[symbol] = other.entries.find(symbol)->second; + state.terminal_entries[symbol.index] = other.terminal_entries.find(symbol.index)->second; return true; } diff --git a/src/compiler/parse_table.h b/src/compiler/parse_table.h index cf1b2a1b..6c883c26 100644 --- a/src/compiler/parse_table.h +++ b/src/compiler/parse_table.h @@ -1,6 +1,7 @@ #ifndef COMPILER_PARSE_TABLE_H_ #define COMPILER_PARSE_TABLE_H_ +#include #include #include #include @@ -13,7 +14,7 @@ namespace tree_sitter { -typedef uint64_t ParseStateId; +typedef int64_t ParseStateId; enum ParseActionType { ParseActionTypeError, @@ -72,10 +73,11 @@ class ParseState { std::set expected_inputs() const; bool operator==(const ParseState &) const; bool merge(const ParseState &); - void each_advance_action(std::function); + void each_referenced_state(std::function); bool has_shift_action() const; - std::map entries; + std::map terminal_entries; + std::map nonterminal_entries; LexStateId lex_state_id; }; @@ -88,10 +90,9 @@ class ParseTable { public: std::set all_symbols() const; ParseStateId add_state(); - ParseAction &set_action(ParseStateId state_id, rules::Symbol symbol, - ParseAction action); - ParseAction &add_action(ParseStateId state_id, rules::Symbol symbol, - ParseAction action); + ParseAction &add_terminal_action(ParseStateId state_id, int, ParseAction); + ParseAction &set_terminal_action(ParseStateId state_id, int index, ParseAction); + void set_nonterminal_action(ParseStateId state_id, int index, ParseStateId); bool merge_state(size_t i, size_t j); std::vector states; diff --git a/src/compiler/rules/symbol.cc b/src/compiler/rules/symbol.cc index cdfb78cf..697a3465 100644 --- a/src/compiler/rules/symbol.cc +++ b/src/compiler/rules/symbol.cc @@ -37,9 +37,9 @@ string Symbol::to_string() const { } bool Symbol::operator<(const Symbol &other) const { - if (!is_token && other.is_token) - return true; if (is_token && !other.is_token) + return true; + if (!is_token && other.is_token) return false; return (index < other.index); } diff --git a/src/runtime/language.c b/src/runtime/language.c index 0bc4ae7e..78ce0a7f 100644 --- a/src/runtime/language.c +++ b/src/runtime/language.c @@ -19,6 +19,7 @@ void ts_language_table_entry(const TSLanguage *self, TSStateId state, } action_index = 0; } else { + assert(symbol < self->token_count); action_index = self->parse_table[state * self->symbol_count + symbol]; } diff --git a/src/runtime/language.h b/src/runtime/language.h index 3941d875..7aefeed9 100644 --- a/src/runtime/language.h +++ b/src/runtime/language.h @@ -40,6 +40,23 @@ static inline const TSParseAction *ts_language_last_action( return NULL; } +static inline TSStateId ts_language_next_state(const TSLanguage *self, + TSStateId state, + TSSymbol symbol) { + if (symbol == ts_builtin_sym_error) { + return 0; + } else if (symbol < self->token_count) { + const TSParseAction *action = ts_language_last_action(self, state, symbol); + if (action && (action->type == TSParseActionTypeShift || action->type == TSParseActionTypeRecover)) { + return action->params.to_state; + } else { + return 0; + } + } else { + return self->parse_table[state * self->symbol_count + symbol]; + } +} + static inline bool ts_language_is_reusable(const TSLanguage *self, TSStateId state, TSSymbol symbol) { TableEntry entry; diff --git a/src/runtime/parser.c b/src/runtime/parser.c index 04556ebb..e5b6f517 100644 --- a/src/runtime/parser.c +++ b/src/runtime/parser.c @@ -87,11 +87,7 @@ static bool parser__breakdown_top_of_stack(Parser *self, StackVersion version) { if (child->symbol == ts_builtin_sym_error) { state = ERROR_STATE; } else if (!child->extra) { - const TSParseAction *action = - ts_language_last_action(self->language, state, child->symbol); - assert(action && (action->type == TSParseActionTypeShift || - action->type == TSParseActionTypeRecover)); - state = action->params.to_state; + state = ts_language_next_state(self->language, state, child->symbol); } ts_stack_push(self->stack, slice.version, child, pending, state); @@ -486,13 +482,8 @@ static Reduction parser__reduce(Parser *self, StackVersion version, parent->parse_state = state; } - const TSParseAction *action = - ts_language_last_action(language, state, symbol); - assert(action->type == TSParseActionTypeShift || - action->type == TSParseActionTypeRecover); - - if (action->type == TSParseActionTypeRecover && child_count > 1 && - allow_skipping) { + TSStateId next_state = ts_language_next_state(language, state, symbol); + if (state == ERROR_STATE && allow_skipping) { StackVersion other_version = ts_stack_duplicate_version(self->stack, slice.version); @@ -508,10 +499,10 @@ static Reduction parser__reduce(Parser *self, StackVersion version, ts_stack_remove_version(self->stack, other_version); } - parser__push(self, slice.version, parent, action->params.to_state); + parser__push(self, slice.version, parent, next_state); for (size_t j = parent->child_count; j < slice.trees.size; j++) { Tree *tree = slice.trees.contents[j]; - parser__push(self, slice.version, tree, action->params.to_state); + parser__push(self, slice.version, tree, next_state); } } @@ -540,26 +531,24 @@ static inline const TSParseAction *parser__reductions_after_sequence( if (child_count == tree_count_below) break; Tree *tree = trees_below->contents[trees_below->size - 1 - i]; - const TSParseAction *action = - ts_language_last_action(self->language, state, tree->symbol); - if (!action || action->type != TSParseActionTypeShift) + TSStateId next_state = ts_language_next_state(self->language, state, tree->symbol); + if (next_state == ERROR_STATE) return NULL; - if (action->extra || tree->extra) - continue; - child_count++; - state = action->params.to_state; + if (next_state != state) { + child_count++; + state = next_state; + } } for (size_t i = 0; i < trees_above->size; i++) { Tree *tree = trees_above->contents[i]; - const TSParseAction *action = - ts_language_last_action(self->language, state, tree->symbol); - if (!action || action->type != TSParseActionTypeShift) + TSStateId next_state = ts_language_next_state(self->language, state, tree->symbol); + if (next_state == ERROR_STATE) return NULL; - if (action->extra || tree->extra) - continue; - child_count++; - state = action->params.to_state; + if (next_state != state) { + child_count++; + state = next_state; + } } const TSParseAction *actions = @@ -610,15 +599,9 @@ static StackIterateAction parser__error_repair_callback( continue; } - const TSParseAction *repair_symbol_action = - ts_language_last_action(self->language, state, repair->symbol); - if (!repair_symbol_action || - repair_symbol_action->type != TSParseActionTypeShift) - continue; - - TSStateId state_after_repair = repair_symbol_action->params.to_state; - if (!ts_language_last_action(self->language, state_after_repair, - lookahead_symbol)) + TSStateId state_after_repair = ts_language_next_state(self->language, state, repair->symbol); + if (state == ERROR_STATE || state_after_repair == ERROR_STATE || + !ts_language_last_action(self->language, state_after_repair, lookahead_symbol)) continue; if (count_needed_below_error != last_repair_count) { @@ -795,7 +778,7 @@ static bool parser__do_potential_reductions( size_t previous_version_count = ts_stack_version_count(self->stack); array_clear(&self->reduce_actions); - for (TSSymbol symbol = 0; symbol < self->language->symbol_count; symbol++) { + for (TSSymbol symbol = 0; symbol < self->language->token_count; symbol++) { TableEntry entry; ts_language_table_entry(self->language, state, symbol, &entry); for (size_t i = 0; i < entry.action_count; i++) { @@ -915,6 +898,9 @@ static void parser__handle_error(Parser *self, StackVersion version, ts_stack_push(self->stack, version, NULL, false, ERROR_STATE); while (ts_stack_version_count(self->stack) > previous_version_count) { ts_stack_push(self->stack, previous_version_count, NULL, false, ERROR_STATE); + + LOG_STACK(); + assert(ts_stack_merge(self->stack, version, previous_version_count)); } } @@ -982,6 +968,17 @@ static void parser__advance(Parser *self, StackVersion version, switch (action.type) { case TSParseActionTypeShift: { + bool extra = action.extra; + TSStateId next_state; + + if (action.extra) { + next_state = state; + LOG("shift_extra"); + } else { + next_state = action.params.to_state; + LOG("shift state:%u", next_state); + } + if (lookahead->child_count > 0) { if (parser__breakdown_lookahead(self, &lookahead, state, reusable_node)) { @@ -992,20 +989,10 @@ static void parser__advance(Parser *self, StackVersion version, } } - action = *ts_language_last_action(self->language, state, - lookahead->symbol); + next_state = ts_language_next_state(self->language, state, lookahead->symbol); } - TSStateId next_state; - if (action.extra) { - next_state = state; - LOG("shift_extra"); - } else { - next_state = action.params.to_state; - LOG("shift state:%u", next_state); - } - - parser__shift(self, version, next_state, lookahead, action.extra); + parser__shift(self, version, next_state, lookahead, extra); if (lookahead == reusable_node->tree) parser__pop_reusable_node(reusable_node);