diff --git a/spec/compiler/build_tables/rule_transitions_spec.cpp b/spec/compiler/build_tables/rule_transitions_spec.cpp index 56606849..66f9b936 100644 --- a/spec/compiler/build_tables/rule_transitions_spec.cpp +++ b/spec/compiler/build_tables/rule_transitions_spec.cpp @@ -7,32 +7,24 @@ using namespace build_tables; START_TEST describe("rule transitions", []() { - rule_ptr symbol1 = sym("1"); - rule_ptr symbol2 = sym("2"); - rule_ptr symbol3 = sym("3"); - rule_ptr symbol4 = sym("4"); - rule_ptr char1 = character('a'); + auto symbol1 = sym("1"); + auto symbol2 = sym("2"); + auto symbol3 = sym("3"); + auto symbol4 = sym("4"); + auto char1 = character({ 'a' }); it("handles symbols", [&]() { AssertThat( - rule_transitions(symbol1), - Equals(transition_map({ + sym_transitions(symbol1), + Equals(transition_map({ { symbol1, blank() } }))); }); - it("handles characters", [&]() { - AssertThat( - rule_transitions(char1), - Equals(transition_map({ - { char1, blank() } - }))); - }); - it("handles choices", [&]() { AssertThat( - rule_transitions(choice({ symbol1, symbol2 })), - Equals(transition_map({ + sym_transitions(choice({ symbol1, symbol2 })), + Equals(transition_map({ { symbol1, blank() }, { symbol2, blank() } }))); @@ -40,77 +32,84 @@ describe("rule transitions", []() { it("handles sequences", [&]() { AssertThat( - rule_transitions(seq({ symbol1, symbol2 })), - Equals(transition_map({ + sym_transitions(seq({ symbol1, symbol2 })), + Equals(transition_map({ { symbol1, symbol2 } }))); }); - it("handles_long_sequences", [&]() { + it("handles long sequences", [&]() { AssertThat( - rule_transitions(seq({ + sym_transitions(seq({ symbol1, symbol2, symbol3, symbol4 })), - Equals(transition_map({ + Equals(transition_map({ { symbol1, seq({ symbol2, symbol3, symbol4 }) } }))); }); it("handles sequences whose left sides can be blank", [&]() { AssertThat( - rule_transitions(seq({ + sym_transitions(seq({ choice({ - sym("x"), + symbol1, blank(), }), seq({ - sym("x"), - sym("y") + symbol1, + symbol2 }) - })), Equals(transition_map({ - { sym("x"), choice({ seq({ sym("x"), sym("y") }), sym("y"), }) } + })), Equals(transition_map({ + { symbol1, choice({ seq({ symbol1, symbol2 }), symbol2, }) } }))); }); it("handles choices with common starting symbols", [&]() { AssertThat( - rule_transitions( + sym_transitions( choice({ seq({ symbol1, symbol2 }), seq({ symbol1, symbol3 }) })), - Equals(transition_map({ + Equals(transition_map({ { symbol1, choice({ symbol2, symbol3 }) } }))); }); + it("handles characters", [&]() { + AssertThat( + char_transitions(char1), + Equals(transition_map({ + { char1, blank() } + }))); + }); + it("handles strings", [&]() { AssertThat( - rule_transitions(str("bad")), - Equals(transition_map({ - { character('b'), seq({ character('a'), character('d') }) - } - }))); + char_transitions(str("bad")), + Equals(transition_map({ + { character({ 'b' }, true), seq({ character('a'), character('d') }) } + }))); }); it("handles patterns", [&]() { AssertThat( - rule_transitions(pattern("a|b")), - Equals(transition_map({ - { character('a'), blank() }, - { character('b'), blank() } + char_transitions(pattern("a|b")), + Equals(transition_map({ + { character({ 'a' }, true), blank() }, + { character({ 'b' }, true), blank() } }))); }); it("handles repeats", [&]() { rule_ptr rule = repeat(str("ab")); AssertThat( - rule_transitions(rule), - Equals(transition_map({ + char_transitions(rule), + Equals(transition_map({ { - character('a'), + character({ 'a' }, true), seq({ character('b'), choice({ @@ -122,10 +121,10 @@ describe("rule transitions", []() { rule = repeat(str("a")); AssertThat( - rule_transitions(rule), - Equals(transition_map({ + char_transitions(rule), + Equals(transition_map({ { - character('a'), + character({ 'a' }, true), choice({ rule, blank() @@ -143,14 +142,14 @@ describe("rule transitions", []() { character('"'), }); - AssertThat(rule_transitions(rule), Equals(transition_map({ + AssertThat(char_transitions(rule), Equals(transition_map({ { character({ '"' }, false), seq({ choice({ repeat(character({ '"' }, false)), blank(), }), character('"'), }) }, - { character('"'), blank() }, + { character({ '"' }, true), blank() }, }))); }); }); diff --git a/src/compiler/build_tables/follow_sets.cpp b/src/compiler/build_tables/follow_sets.cpp index b9a68c12..56579205 100644 --- a/src/compiler/build_tables/follow_sets.cpp +++ b/src/compiler/build_tables/follow_sets.cpp @@ -15,14 +15,14 @@ namespace tree_sitter { unordered_map> follow_sets(const ParseItem &item, const Grammar &grammar) { unordered_map> result; - for (auto pair : rule_transitions(item.rule)) { - auto symbol = dynamic_pointer_cast(pair.first); - if (symbol && grammar.has_definition(*symbol)) { + for (auto pair : sym_transitions(item.rule)) { + auto symbol = *pair.first; + if (grammar.has_definition(symbol)) { auto following_non_terminals = first_set(pair.second, grammar); if (rule_can_be_blank(pair.second)) { following_non_terminals.insert(item.lookahead_sym); } - result.insert({ *symbol, following_non_terminals }); + result.insert({ symbol, following_non_terminals }); } } diff --git a/src/compiler/build_tables/item_set_transitions.cpp b/src/compiler/build_tables/item_set_transitions.cpp index 57bb0cba..d3ee0335 100644 --- a/src/compiler/build_tables/item_set_transitions.cpp +++ b/src/compiler/build_tables/item_set_transitions.cpp @@ -21,13 +21,11 @@ namespace tree_sitter { transition_map result; for (LexItem item : item_set) { transition_map item_transitions; - for (auto transition : rule_transitions(item.rule)) { - auto rule = dynamic_pointer_cast(transition.first); - if (rule.get()) { - auto new_item = LexItem(item.lhs, transition.second); - auto new_item_set = LexItemSet({ new_item }); - item_transitions.add(rule, make_shared(new_item_set)); - } + for (auto transition : char_transitions(item.rule)) { + auto rule = transition.first; + auto new_item = LexItem(item.lhs, transition.second); + auto new_item_set = LexItemSet({ new_item }); + item_transitions.add(rule, make_shared(new_item_set)); } result.merge(item_transitions, [](shared_ptr left, shared_ptr right) -> shared_ptr { @@ -42,15 +40,13 @@ namespace tree_sitter { transition_map result; for (ParseItem item : item_set) { transition_map item_transitions; - for (auto transition : rule_transitions(item.rule)) { - auto rule = dynamic_pointer_cast(transition.first); - if (rule.get()) { - auto consumed_symbols = item.consumed_symbols; - consumed_symbols.push_back(rule->is_auxiliary); - auto new_item = ParseItem(item.lhs, transition.second, consumed_symbols, item.lookahead_sym); - auto new_item_set = item_set_closure(ParseItemSet({ new_item }), grammar); - item_transitions.add(rule, make_shared(new_item_set)); - } + for (auto transition : sym_transitions(item.rule)) { + auto rule = transition.first; + auto consumed_symbols = item.consumed_symbols; + consumed_symbols.push_back(rule->is_auxiliary); + auto new_item = ParseItem(item.lhs, transition.second, consumed_symbols, item.lookahead_sym); + auto new_item_set = item_set_closure(ParseItemSet({ new_item }), grammar); + item_transitions.add(rule, make_shared(new_item_set)); } result.merge(item_transitions, [](shared_ptr left, shared_ptr right) -> shared_ptr { diff --git a/src/compiler/build_tables/rule_transitions.cpp b/src/compiler/build_tables/rule_transitions.cpp index 622b32e9..251ee0bc 100644 --- a/src/compiler/build_tables/rule_transitions.cpp +++ b/src/compiler/build_tables/rule_transitions.cpp @@ -9,41 +9,55 @@ namespace tree_sitter { return typeid(*rule) == typeid(Blank); } + template class TransitionsVisitor : public rules::Visitor { public: - transition_map value; + transition_map value; + + static transition_map transitions(const rule_ptr rule) { + TransitionsVisitor visitor; + rule->accept(visitor); + return visitor.value; + } + + void visit_atom(const Rule *rule) { + auto atom = dynamic_cast(rule); + if (atom) { + value = transition_map({{ std::make_shared(*atom), blank() }}); + } + } void visit(const CharacterSet *rule) { - value = transition_map({{ rule->copy(), blank() }}); + visit_atom(rule); } void visit(const Symbol *rule) { - value = transition_map({{ rule->copy(), blank() }}); + visit_atom(rule); } void visit(const Choice *rule) { - value = rule_transitions(rule->left); - value.merge(rule_transitions(rule->right), [&](rule_ptr left, rule_ptr right) -> rule_ptr { + value = transitions(rule->left); + value.merge(transitions(rule->right), [&](rule_ptr left, rule_ptr right) -> rule_ptr { return choice({ left, right }); }); } void visit(const Seq *rule) { - value = rule_transitions(rule->left).map([&](const rule_ptr left_rule) -> rule_ptr { + value = transitions(rule->left).template map([&](const rule_ptr left_rule) -> rule_ptr { if (is_blank(left_rule)) return rule->right; else return seq({ left_rule, rule->right }); }); if (rule_can_be_blank(rule->left)) { - value.merge(rule_transitions(rule->right), [&](rule_ptr left, rule_ptr right) -> rule_ptr { + value.merge(transitions(rule->right), [&](rule_ptr left, rule_ptr right) -> rule_ptr { return choice({ left, right }); }); } } void visit(const Repeat *rule) { - value = rule_transitions(rule->content).map([&](const rule_ptr &value) -> rule_ptr { + value = transitions(rule->content).template map([&](const rule_ptr &value) -> rule_ptr { return seq({ value, choice({ rule->copy(), blank() }) }); }); } @@ -52,20 +66,22 @@ namespace tree_sitter { rule_ptr result = character(rule->value[0]); for (int i = 1; i < rule->value.length(); i++) result = seq({ result, character(rule->value[i]) }); - value = rule_transitions(result); + value = transitions(result); } void visit(const Pattern *rule) { - value = rule_transitions(rule->to_rule_tree()); + value = transitions(rule->to_rule_tree()); } }; - transition_map rule_transitions(const rule_ptr &rule) { - TransitionsVisitor visitor; - rule->accept(visitor); - return visitor.value; + transition_map char_transitions(const rule_ptr &rule) { + return TransitionsVisitor::transitions(rule); } - + + transition_map sym_transitions(const rule_ptr &rule) { + return TransitionsVisitor::transitions(rule); + } + class EpsilonVisitor : public rules::Visitor { public: bool value; diff --git a/src/compiler/build_tables/rule_transitions.h b/src/compiler/build_tables/rule_transitions.h index 19db0ac3..dd23b29c 100644 --- a/src/compiler/build_tables/rule_transitions.h +++ b/src/compiler/build_tables/rule_transitions.h @@ -1,13 +1,15 @@ #ifndef __tree_sitter__transitions__ #define __tree_sitter__transitions__ -#include "rule.h" +#include "character_set.h" +#include "symbol.h" #include "transition_map.h" namespace tree_sitter { namespace build_tables { bool rule_can_be_blank(const rules::rule_ptr &rule); - transition_map rule_transitions(const rules::rule_ptr &rule); + transition_map char_transitions(const rules::rule_ptr &rule); + transition_map sym_transitions(const rules::rule_ptr &rule); } } diff --git a/src/compiler/rules/character_set.h b/src/compiler/rules/character_set.h index 5838e7eb..d8d33ff9 100644 --- a/src/compiler/rules/character_set.h +++ b/src/compiler/rules/character_set.h @@ -50,6 +50,8 @@ namespace tree_sitter { std::set ranges; }; + + typedef std::shared_ptr char_ptr; } } diff --git a/src/compiler/rules/rules.cpp b/src/compiler/rules/rules.cpp index 8e6a3fe0..eadab77e 100644 --- a/src/compiler/rules/rules.cpp +++ b/src/compiler/rules/rules.cpp @@ -11,16 +11,16 @@ namespace tree_sitter { return make_shared(); } - rule_ptr character(char value) { + char_ptr character(char value) { set ranges = { value }; return make_shared(ranges); } - rule_ptr character(const set &ranges) { + char_ptr character(const set &ranges) { return make_shared(ranges); } - rule_ptr character(const set &ranges, bool sign) { + char_ptr character(const set &ranges, bool sign) { return make_shared(ranges, sign); } @@ -52,7 +52,7 @@ namespace tree_sitter { return make_shared(value); } - rule_ptr sym(const string &name) { + sym_ptr sym(const string &name) { return make_shared(name, false); } diff --git a/src/compiler/rules/rules.h b/src/compiler/rules/rules.h index 02fee8b1..8208ed36 100644 --- a/src/compiler/rules/rules.h +++ b/src/compiler/rules/rules.h @@ -16,16 +16,16 @@ namespace tree_sitter { namespace rules { rule_ptr blank(); - rule_ptr character(char value); - rule_ptr character(const std::set &matches); - rule_ptr character(const std::set &matches, bool); + char_ptr character(char value); + char_ptr character(const std::set &matches); + char_ptr character(const std::set &matches, bool); rule_ptr choice(const std::vector &rules); rule_ptr pattern(const std::string &value); rule_ptr repeat(const rule_ptr content); rule_ptr seq(const std::vector &rules); rule_ptr str(const std::string &value); - rule_ptr sym(const std::string &name); + sym_ptr sym(const std::string &name); rule_ptr aux_sym(const std::string &name); } } diff --git a/src/compiler/rules/symbol.h b/src/compiler/rules/symbol.h index 6488cadb..7b4d1d8b 100644 --- a/src/compiler/rules/symbol.h +++ b/src/compiler/rules/symbol.h @@ -22,6 +22,8 @@ namespace tree_sitter { std::string name; bool is_auxiliary; }; + + typedef std::shared_ptr sym_ptr; } }