Create separate lexer function for keywords

This commit is contained in:
Max Brunsfeld 2018-03-07 11:56:59 -08:00
parent 16cdd2ffbe
commit c0cc35ff07
11 changed files with 231 additions and 114 deletions

View file

@ -81,6 +81,8 @@ typedef struct TSLanguage {
const TSSymbol *alias_sequences;
uint16_t max_alias_sequence_length;
bool (*lex_fn)(TSLexer *, TSStateId);
bool (*keyword_lex_fn)(TSLexer *, TSStateId);
TSSymbol keyword_capture_token;
struct {
const bool *states;
const TSSymbol *symbol_map;

View file

@ -9,7 +9,7 @@ extern "C" {
#include <stdint.h>
#include <stdbool.h>
#define TREE_SITTER_LANGUAGE_VERSION 6
#define TREE_SITTER_LANGUAGE_VERSION 7
typedef unsigned short TSSymbol;
typedef struct TSLanguage TSLanguage;

View file

@ -34,13 +34,14 @@ using rules::Symbol;
using rules::Metadata;
using rules::Seq;
template <bool is_start>
template <bool include_all>
class StartOrEndCharacterAggregator {
public:
void apply(const Rule &rule) {
rule.match(
[this](const Seq &sequence) {
apply(is_start ? *sequence.left : *sequence.right);
apply(*sequence.left);
if (include_all) apply(*sequence.right);
},
[this](const rules::Choice &rule) {
@ -59,15 +60,17 @@ class StartOrEndCharacterAggregator {
CharacterSet result;
};
using StartingCharacterAggregator = StartOrEndCharacterAggregator<true>;
using EndingCharacterAggregator = StartOrEndCharacterAggregator<false>;
using StartingCharacterAggregator = StartOrEndCharacterAggregator<false>;
using AllCharacterAggregator = StartOrEndCharacterAggregator<true>;
class LexTableBuilderImpl : public LexTableBuilder {
LexTable lex_table;
LexTable main_lex_table;
LexTable keyword_lex_table;
const LexicalGrammar grammar;
vector<Rule> separator_rules;
LexConflictManager conflict_manager;
unordered_map<LexItemSet, LexStateId> lex_state_ids;
unordered_map<LexItemSet, LexStateId> main_lex_state_ids;
unordered_map<LexItemSet, LexStateId> keyword_lex_state_ids;
CharacterSet separator_start_characters;
vector<CharacterSet> starting_characters_by_token;
vector<CharacterSet> following_characters_by_token;
@ -75,6 +78,8 @@ class LexTableBuilderImpl : public LexTableBuilder {
const vector<LookaheadSet> &coincident_tokens_by_token;
vector<bool> conflict_status_by_token;
bool conflict_detection_mode;
LookaheadSet keyword_symbols;
Symbol keyword_capture_token;
public:
LexTableBuilderImpl(const SyntaxGrammar &syntax_grammar,
@ -86,7 +91,8 @@ class LexTableBuilderImpl : public LexTableBuilder {
following_characters_by_token(lexical_grammar.variables.size()),
shadowed_tokens_by_token(lexical_grammar.variables.size()),
coincident_tokens_by_token(coincident_tokens),
conflict_detection_mode(false) {
conflict_detection_mode(false),
keyword_capture_token(rules::NONE()) {
// Compute the possible separator rules and the set of separator characters that can occur
// immediately after any token.
@ -113,18 +119,32 @@ class LexTableBuilderImpl : public LexTableBuilder {
});
}
// TODO - Refactor this. In general, a keyword token cannot be followed immediately by
// another alphanumeric character. But this requirement is currently not expressed anywhere in
// the grammar. So without this hack, we would be overly conservative about merging parse
// states because we would often consider `identifier` tokens to *conflict* with keyword
// tokens.
if (is_keyword(grammar.variables[i])) {
following_character_aggregator.result
.exclude('a', 'z')
.exclude('A', 'Z')
.exclude('0', '9')
.exclude('_')
.exclude('$');
if (grammar.variables[i].is_string) {
AllCharacterAggregator aggregator;
aggregator.apply(grammar.variables[i].rule);
bool all_alpha = true, all_lower = true;
for (auto character : aggregator.result.included_chars) {
if (!iswalpha(character)) all_alpha = true;
if (!iswlower(character)) all_lower = false;
}
if (all_lower) {
keyword_symbols.insert(Symbol::terminal(i));
}
// TODO - Refactor this. In general, a keyword token cannot be followed immediately
// by another alphanumeric character. But this requirement is currently not expressed
// anywhere in the grammar. So without this hack, we would be overly conservative about
// merging parse states because we would often consider `identifier` tokens to *conflict*
// with keyword tokens.
if (all_alpha) {
following_character_aggregator.result
.exclude('a', 'z')
.exclude('A', 'Z')
.exclude('0', '9')
.exclude('_')
.exclude('$');
}
}
following_characters_by_token[i] = following_character_aggregator.result;
@ -139,18 +159,35 @@ class LexTableBuilderImpl : public LexTableBuilder {
starting_characters_by_token[i].intersects(separator_start_characters) ||
starting_characters_by_token[j].intersects(separator_start_characters)) {
clear();
add_lex_state(item_set_for_terminals(LookaheadSet({
add_lex_state(main_lex_table, item_set_for_terminals(LookaheadSet({
Symbol::terminal(i),
Symbol::terminal(j)
})));
}), true));
if (conflict_status_by_token[i]) shadowed_tokens_by_token[j].insert(Symbol::terminal(i));
if (conflict_status_by_token[j]) shadowed_tokens_by_token[i].insert(Symbol::terminal(j));
}
}
}
for (Symbol::Index i = 0, n = grammar.variables.size(); i < n; i++) {
Symbol symbol = Symbol::terminal(i);
bool matches_all_keywords = true;
keyword_symbols.for_each([&](Symbol keyword_symbol) {
if (!shadowed_tokens_by_token[keyword_symbol.index].count(symbol)) {
matches_all_keywords = false;
}
});
if (matches_all_keywords && (
keyword_capture_token == rules::NONE() ||
shadowed_tokens_by_token[symbol.index].size() <
shadowed_tokens_by_token[keyword_capture_token.index].size()
)) keyword_capture_token = symbol;
}
}
LexTable build(ParseTable *parse_table) {
BuildResult build(ParseTable *parse_table) {
clear();
conflict_detection_mode = false;
vector<pair<LookaheadSet, vector<ParseState *>>> starting_token_sets;
@ -158,7 +195,11 @@ class LexTableBuilderImpl : public LexTableBuilder {
for (ParseState &parse_state : parse_table->states) {
LookaheadSet token_set;
for (auto &entry : parse_state.terminal_entries) {
token_set.insert(entry.first);
if (keyword_capture_token.is_terminal() && keyword_symbols.contains(entry.first)) {
token_set.insert(keyword_capture_token);
} else {
token_set.insert(entry.first);
}
}
bool did_merge = false;
@ -174,14 +215,17 @@ class LexTableBuilderImpl : public LexTableBuilder {
}
for (auto &pair : starting_token_sets) {
LexStateId state_id = add_lex_state(item_set_for_terminals(pair.first));
LexStateId state_id = add_lex_state(main_lex_table, item_set_for_terminals(pair.first, true));
for (ParseState *parse_state : pair.second) {
parse_state->lex_state_id = state_id;
}
}
add_lex_state(keyword_lex_table, item_set_for_terminals(keyword_symbols, false));
mark_fragile_tokens(parse_table);
remove_duplicate_lex_states(parse_table);
return lex_table;
remove_duplicate_lex_states(main_lex_table, parse_table);
return {main_lex_table, keyword_lex_table, keyword_capture_token};
}
const set<Symbol> &get_incompatible_tokens(Symbol::Index index) const {
@ -189,36 +233,24 @@ class LexTableBuilderImpl : public LexTableBuilder {
}
private:
bool is_keyword(const LexicalVariable &variable) {
EndingCharacterAggregator aggregator;
aggregator.apply(variable.rule);
return
!aggregator.result.includes_all &&
aggregator.result.included_chars.size() == 1 &&
iswalpha(*aggregator.result.included_chars.begin());
}
LexStateId add_lex_state(const LexItemSet &item_set) {
LexStateId add_lex_state(LexTable &lex_table, const LexItemSet &item_set) {
auto &lex_state_ids = &lex_table == &main_lex_table ?
main_lex_state_ids :
keyword_lex_state_ids;
const auto &pair = lex_state_ids.find(item_set);
if (pair == lex_state_ids.end()) {
LexStateId state_id = lex_table.states.size();
lex_table.states.push_back(LexState());
lex_state_ids[item_set] = state_id;
add_accept_token_actions(item_set, state_id);
add_advance_actions(item_set, state_id);
add_accept_token_actions(lex_table, item_set, state_id);
add_advance_actions(lex_table, item_set, state_id);
return state_id;
} else {
return pair->second;
}
}
void clear() {
lex_table.states.clear();
lex_state_ids.clear();
conflict_status_by_token = vector<bool>(grammar.variables.size(), false);
}
void add_advance_actions(const LexItemSet &item_set, LexStateId state_id) {
void add_advance_actions(LexTable &lex_table, const LexItemSet &item_set, LexStateId state_id) {
for (const auto &pair : item_set.transitions()) {
const CharacterSet &characters = pair.first;
const LexItemSet::Transition &transition = pair.second;
@ -253,12 +285,12 @@ class LexTableBuilderImpl : public LexTableBuilder {
if (!prefer_advancing) continue;
}
action.state_index = add_lex_state(transition.destination);
action.state_index = add_lex_state(lex_table, transition.destination);
lex_table.states[state_id].advance_actions[characters] = action;
}
}
void add_accept_token_actions(const LexItemSet &item_set, LexStateId state_id) {
void add_accept_token_actions(LexTable &lex_table, const LexItemSet &item_set, LexStateId state_id) {
for (const LexItem &item : item_set.entries) {
LexItem::CompletionStatus completion_status = item.completion_status();
if (completion_status.is_done) {
@ -340,7 +372,7 @@ class LexTableBuilderImpl : public LexTableBuilder {
return is_compatible;
}
void remove_duplicate_lex_states(ParseTable *parse_table) {
void remove_duplicate_lex_states(LexTable &lex_table, ParseTable *parse_table) {
for (LexState &state : lex_table.states) {
state.accept_action.is_string = false;
state.accept_action.precedence = 0;
@ -407,22 +439,26 @@ class LexTableBuilderImpl : public LexTableBuilder {
}
}
LexItemSet item_set_for_terminals(const LookaheadSet &terminals) {
LexItemSet item_set_for_terminals(const LookaheadSet &terminals, bool with_separators) {
LexItemSet result;
terminals.for_each([&](Symbol symbol) {
if (symbol.is_terminal()) {
for (const auto &rule : rules_for_symbol(symbol)) {
for (const auto &separator_rule : separator_rules) {
result.entries.insert(LexItem(
symbol,
Metadata::separator(
Rule::seq({
separator_rule,
Metadata::main_token(rule)
})
)
));
}
if (with_separators) {
for (const auto &separator_rule : separator_rules) {
result.entries.insert(LexItem(
symbol,
Metadata::separator(
Rule::seq({
separator_rule,
Metadata::main_token(rule)
})
)
));
}
} else {
result.entries.insert(LexItem(symbol, Metadata::main_token(rule)));
}
}
}
});
@ -444,6 +480,12 @@ class LexTableBuilderImpl : public LexTableBuilder {
}
);
}
void clear() {
main_lex_table.states.clear();
main_lex_state_ids.clear();
conflict_status_by_token = vector<bool>(grammar.variables.size(), false);
}
};
unique_ptr<LexTableBuilder> LexTableBuilder::create(const SyntaxGrammar &syntax_grammar,
@ -458,7 +500,7 @@ unique_ptr<LexTableBuilder> LexTableBuilder::create(const SyntaxGrammar &syntax_
));
}
LexTable LexTableBuilder::build(ParseTable *parse_table) {
LexTableBuilder::BuildResult LexTableBuilder::build(ParseTable *parse_table) {
return static_cast<LexTableBuilderImpl *>(this)->build(parse_table);
}

View file

@ -23,7 +23,14 @@ class LexTableBuilder {
const LexicalGrammar &,
const std::unordered_map<rules::Symbol, LookaheadSet> &,
const std::vector<LookaheadSet> &);
LexTable build(ParseTable *);
struct BuildResult {
LexTable main_table;
LexTable keyword_table;
rules::Symbol keyword_capture_token;
};
BuildResult build(ParseTable *);
const std::set<rules::Symbol> &get_incompatible_tokens(rules::Symbol::Index) const;
protected:

View file

@ -73,7 +73,7 @@ class ParseTableBuilderImpl : public ParseTableBuilder {
}
}
tuple<ParseTable, LexTable, CompileError> build() {
BuildResult build() {
// Ensure that the empty rename sequence has index 0.
parse_table.alias_sequences.push_back({});
@ -92,7 +92,13 @@ class ParseTableBuilderImpl : public ParseTableBuilder {
}});
CompileError error = process_part_state_queue();
if (error) return make_tuple(parse_table, LexTable(), error);
if (error) return {
parse_table,
LexTable(),
LexTable(),
rules::NONE(),
error,
};
lex_table_builder = LexTableBuilder::create(
grammar,
@ -105,8 +111,14 @@ class ParseTableBuilderImpl : public ParseTableBuilder {
remove_precedence_values();
remove_duplicate_parse_states();
auto lex_table = lex_table_builder->build(&parse_table);
return make_tuple(parse_table, lex_table, CompileError::none());
auto lex_table_result = lex_table_builder->build(&parse_table);
return {
parse_table,
lex_table_result.main_table,
lex_table_result.keyword_table,
lex_table_result.keyword_capture_token,
CompileError::none()
};
}
private:
@ -770,7 +782,7 @@ unique_ptr<ParseTableBuilder> ParseTableBuilder::create(
return unique_ptr<ParseTableBuilder>(new ParseTableBuilderImpl(syntax_grammar, lexical_grammar));
}
tuple<ParseTable, LexTable, CompileError> ParseTableBuilder::build() {
ParseTableBuilder::BuildResult ParseTableBuilder::build() {
return static_cast<ParseTableBuilderImpl *>(this)->build();
}

View file

@ -17,7 +17,16 @@ namespace build_tables {
class ParseTableBuilder {
public:
static std::unique_ptr<ParseTableBuilder> create(const SyntaxGrammar &, const LexicalGrammar &);
std::tuple<ParseTable, LexTable, CompileError> build();
struct BuildResult {
ParseTable parse_table;
LexTable main_lex_table;
LexTable keyword_lex_table;
rules::Symbol keyword_capture_token;
CompileError error;
};
BuildResult build();
protected:
ParseTableBuilder() = default;

View file

@ -9,6 +9,7 @@
namespace tree_sitter {
using std::move;
using std::pair;
using std::string;
using std::vector;
@ -23,26 +24,32 @@ extern "C" TSCompileResult ts_compile_grammar(const char *input) {
}
auto prepare_grammar_result = prepare_grammar::prepare_grammar(parse_result.grammar);
const SyntaxGrammar &syntax_grammar = get<0>(prepare_grammar_result);
const LexicalGrammar &lexical_grammar = get<1>(prepare_grammar_result);
SyntaxGrammar &syntax_grammar = get<0>(prepare_grammar_result);
LexicalGrammar &lexical_grammar = get<1>(prepare_grammar_result);
CompileError error = get<2>(prepare_grammar_result);
if (error.type) {
return { nullptr, strdup(error.message.c_str()), error.type };
return {nullptr, strdup(error.message.c_str()), error.type};
}
auto builder = build_tables::ParseTableBuilder::create(syntax_grammar, lexical_grammar);
auto table_build_result = builder->build();
const ParseTable &parse_table = get<0>(table_build_result);
const LexTable &lex_table = get<1>(table_build_result);
error = get<2>(table_build_result);
if (error.type) {
return { nullptr, strdup(error.message.c_str()), error.type };
auto build_tables_result = builder->build();
error = build_tables_result.error;
if (error.type != 0) {
return {nullptr, strdup(error.message.c_str()), error.type};
}
string code = generate_code::c_code(parse_result.name, parse_table, lex_table,
syntax_grammar, lexical_grammar);
string code = generate_code::c_code(
parse_result.name,
move(build_tables_result.parse_table),
move(build_tables_result.main_lex_table),
move(build_tables_result.keyword_lex_table),
build_tables_result.keyword_capture_token,
move(syntax_grammar),
move(lexical_grammar)
);
return { strdup(code.c_str()), nullptr, TSCompileErrorTypeNone };
return {
strdup(code.c_str()), nullptr, TSCompileErrorTypeNone };
}
} // namespace tree_sitter

View file

@ -18,6 +18,7 @@ namespace generate_code {
using std::function;
using std::map;
using std::move;
using std::pair;
using std::set;
using std::string;
@ -70,7 +71,9 @@ class CCodeGenerator {
const string name;
const ParseTable parse_table;
const LexTable lex_table;
const LexTable main_lex_table;
const LexTable keyword_lex_table;
Symbol keyword_capture_token;
const SyntaxGrammar syntax_grammar;
const LexicalGrammar lexical_grammar;
map<string, string> sanitized_names;
@ -80,15 +83,17 @@ class CCodeGenerator {
set<Alias> unique_aliases;
public:
CCodeGenerator(string name, const ParseTable &parse_table,
const LexTable &lex_table, const SyntaxGrammar &syntax_grammar,
const LexicalGrammar &lexical_grammar)
CCodeGenerator(string name, ParseTable &&parse_table, LexTable &&main_lex_table,
LexTable &&keyword_lex_table, Symbol keyword_capture_token,
SyntaxGrammar &&syntax_grammar, LexicalGrammar &&lexical_grammar)
: indent_level(0),
name(name),
parse_table(parse_table),
lex_table(lex_table),
syntax_grammar(syntax_grammar),
lexical_grammar(lexical_grammar),
parse_table(move(parse_table)),
main_lex_table(move(main_lex_table)),
keyword_lex_table(move(keyword_lex_table)),
keyword_capture_token(keyword_capture_token),
syntax_grammar(move(syntax_grammar)),
lexical_grammar(move(lexical_grammar)),
next_parse_action_list_index(0) {}
string code() {
@ -105,7 +110,12 @@ class CCodeGenerator {
add_alias_sequences();
}
add_lex_function();
add_lex_function("ts_lex", main_lex_table);
if (keyword_capture_token != rules::NONE()) {
add_lex_function("ts_lex_keywords", keyword_lex_table);
}
add_lex_modes_list();
if (!syntax_grammar.external_tokens.empty()) {
@ -273,8 +283,8 @@ class CCodeGenerator {
line();
}
void add_lex_function() {
line("static bool ts_lex(TSLexer *lexer, TSStateId state) {");
void add_lex_function(string name, const LexTable &lex_table) {
line("static bool " + name + "(TSLexer *lexer, TSStateId state) {");
indent([&]() {
line("START_LEXER();");
_switch("state", [&]() {
@ -457,6 +467,12 @@ class CCodeGenerator {
line(".max_alias_sequence_length = MAX_ALIAS_SEQUENCE_LENGTH,");
line(".lex_fn = ts_lex,");
if (keyword_capture_token != rules::NONE()) {
line(".keyword_lex_fn = ts_lex_keywords,");
line(".keyword_capture_token = " + symbol_id(keyword_capture_token) + ",");
}
line(".external_token_count = EXTERNAL_TOKEN_COUNT,");
if (!syntax_grammar.external_tokens.empty()) {
@ -832,15 +848,17 @@ class CCodeGenerator {
}
};
string c_code(string name, const ParseTable &parse_table,
const LexTable &lex_table, const SyntaxGrammar &syntax_grammar,
const LexicalGrammar &lexical_grammar) {
string c_code(string name, ParseTable &&parse_table, LexTable &&lex_table,
LexTable &&keyword_lex_table, Symbol keyword_capture_token,
SyntaxGrammar &&syntax_grammar, LexicalGrammar &&lexical_grammar) {
return CCodeGenerator(
name,
parse_table,
lex_table,
syntax_grammar,
lexical_grammar
move(parse_table),
move(lex_table),
move(keyword_lex_table),
keyword_capture_token,
move(syntax_grammar),
move(lexical_grammar)
).code();
}

View file

@ -2,6 +2,7 @@
#define COMPILER_GENERATE_CODE_C_CODE_H_
#include <string>
#include "compiler/rule.h"
namespace tree_sitter {
@ -12,8 +13,15 @@ struct ParseTable;
namespace generate_code {
std::string c_code(std::string, const ParseTable &, const LexTable &,
const SyntaxGrammar &, const LexicalGrammar &);
std::string c_code(
std::string,
ParseTable &&,
LexTable &&,
LexTable &&,
rules::Symbol,
SyntaxGrammar &&,
LexicalGrammar &&
);
} // namespace generate_code
} // namespace tree_sitter

View file

@ -73,4 +73,4 @@ inline Symbol NONE() {
} // namespace rules
} // namespace tree_sitter
#endif // COMPILER_RULES_SYMBOL_H_
#endif // COMPILER_RULES_SYMBOL_H_

View file

@ -358,9 +358,6 @@ static Tree *parser__lex(Parser *self, StackVersion version, TSStateId parse_sta
);
ts_lexer_start(&self->lexer);
if (self->language->lex_fn(&self->lexer.data, lex_mode.lex_state)) {
if (length_is_undefined(self->lexer.token_end_position)) {
self->lexer.token_end_position = self->lexer.current_position;
}
break;
}
@ -398,23 +395,39 @@ static Tree *parser__lex(Parser *self, StackVersion version, TSStateId parse_sta
error_end_position = self->lexer.current_position;
}
if (self->lexer.current_position.bytes > last_byte_scanned) {
last_byte_scanned = self->lexer.current_position.bytes;
}
Tree *result;
if (skipped_error) {
Length padding = length_sub(error_start_position, start_position);
Length size = length_sub(error_end_position, error_start_position);
result = ts_tree_make_error(&self->tree_pool, size, padding, first_error_character, self->language);
} else {
TSSymbol symbol = self->lexer.data.result_symbol;
if (found_external_token) {
symbol = self->language->external_scanner.symbol_map[symbol];
}
if (self->lexer.token_end_position.bytes < self->lexer.token_start_position.bytes) {
self->lexer.token_start_position = self->lexer.token_end_position;
}
TSSymbol symbol = self->lexer.data.result_symbol;
Length padding = length_sub(self->lexer.token_start_position, start_position);
Length size = length_sub(self->lexer.token_end_position, self->lexer.token_start_position);
if (found_external_token) {
symbol = self->language->external_scanner.symbol_map[symbol];
} else if (symbol == self->language->keyword_capture_token && symbol != 0) {
uint32_t end_byte = self->lexer.token_end_position.bytes;
ts_lexer_reset(&self->lexer, self->lexer.token_start_position);
ts_lexer_start(&self->lexer);
if (
self->language->keyword_lex_fn(&self->lexer.data, 0) &&
self->lexer.token_end_position.bytes == end_byte &&
ts_language_has_actions(self->language, parse_state, self->lexer.data.result_symbol)
) {
symbol = self->lexer.data.result_symbol;
}
}
result = ts_tree_make_leaf(&self->tree_pool, symbol, padding, size, self->language);
if (found_external_token) {
@ -427,9 +440,6 @@ static Tree *parser__lex(Parser *self, StackVersion version, TSStateId parse_sta
}
}
if (self->lexer.current_position.bytes > last_byte_scanned) {
last_byte_scanned = self->lexer.current_position.bytes;
}
result->bytes_scanned = last_byte_scanned - start_position.bytes + 1;
result->parse_state = parse_state;
result->first_leaf.lex_mode = lex_mode;
@ -466,7 +476,9 @@ static bool parser__can_reuse_first_leaf(Parser *self, TSStateId state, Tree *tr
TSLexMode current_lex_mode = self->language->lex_modes[state];
return
(tree->first_leaf.lex_mode.lex_state == current_lex_mode.lex_state &&
tree->first_leaf.lex_mode.external_lex_state == current_lex_mode.external_lex_state) ||
tree->first_leaf.lex_mode.external_lex_state == current_lex_mode.external_lex_state &&
(tree->first_leaf.symbol != self->language->keyword_capture_token ||
tree->parse_state == state)) ||
(current_lex_mode.external_lex_state == 0 &&
tree->size.bytes > 0 &&
table_entry->is_reusable &&