| 1 | //===- LoopEmitter.cpp ----------------------------------------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "LoopEmitter.h" |
| 10 | #include "CodegenUtils.h" |
| 11 | |
| 12 | #include "mlir/Dialect/Arith/IR/Arith.h" |
| 13 | #include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
| 14 | #include "mlir/Dialect/Linalg/Utils/Utils.h" |
| 15 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| 16 | #include "mlir/Dialect/SCF/IR/SCF.h" |
| 17 | #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" |
| 18 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| 19 | |
| 20 | using namespace mlir; |
| 21 | using namespace mlir::sparse_tensor; |
| 22 | |
| 23 | //===----------------------------------------------------------------------===// |
| 24 | // File local shorthand macros |
| 25 | //===----------------------------------------------------------------------===// |
| 26 | |
| 27 | #define CMPI(p, l, r) \ |
| 28 | (builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::p, (l), (r)) \ |
| 29 | .getResult()) |
| 30 | |
| 31 | #define C_IDX(v) (constantIndex(builder, loc, (v))) |
| 32 | #define YIELD(vs) (builder.create<scf::YieldOp>(loc, (vs))) |
| 33 | #define ADDI(lhs, rhs) (builder.create<arith::AddIOp>(loc, (lhs), (rhs))) |
| 34 | #define ANDI(lhs, rhs) (builder.create<arith::AndIOp>(loc, (lhs), (rhs))) |
| 35 | #define SUBI(lhs, rhs) (builder.create<arith::SubIOp>(loc, (lhs), (rhs))) |
| 36 | #define MULI(lhs, rhs) (builder.create<arith::MulIOp>(loc, (lhs), (rhs))) |
| 37 | #define REMUI(lhs, rhs) (builder.create<arith::RemUIOp>(loc, (lhs), (rhs))) |
| 38 | #define DIVUI(lhs, rhs) (builder.create<arith::DivUIOp>(loc, (lhs), (rhs))) |
| 39 | #define SELECT(c, l, r) (builder.create<arith::SelectOp>(loc, (c), (l), (r))) |
| 40 | |
| 41 | //===----------------------------------------------------------------------===// |
| 42 | // Debugging utils |
| 43 | //===----------------------------------------------------------------------===// |
| 44 | |
| 45 | #ifndef NDEBUG |
| 46 | LLVM_ATTRIBUTE_UNUSED static void dumpIndexMemRef(OpBuilder &builder, |
| 47 | Location loc, Value memref) { |
| 48 | memref = builder.create<memref::CastOp>( |
| 49 | loc, UnrankedMemRefType::get(builder.getIndexType(), 0), memref); |
| 50 | createFuncCall(builder, loc, "printMemrefInd" , TypeRange{}, |
| 51 | ValueRange{memref}, EmitCInterface::On); |
| 52 | } |
| 53 | #endif |
| 54 | |
| 55 | //===----------------------------------------------------------------------===// |
| 56 | // File local helper functions. |
| 57 | //===----------------------------------------------------------------------===// |
| 58 | |
| 59 | // For index reduction loops, since the tensor are sliced into non-continuous |
| 60 | // fragments, we need a triple [pLo, pHi, pPtr], in which the pair (pLo, pHi) |
| 61 | // specifies the range of the fragment, and pPtr specifies the index of the |
| 62 | // corresponding fragment in the child level (i.e., a pointer to the sliced |
| 63 | // position array). |
| 64 | static Value genSliceOffset(OpBuilder &builder, Location loc, Value tensor, |
| 65 | Level lvl) { |
| 66 | auto enc = getSparseTensorEncoding(type: tensor.getType()); |
| 67 | return createOrFoldSliceOffsetOp(builder, loc, tensor, dim: toDim(enc, l: lvl)); |
| 68 | } |
| 69 | |
| 70 | static Value genSliceStride(OpBuilder &builder, Location loc, Value tensor, |
| 71 | Level lvl) { |
| 72 | auto enc = getSparseTensorEncoding(type: tensor.getType()); |
| 73 | return createOrFoldSliceStrideOp(builder, loc, tensor, dim: toDim(enc, l: lvl)); |
| 74 | } |
| 75 | |
| 76 | static bool isIntOrFPZero(Attribute attr) { |
| 77 | if (auto f = llvm::dyn_cast<FloatAttr>(Val&: attr); f && f.getValue().isZero()) |
| 78 | return true; |
| 79 | if (auto i = llvm::dyn_cast<IntegerAttr>(Val&: attr); i && i.getValue().isZero()) |
| 80 | return true; |
| 81 | return false; |
| 82 | } |
| 83 | |
| 84 | static Value unFoldOpIntResult(OpBuilder &builder, Location loc, |
| 85 | OpFoldResult ofr) { |
| 86 | if (std::optional<int64_t> i = getConstantIntValue(ofr); i.has_value()) |
| 87 | return constantIndex(builder, loc, i: *i); |
| 88 | return cast<Value>(Val&: ofr); |
| 89 | } |
| 90 | |
| 91 | static Value tryFoldTensors(Value t) { |
| 92 | // TODO: this should be done through a folding pass after switching to |
| 93 | // `sparse_tensor.iterate`-based sparsification. |
| 94 | auto stt = tryGetSparseTensorType(val: t); |
| 95 | auto padOp = t.getDefiningOp<tensor::PadOp>(); |
| 96 | if (padOp && stt.has_value() && stt->hasEncoding() && |
| 97 | padOp.getSourceType().getEncoding() == stt->getEncoding() && |
| 98 | stt->getEncoding().isIdentity()) { |
| 99 | // Try fusing padOp with zeros. |
| 100 | Attribute padCst; |
| 101 | if (matchPattern(op: padOp.getBody()->getTerminator(), |
| 102 | pattern: m_Op<tensor::YieldOp>(matchers: m_Constant(bind_value: &padCst))) && |
| 103 | isIntOrFPZero(attr: padCst)) { |
| 104 | return padOp.getSource(); |
| 105 | } |
| 106 | } |
| 107 | return t; |
| 108 | } |
| 109 | |
| 110 | //===----------------------------------------------------------------------===// |
| 111 | // Sparse tensor loop emitter class implementations |
| 112 | //===----------------------------------------------------------------------===// |
| 113 | |
| 114 | LoopEmitter::LoopEmitter(ValueRange tensors, StringAttr loopTag, bool hasOutput, |
| 115 | bool isSparseOut, unsigned numLoops, |
| 116 | DependentLvlGetter dimGetter, |
| 117 | SparseEmitStrategy emitStrategy) { |
| 118 | initialize(tensors, loopTag, hasOutput, isSparseOut, numLoops, getter: dimGetter); |
| 119 | } |
| 120 | |
| 121 | void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput, |
| 122 | bool isSparseOut, unsigned numLoops, |
| 123 | DependentLvlGetter dimGetter, |
| 124 | SparseEmitStrategy emitStrategy) { |
| 125 | // First initialize the top-level type of the fields. |
| 126 | this->loopTag = loopTag; |
| 127 | this->hasOutput = hasOutput; |
| 128 | this->isSparseOut = isSparseOut; |
| 129 | this->emitStrategy = emitStrategy; |
| 130 | |
| 131 | const unsigned numManifestTensors = ts.size(); |
| 132 | const unsigned synTensorId = numManifestTensors; |
| 133 | const unsigned numTensors = numManifestTensors + 1; |
| 134 | // tensors array (len == numManifestTensor). |
| 135 | this->tensors.assign(first: ts.begin(), last: ts.end()); |
| 136 | // Arrays with len == numTensor. |
| 137 | this->valBuffer.assign(n: numTensors, val: nullptr); |
| 138 | this->lvls.resize(new_size: numTensors); |
| 139 | this->iters.resize(new_size: numTensors); |
| 140 | this->spIterVals.resize(new_size: numTensors); |
| 141 | |
| 142 | // These zeros will be overwritten below, but we need to initialize |
| 143 | // them to something since we'll need random-access assignment. |
| 144 | this->loopStack.reserve(n: numLoops); |
| 145 | this->loopSeqStack.reserve(n: numLoops); |
| 146 | |
| 147 | // Index-reduction related fields. |
| 148 | this->dependentLvlMap.assign( |
| 149 | n: numTensors, val: std::vector<std::vector<std::pair<TensorLevel, unsigned>>>()); |
| 150 | this->sliceMeta.assign( |
| 151 | n: numTensors, val: std::vector<std::vector<std::pair<Value, unsigned>>>()); |
| 152 | this->levelReducedDep.assign(n: numTensors, val: std::vector<unsigned>()); |
| 153 | |
| 154 | // Initialize nested types of `TensorId`-indexed fields. |
| 155 | for (TensorId tid = 0; tid < numTensors; tid++) { |
| 156 | Level lvlRank; |
| 157 | if (tid == synTensorId) { |
| 158 | // Synthetic tensor (conceptually) is an all-dense tensor with rank equal |
| 159 | // to the total number of loops (each level can potentially be mapped to |
| 160 | // one of the loop being generated). |
| 161 | lvlRank = numLoops; |
| 162 | } else { |
| 163 | const Value t = tensors[tid]; |
| 164 | // a scalar or 0-dimension tensors |
| 165 | if (isZeroRankedTensorOrScalar(type: t.getType())) |
| 166 | continue; |
| 167 | |
| 168 | auto rtp = getRankedTensorType(t); |
| 169 | const SparseTensorType stt(rtp); |
| 170 | lvlRank = stt.getLvlRank(); |
| 171 | } |
| 172 | |
| 173 | lvls[tid].resize(new_size: lvlRank); |
| 174 | iters[tid].resize(new_size: lvlRank); |
| 175 | spIterVals[tid].resize(new_size: lvlRank); |
| 176 | loopHighs.assign(n: numLoops, val: nullptr); |
| 177 | |
| 178 | // Slice-driven loops related initialization. |
| 179 | levelReducedDep[tid].assign(n: lvlRank, val: 0); |
| 180 | dependentLvlMap[tid].assign( |
| 181 | n: lvlRank, val: std::vector<std::pair<TensorLevel, unsigned>>()); |
| 182 | sliceMeta[tid].assign(n: lvlRank, val: std::vector<std::pair<Value, unsigned>>()); |
| 183 | if (dimGetter && !isSynTensor(tid)) { |
| 184 | for (Level l = 0; l < lvlRank; l++) { |
| 185 | std::vector<std::pair<LoopId, unsigned>> deps = dimGetter(tid, l); |
| 186 | // Sort the loop by order. |
| 187 | llvm::sort(C&: deps, Comp: llvm::less_first()); |
| 188 | |
| 189 | dependentLvlMap[tid][l] = std::move(deps); |
| 190 | unsigned depends = dependentLvlMap[tid][l].size(); |
| 191 | if (depends == 0) |
| 192 | continue; |
| 193 | sliceMeta[tid][l].reserve(n: depends); |
| 194 | } |
| 195 | } |
| 196 | } |
| 197 | } |
| 198 | |
| 199 | std::unique_ptr<SparseIterator> |
| 200 | LoopEmitter::makeLevelIterator(OpBuilder &builder, Location loc, TensorId t, |
| 201 | Level l) { |
| 202 | Value tensor = tensors[t]; |
| 203 | auto stt = getSparseTensorType(val: tensor); |
| 204 | auto it = makeSimpleIterator(stl: *lvls[t][l], strategy: emitStrategy); |
| 205 | |
| 206 | Value folded = tryFoldTensors(t: tensor); |
| 207 | if (folded != tensor) { |
| 208 | auto padOp = tensor.getDefiningOp<tensor::PadOp>(); |
| 209 | assert(padOp); |
| 210 | if (padOp.getPaddedDims().test(Idx: l)) { |
| 211 | Value low = unFoldOpIntResult(builder, loc, ofr: padOp.getMixedLowPad()[l]); |
| 212 | Value high = unFoldOpIntResult(builder, loc, ofr: padOp.getMixedHighPad()[l]); |
| 213 | auto padIt = makePaddedIterator(sit: std::move(it), padLow: low, padHigh: high, strategy: emitStrategy); |
| 214 | return padIt; |
| 215 | } |
| 216 | } |
| 217 | |
| 218 | if (stt.hasEncoding() && stt.getEncoding().isSlice()) { |
| 219 | Value offset = genSliceOffset(builder, loc, tensor, lvl: l); |
| 220 | Value stride = genSliceStride(builder, loc, tensor, lvl: l); |
| 221 | auto slicedIt = makeSlicedLevelIterator( |
| 222 | sit: std::move(it), offset, stride, size: lvls[t][l]->getSize(), strategy: emitStrategy); |
| 223 | return slicedIt; |
| 224 | } |
| 225 | |
| 226 | return it; |
| 227 | } |
| 228 | |
| 229 | void LoopEmitter::initializeLoopEmit( |
| 230 | OpBuilder &builder, Location loc, LoopEmitter::OutputUpdater updater, |
| 231 | LoopEmitter::SynTensorBoundSetter synSetter) { |
| 232 | |
| 233 | // For every manifest tensor, set up the values buffer. |
| 234 | for (TensorId t = 0, numTensors = getNumManifestTensors(); t < numTensors; |
| 235 | t++) { |
| 236 | // TODO: this should be done through a folding pass after switching to |
| 237 | // `sparse_tensor.iterate`-based sparsification. |
| 238 | const Value tensor = tryFoldTensors(t: tensors[t]); |
| 239 | const auto rtp = dyn_cast<RankedTensorType>(Val: tensor.getType()); |
| 240 | // Skips only scalar, zero ranked tensor still need to be bufferized and |
| 241 | // (probably) filled with zeros by users. |
| 242 | if (!rtp) |
| 243 | continue; |
| 244 | |
| 245 | auto stt = getSparseTensorType(val: tensor); |
| 246 | const auto shape = rtp.getShape(); |
| 247 | |
| 248 | // Perform the required bufferization. Dense inputs materialize from the |
| 249 | // input tensors. Sparse inputs use sparse primitives to obtain the values. |
| 250 | // Delegates extra output initialization to clients. |
| 251 | bool isOutput = isOutputTensor(tid: t); |
| 252 | Type elementType = stt.getElementType(); |
| 253 | if (!stt.hasEncoding()) { |
| 254 | // Non-annotated dense tensors. |
| 255 | BaseMemRefType denseTp = MemRefType::get(shape, elementType); |
| 256 | |
| 257 | // TODO: if we unconditionally use fully dynamic layout here, it breaks |
| 258 | // some vectorization passes which requires static stride = 1. |
| 259 | // Is it possible to call vectorization pass after bufferization? |
| 260 | if (llvm::isa_and_nonnull<tensor::ExtractSliceOp>(Val: tensor.getDefiningOp())) |
| 261 | denseTp = bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType: rtp); |
| 262 | |
| 263 | Value denseVal = |
| 264 | builder.create<bufferization::ToBufferOp>(location: loc, args&: denseTp, args: tensor); |
| 265 | // Dense outputs need special handling. |
| 266 | if (isOutput && updater) |
| 267 | denseVal = updater(builder, loc, denseVal, tensor); |
| 268 | |
| 269 | valBuffer[t] = denseVal; |
| 270 | } else { |
| 271 | // Annotated sparse tensors. |
| 272 | // We also need the value buffer for all-dense annotated "sparse" |
| 273 | // tensors. |
| 274 | valBuffer[t] = builder.create<ToValuesOp>(location: loc, args: tensor); |
| 275 | } |
| 276 | } |
| 277 | |
| 278 | // The sparse iterator values will only be available after the loop is |
| 279 | // constructed. |
| 280 | if (emitStrategy == SparseEmitStrategy::kSparseIterator) |
| 281 | return; |
| 282 | |
| 283 | // For every synthetic tensor, set the high bound by calling the callback. |
| 284 | if (synSetter) { |
| 285 | TensorId synId = getSynTensorId(); |
| 286 | for (unsigned i = 0, e = loopHighs.size(); i < e; i++) { |
| 287 | Value sz = loopHighs[i] = synSetter(builder, loc, i); |
| 288 | auto [stl, it] = makeSynLevelAndIterator(sz, tid: synId, lvl: i, strategy: emitStrategy); |
| 289 | lvls[synId][i] = std::move(stl); |
| 290 | iters[synId][i].emplace_back(args: std::move(it)); |
| 291 | } |
| 292 | } |
| 293 | |
| 294 | // For every manifest tensor: |
| 295 | // * For every level: |
| 296 | // * get the positions and coordinates buffers |
| 297 | // * get/compute the level-size, which is also used as the upper-bound |
| 298 | // on positions. |
| 299 | for (TensorId t = 0, numTensors = getNumManifestTensors(); t < numTensors; |
| 300 | t++) { |
| 301 | // TODO: this should be done through a folding pass after switching to |
| 302 | // `sparse_tensor.iterate`-based sparsification. |
| 303 | const Value tensor = tryFoldTensors(t: tensors[t]); |
| 304 | const auto rtp = dyn_cast<RankedTensorType>(Val: tensor.getType()); |
| 305 | if (!rtp) |
| 306 | // Skips only scalar, zero ranked tensor still need to be bufferized and |
| 307 | // (probably) filled with zeros by users. |
| 308 | continue; |
| 309 | |
| 310 | auto stt = getSparseTensorType(val: tensor); |
| 311 | const Level lvlRank = stt.getLvlRank(); |
| 312 | |
| 313 | // Scan all levels of current tensor. |
| 314 | for (Level l = 0; l < lvlRank; l++) { |
| 315 | // Find upper bound in current dimension. |
| 316 | lvls[t][l] = makeSparseTensorLevel(b&: builder, l: loc, t: tensor, tid: t, lvl: l); |
| 317 | if (!dependentLvlMap[t][l].empty()) |
| 318 | continue; |
| 319 | |
| 320 | auto it = makeLevelIterator(builder, loc, t, l); |
| 321 | iters[t][l].emplace_back(args: std::move(it)); |
| 322 | } |
| 323 | // NOTE: we can also prepare for 0 lvl here in advance, this will hoist |
| 324 | // some loop preparation from tensor iteration, but will also (undesirably) |
| 325 | // hoist the code ouside if-conditions. |
| 326 | } |
| 327 | // TODO: avoid treating subsection iterator as a special case. |
| 328 | initSubSectIterator(builder, loc); |
| 329 | } |
| 330 | |
| 331 | void LoopEmitter::initSubSectIterator(OpBuilder &builder, Location loc) { |
| 332 | Value c0 = C_IDX(0); |
| 333 | for (TensorId t = 0, e = tensors.size(); t < e; t++) { |
| 334 | auto rtp = dyn_cast<RankedTensorType>(Val: tensors[t].getType()); |
| 335 | if (!rtp) |
| 336 | continue; |
| 337 | |
| 338 | Level lvlRank = SparseTensorType(rtp).getLvlRank(); |
| 339 | |
| 340 | // Compute the dependency reduction order. |
| 341 | auto remDepStack = dependentLvlMap; |
| 342 | std::vector<std::tuple<LoopId, TensorId, Level>> depRedOrder; |
| 343 | for (Level lvl = 0; lvl < lvlRank; lvl++) { |
| 344 | // Reverse queue into a stack. |
| 345 | std::reverse(first: remDepStack[t][lvl].begin(), last: remDepStack[t][lvl].end()); |
| 346 | for (auto [loop, coeff] : dependentLvlMap[t][lvl]) |
| 347 | depRedOrder.emplace_back(args: std::make_tuple(args&: loop, args&: t, args&: lvl)); |
| 348 | } |
| 349 | |
| 350 | if (depRedOrder.empty()) |
| 351 | continue; |
| 352 | |
| 353 | llvm::sort(C&: depRedOrder, Comp: llvm::less_first()); |
| 354 | |
| 355 | SmallVector<SparseIterator *> lastIter(tensors.size(), nullptr); |
| 356 | for (auto [loop, t, lvl] : depRedOrder) { |
| 357 | std::pair<LoopId, unsigned> curDep = remDepStack[t][lvl].back(); |
| 358 | assert(curDep.first == loop); |
| 359 | remDepStack[t][lvl].pop_back(); |
| 360 | |
| 361 | auto lvlIt = makeLevelIterator(builder, loc, t, l: lvl); |
| 362 | const SparseIterator *parent = lastIter[t]; |
| 363 | if (!parent && lvl > 0) { |
| 364 | if (dependentLvlMap[t][lvl - 1].empty()) { |
| 365 | parent = iters[t][lvl - 1].back().get(); |
| 366 | } |
| 367 | } |
| 368 | |
| 369 | std::unique_ptr<SparseIterator> it; |
| 370 | if (!remDepStack[t][lvl].empty()) { |
| 371 | // Compute the subsection size. |
| 372 | Value size = c0; |
| 373 | for (auto [loop, stride] : remDepStack[t][lvl]) { |
| 374 | Value idxMax = SUBI(loopHighs[loop], C_IDX(1)); |
| 375 | size = ADDI(size, ADDI(MULI(idxMax, C_IDX(stride)), C_IDX(1))); |
| 376 | } |
| 377 | it = makeNonEmptySubSectIterator(b&: builder, l: loc, parent, loopBound: loopHighs[loop], |
| 378 | delegate: std::move(lvlIt), size, stride: curDep.second, |
| 379 | strategy: emitStrategy); |
| 380 | } else { |
| 381 | const SparseIterator &subSectIter = *iters[t][lvl].back(); |
| 382 | it = makeTraverseSubSectIterator(b&: builder, l: loc, subsectIter: subSectIter, parent: *parent, |
| 383 | wrap: std::move(lvlIt), loopBound: loopHighs[loop], |
| 384 | stride: curDep.second, strategy: emitStrategy); |
| 385 | } |
| 386 | lastIter[t] = it.get(); |
| 387 | iters[t][lvl].emplace_back(args: std::move(it)); |
| 388 | } |
| 389 | } |
| 390 | } |
| 391 | |
| 392 | void LoopEmitter::categorizeIterators( |
| 393 | ArrayRef<TensorLevel> tidLvls, SmallVectorImpl<SparseIterator *> &raIters, |
| 394 | SmallVectorImpl<SparseIterator *> &spIters) { |
| 395 | // Finds out the tensor level that we should use to generate loops. Amongs all |
| 396 | // the tensor levels, there is at most one sparse tensor level. |
| 397 | for (auto [t, l] : unpackTensorLevelRange(c&: tidLvls)) { |
| 398 | SparseIterator *it = &getCurIterator(tid: t, lvl: l); |
| 399 | if (it->randomAccessible()) |
| 400 | raIters.push_back(Elt: it); |
| 401 | else |
| 402 | spIters.push_back(Elt: it); |
| 403 | } |
| 404 | |
| 405 | llvm::stable_sort(Range&: spIters, C: [](auto lhs, auto rhs) { |
| 406 | // AffineUnRed > Affine > Slice > Trivial |
| 407 | return static_cast<uint8_t>(lhs->kind) > static_cast<uint8_t>(rhs->kind); |
| 408 | }); |
| 409 | } |
| 410 | |
| 411 | void LoopEmitter::enterNewLoopSeq(OpBuilder &builder, Location loc, |
| 412 | ArrayRef<TensorLevel> tidLvls) { |
| 413 | // TODO: sort |
| 414 | assert(loopSeqStack.size() == loopStack.size()); |
| 415 | |
| 416 | if (emitStrategy != SparseEmitStrategy::kSparseIterator) { |
| 417 | // Prepares for all the tensors used in the current loop sequence. |
| 418 | for (auto [tid, lvl] : unpackTensorLevelRange(c&: tidLvls)) { |
| 419 | levelReducedDep[tid][lvl]++; |
| 420 | prepareLoopOverTensorAtLvl(builder, loc, tid, lvl); |
| 421 | } |
| 422 | } |
| 423 | |
| 424 | // Universal Index starts from 0. |
| 425 | loopSeqStack.emplace_back(C_IDX(0), args: tidLvls.vec()); |
| 426 | } |
| 427 | |
| 428 | void LoopEmitter::exitCurrentLoopSeq(OpBuilder &builder, Location loc) { |
| 429 | assert(loopSeqStack.size() == loopStack.size() + 1); |
| 430 | |
| 431 | // Depending on whether the slice is resolved or not at current loop sequence, |
| 432 | // end them in different ways. |
| 433 | for (auto [tid, lvl] : unpackTensorLevelRange(c&: loopSeqStack.back().second)) |
| 434 | levelReducedDep[tid][lvl]--; |
| 435 | |
| 436 | loopSeqStack.pop_back(); |
| 437 | } |
| 438 | |
| 439 | Value LoopEmitter::genAffine(OpBuilder &builder, Location loc, AffineExpr a) { |
| 440 | switch (a.getKind()) { |
| 441 | case AffineExprKind::DimId: { |
| 442 | // FIXME: since the one callsite in Sparsification passes in a |
| 443 | // level-expression, the `getPosition` must in fact be a `Dimension`. |
| 444 | // However, elsewhere we have been lead to expect that `loopIdToOrd` |
| 445 | // should be indexed by `LoopId`... |
| 446 | const auto loopId = cast<AffineDimExpr>(Val&: a).getPosition(); |
| 447 | return loopStack[loopId].iv; |
| 448 | } |
| 449 | case AffineExprKind::Add: { |
| 450 | auto binOp = cast<AffineBinaryOpExpr>(Val&: a); |
| 451 | return ADDI(genAffine(builder, loc, binOp.getLHS()), |
| 452 | genAffine(builder, loc, binOp.getRHS())); |
| 453 | } |
| 454 | case AffineExprKind::Mul: { |
| 455 | auto binOp = cast<AffineBinaryOpExpr>(Val&: a); |
| 456 | return MULI(genAffine(builder, loc, binOp.getLHS()), |
| 457 | genAffine(builder, loc, binOp.getRHS())); |
| 458 | } |
| 459 | case AffineExprKind::Constant: { |
| 460 | int64_t c = cast<AffineConstantExpr>(Val&: a).getValue(); |
| 461 | return C_IDX(c); |
| 462 | } |
| 463 | default: |
| 464 | llvm_unreachable("unexpected affine subscript" ); |
| 465 | } |
| 466 | } |
| 467 | |
| 468 | std::pair<Operation *, Value> LoopEmitter::emitForLoopOverTensorAtLvl( |
| 469 | OpBuilder &builder, Location loc, SparseIterator &iter, |
| 470 | MutableArrayRef<Value> reduc, bool isParallel) { |
| 471 | |
| 472 | // TODO: support dynamic slices. |
| 473 | // Uses the first dimension here to build the loop bound (which is also the |
| 474 | // biggest range). |
| 475 | |
| 476 | Value step = C_IDX(1); |
| 477 | auto [lo, hi] = iter.genForCond(b&: builder, l: loc); |
| 478 | Operation *loop = nullptr; |
| 479 | Value iv; |
| 480 | if (isParallel) { |
| 481 | scf::ParallelOp parOp = |
| 482 | builder.create<scf::ParallelOp>(location: loc, args&: lo, args&: hi, args&: step, args&: reduc); |
| 483 | builder.setInsertionPointToStart(parOp.getBody()); |
| 484 | assert(parOp.getNumReductions() == reduc.size()); |
| 485 | iv = parOp.getInductionVars()[0]; |
| 486 | |
| 487 | // In-place update on the reduction variable vector. |
| 488 | // Note that the init vals is not the actual reduction variables but instead |
| 489 | // used as a "special handle" to (temporarily) represent them. The |
| 490 | // expression on init vals will be moved into scf.reduce and replaced with |
| 491 | // the block arguments when exiting the loop (see exitForLoop). This is |
| 492 | // needed as we can not build the actual reduction block and get the actual |
| 493 | // reduction variable before users fill parallel loop body. |
| 494 | for (int i = 0, e = reduc.size(); i < e; i++) |
| 495 | reduc[i] = parOp.getInitVals()[i]; |
| 496 | loop = parOp; |
| 497 | } else { |
| 498 | scf::ForOp forOp = builder.create<scf::ForOp>(location: loc, args&: lo, args&: hi, args&: step, args&: reduc); |
| 499 | builder.setInsertionPointToStart(forOp.getBody()); |
| 500 | iv = forOp.getInductionVar(); |
| 501 | |
| 502 | // In-place update on the reduction variable vector. |
| 503 | assert(forOp.getNumRegionIterArgs() == reduc.size()); |
| 504 | for (int i = 0, e = reduc.size(); i < e; i++) |
| 505 | reduc[i] = forOp.getRegionIterArg(index: i); |
| 506 | loop = forOp; |
| 507 | } |
| 508 | assert(loop && iv); |
| 509 | |
| 510 | Value crd = iv; |
| 511 | if (!iter.randomAccessible()) { |
| 512 | iter.linkNewScope(pos: iv); |
| 513 | crd = iter.deref(b&: builder, l: loc); |
| 514 | } else { |
| 515 | iter.locate(b&: builder, l: loc, crd: iv); |
| 516 | } |
| 517 | |
| 518 | return {loop, crd}; |
| 519 | } |
| 520 | |
| 521 | std::pair<Operation *, Value> LoopEmitter::emitWhileLoopOverTensorsAtLvls( |
| 522 | OpBuilder &builder, Location loc, ArrayRef<SparseIterator *> spIters, |
| 523 | MutableArrayRef<Value> reduc, bool needsUniv) { |
| 524 | return genCoIteration(builder, loc, iters: spIters, reduc, |
| 525 | uniIdx: needsUniv ? loopSeqStack.back().first : nullptr); |
| 526 | } |
| 527 | |
| 528 | bool LoopEmitter::shouldIteratedByForLoop(ArrayRef<SparseIterator *> spIters) { |
| 529 | // If we need to co-iterate over two sparse tensors, we need a while loop |
| 530 | if (spIters.size() > 1) |
| 531 | return false; |
| 532 | |
| 533 | if (spIters.size() == 1) |
| 534 | return spIters.front()->iteratableByFor(); |
| 535 | |
| 536 | return true; |
| 537 | } |
| 538 | |
| 539 | Region *LoopEmitter::enterCurrentCoIterationCase(OpBuilder &builder, |
| 540 | Location loc, |
| 541 | I64BitSet caseBit, |
| 542 | unsigned caseIdx, |
| 543 | MutableArrayRef<Value> reduc) { |
| 544 | auto coIterOp = cast<CoIterateOp>(Val: loopStack.back().loop); |
| 545 | SmallVector<Attribute> cases(coIterOp.getCases().getAsRange<Attribute>()); |
| 546 | cases[caseIdx] = builder.getI64IntegerAttr(value: caseBit); |
| 547 | |
| 548 | coIterOp.setCasesAttr(builder.getArrayAttr(value: cases)); |
| 549 | Region &caseRegion = coIterOp.getRegion(i: caseIdx); |
| 550 | assert(caseRegion.getBlocks().empty() && |
| 551 | "re-initialize the same coiteration case region." ); |
| 552 | |
| 553 | // Each block starts with by a list of user-provided iteration arguments. |
| 554 | TypeRange iterArgsTps = coIterOp.getInitArgs().getTypes(); |
| 555 | // Followed by a list of used coordinates of index type. |
| 556 | SmallVector<Type> blockArgTps(coIterOp.getCrdUsedLvls().count(), |
| 557 | builder.getIndexType()); |
| 558 | |
| 559 | blockArgTps.append(in_start: iterArgsTps.begin(), in_end: iterArgsTps.end()); |
| 560 | // Ends with a set of iterators that defines the actually iteration space. |
| 561 | for (auto i : caseBit.bits()) { |
| 562 | blockArgTps.push_back( |
| 563 | Elt: cast<IterSpaceType>(Val: coIterOp.getIterSpaces()[i].getType()) |
| 564 | .getIteratorType()); |
| 565 | } |
| 566 | SmallVector<Location> locs(blockArgTps.size(), loc); |
| 567 | caseRegion.emplaceBlock().addArguments(types: blockArgTps, locs); |
| 568 | |
| 569 | // Entering the new region scope, updating the SSA chain. |
| 570 | builder.setInsertionPointToStart(&caseRegion.front()); |
| 571 | // Update the coordinates. |
| 572 | loopStack.back().iv = coIterOp.getCrds(regionIdx: caseIdx).front(); |
| 573 | // Updates loop iteration arguments. |
| 574 | ValueRange iterArgs = coIterOp.getRegionIterArgs(regionIdx: caseIdx); |
| 575 | llvm::copy(Range&: iterArgs, Out: reduc.begin()); |
| 576 | // Updates sparse iterator values. |
| 577 | ValueRange iters = coIterOp.getRegionIterators(regionIdx: caseIdx); |
| 578 | ArrayRef<TensorLevel> tidLvls = loopStack.back().tidLvls; |
| 579 | for (auto [i, tl] : llvm::enumerate(First: unpackTensorLevelRange(c&: tidLvls))) { |
| 580 | if (caseBit[i]) { |
| 581 | spIterVals[tl.first][tl.second] = iters.front(); |
| 582 | iters = iters.drop_front(); |
| 583 | } else { |
| 584 | spIterVals[tl.first][tl.second] = nullptr; |
| 585 | } |
| 586 | } |
| 587 | // Must have consumed all iterator SSA values. |
| 588 | assert(iters.empty()); |
| 589 | return &caseRegion; |
| 590 | } |
| 591 | |
| 592 | Operation *LoopEmitter::enterCoIterationOverTensorsAtLvls( |
| 593 | OpBuilder &builder, Location loc, ArrayRef<TensorLevel> tidLvls, |
| 594 | unsigned numCases, MutableArrayRef<Value> reduc, bool tryParallel, |
| 595 | bool needsUniv) { |
| 596 | // TODO: Argument `numCases` only used when generating iterator-based sparse |
| 597 | // loops. Simplify the code upon feature complete. |
| 598 | // TODO: handle coiteration with sparse iterator. |
| 599 | if (emitStrategy == SparseEmitStrategy::kSparseIterator) { |
| 600 | if (tidLvls.size() == 1) { |
| 601 | auto [tid, lvl] = unpackTensorLevel(tidLvl: tidLvls.front()); |
| 602 | Value t = tensors[tid]; |
| 603 | |
| 604 | // Extract and iterate over the iteration space. |
| 605 | ExtractIterSpaceOp = |
| 606 | lvl == 0 ? builder.create<ExtractIterSpaceOp>(location: loc, args&: t) |
| 607 | : builder.create<ExtractIterSpaceOp>( |
| 608 | location: loc, args&: t, args&: spIterVals[tid][lvl - 1], args&: lvl); |
| 609 | |
| 610 | IterateOp iterOp = builder.create<IterateOp>( |
| 611 | location: loc, args: extractSpaceOp.getExtractedSpace(), args&: reduc); |
| 612 | spIterVals[tid][lvl] = iterOp.getIterator(); |
| 613 | |
| 614 | // Update the reduction varaibles. |
| 615 | llvm::copy(Range: iterOp.getRegionIterArgs(), Out: reduc.begin()); |
| 616 | // Set the insertion point to loop body. |
| 617 | builder.setInsertionPointToStart(iterOp.getBody()); |
| 618 | loopStack.emplace_back(args&: tidLvls, args&: iterOp, args: builder.getInsertionBlock(), |
| 619 | args&: iterOp.getCrds().front(), args&: loopTag); |
| 620 | return iterOp; |
| 621 | } |
| 622 | |
| 623 | // CoIteration Loops. |
| 624 | SmallVector<Value> spaces; |
| 625 | for (auto [tid, lvl] : unpackTensorLevelRange(c&: tidLvls)) { |
| 626 | Value t = tensors[tid]; |
| 627 | ExtractIterSpaceOp = |
| 628 | lvl == 0 ? builder.create<ExtractIterSpaceOp>(location: loc, args&: t) |
| 629 | : builder.create<ExtractIterSpaceOp>( |
| 630 | location: loc, args&: t, args&: spIterVals[tid][lvl - 1], args&: lvl); |
| 631 | spaces.push_back(Elt: extractSpaceOp.getExtractedSpace()); |
| 632 | } |
| 633 | auto coIterOp = builder.create<CoIterateOp>(location: loc, args&: spaces, args&: reduc, args&: numCases); |
| 634 | // The CoIterationOp does not have insertion block nor induction variable. |
| 635 | // TODO: the `struct LoopInfo` should be simplied after full migration. |
| 636 | loopStack.emplace_back(args&: tidLvls, args&: coIterOp, /*insertion block*/ args: nullptr, |
| 637 | /*induction variable*/ args: nullptr, args&: loopTag); |
| 638 | return coIterOp; |
| 639 | } |
| 640 | |
| 641 | // TODO: support multiple return on parallel for? |
| 642 | tryParallel = tryParallel && reduc.size() <= 1; |
| 643 | |
| 644 | SmallVector<SparseIterator *> raIters; |
| 645 | SmallVector<SparseIterator *> spIters; |
| 646 | categorizeIterators(tidLvls, raIters, spIters); |
| 647 | |
| 648 | // Only when there is at least one sparse conditions, do we really need the |
| 649 | // universal index. |
| 650 | // TODO: Maybe we should instead requires merger to pass in a valid value at |
| 651 | // the first place instead of adjusting it in LoopEmitter? |
| 652 | needsUniv = !spIters.empty() && needsUniv; |
| 653 | // The TensorLevel used for loop conditions. |
| 654 | // If there is any sparse level, we need to use the sparse condition. |
| 655 | // If all levels are dense, we can pick arbitrary one (dense slice-driven loop |
| 656 | // can be generated using a simple ForOp as well). |
| 657 | Operation *l = nullptr; |
| 658 | Value iv = nullptr; |
| 659 | SmallVector<TensorLevel> tls; |
| 660 | |
| 661 | // Generates loops differently depending on whether we need a slice-driven |
| 662 | // loop or a simple level traversal loop. |
| 663 | if (shouldIteratedByForLoop(spIters) && !needsUniv) { |
| 664 | assert(spIters.size() <= 1); |
| 665 | SparseIterator &it = spIters.empty() ? *raIters.front() : *spIters.front(); |
| 666 | std::tie(args&: l, args&: iv) = |
| 667 | emitForLoopOverTensorAtLvl(builder, loc, iter&: it, reduc, isParallel: tryParallel); |
| 668 | tls.push_back(Elt: makeTensorLevel(t: it.tid, l: it.lvl)); |
| 669 | } else { |
| 670 | for (auto *it : spIters) { |
| 671 | tls.push_back(Elt: makeTensorLevel(t: it->tid, l: it->lvl)); |
| 672 | } |
| 673 | |
| 674 | if (needsUniv) |
| 675 | for (auto *it : raIters) |
| 676 | tls.push_back(Elt: makeTensorLevel(t: it->tid, l: it->lvl)); |
| 677 | |
| 678 | std::tie(args&: l, args&: iv) = |
| 679 | emitWhileLoopOverTensorsAtLvls(builder, loc, spIters, reduc, needsUniv); |
| 680 | } |
| 681 | |
| 682 | // Enter dense tensor levels. |
| 683 | for (SparseIterator *it : raIters) |
| 684 | it->locate(b&: builder, l: loc, crd: iv); |
| 685 | |
| 686 | // NOTE: we can also prepare for next dim here in advance |
| 687 | // Pushes the loop into stack. |
| 688 | loopStack.emplace_back(args&: tls, args&: l, args: builder.getInsertionBlock(), args&: iv, args&: loopTag); |
| 689 | return l; |
| 690 | } |
| 691 | |
| 692 | void LoopEmitter::locateLvlAtAffineAddress(OpBuilder &builder, Location loc, |
| 693 | TensorLevel tidLvl, |
| 694 | AffineExpr lvlExpr) { |
| 695 | auto [tid, lvl] = unpackTensorLevel(tidLvl); |
| 696 | |
| 697 | const SparseIterator *parent = |
| 698 | lvl == 0 ? nullptr : iters[tid][lvl - 1].back().get(); |
| 699 | auto &it = getCurIterator(tid, lvl); |
| 700 | it.genInit(b&: builder, l: loc, p: parent); |
| 701 | |
| 702 | assert(it.kind == IterKind::kTrivial && it.randomAccessible()); |
| 703 | Value lvlCrd = genAffine(builder, loc, a: lvlExpr); |
| 704 | it.locate(b&: builder, l: loc, crd: lvlCrd); |
| 705 | } |
| 706 | |
| 707 | void LoopEmitter::prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc, |
| 708 | TensorId tid, Level lvl) { |
| 709 | // if this is the first level, there is no parent iterator for the current |
| 710 | // iterator. |
| 711 | // If the current iterator is a subsection-based iterator, the parent iterator |
| 712 | // is memorized by the iterator. |
| 713 | bool hasParent = lvl == 0 || !dependentLvlMap[tid][lvl].empty(); |
| 714 | |
| 715 | const SparseIterator *parent = |
| 716 | hasParent ? nullptr : iters[tid][lvl - 1].back().get(); |
| 717 | auto &it = getCurIterator(tid, lvl); |
| 718 | it.genInit(b&: builder, l: loc, p: parent); |
| 719 | |
| 720 | // Locates the randon accessible iterator to 0. |
| 721 | if (it.randomAccessible()) |
| 722 | it.locate(b&: builder, l: loc, C_IDX(0)); |
| 723 | } |
| 724 | |
| 725 | void LoopEmitter::exitForLoop(RewriterBase &rewriter, Location loc, |
| 726 | MutableArrayRef<Value> reduc) { |
| 727 | const LoopInfo &loopInfo = loopStack.back(); |
| 728 | if (emitStrategy == SparseEmitStrategy::kSparseIterator) { |
| 729 | auto iterateOp = llvm::cast<IterateOp>(Val: loopInfo.loop); |
| 730 | assert(reduc.size() == iterateOp.getNumResults()); |
| 731 | rewriter.create<sparse_tensor::YieldOp>(location: loc, args&: reduc); |
| 732 | // Exit the loop. |
| 733 | rewriter.setInsertionPointAfter(iterateOp); |
| 734 | // In-place update reduction variables. |
| 735 | llvm::copy(Range: iterateOp.getResults(), Out: reduc.begin()); |
| 736 | return; |
| 737 | } |
| 738 | if (auto forOp = llvm::dyn_cast<scf::ForOp>(Val: loopInfo.loop)) { |
| 739 | if (!reduc.empty()) { |
| 740 | assert(reduc.size() == forOp.getNumResults()); |
| 741 | rewriter.create<scf::YieldOp>(location: loc, args&: reduc); |
| 742 | } |
| 743 | // Exit the loop. |
| 744 | rewriter.setInsertionPointAfter(forOp); |
| 745 | // In-place update reduction variables. |
| 746 | llvm::copy(Range: forOp.getResults(), Out: reduc.begin()); |
| 747 | } else { |
| 748 | auto parOp = llvm::cast<scf::ParallelOp>(Val: loopInfo.loop); |
| 749 | if (!reduc.empty()) { |
| 750 | assert(reduc.size() == parOp.getInitVals().size() && reduc.size() == 1); |
| 751 | Operation *redExp = reduc.front().getDefiningOp(); |
| 752 | // Reduction expression should have no use. |
| 753 | assert(redExp->getUses().empty()); |
| 754 | // This must be a binary operation. |
| 755 | // NOTE: This is users' responsibility to ensure the operation are |
| 756 | // commutative. |
| 757 | assert(redExp->getNumOperands() == 2 && redExp->getNumResults() == 1); |
| 758 | |
| 759 | Value redVal = parOp.getInitVals().front(); |
| 760 | Value curVal; |
| 761 | if (redExp->getOperand(idx: 0) == redVal) |
| 762 | curVal = redExp->getOperand(idx: 1); |
| 763 | else if (redExp->getOperand(idx: 1) == redVal) |
| 764 | curVal = redExp->getOperand(idx: 0); |
| 765 | // One of the operands must be the init value (which is also the |
| 766 | // previous reduction value). |
| 767 | assert(curVal); |
| 768 | #ifndef NDEBUG |
| 769 | // The reduction expression should be the only user of the reduction val |
| 770 | // inside the parallel for. |
| 771 | unsigned numUsers = 0; |
| 772 | for (Operation *op : redVal.getUsers()) { |
| 773 | if (op->getParentOp() == parOp) |
| 774 | numUsers++; |
| 775 | } |
| 776 | assert(numUsers == 1); |
| 777 | #endif // NDEBUG |
| 778 | |
| 779 | rewriter.setInsertionPointAfter(redExp); |
| 780 | auto redOp = rewriter.create<scf::ReduceOp>(location: loc, args&: curVal); |
| 781 | // Attach to the reduction op. |
| 782 | Block *redBlock = &redOp.getReductions().front().front(); |
| 783 | rewriter.setInsertionPointToEnd(redBlock); |
| 784 | Operation *newRed = rewriter.clone(op&: *redExp); |
| 785 | // Replaces arguments of the reduction expression by using the block |
| 786 | // arguments from scf.reduce. |
| 787 | rewriter.modifyOpInPlace( |
| 788 | root: newRed, callable: [&]() { newRed->setOperands(redBlock->getArguments()); }); |
| 789 | // Erases the out-dated reduction expression. |
| 790 | rewriter.eraseOp(op: redExp); |
| 791 | rewriter.setInsertionPointToEnd(redBlock); |
| 792 | rewriter.create<scf::ReduceReturnOp>(location: loc, args: newRed->getResult(idx: 0)); |
| 793 | } |
| 794 | rewriter.setInsertionPointAfter(parOp); |
| 795 | // In-place update reduction variables. |
| 796 | for (unsigned i = 0, e = parOp.getResults().size(); i < e; i++) |
| 797 | reduc[i] = parOp.getResult(i); |
| 798 | } |
| 799 | } |
| 800 | |
| 801 | void LoopEmitter::exitWhileLoop(OpBuilder &builder, Location loc, |
| 802 | MutableArrayRef<Value> reduc) { |
| 803 | const LoopInfo &loopInfo = loopStack.back(); |
| 804 | auto whileOp = llvm::cast<scf::WhileOp>(Val: loopInfo.loop); |
| 805 | Value iv = loopInfo.iv; |
| 806 | Value one = C_IDX(1); |
| 807 | |
| 808 | // Finalize the induction. Note that the induction could be performed |
| 809 | // in the individual if-branches to avoid re-evaluating the conditions. |
| 810 | // However, that would result in a rather elaborate forest of yield |
| 811 | // instructions during code generation. Moreover, performing the induction |
| 812 | // after the if-statements more closely resembles code generated by TACO. |
| 813 | SmallVector<Value> operands; |
| 814 | ValueRange whileRes = whileOp.getResults(); |
| 815 | |
| 816 | for (auto [tid, lvl] : unpackTensorLevelRange(c: loopInfo.tidLvls)) { |
| 817 | SparseIterator &it = getCurIterator(tid, lvl); |
| 818 | if (!it.randomAccessible()) { |
| 819 | // Forward the sparse iterator. |
| 820 | Value cmp = CMPI(eq, it.getCrd(), iv); |
| 821 | it.forwardIf(b&: builder, l: loc, cond: cmp); |
| 822 | operands.append(in_start: it.getCursor().begin(), in_end: it.getCursor().end()); |
| 823 | // const Value newPos = whileOp->getResult(o++); |
| 824 | // Following loops continue iteration from the break point of the |
| 825 | // current while loop. |
| 826 | whileRes = it.linkNewScope(pos: whileRes); |
| 827 | } else { |
| 828 | // Make sure randomly accessible (dense) iterator is set to the right |
| 829 | // position according to the universal index. |
| 830 | Value uniIdx = whileOp.getResults().back(); |
| 831 | it.locate(b&: builder, l: loc, crd: uniIdx); |
| 832 | } |
| 833 | } |
| 834 | |
| 835 | // Reduction value from users. |
| 836 | for (auto &i : reduc) { |
| 837 | operands.push_back(Elt: i); |
| 838 | // Update user reduction variables. |
| 839 | i = whileRes.front(); |
| 840 | whileRes = whileRes.drop_front(); |
| 841 | } |
| 842 | |
| 843 | // An (optional) universal index. |
| 844 | if (operands.size() < whileOp.getNumResults()) { |
| 845 | assert(operands.size() + 1 == whileOp.getNumResults()); |
| 846 | // The last one is the universial index. |
| 847 | operands.push_back(ADDI(iv, one)); |
| 848 | // update the loop starting point of current loop sequence |
| 849 | loopSeqStack.back().first = whileOp->getResults().back(); |
| 850 | } |
| 851 | |
| 852 | if (!operands.empty()) |
| 853 | YIELD(operands); |
| 854 | |
| 855 | builder.setInsertionPointAfter(whileOp); |
| 856 | } |
| 857 | |
| 858 | void LoopEmitter::exitCurrentLoop(RewriterBase &rewriter, Location loc, |
| 859 | MutableArrayRef<Value> reduc) { |
| 860 | // Clean up the values, it would help use to discover potential bug at a |
| 861 | // earlier stage (instead of silently using a wrong value). |
| 862 | const LoopInfo &loopInfo = loopStack.back(); |
| 863 | if (emitStrategy == SparseEmitStrategy::kSparseIterator) { |
| 864 | Operation *p = loopInfo.loop; |
| 865 | if (isa<IterateOp>(Val: p)) |
| 866 | rewriter.create<sparse_tensor::YieldOp>(location: loc, args&: reduc); |
| 867 | |
| 868 | // Exit the loop. |
| 869 | rewriter.setInsertionPointAfter(p); |
| 870 | // In-place update reduction variables. |
| 871 | llvm::copy(Range: p->getResults(), Out: reduc.begin()); |
| 872 | loopStack.pop_back(); |
| 873 | return; |
| 874 | } |
| 875 | |
| 876 | // Sets the insertion point to the right position. |
| 877 | rewriter.setInsertionPointToEnd(loopInfo.userCodeBlock); |
| 878 | if (!loopInfo.userCodeBlock->empty() && |
| 879 | llvm::isa<scf::YieldOp>(Val: &loopInfo.userCodeBlock->back())) { |
| 880 | // scf::While/For inserts an implicit yield op when there is no loop |
| 881 | // iter args. In this case, we need to insert the code before the yield. |
| 882 | assert(loopInfo.userCodeBlock->back().getNumResults() == 0); |
| 883 | rewriter.setInsertionPoint(&loopInfo.userCodeBlock->back()); |
| 884 | } |
| 885 | |
| 886 | if (llvm::isa<scf::WhileOp>(Val: loopInfo.loop)) { |
| 887 | exitWhileLoop(builder&: rewriter, loc, reduc); |
| 888 | } else { |
| 889 | exitForLoop(rewriter, loc, reduc); |
| 890 | } |
| 891 | |
| 892 | assert(loopStack.size() == loopSeqStack.size()); |
| 893 | loopStack.pop_back(); |
| 894 | } |
| 895 | |
| 896 | //===----------------------------------------------------------------------===// |
| 897 | // Loop generation utils |
| 898 | //===----------------------------------------------------------------------===// |
| 899 | |
| 900 | std::pair<Operation *, Value> sparse_tensor::genCoIteration( |
| 901 | OpBuilder &builder, Location loc, ArrayRef<SparseIterator *> spIters, |
| 902 | MutableArrayRef<Value> reduc, Value uniIdx, bool userReducFirst) { |
| 903 | // NOTE: the slice driven tensor-related reduction variable must |
| 904 | // appear before normal tensors. |
| 905 | |
| 906 | // The set of induction variables for the while loop. |
| 907 | SmallVector<Value> ivs; |
| 908 | |
| 909 | // TODO: remove the flag after full migration. Currently |
| 910 | // `sparse_tensor.coiterate` operation (must) put user provided reduction |
| 911 | // values at the front of the block list, while direct sparsification to scf |
| 912 | // loops put them at the end. |
| 913 | if (userReducFirst) |
| 914 | ivs.append(in_start: reduc.begin(), in_end: reduc.end()); |
| 915 | |
| 916 | // Construct the while-loop with a parameter for each coordinate. |
| 917 | for (SparseIterator *it : spIters) { |
| 918 | ValueRange itVals = it->getCursor(); |
| 919 | ivs.append(in_start: itVals.begin(), in_end: itVals.end()); |
| 920 | } |
| 921 | |
| 922 | if (!userReducFirst) |
| 923 | ivs.append(in_start: reduc.begin(), in_end: reduc.end()); |
| 924 | |
| 925 | // Update universal index. |
| 926 | if (uniIdx) |
| 927 | ivs.push_back(Elt: uniIdx); |
| 928 | |
| 929 | // Ensures all operands are valid. |
| 930 | assert(!llvm::is_contained(ivs, nullptr)); |
| 931 | TypeRange types = ValueRange(ivs).getTypes(); |
| 932 | auto whileOp = builder.create<scf::WhileOp>(location: loc, args&: types, args&: ivs); |
| 933 | |
| 934 | SmallVector<Location> locs(types.size(), loc); |
| 935 | Block *before = builder.createBlock(parent: &whileOp.getBefore(), insertPt: {}, argTypes: types, locs); |
| 936 | Block *after = builder.createBlock(parent: &whileOp.getAfter(), insertPt: {}, argTypes: types, locs); |
| 937 | |
| 938 | // Generates loop conditions. |
| 939 | builder.setInsertionPointToStart(before); |
| 940 | ValueRange bArgs = before->getArguments(); |
| 941 | Value whileCond = nullptr; // bool values for loop condition. |
| 942 | |
| 943 | for (SparseIterator *it : spIters) { |
| 944 | auto [cond, remArgs] = it->genWhileCond(b&: builder, l: loc, vs: bArgs); |
| 945 | whileCond = !whileCond ? cond : ANDI(whileCond, cond); |
| 946 | bArgs = remArgs; |
| 947 | } |
| 948 | // The remaining block arguments are user-provided reduction values and an |
| 949 | // optional universal index. Make sure their sizes match. |
| 950 | assert(bArgs.size() == reduc.size() + (uniIdx ? 1 : 0)); |
| 951 | builder.create<scf::ConditionOp>(location: loc, args&: whileCond, args: before->getArguments()); |
| 952 | |
| 953 | // Generates loop body. |
| 954 | builder.setInsertionPointToStart(after); |
| 955 | ValueRange aArgs = after->getArguments(); |
| 956 | |
| 957 | for (SparseIterator *it : spIters) { |
| 958 | aArgs = it->linkNewScope(pos: aArgs); |
| 959 | // Dereference the iterator to cache the coordinate. |
| 960 | it->deref(b&: builder, l: loc); |
| 961 | } |
| 962 | |
| 963 | // In-place update on reduction variable. |
| 964 | for (unsigned i = 0, e = reduc.size(); i < e; i++) |
| 965 | reduc[i] = aArgs[i]; |
| 966 | |
| 967 | Value min; |
| 968 | // Finds the minimum coordinate |
| 969 | if (!uniIdx) { |
| 970 | for (SparseIterator *it : spIters) { |
| 971 | if (min) { |
| 972 | Value cmp = CMPI(ult, it->getCrd(), min); |
| 973 | min = SELECT(cmp, it->getCrd(), min); |
| 974 | } else { |
| 975 | min = it->getCrd(); |
| 976 | } |
| 977 | } |
| 978 | } else { |
| 979 | // Otherwise, universal index is the minimal pos. |
| 980 | min = whileOp.getAfterArguments().back(); |
| 981 | } |
| 982 | |
| 983 | return {whileOp, min}; |
| 984 | } |
| 985 | |
| 986 | #undef CMPI |
| 987 | #undef C_IDX |
| 988 | #undef YIELD |
| 989 | #undef ADDI |
| 990 | #undef ANDI |
| 991 | #undef SUBI |
| 992 | #undef MULI |
| 993 | #undef SELECT |
| 994 | |