| 1 | //===- Sparsification.cpp - Implementation of sparsification --------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements converting sparse tensor types to actual sparse code. |
| 10 | // |
| 11 | //===----------------------------------------------------------------------===// |
| 12 | |
| 13 | #include "Utils/CodegenEnv.h" |
| 14 | #include "Utils/CodegenUtils.h" |
| 15 | #include "Utils/LoopEmitter.h" |
| 16 | |
| 17 | #include "mlir/Dialect/Arith/IR/Arith.h" |
| 18 | #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" |
| 19 | #include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
| 20 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
| 21 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
| 22 | #include "mlir/Dialect/Linalg/Utils/Utils.h" |
| 23 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| 24 | #include "mlir/Dialect/SCF/IR/SCF.h" |
| 25 | #include "mlir/Dialect/SCF/Transforms/Transforms.h" |
| 26 | #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" |
| 27 | #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" |
| 28 | #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" |
| 29 | #include "mlir/Dialect/SparseTensor/Utils/Merger.h" |
| 30 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| 31 | |
| 32 | #include <optional> |
| 33 | |
| 34 | using namespace mlir; |
| 35 | using namespace mlir::sparse_tensor; |
| 36 | |
| 37 | //===----------------------------------------------------------------------===// |
| 38 | // Sparsifier analysis methods. |
| 39 | //===----------------------------------------------------------------------===// |
| 40 | |
| 41 | /// Returns true iff affine expression is invariant. Sets the |
| 42 | /// parameter `isCurrentLoop` when expression just became invariant. |
| 43 | static bool isInvariantAffine(AffineExpr a, LoopId curr, bool &isCurrentLoop) { |
| 44 | switch (a.getKind()) { |
| 45 | case AffineExprKind::DimId: { |
| 46 | const LoopId i = cast<AffineDimExpr>(Val&: a).getPosition(); |
| 47 | if (i + 1 == curr) { |
| 48 | isCurrentLoop = true; |
| 49 | return true; // becomes invariant at current loop |
| 50 | } |
| 51 | return i < curr; // invariant when already generated |
| 52 | } |
| 53 | case AffineExprKind::Add: |
| 54 | case AffineExprKind::Mul: { |
| 55 | auto binOp = cast<AffineBinaryOpExpr>(Val&: a); |
| 56 | return isInvariantAffine(a: binOp.getLHS(), curr, isCurrentLoop) && |
| 57 | isInvariantAffine(a: binOp.getRHS(), curr, isCurrentLoop); |
| 58 | } |
| 59 | default: { |
| 60 | assert(isa<AffineConstantExpr>(a)); |
| 61 | return true; |
| 62 | } |
| 63 | } |
| 64 | } |
| 65 | |
| 66 | /// Helper method to inspect affine expressions. Rejects cases where the |
| 67 | /// same index is used more than once. Also rejects compound affine |
| 68 | /// expressions in sparse dimensions. |
| 69 | static bool findAffine(Merger &merger, TensorId tid, Level lvl, AffineExpr a, |
| 70 | LevelType lt, bool setLvlFormat = true) { |
| 71 | switch (a.getKind()) { |
| 72 | case AffineExprKind::DimId: { |
| 73 | const LoopId idx = merger.makeLoopId(i: cast<AffineDimExpr>(Val&: a).getPosition()); |
| 74 | if (!isUndefLT(lt: merger.getLvlType(t: tid, i: idx))) |
| 75 | return false; // used more than once |
| 76 | if (setLvlFormat) |
| 77 | merger.setLevelAndType(t: tid, i: idx, lvl, lt); |
| 78 | return true; |
| 79 | } |
| 80 | case AffineExprKind::Add: |
| 81 | case AffineExprKind::Mul: |
| 82 | case AffineExprKind::Constant: { |
| 83 | assert(lt.hasDenseSemantic()); |
| 84 | if (auto binOp = dyn_cast<AffineBinaryOpExpr>(Val&: a)) { |
| 85 | // We do not set dim level format for affine expression like d0 + d1 on |
| 86 | // either loop index at d0 or d1. We continue the recursion merely to |
| 87 | // check whether current affine is admissible or not. |
| 88 | return findAffine(merger, tid, lvl, a: binOp.getLHS(), lt, setLvlFormat: false) && |
| 89 | findAffine(merger, tid, lvl, a: binOp.getRHS(), lt, setLvlFormat: false); |
| 90 | } |
| 91 | // Falls through when it is a constant Affine |
| 92 | return true; |
| 93 | } |
| 94 | default: |
| 95 | return false; |
| 96 | } |
| 97 | } |
| 98 | |
| 99 | /// Helper method to inspect affine expressions for index variable reduction |
| 100 | /// based codegen. It finds the dependent index set for all tensor levels in the |
| 101 | /// current expression we are generating. |
| 102 | /// |
| 103 | /// For example, when handling A[i+j][j+k], we build the two way mapping in |
| 104 | /// merger between (tensor, level) pairs and their dependent index variable set: |
| 105 | /// A_0 <=> [i, j] and A_1 <=> [j, k] |
| 106 | /// |
| 107 | /// It rejects cases (returns false) |
| 108 | /// 1st, when the same index is used more than once, e.g., A[i+j][i] |
| 109 | /// 2nd, when multiplication is used in the non-trivial index expression. |
| 110 | /// 3rd, when a constant operand is used in the non-trivial index expression. |
| 111 | /// |
| 112 | /// TODO: constant should be easy to handle. |
| 113 | static bool findDepIdxSet(Merger &merger, TensorId tensor, Level lvl, |
| 114 | AffineExpr a, LevelType lt, bool isSubExp = false, |
| 115 | int64_t coefficient = 1) { |
| 116 | switch (a.getKind()) { |
| 117 | case AffineExprKind::DimId: { |
| 118 | // Only allow positive coefficients on AffineDimExpr. |
| 119 | if (coefficient <= 0) |
| 120 | return false; |
| 121 | |
| 122 | const LoopId idx = merger.makeLoopId(i: cast<AffineDimExpr>(Val&: a).getPosition()); |
| 123 | if (!isUndefLT(lt: merger.getLvlType(t: tensor, i: idx))) |
| 124 | return false; // used more than once, e.g., A[i][i] |
| 125 | |
| 126 | // TODO: Generalizes the following two cases. A[i] (with trivial index |
| 127 | // expression) can be treated as a special affine index expression. We do |
| 128 | // not necessarily need to differentiate them. |
| 129 | if (!isSubExp) { |
| 130 | assert(coefficient == 1); |
| 131 | merger.setLevelAndType(t: tensor, i: idx, lvl, lt); |
| 132 | } |
| 133 | |
| 134 | if (isSubExp) { |
| 135 | // The current loops appears in more than one affine expressions on the |
| 136 | // same tensor. We can not handle this case. e.g., A[i+j][i+k], `i` is |
| 137 | // used twice. |
| 138 | if (merger.hasDependentLvl(i: idx, t: tensor)) { |
| 139 | // TODO: This can be supported by coiterate slices if the loop idx is |
| 140 | // appeared on affine index for different tensor, or take slice on |
| 141 | // multiple dimensions when it is on the same tensor. |
| 142 | // E.g., |
| 143 | // `d0 + d1` for indexing t0[lvl0] and `d0 + d2` for indexing t1[lvl0] |
| 144 | // d0_1 = getNextSliceOffset t0 along lvl0 |
| 145 | // d0_2 = getNextSliceOffset t1 along lvl0 |
| 146 | // if d0_1 == d0_2 then d0 = d0_1 = d0_1 |
| 147 | // else increase min(d0_1, d0_2). |
| 148 | return false; |
| 149 | } |
| 150 | merger.setLoopDependentTensorLevel(i: idx, t: tensor, lvl, lt, coefficient); |
| 151 | } |
| 152 | return true; |
| 153 | } |
| 154 | case AffineExprKind::Constant: |
| 155 | case AffineExprKind::Mul: { |
| 156 | // TODO: Support index expression like `2 * d0`, we now only support more |
| 157 | // complicated cases like `2 * d0 + d1`. |
| 158 | if (!isSubExp) |
| 159 | return false; |
| 160 | |
| 161 | // TODO: Support Constant AffineExp for slice-based codegen |
| 162 | if (isa<AffineConstantExpr>(Val: a)) |
| 163 | llvm_unreachable("Not yet implemented" ); |
| 164 | |
| 165 | auto binOp = cast<AffineBinaryOpExpr>(Val&: a); |
| 166 | auto lhs = binOp.getLHS(), rhs = binOp.getRHS(); |
| 167 | if (isa<AffineConstantExpr>(Val: rhs)) |
| 168 | std::swap(a&: lhs, b&: rhs); |
| 169 | // Must be in form of `constant * d`. |
| 170 | assert(isa<AffineConstantExpr>(lhs) && isa<AffineDimExpr>(rhs)); |
| 171 | int64_t coefficient = cast<AffineConstantExpr>(Val&: lhs).getValue(); |
| 172 | return findDepIdxSet(merger, tensor, lvl, a: rhs, lt, isSubExp, coefficient); |
| 173 | } |
| 174 | case AffineExprKind::Add: { |
| 175 | auto binOp = cast<AffineBinaryOpExpr>(Val&: a); |
| 176 | return findDepIdxSet(merger, tensor, lvl, a: binOp.getLHS(), lt, isSubExp: true) && |
| 177 | findDepIdxSet(merger, tensor, lvl, a: binOp.getRHS(), lt, isSubExp: true); |
| 178 | } |
| 179 | default: |
| 180 | return false; |
| 181 | } |
| 182 | } |
| 183 | |
| 184 | /// Gets the total number of compound affine expressions in the |
| 185 | /// `getMatchingIndexingMap` for the given tensor. For the following inputs: |
| 186 | /// |
| 187 | /// map = (d0, d1, d2) => (d0 + d1 : compressed, d2 : compressed) |
| 188 | /// |
| 189 | /// Returns 1 (because the first level is compressed and its corresponding |
| 190 | /// indexing-expression is `d0 + d1`) |
| 191 | static unsigned getNumNonTrivialIdxExpOnSparseLvls(AffineMap map, |
| 192 | Value tensor) { |
| 193 | // The `tensor` is not guaranteed to have `RankedTensorType`, therefore |
| 194 | // we can't use `getRankedTensorType`/`getSparseTensorType` here. |
| 195 | // However, we don't need to handle `StorageSpecifierType`, so we |
| 196 | // can use `SparseTensorType` once we guard against non-tensors. |
| 197 | const auto rtp = dyn_cast<RankedTensorType>(Val: tensor.getType()); |
| 198 | if (!rtp) |
| 199 | return 0; |
| 200 | const SparseTensorType stt(rtp); |
| 201 | |
| 202 | const Level lvlRank = stt.getLvlRank(); |
| 203 | const auto exprs = map.getResults(); |
| 204 | assert(static_cast<Dimension>(exprs.size()) == lvlRank && |
| 205 | "AffineMap does not have dimension-rank many results" ); |
| 206 | unsigned num = 0; |
| 207 | for (Level l = 0; l < lvlRank; l++) { |
| 208 | if (!isa<AffineDimExpr>(Val: exprs[l]) && !stt.getLvlType(l).hasDenseSemantic()) |
| 209 | num++; |
| 210 | } |
| 211 | return num; |
| 212 | } |
| 213 | |
| 214 | /// Gets the total number of sparse levels with compound affine |
| 215 | /// expressions, summed over all operands of the `GenericOp`. |
| 216 | static unsigned getNumNonTrivialIdxExpOnSparseLvls(linalg::GenericOp op) { |
| 217 | unsigned num = 0; |
| 218 | for (OpOperand &t : op->getOpOperands()) |
| 219 | num += getNumNonTrivialIdxExpOnSparseLvls(map: op.getMatchingIndexingMap(opOperand: &t), |
| 220 | tensor: t.get()); |
| 221 | return num; |
| 222 | } |
| 223 | |
| 224 | // Returns true iff output has nontrivial affine indices. |
| 225 | static bool hasNonTrivialAffineOnSparseOut(linalg::GenericOp op) { |
| 226 | OpOperand *out = op.getDpsInitOperand(i: 0); |
| 227 | if (getSparseTensorType(val: out->get()).isAllDense()) |
| 228 | return false; |
| 229 | return getNumNonTrivialIdxExpOnSparseLvls(map: op.getMatchingIndexingMap(opOperand: out), |
| 230 | tensor: out->get()); |
| 231 | } |
| 232 | |
| 233 | /// Helper method to inspect sparse encodings in the tensor types. |
| 234 | /// Fills the per-dimension sparsity information for all tensors. |
| 235 | /// Returns true if the sparse annotations and affine subscript |
| 236 | /// expressions of all tensors are admissible. Returns false if |
| 237 | /// no annotations are found or inadmissible constructs occur. |
| 238 | /// We currently support two different ways to handle non-trivial index |
| 239 | /// expression on sparse tensors, and they accept different affine expressions. |
| 240 | /// When using dependent index reducton-based approach, it currently only |
| 241 | /// supports affine addition index expression. |
| 242 | static bool findSparseAnnotations(CodegenEnv &env, bool idxReducBased) { |
| 243 | bool annotated = false; |
| 244 | for (OpOperand &t : env.op()->getOpOperands()) { |
| 245 | const TensorId tid = env.makeTensorId(t: t.getOperandNumber()); |
| 246 | const auto map = env.op().getMatchingIndexingMap(opOperand: &t); |
| 247 | const auto enc = getSparseTensorEncoding(type: t.get().getType()); |
| 248 | if (enc) |
| 249 | annotated = true; |
| 250 | const Level lvlRank = map.getNumResults(); |
| 251 | assert(!enc || lvlRank == enc.getLvlRank()); |
| 252 | assert(static_cast<Level>(env.op().getRank(&t)) == lvlRank); |
| 253 | // We only need to do index reduction if there is at least one |
| 254 | // non-trivial index expression on sparse levels. If all non-trivial |
| 255 | // index expression is on dense levels, we can efficiently rely on |
| 256 | // the random access to locate the element. |
| 257 | bool needIdxReduc = |
| 258 | enc && getNumNonTrivialIdxExpOnSparseLvls(map, tensor: t.get()) != 0; |
| 259 | // If then current tensor being inspected requires affine index, it need |
| 260 | // to be sliced. |
| 261 | for (Level l = 0; l < lvlRank; l++) { |
| 262 | const AffineExpr a = map.getResult(idx: l); |
| 263 | const LevelType lt = enc.getLvlType(l); |
| 264 | if (idxReducBased && needIdxReduc) { |
| 265 | if (!findDepIdxSet(merger&: env.merger(), tensor: tid, lvl: l, a, lt)) |
| 266 | return false; // inadmissible affine expression |
| 267 | } else { |
| 268 | if (!findAffine(merger&: env.merger(), tid, lvl: l, a, lt)) |
| 269 | return false; // inadmissible affine expression |
| 270 | } |
| 271 | } |
| 272 | } |
| 273 | return annotated; |
| 274 | } |
| 275 | |
| 276 | //===----------------------------------------------------------------------===// |
| 277 | // Sparsifier synthesis methods (statements and expressions). |
| 278 | //===----------------------------------------------------------------------===// |
| 279 | |
| 280 | /// Local bufferization of all dense and sparse data structures. |
| 281 | static void genBuffers(CodegenEnv &env, OpBuilder &builder) { |
| 282 | linalg::GenericOp op = env.op(); |
| 283 | Location loc = op.getLoc(); |
| 284 | assert(op.getNumOperands() == op.getNumDpsInputs() + 1); |
| 285 | |
| 286 | SmallVector<Range, 4> loopRange = |
| 287 | llvm::cast<linalg::LinalgOp>(Val: op.getOperation()) |
| 288 | .createLoopRanges(b&: builder, loc); |
| 289 | |
| 290 | env.emitter().initializeLoopEmit( |
| 291 | builder, loc, |
| 292 | /// Generates buffer for the output tensor. |
| 293 | /// Note that all sparse kernels assume that when all elements are written |
| 294 | /// to (viz. x(i) = y(i) * z(i)), the output buffer is already initialized |
| 295 | /// to all zeroes and only nonzeroes values are computed and written out. |
| 296 | /// For updates (viz. x(i) += y(i) * z(i)), only nonzeroes values are used |
| 297 | /// for the updates and no assumption on the original contents of the |
| 298 | /// output buffer is necessary. |
| 299 | updater: [&op](OpBuilder &builder, Location loc, Value memref, |
| 300 | Value tensor) -> Value { |
| 301 | // Must not be a sparse tensor. |
| 302 | assert(!getSparseTensorEncoding(tensor.getType())); |
| 303 | // Two output tensor references should point to the same object. |
| 304 | OpOperand *lhs = op.getDpsInitOperand(i: 0); |
| 305 | assert(lhs->get() == tensor); |
| 306 | // An output tensor can simply materialize from the buffer of the tensor |
| 307 | // that appears in the outs() clause. For updates, this has the |
| 308 | // advantage that only the nonzero value are involved in the |
| 309 | // computation, keeping the operation O(nnz). In all other cases, we are |
| 310 | // forced to zero out the buffer to enforce the assumption above, which |
| 311 | // may negatively impact running complexity (viz. O(n^2 + nnz) vs. |
| 312 | // O(nnz) for matrices). |
| 313 | // TODO: use better analysis to avoid zeroing out the buffer? |
| 314 | bool isInit = op.isInitTensor(opOperand: lhs); |
| 315 | Value init = memref; |
| 316 | if (!isInit) { |
| 317 | Value zero = constantZero(builder, loc, |
| 318 | tp: getElementTypeOrSelf(type: tensor.getType())); |
| 319 | builder.create<linalg::FillOp>(location: loc, args: ValueRange{zero}, |
| 320 | args: ValueRange{init}); |
| 321 | } |
| 322 | return init; |
| 323 | }, |
| 324 | synSetter: [&loopRange](OpBuilder &b, Location loc, Level l) { |
| 325 | assert(l < loopRange.size()); |
| 326 | return mlir::getValueOrCreateConstantIndexOp(b, loc, ofr: loopRange[l].size); |
| 327 | }); |
| 328 | } |
| 329 | |
| 330 | /// Generates index for load/store on sparse tensor. |
| 331 | static Value genIndex(CodegenEnv &env, OpOperand *t) { |
| 332 | const auto map = env.op().getMatchingIndexingMap(opOperand: t); |
| 333 | const auto stt = getSparseTensorType(val: t->get()); |
| 334 | const Level lvlRank = stt.getLvlRank(); |
| 335 | assert(static_cast<Level>(map.getNumResults()) == lvlRank); |
| 336 | const AffineExpr a = map.getResult(idx: lvlRank - 1); |
| 337 | assert(a.getKind() == AffineExprKind::DimId); |
| 338 | const LoopId idx = env.makeLoopId(i: cast<AffineDimExpr>(Val: a).getPosition()); |
| 339 | return env.getLoopVar(i: idx); |
| 340 | } |
| 341 | |
| 342 | /// Generates subscript for load/store on a dense or sparse tensor. |
| 343 | static Value genSubscript(CodegenEnv &env, OpBuilder &builder, OpOperand *t, |
| 344 | SmallVectorImpl<Value> &args) { |
| 345 | const Location loc = env.op().getLoc(); |
| 346 | const TensorId tid = env.makeTensorId(t: t->getOperandNumber()); |
| 347 | const auto map = env.op().getMatchingIndexingMap(opOperand: t); |
| 348 | const auto stt = getSparseTensorType(val: t->get()); |
| 349 | if (stt.hasEncoding()) { |
| 350 | // For sparse tensors we only push the last-level's position onto `args`. |
| 351 | const auto pos = env.emitter().getValPosits(tid); |
| 352 | assert(!pos.empty()); |
| 353 | args.append(RHS: pos); |
| 354 | // Simply returns the tensor to extract value using iterators. |
| 355 | if (env.options().sparseEmitStrategy == SparseEmitStrategy::kSparseIterator) |
| 356 | return t->get(); |
| 357 | } else { |
| 358 | // For dense tensors we push all level's coordinates onto `args`. |
| 359 | const Level lvlRank = stt.getLvlRank(); |
| 360 | assert(static_cast<Level>(map.getNumResults()) == lvlRank); |
| 361 | for (Level l = 0; l < lvlRank; l++) { |
| 362 | const auto lvlExpr = map.getResult(idx: l); |
| 363 | const auto lvlCrd = env.emitter().genAffine(builder, loc, a: lvlExpr); |
| 364 | args.push_back(Elt: lvlCrd); |
| 365 | } |
| 366 | } |
| 367 | return env.emitter().getValBuffer()[tid]; |
| 368 | } |
| 369 | |
| 370 | /// Generates insertion code to implement dynamic tensor load. |
| 371 | static Value genInsertionLoad(CodegenEnv &env, OpBuilder &builder, |
| 372 | OpOperand *t) { |
| 373 | linalg::GenericOp op = env.op(); |
| 374 | Location loc = op.getLoc(); |
| 375 | // Direct lexicographic coordinate order, tensor loads as zero. |
| 376 | if (!env.isExpand()) { |
| 377 | Type tp = getElementTypeOrSelf(type: t->get().getType()); |
| 378 | return constantZero(builder, loc, tp); |
| 379 | } |
| 380 | // Load from expanded access pattern. |
| 381 | Value index = genIndex(env, t); |
| 382 | return builder.create<memref::LoadOp>(location: loc, args: env.getExpandValues(), args&: index); |
| 383 | } |
| 384 | |
| 385 | /// Generates insertion code to implement dynamic tensor load for reduction. |
| 386 | static Value genInsertionLoadReduce(CodegenEnv &env, OpBuilder &builder, |
| 387 | OpOperand *t) { |
| 388 | linalg::GenericOp op = env.op(); |
| 389 | Location loc = op.getLoc(); |
| 390 | Value identity = env.getCustomRedId(); |
| 391 | // Direct lexicographic coordinate order, tensor loads as identity. |
| 392 | if (!env.isExpand()) |
| 393 | return identity; |
| 394 | // Load from expanded access pattern if filled, identity otherwise. |
| 395 | Value values = env.getExpandValues(); |
| 396 | Value filled = env.getExpandFilled(); |
| 397 | Value index = genIndex(env, t); |
| 398 | Value isFilled = builder.create<memref::LoadOp>(location: loc, args&: filled, args&: index); |
| 399 | Value valAtIndex = builder.create<memref::LoadOp>(location: loc, args&: values, args&: index); |
| 400 | return builder.create<arith::SelectOp>(location: loc, args&: isFilled, args&: valAtIndex, args&: identity); |
| 401 | } |
| 402 | |
| 403 | static Value genConditionalInsert(Location loc, OpBuilder &builder, Value cond, |
| 404 | Value sparseOut, ValueRange ivs, Value v) { |
| 405 | scf::IfOp condInsert = |
| 406 | builder.create<scf::IfOp>(location: loc, args: sparseOut.getType(), args&: cond, args: true); |
| 407 | // True branch. |
| 408 | builder.setInsertionPointToStart(condInsert.thenBlock()); |
| 409 | Value res = builder.create<tensor::InsertOp>(location: loc, args&: v, args&: sparseOut, args&: ivs); |
| 410 | builder.create<scf::YieldOp>(location: loc, args&: res); |
| 411 | // False branch. |
| 412 | builder.setInsertionPointToStart(condInsert.elseBlock()); |
| 413 | builder.create<scf::YieldOp>(location: loc, args&: sparseOut); |
| 414 | // Value assignment. |
| 415 | builder.setInsertionPointAfter(condInsert); |
| 416 | return condInsert.getResult(i: 0); |
| 417 | } |
| 418 | |
| 419 | /// Generates insertion code to implement dynamic tensor store. |
| 420 | static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t, |
| 421 | Value rhs) { |
| 422 | linalg::GenericOp op = env.op(); |
| 423 | Location loc = op.getLoc(); |
| 424 | // Direct insertion in lexicographic coordinate order. |
| 425 | if (!env.isExpand()) { |
| 426 | const LoopId numLoops = op.getRank(opOperand: t); |
| 427 | // Retrieves the first `numLoop` induction variables. |
| 428 | SmallVector<Value> ivs = llvm::to_vector(Range: llvm::drop_end( |
| 429 | RangeOrContainer: env.emitter().getLoopIVsRange(), N: env.getCurrentDepth() - numLoops)); |
| 430 | Value chain = env.getInsertionChain(); |
| 431 | if (env.isValidLexInsert()) { |
| 432 | // Generates runtime check for a valid lex during reduction, |
| 433 | // to avoid inserting the identity value for empty reductions. |
| 434 | // if (validLexInsert) then |
| 435 | // insert(rhs) into chain |
| 436 | // return updated chain |
| 437 | // else |
| 438 | // return unmodified chain |
| 439 | Value out = genConditionalInsert(loc, builder, cond: env.getValidLexInsert(), |
| 440 | sparseOut: chain, ivs, v: rhs); |
| 441 | env.updateInsertionChain(chain: out); |
| 442 | } else { |
| 443 | Value sparseOut; |
| 444 | if (!hasAnySparseType(types: env.op().getInputs().getTypes())) { |
| 445 | // This is an all-dense -> sparse kernel, test rhs != 0 before |
| 446 | // insertion. |
| 447 | Value nz = genIsNonzero(builder, loc, v: rhs); |
| 448 | sparseOut = genConditionalInsert(loc, builder, cond: nz, sparseOut: chain, ivs, v: rhs); |
| 449 | } else { |
| 450 | sparseOut = builder.create<tensor::InsertOp>(location: loc, args&: rhs, args&: chain, args&: ivs); |
| 451 | } |
| 452 | // Generates regular insertion chain. |
| 453 | env.updateInsertionChain(chain: sparseOut); |
| 454 | } |
| 455 | return; |
| 456 | } |
| 457 | // Generates insertion code along expanded access pattern. |
| 458 | // if (!expFilled[i]) then |
| 459 | // expFilled[i] = true |
| 460 | // expAdded[inserts++] = i |
| 461 | // endif |
| 462 | // values[i] = rhs |
| 463 | Value values = env.getExpandValues(); |
| 464 | Value filled = env.getExpandFilled(); |
| 465 | Value added = env.getExpandAdded(); |
| 466 | Value count = env.getExpandCount(); |
| 467 | Value index = genIndex(env, t); |
| 468 | Value fval = constantI1(builder, loc, b: false); |
| 469 | Value tval = constantI1(builder, loc, b: true); |
| 470 | // If statement. |
| 471 | Value isFilled = builder.create<memref::LoadOp>(location: loc, args&: filled, args&: index); |
| 472 | Value cond = builder.create<arith::CmpIOp>(location: loc, args: arith::CmpIPredicate::eq, |
| 473 | args&: isFilled, args&: fval); |
| 474 | scf::IfOp ifOp = builder.create<scf::IfOp>(location: loc, args: builder.getIndexType(), args&: cond, |
| 475 | /*else=*/args: true); |
| 476 | // True branch. |
| 477 | builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); |
| 478 | builder.create<memref::StoreOp>(location: loc, args&: tval, args&: filled, args&: index); |
| 479 | builder.create<memref::StoreOp>(location: loc, args&: index, args&: added, args&: count); |
| 480 | Value one = constantIndex(builder, loc, i: 1); |
| 481 | Value add = builder.create<arith::AddIOp>(location: loc, args&: count, args&: one); |
| 482 | builder.create<scf::YieldOp>(location: loc, args&: add); |
| 483 | // False branch. |
| 484 | builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); |
| 485 | builder.create<scf::YieldOp>(location: loc, args&: count); |
| 486 | builder.setInsertionPointAfter(ifOp); |
| 487 | // Value assignment. |
| 488 | env.updateExpandCount(count: ifOp.getResult(i: 0)); |
| 489 | builder.create<memref::StoreOp>(location: loc, args&: rhs, args&: values, args&: index); |
| 490 | } |
| 491 | |
| 492 | /// Generates a load on a dense or sparse tensor. |
| 493 | static Value genTensorLoad(CodegenEnv &env, OpBuilder &builder, ExprId exp) { |
| 494 | // Test if the load was hoisted to a higher loop nest. |
| 495 | Value val = env.exp(e: exp).val; |
| 496 | if (val) |
| 497 | return val; |
| 498 | // Get tensor operand. |
| 499 | linalg::GenericOp op = env.op(); |
| 500 | Location loc = op.getLoc(); |
| 501 | OpOperand *t = &op->getOpOperand(idx: env.exp(e: exp).tensor); |
| 502 | // Fold binary-valued tensor into explicit value. |
| 503 | const auto stt = getSparseTensorType(val: t->get()); |
| 504 | if (auto explVal = stt.getExplicitVal()) |
| 505 | return genValFromAttr(builder, loc, attr: explVal); |
| 506 | // Load during insertion. |
| 507 | if (env.isSparseOutput(o: t)) { |
| 508 | if (env.isCustomReduc()) |
| 509 | return genInsertionLoadReduce(env, builder, t); |
| 510 | return genInsertionLoad(env, builder, t); |
| 511 | } |
| 512 | |
| 513 | // Actual load. |
| 514 | SmallVector<Value> args; |
| 515 | Value ptr = genSubscript(env, builder, t, args); |
| 516 | if (llvm::isa<TensorType>(Val: ptr.getType())) { |
| 517 | assert(env.options().sparseEmitStrategy == |
| 518 | SparseEmitStrategy::kSparseIterator); |
| 519 | return builder.create<ExtractValOp>(location: loc, args&: ptr, args&: llvm::getSingleElement(C&: args)); |
| 520 | } |
| 521 | return builder.create<memref::LoadOp>(location: loc, args&: ptr, args); |
| 522 | } |
| 523 | |
| 524 | /// Generates a store on a dense or sparse tensor. |
| 525 | static void genTensorStore(CodegenEnv &env, OpBuilder &builder, ExprId exp, |
| 526 | Value rhs) { |
| 527 | // Only unary and binary are allowed to return an uninitialized rhs |
| 528 | // to indicate missing output. Or otherwise a custom reduction that |
| 529 | // received no value to accumulate. |
| 530 | if (!rhs) { |
| 531 | assert(env.exp(exp).kind == TensorExp::Kind::kUnary || |
| 532 | env.exp(exp).kind == TensorExp::Kind::kBinary || |
| 533 | env.exp(exp).kind == TensorExp::Kind::kReduce); |
| 534 | return; |
| 535 | } |
| 536 | // Test if this is a scalarized reduction. |
| 537 | if (env.isReduc()) { |
| 538 | env.updateReduc(val: rhs); |
| 539 | return; |
| 540 | } |
| 541 | // Regular store. |
| 542 | linalg::GenericOp op = env.op(); |
| 543 | Location loc = op.getLoc(); |
| 544 | OpOperand *t = op.getDpsInitOperand(i: 0); |
| 545 | if (!env.isSparseOutput(o: t)) { |
| 546 | SmallVector<Value> args; |
| 547 | Value ptr = genSubscript(env, builder, t, args); |
| 548 | builder.create<memref::StoreOp>(location: loc, args&: rhs, args&: ptr, args); |
| 549 | return; |
| 550 | } |
| 551 | // Store during sparse insertion. |
| 552 | if (env.exp(e: exp).kind != TensorExp::Kind::kSelect) { |
| 553 | genInsertionStore(env, builder, t, rhs); |
| 554 | return; |
| 555 | } |
| 556 | // Select operation insertion. |
| 557 | Value chain = env.getInsertionChain(); |
| 558 | scf::IfOp ifOp = |
| 559 | builder.create<scf::IfOp>(location: loc, args: chain.getType(), args&: rhs, /*else=*/args: true); |
| 560 | builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); |
| 561 | // Existing value was preserved to be used here. |
| 562 | assert(env.exp(exp).val); |
| 563 | Value v0 = env.exp(e: exp).val; |
| 564 | genInsertionStore(env, builder, t, rhs: v0); |
| 565 | env.merger().clearExprValue(e: exp); |
| 566 | // Yield modified insertion chain along true branch. |
| 567 | Value mchain = env.getInsertionChain(); |
| 568 | builder.create<scf::YieldOp>(location: op.getLoc(), args&: mchain); |
| 569 | // Yield original insertion chain along false branch. |
| 570 | builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); |
| 571 | builder.create<scf::YieldOp>(location: loc, args&: chain); |
| 572 | // Done with if statement. |
| 573 | env.updateInsertionChain(chain: ifOp->getResult(idx: 0)); |
| 574 | builder.setInsertionPointAfter(ifOp); |
| 575 | } |
| 576 | |
| 577 | /// Generates an invariant value. |
| 578 | inline static Value genInvariantValue(CodegenEnv &env, ExprId exp) { |
| 579 | return env.exp(e: exp).val; |
| 580 | } |
| 581 | |
| 582 | /// Semi-ring branches are simply inlined by the sparsifier. Prior |
| 583 | /// analysis has verified that all computations are "local" to the inlined |
| 584 | /// branch or otherwise invariantly defined outside the loop nest, with the |
| 585 | /// exception of index computations, which need to be relinked to actual |
| 586 | /// inlined cloned code. |
| 587 | static Value relinkBranch(CodegenEnv &env, RewriterBase &rewriter, Block *block, |
| 588 | Value e) { |
| 589 | if (auto arg = dyn_cast<BlockArgument>(Val&: e)) { |
| 590 | // Direct arguments of the original linalg op must be converted |
| 591 | // into dense tensor loads. Note that we should not encounter |
| 592 | // anything else. This needs to be verified by semi-ring ops. |
| 593 | linalg::GenericOp op = env.op(); |
| 594 | if (arg.getOwner()->getParentOp() == op) { |
| 595 | const TensorId tid = env.makeTensorId(t: arg.getArgNumber()); |
| 596 | OpOperand *t = &op->getOpOperand(idx: tid); |
| 597 | assert(!getSparseTensorType(t->get()).hasEncoding()); // dense! |
| 598 | SmallVector<Value> args; |
| 599 | Value ptr = genSubscript(env, builder&: rewriter, t, args); |
| 600 | return rewriter.create<memref::LoadOp>(location: op.getLoc(), args&: ptr, args); |
| 601 | } |
| 602 | } else if (Operation *def = e.getDefiningOp()) { |
| 603 | // Handle index computation. |
| 604 | if (auto indexOp = dyn_cast<linalg::IndexOp>(Val: def)) |
| 605 | return env.getLoopVar(i: env.makeLoopId(i: indexOp.getDim())); |
| 606 | // When still defined in new body, recurse into operands. |
| 607 | if (def->getBlock() == block) { |
| 608 | rewriter.setInsertionPoint(def); |
| 609 | for (unsigned i = 0, n = def->getNumOperands(); i < n; i++) { |
| 610 | rewriter.modifyOpInPlace(root: def, callable: [&]() { |
| 611 | def->setOperand( |
| 612 | idx: i, value: relinkBranch(env, rewriter, block, e: def->getOperand(idx: i))); |
| 613 | }); |
| 614 | } |
| 615 | } |
| 616 | } |
| 617 | return e; |
| 618 | } |
| 619 | |
| 620 | /// Recursively generates tensor expression. |
| 621 | static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e) { |
| 622 | if (e == ::mlir::sparse_tensor::detail::kInvalidId) |
| 623 | return Value(); |
| 624 | |
| 625 | linalg::GenericOp op = env.op(); |
| 626 | Location loc = op.getLoc(); |
| 627 | const TensorExp &exp = env.exp(e); |
| 628 | const auto kind = exp.kind; |
| 629 | if (kind == TensorExp::Kind::kTensor) |
| 630 | return genTensorLoad(env, builder&: rewriter, exp: e); |
| 631 | if (kind == TensorExp::Kind::kInvariant) |
| 632 | return genInvariantValue(env, exp: e); |
| 633 | if (kind == TensorExp::Kind::kLoopVar) |
| 634 | return env.getLoopVar(i: exp.loop); |
| 635 | |
| 636 | if (kind == TensorExp::Kind::kReduce) |
| 637 | env.startCustomReduc(exp: e); // enter custom |
| 638 | |
| 639 | // If either lhs/rhs is a synthetic zero, we infer the type for the zero value |
| 640 | // based on the type of the other operand. |
| 641 | Value v0, v1; |
| 642 | if (exp.children.e0 != ::mlir::sparse_tensor::detail::kInvalidId && |
| 643 | env.exp(e: exp.children.e0).kind == TensorExp::Kind::kSynZero) { |
| 644 | v1 = genExp(env, rewriter, e: exp.children.e1); |
| 645 | v0 = constantZero(builder&: rewriter, loc, tp: v1.getType()); |
| 646 | } else if (exp.children.e1 != ::mlir::sparse_tensor::detail::kInvalidId && |
| 647 | env.exp(e: exp.children.e1).kind == TensorExp::Kind::kSynZero) { |
| 648 | v0 = genExp(env, rewriter, e: exp.children.e0); |
| 649 | v1 = constantZero(builder&: rewriter, loc, tp: v0.getType()); |
| 650 | } else { |
| 651 | v0 = genExp(env, rewriter, e: exp.children.e0); |
| 652 | v1 = genExp(env, rewriter, e: exp.children.e1); |
| 653 | } |
| 654 | |
| 655 | Value ee; |
| 656 | if (kind == TensorExp::Kind::kReduce && (!v0 || !v1)) { |
| 657 | // custom reduce did not receive a value |
| 658 | } else { |
| 659 | ee = env.merger().buildExp(rewriter, loc, e, v0, v1); |
| 660 | if (ee && |
| 661 | (kind == TensorExp::Kind::kUnary || kind == TensorExp::Kind::kBinary || |
| 662 | kind == TensorExp::Kind::kBinaryBranch || |
| 663 | kind == TensorExp::Kind::kReduce || |
| 664 | kind == TensorExp::Kind::kSelect)) { |
| 665 | OpBuilder::InsertionGuard guard(rewriter); |
| 666 | ee = relinkBranch(env, rewriter, block: ee.getParentBlock(), e: ee); |
| 667 | } |
| 668 | } |
| 669 | |
| 670 | if (kind == TensorExp::Kind::kReduce) |
| 671 | env.endCustomReduc(); // exit custom |
| 672 | |
| 673 | if (kind == TensorExp::Kind::kSelect) |
| 674 | env.merger().setExprValue(e, v: v0); // Preserve value for later use. |
| 675 | |
| 676 | return ee; |
| 677 | } |
| 678 | |
| 679 | /// Hoists loop invariant tensor loads for which indices have been exhausted. |
| 680 | static void genInvariants(CodegenEnv &env, OpBuilder &builder, ExprId exp, |
| 681 | LoopId curr, bool isStart) { |
| 682 | if (exp == ::mlir::sparse_tensor::detail::kInvalidId) |
| 683 | return; |
| 684 | if (env.exp(e: exp).kind == TensorExp::Kind::kTensor) { |
| 685 | // Inspect tensor indices. |
| 686 | linalg::GenericOp op = env.op(); |
| 687 | OpOperand &t = op->getOpOperand(idx: env.exp(e: exp).tensor); |
| 688 | const auto map = op.getMatchingIndexingMap(opOperand: &t); |
| 689 | const auto stt = getSparseTensorType(val: t.get()); |
| 690 | const Level lvlRank = stt.getLvlRank(); |
| 691 | assert(static_cast<Level>(map.getNumResults()) == lvlRank); |
| 692 | bool isCurrentLoop = curr == 0; // for scalar tensors |
| 693 | for (Level l = 0; l < lvlRank; l++) { |
| 694 | const AffineExpr a = map.getResult(idx: l); |
| 695 | if (!isInvariantAffine(a, curr, /*out*/ isCurrentLoop)) |
| 696 | return; // still in play |
| 697 | } |
| 698 | // All exhausted at current level. |
| 699 | if (!isCurrentLoop) |
| 700 | return; |
| 701 | // Generate code for a scalarized reduction or invariant. Note that |
| 702 | // because custom reduction lhs may occur several times in the IR, |
| 703 | // we have a built-in safety for only initializing and wrapping-up |
| 704 | // the scalarized reduction once. |
| 705 | OpOperand *lhs = op.getDpsInitOperand(i: 0); |
| 706 | if (lhs == &t) { |
| 707 | // Start or end a scalarized reduction. |
| 708 | if (isStart) { |
| 709 | if (env.isCustomReduc()) { |
| 710 | if (!env.isReduc()) |
| 711 | env.startReduc(exp, val: env.getCustomRedId()); |
| 712 | } else { |
| 713 | env.startReduc(exp, val: genTensorLoad(env, builder, exp)); |
| 714 | } |
| 715 | if (env.hasSparseOutput()) |
| 716 | env.startValidLexInsert( |
| 717 | val: constantI1(builder, loc: env.op().getLoc(), b: false)); |
| 718 | } else { |
| 719 | if (!env.isCustomReduc() || env.isReduc()) |
| 720 | genTensorStore(env, builder, exp, rhs: env.endReduc()); |
| 721 | if (env.hasSparseOutput()) |
| 722 | env.endValidLexInsert(); |
| 723 | } |
| 724 | } else { |
| 725 | // Start or end loop invariant hoisting of a tensor load. |
| 726 | if (isStart) { |
| 727 | env.merger().setExprValue(e: exp, v: genTensorLoad(env, builder, exp)); |
| 728 | } else { |
| 729 | env.merger().clearExprValue(e: exp); |
| 730 | } |
| 731 | } |
| 732 | } else if (env.exp(e: exp).kind != TensorExp::Kind::kInvariant && |
| 733 | env.exp(e: exp).kind != TensorExp::Kind::kLoopVar && |
| 734 | env.exp(e: exp).kind != TensorExp::Kind::kSynZero) { |
| 735 | // Traverse into the binary operations. Note that we only hoist |
| 736 | // tensor loads, since subsequent MLIR/LLVM passes know how to |
| 737 | // deal with all other kinds of derived loop invariants. |
| 738 | if (env.exp(e: exp).kind == TensorExp::Kind::kReduce) |
| 739 | env.startCustomReduc(exp); // enter custom |
| 740 | const ExprId e0 = env.exp(e: exp).children.e0; |
| 741 | const ExprId e1 = env.exp(e: exp).children.e1; |
| 742 | genInvariants(env, builder, exp: e0, curr, isStart); |
| 743 | genInvariants(env, builder, exp: e1, curr, isStart); |
| 744 | if (env.exp(e: exp).kind == TensorExp::Kind::kReduce) |
| 745 | env.endCustomReduc(); // exit custom |
| 746 | } |
| 747 | } |
| 748 | |
| 749 | /// Generates an expanded access pattern in innermost dimension. |
| 750 | static void genExpand(CodegenEnv &env, OpBuilder &builder, LoopId curr, |
| 751 | bool isStart) { |
| 752 | linalg::GenericOp op = env.op(); |
| 753 | OpOperand *lhs = op.getDpsInitOperand(i: 0); |
| 754 | if (!env.atExpandLevel(o: lhs, rank: op.getRank(opOperand: lhs), n: curr)) |
| 755 | return; // not needed at current level |
| 756 | assert(!env.isReduc()); |
| 757 | // Generate start or end of an expanded access pattern. Note that because |
| 758 | // an expansion does not rely on the ongoing contents of the sparse storage |
| 759 | // scheme, we can use the original tensor as incoming SSA value (which |
| 760 | // simplifies codegen a bit). If expansion on the actual contents is ever |
| 761 | // needed, we will need to use the SSA value in the insertion chain instead. |
| 762 | Value tensor = lhs->get(); |
| 763 | Location loc = op.getLoc(); |
| 764 | if (isStart) { |
| 765 | auto dynShape = {ShapedType::kDynamic}; |
| 766 | Type etp = cast<ShapedType>(Val: tensor.getType()).getElementType(); |
| 767 | Type t1 = MemRefType::get(shape: dynShape, elementType: etp); |
| 768 | Type t2 = MemRefType::get(shape: dynShape, elementType: builder.getI1Type()); |
| 769 | Type t3 = MemRefType::get(shape: dynShape, elementType: builder.getIndexType()); |
| 770 | Type t4 = builder.getIndexType(); |
| 771 | auto r = builder.create<ExpandOp>(location: loc, args: TypeRange({t1, t2, t3, t4}), args&: tensor); |
| 772 | assert(r.getNumResults() == 4); |
| 773 | env.startExpand(values: r.getResult(i: 0), filled: r.getResult(i: 1), added: r.getResult(i: 2), |
| 774 | count: r.getResult(i: 3)); |
| 775 | } else { |
| 776 | SmallVector<Value> indices; |
| 777 | for (LoopId i = 0; i < curr; i++) |
| 778 | indices.push_back(Elt: env.emitter().getLoopIV(n: i)); |
| 779 | Value values = env.getExpandValues(); |
| 780 | Value filled = env.getExpandFilled(); |
| 781 | Value added = env.getExpandAdded(); |
| 782 | Value count = env.getExpandCount(); |
| 783 | Value chain = env.getInsertionChain(); |
| 784 | Value compress = builder.create<CompressOp>(location: loc, args&: values, args&: filled, args&: added, |
| 785 | args&: count, args&: chain, args&: indices); |
| 786 | env.updateInsertionChain(chain: compress); |
| 787 | env.endExpand(); |
| 788 | } |
| 789 | } |
| 790 | |
| 791 | /// Returns parallelization strategy. Any implicit loop in the Linalg |
| 792 | /// operation that is marked "parallel" is a candidate. Whether it is actually |
| 793 | /// converted to a parallel operation depends on the requested strategy. |
| 794 | static bool isParallelFor(CodegenEnv &env, bool isOuter, bool isSparse) { |
| 795 | // Reject parallelization of sparse output. |
| 796 | if (env.hasSparseOutput()) |
| 797 | return false; |
| 798 | // Parallel loops on tensor expansion can cause data races. |
| 799 | if (env.isExpand()) |
| 800 | return false; |
| 801 | // Inspect strategy. |
| 802 | switch (env.options().parallelizationStrategy) { |
| 803 | case SparseParallelizationStrategy::kNone: |
| 804 | return false; |
| 805 | case SparseParallelizationStrategy::kDenseOuterLoop: |
| 806 | return isOuter && !isSparse; |
| 807 | case SparseParallelizationStrategy::kAnyStorageOuterLoop: |
| 808 | return isOuter; |
| 809 | case SparseParallelizationStrategy::kDenseAnyLoop: |
| 810 | return !isSparse; |
| 811 | case SparseParallelizationStrategy::kAnyStorageAnyLoop: |
| 812 | return true; |
| 813 | } |
| 814 | llvm_unreachable("unexpected parallelization strategy" ); |
| 815 | } |
| 816 | |
| 817 | /// Whether or not the current loop being generated should be parallized (if |
| 818 | /// possible) according to the configuration. |
| 819 | static bool shouldTryParallize(CodegenEnv &env, LoopId curr, |
| 820 | ArrayRef<TensorLevel> tidLvls) { |
| 821 | linalg::GenericOp op = env.op(); |
| 822 | auto iteratorTypes = op.getIteratorTypesArray(); |
| 823 | bool isSparse = llvm::any_of(Range&: tidLvls, P: [curr, &env](TensorLevel tidLvl) { |
| 824 | // Queries the LT based on the tensor and loop id, as requested by |
| 825 | // `CodegenEnv::lt(TensorId, LoopId)`. The returned LT from CodegenEnv |
| 826 | // should be consistent with the LT indexed by <TensorId, Level>. |
| 827 | const auto lt = env.lt(t: env.unpackTensorLevel(tl: tidLvl).first, i: curr); |
| 828 | return lt.hasSparseSemantic(); |
| 829 | }); |
| 830 | return isParallelFor(env, /*isOuter=*/curr == 0, isSparse); |
| 831 | } |
| 832 | |
| 833 | /// Emit a loop to coiterate over the list of tensor levels. The generated loop |
| 834 | /// can either be a for loop or while loop depending on whether there is at most |
| 835 | /// one sparse level in the list. |
| 836 | static Operation *genCoIteration(CodegenEnv &env, OpBuilder &builder, |
| 837 | ArrayRef<TensorLevel> tidLvls, |
| 838 | unsigned numCases, bool tryParallel, |
| 839 | bool needsUniv) { |
| 840 | Operation *loop = *env.genLoopBoundary(callback: [&](MutableArrayRef<Value> reduc) { |
| 841 | // Construct while-loop with a parameter for each index. |
| 842 | return env.emitter().enterCoIterationOverTensorsAtLvls( |
| 843 | builder, loc: env.op().getLoc(), tidLvls, numCases, reduc, isParallel: tryParallel, |
| 844 | needsUniv); |
| 845 | }); |
| 846 | assert(loop); |
| 847 | return loop; |
| 848 | } |
| 849 | |
| 850 | /// Generates a for-loop or a while-loop, depending on whether it implements |
| 851 | /// singleton iteration or co-iteration over the given conjunction. |
| 852 | static Operation *genLoop(CodegenEnv &env, OpBuilder &builder, LoopId curr, |
| 853 | unsigned numCases, bool needsUniv, |
| 854 | ArrayRef<TensorLevel> tidLvls) { |
| 855 | bool tryParallel = shouldTryParallize(env, curr, tidLvls); |
| 856 | return genCoIteration(env, builder, tidLvls, numCases, tryParallel, |
| 857 | needsUniv); |
| 858 | } |
| 859 | |
| 860 | /// Generates the induction structure for a while-loop. |
| 861 | static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder, |
| 862 | bool needsUniv) { |
| 863 | Location loc = env.op().getLoc(); |
| 864 | // Finalize each else branch of all if statements. |
| 865 | if (env.isReduc() || env.isExpand() || env.getInsertionChain()) { |
| 866 | while (auto ifOp = dyn_cast_or_null<scf::IfOp>( |
| 867 | Val: builder.getInsertionBlock()->getParentOp())) { |
| 868 | // Break on IfOp for slicing filtering. |
| 869 | if (ifOp->getAttr(name: LoopEmitter::getLoopEmitterLoopAttrName()) == |
| 870 | StringAttr::get(context: ifOp->getContext(), bytes: "slice" )) |
| 871 | break; |
| 872 | |
| 873 | unsigned y = 0; |
| 874 | SmallVector<Value> yields; |
| 875 | if (env.isReduc()) { |
| 876 | yields.push_back(Elt: env.getReduc()); |
| 877 | env.updateReduc(val: ifOp.getResult(i: y++)); |
| 878 | if (env.isValidLexInsert()) { |
| 879 | yields.push_back(Elt: env.getValidLexInsert()); |
| 880 | env.updateValidLexInsert(val: ifOp.getResult(i: y++)); |
| 881 | } |
| 882 | } |
| 883 | if (env.isExpand()) { |
| 884 | yields.push_back(Elt: env.getExpandCount()); |
| 885 | env.updateExpandCount(count: ifOp->getResult(idx: y++)); |
| 886 | } |
| 887 | if (env.getInsertionChain()) { |
| 888 | yields.push_back(Elt: env.getInsertionChain()); |
| 889 | env.updateInsertionChain(chain: ifOp->getResult(idx: y++)); |
| 890 | } |
| 891 | assert(y == yields.size()); |
| 892 | builder.create<scf::YieldOp>(location: loc, args&: yields); |
| 893 | builder.setInsertionPointAfter(ifOp); |
| 894 | } |
| 895 | } |
| 896 | // No need to set the insertion point here as LoopEmitter keeps track of the |
| 897 | // basic block where scf::Yield should be inserted. |
| 898 | } |
| 899 | |
| 900 | /// Generates a case region in the coiterate operation. |
| 901 | static void genCoIterationCase(CodegenEnv &env, OpBuilder &builder, |
| 902 | unsigned caseIdx, LatPointId allCase, |
| 903 | LatPointId curCase, |
| 904 | MutableArrayRef<Value> reduc) { |
| 905 | assert(allCase == curCase || env.merger().latGT(allCase, curCase)); |
| 906 | const BitVector &allCaseBits = env.merger().lat(p: allCase).simple; |
| 907 | const BitVector &curCaseBits = env.merger().lat(p: curCase).simple; |
| 908 | |
| 909 | /// Computes the subset of iterators that are valid in the current case being |
| 910 | /// generated. |
| 911 | I64BitSet caseBit(0); |
| 912 | for (auto [idx, set] : llvm::enumerate(First: allCaseBits.set_bits())) |
| 913 | if (curCaseBits.test(Idx: set)) |
| 914 | caseBit.set(idx); |
| 915 | |
| 916 | env.emitter().enterCurrentCoIterationCase(builder, loc: env.op().getLoc(), caseBit, |
| 917 | caseIdx, reduc); |
| 918 | } |
| 919 | |
| 920 | /// Generates a single if-statement within a while-loop. |
| 921 | static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId curr, |
| 922 | LatPointId p) { |
| 923 | Location loc = env.op().getLoc(); |
| 924 | SmallVector<Type> types; |
| 925 | Value cond; |
| 926 | env.merger().foreachTensorLoopId( |
| 927 | p, /*simple=*/true, |
| 928 | callback: [&](TensorLoopId b, TensorId tid, std::optional<Level> lvl, LevelType lt, |
| 929 | bool isIdxRed) { |
| 930 | if (isIdxRed) { |
| 931 | // Since there is no 1:1 mapping from loop to level (multiple loops |
| 932 | // are required to resolve one level with non-trivial index |
| 933 | // expression), we need to reconstruct the tensor level types if this |
| 934 | // loop requires index reduction condition. |
| 935 | assert(lvl.has_value() && isUndefLT(lt)); |
| 936 | auto stt = getSparseTensorType(val: env.op().getInputs()[tid]); |
| 937 | lt = stt.getLvlType(l: *lvl); |
| 938 | } |
| 939 | assert(curr == env.merger().loop(b)); |
| 940 | Value clause; |
| 941 | if (lt.hasSparseSemantic()) { |
| 942 | assert(lvl.has_value()); |
| 943 | const Value crd = env.emitter().getCoord(tid, lvl: *lvl); |
| 944 | const Value lvar = env.getLoopVar(i: curr); |
| 945 | clause = builder.create<arith::CmpIOp>(location: loc, args: arith::CmpIPredicate::eq, |
| 946 | args: crd, args: lvar); |
| 947 | } else { |
| 948 | assert(lt.hasDenseSemantic() || isUndefLT(lt)); |
| 949 | clause = constantI1(builder, loc, b: true); |
| 950 | } |
| 951 | cond = cond ? builder.create<arith::AndIOp>(location: loc, args&: cond, args&: clause) : clause; |
| 952 | }); |
| 953 | if (env.isReduc()) { |
| 954 | types.push_back(Elt: env.getReduc().getType()); |
| 955 | if (env.isValidLexInsert()) |
| 956 | types.push_back(Elt: env.getValidLexInsert().getType()); |
| 957 | } |
| 958 | if (env.isExpand()) |
| 959 | types.push_back(Elt: builder.getIndexType()); |
| 960 | if (env.getInsertionChain()) |
| 961 | types.push_back(Elt: env.getInsertionChain().getType()); |
| 962 | scf::IfOp ifOp = builder.create<scf::IfOp>(location: loc, args&: types, args&: cond, /*else=*/args: true); |
| 963 | builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); |
| 964 | return ifOp; |
| 965 | } |
| 966 | |
| 967 | /// Generates end of true branch of if-statement within a while-loop. |
| 968 | static void endIf(CodegenEnv &env, OpBuilder &builder, scf::IfOp ifOp, |
| 969 | Value redInput, Value cntInput, Value insInput, |
| 970 | Value validIns) { |
| 971 | SmallVector<Value> operands; |
| 972 | if (env.isReduc()) { |
| 973 | operands.push_back(Elt: env.getReduc()); |
| 974 | env.updateReduc(val: redInput); |
| 975 | if (env.isValidLexInsert()) { |
| 976 | // Any overlapping indices during a reduction creates a valid lex insert. |
| 977 | operands.push_back(Elt: constantI1(builder, loc: env.op().getLoc(), b: true)); |
| 978 | env.updateValidLexInsert(val: validIns); |
| 979 | } |
| 980 | } |
| 981 | if (env.isExpand()) { |
| 982 | operands.push_back(Elt: env.getExpandCount()); |
| 983 | env.updateExpandCount(count: cntInput); |
| 984 | } |
| 985 | if (env.getInsertionChain()) { |
| 986 | operands.push_back(Elt: env.getInsertionChain()); |
| 987 | env.updateInsertionChain(chain: insInput); |
| 988 | } |
| 989 | if (!operands.empty()) |
| 990 | builder.create<scf::YieldOp>(location: env.op().getLoc(), args&: operands); |
| 991 | builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); |
| 992 | } |
| 993 | |
| 994 | //===----------------------------------------------------------------------===// |
| 995 | // Sparsifier synthesis methods (loop sequence). |
| 996 | //===----------------------------------------------------------------------===// |
| 997 | |
| 998 | static bool getAllTidLvlsInLatPoints( |
| 999 | CodegenEnv &env, LatPointId li, LoopId curr, |
| 1000 | llvm::function_ref<void(TensorLevel, AffineExpr)> callback) { |
| 1001 | const BitVector &simple = env.lat(l: li).simple; |
| 1002 | const TensorId outTid = env.merger().getOutTensorID(); |
| 1003 | const std::optional<Level> outLvl = env.merger().getLvl(t: outTid, i: curr); |
| 1004 | |
| 1005 | unsigned numloopCond = 0; |
| 1006 | bool hasNonUnique = false; |
| 1007 | env.merger().foreachTensorLoopId( |
| 1008 | p: li, callback: [&, curr](TensorLoopId b, TensorId tid, std::optional<Level> lvl, |
| 1009 | LevelType lt, bool isIdxReduc) { |
| 1010 | if (simple[b]) { |
| 1011 | if (isIdxReduc) { |
| 1012 | callback(env.makeTensorLevel(t: tid, l: *lvl), nullptr); |
| 1013 | numloopCond++; |
| 1014 | return; |
| 1015 | } |
| 1016 | if (isUndefLT(lt)) { |
| 1017 | // An undefined lt in the lattices, we probably mean to |
| 1018 | // generate a dense loop according to the synthetic tensor (for |
| 1019 | // invariants and sparse output tensor). |
| 1020 | if (env.merger().getSynTensorID() == tid) { |
| 1021 | // Coiterating with an invariant |
| 1022 | // e.g., out = prod(in[i][j] op invariant); |
| 1023 | // or a broadcast |
| 1024 | // e.g., out[i][j] = in[i] (j is undef for input) |
| 1025 | // |
| 1026 | // The level of the synthetic tensor is the current loop depth; |
| 1027 | // the rank of the synthetic tensor equals to number of loops. |
| 1028 | assert(curr == env.getCurrentDepth()); |
| 1029 | lvl = curr; |
| 1030 | } else if (!lvl) { |
| 1031 | // Skips invalid lvl (e.g., when this is a zero ranked tensor). |
| 1032 | return; |
| 1033 | } |
| 1034 | } |
| 1035 | hasNonUnique = !isUniqueLT(lt) || hasNonUnique; |
| 1036 | callback(env.makeTensorLevel(t: tid, l: *lvl), nullptr); |
| 1037 | numloopCond++; |
| 1038 | } else if (lt.hasDenseSemantic() || isIdxReduc) { |
| 1039 | callback(env.makeTensorLevel(t: tid, l: *lvl), nullptr); |
| 1040 | } else { |
| 1041 | assert(isUndefLT(lt)); |
| 1042 | linalg::GenericOp op = env.op(); |
| 1043 | if (tid >= op.getNumDpsInputs()) |
| 1044 | // We only handle affine expression on input tensors (for now). |
| 1045 | return; |
| 1046 | OpOperand *operand = &op->getOpOperand(idx: tid); |
| 1047 | const auto stt = getSparseTensorType(val: operand->get()); |
| 1048 | // Non-annotated dense tensors requires no special handling. |
| 1049 | if (!stt.hasEncoding()) |
| 1050 | return; |
| 1051 | |
| 1052 | ArrayRef<AffineExpr> affines = |
| 1053 | op.getMatchingIndexingMap(opOperand: operand).getResults(); |
| 1054 | const Level lvlRank = stt.getLvlRank(); |
| 1055 | assert(affines.size() == static_cast<size_t>(lvlRank)); |
| 1056 | for (Level l = 0; l < lvlRank; l++) { |
| 1057 | AffineExpr exp = affines[l]; |
| 1058 | // Skip simple affine expression and non-dense levels (which |
| 1059 | // have their own filter loop). |
| 1060 | LevelType lt = stt.getLvlType(l); |
| 1061 | if (isa<AffineDimExpr>(Val: exp) || !lt.hasDenseSemantic()) |
| 1062 | continue; |
| 1063 | |
| 1064 | // Constant affine expression are handled in genLoop. |
| 1065 | if (!isa<AffineConstantExpr>(Val: exp)) { |
| 1066 | bool isCurrentLoop = false; |
| 1067 | assert(curr == env.getCurrentDepth()); |
| 1068 | if (isInvariantAffine(a: exp, curr: curr + 1, /*out*/ isCurrentLoop) && |
| 1069 | isCurrentLoop) { |
| 1070 | // If the compound affine is invariant and we are right at the |
| 1071 | // level. We need to generate the address according to the |
| 1072 | // affine expression. This is also the best place we can do it |
| 1073 | // to avoid putting it inside inner loops. |
| 1074 | callback(env.makeTensorLevel(t: tid, l), exp); |
| 1075 | } |
| 1076 | } |
| 1077 | } |
| 1078 | } |
| 1079 | }); |
| 1080 | |
| 1081 | if (isDenseLT(lt: env.lt(t: outTid, i: curr))) { |
| 1082 | auto stt = getSparseTensorType(val: env.op().getOutputs().front()); |
| 1083 | // Note that we generate dense indices of the output tensor unconditionally, |
| 1084 | // since they may not appear in the lattice, but may be needed for |
| 1085 | // linearized env. |
| 1086 | // TODO: we should avoid introducing corner cases for all-dense sparse |
| 1087 | // tensors. |
| 1088 | if (stt.hasEncoding() && stt.isAllDense()) |
| 1089 | callback(env.makeTensorLevel(t: outTid, l: *outLvl), nullptr); |
| 1090 | } |
| 1091 | |
| 1092 | if (numloopCond == 0) { |
| 1093 | // Corner cases where the loop bound is defined by a *unused* operand, in |
| 1094 | // this case, we just generate a dense "fake" loop by iterating over the |
| 1095 | // synthetic tensor. |
| 1096 | callback(env.makeTensorLevel(t: env.merger().getSynTensorID(), l: curr), nullptr); |
| 1097 | numloopCond++; |
| 1098 | } |
| 1099 | // If we just need to one loop conditions and the conditions is not imposed on |
| 1100 | // non-unique level, the loop can be generated by a for loop. |
| 1101 | // Or, if we are generating sparse-iterator-based loops, we always generate |
| 1102 | // `sparse_tensor.iterate` regardless whether the level is unique or not. |
| 1103 | return numloopCond == 1 && |
| 1104 | (!hasNonUnique || env.options().sparseEmitStrategy == |
| 1105 | SparseEmitStrategy::kSparseIterator); |
| 1106 | } |
| 1107 | |
| 1108 | /// Starts a loop sequence at given level. Returns true if |
| 1109 | /// the universal loop index must be maintained at this level. |
| 1110 | static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp, |
| 1111 | LoopId curr, LatSetId lts) { |
| 1112 | assert(!env.getLoopVar(curr)); |
| 1113 | // Emit invariants at this loop sequence level. |
| 1114 | genInvariants(env, builder, exp, curr, /*isStart=*/true); |
| 1115 | // Emit access pattern expansion for sparse tensor output. |
| 1116 | genExpand(env, builder, curr, /*isStart=*/true); |
| 1117 | // Emit further initialization at this loop sequence level. |
| 1118 | const LatPointId l0 = env.set(lts)[0]; |
| 1119 | |
| 1120 | SmallVector<TensorLevel> tidLvls; |
| 1121 | getAllTidLvlsInLatPoints(env, li: l0, curr, callback: [&](TensorLevel tl, AffineExpr) { |
| 1122 | // TODO: remove this! The same tensor level might be added for multiple |
| 1123 | // times due to the special handling for all-dense "sparse" output tensor |
| 1124 | // (see L1038). |
| 1125 | if (llvm::is_contained(Range&: tidLvls, Element: tl)) |
| 1126 | return; |
| 1127 | tidLvls.emplace_back(Args&: tl); |
| 1128 | }); |
| 1129 | |
| 1130 | env.emitter().enterNewLoopSeq(builder, loc: env.op().getLoc(), tidLvls); |
| 1131 | |
| 1132 | // Maintain the universal index only if it is actually |
| 1133 | // consumed by a subsequent lattice point. |
| 1134 | for (const LatPointId li : env.set(lts).drop_front()) |
| 1135 | if (!env.merger().hasAnySparse(bits: env.lat(l: li).simple)) |
| 1136 | return true; |
| 1137 | |
| 1138 | return false; |
| 1139 | } |
| 1140 | |
| 1141 | // Generates dense affine address for encoding. |
| 1142 | static void genConstantDenseAddressFromLevel(CodegenEnv &env, |
| 1143 | OpBuilder &builder, TensorId tid, |
| 1144 | Level startLvl) { |
| 1145 | // TODO: Handle affine expression on output tensor. |
| 1146 | linalg::GenericOp op = env.op(); |
| 1147 | assert(tid < op.getNumDpsInputs()); |
| 1148 | OpOperand *input = op.getDpsInputOperands()[tid]; |
| 1149 | const auto lvlExprs = op.getMatchingIndexingMap(opOperand: input).getResults(); |
| 1150 | const auto enc = getSparseTensorEncoding(type: input->get().getType()); |
| 1151 | if (enc) { |
| 1152 | const Location loc = op.getLoc(); |
| 1153 | const TensorId tid = env.makeTensorId(t: input->getOperandNumber()); |
| 1154 | const Level lvlRank = enc.getLvlRank(); |
| 1155 | assert(lvlExprs.size() == static_cast<size_t>(lvlRank)); |
| 1156 | for (Level l = startLvl; l < lvlRank; l++) { |
| 1157 | AffineExpr lvlExpr = lvlExprs[l]; |
| 1158 | if (enc.getLvlType(l).hasDenseSemantic() && |
| 1159 | isa<AffineConstantExpr>(Val: lvlExpr)) |
| 1160 | env.emitter().locateLvlAtAffineAddress( |
| 1161 | builder, loc, tidLvl: env.makeTensorLevel(t: tid, l), lvlExpr); |
| 1162 | else |
| 1163 | return; // break on first non-dense non-constant level |
| 1164 | } |
| 1165 | } |
| 1166 | } |
| 1167 | |
| 1168 | // We can generate address for constant affine expression before any loops |
| 1169 | // starting from the first level as they do not depend on anything. |
| 1170 | // E.g., [Dense, Dense, Sparse] -> (1, 2, d0), the addresses for the first two |
| 1171 | // levels can be determined before loops. |
| 1172 | static void genInitConstantDenseAddress(CodegenEnv &env, |
| 1173 | RewriterBase &rewriter) { |
| 1174 | for (TensorId tid = 0, e = env.op().getNumDpsInputs(); tid < e; tid++) |
| 1175 | genConstantDenseAddressFromLevel(env, builder&: rewriter, tid, startLvl: 0); |
| 1176 | } |
| 1177 | |
| 1178 | /// Returns true if the lattice bit can be iterated by a for loop. |
| 1179 | static bool translateBitsToTidLvlPairs( |
| 1180 | CodegenEnv &env, LatPointId li, LoopId curr, |
| 1181 | SmallVectorImpl<TensorLevel> &tidLvls, |
| 1182 | SmallVectorImpl<std::pair<TensorLevel, AffineExpr>> &affineTidLvls) { |
| 1183 | return getAllTidLvlsInLatPoints(env, li, curr, |
| 1184 | callback: [&](TensorLevel tl, AffineExpr exp) { |
| 1185 | if (exp) |
| 1186 | affineTidLvls.emplace_back(Args&: tl, Args&: exp); |
| 1187 | else |
| 1188 | tidLvls.emplace_back(Args&: tl); |
| 1189 | }); |
| 1190 | } |
| 1191 | |
| 1192 | /// Starts a single loop in current sequence. |
| 1193 | static std::pair<Operation *, bool> startLoop(CodegenEnv &env, |
| 1194 | OpBuilder &builder, LoopId curr, |
| 1195 | LatPointId li, unsigned numCases, |
| 1196 | bool needsUniv) { |
| 1197 | // TODO: numCases only used when generating iterator-based loops. Cleanup |
| 1198 | // after fully migration. |
| 1199 | // The set of tensors + lvls to generate loops on |
| 1200 | SmallVector<TensorLevel> tidLvls; |
| 1201 | |
| 1202 | // The set of dense tensors with non-trivial affine expression that just |
| 1203 | // becomes invariant and the address are generated at the current level. |
| 1204 | SmallVector<std::pair<TensorLevel, AffineExpr>> affineTidLvls; |
| 1205 | bool isSingleCond = |
| 1206 | translateBitsToTidLvlPairs(env, li, curr, tidLvls, affineTidLvls); |
| 1207 | |
| 1208 | // Emit the for/while-loop control. |
| 1209 | Operation *loop = genLoop(env, builder, curr, numCases, needsUniv, tidLvls); |
| 1210 | Location loc = env.op().getLoc(); |
| 1211 | for (auto [tidLvl, exp] : affineTidLvls) { |
| 1212 | env.emitter().locateLvlAtAffineAddress(builder, loc, tidLvl, lvlExpr: exp); |
| 1213 | } |
| 1214 | |
| 1215 | // Until now, we have entered every <tid, lvl> pair in {cond, extra, |
| 1216 | // affine}Tids/Lvls. The addresses of the upcoming levels which are dependent |
| 1217 | // on constant affines expression may now be determined. |
| 1218 | auto allTidLvls = |
| 1219 | llvm::concat<TensorLevel>(Ranges&: tidLvls, Ranges: llvm::make_first_range(c&: affineTidLvls)); |
| 1220 | for (auto [tid, lvl] : env.unpackTensorLevelRange(c&: allTidLvls)) { |
| 1221 | if (tid != env.merger().getOutTensorID() && |
| 1222 | tid != env.merger().getSynTensorID()) |
| 1223 | genConstantDenseAddressFromLevel(env, builder, tid, startLvl: lvl + 1); |
| 1224 | } |
| 1225 | |
| 1226 | return std::make_pair(x&: loop, y&: isSingleCond); |
| 1227 | } |
| 1228 | |
| 1229 | /// Ends a single loop in current sequence. Returns new values for needsUniv. |
| 1230 | static bool endLoop(CodegenEnv &env, RewriterBase &rewriter, Operation *loop, |
| 1231 | LatPointId li, bool needsUniv, bool isSingleCond) { |
| 1232 | // Either a for-loop or a while-loop that iterates over a slice. |
| 1233 | if (isSingleCond) { |
| 1234 | // Any iteration creates a valid lex insert. |
| 1235 | if (env.isReduc() && env.isValidLexInsert()) |
| 1236 | env.updateValidLexInsert(val: constantI1(builder&: rewriter, loc: env.op().getLoc(), b: true)); |
| 1237 | } else if (auto whileOp = dyn_cast<scf::WhileOp>(Val: loop)) { |
| 1238 | // End a while-loop. |
| 1239 | finalizeWhileOp(env, builder&: rewriter, needsUniv); |
| 1240 | } else { |
| 1241 | needsUniv = false; |
| 1242 | } |
| 1243 | env.genLoopBoundary(callback: [&](MutableArrayRef<Value> reduc) { |
| 1244 | env.emitter().exitCurrentLoop(rewriter, loc: env.op().getLoc(), reduc); |
| 1245 | return std::nullopt; |
| 1246 | }); |
| 1247 | return needsUniv; |
| 1248 | } |
| 1249 | |
| 1250 | /// Ends a loop sequence at given level. |
| 1251 | static void endLoopSeq(CodegenEnv &env, OpBuilder &builder, unsigned exp, |
| 1252 | unsigned at) { |
| 1253 | assert(!env.getLoopVar(at)); |
| 1254 | env.emitter().exitCurrentLoopSeq(builder, loc: env.op().getLoc()); |
| 1255 | // Unmark bookkeeping of invariants and loop index. |
| 1256 | genInvariants(env, builder, exp, curr: at, /*isStart=*/false); |
| 1257 | // Finalize access pattern expansion for sparse tensor output. |
| 1258 | genExpand(env, builder, curr: at, /*isStart=*/false); |
| 1259 | } |
| 1260 | |
| 1261 | /// Recursively generates code while computing iteration lattices in order |
| 1262 | /// to manage the complexity of implementing co-iteration over unions |
| 1263 | /// and intersections of sparse iterations spaces. |
| 1264 | static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp, |
| 1265 | LoopId curr) { |
| 1266 | assert(curr == env.getCurrentDepth()); |
| 1267 | |
| 1268 | // At each leaf, assign remaining tensor (sub)expression to output tensor. |
| 1269 | if (curr == env.getLoopNum()) { |
| 1270 | Value rhs = genExp(env, rewriter, e: exp); |
| 1271 | genTensorStore(env, builder&: rewriter, exp, rhs); |
| 1272 | return; |
| 1273 | } |
| 1274 | |
| 1275 | // Construct iteration lattices for current loop index. |
| 1276 | const LatSetId lts = |
| 1277 | env.merger().optimizeSet(s: env.merger().buildLattices(e: exp, i: curr)); |
| 1278 | |
| 1279 | // Start a loop sequence. |
| 1280 | bool needsUniv = startLoopSeq(env, builder&: rewriter, exp, curr, lts); |
| 1281 | |
| 1282 | // When using sparse-iterator-based loops, we only need one loops, as |
| 1283 | // opposed to a loop sequence, to cover all the iterator spaces. |
| 1284 | const unsigned lsize = env.set(lts).size(); |
| 1285 | if (env.generatingSparseIterator()) { |
| 1286 | // Get the largest lattice point and start a loop. |
| 1287 | const LatPointId li = env.set(lts)[0]; |
| 1288 | auto [loop, isSingleCond] = |
| 1289 | startLoop(env, builder&: rewriter, curr, li, numCases: lsize, needsUniv); |
| 1290 | assert(isSingleCond == llvm::isa<IterateOp>(loop)); |
| 1291 | // We cannot change this to `for (const LatPointId li : env.set(lts))` |
| 1292 | // because the loop body causes data-movement which invalidates |
| 1293 | // the iterator. |
| 1294 | for (unsigned j = 0; j < lsize; j++) { |
| 1295 | const LatPointId lj = env.set(lts)[j]; |
| 1296 | const ExprId ej = env.lat(l: lj).exp; |
| 1297 | // Recurse into body of each branch. |
| 1298 | if (!isSingleCond) { |
| 1299 | env.genLoopBoundary(callback: [&, curr, j, li, lj](MutableArrayRef<Value> reduc) { |
| 1300 | genCoIterationCase(env, builder&: rewriter, /*caseIdx*/ j, allCase: li, curCase: lj, reduc); |
| 1301 | genStmt(env, rewriter, exp: ej, curr: curr + 1); |
| 1302 | // TODO: handle yield values. |
| 1303 | assert(reduc.empty() && "Not Implemented" ); |
| 1304 | rewriter.create<sparse_tensor::YieldOp>(location: env.op().getLoc()); |
| 1305 | return std::nullopt; |
| 1306 | }); |
| 1307 | // endIf(env, rewriter, ifOp, redInput, cntInput, insInput, validIns); |
| 1308 | } else { |
| 1309 | genStmt(env, rewriter, exp: ej, curr: curr + 1); |
| 1310 | } |
| 1311 | } |
| 1312 | // End a loop. |
| 1313 | needsUniv = endLoop(env, rewriter, loop, li: curr, needsUniv, isSingleCond); |
| 1314 | } else { |
| 1315 | // Emit a loop for every lattice point L0 >= Li in this loop sequence. |
| 1316 | for (unsigned i = 0; i < lsize; i++) { |
| 1317 | const LatPointId li = env.set(lts)[i]; |
| 1318 | // Start a loop. |
| 1319 | auto [loop, isSingleCond] = |
| 1320 | startLoop(env, builder&: rewriter, curr, li, numCases: lsize, needsUniv); |
| 1321 | |
| 1322 | // Visit all lattices points with Li >= Lj to generate the |
| 1323 | // loop-body, possibly with if statements for coiteration. |
| 1324 | Value redInput = env.getReduc(); |
| 1325 | Value cntInput = env.getExpandCount(); |
| 1326 | Value insInput = env.getInsertionChain(); |
| 1327 | Value validIns = env.getValidLexInsert(); |
| 1328 | // We cannot change this to `for (const LatPointId lj : env.set(lts))` |
| 1329 | // because the loop body causes data-movement which invalidates the |
| 1330 | // iterator. |
| 1331 | for (unsigned j = 0; j < lsize; j++) { |
| 1332 | const LatPointId lj = env.set(lts)[j]; |
| 1333 | const ExprId ej = env.lat(l: lj).exp; |
| 1334 | if (li == lj || env.merger().latGT(p0: li, p1: lj)) { |
| 1335 | // Recurse into body of each branch. |
| 1336 | if (!isSingleCond) { |
| 1337 | scf::IfOp ifOp = genIf(env, builder&: rewriter, curr, p: lj); |
| 1338 | genStmt(env, rewriter, exp: ej, curr: curr + 1); |
| 1339 | endIf(env, builder&: rewriter, ifOp, redInput, cntInput, insInput, validIns); |
| 1340 | } else { |
| 1341 | genStmt(env, rewriter, exp: ej, curr: curr + 1); |
| 1342 | } |
| 1343 | } |
| 1344 | } |
| 1345 | |
| 1346 | // End a loop. |
| 1347 | needsUniv = endLoop(env, rewriter, loop, li: curr, needsUniv, isSingleCond); |
| 1348 | } |
| 1349 | } |
| 1350 | |
| 1351 | // End a loop sequence. |
| 1352 | endLoopSeq(env, builder&: rewriter, exp, at: curr); |
| 1353 | assert(curr == env.getCurrentDepth()); |
| 1354 | } |
| 1355 | |
| 1356 | /// Converts the result computed by the sparse kernel into the required form. |
| 1357 | static void genResult(CodegenEnv &env, RewriterBase &rewriter) { |
| 1358 | linalg::GenericOp op = env.op(); |
| 1359 | OpOperand *lhs = op.getDpsInitOperand(i: 0); |
| 1360 | Value tensor = lhs->get(); |
| 1361 | Type resType = tensor.getType(); |
| 1362 | if (getSparseTensorEncoding(type: resType)) { |
| 1363 | // The sparse tensor rematerializes from the original sparse tensor's |
| 1364 | // underlying sparse storage format. For an insertion chain, the |
| 1365 | // tensor materializes from the chain with 'hasInserts' enabled. |
| 1366 | bool hasInserts = false; |
| 1367 | if (Value chain = env.getInsertionChain()) { |
| 1368 | hasInserts = true; |
| 1369 | tensor = chain; |
| 1370 | } |
| 1371 | rewriter.replaceOpWithNewOp<LoadOp>(op, args&: resType, args&: tensor, args&: hasInserts); |
| 1372 | } else { |
| 1373 | // To rematerialize an non-annotated tensor, simply load it |
| 1374 | // from the bufferized value. |
| 1375 | Value val = env.emitter().getValBuffer()[env.merger().getOutTensorID()]; |
| 1376 | rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, args&: resType, args&: val); |
| 1377 | } |
| 1378 | } |
| 1379 | |
| 1380 | //===----------------------------------------------------------------------===// |
| 1381 | // Sparsifier rewriting methods. |
| 1382 | //===----------------------------------------------------------------------===// |
| 1383 | |
| 1384 | namespace { |
| 1385 | |
| 1386 | /// Sparse rewriting rule for generic Lingalg operation. |
| 1387 | struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> { |
| 1388 | public: |
| 1389 | GenericOpSparsifier(MLIRContext *context, SparsificationOptions o) |
| 1390 | : OpRewritePattern<linalg::GenericOp>(context), options(o) {} |
| 1391 | |
| 1392 | LogicalResult matchAndRewrite(linalg::GenericOp op, |
| 1393 | PatternRewriter &rewriter) const override { |
| 1394 | // Only accept single output operations with pure tensor semantics. |
| 1395 | if (op.getNumDpsInits() != 1 || !op.hasPureTensorSemantics()) |
| 1396 | return failure(); |
| 1397 | |
| 1398 | // Only accept trivial affine indices. |
| 1399 | if (hasNonTrivialAffineOnSparseOut(op)) |
| 1400 | return failure(); |
| 1401 | |
| 1402 | // Only accept scheduled loops. |
| 1403 | if (!op->hasAttr(name: "sorted" )) { |
| 1404 | return rewriter.notifyMatchFailure( |
| 1405 | arg&: op, msg: "Loops not yet scheduled, try run --sparse-reinterpret-map " |
| 1406 | "before sparsification." ); |
| 1407 | } |
| 1408 | |
| 1409 | // Must have been demapped as well if the generic op is sorted. |
| 1410 | assert(!hasAnyNonIdentityOperandsOrResults(op)); |
| 1411 | |
| 1412 | // Sets up a code generation environment. |
| 1413 | const unsigned numTensors = op->getNumOperands(); |
| 1414 | const unsigned numLoops = op.getNumLoops(); |
| 1415 | bool needIdxRed = getNumNonTrivialIdxExpOnSparseLvls(op) != 0; |
| 1416 | // If we have indexing map like (d0) -> (0, d0), there might be more |
| 1417 | // levels then loops because of the constant index, that means we can not |
| 1418 | // use numLoops as the upper bound for ranks of all tensors. |
| 1419 | // TODO: Constant indices are currently not support on sparse tensor, but |
| 1420 | // are allowed in non-annotated dense tensor. Support it, it would be |
| 1421 | // required for sparse tensor slice rank reducing too. |
| 1422 | Level maxLvlRank = 0; |
| 1423 | for (auto operand : op.getOperands()) { |
| 1424 | if (auto rtp = dyn_cast<RankedTensorType>(Val: operand.getType())) { |
| 1425 | maxLvlRank = std::max(a: maxLvlRank, b: SparseTensorType(rtp).getLvlRank()); |
| 1426 | } |
| 1427 | } |
| 1428 | |
| 1429 | // Detects sparse annotations and translates the per-level sparsity |
| 1430 | // information for all tensors to loop indices in the kernel. |
| 1431 | CodegenEnv env(op, options, numTensors, numLoops, maxLvlRank); |
| 1432 | if (!findSparseAnnotations(env, idxReducBased: needIdxRed)) |
| 1433 | return failure(); |
| 1434 | |
| 1435 | // Only standard reduction operations (add, sub, or, xor) that can be |
| 1436 | // sparsified by merely reducing the stored values are admissible. More |
| 1437 | // elaborate reduction operations (such as mul, and, min, max) would need |
| 1438 | // to know whether implicit zeros occur as well. They can still be |
| 1439 | // implemented with a custom reduction operation, accepted here as well. |
| 1440 | if (op.getNumReductionLoops() > 0) { |
| 1441 | Operation *yield = op.getRegion().front().getTerminator(); |
| 1442 | assert(isa<linalg::YieldOp>(yield)); |
| 1443 | Operation *redop = yield->getOperand(idx: 0).getDefiningOp(); |
| 1444 | if (!isa<arith::AddFOp>(Val: redop) && !isa<complex::AddOp>(Val: redop) && |
| 1445 | !isa<arith::AddIOp>(Val: redop) && !isa<arith::SubFOp>(Val: redop) && |
| 1446 | !isa<complex::SubOp>(Val: redop) && !isa<arith::SubIOp>(Val: redop) && |
| 1447 | !isa<arith::OrIOp>(Val: redop) && !isa<arith::XOrIOp>(Val: redop) && |
| 1448 | !isa<ReduceOp>(Val: redop)) { |
| 1449 | return failure(); |
| 1450 | } |
| 1451 | } |
| 1452 | |
| 1453 | // Constructs the tensor expressions tree from `op`, returns failure if the |
| 1454 | // tree can not be built or the tensor expression is inadmissible. |
| 1455 | if (failed(Result: env.initTensorExp())) |
| 1456 | return failure(); |
| 1457 | |
| 1458 | // Recursively generates code if admissible. |
| 1459 | env.startEmit(emitStrategy: options.sparseEmitStrategy); |
| 1460 | genBuffers(env, builder&: rewriter); |
| 1461 | // TODO: Constant affine expression should be handled differently when using |
| 1462 | // slice-based codegen, it does not matter now because we already reject the |
| 1463 | // constant expression at an earlier stage. |
| 1464 | genInitConstantDenseAddress(env, rewriter); |
| 1465 | genStmt(env, rewriter, exp: env.getExprId(), curr: 0); |
| 1466 | genResult(env, rewriter); |
| 1467 | return success(); |
| 1468 | } |
| 1469 | |
| 1470 | private: |
| 1471 | /// Options to control sparse code generation. |
| 1472 | SparsificationOptions options; |
| 1473 | }; |
| 1474 | |
| 1475 | } // namespace |
| 1476 | |
| 1477 | /// Populates the given patterns list with rewriting rules required for |
| 1478 | /// the sparsification of linear algebra operations. |
| 1479 | void mlir::populateSparsificationPatterns( |
| 1480 | RewritePatternSet &patterns, const SparsificationOptions &options) { |
| 1481 | patterns.add<GenericOpSparsifier>(arg: patterns.getContext(), args: options); |
| 1482 | } |
| 1483 | |