1 | //===- SparseTensorIterator.cpp -------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "SparseTensorIterator.h" |
10 | #include "CodegenUtils.h" |
11 | |
12 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
13 | #include "mlir/Dialect/SCF/IR/SCF.h" |
14 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
15 | |
16 | using namespace mlir; |
17 | using namespace mlir::sparse_tensor; |
18 | using ValuePair = std::pair<Value, Value>; |
19 | using ValueTuple = std::tuple<Value, Value, Value>; |
20 | |
21 | //===----------------------------------------------------------------------===// |
22 | // File local helper functions/macros. |
23 | //===----------------------------------------------------------------------===// |
24 | #define CMPI(p, lhs, rhs) \ |
25 | (b.create<arith::CmpIOp>(l, arith::CmpIPredicate::p, (lhs), (rhs)) \ |
26 | .getResult()) |
27 | |
28 | #define C_FALSE (constantI1(b, l, false)) |
29 | #define C_TRUE (constantI1(b, l, true)) |
30 | #define C_IDX(v) (constantIndex(b, l, (v))) |
31 | #define YIELD(vs) (b.create<scf::YieldOp>(l, (vs))) |
32 | #define ADDI(lhs, rhs) (b.create<arith::AddIOp>(l, (lhs), (rhs)).getResult()) |
33 | #define ORI(lhs, rhs) (b.create<arith::OrIOp>(l, (lhs), (rhs)).getResult()) |
34 | #define ANDI(lhs, rhs) (b.create<arith::AndIOp>(l, (lhs), (rhs)).getResult()) |
35 | #define SUBI(lhs, rhs) (b.create<arith::SubIOp>(l, (lhs), (rhs)).getResult()) |
36 | #define MULI(lhs, rhs) (b.create<arith::MulIOp>(l, (lhs), (rhs)).getResult()) |
37 | #define MINUI(lhs, rhs) (b.create<arith::MinUIOp>(l, (lhs), (rhs)).getResult()) |
38 | #define REMUI(lhs, rhs) (b.create<arith::RemUIOp>(l, (lhs), (rhs)).getResult()) |
39 | #define DIVUI(lhs, rhs) (b.create<arith::DivUIOp>(l, (lhs), (rhs)).getResult()) |
40 | #define SELECT(c, lhs, rhs) \ |
41 | (b.create<arith::SelectOp>(l, (c), (lhs), (rhs)).getResult()) |
42 | |
43 | //===----------------------------------------------------------------------===// |
44 | // SparseTensorLevel derived classes. |
45 | //===----------------------------------------------------------------------===// |
46 | |
47 | namespace { |
48 | |
49 | template <bool hasPosBuffer> |
50 | class SparseLevel : public SparseTensorLevel { |
51 | // It is either an array of size 2 or size 1 depending on whether the sparse |
52 | // level requires a position array. |
53 | using BufferT = std::conditional_t<hasPosBuffer, std::array<Value, 2>, |
54 | std::array<Value, 1>>; |
55 | |
56 | public: |
57 | SparseLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize, |
58 | BufferT buffers) |
59 | : SparseTensorLevel(tid, lvl, lt, lvlSize), buffers(buffers) {} |
60 | |
61 | ValueRange getLvlBuffers() const override { return buffers; } |
62 | |
63 | Value peekCrdAt(OpBuilder &b, Location l, ValueRange batchPrefix, |
64 | Value iv) const override { |
65 | SmallVector<Value> memCrd(batchPrefix); |
66 | memCrd.push_back(Elt: iv); |
67 | return genIndexLoad(b, l, getCrdBuf(), memCrd); |
68 | } |
69 | |
70 | protected: |
71 | template <typename T = void, typename = std::enable_if_t<hasPosBuffer, T>> |
72 | Value getPosBuf() const { |
73 | return buffers[0]; |
74 | } |
75 | |
76 | Value getCrdBuf() const { |
77 | if constexpr (hasPosBuffer) |
78 | return buffers[1]; |
79 | else |
80 | return buffers[0]; |
81 | } |
82 | |
83 | const BufferT buffers; |
84 | }; |
85 | |
86 | class DenseLevel : public SparseTensorLevel { |
87 | public: |
88 | DenseLevel(unsigned tid, Level lvl, Value lvlSize) |
89 | : SparseTensorLevel(tid, lvl, LevelFormat::Dense, lvlSize) {} |
90 | |
91 | Value peekCrdAt(OpBuilder &, Location, ValueRange, Value) const override { |
92 | llvm_unreachable("locate random-accessible level instead" ); |
93 | } |
94 | |
95 | ValueRange getLvlBuffers() const override { return {}; } |
96 | |
97 | ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix, |
98 | ValueRange parentPos, Value inPadZone) const override { |
99 | assert(parentPos.size() == 1 && "Dense level can not be non-unique." ); |
100 | assert(!inPadZone && "Not implemented" ); |
101 | Value p = parentPos.front(); |
102 | Value posLo = MULI(p, lvlSize); |
103 | return {posLo, lvlSize}; |
104 | } |
105 | }; |
106 | |
107 | class BatchLevel : public SparseTensorLevel { |
108 | public: |
109 | BatchLevel(unsigned tid, Level lvl, Value lvlSize) |
110 | : SparseTensorLevel(tid, lvl, LevelFormat::Batch, lvlSize) {} |
111 | |
112 | Value peekCrdAt(OpBuilder &, Location, ValueRange, Value) const override { |
113 | llvm_unreachable("locate random-accessible level instead" ); |
114 | } |
115 | |
116 | ValueRange getLvlBuffers() const override { return {}; } |
117 | |
118 | ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange, |
119 | ValueRange parentPos, Value inPadZone) const override { |
120 | assert(!inPadZone && "Not implemented" ); |
121 | assert(parentPos.size() == 1 && "Dense level can not be non-unique." ); |
122 | // No need to linearize the position for non-annotated tensors. |
123 | return {C_IDX(0), lvlSize}; |
124 | } |
125 | }; |
126 | |
127 | class CompressedLevel : public SparseLevel</*hasPosBuf=*/true> { |
128 | public: |
129 | CompressedLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize, |
130 | Value posBuffer, Value crdBuffer) |
131 | : SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {} |
132 | |
133 | ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix, |
134 | ValueRange parentPos, Value inPadZone) const override { |
135 | |
136 | assert(parentPos.size() == 1 && |
137 | "compressed level must be the first non-unique level." ); |
138 | |
139 | auto loadRange = [&b, l, parentPos, batchPrefix, this]() -> ValuePair { |
140 | Value p = parentPos.front(); |
141 | SmallVector<Value> memCrd(batchPrefix); |
142 | memCrd.push_back(Elt: p); |
143 | Value pLo = genIndexLoad(builder&: b, loc: l, mem: getPosBuf(), s: memCrd); |
144 | memCrd.back() = ADDI(p, C_IDX(1)); |
145 | Value pHi = genIndexLoad(builder&: b, loc: l, mem: getPosBuf(), s: memCrd); |
146 | return {pLo, pHi}; |
147 | }; |
148 | |
149 | if (inPadZone == nullptr) |
150 | return loadRange(); |
151 | |
152 | SmallVector<Type, 2> types{b.getIndexType(), b.getIndexType()}; |
153 | scf::IfOp posRangeIf = b.create<scf::IfOp>(l, types, inPadZone, true); |
154 | // True branch, returns a "fake" empty range [0, 0) if parent |
155 | // iterator is in pad zone. |
156 | b.setInsertionPointToStart(posRangeIf.thenBlock()); |
157 | |
158 | SmallVector<Value, 2> emptyRange{C_IDX(0), C_IDX(0)}; |
159 | b.create<scf::YieldOp>(l, emptyRange); |
160 | |
161 | // False branch, returns the actual range. |
162 | b.setInsertionPointToStart(posRangeIf.elseBlock()); |
163 | auto [pLo, pHi] = loadRange(); |
164 | SmallVector<Value, 2> loadedRange{pLo, pHi}; |
165 | b.create<scf::YieldOp>(l, loadedRange); |
166 | |
167 | b.setInsertionPointAfter(posRangeIf); |
168 | ValueRange posRange = posRangeIf.getResults(); |
169 | return {posRange.front(), posRange.back()}; |
170 | } |
171 | }; // namespace |
172 | |
173 | class LooseCompressedLevel : public SparseLevel</*hasPosBuf=*/true> { |
174 | public: |
175 | LooseCompressedLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize, |
176 | Value posBuffer, Value crdBuffer) |
177 | : SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {} |
178 | |
179 | ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix, |
180 | ValueRange parentPos, Value inPadZone) const override { |
181 | assert(parentPos.size() == 1 && |
182 | "loose-compressed level must be the first non-unique level." ); |
183 | assert(!inPadZone && "Not implemented" ); |
184 | SmallVector<Value> memCrd(batchPrefix); |
185 | Value p = parentPos.front(); |
186 | p = MULI(p, C_IDX(2)); |
187 | memCrd.push_back(Elt: p); |
188 | Value pLo = genIndexLoad(builder&: b, loc: l, mem: getPosBuf(), s: memCrd); |
189 | memCrd.back() = ADDI(p, C_IDX(1)); |
190 | Value pHi = genIndexLoad(builder&: b, loc: l, mem: getPosBuf(), s: memCrd); |
191 | return {pLo, pHi}; |
192 | } |
193 | }; // namespace |
194 | |
195 | class SingletonLevel : public SparseLevel</*hasPosBuf=*/false> { |
196 | public: |
197 | SingletonLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize, |
198 | Value crdBuffer) |
199 | : SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {} |
200 | |
201 | ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix, |
202 | ValueRange parentPos, Value inPadZone) const override { |
203 | assert(parentPos.size() == 1 || parentPos.size() == 2); |
204 | assert(!inPadZone && "Not implemented" ); |
205 | Value p = parentPos.front(); |
206 | Value segHi = parentPos.size() == 2 ? parentPos.back() : nullptr; |
207 | |
208 | if (segHi == nullptr) |
209 | return {p, ADDI(p, C_IDX(1))}; |
210 | // Use the segHi as the loop upper bound. |
211 | return {p, segHi}; |
212 | } |
213 | |
214 | ValuePair |
215 | collapseRangeBetween(OpBuilder &b, Location l, ValueRange batchPrefix, |
216 | std::pair<Value, Value> parentRange) const override { |
217 | // Singleton level keeps the same range after collapsing. |
218 | return parentRange; |
219 | }; |
220 | }; |
221 | |
222 | class NOutOfMLevel : public SparseLevel</*hasPosBuf=*/false> { |
223 | public: |
224 | NOutOfMLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize, |
225 | Value crdBuffer) |
226 | : SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {} |
227 | |
228 | ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix, |
229 | ValueRange parentPos, Value inPadZone) const override { |
230 | assert(parentPos.size() == 1 && isUnique() && |
231 | "n:m level can not be non-unique." ); |
232 | assert(!inPadZone && "Not implemented" ); |
233 | // Each n:m blk has exactly n specified elements. |
234 | auto n = getN(lt); |
235 | Value posLo = MULI(parentPos.front(), C_IDX(n)); |
236 | return {posLo, ADDI(posLo, C_IDX(n))}; |
237 | } |
238 | }; |
239 | |
240 | } // namespace |
241 | |
242 | //===----------------------------------------------------------------------===// |
243 | // File local helpers |
244 | //===----------------------------------------------------------------------===// |
245 | |
246 | static scf::ValueVector genWhenInBound( |
247 | OpBuilder &b, Location l, SparseIterator &it, ValueRange elseRet, |
248 | llvm::function_ref<scf::ValueVector(OpBuilder &, Location, Value)> |
249 | builder) { |
250 | TypeRange ifRetTypes = elseRet.getTypes(); |
251 | auto ifOp = b.create<scf::IfOp>(l, ifRetTypes, it.genNotEnd(b, l), true); |
252 | |
253 | b.setInsertionPointToStart(ifOp.thenBlock()); |
254 | Value crd = it.deref(b, l); |
255 | scf::ValueVector ret = builder(b, l, crd); |
256 | YIELD(ret); |
257 | |
258 | b.setInsertionPointToStart(ifOp.elseBlock()); |
259 | YIELD(elseRet); |
260 | |
261 | b.setInsertionPointAfter(ifOp); |
262 | return ifOp.getResults(); |
263 | } |
264 | |
265 | /// Generates code to compute the *absolute* offset of the slice based on the |
266 | /// provide minimum coordinates in the slice. |
267 | /// E.g., when reducing d0 + d1 + d2, we need two slices to fully reduced the |
268 | /// expression, i,e, s1 = slice(T, d0), s2 = slice(s1, d1). The *absolute* |
269 | /// offset is the offset computed relative to the initial tensors T. |
270 | /// |
271 | /// When isNonEmpty == true, the computed offset is meaningless and should not |
272 | /// be used during runtime, the method generates code to return 0 currently in |
273 | /// that case. |
274 | /// |
275 | /// offset = minCrd >= size ? minCrd - size + 1 : 0; |
276 | static Value offsetFromMinCrd(OpBuilder &b, Location l, Value minCrd, |
277 | Value size) { |
278 | Value geSize = CMPI(uge, minCrd, size); |
279 | // Compute minCrd - size + 1. |
280 | Value mms = SUBI(ADDI(minCrd, C_IDX(1)), size); |
281 | // This is the absolute offset related to the actual tensor. |
282 | return SELECT(geSize, mms, C_IDX(0)); |
283 | } |
284 | |
285 | //===----------------------------------------------------------------------===// |
286 | // SparseIterator derived classes. |
287 | //===----------------------------------------------------------------------===// |
288 | |
289 | namespace { |
290 | |
291 | // The iterator that traverses a concrete sparse tensor levels. High-level |
292 | // abstract iterators wrap it to achieve more complex goals (such as collapsing |
293 | // several levels). It also holds the common storage to hold the mlir::Values |
294 | // for itself as well as for wrappers. |
295 | class ConcreteIterator : public SparseIterator { |
296 | protected: |
297 | ConcreteIterator(const SparseTensorLevel &stl, IterKind kind, |
298 | unsigned cursorValCnt) |
299 | : SparseIterator(kind, stl.tid, stl.lvl, cursorValCnt, cursorValsStorage), |
300 | stl(stl), cursorValsStorage(cursorValCnt, nullptr) { |
301 | assert(getCursor().size() == cursorValCnt); |
302 | }; |
303 | |
304 | public: |
305 | // For LLVM-style RTTI. |
306 | static bool classof(const SparseIterator *from) { |
307 | return from->kind == IterKind::kTrivial; |
308 | } |
309 | |
310 | bool isBatchIterator() const override { |
311 | return stl.getLT().isa<LevelFormat::Batch>(); |
312 | } |
313 | bool randomAccessible() const override { |
314 | return stl.getLT().hasDenseSemantic(); |
315 | }; |
316 | bool iteratableByFor() const override { return kind != IterKind::kDedup; }; |
317 | Value upperBound(OpBuilder &b, Location l) const override { |
318 | return stl.getSize(); |
319 | }; |
320 | |
321 | protected: |
322 | const SparseTensorLevel &stl; |
323 | // Owner of the storage, all wrappers build on top of a concrete iterator |
324 | // share the same storage such that the iterator values are always |
325 | // synchronized. |
326 | SmallVector<Value> cursorValsStorage; |
327 | }; |
328 | |
329 | class TrivialIterator : public ConcreteIterator { |
330 | public: |
331 | TrivialIterator(const SparseTensorLevel &stl) |
332 | : ConcreteIterator(stl, IterKind::kTrivial, /*itValCnt=*/1) {} |
333 | |
334 | TrivialIterator(OpBuilder &b, Location l, const SparseTensorLevel &stl, |
335 | Value posLo, Value posHi) |
336 | : ConcreteIterator(stl, IterKind::kTrivial, /*itValCnt=*/1), posLo(posLo), |
337 | posHi(posHi) { |
338 | seek(vals: posLo); |
339 | } |
340 | |
341 | std::string getDebugInterfacePrefix() const override { |
342 | return std::string("trivial<" ) + stl.toString() + ">" ; |
343 | } |
344 | SmallVector<Type> getCursorValTypes(OpBuilder &b) const override { |
345 | return {b.getIndexType()}; |
346 | } |
347 | |
348 | SmallVector<Value> serialize() const override { |
349 | SmallVector<Value> ret; |
350 | ret.push_back(Elt: getItPos()); |
351 | if (randomAccessible()) { |
352 | // Loop high is implicit (defined by `upperBound()`) for random-access |
353 | // iterator, but we need to memorize posLo for linearization. |
354 | ret.push_back(Elt: posLo); |
355 | } else { |
356 | ret.push_back(Elt: posHi); |
357 | } |
358 | return ret; |
359 | }; |
360 | |
361 | void deserialize(ValueRange vs) override { |
362 | assert(vs.size() == 2); |
363 | seek(vals: vs.front()); |
364 | if (randomAccessible()) |
365 | posLo = vs.back(); |
366 | else |
367 | posHi = vs.back(); |
368 | }; |
369 | |
370 | void genInitImpl(OpBuilder &b, Location l, |
371 | const SparseIterator *parent) override; |
372 | |
373 | ValuePair genForCond(OpBuilder &b, Location l) override { |
374 | if (randomAccessible()) |
375 | return {deref(b, l), upperBound(b, l)}; |
376 | return std::make_pair(x: getItPos(), y&: posHi); |
377 | } |
378 | |
379 | Value genNotEndImpl(OpBuilder &b, Location l) override { |
380 | // We used the first level bound as the bound the collapsed set of levels. |
381 | return CMPI(ult, getItPos(), posHi); |
382 | } |
383 | |
384 | Value derefImpl(OpBuilder &b, Location l) override { |
385 | if (randomAccessible()) { |
386 | updateCrd(SUBI(getItPos(), posLo)); |
387 | } else { |
388 | updateCrd(crd: stl.peekCrdAt(b, l, batchPrefix: getBatchCrds(), iv: getItPos())); |
389 | } |
390 | return getCrd(); |
391 | }; |
392 | |
393 | ValueRange forwardImpl(OpBuilder &b, Location l) override { |
394 | seek(ADDI(getItPos(), C_IDX(1))); |
395 | return getCursor(); |
396 | } |
397 | |
398 | ValueRange forwardIf(OpBuilder &b, Location l, Value cond) override { |
399 | Value curPos = getCursor().front(); |
400 | Value nxPos = forward(b, l).front(); |
401 | seek(SELECT(cond, nxPos, curPos)); |
402 | return getCursor(); |
403 | } |
404 | |
405 | void locateImpl(OpBuilder &b, Location l, Value crd) override { |
406 | assert(randomAccessible()); |
407 | // Seek to the linearized position. |
408 | seek(ADDI(crd, posLo)); |
409 | updateCrd(crd); |
410 | if (isBatchIterator()) { |
411 | // If this is a batch iterator, also update the batch coordinate. |
412 | assert(batchCrds.size() > lvl); |
413 | batchCrds[lvl] = crd; |
414 | } |
415 | } |
416 | |
417 | Value getItPos() const { return getCursor().front(); } |
418 | Value posLo, posHi; |
419 | }; |
420 | |
421 | class DedupIterator : public ConcreteIterator { |
422 | private: |
423 | Value genSegmentHigh(OpBuilder &b, Location l, Value pos); |
424 | |
425 | public: |
426 | DedupIterator(const SparseTensorLevel &stl) |
427 | : ConcreteIterator(stl, IterKind::kDedup, /*itValCnt=*/2) { |
428 | assert(!stl.isUnique()); |
429 | } |
430 | |
431 | DedupIterator(OpBuilder &b, Location l, const SparseTensorLevel &stl, |
432 | Value posLo, Value posHi) |
433 | : ConcreteIterator(stl, IterKind::kDedup, /*itValCnt=*/2), posHi(posHi) { |
434 | assert(!stl.isUnique()); |
435 | seek(vals: {posLo, genSegmentHigh(b, l, pos: posLo)}); |
436 | } |
437 | |
438 | // For LLVM-style RTTI. |
439 | static bool classof(const SparseIterator *from) { |
440 | return from->kind == IterKind::kDedup; |
441 | } |
442 | |
443 | std::string getDebugInterfacePrefix() const override { |
444 | return std::string("dedup<" ) + stl.toString() + ">" ; |
445 | } |
446 | SmallVector<Type> getCursorValTypes(OpBuilder &b) const override { |
447 | return {b.getIndexType(), b.getIndexType()}; |
448 | } |
449 | |
450 | void genInitImpl(OpBuilder &b, Location l, |
451 | const SparseIterator *parent) override { |
452 | Value c0 = C_IDX(0); |
453 | ValueRange pPos = c0; |
454 | |
455 | // If the parent iterator is a batch iterator, we also start from 0 (but |
456 | // on a different batch). |
457 | if (parent && !parent->isBatchIterator()) |
458 | pPos = parent->getCurPosition(); |
459 | |
460 | Value posLo; |
461 | ValueRange batchPrefix = parent ? parent->getBatchCrds() : ValueRange{}; |
462 | std::tie(args&: posLo, args&: posHi) = stl.peekRangeAt(b, l, batchPrefix, parentPos: pPos); |
463 | |
464 | seek(vals: {posLo, genSegmentHigh(b, l, pos: posLo)}); |
465 | } |
466 | |
467 | SmallVector<Value> serialize() const override { |
468 | SmallVector<Value> ret; |
469 | ret.append(in_start: getCursor().begin(), in_end: getCursor().end()); |
470 | ret.push_back(Elt: posHi); |
471 | return ret; |
472 | }; |
473 | void deserialize(ValueRange vs) override { |
474 | assert(vs.size() == 3); |
475 | seek(vals: vs.take_front(n: getCursor().size())); |
476 | posHi = vs.back(); |
477 | }; |
478 | |
479 | Value genNotEndImpl(OpBuilder &b, Location l) override { |
480 | return CMPI(ult, getPos(), posHi); |
481 | } |
482 | |
483 | Value derefImpl(OpBuilder &b, Location l) override { |
484 | updateCrd(crd: stl.peekCrdAt(b, l, batchPrefix: getBatchCrds(), iv: getPos())); |
485 | return getCrd(); |
486 | }; |
487 | |
488 | ValueRange forwardImpl(OpBuilder &b, Location l) override { |
489 | Value nxPos = getSegHi(); // forward the position to the next segment. |
490 | seek(vals: {nxPos, genSegmentHigh(b, l, pos: nxPos)}); |
491 | return getCursor(); |
492 | } |
493 | |
494 | Value getPos() const { return getCursor()[0]; } |
495 | Value getSegHi() const { return getCursor()[1]; } |
496 | |
497 | Value posHi; |
498 | }; |
499 | |
500 | // A util base-iterator that delegates all methods to the wrapped iterator. |
501 | class SimpleWrapIterator : public SparseIterator { |
502 | public: |
503 | SimpleWrapIterator(std::unique_ptr<SparseIterator> &&wrap, IterKind kind, |
504 | unsigned = 0) |
505 | : SparseIterator(kind, *wrap, extraCursorVal), wrap(std::move(wrap)) {} |
506 | |
507 | SmallVector<Type> getCursorValTypes(OpBuilder &b) const override { |
508 | return wrap->getCursorValTypes(b); |
509 | } |
510 | bool isBatchIterator() const override { return wrap->isBatchIterator(); } |
511 | bool randomAccessible() const override { return wrap->randomAccessible(); }; |
512 | bool iteratableByFor() const override { return wrap->iteratableByFor(); }; |
513 | |
514 | SmallVector<Value> serialize() const override { return wrap->serialize(); }; |
515 | void deserialize(ValueRange vs) override { wrap->deserialize(vs); }; |
516 | ValueRange getCurPosition() const override { return wrap->getCurPosition(); } |
517 | void genInitImpl(OpBuilder &b, Location l, |
518 | const SparseIterator *parent) override { |
519 | wrap->genInit(b, l, p: parent); |
520 | } |
521 | Value genNotEndImpl(OpBuilder &b, Location l) override { |
522 | return wrap->genNotEndImpl(b, l); |
523 | } |
524 | ValueRange forwardImpl(OpBuilder &b, Location l) override { |
525 | return wrap->forward(b, l); |
526 | }; |
527 | Value upperBound(OpBuilder &b, Location l) const override { |
528 | return wrap->upperBound(b, l); |
529 | }; |
530 | |
531 | Value derefImpl(OpBuilder &b, Location l) override { |
532 | return wrap->derefImpl(b, l); |
533 | } |
534 | |
535 | void locateImpl(OpBuilder &b, Location l, Value crd) override { |
536 | return wrap->locate(b, l, crd); |
537 | } |
538 | |
539 | SparseIterator &getWrappedIterator() const { return *wrap; } |
540 | |
541 | protected: |
542 | std::unique_ptr<SparseIterator> wrap; |
543 | }; |
544 | |
545 | // |
546 | // A filter iterator wrapped from another iterator. The filter iterator update |
547 | // the wrapped iterator *in-place*. |
548 | // |
549 | class FilterIterator : public SimpleWrapIterator { |
550 | // Coorindate translation between crd loaded from the wrap iterator and the |
551 | // filter iterator. |
552 | Value fromWrapCrd(OpBuilder &b, Location l, Value wrapCrd) const { |
553 | // crd = (wrapCrd - offset) / stride |
554 | return DIVUI(SUBI(wrapCrd, offset), stride); |
555 | } |
556 | Value toWrapCrd(OpBuilder &b, Location l, Value crd) const { |
557 | // wrapCrd = crd * stride + offset |
558 | return ADDI(MULI(crd, stride), offset); |
559 | } |
560 | |
561 | Value genCrdNotLegitPredicate(OpBuilder &b, Location l, Value wrapCrd); |
562 | |
563 | Value genShouldFilter(OpBuilder &b, Location l); |
564 | |
565 | public: |
566 | // TODO: avoid unnessary check when offset == 0 and/or when stride == 1 and/or |
567 | // when crd always < size. |
568 | FilterIterator(std::unique_ptr<SparseIterator> &&wrap, Value offset, |
569 | Value stride, Value size) |
570 | : SimpleWrapIterator(std::move(wrap), IterKind::kFilter), offset(offset), |
571 | stride(stride), size(size) {} |
572 | |
573 | // For LLVM-style RTTI. |
574 | static bool classof(const SparseIterator *from) { |
575 | return from->kind == IterKind::kFilter; |
576 | } |
577 | |
578 | std::string getDebugInterfacePrefix() const override { |
579 | return std::string("filter<" ) + wrap->getDebugInterfacePrefix() + ">" ; |
580 | } |
581 | |
582 | bool iteratableByFor() const override { return randomAccessible(); }; |
583 | Value upperBound(OpBuilder &b, Location l) const override { return size; }; |
584 | |
585 | void genInitImpl(OpBuilder &b, Location l, |
586 | const SparseIterator *parent) override { |
587 | wrap->genInit(b, l, p: parent); |
588 | if (!randomAccessible()) { |
589 | // TODO: we can skip this when stride == 1 and offset == 0, we can also |
590 | // use binary search here. |
591 | forwardIf(b, l, cond: genShouldFilter(b, l)); |
592 | } else { |
593 | // Else, locate to the slice.offset, which is the first coordinate |
594 | // included by the slice. |
595 | wrap->locate(b, l, crd: offset); |
596 | } |
597 | } |
598 | |
599 | Value genNotEndImpl(OpBuilder &b, Location l) override; |
600 | |
601 | Value derefImpl(OpBuilder &b, Location l) override { |
602 | updateCrd(crd: fromWrapCrd(b, l, wrapCrd: wrap->deref(b, l))); |
603 | return getCrd(); |
604 | } |
605 | |
606 | void locateImpl(OpBuilder &b, Location l, Value crd) override { |
607 | assert(randomAccessible()); |
608 | wrap->locate(b, l, crd: toWrapCrd(b, l, crd)); |
609 | updateCrd(crd); |
610 | } |
611 | |
612 | ValueRange forwardImpl(OpBuilder &b, Location l) override; |
613 | |
614 | Value offset, stride, size; |
615 | }; |
616 | |
617 | // |
618 | // A pad iterator wrapped from another iterator. The pad iterator updates |
619 | // the wrapped iterator *in-place*. |
620 | // |
621 | class PadIterator : public SimpleWrapIterator { |
622 | |
623 | public: |
624 | PadIterator(std::unique_ptr<SparseIterator> &&wrap, Value padLow, |
625 | Value padHigh) |
626 | : SimpleWrapIterator(std::move(wrap), IterKind::kPad, |
627 | wrap->randomAccessible() ? 1 : 0), |
628 | padLow(padLow), padHigh(padHigh) {} |
629 | |
630 | // For LLVM-style RTTI. |
631 | static bool classof(const SparseIterator *from) { |
632 | return from->kind == IterKind::kPad; |
633 | } |
634 | |
635 | std::string getDebugInterfacePrefix() const override { |
636 | return std::string("pad<" ) + wrap->getDebugInterfacePrefix() + ">" ; |
637 | } |
638 | |
639 | // Returns a pair of values for *upper*, *lower* bound respectively. |
640 | ValuePair genForCond(OpBuilder &b, Location l) override { |
641 | if (randomAccessible()) |
642 | return {getCrd(), upperBound(b, l)}; |
643 | return wrap->genForCond(b, l); |
644 | } |
645 | |
646 | // For padded dense iterator, we append a `inPadZone: bool` in addition to |
647 | // values used by the wrapped iterator. |
648 | ValueRange getCurPosition() const override { return getCursor(); } |
649 | |
650 | SmallVector<Type> getCursorValTypes(OpBuilder &b) const override { |
651 | SmallVector<Type> ret = wrap->getCursorValTypes(b); |
652 | // Need an extra boolean value `inPadZone` for padded dense iterator. |
653 | if (randomAccessible()) |
654 | ret.push_back(b.getI1Type()); |
655 | |
656 | return ret; |
657 | } |
658 | |
659 | // The upper bound after padding becomes `size + padLow + padHigh`. |
660 | Value upperBound(OpBuilder &b, Location l) const override { |
661 | return ADDI(ADDI(wrap->upperBound(b, l), padLow), padHigh); |
662 | }; |
663 | |
664 | // The pad_coord = coord + pad_lo |
665 | Value derefImpl(OpBuilder &b, Location l) override { |
666 | updateCrd(ADDI(wrap->deref(b, l), padLow)); |
667 | return getCrd(); |
668 | } |
669 | |
670 | void locateImpl(OpBuilder &b, Location l, Value crd) override { |
671 | assert(randomAccessible()); |
672 | wrap->locate(b, l, SUBI(crd, padLow)); |
673 | |
674 | // inPadZone = crd < padLow || crd >= size + padLow. |
675 | Value inPadLow = CMPI(ult, crd, padLow); |
676 | Value inPadHigh = CMPI(uge, crd, ADDI(wrap->upperBound(b, l), padLow)); |
677 | getMutCursorVals().back() = ORI(inPadLow, inPadHigh); |
678 | |
679 | updateCrd(crd); |
680 | } |
681 | |
682 | Value padLow, padHigh; |
683 | }; |
684 | |
685 | class NonEmptySubSectIterator : public SparseIterator { |
686 | public: |
687 | using TraverseBuilder = llvm::function_ref<scf::ValueVector( |
688 | OpBuilder &, Location, const SparseIterator *, ValueRange)>; |
689 | |
690 | NonEmptySubSectIterator(OpBuilder &b, Location l, |
691 | const SparseIterator *parent, |
692 | std::unique_ptr<SparseIterator> &&delegate, |
693 | Value subSectSz) |
694 | : SparseIterator(IterKind::kNonEmptySubSect, 3, subSectMeta, *delegate), |
695 | parent(parent), delegate(std::move(delegate)), |
696 | tupleSz(this->delegate->serialize().size()), subSectSz(subSectSz) { |
697 | auto *p = dyn_cast_or_null<NonEmptySubSectIterator>(Val: parent); |
698 | if (p == nullptr) { |
699 | // Extract subsections along the root level. |
700 | maxTupleCnt = C_IDX(1); |
701 | } else if (p->lvl == lvl) { |
702 | // Extract subsections along the same level. |
703 | maxTupleCnt = p->maxTupleCnt; |
704 | assert(false && "Not implemented." ); |
705 | } else { |
706 | // Extract subsections along the previous level. |
707 | assert(p->lvl + 1 == lvl); |
708 | maxTupleCnt = MULI(p->maxTupleCnt, p->subSectSz); |
709 | } |
710 | // We don't need an extra buffer to find subsections on random-accessible |
711 | // levels. |
712 | if (randomAccessible()) |
713 | return; |
714 | subSectPosBuf = allocSubSectPosBuf(b, l); |
715 | } |
716 | |
717 | // For LLVM-style RTTI. |
718 | static bool classof(const SparseIterator *from) { |
719 | return from->kind == IterKind::kNonEmptySubSect; |
720 | } |
721 | |
722 | std::string getDebugInterfacePrefix() const override { |
723 | return std::string("ne_sub<" ) + delegate->getDebugInterfacePrefix() + ">" ; |
724 | } |
725 | SmallVector<Type> getCursorValTypes(OpBuilder &b) const override { |
726 | // minCrd, absolute offset, notEnd |
727 | return {b.getIndexType(), b.getIndexType(), b.getI1Type()}; |
728 | } |
729 | |
730 | // The sliced pointer buffer is organized as: |
731 | // [[itVal0, itVal1, ..., pNx0], |
732 | // [itVal0, itVal1, ..., pNx0], |
733 | // ...] |
734 | Value allocSubSectPosBuf(OpBuilder &b, Location l) { |
735 | return b.create<memref::AllocaOp>( |
736 | l, |
737 | MemRefType::get({ShapedType::kDynamic, tupleSz + 1}, b.getIndexType()), |
738 | maxTupleCnt); |
739 | } |
740 | |
741 | void storeNxLvlStart(OpBuilder &b, Location l, Value tupleId, |
742 | Value start) const { |
743 | b.create<memref::StoreOp>(l, start, subSectPosBuf, |
744 | ValueRange{tupleId, C_IDX(tupleSz)}); |
745 | } |
746 | |
747 | Value loadNxLvlStart(OpBuilder &b, Location l, Value tupleId) const { |
748 | return b.create<memref::LoadOp>(l, subSectPosBuf, |
749 | ValueRange{tupleId, C_IDX(tupleSz)}); |
750 | } |
751 | |
752 | void storeCursorVals(OpBuilder &b, Location l, Value tupleId, |
753 | ValueRange itVals) const { |
754 | assert(itVals.size() == tupleSz); |
755 | for (unsigned i = 0; i < tupleSz; i++) { |
756 | b.create<memref::StoreOp>(l, itVals[i], subSectPosBuf, |
757 | ValueRange{tupleId, C_IDX(i)}); |
758 | } |
759 | } |
760 | |
761 | SmallVector<Value> loadCursorVals(OpBuilder &b, Location l, |
762 | Value tupleId) const { |
763 | SmallVector<Value> ret; |
764 | for (unsigned i = 0; i < tupleSz; i++) { |
765 | Value v = b.create<memref::LoadOp>(l, subSectPosBuf, |
766 | ValueRange{tupleId, C_IDX(i)}); |
767 | ret.push_back(Elt: v); |
768 | } |
769 | return ret; |
770 | } |
771 | |
772 | bool isSubSectRoot() const { |
773 | return !parent || !llvm::isa<NonEmptySubSectIterator>(Val: parent); |
774 | } |
775 | |
776 | // Generate code that inflate the current subsection tree till the current |
777 | // level such that every leaf node is visited. |
778 | ValueRange inflateSubSectTree(OpBuilder &b, Location l, ValueRange reduc, |
779 | TraverseBuilder builder) const; |
780 | |
781 | bool isBatchIterator() const override { return delegate->isBatchIterator(); } |
782 | bool randomAccessible() const override { |
783 | return delegate->randomAccessible(); |
784 | }; |
785 | bool iteratableByFor() const override { return randomAccessible(); }; |
786 | Value upperBound(OpBuilder &b, Location l) const override { |
787 | auto *p = dyn_cast_or_null<NonEmptySubSectIterator>(Val: parent); |
788 | Value parentUB = |
789 | p && p->lvl == lvl ? p->upperBound(b, l) : delegate->upperBound(b, l); |
790 | return ADDI(SUBI(parentUB, subSectSz), C_IDX(1)); |
791 | }; |
792 | |
793 | void genInitImpl(OpBuilder &b, Location l, const SparseIterator *) override; |
794 | |
795 | void locateImpl(OpBuilder &b, Location l, Value crd) override { |
796 | Value absOff = crd; |
797 | |
798 | if (isSubSectRoot()) |
799 | delegate->locate(b, l, crd: absOff); |
800 | else |
801 | assert(parent->lvl + 1 == lvl); |
802 | |
803 | seek(vals: ValueRange{absOff, absOff, C_TRUE}); |
804 | updateCrd(crd); |
805 | } |
806 | |
807 | Value toSubSectCrd(OpBuilder &b, Location l, Value wrapCrd) const { |
808 | return SUBI(wrapCrd, getAbsOff()); |
809 | } |
810 | |
811 | Value genNotEndImpl(OpBuilder &b, Location l) override { |
812 | return getNotEnd(); |
813 | }; |
814 | |
815 | Value derefImpl(OpBuilder &b, Location l) override { |
816 | // Use the relative offset to coiterate. |
817 | Value crd; |
818 | auto *p = dyn_cast_or_null<NonEmptySubSectIterator>(Val: parent); |
819 | if (p && p->lvl == lvl) |
820 | crd = SUBI(getAbsOff(), p->getAbsOff()); |
821 | crd = getAbsOff(); |
822 | |
823 | updateCrd(crd); |
824 | return crd; |
825 | }; |
826 | |
827 | ValueRange forwardImpl(OpBuilder &b, Location l) override; |
828 | |
829 | Value getMinCrd() const { return subSectMeta[0]; } |
830 | Value getAbsOff() const { return subSectMeta[1]; } |
831 | Value getNotEnd() const { return subSectMeta[2]; } |
832 | |
833 | const SparseIterator *parent; |
834 | std::unique_ptr<SparseIterator> delegate; |
835 | |
836 | // Number of values required to serialize the wrapped iterator. |
837 | const unsigned tupleSz; |
838 | // Max number of tuples, and the actual number of tuple. |
839 | Value maxTupleCnt, tupleCnt; |
840 | // The memory used to cache the tuple serialized from the wrapped iterator. |
841 | Value subSectPosBuf; |
842 | |
843 | const Value subSectSz; |
844 | |
845 | // minCrd, absolute offset, notEnd |
846 | SmallVector<Value, 3> subSectMeta{nullptr, nullptr, nullptr}; |
847 | }; |
848 | |
849 | class SubSectIterator; |
850 | |
851 | // A wrapper that helps generating code to traverse a subsection, used |
852 | // by both `NonEmptySubSectIterator`and `SubSectIterator`. |
853 | struct SubSectIterHelper { |
854 | explicit SubSectIterHelper(const SubSectIterator &iter); |
855 | explicit SubSectIterHelper(const NonEmptySubSectIterator &subSect); |
856 | |
857 | // Delegate methods. |
858 | void deserializeFromTupleId(OpBuilder &b, Location l, Value tupleId); |
859 | void locate(OpBuilder &b, Location l, Value crd); |
860 | Value genNotEnd(OpBuilder &b, Location l); |
861 | Value deref(OpBuilder &b, Location l); |
862 | ValueRange forward(OpBuilder &b, Location l); |
863 | |
864 | const NonEmptySubSectIterator &subSect; |
865 | SparseIterator &wrap; |
866 | }; |
867 | |
868 | class SubSectIterator : public SparseIterator { |
869 | public: |
870 | SubSectIterator(const NonEmptySubSectIterator &subSect, |
871 | const SparseIterator &parent, |
872 | std::unique_ptr<SparseIterator> &&wrap) |
873 | : SparseIterator(IterKind::kSubSect, *wrap, |
874 | /*extraCursorCnt=*/wrap->randomAccessible() ? 0 : 1), |
875 | subSect(subSect), wrap(std::move(wrap)), parent(parent), helper(*this) { |
876 | assert(subSect.tid == tid && subSect.lvl == lvl); |
877 | assert(parent.kind != IterKind::kSubSect || parent.lvl + 1 == lvl); |
878 | }; |
879 | |
880 | // For LLVM-style RTTI. |
881 | static bool classof(const SparseIterator *from) { |
882 | return from->kind == IterKind::kSubSect; |
883 | } |
884 | |
885 | std::string getDebugInterfacePrefix() const override { |
886 | return std::string("subsect<" ) + wrap->getDebugInterfacePrefix() + ">" ; |
887 | } |
888 | SmallVector<Type> getCursorValTypes(OpBuilder &b) const override { |
889 | SmallVector<Type> ret = wrap->getCursorValTypes(b); |
890 | if (!randomAccessible()) |
891 | ret.push_back(b.getIndexType()); // The extra counter. |
892 | return ret; |
893 | } |
894 | |
895 | bool isBatchIterator() const override { return wrap->isBatchIterator(); } |
896 | bool randomAccessible() const override { return wrap->randomAccessible(); }; |
897 | bool iteratableByFor() const override { return randomAccessible(); }; |
898 | Value upperBound(OpBuilder &b, Location l) const override { |
899 | return subSect.subSectSz; |
900 | } |
901 | |
902 | ValueRange getCurPosition() const override { return wrap->getCurPosition(); }; |
903 | |
904 | Value getNxLvlTupleId(OpBuilder &b, Location l) const { |
905 | if (randomAccessible()) { |
906 | return ADDI(getCrd(), nxLvlTupleStart); |
907 | }; |
908 | return ADDI(getCursor().back(), nxLvlTupleStart); |
909 | } |
910 | |
911 | void genInitImpl(OpBuilder &b, Location l, const SparseIterator *) override { |
912 | if (randomAccessible()) { |
913 | if (auto *p = llvm::dyn_cast<SubSectIterator>(Val: &parent)) { |
914 | assert(p->lvl + 1 == lvl); |
915 | wrap->genInit(b, l, p); |
916 | // Linearize the dense subsection index. |
917 | nxLvlTupleStart = MULI(subSect.subSectSz, p->getNxLvlTupleId(b, l)); |
918 | } else { |
919 | assert(subSect.lvl == lvl && subSect.isSubSectRoot()); |
920 | wrap->deserialize(vs: subSect.delegate->serialize()); |
921 | nxLvlTupleStart = C_IDX(0); |
922 | } |
923 | return; |
924 | } |
925 | assert(!randomAccessible()); |
926 | assert(getCursor().size() == wrap->getCursor().size() + 1); |
927 | // Extra counter that counts the number of actually visited coordinates in |
928 | // the sparse subsection. |
929 | getMutCursorVals().back() = C_IDX(0); |
930 | Value tupleId; |
931 | if (auto *p = llvm::dyn_cast<SubSectIterator>(Val: &parent)) { |
932 | assert(p->lvl + 1 == lvl); |
933 | tupleId = p->getNxLvlTupleId(b, l); |
934 | } else { |
935 | assert(subSect.lvl == lvl && subSect.isSubSectRoot()); |
936 | tupleId = C_IDX(0); |
937 | } |
938 | nxLvlTupleStart = subSect.loadNxLvlStart(b, l, tupleId); |
939 | helper.deserializeFromTupleId(b, l, tupleId); |
940 | } |
941 | |
942 | void locateImpl(OpBuilder &b, Location l, Value crd) override { |
943 | helper.locate(b, l, crd); |
944 | updateCrd(crd); |
945 | } |
946 | |
947 | Value genNotEndImpl(OpBuilder &b, Location l) override { |
948 | return helper.genNotEnd(b, l); |
949 | } |
950 | |
951 | Value derefImpl(OpBuilder &b, Location l) override { |
952 | Value crd = helper.deref(b, l); |
953 | updateCrd(crd); |
954 | return crd; |
955 | }; |
956 | |
957 | ValueRange forwardImpl(OpBuilder &b, Location l) override { |
958 | helper.forward(b, l); |
959 | assert(!randomAccessible()); |
960 | assert(getCursor().size() == wrap->getCursor().size() + 1); |
961 | getMutCursorVals().back() = ADDI(getCursor().back(), C_IDX(1)); |
962 | return getCursor(); |
963 | }; |
964 | |
965 | Value nxLvlTupleStart; |
966 | |
967 | const NonEmptySubSectIterator &subSect; |
968 | std::unique_ptr<SparseIterator> wrap; |
969 | const SparseIterator &parent; |
970 | |
971 | SubSectIterHelper helper; |
972 | }; |
973 | |
974 | } // namespace |
975 | |
976 | //===----------------------------------------------------------------------===// |
977 | // SparseIterator derived classes implementation. |
978 | //===----------------------------------------------------------------------===// |
979 | |
980 | void SparseIterator::genInit(OpBuilder &b, Location l, |
981 | const SparseIterator *p) { |
982 | if (emitStrategy == SparseEmitStrategy::kDebugInterface) { |
983 | std::string prefix = getDebugInterfacePrefix(); |
984 | Operation *begin = b.create(l, b.getStringAttr(prefix + ".begin" ), {}, |
985 | getCursorValTypes(b)); |
986 | seek(vals: begin->getResults()); |
987 | return; |
988 | } |
989 | // Inherent batch coordinates from parents. |
990 | if (p) |
991 | inherentBatch(parent: *p); |
992 | // TODO: support lowering to function call. |
993 | return genInitImpl(b, l, p); |
994 | } |
995 | |
996 | Value SparseIterator::genNotEnd(OpBuilder &b, Location l) { |
997 | if (emitStrategy == SparseEmitStrategy::kDebugInterface) { |
998 | std::string prefix = getDebugInterfacePrefix(); |
999 | Operation *notEnd = b.create(l, b.getStringAttr(prefix + ".not_end" ), |
1000 | getCursor(), b.getI1Type()); |
1001 | return notEnd->getResult(idx: 0); |
1002 | } |
1003 | // TODO: support lowering to function call. |
1004 | return genNotEndImpl(b, l); |
1005 | } |
1006 | |
1007 | void SparseIterator::locate(OpBuilder &b, Location l, Value crd) { |
1008 | if (emitStrategy == SparseEmitStrategy::kDebugInterface) { |
1009 | std::string prefix = getDebugInterfacePrefix(); |
1010 | SmallVector<Value> args = getCursor(); |
1011 | args.push_back(Elt: crd); |
1012 | Operation *locate = b.create(l, b.getStringAttr(prefix + ".locate" ), args, |
1013 | getCursorValTypes(b)); |
1014 | seek(vals: locate->getResults()); |
1015 | updateCrd(crd); |
1016 | return; |
1017 | } |
1018 | return locateImpl(b, l, crd); |
1019 | } |
1020 | |
1021 | Value SparseIterator::deref(OpBuilder &b, Location l) { |
1022 | if (emitStrategy == SparseEmitStrategy::kDebugInterface) { |
1023 | std::string prefix = getDebugInterfacePrefix(); |
1024 | SmallVector<Value> args = getCursor(); |
1025 | Operation *deref = b.create(l, b.getStringAttr(prefix + ".deref" ), |
1026 | getCursor(), b.getIndexType()); |
1027 | updateCrd(crd: deref->getResult(idx: 0)); |
1028 | return getCrd(); |
1029 | } |
1030 | return derefImpl(b, l); |
1031 | } |
1032 | |
1033 | ValueRange SparseIterator::forward(OpBuilder &b, Location l) { |
1034 | assert(!randomAccessible()); |
1035 | if (emitStrategy == SparseEmitStrategy::kDebugInterface) { |
1036 | std::string prefix = getDebugInterfacePrefix(); |
1037 | Operation *next = b.create(l, b.getStringAttr(prefix + ".next" ), |
1038 | getCursor(), getCursorValTypes(b)); |
1039 | seek(vals: next->getResults()); |
1040 | return getCursor(); |
1041 | } |
1042 | return forwardImpl(b, l); |
1043 | } |
1044 | |
1045 | ValueRange SparseIterator::forwardIf(OpBuilder &b, Location l, Value cond) { |
1046 | auto ifOp = b.create<scf::IfOp>(l, getCursor().getTypes(), cond, true); |
1047 | // Generate else branch first, otherwise iterator values will be updated by |
1048 | // `forward()`. |
1049 | b.setInsertionPointToStart(ifOp.elseBlock()); |
1050 | YIELD(getCursor()); |
1051 | |
1052 | b.setInsertionPointToStart(ifOp.thenBlock()); |
1053 | YIELD(forward(b, l)); |
1054 | |
1055 | b.setInsertionPointAfter(ifOp); |
1056 | seek(vals: ifOp.getResults()); |
1057 | return getCursor(); |
1058 | } |
1059 | |
1060 | Value DedupIterator::genSegmentHigh(OpBuilder &b, Location l, Value pos) { |
1061 | auto whileOp = b.create<scf::WhileOp>( |
1062 | l, pos.getType(), pos, |
1063 | /*beforeBuilder=*/ |
1064 | [this, pos](OpBuilder &b, Location l, ValueRange ivs) { |
1065 | Value inBound = CMPI(ult, ivs.front(), posHi); |
1066 | auto ifInBound = b.create<scf::IfOp>(l, b.getI1Type(), inBound, true); |
1067 | { |
1068 | OpBuilder::InsertionGuard guard(b); |
1069 | // If in bound, load the next coordinates and check duplication. |
1070 | b.setInsertionPointToStart(ifInBound.thenBlock()); |
1071 | Value headCrd = stl.peekCrdAt(b, l, getBatchCrds(), pos); |
1072 | Value tailCrd = stl.peekCrdAt(b, l, getBatchCrds(), ivs.front()); |
1073 | Value isDup = CMPI(eq, headCrd, tailCrd); |
1074 | YIELD(isDup); |
1075 | // Else, the position is out of bound, yield false. |
1076 | b.setInsertionPointToStart(ifInBound.elseBlock()); |
1077 | YIELD(constantI1(b, l, false)); |
1078 | } |
1079 | b.create<scf::ConditionOp>(l, ifInBound.getResults()[0], ivs); |
1080 | }, |
1081 | /*afterBuilder=*/ |
1082 | [](OpBuilder &b, Location l, ValueRange ivs) { |
1083 | Value nxPos = ADDI(ivs[0], C_IDX(1)); |
1084 | YIELD(nxPos); |
1085 | }); |
1086 | // Return the segment high. |
1087 | return whileOp.getResult(0); |
1088 | } |
1089 | |
1090 | Value FilterIterator::genCrdNotLegitPredicate(OpBuilder &b, Location l, |
1091 | Value wrapCrd) { |
1092 | Value crd = fromWrapCrd(b, l, wrapCrd); |
1093 | // Test whether the coordinate is on stride. |
1094 | Value notlegit = CMPI(ne, toWrapCrd(b, l, crd), wrapCrd); |
1095 | // Test wrapCrd < offset |
1096 | notlegit = ORI(CMPI(ult, wrapCrd, offset), notlegit); |
1097 | // Test crd >= length |
1098 | notlegit = ORI(CMPI(uge, crd, size), notlegit); |
1099 | return notlegit; |
1100 | } |
1101 | |
1102 | Value FilterIterator::genShouldFilter(OpBuilder &b, Location l) { |
1103 | auto r = genWhenInBound( |
1104 | b, l, it&: *wrap, C_FALSE, |
1105 | builder: [this](OpBuilder &b, Location l, Value wrapCrd) -> scf::ValueVector { |
1106 | Value notLegit = genCrdNotLegitPredicate(b, l, wrapCrd); |
1107 | return {notLegit}; |
1108 | }); |
1109 | return llvm::getSingleElement(C&: r); |
1110 | } |
1111 | |
1112 | Value FilterIterator::genNotEndImpl(OpBuilder &b, Location l) { |
1113 | assert(!wrap->randomAccessible()); |
1114 | auto r = genWhenInBound( |
1115 | b, l, it&: *wrap, C_FALSE, |
1116 | builder: [this](OpBuilder &b, Location l, Value wrapCrd) -> scf::ValueVector { |
1117 | Value crd = fromWrapCrd(b, l, wrapCrd); |
1118 | // crd < size |
1119 | return {CMPI(ult, crd, size)}; |
1120 | }); |
1121 | return llvm::getSingleElement(C&: r); |
1122 | } |
1123 | |
1124 | ValueRange FilterIterator::forwardImpl(OpBuilder &b, Location l) { |
1125 | assert(!randomAccessible()); |
1126 | // Generates |
1127 | // |
1128 | // bool isFirst = true; |
1129 | // while !it.end() && (!legit(*it) || isFirst) |
1130 | // wrap ++; |
1131 | // isFirst = false; |
1132 | // |
1133 | // We do not hoist the first `wrap++` outside the loop but use a `isFirst` |
1134 | // flag here because `wrap++` might have a complex implementation (e.g., to |
1135 | // forward a subsection). |
1136 | Value isFirst = constantI1(builder&: b, loc: l, b: true); |
1137 | |
1138 | SmallVector<Value> whileArgs(getCursor().begin(), getCursor().end()); |
1139 | whileArgs.push_back(Elt: isFirst); |
1140 | auto whileOp = b.create<scf::WhileOp>( |
1141 | l, ValueRange(whileArgs).getTypes(), whileArgs, |
1142 | /*beforeBuilder=*/ |
1143 | [this](OpBuilder &b, Location l, ValueRange ivs) { |
1144 | ValueRange isFirst = linkNewScope(ivs); |
1145 | scf::ValueVector cont = |
1146 | genWhenInBound(b, l, *wrap, C_FALSE, |
1147 | [this, isFirst](OpBuilder &b, Location l, |
1148 | Value wrapCrd) -> scf::ValueVector { |
1149 | // crd < size && !legit(); |
1150 | Value notLegit = |
1151 | genCrdNotLegitPredicate(b, l, wrapCrd); |
1152 | Value crd = fromWrapCrd(b, l, wrapCrd); |
1153 | Value ret = ANDI(CMPI(ult, crd, size), notLegit); |
1154 | ret = ORI(ret, llvm::getSingleElement(isFirst)); |
1155 | return {ret}; |
1156 | }); |
1157 | b.create<scf::ConditionOp>(l, cont.front(), ivs); |
1158 | }, |
1159 | /*afterBuilder=*/ |
1160 | [this](OpBuilder &b, Location l, ValueRange ivs) { |
1161 | linkNewScope(ivs); |
1162 | wrap->forward(b, l); |
1163 | SmallVector<Value> yieldVals(getCursor().begin(), getCursor().end()); |
1164 | yieldVals.push_back(constantI1(b, l, false)); |
1165 | YIELD(yieldVals); |
1166 | }); |
1167 | |
1168 | b.setInsertionPointAfter(whileOp); |
1169 | linkNewScope(pos: whileOp.getResults()); |
1170 | return getCursor(); |
1171 | } |
1172 | |
1173 | SubSectIterHelper::SubSectIterHelper(const NonEmptySubSectIterator &subSect) |
1174 | : subSect(subSect), wrap(*subSect.delegate) {} |
1175 | |
1176 | SubSectIterHelper::SubSectIterHelper(const SubSectIterator &iter) |
1177 | : subSect(iter.subSect), wrap(*iter.wrap) {} |
1178 | |
1179 | void SubSectIterHelper::deserializeFromTupleId(OpBuilder &b, Location l, |
1180 | Value tupleId) { |
1181 | assert(!subSect.randomAccessible()); |
1182 | wrap.deserialize(vs: subSect.loadCursorVals(b, l, tupleId)); |
1183 | } |
1184 | |
1185 | void SubSectIterHelper::locate(OpBuilder &b, Location l, Value crd) { |
1186 | Value absCrd = ADDI(crd, subSect.getAbsOff()); |
1187 | wrap.locate(b, l, crd: absCrd); |
1188 | } |
1189 | |
1190 | Value SubSectIterHelper::genNotEnd(OpBuilder &b, Location l) { |
1191 | assert(!wrap.randomAccessible()); |
1192 | auto r = genWhenInBound( |
1193 | b, l, it&: wrap, C_FALSE, |
1194 | builder: [this](OpBuilder &b, Location l, Value wrapCrd) -> scf::ValueVector { |
1195 | Value crd = SUBI(wrapCrd, subSect.getAbsOff()); |
1196 | // crd < size |
1197 | return {CMPI(ult, crd, subSect.subSectSz)}; |
1198 | }); |
1199 | return llvm::getSingleElement(C&: r); |
1200 | } |
1201 | |
1202 | Value SubSectIterHelper::deref(OpBuilder &b, Location l) { |
1203 | Value wrapCrd = wrap.deref(b, l); |
1204 | Value crd = subSect.toSubSectCrd(b, l, wrapCrd); |
1205 | return crd; |
1206 | } |
1207 | |
1208 | ValueRange SubSectIterHelper::forward(OpBuilder &b, Location l) { |
1209 | return wrap.forward(b, l); |
1210 | } |
1211 | |
1212 | ValueRange NonEmptySubSectIterator::inflateSubSectTree( |
1213 | OpBuilder &b, Location l, ValueRange reduc, TraverseBuilder builder) const { |
1214 | // Set up the helper to help traverse a sparse subsection. |
1215 | SubSectIterHelper helper(*this); |
1216 | if (!randomAccessible()) { |
1217 | // The subsection tree have been expanded till the level and cached, |
1218 | // traverse all the leaves and expanded to the next level. |
1219 | SmallVector<Value> iterArgs; |
1220 | iterArgs.push_back(C_IDX(0)); |
1221 | iterArgs.append(in_start: reduc.begin(), in_end: reduc.end()); |
1222 | auto forEachLeaf = b.create<scf::ForOp>( |
1223 | l, /*lb=*/C_IDX(0), /*ub=*/tupleCnt, /*step=*/C_IDX(1), iterArgs, |
1224 | [&helper, &builder](OpBuilder &b, Location l, Value tupleId, |
1225 | ValueRange iterArgs) { |
1226 | // Deserialize the iterator at the cached position (tupleId). |
1227 | helper.deserializeFromTupleId(b, l, tupleId); |
1228 | |
1229 | Value cnt = iterArgs.front(); |
1230 | // Record the number of leaf nodes included in the subsection. |
1231 | // The number indicates the starting tupleId for the next level that |
1232 | // is corresponding to the current node. |
1233 | helper.subSect.storeNxLvlStart(b, l, tupleId, cnt); |
1234 | |
1235 | SmallVector<Value> whileArgs(helper.wrap.getCursor()); |
1236 | whileArgs.append(iterArgs.begin(), iterArgs.end()); |
1237 | |
1238 | auto whileOp = b.create<scf::WhileOp>( |
1239 | l, ValueRange(whileArgs).getTypes(), whileArgs, |
1240 | /*beforeBuilder=*/ |
1241 | [&helper](OpBuilder &b, Location l, ValueRange ivs) { |
1242 | helper.wrap.linkNewScope(ivs); |
1243 | b.create<scf::ConditionOp>(l, helper.genNotEnd(b, l), ivs); |
1244 | }, |
1245 | /*afterBuilder=*/ |
1246 | [&helper, &builder](OpBuilder &b, Location l, ValueRange ivs) { |
1247 | ValueRange remIter = helper.wrap.linkNewScope(ivs); |
1248 | Value cnt = remIter.front(); |
1249 | ValueRange userIter = remIter.drop_front(); |
1250 | scf::ValueVector userNx = builder(b, l, &helper.wrap, userIter); |
1251 | |
1252 | SmallVector<Value> nxIter = helper.forward(b, l); |
1253 | nxIter.push_back(ADDI(cnt, C_IDX(1))); |
1254 | nxIter.append(userNx.begin(), userNx.end()); |
1255 | YIELD(nxIter); |
1256 | }); |
1257 | ValueRange res = helper.wrap.linkNewScope(whileOp.getResults()); |
1258 | YIELD(res); |
1259 | }); |
1260 | return forEachLeaf.getResults().drop_front(); |
1261 | } |
1262 | |
1263 | assert(randomAccessible()); |
1264 | // Helper lambda that traverse the current dense subsection range. |
1265 | auto visitDenseSubSect = [&, this](OpBuilder &b, Location l, |
1266 | const SparseIterator *parent, |
1267 | ValueRange reduc) { |
1268 | assert(!parent || parent->lvl + 1 == lvl); |
1269 | delegate->genInit(b, l, p: parent); |
1270 | auto forOp = b.create<scf::ForOp>( |
1271 | l, /*lb=*/C_IDX(0), /*ub=*/subSectSz, /*step=*/C_IDX(1), reduc, |
1272 | [&](OpBuilder &b, Location l, Value crd, ValueRange iterArgs) { |
1273 | helper.locate(b, l, crd); |
1274 | scf::ValueVector nx = builder(b, l, &helper.wrap, iterArgs); |
1275 | YIELD(nx); |
1276 | }); |
1277 | return forOp.getResults(); |
1278 | }; |
1279 | |
1280 | if (isSubSectRoot()) { |
1281 | return visitDenseSubSect(b, l, parent, reduc); |
1282 | } |
1283 | // Else, this is not the root, recurse until root. |
1284 | auto *p = llvm::cast<NonEmptySubSectIterator>(Val: parent); |
1285 | assert(p->lvl + 1 == lvl); |
1286 | return p->inflateSubSectTree(b, l, reduc, visitDenseSubSect); |
1287 | } |
1288 | |
1289 | void TrivialIterator::genInitImpl(OpBuilder &b, Location l, |
1290 | const SparseIterator *parent) { |
1291 | |
1292 | if (isBatchIterator() && batchCrds.size() <= stl.lvl) |
1293 | batchCrds.resize(N: stl.lvl + 1, NV: nullptr); |
1294 | |
1295 | Value c0 = C_IDX(0); |
1296 | ValueRange pPos = c0; |
1297 | Value inPadZone = nullptr; |
1298 | // If the parent iterator is a batch iterator, we also start from 0 (but |
1299 | // on a different batch). |
1300 | if (parent && !parent->isBatchIterator()) { |
1301 | pPos = parent->getCurPosition(); |
1302 | if (llvm::isa<PadIterator>(Val: parent) && parent->randomAccessible()) { |
1303 | // A padded dense iterator create "sparse" padded zone, which need to be |
1304 | // handled specially. |
1305 | inPadZone = pPos.back(); |
1306 | pPos = pPos.drop_back(); |
1307 | } |
1308 | } |
1309 | |
1310 | ValueRange batchPrefix = parent ? parent->getBatchCrds() : ValueRange{}; |
1311 | std::tie(args&: posLo, args&: posHi) = stl.peekRangeAt(b, l, batchPrefix, parentPos: pPos, inPadZone); |
1312 | // Seek to the lowest position. |
1313 | seek(vals: posLo); |
1314 | } |
1315 | |
1316 | void NonEmptySubSectIterator::genInitImpl(OpBuilder &b, Location l, |
1317 | const SparseIterator *) { |
1318 | Value c0 = C_IDX(0); |
1319 | if (!isSubSectRoot()) { |
1320 | assert(parent->lvl + 1 == lvl); |
1321 | if (randomAccessible()) { |
1322 | // We can not call wrap->genInit() here to initialize the wrapped |
1323 | // iterator, because the parent of the curent iterator is still |
1324 | // unresolved. |
1325 | seek(vals: {/*minCrd=*/c0, /*offset=*/c0, /*notEnd=*/C_TRUE}); |
1326 | return; |
1327 | } |
1328 | |
1329 | auto *p = cast<NonEmptySubSectIterator>(Val: parent); |
1330 | SmallVector<Value, 3> reduc = { |
1331 | C_IDX(-1), // minCrd (max signless integer) |
1332 | c0, // tupleId |
1333 | }; |
1334 | |
1335 | // Expand the subsection tree from the parent level to the current level. |
1336 | ValueRange result = p->inflateSubSectTree( |
1337 | b, l, reduc, |
1338 | builder: [this](OpBuilder &b, Location l, const SparseIterator *parent, |
1339 | ValueRange reduc) -> scf::ValueVector { |
1340 | assert(parent->lvl + 1 == lvl && reduc.size() == 2); |
1341 | Value minCrd = reduc.front(); |
1342 | Value tupleId = reduc.back(); |
1343 | |
1344 | // Initialize the subsection range. |
1345 | SubSectIterHelper helper(*this); |
1346 | helper.wrap.genInit(b, l, p: parent); |
1347 | |
1348 | // Update minCrd. |
1349 | minCrd = genWhenInBound(b, l, it&: helper.wrap, elseRet: minCrd, |
1350 | builder: [minCrd](OpBuilder &b, Location l, |
1351 | Value crd) -> scf::ValueVector { |
1352 | Value min = MINUI(crd, minCrd); |
1353 | return {min}; |
1354 | }) |
1355 | .front(); |
1356 | |
1357 | // Cache the sparse range. |
1358 | storeCursorVals(b, l, tupleId, itVals: helper.wrap.serialize()); |
1359 | tupleId = ADDI(tupleId, C_IDX(1)); |
1360 | return {minCrd, tupleId}; |
1361 | }); |
1362 | assert(result.size() == 2); |
1363 | tupleCnt = result.back(); |
1364 | |
1365 | Value minCrd = result.front(); |
1366 | Value absOff = offsetFromMinCrd(b, l, minCrd, size: subSectSz); |
1367 | Value notEnd = CMPI(ne, minCrd, C_IDX(-1)); |
1368 | seek(vals: {minCrd, absOff, notEnd}); |
1369 | return; |
1370 | } |
1371 | |
1372 | // This is the root level of the subsection, which means that it is resolved |
1373 | // to one node. |
1374 | assert(isSubSectRoot()); |
1375 | |
1376 | // Initialize the position, the position marks the *lower bound* of the |
1377 | // subRange. The higher bound is determined by the size of the subsection. |
1378 | delegate->genInit(b, l, p: parent); |
1379 | if (randomAccessible()) { |
1380 | seek(vals: {/*minCrd=*/c0, /*offset=*/c0, /*notEnd=*/C_TRUE}); |
1381 | return; |
1382 | } |
1383 | |
1384 | // Only have one root node. |
1385 | tupleCnt = C_IDX(1); |
1386 | // Cache the sparse range. |
1387 | storeCursorVals(b, l, tupleId: c0, itVals: delegate->serialize()); |
1388 | SmallVector<Value> elseRet{c0, c0, /*notEnd=*/C_FALSE}; |
1389 | auto meta = genWhenInBound( |
1390 | b, l, it&: *delegate, elseRet, |
1391 | builder: [this](OpBuilder &b, Location l, Value crd) -> scf::ValueVector { |
1392 | Value offset = offsetFromMinCrd(b, l, minCrd: crd, size: subSectSz); |
1393 | return {crd, offset, C_TRUE}; |
1394 | }); |
1395 | |
1396 | seek(vals: meta); |
1397 | } |
1398 | |
1399 | ValueRange NonEmptySubSectIterator::forwardImpl(OpBuilder &b, Location l) { |
1400 | assert(!randomAccessible()); |
1401 | Value c0 = C_IDX(0), c1 = C_IDX(1); |
1402 | // Forward to the next non empty slice by generating |
1403 | // |
1404 | // if (minCrd > offset) { |
1405 | // offset += 1 |
1406 | // } else { |
1407 | // minCrd = nextMinInSlice(); |
1408 | // offset = minCrd - size + 1; |
1409 | // } |
1410 | // |
1411 | // if (offset + size > parents.size) |
1412 | // isNonEmpty = false; |
1413 | Value fastPathP = CMPI(ugt, getMinCrd(), getAbsOff()); |
1414 | auto ifOp = b.create<scf::IfOp>(l, getCursor().getTypes(), fastPathP, true); |
1415 | { |
1416 | OpBuilder::InsertionGuard guard(b); |
1417 | // Take the fast path |
1418 | // if (minCrd > offset) |
1419 | // offset += 1 |
1420 | b.setInsertionPointToStart(&ifOp.getThenRegion().front()); |
1421 | Value nxOffset = ADDI(getAbsOff(), c1); |
1422 | YIELD((ValueRange{getMinCrd(), nxOffset, getNotEnd()})); |
1423 | |
1424 | // else /*minCrd == offset*/ { |
1425 | // for (i = 0; i < tupleCnt; i++) { |
1426 | // wrap->deserialize(pos[i]); |
1427 | // minCrd=min(minCrd, *wrap); |
1428 | // } |
1429 | // offset = minCrd - size + 1; |
1430 | // } |
1431 | b.setInsertionPointToStart(&ifOp.getElseRegion().front()); |
1432 | SmallVector<Value, 2> loopArgs{C_IDX(-1), // nextMinCrd |
1433 | C_FALSE}; // isNotEnd |
1434 | auto loopNest = scf::buildLoopNest( |
1435 | builder&: b, loc: l, lbs: c0, ubs: tupleCnt, steps: c1, iterArgs: loopArgs, |
1436 | bodyBuilder: [this](OpBuilder &b, Location l, ValueRange ivs, |
1437 | ValueRange iterArgs) -> scf::ValueVector { |
1438 | Value tupleId = ivs.front(); |
1439 | SubSectIterHelper helper(*this); |
1440 | helper.deserializeFromTupleId(b, l, tupleId); |
1441 | |
1442 | return genWhenInBound( |
1443 | b, l, it&: *delegate, /*elseRet=*/iterArgs, |
1444 | builder: [this, iterArgs, tupleId](OpBuilder &b, Location l, |
1445 | Value crd) -> scf::ValueVector { |
1446 | // if coord == minCrd |
1447 | // wrap->forward(); |
1448 | Value isMin = CMPI(eq, crd, getMinCrd()); |
1449 | delegate->forwardIf(b, l, cond: isMin); |
1450 | // Update the forwarded iterator values if needed. |
1451 | auto ifIsMin = b.create<scf::IfOp>(l, isMin, false); |
1452 | b.setInsertionPointToStart(&ifIsMin.getThenRegion().front()); |
1453 | storeCursorVals(b, l, tupleId, itVals: delegate->serialize()); |
1454 | b.setInsertionPointAfter(ifIsMin); |
1455 | // if (!wrap.end()) |
1456 | // yield(min(nxMinCrd, *wrap), true) |
1457 | Value nxMin = iterArgs[0]; |
1458 | return genWhenInBound(b, l, it&: *delegate, /*elseRet=*/iterArgs, |
1459 | builder: [nxMin](OpBuilder &b, Location l, |
1460 | Value crd) -> scf::ValueVector { |
1461 | Value nx = b.create<arith::MinUIOp>( |
1462 | l, crd, nxMin); |
1463 | return {nx, C_TRUE}; |
1464 | }); |
1465 | }); |
1466 | }); |
1467 | |
1468 | scf::ForOp forOp = loopNest.loops.front(); |
1469 | b.setInsertionPointAfter(forOp); |
1470 | |
1471 | Value nxMinCrd = forOp.getResult(0); |
1472 | Value nxNotEnd = forOp.getResult(1); |
1473 | Value nxAbsOff = offsetFromMinCrd(b, l, minCrd: nxMinCrd, size: subSectSz); |
1474 | YIELD((ValueRange{nxMinCrd, nxAbsOff, nxNotEnd})); |
1475 | } |
1476 | |
1477 | Value nxMinCrd = ifOp.getResult(0); |
1478 | Value nxAbsOff = ifOp.getResult(1); |
1479 | Value nxNotEnd = ifOp.getResult(2); |
1480 | |
1481 | // We should at least forward the offset by one. |
1482 | Value minAbsOff = ADDI(getAbsOff(), c1); |
1483 | nxAbsOff = b.create<arith::MaxUIOp>(l, minAbsOff, nxAbsOff); |
1484 | |
1485 | seek(vals: ValueRange{nxMinCrd, nxAbsOff, nxNotEnd}); |
1486 | // The coordinate should not exceeds the space upper bound. |
1487 | Value crd = deref(b, l); |
1488 | nxNotEnd = ANDI(nxNotEnd, CMPI(ult, crd, upperBound(b, l))); |
1489 | |
1490 | seek(vals: ValueRange{nxMinCrd, nxAbsOff, nxNotEnd}); |
1491 | return getCursor(); |
1492 | } |
1493 | |
1494 | //===----------------------------------------------------------------------===// |
1495 | // SparseIterationSpace Implementation |
1496 | //===----------------------------------------------------------------------===// |
1497 | |
1498 | mlir::sparse_tensor::SparseIterationSpace::SparseIterationSpace( |
1499 | Location l, OpBuilder &b, Value t, unsigned tid, |
1500 | std::pair<Level, Level> lvlRange, ValueRange parentPos) |
1501 | : lvls() { |
1502 | auto [lvlLo, lvlHi] = lvlRange; |
1503 | |
1504 | Value c0 = C_IDX(0); |
1505 | if (parentPos.empty()) |
1506 | parentPos = c0; |
1507 | |
1508 | for (Level lvl = lvlLo; lvl < lvlHi; lvl++) |
1509 | lvls.emplace_back(Args: makeSparseTensorLevel(b, l, t, tid, lvl)); |
1510 | |
1511 | bound = lvls.front()->peekRangeAt(b, l, /*batchPrefix=*/{}, parentPos); |
1512 | for (auto &lvl : getLvlRef().drop_front()) |
1513 | bound = lvl->collapseRangeBetween(b, l, /*batchPrefix=*/{}, parentRange: bound); |
1514 | } |
1515 | |
1516 | SparseIterationSpace mlir::sparse_tensor::SparseIterationSpace::fromValues( |
1517 | IterSpaceType dstTp, ValueRange values, unsigned int tid) { |
1518 | // Reconstruct every sparse tensor level. |
1519 | SparseIterationSpace space; |
1520 | for (auto [i, lt] : llvm::enumerate(dstTp.getLvlTypes())) { |
1521 | unsigned bufferCnt = 0; |
1522 | if (lt.isWithPosLT()) |
1523 | bufferCnt++; |
1524 | if (lt.isWithCrdLT()) |
1525 | bufferCnt++; |
1526 | // Sparse tensor buffers. |
1527 | ValueRange buffers = values.take_front(bufferCnt); |
1528 | values = values.drop_front(bufferCnt); |
1529 | |
1530 | // Level size. |
1531 | Value sz = values.front(); |
1532 | values = values.drop_front(); |
1533 | space.lvls.push_back( |
1534 | makeSparseTensorLevel(lt, sz, buffers, tid, i + dstTp.getLoLvl())); |
1535 | } |
1536 | // Two bounds. |
1537 | space.bound = std::make_pair(x: values[0], y: values[1]); |
1538 | values = values.drop_front(n: 2); |
1539 | |
1540 | // Must have consumed all values. |
1541 | assert(values.empty()); |
1542 | return space; |
1543 | } |
1544 | |
1545 | std::unique_ptr<SparseIterator> |
1546 | SparseIterationSpace::(OpBuilder &b, Location l) const { |
1547 | return makeSimpleIterator(b, l, iterSpace: *this); |
1548 | } |
1549 | |
1550 | //===----------------------------------------------------------------------===// |
1551 | // SparseIterator factory functions. |
1552 | //===----------------------------------------------------------------------===// |
1553 | |
1554 | /// Helper function to create a TensorLevel object from given `tensor`. |
1555 | std::unique_ptr<SparseTensorLevel> |
1556 | sparse_tensor::makeSparseTensorLevel(LevelType lt, Value sz, ValueRange b, |
1557 | unsigned t, Level l) { |
1558 | assert(lt.getNumBuffer() == b.size()); |
1559 | switch (lt.getLvlFmt()) { |
1560 | case LevelFormat::Dense: |
1561 | return std::make_unique<DenseLevel>(args&: t, args&: l, args&: sz); |
1562 | case LevelFormat::Batch: |
1563 | return std::make_unique<BatchLevel>(args&: t, args&: l, args&: sz); |
1564 | case LevelFormat::Compressed: |
1565 | return std::make_unique<CompressedLevel>(args&: t, args&: l, args&: lt, args&: sz, args: b[0], args: b[1]); |
1566 | case LevelFormat::LooseCompressed: |
1567 | return std::make_unique<LooseCompressedLevel>(args&: t, args&: l, args&: lt, args&: sz, args: b[0], args: b[1]); |
1568 | case LevelFormat::Singleton: |
1569 | return std::make_unique<SingletonLevel>(args&: t, args&: l, args&: lt, args&: sz, args: b[0]); |
1570 | case LevelFormat::NOutOfM: |
1571 | return std::make_unique<NOutOfMLevel>(args&: t, args&: l, args&: lt, args&: sz, args: b[0]); |
1572 | case LevelFormat::Undef: |
1573 | llvm_unreachable("undefined level format" ); |
1574 | } |
1575 | llvm_unreachable("unrecognizable level format" ); |
1576 | } |
1577 | |
1578 | std::unique_ptr<SparseTensorLevel> |
1579 | sparse_tensor::makeSparseTensorLevel(OpBuilder &b, Location l, Value t, |
1580 | unsigned tid, Level lvl) { |
1581 | auto stt = getSparseTensorType(val: t); |
1582 | |
1583 | LevelType lt = stt.getLvlType(l: lvl); |
1584 | Value sz = stt.hasEncoding() ? b.create<LvlOp>(l, t, lvl).getResult() |
1585 | : b.create<tensor::DimOp>(l, t, lvl).getResult(); |
1586 | |
1587 | SmallVector<Value, 2> buffers; |
1588 | if (lt.isWithPosLT()) { |
1589 | Value pos = b.create<ToPositionsOp>(l, t, lvl); |
1590 | buffers.push_back(Elt: pos); |
1591 | } |
1592 | if (lt.isWithCrdLT()) { |
1593 | Value pos = b.create<ToCoordinatesOp>(l, t, lvl); |
1594 | buffers.push_back(Elt: pos); |
1595 | } |
1596 | return makeSparseTensorLevel(lt, sz, b: buffers, t: tid, l: lvl); |
1597 | } |
1598 | |
1599 | std::pair<std::unique_ptr<SparseTensorLevel>, std::unique_ptr<SparseIterator>> |
1600 | sparse_tensor::makeSynLevelAndIterator(Value sz, unsigned tid, unsigned lvl, |
1601 | SparseEmitStrategy strategy) { |
1602 | auto stl = std::make_unique<BatchLevel>(args&: tid, args&: lvl, args&: sz); |
1603 | auto it = std::make_unique<TrivialIterator>(args&: *stl); |
1604 | it->setSparseEmitStrategy(strategy); |
1605 | return std::make_pair(x: std::move(stl), y: std::move(it)); |
1606 | } |
1607 | |
1608 | std::unique_ptr<SparseIterator> |
1609 | sparse_tensor::makeSimpleIterator(OpBuilder &b, Location l, |
1610 | const SparseIterationSpace &iterSpace) { |
1611 | // assert(iterSpace.getSpaceDim() == 1); |
1612 | std::unique_ptr<SparseIterator> ret; |
1613 | if (!iterSpace.isUnique()) { |
1614 | // We always dedupliate the non-unique level, but we should optimize it away |
1615 | // if possible. |
1616 | ret = std::make_unique<DedupIterator>(args&: b, args&: l, args: iterSpace.getLastLvl(), |
1617 | args: iterSpace.getBoundLo(), |
1618 | args: iterSpace.getBoundHi()); |
1619 | } else { |
1620 | ret = std::make_unique<TrivialIterator>(args&: b, args&: l, args: iterSpace.getLastLvl(), |
1621 | args: iterSpace.getBoundLo(), |
1622 | args: iterSpace.getBoundHi()); |
1623 | } |
1624 | ret->setSparseEmitStrategy(SparseEmitStrategy::kFunctional); |
1625 | return ret; |
1626 | } |
1627 | |
1628 | std::unique_ptr<SparseIterator> |
1629 | sparse_tensor::makeSimpleIterator(const SparseTensorLevel &stl, |
1630 | SparseEmitStrategy strategy) { |
1631 | std::unique_ptr<SparseIterator> ret; |
1632 | if (!isUniqueLT(lt: stl.getLT())) { |
1633 | // We always dedupliate the non-unique level, but we should optimize it away |
1634 | // if possible. |
1635 | ret = std::make_unique<DedupIterator>(args: stl); |
1636 | } else { |
1637 | ret = std::make_unique<TrivialIterator>(args: stl); |
1638 | } |
1639 | ret->setSparseEmitStrategy(strategy); |
1640 | return ret; |
1641 | } |
1642 | |
1643 | std::unique_ptr<SparseIterator> |
1644 | sparse_tensor::makeSlicedLevelIterator(std::unique_ptr<SparseIterator> &&sit, |
1645 | Value offset, Value stride, Value size, |
1646 | SparseEmitStrategy strategy) { |
1647 | |
1648 | auto ret = |
1649 | std::make_unique<FilterIterator>(args: std::move(sit), args&: offset, args&: stride, args&: size); |
1650 | ret->setSparseEmitStrategy(strategy); |
1651 | return ret; |
1652 | } |
1653 | |
1654 | std::unique_ptr<SparseIterator> |
1655 | sparse_tensor::makePaddedIterator(std::unique_ptr<SparseIterator> &&sit, |
1656 | Value padLow, Value padHigh, |
1657 | SparseEmitStrategy strategy) { |
1658 | auto ret = std::make_unique<PadIterator>(args: std::move(sit), args&: padLow, args&: padHigh); |
1659 | ret->setSparseEmitStrategy(strategy); |
1660 | return ret; |
1661 | } |
1662 | |
1663 | static const SparseIterator *tryUnwrapFilter(const SparseIterator *it) { |
1664 | auto *filter = llvm::dyn_cast_or_null<FilterIterator>(Val: it); |
1665 | if (filter) |
1666 | return &filter->getWrappedIterator(); |
1667 | return it; |
1668 | } |
1669 | |
1670 | std::unique_ptr<SparseIterator> sparse_tensor::makeNonEmptySubSectIterator( |
1671 | OpBuilder &b, Location l, const SparseIterator *parent, Value loopBound, |
1672 | std::unique_ptr<SparseIterator> &&delegate, Value size, unsigned stride, |
1673 | SparseEmitStrategy strategy) { |
1674 | |
1675 | // Try unwrap the NonEmptySubSectIterator from a filter parent. |
1676 | parent = tryUnwrapFilter(it: parent); |
1677 | std::unique_ptr<SparseIterator> it = |
1678 | std::make_unique<NonEmptySubSectIterator>(args&: b, args&: l, args&: parent, |
1679 | args: std::move(delegate), args&: size); |
1680 | |
1681 | if (stride != 1) { |
1682 | // TODO: We can safely skip bound checking on sparse levels, but for dense |
1683 | // iteration space, we need the bound to infer the dense loop range. |
1684 | it = std::make_unique<FilterIterator>(args: std::move(it), /*offset=*/C_IDX(0), |
1685 | C_IDX(stride), /*size=*/args&: loopBound); |
1686 | } |
1687 | it->setSparseEmitStrategy(strategy); |
1688 | return it; |
1689 | } |
1690 | |
1691 | std::unique_ptr<SparseIterator> sparse_tensor::makeTraverseSubSectIterator( |
1692 | OpBuilder &b, Location l, const SparseIterator &subSectIter, |
1693 | const SparseIterator &parent, std::unique_ptr<SparseIterator> &&wrap, |
1694 | Value loopBound, unsigned stride, SparseEmitStrategy strategy) { |
1695 | |
1696 | // This must be a subsection iterator or a filtered subsection iterator. |
1697 | auto &subSect = |
1698 | llvm::cast<NonEmptySubSectIterator>(Val: *tryUnwrapFilter(it: &subSectIter)); |
1699 | |
1700 | std::unique_ptr<SparseIterator> it = std::make_unique<SubSectIterator>( |
1701 | args: subSect, args: *tryUnwrapFilter(it: &parent), args: std::move(wrap)); |
1702 | |
1703 | if (stride != 1) { |
1704 | it = std::make_unique<FilterIterator>(args: std::move(it), /*offset=*/C_IDX(0), |
1705 | C_IDX(stride), /*size=*/args&: loopBound); |
1706 | } |
1707 | it->setSparseEmitStrategy(strategy); |
1708 | return it; |
1709 | } |
1710 | |
1711 | #undef CMPI |
1712 | #undef C_IDX |
1713 | #undef YIELD |
1714 | #undef ADDI |
1715 | #undef ANDI |
1716 | #undef SUBI |
1717 | #undef MULI |
1718 | #undef SELECT |
1719 | |