1 | use core::{cmp, fmt, mem, u16, usize}; |
2 | |
3 | use alloc::{string::String, vec, vec::Vec}; |
4 | |
5 | use crate::packed::api::MatchKind; |
6 | |
7 | /// The type used for representing a pattern identifier. |
8 | /// |
9 | /// We don't use `usize` here because our packed searchers don't scale to |
10 | /// huge numbers of patterns, so we keep things a bit smaller. |
11 | pub type PatternID = u16; |
12 | |
13 | /// A non-empty collection of non-empty patterns to search for. |
14 | /// |
15 | /// This collection of patterns is what is passed around to both execute |
16 | /// searches and to construct the searchers themselves. Namely, this permits |
17 | /// searches to avoid copying all of the patterns, and allows us to keep only |
18 | /// one copy throughout all packed searchers. |
19 | /// |
20 | /// Note that this collection is not a set. The same pattern can appear more |
21 | /// than once. |
22 | #[derive (Clone, Debug)] |
23 | pub struct Patterns { |
24 | /// The match semantics supported by this collection of patterns. |
25 | /// |
26 | /// The match semantics determines the order of the iterator over patterns. |
27 | /// For leftmost-first, patterns are provided in the same order as were |
28 | /// provided by the caller. For leftmost-longest, patterns are provided in |
29 | /// descending order of length, with ties broken by the order in which they |
30 | /// were provided by the caller. |
31 | kind: MatchKind, |
32 | /// The collection of patterns, indexed by their identifier. |
33 | by_id: Vec<Vec<u8>>, |
34 | /// The order of patterns defined for iteration, given by pattern |
35 | /// identifiers. The order of `by_id` and `order` is always the same for |
36 | /// leftmost-first semantics, but may be different for leftmost-longest |
37 | /// semantics. |
38 | order: Vec<PatternID>, |
39 | /// The length of the smallest pattern, in bytes. |
40 | minimum_len: usize, |
41 | /// The largest pattern identifier. This should always be equivalent to |
42 | /// the number of patterns minus one in this collection. |
43 | max_pattern_id: PatternID, |
44 | /// The total number of pattern bytes across the entire collection. This |
45 | /// is used for reporting total heap usage in constant time. |
46 | total_pattern_bytes: usize, |
47 | } |
48 | |
49 | impl Patterns { |
50 | /// Create a new collection of patterns for the given match semantics. The |
51 | /// ID of each pattern is the index of the pattern at which it occurs in |
52 | /// the `by_id` slice. |
53 | /// |
54 | /// If any of the patterns in the slice given are empty, then this panics. |
55 | /// Similarly, if the number of patterns given is zero, then this also |
56 | /// panics. |
57 | pub fn new() -> Patterns { |
58 | Patterns { |
59 | kind: MatchKind::default(), |
60 | by_id: vec![], |
61 | order: vec![], |
62 | minimum_len: usize::MAX, |
63 | max_pattern_id: 0, |
64 | total_pattern_bytes: 0, |
65 | } |
66 | } |
67 | |
68 | /// Add a pattern to this collection. |
69 | /// |
70 | /// This panics if the pattern given is empty. |
71 | pub fn add(&mut self, bytes: &[u8]) { |
72 | assert!(!bytes.is_empty()); |
73 | assert!(self.by_id.len() <= u16::MAX as usize); |
74 | |
75 | let id = self.by_id.len() as u16; |
76 | self.max_pattern_id = id; |
77 | self.order.push(id); |
78 | self.by_id.push(bytes.to_vec()); |
79 | self.minimum_len = cmp::min(self.minimum_len, bytes.len()); |
80 | self.total_pattern_bytes += bytes.len(); |
81 | } |
82 | |
83 | /// Set the match kind semantics for this collection of patterns. |
84 | /// |
85 | /// If the kind is not set, then the default is leftmost-first. |
86 | pub fn set_match_kind(&mut self, kind: MatchKind) { |
87 | self.kind = kind; |
88 | match self.kind { |
89 | MatchKind::LeftmostFirst => { |
90 | self.order.sort(); |
91 | } |
92 | MatchKind::LeftmostLongest => { |
93 | let (order, by_id) = (&mut self.order, &mut self.by_id); |
94 | order.sort_by(|&id1, &id2| { |
95 | by_id[id1 as usize] |
96 | .len() |
97 | .cmp(&by_id[id2 as usize].len()) |
98 | .reverse() |
99 | }); |
100 | } |
101 | } |
102 | } |
103 | |
104 | /// Return the number of patterns in this collection. |
105 | /// |
106 | /// This is guaranteed to be greater than zero. |
107 | pub fn len(&self) -> usize { |
108 | self.by_id.len() |
109 | } |
110 | |
111 | /// Returns true if and only if this collection of patterns is empty. |
112 | pub fn is_empty(&self) -> bool { |
113 | self.len() == 0 |
114 | } |
115 | |
116 | /// Returns the approximate total amount of heap used by these patterns, in |
117 | /// units of bytes. |
118 | pub fn memory_usage(&self) -> usize { |
119 | self.order.len() * mem::size_of::<PatternID>() |
120 | + self.by_id.len() * mem::size_of::<Vec<u8>>() |
121 | + self.total_pattern_bytes |
122 | } |
123 | |
124 | /// Clears all heap memory associated with this collection of patterns and |
125 | /// resets all state such that it is a valid empty collection. |
126 | pub fn reset(&mut self) { |
127 | self.kind = MatchKind::default(); |
128 | self.by_id.clear(); |
129 | self.order.clear(); |
130 | self.minimum_len = usize::MAX; |
131 | self.max_pattern_id = 0; |
132 | } |
133 | |
134 | /// Return the maximum pattern identifier in this collection. This can be |
135 | /// useful in searchers for ensuring that the collection of patterns they |
136 | /// are provided at search time and at build time have the same size. |
137 | pub fn max_pattern_id(&self) -> PatternID { |
138 | assert_eq!((self.max_pattern_id + 1) as usize, self.len()); |
139 | self.max_pattern_id |
140 | } |
141 | |
142 | /// Returns the length, in bytes, of the smallest pattern. |
143 | /// |
144 | /// This is guaranteed to be at least one. |
145 | pub fn minimum_len(&self) -> usize { |
146 | self.minimum_len |
147 | } |
148 | |
149 | /// Returns the match semantics used by these patterns. |
150 | pub fn match_kind(&self) -> &MatchKind { |
151 | &self.kind |
152 | } |
153 | |
154 | /// Return the pattern with the given identifier. If such a pattern does |
155 | /// not exist, then this panics. |
156 | pub fn get(&self, id: PatternID) -> Pattern<'_> { |
157 | Pattern(&self.by_id[id as usize]) |
158 | } |
159 | |
160 | /// Return the pattern with the given identifier without performing bounds |
161 | /// checks. |
162 | /// |
163 | /// # Safety |
164 | /// |
165 | /// Callers must ensure that a pattern with the given identifier exists |
166 | /// before using this method. |
167 | #[cfg (all(feature = "std" , target_arch = "x86_64" ))] |
168 | pub unsafe fn get_unchecked(&self, id: PatternID) -> Pattern<'_> { |
169 | Pattern(self.by_id.get_unchecked(id as usize)) |
170 | } |
171 | |
172 | /// Return an iterator over all the patterns in this collection, in the |
173 | /// order in which they should be matched. |
174 | /// |
175 | /// Specifically, in a naive multi-pattern matcher, the following is |
176 | /// guaranteed to satisfy the match semantics of this collection of |
177 | /// patterns: |
178 | /// |
179 | /// ```ignore |
180 | /// for i in 0..haystack.len(): |
181 | /// for p in patterns.iter(): |
182 | /// if haystack[i..].starts_with(p.bytes()): |
183 | /// return Match(p.id(), i, i + p.bytes().len()) |
184 | /// ``` |
185 | /// |
186 | /// Namely, among the patterns in a collection, if they are matched in |
187 | /// the order provided by this iterator, then the result is guaranteed |
188 | /// to satisfy the correct match semantics. (Either leftmost-first or |
189 | /// leftmost-longest.) |
190 | pub fn iter(&self) -> PatternIter<'_> { |
191 | PatternIter { patterns: self, i: 0 } |
192 | } |
193 | } |
194 | |
195 | /// An iterator over the patterns in the `Patterns` collection. |
196 | /// |
197 | /// The order of the patterns provided by this iterator is consistent with the |
198 | /// match semantics of the originating collection of patterns. |
199 | /// |
200 | /// The lifetime `'p` corresponds to the lifetime of the collection of patterns |
201 | /// this is iterating over. |
202 | #[derive (Debug)] |
203 | pub struct PatternIter<'p> { |
204 | patterns: &'p Patterns, |
205 | i: usize, |
206 | } |
207 | |
208 | impl<'p> Iterator for PatternIter<'p> { |
209 | type Item = (PatternID, Pattern<'p>); |
210 | |
211 | fn next(&mut self) -> Option<(PatternID, Pattern<'p>)> { |
212 | if self.i >= self.patterns.len() { |
213 | return None; |
214 | } |
215 | let id: u16 = self.patterns.order[self.i]; |
216 | let p: Pattern<'_> = self.patterns.get(id); |
217 | self.i += 1; |
218 | Some((id, p)) |
219 | } |
220 | } |
221 | |
222 | /// A pattern that is used in packed searching. |
223 | #[derive (Clone)] |
224 | pub struct Pattern<'a>(&'a [u8]); |
225 | |
226 | impl<'a> fmt::Debug for Pattern<'a> { |
227 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
228 | f&mut DebugStruct<'_, '_>.debug_struct("Pattern" ) |
229 | .field(name:"lit" , &String::from_utf8_lossy(&self.0)) |
230 | .finish() |
231 | } |
232 | } |
233 | |
234 | impl<'p> Pattern<'p> { |
235 | /// Returns the length of this pattern, in bytes. |
236 | pub fn len(&self) -> usize { |
237 | self.0.len() |
238 | } |
239 | |
240 | /// Returns the bytes of this pattern. |
241 | pub fn bytes(&self) -> &[u8] { |
242 | &self.0 |
243 | } |
244 | |
245 | /// Returns the first `len` low nybbles from this pattern. If this pattern |
246 | /// is shorter than `len`, then this panics. |
247 | #[cfg (all(feature = "std" , target_arch = "x86_64" ))] |
248 | pub fn low_nybbles(&self, len: usize) -> Vec<u8> { |
249 | let mut nybs = vec![]; |
250 | for &b in self.bytes().iter().take(len) { |
251 | nybs.push(b & 0xF); |
252 | } |
253 | nybs |
254 | } |
255 | |
256 | /// Returns true if this pattern is a prefix of the given bytes. |
257 | #[inline (always)] |
258 | pub fn is_prefix(&self, bytes: &[u8]) -> bool { |
259 | self.len() <= bytes.len() && self.equals(&bytes[..self.len()]) |
260 | } |
261 | |
262 | /// Returns true if and only if this pattern equals the given bytes. |
263 | #[inline (always)] |
264 | pub fn equals(&self, bytes: &[u8]) -> bool { |
265 | // Why not just use memcmp for this? Well, memcmp requires calling out |
266 | // to libc, and this routine is called in fairly hot code paths. Other |
267 | // than just calling out to libc, it also seems to result in worse |
268 | // codegen. By rolling our own memcpy in pure Rust, it seems to appear |
269 | // more friendly to the optimizer. |
270 | // |
271 | // This results in an improvement in just about every benchmark. Some |
272 | // smaller than others, but in some cases, up to 30% faster. |
273 | |
274 | let (x, y) = (self.bytes(), bytes); |
275 | if x.len() != y.len() { |
276 | return false; |
277 | } |
278 | // If we don't have enough bytes to do 4-byte at a time loads, then |
279 | // fall back to the naive slow version. |
280 | if x.len() < 4 { |
281 | for (&b1, &b2) in x.iter().zip(y) { |
282 | if b1 != b2 { |
283 | return false; |
284 | } |
285 | } |
286 | return true; |
287 | } |
288 | // When we have 4 or more bytes to compare, then proceed in chunks of 4 |
289 | // at a time using unaligned loads. |
290 | // |
291 | // Also, why do 4 byte loads instead of, say, 8 byte loads? The reason |
292 | // is that this particular version of memcmp is likely to be called |
293 | // with tiny needles. That means that if we do 8 byte loads, then a |
294 | // higher proportion of memcmp calls will use the slower variant above. |
295 | // With that said, this is a hypothesis and is only loosely supported |
296 | // by benchmarks. There's likely some improvement that could be made |
297 | // here. The main thing here though is to optimize for latency, not |
298 | // throughput. |
299 | |
300 | // SAFETY: Via the conditional above, we know that both `px` and `py` |
301 | // have the same length, so `px < pxend` implies that `py < pyend`. |
302 | // Thus, derefencing both `px` and `py` in the loop below is safe. |
303 | // |
304 | // Moreover, we set `pxend` and `pyend` to be 4 bytes before the actual |
305 | // end of of `px` and `py`. Thus, the final dereference outside of the |
306 | // loop is guaranteed to be valid. (The final comparison will overlap |
307 | // with the last comparison done in the loop for lengths that aren't |
308 | // multiples of four.) |
309 | // |
310 | // Finally, we needn't worry about alignment here, since we do |
311 | // unaligned loads. |
312 | unsafe { |
313 | let (mut px, mut py) = (x.as_ptr(), y.as_ptr()); |
314 | let (pxend, pyend) = (px.add(x.len() - 4), py.add(y.len() - 4)); |
315 | while px < pxend { |
316 | let vx = (px as *const u32).read_unaligned(); |
317 | let vy = (py as *const u32).read_unaligned(); |
318 | if vx != vy { |
319 | return false; |
320 | } |
321 | px = px.add(4); |
322 | py = py.add(4); |
323 | } |
324 | let vx = (pxend as *const u32).read_unaligned(); |
325 | let vy = (pyend as *const u32).read_unaligned(); |
326 | vx == vy |
327 | } |
328 | } |
329 | } |
330 | |