| 1 | // The following ~400 lines of code exists for exactly one purpose, which is | 
| 2 | // to optimize this code: | 
|---|
| 3 | // | 
|---|
| 4 | //     byte_slice.iter().position(|&b| b > 0x7F).unwrap_or(byte_slice.len()) | 
|---|
| 5 | // | 
|---|
| 6 | // Yes... Overengineered is a word that comes to mind, but this is effectively | 
|---|
| 7 | // a very similar problem to memchr, and virtually nobody has been able to | 
|---|
| 8 | // resist optimizing the crap out of that (except for perhaps the BSD and MUSL | 
|---|
| 9 | // folks). In particular, this routine makes a very common case (ASCII) very | 
|---|
| 10 | // fast, which seems worth it. We do stop short of adding AVX variants of the | 
|---|
| 11 | // code below in order to retain our sanity and also to avoid needing to deal | 
|---|
| 12 | // with runtime target feature detection. RESIST! | 
|---|
| 13 | // | 
|---|
| 14 | // In order to understand the SIMD version below, it would be good to read this | 
|---|
| 15 | // comment describing how my memchr routine works: | 
|---|
| 16 | // https://github.com/BurntSushi/rust-memchr/blob/b0a29f267f4a7fad8ffcc8fe8377a06498202883/src/x86/sse2.rs#L19-L106 | 
|---|
| 17 | // | 
|---|
| 18 | // The primary difference with memchr is that for ASCII, we can do a bit less | 
|---|
| 19 | // work. In particular, we don't need to detect the presence of a specific | 
|---|
| 20 | // byte, but rather, whether any byte has its most significant bit set. That | 
|---|
| 21 | // means we can effectively skip the _mm_cmpeq_epi8 step and jump straight to | 
|---|
| 22 | // _mm_movemask_epi8. | 
|---|
| 23 |  | 
|---|
| 24 | #[ cfg(any(test, miri, not(target_arch = "x86_64")))] | 
|---|
| 25 | const USIZE_BYTES: usize = core::mem::size_of::<usize>(); | 
|---|
| 26 | #[ cfg(any(test, miri, not(target_arch = "x86_64")))] | 
|---|
| 27 | const ALIGN_MASK: usize = core::mem::align_of::<usize>() - 1; | 
|---|
| 28 | #[ cfg(any(test, miri, not(target_arch = "x86_64")))] | 
|---|
| 29 | const FALLBACK_LOOP_SIZE: usize = 2 * USIZE_BYTES; | 
|---|
| 30 |  | 
|---|
| 31 | // This is a mask where the most significant bit of each byte in the usize | 
|---|
| 32 | // is set. We test this bit to determine whether a character is ASCII or not. | 
|---|
| 33 | // Namely, a single byte is regarded as an ASCII codepoint if and only if it's | 
|---|
| 34 | // most significant bit is not set. | 
|---|
| 35 | #[ cfg(any(test, miri, not(target_arch = "x86_64")))] | 
|---|
| 36 | const ASCII_MASK_U64: u64 = 0x8080808080808080; | 
|---|
| 37 | #[ cfg(any(test, miri, not(target_arch = "x86_64")))] | 
|---|
| 38 | const ASCII_MASK: usize = ASCII_MASK_U64 as usize; | 
|---|
| 39 |  | 
|---|
| 40 | /// Returns the index of the first non ASCII byte in the given slice. | 
|---|
| 41 | /// | 
|---|
| 42 | /// If slice only contains ASCII bytes, then the length of the slice is | 
|---|
| 43 | /// returned. | 
|---|
| 44 | pub fn first_non_ascii_byte(slice: &[u8]) -> usize { | 
|---|
| 45 | #[ cfg(any(miri, not(target_arch = "x86_64")))] | 
|---|
| 46 | { | 
|---|
| 47 | first_non_ascii_byte_fallback(slice) | 
|---|
| 48 | } | 
|---|
| 49 |  | 
|---|
| 50 | #[ cfg(all(not(miri), target_arch = "x86_64"))] | 
|---|
| 51 | { | 
|---|
| 52 | first_non_ascii_byte_sse2(slice) | 
|---|
| 53 | } | 
|---|
| 54 | } | 
|---|
| 55 |  | 
|---|
| 56 | #[ cfg(any(test, miri, not(target_arch = "x86_64")))] | 
|---|
| 57 | fn first_non_ascii_byte_fallback(slice: &[u8]) -> usize { | 
|---|
| 58 | let start_ptr = slice.as_ptr(); | 
|---|
| 59 | let end_ptr = slice[slice.len()..].as_ptr(); | 
|---|
| 60 | let mut ptr = start_ptr; | 
|---|
| 61 |  | 
|---|
| 62 | unsafe { | 
|---|
| 63 | if slice.len() < USIZE_BYTES { | 
|---|
| 64 | return first_non_ascii_byte_slow(start_ptr, end_ptr, ptr); | 
|---|
| 65 | } | 
|---|
| 66 |  | 
|---|
| 67 | let chunk = read_unaligned_usize(ptr); | 
|---|
| 68 | let mask = chunk & ASCII_MASK; | 
|---|
| 69 | if mask != 0 { | 
|---|
| 70 | return first_non_ascii_byte_mask(mask); | 
|---|
| 71 | } | 
|---|
| 72 |  | 
|---|
| 73 | ptr = ptr_add(ptr, USIZE_BYTES - (start_ptr as usize & ALIGN_MASK)); | 
|---|
| 74 | debug_assert!(ptr > start_ptr); | 
|---|
| 75 | debug_assert!(ptr_sub(end_ptr, USIZE_BYTES) >= start_ptr); | 
|---|
| 76 | if slice.len() >= FALLBACK_LOOP_SIZE { | 
|---|
| 77 | while ptr <= ptr_sub(end_ptr, FALLBACK_LOOP_SIZE) { | 
|---|
| 78 | debug_assert_eq!(0, (ptr as usize) % USIZE_BYTES); | 
|---|
| 79 |  | 
|---|
| 80 | let a = *(ptr as *const usize); | 
|---|
| 81 | let b = *(ptr_add(ptr, USIZE_BYTES) as *const usize); | 
|---|
| 82 | if (a | b) & ASCII_MASK != 0 { | 
|---|
| 83 | // What a kludge. We wrap the position finding code into | 
|---|
| 84 | // a non-inlineable function, which makes the codegen in | 
|---|
| 85 | // the tight loop above a bit better by avoiding a | 
|---|
| 86 | // couple extra movs. We pay for it by two additional | 
|---|
| 87 | // stores, but only in the case of finding a non-ASCII | 
|---|
| 88 | // byte. | 
|---|
| 89 | #[ inline(never)] | 
|---|
| 90 | unsafe fn findpos( | 
|---|
| 91 | start_ptr: *const u8, | 
|---|
| 92 | ptr: *const u8, | 
|---|
| 93 | ) -> usize { | 
|---|
| 94 | let a = *(ptr as *const usize); | 
|---|
| 95 | let b = *(ptr_add(ptr, USIZE_BYTES) as *const usize); | 
|---|
| 96 |  | 
|---|
| 97 | let mut at = sub(ptr, start_ptr); | 
|---|
| 98 | let maska = a & ASCII_MASK; | 
|---|
| 99 | if maska != 0 { | 
|---|
| 100 | return at + first_non_ascii_byte_mask(maska); | 
|---|
| 101 | } | 
|---|
| 102 |  | 
|---|
| 103 | at += USIZE_BYTES; | 
|---|
| 104 | let maskb = b & ASCII_MASK; | 
|---|
| 105 | debug_assert!(maskb != 0); | 
|---|
| 106 | return at + first_non_ascii_byte_mask(maskb); | 
|---|
| 107 | } | 
|---|
| 108 | return findpos(start_ptr, ptr); | 
|---|
| 109 | } | 
|---|
| 110 | ptr = ptr_add(ptr, FALLBACK_LOOP_SIZE); | 
|---|
| 111 | } | 
|---|
| 112 | } | 
|---|
| 113 | first_non_ascii_byte_slow(start_ptr, end_ptr, ptr) | 
|---|
| 114 | } | 
|---|
| 115 | } | 
|---|
| 116 |  | 
|---|
| 117 | #[ cfg(all(not(miri), target_arch = "x86_64"))] | 
|---|
| 118 | fn first_non_ascii_byte_sse2(slice: &[u8]) -> usize { | 
|---|
| 119 | use core::arch::x86_64::*; | 
|---|
| 120 |  | 
|---|
| 121 | const VECTOR_SIZE: usize = core::mem::size_of::<__m128i>(); | 
|---|
| 122 | const VECTOR_ALIGN: usize = VECTOR_SIZE - 1; | 
|---|
| 123 | const VECTOR_LOOP_SIZE: usize = 4 * VECTOR_SIZE; | 
|---|
| 124 |  | 
|---|
| 125 | let start_ptr = slice.as_ptr(); | 
|---|
| 126 | let end_ptr = slice[slice.len()..].as_ptr(); | 
|---|
| 127 | let mut ptr = start_ptr; | 
|---|
| 128 |  | 
|---|
| 129 | unsafe { | 
|---|
| 130 | if slice.len() < VECTOR_SIZE { | 
|---|
| 131 | return first_non_ascii_byte_slow(start_ptr, end_ptr, ptr); | 
|---|
| 132 | } | 
|---|
| 133 |  | 
|---|
| 134 | let chunk = _mm_loadu_si128(ptr as *const __m128i); | 
|---|
| 135 | let mask = _mm_movemask_epi8(chunk); | 
|---|
| 136 | if mask != 0 { | 
|---|
| 137 | return mask.trailing_zeros() as usize; | 
|---|
| 138 | } | 
|---|
| 139 |  | 
|---|
| 140 | ptr = ptr.add(VECTOR_SIZE - (start_ptr as usize & VECTOR_ALIGN)); | 
|---|
| 141 | debug_assert!(ptr > start_ptr); | 
|---|
| 142 | debug_assert!(end_ptr.sub(VECTOR_SIZE) >= start_ptr); | 
|---|
| 143 | if slice.len() >= VECTOR_LOOP_SIZE { | 
|---|
| 144 | while ptr <= ptr_sub(end_ptr, VECTOR_LOOP_SIZE) { | 
|---|
| 145 | debug_assert_eq!(0, (ptr as usize) % VECTOR_SIZE); | 
|---|
| 146 |  | 
|---|
| 147 | let a = _mm_load_si128(ptr as *const __m128i); | 
|---|
| 148 | let b = _mm_load_si128(ptr.add(VECTOR_SIZE) as *const __m128i); | 
|---|
| 149 | let c = | 
|---|
| 150 | _mm_load_si128(ptr.add(2 * VECTOR_SIZE) as *const __m128i); | 
|---|
| 151 | let d = | 
|---|
| 152 | _mm_load_si128(ptr.add(3 * VECTOR_SIZE) as *const __m128i); | 
|---|
| 153 |  | 
|---|
| 154 | let or1 = _mm_or_si128(a, b); | 
|---|
| 155 | let or2 = _mm_or_si128(c, d); | 
|---|
| 156 | let or3 = _mm_or_si128(or1, or2); | 
|---|
| 157 | if _mm_movemask_epi8(or3) != 0 { | 
|---|
| 158 | let mut at = sub(ptr, start_ptr); | 
|---|
| 159 | let mask = _mm_movemask_epi8(a); | 
|---|
| 160 | if mask != 0 { | 
|---|
| 161 | return at + mask.trailing_zeros() as usize; | 
|---|
| 162 | } | 
|---|
| 163 |  | 
|---|
| 164 | at += VECTOR_SIZE; | 
|---|
| 165 | let mask = _mm_movemask_epi8(b); | 
|---|
| 166 | if mask != 0 { | 
|---|
| 167 | return at + mask.trailing_zeros() as usize; | 
|---|
| 168 | } | 
|---|
| 169 |  | 
|---|
| 170 | at += VECTOR_SIZE; | 
|---|
| 171 | let mask = _mm_movemask_epi8(c); | 
|---|
| 172 | if mask != 0 { | 
|---|
| 173 | return at + mask.trailing_zeros() as usize; | 
|---|
| 174 | } | 
|---|
| 175 |  | 
|---|
| 176 | at += VECTOR_SIZE; | 
|---|
| 177 | let mask = _mm_movemask_epi8(d); | 
|---|
| 178 | debug_assert!(mask != 0); | 
|---|
| 179 | return at + mask.trailing_zeros() as usize; | 
|---|
| 180 | } | 
|---|
| 181 | ptr = ptr_add(ptr, VECTOR_LOOP_SIZE); | 
|---|
| 182 | } | 
|---|
| 183 | } | 
|---|
| 184 | while ptr <= end_ptr.sub(VECTOR_SIZE) { | 
|---|
| 185 | debug_assert!(sub(end_ptr, ptr) >= VECTOR_SIZE); | 
|---|
| 186 |  | 
|---|
| 187 | let chunk = _mm_loadu_si128(ptr as *const __m128i); | 
|---|
| 188 | let mask = _mm_movemask_epi8(chunk); | 
|---|
| 189 | if mask != 0 { | 
|---|
| 190 | return sub(ptr, start_ptr) + mask.trailing_zeros() as usize; | 
|---|
| 191 | } | 
|---|
| 192 | ptr = ptr.add(VECTOR_SIZE); | 
|---|
| 193 | } | 
|---|
| 194 | first_non_ascii_byte_slow(start_ptr, end_ptr, ptr) | 
|---|
| 195 | } | 
|---|
| 196 | } | 
|---|
| 197 |  | 
|---|
| 198 | #[ inline(always)] | 
|---|
| 199 | unsafe fn first_non_ascii_byte_slow( | 
|---|
| 200 | start_ptr: *const u8, | 
|---|
| 201 | end_ptr: *const u8, | 
|---|
| 202 | mut ptr: *const u8, | 
|---|
| 203 | ) -> usize { | 
|---|
| 204 | debug_assert!(start_ptr <= ptr); | 
|---|
| 205 | debug_assert!(ptr <= end_ptr); | 
|---|
| 206 |  | 
|---|
| 207 | while ptr < end_ptr { | 
|---|
| 208 | if *ptr > 0x7F { | 
|---|
| 209 | return sub(a:ptr, b:start_ptr); | 
|---|
| 210 | } | 
|---|
| 211 | ptr = ptr.offset(count:1); | 
|---|
| 212 | } | 
|---|
| 213 | sub(a:end_ptr, b:start_ptr) | 
|---|
| 214 | } | 
|---|
| 215 |  | 
|---|
| 216 | /// Compute the position of the first ASCII byte in the given mask. | 
|---|
| 217 | /// | 
|---|
| 218 | /// The mask should be computed by `chunk & ASCII_MASK`, where `chunk` is | 
|---|
| 219 | /// 8 contiguous bytes of the slice being checked where *at least* one of those | 
|---|
| 220 | /// bytes is not an ASCII byte. | 
|---|
| 221 | /// | 
|---|
| 222 | /// The position returned is always in the inclusive range [0, 7]. | 
|---|
| 223 | #[ cfg(any(test, miri, not(target_arch = "x86_64")))] | 
|---|
| 224 | fn first_non_ascii_byte_mask(mask: usize) -> usize { | 
|---|
| 225 | #[ cfg(target_endian = "little")] | 
|---|
| 226 | { | 
|---|
| 227 | mask.trailing_zeros() as usize / 8 | 
|---|
| 228 | } | 
|---|
| 229 | #[ cfg(target_endian = "big")] | 
|---|
| 230 | { | 
|---|
| 231 | mask.leading_zeros() as usize / 8 | 
|---|
| 232 | } | 
|---|
| 233 | } | 
|---|
| 234 |  | 
|---|
| 235 | /// Increment the given pointer by the given amount. | 
|---|
| 236 | unsafe fn ptr_add(ptr: *const u8, amt: usize) -> *const u8 { | 
|---|
| 237 | ptr.add(count:amt) | 
|---|
| 238 | } | 
|---|
| 239 |  | 
|---|
| 240 | /// Decrement the given pointer by the given amount. | 
|---|
| 241 | unsafe fn ptr_sub(ptr: *const u8, amt: usize) -> *const u8 { | 
|---|
| 242 | ptr.sub(count:amt) | 
|---|
| 243 | } | 
|---|
| 244 |  | 
|---|
| 245 | #[ cfg(any(test, miri, not(target_arch = "x86_64")))] | 
|---|
| 246 | unsafe fn read_unaligned_usize(ptr: *const u8) -> usize { | 
|---|
| 247 | use core::ptr; | 
|---|
| 248 |  | 
|---|
| 249 | let mut n: usize = 0; | 
|---|
| 250 | ptr::copy_nonoverlapping(ptr, &mut n as *mut _ as *mut u8, USIZE_BYTES); | 
|---|
| 251 | n | 
|---|
| 252 | } | 
|---|
| 253 |  | 
|---|
| 254 | /// Subtract `b` from `a` and return the difference. `a` should be greater than | 
|---|
| 255 | /// or equal to `b`. | 
|---|
| 256 | fn sub(a: *const u8, b: *const u8) -> usize { | 
|---|
| 257 | debug_assert!(a >= b); | 
|---|
| 258 | (a as usize) - (b as usize) | 
|---|
| 259 | } | 
|---|
| 260 |  | 
|---|
| 261 | #[ cfg(test)] | 
|---|
| 262 | mod tests { | 
|---|
| 263 | use super::*; | 
|---|
| 264 |  | 
|---|
| 265 | // Our testing approach here is to try and exhaustively test every case. | 
|---|
| 266 | // This includes the position at which a non-ASCII byte occurs in addition | 
|---|
| 267 | // to the alignment of the slice that we're searching. | 
|---|
| 268 |  | 
|---|
| 269 | #[ test] | 
|---|
| 270 | fn positive_fallback_forward() { | 
|---|
| 271 | for i in 0..517 { | 
|---|
| 272 | let s = "a".repeat(i); | 
|---|
| 273 | assert_eq!( | 
|---|
| 274 | i, | 
|---|
| 275 | first_non_ascii_byte_fallback(s.as_bytes()), | 
|---|
| 276 | "i: {:?}, len: {:?}, s: {:?}", | 
|---|
| 277 | i, | 
|---|
| 278 | s.len(), | 
|---|
| 279 | s | 
|---|
| 280 | ); | 
|---|
| 281 | } | 
|---|
| 282 | } | 
|---|
| 283 |  | 
|---|
| 284 | #[ test] | 
|---|
| 285 | #[ cfg(target_arch = "x86_64")] | 
|---|
| 286 | #[ cfg(not(miri))] | 
|---|
| 287 | fn positive_sse2_forward() { | 
|---|
| 288 | for i in 0..517 { | 
|---|
| 289 | let b = "a".repeat(i).into_bytes(); | 
|---|
| 290 | assert_eq!(b.len(), first_non_ascii_byte_sse2(&b)); | 
|---|
| 291 | } | 
|---|
| 292 | } | 
|---|
| 293 |  | 
|---|
| 294 | #[ test] | 
|---|
| 295 | #[ cfg(not(miri))] | 
|---|
| 296 | fn negative_fallback_forward() { | 
|---|
| 297 | for i in 0..517 { | 
|---|
| 298 | for align in 0..65 { | 
|---|
| 299 | let mut s = "a".repeat(i); | 
|---|
| 300 | s.push_str( "☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃"); | 
|---|
| 301 | let s = s.get(align..).unwrap_or( ""); | 
|---|
| 302 | assert_eq!( | 
|---|
| 303 | i.saturating_sub(align), | 
|---|
| 304 | first_non_ascii_byte_fallback(s.as_bytes()), | 
|---|
| 305 | "i: {:?}, align: {:?}, len: {:?}, s: {:?}", | 
|---|
| 306 | i, | 
|---|
| 307 | align, | 
|---|
| 308 | s.len(), | 
|---|
| 309 | s | 
|---|
| 310 | ); | 
|---|
| 311 | } | 
|---|
| 312 | } | 
|---|
| 313 | } | 
|---|
| 314 |  | 
|---|
| 315 | #[ test] | 
|---|
| 316 | #[ cfg(target_arch = "x86_64")] | 
|---|
| 317 | #[ cfg(not(miri))] | 
|---|
| 318 | fn negative_sse2_forward() { | 
|---|
| 319 | for i in 0..517 { | 
|---|
| 320 | for align in 0..65 { | 
|---|
| 321 | let mut s = "a".repeat(i); | 
|---|
| 322 | s.push_str( "☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃"); | 
|---|
| 323 | let s = s.get(align..).unwrap_or( ""); | 
|---|
| 324 | assert_eq!( | 
|---|
| 325 | i.saturating_sub(align), | 
|---|
| 326 | first_non_ascii_byte_sse2(s.as_bytes()), | 
|---|
| 327 | "i: {:?}, align: {:?}, len: {:?}, s: {:?}", | 
|---|
| 328 | i, | 
|---|
| 329 | align, | 
|---|
| 330 | s.len(), | 
|---|
| 331 | s | 
|---|
| 332 | ); | 
|---|
| 333 | } | 
|---|
| 334 | } | 
|---|
| 335 | } | 
|---|
| 336 | } | 
|---|
| 337 |  | 
|---|