1 | // The following ~400 lines of code exists for exactly one purpose, which is |
2 | // to optimize this code: |
3 | // |
4 | // byte_slice.iter().position(|&b| b > 0x7F).unwrap_or(byte_slice.len()) |
5 | // |
6 | // Yes... Overengineered is a word that comes to mind, but this is effectively |
7 | // a very similar problem to memchr, and virtually nobody has been able to |
8 | // resist optimizing the crap out of that (except for perhaps the BSD and MUSL |
9 | // folks). In particular, this routine makes a very common case (ASCII) very |
10 | // fast, which seems worth it. We do stop short of adding AVX variants of the |
11 | // code below in order to retain our sanity and also to avoid needing to deal |
12 | // with runtime target feature detection. RESIST! |
13 | // |
14 | // In order to understand the SIMD version below, it would be good to read this |
15 | // comment describing how my memchr routine works: |
16 | // https://github.com/BurntSushi/rust-memchr/blob/b0a29f267f4a7fad8ffcc8fe8377a06498202883/src/x86/sse2.rs#L19-L106 |
17 | // |
18 | // The primary difference with memchr is that for ASCII, we can do a bit less |
19 | // work. In particular, we don't need to detect the presence of a specific |
20 | // byte, but rather, whether any byte has its most significant bit set. That |
21 | // means we can effectively skip the _mm_cmpeq_epi8 step and jump straight to |
22 | // _mm_movemask_epi8. |
23 | |
24 | #[cfg (any(test, miri, not(target_arch = "x86_64" )))] |
25 | const USIZE_BYTES: usize = core::mem::size_of::<usize>(); |
26 | #[cfg (any(test, miri, not(target_arch = "x86_64" )))] |
27 | const FALLBACK_LOOP_SIZE: usize = 2 * USIZE_BYTES; |
28 | |
29 | // This is a mask where the most significant bit of each byte in the usize |
30 | // is set. We test this bit to determine whether a character is ASCII or not. |
31 | // Namely, a single byte is regarded as an ASCII codepoint if and only if it's |
32 | // most significant bit is not set. |
33 | #[cfg (any(test, miri, not(target_arch = "x86_64" )))] |
34 | const ASCII_MASK_U64: u64 = 0x8080808080808080; |
35 | #[cfg (any(test, miri, not(target_arch = "x86_64" )))] |
36 | const ASCII_MASK: usize = ASCII_MASK_U64 as usize; |
37 | |
38 | /// Returns the index of the first non ASCII byte in the given slice. |
39 | /// |
40 | /// If slice only contains ASCII bytes, then the length of the slice is |
41 | /// returned. |
42 | pub fn first_non_ascii_byte(slice: &[u8]) -> usize { |
43 | #[cfg (any(miri, not(target_arch = "x86_64" )))] |
44 | { |
45 | first_non_ascii_byte_fallback(slice) |
46 | } |
47 | |
48 | #[cfg (all(not(miri), target_arch = "x86_64" ))] |
49 | { |
50 | first_non_ascii_byte_sse2(slice) |
51 | } |
52 | } |
53 | |
54 | #[cfg (any(test, miri, not(target_arch = "x86_64" )))] |
55 | fn first_non_ascii_byte_fallback(slice: &[u8]) -> usize { |
56 | let align = USIZE_BYTES - 1; |
57 | let start_ptr = slice.as_ptr(); |
58 | let end_ptr = slice[slice.len()..].as_ptr(); |
59 | let mut ptr = start_ptr; |
60 | |
61 | unsafe { |
62 | if slice.len() < USIZE_BYTES { |
63 | return first_non_ascii_byte_slow(start_ptr, end_ptr, ptr); |
64 | } |
65 | |
66 | let chunk = read_unaligned_usize(ptr); |
67 | let mask = chunk & ASCII_MASK; |
68 | if mask != 0 { |
69 | return first_non_ascii_byte_mask(mask); |
70 | } |
71 | |
72 | ptr = ptr_add(ptr, USIZE_BYTES - (start_ptr as usize & align)); |
73 | debug_assert!(ptr > start_ptr); |
74 | debug_assert!(ptr_sub(end_ptr, USIZE_BYTES) >= start_ptr); |
75 | if slice.len() >= FALLBACK_LOOP_SIZE { |
76 | while ptr <= ptr_sub(end_ptr, FALLBACK_LOOP_SIZE) { |
77 | debug_assert_eq!(0, (ptr as usize) % USIZE_BYTES); |
78 | |
79 | let a = *(ptr as *const usize); |
80 | let b = *(ptr_add(ptr, USIZE_BYTES) as *const usize); |
81 | if (a | b) & ASCII_MASK != 0 { |
82 | // What a kludge. We wrap the position finding code into |
83 | // a non-inlineable function, which makes the codegen in |
84 | // the tight loop above a bit better by avoiding a |
85 | // couple extra movs. We pay for it by two additional |
86 | // stores, but only in the case of finding a non-ASCII |
87 | // byte. |
88 | #[inline (never)] |
89 | unsafe fn findpos( |
90 | start_ptr: *const u8, |
91 | ptr: *const u8, |
92 | ) -> usize { |
93 | let a = *(ptr as *const usize); |
94 | let b = *(ptr_add(ptr, USIZE_BYTES) as *const usize); |
95 | |
96 | let mut at = sub(ptr, start_ptr); |
97 | let maska = a & ASCII_MASK; |
98 | if maska != 0 { |
99 | return at + first_non_ascii_byte_mask(maska); |
100 | } |
101 | |
102 | at += USIZE_BYTES; |
103 | let maskb = b & ASCII_MASK; |
104 | debug_assert!(maskb != 0); |
105 | return at + first_non_ascii_byte_mask(maskb); |
106 | } |
107 | return findpos(start_ptr, ptr); |
108 | } |
109 | ptr = ptr_add(ptr, FALLBACK_LOOP_SIZE); |
110 | } |
111 | } |
112 | first_non_ascii_byte_slow(start_ptr, end_ptr, ptr) |
113 | } |
114 | } |
115 | |
116 | #[cfg (all(not(miri), target_arch = "x86_64" ))] |
117 | fn first_non_ascii_byte_sse2(slice: &[u8]) -> usize { |
118 | use core::arch::x86_64::*; |
119 | |
120 | const VECTOR_SIZE: usize = core::mem::size_of::<__m128i>(); |
121 | const VECTOR_ALIGN: usize = VECTOR_SIZE - 1; |
122 | const VECTOR_LOOP_SIZE: usize = 4 * VECTOR_SIZE; |
123 | |
124 | let start_ptr = slice.as_ptr(); |
125 | let end_ptr = slice[slice.len()..].as_ptr(); |
126 | let mut ptr = start_ptr; |
127 | |
128 | unsafe { |
129 | if slice.len() < VECTOR_SIZE { |
130 | return first_non_ascii_byte_slow(start_ptr, end_ptr, ptr); |
131 | } |
132 | |
133 | let chunk = _mm_loadu_si128(ptr as *const __m128i); |
134 | let mask = _mm_movemask_epi8(chunk); |
135 | if mask != 0 { |
136 | return mask.trailing_zeros() as usize; |
137 | } |
138 | |
139 | ptr = ptr.add(VECTOR_SIZE - (start_ptr as usize & VECTOR_ALIGN)); |
140 | debug_assert!(ptr > start_ptr); |
141 | debug_assert!(end_ptr.sub(VECTOR_SIZE) >= start_ptr); |
142 | if slice.len() >= VECTOR_LOOP_SIZE { |
143 | while ptr <= ptr_sub(end_ptr, VECTOR_LOOP_SIZE) { |
144 | debug_assert_eq!(0, (ptr as usize) % VECTOR_SIZE); |
145 | |
146 | let a = _mm_load_si128(ptr as *const __m128i); |
147 | let b = _mm_load_si128(ptr.add(VECTOR_SIZE) as *const __m128i); |
148 | let c = |
149 | _mm_load_si128(ptr.add(2 * VECTOR_SIZE) as *const __m128i); |
150 | let d = |
151 | _mm_load_si128(ptr.add(3 * VECTOR_SIZE) as *const __m128i); |
152 | |
153 | let or1 = _mm_or_si128(a, b); |
154 | let or2 = _mm_or_si128(c, d); |
155 | let or3 = _mm_or_si128(or1, or2); |
156 | if _mm_movemask_epi8(or3) != 0 { |
157 | let mut at = sub(ptr, start_ptr); |
158 | let mask = _mm_movemask_epi8(a); |
159 | if mask != 0 { |
160 | return at + mask.trailing_zeros() as usize; |
161 | } |
162 | |
163 | at += VECTOR_SIZE; |
164 | let mask = _mm_movemask_epi8(b); |
165 | if mask != 0 { |
166 | return at + mask.trailing_zeros() as usize; |
167 | } |
168 | |
169 | at += VECTOR_SIZE; |
170 | let mask = _mm_movemask_epi8(c); |
171 | if mask != 0 { |
172 | return at + mask.trailing_zeros() as usize; |
173 | } |
174 | |
175 | at += VECTOR_SIZE; |
176 | let mask = _mm_movemask_epi8(d); |
177 | debug_assert!(mask != 0); |
178 | return at + mask.trailing_zeros() as usize; |
179 | } |
180 | ptr = ptr_add(ptr, VECTOR_LOOP_SIZE); |
181 | } |
182 | } |
183 | while ptr <= end_ptr.sub(VECTOR_SIZE) { |
184 | debug_assert!(sub(end_ptr, ptr) >= VECTOR_SIZE); |
185 | |
186 | let chunk = _mm_loadu_si128(ptr as *const __m128i); |
187 | let mask = _mm_movemask_epi8(chunk); |
188 | if mask != 0 { |
189 | return sub(ptr, start_ptr) + mask.trailing_zeros() as usize; |
190 | } |
191 | ptr = ptr.add(VECTOR_SIZE); |
192 | } |
193 | first_non_ascii_byte_slow(start_ptr, end_ptr, ptr) |
194 | } |
195 | } |
196 | |
197 | #[inline (always)] |
198 | unsafe fn first_non_ascii_byte_slow( |
199 | start_ptr: *const u8, |
200 | end_ptr: *const u8, |
201 | mut ptr: *const u8, |
202 | ) -> usize { |
203 | debug_assert!(start_ptr <= ptr); |
204 | debug_assert!(ptr <= end_ptr); |
205 | |
206 | while ptr < end_ptr { |
207 | if *ptr > 0x7F { |
208 | return sub(a:ptr, b:start_ptr); |
209 | } |
210 | ptr = ptr.offset(count:1); |
211 | } |
212 | sub(a:end_ptr, b:start_ptr) |
213 | } |
214 | |
215 | /// Compute the position of the first ASCII byte in the given mask. |
216 | /// |
217 | /// The mask should be computed by `chunk & ASCII_MASK`, where `chunk` is |
218 | /// 8 contiguous bytes of the slice being checked where *at least* one of those |
219 | /// bytes is not an ASCII byte. |
220 | /// |
221 | /// The position returned is always in the inclusive range [0, 7]. |
222 | #[cfg (any(test, miri, not(target_arch = "x86_64" )))] |
223 | fn first_non_ascii_byte_mask(mask: usize) -> usize { |
224 | #[cfg (target_endian = "little" )] |
225 | { |
226 | mask.trailing_zeros() as usize / 8 |
227 | } |
228 | #[cfg (target_endian = "big" )] |
229 | { |
230 | mask.leading_zeros() as usize / 8 |
231 | } |
232 | } |
233 | |
234 | /// Increment the given pointer by the given amount. |
235 | unsafe fn ptr_add(ptr: *const u8, amt: usize) -> *const u8 { |
236 | debug_assert!(amt < ::core::isize::MAX as usize); |
237 | ptr.offset(count:amt as isize) |
238 | } |
239 | |
240 | /// Decrement the given pointer by the given amount. |
241 | unsafe fn ptr_sub(ptr: *const u8, amt: usize) -> *const u8 { |
242 | debug_assert!(amt < ::core::isize::MAX as usize); |
243 | ptr.offset((amt as isize).wrapping_neg()) |
244 | } |
245 | |
246 | #[cfg (any(test, miri, not(target_arch = "x86_64" )))] |
247 | unsafe fn read_unaligned_usize(ptr: *const u8) -> usize { |
248 | use core::ptr; |
249 | |
250 | let mut n: usize = 0; |
251 | ptr::copy_nonoverlapping(src:ptr, &mut n as *mut _ as *mut u8, USIZE_BYTES); |
252 | n |
253 | } |
254 | |
255 | /// Subtract `b` from `a` and return the difference. `a` should be greater than |
256 | /// or equal to `b`. |
257 | fn sub(a: *const u8, b: *const u8) -> usize { |
258 | debug_assert!(a >= b); |
259 | (a as usize) - (b as usize) |
260 | } |
261 | |
262 | #[cfg (test)] |
263 | mod tests { |
264 | use super::*; |
265 | |
266 | // Our testing approach here is to try and exhaustively test every case. |
267 | // This includes the position at which a non-ASCII byte occurs in addition |
268 | // to the alignment of the slice that we're searching. |
269 | |
270 | #[test ] |
271 | fn positive_fallback_forward() { |
272 | for i in 0..517 { |
273 | let s = "a" .repeat(i); |
274 | assert_eq!( |
275 | i, |
276 | first_non_ascii_byte_fallback(s.as_bytes()), |
277 | "i: {:?}, len: {:?}, s: {:?}" , |
278 | i, |
279 | s.len(), |
280 | s |
281 | ); |
282 | } |
283 | } |
284 | |
285 | #[test ] |
286 | #[cfg (target_arch = "x86_64" )] |
287 | #[cfg (not(miri))] |
288 | fn positive_sse2_forward() { |
289 | for i in 0..517 { |
290 | let b = "a" .repeat(i).into_bytes(); |
291 | assert_eq!(b.len(), first_non_ascii_byte_sse2(&b)); |
292 | } |
293 | } |
294 | |
295 | #[test ] |
296 | #[cfg (not(miri))] |
297 | fn negative_fallback_forward() { |
298 | for i in 0..517 { |
299 | for align in 0..65 { |
300 | let mut s = "a" .repeat(i); |
301 | s.push_str("☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃" ); |
302 | let s = s.get(align..).unwrap_or("" ); |
303 | assert_eq!( |
304 | i.saturating_sub(align), |
305 | first_non_ascii_byte_fallback(s.as_bytes()), |
306 | "i: {:?}, align: {:?}, len: {:?}, s: {:?}" , |
307 | i, |
308 | align, |
309 | s.len(), |
310 | s |
311 | ); |
312 | } |
313 | } |
314 | } |
315 | |
316 | #[test ] |
317 | #[cfg (target_arch = "x86_64" )] |
318 | #[cfg (not(miri))] |
319 | fn negative_sse2_forward() { |
320 | for i in 0..517 { |
321 | for align in 0..65 { |
322 | let mut s = "a" .repeat(i); |
323 | s.push_str("☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃" ); |
324 | let s = s.get(align..).unwrap_or("" ); |
325 | assert_eq!( |
326 | i.saturating_sub(align), |
327 | first_non_ascii_byte_sse2(s.as_bytes()), |
328 | "i: {:?}, align: {:?}, len: {:?}, s: {:?}" , |
329 | i, |
330 | align, |
331 | s.len(), |
332 | s |
333 | ); |
334 | } |
335 | } |
336 | } |
337 | } |
338 | |