| 1 | use core::{cmp, fmt, mem, u16, usize}; |
| 2 | |
| 3 | use alloc::{boxed::Box, string::String, vec, vec::Vec}; |
| 4 | |
| 5 | use crate::{ |
| 6 | packed::{api::MatchKind, ext::Pointer}, |
| 7 | PatternID, |
| 8 | }; |
| 9 | |
| 10 | /// A non-empty collection of non-empty patterns to search for. |
| 11 | /// |
| 12 | /// This collection of patterns is what is passed around to both execute |
| 13 | /// searches and to construct the searchers themselves. Namely, this permits |
| 14 | /// searches to avoid copying all of the patterns, and allows us to keep only |
| 15 | /// one copy throughout all packed searchers. |
| 16 | /// |
| 17 | /// Note that this collection is not a set. The same pattern can appear more |
| 18 | /// than once. |
| 19 | #[derive (Clone, Debug)] |
| 20 | pub(crate) struct Patterns { |
| 21 | /// The match semantics supported by this collection of patterns. |
| 22 | /// |
| 23 | /// The match semantics determines the order of the iterator over patterns. |
| 24 | /// For leftmost-first, patterns are provided in the same order as were |
| 25 | /// provided by the caller. For leftmost-longest, patterns are provided in |
| 26 | /// descending order of length, with ties broken by the order in which they |
| 27 | /// were provided by the caller. |
| 28 | kind: MatchKind, |
| 29 | /// The collection of patterns, indexed by their identifier. |
| 30 | by_id: Vec<Vec<u8>>, |
| 31 | /// The order of patterns defined for iteration, given by pattern |
| 32 | /// identifiers. The order of `by_id` and `order` is always the same for |
| 33 | /// leftmost-first semantics, but may be different for leftmost-longest |
| 34 | /// semantics. |
| 35 | order: Vec<PatternID>, |
| 36 | /// The length of the smallest pattern, in bytes. |
| 37 | minimum_len: usize, |
| 38 | /// The total number of pattern bytes across the entire collection. This |
| 39 | /// is used for reporting total heap usage in constant time. |
| 40 | total_pattern_bytes: usize, |
| 41 | } |
| 42 | |
| 43 | // BREADCRUMBS: I think we want to experiment with a different bucket |
| 44 | // representation. Basically, each bucket is just a Range<usize> to a single |
| 45 | // contiguous allocation? Maybe length-prefixed patterns or something? The |
| 46 | // idea is to try to get rid of the pointer chasing in verification. I don't |
| 47 | // know that that is the issue, but I suspect it is. |
| 48 | |
| 49 | impl Patterns { |
| 50 | /// Create a new collection of patterns for the given match semantics. The |
| 51 | /// ID of each pattern is the index of the pattern at which it occurs in |
| 52 | /// the `by_id` slice. |
| 53 | /// |
| 54 | /// If any of the patterns in the slice given are empty, then this panics. |
| 55 | /// Similarly, if the number of patterns given is zero, then this also |
| 56 | /// panics. |
| 57 | pub(crate) fn new() -> Patterns { |
| 58 | Patterns { |
| 59 | kind: MatchKind::default(), |
| 60 | by_id: vec![], |
| 61 | order: vec![], |
| 62 | minimum_len: usize::MAX, |
| 63 | total_pattern_bytes: 0, |
| 64 | } |
| 65 | } |
| 66 | |
| 67 | /// Add a pattern to this collection. |
| 68 | /// |
| 69 | /// This panics if the pattern given is empty. |
| 70 | pub(crate) fn add(&mut self, bytes: &[u8]) { |
| 71 | assert!(!bytes.is_empty()); |
| 72 | assert!(self.by_id.len() <= u16::MAX as usize); |
| 73 | |
| 74 | let id = PatternID::new(self.by_id.len()).unwrap(); |
| 75 | self.order.push(id); |
| 76 | self.by_id.push(bytes.to_vec()); |
| 77 | self.minimum_len = cmp::min(self.minimum_len, bytes.len()); |
| 78 | self.total_pattern_bytes += bytes.len(); |
| 79 | } |
| 80 | |
| 81 | /// Set the match kind semantics for this collection of patterns. |
| 82 | /// |
| 83 | /// If the kind is not set, then the default is leftmost-first. |
| 84 | pub(crate) fn set_match_kind(&mut self, kind: MatchKind) { |
| 85 | self.kind = kind; |
| 86 | match self.kind { |
| 87 | MatchKind::LeftmostFirst => { |
| 88 | self.order.sort(); |
| 89 | } |
| 90 | MatchKind::LeftmostLongest => { |
| 91 | let (order, by_id) = (&mut self.order, &mut self.by_id); |
| 92 | order.sort_by(|&id1, &id2| { |
| 93 | by_id[id1].len().cmp(&by_id[id2].len()).reverse() |
| 94 | }); |
| 95 | } |
| 96 | } |
| 97 | } |
| 98 | |
| 99 | /// Return the number of patterns in this collection. |
| 100 | /// |
| 101 | /// This is guaranteed to be greater than zero. |
| 102 | pub(crate) fn len(&self) -> usize { |
| 103 | self.by_id.len() |
| 104 | } |
| 105 | |
| 106 | /// Returns true if and only if this collection of patterns is empty. |
| 107 | pub(crate) fn is_empty(&self) -> bool { |
| 108 | self.len() == 0 |
| 109 | } |
| 110 | |
| 111 | /// Returns the approximate total amount of heap used by these patterns, in |
| 112 | /// units of bytes. |
| 113 | pub(crate) fn memory_usage(&self) -> usize { |
| 114 | self.order.len() * mem::size_of::<PatternID>() |
| 115 | + self.by_id.len() * mem::size_of::<Vec<u8>>() |
| 116 | + self.total_pattern_bytes |
| 117 | } |
| 118 | |
| 119 | /// Clears all heap memory associated with this collection of patterns and |
| 120 | /// resets all state such that it is a valid empty collection. |
| 121 | pub(crate) fn reset(&mut self) { |
| 122 | self.kind = MatchKind::default(); |
| 123 | self.by_id.clear(); |
| 124 | self.order.clear(); |
| 125 | self.minimum_len = usize::MAX; |
| 126 | } |
| 127 | |
| 128 | /// Returns the length, in bytes, of the smallest pattern. |
| 129 | /// |
| 130 | /// This is guaranteed to be at least one. |
| 131 | pub(crate) fn minimum_len(&self) -> usize { |
| 132 | self.minimum_len |
| 133 | } |
| 134 | |
| 135 | /// Returns the match semantics used by these patterns. |
| 136 | pub(crate) fn match_kind(&self) -> &MatchKind { |
| 137 | &self.kind |
| 138 | } |
| 139 | |
| 140 | /// Return the pattern with the given identifier. If such a pattern does |
| 141 | /// not exist, then this panics. |
| 142 | pub(crate) fn get(&self, id: PatternID) -> Pattern<'_> { |
| 143 | Pattern(&self.by_id[id]) |
| 144 | } |
| 145 | |
| 146 | /// Return the pattern with the given identifier without performing bounds |
| 147 | /// checks. |
| 148 | /// |
| 149 | /// # Safety |
| 150 | /// |
| 151 | /// Callers must ensure that a pattern with the given identifier exists |
| 152 | /// before using this method. |
| 153 | pub(crate) unsafe fn get_unchecked(&self, id: PatternID) -> Pattern<'_> { |
| 154 | Pattern(self.by_id.get_unchecked(id.as_usize())) |
| 155 | } |
| 156 | |
| 157 | /// Return an iterator over all the patterns in this collection, in the |
| 158 | /// order in which they should be matched. |
| 159 | /// |
| 160 | /// Specifically, in a naive multi-pattern matcher, the following is |
| 161 | /// guaranteed to satisfy the match semantics of this collection of |
| 162 | /// patterns: |
| 163 | /// |
| 164 | /// ```ignore |
| 165 | /// for i in 0..haystack.len(): |
| 166 | /// for p in patterns.iter(): |
| 167 | /// if haystack[i..].starts_with(p.bytes()): |
| 168 | /// return Match(p.id(), i, i + p.bytes().len()) |
| 169 | /// ``` |
| 170 | /// |
| 171 | /// Namely, among the patterns in a collection, if they are matched in |
| 172 | /// the order provided by this iterator, then the result is guaranteed |
| 173 | /// to satisfy the correct match semantics. (Either leftmost-first or |
| 174 | /// leftmost-longest.) |
| 175 | pub(crate) fn iter(&self) -> PatternIter<'_> { |
| 176 | PatternIter { patterns: self, i: 0 } |
| 177 | } |
| 178 | } |
| 179 | |
| 180 | /// An iterator over the patterns in the `Patterns` collection. |
| 181 | /// |
| 182 | /// The order of the patterns provided by this iterator is consistent with the |
| 183 | /// match semantics of the originating collection of patterns. |
| 184 | /// |
| 185 | /// The lifetime `'p` corresponds to the lifetime of the collection of patterns |
| 186 | /// this is iterating over. |
| 187 | #[derive (Debug)] |
| 188 | pub(crate) struct PatternIter<'p> { |
| 189 | patterns: &'p Patterns, |
| 190 | i: usize, |
| 191 | } |
| 192 | |
| 193 | impl<'p> Iterator for PatternIter<'p> { |
| 194 | type Item = (PatternID, Pattern<'p>); |
| 195 | |
| 196 | fn next(&mut self) -> Option<(PatternID, Pattern<'p>)> { |
| 197 | if self.i >= self.patterns.len() { |
| 198 | return None; |
| 199 | } |
| 200 | let id: PatternID = self.patterns.order[self.i]; |
| 201 | let p: Pattern<'_> = self.patterns.get(id); |
| 202 | self.i += 1; |
| 203 | Some((id, p)) |
| 204 | } |
| 205 | } |
| 206 | |
| 207 | /// A pattern that is used in packed searching. |
| 208 | #[derive (Clone)] |
| 209 | pub(crate) struct Pattern<'a>(&'a [u8]); |
| 210 | |
| 211 | impl<'a> fmt::Debug for Pattern<'a> { |
| 212 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
| 213 | f&mut DebugStruct<'_, '_>.debug_struct("Pattern" ) |
| 214 | .field(name:"lit" , &String::from_utf8_lossy(&self.0)) |
| 215 | .finish() |
| 216 | } |
| 217 | } |
| 218 | |
| 219 | impl<'p> Pattern<'p> { |
| 220 | /// Returns the length of this pattern, in bytes. |
| 221 | pub(crate) fn len(&self) -> usize { |
| 222 | self.0.len() |
| 223 | } |
| 224 | |
| 225 | /// Returns the bytes of this pattern. |
| 226 | pub(crate) fn bytes(&self) -> &[u8] { |
| 227 | &self.0 |
| 228 | } |
| 229 | |
| 230 | /// Returns the first `len` low nybbles from this pattern. If this pattern |
| 231 | /// is shorter than `len`, then this panics. |
| 232 | pub(crate) fn low_nybbles(&self, len: usize) -> Box<[u8]> { |
| 233 | let mut nybs = vec![0; len].into_boxed_slice(); |
| 234 | for (i, byte) in self.bytes().iter().take(len).enumerate() { |
| 235 | nybs[i] = byte & 0xF; |
| 236 | } |
| 237 | nybs |
| 238 | } |
| 239 | |
| 240 | /// Returns true if this pattern is a prefix of the given bytes. |
| 241 | #[inline (always)] |
| 242 | pub(crate) fn is_prefix(&self, bytes: &[u8]) -> bool { |
| 243 | is_prefix(bytes, self.bytes()) |
| 244 | } |
| 245 | |
| 246 | /// Returns true if this pattern is a prefix of the haystack given by the |
| 247 | /// raw `start` and `end` pointers. |
| 248 | /// |
| 249 | /// # Safety |
| 250 | /// |
| 251 | /// * It must be the case that `start < end` and that the distance between |
| 252 | /// them is at least equal to `V::BYTES`. That is, it must always be valid |
| 253 | /// to do at least an unaligned load of `V` at `start`. |
| 254 | /// * Both `start` and `end` must be valid for reads. |
| 255 | /// * Both `start` and `end` must point to an initialized value. |
| 256 | /// * Both `start` and `end` must point to the same allocated object and |
| 257 | /// must either be in bounds or at most one byte past the end of the |
| 258 | /// allocated object. |
| 259 | /// * Both `start` and `end` must be _derived from_ a pointer to the same |
| 260 | /// object. |
| 261 | /// * The distance between `start` and `end` must not overflow `isize`. |
| 262 | /// * The distance being in bounds must not rely on "wrapping around" the |
| 263 | /// address space. |
| 264 | #[inline (always)] |
| 265 | pub(crate) unsafe fn is_prefix_raw( |
| 266 | &self, |
| 267 | start: *const u8, |
| 268 | end: *const u8, |
| 269 | ) -> bool { |
| 270 | let patlen = self.bytes().len(); |
| 271 | let haylen = end.distance(start); |
| 272 | if patlen > haylen { |
| 273 | return false; |
| 274 | } |
| 275 | // SAFETY: We've checked that the haystack has length at least equal |
| 276 | // to this pattern. All other safety concerns are the responsibility |
| 277 | // of the caller. |
| 278 | is_equal_raw(start, self.bytes().as_ptr(), patlen) |
| 279 | } |
| 280 | } |
| 281 | |
| 282 | /// Returns true if and only if `needle` is a prefix of `haystack`. |
| 283 | /// |
| 284 | /// This uses a latency optimized variant of `memcmp` internally which *might* |
| 285 | /// make this faster for very short strings. |
| 286 | /// |
| 287 | /// # Inlining |
| 288 | /// |
| 289 | /// This routine is marked `inline(always)`. If you want to call this function |
| 290 | /// in a way that is not always inlined, you'll need to wrap a call to it in |
| 291 | /// another function that is marked as `inline(never)` or just `inline`. |
| 292 | #[inline (always)] |
| 293 | fn is_prefix(haystack: &[u8], needle: &[u8]) -> bool { |
| 294 | if needle.len() > haystack.len() { |
| 295 | return false; |
| 296 | } |
| 297 | // SAFETY: Our pointers are derived directly from borrowed slices which |
| 298 | // uphold all of our safety guarantees except for length. We account for |
| 299 | // length with the check above. |
| 300 | unsafe { is_equal_raw(x:haystack.as_ptr(), y:needle.as_ptr(), n:needle.len()) } |
| 301 | } |
| 302 | |
| 303 | /// Compare corresponding bytes in `x` and `y` for equality. |
| 304 | /// |
| 305 | /// That is, this returns true if and only if `x.len() == y.len()` and |
| 306 | /// `x[i] == y[i]` for all `0 <= i < x.len()`. |
| 307 | /// |
| 308 | /// Note that this isn't used. We only use it in tests as a convenient way |
| 309 | /// of testing `is_equal_raw`. |
| 310 | /// |
| 311 | /// # Inlining |
| 312 | /// |
| 313 | /// This routine is marked `inline(always)`. If you want to call this function |
| 314 | /// in a way that is not always inlined, you'll need to wrap a call to it in |
| 315 | /// another function that is marked as `inline(never)` or just `inline`. |
| 316 | /// |
| 317 | /// # Motivation |
| 318 | /// |
| 319 | /// Why not use slice equality instead? Well, slice equality usually results in |
| 320 | /// a call out to the current platform's `libc` which might not be inlineable |
| 321 | /// or have other overhead. This routine isn't guaranteed to be a win, but it |
| 322 | /// might be in some cases. |
| 323 | #[cfg (test)] |
| 324 | #[inline (always)] |
| 325 | fn is_equal(x: &[u8], y: &[u8]) -> bool { |
| 326 | if x.len() != y.len() { |
| 327 | return false; |
| 328 | } |
| 329 | // SAFETY: Our pointers are derived directly from borrowed slices which |
| 330 | // uphold all of our safety guarantees except for length. We account for |
| 331 | // length with the check above. |
| 332 | unsafe { is_equal_raw(x.as_ptr(), y.as_ptr(), x.len()) } |
| 333 | } |
| 334 | |
| 335 | /// Compare `n` bytes at the given pointers for equality. |
| 336 | /// |
| 337 | /// This returns true if and only if `*x.add(i) == *y.add(i)` for all |
| 338 | /// `0 <= i < n`. |
| 339 | /// |
| 340 | /// # Inlining |
| 341 | /// |
| 342 | /// This routine is marked `inline(always)`. If you want to call this function |
| 343 | /// in a way that is not always inlined, you'll need to wrap a call to it in |
| 344 | /// another function that is marked as `inline(never)` or just `inline`. |
| 345 | /// |
| 346 | /// # Motivation |
| 347 | /// |
| 348 | /// Why not use slice equality instead? Well, slice equality usually results in |
| 349 | /// a call out to the current platform's `libc` which might not be inlineable |
| 350 | /// or have other overhead. This routine isn't guaranteed to be a win, but it |
| 351 | /// might be in some cases. |
| 352 | /// |
| 353 | /// # Safety |
| 354 | /// |
| 355 | /// * Both `x` and `y` must be valid for reads of up to `n` bytes. |
| 356 | /// * Both `x` and `y` must point to an initialized value. |
| 357 | /// * Both `x` and `y` must each point to an allocated object and |
| 358 | /// must either be in bounds or at most one byte past the end of the |
| 359 | /// allocated object. `x` and `y` do not need to point to the same allocated |
| 360 | /// object, but they may. |
| 361 | /// * Both `x` and `y` must be _derived from_ a pointer to their respective |
| 362 | /// allocated objects. |
| 363 | /// * The distance between `x` and `x+n` must not overflow `isize`. Similarly |
| 364 | /// for `y` and `y+n`. |
| 365 | /// * The distance being in bounds must not rely on "wrapping around" the |
| 366 | /// address space. |
| 367 | #[inline (always)] |
| 368 | unsafe fn is_equal_raw(mut x: *const u8, mut y: *const u8, n: usize) -> bool { |
| 369 | // If we don't have enough bytes to do 4-byte at a time loads, then |
| 370 | // handle each possible length specially. Note that I used to have a |
| 371 | // byte-at-a-time loop here and that turned out to be quite a bit slower |
| 372 | // for the memmem/pathological/defeat-simple-vector-alphabet benchmark. |
| 373 | if n < 4 { |
| 374 | return match n { |
| 375 | 0 => true, |
| 376 | 1 => x.read() == y.read(), |
| 377 | 2 => { |
| 378 | x.cast::<u16>().read_unaligned() |
| 379 | == y.cast::<u16>().read_unaligned() |
| 380 | } |
| 381 | // I also tried copy_nonoverlapping here and it looks like the |
| 382 | // codegen is the same. |
| 383 | 3 => x.cast::<[u8; 3]>().read() == y.cast::<[u8; 3]>().read(), |
| 384 | _ => unreachable!(), |
| 385 | }; |
| 386 | } |
| 387 | // When we have 4 or more bytes to compare, then proceed in chunks of 4 at |
| 388 | // a time using unaligned loads. |
| 389 | // |
| 390 | // Also, why do 4 byte loads instead of, say, 8 byte loads? The reason is |
| 391 | // that this particular version of memcmp is likely to be called with tiny |
| 392 | // needles. That means that if we do 8 byte loads, then a higher proportion |
| 393 | // of memcmp calls will use the slower variant above. With that said, this |
| 394 | // is a hypothesis and is only loosely supported by benchmarks. There's |
| 395 | // likely some improvement that could be made here. The main thing here |
| 396 | // though is to optimize for latency, not throughput. |
| 397 | |
| 398 | // SAFETY: The caller is responsible for ensuring the pointers we get are |
| 399 | // valid and readable for at least `n` bytes. We also do unaligned loads, |
| 400 | // so there's no need to ensure we're aligned. (This is justified by this |
| 401 | // routine being specifically for short strings.) |
| 402 | let xend = x.add(n.wrapping_sub(4)); |
| 403 | let yend = y.add(n.wrapping_sub(4)); |
| 404 | while x < xend { |
| 405 | let vx = x.cast::<u32>().read_unaligned(); |
| 406 | let vy = y.cast::<u32>().read_unaligned(); |
| 407 | if vx != vy { |
| 408 | return false; |
| 409 | } |
| 410 | x = x.add(4); |
| 411 | y = y.add(4); |
| 412 | } |
| 413 | let vx = xend.cast::<u32>().read_unaligned(); |
| 414 | let vy = yend.cast::<u32>().read_unaligned(); |
| 415 | vx == vy |
| 416 | } |
| 417 | |
| 418 | #[cfg (test)] |
| 419 | mod tests { |
| 420 | use super::*; |
| 421 | |
| 422 | #[test ] |
| 423 | fn equals_different_lengths() { |
| 424 | assert!(!is_equal(b"" , b"a" )); |
| 425 | assert!(!is_equal(b"a" , b"" )); |
| 426 | assert!(!is_equal(b"ab" , b"a" )); |
| 427 | assert!(!is_equal(b"a" , b"ab" )); |
| 428 | } |
| 429 | |
| 430 | #[test ] |
| 431 | fn equals_mismatch() { |
| 432 | let one_mismatch = [ |
| 433 | (&b"a" [..], &b"x" [..]), |
| 434 | (&b"ab" [..], &b"ax" [..]), |
| 435 | (&b"abc" [..], &b"abx" [..]), |
| 436 | (&b"abcd" [..], &b"abcx" [..]), |
| 437 | (&b"abcde" [..], &b"abcdx" [..]), |
| 438 | (&b"abcdef" [..], &b"abcdex" [..]), |
| 439 | (&b"abcdefg" [..], &b"abcdefx" [..]), |
| 440 | (&b"abcdefgh" [..], &b"abcdefgx" [..]), |
| 441 | (&b"abcdefghi" [..], &b"abcdefghx" [..]), |
| 442 | (&b"abcdefghij" [..], &b"abcdefghix" [..]), |
| 443 | (&b"abcdefghijk" [..], &b"abcdefghijx" [..]), |
| 444 | (&b"abcdefghijkl" [..], &b"abcdefghijkx" [..]), |
| 445 | (&b"abcdefghijklm" [..], &b"abcdefghijklx" [..]), |
| 446 | (&b"abcdefghijklmn" [..], &b"abcdefghijklmx" [..]), |
| 447 | ]; |
| 448 | for (x, y) in one_mismatch { |
| 449 | assert_eq!(x.len(), y.len(), "lengths should match" ); |
| 450 | assert!(!is_equal(x, y)); |
| 451 | assert!(!is_equal(y, x)); |
| 452 | } |
| 453 | } |
| 454 | |
| 455 | #[test ] |
| 456 | fn equals_yes() { |
| 457 | assert!(is_equal(b"" , b"" )); |
| 458 | assert!(is_equal(b"a" , b"a" )); |
| 459 | assert!(is_equal(b"ab" , b"ab" )); |
| 460 | assert!(is_equal(b"abc" , b"abc" )); |
| 461 | assert!(is_equal(b"abcd" , b"abcd" )); |
| 462 | assert!(is_equal(b"abcde" , b"abcde" )); |
| 463 | assert!(is_equal(b"abcdef" , b"abcdef" )); |
| 464 | assert!(is_equal(b"abcdefg" , b"abcdefg" )); |
| 465 | assert!(is_equal(b"abcdefgh" , b"abcdefgh" )); |
| 466 | assert!(is_equal(b"abcdefghi" , b"abcdefghi" )); |
| 467 | } |
| 468 | |
| 469 | #[test ] |
| 470 | fn prefix() { |
| 471 | assert!(is_prefix(b"" , b"" )); |
| 472 | assert!(is_prefix(b"a" , b"" )); |
| 473 | assert!(is_prefix(b"ab" , b"" )); |
| 474 | assert!(is_prefix(b"foo" , b"foo" )); |
| 475 | assert!(is_prefix(b"foobar" , b"foo" )); |
| 476 | |
| 477 | assert!(!is_prefix(b"foo" , b"fob" )); |
| 478 | assert!(!is_prefix(b"foobar" , b"fob" )); |
| 479 | } |
| 480 | } |
| 481 | |