1 | //===- SparseTensorIterator.h ---------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORITERATOR_H_ |
10 | #define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORITERATOR_H_ |
11 | |
12 | #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" |
13 | #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" |
14 | |
15 | namespace mlir { |
16 | namespace sparse_tensor { |
17 | |
18 | // Forward declaration. |
19 | class SparseIterator; |
20 | |
21 | /// The base class for all types of sparse tensor levels. It provides interfaces |
22 | /// to query the loop range (see `peekRangeAt`) and look up the coordinates (see |
23 | /// `peekCrdAt`). |
24 | class SparseTensorLevel { |
25 | SparseTensorLevel(SparseTensorLevel &&) = delete; |
26 | SparseTensorLevel(const SparseTensorLevel &) = delete; |
27 | SparseTensorLevel &operator=(SparseTensorLevel &&) = delete; |
28 | SparseTensorLevel &operator=(const SparseTensorLevel &) = delete; |
29 | |
30 | public: |
31 | virtual ~SparseTensorLevel() = default; |
32 | |
33 | std::string toString() const { |
34 | return std::string(toMLIRString(lt)) + "[" + std::to_string(val: tid) + "," + |
35 | std::to_string(val: lvl) + "]" ; |
36 | } |
37 | |
38 | virtual Value peekCrdAt(OpBuilder &b, Location l, ValueRange batchPrefix, |
39 | Value iv) const = 0; |
40 | |
41 | /// Peeks the lower and upper bound to *fully* traverse the level with |
42 | /// the given position `parentPos`, see SparseTensorIterator::getCurPostion(), |
43 | /// that the immediate parent level is current at. Returns a pair of values |
44 | /// for *posLo* and *loopHi* respectively. |
45 | /// |
46 | /// For a dense level, the *posLo* is the linearized position at beginning, |
47 | /// while *loopHi* is the largest *coordinate*, it also implies that the |
48 | /// smallest *coordinate* to start the loop is 0. |
49 | /// |
50 | /// For a sparse level, [posLo, loopHi) specifies the range of index pointer |
51 | /// to load coordinate from the coordinate buffer. |
52 | virtual std::pair<Value, Value> |
53 | peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix, |
54 | ValueRange parentPos, Value inPadZone = nullptr) const = 0; |
55 | |
56 | virtual std::pair<Value, Value> |
57 | collapseRangeBetween(OpBuilder &b, Location l, ValueRange batchPrefix, |
58 | std::pair<Value, Value> parentRange) const { |
59 | llvm_unreachable("Not Implemented" ); |
60 | }; |
61 | |
62 | Level getLevel() const { return lvl; } |
63 | LevelType getLT() const { return lt; } |
64 | Value getSize() const { return lvlSize; } |
65 | virtual ValueRange getLvlBuffers() const = 0; |
66 | |
67 | // |
68 | // Level properties |
69 | // |
70 | bool isUnique() const { return isUniqueLT(lt); } |
71 | |
72 | protected: |
73 | SparseTensorLevel(unsigned tid, unsigned lvl, LevelType lt, Value lvlSize) |
74 | : tid(tid), lvl(lvl), lt(lt), lvlSize(lvlSize) {}; |
75 | |
76 | public: |
77 | const unsigned tid, lvl; |
78 | const LevelType lt; |
79 | const Value lvlSize; |
80 | }; |
81 | |
82 | enum class IterKind : uint8_t { |
83 | kTrivial, |
84 | kDedup, |
85 | kSubSect, |
86 | kNonEmptySubSect, |
87 | kFilter, |
88 | kPad, |
89 | }; |
90 | |
91 | /// A `SparseIterationSpace` represents a sparse set of coordinates defined by |
92 | /// (possibly multiple) levels of a specific sparse tensor. |
93 | /// TODO: remove `SparseTensorLevel` and switch to SparseIterationSpace when |
94 | /// feature complete. |
95 | class SparseIterationSpace { |
96 | public: |
97 | SparseIterationSpace() = default; |
98 | SparseIterationSpace(SparseIterationSpace &) = delete; |
99 | SparseIterationSpace(SparseIterationSpace &&) = default; |
100 | |
101 | // Constructs a N-D iteration space. |
102 | SparseIterationSpace(Location loc, OpBuilder &b, Value t, unsigned tid, |
103 | std::pair<Level, Level> lvlRange, ValueRange parentPos); |
104 | |
105 | // Constructs a 1-D iteration space. |
106 | SparseIterationSpace(Location loc, OpBuilder &b, Value t, unsigned tid, |
107 | Level lvl, ValueRange parentPos) |
108 | : SparseIterationSpace(loc, b, t, tid, {lvl, lvl + 1}, parentPos) {}; |
109 | |
110 | bool isUnique() const { return lvls.back()->isUnique(); } |
111 | |
112 | unsigned getSpaceDim() const { return lvls.size(); } |
113 | |
114 | // Reconstructs a iteration space directly from the provided ValueRange. |
115 | static SparseIterationSpace fromValues(IterSpaceType dstTp, ValueRange values, |
116 | unsigned tid); |
117 | |
118 | // The inverse operation of `fromValues`. |
119 | SmallVector<Value> toValues() const { |
120 | SmallVector<Value> vals; |
121 | for (auto &stl : lvls) { |
122 | llvm::append_range(C&: vals, R: stl->getLvlBuffers()); |
123 | vals.push_back(Elt: stl->getSize()); |
124 | } |
125 | vals.append(IL: {bound.first, bound.second}); |
126 | return vals; |
127 | } |
128 | |
129 | const SparseTensorLevel &getLastLvl() const { return *lvls.back(); } |
130 | ArrayRef<std::unique_ptr<SparseTensorLevel>> getLvlRef() const { |
131 | return lvls; |
132 | } |
133 | |
134 | Value getBoundLo() const { return bound.first; } |
135 | Value getBoundHi() const { return bound.second; } |
136 | |
137 | // Extract an iterator to iterate over the sparse iteration space. |
138 | std::unique_ptr<SparseIterator> (OpBuilder &b, |
139 | Location l) const; |
140 | |
141 | private: |
142 | SmallVector<std::unique_ptr<SparseTensorLevel>> lvls; |
143 | std::pair<Value, Value> bound; |
144 | }; |
145 | |
146 | /// Helper class that generates loop conditions, etc, to traverse a |
147 | /// sparse tensor level. |
148 | class SparseIterator { |
149 | SparseIterator(SparseIterator &&) = delete; |
150 | SparseIterator(const SparseIterator &) = delete; |
151 | SparseIterator &operator=(SparseIterator &&) = delete; |
152 | SparseIterator &operator=(const SparseIterator &) = delete; |
153 | |
154 | protected: |
155 | SparseIterator(IterKind kind, unsigned tid, unsigned lvl, |
156 | unsigned cursorValsCnt, |
157 | SmallVectorImpl<Value> &cursorValStorage) |
158 | : batchCrds(0), kind(kind), tid(tid), lvl(lvl), crd(nullptr), |
159 | cursorValsCnt(cursorValsCnt), cursorValsStorageRef(cursorValStorage) {}; |
160 | |
161 | SparseIterator(IterKind kind, unsigned cursorValsCnt, |
162 | SmallVectorImpl<Value> &cursorValStorage, |
163 | const SparseIterator &delegate) |
164 | : SparseIterator(kind, delegate.tid, delegate.lvl, cursorValsCnt, |
165 | cursorValStorage) {}; |
166 | |
167 | SparseIterator(IterKind kind, const SparseIterator &wrap, |
168 | unsigned = 0) |
169 | : SparseIterator(kind, wrap.tid, wrap.lvl, |
170 | extraCursorCnt + wrap.cursorValsCnt, |
171 | wrap.cursorValsStorageRef) { |
172 | assert(wrap.cursorValsCnt == wrap.cursorValsStorageRef.size()); |
173 | cursorValsStorageRef.append(NumInputs: extraCursorCnt, Elt: nullptr); |
174 | assert(cursorValsStorageRef.size() == wrap.cursorValsCnt + extraCursorCnt); |
175 | }; |
176 | |
177 | public: |
178 | virtual ~SparseIterator() = default; |
179 | |
180 | void setSparseEmitStrategy(SparseEmitStrategy strategy) { |
181 | emitStrategy = strategy; |
182 | } |
183 | |
184 | virtual std::string getDebugInterfacePrefix() const = 0; |
185 | virtual SmallVector<Type> getCursorValTypes(OpBuilder &b) const = 0; |
186 | |
187 | Value getCrd() const { return crd; } |
188 | ValueRange getBatchCrds() const { return batchCrds; } |
189 | ValueRange getCursor() const { |
190 | return ValueRange(cursorValsStorageRef).take_front(n: cursorValsCnt); |
191 | }; |
192 | |
193 | // Sets the iterate to the specified position. |
194 | void seek(ValueRange vals) { |
195 | assert(vals.size() == cursorValsCnt); |
196 | std::copy(first: vals.begin(), last: vals.end(), result: cursorValsStorageRef.begin()); |
197 | // Now that the iterator is re-positioned, the coordinate becomes invalid. |
198 | crd = nullptr; |
199 | } |
200 | |
201 | // Reconstructs a iteration space directly from the provided ValueRange. |
202 | static std::unique_ptr<SparseIterator> |
203 | fromValues(IteratorType dstTp, ValueRange values, unsigned tid); |
204 | |
205 | // The inverse operation of `fromValues`. |
206 | SmallVector<Value> toValues() const { llvm_unreachable("Not implemented" ); } |
207 | |
208 | // |
209 | // Iterator properties. |
210 | // |
211 | |
212 | // Whether the iterator is a iterator over a batch level. |
213 | virtual bool isBatchIterator() const = 0; |
214 | |
215 | // Whether the iterator support random access (i.e., support look up by |
216 | // *coordinate*). A random access iterator must also traverses a dense space. |
217 | virtual bool randomAccessible() const = 0; |
218 | |
219 | // Whether the iterator can simply traversed by a for loop. |
220 | virtual bool iteratableByFor() const { return false; }; |
221 | |
222 | // Get the upper bound of the sparse space that the iterator might visited. A |
223 | // sparse space is a subset of a dense space [0, bound), this function returns |
224 | // *bound*. |
225 | virtual Value upperBound(OpBuilder &b, Location l) const = 0; |
226 | |
227 | // Serializes and deserializes the current status to/from a set of values. The |
228 | // ValueRange should contain values that are sufficient to recover the current |
229 | // iterating postion (i.e., itVals) as well as loop bound. |
230 | // |
231 | // Not every type of iterator supports the operations, e.g., non-empty |
232 | // subsection iterator does not because the the number of non-empty |
233 | // subsections can not be determined easily. |
234 | // |
235 | // NOTE: All the values should have index type. |
236 | virtual SmallVector<Value> serialize() const { |
237 | llvm_unreachable("unsupported" ); |
238 | }; |
239 | virtual void deserialize(ValueRange vs) { llvm_unreachable("unsupported" ); }; |
240 | |
241 | // |
242 | // Core functions. |
243 | // |
244 | |
245 | // Initializes the iterator according to the parent iterator's state. |
246 | void genInit(OpBuilder &b, Location l, const SparseIterator *p); |
247 | |
248 | // Forwards the iterator to the next element. |
249 | ValueRange forward(OpBuilder &b, Location l); |
250 | |
251 | // Locate the iterator to the position specified by *crd*, this can only |
252 | // be done on an iterator that supports randm access. |
253 | void locate(OpBuilder &b, Location l, Value crd); |
254 | |
255 | // Returns a boolean value that equals `!it.end()` |
256 | Value genNotEnd(OpBuilder &b, Location l); |
257 | |
258 | // Dereferences the iterator, loads the coordinate at the current position. |
259 | // |
260 | // The method assumes that the iterator is not currently exhausted (i.e., |
261 | // it != it.end()). |
262 | Value deref(OpBuilder &b, Location l); |
263 | |
264 | // Actual Implementation provided by derived class. |
265 | virtual void genInitImpl(OpBuilder &, Location, const SparseIterator *) = 0; |
266 | virtual ValueRange forwardImpl(OpBuilder &b, Location l) = 0; |
267 | virtual void locateImpl(OpBuilder &b, Location l, Value crd) { |
268 | llvm_unreachable("Unsupported" ); |
269 | } |
270 | virtual Value genNotEndImpl(OpBuilder &b, Location l) = 0; |
271 | virtual Value derefImpl(OpBuilder &b, Location l) = 0; |
272 | // Gets the ValueRange that together specifies the current position of the |
273 | // iterator. For a unique level, the position can be a single index points to |
274 | // the current coordinate being visited. For a non-unique level, an extra |
275 | // index for the `segment high` is needed to to specifies the range of |
276 | // duplicated coordinates. The ValueRange should be able to uniquely identify |
277 | // the sparse range for the next level. See SparseTensorLevel::peekRangeAt(); |
278 | // |
279 | // Not every type of iterator supports the operation, e.g., non-empty |
280 | // subsection iterator does not because it represent a range of coordinates |
281 | // instead of just one. |
282 | virtual ValueRange getCurPosition() const { return getCursor(); }; |
283 | |
284 | // Returns a pair of values for *upper*, *lower* bound respectively. |
285 | virtual std::pair<Value, Value> genForCond(OpBuilder &b, Location l) { |
286 | assert(randomAccessible()); |
287 | // Random-access iterator is traversed by coordinate, i.e., [curCrd, UB). |
288 | return {getCrd(), upperBound(b, l)}; |
289 | } |
290 | |
291 | // Generates a bool value for scf::ConditionOp. |
292 | std::pair<Value, ValueRange> genWhileCond(OpBuilder &b, Location l, |
293 | ValueRange vs) { |
294 | ValueRange rem = linkNewScope(pos: vs); |
295 | return std::make_pair(x: genNotEnd(b, l), y&: rem); |
296 | } |
297 | |
298 | // Generate a conditional it.next() in the following form |
299 | // |
300 | // if (cond) |
301 | // yield it.next |
302 | // else |
303 | // yield it |
304 | // |
305 | // The function is virtual to allow alternative implementation. For example, |
306 | // if it.next() is trivial to compute, we can use a select operation instead. |
307 | // E.g., |
308 | // |
309 | // it = select cond ? it+1 : it |
310 | virtual ValueRange forwardIf(OpBuilder &b, Location l, Value cond); |
311 | |
312 | // Update the SSA value for the iterator after entering a new scope. |
313 | ValueRange linkNewScope(ValueRange pos) { |
314 | assert(!randomAccessible() && "random accessible iterators are traversed " |
315 | "by coordinate, call locate() instead." ); |
316 | seek(vals: pos.take_front(n: cursorValsCnt)); |
317 | return pos.drop_front(n: cursorValsCnt); |
318 | }; |
319 | |
320 | protected: |
321 | void updateCrd(Value crd) { this->crd = crd; } |
322 | |
323 | MutableArrayRef<Value> getMutCursorVals() { |
324 | MutableArrayRef<Value> ref = cursorValsStorageRef; |
325 | return ref.take_front(N: cursorValsCnt); |
326 | } |
327 | |
328 | void inherentBatch(const SparseIterator &parent) { |
329 | batchCrds = parent.batchCrds; |
330 | } |
331 | |
332 | SparseEmitStrategy emitStrategy; |
333 | SmallVector<Value> batchCrds; |
334 | |
335 | public: |
336 | const IterKind kind; // For LLVM-style RTTI. |
337 | const unsigned tid, lvl; // tensor level identifier. |
338 | |
339 | private: |
340 | Value crd; // The sparse coordinate used to coiterate; |
341 | |
342 | // A range of value that together defines the current state of the |
343 | // iterator. Only loop variants should be included. |
344 | // |
345 | // For trivial iterators, it is the position; for dedup iterators, it consists |
346 | // of the positon and the segment high, for non-empty subsection iterator, it |
347 | // is the metadata that specifies the subsection. |
348 | // Note that the wrapped iterator shares the same storage to maintain itVals |
349 | // with it wrapper, which means the wrapped iterator might only own a subset |
350 | // of all the values stored in itValStorage. |
351 | const unsigned cursorValsCnt; |
352 | SmallVectorImpl<Value> &cursorValsStorageRef; |
353 | }; |
354 | |
355 | /// Helper function to create a TensorLevel object from given `tensor`. |
356 | std::unique_ptr<SparseTensorLevel> makeSparseTensorLevel(OpBuilder &b, |
357 | Location l, Value t, |
358 | unsigned tid, |
359 | Level lvl); |
360 | |
361 | /// Helper function to create a TensorLevel object from given ValueRange. |
362 | std::unique_ptr<SparseTensorLevel> makeSparseTensorLevel(LevelType lt, Value sz, |
363 | ValueRange buffers, |
364 | unsigned tid, Level l); |
365 | |
366 | /// Helper function to create a simple SparseIterator object that iterate |
367 | /// over the entire iteration space. |
368 | std::unique_ptr<SparseIterator> |
369 | makeSimpleIterator(OpBuilder &b, Location l, |
370 | const SparseIterationSpace &iterSpace); |
371 | |
372 | /// Helper function to create a simple SparseIterator object that iterate |
373 | /// over the sparse tensor level. |
374 | /// TODO: switch to `SparseIterationSpace` (which support N-D iterator) when |
375 | /// feature complete. |
376 | std::unique_ptr<SparseIterator> makeSimpleIterator( |
377 | const SparseTensorLevel &stl, |
378 | SparseEmitStrategy strategy = SparseEmitStrategy::kFunctional); |
379 | |
380 | /// Helper function to create a synthetic SparseIterator object that iterates |
381 | /// over a dense space specified by [0,`sz`). |
382 | std::pair<std::unique_ptr<SparseTensorLevel>, std::unique_ptr<SparseIterator>> |
383 | makeSynLevelAndIterator(Value sz, unsigned tid, unsigned lvl, |
384 | SparseEmitStrategy strategy); |
385 | |
386 | /// Helper function to create a SparseIterator object that iterates over a |
387 | /// sliced space, the orignal space (before slicing) is traversed by `sit`. |
388 | std::unique_ptr<SparseIterator> |
389 | makeSlicedLevelIterator(std::unique_ptr<SparseIterator> &&sit, Value offset, |
390 | Value stride, Value size, SparseEmitStrategy strategy); |
391 | |
392 | /// Helper function to create a SparseIterator object that iterates over a |
393 | /// padded sparse level (the padded value must be zero). |
394 | std::unique_ptr<SparseIterator> |
395 | makePaddedIterator(std::unique_ptr<SparseIterator> &&sit, Value padLow, |
396 | Value padHigh, SparseEmitStrategy strategy); |
397 | |
398 | /// Helper function to create a SparseIterator object that iterate over the |
399 | /// non-empty subsections set. |
400 | std::unique_ptr<SparseIterator> makeNonEmptySubSectIterator( |
401 | OpBuilder &b, Location l, const SparseIterator *parent, Value loopBound, |
402 | std::unique_ptr<SparseIterator> &&delegate, Value size, unsigned stride, |
403 | SparseEmitStrategy strategy); |
404 | |
405 | /// Helper function to create a SparseIterator object that iterates over a |
406 | /// non-empty subsection created by NonEmptySubSectIterator. |
407 | std::unique_ptr<SparseIterator> makeTraverseSubSectIterator( |
408 | OpBuilder &b, Location l, const SparseIterator &subsectIter, |
409 | const SparseIterator &parent, std::unique_ptr<SparseIterator> &&wrap, |
410 | Value loopBound, unsigned stride, SparseEmitStrategy strategy); |
411 | |
412 | } // namespace sparse_tensor |
413 | } // namespace mlir |
414 | |
415 | #endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORITERATOR_H_ |
416 | |