Fix some inconsistencies in error cost calculation
Signed-off-by: Nathan Sobo <nathan@github.com>
This commit is contained in:
parent
883a7c8266
commit
0faae52132
7 changed files with 135 additions and 79 deletions
13
spec/fixtures/error_corpus/javascript_errors.txt
vendored
13
spec/fixtures/error_corpus/javascript_errors.txt
vendored
|
|
@ -59,12 +59,13 @@ if ({a: 'b'} {c: 'd'}) {
|
|||
(ERROR (object (pair (identifier) (string))))
|
||||
(object (pair (identifier) (string)))
|
||||
(statement_block
|
||||
(expression_statement (function
|
||||
(formal_parameters (identifier))
|
||||
(statement_block (expression_statement (identifier))))
|
||||
(ERROR (function
|
||||
(formal_parameters (identifier))
|
||||
(statement_block (expression_statement (identifier)))))))))
|
||||
(expression_statement
|
||||
(function
|
||||
(formal_parameters (identifier))
|
||||
(statement_block (expression_statement (identifier))))
|
||||
(ERROR (function
|
||||
(formal_parameters (identifier))
|
||||
(statement_block (expression_statement (identifier)))))))))
|
||||
|
||||
===================================================
|
||||
one invalid token at the end of the file
|
||||
|
|
|
|||
43
src/runtime/error_costs.h
Normal file
43
src/runtime/error_costs.h
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
#ifndef RUNTIME_ERROR_COSTS_H_
|
||||
#define RUNTIME_ERROR_COSTS_H_
|
||||
|
||||
#define ERROR_COST_PER_SKIPPED_TREE 10
|
||||
#define ERROR_COST_PER_SKIPPED_LINE 3
|
||||
#define ERROR_COST_PER_SKIPPED_CHAR 0
|
||||
|
||||
typedef struct {
|
||||
unsigned cost;
|
||||
unsigned count;
|
||||
unsigned depth;
|
||||
} ErrorStatus;
|
||||
|
||||
static inline int error_status_compare(ErrorStatus a, ErrorStatus b) {
|
||||
static unsigned ERROR_COST_THRESHOLD = 3 * ERROR_COST_PER_SKIPPED_TREE;
|
||||
static unsigned ERROR_COUNT_THRESHOLD = 1;
|
||||
|
||||
// TODO remove
|
||||
a.cost += ERROR_COST_PER_SKIPPED_TREE * a.count;
|
||||
b.cost += ERROR_COST_PER_SKIPPED_TREE * b.count;
|
||||
|
||||
if ((a.count + ERROR_COUNT_THRESHOLD < b.count) ||
|
||||
(a.count < b.count && a.cost <= b.cost)) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
if ((b.count + ERROR_COUNT_THRESHOLD < a.count) ||
|
||||
(b.count < a.count && b.cost <= a.cost)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (a.cost + ERROR_COST_THRESHOLD < b.cost) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (b.cost + ERROR_COST_THRESHOLD < a.cost) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
|
@ -12,6 +12,7 @@
|
|||
#include "runtime/language.h"
|
||||
#include "runtime/alloc.h"
|
||||
#include "runtime/reduce_action.h"
|
||||
#include "runtime/error_costs.h"
|
||||
|
||||
#define LOG(...) \
|
||||
if (self->lexer.debugger.debug_fn) { \
|
||||
|
|
@ -47,8 +48,6 @@
|
|||
goto error; \
|
||||
}
|
||||
|
||||
static const unsigned ERROR_COST_THRESHOLD = 3;
|
||||
|
||||
typedef struct {
|
||||
Parser *parser;
|
||||
TSSymbol lookahead_symbol;
|
||||
|
|
@ -208,25 +207,35 @@ static bool parser__condense_stack(Parser *self) {
|
|||
bool result = false;
|
||||
for (StackVersion i = 0; i < ts_stack_version_count(self->stack); i++) {
|
||||
if (ts_stack_is_halted(self->stack, i)) {
|
||||
result = true;
|
||||
ts_stack_remove_version(self->stack, i);
|
||||
result = true;
|
||||
i--;
|
||||
continue;
|
||||
}
|
||||
|
||||
bool did_merge = false;
|
||||
for (size_t j = 0; j < i; j++) {
|
||||
ErrorStatus error_status = ts_stack_error_status(self->stack, i);
|
||||
|
||||
for (StackVersion j = 0; j < i; j++) {
|
||||
if (ts_stack_merge(self->stack, j, i)) {
|
||||
did_merge = true;
|
||||
result = true;
|
||||
i--;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (did_merge) {
|
||||
result = true;
|
||||
i--;
|
||||
continue;
|
||||
switch (error_status_compare(error_status,
|
||||
ts_stack_error_status(self->stack, j))) {
|
||||
case -1:
|
||||
ts_stack_remove_version(self->stack, j);
|
||||
result = true;
|
||||
i--;
|
||||
j--;
|
||||
break;
|
||||
case 1:
|
||||
ts_stack_remove_version(self->stack, i);
|
||||
result = true;
|
||||
i--;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return result;
|
||||
|
|
@ -365,34 +374,23 @@ static bool parser__select_tree(Parser *self, TSTree *left, TSTree *right) {
|
|||
}
|
||||
|
||||
static bool parser__better_version_exists(Parser *self, StackVersion version,
|
||||
unsigned my_error_count,
|
||||
unsigned my_error_cost) {
|
||||
if (self->finished_tree && self->finished_tree->error_cost <= my_error_cost)
|
||||
ErrorStatus my_error_status) {
|
||||
if (self->finished_tree &&
|
||||
self->finished_tree->error_cost <= my_error_status.cost)
|
||||
return true;
|
||||
|
||||
for (StackVersion i = 0, n = ts_stack_version_count(self->stack); i < n; i++) {
|
||||
if (i == version || ts_stack_is_halted(self->stack, i))
|
||||
continue;
|
||||
|
||||
unsigned error_cost = ts_stack_error_cost(self->stack, i);
|
||||
unsigned error_count = ts_stack_error_count(self->stack, i);
|
||||
|
||||
if ((error_count > my_error_count + 1) ||
|
||||
(error_count > my_error_count && error_cost >= my_error_cost) ||
|
||||
(my_error_count == 0 && error_cost > my_error_cost) ||
|
||||
(error_count == my_error_count &&
|
||||
error_cost >= my_error_cost + ERROR_COST_THRESHOLD)) {
|
||||
LOG("halt_other version:%u", i);
|
||||
ts_stack_halt(self->stack, i);
|
||||
continue;
|
||||
}
|
||||
|
||||
if ((my_error_count > error_count + 1) ||
|
||||
(my_error_count > error_count && my_error_cost >= error_cost) ||
|
||||
(error_count == 0 && my_error_cost > error_cost) ||
|
||||
(my_error_count == error_count &&
|
||||
my_error_cost >= error_cost + ERROR_COST_THRESHOLD)) {
|
||||
return true;
|
||||
switch (error_status_compare(my_error_status,
|
||||
ts_stack_error_status(self->stack, i))) {
|
||||
case -1:
|
||||
LOG("halt_other version:%u", i);
|
||||
ts_stack_halt(self->stack, i);
|
||||
break;
|
||||
case 1:
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -516,11 +514,8 @@ static Reduction parser__reduce(Parser *self, StackVersion version,
|
|||
|
||||
if (action->type == TSParseActionTypeRecover && child_count > 1 &&
|
||||
allow_skipping) {
|
||||
unsigned error_count = ts_stack_error_count(self->stack, slice.version);
|
||||
unsigned error_cost =
|
||||
ts_stack_error_cost(self->stack, slice.version) + 1;
|
||||
if (!parser__better_version_exists(self, slice.version, error_count,
|
||||
error_cost)) {
|
||||
ErrorStatus error_status = ts_stack_error_status(self->stack, version);
|
||||
if (!parser__better_version_exists(self, version, error_status)) {
|
||||
StackVersion other_version =
|
||||
ts_stack_duplicate_version(self->stack, slice.version);
|
||||
CHECK(other_version != STACK_VERSION_NONE);
|
||||
|
|
@ -753,15 +748,13 @@ static RepairResult parser__repair_error(Parser *self, StackSlice slice,
|
|||
CHECK(parent);
|
||||
CHECK(parser__push(self, slice.version, parent, next_state));
|
||||
|
||||
unsigned error_cost = ts_stack_error_cost(self->stack, slice.version);
|
||||
unsigned error_count = ts_stack_error_count(self->stack, slice.version);
|
||||
if (parser__better_version_exists(self, slice.version, error_count,
|
||||
error_cost)) {
|
||||
ErrorStatus error_status = ts_stack_error_status(self->stack, slice.version);
|
||||
if (parser__better_version_exists(self, slice.version, error_status)) {
|
||||
LOG("no_better_repair_found");
|
||||
ts_stack_halt(self->stack, slice.version);
|
||||
return RepairNoneFound;
|
||||
} else {
|
||||
LOG("repair_found sym:%s, child_count:%lu, skipped:%lu", SYM_NAME(symbol),
|
||||
LOG("repair_found sym:%s, child_count:%lu, cost:%u", SYM_NAME(symbol),
|
||||
repair.count, parent->error_cost);
|
||||
return RepairSucceeded;
|
||||
}
|
||||
|
|
@ -954,9 +947,9 @@ error:
|
|||
|
||||
static bool parser__handle_error(Parser *self, StackVersion version,
|
||||
TSSymbol lookahead_symbol) {
|
||||
unsigned error_cost = ts_stack_error_cost(self->stack, version);
|
||||
unsigned error_count = ts_stack_error_count(self->stack, version) + 1;
|
||||
if (parser__better_version_exists(self, version, error_count, error_cost)) {
|
||||
ErrorStatus error_status = ts_stack_error_status(self->stack, version);
|
||||
error_status.count++;
|
||||
if (parser__better_version_exists(self, version, error_status)) {
|
||||
ts_stack_halt(self->stack, version);
|
||||
LOG("bail_on_error");
|
||||
return true;
|
||||
|
|
@ -1005,9 +998,8 @@ static bool parser__recover(Parser *self, StackVersion version, TSStateId state,
|
|||
return parser__accept(self, version);
|
||||
}
|
||||
|
||||
unsigned error_cost = ts_stack_error_cost(self->stack, version);
|
||||
unsigned error_count = ts_stack_error_count(self->stack, version);
|
||||
if (parser__better_version_exists(self, version, error_count, error_cost)) {
|
||||
ErrorStatus error_status = ts_stack_error_status(self->stack, version);
|
||||
if (parser__better_version_exists(self, version, error_status)) {
|
||||
ts_stack_halt(self->stack, version);
|
||||
LOG("bail_on_recovery");
|
||||
return true;
|
||||
|
|
@ -1142,7 +1134,7 @@ static bool parser__advance(Parser *self, StackVersion version,
|
|||
}
|
||||
|
||||
case TSParseActionTypeAccept: {
|
||||
if (ts_stack_error_count(self->stack, version) > 0)
|
||||
if (ts_stack_error_status(self->stack, version).count > 0)
|
||||
continue;
|
||||
|
||||
LOG("accept");
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ struct StackNode {
|
|||
short unsigned int ref_count;
|
||||
unsigned error_cost;
|
||||
unsigned error_count;
|
||||
unsigned error_depth;
|
||||
};
|
||||
|
||||
typedef struct {
|
||||
|
|
@ -97,8 +98,6 @@ static StackNode *stack_node_new(StackNode *next, TSTree *tree, bool is_pending,
|
|||
.links = {},
|
||||
.state = state,
|
||||
.position = position,
|
||||
.error_count = 0,
|
||||
.error_cost = 0,
|
||||
};
|
||||
|
||||
if (next) {
|
||||
|
|
@ -107,22 +106,33 @@ static StackNode *stack_node_new(StackNode *next, TSTree *tree, bool is_pending,
|
|||
node->link_count = 1;
|
||||
node->links[0] = (StackLink){ next, tree, is_pending };
|
||||
|
||||
node->error_cost = next->error_cost;
|
||||
node->error_count = next->error_count;
|
||||
node->error_cost = next->error_cost;
|
||||
node->error_depth = next->error_depth;
|
||||
|
||||
if (tree) {
|
||||
ts_tree_retain(tree);
|
||||
node->error_cost += tree->error_cost;
|
||||
|
||||
if (state == TS_STATE_ERROR) {
|
||||
if (!tree->extra) {
|
||||
node->error_cost += 1 + tree->padding.rows + tree->size.rows;
|
||||
node->error_cost += ERROR_COST_PER_SKIPPED_TREE +
|
||||
ERROR_COST_PER_SKIPPED_CHAR *
|
||||
(tree->padding.chars + tree->size.chars) +
|
||||
ERROR_COST_PER_SKIPPED_LINE *
|
||||
(tree->padding.rows + tree->size.rows);
|
||||
}
|
||||
} else {
|
||||
node->error_cost += tree->error_cost;
|
||||
node->error_depth++;
|
||||
}
|
||||
} else {
|
||||
node->error_cost++;
|
||||
node->error_count++;
|
||||
node->error_depth = 0;
|
||||
}
|
||||
} else {
|
||||
node->error_count = 0;
|
||||
node->error_cost = 0;
|
||||
node->error_depth = 0;
|
||||
}
|
||||
|
||||
return node;
|
||||
|
|
@ -339,12 +349,18 @@ TSLength ts_stack_top_position(const Stack *self, StackVersion version) {
|
|||
return array_get(&self->heads, version)->node->position;
|
||||
}
|
||||
|
||||
unsigned ts_stack_error_cost(const Stack *self, StackVersion version) {
|
||||
return array_get(&self->heads, version)->node->error_cost;
|
||||
ErrorStatus ts_stack_error_status(const Stack *self, StackVersion version) {
|
||||
StackNode *node = array_get(&self->heads, version)->node;
|
||||
return (ErrorStatus){
|
||||
.cost = node->error_cost,
|
||||
.count = node->error_count,
|
||||
.depth = node->error_depth,
|
||||
};
|
||||
}
|
||||
|
||||
unsigned ts_stack_error_count(const Stack *self, StackVersion version) {
|
||||
return array_get(&self->heads, version)->node->error_count;
|
||||
StackNode *node = array_get(&self->heads, version)->node;
|
||||
return node->error_count;
|
||||
}
|
||||
|
||||
bool ts_stack_push(Stack *self, StackVersion version, TSTree *tree,
|
||||
|
|
@ -551,11 +567,12 @@ bool ts_stack_print_dot_graph(Stack *self, const char **symbol_names, FILE *f) {
|
|||
fprintf(f, "shape=point margin=0 label=\"\"");
|
||||
else
|
||||
fprintf(f, "label=\"%d\"", node->state);
|
||||
fprintf(
|
||||
f,
|
||||
" tooltip=\"position: %lu,%lu\nerror-count: %u\nerror-cost: %u\"];\n",
|
||||
node->position.rows, node->position.columns, node->error_count,
|
||||
node->error_cost);
|
||||
|
||||
fprintf(f,
|
||||
" tooltip=\"position: %lu,%lu\nerror_count: %u\nerror_cost: %u\n"
|
||||
"error_depth: %u\"];\n",
|
||||
node->position.rows, node->position.columns, node->error_count,
|
||||
node->error_cost, node->error_depth);
|
||||
|
||||
for (int j = 0; j < node->link_count; j++) {
|
||||
StackLink link = node->links[j];
|
||||
|
|
@ -581,7 +598,8 @@ bool ts_stack_print_dot_graph(Stack *self, const char **symbol_names, FILE *f) {
|
|||
}
|
||||
if (!link.tree->named)
|
||||
fprintf(f, "'");
|
||||
fprintf(f, "\"");
|
||||
fprintf(f, "\" labeltooltip=\"error_cost: %u\"",
|
||||
link.tree->error_cost);
|
||||
}
|
||||
|
||||
fprintf(f, "];\n");
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ extern "C" {
|
|||
#include "tree_sitter/parser.h"
|
||||
#include "runtime/array.h"
|
||||
#include "runtime/tree.h"
|
||||
#include "runtime/error_costs.h"
|
||||
#include <stdio.h>
|
||||
|
||||
typedef struct Stack Stack;
|
||||
|
|
@ -95,9 +96,7 @@ StackPopResult ts_stack_pop_pending(Stack *, StackVersion);
|
|||
|
||||
StackPopResult ts_stack_pop_all(Stack *, StackVersion);
|
||||
|
||||
unsigned ts_stack_error_count(const Stack *, StackVersion);
|
||||
|
||||
unsigned ts_stack_error_cost(const Stack *, StackVersion);
|
||||
ErrorStatus ts_stack_error_status(const Stack *, StackVersion);
|
||||
|
||||
bool ts_stack_merge(Stack *, StackVersion, StackVersion);
|
||||
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@
|
|||
#include "runtime/alloc.h"
|
||||
#include "runtime/tree.h"
|
||||
#include "runtime/length.h"
|
||||
#include "runtime/error_costs.h"
|
||||
|
||||
TSStateId TS_TREE_STATE_NONE = USHRT_MAX;
|
||||
|
||||
|
|
@ -150,10 +151,11 @@ void ts_tree_set_children(TSTree *self, size_t child_count, TSTree **children) {
|
|||
}
|
||||
|
||||
if (self->symbol == ts_builtin_sym_error) {
|
||||
self->error_cost = self->size.rows;
|
||||
self->error_cost += ERROR_COST_PER_SKIPPED_CHAR * self->size.chars +
|
||||
ERROR_COST_PER_SKIPPED_LINE * self->size.rows;
|
||||
for (size_t i = 0; i < child_count; i++)
|
||||
if (!self->children[i]->extra)
|
||||
self->error_cost++;
|
||||
self->error_cost += ERROR_COST_PER_SKIPPED_TREE;
|
||||
}
|
||||
|
||||
if (child_count > 0) {
|
||||
|
|
@ -424,8 +426,9 @@ void ts_tree__print_dot_graph(const TSTree *self, size_t offset,
|
|||
if (self->extra)
|
||||
fprintf(f, ", fontcolor=gray");
|
||||
|
||||
fprintf(f, ", tooltip=\"range:%lu - %lu\nstate:%d\"]\n", offset,
|
||||
offset + ts_tree_total_chars(self), self->parse_state);
|
||||
fprintf(f, ", tooltip=\"range:%lu - %lu\nstate:%d\nerror-cost:%u\"]\n",
|
||||
offset, offset + ts_tree_total_chars(self), self->parse_state,
|
||||
self->error_cost);
|
||||
for (size_t i = 0; i < self->child_count; i++) {
|
||||
const TSTree *child = self->children[i];
|
||||
ts_tree__print_dot_graph(child, offset, language, f);
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ typedef struct TSTree {
|
|||
|
||||
TSSymbol symbol;
|
||||
TSStateId parse_state;
|
||||
size_t error_cost;
|
||||
unsigned error_cost;
|
||||
|
||||
struct {
|
||||
TSSymbol symbol;
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue