Refactor construction of out-of-context states

This commit is contained in:
Max Brunsfeld 2016-04-25 21:59:40 -07:00
parent e99a3925e0
commit 31f6b2e24a
7 changed files with 51 additions and 41 deletions

View file

@ -9,7 +9,6 @@
#include "compiler/build_tables/lex_conflict_manager.h"
#include "compiler/build_tables/remove_duplicate_states.h"
#include "compiler/build_tables/lex_item.h"
#include "compiler/build_tables/does_match_any_line.h"
#include "compiler/parse_table.h"
#include "compiler/lexical_grammar.h"
#include "compiler/rules/built_in_symbols.h"
@ -48,11 +47,10 @@ class LexTableBuilder {
}
LexTable build() {
add_lex_state(build_lex_item_set(parse_table->all_symbols(), true));
add_lex_state_for_parse_state(&parse_table->error_state);
for (ParseState &parse_state : parse_table->states)
parse_state.lex_state_id =
add_lex_state(build_lex_item_set(parse_state.expected_inputs(), false));
add_lex_state_for_parse_state(&parse_state);
mark_fragile_tokens();
remove_duplicate_lex_states();
@ -61,7 +59,7 @@ class LexTableBuilder {
}
private:
LexItemSet build_lex_item_set(const set<Symbol> &symbols, bool error) {
LexItemSet build_lex_item_set(const set<Symbol> &symbols) {
LexItemSet result;
for (const Symbol &symbol : symbols) {
vector<rule_ptr> rules;
@ -69,8 +67,6 @@ class LexTableBuilder {
rules.push_back(CharacterSet().include(0).copy());
} else if (symbol.is_token) {
rule_ptr rule = lex_grammar.variables[symbol.index].rule;
if (error && does_match_any_line(rule))
continue;
auto choice = rule->as<rules::Choice>();
if (choice)
@ -98,6 +94,11 @@ class LexTableBuilder {
return result;
}
void add_lex_state_for_parse_state(ParseState *parse_state) {
parse_state->lex_state_id =
add_lex_state(build_lex_item_set(parse_state->expected_inputs()));
}
LexStateId add_lex_state(const LexItemSet &item_set) {
const auto &pair = lex_state_ids.find(item_set);
if (pair == lex_state_ids.end()) {

View file

@ -15,6 +15,7 @@
#include "compiler/syntax_grammar.h"
#include "compiler/rules/symbol.h"
#include "compiler/rules/built_in_symbols.h"
#include "compiler/build_tables/does_match_any_line.h"
namespace tree_sitter {
namespace build_tables {
@ -105,17 +106,20 @@ class ParseTableBuilder {
}
void add_out_of_context_parse_states() {
map<Symbol, set<Symbol>> symbols_by_first = symbols_by_first_symbol(grammar);
auto symbols_by_first = symbols_by_first_symbol(grammar);
for (size_t i = 0; i < lexical_grammar.variables.size(); i++) {
Symbol symbol(i, true);
if (!grammar.extra_tokens.count(symbol))
if (!does_match_any_line(lexical_grammar.variables[i].rule))
add_out_of_context_parse_state(symbol, symbols_by_first[symbol]);
}
for (size_t i = 0; i < grammar.variables.size(); i++) {
Symbol symbol(i, false);
add_out_of_context_parse_state(Symbol(i, false), symbols_by_first[symbol]);
add_out_of_context_parse_state(symbol, symbols_by_first[symbol]);
}
parse_table.error_state.actions[rules::END_OF_INPUT()].clear();
}
void add_out_of_context_parse_state(const rules::Symbol &symbol,
@ -133,8 +137,11 @@ class ParseTableBuilder {
}
}
ParseStateId state = add_parse_state(item_set);
parse_table.out_of_context_state_indices[symbol] = state;
if (!item_set.entries.empty()) {
ParseStateId state = add_parse_state(item_set);
parse_table.error_state.actions[symbol].push_back(
ParseAction::Shift(state, PrecedenceRange()));
}
}
ParseStateId add_parse_state(const ParseItemSet &item_set) {
@ -265,11 +272,12 @@ class ParseTableBuilder {
auto replacements =
remove_duplicate_states<ParseState, ParseAction>(&parse_table.states);
for (auto &pair : parse_table.out_of_context_state_indices) {
auto replacement = replacements.find(pair.second);
if (replacement != replacements.end())
pair.second = replacement->second;
}
parse_table.error_state.each_advance_action(
[&replacements](ParseAction *action) {
auto replacement = replacements.find(action->state_index);
if (replacement != replacements.end())
action->state_index = replacement->second;
});
}
ParseAction *add_action(ParseStateId state_id, Symbol lookahead,

View file

@ -46,12 +46,11 @@ std::map<size_t, size_t> remove_duplicate_states(std::vector<StateType> *states)
}
for (StateType &state : *states)
state.each_advance_action(
[&duplicates, &new_replacements](ActionType *action) {
auto new_replacement = new_replacements.find(action->state_index);
if (new_replacement != new_replacements.end())
action->state_index = new_replacement->second;
});
state.each_advance_action([&new_replacements](ActionType *action) {
auto new_replacement = new_replacements.find(action->state_index);
if (new_replacement != new_replacements.end())
action->state_index = new_replacement->second;
});
for (auto i = duplicates.rbegin(); i != duplicates.rend(); ++i)
states->erase(states->begin() + i->first);

View file

@ -223,15 +223,12 @@ class CCodeGenerator {
void add_out_of_context_parse_states_list() {
line("static TSStateId ts_out_of_context_states[SYMBOL_COUNT] = {");
indent([&]() {
for (const auto &entry : parse_table.symbols) {
for (const auto &entry : parse_table.error_state.actions) {
const rules::Symbol &symbol = entry.first;
if (symbol.is_built_in())
continue;
auto iter = parse_table.out_of_context_state_indices.find(symbol);
string state = (iter != parse_table.out_of_context_state_indices.end())
? to_string(iter->second)
: "ts_parse_state_error";
line("[" + symbol_id(symbol) + "] = " + state + ",");
if (!entry.second.empty()) {
ParseStateId state = entry.second[0].state_index;
line("[" + symbol_id(symbol) + "] = " + to_string(state) + ",");
}
}
});
line("};");

View file

@ -102,8 +102,8 @@ class ParseTable {
ParseAction action);
std::vector<ParseState> states;
ParseState error_state;
std::map<rules::Symbol, ParseTableSymbolMetadata> symbols;
std::map<rules::Symbol, size_t> out_of_context_state_indices;
};
} // namespace tree_sitter

View file

@ -264,11 +264,13 @@ static bool ts_parser__select_tree(TSParser *self, TSTree *left, TSTree *right)
if (!right)
return false;
if (right->error_size < left->error_size) {
LOG_ACTION("select_smaller_error symbol:%s, over_symbol:%s", SYM_NAME(right->symbol), SYM_NAME(left->symbol));
LOG_ACTION("select_smaller_error symbol:%s, over_symbol:%s",
SYM_NAME(right->symbol), SYM_NAME(left->symbol));
return true;
}
if (left->error_size < right->error_size) {
LOG_ACTION("select_smaller_error symbol:%s, over_symbol:%s", SYM_NAME(left->symbol), SYM_NAME(right->symbol));
LOG_ACTION("select_smaller_error symbol:%s, over_symbol:%s",
SYM_NAME(left->symbol), SYM_NAME(right->symbol));
return false;
}
return ts_tree_compare(right, left) < 0;
@ -314,7 +316,7 @@ error:
}
static bool ts_parser__switch_children(TSParser *self, TSTree *tree,
TSTree **children, size_t count) {
TSTree **children, size_t count) {
self->scratch_tree.symbol = tree->symbol;
self->scratch_tree.child_count = 0;
ts_tree_set_children(&self->scratch_tree, count, children);
@ -534,8 +536,8 @@ static RepairResult ts_parser__repair_error(TSParser *self, StackSlice slice,
LOG_ACTION(
"repair_found sym:%s, child_count:%lu, match_count:%lu, skipped:%lu",
SYM_NAME(symbol), repair.count_below_error + count_above_error, repair.in_progress_state_count,
skip_count);
SYM_NAME(symbol), repair.count_below_error + count_above_error,
repair.in_progress_state_count, skip_count);
if (skip_count > 0) {
TreeArray skipped_children = array_new();

View file

@ -134,9 +134,9 @@ static StackVersion ts_stack__add_version(Stack *self, StackNode *node) {
static bool ts_stack__add_slice(Stack *self, StackNode *node, TreeArray *trees) {
for (size_t i = self->slices.size - 1; i + 1 > 0; i--) {
StackVersion version = self->slices.contents[i].version;
StackVersion version = self->slices.contents[i].version;
if (self->heads.contents[version] == node) {
StackSlice slice = {*trees, version};
StackSlice slice = { *trees, version };
return array_insert(&self->slices, i + 1, slice);
}
}
@ -144,7 +144,7 @@ static bool ts_stack__add_slice(Stack *self, StackNode *node, TreeArray *trees)
StackVersion version = ts_stack__add_version(self, node);
if (version == STACK_VERSION_NONE)
return false;
StackSlice slice = {*trees, version};
StackSlice slice = { *trees, version };
return array_push(&self->slices, slice);
}
@ -442,7 +442,10 @@ int ts_stack_print_dot_graph(Stack *self, const char **symbol_names, FILE *f) {
for (size_t i = 0; i < self->heads.size; i++) {
StackNode *node = self->heads.contents[i];
fprintf(f, "node_head_%lu [shape=none, label=\"\"]\n", i);
fprintf(f, "node_head_%lu -> node_%p [label=%lu, arrowhead=none, fontcolor=blue, weight=10000]\n", i, node, i);
fprintf(f,
"node_head_%lu -> node_%p [label=%lu, arrowhead=none, "
"fontcolor=blue, weight=10000]\n",
i, node, i);
array_push(&self->pop_paths, ((PopPath){.node = node }));
}