1 | //===- CodegenEnv.cpp - Code generation environment class ----------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "CodegenEnv.h" |
10 | |
11 | #include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
12 | #include "mlir/Dialect/Linalg/Utils/Utils.h" |
13 | #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" |
14 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
15 | |
16 | #include <optional> |
17 | |
18 | using namespace mlir; |
19 | using namespace mlir::sparse_tensor; |
20 | |
21 | //===----------------------------------------------------------------------===// |
22 | // Code generation environment helper functions |
23 | //===----------------------------------------------------------------------===// |
24 | |
25 | /// Returns true if tensor materializes uninitialized into the computation. |
26 | static bool isMaterializing(Value val) { |
27 | return val.getDefiningOp<tensor::EmptyOp>() || |
28 | val.getDefiningOp<bufferization::AllocTensorOp>(); |
29 | } |
30 | |
31 | /// Sorts the dependent loops such that it is ordered in the same sequence in |
32 | /// which loops will be generated. |
33 | static void sortDependentLoops(std::vector<LoopCoeffPair> &target) { |
34 | llvm::sort(C&: target, Comp: [](const LoopCoeffPair &l, const LoopCoeffPair &r) { |
35 | assert(std::addressof(l) == std::addressof(r) || l != r); |
36 | return l.first < r.first; |
37 | }); |
38 | } |
39 | //===----------------------------------------------------------------------===// |
40 | // Code generation environment constructor and general methods |
41 | //===----------------------------------------------------------------------===// |
42 | |
43 | CodegenEnv::CodegenEnv(linalg::GenericOp linop, SparsificationOptions opts, |
44 | unsigned numTensors, unsigned numLoops, unsigned maxRank) |
45 | : linalgOp(linop), sparseOptions(opts), |
46 | latticeMerger(numTensors, numLoops, maxRank), loopEmitter(), |
47 | sparseOut(nullptr), outerParNest(-1u), insChain(), expValues(), |
48 | expFilled(), expAdded(), expCount(), redVal(), redExp(detail::kInvalidId), |
49 | redCustom(detail::kInvalidId), redValidLexInsert() {} |
50 | |
51 | LogicalResult CodegenEnv::initTensorExp() { |
52 | // Builds the tensor expression for the Linalg operation in SSA form. |
53 | std::optional<ExprId> optExp = latticeMerger.buildTensorExpFromLinalg(op()); |
54 | if (!optExp || !isAdmissibleTensorExp(e: *optExp)) |
55 | return failure(); |
56 | |
57 | tensorExp = *optExp; |
58 | return success(); |
59 | } |
60 | |
61 | void CodegenEnv::startEmit(SparseEmitStrategy emitStrategy) { |
62 | assert(insChain == nullptr && "must only start emitting once" ); |
63 | if (sparseOut) { |
64 | insChain = sparseOut->get(); |
65 | latticeMerger.setHasSparseOut(true); |
66 | } |
67 | |
68 | // Sort the related loop array such that they are in the same order as they |
69 | // appears on the topoOrder. |
70 | // TODO: since we only handle affine addition for slice based codegen, and |
71 | // addition is assoicative, the order how we evaluate the expression does |
72 | // not matter. However, to support multiplication, the order of the loop |
73 | // index should match the evaluation order to the affine expression AST. |
74 | |
75 | // Initialize loop emitter. |
76 | SmallVector<Value> tensors; // input tensors passed to loop emitter |
77 | for (OpOperand &t : linalgOp->getOpOperands()) { |
78 | tensors.push_back(t.get()); |
79 | const TensorId tid = makeTensorId(t.getOperandNumber()); |
80 | const Level lvlRank = linalgOp.getMatchingIndexingMap(&t).getNumResults(); |
81 | const auto enc = getSparseTensorEncoding(t.get().getType()); |
82 | (void)enc; |
83 | assert(!enc || lvlRank == enc.getLvlRank()); |
84 | for (Level lvl = 0; lvl < lvlRank; lvl++) |
85 | sortDependentLoops(latticeMerger.getDependentLoops(tid, lvl)); |
86 | } |
87 | loopEmitter.initialize( |
88 | tensors, |
89 | StringAttr::get(linalgOp.getContext(), |
90 | linalg::GenericOp::getOperationName()), |
91 | /*hasOutput=*/true, |
92 | /*isSparseOut=*/sparseOut != nullptr, /*numLoops=*/getLoopNum(), |
93 | // TODO: compute the map and pass it to loop emitter directly instead of |
94 | // passing in a callback. |
95 | /*dependentLvlGetter=*/ |
96 | [this](TensorId t, Level lvl) -> std::vector<LoopCoeffPair> { |
97 | return merger().getDependentLoops(t, lvl); |
98 | }, |
99 | emitStrategy); |
100 | } |
101 | |
102 | std::optional<Operation *> CodegenEnv::genLoopBoundary( |
103 | function_ref<std::optional<Operation *>(MutableArrayRef<Value> parameters)> |
104 | callback) { |
105 | SmallVector<Value> params; |
106 | if (isReduc()) { |
107 | params.push_back(Elt: redVal); |
108 | if (isValidLexInsert()) |
109 | params.push_back(Elt: redValidLexInsert); |
110 | } else { |
111 | assert(!isValidLexInsert()); |
112 | } |
113 | if (isExpand()) |
114 | params.push_back(Elt: expCount); |
115 | if (insChain != nullptr) |
116 | params.push_back(Elt: insChain); |
117 | auto r = callback(params); // may update parameters |
118 | unsigned i = 0; |
119 | if (isReduc()) { |
120 | updateReduc(val: params[i++]); |
121 | if (isValidLexInsert()) |
122 | updateValidLexInsert(val: params[i++]); |
123 | } |
124 | if (isExpand()) |
125 | updateExpandCount(count: params[i++]); |
126 | if (insChain != nullptr) |
127 | updateInsertionChain(chain: params[i]); |
128 | return r; |
129 | } |
130 | |
131 | //===----------------------------------------------------------------------===// |
132 | // Code generation environment verify functions. |
133 | //===----------------------------------------------------------------------===// |
134 | |
135 | bool CodegenEnv::isAdmissibleTensorExp(ExprId exp) { |
136 | // We reject any expression that makes a reduction from `-outTensor`, as those |
137 | // expressions create a dependency between the current iteration (i) and the |
138 | // previous iteration (i-1). It would require iterating over the whole |
139 | // coordinate space, which prevent exploiting sparsity for faster code. |
140 | for (utils::IteratorType it : linalgOp.getIteratorTypesArray()) { |
141 | if (it == utils::IteratorType::reduction) { |
142 | if (latticeMerger.hasNegateOnOut(exp)) |
143 | return false; |
144 | break; |
145 | } |
146 | } |
147 | |
148 | OpOperand *lhs = linalgOp.getDpsInitOperand(0); |
149 | const TensorId tensor = makeTensorId(t: lhs->getOperandNumber()); |
150 | // An non-annotated output tensor is assumed dense, and becomes a random |
151 | // access n-dim memref. Admissible since insertions cannot occur. |
152 | if (getSparseTensorType(val: lhs->get()).isAllDense()) |
153 | return true; |
154 | |
155 | // A tensor expression with a sparse output tensor that changes its values |
156 | // but not its nonzero structure, an operation called "simply dynamic" in |
157 | // [Bik96,Ch9], is also admissible without special env. |
158 | if (latticeMerger.isSingleCondition(t: tensor, e: exp)) |
159 | return true; |
160 | |
161 | // Accept "truly dynamic" if the output tensor materializes uninitialized |
162 | // into the computation and insertions occur in lexicographic index order. |
163 | sparseOut = lhs; |
164 | |
165 | // Find the outermost parallel nest to determine whether compress/expand is |
166 | // needed. |
167 | outerParNest = 0; |
168 | const auto iteratorTypes = linalgOp.getIteratorTypesArray(); |
169 | for (unsigned i = 0, e = getLoopNum(); i < e; i++) { |
170 | if (linalg::isReductionIterator(iteratorTypes[i])) |
171 | break; // terminate at first reduction |
172 | outerParNest++; |
173 | } |
174 | |
175 | // Inadmissible kernel should have already been rejected by the previous |
176 | // path during loop scheduling. |
177 | assert(static_cast<int64_t>(outerParNest) >= |
178 | linalgOp.getRank(linalgOp.getDpsInitOperand(0)) - 1); |
179 | return isMaterializing(val: lhs->get()); |
180 | } |
181 | |
182 | //===----------------------------------------------------------------------===// |
183 | // Code generation environment topological sort methods |
184 | //===----------------------------------------------------------------------===// |
185 | |
186 | Value CodegenEnv::getLoopVar(LoopId i) const { |
187 | return loopEmitter.getLoopIV(i); |
188 | } |
189 | |
190 | //===----------------------------------------------------------------------===// |
191 | // Code generation environment sparse tensor output and expansion methods |
192 | //===----------------------------------------------------------------------===// |
193 | |
194 | void CodegenEnv::updateInsertionChain(Value chain) { |
195 | assert(sparseOut != nullptr && insChain != nullptr); |
196 | insChain = chain; |
197 | } |
198 | |
199 | bool CodegenEnv::atExpandLevel(OpOperand *o, unsigned rank, LoopId n) const { |
200 | return sparseOut == o && outerParNest == static_cast<LoopId>(rank - 1) && |
201 | outerParNest == n; |
202 | } |
203 | |
204 | void CodegenEnv::startExpand(Value values, Value filled, Value added, |
205 | Value count) { |
206 | assert(sparseOut != nullptr && expValues == nullptr); |
207 | expValues = values; |
208 | expFilled = filled; |
209 | expAdded = added; |
210 | expCount = count; |
211 | } |
212 | |
213 | void CodegenEnv::updateExpandCount(Value count) { |
214 | assert(sparseOut != nullptr && expValues != nullptr); |
215 | expCount = count; |
216 | } |
217 | |
218 | void CodegenEnv::endExpand() { |
219 | assert(sparseOut != nullptr && expValues != nullptr); |
220 | expValues = expFilled = expAdded = expCount = Value(); |
221 | } |
222 | |
223 | //===----------------------------------------------------------------------===// |
224 | // Code generation environment reduction methods |
225 | //===----------------------------------------------------------------------===// |
226 | |
227 | void CodegenEnv::startReduc(ExprId exp, Value val) { |
228 | assert(!isReduc() && exp != detail::kInvalidId && val); |
229 | redExp = exp; |
230 | redVal = val; |
231 | latticeMerger.setExprValue(e: exp, v: val); |
232 | } |
233 | |
234 | void CodegenEnv::updateReduc(Value val) { |
235 | assert(isReduc() && val); |
236 | redVal = val; |
237 | latticeMerger.clearExprValue(e: redExp); |
238 | latticeMerger.setExprValue(e: redExp, v: val); |
239 | } |
240 | |
241 | Value CodegenEnv::endReduc() { |
242 | assert(isReduc()); |
243 | Value val = redVal; |
244 | redVal = val; |
245 | latticeMerger.clearExprValue(e: redExp); |
246 | redExp = detail::kInvalidId; |
247 | return val; |
248 | } |
249 | |
250 | void CodegenEnv::startValidLexInsert(Value val) { |
251 | assert(!isValidLexInsert() && isReduc() && val); |
252 | redValidLexInsert = val; |
253 | } |
254 | |
255 | void CodegenEnv::updateValidLexInsert(Value val) { |
256 | assert(redValidLexInsert && isReduc() && val); |
257 | redValidLexInsert = val; |
258 | } |
259 | |
260 | void CodegenEnv::endValidLexInsert() { |
261 | assert(isValidLexInsert() && !isReduc()); |
262 | redValidLexInsert = Value(); |
263 | } |
264 | |
265 | void CodegenEnv::startCustomReduc(ExprId exp) { |
266 | assert(!isCustomReduc() && exp != detail::kInvalidId); |
267 | redCustom = exp; |
268 | } |
269 | |
270 | Value CodegenEnv::getCustomRedId() const { |
271 | assert(isCustomReduc()); |
272 | return dyn_cast<sparse_tensor::ReduceOp>(exp(redCustom).op).getIdentity(); |
273 | } |
274 | |
275 | void CodegenEnv::endCustomReduc() { |
276 | assert(isCustomReduc()); |
277 | redCustom = detail::kInvalidId; |
278 | } |
279 | |