1 | //===- SparseTensorIterator.h ---------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORITERATOR_H_ |
10 | #define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORITERATOR_H_ |
11 | |
12 | #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" |
13 | #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" |
14 | |
15 | namespace mlir { |
16 | namespace sparse_tensor { |
17 | |
18 | /// The base class for all types of sparse tensor levels. It provides interfaces |
19 | /// to query the loop range (see `peekRangeAt`) and look up the coordinates (see |
20 | /// `peekCrdAt`). |
21 | class SparseTensorLevel { |
22 | SparseTensorLevel(SparseTensorLevel &&) = delete; |
23 | SparseTensorLevel(const SparseTensorLevel &) = delete; |
24 | SparseTensorLevel &operator=(SparseTensorLevel &&) = delete; |
25 | SparseTensorLevel &operator=(const SparseTensorLevel &) = delete; |
26 | |
27 | public: |
28 | virtual ~SparseTensorLevel() = default; |
29 | |
30 | std::string toString() const { |
31 | return std::string(toMLIRString(lt)) + "[" + std::to_string(val: tid) + "," + |
32 | std::to_string(val: lvl) + "]" ; |
33 | } |
34 | |
35 | virtual Value peekCrdAt(OpBuilder &b, Location l, ValueRange batchPrefix, |
36 | Value iv) const = 0; |
37 | |
38 | /// Peeks the lower and upper bound to *fully* traverse the level with |
39 | /// the given position `p` that the immediate parent level is current at. |
40 | /// Returns a pair of values for *posLo* and *loopHi* respectively. |
41 | /// |
42 | /// For a dense level, the *posLo* is the linearized position at beginning, |
43 | /// while *loopHi* is the largest *coordinate*, it also implies that the |
44 | /// smallest *coordinate* to start the loop is 0. |
45 | /// |
46 | /// For a sparse level, [posLo, loopHi) specifies the range of index pointer |
47 | /// to load coordinate from the coordinate buffer. |
48 | /// |
49 | /// `bound` is only used when the level is `non-unique` and deduplication is |
50 | /// required. It specifies the max upper bound of the non-unique segment. |
51 | virtual std::pair<Value, Value> peekRangeAt(OpBuilder &b, Location l, |
52 | ValueRange batchPrefix, Value p, |
53 | Value segHi = Value()) const = 0; |
54 | |
55 | Level getLevel() const { return lvl; } |
56 | LevelType getLT() const { return lt; } |
57 | Value getSize() const { return lvlSize; } |
58 | virtual ValueRange getLvlBuffers() const = 0; |
59 | |
60 | // |
61 | // Level properties |
62 | // |
63 | bool isUnique() const { return isUniqueLT(lt); } |
64 | |
65 | protected: |
66 | SparseTensorLevel(unsigned tid, unsigned lvl, LevelType lt, Value lvlSize) |
67 | : tid(tid), lvl(lvl), lt(lt), lvlSize(lvlSize){}; |
68 | |
69 | public: |
70 | const unsigned tid, lvl; |
71 | const LevelType lt; |
72 | const Value lvlSize; |
73 | }; |
74 | |
75 | enum class IterKind : uint8_t { |
76 | kTrivial, |
77 | kDedup, |
78 | kSubSect, |
79 | kNonEmptySubSect, |
80 | kFilter, |
81 | }; |
82 | |
83 | /// Helper class that generates loop conditions, etc, to traverse a |
84 | /// sparse tensor level. |
85 | class SparseIterator { |
86 | SparseIterator(SparseIterator &&) = delete; |
87 | SparseIterator(const SparseIterator &) = delete; |
88 | SparseIterator &operator=(SparseIterator &&) = delete; |
89 | SparseIterator &operator=(const SparseIterator &) = delete; |
90 | |
91 | protected: |
92 | SparseIterator(IterKind kind, unsigned tid, unsigned lvl, |
93 | unsigned cursorValsCnt, |
94 | SmallVectorImpl<Value> &cursorValStorage) |
95 | : batchCrds(0), kind(kind), tid(tid), lvl(lvl), crd(nullptr), |
96 | cursorValsCnt(cursorValsCnt), cursorValsStorageRef(cursorValStorage){}; |
97 | |
98 | SparseIterator(IterKind kind, unsigned cursorValsCnt, |
99 | SmallVectorImpl<Value> &cursorValStorage, |
100 | const SparseIterator &delegate) |
101 | : SparseIterator(kind, delegate.tid, delegate.lvl, cursorValsCnt, |
102 | cursorValStorage){}; |
103 | |
104 | SparseIterator(IterKind kind, const SparseIterator &wrap, |
105 | unsigned = 0) |
106 | : SparseIterator(kind, wrap.tid, wrap.lvl, |
107 | extraCursorCnt + wrap.cursorValsCnt, |
108 | wrap.cursorValsStorageRef) { |
109 | assert(wrap.cursorValsCnt == wrap.cursorValsStorageRef.size()); |
110 | cursorValsStorageRef.append(NumInputs: extraCursorCnt, Elt: nullptr); |
111 | assert(cursorValsStorageRef.size() == wrap.cursorValsCnt + extraCursorCnt); |
112 | }; |
113 | |
114 | public: |
115 | virtual ~SparseIterator() = default; |
116 | |
117 | void setSparseEmitStrategy(SparseEmitStrategy strategy) { |
118 | emitStrategy = strategy; |
119 | } |
120 | |
121 | virtual std::string getDebugInterfacePrefix() const = 0; |
122 | virtual SmallVector<Type> getCursorValTypes(OpBuilder &b) const = 0; |
123 | |
124 | Value getCrd() const { return crd; } |
125 | ValueRange getBatchCrds() const { return batchCrds; } |
126 | ValueRange getCursor() const { |
127 | return ValueRange(cursorValsStorageRef).take_front(n: cursorValsCnt); |
128 | }; |
129 | |
130 | // Sets the iterate to the specified position. |
131 | void seek(ValueRange vals) { |
132 | assert(vals.size() == cursorValsCnt); |
133 | std::copy(first: vals.begin(), last: vals.end(), result: cursorValsStorageRef.begin()); |
134 | // Now that the iterator is re-positioned, the coordinate becomes invalid. |
135 | crd = nullptr; |
136 | } |
137 | |
138 | // |
139 | // Iterator properties. |
140 | // |
141 | |
142 | // Whether the iterator is a iterator over a batch level. |
143 | virtual bool isBatchIterator() const = 0; |
144 | |
145 | // Whether the iterator support random access (i.e., support look up by |
146 | // *coordinate*). A random access iterator must also traverses a dense space. |
147 | virtual bool randomAccessible() const = 0; |
148 | |
149 | // Whether the iterator can simply traversed by a for loop. |
150 | virtual bool iteratableByFor() const { return false; }; |
151 | |
152 | // Get the upper bound of the sparse space that the iterator might visited. A |
153 | // sparse space is a subset of a dense space [0, bound), this function returns |
154 | // *bound*. |
155 | virtual Value upperBound(OpBuilder &b, Location l) const = 0; |
156 | |
157 | // Serializes and deserializes the current status to/from a set of values. The |
158 | // ValueRange should contain values that are sufficient to recover the current |
159 | // iterating postion (i.e., itVals) as well as loop bound. |
160 | // |
161 | // Not every type of iterator supports the operations, e.g., non-empty |
162 | // subsection iterator does not because the the number of non-empty |
163 | // subsections can not be determined easily. |
164 | // |
165 | // NOTE: All the values should have index type. |
166 | virtual SmallVector<Value> serialize() const { |
167 | llvm_unreachable("unsupported" ); |
168 | }; |
169 | virtual void deserialize(ValueRange vs) { llvm_unreachable("unsupported" ); }; |
170 | |
171 | // |
172 | // Core functions. |
173 | // |
174 | |
175 | // Initializes the iterator according to the parent iterator's state. |
176 | void genInit(OpBuilder &b, Location l, const SparseIterator *p); |
177 | |
178 | // Forwards the iterator to the next element. |
179 | ValueRange forward(OpBuilder &b, Location l); |
180 | |
181 | // Locate the iterator to the position specified by *crd*, this can only |
182 | // be done on an iterator that supports randm access. |
183 | void locate(OpBuilder &b, Location l, Value crd); |
184 | |
185 | // Returns a boolean value that equals `!it.end()` |
186 | Value genNotEnd(OpBuilder &b, Location l); |
187 | |
188 | // Dereferences the iterator, loads the coordinate at the current position. |
189 | // |
190 | // The method assumes that the iterator is not currently exhausted (i.e., |
191 | // it != it.end()). |
192 | Value deref(OpBuilder &b, Location l); |
193 | |
194 | // Actual Implementation provided by derived class. |
195 | virtual void genInitImpl(OpBuilder &, Location, const SparseIterator *) = 0; |
196 | virtual ValueRange forwardImpl(OpBuilder &b, Location l) = 0; |
197 | virtual void locateImpl(OpBuilder &b, Location l, Value crd) { |
198 | llvm_unreachable("Unsupported" ); |
199 | } |
200 | virtual Value genNotEndImpl(OpBuilder &b, Location l) = 0; |
201 | virtual Value derefImpl(OpBuilder &b, Location l) = 0; |
202 | // Gets the current position and the optional *position high* (for |
203 | // non-unique iterators), the value is essentially the number of sparse |
204 | // coordinate that the iterator is current visiting. It should be able to |
205 | // uniquely identify the sparse range for the next level. See |
206 | // SparseTensorLevel::peekRangeAt(); |
207 | // |
208 | // Not every type of iterator supports the operation, e.g., non-empty |
209 | // subsection iterator does not because it represent a range of coordinates |
210 | // instead of just one. |
211 | virtual std::pair<Value, Value> getCurPosition() const { |
212 | llvm_unreachable("unsupported" ); |
213 | }; |
214 | |
215 | // Returns a pair of values for *upper*, *lower* bound respectively. |
216 | virtual std::pair<Value, Value> genForCond(OpBuilder &b, Location l) { |
217 | assert(randomAccessible()); |
218 | // Random-access iterator is traversed by coordinate, i.e., [curCrd, UB). |
219 | return {getCrd(), upperBound(b, l)}; |
220 | } |
221 | |
222 | // Generates a bool value for scf::ConditionOp. |
223 | std::pair<Value, ValueRange> genWhileCond(OpBuilder &b, Location l, |
224 | ValueRange vs) { |
225 | ValueRange rem = linkNewScope(pos: vs); |
226 | return std::make_pair(x: genNotEnd(b, l), y&: rem); |
227 | } |
228 | |
229 | // Generate a conditional it.next() in the following form |
230 | // |
231 | // if (cond) |
232 | // yield it.next |
233 | // else |
234 | // yield it |
235 | // |
236 | // The function is virtual to allow alternative implementation. For example, |
237 | // if it.next() is trivial to compute, we can use a select operation instead. |
238 | // E.g., |
239 | // |
240 | // it = select cond ? it+1 : it |
241 | virtual ValueRange forwardIf(OpBuilder &b, Location l, Value cond); |
242 | |
243 | // Update the SSA value for the iterator after entering a new scope. |
244 | ValueRange linkNewScope(ValueRange pos) { |
245 | assert(!randomAccessible() && "random accessible iterators are traversed " |
246 | "by coordinate, call locate() instead." ); |
247 | seek(vals: pos.take_front(n: cursorValsCnt)); |
248 | return pos.drop_front(n: cursorValsCnt); |
249 | }; |
250 | |
251 | protected: |
252 | void updateCrd(Value crd) { this->crd = crd; } |
253 | |
254 | MutableArrayRef<Value> getMutCursorVals() { |
255 | MutableArrayRef<Value> ref = cursorValsStorageRef; |
256 | return ref.take_front(N: cursorValsCnt); |
257 | } |
258 | |
259 | void inherentBatch(const SparseIterator &parent) { |
260 | batchCrds = parent.batchCrds; |
261 | } |
262 | |
263 | SparseEmitStrategy emitStrategy; |
264 | SmallVector<Value> batchCrds; |
265 | |
266 | public: |
267 | const IterKind kind; // For LLVM-style RTTI. |
268 | const unsigned tid, lvl; // tensor level identifier. |
269 | |
270 | private: |
271 | Value crd; // The sparse coordinate used to coiterate; |
272 | |
273 | // A range of value that together defines the current state of the |
274 | // iterator. Only loop variants should be included. |
275 | // |
276 | // For trivial iterators, it is the position; for dedup iterators, it consists |
277 | // of the positon and the segment high, for non-empty subsection iterator, it |
278 | // is the metadata that specifies the subsection. |
279 | // Note that the wrapped iterator shares the same storage to maintain itVals |
280 | // with it wrapper, which means the wrapped iterator might only own a subset |
281 | // of all the values stored in itValStorage. |
282 | const unsigned cursorValsCnt; |
283 | SmallVectorImpl<Value> &cursorValsStorageRef; |
284 | }; |
285 | |
286 | /// Helper function to create a TensorLevel object from given `tensor`. |
287 | std::unique_ptr<SparseTensorLevel> makeSparseTensorLevel(OpBuilder &builder, |
288 | Location loc, Value t, |
289 | unsigned tid, Level l); |
290 | |
291 | /// Helper function to create a simple SparseIterator object that iterate over |
292 | /// the SparseTensorLevel. |
293 | std::unique_ptr<SparseIterator> makeSimpleIterator(const SparseTensorLevel &stl, |
294 | SparseEmitStrategy strategy); |
295 | |
296 | /// Helper function to create a synthetic SparseIterator object that iterate |
297 | /// over a dense space specified by [0,`sz`). |
298 | std::pair<std::unique_ptr<SparseTensorLevel>, std::unique_ptr<SparseIterator>> |
299 | makeSynLevelAndIterator(Value sz, unsigned tid, unsigned lvl, |
300 | SparseEmitStrategy strategy); |
301 | |
302 | /// Helper function to create a SparseIterator object that iterate over a |
303 | /// sliced space, the orignal space (before slicing) is traversed by `sit`. |
304 | std::unique_ptr<SparseIterator> |
305 | makeSlicedLevelIterator(std::unique_ptr<SparseIterator> &&sit, Value offset, |
306 | Value stride, Value size, SparseEmitStrategy strategy); |
307 | |
308 | /// Helper function to create a SparseIterator object that iterate over the |
309 | /// non-empty subsections set. |
310 | std::unique_ptr<SparseIterator> makeNonEmptySubSectIterator( |
311 | OpBuilder &b, Location l, const SparseIterator *parent, Value loopBound, |
312 | std::unique_ptr<SparseIterator> &&delegate, Value size, unsigned stride, |
313 | SparseEmitStrategy strategy); |
314 | |
315 | /// Helper function to create a SparseIterator object that iterate over a |
316 | /// non-empty subsection created by NonEmptySubSectIterator. |
317 | std::unique_ptr<SparseIterator> makeTraverseSubSectIterator( |
318 | OpBuilder &b, Location l, const SparseIterator &subsectIter, |
319 | const SparseIterator &parent, std::unique_ptr<SparseIterator> &&wrap, |
320 | Value loopBound, unsigned stride, SparseEmitStrategy strategy); |
321 | |
322 | } // namespace sparse_tensor |
323 | } // namespace mlir |
324 | |
325 | #endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORITERATOR_H_ |
326 | |