diff --git a/src/runtime/parser.c b/src/runtime/parser.c index 2771dfa6..7ccd4cd9 100644 --- a/src/runtime/parser.c +++ b/src/runtime/parser.c @@ -483,7 +483,7 @@ error: return (Reduction){ ReduceFailed }; } -static inline const TSParseAction * ts_parser__reductions_after_sequence( +static inline const TSParseAction *ts_parser__reductions_after_sequence( TSParser *self, TSStateId start_state, const TreeArray *trees_below, size_t tree_count_below, const TreeArray *trees_above, TSSymbol lookahead_symbol, size_t *count) { @@ -515,8 +515,8 @@ static inline const TSParseAction * ts_parser__reductions_after_sequence( state = action.data.to_state; } - const TSParseAction *actions = ts_language_actions( - self->language, state, lookahead_symbol, count); + const TSParseAction *actions = + ts_language_actions(self->language, state, lookahead_symbol, count); if (actions[*count - 1].type != TSParseActionTypeReduce) (*count)--; @@ -568,17 +568,16 @@ static StackIterateAction ts_parser__error_repair_callback( continue; TSStateId state_after_repair = repair_symbol_action.data.to_state; - if (!ts_language_has_action(self->language, state_after_repair, lookahead_symbol)) + if (!ts_language_has_action(self->language, state_after_repair, + lookahead_symbol)) continue; if (count_needed_below_error != last_repair_count) { assert(count_needed_below_error > last_repair_count); last_repair_count = count_needed_below_error; - repair_reductions = - ts_parser__reductions_after_sequence(self, state, trees, - count_needed_below_error, - trees_above_error, lookahead_symbol, - &repair_reduction_count); + repair_reductions = ts_parser__reductions_after_sequence( + self, state, trees, count_needed_below_error, trees_above_error, + lookahead_symbol, &repair_reduction_count); } if (!repair_reductions) @@ -620,11 +619,11 @@ static RepairResult ts_parser__repair_error(TSParser *self, StackSlice slice, for (size_t i = 0; i < action_count; i++) if (actions[i].type == TSParseActionTypeReduce && actions[i].data.child_count > session.tree_count_above_error) - CHECK(array_push(&self->reduce_actions, - ((ReduceAction){ - .symbol = actions[i].data.symbol, - .count = actions[i].data.child_count, - }))); + CHECK(array_push( + &self->reduce_actions, + ((ReduceAction){ + .symbol = actions[i].data.symbol, .count = actions[i].data.child_count, + }))); StackPopResult pop = ts_stack_iterate( self->stack, slice.version, ts_parser__error_repair_callback, &session); @@ -680,15 +679,16 @@ static RepairResult ts_parser__repair_error(TSParser *self, StackSlice slice, unsigned my_error_cost = ts_stack_error_cost(self->stack, slice.version); unsigned my_error_depth = ts_stack_error_depth(self->stack, slice.version); for (StackVersion i = 0; i < ts_stack_version_count(self->stack); i++) { - if (i != slice.version) { - unsigned error_cost = ts_stack_error_cost(self->stack, i); - unsigned error_depth = ts_stack_error_depth(self->stack, i); - if (error_depth > my_error_depth + 1 || - (error_depth == my_error_depth + 1 && error_cost >= my_error_cost) || - (error_depth == my_error_depth && error_cost >= my_error_cost + ERROR_COST_THRESHOLD)) { - LOG_ACTION("halt_other version:%u", i); - ts_stack_halt(self->stack, i); - } + if (i == slice.version || ts_stack_is_halted(self->stack, i)) + continue; + unsigned error_cost = ts_stack_error_cost(self->stack, i); + unsigned error_depth = ts_stack_error_depth(self->stack, i); + if (error_depth > my_error_depth + 1 || + (error_depth == my_error_depth + 1 && error_cost >= my_error_cost) || + (error_depth == my_error_depth && + error_cost >= my_error_cost + ERROR_COST_THRESHOLD)) { + LOG_ACTION("halt_other version:%u", i); + ts_stack_halt(self->stack, i); } } @@ -756,10 +756,41 @@ error: return false; } +static bool ts_parser__halt_if_better_version_exists(TSParser *self, + StackVersion version, + unsigned my_error_depth, + unsigned my_error_cost) { + for (StackVersion i = 0; i < ts_stack_version_count(self->stack); i++) { + if (i == version || ts_stack_is_halted(self->stack, i)) + continue; + unsigned error_cost = ts_stack_error_cost(self->stack, i); + unsigned error_depth = ts_stack_error_depth(self->stack, i); + if (error_depth < my_error_depth - 1 || + (error_depth == my_error_depth - 1 && error_cost < my_error_cost) || + (error_depth == my_error_depth && + error_cost + ERROR_COST_THRESHOLD <= my_error_cost)) { + ts_stack_halt(self->stack, version); + return true; + } + } + + return false; +} + static bool ts_parser__handle_error(TSParser *self, StackVersion version, TSStateId state, TSTree *lookahead) { size_t previous_version_count = ts_stack_version_count(self->stack); + unsigned my_error_cost = ts_stack_error_cost(self->stack, version); + unsigned my_error_depth = ts_stack_error_depth(self->stack, version) + 1; + if (ts_parser__halt_if_better_version_exists(self, version, my_error_depth, + my_error_cost)) { + LOG_ACTION("bail_on_error"); + return true; + } + + LOG_ACTION("handle_error"); + bool has_shift_action = false; array_clear(&self->reduce_actions); for (TSSymbol symbol = 0; symbol < self->language->symbol_count; symbol++) { @@ -834,19 +865,10 @@ static bool ts_parser__recover(TSParser *self, StackVersion version, unsigned my_error_cost = ts_stack_error_cost(self->stack, version); unsigned my_error_depth = ts_stack_error_depth(self->stack, version); - for (StackVersion i = 0; i < ts_stack_version_count(self->stack); i++) { - if (i != version) { - unsigned error_cost = ts_stack_error_cost(self->stack, i); - unsigned error_depth = ts_stack_error_depth(self->stack, i); - if (error_depth < my_error_depth - 1 || - (error_depth == my_error_depth - 1 && error_cost <= my_error_cost) || - (error_depth == my_error_depth && - error_cost + ERROR_COST_THRESHOLD <= my_error_cost)) { - ts_stack_halt(self->stack, version); - LOG_ACTION("bail_on_error"); - return true; - } - } + if (ts_parser__halt_if_better_version_exists(self, version, my_error_depth, + my_error_cost)) { + LOG_ACTION("bail_on_recovery"); + return true; } LOG_ACTION("recover state:%u", state); @@ -904,8 +926,9 @@ static bool ts_parser__consume_lookahead(TSParser *self, StackVersion version, break; } - LOG_ACTION("handle_error"); CHECK(ts_parser__handle_error(self, version, state, lookahead)); + if (ts_stack_is_halted(self->stack, version)) + return true; error_repair_failed = false; break; } @@ -975,8 +998,8 @@ static bool ts_parser__consume_lookahead(TSParser *self, StackVersion version, } case TSParseActionTypeRecover: { - CHECK(ts_parser__recover(self, version, action.data.to_state, - lookahead)); + CHECK( + ts_parser__recover(self, version, action.data.to_state, lookahead)); return true; } } diff --git a/src/runtime/stack.c b/src/runtime/stack.c index 7c464e7a..3d25a895 100644 --- a/src/runtime/stack.c +++ b/src/runtime/stack.c @@ -553,7 +553,8 @@ bool ts_stack_print_dot_graph(Stack *self, const char **symbol_names, FILE *f) { fprintf(f, "shape=point margin=0 label=\"\""); else fprintf(f, "label=\"%d\"", node->state); - fprintf(f, " tooltip=\"error-count:%u, error-cost:%u\"];\n", node->error_depth, node->error_cost); + fprintf(f, " tooltip=\"error-count:%u, error-cost:%u\"];\n", + node->error_depth, node->error_cost); for (int j = 0; j < node->link_count; j++) { StackLink link = node->links[j];