From b15e90bd265a9899814bf6231690dabb2066fc69 Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Wed, 18 Sep 2019 17:35:47 -0700 Subject: [PATCH] Handle set! predicate function in queries --- cli/src/tests/query_test.rs | 27 +++++++++++++ lib/binding_rust/lib.rs | 64 +++++++++++++++++++++--------- lib/binding_web/binding.js | 47 +++++++++++++++++++--- lib/binding_web/test/query-test.js | 24 ++++++++++- lib/src/query.c | 3 +- 5 files changed, 140 insertions(+), 25 deletions(-) diff --git a/cli/src/tests/query_test.rs b/cli/src/tests/query_test.rs index 9344598d..e3fe6e6e 100644 --- a/cli/src/tests/query_test.rs +++ b/cli/src/tests/query_test.rs @@ -632,6 +632,33 @@ fn test_query_captures_with_text_conditions() { }); } +#[test] +fn test_query_captures_with_set_properties() { + allocations::record(|| { + let language = get_language("javascript"); + + let query = Query::new( + language, + r#" + ((call_expression (identifier) @foo) + (set! name something) + (set! age 24)) + + (property_identifier) @bar"#, + ) + .unwrap(); + + assert_eq!( + query.pattern_properties(0), + &[ + ("name".to_string(), "something".to_string()), + ("age".to_string(), "24".to_string()), + ] + ); + assert_eq!(query.pattern_properties(1), &[]) + }); +} + #[test] fn test_query_captures_with_duplicates() { allocations::record(|| { diff --git a/lib/binding_rust/lib.rs b/lib/binding_rust/lib.rs index 6a7fe88f..7cba9bfb 100644 --- a/lib/binding_rust/lib.rs +++ b/lib/binding_rust/lib.rs @@ -150,6 +150,7 @@ pub struct Query { ptr: NonNull, capture_names: Vec, predicates: Vec>, + properties: Vec>, } pub struct QueryCursor(NonNull); @@ -1002,6 +1003,7 @@ impl Query { ptr: unsafe { NonNull::new_unchecked(ptr) }, capture_names: Vec::with_capacity(capture_count as usize), predicates: Vec::with_capacity(pattern_count), + properties: Vec::with_capacity(pattern_count), }; // Build a vector of strings to store the capture names. @@ -1042,6 +1044,7 @@ impl Query { let type_capture = ffi::TSQueryPredicateStepType_TSQueryPredicateStepTypeCapture; let type_string = ffi::TSQueryPredicateStepType_TSQueryPredicateStepTypeString; + let mut pattern_properties = Vec::new(); let mut pattern_predicates = Vec::new(); for p in predicate_steps.split(|s| s.type_ == type_done) { if p.is_empty() { @@ -1057,7 +1060,7 @@ impl Query { // Build a predicate for each of the known predicate function names. let operator_name = &string_values[p[0].value_id as usize]; - pattern_predicates.push(match operator_name.as_str() { + match operator_name.as_str() { "eq?" => { if p.len() != 3 { return Err(QueryError::Predicate(format!( @@ -1072,17 +1075,14 @@ impl Query { ))); } - if p[2].type_ == type_capture { - Ok(QueryPredicate::CaptureEqCapture( - p[1].value_id, - p[2].value_id, - )) + pattern_predicates.push(if p[2].type_ == type_capture { + QueryPredicate::CaptureEqCapture(p[1].value_id, p[2].value_id) } else { - Ok(QueryPredicate::CaptureEqString( + QueryPredicate::CaptureEqString( p[1].value_id, string_values[p[2].value_id as usize].clone(), - )) - } + ) + }); } "match?" => { @@ -1106,20 +1106,44 @@ impl Query { } let regex = &string_values[p[2].value_id as usize]; - Ok(QueryPredicate::CaptureMatchString( + pattern_predicates.push(QueryPredicate::CaptureMatchString( p[1].value_id, - regex::bytes::Regex::new(regex) - .map_err(|_| QueryError::Predicate(format!("Invalid regex '{}'", regex)))?, - )) + regex::bytes::Regex::new(regex).map_err(|_| { + QueryError::Predicate(format!("Invalid regex '{}'", regex)) + })?, + )); } - _ => Err(QueryError::Predicate(format!( - "Unknown query predicate function {}", - operator_name, - ))), - }?); + "set!" => { + if p.len() != 3 { + return Err(QueryError::Predicate(format!( + "Wrong number of arguments to set! predicate. Expected 2, got {}.", + p.len() - 1 + ))); + } + if p[1].type_ != type_string || p[2].type_ != type_string { + return Err(QueryError::Predicate( + "Argument to set! predicate must be strings.".to_string(), + )); + } + let key = &string_values[p[1].value_id as usize]; + let value = &string_values[p[2].value_id as usize]; + + pattern_properties.push((key.to_string(), value.to_string())); + } + + _ => { + return Err(QueryError::Predicate(format!( + "Unknown query predicate function {}", + operator_name, + ))) + } + } } + result + .properties + .push(pattern_properties.into_boxed_slice()); result.predicates.push(pattern_predicates); } @@ -1146,6 +1170,10 @@ impl Query { pub fn capture_names(&self) -> &[String] { &self.capture_names } + + pub fn pattern_properties(&self, index: usize) -> &[(String, String)] { + &self.properties[index] + } } impl QueryCursor { diff --git a/lib/binding_web/binding.js b/lib/binding_web/binding.js index 20c3fdbe..525132e2 100644 --- a/lib/binding_web/binding.js +++ b/lib/binding_web/binding.js @@ -719,6 +719,7 @@ class Language { stringValues[i] = UTF8ToString(valueAddress, nameLength); } + const properties = new Array(patternCount); const predicates = new Array(patternCount); for (let i = 0; i < patternCount; i++) { const predicatesAddress = C._ts_query_predicates_for_pattern( @@ -729,6 +730,8 @@ class Language { const stepCount = getValue(TRANSFER_BUFFER, 'i32'); predicates[i] = []; + properties[i] = null; + const steps = []; let stepAddress = predicatesAddress; for (let j = 0; j < stepCount; j++) { @@ -741,14 +744,28 @@ class Language { } else if (stepType === PREDICATE_STEP_TYPE_STRING) { steps.push({type: 'string', value: stringValues[stepValueId]}); } else if (steps.length > 0) { - predicates[i].push(buildQueryPredicate(steps)); + const predicate = buildQueryPredicate(steps); + if (typeof predicate === 'function') { + predicates[i].push(predicate); + } else { + if (!properties[i]) properties[i] = {}; + properties[i][predicate.key] = predicate.value; + } steps.length = 0; } } + + Object.freeze(properties[i]); } C._free(sourceAddress); - return new Query(INTERNAL, address, captureNames, predicates); + return new Query( + INTERNAL, + address, + captureNames, + predicates, + Object.freeze(properties) + ); } static load(url) { @@ -784,11 +801,12 @@ class Language { } class Query { - constructor(internal, address, captureNames, predicates) { + constructor(internal, address, captureNames, predicates, properties) { assertInternal(internal); this[0] = address; this.captureNames = captureNames; this.predicates = predicates; + this.patternProperties = properties; } delete() { @@ -826,6 +844,9 @@ class Query { if (this.predicates[pattern].every(p => p(captures))) { result[i] = {pattern, captures}; } + + const properties = this.patternProperties[pattern]; + if (properties) result[i].properties = properties; } C._free(startAddress); @@ -851,6 +872,7 @@ class Query { const startAddress = getValue(TRANSFER_BUFFER + SIZE_OF_INT, 'i32'); const result = []; + const captures = []; let address = startAddress; for (let i = 0; i < count; i++) { const pattern = getValue(address, 'i32'); @@ -860,11 +882,14 @@ class Query { const captureIndex = getValue(address, 'i32'); address += SIZE_OF_INT; - const captures = new Array(captureCount); + captures.length = captureCount address = unmarshalCaptures(this, node.tree, address, captures); if (this.predicates[pattern].every(p => p(captures))) { - result.push(captures[captureIndex]); + const capture = captures[captureIndex]; + const properties = this.patternProperties[pattern]; + if (properties) capture.properties = properties; + result.push(capture); } } @@ -927,6 +952,18 @@ function buildQueryPredicate(steps) { return false; } + case 'set!': + if (steps.length !== 3) throw new Error( + `Wrong number of arguments to \`set!\` predicate. Expected 2, got ${steps.length - 1}.` + ); + if (steps[1].type !== 'string' || steps[2].type !== 'string') throw new Error( + `Arguments to \`set!\` predicate must be a strings.".` + ); + return { + key: steps[1].value, + value: steps[2].value, + }; + default: throw new Error(`Unknown query predicate \`${steps[0].value}\``); } diff --git a/lib/binding_web/test/query-test.js b/lib/binding_web/test/query-test.js index b927ceff..f002374c 100644 --- a/lib/binding_web/test/query-test.js +++ b/lib/binding_web/test/query-test.js @@ -197,6 +197,23 @@ describe("Query", () => { ] ); }); + + it('handles patterns with properties', () => { + tree = parser.parse(`a(b.c);`); + query = JavaScript.query(` + ((call_expression (identifier) @func) + (set! foo bar) + (set! baz quux)) + + (property_identifier) @prop + `); + + const captures = query.captures(tree.rootNode); + assert.deepEqual(formatCaptures(captures), [ + {name: 'func', text: 'a', properties: {foo: 'bar', baz: 'quux'}}, + {name: 'prop', text: 'c'}, + ]); + }); }); }); @@ -208,5 +225,10 @@ function formatMatches(matches) { } function formatCaptures(captures) { - return captures.map(({name, node}) => ({ name, text: node.text })) + return captures.map(c => { + const node = c.node; + delete c.node; + c.text = node.text; + return c; + }) } diff --git a/lib/src/query.c b/lib/src/query.c index 6b3c7ddf..8e71052d 100644 --- a/lib/src/query.c +++ b/lib/src/query.c @@ -196,7 +196,7 @@ static void stream_skip_whitespace(Stream *stream) { } static bool stream_is_ident_start(Stream *stream) { - return iswalpha(stream->next) || stream->next == '_' || stream->next == '-'; + return iswalnum(stream->next) || stream->next == '_' || stream->next == '-'; } static void stream_scan_identifier(Stream *stream) { @@ -417,6 +417,7 @@ static TSQueryError ts_query_parse_predicate( for (;;) { if (stream->next == ')') { stream_advance(stream); + stream_skip_whitespace(stream); array_back(&self->predicates_by_pattern)->length++; array_push(&self->predicate_steps, ((TSQueryPredicateStep) { .type = TSQueryPredicateStepTypeDone,