From 3c4957e940ed7cb12910f824386e8a2e32873270 Mon Sep 17 00:00:00 2001 From: MrPrezident <2295306+MrPrezident@users.noreply.github.com> Date: Thu, 19 Oct 2023 13:00:07 -0500 Subject: [PATCH] fix "test_point_range_captures not working" Fix for https://github.com/tree-sitter/py-tree-sitter/issues/105 --- cli/src/tests/query_test.rs | 67 +++++++++++++++++++++++++++++++++++++ lib/src/query.c | 15 +++++++-- 2 files changed, 80 insertions(+), 2 deletions(-) diff --git a/cli/src/tests/query_test.rs b/cli/src/tests/query_test.rs index 13e4f8d0..31d7e3f3 100644 --- a/cli/src/tests/query_test.rs +++ b/cli/src/tests/query_test.rs @@ -2105,6 +2105,73 @@ fn test_query_captures_within_byte_range() { }); } +#[test] +fn test_query_cursor_next_capture_with_byte_range() { + allocations::record(|| { + let language = get_language("python"); + let query = Query::new( + language, + "(function_definition name: (identifier) @function) + (attribute attribute: (identifier) @property) + ((identifier) @variable)", + ) + .unwrap(); + + let source = "def func():\n foo.bar.baz()\n"; + // ^ ^ ^ ^ + // byte_pos 0 12 17 27 + // point_pos (0,0) (1,0) (1,5) (1,15) + + let mut parser = Parser::new(); + parser.set_language(language).unwrap(); + let tree = parser.parse(source, None).unwrap(); + + let mut cursor = QueryCursor::new(); + let captures = + cursor + .set_byte_range(12..17) + .captures(&query, tree.root_node(), source.as_bytes()); + + assert_eq!( + collect_captures(captures, &query, source), + &[("variable", "foo"),] + ); + }); +} + +#[test] +fn test_query_cursor_next_capture_with_point_range() { + allocations::record(|| { + let language = get_language("python"); + let query = Query::new( + language, + "(function_definition name: (identifier) @function) + (attribute attribute: (identifier) @property) + ((identifier) @variable)", + ) + .unwrap(); + + let source = "def func():\n foo.bar.baz()\n"; + // ^ ^ ^ ^ + // byte_pos 0 12 17 27 + // point_pos (0,0) (1,0) (1,5) (1,15) + + let mut parser = Parser::new(); + parser.set_language(language).unwrap(); + let tree = parser.parse(source, None).unwrap(); + + let mut cursor = QueryCursor::new(); + let captures = cursor + .set_point_range(Point::new(1, 0)..Point::new(1, 5)) + .captures(&query, tree.root_node(), source.as_bytes()); + + assert_eq!( + collect_captures(captures, &query, source), + &[("variable", "foo"),] + ); + }); +} + #[test] fn test_query_matches_with_unrooted_patterns_intersecting_byte_range() { allocations::record(|| { diff --git a/lib/src/query.c b/lib/src/query.c index 4e623ae7..826bfc67 100644 --- a/lib/src/query.c +++ b/lib/src/query.c @@ -4048,9 +4048,20 @@ bool ts_query_cursor_next_capture( continue; } - // Skip captures that precede the cursor's start byte. TSNode node = captures->contents[state->consumed_capture_count].node; - if (ts_node_end_byte(node) <= self->start_byte) { + + bool node_precedes_range = ( + ts_node_end_byte(node) <= self->start_byte || + point_lte(ts_node_end_point(node), self->start_point) + ); + bool node_follows_range = ( + ts_node_start_byte(node) >= self->end_byte || + point_gte(ts_node_start_point(node), self->end_point) + ); + bool node_outside_of_range = node_precedes_range || node_follows_range; + + // Skip captures that are outside of the cursor's range. + if (node_outside_of_range) { state->consumed_capture_count++; continue; }