diff --git a/spec/compiler/generate_parsers.cpp b/spec/compiler/generate_parsers.cpp index f3fbfa86..99843d41 100644 --- a/spec/compiler/generate_parsers.cpp +++ b/spec/compiler/generate_parsers.cpp @@ -1,6 +1,7 @@ #include "spec_helper.h" #include "table_builder.h" #include "parse_table.h" +#include "prepare_grammar.h" #include "c_code.h" #include @@ -11,9 +12,13 @@ describe("code generation", []() { it("works for the arithmetic grammar", [&]() { Grammar grammar = test_grammars::arithmetic(); - auto tables = lr::build_tables(grammar); - string code = code_gen::c_code(grammar, tables.first, tables.second); - std::ofstream(test_parser_dir + "/arithmetic.c") << code; + auto grammars = prepare_grammar(grammar); + auto tables = lr::build_tables(grammars.first, grammars.second); + auto rule_names = grammars.first.rule_names(); + auto token_names = grammars.second.rule_names(); + rule_names.insert(rule_names.end(), token_names.begin(), token_names.end()); + auto code = code_gen::c_code(rule_names, tables.first, tables.second); + ofstream(test_parser_dir + "/arithmetic.c") << code; }); }); diff --git a/spec/compiler/lr/item_set_spec.cpp b/spec/compiler/lr/item_set_spec.cpp deleted file mode 100644 index 89b86927..00000000 --- a/spec/compiler/lr/item_set_spec.cpp +++ /dev/null @@ -1,68 +0,0 @@ -#include "spec_helper.h" -#include - -using namespace tree_sitter::lr; -using namespace tree_sitter::rules; - -static item_set_ptr item_set(const std::initializer_list &items) { - return item_set_ptr(new ItemSet(items)); -} - -START_TEST - -describe("item sets", []() { - Grammar grammar = test_grammars::arithmetic(); - - it("computes the closure of an item set under symbol expansion", [&]() { - Item item = Item::at_beginning_of_rule("expression", grammar); - ItemSet set = ItemSet(item, grammar); - - AssertThat( - set, - EqualsContainer(ItemSet({ - Item("expression", grammar.rule("expression"), 0), - Item("term", grammar.rule("term"), 0), - Item("factor", grammar.rule("factor"), 0), - Item("variable", grammar.rule("variable"), 0), - Item("number", grammar.rule("number"), 0), - Item("left_paren", grammar.rule("left_paren"), 0), - }))); - }); - - it("computes transitions", [&]() { - Item item = Item::at_beginning_of_rule("factor", grammar); - ItemSet set = ItemSet(item, grammar); - - AssertThat( - set.sym_transitions(grammar), - Equals(transition_map({ - { sym("variable"), item_set({ Item("factor", blank(), 1) }) }, - { sym("number"), item_set({ Item("factor", blank(), 1) }) }, - { sym("left_paren"), std::make_shared(Item("factor", seq({ sym("expression"), sym("right_paren") }), 1), grammar) }, - }))); - }); - - it("computes character transitions", [&]() { - Item item = Item::at_beginning_of_rule("factor", grammar); - ItemSet set = ItemSet(item, grammar); - - AssertThat( - set.char_transitions(grammar), - Equals(transition_map({ - { character(CharClassWord), item_set({ Item("variable", choice({ repeat(character(CharClassWord)), blank() }), 1) }) }, - { character(CharClassDigit), item_set({ Item("number", choice({ repeat(character(CharClassDigit)), blank() }), 1) }) }, - { character('('), item_set({ Item("left_paren", blank(), 1) }) } - }))); - }); - - it("can be hashed", [&]() { - ItemSet set1 = ItemSet(Item::at_beginning_of_rule("factor", grammar), grammar); - ItemSet set2 = ItemSet(Item::at_beginning_of_rule("factor", grammar), grammar); - AssertThat(std::hash()(set1), Equals(std::hash()(set2))); - - ItemSet set3 = ItemSet(Item::at_beginning_of_rule("term", grammar), grammar); - AssertThat(std::hash()(set1), !Equals(std::hash()(set3))); - }); -}); - -END_TEST diff --git a/spec/compiler/lr/table_builder_spec.cpp b/spec/compiler/lr/table_builder_spec.cpp index 5e16d8d0..fc093764 100644 --- a/spec/compiler/lr/table_builder_spec.cpp +++ b/spec/compiler/lr/table_builder_spec.cpp @@ -2,15 +2,44 @@ #include using namespace tree_sitter::lr; +using namespace tree_sitter::rules; + typedef std::unordered_set parse_actions; typedef std::unordered_set lex_actions; START_TEST describe("building parse and lex tables", []() { - Grammar grammar = test_grammars::arithmetic(); - ParseTable table = build_tables(grammar).first; - LexTable lex_table = build_tables(grammar).second; + Grammar grammar({ + { "expression", choice({ + seq({ + sym("term"), + token("plus-token"), + sym("term") }), + sym("term") }) }, + { "term", choice({ + sym("variable"), + sym("number"), + seq({ + token("left-paren-token"), + sym("expression"), + token("right-paren-token") + }) }) }, + { "variable", token("variable-token") }, + { "number", token("number-token") } + }); + + Grammar lex_grammar({ + { "plus-token", character('+') }, + { "variable-token", pattern("\\w+") }, + { "number-token", pattern("\\d+") }, + { "left-paren-token", character('(') }, + { "right-paren-token", character(')') } + }); + + pair tables = build_tables(grammar, lex_grammar); + ParseTable table = tables.first; + LexTable lex_table = tables.second; function parse_state = [&](size_t index) { return table.states[index]; @@ -25,16 +54,18 @@ describe("building parse and lex tables", []() { AssertThat(parse_state(0).actions, Equals(unordered_map({ { "expression", parse_actions({ ParseAction::Shift(1) }) }, { "term", parse_actions({ ParseAction::Shift(2) }) }, - { "factor", parse_actions({ ParseAction::Shift(5) }) }, - { "variable", parse_actions({ ParseAction::Shift(8) }) }, - { "number", parse_actions({ ParseAction::Shift(8) }) }, - { "left_paren", parse_actions({ ParseAction::Shift(9) }) } + { "number", parse_actions({ ParseAction::Shift(5) }) }, + { "variable", parse_actions({ ParseAction::Shift(5) }) }, + + { "left-paren-token", parse_actions({ ParseAction::Shift(6) }) }, + { "variable-token", parse_actions({ ParseAction::Shift(9) }) }, + { "number-token", parse_actions({ ParseAction::Shift(10) }) }, }))); AssertThat(lex_state(0).actions, Equals(unordered_map({ - { CharMatchClass(CharClassWord), lex_actions({ LexAction::Advance(1) }) }, - { CharMatchClass(CharClassDigit), lex_actions({ LexAction::Advance(4) }) }, - { CharMatchSpecific('('), lex_actions({ LexAction::Advance(11) }) } + { CharMatchSpecific('('), lex_actions({ LexAction::Advance(1) }) }, + { CharMatchClass(CharClassWord), lex_actions({ LexAction::Advance(2) }) }, + { CharMatchClass(CharClassDigit), lex_actions({ LexAction::Advance(3) }) }, }))); }); @@ -46,15 +77,7 @@ describe("building parse and lex tables", []() { it("has the right next states", [&]() { AssertThat(parse_state(2).actions, Equals(unordered_map({ - { "plus", parse_actions({ ParseAction::Shift(3) }) }, - }))); - - AssertThat(parse_state(3).actions, Equals(unordered_map({ - { "variable", parse_actions({ ParseAction::Shift(8) }) }, - { "factor", parse_actions({ ParseAction::Shift(5) }) }, - { "left_paren", parse_actions({ ParseAction::Shift(9) }) }, - { "number", parse_actions({ ParseAction::Shift(8) }) }, - { "term", parse_actions({ ParseAction::Shift(4) }) }, + { "plus-token", parse_actions({ ParseAction::Shift(3) }) }, }))); }); }); diff --git a/spec/compiler/spec_helper.h b/spec/compiler/spec_helper.h index 7c59a241..548f4df3 100644 --- a/spec/compiler/spec_helper.h +++ b/spec/compiler/spec_helper.h @@ -2,6 +2,7 @@ #define TreeSitter_SpecHelper_h #include "bandit/bandit.h" +#include #include "transition_map.h" #include "rules.h" #include "item.h" diff --git a/spec/fixtures/grammars/arithmetic.cpp b/spec/fixtures/grammars/arithmetic.cpp index 837e647f..3453d281 100644 --- a/spec/fixtures/grammars/arithmetic.cpp +++ b/spec/fixtures/grammars/arithmetic.cpp @@ -10,28 +10,24 @@ namespace test_grammars { { "expression", choice({ seq({ sym("term"), - sym("plus"), + character('+'), sym("term") }), sym("term") }) }, { "term", choice({ seq({ sym("factor"), - sym("times"), + character('*'), sym("factor") }), sym("factor") }) }, { "factor", choice({ sym("variable"), sym("number"), seq({ - sym("left_paren"), + character('('), sym("expression"), - sym("right_paren") }) }) }, + character(')') }) }) }, { "number", pattern("\\d+") }, { "variable", pattern("\\w+") }, - { "plus", str("+") }, - { "times", str("*") }, - { "left_paren", str("(") }, - { "right_paren", str(")") } }); } } diff --git a/spec/fixtures/parsers/arithmetic.c b/spec/fixtures/parsers/arithmetic.c index f77cdd19..1b634782 100644 --- a/spec/fixtures/parsers/arithmetic.c +++ b/spec/fixtures/parsers/arithmetic.c @@ -4,13 +4,15 @@ typedef enum { ts_symbol_expression, ts_symbol_term, - ts_symbol_right_paren, ts_symbol_number, ts_symbol_factor, ts_symbol_variable, - ts_symbol_plus, - ts_symbol_times, - ts_symbol_left_paren, + ts_symbol_6, + ts_symbol_5, + ts_symbol_4, + ts_symbol_3, + ts_symbol_2, + ts_symbol_1, ts_symbol___END__ } ts_symbol; @@ -18,107 +20,43 @@ static void ts_lex(TSParser *parser) { START_LEXER(); switch (LEX_STATE()) { case 0: - if (LOOKAHEAD_CHAR() == '(') - ADVANCE(11); - if (isdigit(LOOKAHEAD_CHAR())) - ADVANCE(4); if (isalnum(LOOKAHEAD_CHAR())) + ADVANCE(2); + if (isdigit(LOOKAHEAD_CHAR())) + ADVANCE(3); + if (LOOKAHEAD_CHAR() == '(') ADVANCE(1); LEX_ERROR(); case 1: - if (isalnum(LOOKAHEAD_CHAR())) - ADVANCE(2); - LEX_ERROR(); + ACCEPT_TOKEN(ts_symbol_2); case 2: if (isalnum(LOOKAHEAD_CHAR())) - ADVANCE(3); - LEX_ERROR(); + ADVANCE(2); + ACCEPT_TOKEN(ts_symbol_1); case 3: - if (isalnum(LOOKAHEAD_CHAR())) - ADVANCE(1); - LEX_ERROR(); - case 4: if (isdigit(LOOKAHEAD_CHAR())) - ADVANCE(5); + ADVANCE(3); + ACCEPT_TOKEN(ts_symbol_4); + case 4: LEX_ERROR(); case 5: - if (isdigit(LOOKAHEAD_CHAR())) + if (LOOKAHEAD_CHAR() == '+') ADVANCE(6); LEX_ERROR(); case 6: - if (isdigit(LOOKAHEAD_CHAR())) - ADVANCE(7); - LEX_ERROR(); + ACCEPT_TOKEN(ts_symbol_6); case 7: - if (isdigit(LOOKAHEAD_CHAR())) + if (LOOKAHEAD_CHAR() == '*') ADVANCE(8); LEX_ERROR(); case 8: - if (isdigit(LOOKAHEAD_CHAR())) - ADVANCE(9); - LEX_ERROR(); + ACCEPT_TOKEN(ts_symbol_5); case 9: - if (isdigit(LOOKAHEAD_CHAR())) + if (LOOKAHEAD_CHAR() == ')') ADVANCE(10); LEX_ERROR(); case 10: - if (isdigit(LOOKAHEAD_CHAR())) - ADVANCE(8); - LEX_ERROR(); - case 11: - ACCEPT_TOKEN(ts_symbol_left_paren); - case 12: - ACCEPT_TOKEN(ts_symbol___START__); - case 13: - if (LOOKAHEAD_CHAR() == '+') - ADVANCE(14); - LEX_ERROR(); - case 14: - ACCEPT_TOKEN(ts_symbol_plus); - case 15: - if (LOOKAHEAD_CHAR() == '(') - ADVANCE(11); - if (isdigit(LOOKAHEAD_CHAR())) - ADVANCE(4); - if (isalnum(LOOKAHEAD_CHAR())) - ADVANCE(1); - LEX_ERROR(); - case 16: - ACCEPT_TOKEN(ts_symbol_expression); - case 17: - if (LOOKAHEAD_CHAR() == '*') - ADVANCE(18); - LEX_ERROR(); - case 18: - ACCEPT_TOKEN(ts_symbol_times); - case 19: - if (LOOKAHEAD_CHAR() == '(') - ADVANCE(11); - if (isdigit(LOOKAHEAD_CHAR())) - ADVANCE(4); - if (isalnum(LOOKAHEAD_CHAR())) - ADVANCE(1); - LEX_ERROR(); - case 20: - ACCEPT_TOKEN(ts_symbol_term); - case 21: - ACCEPT_TOKEN(ts_symbol_factor); - case 22: - if (LOOKAHEAD_CHAR() == '(') - ADVANCE(11); - if (isdigit(LOOKAHEAD_CHAR())) - ADVANCE(4); - if (isalnum(LOOKAHEAD_CHAR())) - ADVANCE(1); - LEX_ERROR(); - case 23: - if (LOOKAHEAD_CHAR() == ')') - ADVANCE(24); - LEX_ERROR(); - case 24: - ACCEPT_TOKEN(ts_symbol_right_paren); - case 25: - ACCEPT_TOKEN(ts_symbol_factor); + ACCEPT_TOKEN(ts_symbol_3); default: LEX_ERROR(); } @@ -131,14 +69,18 @@ TSTree ts_parse_arithmetic(const char *input) { case 0: SET_LEX_STATE(0); switch (LOOKAHEAD_SYM()) { - case ts_symbol_left_paren: - SHIFT(9); + case ts_symbol_4: + SHIFT(13); case ts_symbol_variable: SHIFT(8); case ts_symbol_factor: SHIFT(5); case ts_symbol_number: SHIFT(8); + case ts_symbol_1: + SHIFT(12); + case ts_symbol_2: + SHIFT(9); case ts_symbol_term: SHIFT(2); case ts_symbol_expression: @@ -147,7 +89,7 @@ TSTree ts_parse_arithmetic(const char *input) { PARSE_ERROR(); } case 1: - SET_LEX_STATE(12); + SET_LEX_STATE(4); switch (LOOKAHEAD_SYM()) { case ts_symbol___END__: ACCEPT_INPUT(); @@ -155,48 +97,56 @@ TSTree ts_parse_arithmetic(const char *input) { PARSE_ERROR(); } case 2: - SET_LEX_STATE(13); + SET_LEX_STATE(5); switch (LOOKAHEAD_SYM()) { - case ts_symbol_plus: + case ts_symbol_6: SHIFT(3); default: - PARSE_ERROR(); + REDUCE(ts_symbol_expression, 1); } case 3: - SET_LEX_STATE(15); + SET_LEX_STATE(0); switch (LOOKAHEAD_SYM()) { + case ts_symbol_4: + SHIFT(13); + case ts_symbol_1: + SHIFT(12); + case ts_symbol_term: + SHIFT(4); + case ts_symbol_2: + SHIFT(9); case ts_symbol_variable: SHIFT(8); case ts_symbol_factor: SHIFT(5); - case ts_symbol_left_paren: - SHIFT(9); case ts_symbol_number: SHIFT(8); - case ts_symbol_term: - SHIFT(4); default: PARSE_ERROR(); } case 4: - SET_LEX_STATE(16); + SET_LEX_STATE(4); switch (LOOKAHEAD_SYM()) { default: REDUCE(ts_symbol_expression, 3); } case 5: - SET_LEX_STATE(17); + SET_LEX_STATE(7); switch (LOOKAHEAD_SYM()) { - case ts_symbol_times: + case ts_symbol_5: SHIFT(6); default: - PARSE_ERROR(); + REDUCE(ts_symbol_term, 1); } case 6: - SET_LEX_STATE(19); + SET_LEX_STATE(0); switch (LOOKAHEAD_SYM()) { - case ts_symbol_left_paren: + case ts_symbol_4: + SHIFT(13); + case ts_symbol_2: SHIFT(9); + case ts_symbol_1: + SHIFT(12); case ts_symbol_number: SHIFT(8); case ts_symbol_variable: @@ -207,28 +157,32 @@ TSTree ts_parse_arithmetic(const char *input) { PARSE_ERROR(); } case 7: - SET_LEX_STATE(20); + SET_LEX_STATE(4); switch (LOOKAHEAD_SYM()) { default: REDUCE(ts_symbol_term, 3); } case 8: - SET_LEX_STATE(21); + SET_LEX_STATE(4); switch (LOOKAHEAD_SYM()) { default: REDUCE(ts_symbol_factor, 1); } case 9: - SET_LEX_STATE(22); + SET_LEX_STATE(0); switch (LOOKAHEAD_SYM()) { - case ts_symbol_left_paren: - SHIFT(9); + case ts_symbol_4: + SHIFT(13); case ts_symbol_variable: SHIFT(8); case ts_symbol_factor: SHIFT(5); case ts_symbol_number: SHIFT(8); + case ts_symbol_1: + SHIFT(12); + case ts_symbol_2: + SHIFT(9); case ts_symbol_term: SHIFT(2); case ts_symbol_expression: @@ -237,19 +191,31 @@ TSTree ts_parse_arithmetic(const char *input) { PARSE_ERROR(); } case 10: - SET_LEX_STATE(23); + SET_LEX_STATE(9); switch (LOOKAHEAD_SYM()) { - case ts_symbol_right_paren: + case ts_symbol_3: SHIFT(11); default: PARSE_ERROR(); } case 11: - SET_LEX_STATE(25); + SET_LEX_STATE(4); switch (LOOKAHEAD_SYM()) { default: REDUCE(ts_symbol_factor, 3); } + case 12: + SET_LEX_STATE(4); + switch (LOOKAHEAD_SYM()) { + default: + REDUCE(ts_symbol_variable, 1); + } + case 13: + SET_LEX_STATE(4); + switch (LOOKAHEAD_SYM()) { + default: + REDUCE(ts_symbol_number, 1); + } default: PARSE_ERROR(); } diff --git a/spec/main.cpp b/spec/main.cpp index 4314b094..1562a5fd 100644 --- a/spec/main.cpp +++ b/spec/main.cpp @@ -2,6 +2,11 @@ int main(int argc, char *argv[]) { - char *args[] = {nullptr, (char *)"--no-color"}; - return bandit::run(2, args); + const char *args[] = { + "", + "--no-color", + "--only=" + "", + }; + return bandit::run(4, const_cast(args)); } \ No newline at end of file diff --git a/src/compiler/code_gen/c_code.cpp b/src/compiler/code_gen/c_code.cpp index 29dd9034..00699422 100644 --- a/src/compiler/code_gen/c_code.cpp +++ b/src/compiler/code_gen/c_code.cpp @@ -51,12 +51,12 @@ namespace tree_sitter { } class CCodeGenerator { - const Grammar grammar; + const vector rule_names; const ParseTable parse_table; const LexTable lex_table; public: - CCodeGenerator(const Grammar &grammar, const ParseTable &parse_table, const LexTable &lex_table) : - grammar(grammar), + CCodeGenerator(vector rule_names, const ParseTable &parse_table, const LexTable &lex_table) : + rule_names(rule_names), parse_table(parse_table), lex_table(lex_table) {} @@ -152,7 +152,7 @@ namespace tree_sitter { string symbol_enum() { string result = "typedef enum {\n"; - for (string rule_name : grammar.rule_names()) + for (string rule_name : rule_names) result += indent(symbol_id(rule_name)) + ",\n"; result += indent(symbol_id(ParseTable::END_OF_INPUT)); return result + "\n" @@ -192,8 +192,8 @@ namespace tree_sitter { } }; - string c_code(const Grammar &grammar, const ParseTable &parse_table, const LexTable &lex_table) { - return CCodeGenerator(grammar, parse_table, lex_table).code(); + string c_code(const vector rule_names, const ParseTable &parse_table, const LexTable &lex_table) { + return CCodeGenerator(rule_names, parse_table, lex_table).code(); } } } \ No newline at end of file diff --git a/src/compiler/code_gen/c_code.h b/src/compiler/code_gen/c_code.h index e4a44d51..f355b893 100644 --- a/src/compiler/code_gen/c_code.h +++ b/src/compiler/code_gen/c_code.h @@ -7,7 +7,7 @@ namespace tree_sitter { namespace code_gen { - std::string c_code(const Grammar &grammar, const lr::ParseTable &parse_table, const lr::LexTable &lex_table); + std::string c_code(std::vector rule_names, const lr::ParseTable &parse_table, const lr::LexTable &lex_table); } } diff --git a/src/compiler/lr/item.cpp b/src/compiler/lr/item.cpp index 233acfbd..11823ae3 100644 --- a/src/compiler/lr/item.cpp +++ b/src/compiler/lr/item.cpp @@ -16,17 +16,22 @@ namespace tree_sitter { Item Item::at_beginning_of_rule(const std::string &rule_name, const Grammar &grammar) { return Item(rule_name, grammar.rule(rule_name), 0); } - + + Item Item::at_beginning_of_token(const std::string &rule_name, const Grammar &grammar) { + return Item(rule_name, grammar.rule(rule_name), -1); + } + transition_map Item::transitions() const { return lr::transitions(rule).map([&](rules::rule_ptr to_rule) -> item_ptr { - return std::make_shared(rule_name, to_rule, consumed_sym_count + 1); + int next_sym_count = (consumed_sym_count == -1) ? -1 : (consumed_sym_count + 1); + return std::make_shared(rule_name, to_rule, next_sym_count); }); }; - vector Item::next_symbols() const { - vector result; + vector Item::next_symbols() const { + vector result; for (auto pair : lr::transitions(rule)) { - shared_ptr sym = dynamic_pointer_cast(pair.first); + auto sym = dynamic_pointer_cast(pair.first); if (sym) result.push_back(*sym); } return result; @@ -39,7 +44,10 @@ namespace tree_sitter { } bool Item::is_done() const { - return *rule == rules::Blank(); + for (auto pair : transitions()) { + if (*pair.first == rules::Blank()) return true; + } + return false; } std::ostream& operator<<(ostream &stream, const Item &item) { diff --git a/src/compiler/lr/item.h b/src/compiler/lr/item.h index e2e15036..6a8d6ffb 100644 --- a/src/compiler/lr/item.h +++ b/src/compiler/lr/item.h @@ -3,7 +3,7 @@ #include #include "rule.h" -#include "symbol.h" +#include "non_terminal.h" #include "transition_map.h" namespace tree_sitter { @@ -17,9 +17,10 @@ namespace tree_sitter { public: Item(const std::string &rule_name, const rules::rule_ptr rule, int consumed_sym_count); static Item at_beginning_of_rule(const std::string &rule_name, const Grammar &grammar); + static Item at_beginning_of_token(const std::string &rule_name, const Grammar &grammar); transition_map transitions() const; - std::vector next_symbols() const; + std::vector next_symbols() const; bool operator==(const Item &other) const; bool is_done() const; diff --git a/src/compiler/lr/item_set.cpp b/src/compiler/lr/item_set.cpp index 517aa99e..22129a46 100644 --- a/src/compiler/lr/item_set.cpp +++ b/src/compiler/lr/item_set.cpp @@ -18,7 +18,7 @@ namespace tree_sitter { static void add_item(vector &vector, const Item &item, const Grammar &grammar) { if (!vector_contains(vector, item)) { vector.push_back(item); - for (rules::Symbol rule : item.next_symbols()) { + for (rules::NonTerminal rule : item.next_symbols()) { Item next_item = Item::at_beginning_of_rule(rule.name, grammar); add_item(vector, next_item, grammar); } @@ -33,29 +33,17 @@ namespace tree_sitter { ItemSet::ItemSet(const Item &item, const Grammar &grammar) : contents(closure_in_grammar(item, grammar)) {} - template - static transition_map transitions(const ItemSet &item_set, const Grammar &grammar) { - transition_map result; - for (auto item : item_set) { + transition_map ItemSet::all_transitions(const Grammar &grammar) const { + transition_map result; + for (auto item : *this) { auto item_transitions = item.transitions(); for (auto pair : item_transitions) { - std::shared_ptr rule = dynamic_pointer_cast(pair.first); - Item item = *pair.second; - if (rule.get() != nullptr) - result.add(rule, std::make_shared(item, grammar)); + result.add(pair.first, std::make_shared(*pair.second, grammar)); } } return result; } - - transition_map ItemSet::char_transitions(const Grammar &grammar) const { - return transitions(*this, grammar); - } - transition_map ItemSet::sym_transitions(const Grammar &grammar) const { - return transitions(*this, grammar); - } - bool ItemSet::operator==(const tree_sitter::lr::ItemSet &other) const { return contents == other.contents; } diff --git a/src/compiler/lr/item_set.h b/src/compiler/lr/item_set.h index a662282a..df6eae7f 100644 --- a/src/compiler/lr/item_set.h +++ b/src/compiler/lr/item_set.h @@ -21,9 +21,26 @@ namespace tree_sitter { const_iterator begin() const; const_iterator end() const; size_t size() const; + + transition_map all_transitions(const Grammar &grammar) const; - transition_map char_transitions(const Grammar &grammar) const; - transition_map sym_transitions(const Grammar &grammar) const; + template + transition_map transitions(const Grammar &grammar) const { + transition_map result; + for (auto transition : all_transitions(grammar)) { + auto rule = std::dynamic_pointer_cast(transition.first); + if (rule.get()) result.add(rule, transition.second); + } + return result; + } + + template + std::vector next_inputs(const Grammar &grammar) const { + std::vector result; + for (auto pair : transitions(grammar)) + result.push_back(*pair.first); + return result; + } bool operator==(const ItemSet &other) const; }; diff --git a/src/compiler/lr/lex_table.cpp b/src/compiler/lr/lex_table.cpp index 89af6974..4614f28c 100644 --- a/src/compiler/lr/lex_table.cpp +++ b/src/compiler/lr/lex_table.cpp @@ -49,8 +49,6 @@ namespace tree_sitter { LexState::LexState() : actions(unordered_map>()) {} // Table - LexTable::LexTable(vector rule_names) : symbol_names(rule_names) {} - size_t LexTable::add_state() { states.push_back(LexState()); return states.size() - 1; diff --git a/src/compiler/lr/lex_table.h b/src/compiler/lr/lex_table.h index d5444d45..5982af0e 100644 --- a/src/compiler/lr/lex_table.h +++ b/src/compiler/lr/lex_table.h @@ -55,8 +55,6 @@ namespace tree_sitter { class LexTable { public: - LexTable(std::vector rule_names); - size_t add_state(); void add_action(size_t state_index, CharMatch match, LexAction action); void add_default_action(size_t state_index, LexAction action); @@ -64,7 +62,6 @@ namespace tree_sitter { static const std::string START; static const std::string END_OF_INPUT; std::vector states; - const std::vector symbol_names; }; } } diff --git a/src/compiler/lr/parse_table.cpp b/src/compiler/lr/parse_table.cpp index fdb523ea..3e2d8c2d 100644 --- a/src/compiler/lr/parse_table.cpp +++ b/src/compiler/lr/parse_table.cpp @@ -55,10 +55,6 @@ namespace tree_sitter { {} // Table - ParseTable::ParseTable(vector symbol_names) : - symbol_names(symbol_names), - states(vector()) {}; - size_t ParseTable::add_state() { states.push_back(ParseState()); return states.size() - 1; diff --git a/src/compiler/lr/parse_table.h b/src/compiler/lr/parse_table.h index 805800e8..15f7e83f 100644 --- a/src/compiler/lr/parse_table.h +++ b/src/compiler/lr/parse_table.h @@ -59,8 +59,6 @@ namespace tree_sitter { class ParseTable { public: - ParseTable(std::vector rule_names); - size_t add_state(); void add_action(size_t state_index, std::string symbol_name, ParseAction action); void add_default_action(size_t state_index, ParseAction action); @@ -68,7 +66,6 @@ namespace tree_sitter { static const std::string START; static const std::string END_OF_INPUT; std::vector states; - const std::vector symbol_names; }; } } diff --git a/src/compiler/lr/table_builder.cpp b/src/compiler/lr/table_builder.cpp index 318800f3..121a4176 100644 --- a/src/compiler/lr/table_builder.cpp +++ b/src/compiler/lr/table_builder.cpp @@ -5,6 +5,8 @@ #include "item_set.h" #include "grammar.h" +#include + using namespace std; namespace tree_sitter { @@ -13,6 +15,7 @@ namespace tree_sitter { class TableBuilder { const Grammar grammar; + const Grammar lex_grammar; std::unordered_map parse_state_indices; std::unordered_map lex_state_indices; ParseTable parse_table; @@ -29,7 +32,7 @@ namespace tree_sitter { } void add_shift_actions(const ItemSet &item_set, size_t state_index) { - for (auto transition : item_set.sym_transitions(grammar)) { + for (auto transition : item_set.transitions(grammar)) { rules::Symbol symbol = *transition.first; ItemSet item_set = *transition.second; size_t new_state_index = add_parse_state(item_set); @@ -38,7 +41,7 @@ namespace tree_sitter { } void add_advance_actions(const ItemSet &item_set, size_t state_index) { - for (auto transition : item_set.char_transitions(grammar)) { + for (auto transition : item_set.transitions(grammar)) { rules::Character rule = *transition.first; ItemSet item_set = *transition.second; size_t new_state_index = add_lex_state(item_set); @@ -77,13 +80,21 @@ namespace tree_sitter { return state_index; } + ItemSet lex_item_set_for_parse_item_set(const ItemSet &parse_item_set) { + vector items; + for (rules::Token token : parse_item_set.next_inputs(grammar)) + items.push_back(Item::at_beginning_of_token(token.name, lex_grammar)); + return ItemSet(items); + } + size_t add_parse_state(const ItemSet &item_set) { auto state_index = parse_state_index_for_item_set(item_set); if (state_index == NOT_FOUND) { state_index = parse_table.add_state(); parse_state_indices[item_set] = state_index; - parse_table.states[state_index].lex_state_index = add_lex_state(item_set); + ItemSet lex_item_set = lex_item_set_for_parse_item_set(item_set); + parse_table.states[state_index].lex_state_index = add_lex_state(lex_item_set); add_shift_actions(item_set, state_index); add_reduce_actions(item_set, state_index); } @@ -92,13 +103,9 @@ namespace tree_sitter { public: - TableBuilder(const Grammar &grammar) : + TableBuilder(const Grammar &grammar, const Grammar &lex_grammar) : grammar(grammar), - parse_table(ParseTable(grammar.rule_names())), - lex_table(LexTable(grammar.rule_names())), - parse_state_indices(unordered_map()), - lex_state_indices(unordered_map()) - {}; + lex_grammar(lex_grammar) {}; std::pair build() { auto item = Item(ParseTable::START, rules::sym(grammar.start_rule_name), 0); @@ -108,8 +115,8 @@ namespace tree_sitter { } }; - std::pair build_tables(const tree_sitter::Grammar &grammar) { - return TableBuilder(grammar).build(); + std::pair build_tables(const Grammar &grammar, const Grammar &lex_grammar) { + return TableBuilder(grammar, lex_grammar).build(); } } } \ No newline at end of file diff --git a/src/compiler/lr/table_builder.h b/src/compiler/lr/table_builder.h index cdb934fb..7f688fa5 100644 --- a/src/compiler/lr/table_builder.h +++ b/src/compiler/lr/table_builder.h @@ -8,7 +8,7 @@ namespace tree_sitter { class Grammar; namespace lr { - std::pair build_tables(const Grammar &grammar); + std::pair build_tables(const Grammar &grammar, const Grammar &lex_grammar); } } diff --git a/src/compiler/lr/transitions.cpp b/src/compiler/lr/transitions.cpp index 0a41039c..1b98465b 100644 --- a/src/compiler/lr/transitions.cpp +++ b/src/compiler/lr/transitions.cpp @@ -10,7 +10,7 @@ namespace tree_sitter { transition_map value; void visit(const Blank *rule) { - value = transition_map(); + value = transition_map({{ blank(), blank() }}); } void visit(const Character *rule) { @@ -20,7 +20,11 @@ namespace tree_sitter { void visit(const Symbol *rule) { value = transition_map({{ rule->copy(), blank() }}); } - + + void visit(const Token *rule) { + value = transition_map({{ rule->copy(), blank() }}); + } + void visit(const Choice *rule) { value = transitions(rule->left); value.merge(transitions(rule->right), [&](rule_ptr left, rule_ptr right) -> rule_ptr { @@ -39,7 +43,7 @@ namespace tree_sitter { void visit(const Repeat *rule) { value = transitions(rule->content).map([&](const rule_ptr &value) -> rule_ptr { - return seq({ value, choice({ repeat(rule->content), blank() }) }); + return seq({ value, choice({ rule->copy(), blank() }) }); }); } diff --git a/src/compiler/rules/non_terminal.cpp b/src/compiler/rules/non_terminal.cpp new file mode 100644 index 00000000..2d107b66 --- /dev/null +++ b/src/compiler/rules/non_terminal.cpp @@ -0,0 +1,28 @@ +#include "rules.h" +#include "transition_map.h" + +using std::string; +using std::hash; + +namespace tree_sitter { + namespace rules { + NonTerminal::NonTerminal(const std::string &name) : Symbol(name) {}; + + bool NonTerminal::operator==(const Rule &rule) const { + const NonTerminal *other = dynamic_cast(&rule); + return other && (other->name == name); + } + + rule_ptr NonTerminal::copy() const { + return std::make_shared(*this); + } + + string NonTerminal::to_string() const { + return string("#"; + } + + void NonTerminal::accept(Visitor &visitor) const { + visitor.visit(this); + } + } +} \ No newline at end of file diff --git a/src/compiler/rules/non_terminal.h b/src/compiler/rules/non_terminal.h new file mode 100644 index 00000000..77a548f7 --- /dev/null +++ b/src/compiler/rules/non_terminal.h @@ -0,0 +1,20 @@ +#ifndef __tree_sitter__non_terminal__ +#define __tree_sitter__non_terminal__ + +#include "symbol.h" + +namespace tree_sitter { + namespace rules { + class NonTerminal : public Symbol { + public: + NonTerminal(const std::string &name); + + bool operator==(const Rule& other) const; + rule_ptr copy() const; + std::string to_string() const; + void accept(Visitor &visitor) const; + }; + } +} + +#endif diff --git a/src/compiler/rules/rules.cpp b/src/compiler/rules/rules.cpp index 2a511628..6920df83 100644 --- a/src/compiler/rules/rules.cpp +++ b/src/compiler/rules/rules.cpp @@ -47,7 +47,7 @@ namespace tree_sitter { } sym_ptr sym(const string &name) { - return make_shared(name); + return make_shared(name); } rule_ptr token(const std::string &name) { diff --git a/src/compiler/rules/rules.h b/src/compiler/rules/rules.h index 173a57a2..6c4aa6fa 100644 --- a/src/compiler/rules/rules.h +++ b/src/compiler/rules/rules.h @@ -11,6 +11,7 @@ #include "pattern.h" #include "character.h" #include "repeat.h" +#include "non_terminal.h" #include "visitor.h" namespace tree_sitter { diff --git a/src/compiler/rules/token.cpp b/src/compiler/rules/token.cpp index fd13c732..9c279784 100644 --- a/src/compiler/rules/token.cpp +++ b/src/compiler/rules/token.cpp @@ -6,17 +6,13 @@ using std::hash; namespace tree_sitter { namespace rules { - Token::Token(const std::string &name) : name(name) {}; + Token::Token(const std::string &name) : Symbol(name) {}; bool Token::operator==(const Rule &rule) const { const Token *other = dynamic_cast(&rule); return other && (other->name == name); } - size_t Token::hash_code() const { - return typeid(this).hash_code() ^ hash()(name); - } - rule_ptr Token::copy() const { return std::make_shared(*this); } diff --git a/src/compiler/rules/token.h b/src/compiler/rules/token.h index a05931d6..96d9b5a3 100644 --- a/src/compiler/rules/token.h +++ b/src/compiler/rules/token.h @@ -1,21 +1,18 @@ #ifndef __tree_sitter__token__ #define __tree_sitter__token__ -#include "rule.h" +#include "symbol.h" namespace tree_sitter { namespace rules { - class Token : public Rule { + class Token : public Symbol { public: Token(const std::string &name); bool operator==(const Rule& other) const; - size_t hash_code() const; rule_ptr copy() const; std::string to_string() const; void accept(Visitor &visitor) const; - - const std::string name; }; } } diff --git a/tree_sitter.xcodeproj/project.pbxproj b/tree_sitter.xcodeproj/project.pbxproj index cb8b8857..12d995c7 100644 --- a/tree_sitter.xcodeproj/project.pbxproj +++ b/tree_sitter.xcodeproj/project.pbxproj @@ -16,6 +16,8 @@ 1213061B182C84DF00FCF928 /* item.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 12130619182C84DF00FCF928 /* item.cpp */; }; 12130622182C85D300FCF928 /* item_set.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 12130620182C85D300FCF928 /* item_set.cpp */; }; 1214930E181E200B008E9BDA /* main.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 121492E9181E200B008E9BDA /* main.cpp */; }; + 121D8B2E187763F3003CF44B /* non_terminal.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 121D8B2C187763F3003CF44B /* non_terminal.cpp */; }; + 121D8B2F1877AD1C003CF44B /* non_terminal.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 121D8B2C187763F3003CF44B /* non_terminal.cpp */; }; 1225CC6418765693000D4723 /* prepare_grammar_spec.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 1225CC6318765693000D4723 /* prepare_grammar_spec.cpp */; }; 1225CC6718765737000D4723 /* prepare_grammar.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 1225CC6518765737000D4723 /* prepare_grammar.cpp */; }; 1225CC6A187661C7000D4723 /* extract_tokens.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 1225CC68187661C7000D4723 /* extract_tokens.cpp */; }; @@ -50,7 +52,6 @@ 12FD40CA185EEB5E0041A84E /* rule.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 1251209A1830145300C9B56A /* rule.cpp */; }; 12FD40CB185EEB5E0041A84E /* pattern.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 27A340F3EEB184C040521323 /* pattern.cpp */; }; 12FD40D2185EEB970041A84E /* arithmetic.c in Sources */ = {isa = PBXBuildFile; fileRef = 12FD4065185E7C2F0041A84E /* arithmetic.c */; }; - 12FD40D5185FEEDB0041A84E /* item_set_spec.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 1213061D182C857100FCF928 /* item_set_spec.cpp */; }; 12FD40D6185FEEDB0041A84E /* table_builder_spec.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 12512092182F307C00C9B56A /* table_builder_spec.cpp */; }; 12FD40D7185FEEDB0041A84E /* item_spec.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 12D1369C18328C5A005F3369 /* item_spec.cpp */; }; 12FD40D8185FEEDF0041A84E /* rules_spec.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 121492EA181E200B008E9BDA /* rules_spec.cpp */; }; @@ -108,11 +109,12 @@ 12130616182C3D2900FCF928 /* string.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = string.h; sourceTree = ""; }; 12130619182C84DF00FCF928 /* item.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = item.cpp; sourceTree = ""; }; 1213061A182C84DF00FCF928 /* item.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = item.h; sourceTree = ""; }; - 1213061D182C857100FCF928 /* item_set_spec.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; name = item_set_spec.cpp; path = spec/compiler/lr/item_set_spec.cpp; sourceTree = SOURCE_ROOT; }; 12130620182C85D300FCF928 /* item_set.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = item_set.cpp; sourceTree = ""; }; 12130621182C85D300FCF928 /* item_set.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = item_set.h; sourceTree = ""; }; 121492E9181E200B008E9BDA /* main.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; name = main.cpp; path = spec/main.cpp; sourceTree = SOURCE_ROOT; }; 121492EA181E200B008E9BDA /* rules_spec.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; name = rules_spec.cpp; path = spec/compiler/rules/rules_spec.cpp; sourceTree = SOURCE_ROOT; }; + 121D8B2C187763F3003CF44B /* non_terminal.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = non_terminal.cpp; sourceTree = ""; }; + 121D8B2D187763F3003CF44B /* non_terminal.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = non_terminal.h; sourceTree = ""; }; 1225CC6318765693000D4723 /* prepare_grammar_spec.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = prepare_grammar_spec.cpp; sourceTree = ""; }; 1225CC6518765737000D4723 /* prepare_grammar.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = prepare_grammar.cpp; sourceTree = ""; }; 1225CC6618765737000D4723 /* prepare_grammar.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = prepare_grammar.h; sourceTree = ""; }; @@ -210,6 +212,8 @@ 1225CC6F1876AFFF000D4723 /* token.h */, 12FD40E618639B910041A84E /* visitor.cpp */, 12FD40E41862B3530041A84E /* visitor.h */, + 121D8B2C187763F3003CF44B /* non_terminal.cpp */, + 121D8B2D187763F3003CF44B /* non_terminal.h */, ); path = rules; sourceTree = ""; @@ -237,7 +241,6 @@ isa = PBXGroup; children = ( 12ED72A6186FC8220089229B /* transitions_spec.cpp */, - 1213061D182C857100FCF928 /* item_set_spec.cpp */, 12512092182F307C00C9B56A /* table_builder_spec.cpp */, 12D1369C18328C5A005F3369 /* item_spec.cpp */, ); @@ -483,8 +486,8 @@ buildActionMask = 2147483647; files = ( 12FD40D7185FEEDB0041A84E /* item_spec.cpp in Sources */, - 12FD40D5185FEEDB0041A84E /* item_set_spec.cpp in Sources */, 12130614182C3A1700FCF928 /* seq.cpp in Sources */, + 121D8B2E187763F3003CF44B /* non_terminal.cpp in Sources */, 1225CC6A187661C7000D4723 /* extract_tokens.cpp in Sources */, 129D242C183EB1EB00FE9F71 /* table_builder.cpp in Sources */, 125120A4183083BD00C9B56A /* arithmetic.cpp in Sources */, @@ -527,6 +530,7 @@ 12FD40B3185EEB5E0041A84E /* seq.cpp in Sources */, 12FD40B4185EEB5E0041A84E /* table_builder.cpp in Sources */, 12FD40B6185EEB5E0041A84E /* arithmetic.cpp in Sources */, + 121D8B2F1877AD1C003CF44B /* non_terminal.cpp in Sources */, 12FD40DD185FF12C0041A84E /* parser.c in Sources */, 12FD40B8185EEB5E0041A84E /* item.cpp in Sources */, 12FD40B9185EEB5E0041A84E /* string.cpp in Sources */,