1//===- SparseBufferRewriting.cpp - Sparse buffer rewriting rules ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements rewriting rules that are specific to sparse tensor
10// primitives with memref operands.
11//
12//===----------------------------------------------------------------------===//
13
14#include "Utils/CodegenUtils.h"
15
16#include "mlir/Dialect/Arith/IR/Arith.h"
17#include "mlir/Dialect/Func/IR/FuncOps.h"
18#include "mlir/Dialect/Linalg/IR/Linalg.h"
19#include "mlir/Dialect/Math/IR/Math.h"
20#include "mlir/Dialect/MemRef/IR/MemRef.h"
21#include "mlir/Dialect/SCF/IR/SCF.h"
22#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
23#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
24#include "mlir/Support/LLVM.h"
25
26using namespace mlir;
27using namespace mlir::sparse_tensor;
28
29//===---------------------------------------------------------------------===//
30// Helper methods for the actual rewriting rules.
31//===---------------------------------------------------------------------===//
32
33static constexpr uint64_t loIdx = 0;
34static constexpr uint64_t hiIdx = 1;
35static constexpr uint64_t xStartIdx = 2;
36
37static constexpr const char kPartitionFuncNamePrefix[] = "_sparse_partition_";
38static constexpr const char kBinarySearchFuncNamePrefix[] =
39 "_sparse_binary_search_";
40static constexpr const char kHybridQuickSortFuncNamePrefix[] =
41 "_sparse_hybrid_qsort_";
42static constexpr const char kSortStableFuncNamePrefix[] =
43 "_sparse_sort_stable_";
44static constexpr const char kShiftDownFuncNamePrefix[] = "_sparse_shift_down_";
45static constexpr const char kHeapSortFuncNamePrefix[] = "_sparse_heap_sort_";
46static constexpr const char kQuickSortFuncNamePrefix[] = "_sparse_qsort_";
47
48using FuncGeneratorType = function_ref<void(OpBuilder &, ModuleOp, func::FuncOp,
49 AffineMap, uint64_t, uint32_t)>;
50
51/// Constructs a function name with this format to facilitate quick sort:
52/// <namePrefix><xPerm>_<x type>_<y0 type>..._<yn type> for sort
53/// <namePrefix><xPerm>_<x type>_coo_<ny>_<y0 type>..._<yn type> for sort_coo
54static void getMangledSortHelperFuncName(llvm::raw_svector_ostream &nameOstream,
55 StringRef namePrefix, AffineMap xPerm,
56 uint64_t ny, ValueRange operands) {
57 nameOstream << namePrefix;
58 for (auto res : xPerm.getResults())
59 nameOstream << cast<AffineDimExpr>(Val&: res).getPosition() << "_";
60
61 nameOstream << getMemRefType(operands[xStartIdx]).getElementType();
62 nameOstream << "_coo_" << ny;
63
64 constexpr uint64_t yBufferOffset = 1;
65 for (Value v : operands.drop_front(n: xStartIdx + yBufferOffset))
66 nameOstream << "_" << getMemRefType(v).getElementType();
67}
68
69/// Looks up a function that is appropriate for the given operands being
70/// sorted, and creates such a function if it doesn't exist yet. The
71/// parameters `xPerm` and `ny` tell the number of x and y values provided
72/// by the buffer in xStartIdx.
73//
74// All sorting function generators take (lo, hi, xs, ys) in `operands` as
75// parameters for the sorting functions. Other parameters, such as the recursive
76// call depth, are appended to the end of the parameter list as
77// "trailing parameters".
78static FlatSymbolRefAttr getMangledSortHelperFunc(
79 OpBuilder &builder, func::FuncOp insertPoint, TypeRange resultTypes,
80 StringRef namePrefix, AffineMap xPerm, uint64_t ny, ValueRange operands,
81 FuncGeneratorType createFunc, uint32_t nTrailingP = 0) {
82 SmallString<32> nameBuffer;
83 llvm::raw_svector_ostream nameOstream(nameBuffer);
84 getMangledSortHelperFuncName(nameOstream, namePrefix, xPerm, ny,
85 operands: operands.drop_back(n: nTrailingP));
86
87 ModuleOp module = insertPoint->getParentOfType<ModuleOp>();
88 MLIRContext *context = module.getContext();
89 auto result = SymbolRefAttr::get(context, nameOstream.str());
90 auto func = module.lookupSymbol<func::FuncOp>(result.getAttr());
91
92 if (!func) {
93 // Create the function.
94 OpBuilder::InsertionGuard insertionGuard(builder);
95 builder.setInsertionPoint(insertPoint);
96 Location loc = insertPoint.getLoc();
97 func = builder.create<func::FuncOp>(
98 loc, nameOstream.str(),
99 FunctionType::get(context, operands.getTypes(), resultTypes));
100 func.setPrivate();
101 createFunc(builder, module, func, xPerm, ny, nTrailingP);
102 }
103
104 return result;
105}
106
107/// Creates a code block to process each pair of (xs[i], xs[j]) for sorting.
108/// The code to process the value pairs is generated by `bodyBuilder`.
109static void forEachIJPairInXs(
110 OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm,
111 uint64_t ny,
112 function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) {
113 Value cstep = constantIndex(builder, loc, i: xPerm.getNumResults() + ny);
114 Value iOffset = builder.create<arith::MulIOp>(loc, args[0], cstep);
115 Value jOffset = builder.create<arith::MulIOp>(loc, args[1], cstep);
116 for (unsigned k = 0, e = xPerm.getNumResults(); k < e; k++) {
117 unsigned actualK = cast<AffineDimExpr>(Val: xPerm.getResult(idx: k)).getPosition();
118 Value ak = constantIndex(builder, loc, i: actualK);
119 Value i = builder.create<arith::AddIOp>(loc, ak, iOffset);
120 Value j = builder.create<arith::AddIOp>(loc, ak, jOffset);
121 Value buffer = args[xStartIdx];
122
123 bodyBuilder(k, i, j, buffer);
124 }
125}
126
127/// Creates a code block to process each pair of (xys[i], xys[j]) for sorting.
128/// The code to process the value pairs is generated by `bodyBuilder`.
129static void forEachIJPairInAllBuffers(
130 OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm,
131 uint64_t ny,
132 function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) {
133
134 // Create code for the first (xPerm + ny) buffers.
135 SmallVector<AffineExpr> exps(xPerm.getResults());
136 for (unsigned y = 0; y < ny; y++) {
137 exps.push_back(Elt: builder.getAffineDimExpr(position: y + xPerm.getNumResults()));
138 }
139 AffineMap xyPerm = AffineMap::get(dimCount: exps.size(), symbolCount: 0, results: exps, context: builder.getContext());
140 assert(xyPerm.isPermutation());
141
142 forEachIJPairInXs(builder, loc, args, xPerm: xyPerm, ny: 0, bodyBuilder);
143
144 constexpr uint64_t numHandledBuffers = 1;
145 // Create code for the remaining buffers.
146 Value i = args[0];
147 Value j = args[1];
148 for (const auto &arg :
149 llvm::enumerate(First: args.drop_front(n: xStartIdx + numHandledBuffers))) {
150 bodyBuilder(arg.index() + xPerm.getNumResults() + ny, i, j, arg.value());
151 }
152}
153
154/// Creates a code block for swapping the values in index i and j for all the
155/// buffers.
156//
157// The generated IR corresponds to this C like algorithm:
158// swap(x0[i], x0[j]);
159// swap(x1[i], x1[j]);
160// ...
161// swap(xn[i], xn[j]);
162// swap(y0[i], y0[j]);
163// ...
164// swap(yn[i], yn[j]);
165static void createSwap(OpBuilder &builder, Location loc, ValueRange args,
166 AffineMap xPerm, uint64_t ny) {
167 auto swapOnePair = [&](uint64_t unused, Value i, Value j, Value buffer) {
168 Value vi = builder.create<memref::LoadOp>(loc, buffer, i);
169 Value vj = builder.create<memref::LoadOp>(loc, buffer, j);
170 builder.create<memref::StoreOp>(loc, vj, buffer, i);
171 builder.create<memref::StoreOp>(loc, vi, buffer, j);
172 };
173
174 forEachIJPairInAllBuffers(builder, loc, args, xPerm, ny, bodyBuilder: swapOnePair);
175}
176
177/// Creates code to compare all the (xs[i], xs[j]) pairs. The method to compare
178/// each pair is create via `compareBuilder`.
179static Value createInlinedCompareImplementation(
180 OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm,
181 uint64_t ny,
182 function_ref<Value(OpBuilder &, Location, Value, Value, Value, bool, bool)>
183 compareBuilder) {
184 Value result;
185 auto bodyBuilder = [&](uint64_t k, Value i, Value j, Value buffer) {
186 bool isFirstDim = (k == 0);
187 bool isLastDim = (k == xPerm.getNumResults() - 1);
188 Value val =
189 compareBuilder(builder, loc, i, j, buffer, isFirstDim, isLastDim);
190 if (isFirstDim) {
191 result = val;
192 } else if (!isLastDim) {
193 OpBuilder::InsertionGuard insertionGuard(builder);
194 auto ifOp = cast<scf::IfOp>(val.getDefiningOp());
195 builder.setInsertionPointAfter(ifOp);
196 builder.create<scf::YieldOp>(loc, ifOp.getResult(0));
197 }
198 };
199
200 forEachIJPairInXs(builder, loc, args, xPerm, ny, bodyBuilder);
201
202 builder.setInsertionPointAfterValue(result);
203 return result;
204}
205
206/// Generates code to compare whether x[i] is equal to x[j] and returns the
207/// result of the comparison.
208static Value createEqCompare(OpBuilder &builder, Location loc, Value i, Value j,
209 Value x, bool isFirstDim, bool isLastDim) {
210 Value vi = builder.create<memref::LoadOp>(loc, x, i);
211 Value vj = builder.create<memref::LoadOp>(loc, x, j);
212
213 Value res;
214 if (isLastDim) {
215 res = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, vi, vj);
216 // For 1D, we create a compare without any control flow. Otherwise, we
217 // create YieldOp to return the result in the nested if-stmt.
218 if (!isFirstDim)
219 builder.create<scf::YieldOp>(loc, res);
220 } else {
221 Value ne =
222 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, vi, vj);
223 scf::IfOp ifOp = builder.create<scf::IfOp>(loc, builder.getIntegerType(1),
224 ne, /*else=*/true);
225 // If (x[i] != x[j]).
226 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
227 Value f = constantI1(builder, loc, b: false);
228 builder.create<scf::YieldOp>(loc, f);
229
230 // If (x[i] == x[j]). Set up the insertion point for the nested if-stmt that
231 // checks the remaining dimensions.
232 builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
233 res = ifOp.getResult(0);
234 }
235
236 return res;
237}
238
239/// Creates code to compare whether xs[i] is equal to xs[j].
240//
241// The generate IR corresponds to this C like algorithm:
242// if (x0[i] != x0[j])
243// return false;
244// else
245// if (x1[i] != x1[j])
246// return false;
247// else if (x2[2] != x2[j]))
248// and so on ...
249static Value createInlinedEqCompare(OpBuilder &builder, Location loc,
250 ValueRange args, AffineMap xPerm,
251 uint64_t ny, uint32_t nTrailingP = 0) {
252 // Compare functions don't use trailing parameters.
253 (void)nTrailingP;
254 assert(nTrailingP == 0);
255 return createInlinedCompareImplementation(builder, loc, args, xPerm, ny,
256 compareBuilder: createEqCompare);
257}
258
259/// Generates code to compare whether x[i] is less than x[j] and returns the
260/// result of the comparison.
261static Value createLessThanCompare(OpBuilder &builder, Location loc, Value i,
262 Value j, Value x, bool isFirstDim,
263 bool isLastDim) {
264 Value vi = builder.create<memref::LoadOp>(loc, x, i);
265 Value vj = builder.create<memref::LoadOp>(loc, x, j);
266
267 Value res;
268 if (isLastDim) {
269 res = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, vi, vj);
270 // For 1D, we create a compare without any control flow. Otherwise, we
271 // create YieldOp to return the result in the nested if-stmt.
272 if (!isFirstDim)
273 builder.create<scf::YieldOp>(loc, res);
274 } else {
275 Value ne =
276 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, vi, vj);
277 scf::IfOp ifOp = builder.create<scf::IfOp>(loc, builder.getIntegerType(1),
278 ne, /*else=*/true);
279 // If (x[i] != x[j]).
280 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
281 Value lt =
282 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, vi, vj);
283 builder.create<scf::YieldOp>(loc, lt);
284
285 // If (x[i] == x[j]). Set up the insertion point for the nested if-stmt that
286 // checks the remaining dimensions.
287 builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
288 res = ifOp.getResult(0);
289 }
290
291 return res;
292}
293
294/// Creates code to compare whether xs[i] is less than xs[j].
295//
296// The generate IR corresponds to this C like algorithm:
297// if (x0[i] != x0[j])
298// return x0[i] < x0[j];
299// else if (x1[j] != x1[i])
300// return x1[i] < x1[j];
301// else
302// and so on ...
303static Value createInlinedLessThan(OpBuilder &builder, Location loc,
304 ValueRange args, AffineMap xPerm,
305 uint64_t ny, uint32_t nTrailingP = 0) {
306 // Compare functions don't use trailing parameters.
307 (void)nTrailingP;
308 assert(nTrailingP == 0);
309 return createInlinedCompareImplementation(builder, loc, args, xPerm, ny,
310 compareBuilder: createLessThanCompare);
311}
312
313/// Creates a function to use a binary search to find the insertion point for
314/// inserting xs[hi] to the sorted values xs[lo..hi).
315//
316// The generate IR corresponds to this C like algorithm:
317// p = hi
318// while (lo < hi)
319// mid = (lo + hi) >> 1
320// if (xs[p] < xs[mid])
321// hi = mid
322// else
323// lo = mid - 1
324// return lo;
325//
326static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module,
327 func::FuncOp func, AffineMap xPerm,
328 uint64_t ny, uint32_t nTrailingP = 0) {
329 // Binary search doesn't use trailing parameters.
330 (void)nTrailingP;
331 assert(nTrailingP == 0);
332 OpBuilder::InsertionGuard insertionGuard(builder);
333 Block *entryBlock = func.addEntryBlock();
334 builder.setInsertionPointToStart(entryBlock);
335
336 Location loc = func.getLoc();
337 ValueRange args = entryBlock->getArguments();
338 Value p = args[hiIdx];
339 SmallVector<Type, 2> types(2, p.getType()); // Only two types.
340 scf::WhileOp whileOp = builder.create<scf::WhileOp>(
341 loc, types, SmallVector<Value, 2>{args[loIdx], args[hiIdx]});
342
343 // The before-region of the WhileOp.
344 Block *before =
345 builder.createBlock(&whileOp.getBefore(), {}, types, {loc, loc});
346 builder.setInsertionPointToEnd(before);
347 Value cond1 = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
348 before->getArgument(0),
349 before->getArgument(1));
350 builder.create<scf::ConditionOp>(loc, cond1, before->getArguments());
351
352 // The after-region of the WhileOp.
353 Block *after =
354 builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc});
355 builder.setInsertionPointToEnd(after);
356 Value lo = after->getArgument(i: 0);
357 Value hi = after->getArgument(i: 1);
358 // Compute mid = (lo + hi) >> 1.
359 Value c1 = constantIndex(builder, loc, i: 1);
360 Value mid = builder.create<arith::ShRUIOp>(
361 loc, builder.create<arith::AddIOp>(loc, lo, hi), c1);
362 Value midp1 = builder.create<arith::AddIOp>(loc, mid, c1);
363
364 // Compare xs[p] < xs[mid].
365 SmallVector<Value> compareOperands{p, mid};
366 constexpr uint64_t numXBuffers = 1;
367 compareOperands.append(in_start: args.begin() + xStartIdx,
368 in_end: args.begin() + xStartIdx + numXBuffers);
369 Value cond2 = createInlinedLessThan(builder, loc, args: compareOperands, xPerm, ny);
370 // Update lo and hi for the WhileOp as follows:
371 // if (xs[p] < xs[mid]))
372 // hi = mid;
373 // else
374 // lo = mid + 1;
375 Value newLo = builder.create<arith::SelectOp>(loc, cond2, lo, midp1);
376 Value newHi = builder.create<arith::SelectOp>(loc, cond2, mid, hi);
377 builder.create<scf::YieldOp>(loc, ValueRange{newLo, newHi});
378
379 builder.setInsertionPointAfter(whileOp);
380 builder.create<func::ReturnOp>(loc, whileOp.getResult(0));
381}
382
383/// Creates code to advance i in a loop based on xs[p] as follows:
384/// while (xs[i] < xs[p]) i += step (step > 0)
385/// or
386/// while (xs[i] > xs[p]) i += step (step < 0)
387/// The routine returns i as well as a boolean value to indicate whether
388/// xs[i] == xs[p].
389static std::pair<Value, Value> createScanLoop(OpBuilder &builder,
390 ModuleOp module,
391 func::FuncOp func, ValueRange xs,
392 Value i, Value p, AffineMap xPerm,
393 uint64_t ny, int step) {
394 Location loc = func.getLoc();
395 scf::WhileOp whileOp =
396 builder.create<scf::WhileOp>(loc, TypeRange{i.getType()}, ValueRange{i});
397
398 Block *before =
399 builder.createBlock(&whileOp.getBefore(), {}, {i.getType()}, {loc});
400 builder.setInsertionPointToEnd(before);
401 SmallVector<Value> compareOperands;
402 if (step > 0) {
403 compareOperands.push_back(Elt: before->getArgument(i: 0));
404 compareOperands.push_back(Elt: p);
405 } else {
406 assert(step < 0);
407 compareOperands.push_back(Elt: p);
408 compareOperands.push_back(Elt: before->getArgument(i: 0));
409 }
410 compareOperands.append(in_start: xs.begin(), in_end: xs.end());
411 Value cond = createInlinedLessThan(builder, loc, args: compareOperands, xPerm, ny);
412 builder.create<scf::ConditionOp>(loc, cond, before->getArguments());
413
414 Block *after =
415 builder.createBlock(&whileOp.getAfter(), {}, {i.getType()}, {loc});
416 builder.setInsertionPointToEnd(after);
417 Value cs = constantIndex(builder, loc, i: step);
418 i = builder.create<arith::AddIOp>(loc, after->getArgument(0), cs);
419 builder.create<scf::YieldOp>(loc, ValueRange{i});
420 i = whileOp.getResult(0);
421
422 builder.setInsertionPointAfter(whileOp);
423 compareOperands[0] = i;
424 compareOperands[1] = p;
425 Value compareEq =
426 createInlinedEqCompare(builder, loc, args: compareOperands, xPerm, ny);
427
428 return std::make_pair(whileOp.getResult(0), compareEq);
429}
430
431/// Creates and returns an IfOp to compare two elements and swap the elements
432/// if compareFunc(data[b], data[a]) returns true. The new insertion point is
433/// right after the swap instructions.
434static scf::IfOp createCompareThenSwap(OpBuilder &builder, Location loc,
435 AffineMap xPerm, uint64_t ny,
436 SmallVectorImpl<Value> &swapOperands,
437 SmallVectorImpl<Value> &compareOperands,
438 Value a, Value b) {
439 // Compare(data[b], data[a]).
440 compareOperands[0] = b;
441 compareOperands[1] = a;
442 Value cond = createInlinedLessThan(builder, loc, args: compareOperands, xPerm, ny);
443 scf::IfOp ifOp = builder.create<scf::IfOp>(loc, cond, /*else=*/false);
444 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
445 swapOperands[0] = b;
446 swapOperands[1] = a;
447 createSwap(builder, loc, args: swapOperands, xPerm, ny);
448 return ifOp;
449}
450
451/// Creates code to insert the 3rd element to a list of two sorted elements.
452static void createInsert3rd(OpBuilder &builder, Location loc, AffineMap xPerm,
453 uint64_t ny, SmallVectorImpl<Value> &swapOperands,
454 SmallVectorImpl<Value> &compareOperands, Value v0,
455 Value v1, Value v2) {
456 scf::IfOp ifOp = createCompareThenSwap(builder, loc, xPerm, ny, swapOperands,
457 compareOperands, v1, v2);
458 createCompareThenSwap(builder, loc, xPerm, ny, swapOperands, compareOperands,
459 a: v0, b: v1);
460 builder.setInsertionPointAfter(ifOp);
461}
462
463/// Creates code to sort 3 elements.
464static void createSort3(OpBuilder &builder, Location loc, AffineMap xPerm,
465 uint64_t ny, SmallVectorImpl<Value> &swapOperands,
466 SmallVectorImpl<Value> &compareOperands, Value v0,
467 Value v1, Value v2) {
468 // Sort the first 2 elements.
469 scf::IfOp ifOp1 = createCompareThenSwap(builder, loc, xPerm, ny, swapOperands,
470 compareOperands, v0, v1);
471 builder.setInsertionPointAfter(ifOp1);
472
473 // Insert the 3th element.
474 createInsert3rd(builder, loc, xPerm, ny, swapOperands, compareOperands, v0,
475 v1, v2);
476}
477
478/// Creates code to sort 5 elements.
479static void createSort5(OpBuilder &builder, Location loc, AffineMap xPerm,
480 uint64_t ny, SmallVectorImpl<Value> &swapOperands,
481 SmallVectorImpl<Value> &compareOperands, Value v0,
482 Value v1, Value v2, Value v3, Value v4) {
483 // Sort the first 3 elements.
484 createSort3(builder, loc, xPerm, ny, swapOperands, compareOperands, v0, v1,
485 v2);
486
487 auto insert4th = [&]() {
488 scf::IfOp ifOp = createCompareThenSwap(
489 builder, loc, xPerm, ny, swapOperands, compareOperands, v2, v3);
490 createInsert3rd(builder, loc, xPerm, ny, swapOperands, compareOperands, v0,
491 v1, v2);
492 builder.setInsertionPointAfter(ifOp);
493 };
494
495 // Insert the 4th element.
496 insert4th();
497
498 // Insert the 5th element.
499 scf::IfOp ifOp = createCompareThenSwap(builder, loc, xPerm, ny, swapOperands,
500 compareOperands, v3, v4);
501 insert4th();
502 builder.setInsertionPointAfter(ifOp);
503}
504
505/// Creates a code block to swap the values in indices lo, mi, and hi so that
506/// data[lo], data[mi] and data[hi] are sorted in non-decreasing values. When
507/// the number of values in range [lo, hi) is more than a threshold, we also
508/// include the middle of [lo, mi) and [mi, hi) and sort a total of five values.
509static void createChoosePivot(OpBuilder &builder, ModuleOp module,
510 func::FuncOp func, AffineMap xPerm, uint64_t ny,
511 Value lo, Value hi, Value mi, ValueRange args) {
512 SmallVector<Value> compareOperands{mi, lo};
513 constexpr uint64_t numXBuffers = 1;
514 compareOperands.append(in_start: args.begin() + xStartIdx,
515 in_end: args.begin() + xStartIdx + numXBuffers);
516 SmallVector<Value> swapOperands{mi, lo};
517 swapOperands.append(in_start: args.begin() + xStartIdx, in_end: args.end());
518 Location loc = func.getLoc();
519 Value c1 = constantIndex(builder, loc, i: 1);
520 Value hiP1 = builder.create<arith::AddIOp>(loc, hi, c1);
521 Value len = builder.create<arith::SubIOp>(loc, hiP1, lo);
522 Value lenThreshold = constantIndex(builder, loc, i: 1000);
523 Value lenCond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
524 len, lenThreshold);
525 scf::IfOp lenIf = builder.create<scf::IfOp>(loc, lenCond, /*else=*/true);
526
527 // When len < 1000, choose pivot from median of 3 values.
528 builder.setInsertionPointToStart(&lenIf.getThenRegion().front());
529 createSort3(builder, loc, xPerm, ny, swapOperands, compareOperands, v0: lo, v1: mi,
530 v2: hi);
531
532 // When len >= 1000, choose pivot from median of 5 values.
533 builder.setInsertionPointToStart(&lenIf.getElseRegion().front());
534 Value miP1 = builder.create<arith::AddIOp>(loc, hi, c1);
535 Value a = builder.create<arith::AddIOp>(loc, lo, miP1);
536 // Value a is the middle between [loc, mi].
537 a = builder.create<arith::ShRUIOp>(loc, a, c1);
538 Value b = builder.create<arith::AddIOp>(loc, mi, hiP1);
539 // Value b is the middle between [mi, hi].
540 b = builder.create<arith::ShRUIOp>(loc, b, c1);
541 createSort5(builder, loc, xPerm, ny, swapOperands, compareOperands, v0: lo, v1: a, v2: mi,
542 v3: b, v4: hi);
543
544 builder.setInsertionPointAfter(lenIf);
545}
546
547/// Creates a function to perform quick sort partition on the values in the
548/// range of index [lo, hi), assuming lo < hi.
549//
550// The generated IR corresponds to this C like algorithm:
551// int partition(lo, hi, xs) {
552// p = (lo+hi)/2 // pivot index
553// i = lo
554// j = hi-1
555// while (true) do {
556// while (xs[i] < xs[p]) i ++;
557// i_eq = (xs[i] == xs[p]);
558// while (xs[j] > xs[p]) j --;
559// j_eq = (xs[j] == xs[p]);
560//
561// if (i >= j) return j + 1;
562//
563// if (i < j) {
564// swap(xs[i], xs[j])
565// if (i == p) {
566// p = j;
567// } else if (j == p) {
568// p = i;
569// }
570// if (i_eq && j_eq) {
571// ++i;
572// --j;
573// }
574// }
575// }
576// }
577static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
578 func::FuncOp func, AffineMap xPerm, uint64_t ny,
579 uint32_t nTrailingP = 0) {
580 // Quick sort partition doesn't use trailing parameters.
581 (void)nTrailingP;
582 assert(nTrailingP == 0);
583 OpBuilder::InsertionGuard insertionGuard(builder);
584
585 Block *entryBlock = func.addEntryBlock();
586 builder.setInsertionPointToStart(entryBlock);
587
588 Location loc = func.getLoc();
589 ValueRange args = entryBlock->getArguments();
590 Value lo = args[loIdx];
591 Value hi = args[hiIdx];
592 Value sum = builder.create<arith::AddIOp>(loc, lo, hi);
593 Value c1 = constantIndex(builder, loc, i: 1);
594 Value p = builder.create<arith::ShRUIOp>(loc, sum, c1);
595
596 Value i = lo;
597 Value j = builder.create<arith::SubIOp>(loc, hi, c1);
598 createChoosePivot(builder, module, func, xPerm, ny, i, j, p, args);
599 Value trueVal = constantI1(builder, loc, b: true); // The value for while (true)
600 SmallVector<Value, 4> operands{i, j, p, trueVal}; // Exactly four values.
601 SmallVector<Type, 4> types{i.getType(), j.getType(), p.getType(),
602 trueVal.getType()};
603 scf::WhileOp whileOp = builder.create<scf::WhileOp>(loc, types, operands);
604
605 // The before-region of the WhileOp.
606 Block *before = builder.createBlock(&whileOp.getBefore(), {}, types,
607 {loc, loc, loc, loc});
608 builder.setInsertionPointToEnd(before);
609 builder.create<scf::ConditionOp>(loc, before->getArgument(3),
610 before->getArguments());
611
612 // The after-region of the WhileOp.
613 Block *after =
614 builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc, loc, loc});
615 builder.setInsertionPointToEnd(after);
616 i = after->getArgument(i: 0);
617 j = after->getArgument(i: 1);
618 p = after->getArgument(i: 2);
619
620 constexpr uint64_t numXBuffers = 1;
621 auto [iresult, iCompareEq] =
622 createScanLoop(builder, module, func, args.slice(n: xStartIdx, m: numXBuffers),
623 i, p, xPerm, ny, 1);
624 i = iresult;
625 auto [jresult, jCompareEq] =
626 createScanLoop(builder, module, func, args.slice(n: xStartIdx, m: numXBuffers),
627 j, p, xPerm, ny, -1);
628 j = jresult;
629
630 // If i < j:
631 Value cond =
632 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, i, j);
633 scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, cond, /*else=*/true);
634 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
635 SmallVector<Value> swapOperands{i, j};
636 swapOperands.append(in_start: args.begin() + xStartIdx, in_end: args.end());
637 createSwap(builder, loc, args: swapOperands, xPerm, ny);
638 // If the pivot is moved, update p with the new pivot.
639 Value icond =
640 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, i, p);
641 scf::IfOp ifOpI = builder.create<scf::IfOp>(loc, TypeRange{p.getType()},
642 icond, /*else=*/true);
643 builder.setInsertionPointToStart(&ifOpI.getThenRegion().front());
644 builder.create<scf::YieldOp>(loc, ValueRange{j});
645 builder.setInsertionPointToStart(&ifOpI.getElseRegion().front());
646 Value jcond =
647 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, j, p);
648 scf::IfOp ifOpJ = builder.create<scf::IfOp>(loc, TypeRange{p.getType()},
649 jcond, /*else=*/true);
650 builder.setInsertionPointToStart(&ifOpJ.getThenRegion().front());
651 builder.create<scf::YieldOp>(loc, ValueRange{i});
652 builder.setInsertionPointToStart(&ifOpJ.getElseRegion().front());
653 builder.create<scf::YieldOp>(loc, ValueRange{p});
654 builder.setInsertionPointAfter(ifOpJ);
655 builder.create<scf::YieldOp>(loc, ifOpJ.getResults());
656 builder.setInsertionPointAfter(ifOpI);
657 Value compareEqIJ =
658 builder.create<arith::AndIOp>(loc, iCompareEq, jCompareEq);
659 scf::IfOp ifOp2 = builder.create<scf::IfOp>(
660 loc, TypeRange{i.getType(), j.getType()}, compareEqIJ, /*else=*/true);
661 builder.setInsertionPointToStart(&ifOp2.getThenRegion().front());
662 Value i2 = builder.create<arith::AddIOp>(loc, i, c1);
663 Value j2 = builder.create<arith::SubIOp>(loc, j, c1);
664 builder.create<scf::YieldOp>(loc, ValueRange{i2, j2});
665 builder.setInsertionPointToStart(&ifOp2.getElseRegion().front());
666 builder.create<scf::YieldOp>(loc, ValueRange{i, j});
667 builder.setInsertionPointAfter(ifOp2);
668 builder.create<scf::YieldOp>(
669 loc,
670 ValueRange{ifOp2.getResult(0), ifOp2.getResult(1), ifOpI.getResult(0),
671 /*cont=*/constantI1(builder, loc, true)});
672
673 // False branch for if i < j (i.e., i >= j):
674 builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
675 p = builder.create<arith::AddIOp>(loc, j,
676 constantOne(builder, loc, j.getType()));
677 builder.create<scf::YieldOp>(
678 loc, ValueRange{i, j, p, /*cont=*/constantI1(builder, loc, false)});
679
680 // Return for the whileOp.
681 builder.setInsertionPointAfter(ifOp);
682 builder.create<scf::YieldOp>(loc, ifOp.getResults());
683
684 // Return for the function.
685 builder.setInsertionPointAfter(whileOp);
686 builder.create<func::ReturnOp>(loc, whileOp.getResult(2));
687}
688
689/// Computes (n-2)/n, assuming n has index type.
690static Value createSubTwoDividedByTwo(OpBuilder &builder, Location loc,
691 Value n) {
692 Value i2 = constantIndex(builder, loc, i: 2);
693 Value res = builder.create<arith::SubIOp>(loc, n, i2);
694 Value i1 = constantIndex(builder, loc, i: 1);
695 return builder.create<arith::ShRUIOp>(loc, res, i1);
696}
697
698/// Creates a function to heapify the subtree with root `start` within the full
699/// binary tree in the range of index [first, first + n).
700//
701// The generated IR corresponds to this C like algorithm:
702// void shiftDown(first, start, n, data) {
703// if (n >= 2) {
704// child = start - first
705// if ((n-2)/2 >= child) {
706// // Left child exists.
707// child = child * 2 + 1 // Initialize the bigger child to left child.
708// childIndex = child + first
709// if (child+1 < n && data[childIndex] < data[childIndex+1])
710// // Right child exits and is bigger.
711// childIndex++; child++;
712// // Shift data[start] down to where it belongs in the subtree.
713// while (data[start] < data[childIndex) {
714// swap(data[start], data[childIndex])
715// start = childIndex
716// if ((n - 2)/2 >= child) {
717// // Left child exists.
718// child = 2*child + 1
719// childIndex = child + 1
720// if (child + 1) < n && data[childIndex] < data[childIndex+1]
721// childIndex++; child++;
722// }
723// }
724// }
725// }
726// }
727//
728static void createShiftDownFunc(OpBuilder &builder, ModuleOp module,
729 func::FuncOp func, AffineMap xPerm, uint64_t ny,
730 uint32_t nTrailingP) {
731 // The value n is passed in as a trailing parameter.
732 assert(nTrailingP == 1);
733 OpBuilder::InsertionGuard insertionGuard(builder);
734 Block *entryBlock = func.addEntryBlock();
735 builder.setInsertionPointToStart(entryBlock);
736
737 Location loc = func.getLoc();
738 Value n = entryBlock->getArguments().back();
739 ValueRange args = entryBlock->getArguments().drop_back();
740 Value first = args[loIdx];
741 Value start = args[hiIdx];
742
743 // If (n >= 2).
744 Value c2 = constantIndex(builder, loc, i: 2);
745 Value condN =
746 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, n, c2);
747 scf::IfOp ifN = builder.create<scf::IfOp>(loc, condN, /*else=*/false);
748 builder.setInsertionPointToStart(&ifN.getThenRegion().front());
749 Value child = builder.create<arith::SubIOp>(loc, start, first);
750
751 // If ((n-2)/2 >= child).
752 Value t = createSubTwoDividedByTwo(builder, loc, n);
753 Value condNc =
754 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, t, child);
755 scf::IfOp ifNc = builder.create<scf::IfOp>(loc, condNc, /*else=*/false);
756
757 builder.setInsertionPointToStart(&ifNc.getThenRegion().front());
758 Value c1 = constantIndex(builder, loc, i: 1);
759 SmallVector<Value> compareOperands{start, start};
760 constexpr uint64_t numXBuffers = 1;
761 compareOperands.append(in_start: args.begin() + xStartIdx,
762 in_end: args.begin() + xStartIdx + numXBuffers);
763
764 // Generate code to inspect the children of 'r' and return the larger child
765 // as follows:
766 // child = r * 2 + 1 // Left child.
767 // childIndex = child + first
768 // if (child+1 < n && data[childIndex] < data[childIndex+1])
769 // childIndex ++; child ++ // Right child is bigger.
770 auto getLargerChild = [&](Value r) -> std::pair<Value, Value> {
771 Value lChild = builder.create<arith::ShLIOp>(loc, r, c1);
772 lChild = builder.create<arith::AddIOp>(loc, lChild, c1);
773 Value lChildIdx = builder.create<arith::AddIOp>(loc, lChild, first);
774 Value rChild = builder.create<arith::AddIOp>(loc, lChild, c1);
775 Value cond1 = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
776 rChild, n);
777 SmallVector<Type, 2> ifTypes(2, r.getType());
778 scf::IfOp if1 =
779 builder.create<scf::IfOp>(loc, ifTypes, cond1, /*else=*/true);
780 builder.setInsertionPointToStart(&if1.getThenRegion().front());
781 Value rChildIdx = builder.create<arith::AddIOp>(loc, rChild, first);
782 // Compare data[left] < data[right].
783 compareOperands[0] = lChildIdx;
784 compareOperands[1] = rChildIdx;
785 Value cond2 =
786 createInlinedLessThan(builder, loc, args: compareOperands, xPerm, ny);
787 scf::IfOp if2 =
788 builder.create<scf::IfOp>(loc, ifTypes, cond2, /*else=*/true);
789 builder.setInsertionPointToStart(&if2.getThenRegion().front());
790 builder.create<scf::YieldOp>(loc, ValueRange{rChild, rChildIdx});
791 builder.setInsertionPointToStart(&if2.getElseRegion().front());
792 builder.create<scf::YieldOp>(loc, ValueRange{lChild, lChildIdx});
793 builder.setInsertionPointAfter(if2);
794 builder.create<scf::YieldOp>(loc, if2.getResults());
795 builder.setInsertionPointToStart(&if1.getElseRegion().front());
796 builder.create<scf::YieldOp>(loc, ValueRange{lChild, lChildIdx});
797 builder.setInsertionPointAfter(if1);
798 return std::make_pair(if1.getResult(0), if1.getResult(1));
799 };
800
801 Value childIdx;
802 std::tie(args&: child, args&: childIdx) = getLargerChild(child);
803
804 // While (data[start] < data[childIndex]).
805 SmallVector<Type, 3> types(3, child.getType());
806 scf::WhileOp whileOp = builder.create<scf::WhileOp>(
807 loc, types, SmallVector<Value, 2>{start, child, childIdx});
808
809 // The before-region of the WhileOp.
810 SmallVector<Location, 3> locs(3, loc);
811 Block *before = builder.createBlock(&whileOp.getBefore(), {}, types, locs);
812 builder.setInsertionPointToEnd(before);
813 start = before->getArgument(i: 0);
814 childIdx = before->getArgument(i: 2);
815 compareOperands[0] = start;
816 compareOperands[1] = childIdx;
817 Value cond = createInlinedLessThan(builder, loc, args: compareOperands, xPerm, ny);
818 builder.create<scf::ConditionOp>(loc, cond, before->getArguments());
819
820 // The after-region of the WhileOp.
821 Block *after = builder.createBlock(&whileOp.getAfter(), {}, types, locs);
822 start = after->getArgument(i: 0);
823 child = after->getArgument(i: 1);
824 childIdx = after->getArgument(i: 2);
825 SmallVector<Value> swapOperands{start, childIdx};
826 swapOperands.append(in_start: args.begin() + xStartIdx, in_end: args.end());
827 createSwap(builder, loc, args: swapOperands, xPerm, ny);
828 start = childIdx;
829 Value cond2 =
830 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, t, child);
831 scf::IfOp if2 = builder.create<scf::IfOp>(
832 loc, TypeRange{child.getType(), child.getType()}, cond2, /*else=*/true);
833 builder.setInsertionPointToStart(&if2.getThenRegion().front());
834 auto [newChild, newChildIdx] = getLargerChild(child);
835 builder.create<scf::YieldOp>(loc, ValueRange{newChild, newChildIdx});
836 builder.setInsertionPointToStart(&if2.getElseRegion().front());
837 builder.create<scf::YieldOp>(loc, ValueRange{child, childIdx});
838 builder.setInsertionPointAfter(if2);
839 builder.create<scf::YieldOp>(
840 loc, ValueRange{start, if2.getResult(0), if2.getResult(1)});
841
842 builder.setInsertionPointAfter(ifN);
843 builder.create<func::ReturnOp>(loc);
844}
845
846/// Creates a function to perform heap sort on the values in the range of index
847/// [lo, hi) with the assumption hi - lo >= 2.
848//
849// The generate IR corresponds to this C like algorithm:
850// void heapSort(lo, hi, data) {
851// n = hi - lo
852// for i = (n-2)/2 downto 0
853// shiftDown(lo, lo+i, n)
854//
855// for l = n downto 2
856// swap(lo, lo+l-1)
857// shiftdown(lo, lo, l-1)
858// }
859static void createHeapSortFunc(OpBuilder &builder, ModuleOp module,
860 func::FuncOp func, AffineMap xPerm, uint64_t ny,
861 uint32_t nTrailingP) {
862 // Heap sort function doesn't have trailing parameters.
863 (void)nTrailingP;
864 assert(nTrailingP == 0);
865 OpBuilder::InsertionGuard insertionGuard(builder);
866 Block *entryBlock = func.addEntryBlock();
867 builder.setInsertionPointToStart(entryBlock);
868
869 Location loc = func.getLoc();
870 ValueRange args = entryBlock->getArguments();
871 Value lo = args[loIdx];
872 Value hi = args[hiIdx];
873 Value n = builder.create<arith::SubIOp>(loc, hi, lo);
874
875 // For i = (n-2)/2 downto 0.
876 Value c0 = constantIndex(builder, loc, i: 0);
877 Value c1 = constantIndex(builder, loc, i: 1);
878 Value s = createSubTwoDividedByTwo(builder, loc, n);
879 Value up = builder.create<arith::AddIOp>(loc, s, c1);
880 scf::ForOp forI = builder.create<scf::ForOp>(loc, c0, up, c1);
881 builder.setInsertionPointToStart(forI.getBody());
882 Value i = builder.create<arith::SubIOp>(loc, s, forI.getInductionVar());
883 Value lopi = builder.create<arith::AddIOp>(loc, lo, i);
884 SmallVector<Value> shiftDownOperands = {lo, lopi};
885 shiftDownOperands.append(in_start: args.begin() + xStartIdx, in_end: args.end());
886 shiftDownOperands.push_back(Elt: n);
887 FlatSymbolRefAttr shiftDownFunc = getMangledSortHelperFunc(
888 builder, func, TypeRange(), kShiftDownFuncNamePrefix, xPerm, ny,
889 shiftDownOperands, createShiftDownFunc, /*nTrailingP=*/1);
890 builder.create<func::CallOp>(loc, shiftDownFunc, TypeRange(),
891 shiftDownOperands);
892
893 builder.setInsertionPointAfter(forI);
894 // For l = n downto 2.
895 up = builder.create<arith::SubIOp>(loc, n, c1);
896 scf::ForOp forL = builder.create<scf::ForOp>(loc, c0, up, c1);
897 builder.setInsertionPointToStart(forL.getBody());
898 Value l = builder.create<arith::SubIOp>(loc, n, forL.getInductionVar());
899 Value loplm1 = builder.create<arith::AddIOp>(loc, lo, l);
900 loplm1 = builder.create<arith::SubIOp>(loc, loplm1, c1);
901 SmallVector<Value> swapOperands{lo, loplm1};
902 swapOperands.append(in_start: args.begin() + xStartIdx, in_end: args.end());
903 createSwap(builder, loc, args: swapOperands, xPerm, ny);
904 shiftDownOperands[1] = lo;
905 shiftDownOperands[shiftDownOperands.size() - 1] =
906 builder.create<arith::SubIOp>(loc, l, c1);
907 builder.create<func::CallOp>(loc, shiftDownFunc, TypeRange(),
908 shiftDownOperands);
909
910 builder.setInsertionPointAfter(forL);
911 builder.create<func::ReturnOp>(loc);
912}
913
914/// A helper for generating code to perform quick sort. It partitions [lo, hi),
915/// recursively calls quick sort to process the smaller partition and returns
916/// the bigger partition to be processed by the enclosed while-loop.
917static std::pair<Value, Value>
918createQuickSort(OpBuilder &builder, ModuleOp module, func::FuncOp func,
919 ValueRange args, AffineMap xPerm, uint64_t ny,
920 uint32_t nTrailingP) {
921 MLIRContext *context = module.getContext();
922 Location loc = func.getLoc();
923 Value lo = args[loIdx];
924 Value hi = args[hiIdx];
925 SmallVector<Type, 2> types(2, lo.getType()); // Only two types.
926
927 FlatSymbolRefAttr partitionFunc = getMangledSortHelperFunc(
928 builder, func, {IndexType::get(context)}, kPartitionFuncNamePrefix, xPerm,
929 ny, args.drop_back(nTrailingP), createPartitionFunc);
930 Value p = builder
931 .create<func::CallOp>(loc, partitionFunc,
932 TypeRange{IndexType::get(context)},
933 args.drop_back(nTrailingP))
934 .getResult(0);
935
936 Value lenLow = builder.create<arith::SubIOp>(loc, p, lo);
937 Value lenHigh = builder.create<arith::SubIOp>(loc, hi, p);
938 // Partition already sorts array with len <= 2
939 Value c2 = constantIndex(builder, loc, i: 2);
940 Value len = builder.create<arith::SubIOp>(loc, hi, lo);
941 Value lenGtTwo =
942 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ugt, len, c2);
943 scf::IfOp ifLenGtTwo =
944 builder.create<scf::IfOp>(loc, types, lenGtTwo, /*else=*/true);
945 builder.setInsertionPointToStart(&ifLenGtTwo.getElseRegion().front());
946 // Returns an empty range to mark the entire region is fully sorted.
947 builder.create<scf::YieldOp>(loc, ValueRange{lo, lo});
948
949 // Else len > 2, need recursion.
950 builder.setInsertionPointToStart(&ifLenGtTwo.getThenRegion().front());
951 Value cond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ule,
952 lenLow, lenHigh);
953
954 Value c0 = constantIndex(builder, loc, i: 0);
955 scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, cond, /*else=*/true);
956
957 auto mayRecursion = [&](Value low, Value high, Value len) {
958 Value cond =
959 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, len, c0);
960 scf::IfOp ifOp = builder.create<scf::IfOp>(loc, cond, /*else=*/false);
961 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
962 SmallVector<Value> operands{low, high};
963 operands.append(in_start: args.begin() + xStartIdx, in_end: args.end());
964 builder.create<func::CallOp>(loc, func, operands);
965 builder.setInsertionPointAfter(ifOp);
966 };
967
968 // Recursively call quickSort to process the smaller partition and return
969 // the bigger partition to be processed by the enclosed while-loop.
970 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
971 mayRecursion(lo, p, lenLow);
972 builder.create<scf::YieldOp>(loc, ValueRange{p, hi});
973
974 builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
975 mayRecursion(p, hi, lenHigh);
976 builder.create<scf::YieldOp>(loc, ValueRange{lo, p});
977
978 builder.setInsertionPointAfter(ifOp);
979 builder.create<scf::YieldOp>(loc, ifOp.getResults());
980
981 builder.setInsertionPointAfter(ifLenGtTwo);
982 return std::make_pair(ifLenGtTwo.getResult(0), ifLenGtTwo.getResult(1));
983}
984
985/// Creates a function to perform insertion sort on the values in the range of
986/// index [lo, hi).
987//
988// The generate IR corresponds to this C like algorithm:
989// void insertionSort(lo, hi, data) {
990// for (i = lo+1; i < hi; i++) {
991// d = data[i];
992// p = binarySearch(lo, i-1, data)
993// for (j = 0; j > i - p; j++)
994// data[i-j] = data[i-j-1]
995// data[p] = d
996// }
997// }
998static void createSortStableFunc(OpBuilder &builder, ModuleOp module,
999 func::FuncOp func, AffineMap xPerm,
1000 uint64_t ny, uint32_t nTrailingP) {
1001 // Stable sort function doesn't use trailing parameters.
1002 (void)nTrailingP;
1003 assert(nTrailingP == 0);
1004 OpBuilder::InsertionGuard insertionGuard(builder);
1005 Block *entryBlock = func.addEntryBlock();
1006 builder.setInsertionPointToStart(entryBlock);
1007
1008 MLIRContext *context = module.getContext();
1009 Location loc = func.getLoc();
1010 ValueRange args = entryBlock->getArguments();
1011 Value c1 = constantIndex(builder, loc, i: 1);
1012 Value lo = args[loIdx];
1013 Value hi = args[hiIdx];
1014 Value lop1 = builder.create<arith::AddIOp>(loc, lo, c1);
1015
1016 // Start the outer for-stmt with induction variable i.
1017 scf::ForOp forOpI = builder.create<scf::ForOp>(loc, lop1, hi, c1);
1018 builder.setInsertionPointToStart(forOpI.getBody());
1019 Value i = forOpI.getInductionVar();
1020
1021 // Binary search to find the insertion point p.
1022 SmallVector<Value> operands{lo, i};
1023 operands.append(in_start: args.begin() + xStartIdx, in_end: args.end());
1024 FlatSymbolRefAttr searchFunc = getMangledSortHelperFunc(
1025 builder, func, {IndexType::get(context)}, kBinarySearchFuncNamePrefix,
1026 xPerm, ny, operands, createBinarySearchFunc);
1027 Value p = builder
1028 .create<func::CallOp>(loc, searchFunc, TypeRange{c1.getType()},
1029 operands)
1030 .getResult(0);
1031
1032 // Move the value at data[i] to a temporary location.
1033 operands[0] = operands[1] = i;
1034 SmallVector<Value> d;
1035 forEachIJPairInAllBuffers(
1036 builder, loc, args: operands, xPerm, ny,
1037 bodyBuilder: [&](uint64_t unused, Value i, Value unused2, Value buffer) {
1038 d.push_back(builder.create<memref::LoadOp>(loc, buffer, i));
1039 });
1040
1041 // Start the inner for-stmt with induction variable j, for moving data[p..i)
1042 // to data[p+1..i+1).
1043 Value imp = builder.create<arith::SubIOp>(loc, i, p);
1044 Value c0 = constantIndex(builder, loc, i: 0);
1045 scf::ForOp forOpJ = builder.create<scf::ForOp>(loc, c0, imp, c1);
1046 builder.setInsertionPointToStart(forOpJ.getBody());
1047 Value j = forOpJ.getInductionVar();
1048 Value imj = builder.create<arith::SubIOp>(loc, i, j);
1049 operands[1] = imj;
1050 operands[0] = builder.create<arith::SubIOp>(loc, imj, c1);
1051 forEachIJPairInAllBuffers(
1052 builder, loc, args: operands, xPerm, ny,
1053 bodyBuilder: [&](uint64_t unused, Value imjm1, Value imj, Value buffer) {
1054 Value t = builder.create<memref::LoadOp>(loc, buffer, imjm1);
1055 builder.create<memref::StoreOp>(loc, t, buffer, imj);
1056 });
1057
1058 // Store the value at data[i] to data[p].
1059 builder.setInsertionPointAfter(forOpJ);
1060 operands[0] = operands[1] = p;
1061 forEachIJPairInAllBuffers(
1062 builder, loc, args: operands, xPerm, ny,
1063 bodyBuilder: [&](uint64_t k, Value p, Value usused, Value buffer) {
1064 builder.create<memref::StoreOp>(loc, d[k], buffer, p);
1065 });
1066
1067 builder.setInsertionPointAfter(forOpI);
1068 builder.create<func::ReturnOp>(loc);
1069}
1070
1071/// Creates a function to perform quick sort or a hybrid quick sort on the
1072/// values in the range of index [lo, hi).
1073//
1074//
1075// When nTrailingP == 0, the generated IR corresponds to this C like algorithm:
1076// void quickSort(lo, hi, data) {
1077// while (lo + 1 < hi) {
1078// p = partition(low, high, data);
1079// if (len(lo, p) < len(p+1, hi)) {
1080// quickSort(lo, p, data);
1081// lo = p+1;
1082// } else {
1083// quickSort(p + 1, hi, data);
1084// hi = p;
1085// }
1086// }
1087// }
1088//
1089// When nTrailingP == 1, the generated IR corresponds to this C like algorithm:
1090// void hybridQuickSort(lo, hi, data, depthLimit) {
1091// while (lo + 1 < hi) {
1092// len = hi - lo;
1093// if (len <= limit) {
1094// insertionSort(lo, hi, data);
1095// } else {
1096// depthLimit --;
1097// if (depthLimit <= 0) {
1098// heapSort(lo, hi, data);
1099// } else {
1100// p = partition(low, high, data);
1101// if (len(lo, p) < len(p+1, hi)) {
1102// quickSort(lo, p, data, depthLimit);
1103// lo = p+1;
1104// } else {
1105// quickSort(p + 1, hi, data, depthLimit);
1106// hi = p;
1107// }
1108// }
1109// }
1110// }
1111// }
1112//
1113static void createQuickSortFunc(OpBuilder &builder, ModuleOp module,
1114 func::FuncOp func, AffineMap xPerm, uint64_t ny,
1115 uint32_t nTrailingP) {
1116 assert(nTrailingP == 1 || nTrailingP == 0);
1117 bool isHybrid = (nTrailingP == 1);
1118 OpBuilder::InsertionGuard insertionGuard(builder);
1119 Block *entryBlock = func.addEntryBlock();
1120 builder.setInsertionPointToStart(entryBlock);
1121
1122 Location loc = func.getLoc();
1123 SmallVector<Value> args;
1124 args.append(in_start: entryBlock->getArguments().begin(),
1125 in_end: entryBlock->getArguments().end());
1126 Value lo = args[loIdx];
1127 Value hi = args[hiIdx];
1128 SmallVector<Type, 2> types(2, lo.getType()); // Only two types.
1129 scf::WhileOp whileOp =
1130 builder.create<scf::WhileOp>(loc, types, SmallVector<Value, 2>{lo, hi});
1131
1132 // The before-region of the WhileOp.
1133 Block *before =
1134 builder.createBlock(&whileOp.getBefore(), {}, types, {loc, loc});
1135 builder.setInsertionPointToEnd(before);
1136 lo = before->getArgument(i: 0);
1137 hi = before->getArgument(i: 1);
1138 Value loP1 =
1139 builder.create<arith::AddIOp>(loc, lo, constantIndex(builder, loc, 1));
1140 Value needSort =
1141 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, loP1, hi);
1142 builder.create<scf::ConditionOp>(loc, needSort, before->getArguments());
1143
1144 // The after-region of the WhileOp.
1145 Block *after =
1146 builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc});
1147 builder.setInsertionPointToEnd(after);
1148 lo = after->getArgument(i: 0);
1149 hi = after->getArgument(i: 1);
1150 args[0] = lo;
1151 args[1] = hi;
1152
1153 if (isHybrid) {
1154 Value len = builder.create<arith::SubIOp>(loc, hi, lo);
1155 Value lenLimit = constantIndex(builder, loc, i: 30);
1156 Value lenCond = builder.create<arith::CmpIOp>(
1157 loc, arith::CmpIPredicate::ule, len, lenLimit);
1158 scf::IfOp lenIf =
1159 builder.create<scf::IfOp>(loc, types, lenCond, /*else=*/true);
1160
1161 // When len <= limit.
1162 builder.setInsertionPointToStart(&lenIf.getThenRegion().front());
1163 FlatSymbolRefAttr insertionSortFunc = getMangledSortHelperFunc(
1164 builder, func, TypeRange(), kSortStableFuncNamePrefix, xPerm, ny,
1165 ValueRange(args).drop_back(n: nTrailingP), createSortStableFunc);
1166 builder.create<func::CallOp>(loc, insertionSortFunc, TypeRange(),
1167 ValueRange(args).drop_back(nTrailingP));
1168 builder.create<scf::YieldOp>(loc, ValueRange{lo, lo});
1169
1170 // When len > limit.
1171 builder.setInsertionPointToStart(&lenIf.getElseRegion().front());
1172 Value depthLimit = args.back();
1173 depthLimit = builder.create<arith::SubIOp>(loc, depthLimit,
1174 constantI64(builder, loc, 1));
1175 Value depthCond =
1176 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ule,
1177 depthLimit, constantI64(builder, loc, 0));
1178 scf::IfOp depthIf =
1179 builder.create<scf::IfOp>(loc, types, depthCond, /*else=*/true);
1180
1181 // When depth exceeds limit.
1182 builder.setInsertionPointToStart(&depthIf.getThenRegion().front());
1183 FlatSymbolRefAttr heapSortFunc = getMangledSortHelperFunc(
1184 builder, func, TypeRange(), kHeapSortFuncNamePrefix, xPerm, ny,
1185 ValueRange(args).drop_back(n: nTrailingP), createHeapSortFunc);
1186 builder.create<func::CallOp>(loc, heapSortFunc, TypeRange(),
1187 ValueRange(args).drop_back(nTrailingP));
1188 builder.create<scf::YieldOp>(loc, ValueRange{lo, lo});
1189
1190 // When depth doesn't exceed limit.
1191 builder.setInsertionPointToStart(&depthIf.getElseRegion().front());
1192 args.back() = depthLimit;
1193 std::tie(args&: lo, args&: hi) =
1194 createQuickSort(builder, module, func, args, xPerm, ny, nTrailingP);
1195 builder.create<scf::YieldOp>(loc, ValueRange{lo, hi});
1196
1197 builder.setInsertionPointAfter(depthIf);
1198 lo = depthIf.getResult(0);
1199 hi = depthIf.getResult(1);
1200 builder.create<scf::YieldOp>(loc, ValueRange{lo, hi});
1201
1202 builder.setInsertionPointAfter(lenIf);
1203 lo = lenIf.getResult(0);
1204 hi = lenIf.getResult(1);
1205 } else {
1206 std::tie(args&: lo, args&: hi) =
1207 createQuickSort(builder, module, func, args, xPerm, ny, nTrailingP);
1208 }
1209
1210 // New [lo, hi) for the next while-loop iteration.
1211 builder.create<scf::YieldOp>(loc, ValueRange{lo, hi});
1212
1213 // After the while-loop.
1214 builder.setInsertionPointAfter(whileOp);
1215 builder.create<func::ReturnOp>(loc);
1216}
1217
1218/// Implements the rewriting for operator sort and sort_coo.
1219template <typename OpTy>
1220LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, AffineMap xPerm,
1221 uint64_t ny, PatternRewriter &rewriter) {
1222 Location loc = op.getLoc();
1223 SmallVector<Value> operands{constantIndex(builder&: rewriter, loc, i: 0), op.getN()};
1224
1225 // Convert `values` to have dynamic shape and append them to `operands`.
1226 for (Value v : xys) {
1227 auto mtp = getMemRefType(v);
1228 if (!mtp.isDynamicDim(0)) {
1229 auto newMtp =
1230 MemRefType::get({ShapedType::kDynamic}, mtp.getElementType());
1231 v = rewriter.create<memref::CastOp>(loc, newMtp, v);
1232 }
1233 operands.push_back(Elt: v);
1234 }
1235
1236 auto insertPoint = op->template getParentOfType<func::FuncOp>();
1237 if (!insertPoint)
1238 return failure();
1239
1240 SmallString<32> funcName;
1241 FuncGeneratorType funcGenerator;
1242 uint32_t nTrailingP = 0;
1243 switch (op.getAlgorithm()) {
1244 case SparseTensorSortKind::HybridQuickSort: {
1245 funcName = kHybridQuickSortFuncNamePrefix;
1246 funcGenerator = createQuickSortFunc;
1247 nTrailingP = 1;
1248 // As a heuristics, set depthLimit = 2 * log2(n).
1249 Value lo = operands[loIdx];
1250 Value hi = operands[hiIdx];
1251 Value len = rewriter.create<arith::IndexCastOp>(
1252 loc, rewriter.getI64Type(),
1253 rewriter.create<arith::SubIOp>(loc, hi, lo));
1254 Value depthLimit = rewriter.create<arith::SubIOp>(
1255 loc, constantI64(rewriter, loc, 64),
1256 rewriter.create<math::CountLeadingZerosOp>(loc, len));
1257 operands.push_back(Elt: depthLimit);
1258 break;
1259 }
1260 case SparseTensorSortKind::QuickSort:
1261 funcName = kQuickSortFuncNamePrefix;
1262 funcGenerator = createQuickSortFunc;
1263 break;
1264 case SparseTensorSortKind::InsertionSortStable:
1265 funcName = kSortStableFuncNamePrefix;
1266 funcGenerator = createSortStableFunc;
1267 break;
1268 case SparseTensorSortKind::HeapSort:
1269 funcName = kHeapSortFuncNamePrefix;
1270 funcGenerator = createHeapSortFunc;
1271 break;
1272 }
1273
1274 FlatSymbolRefAttr func =
1275 getMangledSortHelperFunc(rewriter, insertPoint, TypeRange(), funcName,
1276 xPerm, ny, operands, funcGenerator, nTrailingP);
1277 rewriter.replaceOpWithNewOp<func::CallOp>(op, func, TypeRange(), operands);
1278 return success();
1279}
1280
1281//===---------------------------------------------------------------------===//
1282// The actual sparse buffer rewriting rules.
1283//===---------------------------------------------------------------------===//
1284
1285namespace {
1286/// Sparse rewriting rule for the push_back operator.
1287struct PushBackRewriter : OpRewritePattern<PushBackOp> {
1288public:
1289 using OpRewritePattern<PushBackOp>::OpRewritePattern;
1290 PushBackRewriter(MLIRContext *context, bool enableInit)
1291 : OpRewritePattern(context), enableBufferInitialization(enableInit) {}
1292 LogicalResult matchAndRewrite(PushBackOp op,
1293 PatternRewriter &rewriter) const override {
1294 // Rewrite push_back(buffer, value, n) to:
1295 // new_size = size(buffer) + n
1296 // if (new_size > capacity(buffer))
1297 // while new_size > new_capacity
1298 // new_capacity = new_capacity*2
1299 // new_buffer = realloc(buffer, new_capacity)
1300 // buffer = new_buffer
1301 // subBuffer = subviewof(buffer)
1302 // linalg.fill subBuffer value
1303 //
1304 // size(buffer) += n
1305 //
1306 // The capacity check is skipped when the attribute inbounds is presented.
1307 Location loc = op->getLoc();
1308 Value c0 = constantIndex(builder&: rewriter, loc, i: 0);
1309 Value buffer = op.getInBuffer();
1310 Value capacity = rewriter.create<memref::DimOp>(loc, buffer, c0);
1311 Value size = op.getCurSize();
1312 Value value = op.getValue();
1313
1314 Value n = op.getN() ? op.getN() : constantIndex(builder&: rewriter, loc, i: 1);
1315 Value newSize = rewriter.create<arith::AddIOp>(loc, size, n);
1316 auto nValue = dyn_cast_or_null<arith::ConstantIndexOp>(Val: n.getDefiningOp());
1317 bool nIsOne = (nValue && nValue.value() == 1);
1318
1319 if (!op.getInbounds()) {
1320 Value cond = rewriter.create<arith::CmpIOp>(
1321 loc, arith::CmpIPredicate::ugt, newSize, capacity);
1322
1323 Value c2 = constantIndex(builder&: rewriter, loc, i: 2);
1324 auto bufferType =
1325 MemRefType::get({ShapedType::kDynamic}, value.getType());
1326 scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, bufferType, cond,
1327 /*else=*/true);
1328 // True branch.
1329 rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
1330 if (nIsOne) {
1331 capacity = rewriter.create<arith::MulIOp>(loc, capacity, c2);
1332 } else {
1333 // Use a do-while loop to calculate the new capacity as follows:
1334 // do { new_capacity *= 2 } while (size > new_capacity)
1335 scf::WhileOp whileOp =
1336 rewriter.create<scf::WhileOp>(loc, capacity.getType(), capacity);
1337
1338 // The before-region of the WhileOp.
1339 Block *before = rewriter.createBlock(&whileOp.getBefore(), {},
1340 {capacity.getType()}, {loc});
1341 rewriter.setInsertionPointToEnd(before);
1342
1343 capacity =
1344 rewriter.create<arith::MulIOp>(loc, before->getArgument(0), c2);
1345 cond = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ugt,
1346 newSize, capacity);
1347 rewriter.create<scf::ConditionOp>(loc, cond, ValueRange{capacity});
1348 // The after-region of the WhileOp.
1349 Block *after = rewriter.createBlock(&whileOp.getAfter(), {},
1350 {capacity.getType()}, {loc});
1351 rewriter.setInsertionPointToEnd(after);
1352 rewriter.create<scf::YieldOp>(loc, after->getArguments());
1353
1354 rewriter.setInsertionPointAfter(whileOp);
1355 capacity = whileOp.getResult(0);
1356 }
1357
1358 Value newBuffer =
1359 rewriter.create<memref::ReallocOp>(loc, bufferType, buffer, capacity);
1360 if (enableBufferInitialization) {
1361 Value fillSize = rewriter.create<arith::SubIOp>(loc, capacity, newSize);
1362 Value fillValue = constantZero(builder&: rewriter, loc, tp: value.getType());
1363 Value subBuffer = rewriter.create<memref::SubViewOp>(
1364 loc, newBuffer, /*offset=*/ValueRange{newSize},
1365 /*size=*/ValueRange{fillSize},
1366 /*step=*/ValueRange{constantIndex(rewriter, loc, 1)});
1367 rewriter.create<linalg::FillOp>(loc, fillValue, subBuffer);
1368 }
1369 rewriter.create<scf::YieldOp>(loc, newBuffer);
1370
1371 // False branch.
1372 rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
1373 rewriter.create<scf::YieldOp>(loc, buffer);
1374
1375 // Prepare for adding the value to the end of the buffer.
1376 rewriter.setInsertionPointAfter(ifOp);
1377 buffer = ifOp.getResult(0);
1378 }
1379
1380 // Add the value to the end of the buffer.
1381 if (nIsOne) {
1382 rewriter.create<memref::StoreOp>(loc, value, buffer, size);
1383 } else {
1384 Value subBuffer = rewriter.create<memref::SubViewOp>(
1385 loc, buffer, /*offset=*/ValueRange{size}, /*size=*/ValueRange{n},
1386 /*step=*/ValueRange{constantIndex(rewriter, loc, 1)});
1387 rewriter.create<linalg::FillOp>(loc, value, subBuffer);
1388 }
1389
1390 // Update the buffer size.
1391 rewriter.replaceOp(op, {buffer, newSize});
1392 return success();
1393 }
1394
1395private:
1396 bool enableBufferInitialization;
1397};
1398
1399/// Sparse rewriting rule for the sort_coo operator.
1400struct SortRewriter : public OpRewritePattern<SortOp> {
1401public:
1402 using OpRewritePattern<SortOp>::OpRewritePattern;
1403
1404 LogicalResult matchAndRewrite(SortOp op,
1405 PatternRewriter &rewriter) const override {
1406 SmallVector<Value> xys;
1407 xys.push_back(Elt: op.getXy());
1408 xys.append(op.getYs().begin(), op.getYs().end());
1409
1410 auto xPerm = op.getPermMap();
1411 uint64_t ny = 0;
1412 if (auto nyAttr = op.getNyAttr())
1413 ny = nyAttr.getInt();
1414
1415 return matchAndRewriteSortOp(op, xys, xPerm, ny, rewriter);
1416 }
1417};
1418
1419} // namespace
1420
1421//===---------------------------------------------------------------------===//
1422// Methods that add patterns described in this file to a pattern list.
1423//===---------------------------------------------------------------------===//
1424
1425void mlir::populateSparseBufferRewriting(RewritePatternSet &patterns,
1426 bool enableBufferInitialization) {
1427 patterns.add<PushBackRewriter>(arg: patterns.getContext(),
1428 args&: enableBufferInitialization);
1429 patterns.add<SortRewriter>(arg: patterns.getContext());
1430}
1431

source code of mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp