diff --git a/cli/src/query.rs b/cli/src/query.rs index 56b86740..704a2c56 100644 --- a/cli/src/query.rs +++ b/cli/src/query.rs @@ -60,6 +60,7 @@ pub fn query_files_at_paths( )?; results.push(CaptureInfo { name: capture_name.to_string(), + position: capture.node.start_position(), }); } } else { @@ -86,6 +87,7 @@ pub fn query_files_at_paths( } results.push(CaptureInfo { name: capture_name.to_string(), + position: capture.node.start_position(), }); } } diff --git a/cli/src/query/assert.rs b/cli/src/query/assert.rs index d5998eaf..1b31c1c0 100644 --- a/cli/src/query/assert.rs +++ b/cli/src/query/assert.rs @@ -1,6 +1,8 @@ +use super::super::error; use super::super::error::Result; use lazy_static::lazy_static; use regex::Regex; +use std::collections::hash_map::HashMap; use std::fs; use tree_sitter::Point; @@ -10,8 +12,10 @@ lazy_static! { static ref METADATA_REGEX: Regex = Regex::new(r#"(\w+): ([^\s,]+), (\d+), (\d+)"#).unwrap(); } +#[derive(Debug, Eq, PartialEq)] pub struct CaptureInfo { pub name: String, + pub position: Point, } #[derive(Debug, Eq, PartialEq)] @@ -44,7 +48,7 @@ impl From> for Assertion { } } -pub fn assert_expected_captures(_captures: Vec, path: String) -> Result<()> { +pub fn assert_expected_captures(captures: Vec, path: String) -> Result<()> { let contents = fs::read_to_string(path)?; let assertions: Vec = METADATA_REGEX @@ -52,9 +56,22 @@ pub fn assert_expected_captures(_captures: Vec, path: String) -> Re .map(|c| Assertion::from(c)) .collect(); - for a in assertions { - println!("a: {:?}", a); - } + let per_position_index: HashMap = + assertions.iter().map(|a| (a.position, a)).collect(); + for capture in &captures { + let oFound = per_position_index.get(&capture.position); + if oFound.is_none() { + continue; + } + let found = oFound.unwrap(); + let joined = format!("{}.{}", found.capture_class, found.capture_type); + if joined != capture.name && capture.name != "name" { + Err(error::Error::new(format!( + "Assertion failed: at {}, found {}, expected {}", + capture.position, capture.name, joined + )))? + } + } Ok(()) } diff --git a/test/fixtures/queries/python.py b/test/fixtures/queries/python.py index c90830a7..a48ed2de 100644 --- a/test/fixtures/queries/python.py +++ b/test/fixtures/queries/python.py @@ -1,7 +1,7 @@ def foo(): pass -# declaration: function: 0, 0 +# definition: function: 0, 0 def bar(): -# declaration: function, 3, 0 +# definition: function, 3, 0 foo() # reference: call, 5, 4