Don't merge stack versions with different error costs

This commit is contained in:
Max Brunsfeld 2016-06-12 17:27:08 -07:00
parent ef2c3a10e3
commit 2b80e66188

View file

@ -26,8 +26,7 @@ struct StackNode {
StackLink links[MAX_LINK_COUNT];
short unsigned int link_count;
short unsigned int ref_count;
unsigned min_error_cost;
unsigned max_error_cost;
unsigned error_cost;
unsigned error_depth;
};
@ -101,8 +100,7 @@ static StackNode *stack_node_new(StackNode *next, TSTree *tree, bool is_pending,
.state = state,
.position = position,
.error_depth = 0,
.min_error_cost = is_error ? 1 : 0,
.max_error_cost = is_error ? 1 : 0,
.error_cost = is_error ? 1 : 0,
};
if (next) {
@ -110,14 +108,12 @@ static StackNode *stack_node_new(StackNode *next, TSTree *tree, bool is_pending,
node->links[0] = (StackLink){ next, tree, is_pending };
node->link_count = 1;
node->min_error_cost += next->min_error_cost;
node->max_error_cost += next->max_error_cost;
node->error_cost += next->error_cost;
node->error_depth = next->error_depth;
if (tree) {
ts_tree_retain(tree);
node->min_error_cost += tree->error_size;
node->max_error_cost += tree->error_size;
node->error_cost += tree->error_size;
} else {
node->error_depth++;
}
@ -142,15 +138,8 @@ static void stack_node_add_link(StackNode *self, StackLink link) {
if (self->link_count < MAX_LINK_COUNT) {
stack_node_retain(link.node);
if (link.tree) {
if (link.tree)
ts_tree_retain(link.tree);
size_t min_error_cost = link.tree->error_size + link.node->min_error_cost;
size_t max_error_cost = link.tree->error_size + link.node->max_error_cost;
if (min_error_cost < self->min_error_cost)
self->min_error_cost = min_error_cost;
if (max_error_cost < self->max_error_cost)
self->max_error_cost = max_error_cost;
}
self->links[self->link_count++] = (StackLink){
link.node, link.tree, link.is_pending,
@ -339,7 +328,7 @@ TSLength ts_stack_top_position(const Stack *self, StackVersion version) {
}
unsigned ts_stack_error_cost(const Stack *self, StackVersion version) {
return array_get(&self->heads, version)->node->min_error_cost;
return array_get(&self->heads, version)->node->error_cost;
}
unsigned ts_stack_error_depth(const Stack *self, StackVersion version) {
@ -486,7 +475,8 @@ bool ts_stack_merge(Stack *self, StackVersion version, StackVersion new_version)
if (new_node->state == node->state &&
new_node->position.chars == node->position.chars &&
new_node->error_depth == node->error_depth) {
new_node->error_depth == node->error_depth &&
new_node->error_cost == node->error_cost) {
for (size_t j = 0; j < new_node->link_count; j++)
stack_node_add_link(node, new_node->links[j]);
ts_stack_remove_version(self, new_version);
@ -496,31 +486,6 @@ bool ts_stack_merge(Stack *self, StackVersion version, StackVersion new_version)
}
}
void stack_node_remove_link(StackNode *self, size_t i,
StackNodeArray *node_pool) {
self->link_count--;
ts_tree_release(self->links[i].tree);
stack_node_release(self->links[i].node, node_pool);
memmove(&self->links[i], &self->links[i + 1],
(self->link_count - i) * sizeof(StackLink));
}
void stack_node_prune_paths_with_error_cost(StackNode *self, size_t cost,
StackNodeArray *node_pool) {
for (size_t i = 0; i < self->link_count; i++) {
StackLink link = self->links[i];
size_t link_cost = cost;
if (link.tree)
link_cost -= link.tree->error_size;
if (link.node->min_error_cost >= link_cost) {
stack_node_remove_link(self, i, node_pool);
i--;
} else if (link.node->max_error_cost >= link_cost) {
stack_node_prune_paths_with_error_cost(link.node, link_cost, node_pool);
}
}
}
void ts_stack_halt(Stack *self, StackVersion version) {
array_get(&self->heads, version)->is_halted = true;
}
@ -588,12 +553,7 @@ bool ts_stack_print_dot_graph(Stack *self, const char **symbol_names, FILE *f) {
fprintf(f, "shape=point margin=0 label=\"\"");
else
fprintf(f, "label=\"%d\"", node->state);
fprintf(f, " tooltip=\"error-count:%u, error-cost:", node->error_depth);
if (node->min_error_cost == node->max_error_cost)
fprintf(f, "%u", node->min_error_cost);
else
fprintf(f, "%u-%u", node->min_error_cost, node->max_error_cost);
fprintf(f, "\"];\n");
fprintf(f, " tooltip=\"error-count:%u, error-cost:%u\"];\n", node->error_depth, node->error_cost);
for (int j = 0; j < node->link_count; j++) {
StackLink link = node->links[j];