Fix some inconsistencies in error cost calculation

Signed-off-by: Nathan Sobo <nathan@github.com>
This commit is contained in:
Max Brunsfeld 2016-08-31 10:51:59 -07:00 committed by Nathan Sobo
parent 883a7c8266
commit 0faae52132
7 changed files with 135 additions and 79 deletions

View file

@ -59,12 +59,13 @@ if ({a: 'b'} {c: 'd'}) {
(ERROR (object (pair (identifier) (string))))
(object (pair (identifier) (string)))
(statement_block
(expression_statement (function
(formal_parameters (identifier))
(statement_block (expression_statement (identifier))))
(ERROR (function
(formal_parameters (identifier))
(statement_block (expression_statement (identifier)))))))))
(expression_statement
(function
(formal_parameters (identifier))
(statement_block (expression_statement (identifier))))
(ERROR (function
(formal_parameters (identifier))
(statement_block (expression_statement (identifier)))))))))
===================================================
one invalid token at the end of the file

43
src/runtime/error_costs.h Normal file
View file

@ -0,0 +1,43 @@
#ifndef RUNTIME_ERROR_COSTS_H_
#define RUNTIME_ERROR_COSTS_H_
#define ERROR_COST_PER_SKIPPED_TREE 10
#define ERROR_COST_PER_SKIPPED_LINE 3
#define ERROR_COST_PER_SKIPPED_CHAR 0
typedef struct {
unsigned cost;
unsigned count;
unsigned depth;
} ErrorStatus;
static inline int error_status_compare(ErrorStatus a, ErrorStatus b) {
static unsigned ERROR_COST_THRESHOLD = 3 * ERROR_COST_PER_SKIPPED_TREE;
static unsigned ERROR_COUNT_THRESHOLD = 1;
// TODO remove
a.cost += ERROR_COST_PER_SKIPPED_TREE * a.count;
b.cost += ERROR_COST_PER_SKIPPED_TREE * b.count;
if ((a.count + ERROR_COUNT_THRESHOLD < b.count) ||
(a.count < b.count && a.cost <= b.cost)) {
return -1;
}
if ((b.count + ERROR_COUNT_THRESHOLD < a.count) ||
(b.count < a.count && b.cost <= a.cost)) {
return 1;
}
if (a.cost + ERROR_COST_THRESHOLD < b.cost) {
return -1;
}
if (b.cost + ERROR_COST_THRESHOLD < a.cost) {
return 1;
}
return 0;
}
#endif

View file

@ -12,6 +12,7 @@
#include "runtime/language.h"
#include "runtime/alloc.h"
#include "runtime/reduce_action.h"
#include "runtime/error_costs.h"
#define LOG(...) \
if (self->lexer.debugger.debug_fn) { \
@ -47,8 +48,6 @@
goto error; \
}
static const unsigned ERROR_COST_THRESHOLD = 3;
typedef struct {
Parser *parser;
TSSymbol lookahead_symbol;
@ -208,25 +207,35 @@ static bool parser__condense_stack(Parser *self) {
bool result = false;
for (StackVersion i = 0; i < ts_stack_version_count(self->stack); i++) {
if (ts_stack_is_halted(self->stack, i)) {
result = true;
ts_stack_remove_version(self->stack, i);
result = true;
i--;
continue;
}
bool did_merge = false;
for (size_t j = 0; j < i; j++) {
ErrorStatus error_status = ts_stack_error_status(self->stack, i);
for (StackVersion j = 0; j < i; j++) {
if (ts_stack_merge(self->stack, j, i)) {
did_merge = true;
result = true;
i--;
break;
}
}
if (did_merge) {
result = true;
i--;
continue;
switch (error_status_compare(error_status,
ts_stack_error_status(self->stack, j))) {
case -1:
ts_stack_remove_version(self->stack, j);
result = true;
i--;
j--;
break;
case 1:
ts_stack_remove_version(self->stack, i);
result = true;
i--;
break;
}
}
}
return result;
@ -365,34 +374,23 @@ static bool parser__select_tree(Parser *self, TSTree *left, TSTree *right) {
}
static bool parser__better_version_exists(Parser *self, StackVersion version,
unsigned my_error_count,
unsigned my_error_cost) {
if (self->finished_tree && self->finished_tree->error_cost <= my_error_cost)
ErrorStatus my_error_status) {
if (self->finished_tree &&
self->finished_tree->error_cost <= my_error_status.cost)
return true;
for (StackVersion i = 0, n = ts_stack_version_count(self->stack); i < n; i++) {
if (i == version || ts_stack_is_halted(self->stack, i))
continue;
unsigned error_cost = ts_stack_error_cost(self->stack, i);
unsigned error_count = ts_stack_error_count(self->stack, i);
if ((error_count > my_error_count + 1) ||
(error_count > my_error_count && error_cost >= my_error_cost) ||
(my_error_count == 0 && error_cost > my_error_cost) ||
(error_count == my_error_count &&
error_cost >= my_error_cost + ERROR_COST_THRESHOLD)) {
LOG("halt_other version:%u", i);
ts_stack_halt(self->stack, i);
continue;
}
if ((my_error_count > error_count + 1) ||
(my_error_count > error_count && my_error_cost >= error_cost) ||
(error_count == 0 && my_error_cost > error_cost) ||
(my_error_count == error_count &&
my_error_cost >= error_cost + ERROR_COST_THRESHOLD)) {
return true;
switch (error_status_compare(my_error_status,
ts_stack_error_status(self->stack, i))) {
case -1:
LOG("halt_other version:%u", i);
ts_stack_halt(self->stack, i);
break;
case 1:
return true;
}
}
@ -516,11 +514,8 @@ static Reduction parser__reduce(Parser *self, StackVersion version,
if (action->type == TSParseActionTypeRecover && child_count > 1 &&
allow_skipping) {
unsigned error_count = ts_stack_error_count(self->stack, slice.version);
unsigned error_cost =
ts_stack_error_cost(self->stack, slice.version) + 1;
if (!parser__better_version_exists(self, slice.version, error_count,
error_cost)) {
ErrorStatus error_status = ts_stack_error_status(self->stack, version);
if (!parser__better_version_exists(self, version, error_status)) {
StackVersion other_version =
ts_stack_duplicate_version(self->stack, slice.version);
CHECK(other_version != STACK_VERSION_NONE);
@ -753,15 +748,13 @@ static RepairResult parser__repair_error(Parser *self, StackSlice slice,
CHECK(parent);
CHECK(parser__push(self, slice.version, parent, next_state));
unsigned error_cost = ts_stack_error_cost(self->stack, slice.version);
unsigned error_count = ts_stack_error_count(self->stack, slice.version);
if (parser__better_version_exists(self, slice.version, error_count,
error_cost)) {
ErrorStatus error_status = ts_stack_error_status(self->stack, slice.version);
if (parser__better_version_exists(self, slice.version, error_status)) {
LOG("no_better_repair_found");
ts_stack_halt(self->stack, slice.version);
return RepairNoneFound;
} else {
LOG("repair_found sym:%s, child_count:%lu, skipped:%lu", SYM_NAME(symbol),
LOG("repair_found sym:%s, child_count:%lu, cost:%u", SYM_NAME(symbol),
repair.count, parent->error_cost);
return RepairSucceeded;
}
@ -954,9 +947,9 @@ error:
static bool parser__handle_error(Parser *self, StackVersion version,
TSSymbol lookahead_symbol) {
unsigned error_cost = ts_stack_error_cost(self->stack, version);
unsigned error_count = ts_stack_error_count(self->stack, version) + 1;
if (parser__better_version_exists(self, version, error_count, error_cost)) {
ErrorStatus error_status = ts_stack_error_status(self->stack, version);
error_status.count++;
if (parser__better_version_exists(self, version, error_status)) {
ts_stack_halt(self->stack, version);
LOG("bail_on_error");
return true;
@ -1005,9 +998,8 @@ static bool parser__recover(Parser *self, StackVersion version, TSStateId state,
return parser__accept(self, version);
}
unsigned error_cost = ts_stack_error_cost(self->stack, version);
unsigned error_count = ts_stack_error_count(self->stack, version);
if (parser__better_version_exists(self, version, error_count, error_cost)) {
ErrorStatus error_status = ts_stack_error_status(self->stack, version);
if (parser__better_version_exists(self, version, error_status)) {
ts_stack_halt(self->stack, version);
LOG("bail_on_recovery");
return true;
@ -1142,7 +1134,7 @@ static bool parser__advance(Parser *self, StackVersion version,
}
case TSParseActionTypeAccept: {
if (ts_stack_error_count(self->stack, version) > 0)
if (ts_stack_error_status(self->stack, version).count > 0)
continue;
LOG("accept");

View file

@ -28,6 +28,7 @@ struct StackNode {
short unsigned int ref_count;
unsigned error_cost;
unsigned error_count;
unsigned error_depth;
};
typedef struct {
@ -97,8 +98,6 @@ static StackNode *stack_node_new(StackNode *next, TSTree *tree, bool is_pending,
.links = {},
.state = state,
.position = position,
.error_count = 0,
.error_cost = 0,
};
if (next) {
@ -107,22 +106,33 @@ static StackNode *stack_node_new(StackNode *next, TSTree *tree, bool is_pending,
node->link_count = 1;
node->links[0] = (StackLink){ next, tree, is_pending };
node->error_cost = next->error_cost;
node->error_count = next->error_count;
node->error_cost = next->error_cost;
node->error_depth = next->error_depth;
if (tree) {
ts_tree_retain(tree);
node->error_cost += tree->error_cost;
if (state == TS_STATE_ERROR) {
if (!tree->extra) {
node->error_cost += 1 + tree->padding.rows + tree->size.rows;
node->error_cost += ERROR_COST_PER_SKIPPED_TREE +
ERROR_COST_PER_SKIPPED_CHAR *
(tree->padding.chars + tree->size.chars) +
ERROR_COST_PER_SKIPPED_LINE *
(tree->padding.rows + tree->size.rows);
}
} else {
node->error_cost += tree->error_cost;
node->error_depth++;
}
} else {
node->error_cost++;
node->error_count++;
node->error_depth = 0;
}
} else {
node->error_count = 0;
node->error_cost = 0;
node->error_depth = 0;
}
return node;
@ -339,12 +349,18 @@ TSLength ts_stack_top_position(const Stack *self, StackVersion version) {
return array_get(&self->heads, version)->node->position;
}
unsigned ts_stack_error_cost(const Stack *self, StackVersion version) {
return array_get(&self->heads, version)->node->error_cost;
ErrorStatus ts_stack_error_status(const Stack *self, StackVersion version) {
StackNode *node = array_get(&self->heads, version)->node;
return (ErrorStatus){
.cost = node->error_cost,
.count = node->error_count,
.depth = node->error_depth,
};
}
unsigned ts_stack_error_count(const Stack *self, StackVersion version) {
return array_get(&self->heads, version)->node->error_count;
StackNode *node = array_get(&self->heads, version)->node;
return node->error_count;
}
bool ts_stack_push(Stack *self, StackVersion version, TSTree *tree,
@ -551,11 +567,12 @@ bool ts_stack_print_dot_graph(Stack *self, const char **symbol_names, FILE *f) {
fprintf(f, "shape=point margin=0 label=\"\"");
else
fprintf(f, "label=\"%d\"", node->state);
fprintf(
f,
" tooltip=\"position: %lu,%lu\nerror-count: %u\nerror-cost: %u\"];\n",
node->position.rows, node->position.columns, node->error_count,
node->error_cost);
fprintf(f,
" tooltip=\"position: %lu,%lu\nerror_count: %u\nerror_cost: %u\n"
"error_depth: %u\"];\n",
node->position.rows, node->position.columns, node->error_count,
node->error_cost, node->error_depth);
for (int j = 0; j < node->link_count; j++) {
StackLink link = node->links[j];
@ -581,7 +598,8 @@ bool ts_stack_print_dot_graph(Stack *self, const char **symbol_names, FILE *f) {
}
if (!link.tree->named)
fprintf(f, "'");
fprintf(f, "\"");
fprintf(f, "\" labeltooltip=\"error_cost: %u\"",
link.tree->error_cost);
}
fprintf(f, "];\n");

View file

@ -8,6 +8,7 @@ extern "C" {
#include "tree_sitter/parser.h"
#include "runtime/array.h"
#include "runtime/tree.h"
#include "runtime/error_costs.h"
#include <stdio.h>
typedef struct Stack Stack;
@ -95,9 +96,7 @@ StackPopResult ts_stack_pop_pending(Stack *, StackVersion);
StackPopResult ts_stack_pop_all(Stack *, StackVersion);
unsigned ts_stack_error_count(const Stack *, StackVersion);
unsigned ts_stack_error_cost(const Stack *, StackVersion);
ErrorStatus ts_stack_error_status(const Stack *, StackVersion);
bool ts_stack_merge(Stack *, StackVersion, StackVersion);

View file

@ -7,6 +7,7 @@
#include "runtime/alloc.h"
#include "runtime/tree.h"
#include "runtime/length.h"
#include "runtime/error_costs.h"
TSStateId TS_TREE_STATE_NONE = USHRT_MAX;
@ -150,10 +151,11 @@ void ts_tree_set_children(TSTree *self, size_t child_count, TSTree **children) {
}
if (self->symbol == ts_builtin_sym_error) {
self->error_cost = self->size.rows;
self->error_cost += ERROR_COST_PER_SKIPPED_CHAR * self->size.chars +
ERROR_COST_PER_SKIPPED_LINE * self->size.rows;
for (size_t i = 0; i < child_count; i++)
if (!self->children[i]->extra)
self->error_cost++;
self->error_cost += ERROR_COST_PER_SKIPPED_TREE;
}
if (child_count > 0) {
@ -424,8 +426,9 @@ void ts_tree__print_dot_graph(const TSTree *self, size_t offset,
if (self->extra)
fprintf(f, ", fontcolor=gray");
fprintf(f, ", tooltip=\"range:%lu - %lu\nstate:%d\"]\n", offset,
offset + ts_tree_total_chars(self), self->parse_state);
fprintf(f, ", tooltip=\"range:%lu - %lu\nstate:%d\nerror-cost:%u\"]\n",
offset, offset + ts_tree_total_chars(self), self->parse_state,
self->error_cost);
for (size_t i = 0; i < self->child_count; i++) {
const TSTree *child = self->children[i];
ts_tree__print_dot_graph(child, offset, language, f);

View file

@ -33,7 +33,7 @@ typedef struct TSTree {
TSSymbol symbol;
TSStateId parse_state;
size_t error_cost;
unsigned error_cost;
struct {
TSSymbol symbol;