diff --git a/include/tree_sitter/parser.h b/include/tree_sitter/parser.h index e1872ab4..09c09c54 100644 --- a/include/tree_sitter/parser.h +++ b/include/tree_sitter/parser.h @@ -80,12 +80,19 @@ typedef union { unsigned int count; } TSParseActionEntry; +typedef union { + TSSymbol symbol; + unsigned int count; +} TSInProgressSymbolEntry; + struct TSLanguage { size_t symbol_count; const char **symbol_names; const TSSymbolMetadata *symbol_metadata; const unsigned short *parse_table; const TSParseActionEntry *parse_actions; + const unsigned short *in_progress_symbol_table; + const TSInProgressSymbolEntry *in_progress_symbols; const TSStateId *lex_states; TSTree *(*lex_fn)(TSLexer *, TSStateId, bool); }; @@ -181,19 +188,21 @@ enum { { .type = TSParseActionTypeAccept } \ } -#define EXPORT_LANGUAGE(language_name) \ - static TSLanguage language = { \ - .symbol_count = SYMBOL_COUNT, \ - .symbol_metadata = ts_symbol_metadata, \ - .parse_table = (const unsigned short *)ts_parse_table, \ - .parse_actions = ts_parse_actions, \ - .lex_states = ts_lex_states, \ - .symbol_names = ts_symbol_names, \ - .lex_fn = ts_lex, \ - }; \ - \ - const TSLanguage *language_name() { \ - return &language; \ +#define EXPORT_LANGUAGE(language_name) \ + static TSLanguage language = { \ + .symbol_count = SYMBOL_COUNT, \ + .symbol_metadata = ts_symbol_metadata, \ + .parse_table = (const unsigned short *)ts_parse_table, \ + .parse_actions = ts_parse_actions, \ + .in_progress_symbol_table = ts_in_progress_symbol_table, \ + .in_progress_symbols = ts_in_progress_symbols, \ + .lex_states = ts_lex_states, \ + .symbol_names = ts_symbol_names, \ + .lex_fn = ts_lex, \ + }; \ + \ + const TSLanguage *language_name() { \ + return &language; \ } #ifdef __cplusplus diff --git a/src/compiler/build_tables/build_parse_table.cc b/src/compiler/build_tables/build_parse_table.cc index 2592e5e3..634a5afd 100644 --- a/src/compiler/build_tables/build_parse_table.cc +++ b/src/compiler/build_tables/build_parse_table.cc @@ -131,6 +131,12 @@ class ParseTableBuilder { auto pair = parse_state_ids.find(item_set); if (pair == parse_state_ids.end()) { ParseStateId state_id = parse_table.add_state(); + for (const auto &entry : item_set.entries) { + const ParseItem &item = entry.first; + if (item.step_index > 0 && item.lhs() != rules::START()) + parse_table.states[state_id].in_progress_symbols.insert(item.lhs()); + } + parse_state_ids[item_set] = state_id; item_sets_to_process.push_back({ item_set, state_id }); return state_id; diff --git a/src/compiler/generate_code/c_code.cc b/src/compiler/generate_code/c_code.cc index 1e09a977..3819e0b2 100644 --- a/src/compiler/generate_code/c_code.cc +++ b/src/compiler/generate_code/c_code.cc @@ -74,7 +74,9 @@ class CCodeGenerator { const LexicalGrammar lexical_grammar; map sanitized_names; vector>> parse_actions; + vector>> in_progress_symbols; size_t next_parse_action_list_index; + size_t next_in_progress_symbol_list_index; public: CCodeGenerator(string name, const ParseTable &parse_table, @@ -86,7 +88,8 @@ class CCodeGenerator { lex_table(lex_table), syntax_grammar(syntax_grammar), lexical_grammar(lexical_grammar), - next_parse_action_list_index(0) {} + next_parse_action_list_index(0), + next_in_progress_symbol_list_index(0) {} string code() { buffer = ""; @@ -100,6 +103,7 @@ class CCodeGenerator { add_lex_states_list(); add_out_of_context_parse_states_list(); add_parse_table(); + add_in_progress_symbol_table(); add_parser_export(); return buffer; @@ -233,7 +237,7 @@ class CCodeGenerator { } void add_parse_table() { - add_parse_actions({ ParseAction::Error() }); + add_parse_action_list_id({ ParseAction::Error() }); size_t state_id = 0; line("#pragma GCC diagnostic push"); @@ -247,7 +251,7 @@ class CCodeGenerator { indent([&]() { for (const auto &pair : state.actions) { line("[" + symbol_id(pair.first) + "] = "); - add(to_string(add_parse_actions(pair.second))); + add(to_string(add_parse_action_list_id(pair.second))); add(","); } }); @@ -263,6 +267,25 @@ class CCodeGenerator { line(); } + void add_in_progress_symbol_table() { + line("static unsigned short ts_in_progress_symbol_table[] = {"); + + indent([&]() { + size_t state_id = 0; + for (const ParseState &state : parse_table.states) { + line("[" + to_string(state_id) + "] = "); + add(to_string(add_in_progress_symbol_list_id(state.in_progress_symbols))); + add(","); + state_id++; + } + }); + + line("};"); + line(); + add_in_progress_symbols_list(); + line(); + } + void add_parser_export() { line("EXPORT_LANGUAGE(ts_language_" + name + ");"); line(); @@ -345,7 +368,6 @@ class CCodeGenerator { to_string(pair.second.size()) + "},"); for (const ParseAction &action : pair.second) { - index++; add(" "); switch (action.type) { case ParseActionTypeError: @@ -383,7 +405,27 @@ class CCodeGenerator { line("};"); } - size_t add_parse_actions(const vector &actions) { + void add_in_progress_symbols_list() { + line("static TSInProgressSymbolEntry ts_in_progress_symbols[] = {"); + + indent([&]() { + for (const auto &pair : in_progress_symbols) { + size_t index = pair.first; + line("[" + to_string(index) + "] = {.count = " + + to_string(pair.second.size()) + "},"); + + for (const rules::Symbol &symbol : pair.second) { + add(" "); + add("{" + symbol_id(symbol) + "}"); + add(","); + } + } + }); + + line("};"); + } + + size_t add_parse_action_list_id(const vector &actions) { for (const auto &pair : parse_actions) { if (pair.second == actions) { return pair.first; @@ -396,6 +438,19 @@ class CCodeGenerator { return result; } + size_t add_in_progress_symbol_list_id(const set &symbols) { + for (const auto &pair : in_progress_symbols) { + if (pair.second == symbols) { + return pair.first; + } + } + + size_t result = next_in_progress_symbol_list_index; + in_progress_symbols.push_back({ result, symbols }); + next_in_progress_symbol_list_index += 1 + symbols.size(); + return result; + } + void add_action_flags(const ParseAction &action) { if (action.fragile && action.can_hide_split) add("FRAGILE|CAN_HIDE_SPLIT"); diff --git a/src/compiler/parse_table.h b/src/compiler/parse_table.h index d3dba182..99232707 100644 --- a/src/compiler/parse_table.h +++ b/src/compiler/parse_table.h @@ -84,6 +84,7 @@ class ParseState { void each_advance_action(std::function); std::map> actions; + std::set in_progress_symbols; LexStateId lex_state_id; };