1use core::mem::size_of;
2
3use crate::memmem::{
4 prefilter::{PrefilterFnTy, PrefilterState},
5 vector::Vector,
6 NeedleInfo,
7};
8
9/// The implementation of the forward vector accelerated candidate finder.
10///
11/// This is inspired by the "generic SIMD" algorithm described here:
12/// http://0x80.pl/articles/simd-strfind.html#algorithm-1-generic-simd
13///
14/// The main difference is that this is just a prefilter. That is, it reports
15/// candidates once they are seen and doesn't attempt to confirm them. Also,
16/// the bytes this routine uses to check for candidates are selected based on
17/// an a priori background frequency distribution. This means that on most
18/// haystacks, this will on average spend more time in vectorized code than you
19/// would if you just selected the first and last bytes of the needle.
20///
21/// Note that a non-prefilter variant of this algorithm can be found in the
22/// parent module, but it only works on smaller needles.
23///
24/// `prestate`, `ninfo`, `haystack` and `needle` are the four prefilter
25/// function parameters. `fallback` is a prefilter that is used if the haystack
26/// is too small to be handled with the given vector size.
27///
28/// This routine is not safe because it is intended for callers to specialize
29/// this with a particular vector (e.g., __m256i) and then call it with the
30/// relevant target feature (e.g., avx2) enabled.
31///
32/// # Panics
33///
34/// If `needle.len() <= 1`, then this panics.
35///
36/// # Safety
37///
38/// Since this is meant to be used with vector functions, callers need to
39/// specialize this inside of a function with a `target_feature` attribute.
40/// Therefore, callers must ensure that whatever target feature is being used
41/// supports the vector functions that this function is specialized for. (For
42/// the specific vector functions used, see the Vector trait implementations.)
43#[inline(always)]
44pub(crate) unsafe fn find<V: Vector>(
45 prestate: &mut PrefilterState,
46 ninfo: &NeedleInfo,
47 haystack: &[u8],
48 needle: &[u8],
49 fallback: PrefilterFnTy,
50) -> Option<usize> {
51 assert!(needle.len() >= 2, "needle must be at least 2 bytes");
52 let (rare1i, rare2i) = ninfo.rarebytes.as_rare_ordered_usize();
53 let min_haystack_len = rare2i + size_of::<V>();
54 if haystack.len() < min_haystack_len {
55 return fallback(prestate, ninfo, haystack, needle);
56 }
57
58 let start_ptr = haystack.as_ptr();
59 let end_ptr = start_ptr.add(haystack.len());
60 let max_ptr = end_ptr.sub(min_haystack_len);
61 let mut ptr = start_ptr;
62
63 let rare1chunk = V::splat(needle[rare1i]);
64 let rare2chunk = V::splat(needle[rare2i]);
65
66 // N.B. I did experiment with unrolling the loop to deal with size(V)
67 // bytes at a time and 2*size(V) bytes at a time. The double unroll
68 // was marginally faster while the quadruple unroll was unambiguously
69 // slower. In the end, I decided the complexity from unrolling wasn't
70 // worth it. I used the memmem/krate/prebuilt/huge-en/ benchmarks to
71 // compare.
72 while ptr <= max_ptr {
73 let m = find_in_chunk2(ptr, rare1i, rare2i, rare1chunk, rare2chunk);
74 if let Some(chunki) = m {
75 return Some(matched(prestate, start_ptr, ptr, chunki));
76 }
77 ptr = ptr.add(size_of::<V>());
78 }
79 if ptr < end_ptr {
80 // This routine immediately quits if a candidate match is found.
81 // That means that if we're here, no candidate matches have been
82 // found at or before 'ptr'. Thus, we don't need to mask anything
83 // out even though we might technically search part of the haystack
84 // that we've already searched (because we know it can't match).
85 ptr = max_ptr;
86 let m = find_in_chunk2(ptr, rare1i, rare2i, rare1chunk, rare2chunk);
87 if let Some(chunki) = m {
88 return Some(matched(prestate, start_ptr, ptr, chunki));
89 }
90 }
91 prestate.update(haystack.len());
92 None
93}
94
95// Below are two different techniques for checking whether a candidate
96// match exists in a given chunk or not. find_in_chunk2 checks two bytes
97// where as find_in_chunk3 checks three bytes. The idea behind checking
98// three bytes is that while we do a bit more work per iteration, we
99// decrease the chances of a false positive match being reported and thus
100// make the search faster overall. This actually works out for the
101// memmem/krate/prebuilt/huge-en/never-all-common-bytes benchmark, where
102// using find_in_chunk3 is about 25% faster than find_in_chunk2. However,
103// it turns out that find_in_chunk2 is faster for all other benchmarks, so
104// perhaps the extra check isn't worth it in practice.
105//
106// For now, we go with find_in_chunk2, but we leave find_in_chunk3 around
107// to make it easy to switch to and benchmark when possible.
108
109/// Search for an occurrence of two rare bytes from the needle in the current
110/// chunk pointed to by ptr.
111///
112/// rare1chunk and rare2chunk correspond to vectors with the rare1 and rare2
113/// bytes repeated in each 8-bit lane, respectively.
114///
115/// # Safety
116///
117/// It must be safe to do an unaligned read of size(V) bytes starting at both
118/// (ptr + rare1i) and (ptr + rare2i).
119#[inline(always)]
120unsafe fn find_in_chunk2<V: Vector>(
121 ptr: *const u8,
122 rare1i: usize,
123 rare2i: usize,
124 rare1chunk: V,
125 rare2chunk: V,
126) -> Option<usize> {
127 let chunk0: V = V::load_unaligned(data:ptr.add(count:rare1i));
128 let chunk1: V = V::load_unaligned(data:ptr.add(count:rare2i));
129
130 let eq0: V = chunk0.cmpeq(vector2:rare1chunk);
131 let eq1: V = chunk1.cmpeq(vector2:rare2chunk);
132
133 let match_offsets: u32 = eq0.and(vector2:eq1).movemask();
134 if match_offsets == 0 {
135 return None;
136 }
137 Some(match_offsets.trailing_zeros() as usize)
138}
139
140/// Search for an occurrence of two rare bytes and the first byte (even if one
141/// of the rare bytes is equivalent to the first byte) from the needle in the
142/// current chunk pointed to by ptr.
143///
144/// firstchunk, rare1chunk and rare2chunk correspond to vectors with the first,
145/// rare1 and rare2 bytes repeated in each 8-bit lane, respectively.
146///
147/// # Safety
148///
149/// It must be safe to do an unaligned read of size(V) bytes starting at ptr,
150/// (ptr + rare1i) and (ptr + rare2i).
151#[allow(dead_code)]
152#[inline(always)]
153unsafe fn find_in_chunk3<V: Vector>(
154 ptr: *const u8,
155 rare1i: usize,
156 rare2i: usize,
157 firstchunk: V,
158 rare1chunk: V,
159 rare2chunk: V,
160) -> Option<usize> {
161 let chunk0: V = V::load_unaligned(data:ptr);
162 let chunk1: V = V::load_unaligned(data:ptr.add(count:rare1i));
163 let chunk2: V = V::load_unaligned(data:ptr.add(count:rare2i));
164
165 let eq0: V = chunk0.cmpeq(vector2:firstchunk);
166 let eq1: V = chunk1.cmpeq(vector2:rare1chunk);
167 let eq2: V = chunk2.cmpeq(vector2:rare2chunk);
168
169 let match_offsets: u32 = eq0.and(eq1).and(vector2:eq2).movemask();
170 if match_offsets == 0 {
171 return None;
172 }
173 Some(match_offsets.trailing_zeros() as usize)
174}
175
176/// Accepts a chunk-relative offset and returns a haystack relative offset
177/// after updating the prefilter state.
178///
179/// Why do we use this unlineable function when a search completes? Well,
180/// I don't know. Really. Obviously this function was not here initially.
181/// When doing profiling, the codegen for the inner loop here looked bad and
182/// I didn't know why. There were a couple extra 'add' instructions and an
183/// extra 'lea' instruction that I couldn't explain. I hypothesized that the
184/// optimizer was having trouble untangling the hot code in the loop from the
185/// code that deals with a candidate match. By putting the latter into an
186/// unlineable function, it kind of forces the issue and it had the intended
187/// effect: codegen improved measurably. It's good for a ~10% improvement
188/// across the board on the memmem/krate/prebuilt/huge-en/ benchmarks.
189#[cold]
190#[inline(never)]
191fn matched(
192 prestate: &mut PrefilterState,
193 start_ptr: *const u8,
194 ptr: *const u8,
195 chunki: usize,
196) -> usize {
197 let found: usize = diff(a:ptr, b:start_ptr) + chunki;
198 prestate.update(skipped:found);
199 found
200}
201
202/// Subtract `b` from `a` and return the difference. `a` must be greater than
203/// or equal to `b`.
204fn diff(a: *const u8, b: *const u8) -> usize {
205 debug_assert!(a >= b);
206 (a as usize) - (b as usize)
207}
208