From 0faae52132526a41a6e129ae80aa75b04a2f6e6b Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Wed, 31 Aug 2016 10:51:59 -0700 Subject: [PATCH] Fix some inconsistencies in error cost calculation Signed-off-by: Nathan Sobo --- .../error_corpus/javascript_errors.txt | 13 +-- src/runtime/error_costs.h | 43 +++++++++ src/runtime/parser.c | 92 +++++++++---------- src/runtime/stack.c | 48 +++++++--- src/runtime/stack.h | 5 +- src/runtime/tree.c | 11 ++- src/runtime/tree.h | 2 +- 7 files changed, 135 insertions(+), 79 deletions(-) create mode 100644 src/runtime/error_costs.h diff --git a/spec/fixtures/error_corpus/javascript_errors.txt b/spec/fixtures/error_corpus/javascript_errors.txt index 973db0a8..d9e88a3e 100644 --- a/spec/fixtures/error_corpus/javascript_errors.txt +++ b/spec/fixtures/error_corpus/javascript_errors.txt @@ -59,12 +59,13 @@ if ({a: 'b'} {c: 'd'}) { (ERROR (object (pair (identifier) (string)))) (object (pair (identifier) (string))) (statement_block - (expression_statement (function - (formal_parameters (identifier)) - (statement_block (expression_statement (identifier)))) - (ERROR (function - (formal_parameters (identifier)) - (statement_block (expression_statement (identifier))))))))) + (expression_statement + (function + (formal_parameters (identifier)) + (statement_block (expression_statement (identifier)))) + (ERROR (function + (formal_parameters (identifier)) + (statement_block (expression_statement (identifier))))))))) =================================================== one invalid token at the end of the file diff --git a/src/runtime/error_costs.h b/src/runtime/error_costs.h new file mode 100644 index 00000000..b1640a30 --- /dev/null +++ b/src/runtime/error_costs.h @@ -0,0 +1,43 @@ +#ifndef RUNTIME_ERROR_COSTS_H_ +#define RUNTIME_ERROR_COSTS_H_ + +#define ERROR_COST_PER_SKIPPED_TREE 10 +#define ERROR_COST_PER_SKIPPED_LINE 3 +#define ERROR_COST_PER_SKIPPED_CHAR 0 + +typedef struct { + unsigned cost; + unsigned count; + unsigned depth; +} ErrorStatus; + +static inline int error_status_compare(ErrorStatus a, ErrorStatus b) { + static unsigned ERROR_COST_THRESHOLD = 3 * ERROR_COST_PER_SKIPPED_TREE; + static unsigned ERROR_COUNT_THRESHOLD = 1; + + // TODO remove + a.cost += ERROR_COST_PER_SKIPPED_TREE * a.count; + b.cost += ERROR_COST_PER_SKIPPED_TREE * b.count; + + if ((a.count + ERROR_COUNT_THRESHOLD < b.count) || + (a.count < b.count && a.cost <= b.cost)) { + return -1; + } + + if ((b.count + ERROR_COUNT_THRESHOLD < a.count) || + (b.count < a.count && b.cost <= a.cost)) { + return 1; + } + + if (a.cost + ERROR_COST_THRESHOLD < b.cost) { + return -1; + } + + if (b.cost + ERROR_COST_THRESHOLD < a.cost) { + return 1; + } + + return 0; +} + +#endif diff --git a/src/runtime/parser.c b/src/runtime/parser.c index 0942d38b..c6966bdf 100644 --- a/src/runtime/parser.c +++ b/src/runtime/parser.c @@ -12,6 +12,7 @@ #include "runtime/language.h" #include "runtime/alloc.h" #include "runtime/reduce_action.h" +#include "runtime/error_costs.h" #define LOG(...) \ if (self->lexer.debugger.debug_fn) { \ @@ -47,8 +48,6 @@ goto error; \ } -static const unsigned ERROR_COST_THRESHOLD = 3; - typedef struct { Parser *parser; TSSymbol lookahead_symbol; @@ -208,25 +207,35 @@ static bool parser__condense_stack(Parser *self) { bool result = false; for (StackVersion i = 0; i < ts_stack_version_count(self->stack); i++) { if (ts_stack_is_halted(self->stack, i)) { - result = true; ts_stack_remove_version(self->stack, i); + result = true; i--; continue; } - bool did_merge = false; - for (size_t j = 0; j < i; j++) { + ErrorStatus error_status = ts_stack_error_status(self->stack, i); + for (StackVersion j = 0; j < i; j++) { if (ts_stack_merge(self->stack, j, i)) { - did_merge = true; + result = true; + i--; break; } - } - if (did_merge) { - result = true; - i--; - continue; + switch (error_status_compare(error_status, + ts_stack_error_status(self->stack, j))) { + case -1: + ts_stack_remove_version(self->stack, j); + result = true; + i--; + j--; + break; + case 1: + ts_stack_remove_version(self->stack, i); + result = true; + i--; + break; + } } } return result; @@ -365,34 +374,23 @@ static bool parser__select_tree(Parser *self, TSTree *left, TSTree *right) { } static bool parser__better_version_exists(Parser *self, StackVersion version, - unsigned my_error_count, - unsigned my_error_cost) { - if (self->finished_tree && self->finished_tree->error_cost <= my_error_cost) + ErrorStatus my_error_status) { + if (self->finished_tree && + self->finished_tree->error_cost <= my_error_status.cost) return true; for (StackVersion i = 0, n = ts_stack_version_count(self->stack); i < n; i++) { if (i == version || ts_stack_is_halted(self->stack, i)) continue; - unsigned error_cost = ts_stack_error_cost(self->stack, i); - unsigned error_count = ts_stack_error_count(self->stack, i); - - if ((error_count > my_error_count + 1) || - (error_count > my_error_count && error_cost >= my_error_cost) || - (my_error_count == 0 && error_cost > my_error_cost) || - (error_count == my_error_count && - error_cost >= my_error_cost + ERROR_COST_THRESHOLD)) { - LOG("halt_other version:%u", i); - ts_stack_halt(self->stack, i); - continue; - } - - if ((my_error_count > error_count + 1) || - (my_error_count > error_count && my_error_cost >= error_cost) || - (error_count == 0 && my_error_cost > error_cost) || - (my_error_count == error_count && - my_error_cost >= error_cost + ERROR_COST_THRESHOLD)) { - return true; + switch (error_status_compare(my_error_status, + ts_stack_error_status(self->stack, i))) { + case -1: + LOG("halt_other version:%u", i); + ts_stack_halt(self->stack, i); + break; + case 1: + return true; } } @@ -516,11 +514,8 @@ static Reduction parser__reduce(Parser *self, StackVersion version, if (action->type == TSParseActionTypeRecover && child_count > 1 && allow_skipping) { - unsigned error_count = ts_stack_error_count(self->stack, slice.version); - unsigned error_cost = - ts_stack_error_cost(self->stack, slice.version) + 1; - if (!parser__better_version_exists(self, slice.version, error_count, - error_cost)) { + ErrorStatus error_status = ts_stack_error_status(self->stack, version); + if (!parser__better_version_exists(self, version, error_status)) { StackVersion other_version = ts_stack_duplicate_version(self->stack, slice.version); CHECK(other_version != STACK_VERSION_NONE); @@ -753,15 +748,13 @@ static RepairResult parser__repair_error(Parser *self, StackSlice slice, CHECK(parent); CHECK(parser__push(self, slice.version, parent, next_state)); - unsigned error_cost = ts_stack_error_cost(self->stack, slice.version); - unsigned error_count = ts_stack_error_count(self->stack, slice.version); - if (parser__better_version_exists(self, slice.version, error_count, - error_cost)) { + ErrorStatus error_status = ts_stack_error_status(self->stack, slice.version); + if (parser__better_version_exists(self, slice.version, error_status)) { LOG("no_better_repair_found"); ts_stack_halt(self->stack, slice.version); return RepairNoneFound; } else { - LOG("repair_found sym:%s, child_count:%lu, skipped:%lu", SYM_NAME(symbol), + LOG("repair_found sym:%s, child_count:%lu, cost:%u", SYM_NAME(symbol), repair.count, parent->error_cost); return RepairSucceeded; } @@ -954,9 +947,9 @@ error: static bool parser__handle_error(Parser *self, StackVersion version, TSSymbol lookahead_symbol) { - unsigned error_cost = ts_stack_error_cost(self->stack, version); - unsigned error_count = ts_stack_error_count(self->stack, version) + 1; - if (parser__better_version_exists(self, version, error_count, error_cost)) { + ErrorStatus error_status = ts_stack_error_status(self->stack, version); + error_status.count++; + if (parser__better_version_exists(self, version, error_status)) { ts_stack_halt(self->stack, version); LOG("bail_on_error"); return true; @@ -1005,9 +998,8 @@ static bool parser__recover(Parser *self, StackVersion version, TSStateId state, return parser__accept(self, version); } - unsigned error_cost = ts_stack_error_cost(self->stack, version); - unsigned error_count = ts_stack_error_count(self->stack, version); - if (parser__better_version_exists(self, version, error_count, error_cost)) { + ErrorStatus error_status = ts_stack_error_status(self->stack, version); + if (parser__better_version_exists(self, version, error_status)) { ts_stack_halt(self->stack, version); LOG("bail_on_recovery"); return true; @@ -1142,7 +1134,7 @@ static bool parser__advance(Parser *self, StackVersion version, } case TSParseActionTypeAccept: { - if (ts_stack_error_count(self->stack, version) > 0) + if (ts_stack_error_status(self->stack, version).count > 0) continue; LOG("accept"); diff --git a/src/runtime/stack.c b/src/runtime/stack.c index edd5136f..7e9dc0e9 100644 --- a/src/runtime/stack.c +++ b/src/runtime/stack.c @@ -28,6 +28,7 @@ struct StackNode { short unsigned int ref_count; unsigned error_cost; unsigned error_count; + unsigned error_depth; }; typedef struct { @@ -97,8 +98,6 @@ static StackNode *stack_node_new(StackNode *next, TSTree *tree, bool is_pending, .links = {}, .state = state, .position = position, - .error_count = 0, - .error_cost = 0, }; if (next) { @@ -107,22 +106,33 @@ static StackNode *stack_node_new(StackNode *next, TSTree *tree, bool is_pending, node->link_count = 1; node->links[0] = (StackLink){ next, tree, is_pending }; - node->error_cost = next->error_cost; node->error_count = next->error_count; + node->error_cost = next->error_cost; + node->error_depth = next->error_depth; if (tree) { ts_tree_retain(tree); + node->error_cost += tree->error_cost; + if (state == TS_STATE_ERROR) { if (!tree->extra) { - node->error_cost += 1 + tree->padding.rows + tree->size.rows; + node->error_cost += ERROR_COST_PER_SKIPPED_TREE + + ERROR_COST_PER_SKIPPED_CHAR * + (tree->padding.chars + tree->size.chars) + + ERROR_COST_PER_SKIPPED_LINE * + (tree->padding.rows + tree->size.rows); } } else { - node->error_cost += tree->error_cost; + node->error_depth++; } } else { - node->error_cost++; node->error_count++; + node->error_depth = 0; } + } else { + node->error_count = 0; + node->error_cost = 0; + node->error_depth = 0; } return node; @@ -339,12 +349,18 @@ TSLength ts_stack_top_position(const Stack *self, StackVersion version) { return array_get(&self->heads, version)->node->position; } -unsigned ts_stack_error_cost(const Stack *self, StackVersion version) { - return array_get(&self->heads, version)->node->error_cost; +ErrorStatus ts_stack_error_status(const Stack *self, StackVersion version) { + StackNode *node = array_get(&self->heads, version)->node; + return (ErrorStatus){ + .cost = node->error_cost, + .count = node->error_count, + .depth = node->error_depth, + }; } unsigned ts_stack_error_count(const Stack *self, StackVersion version) { - return array_get(&self->heads, version)->node->error_count; + StackNode *node = array_get(&self->heads, version)->node; + return node->error_count; } bool ts_stack_push(Stack *self, StackVersion version, TSTree *tree, @@ -551,11 +567,12 @@ bool ts_stack_print_dot_graph(Stack *self, const char **symbol_names, FILE *f) { fprintf(f, "shape=point margin=0 label=\"\""); else fprintf(f, "label=\"%d\"", node->state); - fprintf( - f, - " tooltip=\"position: %lu,%lu\nerror-count: %u\nerror-cost: %u\"];\n", - node->position.rows, node->position.columns, node->error_count, - node->error_cost); + + fprintf(f, + " tooltip=\"position: %lu,%lu\nerror_count: %u\nerror_cost: %u\n" + "error_depth: %u\"];\n", + node->position.rows, node->position.columns, node->error_count, + node->error_cost, node->error_depth); for (int j = 0; j < node->link_count; j++) { StackLink link = node->links[j]; @@ -581,7 +598,8 @@ bool ts_stack_print_dot_graph(Stack *self, const char **symbol_names, FILE *f) { } if (!link.tree->named) fprintf(f, "'"); - fprintf(f, "\""); + fprintf(f, "\" labeltooltip=\"error_cost: %u\"", + link.tree->error_cost); } fprintf(f, "];\n"); diff --git a/src/runtime/stack.h b/src/runtime/stack.h index 9f9dbd1c..4eba9f3d 100644 --- a/src/runtime/stack.h +++ b/src/runtime/stack.h @@ -8,6 +8,7 @@ extern "C" { #include "tree_sitter/parser.h" #include "runtime/array.h" #include "runtime/tree.h" +#include "runtime/error_costs.h" #include typedef struct Stack Stack; @@ -95,9 +96,7 @@ StackPopResult ts_stack_pop_pending(Stack *, StackVersion); StackPopResult ts_stack_pop_all(Stack *, StackVersion); -unsigned ts_stack_error_count(const Stack *, StackVersion); - -unsigned ts_stack_error_cost(const Stack *, StackVersion); +ErrorStatus ts_stack_error_status(const Stack *, StackVersion); bool ts_stack_merge(Stack *, StackVersion, StackVersion); diff --git a/src/runtime/tree.c b/src/runtime/tree.c index cd24cf63..76e0b6a8 100644 --- a/src/runtime/tree.c +++ b/src/runtime/tree.c @@ -7,6 +7,7 @@ #include "runtime/alloc.h" #include "runtime/tree.h" #include "runtime/length.h" +#include "runtime/error_costs.h" TSStateId TS_TREE_STATE_NONE = USHRT_MAX; @@ -150,10 +151,11 @@ void ts_tree_set_children(TSTree *self, size_t child_count, TSTree **children) { } if (self->symbol == ts_builtin_sym_error) { - self->error_cost = self->size.rows; + self->error_cost += ERROR_COST_PER_SKIPPED_CHAR * self->size.chars + + ERROR_COST_PER_SKIPPED_LINE * self->size.rows; for (size_t i = 0; i < child_count; i++) if (!self->children[i]->extra) - self->error_cost++; + self->error_cost += ERROR_COST_PER_SKIPPED_TREE; } if (child_count > 0) { @@ -424,8 +426,9 @@ void ts_tree__print_dot_graph(const TSTree *self, size_t offset, if (self->extra) fprintf(f, ", fontcolor=gray"); - fprintf(f, ", tooltip=\"range:%lu - %lu\nstate:%d\"]\n", offset, - offset + ts_tree_total_chars(self), self->parse_state); + fprintf(f, ", tooltip=\"range:%lu - %lu\nstate:%d\nerror-cost:%u\"]\n", + offset, offset + ts_tree_total_chars(self), self->parse_state, + self->error_cost); for (size_t i = 0; i < self->child_count; i++) { const TSTree *child = self->children[i]; ts_tree__print_dot_graph(child, offset, language, f); diff --git a/src/runtime/tree.h b/src/runtime/tree.h index 2771402b..d147b4ed 100644 --- a/src/runtime/tree.h +++ b/src/runtime/tree.h @@ -33,7 +33,7 @@ typedef struct TSTree { TSSymbol symbol; TSStateId parse_state; - size_t error_cost; + unsigned error_cost; struct { TSSymbol symbol;