Merge pull request #163 from tree-sitter/loosen-keyword-capture-token-criteria

Loosen keyword capture token criteria
This commit is contained in:
Max Brunsfeld 2018-05-25 21:51:28 -07:00 committed by GitHub
commit f0e557fa78
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 252 additions and 119 deletions

View file

@ -15,6 +15,19 @@
#include "compiler/rule.h"
#include "utf8proc.h"
namespace std {
using tree_sitter::rules::Symbol;
size_t hash<pair<Symbol::Index, Symbol::Index>>::operator()(
const pair<Symbol::Index, Symbol::Index> &p
) const {
hash<Symbol::Index> hasher;
return hasher(p.first) ^ hasher(p.second);
}
} // namespace std
namespace tree_sitter {
namespace build_tables {
@ -36,8 +49,24 @@ using rules::Symbol;
using rules::Metadata;
using rules::Seq;
static const std::unordered_set<ParseStateId> EMPTY;
bool CoincidentTokenIndex::contains(Symbol a, Symbol b) const {
return a == b || !states_with(a, b).empty();
}
const std::unordered_set<ParseStateId> &CoincidentTokenIndex::states_with(Symbol a, Symbol b) const {
if (a.index > b.index) std::swap(a, b);
auto iter = entries.find({a.index, b.index});
if (iter == entries.end()) {
return EMPTY;
} else {
return iter->second;
}
}
template <bool include_all>
class StartOrEndCharacterAggregator {
class CharacterAggregator {
public:
void apply(const Rule &rule) {
rule.match(
@ -62,8 +91,8 @@ class StartOrEndCharacterAggregator {
CharacterSet result;
};
using StartingCharacterAggregator = StartOrEndCharacterAggregator<false>;
using AllCharacterAggregator = StartOrEndCharacterAggregator<true>;
using StartingCharacterAggregator = CharacterAggregator<false>;
using AllCharacterAggregator = CharacterAggregator<true>;
class LexTableBuilderImpl : public LexTableBuilder {
LexTable main_lex_table;
@ -75,7 +104,8 @@ class LexTableBuilderImpl : public LexTableBuilder {
CharacterSet separator_start_characters;
vector<CharacterSet> starting_characters_by_token;
vector<CharacterSet> following_characters_by_token;
const vector<LookaheadSet> &coincident_tokens_by_token;
const CoincidentTokenIndex &coincident_token_index;
ParseTable *parse_table;
vector<ConflictStatus> conflict_matrix;
bool conflict_detection_mode;
LookaheadSet keyword_symbols;
@ -86,11 +116,13 @@ class LexTableBuilderImpl : public LexTableBuilder {
LexTableBuilderImpl(const SyntaxGrammar &syntax_grammar,
const LexicalGrammar &lexical_grammar,
const unordered_map<Symbol, LookaheadSet> &following_tokens_by_token,
const vector<LookaheadSet> &coincident_tokens)
const CoincidentTokenIndex &coincident_token_index,
ParseTable *parse_table)
: grammar(lexical_grammar),
starting_characters_by_token(lexical_grammar.variables.size()),
following_characters_by_token(lexical_grammar.variables.size()),
coincident_tokens_by_token(coincident_tokens),
coincident_token_index(coincident_token_index),
parse_table(parse_table),
conflict_matrix(lexical_grammar.variables.size() * lexical_grammar.variables.size(), DoesNotMatch),
conflict_detection_mode(false),
keyword_capture_token(rules::NONE()) {
@ -106,51 +138,53 @@ class LexTableBuilderImpl : public LexTableBuilder {
separator_start_characters = separator_character_aggregator.result;
// Compute the set of characters that each token can start with and the set of non-separator
// characters that can follow each token.
// characters that can follow each token. Also identify all of the tokens that can be
// considered 'keywords'.
LOG_START("characterizing tokens");
LookaheadSet potential_keyword_symbols;
for (unsigned i = 0, n = grammar.variables.size(); i < n; i++) {
Symbol token = Symbol::terminal(i);
StartingCharacterAggregator starting_character_aggregator;
starting_character_aggregator.apply(grammar.variables[i].rule);
starting_characters_by_token[i] = starting_character_aggregator.result;
StartingCharacterAggregator following_character_aggregator;
const auto &following_tokens = following_tokens_by_token.find(Symbol::terminal(i));
const auto &following_tokens = following_tokens_by_token.find(token);
if (following_tokens != following_tokens_by_token.end()) {
following_tokens->second.for_each([&](Symbol following_token) {
following_character_aggregator.apply(grammar.variables[following_token.index].rule);
return true;
});
}
following_characters_by_token[i] = following_character_aggregator.result;
if (grammar.variables[i].is_string) {
AllCharacterAggregator aggregator;
aggregator.apply(grammar.variables[i].rule);
bool all_alpha = true, all_lower = true;
for (auto character : aggregator.result.included_chars) {
if (!iswalpha(character) && character != '_') all_alpha = false;
if (!iswlower(character)) all_lower = false;
AllCharacterAggregator all_character_aggregator;
all_character_aggregator.apply(grammar.variables[i].rule);
if (
!starting_character_aggregator.result.includes_all &&
!all_character_aggregator.result.includes_all
) {
bool starts_alpha = true, all_alnum = true;
for (auto character : starting_character_aggregator.result.included_chars) {
if (!iswalpha(character) && character != '_') {
starts_alpha = false;
}
}
if (all_lower) {
keyword_symbols.insert(Symbol::terminal(i));
for (auto character : all_character_aggregator.result.included_chars) {
if (!iswalnum(character) && character != '_') {
all_alnum = false;
}
}
// TODO - Refactor this. In general, a keyword token cannot be followed immediately
// by another alphanumeric character. But this requirement is currently not expressed
// anywhere in the grammar. So without this hack, we would be overly conservative about
// merging parse states because we would often consider `identifier` tokens to *conflict*
// with keyword tokens.
if (all_alpha) {
following_character_aggregator.result
.exclude('a', 'z')
.exclude('A', 'Z')
.exclude('0', '9')
.exclude('_')
.exclude('$');
if (starts_alpha && all_alnum) {
LOG("potential keyword: %s", token_name(token).c_str());
potential_keyword_symbols.insert(token);
}
}
following_characters_by_token[i] = following_character_aggregator.result;
}
LOG_END();
// For each pair of tokens, generate a lex table for just those two tokens and record what
// conflicts arise.
@ -171,50 +205,101 @@ class LexTableBuilderImpl : public LexTableBuilder {
}
LOG_END();
// Find a 'keyword capture token' that matches all of the indentified keywords.
LOG_START("finding keyword capture token");
for (Symbol::Index i = 0, n = grammar.variables.size(); i < n; i++) {
Symbol symbol = Symbol::terminal(i);
bool matches_all_keywords = true;
keyword_symbols.for_each([&](Symbol keyword_symbol) {
if (!(get_conflict_status(symbol, keyword_symbol) & MatchesSameString)) {
matches_all_keywords = false;
Symbol candidate = Symbol::terminal(i);
LookaheadSet homonyms;
potential_keyword_symbols.for_each([&](Symbol other_token) {
if (get_conflict_status(other_token, candidate) & MatchesShorterStringWithinSeparators) {
homonyms.clear();
return false;
}
if (get_conflict_status(candidate, other_token) == MatchesSameString) {
homonyms.insert(other_token);
}
return true;
});
if (!matches_all_keywords) continue;
if (homonyms.empty()) continue;
// Don't use a token to capture keywords if it overlaps with separator characters.
AllCharacterAggregator capture_aggregator;
capture_aggregator.apply(grammar.variables[i].rule);
if (capture_aggregator.result.intersects(separator_start_characters)) continue;
LOG_START(
"keyword capture token candidate: %s, homonym count: %lu",
token_name(candidate).c_str(),
homonyms.size()
);
homonyms.for_each([&](Symbol homonym1) {
homonyms.for_each([&](Symbol homonym2) {
if (get_conflict_status(homonym1, homonym2) & MatchesSameString) {
LOG(
"conflict between homonyms %s %s",
token_name(homonym1).c_str(),
token_name(homonym2).c_str()
);
homonyms.remove(homonym1);
}
return false;
});
return true;
});
// Don't use a token to capture keywords if it conflicts with other tokens
// that occur in the same state as a keyword.
bool shadows_other_tokens = false;
for (Symbol::Index j = 0; j < n; j++) {
Symbol other_symbol = Symbol::terminal(j);
if ((get_conflict_status(other_symbol, symbol) & (MatchesShorterStringWithinSeparators|MatchesLongerStringWithValidNextChar)) &&
!keyword_symbols.contains(other_symbol) &&
keyword_symbols.intersects(coincident_tokens_by_token[j])) {
shadows_other_tokens = true;
break;
Symbol other_token = Symbol::terminal(j);
if (other_token == candidate || homonyms.contains(other_token)) continue;
bool candidate_shadows_other = get_conflict_status(other_token, candidate);
bool other_shadows_candidate = get_conflict_status(candidate, other_token);
if (candidate_shadows_other || other_shadows_candidate) {
homonyms.for_each([&](Symbol homonym) {
bool other_shadows_homonym = get_conflict_status(homonym, other_token);
bool candidate_was_already_present = true;
for (ParseStateId state_id : coincident_token_index.states_with(homonym, other_token)) {
if (!parse_table->states[state_id].has_terminal_entry(candidate)) {
candidate_was_already_present = false;
break;
}
}
if (candidate_was_already_present) return true;
if (candidate_shadows_other) {
homonyms.remove(homonym);
LOG(
"remove %s because candidate would shadow %s",
token_name(homonym).c_str(),
token_name(other_token).c_str()
);
} else if (other_shadows_candidate && !other_shadows_homonym) {
homonyms.remove(homonym);
LOG(
"remove %s because %s would shadow candidate",
token_name(homonym).c_str(),
token_name(other_token).c_str()
);
}
return true;
});
}
}
if (shadows_other_tokens) continue;
// If multiple keyword capture tokens are found, don't bother extracting
// the keywords into their own function.
if (keyword_capture_token == rules::NONE()) {
keyword_capture_token = symbol;
} else {
keyword_capture_token = rules::NONE();
break;
if (homonyms.size() > keyword_symbols.size()) {
LOG_START("found capture token. homonyms:");
homonyms.for_each([&](Symbol homonym) {
LOG("%s", token_name(homonym).c_str());
return true;
});
LOG_END();
keyword_symbols = homonyms;
keyword_capture_token = candidate;
}
LOG_END();
}
LOG_END();
}
BuildResult build(ParseTable *parse_table) {
BuildResult build() {
clear();
conflict_detection_mode = false;
vector<pair<LookaheadSet, vector<ParseState *>>> starting_token_sets;
@ -250,8 +335,8 @@ class LexTableBuilderImpl : public LexTableBuilder {
add_lex_state(keyword_lex_table, item_set_for_terminals(keyword_symbols, false));
mark_fragile_tokens(parse_table);
remove_duplicate_lex_states(main_lex_table, parse_table);
mark_fragile_tokens();
remove_duplicate_lex_states(main_lex_table);
return {main_lex_table, keyword_lex_table, keyword_capture_token};
}
@ -266,10 +351,11 @@ class LexTableBuilderImpl : public LexTableBuilder {
private:
bool record_conflict(Symbol shadowed_token, Symbol other_token, ConflictStatus status) {
if (!conflict_detection_mode) return false;
unsigned index = shadowed_token.index * grammar.variables.size() + other_token.index;
bool old_value = conflict_matrix[index] & status;
bool was_set = conflict_matrix[index] & status;
conflict_matrix[index] = static_cast<ConflictStatus>(conflict_matrix[index] | status);
return old_value;
return !was_set;
}
LexStateId add_lex_state(LexTable &lex_table, const LexItemSet &item_set) {
@ -313,8 +399,12 @@ class LexTableBuilderImpl : public LexTableBuilder {
auto advance_symbol = transition.destination.entries.begin()->lhs;
auto &following_chars = following_characters_by_token[accept_action.symbol.index];
CharacterSet conflicting_following_chars = characters.intersection(following_chars);
CharacterSet conflicting_sep_chars = characters.intersection(separator_start_characters);
if (!conflicting_following_chars.is_empty()) {
if (conflicting_following_chars.is_empty()) {
conflicting_following_chars = characters.intersection(separator_start_characters);
}
if (conflicting_following_chars.is_empty()) {
record_conflict(accept_action.symbol, advance_symbol, MatchesLongerString);
} else {
if (record_conflict(
accept_action.symbol,
advance_symbol,
@ -327,21 +417,6 @@ class LexTableBuilderImpl : public LexTableBuilder {
log_char(*conflicting_following_chars.included_chars.begin())
);
}
} else if (!conflicting_sep_chars.is_empty()) {
if (record_conflict(
accept_action.symbol,
advance_symbol,
MatchesLongerStringWithValidNextChar
)) {
LOG(
"%s shadows %s followed by '%s'",
token_name(advance_symbol).c_str(),
token_name(accept_action.symbol).c_str(),
log_char(*conflicting_sep_chars.included_chars.begin())
);
}
} else {
record_conflict(accept_action.symbol, advance_symbol, MatchesLongerString);
}
}
}
@ -364,9 +439,21 @@ class LexTableBuilderImpl : public LexTableBuilder {
AcceptTokenAction &existing_action = lex_table.states[state_id].accept_action;
if (existing_action.is_present()) {
if (should_replace_accept_action(existing_action, action)) {
record_conflict(existing_action.symbol, action.symbol, MatchesSameString);
if (record_conflict(existing_action.symbol, action.symbol, MatchesSameString)) {
LOG(
"%s shadows %s - same length",
token_name(action.symbol).c_str(),
token_name(existing_action.symbol).c_str()
);
}
} else {
record_conflict(action.symbol, existing_action.symbol, MatchesSameString);
if (record_conflict(action.symbol, existing_action.symbol, MatchesSameString)) {
LOG(
"%s shadows %s - same length",
token_name(existing_action.symbol).c_str(),
token_name(action.symbol).c_str()
);
}
continue;
}
}
@ -375,7 +462,7 @@ class LexTableBuilderImpl : public LexTableBuilder {
}
}
void mark_fragile_tokens(ParseTable *parse_table) {
void mark_fragile_tokens() {
for (ParseState &state : parse_table->states) {
for (auto &entry : state.terminal_entries) {
Symbol token = entry.first;
@ -401,7 +488,7 @@ class LexTableBuilderImpl : public LexTableBuilder {
const LookaheadSet &existing_set = in_left ? right : *left;
existing_set.for_each([&](Symbol existing_symbol) {
if ((get_conflict_status(existing_symbol, different_symbol) & CannotDistinguish) ||
!coincident_tokens_by_token[different_symbol.index].contains(existing_symbol)) {
!coincident_token_index.contains(different_symbol, existing_symbol)) {
is_compatible = false;
return false;
}
@ -417,7 +504,7 @@ class LexTableBuilderImpl : public LexTableBuilder {
return is_compatible;
}
void remove_duplicate_lex_states(LexTable &lex_table, ParseTable *parse_table) {
void remove_duplicate_lex_states(LexTable &lex_table) {
for (LexState &state : lex_table.states) {
state.accept_action.is_string = false;
state.accept_action.precedence = 0;
@ -541,7 +628,7 @@ class LexTableBuilderImpl : public LexTableBuilder {
main_lex_state_ids.clear();
}
string token_name(rules::Symbol &symbol) {
string token_name(const rules::Symbol &symbol) {
const LexicalVariable &variable = grammar.variables[symbol.index];
if (variable.type == VariableTypeNamed) {
return variable.name;
@ -563,17 +650,19 @@ class LexTableBuilderImpl : public LexTableBuilder {
unique_ptr<LexTableBuilder> LexTableBuilder::create(const SyntaxGrammar &syntax_grammar,
const LexicalGrammar &lexical_grammar,
const unordered_map<Symbol, LookaheadSet> &following_tokens,
const vector<LookaheadSet> &coincident_tokens) {
const CoincidentTokenIndex &coincident_tokens,
ParseTable *parse_table) {
return unique_ptr<LexTableBuilder>(new LexTableBuilderImpl(
syntax_grammar,
lexical_grammar,
following_tokens,
coincident_tokens
coincident_tokens,
parse_table
));
}
LexTableBuilder::BuildResult LexTableBuilder::build(ParseTable *parse_table) {
return static_cast<LexTableBuilderImpl *>(this)->build(parse_table);
LexTableBuilder::BuildResult LexTableBuilder::build() {
return static_cast<LexTableBuilderImpl *>(this)->build();
}
ConflictStatus LexTableBuilder::get_conflict_status(Symbol a, Symbol b) const {

View file

@ -4,9 +4,22 @@
#include <memory>
#include <vector>
#include <unordered_map>
#include <set>
#include <unordered_set>
#include <utility>
#include "compiler/parse_table.h"
#include "compiler/lex_table.h"
namespace std {
using tree_sitter::rules::Symbol;
template <>
struct hash<pair<Symbol::Index, Symbol::Index>> {
size_t operator()(const pair<Symbol::Index, Symbol::Index> &) const;
};
} // namespace std
namespace tree_sitter {
struct ParseTable;
@ -30,12 +43,23 @@ enum ConflictStatus {
),
};
struct CoincidentTokenIndex {
std::unordered_map<
std::pair<rules::Symbol::Index, rules::Symbol::Index>,
std::unordered_set<ParseStateId>
> entries;
bool contains(rules::Symbol, rules::Symbol) const;
const std::unordered_set<ParseStateId> &states_with(rules::Symbol, rules::Symbol) const;
};
class LexTableBuilder {
public:
static std::unique_ptr<LexTableBuilder> create(const SyntaxGrammar &,
const LexicalGrammar &,
const std::unordered_map<rules::Symbol, LookaheadSet> &,
const std::vector<LookaheadSet> &);
const CoincidentTokenIndex &,
ParseTable *);
struct BuildResult {
LexTable main_table;
@ -43,7 +67,7 @@ class LexTableBuilder {
rules::Symbol keyword_capture_token;
};
BuildResult build(ParseTable *);
BuildResult build();
ConflictStatus get_conflict_status(rules::Symbol, rules::Symbol) const;

View file

@ -117,5 +117,31 @@ bool LookaheadSet::insert(const Symbol &symbol) {
return false;
}
bool LookaheadSet::remove(const Symbol &symbol) {
if (symbol == rules::END_OF_INPUT()) {
if (eof) {
eof = false;
return true;
}
return false;
}
auto &bits = symbol.is_external() ? external_bits : terminal_bits;
if (bits.size() > static_cast<size_t>(symbol.index)) {
if (bits[symbol.index]) {
bits[symbol.index] = false;
return true;
}
}
return false;
}
void LookaheadSet::clear() {
eof = false;
terminal_bits.clear();
external_bits.clear();
}
} // namespace build_tables
} // namespace tree_sitter

View file

@ -22,6 +22,8 @@ class LookaheadSet {
bool contains(const rules::Symbol &) const;
bool insert_all(const LookaheadSet &);
bool insert(const rules::Symbol &);
bool remove(const rules::Symbol &);
void clear();
bool intersects(const LookaheadSet &) const;
template <typename Callback>

View file

@ -52,28 +52,14 @@ class ParseTableBuilderImpl : public ParseTableBuilder {
ParseItemSetBuilder item_set_builder;
unique_ptr<LexTableBuilder> lex_table_builder;
unordered_map<Symbol, LookaheadSet> following_tokens_by_token;
vector<LookaheadSet> coincident_tokens_by_token;
CoincidentTokenIndex coincident_token_index;
set<std::pair<Symbol, Symbol>> logged_conflict_tokens;
public:
ParseTableBuilderImpl(const SyntaxGrammar &syntax_grammar, const LexicalGrammar &lexical_grammar)
: grammar(syntax_grammar),
lexical_grammar(lexical_grammar),
item_set_builder(syntax_grammar, lexical_grammar),
coincident_tokens_by_token(lexical_grammar.variables.size()) {
for (unsigned i = 0, n = lexical_grammar.variables.size(); i < n; i++) {
coincident_tokens_by_token[i].insert(rules::END_OF_INPUT());
if (lexical_grammar.variables[i].is_string) {
for (unsigned j = 0; j < i; j++) {
if (lexical_grammar.variables[j].is_string) {
coincident_tokens_by_token[i].insert(Symbol::terminal(j));
coincident_tokens_by_token[j].insert(Symbol::terminal(i));
}
}
}
}
}
item_set_builder(syntax_grammar, lexical_grammar) {}
BuildResult build() {
// Ensure that the empty rename sequence has index 0.
@ -106,7 +92,8 @@ class ParseTableBuilderImpl : public ParseTableBuilder {
grammar,
lexical_grammar,
following_tokens_by_token,
coincident_tokens_by_token
coincident_token_index,
&parse_table
);
build_error_parse_state(error_state_id);
@ -115,7 +102,7 @@ class ParseTableBuilderImpl : public ParseTableBuilder {
eliminate_unit_reductions();
populate_used_terminals();
auto lex_table_result = lex_table_builder->build(&parse_table);
auto lex_table_result = lex_table_builder->build();
return {
parse_table,
lex_table_result.main_table,
@ -161,8 +148,7 @@ class ParseTableBuilderImpl : public ParseTableBuilder {
bool conflicts_with_other_tokens = false;
for (unsigned j = 0; j < lexical_grammar.variables.size(); j++) {
Symbol other_token = Symbol::terminal(j);
if (j != i &&
!coincident_tokens_by_token[token.index].contains(other_token) &&
if (!coincident_token_index.contains(token, other_token) &&
(lex_table_builder->get_conflict_status(other_token, token) & CannotMerge)) {
conflicts_with_other_tokens = true;
break;
@ -184,7 +170,7 @@ class ParseTableBuilderImpl : public ParseTableBuilder {
} else {
bool conflicts_with_other_tokens = false;
conflict_free_tokens.for_each([&](Symbol other_token) {
if (!coincident_tokens_by_token[token.index].contains(other_token) &&
if (!coincident_token_index.contains(token, other_token) &&
(lex_table_builder->get_conflict_status(other_token, token) & CannotMerge)) {
LOG(
"exclude %s: conflicts with %s",
@ -332,8 +318,10 @@ class ParseTableBuilderImpl : public ParseTableBuilder {
if (iter->first.is_built_in() || iter->first.is_external()) continue;
for (auto other_iter = terminals.begin(); other_iter != iter; ++other_iter) {
if (other_iter->first.is_built_in() || other_iter->first.is_external()) continue;
coincident_tokens_by_token[iter->first.index].insert(other_iter->first);
coincident_tokens_by_token[other_iter->first.index].insert(iter->first);
coincident_token_index.entries[{
other_iter->first.index,
iter->first.index
}].insert(state_id);
}
}

View file

@ -20,7 +20,6 @@ void _print_indent();
#define LOG_END(...) \
do { \
_outdent_logs(); \
LOG(""); \
} while (0)
#define LOG(...) \

View file

@ -123,6 +123,10 @@ bool ParseState::has_shift_action() const {
return (!nonterminal_entries.empty());
}
bool ParseState::has_terminal_entry(rules::Symbol symbol) const {
return terminal_entries.find(symbol) != terminal_entries.end();
}
void ParseState::each_referenced_state(function<void(ParseStateId *)> fn) {
for (auto &entry : terminal_entries)
for (ParseAction &action : entry.second.actions)

View file

@ -65,6 +65,7 @@ struct ParseState {
bool merge(const ParseState &);
void each_referenced_state(std::function<void(ParseStateId *)>);
bool has_shift_action() const;
bool has_terminal_entry(rules::Symbol) const;
std::map<rules::Symbol, ParseTableEntry> terminal_entries;
std::map<rules::Symbol::Index, ParseStateId> nonterminal_entries;