1// This is adapted from `fallback.rs` from rust-memchr. It's modified to return
2// the 'inverse' query of memchr, e.g. finding the first byte not in the
3// provided set. This is simple for the 1-byte case.
4
5use core::{cmp, usize};
6
7const USIZE_BYTES: usize = core::mem::size_of::<usize>();
8
9// The number of bytes to loop at in one iteration of memchr/memrchr.
10const LOOP_SIZE: usize = 2 * USIZE_BYTES;
11
12/// Repeat the given byte into a word size number. That is, every 8 bits
13/// is equivalent to the given byte. For example, if `b` is `\x4E` or
14/// `01001110` in binary, then the returned value on a 32-bit system would be:
15/// `01001110_01001110_01001110_01001110`.
16#[inline(always)]
17fn repeat_byte(b: u8) -> usize {
18 (b as usize) * (usize::MAX / 255)
19}
20
21pub fn inv_memchr(n1: u8, haystack: &[u8]) -> Option<usize> {
22 let vn1 = repeat_byte(n1);
23 let confirm = |byte| byte != n1;
24 let loop_size = cmp::min(LOOP_SIZE, haystack.len());
25 let align = USIZE_BYTES - 1;
26 let start_ptr = haystack.as_ptr();
27
28 unsafe {
29 let end_ptr = haystack.as_ptr().add(haystack.len());
30 let mut ptr = start_ptr;
31
32 if haystack.len() < USIZE_BYTES {
33 return forward_search(start_ptr, end_ptr, ptr, confirm);
34 }
35
36 let chunk = read_unaligned_usize(ptr);
37 if (chunk ^ vn1) != 0 {
38 return forward_search(start_ptr, end_ptr, ptr, confirm);
39 }
40
41 ptr = ptr.add(USIZE_BYTES - (start_ptr as usize & align));
42 debug_assert!(ptr > start_ptr);
43 debug_assert!(end_ptr.sub(USIZE_BYTES) >= start_ptr);
44 while loop_size == LOOP_SIZE && ptr <= end_ptr.sub(loop_size) {
45 debug_assert_eq!(0, (ptr as usize) % USIZE_BYTES);
46
47 let a = *(ptr as *const usize);
48 let b = *(ptr.add(USIZE_BYTES) as *const usize);
49 let eqa = (a ^ vn1) != 0;
50 let eqb = (b ^ vn1) != 0;
51 if eqa || eqb {
52 break;
53 }
54 ptr = ptr.add(LOOP_SIZE);
55 }
56 forward_search(start_ptr, end_ptr, ptr, confirm)
57 }
58}
59
60/// Return the last index not matching the byte `x` in `text`.
61pub fn inv_memrchr(n1: u8, haystack: &[u8]) -> Option<usize> {
62 let vn1 = repeat_byte(n1);
63 let confirm = |byte| byte != n1;
64 let loop_size = cmp::min(LOOP_SIZE, haystack.len());
65 let align = USIZE_BYTES - 1;
66 let start_ptr = haystack.as_ptr();
67
68 unsafe {
69 let end_ptr = haystack.as_ptr().add(haystack.len());
70 let mut ptr = end_ptr;
71
72 if haystack.len() < USIZE_BYTES {
73 return reverse_search(start_ptr, end_ptr, ptr, confirm);
74 }
75
76 let chunk = read_unaligned_usize(ptr.sub(USIZE_BYTES));
77 if (chunk ^ vn1) != 0 {
78 return reverse_search(start_ptr, end_ptr, ptr, confirm);
79 }
80
81 ptr = ptr.sub(end_ptr as usize & align);
82 debug_assert!(start_ptr <= ptr && ptr <= end_ptr);
83 while loop_size == LOOP_SIZE && ptr >= start_ptr.add(loop_size) {
84 debug_assert_eq!(0, (ptr as usize) % USIZE_BYTES);
85
86 let a = *(ptr.sub(2 * USIZE_BYTES) as *const usize);
87 let b = *(ptr.sub(1 * USIZE_BYTES) as *const usize);
88 let eqa = (a ^ vn1) != 0;
89 let eqb = (b ^ vn1) != 0;
90 if eqa || eqb {
91 break;
92 }
93 ptr = ptr.sub(loop_size);
94 }
95 reverse_search(start_ptr, end_ptr, ptr, confirm)
96 }
97}
98
99#[inline(always)]
100unsafe fn forward_search<F: Fn(u8) -> bool>(
101 start_ptr: *const u8,
102 end_ptr: *const u8,
103 mut ptr: *const u8,
104 confirm: F,
105) -> Option<usize> {
106 debug_assert!(start_ptr <= ptr);
107 debug_assert!(ptr <= end_ptr);
108
109 while ptr < end_ptr {
110 if confirm(*ptr) {
111 return Some(sub(a:ptr, b:start_ptr));
112 }
113 ptr = ptr.offset(count:1);
114 }
115 None
116}
117
118#[inline(always)]
119unsafe fn reverse_search<F: Fn(u8) -> bool>(
120 start_ptr: *const u8,
121 end_ptr: *const u8,
122 mut ptr: *const u8,
123 confirm: F,
124) -> Option<usize> {
125 debug_assert!(start_ptr <= ptr);
126 debug_assert!(ptr <= end_ptr);
127
128 while ptr > start_ptr {
129 ptr = ptr.offset(count:-1);
130 if confirm(*ptr) {
131 return Some(sub(a:ptr, b:start_ptr));
132 }
133 }
134 None
135}
136
137unsafe fn read_unaligned_usize(ptr: *const u8) -> usize {
138 (ptr as *const usize).read_unaligned()
139}
140
141/// Subtract `b` from `a` and return the difference. `a` should be greater than
142/// or equal to `b`.
143fn sub(a: *const u8, b: *const u8) -> usize {
144 debug_assert!(a >= b);
145 (a as usize) - (b as usize)
146}
147
148/// Safe wrapper around `forward_search`
149#[inline]
150pub(crate) fn forward_search_bytes<F: Fn(u8) -> bool>(
151 s: &[u8],
152 confirm: F,
153) -> Option<usize> {
154 unsafe {
155 let start: *const u8 = s.as_ptr();
156 let end: *const u8 = start.add(count:s.len());
157 forward_search(start_ptr:start, end_ptr:end, ptr:start, confirm)
158 }
159}
160
161/// Safe wrapper around `reverse_search`
162#[inline]
163pub(crate) fn reverse_search_bytes<F: Fn(u8) -> bool>(
164 s: &[u8],
165 confirm: F,
166) -> Option<usize> {
167 unsafe {
168 let start: *const u8 = s.as_ptr();
169 let end: *const u8 = start.add(count:s.len());
170 reverse_search(start_ptr:start, end_ptr:end, ptr:end, confirm)
171 }
172}
173
174#[cfg(all(test, feature = "std"))]
175mod tests {
176 use super::{inv_memchr, inv_memrchr};
177
178 // search string, search byte, inv_memchr result, inv_memrchr result.
179 // these are expanded into a much larger set of tests in build_tests
180 const TESTS: &[(&[u8], u8, usize, usize)] = &[
181 (b"z", b'a', 0, 0),
182 (b"zz", b'a', 0, 1),
183 (b"aza", b'a', 1, 1),
184 (b"zaz", b'a', 0, 2),
185 (b"zza", b'a', 0, 1),
186 (b"zaa", b'a', 0, 0),
187 (b"zzz", b'a', 0, 2),
188 ];
189
190 type TestCase = (Vec<u8>, u8, Option<(usize, usize)>);
191
192 fn build_tests() -> Vec<TestCase> {
193 #[cfg(not(miri))]
194 const MAX_PER: usize = 515;
195 #[cfg(miri)]
196 const MAX_PER: usize = 10;
197
198 let mut result = vec![];
199 for &(search, byte, fwd_pos, rev_pos) in TESTS {
200 result.push((search.to_vec(), byte, Some((fwd_pos, rev_pos))));
201 for i in 1..MAX_PER {
202 // add a bunch of copies of the search byte to the end.
203 let mut suffixed: Vec<u8> = search.into();
204 suffixed.extend(std::iter::repeat(byte).take(i));
205 result.push((suffixed, byte, Some((fwd_pos, rev_pos))));
206
207 // add a bunch of copies of the search byte to the start.
208 let mut prefixed: Vec<u8> =
209 std::iter::repeat(byte).take(i).collect();
210 prefixed.extend(search);
211 result.push((
212 prefixed,
213 byte,
214 Some((fwd_pos + i, rev_pos + i)),
215 ));
216
217 // add a bunch of copies of the search byte to both ends.
218 let mut surrounded: Vec<u8> =
219 std::iter::repeat(byte).take(i).collect();
220 surrounded.extend(search);
221 surrounded.extend(std::iter::repeat(byte).take(i));
222 result.push((
223 surrounded,
224 byte,
225 Some((fwd_pos + i, rev_pos + i)),
226 ));
227 }
228 }
229
230 // build non-matching tests for several sizes
231 for i in 0..MAX_PER {
232 result.push((
233 std::iter::repeat(b'\0').take(i).collect(),
234 b'\0',
235 None,
236 ));
237 }
238
239 result
240 }
241
242 #[test]
243 fn test_inv_memchr() {
244 use crate::{ByteSlice, B};
245
246 #[cfg(not(miri))]
247 const MAX_OFFSET: usize = 130;
248 #[cfg(miri)]
249 const MAX_OFFSET: usize = 13;
250
251 for (search, byte, matching) in build_tests() {
252 assert_eq!(
253 inv_memchr(byte, &search),
254 matching.map(|m| m.0),
255 "inv_memchr when searching for {:?} in {:?}",
256 byte as char,
257 // better printing
258 B(&search).as_bstr(),
259 );
260 assert_eq!(
261 inv_memrchr(byte, &search),
262 matching.map(|m| m.1),
263 "inv_memrchr when searching for {:?} in {:?}",
264 byte as char,
265 // better printing
266 B(&search).as_bstr(),
267 );
268 // Test a rather large number off offsets for potential alignment
269 // issues.
270 for offset in 1..MAX_OFFSET {
271 if offset >= search.len() {
272 break;
273 }
274 // If this would cause us to shift the results off the end,
275 // skip it so that we don't have to recompute them.
276 if let Some((f, r)) = matching {
277 if offset > f || offset > r {
278 break;
279 }
280 }
281 let realigned = &search[offset..];
282
283 let forward_pos = matching.map(|m| m.0 - offset);
284 let reverse_pos = matching.map(|m| m.1 - offset);
285
286 assert_eq!(
287 inv_memchr(byte, &realigned),
288 forward_pos,
289 "inv_memchr when searching (realigned by {}) for {:?} in {:?}",
290 offset,
291 byte as char,
292 realigned.as_bstr(),
293 );
294 assert_eq!(
295 inv_memrchr(byte, &realigned),
296 reverse_pos,
297 "inv_memrchr when searching (realigned by {}) for {:?} in {:?}",
298 offset,
299 byte as char,
300 realigned.as_bstr(),
301 );
302 }
303 }
304 }
305}
306