feat(generate)!: use regex_syntax::Hir for expanding regexes

Co-authored-by: Piotr Osiewicz <24362066+osiewicz@users.noreply.github.com>
This commit is contained in:
Amaan Qureshi 2024-10-26 18:31:40 -04:00
parent c707f3ee9e
commit c8cf75fd30
2 changed files with 97 additions and 271 deletions

View file

@ -58,7 +58,8 @@ impl CharacterSet {
/// Create a character set with a given *inclusive* range of characters.
#[allow(clippy::single_range_in_vec_init)]
pub fn from_range(mut first: char, mut last: char) -> Self {
#[cfg(test)]
fn from_range(mut first: char, mut last: char) -> Self {
if first > last {
swap(&mut first, &mut last);
}
@ -286,7 +287,8 @@ impl CharacterSet {
/// Produces a `CharacterSet` containing every character that is in _exactly one_ of `self` or
/// `other`, but is not present in both sets.
pub fn symmetric_difference(mut self, mut other: Self) -> Self {
#[cfg(test)]
fn symmetric_difference(mut self, mut other: Self) -> Self {
self.remove_intersection(&mut other);
self.add(&other)
}

View file

@ -2,9 +2,9 @@ use std::collections::HashMap;
use anyhow::{anyhow, Context, Result};
use lazy_static::lazy_static;
use regex_syntax::ast::{
parse, Ast, ClassPerlKind, ClassSet, ClassSetBinaryOpKind, ClassSetItem, ClassUnicodeKind,
RepetitionKind, RepetitionRange,
use regex_syntax::{
hir::{Class, Hir, HirKind},
ParserBuilder,
};
use super::ExtractedLexicalGrammar;
@ -114,8 +114,25 @@ impl NfaBuilder {
fn expand_rule(&mut self, rule: &Rule, mut next_state_id: u32) -> Result<bool> {
match rule {
Rule::Pattern(s, f) => {
let ast = parse::Parser::new().parse(s)?;
self.expand_regex(&ast, next_state_id, f.contains('i'))
// With unicode enabled, `\w`, `\s` and `\d` expand to character sets that are much
// larger than intended, so we replace them with the actual
// character sets they should represent. If the full unicode range
// of `\w`, `\s` or `\d` are needed then `\p{L}`, `\p{Z}` and `\p{N}` should be
// used.
let s = s
.replace(r"\w", r"[0-9A-Za-z_]")
.replace(r"\s", r"[\t-\r ]")
.replace(r"\d", r"[0-9]")
.replace(r"\W", r"[^0-9A-Za-z_]")
.replace(r"\S", r"[^\t-\r ]")
.replace(r"\D", r"[^0-9]");
let mut parser = ParserBuilder::new()
.case_insensitive(f.contains('i'))
.unicode(true)
.utf8(false)
.build();
let hir = parser.parse(&s)?;
self.expand_regex(&hir, next_state_id)
}
Rule::String(s) => {
for c in s.chars().rev() {
@ -183,125 +200,90 @@ impl NfaBuilder {
}
}
fn expand_regex(
&mut self,
ast: &Ast,
mut next_state_id: u32,
case_insensitive: bool,
) -> Result<bool> {
const fn inverse_char(c: char) -> char {
match c {
'a'..='z' => (c as u8 - b'a' + b'A') as char,
'A'..='Z' => (c as u8 - b'A' + b'a') as char,
c => c,
}
}
fn with_inverse_char(mut chars: CharacterSet) -> CharacterSet {
for char in chars.clone().chars() {
let inverted = inverse_char(char);
if char != inverted {
chars = chars.add_char(inverted);
fn expand_regex(&mut self, hir: &Hir, mut next_state_id: u32) -> Result<bool> {
match hir.kind() {
HirKind::Empty => Ok(false),
HirKind::Literal(literal) => {
for character in std::str::from_utf8(&literal.0)?.chars().rev() {
let char_set = CharacterSet::from_char(character);
self.push_advance(char_set, next_state_id);
next_state_id = self.nfa.last_state_id();
}
}
chars
}
match ast {
Ast::Empty(_) => Ok(false),
Ast::Flags(_) => Err(anyhow!("Regex error: Flags are not supported")),
Ast::Literal(literal) => {
let mut char_set = CharacterSet::from_char(literal.c);
if case_insensitive {
let inverted = inverse_char(literal.c);
if literal.c != inverted {
char_set = char_set.add_char(inverted);
Ok(true)
}
HirKind::Class(class) => match class {
Class::Unicode(class) => {
let mut chars = CharacterSet::default();
for c in class.ranges() {
chars = chars.add_range(c.start(), c.end());
}
// For some reason, the long s `ſ` is included if the letter `s` is in a
// pattern, so we remove it.
if chars.range_count() == 3
&& chars
.ranges()
// exact check to ensure that `ſ` wasn't intentionally added.
.all(|r| ['s'..='s', 'S'..='S', 'ſ'..='ſ'].contains(&r))
{
chars = chars.difference(CharacterSet::from_char('ſ'));
}
self.push_advance(chars, next_state_id);
Ok(true)
}
self.push_advance(char_set, next_state_id);
Ok(true)
}
Ast::Dot(_) => {
self.push_advance(CharacterSet::from_char('\n').negate(), next_state_id);
Ok(true)
}
Ast::Assertion(_) => Err(anyhow!("Regex error: Assertions are not supported")),
Ast::ClassUnicode(class) => {
let mut chars = self.expand_unicode_character_class(&class.kind)?;
if class.negated {
chars = chars.negate();
Class::Bytes(bytes_class) => {
let mut chars = CharacterSet::default();
for c in bytes_class.ranges() {
chars = chars.add_range(c.start().into(), c.end().into());
}
self.push_advance(chars, next_state_id);
Ok(true)
}
if case_insensitive {
chars = with_inverse_char(chars);
},
HirKind::Look(_) => Err(anyhow!("Regex error: Assertions are not supported")),
HirKind::Repetition(repetition) => match (repetition.min, repetition.max) {
(0, Some(1)) => self.expand_zero_or_one(&repetition.sub, next_state_id),
(1, None) => self.expand_one_or_more(&repetition.sub, next_state_id),
(0, None) => self.expand_zero_or_more(&repetition.sub, next_state_id),
(min, Some(max)) if min == max => {
self.expand_count(&repetition.sub, min, next_state_id)
}
self.push_advance(chars, next_state_id);
Ok(true)
}
Ast::ClassPerl(class) => {
let mut chars = self.expand_perl_character_class(&class.kind);
if class.negated {
chars = chars.negate();
}
if case_insensitive {
chars = with_inverse_char(chars);
}
self.push_advance(chars, next_state_id);
Ok(true)
}
Ast::ClassBracketed(class) => {
let mut chars = self.translate_class_set(&class.kind)?;
if class.negated {
chars = chars.negate();
}
if case_insensitive {
chars = with_inverse_char(chars);
}
self.push_advance(chars, next_state_id);
Ok(true)
}
Ast::Repetition(repetition) => match repetition.op.kind {
RepetitionKind::ZeroOrOne => {
self.expand_zero_or_one(&repetition.ast, next_state_id, case_insensitive)
}
RepetitionKind::OneOrMore => {
self.expand_one_or_more(&repetition.ast, next_state_id, case_insensitive)
}
RepetitionKind::ZeroOrMore => {
self.expand_zero_or_more(&repetition.ast, next_state_id, case_insensitive)
}
RepetitionKind::Range(RepetitionRange::Exactly(count)) => {
self.expand_count(&repetition.ast, count, next_state_id, case_insensitive)
}
RepetitionKind::Range(RepetitionRange::AtLeast(min)) => {
if self.expand_zero_or_more(&repetition.ast, next_state_id, case_insensitive)? {
self.expand_count(&repetition.ast, min, next_state_id, case_insensitive)
(min, None) => {
if self.expand_zero_or_more(&repetition.sub, next_state_id)? {
self.expand_count(&repetition.sub, min, next_state_id)
} else {
Ok(false)
}
}
RepetitionKind::Range(RepetitionRange::Bounded(min, max)) => {
let mut result =
self.expand_count(&repetition.ast, min, next_state_id, case_insensitive)?;
(min, Some(max)) => {
let mut result = self.expand_count(&repetition.sub, min, next_state_id)?;
for _ in min..max {
if result {
next_state_id = self.nfa.last_state_id();
}
if self.expand_zero_or_one(
&repetition.ast,
next_state_id,
case_insensitive,
)? {
if self.expand_zero_or_one(&repetition.sub, next_state_id)? {
result = true;
}
}
Ok(result)
}
},
Ast::Group(group) => self.expand_regex(&group.ast, next_state_id, case_insensitive),
Ast::Alternation(alternation) => {
HirKind::Capture(capture) => self.expand_regex(&capture.sub, next_state_id),
HirKind::Concat(concat) => {
let mut result = false;
for hir in concat.iter().rev() {
if self.expand_regex(hir, next_state_id)? {
result = true;
next_state_id = self.nfa.last_state_id();
}
}
Ok(result)
}
HirKind::Alternation(alternations) => {
let mut alternative_state_ids = Vec::new();
for ast in &alternation.asts {
if self.expand_regex(ast, next_state_id, case_insensitive)? {
for hir in alternations {
if self.expand_regex(hir, next_state_id)? {
alternative_state_ids.push(self.nfa.last_state_id());
} else {
alternative_state_ids.push(next_state_id);
@ -310,58 +292,21 @@ impl NfaBuilder {
alternative_state_ids.sort_unstable();
alternative_state_ids.dedup();
alternative_state_ids.retain(|i| *i != self.nfa.last_state_id());
for alternative_state_id in alternative_state_ids {
self.push_split(alternative_state_id);
}
Ok(true)
}
Ast::Concat(concat) => {
let mut result = false;
for ast in concat.asts.iter().rev() {
if self.expand_regex(ast, next_state_id, case_insensitive)? {
result = true;
next_state_id = self.nfa.last_state_id();
}
}
Ok(result)
}
}
}
fn translate_class_set(&self, class_set: &ClassSet) -> Result<CharacterSet> {
match &class_set {
ClassSet::Item(item) => self.expand_character_class(item),
ClassSet::BinaryOp(binary_op) => {
let mut lhs_char_class = self.translate_class_set(&binary_op.lhs)?;
let mut rhs_char_class = self.translate_class_set(&binary_op.rhs)?;
match binary_op.kind {
ClassSetBinaryOpKind::Intersection => {
Ok(lhs_char_class.remove_intersection(&mut rhs_char_class))
}
ClassSetBinaryOpKind::Difference => {
Ok(lhs_char_class.difference(rhs_char_class))
}
ClassSetBinaryOpKind::SymmetricDifference => {
Ok(lhs_char_class.symmetric_difference(rhs_char_class))
}
}
}
}
}
fn expand_one_or_more(
&mut self,
ast: &Ast,
next_state_id: u32,
case_insensitive: bool,
) -> Result<bool> {
fn expand_one_or_more(&mut self, hir: &Hir, next_state_id: u32) -> Result<bool> {
self.nfa.states.push(NfaState::Accept {
variable_index: 0,
precedence: 0,
}); // Placeholder for split
let split_state_id = self.nfa.last_state_id();
if self.expand_regex(ast, split_state_id, case_insensitive)? {
if self.expand_regex(hir, split_state_id)? {
self.nfa.states[split_state_id as usize] =
NfaState::Split(self.nfa.last_state_id(), next_state_id);
Ok(true)
@ -371,13 +316,8 @@ impl NfaBuilder {
}
}
fn expand_zero_or_one(
&mut self,
ast: &Ast,
next_state_id: u32,
case_insensitive: bool,
) -> Result<bool> {
if self.expand_regex(ast, next_state_id, case_insensitive)? {
fn expand_zero_or_one(&mut self, hir: &Hir, next_state_id: u32) -> Result<bool> {
if self.expand_regex(hir, next_state_id)? {
self.push_split(next_state_id);
Ok(true)
} else {
@ -385,13 +325,8 @@ impl NfaBuilder {
}
}
fn expand_zero_or_more(
&mut self,
ast: &Ast,
next_state_id: u32,
case_insensitive: bool,
) -> Result<bool> {
if self.expand_one_or_more(ast, next_state_id, case_insensitive)? {
fn expand_zero_or_more(&mut self, hir: &Hir, next_state_id: u32) -> Result<bool> {
if self.expand_one_or_more(hir, next_state_id)? {
self.push_split(next_state_id);
Ok(true)
} else {
@ -399,16 +334,10 @@ impl NfaBuilder {
}
}
fn expand_count(
&mut self,
ast: &Ast,
count: u32,
mut next_state_id: u32,
case_insensitive: bool,
) -> Result<bool> {
fn expand_count(&mut self, hir: &Hir, count: u32, mut next_state_id: u32) -> Result<bool> {
let mut result = false;
for _ in 0..count {
if self.expand_regex(ast, next_state_id, case_insensitive)? {
if self.expand_regex(hir, next_state_id)? {
result = true;
next_state_id = self.nfa.last_state_id();
}
@ -416,111 +345,6 @@ impl NfaBuilder {
Ok(result)
}
fn expand_character_class(&self, item: &ClassSetItem) -> Result<CharacterSet> {
match item {
ClassSetItem::Empty(_) => Ok(CharacterSet::empty()),
ClassSetItem::Literal(literal) => Ok(CharacterSet::from_char(literal.c)),
ClassSetItem::Range(range) => Ok(CharacterSet::from_range(range.start.c, range.end.c)),
ClassSetItem::Union(union) => {
let mut result = CharacterSet::empty();
for item in &union.items {
result = result.add(&self.expand_character_class(item)?);
}
Ok(result)
}
ClassSetItem::Perl(class) => Ok(self.expand_perl_character_class(&class.kind)),
ClassSetItem::Unicode(class) => {
let mut set = self.expand_unicode_character_class(&class.kind)?;
if class.negated {
set = set.negate();
}
Ok(set)
}
ClassSetItem::Bracketed(class) => {
let mut set = self.translate_class_set(&class.kind)?;
if class.negated {
set = set.negate();
}
Ok(set)
}
ClassSetItem::Ascii(_) => Err(anyhow!(
"Regex error: Unsupported character class syntax {item:?}",
)),
}
}
fn expand_unicode_character_class(&self, class: &ClassUnicodeKind) -> Result<CharacterSet> {
let mut chars = CharacterSet::empty();
let category_letter;
match class {
ClassUnicodeKind::OneLetter(le) => {
category_letter = le.to_string();
}
ClassUnicodeKind::Named(class_name) => {
let actual_class_name = UNICODE_CATEGORY_ALIASES
.get(class_name.as_str())
.or_else(|| UNICODE_PROPERTY_ALIASES.get(class_name.as_str()))
.unwrap_or(class_name);
if actual_class_name.len() == 1 {
category_letter = actual_class_name.clone();
} else {
let code_points =
UNICODE_CATEGORIES
.get(actual_class_name.as_str())
.or_else(|| UNICODE_PROPERTIES.get(actual_class_name.as_str()))
.ok_or_else(|| {
anyhow!(
"Regex error: Unsupported unicode character class {class_name}",
)
})?;
for c in code_points {
if let Some(c) = char::from_u32(*c) {
chars = chars.add_char(c);
}
}
return Ok(chars);
}
}
ClassUnicodeKind::NamedValue { .. } => {
return Err(anyhow!(
"Regex error: Key-value unicode properties are not supported"
))
}
}
for (category, code_points) in UNICODE_CATEGORIES.iter() {
if category.starts_with(&category_letter) {
for c in code_points {
if let Some(c) = char::from_u32(*c) {
chars = chars.add_char(c);
}
}
}
}
Ok(chars)
}
fn expand_perl_character_class(&self, item: &ClassPerlKind) -> CharacterSet {
match item {
ClassPerlKind::Digit => CharacterSet::from_range('0', '9'),
ClassPerlKind::Space => CharacterSet::empty()
.add_char(' ')
.add_char('\t')
.add_char('\r')
.add_char('\n')
.add_char('\x0B')
.add_char('\x0C'),
ClassPerlKind::Word => CharacterSet::empty()
.add_char('_')
.add_range('A', 'Z')
.add_range('a', 'z')
.add_range('0', '9'),
}
}
fn push_advance(&mut self, chars: CharacterSet, state_id: u32) {
let precedence = *self.precedence_stack.last().unwrap();
self.nfa.states.push(NfaState::Advance {