| 1 | //===- CodegenEnv.cpp - Code generation environment class ----------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "CodegenEnv.h" |
| 10 | |
| 11 | #include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
| 12 | #include "mlir/Dialect/Linalg/Utils/Utils.h" |
| 13 | #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" |
| 14 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| 15 | |
| 16 | #include <optional> |
| 17 | |
| 18 | using namespace mlir; |
| 19 | using namespace mlir::sparse_tensor; |
| 20 | |
| 21 | //===----------------------------------------------------------------------===// |
| 22 | // Code generation environment helper functions |
| 23 | //===----------------------------------------------------------------------===// |
| 24 | |
| 25 | /// Returns true if tensor materializes uninitialized into the computation. |
| 26 | static bool isMaterializing(Value val) { |
| 27 | return val.getDefiningOp<tensor::EmptyOp>() || |
| 28 | val.getDefiningOp<bufferization::AllocTensorOp>(); |
| 29 | } |
| 30 | |
| 31 | /// Sorts the dependent loops such that it is ordered in the same sequence in |
| 32 | /// which loops will be generated. |
| 33 | static void sortDependentLoops(std::vector<LoopCoeffPair> &target) { |
| 34 | llvm::sort(C&: target, Comp: [](const LoopCoeffPair &l, const LoopCoeffPair &r) { |
| 35 | assert(std::addressof(l) == std::addressof(r) || l != r); |
| 36 | return l.first < r.first; |
| 37 | }); |
| 38 | } |
| 39 | //===----------------------------------------------------------------------===// |
| 40 | // Code generation environment constructor and general methods |
| 41 | //===----------------------------------------------------------------------===// |
| 42 | |
| 43 | CodegenEnv::CodegenEnv(linalg::GenericOp linop, SparsificationOptions opts, |
| 44 | unsigned numTensors, unsigned numLoops, unsigned maxRank) |
| 45 | : linalgOp(linop), sparseOptions(opts), |
| 46 | latticeMerger(numTensors, numLoops, maxRank), loopEmitter(), |
| 47 | sparseOut(nullptr), outerParNest(-1u), insChain(), expValues(), |
| 48 | expFilled(), expAdded(), expCount(), redVal(), redExp(detail::kInvalidId), |
| 49 | redCustom(detail::kInvalidId), redValidLexInsert() {} |
| 50 | |
| 51 | LogicalResult CodegenEnv::initTensorExp() { |
| 52 | // Builds the tensor expression for the Linalg operation in SSA form. |
| 53 | std::optional<ExprId> optExp = latticeMerger.buildTensorExpFromLinalg(op()); |
| 54 | if (!optExp || !isAdmissibleTensorExp(e: *optExp)) |
| 55 | return failure(); |
| 56 | |
| 57 | tensorExp = *optExp; |
| 58 | return success(); |
| 59 | } |
| 60 | |
| 61 | void CodegenEnv::startEmit(SparseEmitStrategy emitStrategy) { |
| 62 | assert(insChain == nullptr && "must only start emitting once" ); |
| 63 | if (sparseOut) { |
| 64 | insChain = sparseOut->get(); |
| 65 | latticeMerger.setHasSparseOut(true); |
| 66 | } |
| 67 | |
| 68 | // Sort the related loop array such that they are in the same order as they |
| 69 | // appears on the topoOrder. |
| 70 | // TODO: since we only handle affine addition for slice based codegen, and |
| 71 | // addition is assoicative, the order how we evaluate the expression does |
| 72 | // not matter. However, to support multiplication, the order of the loop |
| 73 | // index should match the evaluation order to the affine expression AST. |
| 74 | |
| 75 | // Initialize loop emitter. |
| 76 | SmallVector<Value> tensors; // input tensors passed to loop emitter |
| 77 | for (OpOperand &t : linalgOp->getOpOperands()) { |
| 78 | tensors.push_back(t.get()); |
| 79 | const TensorId tid = makeTensorId(t.getOperandNumber()); |
| 80 | const Level lvlRank = linalgOp.getMatchingIndexingMap(&t).getNumResults(); |
| 81 | const auto enc = getSparseTensorEncoding(t.get().getType()); |
| 82 | (void)enc; |
| 83 | assert(!enc || lvlRank == enc.getLvlRank()); |
| 84 | for (Level lvl = 0; lvl < lvlRank; lvl++) |
| 85 | sortDependentLoops(latticeMerger.getDependentLoops(tid, lvl)); |
| 86 | } |
| 87 | loopEmitter.initialize( |
| 88 | tensors, |
| 89 | StringAttr::get(linalgOp.getContext(), |
| 90 | linalg::GenericOp::getOperationName()), |
| 91 | /*hasOutput=*/true, |
| 92 | /*isSparseOut=*/sparseOut != nullptr, /*numLoops=*/getLoopNum(), |
| 93 | // TODO: compute the map and pass it to loop emitter directly instead of |
| 94 | // passing in a callback. |
| 95 | /*dependentLvlGetter=*/ |
| 96 | [this](TensorId t, Level lvl) -> std::vector<LoopCoeffPair> { |
| 97 | return merger().getDependentLoops(t, lvl); |
| 98 | }, |
| 99 | emitStrategy); |
| 100 | } |
| 101 | |
| 102 | std::optional<Operation *> CodegenEnv::genLoopBoundary( |
| 103 | function_ref<std::optional<Operation *>(MutableArrayRef<Value> parameters)> |
| 104 | callback) { |
| 105 | SmallVector<Value> params; |
| 106 | if (isReduc()) { |
| 107 | params.push_back(Elt: redVal); |
| 108 | if (isValidLexInsert()) |
| 109 | params.push_back(Elt: redValidLexInsert); |
| 110 | } else { |
| 111 | assert(!isValidLexInsert()); |
| 112 | } |
| 113 | if (isExpand()) |
| 114 | params.push_back(Elt: expCount); |
| 115 | if (insChain != nullptr) |
| 116 | params.push_back(Elt: insChain); |
| 117 | auto r = callback(params); // may update parameters |
| 118 | unsigned i = 0; |
| 119 | if (isReduc()) { |
| 120 | updateReduc(val: params[i++]); |
| 121 | if (isValidLexInsert()) |
| 122 | updateValidLexInsert(val: params[i++]); |
| 123 | } |
| 124 | if (isExpand()) |
| 125 | updateExpandCount(count: params[i++]); |
| 126 | if (insChain != nullptr) |
| 127 | updateInsertionChain(chain: params[i]); |
| 128 | return r; |
| 129 | } |
| 130 | |
| 131 | //===----------------------------------------------------------------------===// |
| 132 | // Code generation environment verify functions. |
| 133 | //===----------------------------------------------------------------------===// |
| 134 | |
| 135 | bool CodegenEnv::isAdmissibleTensorExp(ExprId exp) { |
| 136 | // We reject any expression that makes a reduction from `-outTensor`, as those |
| 137 | // expressions create a dependency between the current iteration (i) and the |
| 138 | // previous iteration (i-1). It would require iterating over the whole |
| 139 | // coordinate space, which prevent exploiting sparsity for faster code. |
| 140 | for (utils::IteratorType it : linalgOp.getIteratorTypesArray()) { |
| 141 | if (it == utils::IteratorType::reduction) { |
| 142 | if (latticeMerger.hasNegateOnOut(exp)) |
| 143 | return false; |
| 144 | break; |
| 145 | } |
| 146 | } |
| 147 | |
| 148 | OpOperand *lhs = linalgOp.getDpsInitOperand(0); |
| 149 | const TensorId tensor = makeTensorId(t: lhs->getOperandNumber()); |
| 150 | // An non-annotated output tensor is assumed dense, and becomes a random |
| 151 | // access n-dim memref. Admissible since insertions cannot occur. |
| 152 | if (getSparseTensorType(val: lhs->get()).isAllDense()) |
| 153 | return true; |
| 154 | |
| 155 | // A tensor expression with a sparse output tensor that changes its values |
| 156 | // but not its nonzero structure, an operation called "simply dynamic" in |
| 157 | // [Bik96,Ch9], is also admissible without special env. |
| 158 | if (latticeMerger.isSingleCondition(t: tensor, e: exp)) |
| 159 | return true; |
| 160 | |
| 161 | // Accept "truly dynamic" if the output tensor materializes uninitialized |
| 162 | // into the computation and insertions occur in lexicographic index order. |
| 163 | sparseOut = lhs; |
| 164 | |
| 165 | // Find the outermost parallel nest to determine whether compress/expand is |
| 166 | // needed. |
| 167 | outerParNest = 0; |
| 168 | const auto iteratorTypes = linalgOp.getIteratorTypesArray(); |
| 169 | for (unsigned i = 0, e = getLoopNum(); i < e; i++) { |
| 170 | if (linalg::isReductionIterator(iteratorTypes[i])) |
| 171 | break; // terminate at first reduction |
| 172 | outerParNest++; |
| 173 | } |
| 174 | |
| 175 | // Inadmissible kernel should have already been rejected by the previous |
| 176 | // path during loop scheduling. |
| 177 | assert(static_cast<int64_t>(outerParNest) >= |
| 178 | linalgOp.getRank(linalgOp.getDpsInitOperand(0)) - 1); |
| 179 | return isMaterializing(val: lhs->get()); |
| 180 | } |
| 181 | |
| 182 | //===----------------------------------------------------------------------===// |
| 183 | // Code generation environment topological sort methods |
| 184 | //===----------------------------------------------------------------------===// |
| 185 | |
| 186 | Value CodegenEnv::getLoopVar(LoopId i) const { |
| 187 | return loopEmitter.getLoopIV(i); |
| 188 | } |
| 189 | |
| 190 | //===----------------------------------------------------------------------===// |
| 191 | // Code generation environment sparse tensor output and expansion methods |
| 192 | //===----------------------------------------------------------------------===// |
| 193 | |
| 194 | void CodegenEnv::updateInsertionChain(Value chain) { |
| 195 | assert(sparseOut != nullptr && insChain != nullptr); |
| 196 | insChain = chain; |
| 197 | } |
| 198 | |
| 199 | bool CodegenEnv::atExpandLevel(OpOperand *o, unsigned rank, LoopId n) const { |
| 200 | return sparseOut == o && outerParNest == static_cast<LoopId>(rank - 1) && |
| 201 | outerParNest == n; |
| 202 | } |
| 203 | |
| 204 | void CodegenEnv::startExpand(Value values, Value filled, Value added, |
| 205 | Value count) { |
| 206 | assert(sparseOut != nullptr && expValues == nullptr); |
| 207 | expValues = values; |
| 208 | expFilled = filled; |
| 209 | expAdded = added; |
| 210 | expCount = count; |
| 211 | } |
| 212 | |
| 213 | void CodegenEnv::updateExpandCount(Value count) { |
| 214 | assert(sparseOut != nullptr && expValues != nullptr); |
| 215 | expCount = count; |
| 216 | } |
| 217 | |
| 218 | void CodegenEnv::endExpand() { |
| 219 | assert(sparseOut != nullptr && expValues != nullptr); |
| 220 | expValues = expFilled = expAdded = expCount = Value(); |
| 221 | } |
| 222 | |
| 223 | //===----------------------------------------------------------------------===// |
| 224 | // Code generation environment reduction methods |
| 225 | //===----------------------------------------------------------------------===// |
| 226 | |
| 227 | void CodegenEnv::startReduc(ExprId exp, Value val) { |
| 228 | assert(!isReduc() && exp != detail::kInvalidId && val); |
| 229 | redExp = exp; |
| 230 | redVal = val; |
| 231 | latticeMerger.setExprValue(e: exp, v: val); |
| 232 | } |
| 233 | |
| 234 | void CodegenEnv::updateReduc(Value val) { |
| 235 | assert(isReduc() && val); |
| 236 | redVal = val; |
| 237 | latticeMerger.clearExprValue(e: redExp); |
| 238 | latticeMerger.setExprValue(e: redExp, v: val); |
| 239 | } |
| 240 | |
| 241 | Value CodegenEnv::endReduc() { |
| 242 | assert(isReduc()); |
| 243 | Value val = redVal; |
| 244 | redVal = val; |
| 245 | latticeMerger.clearExprValue(e: redExp); |
| 246 | redExp = detail::kInvalidId; |
| 247 | return val; |
| 248 | } |
| 249 | |
| 250 | void CodegenEnv::startValidLexInsert(Value val) { |
| 251 | assert(!isValidLexInsert() && isReduc() && val); |
| 252 | redValidLexInsert = val; |
| 253 | } |
| 254 | |
| 255 | void CodegenEnv::updateValidLexInsert(Value val) { |
| 256 | assert(redValidLexInsert && isReduc() && val); |
| 257 | redValidLexInsert = val; |
| 258 | } |
| 259 | |
| 260 | void CodegenEnv::endValidLexInsert() { |
| 261 | assert(isValidLexInsert() && !isReduc()); |
| 262 | redValidLexInsert = Value(); |
| 263 | } |
| 264 | |
| 265 | void CodegenEnv::startCustomReduc(ExprId exp) { |
| 266 | assert(!isCustomReduc() && exp != detail::kInvalidId); |
| 267 | redCustom = exp; |
| 268 | } |
| 269 | |
| 270 | Value CodegenEnv::getCustomRedId() const { |
| 271 | assert(isCustomReduc()); |
| 272 | return dyn_cast<sparse_tensor::ReduceOp>(exp(redCustom).op).getIdentity(); |
| 273 | } |
| 274 | |
| 275 | void CodegenEnv::endCustomReduc() { |
| 276 | assert(isCustomReduc()); |
| 277 | redCustom = detail::kInvalidId; |
| 278 | } |
| 279 | |