diff --git a/include/tree_sitter/parser.h b/include/tree_sitter/parser.h index a335dd6d..e099fd7f 100644 --- a/include/tree_sitter/parser.h +++ b/include/tree_sitter/parser.h @@ -65,6 +65,7 @@ typedef union { typedef struct TSLanguage { uint32_t symbol_count; uint32_t token_count; + uint32_t external_token_count; const char **symbol_names; const TSSymbolMetadata *symbol_metadata; const unsigned short *parse_table; @@ -75,7 +76,7 @@ typedef struct TSLanguage { const bool *external_token_lists; struct { void * (*create)(); - bool (*scan)(TSLexer *, const bool *symbol_whitelist); + bool (*scan)(void *, TSLexer *, const bool *symbol_whitelist); void (*destroy)(void *); } external_scanner; } TSLanguage; @@ -158,7 +159,6 @@ typedef struct TSLanguage { { .type = TSParseActionTypeAccept } \ } - #define GET_LANGUAGE(...) \ static TSLanguage language = { \ .symbol_count = SYMBOL_COUNT, \ @@ -169,6 +169,7 @@ typedef struct TSLanguage { .lex_modes = ts_lex_modes, \ .symbol_names = ts_symbol_names, \ .lex_fn = ts_lex, \ + .external_token_count = EXTERNAL_TOKEN_COUNT, \ .external_token_lists = (const bool *)ts_external_token_lists, \ .external_token_symbol_map = ts_external_token_symbol_map, \ .external_scanner = {__VA_ARGS__} \ diff --git a/spec/compiler/prepare_grammar/extract_tokens_spec.cc b/spec/compiler/prepare_grammar/extract_tokens_spec.cc index 577dead1..613d31cc 100644 --- a/spec/compiler/prepare_grammar/extract_tokens_spec.cc +++ b/spec/compiler/prepare_grammar/extract_tokens_spec.cc @@ -5,6 +5,7 @@ #include "compiler/prepare_grammar/extract_tokens.h" #include "helpers/rule_helpers.h" #include "helpers/equals_pointer.h" +#include "helpers/stream_methods.h" START_TEST @@ -211,9 +212,10 @@ describe("extract_tokens", []() { }, { choice({ i_sym(1), blank() }) }, {}}); AssertThat(get<2>(result), !Equals(CompileError::none())); - AssertThat(get<2>(result), Equals( - CompileError(TSCompileErrorTypeInvalidUbiquitousToken, - "Not a token: (choice (sym 1) (blank))"))); + AssertThat(get<2>(result), Equals(CompileError( + TSCompileErrorTypeInvalidUbiquitousToken, + "Not a token: (choice (non-terminal 1) (blank))" + ))); }); }); }); diff --git a/spec/fixtures/external_scanners/external_scan.c b/spec/fixtures/external_scanners/external_scan.c index 7abab3ae..41ef3706 100644 --- a/spec/fixtures/external_scanners/external_scan.c +++ b/spec/fixtures/external_scanners/external_scan.c @@ -1,13 +1,108 @@ #include +#include + +enum { + percent_string, + percent_string_start, + percent_string_end +}; + +typedef struct { + int32_t open_delimiter; + int32_t close_delimiter; + uint32_t depth; +} Scanner; void *ts_language_external_scanner_example_external_scanner_create() { - puts("HELLO FROM EXTERNAL SCANNER"); - return 0; + Scanner *scanner = malloc(sizeof(Scanner)); + *scanner = (Scanner){ + .open_delimiter = 0, + .close_delimiter = 0, + .depth = 0 + }; + return scanner; } -bool ts_language_external_scanner_example_external_scanner_scan() { - return true; +bool ts_language_external_scanner_example_external_scanner_scan( + void *payload, TSLexer *lexer, const bool *whitelist) { + Scanner *scanner = payload; + + if (whitelist[percent_string]) { + while (lexer->lookahead == ' ' || + lexer->lookahead == '\t' || + lexer->lookahead == '\n') { + lexer->advance(lexer, 0, true); + } + + if (lexer->lookahead != '%') return false; + lexer->advance(lexer, 0, false); + + switch (lexer->lookahead) { + case '(': + scanner->open_delimiter = '('; + scanner->close_delimiter = ')'; + scanner->depth = 1; + break; + case '[': + scanner->open_delimiter = '['; + scanner->close_delimiter = ']'; + scanner->depth = 1; + break; + case '{': + scanner->open_delimiter = '{'; + scanner->close_delimiter = '}'; + scanner->depth = 1; + break; + default: + return false; + } + + lexer->advance(lexer, 0, false); + + for (;;) { + if (scanner->depth == 0) { + lexer->result_symbol = percent_string; + return true; + } + + if (lexer->lookahead == scanner->open_delimiter) { + scanner->depth++; + } else if (lexer->lookahead == scanner->close_delimiter) { + scanner->depth--; + } else if (lexer->lookahead == '#') { + lexer->advance(lexer, 0, false); + if (lexer->lookahead == '{') { + lexer->advance(lexer, 0, false); + lexer->result_symbol = percent_string_start; + return true; + } + } + + lexer->advance(lexer, 0, false); + } + } else if (whitelist[percent_string_end]) { + if (lexer->lookahead != '}') return false; + lexer->advance(lexer, 0, false); + + for (;;) { + if (scanner->depth == 0) { + lexer->result_symbol = percent_string_end; + return true; + } + + if (lexer->lookahead == scanner->open_delimiter) { + scanner->depth++; + } else if (lexer->lookahead == scanner->close_delimiter) { + scanner->depth--; + } + + lexer->advance(lexer, 0, false); + } + } + + return false; } -void ts_language_external_scanner_example_external_scanner_destroy() { +void ts_language_external_scanner_example_external_scanner_destroy(void *payload) { + free(payload); } diff --git a/spec/integration/compile_grammar_spec.cc b/spec/integration/compile_grammar_spec.cc index 21307c89..934b428c 100644 --- a/spec/integration/compile_grammar_spec.cc +++ b/spec/integration/compile_grammar_spec.cc @@ -562,7 +562,11 @@ describe("compile_grammar", []() { "spec/fixtures/external_scanners/external_scan.c" )); - ts_document_set_input_string(document, "%|hi|"); + ts_document_set_input_string(document, "%(sup (external) scanner?)"); + ts_document_parse(document); + assert_root_node("(string)"); + + ts_document_set_input_string(document, "%{sup {} external {} scanner?}"); ts_document_parse(document); assert_root_node("(string)"); diff --git a/src/compiler/generate_code/c_code.cc b/src/compiler/generate_code/c_code.cc index a5a9c17a..36bd7cab 100644 --- a/src/compiler/generate_code/c_code.cc +++ b/src/compiler/generate_code/c_code.cc @@ -128,7 +128,7 @@ class CCodeGenerator { void add_stats() { line("#define STATE_COUNT " + to_string(parse_table.states.size())); line("#define SYMBOL_COUNT " + to_string(parse_table.symbols.size())); - line("#define TOKEN_COUNT " + to_string(lexical_grammar.variables.size() + 1)); + line("#define TOKEN_COUNT " + to_string(lexical_grammar.variables.size() + 1 + syntax_grammar.external_tokens.size())); line("#define EXTERNAL_TOKEN_COUNT " + to_string(syntax_grammar.external_tokens.size())); line(); } @@ -327,7 +327,7 @@ class CCodeGenerator { string external_scanner_name = "ts_language_" + name + "_external_scanner"; line("void *" + external_scanner_name + "_create();"); - line("bool " + external_scanner_name + "_scan();"); + line("bool " + external_scanner_name + "_scan(void *, TSLexer *, const bool *);"); line("void " + external_scanner_name + "_destroy();"); line(); diff --git a/src/compiler/rules/symbol.h b/src/compiler/rules/symbol.h index 46272dc5..4aacf1b2 100644 --- a/src/compiler/rules/symbol.h +++ b/src/compiler/rules/symbol.h @@ -13,8 +13,8 @@ class Symbol : public Rule { typedef enum { Terminal, - NonTerminal, External, + NonTerminal, } Type; Symbol(Index index, Type type); diff --git a/src/runtime/document.c b/src/runtime/document.c index 65f9e435..c68d8c62 100644 --- a/src/runtime/document.c +++ b/src/runtime/document.c @@ -37,7 +37,7 @@ const TSLanguage *ts_document_language(TSDocument *self) { void ts_document_set_language(TSDocument *self, const TSLanguage *language) { ts_document_invalidate(self); - self->parser.language = language; + parser_set_language(&self->parser, language); if (self->tree) { ts_tree_release(self->tree); self->tree = NULL; diff --git a/src/runtime/language.h b/src/runtime/language.h index a4f44b11..5a2693db 100644 --- a/src/runtime/language.h +++ b/src/runtime/language.h @@ -49,6 +49,12 @@ static inline TSStateId ts_language_next_state(const TSLanguage *self, } } +static inline const bool * +ts_language_enabled_external_tokens(const TSLanguage *self, + unsigned external_scanner_state) { + return self->external_token_lists + self->external_token_count * external_scanner_state; +} + #ifdef __cplusplus } #endif diff --git a/src/runtime/lexer.c b/src/runtime/lexer.c index 32910935..77d76ec6 100644 --- a/src/runtime/lexer.c +++ b/src/runtime/lexer.c @@ -123,9 +123,7 @@ void ts_lexer_reset(Lexer *self, Length position) { return; } -void ts_lexer_start(Lexer *self, TSStateId lex_state) { - LOG("start_lex state:%d, pos:%u", lex_state, self->current_position.chars); - +void ts_lexer_start(Lexer *self) { self->token_start_position = self->current_position; self->data.result_symbol = 0; diff --git a/src/runtime/lexer.h b/src/runtime/lexer.h index 1b047e5b..682c3f93 100644 --- a/src/runtime/lexer.h +++ b/src/runtime/lexer.h @@ -30,7 +30,7 @@ typedef struct { void ts_lexer_init(Lexer *); void ts_lexer_set_input(Lexer *, TSInput); void ts_lexer_reset(Lexer *, Length); -void ts_lexer_start(Lexer *, TSStateId); +void ts_lexer_start(Lexer *); #ifdef __cplusplus } diff --git a/src/runtime/parser.c b/src/runtime/parser.c index c37b7871..997103c8 100644 --- a/src/runtime/parser.c +++ b/src/runtime/parser.c @@ -209,23 +209,43 @@ static bool parser__condense_stack(Parser *self) { } static Tree *parser__lex(Parser *self, TSStateId parse_state) { + Length start_position = self->lexer.current_position; + ts_lexer_start(&self->lexer); + + TSLexMode lex_mode = self->language->lex_modes[parse_state]; + if (lex_mode.external_tokens) { + const bool *external_tokens = ts_language_enabled_external_tokens(self->language, lex_mode.external_tokens); + if (self->language->external_scanner.scan( + self->external_scanner_payload, + &self->lexer.data, + external_tokens + )) { + TSSymbol symbol = self->language->external_token_symbol_map[self->lexer.data.result_symbol]; + Length padding = length_sub(self->lexer.token_start_position, start_position); + Length size = length_sub(self->lexer.current_position, self->lexer.token_start_position); + TSSymbolMetadata metadata = ts_language_symbol_metadata(self->language, symbol); + Tree *result = ts_tree_make_leaf(symbol, padding, size, metadata); + result->parse_state = parse_state; + return result; + } else { + ts_lexer_reset(&self->lexer, start_position); + } + } + TSStateId start_state = self->language->lex_modes[parse_state].lex_state; TSStateId current_state = start_state; - Length start_position = self->lexer.current_position; LOG("lex state:%d", start_state); bool skipped_error = false; int32_t first_error_character = 0; Length error_start_position, error_end_position; - ts_lexer_start(&self->lexer, start_state); - while (!self->language->lex_fn(&self->lexer.data, current_state)) { if (current_state != ERROR_STATE) { LOG("retry_in_error_mode"); current_state = ERROR_STATE; ts_lexer_reset(&self->lexer, start_position); - ts_lexer_start(&self->lexer, current_state); + ts_lexer_start(&self->lexer); continue; } @@ -247,7 +267,6 @@ static Tree *parser__lex(Parser *self, TSStateId parse_state) { } Tree *result; - if (skipped_error) { Length padding = length_sub(error_start_position, start_position); Length size = length_sub(error_end_position, error_start_position); @@ -255,18 +274,12 @@ static Tree *parser__lex(Parser *self, TSStateId parse_state) { result = ts_tree_make_error(size, padding, first_error_character); } else { TSSymbol symbol = self->lexer.data.result_symbol; - Length padding = - length_sub(self->lexer.token_start_position, start_position); - Length size = length_sub(self->lexer.current_position, - self->lexer.token_start_position); - result = - ts_tree_make_leaf(symbol, padding, size, - ts_language_symbol_metadata(self->language, symbol)); + Length padding = length_sub(self->lexer.token_start_position, start_position); + Length size = length_sub(self->lexer.current_position, self->lexer.token_start_position); + TSSymbolMetadata metadata = ts_language_symbol_metadata(self->language, symbol); + result = ts_tree_make_leaf(symbol, padding, size, metadata); } - if (!result) - return NULL; - result->parse_state = parse_state; result->first_leaf.lex_state = start_state; return result; @@ -1106,6 +1119,15 @@ bool parser_init(Parser *self) { return true; } +void parser_set_language(Parser *self, const TSLanguage *language) { + self->language = language; + if (language->external_scanner.create) { + self->external_scanner_payload = language->external_scanner.create(); + } else { + self->external_scanner_payload = NULL; + } +} + void parser_destroy(Parser *self) { if (self->stack) ts_stack_delete(self->stack); diff --git a/src/runtime/parser.h b/src/runtime/parser.h index 41512e12..54c041b3 100644 --- a/src/runtime/parser.h +++ b/src/runtime/parser.h @@ -29,11 +29,13 @@ typedef struct { ReusableNode reusable_node; TreePath tree_path1; TreePath tree_path2; + void *external_scanner_payload; } Parser; bool parser_init(Parser *); void parser_destroy(Parser *); Tree *parser_parse(Parser *, TSInput, Tree *); +void parser_set_language(Parser *, const TSLanguage *); #ifdef __cplusplus }