Refactor error comparisons

* Deal with mergeability outside of error comparison function
* Make `better_version_exists` function pure (don't halt other versions
as a side effect).
* Tweak error comparison logic

Signed-off-by: Rick Winfrey <rewinfrey@github.com>
This commit is contained in:
Max Brunsfeld 2017-09-13 16:38:15 -07:00 committed by Rick Winfrey
parent 71595ffde6
commit d291af9a31
9 changed files with 153 additions and 164 deletions

View file

@ -89,7 +89,6 @@
],
'sources': [
'src/runtime/document.c',
'src/runtime/error_costs.c',
'src/runtime/get_changed_ranges.c',
'src/runtime/language.c',
'src/runtime/lexer.c',

View file

@ -1,42 +0,0 @@
#include "runtime/error_costs.h"
static const unsigned MAX_COST_DIFFERENCE = 16 * ERROR_COST_PER_SKIPPED_TREE;
static const unsigned MAX_PUSH_COUNT_WITH_COUNT_DIFFERENCE = 24;
ErrorComparison error_status_compare(ErrorStatus a, ErrorStatus b, bool are_mergeable) {
ErrorComparison result = ErrorComparisonNone;
if (!a.recovering && b.recovering) {
if (a.push_count > MAX_PUSH_COUNT_WITH_COUNT_DIFFERENCE) {
return ErrorComparisonTakeLeft;
} else {
result = ErrorComparisonPreferLeft;
}
}
if (!b.recovering && a.recovering) {
if (b.push_count > MAX_PUSH_COUNT_WITH_COUNT_DIFFERENCE) {
return ErrorComparisonTakeRight;
} else {
result = ErrorComparisonPreferRight;
}
}
if (a.cost < b.cost) {
if (are_mergeable || (b.cost - a.cost) * (1 + a.push_count) > MAX_COST_DIFFERENCE) {
return ErrorComparisonTakeLeft;
} else {
return ErrorComparisonPreferLeft;
}
}
if (b.cost < a.cost) {
if (are_mergeable || (a.cost - b.cost) * (1 + b.push_count) > MAX_COST_DIFFERENCE) {
return ErrorComparisonTakeRight;
} else {
return ErrorComparisonPreferRight;
}
}
return result;
}

View file

@ -1,36 +1,9 @@
#ifndef RUNTIME_ERROR_COSTS_H_
#define RUNTIME_ERROR_COSTS_H_
#include <stdbool.h>
#ifdef __cplusplus
extern "C" {
#endif
#define ERROR_STATE 0
#define ERROR_COST_PER_SKIPPED_TREE 100
#define ERROR_COST_PER_SKIPPED_LINE 30
#define ERROR_COST_PER_SKIPPED_CHAR 1
typedef struct {
unsigned cost;
unsigned push_count;
unsigned depth;
bool recovering;
} ErrorStatus;
typedef enum {
ErrorComparisonTakeLeft,
ErrorComparisonPreferLeft,
ErrorComparisonNone,
ErrorComparisonPreferRight,
ErrorComparisonTakeRight,
} ErrorComparison;
ErrorComparison error_status_compare(ErrorStatus a, ErrorStatus b, bool can_merge);
#ifdef __cplusplus
}
#endif
#endif

View file

@ -36,6 +36,21 @@
static const unsigned MAX_VERSION_COUNT = 6;
static const unsigned MAX_SUMMARY_DEPTH = 16;
static const int MAX_COST_DIFFERENCE = 16 * ERROR_COST_PER_SKIPPED_TREE;
typedef struct {
unsigned cost;
unsigned push_count;
bool is_in_error;
} ErrorStatus;
typedef enum {
ErrorComparisonTakeLeft,
ErrorComparisonPreferLeft,
ErrorComparisonNone,
ErrorComparisonPreferRight,
ErrorComparisonTakeRight,
} ErrorComparison;
static void parser__log(Parser *self) {
if (self->lexer.logger.log) {
@ -120,10 +135,72 @@ static void parser__breakdown_lookahead(Parser *self, Tree **lookahead,
}
}
static bool parser__condense_stack(Parser *self) {
bool all_versions_have_error = true;
unsigned old_version_count = ts_stack_version_count(self->stack);
static ErrorComparison parser__compare_versions(Parser *self, ErrorStatus a, ErrorStatus b) {
if (!a.is_in_error && b.is_in_error) {
if (a.cost < b.cost) {
return ErrorComparisonTakeLeft;
} else {
return ErrorComparisonPreferLeft;
}
}
if (a.is_in_error && !b.is_in_error) {
if (b.cost < a.cost) {
return ErrorComparisonTakeRight;
} else {
return ErrorComparisonPreferRight;
}
}
if (a.cost < b.cost) {
if ((b.cost - a.cost) * (1 + a.push_count) > MAX_COST_DIFFERENCE) {
return ErrorComparisonTakeLeft;
} else {
return ErrorComparisonPreferLeft;
}
}
if (b.cost < a.cost) {
if ((a.cost - b.cost) * (1 + b.push_count) > MAX_COST_DIFFERENCE) {
return ErrorComparisonTakeRight;
} else {
return ErrorComparisonPreferRight;
}
}
return ErrorComparisonNone;
}
static bool parser__better_version_exists(Parser *self, StackVersion version,
bool is_in_error, unsigned cost) {
if (self->finished_tree && self->finished_tree->error_cost <= cost) return true;
ErrorStatus status = {.cost = cost, .is_in_error = is_in_error, .push_count = 0};
for (StackVersion i = 0, n = ts_stack_version_count(self->stack); i < n; i++) {
if (i == version || ts_stack_is_halted(self->stack, i)) continue;
ErrorStatus status_i = {
.cost = ts_stack_error_cost(self->stack, i),
.is_in_error = ts_stack_top_state(self->stack, i) == ERROR_STATE,
.push_count = ts_stack_push_count(self->stack, i)
};
switch (parser__compare_versions(self, status, status_i)) {
case ErrorComparisonTakeRight:
return true;
case ErrorComparisonPreferRight:
if (ts_stack_can_merge(self->stack, i, version)) return true;
default:
break;
}
}
return false;
}
static bool parser__condense_stack(Parser *self) {
bool made_changes = false;
unsigned min_error_cost = UINT_MAX;
bool all_versions_have_error = true;
for (StackVersion i = 0; i < ts_stack_version_count(self->stack); i++) {
if (ts_stack_is_halted(self->stack, i)) {
ts_stack_remove_version(self->stack, i);
@ -131,35 +208,47 @@ static bool parser__condense_stack(Parser *self) {
continue;
}
ErrorStatus right_error_status = ts_stack_error_status(self->stack, i);
if (!right_error_status.recovering) all_versions_have_error = false;
ErrorStatus status_i = {
.cost = ts_stack_error_cost(self->stack, i),
.push_count = ts_stack_push_count(self->stack, i),
.is_in_error = ts_stack_top_state(self->stack, i) == ERROR_STATE,
};
if (!status_i.is_in_error) all_versions_have_error = false;
if (status_i.cost < min_error_cost) min_error_cost = status_i.cost;
for (StackVersion j = 0; j < i; j++) {
bool can_merge = ts_stack_can_merge(self->stack, i, j);
ErrorStatus left_error_status = ts_stack_error_status(self->stack, j);
ErrorStatus status_j = {
.cost = ts_stack_error_cost(self->stack, j),
.push_count = ts_stack_push_count(self->stack, j),
.is_in_error = ts_stack_top_state(self->stack, j) == ERROR_STATE,
};
switch (error_status_compare(left_error_status, right_error_status, can_merge)) {
bool can_merge = ts_stack_can_merge(self->stack, j, i);
switch (parser__compare_versions(self, status_j, status_i)) {
case ErrorComparisonTakeLeft:
made_changes = true;
ts_stack_remove_version(self->stack, i);
i--;
j = i;
break;
case ErrorComparisonTakeRight:
ts_stack_remove_version(self->stack, j);
i--;
j--;
break;
case ErrorComparisonPreferLeft:
if (can_merge) {
made_changes = true;
ts_stack_remove_version(self->stack, i);
i--;
j = i;
}
break;
case ErrorComparisonNone:
if (can_merge) {
made_changes = true;
ts_stack_force_merge(self->stack, j, i);
i--;
j = i;
}
break;
case ErrorComparisonPreferRight:
made_changes = true;
if (can_merge) {
ts_stack_remove_version(self->stack, j);
i--;
@ -169,12 +258,11 @@ static bool parser__condense_stack(Parser *self) {
j = i;
}
break;
case ErrorComparisonNone:
if (can_merge) {
ts_stack_force_merge(self->stack, j, i);
i--;
}
case ErrorComparisonTakeRight:
made_changes = true;
ts_stack_remove_version(self->stack, j);
i--;
j--;
break;
}
}
@ -182,15 +270,17 @@ static bool parser__condense_stack(Parser *self) {
while (ts_stack_version_count(self->stack) > MAX_VERSION_COUNT) {
ts_stack_remove_version(self->stack, MAX_VERSION_COUNT);
made_changes = true;
}
unsigned new_version_count = ts_stack_version_count(self->stack);
if (new_version_count != old_version_count) {
if (made_changes) {
LOG("condense");
LOG_STACK();
}
return all_versions_have_error && new_version_count > 0;
return
(all_versions_have_error && ts_stack_version_count(self->stack) > 0) ||
(self->finished_tree && self->finished_tree->error_cost < min_error_cost);
}
static void parser__restore_external_scanner(Parser *self, Tree *external_token) {
@ -501,30 +591,6 @@ static bool parser__select_tree(Parser *self, Tree *left, Tree *right) {
}
}
static bool parser__better_version_exists(Parser *self, StackVersion version,
ErrorStatus my_error_status) {
if (self->finished_tree && self->finished_tree->error_cost <= my_error_status.cost) return true;
for (StackVersion i = 0, n = ts_stack_version_count(self->stack); i < n; i++) {
if (i == version || ts_stack_is_halted(self->stack, i)) continue;
switch (error_status_compare(my_error_status,
ts_stack_error_status(self->stack, i),
ts_stack_can_merge(self->stack, i, version))) {
case ErrorComparisonTakeLeft:
LOG("halt_other version:%u", i);
ts_stack_halt(self->stack, i);
break;
case ErrorComparisonTakeRight:
if (i < version) return true;
default:
break;
}
}
return false;
}
static void parser__shift(Parser *self, StackVersion version, TSStateId state,
Tree *lookahead, bool extra) {
if (extra != lookahead->extra) {
@ -766,10 +832,8 @@ static bool parser__do_potential_reductions(Parser *self, StackVersion version)
static void parser__handle_error(Parser *self, StackVersion version, TSSymbol lookahead_symbol) {
// If there are other stack versions that are clearly better than this one,
// just halt this version.
ErrorStatus error_status = ts_stack_error_status(self->stack, version);
error_status.recovering = true;
error_status.cost += ERROR_COST_PER_SKIPPED_TREE;
if (parser__better_version_exists(self, version, error_status)) {
unsigned new_cost = ts_stack_error_cost(self->stack, version) + ERROR_COST_PER_SKIPPED_TREE;
if (parser__better_version_exists(self, version, true, new_cost)) {
ts_stack_halt(self->stack, version);
LOG("bail_on_error");
return;
@ -837,15 +901,11 @@ static void parser__recover(Parser *self, StackVersion version, Tree *lookahead)
if (entry.state == ERROR_STATE) continue;
unsigned depth = entry.depth + ts_stack_depth_since_error(self->stack, version);
ErrorStatus status = {
.recovering = false,
.push_count = 0,
.cost =
depth * ERROR_COST_PER_SKIPPED_TREE +
(position.chars - entry.position.chars) * ERROR_COST_PER_SKIPPED_CHAR +
(position.extent.row - entry.position.extent.row) * ERROR_COST_PER_SKIPPED_LINE
};
if (parser__better_version_exists(self, version, status)) break;
unsigned new_cost =
depth * ERROR_COST_PER_SKIPPED_TREE +
(position.chars - entry.position.chars) * ERROR_COST_PER_SKIPPED_CHAR +
(position.extent.row - entry.position.extent.row) * ERROR_COST_PER_SKIPPED_LINE;
if (parser__better_version_exists(self, version, false, new_cost)) break;
unsigned count = 0;
if (ts_language_actions(self->language, entry.state, lookahead->symbol, &count) && count > 0) {
@ -932,8 +992,7 @@ static void parser__recover(Parser *self, StackVersion version, Tree *lookahead)
bool can_be_extra = ts_language_symbol_metadata(self->language, lookahead->symbol).extra;
parser__shift(self, version, ERROR_STATE, lookahead, can_be_extra);
ErrorStatus error_status = ts_stack_error_status(self->stack, version);
if (parser__better_version_exists(self, version, error_status)) {
if (parser__better_version_exists(self, version, true, ts_stack_error_cost(self->stack, version))) {
ts_stack_halt(self->stack, version);
}
}
@ -1095,10 +1154,14 @@ Tree *parser_parse(Parser *self, TSInput input, Tree *old_tree, bool halt_on_err
self->reusable_node = reusable_node;
bool all_versions_have_error = parser__condense_stack(self);
if (halt_on_error && all_versions_have_error) {
parser__halt_parse(self);
break;
bool should_halt = parser__condense_stack(self);
if (should_halt) {
if (self->finished_tree) {
break;
} else if (halt_on_error) {
parser__halt_parse(self);
break;
}
}
self->is_split = (version > 1);

View file

@ -167,7 +167,8 @@ static void stack_node_add_link(StackNode *self, StackLink link) {
StackLink existing_link = self->links[i];
if (stack__tree_is_equivalent(existing_link.tree, link.tree)) {
if (existing_link.node == link.node) return;
if (existing_link.node->state == link.node->state) {
if (existing_link.node->state == link.node->state &&
existing_link.node->position.bytes == link.node->position.bytes) {
for (int j = 0; j < link.node->link_count; j++) {
stack_node_add_link(existing_link.node, link.node->links[j]);
}
@ -380,13 +381,9 @@ void ts_stack_set_last_external_token(Stack *self, StackVersion version, Tree *t
head->last_external_token = token;
}
ErrorStatus ts_stack_error_status(const Stack *self, StackVersion version) {
unsigned ts_stack_error_cost(const Stack *self, StackVersion version) {
StackHead *head = array_get(&self->heads, version);
return (ErrorStatus){
.cost = head->node->error_cost,
.recovering = head->node->state == ERROR_STATE,
.push_count = head->push_count,
};
return head->node->error_cost;
}
void ts_stack_push(Stack *self, StackVersion version, Tree *tree, bool pending, TSStateId state) {

View file

@ -108,7 +108,7 @@ void ts_stack_record_summary(Stack *, StackVersion, unsigned max_depth);
StackSummary *ts_stack_get_summary(Stack *, StackVersion);
ErrorStatus ts_stack_error_status(const Stack *, StackVersion);
unsigned ts_stack_error_cost(const Stack *, StackVersion version);
bool ts_stack_merge(Stack *, StackVersion, StackVersion);

View file

@ -129,15 +129,13 @@ int b() {
(compound_statement
(declaration
(type_identifier)
(ERROR (identifier))
(ERROR (identifier) (identifier))
(init_declarator
(identifier)
(ERROR (identifier))
(number_literal)))
(declaration
(type_identifier)
(ERROR (identifier))
(ERROR (identifier) (identifier))
(init_declarator
(identifier)
(ERROR (identifier))
(number_literal))))))

View file

@ -37,10 +37,10 @@ h i j k;
(identifier)
(ERROR (identifier) (identifier)))
(statement_block
(ERROR (identifier) (identifier) (identifier))
(expression_statement (identifier))))
(ERROR (identifier) (identifier) (identifier))
(expression_statement (identifier)))
(ERROR (identifier))
(expression_statement (identifier) (ERROR (identifier) (identifier)))))
(ERROR (identifier))
(expression_statement (identifier) (ERROR (identifier) (identifier))))
===================================================
one invalid subtree right after the viable prefix

View file

@ -91,15 +91,15 @@ describe("Parser", [&]() {
TSNode error = ts_node_named_child(ts_node_child(root, 0), 1);
AssertThat(ts_node_type(error, document), Equals("ERROR"));
AssertThat(get_node_text(error), Equals(", @@@@@"));
AssertThat(get_node_text(error), Equals("@@@@@,"));
AssertThat(ts_node_child_count(error), Equals<size_t>(2));
TSNode comma = ts_node_child(error, 0);
AssertThat(get_node_text(comma), Equals(","));
TSNode garbage = ts_node_child(error, 1);
TSNode garbage = ts_node_child(error, 0);
AssertThat(get_node_text(garbage), Equals("@@@@@"));
TSNode comma = ts_node_child(error, 1);
AssertThat(get_node_text(comma), Equals(","));
TSNode node_after_error = ts_node_next_named_sibling(error);
AssertThat(ts_node_type(node_after_error, document), Equals("true"));
AssertThat(get_node_text(node_after_error), Equals("true"));
@ -116,16 +116,17 @@ describe("Parser", [&]() {
TSNode error = ts_node_named_child(ts_node_child(root, 0), 1);
AssertThat(ts_node_type(error, document), Equals("ERROR"));
AssertThat(get_node_text(error), Equals("faaaaalse,"));
AssertThat(ts_node_child_count(error), Equals<size_t>(2));
TSNode comma = ts_node_child(error, 0);
AssertThat(ts_node_type(comma, document), Equals(","));
AssertThat(get_node_text(comma), Equals(","));
TSNode garbage = ts_node_child(error, 1);
TSNode garbage = ts_node_child(error, 0);
AssertThat(ts_node_type(garbage, document), Equals("ERROR"));
AssertThat(get_node_text(garbage), Equals("faaaaalse"));
TSNode comma = ts_node_child(error, 1);
AssertThat(ts_node_type(comma, document), Equals(","));
AssertThat(get_node_text(comma), Equals(","));
TSNode last = ts_node_next_named_sibling(error);
AssertThat(ts_node_type(last, document), Equals("true"));
AssertThat(ts_node_start_byte(last), Equals(strlen(" [123, faaaaalse, ")));