1use core::{cmp, fmt, mem, u16, usize};
2
3use alloc::{boxed::Box, string::String, vec, vec::Vec};
4
5use crate::{
6 packed::{api::MatchKind, ext::Pointer},
7 PatternID,
8};
9
10/// A non-empty collection of non-empty patterns to search for.
11///
12/// This collection of patterns is what is passed around to both execute
13/// searches and to construct the searchers themselves. Namely, this permits
14/// searches to avoid copying all of the patterns, and allows us to keep only
15/// one copy throughout all packed searchers.
16///
17/// Note that this collection is not a set. The same pattern can appear more
18/// than once.
19#[derive(Clone, Debug)]
20pub(crate) struct Patterns {
21 /// The match semantics supported by this collection of patterns.
22 ///
23 /// The match semantics determines the order of the iterator over patterns.
24 /// For leftmost-first, patterns are provided in the same order as were
25 /// provided by the caller. For leftmost-longest, patterns are provided in
26 /// descending order of length, with ties broken by the order in which they
27 /// were provided by the caller.
28 kind: MatchKind,
29 /// The collection of patterns, indexed by their identifier.
30 by_id: Vec<Vec<u8>>,
31 /// The order of patterns defined for iteration, given by pattern
32 /// identifiers. The order of `by_id` and `order` is always the same for
33 /// leftmost-first semantics, but may be different for leftmost-longest
34 /// semantics.
35 order: Vec<PatternID>,
36 /// The length of the smallest pattern, in bytes.
37 minimum_len: usize,
38 /// The total number of pattern bytes across the entire collection. This
39 /// is used for reporting total heap usage in constant time.
40 total_pattern_bytes: usize,
41}
42
43// BREADCRUMBS: I think we want to experiment with a different bucket
44// representation. Basically, each bucket is just a Range<usize> to a single
45// contiguous allocation? Maybe length-prefixed patterns or something? The
46// idea is to try to get rid of the pointer chasing in verification. I don't
47// know that that is the issue, but I suspect it is.
48
49impl Patterns {
50 /// Create a new collection of patterns for the given match semantics. The
51 /// ID of each pattern is the index of the pattern at which it occurs in
52 /// the `by_id` slice.
53 ///
54 /// If any of the patterns in the slice given are empty, then this panics.
55 /// Similarly, if the number of patterns given is zero, then this also
56 /// panics.
57 pub(crate) fn new() -> Patterns {
58 Patterns {
59 kind: MatchKind::default(),
60 by_id: vec![],
61 order: vec![],
62 minimum_len: usize::MAX,
63 total_pattern_bytes: 0,
64 }
65 }
66
67 /// Add a pattern to this collection.
68 ///
69 /// This panics if the pattern given is empty.
70 pub(crate) fn add(&mut self, bytes: &[u8]) {
71 assert!(!bytes.is_empty());
72 assert!(self.by_id.len() <= u16::MAX as usize);
73
74 let id = PatternID::new(self.by_id.len()).unwrap();
75 self.order.push(id);
76 self.by_id.push(bytes.to_vec());
77 self.minimum_len = cmp::min(self.minimum_len, bytes.len());
78 self.total_pattern_bytes += bytes.len();
79 }
80
81 /// Set the match kind semantics for this collection of patterns.
82 ///
83 /// If the kind is not set, then the default is leftmost-first.
84 pub(crate) fn set_match_kind(&mut self, kind: MatchKind) {
85 self.kind = kind;
86 match self.kind {
87 MatchKind::LeftmostFirst => {
88 self.order.sort();
89 }
90 MatchKind::LeftmostLongest => {
91 let (order, by_id) = (&mut self.order, &mut self.by_id);
92 order.sort_by(|&id1, &id2| {
93 by_id[id1].len().cmp(&by_id[id2].len()).reverse()
94 });
95 }
96 }
97 }
98
99 /// Return the number of patterns in this collection.
100 ///
101 /// This is guaranteed to be greater than zero.
102 pub(crate) fn len(&self) -> usize {
103 self.by_id.len()
104 }
105
106 /// Returns true if and only if this collection of patterns is empty.
107 pub(crate) fn is_empty(&self) -> bool {
108 self.len() == 0
109 }
110
111 /// Returns the approximate total amount of heap used by these patterns, in
112 /// units of bytes.
113 pub(crate) fn memory_usage(&self) -> usize {
114 self.order.len() * mem::size_of::<PatternID>()
115 + self.by_id.len() * mem::size_of::<Vec<u8>>()
116 + self.total_pattern_bytes
117 }
118
119 /// Clears all heap memory associated with this collection of patterns and
120 /// resets all state such that it is a valid empty collection.
121 pub(crate) fn reset(&mut self) {
122 self.kind = MatchKind::default();
123 self.by_id.clear();
124 self.order.clear();
125 self.minimum_len = usize::MAX;
126 }
127
128 /// Returns the length, in bytes, of the smallest pattern.
129 ///
130 /// This is guaranteed to be at least one.
131 pub(crate) fn minimum_len(&self) -> usize {
132 self.minimum_len
133 }
134
135 /// Returns the match semantics used by these patterns.
136 pub(crate) fn match_kind(&self) -> &MatchKind {
137 &self.kind
138 }
139
140 /// Return the pattern with the given identifier. If such a pattern does
141 /// not exist, then this panics.
142 pub(crate) fn get(&self, id: PatternID) -> Pattern<'_> {
143 Pattern(&self.by_id[id])
144 }
145
146 /// Return the pattern with the given identifier without performing bounds
147 /// checks.
148 ///
149 /// # Safety
150 ///
151 /// Callers must ensure that a pattern with the given identifier exists
152 /// before using this method.
153 pub(crate) unsafe fn get_unchecked(&self, id: PatternID) -> Pattern<'_> {
154 Pattern(self.by_id.get_unchecked(id.as_usize()))
155 }
156
157 /// Return an iterator over all the patterns in this collection, in the
158 /// order in which they should be matched.
159 ///
160 /// Specifically, in a naive multi-pattern matcher, the following is
161 /// guaranteed to satisfy the match semantics of this collection of
162 /// patterns:
163 ///
164 /// ```ignore
165 /// for i in 0..haystack.len():
166 /// for p in patterns.iter():
167 /// if haystack[i..].starts_with(p.bytes()):
168 /// return Match(p.id(), i, i + p.bytes().len())
169 /// ```
170 ///
171 /// Namely, among the patterns in a collection, if they are matched in
172 /// the order provided by this iterator, then the result is guaranteed
173 /// to satisfy the correct match semantics. (Either leftmost-first or
174 /// leftmost-longest.)
175 pub(crate) fn iter(&self) -> PatternIter<'_> {
176 PatternIter { patterns: self, i: 0 }
177 }
178}
179
180/// An iterator over the patterns in the `Patterns` collection.
181///
182/// The order of the patterns provided by this iterator is consistent with the
183/// match semantics of the originating collection of patterns.
184///
185/// The lifetime `'p` corresponds to the lifetime of the collection of patterns
186/// this is iterating over.
187#[derive(Debug)]
188pub(crate) struct PatternIter<'p> {
189 patterns: &'p Patterns,
190 i: usize,
191}
192
193impl<'p> Iterator for PatternIter<'p> {
194 type Item = (PatternID, Pattern<'p>);
195
196 fn next(&mut self) -> Option<(PatternID, Pattern<'p>)> {
197 if self.i >= self.patterns.len() {
198 return None;
199 }
200 let id = self.patterns.order[self.i];
201 let p = self.patterns.get(id);
202 self.i += 1;
203 Some((id, p))
204 }
205}
206
207/// A pattern that is used in packed searching.
208#[derive(Clone)]
209pub(crate) struct Pattern<'a>(&'a [u8]);
210
211impl<'a> fmt::Debug for Pattern<'a> {
212 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
213 f.debug_struct("Pattern")
214 .field("lit", &String::from_utf8_lossy(&self.0))
215 .finish()
216 }
217}
218
219impl<'p> Pattern<'p> {
220 /// Returns the length of this pattern, in bytes.
221 pub(crate) fn len(&self) -> usize {
222 self.0.len()
223 }
224
225 /// Returns the bytes of this pattern.
226 pub(crate) fn bytes(&self) -> &[u8] {
227 &self.0
228 }
229
230 /// Returns the first `len` low nybbles from this pattern. If this pattern
231 /// is shorter than `len`, then this panics.
232 pub(crate) fn low_nybbles(&self, len: usize) -> Box<[u8]> {
233 let mut nybs = vec![0; len].into_boxed_slice();
234 for (i, byte) in self.bytes().iter().take(len).enumerate() {
235 nybs[i] = byte & 0xF;
236 }
237 nybs
238 }
239
240 /// Returns true if this pattern is a prefix of the given bytes.
241 #[inline(always)]
242 pub(crate) fn is_prefix(&self, bytes: &[u8]) -> bool {
243 is_prefix(bytes, self.bytes())
244 }
245
246 /// Returns true if this pattern is a prefix of the haystack given by the
247 /// raw `start` and `end` pointers.
248 ///
249 /// # Safety
250 ///
251 /// * It must be the case that `start < end` and that the distance between
252 /// them is at least equal to `V::BYTES`. That is, it must always be valid
253 /// to do at least an unaligned load of `V` at `start`.
254 /// * Both `start` and `end` must be valid for reads.
255 /// * Both `start` and `end` must point to an initialized value.
256 /// * Both `start` and `end` must point to the same allocated object and
257 /// must either be in bounds or at most one byte past the end of the
258 /// allocated object.
259 /// * Both `start` and `end` must be _derived from_ a pointer to the same
260 /// object.
261 /// * The distance between `start` and `end` must not overflow `isize`.
262 /// * The distance being in bounds must not rely on "wrapping around" the
263 /// address space.
264 #[inline(always)]
265 pub(crate) unsafe fn is_prefix_raw(
266 &self,
267 start: *const u8,
268 end: *const u8,
269 ) -> bool {
270 let patlen = self.bytes().len();
271 let haylen = end.distance(start);
272 if patlen > haylen {
273 return false;
274 }
275 // SAFETY: We've checked that the haystack has length at least equal
276 // to this pattern. All other safety concerns are the responsibility
277 // of the caller.
278 is_equal_raw(start, self.bytes().as_ptr(), patlen)
279 }
280}
281
282/// Returns true if and only if `needle` is a prefix of `haystack`.
283///
284/// This uses a latency optimized variant of `memcmp` internally which *might*
285/// make this faster for very short strings.
286///
287/// # Inlining
288///
289/// This routine is marked `inline(always)`. If you want to call this function
290/// in a way that is not always inlined, you'll need to wrap a call to it in
291/// another function that is marked as `inline(never)` or just `inline`.
292#[inline(always)]
293fn is_prefix(haystack: &[u8], needle: &[u8]) -> bool {
294 if needle.len() > haystack.len() {
295 return false;
296 }
297 // SAFETY: Our pointers are derived directly from borrowed slices which
298 // uphold all of our safety guarantees except for length. We account for
299 // length with the check above.
300 unsafe { is_equal_raw(haystack.as_ptr(), needle.as_ptr(), needle.len()) }
301}
302
303/// Compare corresponding bytes in `x` and `y` for equality.
304///
305/// That is, this returns true if and only if `x.len() == y.len()` and
306/// `x[i] == y[i]` for all `0 <= i < x.len()`.
307///
308/// Note that this isn't used. We only use it in tests as a convenient way
309/// of testing `is_equal_raw`.
310///
311/// # Inlining
312///
313/// This routine is marked `inline(always)`. If you want to call this function
314/// in a way that is not always inlined, you'll need to wrap a call to it in
315/// another function that is marked as `inline(never)` or just `inline`.
316///
317/// # Motivation
318///
319/// Why not use slice equality instead? Well, slice equality usually results in
320/// a call out to the current platform's `libc` which might not be inlineable
321/// or have other overhead. This routine isn't guaranteed to be a win, but it
322/// might be in some cases.
323#[cfg(test)]
324#[inline(always)]
325fn is_equal(x: &[u8], y: &[u8]) -> bool {
326 if x.len() != y.len() {
327 return false;
328 }
329 // SAFETY: Our pointers are derived directly from borrowed slices which
330 // uphold all of our safety guarantees except for length. We account for
331 // length with the check above.
332 unsafe { is_equal_raw(x.as_ptr(), y.as_ptr(), x.len()) }
333}
334
335/// Compare `n` bytes at the given pointers for equality.
336///
337/// This returns true if and only if `*x.add(i) == *y.add(i)` for all
338/// `0 <= i < n`.
339///
340/// # Inlining
341///
342/// This routine is marked `inline(always)`. If you want to call this function
343/// in a way that is not always inlined, you'll need to wrap a call to it in
344/// another function that is marked as `inline(never)` or just `inline`.
345///
346/// # Motivation
347///
348/// Why not use slice equality instead? Well, slice equality usually results in
349/// a call out to the current platform's `libc` which might not be inlineable
350/// or have other overhead. This routine isn't guaranteed to be a win, but it
351/// might be in some cases.
352///
353/// # Safety
354///
355/// * Both `x` and `y` must be valid for reads of up to `n` bytes.
356/// * Both `x` and `y` must point to an initialized value.
357/// * Both `x` and `y` must each point to an allocated object and
358/// must either be in bounds or at most one byte past the end of the
359/// allocated object. `x` and `y` do not need to point to the same allocated
360/// object, but they may.
361/// * Both `x` and `y` must be _derived from_ a pointer to their respective
362/// allocated objects.
363/// * The distance between `x` and `x+n` must not overflow `isize`. Similarly
364/// for `y` and `y+n`.
365/// * The distance being in bounds must not rely on "wrapping around" the
366/// address space.
367#[inline(always)]
368unsafe fn is_equal_raw(mut x: *const u8, mut y: *const u8, n: usize) -> bool {
369 // If we don't have enough bytes to do 4-byte at a time loads, then
370 // handle each possible length specially. Note that I used to have a
371 // byte-at-a-time loop here and that turned out to be quite a bit slower
372 // for the memmem/pathological/defeat-simple-vector-alphabet benchmark.
373 if n < 4 {
374 return match n {
375 0 => true,
376 1 => x.read() == y.read(),
377 2 => {
378 x.cast::<u16>().read_unaligned()
379 == y.cast::<u16>().read_unaligned()
380 }
381 // I also tried copy_nonoverlapping here and it looks like the
382 // codegen is the same.
383 3 => x.cast::<[u8; 3]>().read() == y.cast::<[u8; 3]>().read(),
384 _ => unreachable!(),
385 };
386 }
387 // When we have 4 or more bytes to compare, then proceed in chunks of 4 at
388 // a time using unaligned loads.
389 //
390 // Also, why do 4 byte loads instead of, say, 8 byte loads? The reason is
391 // that this particular version of memcmp is likely to be called with tiny
392 // needles. That means that if we do 8 byte loads, then a higher proportion
393 // of memcmp calls will use the slower variant above. With that said, this
394 // is a hypothesis and is only loosely supported by benchmarks. There's
395 // likely some improvement that could be made here. The main thing here
396 // though is to optimize for latency, not throughput.
397
398 // SAFETY: The caller is responsible for ensuring the pointers we get are
399 // valid and readable for at least `n` bytes. We also do unaligned loads,
400 // so there's no need to ensure we're aligned. (This is justified by this
401 // routine being specifically for short strings.)
402 let xend = x.add(n.wrapping_sub(4));
403 let yend = y.add(n.wrapping_sub(4));
404 while x < xend {
405 let vx = x.cast::<u32>().read_unaligned();
406 let vy = y.cast::<u32>().read_unaligned();
407 if vx != vy {
408 return false;
409 }
410 x = x.add(4);
411 y = y.add(4);
412 }
413 let vx = xend.cast::<u32>().read_unaligned();
414 let vy = yend.cast::<u32>().read_unaligned();
415 vx == vy
416}
417
418#[cfg(test)]
419mod tests {
420 use super::*;
421
422 #[test]
423 fn equals_different_lengths() {
424 assert!(!is_equal(b"", b"a"));
425 assert!(!is_equal(b"a", b""));
426 assert!(!is_equal(b"ab", b"a"));
427 assert!(!is_equal(b"a", b"ab"));
428 }
429
430 #[test]
431 fn equals_mismatch() {
432 let one_mismatch = [
433 (&b"a"[..], &b"x"[..]),
434 (&b"ab"[..], &b"ax"[..]),
435 (&b"abc"[..], &b"abx"[..]),
436 (&b"abcd"[..], &b"abcx"[..]),
437 (&b"abcde"[..], &b"abcdx"[..]),
438 (&b"abcdef"[..], &b"abcdex"[..]),
439 (&b"abcdefg"[..], &b"abcdefx"[..]),
440 (&b"abcdefgh"[..], &b"abcdefgx"[..]),
441 (&b"abcdefghi"[..], &b"abcdefghx"[..]),
442 (&b"abcdefghij"[..], &b"abcdefghix"[..]),
443 (&b"abcdefghijk"[..], &b"abcdefghijx"[..]),
444 (&b"abcdefghijkl"[..], &b"abcdefghijkx"[..]),
445 (&b"abcdefghijklm"[..], &b"abcdefghijklx"[..]),
446 (&b"abcdefghijklmn"[..], &b"abcdefghijklmx"[..]),
447 ];
448 for (x, y) in one_mismatch {
449 assert_eq!(x.len(), y.len(), "lengths should match");
450 assert!(!is_equal(x, y));
451 assert!(!is_equal(y, x));
452 }
453 }
454
455 #[test]
456 fn equals_yes() {
457 assert!(is_equal(b"", b""));
458 assert!(is_equal(b"a", b"a"));
459 assert!(is_equal(b"ab", b"ab"));
460 assert!(is_equal(b"abc", b"abc"));
461 assert!(is_equal(b"abcd", b"abcd"));
462 assert!(is_equal(b"abcde", b"abcde"));
463 assert!(is_equal(b"abcdef", b"abcdef"));
464 assert!(is_equal(b"abcdefg", b"abcdefg"));
465 assert!(is_equal(b"abcdefgh", b"abcdefgh"));
466 assert!(is_equal(b"abcdefghi", b"abcdefghi"));
467 }
468
469 #[test]
470 fn prefix() {
471 assert!(is_prefix(b"", b""));
472 assert!(is_prefix(b"a", b""));
473 assert!(is_prefix(b"ab", b""));
474 assert!(is_prefix(b"foo", b"foo"));
475 assert!(is_prefix(b"foobar", b"foo"));
476
477 assert!(!is_prefix(b"foo", b"fob"));
478 assert!(!is_prefix(b"foobar", b"fob"));
479 }
480}
481