| 1 | |
| 2 | #include "Utils/CodegenUtils.h" |
| 3 | #include "Utils/LoopEmitter.h" |
| 4 | #include "Utils/SparseTensorIterator.h" |
| 5 | |
| 6 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| 7 | #include "mlir/Dialect/SCF/IR/SCF.h" |
| 8 | #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" |
| 9 | #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" |
| 10 | #include "mlir/Transforms/DialectConversion.h" |
| 11 | |
| 12 | using namespace mlir; |
| 13 | using namespace mlir::sparse_tensor; |
| 14 | |
| 15 | static void convertLevelType(SparseTensorEncodingAttr enc, Level lvl, |
| 16 | SmallVectorImpl<Type> &fields) { |
| 17 | // Position and coordinate buffer in the sparse structure. |
| 18 | if (enc.getLvlType(lvl).isWithPosLT()) |
| 19 | fields.push_back(Elt: enc.getPosMemRefType()); |
| 20 | if (enc.getLvlType(lvl).isWithCrdLT()) |
| 21 | fields.push_back(Elt: enc.getCrdMemRefType()); |
| 22 | // One index for shape bound (result from lvlOp). |
| 23 | fields.push_back(IndexType::get(enc.getContext())); |
| 24 | } |
| 25 | |
| 26 | static std::optional<LogicalResult> |
| 27 | convertIterSpaceType(IterSpaceType itSp, SmallVectorImpl<Type> &fields) { |
| 28 | |
| 29 | auto idxTp = IndexType::get(itSp.getContext()); |
| 30 | for (Level l = itSp.getLoLvl(); l < itSp.getHiLvl(); l++) |
| 31 | convertLevelType(itSp.getEncoding(), l, fields); |
| 32 | |
| 33 | // Two indices for lower and upper bound (we only need one pair for the last |
| 34 | // iteration space). |
| 35 | fields.append({idxTp, idxTp}); |
| 36 | return success(); |
| 37 | } |
| 38 | |
| 39 | static std::optional<LogicalResult> |
| 40 | convertIteratorType(IteratorType itTp, SmallVectorImpl<Type> &fields) { |
| 41 | // The actually Iterator Values (that are updated every iteration). |
| 42 | auto idxTp = IndexType::get(itTp.getContext()); |
| 43 | // TODO: handle batch dimension. |
| 44 | assert(itTp.getEncoding().getBatchLvlRank() == 0); |
| 45 | if (!itTp.isUnique()) { |
| 46 | // Segment high for non-unique iterator. |
| 47 | fields.push_back(Elt: idxTp); |
| 48 | } |
| 49 | fields.push_back(Elt: idxTp); |
| 50 | return success(); |
| 51 | } |
| 52 | |
| 53 | static ValueRange |
| 54 | genCoIterateBranchNest(PatternRewriter &rewriter, Location loc, CoIterateOp op, |
| 55 | Value loopCrd, |
| 56 | ArrayRef<std::unique_ptr<SparseIterator>> iters, |
| 57 | ArrayRef<Block *> newBlocks, ArrayRef<Block *> oldBlocks, |
| 58 | ArrayRef<Value> userReduc) { |
| 59 | if (newBlocks.empty()) |
| 60 | return userReduc; |
| 61 | |
| 62 | // The current branch that we are handling. |
| 63 | Block *newBlock = newBlocks.front(); |
| 64 | Block *oldBlock = oldBlocks.front(); |
| 65 | Value casePred = constantI1(builder&: rewriter, loc, b: true); |
| 66 | I64BitSet caseBits = |
| 67 | op.getRegionDefinedSpace(newBlock->getParent()->getRegionNumber()); |
| 68 | for (unsigned i : caseBits.bits()) { |
| 69 | SparseIterator *it = iters[i].get(); |
| 70 | Value pred = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, |
| 71 | it->getCrd(), loopCrd); |
| 72 | casePred = rewriter.create<arith::AndIOp>(loc, casePred, pred); |
| 73 | } |
| 74 | scf::IfOp ifOp = rewriter.create<scf::IfOp>( |
| 75 | loc, ValueRange(userReduc).getTypes(), casePred, /*else=*/true); |
| 76 | rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); |
| 77 | |
| 78 | // Erase the empty block. |
| 79 | rewriter.eraseBlock(block: &ifOp.getThenRegion().front()); |
| 80 | // Set up block arguments: user-provided values -> loop coord -> iterators. |
| 81 | SmallVector<Value> blockArgs(userReduc); |
| 82 | blockArgs.push_back(Elt: loopCrd); |
| 83 | for (unsigned idx : caseBits.bits()) |
| 84 | llvm::append_range(blockArgs, iters[idx]->getCursor()); |
| 85 | |
| 86 | // Map the old block arguments, because the dialect conversion driver does |
| 87 | // not immediately perform SSA value replacements. This function is still |
| 88 | // seeing the old uses. |
| 89 | IRMapping mapping; |
| 90 | for (auto [from, to] : llvm::zip_equal(t: oldBlock->getArguments(), u&: blockArgs)) { |
| 91 | mapping.map(from, to); |
| 92 | } |
| 93 | |
| 94 | // Clone the region, we can not erase the region now because the same region |
| 95 | // might be a subcase for multiple lattice point. |
| 96 | rewriter.cloneRegionBefore(*newBlock->getParent(), ifOp.getThenRegion(), |
| 97 | ifOp.getThenRegion().begin(), mapping); |
| 98 | // Remove the block arguments, they were already replaced via `mapping`. |
| 99 | ifOp.getThenRegion().front().eraseArguments(0, blockArgs.size()); |
| 100 | |
| 101 | // replace sparse_tensor::YieldOp -> scf::YieldOp |
| 102 | auto spY = cast<sparse_tensor::YieldOp>(&ifOp.getThenRegion().front().back()); |
| 103 | ValueRange yields = spY.getResults(); |
| 104 | rewriter.eraseOp(op: spY); |
| 105 | rewriter.setInsertionPointToEnd(&ifOp.getThenRegion().front()); |
| 106 | rewriter.create<scf::YieldOp>(loc, yields); |
| 107 | |
| 108 | // Generates remaining case recursively. |
| 109 | rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); |
| 110 | ValueRange res = genCoIterateBranchNest(rewriter, loc, op, loopCrd, iters, |
| 111 | newBlocks.drop_front(), |
| 112 | oldBlocks.drop_front(), userReduc); |
| 113 | if (!res.empty()) |
| 114 | rewriter.create<scf::YieldOp>(loc, res); |
| 115 | |
| 116 | rewriter.setInsertionPointAfter(ifOp); |
| 117 | return ifOp.getResults(); |
| 118 | } |
| 119 | |
| 120 | static ValueRange genLoopWithIterator( |
| 121 | PatternRewriter &rewriter, Location loc, SparseIterator *it, |
| 122 | ValueRange reduc, |
| 123 | function_ref<SmallVector<Value>(PatternRewriter &rewriter, Location loc, |
| 124 | Region &loopBody, SparseIterator *it, |
| 125 | ValueRange reduc)> |
| 126 | bodyBuilder) { |
| 127 | if (it->iteratableByFor()) { |
| 128 | auto [lo, hi] = it->genForCond(b&: rewriter, l: loc); |
| 129 | Value step = constantIndex(builder&: rewriter, loc, i: 1); |
| 130 | scf::ForOp forOp = rewriter.create<scf::ForOp>( |
| 131 | loc, lo, hi, step, reduc, |
| 132 | [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) { |
| 133 | // Empty builder function to ensure that no terminator is created. |
| 134 | }); |
| 135 | { |
| 136 | OpBuilder::InsertionGuard guard(rewriter); |
| 137 | it->linkNewScope(pos: forOp.getInductionVar()); |
| 138 | rewriter.setInsertionPointToStart(forOp.getBody()); |
| 139 | SmallVector<Value> ret = bodyBuilder(rewriter, loc, forOp.getBodyRegion(), |
| 140 | it, forOp.getRegionIterArgs()); |
| 141 | |
| 142 | rewriter.setInsertionPointToEnd(forOp.getBody()); |
| 143 | rewriter.create<scf::YieldOp>(loc, ret); |
| 144 | } |
| 145 | return forOp.getResults(); |
| 146 | } |
| 147 | |
| 148 | SmallVector<Value> ivs(reduc); |
| 149 | llvm::append_range(C&: ivs, R: it->getCursor()); |
| 150 | |
| 151 | TypeRange types = ValueRange(ivs).getTypes(); |
| 152 | auto whileOp = rewriter.create<scf::WhileOp>(loc, types, ivs); |
| 153 | { |
| 154 | OpBuilder::InsertionGuard guard(rewriter); |
| 155 | // Generates loop conditions. |
| 156 | SmallVector<Location> l(types.size(), loc); |
| 157 | Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, types, l); |
| 158 | rewriter.setInsertionPointToStart(before); |
| 159 | ValueRange bArgs = before->getArguments(); |
| 160 | auto [whileCond, remArgs] = it->genWhileCond(b&: rewriter, l: loc, vs: bArgs); |
| 161 | rewriter.create<scf::ConditionOp>(loc, whileCond, before->getArguments()); |
| 162 | |
| 163 | // Delegates loop body generation. |
| 164 | Region &dstRegion = whileOp.getAfter(); |
| 165 | Block *after = rewriter.createBlock(parent: &dstRegion, insertPt: {}, argTypes: types, locs: l); |
| 166 | ValueRange aArgs = whileOp.getAfterArguments(); |
| 167 | it->linkNewScope(pos: aArgs.drop_front(n: reduc.size())); |
| 168 | aArgs = aArgs.take_front(n: reduc.size()); |
| 169 | |
| 170 | rewriter.setInsertionPointToStart(after); |
| 171 | SmallVector<Value> ret = bodyBuilder(rewriter, loc, dstRegion, it, aArgs); |
| 172 | rewriter.setInsertionPointToEnd(after); |
| 173 | |
| 174 | // Forward loops |
| 175 | SmallVector<Value> yields; |
| 176 | llvm::append_range(C&: yields, R&: ret); |
| 177 | llvm::append_range(C&: yields, R: it->forward(b&: rewriter, l: loc)); |
| 178 | rewriter.create<scf::YieldOp>(loc, yields); |
| 179 | } |
| 180 | return whileOp.getResults().drop_front(it->getCursor().size()); |
| 181 | } |
| 182 | |
| 183 | namespace { |
| 184 | |
| 185 | /// Sparse codegen rule for number of entries operator. |
| 186 | class |
| 187 | : public OpConversionPattern<ExtractIterSpaceOp> { |
| 188 | public: |
| 189 | using OpConversionPattern::OpConversionPattern; |
| 190 | LogicalResult |
| 191 | matchAndRewrite(ExtractIterSpaceOp op, OneToNOpAdaptor adaptor, |
| 192 | ConversionPatternRewriter &rewriter) const override { |
| 193 | Location loc = op.getLoc(); |
| 194 | |
| 195 | // Construct the iteration space. |
| 196 | SparseIterationSpace space(loc, rewriter, |
| 197 | llvm::getSingleElement(adaptor.getTensor()), 0, |
| 198 | op.getLvlRange(), adaptor.getParentIter()); |
| 199 | |
| 200 | SmallVector<Value> result = space.toValues(); |
| 201 | rewriter.replaceOpWithMultiple(op, {result}); |
| 202 | return success(); |
| 203 | } |
| 204 | }; |
| 205 | |
| 206 | /// Sparse codegen rule for number of entries operator. |
| 207 | class : public OpConversionPattern<ExtractValOp> { |
| 208 | public: |
| 209 | using OpConversionPattern::OpConversionPattern; |
| 210 | LogicalResult |
| 211 | matchAndRewrite(ExtractValOp op, OneToNOpAdaptor adaptor, |
| 212 | ConversionPatternRewriter &rewriter) const override { |
| 213 | Location loc = op.getLoc(); |
| 214 | Value pos = adaptor.getIterator().back(); |
| 215 | Value valBuf = rewriter.create<ToValuesOp>( |
| 216 | loc, llvm::getSingleElement(adaptor.getTensor())); |
| 217 | rewriter.replaceOpWithNewOp<memref::LoadOp>(op, valBuf, pos); |
| 218 | return success(); |
| 219 | } |
| 220 | }; |
| 221 | |
| 222 | class SparseIterateOpConverter : public OpConversionPattern<IterateOp> { |
| 223 | public: |
| 224 | using OpConversionPattern::OpConversionPattern; |
| 225 | LogicalResult |
| 226 | matchAndRewrite(IterateOp op, OneToNOpAdaptor adaptor, |
| 227 | ConversionPatternRewriter &rewriter) const override { |
| 228 | if (!op.getCrdUsedLvls().empty()) |
| 229 | return rewriter.notifyMatchFailure( |
| 230 | op, "non-empty coordinates list not implemented." ); |
| 231 | |
| 232 | Location loc = op.getLoc(); |
| 233 | |
| 234 | auto iterSpace = SparseIterationSpace::fromValues( |
| 235 | op.getIterSpace().getType(), adaptor.getIterSpace(), 0); |
| 236 | |
| 237 | std::unique_ptr<SparseIterator> it = |
| 238 | iterSpace.extractIterator(rewriter, loc); |
| 239 | |
| 240 | SmallVector<Value> ivs; |
| 241 | for (ValueRange inits : adaptor.getInitArgs()) |
| 242 | llvm::append_range(ivs, inits); |
| 243 | |
| 244 | // Type conversion on iterate op block. |
| 245 | unsigned numOrigArgs = op.getBody()->getArgumentTypes().size(); |
| 246 | TypeConverter::SignatureConversion signatureConversion(numOrigArgs); |
| 247 | if (failed(typeConverter->convertSignatureArgs( |
| 248 | op.getBody()->getArgumentTypes(), signatureConversion))) |
| 249 | return rewriter.notifyMatchFailure( |
| 250 | op, "failed to convert iterate region argurment types" ); |
| 251 | |
| 252 | Block *block = rewriter.applySignatureConversion( |
| 253 | block: op.getBody(), conversion&: signatureConversion, converter: getTypeConverter()); |
| 254 | ValueRange ret = genLoopWithIterator( |
| 255 | rewriter, loc, it: it.get(), reduc: ivs, |
| 256 | bodyBuilder: [block](PatternRewriter &rewriter, Location loc, Region &loopBody, |
| 257 | SparseIterator *it, ValueRange reduc) -> SmallVector<Value> { |
| 258 | SmallVector<Value> blockArgs(reduc); |
| 259 | // TODO: Also appends coordinates if used. |
| 260 | // blockArgs.push_back(it->deref(rewriter, loc)); |
| 261 | llvm::append_range(C&: blockArgs, R: it->getCursor()); |
| 262 | |
| 263 | Block *dstBlock = &loopBody.getBlocks().front(); |
| 264 | rewriter.inlineBlockBefore(source: block, dest: dstBlock, before: dstBlock->end(), |
| 265 | argValues: blockArgs); |
| 266 | auto yield = llvm::cast<sparse_tensor::YieldOp>(dstBlock->back()); |
| 267 | // We can not use ValueRange as the operation holding the values will |
| 268 | // be destroyed. |
| 269 | SmallVector<Value> result(yield.getResults()); |
| 270 | rewriter.eraseOp(op: yield); |
| 271 | return result; |
| 272 | }); |
| 273 | |
| 274 | rewriter.replaceOp(op, ret); |
| 275 | return success(); |
| 276 | } |
| 277 | }; |
| 278 | |
| 279 | class SparseCoIterateOpConverter : public OpConversionPattern<CoIterateOp> { |
| 280 | using OpConversionPattern::OpConversionPattern; |
| 281 | |
| 282 | LogicalResult |
| 283 | matchAndRewrite(CoIterateOp op, OneToNOpAdaptor adaptor, |
| 284 | ConversionPatternRewriter &rewriter) const override { |
| 285 | assert(op.getSpaceDim() == 1 && "Not implemented" ); |
| 286 | Location loc = op.getLoc(); |
| 287 | |
| 288 | I64BitSet denseBits(0); |
| 289 | for (auto [idx, spaceTp] : llvm::enumerate(op.getIterSpaces().getTypes())) |
| 290 | if (all_of(cast<IterSpaceType>(spaceTp).getLvlTypes(), isDenseLT)) |
| 291 | denseBits.set(idx); |
| 292 | |
| 293 | // If there exists a case that only contains dense spaces. I.e., case |
| 294 | // bits is a subset of dense bits, or when there is a full empty case (due |
| 295 | // to complements), we need a universal pointer to forward the coiteration |
| 296 | // loop. |
| 297 | bool needUniv = |
| 298 | any_of(op.getRegionDefinedSpaces(), [denseBits](I64BitSet caseBits) { |
| 299 | // A case for complement. |
| 300 | if (caseBits.count() == 0) |
| 301 | return true; |
| 302 | // An all-dense case. |
| 303 | return caseBits.isSubSetOf(p: denseBits); |
| 304 | }); |
| 305 | assert(!needUniv && "Not implemented" ); |
| 306 | (void)needUniv; |
| 307 | |
| 308 | SmallVector<Block *> newBlocks; |
| 309 | DenseMap<Block *, Block *> newToOldBlockMap; |
| 310 | for (Region ®ion : op.getCaseRegions()) { |
| 311 | // Do a one-shot type conversion on all region blocks, since the same |
| 312 | // region might be used multiple time. |
| 313 | Block *block = ®ion.getBlocks().front(); |
| 314 | TypeConverter::SignatureConversion blockTypeMapping( |
| 315 | block->getArgumentTypes().size()); |
| 316 | if (failed(typeConverter->convertSignatureArgs(block->getArgumentTypes(), |
| 317 | blockTypeMapping))) { |
| 318 | return rewriter.notifyMatchFailure( |
| 319 | op, "failed to convert coiterate region argurment types" ); |
| 320 | } |
| 321 | |
| 322 | newBlocks.push_back(rewriter.applySignatureConversion( |
| 323 | block, blockTypeMapping, getTypeConverter())); |
| 324 | newToOldBlockMap[newBlocks.back()] = block; |
| 325 | } |
| 326 | |
| 327 | SmallVector<SparseIterationSpace> spaces; |
| 328 | SmallVector<std::unique_ptr<SparseIterator>> iters; |
| 329 | for (auto [spaceTp, spaceVals] : llvm::zip_equal( |
| 330 | op.getIterSpaces().getTypes(), adaptor.getIterSpaces())) { |
| 331 | // TODO: do we really need tid? |
| 332 | spaces.push_back(SparseIterationSpace::fromValues( |
| 333 | cast<IterSpaceType>(spaceTp), spaceVals, /*tid=*/0)); |
| 334 | // Extract the iterator. |
| 335 | iters.push_back(spaces.back().extractIterator(rewriter, loc)); |
| 336 | } |
| 337 | |
| 338 | auto getFilteredIters = [&iters](I64BitSet caseBits) { |
| 339 | // Retrives a vector of pointers to the iterators used in the case. |
| 340 | SmallVector<SparseIterator *> validIters; |
| 341 | for (auto idx : caseBits.bits()) |
| 342 | validIters.push_back(Elt: iters[idx].get()); |
| 343 | return validIters; |
| 344 | }; |
| 345 | |
| 346 | // Get a flattened user-provided loop reduction values. |
| 347 | SmallVector<Value> userReduc; |
| 348 | for (ValueRange r : adaptor.getInitArgs()) |
| 349 | llvm::append_range(userReduc, r); |
| 350 | |
| 351 | // TODO: we need to sort the cases such that they appears in lexical order. |
| 352 | // Although sparsification always generates cases in that order, it might |
| 353 | // not be the case for human-written code. |
| 354 | |
| 355 | // Generates a loop sequence, one loop per case. |
| 356 | for (auto [r, caseBits] : |
| 357 | llvm::zip_equal(newBlocks, op.getRegionDefinedSpaces())) { |
| 358 | assert(caseBits.count() > 0 && "Complement space not implemented" ); |
| 359 | |
| 360 | // Retrives a vector of pointers to the iterators used in the case. |
| 361 | SmallVector<SparseIterator *> validIters = getFilteredIters(caseBits); |
| 362 | |
| 363 | if (validIters.size() > 1) { |
| 364 | auto [loop, loopCrd] = |
| 365 | genCoIteration(rewriter, loc, validIters, userReduc, |
| 366 | /*uniIdx=*/nullptr, /*userReducFirst=*/true); |
| 367 | |
| 368 | // 1st. find all the cases that is a strict subset of the current case |
| 369 | // condition, for which we generate one branch per case inside the loop. |
| 370 | // The subcases are never empty, it must contains at least the current |
| 371 | // region itself. |
| 372 | // TODO: these cases should be sorted. |
| 373 | SmallVector<Region *> subCases = |
| 374 | op.getSubCasesOf(r->getParent()->getRegionNumber()); |
| 375 | SmallVector<Block *> newBlocks, oldBlocks; |
| 376 | for (Region *r : subCases) { |
| 377 | newBlocks.push_back(&r->front()); |
| 378 | oldBlocks.push_back(newToOldBlockMap[newBlocks.back()]); |
| 379 | } |
| 380 | assert(!subCases.empty()); |
| 381 | |
| 382 | ValueRange res = genCoIterateBranchNest( |
| 383 | rewriter, loc, op, loopCrd, iters, newBlocks, oldBlocks, userReduc); |
| 384 | |
| 385 | SmallVector<Value> nextIterYields(res); |
| 386 | // 2nd. foward the loop. |
| 387 | for (SparseIterator *it : validIters) { |
| 388 | Value cmp = rewriter.create<arith::CmpIOp>( |
| 389 | loc, arith::CmpIPredicate::eq, it->getCrd(), loopCrd); |
| 390 | it->forwardIf(rewriter, loc, cmp); |
| 391 | llvm::append_range(nextIterYields, it->getCursor()); |
| 392 | } |
| 393 | rewriter.create<scf::YieldOp>(loc, nextIterYields); |
| 394 | |
| 395 | // Exit the loop, relink the iterator SSA value. |
| 396 | rewriter.setInsertionPointAfter(loop); |
| 397 | ValueRange iterVals = loop->getResults().drop_front(userReduc.size()); |
| 398 | for (SparseIterator *it : validIters) |
| 399 | iterVals = it->linkNewScope(iterVals); |
| 400 | assert(iterVals.empty()); |
| 401 | |
| 402 | ValueRange curResult = loop->getResults().take_front(userReduc.size()); |
| 403 | userReduc.assign(curResult.begin(), curResult.end()); |
| 404 | } else { |
| 405 | // This is a simple iteration loop. |
| 406 | assert(caseBits.count() == 1); |
| 407 | |
| 408 | Block *block = r; |
| 409 | ValueRange curResult = genLoopWithIterator( |
| 410 | rewriter, loc, validIters.front(), userReduc, |
| 411 | /*bodyBuilder=*/ |
| 412 | [block](PatternRewriter &rewriter, Location loc, Region &dstRegion, |
| 413 | SparseIterator *it, |
| 414 | ValueRange reduc) -> SmallVector<Value> { |
| 415 | SmallVector<Value> blockArgs(reduc); |
| 416 | blockArgs.push_back(it->deref(rewriter, loc)); |
| 417 | llvm::append_range(blockArgs, it->getCursor()); |
| 418 | |
| 419 | Block *dstBlock = &dstRegion.getBlocks().front(); |
| 420 | rewriter.inlineBlockBefore( |
| 421 | block, dstBlock, rewriter.getInsertionPoint(), blockArgs); |
| 422 | auto yield = llvm::cast<sparse_tensor::YieldOp>(dstBlock->back()); |
| 423 | SmallVector<Value> result(yield.getResults()); |
| 424 | rewriter.eraseOp(yield); |
| 425 | return result; |
| 426 | }); |
| 427 | |
| 428 | userReduc.assign(curResult.begin(), curResult.end()); |
| 429 | } |
| 430 | } |
| 431 | |
| 432 | rewriter.replaceOp(op, userReduc); |
| 433 | return success(); |
| 434 | } |
| 435 | }; |
| 436 | |
| 437 | } // namespace |
| 438 | |
| 439 | mlir::SparseIterationTypeConverter::SparseIterationTypeConverter() { |
| 440 | addConversion(callback: [](Type type) { return type; }); |
| 441 | addConversion(convertIteratorType); |
| 442 | addConversion(convertIterSpaceType); |
| 443 | |
| 444 | addSourceMaterialization(callback: [](OpBuilder &builder, IterSpaceType spTp, |
| 445 | ValueRange inputs, Location loc) -> Value { |
| 446 | return builder |
| 447 | .create<UnrealizedConversionCastOp>(loc, TypeRange(spTp), inputs) |
| 448 | .getResult(0); |
| 449 | }); |
| 450 | } |
| 451 | |
| 452 | void mlir::populateLowerSparseIterationToSCFPatterns( |
| 453 | const TypeConverter &converter, RewritePatternSet &patterns) { |
| 454 | |
| 455 | IterateOp::getCanonicalizationPatterns(patterns, patterns.getContext()); |
| 456 | patterns.add<ExtractIterSpaceConverter, ExtractValOpConverter, |
| 457 | SparseIterateOpConverter, SparseCoIterateOpConverter>( |
| 458 | arg: converter, args: patterns.getContext()); |
| 459 | } |
| 460 | |