1//===- LoopEmitter.h --------------------------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_LOOPEMITTER_H_
10#define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_LOOPEMITTER_H_
11
12#include <vector>
13
14#include "SparseTensorIterator.h"
15
16#include "mlir/Dialect/SparseTensor/IR/Enums.h"
17#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
18#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
19#include "mlir/Dialect/SparseTensor/Utils/Merger.h"
20#include "mlir/IR/PatternMatch.h"
21
22namespace mlir {
23namespace sparse_tensor {
24
25// A compressed <tensor id, level> pair.
26using TensorLevel = unsigned;
27
28//
29// SparseTensorLoopEmiter class, manages sparse tensors and helps to
30// generate loop structure to (co)-iterate sparse tensors.
31//
32// An example usage:
33// To generate the following loops over T1<?x?> and T2<?x?>
34//
35// for i in TENSOR_1_0 {
36// for j : TENSOR_2_0 {
37// for k : TENSOR_1_1 {}
38// for k : TENSOR_2_1 {}
39// }
40// }
41//
42// One can use
43//
44// LoopEmiter loopEmiter({T1, T1});
45// loopEmiter.initializeLoopEmit();
46// loopEmiter.enterLoopOverTensorAtLvl(T1, 0);
47// loopEmiter.enterLoopOverTensorAtLvl(T2, 0);
48// loopEmiter.enterLoopOverTensorAtLvl(T1, 1);
49// loopEmiter.exitCurrentLoop();
50// loopEmiter.enterLoopOverTensorAtLvl(T2, 1);
51// loopEmiter.exitCurrentLoop(); // exit k
52// loopEmiter.exitCurrentLoop(); // exit j
53// loopEmiter.exitCurrentLoop(); // exit i
54//
55class LoopEmitter {
56public:
57 /// Optional callback function to setup dense output tensors when
58 /// initializing the loop emitter (e.g., to fill a dense output with zeros).
59 using OutputUpdater = function_ref<Value(OpBuilder &builder, Location loc,
60 Value memref, Value tensor)>;
61
62 /// Optional callback function to set the bound for the synthetic tensor,
63 /// which essentially is the dense loop bound.
64 using SynTensorBoundSetter =
65 function_ref<Value(OpBuilder &builder, Location loc, Level lvl)>;
66
67 // Map from [tid, lvl] to a list of dependent [LoopId, coeffecient] for
68 // subscript expressions on sparse tensors.
69 //
70 // E.g., for affine index (2 * d0 + d1), it depends on loop d0 and d1 (for
71 // affine expression reduction) and uses 2 and 1 for coefficients on d0, d1
72 // respectively. If the list is empty, it means that there is no affine
73 // expression on the input [tid, lvl].
74 //
75 // NOTE: LoopEmitter assumes that the loop id is consistent with the loop
76 // order, i.e., loop `d0` will be generated before loop `d1`.
77 using DependentLvlGetter =
78 function_ref<std::vector<std::pair<LoopId, unsigned>>(TensorId, Level)>;
79
80 LoopEmitter() = default;
81
82 /// Takes an array of input tensors, which the generated loops will
83 /// iterate over. Each tensor is given a `TensorId` (numerically equal
84 /// to the position of that tensor `Value` in the array). Setting
85 /// `isSparseOut` indicates that the sparse output tensor is empty,
86 /// so the loop emitter will generate loops over it according to the
87 /// level-sizes.
88 void
89 initialize(ValueRange tensors, StringAttr loopTag = nullptr,
90 bool hasOutput = false, bool isSparseOut = false,
91 unsigned numLoops = 0, DependentLvlGetter getter = nullptr,
92 SparseEmitStrategy emitStrategy = SparseEmitStrategy::kFunctional);
93
94 explicit LoopEmitter(
95 ValueRange tensors, StringAttr loopTag = nullptr, bool hasOutput = false,
96 bool isSparseOut = false, unsigned numLoops = 0,
97 DependentLvlGetter getter = nullptr,
98 SparseEmitStrategy emitStrategy = SparseEmitStrategy::kFunctional);
99
100 /// Starts a loop emitting session by generating all the buffers needed
101 /// for iterating over the tensors.
102 void initializeLoopEmit(OpBuilder &builder, Location loc,
103 OutputUpdater updater = nullptr,
104 SynTensorBoundSetter synSetter = nullptr);
105
106 /// Generates code to compute an affine expression whose variables are
107 /// `LoopId`s (i.e., `cast<AffineDimExpr>(a).getPosition()` is a valid
108 /// `LoopId`).
109 Value genAffine(OpBuilder &builder, Location loc, AffineExpr a);
110
111 /// Enters a new loop sequence, the loops within the same sequence starts
112 /// from the break points of previous loop instead of starting over from 0.
113 /// e.g.,
114 /// {
115 /// // loop sequence start.
116 /// p0 = while(xxx)
117 /// ...
118 /// break p0
119 ///
120 /// // Starts loop from p0
121 /// for (i = p0; i < end; i++)
122 /// ...
123 /// // loop sequence end.
124 /// }
125 void enterNewLoopSeq(OpBuilder &builder, Location loc,
126 ArrayRef<TensorLevel> tidLvls);
127
128 /// Exits the current loop sequence, this will reset universal index to 0.
129 void exitCurrentLoopSeq(OpBuilder &builder, Location loc);
130
131 /// Emits the address for a dense level based on the value evaluated by the
132 /// provided affine expression.
133 void locateLvlAtAffineAddress(OpBuilder &builder, Location loc,
134 TensorLevel tidLvl, AffineExpr lvlExpr);
135
136 // TODO: Get rid of `lvls` in the argument list? Track the level we
137 // are currently at internally. Then it would be enterNextLvlForTensor.
138 // Still need a way to specify the lvl for non-annotated tensors though,
139 // as those can be accessed out of order.
140 //
141 /// Emits a co-iteration loop over a set of tensors.
142 /// Emits loop over tensor_tid_lvl, it assumes that loops between
143 /// tensor_tid_[0, lvl - 1] have already been generated.
144 /// The function will also perform in-place update on the `reduc` vector to
145 /// return the reduction variable used inside the generated loop.
146 Operation *enterCoIterationOverTensorsAtLvls(
147 OpBuilder &builder, Location loc, ArrayRef<TensorLevel> tidLvls,
148 unsigned numCases, MutableArrayRef<Value> reduc = {},
149 bool isParallel = false, bool needsUniv = false);
150
151 Region *enterCurrentCoIterationCase(OpBuilder &builder, Location loc,
152 I64BitSet caseBit, unsigned caseIdx,
153 MutableArrayRef<Value> reduc);
154
155 /// Generates code to exit the current loop (e.g., generates yields, forwards
156 /// loop induction variables, etc).
157 void exitCurrentLoop(RewriterBase &rewriter, Location loc,
158 MutableArrayRef<Value> reduc = {});
159
160 /// Get the range of values for all induction variables.
161 auto getLoopIVsRange() const {
162 return llvm::map_range(C: loopStack, F: [](const LoopInfo &li) { return li.iv; });
163 }
164
165 /// Fills the out-parameter with the loop induction variables for all
166 /// loops in the current loop-stack.
167 SmallVector<Value> getLoopIVs() const {
168 return llvm::to_vector(Range: getLoopIVsRange());
169 }
170
171 /// Gets the current depth of the loop-stack.
172 LoopId getCurrentDepth() const { return llvm::range_size(Range: getLoopIVsRange()); }
173
174 /// Gets loop induction variable for the given loop
175 Value getLoopIV(LoopId n) const {
176 if (n >= getCurrentDepth())
177 return Value();
178 auto it = getLoopIVsRange().begin();
179 std::advance(i&: it, n: n);
180 return *it;
181 }
182
183 /// Gets the total number of manifest tensors (excluding the synthetic
184 /// tensor).
185 unsigned getNumManifestTensors() const { return tensors.size(); }
186
187 /// Gets the total number of tensors that loopEmitter is operating on.
188 unsigned getNumTensors() const {
189 // Manifest tensors with one synthetic tensor at the end.
190 return getNumManifestTensors() + 1;
191 }
192
193 /// Gets the TensorId for synthetic tensor.
194 TensorId getSynTensorId() const { return tensors.size(); }
195
196 /// Gets the TensorId for output tensor.
197 TensorId getOutTensorId() const {
198 assert(hasOutput);
199 return getNumManifestTensors() - 1;
200 }
201
202 /// Compresses a TensorId and Level into a TensorLevel.
203 TensorLevel makeTensorLevel(TensorId t, Level l) const {
204 return l * getNumTensors() + t;
205 }
206
207 /// De-compresses a TensorLevel back to a pair of TensorId and Level.
208 std::pair<TensorId, Level> unpackTensorLevel(TensorLevel tidLvl) const {
209 unsigned nt = getNumTensors();
210 return std::make_pair(x: tidLvl % nt, y: tidLvl / nt);
211 }
212
213 /// Converts a range of TensorLevel to a range of std::pair<TensorId, Level>
214 template <class ContainerTy>
215 auto unpackTensorLevelRange(ContainerTy &&c) const {
216 using EltTy = decltype(*c.begin());
217 static_assert(std::is_same_v<llvm::remove_cvref_t<EltTy>, TensorLevel>,
218 "Must be unpacking a TensorLevel range");
219 return llvm::map_range(std::forward<ContainerTy>(c), [this](EltTy tl) {
220 return this->unpackTensorLevel(tidLvl: tl);
221 });
222 }
223
224 ///
225 /// Getters.
226 ///
227 SmallVector<Value> getValPosits(TensorId tid) const {
228 // Returns the iterator if we are generating sparse (co)iterate-based loops.
229 if (emitStrategy == SparseEmitStrategy::kSparseIterator)
230 return {spIterVals[tid].back()};
231
232 // Returns {[batch coords], last-level position}.
233 SmallVector<Value> batchCrds = iters[tid].back().back()->getBatchCrds();
234 Value lastLvlPos = iters[tid].back().back()->getCurPosition().front();
235 batchCrds.push_back(Elt: lastLvlPos);
236 return batchCrds;
237 };
238 Value getCoord(TensorId tid, Level lvl) const {
239 return getCurIterator(tid, lvl).getCrd();
240 };
241 const std::vector<Value> &getValBuffer() const { return valBuffer; };
242
243 constexpr static llvm::StringLiteral getLoopEmitterLoopAttrName() {
244 return llvm::StringLiteral("Emitted from");
245 }
246
247private:
248 ///
249 /// Structure definitions that hold different kinds of loops information.
250 ///
251
252 // LoopInfo stores information of a loop generated by LoopEmitter. E.g.,
253 // the set of tensors levels that the loop is iterating over.
254 struct LoopInfo final {
255 LoopInfo(ArrayRef<TensorLevel> tidLvls, Operation *loop, Block *userBlock,
256 Value iv, StringAttr loopTag)
257 : tidLvls(tidLvls), loop(loop), userCodeBlock(userBlock), iv(iv) {
258 // Attached a special tag to loop emitter generated loop.
259 if (loopTag)
260 loop->setAttr(LoopEmitter::getLoopEmitterLoopAttrName(), loopTag);
261 }
262 // The set of <tensor, lvl>, with *only* trivial index expressions, that are
263 // used as the condition for the generated loop. Extra information is
264 // required for levels with non-tivial index expressions, which is
265 // maintained by the sliceDrivenInfo array below.
266 const llvm::SmallVector<TensorLevel> tidLvls;
267 Operation *loop; // the loop operation
268 Block *const userCodeBlock; // the block holding users' generated code.
269 Value iv; // the induction variable for the loop
270 };
271
272 void categorizeIterators(ArrayRef<TensorLevel> tidLvls,
273 SmallVectorImpl<SparseIterator *> &raIters,
274 SmallVectorImpl<SparseIterator *> &spIters);
275 ///
276 /// LoopEmitter internal helper functions.
277 ///
278
279 using LoopBodyBuilder = llvm::function_ref<void(OpBuilder &, Location, Value,
280 MutableArrayRef<Value>)>;
281
282 /// Whether the list of the sparse condition should be iterated by for loop.
283 bool shouldIteratedByForLoop(ArrayRef<SparseIterator *> spIters);
284
285 /// Generates instructions to compute the coordinate of tensors[tid][lvl]
286 /// under the current loop context. The final argument is the
287 /// collapsed-output level, whereas this function handles converting
288 /// that to the uncollapsed-input level
289 Value genSparseCrd(OpBuilder &builder, Location loc, TensorId tid,
290 Level dstLvl);
291
292 bool isSynTensor(TensorId tid) const { return tid == getSynTensorId(); }
293
294 bool isOutputTensor(TensorId tid) const {
295 return hasOutput && tid == getOutTensorId();
296 }
297
298 bool isSparseOutput(TensorId tid) const {
299 return isOutputTensor(tid) && isSparseOut;
300 }
301
302 bool isValidLevel(TensorId tid, Level lvl) const {
303 return tid < lvls.size() && lvl < lvls[tid].size();
304 }
305
306 /// Prepares loop for iterating over `tensor[lvl]`, under the assumption
307 /// that `tensor[0...lvl-1]` loops have already been set up.
308 void prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc,
309 TensorId tid, Level lvl);
310
311 /// Emits a for loop to iterate over a tensor level with the provided
312 /// lower bound `lo` and upper bound `hi`. Apart from iterating just
313 /// single tensor level, for loops can be used for slice-driven loop on
314 /// dense level too.
315 /// Returns a pair: the loop generated and the value for the induction
316 /// variable.
317 std::pair<Operation *, Value>
318 emitForLoopOverTensorAtLvl(OpBuilder &builder, Location loc,
319 SparseIterator &iter, MutableArrayRef<Value> reduc,
320 bool isParallel);
321
322 /// Emits a while loop to co-iterate over a list of sparse condition, or
323 /// (complex) single sparse condition that can not be handled by for loop
324 /// (e.g., index reduction loop).
325 /// Returns a pair: the loop generated and the value for the induction
326 /// variable (which is the minimum coordinate of all the tensor that being
327 /// iterated).
328 std::pair<Operation *, Value>
329 emitWhileLoopOverTensorsAtLvls(OpBuilder &builder, Location loc,
330 ArrayRef<SparseIterator *> iters,
331 MutableArrayRef<Value> reduc, bool needsUniv);
332
333 /// Exits a for loop, returns the reduction results, e.g.,
334 /// For sequential for loops:
335 /// %ret = for () {
336 /// ...
337 /// %val = addi %args, %c
338 /// yield %val
339 /// }
340 /// For parallel loops, the following generated code by users:
341 /// %ret = parallel () init(%args) {
342 /// ...
343 /// %val = op %args, %c
344 /// }
345 /// will be transformed into
346 /// %ret = parallel () init(%args) {
347 /// ...
348 /// scf.reduce(%c) bb0(%0, %1){
349 /// %val = op %0, %1
350 /// scf.reduce.return %val
351 /// }
352 /// }
353 /// NOTE: only one instruction will be moved into reduce block,
354 /// transformation will fail if multiple instructions are used to compute
355 /// the reduction value. Return %ret to user, while %val is provided by
356 /// users (`reduc`).
357 void exitForLoop(RewriterBase &rewriter, Location loc,
358 MutableArrayRef<Value> reduc);
359
360 /// Exits a while loop, returns the reduction results.
361 void exitWhileLoop(OpBuilder &builder, Location loc,
362 MutableArrayRef<Value> reduc);
363
364 //
365 // Slice-driven loop related methods.
366 //
367
368 void initSubSectIterator(OpBuilder &builder, Location loc);
369
370 /// Get the reduced number of contraints on tensor[tid][lvl].
371 unsigned redDepOnLevel(TensorId tid, Level lvl) const {
372 return levelReducedDep[tid][lvl];
373 };
374
375 SparseIterator &getCurIterator(TensorId tid, Level lvl) const {
376 if (dependentLvlMap[tid][lvl].empty())
377 return *iters[tid][lvl].back();
378
379 assert(redDepOnLevel(tid, lvl) >= 1);
380 return *iters[tid][lvl][redDepOnLevel(tid, lvl) - 1];
381 }
382
383 std::unique_ptr<SparseIterator>
384 makeLevelIterator(OpBuilder &builder, Location loc, TensorId tid, Level l);
385
386 /// A optional string attribute that should be attached to the loop
387 /// generated by loop emitter, it might help following passes to identify
388 /// loops that operates on sparse tensors more easily.
389 StringAttr loopTag;
390 /// Whether the loop emitter needs to treat the last tensor as the output
391 /// tensor.
392 bool hasOutput;
393 bool isSparseOut;
394 SparseEmitStrategy emitStrategy;
395
396 //
397 // Fields which have `numTensor` many entries.
398 //
399
400 /// Input and (optional) output tensors.
401 std::vector<Value> tensors;
402 std::vector<Value> loopHighs;
403 std::vector<std::vector<std::unique_ptr<SparseTensorLevel>>> lvls;
404 std::vector<std::vector<std::vector<std::unique_ptr<SparseIterator>>>> iters;
405 std::vector<Value> valBuffer; // to_value
406
407 // Map from [tid, level] to a list of dependent [tidlevel, coefficient].
408 // See comments for `DependentLvlGetter`.
409 std::vector<std::vector<std::vector<std::pair<LoopId, unsigned>>>>
410 dependentLvlMap;
411
412 // The (size, stride) for each conceptual slice used for index reduction
413 // loops.
414 std::vector<std::vector<std::vector<std::pair<Value, unsigned>>>> sliceMeta;
415
416 // The number of reduced dependencies on a tensor level so far.
417 std::vector<std::vector<unsigned>> levelReducedDep;
418
419 //
420 // Fields which have at most `numLoops` many entries.
421 //
422
423 /// Loop Stack, stores the information of all the nested loops that are
424 /// alive.
425 std::vector<LoopInfo> loopStack;
426
427 // Loop Sequence Stack, stores the universal index for the current loop
428 // sequence. and a list of tid level that the loop sequence traverse.
429 std::vector<std::pair<Value, std::vector<TensorLevel>>> loopSeqStack;
430
431 //
432 // EXPERIMENTAL:
433 // Fields for generating sparse-iterator-based loop.
434 //
435
436 std::vector<std::vector<Value>> spIterVals;
437};
438
439//
440// Utils functions to generate sparse loops.
441//
442
443// Generate a while loop that co-iterates over a set of iterators.
444std::pair<Operation *, Value> genCoIteration(OpBuilder &builder, Location loc,
445 ArrayRef<SparseIterator *> iters,
446 MutableArrayRef<Value> reduc,
447 Value uniIdx,
448 bool userReducFirst = false);
449
450} // namespace sparse_tensor
451} // namespace mlir
452
453#endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_LOOPEMITTER_H_
454

source code of mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h