1 | use core::{cmp, fmt, mem, u16, usize}; |
2 | |
3 | use alloc::{boxed::Box, string::String, vec, vec::Vec}; |
4 | |
5 | use crate::{ |
6 | packed::{api::MatchKind, ext::Pointer}, |
7 | PatternID, |
8 | }; |
9 | |
10 | /// A non-empty collection of non-empty patterns to search for. |
11 | /// |
12 | /// This collection of patterns is what is passed around to both execute |
13 | /// searches and to construct the searchers themselves. Namely, this permits |
14 | /// searches to avoid copying all of the patterns, and allows us to keep only |
15 | /// one copy throughout all packed searchers. |
16 | /// |
17 | /// Note that this collection is not a set. The same pattern can appear more |
18 | /// than once. |
19 | #[derive(Clone, Debug)] |
20 | pub(crate) struct Patterns { |
21 | /// The match semantics supported by this collection of patterns. |
22 | /// |
23 | /// The match semantics determines the order of the iterator over patterns. |
24 | /// For leftmost-first, patterns are provided in the same order as were |
25 | /// provided by the caller. For leftmost-longest, patterns are provided in |
26 | /// descending order of length, with ties broken by the order in which they |
27 | /// were provided by the caller. |
28 | kind: MatchKind, |
29 | /// The collection of patterns, indexed by their identifier. |
30 | by_id: Vec<Vec<u8>>, |
31 | /// The order of patterns defined for iteration, given by pattern |
32 | /// identifiers. The order of `by_id` and `order` is always the same for |
33 | /// leftmost-first semantics, but may be different for leftmost-longest |
34 | /// semantics. |
35 | order: Vec<PatternID>, |
36 | /// The length of the smallest pattern, in bytes. |
37 | minimum_len: usize, |
38 | /// The total number of pattern bytes across the entire collection. This |
39 | /// is used for reporting total heap usage in constant time. |
40 | total_pattern_bytes: usize, |
41 | } |
42 | |
43 | // BREADCRUMBS: I think we want to experiment with a different bucket |
44 | // representation. Basically, each bucket is just a Range<usize> to a single |
45 | // contiguous allocation? Maybe length-prefixed patterns or something? The |
46 | // idea is to try to get rid of the pointer chasing in verification. I don't |
47 | // know that that is the issue, but I suspect it is. |
48 | |
49 | impl Patterns { |
50 | /// Create a new collection of patterns for the given match semantics. The |
51 | /// ID of each pattern is the index of the pattern at which it occurs in |
52 | /// the `by_id` slice. |
53 | /// |
54 | /// If any of the patterns in the slice given are empty, then this panics. |
55 | /// Similarly, if the number of patterns given is zero, then this also |
56 | /// panics. |
57 | pub(crate) fn new() -> Patterns { |
58 | Patterns { |
59 | kind: MatchKind::default(), |
60 | by_id: vec![], |
61 | order: vec![], |
62 | minimum_len: usize::MAX, |
63 | total_pattern_bytes: 0, |
64 | } |
65 | } |
66 | |
67 | /// Add a pattern to this collection. |
68 | /// |
69 | /// This panics if the pattern given is empty. |
70 | pub(crate) fn add(&mut self, bytes: &[u8]) { |
71 | assert!(!bytes.is_empty()); |
72 | assert!(self.by_id.len() <= u16::MAX as usize); |
73 | |
74 | let id = PatternID::new(self.by_id.len()).unwrap(); |
75 | self.order.push(id); |
76 | self.by_id.push(bytes.to_vec()); |
77 | self.minimum_len = cmp::min(self.minimum_len, bytes.len()); |
78 | self.total_pattern_bytes += bytes.len(); |
79 | } |
80 | |
81 | /// Set the match kind semantics for this collection of patterns. |
82 | /// |
83 | /// If the kind is not set, then the default is leftmost-first. |
84 | pub(crate) fn set_match_kind(&mut self, kind: MatchKind) { |
85 | self.kind = kind; |
86 | match self.kind { |
87 | MatchKind::LeftmostFirst => { |
88 | self.order.sort(); |
89 | } |
90 | MatchKind::LeftmostLongest => { |
91 | let (order, by_id) = (&mut self.order, &mut self.by_id); |
92 | order.sort_by(|&id1, &id2| { |
93 | by_id[id1].len().cmp(&by_id[id2].len()).reverse() |
94 | }); |
95 | } |
96 | } |
97 | } |
98 | |
99 | /// Return the number of patterns in this collection. |
100 | /// |
101 | /// This is guaranteed to be greater than zero. |
102 | pub(crate) fn len(&self) -> usize { |
103 | self.by_id.len() |
104 | } |
105 | |
106 | /// Returns true if and only if this collection of patterns is empty. |
107 | pub(crate) fn is_empty(&self) -> bool { |
108 | self.len() == 0 |
109 | } |
110 | |
111 | /// Returns the approximate total amount of heap used by these patterns, in |
112 | /// units of bytes. |
113 | pub(crate) fn memory_usage(&self) -> usize { |
114 | self.order.len() * mem::size_of::<PatternID>() |
115 | + self.by_id.len() * mem::size_of::<Vec<u8>>() |
116 | + self.total_pattern_bytes |
117 | } |
118 | |
119 | /// Clears all heap memory associated with this collection of patterns and |
120 | /// resets all state such that it is a valid empty collection. |
121 | pub(crate) fn reset(&mut self) { |
122 | self.kind = MatchKind::default(); |
123 | self.by_id.clear(); |
124 | self.order.clear(); |
125 | self.minimum_len = usize::MAX; |
126 | } |
127 | |
128 | /// Returns the length, in bytes, of the smallest pattern. |
129 | /// |
130 | /// This is guaranteed to be at least one. |
131 | pub(crate) fn minimum_len(&self) -> usize { |
132 | self.minimum_len |
133 | } |
134 | |
135 | /// Returns the match semantics used by these patterns. |
136 | pub(crate) fn match_kind(&self) -> &MatchKind { |
137 | &self.kind |
138 | } |
139 | |
140 | /// Return the pattern with the given identifier. If such a pattern does |
141 | /// not exist, then this panics. |
142 | pub(crate) fn get(&self, id: PatternID) -> Pattern<'_> { |
143 | Pattern(&self.by_id[id]) |
144 | } |
145 | |
146 | /// Return the pattern with the given identifier without performing bounds |
147 | /// checks. |
148 | /// |
149 | /// # Safety |
150 | /// |
151 | /// Callers must ensure that a pattern with the given identifier exists |
152 | /// before using this method. |
153 | pub(crate) unsafe fn get_unchecked(&self, id: PatternID) -> Pattern<'_> { |
154 | Pattern(self.by_id.get_unchecked(id.as_usize())) |
155 | } |
156 | |
157 | /// Return an iterator over all the patterns in this collection, in the |
158 | /// order in which they should be matched. |
159 | /// |
160 | /// Specifically, in a naive multi-pattern matcher, the following is |
161 | /// guaranteed to satisfy the match semantics of this collection of |
162 | /// patterns: |
163 | /// |
164 | /// ```ignore |
165 | /// for i in 0..haystack.len(): |
166 | /// for p in patterns.iter(): |
167 | /// if haystack[i..].starts_with(p.bytes()): |
168 | /// return Match(p.id(), i, i + p.bytes().len()) |
169 | /// ``` |
170 | /// |
171 | /// Namely, among the patterns in a collection, if they are matched in |
172 | /// the order provided by this iterator, then the result is guaranteed |
173 | /// to satisfy the correct match semantics. (Either leftmost-first or |
174 | /// leftmost-longest.) |
175 | pub(crate) fn iter(&self) -> PatternIter<'_> { |
176 | PatternIter { patterns: self, i: 0 } |
177 | } |
178 | } |
179 | |
180 | /// An iterator over the patterns in the `Patterns` collection. |
181 | /// |
182 | /// The order of the patterns provided by this iterator is consistent with the |
183 | /// match semantics of the originating collection of patterns. |
184 | /// |
185 | /// The lifetime `'p` corresponds to the lifetime of the collection of patterns |
186 | /// this is iterating over. |
187 | #[derive(Debug)] |
188 | pub(crate) struct PatternIter<'p> { |
189 | patterns: &'p Patterns, |
190 | i: usize, |
191 | } |
192 | |
193 | impl<'p> Iterator for PatternIter<'p> { |
194 | type Item = (PatternID, Pattern<'p>); |
195 | |
196 | fn next(&mut self) -> Option<(PatternID, Pattern<'p>)> { |
197 | if self.i >= self.patterns.len() { |
198 | return None; |
199 | } |
200 | let id = self.patterns.order[self.i]; |
201 | let p = self.patterns.get(id); |
202 | self.i += 1; |
203 | Some((id, p)) |
204 | } |
205 | } |
206 | |
207 | /// A pattern that is used in packed searching. |
208 | #[derive(Clone)] |
209 | pub(crate) struct Pattern<'a>(&'a [u8]); |
210 | |
211 | impl<'a> fmt::Debug for Pattern<'a> { |
212 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
213 | f.debug_struct("Pattern" ) |
214 | .field("lit" , &String::from_utf8_lossy(&self.0)) |
215 | .finish() |
216 | } |
217 | } |
218 | |
219 | impl<'p> Pattern<'p> { |
220 | /// Returns the length of this pattern, in bytes. |
221 | pub(crate) fn len(&self) -> usize { |
222 | self.0.len() |
223 | } |
224 | |
225 | /// Returns the bytes of this pattern. |
226 | pub(crate) fn bytes(&self) -> &[u8] { |
227 | &self.0 |
228 | } |
229 | |
230 | /// Returns the first `len` low nybbles from this pattern. If this pattern |
231 | /// is shorter than `len`, then this panics. |
232 | pub(crate) fn low_nybbles(&self, len: usize) -> Box<[u8]> { |
233 | let mut nybs = vec![0; len].into_boxed_slice(); |
234 | for (i, byte) in self.bytes().iter().take(len).enumerate() { |
235 | nybs[i] = byte & 0xF; |
236 | } |
237 | nybs |
238 | } |
239 | |
240 | /// Returns true if this pattern is a prefix of the given bytes. |
241 | #[inline (always)] |
242 | pub(crate) fn is_prefix(&self, bytes: &[u8]) -> bool { |
243 | is_prefix(bytes, self.bytes()) |
244 | } |
245 | |
246 | /// Returns true if this pattern is a prefix of the haystack given by the |
247 | /// raw `start` and `end` pointers. |
248 | /// |
249 | /// # Safety |
250 | /// |
251 | /// * It must be the case that `start < end` and that the distance between |
252 | /// them is at least equal to `V::BYTES`. That is, it must always be valid |
253 | /// to do at least an unaligned load of `V` at `start`. |
254 | /// * Both `start` and `end` must be valid for reads. |
255 | /// * Both `start` and `end` must point to an initialized value. |
256 | /// * Both `start` and `end` must point to the same allocated object and |
257 | /// must either be in bounds or at most one byte past the end of the |
258 | /// allocated object. |
259 | /// * Both `start` and `end` must be _derived from_ a pointer to the same |
260 | /// object. |
261 | /// * The distance between `start` and `end` must not overflow `isize`. |
262 | /// * The distance being in bounds must not rely on "wrapping around" the |
263 | /// address space. |
264 | #[inline (always)] |
265 | pub(crate) unsafe fn is_prefix_raw( |
266 | &self, |
267 | start: *const u8, |
268 | end: *const u8, |
269 | ) -> bool { |
270 | let patlen = self.bytes().len(); |
271 | let haylen = end.distance(start); |
272 | if patlen > haylen { |
273 | return false; |
274 | } |
275 | // SAFETY: We've checked that the haystack has length at least equal |
276 | // to this pattern. All other safety concerns are the responsibility |
277 | // of the caller. |
278 | is_equal_raw(start, self.bytes().as_ptr(), patlen) |
279 | } |
280 | } |
281 | |
282 | /// Returns true if and only if `needle` is a prefix of `haystack`. |
283 | /// |
284 | /// This uses a latency optimized variant of `memcmp` internally which *might* |
285 | /// make this faster for very short strings. |
286 | /// |
287 | /// # Inlining |
288 | /// |
289 | /// This routine is marked `inline(always)`. If you want to call this function |
290 | /// in a way that is not always inlined, you'll need to wrap a call to it in |
291 | /// another function that is marked as `inline(never)` or just `inline`. |
292 | #[inline (always)] |
293 | fn is_prefix(haystack: &[u8], needle: &[u8]) -> bool { |
294 | if needle.len() > haystack.len() { |
295 | return false; |
296 | } |
297 | // SAFETY: Our pointers are derived directly from borrowed slices which |
298 | // uphold all of our safety guarantees except for length. We account for |
299 | // length with the check above. |
300 | unsafe { is_equal_raw(haystack.as_ptr(), needle.as_ptr(), needle.len()) } |
301 | } |
302 | |
303 | /// Compare corresponding bytes in `x` and `y` for equality. |
304 | /// |
305 | /// That is, this returns true if and only if `x.len() == y.len()` and |
306 | /// `x[i] == y[i]` for all `0 <= i < x.len()`. |
307 | /// |
308 | /// Note that this isn't used. We only use it in tests as a convenient way |
309 | /// of testing `is_equal_raw`. |
310 | /// |
311 | /// # Inlining |
312 | /// |
313 | /// This routine is marked `inline(always)`. If you want to call this function |
314 | /// in a way that is not always inlined, you'll need to wrap a call to it in |
315 | /// another function that is marked as `inline(never)` or just `inline`. |
316 | /// |
317 | /// # Motivation |
318 | /// |
319 | /// Why not use slice equality instead? Well, slice equality usually results in |
320 | /// a call out to the current platform's `libc` which might not be inlineable |
321 | /// or have other overhead. This routine isn't guaranteed to be a win, but it |
322 | /// might be in some cases. |
323 | #[cfg (test)] |
324 | #[inline (always)] |
325 | fn is_equal(x: &[u8], y: &[u8]) -> bool { |
326 | if x.len() != y.len() { |
327 | return false; |
328 | } |
329 | // SAFETY: Our pointers are derived directly from borrowed slices which |
330 | // uphold all of our safety guarantees except for length. We account for |
331 | // length with the check above. |
332 | unsafe { is_equal_raw(x.as_ptr(), y.as_ptr(), x.len()) } |
333 | } |
334 | |
335 | /// Compare `n` bytes at the given pointers for equality. |
336 | /// |
337 | /// This returns true if and only if `*x.add(i) == *y.add(i)` for all |
338 | /// `0 <= i < n`. |
339 | /// |
340 | /// # Inlining |
341 | /// |
342 | /// This routine is marked `inline(always)`. If you want to call this function |
343 | /// in a way that is not always inlined, you'll need to wrap a call to it in |
344 | /// another function that is marked as `inline(never)` or just `inline`. |
345 | /// |
346 | /// # Motivation |
347 | /// |
348 | /// Why not use slice equality instead? Well, slice equality usually results in |
349 | /// a call out to the current platform's `libc` which might not be inlineable |
350 | /// or have other overhead. This routine isn't guaranteed to be a win, but it |
351 | /// might be in some cases. |
352 | /// |
353 | /// # Safety |
354 | /// |
355 | /// * Both `x` and `y` must be valid for reads of up to `n` bytes. |
356 | /// * Both `x` and `y` must point to an initialized value. |
357 | /// * Both `x` and `y` must each point to an allocated object and |
358 | /// must either be in bounds or at most one byte past the end of the |
359 | /// allocated object. `x` and `y` do not need to point to the same allocated |
360 | /// object, but they may. |
361 | /// * Both `x` and `y` must be _derived from_ a pointer to their respective |
362 | /// allocated objects. |
363 | /// * The distance between `x` and `x+n` must not overflow `isize`. Similarly |
364 | /// for `y` and `y+n`. |
365 | /// * The distance being in bounds must not rely on "wrapping around" the |
366 | /// address space. |
367 | #[inline (always)] |
368 | unsafe fn is_equal_raw(mut x: *const u8, mut y: *const u8, n: usize) -> bool { |
369 | // If we don't have enough bytes to do 4-byte at a time loads, then |
370 | // handle each possible length specially. Note that I used to have a |
371 | // byte-at-a-time loop here and that turned out to be quite a bit slower |
372 | // for the memmem/pathological/defeat-simple-vector-alphabet benchmark. |
373 | if n < 4 { |
374 | return match n { |
375 | 0 => true, |
376 | 1 => x.read() == y.read(), |
377 | 2 => { |
378 | x.cast::<u16>().read_unaligned() |
379 | == y.cast::<u16>().read_unaligned() |
380 | } |
381 | // I also tried copy_nonoverlapping here and it looks like the |
382 | // codegen is the same. |
383 | 3 => x.cast::<[u8; 3]>().read() == y.cast::<[u8; 3]>().read(), |
384 | _ => unreachable!(), |
385 | }; |
386 | } |
387 | // When we have 4 or more bytes to compare, then proceed in chunks of 4 at |
388 | // a time using unaligned loads. |
389 | // |
390 | // Also, why do 4 byte loads instead of, say, 8 byte loads? The reason is |
391 | // that this particular version of memcmp is likely to be called with tiny |
392 | // needles. That means that if we do 8 byte loads, then a higher proportion |
393 | // of memcmp calls will use the slower variant above. With that said, this |
394 | // is a hypothesis and is only loosely supported by benchmarks. There's |
395 | // likely some improvement that could be made here. The main thing here |
396 | // though is to optimize for latency, not throughput. |
397 | |
398 | // SAFETY: The caller is responsible for ensuring the pointers we get are |
399 | // valid and readable for at least `n` bytes. We also do unaligned loads, |
400 | // so there's no need to ensure we're aligned. (This is justified by this |
401 | // routine being specifically for short strings.) |
402 | let xend = x.add(n.wrapping_sub(4)); |
403 | let yend = y.add(n.wrapping_sub(4)); |
404 | while x < xend { |
405 | let vx = x.cast::<u32>().read_unaligned(); |
406 | let vy = y.cast::<u32>().read_unaligned(); |
407 | if vx != vy { |
408 | return false; |
409 | } |
410 | x = x.add(4); |
411 | y = y.add(4); |
412 | } |
413 | let vx = xend.cast::<u32>().read_unaligned(); |
414 | let vy = yend.cast::<u32>().read_unaligned(); |
415 | vx == vy |
416 | } |
417 | |
418 | #[cfg (test)] |
419 | mod tests { |
420 | use super::*; |
421 | |
422 | #[test] |
423 | fn equals_different_lengths() { |
424 | assert!(!is_equal(b"" , b"a" )); |
425 | assert!(!is_equal(b"a" , b"" )); |
426 | assert!(!is_equal(b"ab" , b"a" )); |
427 | assert!(!is_equal(b"a" , b"ab" )); |
428 | } |
429 | |
430 | #[test] |
431 | fn equals_mismatch() { |
432 | let one_mismatch = [ |
433 | (&b"a" [..], &b"x" [..]), |
434 | (&b"ab" [..], &b"ax" [..]), |
435 | (&b"abc" [..], &b"abx" [..]), |
436 | (&b"abcd" [..], &b"abcx" [..]), |
437 | (&b"abcde" [..], &b"abcdx" [..]), |
438 | (&b"abcdef" [..], &b"abcdex" [..]), |
439 | (&b"abcdefg" [..], &b"abcdefx" [..]), |
440 | (&b"abcdefgh" [..], &b"abcdefgx" [..]), |
441 | (&b"abcdefghi" [..], &b"abcdefghx" [..]), |
442 | (&b"abcdefghij" [..], &b"abcdefghix" [..]), |
443 | (&b"abcdefghijk" [..], &b"abcdefghijx" [..]), |
444 | (&b"abcdefghijkl" [..], &b"abcdefghijkx" [..]), |
445 | (&b"abcdefghijklm" [..], &b"abcdefghijklx" [..]), |
446 | (&b"abcdefghijklmn" [..], &b"abcdefghijklmx" [..]), |
447 | ]; |
448 | for (x, y) in one_mismatch { |
449 | assert_eq!(x.len(), y.len(), "lengths should match" ); |
450 | assert!(!is_equal(x, y)); |
451 | assert!(!is_equal(y, x)); |
452 | } |
453 | } |
454 | |
455 | #[test] |
456 | fn equals_yes() { |
457 | assert!(is_equal(b"" , b"" )); |
458 | assert!(is_equal(b"a" , b"a" )); |
459 | assert!(is_equal(b"ab" , b"ab" )); |
460 | assert!(is_equal(b"abc" , b"abc" )); |
461 | assert!(is_equal(b"abcd" , b"abcd" )); |
462 | assert!(is_equal(b"abcde" , b"abcde" )); |
463 | assert!(is_equal(b"abcdef" , b"abcdef" )); |
464 | assert!(is_equal(b"abcdefg" , b"abcdefg" )); |
465 | assert!(is_equal(b"abcdefgh" , b"abcdefgh" )); |
466 | assert!(is_equal(b"abcdefghi" , b"abcdefghi" )); |
467 | } |
468 | |
469 | #[test] |
470 | fn prefix() { |
471 | assert!(is_prefix(b"" , b"" )); |
472 | assert!(is_prefix(b"a" , b"" )); |
473 | assert!(is_prefix(b"ab" , b"" )); |
474 | assert!(is_prefix(b"foo" , b"foo" )); |
475 | assert!(is_prefix(b"foobar" , b"foo" )); |
476 | |
477 | assert!(!is_prefix(b"foo" , b"fob" )); |
478 | assert!(!is_prefix(b"foobar" , b"fob" )); |
479 | } |
480 | } |
481 | |