Handle set! predicate function in queries

This commit is contained in:
Max Brunsfeld 2019-09-18 17:35:47 -07:00
parent ff9a2c1f53
commit b15e90bd26
5 changed files with 140 additions and 25 deletions

View file

@ -632,6 +632,33 @@ fn test_query_captures_with_text_conditions() {
});
}
#[test]
fn test_query_captures_with_set_properties() {
allocations::record(|| {
let language = get_language("javascript");
let query = Query::new(
language,
r#"
((call_expression (identifier) @foo)
(set! name something)
(set! age 24))
(property_identifier) @bar"#,
)
.unwrap();
assert_eq!(
query.pattern_properties(0),
&[
("name".to_string(), "something".to_string()),
("age".to_string(), "24".to_string()),
]
);
assert_eq!(query.pattern_properties(1), &[])
});
}
#[test]
fn test_query_captures_with_duplicates() {
allocations::record(|| {

View file

@ -150,6 +150,7 @@ pub struct Query {
ptr: NonNull<ffi::TSQuery>,
capture_names: Vec<String>,
predicates: Vec<Vec<QueryPredicate>>,
properties: Vec<Box<[(String, String)]>>,
}
pub struct QueryCursor(NonNull<ffi::TSQueryCursor>);
@ -1002,6 +1003,7 @@ impl Query {
ptr: unsafe { NonNull::new_unchecked(ptr) },
capture_names: Vec::with_capacity(capture_count as usize),
predicates: Vec::with_capacity(pattern_count),
properties: Vec::with_capacity(pattern_count),
};
// Build a vector of strings to store the capture names.
@ -1042,6 +1044,7 @@ impl Query {
let type_capture = ffi::TSQueryPredicateStepType_TSQueryPredicateStepTypeCapture;
let type_string = ffi::TSQueryPredicateStepType_TSQueryPredicateStepTypeString;
let mut pattern_properties = Vec::new();
let mut pattern_predicates = Vec::new();
for p in predicate_steps.split(|s| s.type_ == type_done) {
if p.is_empty() {
@ -1057,7 +1060,7 @@ impl Query {
// Build a predicate for each of the known predicate function names.
let operator_name = &string_values[p[0].value_id as usize];
pattern_predicates.push(match operator_name.as_str() {
match operator_name.as_str() {
"eq?" => {
if p.len() != 3 {
return Err(QueryError::Predicate(format!(
@ -1072,17 +1075,14 @@ impl Query {
)));
}
if p[2].type_ == type_capture {
Ok(QueryPredicate::CaptureEqCapture(
p[1].value_id,
p[2].value_id,
))
pattern_predicates.push(if p[2].type_ == type_capture {
QueryPredicate::CaptureEqCapture(p[1].value_id, p[2].value_id)
} else {
Ok(QueryPredicate::CaptureEqString(
QueryPredicate::CaptureEqString(
p[1].value_id,
string_values[p[2].value_id as usize].clone(),
))
}
)
});
}
"match?" => {
@ -1106,20 +1106,44 @@ impl Query {
}
let regex = &string_values[p[2].value_id as usize];
Ok(QueryPredicate::CaptureMatchString(
pattern_predicates.push(QueryPredicate::CaptureMatchString(
p[1].value_id,
regex::bytes::Regex::new(regex)
.map_err(|_| QueryError::Predicate(format!("Invalid regex '{}'", regex)))?,
))
regex::bytes::Regex::new(regex).map_err(|_| {
QueryError::Predicate(format!("Invalid regex '{}'", regex))
})?,
));
}
_ => Err(QueryError::Predicate(format!(
"Unknown query predicate function {}",
operator_name,
))),
}?);
"set!" => {
if p.len() != 3 {
return Err(QueryError::Predicate(format!(
"Wrong number of arguments to set! predicate. Expected 2, got {}.",
p.len() - 1
)));
}
if p[1].type_ != type_string || p[2].type_ != type_string {
return Err(QueryError::Predicate(
"Argument to set! predicate must be strings.".to_string(),
));
}
let key = &string_values[p[1].value_id as usize];
let value = &string_values[p[2].value_id as usize];
pattern_properties.push((key.to_string(), value.to_string()));
}
_ => {
return Err(QueryError::Predicate(format!(
"Unknown query predicate function {}",
operator_name,
)))
}
}
}
result
.properties
.push(pattern_properties.into_boxed_slice());
result.predicates.push(pattern_predicates);
}
@ -1146,6 +1170,10 @@ impl Query {
pub fn capture_names(&self) -> &[String] {
&self.capture_names
}
pub fn pattern_properties(&self, index: usize) -> &[(String, String)] {
&self.properties[index]
}
}
impl QueryCursor {

View file

@ -719,6 +719,7 @@ class Language {
stringValues[i] = UTF8ToString(valueAddress, nameLength);
}
const properties = new Array(patternCount);
const predicates = new Array(patternCount);
for (let i = 0; i < patternCount; i++) {
const predicatesAddress = C._ts_query_predicates_for_pattern(
@ -729,6 +730,8 @@ class Language {
const stepCount = getValue(TRANSFER_BUFFER, 'i32');
predicates[i] = [];
properties[i] = null;
const steps = [];
let stepAddress = predicatesAddress;
for (let j = 0; j < stepCount; j++) {
@ -741,14 +744,28 @@ class Language {
} else if (stepType === PREDICATE_STEP_TYPE_STRING) {
steps.push({type: 'string', value: stringValues[stepValueId]});
} else if (steps.length > 0) {
predicates[i].push(buildQueryPredicate(steps));
const predicate = buildQueryPredicate(steps);
if (typeof predicate === 'function') {
predicates[i].push(predicate);
} else {
if (!properties[i]) properties[i] = {};
properties[i][predicate.key] = predicate.value;
}
steps.length = 0;
}
}
Object.freeze(properties[i]);
}
C._free(sourceAddress);
return new Query(INTERNAL, address, captureNames, predicates);
return new Query(
INTERNAL,
address,
captureNames,
predicates,
Object.freeze(properties)
);
}
static load(url) {
@ -784,11 +801,12 @@ class Language {
}
class Query {
constructor(internal, address, captureNames, predicates) {
constructor(internal, address, captureNames, predicates, properties) {
assertInternal(internal);
this[0] = address;
this.captureNames = captureNames;
this.predicates = predicates;
this.patternProperties = properties;
}
delete() {
@ -826,6 +844,9 @@ class Query {
if (this.predicates[pattern].every(p => p(captures))) {
result[i] = {pattern, captures};
}
const properties = this.patternProperties[pattern];
if (properties) result[i].properties = properties;
}
C._free(startAddress);
@ -851,6 +872,7 @@ class Query {
const startAddress = getValue(TRANSFER_BUFFER + SIZE_OF_INT, 'i32');
const result = [];
const captures = [];
let address = startAddress;
for (let i = 0; i < count; i++) {
const pattern = getValue(address, 'i32');
@ -860,11 +882,14 @@ class Query {
const captureIndex = getValue(address, 'i32');
address += SIZE_OF_INT;
const captures = new Array(captureCount);
captures.length = captureCount
address = unmarshalCaptures(this, node.tree, address, captures);
if (this.predicates[pattern].every(p => p(captures))) {
result.push(captures[captureIndex]);
const capture = captures[captureIndex];
const properties = this.patternProperties[pattern];
if (properties) capture.properties = properties;
result.push(capture);
}
}
@ -927,6 +952,18 @@ function buildQueryPredicate(steps) {
return false;
}
case 'set!':
if (steps.length !== 3) throw new Error(
`Wrong number of arguments to \`set!\` predicate. Expected 2, got ${steps.length - 1}.`
);
if (steps[1].type !== 'string' || steps[2].type !== 'string') throw new Error(
`Arguments to \`set!\` predicate must be a strings.".`
);
return {
key: steps[1].value,
value: steps[2].value,
};
default:
throw new Error(`Unknown query predicate \`${steps[0].value}\``);
}

View file

@ -197,6 +197,23 @@ describe("Query", () => {
]
);
});
it('handles patterns with properties', () => {
tree = parser.parse(`a(b.c);`);
query = JavaScript.query(`
((call_expression (identifier) @func)
(set! foo bar)
(set! baz quux))
(property_identifier) @prop
`);
const captures = query.captures(tree.rootNode);
assert.deepEqual(formatCaptures(captures), [
{name: 'func', text: 'a', properties: {foo: 'bar', baz: 'quux'}},
{name: 'prop', text: 'c'},
]);
});
});
});
@ -208,5 +225,10 @@ function formatMatches(matches) {
}
function formatCaptures(captures) {
return captures.map(({name, node}) => ({ name, text: node.text }))
return captures.map(c => {
const node = c.node;
delete c.node;
c.text = node.text;
return c;
})
}

View file

@ -196,7 +196,7 @@ static void stream_skip_whitespace(Stream *stream) {
}
static bool stream_is_ident_start(Stream *stream) {
return iswalpha(stream->next) || stream->next == '_' || stream->next == '-';
return iswalnum(stream->next) || stream->next == '_' || stream->next == '-';
}
static void stream_scan_identifier(Stream *stream) {
@ -417,6 +417,7 @@ static TSQueryError ts_query_parse_predicate(
for (;;) {
if (stream->next == ')') {
stream_advance(stream);
stream_skip_whitespace(stream);
array_back(&self->predicates_by_pattern)->length++;
array_push(&self->predicate_steps, ((TSQueryPredicateStep) {
.type = TSQueryPredicateStepTypeDone,