1 | // This is adapted from `fallback.rs` from rust-memchr. It's modified to return |
2 | // the 'inverse' query of memchr, e.g. finding the first byte not in the |
3 | // provided set. This is simple for the 1-byte case. |
4 | |
5 | use core::{cmp, usize}; |
6 | |
7 | const USIZE_BYTES: usize = core::mem::size_of::<usize>(); |
8 | const ALIGN_MASK: usize = core::mem::align_of::<usize>() - 1; |
9 | |
10 | // The number of bytes to loop at in one iteration of memchr/memrchr. |
11 | const LOOP_SIZE: usize = 2 * USIZE_BYTES; |
12 | |
13 | /// Repeat the given byte into a word size number. That is, every 8 bits |
14 | /// is equivalent to the given byte. For example, if `b` is `\x4E` or |
15 | /// `01001110` in binary, then the returned value on a 32-bit system would be: |
16 | /// `01001110_01001110_01001110_01001110`. |
17 | #[inline (always)] |
18 | fn repeat_byte(b: u8) -> usize { |
19 | (b as usize) * (usize::MAX / 255) |
20 | } |
21 | |
22 | pub fn inv_memchr(n1: u8, haystack: &[u8]) -> Option<usize> { |
23 | let vn1 = repeat_byte(n1); |
24 | let confirm = |byte| byte != n1; |
25 | let loop_size = cmp::min(LOOP_SIZE, haystack.len()); |
26 | let start_ptr = haystack.as_ptr(); |
27 | |
28 | unsafe { |
29 | let end_ptr = haystack.as_ptr().add(haystack.len()); |
30 | let mut ptr = start_ptr; |
31 | |
32 | if haystack.len() < USIZE_BYTES { |
33 | return forward_search(start_ptr, end_ptr, ptr, confirm); |
34 | } |
35 | |
36 | let chunk = read_unaligned_usize(ptr); |
37 | if (chunk ^ vn1) != 0 { |
38 | return forward_search(start_ptr, end_ptr, ptr, confirm); |
39 | } |
40 | |
41 | ptr = ptr.add(USIZE_BYTES - (start_ptr as usize & ALIGN_MASK)); |
42 | debug_assert!(ptr > start_ptr); |
43 | debug_assert!(end_ptr.sub(USIZE_BYTES) >= start_ptr); |
44 | while loop_size == LOOP_SIZE && ptr <= end_ptr.sub(loop_size) { |
45 | debug_assert_eq!(0, (ptr as usize) % USIZE_BYTES); |
46 | |
47 | let a = *(ptr as *const usize); |
48 | let b = *(ptr.add(USIZE_BYTES) as *const usize); |
49 | let eqa = (a ^ vn1) != 0; |
50 | let eqb = (b ^ vn1) != 0; |
51 | if eqa || eqb { |
52 | break; |
53 | } |
54 | ptr = ptr.add(LOOP_SIZE); |
55 | } |
56 | forward_search(start_ptr, end_ptr, ptr, confirm) |
57 | } |
58 | } |
59 | |
60 | /// Return the last index not matching the byte `x` in `text`. |
61 | pub fn inv_memrchr(n1: u8, haystack: &[u8]) -> Option<usize> { |
62 | let vn1 = repeat_byte(n1); |
63 | let confirm = |byte| byte != n1; |
64 | let loop_size = cmp::min(LOOP_SIZE, haystack.len()); |
65 | let start_ptr = haystack.as_ptr(); |
66 | |
67 | unsafe { |
68 | let end_ptr = haystack.as_ptr().add(haystack.len()); |
69 | let mut ptr = end_ptr; |
70 | |
71 | if haystack.len() < USIZE_BYTES { |
72 | return reverse_search(start_ptr, end_ptr, ptr, confirm); |
73 | } |
74 | |
75 | let chunk = read_unaligned_usize(ptr.sub(USIZE_BYTES)); |
76 | if (chunk ^ vn1) != 0 { |
77 | return reverse_search(start_ptr, end_ptr, ptr, confirm); |
78 | } |
79 | |
80 | ptr = ptr.sub(end_ptr as usize & ALIGN_MASK); |
81 | debug_assert!(start_ptr <= ptr && ptr <= end_ptr); |
82 | while loop_size == LOOP_SIZE && ptr >= start_ptr.add(loop_size) { |
83 | debug_assert_eq!(0, (ptr as usize) % USIZE_BYTES); |
84 | |
85 | let a = *(ptr.sub(2 * USIZE_BYTES) as *const usize); |
86 | let b = *(ptr.sub(1 * USIZE_BYTES) as *const usize); |
87 | let eqa = (a ^ vn1) != 0; |
88 | let eqb = (b ^ vn1) != 0; |
89 | if eqa || eqb { |
90 | break; |
91 | } |
92 | ptr = ptr.sub(loop_size); |
93 | } |
94 | reverse_search(start_ptr, end_ptr, ptr, confirm) |
95 | } |
96 | } |
97 | |
98 | #[inline (always)] |
99 | unsafe fn forward_search<F: Fn(u8) -> bool>( |
100 | start_ptr: *const u8, |
101 | end_ptr: *const u8, |
102 | mut ptr: *const u8, |
103 | confirm: F, |
104 | ) -> Option<usize> { |
105 | debug_assert!(start_ptr <= ptr); |
106 | debug_assert!(ptr <= end_ptr); |
107 | |
108 | while ptr < end_ptr { |
109 | if confirm(*ptr) { |
110 | return Some(sub(a:ptr, b:start_ptr)); |
111 | } |
112 | ptr = ptr.offset(count:1); |
113 | } |
114 | None |
115 | } |
116 | |
117 | #[inline (always)] |
118 | unsafe fn reverse_search<F: Fn(u8) -> bool>( |
119 | start_ptr: *const u8, |
120 | end_ptr: *const u8, |
121 | mut ptr: *const u8, |
122 | confirm: F, |
123 | ) -> Option<usize> { |
124 | debug_assert!(start_ptr <= ptr); |
125 | debug_assert!(ptr <= end_ptr); |
126 | |
127 | while ptr > start_ptr { |
128 | ptr = ptr.offset(count:-1); |
129 | if confirm(*ptr) { |
130 | return Some(sub(a:ptr, b:start_ptr)); |
131 | } |
132 | } |
133 | None |
134 | } |
135 | |
136 | unsafe fn read_unaligned_usize(ptr: *const u8) -> usize { |
137 | (ptr as *const usize).read_unaligned() |
138 | } |
139 | |
140 | /// Subtract `b` from `a` and return the difference. `a` should be greater than |
141 | /// or equal to `b`. |
142 | fn sub(a: *const u8, b: *const u8) -> usize { |
143 | debug_assert!(a >= b); |
144 | (a as usize) - (b as usize) |
145 | } |
146 | |
147 | /// Safe wrapper around `forward_search` |
148 | #[inline ] |
149 | pub(crate) fn forward_search_bytes<F: Fn(u8) -> bool>( |
150 | s: &[u8], |
151 | confirm: F, |
152 | ) -> Option<usize> { |
153 | unsafe { |
154 | let start: *const u8 = s.as_ptr(); |
155 | let end: *const u8 = start.add(count:s.len()); |
156 | forward_search(start, end, ptr:start, confirm) |
157 | } |
158 | } |
159 | |
160 | /// Safe wrapper around `reverse_search` |
161 | #[inline ] |
162 | pub(crate) fn reverse_search_bytes<F: Fn(u8) -> bool>( |
163 | s: &[u8], |
164 | confirm: F, |
165 | ) -> Option<usize> { |
166 | unsafe { |
167 | let start: *const u8 = s.as_ptr(); |
168 | let end: *const u8 = start.add(count:s.len()); |
169 | reverse_search(start, end, ptr:end, confirm) |
170 | } |
171 | } |
172 | |
173 | #[cfg (all(test, feature = "std" ))] |
174 | mod tests { |
175 | use alloc::{vec, vec::Vec}; |
176 | |
177 | use super::{inv_memchr, inv_memrchr}; |
178 | |
179 | // search string, search byte, inv_memchr result, inv_memrchr result. |
180 | // these are expanded into a much larger set of tests in build_tests |
181 | const TESTS: &[(&[u8], u8, usize, usize)] = &[ |
182 | (b"z" , b'a' , 0, 0), |
183 | (b"zz" , b'a' , 0, 1), |
184 | (b"aza" , b'a' , 1, 1), |
185 | (b"zaz" , b'a' , 0, 2), |
186 | (b"zza" , b'a' , 0, 1), |
187 | (b"zaa" , b'a' , 0, 0), |
188 | (b"zzz" , b'a' , 0, 2), |
189 | ]; |
190 | |
191 | type TestCase = (Vec<u8>, u8, Option<(usize, usize)>); |
192 | |
193 | fn build_tests() -> Vec<TestCase> { |
194 | #[cfg (not(miri))] |
195 | const MAX_PER: usize = 515; |
196 | #[cfg (miri)] |
197 | const MAX_PER: usize = 10; |
198 | |
199 | let mut result = vec![]; |
200 | for &(search, byte, fwd_pos, rev_pos) in TESTS { |
201 | result.push((search.to_vec(), byte, Some((fwd_pos, rev_pos)))); |
202 | for i in 1..MAX_PER { |
203 | // add a bunch of copies of the search byte to the end. |
204 | let mut suffixed: Vec<u8> = search.into(); |
205 | suffixed.extend(std::iter::repeat(byte).take(i)); |
206 | result.push((suffixed, byte, Some((fwd_pos, rev_pos)))); |
207 | |
208 | // add a bunch of copies of the search byte to the start. |
209 | let mut prefixed: Vec<u8> = |
210 | std::iter::repeat(byte).take(i).collect(); |
211 | prefixed.extend(search); |
212 | result.push(( |
213 | prefixed, |
214 | byte, |
215 | Some((fwd_pos + i, rev_pos + i)), |
216 | )); |
217 | |
218 | // add a bunch of copies of the search byte to both ends. |
219 | let mut surrounded: Vec<u8> = |
220 | std::iter::repeat(byte).take(i).collect(); |
221 | surrounded.extend(search); |
222 | surrounded.extend(std::iter::repeat(byte).take(i)); |
223 | result.push(( |
224 | surrounded, |
225 | byte, |
226 | Some((fwd_pos + i, rev_pos + i)), |
227 | )); |
228 | } |
229 | } |
230 | |
231 | // build non-matching tests for several sizes |
232 | for i in 0..MAX_PER { |
233 | result.push(( |
234 | std::iter::repeat(b' \0' ).take(i).collect(), |
235 | b' \0' , |
236 | None, |
237 | )); |
238 | } |
239 | |
240 | result |
241 | } |
242 | |
243 | #[test ] |
244 | fn test_inv_memchr() { |
245 | use crate::{ByteSlice, B}; |
246 | |
247 | #[cfg (not(miri))] |
248 | const MAX_OFFSET: usize = 130; |
249 | #[cfg (miri)] |
250 | const MAX_OFFSET: usize = 13; |
251 | |
252 | for (search, byte, matching) in build_tests() { |
253 | assert_eq!( |
254 | inv_memchr(byte, &search), |
255 | matching.map(|m| m.0), |
256 | "inv_memchr when searching for {:?} in {:?}" , |
257 | byte as char, |
258 | // better printing |
259 | B(&search).as_bstr(), |
260 | ); |
261 | assert_eq!( |
262 | inv_memrchr(byte, &search), |
263 | matching.map(|m| m.1), |
264 | "inv_memrchr when searching for {:?} in {:?}" , |
265 | byte as char, |
266 | // better printing |
267 | B(&search).as_bstr(), |
268 | ); |
269 | // Test a rather large number off offsets for potential alignment |
270 | // issues. |
271 | for offset in 1..MAX_OFFSET { |
272 | if offset >= search.len() { |
273 | break; |
274 | } |
275 | // If this would cause us to shift the results off the end, |
276 | // skip it so that we don't have to recompute them. |
277 | if let Some((f, r)) = matching { |
278 | if offset > f || offset > r { |
279 | break; |
280 | } |
281 | } |
282 | let realigned = &search[offset..]; |
283 | |
284 | let forward_pos = matching.map(|m| m.0 - offset); |
285 | let reverse_pos = matching.map(|m| m.1 - offset); |
286 | |
287 | assert_eq!( |
288 | inv_memchr(byte, &realigned), |
289 | forward_pos, |
290 | "inv_memchr when searching (realigned by {}) for {:?} in {:?}" , |
291 | offset, |
292 | byte as char, |
293 | realigned.as_bstr(), |
294 | ); |
295 | assert_eq!( |
296 | inv_memrchr(byte, &realigned), |
297 | reverse_pos, |
298 | "inv_memrchr when searching (realigned by {}) for {:?} in {:?}" , |
299 | offset, |
300 | byte as char, |
301 | realigned.as_bstr(), |
302 | ); |
303 | } |
304 | } |
305 | } |
306 | } |
307 | |