1//===- LoopEmitter.cpp ----------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "LoopEmitter.h"
10#include "CodegenUtils.h"
11
12#include "mlir/Dialect/Arith/IR/Arith.h"
13#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
14#include "mlir/Dialect/Linalg/IR/Linalg.h"
15#include "mlir/Dialect/Linalg/Utils/Utils.h"
16#include "mlir/Dialect/MemRef/IR/MemRef.h"
17#include "mlir/Dialect/SCF/IR/SCF.h"
18#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
19#include "mlir/Dialect/Tensor/IR/Tensor.h"
20#include "mlir/Dialect/Vector/IR/VectorOps.h"
21
22using namespace mlir;
23using namespace mlir::sparse_tensor;
24
25//===----------------------------------------------------------------------===//
26// File local shorthand macros
27//===----------------------------------------------------------------------===//
28
29#define CMPI(p, l, r) \
30 (builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::p, (l), (r)) \
31 .getResult())
32
33#define C_IDX(v) (constantIndex(builder, loc, (v)))
34#define YIELD(vs) (builder.create<scf::YieldOp>(loc, (vs)))
35#define ADDI(lhs, rhs) (builder.create<arith::AddIOp>(loc, (lhs), (rhs)))
36#define ANDI(lhs, rhs) (builder.create<arith::AndIOp>(loc, (lhs), (rhs)))
37#define SUBI(lhs, rhs) (builder.create<arith::SubIOp>(loc, (lhs), (rhs)))
38#define MULI(lhs, rhs) (builder.create<arith::MulIOp>(loc, (lhs), (rhs)))
39#define REMUI(lhs, rhs) (builder.create<arith::RemUIOp>(loc, (lhs), (rhs)))
40#define DIVUI(lhs, rhs) (builder.create<arith::DivUIOp>(loc, (lhs), (rhs)))
41#define SELECT(c, l, r) (builder.create<arith::SelectOp>(loc, (c), (l), (r)))
42
43//===----------------------------------------------------------------------===//
44// Debugging utils
45//===----------------------------------------------------------------------===//
46
47#ifndef NDEBUG
48LLVM_ATTRIBUTE_UNUSED static void dumpIndexMemRef(OpBuilder &builder,
49 Location loc, Value memref) {
50 memref = builder.create<memref::CastOp>(
51 loc, UnrankedMemRefType::get(builder.getIndexType(), 0), memref);
52 createFuncCall(builder, loc, name: "printMemrefInd", resultType: TypeRange{},
53 operands: ValueRange{memref}, emitCInterface: EmitCInterface::On);
54}
55#endif
56
57//===----------------------------------------------------------------------===//
58// File local helper functions.
59//===----------------------------------------------------------------------===//
60
61// For index reduction loops, since the tensor are sliced into non-continuous
62// fragments, we need a triple [pLo, pHi, pPtr], in which the pair (pLo, pHi)
63// specifies the range of the fragment, and pPtr specifies the index of the
64// corresponding fragment in the child level (i.e., a pointer to the sliced
65// position array).
66static Value genSliceOffset(OpBuilder &builder, Location loc, Value tensor,
67 Level lvl) {
68 auto enc = getSparseTensorEncoding(tensor.getType());
69 return createOrFoldSliceOffsetOp(builder, loc, tensor, toDim(enc, lvl));
70}
71
72static Value genSliceStride(OpBuilder &builder, Location loc, Value tensor,
73 Level lvl) {
74 auto enc = getSparseTensorEncoding(tensor.getType());
75 return createOrFoldSliceStrideOp(builder, loc, tensor, toDim(enc, lvl));
76}
77
78static bool isIntOrFPZero(Attribute attr) {
79 if (auto f = llvm::dyn_cast<FloatAttr>(attr); f && f.getValue().isZero())
80 return true;
81 if (auto i = llvm::dyn_cast<IntegerAttr>(attr); i && i.getValue().isZero())
82 return true;
83 return false;
84}
85
86static Value unFoldOpIntResult(OpBuilder &builder, Location loc,
87 OpFoldResult ofr) {
88 if (std::optional<int64_t> i = getConstantIntValue(ofr); i.has_value())
89 return constantIndex(builder, loc, i: *i);
90 return cast<Value>(Val&: ofr);
91}
92
93static Value tryFoldTensors(Value t) {
94 // TODO: this should be done through a folding pass after switching to
95 // `sparse_tensor.iterate`-based sparsification.
96 auto stt = tryGetSparseTensorType(t);
97 auto padOp = t.getDefiningOp<tensor::PadOp>();
98 if (padOp && stt.has_value() && stt->hasEncoding() &&
99 padOp.getSourceType().getEncoding() == stt->getEncoding() &&
100 stt->getEncoding().isIdentity()) {
101 // Try fusing padOp with zeros.
102 Attribute padCst;
103 if (matchPattern(padOp.getBody()->getTerminator(),
104 m_Op<tensor::YieldOp>(m_Constant(&padCst))) &&
105 isIntOrFPZero(padCst)) {
106 return padOp.getSource();
107 }
108 }
109 return t;
110}
111
112//===----------------------------------------------------------------------===//
113// Sparse tensor loop emitter class implementations
114//===----------------------------------------------------------------------===//
115
116LoopEmitter::LoopEmitter(ValueRange tensors, StringAttr loopTag, bool hasOutput,
117 bool isSparseOut, unsigned numLoops,
118 DependentLvlGetter dimGetter,
119 SparseEmitStrategy emitStrategy) {
120 initialize(tensors, loopTag: loopTag, hasOutput, isSparseOut, numLoops, getter: dimGetter);
121}
122
123void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
124 bool isSparseOut, unsigned numLoops,
125 DependentLvlGetter dimGetter,
126 SparseEmitStrategy emitStrategy) {
127 // First initialize the top-level type of the fields.
128 this->loopTag = loopTag;
129 this->hasOutput = hasOutput;
130 this->isSparseOut = isSparseOut;
131 this->emitStrategy = emitStrategy;
132
133 const unsigned numManifestTensors = ts.size();
134 const unsigned synTensorId = numManifestTensors;
135 const unsigned numTensors = numManifestTensors + 1;
136 // tensors array (len == numManifestTensor).
137 this->tensors.assign(first: ts.begin(), last: ts.end());
138 // Arrays with len == numTensor.
139 this->valBuffer.assign(n: numTensors, val: nullptr);
140 this->lvls.resize(new_size: numTensors);
141 this->iters.resize(new_size: numTensors);
142 this->spIterVals.resize(new_size: numTensors);
143
144 // These zeros will be overwritten below, but we need to initialize
145 // them to something since we'll need random-access assignment.
146 this->loopStack.reserve(n: numLoops);
147 this->loopSeqStack.reserve(n: numLoops);
148
149 // Index-reduction related fields.
150 this->dependentLvlMap.assign(
151 n: numTensors, val: std::vector<std::vector<std::pair<TensorLevel, unsigned>>>());
152 this->sliceMeta.assign(
153 n: numTensors, val: std::vector<std::vector<std::pair<Value, unsigned>>>());
154 this->levelReducedDep.assign(n: numTensors, val: std::vector<unsigned>());
155
156 // Initialize nested types of `TensorId`-indexed fields.
157 for (TensorId tid = 0; tid < numTensors; tid++) {
158 Level lvlRank;
159 if (tid == synTensorId) {
160 // Synthetic tensor (conceptually) is an all-dense tensor with rank equal
161 // to the total number of loops (each level can potentially be mapped to
162 // one of the loop being generated).
163 lvlRank = numLoops;
164 } else {
165 const Value t = tensors[tid];
166 // a scalar or 0-dimension tensors
167 if (isZeroRankedTensorOrScalar(type: t.getType()))
168 continue;
169
170 auto rtp = getRankedTensorType(t);
171 const SparseTensorType stt(rtp);
172 lvlRank = stt.getLvlRank();
173 }
174
175 lvls[tid].resize(new_size: lvlRank);
176 iters[tid].resize(new_size: lvlRank);
177 spIterVals[tid].resize(new_size: lvlRank);
178 loopHighs.assign(n: numLoops, val: nullptr);
179
180 // Slice-driven loops related initialization.
181 levelReducedDep[tid].assign(n: lvlRank, val: 0);
182 dependentLvlMap[tid].assign(
183 n: lvlRank, val: std::vector<std::pair<TensorLevel, unsigned>>());
184 sliceMeta[tid].assign(n: lvlRank, val: std::vector<std::pair<Value, unsigned>>());
185 if (dimGetter && !isSynTensor(tid)) {
186 for (Level l = 0; l < lvlRank; l++) {
187 std::vector<std::pair<LoopId, unsigned>> deps = dimGetter(tid, l);
188 // Sort the loop by order.
189 llvm::sort(C&: deps, Comp: llvm::less_first());
190
191 dependentLvlMap[tid][l] = std::move(deps);
192 unsigned depends = dependentLvlMap[tid][l].size();
193 if (depends == 0)
194 continue;
195 sliceMeta[tid][l].reserve(n: depends);
196 }
197 }
198 }
199}
200
201std::unique_ptr<SparseIterator>
202LoopEmitter::makeLevelIterator(OpBuilder &builder, Location loc, TensorId t,
203 Level l) {
204 Value tensor = tensors[t];
205 auto stt = getSparseTensorType(val: tensor);
206 auto it = makeSimpleIterator(stl: *lvls[t][l], strategy: emitStrategy);
207
208 Value folded = tryFoldTensors(t: tensor);
209 if (folded != tensor) {
210 auto padOp = tensor.getDefiningOp<tensor::PadOp>();
211 assert(padOp);
212 if (padOp.getPaddedDims().test(l)) {
213 Value low = unFoldOpIntResult(builder, loc, padOp.getMixedLowPad()[l]);
214 Value high = unFoldOpIntResult(builder, loc, padOp.getMixedHighPad()[l]);
215 auto padIt = makePaddedIterator(sit: std::move(it), padLow: low, padHigh: high, strategy: emitStrategy);
216 return padIt;
217 }
218 }
219
220 if (stt.hasEncoding() && stt.getEncoding().isSlice()) {
221 Value offset = genSliceOffset(builder, loc, tensor, lvl: l);
222 Value stride = genSliceStride(builder, loc, tensor, lvl: l);
223 auto slicedIt = makeSlicedLevelIterator(
224 sit: std::move(it), offset, stride, size: lvls[t][l]->getSize(), strategy: emitStrategy);
225 return slicedIt;
226 }
227
228 return it;
229}
230
231void LoopEmitter::initializeLoopEmit(
232 OpBuilder &builder, Location loc, LoopEmitter::OutputUpdater updater,
233 LoopEmitter::SynTensorBoundSetter synSetter) {
234
235 // For every manifest tensor, set up the values buffer.
236 for (TensorId t = 0, numTensors = getNumManifestTensors(); t < numTensors;
237 t++) {
238 // TODO: this should be done through a folding pass after switching to
239 // `sparse_tensor.iterate`-based sparsification.
240 const Value tensor = tryFoldTensors(t: tensors[t]);
241 const auto rtp = dyn_cast<RankedTensorType>(tensor.getType());
242 // Skips only scalar, zero ranked tensor still need to be bufferized and
243 // (probably) filled with zeros by users.
244 if (!rtp)
245 continue;
246
247 auto stt = getSparseTensorType(val: tensor);
248 const auto shape = rtp.getShape();
249
250 // Perform the required bufferization. Dense inputs materialize from the
251 // input tensors. Sparse inputs use sparse primitives to obtain the values.
252 // Delegates extra output initialization to clients.
253 bool isOutput = isOutputTensor(tid: t);
254 Type elementType = stt.getElementType();
255 if (!stt.hasEncoding()) {
256 // Non-annotated dense tensors.
257 BaseMemRefType denseTp = MemRefType::get(shape, elementType);
258
259 // TODO: if we unconditionally use fully dynamic layout here, it breaks
260 // some vectorization passes which requires static stride = 1.
261 // Is it possible to call vectorization pass after bufferization?
262 if (llvm::isa_and_nonnull<tensor::ExtractSliceOp>(Val: tensor.getDefiningOp()))
263 denseTp = bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType: rtp);
264
265 Value denseVal =
266 builder.create<bufferization::ToBufferOp>(loc, denseTp, tensor);
267 // Dense outputs need special handling.
268 if (isOutput && updater)
269 denseVal = updater(builder, loc, denseVal, tensor);
270
271 valBuffer[t] = denseVal;
272 } else {
273 // Annotated sparse tensors.
274 // We also need the value buffer for all-dense annotated "sparse"
275 // tensors.
276 valBuffer[t] = builder.create<ToValuesOp>(loc, tensor);
277 }
278 }
279
280 // The sparse iterator values will only be available after the loop is
281 // constructed.
282 if (emitStrategy == SparseEmitStrategy::kSparseIterator)
283 return;
284
285 // For every synthetic tensor, set the high bound by calling the callback.
286 if (synSetter) {
287 TensorId synId = getSynTensorId();
288 for (unsigned i = 0, e = loopHighs.size(); i < e; i++) {
289 Value sz = loopHighs[i] = synSetter(builder, loc, i);
290 auto [stl, it] = makeSynLevelAndIterator(sz, tid: synId, lvl: i, strategy: emitStrategy);
291 lvls[synId][i] = std::move(stl);
292 iters[synId][i].emplace_back(args: std::move(it));
293 }
294 }
295
296 // For every manifest tensor:
297 // * For every level:
298 // * get the positions and coordinates buffers
299 // * get/compute the level-size, which is also used as the upper-bound
300 // on positions.
301 for (TensorId t = 0, numTensors = getNumManifestTensors(); t < numTensors;
302 t++) {
303 // TODO: this should be done through a folding pass after switching to
304 // `sparse_tensor.iterate`-based sparsification.
305 const Value tensor = tryFoldTensors(t: tensors[t]);
306 const auto rtp = dyn_cast<RankedTensorType>(tensor.getType());
307 if (!rtp)
308 // Skips only scalar, zero ranked tensor still need to be bufferized and
309 // (probably) filled with zeros by users.
310 continue;
311
312 auto stt = getSparseTensorType(val: tensor);
313 const Level lvlRank = stt.getLvlRank();
314
315 // Scan all levels of current tensor.
316 for (Level l = 0; l < lvlRank; l++) {
317 // Find upper bound in current dimension.
318 lvls[t][l] = makeSparseTensorLevel(b&: builder, l: loc, t: tensor, tid: t, lvl: l);
319 if (!dependentLvlMap[t][l].empty())
320 continue;
321
322 auto it = makeLevelIterator(builder, loc, t, l);
323 iters[t][l].emplace_back(args: std::move(it));
324 }
325 // NOTE: we can also prepare for 0 lvl here in advance, this will hoist
326 // some loop preparation from tensor iteration, but will also (undesirably)
327 // hoist the code ouside if-conditions.
328 }
329 // TODO: avoid treating subsection iterator as a special case.
330 initSubSectIterator(builder, loc);
331}
332
333void LoopEmitter::initSubSectIterator(OpBuilder &builder, Location loc) {
334 Value c0 = C_IDX(0);
335 for (TensorId t = 0, e = tensors.size(); t < e; t++) {
336 auto rtp = dyn_cast<RankedTensorType>(tensors[t].getType());
337 if (!rtp)
338 continue;
339
340 Level lvlRank = SparseTensorType(rtp).getLvlRank();
341
342 // Compute the dependency reduction order.
343 auto remDepStack = dependentLvlMap;
344 std::vector<std::tuple<LoopId, TensorId, Level>> depRedOrder;
345 for (Level lvl = 0; lvl < lvlRank; lvl++) {
346 // Reverse queue into a stack.
347 std::reverse(first: remDepStack[t][lvl].begin(), last: remDepStack[t][lvl].end());
348 for (auto [loop, coeff] : dependentLvlMap[t][lvl])
349 depRedOrder.emplace_back(args: std::make_tuple(args&: loop, args&: t, args&: lvl));
350 }
351
352 if (depRedOrder.empty())
353 continue;
354
355 llvm::sort(C&: depRedOrder, Comp: llvm::less_first());
356
357 SmallVector<SparseIterator *> lastIter(tensors.size(), nullptr);
358 for (auto [loop, t, lvl] : depRedOrder) {
359 std::pair<LoopId, unsigned> curDep = remDepStack[t][lvl].back();
360 assert(curDep.first == loop);
361 remDepStack[t][lvl].pop_back();
362
363 auto lvlIt = makeLevelIterator(builder, loc, t, l: lvl);
364 const SparseIterator *parent = lastIter[t];
365 if (!parent && lvl > 0) {
366 if (dependentLvlMap[t][lvl - 1].empty()) {
367 parent = iters[t][lvl - 1].back().get();
368 }
369 }
370
371 std::unique_ptr<SparseIterator> it;
372 if (!remDepStack[t][lvl].empty()) {
373 // Compute the subsection size.
374 Value size = c0;
375 for (auto [loop, stride] : remDepStack[t][lvl]) {
376 Value idxMax = SUBI(loopHighs[loop], C_IDX(1));
377 size = ADDI(size, ADDI(MULI(idxMax, C_IDX(stride)), C_IDX(1)));
378 }
379 it = makeNonEmptySubSectIterator(b&: builder, l: loc, parent, loopBound: loopHighs[loop],
380 delegate: std::move(lvlIt), size, stride: curDep.second,
381 strategy: emitStrategy);
382 } else {
383 const SparseIterator &subSectIter = *iters[t][lvl].back();
384 it = makeTraverseSubSectIterator(b&: builder, l: loc, subsectIter: subSectIter, parent: *parent,
385 wrap: std::move(lvlIt), loopBound: loopHighs[loop],
386 stride: curDep.second, strategy: emitStrategy);
387 }
388 lastIter[t] = it.get();
389 iters[t][lvl].emplace_back(args: std::move(it));
390 }
391 }
392}
393
394void LoopEmitter::categorizeIterators(
395 ArrayRef<TensorLevel> tidLvls, SmallVectorImpl<SparseIterator *> &raIters,
396 SmallVectorImpl<SparseIterator *> &spIters) {
397 // Finds out the tensor level that we should use to generate loops. Amongs all
398 // the tensor levels, there is at most one sparse tensor level.
399 for (auto [t, l] : unpackTensorLevelRange(c&: tidLvls)) {
400 SparseIterator *it = &getCurIterator(tid: t, lvl: l);
401 if (it->randomAccessible())
402 raIters.push_back(Elt: it);
403 else
404 spIters.push_back(Elt: it);
405 }
406
407 llvm::stable_sort(Range&: spIters, C: [](auto lhs, auto rhs) {
408 // AffineUnRed > Affine > Slice > Trivial
409 return static_cast<uint8_t>(lhs->kind) > static_cast<uint8_t>(rhs->kind);
410 });
411}
412
413void LoopEmitter::enterNewLoopSeq(OpBuilder &builder, Location loc,
414 ArrayRef<TensorLevel> tidLvls) {
415 // TODO: sort
416 assert(loopSeqStack.size() == loopStack.size());
417
418 if (emitStrategy != SparseEmitStrategy::kSparseIterator) {
419 // Prepares for all the tensors used in the current loop sequence.
420 for (auto [tid, lvl] : unpackTensorLevelRange(c&: tidLvls)) {
421 levelReducedDep[tid][lvl]++;
422 prepareLoopOverTensorAtLvl(builder, loc, tid, lvl);
423 }
424 }
425
426 // Universal Index starts from 0.
427 loopSeqStack.emplace_back(C_IDX(0), args: tidLvls.vec());
428}
429
430void LoopEmitter::exitCurrentLoopSeq(OpBuilder &builder, Location loc) {
431 assert(loopSeqStack.size() == loopStack.size() + 1);
432
433 // Depending on whether the slice is resolved or not at current loop sequence,
434 // end them in different ways.
435 for (auto [tid, lvl] : unpackTensorLevelRange(c&: loopSeqStack.back().second))
436 levelReducedDep[tid][lvl]--;
437
438 loopSeqStack.pop_back();
439}
440
441Value LoopEmitter::genAffine(OpBuilder &builder, Location loc, AffineExpr a) {
442 switch (a.getKind()) {
443 case AffineExprKind::DimId: {
444 // FIXME: since the one callsite in Sparsification passes in a
445 // level-expression, the `getPosition` must in fact be a `Dimension`.
446 // However, elsewhere we have been lead to expect that `loopIdToOrd`
447 // should be indexed by `LoopId`...
448 const auto loopId = cast<AffineDimExpr>(Val&: a).getPosition();
449 return loopStack[loopId].iv;
450 }
451 case AffineExprKind::Add: {
452 auto binOp = cast<AffineBinaryOpExpr>(Val&: a);
453 return ADDI(genAffine(builder, loc, binOp.getLHS()),
454 genAffine(builder, loc, binOp.getRHS()));
455 }
456 case AffineExprKind::Mul: {
457 auto binOp = cast<AffineBinaryOpExpr>(Val&: a);
458 return MULI(genAffine(builder, loc, binOp.getLHS()),
459 genAffine(builder, loc, binOp.getRHS()));
460 }
461 case AffineExprKind::Constant: {
462 int64_t c = cast<AffineConstantExpr>(Val&: a).getValue();
463 return C_IDX(c);
464 }
465 default:
466 llvm_unreachable("unexpected affine subscript");
467 }
468}
469
470std::pair<Operation *, Value> LoopEmitter::emitForLoopOverTensorAtLvl(
471 OpBuilder &builder, Location loc, SparseIterator &iter,
472 MutableArrayRef<Value> reduc, bool isParallel) {
473
474 // TODO: support dynamic slices.
475 // Uses the first dimension here to build the loop bound (which is also the
476 // biggest range).
477
478 Value step = C_IDX(1);
479 auto [lo, hi] = iter.genForCond(b&: builder, l: loc);
480 Operation *loop = nullptr;
481 Value iv;
482 if (isParallel) {
483 scf::ParallelOp parOp =
484 builder.create<scf::ParallelOp>(loc, lo, hi, step, reduc);
485 builder.setInsertionPointToStart(parOp.getBody());
486 assert(parOp.getNumReductions() == reduc.size());
487 iv = parOp.getInductionVars()[0];
488
489 // In-place update on the reduction variable vector.
490 // Note that the init vals is not the actual reduction variables but instead
491 // used as a "special handle" to (temporarily) represent them. The
492 // expression on init vals will be moved into scf.reduce and replaced with
493 // the block arguments when exiting the loop (see exitForLoop). This is
494 // needed as we can not build the actual reduction block and get the actual
495 // reduction variable before users fill parallel loop body.
496 for (int i = 0, e = reduc.size(); i < e; i++)
497 reduc[i] = parOp.getInitVals()[i];
498 loop = parOp;
499 } else {
500 scf::ForOp forOp = builder.create<scf::ForOp>(loc, lo, hi, step, reduc);
501 builder.setInsertionPointToStart(forOp.getBody());
502 iv = forOp.getInductionVar();
503
504 // In-place update on the reduction variable vector.
505 assert(forOp.getNumRegionIterArgs() == reduc.size());
506 for (int i = 0, e = reduc.size(); i < e; i++)
507 reduc[i] = forOp.getRegionIterArg(i);
508 loop = forOp;
509 }
510 assert(loop && iv);
511
512 Value crd = iv;
513 if (!iter.randomAccessible()) {
514 iter.linkNewScope(pos: iv);
515 crd = iter.deref(b&: builder, l: loc);
516 } else {
517 iter.locate(b&: builder, l: loc, crd: iv);
518 }
519
520 return {loop, crd};
521}
522
523std::pair<Operation *, Value> LoopEmitter::emitWhileLoopOverTensorsAtLvls(
524 OpBuilder &builder, Location loc, ArrayRef<SparseIterator *> spIters,
525 MutableArrayRef<Value> reduc, bool needsUniv) {
526 return genCoIteration(builder, loc, iters: spIters, reduc,
527 uniIdx: needsUniv ? loopSeqStack.back().first : nullptr);
528}
529
530bool LoopEmitter::shouldIteratedByForLoop(ArrayRef<SparseIterator *> spIters) {
531 // If we need to co-iterate over two sparse tensors, we need a while loop
532 if (spIters.size() > 1)
533 return false;
534
535 if (spIters.size() == 1)
536 return spIters.front()->iteratableByFor();
537
538 return true;
539}
540
541Region *LoopEmitter::enterCurrentCoIterationCase(OpBuilder &builder,
542 Location loc,
543 I64BitSet caseBit,
544 unsigned caseIdx,
545 MutableArrayRef<Value> reduc) {
546 auto coIterOp = cast<CoIterateOp>(loopStack.back().loop);
547 SmallVector<Attribute> cases(coIterOp.getCases().getAsRange<Attribute>());
548 cases[caseIdx] = builder.getI64IntegerAttr(caseBit);
549
550 coIterOp.setCasesAttr(builder.getArrayAttr(cases));
551 Region &caseRegion = coIterOp.getRegion(caseIdx);
552 assert(caseRegion.getBlocks().empty() &&
553 "re-initialize the same coiteration case region.");
554
555 // Each block starts with by a list of user-provided iteration arguments.
556 TypeRange iterArgsTps = coIterOp.getInitArgs().getTypes();
557 // Followed by a list of used coordinates of index type.
558 SmallVector<Type> blockArgTps(coIterOp.getCrdUsedLvls().count(),
559 builder.getIndexType());
560
561 blockArgTps.append(in_start: iterArgsTps.begin(), in_end: iterArgsTps.end());
562 // Ends with a set of iterators that defines the actually iteration space.
563 for (auto i : caseBit.bits()) {
564 blockArgTps.push_back(
565 cast<IterSpaceType>(coIterOp.getIterSpaces()[i].getType())
566 .getIteratorType());
567 }
568 SmallVector<Location> locs(blockArgTps.size(), loc);
569 caseRegion.emplaceBlock().addArguments(types: blockArgTps, locs);
570
571 // Entering the new region scope, updating the SSA chain.
572 builder.setInsertionPointToStart(&caseRegion.front());
573 // Update the coordinates.
574 loopStack.back().iv = coIterOp.getCrds(caseIdx).front();
575 // Updates loop iteration arguments.
576 ValueRange iterArgs = coIterOp.getRegionIterArgs(caseIdx);
577 llvm::copy(Range&: iterArgs, Out: reduc.begin());
578 // Updates sparse iterator values.
579 ValueRange iters = coIterOp.getRegionIterators(caseIdx);
580 ArrayRef<TensorLevel> tidLvls = loopStack.back().tidLvls;
581 for (auto [i, tl] : llvm::enumerate(First: unpackTensorLevelRange(c&: tidLvls))) {
582 if (caseBit[i]) {
583 spIterVals[tl.first][tl.second] = iters.front();
584 iters = iters.drop_front();
585 } else {
586 spIterVals[tl.first][tl.second] = nullptr;
587 }
588 }
589 // Must have consumed all iterator SSA values.
590 assert(iters.empty());
591 return &caseRegion;
592}
593
594Operation *LoopEmitter::enterCoIterationOverTensorsAtLvls(
595 OpBuilder &builder, Location loc, ArrayRef<TensorLevel> tidLvls,
596 unsigned numCases, MutableArrayRef<Value> reduc, bool tryParallel,
597 bool needsUniv) {
598 // TODO: Argument `numCases` only used when generating iterator-based sparse
599 // loops. Simplify the code upon feature complete.
600 // TODO: handle coiteration with sparse iterator.
601 if (emitStrategy == SparseEmitStrategy::kSparseIterator) {
602 if (tidLvls.size() == 1) {
603 auto [tid, lvl] = unpackTensorLevel(tidLvl: tidLvls.front());
604 Value t = tensors[tid];
605
606 // Extract and iterate over the iteration space.
607 ExtractIterSpaceOp extractSpaceOp =
608 lvl == 0 ? builder.create<ExtractIterSpaceOp>(loc, t)
609 : builder.create<ExtractIterSpaceOp>(
610 loc, t, spIterVals[tid][lvl - 1], lvl);
611
612 IterateOp iterOp = builder.create<IterateOp>(
613 loc, extractSpaceOp.getExtractedSpace(), reduc);
614 spIterVals[tid][lvl] = iterOp.getIterator();
615
616 // Update the reduction varaibles.
617 llvm::copy(iterOp.getRegionIterArgs(), reduc.begin());
618 // Set the insertion point to loop body.
619 builder.setInsertionPointToStart(iterOp.getBody());
620 loopStack.emplace_back(tidLvls, iterOp, builder.getInsertionBlock(),
621 iterOp.getCrds().front(), loopTag);
622 return iterOp;
623 }
624
625 // CoIteration Loops.
626 SmallVector<Value> spaces;
627 for (auto [tid, lvl] : unpackTensorLevelRange(c&: tidLvls)) {
628 Value t = tensors[tid];
629 ExtractIterSpaceOp extractSpaceOp =
630 lvl == 0 ? builder.create<ExtractIterSpaceOp>(loc, t)
631 : builder.create<ExtractIterSpaceOp>(
632 loc, t, spIterVals[tid][lvl - 1], lvl);
633 spaces.push_back(Elt: extractSpaceOp.getExtractedSpace());
634 }
635 auto coIterOp = builder.create<CoIterateOp>(loc, spaces, reduc, numCases);
636 // The CoIterationOp does not have insertion block nor induction variable.
637 // TODO: the `struct LoopInfo` should be simplied after full migration.
638 loopStack.emplace_back(tidLvls, coIterOp, /*insertion block*/ nullptr,
639 /*induction variable*/ nullptr, loopTag);
640 return coIterOp;
641 }
642
643 // TODO: support multiple return on parallel for?
644 tryParallel = tryParallel && reduc.size() <= 1;
645
646 SmallVector<SparseIterator *> raIters;
647 SmallVector<SparseIterator *> spIters;
648 categorizeIterators(tidLvls, raIters, spIters);
649
650 // Only when there is at least one sparse conditions, do we really need the
651 // universal index.
652 // TODO: Maybe we should instead requires merger to pass in a valid value at
653 // the first place instead of adjusting it in LoopEmitter?
654 needsUniv = !spIters.empty() && needsUniv;
655 // The TensorLevel used for loop conditions.
656 // If there is any sparse level, we need to use the sparse condition.
657 // If all levels are dense, we can pick arbitrary one (dense slice-driven loop
658 // can be generated using a simple ForOp as well).
659 Operation *l = nullptr;
660 Value iv = nullptr;
661 SmallVector<TensorLevel> tls;
662
663 // Generates loops differently depending on whether we need a slice-driven
664 // loop or a simple level traversal loop.
665 if (shouldIteratedByForLoop(spIters) && !needsUniv) {
666 assert(spIters.size() <= 1);
667 SparseIterator &it = spIters.empty() ? *raIters.front() : *spIters.front();
668 std::tie(args&: l, args&: iv) =
669 emitForLoopOverTensorAtLvl(builder, loc, iter&: it, reduc, isParallel: tryParallel);
670 tls.push_back(Elt: makeTensorLevel(t: it.tid, l: it.lvl));
671 } else {
672 for (auto *it : spIters) {
673 tls.push_back(Elt: makeTensorLevel(t: it->tid, l: it->lvl));
674 }
675
676 if (needsUniv)
677 for (auto *it : raIters)
678 tls.push_back(Elt: makeTensorLevel(t: it->tid, l: it->lvl));
679
680 std::tie(args&: l, args&: iv) =
681 emitWhileLoopOverTensorsAtLvls(builder, loc, spIters, reduc, needsUniv);
682 }
683
684 // Enter dense tensor levels.
685 for (SparseIterator *it : raIters)
686 it->locate(b&: builder, l: loc, crd: iv);
687
688 // NOTE: we can also prepare for next dim here in advance
689 // Pushes the loop into stack.
690 loopStack.emplace_back(tls, l, builder.getInsertionBlock(), iv, loopTag);
691 return l;
692}
693
694void LoopEmitter::locateLvlAtAffineAddress(OpBuilder &builder, Location loc,
695 TensorLevel tidLvl,
696 AffineExpr lvlExpr) {
697 auto [tid, lvl] = unpackTensorLevel(tidLvl);
698
699 const SparseIterator *parent =
700 lvl == 0 ? nullptr : iters[tid][lvl - 1].back().get();
701 auto &it = getCurIterator(tid, lvl);
702 it.genInit(b&: builder, l: loc, p: parent);
703
704 assert(it.kind == IterKind::kTrivial && it.randomAccessible());
705 Value lvlCrd = genAffine(builder, loc, a: lvlExpr);
706 it.locate(b&: builder, l: loc, crd: lvlCrd);
707}
708
709void LoopEmitter::prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc,
710 TensorId tid, Level lvl) {
711 // if this is the first level, there is no parent iterator for the current
712 // iterator.
713 // If the current iterator is a subsection-based iterator, the parent iterator
714 // is memorized by the iterator.
715 bool hasParent = lvl == 0 || !dependentLvlMap[tid][lvl].empty();
716
717 const SparseIterator *parent =
718 hasParent ? nullptr : iters[tid][lvl - 1].back().get();
719 auto &it = getCurIterator(tid, lvl);
720 it.genInit(b&: builder, l: loc, p: parent);
721
722 // Locates the randon accessible iterator to 0.
723 if (it.randomAccessible())
724 it.locate(b&: builder, l: loc, C_IDX(0));
725}
726
727void LoopEmitter::exitForLoop(RewriterBase &rewriter, Location loc,
728 MutableArrayRef<Value> reduc) {
729 const LoopInfo &loopInfo = loopStack.back();
730 if (emitStrategy == SparseEmitStrategy::kSparseIterator) {
731 auto iterateOp = llvm::cast<IterateOp>(loopInfo.loop);
732 assert(reduc.size() == iterateOp.getNumResults());
733 rewriter.create<sparse_tensor::YieldOp>(loc, reduc);
734 // Exit the loop.
735 rewriter.setInsertionPointAfter(iterateOp);
736 // In-place update reduction variables.
737 llvm::copy(iterateOp.getResults(), reduc.begin());
738 return;
739 }
740 if (auto forOp = llvm::dyn_cast<scf::ForOp>(loopInfo.loop)) {
741 if (!reduc.empty()) {
742 assert(reduc.size() == forOp.getNumResults());
743 rewriter.create<scf::YieldOp>(loc, reduc);
744 }
745 // Exit the loop.
746 rewriter.setInsertionPointAfter(forOp);
747 // In-place update reduction variables.
748 llvm::copy(forOp.getResults(), reduc.begin());
749 } else {
750 auto parOp = llvm::cast<scf::ParallelOp>(loopInfo.loop);
751 if (!reduc.empty()) {
752 assert(reduc.size() == parOp.getInitVals().size() && reduc.size() == 1);
753 Operation *redExp = reduc.front().getDefiningOp();
754 // Reduction expression should have no use.
755 assert(redExp->getUses().empty());
756 // This must be a binary operation.
757 // NOTE: This is users' responsibility to ensure the operation are
758 // commutative.
759 assert(redExp->getNumOperands() == 2 && redExp->getNumResults() == 1);
760
761 Value redVal = parOp.getInitVals().front();
762 Value curVal;
763 if (redExp->getOperand(idx: 0) == redVal)
764 curVal = redExp->getOperand(idx: 1);
765 else if (redExp->getOperand(idx: 1) == redVal)
766 curVal = redExp->getOperand(idx: 0);
767 // One of the operands must be the init value (which is also the
768 // previous reduction value).
769 assert(curVal);
770#ifndef NDEBUG
771 // The reduction expression should be the only user of the reduction val
772 // inside the parallel for.
773 unsigned numUsers = 0;
774 for (Operation *op : redVal.getUsers()) {
775 if (op->getParentOp() == parOp)
776 numUsers++;
777 }
778 assert(numUsers == 1);
779#endif // NDEBUG
780
781 rewriter.setInsertionPointAfter(redExp);
782 auto redOp = rewriter.create<scf::ReduceOp>(loc, curVal);
783 // Attach to the reduction op.
784 Block *redBlock = &redOp.getReductions().front().front();
785 rewriter.setInsertionPointToEnd(redBlock);
786 Operation *newRed = rewriter.clone(op&: *redExp);
787 // Replaces arguments of the reduction expression by using the block
788 // arguments from scf.reduce.
789 rewriter.modifyOpInPlace(
790 root: newRed, callable: [&]() { newRed->setOperands(redBlock->getArguments()); });
791 // Erases the out-dated reduction expression.
792 rewriter.eraseOp(op: redExp);
793 rewriter.setInsertionPointToEnd(redBlock);
794 rewriter.create<scf::ReduceReturnOp>(loc, newRed->getResult(0));
795 }
796 rewriter.setInsertionPointAfter(parOp);
797 // In-place update reduction variables.
798 for (unsigned i = 0, e = parOp.getResults().size(); i < e; i++)
799 reduc[i] = parOp.getResult(i);
800 }
801}
802
803void LoopEmitter::exitWhileLoop(OpBuilder &builder, Location loc,
804 MutableArrayRef<Value> reduc) {
805 const LoopInfo &loopInfo = loopStack.back();
806 auto whileOp = llvm::cast<scf::WhileOp>(loopInfo.loop);
807 Value iv = loopInfo.iv;
808 Value one = C_IDX(1);
809
810 // Finalize the induction. Note that the induction could be performed
811 // in the individual if-branches to avoid re-evaluating the conditions.
812 // However, that would result in a rather elaborate forest of yield
813 // instructions during code generation. Moreover, performing the induction
814 // after the if-statements more closely resembles code generated by TACO.
815 SmallVector<Value> operands;
816 ValueRange whileRes = whileOp.getResults();
817
818 for (auto [tid, lvl] : unpackTensorLevelRange(c: loopInfo.tidLvls)) {
819 SparseIterator &it = getCurIterator(tid, lvl);
820 if (!it.randomAccessible()) {
821 // Forward the sparse iterator.
822 Value cmp = CMPI(eq, it.getCrd(), iv);
823 it.forwardIf(b&: builder, l: loc, cond: cmp);
824 operands.append(in_start: it.getCursor().begin(), in_end: it.getCursor().end());
825 // const Value newPos = whileOp->getResult(o++);
826 // Following loops continue iteration from the break point of the
827 // current while loop.
828 whileRes = it.linkNewScope(pos: whileRes);
829 } else {
830 // Make sure randomly accessible (dense) iterator is set to the right
831 // position according to the universal index.
832 Value uniIdx = whileOp.getResults().back();
833 it.locate(b&: builder, l: loc, crd: uniIdx);
834 }
835 }
836
837 // Reduction value from users.
838 for (auto &i : reduc) {
839 operands.push_back(Elt: i);
840 // Update user reduction variables.
841 i = whileRes.front();
842 whileRes = whileRes.drop_front();
843 }
844
845 // An (optional) universal index.
846 if (operands.size() < whileOp.getNumResults()) {
847 assert(operands.size() + 1 == whileOp.getNumResults());
848 // The last one is the universial index.
849 operands.push_back(ADDI(iv, one));
850 // update the loop starting point of current loop sequence
851 loopSeqStack.back().first = whileOp->getResults().back();
852 }
853
854 if (!operands.empty())
855 YIELD(operands);
856
857 builder.setInsertionPointAfter(whileOp);
858}
859
860void LoopEmitter::exitCurrentLoop(RewriterBase &rewriter, Location loc,
861 MutableArrayRef<Value> reduc) {
862 // Clean up the values, it would help use to discover potential bug at a
863 // earlier stage (instead of silently using a wrong value).
864 const LoopInfo &loopInfo = loopStack.back();
865 if (emitStrategy == SparseEmitStrategy::kSparseIterator) {
866 Operation *p = loopInfo.loop;
867 if (isa<IterateOp>(p))
868 rewriter.create<sparse_tensor::YieldOp>(loc, reduc);
869
870 // Exit the loop.
871 rewriter.setInsertionPointAfter(p);
872 // In-place update reduction variables.
873 llvm::copy(Range: p->getResults(), Out: reduc.begin());
874 loopStack.pop_back();
875 return;
876 }
877
878 // Sets the insertion point to the right position.
879 rewriter.setInsertionPointToEnd(loopInfo.userCodeBlock);
880 if (!loopInfo.userCodeBlock->empty() &&
881 llvm::isa<scf::YieldOp>(&loopInfo.userCodeBlock->back())) {
882 // scf::While/For inserts an implicit yield op when there is no loop
883 // iter args. In this case, we need to insert the code before the yield.
884 assert(loopInfo.userCodeBlock->back().getNumResults() == 0);
885 rewriter.setInsertionPoint(&loopInfo.userCodeBlock->back());
886 }
887
888 if (llvm::isa<scf::WhileOp>(loopInfo.loop)) {
889 exitWhileLoop(builder&: rewriter, loc, reduc);
890 } else {
891 exitForLoop(rewriter, loc, reduc);
892 }
893
894 assert(loopStack.size() == loopSeqStack.size());
895 loopStack.pop_back();
896}
897
898//===----------------------------------------------------------------------===//
899// Loop generation utils
900//===----------------------------------------------------------------------===//
901
902std::pair<Operation *, Value> sparse_tensor::genCoIteration(
903 OpBuilder &builder, Location loc, ArrayRef<SparseIterator *> spIters,
904 MutableArrayRef<Value> reduc, Value uniIdx, bool userReducFirst) {
905 // NOTE: the slice driven tensor-related reduction variable must
906 // appear before normal tensors.
907
908 // The set of induction variables for the while loop.
909 SmallVector<Value> ivs;
910
911 // TODO: remove the flag after full migration. Currently
912 // `sparse_tensor.coiterate` operation (must) put user provided reduction
913 // values at the front of the block list, while direct sparsification to scf
914 // loops put them at the end.
915 if (userReducFirst)
916 ivs.append(in_start: reduc.begin(), in_end: reduc.end());
917
918 // Construct the while-loop with a parameter for each coordinate.
919 for (SparseIterator *it : spIters) {
920 ValueRange itVals = it->getCursor();
921 ivs.append(in_start: itVals.begin(), in_end: itVals.end());
922 }
923
924 if (!userReducFirst)
925 ivs.append(in_start: reduc.begin(), in_end: reduc.end());
926
927 // Update universal index.
928 if (uniIdx)
929 ivs.push_back(Elt: uniIdx);
930
931 // Ensures all operands are valid.
932 assert(llvm::all_of(ivs, [](Value v) { return v != nullptr; }));
933 TypeRange types = ValueRange(ivs).getTypes();
934 auto whileOp = builder.create<scf::WhileOp>(loc, types, ivs);
935
936 SmallVector<Location> locs(types.size(), loc);
937 Block *before = builder.createBlock(&whileOp.getBefore(), {}, types, locs);
938 Block *after = builder.createBlock(&whileOp.getAfter(), {}, types, locs);
939
940 // Generates loop conditions.
941 builder.setInsertionPointToStart(before);
942 ValueRange bArgs = before->getArguments();
943 Value whileCond = nullptr; // bool values for loop condition.
944
945 for (SparseIterator *it : spIters) {
946 auto [cond, remArgs] = it->genWhileCond(b&: builder, l: loc, vs: bArgs);
947 whileCond = !whileCond ? cond : ANDI(whileCond, cond);
948 bArgs = remArgs;
949 }
950 // The remaining block arguments are user-provided reduction values and an
951 // optional universal index. Make sure their sizes match.
952 assert(bArgs.size() == reduc.size() + (uniIdx ? 1 : 0));
953 builder.create<scf::ConditionOp>(loc, whileCond, before->getArguments());
954
955 // Generates loop body.
956 builder.setInsertionPointToStart(after);
957 ValueRange aArgs = after->getArguments();
958
959 for (SparseIterator *it : spIters) {
960 aArgs = it->linkNewScope(pos: aArgs);
961 // Dereference the iterator to cache the coordinate.
962 it->deref(b&: builder, l: loc);
963 }
964
965 // In-place update on reduction variable.
966 for (unsigned i = 0, e = reduc.size(); i < e; i++)
967 reduc[i] = aArgs[i];
968
969 Value min;
970 // Finds the minimum coordinate
971 if (!uniIdx) {
972 for (SparseIterator *it : spIters) {
973 if (min) {
974 Value cmp = CMPI(ult, it->getCrd(), min);
975 min = SELECT(cmp, it->getCrd(), min);
976 } else {
977 min = it->getCrd();
978 }
979 }
980 } else {
981 // Otherwise, universal index is the minimal pos.
982 min = whileOp.getAfterArguments().back();
983 }
984
985 return {whileOp, min};
986}
987
988#undef CMPI
989#undef C_IDX
990#undef YIELD
991#undef ADDI
992#undef ANDI
993#undef SUBI
994#undef MULI
995#undef SELECT
996

source code of mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp