From 783c087aecf9f2bffd57abd5f4562fc9d108f00e Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Wed, 25 Mar 2020 11:26:52 -0700 Subject: [PATCH] tags: Handle cancellation --- cli/src/highlight.rs | 22 +++------------- cli/src/tags.rs | 5 +++- cli/src/tests/tags_test.rs | 53 +++++++++++++++++++++++++++++++++----- cli/src/util.rs | 19 +++++++++++++- tags/src/c_lib.rs | 28 ++++++++++++++++++-- tags/src/lib.rs | 45 ++++++++++++++++++++++---------- 6 files changed, 130 insertions(+), 42 deletions(-) diff --git a/cli/src/highlight.rs b/cli/src/highlight.rs index c80e6083..c6b1193d 100644 --- a/cli/src/highlight.rs +++ b/cli/src/highlight.rs @@ -1,3 +1,4 @@ +use super::util; use crate::error::Result; use crate::loader::Loader; use ansi_term::Color; @@ -6,10 +7,8 @@ use serde::ser::SerializeMap; use serde::{Deserialize, Deserializer, Serialize, Serializer}; use serde_json::{json, Value}; use std::collections::HashMap; -use std::sync::atomic::{AtomicUsize, Ordering}; -use std::sync::Arc; use std::time::Instant; -use std::{fs, io, path, str, thread, usize}; +use std::{fs, io, path, str, usize}; use tree_sitter_highlight::{HighlightConfiguration, HighlightEvent, Highlighter, HtmlRenderer}; pub const HTML_HEADER: &'static str = " @@ -273,19 +272,6 @@ fn color_to_css(color: Color) -> &'static str { } } -fn cancel_on_stdin() -> Arc { - let result = Arc::new(AtomicUsize::new(0)); - thread::spawn({ - let flag = result.clone(); - move || { - let mut line = String::new(); - io::stdin().read_line(&mut line).unwrap(); - flag.store(1, Ordering::Relaxed); - } - }); - result -} - pub fn ansi( loader: &Loader, theme: &Theme, @@ -296,7 +282,7 @@ pub fn ansi( let stdout = io::stdout(); let mut stdout = stdout.lock(); let time = Instant::now(); - let cancellation_flag = cancel_on_stdin(); + let cancellation_flag = util::cancel_on_stdin(); let mut highlighter = Highlighter::new(); let events = highlighter.highlight(config, source, Some(&cancellation_flag), |string| { @@ -341,7 +327,7 @@ pub fn html( let stdout = io::stdout(); let mut stdout = stdout.lock(); let time = Instant::now(); - let cancellation_flag = cancel_on_stdin(); + let cancellation_flag = util::cancel_on_stdin(); let mut highlighter = Highlighter::new(); let events = highlighter.highlight(config, source, Some(&cancellation_flag), |string| { diff --git a/cli/src/tags.rs b/cli/src/tags.rs index c65d5479..d6704ec5 100644 --- a/cli/src/tags.rs +++ b/cli/src/tags.rs @@ -1,4 +1,5 @@ use super::loader::Loader; +use super::util; use crate::error::{Error, Result}; use std::io::{self, Write}; use std::path::Path; @@ -15,6 +16,7 @@ pub fn generate_tags(loader: &Loader, scope: Option<&str>, paths: &[String]) -> } let mut context = TagsContext::new(); + let cancellation_flag = util::cancel_on_stdin(); let stdout = io::stdout(); let mut stdout = stdout.lock(); @@ -36,7 +38,8 @@ pub fn generate_tags(loader: &Loader, scope: Option<&str>, paths: &[String]) -> writeln!(&mut stdout, "{}", &path_str[1..path_str.len() - 1])?; let source = fs::read(path)?; - for tag in context.generate_tags(tags_config, &source) { + for tag in context.generate_tags(tags_config, &source, Some(&cancellation_flag))? { + let tag = tag?; write!( &mut stdout, " {:<8} {:<40}\t{:>9}-{:<9}", diff --git a/cli/src/tests/tags_test.rs b/cli/src/tests/tags_test.rs index 1b6cb2f3..41907a3c 100644 --- a/cli/src/tests/tags_test.rs +++ b/cli/src/tests/tags_test.rs @@ -3,7 +3,7 @@ use super::helpers::fixtures::{get_language, get_language_queries_path}; use std::ffi::CString; use std::{fs, ptr, slice, str}; use tree_sitter_tags::c_lib as c; -use tree_sitter_tags::{TagKind, TagsConfiguration, TagsContext}; +use tree_sitter_tags::{Error, TagKind, TagsConfiguration, TagsContext}; const PYTHON_TAG_QUERY: &'static str = r#" ((function_definition @@ -79,8 +79,10 @@ fn test_tags_python() { "#; let tags = tag_context - .generate_tags(&tags_config, source) - .collect::>(); + .generate_tags(&tags_config, source, None) + .unwrap() + .collect::, _>>() + .unwrap(); assert_eq!( tags.iter() @@ -128,8 +130,10 @@ fn test_tags_javascript() { let mut tag_context = TagsContext::new(); let tags = tag_context - .generate_tags(&tags_config, source) - .collect::>(); + .generate_tags(&tags_config, source, None) + .unwrap() + .collect::, _>>() + .unwrap(); assert_eq!( tags.iter() @@ -178,8 +182,10 @@ fn test_tags_ruby() { let mut tag_context = TagsContext::new(); let tags = tag_context - .generate_tags(&tags_config, source.as_bytes()) - .collect::>(); + .generate_tags(&tags_config, source.as_bytes(), None) + .unwrap() + .collect::, _>>() + .unwrap(); assert_eq!( tags.iter() @@ -201,6 +207,39 @@ fn test_tags_ruby() { ); } +#[test] +fn test_tags_cancellation() { + use std::sync::atomic::{AtomicUsize, Ordering}; + + allocations::record(|| { + // Large javascript document + let source = (0..500) + .map(|_| "/* hi */ class A { /* ok */ b() {} }\n") + .collect::(); + + let cancellation_flag = AtomicUsize::new(0); + let language = get_language("javascript"); + let tags_config = TagsConfiguration::new(language, JS_TAG_QUERY, "").unwrap(); + + let mut tag_context = TagsContext::new(); + let tags = tag_context + .generate_tags(&tags_config, source.as_bytes(), Some(&cancellation_flag)) + .unwrap(); + + for (i, tag) in tags.enumerate() { + if i == 150 { + cancellation_flag.store(1, Ordering::SeqCst); + } + if let Err(e) = tag { + assert_eq!(e, Error::Cancelled); + return; + } + } + + panic!("Expected to halt tagging with an error"); + }); +} + #[test] fn test_tags_via_c_api() { allocations::record(|| { diff --git a/cli/src/util.rs b/cli/src/util.rs index e880bea1..8978ecc1 100644 --- a/cli/src/util.rs +++ b/cli/src/util.rs @@ -1,12 +1,29 @@ +use std::io; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; +use std::thread; +use tree_sitter::Parser; + #[cfg(unix)] use std::path::PathBuf; #[cfg(unix)] use std::process::{Child, ChildStdin, Command, Stdio}; -use tree_sitter::Parser; #[cfg(unix)] const HTML_HEADER: &[u8] = b"\n\n\n"; +pub fn cancel_on_stdin() -> Arc { + let result = Arc::new(AtomicUsize::new(0)); + thread::spawn({ + let flag = result.clone(); + move || { + let mut line = String::new(); + io::stdin().read_line(&mut line).unwrap(); + flag.store(1, Ordering::Relaxed); + } + }); + result +} #[cfg(windows)] pub struct LogSession(); diff --git a/tags/src/c_lib.rs b/tags/src/c_lib.rs index 0d61fb46..df785aa7 100644 --- a/tags/src/c_lib.rs +++ b/tags/src/c_lib.rs @@ -2,6 +2,7 @@ use super::{Error, TagKind, TagsConfiguration, TagsContext}; use std::collections::HashMap; use std::ffi::CStr; use std::process::abort; +use std::sync::atomic::AtomicUsize; use std::{fmt, ptr, slice, str}; use tree_sitter::Language; @@ -15,6 +16,7 @@ pub enum TSTagsError { InvalidUtf8, InvalidRegex, InvalidQuery, + Unknown, } #[repr(C)] @@ -100,6 +102,7 @@ pub extern "C" fn ts_tagger_add_language( } Err(Error::Query(_)) => TSTagsError::InvalidQuery, Err(Error::Regex(_)) => TSTagsError::InvalidRegex, + Err(_) => TSTagsError::Unknown, } } @@ -110,7 +113,7 @@ pub extern "C" fn ts_tagger_tag( source_code: *const u8, source_code_len: u32, output: *mut TSTagsBuffer, - cancellation_flag: *const usize, + cancellation_flag: *const AtomicUsize, ) -> TSTagsError { let tagger = unwrap_mut_ptr(this); let buffer = unwrap_mut_ptr(output); @@ -120,8 +123,29 @@ pub extern "C" fn ts_tagger_tag( buffer.tags.clear(); buffer.docs.clear(); let source_code = unsafe { slice::from_raw_parts(source_code, source_code_len as usize) }; + let cancellation_flag = unsafe { cancellation_flag.as_ref() }; + + let tags = match buffer + .context + .generate_tags(config, source_code, cancellation_flag) + { + Ok(tags) => tags, + Err(e) => { + return match e { + Error::InvalidLanguage => TSTagsError::InvalidLanguage, + Error::Cancelled => TSTagsError::Timeout, + _ => TSTagsError::Timeout, + } + } + }; + + for tag in tags { + let tag = if let Ok(tag) = tag { + tag + } else { + return TSTagsError::Timeout; + }; - for tag in buffer.context.generate_tags(config, source_code) { let prev_docs_len = buffer.docs.len(); if let Some(docs) = tag.docs { buffer.docs.extend_from_slice(docs.as_bytes()); diff --git a/tags/src/lib.rs b/tags/src/lib.rs index 566efe52..c3642c8f 100644 --- a/tags/src/lib.rs +++ b/tags/src/lib.rs @@ -3,12 +3,14 @@ pub mod c_lib; use memchr::{memchr, memrchr}; use regex::Regex; use std::ops::Range; +use std::sync::atomic::{AtomicUsize, Ordering}; use std::{fmt, mem, str}; use tree_sitter::{ Language, Parser, Point, Query, QueryCursor, QueryError, QueryPredicateArg, Tree, }; const MAX_LINE_LEN: usize = 180; +const CANCELLATION_CHECK_INTERVAL: usize = 100; /// Contains the data neeeded to compute tags for code written in a /// particular language. @@ -53,10 +55,12 @@ pub enum TagKind { Call, } -#[derive(Debug)] +#[derive(Debug, PartialEq)] pub enum Error { Query(QueryError), Regex(regex::Error), + Cancelled, + InvalidLanguage, } #[derive(Debug, Default)] @@ -88,6 +92,8 @@ where _tree: Tree, source: &'a [u8], config: &'a TagsConfiguration, + cancellation_flag: Option<&'a AtomicUsize>, + iter_count: usize, tag_queue: Vec<(Tag, usize)>, scopes: Vec>, } @@ -201,14 +207,13 @@ impl TagsContext { &'a mut self, config: &'a TagsConfiguration, source: &'a [u8], - ) -> impl Iterator + 'a { + cancellation_flag: Option<&'a AtomicUsize>, + ) -> Result> + 'a, Error> { self.parser .set_language(config.language) - .expect("Incompatible language"); - let tree = self - .parser - .parse(source, None) - .expect("Parsing failed unexpectedly"); + .map_err(|_| Error::InvalidLanguage)?; + unsafe { self.parser.set_cancellation_flag(cancellation_flag) }; + let tree = self.parser.parse(source, None).ok_or(Error::Cancelled)?; // The `matches` iterator borrows the `Tree`, which prevents it from being moved. // But the tree is really just a pointer, so it's actually ok to move it. @@ -218,18 +223,20 @@ impl TagsContext { .matches(&config.query, tree_ref.root_node(), move |node| { &source[node.byte_range()] }); - TagsIter { + Ok(TagsIter { _tree: tree, matches, source, config, + cancellation_flag, tag_queue: Vec::new(), + iter_count: 0, scopes: vec![LocalScope { range: 0..source.len(), inherits: false, local_defs: Vec::new(), }], - } + }) } } @@ -237,17 +244,29 @@ impl<'a, I> Iterator for TagsIter<'a, I> where I: Iterator>, { - type Item = Tag; + type Item = Result; - fn next(&mut self) -> Option { + fn next(&mut self) -> Option { loop { + // Periodically check for cancellation, returning `Cancelled` error if the + // cancellation flag was flipped. + if let Some(cancellation_flag) = self.cancellation_flag { + self.iter_count += 1; + if self.iter_count >= CANCELLATION_CHECK_INTERVAL { + self.iter_count = 0; + if cancellation_flag.load(Ordering::Relaxed) != 0 { + return Some(Err(Error::Cancelled)); + } + } + } + // If there is a queued tag for an earlier node in the syntax tree, then pop // it off of the queue and return it. if let Some(last_entry) = self.tag_queue.last() { if self.tag_queue.len() > 1 && self.tag_queue[0].0.name_range.end < last_entry.0.name_range.start { - return Some(self.tag_queue.remove(0).0); + return Some(Ok(self.tag_queue.remove(0).0)); } } @@ -420,7 +439,7 @@ where } // If there are no more matches, then drain the queue. else if !self.tag_queue.is_empty() { - return Some(self.tag_queue.remove(0).0); + return Some(Ok(self.tag_queue.remove(0).0)); } else { return None; }