Refactor c code generator

It's been rewritten in a less functional style. String copies were actually
taking significant time for large parsers.
This commit is contained in:
Max Brunsfeld 2014-06-16 21:29:04 -07:00
parent 1daaf4485f
commit c312f985c8
3 changed files with 224 additions and 189 deletions

View file

@ -1,6 +1,7 @@
#include <map>
#include <set>
#include <string>
#include <functional>
#include <utility>
#include <vector>
#include "compiler/generate_code/c_code.h"
@ -11,45 +12,18 @@
namespace tree_sitter {
using std::string;
using std::to_string;
using std::function;
using std::map;
using std::vector;
using std::set;
using std::pair;
using util::join;
using util::indent;
using util::escape_char;
namespace generate_code {
string _switch(string condition, string body) {
return join({
"switch (" + condition + ") {",
indent(body),
"}"
});
}
string _case(string value, string body) {
return join({
"case " + value + ":",
indent(body), ""
});
}
string _default(string body) {
return join({
"default:",
indent(body)
});
}
string _if(string condition, string body) {
return join({
"if (" + condition + ")",
indent(body), ""
});
}
class CCodeGenerator {
string buffer;
size_t indent_level;
const string name;
const ParseTable parse_table;
const LexTable lex_table;
@ -63,6 +37,7 @@ namespace tree_sitter {
const LexTable &lex_table,
const PreparedGrammar &syntax_grammar,
const PreparedGrammar &lexical_grammar) :
indent_level(0),
name(name),
parse_table(parse_table),
lex_table(lex_table),
@ -70,21 +45,56 @@ namespace tree_sitter {
lexical_grammar(lexical_grammar) {}
string code() {
return join({
includes(),
state_and_symbol_counts(),
symbol_enum(),
symbol_names_list(),
ubiquitous_symbols_list(),
hidden_symbols_list(),
lex_function(),
lex_states_list(),
parse_table_array(),
parser_export(),
}, "\n\n") + "\n";
buffer = "";
includes();
state_and_symbol_counts();
symbol_enum();
symbol_names_list();
ubiquitous_symbols_list();
hidden_symbols_list();
lex_function();
lex_states_list();
parse_table_array();
parser_export();
return buffer;
}
private:
void _switch(string condition, function<void()> body) {
line("switch (" + condition + ") {");
indent(body);
line("}");
}
void _case(string value, function<void()> body) {
line("case " + value + ":");
indent(body);
}
void _default(function<void()> body) {
line("default:");
indent(body);
}
void _if(function<void()> condition, function<void()> body) {
line("if (");
indent(condition);
add(")");
indent(body);
}
void indent(function<void()> body) {
indent(body, 1);
}
void indent(function<void()> body, size_t n) {
indent_level += n;
body();
indent_level -= n;
}
const PreparedGrammar & grammar_for_symbol(const rules::Symbol &symbol) {
return symbol.is_token() ? lexical_grammar : syntax_grammar;
}
@ -161,177 +171,226 @@ namespace tree_sitter {
}
}
string condition_for_character_set(const rules::CharacterSet &set) {
vector<string> parts;
void condition_for_character_set(const rules::CharacterSet &set) {
if (set.ranges.size() == 1) {
return condition_for_character_range(*set.ranges.begin());
add(condition_for_character_range(*set.ranges.begin()));
} else {
for (auto &match : set.ranges)
parts.push_back("(" + condition_for_character_range(match) + ")");
return join(parts, " ||\n ");
bool first = true;
for (auto &match : set.ranges) {
string part = "(" + condition_for_character_range(match) + ")";
if (first) {
add(part);
} else {
add(" ||");
line(part);
}
first = false;
}
}
}
string condition_for_character_rule(const rules::CharacterSet &rule) {
vector<string> parts;
void condition_for_character_rule(const rules::CharacterSet &rule) {
pair<rules::CharacterSet, bool> representation = rule.most_compact_representation();
if (representation.second)
return condition_for_character_set(representation.first);
else
return "!(" + condition_for_character_set(rule.complement()) + ")";
if (representation.second) {
condition_for_character_set(representation.first);
} else {
add("!(");
condition_for_character_set(rule.complement());
add(")");
}
}
string code_for_parse_action(const ParseAction &action) {
void code_for_parse_action(const ParseAction &action) {
switch (action.type) {
case ParseActionTypeAccept:
return "ACCEPT_INPUT()";
add("ACCEPT_INPUT()");
break;
case ParseActionTypeShift:
return "SHIFT(" + to_string(action.state_index) + ")";
add("SHIFT(" + to_string(action.state_index) + ")");
break;
case ParseActionTypeReduce:
return "REDUCE(" +
add("REDUCE(" +
symbol_id(action.symbol) + ", " +
to_string(action.consumed_symbol_count) + ")";
default:
return "";
to_string(action.consumed_symbol_count) + ")");
break;
default:;
}
}
string code_for_lex_actions(const LexAction &action,
void code_for_lex_actions(const LexAction &action,
const set<rules::CharacterSet> &expected_inputs) {
switch (action.type) {
case LexActionTypeAdvance:
return "ADVANCE(" + lex_state_index(action.state_index) + ");";
line("ADVANCE(" + lex_state_index(action.state_index) + ");");
break;
case LexActionTypeAccept:
return "ACCEPT_TOKEN(" + symbol_id(action.symbol) + ");";
line("ACCEPT_TOKEN(" + symbol_id(action.symbol) + ");");
break;
case LexActionTypeError:
return "LEX_ERROR();";
default:
return "";
line("LEX_ERROR();");
break;
default:;
}
}
string code_for_lex_state(const LexState &lex_state) {
string result = "";
void code_for_lex_state(const LexState &lex_state) {
auto expected_inputs = lex_state.expected_inputs();
if (lex_state.is_token_start)
result += "START_TOKEN();" "\n";
line("START_TOKEN();");
for (auto pair : lex_state.actions)
if (!pair.first.is_empty())
result += _if(condition_for_character_rule(pair.first),
code_for_lex_actions(pair.second, expected_inputs));
result += code_for_lex_actions(lex_state.default_action, expected_inputs);
return result;
_if([&]() { condition_for_character_rule(pair.first); },
[&]() { code_for_lex_actions(pair.second, expected_inputs); });
code_for_lex_actions(lex_state.default_action, expected_inputs);
}
string switch_on_lex_state() {
string body = "";
for (size_t i = 0; i < lex_table.states.size(); i++)
body += _case(lex_state_index(i), code_for_lex_state(lex_table.states[i]));
body += _case("ts_lex_state_error", code_for_lex_state(lex_table.error_state));
body += _default("LEX_PANIC();");
return _switch("lex_state", body);
}
string state_and_symbol_counts() {
return join({
"#define STATE_COUNT " + to_string(parse_table.states.size()),
"#define SYMBOL_COUNT " + to_string(parse_table.symbols.size())
void switch_on_lex_state() {
_switch("lex_state", [&]() {
for (size_t i = 0; i < lex_table.states.size(); i++)
_case(lex_state_index(i), [&]() {
code_for_lex_state(lex_table.states[i]);
});
_case("ts_lex_state_error", [&]() {
code_for_lex_state(lex_table.error_state);
});
_default([&]() {
line("LEX_PANIC();");
});
});
}
string symbol_enum() {
string result = "enum {\n";
bool at_start = true;
for (auto symbol : parse_table.symbols)
if (!symbol.is_built_in()) {
if (at_start)
result += indent(symbol_id(symbol)) + " = ts_start_sym,\n";
else
result += indent(symbol_id(symbol)) + ",\n";
at_start = false;
}
return result + "};";
void state_and_symbol_counts() {
line("#define STATE_COUNT " + to_string(parse_table.states.size()));
line("#define SYMBOL_COUNT " + to_string(parse_table.symbols.size()));
line();
}
string symbol_names_list() {
void symbol_enum() {
line("enum {");
indent([&]() {
bool at_start = true;
for (auto symbol : parse_table.symbols)
if (!symbol.is_built_in()) {
if (at_start)
line(symbol_id(symbol) + " = ts_start_sym,");
else
line(symbol_id(symbol) + ",");
at_start = false;
}
});
line("};");
line();
}
void symbol_names_list() {
set<rules::Symbol> symbols(parse_table.symbols);
symbols.insert(rules::END_OF_INPUT());
symbols.insert(rules::ERROR());
string result = "SYMBOL_NAMES = {\n";
for (auto symbol : parse_table.symbols)
result += indent("[" + symbol_id(symbol) + "] = \"" + symbol_name(symbol)) + "\",\n";
return result + "};";
}
string ubiquitous_symbols_list() {
string result = "UBIQUITOUS_SYMBOLS = {\n";
for (auto &symbol : syntax_grammar.ubiquitous_tokens())
result += indent("[" + symbol_id(symbol) + "] = 1,") + "\n";
return result + "};";
}
string hidden_symbols_list() {
string result = "HIDDEN_SYMBOLS = {\n";
for (auto &symbol : parse_table.symbols)
if (!symbol.is_built_in() && (symbol.is_auxiliary() || grammar_for_symbol(symbol).rule_name(symbol)[0] == '_'))
result += indent("[" + symbol_id(symbol) + "] = 1,") + "\n";
return result + "};";
}
string includes() {
return "#include \"tree_sitter/parser.h\"";
}
string lex_function() {
return join({
"LEX_FN() {",
indent("START_LEXER();"),
indent(switch_on_lex_state()),
"}"
line("SYMBOL_NAMES = {");
indent([&]() {
for (auto symbol : parse_table.symbols)
line("[" + symbol_id(symbol) + "] = \"" + symbol_name(symbol) + "\",");
});
line("};");
line();
}
template<typename T>
vector<string> map_to_string(const vector<T> &inputs, std::function<string(T)> f) {
vector<string> result;
for (auto &item : inputs)
result.push_back(f(item));
return result;
void ubiquitous_symbols_list() {
line("UBIQUITOUS_SYMBOLS = {");
indent([&]() {
for (auto &symbol : syntax_grammar.ubiquitous_tokens())
line("[" + symbol_id(symbol) + "] = 1,");
});
line("};");
line();
}
string lex_states_list() {
void hidden_symbols_list() {
line("HIDDEN_SYMBOLS = {");
indent([&]() {
for (auto &symbol : parse_table.symbols)
if (!symbol.is_built_in() && (symbol.is_auxiliary() || grammar_for_symbol(symbol).rule_name(symbol)[0] == '_'))
line("[" + symbol_id(symbol) + "] = 1,");
});
line("};");
line();
}
void includes() {
add("#include \"tree_sitter/parser.h\"");
line();
}
void lex_function() {
line("LEX_FN() {");
indent([&]() {
line("START_LEXER();");
switch_on_lex_state();
});
line("}");
line();
}
void lex_states_list() {
line("LEX_STATES = {");
indent([&]() {
size_t state_id = 0;
for (auto &state : parse_table.states)
line("[" + to_string(state_id++) + "] = " + lex_state_index(state.lex_state_id) + ",");
});
line("};");
line();
}
void parse_table_array() {
size_t state_id = 0;
return join({
"LEX_STATES = {",
indent(join(map_to_string<ParseState>(parse_table.states, [&](ParseState state) {
return "[" + to_string(state_id++) + "] = " + lex_state_index(state.lex_state_id) + ",";
}))),
"};"
line("#pragma GCC diagnostic push");
line("#pragma GCC diagnostic ignored \"-Wmissing-field-initializers\"");
line();
line("PARSE_TABLE = {");
indent([&]() {
for (auto &state : parse_table.states) {
line("[" + to_string(state_id++) + "] = {");
indent([&]() {
for (auto &pair : state.actions) {
line("[" + symbol_id(pair.first) + "] = ");
code_for_parse_action(pair.second);
add(",");
}
});
line("},");
}
});
line("};");
line();
line("#pragma GCC diagnostic pop");
line();
}
string parse_table_array() {
size_t state_id = 0;
return join({
"#pragma GCC diagnostic push",
"#pragma GCC diagnostic ignored \"-Wmissing-field-initializers\"",
"",
"PARSE_TABLE = {",
indent(join(map_to_string<ParseState>(parse_table.states, [&](ParseState state) {
string result = "[" + to_string(state_id++) + "] = {\n";
for (auto &pair : state.actions)
result += indent("[" + symbol_id(pair.first) + "] = " + code_for_parse_action(pair.second) + ",") + "\n";
return result + "},";
}), "\n")),
"};",
"",
"#pragma GCC diagnostic pop"
});
void parser_export() {
line("EXPORT_PARSER(ts_parser_" + name + ");");
line();
}
string parser_export() {
return "EXPORT_PARSER(ts_parser_" + name + ");";
void line() {
line("");
}
void line(string input) {
add("\n");
if (!input.empty()) {
string space;
for (size_t i = 0; i < indent_level; i++)
space += " ";
add(space + input);
}
}
void add(string input) {
buffer += input;
}
};

View file

@ -23,7 +23,7 @@ namespace tree_sitter {
str_replace(&input, "\n", "\\n");
return input;
}
string escape_char(char character) {
switch (character) {
case '\0':
@ -44,26 +44,5 @@ namespace tree_sitter {
return string() + character;
}
}
string join(vector<string> lines, string separator) {
string result;
bool started = false;
for (auto line : lines) {
if (started) result += separator;
started = true;
result += line;
}
return result;
}
string join(vector<string> lines) {
return join(lines, "\n");
}
string indent(string input) {
string tab = " ";
util::str_replace(&input, "\n", "\n" + tab);
return tab + input;
}
}
}

View file

@ -10,9 +10,6 @@ namespace tree_sitter {
void str_replace(std::string *input, const std::string &search, const std::string &replace);
std::string escape_string(std::string input);
std::string escape_char(char character);
std::string indent(std::string input);
std::string join(std::vector<std::string> lines, std::string separator);
std::string join(std::vector<std::string> lines);
}
}