1//===- LoopEmitter.cpp ----------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "LoopEmitter.h"
10#include "CodegenUtils.h"
11
12#include "mlir/Dialect/Arith/IR/Arith.h"
13#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
14#include "mlir/Dialect/Linalg/IR/Linalg.h"
15#include "mlir/Dialect/Linalg/Utils/Utils.h"
16#include "mlir/Dialect/MemRef/IR/MemRef.h"
17#include "mlir/Dialect/SCF/IR/SCF.h"
18#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
19#include "mlir/Dialect/Tensor/IR/Tensor.h"
20#include "mlir/Dialect/Vector/IR/VectorOps.h"
21
22using namespace mlir;
23using namespace mlir::sparse_tensor;
24
25//===----------------------------------------------------------------------===//
26// File local shorthand macros
27//===----------------------------------------------------------------------===//
28
29#define CMPI(p, l, r) \
30 (builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::p, (l), (r)) \
31 .getResult())
32
33#define C_IDX(v) (constantIndex(builder, loc, (v)))
34#define YIELD(vs) (builder.create<scf::YieldOp>(loc, (vs)))
35#define ADDI(lhs, rhs) (builder.create<arith::AddIOp>(loc, (lhs), (rhs)))
36#define ANDI(lhs, rhs) (builder.create<arith::AndIOp>(loc, (lhs), (rhs)))
37#define SUBI(lhs, rhs) (builder.create<arith::SubIOp>(loc, (lhs), (rhs)))
38#define MULI(lhs, rhs) (builder.create<arith::MulIOp>(loc, (lhs), (rhs)))
39#define REMUI(lhs, rhs) (builder.create<arith::RemUIOp>(loc, (lhs), (rhs)))
40#define DIVUI(lhs, rhs) (builder.create<arith::DivUIOp>(loc, (lhs), (rhs)))
41#define SELECT(c, l, r) (builder.create<arith::SelectOp>(loc, (c), (l), (r)))
42
43//===----------------------------------------------------------------------===//
44// Debugging utils
45//===----------------------------------------------------------------------===//
46
47#ifndef NDEBUG
48LLVM_ATTRIBUTE_UNUSED static void dumpIndexMemRef(OpBuilder &builder,
49 Location loc, Value memref) {
50 memref = builder.create<memref::CastOp>(
51 loc, UnrankedMemRefType::get(builder.getIndexType(), 0), memref);
52 createFuncCall(builder, loc, name: "printMemrefInd", resultType: TypeRange{},
53 operands: ValueRange{memref}, emitCInterface: EmitCInterface::On);
54}
55#endif
56
57//===----------------------------------------------------------------------===//
58// File local helper functions.
59//===----------------------------------------------------------------------===//
60
61// For index reduction loops, since the tensor are sliced into non-continuous
62// fragments, we need a triple [pLo, pHi, pPtr], in which the pair (pLo, pHi)
63// specifies the range of the fragment, and pPtr specifies the index of the
64// corresponding fragment in the child level (i.e., a pointer to the sliced
65// position array).
66static Value genSliceOffset(OpBuilder &builder, Location loc, Value tensor,
67 Level lvl) {
68 auto enc = getSparseTensorEncoding(tensor.getType());
69 return createOrFoldSliceOffsetOp(builder, loc, tensor, toDim(enc, lvl));
70}
71
72static Value genSliceStride(OpBuilder &builder, Location loc, Value tensor,
73 Level lvl) {
74 auto enc = getSparseTensorEncoding(tensor.getType());
75 return createOrFoldSliceStrideOp(builder, loc, tensor, toDim(enc, lvl));
76}
77
78//===----------------------------------------------------------------------===//
79// Sparse tensor loop emitter class implementations
80//===----------------------------------------------------------------------===//
81
82LoopEmitter::LoopEmitter(ValueRange tensors, StringAttr loopTag, bool hasOutput,
83 bool isSparseOut, unsigned numLoops,
84 DependentLvlGetter dimGetter,
85 SparseEmitStrategy emitStrategy) {
86 initialize(tensors, loopTag: loopTag, hasOutput, isSparseOut, numLoops, getter: dimGetter);
87}
88
89void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
90 bool isSparseOut, unsigned numLoops,
91 DependentLvlGetter dimGetter,
92 SparseEmitStrategy emitStrategy) {
93 // First initialize the top-level type of the fields.
94 this->loopTag = loopTag;
95 this->hasOutput = hasOutput;
96 this->isSparseOut = isSparseOut;
97 this->emitStrategy = emitStrategy;
98
99 const unsigned numManifestTensors = ts.size();
100 const unsigned synTensorId = numManifestTensors;
101 const unsigned numTensors = numManifestTensors + 1;
102 // tensors array (len == numManifestTensor).
103 this->tensors.assign(first: ts.begin(), last: ts.end());
104 // Arrays with len == numTensor.
105 this->valBuffer.assign(n: numTensors, val: nullptr);
106 this->lvls.resize(new_size: numTensors);
107 this->iters.resize(new_size: numTensors);
108
109 // These zeros will be overwritten below, but we need to initialize
110 // them to something since we'll need random-access assignment.
111 this->loopStack.reserve(n: numLoops);
112 this->loopSeqStack.reserve(n: numLoops);
113
114 // Index-reduction related fields.
115 this->dependentLvlMap.assign(
116 n: numTensors, val: std::vector<std::vector<std::pair<TensorLevel, unsigned>>>());
117 this->sliceMeta.assign(
118 n: numTensors, val: std::vector<std::vector<std::pair<Value, unsigned>>>());
119 this->levelReducedDep.assign(n: numTensors, val: std::vector<unsigned>());
120
121 // Initialize nested types of `TensorId`-indexed fields.
122 for (TensorId tid = 0; tid < numTensors; tid++) {
123 Level lvlRank;
124 if (tid == synTensorId) {
125 // Synthetic tensor (conceptually) is an all-dense tensor with rank equal
126 // to the total number of loops (each level can potentially be mapped to
127 // one of the loop being generated).
128 lvlRank = numLoops;
129 } else {
130 const Value t = tensors[tid];
131 // a scalar or 0-dimension tensors
132 if (isZeroRankedTensorOrScalar(type: t.getType()))
133 continue;
134
135 auto rtp = getRankedTensorType(t);
136 const SparseTensorType stt(rtp);
137 lvlRank = stt.getLvlRank();
138 }
139
140 lvls[tid].resize(new_size: lvlRank);
141 iters[tid].resize(new_size: lvlRank);
142 loopHighs.assign(n: numLoops, val: nullptr);
143
144 // Slice-driven loops related initialization.
145 levelReducedDep[tid].assign(n: lvlRank, val: 0);
146 dependentLvlMap[tid].assign(
147 n: lvlRank, val: std::vector<std::pair<TensorLevel, unsigned>>());
148 sliceMeta[tid].assign(n: lvlRank, val: std::vector<std::pair<Value, unsigned>>());
149 if (dimGetter && !isSynTensor(tid)) {
150 for (Level l = 0; l < lvlRank; l++) {
151 std::vector<std::pair<LoopId, unsigned>> deps = dimGetter(tid, l);
152 // Sort the loop by order.
153 std::sort(first: deps.begin(), last: deps.end(),
154 comp: [](auto &lhs, auto &rhs) { return lhs.first < rhs.first; });
155
156 dependentLvlMap[tid][l] = std::move(deps);
157 unsigned depends = dependentLvlMap[tid][l].size();
158 if (depends == 0)
159 continue;
160 sliceMeta[tid][l].reserve(n: depends);
161 }
162 }
163 }
164}
165
166std::unique_ptr<SparseIterator>
167LoopEmitter::makeLevelIterator(OpBuilder &builder, Location loc, TensorId t,
168 Level l) {
169 auto it = makeSimpleIterator(stl: *lvls[t][l], strategy: emitStrategy);
170 auto stt = getSparseTensorType(val: tensors[t]);
171 if (stt.hasEncoding() && stt.getEncoding().isSlice()) {
172 Value offset = genSliceOffset(builder, loc, tensor: tensors[t], lvl: l);
173 Value stride = genSliceStride(builder, loc, tensor: tensors[t], lvl: l);
174 auto slicedIt = makeSlicedLevelIterator(
175 sit: std::move(it), offset, stride, size: lvls[t][l]->getSize(), strategy: emitStrategy);
176 return slicedIt;
177 }
178 return it;
179}
180
181void LoopEmitter::initializeLoopEmit(
182 OpBuilder &builder, Location loc, LoopEmitter::OutputUpdater updater,
183 LoopEmitter::SynTensorBoundSetter synSetter) {
184 // For every synthetic tensor, set the high bound by calling the callback.
185 if (synSetter) {
186 TensorId synId = getSynTensorId();
187 for (unsigned i = 0, e = loopHighs.size(); i < e; i++) {
188 Value sz = loopHighs[i] = synSetter(builder, loc, i);
189 auto [stl, it] = makeSynLevelAndIterator(sz, tid: synId, lvl: i, strategy: emitStrategy);
190 lvls[synId][i] = std::move(stl);
191 iters[synId][i].emplace_back(args: std::move(it));
192 }
193 }
194
195 // For every manifest tensor:
196 // * get the values buffer.
197 // * For every level:
198 // * get the positions and coordinates buffers
199 // * get/compute the level-size, which is also used as the upper-bound
200 // on positions.
201 for (TensorId t = 0, numTensors = getNumManifestTensors(); t < numTensors;
202 t++) {
203 const Value tensor = tensors[t];
204 const auto rtp = dyn_cast<RankedTensorType>(tensor.getType());
205 if (!rtp)
206 // Skips only scalar, zero ranked tensor still need to be bufferized and
207 // (probably) filled with zeros by users.
208 continue;
209 // FIXME: the definition of `lvlRank` looks more like a dim-rank;
210 // but the variable is used as a level everywhere below, which
211 // suggests there may be some dim/lvl confusion going on here.
212 auto stt = getSparseTensorType(val: tensor);
213 const Level lvlRank = stt.getLvlRank();
214 const auto shape = rtp.getShape();
215
216 SmallVector<Value> lvlSzs;
217 for (Level l = 0; l < stt.getLvlRank(); l++) {
218 if (stt.hasEncoding())
219 lvlSzs.push_back(builder.create<LvlOp>(loc, tensor, l));
220 else
221 lvlSzs.push_back(builder.create<tensor::DimOp>(loc, tensor, l));
222 }
223
224 // Scan all levels of current tensor.
225 for (Level l = 0; l < lvlRank; l++) {
226 // Find upper bound in current dimension.
227 lvls[t][l] = makeSparseTensorLevel(builder, loc, t: tensor, tid: t, l);
228 if (!dependentLvlMap[t][l].empty())
229 continue;
230
231 auto it = makeLevelIterator(builder, loc, t, l);
232 iters[t][l].emplace_back(args: std::move(it));
233 }
234
235 // Perform the required bufferization. Dense inputs materialize
236 // from the input tensors. Sparse inputs use sparse primitives to obtain the
237 // values.
238 // Delegates extra output initialization to clients.
239 bool isOutput = isOutputTensor(tid: t);
240 Type elementType = stt.getElementType();
241 if (!stt.hasEncoding()) {
242 // Non-annotated dense tensors.
243 BaseMemRefType denseTp = MemRefType::get(shape, elementType);
244
245 // TODO: if we unconditionally use fully dynamic layout here, it breaks
246 // some vectorization passes which requires static stride = 1.
247 // Is it possible to call vectorization pass after bufferization?
248 if (llvm::isa_and_nonnull<tensor::ExtractSliceOp>(Val: tensor.getDefiningOp()))
249 denseTp = bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType: rtp);
250
251 Value denseVal =
252 builder.create<bufferization::ToMemrefOp>(loc, denseTp, tensor);
253 // Dense outputs need special handling.
254 if (isOutput && updater)
255 denseVal = updater(builder, loc, denseVal, tensor);
256
257 valBuffer[t] = denseVal;
258 } else {
259 // Annotated sparse tensors.
260 // We also need the value buffer for all-dense annotated "sparse"
261 // tensors.
262 valBuffer[t] = builder.create<ToValuesOp>(loc, tensor);
263 }
264 // NOTE: we can also prepare for 0 lvl here in advance, this will hoist
265 // some loop preparation from tensor iteration, but will also (undesirably)
266 // hoist the code ouside if-conditions.
267 }
268 // TODO: avoid treating subsection iterator as a special case.
269 initSubSectIterator(builder, loc);
270}
271
272void LoopEmitter::initSubSectIterator(OpBuilder &builder, Location loc) {
273 Value c0 = C_IDX(0);
274 for (TensorId t = 0, e = tensors.size(); t < e; t++) {
275 auto rtp = dyn_cast<RankedTensorType>(tensors[t].getType());
276 if (!rtp)
277 continue;
278
279 Level lvlRank = SparseTensorType(rtp).getLvlRank();
280
281 // Compute the dependency reduction order.
282 auto remDepStack = dependentLvlMap;
283 std::vector<std::tuple<LoopId, TensorId, Level>> depRedOrder;
284 for (Level lvl = 0; lvl < lvlRank; lvl++) {
285 // Reverse queue into a stack.
286 std::reverse(first: remDepStack[t][lvl].begin(), last: remDepStack[t][lvl].end());
287 for (auto [loop, coeff] : dependentLvlMap[t][lvl])
288 depRedOrder.emplace_back(args: std::make_tuple(args&: loop, args&: t, args&: lvl));
289 }
290
291 if (depRedOrder.empty())
292 continue;
293
294 std::sort(first: depRedOrder.begin(), last: depRedOrder.end(),
295 comp: [](auto &l, auto &r) { return std::get<0>(l) < std::get<0>(r); });
296
297 SmallVector<SparseIterator *> lastIter(tensors.size(), nullptr);
298 for (auto [loop, t, lvl] : depRedOrder) {
299 std::pair<LoopId, unsigned> curDep = remDepStack[t][lvl].back();
300 assert(curDep.first == loop);
301 remDepStack[t][lvl].pop_back();
302
303 auto lvlIt = makeLevelIterator(builder, loc, t, l: lvl);
304 const SparseIterator *parent = lastIter[t];
305 if (!parent && lvl > 0) {
306 if (dependentLvlMap[t][lvl - 1].empty()) {
307 parent = iters[t][lvl - 1].back().get();
308 }
309 }
310
311 std::unique_ptr<SparseIterator> it;
312 if (!remDepStack[t][lvl].empty()) {
313 // Compute the subsection size.
314 Value size = c0;
315 for (auto [loop, stride] : remDepStack[t][lvl]) {
316 Value idxMax = SUBI(loopHighs[loop], C_IDX(1));
317 size = ADDI(size, ADDI(MULI(idxMax, C_IDX(stride)), C_IDX(1)));
318 }
319 it = makeNonEmptySubSectIterator(b&: builder, l: loc, parent, loopBound: loopHighs[loop],
320 delegate: std::move(lvlIt), size, stride: curDep.second,
321 strategy: emitStrategy);
322 } else {
323 const SparseIterator &subSectIter = *iters[t][lvl].back();
324 it = makeTraverseSubSectIterator(b&: builder, l: loc, subsectIter: subSectIter, parent: *parent,
325 wrap: std::move(lvlIt), loopBound: loopHighs[loop],
326 stride: curDep.second, strategy: emitStrategy);
327 }
328 lastIter[t] = it.get();
329 iters[t][lvl].emplace_back(args: std::move(it));
330 }
331 }
332}
333
334void LoopEmitter::categorizeIterators(
335 ArrayRef<TensorLevel> tidLvls, SmallVectorImpl<SparseIterator *> &raIters,
336 SmallVectorImpl<SparseIterator *> &spIters) {
337 // Finds out the tensor level that we should use to generate loops. Amongs all
338 // the tensor levels, there is at most one sparse tensor level.
339 for (auto [t, l] : unpackTensorLevelRange(c&: tidLvls)) {
340 SparseIterator *it = &getCurIterator(tid: t, lvl: l);
341 if (it->randomAccessible())
342 raIters.push_back(Elt: it);
343 else
344 spIters.push_back(Elt: it);
345 }
346
347 std::stable_sort(first: spIters.begin(), last: spIters.end(), comp: [](auto lhs, auto rhs) {
348 // AffineUnRed > Affine > Slice > Trivial
349 return static_cast<uint8_t>(lhs->kind) > static_cast<uint8_t>(rhs->kind);
350 });
351}
352
353void LoopEmitter::enterNewLoopSeq(OpBuilder &builder, Location loc,
354 ArrayRef<TensorLevel> tidLvls) {
355 // TODO: sort
356 assert(loopSeqStack.size() == loopStack.size());
357 // Prepares for all the tensors used in the current loop sequence.
358
359 for (auto [tid, lvl] : unpackTensorLevelRange(c&: tidLvls)) {
360 levelReducedDep[tid][lvl]++;
361 prepareLoopOverTensorAtLvl(builder, loc, tid, lvl);
362 }
363
364 // Universal Index starts from 0.
365 loopSeqStack.emplace_back(C_IDX(0), args: tidLvls.vec());
366}
367
368void LoopEmitter::exitCurrentLoopSeq(OpBuilder &builder, Location loc) {
369 assert(loopSeqStack.size() == loopStack.size() + 1);
370
371 // Depending on whether the slice is resolved or not at current loop sequence,
372 // end them in different ways.
373 for (auto [tid, lvl] : unpackTensorLevelRange(c&: loopSeqStack.back().second))
374 levelReducedDep[tid][lvl]--;
375
376 loopSeqStack.pop_back();
377}
378
379Value LoopEmitter::genAffine(OpBuilder &builder, Location loc, AffineExpr a) {
380 switch (a.getKind()) {
381 case AffineExprKind::DimId: {
382 // FIXME: since the one callsite in Sparsification passes in a
383 // level-expression, the `getPosition` must in fact be a `Dimension`.
384 // However, elsewhere we have been lead to expect that `loopIdToOrd`
385 // should be indexed by `LoopId`...
386 const auto loopId = cast<AffineDimExpr>(Val&: a).getPosition();
387 return loopStack[loopId].iv;
388 }
389 case AffineExprKind::Add: {
390 auto binOp = cast<AffineBinaryOpExpr>(Val&: a);
391 return ADDI(genAffine(builder, loc, binOp.getLHS()),
392 genAffine(builder, loc, binOp.getRHS()));
393 }
394 case AffineExprKind::Mul: {
395 auto binOp = cast<AffineBinaryOpExpr>(Val&: a);
396 return MULI(genAffine(builder, loc, binOp.getLHS()),
397 genAffine(builder, loc, binOp.getRHS()));
398 }
399 case AffineExprKind::Constant: {
400 int64_t c = cast<AffineConstantExpr>(Val&: a).getValue();
401 return C_IDX(c);
402 }
403 default:
404 llvm_unreachable("unexpected affine subscript");
405 }
406}
407
408std::pair<Operation *, Value> LoopEmitter::emitForLoopOverTensorAtLvl(
409 OpBuilder &builder, Location loc, SparseIterator &iter,
410 MutableArrayRef<Value> reduc, bool isParallel) {
411
412 // TODO: support dynamic slices.
413 // Uses the first dimension here to build the loop bound (which is also the
414 // biggest range).
415
416 Value step = C_IDX(1);
417 auto [lo, hi] = iter.genForCond(b&: builder, l: loc);
418 Operation *loop = nullptr;
419 Value iv;
420 if (isParallel) {
421 scf::ParallelOp parOp =
422 builder.create<scf::ParallelOp>(loc, lo, hi, step, reduc);
423 builder.setInsertionPointToStart(parOp.getBody());
424 assert(parOp.getNumReductions() == reduc.size());
425 iv = parOp.getInductionVars()[0];
426
427 // In-place update on the reduction variable vector.
428 // Note that the init vals is not the actual reduction variables but instead
429 // used as a "special handle" to (temporarily) represent them. The
430 // expression on init vals will be moved into scf.reduce and replaced with
431 // the block arguments when exiting the loop (see exitForLoop). This is
432 // needed as we can not build the actual reduction block and get the actual
433 // reduction variable before users fill parallel loop body.
434 for (int i = 0, e = reduc.size(); i < e; i++)
435 reduc[i] = parOp.getInitVals()[i];
436 loop = parOp;
437 } else {
438 scf::ForOp forOp = builder.create<scf::ForOp>(loc, lo, hi, step, reduc);
439 builder.setInsertionPointToStart(forOp.getBody());
440 iv = forOp.getInductionVar();
441
442 // In-place update on the reduction variable vector.
443 assert(forOp.getNumRegionIterArgs() == reduc.size());
444 for (int i = 0, e = reduc.size(); i < e; i++)
445 reduc[i] = forOp.getRegionIterArg(i);
446 loop = forOp;
447 }
448 assert(loop && iv);
449
450 Value crd = iv;
451 if (!iter.randomAccessible()) {
452 iter.linkNewScope(pos: iv);
453 crd = iter.deref(b&: builder, l: loc);
454 } else {
455 iter.locate(b&: builder, l: loc, crd: iv);
456 }
457
458 return {loop, crd};
459}
460
461std::pair<Operation *, Value> LoopEmitter::emitWhileLoopOverTensorsAtLvls(
462 OpBuilder &builder, Location loc, ArrayRef<SparseIterator *> spIters,
463 MutableArrayRef<Value> reduc, bool needsUniv) {
464 // NOTE: the slice driven tensor-related reduction variable must
465 // appear before normal tensors.
466
467 // The set of induction variables for the while loop.
468 SmallVector<Value> ivs;
469
470 // Construct the while-loop with a parameter for each coordinate.
471 for (SparseIterator *it : spIters) {
472 ValueRange itVals = it->getCursor();
473 ivs.append(in_start: itVals.begin(), in_end: itVals.end());
474 }
475
476 // The position where user-supplied reduction variable starts.
477 ivs.append(in_start: reduc.begin(), in_end: reduc.end());
478 // Update universal index.
479 if (needsUniv)
480 ivs.push_back(Elt: loopSeqStack.back().first);
481
482 // Ensures all operands are valid.
483 assert(llvm::all_of(ivs, [](Value v) { return v != nullptr; }));
484 TypeRange types = ValueRange(ivs).getTypes();
485 auto whileOp = builder.create<scf::WhileOp>(loc, types, ivs);
486
487 SmallVector<Location> locs(types.size(), loc);
488 Block *before = builder.createBlock(&whileOp.getBefore(), {}, types, locs);
489 Block *after = builder.createBlock(&whileOp.getAfter(), {}, types, locs);
490
491 // Generates loop conditions.
492 builder.setInsertionPointToStart(before);
493 ValueRange bArgs = before->getArguments();
494 Value whileCond = nullptr; // bool values for loop condition.
495
496 for (SparseIterator *it : spIters) {
497 auto [cond, remArgs] = it->genWhileCond(b&: builder, l: loc, vs: bArgs);
498 whileCond = !whileCond ? cond : ANDI(whileCond, cond);
499 bArgs = remArgs;
500 }
501 // The remaining block arguments are user-provided reduction values and an
502 // optional universal index. Make sure their sizes match.
503 assert(bArgs.size() == reduc.size() + needsUniv ? 1 : 0);
504 builder.create<scf::ConditionOp>(loc, whileCond, before->getArguments());
505
506 // Generates loop body.
507 builder.setInsertionPointToStart(after);
508 ValueRange aArgs = after->getArguments();
509 // Since some LoopCondKind might need extra checks to filter out invalid
510 // iterations, we maintains another array to hold the iteration arguments to
511 // yield if the checks fails.
512 SmallVector<Value> nextArgs(aArgs.begin(), aArgs.end());
513
514 for (SparseIterator *it : spIters) {
515 aArgs = it->linkNewScope(pos: aArgs);
516 // Dereference the iterator to cache the coordinate.
517 it->deref(b&: builder, l: loc);
518 }
519
520 // In-place update on reduction variable.
521 assert(aArgs.size() == reduc.size() + needsUniv ? 1 : 0);
522 for (unsigned i = 0, e = reduc.size(); i < e; i++)
523 reduc[i] = aArgs[i];
524
525 Value min;
526 // Finds the minimum coordinate
527 if (!needsUniv) {
528 for (SparseIterator *it : spIters) {
529 if (min) {
530 Value cmp = CMPI(ult, it->getCrd(), min);
531 min = SELECT(cmp, it->getCrd(), min);
532 } else {
533 min = it->getCrd();
534 }
535 }
536 } else {
537 // Otherwise, universal index is the minimal pos.
538 min = whileOp.getAfterArguments().back();
539 }
540
541 return {whileOp, min};
542}
543
544bool LoopEmitter::shouldIteratedByForLoop(ArrayRef<SparseIterator *> spIters) {
545 // If we need to co-iterate over two sparse tensors, we need a while loop
546 if (spIters.size() > 1)
547 return false;
548
549 if (spIters.size() == 1)
550 return spIters.front()->iteratableByFor();
551
552 return true;
553}
554
555Operation *LoopEmitter::enterCoIterationOverTensorsAtLvls(
556 OpBuilder &builder, Location loc, ArrayRef<TensorLevel> tidLvls,
557 MutableArrayRef<Value> reduc, bool tryParallel, bool needsUniv) {
558
559 // TODO: support multiple return on parallel for?
560 tryParallel = tryParallel && reduc.size() <= 1;
561
562 SmallVector<SparseIterator *> raIters;
563 SmallVector<SparseIterator *> spIters;
564 categorizeIterators(tidLvls, raIters, spIters);
565
566 // Only when there is at least one sparse conditions, do we really need the
567 // universal index.
568 // TODO: Maybe we should instead requires merger to pass in a valid value at
569 // the first place instead of adjusting it in LoopEmitter?
570 needsUniv = !spIters.empty() && needsUniv;
571 // The TensorLevel used for loop conditions.
572 // If there is any sparse level, we need to use the sparse condition.
573 // If all levels are dense, we can pick arbitrary one (dense slice-driven loop
574 // can be generated using a simple ForOp as well).
575 Operation *l = nullptr;
576 Value iv = nullptr;
577 SmallVector<TensorLevel> tls;
578
579 // Generates loops differently depending on whether we need a slice-driven
580 // loop or a simple level traversal loop.
581 if (shouldIteratedByForLoop(spIters) && !needsUniv) {
582 assert(spIters.size() <= 1);
583 SparseIterator &it = spIters.empty() ? *raIters.front() : *spIters.front();
584 std::tie(args&: l, args&: iv) =
585 emitForLoopOverTensorAtLvl(builder, loc, iter&: it, reduc, isParallel: tryParallel);
586 tls.push_back(Elt: makeTensorLevel(t: it.tid, l: it.lvl));
587 } else {
588 for (auto *it : spIters) {
589 tls.push_back(Elt: makeTensorLevel(t: it->tid, l: it->lvl));
590 }
591
592 if (needsUniv)
593 for (auto *it : raIters)
594 tls.push_back(Elt: makeTensorLevel(t: it->tid, l: it->lvl));
595
596 std::tie(args&: l, args&: iv) =
597 emitWhileLoopOverTensorsAtLvls(builder, loc, spIters, reduc, needsUniv);
598 }
599
600 // Enter dense tensor levels.
601 for (SparseIterator *it : raIters)
602 it->locate(b&: builder, l: loc, crd: iv);
603
604 // NOTE: we can also prepare for next dim here in advance
605 // Pushes the loop into stack.
606 loopStack.emplace_back(tls, l, builder.getInsertionBlock(), iv, loopTag);
607 return l;
608}
609
610void LoopEmitter::locateLvlAtAffineAddress(OpBuilder &builder, Location loc,
611 TensorLevel tidLvl,
612 AffineExpr lvlExpr) {
613 auto [tid, lvl] = unpackTensorLevel(tidLvl);
614
615 const SparseIterator *parent =
616 lvl == 0 ? nullptr : iters[tid][lvl - 1].back().get();
617 auto &it = getCurIterator(tid, lvl);
618 it.genInit(b&: builder, l: loc, p: parent);
619
620 assert(it.kind == IterKind::kTrivial && it.randomAccessible());
621 Value lvlCrd = genAffine(builder, loc, a: lvlExpr);
622 it.locate(b&: builder, l: loc, crd: lvlCrd);
623}
624
625void LoopEmitter::prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc,
626 TensorId tid, Level lvl) {
627 // if this is the first level, there is no parent iterator for the current
628 // iterator.
629 // If the current iterator is a subsection-based iterator, the parent iterator
630 // is memorized by the iterator.
631 bool hasParent = lvl == 0 || !dependentLvlMap[tid][lvl].empty();
632
633 const SparseIterator *parent =
634 hasParent ? nullptr : iters[tid][lvl - 1].back().get();
635 auto &it = getCurIterator(tid, lvl);
636 it.genInit(b&: builder, l: loc, p: parent);
637
638 // Locates the randon accessible iterator to 0.
639 if (it.randomAccessible())
640 it.locate(b&: builder, l: loc, C_IDX(0));
641}
642
643void LoopEmitter::exitForLoop(RewriterBase &rewriter, Location loc,
644 MutableArrayRef<Value> reduc) {
645 const LoopInfo &loopInfo = loopStack.back();
646 if (auto forOp = llvm::dyn_cast<scf::ForOp>(loopInfo.loop)) {
647 if (!reduc.empty()) {
648 assert(reduc.size() == forOp.getNumResults());
649 rewriter.create<scf::YieldOp>(loc, reduc);
650 }
651 // Exit the loop.
652 rewriter.setInsertionPointAfter(forOp);
653 // In-place update reduction variables.
654 for (unsigned i = 0, e = forOp.getResults().size(); i < e; i++)
655 reduc[i] = forOp.getResult(i);
656 } else {
657 auto parOp = llvm::cast<scf::ParallelOp>(loopInfo.loop);
658 if (!reduc.empty()) {
659 assert(reduc.size() == parOp.getInitVals().size() && reduc.size() == 1);
660 Operation *redExp = reduc.front().getDefiningOp();
661 // Reduction expression should have no use.
662 assert(redExp->getUses().empty());
663 // This must be a binary operation.
664 // NOTE: This is users' responsibility to ensure the operation are
665 // commutative.
666 assert(redExp->getNumOperands() == 2 && redExp->getNumResults() == 1);
667
668 Value redVal = parOp.getInitVals().front();
669 Value curVal;
670 if (redExp->getOperand(idx: 0) == redVal)
671 curVal = redExp->getOperand(idx: 1);
672 else if (redExp->getOperand(idx: 1) == redVal)
673 curVal = redExp->getOperand(idx: 0);
674 // One of the operands must be the init value (which is also the
675 // previous reduction value).
676 assert(curVal);
677#ifndef NDEBUG
678 // The reduction expression should be the only user of the reduction val
679 // inside the parallel for.
680 unsigned numUsers = 0;
681 for (Operation *op : redVal.getUsers()) {
682 if (op->getParentOp() == parOp)
683 numUsers++;
684 }
685 assert(numUsers == 1);
686#endif // NDEBUG
687
688 rewriter.setInsertionPointAfter(redExp);
689 auto redOp = rewriter.create<scf::ReduceOp>(loc, curVal);
690 // Attach to the reduction op.
691 Block *redBlock = &redOp.getReductions().front().front();
692 rewriter.setInsertionPointToEnd(redBlock);
693 Operation *newRed = rewriter.clone(op&: *redExp);
694 // Replaces arguments of the reduction expression by using the block
695 // arguments from scf.reduce.
696 rewriter.modifyOpInPlace(
697 root: newRed, callable: [&]() { newRed->setOperands(redBlock->getArguments()); });
698 // Erases the out-dated reduction expression.
699 rewriter.eraseOp(op: redExp);
700 rewriter.setInsertionPointToEnd(redBlock);
701 rewriter.create<scf::ReduceReturnOp>(loc, newRed->getResult(0));
702 }
703 rewriter.setInsertionPointAfter(parOp);
704 // In-place update reduction variables.
705 for (unsigned i = 0, e = parOp.getResults().size(); i < e; i++)
706 reduc[i] = parOp.getResult(i);
707 }
708}
709
710void LoopEmitter::exitWhileLoop(OpBuilder &builder, Location loc,
711 MutableArrayRef<Value> reduc) {
712 const LoopInfo &loopInfo = loopStack.back();
713 auto whileOp = llvm::cast<scf::WhileOp>(loopInfo.loop);
714 Value iv = loopInfo.iv;
715 Value one = C_IDX(1);
716
717 // Finalize the induction. Note that the induction could be performed
718 // in the individual if-branches to avoid re-evaluating the conditions.
719 // However, that would result in a rather elaborate forest of yield
720 // instructions during code generation. Moreover, performing the induction
721 // after the if-statements more closely resembles code generated by TACO.
722 SmallVector<Value> operands;
723 ValueRange whileRes = whileOp.getResults();
724
725 for (auto [tid, lvl] : unpackTensorLevelRange(c: loopInfo.tidLvls)) {
726 SparseIterator &it = getCurIterator(tid, lvl);
727 if (!it.randomAccessible()) {
728 // Forward the sparse iterator.
729 Value cmp = CMPI(eq, it.getCrd(), iv);
730 it.forwardIf(b&: builder, l: loc, cond: cmp);
731 operands.append(in_start: it.getCursor().begin(), in_end: it.getCursor().end());
732 // const Value newPos = whileOp->getResult(o++);
733 // Following loops continue iteration from the break point of the
734 // current while loop.
735 whileRes = it.linkNewScope(pos: whileRes);
736 } else {
737 // Make sure randomly accessible (dense) iterator is set to the right
738 // position according to the universal index.
739 Value uniIdx = whileOp.getResults().back();
740 it.locate(b&: builder, l: loc, crd: uniIdx);
741 }
742 }
743
744 // Reduction value from users.
745 for (auto &i : reduc) {
746 operands.push_back(Elt: i);
747 // Update user reduction variables.
748 i = whileRes.front();
749 whileRes = whileRes.drop_front();
750 }
751
752 // An (optional) universal index.
753 if (operands.size() < whileOp.getNumResults()) {
754 assert(operands.size() + 1 == whileOp.getNumResults());
755 // The last one is the universial index.
756 operands.push_back(ADDI(iv, one));
757 // update the loop starting point of current loop sequence
758 loopSeqStack.back().first = whileOp->getResults().back();
759 }
760
761 if (!operands.empty())
762 YIELD(operands);
763
764 builder.setInsertionPointAfter(whileOp);
765}
766
767void LoopEmitter::exitCurrentLoop(RewriterBase &rewriter, Location loc,
768 MutableArrayRef<Value> reduc) {
769 // Clean up the values, it would help use to discover potential bug at a
770 // earlier stage (instead of silently using a wrong value).
771 const LoopInfo &loopInfo = loopStack.back();
772
773 // Sets the insertion point to the right position.
774 rewriter.setInsertionPointToEnd(loopInfo.userCodeBlock);
775 if (!loopInfo.userCodeBlock->empty() &&
776 llvm::isa<scf::YieldOp>(&loopInfo.userCodeBlock->back())) {
777 // scf::While/For inserts an implicit yield op when there is no loop
778 // iter args. In this case, we need to insert the code before the yield.
779 assert(loopInfo.userCodeBlock->back().getNumResults() == 0);
780 rewriter.setInsertionPoint(&loopInfo.userCodeBlock->back());
781 }
782
783 if (llvm::isa<scf::WhileOp>(loopInfo.loop)) {
784 exitWhileLoop(builder&: rewriter, loc, reduc);
785 } else {
786 exitForLoop(rewriter, loc, reduc);
787 }
788
789 assert(loopStack.size() == loopSeqStack.size());
790 loopStack.pop_back();
791}
792
793#undef CMPI
794#undef C_IDX
795#undef YIELD
796#undef ADDI
797#undef ANDI
798#undef SUBI
799#undef MULI
800#undef SELECT
801

source code of mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp