1 | use crate::util::int::Usize; |
2 | |
3 | /// A representation of byte oriented equivalence classes. |
4 | /// |
5 | /// This is used in finite state machines to reduce the size of the transition |
6 | /// table. This can have a particularly large impact not only on the total size |
7 | /// of an FSM, but also on FSM build times because it reduces the number of |
8 | /// transitions that need to be visited/set. |
9 | #[derive(Clone, Copy)] |
10 | pub(crate) struct ByteClasses([u8; 256]); |
11 | |
12 | impl ByteClasses { |
13 | /// Creates a new set of equivalence classes where all bytes are mapped to |
14 | /// the same class. |
15 | pub(crate) fn empty() -> ByteClasses { |
16 | ByteClasses([0; 256]) |
17 | } |
18 | |
19 | /// Creates a new set of equivalence classes where each byte belongs to |
20 | /// its own equivalence class. |
21 | pub(crate) fn singletons() -> ByteClasses { |
22 | let mut classes = ByteClasses::empty(); |
23 | for b in 0..=255 { |
24 | classes.set(b, b); |
25 | } |
26 | classes |
27 | } |
28 | |
29 | /// Set the equivalence class for the given byte. |
30 | #[inline ] |
31 | pub(crate) fn set(&mut self, byte: u8, class: u8) { |
32 | self.0[usize::from(byte)] = class; |
33 | } |
34 | |
35 | /// Get the equivalence class for the given byte. |
36 | #[inline ] |
37 | pub(crate) fn get(&self, byte: u8) -> u8 { |
38 | self.0[usize::from(byte)] |
39 | } |
40 | |
41 | /// Return the total number of elements in the alphabet represented by |
42 | /// these equivalence classes. Equivalently, this returns the total number |
43 | /// of equivalence classes. |
44 | #[inline ] |
45 | pub(crate) fn alphabet_len(&self) -> usize { |
46 | // Add one since the number of equivalence classes is one bigger than |
47 | // the last one. |
48 | usize::from(self.0[255]) + 1 |
49 | } |
50 | |
51 | /// Returns the stride, as a base-2 exponent, required for these |
52 | /// equivalence classes. |
53 | /// |
54 | /// The stride is always the smallest power of 2 that is greater than or |
55 | /// equal to the alphabet length. This is done so that converting between |
56 | /// state IDs and indices can be done with shifts alone, which is much |
57 | /// faster than integer division. The "stride2" is the exponent. i.e., |
58 | /// `2^stride2 = stride`. |
59 | pub(crate) fn stride2(&self) -> usize { |
60 | let zeros = self.alphabet_len().next_power_of_two().trailing_zeros(); |
61 | usize::try_from(zeros).unwrap() |
62 | } |
63 | |
64 | /// Returns the stride for these equivalence classes, which corresponds |
65 | /// to the smallest power of 2 greater than or equal to the number of |
66 | /// equivalence classes. |
67 | pub(crate) fn stride(&self) -> usize { |
68 | 1 << self.stride2() |
69 | } |
70 | |
71 | /// Returns true if and only if every byte in this class maps to its own |
72 | /// equivalence class. Equivalently, there are 257 equivalence classes |
73 | /// and each class contains exactly one byte (plus the special EOI class). |
74 | #[inline ] |
75 | pub(crate) fn is_singleton(&self) -> bool { |
76 | self.alphabet_len() == 256 |
77 | } |
78 | |
79 | /// Returns an iterator over all equivalence classes in this set. |
80 | pub(crate) fn iter(&self) -> ByteClassIter { |
81 | ByteClassIter { it: 0..self.alphabet_len() } |
82 | } |
83 | |
84 | /// Returns an iterator of the bytes in the given equivalence class. |
85 | pub(crate) fn elements(&self, class: u8) -> ByteClassElements { |
86 | ByteClassElements { classes: self, class, bytes: 0..=255 } |
87 | } |
88 | |
89 | /// Returns an iterator of byte ranges in the given equivalence class. |
90 | /// |
91 | /// That is, a sequence of contiguous ranges are returned. Typically, every |
92 | /// class maps to a single contiguous range. |
93 | fn element_ranges(&self, class: u8) -> ByteClassElementRanges { |
94 | ByteClassElementRanges { elements: self.elements(class), range: None } |
95 | } |
96 | } |
97 | |
98 | impl core::fmt::Debug for ByteClasses { |
99 | fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { |
100 | if self.is_singleton() { |
101 | write!(f, "ByteClasses(<one-class-per-byte>)" ) |
102 | } else { |
103 | write!(f, "ByteClasses(" )?; |
104 | for (i, class) in self.iter().enumerate() { |
105 | if i > 0 { |
106 | write!(f, ", " )?; |
107 | } |
108 | write!(f, "{:?} => [" , class)?; |
109 | for (start, end) in self.element_ranges(class) { |
110 | if start == end { |
111 | write!(f, "{:?}" , start)?; |
112 | } else { |
113 | write!(f, "{:?}-{:?}" , start, end)?; |
114 | } |
115 | } |
116 | write!(f, "]" )?; |
117 | } |
118 | write!(f, ")" ) |
119 | } |
120 | } |
121 | } |
122 | |
123 | /// An iterator over each equivalence class. |
124 | #[derive(Debug)] |
125 | pub(crate) struct ByteClassIter { |
126 | it: core::ops::Range<usize>, |
127 | } |
128 | |
129 | impl Iterator for ByteClassIter { |
130 | type Item = u8; |
131 | |
132 | fn next(&mut self) -> Option<u8> { |
133 | self.it.next().map(|class| class.as_u8()) |
134 | } |
135 | } |
136 | |
137 | /// An iterator over all elements in a specific equivalence class. |
138 | #[derive(Debug)] |
139 | pub(crate) struct ByteClassElements<'a> { |
140 | classes: &'a ByteClasses, |
141 | class: u8, |
142 | bytes: core::ops::RangeInclusive<u8>, |
143 | } |
144 | |
145 | impl<'a> Iterator for ByteClassElements<'a> { |
146 | type Item = u8; |
147 | |
148 | fn next(&mut self) -> Option<u8> { |
149 | while let Some(byte) = self.bytes.next() { |
150 | if self.class == self.classes.get(byte) { |
151 | return Some(byte); |
152 | } |
153 | } |
154 | None |
155 | } |
156 | } |
157 | |
158 | /// An iterator over all elements in an equivalence class expressed as a |
159 | /// sequence of contiguous ranges. |
160 | #[derive(Debug)] |
161 | pub(crate) struct ByteClassElementRanges<'a> { |
162 | elements: ByteClassElements<'a>, |
163 | range: Option<(u8, u8)>, |
164 | } |
165 | |
166 | impl<'a> Iterator for ByteClassElementRanges<'a> { |
167 | type Item = (u8, u8); |
168 | |
169 | fn next(&mut self) -> Option<(u8, u8)> { |
170 | loop { |
171 | let element = match self.elements.next() { |
172 | None => return self.range.take(), |
173 | Some(element) => element, |
174 | }; |
175 | match self.range.take() { |
176 | None => { |
177 | self.range = Some((element, element)); |
178 | } |
179 | Some((start, end)) => { |
180 | if usize::from(end) + 1 != usize::from(element) { |
181 | self.range = Some((element, element)); |
182 | return Some((start, end)); |
183 | } |
184 | self.range = Some((start, element)); |
185 | } |
186 | } |
187 | } |
188 | } |
189 | } |
190 | |
191 | /// A partitioning of bytes into equivalence classes. |
192 | /// |
193 | /// A byte class set keeps track of an *approximation* of equivalence classes |
194 | /// of bytes during NFA construction. That is, every byte in an equivalence |
195 | /// class cannot discriminate between a match and a non-match. |
196 | /// |
197 | /// Note that this may not compute the minimal set of equivalence classes. |
198 | /// Basically, any byte in a pattern given to the noncontiguous NFA builder |
199 | /// will automatically be treated as its own equivalence class. All other |
200 | /// bytes---any byte not in any pattern---will be treated as their own |
201 | /// equivalence classes. In theory, all bytes not in any pattern should |
202 | /// be part of a single equivalence class, but in practice, we only treat |
203 | /// contiguous ranges of bytes as an equivalence class. So the number of |
204 | /// classes computed may be bigger than necessary. This usually doesn't make |
205 | /// much of a difference, and keeps the implementation simple. |
206 | #[derive(Clone, Debug)] |
207 | pub(crate) struct ByteClassSet(ByteSet); |
208 | |
209 | impl Default for ByteClassSet { |
210 | fn default() -> ByteClassSet { |
211 | ByteClassSet::empty() |
212 | } |
213 | } |
214 | |
215 | impl ByteClassSet { |
216 | /// Create a new set of byte classes where all bytes are part of the same |
217 | /// equivalence class. |
218 | pub(crate) fn empty() -> Self { |
219 | ByteClassSet(ByteSet::empty()) |
220 | } |
221 | |
222 | /// Indicate the the range of byte given (inclusive) can discriminate a |
223 | /// match between it and all other bytes outside of the range. |
224 | pub(crate) fn set_range(&mut self, start: u8, end: u8) { |
225 | debug_assert!(start <= end); |
226 | if start > 0 { |
227 | self.0.add(start - 1); |
228 | } |
229 | self.0.add(end); |
230 | } |
231 | |
232 | /// Convert this boolean set to a map that maps all byte values to their |
233 | /// corresponding equivalence class. The last mapping indicates the largest |
234 | /// equivalence class identifier (which is never bigger than 255). |
235 | pub(crate) fn byte_classes(&self) -> ByteClasses { |
236 | let mut classes = ByteClasses::empty(); |
237 | let mut class = 0u8; |
238 | let mut b = 0u8; |
239 | loop { |
240 | classes.set(b, class); |
241 | if b == 255 { |
242 | break; |
243 | } |
244 | if self.0.contains(b) { |
245 | class = class.checked_add(1).unwrap(); |
246 | } |
247 | b = b.checked_add(1).unwrap(); |
248 | } |
249 | classes |
250 | } |
251 | } |
252 | |
253 | /// A simple set of bytes that is reasonably cheap to copy and allocation free. |
254 | #[derive(Clone, Copy, Debug, Default, Eq, PartialEq)] |
255 | pub(crate) struct ByteSet { |
256 | bits: BitSet, |
257 | } |
258 | |
259 | /// The representation of a byte set. Split out so that we can define a |
260 | /// convenient Debug impl for it while keeping "ByteSet" in the output. |
261 | #[derive(Clone, Copy, Default, Eq, PartialEq)] |
262 | struct BitSet([u128; 2]); |
263 | |
264 | impl ByteSet { |
265 | /// Create an empty set of bytes. |
266 | pub(crate) fn empty() -> ByteSet { |
267 | ByteSet { bits: BitSet([0; 2]) } |
268 | } |
269 | |
270 | /// Add a byte to this set. |
271 | /// |
272 | /// If the given byte already belongs to this set, then this is a no-op. |
273 | pub(crate) fn add(&mut self, byte: u8) { |
274 | let bucket = byte / 128; |
275 | let bit = byte % 128; |
276 | self.bits.0[usize::from(bucket)] |= 1 << bit; |
277 | } |
278 | |
279 | /// Return true if and only if the given byte is in this set. |
280 | pub(crate) fn contains(&self, byte: u8) -> bool { |
281 | let bucket = byte / 128; |
282 | let bit = byte % 128; |
283 | self.bits.0[usize::from(bucket)] & (1 << bit) > 0 |
284 | } |
285 | } |
286 | |
287 | impl core::fmt::Debug for BitSet { |
288 | fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { |
289 | let mut fmtd = f.debug_set(); |
290 | for b in 0u8..=255 { |
291 | if (ByteSet { bits: *self }).contains(b) { |
292 | fmtd.entry(&b); |
293 | } |
294 | } |
295 | fmtd.finish() |
296 | } |
297 | } |
298 | |
299 | #[cfg (test)] |
300 | mod tests { |
301 | use alloc::{vec, vec::Vec}; |
302 | |
303 | use super::*; |
304 | |
305 | #[test] |
306 | fn byte_classes() { |
307 | let mut set = ByteClassSet::empty(); |
308 | set.set_range(b'a' , b'z' ); |
309 | |
310 | let classes = set.byte_classes(); |
311 | assert_eq!(classes.get(0), 0); |
312 | assert_eq!(classes.get(1), 0); |
313 | assert_eq!(classes.get(2), 0); |
314 | assert_eq!(classes.get(b'a' - 1), 0); |
315 | assert_eq!(classes.get(b'a' ), 1); |
316 | assert_eq!(classes.get(b'm' ), 1); |
317 | assert_eq!(classes.get(b'z' ), 1); |
318 | assert_eq!(classes.get(b'z' + 1), 2); |
319 | assert_eq!(classes.get(254), 2); |
320 | assert_eq!(classes.get(255), 2); |
321 | |
322 | let mut set = ByteClassSet::empty(); |
323 | set.set_range(0, 2); |
324 | set.set_range(4, 6); |
325 | let classes = set.byte_classes(); |
326 | assert_eq!(classes.get(0), 0); |
327 | assert_eq!(classes.get(1), 0); |
328 | assert_eq!(classes.get(2), 0); |
329 | assert_eq!(classes.get(3), 1); |
330 | assert_eq!(classes.get(4), 2); |
331 | assert_eq!(classes.get(5), 2); |
332 | assert_eq!(classes.get(6), 2); |
333 | assert_eq!(classes.get(7), 3); |
334 | assert_eq!(classes.get(255), 3); |
335 | } |
336 | |
337 | #[test] |
338 | fn full_byte_classes() { |
339 | let mut set = ByteClassSet::empty(); |
340 | for b in 0u8..=255 { |
341 | set.set_range(b, b); |
342 | } |
343 | assert_eq!(set.byte_classes().alphabet_len(), 256); |
344 | } |
345 | |
346 | #[test] |
347 | fn elements_typical() { |
348 | let mut set = ByteClassSet::empty(); |
349 | set.set_range(b'b' , b'd' ); |
350 | set.set_range(b'g' , b'm' ); |
351 | set.set_range(b'z' , b'z' ); |
352 | let classes = set.byte_classes(); |
353 | // class 0: \x00-a |
354 | // class 1: b-d |
355 | // class 2: e-f |
356 | // class 3: g-m |
357 | // class 4: n-y |
358 | // class 5: z-z |
359 | // class 6: \x7B-\xFF |
360 | assert_eq!(classes.alphabet_len(), 7); |
361 | |
362 | let elements = classes.elements(0).collect::<Vec<_>>(); |
363 | assert_eq!(elements.len(), 98); |
364 | assert_eq!(elements[0], b' \x00' ); |
365 | assert_eq!(elements[97], b'a' ); |
366 | |
367 | let elements = classes.elements(1).collect::<Vec<_>>(); |
368 | assert_eq!(elements, vec![b'b' , b'c' , b'd' ],); |
369 | |
370 | let elements = classes.elements(2).collect::<Vec<_>>(); |
371 | assert_eq!(elements, vec![b'e' , b'f' ],); |
372 | |
373 | let elements = classes.elements(3).collect::<Vec<_>>(); |
374 | assert_eq!(elements, vec![b'g' , b'h' , b'i' , b'j' , b'k' , b'l' , b'm' ,],); |
375 | |
376 | let elements = classes.elements(4).collect::<Vec<_>>(); |
377 | assert_eq!(elements.len(), 12); |
378 | assert_eq!(elements[0], b'n' ); |
379 | assert_eq!(elements[11], b'y' ); |
380 | |
381 | let elements = classes.elements(5).collect::<Vec<_>>(); |
382 | assert_eq!(elements, vec![b'z' ]); |
383 | |
384 | let elements = classes.elements(6).collect::<Vec<_>>(); |
385 | assert_eq!(elements.len(), 133); |
386 | assert_eq!(elements[0], b' \x7B' ); |
387 | assert_eq!(elements[132], b' \xFF' ); |
388 | } |
389 | |
390 | #[test] |
391 | fn elements_singletons() { |
392 | let classes = ByteClasses::singletons(); |
393 | assert_eq!(classes.alphabet_len(), 256); |
394 | |
395 | let elements = classes.elements(b'a' ).collect::<Vec<_>>(); |
396 | assert_eq!(elements, vec![b'a' ]); |
397 | } |
398 | |
399 | #[test] |
400 | fn elements_empty() { |
401 | let classes = ByteClasses::empty(); |
402 | assert_eq!(classes.alphabet_len(), 1); |
403 | |
404 | let elements = classes.elements(0).collect::<Vec<_>>(); |
405 | assert_eq!(elements.len(), 256); |
406 | assert_eq!(elements[0], b' \x00' ); |
407 | assert_eq!(elements[255], b' \xFF' ); |
408 | } |
409 | } |
410 | |