1//===- SparseTensorIterator.h ---------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORITERATOR_H_
10#define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORITERATOR_H_
11
12#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
13#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
14
15namespace mlir {
16namespace sparse_tensor {
17
18// Forward declaration.
19class SparseIterator;
20
21/// The base class for all types of sparse tensor levels. It provides interfaces
22/// to query the loop range (see `peekRangeAt`) and look up the coordinates (see
23/// `peekCrdAt`).
24class SparseTensorLevel {
25 SparseTensorLevel(SparseTensorLevel &&) = delete;
26 SparseTensorLevel(const SparseTensorLevel &) = delete;
27 SparseTensorLevel &operator=(SparseTensorLevel &&) = delete;
28 SparseTensorLevel &operator=(const SparseTensorLevel &) = delete;
29
30public:
31 virtual ~SparseTensorLevel() = default;
32
33 std::string toString() const {
34 return std::string(toMLIRString(lt)) + "[" + std::to_string(val: tid) + "," +
35 std::to_string(val: lvl) + "]";
36 }
37
38 virtual Value peekCrdAt(OpBuilder &b, Location l, ValueRange batchPrefix,
39 Value iv) const = 0;
40
41 /// Peeks the lower and upper bound to *fully* traverse the level with
42 /// the given position `parentPos`, see SparseTensorIterator::getCurPostion(),
43 /// that the immediate parent level is current at. Returns a pair of values
44 /// for *posLo* and *loopHi* respectively.
45 ///
46 /// For a dense level, the *posLo* is the linearized position at beginning,
47 /// while *loopHi* is the largest *coordinate*, it also implies that the
48 /// smallest *coordinate* to start the loop is 0.
49 ///
50 /// For a sparse level, [posLo, loopHi) specifies the range of index pointer
51 /// to load coordinate from the coordinate buffer.
52 virtual std::pair<Value, Value>
53 peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
54 ValueRange parentPos, Value inPadZone = nullptr) const = 0;
55
56 virtual std::pair<Value, Value>
57 collapseRangeBetween(OpBuilder &b, Location l, ValueRange batchPrefix,
58 std::pair<Value, Value> parentRange) const {
59 llvm_unreachable("Not Implemented");
60 };
61
62 Level getLevel() const { return lvl; }
63 LevelType getLT() const { return lt; }
64 Value getSize() const { return lvlSize; }
65 virtual ValueRange getLvlBuffers() const = 0;
66
67 //
68 // Level properties
69 //
70 bool isUnique() const { return isUniqueLT(lt); }
71
72protected:
73 SparseTensorLevel(unsigned tid, unsigned lvl, LevelType lt, Value lvlSize)
74 : tid(tid), lvl(lvl), lt(lt), lvlSize(lvlSize) {};
75
76public:
77 const unsigned tid, lvl;
78 const LevelType lt;
79 const Value lvlSize;
80};
81
82enum class IterKind : uint8_t {
83 kTrivial,
84 kDedup,
85 kSubSect,
86 kNonEmptySubSect,
87 kFilter,
88 kPad,
89};
90
91/// A `SparseIterationSpace` represents a sparse set of coordinates defined by
92/// (possibly multiple) levels of a specific sparse tensor.
93/// TODO: remove `SparseTensorLevel` and switch to SparseIterationSpace when
94/// feature complete.
95class SparseIterationSpace {
96public:
97 SparseIterationSpace() = default;
98 SparseIterationSpace(SparseIterationSpace &) = delete;
99 SparseIterationSpace(SparseIterationSpace &&) = default;
100
101 // Constructs a N-D iteration space.
102 SparseIterationSpace(Location loc, OpBuilder &b, Value t, unsigned tid,
103 std::pair<Level, Level> lvlRange, ValueRange parentPos);
104
105 // Constructs a 1-D iteration space.
106 SparseIterationSpace(Location loc, OpBuilder &b, Value t, unsigned tid,
107 Level lvl, ValueRange parentPos)
108 : SparseIterationSpace(loc, b, t, tid, {lvl, lvl + 1}, parentPos) {};
109
110 bool isUnique() const { return lvls.back()->isUnique(); }
111
112 unsigned getSpaceDim() const { return lvls.size(); }
113
114 // Reconstructs a iteration space directly from the provided ValueRange.
115 static SparseIterationSpace fromValues(IterSpaceType dstTp, ValueRange values,
116 unsigned tid);
117
118 // The inverse operation of `fromValues`.
119 SmallVector<Value> toValues() const {
120 SmallVector<Value> vals;
121 for (auto &stl : lvls) {
122 llvm::append_range(C&: vals, R: stl->getLvlBuffers());
123 vals.push_back(Elt: stl->getSize());
124 }
125 vals.append(IL: {bound.first, bound.second});
126 return vals;
127 }
128
129 const SparseTensorLevel &getLastLvl() const { return *lvls.back(); }
130 ArrayRef<std::unique_ptr<SparseTensorLevel>> getLvlRef() const {
131 return lvls;
132 }
133
134 Value getBoundLo() const { return bound.first; }
135 Value getBoundHi() const { return bound.second; }
136
137 // Extract an iterator to iterate over the sparse iteration space.
138 std::unique_ptr<SparseIterator> extractIterator(OpBuilder &b,
139 Location l) const;
140
141private:
142 SmallVector<std::unique_ptr<SparseTensorLevel>> lvls;
143 std::pair<Value, Value> bound;
144};
145
146/// Helper class that generates loop conditions, etc, to traverse a
147/// sparse tensor level.
148class SparseIterator {
149 SparseIterator(SparseIterator &&) = delete;
150 SparseIterator(const SparseIterator &) = delete;
151 SparseIterator &operator=(SparseIterator &&) = delete;
152 SparseIterator &operator=(const SparseIterator &) = delete;
153
154protected:
155 SparseIterator(IterKind kind, unsigned tid, unsigned lvl,
156 unsigned cursorValsCnt,
157 SmallVectorImpl<Value> &cursorValStorage)
158 : batchCrds(0), kind(kind), tid(tid), lvl(lvl), crd(nullptr),
159 cursorValsCnt(cursorValsCnt), cursorValsStorageRef(cursorValStorage) {};
160
161 SparseIterator(IterKind kind, unsigned cursorValsCnt,
162 SmallVectorImpl<Value> &cursorValStorage,
163 const SparseIterator &delegate)
164 : SparseIterator(kind, delegate.tid, delegate.lvl, cursorValsCnt,
165 cursorValStorage) {};
166
167 SparseIterator(IterKind kind, const SparseIterator &wrap,
168 unsigned extraCursorCnt = 0)
169 : SparseIterator(kind, wrap.tid, wrap.lvl,
170 extraCursorCnt + wrap.cursorValsCnt,
171 wrap.cursorValsStorageRef) {
172 assert(wrap.cursorValsCnt == wrap.cursorValsStorageRef.size());
173 cursorValsStorageRef.append(NumInputs: extraCursorCnt, Elt: nullptr);
174 assert(cursorValsStorageRef.size() == wrap.cursorValsCnt + extraCursorCnt);
175 };
176
177public:
178 virtual ~SparseIterator() = default;
179
180 void setSparseEmitStrategy(SparseEmitStrategy strategy) {
181 emitStrategy = strategy;
182 }
183
184 virtual std::string getDebugInterfacePrefix() const = 0;
185 virtual SmallVector<Type> getCursorValTypes(OpBuilder &b) const = 0;
186
187 Value getCrd() const { return crd; }
188 ValueRange getBatchCrds() const { return batchCrds; }
189 ValueRange getCursor() const {
190 return ValueRange(cursorValsStorageRef).take_front(n: cursorValsCnt);
191 };
192
193 // Sets the iterate to the specified position.
194 void seek(ValueRange vals) {
195 assert(vals.size() == cursorValsCnt);
196 std::copy(first: vals.begin(), last: vals.end(), result: cursorValsStorageRef.begin());
197 // Now that the iterator is re-positioned, the coordinate becomes invalid.
198 crd = nullptr;
199 }
200
201 // Reconstructs a iteration space directly from the provided ValueRange.
202 static std::unique_ptr<SparseIterator>
203 fromValues(IteratorType dstTp, ValueRange values, unsigned tid);
204
205 // The inverse operation of `fromValues`.
206 SmallVector<Value> toValues() const { llvm_unreachable("Not implemented"); }
207
208 //
209 // Iterator properties.
210 //
211
212 // Whether the iterator is a iterator over a batch level.
213 virtual bool isBatchIterator() const = 0;
214
215 // Whether the iterator support random access (i.e., support look up by
216 // *coordinate*). A random access iterator must also traverses a dense space.
217 virtual bool randomAccessible() const = 0;
218
219 // Whether the iterator can simply traversed by a for loop.
220 virtual bool iteratableByFor() const { return false; };
221
222 // Get the upper bound of the sparse space that the iterator might visited. A
223 // sparse space is a subset of a dense space [0, bound), this function returns
224 // *bound*.
225 virtual Value upperBound(OpBuilder &b, Location l) const = 0;
226
227 // Serializes and deserializes the current status to/from a set of values. The
228 // ValueRange should contain values that are sufficient to recover the current
229 // iterating postion (i.e., itVals) as well as loop bound.
230 //
231 // Not every type of iterator supports the operations, e.g., non-empty
232 // subsection iterator does not because the the number of non-empty
233 // subsections can not be determined easily.
234 //
235 // NOTE: All the values should have index type.
236 virtual SmallVector<Value> serialize() const {
237 llvm_unreachable("unsupported");
238 };
239 virtual void deserialize(ValueRange vs) { llvm_unreachable("unsupported"); };
240
241 //
242 // Core functions.
243 //
244
245 // Initializes the iterator according to the parent iterator's state.
246 void genInit(OpBuilder &b, Location l, const SparseIterator *p);
247
248 // Forwards the iterator to the next element.
249 ValueRange forward(OpBuilder &b, Location l);
250
251 // Locate the iterator to the position specified by *crd*, this can only
252 // be done on an iterator that supports randm access.
253 void locate(OpBuilder &b, Location l, Value crd);
254
255 // Returns a boolean value that equals `!it.end()`
256 Value genNotEnd(OpBuilder &b, Location l);
257
258 // Dereferences the iterator, loads the coordinate at the current position.
259 //
260 // The method assumes that the iterator is not currently exhausted (i.e.,
261 // it != it.end()).
262 Value deref(OpBuilder &b, Location l);
263
264 // Actual Implementation provided by derived class.
265 virtual void genInitImpl(OpBuilder &, Location, const SparseIterator *) = 0;
266 virtual ValueRange forwardImpl(OpBuilder &b, Location l) = 0;
267 virtual void locateImpl(OpBuilder &b, Location l, Value crd) {
268 llvm_unreachable("Unsupported");
269 }
270 virtual Value genNotEndImpl(OpBuilder &b, Location l) = 0;
271 virtual Value derefImpl(OpBuilder &b, Location l) = 0;
272 // Gets the ValueRange that together specifies the current position of the
273 // iterator. For a unique level, the position can be a single index points to
274 // the current coordinate being visited. For a non-unique level, an extra
275 // index for the `segment high` is needed to to specifies the range of
276 // duplicated coordinates. The ValueRange should be able to uniquely identify
277 // the sparse range for the next level. See SparseTensorLevel::peekRangeAt();
278 //
279 // Not every type of iterator supports the operation, e.g., non-empty
280 // subsection iterator does not because it represent a range of coordinates
281 // instead of just one.
282 virtual ValueRange getCurPosition() const { return getCursor(); };
283
284 // Returns a pair of values for *upper*, *lower* bound respectively.
285 virtual std::pair<Value, Value> genForCond(OpBuilder &b, Location l) {
286 assert(randomAccessible());
287 // Random-access iterator is traversed by coordinate, i.e., [curCrd, UB).
288 return {getCrd(), upperBound(b, l)};
289 }
290
291 // Generates a bool value for scf::ConditionOp.
292 std::pair<Value, ValueRange> genWhileCond(OpBuilder &b, Location l,
293 ValueRange vs) {
294 ValueRange rem = linkNewScope(pos: vs);
295 return std::make_pair(x: genNotEnd(b, l), y&: rem);
296 }
297
298 // Generate a conditional it.next() in the following form
299 //
300 // if (cond)
301 // yield it.next
302 // else
303 // yield it
304 //
305 // The function is virtual to allow alternative implementation. For example,
306 // if it.next() is trivial to compute, we can use a select operation instead.
307 // E.g.,
308 //
309 // it = select cond ? it+1 : it
310 virtual ValueRange forwardIf(OpBuilder &b, Location l, Value cond);
311
312 // Update the SSA value for the iterator after entering a new scope.
313 ValueRange linkNewScope(ValueRange pos) {
314 assert(!randomAccessible() && "random accessible iterators are traversed "
315 "by coordinate, call locate() instead.");
316 seek(vals: pos.take_front(n: cursorValsCnt));
317 return pos.drop_front(n: cursorValsCnt);
318 };
319
320protected:
321 void updateCrd(Value crd) { this->crd = crd; }
322
323 MutableArrayRef<Value> getMutCursorVals() {
324 MutableArrayRef<Value> ref = cursorValsStorageRef;
325 return ref.take_front(N: cursorValsCnt);
326 }
327
328 void inherentBatch(const SparseIterator &parent) {
329 batchCrds = parent.batchCrds;
330 }
331
332 SparseEmitStrategy emitStrategy;
333 SmallVector<Value> batchCrds;
334
335public:
336 const IterKind kind; // For LLVM-style RTTI.
337 const unsigned tid, lvl; // tensor level identifier.
338
339private:
340 Value crd; // The sparse coordinate used to coiterate;
341
342 // A range of value that together defines the current state of the
343 // iterator. Only loop variants should be included.
344 //
345 // For trivial iterators, it is the position; for dedup iterators, it consists
346 // of the positon and the segment high, for non-empty subsection iterator, it
347 // is the metadata that specifies the subsection.
348 // Note that the wrapped iterator shares the same storage to maintain itVals
349 // with it wrapper, which means the wrapped iterator might only own a subset
350 // of all the values stored in itValStorage.
351 const unsigned cursorValsCnt;
352 SmallVectorImpl<Value> &cursorValsStorageRef;
353};
354
355/// Helper function to create a TensorLevel object from given `tensor`.
356std::unique_ptr<SparseTensorLevel> makeSparseTensorLevel(OpBuilder &b,
357 Location l, Value t,
358 unsigned tid,
359 Level lvl);
360
361/// Helper function to create a TensorLevel object from given ValueRange.
362std::unique_ptr<SparseTensorLevel> makeSparseTensorLevel(LevelType lt, Value sz,
363 ValueRange buffers,
364 unsigned tid, Level l);
365
366/// Helper function to create a simple SparseIterator object that iterate
367/// over the entire iteration space.
368std::unique_ptr<SparseIterator>
369makeSimpleIterator(OpBuilder &b, Location l,
370 const SparseIterationSpace &iterSpace);
371
372/// Helper function to create a simple SparseIterator object that iterate
373/// over the sparse tensor level.
374/// TODO: switch to `SparseIterationSpace` (which support N-D iterator) when
375/// feature complete.
376std::unique_ptr<SparseIterator> makeSimpleIterator(
377 const SparseTensorLevel &stl,
378 SparseEmitStrategy strategy = SparseEmitStrategy::kFunctional);
379
380/// Helper function to create a synthetic SparseIterator object that iterates
381/// over a dense space specified by [0,`sz`).
382std::pair<std::unique_ptr<SparseTensorLevel>, std::unique_ptr<SparseIterator>>
383makeSynLevelAndIterator(Value sz, unsigned tid, unsigned lvl,
384 SparseEmitStrategy strategy);
385
386/// Helper function to create a SparseIterator object that iterates over a
387/// sliced space, the orignal space (before slicing) is traversed by `sit`.
388std::unique_ptr<SparseIterator>
389makeSlicedLevelIterator(std::unique_ptr<SparseIterator> &&sit, Value offset,
390 Value stride, Value size, SparseEmitStrategy strategy);
391
392/// Helper function to create a SparseIterator object that iterates over a
393/// padded sparse level (the padded value must be zero).
394std::unique_ptr<SparseIterator>
395makePaddedIterator(std::unique_ptr<SparseIterator> &&sit, Value padLow,
396 Value padHigh, SparseEmitStrategy strategy);
397
398/// Helper function to create a SparseIterator object that iterate over the
399/// non-empty subsections set.
400std::unique_ptr<SparseIterator> makeNonEmptySubSectIterator(
401 OpBuilder &b, Location l, const SparseIterator *parent, Value loopBound,
402 std::unique_ptr<SparseIterator> &&delegate, Value size, unsigned stride,
403 SparseEmitStrategy strategy);
404
405/// Helper function to create a SparseIterator object that iterates over a
406/// non-empty subsection created by NonEmptySubSectIterator.
407std::unique_ptr<SparseIterator> makeTraverseSubSectIterator(
408 OpBuilder &b, Location l, const SparseIterator &subsectIter,
409 const SparseIterator &parent, std::unique_ptr<SparseIterator> &&wrap,
410 Value loopBound, unsigned stride, SparseEmitStrategy strategy);
411
412} // namespace sparse_tensor
413} // namespace mlir
414
415#endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORITERATOR_H_
416

source code of mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h