1//===- LoopEmitter.cpp ----------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "LoopEmitter.h"
10#include "CodegenUtils.h"
11
12#include "mlir/Dialect/Arith/IR/Arith.h"
13#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
14#include "mlir/Dialect/Linalg/Utils/Utils.h"
15#include "mlir/Dialect/MemRef/IR/MemRef.h"
16#include "mlir/Dialect/SCF/IR/SCF.h"
17#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
18#include "mlir/Dialect/Tensor/IR/Tensor.h"
19
20using namespace mlir;
21using namespace mlir::sparse_tensor;
22
23//===----------------------------------------------------------------------===//
24// File local shorthand macros
25//===----------------------------------------------------------------------===//
26
27#define CMPI(p, l, r) \
28 (builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::p, (l), (r)) \
29 .getResult())
30
31#define C_IDX(v) (constantIndex(builder, loc, (v)))
32#define YIELD(vs) (builder.create<scf::YieldOp>(loc, (vs)))
33#define ADDI(lhs, rhs) (builder.create<arith::AddIOp>(loc, (lhs), (rhs)))
34#define ANDI(lhs, rhs) (builder.create<arith::AndIOp>(loc, (lhs), (rhs)))
35#define SUBI(lhs, rhs) (builder.create<arith::SubIOp>(loc, (lhs), (rhs)))
36#define MULI(lhs, rhs) (builder.create<arith::MulIOp>(loc, (lhs), (rhs)))
37#define REMUI(lhs, rhs) (builder.create<arith::RemUIOp>(loc, (lhs), (rhs)))
38#define DIVUI(lhs, rhs) (builder.create<arith::DivUIOp>(loc, (lhs), (rhs)))
39#define SELECT(c, l, r) (builder.create<arith::SelectOp>(loc, (c), (l), (r)))
40
41//===----------------------------------------------------------------------===//
42// Debugging utils
43//===----------------------------------------------------------------------===//
44
45#ifndef NDEBUG
46LLVM_ATTRIBUTE_UNUSED static void dumpIndexMemRef(OpBuilder &builder,
47 Location loc, Value memref) {
48 memref = builder.create<memref::CastOp>(
49 loc, UnrankedMemRefType::get(builder.getIndexType(), 0), memref);
50 createFuncCall(builder, loc, "printMemrefInd", TypeRange{},
51 ValueRange{memref}, EmitCInterface::On);
52}
53#endif
54
55//===----------------------------------------------------------------------===//
56// File local helper functions.
57//===----------------------------------------------------------------------===//
58
59// For index reduction loops, since the tensor are sliced into non-continuous
60// fragments, we need a triple [pLo, pHi, pPtr], in which the pair (pLo, pHi)
61// specifies the range of the fragment, and pPtr specifies the index of the
62// corresponding fragment in the child level (i.e., a pointer to the sliced
63// position array).
64static Value genSliceOffset(OpBuilder &builder, Location loc, Value tensor,
65 Level lvl) {
66 auto enc = getSparseTensorEncoding(type: tensor.getType());
67 return createOrFoldSliceOffsetOp(builder, loc, tensor, dim: toDim(enc, l: lvl));
68}
69
70static Value genSliceStride(OpBuilder &builder, Location loc, Value tensor,
71 Level lvl) {
72 auto enc = getSparseTensorEncoding(type: tensor.getType());
73 return createOrFoldSliceStrideOp(builder, loc, tensor, dim: toDim(enc, l: lvl));
74}
75
76static bool isIntOrFPZero(Attribute attr) {
77 if (auto f = llvm::dyn_cast<FloatAttr>(Val&: attr); f && f.getValue().isZero())
78 return true;
79 if (auto i = llvm::dyn_cast<IntegerAttr>(Val&: attr); i && i.getValue().isZero())
80 return true;
81 return false;
82}
83
84static Value unFoldOpIntResult(OpBuilder &builder, Location loc,
85 OpFoldResult ofr) {
86 if (std::optional<int64_t> i = getConstantIntValue(ofr); i.has_value())
87 return constantIndex(builder, loc, i: *i);
88 return cast<Value>(Val&: ofr);
89}
90
91static Value tryFoldTensors(Value t) {
92 // TODO: this should be done through a folding pass after switching to
93 // `sparse_tensor.iterate`-based sparsification.
94 auto stt = tryGetSparseTensorType(val: t);
95 auto padOp = t.getDefiningOp<tensor::PadOp>();
96 if (padOp && stt.has_value() && stt->hasEncoding() &&
97 padOp.getSourceType().getEncoding() == stt->getEncoding() &&
98 stt->getEncoding().isIdentity()) {
99 // Try fusing padOp with zeros.
100 Attribute padCst;
101 if (matchPattern(op: padOp.getBody()->getTerminator(),
102 pattern: m_Op<tensor::YieldOp>(matchers: m_Constant(bind_value: &padCst))) &&
103 isIntOrFPZero(attr: padCst)) {
104 return padOp.getSource();
105 }
106 }
107 return t;
108}
109
110//===----------------------------------------------------------------------===//
111// Sparse tensor loop emitter class implementations
112//===----------------------------------------------------------------------===//
113
114LoopEmitter::LoopEmitter(ValueRange tensors, StringAttr loopTag, bool hasOutput,
115 bool isSparseOut, unsigned numLoops,
116 DependentLvlGetter dimGetter,
117 SparseEmitStrategy emitStrategy) {
118 initialize(tensors, loopTag, hasOutput, isSparseOut, numLoops, getter: dimGetter);
119}
120
121void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
122 bool isSparseOut, unsigned numLoops,
123 DependentLvlGetter dimGetter,
124 SparseEmitStrategy emitStrategy) {
125 // First initialize the top-level type of the fields.
126 this->loopTag = loopTag;
127 this->hasOutput = hasOutput;
128 this->isSparseOut = isSparseOut;
129 this->emitStrategy = emitStrategy;
130
131 const unsigned numManifestTensors = ts.size();
132 const unsigned synTensorId = numManifestTensors;
133 const unsigned numTensors = numManifestTensors + 1;
134 // tensors array (len == numManifestTensor).
135 this->tensors.assign(first: ts.begin(), last: ts.end());
136 // Arrays with len == numTensor.
137 this->valBuffer.assign(n: numTensors, val: nullptr);
138 this->lvls.resize(new_size: numTensors);
139 this->iters.resize(new_size: numTensors);
140 this->spIterVals.resize(new_size: numTensors);
141
142 // These zeros will be overwritten below, but we need to initialize
143 // them to something since we'll need random-access assignment.
144 this->loopStack.reserve(n: numLoops);
145 this->loopSeqStack.reserve(n: numLoops);
146
147 // Index-reduction related fields.
148 this->dependentLvlMap.assign(
149 n: numTensors, val: std::vector<std::vector<std::pair<TensorLevel, unsigned>>>());
150 this->sliceMeta.assign(
151 n: numTensors, val: std::vector<std::vector<std::pair<Value, unsigned>>>());
152 this->levelReducedDep.assign(n: numTensors, val: std::vector<unsigned>());
153
154 // Initialize nested types of `TensorId`-indexed fields.
155 for (TensorId tid = 0; tid < numTensors; tid++) {
156 Level lvlRank;
157 if (tid == synTensorId) {
158 // Synthetic tensor (conceptually) is an all-dense tensor with rank equal
159 // to the total number of loops (each level can potentially be mapped to
160 // one of the loop being generated).
161 lvlRank = numLoops;
162 } else {
163 const Value t = tensors[tid];
164 // a scalar or 0-dimension tensors
165 if (isZeroRankedTensorOrScalar(type: t.getType()))
166 continue;
167
168 auto rtp = getRankedTensorType(t);
169 const SparseTensorType stt(rtp);
170 lvlRank = stt.getLvlRank();
171 }
172
173 lvls[tid].resize(new_size: lvlRank);
174 iters[tid].resize(new_size: lvlRank);
175 spIterVals[tid].resize(new_size: lvlRank);
176 loopHighs.assign(n: numLoops, val: nullptr);
177
178 // Slice-driven loops related initialization.
179 levelReducedDep[tid].assign(n: lvlRank, val: 0);
180 dependentLvlMap[tid].assign(
181 n: lvlRank, val: std::vector<std::pair<TensorLevel, unsigned>>());
182 sliceMeta[tid].assign(n: lvlRank, val: std::vector<std::pair<Value, unsigned>>());
183 if (dimGetter && !isSynTensor(tid)) {
184 for (Level l = 0; l < lvlRank; l++) {
185 std::vector<std::pair<LoopId, unsigned>> deps = dimGetter(tid, l);
186 // Sort the loop by order.
187 llvm::sort(C&: deps, Comp: llvm::less_first());
188
189 dependentLvlMap[tid][l] = std::move(deps);
190 unsigned depends = dependentLvlMap[tid][l].size();
191 if (depends == 0)
192 continue;
193 sliceMeta[tid][l].reserve(n: depends);
194 }
195 }
196 }
197}
198
199std::unique_ptr<SparseIterator>
200LoopEmitter::makeLevelIterator(OpBuilder &builder, Location loc, TensorId t,
201 Level l) {
202 Value tensor = tensors[t];
203 auto stt = getSparseTensorType(val: tensor);
204 auto it = makeSimpleIterator(stl: *lvls[t][l], strategy: emitStrategy);
205
206 Value folded = tryFoldTensors(t: tensor);
207 if (folded != tensor) {
208 auto padOp = tensor.getDefiningOp<tensor::PadOp>();
209 assert(padOp);
210 if (padOp.getPaddedDims().test(Idx: l)) {
211 Value low = unFoldOpIntResult(builder, loc, ofr: padOp.getMixedLowPad()[l]);
212 Value high = unFoldOpIntResult(builder, loc, ofr: padOp.getMixedHighPad()[l]);
213 auto padIt = makePaddedIterator(sit: std::move(it), padLow: low, padHigh: high, strategy: emitStrategy);
214 return padIt;
215 }
216 }
217
218 if (stt.hasEncoding() && stt.getEncoding().isSlice()) {
219 Value offset = genSliceOffset(builder, loc, tensor, lvl: l);
220 Value stride = genSliceStride(builder, loc, tensor, lvl: l);
221 auto slicedIt = makeSlicedLevelIterator(
222 sit: std::move(it), offset, stride, size: lvls[t][l]->getSize(), strategy: emitStrategy);
223 return slicedIt;
224 }
225
226 return it;
227}
228
229void LoopEmitter::initializeLoopEmit(
230 OpBuilder &builder, Location loc, LoopEmitter::OutputUpdater updater,
231 LoopEmitter::SynTensorBoundSetter synSetter) {
232
233 // For every manifest tensor, set up the values buffer.
234 for (TensorId t = 0, numTensors = getNumManifestTensors(); t < numTensors;
235 t++) {
236 // TODO: this should be done through a folding pass after switching to
237 // `sparse_tensor.iterate`-based sparsification.
238 const Value tensor = tryFoldTensors(t: tensors[t]);
239 const auto rtp = dyn_cast<RankedTensorType>(Val: tensor.getType());
240 // Skips only scalar, zero ranked tensor still need to be bufferized and
241 // (probably) filled with zeros by users.
242 if (!rtp)
243 continue;
244
245 auto stt = getSparseTensorType(val: tensor);
246 const auto shape = rtp.getShape();
247
248 // Perform the required bufferization. Dense inputs materialize from the
249 // input tensors. Sparse inputs use sparse primitives to obtain the values.
250 // Delegates extra output initialization to clients.
251 bool isOutput = isOutputTensor(tid: t);
252 Type elementType = stt.getElementType();
253 if (!stt.hasEncoding()) {
254 // Non-annotated dense tensors.
255 BaseMemRefType denseTp = MemRefType::get(shape, elementType);
256
257 // TODO: if we unconditionally use fully dynamic layout here, it breaks
258 // some vectorization passes which requires static stride = 1.
259 // Is it possible to call vectorization pass after bufferization?
260 if (llvm::isa_and_nonnull<tensor::ExtractSliceOp>(Val: tensor.getDefiningOp()))
261 denseTp = bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType: rtp);
262
263 Value denseVal =
264 builder.create<bufferization::ToBufferOp>(location: loc, args&: denseTp, args: tensor);
265 // Dense outputs need special handling.
266 if (isOutput && updater)
267 denseVal = updater(builder, loc, denseVal, tensor);
268
269 valBuffer[t] = denseVal;
270 } else {
271 // Annotated sparse tensors.
272 // We also need the value buffer for all-dense annotated "sparse"
273 // tensors.
274 valBuffer[t] = builder.create<ToValuesOp>(location: loc, args: tensor);
275 }
276 }
277
278 // The sparse iterator values will only be available after the loop is
279 // constructed.
280 if (emitStrategy == SparseEmitStrategy::kSparseIterator)
281 return;
282
283 // For every synthetic tensor, set the high bound by calling the callback.
284 if (synSetter) {
285 TensorId synId = getSynTensorId();
286 for (unsigned i = 0, e = loopHighs.size(); i < e; i++) {
287 Value sz = loopHighs[i] = synSetter(builder, loc, i);
288 auto [stl, it] = makeSynLevelAndIterator(sz, tid: synId, lvl: i, strategy: emitStrategy);
289 lvls[synId][i] = std::move(stl);
290 iters[synId][i].emplace_back(args: std::move(it));
291 }
292 }
293
294 // For every manifest tensor:
295 // * For every level:
296 // * get the positions and coordinates buffers
297 // * get/compute the level-size, which is also used as the upper-bound
298 // on positions.
299 for (TensorId t = 0, numTensors = getNumManifestTensors(); t < numTensors;
300 t++) {
301 // TODO: this should be done through a folding pass after switching to
302 // `sparse_tensor.iterate`-based sparsification.
303 const Value tensor = tryFoldTensors(t: tensors[t]);
304 const auto rtp = dyn_cast<RankedTensorType>(Val: tensor.getType());
305 if (!rtp)
306 // Skips only scalar, zero ranked tensor still need to be bufferized and
307 // (probably) filled with zeros by users.
308 continue;
309
310 auto stt = getSparseTensorType(val: tensor);
311 const Level lvlRank = stt.getLvlRank();
312
313 // Scan all levels of current tensor.
314 for (Level l = 0; l < lvlRank; l++) {
315 // Find upper bound in current dimension.
316 lvls[t][l] = makeSparseTensorLevel(b&: builder, l: loc, t: tensor, tid: t, lvl: l);
317 if (!dependentLvlMap[t][l].empty())
318 continue;
319
320 auto it = makeLevelIterator(builder, loc, t, l);
321 iters[t][l].emplace_back(args: std::move(it));
322 }
323 // NOTE: we can also prepare for 0 lvl here in advance, this will hoist
324 // some loop preparation from tensor iteration, but will also (undesirably)
325 // hoist the code ouside if-conditions.
326 }
327 // TODO: avoid treating subsection iterator as a special case.
328 initSubSectIterator(builder, loc);
329}
330
331void LoopEmitter::initSubSectIterator(OpBuilder &builder, Location loc) {
332 Value c0 = C_IDX(0);
333 for (TensorId t = 0, e = tensors.size(); t < e; t++) {
334 auto rtp = dyn_cast<RankedTensorType>(Val: tensors[t].getType());
335 if (!rtp)
336 continue;
337
338 Level lvlRank = SparseTensorType(rtp).getLvlRank();
339
340 // Compute the dependency reduction order.
341 auto remDepStack = dependentLvlMap;
342 std::vector<std::tuple<LoopId, TensorId, Level>> depRedOrder;
343 for (Level lvl = 0; lvl < lvlRank; lvl++) {
344 // Reverse queue into a stack.
345 std::reverse(first: remDepStack[t][lvl].begin(), last: remDepStack[t][lvl].end());
346 for (auto [loop, coeff] : dependentLvlMap[t][lvl])
347 depRedOrder.emplace_back(args: std::make_tuple(args&: loop, args&: t, args&: lvl));
348 }
349
350 if (depRedOrder.empty())
351 continue;
352
353 llvm::sort(C&: depRedOrder, Comp: llvm::less_first());
354
355 SmallVector<SparseIterator *> lastIter(tensors.size(), nullptr);
356 for (auto [loop, t, lvl] : depRedOrder) {
357 std::pair<LoopId, unsigned> curDep = remDepStack[t][lvl].back();
358 assert(curDep.first == loop);
359 remDepStack[t][lvl].pop_back();
360
361 auto lvlIt = makeLevelIterator(builder, loc, t, l: lvl);
362 const SparseIterator *parent = lastIter[t];
363 if (!parent && lvl > 0) {
364 if (dependentLvlMap[t][lvl - 1].empty()) {
365 parent = iters[t][lvl - 1].back().get();
366 }
367 }
368
369 std::unique_ptr<SparseIterator> it;
370 if (!remDepStack[t][lvl].empty()) {
371 // Compute the subsection size.
372 Value size = c0;
373 for (auto [loop, stride] : remDepStack[t][lvl]) {
374 Value idxMax = SUBI(loopHighs[loop], C_IDX(1));
375 size = ADDI(size, ADDI(MULI(idxMax, C_IDX(stride)), C_IDX(1)));
376 }
377 it = makeNonEmptySubSectIterator(b&: builder, l: loc, parent, loopBound: loopHighs[loop],
378 delegate: std::move(lvlIt), size, stride: curDep.second,
379 strategy: emitStrategy);
380 } else {
381 const SparseIterator &subSectIter = *iters[t][lvl].back();
382 it = makeTraverseSubSectIterator(b&: builder, l: loc, subsectIter: subSectIter, parent: *parent,
383 wrap: std::move(lvlIt), loopBound: loopHighs[loop],
384 stride: curDep.second, strategy: emitStrategy);
385 }
386 lastIter[t] = it.get();
387 iters[t][lvl].emplace_back(args: std::move(it));
388 }
389 }
390}
391
392void LoopEmitter::categorizeIterators(
393 ArrayRef<TensorLevel> tidLvls, SmallVectorImpl<SparseIterator *> &raIters,
394 SmallVectorImpl<SparseIterator *> &spIters) {
395 // Finds out the tensor level that we should use to generate loops. Amongs all
396 // the tensor levels, there is at most one sparse tensor level.
397 for (auto [t, l] : unpackTensorLevelRange(c&: tidLvls)) {
398 SparseIterator *it = &getCurIterator(tid: t, lvl: l);
399 if (it->randomAccessible())
400 raIters.push_back(Elt: it);
401 else
402 spIters.push_back(Elt: it);
403 }
404
405 llvm::stable_sort(Range&: spIters, C: [](auto lhs, auto rhs) {
406 // AffineUnRed > Affine > Slice > Trivial
407 return static_cast<uint8_t>(lhs->kind) > static_cast<uint8_t>(rhs->kind);
408 });
409}
410
411void LoopEmitter::enterNewLoopSeq(OpBuilder &builder, Location loc,
412 ArrayRef<TensorLevel> tidLvls) {
413 // TODO: sort
414 assert(loopSeqStack.size() == loopStack.size());
415
416 if (emitStrategy != SparseEmitStrategy::kSparseIterator) {
417 // Prepares for all the tensors used in the current loop sequence.
418 for (auto [tid, lvl] : unpackTensorLevelRange(c&: tidLvls)) {
419 levelReducedDep[tid][lvl]++;
420 prepareLoopOverTensorAtLvl(builder, loc, tid, lvl);
421 }
422 }
423
424 // Universal Index starts from 0.
425 loopSeqStack.emplace_back(C_IDX(0), args: tidLvls.vec());
426}
427
428void LoopEmitter::exitCurrentLoopSeq(OpBuilder &builder, Location loc) {
429 assert(loopSeqStack.size() == loopStack.size() + 1);
430
431 // Depending on whether the slice is resolved or not at current loop sequence,
432 // end them in different ways.
433 for (auto [tid, lvl] : unpackTensorLevelRange(c&: loopSeqStack.back().second))
434 levelReducedDep[tid][lvl]--;
435
436 loopSeqStack.pop_back();
437}
438
439Value LoopEmitter::genAffine(OpBuilder &builder, Location loc, AffineExpr a) {
440 switch (a.getKind()) {
441 case AffineExprKind::DimId: {
442 // FIXME: since the one callsite in Sparsification passes in a
443 // level-expression, the `getPosition` must in fact be a `Dimension`.
444 // However, elsewhere we have been lead to expect that `loopIdToOrd`
445 // should be indexed by `LoopId`...
446 const auto loopId = cast<AffineDimExpr>(Val&: a).getPosition();
447 return loopStack[loopId].iv;
448 }
449 case AffineExprKind::Add: {
450 auto binOp = cast<AffineBinaryOpExpr>(Val&: a);
451 return ADDI(genAffine(builder, loc, binOp.getLHS()),
452 genAffine(builder, loc, binOp.getRHS()));
453 }
454 case AffineExprKind::Mul: {
455 auto binOp = cast<AffineBinaryOpExpr>(Val&: a);
456 return MULI(genAffine(builder, loc, binOp.getLHS()),
457 genAffine(builder, loc, binOp.getRHS()));
458 }
459 case AffineExprKind::Constant: {
460 int64_t c = cast<AffineConstantExpr>(Val&: a).getValue();
461 return C_IDX(c);
462 }
463 default:
464 llvm_unreachable("unexpected affine subscript");
465 }
466}
467
468std::pair<Operation *, Value> LoopEmitter::emitForLoopOverTensorAtLvl(
469 OpBuilder &builder, Location loc, SparseIterator &iter,
470 MutableArrayRef<Value> reduc, bool isParallel) {
471
472 // TODO: support dynamic slices.
473 // Uses the first dimension here to build the loop bound (which is also the
474 // biggest range).
475
476 Value step = C_IDX(1);
477 auto [lo, hi] = iter.genForCond(b&: builder, l: loc);
478 Operation *loop = nullptr;
479 Value iv;
480 if (isParallel) {
481 scf::ParallelOp parOp =
482 builder.create<scf::ParallelOp>(location: loc, args&: lo, args&: hi, args&: step, args&: reduc);
483 builder.setInsertionPointToStart(parOp.getBody());
484 assert(parOp.getNumReductions() == reduc.size());
485 iv = parOp.getInductionVars()[0];
486
487 // In-place update on the reduction variable vector.
488 // Note that the init vals is not the actual reduction variables but instead
489 // used as a "special handle" to (temporarily) represent them. The
490 // expression on init vals will be moved into scf.reduce and replaced with
491 // the block arguments when exiting the loop (see exitForLoop). This is
492 // needed as we can not build the actual reduction block and get the actual
493 // reduction variable before users fill parallel loop body.
494 for (int i = 0, e = reduc.size(); i < e; i++)
495 reduc[i] = parOp.getInitVals()[i];
496 loop = parOp;
497 } else {
498 scf::ForOp forOp = builder.create<scf::ForOp>(location: loc, args&: lo, args&: hi, args&: step, args&: reduc);
499 builder.setInsertionPointToStart(forOp.getBody());
500 iv = forOp.getInductionVar();
501
502 // In-place update on the reduction variable vector.
503 assert(forOp.getNumRegionIterArgs() == reduc.size());
504 for (int i = 0, e = reduc.size(); i < e; i++)
505 reduc[i] = forOp.getRegionIterArg(index: i);
506 loop = forOp;
507 }
508 assert(loop && iv);
509
510 Value crd = iv;
511 if (!iter.randomAccessible()) {
512 iter.linkNewScope(pos: iv);
513 crd = iter.deref(b&: builder, l: loc);
514 } else {
515 iter.locate(b&: builder, l: loc, crd: iv);
516 }
517
518 return {loop, crd};
519}
520
521std::pair<Operation *, Value> LoopEmitter::emitWhileLoopOverTensorsAtLvls(
522 OpBuilder &builder, Location loc, ArrayRef<SparseIterator *> spIters,
523 MutableArrayRef<Value> reduc, bool needsUniv) {
524 return genCoIteration(builder, loc, iters: spIters, reduc,
525 uniIdx: needsUniv ? loopSeqStack.back().first : nullptr);
526}
527
528bool LoopEmitter::shouldIteratedByForLoop(ArrayRef<SparseIterator *> spIters) {
529 // If we need to co-iterate over two sparse tensors, we need a while loop
530 if (spIters.size() > 1)
531 return false;
532
533 if (spIters.size() == 1)
534 return spIters.front()->iteratableByFor();
535
536 return true;
537}
538
539Region *LoopEmitter::enterCurrentCoIterationCase(OpBuilder &builder,
540 Location loc,
541 I64BitSet caseBit,
542 unsigned caseIdx,
543 MutableArrayRef<Value> reduc) {
544 auto coIterOp = cast<CoIterateOp>(Val: loopStack.back().loop);
545 SmallVector<Attribute> cases(coIterOp.getCases().getAsRange<Attribute>());
546 cases[caseIdx] = builder.getI64IntegerAttr(value: caseBit);
547
548 coIterOp.setCasesAttr(builder.getArrayAttr(value: cases));
549 Region &caseRegion = coIterOp.getRegion(i: caseIdx);
550 assert(caseRegion.getBlocks().empty() &&
551 "re-initialize the same coiteration case region.");
552
553 // Each block starts with by a list of user-provided iteration arguments.
554 TypeRange iterArgsTps = coIterOp.getInitArgs().getTypes();
555 // Followed by a list of used coordinates of index type.
556 SmallVector<Type> blockArgTps(coIterOp.getCrdUsedLvls().count(),
557 builder.getIndexType());
558
559 blockArgTps.append(in_start: iterArgsTps.begin(), in_end: iterArgsTps.end());
560 // Ends with a set of iterators that defines the actually iteration space.
561 for (auto i : caseBit.bits()) {
562 blockArgTps.push_back(
563 Elt: cast<IterSpaceType>(Val: coIterOp.getIterSpaces()[i].getType())
564 .getIteratorType());
565 }
566 SmallVector<Location> locs(blockArgTps.size(), loc);
567 caseRegion.emplaceBlock().addArguments(types: blockArgTps, locs);
568
569 // Entering the new region scope, updating the SSA chain.
570 builder.setInsertionPointToStart(&caseRegion.front());
571 // Update the coordinates.
572 loopStack.back().iv = coIterOp.getCrds(regionIdx: caseIdx).front();
573 // Updates loop iteration arguments.
574 ValueRange iterArgs = coIterOp.getRegionIterArgs(regionIdx: caseIdx);
575 llvm::copy(Range&: iterArgs, Out: reduc.begin());
576 // Updates sparse iterator values.
577 ValueRange iters = coIterOp.getRegionIterators(regionIdx: caseIdx);
578 ArrayRef<TensorLevel> tidLvls = loopStack.back().tidLvls;
579 for (auto [i, tl] : llvm::enumerate(First: unpackTensorLevelRange(c&: tidLvls))) {
580 if (caseBit[i]) {
581 spIterVals[tl.first][tl.second] = iters.front();
582 iters = iters.drop_front();
583 } else {
584 spIterVals[tl.first][tl.second] = nullptr;
585 }
586 }
587 // Must have consumed all iterator SSA values.
588 assert(iters.empty());
589 return &caseRegion;
590}
591
592Operation *LoopEmitter::enterCoIterationOverTensorsAtLvls(
593 OpBuilder &builder, Location loc, ArrayRef<TensorLevel> tidLvls,
594 unsigned numCases, MutableArrayRef<Value> reduc, bool tryParallel,
595 bool needsUniv) {
596 // TODO: Argument `numCases` only used when generating iterator-based sparse
597 // loops. Simplify the code upon feature complete.
598 // TODO: handle coiteration with sparse iterator.
599 if (emitStrategy == SparseEmitStrategy::kSparseIterator) {
600 if (tidLvls.size() == 1) {
601 auto [tid, lvl] = unpackTensorLevel(tidLvl: tidLvls.front());
602 Value t = tensors[tid];
603
604 // Extract and iterate over the iteration space.
605 ExtractIterSpaceOp extractSpaceOp =
606 lvl == 0 ? builder.create<ExtractIterSpaceOp>(location: loc, args&: t)
607 : builder.create<ExtractIterSpaceOp>(
608 location: loc, args&: t, args&: spIterVals[tid][lvl - 1], args&: lvl);
609
610 IterateOp iterOp = builder.create<IterateOp>(
611 location: loc, args: extractSpaceOp.getExtractedSpace(), args&: reduc);
612 spIterVals[tid][lvl] = iterOp.getIterator();
613
614 // Update the reduction varaibles.
615 llvm::copy(Range: iterOp.getRegionIterArgs(), Out: reduc.begin());
616 // Set the insertion point to loop body.
617 builder.setInsertionPointToStart(iterOp.getBody());
618 loopStack.emplace_back(args&: tidLvls, args&: iterOp, args: builder.getInsertionBlock(),
619 args&: iterOp.getCrds().front(), args&: loopTag);
620 return iterOp;
621 }
622
623 // CoIteration Loops.
624 SmallVector<Value> spaces;
625 for (auto [tid, lvl] : unpackTensorLevelRange(c&: tidLvls)) {
626 Value t = tensors[tid];
627 ExtractIterSpaceOp extractSpaceOp =
628 lvl == 0 ? builder.create<ExtractIterSpaceOp>(location: loc, args&: t)
629 : builder.create<ExtractIterSpaceOp>(
630 location: loc, args&: t, args&: spIterVals[tid][lvl - 1], args&: lvl);
631 spaces.push_back(Elt: extractSpaceOp.getExtractedSpace());
632 }
633 auto coIterOp = builder.create<CoIterateOp>(location: loc, args&: spaces, args&: reduc, args&: numCases);
634 // The CoIterationOp does not have insertion block nor induction variable.
635 // TODO: the `struct LoopInfo` should be simplied after full migration.
636 loopStack.emplace_back(args&: tidLvls, args&: coIterOp, /*insertion block*/ args: nullptr,
637 /*induction variable*/ args: nullptr, args&: loopTag);
638 return coIterOp;
639 }
640
641 // TODO: support multiple return on parallel for?
642 tryParallel = tryParallel && reduc.size() <= 1;
643
644 SmallVector<SparseIterator *> raIters;
645 SmallVector<SparseIterator *> spIters;
646 categorizeIterators(tidLvls, raIters, spIters);
647
648 // Only when there is at least one sparse conditions, do we really need the
649 // universal index.
650 // TODO: Maybe we should instead requires merger to pass in a valid value at
651 // the first place instead of adjusting it in LoopEmitter?
652 needsUniv = !spIters.empty() && needsUniv;
653 // The TensorLevel used for loop conditions.
654 // If there is any sparse level, we need to use the sparse condition.
655 // If all levels are dense, we can pick arbitrary one (dense slice-driven loop
656 // can be generated using a simple ForOp as well).
657 Operation *l = nullptr;
658 Value iv = nullptr;
659 SmallVector<TensorLevel> tls;
660
661 // Generates loops differently depending on whether we need a slice-driven
662 // loop or a simple level traversal loop.
663 if (shouldIteratedByForLoop(spIters) && !needsUniv) {
664 assert(spIters.size() <= 1);
665 SparseIterator &it = spIters.empty() ? *raIters.front() : *spIters.front();
666 std::tie(args&: l, args&: iv) =
667 emitForLoopOverTensorAtLvl(builder, loc, iter&: it, reduc, isParallel: tryParallel);
668 tls.push_back(Elt: makeTensorLevel(t: it.tid, l: it.lvl));
669 } else {
670 for (auto *it : spIters) {
671 tls.push_back(Elt: makeTensorLevel(t: it->tid, l: it->lvl));
672 }
673
674 if (needsUniv)
675 for (auto *it : raIters)
676 tls.push_back(Elt: makeTensorLevel(t: it->tid, l: it->lvl));
677
678 std::tie(args&: l, args&: iv) =
679 emitWhileLoopOverTensorsAtLvls(builder, loc, spIters, reduc, needsUniv);
680 }
681
682 // Enter dense tensor levels.
683 for (SparseIterator *it : raIters)
684 it->locate(b&: builder, l: loc, crd: iv);
685
686 // NOTE: we can also prepare for next dim here in advance
687 // Pushes the loop into stack.
688 loopStack.emplace_back(args&: tls, args&: l, args: builder.getInsertionBlock(), args&: iv, args&: loopTag);
689 return l;
690}
691
692void LoopEmitter::locateLvlAtAffineAddress(OpBuilder &builder, Location loc,
693 TensorLevel tidLvl,
694 AffineExpr lvlExpr) {
695 auto [tid, lvl] = unpackTensorLevel(tidLvl);
696
697 const SparseIterator *parent =
698 lvl == 0 ? nullptr : iters[tid][lvl - 1].back().get();
699 auto &it = getCurIterator(tid, lvl);
700 it.genInit(b&: builder, l: loc, p: parent);
701
702 assert(it.kind == IterKind::kTrivial && it.randomAccessible());
703 Value lvlCrd = genAffine(builder, loc, a: lvlExpr);
704 it.locate(b&: builder, l: loc, crd: lvlCrd);
705}
706
707void LoopEmitter::prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc,
708 TensorId tid, Level lvl) {
709 // if this is the first level, there is no parent iterator for the current
710 // iterator.
711 // If the current iterator is a subsection-based iterator, the parent iterator
712 // is memorized by the iterator.
713 bool hasParent = lvl == 0 || !dependentLvlMap[tid][lvl].empty();
714
715 const SparseIterator *parent =
716 hasParent ? nullptr : iters[tid][lvl - 1].back().get();
717 auto &it = getCurIterator(tid, lvl);
718 it.genInit(b&: builder, l: loc, p: parent);
719
720 // Locates the randon accessible iterator to 0.
721 if (it.randomAccessible())
722 it.locate(b&: builder, l: loc, C_IDX(0));
723}
724
725void LoopEmitter::exitForLoop(RewriterBase &rewriter, Location loc,
726 MutableArrayRef<Value> reduc) {
727 const LoopInfo &loopInfo = loopStack.back();
728 if (emitStrategy == SparseEmitStrategy::kSparseIterator) {
729 auto iterateOp = llvm::cast<IterateOp>(Val: loopInfo.loop);
730 assert(reduc.size() == iterateOp.getNumResults());
731 rewriter.create<sparse_tensor::YieldOp>(location: loc, args&: reduc);
732 // Exit the loop.
733 rewriter.setInsertionPointAfter(iterateOp);
734 // In-place update reduction variables.
735 llvm::copy(Range: iterateOp.getResults(), Out: reduc.begin());
736 return;
737 }
738 if (auto forOp = llvm::dyn_cast<scf::ForOp>(Val: loopInfo.loop)) {
739 if (!reduc.empty()) {
740 assert(reduc.size() == forOp.getNumResults());
741 rewriter.create<scf::YieldOp>(location: loc, args&: reduc);
742 }
743 // Exit the loop.
744 rewriter.setInsertionPointAfter(forOp);
745 // In-place update reduction variables.
746 llvm::copy(Range: forOp.getResults(), Out: reduc.begin());
747 } else {
748 auto parOp = llvm::cast<scf::ParallelOp>(Val: loopInfo.loop);
749 if (!reduc.empty()) {
750 assert(reduc.size() == parOp.getInitVals().size() && reduc.size() == 1);
751 Operation *redExp = reduc.front().getDefiningOp();
752 // Reduction expression should have no use.
753 assert(redExp->getUses().empty());
754 // This must be a binary operation.
755 // NOTE: This is users' responsibility to ensure the operation are
756 // commutative.
757 assert(redExp->getNumOperands() == 2 && redExp->getNumResults() == 1);
758
759 Value redVal = parOp.getInitVals().front();
760 Value curVal;
761 if (redExp->getOperand(idx: 0) == redVal)
762 curVal = redExp->getOperand(idx: 1);
763 else if (redExp->getOperand(idx: 1) == redVal)
764 curVal = redExp->getOperand(idx: 0);
765 // One of the operands must be the init value (which is also the
766 // previous reduction value).
767 assert(curVal);
768#ifndef NDEBUG
769 // The reduction expression should be the only user of the reduction val
770 // inside the parallel for.
771 unsigned numUsers = 0;
772 for (Operation *op : redVal.getUsers()) {
773 if (op->getParentOp() == parOp)
774 numUsers++;
775 }
776 assert(numUsers == 1);
777#endif // NDEBUG
778
779 rewriter.setInsertionPointAfter(redExp);
780 auto redOp = rewriter.create<scf::ReduceOp>(location: loc, args&: curVal);
781 // Attach to the reduction op.
782 Block *redBlock = &redOp.getReductions().front().front();
783 rewriter.setInsertionPointToEnd(redBlock);
784 Operation *newRed = rewriter.clone(op&: *redExp);
785 // Replaces arguments of the reduction expression by using the block
786 // arguments from scf.reduce.
787 rewriter.modifyOpInPlace(
788 root: newRed, callable: [&]() { newRed->setOperands(redBlock->getArguments()); });
789 // Erases the out-dated reduction expression.
790 rewriter.eraseOp(op: redExp);
791 rewriter.setInsertionPointToEnd(redBlock);
792 rewriter.create<scf::ReduceReturnOp>(location: loc, args: newRed->getResult(idx: 0));
793 }
794 rewriter.setInsertionPointAfter(parOp);
795 // In-place update reduction variables.
796 for (unsigned i = 0, e = parOp.getResults().size(); i < e; i++)
797 reduc[i] = parOp.getResult(i);
798 }
799}
800
801void LoopEmitter::exitWhileLoop(OpBuilder &builder, Location loc,
802 MutableArrayRef<Value> reduc) {
803 const LoopInfo &loopInfo = loopStack.back();
804 auto whileOp = llvm::cast<scf::WhileOp>(Val: loopInfo.loop);
805 Value iv = loopInfo.iv;
806 Value one = C_IDX(1);
807
808 // Finalize the induction. Note that the induction could be performed
809 // in the individual if-branches to avoid re-evaluating the conditions.
810 // However, that would result in a rather elaborate forest of yield
811 // instructions during code generation. Moreover, performing the induction
812 // after the if-statements more closely resembles code generated by TACO.
813 SmallVector<Value> operands;
814 ValueRange whileRes = whileOp.getResults();
815
816 for (auto [tid, lvl] : unpackTensorLevelRange(c: loopInfo.tidLvls)) {
817 SparseIterator &it = getCurIterator(tid, lvl);
818 if (!it.randomAccessible()) {
819 // Forward the sparse iterator.
820 Value cmp = CMPI(eq, it.getCrd(), iv);
821 it.forwardIf(b&: builder, l: loc, cond: cmp);
822 operands.append(in_start: it.getCursor().begin(), in_end: it.getCursor().end());
823 // const Value newPos = whileOp->getResult(o++);
824 // Following loops continue iteration from the break point of the
825 // current while loop.
826 whileRes = it.linkNewScope(pos: whileRes);
827 } else {
828 // Make sure randomly accessible (dense) iterator is set to the right
829 // position according to the universal index.
830 Value uniIdx = whileOp.getResults().back();
831 it.locate(b&: builder, l: loc, crd: uniIdx);
832 }
833 }
834
835 // Reduction value from users.
836 for (auto &i : reduc) {
837 operands.push_back(Elt: i);
838 // Update user reduction variables.
839 i = whileRes.front();
840 whileRes = whileRes.drop_front();
841 }
842
843 // An (optional) universal index.
844 if (operands.size() < whileOp.getNumResults()) {
845 assert(operands.size() + 1 == whileOp.getNumResults());
846 // The last one is the universial index.
847 operands.push_back(ADDI(iv, one));
848 // update the loop starting point of current loop sequence
849 loopSeqStack.back().first = whileOp->getResults().back();
850 }
851
852 if (!operands.empty())
853 YIELD(operands);
854
855 builder.setInsertionPointAfter(whileOp);
856}
857
858void LoopEmitter::exitCurrentLoop(RewriterBase &rewriter, Location loc,
859 MutableArrayRef<Value> reduc) {
860 // Clean up the values, it would help use to discover potential bug at a
861 // earlier stage (instead of silently using a wrong value).
862 const LoopInfo &loopInfo = loopStack.back();
863 if (emitStrategy == SparseEmitStrategy::kSparseIterator) {
864 Operation *p = loopInfo.loop;
865 if (isa<IterateOp>(Val: p))
866 rewriter.create<sparse_tensor::YieldOp>(location: loc, args&: reduc);
867
868 // Exit the loop.
869 rewriter.setInsertionPointAfter(p);
870 // In-place update reduction variables.
871 llvm::copy(Range: p->getResults(), Out: reduc.begin());
872 loopStack.pop_back();
873 return;
874 }
875
876 // Sets the insertion point to the right position.
877 rewriter.setInsertionPointToEnd(loopInfo.userCodeBlock);
878 if (!loopInfo.userCodeBlock->empty() &&
879 llvm::isa<scf::YieldOp>(Val: &loopInfo.userCodeBlock->back())) {
880 // scf::While/For inserts an implicit yield op when there is no loop
881 // iter args. In this case, we need to insert the code before the yield.
882 assert(loopInfo.userCodeBlock->back().getNumResults() == 0);
883 rewriter.setInsertionPoint(&loopInfo.userCodeBlock->back());
884 }
885
886 if (llvm::isa<scf::WhileOp>(Val: loopInfo.loop)) {
887 exitWhileLoop(builder&: rewriter, loc, reduc);
888 } else {
889 exitForLoop(rewriter, loc, reduc);
890 }
891
892 assert(loopStack.size() == loopSeqStack.size());
893 loopStack.pop_back();
894}
895
896//===----------------------------------------------------------------------===//
897// Loop generation utils
898//===----------------------------------------------------------------------===//
899
900std::pair<Operation *, Value> sparse_tensor::genCoIteration(
901 OpBuilder &builder, Location loc, ArrayRef<SparseIterator *> spIters,
902 MutableArrayRef<Value> reduc, Value uniIdx, bool userReducFirst) {
903 // NOTE: the slice driven tensor-related reduction variable must
904 // appear before normal tensors.
905
906 // The set of induction variables for the while loop.
907 SmallVector<Value> ivs;
908
909 // TODO: remove the flag after full migration. Currently
910 // `sparse_tensor.coiterate` operation (must) put user provided reduction
911 // values at the front of the block list, while direct sparsification to scf
912 // loops put them at the end.
913 if (userReducFirst)
914 ivs.append(in_start: reduc.begin(), in_end: reduc.end());
915
916 // Construct the while-loop with a parameter for each coordinate.
917 for (SparseIterator *it : spIters) {
918 ValueRange itVals = it->getCursor();
919 ivs.append(in_start: itVals.begin(), in_end: itVals.end());
920 }
921
922 if (!userReducFirst)
923 ivs.append(in_start: reduc.begin(), in_end: reduc.end());
924
925 // Update universal index.
926 if (uniIdx)
927 ivs.push_back(Elt: uniIdx);
928
929 // Ensures all operands are valid.
930 assert(!llvm::is_contained(ivs, nullptr));
931 TypeRange types = ValueRange(ivs).getTypes();
932 auto whileOp = builder.create<scf::WhileOp>(location: loc, args&: types, args&: ivs);
933
934 SmallVector<Location> locs(types.size(), loc);
935 Block *before = builder.createBlock(parent: &whileOp.getBefore(), insertPt: {}, argTypes: types, locs);
936 Block *after = builder.createBlock(parent: &whileOp.getAfter(), insertPt: {}, argTypes: types, locs);
937
938 // Generates loop conditions.
939 builder.setInsertionPointToStart(before);
940 ValueRange bArgs = before->getArguments();
941 Value whileCond = nullptr; // bool values for loop condition.
942
943 for (SparseIterator *it : spIters) {
944 auto [cond, remArgs] = it->genWhileCond(b&: builder, l: loc, vs: bArgs);
945 whileCond = !whileCond ? cond : ANDI(whileCond, cond);
946 bArgs = remArgs;
947 }
948 // The remaining block arguments are user-provided reduction values and an
949 // optional universal index. Make sure their sizes match.
950 assert(bArgs.size() == reduc.size() + (uniIdx ? 1 : 0));
951 builder.create<scf::ConditionOp>(location: loc, args&: whileCond, args: before->getArguments());
952
953 // Generates loop body.
954 builder.setInsertionPointToStart(after);
955 ValueRange aArgs = after->getArguments();
956
957 for (SparseIterator *it : spIters) {
958 aArgs = it->linkNewScope(pos: aArgs);
959 // Dereference the iterator to cache the coordinate.
960 it->deref(b&: builder, l: loc);
961 }
962
963 // In-place update on reduction variable.
964 for (unsigned i = 0, e = reduc.size(); i < e; i++)
965 reduc[i] = aArgs[i];
966
967 Value min;
968 // Finds the minimum coordinate
969 if (!uniIdx) {
970 for (SparseIterator *it : spIters) {
971 if (min) {
972 Value cmp = CMPI(ult, it->getCrd(), min);
973 min = SELECT(cmp, it->getCrd(), min);
974 } else {
975 min = it->getCrd();
976 }
977 }
978 } else {
979 // Otherwise, universal index is the minimal pos.
980 min = whileOp.getAfterArguments().back();
981 }
982
983 return {whileOp, min};
984}
985
986#undef CMPI
987#undef C_IDX
988#undef YIELD
989#undef ADDI
990#undef ANDI
991#undef SUBI
992#undef MULI
993#undef SELECT
994

source code of mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp