1 | //===- LoopEmitter.h --------------------------------------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_LOOPEMITTER_H_ |
10 | #define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_LOOPEMITTER_H_ |
11 | |
12 | #include <vector> |
13 | |
14 | #include "SparseTensorIterator.h" |
15 | |
16 | #include "mlir/Dialect/SparseTensor/IR/Enums.h" |
17 | #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" |
18 | #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" |
19 | #include "mlir/Dialect/SparseTensor/Utils/Merger.h" |
20 | #include "mlir/IR/PatternMatch.h" |
21 | |
22 | namespace mlir { |
23 | namespace sparse_tensor { |
24 | |
25 | // A compressed <tensor id, level> pair. |
26 | using TensorLevel = unsigned; |
27 | |
28 | // |
29 | // SparseTensorLoopEmiter class, manages sparse tensors and helps to |
30 | // generate loop structure to (co)-iterate sparse tensors. |
31 | // |
32 | // An example usage: |
33 | // To generate the following loops over T1<?x?> and T2<?x?> |
34 | // |
35 | // for i in TENSOR_1_0 { |
36 | // for j : TENSOR_2_0 { |
37 | // for k : TENSOR_1_1 {} |
38 | // for k : TENSOR_2_1 {} |
39 | // } |
40 | // } |
41 | // |
42 | // One can use |
43 | // |
44 | // LoopEmiter loopEmiter({T1, T1}); |
45 | // loopEmiter.initializeLoopEmit(); |
46 | // loopEmiter.enterLoopOverTensorAtLvl(T1, 0); |
47 | // loopEmiter.enterLoopOverTensorAtLvl(T2, 0); |
48 | // loopEmiter.enterLoopOverTensorAtLvl(T1, 1); |
49 | // loopEmiter.exitCurrentLoop(); |
50 | // loopEmiter.enterLoopOverTensorAtLvl(T2, 1); |
51 | // loopEmiter.exitCurrentLoop(); // exit k |
52 | // loopEmiter.exitCurrentLoop(); // exit j |
53 | // loopEmiter.exitCurrentLoop(); // exit i |
54 | // |
55 | class LoopEmitter { |
56 | public: |
57 | /// Optional callback function to setup dense output tensors when |
58 | /// initializing the loop emitter (e.g., to fill a dense output with zeros). |
59 | using OutputUpdater = function_ref<Value(OpBuilder &builder, Location loc, |
60 | Value memref, Value tensor)>; |
61 | |
62 | /// Optional callback function to set the bound for the synthetic tensor, |
63 | /// which essentially is the dense loop bound. |
64 | using SynTensorBoundSetter = |
65 | function_ref<Value(OpBuilder &builder, Location loc, Level lvl)>; |
66 | |
67 | // Map from [tid, lvl] to a list of dependent [LoopId, coeffecient] for |
68 | // subscript expressions on sparse tensors. |
69 | // |
70 | // E.g., for affine index (2 * d0 + d1), it depends on loop d0 and d1 (for |
71 | // affine expression reduction) and uses 2 and 1 for coefficients on d0, d1 |
72 | // respectively. If the list is empty, it means that there is no affine |
73 | // expression on the input [tid, lvl]. |
74 | // |
75 | // NOTE: LoopEmitter assumes that the loop id is consistent with the loop |
76 | // order, i.e., loop `d0` will be generated before loop `d1`. |
77 | using DependentLvlGetter = |
78 | function_ref<std::vector<std::pair<LoopId, unsigned>>(TensorId, Level)>; |
79 | |
80 | LoopEmitter() = default; |
81 | |
82 | /// Takes an array of input tensors, which the generated loops will |
83 | /// iterate over. Each tensor is given a `TensorId` (numerically equal |
84 | /// to the position of that tensor `Value` in the array). Setting |
85 | /// `isSparseOut` indicates that the sparse output tensor is empty, |
86 | /// so the loop emitter will generate loops over it according to the |
87 | /// level-sizes. |
88 | void |
89 | initialize(ValueRange tensors, StringAttr loopTag = nullptr, |
90 | bool hasOutput = false, bool isSparseOut = false, |
91 | unsigned numLoops = 0, DependentLvlGetter getter = nullptr, |
92 | SparseEmitStrategy emitStrategy = SparseEmitStrategy::kFunctional); |
93 | |
94 | explicit LoopEmitter( |
95 | ValueRange tensors, StringAttr loopTag = nullptr, bool hasOutput = false, |
96 | bool isSparseOut = false, unsigned numLoops = 0, |
97 | DependentLvlGetter getter = nullptr, |
98 | SparseEmitStrategy emitStrategy = SparseEmitStrategy::kFunctional); |
99 | |
100 | /// Starts a loop emitting session by generating all the buffers needed |
101 | /// for iterating over the tensors. |
102 | void initializeLoopEmit(OpBuilder &builder, Location loc, |
103 | OutputUpdater updater = nullptr, |
104 | SynTensorBoundSetter synSetter = nullptr); |
105 | |
106 | /// Generates code to compute an affine expression whose variables are |
107 | /// `LoopId`s (i.e., `cast<AffineDimExpr>(a).getPosition()` is a valid |
108 | /// `LoopId`). |
109 | Value genAffine(OpBuilder &builder, Location loc, AffineExpr a); |
110 | |
111 | /// Enters a new loop sequence, the loops within the same sequence starts |
112 | /// from the break points of previous loop instead of starting over from 0. |
113 | /// e.g., |
114 | /// { |
115 | /// // loop sequence start. |
116 | /// p0 = while(xxx) |
117 | /// ... |
118 | /// break p0 |
119 | /// |
120 | /// // Starts loop from p0 |
121 | /// for (i = p0; i < end; i++) |
122 | /// ... |
123 | /// // loop sequence end. |
124 | /// } |
125 | void enterNewLoopSeq(OpBuilder &builder, Location loc, |
126 | ArrayRef<TensorLevel> tidLvls); |
127 | |
128 | /// Exits the current loop sequence, this will reset universal index to 0. |
129 | void exitCurrentLoopSeq(OpBuilder &builder, Location loc); |
130 | |
131 | /// Emits the address for a dense level based on the value evaluated by the |
132 | /// provided affine expression. |
133 | void locateLvlAtAffineAddress(OpBuilder &builder, Location loc, |
134 | TensorLevel tidLvl, AffineExpr lvlExpr); |
135 | |
136 | // TODO: Get rid of `lvls` in the argument list? Track the level we |
137 | // are currently at internally. Then it would be enterNextLvlForTensor. |
138 | // Still need a way to specify the lvl for non-annotated tensors though, |
139 | // as those can be accessed out of order. |
140 | // |
141 | /// Emits a co-iteration loop over a set of tensors. |
142 | /// Emits loop over tensor_tid_lvl, it assumes that loops between |
143 | /// tensor_tid_[0, lvl - 1] have already been generated. |
144 | /// The function will also perform in-place update on the `reduc` vector to |
145 | /// return the reduction variable used inside the generated loop. |
146 | Operation *enterCoIterationOverTensorsAtLvls( |
147 | OpBuilder &builder, Location loc, ArrayRef<TensorLevel> tidLvls, |
148 | unsigned numCases, MutableArrayRef<Value> reduc = {}, |
149 | bool isParallel = false, bool needsUniv = false); |
150 | |
151 | Region *enterCurrentCoIterationCase(OpBuilder &builder, Location loc, |
152 | I64BitSet caseBit, unsigned caseIdx, |
153 | MutableArrayRef<Value> reduc); |
154 | |
155 | /// Generates code to exit the current loop (e.g., generates yields, forwards |
156 | /// loop induction variables, etc). |
157 | void exitCurrentLoop(RewriterBase &rewriter, Location loc, |
158 | MutableArrayRef<Value> reduc = {}); |
159 | |
160 | /// Get the range of values for all induction variables. |
161 | auto getLoopIVsRange() const { |
162 | return llvm::map_range(C: loopStack, F: [](const LoopInfo &li) { return li.iv; }); |
163 | } |
164 | |
165 | /// Fills the out-parameter with the loop induction variables for all |
166 | /// loops in the current loop-stack. |
167 | SmallVector<Value> getLoopIVs() const { |
168 | return llvm::to_vector(Range: getLoopIVsRange()); |
169 | } |
170 | |
171 | /// Gets the current depth of the loop-stack. |
172 | LoopId getCurrentDepth() const { return llvm::range_size(Range: getLoopIVsRange()); } |
173 | |
174 | /// Gets loop induction variable for the given loop |
175 | Value getLoopIV(LoopId n) const { |
176 | if (n >= getCurrentDepth()) |
177 | return Value(); |
178 | auto it = getLoopIVsRange().begin(); |
179 | std::advance(i&: it, n: n); |
180 | return *it; |
181 | } |
182 | |
183 | /// Gets the total number of manifest tensors (excluding the synthetic |
184 | /// tensor). |
185 | unsigned getNumManifestTensors() const { return tensors.size(); } |
186 | |
187 | /// Gets the total number of tensors that loopEmitter is operating on. |
188 | unsigned getNumTensors() const { |
189 | // Manifest tensors with one synthetic tensor at the end. |
190 | return getNumManifestTensors() + 1; |
191 | } |
192 | |
193 | /// Gets the TensorId for synthetic tensor. |
194 | TensorId getSynTensorId() const { return tensors.size(); } |
195 | |
196 | /// Gets the TensorId for output tensor. |
197 | TensorId getOutTensorId() const { |
198 | assert(hasOutput); |
199 | return getNumManifestTensors() - 1; |
200 | } |
201 | |
202 | /// Compresses a TensorId and Level into a TensorLevel. |
203 | TensorLevel makeTensorLevel(TensorId t, Level l) const { |
204 | return l * getNumTensors() + t; |
205 | } |
206 | |
207 | /// De-compresses a TensorLevel back to a pair of TensorId and Level. |
208 | std::pair<TensorId, Level> unpackTensorLevel(TensorLevel tidLvl) const { |
209 | unsigned nt = getNumTensors(); |
210 | return std::make_pair(x: tidLvl % nt, y: tidLvl / nt); |
211 | } |
212 | |
213 | /// Converts a range of TensorLevel to a range of std::pair<TensorId, Level> |
214 | template <class ContainerTy> |
215 | auto unpackTensorLevelRange(ContainerTy &&c) const { |
216 | using EltTy = decltype(*c.begin()); |
217 | static_assert(std::is_same_v<llvm::remove_cvref_t<EltTy>, TensorLevel>, |
218 | "Must be unpacking a TensorLevel range" ); |
219 | return llvm::map_range(std::forward<ContainerTy>(c), [this](EltTy tl) { |
220 | return this->unpackTensorLevel(tidLvl: tl); |
221 | }); |
222 | } |
223 | |
224 | /// |
225 | /// Getters. |
226 | /// |
227 | SmallVector<Value> getValPosits(TensorId tid) const { |
228 | // Returns the iterator if we are generating sparse (co)iterate-based loops. |
229 | if (emitStrategy == SparseEmitStrategy::kSparseIterator) |
230 | return {spIterVals[tid].back()}; |
231 | |
232 | // Returns {[batch coords], last-level position}. |
233 | SmallVector<Value> batchCrds = iters[tid].back().back()->getBatchCrds(); |
234 | Value lastLvlPos = iters[tid].back().back()->getCurPosition().front(); |
235 | batchCrds.push_back(Elt: lastLvlPos); |
236 | return batchCrds; |
237 | }; |
238 | Value getCoord(TensorId tid, Level lvl) const { |
239 | return getCurIterator(tid, lvl).getCrd(); |
240 | }; |
241 | const std::vector<Value> &getValBuffer() const { return valBuffer; }; |
242 | |
243 | constexpr static llvm::StringLiteral getLoopEmitterLoopAttrName() { |
244 | return llvm::StringLiteral("Emitted from" ); |
245 | } |
246 | |
247 | private: |
248 | /// |
249 | /// Structure definitions that hold different kinds of loops information. |
250 | /// |
251 | |
252 | // LoopInfo stores information of a loop generated by LoopEmitter. E.g., |
253 | // the set of tensors levels that the loop is iterating over. |
254 | struct LoopInfo final { |
255 | LoopInfo(ArrayRef<TensorLevel> tidLvls, Operation *loop, Block *userBlock, |
256 | Value iv, StringAttr loopTag) |
257 | : tidLvls(tidLvls), loop(loop), userCodeBlock(userBlock), iv(iv) { |
258 | // Attached a special tag to loop emitter generated loop. |
259 | if (loopTag) |
260 | loop->setAttr(LoopEmitter::getLoopEmitterLoopAttrName(), loopTag); |
261 | } |
262 | // The set of <tensor, lvl>, with *only* trivial index expressions, that are |
263 | // used as the condition for the generated loop. Extra information is |
264 | // required for levels with non-tivial index expressions, which is |
265 | // maintained by the sliceDrivenInfo array below. |
266 | const llvm::SmallVector<TensorLevel> tidLvls; |
267 | Operation *loop; // the loop operation |
268 | Block *const userCodeBlock; // the block holding users' generated code. |
269 | Value iv; // the induction variable for the loop |
270 | }; |
271 | |
272 | void categorizeIterators(ArrayRef<TensorLevel> tidLvls, |
273 | SmallVectorImpl<SparseIterator *> &raIters, |
274 | SmallVectorImpl<SparseIterator *> &spIters); |
275 | /// |
276 | /// LoopEmitter internal helper functions. |
277 | /// |
278 | |
279 | using LoopBodyBuilder = llvm::function_ref<void(OpBuilder &, Location, Value, |
280 | MutableArrayRef<Value>)>; |
281 | |
282 | /// Whether the list of the sparse condition should be iterated by for loop. |
283 | bool shouldIteratedByForLoop(ArrayRef<SparseIterator *> spIters); |
284 | |
285 | /// Generates instructions to compute the coordinate of tensors[tid][lvl] |
286 | /// under the current loop context. The final argument is the |
287 | /// collapsed-output level, whereas this function handles converting |
288 | /// that to the uncollapsed-input level |
289 | Value genSparseCrd(OpBuilder &builder, Location loc, TensorId tid, |
290 | Level dstLvl); |
291 | |
292 | bool isSynTensor(TensorId tid) const { return tid == getSynTensorId(); } |
293 | |
294 | bool isOutputTensor(TensorId tid) const { |
295 | return hasOutput && tid == getOutTensorId(); |
296 | } |
297 | |
298 | bool isSparseOutput(TensorId tid) const { |
299 | return isOutputTensor(tid) && isSparseOut; |
300 | } |
301 | |
302 | bool isValidLevel(TensorId tid, Level lvl) const { |
303 | return tid < lvls.size() && lvl < lvls[tid].size(); |
304 | } |
305 | |
306 | /// Prepares loop for iterating over `tensor[lvl]`, under the assumption |
307 | /// that `tensor[0...lvl-1]` loops have already been set up. |
308 | void prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc, |
309 | TensorId tid, Level lvl); |
310 | |
311 | /// Emits a for loop to iterate over a tensor level with the provided |
312 | /// lower bound `lo` and upper bound `hi`. Apart from iterating just |
313 | /// single tensor level, for loops can be used for slice-driven loop on |
314 | /// dense level too. |
315 | /// Returns a pair: the loop generated and the value for the induction |
316 | /// variable. |
317 | std::pair<Operation *, Value> |
318 | emitForLoopOverTensorAtLvl(OpBuilder &builder, Location loc, |
319 | SparseIterator &iter, MutableArrayRef<Value> reduc, |
320 | bool isParallel); |
321 | |
322 | /// Emits a while loop to co-iterate over a list of sparse condition, or |
323 | /// (complex) single sparse condition that can not be handled by for loop |
324 | /// (e.g., index reduction loop). |
325 | /// Returns a pair: the loop generated and the value for the induction |
326 | /// variable (which is the minimum coordinate of all the tensor that being |
327 | /// iterated). |
328 | std::pair<Operation *, Value> |
329 | emitWhileLoopOverTensorsAtLvls(OpBuilder &builder, Location loc, |
330 | ArrayRef<SparseIterator *> iters, |
331 | MutableArrayRef<Value> reduc, bool needsUniv); |
332 | |
333 | /// Exits a for loop, returns the reduction results, e.g., |
334 | /// For sequential for loops: |
335 | /// %ret = for () { |
336 | /// ... |
337 | /// %val = addi %args, %c |
338 | /// yield %val |
339 | /// } |
340 | /// For parallel loops, the following generated code by users: |
341 | /// %ret = parallel () init(%args) { |
342 | /// ... |
343 | /// %val = op %args, %c |
344 | /// } |
345 | /// will be transformed into |
346 | /// %ret = parallel () init(%args) { |
347 | /// ... |
348 | /// scf.reduce(%c) bb0(%0, %1){ |
349 | /// %val = op %0, %1 |
350 | /// scf.reduce.return %val |
351 | /// } |
352 | /// } |
353 | /// NOTE: only one instruction will be moved into reduce block, |
354 | /// transformation will fail if multiple instructions are used to compute |
355 | /// the reduction value. Return %ret to user, while %val is provided by |
356 | /// users (`reduc`). |
357 | void exitForLoop(RewriterBase &rewriter, Location loc, |
358 | MutableArrayRef<Value> reduc); |
359 | |
360 | /// Exits a while loop, returns the reduction results. |
361 | void exitWhileLoop(OpBuilder &builder, Location loc, |
362 | MutableArrayRef<Value> reduc); |
363 | |
364 | // |
365 | // Slice-driven loop related methods. |
366 | // |
367 | |
368 | void initSubSectIterator(OpBuilder &builder, Location loc); |
369 | |
370 | /// Get the reduced number of contraints on tensor[tid][lvl]. |
371 | unsigned redDepOnLevel(TensorId tid, Level lvl) const { |
372 | return levelReducedDep[tid][lvl]; |
373 | }; |
374 | |
375 | SparseIterator &getCurIterator(TensorId tid, Level lvl) const { |
376 | if (dependentLvlMap[tid][lvl].empty()) |
377 | return *iters[tid][lvl].back(); |
378 | |
379 | assert(redDepOnLevel(tid, lvl) >= 1); |
380 | return *iters[tid][lvl][redDepOnLevel(tid, lvl) - 1]; |
381 | } |
382 | |
383 | std::unique_ptr<SparseIterator> |
384 | makeLevelIterator(OpBuilder &builder, Location loc, TensorId tid, Level l); |
385 | |
386 | /// A optional string attribute that should be attached to the loop |
387 | /// generated by loop emitter, it might help following passes to identify |
388 | /// loops that operates on sparse tensors more easily. |
389 | StringAttr loopTag; |
390 | /// Whether the loop emitter needs to treat the last tensor as the output |
391 | /// tensor. |
392 | bool hasOutput; |
393 | bool isSparseOut; |
394 | SparseEmitStrategy emitStrategy; |
395 | |
396 | // |
397 | // Fields which have `numTensor` many entries. |
398 | // |
399 | |
400 | /// Input and (optional) output tensors. |
401 | std::vector<Value> tensors; |
402 | std::vector<Value> loopHighs; |
403 | std::vector<std::vector<std::unique_ptr<SparseTensorLevel>>> lvls; |
404 | std::vector<std::vector<std::vector<std::unique_ptr<SparseIterator>>>> iters; |
405 | std::vector<Value> valBuffer; // to_value |
406 | |
407 | // Map from [tid, level] to a list of dependent [tidlevel, coefficient]. |
408 | // See comments for `DependentLvlGetter`. |
409 | std::vector<std::vector<std::vector<std::pair<LoopId, unsigned>>>> |
410 | dependentLvlMap; |
411 | |
412 | // The (size, stride) for each conceptual slice used for index reduction |
413 | // loops. |
414 | std::vector<std::vector<std::vector<std::pair<Value, unsigned>>>> sliceMeta; |
415 | |
416 | // The number of reduced dependencies on a tensor level so far. |
417 | std::vector<std::vector<unsigned>> levelReducedDep; |
418 | |
419 | // |
420 | // Fields which have at most `numLoops` many entries. |
421 | // |
422 | |
423 | /// Loop Stack, stores the information of all the nested loops that are |
424 | /// alive. |
425 | std::vector<LoopInfo> loopStack; |
426 | |
427 | // Loop Sequence Stack, stores the universal index for the current loop |
428 | // sequence. and a list of tid level that the loop sequence traverse. |
429 | std::vector<std::pair<Value, std::vector<TensorLevel>>> loopSeqStack; |
430 | |
431 | // |
432 | // EXPERIMENTAL: |
433 | // Fields for generating sparse-iterator-based loop. |
434 | // |
435 | |
436 | std::vector<std::vector<Value>> spIterVals; |
437 | }; |
438 | |
439 | // |
440 | // Utils functions to generate sparse loops. |
441 | // |
442 | |
443 | // Generate a while loop that co-iterates over a set of iterators. |
444 | std::pair<Operation *, Value> genCoIteration(OpBuilder &builder, Location loc, |
445 | ArrayRef<SparseIterator *> iters, |
446 | MutableArrayRef<Value> reduc, |
447 | Value uniIdx, |
448 | bool userReducFirst = false); |
449 | |
450 | } // namespace sparse_tensor |
451 | } // namespace mlir |
452 | |
453 | #endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_LOOPEMITTER_H_ |
454 | |