1//===- SparseTensorIterator.h ---------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORITERATOR_H_
10#define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORITERATOR_H_
11
12#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
13#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
14
15namespace mlir {
16namespace sparse_tensor {
17
18/// The base class for all types of sparse tensor levels. It provides interfaces
19/// to query the loop range (see `peekRangeAt`) and look up the coordinates (see
20/// `peekCrdAt`).
21class SparseTensorLevel {
22 SparseTensorLevel(SparseTensorLevel &&) = delete;
23 SparseTensorLevel(const SparseTensorLevel &) = delete;
24 SparseTensorLevel &operator=(SparseTensorLevel &&) = delete;
25 SparseTensorLevel &operator=(const SparseTensorLevel &) = delete;
26
27public:
28 virtual ~SparseTensorLevel() = default;
29
30 std::string toString() const {
31 return std::string(toMLIRString(lt)) + "[" + std::to_string(val: tid) + "," +
32 std::to_string(val: lvl) + "]";
33 }
34
35 virtual Value peekCrdAt(OpBuilder &b, Location l, ValueRange batchPrefix,
36 Value iv) const = 0;
37
38 /// Peeks the lower and upper bound to *fully* traverse the level with
39 /// the given position `p` that the immediate parent level is current at.
40 /// Returns a pair of values for *posLo* and *loopHi* respectively.
41 ///
42 /// For a dense level, the *posLo* is the linearized position at beginning,
43 /// while *loopHi* is the largest *coordinate*, it also implies that the
44 /// smallest *coordinate* to start the loop is 0.
45 ///
46 /// For a sparse level, [posLo, loopHi) specifies the range of index pointer
47 /// to load coordinate from the coordinate buffer.
48 ///
49 /// `bound` is only used when the level is `non-unique` and deduplication is
50 /// required. It specifies the max upper bound of the non-unique segment.
51 virtual std::pair<Value, Value> peekRangeAt(OpBuilder &b, Location l,
52 ValueRange batchPrefix, Value p,
53 Value segHi = Value()) const = 0;
54
55 Level getLevel() const { return lvl; }
56 LevelType getLT() const { return lt; }
57 Value getSize() const { return lvlSize; }
58 virtual ValueRange getLvlBuffers() const = 0;
59
60 //
61 // Level properties
62 //
63 bool isUnique() const { return isUniqueLT(lt); }
64
65protected:
66 SparseTensorLevel(unsigned tid, unsigned lvl, LevelType lt, Value lvlSize)
67 : tid(tid), lvl(lvl), lt(lt), lvlSize(lvlSize){};
68
69public:
70 const unsigned tid, lvl;
71 const LevelType lt;
72 const Value lvlSize;
73};
74
75enum class IterKind : uint8_t {
76 kTrivial,
77 kDedup,
78 kSubSect,
79 kNonEmptySubSect,
80 kFilter,
81};
82
83/// Helper class that generates loop conditions, etc, to traverse a
84/// sparse tensor level.
85class SparseIterator {
86 SparseIterator(SparseIterator &&) = delete;
87 SparseIterator(const SparseIterator &) = delete;
88 SparseIterator &operator=(SparseIterator &&) = delete;
89 SparseIterator &operator=(const SparseIterator &) = delete;
90
91protected:
92 SparseIterator(IterKind kind, unsigned tid, unsigned lvl,
93 unsigned cursorValsCnt,
94 SmallVectorImpl<Value> &cursorValStorage)
95 : batchCrds(0), kind(kind), tid(tid), lvl(lvl), crd(nullptr),
96 cursorValsCnt(cursorValsCnt), cursorValsStorageRef(cursorValStorage){};
97
98 SparseIterator(IterKind kind, unsigned cursorValsCnt,
99 SmallVectorImpl<Value> &cursorValStorage,
100 const SparseIterator &delegate)
101 : SparseIterator(kind, delegate.tid, delegate.lvl, cursorValsCnt,
102 cursorValStorage){};
103
104 SparseIterator(IterKind kind, const SparseIterator &wrap,
105 unsigned extraCursorCnt = 0)
106 : SparseIterator(kind, wrap.tid, wrap.lvl,
107 extraCursorCnt + wrap.cursorValsCnt,
108 wrap.cursorValsStorageRef) {
109 assert(wrap.cursorValsCnt == wrap.cursorValsStorageRef.size());
110 cursorValsStorageRef.append(NumInputs: extraCursorCnt, Elt: nullptr);
111 assert(cursorValsStorageRef.size() == wrap.cursorValsCnt + extraCursorCnt);
112 };
113
114public:
115 virtual ~SparseIterator() = default;
116
117 void setSparseEmitStrategy(SparseEmitStrategy strategy) {
118 emitStrategy = strategy;
119 }
120
121 virtual std::string getDebugInterfacePrefix() const = 0;
122 virtual SmallVector<Type> getCursorValTypes(OpBuilder &b) const = 0;
123
124 Value getCrd() const { return crd; }
125 ValueRange getBatchCrds() const { return batchCrds; }
126 ValueRange getCursor() const {
127 return ValueRange(cursorValsStorageRef).take_front(n: cursorValsCnt);
128 };
129
130 // Sets the iterate to the specified position.
131 void seek(ValueRange vals) {
132 assert(vals.size() == cursorValsCnt);
133 std::copy(first: vals.begin(), last: vals.end(), result: cursorValsStorageRef.begin());
134 // Now that the iterator is re-positioned, the coordinate becomes invalid.
135 crd = nullptr;
136 }
137
138 //
139 // Iterator properties.
140 //
141
142 // Whether the iterator is a iterator over a batch level.
143 virtual bool isBatchIterator() const = 0;
144
145 // Whether the iterator support random access (i.e., support look up by
146 // *coordinate*). A random access iterator must also traverses a dense space.
147 virtual bool randomAccessible() const = 0;
148
149 // Whether the iterator can simply traversed by a for loop.
150 virtual bool iteratableByFor() const { return false; };
151
152 // Get the upper bound of the sparse space that the iterator might visited. A
153 // sparse space is a subset of a dense space [0, bound), this function returns
154 // *bound*.
155 virtual Value upperBound(OpBuilder &b, Location l) const = 0;
156
157 // Serializes and deserializes the current status to/from a set of values. The
158 // ValueRange should contain values that are sufficient to recover the current
159 // iterating postion (i.e., itVals) as well as loop bound.
160 //
161 // Not every type of iterator supports the operations, e.g., non-empty
162 // subsection iterator does not because the the number of non-empty
163 // subsections can not be determined easily.
164 //
165 // NOTE: All the values should have index type.
166 virtual SmallVector<Value> serialize() const {
167 llvm_unreachable("unsupported");
168 };
169 virtual void deserialize(ValueRange vs) { llvm_unreachable("unsupported"); };
170
171 //
172 // Core functions.
173 //
174
175 // Initializes the iterator according to the parent iterator's state.
176 void genInit(OpBuilder &b, Location l, const SparseIterator *p);
177
178 // Forwards the iterator to the next element.
179 ValueRange forward(OpBuilder &b, Location l);
180
181 // Locate the iterator to the position specified by *crd*, this can only
182 // be done on an iterator that supports randm access.
183 void locate(OpBuilder &b, Location l, Value crd);
184
185 // Returns a boolean value that equals `!it.end()`
186 Value genNotEnd(OpBuilder &b, Location l);
187
188 // Dereferences the iterator, loads the coordinate at the current position.
189 //
190 // The method assumes that the iterator is not currently exhausted (i.e.,
191 // it != it.end()).
192 Value deref(OpBuilder &b, Location l);
193
194 // Actual Implementation provided by derived class.
195 virtual void genInitImpl(OpBuilder &, Location, const SparseIterator *) = 0;
196 virtual ValueRange forwardImpl(OpBuilder &b, Location l) = 0;
197 virtual void locateImpl(OpBuilder &b, Location l, Value crd) {
198 llvm_unreachable("Unsupported");
199 }
200 virtual Value genNotEndImpl(OpBuilder &b, Location l) = 0;
201 virtual Value derefImpl(OpBuilder &b, Location l) = 0;
202 // Gets the current position and the optional *position high* (for
203 // non-unique iterators), the value is essentially the number of sparse
204 // coordinate that the iterator is current visiting. It should be able to
205 // uniquely identify the sparse range for the next level. See
206 // SparseTensorLevel::peekRangeAt();
207 //
208 // Not every type of iterator supports the operation, e.g., non-empty
209 // subsection iterator does not because it represent a range of coordinates
210 // instead of just one.
211 virtual std::pair<Value, Value> getCurPosition() const {
212 llvm_unreachable("unsupported");
213 };
214
215 // Returns a pair of values for *upper*, *lower* bound respectively.
216 virtual std::pair<Value, Value> genForCond(OpBuilder &b, Location l) {
217 assert(randomAccessible());
218 // Random-access iterator is traversed by coordinate, i.e., [curCrd, UB).
219 return {getCrd(), upperBound(b, l)};
220 }
221
222 // Generates a bool value for scf::ConditionOp.
223 std::pair<Value, ValueRange> genWhileCond(OpBuilder &b, Location l,
224 ValueRange vs) {
225 ValueRange rem = linkNewScope(pos: vs);
226 return std::make_pair(x: genNotEnd(b, l), y&: rem);
227 }
228
229 // Generate a conditional it.next() in the following form
230 //
231 // if (cond)
232 // yield it.next
233 // else
234 // yield it
235 //
236 // The function is virtual to allow alternative implementation. For example,
237 // if it.next() is trivial to compute, we can use a select operation instead.
238 // E.g.,
239 //
240 // it = select cond ? it+1 : it
241 virtual ValueRange forwardIf(OpBuilder &b, Location l, Value cond);
242
243 // Update the SSA value for the iterator after entering a new scope.
244 ValueRange linkNewScope(ValueRange pos) {
245 assert(!randomAccessible() && "random accessible iterators are traversed "
246 "by coordinate, call locate() instead.");
247 seek(vals: pos.take_front(n: cursorValsCnt));
248 return pos.drop_front(n: cursorValsCnt);
249 };
250
251protected:
252 void updateCrd(Value crd) { this->crd = crd; }
253
254 MutableArrayRef<Value> getMutCursorVals() {
255 MutableArrayRef<Value> ref = cursorValsStorageRef;
256 return ref.take_front(N: cursorValsCnt);
257 }
258
259 void inherentBatch(const SparseIterator &parent) {
260 batchCrds = parent.batchCrds;
261 }
262
263 SparseEmitStrategy emitStrategy;
264 SmallVector<Value> batchCrds;
265
266public:
267 const IterKind kind; // For LLVM-style RTTI.
268 const unsigned tid, lvl; // tensor level identifier.
269
270private:
271 Value crd; // The sparse coordinate used to coiterate;
272
273 // A range of value that together defines the current state of the
274 // iterator. Only loop variants should be included.
275 //
276 // For trivial iterators, it is the position; for dedup iterators, it consists
277 // of the positon and the segment high, for non-empty subsection iterator, it
278 // is the metadata that specifies the subsection.
279 // Note that the wrapped iterator shares the same storage to maintain itVals
280 // with it wrapper, which means the wrapped iterator might only own a subset
281 // of all the values stored in itValStorage.
282 const unsigned cursorValsCnt;
283 SmallVectorImpl<Value> &cursorValsStorageRef;
284};
285
286/// Helper function to create a TensorLevel object from given `tensor`.
287std::unique_ptr<SparseTensorLevel> makeSparseTensorLevel(OpBuilder &builder,
288 Location loc, Value t,
289 unsigned tid, Level l);
290
291/// Helper function to create a simple SparseIterator object that iterate over
292/// the SparseTensorLevel.
293std::unique_ptr<SparseIterator> makeSimpleIterator(const SparseTensorLevel &stl,
294 SparseEmitStrategy strategy);
295
296/// Helper function to create a synthetic SparseIterator object that iterate
297/// over a dense space specified by [0,`sz`).
298std::pair<std::unique_ptr<SparseTensorLevel>, std::unique_ptr<SparseIterator>>
299makeSynLevelAndIterator(Value sz, unsigned tid, unsigned lvl,
300 SparseEmitStrategy strategy);
301
302/// Helper function to create a SparseIterator object that iterate over a
303/// sliced space, the orignal space (before slicing) is traversed by `sit`.
304std::unique_ptr<SparseIterator>
305makeSlicedLevelIterator(std::unique_ptr<SparseIterator> &&sit, Value offset,
306 Value stride, Value size, SparseEmitStrategy strategy);
307
308/// Helper function to create a SparseIterator object that iterate over the
309/// non-empty subsections set.
310std::unique_ptr<SparseIterator> makeNonEmptySubSectIterator(
311 OpBuilder &b, Location l, const SparseIterator *parent, Value loopBound,
312 std::unique_ptr<SparseIterator> &&delegate, Value size, unsigned stride,
313 SparseEmitStrategy strategy);
314
315/// Helper function to create a SparseIterator object that iterate over a
316/// non-empty subsection created by NonEmptySubSectIterator.
317std::unique_ptr<SparseIterator> makeTraverseSubSectIterator(
318 OpBuilder &b, Location l, const SparseIterator &subsectIter,
319 const SparseIterator &parent, std::unique_ptr<SparseIterator> &&wrap,
320 Value loopBound, unsigned stride, SparseEmitStrategy strategy);
321
322} // namespace sparse_tensor
323} // namespace mlir
324
325#endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORITERATOR_H_
326

source code of mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h