1 | //===- SparseBufferRewriting.cpp - Sparse buffer rewriting rules ----------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements rewriting rules that are specific to sparse tensor |
10 | // primitives with memref operands. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "Utils/CodegenUtils.h" |
15 | |
16 | #include "mlir/Dialect/Arith/IR/Arith.h" |
17 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
18 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
19 | #include "mlir/Dialect/Math/IR/Math.h" |
20 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
21 | #include "mlir/Dialect/SCF/IR/SCF.h" |
22 | #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" |
23 | #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" |
24 | #include "mlir/Support/LLVM.h" |
25 | |
26 | using namespace mlir; |
27 | using namespace mlir::sparse_tensor; |
28 | |
29 | //===---------------------------------------------------------------------===// |
30 | // Helper methods for the actual rewriting rules. |
31 | //===---------------------------------------------------------------------===// |
32 | |
33 | static constexpr uint64_t loIdx = 0; |
34 | static constexpr uint64_t hiIdx = 1; |
35 | static constexpr uint64_t xStartIdx = 2; |
36 | |
37 | static constexpr const char kPartitionFuncNamePrefix[] = "_sparse_partition_" ; |
38 | static constexpr const char kBinarySearchFuncNamePrefix[] = |
39 | "_sparse_binary_search_" ; |
40 | static constexpr const char kHybridQuickSortFuncNamePrefix[] = |
41 | "_sparse_hybrid_qsort_" ; |
42 | static constexpr const char kSortStableFuncNamePrefix[] = |
43 | "_sparse_sort_stable_" ; |
44 | static constexpr const char kShiftDownFuncNamePrefix[] = "_sparse_shift_down_" ; |
45 | static constexpr const char kHeapSortFuncNamePrefix[] = "_sparse_heap_sort_" ; |
46 | static constexpr const char kQuickSortFuncNamePrefix[] = "_sparse_qsort_" ; |
47 | |
48 | using FuncGeneratorType = function_ref<void(OpBuilder &, ModuleOp, func::FuncOp, |
49 | AffineMap, uint64_t, uint32_t)>; |
50 | |
51 | /// Constructs a function name with this format to facilitate quick sort: |
52 | /// <namePrefix><xPerm>_<x type>_<y0 type>..._<yn type> for sort |
53 | /// <namePrefix><xPerm>_<x type>_coo_<ny>_<y0 type>..._<yn type> for sort_coo |
54 | static void getMangledSortHelperFuncName(llvm::raw_svector_ostream &nameOstream, |
55 | StringRef namePrefix, AffineMap xPerm, |
56 | uint64_t ny, ValueRange operands) { |
57 | nameOstream << namePrefix; |
58 | for (auto res : xPerm.getResults()) |
59 | nameOstream << cast<AffineDimExpr>(Val&: res).getPosition() << "_" ; |
60 | |
61 | nameOstream << getMemRefType(operands[xStartIdx]).getElementType(); |
62 | nameOstream << "_coo_" << ny; |
63 | |
64 | constexpr uint64_t yBufferOffset = 1; |
65 | for (Value v : operands.drop_front(n: xStartIdx + yBufferOffset)) |
66 | nameOstream << "_" << getMemRefType(v).getElementType(); |
67 | } |
68 | |
69 | /// Looks up a function that is appropriate for the given operands being |
70 | /// sorted, and creates such a function if it doesn't exist yet. The |
71 | /// parameters `xPerm` and `ny` tell the number of x and y values provided |
72 | /// by the buffer in xStartIdx. |
73 | // |
74 | // All sorting function generators take (lo, hi, xs, ys) in `operands` as |
75 | // parameters for the sorting functions. Other parameters, such as the recursive |
76 | // call depth, are appended to the end of the parameter list as |
77 | // "trailing parameters". |
78 | static FlatSymbolRefAttr getMangledSortHelperFunc( |
79 | OpBuilder &builder, func::FuncOp insertPoint, TypeRange resultTypes, |
80 | StringRef namePrefix, AffineMap xPerm, uint64_t ny, ValueRange operands, |
81 | FuncGeneratorType createFunc, uint32_t nTrailingP = 0) { |
82 | SmallString<32> nameBuffer; |
83 | llvm::raw_svector_ostream nameOstream(nameBuffer); |
84 | getMangledSortHelperFuncName(nameOstream, namePrefix, xPerm, ny, |
85 | operands: operands.drop_back(n: nTrailingP)); |
86 | |
87 | ModuleOp module = insertPoint->getParentOfType<ModuleOp>(); |
88 | MLIRContext *context = module.getContext(); |
89 | auto result = SymbolRefAttr::get(context, nameOstream.str()); |
90 | auto func = module.lookupSymbol<func::FuncOp>(result.getAttr()); |
91 | |
92 | if (!func) { |
93 | // Create the function. |
94 | OpBuilder::InsertionGuard insertionGuard(builder); |
95 | builder.setInsertionPoint(insertPoint); |
96 | Location loc = insertPoint.getLoc(); |
97 | func = builder.create<func::FuncOp>( |
98 | loc, nameOstream.str(), |
99 | FunctionType::get(context, operands.getTypes(), resultTypes)); |
100 | func.setPrivate(); |
101 | createFunc(builder, module, func, xPerm, ny, nTrailingP); |
102 | } |
103 | |
104 | return result; |
105 | } |
106 | |
107 | /// Creates a code block to process each pair of (xs[i], xs[j]) for sorting. |
108 | /// The code to process the value pairs is generated by `bodyBuilder`. |
109 | static void forEachIJPairInXs( |
110 | OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm, |
111 | uint64_t ny, |
112 | function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) { |
113 | Value cstep = constantIndex(builder, loc, i: xPerm.getNumResults() + ny); |
114 | Value iOffset = builder.create<arith::MulIOp>(loc, args[0], cstep); |
115 | Value jOffset = builder.create<arith::MulIOp>(loc, args[1], cstep); |
116 | for (unsigned k = 0, e = xPerm.getNumResults(); k < e; k++) { |
117 | unsigned actualK = cast<AffineDimExpr>(Val: xPerm.getResult(idx: k)).getPosition(); |
118 | Value ak = constantIndex(builder, loc, i: actualK); |
119 | Value i = builder.create<arith::AddIOp>(loc, ak, iOffset); |
120 | Value j = builder.create<arith::AddIOp>(loc, ak, jOffset); |
121 | Value buffer = args[xStartIdx]; |
122 | |
123 | bodyBuilder(k, i, j, buffer); |
124 | } |
125 | } |
126 | |
127 | /// Creates a code block to process each pair of (xys[i], xys[j]) for sorting. |
128 | /// The code to process the value pairs is generated by `bodyBuilder`. |
129 | static void forEachIJPairInAllBuffers( |
130 | OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm, |
131 | uint64_t ny, |
132 | function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) { |
133 | |
134 | // Create code for the first (xPerm + ny) buffers. |
135 | SmallVector<AffineExpr> exps(xPerm.getResults().begin(), |
136 | xPerm.getResults().end()); |
137 | for (unsigned y = 0; y < ny; y++) { |
138 | exps.push_back(Elt: builder.getAffineDimExpr(position: y + xPerm.getNumResults())); |
139 | } |
140 | AffineMap xyPerm = AffineMap::get(dimCount: exps.size(), symbolCount: 0, results: exps, context: builder.getContext()); |
141 | assert(xyPerm.isPermutation()); |
142 | |
143 | forEachIJPairInXs(builder, loc, args, xPerm: xyPerm, ny: 0, bodyBuilder); |
144 | |
145 | constexpr uint64_t numHandledBuffers = 1; |
146 | // Create code for the remaining buffers. |
147 | Value i = args[0]; |
148 | Value j = args[1]; |
149 | for (const auto &arg : |
150 | llvm::enumerate(First: args.drop_front(n: xStartIdx + numHandledBuffers))) { |
151 | bodyBuilder(arg.index() + xPerm.getNumResults() + ny, i, j, arg.value()); |
152 | } |
153 | } |
154 | |
155 | /// Creates a code block for swapping the values in index i and j for all the |
156 | /// buffers. |
157 | // |
158 | // The generated IR corresponds to this C like algorithm: |
159 | // swap(x0[i], x0[j]); |
160 | // swap(x1[i], x1[j]); |
161 | // ... |
162 | // swap(xn[i], xn[j]); |
163 | // swap(y0[i], y0[j]); |
164 | // ... |
165 | // swap(yn[i], yn[j]); |
166 | static void createSwap(OpBuilder &builder, Location loc, ValueRange args, |
167 | AffineMap xPerm, uint64_t ny) { |
168 | auto swapOnePair = [&](uint64_t unused, Value i, Value j, Value buffer) { |
169 | Value vi = builder.create<memref::LoadOp>(loc, buffer, i); |
170 | Value vj = builder.create<memref::LoadOp>(loc, buffer, j); |
171 | builder.create<memref::StoreOp>(loc, vj, buffer, i); |
172 | builder.create<memref::StoreOp>(loc, vi, buffer, j); |
173 | }; |
174 | |
175 | forEachIJPairInAllBuffers(builder, loc, args, xPerm, ny, bodyBuilder: swapOnePair); |
176 | } |
177 | |
178 | /// Creates code to compare all the (xs[i], xs[j]) pairs. The method to compare |
179 | /// each pair is create via `compareBuilder`. |
180 | static Value createInlinedCompareImplementation( |
181 | OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm, |
182 | uint64_t ny, |
183 | function_ref<Value(OpBuilder &, Location, Value, Value, Value, bool, bool)> |
184 | compareBuilder) { |
185 | Value result; |
186 | auto bodyBuilder = [&](uint64_t k, Value i, Value j, Value buffer) { |
187 | bool isFirstDim = (k == 0); |
188 | bool isLastDim = (k == xPerm.getNumResults() - 1); |
189 | Value val = |
190 | compareBuilder(builder, loc, i, j, buffer, isFirstDim, isLastDim); |
191 | if (isFirstDim) { |
192 | result = val; |
193 | } else if (!isLastDim) { |
194 | OpBuilder::InsertionGuard insertionGuard(builder); |
195 | auto ifOp = cast<scf::IfOp>(val.getDefiningOp()); |
196 | builder.setInsertionPointAfter(ifOp); |
197 | builder.create<scf::YieldOp>(loc, ifOp.getResult(0)); |
198 | } |
199 | }; |
200 | |
201 | forEachIJPairInXs(builder, loc, args, xPerm, ny, bodyBuilder); |
202 | |
203 | builder.setInsertionPointAfterValue(result); |
204 | return result; |
205 | } |
206 | |
207 | /// Generates code to compare whether x[i] is equal to x[j] and returns the |
208 | /// result of the comparison. |
209 | static Value createEqCompare(OpBuilder &builder, Location loc, Value i, Value j, |
210 | Value x, bool isFirstDim, bool isLastDim) { |
211 | Value vi = builder.create<memref::LoadOp>(loc, x, i); |
212 | Value vj = builder.create<memref::LoadOp>(loc, x, j); |
213 | |
214 | Value res; |
215 | if (isLastDim) { |
216 | res = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, vi, vj); |
217 | // For 1D, we create a compare without any control flow. Otherwise, we |
218 | // create YieldOp to return the result in the nested if-stmt. |
219 | if (!isFirstDim) |
220 | builder.create<scf::YieldOp>(loc, res); |
221 | } else { |
222 | Value ne = |
223 | builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, vi, vj); |
224 | scf::IfOp ifOp = builder.create<scf::IfOp>(loc, builder.getIntegerType(1), |
225 | ne, /*else=*/true); |
226 | // If (x[i] != x[j]). |
227 | builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); |
228 | Value f = constantI1(builder, loc, b: false); |
229 | builder.create<scf::YieldOp>(loc, f); |
230 | |
231 | // If (x[i] == x[j]). Set up the insertion point for the nested if-stmt that |
232 | // checks the remaining dimensions. |
233 | builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); |
234 | res = ifOp.getResult(0); |
235 | } |
236 | |
237 | return res; |
238 | } |
239 | |
240 | /// Creates code to compare whether xs[i] is equal to xs[j]. |
241 | // |
242 | // The generate IR corresponds to this C like algorithm: |
243 | // if (x0[i] != x0[j]) |
244 | // return false; |
245 | // else |
246 | // if (x1[i] != x1[j]) |
247 | // return false; |
248 | // else if (x2[2] != x2[j])) |
249 | // and so on ... |
250 | static Value createInlinedEqCompare(OpBuilder &builder, Location loc, |
251 | ValueRange args, AffineMap xPerm, |
252 | uint64_t ny, uint32_t nTrailingP = 0) { |
253 | // Compare functions don't use trailing parameters. |
254 | (void)nTrailingP; |
255 | assert(nTrailingP == 0); |
256 | return createInlinedCompareImplementation(builder, loc, args, xPerm, ny, |
257 | compareBuilder: createEqCompare); |
258 | } |
259 | |
260 | /// Generates code to compare whether x[i] is less than x[j] and returns the |
261 | /// result of the comparison. |
262 | static Value createLessThanCompare(OpBuilder &builder, Location loc, Value i, |
263 | Value j, Value x, bool isFirstDim, |
264 | bool isLastDim) { |
265 | Value vi = builder.create<memref::LoadOp>(loc, x, i); |
266 | Value vj = builder.create<memref::LoadOp>(loc, x, j); |
267 | |
268 | Value res; |
269 | if (isLastDim) { |
270 | res = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, vi, vj); |
271 | // For 1D, we create a compare without any control flow. Otherwise, we |
272 | // create YieldOp to return the result in the nested if-stmt. |
273 | if (!isFirstDim) |
274 | builder.create<scf::YieldOp>(loc, res); |
275 | } else { |
276 | Value ne = |
277 | builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, vi, vj); |
278 | scf::IfOp ifOp = builder.create<scf::IfOp>(loc, builder.getIntegerType(1), |
279 | ne, /*else=*/true); |
280 | // If (x[i] != x[j]). |
281 | builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); |
282 | Value lt = |
283 | builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, vi, vj); |
284 | builder.create<scf::YieldOp>(loc, lt); |
285 | |
286 | // If (x[i] == x[j]). Set up the insertion point for the nested if-stmt that |
287 | // checks the remaining dimensions. |
288 | builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); |
289 | res = ifOp.getResult(0); |
290 | } |
291 | |
292 | return res; |
293 | } |
294 | |
295 | /// Creates code to compare whether xs[i] is less than xs[j]. |
296 | // |
297 | // The generate IR corresponds to this C like algorithm: |
298 | // if (x0[i] != x0[j]) |
299 | // return x0[i] < x0[j]; |
300 | // else if (x1[j] != x1[i]) |
301 | // return x1[i] < x1[j]; |
302 | // else |
303 | // and so on ... |
304 | static Value createInlinedLessThan(OpBuilder &builder, Location loc, |
305 | ValueRange args, AffineMap xPerm, |
306 | uint64_t ny, uint32_t nTrailingP = 0) { |
307 | // Compare functions don't use trailing parameters. |
308 | (void)nTrailingP; |
309 | assert(nTrailingP == 0); |
310 | return createInlinedCompareImplementation(builder, loc, args, xPerm, ny, |
311 | compareBuilder: createLessThanCompare); |
312 | } |
313 | |
314 | /// Creates a function to use a binary search to find the insertion point for |
315 | /// inserting xs[hi] to the sorted values xs[lo..hi). |
316 | // |
317 | // The generate IR corresponds to this C like algorithm: |
318 | // p = hi |
319 | // while (lo < hi) |
320 | // mid = (lo + hi) >> 1 |
321 | // if (xs[p] < xs[mid]) |
322 | // hi = mid |
323 | // else |
324 | // lo = mid - 1 |
325 | // return lo; |
326 | // |
327 | static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module, |
328 | func::FuncOp func, AffineMap xPerm, |
329 | uint64_t ny, uint32_t nTrailingP = 0) { |
330 | // Binary search doesn't use trailing parameters. |
331 | (void)nTrailingP; |
332 | assert(nTrailingP == 0); |
333 | OpBuilder::InsertionGuard insertionGuard(builder); |
334 | Block *entryBlock = func.addEntryBlock(); |
335 | builder.setInsertionPointToStart(entryBlock); |
336 | |
337 | Location loc = func.getLoc(); |
338 | ValueRange args = entryBlock->getArguments(); |
339 | Value p = args[hiIdx]; |
340 | SmallVector<Type, 2> types(2, p.getType()); // Only two types. |
341 | scf::WhileOp whileOp = builder.create<scf::WhileOp>( |
342 | loc, types, SmallVector<Value, 2>{args[loIdx], args[hiIdx]}); |
343 | |
344 | // The before-region of the WhileOp. |
345 | Block *before = |
346 | builder.createBlock(&whileOp.getBefore(), {}, types, {loc, loc}); |
347 | builder.setInsertionPointToEnd(before); |
348 | Value cond1 = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, |
349 | before->getArgument(0), |
350 | before->getArgument(1)); |
351 | builder.create<scf::ConditionOp>(loc, cond1, before->getArguments()); |
352 | |
353 | // The after-region of the WhileOp. |
354 | Block *after = |
355 | builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc}); |
356 | builder.setInsertionPointToEnd(after); |
357 | Value lo = after->getArgument(i: 0); |
358 | Value hi = after->getArgument(i: 1); |
359 | // Compute mid = (lo + hi) >> 1. |
360 | Value c1 = constantIndex(builder, loc, i: 1); |
361 | Value mid = builder.create<arith::ShRUIOp>( |
362 | loc, builder.create<arith::AddIOp>(loc, lo, hi), c1); |
363 | Value midp1 = builder.create<arith::AddIOp>(loc, mid, c1); |
364 | |
365 | // Compare xs[p] < xs[mid]. |
366 | SmallVector<Value> compareOperands{p, mid}; |
367 | constexpr uint64_t numXBuffers = 1; |
368 | compareOperands.append(in_start: args.begin() + xStartIdx, |
369 | in_end: args.begin() + xStartIdx + numXBuffers); |
370 | Value cond2 = createInlinedLessThan(builder, loc, args: compareOperands, xPerm, ny); |
371 | // Update lo and hi for the WhileOp as follows: |
372 | // if (xs[p] < xs[mid])) |
373 | // hi = mid; |
374 | // else |
375 | // lo = mid + 1; |
376 | Value newLo = builder.create<arith::SelectOp>(loc, cond2, lo, midp1); |
377 | Value newHi = builder.create<arith::SelectOp>(loc, cond2, mid, hi); |
378 | builder.create<scf::YieldOp>(loc, ValueRange{newLo, newHi}); |
379 | |
380 | builder.setInsertionPointAfter(whileOp); |
381 | builder.create<func::ReturnOp>(loc, whileOp.getResult(0)); |
382 | } |
383 | |
384 | /// Creates code to advance i in a loop based on xs[p] as follows: |
385 | /// while (xs[i] < xs[p]) i += step (step > 0) |
386 | /// or |
387 | /// while (xs[i] > xs[p]) i += step (step < 0) |
388 | /// The routine returns i as well as a boolean value to indicate whether |
389 | /// xs[i] == xs[p]. |
390 | static std::pair<Value, Value> createScanLoop(OpBuilder &builder, |
391 | ModuleOp module, |
392 | func::FuncOp func, ValueRange xs, |
393 | Value i, Value p, AffineMap xPerm, |
394 | uint64_t ny, int step) { |
395 | Location loc = func.getLoc(); |
396 | scf::WhileOp whileOp = |
397 | builder.create<scf::WhileOp>(loc, TypeRange{i.getType()}, ValueRange{i}); |
398 | |
399 | Block *before = |
400 | builder.createBlock(&whileOp.getBefore(), {}, {i.getType()}, {loc}); |
401 | builder.setInsertionPointToEnd(before); |
402 | SmallVector<Value> compareOperands; |
403 | if (step > 0) { |
404 | compareOperands.push_back(Elt: before->getArgument(i: 0)); |
405 | compareOperands.push_back(Elt: p); |
406 | } else { |
407 | assert(step < 0); |
408 | compareOperands.push_back(Elt: p); |
409 | compareOperands.push_back(Elt: before->getArgument(i: 0)); |
410 | } |
411 | compareOperands.append(in_start: xs.begin(), in_end: xs.end()); |
412 | Value cond = createInlinedLessThan(builder, loc, args: compareOperands, xPerm, ny); |
413 | builder.create<scf::ConditionOp>(loc, cond, before->getArguments()); |
414 | |
415 | Block *after = |
416 | builder.createBlock(&whileOp.getAfter(), {}, {i.getType()}, {loc}); |
417 | builder.setInsertionPointToEnd(after); |
418 | Value cs = constantIndex(builder, loc, i: step); |
419 | i = builder.create<arith::AddIOp>(loc, after->getArgument(0), cs); |
420 | builder.create<scf::YieldOp>(loc, ValueRange{i}); |
421 | i = whileOp.getResult(0); |
422 | |
423 | builder.setInsertionPointAfter(whileOp); |
424 | compareOperands[0] = i; |
425 | compareOperands[1] = p; |
426 | Value compareEq = |
427 | createInlinedEqCompare(builder, loc, args: compareOperands, xPerm, ny); |
428 | |
429 | return std::make_pair(whileOp.getResult(0), compareEq); |
430 | } |
431 | |
432 | /// Creates and returns an IfOp to compare two elements and swap the elements |
433 | /// if compareFunc(data[b], data[a]) returns true. The new insertion point is |
434 | /// right after the swap instructions. |
435 | static scf::IfOp createCompareThenSwap(OpBuilder &builder, Location loc, |
436 | AffineMap xPerm, uint64_t ny, |
437 | SmallVectorImpl<Value> &swapOperands, |
438 | SmallVectorImpl<Value> &compareOperands, |
439 | Value a, Value b) { |
440 | // Compare(data[b], data[a]). |
441 | compareOperands[0] = b; |
442 | compareOperands[1] = a; |
443 | Value cond = createInlinedLessThan(builder, loc, args: compareOperands, xPerm, ny); |
444 | scf::IfOp ifOp = builder.create<scf::IfOp>(loc, cond, /*else=*/false); |
445 | builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); |
446 | swapOperands[0] = b; |
447 | swapOperands[1] = a; |
448 | createSwap(builder, loc, args: swapOperands, xPerm, ny); |
449 | return ifOp; |
450 | } |
451 | |
452 | /// Creates code to insert the 3rd element to a list of two sorted elements. |
453 | static void createInsert3rd(OpBuilder &builder, Location loc, AffineMap xPerm, |
454 | uint64_t ny, SmallVectorImpl<Value> &swapOperands, |
455 | SmallVectorImpl<Value> &compareOperands, Value v0, |
456 | Value v1, Value v2) { |
457 | scf::IfOp ifOp = createCompareThenSwap(builder, loc, xPerm, ny, swapOperands, |
458 | compareOperands, v1, v2); |
459 | createCompareThenSwap(builder, loc, xPerm, ny, swapOperands, compareOperands, |
460 | a: v0, b: v1); |
461 | builder.setInsertionPointAfter(ifOp); |
462 | } |
463 | |
464 | /// Creates code to sort 3 elements. |
465 | static void createSort3(OpBuilder &builder, Location loc, AffineMap xPerm, |
466 | uint64_t ny, SmallVectorImpl<Value> &swapOperands, |
467 | SmallVectorImpl<Value> &compareOperands, Value v0, |
468 | Value v1, Value v2) { |
469 | // Sort the first 2 elements. |
470 | scf::IfOp ifOp1 = createCompareThenSwap(builder, loc, xPerm, ny, swapOperands, |
471 | compareOperands, v0, v1); |
472 | builder.setInsertionPointAfter(ifOp1); |
473 | |
474 | // Insert the 3th element. |
475 | createInsert3rd(builder, loc, xPerm, ny, swapOperands, compareOperands, v0, |
476 | v1, v2); |
477 | } |
478 | |
479 | /// Creates code to sort 5 elements. |
480 | static void createSort5(OpBuilder &builder, Location loc, AffineMap xPerm, |
481 | uint64_t ny, SmallVectorImpl<Value> &swapOperands, |
482 | SmallVectorImpl<Value> &compareOperands, Value v0, |
483 | Value v1, Value v2, Value v3, Value v4) { |
484 | // Sort the first 3 elements. |
485 | createSort3(builder, loc, xPerm, ny, swapOperands, compareOperands, v0, v1, |
486 | v2); |
487 | |
488 | auto insert4th = [&]() { |
489 | scf::IfOp ifOp = createCompareThenSwap( |
490 | builder, loc, xPerm, ny, swapOperands, compareOperands, v2, v3); |
491 | createInsert3rd(builder, loc, xPerm, ny, swapOperands, compareOperands, v0, |
492 | v1, v2); |
493 | builder.setInsertionPointAfter(ifOp); |
494 | }; |
495 | |
496 | // Insert the 4th element. |
497 | insert4th(); |
498 | |
499 | // Insert the 5th element. |
500 | scf::IfOp ifOp = createCompareThenSwap(builder, loc, xPerm, ny, swapOperands, |
501 | compareOperands, v3, v4); |
502 | insert4th(); |
503 | builder.setInsertionPointAfter(ifOp); |
504 | } |
505 | |
506 | /// Creates a code block to swap the values in indices lo, mi, and hi so that |
507 | /// data[lo], data[mi] and data[hi] are sorted in non-decreasing values. When |
508 | /// the number of values in range [lo, hi) is more than a threshold, we also |
509 | /// include the middle of [lo, mi) and [mi, hi) and sort a total of five values. |
510 | static void createChoosePivot(OpBuilder &builder, ModuleOp module, |
511 | func::FuncOp func, AffineMap xPerm, uint64_t ny, |
512 | Value lo, Value hi, Value mi, ValueRange args) { |
513 | SmallVector<Value> compareOperands{mi, lo}; |
514 | constexpr uint64_t numXBuffers = 1; |
515 | compareOperands.append(in_start: args.begin() + xStartIdx, |
516 | in_end: args.begin() + xStartIdx + numXBuffers); |
517 | SmallVector<Value> swapOperands{mi, lo}; |
518 | swapOperands.append(in_start: args.begin() + xStartIdx, in_end: args.end()); |
519 | Location loc = func.getLoc(); |
520 | Value c1 = constantIndex(builder, loc, i: 1); |
521 | Value hiP1 = builder.create<arith::AddIOp>(loc, hi, c1); |
522 | Value len = builder.create<arith::SubIOp>(loc, hiP1, lo); |
523 | Value lenThreshold = constantIndex(builder, loc, i: 1000); |
524 | Value lenCond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, |
525 | len, lenThreshold); |
526 | scf::IfOp lenIf = builder.create<scf::IfOp>(loc, lenCond, /*else=*/true); |
527 | |
528 | // When len < 1000, choose pivot from median of 3 values. |
529 | builder.setInsertionPointToStart(&lenIf.getThenRegion().front()); |
530 | createSort3(builder, loc, xPerm, ny, swapOperands, compareOperands, v0: lo, v1: mi, |
531 | v2: hi); |
532 | |
533 | // When len >= 1000, choose pivot from median of 5 values. |
534 | builder.setInsertionPointToStart(&lenIf.getElseRegion().front()); |
535 | Value miP1 = builder.create<arith::AddIOp>(loc, hi, c1); |
536 | Value a = builder.create<arith::AddIOp>(loc, lo, miP1); |
537 | // Value a is the middle between [loc, mi]. |
538 | a = builder.create<arith::ShRUIOp>(loc, a, c1); |
539 | Value b = builder.create<arith::AddIOp>(loc, mi, hiP1); |
540 | // Value b is the middle between [mi, hi]. |
541 | b = builder.create<arith::ShRUIOp>(loc, b, c1); |
542 | createSort5(builder, loc, xPerm, ny, swapOperands, compareOperands, v0: lo, v1: a, v2: mi, |
543 | v3: b, v4: hi); |
544 | |
545 | builder.setInsertionPointAfter(lenIf); |
546 | } |
547 | |
548 | /// Creates a function to perform quick sort partition on the values in the |
549 | /// range of index [lo, hi), assuming lo < hi. |
550 | // |
551 | // The generated IR corresponds to this C like algorithm: |
552 | // int partition(lo, hi, xs) { |
553 | // p = (lo+hi)/2 // pivot index |
554 | // i = lo |
555 | // j = hi-1 |
556 | // while (true) do { |
557 | // while (xs[i] < xs[p]) i ++; |
558 | // i_eq = (xs[i] == xs[p]); |
559 | // while (xs[j] > xs[p]) j --; |
560 | // j_eq = (xs[j] == xs[p]); |
561 | // |
562 | // if (i >= j) return j + 1; |
563 | // |
564 | // if (i < j) { |
565 | // swap(xs[i], xs[j]) |
566 | // if (i == p) { |
567 | // p = j; |
568 | // } else if (j == p) { |
569 | // p = i; |
570 | // } |
571 | // if (i_eq && j_eq) { |
572 | // ++i; |
573 | // --j; |
574 | // } |
575 | // } |
576 | // } |
577 | // } |
578 | static void createPartitionFunc(OpBuilder &builder, ModuleOp module, |
579 | func::FuncOp func, AffineMap xPerm, uint64_t ny, |
580 | uint32_t nTrailingP = 0) { |
581 | // Quick sort partition doesn't use trailing parameters. |
582 | (void)nTrailingP; |
583 | assert(nTrailingP == 0); |
584 | OpBuilder::InsertionGuard insertionGuard(builder); |
585 | |
586 | Block *entryBlock = func.addEntryBlock(); |
587 | builder.setInsertionPointToStart(entryBlock); |
588 | |
589 | Location loc = func.getLoc(); |
590 | ValueRange args = entryBlock->getArguments(); |
591 | Value lo = args[loIdx]; |
592 | Value hi = args[hiIdx]; |
593 | Value sum = builder.create<arith::AddIOp>(loc, lo, hi); |
594 | Value c1 = constantIndex(builder, loc, i: 1); |
595 | Value p = builder.create<arith::ShRUIOp>(loc, sum, c1); |
596 | |
597 | Value i = lo; |
598 | Value j = builder.create<arith::SubIOp>(loc, hi, c1); |
599 | createChoosePivot(builder, module, func, xPerm, ny, i, j, p, args); |
600 | Value trueVal = constantI1(builder, loc, b: true); // The value for while (true) |
601 | SmallVector<Value, 4> operands{i, j, p, trueVal}; // Exactly four values. |
602 | SmallVector<Type, 4> types{i.getType(), j.getType(), p.getType(), |
603 | trueVal.getType()}; |
604 | scf::WhileOp whileOp = builder.create<scf::WhileOp>(loc, types, operands); |
605 | |
606 | // The before-region of the WhileOp. |
607 | Block *before = builder.createBlock(&whileOp.getBefore(), {}, types, |
608 | {loc, loc, loc, loc}); |
609 | builder.setInsertionPointToEnd(before); |
610 | builder.create<scf::ConditionOp>(loc, before->getArgument(3), |
611 | before->getArguments()); |
612 | |
613 | // The after-region of the WhileOp. |
614 | Block *after = |
615 | builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc, loc, loc}); |
616 | builder.setInsertionPointToEnd(after); |
617 | i = after->getArgument(i: 0); |
618 | j = after->getArgument(i: 1); |
619 | p = after->getArgument(i: 2); |
620 | |
621 | constexpr uint64_t numXBuffers = 1; |
622 | auto [iresult, iCompareEq] = |
623 | createScanLoop(builder, module, func, args.slice(xStartIdx, numXBuffers), |
624 | i, p, xPerm, ny, 1); |
625 | i = iresult; |
626 | auto [jresult, jCompareEq] = |
627 | createScanLoop(builder, module, func, args.slice(xStartIdx, numXBuffers), |
628 | j, p, xPerm, ny, -1); |
629 | j = jresult; |
630 | |
631 | // If i < j: |
632 | Value cond = |
633 | builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, i, j); |
634 | scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, cond, /*else=*/true); |
635 | builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); |
636 | SmallVector<Value> swapOperands{i, j}; |
637 | swapOperands.append(in_start: args.begin() + xStartIdx, in_end: args.end()); |
638 | createSwap(builder, loc, args: swapOperands, xPerm, ny); |
639 | // If the pivot is moved, update p with the new pivot. |
640 | Value icond = |
641 | builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, i, p); |
642 | scf::IfOp ifOpI = builder.create<scf::IfOp>(loc, TypeRange{p.getType()}, |
643 | icond, /*else=*/true); |
644 | builder.setInsertionPointToStart(&ifOpI.getThenRegion().front()); |
645 | builder.create<scf::YieldOp>(loc, ValueRange{j}); |
646 | builder.setInsertionPointToStart(&ifOpI.getElseRegion().front()); |
647 | Value jcond = |
648 | builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, j, p); |
649 | scf::IfOp ifOpJ = builder.create<scf::IfOp>(loc, TypeRange{p.getType()}, |
650 | jcond, /*else=*/true); |
651 | builder.setInsertionPointToStart(&ifOpJ.getThenRegion().front()); |
652 | builder.create<scf::YieldOp>(loc, ValueRange{i}); |
653 | builder.setInsertionPointToStart(&ifOpJ.getElseRegion().front()); |
654 | builder.create<scf::YieldOp>(loc, ValueRange{p}); |
655 | builder.setInsertionPointAfter(ifOpJ); |
656 | builder.create<scf::YieldOp>(loc, ifOpJ.getResults()); |
657 | builder.setInsertionPointAfter(ifOpI); |
658 | Value compareEqIJ = |
659 | builder.create<arith::AndIOp>(loc, iCompareEq, jCompareEq); |
660 | scf::IfOp ifOp2 = builder.create<scf::IfOp>( |
661 | loc, TypeRange{i.getType(), j.getType()}, compareEqIJ, /*else=*/true); |
662 | builder.setInsertionPointToStart(&ifOp2.getThenRegion().front()); |
663 | Value i2 = builder.create<arith::AddIOp>(loc, i, c1); |
664 | Value j2 = builder.create<arith::SubIOp>(loc, j, c1); |
665 | builder.create<scf::YieldOp>(loc, ValueRange{i2, j2}); |
666 | builder.setInsertionPointToStart(&ifOp2.getElseRegion().front()); |
667 | builder.create<scf::YieldOp>(loc, ValueRange{i, j}); |
668 | builder.setInsertionPointAfter(ifOp2); |
669 | builder.create<scf::YieldOp>( |
670 | loc, |
671 | ValueRange{ifOp2.getResult(0), ifOp2.getResult(1), ifOpI.getResult(0), |
672 | /*cont=*/constantI1(builder, loc, true)}); |
673 | |
674 | // False branch for if i < j (i.e., i >= j): |
675 | builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); |
676 | p = builder.create<arith::AddIOp>(loc, j, |
677 | constantOne(builder, loc, j.getType())); |
678 | builder.create<scf::YieldOp>( |
679 | loc, ValueRange{i, j, p, /*cont=*/constantI1(builder, loc, false)}); |
680 | |
681 | // Return for the whileOp. |
682 | builder.setInsertionPointAfter(ifOp); |
683 | builder.create<scf::YieldOp>(loc, ifOp.getResults()); |
684 | |
685 | // Return for the function. |
686 | builder.setInsertionPointAfter(whileOp); |
687 | builder.create<func::ReturnOp>(loc, whileOp.getResult(2)); |
688 | } |
689 | |
690 | /// Computes (n-2)/n, assuming n has index type. |
691 | static Value createSubTwoDividedByTwo(OpBuilder &builder, Location loc, |
692 | Value n) { |
693 | Value i2 = constantIndex(builder, loc, i: 2); |
694 | Value res = builder.create<arith::SubIOp>(loc, n, i2); |
695 | Value i1 = constantIndex(builder, loc, i: 1); |
696 | return builder.create<arith::ShRUIOp>(loc, res, i1); |
697 | } |
698 | |
699 | /// Creates a function to heapify the subtree with root `start` within the full |
700 | /// binary tree in the range of index [first, first + n). |
701 | // |
702 | // The generated IR corresponds to this C like algorithm: |
703 | // void shiftDown(first, start, n, data) { |
704 | // if (n >= 2) { |
705 | // child = start - first |
706 | // if ((n-2)/2 >= child) { |
707 | // // Left child exists. |
708 | // child = child * 2 + 1 // Initialize the bigger child to left child. |
709 | // childIndex = child + first |
710 | // if (child+1 < n && data[childIndex] < data[childIndex+1]) |
711 | // // Right child exits and is bigger. |
712 | // childIndex++; child++; |
713 | // // Shift data[start] down to where it belongs in the subtree. |
714 | // while (data[start] < data[childIndex) { |
715 | // swap(data[start], data[childIndex]) |
716 | // start = childIndex |
717 | // if ((n - 2)/2 >= child) { |
718 | // // Left child exists. |
719 | // child = 2*child + 1 |
720 | // childIndex = child + 1 |
721 | // if (child + 1) < n && data[childIndex] < data[childIndex+1] |
722 | // childIndex++; child++; |
723 | // } |
724 | // } |
725 | // } |
726 | // } |
727 | // } |
728 | // |
729 | static void createShiftDownFunc(OpBuilder &builder, ModuleOp module, |
730 | func::FuncOp func, AffineMap xPerm, uint64_t ny, |
731 | uint32_t nTrailingP) { |
732 | // The value n is passed in as a trailing parameter. |
733 | assert(nTrailingP == 1); |
734 | OpBuilder::InsertionGuard insertionGuard(builder); |
735 | Block *entryBlock = func.addEntryBlock(); |
736 | builder.setInsertionPointToStart(entryBlock); |
737 | |
738 | Location loc = func.getLoc(); |
739 | Value n = entryBlock->getArguments().back(); |
740 | ValueRange args = entryBlock->getArguments().drop_back(); |
741 | Value first = args[loIdx]; |
742 | Value start = args[hiIdx]; |
743 | |
744 | // If (n >= 2). |
745 | Value c2 = constantIndex(builder, loc, i: 2); |
746 | Value condN = |
747 | builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, n, c2); |
748 | scf::IfOp ifN = builder.create<scf::IfOp>(loc, condN, /*else=*/false); |
749 | builder.setInsertionPointToStart(&ifN.getThenRegion().front()); |
750 | Value child = builder.create<arith::SubIOp>(loc, start, first); |
751 | |
752 | // If ((n-2)/2 >= child). |
753 | Value t = createSubTwoDividedByTwo(builder, loc, n); |
754 | Value condNc = |
755 | builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, t, child); |
756 | scf::IfOp ifNc = builder.create<scf::IfOp>(loc, condNc, /*else=*/false); |
757 | |
758 | builder.setInsertionPointToStart(&ifNc.getThenRegion().front()); |
759 | Value c1 = constantIndex(builder, loc, i: 1); |
760 | SmallVector<Value> compareOperands{start, start}; |
761 | constexpr uint64_t numXBuffers = 1; |
762 | compareOperands.append(in_start: args.begin() + xStartIdx, |
763 | in_end: args.begin() + xStartIdx + numXBuffers); |
764 | |
765 | // Generate code to inspect the children of 'r' and return the larger child |
766 | // as follows: |
767 | // child = r * 2 + 1 // Left child. |
768 | // childIndex = child + first |
769 | // if (child+1 < n && data[childIndex] < data[childIndex+1]) |
770 | // childIndex ++; child ++ // Right child is bigger. |
771 | auto getLargerChild = [&](Value r) -> std::pair<Value, Value> { |
772 | Value lChild = builder.create<arith::ShLIOp>(loc, r, c1); |
773 | lChild = builder.create<arith::AddIOp>(loc, lChild, c1); |
774 | Value lChildIdx = builder.create<arith::AddIOp>(loc, lChild, first); |
775 | Value rChild = builder.create<arith::AddIOp>(loc, lChild, c1); |
776 | Value cond1 = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, |
777 | rChild, n); |
778 | SmallVector<Type, 2> ifTypes(2, r.getType()); |
779 | scf::IfOp if1 = |
780 | builder.create<scf::IfOp>(loc, ifTypes, cond1, /*else=*/true); |
781 | builder.setInsertionPointToStart(&if1.getThenRegion().front()); |
782 | Value rChildIdx = builder.create<arith::AddIOp>(loc, rChild, first); |
783 | // Compare data[left] < data[right]. |
784 | compareOperands[0] = lChildIdx; |
785 | compareOperands[1] = rChildIdx; |
786 | Value cond2 = |
787 | createInlinedLessThan(builder, loc, args: compareOperands, xPerm, ny); |
788 | scf::IfOp if2 = |
789 | builder.create<scf::IfOp>(loc, ifTypes, cond2, /*else=*/true); |
790 | builder.setInsertionPointToStart(&if2.getThenRegion().front()); |
791 | builder.create<scf::YieldOp>(loc, ValueRange{rChild, rChildIdx}); |
792 | builder.setInsertionPointToStart(&if2.getElseRegion().front()); |
793 | builder.create<scf::YieldOp>(loc, ValueRange{lChild, lChildIdx}); |
794 | builder.setInsertionPointAfter(if2); |
795 | builder.create<scf::YieldOp>(loc, if2.getResults()); |
796 | builder.setInsertionPointToStart(&if1.getElseRegion().front()); |
797 | builder.create<scf::YieldOp>(loc, ValueRange{lChild, lChildIdx}); |
798 | builder.setInsertionPointAfter(if1); |
799 | return std::make_pair(if1.getResult(0), if1.getResult(1)); |
800 | }; |
801 | |
802 | Value childIdx; |
803 | std::tie(args&: child, args&: childIdx) = getLargerChild(child); |
804 | |
805 | // While (data[start] < data[childIndex]). |
806 | SmallVector<Type, 3> types(3, child.getType()); |
807 | scf::WhileOp whileOp = builder.create<scf::WhileOp>( |
808 | loc, types, SmallVector<Value, 2>{start, child, childIdx}); |
809 | |
810 | // The before-region of the WhileOp. |
811 | SmallVector<Location, 3> locs(3, loc); |
812 | Block *before = builder.createBlock(&whileOp.getBefore(), {}, types, locs); |
813 | builder.setInsertionPointToEnd(before); |
814 | start = before->getArgument(i: 0); |
815 | childIdx = before->getArgument(i: 2); |
816 | compareOperands[0] = start; |
817 | compareOperands[1] = childIdx; |
818 | Value cond = createInlinedLessThan(builder, loc, args: compareOperands, xPerm, ny); |
819 | builder.create<scf::ConditionOp>(loc, cond, before->getArguments()); |
820 | |
821 | // The after-region of the WhileOp. |
822 | Block *after = builder.createBlock(&whileOp.getAfter(), {}, types, locs); |
823 | start = after->getArgument(i: 0); |
824 | child = after->getArgument(i: 1); |
825 | childIdx = after->getArgument(i: 2); |
826 | SmallVector<Value> swapOperands{start, childIdx}; |
827 | swapOperands.append(in_start: args.begin() + xStartIdx, in_end: args.end()); |
828 | createSwap(builder, loc, args: swapOperands, xPerm, ny); |
829 | start = childIdx; |
830 | Value cond2 = |
831 | builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, t, child); |
832 | scf::IfOp if2 = builder.create<scf::IfOp>( |
833 | loc, TypeRange{child.getType(), child.getType()}, cond2, /*else=*/true); |
834 | builder.setInsertionPointToStart(&if2.getThenRegion().front()); |
835 | auto [newChild, newChildIdx] = getLargerChild(child); |
836 | builder.create<scf::YieldOp>(loc, ValueRange{newChild, newChildIdx}); |
837 | builder.setInsertionPointToStart(&if2.getElseRegion().front()); |
838 | builder.create<scf::YieldOp>(loc, ValueRange{child, childIdx}); |
839 | builder.setInsertionPointAfter(if2); |
840 | builder.create<scf::YieldOp>( |
841 | loc, ValueRange{start, if2.getResult(0), if2.getResult(1)}); |
842 | |
843 | builder.setInsertionPointAfter(ifN); |
844 | builder.create<func::ReturnOp>(loc); |
845 | } |
846 | |
847 | /// Creates a function to perform heap sort on the values in the range of index |
848 | /// [lo, hi) with the assumption hi - lo >= 2. |
849 | // |
850 | // The generate IR corresponds to this C like algorithm: |
851 | // void heapSort(lo, hi, data) { |
852 | // n = hi - lo |
853 | // for i = (n-2)/2 downto 0 |
854 | // shiftDown(lo, lo+i, n) |
855 | // |
856 | // for l = n downto 2 |
857 | // swap(lo, lo+l-1) |
858 | // shiftdown(lo, lo, l-1) |
859 | // } |
860 | static void createHeapSortFunc(OpBuilder &builder, ModuleOp module, |
861 | func::FuncOp func, AffineMap xPerm, uint64_t ny, |
862 | uint32_t nTrailingP) { |
863 | // Heap sort function doesn't have trailing parameters. |
864 | (void)nTrailingP; |
865 | assert(nTrailingP == 0); |
866 | OpBuilder::InsertionGuard insertionGuard(builder); |
867 | Block *entryBlock = func.addEntryBlock(); |
868 | builder.setInsertionPointToStart(entryBlock); |
869 | |
870 | Location loc = func.getLoc(); |
871 | ValueRange args = entryBlock->getArguments(); |
872 | Value lo = args[loIdx]; |
873 | Value hi = args[hiIdx]; |
874 | Value n = builder.create<arith::SubIOp>(loc, hi, lo); |
875 | |
876 | // For i = (n-2)/2 downto 0. |
877 | Value c0 = constantIndex(builder, loc, i: 0); |
878 | Value c1 = constantIndex(builder, loc, i: 1); |
879 | Value s = createSubTwoDividedByTwo(builder, loc, n); |
880 | Value up = builder.create<arith::AddIOp>(loc, s, c1); |
881 | scf::ForOp forI = builder.create<scf::ForOp>(loc, c0, up, c1); |
882 | builder.setInsertionPointToStart(forI.getBody()); |
883 | Value i = builder.create<arith::SubIOp>(loc, s, forI.getInductionVar()); |
884 | Value lopi = builder.create<arith::AddIOp>(loc, lo, i); |
885 | SmallVector<Value> shiftDownOperands = {lo, lopi}; |
886 | shiftDownOperands.append(in_start: args.begin() + xStartIdx, in_end: args.end()); |
887 | shiftDownOperands.push_back(Elt: n); |
888 | FlatSymbolRefAttr shiftDownFunc = getMangledSortHelperFunc( |
889 | builder, func, TypeRange(), kShiftDownFuncNamePrefix, xPerm, ny, |
890 | shiftDownOperands, createShiftDownFunc, /*nTrailingP=*/1); |
891 | builder.create<func::CallOp>(loc, shiftDownFunc, TypeRange(), |
892 | shiftDownOperands); |
893 | |
894 | builder.setInsertionPointAfter(forI); |
895 | // For l = n downto 2. |
896 | up = builder.create<arith::SubIOp>(loc, n, c1); |
897 | scf::ForOp forL = builder.create<scf::ForOp>(loc, c0, up, c1); |
898 | builder.setInsertionPointToStart(forL.getBody()); |
899 | Value l = builder.create<arith::SubIOp>(loc, n, forL.getInductionVar()); |
900 | Value loplm1 = builder.create<arith::AddIOp>(loc, lo, l); |
901 | loplm1 = builder.create<arith::SubIOp>(loc, loplm1, c1); |
902 | SmallVector<Value> swapOperands{lo, loplm1}; |
903 | swapOperands.append(in_start: args.begin() + xStartIdx, in_end: args.end()); |
904 | createSwap(builder, loc, args: swapOperands, xPerm, ny); |
905 | shiftDownOperands[1] = lo; |
906 | shiftDownOperands[shiftDownOperands.size() - 1] = |
907 | builder.create<arith::SubIOp>(loc, l, c1); |
908 | builder.create<func::CallOp>(loc, shiftDownFunc, TypeRange(), |
909 | shiftDownOperands); |
910 | |
911 | builder.setInsertionPointAfter(forL); |
912 | builder.create<func::ReturnOp>(loc); |
913 | } |
914 | |
915 | /// A helper for generating code to perform quick sort. It partitions [lo, hi), |
916 | /// recursively calls quick sort to process the smaller partition and returns |
917 | /// the bigger partition to be processed by the enclosed while-loop. |
918 | static std::pair<Value, Value> |
919 | createQuickSort(OpBuilder &builder, ModuleOp module, func::FuncOp func, |
920 | ValueRange args, AffineMap xPerm, uint64_t ny, |
921 | uint32_t nTrailingP) { |
922 | MLIRContext *context = module.getContext(); |
923 | Location loc = func.getLoc(); |
924 | Value lo = args[loIdx]; |
925 | Value hi = args[hiIdx]; |
926 | SmallVector<Type, 2> types(2, lo.getType()); // Only two types. |
927 | |
928 | FlatSymbolRefAttr partitionFunc = getMangledSortHelperFunc( |
929 | builder, func, {IndexType::get(context)}, kPartitionFuncNamePrefix, xPerm, |
930 | ny, args.drop_back(nTrailingP), createPartitionFunc); |
931 | Value p = builder |
932 | .create<func::CallOp>(loc, partitionFunc, |
933 | TypeRange{IndexType::get(context)}, |
934 | args.drop_back(nTrailingP)) |
935 | .getResult(0); |
936 | |
937 | Value lenLow = builder.create<arith::SubIOp>(loc, p, lo); |
938 | Value lenHigh = builder.create<arith::SubIOp>(loc, hi, p); |
939 | // Partition already sorts array with len <= 2 |
940 | Value c2 = constantIndex(builder, loc, i: 2); |
941 | Value len = builder.create<arith::SubIOp>(loc, hi, lo); |
942 | Value lenGtTwo = |
943 | builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ugt, len, c2); |
944 | scf::IfOp ifLenGtTwo = |
945 | builder.create<scf::IfOp>(loc, types, lenGtTwo, /*else=*/true); |
946 | builder.setInsertionPointToStart(&ifLenGtTwo.getElseRegion().front()); |
947 | // Returns an empty range to mark the entire region is fully sorted. |
948 | builder.create<scf::YieldOp>(loc, ValueRange{lo, lo}); |
949 | |
950 | // Else len > 2, need recursion. |
951 | builder.setInsertionPointToStart(&ifLenGtTwo.getThenRegion().front()); |
952 | Value cond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ule, |
953 | lenLow, lenHigh); |
954 | |
955 | Value c0 = constantIndex(builder, loc, i: 0); |
956 | scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, cond, /*else=*/true); |
957 | |
958 | auto mayRecursion = [&](Value low, Value high, Value len) { |
959 | Value cond = |
960 | builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, len, c0); |
961 | scf::IfOp ifOp = builder.create<scf::IfOp>(loc, cond, /*else=*/false); |
962 | builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); |
963 | SmallVector<Value> operands{low, high}; |
964 | operands.append(in_start: args.begin() + xStartIdx, in_end: args.end()); |
965 | builder.create<func::CallOp>(loc, func, operands); |
966 | builder.setInsertionPointAfter(ifOp); |
967 | }; |
968 | |
969 | // Recursively call quickSort to process the smaller partition and return |
970 | // the bigger partition to be processed by the enclosed while-loop. |
971 | builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); |
972 | mayRecursion(lo, p, lenLow); |
973 | builder.create<scf::YieldOp>(loc, ValueRange{p, hi}); |
974 | |
975 | builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); |
976 | mayRecursion(p, hi, lenHigh); |
977 | builder.create<scf::YieldOp>(loc, ValueRange{lo, p}); |
978 | |
979 | builder.setInsertionPointAfter(ifOp); |
980 | builder.create<scf::YieldOp>(loc, ifOp.getResults()); |
981 | |
982 | builder.setInsertionPointAfter(ifLenGtTwo); |
983 | return std::make_pair(ifLenGtTwo.getResult(0), ifLenGtTwo.getResult(1)); |
984 | } |
985 | |
986 | /// Creates a function to perform insertion sort on the values in the range of |
987 | /// index [lo, hi). |
988 | // |
989 | // The generate IR corresponds to this C like algorithm: |
990 | // void insertionSort(lo, hi, data) { |
991 | // for (i = lo+1; i < hi; i++) { |
992 | // d = data[i]; |
993 | // p = binarySearch(lo, i-1, data) |
994 | // for (j = 0; j > i - p; j++) |
995 | // data[i-j] = data[i-j-1] |
996 | // data[p] = d |
997 | // } |
998 | // } |
999 | static void createSortStableFunc(OpBuilder &builder, ModuleOp module, |
1000 | func::FuncOp func, AffineMap xPerm, |
1001 | uint64_t ny, uint32_t nTrailingP) { |
1002 | // Stable sort function doesn't use trailing parameters. |
1003 | (void)nTrailingP; |
1004 | assert(nTrailingP == 0); |
1005 | OpBuilder::InsertionGuard insertionGuard(builder); |
1006 | Block *entryBlock = func.addEntryBlock(); |
1007 | builder.setInsertionPointToStart(entryBlock); |
1008 | |
1009 | MLIRContext *context = module.getContext(); |
1010 | Location loc = func.getLoc(); |
1011 | ValueRange args = entryBlock->getArguments(); |
1012 | Value c1 = constantIndex(builder, loc, i: 1); |
1013 | Value lo = args[loIdx]; |
1014 | Value hi = args[hiIdx]; |
1015 | Value lop1 = builder.create<arith::AddIOp>(loc, lo, c1); |
1016 | |
1017 | // Start the outer for-stmt with induction variable i. |
1018 | scf::ForOp forOpI = builder.create<scf::ForOp>(loc, lop1, hi, c1); |
1019 | builder.setInsertionPointToStart(forOpI.getBody()); |
1020 | Value i = forOpI.getInductionVar(); |
1021 | |
1022 | // Binary search to find the insertion point p. |
1023 | SmallVector<Value> operands{lo, i}; |
1024 | operands.append(in_start: args.begin() + xStartIdx, in_end: args.end()); |
1025 | FlatSymbolRefAttr searchFunc = getMangledSortHelperFunc( |
1026 | builder, func, {IndexType::get(context)}, kBinarySearchFuncNamePrefix, |
1027 | xPerm, ny, operands, createBinarySearchFunc); |
1028 | Value p = builder |
1029 | .create<func::CallOp>(loc, searchFunc, TypeRange{c1.getType()}, |
1030 | operands) |
1031 | .getResult(0); |
1032 | |
1033 | // Move the value at data[i] to a temporary location. |
1034 | operands[0] = operands[1] = i; |
1035 | SmallVector<Value> d; |
1036 | forEachIJPairInAllBuffers( |
1037 | builder, loc, args: operands, xPerm, ny, |
1038 | bodyBuilder: [&](uint64_t unused, Value i, Value unused2, Value buffer) { |
1039 | d.push_back(builder.create<memref::LoadOp>(loc, buffer, i)); |
1040 | }); |
1041 | |
1042 | // Start the inner for-stmt with induction variable j, for moving data[p..i) |
1043 | // to data[p+1..i+1). |
1044 | Value imp = builder.create<arith::SubIOp>(loc, i, p); |
1045 | Value c0 = constantIndex(builder, loc, i: 0); |
1046 | scf::ForOp forOpJ = builder.create<scf::ForOp>(loc, c0, imp, c1); |
1047 | builder.setInsertionPointToStart(forOpJ.getBody()); |
1048 | Value j = forOpJ.getInductionVar(); |
1049 | Value imj = builder.create<arith::SubIOp>(loc, i, j); |
1050 | operands[1] = imj; |
1051 | operands[0] = builder.create<arith::SubIOp>(loc, imj, c1); |
1052 | forEachIJPairInAllBuffers( |
1053 | builder, loc, args: operands, xPerm, ny, |
1054 | bodyBuilder: [&](uint64_t unused, Value imjm1, Value imj, Value buffer) { |
1055 | Value t = builder.create<memref::LoadOp>(loc, buffer, imjm1); |
1056 | builder.create<memref::StoreOp>(loc, t, buffer, imj); |
1057 | }); |
1058 | |
1059 | // Store the value at data[i] to data[p]. |
1060 | builder.setInsertionPointAfter(forOpJ); |
1061 | operands[0] = operands[1] = p; |
1062 | forEachIJPairInAllBuffers( |
1063 | builder, loc, args: operands, xPerm, ny, |
1064 | bodyBuilder: [&](uint64_t k, Value p, Value usused, Value buffer) { |
1065 | builder.create<memref::StoreOp>(loc, d[k], buffer, p); |
1066 | }); |
1067 | |
1068 | builder.setInsertionPointAfter(forOpI); |
1069 | builder.create<func::ReturnOp>(loc); |
1070 | } |
1071 | |
1072 | /// Creates a function to perform quick sort or a hybrid quick sort on the |
1073 | /// values in the range of index [lo, hi). |
1074 | // |
1075 | // |
1076 | // When nTrailingP == 0, the generated IR corresponds to this C like algorithm: |
1077 | // void quickSort(lo, hi, data) { |
1078 | // while (lo + 1 < hi) { |
1079 | // p = partition(low, high, data); |
1080 | // if (len(lo, p) < len(p+1, hi)) { |
1081 | // quickSort(lo, p, data); |
1082 | // lo = p+1; |
1083 | // } else { |
1084 | // quickSort(p + 1, hi, data); |
1085 | // hi = p; |
1086 | // } |
1087 | // } |
1088 | // } |
1089 | // |
1090 | // When nTrailingP == 1, the generated IR corresponds to this C like algorithm: |
1091 | // void hybridQuickSort(lo, hi, data, depthLimit) { |
1092 | // while (lo + 1 < hi) { |
1093 | // len = hi - lo; |
1094 | // if (len <= limit) { |
1095 | // insertionSort(lo, hi, data); |
1096 | // } else { |
1097 | // depthLimit --; |
1098 | // if (depthLimit <= 0) { |
1099 | // heapSort(lo, hi, data); |
1100 | // } else { |
1101 | // p = partition(low, high, data); |
1102 | // if (len(lo, p) < len(p+1, hi)) { |
1103 | // quickSort(lo, p, data, depthLimit); |
1104 | // lo = p+1; |
1105 | // } else { |
1106 | // quickSort(p + 1, hi, data, depthLimit); |
1107 | // hi = p; |
1108 | // } |
1109 | // } |
1110 | // } |
1111 | // } |
1112 | // } |
1113 | // |
1114 | static void createQuickSortFunc(OpBuilder &builder, ModuleOp module, |
1115 | func::FuncOp func, AffineMap xPerm, uint64_t ny, |
1116 | uint32_t nTrailingP) { |
1117 | assert(nTrailingP == 1 || nTrailingP == 0); |
1118 | bool isHybrid = (nTrailingP == 1); |
1119 | OpBuilder::InsertionGuard insertionGuard(builder); |
1120 | Block *entryBlock = func.addEntryBlock(); |
1121 | builder.setInsertionPointToStart(entryBlock); |
1122 | |
1123 | Location loc = func.getLoc(); |
1124 | SmallVector<Value> args; |
1125 | args.append(in_start: entryBlock->getArguments().begin(), |
1126 | in_end: entryBlock->getArguments().end()); |
1127 | Value lo = args[loIdx]; |
1128 | Value hi = args[hiIdx]; |
1129 | SmallVector<Type, 2> types(2, lo.getType()); // Only two types. |
1130 | scf::WhileOp whileOp = |
1131 | builder.create<scf::WhileOp>(loc, types, SmallVector<Value, 2>{lo, hi}); |
1132 | |
1133 | // The before-region of the WhileOp. |
1134 | Block *before = |
1135 | builder.createBlock(&whileOp.getBefore(), {}, types, {loc, loc}); |
1136 | builder.setInsertionPointToEnd(before); |
1137 | lo = before->getArgument(i: 0); |
1138 | hi = before->getArgument(i: 1); |
1139 | Value loP1 = |
1140 | builder.create<arith::AddIOp>(loc, lo, constantIndex(builder, loc, 1)); |
1141 | Value needSort = |
1142 | builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, loP1, hi); |
1143 | builder.create<scf::ConditionOp>(loc, needSort, before->getArguments()); |
1144 | |
1145 | // The after-region of the WhileOp. |
1146 | Block *after = |
1147 | builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc}); |
1148 | builder.setInsertionPointToEnd(after); |
1149 | lo = after->getArgument(i: 0); |
1150 | hi = after->getArgument(i: 1); |
1151 | args[0] = lo; |
1152 | args[1] = hi; |
1153 | |
1154 | if (isHybrid) { |
1155 | Value len = builder.create<arith::SubIOp>(loc, hi, lo); |
1156 | Value lenLimit = constantIndex(builder, loc, i: 30); |
1157 | Value lenCond = builder.create<arith::CmpIOp>( |
1158 | loc, arith::CmpIPredicate::ule, len, lenLimit); |
1159 | scf::IfOp lenIf = |
1160 | builder.create<scf::IfOp>(loc, types, lenCond, /*else=*/true); |
1161 | |
1162 | // When len <= limit. |
1163 | builder.setInsertionPointToStart(&lenIf.getThenRegion().front()); |
1164 | FlatSymbolRefAttr insertionSortFunc = getMangledSortHelperFunc( |
1165 | builder, func, TypeRange(), kSortStableFuncNamePrefix, xPerm, ny, |
1166 | ValueRange(args).drop_back(nTrailingP), createSortStableFunc); |
1167 | builder.create<func::CallOp>(loc, insertionSortFunc, TypeRange(), |
1168 | ValueRange(args).drop_back(nTrailingP)); |
1169 | builder.create<scf::YieldOp>(loc, ValueRange{lo, lo}); |
1170 | |
1171 | // When len > limit. |
1172 | builder.setInsertionPointToStart(&lenIf.getElseRegion().front()); |
1173 | Value depthLimit = args.back(); |
1174 | depthLimit = builder.create<arith::SubIOp>(loc, depthLimit, |
1175 | constantI64(builder, loc, 1)); |
1176 | Value depthCond = |
1177 | builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ule, |
1178 | depthLimit, constantI64(builder, loc, 0)); |
1179 | scf::IfOp depthIf = |
1180 | builder.create<scf::IfOp>(loc, types, depthCond, /*else=*/true); |
1181 | |
1182 | // When depth exceeds limit. |
1183 | builder.setInsertionPointToStart(&depthIf.getThenRegion().front()); |
1184 | FlatSymbolRefAttr heapSortFunc = getMangledSortHelperFunc( |
1185 | builder, func, TypeRange(), kHeapSortFuncNamePrefix, xPerm, ny, |
1186 | ValueRange(args).drop_back(nTrailingP), createHeapSortFunc); |
1187 | builder.create<func::CallOp>(loc, heapSortFunc, TypeRange(), |
1188 | ValueRange(args).drop_back(nTrailingP)); |
1189 | builder.create<scf::YieldOp>(loc, ValueRange{lo, lo}); |
1190 | |
1191 | // When depth doesn't exceed limit. |
1192 | builder.setInsertionPointToStart(&depthIf.getElseRegion().front()); |
1193 | args.back() = depthLimit; |
1194 | std::tie(lo, hi) = |
1195 | createQuickSort(builder, module, func, args, xPerm, ny, nTrailingP); |
1196 | builder.create<scf::YieldOp>(loc, ValueRange{lo, hi}); |
1197 | |
1198 | builder.setInsertionPointAfter(depthIf); |
1199 | lo = depthIf.getResult(0); |
1200 | hi = depthIf.getResult(1); |
1201 | builder.create<scf::YieldOp>(loc, ValueRange{lo, hi}); |
1202 | |
1203 | builder.setInsertionPointAfter(lenIf); |
1204 | lo = lenIf.getResult(0); |
1205 | hi = lenIf.getResult(1); |
1206 | } else { |
1207 | std::tie(lo, hi) = |
1208 | createQuickSort(builder, module, func, args, xPerm, ny, nTrailingP); |
1209 | } |
1210 | |
1211 | // New [lo, hi) for the next while-loop iteration. |
1212 | builder.create<scf::YieldOp>(loc, ValueRange{lo, hi}); |
1213 | |
1214 | // After the while-loop. |
1215 | builder.setInsertionPointAfter(whileOp); |
1216 | builder.create<func::ReturnOp>(loc); |
1217 | } |
1218 | |
1219 | /// Implements the rewriting for operator sort and sort_coo. |
1220 | template <typename OpTy> |
1221 | LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, AffineMap xPerm, |
1222 | uint64_t ny, PatternRewriter &rewriter) { |
1223 | Location loc = op.getLoc(); |
1224 | SmallVector<Value> operands{constantIndex(builder&: rewriter, loc, i: 0), op.getN()}; |
1225 | |
1226 | // Convert `values` to have dynamic shape and append them to `operands`. |
1227 | for (Value v : xys) { |
1228 | auto mtp = getMemRefType(v); |
1229 | if (!mtp.isDynamicDim(0)) { |
1230 | auto newMtp = |
1231 | MemRefType::get({ShapedType::kDynamic}, mtp.getElementType()); |
1232 | v = rewriter.create<memref::CastOp>(loc, newMtp, v); |
1233 | } |
1234 | operands.push_back(Elt: v); |
1235 | } |
1236 | |
1237 | auto insertPoint = op->template getParentOfType<func::FuncOp>(); |
1238 | if (!insertPoint) |
1239 | return failure(); |
1240 | |
1241 | SmallString<32> funcName; |
1242 | FuncGeneratorType funcGenerator; |
1243 | uint32_t nTrailingP = 0; |
1244 | switch (op.getAlgorithm()) { |
1245 | case SparseTensorSortKind::HybridQuickSort: { |
1246 | funcName = kHybridQuickSortFuncNamePrefix; |
1247 | funcGenerator = createQuickSortFunc; |
1248 | nTrailingP = 1; |
1249 | // As a heuristics, set depthLimit = 2 * log2(n). |
1250 | Value lo = operands[loIdx]; |
1251 | Value hi = operands[hiIdx]; |
1252 | Value len = rewriter.create<arith::IndexCastOp>( |
1253 | loc, rewriter.getI64Type(), |
1254 | rewriter.create<arith::SubIOp>(loc, hi, lo)); |
1255 | Value depthLimit = rewriter.create<arith::SubIOp>( |
1256 | loc, constantI64(rewriter, loc, 64), |
1257 | rewriter.create<math::CountLeadingZerosOp>(loc, len)); |
1258 | operands.push_back(Elt: depthLimit); |
1259 | break; |
1260 | } |
1261 | case SparseTensorSortKind::QuickSort: |
1262 | funcName = kQuickSortFuncNamePrefix; |
1263 | funcGenerator = createQuickSortFunc; |
1264 | break; |
1265 | case SparseTensorSortKind::InsertionSortStable: |
1266 | funcName = kSortStableFuncNamePrefix; |
1267 | funcGenerator = createSortStableFunc; |
1268 | break; |
1269 | case SparseTensorSortKind::HeapSort: |
1270 | funcName = kHeapSortFuncNamePrefix; |
1271 | funcGenerator = createHeapSortFunc; |
1272 | break; |
1273 | } |
1274 | |
1275 | FlatSymbolRefAttr func = |
1276 | getMangledSortHelperFunc(rewriter, insertPoint, TypeRange(), funcName, |
1277 | xPerm, ny, operands, funcGenerator, nTrailingP); |
1278 | rewriter.replaceOpWithNewOp<func::CallOp>(op, func, TypeRange(), operands); |
1279 | return success(); |
1280 | } |
1281 | |
1282 | //===---------------------------------------------------------------------===// |
1283 | // The actual sparse buffer rewriting rules. |
1284 | //===---------------------------------------------------------------------===// |
1285 | |
1286 | namespace { |
1287 | /// Sparse rewriting rule for the push_back operator. |
1288 | struct PushBackRewriter : OpRewritePattern<PushBackOp> { |
1289 | public: |
1290 | using OpRewritePattern<PushBackOp>::OpRewritePattern; |
1291 | PushBackRewriter(MLIRContext *context, bool enableInit) |
1292 | : OpRewritePattern(context), enableBufferInitialization(enableInit) {} |
1293 | LogicalResult matchAndRewrite(PushBackOp op, |
1294 | PatternRewriter &rewriter) const override { |
1295 | // Rewrite push_back(buffer, value, n) to: |
1296 | // new_size = size(buffer) + n |
1297 | // if (new_size > capacity(buffer)) |
1298 | // while new_size > new_capacity |
1299 | // new_capacity = new_capacity*2 |
1300 | // new_buffer = realloc(buffer, new_capacity) |
1301 | // buffer = new_buffer |
1302 | // subBuffer = subviewof(buffer) |
1303 | // linalg.fill subBuffer value |
1304 | // |
1305 | // size(buffer) += n |
1306 | // |
1307 | // The capacity check is skipped when the attribute inbounds is presented. |
1308 | Location loc = op->getLoc(); |
1309 | Value c0 = constantIndex(builder&: rewriter, loc, i: 0); |
1310 | Value buffer = op.getInBuffer(); |
1311 | Value capacity = rewriter.create<memref::DimOp>(loc, buffer, c0); |
1312 | Value size = op.getCurSize(); |
1313 | Value value = op.getValue(); |
1314 | |
1315 | Value n = op.getN() ? op.getN() : constantIndex(builder&: rewriter, loc, i: 1); |
1316 | Value newSize = rewriter.create<arith::AddIOp>(loc, size, n); |
1317 | auto nValue = dyn_cast_or_null<arith::ConstantIndexOp>(Val: n.getDefiningOp()); |
1318 | bool nIsOne = (nValue && nValue.value() == 1); |
1319 | |
1320 | if (!op.getInbounds()) { |
1321 | Value cond = rewriter.create<arith::CmpIOp>( |
1322 | loc, arith::CmpIPredicate::ugt, newSize, capacity); |
1323 | |
1324 | Value c2 = constantIndex(builder&: rewriter, loc, i: 2); |
1325 | auto bufferType = |
1326 | MemRefType::get({ShapedType::kDynamic}, value.getType()); |
1327 | scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, bufferType, cond, |
1328 | /*else=*/true); |
1329 | // True branch. |
1330 | rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); |
1331 | if (nIsOne) { |
1332 | capacity = rewriter.create<arith::MulIOp>(loc, capacity, c2); |
1333 | } else { |
1334 | // Use a do-while loop to calculate the new capacity as follows: |
1335 | // do { new_capacity *= 2 } while (size > new_capacity) |
1336 | scf::WhileOp whileOp = |
1337 | rewriter.create<scf::WhileOp>(loc, capacity.getType(), capacity); |
1338 | |
1339 | // The before-region of the WhileOp. |
1340 | Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, |
1341 | {capacity.getType()}, {loc}); |
1342 | rewriter.setInsertionPointToEnd(before); |
1343 | |
1344 | capacity = |
1345 | rewriter.create<arith::MulIOp>(loc, before->getArgument(0), c2); |
1346 | cond = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ugt, |
1347 | newSize, capacity); |
1348 | rewriter.create<scf::ConditionOp>(loc, cond, ValueRange{capacity}); |
1349 | // The after-region of the WhileOp. |
1350 | Block *after = rewriter.createBlock(&whileOp.getAfter(), {}, |
1351 | {capacity.getType()}, {loc}); |
1352 | rewriter.setInsertionPointToEnd(after); |
1353 | rewriter.create<scf::YieldOp>(loc, after->getArguments()); |
1354 | |
1355 | rewriter.setInsertionPointAfter(whileOp); |
1356 | capacity = whileOp.getResult(0); |
1357 | } |
1358 | |
1359 | Value newBuffer = |
1360 | rewriter.create<memref::ReallocOp>(loc, bufferType, buffer, capacity); |
1361 | if (enableBufferInitialization) { |
1362 | Value fillSize = rewriter.create<arith::SubIOp>(loc, capacity, newSize); |
1363 | Value fillValue = constantZero(builder&: rewriter, loc, tp: value.getType()); |
1364 | Value subBuffer = rewriter.create<memref::SubViewOp>( |
1365 | loc, newBuffer, /*offset=*/ValueRange{newSize}, |
1366 | /*size=*/ValueRange{fillSize}, |
1367 | /*step=*/ValueRange{constantIndex(rewriter, loc, 1)}); |
1368 | rewriter.create<linalg::FillOp>(loc, fillValue, subBuffer); |
1369 | } |
1370 | rewriter.create<scf::YieldOp>(loc, newBuffer); |
1371 | |
1372 | // False branch. |
1373 | rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); |
1374 | rewriter.create<scf::YieldOp>(loc, buffer); |
1375 | |
1376 | // Prepare for adding the value to the end of the buffer. |
1377 | rewriter.setInsertionPointAfter(ifOp); |
1378 | buffer = ifOp.getResult(0); |
1379 | } |
1380 | |
1381 | // Add the value to the end of the buffer. |
1382 | if (nIsOne) { |
1383 | rewriter.create<memref::StoreOp>(loc, value, buffer, size); |
1384 | } else { |
1385 | Value subBuffer = rewriter.create<memref::SubViewOp>( |
1386 | loc, buffer, /*offset=*/ValueRange{size}, /*size=*/ValueRange{n}, |
1387 | /*step=*/ValueRange{constantIndex(rewriter, loc, 1)}); |
1388 | rewriter.create<linalg::FillOp>(loc, value, subBuffer); |
1389 | } |
1390 | |
1391 | // Update the buffer size. |
1392 | rewriter.replaceOp(op, {buffer, newSize}); |
1393 | return success(); |
1394 | } |
1395 | |
1396 | private: |
1397 | bool enableBufferInitialization; |
1398 | }; |
1399 | |
1400 | /// Sparse rewriting rule for the sort_coo operator. |
1401 | struct SortRewriter : public OpRewritePattern<SortOp> { |
1402 | public: |
1403 | using OpRewritePattern<SortOp>::OpRewritePattern; |
1404 | |
1405 | LogicalResult matchAndRewrite(SortOp op, |
1406 | PatternRewriter &rewriter) const override { |
1407 | SmallVector<Value> xys; |
1408 | xys.push_back(Elt: op.getXy()); |
1409 | xys.append(op.getYs().begin(), op.getYs().end()); |
1410 | |
1411 | auto xPerm = op.getPermMap(); |
1412 | uint64_t ny = 0; |
1413 | if (auto nyAttr = op.getNyAttr()) |
1414 | ny = nyAttr.getInt(); |
1415 | |
1416 | return matchAndRewriteSortOp(op, xys, xPerm, ny, rewriter); |
1417 | } |
1418 | }; |
1419 | |
1420 | } // namespace |
1421 | |
1422 | //===---------------------------------------------------------------------===// |
1423 | // Methods that add patterns described in this file to a pattern list. |
1424 | //===---------------------------------------------------------------------===// |
1425 | |
1426 | void mlir::populateSparseBufferRewriting(RewritePatternSet &patterns, |
1427 | bool enableBufferInitialization) { |
1428 | patterns.add<PushBackRewriter>(arg: patterns.getContext(), |
1429 | args&: enableBufferInitialization); |
1430 | patterns.add<SortRewriter>(arg: patterns.getContext()); |
1431 | } |
1432 | |