| 1 | //===- Sparsification.cpp - Implementation of sparsification --------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements converting sparse tensor types to actual sparse code. |
| 10 | // |
| 11 | //===----------------------------------------------------------------------===// |
| 12 | |
| 13 | #include "Utils/CodegenEnv.h" |
| 14 | #include "Utils/CodegenUtils.h" |
| 15 | #include "Utils/LoopEmitter.h" |
| 16 | |
| 17 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| 18 | #include "mlir/Dialect/Arith/IR/Arith.h" |
| 19 | #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" |
| 20 | #include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
| 21 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
| 22 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
| 23 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
| 24 | #include "mlir/Dialect/Linalg/Utils/Utils.h" |
| 25 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| 26 | #include "mlir/Dialect/SCF/IR/SCF.h" |
| 27 | #include "mlir/Dialect/SCF/Transforms/Transforms.h" |
| 28 | #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" |
| 29 | #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" |
| 30 | #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" |
| 31 | #include "mlir/Dialect/SparseTensor/Utils/Merger.h" |
| 32 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| 33 | #include "mlir/IR/AffineExprVisitor.h" |
| 34 | #include "mlir/IR/Matchers.h" |
| 35 | #include "mlir/IR/TensorEncoding.h" |
| 36 | #include "llvm/ADT/SmallBitVector.h" |
| 37 | |
| 38 | #include <optional> |
| 39 | |
| 40 | using namespace mlir; |
| 41 | using namespace mlir::sparse_tensor; |
| 42 | |
| 43 | //===----------------------------------------------------------------------===// |
| 44 | // Sparsifier analysis methods. |
| 45 | //===----------------------------------------------------------------------===// |
| 46 | |
| 47 | /// Returns true iff affine expression is invariant. Sets the |
| 48 | /// parameter `isCurrentLoop` when expression just became invariant. |
| 49 | static bool isInvariantAffine(AffineExpr a, LoopId curr, bool &isCurrentLoop) { |
| 50 | switch (a.getKind()) { |
| 51 | case AffineExprKind::DimId: { |
| 52 | const LoopId i = cast<AffineDimExpr>(Val&: a).getPosition(); |
| 53 | if (i + 1 == curr) { |
| 54 | isCurrentLoop = true; |
| 55 | return true; // becomes invariant at current loop |
| 56 | } |
| 57 | return i < curr; // invariant when already generated |
| 58 | } |
| 59 | case AffineExprKind::Add: |
| 60 | case AffineExprKind::Mul: { |
| 61 | auto binOp = cast<AffineBinaryOpExpr>(Val&: a); |
| 62 | return isInvariantAffine(a: binOp.getLHS(), curr, isCurrentLoop) && |
| 63 | isInvariantAffine(a: binOp.getRHS(), curr, isCurrentLoop); |
| 64 | } |
| 65 | default: { |
| 66 | assert(isa<AffineConstantExpr>(a)); |
| 67 | return true; |
| 68 | } |
| 69 | } |
| 70 | } |
| 71 | |
| 72 | /// Helper method to inspect affine expressions. Rejects cases where the |
| 73 | /// same index is used more than once. Also rejects compound affine |
| 74 | /// expressions in sparse dimensions. |
| 75 | static bool findAffine(Merger &merger, TensorId tid, Level lvl, AffineExpr a, |
| 76 | LevelType lt, bool setLvlFormat = true) { |
| 77 | switch (a.getKind()) { |
| 78 | case AffineExprKind::DimId: { |
| 79 | const LoopId idx = merger.makeLoopId(i: cast<AffineDimExpr>(Val&: a).getPosition()); |
| 80 | if (!isUndefLT(lt: merger.getLvlType(t: tid, i: idx))) |
| 81 | return false; // used more than once |
| 82 | if (setLvlFormat) |
| 83 | merger.setLevelAndType(t: tid, i: idx, lvl, lt); |
| 84 | return true; |
| 85 | } |
| 86 | case AffineExprKind::Add: |
| 87 | case AffineExprKind::Mul: |
| 88 | case AffineExprKind::Constant: { |
| 89 | assert(lt.hasDenseSemantic()); |
| 90 | if (auto binOp = dyn_cast<AffineBinaryOpExpr>(Val&: a)) { |
| 91 | // We do not set dim level format for affine expression like d0 + d1 on |
| 92 | // either loop index at d0 or d1. We continue the recursion merely to |
| 93 | // check whether current affine is admissible or not. |
| 94 | return findAffine(merger, tid, lvl, a: binOp.getLHS(), lt, setLvlFormat: false) && |
| 95 | findAffine(merger, tid, lvl, a: binOp.getRHS(), lt, setLvlFormat: false); |
| 96 | } |
| 97 | // Falls through when it is a constant Affine |
| 98 | return true; |
| 99 | } |
| 100 | default: |
| 101 | return false; |
| 102 | } |
| 103 | } |
| 104 | |
| 105 | /// Helper method to inspect affine expressions for index variable reduction |
| 106 | /// based codegen. It finds the dependent index set for all tensor levels in the |
| 107 | /// current expression we are generating. |
| 108 | /// |
| 109 | /// For example, when handling A[i+j][j+k], we build the two way mapping in |
| 110 | /// merger between (tensor, level) pairs and their dependent index variable set: |
| 111 | /// A_0 <=> [i, j] and A_1 <=> [j, k] |
| 112 | /// |
| 113 | /// It rejects cases (returns false) |
| 114 | /// 1st, when the same index is used more than once, e.g., A[i+j][i] |
| 115 | /// 2nd, when multiplication is used in the non-trivial index expression. |
| 116 | /// 3rd, when a constant operand is used in the non-trivial index expression. |
| 117 | /// |
| 118 | /// TODO: constant should be easy to handle. |
| 119 | static bool findDepIdxSet(Merger &merger, TensorId tensor, Level lvl, |
| 120 | AffineExpr a, LevelType lt, bool isSubExp = false, |
| 121 | int64_t coefficient = 1) { |
| 122 | switch (a.getKind()) { |
| 123 | case AffineExprKind::DimId: { |
| 124 | // Only allow positive coefficients on AffineDimExpr. |
| 125 | if (coefficient <= 0) |
| 126 | return false; |
| 127 | |
| 128 | const LoopId idx = merger.makeLoopId(i: cast<AffineDimExpr>(Val&: a).getPosition()); |
| 129 | if (!isUndefLT(lt: merger.getLvlType(t: tensor, i: idx))) |
| 130 | return false; // used more than once, e.g., A[i][i] |
| 131 | |
| 132 | // TODO: Generalizes the following two cases. A[i] (with trivial index |
| 133 | // expression) can be treated as a special affine index expression. We do |
| 134 | // not necessarily need to differentiate them. |
| 135 | if (!isSubExp) { |
| 136 | assert(coefficient == 1); |
| 137 | merger.setLevelAndType(t: tensor, i: idx, lvl, lt); |
| 138 | } |
| 139 | |
| 140 | if (isSubExp) { |
| 141 | // The current loops appears in more than one affine expressions on the |
| 142 | // same tensor. We can not handle this case. e.g., A[i+j][i+k], `i` is |
| 143 | // used twice. |
| 144 | if (merger.hasDependentLvl(i: idx, t: tensor)) { |
| 145 | // TODO: This can be supported by coiterate slices if the loop idx is |
| 146 | // appeared on affine index for different tensor, or take slice on |
| 147 | // multiple dimensions when it is on the same tensor. |
| 148 | // E.g., |
| 149 | // `d0 + d1` for indexing t0[lvl0] and `d0 + d2` for indexing t1[lvl0] |
| 150 | // d0_1 = getNextSliceOffset t0 along lvl0 |
| 151 | // d0_2 = getNextSliceOffset t1 along lvl0 |
| 152 | // if d0_1 == d0_2 then d0 = d0_1 = d0_1 |
| 153 | // else increase min(d0_1, d0_2). |
| 154 | return false; |
| 155 | } |
| 156 | merger.setLoopDependentTensorLevel(i: idx, t: tensor, lvl, lt, coefficient); |
| 157 | } |
| 158 | return true; |
| 159 | } |
| 160 | case AffineExprKind::Constant: |
| 161 | case AffineExprKind::Mul: { |
| 162 | // TODO: Support index expression like `2 * d0`, we now only support more |
| 163 | // complicated cases like `2 * d0 + d1`. |
| 164 | if (!isSubExp) |
| 165 | return false; |
| 166 | |
| 167 | // TODO: Support Constant AffineExp for slice-based codegen |
| 168 | if (isa<AffineConstantExpr>(Val: a)) |
| 169 | llvm_unreachable("Not yet implemented" ); |
| 170 | |
| 171 | auto binOp = cast<AffineBinaryOpExpr>(Val&: a); |
| 172 | auto lhs = binOp.getLHS(), rhs = binOp.getRHS(); |
| 173 | if (isa<AffineConstantExpr>(Val: rhs)) |
| 174 | std::swap(a&: lhs, b&: rhs); |
| 175 | // Must be in form of `constant * d`. |
| 176 | assert(isa<AffineConstantExpr>(lhs) && isa<AffineDimExpr>(rhs)); |
| 177 | int64_t coefficient = cast<AffineConstantExpr>(Val&: lhs).getValue(); |
| 178 | return findDepIdxSet(merger, tensor, lvl, a: rhs, lt, isSubExp, coefficient); |
| 179 | } |
| 180 | case AffineExprKind::Add: { |
| 181 | auto binOp = cast<AffineBinaryOpExpr>(Val&: a); |
| 182 | return findDepIdxSet(merger, tensor, lvl, a: binOp.getLHS(), lt, isSubExp: true) && |
| 183 | findDepIdxSet(merger, tensor, lvl, a: binOp.getRHS(), lt, isSubExp: true); |
| 184 | } |
| 185 | default: |
| 186 | return false; |
| 187 | } |
| 188 | } |
| 189 | |
| 190 | /// Gets the total number of compound affine expressions in the |
| 191 | /// `getMatchingIndexingMap` for the given tensor. For the following inputs: |
| 192 | /// |
| 193 | /// map = (d0, d1, d2) => (d0 + d1 : compressed, d2 : compressed) |
| 194 | /// |
| 195 | /// Returns 1 (because the first level is compressed and its corresponding |
| 196 | /// indexing-expression is `d0 + d1`) |
| 197 | static unsigned getNumNonTrivialIdxExpOnSparseLvls(AffineMap map, |
| 198 | Value tensor) { |
| 199 | // The `tensor` is not guaranteed to have `RankedTensorType`, therefore |
| 200 | // we can't use `getRankedTensorType`/`getSparseTensorType` here. |
| 201 | // However, we don't need to handle `StorageSpecifierType`, so we |
| 202 | // can use `SparseTensorType` once we guard against non-tensors. |
| 203 | const auto rtp = dyn_cast<RankedTensorType>(tensor.getType()); |
| 204 | if (!rtp) |
| 205 | return 0; |
| 206 | const SparseTensorType stt(rtp); |
| 207 | |
| 208 | const Level lvlRank = stt.getLvlRank(); |
| 209 | const auto exprs = map.getResults(); |
| 210 | assert(static_cast<Dimension>(exprs.size()) == lvlRank && |
| 211 | "AffineMap does not have dimension-rank many results" ); |
| 212 | unsigned num = 0; |
| 213 | for (Level l = 0; l < lvlRank; l++) { |
| 214 | if (!isa<AffineDimExpr>(Val: exprs[l]) && !stt.getLvlType(l).hasDenseSemantic()) |
| 215 | num++; |
| 216 | } |
| 217 | return num; |
| 218 | } |
| 219 | |
| 220 | /// Gets the total number of sparse levels with compound affine |
| 221 | /// expressions, summed over all operands of the `GenericOp`. |
| 222 | static unsigned getNumNonTrivialIdxExpOnSparseLvls(linalg::GenericOp op) { |
| 223 | unsigned num = 0; |
| 224 | for (OpOperand &t : op->getOpOperands()) |
| 225 | num += getNumNonTrivialIdxExpOnSparseLvls(op.getMatchingIndexingMap(&t), |
| 226 | t.get()); |
| 227 | return num; |
| 228 | } |
| 229 | |
| 230 | // Returns true iff output has nontrivial affine indices. |
| 231 | static bool hasNonTrivialAffineOnSparseOut(linalg::GenericOp op) { |
| 232 | OpOperand *out = op.getDpsInitOperand(0); |
| 233 | if (getSparseTensorType(val: out->get()).isAllDense()) |
| 234 | return false; |
| 235 | return getNumNonTrivialIdxExpOnSparseLvls(op.getMatchingIndexingMap(out), |
| 236 | out->get()); |
| 237 | } |
| 238 | |
| 239 | /// Helper method to inspect sparse encodings in the tensor types. |
| 240 | /// Fills the per-dimension sparsity information for all tensors. |
| 241 | /// Returns true if the sparse annotations and affine subscript |
| 242 | /// expressions of all tensors are admissible. Returns false if |
| 243 | /// no annotations are found or inadmissible constructs occur. |
| 244 | /// We currently support two different ways to handle non-trivial index |
| 245 | /// expression on sparse tensors, and they accept different affine expressions. |
| 246 | /// When using dependent index reducton-based approach, it currently only |
| 247 | /// supports affine addition index expression. |
| 248 | static bool findSparseAnnotations(CodegenEnv &env, bool idxReducBased) { |
| 249 | bool annotated = false; |
| 250 | for (OpOperand &t : env.op()->getOpOperands()) { |
| 251 | const TensorId tid = env.makeTensorId(t.getOperandNumber()); |
| 252 | const auto map = env.op().getMatchingIndexingMap(&t); |
| 253 | const auto enc = getSparseTensorEncoding(t.get().getType()); |
| 254 | if (enc) |
| 255 | annotated = true; |
| 256 | const Level lvlRank = map.getNumResults(); |
| 257 | assert(!enc || lvlRank == enc.getLvlRank()); |
| 258 | assert(static_cast<Level>(env.op().getRank(&t)) == lvlRank); |
| 259 | // We only need to do index reduction if there is at least one |
| 260 | // non-trivial index expression on sparse levels. If all non-trivial |
| 261 | // index expression is on dense levels, we can efficiently rely on |
| 262 | // the random access to locate the element. |
| 263 | bool needIdxReduc = |
| 264 | enc && getNumNonTrivialIdxExpOnSparseLvls(map, t.get()) != 0; |
| 265 | // If then current tensor being inspected requires affine index, it need |
| 266 | // to be sliced. |
| 267 | for (Level l = 0; l < lvlRank; l++) { |
| 268 | const AffineExpr a = map.getResult(l); |
| 269 | const LevelType lt = enc.getLvlType(l); |
| 270 | if (idxReducBased && needIdxReduc) { |
| 271 | if (!findDepIdxSet(env.merger(), tid, l, a, lt)) |
| 272 | return false; // inadmissible affine expression |
| 273 | } else { |
| 274 | if (!findAffine(env.merger(), tid, l, a, lt)) |
| 275 | return false; // inadmissible affine expression |
| 276 | } |
| 277 | } |
| 278 | } |
| 279 | return annotated; |
| 280 | } |
| 281 | |
| 282 | //===----------------------------------------------------------------------===// |
| 283 | // Sparsifier synthesis methods (statements and expressions). |
| 284 | //===----------------------------------------------------------------------===// |
| 285 | |
| 286 | /// Local bufferization of all dense and sparse data structures. |
| 287 | static void genBuffers(CodegenEnv &env, OpBuilder &builder) { |
| 288 | linalg::GenericOp op = env.op(); |
| 289 | Location loc = op.getLoc(); |
| 290 | assert(op.getNumOperands() == op.getNumDpsInputs() + 1); |
| 291 | |
| 292 | SmallVector<Range, 4> loopRange = |
| 293 | llvm::cast<linalg::LinalgOp>(op.getOperation()) |
| 294 | .createLoopRanges(builder, loc); |
| 295 | |
| 296 | env.emitter().initializeLoopEmit( |
| 297 | builder, loc, |
| 298 | /// Generates buffer for the output tensor. |
| 299 | /// Note that all sparse kernels assume that when all elements are written |
| 300 | /// to (viz. x(i) = y(i) * z(i)), the output buffer is already initialized |
| 301 | /// to all zeroes and only nonzeroes values are computed and written out. |
| 302 | /// For updates (viz. x(i) += y(i) * z(i)), only nonzeroes values are used |
| 303 | /// for the updates and no assumption on the original contents of the |
| 304 | /// output buffer is necessary. |
| 305 | [&op](OpBuilder &builder, Location loc, Value memref, |
| 306 | Value tensor) -> Value { |
| 307 | // Must not be a sparse tensor. |
| 308 | assert(!getSparseTensorEncoding(tensor.getType())); |
| 309 | // Two output tensor references should point to the same object. |
| 310 | OpOperand *lhs = op.getDpsInitOperand(0); |
| 311 | assert(lhs->get() == tensor); |
| 312 | // An output tensor can simply materialize from the buffer of the tensor |
| 313 | // that appears in the outs() clause. For updates, this has the |
| 314 | // advantage that only the nonzero value are involved in the |
| 315 | // computation, keeping the operation O(nnz). In all other cases, we are |
| 316 | // forced to zero out the buffer to enforce the assumption above, which |
| 317 | // may negatively impact running complexity (viz. O(n^2 + nnz) vs. |
| 318 | // O(nnz) for matrices). |
| 319 | // TODO: use better analysis to avoid zeroing out the buffer? |
| 320 | bool isInit = op.isInitTensor(lhs); |
| 321 | Value init = memref; |
| 322 | if (!isInit) { |
| 323 | Value zero = constantZero(builder, loc, |
| 324 | tp: getElementTypeOrSelf(type: tensor.getType())); |
| 325 | builder.create<linalg::FillOp>(loc, ValueRange{zero}, |
| 326 | ValueRange{init}); |
| 327 | } |
| 328 | return init; |
| 329 | }, |
| 330 | [&loopRange](OpBuilder &b, Location loc, Level l) { |
| 331 | assert(l < loopRange.size()); |
| 332 | return mlir::getValueOrCreateConstantIndexOp(b, loc, loopRange[l].size); |
| 333 | }); |
| 334 | } |
| 335 | |
| 336 | /// Generates index for load/store on sparse tensor. |
| 337 | static Value genIndex(CodegenEnv &env, OpOperand *t) { |
| 338 | const auto map = env.op().getMatchingIndexingMap(t); |
| 339 | const auto stt = getSparseTensorType(val: t->get()); |
| 340 | const Level lvlRank = stt.getLvlRank(); |
| 341 | assert(static_cast<Level>(map.getNumResults()) == lvlRank); |
| 342 | const AffineExpr a = map.getResult(lvlRank - 1); |
| 343 | assert(a.getKind() == AffineExprKind::DimId); |
| 344 | const LoopId idx = env.makeLoopId(i: cast<AffineDimExpr>(Val: a).getPosition()); |
| 345 | return env.getLoopVar(i: idx); |
| 346 | } |
| 347 | |
| 348 | /// Generates subscript for load/store on a dense or sparse tensor. |
| 349 | static Value genSubscript(CodegenEnv &env, OpBuilder &builder, OpOperand *t, |
| 350 | SmallVectorImpl<Value> &args) { |
| 351 | const Location loc = env.op().getLoc(); |
| 352 | const TensorId tid = env.makeTensorId(t: t->getOperandNumber()); |
| 353 | const auto map = env.op().getMatchingIndexingMap(t); |
| 354 | const auto stt = getSparseTensorType(val: t->get()); |
| 355 | if (stt.hasEncoding()) { |
| 356 | // For sparse tensors we only push the last-level's position onto `args`. |
| 357 | const auto pos = env.emitter().getValPosits(tid); |
| 358 | assert(!pos.empty()); |
| 359 | args.append(RHS: pos); |
| 360 | // Simply returns the tensor to extract value using iterators. |
| 361 | if (env.options().sparseEmitStrategy == SparseEmitStrategy::kSparseIterator) |
| 362 | return t->get(); |
| 363 | } else { |
| 364 | // For dense tensors we push all level's coordinates onto `args`. |
| 365 | const Level lvlRank = stt.getLvlRank(); |
| 366 | assert(static_cast<Level>(map.getNumResults()) == lvlRank); |
| 367 | for (Level l = 0; l < lvlRank; l++) { |
| 368 | const auto lvlExpr = map.getResult(l); |
| 369 | const auto lvlCrd = env.emitter().genAffine(builder, loc, a: lvlExpr); |
| 370 | args.push_back(Elt: lvlCrd); |
| 371 | } |
| 372 | } |
| 373 | return env.emitter().getValBuffer()[tid]; |
| 374 | } |
| 375 | |
| 376 | /// Generates insertion code to implement dynamic tensor load. |
| 377 | static Value genInsertionLoad(CodegenEnv &env, OpBuilder &builder, |
| 378 | OpOperand *t) { |
| 379 | linalg::GenericOp op = env.op(); |
| 380 | Location loc = op.getLoc(); |
| 381 | // Direct lexicographic coordinate order, tensor loads as zero. |
| 382 | if (!env.isExpand()) { |
| 383 | Type tp = getElementTypeOrSelf(type: t->get().getType()); |
| 384 | return constantZero(builder, loc, tp); |
| 385 | } |
| 386 | // Load from expanded access pattern. |
| 387 | Value index = genIndex(env, t); |
| 388 | return builder.create<memref::LoadOp>(loc, env.getExpandValues(), index); |
| 389 | } |
| 390 | |
| 391 | /// Generates insertion code to implement dynamic tensor load for reduction. |
| 392 | static Value genInsertionLoadReduce(CodegenEnv &env, OpBuilder &builder, |
| 393 | OpOperand *t) { |
| 394 | linalg::GenericOp op = env.op(); |
| 395 | Location loc = op.getLoc(); |
| 396 | Value identity = env.getCustomRedId(); |
| 397 | // Direct lexicographic coordinate order, tensor loads as identity. |
| 398 | if (!env.isExpand()) |
| 399 | return identity; |
| 400 | // Load from expanded access pattern if filled, identity otherwise. |
| 401 | Value values = env.getExpandValues(); |
| 402 | Value filled = env.getExpandFilled(); |
| 403 | Value index = genIndex(env, t); |
| 404 | Value isFilled = builder.create<memref::LoadOp>(loc, filled, index); |
| 405 | Value valAtIndex = builder.create<memref::LoadOp>(loc, values, index); |
| 406 | return builder.create<arith::SelectOp>(loc, isFilled, valAtIndex, identity); |
| 407 | } |
| 408 | |
| 409 | static Value genConditionalInsert(Location loc, OpBuilder &builder, Value cond, |
| 410 | Value sparseOut, ValueRange ivs, Value v) { |
| 411 | scf::IfOp condInsert = |
| 412 | builder.create<scf::IfOp>(loc, sparseOut.getType(), cond, true); |
| 413 | // True branch. |
| 414 | builder.setInsertionPointToStart(condInsert.thenBlock()); |
| 415 | Value res = builder.create<tensor::InsertOp>(loc, v, sparseOut, ivs); |
| 416 | builder.create<scf::YieldOp>(loc, res); |
| 417 | // False branch. |
| 418 | builder.setInsertionPointToStart(condInsert.elseBlock()); |
| 419 | builder.create<scf::YieldOp>(loc, sparseOut); |
| 420 | // Value assignment. |
| 421 | builder.setInsertionPointAfter(condInsert); |
| 422 | return condInsert.getResult(0); |
| 423 | } |
| 424 | |
| 425 | /// Generates insertion code to implement dynamic tensor store. |
| 426 | static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t, |
| 427 | Value rhs) { |
| 428 | linalg::GenericOp op = env.op(); |
| 429 | Location loc = op.getLoc(); |
| 430 | // Direct insertion in lexicographic coordinate order. |
| 431 | if (!env.isExpand()) { |
| 432 | const LoopId numLoops = op.getRank(t); |
| 433 | // Retrieves the first `numLoop` induction variables. |
| 434 | SmallVector<Value> ivs = llvm::to_vector(Range: llvm::drop_end( |
| 435 | RangeOrContainer: env.emitter().getLoopIVsRange(), N: env.getCurrentDepth() - numLoops)); |
| 436 | Value chain = env.getInsertionChain(); |
| 437 | if (env.isValidLexInsert()) { |
| 438 | // Generates runtime check for a valid lex during reduction, |
| 439 | // to avoid inserting the identity value for empty reductions. |
| 440 | // if (validLexInsert) then |
| 441 | // insert(rhs) into chain |
| 442 | // return updated chain |
| 443 | // else |
| 444 | // return unmodified chain |
| 445 | Value out = genConditionalInsert(loc, builder, cond: env.getValidLexInsert(), |
| 446 | sparseOut: chain, ivs, v: rhs); |
| 447 | env.updateInsertionChain(chain: out); |
| 448 | } else { |
| 449 | Value sparseOut; |
| 450 | if (!hasAnySparseType(env.op().getInputs().getTypes())) { |
| 451 | // This is an all-dense -> sparse kernel, test rhs != 0 before |
| 452 | // insertion. |
| 453 | Value nz = genIsNonzero(builder, loc, v: rhs); |
| 454 | sparseOut = genConditionalInsert(loc, builder, cond: nz, sparseOut: chain, ivs, v: rhs); |
| 455 | } else { |
| 456 | sparseOut = builder.create<tensor::InsertOp>(loc, rhs, chain, ivs); |
| 457 | } |
| 458 | // Generates regular insertion chain. |
| 459 | env.updateInsertionChain(chain: sparseOut); |
| 460 | } |
| 461 | return; |
| 462 | } |
| 463 | // Generates insertion code along expanded access pattern. |
| 464 | // if (!expFilled[i]) then |
| 465 | // expFilled[i] = true |
| 466 | // expAdded[inserts++] = i |
| 467 | // endif |
| 468 | // values[i] = rhs |
| 469 | Value values = env.getExpandValues(); |
| 470 | Value filled = env.getExpandFilled(); |
| 471 | Value added = env.getExpandAdded(); |
| 472 | Value count = env.getExpandCount(); |
| 473 | Value index = genIndex(env, t); |
| 474 | Value fval = constantI1(builder, loc, b: false); |
| 475 | Value tval = constantI1(builder, loc, b: true); |
| 476 | // If statement. |
| 477 | Value isFilled = builder.create<memref::LoadOp>(loc, filled, index); |
| 478 | Value cond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, |
| 479 | isFilled, fval); |
| 480 | scf::IfOp ifOp = builder.create<scf::IfOp>(loc, builder.getIndexType(), cond, |
| 481 | /*else=*/true); |
| 482 | // True branch. |
| 483 | builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); |
| 484 | builder.create<memref::StoreOp>(loc, tval, filled, index); |
| 485 | builder.create<memref::StoreOp>(loc, index, added, count); |
| 486 | Value one = constantIndex(builder, loc, i: 1); |
| 487 | Value add = builder.create<arith::AddIOp>(loc, count, one); |
| 488 | builder.create<scf::YieldOp>(loc, add); |
| 489 | // False branch. |
| 490 | builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); |
| 491 | builder.create<scf::YieldOp>(loc, count); |
| 492 | builder.setInsertionPointAfter(ifOp); |
| 493 | // Value assignment. |
| 494 | env.updateExpandCount(count: ifOp.getResult(0)); |
| 495 | builder.create<memref::StoreOp>(loc, rhs, values, index); |
| 496 | } |
| 497 | |
| 498 | /// Generates a load on a dense or sparse tensor. |
| 499 | static Value genTensorLoad(CodegenEnv &env, OpBuilder &builder, ExprId exp) { |
| 500 | // Test if the load was hoisted to a higher loop nest. |
| 501 | Value val = env.exp(e: exp).val; |
| 502 | if (val) |
| 503 | return val; |
| 504 | // Get tensor operand. |
| 505 | linalg::GenericOp op = env.op(); |
| 506 | Location loc = op.getLoc(); |
| 507 | OpOperand *t = &op->getOpOperand(env.exp(e: exp).tensor); |
| 508 | // Fold binary-valued tensor into explicit value. |
| 509 | const auto stt = getSparseTensorType(val: t->get()); |
| 510 | if (auto explVal = stt.getExplicitVal()) |
| 511 | return genValFromAttr(builder, loc, explVal); |
| 512 | // Load during insertion. |
| 513 | if (env.isSparseOutput(o: t)) { |
| 514 | if (env.isCustomReduc()) |
| 515 | return genInsertionLoadReduce(env, builder, t); |
| 516 | return genInsertionLoad(env, builder, t); |
| 517 | } |
| 518 | |
| 519 | // Actual load. |
| 520 | SmallVector<Value> args; |
| 521 | Value ptr = genSubscript(env, builder, t, args); |
| 522 | if (llvm::isa<TensorType>(Val: ptr.getType())) { |
| 523 | assert(env.options().sparseEmitStrategy == |
| 524 | SparseEmitStrategy::kSparseIterator); |
| 525 | return builder.create<ExtractValOp>(loc, ptr, llvm::getSingleElement(args)); |
| 526 | } |
| 527 | return builder.create<memref::LoadOp>(loc, ptr, args); |
| 528 | } |
| 529 | |
| 530 | /// Generates a store on a dense or sparse tensor. |
| 531 | static void genTensorStore(CodegenEnv &env, OpBuilder &builder, ExprId exp, |
| 532 | Value rhs) { |
| 533 | // Only unary and binary are allowed to return an uninitialized rhs |
| 534 | // to indicate missing output. Or otherwise a custom reduction that |
| 535 | // received no value to accumulate. |
| 536 | if (!rhs) { |
| 537 | assert(env.exp(exp).kind == TensorExp::Kind::kUnary || |
| 538 | env.exp(exp).kind == TensorExp::Kind::kBinary || |
| 539 | env.exp(exp).kind == TensorExp::Kind::kReduce); |
| 540 | return; |
| 541 | } |
| 542 | // Test if this is a scalarized reduction. |
| 543 | if (env.isReduc()) { |
| 544 | env.updateReduc(val: rhs); |
| 545 | return; |
| 546 | } |
| 547 | // Regular store. |
| 548 | linalg::GenericOp op = env.op(); |
| 549 | Location loc = op.getLoc(); |
| 550 | OpOperand *t = op.getDpsInitOperand(0); |
| 551 | if (!env.isSparseOutput(o: t)) { |
| 552 | SmallVector<Value> args; |
| 553 | Value ptr = genSubscript(env, builder, t, args); |
| 554 | builder.create<memref::StoreOp>(loc, rhs, ptr, args); |
| 555 | return; |
| 556 | } |
| 557 | // Store during sparse insertion. |
| 558 | if (env.exp(e: exp).kind != TensorExp::Kind::kSelect) { |
| 559 | genInsertionStore(env, builder, t, rhs); |
| 560 | return; |
| 561 | } |
| 562 | // Select operation insertion. |
| 563 | Value chain = env.getInsertionChain(); |
| 564 | scf::IfOp ifOp = |
| 565 | builder.create<scf::IfOp>(loc, chain.getType(), rhs, /*else=*/true); |
| 566 | builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); |
| 567 | // Existing value was preserved to be used here. |
| 568 | assert(env.exp(exp).val); |
| 569 | Value v0 = env.exp(e: exp).val; |
| 570 | genInsertionStore(env, builder, t, rhs: v0); |
| 571 | env.merger().clearExprValue(e: exp); |
| 572 | // Yield modified insertion chain along true branch. |
| 573 | Value mchain = env.getInsertionChain(); |
| 574 | builder.create<scf::YieldOp>(op.getLoc(), mchain); |
| 575 | // Yield original insertion chain along false branch. |
| 576 | builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); |
| 577 | builder.create<scf::YieldOp>(loc, chain); |
| 578 | // Done with if statement. |
| 579 | env.updateInsertionChain(chain: ifOp->getResult(0)); |
| 580 | builder.setInsertionPointAfter(ifOp); |
| 581 | } |
| 582 | |
| 583 | /// Generates an invariant value. |
| 584 | inline static Value genInvariantValue(CodegenEnv &env, ExprId exp) { |
| 585 | return env.exp(e: exp).val; |
| 586 | } |
| 587 | |
| 588 | /// Semi-ring branches are simply inlined by the sparsifier. Prior |
| 589 | /// analysis has verified that all computations are "local" to the inlined |
| 590 | /// branch or otherwise invariantly defined outside the loop nest, with the |
| 591 | /// exception of index computations, which need to be relinked to actual |
| 592 | /// inlined cloned code. |
| 593 | static Value relinkBranch(CodegenEnv &env, RewriterBase &rewriter, Block *block, |
| 594 | Value e) { |
| 595 | if (auto arg = dyn_cast<BlockArgument>(Val&: e)) { |
| 596 | // Direct arguments of the original linalg op must be converted |
| 597 | // into dense tensor loads. Note that we should not encounter |
| 598 | // anything else. This needs to be verified by semi-ring ops. |
| 599 | linalg::GenericOp op = env.op(); |
| 600 | if (arg.getOwner()->getParentOp() == op) { |
| 601 | const TensorId tid = env.makeTensorId(t: arg.getArgNumber()); |
| 602 | OpOperand *t = &op->getOpOperand(tid); |
| 603 | assert(!getSparseTensorType(t->get()).hasEncoding()); // dense! |
| 604 | SmallVector<Value> args; |
| 605 | Value ptr = genSubscript(env, builder&: rewriter, t, args); |
| 606 | return rewriter.create<memref::LoadOp>(op.getLoc(), ptr, args); |
| 607 | } |
| 608 | } else if (Operation *def = e.getDefiningOp()) { |
| 609 | // Handle index computation. |
| 610 | if (auto indexOp = dyn_cast<linalg::IndexOp>(def)) |
| 611 | return env.getLoopVar(i: env.makeLoopId(i: indexOp.getDim())); |
| 612 | // When still defined in new body, recurse into operands. |
| 613 | if (def->getBlock() == block) { |
| 614 | rewriter.setInsertionPoint(def); |
| 615 | for (unsigned i = 0, n = def->getNumOperands(); i < n; i++) { |
| 616 | rewriter.modifyOpInPlace(root: def, callable: [&]() { |
| 617 | def->setOperand( |
| 618 | idx: i, value: relinkBranch(env, rewriter, block, e: def->getOperand(idx: i))); |
| 619 | }); |
| 620 | } |
| 621 | } |
| 622 | } |
| 623 | return e; |
| 624 | } |
| 625 | |
| 626 | /// Recursively generates tensor expression. |
| 627 | static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e) { |
| 628 | if (e == ::mlir::sparse_tensor::detail::kInvalidId) |
| 629 | return Value(); |
| 630 | |
| 631 | linalg::GenericOp op = env.op(); |
| 632 | Location loc = op.getLoc(); |
| 633 | const TensorExp &exp = env.exp(e); |
| 634 | const auto kind = exp.kind; |
| 635 | if (kind == TensorExp::Kind::kTensor) |
| 636 | return genTensorLoad(env, builder&: rewriter, exp: e); |
| 637 | if (kind == TensorExp::Kind::kInvariant) |
| 638 | return genInvariantValue(env, exp: e); |
| 639 | if (kind == TensorExp::Kind::kLoopVar) |
| 640 | return env.getLoopVar(i: exp.loop); |
| 641 | |
| 642 | if (kind == TensorExp::Kind::kReduce) |
| 643 | env.startCustomReduc(exp: e); // enter custom |
| 644 | |
| 645 | // If either lhs/rhs is a synthetic zero, we infer the type for the zero value |
| 646 | // based on the type of the other operand. |
| 647 | Value v0, v1; |
| 648 | if (exp.children.e0 != ::mlir::sparse_tensor::detail::kInvalidId && |
| 649 | env.exp(e: exp.children.e0).kind == TensorExp::Kind::kSynZero) { |
| 650 | v1 = genExp(env, rewriter, e: exp.children.e1); |
| 651 | v0 = constantZero(builder&: rewriter, loc, tp: v1.getType()); |
| 652 | } else if (exp.children.e1 != ::mlir::sparse_tensor::detail::kInvalidId && |
| 653 | env.exp(e: exp.children.e1).kind == TensorExp::Kind::kSynZero) { |
| 654 | v0 = genExp(env, rewriter, e: exp.children.e0); |
| 655 | v1 = constantZero(builder&: rewriter, loc, tp: v0.getType()); |
| 656 | } else { |
| 657 | v0 = genExp(env, rewriter, e: exp.children.e0); |
| 658 | v1 = genExp(env, rewriter, e: exp.children.e1); |
| 659 | } |
| 660 | |
| 661 | Value ee; |
| 662 | if (kind == TensorExp::Kind::kReduce && (!v0 || !v1)) { |
| 663 | // custom reduce did not receive a value |
| 664 | } else { |
| 665 | ee = env.merger().buildExp(rewriter, loc, e, v0, v1); |
| 666 | if (ee && |
| 667 | (kind == TensorExp::Kind::kUnary || kind == TensorExp::Kind::kBinary || |
| 668 | kind == TensorExp::Kind::kBinaryBranch || |
| 669 | kind == TensorExp::Kind::kReduce || |
| 670 | kind == TensorExp::Kind::kSelect)) { |
| 671 | OpBuilder::InsertionGuard guard(rewriter); |
| 672 | ee = relinkBranch(env, rewriter, block: ee.getParentBlock(), e: ee); |
| 673 | } |
| 674 | } |
| 675 | |
| 676 | if (kind == TensorExp::Kind::kReduce) |
| 677 | env.endCustomReduc(); // exit custom |
| 678 | |
| 679 | if (kind == TensorExp::Kind::kSelect) |
| 680 | env.merger().setExprValue(e, v: v0); // Preserve value for later use. |
| 681 | |
| 682 | return ee; |
| 683 | } |
| 684 | |
| 685 | /// Hoists loop invariant tensor loads for which indices have been exhausted. |
| 686 | static void genInvariants(CodegenEnv &env, OpBuilder &builder, ExprId exp, |
| 687 | LoopId curr, bool isStart) { |
| 688 | if (exp == ::mlir::sparse_tensor::detail::kInvalidId) |
| 689 | return; |
| 690 | if (env.exp(e: exp).kind == TensorExp::Kind::kTensor) { |
| 691 | // Inspect tensor indices. |
| 692 | linalg::GenericOp op = env.op(); |
| 693 | OpOperand &t = op->getOpOperand(env.exp(e: exp).tensor); |
| 694 | const auto map = op.getMatchingIndexingMap(&t); |
| 695 | const auto stt = getSparseTensorType(val: t.get()); |
| 696 | const Level lvlRank = stt.getLvlRank(); |
| 697 | assert(static_cast<Level>(map.getNumResults()) == lvlRank); |
| 698 | bool isCurrentLoop = curr == 0; // for scalar tensors |
| 699 | for (Level l = 0; l < lvlRank; l++) { |
| 700 | const AffineExpr a = map.getResult(l); |
| 701 | if (!isInvariantAffine(a, curr, /*out*/ isCurrentLoop)) |
| 702 | return; // still in play |
| 703 | } |
| 704 | // All exhausted at current level. |
| 705 | if (!isCurrentLoop) |
| 706 | return; |
| 707 | // Generate code for a scalarized reduction or invariant. Note that |
| 708 | // because custom reduction lhs may occur several times in the IR, |
| 709 | // we have a built-in safety for only initializing and wrapping-up |
| 710 | // the scalarized reduction once. |
| 711 | OpOperand *lhs = op.getDpsInitOperand(0); |
| 712 | if (lhs == &t) { |
| 713 | // Start or end a scalarized reduction. |
| 714 | if (isStart) { |
| 715 | if (env.isCustomReduc()) { |
| 716 | if (!env.isReduc()) |
| 717 | env.startReduc(exp, val: env.getCustomRedId()); |
| 718 | } else { |
| 719 | env.startReduc(exp, val: genTensorLoad(env, builder, exp)); |
| 720 | } |
| 721 | if (env.hasSparseOutput()) |
| 722 | env.startValidLexInsert( |
| 723 | val: constantI1(builder, env.op().getLoc(), false)); |
| 724 | } else { |
| 725 | if (!env.isCustomReduc() || env.isReduc()) |
| 726 | genTensorStore(env, builder, exp, rhs: env.endReduc()); |
| 727 | if (env.hasSparseOutput()) |
| 728 | env.endValidLexInsert(); |
| 729 | } |
| 730 | } else { |
| 731 | // Start or end loop invariant hoisting of a tensor load. |
| 732 | if (isStart) { |
| 733 | env.merger().setExprValue(e: exp, v: genTensorLoad(env, builder, exp)); |
| 734 | } else { |
| 735 | env.merger().clearExprValue(e: exp); |
| 736 | } |
| 737 | } |
| 738 | } else if (env.exp(e: exp).kind != TensorExp::Kind::kInvariant && |
| 739 | env.exp(e: exp).kind != TensorExp::Kind::kLoopVar && |
| 740 | env.exp(e: exp).kind != TensorExp::Kind::kSynZero) { |
| 741 | // Traverse into the binary operations. Note that we only hoist |
| 742 | // tensor loads, since subsequent MLIR/LLVM passes know how to |
| 743 | // deal with all other kinds of derived loop invariants. |
| 744 | if (env.exp(e: exp).kind == TensorExp::Kind::kReduce) |
| 745 | env.startCustomReduc(exp); // enter custom |
| 746 | const ExprId e0 = env.exp(e: exp).children.e0; |
| 747 | const ExprId e1 = env.exp(e: exp).children.e1; |
| 748 | genInvariants(env, builder, exp: e0, curr, isStart); |
| 749 | genInvariants(env, builder, exp: e1, curr, isStart); |
| 750 | if (env.exp(e: exp).kind == TensorExp::Kind::kReduce) |
| 751 | env.endCustomReduc(); // exit custom |
| 752 | } |
| 753 | } |
| 754 | |
| 755 | /// Generates an expanded access pattern in innermost dimension. |
| 756 | static void genExpand(CodegenEnv &env, OpBuilder &builder, LoopId curr, |
| 757 | bool isStart) { |
| 758 | linalg::GenericOp op = env.op(); |
| 759 | OpOperand *lhs = op.getDpsInitOperand(0); |
| 760 | if (!env.atExpandLevel(o: lhs, rank: op.getRank(lhs), n: curr)) |
| 761 | return; // not needed at current level |
| 762 | assert(!env.isReduc()); |
| 763 | // Generate start or end of an expanded access pattern. Note that because |
| 764 | // an expansion does not rely on the ongoing contents of the sparse storage |
| 765 | // scheme, we can use the original tensor as incoming SSA value (which |
| 766 | // simplifies codegen a bit). If expansion on the actual contents is ever |
| 767 | // needed, we will need to use the SSA value in the insertion chain instead. |
| 768 | Value tensor = lhs->get(); |
| 769 | Location loc = op.getLoc(); |
| 770 | if (isStart) { |
| 771 | auto dynShape = {ShapedType::kDynamic}; |
| 772 | Type etp = cast<ShapedType>(tensor.getType()).getElementType(); |
| 773 | Type t1 = MemRefType::get(dynShape, etp); |
| 774 | Type t2 = MemRefType::get(dynShape, builder.getI1Type()); |
| 775 | Type t3 = MemRefType::get(dynShape, builder.getIndexType()); |
| 776 | Type t4 = builder.getIndexType(); |
| 777 | auto r = builder.create<ExpandOp>(loc, TypeRange({t1, t2, t3, t4}), tensor); |
| 778 | assert(r.getNumResults() == 4); |
| 779 | env.startExpand(values: r.getResult(0), filled: r.getResult(1), added: r.getResult(2), |
| 780 | count: r.getResult(3)); |
| 781 | } else { |
| 782 | SmallVector<Value> indices; |
| 783 | for (LoopId i = 0; i < curr; i++) |
| 784 | indices.push_back(Elt: env.emitter().getLoopIV(n: i)); |
| 785 | Value values = env.getExpandValues(); |
| 786 | Value filled = env.getExpandFilled(); |
| 787 | Value added = env.getExpandAdded(); |
| 788 | Value count = env.getExpandCount(); |
| 789 | Value chain = env.getInsertionChain(); |
| 790 | Value compress = builder.create<CompressOp>(loc, values, filled, added, |
| 791 | count, chain, indices); |
| 792 | env.updateInsertionChain(chain: compress); |
| 793 | env.endExpand(); |
| 794 | } |
| 795 | } |
| 796 | |
| 797 | /// Returns parallelization strategy. Any implicit loop in the Linalg |
| 798 | /// operation that is marked "parallel" is a candidate. Whether it is actually |
| 799 | /// converted to a parallel operation depends on the requested strategy. |
| 800 | static bool isParallelFor(CodegenEnv &env, bool isOuter, bool isSparse) { |
| 801 | // Reject parallelization of sparse output. |
| 802 | if (env.hasSparseOutput()) |
| 803 | return false; |
| 804 | // Parallel loops on tensor expansion can cause data races. |
| 805 | if (env.isExpand()) |
| 806 | return false; |
| 807 | // Inspect strategy. |
| 808 | switch (env.options().parallelizationStrategy) { |
| 809 | case SparseParallelizationStrategy::kNone: |
| 810 | return false; |
| 811 | case SparseParallelizationStrategy::kDenseOuterLoop: |
| 812 | return isOuter && !isSparse; |
| 813 | case SparseParallelizationStrategy::kAnyStorageOuterLoop: |
| 814 | return isOuter; |
| 815 | case SparseParallelizationStrategy::kDenseAnyLoop: |
| 816 | return !isSparse; |
| 817 | case SparseParallelizationStrategy::kAnyStorageAnyLoop: |
| 818 | return true; |
| 819 | } |
| 820 | llvm_unreachable("unexpected parallelization strategy" ); |
| 821 | } |
| 822 | |
| 823 | /// Whether or not the current loop being generated should be parallized (if |
| 824 | /// possible) according to the configuration. |
| 825 | static bool shouldTryParallize(CodegenEnv &env, LoopId curr, |
| 826 | ArrayRef<TensorLevel> tidLvls) { |
| 827 | linalg::GenericOp op = env.op(); |
| 828 | auto iteratorTypes = op.getIteratorTypesArray(); |
| 829 | bool isSparse = llvm::any_of(Range&: tidLvls, P: [curr, &env](TensorLevel tidLvl) { |
| 830 | // Queries the LT based on the tensor and loop id, as requested by |
| 831 | // `CodegenEnv::lt(TensorId, LoopId)`. The returned LT from CodegenEnv |
| 832 | // should be consistent with the LT indexed by <TensorId, Level>. |
| 833 | const auto lt = env.lt(t: env.unpackTensorLevel(tl: tidLvl).first, i: curr); |
| 834 | return lt.hasSparseSemantic(); |
| 835 | }); |
| 836 | return isParallelFor(env, /*isOuter=*/curr == 0, isSparse); |
| 837 | } |
| 838 | |
| 839 | /// Emit a loop to coiterate over the list of tensor levels. The generated loop |
| 840 | /// can either be a for loop or while loop depending on whether there is at most |
| 841 | /// one sparse level in the list. |
| 842 | static Operation *genCoIteration(CodegenEnv &env, OpBuilder &builder, |
| 843 | ArrayRef<TensorLevel> tidLvls, |
| 844 | unsigned numCases, bool tryParallel, |
| 845 | bool needsUniv) { |
| 846 | Operation *loop = *env.genLoopBoundary([&](MutableArrayRef<Value> reduc) { |
| 847 | // Construct while-loop with a parameter for each index. |
| 848 | return env.emitter().enterCoIterationOverTensorsAtLvls( |
| 849 | builder, env.op().getLoc(), tidLvls, numCases, reduc, tryParallel, |
| 850 | needsUniv); |
| 851 | }); |
| 852 | assert(loop); |
| 853 | return loop; |
| 854 | } |
| 855 | |
| 856 | /// Generates a for-loop or a while-loop, depending on whether it implements |
| 857 | /// singleton iteration or co-iteration over the given conjunction. |
| 858 | static Operation *genLoop(CodegenEnv &env, OpBuilder &builder, LoopId curr, |
| 859 | unsigned numCases, bool needsUniv, |
| 860 | ArrayRef<TensorLevel> tidLvls) { |
| 861 | bool tryParallel = shouldTryParallize(env, curr, tidLvls); |
| 862 | return genCoIteration(env, builder, tidLvls, numCases, tryParallel, |
| 863 | needsUniv); |
| 864 | } |
| 865 | |
| 866 | /// Generates the induction structure for a while-loop. |
| 867 | static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder, |
| 868 | bool needsUniv) { |
| 869 | Location loc = env.op().getLoc(); |
| 870 | // Finalize each else branch of all if statements. |
| 871 | if (env.isReduc() || env.isExpand() || env.getInsertionChain()) { |
| 872 | while (auto ifOp = dyn_cast_or_null<scf::IfOp>( |
| 873 | builder.getInsertionBlock()->getParentOp())) { |
| 874 | // Break on IfOp for slicing filtering. |
| 875 | if (ifOp->getAttr(LoopEmitter::getLoopEmitterLoopAttrName()) == |
| 876 | StringAttr::get(ifOp->getContext(), "slice" )) |
| 877 | break; |
| 878 | |
| 879 | unsigned y = 0; |
| 880 | SmallVector<Value> yields; |
| 881 | if (env.isReduc()) { |
| 882 | yields.push_back(Elt: env.getReduc()); |
| 883 | env.updateReduc(val: ifOp.getResult(y++)); |
| 884 | if (env.isValidLexInsert()) { |
| 885 | yields.push_back(Elt: env.getValidLexInsert()); |
| 886 | env.updateValidLexInsert(val: ifOp.getResult(y++)); |
| 887 | } |
| 888 | } |
| 889 | if (env.isExpand()) { |
| 890 | yields.push_back(Elt: env.getExpandCount()); |
| 891 | env.updateExpandCount(count: ifOp->getResult(y++)); |
| 892 | } |
| 893 | if (env.getInsertionChain()) { |
| 894 | yields.push_back(Elt: env.getInsertionChain()); |
| 895 | env.updateInsertionChain(chain: ifOp->getResult(y++)); |
| 896 | } |
| 897 | assert(y == yields.size()); |
| 898 | builder.create<scf::YieldOp>(loc, yields); |
| 899 | builder.setInsertionPointAfter(ifOp); |
| 900 | } |
| 901 | } |
| 902 | // No need to set the insertion point here as LoopEmitter keeps track of the |
| 903 | // basic block where scf::Yield should be inserted. |
| 904 | } |
| 905 | |
| 906 | /// Generates a case region in the coiterate operation. |
| 907 | static void genCoIterationCase(CodegenEnv &env, OpBuilder &builder, |
| 908 | unsigned caseIdx, LatPointId allCase, |
| 909 | LatPointId curCase, |
| 910 | MutableArrayRef<Value> reduc) { |
| 911 | assert(allCase == curCase || env.merger().latGT(allCase, curCase)); |
| 912 | const BitVector &allCaseBits = env.merger().lat(p: allCase).simple; |
| 913 | const BitVector &curCaseBits = env.merger().lat(p: curCase).simple; |
| 914 | |
| 915 | /// Computes the subset of iterators that are valid in the current case being |
| 916 | /// generated. |
| 917 | I64BitSet caseBit(0); |
| 918 | for (auto [idx, set] : llvm::enumerate(First: allCaseBits.set_bits())) |
| 919 | if (curCaseBits.test(Idx: set)) |
| 920 | caseBit.set(idx); |
| 921 | |
| 922 | env.emitter().enterCurrentCoIterationCase(builder, loc: env.op().getLoc(), caseBit, |
| 923 | caseIdx, reduc); |
| 924 | } |
| 925 | |
| 926 | /// Generates a single if-statement within a while-loop. |
| 927 | static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId curr, |
| 928 | LatPointId p) { |
| 929 | Location loc = env.op().getLoc(); |
| 930 | SmallVector<Type> types; |
| 931 | Value cond; |
| 932 | env.merger().foreachTensorLoopId( |
| 933 | p, /*simple=*/true, |
| 934 | callback: [&](TensorLoopId b, TensorId tid, std::optional<Level> lvl, LevelType lt, |
| 935 | bool isIdxRed) { |
| 936 | if (isIdxRed) { |
| 937 | // Since there is no 1:1 mapping from loop to level (multiple loops |
| 938 | // are required to resolve one level with non-trivial index |
| 939 | // expression), we need to reconstruct the tensor level types if this |
| 940 | // loop requires index reduction condition. |
| 941 | assert(lvl.has_value() && isUndefLT(lt)); |
| 942 | auto stt = getSparseTensorType(env.op().getInputs()[tid]); |
| 943 | lt = stt.getLvlType(*lvl); |
| 944 | } |
| 945 | assert(curr == env.merger().loop(b)); |
| 946 | Value clause; |
| 947 | if (lt.hasSparseSemantic()) { |
| 948 | assert(lvl.has_value()); |
| 949 | const Value crd = env.emitter().getCoord(tid, lvl: *lvl); |
| 950 | const Value lvar = env.getLoopVar(i: curr); |
| 951 | clause = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, |
| 952 | crd, lvar); |
| 953 | } else { |
| 954 | assert(lt.hasDenseSemantic() || isUndefLT(lt)); |
| 955 | clause = constantI1(builder, loc, b: true); |
| 956 | } |
| 957 | cond = cond ? builder.create<arith::AndIOp>(loc, cond, clause) : clause; |
| 958 | }); |
| 959 | if (env.isReduc()) { |
| 960 | types.push_back(Elt: env.getReduc().getType()); |
| 961 | if (env.isValidLexInsert()) |
| 962 | types.push_back(Elt: env.getValidLexInsert().getType()); |
| 963 | } |
| 964 | if (env.isExpand()) |
| 965 | types.push_back(builder.getIndexType()); |
| 966 | if (env.getInsertionChain()) |
| 967 | types.push_back(Elt: env.getInsertionChain().getType()); |
| 968 | scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, cond, /*else=*/true); |
| 969 | builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); |
| 970 | return ifOp; |
| 971 | } |
| 972 | |
| 973 | /// Generates end of true branch of if-statement within a while-loop. |
| 974 | static void endIf(CodegenEnv &env, OpBuilder &builder, scf::IfOp ifOp, |
| 975 | Value redInput, Value cntInput, Value insInput, |
| 976 | Value validIns) { |
| 977 | SmallVector<Value> operands; |
| 978 | if (env.isReduc()) { |
| 979 | operands.push_back(Elt: env.getReduc()); |
| 980 | env.updateReduc(val: redInput); |
| 981 | if (env.isValidLexInsert()) { |
| 982 | // Any overlapping indices during a reduction creates a valid lex insert. |
| 983 | operands.push_back(Elt: constantI1(builder, env.op().getLoc(), true)); |
| 984 | env.updateValidLexInsert(val: validIns); |
| 985 | } |
| 986 | } |
| 987 | if (env.isExpand()) { |
| 988 | operands.push_back(Elt: env.getExpandCount()); |
| 989 | env.updateExpandCount(count: cntInput); |
| 990 | } |
| 991 | if (env.getInsertionChain()) { |
| 992 | operands.push_back(Elt: env.getInsertionChain()); |
| 993 | env.updateInsertionChain(chain: insInput); |
| 994 | } |
| 995 | if (!operands.empty()) |
| 996 | builder.create<scf::YieldOp>(env.op().getLoc(), operands); |
| 997 | builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); |
| 998 | } |
| 999 | |
| 1000 | //===----------------------------------------------------------------------===// |
| 1001 | // Sparsifier synthesis methods (loop sequence). |
| 1002 | //===----------------------------------------------------------------------===// |
| 1003 | |
| 1004 | static bool getAllTidLvlsInLatPoints( |
| 1005 | CodegenEnv &env, LatPointId li, LoopId curr, |
| 1006 | llvm::function_ref<void(TensorLevel, AffineExpr)> callback) { |
| 1007 | const BitVector &simple = env.lat(l: li).simple; |
| 1008 | const TensorId outTid = env.merger().getOutTensorID(); |
| 1009 | const std::optional<Level> outLvl = env.merger().getLvl(t: outTid, i: curr); |
| 1010 | |
| 1011 | unsigned numloopCond = 0; |
| 1012 | bool hasNonUnique = false; |
| 1013 | env.merger().foreachTensorLoopId( |
| 1014 | p: li, callback: [&, curr](TensorLoopId b, TensorId tid, std::optional<Level> lvl, |
| 1015 | LevelType lt, bool isIdxReduc) { |
| 1016 | if (simple[b]) { |
| 1017 | if (isIdxReduc) { |
| 1018 | callback(env.makeTensorLevel(t: tid, l: *lvl), nullptr); |
| 1019 | numloopCond++; |
| 1020 | return; |
| 1021 | } |
| 1022 | if (isUndefLT(lt)) { |
| 1023 | // An undefined lt in the lattices, we probably mean to |
| 1024 | // generate a dense loop according to the synthetic tensor (for |
| 1025 | // invariants and sparse output tensor). |
| 1026 | if (env.merger().getSynTensorID() == tid) { |
| 1027 | // Coiterating with an invariant |
| 1028 | // e.g., out = prod(in[i][j] op invariant); |
| 1029 | // or a broadcast |
| 1030 | // e.g., out[i][j] = in[i] (j is undef for input) |
| 1031 | // |
| 1032 | // The level of the synthetic tensor is the current loop depth; |
| 1033 | // the rank of the synthetic tensor equals to number of loops. |
| 1034 | assert(curr == env.getCurrentDepth()); |
| 1035 | lvl = curr; |
| 1036 | } else if (!lvl) { |
| 1037 | // Skips invalid lvl (e.g., when this is a zero ranked tensor). |
| 1038 | return; |
| 1039 | } |
| 1040 | } |
| 1041 | hasNonUnique = !isUniqueLT(lt) || hasNonUnique; |
| 1042 | callback(env.makeTensorLevel(t: tid, l: *lvl), nullptr); |
| 1043 | numloopCond++; |
| 1044 | } else if (lt.hasDenseSemantic() || isIdxReduc) { |
| 1045 | callback(env.makeTensorLevel(t: tid, l: *lvl), nullptr); |
| 1046 | } else { |
| 1047 | assert(isUndefLT(lt)); |
| 1048 | linalg::GenericOp op = env.op(); |
| 1049 | if (tid >= op.getNumDpsInputs()) |
| 1050 | // We only handle affine expression on input tensors (for now). |
| 1051 | return; |
| 1052 | OpOperand *operand = &op->getOpOperand(tid); |
| 1053 | const auto stt = getSparseTensorType(val: operand->get()); |
| 1054 | // Non-annotated dense tensors requires no special handling. |
| 1055 | if (!stt.hasEncoding()) |
| 1056 | return; |
| 1057 | |
| 1058 | ArrayRef<AffineExpr> affines = |
| 1059 | op.getMatchingIndexingMap(operand).getResults(); |
| 1060 | const Level lvlRank = stt.getLvlRank(); |
| 1061 | assert(affines.size() == static_cast<size_t>(lvlRank)); |
| 1062 | for (Level l = 0; l < lvlRank; l++) { |
| 1063 | AffineExpr exp = affines[l]; |
| 1064 | // Skip simple affine expression and non-dense levels (which |
| 1065 | // have their own filter loop). |
| 1066 | LevelType lt = stt.getLvlType(l); |
| 1067 | if (isa<AffineDimExpr>(Val: exp) || !lt.hasDenseSemantic()) |
| 1068 | continue; |
| 1069 | |
| 1070 | // Constant affine expression are handled in genLoop. |
| 1071 | if (!isa<AffineConstantExpr>(Val: exp)) { |
| 1072 | bool isCurrentLoop = false; |
| 1073 | assert(curr == env.getCurrentDepth()); |
| 1074 | if (isInvariantAffine(a: exp, curr: curr + 1, /*out*/ isCurrentLoop) && |
| 1075 | isCurrentLoop) { |
| 1076 | // If the compound affine is invariant and we are right at the |
| 1077 | // level. We need to generate the address according to the |
| 1078 | // affine expression. This is also the best place we can do it |
| 1079 | // to avoid putting it inside inner loops. |
| 1080 | callback(env.makeTensorLevel(t: tid, l), exp); |
| 1081 | } |
| 1082 | } |
| 1083 | } |
| 1084 | } |
| 1085 | }); |
| 1086 | |
| 1087 | if (isDenseLT(lt: env.lt(t: outTid, i: curr))) { |
| 1088 | auto stt = getSparseTensorType(env.op().getOutputs().front()); |
| 1089 | // Note that we generate dense indices of the output tensor unconditionally, |
| 1090 | // since they may not appear in the lattice, but may be needed for |
| 1091 | // linearized env. |
| 1092 | // TODO: we should avoid introducing corner cases for all-dense sparse |
| 1093 | // tensors. |
| 1094 | if (stt.hasEncoding() && stt.isAllDense()) |
| 1095 | callback(env.makeTensorLevel(t: outTid, l: *outLvl), nullptr); |
| 1096 | } |
| 1097 | |
| 1098 | if (numloopCond == 0) { |
| 1099 | // Corner cases where the loop bound is defined by a *unused* operand, in |
| 1100 | // this case, we just generate a dense "fake" loop by iterating over the |
| 1101 | // synthetic tensor. |
| 1102 | callback(env.makeTensorLevel(t: env.merger().getSynTensorID(), l: curr), nullptr); |
| 1103 | numloopCond++; |
| 1104 | } |
| 1105 | // If we just need to one loop conditions and the conditions is not imposed on |
| 1106 | // non-unique level, the loop can be generated by a for loop. |
| 1107 | // Or, if we are generating sparse-iterator-based loops, we always generate |
| 1108 | // `sparse_tensor.iterate` regardless whether the level is unique or not. |
| 1109 | return numloopCond == 1 && |
| 1110 | (!hasNonUnique || env.options().sparseEmitStrategy == |
| 1111 | SparseEmitStrategy::kSparseIterator); |
| 1112 | } |
| 1113 | |
| 1114 | /// Starts a loop sequence at given level. Returns true if |
| 1115 | /// the universal loop index must be maintained at this level. |
| 1116 | static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp, |
| 1117 | LoopId curr, LatSetId lts) { |
| 1118 | assert(!env.getLoopVar(curr)); |
| 1119 | // Emit invariants at this loop sequence level. |
| 1120 | genInvariants(env, builder, exp, curr, /*isStart=*/true); |
| 1121 | // Emit access pattern expansion for sparse tensor output. |
| 1122 | genExpand(env, builder, curr, /*isStart=*/true); |
| 1123 | // Emit further initialization at this loop sequence level. |
| 1124 | const LatPointId l0 = env.set(lts)[0]; |
| 1125 | |
| 1126 | SmallVector<TensorLevel> tidLvls; |
| 1127 | getAllTidLvlsInLatPoints(env, li: l0, curr, callback: [&](TensorLevel tl, AffineExpr) { |
| 1128 | // TODO: remove this! The same tensor level might be added for multiple |
| 1129 | // times due to the special handling for all-dense "sparse" output tensor |
| 1130 | // (see L1038). |
| 1131 | if (llvm::is_contained(Range&: tidLvls, Element: tl)) |
| 1132 | return; |
| 1133 | tidLvls.emplace_back(Args&: tl); |
| 1134 | }); |
| 1135 | |
| 1136 | env.emitter().enterNewLoopSeq(builder, loc: env.op().getLoc(), tidLvls); |
| 1137 | |
| 1138 | // Maintain the universal index only if it is actually |
| 1139 | // consumed by a subsequent lattice point. |
| 1140 | for (const LatPointId li : env.set(lts).drop_front()) |
| 1141 | if (!env.merger().hasAnySparse(bits: env.lat(l: li).simple)) |
| 1142 | return true; |
| 1143 | |
| 1144 | return false; |
| 1145 | } |
| 1146 | |
| 1147 | // Generates dense affine address for encoding. |
| 1148 | static void genConstantDenseAddressFromLevel(CodegenEnv &env, |
| 1149 | OpBuilder &builder, TensorId tid, |
| 1150 | Level startLvl) { |
| 1151 | // TODO: Handle affine expression on output tensor. |
| 1152 | linalg::GenericOp op = env.op(); |
| 1153 | assert(tid < op.getNumDpsInputs()); |
| 1154 | OpOperand *input = op.getDpsInputOperands()[tid]; |
| 1155 | const auto lvlExprs = op.getMatchingIndexingMap(input).getResults(); |
| 1156 | const auto enc = getSparseTensorEncoding(input->get().getType()); |
| 1157 | if (enc) { |
| 1158 | const Location loc = op.getLoc(); |
| 1159 | const TensorId tid = env.makeTensorId(t: input->getOperandNumber()); |
| 1160 | const Level lvlRank = enc.getLvlRank(); |
| 1161 | assert(lvlExprs.size() == static_cast<size_t>(lvlRank)); |
| 1162 | for (Level l = startLvl; l < lvlRank; l++) { |
| 1163 | AffineExpr lvlExpr = lvlExprs[l]; |
| 1164 | if (enc.getLvlType(l).hasDenseSemantic() && |
| 1165 | isa<AffineConstantExpr>(Val: lvlExpr)) |
| 1166 | env.emitter().locateLvlAtAffineAddress( |
| 1167 | builder, loc, tidLvl: env.makeTensorLevel(t: tid, l), lvlExpr); |
| 1168 | else |
| 1169 | return; // break on first non-dense non-constant level |
| 1170 | } |
| 1171 | } |
| 1172 | } |
| 1173 | |
| 1174 | // We can generate address for constant affine expression before any loops |
| 1175 | // starting from the first level as they do not depend on anything. |
| 1176 | // E.g., [Dense, Dense, Sparse] -> (1, 2, d0), the addresses for the first two |
| 1177 | // levels can be determined before loops. |
| 1178 | static void genInitConstantDenseAddress(CodegenEnv &env, |
| 1179 | RewriterBase &rewriter) { |
| 1180 | for (TensorId tid = 0, e = env.op().getNumDpsInputs(); tid < e; tid++) |
| 1181 | genConstantDenseAddressFromLevel(env, builder&: rewriter, tid, startLvl: 0); |
| 1182 | } |
| 1183 | |
| 1184 | /// Returns true if the lattice bit can be iterated by a for loop. |
| 1185 | static bool translateBitsToTidLvlPairs( |
| 1186 | CodegenEnv &env, LatPointId li, LoopId curr, |
| 1187 | SmallVectorImpl<TensorLevel> &tidLvls, |
| 1188 | SmallVectorImpl<std::pair<TensorLevel, AffineExpr>> &affineTidLvls) { |
| 1189 | return getAllTidLvlsInLatPoints(env, li, curr, |
| 1190 | callback: [&](TensorLevel tl, AffineExpr exp) { |
| 1191 | if (exp) |
| 1192 | affineTidLvls.emplace_back(Args&: tl, Args&: exp); |
| 1193 | else |
| 1194 | tidLvls.emplace_back(Args&: tl); |
| 1195 | }); |
| 1196 | } |
| 1197 | |
| 1198 | /// Starts a single loop in current sequence. |
| 1199 | static std::pair<Operation *, bool> startLoop(CodegenEnv &env, |
| 1200 | OpBuilder &builder, LoopId curr, |
| 1201 | LatPointId li, unsigned numCases, |
| 1202 | bool needsUniv) { |
| 1203 | // TODO: numCases only used when generating iterator-based loops. Cleanup |
| 1204 | // after fully migration. |
| 1205 | // The set of tensors + lvls to generate loops on |
| 1206 | SmallVector<TensorLevel> tidLvls; |
| 1207 | |
| 1208 | // The set of dense tensors with non-trivial affine expression that just |
| 1209 | // becomes invariant and the address are generated at the current level. |
| 1210 | SmallVector<std::pair<TensorLevel, AffineExpr>> affineTidLvls; |
| 1211 | bool isSingleCond = |
| 1212 | translateBitsToTidLvlPairs(env, li, curr, tidLvls, affineTidLvls); |
| 1213 | |
| 1214 | // Emit the for/while-loop control. |
| 1215 | Operation *loop = genLoop(env, builder, curr, numCases, needsUniv, tidLvls); |
| 1216 | Location loc = env.op().getLoc(); |
| 1217 | for (auto [tidLvl, exp] : affineTidLvls) { |
| 1218 | env.emitter().locateLvlAtAffineAddress(builder, loc, tidLvl, lvlExpr: exp); |
| 1219 | } |
| 1220 | |
| 1221 | // Until now, we have entered every <tid, lvl> pair in {cond, extra, |
| 1222 | // affine}Tids/Lvls. The addresses of the upcoming levels which are dependent |
| 1223 | // on constant affines expression may now be determined. |
| 1224 | auto allTidLvls = |
| 1225 | llvm::concat<TensorLevel>(Ranges&: tidLvls, Ranges: llvm::make_first_range(c&: affineTidLvls)); |
| 1226 | for (auto [tid, lvl] : env.unpackTensorLevelRange(allTidLvls)) { |
| 1227 | if (tid != env.merger().getOutTensorID() && |
| 1228 | tid != env.merger().getSynTensorID()) |
| 1229 | genConstantDenseAddressFromLevel(env, builder, tid, lvl + 1); |
| 1230 | } |
| 1231 | |
| 1232 | return std::make_pair(x&: loop, y&: isSingleCond); |
| 1233 | } |
| 1234 | |
| 1235 | /// Ends a single loop in current sequence. Returns new values for needsUniv. |
| 1236 | static bool endLoop(CodegenEnv &env, RewriterBase &rewriter, Operation *loop, |
| 1237 | LatPointId li, bool needsUniv, bool isSingleCond) { |
| 1238 | // Either a for-loop or a while-loop that iterates over a slice. |
| 1239 | if (isSingleCond) { |
| 1240 | // Any iteration creates a valid lex insert. |
| 1241 | if (env.isReduc() && env.isValidLexInsert()) |
| 1242 | env.updateValidLexInsert(val: constantI1(rewriter, env.op().getLoc(), true)); |
| 1243 | } else if (auto whileOp = dyn_cast<scf::WhileOp>(loop)) { |
| 1244 | // End a while-loop. |
| 1245 | finalizeWhileOp(env, builder&: rewriter, needsUniv); |
| 1246 | } else { |
| 1247 | needsUniv = false; |
| 1248 | } |
| 1249 | env.genLoopBoundary(callback: [&](MutableArrayRef<Value> reduc) { |
| 1250 | env.emitter().exitCurrentLoop(rewriter, loc: env.op().getLoc(), reduc); |
| 1251 | return std::nullopt; |
| 1252 | }); |
| 1253 | return needsUniv; |
| 1254 | } |
| 1255 | |
| 1256 | /// Ends a loop sequence at given level. |
| 1257 | static void endLoopSeq(CodegenEnv &env, OpBuilder &builder, unsigned exp, |
| 1258 | unsigned at) { |
| 1259 | assert(!env.getLoopVar(at)); |
| 1260 | env.emitter().exitCurrentLoopSeq(builder, loc: env.op().getLoc()); |
| 1261 | // Unmark bookkeeping of invariants and loop index. |
| 1262 | genInvariants(env, builder, exp, curr: at, /*isStart=*/false); |
| 1263 | // Finalize access pattern expansion for sparse tensor output. |
| 1264 | genExpand(env, builder, curr: at, /*isStart=*/false); |
| 1265 | } |
| 1266 | |
| 1267 | /// Recursively generates code while computing iteration lattices in order |
| 1268 | /// to manage the complexity of implementing co-iteration over unions |
| 1269 | /// and intersections of sparse iterations spaces. |
| 1270 | static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp, |
| 1271 | LoopId curr) { |
| 1272 | assert(curr == env.getCurrentDepth()); |
| 1273 | |
| 1274 | // At each leaf, assign remaining tensor (sub)expression to output tensor. |
| 1275 | if (curr == env.getLoopNum()) { |
| 1276 | Value rhs = genExp(env, rewriter, e: exp); |
| 1277 | genTensorStore(env, builder&: rewriter, exp, rhs); |
| 1278 | return; |
| 1279 | } |
| 1280 | |
| 1281 | // Construct iteration lattices for current loop index. |
| 1282 | const LatSetId lts = |
| 1283 | env.merger().optimizeSet(s: env.merger().buildLattices(e: exp, i: curr)); |
| 1284 | |
| 1285 | // Start a loop sequence. |
| 1286 | bool needsUniv = startLoopSeq(env, builder&: rewriter, exp, curr, lts); |
| 1287 | |
| 1288 | // When using sparse-iterator-based loops, we only need one loops, as |
| 1289 | // opposed to a loop sequence, to cover all the iterator spaces. |
| 1290 | const unsigned lsize = env.set(lts).size(); |
| 1291 | if (env.generatingSparseIterator()) { |
| 1292 | // Get the largest lattice point and start a loop. |
| 1293 | const LatPointId li = env.set(lts)[0]; |
| 1294 | auto [loop, isSingleCond] = |
| 1295 | startLoop(env, builder&: rewriter, curr, li, numCases: lsize, needsUniv); |
| 1296 | assert(isSingleCond == llvm::isa<IterateOp>(loop)); |
| 1297 | // We cannot change this to `for (const LatPointId li : env.set(lts))` |
| 1298 | // because the loop body causes data-movement which invalidates |
| 1299 | // the iterator. |
| 1300 | for (unsigned j = 0; j < lsize; j++) { |
| 1301 | const LatPointId lj = env.set(lts)[j]; |
| 1302 | const ExprId ej = env.lat(l: lj).exp; |
| 1303 | // Recurse into body of each branch. |
| 1304 | if (!isSingleCond) { |
| 1305 | env.genLoopBoundary(callback: [&, curr, j, li, lj](MutableArrayRef<Value> reduc) { |
| 1306 | genCoIterationCase(env, builder&: rewriter, /*caseIdx*/ j, allCase: li, curCase: lj, reduc); |
| 1307 | genStmt(env, rewriter, exp: ej, curr: curr + 1); |
| 1308 | // TODO: handle yield values. |
| 1309 | assert(reduc.empty() && "Not Implemented" ); |
| 1310 | rewriter.create<sparse_tensor::YieldOp>(env.op().getLoc()); |
| 1311 | return std::nullopt; |
| 1312 | }); |
| 1313 | // endIf(env, rewriter, ifOp, redInput, cntInput, insInput, validIns); |
| 1314 | } else { |
| 1315 | genStmt(env, rewriter, exp: ej, curr: curr + 1); |
| 1316 | } |
| 1317 | } |
| 1318 | // End a loop. |
| 1319 | needsUniv = endLoop(env, rewriter, loop, li: curr, needsUniv, isSingleCond); |
| 1320 | } else { |
| 1321 | // Emit a loop for every lattice point L0 >= Li in this loop sequence. |
| 1322 | for (unsigned i = 0; i < lsize; i++) { |
| 1323 | const LatPointId li = env.set(lts)[i]; |
| 1324 | // Start a loop. |
| 1325 | auto [loop, isSingleCond] = |
| 1326 | startLoop(env, builder&: rewriter, curr, li, numCases: lsize, needsUniv); |
| 1327 | |
| 1328 | // Visit all lattices points with Li >= Lj to generate the |
| 1329 | // loop-body, possibly with if statements for coiteration. |
| 1330 | Value redInput = env.getReduc(); |
| 1331 | Value cntInput = env.getExpandCount(); |
| 1332 | Value insInput = env.getInsertionChain(); |
| 1333 | Value validIns = env.getValidLexInsert(); |
| 1334 | // We cannot change this to `for (const LatPointId lj : env.set(lts))` |
| 1335 | // because the loop body causes data-movement which invalidates the |
| 1336 | // iterator. |
| 1337 | for (unsigned j = 0; j < lsize; j++) { |
| 1338 | const LatPointId lj = env.set(lts)[j]; |
| 1339 | const ExprId ej = env.lat(l: lj).exp; |
| 1340 | if (li == lj || env.merger().latGT(p0: li, p1: lj)) { |
| 1341 | // Recurse into body of each branch. |
| 1342 | if (!isSingleCond) { |
| 1343 | scf::IfOp ifOp = genIf(env, rewriter, curr, lj); |
| 1344 | genStmt(env, rewriter, exp: ej, curr: curr + 1); |
| 1345 | endIf(env, rewriter, ifOp, redInput, cntInput, insInput, validIns); |
| 1346 | } else { |
| 1347 | genStmt(env, rewriter, exp: ej, curr: curr + 1); |
| 1348 | } |
| 1349 | } |
| 1350 | } |
| 1351 | |
| 1352 | // End a loop. |
| 1353 | needsUniv = endLoop(env, rewriter, loop, li: curr, needsUniv, isSingleCond); |
| 1354 | } |
| 1355 | } |
| 1356 | |
| 1357 | // End a loop sequence. |
| 1358 | endLoopSeq(env, builder&: rewriter, exp, at: curr); |
| 1359 | assert(curr == env.getCurrentDepth()); |
| 1360 | } |
| 1361 | |
| 1362 | /// Converts the result computed by the sparse kernel into the required form. |
| 1363 | static void genResult(CodegenEnv &env, RewriterBase &rewriter) { |
| 1364 | linalg::GenericOp op = env.op(); |
| 1365 | OpOperand *lhs = op.getDpsInitOperand(0); |
| 1366 | Value tensor = lhs->get(); |
| 1367 | Type resType = tensor.getType(); |
| 1368 | if (getSparseTensorEncoding(type: resType)) { |
| 1369 | // The sparse tensor rematerializes from the original sparse tensor's |
| 1370 | // underlying sparse storage format. For an insertion chain, the |
| 1371 | // tensor materializes from the chain with 'hasInserts' enabled. |
| 1372 | bool hasInserts = false; |
| 1373 | if (Value chain = env.getInsertionChain()) { |
| 1374 | hasInserts = true; |
| 1375 | tensor = chain; |
| 1376 | } |
| 1377 | rewriter.replaceOpWithNewOp<LoadOp>(op, resType, tensor, hasInserts); |
| 1378 | } else { |
| 1379 | // To rematerialize an non-annotated tensor, simply load it |
| 1380 | // from the bufferized value. |
| 1381 | Value val = env.emitter().getValBuffer()[env.merger().getOutTensorID()]; |
| 1382 | rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, resType, val); |
| 1383 | } |
| 1384 | } |
| 1385 | |
| 1386 | //===----------------------------------------------------------------------===// |
| 1387 | // Sparsifier rewriting methods. |
| 1388 | //===----------------------------------------------------------------------===// |
| 1389 | |
| 1390 | namespace { |
| 1391 | |
| 1392 | /// Sparse rewriting rule for generic Lingalg operation. |
| 1393 | struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> { |
| 1394 | public: |
| 1395 | GenericOpSparsifier(MLIRContext *context, SparsificationOptions o) |
| 1396 | : OpRewritePattern<linalg::GenericOp>(context), options(o) {} |
| 1397 | |
| 1398 | LogicalResult matchAndRewrite(linalg::GenericOp op, |
| 1399 | PatternRewriter &rewriter) const override { |
| 1400 | // Only accept single output operations with pure tensor semantics. |
| 1401 | if (op.getNumDpsInits() != 1 || !op.hasPureTensorSemantics()) |
| 1402 | return failure(); |
| 1403 | |
| 1404 | // Only accept trivial affine indices. |
| 1405 | if (hasNonTrivialAffineOnSparseOut(op)) |
| 1406 | return failure(); |
| 1407 | |
| 1408 | // Only accept scheduled loops. |
| 1409 | if (!op->hasAttr("sorted" )) { |
| 1410 | return rewriter.notifyMatchFailure( |
| 1411 | op, "Loops not yet scheduled, try run --sparse-reinterpret-map " |
| 1412 | "before sparsification." ); |
| 1413 | } |
| 1414 | |
| 1415 | // Must have been demapped as well if the generic op is sorted. |
| 1416 | assert(!hasAnyNonIdentityOperandsOrResults(op)); |
| 1417 | |
| 1418 | // Sets up a code generation environment. |
| 1419 | const unsigned numTensors = op->getNumOperands(); |
| 1420 | const unsigned numLoops = op.getNumLoops(); |
| 1421 | bool needIdxRed = getNumNonTrivialIdxExpOnSparseLvls(op) != 0; |
| 1422 | // If we have indexing map like (d0) -> (0, d0), there might be more |
| 1423 | // levels then loops because of the constant index, that means we can not |
| 1424 | // use numLoops as the upper bound for ranks of all tensors. |
| 1425 | // TODO: Constant indices are currently not support on sparse tensor, but |
| 1426 | // are allowed in non-annotated dense tensor. Support it, it would be |
| 1427 | // required for sparse tensor slice rank reducing too. |
| 1428 | Level maxLvlRank = 0; |
| 1429 | for (auto operand : op.getOperands()) { |
| 1430 | if (auto rtp = dyn_cast<RankedTensorType>(operand.getType())) { |
| 1431 | maxLvlRank = std::max(maxLvlRank, SparseTensorType(rtp).getLvlRank()); |
| 1432 | } |
| 1433 | } |
| 1434 | |
| 1435 | // Detects sparse annotations and translates the per-level sparsity |
| 1436 | // information for all tensors to loop indices in the kernel. |
| 1437 | CodegenEnv env(op, options, numTensors, numLoops, maxLvlRank); |
| 1438 | if (!findSparseAnnotations(env, idxReducBased: needIdxRed)) |
| 1439 | return failure(); |
| 1440 | |
| 1441 | // Only standard reduction operations (add, sub, or, xor) that can be |
| 1442 | // sparsified by merely reducing the stored values are admissible. More |
| 1443 | // elaborate reduction operations (such as mul, and, min, max) would need |
| 1444 | // to know whether implicit zeros occur as well. They can still be |
| 1445 | // implemented with a custom reduction operation, accepted here as well. |
| 1446 | if (op.getNumReductionLoops() > 0) { |
| 1447 | Operation *yield = op.getRegion().front().getTerminator(); |
| 1448 | assert(isa<linalg::YieldOp>(yield)); |
| 1449 | Operation *redop = yield->getOperand(idx: 0).getDefiningOp(); |
| 1450 | if (!isa<arith::AddFOp>(redop) && !isa<complex::AddOp>(redop) && |
| 1451 | !isa<arith::AddIOp>(redop) && !isa<arith::SubFOp>(redop) && |
| 1452 | !isa<complex::SubOp>(redop) && !isa<arith::SubIOp>(redop) && |
| 1453 | !isa<arith::OrIOp>(redop) && !isa<arith::XOrIOp>(redop) && |
| 1454 | !isa<ReduceOp>(redop)) { |
| 1455 | return failure(); |
| 1456 | } |
| 1457 | } |
| 1458 | |
| 1459 | // Constructs the tensor expressions tree from `op`, returns failure if the |
| 1460 | // tree can not be built or the tensor expression is inadmissible. |
| 1461 | if (failed(Result: env.initTensorExp())) |
| 1462 | return failure(); |
| 1463 | |
| 1464 | // Recursively generates code if admissible. |
| 1465 | env.startEmit(emitStrategy: options.sparseEmitStrategy); |
| 1466 | genBuffers(env, builder&: rewriter); |
| 1467 | // TODO: Constant affine expression should be handled differently when using |
| 1468 | // slice-based codegen, it does not matter now because we already reject the |
| 1469 | // constant expression at an earlier stage. |
| 1470 | genInitConstantDenseAddress(env, rewriter); |
| 1471 | genStmt(env, rewriter, exp: env.getExprId(), curr: 0); |
| 1472 | genResult(env, rewriter); |
| 1473 | return success(); |
| 1474 | } |
| 1475 | |
| 1476 | private: |
| 1477 | /// Options to control sparse code generation. |
| 1478 | SparsificationOptions options; |
| 1479 | }; |
| 1480 | |
| 1481 | } // namespace |
| 1482 | |
| 1483 | /// Populates the given patterns list with rewriting rules required for |
| 1484 | /// the sparsification of linear algebra operations. |
| 1485 | void mlir::populateSparsificationPatterns( |
| 1486 | RewritePatternSet &patterns, const SparsificationOptions &options) { |
| 1487 | patterns.add<GenericOpSparsifier>(arg: patterns.getContext(), args: options); |
| 1488 | } |
| 1489 | |