| 1 | //===- SparseBufferRewriting.cpp - Sparse buffer rewriting rules ----------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements rewriting rules that are specific to sparse tensor |
| 10 | // primitives with memref operands. |
| 11 | // |
| 12 | //===----------------------------------------------------------------------===// |
| 13 | |
| 14 | #include "Utils/CodegenUtils.h" |
| 15 | |
| 16 | #include "mlir/Dialect/Arith/IR/Arith.h" |
| 17 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
| 18 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
| 19 | #include "mlir/Dialect/Math/IR/Math.h" |
| 20 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| 21 | #include "mlir/Dialect/SCF/IR/SCF.h" |
| 22 | #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" |
| 23 | #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" |
| 24 | #include "mlir/Support/LLVM.h" |
| 25 | |
| 26 | using namespace mlir; |
| 27 | using namespace mlir::sparse_tensor; |
| 28 | |
| 29 | //===---------------------------------------------------------------------===// |
| 30 | // Helper methods for the actual rewriting rules. |
| 31 | //===---------------------------------------------------------------------===// |
| 32 | |
| 33 | static constexpr uint64_t loIdx = 0; |
| 34 | static constexpr uint64_t hiIdx = 1; |
| 35 | static constexpr uint64_t xStartIdx = 2; |
| 36 | |
| 37 | static constexpr const char kPartitionFuncNamePrefix[] = "_sparse_partition_" ; |
| 38 | static constexpr const char kBinarySearchFuncNamePrefix[] = |
| 39 | "_sparse_binary_search_" ; |
| 40 | static constexpr const char kHybridQuickSortFuncNamePrefix[] = |
| 41 | "_sparse_hybrid_qsort_" ; |
| 42 | static constexpr const char kSortStableFuncNamePrefix[] = |
| 43 | "_sparse_sort_stable_" ; |
| 44 | static constexpr const char kShiftDownFuncNamePrefix[] = "_sparse_shift_down_" ; |
| 45 | static constexpr const char kHeapSortFuncNamePrefix[] = "_sparse_heap_sort_" ; |
| 46 | static constexpr const char kQuickSortFuncNamePrefix[] = "_sparse_qsort_" ; |
| 47 | |
| 48 | using FuncGeneratorType = function_ref<void(OpBuilder &, ModuleOp, func::FuncOp, |
| 49 | AffineMap, uint64_t, uint32_t)>; |
| 50 | |
| 51 | /// Constructs a function name with this format to facilitate quick sort: |
| 52 | /// <namePrefix><xPerm>_<x type>_<y0 type>..._<yn type> for sort |
| 53 | /// <namePrefix><xPerm>_<x type>_coo_<ny>_<y0 type>..._<yn type> for sort_coo |
| 54 | static void getMangledSortHelperFuncName(llvm::raw_svector_ostream &nameOstream, |
| 55 | StringRef namePrefix, AffineMap xPerm, |
| 56 | uint64_t ny, ValueRange operands) { |
| 57 | nameOstream << namePrefix; |
| 58 | for (auto res : xPerm.getResults()) |
| 59 | nameOstream << cast<AffineDimExpr>(Val&: res).getPosition() << "_" ; |
| 60 | |
| 61 | nameOstream << getMemRefType(operands[xStartIdx]).getElementType(); |
| 62 | nameOstream << "_coo_" << ny; |
| 63 | |
| 64 | constexpr uint64_t yBufferOffset = 1; |
| 65 | for (Value v : operands.drop_front(n: xStartIdx + yBufferOffset)) |
| 66 | nameOstream << "_" << getMemRefType(v).getElementType(); |
| 67 | } |
| 68 | |
| 69 | /// Looks up a function that is appropriate for the given operands being |
| 70 | /// sorted, and creates such a function if it doesn't exist yet. The |
| 71 | /// parameters `xPerm` and `ny` tell the number of x and y values provided |
| 72 | /// by the buffer in xStartIdx. |
| 73 | // |
| 74 | // All sorting function generators take (lo, hi, xs, ys) in `operands` as |
| 75 | // parameters for the sorting functions. Other parameters, such as the recursive |
| 76 | // call depth, are appended to the end of the parameter list as |
| 77 | // "trailing parameters". |
| 78 | static FlatSymbolRefAttr getMangledSortHelperFunc( |
| 79 | OpBuilder &builder, func::FuncOp insertPoint, TypeRange resultTypes, |
| 80 | StringRef namePrefix, AffineMap xPerm, uint64_t ny, ValueRange operands, |
| 81 | FuncGeneratorType createFunc, uint32_t nTrailingP = 0) { |
| 82 | SmallString<32> nameBuffer; |
| 83 | llvm::raw_svector_ostream nameOstream(nameBuffer); |
| 84 | getMangledSortHelperFuncName(nameOstream, namePrefix, xPerm, ny, |
| 85 | operands: operands.drop_back(n: nTrailingP)); |
| 86 | |
| 87 | ModuleOp module = insertPoint->getParentOfType<ModuleOp>(); |
| 88 | MLIRContext *context = module.getContext(); |
| 89 | auto result = SymbolRefAttr::get(context, nameOstream.str()); |
| 90 | auto func = module.lookupSymbol<func::FuncOp>(result.getAttr()); |
| 91 | |
| 92 | if (!func) { |
| 93 | // Create the function. |
| 94 | OpBuilder::InsertionGuard insertionGuard(builder); |
| 95 | builder.setInsertionPoint(insertPoint); |
| 96 | Location loc = insertPoint.getLoc(); |
| 97 | func = builder.create<func::FuncOp>( |
| 98 | loc, nameOstream.str(), |
| 99 | FunctionType::get(context, operands.getTypes(), resultTypes)); |
| 100 | func.setPrivate(); |
| 101 | createFunc(builder, module, func, xPerm, ny, nTrailingP); |
| 102 | } |
| 103 | |
| 104 | return result; |
| 105 | } |
| 106 | |
| 107 | /// Creates a code block to process each pair of (xs[i], xs[j]) for sorting. |
| 108 | /// The code to process the value pairs is generated by `bodyBuilder`. |
| 109 | static void forEachIJPairInXs( |
| 110 | OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm, |
| 111 | uint64_t ny, |
| 112 | function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) { |
| 113 | Value cstep = constantIndex(builder, loc, i: xPerm.getNumResults() + ny); |
| 114 | Value iOffset = builder.create<arith::MulIOp>(loc, args[0], cstep); |
| 115 | Value jOffset = builder.create<arith::MulIOp>(loc, args[1], cstep); |
| 116 | for (unsigned k = 0, e = xPerm.getNumResults(); k < e; k++) { |
| 117 | unsigned actualK = cast<AffineDimExpr>(Val: xPerm.getResult(idx: k)).getPosition(); |
| 118 | Value ak = constantIndex(builder, loc, i: actualK); |
| 119 | Value i = builder.create<arith::AddIOp>(loc, ak, iOffset); |
| 120 | Value j = builder.create<arith::AddIOp>(loc, ak, jOffset); |
| 121 | Value buffer = args[xStartIdx]; |
| 122 | |
| 123 | bodyBuilder(k, i, j, buffer); |
| 124 | } |
| 125 | } |
| 126 | |
| 127 | /// Creates a code block to process each pair of (xys[i], xys[j]) for sorting. |
| 128 | /// The code to process the value pairs is generated by `bodyBuilder`. |
| 129 | static void forEachIJPairInAllBuffers( |
| 130 | OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm, |
| 131 | uint64_t ny, |
| 132 | function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) { |
| 133 | |
| 134 | // Create code for the first (xPerm + ny) buffers. |
| 135 | SmallVector<AffineExpr> exps(xPerm.getResults()); |
| 136 | for (unsigned y = 0; y < ny; y++) { |
| 137 | exps.push_back(Elt: builder.getAffineDimExpr(position: y + xPerm.getNumResults())); |
| 138 | } |
| 139 | AffineMap xyPerm = AffineMap::get(dimCount: exps.size(), symbolCount: 0, results: exps, context: builder.getContext()); |
| 140 | assert(xyPerm.isPermutation()); |
| 141 | |
| 142 | forEachIJPairInXs(builder, loc, args, xPerm: xyPerm, ny: 0, bodyBuilder); |
| 143 | |
| 144 | constexpr uint64_t numHandledBuffers = 1; |
| 145 | // Create code for the remaining buffers. |
| 146 | Value i = args[0]; |
| 147 | Value j = args[1]; |
| 148 | for (const auto &arg : |
| 149 | llvm::enumerate(First: args.drop_front(n: xStartIdx + numHandledBuffers))) { |
| 150 | bodyBuilder(arg.index() + xPerm.getNumResults() + ny, i, j, arg.value()); |
| 151 | } |
| 152 | } |
| 153 | |
| 154 | /// Creates a code block for swapping the values in index i and j for all the |
| 155 | /// buffers. |
| 156 | // |
| 157 | // The generated IR corresponds to this C like algorithm: |
| 158 | // swap(x0[i], x0[j]); |
| 159 | // swap(x1[i], x1[j]); |
| 160 | // ... |
| 161 | // swap(xn[i], xn[j]); |
| 162 | // swap(y0[i], y0[j]); |
| 163 | // ... |
| 164 | // swap(yn[i], yn[j]); |
| 165 | static void createSwap(OpBuilder &builder, Location loc, ValueRange args, |
| 166 | AffineMap xPerm, uint64_t ny) { |
| 167 | auto swapOnePair = [&](uint64_t unused, Value i, Value j, Value buffer) { |
| 168 | Value vi = builder.create<memref::LoadOp>(loc, buffer, i); |
| 169 | Value vj = builder.create<memref::LoadOp>(loc, buffer, j); |
| 170 | builder.create<memref::StoreOp>(loc, vj, buffer, i); |
| 171 | builder.create<memref::StoreOp>(loc, vi, buffer, j); |
| 172 | }; |
| 173 | |
| 174 | forEachIJPairInAllBuffers(builder, loc, args, xPerm, ny, bodyBuilder: swapOnePair); |
| 175 | } |
| 176 | |
| 177 | /// Creates code to compare all the (xs[i], xs[j]) pairs. The method to compare |
| 178 | /// each pair is create via `compareBuilder`. |
| 179 | static Value createInlinedCompareImplementation( |
| 180 | OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm, |
| 181 | uint64_t ny, |
| 182 | function_ref<Value(OpBuilder &, Location, Value, Value, Value, bool, bool)> |
| 183 | compareBuilder) { |
| 184 | Value result; |
| 185 | auto bodyBuilder = [&](uint64_t k, Value i, Value j, Value buffer) { |
| 186 | bool isFirstDim = (k == 0); |
| 187 | bool isLastDim = (k == xPerm.getNumResults() - 1); |
| 188 | Value val = |
| 189 | compareBuilder(builder, loc, i, j, buffer, isFirstDim, isLastDim); |
| 190 | if (isFirstDim) { |
| 191 | result = val; |
| 192 | } else if (!isLastDim) { |
| 193 | OpBuilder::InsertionGuard insertionGuard(builder); |
| 194 | auto ifOp = cast<scf::IfOp>(val.getDefiningOp()); |
| 195 | builder.setInsertionPointAfter(ifOp); |
| 196 | builder.create<scf::YieldOp>(loc, ifOp.getResult(0)); |
| 197 | } |
| 198 | }; |
| 199 | |
| 200 | forEachIJPairInXs(builder, loc, args, xPerm, ny, bodyBuilder); |
| 201 | |
| 202 | builder.setInsertionPointAfterValue(result); |
| 203 | return result; |
| 204 | } |
| 205 | |
| 206 | /// Generates code to compare whether x[i] is equal to x[j] and returns the |
| 207 | /// result of the comparison. |
| 208 | static Value createEqCompare(OpBuilder &builder, Location loc, Value i, Value j, |
| 209 | Value x, bool isFirstDim, bool isLastDim) { |
| 210 | Value vi = builder.create<memref::LoadOp>(loc, x, i); |
| 211 | Value vj = builder.create<memref::LoadOp>(loc, x, j); |
| 212 | |
| 213 | Value res; |
| 214 | if (isLastDim) { |
| 215 | res = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, vi, vj); |
| 216 | // For 1D, we create a compare without any control flow. Otherwise, we |
| 217 | // create YieldOp to return the result in the nested if-stmt. |
| 218 | if (!isFirstDim) |
| 219 | builder.create<scf::YieldOp>(loc, res); |
| 220 | } else { |
| 221 | Value ne = |
| 222 | builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, vi, vj); |
| 223 | scf::IfOp ifOp = builder.create<scf::IfOp>(loc, builder.getIntegerType(1), |
| 224 | ne, /*else=*/true); |
| 225 | // If (x[i] != x[j]). |
| 226 | builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); |
| 227 | Value f = constantI1(builder, loc, b: false); |
| 228 | builder.create<scf::YieldOp>(loc, f); |
| 229 | |
| 230 | // If (x[i] == x[j]). Set up the insertion point for the nested if-stmt that |
| 231 | // checks the remaining dimensions. |
| 232 | builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); |
| 233 | res = ifOp.getResult(0); |
| 234 | } |
| 235 | |
| 236 | return res; |
| 237 | } |
| 238 | |
| 239 | /// Creates code to compare whether xs[i] is equal to xs[j]. |
| 240 | // |
| 241 | // The generate IR corresponds to this C like algorithm: |
| 242 | // if (x0[i] != x0[j]) |
| 243 | // return false; |
| 244 | // else |
| 245 | // if (x1[i] != x1[j]) |
| 246 | // return false; |
| 247 | // else if (x2[2] != x2[j])) |
| 248 | // and so on ... |
| 249 | static Value createInlinedEqCompare(OpBuilder &builder, Location loc, |
| 250 | ValueRange args, AffineMap xPerm, |
| 251 | uint64_t ny, uint32_t nTrailingP = 0) { |
| 252 | // Compare functions don't use trailing parameters. |
| 253 | (void)nTrailingP; |
| 254 | assert(nTrailingP == 0); |
| 255 | return createInlinedCompareImplementation(builder, loc, args, xPerm, ny, |
| 256 | compareBuilder: createEqCompare); |
| 257 | } |
| 258 | |
| 259 | /// Generates code to compare whether x[i] is less than x[j] and returns the |
| 260 | /// result of the comparison. |
| 261 | static Value createLessThanCompare(OpBuilder &builder, Location loc, Value i, |
| 262 | Value j, Value x, bool isFirstDim, |
| 263 | bool isLastDim) { |
| 264 | Value vi = builder.create<memref::LoadOp>(loc, x, i); |
| 265 | Value vj = builder.create<memref::LoadOp>(loc, x, j); |
| 266 | |
| 267 | Value res; |
| 268 | if (isLastDim) { |
| 269 | res = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, vi, vj); |
| 270 | // For 1D, we create a compare without any control flow. Otherwise, we |
| 271 | // create YieldOp to return the result in the nested if-stmt. |
| 272 | if (!isFirstDim) |
| 273 | builder.create<scf::YieldOp>(loc, res); |
| 274 | } else { |
| 275 | Value ne = |
| 276 | builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, vi, vj); |
| 277 | scf::IfOp ifOp = builder.create<scf::IfOp>(loc, builder.getIntegerType(1), |
| 278 | ne, /*else=*/true); |
| 279 | // If (x[i] != x[j]). |
| 280 | builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); |
| 281 | Value lt = |
| 282 | builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, vi, vj); |
| 283 | builder.create<scf::YieldOp>(loc, lt); |
| 284 | |
| 285 | // If (x[i] == x[j]). Set up the insertion point for the nested if-stmt that |
| 286 | // checks the remaining dimensions. |
| 287 | builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); |
| 288 | res = ifOp.getResult(0); |
| 289 | } |
| 290 | |
| 291 | return res; |
| 292 | } |
| 293 | |
| 294 | /// Creates code to compare whether xs[i] is less than xs[j]. |
| 295 | // |
| 296 | // The generate IR corresponds to this C like algorithm: |
| 297 | // if (x0[i] != x0[j]) |
| 298 | // return x0[i] < x0[j]; |
| 299 | // else if (x1[j] != x1[i]) |
| 300 | // return x1[i] < x1[j]; |
| 301 | // else |
| 302 | // and so on ... |
| 303 | static Value createInlinedLessThan(OpBuilder &builder, Location loc, |
| 304 | ValueRange args, AffineMap xPerm, |
| 305 | uint64_t ny, uint32_t nTrailingP = 0) { |
| 306 | // Compare functions don't use trailing parameters. |
| 307 | (void)nTrailingP; |
| 308 | assert(nTrailingP == 0); |
| 309 | return createInlinedCompareImplementation(builder, loc, args, xPerm, ny, |
| 310 | compareBuilder: createLessThanCompare); |
| 311 | } |
| 312 | |
| 313 | /// Creates a function to use a binary search to find the insertion point for |
| 314 | /// inserting xs[hi] to the sorted values xs[lo..hi). |
| 315 | // |
| 316 | // The generate IR corresponds to this C like algorithm: |
| 317 | // p = hi |
| 318 | // while (lo < hi) |
| 319 | // mid = (lo + hi) >> 1 |
| 320 | // if (xs[p] < xs[mid]) |
| 321 | // hi = mid |
| 322 | // else |
| 323 | // lo = mid - 1 |
| 324 | // return lo; |
| 325 | // |
| 326 | static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module, |
| 327 | func::FuncOp func, AffineMap xPerm, |
| 328 | uint64_t ny, uint32_t nTrailingP = 0) { |
| 329 | // Binary search doesn't use trailing parameters. |
| 330 | (void)nTrailingP; |
| 331 | assert(nTrailingP == 0); |
| 332 | OpBuilder::InsertionGuard insertionGuard(builder); |
| 333 | Block *entryBlock = func.addEntryBlock(); |
| 334 | builder.setInsertionPointToStart(entryBlock); |
| 335 | |
| 336 | Location loc = func.getLoc(); |
| 337 | ValueRange args = entryBlock->getArguments(); |
| 338 | Value p = args[hiIdx]; |
| 339 | SmallVector<Type, 2> types(2, p.getType()); // Only two types. |
| 340 | scf::WhileOp whileOp = builder.create<scf::WhileOp>( |
| 341 | loc, types, SmallVector<Value, 2>{args[loIdx], args[hiIdx]}); |
| 342 | |
| 343 | // The before-region of the WhileOp. |
| 344 | Block *before = |
| 345 | builder.createBlock(&whileOp.getBefore(), {}, types, {loc, loc}); |
| 346 | builder.setInsertionPointToEnd(before); |
| 347 | Value cond1 = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, |
| 348 | before->getArgument(0), |
| 349 | before->getArgument(1)); |
| 350 | builder.create<scf::ConditionOp>(loc, cond1, before->getArguments()); |
| 351 | |
| 352 | // The after-region of the WhileOp. |
| 353 | Block *after = |
| 354 | builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc}); |
| 355 | builder.setInsertionPointToEnd(after); |
| 356 | Value lo = after->getArgument(i: 0); |
| 357 | Value hi = after->getArgument(i: 1); |
| 358 | // Compute mid = (lo + hi) >> 1. |
| 359 | Value c1 = constantIndex(builder, loc, i: 1); |
| 360 | Value mid = builder.create<arith::ShRUIOp>( |
| 361 | loc, builder.create<arith::AddIOp>(loc, lo, hi), c1); |
| 362 | Value midp1 = builder.create<arith::AddIOp>(loc, mid, c1); |
| 363 | |
| 364 | // Compare xs[p] < xs[mid]. |
| 365 | SmallVector<Value> compareOperands{p, mid}; |
| 366 | constexpr uint64_t numXBuffers = 1; |
| 367 | compareOperands.append(in_start: args.begin() + xStartIdx, |
| 368 | in_end: args.begin() + xStartIdx + numXBuffers); |
| 369 | Value cond2 = createInlinedLessThan(builder, loc, args: compareOperands, xPerm, ny); |
| 370 | // Update lo and hi for the WhileOp as follows: |
| 371 | // if (xs[p] < xs[mid])) |
| 372 | // hi = mid; |
| 373 | // else |
| 374 | // lo = mid + 1; |
| 375 | Value newLo = builder.create<arith::SelectOp>(loc, cond2, lo, midp1); |
| 376 | Value newHi = builder.create<arith::SelectOp>(loc, cond2, mid, hi); |
| 377 | builder.create<scf::YieldOp>(loc, ValueRange{newLo, newHi}); |
| 378 | |
| 379 | builder.setInsertionPointAfter(whileOp); |
| 380 | builder.create<func::ReturnOp>(loc, whileOp.getResult(0)); |
| 381 | } |
| 382 | |
| 383 | /// Creates code to advance i in a loop based on xs[p] as follows: |
| 384 | /// while (xs[i] < xs[p]) i += step (step > 0) |
| 385 | /// or |
| 386 | /// while (xs[i] > xs[p]) i += step (step < 0) |
| 387 | /// The routine returns i as well as a boolean value to indicate whether |
| 388 | /// xs[i] == xs[p]. |
| 389 | static std::pair<Value, Value> createScanLoop(OpBuilder &builder, |
| 390 | ModuleOp module, |
| 391 | func::FuncOp func, ValueRange xs, |
| 392 | Value i, Value p, AffineMap xPerm, |
| 393 | uint64_t ny, int step) { |
| 394 | Location loc = func.getLoc(); |
| 395 | scf::WhileOp whileOp = |
| 396 | builder.create<scf::WhileOp>(loc, TypeRange{i.getType()}, ValueRange{i}); |
| 397 | |
| 398 | Block *before = |
| 399 | builder.createBlock(&whileOp.getBefore(), {}, {i.getType()}, {loc}); |
| 400 | builder.setInsertionPointToEnd(before); |
| 401 | SmallVector<Value> compareOperands; |
| 402 | if (step > 0) { |
| 403 | compareOperands.push_back(Elt: before->getArgument(i: 0)); |
| 404 | compareOperands.push_back(Elt: p); |
| 405 | } else { |
| 406 | assert(step < 0); |
| 407 | compareOperands.push_back(Elt: p); |
| 408 | compareOperands.push_back(Elt: before->getArgument(i: 0)); |
| 409 | } |
| 410 | compareOperands.append(in_start: xs.begin(), in_end: xs.end()); |
| 411 | Value cond = createInlinedLessThan(builder, loc, args: compareOperands, xPerm, ny); |
| 412 | builder.create<scf::ConditionOp>(loc, cond, before->getArguments()); |
| 413 | |
| 414 | Block *after = |
| 415 | builder.createBlock(&whileOp.getAfter(), {}, {i.getType()}, {loc}); |
| 416 | builder.setInsertionPointToEnd(after); |
| 417 | Value cs = constantIndex(builder, loc, i: step); |
| 418 | i = builder.create<arith::AddIOp>(loc, after->getArgument(0), cs); |
| 419 | builder.create<scf::YieldOp>(loc, ValueRange{i}); |
| 420 | i = whileOp.getResult(0); |
| 421 | |
| 422 | builder.setInsertionPointAfter(whileOp); |
| 423 | compareOperands[0] = i; |
| 424 | compareOperands[1] = p; |
| 425 | Value compareEq = |
| 426 | createInlinedEqCompare(builder, loc, args: compareOperands, xPerm, ny); |
| 427 | |
| 428 | return std::make_pair(whileOp.getResult(0), compareEq); |
| 429 | } |
| 430 | |
| 431 | /// Creates and returns an IfOp to compare two elements and swap the elements |
| 432 | /// if compareFunc(data[b], data[a]) returns true. The new insertion point is |
| 433 | /// right after the swap instructions. |
| 434 | static scf::IfOp createCompareThenSwap(OpBuilder &builder, Location loc, |
| 435 | AffineMap xPerm, uint64_t ny, |
| 436 | SmallVectorImpl<Value> &swapOperands, |
| 437 | SmallVectorImpl<Value> &compareOperands, |
| 438 | Value a, Value b) { |
| 439 | // Compare(data[b], data[a]). |
| 440 | compareOperands[0] = b; |
| 441 | compareOperands[1] = a; |
| 442 | Value cond = createInlinedLessThan(builder, loc, args: compareOperands, xPerm, ny); |
| 443 | scf::IfOp ifOp = builder.create<scf::IfOp>(loc, cond, /*else=*/false); |
| 444 | builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); |
| 445 | swapOperands[0] = b; |
| 446 | swapOperands[1] = a; |
| 447 | createSwap(builder, loc, args: swapOperands, xPerm, ny); |
| 448 | return ifOp; |
| 449 | } |
| 450 | |
| 451 | /// Creates code to insert the 3rd element to a list of two sorted elements. |
| 452 | static void createInsert3rd(OpBuilder &builder, Location loc, AffineMap xPerm, |
| 453 | uint64_t ny, SmallVectorImpl<Value> &swapOperands, |
| 454 | SmallVectorImpl<Value> &compareOperands, Value v0, |
| 455 | Value v1, Value v2) { |
| 456 | scf::IfOp ifOp = createCompareThenSwap(builder, loc, xPerm, ny, swapOperands, |
| 457 | compareOperands, v1, v2); |
| 458 | createCompareThenSwap(builder, loc, xPerm, ny, swapOperands, compareOperands, |
| 459 | a: v0, b: v1); |
| 460 | builder.setInsertionPointAfter(ifOp); |
| 461 | } |
| 462 | |
| 463 | /// Creates code to sort 3 elements. |
| 464 | static void createSort3(OpBuilder &builder, Location loc, AffineMap xPerm, |
| 465 | uint64_t ny, SmallVectorImpl<Value> &swapOperands, |
| 466 | SmallVectorImpl<Value> &compareOperands, Value v0, |
| 467 | Value v1, Value v2) { |
| 468 | // Sort the first 2 elements. |
| 469 | scf::IfOp ifOp1 = createCompareThenSwap(builder, loc, xPerm, ny, swapOperands, |
| 470 | compareOperands, v0, v1); |
| 471 | builder.setInsertionPointAfter(ifOp1); |
| 472 | |
| 473 | // Insert the 3th element. |
| 474 | createInsert3rd(builder, loc, xPerm, ny, swapOperands, compareOperands, v0, |
| 475 | v1, v2); |
| 476 | } |
| 477 | |
| 478 | /// Creates code to sort 5 elements. |
| 479 | static void createSort5(OpBuilder &builder, Location loc, AffineMap xPerm, |
| 480 | uint64_t ny, SmallVectorImpl<Value> &swapOperands, |
| 481 | SmallVectorImpl<Value> &compareOperands, Value v0, |
| 482 | Value v1, Value v2, Value v3, Value v4) { |
| 483 | // Sort the first 3 elements. |
| 484 | createSort3(builder, loc, xPerm, ny, swapOperands, compareOperands, v0, v1, |
| 485 | v2); |
| 486 | |
| 487 | auto insert4th = [&]() { |
| 488 | scf::IfOp ifOp = createCompareThenSwap( |
| 489 | builder, loc, xPerm, ny, swapOperands, compareOperands, v2, v3); |
| 490 | createInsert3rd(builder, loc, xPerm, ny, swapOperands, compareOperands, v0, |
| 491 | v1, v2); |
| 492 | builder.setInsertionPointAfter(ifOp); |
| 493 | }; |
| 494 | |
| 495 | // Insert the 4th element. |
| 496 | insert4th(); |
| 497 | |
| 498 | // Insert the 5th element. |
| 499 | scf::IfOp ifOp = createCompareThenSwap(builder, loc, xPerm, ny, swapOperands, |
| 500 | compareOperands, v3, v4); |
| 501 | insert4th(); |
| 502 | builder.setInsertionPointAfter(ifOp); |
| 503 | } |
| 504 | |
| 505 | /// Creates a code block to swap the values in indices lo, mi, and hi so that |
| 506 | /// data[lo], data[mi] and data[hi] are sorted in non-decreasing values. When |
| 507 | /// the number of values in range [lo, hi) is more than a threshold, we also |
| 508 | /// include the middle of [lo, mi) and [mi, hi) and sort a total of five values. |
| 509 | static void createChoosePivot(OpBuilder &builder, ModuleOp module, |
| 510 | func::FuncOp func, AffineMap xPerm, uint64_t ny, |
| 511 | Value lo, Value hi, Value mi, ValueRange args) { |
| 512 | SmallVector<Value> compareOperands{mi, lo}; |
| 513 | constexpr uint64_t numXBuffers = 1; |
| 514 | compareOperands.append(in_start: args.begin() + xStartIdx, |
| 515 | in_end: args.begin() + xStartIdx + numXBuffers); |
| 516 | SmallVector<Value> swapOperands{mi, lo}; |
| 517 | swapOperands.append(in_start: args.begin() + xStartIdx, in_end: args.end()); |
| 518 | Location loc = func.getLoc(); |
| 519 | Value c1 = constantIndex(builder, loc, i: 1); |
| 520 | Value hiP1 = builder.create<arith::AddIOp>(loc, hi, c1); |
| 521 | Value len = builder.create<arith::SubIOp>(loc, hiP1, lo); |
| 522 | Value lenThreshold = constantIndex(builder, loc, i: 1000); |
| 523 | Value lenCond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, |
| 524 | len, lenThreshold); |
| 525 | scf::IfOp lenIf = builder.create<scf::IfOp>(loc, lenCond, /*else=*/true); |
| 526 | |
| 527 | // When len < 1000, choose pivot from median of 3 values. |
| 528 | builder.setInsertionPointToStart(&lenIf.getThenRegion().front()); |
| 529 | createSort3(builder, loc, xPerm, ny, swapOperands, compareOperands, v0: lo, v1: mi, |
| 530 | v2: hi); |
| 531 | |
| 532 | // When len >= 1000, choose pivot from median of 5 values. |
| 533 | builder.setInsertionPointToStart(&lenIf.getElseRegion().front()); |
| 534 | Value miP1 = builder.create<arith::AddIOp>(loc, hi, c1); |
| 535 | Value a = builder.create<arith::AddIOp>(loc, lo, miP1); |
| 536 | // Value a is the middle between [loc, mi]. |
| 537 | a = builder.create<arith::ShRUIOp>(loc, a, c1); |
| 538 | Value b = builder.create<arith::AddIOp>(loc, mi, hiP1); |
| 539 | // Value b is the middle between [mi, hi]. |
| 540 | b = builder.create<arith::ShRUIOp>(loc, b, c1); |
| 541 | createSort5(builder, loc, xPerm, ny, swapOperands, compareOperands, v0: lo, v1: a, v2: mi, |
| 542 | v3: b, v4: hi); |
| 543 | |
| 544 | builder.setInsertionPointAfter(lenIf); |
| 545 | } |
| 546 | |
| 547 | /// Creates a function to perform quick sort partition on the values in the |
| 548 | /// range of index [lo, hi), assuming lo < hi. |
| 549 | // |
| 550 | // The generated IR corresponds to this C like algorithm: |
| 551 | // int partition(lo, hi, xs) { |
| 552 | // p = (lo+hi)/2 // pivot index |
| 553 | // i = lo |
| 554 | // j = hi-1 |
| 555 | // while (true) do { |
| 556 | // while (xs[i] < xs[p]) i ++; |
| 557 | // i_eq = (xs[i] == xs[p]); |
| 558 | // while (xs[j] > xs[p]) j --; |
| 559 | // j_eq = (xs[j] == xs[p]); |
| 560 | // |
| 561 | // if (i >= j) return j + 1; |
| 562 | // |
| 563 | // if (i < j) { |
| 564 | // swap(xs[i], xs[j]) |
| 565 | // if (i == p) { |
| 566 | // p = j; |
| 567 | // } else if (j == p) { |
| 568 | // p = i; |
| 569 | // } |
| 570 | // if (i_eq && j_eq) { |
| 571 | // ++i; |
| 572 | // --j; |
| 573 | // } |
| 574 | // } |
| 575 | // } |
| 576 | // } |
| 577 | static void createPartitionFunc(OpBuilder &builder, ModuleOp module, |
| 578 | func::FuncOp func, AffineMap xPerm, uint64_t ny, |
| 579 | uint32_t nTrailingP = 0) { |
| 580 | // Quick sort partition doesn't use trailing parameters. |
| 581 | (void)nTrailingP; |
| 582 | assert(nTrailingP == 0); |
| 583 | OpBuilder::InsertionGuard insertionGuard(builder); |
| 584 | |
| 585 | Block *entryBlock = func.addEntryBlock(); |
| 586 | builder.setInsertionPointToStart(entryBlock); |
| 587 | |
| 588 | Location loc = func.getLoc(); |
| 589 | ValueRange args = entryBlock->getArguments(); |
| 590 | Value lo = args[loIdx]; |
| 591 | Value hi = args[hiIdx]; |
| 592 | Value sum = builder.create<arith::AddIOp>(loc, lo, hi); |
| 593 | Value c1 = constantIndex(builder, loc, i: 1); |
| 594 | Value p = builder.create<arith::ShRUIOp>(loc, sum, c1); |
| 595 | |
| 596 | Value i = lo; |
| 597 | Value j = builder.create<arith::SubIOp>(loc, hi, c1); |
| 598 | createChoosePivot(builder, module, func, xPerm, ny, i, j, p, args); |
| 599 | Value trueVal = constantI1(builder, loc, b: true); // The value for while (true) |
| 600 | SmallVector<Value, 4> operands{i, j, p, trueVal}; // Exactly four values. |
| 601 | SmallVector<Type, 4> types{i.getType(), j.getType(), p.getType(), |
| 602 | trueVal.getType()}; |
| 603 | scf::WhileOp whileOp = builder.create<scf::WhileOp>(loc, types, operands); |
| 604 | |
| 605 | // The before-region of the WhileOp. |
| 606 | Block *before = builder.createBlock(&whileOp.getBefore(), {}, types, |
| 607 | {loc, loc, loc, loc}); |
| 608 | builder.setInsertionPointToEnd(before); |
| 609 | builder.create<scf::ConditionOp>(loc, before->getArgument(3), |
| 610 | before->getArguments()); |
| 611 | |
| 612 | // The after-region of the WhileOp. |
| 613 | Block *after = |
| 614 | builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc, loc, loc}); |
| 615 | builder.setInsertionPointToEnd(after); |
| 616 | i = after->getArgument(i: 0); |
| 617 | j = after->getArgument(i: 1); |
| 618 | p = after->getArgument(i: 2); |
| 619 | |
| 620 | constexpr uint64_t numXBuffers = 1; |
| 621 | auto [iresult, iCompareEq] = |
| 622 | createScanLoop(builder, module, func, args.slice(n: xStartIdx, m: numXBuffers), |
| 623 | i, p, xPerm, ny, 1); |
| 624 | i = iresult; |
| 625 | auto [jresult, jCompareEq] = |
| 626 | createScanLoop(builder, module, func, args.slice(n: xStartIdx, m: numXBuffers), |
| 627 | j, p, xPerm, ny, -1); |
| 628 | j = jresult; |
| 629 | |
| 630 | // If i < j: |
| 631 | Value cond = |
| 632 | builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, i, j); |
| 633 | scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, cond, /*else=*/true); |
| 634 | builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); |
| 635 | SmallVector<Value> swapOperands{i, j}; |
| 636 | swapOperands.append(in_start: args.begin() + xStartIdx, in_end: args.end()); |
| 637 | createSwap(builder, loc, args: swapOperands, xPerm, ny); |
| 638 | // If the pivot is moved, update p with the new pivot. |
| 639 | Value icond = |
| 640 | builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, i, p); |
| 641 | scf::IfOp ifOpI = builder.create<scf::IfOp>(loc, TypeRange{p.getType()}, |
| 642 | icond, /*else=*/true); |
| 643 | builder.setInsertionPointToStart(&ifOpI.getThenRegion().front()); |
| 644 | builder.create<scf::YieldOp>(loc, ValueRange{j}); |
| 645 | builder.setInsertionPointToStart(&ifOpI.getElseRegion().front()); |
| 646 | Value jcond = |
| 647 | builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, j, p); |
| 648 | scf::IfOp ifOpJ = builder.create<scf::IfOp>(loc, TypeRange{p.getType()}, |
| 649 | jcond, /*else=*/true); |
| 650 | builder.setInsertionPointToStart(&ifOpJ.getThenRegion().front()); |
| 651 | builder.create<scf::YieldOp>(loc, ValueRange{i}); |
| 652 | builder.setInsertionPointToStart(&ifOpJ.getElseRegion().front()); |
| 653 | builder.create<scf::YieldOp>(loc, ValueRange{p}); |
| 654 | builder.setInsertionPointAfter(ifOpJ); |
| 655 | builder.create<scf::YieldOp>(loc, ifOpJ.getResults()); |
| 656 | builder.setInsertionPointAfter(ifOpI); |
| 657 | Value compareEqIJ = |
| 658 | builder.create<arith::AndIOp>(loc, iCompareEq, jCompareEq); |
| 659 | scf::IfOp ifOp2 = builder.create<scf::IfOp>( |
| 660 | loc, TypeRange{i.getType(), j.getType()}, compareEqIJ, /*else=*/true); |
| 661 | builder.setInsertionPointToStart(&ifOp2.getThenRegion().front()); |
| 662 | Value i2 = builder.create<arith::AddIOp>(loc, i, c1); |
| 663 | Value j2 = builder.create<arith::SubIOp>(loc, j, c1); |
| 664 | builder.create<scf::YieldOp>(loc, ValueRange{i2, j2}); |
| 665 | builder.setInsertionPointToStart(&ifOp2.getElseRegion().front()); |
| 666 | builder.create<scf::YieldOp>(loc, ValueRange{i, j}); |
| 667 | builder.setInsertionPointAfter(ifOp2); |
| 668 | builder.create<scf::YieldOp>( |
| 669 | loc, |
| 670 | ValueRange{ifOp2.getResult(0), ifOp2.getResult(1), ifOpI.getResult(0), |
| 671 | /*cont=*/constantI1(builder, loc, true)}); |
| 672 | |
| 673 | // False branch for if i < j (i.e., i >= j): |
| 674 | builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); |
| 675 | p = builder.create<arith::AddIOp>(loc, j, |
| 676 | constantOne(builder, loc, j.getType())); |
| 677 | builder.create<scf::YieldOp>( |
| 678 | loc, ValueRange{i, j, p, /*cont=*/constantI1(builder, loc, false)}); |
| 679 | |
| 680 | // Return for the whileOp. |
| 681 | builder.setInsertionPointAfter(ifOp); |
| 682 | builder.create<scf::YieldOp>(loc, ifOp.getResults()); |
| 683 | |
| 684 | // Return for the function. |
| 685 | builder.setInsertionPointAfter(whileOp); |
| 686 | builder.create<func::ReturnOp>(loc, whileOp.getResult(2)); |
| 687 | } |
| 688 | |
| 689 | /// Computes (n-2)/n, assuming n has index type. |
| 690 | static Value createSubTwoDividedByTwo(OpBuilder &builder, Location loc, |
| 691 | Value n) { |
| 692 | Value i2 = constantIndex(builder, loc, i: 2); |
| 693 | Value res = builder.create<arith::SubIOp>(loc, n, i2); |
| 694 | Value i1 = constantIndex(builder, loc, i: 1); |
| 695 | return builder.create<arith::ShRUIOp>(loc, res, i1); |
| 696 | } |
| 697 | |
| 698 | /// Creates a function to heapify the subtree with root `start` within the full |
| 699 | /// binary tree in the range of index [first, first + n). |
| 700 | // |
| 701 | // The generated IR corresponds to this C like algorithm: |
| 702 | // void shiftDown(first, start, n, data) { |
| 703 | // if (n >= 2) { |
| 704 | // child = start - first |
| 705 | // if ((n-2)/2 >= child) { |
| 706 | // // Left child exists. |
| 707 | // child = child * 2 + 1 // Initialize the bigger child to left child. |
| 708 | // childIndex = child + first |
| 709 | // if (child+1 < n && data[childIndex] < data[childIndex+1]) |
| 710 | // // Right child exits and is bigger. |
| 711 | // childIndex++; child++; |
| 712 | // // Shift data[start] down to where it belongs in the subtree. |
| 713 | // while (data[start] < data[childIndex) { |
| 714 | // swap(data[start], data[childIndex]) |
| 715 | // start = childIndex |
| 716 | // if ((n - 2)/2 >= child) { |
| 717 | // // Left child exists. |
| 718 | // child = 2*child + 1 |
| 719 | // childIndex = child + 1 |
| 720 | // if (child + 1) < n && data[childIndex] < data[childIndex+1] |
| 721 | // childIndex++; child++; |
| 722 | // } |
| 723 | // } |
| 724 | // } |
| 725 | // } |
| 726 | // } |
| 727 | // |
| 728 | static void createShiftDownFunc(OpBuilder &builder, ModuleOp module, |
| 729 | func::FuncOp func, AffineMap xPerm, uint64_t ny, |
| 730 | uint32_t nTrailingP) { |
| 731 | // The value n is passed in as a trailing parameter. |
| 732 | assert(nTrailingP == 1); |
| 733 | OpBuilder::InsertionGuard insertionGuard(builder); |
| 734 | Block *entryBlock = func.addEntryBlock(); |
| 735 | builder.setInsertionPointToStart(entryBlock); |
| 736 | |
| 737 | Location loc = func.getLoc(); |
| 738 | Value n = entryBlock->getArguments().back(); |
| 739 | ValueRange args = entryBlock->getArguments().drop_back(); |
| 740 | Value first = args[loIdx]; |
| 741 | Value start = args[hiIdx]; |
| 742 | |
| 743 | // If (n >= 2). |
| 744 | Value c2 = constantIndex(builder, loc, i: 2); |
| 745 | Value condN = |
| 746 | builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, n, c2); |
| 747 | scf::IfOp ifN = builder.create<scf::IfOp>(loc, condN, /*else=*/false); |
| 748 | builder.setInsertionPointToStart(&ifN.getThenRegion().front()); |
| 749 | Value child = builder.create<arith::SubIOp>(loc, start, first); |
| 750 | |
| 751 | // If ((n-2)/2 >= child). |
| 752 | Value t = createSubTwoDividedByTwo(builder, loc, n); |
| 753 | Value condNc = |
| 754 | builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, t, child); |
| 755 | scf::IfOp ifNc = builder.create<scf::IfOp>(loc, condNc, /*else=*/false); |
| 756 | |
| 757 | builder.setInsertionPointToStart(&ifNc.getThenRegion().front()); |
| 758 | Value c1 = constantIndex(builder, loc, i: 1); |
| 759 | SmallVector<Value> compareOperands{start, start}; |
| 760 | constexpr uint64_t numXBuffers = 1; |
| 761 | compareOperands.append(in_start: args.begin() + xStartIdx, |
| 762 | in_end: args.begin() + xStartIdx + numXBuffers); |
| 763 | |
| 764 | // Generate code to inspect the children of 'r' and return the larger child |
| 765 | // as follows: |
| 766 | // child = r * 2 + 1 // Left child. |
| 767 | // childIndex = child + first |
| 768 | // if (child+1 < n && data[childIndex] < data[childIndex+1]) |
| 769 | // childIndex ++; child ++ // Right child is bigger. |
| 770 | auto getLargerChild = [&](Value r) -> std::pair<Value, Value> { |
| 771 | Value lChild = builder.create<arith::ShLIOp>(loc, r, c1); |
| 772 | lChild = builder.create<arith::AddIOp>(loc, lChild, c1); |
| 773 | Value lChildIdx = builder.create<arith::AddIOp>(loc, lChild, first); |
| 774 | Value rChild = builder.create<arith::AddIOp>(loc, lChild, c1); |
| 775 | Value cond1 = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, |
| 776 | rChild, n); |
| 777 | SmallVector<Type, 2> ifTypes(2, r.getType()); |
| 778 | scf::IfOp if1 = |
| 779 | builder.create<scf::IfOp>(loc, ifTypes, cond1, /*else=*/true); |
| 780 | builder.setInsertionPointToStart(&if1.getThenRegion().front()); |
| 781 | Value rChildIdx = builder.create<arith::AddIOp>(loc, rChild, first); |
| 782 | // Compare data[left] < data[right]. |
| 783 | compareOperands[0] = lChildIdx; |
| 784 | compareOperands[1] = rChildIdx; |
| 785 | Value cond2 = |
| 786 | createInlinedLessThan(builder, loc, args: compareOperands, xPerm, ny); |
| 787 | scf::IfOp if2 = |
| 788 | builder.create<scf::IfOp>(loc, ifTypes, cond2, /*else=*/true); |
| 789 | builder.setInsertionPointToStart(&if2.getThenRegion().front()); |
| 790 | builder.create<scf::YieldOp>(loc, ValueRange{rChild, rChildIdx}); |
| 791 | builder.setInsertionPointToStart(&if2.getElseRegion().front()); |
| 792 | builder.create<scf::YieldOp>(loc, ValueRange{lChild, lChildIdx}); |
| 793 | builder.setInsertionPointAfter(if2); |
| 794 | builder.create<scf::YieldOp>(loc, if2.getResults()); |
| 795 | builder.setInsertionPointToStart(&if1.getElseRegion().front()); |
| 796 | builder.create<scf::YieldOp>(loc, ValueRange{lChild, lChildIdx}); |
| 797 | builder.setInsertionPointAfter(if1); |
| 798 | return std::make_pair(if1.getResult(0), if1.getResult(1)); |
| 799 | }; |
| 800 | |
| 801 | Value childIdx; |
| 802 | std::tie(args&: child, args&: childIdx) = getLargerChild(child); |
| 803 | |
| 804 | // While (data[start] < data[childIndex]). |
| 805 | SmallVector<Type, 3> types(3, child.getType()); |
| 806 | scf::WhileOp whileOp = builder.create<scf::WhileOp>( |
| 807 | loc, types, SmallVector<Value, 2>{start, child, childIdx}); |
| 808 | |
| 809 | // The before-region of the WhileOp. |
| 810 | SmallVector<Location, 3> locs(3, loc); |
| 811 | Block *before = builder.createBlock(&whileOp.getBefore(), {}, types, locs); |
| 812 | builder.setInsertionPointToEnd(before); |
| 813 | start = before->getArgument(i: 0); |
| 814 | childIdx = before->getArgument(i: 2); |
| 815 | compareOperands[0] = start; |
| 816 | compareOperands[1] = childIdx; |
| 817 | Value cond = createInlinedLessThan(builder, loc, args: compareOperands, xPerm, ny); |
| 818 | builder.create<scf::ConditionOp>(loc, cond, before->getArguments()); |
| 819 | |
| 820 | // The after-region of the WhileOp. |
| 821 | Block *after = builder.createBlock(&whileOp.getAfter(), {}, types, locs); |
| 822 | start = after->getArgument(i: 0); |
| 823 | child = after->getArgument(i: 1); |
| 824 | childIdx = after->getArgument(i: 2); |
| 825 | SmallVector<Value> swapOperands{start, childIdx}; |
| 826 | swapOperands.append(in_start: args.begin() + xStartIdx, in_end: args.end()); |
| 827 | createSwap(builder, loc, args: swapOperands, xPerm, ny); |
| 828 | start = childIdx; |
| 829 | Value cond2 = |
| 830 | builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, t, child); |
| 831 | scf::IfOp if2 = builder.create<scf::IfOp>( |
| 832 | loc, TypeRange{child.getType(), child.getType()}, cond2, /*else=*/true); |
| 833 | builder.setInsertionPointToStart(&if2.getThenRegion().front()); |
| 834 | auto [newChild, newChildIdx] = getLargerChild(child); |
| 835 | builder.create<scf::YieldOp>(loc, ValueRange{newChild, newChildIdx}); |
| 836 | builder.setInsertionPointToStart(&if2.getElseRegion().front()); |
| 837 | builder.create<scf::YieldOp>(loc, ValueRange{child, childIdx}); |
| 838 | builder.setInsertionPointAfter(if2); |
| 839 | builder.create<scf::YieldOp>( |
| 840 | loc, ValueRange{start, if2.getResult(0), if2.getResult(1)}); |
| 841 | |
| 842 | builder.setInsertionPointAfter(ifN); |
| 843 | builder.create<func::ReturnOp>(loc); |
| 844 | } |
| 845 | |
| 846 | /// Creates a function to perform heap sort on the values in the range of index |
| 847 | /// [lo, hi) with the assumption hi - lo >= 2. |
| 848 | // |
| 849 | // The generate IR corresponds to this C like algorithm: |
| 850 | // void heapSort(lo, hi, data) { |
| 851 | // n = hi - lo |
| 852 | // for i = (n-2)/2 downto 0 |
| 853 | // shiftDown(lo, lo+i, n) |
| 854 | // |
| 855 | // for l = n downto 2 |
| 856 | // swap(lo, lo+l-1) |
| 857 | // shiftdown(lo, lo, l-1) |
| 858 | // } |
| 859 | static void createHeapSortFunc(OpBuilder &builder, ModuleOp module, |
| 860 | func::FuncOp func, AffineMap xPerm, uint64_t ny, |
| 861 | uint32_t nTrailingP) { |
| 862 | // Heap sort function doesn't have trailing parameters. |
| 863 | (void)nTrailingP; |
| 864 | assert(nTrailingP == 0); |
| 865 | OpBuilder::InsertionGuard insertionGuard(builder); |
| 866 | Block *entryBlock = func.addEntryBlock(); |
| 867 | builder.setInsertionPointToStart(entryBlock); |
| 868 | |
| 869 | Location loc = func.getLoc(); |
| 870 | ValueRange args = entryBlock->getArguments(); |
| 871 | Value lo = args[loIdx]; |
| 872 | Value hi = args[hiIdx]; |
| 873 | Value n = builder.create<arith::SubIOp>(loc, hi, lo); |
| 874 | |
| 875 | // For i = (n-2)/2 downto 0. |
| 876 | Value c0 = constantIndex(builder, loc, i: 0); |
| 877 | Value c1 = constantIndex(builder, loc, i: 1); |
| 878 | Value s = createSubTwoDividedByTwo(builder, loc, n); |
| 879 | Value up = builder.create<arith::AddIOp>(loc, s, c1); |
| 880 | scf::ForOp forI = builder.create<scf::ForOp>(loc, c0, up, c1); |
| 881 | builder.setInsertionPointToStart(forI.getBody()); |
| 882 | Value i = builder.create<arith::SubIOp>(loc, s, forI.getInductionVar()); |
| 883 | Value lopi = builder.create<arith::AddIOp>(loc, lo, i); |
| 884 | SmallVector<Value> shiftDownOperands = {lo, lopi}; |
| 885 | shiftDownOperands.append(in_start: args.begin() + xStartIdx, in_end: args.end()); |
| 886 | shiftDownOperands.push_back(Elt: n); |
| 887 | FlatSymbolRefAttr shiftDownFunc = getMangledSortHelperFunc( |
| 888 | builder, func, TypeRange(), kShiftDownFuncNamePrefix, xPerm, ny, |
| 889 | shiftDownOperands, createShiftDownFunc, /*nTrailingP=*/1); |
| 890 | builder.create<func::CallOp>(loc, shiftDownFunc, TypeRange(), |
| 891 | shiftDownOperands); |
| 892 | |
| 893 | builder.setInsertionPointAfter(forI); |
| 894 | // For l = n downto 2. |
| 895 | up = builder.create<arith::SubIOp>(loc, n, c1); |
| 896 | scf::ForOp forL = builder.create<scf::ForOp>(loc, c0, up, c1); |
| 897 | builder.setInsertionPointToStart(forL.getBody()); |
| 898 | Value l = builder.create<arith::SubIOp>(loc, n, forL.getInductionVar()); |
| 899 | Value loplm1 = builder.create<arith::AddIOp>(loc, lo, l); |
| 900 | loplm1 = builder.create<arith::SubIOp>(loc, loplm1, c1); |
| 901 | SmallVector<Value> swapOperands{lo, loplm1}; |
| 902 | swapOperands.append(in_start: args.begin() + xStartIdx, in_end: args.end()); |
| 903 | createSwap(builder, loc, args: swapOperands, xPerm, ny); |
| 904 | shiftDownOperands[1] = lo; |
| 905 | shiftDownOperands[shiftDownOperands.size() - 1] = |
| 906 | builder.create<arith::SubIOp>(loc, l, c1); |
| 907 | builder.create<func::CallOp>(loc, shiftDownFunc, TypeRange(), |
| 908 | shiftDownOperands); |
| 909 | |
| 910 | builder.setInsertionPointAfter(forL); |
| 911 | builder.create<func::ReturnOp>(loc); |
| 912 | } |
| 913 | |
| 914 | /// A helper for generating code to perform quick sort. It partitions [lo, hi), |
| 915 | /// recursively calls quick sort to process the smaller partition and returns |
| 916 | /// the bigger partition to be processed by the enclosed while-loop. |
| 917 | static std::pair<Value, Value> |
| 918 | createQuickSort(OpBuilder &builder, ModuleOp module, func::FuncOp func, |
| 919 | ValueRange args, AffineMap xPerm, uint64_t ny, |
| 920 | uint32_t nTrailingP) { |
| 921 | MLIRContext *context = module.getContext(); |
| 922 | Location loc = func.getLoc(); |
| 923 | Value lo = args[loIdx]; |
| 924 | Value hi = args[hiIdx]; |
| 925 | SmallVector<Type, 2> types(2, lo.getType()); // Only two types. |
| 926 | |
| 927 | FlatSymbolRefAttr partitionFunc = getMangledSortHelperFunc( |
| 928 | builder, func, {IndexType::get(context)}, kPartitionFuncNamePrefix, xPerm, |
| 929 | ny, args.drop_back(nTrailingP), createPartitionFunc); |
| 930 | Value p = builder |
| 931 | .create<func::CallOp>(loc, partitionFunc, |
| 932 | TypeRange{IndexType::get(context)}, |
| 933 | args.drop_back(nTrailingP)) |
| 934 | .getResult(0); |
| 935 | |
| 936 | Value lenLow = builder.create<arith::SubIOp>(loc, p, lo); |
| 937 | Value lenHigh = builder.create<arith::SubIOp>(loc, hi, p); |
| 938 | // Partition already sorts array with len <= 2 |
| 939 | Value c2 = constantIndex(builder, loc, i: 2); |
| 940 | Value len = builder.create<arith::SubIOp>(loc, hi, lo); |
| 941 | Value lenGtTwo = |
| 942 | builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ugt, len, c2); |
| 943 | scf::IfOp ifLenGtTwo = |
| 944 | builder.create<scf::IfOp>(loc, types, lenGtTwo, /*else=*/true); |
| 945 | builder.setInsertionPointToStart(&ifLenGtTwo.getElseRegion().front()); |
| 946 | // Returns an empty range to mark the entire region is fully sorted. |
| 947 | builder.create<scf::YieldOp>(loc, ValueRange{lo, lo}); |
| 948 | |
| 949 | // Else len > 2, need recursion. |
| 950 | builder.setInsertionPointToStart(&ifLenGtTwo.getThenRegion().front()); |
| 951 | Value cond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ule, |
| 952 | lenLow, lenHigh); |
| 953 | |
| 954 | Value c0 = constantIndex(builder, loc, i: 0); |
| 955 | scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, cond, /*else=*/true); |
| 956 | |
| 957 | auto mayRecursion = [&](Value low, Value high, Value len) { |
| 958 | Value cond = |
| 959 | builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, len, c0); |
| 960 | scf::IfOp ifOp = builder.create<scf::IfOp>(loc, cond, /*else=*/false); |
| 961 | builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); |
| 962 | SmallVector<Value> operands{low, high}; |
| 963 | operands.append(in_start: args.begin() + xStartIdx, in_end: args.end()); |
| 964 | builder.create<func::CallOp>(loc, func, operands); |
| 965 | builder.setInsertionPointAfter(ifOp); |
| 966 | }; |
| 967 | |
| 968 | // Recursively call quickSort to process the smaller partition and return |
| 969 | // the bigger partition to be processed by the enclosed while-loop. |
| 970 | builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); |
| 971 | mayRecursion(lo, p, lenLow); |
| 972 | builder.create<scf::YieldOp>(loc, ValueRange{p, hi}); |
| 973 | |
| 974 | builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); |
| 975 | mayRecursion(p, hi, lenHigh); |
| 976 | builder.create<scf::YieldOp>(loc, ValueRange{lo, p}); |
| 977 | |
| 978 | builder.setInsertionPointAfter(ifOp); |
| 979 | builder.create<scf::YieldOp>(loc, ifOp.getResults()); |
| 980 | |
| 981 | builder.setInsertionPointAfter(ifLenGtTwo); |
| 982 | return std::make_pair(ifLenGtTwo.getResult(0), ifLenGtTwo.getResult(1)); |
| 983 | } |
| 984 | |
| 985 | /// Creates a function to perform insertion sort on the values in the range of |
| 986 | /// index [lo, hi). |
| 987 | // |
| 988 | // The generate IR corresponds to this C like algorithm: |
| 989 | // void insertionSort(lo, hi, data) { |
| 990 | // for (i = lo+1; i < hi; i++) { |
| 991 | // d = data[i]; |
| 992 | // p = binarySearch(lo, i-1, data) |
| 993 | // for (j = 0; j > i - p; j++) |
| 994 | // data[i-j] = data[i-j-1] |
| 995 | // data[p] = d |
| 996 | // } |
| 997 | // } |
| 998 | static void createSortStableFunc(OpBuilder &builder, ModuleOp module, |
| 999 | func::FuncOp func, AffineMap xPerm, |
| 1000 | uint64_t ny, uint32_t nTrailingP) { |
| 1001 | // Stable sort function doesn't use trailing parameters. |
| 1002 | (void)nTrailingP; |
| 1003 | assert(nTrailingP == 0); |
| 1004 | OpBuilder::InsertionGuard insertionGuard(builder); |
| 1005 | Block *entryBlock = func.addEntryBlock(); |
| 1006 | builder.setInsertionPointToStart(entryBlock); |
| 1007 | |
| 1008 | MLIRContext *context = module.getContext(); |
| 1009 | Location loc = func.getLoc(); |
| 1010 | ValueRange args = entryBlock->getArguments(); |
| 1011 | Value c1 = constantIndex(builder, loc, i: 1); |
| 1012 | Value lo = args[loIdx]; |
| 1013 | Value hi = args[hiIdx]; |
| 1014 | Value lop1 = builder.create<arith::AddIOp>(loc, lo, c1); |
| 1015 | |
| 1016 | // Start the outer for-stmt with induction variable i. |
| 1017 | scf::ForOp forOpI = builder.create<scf::ForOp>(loc, lop1, hi, c1); |
| 1018 | builder.setInsertionPointToStart(forOpI.getBody()); |
| 1019 | Value i = forOpI.getInductionVar(); |
| 1020 | |
| 1021 | // Binary search to find the insertion point p. |
| 1022 | SmallVector<Value> operands{lo, i}; |
| 1023 | operands.append(in_start: args.begin() + xStartIdx, in_end: args.end()); |
| 1024 | FlatSymbolRefAttr searchFunc = getMangledSortHelperFunc( |
| 1025 | builder, func, {IndexType::get(context)}, kBinarySearchFuncNamePrefix, |
| 1026 | xPerm, ny, operands, createBinarySearchFunc); |
| 1027 | Value p = builder |
| 1028 | .create<func::CallOp>(loc, searchFunc, TypeRange{c1.getType()}, |
| 1029 | operands) |
| 1030 | .getResult(0); |
| 1031 | |
| 1032 | // Move the value at data[i] to a temporary location. |
| 1033 | operands[0] = operands[1] = i; |
| 1034 | SmallVector<Value> d; |
| 1035 | forEachIJPairInAllBuffers( |
| 1036 | builder, loc, args: operands, xPerm, ny, |
| 1037 | bodyBuilder: [&](uint64_t unused, Value i, Value unused2, Value buffer) { |
| 1038 | d.push_back(builder.create<memref::LoadOp>(loc, buffer, i)); |
| 1039 | }); |
| 1040 | |
| 1041 | // Start the inner for-stmt with induction variable j, for moving data[p..i) |
| 1042 | // to data[p+1..i+1). |
| 1043 | Value imp = builder.create<arith::SubIOp>(loc, i, p); |
| 1044 | Value c0 = constantIndex(builder, loc, i: 0); |
| 1045 | scf::ForOp forOpJ = builder.create<scf::ForOp>(loc, c0, imp, c1); |
| 1046 | builder.setInsertionPointToStart(forOpJ.getBody()); |
| 1047 | Value j = forOpJ.getInductionVar(); |
| 1048 | Value imj = builder.create<arith::SubIOp>(loc, i, j); |
| 1049 | operands[1] = imj; |
| 1050 | operands[0] = builder.create<arith::SubIOp>(loc, imj, c1); |
| 1051 | forEachIJPairInAllBuffers( |
| 1052 | builder, loc, args: operands, xPerm, ny, |
| 1053 | bodyBuilder: [&](uint64_t unused, Value imjm1, Value imj, Value buffer) { |
| 1054 | Value t = builder.create<memref::LoadOp>(loc, buffer, imjm1); |
| 1055 | builder.create<memref::StoreOp>(loc, t, buffer, imj); |
| 1056 | }); |
| 1057 | |
| 1058 | // Store the value at data[i] to data[p]. |
| 1059 | builder.setInsertionPointAfter(forOpJ); |
| 1060 | operands[0] = operands[1] = p; |
| 1061 | forEachIJPairInAllBuffers( |
| 1062 | builder, loc, args: operands, xPerm, ny, |
| 1063 | bodyBuilder: [&](uint64_t k, Value p, Value usused, Value buffer) { |
| 1064 | builder.create<memref::StoreOp>(loc, d[k], buffer, p); |
| 1065 | }); |
| 1066 | |
| 1067 | builder.setInsertionPointAfter(forOpI); |
| 1068 | builder.create<func::ReturnOp>(loc); |
| 1069 | } |
| 1070 | |
| 1071 | /// Creates a function to perform quick sort or a hybrid quick sort on the |
| 1072 | /// values in the range of index [lo, hi). |
| 1073 | // |
| 1074 | // |
| 1075 | // When nTrailingP == 0, the generated IR corresponds to this C like algorithm: |
| 1076 | // void quickSort(lo, hi, data) { |
| 1077 | // while (lo + 1 < hi) { |
| 1078 | // p = partition(low, high, data); |
| 1079 | // if (len(lo, p) < len(p+1, hi)) { |
| 1080 | // quickSort(lo, p, data); |
| 1081 | // lo = p+1; |
| 1082 | // } else { |
| 1083 | // quickSort(p + 1, hi, data); |
| 1084 | // hi = p; |
| 1085 | // } |
| 1086 | // } |
| 1087 | // } |
| 1088 | // |
| 1089 | // When nTrailingP == 1, the generated IR corresponds to this C like algorithm: |
| 1090 | // void hybridQuickSort(lo, hi, data, depthLimit) { |
| 1091 | // while (lo + 1 < hi) { |
| 1092 | // len = hi - lo; |
| 1093 | // if (len <= limit) { |
| 1094 | // insertionSort(lo, hi, data); |
| 1095 | // } else { |
| 1096 | // depthLimit --; |
| 1097 | // if (depthLimit <= 0) { |
| 1098 | // heapSort(lo, hi, data); |
| 1099 | // } else { |
| 1100 | // p = partition(low, high, data); |
| 1101 | // if (len(lo, p) < len(p+1, hi)) { |
| 1102 | // quickSort(lo, p, data, depthLimit); |
| 1103 | // lo = p+1; |
| 1104 | // } else { |
| 1105 | // quickSort(p + 1, hi, data, depthLimit); |
| 1106 | // hi = p; |
| 1107 | // } |
| 1108 | // } |
| 1109 | // } |
| 1110 | // } |
| 1111 | // } |
| 1112 | // |
| 1113 | static void createQuickSortFunc(OpBuilder &builder, ModuleOp module, |
| 1114 | func::FuncOp func, AffineMap xPerm, uint64_t ny, |
| 1115 | uint32_t nTrailingP) { |
| 1116 | assert(nTrailingP == 1 || nTrailingP == 0); |
| 1117 | bool isHybrid = (nTrailingP == 1); |
| 1118 | OpBuilder::InsertionGuard insertionGuard(builder); |
| 1119 | Block *entryBlock = func.addEntryBlock(); |
| 1120 | builder.setInsertionPointToStart(entryBlock); |
| 1121 | |
| 1122 | Location loc = func.getLoc(); |
| 1123 | SmallVector<Value> args; |
| 1124 | args.append(in_start: entryBlock->getArguments().begin(), |
| 1125 | in_end: entryBlock->getArguments().end()); |
| 1126 | Value lo = args[loIdx]; |
| 1127 | Value hi = args[hiIdx]; |
| 1128 | SmallVector<Type, 2> types(2, lo.getType()); // Only two types. |
| 1129 | scf::WhileOp whileOp = |
| 1130 | builder.create<scf::WhileOp>(loc, types, SmallVector<Value, 2>{lo, hi}); |
| 1131 | |
| 1132 | // The before-region of the WhileOp. |
| 1133 | Block *before = |
| 1134 | builder.createBlock(&whileOp.getBefore(), {}, types, {loc, loc}); |
| 1135 | builder.setInsertionPointToEnd(before); |
| 1136 | lo = before->getArgument(i: 0); |
| 1137 | hi = before->getArgument(i: 1); |
| 1138 | Value loP1 = |
| 1139 | builder.create<arith::AddIOp>(loc, lo, constantIndex(builder, loc, 1)); |
| 1140 | Value needSort = |
| 1141 | builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, loP1, hi); |
| 1142 | builder.create<scf::ConditionOp>(loc, needSort, before->getArguments()); |
| 1143 | |
| 1144 | // The after-region of the WhileOp. |
| 1145 | Block *after = |
| 1146 | builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc}); |
| 1147 | builder.setInsertionPointToEnd(after); |
| 1148 | lo = after->getArgument(i: 0); |
| 1149 | hi = after->getArgument(i: 1); |
| 1150 | args[0] = lo; |
| 1151 | args[1] = hi; |
| 1152 | |
| 1153 | if (isHybrid) { |
| 1154 | Value len = builder.create<arith::SubIOp>(loc, hi, lo); |
| 1155 | Value lenLimit = constantIndex(builder, loc, i: 30); |
| 1156 | Value lenCond = builder.create<arith::CmpIOp>( |
| 1157 | loc, arith::CmpIPredicate::ule, len, lenLimit); |
| 1158 | scf::IfOp lenIf = |
| 1159 | builder.create<scf::IfOp>(loc, types, lenCond, /*else=*/true); |
| 1160 | |
| 1161 | // When len <= limit. |
| 1162 | builder.setInsertionPointToStart(&lenIf.getThenRegion().front()); |
| 1163 | FlatSymbolRefAttr insertionSortFunc = getMangledSortHelperFunc( |
| 1164 | builder, func, TypeRange(), kSortStableFuncNamePrefix, xPerm, ny, |
| 1165 | ValueRange(args).drop_back(n: nTrailingP), createSortStableFunc); |
| 1166 | builder.create<func::CallOp>(loc, insertionSortFunc, TypeRange(), |
| 1167 | ValueRange(args).drop_back(nTrailingP)); |
| 1168 | builder.create<scf::YieldOp>(loc, ValueRange{lo, lo}); |
| 1169 | |
| 1170 | // When len > limit. |
| 1171 | builder.setInsertionPointToStart(&lenIf.getElseRegion().front()); |
| 1172 | Value depthLimit = args.back(); |
| 1173 | depthLimit = builder.create<arith::SubIOp>(loc, depthLimit, |
| 1174 | constantI64(builder, loc, 1)); |
| 1175 | Value depthCond = |
| 1176 | builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ule, |
| 1177 | depthLimit, constantI64(builder, loc, 0)); |
| 1178 | scf::IfOp depthIf = |
| 1179 | builder.create<scf::IfOp>(loc, types, depthCond, /*else=*/true); |
| 1180 | |
| 1181 | // When depth exceeds limit. |
| 1182 | builder.setInsertionPointToStart(&depthIf.getThenRegion().front()); |
| 1183 | FlatSymbolRefAttr heapSortFunc = getMangledSortHelperFunc( |
| 1184 | builder, func, TypeRange(), kHeapSortFuncNamePrefix, xPerm, ny, |
| 1185 | ValueRange(args).drop_back(n: nTrailingP), createHeapSortFunc); |
| 1186 | builder.create<func::CallOp>(loc, heapSortFunc, TypeRange(), |
| 1187 | ValueRange(args).drop_back(nTrailingP)); |
| 1188 | builder.create<scf::YieldOp>(loc, ValueRange{lo, lo}); |
| 1189 | |
| 1190 | // When depth doesn't exceed limit. |
| 1191 | builder.setInsertionPointToStart(&depthIf.getElseRegion().front()); |
| 1192 | args.back() = depthLimit; |
| 1193 | std::tie(args&: lo, args&: hi) = |
| 1194 | createQuickSort(builder, module, func, args, xPerm, ny, nTrailingP); |
| 1195 | builder.create<scf::YieldOp>(loc, ValueRange{lo, hi}); |
| 1196 | |
| 1197 | builder.setInsertionPointAfter(depthIf); |
| 1198 | lo = depthIf.getResult(0); |
| 1199 | hi = depthIf.getResult(1); |
| 1200 | builder.create<scf::YieldOp>(loc, ValueRange{lo, hi}); |
| 1201 | |
| 1202 | builder.setInsertionPointAfter(lenIf); |
| 1203 | lo = lenIf.getResult(0); |
| 1204 | hi = lenIf.getResult(1); |
| 1205 | } else { |
| 1206 | std::tie(args&: lo, args&: hi) = |
| 1207 | createQuickSort(builder, module, func, args, xPerm, ny, nTrailingP); |
| 1208 | } |
| 1209 | |
| 1210 | // New [lo, hi) for the next while-loop iteration. |
| 1211 | builder.create<scf::YieldOp>(loc, ValueRange{lo, hi}); |
| 1212 | |
| 1213 | // After the while-loop. |
| 1214 | builder.setInsertionPointAfter(whileOp); |
| 1215 | builder.create<func::ReturnOp>(loc); |
| 1216 | } |
| 1217 | |
| 1218 | /// Implements the rewriting for operator sort and sort_coo. |
| 1219 | template <typename OpTy> |
| 1220 | LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, AffineMap xPerm, |
| 1221 | uint64_t ny, PatternRewriter &rewriter) { |
| 1222 | Location loc = op.getLoc(); |
| 1223 | SmallVector<Value> operands{constantIndex(builder&: rewriter, loc, i: 0), op.getN()}; |
| 1224 | |
| 1225 | // Convert `values` to have dynamic shape and append them to `operands`. |
| 1226 | for (Value v : xys) { |
| 1227 | auto mtp = getMemRefType(v); |
| 1228 | if (!mtp.isDynamicDim(0)) { |
| 1229 | auto newMtp = |
| 1230 | MemRefType::get({ShapedType::kDynamic}, mtp.getElementType()); |
| 1231 | v = rewriter.create<memref::CastOp>(loc, newMtp, v); |
| 1232 | } |
| 1233 | operands.push_back(Elt: v); |
| 1234 | } |
| 1235 | |
| 1236 | auto insertPoint = op->template getParentOfType<func::FuncOp>(); |
| 1237 | if (!insertPoint) |
| 1238 | return failure(); |
| 1239 | |
| 1240 | SmallString<32> funcName; |
| 1241 | FuncGeneratorType funcGenerator; |
| 1242 | uint32_t nTrailingP = 0; |
| 1243 | switch (op.getAlgorithm()) { |
| 1244 | case SparseTensorSortKind::HybridQuickSort: { |
| 1245 | funcName = kHybridQuickSortFuncNamePrefix; |
| 1246 | funcGenerator = createQuickSortFunc; |
| 1247 | nTrailingP = 1; |
| 1248 | // As a heuristics, set depthLimit = 2 * log2(n). |
| 1249 | Value lo = operands[loIdx]; |
| 1250 | Value hi = operands[hiIdx]; |
| 1251 | Value len = rewriter.create<arith::IndexCastOp>( |
| 1252 | loc, rewriter.getI64Type(), |
| 1253 | rewriter.create<arith::SubIOp>(loc, hi, lo)); |
| 1254 | Value depthLimit = rewriter.create<arith::SubIOp>( |
| 1255 | loc, constantI64(rewriter, loc, 64), |
| 1256 | rewriter.create<math::CountLeadingZerosOp>(loc, len)); |
| 1257 | operands.push_back(Elt: depthLimit); |
| 1258 | break; |
| 1259 | } |
| 1260 | case SparseTensorSortKind::QuickSort: |
| 1261 | funcName = kQuickSortFuncNamePrefix; |
| 1262 | funcGenerator = createQuickSortFunc; |
| 1263 | break; |
| 1264 | case SparseTensorSortKind::InsertionSortStable: |
| 1265 | funcName = kSortStableFuncNamePrefix; |
| 1266 | funcGenerator = createSortStableFunc; |
| 1267 | break; |
| 1268 | case SparseTensorSortKind::HeapSort: |
| 1269 | funcName = kHeapSortFuncNamePrefix; |
| 1270 | funcGenerator = createHeapSortFunc; |
| 1271 | break; |
| 1272 | } |
| 1273 | |
| 1274 | FlatSymbolRefAttr func = |
| 1275 | getMangledSortHelperFunc(rewriter, insertPoint, TypeRange(), funcName, |
| 1276 | xPerm, ny, operands, funcGenerator, nTrailingP); |
| 1277 | rewriter.replaceOpWithNewOp<func::CallOp>(op, func, TypeRange(), operands); |
| 1278 | return success(); |
| 1279 | } |
| 1280 | |
| 1281 | //===---------------------------------------------------------------------===// |
| 1282 | // The actual sparse buffer rewriting rules. |
| 1283 | //===---------------------------------------------------------------------===// |
| 1284 | |
| 1285 | namespace { |
| 1286 | /// Sparse rewriting rule for the push_back operator. |
| 1287 | struct PushBackRewriter : OpRewritePattern<PushBackOp> { |
| 1288 | public: |
| 1289 | using OpRewritePattern<PushBackOp>::OpRewritePattern; |
| 1290 | PushBackRewriter(MLIRContext *context, bool enableInit) |
| 1291 | : OpRewritePattern(context), enableBufferInitialization(enableInit) {} |
| 1292 | LogicalResult matchAndRewrite(PushBackOp op, |
| 1293 | PatternRewriter &rewriter) const override { |
| 1294 | // Rewrite push_back(buffer, value, n) to: |
| 1295 | // new_size = size(buffer) + n |
| 1296 | // if (new_size > capacity(buffer)) |
| 1297 | // while new_size > new_capacity |
| 1298 | // new_capacity = new_capacity*2 |
| 1299 | // new_buffer = realloc(buffer, new_capacity) |
| 1300 | // buffer = new_buffer |
| 1301 | // subBuffer = subviewof(buffer) |
| 1302 | // linalg.fill subBuffer value |
| 1303 | // |
| 1304 | // size(buffer) += n |
| 1305 | // |
| 1306 | // The capacity check is skipped when the attribute inbounds is presented. |
| 1307 | Location loc = op->getLoc(); |
| 1308 | Value c0 = constantIndex(builder&: rewriter, loc, i: 0); |
| 1309 | Value buffer = op.getInBuffer(); |
| 1310 | Value capacity = rewriter.create<memref::DimOp>(loc, buffer, c0); |
| 1311 | Value size = op.getCurSize(); |
| 1312 | Value value = op.getValue(); |
| 1313 | |
| 1314 | Value n = op.getN() ? op.getN() : constantIndex(builder&: rewriter, loc, i: 1); |
| 1315 | Value newSize = rewriter.create<arith::AddIOp>(loc, size, n); |
| 1316 | auto nValue = dyn_cast_or_null<arith::ConstantIndexOp>(Val: n.getDefiningOp()); |
| 1317 | bool nIsOne = (nValue && nValue.value() == 1); |
| 1318 | |
| 1319 | if (!op.getInbounds()) { |
| 1320 | Value cond = rewriter.create<arith::CmpIOp>( |
| 1321 | loc, arith::CmpIPredicate::ugt, newSize, capacity); |
| 1322 | |
| 1323 | Value c2 = constantIndex(builder&: rewriter, loc, i: 2); |
| 1324 | auto bufferType = |
| 1325 | MemRefType::get({ShapedType::kDynamic}, value.getType()); |
| 1326 | scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, bufferType, cond, |
| 1327 | /*else=*/true); |
| 1328 | // True branch. |
| 1329 | rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); |
| 1330 | if (nIsOne) { |
| 1331 | capacity = rewriter.create<arith::MulIOp>(loc, capacity, c2); |
| 1332 | } else { |
| 1333 | // Use a do-while loop to calculate the new capacity as follows: |
| 1334 | // do { new_capacity *= 2 } while (size > new_capacity) |
| 1335 | scf::WhileOp whileOp = |
| 1336 | rewriter.create<scf::WhileOp>(loc, capacity.getType(), capacity); |
| 1337 | |
| 1338 | // The before-region of the WhileOp. |
| 1339 | Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, |
| 1340 | {capacity.getType()}, {loc}); |
| 1341 | rewriter.setInsertionPointToEnd(before); |
| 1342 | |
| 1343 | capacity = |
| 1344 | rewriter.create<arith::MulIOp>(loc, before->getArgument(0), c2); |
| 1345 | cond = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ugt, |
| 1346 | newSize, capacity); |
| 1347 | rewriter.create<scf::ConditionOp>(loc, cond, ValueRange{capacity}); |
| 1348 | // The after-region of the WhileOp. |
| 1349 | Block *after = rewriter.createBlock(&whileOp.getAfter(), {}, |
| 1350 | {capacity.getType()}, {loc}); |
| 1351 | rewriter.setInsertionPointToEnd(after); |
| 1352 | rewriter.create<scf::YieldOp>(loc, after->getArguments()); |
| 1353 | |
| 1354 | rewriter.setInsertionPointAfter(whileOp); |
| 1355 | capacity = whileOp.getResult(0); |
| 1356 | } |
| 1357 | |
| 1358 | Value newBuffer = |
| 1359 | rewriter.create<memref::ReallocOp>(loc, bufferType, buffer, capacity); |
| 1360 | if (enableBufferInitialization) { |
| 1361 | Value fillSize = rewriter.create<arith::SubIOp>(loc, capacity, newSize); |
| 1362 | Value fillValue = constantZero(builder&: rewriter, loc, tp: value.getType()); |
| 1363 | Value subBuffer = rewriter.create<memref::SubViewOp>( |
| 1364 | loc, newBuffer, /*offset=*/ValueRange{newSize}, |
| 1365 | /*size=*/ValueRange{fillSize}, |
| 1366 | /*step=*/ValueRange{constantIndex(rewriter, loc, 1)}); |
| 1367 | rewriter.create<linalg::FillOp>(loc, fillValue, subBuffer); |
| 1368 | } |
| 1369 | rewriter.create<scf::YieldOp>(loc, newBuffer); |
| 1370 | |
| 1371 | // False branch. |
| 1372 | rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); |
| 1373 | rewriter.create<scf::YieldOp>(loc, buffer); |
| 1374 | |
| 1375 | // Prepare for adding the value to the end of the buffer. |
| 1376 | rewriter.setInsertionPointAfter(ifOp); |
| 1377 | buffer = ifOp.getResult(0); |
| 1378 | } |
| 1379 | |
| 1380 | // Add the value to the end of the buffer. |
| 1381 | if (nIsOne) { |
| 1382 | rewriter.create<memref::StoreOp>(loc, value, buffer, size); |
| 1383 | } else { |
| 1384 | Value subBuffer = rewriter.create<memref::SubViewOp>( |
| 1385 | loc, buffer, /*offset=*/ValueRange{size}, /*size=*/ValueRange{n}, |
| 1386 | /*step=*/ValueRange{constantIndex(rewriter, loc, 1)}); |
| 1387 | rewriter.create<linalg::FillOp>(loc, value, subBuffer); |
| 1388 | } |
| 1389 | |
| 1390 | // Update the buffer size. |
| 1391 | rewriter.replaceOp(op, {buffer, newSize}); |
| 1392 | return success(); |
| 1393 | } |
| 1394 | |
| 1395 | private: |
| 1396 | bool enableBufferInitialization; |
| 1397 | }; |
| 1398 | |
| 1399 | /// Sparse rewriting rule for the sort_coo operator. |
| 1400 | struct SortRewriter : public OpRewritePattern<SortOp> { |
| 1401 | public: |
| 1402 | using OpRewritePattern<SortOp>::OpRewritePattern; |
| 1403 | |
| 1404 | LogicalResult matchAndRewrite(SortOp op, |
| 1405 | PatternRewriter &rewriter) const override { |
| 1406 | SmallVector<Value> xys; |
| 1407 | xys.push_back(Elt: op.getXy()); |
| 1408 | xys.append(op.getYs().begin(), op.getYs().end()); |
| 1409 | |
| 1410 | auto xPerm = op.getPermMap(); |
| 1411 | uint64_t ny = 0; |
| 1412 | if (auto nyAttr = op.getNyAttr()) |
| 1413 | ny = nyAttr.getInt(); |
| 1414 | |
| 1415 | return matchAndRewriteSortOp(op, xys, xPerm, ny, rewriter); |
| 1416 | } |
| 1417 | }; |
| 1418 | |
| 1419 | } // namespace |
| 1420 | |
| 1421 | //===---------------------------------------------------------------------===// |
| 1422 | // Methods that add patterns described in this file to a pattern list. |
| 1423 | //===---------------------------------------------------------------------===// |
| 1424 | |
| 1425 | void mlir::populateSparseBufferRewriting(RewritePatternSet &patterns, |
| 1426 | bool enableBufferInitialization) { |
| 1427 | patterns.add<PushBackRewriter>(arg: patterns.getContext(), |
| 1428 | args&: enableBufferInitialization); |
| 1429 | patterns.add<SortRewriter>(arg: patterns.getContext()); |
| 1430 | } |
| 1431 | |