diff --git a/cli/src/lib.rs b/cli/src/lib.rs index 97c288a1..e00323b7 100644 --- a/cli/src/lib.rs +++ b/cli/src/lib.rs @@ -6,6 +6,7 @@ pub mod loader; pub mod logger; pub mod parse; pub mod query; +pub mod query_testing; pub mod tags; pub mod test; pub mod test_highlight; diff --git a/cli/src/main.rs b/cli/src/main.rs index 1eaa6a75..64ec7253 100644 --- a/cli/src/main.rs +++ b/cli/src/main.rs @@ -94,7 +94,8 @@ fn run() -> error::Result<()> { .takes_value(true), ) .arg(Arg::with_name("scope").long("scope").takes_value(true)) - .arg(Arg::with_name("captures").long("captures").short("c")), + .arg(Arg::with_name("captures").long("captures").short("c")) + .arg(Arg::with_name("test").long("test")), ) .subcommand( SubCommand::with_name("tags") @@ -289,7 +290,15 @@ fn run() -> error::Result<()> { let r: Vec<&str> = br.split(":").collect(); (r[0].parse().unwrap(), r[1].parse().unwrap()) }); - query::query_files_at_paths(language, paths, query_path, ordered_captures, range)?; + let should_test = matches.is_present("test"); + query::query_files_at_paths( + language, + paths, + query_path, + ordered_captures, + range, + should_test, + )?; } else if let Some(matches) = matches.subcommand_matches("tags") { loader.find_all_languages(&config.parser_directories)?; let paths = collect_paths(matches.value_of("paths-file"), matches.values_of("paths"))?; diff --git a/cli/src/query.rs b/cli/src/query.rs index e71e6254..485fdb82 100644 --- a/cli/src/query.rs +++ b/cli/src/query.rs @@ -1,4 +1,5 @@ use super::error::{Error, Result}; +use crate::query_testing; use std::fs; use std::io::{self, Write}; use std::path::Path; @@ -10,6 +11,7 @@ pub fn query_files_at_paths( query_path: &Path, ordered_captures: bool, range: Option<(usize, usize)>, + should_test: bool, ) -> Result<()> { let stdout = io::stdout(); let mut stdout = stdout.lock(); @@ -29,6 +31,8 @@ pub fn query_files_at_paths( parser.set_language(language).map_err(|e| e.to_string())?; for path in paths { + let mut results = Vec::new(); + writeln!(&mut stdout, "{}", path)?; let source_code = fs::read(&path).map_err(Error::wrap(|| { @@ -42,14 +46,20 @@ pub fn query_files_at_paths( query_cursor.captures(&query, tree.root_node(), text_callback) { let capture = mat.captures[capture_index]; + let capture_name = &query.capture_names()[capture.index as usize]; writeln!( &mut stdout, " pattern: {}, capture: {}, row: {}, text: {:?}", mat.pattern_index, - &query.capture_names()[capture.index as usize], + capture_name, capture.node.start_position().row, capture.node.utf8_text(&source_code).unwrap_or("") )?; + results.push(query_testing::CaptureInfo { + name: capture_name.to_string(), + start: capture.node.start_position(), + end: capture.node.end_position(), + }); } } else { for m in query_cursor.matches(&query, tree.root_node(), text_callback) { @@ -57,11 +67,12 @@ pub fn query_files_at_paths( for capture in m.captures { let start = capture.node.start_position(); let end = capture.node.end_position(); + let capture_name = &query.capture_names()[capture.index as usize]; if end.row == start.row { writeln!( &mut stdout, " capture: {}, start: {}, text: {:?}", - &query.capture_names()[capture.index as usize], + capture_name, start, capture.node.utf8_text(&source_code).unwrap_or("") )?; @@ -69,14 +80,20 @@ pub fn query_files_at_paths( writeln!( &mut stdout, " capture: {}, start: {}, end: {}", - &query.capture_names()[capture.index as usize], - start, - end, + capture_name, start, end, )?; } + results.push(query_testing::CaptureInfo { + name: capture_name.to_string(), + start: capture.node.start_position(), + end: capture.node.end_position(), + }); } } } + if should_test { + query_testing::assert_expected_captures(results, path, &mut parser, language)? + } } Ok(()) diff --git a/cli/src/query_testing.rs b/cli/src/query_testing.rs new file mode 100644 index 00000000..ef02ec69 --- /dev/null +++ b/cli/src/query_testing.rs @@ -0,0 +1,150 @@ +use crate::error; +use crate::error::Result; +use lazy_static::lazy_static; +use regex::Regex; +use std::fs; +use tree_sitter::{Language, Parser, Point}; + +lazy_static! { + static ref CAPTURE_NAME_REGEX: Regex = Regex::new("[\\w_\\-.]+").unwrap(); +} + +#[derive(Debug, Eq, PartialEq)] +pub struct CaptureInfo { + pub name: String, + pub start: Point, + pub end: Point, +} + +#[derive(Debug, PartialEq, Eq)] +pub struct Assertion { + pub position: Point, + pub expected_capture_name: String, +} + +/// Parse the given source code, finding all of the comments that contain +/// highlighting assertions. Return a vector of (position, expected highlight name) +/// pairs. +pub fn parse_position_comments( + parser: &mut Parser, + language: Language, + source: &[u8], +) -> Result> { + let mut result = Vec::new(); + let mut assertion_ranges = Vec::new(); + + // Parse the code. + parser.set_included_ranges(&[]).unwrap(); + parser.set_language(language).unwrap(); + let tree = parser.parse(source, None).unwrap(); + + // Walk the tree, finding comment nodes that contain assertions. + let mut ascending = false; + let mut cursor = tree.root_node().walk(); + loop { + if ascending { + let node = cursor.node(); + + // Find every comment node. + if node.kind().contains("comment") { + if let Ok(text) = node.utf8_text(source) { + let mut position = node.start_position(); + if position.row == 0 { + continue; + } + + // Find the arrow character ("^" or '<-") in the comment. A left arrow + // refers to the column where the comment node starts. An up arrow refers + // to its own column. + let mut has_left_caret = false; + let mut has_arrow = false; + let mut arrow_end = 0; + for (i, c) in text.char_indices() { + arrow_end = i + 1; + if c == '-' && has_left_caret { + has_arrow = true; + break; + } + if c == '^' { + has_arrow = true; + position.column += i; + break; + } + has_left_caret = c == '<'; + } + + // If the comment node contains an arrow and a highlight name, record the + // highlight name and the position. + if let (true, Some(mat)) = + (has_arrow, CAPTURE_NAME_REGEX.find(&text[arrow_end..])) + { + assertion_ranges.push((node.start_position(), node.end_position())); + result.push(Assertion { + position: position, + expected_capture_name: mat.as_str().to_string(), + }); + } + } + } + + // Continue walking the tree. + if cursor.goto_next_sibling() { + ascending = false; + } else if !cursor.goto_parent() { + break; + } + } else if !cursor.goto_first_child() { + ascending = true; + } + } + + // Adjust the row number in each assertion's position to refer to the line of + // code *above* the assertion. There can be multiple lines of assertion comments, + // so the positions may have to be decremented by more than one row. + let mut i = 0; + for assertion in result.iter_mut() { + loop { + let on_assertion_line = assertion_ranges[i..] + .iter() + .any(|(start, _)| start.row == assertion.position.row); + if on_assertion_line { + assertion.position.row -= 1; + } else { + while i < assertion_ranges.len() + && assertion_ranges[i].0.row < assertion.position.row + { + i += 1; + } + break; + } + } + } + + // The assertions can end up out of order due to the line adjustments. + result.sort_unstable_by_key(|a| a.position); + + Ok(result) +} + +pub fn assert_expected_captures( + infos: Vec, + path: String, + parser: &mut Parser, + language: Language, +) -> Result<()> { + let contents = fs::read_to_string(path)?; + let pairs = parse_position_comments(parser, language, contents.as_bytes())?; + for info in &infos { + if let Some(found) = pairs.iter().find(|p| { + p.position.row == info.start.row && p.position >= info.start && p.position < info.end + }) { + if found.expected_capture_name != info.name && info.name != "name" { + Err(error::Error::new(format!( + "Assertion failed: at {}, found {}, expected {}", + info.start, found.expected_capture_name, info.name + )))? + } + } + } + Ok(()) +} diff --git a/cli/src/test_highlight.rs b/cli/src/test_highlight.rs index cf163c05..df870bf6 100644 --- a/cli/src/test_highlight.rs +++ b/cli/src/test_highlight.rs @@ -1,17 +1,12 @@ use super::error::Result; use crate::loader::Loader; +use crate::query_testing::{parse_position_comments, Assertion}; use ansi_term::Colour; -use lazy_static::lazy_static; -use regex::Regex; use std::fs; use std::path::Path; -use tree_sitter::{Language, Parser, Point}; +use tree_sitter::Point; use tree_sitter_highlight::{Highlight, HighlightConfiguration, HighlightEvent, Highlighter}; -lazy_static! { - static ref HIGHLIGHT_NAME_REGEX: Regex = Regex::new("[\\w_\\-.]+").unwrap(); -} - pub struct Failure { row: usize, column: usize, @@ -86,23 +81,20 @@ pub fn test_highlights(loader: &Loader, directory: &Path) -> Result<()> { Ok(()) } } - -pub fn test_highlight( - loader: &Loader, - highlighter: &mut Highlighter, - highlight_config: &HighlightConfiguration, - source: &[u8], +pub fn iterate_assertions( + assertions: &Vec, + highlights: &Vec<(Point, Point, Highlight)>, + highlight_names: &Vec, ) -> Result { - // Highlight the file, and parse out all of the highlighting assertions. - let highlight_names = loader.highlight_names(); - let highlights = get_highlight_positions(loader, highlighter, highlight_config, source)?; - let assertions = parse_highlight_test(highlighter.parser(), highlight_config.language, source)?; - // Iterate through all of the highlighting assertions, checking each one against the // actual highlights. let mut i = 0; let mut actual_highlights = Vec::<&String>::new(); - for (position, expected_highlight) in &assertions { + for Assertion { + position, + expected_capture_name: expected_highlight, + } in assertions + { let mut passed = false; actual_highlights.clear(); @@ -156,103 +148,80 @@ pub fn test_highlight( Ok(assertions.len()) } -/// Parse the given source code, finding all of the comments that contain -/// highlighting assertions. Return a vector of (position, expected highlight name) -/// pairs. -pub fn parse_highlight_test( - parser: &mut Parser, - language: Language, +pub fn test_highlight( + loader: &Loader, + highlighter: &mut Highlighter, + highlight_config: &HighlightConfiguration, source: &[u8], -) -> Result> { - let mut result = Vec::new(); - let mut assertion_ranges = Vec::new(); +) -> Result { + // Highlight the file, and parse out all of the highlighting assertions. + let highlight_names = loader.highlight_names(); + let highlights = get_highlight_positions(loader, highlighter, highlight_config, source)?; + let assertions = + parse_position_comments(highlighter.parser(), highlight_config.language, source)?; - // Parse the code. - parser.set_included_ranges(&[]).unwrap(); - parser.set_language(language).unwrap(); - let tree = parser.parse(source, None).unwrap(); + iterate_assertions(&assertions, &highlights, &highlight_names)?; - // Walk the tree, finding comment nodes that contain assertions. - let mut ascending = false; - let mut cursor = tree.root_node().walk(); - loop { - if ascending { - let node = cursor.node(); - - // Find every comment node. - if node.kind().contains("comment") { - if let Ok(text) = node.utf8_text(source) { - let mut position = node.start_position(); - if position.row == 0 { - continue; - } - - // Find the arrow character ("^" or '<-") in the comment. A left arrow - // refers to the column where the comment node starts. An up arrow refers - // to its own column. - let mut has_left_caret = false; - let mut has_arrow = false; - let mut arrow_end = 0; - for (i, c) in text.char_indices() { - arrow_end = i + 1; - if c == '-' && has_left_caret { - has_arrow = true; - break; - } - if c == '^' { - has_arrow = true; - position.column += i; - break; - } - has_left_caret = c == '<'; - } - - // If the comment node contains an arrow and a highlight name, record the - // highlight name and the position. - if let (true, Some(mat)) = - (has_arrow, HIGHLIGHT_NAME_REGEX.find(&text[arrow_end..])) - { - assertion_ranges.push((node.start_position(), node.end_position())); - result.push((position, mat.as_str().to_string())); - } - } - } - - // Continue walking the tree. - if cursor.goto_next_sibling() { - ascending = false; - } else if !cursor.goto_parent() { - break; - } - } else if !cursor.goto_first_child() { - ascending = true; - } - } - - // Adjust the row number in each assertion's position to refer to the line of - // code *above* the assertion. There can be multiple lines of assertion comments, - // so the positions may have to be decremented by more than one row. + // Iterate through all of the highlighting assertions, checking each one against the + // actual highlights. let mut i = 0; - for (position, _) in result.iter_mut() { - loop { - let on_assertion_line = assertion_ranges[i..] - .iter() - .any(|(start, _)| start.row == position.row); - if on_assertion_line { - position.row -= 1; - } else { - while i < assertion_ranges.len() && assertion_ranges[i].0.row < position.row { + let mut actual_highlights = Vec::<&String>::new(); + for Assertion { + position, + expected_capture_name: expected_highlight, + } in &assertions + { + let mut passed = false; + actual_highlights.clear(); + + 'highlight_loop: loop { + // The assertions are ordered by position, so skip past all of the highlights that + // end at or before this assertion's position. + if let Some(highlight) = highlights.get(i) { + if highlight.1 <= *position { i += 1; + continue; } + + // Iterate through all of the highlights that start at or before this assertion's, + // position, looking for one that matches the assertion. + let mut j = i; + while let (false, Some(highlight)) = (passed, highlights.get(j)) { + if highlight.0 > *position { + break 'highlight_loop; + } + + // If the highlight matches the assertion, this test passes. Otherwise, + // add this highlight to the list of actual highlights that span the + // assertion's position, in order to generate an error message in the event + // of a failure. + let highlight_name = &highlight_names[(highlight.2).0]; + if *highlight_name == *expected_highlight { + passed = true; + break 'highlight_loop; + } else { + actual_highlights.push(highlight_name); + } + + j += 1; + } + } else { break; } } + + if !passed { + return Err(Failure { + row: position.row, + column: position.column, + expected_highlight: expected_highlight.clone(), + actual_highlights: actual_highlights.into_iter().cloned().collect(), + } + .into()); + } } - // The assertions can end up out of order due to the line adjustments. - result.sort_unstable_by_key(|a| a.0); - - Ok(result) + Ok(assertions.len()) } pub fn get_highlight_positions( diff --git a/cli/src/tests/test_highlight_test.rs b/cli/src/tests/test_highlight_test.rs index 6a857dd9..1a658281 100644 --- a/cli/src/tests/test_highlight_test.rs +++ b/cli/src/tests/test_highlight_test.rs @@ -1,5 +1,6 @@ use super::helpers::fixtures::{get_highlight_config, get_language, test_loader}; -use crate::test_highlight::{get_highlight_positions, parse_highlight_test}; +use crate::query_testing::{parse_position_comments, Assertion}; +use crate::test_highlight::get_highlight_positions; use tree_sitter::{Parser, Point}; use tree_sitter_highlight::{Highlight, Highlighter}; @@ -25,13 +26,23 @@ fn test_highlight_test_with_basic_test() { ] .join("\n"); - let assertions = parse_highlight_test(&mut Parser::new(), language, source.as_bytes()).unwrap(); + let assertions = + parse_position_comments(&mut Parser::new(), language, source.as_bytes()).unwrap(); assert_eq!( assertions, &[ - (Point::new(0, 5), "function".to_string()), - (Point::new(0, 11), "keyword".to_string()), - (Point::new(3, 9), "variable.parameter".to_string()), + Assertion { + position: Point::new(0, 5), + expected_capture_name: "function".to_string() + }, + Assertion { + position: Point::new(0, 11), + expected_capture_name: "keyword".to_string() + }, + Assertion { + position: Point::new(3, 9), + expected_capture_name: "variable.parameter".to_string() + }, ] );