1use std::cmp;
2use std::fmt;
3use std::mem;
4use std::u16;
5use std::usize;
6
7use crate::packed::api::MatchKind;
8
9/// The type used for representing a pattern identifier.
10///
11/// We don't use `usize` here because our packed searchers don't scale to
12/// huge numbers of patterns, so we keep things a bit smaller.
13pub type PatternID = u16;
14
15/// A non-empty collection of non-empty patterns to search for.
16///
17/// This collection of patterns is what is passed around to both execute
18/// searches and to construct the searchers themselves. Namely, this permits
19/// searches to avoid copying all of the patterns, and allows us to keep only
20/// one copy throughout all packed searchers.
21///
22/// Note that this collection is not a set. The same pattern can appear more
23/// than once.
24#[derive(Clone, Debug)]
25pub struct Patterns {
26 /// The match semantics supported by this collection of patterns.
27 ///
28 /// The match semantics determines the order of the iterator over patterns.
29 /// For leftmost-first, patterns are provided in the same order as were
30 /// provided by the caller. For leftmost-longest, patterns are provided in
31 /// descending order of length, with ties broken by the order in which they
32 /// were provided by the caller.
33 kind: MatchKind,
34 /// The collection of patterns, indexed by their identifier.
35 by_id: Vec<Vec<u8>>,
36 /// The order of patterns defined for iteration, given by pattern
37 /// identifiers. The order of `by_id` and `order` is always the same for
38 /// leftmost-first semantics, but may be different for leftmost-longest
39 /// semantics.
40 order: Vec<PatternID>,
41 /// The length of the smallest pattern, in bytes.
42 minimum_len: usize,
43 /// The largest pattern identifier. This should always be equivalent to
44 /// the number of patterns minus one in this collection.
45 max_pattern_id: PatternID,
46 /// The total number of pattern bytes across the entire collection. This
47 /// is used for reporting total heap usage in constant time.
48 total_pattern_bytes: usize,
49}
50
51impl Patterns {
52 /// Create a new collection of patterns for the given match semantics. The
53 /// ID of each pattern is the index of the pattern at which it occurs in
54 /// the `by_id` slice.
55 ///
56 /// If any of the patterns in the slice given are empty, then this panics.
57 /// Similarly, if the number of patterns given is zero, then this also
58 /// panics.
59 pub fn new() -> Patterns {
60 Patterns {
61 kind: MatchKind::default(),
62 by_id: vec![],
63 order: vec![],
64 minimum_len: usize::MAX,
65 max_pattern_id: 0,
66 total_pattern_bytes: 0,
67 }
68 }
69
70 /// Add a pattern to this collection.
71 ///
72 /// This panics if the pattern given is empty.
73 pub fn add(&mut self, bytes: &[u8]) {
74 assert!(!bytes.is_empty());
75 assert!(self.by_id.len() <= u16::MAX as usize);
76
77 let id = self.by_id.len() as u16;
78 self.max_pattern_id = id;
79 self.order.push(id);
80 self.by_id.push(bytes.to_vec());
81 self.minimum_len = cmp::min(self.minimum_len, bytes.len());
82 self.total_pattern_bytes += bytes.len();
83 }
84
85 /// Set the match kind semantics for this collection of patterns.
86 ///
87 /// If the kind is not set, then the default is leftmost-first.
88 pub fn set_match_kind(&mut self, kind: MatchKind) {
89 match kind {
90 MatchKind::LeftmostFirst => {
91 self.order.sort();
92 }
93 MatchKind::LeftmostLongest => {
94 let (order, by_id) = (&mut self.order, &mut self.by_id);
95 order.sort_by(|&id1, &id2| {
96 by_id[id1 as usize]
97 .len()
98 .cmp(&by_id[id2 as usize].len())
99 .reverse()
100 });
101 }
102 MatchKind::__Nonexhaustive => unreachable!(),
103 }
104 }
105
106 /// Return the number of patterns in this collection.
107 ///
108 /// This is guaranteed to be greater than zero.
109 pub fn len(&self) -> usize {
110 self.by_id.len()
111 }
112
113 /// Returns true if and only if this collection of patterns is empty.
114 pub fn is_empty(&self) -> bool {
115 self.len() == 0
116 }
117
118 /// Returns the approximate total amount of heap used by these patterns, in
119 /// units of bytes.
120 pub fn heap_bytes(&self) -> usize {
121 self.order.len() * mem::size_of::<PatternID>()
122 + self.by_id.len() * mem::size_of::<Vec<u8>>()
123 + self.total_pattern_bytes
124 }
125
126 /// Clears all heap memory associated with this collection of patterns and
127 /// resets all state such that it is a valid empty collection.
128 pub fn reset(&mut self) {
129 self.kind = MatchKind::default();
130 self.by_id.clear();
131 self.order.clear();
132 self.minimum_len = usize::MAX;
133 self.max_pattern_id = 0;
134 }
135
136 /// Return the maximum pattern identifier in this collection. This can be
137 /// useful in searchers for ensuring that the collection of patterns they
138 /// are provided at search time and at build time have the same size.
139 pub fn max_pattern_id(&self) -> PatternID {
140 assert_eq!((self.max_pattern_id + 1) as usize, self.len());
141 self.max_pattern_id
142 }
143
144 /// Returns the length, in bytes, of the smallest pattern.
145 ///
146 /// This is guaranteed to be at least one.
147 pub fn minimum_len(&self) -> usize {
148 self.minimum_len
149 }
150
151 /// Returns the match semantics used by these patterns.
152 pub fn match_kind(&self) -> &MatchKind {
153 &self.kind
154 }
155
156 /// Return the pattern with the given identifier. If such a pattern does
157 /// not exist, then this panics.
158 pub fn get(&self, id: PatternID) -> Pattern<'_> {
159 Pattern(&self.by_id[id as usize])
160 }
161
162 /// Return the pattern with the given identifier without performing bounds
163 /// checks.
164 ///
165 /// # Safety
166 ///
167 /// Callers must ensure that a pattern with the given identifier exists
168 /// before using this method.
169 #[cfg(target_arch = "x86_64")]
170 pub unsafe fn get_unchecked(&self, id: PatternID) -> Pattern<'_> {
171 Pattern(self.by_id.get_unchecked(id as usize))
172 }
173
174 /// Return an iterator over all the patterns in this collection, in the
175 /// order in which they should be matched.
176 ///
177 /// Specifically, in a naive multi-pattern matcher, the following is
178 /// guaranteed to satisfy the match semantics of this collection of
179 /// patterns:
180 ///
181 /// ```ignore
182 /// for i in 0..haystack.len():
183 /// for p in patterns.iter():
184 /// if haystack[i..].starts_with(p.bytes()):
185 /// return Match(p.id(), i, i + p.bytes().len())
186 /// ```
187 ///
188 /// Namely, among the patterns in a collection, if they are matched in
189 /// the order provided by this iterator, then the result is guaranteed
190 /// to satisfy the correct match semantics. (Either leftmost-first or
191 /// leftmost-longest.)
192 pub fn iter(&self) -> PatternIter<'_> {
193 PatternIter { patterns: self, i: 0 }
194 }
195}
196
197/// An iterator over the patterns in the `Patterns` collection.
198///
199/// The order of the patterns provided by this iterator is consistent with the
200/// match semantics of the originating collection of patterns.
201///
202/// The lifetime `'p` corresponds to the lifetime of the collection of patterns
203/// this is iterating over.
204#[derive(Debug)]
205pub struct PatternIter<'p> {
206 patterns: &'p Patterns,
207 i: usize,
208}
209
210impl<'p> Iterator for PatternIter<'p> {
211 type Item = (PatternID, Pattern<'p>);
212
213 fn next(&mut self) -> Option<(PatternID, Pattern<'p>)> {
214 if self.i >= self.patterns.len() {
215 return None;
216 }
217 let id: u16 = self.patterns.order[self.i];
218 let p: Pattern<'_> = self.patterns.get(id);
219 self.i += 1;
220 Some((id, p))
221 }
222}
223
224/// A pattern that is used in packed searching.
225#[derive(Clone)]
226pub struct Pattern<'a>(&'a [u8]);
227
228impl<'a> fmt::Debug for Pattern<'a> {
229 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
230 f&mut DebugStruct<'_, '_>.debug_struct("Pattern")
231 .field(name:"lit", &String::from_utf8_lossy(&self.0))
232 .finish()
233 }
234}
235
236impl<'p> Pattern<'p> {
237 /// Returns the length of this pattern, in bytes.
238 pub fn len(&self) -> usize {
239 self.0.len()
240 }
241
242 /// Returns the bytes of this pattern.
243 pub fn bytes(&self) -> &[u8] {
244 &self.0
245 }
246
247 /// Returns the first `len` low nybbles from this pattern. If this pattern
248 /// is shorter than `len`, then this panics.
249 #[cfg(target_arch = "x86_64")]
250 pub fn low_nybbles(&self, len: usize) -> Vec<u8> {
251 let mut nybs = vec![];
252 for &b in self.bytes().iter().take(len) {
253 nybs.push(b & 0xF);
254 }
255 nybs
256 }
257
258 /// Returns true if this pattern is a prefix of the given bytes.
259 #[inline(always)]
260 pub fn is_prefix(&self, bytes: &[u8]) -> bool {
261 self.len() <= bytes.len() && self.equals(&bytes[..self.len()])
262 }
263
264 /// Returns true if and only if this pattern equals the given bytes.
265 #[inline(always)]
266 pub fn equals(&self, bytes: &[u8]) -> bool {
267 // Why not just use memcmp for this? Well, memcmp requires calling out
268 // to libc, and this routine is called in fairly hot code paths. Other
269 // than just calling out to libc, it also seems to result in worse
270 // codegen. By rolling our own memcpy in pure Rust, it seems to appear
271 // more friendly to the optimizer.
272 //
273 // This results in an improvement in just about every benchmark. Some
274 // smaller than others, but in some cases, up to 30% faster.
275
276 if self.len() != bytes.len() {
277 return false;
278 }
279 if self.len() < 8 {
280 for (&b1, &b2) in self.bytes().iter().zip(bytes) {
281 if b1 != b2 {
282 return false;
283 }
284 }
285 return true;
286 }
287 // When we have 8 or more bytes to compare, then proceed in chunks of
288 // 8 at a time using unaligned loads.
289 let mut p1 = self.bytes().as_ptr();
290 let mut p2 = bytes.as_ptr();
291 let p1end = self.bytes()[self.len() - 8..].as_ptr();
292 let p2end = bytes[bytes.len() - 8..].as_ptr();
293 // SAFETY: Via the conditional above, we know that both `p1` and `p2`
294 // have the same length, so `p1 < p1end` implies that `p2 < p2end`.
295 // Thus, derefencing both `p1` and `p2` in the loop below is safe.
296 //
297 // Moreover, we set `p1end` and `p2end` to be 8 bytes before the actual
298 // end of of `p1` and `p2`. Thus, the final dereference outside of the
299 // loop is guaranteed to be valid.
300 //
301 // Finally, we needn't worry about 64-bit alignment here, since we
302 // do unaligned loads.
303 unsafe {
304 while p1 < p1end {
305 let v1 = (p1 as *const u64).read_unaligned();
306 let v2 = (p2 as *const u64).read_unaligned();
307 if v1 != v2 {
308 return false;
309 }
310 p1 = p1.add(8);
311 p2 = p2.add(8);
312 }
313 let v1 = (p1end as *const u64).read_unaligned();
314 let v2 = (p2end as *const u64).read_unaligned();
315 v1 == v2
316 }
317 }
318}
319