1 | // This is adapted from `fallback.rs` from rust-memchr. It's modified to return |
2 | // the 'inverse' query of memchr, e.g. finding the first byte not in the |
3 | // provided set. This is simple for the 1-byte case. |
4 | |
5 | use core::{cmp, usize}; |
6 | |
7 | const USIZE_BYTES: usize = core::mem::size_of::<usize>(); |
8 | |
9 | // The number of bytes to loop at in one iteration of memchr/memrchr. |
10 | const LOOP_SIZE: usize = 2 * USIZE_BYTES; |
11 | |
12 | /// Repeat the given byte into a word size number. That is, every 8 bits |
13 | /// is equivalent to the given byte. For example, if `b` is `\x4E` or |
14 | /// `01001110` in binary, then the returned value on a 32-bit system would be: |
15 | /// `01001110_01001110_01001110_01001110`. |
16 | #[inline (always)] |
17 | fn repeat_byte(b: u8) -> usize { |
18 | (b as usize) * (usize::MAX / 255) |
19 | } |
20 | |
21 | pub fn inv_memchr(n1: u8, haystack: &[u8]) -> Option<usize> { |
22 | let vn1 = repeat_byte(n1); |
23 | let confirm = |byte| byte != n1; |
24 | let loop_size = cmp::min(LOOP_SIZE, haystack.len()); |
25 | let align = USIZE_BYTES - 1; |
26 | let start_ptr = haystack.as_ptr(); |
27 | |
28 | unsafe { |
29 | let end_ptr = haystack.as_ptr().add(haystack.len()); |
30 | let mut ptr = start_ptr; |
31 | |
32 | if haystack.len() < USIZE_BYTES { |
33 | return forward_search(start_ptr, end_ptr, ptr, confirm); |
34 | } |
35 | |
36 | let chunk = read_unaligned_usize(ptr); |
37 | if (chunk ^ vn1) != 0 { |
38 | return forward_search(start_ptr, end_ptr, ptr, confirm); |
39 | } |
40 | |
41 | ptr = ptr.add(USIZE_BYTES - (start_ptr as usize & align)); |
42 | debug_assert!(ptr > start_ptr); |
43 | debug_assert!(end_ptr.sub(USIZE_BYTES) >= start_ptr); |
44 | while loop_size == LOOP_SIZE && ptr <= end_ptr.sub(loop_size) { |
45 | debug_assert_eq!(0, (ptr as usize) % USIZE_BYTES); |
46 | |
47 | let a = *(ptr as *const usize); |
48 | let b = *(ptr.add(USIZE_BYTES) as *const usize); |
49 | let eqa = (a ^ vn1) != 0; |
50 | let eqb = (b ^ vn1) != 0; |
51 | if eqa || eqb { |
52 | break; |
53 | } |
54 | ptr = ptr.add(LOOP_SIZE); |
55 | } |
56 | forward_search(start_ptr, end_ptr, ptr, confirm) |
57 | } |
58 | } |
59 | |
60 | /// Return the last index not matching the byte `x` in `text`. |
61 | pub fn inv_memrchr(n1: u8, haystack: &[u8]) -> Option<usize> { |
62 | let vn1 = repeat_byte(n1); |
63 | let confirm = |byte| byte != n1; |
64 | let loop_size = cmp::min(LOOP_SIZE, haystack.len()); |
65 | let align = USIZE_BYTES - 1; |
66 | let start_ptr = haystack.as_ptr(); |
67 | |
68 | unsafe { |
69 | let end_ptr = haystack.as_ptr().add(haystack.len()); |
70 | let mut ptr = end_ptr; |
71 | |
72 | if haystack.len() < USIZE_BYTES { |
73 | return reverse_search(start_ptr, end_ptr, ptr, confirm); |
74 | } |
75 | |
76 | let chunk = read_unaligned_usize(ptr.sub(USIZE_BYTES)); |
77 | if (chunk ^ vn1) != 0 { |
78 | return reverse_search(start_ptr, end_ptr, ptr, confirm); |
79 | } |
80 | |
81 | ptr = ptr.sub(end_ptr as usize & align); |
82 | debug_assert!(start_ptr <= ptr && ptr <= end_ptr); |
83 | while loop_size == LOOP_SIZE && ptr >= start_ptr.add(loop_size) { |
84 | debug_assert_eq!(0, (ptr as usize) % USIZE_BYTES); |
85 | |
86 | let a = *(ptr.sub(2 * USIZE_BYTES) as *const usize); |
87 | let b = *(ptr.sub(1 * USIZE_BYTES) as *const usize); |
88 | let eqa = (a ^ vn1) != 0; |
89 | let eqb = (b ^ vn1) != 0; |
90 | if eqa || eqb { |
91 | break; |
92 | } |
93 | ptr = ptr.sub(loop_size); |
94 | } |
95 | reverse_search(start_ptr, end_ptr, ptr, confirm) |
96 | } |
97 | } |
98 | |
99 | #[inline (always)] |
100 | unsafe fn forward_search<F: Fn(u8) -> bool>( |
101 | start_ptr: *const u8, |
102 | end_ptr: *const u8, |
103 | mut ptr: *const u8, |
104 | confirm: F, |
105 | ) -> Option<usize> { |
106 | debug_assert!(start_ptr <= ptr); |
107 | debug_assert!(ptr <= end_ptr); |
108 | |
109 | while ptr < end_ptr { |
110 | if confirm(*ptr) { |
111 | return Some(sub(a:ptr, b:start_ptr)); |
112 | } |
113 | ptr = ptr.offset(count:1); |
114 | } |
115 | None |
116 | } |
117 | |
118 | #[inline (always)] |
119 | unsafe fn reverse_search<F: Fn(u8) -> bool>( |
120 | start_ptr: *const u8, |
121 | end_ptr: *const u8, |
122 | mut ptr: *const u8, |
123 | confirm: F, |
124 | ) -> Option<usize> { |
125 | debug_assert!(start_ptr <= ptr); |
126 | debug_assert!(ptr <= end_ptr); |
127 | |
128 | while ptr > start_ptr { |
129 | ptr = ptr.offset(count:-1); |
130 | if confirm(*ptr) { |
131 | return Some(sub(a:ptr, b:start_ptr)); |
132 | } |
133 | } |
134 | None |
135 | } |
136 | |
137 | unsafe fn read_unaligned_usize(ptr: *const u8) -> usize { |
138 | (ptr as *const usize).read_unaligned() |
139 | } |
140 | |
141 | /// Subtract `b` from `a` and return the difference. `a` should be greater than |
142 | /// or equal to `b`. |
143 | fn sub(a: *const u8, b: *const u8) -> usize { |
144 | debug_assert!(a >= b); |
145 | (a as usize) - (b as usize) |
146 | } |
147 | |
148 | /// Safe wrapper around `forward_search` |
149 | #[inline ] |
150 | pub(crate) fn forward_search_bytes<F: Fn(u8) -> bool>( |
151 | s: &[u8], |
152 | confirm: F, |
153 | ) -> Option<usize> { |
154 | unsafe { |
155 | let start: *const u8 = s.as_ptr(); |
156 | let end: *const u8 = start.add(count:s.len()); |
157 | forward_search(start_ptr:start, end_ptr:end, ptr:start, confirm) |
158 | } |
159 | } |
160 | |
161 | /// Safe wrapper around `reverse_search` |
162 | #[inline ] |
163 | pub(crate) fn reverse_search_bytes<F: Fn(u8) -> bool>( |
164 | s: &[u8], |
165 | confirm: F, |
166 | ) -> Option<usize> { |
167 | unsafe { |
168 | let start: *const u8 = s.as_ptr(); |
169 | let end: *const u8 = start.add(count:s.len()); |
170 | reverse_search(start_ptr:start, end_ptr:end, ptr:end, confirm) |
171 | } |
172 | } |
173 | |
174 | #[cfg (all(test, feature = "std" ))] |
175 | mod tests { |
176 | use super::{inv_memchr, inv_memrchr}; |
177 | |
178 | // search string, search byte, inv_memchr result, inv_memrchr result. |
179 | // these are expanded into a much larger set of tests in build_tests |
180 | const TESTS: &[(&[u8], u8, usize, usize)] = &[ |
181 | (b"z" , b'a' , 0, 0), |
182 | (b"zz" , b'a' , 0, 1), |
183 | (b"aza" , b'a' , 1, 1), |
184 | (b"zaz" , b'a' , 0, 2), |
185 | (b"zza" , b'a' , 0, 1), |
186 | (b"zaa" , b'a' , 0, 0), |
187 | (b"zzz" , b'a' , 0, 2), |
188 | ]; |
189 | |
190 | type TestCase = (Vec<u8>, u8, Option<(usize, usize)>); |
191 | |
192 | fn build_tests() -> Vec<TestCase> { |
193 | #[cfg (not(miri))] |
194 | const MAX_PER: usize = 515; |
195 | #[cfg (miri)] |
196 | const MAX_PER: usize = 10; |
197 | |
198 | let mut result = vec![]; |
199 | for &(search, byte, fwd_pos, rev_pos) in TESTS { |
200 | result.push((search.to_vec(), byte, Some((fwd_pos, rev_pos)))); |
201 | for i in 1..MAX_PER { |
202 | // add a bunch of copies of the search byte to the end. |
203 | let mut suffixed: Vec<u8> = search.into(); |
204 | suffixed.extend(std::iter::repeat(byte).take(i)); |
205 | result.push((suffixed, byte, Some((fwd_pos, rev_pos)))); |
206 | |
207 | // add a bunch of copies of the search byte to the start. |
208 | let mut prefixed: Vec<u8> = |
209 | std::iter::repeat(byte).take(i).collect(); |
210 | prefixed.extend(search); |
211 | result.push(( |
212 | prefixed, |
213 | byte, |
214 | Some((fwd_pos + i, rev_pos + i)), |
215 | )); |
216 | |
217 | // add a bunch of copies of the search byte to both ends. |
218 | let mut surrounded: Vec<u8> = |
219 | std::iter::repeat(byte).take(i).collect(); |
220 | surrounded.extend(search); |
221 | surrounded.extend(std::iter::repeat(byte).take(i)); |
222 | result.push(( |
223 | surrounded, |
224 | byte, |
225 | Some((fwd_pos + i, rev_pos + i)), |
226 | )); |
227 | } |
228 | } |
229 | |
230 | // build non-matching tests for several sizes |
231 | for i in 0..MAX_PER { |
232 | result.push(( |
233 | std::iter::repeat(b' \0' ).take(i).collect(), |
234 | b' \0' , |
235 | None, |
236 | )); |
237 | } |
238 | |
239 | result |
240 | } |
241 | |
242 | #[test ] |
243 | fn test_inv_memchr() { |
244 | use crate::{ByteSlice, B}; |
245 | |
246 | #[cfg (not(miri))] |
247 | const MAX_OFFSET: usize = 130; |
248 | #[cfg (miri)] |
249 | const MAX_OFFSET: usize = 13; |
250 | |
251 | for (search, byte, matching) in build_tests() { |
252 | assert_eq!( |
253 | inv_memchr(byte, &search), |
254 | matching.map(|m| m.0), |
255 | "inv_memchr when searching for {:?} in {:?}" , |
256 | byte as char, |
257 | // better printing |
258 | B(&search).as_bstr(), |
259 | ); |
260 | assert_eq!( |
261 | inv_memrchr(byte, &search), |
262 | matching.map(|m| m.1), |
263 | "inv_memrchr when searching for {:?} in {:?}" , |
264 | byte as char, |
265 | // better printing |
266 | B(&search).as_bstr(), |
267 | ); |
268 | // Test a rather large number off offsets for potential alignment |
269 | // issues. |
270 | for offset in 1..MAX_OFFSET { |
271 | if offset >= search.len() { |
272 | break; |
273 | } |
274 | // If this would cause us to shift the results off the end, |
275 | // skip it so that we don't have to recompute them. |
276 | if let Some((f, r)) = matching { |
277 | if offset > f || offset > r { |
278 | break; |
279 | } |
280 | } |
281 | let realigned = &search[offset..]; |
282 | |
283 | let forward_pos = matching.map(|m| m.0 - offset); |
284 | let reverse_pos = matching.map(|m| m.1 - offset); |
285 | |
286 | assert_eq!( |
287 | inv_memchr(byte, &realigned), |
288 | forward_pos, |
289 | "inv_memchr when searching (realigned by {}) for {:?} in {:?}" , |
290 | offset, |
291 | byte as char, |
292 | realigned.as_bstr(), |
293 | ); |
294 | assert_eq!( |
295 | inv_memrchr(byte, &realigned), |
296 | reverse_pos, |
297 | "inv_memrchr when searching (realigned by {}) for {:?} in {:?}" , |
298 | offset, |
299 | byte as char, |
300 | realigned.as_bstr(), |
301 | ); |
302 | } |
303 | } |
304 | } |
305 | } |
306 | |