1//===- CodegenEnv.cpp - Code generation environment class ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "CodegenEnv.h"
10
11#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
12#include "mlir/Dialect/Linalg/Utils/Utils.h"
13#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
14#include "mlir/Dialect/Tensor/IR/Tensor.h"
15
16#include <optional>
17
18using namespace mlir;
19using namespace mlir::sparse_tensor;
20
21//===----------------------------------------------------------------------===//
22// Code generation environment helper functions
23//===----------------------------------------------------------------------===//
24
25/// Returns true if tensor materializes uninitialized into the computation.
26static bool isMaterializing(Value val) {
27 return val.getDefiningOp<tensor::EmptyOp>() ||
28 val.getDefiningOp<bufferization::AllocTensorOp>();
29}
30
31/// Sorts the dependent loops such that it is ordered in the same sequence in
32/// which loops will be generated.
33static void sortDependentLoops(std::vector<LoopCoeffPair> &target) {
34 std::sort(first: target.begin(), last: target.end(),
35 comp: [](const LoopCoeffPair &l, const LoopCoeffPair &r) {
36 assert(std::addressof(l) == std::addressof(r) || l != r);
37 return l.first < r.first;
38 });
39}
40//===----------------------------------------------------------------------===//
41// Code generation environment constructor and general methods
42//===----------------------------------------------------------------------===//
43
44CodegenEnv::CodegenEnv(linalg::GenericOp linop, SparsificationOptions opts,
45 unsigned numTensors, unsigned numLoops, unsigned maxRank)
46 : linalgOp(linop), sparseOptions(opts),
47 latticeMerger(numTensors, numLoops, maxRank), loopEmitter(),
48 sparseOut(nullptr), outerParNest(-1u), insChain(), expValues(),
49 expFilled(), expAdded(), expCount(), redVal(), redExp(detail::kInvalidId),
50 redCustom(detail::kInvalidId), redValidLexInsert() {}
51
52LogicalResult CodegenEnv::initTensorExp() {
53 // Builds the tensor expression for the Linalg operation in SSA form.
54 std::optional<ExprId> optExp = latticeMerger.buildTensorExpFromLinalg(op());
55 if (!optExp || !isAdmissibleTensorExp(e: *optExp))
56 return failure();
57
58 tensorExp = *optExp;
59 return success();
60}
61
62void CodegenEnv::startEmit(SparseEmitStrategy emitStrategy) {
63 assert(insChain == nullptr && "must only start emitting once");
64 if (sparseOut) {
65 insChain = sparseOut->get();
66 latticeMerger.setHasSparseOut(true);
67 }
68
69 // Sort the related loop array such that they are in the same order as they
70 // appears on the topoOrder.
71 // TODO: since we only handle affine addition for slice based codegen, and
72 // addition is assoicative, the order how we evaluate the expression does
73 // not matter. However, to support multiplication, the order of the loop
74 // index should match the evaluation order to the affine expression AST.
75
76 // Initialize loop emitter.
77 SmallVector<Value> tensors; // input tensors passed to loop emitter
78 for (OpOperand &t : linalgOp->getOpOperands()) {
79 tensors.push_back(t.get());
80 const TensorId tid = makeTensorId(t.getOperandNumber());
81 const Level lvlRank = linalgOp.getMatchingIndexingMap(&t).getNumResults();
82 const auto enc = getSparseTensorEncoding(t.get().getType());
83 (void)enc;
84 assert(!enc || lvlRank == enc.getLvlRank());
85 for (Level lvl = 0; lvl < lvlRank; lvl++)
86 sortDependentLoops(latticeMerger.getDependentLoops(tid, lvl));
87 }
88 loopEmitter.initialize(
89 tensors,
90 StringAttr::get(linalgOp.getContext(),
91 linalg::GenericOp::getOperationName()),
92 /*hasOutput=*/true,
93 /*isSparseOut=*/sparseOut != nullptr, /*numLoops=*/getLoopNum(),
94 // TODO: compute the map and pass it to loop emitter directly instead of
95 // passing in a callback.
96 /*dependentLvlGetter=*/
97 [this](TensorId t, Level lvl) -> std::vector<LoopCoeffPair> {
98 return merger().getDependentLoops(t, lvl);
99 },
100 emitStrategy);
101}
102
103std::optional<Operation *> CodegenEnv::genLoopBoundary(
104 function_ref<std::optional<Operation *>(MutableArrayRef<Value> parameters)>
105 callback) {
106 SmallVector<Value> params;
107 if (isReduc()) {
108 params.push_back(Elt: redVal);
109 if (isValidLexInsert())
110 params.push_back(Elt: redValidLexInsert);
111 } else {
112 assert(!isValidLexInsert());
113 }
114 if (isExpand())
115 params.push_back(Elt: expCount);
116 if (insChain != nullptr)
117 params.push_back(Elt: insChain);
118 auto r = callback(params); // may update parameters
119 unsigned i = 0;
120 if (isReduc()) {
121 updateReduc(val: params[i++]);
122 if (isValidLexInsert())
123 updateValidLexInsert(val: params[i++]);
124 }
125 if (isExpand())
126 updateExpandCount(count: params[i++]);
127 if (insChain != nullptr)
128 updateInsertionChain(chain: params[i]);
129 return r;
130}
131
132//===----------------------------------------------------------------------===//
133// Code generation environment verify functions.
134//===----------------------------------------------------------------------===//
135
136bool CodegenEnv::isAdmissibleTensorExp(ExprId exp) {
137 // We reject any expression that makes a reduction from `-outTensor`, as those
138 // expressions create a dependency between the current iteration (i) and the
139 // previous iteration (i-1). It would require iterating over the whole
140 // coordinate space, which prevent exploiting sparsity for faster code.
141 for (utils::IteratorType it : linalgOp.getIteratorTypesArray()) {
142 if (it == utils::IteratorType::reduction) {
143 if (latticeMerger.hasNegateOnOut(exp))
144 return false;
145 break;
146 }
147 }
148
149 OpOperand *lhs = linalgOp.getDpsInitOperand(0);
150 const TensorId tensor = makeTensorId(t: lhs->getOperandNumber());
151 // An non-annotated output tensor is assumed dense, and becomes a random
152 // access n-dim memref. Admissible since insertions cannot occur.
153 if (getSparseTensorType(val: lhs->get()).isAllDense())
154 return true;
155
156 // A tensor expression with a sparse output tensor that changes its values
157 // but not its nonzero structure, an operation called "simply dynamic" in
158 // [Bik96,Ch9], is also admissible without special env.
159 if (latticeMerger.isSingleCondition(t: tensor, e: exp))
160 return true;
161
162 // Accept "truly dynamic" if the output tensor materializes uninitialized
163 // into the computation and insertions occur in lexicographic index order.
164 sparseOut = lhs;
165
166 // Find the outermost parallel nest to determine whether compress/expand is
167 // needed.
168 outerParNest = 0;
169 const auto iteratorTypes = linalgOp.getIteratorTypesArray();
170 for (unsigned i = 0, e = getLoopNum(); i < e; i++) {
171 if (linalg::isReductionIterator(iteratorTypes[i]))
172 break; // terminate at first reduction
173 outerParNest++;
174 }
175
176 // Inadmissible kernel should have already been rejected by the previous
177 // path during loop scheduling.
178 assert(static_cast<int64_t>(outerParNest) >=
179 linalgOp.getRank(linalgOp.getDpsInitOperand(0)) - 1);
180 return isMaterializing(val: lhs->get());
181}
182
183//===----------------------------------------------------------------------===//
184// Code generation environment topological sort methods
185//===----------------------------------------------------------------------===//
186
187Value CodegenEnv::getLoopVar(LoopId i) const {
188 return loopEmitter.getLoopIV(i);
189}
190
191//===----------------------------------------------------------------------===//
192// Code generation environment sparse tensor output and expansion methods
193//===----------------------------------------------------------------------===//
194
195void CodegenEnv::updateInsertionChain(Value chain) {
196 assert(sparseOut != nullptr && insChain != nullptr);
197 insChain = chain;
198}
199
200bool CodegenEnv::atExpandLevel(OpOperand *o, unsigned rank, LoopId n) const {
201 return sparseOut == o && outerParNest == static_cast<LoopId>(rank - 1) &&
202 outerParNest == n;
203}
204
205void CodegenEnv::startExpand(Value values, Value filled, Value added,
206 Value count) {
207 assert(sparseOut != nullptr && expValues == nullptr);
208 expValues = values;
209 expFilled = filled;
210 expAdded = added;
211 expCount = count;
212}
213
214void CodegenEnv::updateExpandCount(Value count) {
215 assert(sparseOut != nullptr && expValues != nullptr);
216 expCount = count;
217}
218
219void CodegenEnv::endExpand() {
220 assert(sparseOut != nullptr && expValues != nullptr);
221 expValues = expFilled = expAdded = expCount = Value();
222}
223
224//===----------------------------------------------------------------------===//
225// Code generation environment reduction methods
226//===----------------------------------------------------------------------===//
227
228void CodegenEnv::startReduc(ExprId exp, Value val) {
229 assert(!isReduc() && exp != detail::kInvalidId && val);
230 redExp = exp;
231 redVal = val;
232 latticeMerger.setExprValue(e: exp, v: val);
233}
234
235void CodegenEnv::updateReduc(Value val) {
236 assert(isReduc() && val);
237 redVal = val;
238 latticeMerger.clearExprValue(e: redExp);
239 latticeMerger.setExprValue(e: redExp, v: val);
240}
241
242Value CodegenEnv::endReduc() {
243 assert(isReduc());
244 Value val = redVal;
245 redVal = val;
246 latticeMerger.clearExprValue(e: redExp);
247 redExp = detail::kInvalidId;
248 return val;
249}
250
251void CodegenEnv::startValidLexInsert(Value val) {
252 assert(!isValidLexInsert() && isReduc() && val);
253 redValidLexInsert = val;
254}
255
256void CodegenEnv::updateValidLexInsert(Value val) {
257 assert(redValidLexInsert && isReduc() && val);
258 redValidLexInsert = val;
259}
260
261void CodegenEnv::endValidLexInsert() {
262 assert(isValidLexInsert() && !isReduc());
263 redValidLexInsert = Value();
264}
265
266void CodegenEnv::startCustomReduc(ExprId exp) {
267 assert(!isCustomReduc() && exp != detail::kInvalidId);
268 redCustom = exp;
269}
270
271Value CodegenEnv::getCustomRedId() const {
272 assert(isCustomReduc());
273 return dyn_cast<sparse_tensor::ReduceOp>(exp(redCustom).op).getIdentity();
274}
275
276void CodegenEnv::endCustomReduc() {
277 assert(isCustomReduc());
278 redCustom = detail::kInvalidId;
279}
280

source code of mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.cpp