| 1 | //===- LoopEmitter.cpp ----------------------------------------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "LoopEmitter.h" |
| 10 | #include "CodegenUtils.h" |
| 11 | |
| 12 | #include "mlir/Dialect/Arith/IR/Arith.h" |
| 13 | #include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
| 14 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
| 15 | #include "mlir/Dialect/Linalg/Utils/Utils.h" |
| 16 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| 17 | #include "mlir/Dialect/SCF/IR/SCF.h" |
| 18 | #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" |
| 19 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| 20 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
| 21 | |
| 22 | using namespace mlir; |
| 23 | using namespace mlir::sparse_tensor; |
| 24 | |
| 25 | //===----------------------------------------------------------------------===// |
| 26 | // File local shorthand macros |
| 27 | //===----------------------------------------------------------------------===// |
| 28 | |
| 29 | #define CMPI(p, l, r) \ |
| 30 | (builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::p, (l), (r)) \ |
| 31 | .getResult()) |
| 32 | |
| 33 | #define C_IDX(v) (constantIndex(builder, loc, (v))) |
| 34 | #define YIELD(vs) (builder.create<scf::YieldOp>(loc, (vs))) |
| 35 | #define ADDI(lhs, rhs) (builder.create<arith::AddIOp>(loc, (lhs), (rhs))) |
| 36 | #define ANDI(lhs, rhs) (builder.create<arith::AndIOp>(loc, (lhs), (rhs))) |
| 37 | #define SUBI(lhs, rhs) (builder.create<arith::SubIOp>(loc, (lhs), (rhs))) |
| 38 | #define MULI(lhs, rhs) (builder.create<arith::MulIOp>(loc, (lhs), (rhs))) |
| 39 | #define REMUI(lhs, rhs) (builder.create<arith::RemUIOp>(loc, (lhs), (rhs))) |
| 40 | #define DIVUI(lhs, rhs) (builder.create<arith::DivUIOp>(loc, (lhs), (rhs))) |
| 41 | #define SELECT(c, l, r) (builder.create<arith::SelectOp>(loc, (c), (l), (r))) |
| 42 | |
| 43 | //===----------------------------------------------------------------------===// |
| 44 | // Debugging utils |
| 45 | //===----------------------------------------------------------------------===// |
| 46 | |
| 47 | #ifndef NDEBUG |
| 48 | LLVM_ATTRIBUTE_UNUSED static void dumpIndexMemRef(OpBuilder &builder, |
| 49 | Location loc, Value memref) { |
| 50 | memref = builder.create<memref::CastOp>( |
| 51 | loc, UnrankedMemRefType::get(builder.getIndexType(), 0), memref); |
| 52 | createFuncCall(builder, loc, name: "printMemrefInd" , resultType: TypeRange{}, |
| 53 | operands: ValueRange{memref}, emitCInterface: EmitCInterface::On); |
| 54 | } |
| 55 | #endif |
| 56 | |
| 57 | //===----------------------------------------------------------------------===// |
| 58 | // File local helper functions. |
| 59 | //===----------------------------------------------------------------------===// |
| 60 | |
| 61 | // For index reduction loops, since the tensor are sliced into non-continuous |
| 62 | // fragments, we need a triple [pLo, pHi, pPtr], in which the pair (pLo, pHi) |
| 63 | // specifies the range of the fragment, and pPtr specifies the index of the |
| 64 | // corresponding fragment in the child level (i.e., a pointer to the sliced |
| 65 | // position array). |
| 66 | static Value genSliceOffset(OpBuilder &builder, Location loc, Value tensor, |
| 67 | Level lvl) { |
| 68 | auto enc = getSparseTensorEncoding(tensor.getType()); |
| 69 | return createOrFoldSliceOffsetOp(builder, loc, tensor, toDim(enc, lvl)); |
| 70 | } |
| 71 | |
| 72 | static Value genSliceStride(OpBuilder &builder, Location loc, Value tensor, |
| 73 | Level lvl) { |
| 74 | auto enc = getSparseTensorEncoding(tensor.getType()); |
| 75 | return createOrFoldSliceStrideOp(builder, loc, tensor, toDim(enc, lvl)); |
| 76 | } |
| 77 | |
| 78 | static bool isIntOrFPZero(Attribute attr) { |
| 79 | if (auto f = llvm::dyn_cast<FloatAttr>(attr); f && f.getValue().isZero()) |
| 80 | return true; |
| 81 | if (auto i = llvm::dyn_cast<IntegerAttr>(attr); i && i.getValue().isZero()) |
| 82 | return true; |
| 83 | return false; |
| 84 | } |
| 85 | |
| 86 | static Value unFoldOpIntResult(OpBuilder &builder, Location loc, |
| 87 | OpFoldResult ofr) { |
| 88 | if (std::optional<int64_t> i = getConstantIntValue(ofr); i.has_value()) |
| 89 | return constantIndex(builder, loc, i: *i); |
| 90 | return cast<Value>(Val&: ofr); |
| 91 | } |
| 92 | |
| 93 | static Value tryFoldTensors(Value t) { |
| 94 | // TODO: this should be done through a folding pass after switching to |
| 95 | // `sparse_tensor.iterate`-based sparsification. |
| 96 | auto stt = tryGetSparseTensorType(t); |
| 97 | auto padOp = t.getDefiningOp<tensor::PadOp>(); |
| 98 | if (padOp && stt.has_value() && stt->hasEncoding() && |
| 99 | padOp.getSourceType().getEncoding() == stt->getEncoding() && |
| 100 | stt->getEncoding().isIdentity()) { |
| 101 | // Try fusing padOp with zeros. |
| 102 | Attribute padCst; |
| 103 | if (matchPattern(padOp.getBody()->getTerminator(), |
| 104 | m_Op<tensor::YieldOp>(m_Constant(&padCst))) && |
| 105 | isIntOrFPZero(padCst)) { |
| 106 | return padOp.getSource(); |
| 107 | } |
| 108 | } |
| 109 | return t; |
| 110 | } |
| 111 | |
| 112 | //===----------------------------------------------------------------------===// |
| 113 | // Sparse tensor loop emitter class implementations |
| 114 | //===----------------------------------------------------------------------===// |
| 115 | |
| 116 | LoopEmitter::LoopEmitter(ValueRange tensors, StringAttr loopTag, bool hasOutput, |
| 117 | bool isSparseOut, unsigned numLoops, |
| 118 | DependentLvlGetter dimGetter, |
| 119 | SparseEmitStrategy emitStrategy) { |
| 120 | initialize(tensors, loopTag: loopTag, hasOutput, isSparseOut, numLoops, getter: dimGetter); |
| 121 | } |
| 122 | |
| 123 | void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput, |
| 124 | bool isSparseOut, unsigned numLoops, |
| 125 | DependentLvlGetter dimGetter, |
| 126 | SparseEmitStrategy emitStrategy) { |
| 127 | // First initialize the top-level type of the fields. |
| 128 | this->loopTag = loopTag; |
| 129 | this->hasOutput = hasOutput; |
| 130 | this->isSparseOut = isSparseOut; |
| 131 | this->emitStrategy = emitStrategy; |
| 132 | |
| 133 | const unsigned numManifestTensors = ts.size(); |
| 134 | const unsigned synTensorId = numManifestTensors; |
| 135 | const unsigned numTensors = numManifestTensors + 1; |
| 136 | // tensors array (len == numManifestTensor). |
| 137 | this->tensors.assign(first: ts.begin(), last: ts.end()); |
| 138 | // Arrays with len == numTensor. |
| 139 | this->valBuffer.assign(n: numTensors, val: nullptr); |
| 140 | this->lvls.resize(new_size: numTensors); |
| 141 | this->iters.resize(new_size: numTensors); |
| 142 | this->spIterVals.resize(new_size: numTensors); |
| 143 | |
| 144 | // These zeros will be overwritten below, but we need to initialize |
| 145 | // them to something since we'll need random-access assignment. |
| 146 | this->loopStack.reserve(n: numLoops); |
| 147 | this->loopSeqStack.reserve(n: numLoops); |
| 148 | |
| 149 | // Index-reduction related fields. |
| 150 | this->dependentLvlMap.assign( |
| 151 | n: numTensors, val: std::vector<std::vector<std::pair<TensorLevel, unsigned>>>()); |
| 152 | this->sliceMeta.assign( |
| 153 | n: numTensors, val: std::vector<std::vector<std::pair<Value, unsigned>>>()); |
| 154 | this->levelReducedDep.assign(n: numTensors, val: std::vector<unsigned>()); |
| 155 | |
| 156 | // Initialize nested types of `TensorId`-indexed fields. |
| 157 | for (TensorId tid = 0; tid < numTensors; tid++) { |
| 158 | Level lvlRank; |
| 159 | if (tid == synTensorId) { |
| 160 | // Synthetic tensor (conceptually) is an all-dense tensor with rank equal |
| 161 | // to the total number of loops (each level can potentially be mapped to |
| 162 | // one of the loop being generated). |
| 163 | lvlRank = numLoops; |
| 164 | } else { |
| 165 | const Value t = tensors[tid]; |
| 166 | // a scalar or 0-dimension tensors |
| 167 | if (isZeroRankedTensorOrScalar(type: t.getType())) |
| 168 | continue; |
| 169 | |
| 170 | auto rtp = getRankedTensorType(t); |
| 171 | const SparseTensorType stt(rtp); |
| 172 | lvlRank = stt.getLvlRank(); |
| 173 | } |
| 174 | |
| 175 | lvls[tid].resize(new_size: lvlRank); |
| 176 | iters[tid].resize(new_size: lvlRank); |
| 177 | spIterVals[tid].resize(new_size: lvlRank); |
| 178 | loopHighs.assign(n: numLoops, val: nullptr); |
| 179 | |
| 180 | // Slice-driven loops related initialization. |
| 181 | levelReducedDep[tid].assign(n: lvlRank, val: 0); |
| 182 | dependentLvlMap[tid].assign( |
| 183 | n: lvlRank, val: std::vector<std::pair<TensorLevel, unsigned>>()); |
| 184 | sliceMeta[tid].assign(n: lvlRank, val: std::vector<std::pair<Value, unsigned>>()); |
| 185 | if (dimGetter && !isSynTensor(tid)) { |
| 186 | for (Level l = 0; l < lvlRank; l++) { |
| 187 | std::vector<std::pair<LoopId, unsigned>> deps = dimGetter(tid, l); |
| 188 | // Sort the loop by order. |
| 189 | llvm::sort(C&: deps, Comp: llvm::less_first()); |
| 190 | |
| 191 | dependentLvlMap[tid][l] = std::move(deps); |
| 192 | unsigned depends = dependentLvlMap[tid][l].size(); |
| 193 | if (depends == 0) |
| 194 | continue; |
| 195 | sliceMeta[tid][l].reserve(n: depends); |
| 196 | } |
| 197 | } |
| 198 | } |
| 199 | } |
| 200 | |
| 201 | std::unique_ptr<SparseIterator> |
| 202 | LoopEmitter::makeLevelIterator(OpBuilder &builder, Location loc, TensorId t, |
| 203 | Level l) { |
| 204 | Value tensor = tensors[t]; |
| 205 | auto stt = getSparseTensorType(val: tensor); |
| 206 | auto it = makeSimpleIterator(stl: *lvls[t][l], strategy: emitStrategy); |
| 207 | |
| 208 | Value folded = tryFoldTensors(t: tensor); |
| 209 | if (folded != tensor) { |
| 210 | auto padOp = tensor.getDefiningOp<tensor::PadOp>(); |
| 211 | assert(padOp); |
| 212 | if (padOp.getPaddedDims().test(l)) { |
| 213 | Value low = unFoldOpIntResult(builder, loc, padOp.getMixedLowPad()[l]); |
| 214 | Value high = unFoldOpIntResult(builder, loc, padOp.getMixedHighPad()[l]); |
| 215 | auto padIt = makePaddedIterator(sit: std::move(it), padLow: low, padHigh: high, strategy: emitStrategy); |
| 216 | return padIt; |
| 217 | } |
| 218 | } |
| 219 | |
| 220 | if (stt.hasEncoding() && stt.getEncoding().isSlice()) { |
| 221 | Value offset = genSliceOffset(builder, loc, tensor, lvl: l); |
| 222 | Value stride = genSliceStride(builder, loc, tensor, lvl: l); |
| 223 | auto slicedIt = makeSlicedLevelIterator( |
| 224 | sit: std::move(it), offset, stride, size: lvls[t][l]->getSize(), strategy: emitStrategy); |
| 225 | return slicedIt; |
| 226 | } |
| 227 | |
| 228 | return it; |
| 229 | } |
| 230 | |
| 231 | void LoopEmitter::initializeLoopEmit( |
| 232 | OpBuilder &builder, Location loc, LoopEmitter::OutputUpdater updater, |
| 233 | LoopEmitter::SynTensorBoundSetter synSetter) { |
| 234 | |
| 235 | // For every manifest tensor, set up the values buffer. |
| 236 | for (TensorId t = 0, numTensors = getNumManifestTensors(); t < numTensors; |
| 237 | t++) { |
| 238 | // TODO: this should be done through a folding pass after switching to |
| 239 | // `sparse_tensor.iterate`-based sparsification. |
| 240 | const Value tensor = tryFoldTensors(t: tensors[t]); |
| 241 | const auto rtp = dyn_cast<RankedTensorType>(tensor.getType()); |
| 242 | // Skips only scalar, zero ranked tensor still need to be bufferized and |
| 243 | // (probably) filled with zeros by users. |
| 244 | if (!rtp) |
| 245 | continue; |
| 246 | |
| 247 | auto stt = getSparseTensorType(val: tensor); |
| 248 | const auto shape = rtp.getShape(); |
| 249 | |
| 250 | // Perform the required bufferization. Dense inputs materialize from the |
| 251 | // input tensors. Sparse inputs use sparse primitives to obtain the values. |
| 252 | // Delegates extra output initialization to clients. |
| 253 | bool isOutput = isOutputTensor(tid: t); |
| 254 | Type elementType = stt.getElementType(); |
| 255 | if (!stt.hasEncoding()) { |
| 256 | // Non-annotated dense tensors. |
| 257 | BaseMemRefType denseTp = MemRefType::get(shape, elementType); |
| 258 | |
| 259 | // TODO: if we unconditionally use fully dynamic layout here, it breaks |
| 260 | // some vectorization passes which requires static stride = 1. |
| 261 | // Is it possible to call vectorization pass after bufferization? |
| 262 | if (llvm::isa_and_nonnull<tensor::ExtractSliceOp>(Val: tensor.getDefiningOp())) |
| 263 | denseTp = bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType: rtp); |
| 264 | |
| 265 | Value denseVal = |
| 266 | builder.create<bufferization::ToBufferOp>(loc, denseTp, tensor); |
| 267 | // Dense outputs need special handling. |
| 268 | if (isOutput && updater) |
| 269 | denseVal = updater(builder, loc, denseVal, tensor); |
| 270 | |
| 271 | valBuffer[t] = denseVal; |
| 272 | } else { |
| 273 | // Annotated sparse tensors. |
| 274 | // We also need the value buffer for all-dense annotated "sparse" |
| 275 | // tensors. |
| 276 | valBuffer[t] = builder.create<ToValuesOp>(loc, tensor); |
| 277 | } |
| 278 | } |
| 279 | |
| 280 | // The sparse iterator values will only be available after the loop is |
| 281 | // constructed. |
| 282 | if (emitStrategy == SparseEmitStrategy::kSparseIterator) |
| 283 | return; |
| 284 | |
| 285 | // For every synthetic tensor, set the high bound by calling the callback. |
| 286 | if (synSetter) { |
| 287 | TensorId synId = getSynTensorId(); |
| 288 | for (unsigned i = 0, e = loopHighs.size(); i < e; i++) { |
| 289 | Value sz = loopHighs[i] = synSetter(builder, loc, i); |
| 290 | auto [stl, it] = makeSynLevelAndIterator(sz, tid: synId, lvl: i, strategy: emitStrategy); |
| 291 | lvls[synId][i] = std::move(stl); |
| 292 | iters[synId][i].emplace_back(args: std::move(it)); |
| 293 | } |
| 294 | } |
| 295 | |
| 296 | // For every manifest tensor: |
| 297 | // * For every level: |
| 298 | // * get the positions and coordinates buffers |
| 299 | // * get/compute the level-size, which is also used as the upper-bound |
| 300 | // on positions. |
| 301 | for (TensorId t = 0, numTensors = getNumManifestTensors(); t < numTensors; |
| 302 | t++) { |
| 303 | // TODO: this should be done through a folding pass after switching to |
| 304 | // `sparse_tensor.iterate`-based sparsification. |
| 305 | const Value tensor = tryFoldTensors(t: tensors[t]); |
| 306 | const auto rtp = dyn_cast<RankedTensorType>(tensor.getType()); |
| 307 | if (!rtp) |
| 308 | // Skips only scalar, zero ranked tensor still need to be bufferized and |
| 309 | // (probably) filled with zeros by users. |
| 310 | continue; |
| 311 | |
| 312 | auto stt = getSparseTensorType(val: tensor); |
| 313 | const Level lvlRank = stt.getLvlRank(); |
| 314 | |
| 315 | // Scan all levels of current tensor. |
| 316 | for (Level l = 0; l < lvlRank; l++) { |
| 317 | // Find upper bound in current dimension. |
| 318 | lvls[t][l] = makeSparseTensorLevel(b&: builder, l: loc, t: tensor, tid: t, lvl: l); |
| 319 | if (!dependentLvlMap[t][l].empty()) |
| 320 | continue; |
| 321 | |
| 322 | auto it = makeLevelIterator(builder, loc, t, l); |
| 323 | iters[t][l].emplace_back(args: std::move(it)); |
| 324 | } |
| 325 | // NOTE: we can also prepare for 0 lvl here in advance, this will hoist |
| 326 | // some loop preparation from tensor iteration, but will also (undesirably) |
| 327 | // hoist the code ouside if-conditions. |
| 328 | } |
| 329 | // TODO: avoid treating subsection iterator as a special case. |
| 330 | initSubSectIterator(builder, loc); |
| 331 | } |
| 332 | |
| 333 | void LoopEmitter::initSubSectIterator(OpBuilder &builder, Location loc) { |
| 334 | Value c0 = C_IDX(0); |
| 335 | for (TensorId t = 0, e = tensors.size(); t < e; t++) { |
| 336 | auto rtp = dyn_cast<RankedTensorType>(tensors[t].getType()); |
| 337 | if (!rtp) |
| 338 | continue; |
| 339 | |
| 340 | Level lvlRank = SparseTensorType(rtp).getLvlRank(); |
| 341 | |
| 342 | // Compute the dependency reduction order. |
| 343 | auto remDepStack = dependentLvlMap; |
| 344 | std::vector<std::tuple<LoopId, TensorId, Level>> depRedOrder; |
| 345 | for (Level lvl = 0; lvl < lvlRank; lvl++) { |
| 346 | // Reverse queue into a stack. |
| 347 | std::reverse(first: remDepStack[t][lvl].begin(), last: remDepStack[t][lvl].end()); |
| 348 | for (auto [loop, coeff] : dependentLvlMap[t][lvl]) |
| 349 | depRedOrder.emplace_back(args: std::make_tuple(args&: loop, args&: t, args&: lvl)); |
| 350 | } |
| 351 | |
| 352 | if (depRedOrder.empty()) |
| 353 | continue; |
| 354 | |
| 355 | llvm::sort(C&: depRedOrder, Comp: llvm::less_first()); |
| 356 | |
| 357 | SmallVector<SparseIterator *> lastIter(tensors.size(), nullptr); |
| 358 | for (auto [loop, t, lvl] : depRedOrder) { |
| 359 | std::pair<LoopId, unsigned> curDep = remDepStack[t][lvl].back(); |
| 360 | assert(curDep.first == loop); |
| 361 | remDepStack[t][lvl].pop_back(); |
| 362 | |
| 363 | auto lvlIt = makeLevelIterator(builder, loc, t, l: lvl); |
| 364 | const SparseIterator *parent = lastIter[t]; |
| 365 | if (!parent && lvl > 0) { |
| 366 | if (dependentLvlMap[t][lvl - 1].empty()) { |
| 367 | parent = iters[t][lvl - 1].back().get(); |
| 368 | } |
| 369 | } |
| 370 | |
| 371 | std::unique_ptr<SparseIterator> it; |
| 372 | if (!remDepStack[t][lvl].empty()) { |
| 373 | // Compute the subsection size. |
| 374 | Value size = c0; |
| 375 | for (auto [loop, stride] : remDepStack[t][lvl]) { |
| 376 | Value idxMax = SUBI(loopHighs[loop], C_IDX(1)); |
| 377 | size = ADDI(size, ADDI(MULI(idxMax, C_IDX(stride)), C_IDX(1))); |
| 378 | } |
| 379 | it = makeNonEmptySubSectIterator(b&: builder, l: loc, parent, loopBound: loopHighs[loop], |
| 380 | delegate: std::move(lvlIt), size, stride: curDep.second, |
| 381 | strategy: emitStrategy); |
| 382 | } else { |
| 383 | const SparseIterator &subSectIter = *iters[t][lvl].back(); |
| 384 | it = makeTraverseSubSectIterator(b&: builder, l: loc, subsectIter: subSectIter, parent: *parent, |
| 385 | wrap: std::move(lvlIt), loopBound: loopHighs[loop], |
| 386 | stride: curDep.second, strategy: emitStrategy); |
| 387 | } |
| 388 | lastIter[t] = it.get(); |
| 389 | iters[t][lvl].emplace_back(args: std::move(it)); |
| 390 | } |
| 391 | } |
| 392 | } |
| 393 | |
| 394 | void LoopEmitter::categorizeIterators( |
| 395 | ArrayRef<TensorLevel> tidLvls, SmallVectorImpl<SparseIterator *> &raIters, |
| 396 | SmallVectorImpl<SparseIterator *> &spIters) { |
| 397 | // Finds out the tensor level that we should use to generate loops. Amongs all |
| 398 | // the tensor levels, there is at most one sparse tensor level. |
| 399 | for (auto [t, l] : unpackTensorLevelRange(c&: tidLvls)) { |
| 400 | SparseIterator *it = &getCurIterator(tid: t, lvl: l); |
| 401 | if (it->randomAccessible()) |
| 402 | raIters.push_back(Elt: it); |
| 403 | else |
| 404 | spIters.push_back(Elt: it); |
| 405 | } |
| 406 | |
| 407 | llvm::stable_sort(Range&: spIters, C: [](auto lhs, auto rhs) { |
| 408 | // AffineUnRed > Affine > Slice > Trivial |
| 409 | return static_cast<uint8_t>(lhs->kind) > static_cast<uint8_t>(rhs->kind); |
| 410 | }); |
| 411 | } |
| 412 | |
| 413 | void LoopEmitter::enterNewLoopSeq(OpBuilder &builder, Location loc, |
| 414 | ArrayRef<TensorLevel> tidLvls) { |
| 415 | // TODO: sort |
| 416 | assert(loopSeqStack.size() == loopStack.size()); |
| 417 | |
| 418 | if (emitStrategy != SparseEmitStrategy::kSparseIterator) { |
| 419 | // Prepares for all the tensors used in the current loop sequence. |
| 420 | for (auto [tid, lvl] : unpackTensorLevelRange(c&: tidLvls)) { |
| 421 | levelReducedDep[tid][lvl]++; |
| 422 | prepareLoopOverTensorAtLvl(builder, loc, tid, lvl); |
| 423 | } |
| 424 | } |
| 425 | |
| 426 | // Universal Index starts from 0. |
| 427 | loopSeqStack.emplace_back(C_IDX(0), args: tidLvls.vec()); |
| 428 | } |
| 429 | |
| 430 | void LoopEmitter::exitCurrentLoopSeq(OpBuilder &builder, Location loc) { |
| 431 | assert(loopSeqStack.size() == loopStack.size() + 1); |
| 432 | |
| 433 | // Depending on whether the slice is resolved or not at current loop sequence, |
| 434 | // end them in different ways. |
| 435 | for (auto [tid, lvl] : unpackTensorLevelRange(c&: loopSeqStack.back().second)) |
| 436 | levelReducedDep[tid][lvl]--; |
| 437 | |
| 438 | loopSeqStack.pop_back(); |
| 439 | } |
| 440 | |
| 441 | Value LoopEmitter::genAffine(OpBuilder &builder, Location loc, AffineExpr a) { |
| 442 | switch (a.getKind()) { |
| 443 | case AffineExprKind::DimId: { |
| 444 | // FIXME: since the one callsite in Sparsification passes in a |
| 445 | // level-expression, the `getPosition` must in fact be a `Dimension`. |
| 446 | // However, elsewhere we have been lead to expect that `loopIdToOrd` |
| 447 | // should be indexed by `LoopId`... |
| 448 | const auto loopId = cast<AffineDimExpr>(Val&: a).getPosition(); |
| 449 | return loopStack[loopId].iv; |
| 450 | } |
| 451 | case AffineExprKind::Add: { |
| 452 | auto binOp = cast<AffineBinaryOpExpr>(Val&: a); |
| 453 | return ADDI(genAffine(builder, loc, binOp.getLHS()), |
| 454 | genAffine(builder, loc, binOp.getRHS())); |
| 455 | } |
| 456 | case AffineExprKind::Mul: { |
| 457 | auto binOp = cast<AffineBinaryOpExpr>(Val&: a); |
| 458 | return MULI(genAffine(builder, loc, binOp.getLHS()), |
| 459 | genAffine(builder, loc, binOp.getRHS())); |
| 460 | } |
| 461 | case AffineExprKind::Constant: { |
| 462 | int64_t c = cast<AffineConstantExpr>(Val&: a).getValue(); |
| 463 | return C_IDX(c); |
| 464 | } |
| 465 | default: |
| 466 | llvm_unreachable("unexpected affine subscript" ); |
| 467 | } |
| 468 | } |
| 469 | |
| 470 | std::pair<Operation *, Value> LoopEmitter::emitForLoopOverTensorAtLvl( |
| 471 | OpBuilder &builder, Location loc, SparseIterator &iter, |
| 472 | MutableArrayRef<Value> reduc, bool isParallel) { |
| 473 | |
| 474 | // TODO: support dynamic slices. |
| 475 | // Uses the first dimension here to build the loop bound (which is also the |
| 476 | // biggest range). |
| 477 | |
| 478 | Value step = C_IDX(1); |
| 479 | auto [lo, hi] = iter.genForCond(b&: builder, l: loc); |
| 480 | Operation *loop = nullptr; |
| 481 | Value iv; |
| 482 | if (isParallel) { |
| 483 | scf::ParallelOp parOp = |
| 484 | builder.create<scf::ParallelOp>(loc, lo, hi, step, reduc); |
| 485 | builder.setInsertionPointToStart(parOp.getBody()); |
| 486 | assert(parOp.getNumReductions() == reduc.size()); |
| 487 | iv = parOp.getInductionVars()[0]; |
| 488 | |
| 489 | // In-place update on the reduction variable vector. |
| 490 | // Note that the init vals is not the actual reduction variables but instead |
| 491 | // used as a "special handle" to (temporarily) represent them. The |
| 492 | // expression on init vals will be moved into scf.reduce and replaced with |
| 493 | // the block arguments when exiting the loop (see exitForLoop). This is |
| 494 | // needed as we can not build the actual reduction block and get the actual |
| 495 | // reduction variable before users fill parallel loop body. |
| 496 | for (int i = 0, e = reduc.size(); i < e; i++) |
| 497 | reduc[i] = parOp.getInitVals()[i]; |
| 498 | loop = parOp; |
| 499 | } else { |
| 500 | scf::ForOp forOp = builder.create<scf::ForOp>(loc, lo, hi, step, reduc); |
| 501 | builder.setInsertionPointToStart(forOp.getBody()); |
| 502 | iv = forOp.getInductionVar(); |
| 503 | |
| 504 | // In-place update on the reduction variable vector. |
| 505 | assert(forOp.getNumRegionIterArgs() == reduc.size()); |
| 506 | for (int i = 0, e = reduc.size(); i < e; i++) |
| 507 | reduc[i] = forOp.getRegionIterArg(i); |
| 508 | loop = forOp; |
| 509 | } |
| 510 | assert(loop && iv); |
| 511 | |
| 512 | Value crd = iv; |
| 513 | if (!iter.randomAccessible()) { |
| 514 | iter.linkNewScope(pos: iv); |
| 515 | crd = iter.deref(b&: builder, l: loc); |
| 516 | } else { |
| 517 | iter.locate(b&: builder, l: loc, crd: iv); |
| 518 | } |
| 519 | |
| 520 | return {loop, crd}; |
| 521 | } |
| 522 | |
| 523 | std::pair<Operation *, Value> LoopEmitter::emitWhileLoopOverTensorsAtLvls( |
| 524 | OpBuilder &builder, Location loc, ArrayRef<SparseIterator *> spIters, |
| 525 | MutableArrayRef<Value> reduc, bool needsUniv) { |
| 526 | return genCoIteration(builder, loc, iters: spIters, reduc, |
| 527 | uniIdx: needsUniv ? loopSeqStack.back().first : nullptr); |
| 528 | } |
| 529 | |
| 530 | bool LoopEmitter::shouldIteratedByForLoop(ArrayRef<SparseIterator *> spIters) { |
| 531 | // If we need to co-iterate over two sparse tensors, we need a while loop |
| 532 | if (spIters.size() > 1) |
| 533 | return false; |
| 534 | |
| 535 | if (spIters.size() == 1) |
| 536 | return spIters.front()->iteratableByFor(); |
| 537 | |
| 538 | return true; |
| 539 | } |
| 540 | |
| 541 | Region *LoopEmitter::enterCurrentCoIterationCase(OpBuilder &builder, |
| 542 | Location loc, |
| 543 | I64BitSet caseBit, |
| 544 | unsigned caseIdx, |
| 545 | MutableArrayRef<Value> reduc) { |
| 546 | auto coIterOp = cast<CoIterateOp>(loopStack.back().loop); |
| 547 | SmallVector<Attribute> cases(coIterOp.getCases().getAsRange<Attribute>()); |
| 548 | cases[caseIdx] = builder.getI64IntegerAttr(caseBit); |
| 549 | |
| 550 | coIterOp.setCasesAttr(builder.getArrayAttr(cases)); |
| 551 | Region &caseRegion = coIterOp.getRegion(caseIdx); |
| 552 | assert(caseRegion.getBlocks().empty() && |
| 553 | "re-initialize the same coiteration case region." ); |
| 554 | |
| 555 | // Each block starts with by a list of user-provided iteration arguments. |
| 556 | TypeRange iterArgsTps = coIterOp.getInitArgs().getTypes(); |
| 557 | // Followed by a list of used coordinates of index type. |
| 558 | SmallVector<Type> blockArgTps(coIterOp.getCrdUsedLvls().count(), |
| 559 | builder.getIndexType()); |
| 560 | |
| 561 | blockArgTps.append(in_start: iterArgsTps.begin(), in_end: iterArgsTps.end()); |
| 562 | // Ends with a set of iterators that defines the actually iteration space. |
| 563 | for (auto i : caseBit.bits()) { |
| 564 | blockArgTps.push_back( |
| 565 | cast<IterSpaceType>(coIterOp.getIterSpaces()[i].getType()) |
| 566 | .getIteratorType()); |
| 567 | } |
| 568 | SmallVector<Location> locs(blockArgTps.size(), loc); |
| 569 | caseRegion.emplaceBlock().addArguments(types: blockArgTps, locs); |
| 570 | |
| 571 | // Entering the new region scope, updating the SSA chain. |
| 572 | builder.setInsertionPointToStart(&caseRegion.front()); |
| 573 | // Update the coordinates. |
| 574 | loopStack.back().iv = coIterOp.getCrds(caseIdx).front(); |
| 575 | // Updates loop iteration arguments. |
| 576 | ValueRange iterArgs = coIterOp.getRegionIterArgs(caseIdx); |
| 577 | llvm::copy(Range&: iterArgs, Out: reduc.begin()); |
| 578 | // Updates sparse iterator values. |
| 579 | ValueRange iters = coIterOp.getRegionIterators(caseIdx); |
| 580 | ArrayRef<TensorLevel> tidLvls = loopStack.back().tidLvls; |
| 581 | for (auto [i, tl] : llvm::enumerate(First: unpackTensorLevelRange(c&: tidLvls))) { |
| 582 | if (caseBit[i]) { |
| 583 | spIterVals[tl.first][tl.second] = iters.front(); |
| 584 | iters = iters.drop_front(); |
| 585 | } else { |
| 586 | spIterVals[tl.first][tl.second] = nullptr; |
| 587 | } |
| 588 | } |
| 589 | // Must have consumed all iterator SSA values. |
| 590 | assert(iters.empty()); |
| 591 | return &caseRegion; |
| 592 | } |
| 593 | |
| 594 | Operation *LoopEmitter::enterCoIterationOverTensorsAtLvls( |
| 595 | OpBuilder &builder, Location loc, ArrayRef<TensorLevel> tidLvls, |
| 596 | unsigned numCases, MutableArrayRef<Value> reduc, bool tryParallel, |
| 597 | bool needsUniv) { |
| 598 | // TODO: Argument `numCases` only used when generating iterator-based sparse |
| 599 | // loops. Simplify the code upon feature complete. |
| 600 | // TODO: handle coiteration with sparse iterator. |
| 601 | if (emitStrategy == SparseEmitStrategy::kSparseIterator) { |
| 602 | if (tidLvls.size() == 1) { |
| 603 | auto [tid, lvl] = unpackTensorLevel(tidLvl: tidLvls.front()); |
| 604 | Value t = tensors[tid]; |
| 605 | |
| 606 | // Extract and iterate over the iteration space. |
| 607 | ExtractIterSpaceOp = |
| 608 | lvl == 0 ? builder.create<ExtractIterSpaceOp>(loc, t) |
| 609 | : builder.create<ExtractIterSpaceOp>( |
| 610 | loc, t, spIterVals[tid][lvl - 1], lvl); |
| 611 | |
| 612 | IterateOp iterOp = builder.create<IterateOp>( |
| 613 | loc, extractSpaceOp.getExtractedSpace(), reduc); |
| 614 | spIterVals[tid][lvl] = iterOp.getIterator(); |
| 615 | |
| 616 | // Update the reduction varaibles. |
| 617 | llvm::copy(iterOp.getRegionIterArgs(), reduc.begin()); |
| 618 | // Set the insertion point to loop body. |
| 619 | builder.setInsertionPointToStart(iterOp.getBody()); |
| 620 | loopStack.emplace_back(tidLvls, iterOp, builder.getInsertionBlock(), |
| 621 | iterOp.getCrds().front(), loopTag); |
| 622 | return iterOp; |
| 623 | } |
| 624 | |
| 625 | // CoIteration Loops. |
| 626 | SmallVector<Value> spaces; |
| 627 | for (auto [tid, lvl] : unpackTensorLevelRange(c&: tidLvls)) { |
| 628 | Value t = tensors[tid]; |
| 629 | ExtractIterSpaceOp = |
| 630 | lvl == 0 ? builder.create<ExtractIterSpaceOp>(loc, t) |
| 631 | : builder.create<ExtractIterSpaceOp>( |
| 632 | loc, t, spIterVals[tid][lvl - 1], lvl); |
| 633 | spaces.push_back(Elt: extractSpaceOp.getExtractedSpace()); |
| 634 | } |
| 635 | auto coIterOp = builder.create<CoIterateOp>(loc, spaces, reduc, numCases); |
| 636 | // The CoIterationOp does not have insertion block nor induction variable. |
| 637 | // TODO: the `struct LoopInfo` should be simplied after full migration. |
| 638 | loopStack.emplace_back(tidLvls, coIterOp, /*insertion block*/ nullptr, |
| 639 | /*induction variable*/ nullptr, loopTag); |
| 640 | return coIterOp; |
| 641 | } |
| 642 | |
| 643 | // TODO: support multiple return on parallel for? |
| 644 | tryParallel = tryParallel && reduc.size() <= 1; |
| 645 | |
| 646 | SmallVector<SparseIterator *> raIters; |
| 647 | SmallVector<SparseIterator *> spIters; |
| 648 | categorizeIterators(tidLvls, raIters, spIters); |
| 649 | |
| 650 | // Only when there is at least one sparse conditions, do we really need the |
| 651 | // universal index. |
| 652 | // TODO: Maybe we should instead requires merger to pass in a valid value at |
| 653 | // the first place instead of adjusting it in LoopEmitter? |
| 654 | needsUniv = !spIters.empty() && needsUniv; |
| 655 | // The TensorLevel used for loop conditions. |
| 656 | // If there is any sparse level, we need to use the sparse condition. |
| 657 | // If all levels are dense, we can pick arbitrary one (dense slice-driven loop |
| 658 | // can be generated using a simple ForOp as well). |
| 659 | Operation *l = nullptr; |
| 660 | Value iv = nullptr; |
| 661 | SmallVector<TensorLevel> tls; |
| 662 | |
| 663 | // Generates loops differently depending on whether we need a slice-driven |
| 664 | // loop or a simple level traversal loop. |
| 665 | if (shouldIteratedByForLoop(spIters) && !needsUniv) { |
| 666 | assert(spIters.size() <= 1); |
| 667 | SparseIterator &it = spIters.empty() ? *raIters.front() : *spIters.front(); |
| 668 | std::tie(args&: l, args&: iv) = |
| 669 | emitForLoopOverTensorAtLvl(builder, loc, iter&: it, reduc, isParallel: tryParallel); |
| 670 | tls.push_back(Elt: makeTensorLevel(t: it.tid, l: it.lvl)); |
| 671 | } else { |
| 672 | for (auto *it : spIters) { |
| 673 | tls.push_back(Elt: makeTensorLevel(t: it->tid, l: it->lvl)); |
| 674 | } |
| 675 | |
| 676 | if (needsUniv) |
| 677 | for (auto *it : raIters) |
| 678 | tls.push_back(Elt: makeTensorLevel(t: it->tid, l: it->lvl)); |
| 679 | |
| 680 | std::tie(args&: l, args&: iv) = |
| 681 | emitWhileLoopOverTensorsAtLvls(builder, loc, spIters, reduc, needsUniv); |
| 682 | } |
| 683 | |
| 684 | // Enter dense tensor levels. |
| 685 | for (SparseIterator *it : raIters) |
| 686 | it->locate(b&: builder, l: loc, crd: iv); |
| 687 | |
| 688 | // NOTE: we can also prepare for next dim here in advance |
| 689 | // Pushes the loop into stack. |
| 690 | loopStack.emplace_back(tls, l, builder.getInsertionBlock(), iv, loopTag); |
| 691 | return l; |
| 692 | } |
| 693 | |
| 694 | void LoopEmitter::locateLvlAtAffineAddress(OpBuilder &builder, Location loc, |
| 695 | TensorLevel tidLvl, |
| 696 | AffineExpr lvlExpr) { |
| 697 | auto [tid, lvl] = unpackTensorLevel(tidLvl); |
| 698 | |
| 699 | const SparseIterator *parent = |
| 700 | lvl == 0 ? nullptr : iters[tid][lvl - 1].back().get(); |
| 701 | auto &it = getCurIterator(tid, lvl); |
| 702 | it.genInit(b&: builder, l: loc, p: parent); |
| 703 | |
| 704 | assert(it.kind == IterKind::kTrivial && it.randomAccessible()); |
| 705 | Value lvlCrd = genAffine(builder, loc, a: lvlExpr); |
| 706 | it.locate(b&: builder, l: loc, crd: lvlCrd); |
| 707 | } |
| 708 | |
| 709 | void LoopEmitter::prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc, |
| 710 | TensorId tid, Level lvl) { |
| 711 | // if this is the first level, there is no parent iterator for the current |
| 712 | // iterator. |
| 713 | // If the current iterator is a subsection-based iterator, the parent iterator |
| 714 | // is memorized by the iterator. |
| 715 | bool hasParent = lvl == 0 || !dependentLvlMap[tid][lvl].empty(); |
| 716 | |
| 717 | const SparseIterator *parent = |
| 718 | hasParent ? nullptr : iters[tid][lvl - 1].back().get(); |
| 719 | auto &it = getCurIterator(tid, lvl); |
| 720 | it.genInit(b&: builder, l: loc, p: parent); |
| 721 | |
| 722 | // Locates the randon accessible iterator to 0. |
| 723 | if (it.randomAccessible()) |
| 724 | it.locate(b&: builder, l: loc, C_IDX(0)); |
| 725 | } |
| 726 | |
| 727 | void LoopEmitter::exitForLoop(RewriterBase &rewriter, Location loc, |
| 728 | MutableArrayRef<Value> reduc) { |
| 729 | const LoopInfo &loopInfo = loopStack.back(); |
| 730 | if (emitStrategy == SparseEmitStrategy::kSparseIterator) { |
| 731 | auto iterateOp = llvm::cast<IterateOp>(loopInfo.loop); |
| 732 | assert(reduc.size() == iterateOp.getNumResults()); |
| 733 | rewriter.create<sparse_tensor::YieldOp>(loc, reduc); |
| 734 | // Exit the loop. |
| 735 | rewriter.setInsertionPointAfter(iterateOp); |
| 736 | // In-place update reduction variables. |
| 737 | llvm::copy(iterateOp.getResults(), reduc.begin()); |
| 738 | return; |
| 739 | } |
| 740 | if (auto forOp = llvm::dyn_cast<scf::ForOp>(loopInfo.loop)) { |
| 741 | if (!reduc.empty()) { |
| 742 | assert(reduc.size() == forOp.getNumResults()); |
| 743 | rewriter.create<scf::YieldOp>(loc, reduc); |
| 744 | } |
| 745 | // Exit the loop. |
| 746 | rewriter.setInsertionPointAfter(forOp); |
| 747 | // In-place update reduction variables. |
| 748 | llvm::copy(forOp.getResults(), reduc.begin()); |
| 749 | } else { |
| 750 | auto parOp = llvm::cast<scf::ParallelOp>(loopInfo.loop); |
| 751 | if (!reduc.empty()) { |
| 752 | assert(reduc.size() == parOp.getInitVals().size() && reduc.size() == 1); |
| 753 | Operation *redExp = reduc.front().getDefiningOp(); |
| 754 | // Reduction expression should have no use. |
| 755 | assert(redExp->getUses().empty()); |
| 756 | // This must be a binary operation. |
| 757 | // NOTE: This is users' responsibility to ensure the operation are |
| 758 | // commutative. |
| 759 | assert(redExp->getNumOperands() == 2 && redExp->getNumResults() == 1); |
| 760 | |
| 761 | Value redVal = parOp.getInitVals().front(); |
| 762 | Value curVal; |
| 763 | if (redExp->getOperand(idx: 0) == redVal) |
| 764 | curVal = redExp->getOperand(idx: 1); |
| 765 | else if (redExp->getOperand(idx: 1) == redVal) |
| 766 | curVal = redExp->getOperand(idx: 0); |
| 767 | // One of the operands must be the init value (which is also the |
| 768 | // previous reduction value). |
| 769 | assert(curVal); |
| 770 | #ifndef NDEBUG |
| 771 | // The reduction expression should be the only user of the reduction val |
| 772 | // inside the parallel for. |
| 773 | unsigned numUsers = 0; |
| 774 | for (Operation *op : redVal.getUsers()) { |
| 775 | if (op->getParentOp() == parOp) |
| 776 | numUsers++; |
| 777 | } |
| 778 | assert(numUsers == 1); |
| 779 | #endif // NDEBUG |
| 780 | |
| 781 | rewriter.setInsertionPointAfter(redExp); |
| 782 | auto redOp = rewriter.create<scf::ReduceOp>(loc, curVal); |
| 783 | // Attach to the reduction op. |
| 784 | Block *redBlock = &redOp.getReductions().front().front(); |
| 785 | rewriter.setInsertionPointToEnd(redBlock); |
| 786 | Operation *newRed = rewriter.clone(op&: *redExp); |
| 787 | // Replaces arguments of the reduction expression by using the block |
| 788 | // arguments from scf.reduce. |
| 789 | rewriter.modifyOpInPlace( |
| 790 | root: newRed, callable: [&]() { newRed->setOperands(redBlock->getArguments()); }); |
| 791 | // Erases the out-dated reduction expression. |
| 792 | rewriter.eraseOp(op: redExp); |
| 793 | rewriter.setInsertionPointToEnd(redBlock); |
| 794 | rewriter.create<scf::ReduceReturnOp>(loc, newRed->getResult(0)); |
| 795 | } |
| 796 | rewriter.setInsertionPointAfter(parOp); |
| 797 | // In-place update reduction variables. |
| 798 | for (unsigned i = 0, e = parOp.getResults().size(); i < e; i++) |
| 799 | reduc[i] = parOp.getResult(i); |
| 800 | } |
| 801 | } |
| 802 | |
| 803 | void LoopEmitter::exitWhileLoop(OpBuilder &builder, Location loc, |
| 804 | MutableArrayRef<Value> reduc) { |
| 805 | const LoopInfo &loopInfo = loopStack.back(); |
| 806 | auto whileOp = llvm::cast<scf::WhileOp>(loopInfo.loop); |
| 807 | Value iv = loopInfo.iv; |
| 808 | Value one = C_IDX(1); |
| 809 | |
| 810 | // Finalize the induction. Note that the induction could be performed |
| 811 | // in the individual if-branches to avoid re-evaluating the conditions. |
| 812 | // However, that would result in a rather elaborate forest of yield |
| 813 | // instructions during code generation. Moreover, performing the induction |
| 814 | // after the if-statements more closely resembles code generated by TACO. |
| 815 | SmallVector<Value> operands; |
| 816 | ValueRange whileRes = whileOp.getResults(); |
| 817 | |
| 818 | for (auto [tid, lvl] : unpackTensorLevelRange(c: loopInfo.tidLvls)) { |
| 819 | SparseIterator &it = getCurIterator(tid, lvl); |
| 820 | if (!it.randomAccessible()) { |
| 821 | // Forward the sparse iterator. |
| 822 | Value cmp = CMPI(eq, it.getCrd(), iv); |
| 823 | it.forwardIf(b&: builder, l: loc, cond: cmp); |
| 824 | operands.append(in_start: it.getCursor().begin(), in_end: it.getCursor().end()); |
| 825 | // const Value newPos = whileOp->getResult(o++); |
| 826 | // Following loops continue iteration from the break point of the |
| 827 | // current while loop. |
| 828 | whileRes = it.linkNewScope(pos: whileRes); |
| 829 | } else { |
| 830 | // Make sure randomly accessible (dense) iterator is set to the right |
| 831 | // position according to the universal index. |
| 832 | Value uniIdx = whileOp.getResults().back(); |
| 833 | it.locate(b&: builder, l: loc, crd: uniIdx); |
| 834 | } |
| 835 | } |
| 836 | |
| 837 | // Reduction value from users. |
| 838 | for (auto &i : reduc) { |
| 839 | operands.push_back(Elt: i); |
| 840 | // Update user reduction variables. |
| 841 | i = whileRes.front(); |
| 842 | whileRes = whileRes.drop_front(); |
| 843 | } |
| 844 | |
| 845 | // An (optional) universal index. |
| 846 | if (operands.size() < whileOp.getNumResults()) { |
| 847 | assert(operands.size() + 1 == whileOp.getNumResults()); |
| 848 | // The last one is the universial index. |
| 849 | operands.push_back(ADDI(iv, one)); |
| 850 | // update the loop starting point of current loop sequence |
| 851 | loopSeqStack.back().first = whileOp->getResults().back(); |
| 852 | } |
| 853 | |
| 854 | if (!operands.empty()) |
| 855 | YIELD(operands); |
| 856 | |
| 857 | builder.setInsertionPointAfter(whileOp); |
| 858 | } |
| 859 | |
| 860 | void LoopEmitter::exitCurrentLoop(RewriterBase &rewriter, Location loc, |
| 861 | MutableArrayRef<Value> reduc) { |
| 862 | // Clean up the values, it would help use to discover potential bug at a |
| 863 | // earlier stage (instead of silently using a wrong value). |
| 864 | const LoopInfo &loopInfo = loopStack.back(); |
| 865 | if (emitStrategy == SparseEmitStrategy::kSparseIterator) { |
| 866 | Operation *p = loopInfo.loop; |
| 867 | if (isa<IterateOp>(p)) |
| 868 | rewriter.create<sparse_tensor::YieldOp>(loc, reduc); |
| 869 | |
| 870 | // Exit the loop. |
| 871 | rewriter.setInsertionPointAfter(p); |
| 872 | // In-place update reduction variables. |
| 873 | llvm::copy(Range: p->getResults(), Out: reduc.begin()); |
| 874 | loopStack.pop_back(); |
| 875 | return; |
| 876 | } |
| 877 | |
| 878 | // Sets the insertion point to the right position. |
| 879 | rewriter.setInsertionPointToEnd(loopInfo.userCodeBlock); |
| 880 | if (!loopInfo.userCodeBlock->empty() && |
| 881 | llvm::isa<scf::YieldOp>(&loopInfo.userCodeBlock->back())) { |
| 882 | // scf::While/For inserts an implicit yield op when there is no loop |
| 883 | // iter args. In this case, we need to insert the code before the yield. |
| 884 | assert(loopInfo.userCodeBlock->back().getNumResults() == 0); |
| 885 | rewriter.setInsertionPoint(&loopInfo.userCodeBlock->back()); |
| 886 | } |
| 887 | |
| 888 | if (llvm::isa<scf::WhileOp>(loopInfo.loop)) { |
| 889 | exitWhileLoop(builder&: rewriter, loc, reduc); |
| 890 | } else { |
| 891 | exitForLoop(rewriter, loc, reduc); |
| 892 | } |
| 893 | |
| 894 | assert(loopStack.size() == loopSeqStack.size()); |
| 895 | loopStack.pop_back(); |
| 896 | } |
| 897 | |
| 898 | //===----------------------------------------------------------------------===// |
| 899 | // Loop generation utils |
| 900 | //===----------------------------------------------------------------------===// |
| 901 | |
| 902 | std::pair<Operation *, Value> sparse_tensor::genCoIteration( |
| 903 | OpBuilder &builder, Location loc, ArrayRef<SparseIterator *> spIters, |
| 904 | MutableArrayRef<Value> reduc, Value uniIdx, bool userReducFirst) { |
| 905 | // NOTE: the slice driven tensor-related reduction variable must |
| 906 | // appear before normal tensors. |
| 907 | |
| 908 | // The set of induction variables for the while loop. |
| 909 | SmallVector<Value> ivs; |
| 910 | |
| 911 | // TODO: remove the flag after full migration. Currently |
| 912 | // `sparse_tensor.coiterate` operation (must) put user provided reduction |
| 913 | // values at the front of the block list, while direct sparsification to scf |
| 914 | // loops put them at the end. |
| 915 | if (userReducFirst) |
| 916 | ivs.append(in_start: reduc.begin(), in_end: reduc.end()); |
| 917 | |
| 918 | // Construct the while-loop with a parameter for each coordinate. |
| 919 | for (SparseIterator *it : spIters) { |
| 920 | ValueRange itVals = it->getCursor(); |
| 921 | ivs.append(in_start: itVals.begin(), in_end: itVals.end()); |
| 922 | } |
| 923 | |
| 924 | if (!userReducFirst) |
| 925 | ivs.append(in_start: reduc.begin(), in_end: reduc.end()); |
| 926 | |
| 927 | // Update universal index. |
| 928 | if (uniIdx) |
| 929 | ivs.push_back(Elt: uniIdx); |
| 930 | |
| 931 | // Ensures all operands are valid. |
| 932 | assert(llvm::all_of(ivs, [](Value v) { return v != nullptr; })); |
| 933 | TypeRange types = ValueRange(ivs).getTypes(); |
| 934 | auto whileOp = builder.create<scf::WhileOp>(loc, types, ivs); |
| 935 | |
| 936 | SmallVector<Location> locs(types.size(), loc); |
| 937 | Block *before = builder.createBlock(&whileOp.getBefore(), {}, types, locs); |
| 938 | Block *after = builder.createBlock(&whileOp.getAfter(), {}, types, locs); |
| 939 | |
| 940 | // Generates loop conditions. |
| 941 | builder.setInsertionPointToStart(before); |
| 942 | ValueRange bArgs = before->getArguments(); |
| 943 | Value whileCond = nullptr; // bool values for loop condition. |
| 944 | |
| 945 | for (SparseIterator *it : spIters) { |
| 946 | auto [cond, remArgs] = it->genWhileCond(b&: builder, l: loc, vs: bArgs); |
| 947 | whileCond = !whileCond ? cond : ANDI(whileCond, cond); |
| 948 | bArgs = remArgs; |
| 949 | } |
| 950 | // The remaining block arguments are user-provided reduction values and an |
| 951 | // optional universal index. Make sure their sizes match. |
| 952 | assert(bArgs.size() == reduc.size() + (uniIdx ? 1 : 0)); |
| 953 | builder.create<scf::ConditionOp>(loc, whileCond, before->getArguments()); |
| 954 | |
| 955 | // Generates loop body. |
| 956 | builder.setInsertionPointToStart(after); |
| 957 | ValueRange aArgs = after->getArguments(); |
| 958 | |
| 959 | for (SparseIterator *it : spIters) { |
| 960 | aArgs = it->linkNewScope(pos: aArgs); |
| 961 | // Dereference the iterator to cache the coordinate. |
| 962 | it->deref(b&: builder, l: loc); |
| 963 | } |
| 964 | |
| 965 | // In-place update on reduction variable. |
| 966 | for (unsigned i = 0, e = reduc.size(); i < e; i++) |
| 967 | reduc[i] = aArgs[i]; |
| 968 | |
| 969 | Value min; |
| 970 | // Finds the minimum coordinate |
| 971 | if (!uniIdx) { |
| 972 | for (SparseIterator *it : spIters) { |
| 973 | if (min) { |
| 974 | Value cmp = CMPI(ult, it->getCrd(), min); |
| 975 | min = SELECT(cmp, it->getCrd(), min); |
| 976 | } else { |
| 977 | min = it->getCrd(); |
| 978 | } |
| 979 | } |
| 980 | } else { |
| 981 | // Otherwise, universal index is the minimal pos. |
| 982 | min = whileOp.getAfterArguments().back(); |
| 983 | } |
| 984 | |
| 985 | return {whileOp, min}; |
| 986 | } |
| 987 | |
| 988 | #undef CMPI |
| 989 | #undef C_IDX |
| 990 | #undef YIELD |
| 991 | #undef ADDI |
| 992 | #undef ANDI |
| 993 | #undef SUBI |
| 994 | #undef MULI |
| 995 | #undef SELECT |
| 996 | |