| 1 | //===- SparseTensorIterator.h ---------------------------------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORITERATOR_H_ |
| 10 | #define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORITERATOR_H_ |
| 11 | |
| 12 | #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" |
| 13 | #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" |
| 14 | |
| 15 | namespace mlir { |
| 16 | namespace sparse_tensor { |
| 17 | |
| 18 | // Forward declaration. |
| 19 | class SparseIterator; |
| 20 | |
| 21 | /// The base class for all types of sparse tensor levels. It provides interfaces |
| 22 | /// to query the loop range (see `peekRangeAt`) and look up the coordinates (see |
| 23 | /// `peekCrdAt`). |
| 24 | class SparseTensorLevel { |
| 25 | SparseTensorLevel(SparseTensorLevel &&) = delete; |
| 26 | SparseTensorLevel(const SparseTensorLevel &) = delete; |
| 27 | SparseTensorLevel &operator=(SparseTensorLevel &&) = delete; |
| 28 | SparseTensorLevel &operator=(const SparseTensorLevel &) = delete; |
| 29 | |
| 30 | public: |
| 31 | virtual ~SparseTensorLevel() = default; |
| 32 | |
| 33 | std::string toString() const { |
| 34 | return std::string(toMLIRString(lt)) + "[" + std::to_string(val: tid) + "," + |
| 35 | std::to_string(val: lvl) + "]" ; |
| 36 | } |
| 37 | |
| 38 | virtual Value peekCrdAt(OpBuilder &b, Location l, ValueRange batchPrefix, |
| 39 | Value iv) const = 0; |
| 40 | |
| 41 | /// Peeks the lower and upper bound to *fully* traverse the level with |
| 42 | /// the given position `parentPos`, see SparseTensorIterator::getCurPostion(), |
| 43 | /// that the immediate parent level is current at. Returns a pair of values |
| 44 | /// for *posLo* and *loopHi* respectively. |
| 45 | /// |
| 46 | /// For a dense level, the *posLo* is the linearized position at beginning, |
| 47 | /// while *loopHi* is the largest *coordinate*, it also implies that the |
| 48 | /// smallest *coordinate* to start the loop is 0. |
| 49 | /// |
| 50 | /// For a sparse level, [posLo, loopHi) specifies the range of index pointer |
| 51 | /// to load coordinate from the coordinate buffer. |
| 52 | virtual std::pair<Value, Value> |
| 53 | peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix, |
| 54 | ValueRange parentPos, Value inPadZone = nullptr) const = 0; |
| 55 | |
| 56 | virtual std::pair<Value, Value> |
| 57 | collapseRangeBetween(OpBuilder &b, Location l, ValueRange batchPrefix, |
| 58 | std::pair<Value, Value> parentRange) const { |
| 59 | llvm_unreachable("Not Implemented" ); |
| 60 | }; |
| 61 | |
| 62 | Level getLevel() const { return lvl; } |
| 63 | LevelType getLT() const { return lt; } |
| 64 | Value getSize() const { return lvlSize; } |
| 65 | virtual ValueRange getLvlBuffers() const = 0; |
| 66 | |
| 67 | // |
| 68 | // Level properties |
| 69 | // |
| 70 | bool isUnique() const { return isUniqueLT(lt); } |
| 71 | |
| 72 | protected: |
| 73 | SparseTensorLevel(unsigned tid, unsigned lvl, LevelType lt, Value lvlSize) |
| 74 | : tid(tid), lvl(lvl), lt(lt), lvlSize(lvlSize) {}; |
| 75 | |
| 76 | public: |
| 77 | const unsigned tid, lvl; |
| 78 | const LevelType lt; |
| 79 | const Value lvlSize; |
| 80 | }; |
| 81 | |
| 82 | enum class IterKind : uint8_t { |
| 83 | kTrivial, |
| 84 | kDedup, |
| 85 | kSubSect, |
| 86 | kNonEmptySubSect, |
| 87 | kFilter, |
| 88 | kPad, |
| 89 | }; |
| 90 | |
| 91 | /// A `SparseIterationSpace` represents a sparse set of coordinates defined by |
| 92 | /// (possibly multiple) levels of a specific sparse tensor. |
| 93 | /// TODO: remove `SparseTensorLevel` and switch to SparseIterationSpace when |
| 94 | /// feature complete. |
| 95 | class SparseIterationSpace { |
| 96 | public: |
| 97 | SparseIterationSpace() = default; |
| 98 | SparseIterationSpace(SparseIterationSpace &) = delete; |
| 99 | SparseIterationSpace(SparseIterationSpace &&) = default; |
| 100 | |
| 101 | // Constructs a N-D iteration space. |
| 102 | SparseIterationSpace(Location loc, OpBuilder &b, Value t, unsigned tid, |
| 103 | std::pair<Level, Level> lvlRange, ValueRange parentPos); |
| 104 | |
| 105 | // Constructs a 1-D iteration space. |
| 106 | SparseIterationSpace(Location loc, OpBuilder &b, Value t, unsigned tid, |
| 107 | Level lvl, ValueRange parentPos) |
| 108 | : SparseIterationSpace(loc, b, t, tid, {lvl, lvl + 1}, parentPos) {}; |
| 109 | |
| 110 | bool isUnique() const { return lvls.back()->isUnique(); } |
| 111 | |
| 112 | unsigned getSpaceDim() const { return lvls.size(); } |
| 113 | |
| 114 | // Reconstructs a iteration space directly from the provided ValueRange. |
| 115 | static SparseIterationSpace fromValues(IterSpaceType dstTp, ValueRange values, |
| 116 | unsigned tid); |
| 117 | |
| 118 | // The inverse operation of `fromValues`. |
| 119 | SmallVector<Value> toValues() const { |
| 120 | SmallVector<Value> vals; |
| 121 | for (auto &stl : lvls) { |
| 122 | llvm::append_range(C&: vals, R: stl->getLvlBuffers()); |
| 123 | vals.push_back(Elt: stl->getSize()); |
| 124 | } |
| 125 | vals.append(IL: {bound.first, bound.second}); |
| 126 | return vals; |
| 127 | } |
| 128 | |
| 129 | const SparseTensorLevel &getLastLvl() const { return *lvls.back(); } |
| 130 | ArrayRef<std::unique_ptr<SparseTensorLevel>> getLvlRef() const { |
| 131 | return lvls; |
| 132 | } |
| 133 | |
| 134 | Value getBoundLo() const { return bound.first; } |
| 135 | Value getBoundHi() const { return bound.second; } |
| 136 | |
| 137 | // Extract an iterator to iterate over the sparse iteration space. |
| 138 | std::unique_ptr<SparseIterator> (OpBuilder &b, |
| 139 | Location l) const; |
| 140 | |
| 141 | private: |
| 142 | SmallVector<std::unique_ptr<SparseTensorLevel>> lvls; |
| 143 | std::pair<Value, Value> bound; |
| 144 | }; |
| 145 | |
| 146 | /// Helper class that generates loop conditions, etc, to traverse a |
| 147 | /// sparse tensor level. |
| 148 | class SparseIterator { |
| 149 | SparseIterator(SparseIterator &&) = delete; |
| 150 | SparseIterator(const SparseIterator &) = delete; |
| 151 | SparseIterator &operator=(SparseIterator &&) = delete; |
| 152 | SparseIterator &operator=(const SparseIterator &) = delete; |
| 153 | |
| 154 | protected: |
| 155 | SparseIterator(IterKind kind, unsigned tid, unsigned lvl, |
| 156 | unsigned cursorValsCnt, |
| 157 | SmallVectorImpl<Value> &cursorValStorage) |
| 158 | : batchCrds(0), kind(kind), tid(tid), lvl(lvl), crd(nullptr), |
| 159 | cursorValsCnt(cursorValsCnt), cursorValsStorageRef(cursorValStorage) {}; |
| 160 | |
| 161 | SparseIterator(IterKind kind, unsigned cursorValsCnt, |
| 162 | SmallVectorImpl<Value> &cursorValStorage, |
| 163 | const SparseIterator &delegate) |
| 164 | : SparseIterator(kind, delegate.tid, delegate.lvl, cursorValsCnt, |
| 165 | cursorValStorage) {}; |
| 166 | |
| 167 | SparseIterator(IterKind kind, const SparseIterator &wrap, |
| 168 | unsigned = 0) |
| 169 | : SparseIterator(kind, wrap.tid, wrap.lvl, |
| 170 | extraCursorCnt + wrap.cursorValsCnt, |
| 171 | wrap.cursorValsStorageRef) { |
| 172 | assert(wrap.cursorValsCnt == wrap.cursorValsStorageRef.size()); |
| 173 | cursorValsStorageRef.append(NumInputs: extraCursorCnt, Elt: nullptr); |
| 174 | assert(cursorValsStorageRef.size() == wrap.cursorValsCnt + extraCursorCnt); |
| 175 | }; |
| 176 | |
| 177 | public: |
| 178 | virtual ~SparseIterator() = default; |
| 179 | |
| 180 | void setSparseEmitStrategy(SparseEmitStrategy strategy) { |
| 181 | emitStrategy = strategy; |
| 182 | } |
| 183 | |
| 184 | virtual std::string getDebugInterfacePrefix() const = 0; |
| 185 | virtual SmallVector<Type> getCursorValTypes(OpBuilder &b) const = 0; |
| 186 | |
| 187 | Value getCrd() const { return crd; } |
| 188 | ValueRange getBatchCrds() const { return batchCrds; } |
| 189 | ValueRange getCursor() const { |
| 190 | return ValueRange(cursorValsStorageRef).take_front(n: cursorValsCnt); |
| 191 | }; |
| 192 | |
| 193 | // Sets the iterate to the specified position. |
| 194 | void seek(ValueRange vals) { |
| 195 | assert(vals.size() == cursorValsCnt); |
| 196 | std::copy(first: vals.begin(), last: vals.end(), result: cursorValsStorageRef.begin()); |
| 197 | // Now that the iterator is re-positioned, the coordinate becomes invalid. |
| 198 | crd = nullptr; |
| 199 | } |
| 200 | |
| 201 | // Reconstructs a iteration space directly from the provided ValueRange. |
| 202 | static std::unique_ptr<SparseIterator> |
| 203 | fromValues(IteratorType dstTp, ValueRange values, unsigned tid); |
| 204 | |
| 205 | // The inverse operation of `fromValues`. |
| 206 | SmallVector<Value> toValues() const { llvm_unreachable("Not implemented" ); } |
| 207 | |
| 208 | // |
| 209 | // Iterator properties. |
| 210 | // |
| 211 | |
| 212 | // Whether the iterator is a iterator over a batch level. |
| 213 | virtual bool isBatchIterator() const = 0; |
| 214 | |
| 215 | // Whether the iterator support random access (i.e., support look up by |
| 216 | // *coordinate*). A random access iterator must also traverses a dense space. |
| 217 | virtual bool randomAccessible() const = 0; |
| 218 | |
| 219 | // Whether the iterator can simply traversed by a for loop. |
| 220 | virtual bool iteratableByFor() const { return false; }; |
| 221 | |
| 222 | // Get the upper bound of the sparse space that the iterator might visited. A |
| 223 | // sparse space is a subset of a dense space [0, bound), this function returns |
| 224 | // *bound*. |
| 225 | virtual Value upperBound(OpBuilder &b, Location l) const = 0; |
| 226 | |
| 227 | // Serializes and deserializes the current status to/from a set of values. The |
| 228 | // ValueRange should contain values that are sufficient to recover the current |
| 229 | // iterating postion (i.e., itVals) as well as loop bound. |
| 230 | // |
| 231 | // Not every type of iterator supports the operations, e.g., non-empty |
| 232 | // subsection iterator does not because the the number of non-empty |
| 233 | // subsections can not be determined easily. |
| 234 | // |
| 235 | // NOTE: All the values should have index type. |
| 236 | virtual SmallVector<Value> serialize() const { |
| 237 | llvm_unreachable("unsupported" ); |
| 238 | }; |
| 239 | virtual void deserialize(ValueRange vs) { llvm_unreachable("unsupported" ); }; |
| 240 | |
| 241 | // |
| 242 | // Core functions. |
| 243 | // |
| 244 | |
| 245 | // Initializes the iterator according to the parent iterator's state. |
| 246 | void genInit(OpBuilder &b, Location l, const SparseIterator *p); |
| 247 | |
| 248 | // Forwards the iterator to the next element. |
| 249 | ValueRange forward(OpBuilder &b, Location l); |
| 250 | |
| 251 | // Locate the iterator to the position specified by *crd*, this can only |
| 252 | // be done on an iterator that supports randm access. |
| 253 | void locate(OpBuilder &b, Location l, Value crd); |
| 254 | |
| 255 | // Returns a boolean value that equals `!it.end()` |
| 256 | Value genNotEnd(OpBuilder &b, Location l); |
| 257 | |
| 258 | // Dereferences the iterator, loads the coordinate at the current position. |
| 259 | // |
| 260 | // The method assumes that the iterator is not currently exhausted (i.e., |
| 261 | // it != it.end()). |
| 262 | Value deref(OpBuilder &b, Location l); |
| 263 | |
| 264 | // Actual Implementation provided by derived class. |
| 265 | virtual void genInitImpl(OpBuilder &, Location, const SparseIterator *) = 0; |
| 266 | virtual ValueRange forwardImpl(OpBuilder &b, Location l) = 0; |
| 267 | virtual void locateImpl(OpBuilder &b, Location l, Value crd) { |
| 268 | llvm_unreachable("Unsupported" ); |
| 269 | } |
| 270 | virtual Value genNotEndImpl(OpBuilder &b, Location l) = 0; |
| 271 | virtual Value derefImpl(OpBuilder &b, Location l) = 0; |
| 272 | // Gets the ValueRange that together specifies the current position of the |
| 273 | // iterator. For a unique level, the position can be a single index points to |
| 274 | // the current coordinate being visited. For a non-unique level, an extra |
| 275 | // index for the `segment high` is needed to to specifies the range of |
| 276 | // duplicated coordinates. The ValueRange should be able to uniquely identify |
| 277 | // the sparse range for the next level. See SparseTensorLevel::peekRangeAt(); |
| 278 | // |
| 279 | // Not every type of iterator supports the operation, e.g., non-empty |
| 280 | // subsection iterator does not because it represent a range of coordinates |
| 281 | // instead of just one. |
| 282 | virtual ValueRange getCurPosition() const { return getCursor(); }; |
| 283 | |
| 284 | // Returns a pair of values for *upper*, *lower* bound respectively. |
| 285 | virtual std::pair<Value, Value> genForCond(OpBuilder &b, Location l) { |
| 286 | assert(randomAccessible()); |
| 287 | // Random-access iterator is traversed by coordinate, i.e., [curCrd, UB). |
| 288 | return {getCrd(), upperBound(b, l)}; |
| 289 | } |
| 290 | |
| 291 | // Generates a bool value for scf::ConditionOp. |
| 292 | std::pair<Value, ValueRange> genWhileCond(OpBuilder &b, Location l, |
| 293 | ValueRange vs) { |
| 294 | ValueRange rem = linkNewScope(pos: vs); |
| 295 | return std::make_pair(x: genNotEnd(b, l), y&: rem); |
| 296 | } |
| 297 | |
| 298 | // Generate a conditional it.next() in the following form |
| 299 | // |
| 300 | // if (cond) |
| 301 | // yield it.next |
| 302 | // else |
| 303 | // yield it |
| 304 | // |
| 305 | // The function is virtual to allow alternative implementation. For example, |
| 306 | // if it.next() is trivial to compute, we can use a select operation instead. |
| 307 | // E.g., |
| 308 | // |
| 309 | // it = select cond ? it+1 : it |
| 310 | virtual ValueRange forwardIf(OpBuilder &b, Location l, Value cond); |
| 311 | |
| 312 | // Update the SSA value for the iterator after entering a new scope. |
| 313 | ValueRange linkNewScope(ValueRange pos) { |
| 314 | assert(!randomAccessible() && "random accessible iterators are traversed " |
| 315 | "by coordinate, call locate() instead." ); |
| 316 | seek(vals: pos.take_front(n: cursorValsCnt)); |
| 317 | return pos.drop_front(n: cursorValsCnt); |
| 318 | }; |
| 319 | |
| 320 | protected: |
| 321 | void updateCrd(Value crd) { this->crd = crd; } |
| 322 | |
| 323 | MutableArrayRef<Value> getMutCursorVals() { |
| 324 | MutableArrayRef<Value> ref = cursorValsStorageRef; |
| 325 | return ref.take_front(N: cursorValsCnt); |
| 326 | } |
| 327 | |
| 328 | void inherentBatch(const SparseIterator &parent) { |
| 329 | batchCrds = parent.batchCrds; |
| 330 | } |
| 331 | |
| 332 | SparseEmitStrategy emitStrategy; |
| 333 | SmallVector<Value> batchCrds; |
| 334 | |
| 335 | public: |
| 336 | const IterKind kind; // For LLVM-style RTTI. |
| 337 | const unsigned tid, lvl; // tensor level identifier. |
| 338 | |
| 339 | private: |
| 340 | Value crd; // The sparse coordinate used to coiterate; |
| 341 | |
| 342 | // A range of value that together defines the current state of the |
| 343 | // iterator. Only loop variants should be included. |
| 344 | // |
| 345 | // For trivial iterators, it is the position; for dedup iterators, it consists |
| 346 | // of the positon and the segment high, for non-empty subsection iterator, it |
| 347 | // is the metadata that specifies the subsection. |
| 348 | // Note that the wrapped iterator shares the same storage to maintain itVals |
| 349 | // with it wrapper, which means the wrapped iterator might only own a subset |
| 350 | // of all the values stored in itValStorage. |
| 351 | const unsigned cursorValsCnt; |
| 352 | SmallVectorImpl<Value> &cursorValsStorageRef; |
| 353 | }; |
| 354 | |
| 355 | /// Helper function to create a TensorLevel object from given `tensor`. |
| 356 | std::unique_ptr<SparseTensorLevel> makeSparseTensorLevel(OpBuilder &b, |
| 357 | Location l, Value t, |
| 358 | unsigned tid, |
| 359 | Level lvl); |
| 360 | |
| 361 | /// Helper function to create a TensorLevel object from given ValueRange. |
| 362 | std::unique_ptr<SparseTensorLevel> makeSparseTensorLevel(LevelType lt, Value sz, |
| 363 | ValueRange buffers, |
| 364 | unsigned tid, Level l); |
| 365 | |
| 366 | /// Helper function to create a simple SparseIterator object that iterate |
| 367 | /// over the entire iteration space. |
| 368 | std::unique_ptr<SparseIterator> |
| 369 | makeSimpleIterator(OpBuilder &b, Location l, |
| 370 | const SparseIterationSpace &iterSpace); |
| 371 | |
| 372 | /// Helper function to create a simple SparseIterator object that iterate |
| 373 | /// over the sparse tensor level. |
| 374 | /// TODO: switch to `SparseIterationSpace` (which support N-D iterator) when |
| 375 | /// feature complete. |
| 376 | std::unique_ptr<SparseIterator> makeSimpleIterator( |
| 377 | const SparseTensorLevel &stl, |
| 378 | SparseEmitStrategy strategy = SparseEmitStrategy::kFunctional); |
| 379 | |
| 380 | /// Helper function to create a synthetic SparseIterator object that iterates |
| 381 | /// over a dense space specified by [0,`sz`). |
| 382 | std::pair<std::unique_ptr<SparseTensorLevel>, std::unique_ptr<SparseIterator>> |
| 383 | makeSynLevelAndIterator(Value sz, unsigned tid, unsigned lvl, |
| 384 | SparseEmitStrategy strategy); |
| 385 | |
| 386 | /// Helper function to create a SparseIterator object that iterates over a |
| 387 | /// sliced space, the orignal space (before slicing) is traversed by `sit`. |
| 388 | std::unique_ptr<SparseIterator> |
| 389 | makeSlicedLevelIterator(std::unique_ptr<SparseIterator> &&sit, Value offset, |
| 390 | Value stride, Value size, SparseEmitStrategy strategy); |
| 391 | |
| 392 | /// Helper function to create a SparseIterator object that iterates over a |
| 393 | /// padded sparse level (the padded value must be zero). |
| 394 | std::unique_ptr<SparseIterator> |
| 395 | makePaddedIterator(std::unique_ptr<SparseIterator> &&sit, Value padLow, |
| 396 | Value padHigh, SparseEmitStrategy strategy); |
| 397 | |
| 398 | /// Helper function to create a SparseIterator object that iterate over the |
| 399 | /// non-empty subsections set. |
| 400 | std::unique_ptr<SparseIterator> makeNonEmptySubSectIterator( |
| 401 | OpBuilder &b, Location l, const SparseIterator *parent, Value loopBound, |
| 402 | std::unique_ptr<SparseIterator> &&delegate, Value size, unsigned stride, |
| 403 | SparseEmitStrategy strategy); |
| 404 | |
| 405 | /// Helper function to create a SparseIterator object that iterates over a |
| 406 | /// non-empty subsection created by NonEmptySubSectIterator. |
| 407 | std::unique_ptr<SparseIterator> makeTraverseSubSectIterator( |
| 408 | OpBuilder &b, Location l, const SparseIterator &subsectIter, |
| 409 | const SparseIterator &parent, std::unique_ptr<SparseIterator> &&wrap, |
| 410 | Value loopBound, unsigned stride, SparseEmitStrategy strategy); |
| 411 | |
| 412 | } // namespace sparse_tensor |
| 413 | } // namespace mlir |
| 414 | |
| 415 | #endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORITERATOR_H_ |
| 416 | |