diff --git a/include/tree_sitter/parser.h b/include/tree_sitter/parser.h index 9e1860ba..603015ac 100644 --- a/include/tree_sitter/parser.h +++ b/include/tree_sitter/parser.h @@ -13,8 +13,6 @@ extern "C" { #define ts_parse_state_error ((TSStateId)-1) #define TS_DEBUG_BUFFER_SIZE 512 -typedef struct TSTree TSTree; - typedef unsigned short TSStateId; typedef struct { @@ -31,23 +29,31 @@ typedef struct { bool structural : 1; } TSSymbolMetadata; +typedef enum { + TSTransitionTypeMain, + TSTransitionTypeSeparator, + TSTransitionTypeError, +} TSTransitionType; + typedef struct TSLexer { - void (*start_fn)(struct TSLexer *, TSStateId); - bool (*advance_fn)(struct TSLexer *, TSStateId, bool); - TSTree *(*accept_fn)(struct TSLexer *, TSSymbol, TSSymbolMetadata, - const char *, bool fragile); + void (*advance)(struct TSLexer *, TSStateId, TSTransitionType); + + TSLength current_position; + TSLength token_end_position; + TSLength token_start_position; + TSLength error_end_position; const char *chunk; size_t chunk_start; size_t chunk_size; - TSLength current_position; - TSLength token_end_position; - TSLength token_start_position; - size_t lookahead_size; int32_t lookahead; TSStateId starting_state; + TSSymbol result_symbol; + bool result_is_fragile; + bool result_follows_error; + int32_t first_unexpected_character; TSInput input; TSDebugger debugger; @@ -89,17 +95,16 @@ struct TSLanguage { const TSParseActionEntry *parse_actions; const TSStateId *lex_states; const TSParseAction *recovery_actions; - TSTree *(*lex_fn)(TSLexer *, TSStateId, bool); + void (*lex_fn)(TSLexer *, TSStateId, bool); }; /* * Lexer Macros */ -#define START_LEXER() \ - lexer->start_fn(lexer, state); \ - int32_t lookahead; \ - next_state: \ +#define START_LEXER() \ + int32_t lookahead; \ + next_state: \ lookahead = lexer->lookahead; #define GO_TO_STATE(state_value) \ @@ -108,36 +113,40 @@ struct TSLanguage { goto next_state; \ } -#define ADVANCE(state_value) \ - { \ - lexer->advance_fn(lexer, state_value, true); \ - GO_TO_STATE(state_value); \ +#define ADVANCE(state_value) \ + { \ + lexer->advance(lexer, state_value, TSTransitionTypeMain); \ + GO_TO_STATE(state_value); \ } -#define SKIP(state_value) \ - { \ - lexer->advance_fn(lexer, state_value, false); \ - GO_TO_STATE(state_value); \ +#define SKIP(state_value) \ + { \ + lexer->advance(lexer, state_value, TSTransitionTypeSeparator); \ + GO_TO_STATE(state_value); \ } -#define ACCEPT_FRAGILE_TOKEN(symbol) \ - return lexer->accept_fn(lexer, symbol, ts_symbol_metadata[symbol], \ - ts_symbol_names[symbol], true); - -#define ACCEPT_TOKEN(symbol) \ - return lexer->accept_fn(lexer, symbol, ts_symbol_metadata[symbol], \ - ts_symbol_names[symbol], false); - -#define LEX_ERROR() \ - if (error_mode) { \ - if (state == ts_lex_state_error) \ - lexer->advance_fn(lexer, state, true); \ - GO_TO_STATE(ts_lex_state_error) \ - } else { \ - return lexer->accept_fn(lexer, ts_builtin_sym_error, (TSSymbolMetadata){}, \ - "ERROR", false); \ +#define ACCEPT_FRAGILE_TOKEN(symbol_value) \ + { \ + lexer->result_is_fragile = true; \ + lexer->result_symbol = symbol_value; \ + return; \ } +#define ACCEPT_TOKEN(symbol_value) \ + { \ + lexer->result_symbol = symbol_value; \ + return; \ + } + +#define LEX_ERROR() \ + if (error_mode) { \ + if (state == ts_lex_state_error) \ + lexer->advance(lexer, state, TSTransitionTypeError); \ + } else { \ + error_mode = true; \ + } \ + GO_TO_STATE(ts_lex_state_error) + /* * Parse Table Macros */ diff --git a/spec/runtime/document_spec.cc b/spec/runtime/document_spec.cc index d943c653..fdcb227e 100644 --- a/spec/runtime/document_spec.cc +++ b/spec/runtime/document_spec.cc @@ -162,7 +162,7 @@ describe("Document", [&]() { ts_document_parse(doc); AssertThat(debugger->messages, Contains("lookahead char:'1'")); - AssertThat(debugger->messages, Contains("advance state:1")); + AssertThat(debugger->messages, Contains("accept_token sym:[")); AssertThat(debugger->messages, Contains("accept_token sym:number")); }); diff --git a/spec/runtime/parser_spec.cc b/spec/runtime/parser_spec.cc index 020aa319..e75b1b1f 100644 --- a/spec/runtime/parser_spec.cc +++ b/spec/runtime/parser_spec.cc @@ -88,7 +88,13 @@ describe("Parser", [&]() { ts_document_set_language(doc, get_test_language("json")); }); - describe("when the error occurs at the beginning of a token", [&]() { + auto get_node_text = [&](TSNode node) { + size_t start = ts_node_start_byte(node); + size_t end = ts_node_end_byte(node); + return input->content.substr(start, end - start); + }; + + describe("when there is an invalid substring right before a valid token", [&]() { it("computes the error node's size and position correctly", [&]() { set_text(" [123, @@@@@, true]"); @@ -96,18 +102,24 @@ describe("Parser", [&]() { "(array (number) (ERROR (UNEXPECTED '@')) (true))"); TSNode error = ts_node_named_child(root, 1); - TSNode last = ts_node_named_child(root, 2); - AssertThat(ts_node_name(error, doc), Equals("ERROR")); - AssertThat(ts_node_start_byte(error), Equals(strlen(" [123, "))); - AssertThat(ts_node_end_byte(error), Equals(strlen(" [123, @@@@@,"))); + AssertThat(get_node_text(error), Equals("@@@@@,")); + AssertThat(ts_node_child_count(error), Equals(2)); + AssertThat(ts_node_child_count(error), Equals(2)); - AssertThat(ts_node_name(last, doc), Equals("true")); - AssertThat(ts_node_start_byte(last), Equals(strlen(" [123, @@@@@, "))) + TSNode garbage = ts_node_child(error, 0); + AssertThat(get_node_text(garbage), Equals("@@@@@")); + + TSNode comma = ts_node_child(error, 1); + AssertThat(get_node_text(comma), Equals(",")); + + TSNode node_after_error = ts_node_named_child(root, 2); + AssertThat(ts_node_name(node_after_error, doc), Equals("true")); + AssertThat(get_node_text(node_after_error), Equals("true")); }); }); - describe("when the error occurs in the middle of a token", [&]() { + describe("when there is an unexpected string in the middle of a token", [&]() { it("computes the error node's size and position correctly", [&]() { set_text(" [123, faaaaalse, true]"); @@ -115,54 +127,40 @@ describe("Parser", [&]() { "(array (number) (ERROR (UNEXPECTED 'a')) (true))"); TSNode error = ts_node_named_child(root, 1); - TSNode last = ts_node_named_child(root, 2); - AssertThat(ts_node_symbol(error), Equals(ts_builtin_sym_error)); - AssertThat(ts_node_name(error, doc), Equals("ERROR")); - AssertThat(ts_node_start_byte(error), Equals(strlen(" [123, "))) - AssertThat(ts_node_end_byte(error), Equals(strlen(" [123, faaaaalse,"))) + AssertThat(get_node_text(error), Equals("faaaaalse,")); + AssertThat(ts_node_child_count(error), Equals(2)); + TSNode garbage = ts_node_child(error, 0); + AssertThat(ts_node_name(garbage, doc), Equals("ERROR")); + AssertThat(get_node_text(garbage), Equals("faaaaalse")); + + TSNode comma = ts_node_child(error, 1); + AssertThat(ts_node_name(comma, doc), Equals(",")); + AssertThat(get_node_text(comma), Equals(",")); + + TSNode last = ts_node_named_child(root, 2); AssertThat(ts_node_name(last, doc), Equals("true")); AssertThat(ts_node_start_byte(last), Equals(strlen(" [123, faaaaalse, "))); }); }); - describe("when the error occurs after one or more tokens", [&]() { + describe("when there is one unexpected token between two valid tokens", [&]() { it("computes the error node's size and position correctly", [&]() { set_text(" [123, true false, true]"); assert_root_node( - "(array (number) (ERROR (true) (UNEXPECTED 'f')) (false) (true))"); + "(array (number) (ERROR (true)) (false) (true))"); TSNode error = ts_node_named_child(root, 1); - TSNode last = ts_node_named_child(root, 2); - AssertThat(ts_node_name(error, doc), Equals("ERROR")); - AssertThat(ts_node_start_byte(error), Equals(strlen(" [123, "))); - AssertThat(ts_node_end_byte(error), Equals(strlen(" [123, true "))); + AssertThat(get_node_text(error), Equals("true")); + AssertThat(ts_node_child_count(error), Equals(1)); + TSNode last = ts_node_named_child(root, 2); AssertThat(ts_node_name(last, doc), Equals("false")); - AssertThat(ts_node_start_byte(last), Equals(strlen(" [123, true "))); - }); - }); - - describe("when the error is an empty string", [&]() { - it("computes the error node's size and position correctly", [&]() { - set_text(" [123, , true]"); - - assert_root_node( - "(array (number) (ERROR (UNEXPECTED ',')) (true))"); - - TSNode error = ts_node_named_child(root, 1); - TSNode last = ts_node_named_child(root, 2); - - AssertThat(ts_node_name(error, doc), Equals("ERROR")); - AssertThat(ts_node_start_byte(error), Equals(strlen(" [123, "))); - AssertThat(ts_node_end_byte(error), Equals(strlen(" [123, ,"))) - - AssertThat(ts_node_name(last, doc), Equals("true")); - AssertThat(ts_node_start_byte(last), Equals(strlen(" [123, , "))); + AssertThat(get_node_text(last), Equals("false")); }); }); }); @@ -275,7 +273,7 @@ describe("Parser", [&]() { insert_text(strlen("var x = y"), " *"); assert_root_node( - "(program (var_declaration (identifier) (ERROR (identifier) (UNEXPECTED ';'))))"); + "(program (var_declaration (identifier) (ERROR (identifier))))"); insert_text(strlen("var x = y *"), " z"); @@ -378,7 +376,7 @@ describe("Parser", [&]() { assert_root_node( "(program " - "(expression_statement (number) (ERROR (UNEXPECTED '4') (number))) " + "(expression_statement (number) (ERROR (number))) " "(expression_statement (math_op (number) (number))))"); }); }); diff --git a/src/compiler/generate_code/c_code.cc b/src/compiler/generate_code/c_code.cc index dacc434d..4ab67644 100644 --- a/src/compiler/generate_code/c_code.cc +++ b/src/compiler/generate_code/c_code.cc @@ -192,7 +192,7 @@ class CCodeGenerator { void add_lex_function() { line( - "static TSTree *ts_lex(TSLexer *lexer, TSStateId state, bool error_mode) " + "static void ts_lex(TSLexer *lexer, TSStateId state, bool error_mode) " "{"); indent([&]() { line("START_LEXER();"); diff --git a/src/runtime/language.c b/src/runtime/language.c index e0106893..ef57aa16 100644 --- a/src/runtime/language.c +++ b/src/runtime/language.c @@ -2,11 +2,16 @@ #include "runtime/language.h" #include "runtime/tree.h" +static const TSParseAction ERROR_SHIFT_EXTRA = { + .type = TSParseActionTypeShift, .extra = true, +}; + const TSParseAction *ts_language_actions(const TSLanguage *self, TSStateId state, TSSymbol symbol, size_t *count) { if (state == ts_parse_state_error) { *count = 1; - return &self->recovery_actions[symbol]; + return (symbol == ts_builtin_sym_error) ? &ERROR_SHIFT_EXTRA + : &self->recovery_actions[symbol]; } size_t action_index = 0; diff --git a/src/runtime/lexer.c b/src/runtime/lexer.c index 47331e15..5afb82d8 100644 --- a/src/runtime/lexer.c +++ b/src/runtime/lexer.c @@ -47,24 +47,11 @@ static void ts_lexer__get_lookahead(TSLexer *self) { LOG_LOOKAHEAD(); } -static void ts_lexer__start(TSLexer *self, TSStateId lex_state) { - LOG("start_lex state:%d, pos:%lu", lex_state, self->current_position.chars); - LOG_LOOKAHEAD(); - - self->starting_state = lex_state; - self->token_start_position = self->current_position; - if (!self->chunk) - ts_lexer__get_chunk(self); - if (!self->lookahead_size) - ts_lexer__get_lookahead(self); -} - -static bool ts_lexer__advance(TSLexer *self, TSStateId state, - bool in_main_token) { - LOG("advance state:%d", state); +static void ts_lexer__advance(TSLexer *self, TSStateId state, + TSTransitionType transition_type) { if (self->chunk == empty_chunk) - return false; + return; if (self->lookahead_size) { self->current_position.bytes += self->lookahead_size; @@ -78,53 +65,41 @@ static bool ts_lexer__advance(TSLexer *self, TSStateId state, } } - if (!in_main_token) - self->token_start_position = self->current_position; + switch (transition_type) { + case TSTransitionTypeSeparator: + if (self->result_follows_error) { + LOG("skip_error state:%d", state); + } else { + LOG("skip_separator state:%d", state); + self->token_start_position = self->current_position; + } + break; + case TSTransitionTypeError: + LOG("skip_error state:%d", state); + self->result_follows_error = true; + self->error_end_position = self->current_position; + if (!self->first_unexpected_character) + self->first_unexpected_character = self->lookahead; + break; + default: + LOG("advance state:%d", state); + break; + } if (self->current_position.bytes >= self->chunk_start + self->chunk_size) ts_lexer__get_chunk(self); ts_lexer__get_lookahead(self); - return true; -} - -static TSTree *ts_lexer__accept(TSLexer *self, TSSymbol symbol, - TSSymbolMetadata metadata, - const char *symbol_name, bool fragile) { - TSLength size = - ts_length_sub(self->current_position, self->token_start_position); - TSLength padding = - ts_length_sub(self->token_start_position, self->token_end_position); - self->token_end_position = self->current_position; - - TSTree *result; - if (symbol == ts_builtin_sym_error) { - LOG("error_char"); - result = ts_tree_make_error(size, padding, self->lookahead); - } else { - LOG("accept_token sym:%s", symbol_name); - result = ts_tree_make_leaf(symbol, padding, size, metadata); - } - - if (!result) - return NULL; - - if (fragile) - result->lex_state = self->starting_state; - - return result; } /* - * The lexer's methods are stored as struct fields so that generated parsers - * can call them without needing to be linked against this library. + * The lexer's advance method is stored as a struct field so that generated + * parsers can call it without needing to be linked against this library. */ void ts_lexer_init(TSLexer *self) { *self = (TSLexer){ - .start_fn = ts_lexer__start, - .advance_fn = ts_lexer__advance, - .accept_fn = ts_lexer__accept, + .advance = ts_lexer__advance, .chunk = NULL, .chunk_start = 0, .debugger = ts_debugger_null(), @@ -154,3 +129,40 @@ void ts_lexer_reset(TSLexer *self, TSLength position) { ts_lexer__reset(self, position); return; } + +void ts_lexer_start(TSLexer *self, TSStateId lex_state) { + LOG("start_lex state:%d, pos:%lu", lex_state, self->current_position.chars); + LOG_LOOKAHEAD(); + + self->starting_state = lex_state; + self->token_start_position = self->current_position; + self->result_follows_error = false; + self->result_is_fragile = false; + self->result_symbol = 0; + self->first_unexpected_character = 0; + + if (!self->chunk) + ts_lexer__get_chunk(self); + if (!self->lookahead_size) + ts_lexer__get_lookahead(self); +} + +void ts_lexer_finish(TSLexer *self, TSLexerResult *result) { + result->padding = + ts_length_sub(self->token_start_position, self->token_end_position); + + if (self->result_follows_error) { + result->symbol = ts_builtin_sym_error; + result->size = + ts_length_sub(self->error_end_position, self->token_start_position); + result->first_unexpected_character = self->first_unexpected_character; + result->is_fragile = true; + ts_lexer_reset(self, self->error_end_position); + } else { + result->symbol = self->result_symbol; + result->size = + ts_length_sub(self->current_position, self->token_start_position); + result->is_fragile = self->result_is_fragile; + self->token_end_position = self->current_position; + } +} diff --git a/src/runtime/lexer.h b/src/runtime/lexer.h index 8141375d..97ee3787 100644 --- a/src/runtime/lexer.h +++ b/src/runtime/lexer.h @@ -7,9 +7,19 @@ extern "C" { #include "tree_sitter/parser.h" +typedef struct { + TSSymbol symbol; + TSLength padding; + TSLength size; + bool is_fragile; + int32_t first_unexpected_character; +} TSLexerResult; + void ts_lexer_init(TSLexer *); void ts_lexer_set_input(TSLexer *, TSInput); void ts_lexer_reset(TSLexer *, TSLength); +void ts_lexer_start(TSLexer *, TSStateId); +void ts_lexer_finish(TSLexer *, TSLexerResult *); #ifdef __cplusplus } diff --git a/src/runtime/node.h b/src/runtime/node.h index c9bfd0b5..d6007c94 100644 --- a/src/runtime/node.h +++ b/src/runtime/node.h @@ -2,6 +2,7 @@ #define RUNTIME_NODE_H_ #include "tree_sitter/parser.h" +#include "runtime/tree.h" TSNode ts_node_make(const TSTree *, size_t character, size_t byte, size_t row); diff --git a/src/runtime/parser.c b/src/runtime/parser.c index b25ebd1d..03430387 100644 --- a/src/runtime/parser.c +++ b/src/runtime/parser.c @@ -206,6 +206,32 @@ static bool ts_parser__can_reuse(TSParser *self, StackVersion version, return true; } +static TSTree *ts_parser__lex(TSParser *self, TSStateId state, bool error_mode) { + TSLexerResult lex_result; + ts_lexer_start(&self->lexer, state); + self->language->lex_fn(&self->lexer, state, error_mode); + ts_lexer_finish(&self->lexer, &lex_result); + + TSTree *result; + if (lex_result.symbol == ts_builtin_sym_error) { + result = ts_tree_make_error(lex_result.size, lex_result.padding, + lex_result.first_unexpected_character); + } else { + LOG("accept_token sym:%s", SYM_NAME(lex_result.symbol)); + result = ts_tree_make_leaf( + lex_result.symbol, lex_result.padding, lex_result.size, + ts_language_symbol_metadata(self->language, lex_result.symbol)); + } + + if (!result) + return NULL; + + if (lex_result.is_fragile) + result->lex_state = state; + + return result; +} + static TSTree *ts_parser__get_lookahead(TSParser *self, StackVersion version, ReusableNode *reusable_node) { TSLength position = ts_stack_top_position(self->stack, version); @@ -250,7 +276,7 @@ static TSTree *ts_parser__get_lookahead(TSParser *self, StackVersion version, bool error_mode = parse_state == ts_parse_state_error; TSStateId lex_state = error_mode ? 0 : self->language->lex_states[parse_state]; LOG("lex state:%d", lex_state); - return self->language->lex_fn(&self->lexer, lex_state, error_mode); + return ts_parser__lex(self, lex_state, error_mode); } static bool ts_parser__select_tree(TSParser *self, TSTree *left, TSTree *right) { @@ -666,6 +692,8 @@ error: static bool ts_parser__handle_error(TSParser *self, StackVersion version, TSStateId state, TSTree *lookahead) { + size_t previous_version_count = ts_stack_version_count(self->stack); + bool has_shift_action = false; array_clear(&self->reduce_actions); for (TSSymbol symbol = 0; symbol < self->language->symbol_count; symbol++) { @@ -695,17 +723,19 @@ static bool ts_parser__handle_error(TSParser *self, StackVersion version, false, true); CHECK(reduction.status != ReduceFailed); assert(reduction.status == ReduceSucceeded); - CHECK(ts_parser__shift(self, reduction.slice.version, ts_parse_state_error, - lookahead, false)); + CHECK(ts_stack_push(self->stack, reduction.slice.version, NULL, false, + ts_parse_state_error)); } if (has_shift_action) { - CHECK( - ts_parser__shift(self, version, ts_parse_state_error, lookahead, false)); + CHECK(ts_stack_push(self->stack, version, NULL, false, ts_parse_state_error)); } else { ts_stack_renumber_version(self->stack, reduction.slice.version, version); } + ts_stack_merge_new(self->stack, version, previous_version_count); + assert(ts_stack_version_count(self->stack) == previous_version_count); + return true; error: @@ -793,7 +823,7 @@ static ParseActionResult ts_parser__consume_lookahead(TSParser *self, if (ts_stack_version_count(self->stack) == 1 && !self->finished_tree) { LOG_ACTION("handle_error"); CHECK(ts_parser__handle_error(self, version, state, lookahead)); - return ParseActionUpdated; + break; } else { LOG_ACTION("bail version:%d", version); ts_stack_remove_version(self->stack, version); diff --git a/src/runtime/stack.c b/src/runtime/stack.c index 0d6f27a0..921e477c 100644 --- a/src/runtime/stack.c +++ b/src/runtime/stack.c @@ -487,6 +487,21 @@ void ts_stack_merge_from(Stack *self, StackVersion start_version) { } } +void ts_stack_merge_new(Stack *self, StackVersion reference_version, + StackVersion first_new_version) { + StackNode *reference_node = self->heads.contents[reference_version].node; + for (size_t i = first_new_version; i < self->heads.size; i++) { + StackNode *node = self->heads.contents[i].node; + if (reference_node->state == node->state && + reference_node->position.chars == node->position.chars) { + for (size_t j = 0; j < node->link_count; j++) + stack_node_add_link(reference_node, node->links[j]); + ts_stack_remove_version(self, i); + i--; + } + } +} + void ts_stack_merge(Stack *self) { ts_stack_merge_from(self, 0); } diff --git a/src/runtime/stack.h b/src/runtime/stack.h index 31873e36..6a983710 100644 --- a/src/runtime/stack.h +++ b/src/runtime/stack.h @@ -107,6 +107,9 @@ void ts_stack_merge_from(Stack *, StackVersion); void ts_stack_merge(Stack *); +void ts_stack_merge_new(Stack *, StackVersion base_version, + StackVersion new_version); + void ts_stack_renumber_version(Stack *, StackVersion, StackVersion); StackVersion ts_stack_duplicate_version(Stack *, StackVersion); diff --git a/src/runtime/tree.c b/src/runtime/tree.c index 32347794..f691866d 100644 --- a/src/runtime/tree.c +++ b/src/runtime/tree.c @@ -378,11 +378,6 @@ void ts_tree_edit(TSTree *self, TSInputEdit edit) { } } -void ts_tree_steal_padding(TSTree *self, TSTree *other) { - self->size = ts_length_add(self->size, other->padding); - other->padding = ts_length_zero(); -} - static size_t write_lookahead_to_string(char *string, size_t limit, char lookahead) { switch (lookahead) { diff --git a/src/runtime/tree.h b/src/runtime/tree.h index d48fb1bc..9cf7c15a 100644 --- a/src/runtime/tree.h +++ b/src/runtime/tree.h @@ -13,7 +13,7 @@ extern "C" { extern TSStateId TS_TREE_STATE_INDEPENDENT; extern TSStateId TS_TREE_STATE_ERROR; -struct TSTree { +typedef struct TSTree { struct { struct TSTree *parent; size_t index; @@ -43,7 +43,7 @@ struct TSTree { bool fragile_left : 1; bool fragile_right : 1; bool has_changes : 1; -}; +} TSTree; typedef Array(TSTree *) TreeArray; TreeArray ts_tree_array_copy(TreeArray *); @@ -65,7 +65,6 @@ size_t ts_tree_end_column(const TSTree *self); void ts_tree_set_children(TSTree *, size_t, TSTree **); void ts_tree_assign_parents(TSTree *); void ts_tree_edit(TSTree *, TSInputEdit); -void ts_tree_steal_padding(TSTree *, TSTree *); char *ts_tree_string(const TSTree *, const TSLanguage *, bool include_all); size_t ts_tree_last_error_size(const TSTree *);