diff --git a/cli/src/tests/parser_test.rs b/cli/src/tests/parser_test.rs index 8a549d48..9a04c70d 100644 --- a/cli/src/tests/parser_test.rs +++ b/cli/src/tests/parser_test.rs @@ -5,7 +5,7 @@ use crate::generate::generate_parser_for_grammar; use crate::parse::{perform_edit, Edit}; use std::sync::atomic::{AtomicUsize, Ordering}; use std::{thread, time}; -use tree_sitter::{InputEdit, LogType, Parser, Point, Range}; +use tree_sitter::{IncludedRangesError, InputEdit, LogType, Parser, Point, Range}; #[test] fn test_parsing_simple_string() { @@ -671,7 +671,9 @@ fn test_parsing_with_one_included_range() { let script_content_node = html_tree.root_node().child(1).unwrap().child(1).unwrap(); assert_eq!(script_content_node.kind(), "raw_text"); - parser.set_included_ranges(&[script_content_node.range()]); + parser + .set_included_ranges(&[script_content_node.range()]) + .unwrap(); parser.set_language(get_language("javascript")).unwrap(); let js_tree = parser.parse(source_code, None).unwrap(); @@ -711,26 +713,28 @@ fn test_parsing_with_multiple_included_ranges() { let close_quote_node = template_string_node.child(3).unwrap(); parser.set_language(get_language("html")).unwrap(); - parser.set_included_ranges(&[ - Range { - start_byte: open_quote_node.end_byte(), - start_point: open_quote_node.end_position(), - end_byte: interpolation_node1.start_byte(), - end_point: interpolation_node1.start_position(), - }, - Range { - start_byte: interpolation_node1.end_byte(), - start_point: interpolation_node1.end_position(), - end_byte: interpolation_node2.start_byte(), - end_point: interpolation_node2.start_position(), - }, - Range { - start_byte: interpolation_node2.end_byte(), - start_point: interpolation_node2.end_position(), - end_byte: close_quote_node.start_byte(), - end_point: close_quote_node.start_position(), - }, - ]); + parser + .set_included_ranges(&[ + Range { + start_byte: open_quote_node.end_byte(), + start_point: open_quote_node.end_position(), + end_byte: interpolation_node1.start_byte(), + end_point: interpolation_node1.start_position(), + }, + Range { + start_byte: interpolation_node1.end_byte(), + start_point: interpolation_node1.end_position(), + end_byte: interpolation_node2.start_byte(), + end_point: interpolation_node2.start_position(), + }, + Range { + start_byte: interpolation_node2.end_byte(), + start_point: interpolation_node2.end_position(), + end_byte: close_quote_node.start_byte(), + end_point: close_quote_node.start_position(), + }, + ]) + .unwrap(); let html_tree = parser.parse(source_code, None).unwrap(); assert_eq!( @@ -779,6 +783,47 @@ fn test_parsing_with_multiple_included_ranges() { ); } +#[test] +fn test_parsing_error_in_invalid_included_ranges() { + let mut parser = Parser::new(); + + // Ranges are not ordered + let error = parser + .set_included_ranges(&[ + Range { + start_byte: 23, + end_byte: 29, + start_point: Point::new(0, 23), + end_point: Point::new(0, 29), + }, + Range { + start_byte: 0, + end_byte: 5, + start_point: Point::new(0, 0), + end_point: Point::new(0, 5), + }, + Range { + start_byte: 50, + end_byte: 60, + start_point: Point::new(0, 50), + end_point: Point::new(0, 60), + }, + ]) + .unwrap_err(); + assert_eq!(error, IncludedRangesError(1)); + + // Range ends before it starts + let error = parser + .set_included_ranges(&[Range { + start_byte: 10, + end_byte: 5, + start_point: Point::new(0, 10), + end_point: Point::new(0, 5), + }]) + .unwrap_err(); + assert_eq!(error, IncludedRangesError(0)); +} + #[test] fn test_parsing_utf16_code_with_errors_at_the_end_of_an_included_range() { let source_code = ""; @@ -789,12 +834,14 @@ fn test_parsing_utf16_code_with_errors_at_the_end_of_an_included_range() { let mut parser = Parser::new(); parser.set_language(get_language("javascript")).unwrap(); - parser.set_included_ranges(&[Range { - start_byte, - end_byte, - start_point: Point::new(0, start_byte), - end_point: Point::new(0, end_byte), - }]); + parser + .set_included_ranges(&[Range { + start_byte, + end_byte, + start_point: Point::new(0, start_byte), + end_point: Point::new(0, end_byte), + }]) + .unwrap(); let tree = parser.parse_utf16(&utf16_source_code, None).unwrap(); assert_eq!(tree.root_node().to_sexp(), "(program (ERROR (identifier)))"); } @@ -809,20 +856,22 @@ fn test_parsing_with_external_scanner_that_uses_included_range_boundaries() { let mut parser = Parser::new(); parser.set_language(get_language("javascript")).unwrap(); - parser.set_included_ranges(&[ - Range { - start_byte: range1_start_byte, - end_byte: range1_end_byte, - start_point: Point::new(0, range1_start_byte), - end_point: Point::new(0, range1_end_byte), - }, - Range { - start_byte: range2_start_byte, - end_byte: range2_end_byte, - start_point: Point::new(0, range2_start_byte), - end_point: Point::new(0, range2_end_byte), - }, - ]); + parser + .set_included_ranges(&[ + Range { + start_byte: range1_start_byte, + end_byte: range1_end_byte, + start_point: Point::new(0, range1_start_byte), + end_point: Point::new(0, range1_end_byte), + }, + Range { + start_byte: range2_start_byte, + end_byte: range2_end_byte, + start_point: Point::new(0, range2_start_byte), + end_point: Point::new(0, range2_end_byte), + }, + ]) + .unwrap(); let tree = parser.parse(source_code, None).unwrap(); let root = tree.root_node(); @@ -870,20 +919,22 @@ fn test_parsing_with_a_newly_excluded_range() { let directive_start = source_code.find("<%=").unwrap(); let directive_end = source_code.find("").unwrap(); let source_code_end = source_code.len(); - parser.set_included_ranges(&[ - Range { - start_byte: 0, - end_byte: directive_start, - start_point: Point::new(0, 0), - end_point: Point::new(0, directive_start), - }, - Range { - start_byte: directive_end, - end_byte: source_code_end, - start_point: Point::new(0, directive_end), - end_point: Point::new(0, source_code_end), - }, - ]); + parser + .set_included_ranges(&[ + Range { + start_byte: 0, + end_byte: directive_start, + start_point: Point::new(0, 0), + end_point: Point::new(0, directive_start), + }, + Range { + start_byte: directive_end, + end_byte: source_code_end, + start_point: Point::new(0, directive_end), + end_point: Point::new(0, source_code_end), + }, + ]) + .unwrap(); let tree = parser.parse(&source_code, Some(&first_tree)).unwrap(); assert_eq!( @@ -944,7 +995,7 @@ fn test_parsing_with_a_newly_included_range() { // Parse only the first code directive as JavaScript let mut parser = Parser::new(); parser.set_language(get_language("javascript")).unwrap(); - parser.set_included_ranges(&ranges[0..1]); + parser.set_included_ranges(&ranges[0..1]).unwrap(); let first_tree = parser.parse(source_code, None).unwrap(); assert_eq!( first_tree.root_node().to_sexp(), @@ -955,7 +1006,7 @@ fn test_parsing_with_a_newly_included_range() { ); // Parse both the code directives as JavaScript, using the old tree as a reference. - parser.set_included_ranges(&ranges); + parser.set_included_ranges(&ranges).unwrap(); let tree = parser.parse(&source_code, Some(&first_tree)).unwrap(); assert_eq!( tree.root_node().to_sexp(), @@ -1011,20 +1062,22 @@ fn test_parsing_with_included_ranges_and_missing_tokens() { // There's a missing `a` token at the beginning of the code. It must be inserted // at the beginning of the first included range, not at {0, 0}. let source_code = "__bc__bc__"; - parser.set_included_ranges(&[ - Range { - start_byte: 2, - end_byte: 4, - start_point: Point::new(0, 2), - end_point: Point::new(0, 4), - }, - Range { - start_byte: 6, - end_byte: 8, - start_point: Point::new(0, 6), - end_point: Point::new(0, 8), - }, - ]); + parser + .set_included_ranges(&[ + Range { + start_byte: 2, + end_byte: 4, + start_point: Point::new(0, 2), + end_point: Point::new(0, 4), + }, + Range { + start_byte: 6, + end_byte: 8, + start_point: Point::new(0, 6), + end_point: Point::new(0, 8), + }, + ]) + .unwrap(); let tree = parser.parse(source_code, None).unwrap(); let root = tree.root_node(); diff --git a/highlight/src/lib.rs b/highlight/src/lib.rs index 7364291e..ef3b2304 100644 --- a/highlight/src/lib.rs +++ b/highlight/src/lib.rs @@ -140,6 +140,7 @@ impl Highlighter { end_point: Point::new(usize::MAX, usize::MAX), }], )?; + assert_ne!(layers.len(), 0); let mut result = HighlightIter { source, byte_offset: 0, @@ -333,78 +334,85 @@ impl<'a> HighlightIterLayer<'a> { let mut result = Vec::with_capacity(1); let mut queue = Vec::new(); loop { - highlighter - .parser - .set_language(config.language) - .map_err(|_| Error::InvalidLanguage)?; - highlighter.parser.set_included_ranges(&ranges); - unsafe { highlighter.parser.set_cancellation_flag(cancellation_flag) }; - let tree = highlighter - .parser - .parse(source, None) - .ok_or(Error::Cancelled)?; - unsafe { highlighter.parser.set_cancellation_flag(None) }; - let mut cursor = highlighter.cursors.pop().unwrap_or(QueryCursor::new()); + if highlighter.parser.set_included_ranges(&ranges).is_ok() { + highlighter + .parser + .set_language(config.language) + .map_err(|_| Error::InvalidLanguage)?; - // Process combined injections. - if let Some(combined_injections_query) = &config.combined_injections_query { - let mut injections_by_pattern_index = - vec![(None, Vec::new(), false); combined_injections_query.pattern_count()]; - let matches = - cursor.matches(combined_injections_query, tree.root_node(), |n: Node| { - &source[n.byte_range()] - }); - for mat in matches { - let entry = &mut injections_by_pattern_index[mat.pattern_index]; - let (language_name, content_node, include_children) = - injection_for_match(config, combined_injections_query, &mat, source); - if language_name.is_some() { - entry.0 = language_name; + unsafe { highlighter.parser.set_cancellation_flag(cancellation_flag) }; + let tree = highlighter + .parser + .parse(source, None) + .ok_or(Error::Cancelled)?; + unsafe { highlighter.parser.set_cancellation_flag(None) }; + let mut cursor = highlighter.cursors.pop().unwrap_or(QueryCursor::new()); + + // Process combined injections. + if let Some(combined_injections_query) = &config.combined_injections_query { + let mut injections_by_pattern_index = + vec![(None, Vec::new(), false); combined_injections_query.pattern_count()]; + let matches = + cursor.matches(combined_injections_query, tree.root_node(), |n: Node| { + &source[n.byte_range()] + }); + for mat in matches { + let entry = &mut injections_by_pattern_index[mat.pattern_index]; + let (language_name, content_node, include_children) = + injection_for_match(config, combined_injections_query, &mat, source); + if language_name.is_some() { + entry.0 = language_name; + } + if let Some(content_node) = content_node { + entry.1.push(content_node); + } + entry.2 = include_children; } - if let Some(content_node) = content_node { - entry.1.push(content_node); - } - entry.2 = include_children; - } - for (lang_name, content_nodes, includes_children) in injections_by_pattern_index { - if let (Some(lang_name), false) = (lang_name, content_nodes.is_empty()) { - if let Some(next_config) = (injection_callback)(lang_name) { - let ranges = - Self::intersect_ranges(&ranges, &content_nodes, includes_children); - if !ranges.is_empty() { - queue.push((next_config, depth + 1, ranges)); + for (lang_name, content_nodes, includes_children) in injections_by_pattern_index + { + if let (Some(lang_name), false) = (lang_name, content_nodes.is_empty()) { + if let Some(next_config) = (injection_callback)(lang_name) { + let ranges = Self::intersect_ranges( + &ranges, + &content_nodes, + includes_children, + ); + if !ranges.is_empty() { + queue.push((next_config, depth + 1, ranges)); + } } } } } + + // The `captures` iterator borrows the `Tree` and the `QueryCursor`, which + // prevents them from being moved. But both of these values are really just + // pointers, so it's actually ok to move them. + let tree_ref = unsafe { mem::transmute::<_, &'static Tree>(&tree) }; + let cursor_ref = + unsafe { mem::transmute::<_, &'static mut QueryCursor>(&mut cursor) }; + let captures = cursor_ref + .captures(&config.query, tree_ref.root_node(), move |n: Node| { + &source[n.byte_range()] + }) + .peekable(); + + result.push(HighlightIterLayer { + highlight_end_stack: Vec::new(), + scope_stack: vec![LocalScope { + inherits: false, + range: 0..usize::MAX, + local_defs: Vec::new(), + }], + cursor, + depth, + _tree: tree, + captures, + config, + ranges, + }); } - // The `captures` iterator borrows the `Tree` and the `QueryCursor`, which - // prevents them from being moved. But both of these values are really just - // pointers, so it's actually ok to move them. - let tree_ref = unsafe { mem::transmute::<_, &'static Tree>(&tree) }; - let cursor_ref = unsafe { mem::transmute::<_, &'static mut QueryCursor>(&mut cursor) }; - let captures = cursor_ref - .captures(&config.query, tree_ref.root_node(), move |n: Node| { - &source[n.byte_range()] - }) - .peekable(); - - result.push(HighlightIterLayer { - highlight_end_stack: Vec::new(), - scope_stack: vec![LocalScope { - inherits: false, - range: 0..usize::MAX, - local_defs: Vec::new(), - }], - cursor, - depth, - _tree: tree, - captures, - config, - ranges, - }); - if queue.is_empty() { break; } else { diff --git a/lib/binding_rust/bindings.rs b/lib/binding_rust/bindings.rs index dfc280fa..75c5bb12 100644 --- a/lib/binding_rust/bindings.rs +++ b/lib/binding_rust/bindings.rs @@ -167,7 +167,22 @@ extern "C" { #[doc = " The second and third parameters specify the location and length of an array"] #[doc = " of ranges. The parser does *not* take ownership of these ranges; it copies"] #[doc = " the data, so it doesn\'t matter how these ranges are allocated."] - pub fn ts_parser_set_included_ranges(self_: *mut TSParser, ranges: *const TSRange, length: u32); + #[doc = ""] + #[doc = " If `length` is zero, then the entire document will be parsed. Otherwise,"] + #[doc = " the given ranges must be ordered from earliest to latest in the document,"] + #[doc = " and they must not overlap. That is, the following must hold for all"] + #[doc = " `i` < `length - 1`:"] + #[doc = ""] + #[doc = " ranges[i].end_byte <= ranges[i + 1].start_byte"] + #[doc = ""] + #[doc = " If this requirement is not satisfied, the operation will fail, the ranges"] + #[doc = " will not be assigned, and this function will return `false`. On success,"] + #[doc = " this function returns `true`"] + pub fn ts_parser_set_included_ranges( + self_: *mut TSParser, + ranges: *const TSRange, + length: u32, + ) -> bool; } extern "C" { #[doc = " Get the ranges of text that the parser will include when parsing."] @@ -659,9 +674,11 @@ extern "C" { ) -> *const ::std::os::raw::c_char; } extern "C" { - #[doc = " Disable a certain capture within a query. This prevents the capture"] - #[doc = " from being returned in matches, and also avoids any resource usage"] - #[doc = " associated with recording the capture."] + #[doc = " Disable a certain capture within a query."] + #[doc = ""] + #[doc = " This prevents the capture from being returned in matches, and also avoids"] + #[doc = " any resource usage associated with recording the capture. Currently, there"] + #[doc = " is no way to undo this."] pub fn ts_query_disable_capture( arg1: *mut TSQuery, arg2: *const ::std::os::raw::c_char, @@ -669,9 +686,10 @@ extern "C" { ); } extern "C" { - #[doc = " Disable a certain pattern within a query. This prevents the pattern"] - #[doc = " from matching and removes most of the overhead associated with the"] - #[doc = " pattern."] + #[doc = " Disable a certain pattern within a query."] + #[doc = ""] + #[doc = " This prevents the pattern from matching and removes most of the overhead"] + #[doc = " associated with the pattern. Currently, there is no way to undo this."] pub fn ts_query_disable_pattern(arg1: *mut TSQuery, arg2: u32); } extern "C" { diff --git a/lib/binding_rust/lib.rs b/lib/binding_rust/lib.rs index 0f308d45..a61169b1 100644 --- a/lib/binding_rust/lib.rs +++ b/lib/binding_rust/lib.rs @@ -142,6 +142,10 @@ pub struct LanguageError { version: usize, } +/// An error that occurred in `Parser::set_included_ranges`. +#[derive(Debug, PartialEq, Eq)] +pub struct IncludedRangesError(pub usize); + /// An error that occurred when trying to create a `Query`. #[derive(Debug, PartialEq, Eq)] pub enum QueryError { @@ -508,16 +512,41 @@ impl Parser { /// allows you to parse only a *portion* of a document but still return a syntax /// tree whose ranges match up with the document as a whole. You can also pass /// multiple disjoint ranges. - pub fn set_included_ranges(&mut self, ranges: &[Range]) { + /// + /// If `ranges` is empty, then the entire document will be parsed. Otherwise, + /// the given ranges must be ordered from earliest to latest in the document, + /// and they must not overlap. That is, the following must hold for all + /// `i` < `length - 1`: + /// + /// ranges[i].end_byte <= ranges[i + 1].start_byte + /// + /// If this requirement is not satisfied, method will panic. + pub fn set_included_ranges<'a>( + &mut self, + ranges: &'a [Range], + ) -> Result<(), IncludedRangesError> { let ts_ranges: Vec = ranges.iter().cloned().map(|range| range.into()).collect(); - unsafe { + let result = unsafe { ffi::ts_parser_set_included_ranges( self.0.as_ptr(), ts_ranges.as_ptr(), ts_ranges.len() as u32, ) }; + + if result { + Ok(()) + } else { + let mut prev_end_byte = 0; + for (i, range) in ranges.iter().enumerate() { + if range.start_byte < prev_end_byte || range.end_byte < range.start_byte { + return Err(IncludedRangesError(i)); + } + prev_end_byte = range.end_byte; + } + Err(IncludedRangesError(0)) + } } /// Get the parser's current cancellation flag pointer. diff --git a/lib/include/tree_sitter/api.h b/lib/include/tree_sitter/api.h index 17b6d0c4..0c7b9804 100644 --- a/lib/include/tree_sitter/api.h +++ b/lib/include/tree_sitter/api.h @@ -174,8 +174,19 @@ const TSLanguage *ts_parser_language(const TSParser *self); * The second and third parameters specify the location and length of an array * of ranges. The parser does *not* take ownership of these ranges; it copies * the data, so it doesn't matter how these ranges are allocated. + * + * If `length` is zero, then the entire document will be parsed. Otherwise, + * the given ranges must be ordered from earliest to latest in the document, + * and they must not overlap. That is, the following must hold for all + * `i` < `length - 1`: + * + * ranges[i].end_byte <= ranges[i + 1].start_byte + * + * If this requirement is not satisfied, the operation will fail, the ranges + * will not be assigned, and this function will return `false`. On success, + * this function returns `true` */ -void ts_parser_set_included_ranges( +bool ts_parser_set_included_ranges( TSParser *self, const TSRange *ranges, uint32_t length diff --git a/lib/src/lexer.c b/lib/src/lexer.c index e2ca8519..3f8a4c0a 100644 --- a/lib/src/lexer.c +++ b/lib/src/lexer.c @@ -355,7 +355,7 @@ void ts_lexer_mark_end(Lexer *self) { ts_lexer__mark_end(&self->data); } -void ts_lexer_set_included_ranges( +bool ts_lexer_set_included_ranges( Lexer *self, const TSRange *ranges, uint32_t count @@ -363,6 +363,16 @@ void ts_lexer_set_included_ranges( if (count == 0 || !ranges) { ranges = &DEFAULT_RANGE; count = 1; + } else { + uint32_t previous_byte = 0; + for (unsigned i = 0; i < count; i++) { + const TSRange *range = &ranges[i]; + if ( + range->start_byte < previous_byte || + range->end_byte < range->start_byte + ) return false; + previous_byte = range->end_byte; + } } size_t size = count * sizeof(TSRange); @@ -370,6 +380,7 @@ void ts_lexer_set_included_ranges( memcpy(self->included_ranges, ranges, size); self->included_range_count = count; ts_lexer_goto(self, self->current_position); + return true; } TSRange *ts_lexer_included_ranges(const Lexer *self, uint32_t *count) { diff --git a/lib/src/lexer.h b/lib/src/lexer.h index 8cd9c267..5e392945 100644 --- a/lib/src/lexer.h +++ b/lib/src/lexer.h @@ -38,7 +38,7 @@ void ts_lexer_start(Lexer *); void ts_lexer_finish(Lexer *, uint32_t *); void ts_lexer_advance_to_end(Lexer *); void ts_lexer_mark_end(Lexer *); -void ts_lexer_set_included_ranges(Lexer *self, const TSRange *ranges, uint32_t count); +bool ts_lexer_set_included_ranges(Lexer *self, const TSRange *ranges, uint32_t count); TSRange *ts_lexer_included_ranges(const Lexer *self, uint32_t *count); #ifdef __cplusplus diff --git a/lib/src/parser.c b/lib/src/parser.c index f381afcc..76bdcbfa 100644 --- a/lib/src/parser.c +++ b/lib/src/parser.c @@ -1761,8 +1761,12 @@ void ts_parser_set_timeout_micros(TSParser *self, uint64_t timeout_micros) { self->timeout_duration = duration_from_micros(timeout_micros); } -void ts_parser_set_included_ranges(TSParser *self, const TSRange *ranges, uint32_t count) { - ts_lexer_set_included_ranges(&self->lexer, ranges, count); +bool ts_parser_set_included_ranges( + TSParser *self, + const TSRange *ranges, + uint32_t count +) { + return ts_lexer_set_included_ranges(&self->lexer, ranges, count); } const TSRange *ts_parser_included_ranges(const TSParser *self, uint32_t *count) {