1 | //===- CodegenEnv.cpp - Code generation environment class ----------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "CodegenEnv.h" |
10 | |
11 | #include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
12 | #include "mlir/Dialect/Linalg/Utils/Utils.h" |
13 | #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" |
14 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
15 | |
16 | #include <optional> |
17 | |
18 | using namespace mlir; |
19 | using namespace mlir::sparse_tensor; |
20 | |
21 | //===----------------------------------------------------------------------===// |
22 | // Code generation environment helper functions |
23 | //===----------------------------------------------------------------------===// |
24 | |
25 | /// Returns true if tensor materializes uninitialized into the computation. |
26 | static bool isMaterializing(Value val) { |
27 | return val.getDefiningOp<tensor::EmptyOp>() || |
28 | val.getDefiningOp<bufferization::AllocTensorOp>(); |
29 | } |
30 | |
31 | /// Sorts the dependent loops such that it is ordered in the same sequence in |
32 | /// which loops will be generated. |
33 | static void sortDependentLoops(std::vector<LoopCoeffPair> &target) { |
34 | std::sort(first: target.begin(), last: target.end(), |
35 | comp: [](const LoopCoeffPair &l, const LoopCoeffPair &r) { |
36 | assert(std::addressof(l) == std::addressof(r) || l != r); |
37 | return l.first < r.first; |
38 | }); |
39 | } |
40 | //===----------------------------------------------------------------------===// |
41 | // Code generation environment constructor and general methods |
42 | //===----------------------------------------------------------------------===// |
43 | |
44 | CodegenEnv::CodegenEnv(linalg::GenericOp linop, SparsificationOptions opts, |
45 | unsigned numTensors, unsigned numLoops, unsigned maxRank) |
46 | : linalgOp(linop), sparseOptions(opts), |
47 | latticeMerger(numTensors, numLoops, maxRank), loopEmitter(), |
48 | sparseOut(nullptr), outerParNest(-1u), insChain(), expValues(), |
49 | expFilled(), expAdded(), expCount(), redVal(), redExp(detail::kInvalidId), |
50 | redCustom(detail::kInvalidId), redValidLexInsert() {} |
51 | |
52 | LogicalResult CodegenEnv::initTensorExp() { |
53 | // Builds the tensor expression for the Linalg operation in SSA form. |
54 | std::optional<ExprId> optExp = latticeMerger.buildTensorExpFromLinalg(op()); |
55 | if (!optExp || !isAdmissibleTensorExp(e: *optExp)) |
56 | return failure(); |
57 | |
58 | tensorExp = *optExp; |
59 | return success(); |
60 | } |
61 | |
62 | void CodegenEnv::startEmit(SparseEmitStrategy emitStrategy) { |
63 | assert(insChain == nullptr && "must only start emitting once" ); |
64 | if (sparseOut) { |
65 | insChain = sparseOut->get(); |
66 | latticeMerger.setHasSparseOut(true); |
67 | } |
68 | |
69 | // Sort the related loop array such that they are in the same order as they |
70 | // appears on the topoOrder. |
71 | // TODO: since we only handle affine addition for slice based codegen, and |
72 | // addition is assoicative, the order how we evaluate the expression does |
73 | // not matter. However, to support multiplication, the order of the loop |
74 | // index should match the evaluation order to the affine expression AST. |
75 | |
76 | // Initialize loop emitter. |
77 | SmallVector<Value> tensors; // input tensors passed to loop emitter |
78 | for (OpOperand &t : linalgOp->getOpOperands()) { |
79 | tensors.push_back(t.get()); |
80 | const TensorId tid = makeTensorId(t.getOperandNumber()); |
81 | const Level lvlRank = linalgOp.getMatchingIndexingMap(&t).getNumResults(); |
82 | const auto enc = getSparseTensorEncoding(t.get().getType()); |
83 | (void)enc; |
84 | assert(!enc || lvlRank == enc.getLvlRank()); |
85 | for (Level lvl = 0; lvl < lvlRank; lvl++) |
86 | sortDependentLoops(latticeMerger.getDependentLoops(tid, lvl)); |
87 | } |
88 | loopEmitter.initialize( |
89 | tensors, |
90 | StringAttr::get(linalgOp.getContext(), |
91 | linalg::GenericOp::getOperationName()), |
92 | /*hasOutput=*/true, |
93 | /*isSparseOut=*/sparseOut != nullptr, /*numLoops=*/getLoopNum(), |
94 | // TODO: compute the map and pass it to loop emitter directly instead of |
95 | // passing in a callback. |
96 | /*dependentLvlGetter=*/ |
97 | [this](TensorId t, Level lvl) -> std::vector<LoopCoeffPair> { |
98 | return merger().getDependentLoops(t, lvl); |
99 | }, |
100 | emitStrategy); |
101 | } |
102 | |
103 | std::optional<Operation *> CodegenEnv::genLoopBoundary( |
104 | function_ref<std::optional<Operation *>(MutableArrayRef<Value> parameters)> |
105 | callback) { |
106 | SmallVector<Value> params; |
107 | if (isReduc()) { |
108 | params.push_back(Elt: redVal); |
109 | if (isValidLexInsert()) |
110 | params.push_back(Elt: redValidLexInsert); |
111 | } else { |
112 | assert(!isValidLexInsert()); |
113 | } |
114 | if (isExpand()) |
115 | params.push_back(Elt: expCount); |
116 | if (insChain != nullptr) |
117 | params.push_back(Elt: insChain); |
118 | auto r = callback(params); // may update parameters |
119 | unsigned i = 0; |
120 | if (isReduc()) { |
121 | updateReduc(val: params[i++]); |
122 | if (isValidLexInsert()) |
123 | updateValidLexInsert(val: params[i++]); |
124 | } |
125 | if (isExpand()) |
126 | updateExpandCount(count: params[i++]); |
127 | if (insChain != nullptr) |
128 | updateInsertionChain(chain: params[i]); |
129 | return r; |
130 | } |
131 | |
132 | //===----------------------------------------------------------------------===// |
133 | // Code generation environment verify functions. |
134 | //===----------------------------------------------------------------------===// |
135 | |
136 | bool CodegenEnv::isAdmissibleTensorExp(ExprId exp) { |
137 | // We reject any expression that makes a reduction from `-outTensor`, as those |
138 | // expressions create a dependency between the current iteration (i) and the |
139 | // previous iteration (i-1). It would require iterating over the whole |
140 | // coordinate space, which prevent exploiting sparsity for faster code. |
141 | for (utils::IteratorType it : linalgOp.getIteratorTypesArray()) { |
142 | if (it == utils::IteratorType::reduction) { |
143 | if (latticeMerger.hasNegateOnOut(exp)) |
144 | return false; |
145 | break; |
146 | } |
147 | } |
148 | |
149 | OpOperand *lhs = linalgOp.getDpsInitOperand(0); |
150 | const TensorId tensor = makeTensorId(t: lhs->getOperandNumber()); |
151 | // An non-annotated output tensor is assumed dense, and becomes a random |
152 | // access n-dim memref. Admissible since insertions cannot occur. |
153 | if (getSparseTensorType(val: lhs->get()).isAllDense()) |
154 | return true; |
155 | |
156 | // A tensor expression with a sparse output tensor that changes its values |
157 | // but not its nonzero structure, an operation called "simply dynamic" in |
158 | // [Bik96,Ch9], is also admissible without special env. |
159 | if (latticeMerger.isSingleCondition(t: tensor, e: exp)) |
160 | return true; |
161 | |
162 | // Accept "truly dynamic" if the output tensor materializes uninitialized |
163 | // into the computation and insertions occur in lexicographic index order. |
164 | sparseOut = lhs; |
165 | |
166 | // Find the outermost parallel nest to determine whether compress/expand is |
167 | // needed. |
168 | outerParNest = 0; |
169 | const auto iteratorTypes = linalgOp.getIteratorTypesArray(); |
170 | for (unsigned i = 0, e = getLoopNum(); i < e; i++) { |
171 | if (linalg::isReductionIterator(iteratorTypes[i])) |
172 | break; // terminate at first reduction |
173 | outerParNest++; |
174 | } |
175 | |
176 | // Inadmissible kernel should have already been rejected by the previous |
177 | // path during loop scheduling. |
178 | assert(static_cast<int64_t>(outerParNest) >= |
179 | linalgOp.getRank(linalgOp.getDpsInitOperand(0)) - 1); |
180 | return isMaterializing(val: lhs->get()); |
181 | } |
182 | |
183 | //===----------------------------------------------------------------------===// |
184 | // Code generation environment topological sort methods |
185 | //===----------------------------------------------------------------------===// |
186 | |
187 | Value CodegenEnv::getLoopVar(LoopId i) const { |
188 | return loopEmitter.getLoopIV(i); |
189 | } |
190 | |
191 | //===----------------------------------------------------------------------===// |
192 | // Code generation environment sparse tensor output and expansion methods |
193 | //===----------------------------------------------------------------------===// |
194 | |
195 | void CodegenEnv::updateInsertionChain(Value chain) { |
196 | assert(sparseOut != nullptr && insChain != nullptr); |
197 | insChain = chain; |
198 | } |
199 | |
200 | bool CodegenEnv::atExpandLevel(OpOperand *o, unsigned rank, LoopId n) const { |
201 | return sparseOut == o && outerParNest == static_cast<LoopId>(rank - 1) && |
202 | outerParNest == n; |
203 | } |
204 | |
205 | void CodegenEnv::startExpand(Value values, Value filled, Value added, |
206 | Value count) { |
207 | assert(sparseOut != nullptr && expValues == nullptr); |
208 | expValues = values; |
209 | expFilled = filled; |
210 | expAdded = added; |
211 | expCount = count; |
212 | } |
213 | |
214 | void CodegenEnv::updateExpandCount(Value count) { |
215 | assert(sparseOut != nullptr && expValues != nullptr); |
216 | expCount = count; |
217 | } |
218 | |
219 | void CodegenEnv::endExpand() { |
220 | assert(sparseOut != nullptr && expValues != nullptr); |
221 | expValues = expFilled = expAdded = expCount = Value(); |
222 | } |
223 | |
224 | //===----------------------------------------------------------------------===// |
225 | // Code generation environment reduction methods |
226 | //===----------------------------------------------------------------------===// |
227 | |
228 | void CodegenEnv::startReduc(ExprId exp, Value val) { |
229 | assert(!isReduc() && exp != detail::kInvalidId && val); |
230 | redExp = exp; |
231 | redVal = val; |
232 | latticeMerger.setExprValue(e: exp, v: val); |
233 | } |
234 | |
235 | void CodegenEnv::updateReduc(Value val) { |
236 | assert(isReduc() && val); |
237 | redVal = val; |
238 | latticeMerger.clearExprValue(e: redExp); |
239 | latticeMerger.setExprValue(e: redExp, v: val); |
240 | } |
241 | |
242 | Value CodegenEnv::endReduc() { |
243 | assert(isReduc()); |
244 | Value val = redVal; |
245 | redVal = val; |
246 | latticeMerger.clearExprValue(e: redExp); |
247 | redExp = detail::kInvalidId; |
248 | return val; |
249 | } |
250 | |
251 | void CodegenEnv::startValidLexInsert(Value val) { |
252 | assert(!isValidLexInsert() && isReduc() && val); |
253 | redValidLexInsert = val; |
254 | } |
255 | |
256 | void CodegenEnv::updateValidLexInsert(Value val) { |
257 | assert(redValidLexInsert && isReduc() && val); |
258 | redValidLexInsert = val; |
259 | } |
260 | |
261 | void CodegenEnv::endValidLexInsert() { |
262 | assert(isValidLexInsert() && !isReduc()); |
263 | redValidLexInsert = Value(); |
264 | } |
265 | |
266 | void CodegenEnv::startCustomReduc(ExprId exp) { |
267 | assert(!isCustomReduc() && exp != detail::kInvalidId); |
268 | redCustom = exp; |
269 | } |
270 | |
271 | Value CodegenEnv::getCustomRedId() const { |
272 | assert(isCustomReduc()); |
273 | return dyn_cast<sparse_tensor::ReduceOp>(exp(redCustom).op).getIdentity(); |
274 | } |
275 | |
276 | void CodegenEnv::endCustomReduc() { |
277 | assert(isCustomReduc()); |
278 | redCustom = detail::kInvalidId; |
279 | } |
280 | |