From 9d669abac45b586ad6d7e838643b214b34f5c4b3 Mon Sep 17 00:00:00 2001 From: Amaan Qureshi Date: Mon, 3 Jul 2023 20:59:01 -0400 Subject: [PATCH] feat: add encoding flag and automatically check if a file might be utf-16 --- cli/src/main.rs | 33 ++++++++++++++---- cli/src/parse.rs | 91 ++++++++++++++++++++++++++++++------------------ 2 files changed, 84 insertions(+), 40 deletions(-) diff --git a/cli/src/main.rs b/cli/src/main.rs index 0a863b1c..18e50aad 100644 --- a/cli/src/main.rs +++ b/cli/src/main.rs @@ -3,8 +3,8 @@ use clap::{App, AppSettings, Arg, SubCommand}; use glob::glob; use std::path::{Path, PathBuf}; use std::{env, fs, u64}; -use tree_sitter::Point; -use tree_sitter_cli::parse::ParseOutput; +use tree_sitter::{ffi, Point}; +use tree_sitter_cli::parse::{ParseFileOptions, ParseOutput}; use tree_sitter_cli::{ generate, highlight, logger, parse, playground, query, tags, test, test_highlight, test_tags, util, wasm, @@ -162,6 +162,12 @@ fn run() -> Result<()> { .takes_value(true) .multiple(true) .number_of_values(1), + ) + .arg( + Arg::with_name("encoding") + .help("The encoding of the input files") + .long("encoding") + .takes_value(true), ), ) .subcommand( @@ -399,6 +405,16 @@ fn run() -> Result<()> { ParseOutput::Normal }; + let encoding = + matches + .values_of("encoding") + .map_or(Ok(None), |mut e| match e.next() { + Some("utf16") => Ok(Some(ffi::TSInputEncoding_TSInputEncodingUTF16)), + Some("utf8") => Ok(Some(ffi::TSInputEncoding_TSInputEncodingUTF8)), + Some(_) => Err(anyhow!("Invalid encoding. Expected one of: utf8, utf16")), + None => Ok(None), + })?; + let time = matches.is_present("time"); let edits = matches .values_of("edits") @@ -431,18 +447,21 @@ fn run() -> Result<()> { let language = loader.select_language(path, ¤t_dir, matches.value_of("scope"))?; - let this_file_errored = parse::parse_file_at_path( + let opts = ParseFileOptions { language, path, - &edits, + edits: &edits, max_path_length, output, - time, + print_time: time, timeout, debug, debug_graph, - Some(&cancellation_flag), - )?; + cancellation_flag: Some(&cancellation_flag), + encoding, + }; + + let this_file_errored = parse::parse_file_at_path(opts)?; if should_track_stats { stats.total_parses += 1; diff --git a/cli/src/parse.rs b/cli/src/parse.rs index 3e28e51a..6e62e1cf 100644 --- a/cli/src/parse.rs +++ b/cli/src/parse.rs @@ -5,7 +5,7 @@ use std::path::Path; use std::sync::atomic::AtomicUsize; use std::time::Instant; use std::{fmt, fs, usize}; -use tree_sitter::{InputEdit, Language, LogType, Parser, Point, Tree}; +use tree_sitter::{ffi, InputEdit, Language, LogType, Parser, Point, Tree}; #[derive(Debug)] pub struct Edit { @@ -38,37 +38,40 @@ pub enum ParseOutput { Dot, } -pub fn parse_file_at_path( - language: Language, - path: &Path, - edits: &Vec<&str>, - max_path_length: usize, - output: ParseOutput, - print_time: bool, - timeout: u64, - debug: bool, - debug_graph: bool, - cancellation_flag: Option<&AtomicUsize>, -) -> Result { +pub struct ParseFileOptions<'a> { + pub language: Language, + pub path: &'a Path, + pub edits: &'a [&'a str], + pub max_path_length: usize, + pub output: ParseOutput, + pub print_time: bool, + pub timeout: u64, + pub debug: bool, + pub debug_graph: bool, + pub cancellation_flag: Option<&'a AtomicUsize>, + pub encoding: Option, +} + +pub fn parse_file_at_path(opts: ParseFileOptions) -> Result { let mut _log_session = None; let mut parser = Parser::new(); - parser.set_language(language)?; - let mut source_code = - fs::read(path).with_context(|| format!("Error reading source file {:?}", path))?; + parser.set_language(opts.language)?; + let mut source_code = fs::read(opts.path) + .with_context(|| format!("Error reading source file {:?}", opts.path))?; // If the `--cancel` flag was passed, then cancel the parse // when the user types a newline. - unsafe { parser.set_cancellation_flag(cancellation_flag) }; + unsafe { parser.set_cancellation_flag(opts.cancellation_flag) }; // Set a timeout based on the `--time` flag. - parser.set_timeout_micros(timeout); + parser.set_timeout_micros(opts.timeout); // Render an HTML graph if `--debug-graph` was passed - if debug_graph { + if opts.debug_graph { _log_session = Some(util::log_graphs(&mut parser, "log.html")?); } // Log to stderr if `--debug` was passed - else if debug { + else if opts.debug { parser.set_logger(Some(Box::new(|log_type, message| { if log_type == LogType::Lex { io::stderr().write(b" ").unwrap(); @@ -78,22 +81,44 @@ pub fn parse_file_at_path( } let time = Instant::now(); - let tree = parser.parse(&source_code, None); + + #[inline(always)] + fn is_utf16_bom(bom_bytes: &[u8]) -> bool { + bom_bytes == [0xFF, 0xFE] || bom_bytes == [0xFE, 0xFF] + } + + let tree = match opts.encoding { + Some(encoding) if encoding == ffi::TSInputEncoding_TSInputEncodingUTF16 => { + let source_code_utf16 = source_code + .chunks_exact(2) + .map(|chunk| u16::from_le_bytes([chunk[0], chunk[1]])) + .collect::>(); + parser.parse_utf16(&source_code_utf16, None) + } + None if is_utf16_bom(&source_code[0..2]) => { + let source_code_utf16 = source_code + .chunks_exact(2) + .map(|chunk| u16::from_le_bytes([chunk[0], chunk[1]])) + .collect::>(); + parser.parse_utf16(&source_code_utf16, None) + } + _ => parser.parse(&source_code, None), + }; let stdout = io::stdout(); let mut stdout = stdout.lock(); if let Some(mut tree) = tree { - if debug_graph && !edits.is_empty() { + if opts.debug_graph && !opts.edits.is_empty() { println!("BEFORE:\n{}", String::from_utf8_lossy(&source_code)); } - for (i, edit) in edits.iter().enumerate() { + for (i, edit) in opts.edits.iter().enumerate() { let edit = parse_edit_flag(&source_code, edit)?; perform_edit(&mut tree, &mut source_code, &edit); tree = parser.parse(&source_code, Some(&tree)).unwrap(); - if debug_graph { + if opts.debug_graph { println!("AFTER {}:\n{}", i, String::from_utf8_lossy(&source_code)); } } @@ -102,7 +127,7 @@ pub fn parse_file_at_path( let duration_ms = duration.as_secs() * 1000 + duration.subsec_nanos() as u64 / 1000000; let mut cursor = tree.walk(); - if matches!(output, ParseOutput::Normal) { + if matches!(opts.output, ParseOutput::Normal) { let mut needs_newline = false; let mut indent_level = 0; let mut did_visit_children = false; @@ -158,7 +183,7 @@ pub fn parse_file_at_path( println!(""); } - if matches!(output, ParseOutput::Xml) { + if matches!(opts.output, ParseOutput::Xml) { let mut needs_newline = false; let mut indent_level = 0; let mut did_visit_children = false; @@ -213,7 +238,7 @@ pub fn parse_file_at_path( println!(""); } - if matches!(output, ParseOutput::Dot) { + if matches!(opts.output, ParseOutput::Dot) { util::print_tree_graph(&tree, "log.html").unwrap(); } @@ -234,13 +259,13 @@ pub fn parse_file_at_path( } } - if first_error.is_some() || print_time { + if first_error.is_some() || opts.print_time { write!( &mut stdout, "{:width$}\t{} ms", - path.to_str().unwrap(), + opts.path.to_str().unwrap(), duration_ms, - width = max_path_length + width = opts.max_path_length )?; if let Some(node) = first_error { let start = node.start_position(); @@ -269,15 +294,15 @@ pub fn parse_file_at_path( } return Ok(first_error.is_some()); - } else if print_time { + } else if opts.print_time { let duration = time.elapsed(); let duration_ms = duration.as_secs() * 1000 + duration.subsec_nanos() as u64 / 1000000; writeln!( &mut stdout, "{:width$}\t{} ms (timed out)", - path.to_str().unwrap(), + opts.path.to_str().unwrap(), duration_ms, - width = max_path_length + width = opts.max_path_length )?; }