tags: Handle cancellation

This commit is contained in:
Max Brunsfeld 2020-03-25 11:26:52 -07:00
parent ae075e75f0
commit 783c087aec
6 changed files with 130 additions and 42 deletions

View file

@ -1,3 +1,4 @@
use super::util;
use crate::error::Result;
use crate::loader::Loader;
use ansi_term::Color;
@ -6,10 +7,8 @@ use serde::ser::SerializeMap;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use serde_json::{json, Value};
use std::collections::HashMap;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::Instant;
use std::{fs, io, path, str, thread, usize};
use std::{fs, io, path, str, usize};
use tree_sitter_highlight::{HighlightConfiguration, HighlightEvent, Highlighter, HtmlRenderer};
pub const HTML_HEADER: &'static str = "
@ -273,19 +272,6 @@ fn color_to_css(color: Color) -> &'static str {
}
}
fn cancel_on_stdin() -> Arc<AtomicUsize> {
let result = Arc::new(AtomicUsize::new(0));
thread::spawn({
let flag = result.clone();
move || {
let mut line = String::new();
io::stdin().read_line(&mut line).unwrap();
flag.store(1, Ordering::Relaxed);
}
});
result
}
pub fn ansi(
loader: &Loader,
theme: &Theme,
@ -296,7 +282,7 @@ pub fn ansi(
let stdout = io::stdout();
let mut stdout = stdout.lock();
let time = Instant::now();
let cancellation_flag = cancel_on_stdin();
let cancellation_flag = util::cancel_on_stdin();
let mut highlighter = Highlighter::new();
let events = highlighter.highlight(config, source, Some(&cancellation_flag), |string| {
@ -341,7 +327,7 @@ pub fn html(
let stdout = io::stdout();
let mut stdout = stdout.lock();
let time = Instant::now();
let cancellation_flag = cancel_on_stdin();
let cancellation_flag = util::cancel_on_stdin();
let mut highlighter = Highlighter::new();
let events = highlighter.highlight(config, source, Some(&cancellation_flag), |string| {

View file

@ -1,4 +1,5 @@
use super::loader::Loader;
use super::util;
use crate::error::{Error, Result};
use std::io::{self, Write};
use std::path::Path;
@ -15,6 +16,7 @@ pub fn generate_tags(loader: &Loader, scope: Option<&str>, paths: &[String]) ->
}
let mut context = TagsContext::new();
let cancellation_flag = util::cancel_on_stdin();
let stdout = io::stdout();
let mut stdout = stdout.lock();
@ -36,7 +38,8 @@ pub fn generate_tags(loader: &Loader, scope: Option<&str>, paths: &[String]) ->
writeln!(&mut stdout, "{}", &path_str[1..path_str.len() - 1])?;
let source = fs::read(path)?;
for tag in context.generate_tags(tags_config, &source) {
for tag in context.generate_tags(tags_config, &source, Some(&cancellation_flag))? {
let tag = tag?;
write!(
&mut stdout,
" {:<8} {:<40}\t{:>9}-{:<9}",

View file

@ -3,7 +3,7 @@ use super::helpers::fixtures::{get_language, get_language_queries_path};
use std::ffi::CString;
use std::{fs, ptr, slice, str};
use tree_sitter_tags::c_lib as c;
use tree_sitter_tags::{TagKind, TagsConfiguration, TagsContext};
use tree_sitter_tags::{Error, TagKind, TagsConfiguration, TagsContext};
const PYTHON_TAG_QUERY: &'static str = r#"
((function_definition
@ -79,8 +79,10 @@ fn test_tags_python() {
"#;
let tags = tag_context
.generate_tags(&tags_config, source)
.collect::<Vec<_>>();
.generate_tags(&tags_config, source, None)
.unwrap()
.collect::<Result<Vec<_>, _>>()
.unwrap();
assert_eq!(
tags.iter()
@ -128,8 +130,10 @@ fn test_tags_javascript() {
let mut tag_context = TagsContext::new();
let tags = tag_context
.generate_tags(&tags_config, source)
.collect::<Vec<_>>();
.generate_tags(&tags_config, source, None)
.unwrap()
.collect::<Result<Vec<_>, _>>()
.unwrap();
assert_eq!(
tags.iter()
@ -178,8 +182,10 @@ fn test_tags_ruby() {
let mut tag_context = TagsContext::new();
let tags = tag_context
.generate_tags(&tags_config, source.as_bytes())
.collect::<Vec<_>>();
.generate_tags(&tags_config, source.as_bytes(), None)
.unwrap()
.collect::<Result<Vec<_>, _>>()
.unwrap();
assert_eq!(
tags.iter()
@ -201,6 +207,39 @@ fn test_tags_ruby() {
);
}
#[test]
fn test_tags_cancellation() {
use std::sync::atomic::{AtomicUsize, Ordering};
allocations::record(|| {
// Large javascript document
let source = (0..500)
.map(|_| "/* hi */ class A { /* ok */ b() {} }\n")
.collect::<String>();
let cancellation_flag = AtomicUsize::new(0);
let language = get_language("javascript");
let tags_config = TagsConfiguration::new(language, JS_TAG_QUERY, "").unwrap();
let mut tag_context = TagsContext::new();
let tags = tag_context
.generate_tags(&tags_config, source.as_bytes(), Some(&cancellation_flag))
.unwrap();
for (i, tag) in tags.enumerate() {
if i == 150 {
cancellation_flag.store(1, Ordering::SeqCst);
}
if let Err(e) = tag {
assert_eq!(e, Error::Cancelled);
return;
}
}
panic!("Expected to halt tagging with an error");
});
}
#[test]
fn test_tags_via_c_api() {
allocations::record(|| {

View file

@ -1,12 +1,29 @@
use std::io;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::thread;
use tree_sitter::Parser;
#[cfg(unix)]
use std::path::PathBuf;
#[cfg(unix)]
use std::process::{Child, ChildStdin, Command, Stdio};
use tree_sitter::Parser;
#[cfg(unix)]
const HTML_HEADER: &[u8] = b"<!DOCTYPE html>\n<style>svg { width: 100%; }</style>\n\n";
pub fn cancel_on_stdin() -> Arc<AtomicUsize> {
let result = Arc::new(AtomicUsize::new(0));
thread::spawn({
let flag = result.clone();
move || {
let mut line = String::new();
io::stdin().read_line(&mut line).unwrap();
flag.store(1, Ordering::Relaxed);
}
});
result
}
#[cfg(windows)]
pub struct LogSession();

View file

@ -2,6 +2,7 @@ use super::{Error, TagKind, TagsConfiguration, TagsContext};
use std::collections::HashMap;
use std::ffi::CStr;
use std::process::abort;
use std::sync::atomic::AtomicUsize;
use std::{fmt, ptr, slice, str};
use tree_sitter::Language;
@ -15,6 +16,7 @@ pub enum TSTagsError {
InvalidUtf8,
InvalidRegex,
InvalidQuery,
Unknown,
}
#[repr(C)]
@ -100,6 +102,7 @@ pub extern "C" fn ts_tagger_add_language(
}
Err(Error::Query(_)) => TSTagsError::InvalidQuery,
Err(Error::Regex(_)) => TSTagsError::InvalidRegex,
Err(_) => TSTagsError::Unknown,
}
}
@ -110,7 +113,7 @@ pub extern "C" fn ts_tagger_tag(
source_code: *const u8,
source_code_len: u32,
output: *mut TSTagsBuffer,
cancellation_flag: *const usize,
cancellation_flag: *const AtomicUsize,
) -> TSTagsError {
let tagger = unwrap_mut_ptr(this);
let buffer = unwrap_mut_ptr(output);
@ -120,8 +123,29 @@ pub extern "C" fn ts_tagger_tag(
buffer.tags.clear();
buffer.docs.clear();
let source_code = unsafe { slice::from_raw_parts(source_code, source_code_len as usize) };
let cancellation_flag = unsafe { cancellation_flag.as_ref() };
let tags = match buffer
.context
.generate_tags(config, source_code, cancellation_flag)
{
Ok(tags) => tags,
Err(e) => {
return match e {
Error::InvalidLanguage => TSTagsError::InvalidLanguage,
Error::Cancelled => TSTagsError::Timeout,
_ => TSTagsError::Timeout,
}
}
};
for tag in tags {
let tag = if let Ok(tag) = tag {
tag
} else {
return TSTagsError::Timeout;
};
for tag in buffer.context.generate_tags(config, source_code) {
let prev_docs_len = buffer.docs.len();
if let Some(docs) = tag.docs {
buffer.docs.extend_from_slice(docs.as_bytes());

View file

@ -3,12 +3,14 @@ pub mod c_lib;
use memchr::{memchr, memrchr};
use regex::Regex;
use std::ops::Range;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::{fmt, mem, str};
use tree_sitter::{
Language, Parser, Point, Query, QueryCursor, QueryError, QueryPredicateArg, Tree,
};
const MAX_LINE_LEN: usize = 180;
const CANCELLATION_CHECK_INTERVAL: usize = 100;
/// Contains the data neeeded to compute tags for code written in a
/// particular language.
@ -53,10 +55,12 @@ pub enum TagKind {
Call,
}
#[derive(Debug)]
#[derive(Debug, PartialEq)]
pub enum Error {
Query(QueryError),
Regex(regex::Error),
Cancelled,
InvalidLanguage,
}
#[derive(Debug, Default)]
@ -88,6 +92,8 @@ where
_tree: Tree,
source: &'a [u8],
config: &'a TagsConfiguration,
cancellation_flag: Option<&'a AtomicUsize>,
iter_count: usize,
tag_queue: Vec<(Tag, usize)>,
scopes: Vec<LocalScope<'a>>,
}
@ -201,14 +207,13 @@ impl TagsContext {
&'a mut self,
config: &'a TagsConfiguration,
source: &'a [u8],
) -> impl Iterator<Item = Tag> + 'a {
cancellation_flag: Option<&'a AtomicUsize>,
) -> Result<impl Iterator<Item = Result<Tag, Error>> + 'a, Error> {
self.parser
.set_language(config.language)
.expect("Incompatible language");
let tree = self
.parser
.parse(source, None)
.expect("Parsing failed unexpectedly");
.map_err(|_| Error::InvalidLanguage)?;
unsafe { self.parser.set_cancellation_flag(cancellation_flag) };
let tree = self.parser.parse(source, None).ok_or(Error::Cancelled)?;
// The `matches` iterator borrows the `Tree`, which prevents it from being moved.
// But the tree is really just a pointer, so it's actually ok to move it.
@ -218,18 +223,20 @@ impl TagsContext {
.matches(&config.query, tree_ref.root_node(), move |node| {
&source[node.byte_range()]
});
TagsIter {
Ok(TagsIter {
_tree: tree,
matches,
source,
config,
cancellation_flag,
tag_queue: Vec::new(),
iter_count: 0,
scopes: vec![LocalScope {
range: 0..source.len(),
inherits: false,
local_defs: Vec::new(),
}],
}
})
}
}
@ -237,17 +244,29 @@ impl<'a, I> Iterator for TagsIter<'a, I>
where
I: Iterator<Item = tree_sitter::QueryMatch<'a>>,
{
type Item = Tag;
type Item = Result<Tag, Error>;
fn next(&mut self) -> Option<Tag> {
fn next(&mut self) -> Option<Self::Item> {
loop {
// Periodically check for cancellation, returning `Cancelled` error if the
// cancellation flag was flipped.
if let Some(cancellation_flag) = self.cancellation_flag {
self.iter_count += 1;
if self.iter_count >= CANCELLATION_CHECK_INTERVAL {
self.iter_count = 0;
if cancellation_flag.load(Ordering::Relaxed) != 0 {
return Some(Err(Error::Cancelled));
}
}
}
// If there is a queued tag for an earlier node in the syntax tree, then pop
// it off of the queue and return it.
if let Some(last_entry) = self.tag_queue.last() {
if self.tag_queue.len() > 1
&& self.tag_queue[0].0.name_range.end < last_entry.0.name_range.start
{
return Some(self.tag_queue.remove(0).0);
return Some(Ok(self.tag_queue.remove(0).0));
}
}
@ -420,7 +439,7 @@ where
}
// If there are no more matches, then drain the queue.
else if !self.tag_queue.is_empty() {
return Some(self.tag_queue.remove(0).0);
return Some(Ok(self.tag_queue.remove(0).0));
} else {
return None;
}