1use crate::util::int::Usize;
2
3/// A representation of byte oriented equivalence classes.
4///
5/// This is used in finite state machines to reduce the size of the transition
6/// table. This can have a particularly large impact not only on the total size
7/// of an FSM, but also on FSM build times because it reduces the number of
8/// transitions that need to be visited/set.
9#[derive(Clone, Copy)]
10pub(crate) struct ByteClasses([u8; 256]);
11
12impl ByteClasses {
13 /// Creates a new set of equivalence classes where all bytes are mapped to
14 /// the same class.
15 pub(crate) fn empty() -> ByteClasses {
16 ByteClasses([0; 256])
17 }
18
19 /// Creates a new set of equivalence classes where each byte belongs to
20 /// its own equivalence class.
21 pub(crate) fn singletons() -> ByteClasses {
22 let mut classes = ByteClasses::empty();
23 for b in 0..=255 {
24 classes.set(b, b);
25 }
26 classes
27 }
28
29 /// Set the equivalence class for the given byte.
30 #[inline]
31 pub(crate) fn set(&mut self, byte: u8, class: u8) {
32 self.0[usize::from(byte)] = class;
33 }
34
35 /// Get the equivalence class for the given byte.
36 #[inline]
37 pub(crate) fn get(&self, byte: u8) -> u8 {
38 self.0[usize::from(byte)]
39 }
40
41 /// Return the total number of elements in the alphabet represented by
42 /// these equivalence classes. Equivalently, this returns the total number
43 /// of equivalence classes.
44 #[inline]
45 pub(crate) fn alphabet_len(&self) -> usize {
46 // Add one since the number of equivalence classes is one bigger than
47 // the last one.
48 usize::from(self.0[255]) + 1
49 }
50
51 /// Returns the stride, as a base-2 exponent, required for these
52 /// equivalence classes.
53 ///
54 /// The stride is always the smallest power of 2 that is greater than or
55 /// equal to the alphabet length. This is done so that converting between
56 /// state IDs and indices can be done with shifts alone, which is much
57 /// faster than integer division. The "stride2" is the exponent. i.e.,
58 /// `2^stride2 = stride`.
59 pub(crate) fn stride2(&self) -> usize {
60 let zeros = self.alphabet_len().next_power_of_two().trailing_zeros();
61 usize::try_from(zeros).unwrap()
62 }
63
64 /// Returns the stride for these equivalence classes, which corresponds
65 /// to the smallest power of 2 greater than or equal to the number of
66 /// equivalence classes.
67 pub(crate) fn stride(&self) -> usize {
68 1 << self.stride2()
69 }
70
71 /// Returns true if and only if every byte in this class maps to its own
72 /// equivalence class. Equivalently, there are 257 equivalence classes
73 /// and each class contains exactly one byte (plus the special EOI class).
74 #[inline]
75 pub(crate) fn is_singleton(&self) -> bool {
76 self.alphabet_len() == 256
77 }
78
79 /// Returns an iterator over all equivalence classes in this set.
80 pub(crate) fn iter(&self) -> ByteClassIter {
81 ByteClassIter { it: 0..self.alphabet_len() }
82 }
83
84 /// Returns an iterator of the bytes in the given equivalence class.
85 pub(crate) fn elements(&self, class: u8) -> ByteClassElements {
86 ByteClassElements { classes: self, class, bytes: 0..=255 }
87 }
88
89 /// Returns an iterator of byte ranges in the given equivalence class.
90 ///
91 /// That is, a sequence of contiguous ranges are returned. Typically, every
92 /// class maps to a single contiguous range.
93 fn element_ranges(&self, class: u8) -> ByteClassElementRanges {
94 ByteClassElementRanges { elements: self.elements(class), range: None }
95 }
96}
97
98impl core::fmt::Debug for ByteClasses {
99 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
100 if self.is_singleton() {
101 write!(f, "ByteClasses(<one-class-per-byte>)")
102 } else {
103 write!(f, "ByteClasses(")?;
104 for (i, class) in self.iter().enumerate() {
105 if i > 0 {
106 write!(f, ", ")?;
107 }
108 write!(f, "{:?} => [", class)?;
109 for (start, end) in self.element_ranges(class) {
110 if start == end {
111 write!(f, "{:?}", start)?;
112 } else {
113 write!(f, "{:?}-{:?}", start, end)?;
114 }
115 }
116 write!(f, "]")?;
117 }
118 write!(f, ")")
119 }
120 }
121}
122
123/// An iterator over each equivalence class.
124#[derive(Debug)]
125pub(crate) struct ByteClassIter {
126 it: core::ops::Range<usize>,
127}
128
129impl Iterator for ByteClassIter {
130 type Item = u8;
131
132 fn next(&mut self) -> Option<u8> {
133 self.it.next().map(|class| class.as_u8())
134 }
135}
136
137/// An iterator over all elements in a specific equivalence class.
138#[derive(Debug)]
139pub(crate) struct ByteClassElements<'a> {
140 classes: &'a ByteClasses,
141 class: u8,
142 bytes: core::ops::RangeInclusive<u8>,
143}
144
145impl<'a> Iterator for ByteClassElements<'a> {
146 type Item = u8;
147
148 fn next(&mut self) -> Option<u8> {
149 while let Some(byte) = self.bytes.next() {
150 if self.class == self.classes.get(byte) {
151 return Some(byte);
152 }
153 }
154 None
155 }
156}
157
158/// An iterator over all elements in an equivalence class expressed as a
159/// sequence of contiguous ranges.
160#[derive(Debug)]
161pub(crate) struct ByteClassElementRanges<'a> {
162 elements: ByteClassElements<'a>,
163 range: Option<(u8, u8)>,
164}
165
166impl<'a> Iterator for ByteClassElementRanges<'a> {
167 type Item = (u8, u8);
168
169 fn next(&mut self) -> Option<(u8, u8)> {
170 loop {
171 let element = match self.elements.next() {
172 None => return self.range.take(),
173 Some(element) => element,
174 };
175 match self.range.take() {
176 None => {
177 self.range = Some((element, element));
178 }
179 Some((start, end)) => {
180 if usize::from(end) + 1 != usize::from(element) {
181 self.range = Some((element, element));
182 return Some((start, end));
183 }
184 self.range = Some((start, element));
185 }
186 }
187 }
188 }
189}
190
191/// A partitioning of bytes into equivalence classes.
192///
193/// A byte class set keeps track of an *approximation* of equivalence classes
194/// of bytes during NFA construction. That is, every byte in an equivalence
195/// class cannot discriminate between a match and a non-match.
196///
197/// Note that this may not compute the minimal set of equivalence classes.
198/// Basically, any byte in a pattern given to the noncontiguous NFA builder
199/// will automatically be treated as its own equivalence class. All other
200/// bytes---any byte not in any pattern---will be treated as their own
201/// equivalence classes. In theory, all bytes not in any pattern should
202/// be part of a single equivalence class, but in practice, we only treat
203/// contiguous ranges of bytes as an equivalence class. So the number of
204/// classes computed may be bigger than necessary. This usually doesn't make
205/// much of a difference, and keeps the implementation simple.
206#[derive(Clone, Debug)]
207pub(crate) struct ByteClassSet(ByteSet);
208
209impl Default for ByteClassSet {
210 fn default() -> ByteClassSet {
211 ByteClassSet::empty()
212 }
213}
214
215impl ByteClassSet {
216 /// Create a new set of byte classes where all bytes are part of the same
217 /// equivalence class.
218 pub(crate) fn empty() -> Self {
219 ByteClassSet(ByteSet::empty())
220 }
221
222 /// Indicate the the range of byte given (inclusive) can discriminate a
223 /// match between it and all other bytes outside of the range.
224 pub(crate) fn set_range(&mut self, start: u8, end: u8) {
225 debug_assert!(start <= end);
226 if start > 0 {
227 self.0.add(start - 1);
228 }
229 self.0.add(end);
230 }
231
232 /// Convert this boolean set to a map that maps all byte values to their
233 /// corresponding equivalence class. The last mapping indicates the largest
234 /// equivalence class identifier (which is never bigger than 255).
235 pub(crate) fn byte_classes(&self) -> ByteClasses {
236 let mut classes = ByteClasses::empty();
237 let mut class = 0u8;
238 let mut b = 0u8;
239 loop {
240 classes.set(b, class);
241 if b == 255 {
242 break;
243 }
244 if self.0.contains(b) {
245 class = class.checked_add(1).unwrap();
246 }
247 b = b.checked_add(1).unwrap();
248 }
249 classes
250 }
251}
252
253/// A simple set of bytes that is reasonably cheap to copy and allocation free.
254#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
255pub(crate) struct ByteSet {
256 bits: BitSet,
257}
258
259/// The representation of a byte set. Split out so that we can define a
260/// convenient Debug impl for it while keeping "ByteSet" in the output.
261#[derive(Clone, Copy, Default, Eq, PartialEq)]
262struct BitSet([u128; 2]);
263
264impl ByteSet {
265 /// Create an empty set of bytes.
266 pub(crate) fn empty() -> ByteSet {
267 ByteSet { bits: BitSet([0; 2]) }
268 }
269
270 /// Add a byte to this set.
271 ///
272 /// If the given byte already belongs to this set, then this is a no-op.
273 pub(crate) fn add(&mut self, byte: u8) {
274 let bucket = byte / 128;
275 let bit = byte % 128;
276 self.bits.0[usize::from(bucket)] |= 1 << bit;
277 }
278
279 /// Return true if and only if the given byte is in this set.
280 pub(crate) fn contains(&self, byte: u8) -> bool {
281 let bucket = byte / 128;
282 let bit = byte % 128;
283 self.bits.0[usize::from(bucket)] & (1 << bit) > 0
284 }
285}
286
287impl core::fmt::Debug for BitSet {
288 fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
289 let mut fmtd = f.debug_set();
290 for b in 0u8..=255 {
291 if (ByteSet { bits: *self }).contains(b) {
292 fmtd.entry(&b);
293 }
294 }
295 fmtd.finish()
296 }
297}
298
299#[cfg(test)]
300mod tests {
301 use alloc::{vec, vec::Vec};
302
303 use super::*;
304
305 #[test]
306 fn byte_classes() {
307 let mut set = ByteClassSet::empty();
308 set.set_range(b'a', b'z');
309
310 let classes = set.byte_classes();
311 assert_eq!(classes.get(0), 0);
312 assert_eq!(classes.get(1), 0);
313 assert_eq!(classes.get(2), 0);
314 assert_eq!(classes.get(b'a' - 1), 0);
315 assert_eq!(classes.get(b'a'), 1);
316 assert_eq!(classes.get(b'm'), 1);
317 assert_eq!(classes.get(b'z'), 1);
318 assert_eq!(classes.get(b'z' + 1), 2);
319 assert_eq!(classes.get(254), 2);
320 assert_eq!(classes.get(255), 2);
321
322 let mut set = ByteClassSet::empty();
323 set.set_range(0, 2);
324 set.set_range(4, 6);
325 let classes = set.byte_classes();
326 assert_eq!(classes.get(0), 0);
327 assert_eq!(classes.get(1), 0);
328 assert_eq!(classes.get(2), 0);
329 assert_eq!(classes.get(3), 1);
330 assert_eq!(classes.get(4), 2);
331 assert_eq!(classes.get(5), 2);
332 assert_eq!(classes.get(6), 2);
333 assert_eq!(classes.get(7), 3);
334 assert_eq!(classes.get(255), 3);
335 }
336
337 #[test]
338 fn full_byte_classes() {
339 let mut set = ByteClassSet::empty();
340 for b in 0u8..=255 {
341 set.set_range(b, b);
342 }
343 assert_eq!(set.byte_classes().alphabet_len(), 256);
344 }
345
346 #[test]
347 fn elements_typical() {
348 let mut set = ByteClassSet::empty();
349 set.set_range(b'b', b'd');
350 set.set_range(b'g', b'm');
351 set.set_range(b'z', b'z');
352 let classes = set.byte_classes();
353 // class 0: \x00-a
354 // class 1: b-d
355 // class 2: e-f
356 // class 3: g-m
357 // class 4: n-y
358 // class 5: z-z
359 // class 6: \x7B-\xFF
360 assert_eq!(classes.alphabet_len(), 7);
361
362 let elements = classes.elements(0).collect::<Vec<_>>();
363 assert_eq!(elements.len(), 98);
364 assert_eq!(elements[0], b'\x00');
365 assert_eq!(elements[97], b'a');
366
367 let elements = classes.elements(1).collect::<Vec<_>>();
368 assert_eq!(elements, vec![b'b', b'c', b'd'],);
369
370 let elements = classes.elements(2).collect::<Vec<_>>();
371 assert_eq!(elements, vec![b'e', b'f'],);
372
373 let elements = classes.elements(3).collect::<Vec<_>>();
374 assert_eq!(elements, vec![b'g', b'h', b'i', b'j', b'k', b'l', b'm',],);
375
376 let elements = classes.elements(4).collect::<Vec<_>>();
377 assert_eq!(elements.len(), 12);
378 assert_eq!(elements[0], b'n');
379 assert_eq!(elements[11], b'y');
380
381 let elements = classes.elements(5).collect::<Vec<_>>();
382 assert_eq!(elements, vec![b'z']);
383
384 let elements = classes.elements(6).collect::<Vec<_>>();
385 assert_eq!(elements.len(), 133);
386 assert_eq!(elements[0], b'\x7B');
387 assert_eq!(elements[132], b'\xFF');
388 }
389
390 #[test]
391 fn elements_singletons() {
392 let classes = ByteClasses::singletons();
393 assert_eq!(classes.alphabet_len(), 256);
394
395 let elements = classes.elements(b'a').collect::<Vec<_>>();
396 assert_eq!(elements, vec![b'a']);
397 }
398
399 #[test]
400 fn elements_empty() {
401 let classes = ByteClasses::empty();
402 assert_eq!(classes.alphabet_len(), 1);
403
404 let elements = classes.elements(0).collect::<Vec<_>>();
405 assert_eq!(elements.len(), 256);
406 assert_eq!(elements[0], b'\x00');
407 assert_eq!(elements[255], b'\xFF');
408 }
409}
410