1//===- SparseTensorIterator.cpp -------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "SparseTensorIterator.h"
10#include "CodegenUtils.h"
11
12#include "mlir/Dialect/MemRef/IR/MemRef.h"
13#include "mlir/Dialect/SCF/IR/SCF.h"
14#include "mlir/Dialect/Tensor/IR/Tensor.h"
15
16using namespace mlir;
17using namespace mlir::sparse_tensor;
18using ValuePair = std::pair<Value, Value>;
19using ValueTuple = std::tuple<Value, Value, Value>;
20
21//===----------------------------------------------------------------------===//
22// File local helper functions/macros.
23//===----------------------------------------------------------------------===//
24#define CMPI(p, lhs, rhs) \
25 (b.create<arith::CmpIOp>(l, arith::CmpIPredicate::p, (lhs), (rhs)) \
26 .getResult())
27
28#define C_FALSE (constantI1(b, l, false))
29#define C_TRUE (constantI1(b, l, true))
30#define C_IDX(v) (constantIndex(b, l, (v)))
31#define YIELD(vs) (b.create<scf::YieldOp>(l, (vs)))
32#define ADDI(lhs, rhs) (b.create<arith::AddIOp>(l, (lhs), (rhs)).getResult())
33#define ORI(lhs, rhs) (b.create<arith::OrIOp>(l, (lhs), (rhs)).getResult())
34#define ANDI(lhs, rhs) (b.create<arith::AndIOp>(l, (lhs), (rhs)).getResult())
35#define SUBI(lhs, rhs) (b.create<arith::SubIOp>(l, (lhs), (rhs)).getResult())
36#define MULI(lhs, rhs) (b.create<arith::MulIOp>(l, (lhs), (rhs)).getResult())
37#define MINUI(lhs, rhs) (b.create<arith::MinUIOp>(l, (lhs), (rhs)).getResult())
38#define REMUI(lhs, rhs) (b.create<arith::RemUIOp>(l, (lhs), (rhs)).getResult())
39#define DIVUI(lhs, rhs) (b.create<arith::DivUIOp>(l, (lhs), (rhs)).getResult())
40#define SELECT(c, lhs, rhs) \
41 (b.create<arith::SelectOp>(l, (c), (lhs), (rhs)).getResult())
42
43//===----------------------------------------------------------------------===//
44// SparseTensorLevel derived classes.
45//===----------------------------------------------------------------------===//
46
47namespace {
48
49template <bool hasPosBuffer>
50class SparseLevel : public SparseTensorLevel {
51 // It is either an array of size 2 or size 1 depending on whether the sparse
52 // level requires a position array.
53 using BufferT = std::conditional_t<hasPosBuffer, std::array<Value, 2>,
54 std::array<Value, 1>>;
55
56public:
57 SparseLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
58 BufferT buffers)
59 : SparseTensorLevel(tid, lvl, lt, lvlSize), buffers(buffers) {}
60
61 ValueRange getLvlBuffers() const override { return buffers; }
62
63 Value peekCrdAt(OpBuilder &b, Location l, ValueRange batchPrefix,
64 Value iv) const override {
65 SmallVector<Value> memCrd(batchPrefix);
66 memCrd.push_back(Elt: iv);
67 return genIndexLoad(b, l, getCrdBuf(), memCrd);
68 }
69
70protected:
71 template <typename T = void, typename = std::enable_if_t<hasPosBuffer, T>>
72 Value getPosBuf() const {
73 return buffers[0];
74 }
75
76 Value getCrdBuf() const {
77 if constexpr (hasPosBuffer)
78 return buffers[1];
79 else
80 return buffers[0];
81 }
82
83 const BufferT buffers;
84};
85
86class DenseLevel : public SparseTensorLevel {
87public:
88 DenseLevel(unsigned tid, Level lvl, Value lvlSize)
89 : SparseTensorLevel(tid, lvl, LevelFormat::Dense, lvlSize) {}
90
91 Value peekCrdAt(OpBuilder &, Location, ValueRange, Value) const override {
92 llvm_unreachable("locate random-accessible level instead");
93 }
94
95 ValueRange getLvlBuffers() const override { return {}; }
96
97 ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
98 ValueRange parentPos, Value inPadZone) const override {
99 assert(parentPos.size() == 1 && "Dense level can not be non-unique.");
100 assert(!inPadZone && "Not implemented");
101 Value p = parentPos.front();
102 Value posLo = MULI(p, lvlSize);
103 return {posLo, lvlSize};
104 }
105};
106
107class BatchLevel : public SparseTensorLevel {
108public:
109 BatchLevel(unsigned tid, Level lvl, Value lvlSize)
110 : SparseTensorLevel(tid, lvl, LevelFormat::Batch, lvlSize) {}
111
112 Value peekCrdAt(OpBuilder &, Location, ValueRange, Value) const override {
113 llvm_unreachable("locate random-accessible level instead");
114 }
115
116 ValueRange getLvlBuffers() const override { return {}; }
117
118 ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange,
119 ValueRange parentPos, Value inPadZone) const override {
120 assert(!inPadZone && "Not implemented");
121 assert(parentPos.size() == 1 && "Dense level can not be non-unique.");
122 // No need to linearize the position for non-annotated tensors.
123 return {C_IDX(0), lvlSize};
124 }
125};
126
127class CompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
128public:
129 CompressedLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
130 Value posBuffer, Value crdBuffer)
131 : SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}
132
133 ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
134 ValueRange parentPos, Value inPadZone) const override {
135
136 assert(parentPos.size() == 1 &&
137 "compressed level must be the first non-unique level.");
138
139 auto loadRange = [&b, l, parentPos, batchPrefix, this]() -> ValuePair {
140 Value p = parentPos.front();
141 SmallVector<Value> memCrd(batchPrefix);
142 memCrd.push_back(Elt: p);
143 Value pLo = genIndexLoad(builder&: b, loc: l, mem: getPosBuf(), s: memCrd);
144 memCrd.back() = ADDI(p, C_IDX(1));
145 Value pHi = genIndexLoad(builder&: b, loc: l, mem: getPosBuf(), s: memCrd);
146 return {pLo, pHi};
147 };
148
149 if (inPadZone == nullptr)
150 return loadRange();
151
152 SmallVector<Type, 2> types{b.getIndexType(), b.getIndexType()};
153 scf::IfOp posRangeIf = b.create<scf::IfOp>(l, types, inPadZone, true);
154 // True branch, returns a "fake" empty range [0, 0) if parent
155 // iterator is in pad zone.
156 b.setInsertionPointToStart(posRangeIf.thenBlock());
157
158 SmallVector<Value, 2> emptyRange{C_IDX(0), C_IDX(0)};
159 b.create<scf::YieldOp>(l, emptyRange);
160
161 // False branch, returns the actual range.
162 b.setInsertionPointToStart(posRangeIf.elseBlock());
163 auto [pLo, pHi] = loadRange();
164 SmallVector<Value, 2> loadedRange{pLo, pHi};
165 b.create<scf::YieldOp>(l, loadedRange);
166
167 b.setInsertionPointAfter(posRangeIf);
168 ValueRange posRange = posRangeIf.getResults();
169 return {posRange.front(), posRange.back()};
170 }
171}; // namespace
172
173class LooseCompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
174public:
175 LooseCompressedLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
176 Value posBuffer, Value crdBuffer)
177 : SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}
178
179 ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
180 ValueRange parentPos, Value inPadZone) const override {
181 assert(parentPos.size() == 1 &&
182 "loose-compressed level must be the first non-unique level.");
183 assert(!inPadZone && "Not implemented");
184 SmallVector<Value> memCrd(batchPrefix);
185 Value p = parentPos.front();
186 p = MULI(p, C_IDX(2));
187 memCrd.push_back(Elt: p);
188 Value pLo = genIndexLoad(builder&: b, loc: l, mem: getPosBuf(), s: memCrd);
189 memCrd.back() = ADDI(p, C_IDX(1));
190 Value pHi = genIndexLoad(builder&: b, loc: l, mem: getPosBuf(), s: memCrd);
191 return {pLo, pHi};
192 }
193}; // namespace
194
195class SingletonLevel : public SparseLevel</*hasPosBuf=*/false> {
196public:
197 SingletonLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
198 Value crdBuffer)
199 : SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {}
200
201 ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
202 ValueRange parentPos, Value inPadZone) const override {
203 assert(parentPos.size() == 1 || parentPos.size() == 2);
204 assert(!inPadZone && "Not implemented");
205 Value p = parentPos.front();
206 Value segHi = parentPos.size() == 2 ? parentPos.back() : nullptr;
207
208 if (segHi == nullptr)
209 return {p, ADDI(p, C_IDX(1))};
210 // Use the segHi as the loop upper bound.
211 return {p, segHi};
212 }
213
214 ValuePair
215 collapseRangeBetween(OpBuilder &b, Location l, ValueRange batchPrefix,
216 std::pair<Value, Value> parentRange) const override {
217 // Singleton level keeps the same range after collapsing.
218 return parentRange;
219 };
220};
221
222class NOutOfMLevel : public SparseLevel</*hasPosBuf=*/false> {
223public:
224 NOutOfMLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
225 Value crdBuffer)
226 : SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {}
227
228 ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
229 ValueRange parentPos, Value inPadZone) const override {
230 assert(parentPos.size() == 1 && isUnique() &&
231 "n:m level can not be non-unique.");
232 assert(!inPadZone && "Not implemented");
233 // Each n:m blk has exactly n specified elements.
234 auto n = getN(lt);
235 Value posLo = MULI(parentPos.front(), C_IDX(n));
236 return {posLo, ADDI(posLo, C_IDX(n))};
237 }
238};
239
240} // namespace
241
242//===----------------------------------------------------------------------===//
243// File local helpers
244//===----------------------------------------------------------------------===//
245
246static scf::ValueVector genWhenInBound(
247 OpBuilder &b, Location l, SparseIterator &it, ValueRange elseRet,
248 llvm::function_ref<scf::ValueVector(OpBuilder &, Location, Value)>
249 builder) {
250 TypeRange ifRetTypes = elseRet.getTypes();
251 auto ifOp = b.create<scf::IfOp>(l, ifRetTypes, it.genNotEnd(b, l), true);
252
253 b.setInsertionPointToStart(ifOp.thenBlock());
254 Value crd = it.deref(b, l);
255 scf::ValueVector ret = builder(b, l, crd);
256 YIELD(ret);
257
258 b.setInsertionPointToStart(ifOp.elseBlock());
259 YIELD(elseRet);
260
261 b.setInsertionPointAfter(ifOp);
262 return ifOp.getResults();
263}
264
265/// Generates code to compute the *absolute* offset of the slice based on the
266/// provide minimum coordinates in the slice.
267/// E.g., when reducing d0 + d1 + d2, we need two slices to fully reduced the
268/// expression, i,e, s1 = slice(T, d0), s2 = slice(s1, d1). The *absolute*
269/// offset is the offset computed relative to the initial tensors T.
270///
271/// When isNonEmpty == true, the computed offset is meaningless and should not
272/// be used during runtime, the method generates code to return 0 currently in
273/// that case.
274///
275/// offset = minCrd >= size ? minCrd - size + 1 : 0;
276static Value offsetFromMinCrd(OpBuilder &b, Location l, Value minCrd,
277 Value size) {
278 Value geSize = CMPI(uge, minCrd, size);
279 // Compute minCrd - size + 1.
280 Value mms = SUBI(ADDI(minCrd, C_IDX(1)), size);
281 // This is the absolute offset related to the actual tensor.
282 return SELECT(geSize, mms, C_IDX(0));
283}
284
285//===----------------------------------------------------------------------===//
286// SparseIterator derived classes.
287//===----------------------------------------------------------------------===//
288
289namespace {
290
291// The iterator that traverses a concrete sparse tensor levels. High-level
292// abstract iterators wrap it to achieve more complex goals (such as collapsing
293// several levels). It also holds the common storage to hold the mlir::Values
294// for itself as well as for wrappers.
295class ConcreteIterator : public SparseIterator {
296protected:
297 ConcreteIterator(const SparseTensorLevel &stl, IterKind kind,
298 unsigned cursorValCnt)
299 : SparseIterator(kind, stl.tid, stl.lvl, cursorValCnt, cursorValsStorage),
300 stl(stl), cursorValsStorage(cursorValCnt, nullptr) {
301 assert(getCursor().size() == cursorValCnt);
302 };
303
304public:
305 // For LLVM-style RTTI.
306 static bool classof(const SparseIterator *from) {
307 return from->kind == IterKind::kTrivial;
308 }
309
310 bool isBatchIterator() const override {
311 return stl.getLT().isa<LevelFormat::Batch>();
312 }
313 bool randomAccessible() const override {
314 return stl.getLT().hasDenseSemantic();
315 };
316 bool iteratableByFor() const override { return kind != IterKind::kDedup; };
317 Value upperBound(OpBuilder &b, Location l) const override {
318 return stl.getSize();
319 };
320
321protected:
322 const SparseTensorLevel &stl;
323 // Owner of the storage, all wrappers build on top of a concrete iterator
324 // share the same storage such that the iterator values are always
325 // synchronized.
326 SmallVector<Value> cursorValsStorage;
327};
328
329class TrivialIterator : public ConcreteIterator {
330public:
331 TrivialIterator(const SparseTensorLevel &stl)
332 : ConcreteIterator(stl, IterKind::kTrivial, /*itValCnt=*/1) {}
333
334 TrivialIterator(OpBuilder &b, Location l, const SparseTensorLevel &stl,
335 Value posLo, Value posHi)
336 : ConcreteIterator(stl, IterKind::kTrivial, /*itValCnt=*/1), posLo(posLo),
337 posHi(posHi) {
338 seek(vals: posLo);
339 }
340
341 std::string getDebugInterfacePrefix() const override {
342 return std::string("trivial<") + stl.toString() + ">";
343 }
344 SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
345 return {b.getIndexType()};
346 }
347
348 SmallVector<Value> serialize() const override {
349 SmallVector<Value> ret;
350 ret.push_back(Elt: getItPos());
351 if (randomAccessible()) {
352 // Loop high is implicit (defined by `upperBound()`) for random-access
353 // iterator, but we need to memorize posLo for linearization.
354 ret.push_back(Elt: posLo);
355 } else {
356 ret.push_back(Elt: posHi);
357 }
358 return ret;
359 };
360
361 void deserialize(ValueRange vs) override {
362 assert(vs.size() == 2);
363 seek(vals: vs.front());
364 if (randomAccessible())
365 posLo = vs.back();
366 else
367 posHi = vs.back();
368 };
369
370 void genInitImpl(OpBuilder &b, Location l,
371 const SparseIterator *parent) override;
372
373 ValuePair genForCond(OpBuilder &b, Location l) override {
374 if (randomAccessible())
375 return {deref(b, l), upperBound(b, l)};
376 return std::make_pair(x: getItPos(), y&: posHi);
377 }
378
379 Value genNotEndImpl(OpBuilder &b, Location l) override {
380 // We used the first level bound as the bound the collapsed set of levels.
381 return CMPI(ult, getItPos(), posHi);
382 }
383
384 Value derefImpl(OpBuilder &b, Location l) override {
385 if (randomAccessible()) {
386 updateCrd(SUBI(getItPos(), posLo));
387 } else {
388 updateCrd(crd: stl.peekCrdAt(b, l, batchPrefix: getBatchCrds(), iv: getItPos()));
389 }
390 return getCrd();
391 };
392
393 ValueRange forwardImpl(OpBuilder &b, Location l) override {
394 seek(ADDI(getItPos(), C_IDX(1)));
395 return getCursor();
396 }
397
398 ValueRange forwardIf(OpBuilder &b, Location l, Value cond) override {
399 Value curPos = getCursor().front();
400 Value nxPos = forward(b, l).front();
401 seek(SELECT(cond, nxPos, curPos));
402 return getCursor();
403 }
404
405 void locateImpl(OpBuilder &b, Location l, Value crd) override {
406 assert(randomAccessible());
407 // Seek to the linearized position.
408 seek(ADDI(crd, posLo));
409 updateCrd(crd);
410 if (isBatchIterator()) {
411 // If this is a batch iterator, also update the batch coordinate.
412 assert(batchCrds.size() > lvl);
413 batchCrds[lvl] = crd;
414 }
415 }
416
417 Value getItPos() const { return getCursor().front(); }
418 Value posLo, posHi;
419};
420
421class DedupIterator : public ConcreteIterator {
422private:
423 Value genSegmentHigh(OpBuilder &b, Location l, Value pos);
424
425public:
426 DedupIterator(const SparseTensorLevel &stl)
427 : ConcreteIterator(stl, IterKind::kDedup, /*itValCnt=*/2) {
428 assert(!stl.isUnique());
429 }
430
431 DedupIterator(OpBuilder &b, Location l, const SparseTensorLevel &stl,
432 Value posLo, Value posHi)
433 : ConcreteIterator(stl, IterKind::kDedup, /*itValCnt=*/2), posHi(posHi) {
434 assert(!stl.isUnique());
435 seek(vals: {posLo, genSegmentHigh(b, l, pos: posLo)});
436 }
437
438 // For LLVM-style RTTI.
439 static bool classof(const SparseIterator *from) {
440 return from->kind == IterKind::kDedup;
441 }
442
443 std::string getDebugInterfacePrefix() const override {
444 return std::string("dedup<") + stl.toString() + ">";
445 }
446 SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
447 return {b.getIndexType(), b.getIndexType()};
448 }
449
450 void genInitImpl(OpBuilder &b, Location l,
451 const SparseIterator *parent) override {
452 Value c0 = C_IDX(0);
453 ValueRange pPos = c0;
454
455 // If the parent iterator is a batch iterator, we also start from 0 (but
456 // on a different batch).
457 if (parent && !parent->isBatchIterator())
458 pPos = parent->getCurPosition();
459
460 Value posLo;
461 ValueRange batchPrefix = parent ? parent->getBatchCrds() : ValueRange{};
462 std::tie(args&: posLo, args&: posHi) = stl.peekRangeAt(b, l, batchPrefix, parentPos: pPos);
463
464 seek(vals: {posLo, genSegmentHigh(b, l, pos: posLo)});
465 }
466
467 SmallVector<Value> serialize() const override {
468 SmallVector<Value> ret;
469 ret.append(in_start: getCursor().begin(), in_end: getCursor().end());
470 ret.push_back(Elt: posHi);
471 return ret;
472 };
473 void deserialize(ValueRange vs) override {
474 assert(vs.size() == 3);
475 seek(vals: vs.take_front(n: getCursor().size()));
476 posHi = vs.back();
477 };
478
479 Value genNotEndImpl(OpBuilder &b, Location l) override {
480 return CMPI(ult, getPos(), posHi);
481 }
482
483 Value derefImpl(OpBuilder &b, Location l) override {
484 updateCrd(crd: stl.peekCrdAt(b, l, batchPrefix: getBatchCrds(), iv: getPos()));
485 return getCrd();
486 };
487
488 ValueRange forwardImpl(OpBuilder &b, Location l) override {
489 Value nxPos = getSegHi(); // forward the position to the next segment.
490 seek(vals: {nxPos, genSegmentHigh(b, l, pos: nxPos)});
491 return getCursor();
492 }
493
494 Value getPos() const { return getCursor()[0]; }
495 Value getSegHi() const { return getCursor()[1]; }
496
497 Value posHi;
498};
499
500// A util base-iterator that delegates all methods to the wrapped iterator.
501class SimpleWrapIterator : public SparseIterator {
502public:
503 SimpleWrapIterator(std::unique_ptr<SparseIterator> &&wrap, IterKind kind,
504 unsigned extraCursorVal = 0)
505 : SparseIterator(kind, *wrap, extraCursorVal), wrap(std::move(wrap)) {}
506
507 SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
508 return wrap->getCursorValTypes(b);
509 }
510 bool isBatchIterator() const override { return wrap->isBatchIterator(); }
511 bool randomAccessible() const override { return wrap->randomAccessible(); };
512 bool iteratableByFor() const override { return wrap->iteratableByFor(); };
513
514 SmallVector<Value> serialize() const override { return wrap->serialize(); };
515 void deserialize(ValueRange vs) override { wrap->deserialize(vs); };
516 ValueRange getCurPosition() const override { return wrap->getCurPosition(); }
517 void genInitImpl(OpBuilder &b, Location l,
518 const SparseIterator *parent) override {
519 wrap->genInit(b, l, p: parent);
520 }
521 Value genNotEndImpl(OpBuilder &b, Location l) override {
522 return wrap->genNotEndImpl(b, l);
523 }
524 ValueRange forwardImpl(OpBuilder &b, Location l) override {
525 return wrap->forward(b, l);
526 };
527 Value upperBound(OpBuilder &b, Location l) const override {
528 return wrap->upperBound(b, l);
529 };
530
531 Value derefImpl(OpBuilder &b, Location l) override {
532 return wrap->derefImpl(b, l);
533 }
534
535 void locateImpl(OpBuilder &b, Location l, Value crd) override {
536 return wrap->locate(b, l, crd);
537 }
538
539 SparseIterator &getWrappedIterator() const { return *wrap; }
540
541protected:
542 std::unique_ptr<SparseIterator> wrap;
543};
544
545//
546// A filter iterator wrapped from another iterator. The filter iterator update
547// the wrapped iterator *in-place*.
548//
549class FilterIterator : public SimpleWrapIterator {
550 // Coorindate translation between crd loaded from the wrap iterator and the
551 // filter iterator.
552 Value fromWrapCrd(OpBuilder &b, Location l, Value wrapCrd) const {
553 // crd = (wrapCrd - offset) / stride
554 return DIVUI(SUBI(wrapCrd, offset), stride);
555 }
556 Value toWrapCrd(OpBuilder &b, Location l, Value crd) const {
557 // wrapCrd = crd * stride + offset
558 return ADDI(MULI(crd, stride), offset);
559 }
560
561 Value genCrdNotLegitPredicate(OpBuilder &b, Location l, Value wrapCrd);
562
563 Value genShouldFilter(OpBuilder &b, Location l);
564
565public:
566 // TODO: avoid unnessary check when offset == 0 and/or when stride == 1 and/or
567 // when crd always < size.
568 FilterIterator(std::unique_ptr<SparseIterator> &&wrap, Value offset,
569 Value stride, Value size)
570 : SimpleWrapIterator(std::move(wrap), IterKind::kFilter), offset(offset),
571 stride(stride), size(size) {}
572
573 // For LLVM-style RTTI.
574 static bool classof(const SparseIterator *from) {
575 return from->kind == IterKind::kFilter;
576 }
577
578 std::string getDebugInterfacePrefix() const override {
579 return std::string("filter<") + wrap->getDebugInterfacePrefix() + ">";
580 }
581
582 bool iteratableByFor() const override { return randomAccessible(); };
583 Value upperBound(OpBuilder &b, Location l) const override { return size; };
584
585 void genInitImpl(OpBuilder &b, Location l,
586 const SparseIterator *parent) override {
587 wrap->genInit(b, l, p: parent);
588 if (!randomAccessible()) {
589 // TODO: we can skip this when stride == 1 and offset == 0, we can also
590 // use binary search here.
591 forwardIf(b, l, cond: genShouldFilter(b, l));
592 } else {
593 // Else, locate to the slice.offset, which is the first coordinate
594 // included by the slice.
595 wrap->locate(b, l, crd: offset);
596 }
597 }
598
599 Value genNotEndImpl(OpBuilder &b, Location l) override;
600
601 Value derefImpl(OpBuilder &b, Location l) override {
602 updateCrd(crd: fromWrapCrd(b, l, wrapCrd: wrap->deref(b, l)));
603 return getCrd();
604 }
605
606 void locateImpl(OpBuilder &b, Location l, Value crd) override {
607 assert(randomAccessible());
608 wrap->locate(b, l, crd: toWrapCrd(b, l, crd));
609 updateCrd(crd);
610 }
611
612 ValueRange forwardImpl(OpBuilder &b, Location l) override;
613
614 Value offset, stride, size;
615};
616
617//
618// A pad iterator wrapped from another iterator. The pad iterator updates
619// the wrapped iterator *in-place*.
620//
621class PadIterator : public SimpleWrapIterator {
622
623public:
624 PadIterator(std::unique_ptr<SparseIterator> &&wrap, Value padLow,
625 Value padHigh)
626 : SimpleWrapIterator(std::move(wrap), IterKind::kPad,
627 wrap->randomAccessible() ? 1 : 0),
628 padLow(padLow), padHigh(padHigh) {}
629
630 // For LLVM-style RTTI.
631 static bool classof(const SparseIterator *from) {
632 return from->kind == IterKind::kPad;
633 }
634
635 std::string getDebugInterfacePrefix() const override {
636 return std::string("pad<") + wrap->getDebugInterfacePrefix() + ">";
637 }
638
639 // Returns a pair of values for *upper*, *lower* bound respectively.
640 ValuePair genForCond(OpBuilder &b, Location l) override {
641 if (randomAccessible())
642 return {getCrd(), upperBound(b, l)};
643 return wrap->genForCond(b, l);
644 }
645
646 // For padded dense iterator, we append a `inPadZone: bool` in addition to
647 // values used by the wrapped iterator.
648 ValueRange getCurPosition() const override { return getCursor(); }
649
650 SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
651 SmallVector<Type> ret = wrap->getCursorValTypes(b);
652 // Need an extra boolean value `inPadZone` for padded dense iterator.
653 if (randomAccessible())
654 ret.push_back(b.getI1Type());
655
656 return ret;
657 }
658
659 // The upper bound after padding becomes `size + padLow + padHigh`.
660 Value upperBound(OpBuilder &b, Location l) const override {
661 return ADDI(ADDI(wrap->upperBound(b, l), padLow), padHigh);
662 };
663
664 // The pad_coord = coord + pad_lo
665 Value derefImpl(OpBuilder &b, Location l) override {
666 updateCrd(ADDI(wrap->deref(b, l), padLow));
667 return getCrd();
668 }
669
670 void locateImpl(OpBuilder &b, Location l, Value crd) override {
671 assert(randomAccessible());
672 wrap->locate(b, l, SUBI(crd, padLow));
673
674 // inPadZone = crd < padLow || crd >= size + padLow.
675 Value inPadLow = CMPI(ult, crd, padLow);
676 Value inPadHigh = CMPI(uge, crd, ADDI(wrap->upperBound(b, l), padLow));
677 getMutCursorVals().back() = ORI(inPadLow, inPadHigh);
678
679 updateCrd(crd);
680 }
681
682 Value padLow, padHigh;
683};
684
685class NonEmptySubSectIterator : public SparseIterator {
686public:
687 using TraverseBuilder = llvm::function_ref<scf::ValueVector(
688 OpBuilder &, Location, const SparseIterator *, ValueRange)>;
689
690 NonEmptySubSectIterator(OpBuilder &b, Location l,
691 const SparseIterator *parent,
692 std::unique_ptr<SparseIterator> &&delegate,
693 Value subSectSz)
694 : SparseIterator(IterKind::kNonEmptySubSect, 3, subSectMeta, *delegate),
695 parent(parent), delegate(std::move(delegate)),
696 tupleSz(this->delegate->serialize().size()), subSectSz(subSectSz) {
697 auto *p = dyn_cast_or_null<NonEmptySubSectIterator>(Val: parent);
698 if (p == nullptr) {
699 // Extract subsections along the root level.
700 maxTupleCnt = C_IDX(1);
701 } else if (p->lvl == lvl) {
702 // Extract subsections along the same level.
703 maxTupleCnt = p->maxTupleCnt;
704 assert(false && "Not implemented.");
705 } else {
706 // Extract subsections along the previous level.
707 assert(p->lvl + 1 == lvl);
708 maxTupleCnt = MULI(p->maxTupleCnt, p->subSectSz);
709 }
710 // We don't need an extra buffer to find subsections on random-accessible
711 // levels.
712 if (randomAccessible())
713 return;
714 subSectPosBuf = allocSubSectPosBuf(b, l);
715 }
716
717 // For LLVM-style RTTI.
718 static bool classof(const SparseIterator *from) {
719 return from->kind == IterKind::kNonEmptySubSect;
720 }
721
722 std::string getDebugInterfacePrefix() const override {
723 return std::string("ne_sub<") + delegate->getDebugInterfacePrefix() + ">";
724 }
725 SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
726 // minCrd, absolute offset, notEnd
727 return {b.getIndexType(), b.getIndexType(), b.getI1Type()};
728 }
729
730 // The sliced pointer buffer is organized as:
731 // [[itVal0, itVal1, ..., pNx0],
732 // [itVal0, itVal1, ..., pNx0],
733 // ...]
734 Value allocSubSectPosBuf(OpBuilder &b, Location l) {
735 return b.create<memref::AllocaOp>(
736 l,
737 MemRefType::get({ShapedType::kDynamic, tupleSz + 1}, b.getIndexType()),
738 maxTupleCnt);
739 }
740
741 void storeNxLvlStart(OpBuilder &b, Location l, Value tupleId,
742 Value start) const {
743 b.create<memref::StoreOp>(l, start, subSectPosBuf,
744 ValueRange{tupleId, C_IDX(tupleSz)});
745 }
746
747 Value loadNxLvlStart(OpBuilder &b, Location l, Value tupleId) const {
748 return b.create<memref::LoadOp>(l, subSectPosBuf,
749 ValueRange{tupleId, C_IDX(tupleSz)});
750 }
751
752 void storeCursorVals(OpBuilder &b, Location l, Value tupleId,
753 ValueRange itVals) const {
754 assert(itVals.size() == tupleSz);
755 for (unsigned i = 0; i < tupleSz; i++) {
756 b.create<memref::StoreOp>(l, itVals[i], subSectPosBuf,
757 ValueRange{tupleId, C_IDX(i)});
758 }
759 }
760
761 SmallVector<Value> loadCursorVals(OpBuilder &b, Location l,
762 Value tupleId) const {
763 SmallVector<Value> ret;
764 for (unsigned i = 0; i < tupleSz; i++) {
765 Value v = b.create<memref::LoadOp>(l, subSectPosBuf,
766 ValueRange{tupleId, C_IDX(i)});
767 ret.push_back(Elt: v);
768 }
769 return ret;
770 }
771
772 bool isSubSectRoot() const {
773 return !parent || !llvm::isa<NonEmptySubSectIterator>(Val: parent);
774 }
775
776 // Generate code that inflate the current subsection tree till the current
777 // level such that every leaf node is visited.
778 ValueRange inflateSubSectTree(OpBuilder &b, Location l, ValueRange reduc,
779 TraverseBuilder builder) const;
780
781 bool isBatchIterator() const override { return delegate->isBatchIterator(); }
782 bool randomAccessible() const override {
783 return delegate->randomAccessible();
784 };
785 bool iteratableByFor() const override { return randomAccessible(); };
786 Value upperBound(OpBuilder &b, Location l) const override {
787 auto *p = dyn_cast_or_null<NonEmptySubSectIterator>(Val: parent);
788 Value parentUB =
789 p && p->lvl == lvl ? p->upperBound(b, l) : delegate->upperBound(b, l);
790 return ADDI(SUBI(parentUB, subSectSz), C_IDX(1));
791 };
792
793 void genInitImpl(OpBuilder &b, Location l, const SparseIterator *) override;
794
795 void locateImpl(OpBuilder &b, Location l, Value crd) override {
796 Value absOff = crd;
797
798 if (isSubSectRoot())
799 delegate->locate(b, l, crd: absOff);
800 else
801 assert(parent->lvl + 1 == lvl);
802
803 seek(vals: ValueRange{absOff, absOff, C_TRUE});
804 updateCrd(crd);
805 }
806
807 Value toSubSectCrd(OpBuilder &b, Location l, Value wrapCrd) const {
808 return SUBI(wrapCrd, getAbsOff());
809 }
810
811 Value genNotEndImpl(OpBuilder &b, Location l) override {
812 return getNotEnd();
813 };
814
815 Value derefImpl(OpBuilder &b, Location l) override {
816 // Use the relative offset to coiterate.
817 Value crd;
818 auto *p = dyn_cast_or_null<NonEmptySubSectIterator>(Val: parent);
819 if (p && p->lvl == lvl)
820 crd = SUBI(getAbsOff(), p->getAbsOff());
821 crd = getAbsOff();
822
823 updateCrd(crd);
824 return crd;
825 };
826
827 ValueRange forwardImpl(OpBuilder &b, Location l) override;
828
829 Value getMinCrd() const { return subSectMeta[0]; }
830 Value getAbsOff() const { return subSectMeta[1]; }
831 Value getNotEnd() const { return subSectMeta[2]; }
832
833 const SparseIterator *parent;
834 std::unique_ptr<SparseIterator> delegate;
835
836 // Number of values required to serialize the wrapped iterator.
837 const unsigned tupleSz;
838 // Max number of tuples, and the actual number of tuple.
839 Value maxTupleCnt, tupleCnt;
840 // The memory used to cache the tuple serialized from the wrapped iterator.
841 Value subSectPosBuf;
842
843 const Value subSectSz;
844
845 // minCrd, absolute offset, notEnd
846 SmallVector<Value, 3> subSectMeta{nullptr, nullptr, nullptr};
847};
848
849class SubSectIterator;
850
851// A wrapper that helps generating code to traverse a subsection, used
852// by both `NonEmptySubSectIterator`and `SubSectIterator`.
853struct SubSectIterHelper {
854 explicit SubSectIterHelper(const SubSectIterator &iter);
855 explicit SubSectIterHelper(const NonEmptySubSectIterator &subSect);
856
857 // Delegate methods.
858 void deserializeFromTupleId(OpBuilder &b, Location l, Value tupleId);
859 void locate(OpBuilder &b, Location l, Value crd);
860 Value genNotEnd(OpBuilder &b, Location l);
861 Value deref(OpBuilder &b, Location l);
862 ValueRange forward(OpBuilder &b, Location l);
863
864 const NonEmptySubSectIterator &subSect;
865 SparseIterator &wrap;
866};
867
868class SubSectIterator : public SparseIterator {
869public:
870 SubSectIterator(const NonEmptySubSectIterator &subSect,
871 const SparseIterator &parent,
872 std::unique_ptr<SparseIterator> &&wrap)
873 : SparseIterator(IterKind::kSubSect, *wrap,
874 /*extraCursorCnt=*/wrap->randomAccessible() ? 0 : 1),
875 subSect(subSect), wrap(std::move(wrap)), parent(parent), helper(*this) {
876 assert(subSect.tid == tid && subSect.lvl == lvl);
877 assert(parent.kind != IterKind::kSubSect || parent.lvl + 1 == lvl);
878 };
879
880 // For LLVM-style RTTI.
881 static bool classof(const SparseIterator *from) {
882 return from->kind == IterKind::kSubSect;
883 }
884
885 std::string getDebugInterfacePrefix() const override {
886 return std::string("subsect<") + wrap->getDebugInterfacePrefix() + ">";
887 }
888 SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
889 SmallVector<Type> ret = wrap->getCursorValTypes(b);
890 if (!randomAccessible())
891 ret.push_back(b.getIndexType()); // The extra counter.
892 return ret;
893 }
894
895 bool isBatchIterator() const override { return wrap->isBatchIterator(); }
896 bool randomAccessible() const override { return wrap->randomAccessible(); };
897 bool iteratableByFor() const override { return randomAccessible(); };
898 Value upperBound(OpBuilder &b, Location l) const override {
899 return subSect.subSectSz;
900 }
901
902 ValueRange getCurPosition() const override { return wrap->getCurPosition(); };
903
904 Value getNxLvlTupleId(OpBuilder &b, Location l) const {
905 if (randomAccessible()) {
906 return ADDI(getCrd(), nxLvlTupleStart);
907 };
908 return ADDI(getCursor().back(), nxLvlTupleStart);
909 }
910
911 void genInitImpl(OpBuilder &b, Location l, const SparseIterator *) override {
912 if (randomAccessible()) {
913 if (auto *p = llvm::dyn_cast<SubSectIterator>(Val: &parent)) {
914 assert(p->lvl + 1 == lvl);
915 wrap->genInit(b, l, p);
916 // Linearize the dense subsection index.
917 nxLvlTupleStart = MULI(subSect.subSectSz, p->getNxLvlTupleId(b, l));
918 } else {
919 assert(subSect.lvl == lvl && subSect.isSubSectRoot());
920 wrap->deserialize(vs: subSect.delegate->serialize());
921 nxLvlTupleStart = C_IDX(0);
922 }
923 return;
924 }
925 assert(!randomAccessible());
926 assert(getCursor().size() == wrap->getCursor().size() + 1);
927 // Extra counter that counts the number of actually visited coordinates in
928 // the sparse subsection.
929 getMutCursorVals().back() = C_IDX(0);
930 Value tupleId;
931 if (auto *p = llvm::dyn_cast<SubSectIterator>(Val: &parent)) {
932 assert(p->lvl + 1 == lvl);
933 tupleId = p->getNxLvlTupleId(b, l);
934 } else {
935 assert(subSect.lvl == lvl && subSect.isSubSectRoot());
936 tupleId = C_IDX(0);
937 }
938 nxLvlTupleStart = subSect.loadNxLvlStart(b, l, tupleId);
939 helper.deserializeFromTupleId(b, l, tupleId);
940 }
941
942 void locateImpl(OpBuilder &b, Location l, Value crd) override {
943 helper.locate(b, l, crd);
944 updateCrd(crd);
945 }
946
947 Value genNotEndImpl(OpBuilder &b, Location l) override {
948 return helper.genNotEnd(b, l);
949 }
950
951 Value derefImpl(OpBuilder &b, Location l) override {
952 Value crd = helper.deref(b, l);
953 updateCrd(crd);
954 return crd;
955 };
956
957 ValueRange forwardImpl(OpBuilder &b, Location l) override {
958 helper.forward(b, l);
959 assert(!randomAccessible());
960 assert(getCursor().size() == wrap->getCursor().size() + 1);
961 getMutCursorVals().back() = ADDI(getCursor().back(), C_IDX(1));
962 return getCursor();
963 };
964
965 Value nxLvlTupleStart;
966
967 const NonEmptySubSectIterator &subSect;
968 std::unique_ptr<SparseIterator> wrap;
969 const SparseIterator &parent;
970
971 SubSectIterHelper helper;
972};
973
974} // namespace
975
976//===----------------------------------------------------------------------===//
977// SparseIterator derived classes implementation.
978//===----------------------------------------------------------------------===//
979
980void SparseIterator::genInit(OpBuilder &b, Location l,
981 const SparseIterator *p) {
982 if (emitStrategy == SparseEmitStrategy::kDebugInterface) {
983 std::string prefix = getDebugInterfacePrefix();
984 Operation *begin = b.create(l, b.getStringAttr(prefix + ".begin"), {},
985 getCursorValTypes(b));
986 seek(vals: begin->getResults());
987 return;
988 }
989 // Inherent batch coordinates from parents.
990 if (p)
991 inherentBatch(parent: *p);
992 // TODO: support lowering to function call.
993 return genInitImpl(b, l, p);
994}
995
996Value SparseIterator::genNotEnd(OpBuilder &b, Location l) {
997 if (emitStrategy == SparseEmitStrategy::kDebugInterface) {
998 std::string prefix = getDebugInterfacePrefix();
999 Operation *notEnd = b.create(l, b.getStringAttr(prefix + ".not_end"),
1000 getCursor(), b.getI1Type());
1001 return notEnd->getResult(idx: 0);
1002 }
1003 // TODO: support lowering to function call.
1004 return genNotEndImpl(b, l);
1005}
1006
1007void SparseIterator::locate(OpBuilder &b, Location l, Value crd) {
1008 if (emitStrategy == SparseEmitStrategy::kDebugInterface) {
1009 std::string prefix = getDebugInterfacePrefix();
1010 SmallVector<Value> args = getCursor();
1011 args.push_back(Elt: crd);
1012 Operation *locate = b.create(l, b.getStringAttr(prefix + ".locate"), args,
1013 getCursorValTypes(b));
1014 seek(vals: locate->getResults());
1015 updateCrd(crd);
1016 return;
1017 }
1018 return locateImpl(b, l, crd);
1019}
1020
1021Value SparseIterator::deref(OpBuilder &b, Location l) {
1022 if (emitStrategy == SparseEmitStrategy::kDebugInterface) {
1023 std::string prefix = getDebugInterfacePrefix();
1024 SmallVector<Value> args = getCursor();
1025 Operation *deref = b.create(l, b.getStringAttr(prefix + ".deref"),
1026 getCursor(), b.getIndexType());
1027 updateCrd(crd: deref->getResult(idx: 0));
1028 return getCrd();
1029 }
1030 return derefImpl(b, l);
1031}
1032
1033ValueRange SparseIterator::forward(OpBuilder &b, Location l) {
1034 assert(!randomAccessible());
1035 if (emitStrategy == SparseEmitStrategy::kDebugInterface) {
1036 std::string prefix = getDebugInterfacePrefix();
1037 Operation *next = b.create(l, b.getStringAttr(prefix + ".next"),
1038 getCursor(), getCursorValTypes(b));
1039 seek(vals: next->getResults());
1040 return getCursor();
1041 }
1042 return forwardImpl(b, l);
1043}
1044
1045ValueRange SparseIterator::forwardIf(OpBuilder &b, Location l, Value cond) {
1046 auto ifOp = b.create<scf::IfOp>(l, getCursor().getTypes(), cond, true);
1047 // Generate else branch first, otherwise iterator values will be updated by
1048 // `forward()`.
1049 b.setInsertionPointToStart(ifOp.elseBlock());
1050 YIELD(getCursor());
1051
1052 b.setInsertionPointToStart(ifOp.thenBlock());
1053 YIELD(forward(b, l));
1054
1055 b.setInsertionPointAfter(ifOp);
1056 seek(vals: ifOp.getResults());
1057 return getCursor();
1058}
1059
1060Value DedupIterator::genSegmentHigh(OpBuilder &b, Location l, Value pos) {
1061 auto whileOp = b.create<scf::WhileOp>(
1062 l, pos.getType(), pos,
1063 /*beforeBuilder=*/
1064 [this, pos](OpBuilder &b, Location l, ValueRange ivs) {
1065 Value inBound = CMPI(ult, ivs.front(), posHi);
1066 auto ifInBound = b.create<scf::IfOp>(l, b.getI1Type(), inBound, true);
1067 {
1068 OpBuilder::InsertionGuard guard(b);
1069 // If in bound, load the next coordinates and check duplication.
1070 b.setInsertionPointToStart(ifInBound.thenBlock());
1071 Value headCrd = stl.peekCrdAt(b, l, getBatchCrds(), pos);
1072 Value tailCrd = stl.peekCrdAt(b, l, getBatchCrds(), ivs.front());
1073 Value isDup = CMPI(eq, headCrd, tailCrd);
1074 YIELD(isDup);
1075 // Else, the position is out of bound, yield false.
1076 b.setInsertionPointToStart(ifInBound.elseBlock());
1077 YIELD(constantI1(b, l, false));
1078 }
1079 b.create<scf::ConditionOp>(l, ifInBound.getResults()[0], ivs);
1080 },
1081 /*afterBuilder=*/
1082 [](OpBuilder &b, Location l, ValueRange ivs) {
1083 Value nxPos = ADDI(ivs[0], C_IDX(1));
1084 YIELD(nxPos);
1085 });
1086 // Return the segment high.
1087 return whileOp.getResult(0);
1088}
1089
1090Value FilterIterator::genCrdNotLegitPredicate(OpBuilder &b, Location l,
1091 Value wrapCrd) {
1092 Value crd = fromWrapCrd(b, l, wrapCrd);
1093 // Test whether the coordinate is on stride.
1094 Value notlegit = CMPI(ne, toWrapCrd(b, l, crd), wrapCrd);
1095 // Test wrapCrd < offset
1096 notlegit = ORI(CMPI(ult, wrapCrd, offset), notlegit);
1097 // Test crd >= length
1098 notlegit = ORI(CMPI(uge, crd, size), notlegit);
1099 return notlegit;
1100}
1101
1102Value FilterIterator::genShouldFilter(OpBuilder &b, Location l) {
1103 auto r = genWhenInBound(
1104 b, l, it&: *wrap, C_FALSE,
1105 builder: [this](OpBuilder &b, Location l, Value wrapCrd) -> scf::ValueVector {
1106 Value notLegit = genCrdNotLegitPredicate(b, l, wrapCrd);
1107 return {notLegit};
1108 });
1109 return llvm::getSingleElement(C&: r);
1110}
1111
1112Value FilterIterator::genNotEndImpl(OpBuilder &b, Location l) {
1113 assert(!wrap->randomAccessible());
1114 auto r = genWhenInBound(
1115 b, l, it&: *wrap, C_FALSE,
1116 builder: [this](OpBuilder &b, Location l, Value wrapCrd) -> scf::ValueVector {
1117 Value crd = fromWrapCrd(b, l, wrapCrd);
1118 // crd < size
1119 return {CMPI(ult, crd, size)};
1120 });
1121 return llvm::getSingleElement(C&: r);
1122}
1123
1124ValueRange FilterIterator::forwardImpl(OpBuilder &b, Location l) {
1125 assert(!randomAccessible());
1126 // Generates
1127 //
1128 // bool isFirst = true;
1129 // while !it.end() && (!legit(*it) || isFirst)
1130 // wrap ++;
1131 // isFirst = false;
1132 //
1133 // We do not hoist the first `wrap++` outside the loop but use a `isFirst`
1134 // flag here because `wrap++` might have a complex implementation (e.g., to
1135 // forward a subsection).
1136 Value isFirst = constantI1(builder&: b, loc: l, b: true);
1137
1138 SmallVector<Value> whileArgs(getCursor().begin(), getCursor().end());
1139 whileArgs.push_back(Elt: isFirst);
1140 auto whileOp = b.create<scf::WhileOp>(
1141 l, ValueRange(whileArgs).getTypes(), whileArgs,
1142 /*beforeBuilder=*/
1143 [this](OpBuilder &b, Location l, ValueRange ivs) {
1144 ValueRange isFirst = linkNewScope(ivs);
1145 scf::ValueVector cont =
1146 genWhenInBound(b, l, *wrap, C_FALSE,
1147 [this, isFirst](OpBuilder &b, Location l,
1148 Value wrapCrd) -> scf::ValueVector {
1149 // crd < size && !legit();
1150 Value notLegit =
1151 genCrdNotLegitPredicate(b, l, wrapCrd);
1152 Value crd = fromWrapCrd(b, l, wrapCrd);
1153 Value ret = ANDI(CMPI(ult, crd, size), notLegit);
1154 ret = ORI(ret, llvm::getSingleElement(isFirst));
1155 return {ret};
1156 });
1157 b.create<scf::ConditionOp>(l, cont.front(), ivs);
1158 },
1159 /*afterBuilder=*/
1160 [this](OpBuilder &b, Location l, ValueRange ivs) {
1161 linkNewScope(ivs);
1162 wrap->forward(b, l);
1163 SmallVector<Value> yieldVals(getCursor().begin(), getCursor().end());
1164 yieldVals.push_back(constantI1(b, l, false));
1165 YIELD(yieldVals);
1166 });
1167
1168 b.setInsertionPointAfter(whileOp);
1169 linkNewScope(pos: whileOp.getResults());
1170 return getCursor();
1171}
1172
1173SubSectIterHelper::SubSectIterHelper(const NonEmptySubSectIterator &subSect)
1174 : subSect(subSect), wrap(*subSect.delegate) {}
1175
1176SubSectIterHelper::SubSectIterHelper(const SubSectIterator &iter)
1177 : subSect(iter.subSect), wrap(*iter.wrap) {}
1178
1179void SubSectIterHelper::deserializeFromTupleId(OpBuilder &b, Location l,
1180 Value tupleId) {
1181 assert(!subSect.randomAccessible());
1182 wrap.deserialize(vs: subSect.loadCursorVals(b, l, tupleId));
1183}
1184
1185void SubSectIterHelper::locate(OpBuilder &b, Location l, Value crd) {
1186 Value absCrd = ADDI(crd, subSect.getAbsOff());
1187 wrap.locate(b, l, crd: absCrd);
1188}
1189
1190Value SubSectIterHelper::genNotEnd(OpBuilder &b, Location l) {
1191 assert(!wrap.randomAccessible());
1192 auto r = genWhenInBound(
1193 b, l, it&: wrap, C_FALSE,
1194 builder: [this](OpBuilder &b, Location l, Value wrapCrd) -> scf::ValueVector {
1195 Value crd = SUBI(wrapCrd, subSect.getAbsOff());
1196 // crd < size
1197 return {CMPI(ult, crd, subSect.subSectSz)};
1198 });
1199 return llvm::getSingleElement(C&: r);
1200}
1201
1202Value SubSectIterHelper::deref(OpBuilder &b, Location l) {
1203 Value wrapCrd = wrap.deref(b, l);
1204 Value crd = subSect.toSubSectCrd(b, l, wrapCrd);
1205 return crd;
1206}
1207
1208ValueRange SubSectIterHelper::forward(OpBuilder &b, Location l) {
1209 return wrap.forward(b, l);
1210}
1211
1212ValueRange NonEmptySubSectIterator::inflateSubSectTree(
1213 OpBuilder &b, Location l, ValueRange reduc, TraverseBuilder builder) const {
1214 // Set up the helper to help traverse a sparse subsection.
1215 SubSectIterHelper helper(*this);
1216 if (!randomAccessible()) {
1217 // The subsection tree have been expanded till the level and cached,
1218 // traverse all the leaves and expanded to the next level.
1219 SmallVector<Value> iterArgs;
1220 iterArgs.push_back(C_IDX(0));
1221 iterArgs.append(in_start: reduc.begin(), in_end: reduc.end());
1222 auto forEachLeaf = b.create<scf::ForOp>(
1223 l, /*lb=*/C_IDX(0), /*ub=*/tupleCnt, /*step=*/C_IDX(1), iterArgs,
1224 [&helper, &builder](OpBuilder &b, Location l, Value tupleId,
1225 ValueRange iterArgs) {
1226 // Deserialize the iterator at the cached position (tupleId).
1227 helper.deserializeFromTupleId(b, l, tupleId);
1228
1229 Value cnt = iterArgs.front();
1230 // Record the number of leaf nodes included in the subsection.
1231 // The number indicates the starting tupleId for the next level that
1232 // is corresponding to the current node.
1233 helper.subSect.storeNxLvlStart(b, l, tupleId, cnt);
1234
1235 SmallVector<Value> whileArgs(helper.wrap.getCursor());
1236 whileArgs.append(iterArgs.begin(), iterArgs.end());
1237
1238 auto whileOp = b.create<scf::WhileOp>(
1239 l, ValueRange(whileArgs).getTypes(), whileArgs,
1240 /*beforeBuilder=*/
1241 [&helper](OpBuilder &b, Location l, ValueRange ivs) {
1242 helper.wrap.linkNewScope(ivs);
1243 b.create<scf::ConditionOp>(l, helper.genNotEnd(b, l), ivs);
1244 },
1245 /*afterBuilder=*/
1246 [&helper, &builder](OpBuilder &b, Location l, ValueRange ivs) {
1247 ValueRange remIter = helper.wrap.linkNewScope(ivs);
1248 Value cnt = remIter.front();
1249 ValueRange userIter = remIter.drop_front();
1250 scf::ValueVector userNx = builder(b, l, &helper.wrap, userIter);
1251
1252 SmallVector<Value> nxIter = helper.forward(b, l);
1253 nxIter.push_back(ADDI(cnt, C_IDX(1)));
1254 nxIter.append(userNx.begin(), userNx.end());
1255 YIELD(nxIter);
1256 });
1257 ValueRange res = helper.wrap.linkNewScope(whileOp.getResults());
1258 YIELD(res);
1259 });
1260 return forEachLeaf.getResults().drop_front();
1261 }
1262
1263 assert(randomAccessible());
1264 // Helper lambda that traverse the current dense subsection range.
1265 auto visitDenseSubSect = [&, this](OpBuilder &b, Location l,
1266 const SparseIterator *parent,
1267 ValueRange reduc) {
1268 assert(!parent || parent->lvl + 1 == lvl);
1269 delegate->genInit(b, l, p: parent);
1270 auto forOp = b.create<scf::ForOp>(
1271 l, /*lb=*/C_IDX(0), /*ub=*/subSectSz, /*step=*/C_IDX(1), reduc,
1272 [&](OpBuilder &b, Location l, Value crd, ValueRange iterArgs) {
1273 helper.locate(b, l, crd);
1274 scf::ValueVector nx = builder(b, l, &helper.wrap, iterArgs);
1275 YIELD(nx);
1276 });
1277 return forOp.getResults();
1278 };
1279
1280 if (isSubSectRoot()) {
1281 return visitDenseSubSect(b, l, parent, reduc);
1282 }
1283 // Else, this is not the root, recurse until root.
1284 auto *p = llvm::cast<NonEmptySubSectIterator>(Val: parent);
1285 assert(p->lvl + 1 == lvl);
1286 return p->inflateSubSectTree(b, l, reduc, visitDenseSubSect);
1287}
1288
1289void TrivialIterator::genInitImpl(OpBuilder &b, Location l,
1290 const SparseIterator *parent) {
1291
1292 if (isBatchIterator() && batchCrds.size() <= stl.lvl)
1293 batchCrds.resize(N: stl.lvl + 1, NV: nullptr);
1294
1295 Value c0 = C_IDX(0);
1296 ValueRange pPos = c0;
1297 Value inPadZone = nullptr;
1298 // If the parent iterator is a batch iterator, we also start from 0 (but
1299 // on a different batch).
1300 if (parent && !parent->isBatchIterator()) {
1301 pPos = parent->getCurPosition();
1302 if (llvm::isa<PadIterator>(Val: parent) && parent->randomAccessible()) {
1303 // A padded dense iterator create "sparse" padded zone, which need to be
1304 // handled specially.
1305 inPadZone = pPos.back();
1306 pPos = pPos.drop_back();
1307 }
1308 }
1309
1310 ValueRange batchPrefix = parent ? parent->getBatchCrds() : ValueRange{};
1311 std::tie(args&: posLo, args&: posHi) = stl.peekRangeAt(b, l, batchPrefix, parentPos: pPos, inPadZone);
1312 // Seek to the lowest position.
1313 seek(vals: posLo);
1314}
1315
1316void NonEmptySubSectIterator::genInitImpl(OpBuilder &b, Location l,
1317 const SparseIterator *) {
1318 Value c0 = C_IDX(0);
1319 if (!isSubSectRoot()) {
1320 assert(parent->lvl + 1 == lvl);
1321 if (randomAccessible()) {
1322 // We can not call wrap->genInit() here to initialize the wrapped
1323 // iterator, because the parent of the curent iterator is still
1324 // unresolved.
1325 seek(vals: {/*minCrd=*/c0, /*offset=*/c0, /*notEnd=*/C_TRUE});
1326 return;
1327 }
1328
1329 auto *p = cast<NonEmptySubSectIterator>(Val: parent);
1330 SmallVector<Value, 3> reduc = {
1331 C_IDX(-1), // minCrd (max signless integer)
1332 c0, // tupleId
1333 };
1334
1335 // Expand the subsection tree from the parent level to the current level.
1336 ValueRange result = p->inflateSubSectTree(
1337 b, l, reduc,
1338 builder: [this](OpBuilder &b, Location l, const SparseIterator *parent,
1339 ValueRange reduc) -> scf::ValueVector {
1340 assert(parent->lvl + 1 == lvl && reduc.size() == 2);
1341 Value minCrd = reduc.front();
1342 Value tupleId = reduc.back();
1343
1344 // Initialize the subsection range.
1345 SubSectIterHelper helper(*this);
1346 helper.wrap.genInit(b, l, p: parent);
1347
1348 // Update minCrd.
1349 minCrd = genWhenInBound(b, l, it&: helper.wrap, elseRet: minCrd,
1350 builder: [minCrd](OpBuilder &b, Location l,
1351 Value crd) -> scf::ValueVector {
1352 Value min = MINUI(crd, minCrd);
1353 return {min};
1354 })
1355 .front();
1356
1357 // Cache the sparse range.
1358 storeCursorVals(b, l, tupleId, itVals: helper.wrap.serialize());
1359 tupleId = ADDI(tupleId, C_IDX(1));
1360 return {minCrd, tupleId};
1361 });
1362 assert(result.size() == 2);
1363 tupleCnt = result.back();
1364
1365 Value minCrd = result.front();
1366 Value absOff = offsetFromMinCrd(b, l, minCrd, size: subSectSz);
1367 Value notEnd = CMPI(ne, minCrd, C_IDX(-1));
1368 seek(vals: {minCrd, absOff, notEnd});
1369 return;
1370 }
1371
1372 // This is the root level of the subsection, which means that it is resolved
1373 // to one node.
1374 assert(isSubSectRoot());
1375
1376 // Initialize the position, the position marks the *lower bound* of the
1377 // subRange. The higher bound is determined by the size of the subsection.
1378 delegate->genInit(b, l, p: parent);
1379 if (randomAccessible()) {
1380 seek(vals: {/*minCrd=*/c0, /*offset=*/c0, /*notEnd=*/C_TRUE});
1381 return;
1382 }
1383
1384 // Only have one root node.
1385 tupleCnt = C_IDX(1);
1386 // Cache the sparse range.
1387 storeCursorVals(b, l, tupleId: c0, itVals: delegate->serialize());
1388 SmallVector<Value> elseRet{c0, c0, /*notEnd=*/C_FALSE};
1389 auto meta = genWhenInBound(
1390 b, l, it&: *delegate, elseRet,
1391 builder: [this](OpBuilder &b, Location l, Value crd) -> scf::ValueVector {
1392 Value offset = offsetFromMinCrd(b, l, minCrd: crd, size: subSectSz);
1393 return {crd, offset, C_TRUE};
1394 });
1395
1396 seek(vals: meta);
1397}
1398
1399ValueRange NonEmptySubSectIterator::forwardImpl(OpBuilder &b, Location l) {
1400 assert(!randomAccessible());
1401 Value c0 = C_IDX(0), c1 = C_IDX(1);
1402 // Forward to the next non empty slice by generating
1403 //
1404 // if (minCrd > offset) {
1405 // offset += 1
1406 // } else {
1407 // minCrd = nextMinInSlice();
1408 // offset = minCrd - size + 1;
1409 // }
1410 //
1411 // if (offset + size > parents.size)
1412 // isNonEmpty = false;
1413 Value fastPathP = CMPI(ugt, getMinCrd(), getAbsOff());
1414 auto ifOp = b.create<scf::IfOp>(l, getCursor().getTypes(), fastPathP, true);
1415 {
1416 OpBuilder::InsertionGuard guard(b);
1417 // Take the fast path
1418 // if (minCrd > offset)
1419 // offset += 1
1420 b.setInsertionPointToStart(&ifOp.getThenRegion().front());
1421 Value nxOffset = ADDI(getAbsOff(), c1);
1422 YIELD((ValueRange{getMinCrd(), nxOffset, getNotEnd()}));
1423
1424 // else /*minCrd == offset*/ {
1425 // for (i = 0; i < tupleCnt; i++) {
1426 // wrap->deserialize(pos[i]);
1427 // minCrd=min(minCrd, *wrap);
1428 // }
1429 // offset = minCrd - size + 1;
1430 // }
1431 b.setInsertionPointToStart(&ifOp.getElseRegion().front());
1432 SmallVector<Value, 2> loopArgs{C_IDX(-1), // nextMinCrd
1433 C_FALSE}; // isNotEnd
1434 auto loopNest = scf::buildLoopNest(
1435 builder&: b, loc: l, lbs: c0, ubs: tupleCnt, steps: c1, iterArgs: loopArgs,
1436 bodyBuilder: [this](OpBuilder &b, Location l, ValueRange ivs,
1437 ValueRange iterArgs) -> scf::ValueVector {
1438 Value tupleId = ivs.front();
1439 SubSectIterHelper helper(*this);
1440 helper.deserializeFromTupleId(b, l, tupleId);
1441
1442 return genWhenInBound(
1443 b, l, it&: *delegate, /*elseRet=*/iterArgs,
1444 builder: [this, iterArgs, tupleId](OpBuilder &b, Location l,
1445 Value crd) -> scf::ValueVector {
1446 // if coord == minCrd
1447 // wrap->forward();
1448 Value isMin = CMPI(eq, crd, getMinCrd());
1449 delegate->forwardIf(b, l, cond: isMin);
1450 // Update the forwarded iterator values if needed.
1451 auto ifIsMin = b.create<scf::IfOp>(l, isMin, false);
1452 b.setInsertionPointToStart(&ifIsMin.getThenRegion().front());
1453 storeCursorVals(b, l, tupleId, itVals: delegate->serialize());
1454 b.setInsertionPointAfter(ifIsMin);
1455 // if (!wrap.end())
1456 // yield(min(nxMinCrd, *wrap), true)
1457 Value nxMin = iterArgs[0];
1458 return genWhenInBound(b, l, it&: *delegate, /*elseRet=*/iterArgs,
1459 builder: [nxMin](OpBuilder &b, Location l,
1460 Value crd) -> scf::ValueVector {
1461 Value nx = b.create<arith::MinUIOp>(
1462 l, crd, nxMin);
1463 return {nx, C_TRUE};
1464 });
1465 });
1466 });
1467
1468 scf::ForOp forOp = loopNest.loops.front();
1469 b.setInsertionPointAfter(forOp);
1470
1471 Value nxMinCrd = forOp.getResult(0);
1472 Value nxNotEnd = forOp.getResult(1);
1473 Value nxAbsOff = offsetFromMinCrd(b, l, minCrd: nxMinCrd, size: subSectSz);
1474 YIELD((ValueRange{nxMinCrd, nxAbsOff, nxNotEnd}));
1475 }
1476
1477 Value nxMinCrd = ifOp.getResult(0);
1478 Value nxAbsOff = ifOp.getResult(1);
1479 Value nxNotEnd = ifOp.getResult(2);
1480
1481 // We should at least forward the offset by one.
1482 Value minAbsOff = ADDI(getAbsOff(), c1);
1483 nxAbsOff = b.create<arith::MaxUIOp>(l, minAbsOff, nxAbsOff);
1484
1485 seek(vals: ValueRange{nxMinCrd, nxAbsOff, nxNotEnd});
1486 // The coordinate should not exceeds the space upper bound.
1487 Value crd = deref(b, l);
1488 nxNotEnd = ANDI(nxNotEnd, CMPI(ult, crd, upperBound(b, l)));
1489
1490 seek(vals: ValueRange{nxMinCrd, nxAbsOff, nxNotEnd});
1491 return getCursor();
1492}
1493
1494//===----------------------------------------------------------------------===//
1495// SparseIterationSpace Implementation
1496//===----------------------------------------------------------------------===//
1497
1498mlir::sparse_tensor::SparseIterationSpace::SparseIterationSpace(
1499 Location l, OpBuilder &b, Value t, unsigned tid,
1500 std::pair<Level, Level> lvlRange, ValueRange parentPos)
1501 : lvls() {
1502 auto [lvlLo, lvlHi] = lvlRange;
1503
1504 Value c0 = C_IDX(0);
1505 if (parentPos.empty())
1506 parentPos = c0;
1507
1508 for (Level lvl = lvlLo; lvl < lvlHi; lvl++)
1509 lvls.emplace_back(Args: makeSparseTensorLevel(b, l, t, tid, lvl));
1510
1511 bound = lvls.front()->peekRangeAt(b, l, /*batchPrefix=*/{}, parentPos);
1512 for (auto &lvl : getLvlRef().drop_front())
1513 bound = lvl->collapseRangeBetween(b, l, /*batchPrefix=*/{}, parentRange: bound);
1514}
1515
1516SparseIterationSpace mlir::sparse_tensor::SparseIterationSpace::fromValues(
1517 IterSpaceType dstTp, ValueRange values, unsigned int tid) {
1518 // Reconstruct every sparse tensor level.
1519 SparseIterationSpace space;
1520 for (auto [i, lt] : llvm::enumerate(dstTp.getLvlTypes())) {
1521 unsigned bufferCnt = 0;
1522 if (lt.isWithPosLT())
1523 bufferCnt++;
1524 if (lt.isWithCrdLT())
1525 bufferCnt++;
1526 // Sparse tensor buffers.
1527 ValueRange buffers = values.take_front(bufferCnt);
1528 values = values.drop_front(bufferCnt);
1529
1530 // Level size.
1531 Value sz = values.front();
1532 values = values.drop_front();
1533 space.lvls.push_back(
1534 makeSparseTensorLevel(lt, sz, buffers, tid, i + dstTp.getLoLvl()));
1535 }
1536 // Two bounds.
1537 space.bound = std::make_pair(x: values[0], y: values[1]);
1538 values = values.drop_front(n: 2);
1539
1540 // Must have consumed all values.
1541 assert(values.empty());
1542 return space;
1543}
1544
1545std::unique_ptr<SparseIterator>
1546SparseIterationSpace::extractIterator(OpBuilder &b, Location l) const {
1547 return makeSimpleIterator(b, l, iterSpace: *this);
1548}
1549
1550//===----------------------------------------------------------------------===//
1551// SparseIterator factory functions.
1552//===----------------------------------------------------------------------===//
1553
1554/// Helper function to create a TensorLevel object from given `tensor`.
1555std::unique_ptr<SparseTensorLevel>
1556sparse_tensor::makeSparseTensorLevel(LevelType lt, Value sz, ValueRange b,
1557 unsigned t, Level l) {
1558 assert(lt.getNumBuffer() == b.size());
1559 switch (lt.getLvlFmt()) {
1560 case LevelFormat::Dense:
1561 return std::make_unique<DenseLevel>(args&: t, args&: l, args&: sz);
1562 case LevelFormat::Batch:
1563 return std::make_unique<BatchLevel>(args&: t, args&: l, args&: sz);
1564 case LevelFormat::Compressed:
1565 return std::make_unique<CompressedLevel>(args&: t, args&: l, args&: lt, args&: sz, args: b[0], args: b[1]);
1566 case LevelFormat::LooseCompressed:
1567 return std::make_unique<LooseCompressedLevel>(args&: t, args&: l, args&: lt, args&: sz, args: b[0], args: b[1]);
1568 case LevelFormat::Singleton:
1569 return std::make_unique<SingletonLevel>(args&: t, args&: l, args&: lt, args&: sz, args: b[0]);
1570 case LevelFormat::NOutOfM:
1571 return std::make_unique<NOutOfMLevel>(args&: t, args&: l, args&: lt, args&: sz, args: b[0]);
1572 case LevelFormat::Undef:
1573 llvm_unreachable("undefined level format");
1574 }
1575 llvm_unreachable("unrecognizable level format");
1576}
1577
1578std::unique_ptr<SparseTensorLevel>
1579sparse_tensor::makeSparseTensorLevel(OpBuilder &b, Location l, Value t,
1580 unsigned tid, Level lvl) {
1581 auto stt = getSparseTensorType(val: t);
1582
1583 LevelType lt = stt.getLvlType(l: lvl);
1584 Value sz = stt.hasEncoding() ? b.create<LvlOp>(l, t, lvl).getResult()
1585 : b.create<tensor::DimOp>(l, t, lvl).getResult();
1586
1587 SmallVector<Value, 2> buffers;
1588 if (lt.isWithPosLT()) {
1589 Value pos = b.create<ToPositionsOp>(l, t, lvl);
1590 buffers.push_back(Elt: pos);
1591 }
1592 if (lt.isWithCrdLT()) {
1593 Value pos = b.create<ToCoordinatesOp>(l, t, lvl);
1594 buffers.push_back(Elt: pos);
1595 }
1596 return makeSparseTensorLevel(lt, sz, b: buffers, t: tid, l: lvl);
1597}
1598
1599std::pair<std::unique_ptr<SparseTensorLevel>, std::unique_ptr<SparseIterator>>
1600sparse_tensor::makeSynLevelAndIterator(Value sz, unsigned tid, unsigned lvl,
1601 SparseEmitStrategy strategy) {
1602 auto stl = std::make_unique<BatchLevel>(args&: tid, args&: lvl, args&: sz);
1603 auto it = std::make_unique<TrivialIterator>(args&: *stl);
1604 it->setSparseEmitStrategy(strategy);
1605 return std::make_pair(x: std::move(stl), y: std::move(it));
1606}
1607
1608std::unique_ptr<SparseIterator>
1609sparse_tensor::makeSimpleIterator(OpBuilder &b, Location l,
1610 const SparseIterationSpace &iterSpace) {
1611 // assert(iterSpace.getSpaceDim() == 1);
1612 std::unique_ptr<SparseIterator> ret;
1613 if (!iterSpace.isUnique()) {
1614 // We always dedupliate the non-unique level, but we should optimize it away
1615 // if possible.
1616 ret = std::make_unique<DedupIterator>(args&: b, args&: l, args: iterSpace.getLastLvl(),
1617 args: iterSpace.getBoundLo(),
1618 args: iterSpace.getBoundHi());
1619 } else {
1620 ret = std::make_unique<TrivialIterator>(args&: b, args&: l, args: iterSpace.getLastLvl(),
1621 args: iterSpace.getBoundLo(),
1622 args: iterSpace.getBoundHi());
1623 }
1624 ret->setSparseEmitStrategy(SparseEmitStrategy::kFunctional);
1625 return ret;
1626}
1627
1628std::unique_ptr<SparseIterator>
1629sparse_tensor::makeSimpleIterator(const SparseTensorLevel &stl,
1630 SparseEmitStrategy strategy) {
1631 std::unique_ptr<SparseIterator> ret;
1632 if (!isUniqueLT(lt: stl.getLT())) {
1633 // We always dedupliate the non-unique level, but we should optimize it away
1634 // if possible.
1635 ret = std::make_unique<DedupIterator>(args: stl);
1636 } else {
1637 ret = std::make_unique<TrivialIterator>(args: stl);
1638 }
1639 ret->setSparseEmitStrategy(strategy);
1640 return ret;
1641}
1642
1643std::unique_ptr<SparseIterator>
1644sparse_tensor::makeSlicedLevelIterator(std::unique_ptr<SparseIterator> &&sit,
1645 Value offset, Value stride, Value size,
1646 SparseEmitStrategy strategy) {
1647
1648 auto ret =
1649 std::make_unique<FilterIterator>(args: std::move(sit), args&: offset, args&: stride, args&: size);
1650 ret->setSparseEmitStrategy(strategy);
1651 return ret;
1652}
1653
1654std::unique_ptr<SparseIterator>
1655sparse_tensor::makePaddedIterator(std::unique_ptr<SparseIterator> &&sit,
1656 Value padLow, Value padHigh,
1657 SparseEmitStrategy strategy) {
1658 auto ret = std::make_unique<PadIterator>(args: std::move(sit), args&: padLow, args&: padHigh);
1659 ret->setSparseEmitStrategy(strategy);
1660 return ret;
1661}
1662
1663static const SparseIterator *tryUnwrapFilter(const SparseIterator *it) {
1664 auto *filter = llvm::dyn_cast_or_null<FilterIterator>(Val: it);
1665 if (filter)
1666 return &filter->getWrappedIterator();
1667 return it;
1668}
1669
1670std::unique_ptr<SparseIterator> sparse_tensor::makeNonEmptySubSectIterator(
1671 OpBuilder &b, Location l, const SparseIterator *parent, Value loopBound,
1672 std::unique_ptr<SparseIterator> &&delegate, Value size, unsigned stride,
1673 SparseEmitStrategy strategy) {
1674
1675 // Try unwrap the NonEmptySubSectIterator from a filter parent.
1676 parent = tryUnwrapFilter(it: parent);
1677 std::unique_ptr<SparseIterator> it =
1678 std::make_unique<NonEmptySubSectIterator>(args&: b, args&: l, args&: parent,
1679 args: std::move(delegate), args&: size);
1680
1681 if (stride != 1) {
1682 // TODO: We can safely skip bound checking on sparse levels, but for dense
1683 // iteration space, we need the bound to infer the dense loop range.
1684 it = std::make_unique<FilterIterator>(args: std::move(it), /*offset=*/C_IDX(0),
1685 C_IDX(stride), /*size=*/args&: loopBound);
1686 }
1687 it->setSparseEmitStrategy(strategy);
1688 return it;
1689}
1690
1691std::unique_ptr<SparseIterator> sparse_tensor::makeTraverseSubSectIterator(
1692 OpBuilder &b, Location l, const SparseIterator &subSectIter,
1693 const SparseIterator &parent, std::unique_ptr<SparseIterator> &&wrap,
1694 Value loopBound, unsigned stride, SparseEmitStrategy strategy) {
1695
1696 // This must be a subsection iterator or a filtered subsection iterator.
1697 auto &subSect =
1698 llvm::cast<NonEmptySubSectIterator>(Val: *tryUnwrapFilter(it: &subSectIter));
1699
1700 std::unique_ptr<SparseIterator> it = std::make_unique<SubSectIterator>(
1701 args: subSect, args: *tryUnwrapFilter(it: &parent), args: std::move(wrap));
1702
1703 if (stride != 1) {
1704 it = std::make_unique<FilterIterator>(args: std::move(it), /*offset=*/C_IDX(0),
1705 C_IDX(stride), /*size=*/args&: loopBound);
1706 }
1707 it->setSparseEmitStrategy(strategy);
1708 return it;
1709}
1710
1711#undef CMPI
1712#undef C_IDX
1713#undef YIELD
1714#undef ADDI
1715#undef ANDI
1716#undef SUBI
1717#undef MULI
1718#undef SELECT
1719

source code of mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp