1//===- SparseBufferRewriting.cpp - Sparse buffer rewriting rules ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements rewriting rules that are specific to sparse tensor
10// primitives with memref operands.
11//
12//===----------------------------------------------------------------------===//
13
14#include "Utils/CodegenUtils.h"
15
16#include "mlir/Dialect/Arith/IR/Arith.h"
17#include "mlir/Dialect/Func/IR/FuncOps.h"
18#include "mlir/Dialect/Linalg/IR/Linalg.h"
19#include "mlir/Dialect/Math/IR/Math.h"
20#include "mlir/Dialect/MemRef/IR/MemRef.h"
21#include "mlir/Dialect/SCF/IR/SCF.h"
22#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
23#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
24#include "mlir/Support/LLVM.h"
25
26using namespace mlir;
27using namespace mlir::sparse_tensor;
28
29//===---------------------------------------------------------------------===//
30// Helper methods for the actual rewriting rules.
31//===---------------------------------------------------------------------===//
32
33static constexpr uint64_t loIdx = 0;
34static constexpr uint64_t hiIdx = 1;
35static constexpr uint64_t xStartIdx = 2;
36
37static constexpr const char kPartitionFuncNamePrefix[] = "_sparse_partition_";
38static constexpr const char kBinarySearchFuncNamePrefix[] =
39 "_sparse_binary_search_";
40static constexpr const char kHybridQuickSortFuncNamePrefix[] =
41 "_sparse_hybrid_qsort_";
42static constexpr const char kSortStableFuncNamePrefix[] =
43 "_sparse_sort_stable_";
44static constexpr const char kShiftDownFuncNamePrefix[] = "_sparse_shift_down_";
45static constexpr const char kHeapSortFuncNamePrefix[] = "_sparse_heap_sort_";
46static constexpr const char kQuickSortFuncNamePrefix[] = "_sparse_qsort_";
47
48using FuncGeneratorType = function_ref<void(OpBuilder &, ModuleOp, func::FuncOp,
49 AffineMap, uint64_t, uint32_t)>;
50
51/// Constructs a function name with this format to facilitate quick sort:
52/// <namePrefix><xPerm>_<x type>_<y0 type>..._<yn type> for sort
53/// <namePrefix><xPerm>_<x type>_coo_<ny>_<y0 type>..._<yn type> for sort_coo
54static void getMangledSortHelperFuncName(llvm::raw_svector_ostream &nameOstream,
55 StringRef namePrefix, AffineMap xPerm,
56 uint64_t ny, ValueRange operands) {
57 nameOstream << namePrefix;
58 for (auto res : xPerm.getResults())
59 nameOstream << cast<AffineDimExpr>(Val&: res).getPosition() << "_";
60
61 nameOstream << getMemRefType(operands[xStartIdx]).getElementType();
62 nameOstream << "_coo_" << ny;
63
64 constexpr uint64_t yBufferOffset = 1;
65 for (Value v : operands.drop_front(n: xStartIdx + yBufferOffset))
66 nameOstream << "_" << getMemRefType(v).getElementType();
67}
68
69/// Looks up a function that is appropriate for the given operands being
70/// sorted, and creates such a function if it doesn't exist yet. The
71/// parameters `xPerm` and `ny` tell the number of x and y values provided
72/// by the buffer in xStartIdx.
73//
74// All sorting function generators take (lo, hi, xs, ys) in `operands` as
75// parameters for the sorting functions. Other parameters, such as the recursive
76// call depth, are appended to the end of the parameter list as
77// "trailing parameters".
78static FlatSymbolRefAttr getMangledSortHelperFunc(
79 OpBuilder &builder, func::FuncOp insertPoint, TypeRange resultTypes,
80 StringRef namePrefix, AffineMap xPerm, uint64_t ny, ValueRange operands,
81 FuncGeneratorType createFunc, uint32_t nTrailingP = 0) {
82 SmallString<32> nameBuffer;
83 llvm::raw_svector_ostream nameOstream(nameBuffer);
84 getMangledSortHelperFuncName(nameOstream, namePrefix, xPerm, ny,
85 operands: operands.drop_back(n: nTrailingP));
86
87 ModuleOp module = insertPoint->getParentOfType<ModuleOp>();
88 MLIRContext *context = module.getContext();
89 auto result = SymbolRefAttr::get(context, nameOstream.str());
90 auto func = module.lookupSymbol<func::FuncOp>(result.getAttr());
91
92 if (!func) {
93 // Create the function.
94 OpBuilder::InsertionGuard insertionGuard(builder);
95 builder.setInsertionPoint(insertPoint);
96 Location loc = insertPoint.getLoc();
97 func = builder.create<func::FuncOp>(
98 loc, nameOstream.str(),
99 FunctionType::get(context, operands.getTypes(), resultTypes));
100 func.setPrivate();
101 createFunc(builder, module, func, xPerm, ny, nTrailingP);
102 }
103
104 return result;
105}
106
107/// Creates a code block to process each pair of (xs[i], xs[j]) for sorting.
108/// The code to process the value pairs is generated by `bodyBuilder`.
109static void forEachIJPairInXs(
110 OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm,
111 uint64_t ny,
112 function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) {
113 Value cstep = constantIndex(builder, loc, i: xPerm.getNumResults() + ny);
114 Value iOffset = builder.create<arith::MulIOp>(loc, args[0], cstep);
115 Value jOffset = builder.create<arith::MulIOp>(loc, args[1], cstep);
116 for (unsigned k = 0, e = xPerm.getNumResults(); k < e; k++) {
117 unsigned actualK = cast<AffineDimExpr>(Val: xPerm.getResult(idx: k)).getPosition();
118 Value ak = constantIndex(builder, loc, i: actualK);
119 Value i = builder.create<arith::AddIOp>(loc, ak, iOffset);
120 Value j = builder.create<arith::AddIOp>(loc, ak, jOffset);
121 Value buffer = args[xStartIdx];
122
123 bodyBuilder(k, i, j, buffer);
124 }
125}
126
127/// Creates a code block to process each pair of (xys[i], xys[j]) for sorting.
128/// The code to process the value pairs is generated by `bodyBuilder`.
129static void forEachIJPairInAllBuffers(
130 OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm,
131 uint64_t ny,
132 function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) {
133
134 // Create code for the first (xPerm + ny) buffers.
135 SmallVector<AffineExpr> exps(xPerm.getResults().begin(),
136 xPerm.getResults().end());
137 for (unsigned y = 0; y < ny; y++) {
138 exps.push_back(Elt: builder.getAffineDimExpr(position: y + xPerm.getNumResults()));
139 }
140 AffineMap xyPerm = AffineMap::get(dimCount: exps.size(), symbolCount: 0, results: exps, context: builder.getContext());
141 assert(xyPerm.isPermutation());
142
143 forEachIJPairInXs(builder, loc, args, xPerm: xyPerm, ny: 0, bodyBuilder);
144
145 constexpr uint64_t numHandledBuffers = 1;
146 // Create code for the remaining buffers.
147 Value i = args[0];
148 Value j = args[1];
149 for (const auto &arg :
150 llvm::enumerate(First: args.drop_front(n: xStartIdx + numHandledBuffers))) {
151 bodyBuilder(arg.index() + xPerm.getNumResults() + ny, i, j, arg.value());
152 }
153}
154
155/// Creates a code block for swapping the values in index i and j for all the
156/// buffers.
157//
158// The generated IR corresponds to this C like algorithm:
159// swap(x0[i], x0[j]);
160// swap(x1[i], x1[j]);
161// ...
162// swap(xn[i], xn[j]);
163// swap(y0[i], y0[j]);
164// ...
165// swap(yn[i], yn[j]);
166static void createSwap(OpBuilder &builder, Location loc, ValueRange args,
167 AffineMap xPerm, uint64_t ny) {
168 auto swapOnePair = [&](uint64_t unused, Value i, Value j, Value buffer) {
169 Value vi = builder.create<memref::LoadOp>(loc, buffer, i);
170 Value vj = builder.create<memref::LoadOp>(loc, buffer, j);
171 builder.create<memref::StoreOp>(loc, vj, buffer, i);
172 builder.create<memref::StoreOp>(loc, vi, buffer, j);
173 };
174
175 forEachIJPairInAllBuffers(builder, loc, args, xPerm, ny, bodyBuilder: swapOnePair);
176}
177
178/// Creates code to compare all the (xs[i], xs[j]) pairs. The method to compare
179/// each pair is create via `compareBuilder`.
180static Value createInlinedCompareImplementation(
181 OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm,
182 uint64_t ny,
183 function_ref<Value(OpBuilder &, Location, Value, Value, Value, bool, bool)>
184 compareBuilder) {
185 Value result;
186 auto bodyBuilder = [&](uint64_t k, Value i, Value j, Value buffer) {
187 bool isFirstDim = (k == 0);
188 bool isLastDim = (k == xPerm.getNumResults() - 1);
189 Value val =
190 compareBuilder(builder, loc, i, j, buffer, isFirstDim, isLastDim);
191 if (isFirstDim) {
192 result = val;
193 } else if (!isLastDim) {
194 OpBuilder::InsertionGuard insertionGuard(builder);
195 auto ifOp = cast<scf::IfOp>(val.getDefiningOp());
196 builder.setInsertionPointAfter(ifOp);
197 builder.create<scf::YieldOp>(loc, ifOp.getResult(0));
198 }
199 };
200
201 forEachIJPairInXs(builder, loc, args, xPerm, ny, bodyBuilder);
202
203 builder.setInsertionPointAfterValue(result);
204 return result;
205}
206
207/// Generates code to compare whether x[i] is equal to x[j] and returns the
208/// result of the comparison.
209static Value createEqCompare(OpBuilder &builder, Location loc, Value i, Value j,
210 Value x, bool isFirstDim, bool isLastDim) {
211 Value vi = builder.create<memref::LoadOp>(loc, x, i);
212 Value vj = builder.create<memref::LoadOp>(loc, x, j);
213
214 Value res;
215 if (isLastDim) {
216 res = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, vi, vj);
217 // For 1D, we create a compare without any control flow. Otherwise, we
218 // create YieldOp to return the result in the nested if-stmt.
219 if (!isFirstDim)
220 builder.create<scf::YieldOp>(loc, res);
221 } else {
222 Value ne =
223 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, vi, vj);
224 scf::IfOp ifOp = builder.create<scf::IfOp>(loc, builder.getIntegerType(1),
225 ne, /*else=*/true);
226 // If (x[i] != x[j]).
227 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
228 Value f = constantI1(builder, loc, b: false);
229 builder.create<scf::YieldOp>(loc, f);
230
231 // If (x[i] == x[j]). Set up the insertion point for the nested if-stmt that
232 // checks the remaining dimensions.
233 builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
234 res = ifOp.getResult(0);
235 }
236
237 return res;
238}
239
240/// Creates code to compare whether xs[i] is equal to xs[j].
241//
242// The generate IR corresponds to this C like algorithm:
243// if (x0[i] != x0[j])
244// return false;
245// else
246// if (x1[i] != x1[j])
247// return false;
248// else if (x2[2] != x2[j]))
249// and so on ...
250static Value createInlinedEqCompare(OpBuilder &builder, Location loc,
251 ValueRange args, AffineMap xPerm,
252 uint64_t ny, uint32_t nTrailingP = 0) {
253 // Compare functions don't use trailing parameters.
254 (void)nTrailingP;
255 assert(nTrailingP == 0);
256 return createInlinedCompareImplementation(builder, loc, args, xPerm, ny,
257 compareBuilder: createEqCompare);
258}
259
260/// Generates code to compare whether x[i] is less than x[j] and returns the
261/// result of the comparison.
262static Value createLessThanCompare(OpBuilder &builder, Location loc, Value i,
263 Value j, Value x, bool isFirstDim,
264 bool isLastDim) {
265 Value vi = builder.create<memref::LoadOp>(loc, x, i);
266 Value vj = builder.create<memref::LoadOp>(loc, x, j);
267
268 Value res;
269 if (isLastDim) {
270 res = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, vi, vj);
271 // For 1D, we create a compare without any control flow. Otherwise, we
272 // create YieldOp to return the result in the nested if-stmt.
273 if (!isFirstDim)
274 builder.create<scf::YieldOp>(loc, res);
275 } else {
276 Value ne =
277 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, vi, vj);
278 scf::IfOp ifOp = builder.create<scf::IfOp>(loc, builder.getIntegerType(1),
279 ne, /*else=*/true);
280 // If (x[i] != x[j]).
281 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
282 Value lt =
283 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, vi, vj);
284 builder.create<scf::YieldOp>(loc, lt);
285
286 // If (x[i] == x[j]). Set up the insertion point for the nested if-stmt that
287 // checks the remaining dimensions.
288 builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
289 res = ifOp.getResult(0);
290 }
291
292 return res;
293}
294
295/// Creates code to compare whether xs[i] is less than xs[j].
296//
297// The generate IR corresponds to this C like algorithm:
298// if (x0[i] != x0[j])
299// return x0[i] < x0[j];
300// else if (x1[j] != x1[i])
301// return x1[i] < x1[j];
302// else
303// and so on ...
304static Value createInlinedLessThan(OpBuilder &builder, Location loc,
305 ValueRange args, AffineMap xPerm,
306 uint64_t ny, uint32_t nTrailingP = 0) {
307 // Compare functions don't use trailing parameters.
308 (void)nTrailingP;
309 assert(nTrailingP == 0);
310 return createInlinedCompareImplementation(builder, loc, args, xPerm, ny,
311 compareBuilder: createLessThanCompare);
312}
313
314/// Creates a function to use a binary search to find the insertion point for
315/// inserting xs[hi] to the sorted values xs[lo..hi).
316//
317// The generate IR corresponds to this C like algorithm:
318// p = hi
319// while (lo < hi)
320// mid = (lo + hi) >> 1
321// if (xs[p] < xs[mid])
322// hi = mid
323// else
324// lo = mid - 1
325// return lo;
326//
327static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module,
328 func::FuncOp func, AffineMap xPerm,
329 uint64_t ny, uint32_t nTrailingP = 0) {
330 // Binary search doesn't use trailing parameters.
331 (void)nTrailingP;
332 assert(nTrailingP == 0);
333 OpBuilder::InsertionGuard insertionGuard(builder);
334 Block *entryBlock = func.addEntryBlock();
335 builder.setInsertionPointToStart(entryBlock);
336
337 Location loc = func.getLoc();
338 ValueRange args = entryBlock->getArguments();
339 Value p = args[hiIdx];
340 SmallVector<Type, 2> types(2, p.getType()); // Only two types.
341 scf::WhileOp whileOp = builder.create<scf::WhileOp>(
342 loc, types, SmallVector<Value, 2>{args[loIdx], args[hiIdx]});
343
344 // The before-region of the WhileOp.
345 Block *before =
346 builder.createBlock(&whileOp.getBefore(), {}, types, {loc, loc});
347 builder.setInsertionPointToEnd(before);
348 Value cond1 = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
349 before->getArgument(0),
350 before->getArgument(1));
351 builder.create<scf::ConditionOp>(loc, cond1, before->getArguments());
352
353 // The after-region of the WhileOp.
354 Block *after =
355 builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc});
356 builder.setInsertionPointToEnd(after);
357 Value lo = after->getArgument(i: 0);
358 Value hi = after->getArgument(i: 1);
359 // Compute mid = (lo + hi) >> 1.
360 Value c1 = constantIndex(builder, loc, i: 1);
361 Value mid = builder.create<arith::ShRUIOp>(
362 loc, builder.create<arith::AddIOp>(loc, lo, hi), c1);
363 Value midp1 = builder.create<arith::AddIOp>(loc, mid, c1);
364
365 // Compare xs[p] < xs[mid].
366 SmallVector<Value> compareOperands{p, mid};
367 constexpr uint64_t numXBuffers = 1;
368 compareOperands.append(in_start: args.begin() + xStartIdx,
369 in_end: args.begin() + xStartIdx + numXBuffers);
370 Value cond2 = createInlinedLessThan(builder, loc, args: compareOperands, xPerm, ny);
371 // Update lo and hi for the WhileOp as follows:
372 // if (xs[p] < xs[mid]))
373 // hi = mid;
374 // else
375 // lo = mid + 1;
376 Value newLo = builder.create<arith::SelectOp>(loc, cond2, lo, midp1);
377 Value newHi = builder.create<arith::SelectOp>(loc, cond2, mid, hi);
378 builder.create<scf::YieldOp>(loc, ValueRange{newLo, newHi});
379
380 builder.setInsertionPointAfter(whileOp);
381 builder.create<func::ReturnOp>(loc, whileOp.getResult(0));
382}
383
384/// Creates code to advance i in a loop based on xs[p] as follows:
385/// while (xs[i] < xs[p]) i += step (step > 0)
386/// or
387/// while (xs[i] > xs[p]) i += step (step < 0)
388/// The routine returns i as well as a boolean value to indicate whether
389/// xs[i] == xs[p].
390static std::pair<Value, Value> createScanLoop(OpBuilder &builder,
391 ModuleOp module,
392 func::FuncOp func, ValueRange xs,
393 Value i, Value p, AffineMap xPerm,
394 uint64_t ny, int step) {
395 Location loc = func.getLoc();
396 scf::WhileOp whileOp =
397 builder.create<scf::WhileOp>(loc, TypeRange{i.getType()}, ValueRange{i});
398
399 Block *before =
400 builder.createBlock(&whileOp.getBefore(), {}, {i.getType()}, {loc});
401 builder.setInsertionPointToEnd(before);
402 SmallVector<Value> compareOperands;
403 if (step > 0) {
404 compareOperands.push_back(Elt: before->getArgument(i: 0));
405 compareOperands.push_back(Elt: p);
406 } else {
407 assert(step < 0);
408 compareOperands.push_back(Elt: p);
409 compareOperands.push_back(Elt: before->getArgument(i: 0));
410 }
411 compareOperands.append(in_start: xs.begin(), in_end: xs.end());
412 Value cond = createInlinedLessThan(builder, loc, args: compareOperands, xPerm, ny);
413 builder.create<scf::ConditionOp>(loc, cond, before->getArguments());
414
415 Block *after =
416 builder.createBlock(&whileOp.getAfter(), {}, {i.getType()}, {loc});
417 builder.setInsertionPointToEnd(after);
418 Value cs = constantIndex(builder, loc, i: step);
419 i = builder.create<arith::AddIOp>(loc, after->getArgument(0), cs);
420 builder.create<scf::YieldOp>(loc, ValueRange{i});
421 i = whileOp.getResult(0);
422
423 builder.setInsertionPointAfter(whileOp);
424 compareOperands[0] = i;
425 compareOperands[1] = p;
426 Value compareEq =
427 createInlinedEqCompare(builder, loc, args: compareOperands, xPerm, ny);
428
429 return std::make_pair(whileOp.getResult(0), compareEq);
430}
431
432/// Creates and returns an IfOp to compare two elements and swap the elements
433/// if compareFunc(data[b], data[a]) returns true. The new insertion point is
434/// right after the swap instructions.
435static scf::IfOp createCompareThenSwap(OpBuilder &builder, Location loc,
436 AffineMap xPerm, uint64_t ny,
437 SmallVectorImpl<Value> &swapOperands,
438 SmallVectorImpl<Value> &compareOperands,
439 Value a, Value b) {
440 // Compare(data[b], data[a]).
441 compareOperands[0] = b;
442 compareOperands[1] = a;
443 Value cond = createInlinedLessThan(builder, loc, args: compareOperands, xPerm, ny);
444 scf::IfOp ifOp = builder.create<scf::IfOp>(loc, cond, /*else=*/false);
445 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
446 swapOperands[0] = b;
447 swapOperands[1] = a;
448 createSwap(builder, loc, args: swapOperands, xPerm, ny);
449 return ifOp;
450}
451
452/// Creates code to insert the 3rd element to a list of two sorted elements.
453static void createInsert3rd(OpBuilder &builder, Location loc, AffineMap xPerm,
454 uint64_t ny, SmallVectorImpl<Value> &swapOperands,
455 SmallVectorImpl<Value> &compareOperands, Value v0,
456 Value v1, Value v2) {
457 scf::IfOp ifOp = createCompareThenSwap(builder, loc, xPerm, ny, swapOperands,
458 compareOperands, v1, v2);
459 createCompareThenSwap(builder, loc, xPerm, ny, swapOperands, compareOperands,
460 a: v0, b: v1);
461 builder.setInsertionPointAfter(ifOp);
462}
463
464/// Creates code to sort 3 elements.
465static void createSort3(OpBuilder &builder, Location loc, AffineMap xPerm,
466 uint64_t ny, SmallVectorImpl<Value> &swapOperands,
467 SmallVectorImpl<Value> &compareOperands, Value v0,
468 Value v1, Value v2) {
469 // Sort the first 2 elements.
470 scf::IfOp ifOp1 = createCompareThenSwap(builder, loc, xPerm, ny, swapOperands,
471 compareOperands, v0, v1);
472 builder.setInsertionPointAfter(ifOp1);
473
474 // Insert the 3th element.
475 createInsert3rd(builder, loc, xPerm, ny, swapOperands, compareOperands, v0,
476 v1, v2);
477}
478
479/// Creates code to sort 5 elements.
480static void createSort5(OpBuilder &builder, Location loc, AffineMap xPerm,
481 uint64_t ny, SmallVectorImpl<Value> &swapOperands,
482 SmallVectorImpl<Value> &compareOperands, Value v0,
483 Value v1, Value v2, Value v3, Value v4) {
484 // Sort the first 3 elements.
485 createSort3(builder, loc, xPerm, ny, swapOperands, compareOperands, v0, v1,
486 v2);
487
488 auto insert4th = [&]() {
489 scf::IfOp ifOp = createCompareThenSwap(
490 builder, loc, xPerm, ny, swapOperands, compareOperands, v2, v3);
491 createInsert3rd(builder, loc, xPerm, ny, swapOperands, compareOperands, v0,
492 v1, v2);
493 builder.setInsertionPointAfter(ifOp);
494 };
495
496 // Insert the 4th element.
497 insert4th();
498
499 // Insert the 5th element.
500 scf::IfOp ifOp = createCompareThenSwap(builder, loc, xPerm, ny, swapOperands,
501 compareOperands, v3, v4);
502 insert4th();
503 builder.setInsertionPointAfter(ifOp);
504}
505
506/// Creates a code block to swap the values in indices lo, mi, and hi so that
507/// data[lo], data[mi] and data[hi] are sorted in non-decreasing values. When
508/// the number of values in range [lo, hi) is more than a threshold, we also
509/// include the middle of [lo, mi) and [mi, hi) and sort a total of five values.
510static void createChoosePivot(OpBuilder &builder, ModuleOp module,
511 func::FuncOp func, AffineMap xPerm, uint64_t ny,
512 Value lo, Value hi, Value mi, ValueRange args) {
513 SmallVector<Value> compareOperands{mi, lo};
514 constexpr uint64_t numXBuffers = 1;
515 compareOperands.append(in_start: args.begin() + xStartIdx,
516 in_end: args.begin() + xStartIdx + numXBuffers);
517 SmallVector<Value> swapOperands{mi, lo};
518 swapOperands.append(in_start: args.begin() + xStartIdx, in_end: args.end());
519 Location loc = func.getLoc();
520 Value c1 = constantIndex(builder, loc, i: 1);
521 Value hiP1 = builder.create<arith::AddIOp>(loc, hi, c1);
522 Value len = builder.create<arith::SubIOp>(loc, hiP1, lo);
523 Value lenThreshold = constantIndex(builder, loc, i: 1000);
524 Value lenCond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
525 len, lenThreshold);
526 scf::IfOp lenIf = builder.create<scf::IfOp>(loc, lenCond, /*else=*/true);
527
528 // When len < 1000, choose pivot from median of 3 values.
529 builder.setInsertionPointToStart(&lenIf.getThenRegion().front());
530 createSort3(builder, loc, xPerm, ny, swapOperands, compareOperands, v0: lo, v1: mi,
531 v2: hi);
532
533 // When len >= 1000, choose pivot from median of 5 values.
534 builder.setInsertionPointToStart(&lenIf.getElseRegion().front());
535 Value miP1 = builder.create<arith::AddIOp>(loc, hi, c1);
536 Value a = builder.create<arith::AddIOp>(loc, lo, miP1);
537 // Value a is the middle between [loc, mi].
538 a = builder.create<arith::ShRUIOp>(loc, a, c1);
539 Value b = builder.create<arith::AddIOp>(loc, mi, hiP1);
540 // Value b is the middle between [mi, hi].
541 b = builder.create<arith::ShRUIOp>(loc, b, c1);
542 createSort5(builder, loc, xPerm, ny, swapOperands, compareOperands, v0: lo, v1: a, v2: mi,
543 v3: b, v4: hi);
544
545 builder.setInsertionPointAfter(lenIf);
546}
547
548/// Creates a function to perform quick sort partition on the values in the
549/// range of index [lo, hi), assuming lo < hi.
550//
551// The generated IR corresponds to this C like algorithm:
552// int partition(lo, hi, xs) {
553// p = (lo+hi)/2 // pivot index
554// i = lo
555// j = hi-1
556// while (true) do {
557// while (xs[i] < xs[p]) i ++;
558// i_eq = (xs[i] == xs[p]);
559// while (xs[j] > xs[p]) j --;
560// j_eq = (xs[j] == xs[p]);
561//
562// if (i >= j) return j + 1;
563//
564// if (i < j) {
565// swap(xs[i], xs[j])
566// if (i == p) {
567// p = j;
568// } else if (j == p) {
569// p = i;
570// }
571// if (i_eq && j_eq) {
572// ++i;
573// --j;
574// }
575// }
576// }
577// }
578static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
579 func::FuncOp func, AffineMap xPerm, uint64_t ny,
580 uint32_t nTrailingP = 0) {
581 // Quick sort partition doesn't use trailing parameters.
582 (void)nTrailingP;
583 assert(nTrailingP == 0);
584 OpBuilder::InsertionGuard insertionGuard(builder);
585
586 Block *entryBlock = func.addEntryBlock();
587 builder.setInsertionPointToStart(entryBlock);
588
589 Location loc = func.getLoc();
590 ValueRange args = entryBlock->getArguments();
591 Value lo = args[loIdx];
592 Value hi = args[hiIdx];
593 Value sum = builder.create<arith::AddIOp>(loc, lo, hi);
594 Value c1 = constantIndex(builder, loc, i: 1);
595 Value p = builder.create<arith::ShRUIOp>(loc, sum, c1);
596
597 Value i = lo;
598 Value j = builder.create<arith::SubIOp>(loc, hi, c1);
599 createChoosePivot(builder, module, func, xPerm, ny, i, j, p, args);
600 Value trueVal = constantI1(builder, loc, b: true); // The value for while (true)
601 SmallVector<Value, 4> operands{i, j, p, trueVal}; // Exactly four values.
602 SmallVector<Type, 4> types{i.getType(), j.getType(), p.getType(),
603 trueVal.getType()};
604 scf::WhileOp whileOp = builder.create<scf::WhileOp>(loc, types, operands);
605
606 // The before-region of the WhileOp.
607 Block *before = builder.createBlock(&whileOp.getBefore(), {}, types,
608 {loc, loc, loc, loc});
609 builder.setInsertionPointToEnd(before);
610 builder.create<scf::ConditionOp>(loc, before->getArgument(3),
611 before->getArguments());
612
613 // The after-region of the WhileOp.
614 Block *after =
615 builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc, loc, loc});
616 builder.setInsertionPointToEnd(after);
617 i = after->getArgument(i: 0);
618 j = after->getArgument(i: 1);
619 p = after->getArgument(i: 2);
620
621 constexpr uint64_t numXBuffers = 1;
622 auto [iresult, iCompareEq] =
623 createScanLoop(builder, module, func, args.slice(xStartIdx, numXBuffers),
624 i, p, xPerm, ny, 1);
625 i = iresult;
626 auto [jresult, jCompareEq] =
627 createScanLoop(builder, module, func, args.slice(xStartIdx, numXBuffers),
628 j, p, xPerm, ny, -1);
629 j = jresult;
630
631 // If i < j:
632 Value cond =
633 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, i, j);
634 scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, cond, /*else=*/true);
635 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
636 SmallVector<Value> swapOperands{i, j};
637 swapOperands.append(in_start: args.begin() + xStartIdx, in_end: args.end());
638 createSwap(builder, loc, args: swapOperands, xPerm, ny);
639 // If the pivot is moved, update p with the new pivot.
640 Value icond =
641 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, i, p);
642 scf::IfOp ifOpI = builder.create<scf::IfOp>(loc, TypeRange{p.getType()},
643 icond, /*else=*/true);
644 builder.setInsertionPointToStart(&ifOpI.getThenRegion().front());
645 builder.create<scf::YieldOp>(loc, ValueRange{j});
646 builder.setInsertionPointToStart(&ifOpI.getElseRegion().front());
647 Value jcond =
648 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, j, p);
649 scf::IfOp ifOpJ = builder.create<scf::IfOp>(loc, TypeRange{p.getType()},
650 jcond, /*else=*/true);
651 builder.setInsertionPointToStart(&ifOpJ.getThenRegion().front());
652 builder.create<scf::YieldOp>(loc, ValueRange{i});
653 builder.setInsertionPointToStart(&ifOpJ.getElseRegion().front());
654 builder.create<scf::YieldOp>(loc, ValueRange{p});
655 builder.setInsertionPointAfter(ifOpJ);
656 builder.create<scf::YieldOp>(loc, ifOpJ.getResults());
657 builder.setInsertionPointAfter(ifOpI);
658 Value compareEqIJ =
659 builder.create<arith::AndIOp>(loc, iCompareEq, jCompareEq);
660 scf::IfOp ifOp2 = builder.create<scf::IfOp>(
661 loc, TypeRange{i.getType(), j.getType()}, compareEqIJ, /*else=*/true);
662 builder.setInsertionPointToStart(&ifOp2.getThenRegion().front());
663 Value i2 = builder.create<arith::AddIOp>(loc, i, c1);
664 Value j2 = builder.create<arith::SubIOp>(loc, j, c1);
665 builder.create<scf::YieldOp>(loc, ValueRange{i2, j2});
666 builder.setInsertionPointToStart(&ifOp2.getElseRegion().front());
667 builder.create<scf::YieldOp>(loc, ValueRange{i, j});
668 builder.setInsertionPointAfter(ifOp2);
669 builder.create<scf::YieldOp>(
670 loc,
671 ValueRange{ifOp2.getResult(0), ifOp2.getResult(1), ifOpI.getResult(0),
672 /*cont=*/constantI1(builder, loc, true)});
673
674 // False branch for if i < j (i.e., i >= j):
675 builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
676 p = builder.create<arith::AddIOp>(loc, j,
677 constantOne(builder, loc, j.getType()));
678 builder.create<scf::YieldOp>(
679 loc, ValueRange{i, j, p, /*cont=*/constantI1(builder, loc, false)});
680
681 // Return for the whileOp.
682 builder.setInsertionPointAfter(ifOp);
683 builder.create<scf::YieldOp>(loc, ifOp.getResults());
684
685 // Return for the function.
686 builder.setInsertionPointAfter(whileOp);
687 builder.create<func::ReturnOp>(loc, whileOp.getResult(2));
688}
689
690/// Computes (n-2)/n, assuming n has index type.
691static Value createSubTwoDividedByTwo(OpBuilder &builder, Location loc,
692 Value n) {
693 Value i2 = constantIndex(builder, loc, i: 2);
694 Value res = builder.create<arith::SubIOp>(loc, n, i2);
695 Value i1 = constantIndex(builder, loc, i: 1);
696 return builder.create<arith::ShRUIOp>(loc, res, i1);
697}
698
699/// Creates a function to heapify the subtree with root `start` within the full
700/// binary tree in the range of index [first, first + n).
701//
702// The generated IR corresponds to this C like algorithm:
703// void shiftDown(first, start, n, data) {
704// if (n >= 2) {
705// child = start - first
706// if ((n-2)/2 >= child) {
707// // Left child exists.
708// child = child * 2 + 1 // Initialize the bigger child to left child.
709// childIndex = child + first
710// if (child+1 < n && data[childIndex] < data[childIndex+1])
711// // Right child exits and is bigger.
712// childIndex++; child++;
713// // Shift data[start] down to where it belongs in the subtree.
714// while (data[start] < data[childIndex) {
715// swap(data[start], data[childIndex])
716// start = childIndex
717// if ((n - 2)/2 >= child) {
718// // Left child exists.
719// child = 2*child + 1
720// childIndex = child + 1
721// if (child + 1) < n && data[childIndex] < data[childIndex+1]
722// childIndex++; child++;
723// }
724// }
725// }
726// }
727// }
728//
729static void createShiftDownFunc(OpBuilder &builder, ModuleOp module,
730 func::FuncOp func, AffineMap xPerm, uint64_t ny,
731 uint32_t nTrailingP) {
732 // The value n is passed in as a trailing parameter.
733 assert(nTrailingP == 1);
734 OpBuilder::InsertionGuard insertionGuard(builder);
735 Block *entryBlock = func.addEntryBlock();
736 builder.setInsertionPointToStart(entryBlock);
737
738 Location loc = func.getLoc();
739 Value n = entryBlock->getArguments().back();
740 ValueRange args = entryBlock->getArguments().drop_back();
741 Value first = args[loIdx];
742 Value start = args[hiIdx];
743
744 // If (n >= 2).
745 Value c2 = constantIndex(builder, loc, i: 2);
746 Value condN =
747 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, n, c2);
748 scf::IfOp ifN = builder.create<scf::IfOp>(loc, condN, /*else=*/false);
749 builder.setInsertionPointToStart(&ifN.getThenRegion().front());
750 Value child = builder.create<arith::SubIOp>(loc, start, first);
751
752 // If ((n-2)/2 >= child).
753 Value t = createSubTwoDividedByTwo(builder, loc, n);
754 Value condNc =
755 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, t, child);
756 scf::IfOp ifNc = builder.create<scf::IfOp>(loc, condNc, /*else=*/false);
757
758 builder.setInsertionPointToStart(&ifNc.getThenRegion().front());
759 Value c1 = constantIndex(builder, loc, i: 1);
760 SmallVector<Value> compareOperands{start, start};
761 constexpr uint64_t numXBuffers = 1;
762 compareOperands.append(in_start: args.begin() + xStartIdx,
763 in_end: args.begin() + xStartIdx + numXBuffers);
764
765 // Generate code to inspect the children of 'r' and return the larger child
766 // as follows:
767 // child = r * 2 + 1 // Left child.
768 // childIndex = child + first
769 // if (child+1 < n && data[childIndex] < data[childIndex+1])
770 // childIndex ++; child ++ // Right child is bigger.
771 auto getLargerChild = [&](Value r) -> std::pair<Value, Value> {
772 Value lChild = builder.create<arith::ShLIOp>(loc, r, c1);
773 lChild = builder.create<arith::AddIOp>(loc, lChild, c1);
774 Value lChildIdx = builder.create<arith::AddIOp>(loc, lChild, first);
775 Value rChild = builder.create<arith::AddIOp>(loc, lChild, c1);
776 Value cond1 = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
777 rChild, n);
778 SmallVector<Type, 2> ifTypes(2, r.getType());
779 scf::IfOp if1 =
780 builder.create<scf::IfOp>(loc, ifTypes, cond1, /*else=*/true);
781 builder.setInsertionPointToStart(&if1.getThenRegion().front());
782 Value rChildIdx = builder.create<arith::AddIOp>(loc, rChild, first);
783 // Compare data[left] < data[right].
784 compareOperands[0] = lChildIdx;
785 compareOperands[1] = rChildIdx;
786 Value cond2 =
787 createInlinedLessThan(builder, loc, args: compareOperands, xPerm, ny);
788 scf::IfOp if2 =
789 builder.create<scf::IfOp>(loc, ifTypes, cond2, /*else=*/true);
790 builder.setInsertionPointToStart(&if2.getThenRegion().front());
791 builder.create<scf::YieldOp>(loc, ValueRange{rChild, rChildIdx});
792 builder.setInsertionPointToStart(&if2.getElseRegion().front());
793 builder.create<scf::YieldOp>(loc, ValueRange{lChild, lChildIdx});
794 builder.setInsertionPointAfter(if2);
795 builder.create<scf::YieldOp>(loc, if2.getResults());
796 builder.setInsertionPointToStart(&if1.getElseRegion().front());
797 builder.create<scf::YieldOp>(loc, ValueRange{lChild, lChildIdx});
798 builder.setInsertionPointAfter(if1);
799 return std::make_pair(if1.getResult(0), if1.getResult(1));
800 };
801
802 Value childIdx;
803 std::tie(args&: child, args&: childIdx) = getLargerChild(child);
804
805 // While (data[start] < data[childIndex]).
806 SmallVector<Type, 3> types(3, child.getType());
807 scf::WhileOp whileOp = builder.create<scf::WhileOp>(
808 loc, types, SmallVector<Value, 2>{start, child, childIdx});
809
810 // The before-region of the WhileOp.
811 SmallVector<Location, 3> locs(3, loc);
812 Block *before = builder.createBlock(&whileOp.getBefore(), {}, types, locs);
813 builder.setInsertionPointToEnd(before);
814 start = before->getArgument(i: 0);
815 childIdx = before->getArgument(i: 2);
816 compareOperands[0] = start;
817 compareOperands[1] = childIdx;
818 Value cond = createInlinedLessThan(builder, loc, args: compareOperands, xPerm, ny);
819 builder.create<scf::ConditionOp>(loc, cond, before->getArguments());
820
821 // The after-region of the WhileOp.
822 Block *after = builder.createBlock(&whileOp.getAfter(), {}, types, locs);
823 start = after->getArgument(i: 0);
824 child = after->getArgument(i: 1);
825 childIdx = after->getArgument(i: 2);
826 SmallVector<Value> swapOperands{start, childIdx};
827 swapOperands.append(in_start: args.begin() + xStartIdx, in_end: args.end());
828 createSwap(builder, loc, args: swapOperands, xPerm, ny);
829 start = childIdx;
830 Value cond2 =
831 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, t, child);
832 scf::IfOp if2 = builder.create<scf::IfOp>(
833 loc, TypeRange{child.getType(), child.getType()}, cond2, /*else=*/true);
834 builder.setInsertionPointToStart(&if2.getThenRegion().front());
835 auto [newChild, newChildIdx] = getLargerChild(child);
836 builder.create<scf::YieldOp>(loc, ValueRange{newChild, newChildIdx});
837 builder.setInsertionPointToStart(&if2.getElseRegion().front());
838 builder.create<scf::YieldOp>(loc, ValueRange{child, childIdx});
839 builder.setInsertionPointAfter(if2);
840 builder.create<scf::YieldOp>(
841 loc, ValueRange{start, if2.getResult(0), if2.getResult(1)});
842
843 builder.setInsertionPointAfter(ifN);
844 builder.create<func::ReturnOp>(loc);
845}
846
847/// Creates a function to perform heap sort on the values in the range of index
848/// [lo, hi) with the assumption hi - lo >= 2.
849//
850// The generate IR corresponds to this C like algorithm:
851// void heapSort(lo, hi, data) {
852// n = hi - lo
853// for i = (n-2)/2 downto 0
854// shiftDown(lo, lo+i, n)
855//
856// for l = n downto 2
857// swap(lo, lo+l-1)
858// shiftdown(lo, lo, l-1)
859// }
860static void createHeapSortFunc(OpBuilder &builder, ModuleOp module,
861 func::FuncOp func, AffineMap xPerm, uint64_t ny,
862 uint32_t nTrailingP) {
863 // Heap sort function doesn't have trailing parameters.
864 (void)nTrailingP;
865 assert(nTrailingP == 0);
866 OpBuilder::InsertionGuard insertionGuard(builder);
867 Block *entryBlock = func.addEntryBlock();
868 builder.setInsertionPointToStart(entryBlock);
869
870 Location loc = func.getLoc();
871 ValueRange args = entryBlock->getArguments();
872 Value lo = args[loIdx];
873 Value hi = args[hiIdx];
874 Value n = builder.create<arith::SubIOp>(loc, hi, lo);
875
876 // For i = (n-2)/2 downto 0.
877 Value c0 = constantIndex(builder, loc, i: 0);
878 Value c1 = constantIndex(builder, loc, i: 1);
879 Value s = createSubTwoDividedByTwo(builder, loc, n);
880 Value up = builder.create<arith::AddIOp>(loc, s, c1);
881 scf::ForOp forI = builder.create<scf::ForOp>(loc, c0, up, c1);
882 builder.setInsertionPointToStart(forI.getBody());
883 Value i = builder.create<arith::SubIOp>(loc, s, forI.getInductionVar());
884 Value lopi = builder.create<arith::AddIOp>(loc, lo, i);
885 SmallVector<Value> shiftDownOperands = {lo, lopi};
886 shiftDownOperands.append(in_start: args.begin() + xStartIdx, in_end: args.end());
887 shiftDownOperands.push_back(Elt: n);
888 FlatSymbolRefAttr shiftDownFunc = getMangledSortHelperFunc(
889 builder, func, TypeRange(), kShiftDownFuncNamePrefix, xPerm, ny,
890 shiftDownOperands, createShiftDownFunc, /*nTrailingP=*/1);
891 builder.create<func::CallOp>(loc, shiftDownFunc, TypeRange(),
892 shiftDownOperands);
893
894 builder.setInsertionPointAfter(forI);
895 // For l = n downto 2.
896 up = builder.create<arith::SubIOp>(loc, n, c1);
897 scf::ForOp forL = builder.create<scf::ForOp>(loc, c0, up, c1);
898 builder.setInsertionPointToStart(forL.getBody());
899 Value l = builder.create<arith::SubIOp>(loc, n, forL.getInductionVar());
900 Value loplm1 = builder.create<arith::AddIOp>(loc, lo, l);
901 loplm1 = builder.create<arith::SubIOp>(loc, loplm1, c1);
902 SmallVector<Value> swapOperands{lo, loplm1};
903 swapOperands.append(in_start: args.begin() + xStartIdx, in_end: args.end());
904 createSwap(builder, loc, args: swapOperands, xPerm, ny);
905 shiftDownOperands[1] = lo;
906 shiftDownOperands[shiftDownOperands.size() - 1] =
907 builder.create<arith::SubIOp>(loc, l, c1);
908 builder.create<func::CallOp>(loc, shiftDownFunc, TypeRange(),
909 shiftDownOperands);
910
911 builder.setInsertionPointAfter(forL);
912 builder.create<func::ReturnOp>(loc);
913}
914
915/// A helper for generating code to perform quick sort. It partitions [lo, hi),
916/// recursively calls quick sort to process the smaller partition and returns
917/// the bigger partition to be processed by the enclosed while-loop.
918static std::pair<Value, Value>
919createQuickSort(OpBuilder &builder, ModuleOp module, func::FuncOp func,
920 ValueRange args, AffineMap xPerm, uint64_t ny,
921 uint32_t nTrailingP) {
922 MLIRContext *context = module.getContext();
923 Location loc = func.getLoc();
924 Value lo = args[loIdx];
925 Value hi = args[hiIdx];
926 SmallVector<Type, 2> types(2, lo.getType()); // Only two types.
927
928 FlatSymbolRefAttr partitionFunc = getMangledSortHelperFunc(
929 builder, func, {IndexType::get(context)}, kPartitionFuncNamePrefix, xPerm,
930 ny, args.drop_back(nTrailingP), createPartitionFunc);
931 Value p = builder
932 .create<func::CallOp>(loc, partitionFunc,
933 TypeRange{IndexType::get(context)},
934 args.drop_back(nTrailingP))
935 .getResult(0);
936
937 Value lenLow = builder.create<arith::SubIOp>(loc, p, lo);
938 Value lenHigh = builder.create<arith::SubIOp>(loc, hi, p);
939 // Partition already sorts array with len <= 2
940 Value c2 = constantIndex(builder, loc, i: 2);
941 Value len = builder.create<arith::SubIOp>(loc, hi, lo);
942 Value lenGtTwo =
943 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ugt, len, c2);
944 scf::IfOp ifLenGtTwo =
945 builder.create<scf::IfOp>(loc, types, lenGtTwo, /*else=*/true);
946 builder.setInsertionPointToStart(&ifLenGtTwo.getElseRegion().front());
947 // Returns an empty range to mark the entire region is fully sorted.
948 builder.create<scf::YieldOp>(loc, ValueRange{lo, lo});
949
950 // Else len > 2, need recursion.
951 builder.setInsertionPointToStart(&ifLenGtTwo.getThenRegion().front());
952 Value cond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ule,
953 lenLow, lenHigh);
954
955 Value c0 = constantIndex(builder, loc, i: 0);
956 scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, cond, /*else=*/true);
957
958 auto mayRecursion = [&](Value low, Value high, Value len) {
959 Value cond =
960 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, len, c0);
961 scf::IfOp ifOp = builder.create<scf::IfOp>(loc, cond, /*else=*/false);
962 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
963 SmallVector<Value> operands{low, high};
964 operands.append(in_start: args.begin() + xStartIdx, in_end: args.end());
965 builder.create<func::CallOp>(loc, func, operands);
966 builder.setInsertionPointAfter(ifOp);
967 };
968
969 // Recursively call quickSort to process the smaller partition and return
970 // the bigger partition to be processed by the enclosed while-loop.
971 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
972 mayRecursion(lo, p, lenLow);
973 builder.create<scf::YieldOp>(loc, ValueRange{p, hi});
974
975 builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
976 mayRecursion(p, hi, lenHigh);
977 builder.create<scf::YieldOp>(loc, ValueRange{lo, p});
978
979 builder.setInsertionPointAfter(ifOp);
980 builder.create<scf::YieldOp>(loc, ifOp.getResults());
981
982 builder.setInsertionPointAfter(ifLenGtTwo);
983 return std::make_pair(ifLenGtTwo.getResult(0), ifLenGtTwo.getResult(1));
984}
985
986/// Creates a function to perform insertion sort on the values in the range of
987/// index [lo, hi).
988//
989// The generate IR corresponds to this C like algorithm:
990// void insertionSort(lo, hi, data) {
991// for (i = lo+1; i < hi; i++) {
992// d = data[i];
993// p = binarySearch(lo, i-1, data)
994// for (j = 0; j > i - p; j++)
995// data[i-j] = data[i-j-1]
996// data[p] = d
997// }
998// }
999static void createSortStableFunc(OpBuilder &builder, ModuleOp module,
1000 func::FuncOp func, AffineMap xPerm,
1001 uint64_t ny, uint32_t nTrailingP) {
1002 // Stable sort function doesn't use trailing parameters.
1003 (void)nTrailingP;
1004 assert(nTrailingP == 0);
1005 OpBuilder::InsertionGuard insertionGuard(builder);
1006 Block *entryBlock = func.addEntryBlock();
1007 builder.setInsertionPointToStart(entryBlock);
1008
1009 MLIRContext *context = module.getContext();
1010 Location loc = func.getLoc();
1011 ValueRange args = entryBlock->getArguments();
1012 Value c1 = constantIndex(builder, loc, i: 1);
1013 Value lo = args[loIdx];
1014 Value hi = args[hiIdx];
1015 Value lop1 = builder.create<arith::AddIOp>(loc, lo, c1);
1016
1017 // Start the outer for-stmt with induction variable i.
1018 scf::ForOp forOpI = builder.create<scf::ForOp>(loc, lop1, hi, c1);
1019 builder.setInsertionPointToStart(forOpI.getBody());
1020 Value i = forOpI.getInductionVar();
1021
1022 // Binary search to find the insertion point p.
1023 SmallVector<Value> operands{lo, i};
1024 operands.append(in_start: args.begin() + xStartIdx, in_end: args.end());
1025 FlatSymbolRefAttr searchFunc = getMangledSortHelperFunc(
1026 builder, func, {IndexType::get(context)}, kBinarySearchFuncNamePrefix,
1027 xPerm, ny, operands, createBinarySearchFunc);
1028 Value p = builder
1029 .create<func::CallOp>(loc, searchFunc, TypeRange{c1.getType()},
1030 operands)
1031 .getResult(0);
1032
1033 // Move the value at data[i] to a temporary location.
1034 operands[0] = operands[1] = i;
1035 SmallVector<Value> d;
1036 forEachIJPairInAllBuffers(
1037 builder, loc, args: operands, xPerm, ny,
1038 bodyBuilder: [&](uint64_t unused, Value i, Value unused2, Value buffer) {
1039 d.push_back(builder.create<memref::LoadOp>(loc, buffer, i));
1040 });
1041
1042 // Start the inner for-stmt with induction variable j, for moving data[p..i)
1043 // to data[p+1..i+1).
1044 Value imp = builder.create<arith::SubIOp>(loc, i, p);
1045 Value c0 = constantIndex(builder, loc, i: 0);
1046 scf::ForOp forOpJ = builder.create<scf::ForOp>(loc, c0, imp, c1);
1047 builder.setInsertionPointToStart(forOpJ.getBody());
1048 Value j = forOpJ.getInductionVar();
1049 Value imj = builder.create<arith::SubIOp>(loc, i, j);
1050 operands[1] = imj;
1051 operands[0] = builder.create<arith::SubIOp>(loc, imj, c1);
1052 forEachIJPairInAllBuffers(
1053 builder, loc, args: operands, xPerm, ny,
1054 bodyBuilder: [&](uint64_t unused, Value imjm1, Value imj, Value buffer) {
1055 Value t = builder.create<memref::LoadOp>(loc, buffer, imjm1);
1056 builder.create<memref::StoreOp>(loc, t, buffer, imj);
1057 });
1058
1059 // Store the value at data[i] to data[p].
1060 builder.setInsertionPointAfter(forOpJ);
1061 operands[0] = operands[1] = p;
1062 forEachIJPairInAllBuffers(
1063 builder, loc, args: operands, xPerm, ny,
1064 bodyBuilder: [&](uint64_t k, Value p, Value usused, Value buffer) {
1065 builder.create<memref::StoreOp>(loc, d[k], buffer, p);
1066 });
1067
1068 builder.setInsertionPointAfter(forOpI);
1069 builder.create<func::ReturnOp>(loc);
1070}
1071
1072/// Creates a function to perform quick sort or a hybrid quick sort on the
1073/// values in the range of index [lo, hi).
1074//
1075//
1076// When nTrailingP == 0, the generated IR corresponds to this C like algorithm:
1077// void quickSort(lo, hi, data) {
1078// while (lo + 1 < hi) {
1079// p = partition(low, high, data);
1080// if (len(lo, p) < len(p+1, hi)) {
1081// quickSort(lo, p, data);
1082// lo = p+1;
1083// } else {
1084// quickSort(p + 1, hi, data);
1085// hi = p;
1086// }
1087// }
1088// }
1089//
1090// When nTrailingP == 1, the generated IR corresponds to this C like algorithm:
1091// void hybridQuickSort(lo, hi, data, depthLimit) {
1092// while (lo + 1 < hi) {
1093// len = hi - lo;
1094// if (len <= limit) {
1095// insertionSort(lo, hi, data);
1096// } else {
1097// depthLimit --;
1098// if (depthLimit <= 0) {
1099// heapSort(lo, hi, data);
1100// } else {
1101// p = partition(low, high, data);
1102// if (len(lo, p) < len(p+1, hi)) {
1103// quickSort(lo, p, data, depthLimit);
1104// lo = p+1;
1105// } else {
1106// quickSort(p + 1, hi, data, depthLimit);
1107// hi = p;
1108// }
1109// }
1110// }
1111// }
1112// }
1113//
1114static void createQuickSortFunc(OpBuilder &builder, ModuleOp module,
1115 func::FuncOp func, AffineMap xPerm, uint64_t ny,
1116 uint32_t nTrailingP) {
1117 assert(nTrailingP == 1 || nTrailingP == 0);
1118 bool isHybrid = (nTrailingP == 1);
1119 OpBuilder::InsertionGuard insertionGuard(builder);
1120 Block *entryBlock = func.addEntryBlock();
1121 builder.setInsertionPointToStart(entryBlock);
1122
1123 Location loc = func.getLoc();
1124 SmallVector<Value> args;
1125 args.append(in_start: entryBlock->getArguments().begin(),
1126 in_end: entryBlock->getArguments().end());
1127 Value lo = args[loIdx];
1128 Value hi = args[hiIdx];
1129 SmallVector<Type, 2> types(2, lo.getType()); // Only two types.
1130 scf::WhileOp whileOp =
1131 builder.create<scf::WhileOp>(loc, types, SmallVector<Value, 2>{lo, hi});
1132
1133 // The before-region of the WhileOp.
1134 Block *before =
1135 builder.createBlock(&whileOp.getBefore(), {}, types, {loc, loc});
1136 builder.setInsertionPointToEnd(before);
1137 lo = before->getArgument(i: 0);
1138 hi = before->getArgument(i: 1);
1139 Value loP1 =
1140 builder.create<arith::AddIOp>(loc, lo, constantIndex(builder, loc, 1));
1141 Value needSort =
1142 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, loP1, hi);
1143 builder.create<scf::ConditionOp>(loc, needSort, before->getArguments());
1144
1145 // The after-region of the WhileOp.
1146 Block *after =
1147 builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc});
1148 builder.setInsertionPointToEnd(after);
1149 lo = after->getArgument(i: 0);
1150 hi = after->getArgument(i: 1);
1151 args[0] = lo;
1152 args[1] = hi;
1153
1154 if (isHybrid) {
1155 Value len = builder.create<arith::SubIOp>(loc, hi, lo);
1156 Value lenLimit = constantIndex(builder, loc, i: 30);
1157 Value lenCond = builder.create<arith::CmpIOp>(
1158 loc, arith::CmpIPredicate::ule, len, lenLimit);
1159 scf::IfOp lenIf =
1160 builder.create<scf::IfOp>(loc, types, lenCond, /*else=*/true);
1161
1162 // When len <= limit.
1163 builder.setInsertionPointToStart(&lenIf.getThenRegion().front());
1164 FlatSymbolRefAttr insertionSortFunc = getMangledSortHelperFunc(
1165 builder, func, TypeRange(), kSortStableFuncNamePrefix, xPerm, ny,
1166 ValueRange(args).drop_back(nTrailingP), createSortStableFunc);
1167 builder.create<func::CallOp>(loc, insertionSortFunc, TypeRange(),
1168 ValueRange(args).drop_back(nTrailingP));
1169 builder.create<scf::YieldOp>(loc, ValueRange{lo, lo});
1170
1171 // When len > limit.
1172 builder.setInsertionPointToStart(&lenIf.getElseRegion().front());
1173 Value depthLimit = args.back();
1174 depthLimit = builder.create<arith::SubIOp>(loc, depthLimit,
1175 constantI64(builder, loc, 1));
1176 Value depthCond =
1177 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ule,
1178 depthLimit, constantI64(builder, loc, 0));
1179 scf::IfOp depthIf =
1180 builder.create<scf::IfOp>(loc, types, depthCond, /*else=*/true);
1181
1182 // When depth exceeds limit.
1183 builder.setInsertionPointToStart(&depthIf.getThenRegion().front());
1184 FlatSymbolRefAttr heapSortFunc = getMangledSortHelperFunc(
1185 builder, func, TypeRange(), kHeapSortFuncNamePrefix, xPerm, ny,
1186 ValueRange(args).drop_back(nTrailingP), createHeapSortFunc);
1187 builder.create<func::CallOp>(loc, heapSortFunc, TypeRange(),
1188 ValueRange(args).drop_back(nTrailingP));
1189 builder.create<scf::YieldOp>(loc, ValueRange{lo, lo});
1190
1191 // When depth doesn't exceed limit.
1192 builder.setInsertionPointToStart(&depthIf.getElseRegion().front());
1193 args.back() = depthLimit;
1194 std::tie(lo, hi) =
1195 createQuickSort(builder, module, func, args, xPerm, ny, nTrailingP);
1196 builder.create<scf::YieldOp>(loc, ValueRange{lo, hi});
1197
1198 builder.setInsertionPointAfter(depthIf);
1199 lo = depthIf.getResult(0);
1200 hi = depthIf.getResult(1);
1201 builder.create<scf::YieldOp>(loc, ValueRange{lo, hi});
1202
1203 builder.setInsertionPointAfter(lenIf);
1204 lo = lenIf.getResult(0);
1205 hi = lenIf.getResult(1);
1206 } else {
1207 std::tie(lo, hi) =
1208 createQuickSort(builder, module, func, args, xPerm, ny, nTrailingP);
1209 }
1210
1211 // New [lo, hi) for the next while-loop iteration.
1212 builder.create<scf::YieldOp>(loc, ValueRange{lo, hi});
1213
1214 // After the while-loop.
1215 builder.setInsertionPointAfter(whileOp);
1216 builder.create<func::ReturnOp>(loc);
1217}
1218
1219/// Implements the rewriting for operator sort and sort_coo.
1220template <typename OpTy>
1221LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, AffineMap xPerm,
1222 uint64_t ny, PatternRewriter &rewriter) {
1223 Location loc = op.getLoc();
1224 SmallVector<Value> operands{constantIndex(builder&: rewriter, loc, i: 0), op.getN()};
1225
1226 // Convert `values` to have dynamic shape and append them to `operands`.
1227 for (Value v : xys) {
1228 auto mtp = getMemRefType(v);
1229 if (!mtp.isDynamicDim(0)) {
1230 auto newMtp =
1231 MemRefType::get({ShapedType::kDynamic}, mtp.getElementType());
1232 v = rewriter.create<memref::CastOp>(loc, newMtp, v);
1233 }
1234 operands.push_back(Elt: v);
1235 }
1236
1237 auto insertPoint = op->template getParentOfType<func::FuncOp>();
1238 if (!insertPoint)
1239 return failure();
1240
1241 SmallString<32> funcName;
1242 FuncGeneratorType funcGenerator;
1243 uint32_t nTrailingP = 0;
1244 switch (op.getAlgorithm()) {
1245 case SparseTensorSortKind::HybridQuickSort: {
1246 funcName = kHybridQuickSortFuncNamePrefix;
1247 funcGenerator = createQuickSortFunc;
1248 nTrailingP = 1;
1249 // As a heuristics, set depthLimit = 2 * log2(n).
1250 Value lo = operands[loIdx];
1251 Value hi = operands[hiIdx];
1252 Value len = rewriter.create<arith::IndexCastOp>(
1253 loc, rewriter.getI64Type(),
1254 rewriter.create<arith::SubIOp>(loc, hi, lo));
1255 Value depthLimit = rewriter.create<arith::SubIOp>(
1256 loc, constantI64(rewriter, loc, 64),
1257 rewriter.create<math::CountLeadingZerosOp>(loc, len));
1258 operands.push_back(Elt: depthLimit);
1259 break;
1260 }
1261 case SparseTensorSortKind::QuickSort:
1262 funcName = kQuickSortFuncNamePrefix;
1263 funcGenerator = createQuickSortFunc;
1264 break;
1265 case SparseTensorSortKind::InsertionSortStable:
1266 funcName = kSortStableFuncNamePrefix;
1267 funcGenerator = createSortStableFunc;
1268 break;
1269 case SparseTensorSortKind::HeapSort:
1270 funcName = kHeapSortFuncNamePrefix;
1271 funcGenerator = createHeapSortFunc;
1272 break;
1273 }
1274
1275 FlatSymbolRefAttr func =
1276 getMangledSortHelperFunc(rewriter, insertPoint, TypeRange(), funcName,
1277 xPerm, ny, operands, funcGenerator, nTrailingP);
1278 rewriter.replaceOpWithNewOp<func::CallOp>(op, func, TypeRange(), operands);
1279 return success();
1280}
1281
1282//===---------------------------------------------------------------------===//
1283// The actual sparse buffer rewriting rules.
1284//===---------------------------------------------------------------------===//
1285
1286namespace {
1287/// Sparse rewriting rule for the push_back operator.
1288struct PushBackRewriter : OpRewritePattern<PushBackOp> {
1289public:
1290 using OpRewritePattern<PushBackOp>::OpRewritePattern;
1291 PushBackRewriter(MLIRContext *context, bool enableInit)
1292 : OpRewritePattern(context), enableBufferInitialization(enableInit) {}
1293 LogicalResult matchAndRewrite(PushBackOp op,
1294 PatternRewriter &rewriter) const override {
1295 // Rewrite push_back(buffer, value, n) to:
1296 // new_size = size(buffer) + n
1297 // if (new_size > capacity(buffer))
1298 // while new_size > new_capacity
1299 // new_capacity = new_capacity*2
1300 // new_buffer = realloc(buffer, new_capacity)
1301 // buffer = new_buffer
1302 // subBuffer = subviewof(buffer)
1303 // linalg.fill subBuffer value
1304 //
1305 // size(buffer) += n
1306 //
1307 // The capacity check is skipped when the attribute inbounds is presented.
1308 Location loc = op->getLoc();
1309 Value c0 = constantIndex(builder&: rewriter, loc, i: 0);
1310 Value buffer = op.getInBuffer();
1311 Value capacity = rewriter.create<memref::DimOp>(loc, buffer, c0);
1312 Value size = op.getCurSize();
1313 Value value = op.getValue();
1314
1315 Value n = op.getN() ? op.getN() : constantIndex(builder&: rewriter, loc, i: 1);
1316 Value newSize = rewriter.create<arith::AddIOp>(loc, size, n);
1317 auto nValue = dyn_cast_or_null<arith::ConstantIndexOp>(Val: n.getDefiningOp());
1318 bool nIsOne = (nValue && nValue.value() == 1);
1319
1320 if (!op.getInbounds()) {
1321 Value cond = rewriter.create<arith::CmpIOp>(
1322 loc, arith::CmpIPredicate::ugt, newSize, capacity);
1323
1324 Value c2 = constantIndex(builder&: rewriter, loc, i: 2);
1325 auto bufferType =
1326 MemRefType::get({ShapedType::kDynamic}, value.getType());
1327 scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, bufferType, cond,
1328 /*else=*/true);
1329 // True branch.
1330 rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
1331 if (nIsOne) {
1332 capacity = rewriter.create<arith::MulIOp>(loc, capacity, c2);
1333 } else {
1334 // Use a do-while loop to calculate the new capacity as follows:
1335 // do { new_capacity *= 2 } while (size > new_capacity)
1336 scf::WhileOp whileOp =
1337 rewriter.create<scf::WhileOp>(loc, capacity.getType(), capacity);
1338
1339 // The before-region of the WhileOp.
1340 Block *before = rewriter.createBlock(&whileOp.getBefore(), {},
1341 {capacity.getType()}, {loc});
1342 rewriter.setInsertionPointToEnd(before);
1343
1344 capacity =
1345 rewriter.create<arith::MulIOp>(loc, before->getArgument(0), c2);
1346 cond = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ugt,
1347 newSize, capacity);
1348 rewriter.create<scf::ConditionOp>(loc, cond, ValueRange{capacity});
1349 // The after-region of the WhileOp.
1350 Block *after = rewriter.createBlock(&whileOp.getAfter(), {},
1351 {capacity.getType()}, {loc});
1352 rewriter.setInsertionPointToEnd(after);
1353 rewriter.create<scf::YieldOp>(loc, after->getArguments());
1354
1355 rewriter.setInsertionPointAfter(whileOp);
1356 capacity = whileOp.getResult(0);
1357 }
1358
1359 Value newBuffer =
1360 rewriter.create<memref::ReallocOp>(loc, bufferType, buffer, capacity);
1361 if (enableBufferInitialization) {
1362 Value fillSize = rewriter.create<arith::SubIOp>(loc, capacity, newSize);
1363 Value fillValue = constantZero(builder&: rewriter, loc, tp: value.getType());
1364 Value subBuffer = rewriter.create<memref::SubViewOp>(
1365 loc, newBuffer, /*offset=*/ValueRange{newSize},
1366 /*size=*/ValueRange{fillSize},
1367 /*step=*/ValueRange{constantIndex(rewriter, loc, 1)});
1368 rewriter.create<linalg::FillOp>(loc, fillValue, subBuffer);
1369 }
1370 rewriter.create<scf::YieldOp>(loc, newBuffer);
1371
1372 // False branch.
1373 rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
1374 rewriter.create<scf::YieldOp>(loc, buffer);
1375
1376 // Prepare for adding the value to the end of the buffer.
1377 rewriter.setInsertionPointAfter(ifOp);
1378 buffer = ifOp.getResult(0);
1379 }
1380
1381 // Add the value to the end of the buffer.
1382 if (nIsOne) {
1383 rewriter.create<memref::StoreOp>(loc, value, buffer, size);
1384 } else {
1385 Value subBuffer = rewriter.create<memref::SubViewOp>(
1386 loc, buffer, /*offset=*/ValueRange{size}, /*size=*/ValueRange{n},
1387 /*step=*/ValueRange{constantIndex(rewriter, loc, 1)});
1388 rewriter.create<linalg::FillOp>(loc, value, subBuffer);
1389 }
1390
1391 // Update the buffer size.
1392 rewriter.replaceOp(op, {buffer, newSize});
1393 return success();
1394 }
1395
1396private:
1397 bool enableBufferInitialization;
1398};
1399
1400/// Sparse rewriting rule for the sort_coo operator.
1401struct SortRewriter : public OpRewritePattern<SortOp> {
1402public:
1403 using OpRewritePattern<SortOp>::OpRewritePattern;
1404
1405 LogicalResult matchAndRewrite(SortOp op,
1406 PatternRewriter &rewriter) const override {
1407 SmallVector<Value> xys;
1408 xys.push_back(Elt: op.getXy());
1409 xys.append(op.getYs().begin(), op.getYs().end());
1410
1411 auto xPerm = op.getPermMap();
1412 uint64_t ny = 0;
1413 if (auto nyAttr = op.getNyAttr())
1414 ny = nyAttr.getInt();
1415
1416 return matchAndRewriteSortOp(op, xys, xPerm, ny, rewriter);
1417 }
1418};
1419
1420} // namespace
1421
1422//===---------------------------------------------------------------------===//
1423// Methods that add patterns described in this file to a pattern list.
1424//===---------------------------------------------------------------------===//
1425
1426void mlir::populateSparseBufferRewriting(RewritePatternSet &patterns,
1427 bool enableBufferInitialization) {
1428 patterns.add<PushBackRewriter>(arg: patterns.getContext(),
1429 args&: enableBufferInitialization);
1430 patterns.add<SortRewriter>(arg: patterns.getContext());
1431}
1432

source code of mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp