From 22c550c9d6ccf97806e17ff01c25b7e11cc4b443 Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Mon, 9 May 2016 14:31:44 -0700 Subject: [PATCH] Discard tokens after error detection to find the best repair * Use GLR stack-splitting to try all numbers of tokens to discard until a repair is found. * Check the validity of repairs by looking at the child trees, rather than the statically-computed 'in-progress symbols' list --- include/tree_sitter/parser.h | 40 +- project.gyp | 1 - .../symbols_by_first_symbol_spec.cc | 116 ------ spec/runtime/parser_spec.cc | 8 +- spec/runtime/stack_spec.cc | 2 +- .../build_tables/build_parse_table.cc | 37 +- src/compiler/build_tables/parse_item.cc | 5 + src/compiler/build_tables/parse_item.h | 1 + .../build_tables/symbols_by_first_symbol.cc | 50 --- .../build_tables/symbols_by_first_symbol.h | 20 - src/compiler/generate_code/c_code.cc | 49 +-- src/compiler/parse_table.h | 1 - src/runtime/language.c | 48 +-- src/runtime/language.h | 9 +- src/runtime/parser.c | 392 +++++++++++------- src/runtime/parser.h | 6 +- src/runtime/reduce_action.h | 32 ++ src/runtime/stack.c | 96 +++-- src/runtime/stack.h | 7 + src/runtime/tree.c | 39 +- src/runtime/tree.h | 1 + 21 files changed, 454 insertions(+), 506 deletions(-) delete mode 100644 spec/compiler/build_tables/symbols_by_first_symbol_spec.cc delete mode 100644 src/compiler/build_tables/symbols_by_first_symbol.cc delete mode 100644 src/compiler/build_tables/symbols_by_first_symbol.h create mode 100644 src/runtime/reduce_action.h diff --git a/include/tree_sitter/parser.h b/include/tree_sitter/parser.h index ccccb7ba..19965c07 100644 --- a/include/tree_sitter/parser.h +++ b/include/tree_sitter/parser.h @@ -60,6 +60,7 @@ typedef enum { TSParseActionTypeShift, TSParseActionTypeReduce, TSParseActionTypeAccept, + TSParseActionTypeRecover, } TSParseActionType; typedef struct { @@ -92,10 +93,8 @@ struct TSLanguage { const TSSymbolMetadata *symbol_metadata; const unsigned short *parse_table; const TSParseActionEntry *parse_actions; - const unsigned short *in_progress_symbol_table; - const TSInProgressSymbolEntry *in_progress_symbols; const TSStateId *lex_states; - const TSStateId *recovery_states; + const TSParseAction *recovery_actions; TSTree *(*lex_fn)(TSLexer *, TSStateId, bool); }; @@ -164,6 +163,11 @@ enum { } \ } +#define RECOVER(to_state_value) \ + { \ + .type = TSParseActionTypeRecover, .data = {.to_state = to_state_value } \ + } + #define SHIFT_EXTRA() \ { \ { .type = TSParseActionTypeShift, .extra = true } \ @@ -191,22 +195,20 @@ enum { { .type = TSParseActionTypeAccept } \ } -#define EXPORT_LANGUAGE(language_name) \ - static TSLanguage language = { \ - .symbol_count = SYMBOL_COUNT, \ - .symbol_metadata = ts_symbol_metadata, \ - .parse_table = (const unsigned short *)ts_parse_table, \ - .parse_actions = ts_parse_actions, \ - .in_progress_symbol_table = ts_in_progress_symbol_table, \ - .in_progress_symbols = ts_in_progress_symbols, \ - .recovery_states = ts_recovery_states, \ - .lex_states = ts_lex_states, \ - .symbol_names = ts_symbol_names, \ - .lex_fn = ts_lex, \ - }; \ - \ - const TSLanguage *language_name() { \ - return &language; \ +#define EXPORT_LANGUAGE(language_name) \ + static TSLanguage language = { \ + .symbol_count = SYMBOL_COUNT, \ + .symbol_metadata = ts_symbol_metadata, \ + .parse_table = (const unsigned short *)ts_parse_table, \ + .parse_actions = ts_parse_actions, \ + .recovery_actions = ts_recovery_actions, \ + .lex_states = ts_lex_states, \ + .symbol_names = ts_symbol_names, \ + .lex_fn = ts_lex, \ + }; \ + \ + const TSLanguage *language_name() { \ + return &language; \ } #ifdef __cplusplus diff --git a/project.gyp b/project.gyp index 74327d98..74124b15 100644 --- a/project.gyp +++ b/project.gyp @@ -23,7 +23,6 @@ 'src/compiler/build_tables/parse_item.cc', 'src/compiler/build_tables/parse_conflict_manager.cc', 'src/compiler/build_tables/rule_can_be_blank.cc', - 'src/compiler/build_tables/symbols_by_first_symbol.cc', 'src/compiler/compile.cc', 'src/compiler/generate_code/c_code.cc', 'src/compiler/lex_table.cc', diff --git a/spec/compiler/build_tables/symbols_by_first_symbol_spec.cc b/spec/compiler/build_tables/symbols_by_first_symbol_spec.cc deleted file mode 100644 index b6948368..00000000 --- a/spec/compiler/build_tables/symbols_by_first_symbol_spec.cc +++ /dev/null @@ -1,116 +0,0 @@ -#include "spec_helper.h" -#include "helpers/stream_methods.h" -#include "compiler/build_tables/symbols_by_first_symbol.h" -#include "compiler/syntax_grammar.h" - -using namespace rules; -using build_tables::symbols_by_first_symbol; - -START_TEST - -describe("symbols_by_first_symbol", [&]() { - SyntaxGrammar grammar{{ - - // starts with token-11 and token-13 - SyntaxVariable("rule-0", VariableTypeNamed, vector({ - Production({ - ProductionStep(Symbol(11, true), 0, rules::AssociativityNone), - ProductionStep(Symbol(12, true), 0, rules::AssociativityNone), - }), - Production({ - ProductionStep(Symbol(13, true), 0, rules::AssociativityNone), - ProductionStep(Symbol(14, true), 0, rules::AssociativityNone), - }), - })), - - // starts with rule-0, which implies token-11 and token-13 - SyntaxVariable("rule-1", VariableTypeNamed, vector({ - Production({ - ProductionStep(Symbol(0), 0, rules::AssociativityNone), - ProductionStep(Symbol(12, true), 0, rules::AssociativityNone), - }), - })), - - // starts with token-15 and rule-1, which implies token-11 and token-13 - SyntaxVariable("rule-2", VariableTypeNamed, vector({ - Production({ - ProductionStep(Symbol(1), 0, rules::AssociativityNone), - }), - Production({ - ProductionStep(Symbol(15, true), 0, rules::AssociativityNone), - }), - })), - - // starts with token-15 - SyntaxVariable("rule-3", VariableTypeNamed, vector({ - Production({ - ProductionStep(Symbol(15, true), 0, rules::AssociativityNone), - }), - })) - }, {}, {}}; - - it("gives the set of non-terminals that can start with any given terminal", [&]() { - auto result = symbols_by_first_symbol(grammar); - - AssertThat(result, Equals(map>({ - { - Symbol(11, true), { - Symbol(11, true), - Symbol(0), - Symbol(1), - Symbol(2), - } - }, - { - Symbol(12, true), { - Symbol(12, true), - } - }, - { - Symbol(13, true), { - Symbol(13, true), - Symbol(0), - Symbol(1), - Symbol(2), - } - }, - { - Symbol(14, true), { - Symbol(14, true), - } - }, - { - Symbol(15, true), { - Symbol(15, true), - Symbol(2), - Symbol(3) - } - }, - { - Symbol(0), { - Symbol(0), - Symbol(1), - Symbol(2), - } - }, - { - Symbol(1), { - Symbol(1), - Symbol(2), - } - }, - { - Symbol(2), { - Symbol(2), - } - }, - { - Symbol(3), { - Symbol(3), - } - } - }))); - }); -}); - -END_TEST diff --git a/spec/runtime/parser_spec.cc b/spec/runtime/parser_spec.cc index e1616520..020aa319 100644 --- a/spec/runtime/parser_spec.cc +++ b/spec/runtime/parser_spec.cc @@ -100,7 +100,7 @@ describe("Parser", [&]() { AssertThat(ts_node_name(error, doc), Equals("ERROR")); AssertThat(ts_node_start_byte(error), Equals(strlen(" [123, "))); - AssertThat(ts_node_end_byte(error), Equals(strlen(" [123, @@@@@"))); + AssertThat(ts_node_end_byte(error), Equals(strlen(" [123, @@@@@,"))); AssertThat(ts_node_name(last, doc), Equals("true")); AssertThat(ts_node_start_byte(last), Equals(strlen(" [123, @@@@@, "))) @@ -121,7 +121,7 @@ describe("Parser", [&]() { AssertThat(ts_node_name(error, doc), Equals("ERROR")); AssertThat(ts_node_start_byte(error), Equals(strlen(" [123, "))) - AssertThat(ts_node_end_byte(error), Equals(strlen(" [123, faaaaalse"))) + AssertThat(ts_node_end_byte(error), Equals(strlen(" [123, faaaaalse,"))) AssertThat(ts_node_name(last, doc), Equals("true")); AssertThat(ts_node_start_byte(last), Equals(strlen(" [123, faaaaalse, "))); @@ -159,7 +159,7 @@ describe("Parser", [&]() { AssertThat(ts_node_name(error, doc), Equals("ERROR")); AssertThat(ts_node_start_byte(error), Equals(strlen(" [123, "))); - AssertThat(ts_node_end_byte(error), Equals(strlen(" [123, "))) + AssertThat(ts_node_end_byte(error), Equals(strlen(" [123, ,"))) AssertThat(ts_node_name(last, doc), Equals("true")); AssertThat(ts_node_start_byte(last), Equals(strlen(" [123, , "))); @@ -378,7 +378,7 @@ describe("Parser", [&]() { assert_root_node( "(program " - "(expression_statement (ERROR (number) (UNEXPECTED '4')) (number)) " + "(expression_statement (number) (ERROR (UNEXPECTED '4') (number))) " "(expression_statement (math_op (number) (number))))"); }); }); diff --git a/spec/runtime/stack_spec.cc b/spec/runtime/stack_spec.cc index df015464..d1601a40 100644 --- a/spec/runtime/stack_spec.cc +++ b/spec/runtime/stack_spec.cc @@ -53,7 +53,7 @@ vector get_stack_entries(Stack *stack, StackVersion version) { ts_stack_iterate( stack, version, - [](void *payload, TSStateId state, size_t tree_count, bool is_done, bool is_pending) -> StackIterateAction { + [](void *payload, TSStateId state, TreeArray *trees, size_t tree_count, bool is_done, bool is_pending) -> StackIterateAction { auto entries = static_cast *>(payload); StackEntry entry = {state, tree_count}; if (find(entries->begin(), entries->end(), entry) == entries->end()) diff --git a/src/compiler/build_tables/build_parse_table.cc b/src/compiler/build_tables/build_parse_table.cc index 33ee970b..ebc9d969 100644 --- a/src/compiler/build_tables/build_parse_table.cc +++ b/src/compiler/build_tables/build_parse_table.cc @@ -10,7 +10,6 @@ #include "compiler/build_tables/remove_duplicate_states.h" #include "compiler/build_tables/parse_item.h" #include "compiler/build_tables/item_set_closure.h" -#include "compiler/build_tables/symbols_by_first_symbol.h" #include "compiler/lexical_grammar.h" #include "compiler/syntax_grammar.h" #include "compiler/rules/symbol.h" @@ -35,6 +34,7 @@ class ParseTableBuilder { const SyntaxGrammar grammar; const LexicalGrammar lexical_grammar; ParseConflictManager conflict_manager; + unordered_map recovery_states; unordered_map parse_state_ids; vector> item_sets_to_process; ParseTable parse_table; @@ -106,35 +106,21 @@ class ParseTableBuilder { } void add_out_of_context_parse_states() { - auto symbols_by_first = symbols_by_first_symbol(grammar); - for (const Symbol &symbol : recovery_tokens(lexical_grammar)) { - add_out_of_context_parse_state(symbol, symbols_by_first[symbol]); + add_out_of_context_parse_state(symbol); } for (size_t i = 0; i < grammar.variables.size(); i++) { Symbol symbol(i, false); - add_out_of_context_parse_state(symbol, symbols_by_first[symbol]); + add_out_of_context_parse_state(symbol); } - parse_table.error_state.actions[rules::END_OF_INPUT()].clear(); + parse_table.error_state.actions[rules::END_OF_INPUT()].push_back( + ParseAction::Shift(0, PrecedenceRange())); } - void add_out_of_context_parse_state(const rules::Symbol &symbol, - const set &symbols) { - ParseItemSet item_set; - - for (const auto &parse_state_entry : parse_state_ids) { - ParseItemSet state_item_set = parse_state_entry.first; - for (const auto &pair : state_item_set.entries) { - const ParseItem &item = pair.first; - const LookaheadSet &lookahead_set = pair.second; - if (symbols.count(item.next_symbol())) { - item_set.entries[item].insert_all(lookahead_set); - } - } - } - + void add_out_of_context_parse_state(const rules::Symbol &symbol) { + const ParseItemSet &item_set = recovery_states[symbol]; if (!item_set.entries.empty()) { ParseStateId state = add_parse_state(item_set); parse_table.error_state.actions[symbol].push_back( @@ -146,11 +132,6 @@ class ParseTableBuilder { auto pair = parse_state_ids.find(item_set); if (pair == parse_state_ids.end()) { ParseStateId state_id = parse_table.add_state(); - for (const auto &entry : item_set.entries) { - const ParseItem &item = entry.first; - if (item.step_index > 0 && item.lhs() != rules::START()) - parse_table.states[state_id].in_progress_symbols.insert(item.lhs()); - } parse_state_ids[item_set] = state_id; item_sets_to_process.push_back({ item_set, state_id }); @@ -168,6 +149,10 @@ class ParseTableBuilder { ParseAction *new_action = add_action( state_id, symbol, ParseAction::Shift(0, precedence), item_set); + + if (!allow_any_conflict) + recovery_states[symbol].add(next_item_set); + if (new_action) new_action->state_index = add_parse_state(next_item_set); } diff --git a/src/compiler/build_tables/parse_item.cc b/src/compiler/build_tables/parse_item.cc index 01ea3ec1..8c4dfe40 100644 --- a/src/compiler/build_tables/parse_item.cc +++ b/src/compiler/build_tables/parse_item.cc @@ -128,5 +128,10 @@ ParseItemSet::TransitionMap ParseItemSet::transitions() const { return result; } +void ParseItemSet::add(const ParseItemSet &other) { + for (const auto &pair : other.entries) + entries[pair.first].insert_all(pair.second); +} + } // namespace build_tables } // namespace tree_sitter diff --git a/src/compiler/build_tables/parse_item.h b/src/compiler/build_tables/parse_item.h index 31a4a59a..404b67c3 100644 --- a/src/compiler/build_tables/parse_item.h +++ b/src/compiler/build_tables/parse_item.h @@ -53,6 +53,7 @@ class ParseItemSet { TransitionMap transitions() const; bool operator==(const ParseItemSet &) const; + void add(const ParseItemSet &); std::map entries; }; diff --git a/src/compiler/build_tables/symbols_by_first_symbol.cc b/src/compiler/build_tables/symbols_by_first_symbol.cc deleted file mode 100644 index 526e18c5..00000000 --- a/src/compiler/build_tables/symbols_by_first_symbol.cc +++ /dev/null @@ -1,50 +0,0 @@ -#include "compiler/build_tables/symbols_by_first_symbol.h" -#include "compiler/syntax_grammar.h" -#include "compiler/rules/symbol.h" - -namespace tree_sitter { -namespace build_tables { - -using std::map; -using std::set; -using rules::Symbol; - -map> symbols_by_first_symbol(const SyntaxGrammar &grammar) { - map> result; - - size_t variable_index = -1; - for (const SyntaxVariable &variable : grammar.variables) { - variable_index++; - Symbol symbol(variable_index); - result[symbol].insert(symbol); - for (const Production &production : variable.productions) - if (!production.empty()) { - Symbol first_symbol = production[0].symbol; - result[first_symbol].insert(symbol); - - for (const ProductionStep &step : production) { - result[step.symbol].insert(step.symbol); - } - } - } - - bool done = false; - while (!done) { - done = true; - for (auto &entry : result) { - set new_symbols; - for (const Symbol &symbol : entry.second) - for (const Symbol &other_symbol : result[symbol]) - new_symbols.insert(other_symbol); - - for (const Symbol &new_symbol : new_symbols) - if (entry.second.insert(new_symbol).second) - done = false; - } - } - - return result; -} - -} // namespace build_tables -} // namespace tree_sitter diff --git a/src/compiler/build_tables/symbols_by_first_symbol.h b/src/compiler/build_tables/symbols_by_first_symbol.h deleted file mode 100644 index 9a2b5ed2..00000000 --- a/src/compiler/build_tables/symbols_by_first_symbol.h +++ /dev/null @@ -1,20 +0,0 @@ -#ifndef COMPILER_BUILD_TABLES_SYMBOLS_BY_FIRST_SYMBOL_H_ -#define COMPILER_BUILD_TABLES_SYMBOLS_BY_FIRST_SYMBOL_H_ - -#include -#include -#include "compiler/rules/symbol.h" - -namespace tree_sitter { - -struct SyntaxGrammar; - -namespace build_tables { - -std::map> symbols_by_first_symbol( - const SyntaxGrammar &); - -} // namespace build_tables -} // namespace tree_sitter - -#endif // COMPILER_BUILD_TABLES_SYMBOLS_BY_FIRST_SYMBOL_H_ diff --git a/src/compiler/generate_code/c_code.cc b/src/compiler/generate_code/c_code.cc index 60268376..2fd14e4d 100644 --- a/src/compiler/generate_code/c_code.cc +++ b/src/compiler/generate_code/c_code.cc @@ -102,7 +102,6 @@ class CCodeGenerator { add_lex_states_list(); add_recovery_parse_states_list(); add_parse_table(); - add_in_progress_symbol_table(); add_parser_export(); return buffer; @@ -221,13 +220,14 @@ class CCodeGenerator { } void add_recovery_parse_states_list() { - line("static TSStateId ts_recovery_states[SYMBOL_COUNT] = {"); + line("static TSParseAction ts_recovery_actions[SYMBOL_COUNT] = {"); indent([&]() { for (const auto &entry : parse_table.error_state.actions) { const rules::Symbol &symbol = entry.first; if (!entry.second.empty()) { ParseStateId state = entry.second[0].state_index; - line("[" + symbol_id(symbol) + "] = " + to_string(state) + ","); + line("[" + symbol_id(symbol) + "] = RECOVER(" + to_string(state) + + "),"); } } }); @@ -266,29 +266,6 @@ class CCodeGenerator { line(); } - void add_in_progress_symbol_table() { - add_in_progress_symbol_list_id({}); - line("static unsigned short ts_in_progress_symbol_table[STATE_COUNT] = {"); - - indent([&]() { - size_t state_id = 0; - for (const ParseState &state : parse_table.states) { - if (!state.in_progress_symbols.empty()) { - line("[" + to_string(state_id) + "] = "); - add(to_string( - add_in_progress_symbol_list_id(state.in_progress_symbols))); - add(","); - } - state_id++; - } - }); - - line("};"); - line(); - add_in_progress_symbols_list(); - line(); - } - void add_parser_export() { line("EXPORT_LANGUAGE(ts_language_" + name + ");"); line(); @@ -408,26 +385,6 @@ class CCodeGenerator { line("};"); } - void add_in_progress_symbols_list() { - line("static TSInProgressSymbolEntry ts_in_progress_symbols[] = {"); - - indent([&]() { - for (const auto &pair : in_progress_symbols) { - size_t index = pair.first; - line("[" + to_string(index) + "] = {.count = " + - to_string(pair.second.size()) + "},"); - - for (const rules::Symbol &symbol : pair.second) { - add(" "); - add("{" + symbol_id(symbol) + "}"); - add(","); - } - } - }); - - line("};"); - } - size_t add_parse_action_list_id(const vector &actions) { for (const auto &pair : parse_actions) { if (pair.second == actions) { diff --git a/src/compiler/parse_table.h b/src/compiler/parse_table.h index a31ca916..8510c5fa 100644 --- a/src/compiler/parse_table.h +++ b/src/compiler/parse_table.h @@ -84,7 +84,6 @@ class ParseState { void each_advance_action(std::function); std::map> actions; - std::set in_progress_symbols; LexStateId lex_state_id; }; diff --git a/src/runtime/language.c b/src/runtime/language.c index 712f6a90..e0106893 100644 --- a/src/runtime/language.c +++ b/src/runtime/language.c @@ -1,32 +1,33 @@ #include "tree_sitter/parser.h" +#include "runtime/language.h" +#include "runtime/tree.h" -const TSParseAction *ts_language_actions(const TSLanguage *language, - TSStateId state, TSSymbol symbol, - size_t *count) { - size_t action_index = 0; - if (symbol != ts_builtin_sym_error) { - if (state == ts_parse_state_error) - state = language->recovery_states[symbol]; - action_index = - (language->parse_table + (state * language->symbol_count))[symbol]; +const TSParseAction *ts_language_actions(const TSLanguage *self, TSStateId state, + TSSymbol symbol, size_t *count) { + if (state == ts_parse_state_error) { + *count = 1; + return &self->recovery_actions[symbol]; } - *count = language->parse_actions[action_index].count; - const TSParseActionEntry *entry = language->parse_actions + action_index + 1; + size_t action_index = 0; + if (symbol != ts_builtin_sym_error) + action_index = self->parse_table[state * self->symbol_count + symbol]; + + *count = self->parse_actions[action_index].count; + const TSParseActionEntry *entry = self->parse_actions + action_index + 1; return (const TSParseAction *)entry; } -TSParseAction ts_language_last_action(const TSLanguage *language, - TSStateId state, TSSymbol sym) { +TSParseAction ts_language_last_action(const TSLanguage *self, TSStateId state, + TSSymbol sym) { size_t count; - const TSParseAction *actions = - ts_language_actions(language, state, sym, &count); + const TSParseAction *actions = ts_language_actions(self, state, sym, &count); return actions[count - 1]; } -bool ts_language_has_action(const TSLanguage *language, TSStateId state, +bool ts_language_has_action(const TSLanguage *self, TSStateId state, TSSymbol symbol) { - TSParseAction action = ts_language_last_action(language, state, symbol); + TSParseAction action = ts_language_last_action(self, state, symbol); return action.type != TSParseActionTypeError; } @@ -50,16 +51,3 @@ const char *ts_language_symbol_name(const TSLanguage *language, TSSymbol symbol) else return language->symbol_names[symbol]; } - -bool ts_language_symbol_is_in_progress(const TSLanguage *self, TSStateId state, - TSSymbol symbol) { - if (state == ts_parse_state_error) - return false; - unsigned index = self->in_progress_symbol_table[state]; - unsigned short count = self->in_progress_symbols[index].count; - const TSInProgressSymbolEntry *entries = self->in_progress_symbols + index + 1; - for (size_t i = 0; i < count; i++) - if (entries[i].symbol == symbol) - return true; - return false; -} diff --git a/src/runtime/language.h b/src/runtime/language.h index 28c4a8e9..d5f95297 100644 --- a/src/runtime/language.h +++ b/src/runtime/language.h @@ -6,6 +6,7 @@ extern "C" { #endif #include "tree_sitter/parser.h" +#include "runtime/tree.h" bool ts_language_symbol_is_in_progress(const TSLanguage *, TSStateId, TSSymbol); @@ -15,8 +16,12 @@ TSParseAction ts_language_last_action(const TSLanguage *, TSStateId, TSSymbol); bool ts_language_has_action(const TSLanguage *, TSStateId, TSSymbol); -TSSymbolMetadata ts_language_symbol_metadata(const TSLanguage *language, - TSSymbol symbol); +TSSymbolMetadata ts_language_symbol_metadata(const TSLanguage *, TSSymbol); + +static inline TSStateId ts_language_lex_state(const TSLanguage *self, + TSStateId state) { + return state == ts_parse_state_error ? 0 : self->lex_states[state]; +} #ifdef __cplusplus } diff --git a/src/runtime/parser.c b/src/runtime/parser.c index 304241db..52e2469d 100644 --- a/src/runtime/parser.c +++ b/src/runtime/parser.c @@ -10,6 +10,7 @@ #include "runtime/array.h" #include "runtime/language.h" #include "runtime/alloc.h" +#include "runtime/reduce_action.h" #define LOG(...) \ if (self->lexer.debugger.debug_fn) { \ @@ -51,19 +52,12 @@ typedef struct { size_t char_index; } ReusableNode; -struct ErrorRepair { - TSSymbol symbol; - size_t in_progress_state_count; - size_t count_below_error; -}; - typedef struct { TSParser *parser; TSSymbol lookahead_symbol; - ErrorRepairArray *repairs; + TreeArray *trees_above_error; bool found_repair; - ErrorRepair best_repair; - TSStateId best_repair_state; + ReduceAction best_repair; TSStateId best_repair_next_state; size_t best_repair_skip_count; } ErrorRepairSession; @@ -198,7 +192,7 @@ static bool ts_parser__can_reuse(TSParser *self, StackVersion version, TSStateId top_state = ts_stack_top_state(self->stack, version); if (tree->lex_state != TS_TREE_STATE_INDEPENDENT && - tree->lex_state != self->language->lex_states[top_state]) + tree->lex_state != ts_language_lex_state(self->language, top_state)) return false; const TSParseAction action = @@ -253,9 +247,10 @@ static TSTree *ts_parser__get_lookahead(TSParser *self, StackVersion version, ts_lexer_reset(&self->lexer, position); TSStateId parse_state = ts_stack_top_state(self->stack, version); - TSStateId lex_state = self->language->lex_states[parse_state]; + bool error_mode = parse_state == ts_parse_state_error; + TSStateId lex_state = error_mode ? 0 : self->language->lex_states[parse_state]; LOG("lex state:%d", lex_state); - return self->language->lex_fn(&self->lexer, lex_state, false); + return self->language->lex_fn(&self->lexer, lex_state, error_mode); } static bool ts_parser__select_tree(TSParser *self, TSTree *left, TSTree *right) { @@ -273,12 +268,23 @@ static bool ts_parser__select_tree(TSParser *self, TSTree *left, TSTree *right) SYM_NAME(left->symbol), SYM_NAME(right->symbol)); return false; } - return ts_tree_compare(right, left) < 0; -} -static void ts_parser__remove_version(TSParser *self, StackVersion version) { - LOG_ACTION("bail version:%d", version); - ts_stack_remove_version(self->stack, version); + int comparison = ts_tree_compare(left, right); + switch (comparison) { + case -1: + LOG_ACTION("select_earlier symbol:%s, over_symbol:%s", + SYM_NAME(left->symbol), SYM_NAME(right->symbol)); + return false; + break; + case 1: + LOG_ACTION("select_earlier symbol:%s, over_symbol:%s", + SYM_NAME(right->symbol), SYM_NAME(left->symbol)); + return true; + default: + LOG_ACTION("select_existing symbol:%s, over_symbol:%s", + SYM_NAME(left->symbol), SYM_NAME(right->symbol)); + return false; + } } static bool ts_parser__push(TSParser *self, StackVersion version, TSTree *tree, @@ -291,7 +297,6 @@ static bool ts_parser__push(TSParser *self, StackVersion version, TSTree *tree, static bool ts_parser__shift(TSParser *self, StackVersion version, TSStateId state, TSTree *lookahead, bool extra) { if (extra) { - LOG_ACTION("shift_extra"); TSSymbolMetadata metadata = ts_language_symbol_metadata(self->language, lookahead->symbol); if (metadata.structural && ts_stack_version_count(self->stack) > 1) { @@ -301,7 +306,6 @@ static bool ts_parser__shift(TSParser *self, StackVersion version, } lookahead->extra = true; } else { - LOG_ACTION("shift state:%u", state); ts_tree_retain(lookahead); } @@ -337,13 +341,6 @@ static bool ts_parser__switch_children(TSParser *self, TSTree *tree, static Reduction ts_parser__reduce(TSParser *self, StackVersion version, TSSymbol symbol, unsigned count, bool extra, bool fragile) { - if (extra) { - LOG_ACTION("reduce_extra"); - } else { - LOG_ACTION("reduce sym:%s, child_count:%u, fragile:%s", SYM_NAME(symbol), - count, BOOL_STRING(fragile)); - } - size_t initial_version_count = ts_stack_version_count(self->stack); StackPopResult pop = ts_stack_pop_count(self->stack, version, count); switch (pop.status) { @@ -406,7 +403,8 @@ static Reduction ts_parser__reduce(TSParser *self, StackVersion version, new_state = state; } else { TSParseAction action = ts_language_last_action(language, state, symbol); - assert(action.type == TSParseActionTypeShift); + assert(action.type == TSParseActionTypeShift || + action.type == TSParseActionTypeRecover); new_state = action.data.to_state; } @@ -425,57 +423,96 @@ error: return (Reduction){ ReduceFailed }; } -static StackIterateAction ts_parser__error_repair_callback(void *payload, - TSStateId state, - size_t tree_count, - bool is_done, - bool is_pending) { - ErrorRepairSession *session = (ErrorRepairSession *)payload; - const TSParser *self = session->parser; +static bool ts_parser__is_valid_repair( + const TSParser *self, const TreeArray *trees_below, + const TreeArray *trees_above, TSStateId start_state, TSSymbol goal_symbol, + size_t goal_count_below, TSSymbol lookahead_symbol) { const TSLanguage *language = self->language; - TSSymbol lookahead_symbol = session->lookahead_symbol; - StackIterateAction result = StackIterateNone; + TSStateId state = start_state; + size_t count_below = 0; - for (size_t i = 0; i < session->repairs->size; i++) { - ErrorRepair *repair = &session->repairs->contents[i]; - TSSymbol symbol = repair->symbol; + for (size_t i = trees_below->size - 1; i + 1 > 0; i--) { + TSTree *tree = trees_below->contents[i]; + TSParseAction action = + ts_language_last_action(language, state, tree->symbol); + if (action.type != TSParseActionTypeShift) + return false; + if (action.extra || tree->extra) + continue; - if (tree_count >= repair->count_below_error) { - size_t skip_count = tree_count - repair->count_below_error; + state = action.data.to_state; + count_below++; - if (session->found_repair && skip_count > session->best_repair_skip_count) { - array_erase(session->repairs, i--); - continue; + if (count_below == goal_count_below) { + for (size_t j = 0; j < trees_above->size; j++) { + TSTree *tree = trees_above->contents[j]; + TSParseAction action = + ts_language_last_action(language, state, tree->symbol); + if (action.type != TSParseActionTypeShift) + return false; + if (action.extra || tree->extra) + continue; + + state = action.data.to_state; } - if (repair->in_progress_state_count > 0) { - TSParseAction action = ts_language_last_action(language, state, symbol); - if (action.type == TSParseActionTypeShift) { - TSStateId next_state = action.data.to_state; - if (ts_language_has_action(language, next_state, lookahead_symbol) && - (!session->found_repair || - repair->in_progress_state_count > - session->best_repair.in_progress_state_count)) { - result |= StackIteratePop; - session->found_repair = true; - session->best_repair = *repair; - session->best_repair_state = state; - session->best_repair_skip_count = skip_count; - session->best_repair_next_state = next_state; - array_erase(session->repairs, i--); - continue; - } - } - } + size_t action_count = 0; + const TSParseAction *actions = + ts_language_actions(language, state, lookahead_symbol, &action_count); + for (size_t k = 0; k < action_count; k++) + if (actions[k].type == TSParseActionTypeReduce && + actions[k].data.symbol == goal_symbol) + return true; } - - if (ts_language_symbol_is_in_progress(self->language, state, symbol)) - repair->in_progress_state_count++; - else - repair->in_progress_state_count = 0; } - if (session->repairs->size == 0) + return false; +} + +static StackIterateAction ts_parser__error_repair_callback( + void *payload, TSStateId state, TreeArray *trees, size_t tree_count, + bool is_done, bool is_pending) { + ErrorRepairSession *session = (ErrorRepairSession *)payload; + TSParser *self = session->parser; + const TSLanguage *language = self->language; + TSSymbol lookahead_symbol = session->lookahead_symbol; + ReduceActionSet *repairs = &self->reduce_actions; + TreeArray *trees_above_error = session->trees_above_error; + StackIterateAction result = StackIterateNone; + + for (size_t i = 0; i < repairs->size; i++) { + ReduceAction *repair = &repairs->contents[i]; + if (repair->count > tree_count) + continue; + + size_t skip_count = tree_count - repair->count; + if (session->found_repair && skip_count >= session->best_repair_skip_count) { + array_erase(repairs, i--); + continue; + } + + TSParseAction repair_symbol_action = + ts_language_last_action(language, state, repair->symbol); + if (repair_symbol_action.type != TSParseActionTypeShift) + continue; + + TSStateId state_after_repair = repair_symbol_action.data.to_state; + if (!ts_language_has_action(language, state_after_repair, lookahead_symbol)) + continue; + + if (ts_parser__is_valid_repair(self, trees, trees_above_error, state, + repair->symbol, repair->count, + lookahead_symbol)) { + result |= StackIteratePop; + session->found_repair = true; + session->best_repair = *repair; + session->best_repair_skip_count = skip_count; + session->best_repair_next_state = state_after_repair; + array_erase(repairs, i--); + } + } + + if (repairs->size == 0) result |= StackIterateStop; return result; @@ -486,26 +523,22 @@ static RepairResult ts_parser__repair_error(TSParser *self, StackSlice slice, const TSParseAction *actions, size_t action_count) { size_t count_above_error = ts_tree_array_essential_count(&slice.trees); - LOG_ACTION("repair count_above_error:%lu", count_above_error); - ErrorRepairSession session = { .parser = self, .lookahead_symbol = lookahead->symbol, - .repairs = &self->error_repairs, .found_repair = false, + .trees_above_error = &slice.trees, }; - array_clear(&self->error_repairs); + array_clear(&self->reduce_actions); for (size_t i = 0; i < action_count; i++) if (actions[i].type == TSParseActionTypeReduce && actions[i].data.child_count > count_above_error) - CHECK(array_push( - &self->error_repairs, - ((ErrorRepair){ - .symbol = actions[i].data.symbol, - .count_below_error = actions[i].data.child_count - count_above_error, - .in_progress_state_count = 0, - }))); + CHECK(array_push(&self->reduce_actions, + ((ReduceAction){ + .symbol = actions[i].data.symbol, + .count = actions[i].data.child_count - count_above_error, + }))); StackPopResult pop = ts_stack_iterate( self->stack, slice.version, ts_parser__error_repair_callback, &session); @@ -518,7 +551,7 @@ static RepairResult ts_parser__repair_error(TSParser *self, StackSlice slice, return RepairNoneFound; } - ErrorRepair repair = session.best_repair; + ReduceAction repair = session.best_repair; TSStateId next_state = session.best_repair_next_state; size_t skip_count = session.best_repair_skip_count; TSSymbol symbol = repair.symbol; @@ -534,20 +567,18 @@ static RepairResult ts_parser__repair_error(TSParser *self, StackSlice slice, ts_stack_remove_version(self->stack, other_slice.version); } - LOG_ACTION( - "repair_found sym:%s, child_count:%lu, match_count:%lu, skipped:%lu", - SYM_NAME(symbol), repair.count_below_error + count_above_error, - repair.in_progress_state_count, skip_count); + LOG_ACTION("repair_found sym:%s, child_count:%lu, skipped:%lu", + SYM_NAME(symbol), repair.count + count_above_error, skip_count); if (skip_count > 0) { TreeArray skipped_children = array_new(); CHECK(array_grow(&skipped_children, skip_count)); - for (size_t i = repair.count_below_error; i < children_below.size; i++) + for (size_t i = repair.count; i < children_below.size; i++) array_push(&skipped_children, children_below.contents[i]); TSTree *error = ts_tree_make_error_node(&skipped_children); CHECK(error); - children_below.size = repair.count_below_error; + children_below.size = repair.count; array_push(&children_below, error); } @@ -581,7 +612,6 @@ static void ts_parser__start(TSParser *self, TSInput input, } static bool ts_parser__accept(TSParser *self, StackVersion version) { - LOG_ACTION("accept"); StackPopResult pop = ts_stack_pop_all(self->stack, version); CHECK(pop.status); CHECK(pop.slices.size); @@ -624,67 +654,91 @@ error: return false; } -static ParseActionResult ts_parser__handle_error(TSParser *self, - StackVersion version, - TSTree *invalid_tree) { - const TSLanguage *language = self->language; - TreeArray invalid_trees = array_new(); - TSTree *next_token = self->language->lex_fn(&self->lexer, 0, true); - ts_tree_retain(invalid_tree); - CHECK(array_push(&invalid_trees, invalid_tree)); - LOG_ACTION("handle_error %s", SYM_NAME(invalid_tree->symbol)); - - for (;;) { - if (next_token->symbol == ts_builtin_sym_end) { - LOG_ACTION("fail_to_recover"); - - ts_tree_release(next_token); - TSTree *error = ts_tree_make_error_node(&invalid_trees); - CHECK(error); - CHECK(ts_parser__push(self, version, error, 0)); - - TSTree *parent = ts_tree_make_leaf( - ts_builtin_sym_start, ts_length_zero(), ts_length_zero(), - ts_language_symbol_metadata(language, ts_builtin_sym_start)); - CHECK(parent); - CHECK(ts_parser__push(self, version, parent, 0)); - CHECK(ts_parser__accept(self, version)); - return ParseActionRemoved; +static bool ts_parser__handle_error(TSParser *self, StackVersion version, + TSStateId state, TSTree *lookahead) { + bool has_shift_action = false; + array_clear(&self->reduce_actions); + for (TSSymbol symbol = 0; symbol < self->language->symbol_count; symbol++) { + size_t action_count; + const TSParseAction *actions = + ts_language_actions(self->language, state, symbol, &action_count); + for (size_t i = 0; i < action_count; i++) { + TSParseAction action = actions[i]; + if (action.extra) + continue; + if (action.type == TSParseActionTypeShift || + action.type == TSParseActionTypeRecover) + has_shift_action = true; + if (action.type == TSParseActionTypeReduce) + CHECK(ts_reduce_action_set_add( + &self->reduce_actions, + (ReduceAction){ + .symbol = action.data.symbol, .count = action.data.child_count, + })); } - - TSLength position = self->lexer.current_position; - TSTree *following_token = language->lex_fn(&self->lexer, 0, true); - CHECK(following_token); - - if (!ts_language_symbol_metadata(language, next_token->symbol).extra) { - TSParseAction action = ts_language_last_action( - language, ts_parse_state_error, next_token->symbol); - assert(action.type == TSParseActionTypeShift); - TSStateId next_state = action.data.to_state; - - if (ts_language_has_action(language, next_state, following_token->symbol) && - !ts_language_symbol_metadata(language, following_token->symbol).extra) { - LOG_ACTION("resume_without_context state:%d", next_state); - - ts_tree_release(following_token); - ts_lexer_reset(&self->lexer, position); - ts_tree_steal_padding(*array_back(&invalid_trees), next_token); - TSTree *error = ts_tree_make_error_node(&invalid_trees); - CHECK(error); - CHECK(ts_parser__push(self, version, error, ts_parse_state_error)); - CHECK(ts_parser__push(self, version, next_token, next_state)); - return ParseActionUpdated; - } - } - - CHECK(array_push(&invalid_trees, next_token)); - next_token = following_token; } + Reduction reduction; + for (size_t i = 0; i < self->reduce_actions.size; i++) { + ReduceAction repair = self->reduce_actions.contents[i]; + reduction = ts_parser__reduce(self, version, repair.symbol, repair.count, + false, true); + CHECK(reduction.status != ReduceFailed); + assert(reduction.status == ReduceSucceeded); + CHECK(ts_parser__shift(self, reduction.slice.version, ts_parse_state_error, + lookahead, false)); + } + + if (has_shift_action) { + CHECK( + ts_parser__shift(self, version, ts_parse_state_error, lookahead, false)); + } else { + ts_stack_renumber_version(self->stack, reduction.slice.version, version); + } + + return true; + error: return false; } +static bool ts_parser__recover(TSParser *self, StackVersion version, + TSStateId state, TSTree *lookahead) { + size_t error_length = ts_stack_error_length(self->stack, version); + + bool has_repaired = false; + for (StackVersion i = 0; i < ts_stack_version_count(self->stack); i++) + if (i != version && ts_stack_error_length(self->stack, i) == 0 && + ts_stack_last_repaired_error_size(self->stack, i) <= error_length) { + has_repaired = true; + break; + } + + if (has_repaired) { + LOG_ACTION("final_recover state:%u, error_length:%lu ", state, error_length); + } else { + StackVersion new_version = ts_stack_duplicate_version(self->stack, version); + CHECK(new_version != STACK_VERSION_NONE); + CHECK(ts_parser__shift( + self, new_version, ts_parse_state_error, lookahead, + ts_language_symbol_metadata(self->language, lookahead->symbol).extra)); + LOG_ACTION("recover_and_discard state:%u, error_length:%lu", state, + error_length); + } + + CHECK(ts_parser__shift(self, version, state, lookahead, false)); + return true; + +error: + return false; +} + +static bool ts_parser__recover_eof(TSParser *self, StackVersion version) { + TreeArray children = array_new(); + TSTree *parent = ts_tree_make_error_node(&children); + return ts_parser__push(self, version, parent, 1); +} + static ParseActionResult ts_parser__consume_lookahead(TSParser *self, StackVersion version, TSTree *lookahead) { @@ -727,21 +781,40 @@ static ParseActionResult ts_parser__consume_lookahead(TSParser *self, } if (ts_stack_version_count(self->stack) == 1 && !self->finished_tree) { - return ts_parser__handle_error(self, version, lookahead); + LOG_ACTION("handle_error"); + CHECK(ts_parser__handle_error(self, version, state, lookahead)); + return ParseActionUpdated; } else { - ts_parser__remove_version(self, version); + LOG_ACTION("bail version:%d", version); + ts_stack_remove_version(self->stack, version); return ParseActionRemoved; } } case TSParseActionTypeShift: { - CHECK(ts_parser__shift(self, version, - action.extra ? state : action.data.to_state, - lookahead, action.extra)); + TSStateId next_state; + if (action.extra) { + next_state = state; + LOG_ACTION("shift_extra"); + } else { + next_state = action.data.to_state; + LOG_ACTION("shift state:%u", next_state); + } + + CHECK(ts_parser__shift(self, version, next_state, lookahead, + action.extra)); return ParseActionUpdated; } case TSParseActionTypeReduce: { + if (action.extra) { + LOG_ACTION("reduce_extra"); + } else { + LOG_ACTION("reduce sym:%s, child_count:%u, fragile:%s", + SYM_NAME(action.data.symbol), action.data.child_count, + BOOL_STRING(action.fragile)); + } + Reduction reduction = ts_parser__reduce( self, version, action.data.symbol, action.data.child_count, action.extra, action.fragile); @@ -756,6 +829,7 @@ static ParseActionResult ts_parser__consume_lookahead(TSParser *self, error_repair_depth = ts_tree_array_essential_count(&reduction.slice.trees); + LOG_ACTION("repair count_above_error:%lu", error_repair_depth); switch (ts_parser__repair_error(self, reduction.slice, lookahead, actions, action_count)) { case RepairFailed: @@ -774,9 +848,21 @@ static ParseActionResult ts_parser__consume_lookahead(TSParser *self, } case TSParseActionTypeAccept: { + LOG_ACTION("accept"); CHECK(ts_parser__accept(self, version)); return ParseActionRemoved; } + + case TSParseActionTypeRecover: { + if (lookahead->symbol == ts_builtin_sym_end) { + LOG_ACTION("recover_eof"); + CHECK(ts_parser__recover_eof(self, version)); + } else { + CHECK(ts_parser__recover(self, version, action.data.to_state, + lookahead)); + } + return ParseActionUpdated; + } } } @@ -792,13 +878,13 @@ bool ts_parser_init(TSParser *self) { ts_lexer_init(&self->lexer); self->finished_tree = NULL; self->stack = NULL; - array_init(&self->error_repairs); + array_init(&self->reduce_actions); self->stack = ts_stack_new(); if (!self->stack) goto error; - if (!array_grow(&self->error_repairs, 4)) + if (!array_grow(&self->reduce_actions, 4)) goto error; return true; @@ -808,16 +894,16 @@ error: ts_stack_delete(self->stack); self->stack = NULL; } - if (self->error_repairs.contents) - array_delete(&self->error_repairs); + if (self->reduce_actions.contents) + array_delete(&self->reduce_actions); return false; } void ts_parser_destroy(TSParser *self) { if (self->stack) ts_stack_delete(self->stack); - if (self->error_repairs.contents) - array_delete(&self->error_repairs); + if (self->reduce_actions.contents) + array_delete(&self->reduce_actions); } TSDebugger ts_parser_debugger(const TSParser *self) { diff --git a/src/runtime/parser.h b/src/runtime/parser.h index c5de3ca1..823f07f6 100644 --- a/src/runtime/parser.h +++ b/src/runtime/parser.h @@ -7,15 +7,13 @@ extern "C" { #include "runtime/stack.h" #include "runtime/array.h" - -typedef struct ErrorRepair ErrorRepair; -typedef Array(ErrorRepair) ErrorRepairArray; +#include "runtime/reduce_action.h" typedef struct { TSLexer lexer; Stack *stack; const TSLanguage *language; - ErrorRepairArray error_repairs; + ReduceActionSet reduce_actions; TSTree *finished_tree; bool is_split; bool print_debugging_graphs; diff --git a/src/runtime/reduce_action.h b/src/runtime/reduce_action.h new file mode 100644 index 00000000..3b8841e5 --- /dev/null +++ b/src/runtime/reduce_action.h @@ -0,0 +1,32 @@ +#ifndef RUNTIME_REDUCE_ACTION_H_ +#define RUNTIME_REDUCE_ACTION_H_ + +#ifdef __cplusplus +extern "C" { +#endif + +#include "runtime/array.h" +#include "tree_sitter/runtime.h" + +typedef struct { + TSSymbol symbol; + size_t count; +} ReduceAction; + +typedef Array(ReduceAction) ReduceActionSet; + +static inline bool ts_reduce_action_set_add(ReduceActionSet *self, + ReduceAction new_action) { + for (size_t i = 0; i < self->size; i++) { + ReduceAction action = self->contents[i]; + if (action.symbol == new_action.symbol && action.count == new_action.count) + return true; + } + return array_push(self, new_action); +} + +#ifdef __cplusplus +} +#endif + +#endif // RUNTIME_REDUCE_ACTION_H_ diff --git a/src/runtime/stack.c b/src/runtime/stack.c index cd431668..c01340ed 100644 --- a/src/runtime/stack.c +++ b/src/runtime/stack.c @@ -26,6 +26,7 @@ struct StackNode { StackLink links[MAX_LINK_COUNT]; short unsigned int link_count; short unsigned int ref_count; + size_t error_length; }; typedef struct { @@ -38,6 +39,7 @@ typedef struct { typedef struct { size_t goal_tree_count; bool found_error; + bool found_valid_path; } StackPopSession; typedef Array(StackNode *) StackNodeArray; @@ -90,6 +92,7 @@ static StackNode *stack_node_new(StackNode *next, TSTree *tree, bool is_pending, .links = {}, .state = state, .position = position, + .error_length = (state == ts_parse_state_error) ? 1 : 0, }; if (next) { @@ -97,6 +100,7 @@ static StackNode *stack_node_new(StackNode *next, TSTree *tree, bool is_pending, stack_node_retain(next); node->link_count = 1; node->links[0] = (StackLink){ next, tree, is_pending }; + node->error_length += next->error_length; } return node; @@ -168,8 +172,9 @@ INLINE StackPopResult stack__iter(Stack *self, StackVersion version, StackNode *node = path->node; bool is_done = node == self->base_node; - StackIterateAction action = callback( - payload, node->state, path->tree_count, is_done, path->is_pending); + StackIterateAction action = + callback(payload, node->state, &path->trees, path->tree_count, is_done, + path->is_pending); bool should_pop = action & StackIteratePop; bool should_stop = action & StackIterateStop || node->link_count == 0; @@ -207,7 +212,7 @@ INLINE StackPopResult stack__iter(Stack *self, StackVersion version, next_path->node = link.node; if (!link.is_pending) next_path->is_pending = false; - if (!link.tree->extra && link.tree->symbol != ts_builtin_sym_error) + if (!link.tree->extra) next_path->tree_count++; if (!array_push(&next_path->trees, link.tree)) goto error; @@ -302,6 +307,24 @@ TSLength ts_stack_top_position(const Stack *self, StackVersion version) { return (*array_get(&self->heads, version))->position; } +size_t ts_stack_error_length(const Stack *self, StackVersion version) { + return (*array_get(&self->heads, version))->error_length; +} + +size_t ts_stack_last_repaired_error_size(const Stack *self, + StackVersion version) { + StackNode *node = (*array_get(&self->heads, version)); + for (;;) { + if (node->link_count == 0) + break; + TSTree *tree = node->links[0].tree; + if (tree->error_size > 0) + return ts_tree_last_error_size(tree); + node = node->links[0].node; + } + return 0; +} + bool ts_stack_push(Stack *self, StackVersion version, TSTree *tree, bool is_pending, TSStateId state) { StackNode *node = *array_get(&self->heads, version); @@ -321,40 +344,48 @@ StackPopResult ts_stack_iterate(Stack *self, StackVersion version, } INLINE StackIterateAction pop_count_callback(void *payload, TSStateId state, - size_t tree_count, bool is_done, - bool is_pending) { + TreeArray *trees, size_t tree_count, + bool is_done, bool is_pending) { StackPopSession *pop_session = (StackPopSession *)payload; - if (pop_session->found_error) - return StackIterateStop; - if (tree_count == pop_session->goal_tree_count) + + if (tree_count == pop_session->goal_tree_count) { + pop_session->found_valid_path = true; return StackIteratePop | StackIterateStop; + } + if (state == ts_parse_state_error) { - pop_session->found_error = true; - return StackIteratePop | StackIterateStop; + if (pop_session->found_valid_path || pop_session->found_error) { + return StackIterateStop; + } else { + pop_session->found_error = true; + return StackIteratePop | StackIterateStop; + } } return StackIterateNone; } StackPopResult ts_stack_pop_count(Stack *self, StackVersion version, size_t count) { - StackPopSession session = {.goal_tree_count = count, .found_error = false }; + StackPopSession session = { + .goal_tree_count = count, .found_error = false, .found_valid_path = false, + }; StackPopResult pop = stack__iter(self, version, pop_count_callback, &session); if (pop.status && session.found_error) { - pop.status = StackPopStoppedAtError; - StackSlice stopped_slice = array_pop(&pop.slices); - for (size_t i = 0; i < pop.slices.size; i++) { - StackSlice slice = pop.slices.contents[i]; - ts_tree_array_delete(&slice.trees); - ts_stack_remove_version(self, slice.version); - stopped_slice.version--; + if (session.found_valid_path) { + StackSlice error_slice = pop.slices.contents[0]; + ts_stack_remove_version(self, error_slice.version); + ts_tree_array_delete(&error_slice.trees); + array_erase(&pop.slices, 0); + pop.slices.contents[0].version--; + } else { + pop.status = StackPopStoppedAtError; } - pop.slices.size = 1; - pop.slices.contents[0] = stopped_slice; } return pop; } INLINE StackIterateAction pop_pending_callback(void *payload, TSStateId state, + TreeArray *trees, size_t tree_count, bool is_done, bool is_pending) { if (tree_count >= 1) { @@ -378,8 +409,8 @@ StackPopResult ts_stack_pop_pending(Stack *self, StackVersion version) { } INLINE StackIterateAction pop_all_callback(void *payload, TSStateId state, - size_t tree_count, bool is_done, - bool is_pending) { + TreeArray *trees, size_t tree_count, + bool is_done, bool is_pending) { return is_done ? (StackIteratePop | StackIterateStop) : StackIterateNone; } @@ -401,6 +432,14 @@ void ts_stack_renumber_version(Stack *self, StackVersion v1, StackVersion v2) { array_erase(&self->heads, v1); } +StackVersion ts_stack_duplicate_version(Stack *self, StackVersion version) { + assert(version < self->heads.size); + if (!array_push(&self->heads, self->heads.contents[version])) + return STACK_VERSION_NONE; + stack_node_retain(*array_back(&self->heads)); + return self->heads.size - 1; +} + void ts_stack_merge_from(Stack *self, StackVersion start_version) { for (size_t i = start_version + 1; i < self->heads.size; i++) { StackNode *node = self->heads.contents[i]; @@ -408,11 +447,16 @@ void ts_stack_merge_from(Stack *self, StackVersion start_version) { StackNode *prior_node = self->heads.contents[j]; if (prior_node->state == node->state && prior_node->position.chars == node->position.chars) { - for (size_t k = 0; k < node->link_count; k++) { - StackLink link = node->links[k]; - stack_node_add_link(prior_node, link); + if (prior_node->error_length < node->error_length) { + ts_stack_remove_version(self, i); + } else if (node->error_length < prior_node->error_length) { + ts_stack_remove_version(self, j); + } else { + for (size_t k = 0; k < node->link_count; k++) + stack_node_add_link(prior_node, node->links[k]); + ts_stack_remove_version(self, i); } - ts_stack_remove_version(self, i--); + i--; break; } } diff --git a/src/runtime/stack.h b/src/runtime/stack.h index d5761650..fc21fef4 100644 --- a/src/runtime/stack.h +++ b/src/runtime/stack.h @@ -41,6 +41,7 @@ enum { typedef unsigned int StackIterateAction; typedef StackIterateAction (*StackIterateCallback)(void *, TSStateId state, + TreeArray *trees, size_t tree_count, bool is_done, bool is_pending); @@ -72,6 +73,10 @@ TSStateId ts_stack_top_state(const Stack *, StackVersion); */ TSLength ts_stack_top_position(const Stack *, StackVersion); +size_t ts_stack_error_length(const Stack *, StackVersion); + +size_t ts_stack_last_repaired_error_size(const Stack *, StackVersion); + /* * Push a tree and state onto the given head of the stack. This could cause * the version to merge with an existing version. @@ -100,6 +105,8 @@ void ts_stack_merge(Stack *); void ts_stack_renumber_version(Stack *, StackVersion, StackVersion); +StackVersion ts_stack_duplicate_version(Stack *, StackVersion); + /* * Remove the given version from the stack. */ diff --git a/src/runtime/tree.c b/src/runtime/tree.c index 8a51b149..32347794 100644 --- a/src/runtime/tree.c +++ b/src/runtime/tree.c @@ -107,6 +107,31 @@ recur: } } +static void ts_tree_total_tokens(const TSTree *self, size_t *result) { +recur: + if (self->child_count == 0) { + (*result)++; + } else { + for (size_t i = 1; i < self->child_count; i++) + ts_tree_total_tokens(self->children[i], result); + self = self->children[0]; + goto recur; + } +} + +size_t ts_tree_last_error_size(const TSTree *self) { + if (self->symbol == ts_builtin_sym_error) + return self->error_size; + + for (size_t i = self->child_count - 1; i + 1 > 0; i--) { + size_t result = ts_tree_last_error_size(self->children[i]); + if (result > 0) + return result; + } + + return 0; +} + void ts_tree_set_children(TSTree *self, size_t child_count, TSTree **children) { if (self->child_count > 0) ts_free(self->children); @@ -115,7 +140,7 @@ void ts_tree_set_children(TSTree *self, size_t child_count, TSTree **children) { self->child_count = child_count; self->named_child_count = 0; self->visible_child_count = 0; - size_t error_size = 0; + self->error_size = 0; for (size_t i = 0; i < child_count; i++) { TSTree *child = children[i]; @@ -127,6 +152,8 @@ void ts_tree_set_children(TSTree *self, size_t child_count, TSTree **children) { self->size = ts_length_add(self->size, ts_tree_total_size(child)); } + self->error_size += child->error_size; + if (child->visible) { self->visible_child_count++; if (child->named) @@ -139,15 +166,13 @@ void ts_tree_set_children(TSTree *self, size_t child_count, TSTree **children) { if (child->symbol == ts_builtin_sym_error) { self->fragile_left = self->fragile_right = true; self->parse_state = TS_TREE_STATE_ERROR; - } else { - error_size += child->error_size; } } - if (self->symbol == ts_builtin_sym_error) - self->error_size = self->size.chars; - else - self->error_size = error_size; + if (self->symbol == ts_builtin_sym_error) { + self->error_size = 0; + ts_tree_total_tokens(self, &self->error_size); + } if (child_count > 0) { self->lex_state = children[0]->lex_state; diff --git a/src/runtime/tree.h b/src/runtime/tree.h index 2ace03ca..d48fb1bc 100644 --- a/src/runtime/tree.h +++ b/src/runtime/tree.h @@ -67,6 +67,7 @@ void ts_tree_assign_parents(TSTree *); void ts_tree_edit(TSTree *, TSInputEdit); void ts_tree_steal_padding(TSTree *, TSTree *); char *ts_tree_string(const TSTree *, const TSLanguage *, bool include_all); +size_t ts_tree_last_error_size(const TSTree *); static inline size_t ts_tree_total_chars(const TSTree *self) { return self->padding.chars + self->size.chars;