1// The following ~400 lines of code exists for exactly one purpose, which is
2// to optimize this code:
3//
4// byte_slice.iter().position(|&b| b > 0x7F).unwrap_or(byte_slice.len())
5//
6// Yes... Overengineered is a word that comes to mind, but this is effectively
7// a very similar problem to memchr, and virtually nobody has been able to
8// resist optimizing the crap out of that (except for perhaps the BSD and MUSL
9// folks). In particular, this routine makes a very common case (ASCII) very
10// fast, which seems worth it. We do stop short of adding AVX variants of the
11// code below in order to retain our sanity and also to avoid needing to deal
12// with runtime target feature detection. RESIST!
13//
14// In order to understand the SIMD version below, it would be good to read this
15// comment describing how my memchr routine works:
16// https://github.com/BurntSushi/rust-memchr/blob/b0a29f267f4a7fad8ffcc8fe8377a06498202883/src/x86/sse2.rs#L19-L106
17//
18// The primary difference with memchr is that for ASCII, we can do a bit less
19// work. In particular, we don't need to detect the presence of a specific
20// byte, but rather, whether any byte has its most significant bit set. That
21// means we can effectively skip the _mm_cmpeq_epi8 step and jump straight to
22// _mm_movemask_epi8.
23
24#[cfg(any(test, miri, not(target_arch = "x86_64")))]
25const USIZE_BYTES: usize = core::mem::size_of::<usize>();
26#[cfg(any(test, miri, not(target_arch = "x86_64")))]
27const FALLBACK_LOOP_SIZE: usize = 2 * USIZE_BYTES;
28
29// This is a mask where the most significant bit of each byte in the usize
30// is set. We test this bit to determine whether a character is ASCII or not.
31// Namely, a single byte is regarded as an ASCII codepoint if and only if it's
32// most significant bit is not set.
33#[cfg(any(test, miri, not(target_arch = "x86_64")))]
34const ASCII_MASK_U64: u64 = 0x8080808080808080;
35#[cfg(any(test, miri, not(target_arch = "x86_64")))]
36const ASCII_MASK: usize = ASCII_MASK_U64 as usize;
37
38/// Returns the index of the first non ASCII byte in the given slice.
39///
40/// If slice only contains ASCII bytes, then the length of the slice is
41/// returned.
42pub fn first_non_ascii_byte(slice: &[u8]) -> usize {
43 #[cfg(any(miri, not(target_arch = "x86_64")))]
44 {
45 first_non_ascii_byte_fallback(slice)
46 }
47
48 #[cfg(all(not(miri), target_arch = "x86_64"))]
49 {
50 first_non_ascii_byte_sse2(slice)
51 }
52}
53
54#[cfg(any(test, miri, not(target_arch = "x86_64")))]
55fn first_non_ascii_byte_fallback(slice: &[u8]) -> usize {
56 let align = USIZE_BYTES - 1;
57 let start_ptr = slice.as_ptr();
58 let end_ptr = slice[slice.len()..].as_ptr();
59 let mut ptr = start_ptr;
60
61 unsafe {
62 if slice.len() < USIZE_BYTES {
63 return first_non_ascii_byte_slow(start_ptr, end_ptr, ptr);
64 }
65
66 let chunk = read_unaligned_usize(ptr);
67 let mask = chunk & ASCII_MASK;
68 if mask != 0 {
69 return first_non_ascii_byte_mask(mask);
70 }
71
72 ptr = ptr_add(ptr, USIZE_BYTES - (start_ptr as usize & align));
73 debug_assert!(ptr > start_ptr);
74 debug_assert!(ptr_sub(end_ptr, USIZE_BYTES) >= start_ptr);
75 if slice.len() >= FALLBACK_LOOP_SIZE {
76 while ptr <= ptr_sub(end_ptr, FALLBACK_LOOP_SIZE) {
77 debug_assert_eq!(0, (ptr as usize) % USIZE_BYTES);
78
79 let a = *(ptr as *const usize);
80 let b = *(ptr_add(ptr, USIZE_BYTES) as *const usize);
81 if (a | b) & ASCII_MASK != 0 {
82 // What a kludge. We wrap the position finding code into
83 // a non-inlineable function, which makes the codegen in
84 // the tight loop above a bit better by avoiding a
85 // couple extra movs. We pay for it by two additional
86 // stores, but only in the case of finding a non-ASCII
87 // byte.
88 #[inline(never)]
89 unsafe fn findpos(
90 start_ptr: *const u8,
91 ptr: *const u8,
92 ) -> usize {
93 let a = *(ptr as *const usize);
94 let b = *(ptr_add(ptr, USIZE_BYTES) as *const usize);
95
96 let mut at = sub(ptr, start_ptr);
97 let maska = a & ASCII_MASK;
98 if maska != 0 {
99 return at + first_non_ascii_byte_mask(maska);
100 }
101
102 at += USIZE_BYTES;
103 let maskb = b & ASCII_MASK;
104 debug_assert!(maskb != 0);
105 return at + first_non_ascii_byte_mask(maskb);
106 }
107 return findpos(start_ptr, ptr);
108 }
109 ptr = ptr_add(ptr, FALLBACK_LOOP_SIZE);
110 }
111 }
112 first_non_ascii_byte_slow(start_ptr, end_ptr, ptr)
113 }
114}
115
116#[cfg(all(not(miri), target_arch = "x86_64"))]
117fn first_non_ascii_byte_sse2(slice: &[u8]) -> usize {
118 use core::arch::x86_64::*;
119
120 const VECTOR_SIZE: usize = core::mem::size_of::<__m128i>();
121 const VECTOR_ALIGN: usize = VECTOR_SIZE - 1;
122 const VECTOR_LOOP_SIZE: usize = 4 * VECTOR_SIZE;
123
124 let start_ptr = slice.as_ptr();
125 let end_ptr = slice[slice.len()..].as_ptr();
126 let mut ptr = start_ptr;
127
128 unsafe {
129 if slice.len() < VECTOR_SIZE {
130 return first_non_ascii_byte_slow(start_ptr, end_ptr, ptr);
131 }
132
133 let chunk = _mm_loadu_si128(ptr as *const __m128i);
134 let mask = _mm_movemask_epi8(chunk);
135 if mask != 0 {
136 return mask.trailing_zeros() as usize;
137 }
138
139 ptr = ptr.add(VECTOR_SIZE - (start_ptr as usize & VECTOR_ALIGN));
140 debug_assert!(ptr > start_ptr);
141 debug_assert!(end_ptr.sub(VECTOR_SIZE) >= start_ptr);
142 if slice.len() >= VECTOR_LOOP_SIZE {
143 while ptr <= ptr_sub(end_ptr, VECTOR_LOOP_SIZE) {
144 debug_assert_eq!(0, (ptr as usize) % VECTOR_SIZE);
145
146 let a = _mm_load_si128(ptr as *const __m128i);
147 let b = _mm_load_si128(ptr.add(VECTOR_SIZE) as *const __m128i);
148 let c =
149 _mm_load_si128(ptr.add(2 * VECTOR_SIZE) as *const __m128i);
150 let d =
151 _mm_load_si128(ptr.add(3 * VECTOR_SIZE) as *const __m128i);
152
153 let or1 = _mm_or_si128(a, b);
154 let or2 = _mm_or_si128(c, d);
155 let or3 = _mm_or_si128(or1, or2);
156 if _mm_movemask_epi8(or3) != 0 {
157 let mut at = sub(ptr, start_ptr);
158 let mask = _mm_movemask_epi8(a);
159 if mask != 0 {
160 return at + mask.trailing_zeros() as usize;
161 }
162
163 at += VECTOR_SIZE;
164 let mask = _mm_movemask_epi8(b);
165 if mask != 0 {
166 return at + mask.trailing_zeros() as usize;
167 }
168
169 at += VECTOR_SIZE;
170 let mask = _mm_movemask_epi8(c);
171 if mask != 0 {
172 return at + mask.trailing_zeros() as usize;
173 }
174
175 at += VECTOR_SIZE;
176 let mask = _mm_movemask_epi8(d);
177 debug_assert!(mask != 0);
178 return at + mask.trailing_zeros() as usize;
179 }
180 ptr = ptr_add(ptr, VECTOR_LOOP_SIZE);
181 }
182 }
183 while ptr <= end_ptr.sub(VECTOR_SIZE) {
184 debug_assert!(sub(end_ptr, ptr) >= VECTOR_SIZE);
185
186 let chunk = _mm_loadu_si128(ptr as *const __m128i);
187 let mask = _mm_movemask_epi8(chunk);
188 if mask != 0 {
189 return sub(ptr, start_ptr) + mask.trailing_zeros() as usize;
190 }
191 ptr = ptr.add(VECTOR_SIZE);
192 }
193 first_non_ascii_byte_slow(start_ptr, end_ptr, ptr)
194 }
195}
196
197#[inline(always)]
198unsafe fn first_non_ascii_byte_slow(
199 start_ptr: *const u8,
200 end_ptr: *const u8,
201 mut ptr: *const u8,
202) -> usize {
203 debug_assert!(start_ptr <= ptr);
204 debug_assert!(ptr <= end_ptr);
205
206 while ptr < end_ptr {
207 if *ptr > 0x7F {
208 return sub(a:ptr, b:start_ptr);
209 }
210 ptr = ptr.offset(count:1);
211 }
212 sub(a:end_ptr, b:start_ptr)
213}
214
215/// Compute the position of the first ASCII byte in the given mask.
216///
217/// The mask should be computed by `chunk & ASCII_MASK`, where `chunk` is
218/// 8 contiguous bytes of the slice being checked where *at least* one of those
219/// bytes is not an ASCII byte.
220///
221/// The position returned is always in the inclusive range [0, 7].
222#[cfg(any(test, miri, not(target_arch = "x86_64")))]
223fn first_non_ascii_byte_mask(mask: usize) -> usize {
224 #[cfg(target_endian = "little")]
225 {
226 mask.trailing_zeros() as usize / 8
227 }
228 #[cfg(target_endian = "big")]
229 {
230 mask.leading_zeros() as usize / 8
231 }
232}
233
234/// Increment the given pointer by the given amount.
235unsafe fn ptr_add(ptr: *const u8, amt: usize) -> *const u8 {
236 debug_assert!(amt < ::core::isize::MAX as usize);
237 ptr.offset(count:amt as isize)
238}
239
240/// Decrement the given pointer by the given amount.
241unsafe fn ptr_sub(ptr: *const u8, amt: usize) -> *const u8 {
242 debug_assert!(amt < ::core::isize::MAX as usize);
243 ptr.offset((amt as isize).wrapping_neg())
244}
245
246#[cfg(any(test, miri, not(target_arch = "x86_64")))]
247unsafe fn read_unaligned_usize(ptr: *const u8) -> usize {
248 use core::ptr;
249
250 let mut n: usize = 0;
251 ptr::copy_nonoverlapping(src:ptr, &mut n as *mut _ as *mut u8, USIZE_BYTES);
252 n
253}
254
255/// Subtract `b` from `a` and return the difference. `a` should be greater than
256/// or equal to `b`.
257fn sub(a: *const u8, b: *const u8) -> usize {
258 debug_assert!(a >= b);
259 (a as usize) - (b as usize)
260}
261
262#[cfg(test)]
263mod tests {
264 use super::*;
265
266 // Our testing approach here is to try and exhaustively test every case.
267 // This includes the position at which a non-ASCII byte occurs in addition
268 // to the alignment of the slice that we're searching.
269
270 #[test]
271 fn positive_fallback_forward() {
272 for i in 0..517 {
273 let s = "a".repeat(i);
274 assert_eq!(
275 i,
276 first_non_ascii_byte_fallback(s.as_bytes()),
277 "i: {:?}, len: {:?}, s: {:?}",
278 i,
279 s.len(),
280 s
281 );
282 }
283 }
284
285 #[test]
286 #[cfg(target_arch = "x86_64")]
287 #[cfg(not(miri))]
288 fn positive_sse2_forward() {
289 for i in 0..517 {
290 let b = "a".repeat(i).into_bytes();
291 assert_eq!(b.len(), first_non_ascii_byte_sse2(&b));
292 }
293 }
294
295 #[test]
296 #[cfg(not(miri))]
297 fn negative_fallback_forward() {
298 for i in 0..517 {
299 for align in 0..65 {
300 let mut s = "a".repeat(i);
301 s.push_str("☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃");
302 let s = s.get(align..).unwrap_or("");
303 assert_eq!(
304 i.saturating_sub(align),
305 first_non_ascii_byte_fallback(s.as_bytes()),
306 "i: {:?}, align: {:?}, len: {:?}, s: {:?}",
307 i,
308 align,
309 s.len(),
310 s
311 );
312 }
313 }
314 }
315
316 #[test]
317 #[cfg(target_arch = "x86_64")]
318 #[cfg(not(miri))]
319 fn negative_sse2_forward() {
320 for i in 0..517 {
321 for align in 0..65 {
322 let mut s = "a".repeat(i);
323 s.push_str("☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃");
324 let s = s.get(align..).unwrap_or("");
325 assert_eq!(
326 i.saturating_sub(align),
327 first_non_ascii_byte_sse2(s.as_bytes()),
328 "i: {:?}, align: {:?}, len: {:?}, s: {:?}",
329 i,
330 align,
331 s.len(),
332 s
333 );
334 }
335 }
336 }
337}
338