diff --git a/src/runtime/parser.c b/src/runtime/parser.c index b0748090..f2c4c2ca 100644 --- a/src/runtime/parser.c +++ b/src/runtime/parser.c @@ -59,6 +59,7 @@ typedef struct { TSParser *parser; TSSymbol lookahead_symbol; TreeArray *trees_above_error; + size_t tree_count_above_error; bool found_repair; ReduceAction best_repair; TSStateId best_repair_next_state; @@ -482,92 +483,118 @@ error: return (Reduction){ ReduceFailed }; } -static bool ts_parser__is_valid_repair( - const TSParser *self, const TreeArray *trees_below, - const TreeArray *trees_above, TSStateId start_state, TSSymbol goal_symbol, - size_t goal_count_below, TSSymbol lookahead_symbol) { - const TSLanguage *language = self->language; +static inline const TSParseAction * ts_parser__reductions_after_sequence( + TSParser *self, TSStateId start_state, const TreeArray *trees_below, + size_t tree_count_below, const TreeArray *trees_above, + TSSymbol lookahead_symbol, size_t *count) { TSStateId state = start_state; - size_t count_below = 0; + size_t child_count = 0; + *count = 0; - for (size_t i = trees_below->size - 1; i + 1 > 0; i--) { - TSTree *tree = trees_below->contents[i]; + for (size_t i = 0; i < tree_count_below; i++) { + TSTree *tree = trees_below->contents[trees_below->size - 1 - i]; TSParseAction action = - ts_language_last_action(language, state, tree->symbol); + ts_language_last_action(self->language, state, tree->symbol); if (action.type != TSParseActionTypeShift) - return false; + return NULL; if (action.extra || tree->extra) continue; - + child_count++; state = action.data.to_state; - count_below++; - - if (count_below == goal_count_below) { - for (size_t j = 0; j < trees_above->size; j++) { - TSTree *tree = trees_above->contents[j]; - TSParseAction action = - ts_language_last_action(language, state, tree->symbol); - if (action.type != TSParseActionTypeShift) - return false; - if (action.extra || tree->extra) - continue; - - state = action.data.to_state; - } - - size_t action_count = 0; - const TSParseAction *actions = - ts_language_actions(language, state, lookahead_symbol, &action_count); - for (size_t k = 0; k < action_count; k++) - if (actions[k].type == TSParseActionTypeReduce && - actions[k].data.symbol == goal_symbol) - return true; - } } - return false; + for (size_t i = 0; i < trees_above->size; i++) { + TSTree *tree = trees_above->contents[i]; + TSParseAction action = + ts_language_last_action(self->language, state, tree->symbol); + if (action.type != TSParseActionTypeShift) + return NULL; + if (action.extra || tree->extra) + continue; + child_count++; + state = action.data.to_state; + } + + const TSParseAction *actions = ts_language_actions( + self->language, state, lookahead_symbol, count); + + if (actions[*count - 1].type != TSParseActionTypeReduce) + (*count)--; + + while (*count > 0 && actions[0].data.child_count < child_count) { + actions++; + (*count)--; + } + + while (*count > 0 && actions[*count - 1].data.child_count > child_count) { + (*count)--; + } + + return actions; } static StackIterateAction ts_parser__error_repair_callback( void *payload, TSStateId state, TreeArray *trees, size_t tree_count, bool is_done, bool is_pending) { + ErrorRepairSession *session = (ErrorRepairSession *)payload; TSParser *self = session->parser; - const TSLanguage *language = self->language; TSSymbol lookahead_symbol = session->lookahead_symbol; ReduceActionSet *repairs = &self->reduce_actions; TreeArray *trees_above_error = session->trees_above_error; + size_t tree_count_above_error = session->tree_count_above_error; + StackIterateAction result = StackIterateNone; + size_t last_repair_count = 0; + size_t repair_reduction_count = -1; + const TSParseAction *repair_reductions = NULL; + for (size_t i = 0; i < repairs->size; i++) { ReduceAction *repair = &repairs->contents[i]; - if (repair->count > tree_count) - continue; + size_t count_needed_below_error = repair->count - tree_count_above_error; + if (count_needed_below_error > tree_count) + break; - size_t skip_count = tree_count - repair->count; + size_t skip_count = tree_count - count_needed_below_error; if (session->found_repair && skip_count >= session->best_repair_skip_count) { array_erase(repairs, i--); continue; } TSParseAction repair_symbol_action = - ts_language_last_action(language, state, repair->symbol); + ts_language_last_action(self->language, state, repair->symbol); if (repair_symbol_action.type != TSParseActionTypeShift) continue; TSStateId state_after_repair = repair_symbol_action.data.to_state; - if (!ts_language_has_action(language, state_after_repair, lookahead_symbol)) + if (!ts_language_has_action(self->language, state_after_repair, lookahead_symbol)) continue; - if (ts_parser__is_valid_repair(self, trees, trees_above_error, state, - repair->symbol, repair->count, - lookahead_symbol)) { - result |= StackIteratePop; - session->found_repair = true; - session->best_repair = *repair; - session->best_repair_skip_count = skip_count; - session->best_repair_next_state = state_after_repair; - array_erase(repairs, i--); + if (count_needed_below_error != last_repair_count) { + assert(count_needed_below_error > last_repair_count); + last_repair_count = count_needed_below_error; + repair_reductions = + ts_parser__reductions_after_sequence(self, state, trees, + count_needed_below_error, + trees_above_error, lookahead_symbol, + &repair_reduction_count); + } + + if (!repair_reductions) + continue; + + for (size_t j = 0; j < repair_reduction_count; j++) { + const TSParseAction repair_reduction = repair_reductions[j]; + if (repair_reduction.data.symbol == repair->symbol) { + result |= StackIteratePop; + session->found_repair = true; + session->best_repair = *repair; + session->best_repair_skip_count = skip_count; + session->best_repair_next_state = state_after_repair; + array_erase(repairs, i--); + break; + } } } @@ -581,22 +608,22 @@ static RepairResult ts_parser__repair_error(TSParser *self, StackSlice slice, TSTree *lookahead, const TSParseAction *actions, size_t action_count) { - size_t count_above_error = ts_tree_array_essential_count(&slice.trees); ErrorRepairSession session = { .parser = self, .lookahead_symbol = lookahead->symbol, .found_repair = false, .trees_above_error = &slice.trees, + .tree_count_above_error = ts_tree_array_essential_count(&slice.trees), }; array_clear(&self->reduce_actions); for (size_t i = 0; i < action_count; i++) if (actions[i].type == TSParseActionTypeReduce && - actions[i].data.child_count > count_above_error) + actions[i].data.child_count > session.tree_count_above_error) CHECK(array_push(&self->reduce_actions, ((ReduceAction){ .symbol = actions[i].data.symbol, - .count = actions[i].data.child_count - count_above_error, + .count = actions[i].data.child_count, }))); StackPopResult pop = ts_stack_iterate( @@ -613,6 +640,7 @@ static RepairResult ts_parser__repair_error(TSParser *self, StackSlice slice, ReduceAction repair = session.best_repair; TSStateId next_state = session.best_repair_next_state; size_t skip_count = session.best_repair_skip_count; + size_t count_below = repair.count - session.tree_count_above_error; TSSymbol symbol = repair.symbol; StackSlice new_slice = array_pop(&pop.slices); @@ -628,12 +656,12 @@ static RepairResult ts_parser__repair_error(TSParser *self, StackSlice slice, TreeArray skipped_children = array_new(); CHECK(array_grow(&skipped_children, skip_count)); - for (size_t i = repair.count; i < children.size; i++) + for (size_t i = count_below; i < children.size; i++) array_push(&skipped_children, children.contents[i]); TSTree *error = ts_tree_make_error_node(&skipped_children); CHECK(error); - children.size = repair.count; + children.size = count_below; array_push(&children, error); for (size_t i = 0; i < slice.trees.size; i++) @@ -647,8 +675,7 @@ static RepairResult ts_parser__repair_error(TSParser *self, StackSlice slice, CHECK(ts_parser__push(self, slice.version, parent, next_state)); LOG_ACTION("repair_found sym:%s, child_count:%lu, skipped:%lu", - SYM_NAME(symbol), repair.count + count_above_error, - parent->error_size); + SYM_NAME(symbol), repair.count, parent->error_size); unsigned my_error_cost = ts_stack_error_cost(self->stack, slice.version); unsigned my_error_depth = ts_stack_error_depth(self->stack, slice.version); @@ -657,7 +684,8 @@ static RepairResult ts_parser__repair_error(TSParser *self, StackSlice slice, unsigned error_cost = ts_stack_error_cost(self->stack, i); unsigned error_depth = ts_stack_error_depth(self->stack, i); if (error_depth > my_error_depth + 1 || - (error_depth == my_error_depth + 1 && error_cost >= my_error_cost)) { + (error_depth == my_error_depth + 1 && error_cost >= my_error_cost) || + (error_depth == my_error_depth && error_cost >= my_error_cost + ERROR_COST_THRESHOLD)) { LOG_ACTION("halt_other version:%u", i); ts_stack_halt(self->stack, i); }