diff --git a/src/compiler/prepare_grammar/extract_tokens.cc b/src/compiler/prepare_grammar/extract_tokens.cc index 7510ff71..c63cb968 100644 --- a/src/compiler/prepare_grammar/extract_tokens.cc +++ b/src/compiler/prepare_grammar/extract_tokens.cc @@ -18,7 +18,6 @@ namespace tree_sitter { namespace prepare_grammar { -using std::dynamic_pointer_cast; using std::make_shared; using std::make_tuple; using std::map; @@ -122,8 +121,8 @@ tuple extract_tokens */ size_t i = 0; for (const Variable &variable : processed_variables) { - auto symbol = dynamic_pointer_cast(variable.rule); - if (symbol.get() && symbol->is_token && !symbol->is_built_in() && + auto symbol = variable.rule->as(); + if (symbol && symbol->is_token && !symbol->is_built_in() && extractor.token_usage_counts[symbol->index] == 1) { lexical_grammar.variables[symbol->index].type = variable.type; lexical_grammar.variables[symbol->index].name = variable.name; @@ -160,8 +159,8 @@ tuple extract_tokens continue; } - auto symbol = dynamic_pointer_cast(rule); - if (!symbol.get()) + auto symbol = rule->as(); + if (!symbol) return make_tuple(syntax_grammar, lexical_grammar, ubiq_token_err(rule->to_string())); diff --git a/src/compiler/rule.cc b/src/compiler/rule.cc index b4617257..8cb8ce95 100644 --- a/src/compiler/rule.cc +++ b/src/compiler/rule.cc @@ -1,5 +1,5 @@ #include "compiler/rule.h" -#include +#include namespace tree_sitter { diff --git a/src/compiler/rule.h b/src/compiler/rule.h index 55cf09e8..2dbe4046 100644 --- a/src/compiler/rule.h +++ b/src/compiler/rule.h @@ -22,6 +22,11 @@ class Rule { virtual std::string to_string() const = 0; virtual void accept(rules::Visitor *visitor) const = 0; virtual ~Rule(); + + template + const T * as() const { + return dynamic_cast(this); + } }; } // namespace tree_sitter diff --git a/src/compiler/rules/blank.cc b/src/compiler/rules/blank.cc index b5eee146..6348bf62 100644 --- a/src/compiler/rules/blank.cc +++ b/src/compiler/rules/blank.cc @@ -13,7 +13,7 @@ rule_ptr Blank::build() { } bool Blank::operator==(const Rule &rule) const { - return dynamic_cast(&rule) != nullptr; + return rule.as() != nullptr; } size_t Blank::hash_code() const { diff --git a/src/compiler/rules/character_set.cc b/src/compiler/rules/character_set.cc index d18808ee..f5618a07 100644 --- a/src/compiler/rules/character_set.cc +++ b/src/compiler/rules/character_set.cc @@ -64,7 +64,7 @@ CharacterSet::CharacterSet() : includes_all(false), included_chars({}), excluded_chars({}) {} bool CharacterSet::operator==(const Rule &rule) const { - const CharacterSet *other = dynamic_cast(&rule); + const CharacterSet *other = rule.as(); return other && (includes_all == other->includes_all) && (included_chars == other->included_chars) && (excluded_chars == other->excluded_chars); diff --git a/src/compiler/rules/choice.cc b/src/compiler/rules/choice.cc index b4b08e9e..975adbef 100644 --- a/src/compiler/rules/choice.cc +++ b/src/compiler/rules/choice.cc @@ -10,18 +10,17 @@ using std::string; using std::make_shared; using std::vector; using std::set; -using std::dynamic_pointer_cast; Choice::Choice(const vector &elements) : elements(elements) {} void add_choice_element(vector *vec, const rule_ptr new_rule) { - auto choice = dynamic_pointer_cast(new_rule); - if (choice.get()) { + auto choice = new_rule->as(); + if (choice) { for (auto &child : choice->elements) add_choice_element(vec, child); } else { - for (auto &el : *vec) - if (el->operator==(*new_rule)) + for (auto &element : *vec) + if (element->operator==(*new_rule)) return; vec->push_back(new_rule); } @@ -38,7 +37,7 @@ rule_ptr Choice::build(const vector &inputs) { } bool Choice::operator==(const Rule &rule) const { - const Choice *other = dynamic_cast(&rule); + const Choice *other = rule.as(); if (!other) return false; size_t size = elements.size(); diff --git a/src/compiler/rules/metadata.cc b/src/compiler/rules/metadata.cc index d3d56748..63334a2e 100644 --- a/src/compiler/rules/metadata.cc +++ b/src/compiler/rules/metadata.cc @@ -13,8 +13,12 @@ using std::map; Metadata::Metadata(rule_ptr rule, map values) : rule(rule), value(values) {} +rule_ptr Metadata::build(rule_ptr rule, map values) { + return std::make_shared(rule, values); +} + bool Metadata::operator==(const Rule &rule) const { - auto other = dynamic_cast(&rule); + auto other = rule.as(); return other && other->value == value && other->rule->operator==(*this->rule); } diff --git a/src/compiler/rules/metadata.h b/src/compiler/rules/metadata.h index 48b856cc..d93bbc9f 100644 --- a/src/compiler/rules/metadata.h +++ b/src/compiler/rules/metadata.h @@ -19,6 +19,7 @@ enum MetadataKey { class Metadata : public Rule { public: Metadata(rule_ptr rule, std::map value); + static rule_ptr build(rule_ptr rule, std::map value); bool operator==(const Rule &other) const; size_t hash_code() const; diff --git a/src/compiler/rules/named_symbol.cc b/src/compiler/rules/named_symbol.cc index b5afe89f..d846580b 100644 --- a/src/compiler/rules/named_symbol.cc +++ b/src/compiler/rules/named_symbol.cc @@ -11,7 +11,7 @@ using std::hash; NamedSymbol::NamedSymbol(const std::string &name) : name(name) {} bool NamedSymbol::operator==(const Rule &rule) const { - auto other = dynamic_cast(&rule); + auto other = rule.as(); return other && other->name == name; } diff --git a/src/compiler/rules/pattern.cc b/src/compiler/rules/pattern.cc index 83d44698..5ac8f97b 100644 --- a/src/compiler/rules/pattern.cc +++ b/src/compiler/rules/pattern.cc @@ -12,7 +12,7 @@ using std::hash; Pattern::Pattern(const string &string) : value(string) {} bool Pattern::operator==(tree_sitter::Rule const &other) const { - auto pattern = dynamic_cast(&other); + auto pattern = other.as(); return pattern && (pattern->value == value); } diff --git a/src/compiler/rules/repeat.cc b/src/compiler/rules/repeat.cc index 40f41ed3..64d793bb 100644 --- a/src/compiler/rules/repeat.cc +++ b/src/compiler/rules/repeat.cc @@ -6,22 +6,21 @@ namespace tree_sitter { namespace rules { -using std::dynamic_pointer_cast; using std::make_shared; using std::string; Repeat::Repeat(const rule_ptr content) : content(content) {} rule_ptr Repeat::build(const rule_ptr &rule) { - auto inner_repeat = dynamic_pointer_cast(rule); - if (inner_repeat.get()) - return inner_repeat; + auto inner_repeat = rule->as(); + if (inner_repeat) + return rule; else return make_shared(rule); } bool Repeat::operator==(const Rule &rule) const { - const Repeat *other = dynamic_cast(&rule); + auto other = rule.as(); return other && (*other->content == *content); } diff --git a/src/compiler/rules/seq.cc b/src/compiler/rules/seq.cc index 54dc5821..cc934a5c 100644 --- a/src/compiler/rules/seq.cc +++ b/src/compiler/rules/seq.cc @@ -2,6 +2,7 @@ #include #include "compiler/rules/visitor.h" #include "compiler/rules/blank.h" +#include "compiler/rules/metadata.h" namespace tree_sitter { namespace rules { @@ -14,14 +15,25 @@ Seq::Seq(rule_ptr left, rule_ptr right) : left(left), right(right) {} rule_ptr Seq::build(const std::vector &rules) { rule_ptr result = make_shared(); - for (auto &rule : rules) - result = (typeid(*result) != typeid(Blank)) ? make_shared(result, rule) - : rule; + for (auto &rule : rules) { + auto blank = rule->as(); + if (blank) + continue; + + auto metadata = rule->as(); + if (metadata && metadata->rule->as()) + continue; + + if (result->as()) + result = rule; + else + result = make_shared(result, rule); + } return result; } bool Seq::operator==(const Rule &rule) const { - const Seq *other = dynamic_cast(&rule); + const Seq *other = rule.as(); return other && (*other->left == *left) && (*other->right == *right); } diff --git a/src/compiler/rules/string.cc b/src/compiler/rules/string.cc index 075a4267..8a77b169 100644 --- a/src/compiler/rules/string.cc +++ b/src/compiler/rules/string.cc @@ -11,7 +11,7 @@ using std::hash; String::String(string value) : value(value) {} bool String::operator==(const Rule &rule) const { - const String *other = dynamic_cast(&rule); + auto other = rule.as(); return other && (other->value == value); } diff --git a/src/compiler/rules/symbol.cc b/src/compiler/rules/symbol.cc index 33472874..cdfb78cf 100644 --- a/src/compiler/rules/symbol.cc +++ b/src/compiler/rules/symbol.cc @@ -19,7 +19,7 @@ bool Symbol::operator==(const Symbol &other) const { } bool Symbol::operator==(const Rule &rule) const { - const Symbol *other = dynamic_cast(&rule); + auto other = rule.as(); return other && this->operator==(*other); }