1 | //! Parallel quicksort. |
---|---|

2 | //! |

3 | //! This implementation is copied verbatim from `std::slice::sort_unstable` and then parallelized. |

4 | //! The only difference from the original is that calls to `recurse` are executed in parallel using |

5 | //! `rayon_core::join`. |

6 | |

7 | use std::cmp; |

8 | use std::marker::PhantomData; |

9 | use std::mem::{self, MaybeUninit}; |

10 | use std::ptr; |

11 | |

12 | /// When dropped, copies from `src` into `dest`. |

13 | #[must_use] |

14 | struct CopyOnDrop<'a, T> { |

15 | src: *const T, |

16 | dest: *mut T, |

17 | /// `src` is often a local pointer here, make sure we have appropriate |

18 | /// PhantomData so that dropck can protect us. |

19 | marker: PhantomData<&'a mut T>, |

20 | } |

21 | |

22 | impl<'a, T> CopyOnDrop<'a, T> { |

23 | /// Construct from a source pointer and a destination |

24 | /// Assumes dest lives longer than src, since there is no easy way to |

25 | /// copy down lifetime information from another pointer |

26 | unsafe fn new(src: &'a T, dest: *mut T) -> Self { |

27 | CopyOnDrop { |

28 | src, |

29 | dest, |

30 | marker: PhantomData, |

31 | } |

32 | } |

33 | } |

34 | |

35 | impl<T> Drop for CopyOnDrop<'_, T> { |

36 | fn drop(&mut self) { |

37 | // SAFETY: This is a helper class. |

38 | // Please refer to its usage for correctness. |

39 | // Namely, one must be sure that `src` and `dst` does not overlap as required by `ptr::copy_nonoverlapping`. |

40 | unsafe { |

41 | ptr::copy_nonoverlapping(self.src, self.dest, 1); |

42 | } |

43 | } |

44 | } |

45 | |

46 | /// Shifts the first element to the right until it encounters a greater or equal element. |

47 | fn shift_head<T, F>(v: &mut [T], is_less: &F) |

48 | where |

49 | F: Fn(&T, &T) -> bool, |

50 | { |

51 | let len = v.len(); |

52 | // SAFETY: The unsafe operations below involves indexing without a bounds check (by offsetting a |

53 | // pointer) and copying memory (`ptr::copy_nonoverlapping`). |

54 | // |

55 | // a. Indexing: |

56 | // 1. We checked the size of the array to >=2. |

57 | // 2. All the indexing that we will do is always between {0 <= index < len} at most. |

58 | // |

59 | // b. Memory copying |

60 | // 1. We are obtaining pointers to references which are guaranteed to be valid. |

61 | // 2. They cannot overlap because we obtain pointers to difference indices of the slice. |

62 | // Namely, `i` and `i-1`. |

63 | // 3. If the slice is properly aligned, the elements are properly aligned. |

64 | // It is the caller's responsibility to make sure the slice is properly aligned. |

65 | // |

66 | // See comments below for further detail. |

67 | unsafe { |

68 | // If the first two elements are out-of-order... |

69 | if len >= 2 && is_less(v.get_unchecked(1), v.get_unchecked(0)) { |

70 | // Read the first element into a stack-allocated variable. If a following comparison |

71 | // operation panics, `hole` will get dropped and automatically write the element back |

72 | // into the slice. |

73 | let tmp = mem::ManuallyDrop::new(ptr::read(v.get_unchecked(0))); |

74 | let v = v.as_mut_ptr(); |

75 | let mut hole = CopyOnDrop::new(&*tmp, v.add(1)); |

76 | ptr::copy_nonoverlapping(v.add(1), v.add(0), 1); |

77 | |

78 | for i in 2..len { |

79 | if !is_less(&*v.add(i), &*tmp) { |

80 | break; |

81 | } |

82 | |

83 | // Move `i`-th element one place to the left, thus shifting the hole to the right. |

84 | ptr::copy_nonoverlapping(v.add(i), v.add(i - 1), 1); |

85 | hole.dest = v.add(i); |

86 | } |

87 | // `hole` gets dropped and thus copies `tmp` into the remaining hole in `v`. |

88 | } |

89 | } |

90 | } |

91 | |

92 | /// Shifts the last element to the left until it encounters a smaller or equal element. |

93 | fn shift_tail<T, F>(v: &mut [T], is_less: &F) |

94 | where |

95 | F: Fn(&T, &T) -> bool, |

96 | { |

97 | let len = v.len(); |

98 | // SAFETY: The unsafe operations below involves indexing without a bound check (by offsetting a |

99 | // pointer) and copying memory (`ptr::copy_nonoverlapping`). |

100 | // |

101 | // a. Indexing: |

102 | // 1. We checked the size of the array to >= 2. |

103 | // 2. All the indexing that we will do is always between `0 <= index < len-1` at most. |

104 | // |

105 | // b. Memory copying |

106 | // 1. We are obtaining pointers to references which are guaranteed to be valid. |

107 | // 2. They cannot overlap because we obtain pointers to difference indices of the slice. |

108 | // Namely, `i` and `i+1`. |

109 | // 3. If the slice is properly aligned, the elements are properly aligned. |

110 | // It is the caller's responsibility to make sure the slice is properly aligned. |

111 | // |

112 | // See comments below for further detail. |

113 | unsafe { |

114 | // If the last two elements are out-of-order... |

115 | if len >= 2 && is_less(v.get_unchecked(len - 1), v.get_unchecked(len - 2)) { |

116 | // Read the last element into a stack-allocated variable. If a following comparison |

117 | // operation panics, `hole` will get dropped and automatically write the element back |

118 | // into the slice. |

119 | let tmp = mem::ManuallyDrop::new(ptr::read(v.get_unchecked(len - 1))); |

120 | let v = v.as_mut_ptr(); |

121 | let mut hole = CopyOnDrop::new(&*tmp, v.add(len - 2)); |

122 | ptr::copy_nonoverlapping(v.add(len - 2), v.add(len - 1), 1); |

123 | |

124 | for i in (0..len - 2).rev() { |

125 | if !is_less(&*tmp, &*v.add(i)) { |

126 | break; |

127 | } |

128 | |

129 | // Move `i`-th element one place to the right, thus shifting the hole to the left. |

130 | ptr::copy_nonoverlapping(v.add(i), v.add(i + 1), 1); |

131 | hole.dest = v.add(i); |

132 | } |

133 | // `hole` gets dropped and thus copies `tmp` into the remaining hole in `v`. |

134 | } |

135 | } |

136 | } |

137 | |

138 | /// Partially sorts a slice by shifting several out-of-order elements around. |

139 | /// |

140 | /// Returns `true` if the slice is sorted at the end. This function is *O*(*n*) worst-case. |

141 | #[cold] |

142 | fn partial_insertion_sort<T, F>(v: &mut [T], is_less: &F) -> bool |

143 | where |

144 | F: Fn(&T, &T) -> bool, |

145 | { |

146 | // Maximum number of adjacent out-of-order pairs that will get shifted. |

147 | const MAX_STEPS: usize = 5; |

148 | // If the slice is shorter than this, don't shift any elements. |

149 | const SHORTEST_SHIFTING: usize = 50; |

150 | |

151 | let len = v.len(); |

152 | let mut i = 1; |

153 | |

154 | for _ in 0..MAX_STEPS { |

155 | // SAFETY: We already explicitly did the bound checking with `i < len`. |

156 | // All our subsequent indexing is only in the range `0 <= index < len` |

157 | unsafe { |

158 | // Find the next pair of adjacent out-of-order elements. |

159 | while i < len && !is_less(v.get_unchecked(i), v.get_unchecked(i - 1)) { |

160 | i += 1; |

161 | } |

162 | } |

163 | |

164 | // Are we done? |

165 | if i == len { |

166 | return true; |

167 | } |

168 | |

169 | // Don't shift elements on short arrays, that has a performance cost. |

170 | if len < SHORTEST_SHIFTING { |

171 | return false; |

172 | } |

173 | |

174 | // Swap the found pair of elements. This puts them in correct order. |

175 | v.swap(i - 1, i); |

176 | |

177 | // Shift the smaller element to the left. |

178 | shift_tail(&mut v[..i], is_less); |

179 | // Shift the greater element to the right. |

180 | shift_head(&mut v[i..], is_less); |

181 | } |

182 | |

183 | // Didn't manage to sort the slice in the limited number of steps. |

184 | false |

185 | } |

186 | |

187 | /// Sorts a slice using insertion sort, which is *O*(*n*^2) worst-case. |

188 | fn insertion_sort<T, F>(v: &mut [T], is_less: &F) |

189 | where |

190 | F: Fn(&T, &T) -> bool, |

191 | { |

192 | for i in 1..v.len() { |

193 | shift_tail(&mut v[..i + 1], is_less); |

194 | } |

195 | } |

196 | |

197 | /// Sorts `v` using heapsort, which guarantees *O*(*n* \* log(*n*)) worst-case. |

198 | #[cold] |

199 | fn heapsort<T, F>(v: &mut [T], is_less: &F) |

200 | where |

201 | F: Fn(&T, &T) -> bool, |

202 | { |

203 | // This binary heap respects the invariant `parent >= child`. |

204 | let sift_down = |v: &mut [T], mut node| { |

205 | loop { |

206 | // Children of `node`. |

207 | let mut child = 2 * node + 1; |

208 | if child >= v.len() { |

209 | break; |

210 | } |

211 | |

212 | // Choose the greater child. |

213 | if child + 1 < v.len() && is_less(&v[child], &v[child + 1]) { |

214 | child += 1; |

215 | } |

216 | |

217 | // Stop if the invariant holds at `node`. |

218 | if !is_less(&v[node], &v[child]) { |

219 | break; |

220 | } |

221 | |

222 | // Swap `node` with the greater child, move one step down, and continue sifting. |

223 | v.swap(node, child); |

224 | node = child; |

225 | } |

226 | }; |

227 | |

228 | // Build the heap in linear time. |

229 | for i in (0..v.len() / 2).rev() { |

230 | sift_down(v, i); |

231 | } |

232 | |

233 | // Pop maximal elements from the heap. |

234 | for i in (1..v.len()).rev() { |

235 | v.swap(0, i); |

236 | sift_down(&mut v[..i], 0); |

237 | } |

238 | } |

239 | |

240 | /// Partitions `v` into elements smaller than `pivot`, followed by elements greater than or equal |

241 | /// to `pivot`. |

242 | /// |

243 | /// Returns the number of elements smaller than `pivot`. |

244 | /// |

245 | /// Partitioning is performed block-by-block in order to minimize the cost of branching operations. |

246 | /// This idea is presented in the [BlockQuicksort][pdf] paper. |

247 | /// |

248 | /// [pdf]: https://drops.dagstuhl.de/opus/volltexte/2016/6389/pdf/LIPIcs-ESA-2016-38.pdf |

249 | fn partition_in_blocks<T, F>(v: &mut [T], pivot: &T, is_less: &F) -> usize |

250 | where |

251 | F: Fn(&T, &T) -> bool, |

252 | { |

253 | // Number of elements in a typical block. |

254 | const BLOCK: usize = 128; |

255 | |

256 | // The partitioning algorithm repeats the following steps until completion: |

257 | // |

258 | // 1. Trace a block from the left side to identify elements greater than or equal to the pivot. |

259 | // 2. Trace a block from the right side to identify elements smaller than the pivot. |

260 | // 3. Exchange the identified elements between the left and right side. |

261 | // |

262 | // We keep the following variables for a block of elements: |

263 | // |

264 | // 1. `block` - Number of elements in the block. |

265 | // 2. `start` - Start pointer into the `offsets` array. |

266 | // 3. `end` - End pointer into the `offsets` array. |

267 | // 4. `offsets - Indices of out-of-order elements within the block. |

268 | |

269 | // The current block on the left side (from `l` to `l.add(block_l)`). |

270 | let mut l = v.as_mut_ptr(); |

271 | let mut block_l = BLOCK; |

272 | let mut start_l = ptr::null_mut(); |

273 | let mut end_l = ptr::null_mut(); |

274 | let mut offsets_l = [MaybeUninit::<u8>::uninit(); BLOCK]; |

275 | |

276 | // The current block on the right side (from `r.sub(block_r)` to `r`). |

277 | // SAFETY: The documentation for .add() specifically mention that `vec.as_ptr().add(vec.len())` is always safe` |

278 | let mut r = unsafe { l.add(v.len()) }; |

279 | let mut block_r = BLOCK; |

280 | let mut start_r = ptr::null_mut(); |

281 | let mut end_r = ptr::null_mut(); |

282 | let mut offsets_r = [MaybeUninit::<u8>::uninit(); BLOCK]; |

283 | |

284 | // FIXME: When we get VLAs, try creating one array of length `min(v.len(), 2 * BLOCK)` rather |

285 | // than two fixed-size arrays of length `BLOCK`. VLAs might be more cache-efficient. |

286 | |

287 | // Returns the number of elements between pointers `l` (inclusive) and `r` (exclusive). |

288 | fn width<T>(l: *mut T, r: *mut T) -> usize { |

289 | assert!(mem::size_of::<T>() > 0); |

290 | // FIXME: this should *likely* use `offset_from`, but more |

291 | // investigation is needed (including running tests in miri). |

292 | // TODO unstable: (r.addr() - l.addr()) / mem::size_of::<T>() |

293 | (r as usize - l as usize) / mem::size_of::<T>() |

294 | } |

295 | |

296 | loop { |

297 | // We are done with partitioning block-by-block when `l` and `r` get very close. Then we do |

298 | // some patch-up work in order to partition the remaining elements in between. |

299 | let is_done = width(l, r) <= 2 * BLOCK; |

300 | |

301 | if is_done { |

302 | // Number of remaining elements (still not compared to the pivot). |

303 | let mut rem = width(l, r); |

304 | if start_l < end_l || start_r < end_r { |

305 | rem -= BLOCK; |

306 | } |

307 | |

308 | // Adjust block sizes so that the left and right block don't overlap, but get perfectly |

309 | // aligned to cover the whole remaining gap. |

310 | if start_l < end_l { |

311 | block_r = rem; |

312 | } else if start_r < end_r { |

313 | block_l = rem; |

314 | } else { |

315 | // There were the same number of elements to switch on both blocks during the last |

316 | // iteration, so there are no remaining elements on either block. Cover the remaining |

317 | // items with roughly equally-sized blocks. |

318 | block_l = rem / 2; |

319 | block_r = rem - block_l; |

320 | } |

321 | debug_assert!(block_l <= BLOCK && block_r <= BLOCK); |

322 | debug_assert!(width(l, r) == block_l + block_r); |

323 | } |

324 | |

325 | if start_l == end_l { |

326 | // Trace `block_l` elements from the left side. |

327 | // TODO unstable: start_l = MaybeUninit::slice_as_mut_ptr(&mut offsets_l); |

328 | start_l = offsets_l.as_mut_ptr() as *mut u8; |

329 | end_l = start_l; |

330 | let mut elem = l; |

331 | |

332 | for i in 0..block_l { |

333 | // SAFETY: The unsafety operations below involve the usage of the `offset`. |

334 | // According to the conditions required by the function, we satisfy them because: |

335 | // 1. `offsets_l` is stack-allocated, and thus considered separate allocated object. |

336 | // 2. The function `is_less` returns a `bool`. |

337 | // Casting a `bool` will never overflow `isize`. |

338 | // 3. We have guaranteed that `block_l` will be `<= BLOCK`. |

339 | // Plus, `end_l` was initially set to the begin pointer of `offsets_` which was declared on the stack. |

340 | // Thus, we know that even in the worst case (all invocations of `is_less` returns false) we will only be at most 1 byte pass the end. |

341 | // Another unsafety operation here is dereferencing `elem`. |

342 | // However, `elem` was initially the begin pointer to the slice which is always valid. |

343 | unsafe { |

344 | // Branchless comparison. |

345 | *end_l = i as u8; |

346 | end_l = end_l.offset(!is_less(&*elem, pivot) as isize); |

347 | elem = elem.offset(1); |

348 | } |

349 | } |

350 | } |

351 | |

352 | if start_r == end_r { |

353 | // Trace `block_r` elements from the right side. |

354 | // TODO unstable: start_r = MaybeUninit::slice_as_mut_ptr(&mut offsets_r); |

355 | start_r = offsets_r.as_mut_ptr() as *mut u8; |

356 | end_r = start_r; |

357 | let mut elem = r; |

358 | |

359 | for i in 0..block_r { |

360 | // SAFETY: The unsafety operations below involve the usage of the `offset`. |

361 | // According to the conditions required by the function, we satisfy them because: |

362 | // 1. `offsets_r` is stack-allocated, and thus considered separate allocated object. |

363 | // 2. The function `is_less` returns a `bool`. |

364 | // Casting a `bool` will never overflow `isize`. |

365 | // 3. We have guaranteed that `block_r` will be `<= BLOCK`. |

366 | // Plus, `end_r` was initially set to the begin pointer of `offsets_` which was declared on the stack. |

367 | // Thus, we know that even in the worst case (all invocations of `is_less` returns true) we will only be at most 1 byte pass the end. |

368 | // Another unsafety operation here is dereferencing `elem`. |

369 | // However, `elem` was initially `1 * sizeof(T)` past the end and we decrement it by `1 * sizeof(T)` before accessing it. |

370 | // Plus, `block_r` was asserted to be less than `BLOCK` and `elem` will therefore at most be pointing to the beginning of the slice. |

371 | unsafe { |

372 | // Branchless comparison. |

373 | elem = elem.offset(-1); |

374 | *end_r = i as u8; |

375 | end_r = end_r.offset(is_less(&*elem, pivot) as isize); |

376 | } |

377 | } |

378 | } |

379 | |

380 | // Number of out-of-order elements to swap between the left and right side. |

381 | let count = cmp::min(width(start_l, end_l), width(start_r, end_r)); |

382 | |

383 | if count > 0 { |

384 | macro_rules! left { |

385 | () => { |

386 | l.offset(*start_l as isize) |

387 | }; |

388 | } |

389 | macro_rules! right { |

390 | () => { |

391 | r.offset(-(*start_r as isize) - 1) |

392 | }; |

393 | } |

394 | |

395 | // Instead of swapping one pair at the time, it is more efficient to perform a cyclic |

396 | // permutation. This is not strictly equivalent to swapping, but produces a similar |

397 | // result using fewer memory operations. |

398 | |

399 | // SAFETY: The use of `ptr::read` is valid because there is at least one element in |

400 | // both `offsets_l` and `offsets_r`, so `left!` is a valid pointer to read from. |

401 | // |

402 | // The uses of `left!` involve calls to `offset` on `l`, which points to the |

403 | // beginning of `v`. All the offsets pointed-to by `start_l` are at most `block_l`, so |

404 | // these `offset` calls are safe as all reads are within the block. The same argument |

405 | // applies for the uses of `right!`. |

406 | // |

407 | // The calls to `start_l.offset` are valid because there are at most `count-1` of them, |

408 | // plus the final one at the end of the unsafe block, where `count` is the minimum number |

409 | // of collected offsets in `offsets_l` and `offsets_r`, so there is no risk of there not |

410 | // being enough elements. The same reasoning applies to the calls to `start_r.offset`. |

411 | // |

412 | // The calls to `copy_nonoverlapping` are safe because `left!` and `right!` are guaranteed |

413 | // not to overlap, and are valid because of the reasoning above. |

414 | unsafe { |

415 | let tmp = ptr::read(left!()); |

416 | ptr::copy_nonoverlapping(right!(), left!(), 1); |

417 | |

418 | for _ in 1..count { |

419 | start_l = start_l.offset(1); |

420 | ptr::copy_nonoverlapping(left!(), right!(), 1); |

421 | start_r = start_r.offset(1); |

422 | ptr::copy_nonoverlapping(right!(), left!(), 1); |

423 | } |

424 | |

425 | ptr::copy_nonoverlapping(&tmp, right!(), 1); |

426 | mem::forget(tmp); |

427 | start_l = start_l.offset(1); |

428 | start_r = start_r.offset(1); |

429 | } |

430 | } |

431 | |

432 | if start_l == end_l { |

433 | // All out-of-order elements in the left block were moved. Move to the next block. |

434 | |

435 | // block-width-guarantee |

436 | // SAFETY: if `!is_done` then the slice width is guaranteed to be at least `2*BLOCK` wide. There |

437 | // are at most `BLOCK` elements in `offsets_l` because of its size, so the `offset` operation is |

438 | // safe. Otherwise, the debug assertions in the `is_done` case guarantee that |

439 | // `width(l, r) == block_l + block_r`, namely, that the block sizes have been adjusted to account |

440 | // for the smaller number of remaining elements. |

441 | l = unsafe { l.add(block_l) }; |

442 | } |

443 | |

444 | if start_r == end_r { |

445 | // All out-of-order elements in the right block were moved. Move to the previous block. |

446 | |

447 | // SAFETY: Same argument as [block-width-guarantee]. Either this is a full block `2*BLOCK`-wide, |

448 | // or `block_r` has been adjusted for the last handful of elements. |

449 | r = unsafe { r.offset(-(block_r as isize)) }; |

450 | } |

451 | |

452 | if is_done { |

453 | break; |

454 | } |

455 | } |

456 | |

457 | // All that remains now is at most one block (either the left or the right) with out-of-order |

458 | // elements that need to be moved. Such remaining elements can be simply shifted to the end |

459 | // within their block. |

460 | |

461 | if start_l < end_l { |

462 | // The left block remains. |

463 | // Move its remaining out-of-order elements to the far right. |

464 | debug_assert_eq!(width(l, r), block_l); |

465 | while start_l < end_l { |

466 | // remaining-elements-safety |

467 | // SAFETY: while the loop condition holds there are still elements in `offsets_l`, so it |

468 | // is safe to point `end_l` to the previous element. |

469 | // |

470 | // The `ptr::swap` is safe if both its arguments are valid for reads and writes: |

471 | // - Per the debug assert above, the distance between `l` and `r` is `block_l` |

472 | // elements, so there can be at most `block_l` remaining offsets between `start_l` |

473 | // and `end_l`. This means `r` will be moved at most `block_l` steps back, which |

474 | // makes the `r.offset` calls valid (at that point `l == r`). |

475 | // - `offsets_l` contains valid offsets into `v` collected during the partitioning of |

476 | // the last block, so the `l.offset` calls are valid. |

477 | unsafe { |

478 | end_l = end_l.offset(-1); |

479 | ptr::swap(l.offset(*end_l as isize), r.offset(-1)); |

480 | r = r.offset(-1); |

481 | } |

482 | } |

483 | width(v.as_mut_ptr(), r) |

484 | } else if start_r < end_r { |

485 | // The right block remains. |

486 | // Move its remaining out-of-order elements to the far left. |

487 | debug_assert_eq!(width(l, r), block_r); |

488 | while start_r < end_r { |

489 | // SAFETY: See the reasoning in [remaining-elements-safety]. |

490 | unsafe { |

491 | end_r = end_r.offset(-1); |

492 | ptr::swap(l, r.offset(-(*end_r as isize) - 1)); |

493 | l = l.offset(1); |

494 | } |

495 | } |

496 | width(v.as_mut_ptr(), l) |

497 | } else { |

498 | // Nothing else to do, we're done. |

499 | width(v.as_mut_ptr(), l) |

500 | } |

501 | } |

502 | |

503 | /// Partitions `v` into elements smaller than `v[pivot]`, followed by elements greater than or |

504 | /// equal to `v[pivot]`. |

505 | /// |

506 | /// Returns a tuple of: |

507 | /// |

508 | /// 1. Number of elements smaller than `v[pivot]`. |

509 | /// 2. True if `v` was already partitioned. |

510 | fn partition<T, F>(v: &mut [T], pivot: usize, is_less: &F) -> (usize, bool) |

511 | where |

512 | F: Fn(&T, &T) -> bool, |

513 | { |

514 | let (mid, was_partitioned) = { |

515 | // Place the pivot at the beginning of slice. |

516 | v.swap(0, pivot); |

517 | let (pivot, v) = v.split_at_mut(1); |

518 | let pivot = &mut pivot[0]; |

519 | |

520 | // Read the pivot into a stack-allocated variable for efficiency. If a following comparison |

521 | // operation panics, the pivot will be automatically written back into the slice. |

522 | |

523 | // SAFETY: `pivot` is a reference to the first element of `v`, so `ptr::read` is safe. |

524 | let tmp = mem::ManuallyDrop::new(unsafe { ptr::read(pivot) }); |

525 | let _pivot_guard = unsafe { CopyOnDrop::new(&*tmp, pivot) }; |

526 | let pivot = &*tmp; |

527 | |

528 | // Find the first pair of out-of-order elements. |

529 | let mut l = 0; |

530 | let mut r = v.len(); |

531 | |

532 | // SAFETY: The unsafety below involves indexing an array. |

533 | // For the first one: We already do the bounds checking here with `l < r`. |

534 | // For the second one: We initially have `l == 0` and `r == v.len()` and we checked that `l < r` at every indexing operation. |

535 | // From here we know that `r` must be at least `r == l` which was shown to be valid from the first one. |

536 | unsafe { |

537 | // Find the first element greater than or equal to the pivot. |

538 | while l < r && is_less(v.get_unchecked(l), pivot) { |

539 | l += 1; |

540 | } |

541 | |

542 | // Find the last element smaller that the pivot. |

543 | while l < r && !is_less(v.get_unchecked(r - 1), pivot) { |

544 | r -= 1; |

545 | } |

546 | } |

547 | |

548 | ( |

549 | l + partition_in_blocks(&mut v[l..r], pivot, is_less), |

550 | l >= r, |

551 | ) |

552 | |

553 | // `_pivot_guard` goes out of scope and writes the pivot (which is a stack-allocated |

554 | // variable) back into the slice where it originally was. This step is critical in ensuring |

555 | // safety! |

556 | }; |

557 | |

558 | // Place the pivot between the two partitions. |

559 | v.swap(0, mid); |

560 | |

561 | (mid, was_partitioned) |

562 | } |

563 | |

564 | /// Partitions `v` into elements equal to `v[pivot]` followed by elements greater than `v[pivot]`. |

565 | /// |

566 | /// Returns the number of elements equal to the pivot. It is assumed that `v` does not contain |

567 | /// elements smaller than the pivot. |

568 | fn partition_equal<T, F>(v: &mut [T], pivot: usize, is_less: &F) -> usize |

569 | where |

570 | F: Fn(&T, &T) -> bool, |

571 | { |

572 | // Place the pivot at the beginning of slice. |

573 | v.swap(0, pivot); |

574 | let (pivot, v) = v.split_at_mut(1); |

575 | let pivot = &mut pivot[0]; |

576 | |

577 | // Read the pivot into a stack-allocated variable for efficiency. If a following comparison |

578 | // operation panics, the pivot will be automatically written back into the slice. |

579 | // SAFETY: The pointer here is valid because it is obtained from a reference to a slice. |

580 | let tmp = mem::ManuallyDrop::new(unsafe { ptr::read(pivot) }); |

581 | let _pivot_guard = unsafe { CopyOnDrop::new(&*tmp, pivot) }; |

582 | let pivot = &*tmp; |

583 | |

584 | // Now partition the slice. |

585 | let mut l = 0; |

586 | let mut r = v.len(); |

587 | loop { |

588 | // SAFETY: The unsafety below involves indexing an array. |

589 | // For the first one: We already do the bounds checking here with `l < r`. |

590 | // For the second one: We initially have `l == 0` and `r == v.len()` and we checked that `l < r` at every indexing operation. |

591 | // From here we know that `r` must be at least `r == l` which was shown to be valid from the first one. |

592 | unsafe { |

593 | // Find the first element greater than the pivot. |

594 | while l < r && !is_less(pivot, v.get_unchecked(l)) { |

595 | l += 1; |

596 | } |

597 | |

598 | // Find the last element equal to the pivot. |

599 | while l < r && is_less(pivot, v.get_unchecked(r - 1)) { |

600 | r -= 1; |

601 | } |

602 | |

603 | // Are we done? |

604 | if l >= r { |

605 | break; |

606 | } |

607 | |

608 | // Swap the found pair of out-of-order elements. |

609 | r -= 1; |

610 | let ptr = v.as_mut_ptr(); |

611 | ptr::swap(ptr.add(l), ptr.add(r)); |

612 | l += 1; |

613 | } |

614 | } |

615 | |

616 | // We found `l` elements equal to the pivot. Add 1 to account for the pivot itself. |

617 | l + 1 |

618 | |

619 | // `_pivot_guard` goes out of scope and writes the pivot (which is a stack-allocated variable) |

620 | // back into the slice where it originally was. This step is critical in ensuring safety! |

621 | } |

622 | |

623 | /// Scatters some elements around in an attempt to break patterns that might cause imbalanced |

624 | /// partitions in quicksort. |

625 | #[cold] |

626 | fn break_patterns<T>(v: &mut [T]) { |

627 | let len = v.len(); |

628 | if len >= 8 { |

629 | // Pseudorandom number generator from the "Xorshift RNGs" paper by George Marsaglia. |

630 | let mut random = len as u32; |

631 | let mut gen_u32 = || { |

632 | random ^= random << 13; |

633 | random ^= random >> 17; |

634 | random ^= random << 5; |

635 | random |

636 | }; |

637 | let mut gen_usize = || { |

638 | if usize::BITS <= 32 { |

639 | gen_u32() as usize |

640 | } else { |

641 | (((gen_u32() as u64) << 32) | (gen_u32() as u64)) as usize |

642 | } |

643 | }; |

644 | |

645 | // Take random numbers modulo this number. |

646 | // The number fits into `usize` because `len` is not greater than `isize::MAX`. |

647 | let modulus = len.next_power_of_two(); |

648 | |

649 | // Some pivot candidates will be in the nearby of this index. Let's randomize them. |

650 | let pos = len / 4 * 2; |

651 | |

652 | for i in 0..3 { |

653 | // Generate a random number modulo `len`. However, in order to avoid costly operations |

654 | // we first take it modulo a power of two, and then decrease by `len` until it fits |

655 | // into the range `[0, len - 1]`. |

656 | let mut other = gen_usize() & (modulus - 1); |

657 | |

658 | // `other` is guaranteed to be less than `2 * len`. |

659 | if other >= len { |

660 | other -= len; |

661 | } |

662 | |

663 | v.swap(pos - 1 + i, other); |

664 | } |

665 | } |

666 | } |

667 | |

668 | /// Chooses a pivot in `v` and returns the index and `true` if the slice is likely already sorted. |

669 | /// |

670 | /// Elements in `v` might be reordered in the process. |

671 | fn choose_pivot<T, F>(v: &mut [T], is_less: &F) -> (usize, bool) |

672 | where |

673 | F: Fn(&T, &T) -> bool, |

674 | { |

675 | // Minimum length to choose the median-of-medians method. |

676 | // Shorter slices use the simple median-of-three method. |

677 | const SHORTEST_MEDIAN_OF_MEDIANS: usize = 50; |

678 | // Maximum number of swaps that can be performed in this function. |

679 | const MAX_SWAPS: usize = 4 * 3; |

680 | |

681 | let len = v.len(); |

682 | |

683 | // Three indices near which we are going to choose a pivot. |

684 | #[allow(clippy::identity_op)] |

685 | let mut a = len / 4 * 1; |

686 | let mut b = len / 4 * 2; |

687 | let mut c = len / 4 * 3; |

688 | |

689 | // Counts the total number of swaps we are about to perform while sorting indices. |

690 | let mut swaps = 0; |

691 | |

692 | if len >= 8 { |

693 | // Swaps indices so that `v[a] <= v[b]`. |

694 | // SAFETY: `len >= 8` so there are at least two elements in the neighborhoods of |

695 | // `a`, `b` and `c`. This means the three calls to `sort_adjacent` result in |

696 | // corresponding calls to `sort3` with valid 3-item neighborhoods around each |

697 | // pointer, which in turn means the calls to `sort2` are done with valid |

698 | // references. Thus the `v.get_unchecked` calls are safe, as is the `ptr::swap` |

699 | // call. |

700 | let mut sort2 = |a: &mut usize, b: &mut usize| unsafe { |

701 | if is_less(v.get_unchecked(*b), v.get_unchecked(*a)) { |

702 | ptr::swap(a, b); |

703 | swaps += 1; |

704 | } |

705 | }; |

706 | |

707 | // Swaps indices so that `v[a] <= v[b] <= v[c]`. |

708 | let mut sort3 = |a: &mut usize, b: &mut usize, c: &mut usize| { |

709 | sort2(a, b); |

710 | sort2(b, c); |

711 | sort2(a, b); |

712 | }; |

713 | |

714 | if len >= SHORTEST_MEDIAN_OF_MEDIANS { |

715 | // Finds the median of `v[a - 1], v[a], v[a + 1]` and stores the index into `a`. |

716 | let mut sort_adjacent = |a: &mut usize| { |

717 | let tmp = *a; |

718 | sort3(&mut (tmp - 1), a, &mut (tmp + 1)); |

719 | }; |

720 | |

721 | // Find medians in the neighborhoods of `a`, `b`, and `c`. |

722 | sort_adjacent(&mut a); |

723 | sort_adjacent(&mut b); |

724 | sort_adjacent(&mut c); |

725 | } |

726 | |

727 | // Find the median among `a`, `b`, and `c`. |

728 | sort3(&mut a, &mut b, &mut c); |

729 | } |

730 | |

731 | if swaps < MAX_SWAPS { |

732 | (b, swaps == 0) |

733 | } else { |

734 | // The maximum number of swaps was performed. Chances are the slice is descending or mostly |

735 | // descending, so reversing will probably help sort it faster. |

736 | v.reverse(); |

737 | (len - 1 - b, true) |

738 | } |

739 | } |

740 | |

741 | /// Sorts `v` recursively. |

742 | /// |

743 | /// If the slice had a predecessor in the original array, it is specified as `pred`. |

744 | /// |

745 | /// `limit` is the number of allowed imbalanced partitions before switching to `heapsort`. If zero, |

746 | /// this function will immediately switch to heapsort. |

747 | fn recurse<'a, T, F>(mut v: &'a mut [T], is_less: &F, mut pred: Option<&'a mut T>, mut limit: u32) |

748 | where |

749 | T: Send, |

750 | F: Fn(&T, &T) -> bool + Sync, |

751 | { |

752 | // Slices of up to this length get sorted using insertion sort. |

753 | const MAX_INSERTION: usize = 20; |

754 | // If both partitions are up to this length, we continue sequentially. This number is as small |

755 | // as possible but so that the overhead of Rayon's task scheduling is still negligible. |

756 | const MAX_SEQUENTIAL: usize = 2000; |

757 | |

758 | // True if the last partitioning was reasonably balanced. |

759 | let mut was_balanced = true; |

760 | // True if the last partitioning didn't shuffle elements (the slice was already partitioned). |

761 | let mut was_partitioned = true; |

762 | |

763 | loop { |

764 | let len = v.len(); |

765 | |

766 | // Very short slices get sorted using insertion sort. |

767 | if len <= MAX_INSERTION { |

768 | insertion_sort(v, is_less); |

769 | return; |

770 | } |

771 | |

772 | // If too many bad pivot choices were made, simply fall back to heapsort in order to |

773 | // guarantee `O(n * log(n))` worst-case. |

774 | if limit == 0 { |

775 | heapsort(v, is_less); |

776 | return; |

777 | } |

778 | |

779 | // If the last partitioning was imbalanced, try breaking patterns in the slice by shuffling |

780 | // some elements around. Hopefully we'll choose a better pivot this time. |

781 | if !was_balanced { |

782 | break_patterns(v); |

783 | limit -= 1; |

784 | } |

785 | |

786 | // Choose a pivot and try guessing whether the slice is already sorted. |

787 | let (pivot, likely_sorted) = choose_pivot(v, is_less); |

788 | |

789 | // If the last partitioning was decently balanced and didn't shuffle elements, and if pivot |

790 | // selection predicts the slice is likely already sorted... |

791 | if was_balanced && was_partitioned && likely_sorted { |

792 | // Try identifying several out-of-order elements and shifting them to correct |

793 | // positions. If the slice ends up being completely sorted, we're done. |

794 | if partial_insertion_sort(v, is_less) { |

795 | return; |

796 | } |

797 | } |

798 | |

799 | // If the chosen pivot is equal to the predecessor, then it's the smallest element in the |

800 | // slice. Partition the slice into elements equal to and elements greater than the pivot. |

801 | // This case is usually hit when the slice contains many duplicate elements. |

802 | if let Some(ref p) = pred { |

803 | if !is_less(p, &v[pivot]) { |

804 | let mid = partition_equal(v, pivot, is_less); |

805 | |

806 | // Continue sorting elements greater than the pivot. |

807 | v = &mut v[mid..]; |

808 | continue; |

809 | } |

810 | } |

811 | |

812 | // Partition the slice. |

813 | let (mid, was_p) = partition(v, pivot, is_less); |

814 | was_balanced = cmp::min(mid, len - mid) >= len / 8; |

815 | was_partitioned = was_p; |

816 | |

817 | // Split the slice into `left`, `pivot`, and `right`. |

818 | let (left, right) = v.split_at_mut(mid); |

819 | let (pivot, right) = right.split_at_mut(1); |

820 | let pivot = &mut pivot[0]; |

821 | |

822 | if cmp::max(left.len(), right.len()) <= MAX_SEQUENTIAL { |

823 | // Recurse into the shorter side only in order to minimize the total number of recursive |

824 | // calls and consume less stack space. Then just continue with the longer side (this is |

825 | // akin to tail recursion). |

826 | if left.len() < right.len() { |

827 | recurse(left, is_less, pred, limit); |

828 | v = right; |

829 | pred = Some(pivot); |

830 | } else { |

831 | recurse(right, is_less, Some(pivot), limit); |

832 | v = left; |

833 | } |

834 | } else { |

835 | // Sort the left and right half in parallel. |

836 | rayon_core::join( |

837 | || recurse(left, is_less, pred, limit), |

838 | || recurse(right, is_less, Some(pivot), limit), |

839 | ); |

840 | break; |

841 | } |

842 | } |

843 | } |

844 | |

845 | /// Sorts `v` using pattern-defeating quicksort in parallel. |

846 | /// |

847 | /// The algorithm is unstable, in-place, and *O*(*n* \* log(*n*)) worst-case. |

848 | pub(super) fn par_quicksort<T, F>(v: &mut [T], is_less: F) |

849 | where |

850 | T: Send, |

851 | F: Fn(&T, &T) -> bool + Sync, |

852 | { |

853 | // Sorting has no meaningful behavior on zero-sized types. |

854 | if mem::size_of::<T>() == 0 { |

855 | return; |

856 | } |

857 | |

858 | // Limit the number of imbalanced partitions to `floor(log2(len)) + 1`. |

859 | let limit = usize::BITS - v.len().leading_zeros(); |

860 | |

861 | recurse(v, &is_less, None, limit); |

862 | } |

863 | |

864 | #[cfg(test)] |

865 | mod tests { |

866 | use super::heapsort; |

867 | use rand::distributions::Uniform; |

868 | use rand::{thread_rng, Rng}; |

869 | |

870 | #[test] |

871 | fn test_heapsort() { |

872 | let rng = &mut thread_rng(); |

873 | |

874 | for len in (0..25).chain(500..501) { |

875 | for &modulus in &[5, 10, 100] { |

876 | let dist = Uniform::new(0, modulus); |

877 | for _ in 0..100 { |

878 | let v: Vec<i32> = rng.sample_iter(&dist).take(len).collect(); |

879 | |

880 | // Test heapsort using `<` operator. |

881 | let mut tmp = v.clone(); |

882 | heapsort(&mut tmp, &|a, b| a < b); |

883 | assert!(tmp.windows(2).all(|w| w[0] <= w[1])); |

884 | |

885 | // Test heapsort using `>` operator. |

886 | let mut tmp = v.clone(); |

887 | heapsort(&mut tmp, &|a, b| a > b); |

888 | assert!(tmp.windows(2).all(|w| w[0] >= w[1])); |

889 | } |

890 | } |

891 | } |

892 | |

893 | // Sort using a completely random comparison function. |

894 | // This will reorder the elements *somehow*, but won't panic. |

895 | let mut v: Vec<_> = (0..100).collect(); |

896 | heapsort(&mut v, &|_, _| thread_rng().gen()); |

897 | heapsort(&mut v, &|a, b| a < b); |

898 | |

899 | for (i, &entry) in v.iter().enumerate() { |

900 | assert_eq!(entry, i); |

901 | } |

902 | } |

903 | } |

904 |