Fix error when set_included_ranges is called with an invalid range list

This commit is contained in:
Max Brunsfeld 2020-01-17 10:31:28 -08:00
parent de8877db35
commit 9f63139a10
8 changed files with 284 additions and 150 deletions

View file

@ -5,7 +5,7 @@ use crate::generate::generate_parser_for_grammar;
use crate::parse::{perform_edit, Edit};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::{thread, time};
use tree_sitter::{InputEdit, LogType, Parser, Point, Range};
use tree_sitter::{IncludedRangesError, InputEdit, LogType, Parser, Point, Range};
#[test]
fn test_parsing_simple_string() {
@ -671,7 +671,9 @@ fn test_parsing_with_one_included_range() {
let script_content_node = html_tree.root_node().child(1).unwrap().child(1).unwrap();
assert_eq!(script_content_node.kind(), "raw_text");
parser.set_included_ranges(&[script_content_node.range()]);
parser
.set_included_ranges(&[script_content_node.range()])
.unwrap();
parser.set_language(get_language("javascript")).unwrap();
let js_tree = parser.parse(source_code, None).unwrap();
@ -711,26 +713,28 @@ fn test_parsing_with_multiple_included_ranges() {
let close_quote_node = template_string_node.child(3).unwrap();
parser.set_language(get_language("html")).unwrap();
parser.set_included_ranges(&[
Range {
start_byte: open_quote_node.end_byte(),
start_point: open_quote_node.end_position(),
end_byte: interpolation_node1.start_byte(),
end_point: interpolation_node1.start_position(),
},
Range {
start_byte: interpolation_node1.end_byte(),
start_point: interpolation_node1.end_position(),
end_byte: interpolation_node2.start_byte(),
end_point: interpolation_node2.start_position(),
},
Range {
start_byte: interpolation_node2.end_byte(),
start_point: interpolation_node2.end_position(),
end_byte: close_quote_node.start_byte(),
end_point: close_quote_node.start_position(),
},
]);
parser
.set_included_ranges(&[
Range {
start_byte: open_quote_node.end_byte(),
start_point: open_quote_node.end_position(),
end_byte: interpolation_node1.start_byte(),
end_point: interpolation_node1.start_position(),
},
Range {
start_byte: interpolation_node1.end_byte(),
start_point: interpolation_node1.end_position(),
end_byte: interpolation_node2.start_byte(),
end_point: interpolation_node2.start_position(),
},
Range {
start_byte: interpolation_node2.end_byte(),
start_point: interpolation_node2.end_position(),
end_byte: close_quote_node.start_byte(),
end_point: close_quote_node.start_position(),
},
])
.unwrap();
let html_tree = parser.parse(source_code, None).unwrap();
assert_eq!(
@ -779,6 +783,47 @@ fn test_parsing_with_multiple_included_ranges() {
);
}
#[test]
fn test_parsing_error_in_invalid_included_ranges() {
let mut parser = Parser::new();
// Ranges are not ordered
let error = parser
.set_included_ranges(&[
Range {
start_byte: 23,
end_byte: 29,
start_point: Point::new(0, 23),
end_point: Point::new(0, 29),
},
Range {
start_byte: 0,
end_byte: 5,
start_point: Point::new(0, 0),
end_point: Point::new(0, 5),
},
Range {
start_byte: 50,
end_byte: 60,
start_point: Point::new(0, 50),
end_point: Point::new(0, 60),
},
])
.unwrap_err();
assert_eq!(error, IncludedRangesError(1));
// Range ends before it starts
let error = parser
.set_included_ranges(&[Range {
start_byte: 10,
end_byte: 5,
start_point: Point::new(0, 10),
end_point: Point::new(0, 5),
}])
.unwrap_err();
assert_eq!(error, IncludedRangesError(0));
}
#[test]
fn test_parsing_utf16_code_with_errors_at_the_end_of_an_included_range() {
let source_code = "<script>a.</script>";
@ -789,12 +834,14 @@ fn test_parsing_utf16_code_with_errors_at_the_end_of_an_included_range() {
let mut parser = Parser::new();
parser.set_language(get_language("javascript")).unwrap();
parser.set_included_ranges(&[Range {
start_byte,
end_byte,
start_point: Point::new(0, start_byte),
end_point: Point::new(0, end_byte),
}]);
parser
.set_included_ranges(&[Range {
start_byte,
end_byte,
start_point: Point::new(0, start_byte),
end_point: Point::new(0, end_byte),
}])
.unwrap();
let tree = parser.parse_utf16(&utf16_source_code, None).unwrap();
assert_eq!(tree.root_node().to_sexp(), "(program (ERROR (identifier)))");
}
@ -809,20 +856,22 @@ fn test_parsing_with_external_scanner_that_uses_included_range_boundaries() {
let mut parser = Parser::new();
parser.set_language(get_language("javascript")).unwrap();
parser.set_included_ranges(&[
Range {
start_byte: range1_start_byte,
end_byte: range1_end_byte,
start_point: Point::new(0, range1_start_byte),
end_point: Point::new(0, range1_end_byte),
},
Range {
start_byte: range2_start_byte,
end_byte: range2_end_byte,
start_point: Point::new(0, range2_start_byte),
end_point: Point::new(0, range2_end_byte),
},
]);
parser
.set_included_ranges(&[
Range {
start_byte: range1_start_byte,
end_byte: range1_end_byte,
start_point: Point::new(0, range1_start_byte),
end_point: Point::new(0, range1_end_byte),
},
Range {
start_byte: range2_start_byte,
end_byte: range2_end_byte,
start_point: Point::new(0, range2_start_byte),
end_point: Point::new(0, range2_end_byte),
},
])
.unwrap();
let tree = parser.parse(source_code, None).unwrap();
let root = tree.root_node();
@ -870,20 +919,22 @@ fn test_parsing_with_a_newly_excluded_range() {
let directive_start = source_code.find("<%=").unwrap();
let directive_end = source_code.find("</span>").unwrap();
let source_code_end = source_code.len();
parser.set_included_ranges(&[
Range {
start_byte: 0,
end_byte: directive_start,
start_point: Point::new(0, 0),
end_point: Point::new(0, directive_start),
},
Range {
start_byte: directive_end,
end_byte: source_code_end,
start_point: Point::new(0, directive_end),
end_point: Point::new(0, source_code_end),
},
]);
parser
.set_included_ranges(&[
Range {
start_byte: 0,
end_byte: directive_start,
start_point: Point::new(0, 0),
end_point: Point::new(0, directive_start),
},
Range {
start_byte: directive_end,
end_byte: source_code_end,
start_point: Point::new(0, directive_end),
end_point: Point::new(0, source_code_end),
},
])
.unwrap();
let tree = parser.parse(&source_code, Some(&first_tree)).unwrap();
assert_eq!(
@ -944,7 +995,7 @@ fn test_parsing_with_a_newly_included_range() {
// Parse only the first code directive as JavaScript
let mut parser = Parser::new();
parser.set_language(get_language("javascript")).unwrap();
parser.set_included_ranges(&ranges[0..1]);
parser.set_included_ranges(&ranges[0..1]).unwrap();
let first_tree = parser.parse(source_code, None).unwrap();
assert_eq!(
first_tree.root_node().to_sexp(),
@ -955,7 +1006,7 @@ fn test_parsing_with_a_newly_included_range() {
);
// Parse both the code directives as JavaScript, using the old tree as a reference.
parser.set_included_ranges(&ranges);
parser.set_included_ranges(&ranges).unwrap();
let tree = parser.parse(&source_code, Some(&first_tree)).unwrap();
assert_eq!(
tree.root_node().to_sexp(),
@ -1011,20 +1062,22 @@ fn test_parsing_with_included_ranges_and_missing_tokens() {
// There's a missing `a` token at the beginning of the code. It must be inserted
// at the beginning of the first included range, not at {0, 0}.
let source_code = "__bc__bc__";
parser.set_included_ranges(&[
Range {
start_byte: 2,
end_byte: 4,
start_point: Point::new(0, 2),
end_point: Point::new(0, 4),
},
Range {
start_byte: 6,
end_byte: 8,
start_point: Point::new(0, 6),
end_point: Point::new(0, 8),
},
]);
parser
.set_included_ranges(&[
Range {
start_byte: 2,
end_byte: 4,
start_point: Point::new(0, 2),
end_point: Point::new(0, 4),
},
Range {
start_byte: 6,
end_byte: 8,
start_point: Point::new(0, 6),
end_point: Point::new(0, 8),
},
])
.unwrap();
let tree = parser.parse(source_code, None).unwrap();
let root = tree.root_node();

View file

@ -140,6 +140,7 @@ impl Highlighter {
end_point: Point::new(usize::MAX, usize::MAX),
}],
)?;
assert_ne!(layers.len(), 0);
let mut result = HighlightIter {
source,
byte_offset: 0,
@ -333,78 +334,85 @@ impl<'a> HighlightIterLayer<'a> {
let mut result = Vec::with_capacity(1);
let mut queue = Vec::new();
loop {
highlighter
.parser
.set_language(config.language)
.map_err(|_| Error::InvalidLanguage)?;
highlighter.parser.set_included_ranges(&ranges);
unsafe { highlighter.parser.set_cancellation_flag(cancellation_flag) };
let tree = highlighter
.parser
.parse(source, None)
.ok_or(Error::Cancelled)?;
unsafe { highlighter.parser.set_cancellation_flag(None) };
let mut cursor = highlighter.cursors.pop().unwrap_or(QueryCursor::new());
if highlighter.parser.set_included_ranges(&ranges).is_ok() {
highlighter
.parser
.set_language(config.language)
.map_err(|_| Error::InvalidLanguage)?;
// Process combined injections.
if let Some(combined_injections_query) = &config.combined_injections_query {
let mut injections_by_pattern_index =
vec![(None, Vec::new(), false); combined_injections_query.pattern_count()];
let matches =
cursor.matches(combined_injections_query, tree.root_node(), |n: Node| {
&source[n.byte_range()]
});
for mat in matches {
let entry = &mut injections_by_pattern_index[mat.pattern_index];
let (language_name, content_node, include_children) =
injection_for_match(config, combined_injections_query, &mat, source);
if language_name.is_some() {
entry.0 = language_name;
unsafe { highlighter.parser.set_cancellation_flag(cancellation_flag) };
let tree = highlighter
.parser
.parse(source, None)
.ok_or(Error::Cancelled)?;
unsafe { highlighter.parser.set_cancellation_flag(None) };
let mut cursor = highlighter.cursors.pop().unwrap_or(QueryCursor::new());
// Process combined injections.
if let Some(combined_injections_query) = &config.combined_injections_query {
let mut injections_by_pattern_index =
vec![(None, Vec::new(), false); combined_injections_query.pattern_count()];
let matches =
cursor.matches(combined_injections_query, tree.root_node(), |n: Node| {
&source[n.byte_range()]
});
for mat in matches {
let entry = &mut injections_by_pattern_index[mat.pattern_index];
let (language_name, content_node, include_children) =
injection_for_match(config, combined_injections_query, &mat, source);
if language_name.is_some() {
entry.0 = language_name;
}
if let Some(content_node) = content_node {
entry.1.push(content_node);
}
entry.2 = include_children;
}
if let Some(content_node) = content_node {
entry.1.push(content_node);
}
entry.2 = include_children;
}
for (lang_name, content_nodes, includes_children) in injections_by_pattern_index {
if let (Some(lang_name), false) = (lang_name, content_nodes.is_empty()) {
if let Some(next_config) = (injection_callback)(lang_name) {
let ranges =
Self::intersect_ranges(&ranges, &content_nodes, includes_children);
if !ranges.is_empty() {
queue.push((next_config, depth + 1, ranges));
for (lang_name, content_nodes, includes_children) in injections_by_pattern_index
{
if let (Some(lang_name), false) = (lang_name, content_nodes.is_empty()) {
if let Some(next_config) = (injection_callback)(lang_name) {
let ranges = Self::intersect_ranges(
&ranges,
&content_nodes,
includes_children,
);
if !ranges.is_empty() {
queue.push((next_config, depth + 1, ranges));
}
}
}
}
}
// The `captures` iterator borrows the `Tree` and the `QueryCursor`, which
// prevents them from being moved. But both of these values are really just
// pointers, so it's actually ok to move them.
let tree_ref = unsafe { mem::transmute::<_, &'static Tree>(&tree) };
let cursor_ref =
unsafe { mem::transmute::<_, &'static mut QueryCursor>(&mut cursor) };
let captures = cursor_ref
.captures(&config.query, tree_ref.root_node(), move |n: Node| {
&source[n.byte_range()]
})
.peekable();
result.push(HighlightIterLayer {
highlight_end_stack: Vec::new(),
scope_stack: vec![LocalScope {
inherits: false,
range: 0..usize::MAX,
local_defs: Vec::new(),
}],
cursor,
depth,
_tree: tree,
captures,
config,
ranges,
});
}
// The `captures` iterator borrows the `Tree` and the `QueryCursor`, which
// prevents them from being moved. But both of these values are really just
// pointers, so it's actually ok to move them.
let tree_ref = unsafe { mem::transmute::<_, &'static Tree>(&tree) };
let cursor_ref = unsafe { mem::transmute::<_, &'static mut QueryCursor>(&mut cursor) };
let captures = cursor_ref
.captures(&config.query, tree_ref.root_node(), move |n: Node| {
&source[n.byte_range()]
})
.peekable();
result.push(HighlightIterLayer {
highlight_end_stack: Vec::new(),
scope_stack: vec![LocalScope {
inherits: false,
range: 0..usize::MAX,
local_defs: Vec::new(),
}],
cursor,
depth,
_tree: tree,
captures,
config,
ranges,
});
if queue.is_empty() {
break;
} else {

View file

@ -167,7 +167,22 @@ extern "C" {
#[doc = " The second and third parameters specify the location and length of an array"]
#[doc = " of ranges. The parser does *not* take ownership of these ranges; it copies"]
#[doc = " the data, so it doesn\'t matter how these ranges are allocated."]
pub fn ts_parser_set_included_ranges(self_: *mut TSParser, ranges: *const TSRange, length: u32);
#[doc = ""]
#[doc = " If `length` is zero, then the entire document will be parsed. Otherwise,"]
#[doc = " the given ranges must be ordered from earliest to latest in the document,"]
#[doc = " and they must not overlap. That is, the following must hold for all"]
#[doc = " `i` < `length - 1`:"]
#[doc = ""]
#[doc = " ranges[i].end_byte <= ranges[i + 1].start_byte"]
#[doc = ""]
#[doc = " If this requirement is not satisfied, the operation will fail, the ranges"]
#[doc = " will not be assigned, and this function will return `false`. On success,"]
#[doc = " this function returns `true`"]
pub fn ts_parser_set_included_ranges(
self_: *mut TSParser,
ranges: *const TSRange,
length: u32,
) -> bool;
}
extern "C" {
#[doc = " Get the ranges of text that the parser will include when parsing."]
@ -659,9 +674,11 @@ extern "C" {
) -> *const ::std::os::raw::c_char;
}
extern "C" {
#[doc = " Disable a certain capture within a query. This prevents the capture"]
#[doc = " from being returned in matches, and also avoids any resource usage"]
#[doc = " associated with recording the capture."]
#[doc = " Disable a certain capture within a query."]
#[doc = ""]
#[doc = " This prevents the capture from being returned in matches, and also avoids"]
#[doc = " any resource usage associated with recording the capture. Currently, there"]
#[doc = " is no way to undo this."]
pub fn ts_query_disable_capture(
arg1: *mut TSQuery,
arg2: *const ::std::os::raw::c_char,
@ -669,9 +686,10 @@ extern "C" {
);
}
extern "C" {
#[doc = " Disable a certain pattern within a query. This prevents the pattern"]
#[doc = " from matching and removes most of the overhead associated with the"]
#[doc = " pattern."]
#[doc = " Disable a certain pattern within a query."]
#[doc = ""]
#[doc = " This prevents the pattern from matching and removes most of the overhead"]
#[doc = " associated with the pattern. Currently, there is no way to undo this."]
pub fn ts_query_disable_pattern(arg1: *mut TSQuery, arg2: u32);
}
extern "C" {

View file

@ -142,6 +142,10 @@ pub struct LanguageError {
version: usize,
}
/// An error that occurred in `Parser::set_included_ranges`.
#[derive(Debug, PartialEq, Eq)]
pub struct IncludedRangesError(pub usize);
/// An error that occurred when trying to create a `Query`.
#[derive(Debug, PartialEq, Eq)]
pub enum QueryError {
@ -508,16 +512,41 @@ impl Parser {
/// allows you to parse only a *portion* of a document but still return a syntax
/// tree whose ranges match up with the document as a whole. You can also pass
/// multiple disjoint ranges.
pub fn set_included_ranges(&mut self, ranges: &[Range]) {
///
/// If `ranges` is empty, then the entire document will be parsed. Otherwise,
/// the given ranges must be ordered from earliest to latest in the document,
/// and they must not overlap. That is, the following must hold for all
/// `i` < `length - 1`:
///
/// ranges[i].end_byte <= ranges[i + 1].start_byte
///
/// If this requirement is not satisfied, method will panic.
pub fn set_included_ranges<'a>(
&mut self,
ranges: &'a [Range],
) -> Result<(), IncludedRangesError> {
let ts_ranges: Vec<ffi::TSRange> =
ranges.iter().cloned().map(|range| range.into()).collect();
unsafe {
let result = unsafe {
ffi::ts_parser_set_included_ranges(
self.0.as_ptr(),
ts_ranges.as_ptr(),
ts_ranges.len() as u32,
)
};
if result {
Ok(())
} else {
let mut prev_end_byte = 0;
for (i, range) in ranges.iter().enumerate() {
if range.start_byte < prev_end_byte || range.end_byte < range.start_byte {
return Err(IncludedRangesError(i));
}
prev_end_byte = range.end_byte;
}
Err(IncludedRangesError(0))
}
}
/// Get the parser's current cancellation flag pointer.

View file

@ -174,8 +174,19 @@ const TSLanguage *ts_parser_language(const TSParser *self);
* The second and third parameters specify the location and length of an array
* of ranges. The parser does *not* take ownership of these ranges; it copies
* the data, so it doesn't matter how these ranges are allocated.
*
* If `length` is zero, then the entire document will be parsed. Otherwise,
* the given ranges must be ordered from earliest to latest in the document,
* and they must not overlap. That is, the following must hold for all
* `i` < `length - 1`:
*
* ranges[i].end_byte <= ranges[i + 1].start_byte
*
* If this requirement is not satisfied, the operation will fail, the ranges
* will not be assigned, and this function will return `false`. On success,
* this function returns `true`
*/
void ts_parser_set_included_ranges(
bool ts_parser_set_included_ranges(
TSParser *self,
const TSRange *ranges,
uint32_t length

View file

@ -355,7 +355,7 @@ void ts_lexer_mark_end(Lexer *self) {
ts_lexer__mark_end(&self->data);
}
void ts_lexer_set_included_ranges(
bool ts_lexer_set_included_ranges(
Lexer *self,
const TSRange *ranges,
uint32_t count
@ -363,6 +363,16 @@ void ts_lexer_set_included_ranges(
if (count == 0 || !ranges) {
ranges = &DEFAULT_RANGE;
count = 1;
} else {
uint32_t previous_byte = 0;
for (unsigned i = 0; i < count; i++) {
const TSRange *range = &ranges[i];
if (
range->start_byte < previous_byte ||
range->end_byte < range->start_byte
) return false;
previous_byte = range->end_byte;
}
}
size_t size = count * sizeof(TSRange);
@ -370,6 +380,7 @@ void ts_lexer_set_included_ranges(
memcpy(self->included_ranges, ranges, size);
self->included_range_count = count;
ts_lexer_goto(self, self->current_position);
return true;
}
TSRange *ts_lexer_included_ranges(const Lexer *self, uint32_t *count) {

View file

@ -38,7 +38,7 @@ void ts_lexer_start(Lexer *);
void ts_lexer_finish(Lexer *, uint32_t *);
void ts_lexer_advance_to_end(Lexer *);
void ts_lexer_mark_end(Lexer *);
void ts_lexer_set_included_ranges(Lexer *self, const TSRange *ranges, uint32_t count);
bool ts_lexer_set_included_ranges(Lexer *self, const TSRange *ranges, uint32_t count);
TSRange *ts_lexer_included_ranges(const Lexer *self, uint32_t *count);
#ifdef __cplusplus

View file

@ -1761,8 +1761,12 @@ void ts_parser_set_timeout_micros(TSParser *self, uint64_t timeout_micros) {
self->timeout_duration = duration_from_micros(timeout_micros);
}
void ts_parser_set_included_ranges(TSParser *self, const TSRange *ranges, uint32_t count) {
ts_lexer_set_included_ranges(&self->lexer, ranges, count);
bool ts_parser_set_included_ranges(
TSParser *self,
const TSRange *ranges,
uint32_t count
) {
return ts_lexer_set_included_ranges(&self->lexer, ranges, count);
}
const TSRange *ts_parser_included_ranges(const TSParser *self, uint32_t *count) {