diff --git a/spec/runtime/parse_stack_spec.cc b/spec/runtime/parse_stack_spec.cc index e0c480bf..a854adca 100644 --- a/spec/runtime/parse_stack_spec.cc +++ b/spec/runtime/parse_stack_spec.cc @@ -4,7 +4,7 @@ #include "runtime/length.h" enum { - stateA, stateB, stateC, stateD, stateE, stateF, stateG, stageH + stateA, stateB, stateC, stateD, stateE, stateF, stateG, stateH }; enum { @@ -187,7 +187,7 @@ describe("ParseStack", [&]() { AssertThat(merged, IsTrue()); }); - it("re-joins the heads, creating an 'ambiguity' node", [&]() { + it("re-joins the heads, creating an ambiguity node", [&]() { AssertThat(ts_parse_stack_head_count(stack), Equals(1)); ParseStackNode *head = ts_parse_stack_head(stack, 0); @@ -212,6 +212,38 @@ describe("ParseStack", [&]() { AssertThat(head->successor_count, Equals(1)); }); }); + + describe("when a head with multiple paths is reduced", [&]() { + before_each([&]() { + // A0__B1__C2__D3__G5 + // \______E4__F3__/ + ts_parse_stack_shift(stack, 0, stateG, trees[5]); + ts_parse_stack_shift(stack, 1, stateG, trees[5]); + }); + + it("reduces along all paths, creating an ambiguity node", [&]() { + // A0__B1__C2__H6 + // \______E4__/ + ts_parse_stack_reduce(stack, 0, stateH, symbol6, 2); + AssertThat(ts_parse_stack_head_count(stack), Equals(1)); + + ParseStackNode *head = ts_parse_stack_head(stack, 0); + AssertThat(head->state, Equals(stateH)); + AssertThat(head->tree, Fulfills(EqualsTree( + ts_tree_make_ambiguity(2, tree_array({ + ts_tree_make_node(symbol6, 2, tree_array({ + trees[3], + trees[5], + }), false), + ts_tree_make_node(symbol6, 2, tree_array({ + trees[3], + trees[5], + }), false) + })), + symbol_names))); + AssertThat(head->successor_count, Equals(2)); + }); + }); }); }); diff --git a/src/runtime/parse_stack.c b/src/runtime/parse_stack.c index ff40136b..fa3841f4 100644 --- a/src/runtime/parse_stack.c +++ b/src/runtime/parse_stack.c @@ -1,5 +1,6 @@ #include "tree_sitter/parser.h" #include "runtime/tree.h" +#include "runtime/tree_vector.h" #include "runtime/parse_stack.h" #include "runtime/length.h" #include @@ -64,42 +65,95 @@ bool ts_parse_stack_shift(ParseStack *this, int head_index, TSStateId state, TST return false; } +#define MAX_PATH_COUNT 8 + bool ts_parse_stack_reduce(ParseStack *this, int head_index, TSStateId state, TSSymbol symbol, int child_count) { - ParseStackNode *head = this->heads[head_index]; + int path_count = 1; + ParseStackNode *nodes_by_path[MAX_PATH_COUNT] = {this->heads[head_index]}; + TreeVector children_by_path[MAX_PATH_COUNT] = {tree_vector_new(child_count)}; + size_t child_counts_by_path[MAX_PATH_COUNT] = {child_count}; /* - * Walk down the stack to determine which symbols will be reduced. - * The child node count is known ahead of time, but some children - * may be ubiquitous tokens, which don't count. + * Reduce along every possible path in parallel. Stop when the given number + * of child trees have been collected along every path. */ - ParseStackNode *next_node = head; - for (int i = 0; i < child_count; i++) { - TSTree *child = next_node->tree; - if (ts_tree_is_extra(child)) - child_count++; - next_node = next_node->successors[0]; - if (!next_node) - break; + bool all_paths_done = false; + while (!all_paths_done) { + all_paths_done = true; + int current_path_count = path_count; + for (int path = 0; path < current_path_count; path++) { + if (children_by_path[path].size == child_counts_by_path[path]) + continue; + else + all_paths_done = false; + + /* + * Children that are 'extra' do not count towards the total child count. + */ + ParseStackNode *node = nodes_by_path[path]; + if (ts_tree_is_extra(node->tree)) + child_counts_by_path[path]++; + + /* + * If a node has more than one successor, create new paths for each of + * the additional successors. + */ + tree_vector_push(&children_by_path[path], node->tree); + + for (int i = 0; i < node->successor_count; i++) { + int next_path; + if (i > 0) { + if (path_count == MAX_PATH_COUNT) break; + next_path = path_count; + child_counts_by_path[next_path] = child_counts_by_path[path]; + children_by_path[next_path] = tree_vector_copy(&children_by_path[path]); + path_count++; + } else { + next_path = path; + } + + nodes_by_path[next_path] = node->successors[i]; + } + } } - TSTree **children = malloc(child_count * sizeof(TSTree *)); - next_node = head; - for (int i = 0; i < child_count; i++) { - children[child_count - i - 1] = next_node->tree; - next_node = next_node->successors[0]; + TSTree *parent; + if (path_count > 1) { + TSTree **trees_by_path = malloc(path_count * sizeof(TSTree *)); + for (int path = 0; path < path_count; path++) { + stack_node_retain(nodes_by_path[path]); + tree_vector_reverse(&children_by_path[path]); + trees_by_path[path] = ts_tree_make_node( + symbol, + child_counts_by_path[path], + children_by_path[path].contents, + false + ); + parent = ts_tree_make_ambiguity(path_count, trees_by_path); + } + } else { + stack_node_retain(nodes_by_path[0]); + tree_vector_reverse(&children_by_path[0]); + parent = ts_tree_make_node( + symbol, + child_counts_by_path[0], + children_by_path[0].contents, + false + ); } - TSTree *parent = ts_tree_make_node(symbol, child_count, children, false); - - stack_node_retain(next_node); stack_node_release(this->heads[head_index]); - this->heads[head_index] = next_node; + this->heads[head_index] = nodes_by_path[0]; if (parse_stack_merge_head(this, head_index, state, parent)) return true; - this->heads[head_index] = stack_node_new(next_node, state, parent); + this->heads[head_index] = stack_node_new(nodes_by_path[0], state, parent); + for (int i = 1; i < path_count; i++) { + stack_node_add_successor(this->heads[head_index], nodes_by_path[i]); + } + return false; } diff --git a/src/runtime/tree_vector.h b/src/runtime/tree_vector.h new file mode 100644 index 00000000..110166b0 --- /dev/null +++ b/src/runtime/tree_vector.h @@ -0,0 +1,59 @@ +#ifndef RUNTIME_TREE_VECTOR_H_ +#define RUNTIME_TREE_VECTOR_H_ + +#ifdef __cplusplus +extern "C" { +#endif + +#include +#include "./tree.h" + +typedef struct { + TSTree **contents; + size_t capacity; + size_t size; +} TreeVector; + +static inline TreeVector tree_vector_new(size_t size) { + return (TreeVector) { + .contents = malloc(size * sizeof(TSTree *)), + .capacity = size, + .size = 0, + }; +} + +static inline void tree_vector_push(TreeVector *this, TSTree *tree) { + if (this->size == this->capacity) { + this->capacity += 4; + this->contents = realloc(this->contents, this->capacity * sizeof(TSTree *)); + } + this->contents[this->size++] = tree; +} + +static inline void tree_vector_reverse(TreeVector *this) { + TSTree *swap; + size_t limit = this->size / 2; + for (size_t i = 0; i < limit; i++) { + swap = this->contents[i]; + this->contents[i] = this->contents[this->size - 1 - i]; + this->contents[this->size - 1 - i] = swap; + } +} + +static inline TreeVector tree_vector_copy(TreeVector *this) { + return (TreeVector) { + .contents = memcpy( + malloc(this->capacity * sizeof(TSTree *)), + this->contents, + this->size * sizeof(TSTree *) + ), + .capacity = this->capacity, + .size = this->size, + }; +} + +#ifdef __cplusplus +} +#endif + +#endif // RUNTIME_TREE_VECTOR_H_