| 1 | //===- SCF.cpp - Structured Control Flow Operations -----------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Dialect/SCF/IR/SCF.h" |
| 10 | #include "mlir/Conversion/ConvertToEmitC/ToEmitCInterface.h" |
| 11 | #include "mlir/Dialect/Arith/IR/Arith.h" |
| 12 | #include "mlir/Dialect/Arith/Utils/Utils.h" |
| 13 | #include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h" |
| 14 | #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" |
| 15 | #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" |
| 16 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| 17 | #include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h" |
| 18 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| 19 | #include "mlir/IR/BuiltinAttributes.h" |
| 20 | #include "mlir/IR/IRMapping.h" |
| 21 | #include "mlir/IR/Matchers.h" |
| 22 | #include "mlir/IR/PatternMatch.h" |
| 23 | #include "mlir/Interfaces/FunctionInterfaces.h" |
| 24 | #include "mlir/Interfaces/ValueBoundsOpInterface.h" |
| 25 | #include "mlir/Transforms/InliningUtils.h" |
| 26 | #include "llvm/ADT/MapVector.h" |
| 27 | #include "llvm/ADT/SmallPtrSet.h" |
| 28 | #include "llvm/ADT/TypeSwitch.h" |
| 29 | |
| 30 | using namespace mlir; |
| 31 | using namespace mlir::scf; |
| 32 | |
| 33 | #include "mlir/Dialect/SCF/IR/SCFOpsDialect.cpp.inc" |
| 34 | |
| 35 | //===----------------------------------------------------------------------===// |
| 36 | // SCFDialect Dialect Interfaces |
| 37 | //===----------------------------------------------------------------------===// |
| 38 | |
| 39 | namespace { |
| 40 | struct SCFInlinerInterface : public DialectInlinerInterface { |
| 41 | using DialectInlinerInterface::DialectInlinerInterface; |
| 42 | // We don't have any special restrictions on what can be inlined into |
| 43 | // destination regions (e.g. while/conditional bodies). Always allow it. |
| 44 | bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, |
| 45 | IRMapping &valueMapping) const final { |
| 46 | return true; |
| 47 | } |
| 48 | // Operations in scf dialect are always legal to inline since they are |
| 49 | // pure. |
| 50 | bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final { |
| 51 | return true; |
| 52 | } |
| 53 | // Handle the given inlined terminator by replacing it with a new operation |
| 54 | // as necessary. Required when the region has only one block. |
| 55 | void handleTerminator(Operation *op, ValueRange valuesToRepl) const final { |
| 56 | auto retValOp = dyn_cast<scf::YieldOp>(op); |
| 57 | if (!retValOp) |
| 58 | return; |
| 59 | |
| 60 | for (auto retValue : llvm::zip(valuesToRepl, retValOp.getOperands())) { |
| 61 | std::get<0>(retValue).replaceAllUsesWith(std::get<1>(retValue)); |
| 62 | } |
| 63 | } |
| 64 | }; |
| 65 | } // namespace |
| 66 | |
| 67 | //===----------------------------------------------------------------------===// |
| 68 | // SCFDialect |
| 69 | //===----------------------------------------------------------------------===// |
| 70 | |
| 71 | void SCFDialect::initialize() { |
| 72 | addOperations< |
| 73 | #define GET_OP_LIST |
| 74 | #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc" |
| 75 | >(); |
| 76 | addInterfaces<SCFInlinerInterface>(); |
| 77 | declarePromisedInterface<ConvertToEmitCPatternInterface, SCFDialect>(); |
| 78 | declarePromisedInterfaces<bufferization::BufferDeallocationOpInterface, |
| 79 | InParallelOp, ReduceReturnOp>(); |
| 80 | declarePromisedInterfaces<bufferization::BufferizableOpInterface, ConditionOp, |
| 81 | ExecuteRegionOp, ForOp, IfOp, IndexSwitchOp, |
| 82 | ForallOp, InParallelOp, WhileOp, YieldOp>(); |
| 83 | declarePromisedInterface<ValueBoundsOpInterface, ForOp>(); |
| 84 | } |
| 85 | |
| 86 | /// Default callback for IfOp builders. Inserts a yield without arguments. |
| 87 | void mlir::scf::buildTerminatedBody(OpBuilder &builder, Location loc) { |
| 88 | builder.create<scf::YieldOp>(loc); |
| 89 | } |
| 90 | |
| 91 | /// Verifies that the first block of the given `region` is terminated by a |
| 92 | /// TerminatorTy. Reports errors on the given operation if it is not the case. |
| 93 | template <typename TerminatorTy> |
| 94 | static TerminatorTy verifyAndGetTerminator(Operation *op, Region ®ion, |
| 95 | StringRef errorMessage) { |
| 96 | Operation *terminatorOperation = nullptr; |
| 97 | if (!region.empty() && !region.front().empty()) { |
| 98 | terminatorOperation = ®ion.front().back(); |
| 99 | if (auto yield = dyn_cast_or_null<TerminatorTy>(terminatorOperation)) |
| 100 | return yield; |
| 101 | } |
| 102 | auto diag = op->emitOpError(message: errorMessage); |
| 103 | if (terminatorOperation) |
| 104 | diag.attachNote(noteLoc: terminatorOperation->getLoc()) << "terminator here" ; |
| 105 | return nullptr; |
| 106 | } |
| 107 | |
| 108 | //===----------------------------------------------------------------------===// |
| 109 | // ExecuteRegionOp |
| 110 | //===----------------------------------------------------------------------===// |
| 111 | |
| 112 | /// Replaces the given op with the contents of the given single-block region, |
| 113 | /// using the operands of the block terminator to replace operation results. |
| 114 | static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, |
| 115 | Region ®ion, ValueRange blockArgs = {}) { |
| 116 | assert(llvm::hasSingleElement(region) && "expected single-region block" ); |
| 117 | Block *block = ®ion.front(); |
| 118 | Operation *terminator = block->getTerminator(); |
| 119 | ValueRange results = terminator->getOperands(); |
| 120 | rewriter.inlineBlockBefore(source: block, op, argValues: blockArgs); |
| 121 | rewriter.replaceOp(op, newValues: results); |
| 122 | rewriter.eraseOp(op: terminator); |
| 123 | } |
| 124 | |
| 125 | /// |
| 126 | /// (ssa-id `=`)? `execute_region` `->` function-result-type `{` |
| 127 | /// block+ |
| 128 | /// `}` |
| 129 | /// |
| 130 | /// Example: |
| 131 | /// scf.execute_region -> i32 { |
| 132 | /// %idx = load %rI[%i] : memref<128xi32> |
| 133 | /// return %idx : i32 |
| 134 | /// } |
| 135 | /// |
| 136 | ParseResult ExecuteRegionOp::parse(OpAsmParser &parser, |
| 137 | OperationState &result) { |
| 138 | if (parser.parseOptionalArrowTypeList(result.types)) |
| 139 | return failure(); |
| 140 | |
| 141 | // Introduce the body region and parse it. |
| 142 | Region *body = result.addRegion(); |
| 143 | if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}) || |
| 144 | parser.parseOptionalAttrDict(result.attributes)) |
| 145 | return failure(); |
| 146 | |
| 147 | return success(); |
| 148 | } |
| 149 | |
| 150 | void ExecuteRegionOp::print(OpAsmPrinter &p) { |
| 151 | p.printOptionalArrowTypeList(getResultTypes()); |
| 152 | |
| 153 | p << ' '; |
| 154 | p.printRegion(getRegion(), |
| 155 | /*printEntryBlockArgs=*/false, |
| 156 | /*printBlockTerminators=*/true); |
| 157 | |
| 158 | p.printOptionalAttrDict((*this)->getAttrs()); |
| 159 | } |
| 160 | |
| 161 | LogicalResult ExecuteRegionOp::verify() { |
| 162 | if (getRegion().empty()) |
| 163 | return emitOpError("region needs to have at least one block" ); |
| 164 | if (getRegion().front().getNumArguments() > 0) |
| 165 | return emitOpError("region cannot have any arguments" ); |
| 166 | return success(); |
| 167 | } |
| 168 | |
| 169 | // Inline an ExecuteRegionOp if it only contains one block. |
| 170 | // "test.foo"() : () -> () |
| 171 | // %v = scf.execute_region -> i64 { |
| 172 | // %x = "test.val"() : () -> i64 |
| 173 | // scf.yield %x : i64 |
| 174 | // } |
| 175 | // "test.bar"(%v) : (i64) -> () |
| 176 | // |
| 177 | // becomes |
| 178 | // |
| 179 | // "test.foo"() : () -> () |
| 180 | // %x = "test.val"() : () -> i64 |
| 181 | // "test.bar"(%x) : (i64) -> () |
| 182 | // |
| 183 | struct SingleBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> { |
| 184 | using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern; |
| 185 | |
| 186 | LogicalResult matchAndRewrite(ExecuteRegionOp op, |
| 187 | PatternRewriter &rewriter) const override { |
| 188 | if (!llvm::hasSingleElement(op.getRegion())) |
| 189 | return failure(); |
| 190 | replaceOpWithRegion(rewriter, op, op.getRegion()); |
| 191 | return success(); |
| 192 | } |
| 193 | }; |
| 194 | |
| 195 | // Inline an ExecuteRegionOp if its parent can contain multiple blocks. |
| 196 | // TODO generalize the conditions for operations which can be inlined into. |
| 197 | // func @func_execute_region_elim() { |
| 198 | // "test.foo"() : () -> () |
| 199 | // %v = scf.execute_region -> i64 { |
| 200 | // %c = "test.cmp"() : () -> i1 |
| 201 | // cf.cond_br %c, ^bb2, ^bb3 |
| 202 | // ^bb2: |
| 203 | // %x = "test.val1"() : () -> i64 |
| 204 | // cf.br ^bb4(%x : i64) |
| 205 | // ^bb3: |
| 206 | // %y = "test.val2"() : () -> i64 |
| 207 | // cf.br ^bb4(%y : i64) |
| 208 | // ^bb4(%z : i64): |
| 209 | // scf.yield %z : i64 |
| 210 | // } |
| 211 | // "test.bar"(%v) : (i64) -> () |
| 212 | // return |
| 213 | // } |
| 214 | // |
| 215 | // becomes |
| 216 | // |
| 217 | // func @func_execute_region_elim() { |
| 218 | // "test.foo"() : () -> () |
| 219 | // %c = "test.cmp"() : () -> i1 |
| 220 | // cf.cond_br %c, ^bb1, ^bb2 |
| 221 | // ^bb1: // pred: ^bb0 |
| 222 | // %x = "test.val1"() : () -> i64 |
| 223 | // cf.br ^bb3(%x : i64) |
| 224 | // ^bb2: // pred: ^bb0 |
| 225 | // %y = "test.val2"() : () -> i64 |
| 226 | // cf.br ^bb3(%y : i64) |
| 227 | // ^bb3(%z: i64): // 2 preds: ^bb1, ^bb2 |
| 228 | // "test.bar"(%z) : (i64) -> () |
| 229 | // return |
| 230 | // } |
| 231 | // |
| 232 | struct MultiBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> { |
| 233 | using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern; |
| 234 | |
| 235 | LogicalResult matchAndRewrite(ExecuteRegionOp op, |
| 236 | PatternRewriter &rewriter) const override { |
| 237 | if (!isa<FunctionOpInterface, ExecuteRegionOp>(op->getParentOp())) |
| 238 | return failure(); |
| 239 | |
| 240 | Block *prevBlock = op->getBlock(); |
| 241 | Block *postBlock = rewriter.splitBlock(block: prevBlock, before: op->getIterator()); |
| 242 | rewriter.setInsertionPointToEnd(prevBlock); |
| 243 | |
| 244 | rewriter.create<cf::BranchOp>(op.getLoc(), &op.getRegion().front()); |
| 245 | |
| 246 | for (Block &blk : op.getRegion()) { |
| 247 | if (YieldOp yieldOp = dyn_cast<YieldOp>(blk.getTerminator())) { |
| 248 | rewriter.setInsertionPoint(yieldOp); |
| 249 | rewriter.create<cf::BranchOp>(yieldOp.getLoc(), postBlock, |
| 250 | yieldOp.getResults()); |
| 251 | rewriter.eraseOp(yieldOp); |
| 252 | } |
| 253 | } |
| 254 | |
| 255 | rewriter.inlineRegionBefore(op.getRegion(), postBlock); |
| 256 | SmallVector<Value> blockArgs; |
| 257 | |
| 258 | for (auto res : op.getResults()) |
| 259 | blockArgs.push_back(postBlock->addArgument(res.getType(), res.getLoc())); |
| 260 | |
| 261 | rewriter.replaceOp(op, blockArgs); |
| 262 | return success(); |
| 263 | } |
| 264 | }; |
| 265 | |
| 266 | void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results, |
| 267 | MLIRContext *context) { |
| 268 | results.add<SingleBlockExecuteInliner, MultiBlockExecuteInliner>(context); |
| 269 | } |
| 270 | |
| 271 | void ExecuteRegionOp::getSuccessorRegions( |
| 272 | RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { |
| 273 | // If the predecessor is the ExecuteRegionOp, branch into the body. |
| 274 | if (point.isParent()) { |
| 275 | regions.push_back(RegionSuccessor(&getRegion())); |
| 276 | return; |
| 277 | } |
| 278 | |
| 279 | // Otherwise, the region branches back to the parent operation. |
| 280 | regions.push_back(RegionSuccessor(getResults())); |
| 281 | } |
| 282 | |
| 283 | //===----------------------------------------------------------------------===// |
| 284 | // ConditionOp |
| 285 | //===----------------------------------------------------------------------===// |
| 286 | |
| 287 | MutableOperandRange |
| 288 | ConditionOp::getMutableSuccessorOperands(RegionBranchPoint point) { |
| 289 | assert((point.isParent() || point == getParentOp().getAfter()) && |
| 290 | "condition op can only exit the loop or branch to the after" |
| 291 | "region" ); |
| 292 | // Pass all operands except the condition to the successor region. |
| 293 | return getArgsMutable(); |
| 294 | } |
| 295 | |
| 296 | void ConditionOp::getSuccessorRegions( |
| 297 | ArrayRef<Attribute> operands, SmallVectorImpl<RegionSuccessor> ®ions) { |
| 298 | FoldAdaptor adaptor(operands, *this); |
| 299 | |
| 300 | WhileOp whileOp = getParentOp(); |
| 301 | |
| 302 | // Condition can either lead to the after region or back to the parent op |
| 303 | // depending on whether the condition is true or not. |
| 304 | auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition()); |
| 305 | if (!boolAttr || boolAttr.getValue()) |
| 306 | regions.emplace_back(&whileOp.getAfter(), |
| 307 | whileOp.getAfter().getArguments()); |
| 308 | if (!boolAttr || !boolAttr.getValue()) |
| 309 | regions.emplace_back(whileOp.getResults()); |
| 310 | } |
| 311 | |
| 312 | //===----------------------------------------------------------------------===// |
| 313 | // ForOp |
| 314 | //===----------------------------------------------------------------------===// |
| 315 | |
| 316 | void ForOp::build(OpBuilder &builder, OperationState &result, Value lb, |
| 317 | Value ub, Value step, ValueRange initArgs, |
| 318 | BodyBuilderFn bodyBuilder) { |
| 319 | OpBuilder::InsertionGuard guard(builder); |
| 320 | |
| 321 | result.addOperands({lb, ub, step}); |
| 322 | result.addOperands(initArgs); |
| 323 | for (Value v : initArgs) |
| 324 | result.addTypes(v.getType()); |
| 325 | Type t = lb.getType(); |
| 326 | Region *bodyRegion = result.addRegion(); |
| 327 | Block *bodyBlock = builder.createBlock(bodyRegion); |
| 328 | bodyBlock->addArgument(t, result.location); |
| 329 | for (Value v : initArgs) |
| 330 | bodyBlock->addArgument(v.getType(), v.getLoc()); |
| 331 | |
| 332 | // Create the default terminator if the builder is not provided and if the |
| 333 | // iteration arguments are not provided. Otherwise, leave this to the caller |
| 334 | // because we don't know which values to return from the loop. |
| 335 | if (initArgs.empty() && !bodyBuilder) { |
| 336 | ForOp::ensureTerminator(*bodyRegion, builder, result.location); |
| 337 | } else if (bodyBuilder) { |
| 338 | OpBuilder::InsertionGuard guard(builder); |
| 339 | builder.setInsertionPointToStart(bodyBlock); |
| 340 | bodyBuilder(builder, result.location, bodyBlock->getArgument(0), |
| 341 | bodyBlock->getArguments().drop_front()); |
| 342 | } |
| 343 | } |
| 344 | |
| 345 | LogicalResult ForOp::verify() { |
| 346 | // Check that the number of init args and op results is the same. |
| 347 | if (getInitArgs().size() != getNumResults()) |
| 348 | return emitOpError( |
| 349 | "mismatch in number of loop-carried values and defined values" ); |
| 350 | |
| 351 | return success(); |
| 352 | } |
| 353 | |
| 354 | LogicalResult ForOp::verifyRegions() { |
| 355 | // Check that the body defines as single block argument for the induction |
| 356 | // variable. |
| 357 | if (getInductionVar().getType() != getLowerBound().getType()) |
| 358 | return emitOpError( |
| 359 | "expected induction variable to be same type as bounds and step" ); |
| 360 | |
| 361 | if (getNumRegionIterArgs() != getNumResults()) |
| 362 | return emitOpError( |
| 363 | "mismatch in number of basic block args and defined values" ); |
| 364 | |
| 365 | auto initArgs = getInitArgs(); |
| 366 | auto iterArgs = getRegionIterArgs(); |
| 367 | auto opResults = getResults(); |
| 368 | unsigned i = 0; |
| 369 | for (auto e : llvm::zip(initArgs, iterArgs, opResults)) { |
| 370 | if (std::get<0>(e).getType() != std::get<2>(e).getType()) |
| 371 | return emitOpError() << "types mismatch between " << i |
| 372 | << "th iter operand and defined value" ; |
| 373 | if (std::get<1>(e).getType() != std::get<2>(e).getType()) |
| 374 | return emitOpError() << "types mismatch between " << i |
| 375 | << "th iter region arg and defined value" ; |
| 376 | |
| 377 | ++i; |
| 378 | } |
| 379 | return success(); |
| 380 | } |
| 381 | |
| 382 | std::optional<SmallVector<Value>> ForOp::getLoopInductionVars() { |
| 383 | return SmallVector<Value>{getInductionVar()}; |
| 384 | } |
| 385 | |
| 386 | std::optional<SmallVector<OpFoldResult>> ForOp::getLoopLowerBounds() { |
| 387 | return SmallVector<OpFoldResult>{OpFoldResult(getLowerBound())}; |
| 388 | } |
| 389 | |
| 390 | std::optional<SmallVector<OpFoldResult>> ForOp::getLoopSteps() { |
| 391 | return SmallVector<OpFoldResult>{OpFoldResult(getStep())}; |
| 392 | } |
| 393 | |
| 394 | std::optional<SmallVector<OpFoldResult>> ForOp::getLoopUpperBounds() { |
| 395 | return SmallVector<OpFoldResult>{OpFoldResult(getUpperBound())}; |
| 396 | } |
| 397 | |
| 398 | std::optional<ResultRange> ForOp::getLoopResults() { return getResults(); } |
| 399 | |
| 400 | /// Promotes the loop body of a forOp to its containing block if the forOp |
| 401 | /// it can be determined that the loop has a single iteration. |
| 402 | LogicalResult ForOp::promoteIfSingleIteration(RewriterBase &rewriter) { |
| 403 | std::optional<int64_t> tripCount = |
| 404 | constantTripCount(getLowerBound(), getUpperBound(), getStep()); |
| 405 | if (!tripCount.has_value() || tripCount != 1) |
| 406 | return failure(); |
| 407 | |
| 408 | // Replace all results with the yielded values. |
| 409 | auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator()); |
| 410 | rewriter.replaceAllUsesWith(getResults(), getYieldedValues()); |
| 411 | |
| 412 | // Replace block arguments with lower bound (replacement for IV) and |
| 413 | // iter_args. |
| 414 | SmallVector<Value> bbArgReplacements; |
| 415 | bbArgReplacements.push_back(getLowerBound()); |
| 416 | llvm::append_range(bbArgReplacements, getInitArgs()); |
| 417 | |
| 418 | // Move the loop body operations to the loop's containing block. |
| 419 | rewriter.inlineBlockBefore(getBody(), getOperation()->getBlock(), |
| 420 | getOperation()->getIterator(), bbArgReplacements); |
| 421 | |
| 422 | // Erase the old terminator and the loop. |
| 423 | rewriter.eraseOp(yieldOp); |
| 424 | rewriter.eraseOp(*this); |
| 425 | |
| 426 | return success(); |
| 427 | } |
| 428 | |
| 429 | /// Prints the initialization list in the form of |
| 430 | /// <prefix>(%inner = %outer, %inner2 = %outer2, <...>) |
| 431 | /// where 'inner' values are assumed to be region arguments and 'outer' values |
| 432 | /// are regular SSA values. |
| 433 | static void printInitializationList(OpAsmPrinter &p, |
| 434 | Block::BlockArgListType blocksArgs, |
| 435 | ValueRange initializers, |
| 436 | StringRef prefix = "" ) { |
| 437 | assert(blocksArgs.size() == initializers.size() && |
| 438 | "expected same length of arguments and initializers" ); |
| 439 | if (initializers.empty()) |
| 440 | return; |
| 441 | |
| 442 | p << prefix << '('; |
| 443 | llvm::interleaveComma(c: llvm::zip(t&: blocksArgs, u&: initializers), os&: p, each_fn: [&](auto it) { |
| 444 | p << std::get<0>(it) << " = " << std::get<1>(it); |
| 445 | }); |
| 446 | p << ")" ; |
| 447 | } |
| 448 | |
| 449 | void ForOp::print(OpAsmPrinter &p) { |
| 450 | p << " " << getInductionVar() << " = " << getLowerBound() << " to " |
| 451 | << getUpperBound() << " step " << getStep(); |
| 452 | |
| 453 | printInitializationList(p, getRegionIterArgs(), getInitArgs(), " iter_args" ); |
| 454 | if (!getInitArgs().empty()) |
| 455 | p << " -> (" << getInitArgs().getTypes() << ')'; |
| 456 | p << ' '; |
| 457 | if (Type t = getInductionVar().getType(); !t.isIndex()) |
| 458 | p << " : " << t << ' '; |
| 459 | p.printRegion(getRegion(), |
| 460 | /*printEntryBlockArgs=*/false, |
| 461 | /*printBlockTerminators=*/!getInitArgs().empty()); |
| 462 | p.printOptionalAttrDict((*this)->getAttrs()); |
| 463 | } |
| 464 | |
| 465 | ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) { |
| 466 | auto &builder = parser.getBuilder(); |
| 467 | Type type; |
| 468 | |
| 469 | OpAsmParser::Argument inductionVariable; |
| 470 | OpAsmParser::UnresolvedOperand lb, ub, step; |
| 471 | |
| 472 | // Parse the induction variable followed by '='. |
| 473 | if (parser.parseOperand(inductionVariable.ssaName) || parser.parseEqual() || |
| 474 | // Parse loop bounds. |
| 475 | parser.parseOperand(lb) || parser.parseKeyword("to" ) || |
| 476 | parser.parseOperand(ub) || parser.parseKeyword("step" ) || |
| 477 | parser.parseOperand(step)) |
| 478 | return failure(); |
| 479 | |
| 480 | // Parse the optional initial iteration arguments. |
| 481 | SmallVector<OpAsmParser::Argument, 4> regionArgs; |
| 482 | SmallVector<OpAsmParser::UnresolvedOperand, 4> operands; |
| 483 | regionArgs.push_back(inductionVariable); |
| 484 | |
| 485 | bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args" )); |
| 486 | if (hasIterArgs) { |
| 487 | // Parse assignment list and results type list. |
| 488 | if (parser.parseAssignmentList(regionArgs, operands) || |
| 489 | parser.parseArrowTypeList(result.types)) |
| 490 | return failure(); |
| 491 | } |
| 492 | |
| 493 | if (regionArgs.size() != result.types.size() + 1) |
| 494 | return parser.emitError( |
| 495 | parser.getNameLoc(), |
| 496 | "mismatch in number of loop-carried values and defined values" ); |
| 497 | |
| 498 | // Parse optional type, else assume Index. |
| 499 | if (parser.parseOptionalColon()) |
| 500 | type = builder.getIndexType(); |
| 501 | else if (parser.parseType(type)) |
| 502 | return failure(); |
| 503 | |
| 504 | // Set block argument types, so that they are known when parsing the region. |
| 505 | regionArgs.front().type = type; |
| 506 | for (auto [iterArg, type] : |
| 507 | llvm::zip_equal(llvm::drop_begin(regionArgs), result.types)) |
| 508 | iterArg.type = type; |
| 509 | |
| 510 | // Parse the body region. |
| 511 | Region *body = result.addRegion(); |
| 512 | if (parser.parseRegion(*body, regionArgs)) |
| 513 | return failure(); |
| 514 | ForOp::ensureTerminator(*body, builder, result.location); |
| 515 | |
| 516 | // Resolve input operands. This should be done after parsing the region to |
| 517 | // catch invalid IR where operands were defined inside of the region. |
| 518 | if (parser.resolveOperand(lb, type, result.operands) || |
| 519 | parser.resolveOperand(ub, type, result.operands) || |
| 520 | parser.resolveOperand(step, type, result.operands)) |
| 521 | return failure(); |
| 522 | if (hasIterArgs) { |
| 523 | for (auto argOperandType : llvm::zip_equal(llvm::drop_begin(regionArgs), |
| 524 | operands, result.types)) { |
| 525 | Type type = std::get<2>(argOperandType); |
| 526 | std::get<0>(argOperandType).type = type; |
| 527 | if (parser.resolveOperand(std::get<1>(argOperandType), type, |
| 528 | result.operands)) |
| 529 | return failure(); |
| 530 | } |
| 531 | } |
| 532 | |
| 533 | // Parse the optional attribute list. |
| 534 | if (parser.parseOptionalAttrDict(result.attributes)) |
| 535 | return failure(); |
| 536 | |
| 537 | return success(); |
| 538 | } |
| 539 | |
| 540 | SmallVector<Region *> ForOp::getLoopRegions() { return {&getRegion()}; } |
| 541 | |
| 542 | Block::BlockArgListType ForOp::getRegionIterArgs() { |
| 543 | return getBody()->getArguments().drop_front(getNumInductionVars()); |
| 544 | } |
| 545 | |
| 546 | MutableArrayRef<OpOperand> ForOp::getInitsMutable() { |
| 547 | return getInitArgsMutable(); |
| 548 | } |
| 549 | |
| 550 | FailureOr<LoopLikeOpInterface> |
| 551 | ForOp::replaceWithAdditionalYields(RewriterBase &rewriter, |
| 552 | ValueRange newInitOperands, |
| 553 | bool replaceInitOperandUsesInLoop, |
| 554 | const NewYieldValuesFn &newYieldValuesFn) { |
| 555 | // Create a new loop before the existing one, with the extra operands. |
| 556 | OpBuilder::InsertionGuard g(rewriter); |
| 557 | rewriter.setInsertionPoint(getOperation()); |
| 558 | auto inits = llvm::to_vector(getInitArgs()); |
| 559 | inits.append(newInitOperands.begin(), newInitOperands.end()); |
| 560 | scf::ForOp newLoop = rewriter.create<scf::ForOp>( |
| 561 | getLoc(), getLowerBound(), getUpperBound(), getStep(), inits, |
| 562 | [](OpBuilder &, Location, Value, ValueRange) {}); |
| 563 | newLoop->setAttrs(getPrunedAttributeList(getOperation(), {})); |
| 564 | |
| 565 | // Generate the new yield values and append them to the scf.yield operation. |
| 566 | auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator()); |
| 567 | ArrayRef<BlockArgument> newIterArgs = |
| 568 | newLoop.getBody()->getArguments().take_back(newInitOperands.size()); |
| 569 | { |
| 570 | OpBuilder::InsertionGuard g(rewriter); |
| 571 | rewriter.setInsertionPoint(yieldOp); |
| 572 | SmallVector<Value> newYieldedValues = |
| 573 | newYieldValuesFn(rewriter, getLoc(), newIterArgs); |
| 574 | assert(newInitOperands.size() == newYieldedValues.size() && |
| 575 | "expected as many new yield values as new iter operands" ); |
| 576 | rewriter.modifyOpInPlace(yieldOp, [&]() { |
| 577 | yieldOp.getResultsMutable().append(newYieldedValues); |
| 578 | }); |
| 579 | } |
| 580 | |
| 581 | // Move the loop body to the new op. |
| 582 | rewriter.mergeBlocks(getBody(), newLoop.getBody(), |
| 583 | newLoop.getBody()->getArguments().take_front( |
| 584 | getBody()->getNumArguments())); |
| 585 | |
| 586 | if (replaceInitOperandUsesInLoop) { |
| 587 | // Replace all uses of `newInitOperands` with the corresponding basic block |
| 588 | // arguments. |
| 589 | for (auto it : llvm::zip(newInitOperands, newIterArgs)) { |
| 590 | rewriter.replaceUsesWithIf(std::get<0>(it), std::get<1>(it), |
| 591 | [&](OpOperand &use) { |
| 592 | Operation *user = use.getOwner(); |
| 593 | return newLoop->isProperAncestor(user); |
| 594 | }); |
| 595 | } |
| 596 | } |
| 597 | |
| 598 | // Replace the old loop. |
| 599 | rewriter.replaceOp(getOperation(), |
| 600 | newLoop->getResults().take_front(getNumResults())); |
| 601 | return cast<LoopLikeOpInterface>(newLoop.getOperation()); |
| 602 | } |
| 603 | |
| 604 | ForOp mlir::scf::getForInductionVarOwner(Value val) { |
| 605 | auto ivArg = llvm::dyn_cast<BlockArgument>(Val&: val); |
| 606 | if (!ivArg) |
| 607 | return ForOp(); |
| 608 | assert(ivArg.getOwner() && "unlinked block argument" ); |
| 609 | auto *containingOp = ivArg.getOwner()->getParentOp(); |
| 610 | return dyn_cast_or_null<ForOp>(containingOp); |
| 611 | } |
| 612 | |
| 613 | OperandRange ForOp::getEntrySuccessorOperands(RegionBranchPoint point) { |
| 614 | return getInitArgs(); |
| 615 | } |
| 616 | |
| 617 | void ForOp::getSuccessorRegions(RegionBranchPoint point, |
| 618 | SmallVectorImpl<RegionSuccessor> ®ions) { |
| 619 | // Both the operation itself and the region may be branching into the body or |
| 620 | // back into the operation itself. It is possible for loop not to enter the |
| 621 | // body. |
| 622 | regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs())); |
| 623 | regions.push_back(RegionSuccessor(getResults())); |
| 624 | } |
| 625 | |
| 626 | SmallVector<Region *> ForallOp::getLoopRegions() { return {&getRegion()}; } |
| 627 | |
| 628 | /// Promotes the loop body of a forallOp to its containing block if it can be |
| 629 | /// determined that the loop has a single iteration. |
| 630 | LogicalResult scf::ForallOp::promoteIfSingleIteration(RewriterBase &rewriter) { |
| 631 | for (auto [lb, ub, step] : |
| 632 | llvm::zip(getMixedLowerBound(), getMixedUpperBound(), getMixedStep())) { |
| 633 | auto tripCount = constantTripCount(lb, ub, step); |
| 634 | if (!tripCount.has_value() || *tripCount != 1) |
| 635 | return failure(); |
| 636 | } |
| 637 | |
| 638 | promote(rewriter, *this); |
| 639 | return success(); |
| 640 | } |
| 641 | |
| 642 | Block::BlockArgListType ForallOp::getRegionIterArgs() { |
| 643 | return getBody()->getArguments().drop_front(getRank()); |
| 644 | } |
| 645 | |
| 646 | MutableArrayRef<OpOperand> ForallOp::getInitsMutable() { |
| 647 | return getOutputsMutable(); |
| 648 | } |
| 649 | |
| 650 | /// Promotes the loop body of a scf::ForallOp to its containing block. |
| 651 | void mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp) { |
| 652 | OpBuilder::InsertionGuard g(rewriter); |
| 653 | scf::InParallelOp terminator = forallOp.getTerminator(); |
| 654 | |
| 655 | // Replace block arguments with lower bounds (replacements for IVs) and |
| 656 | // outputs. |
| 657 | SmallVector<Value> bbArgReplacements = forallOp.getLowerBound(rewriter); |
| 658 | bbArgReplacements.append(forallOp.getOutputs().begin(), |
| 659 | forallOp.getOutputs().end()); |
| 660 | |
| 661 | // Move the loop body operations to the loop's containing block. |
| 662 | rewriter.inlineBlockBefore(forallOp.getBody(), forallOp->getBlock(), |
| 663 | forallOp->getIterator(), bbArgReplacements); |
| 664 | |
| 665 | // Replace the terminator with tensor.insert_slice ops. |
| 666 | rewriter.setInsertionPointAfter(forallOp); |
| 667 | SmallVector<Value> results; |
| 668 | results.reserve(N: forallOp.getResults().size()); |
| 669 | for (auto &yieldingOp : terminator.getYieldingOps()) { |
| 670 | auto parallelInsertSliceOp = |
| 671 | cast<tensor::ParallelInsertSliceOp>(yieldingOp); |
| 672 | |
| 673 | Value dst = parallelInsertSliceOp.getDest(); |
| 674 | Value src = parallelInsertSliceOp.getSource(); |
| 675 | if (llvm::isa<TensorType>(src.getType())) { |
| 676 | results.push_back(rewriter.create<tensor::InsertSliceOp>( |
| 677 | forallOp.getLoc(), dst.getType(), src, dst, |
| 678 | parallelInsertSliceOp.getOffsets(), parallelInsertSliceOp.getSizes(), |
| 679 | parallelInsertSliceOp.getStrides(), |
| 680 | parallelInsertSliceOp.getStaticOffsets(), |
| 681 | parallelInsertSliceOp.getStaticSizes(), |
| 682 | parallelInsertSliceOp.getStaticStrides())); |
| 683 | } else { |
| 684 | llvm_unreachable("unsupported terminator" ); |
| 685 | } |
| 686 | } |
| 687 | rewriter.replaceAllUsesWith(forallOp.getResults(), results); |
| 688 | |
| 689 | // Erase the old terminator and the loop. |
| 690 | rewriter.eraseOp(op: terminator); |
| 691 | rewriter.eraseOp(op: forallOp); |
| 692 | } |
| 693 | |
| 694 | LoopNest mlir::scf::buildLoopNest( |
| 695 | OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, |
| 696 | ValueRange steps, ValueRange iterArgs, |
| 697 | function_ref<ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> |
| 698 | bodyBuilder) { |
| 699 | assert(lbs.size() == ubs.size() && |
| 700 | "expected the same number of lower and upper bounds" ); |
| 701 | assert(lbs.size() == steps.size() && |
| 702 | "expected the same number of lower bounds and steps" ); |
| 703 | |
| 704 | // If there are no bounds, call the body-building function and return early. |
| 705 | if (lbs.empty()) { |
| 706 | ValueVector results = |
| 707 | bodyBuilder ? bodyBuilder(builder, loc, ValueRange(), iterArgs) |
| 708 | : ValueVector(); |
| 709 | assert(results.size() == iterArgs.size() && |
| 710 | "loop nest body must return as many values as loop has iteration " |
| 711 | "arguments" ); |
| 712 | return LoopNest{{}, std::move(results)}; |
| 713 | } |
| 714 | |
| 715 | // First, create the loop structure iteratively using the body-builder |
| 716 | // callback of `ForOp::build`. Do not create `YieldOp`s yet. |
| 717 | OpBuilder::InsertionGuard guard(builder); |
| 718 | SmallVector<scf::ForOp, 4> loops; |
| 719 | SmallVector<Value, 4> ivs; |
| 720 | loops.reserve(lbs.size()); |
| 721 | ivs.reserve(N: lbs.size()); |
| 722 | ValueRange currentIterArgs = iterArgs; |
| 723 | Location currentLoc = loc; |
| 724 | for (unsigned i = 0, e = lbs.size(); i < e; ++i) { |
| 725 | auto loop = builder.create<scf::ForOp>( |
| 726 | currentLoc, lbs[i], ubs[i], steps[i], currentIterArgs, |
| 727 | [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv, |
| 728 | ValueRange args) { |
| 729 | ivs.push_back(iv); |
| 730 | // It is safe to store ValueRange args because it points to block |
| 731 | // arguments of a loop operation that we also own. |
| 732 | currentIterArgs = args; |
| 733 | currentLoc = nestedLoc; |
| 734 | }); |
| 735 | // Set the builder to point to the body of the newly created loop. We don't |
| 736 | // do this in the callback because the builder is reset when the callback |
| 737 | // returns. |
| 738 | builder.setInsertionPointToStart(loop.getBody()); |
| 739 | loops.push_back(loop); |
| 740 | } |
| 741 | |
| 742 | // For all loops but the innermost, yield the results of the nested loop. |
| 743 | for (unsigned i = 0, e = loops.size() - 1; i < e; ++i) { |
| 744 | builder.setInsertionPointToEnd(loops[i].getBody()); |
| 745 | builder.create<scf::YieldOp>(loc, loops[i + 1].getResults()); |
| 746 | } |
| 747 | |
| 748 | // In the body of the innermost loop, call the body building function if any |
| 749 | // and yield its results. |
| 750 | builder.setInsertionPointToStart(loops.back().getBody()); |
| 751 | ValueVector results = bodyBuilder |
| 752 | ? bodyBuilder(builder, currentLoc, ivs, |
| 753 | loops.back().getRegionIterArgs()) |
| 754 | : ValueVector(); |
| 755 | assert(results.size() == iterArgs.size() && |
| 756 | "loop nest body must return as many values as loop has iteration " |
| 757 | "arguments" ); |
| 758 | builder.setInsertionPointToEnd(loops.back().getBody()); |
| 759 | builder.create<scf::YieldOp>(loc, results); |
| 760 | |
| 761 | // Return the loops. |
| 762 | ValueVector nestResults; |
| 763 | llvm::append_range(nestResults, loops.front().getResults()); |
| 764 | return LoopNest{std::move(loops), std::move(nestResults)}; |
| 765 | } |
| 766 | |
| 767 | LoopNest mlir::scf::buildLoopNest( |
| 768 | OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, |
| 769 | ValueRange steps, |
| 770 | function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) { |
| 771 | // Delegate to the main function by wrapping the body builder. |
| 772 | return buildLoopNest(builder, loc, lbs, ubs, steps, iterArgs: std::nullopt, |
| 773 | bodyBuilder: [&bodyBuilder](OpBuilder &nestedBuilder, |
| 774 | Location nestedLoc, ValueRange ivs, |
| 775 | ValueRange) -> ValueVector { |
| 776 | if (bodyBuilder) |
| 777 | bodyBuilder(nestedBuilder, nestedLoc, ivs); |
| 778 | return {}; |
| 779 | }); |
| 780 | } |
| 781 | |
| 782 | SmallVector<Value> |
| 783 | mlir::scf::replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp, |
| 784 | OpOperand &operand, Value replacement, |
| 785 | const ValueTypeCastFnTy &castFn) { |
| 786 | assert(operand.getOwner() == forOp); |
| 787 | Type oldType = operand.get().getType(), newType = replacement.getType(); |
| 788 | |
| 789 | // 1. Create new iter operands, exactly 1 is replaced. |
| 790 | assert(operand.getOperandNumber() >= forOp.getNumControlOperands() && |
| 791 | "expected an iter OpOperand" ); |
| 792 | assert(operand.get().getType() != replacement.getType() && |
| 793 | "Expected a different type" ); |
| 794 | SmallVector<Value> newIterOperands; |
| 795 | for (OpOperand &opOperand : forOp.getInitArgsMutable()) { |
| 796 | if (opOperand.getOperandNumber() == operand.getOperandNumber()) { |
| 797 | newIterOperands.push_back(replacement); |
| 798 | continue; |
| 799 | } |
| 800 | newIterOperands.push_back(opOperand.get()); |
| 801 | } |
| 802 | |
| 803 | // 2. Create the new forOp shell. |
| 804 | scf::ForOp newForOp = rewriter.create<scf::ForOp>( |
| 805 | forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), |
| 806 | forOp.getStep(), newIterOperands); |
| 807 | newForOp->setAttrs(forOp->getAttrs()); |
| 808 | Block &newBlock = newForOp.getRegion().front(); |
| 809 | SmallVector<Value, 4> newBlockTransferArgs(newBlock.getArguments().begin(), |
| 810 | newBlock.getArguments().end()); |
| 811 | |
| 812 | // 3. Inject an incoming cast op at the beginning of the block for the bbArg |
| 813 | // corresponding to the `replacement` value. |
| 814 | OpBuilder::InsertionGuard g(rewriter); |
| 815 | rewriter.setInsertionPointToStart(&newBlock); |
| 816 | BlockArgument newRegionIterArg = newForOp.getTiedLoopRegionIterArg( |
| 817 | &newForOp->getOpOperand(operand.getOperandNumber())); |
| 818 | Value castIn = castFn(rewriter, newForOp.getLoc(), oldType, newRegionIterArg); |
| 819 | newBlockTransferArgs[newRegionIterArg.getArgNumber()] = castIn; |
| 820 | |
| 821 | // 4. Steal the old block ops, mapping to the newBlockTransferArgs. |
| 822 | Block &oldBlock = forOp.getRegion().front(); |
| 823 | rewriter.mergeBlocks(source: &oldBlock, dest: &newBlock, argValues: newBlockTransferArgs); |
| 824 | |
| 825 | // 5. Inject an outgoing cast op at the end of the block and yield it instead. |
| 826 | auto clonedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator()); |
| 827 | rewriter.setInsertionPoint(clonedYieldOp); |
| 828 | unsigned yieldIdx = |
| 829 | newRegionIterArg.getArgNumber() - forOp.getNumInductionVars(); |
| 830 | Value castOut = castFn(rewriter, newForOp.getLoc(), newType, |
| 831 | clonedYieldOp.getOperand(yieldIdx)); |
| 832 | SmallVector<Value> newYieldOperands = clonedYieldOp.getOperands(); |
| 833 | newYieldOperands[yieldIdx] = castOut; |
| 834 | rewriter.create<scf::YieldOp>(newForOp.getLoc(), newYieldOperands); |
| 835 | rewriter.eraseOp(op: clonedYieldOp); |
| 836 | |
| 837 | // 6. Inject an outgoing cast op after the forOp. |
| 838 | rewriter.setInsertionPointAfter(newForOp); |
| 839 | SmallVector<Value> newResults = newForOp.getResults(); |
| 840 | newResults[yieldIdx] = |
| 841 | castFn(rewriter, newForOp.getLoc(), oldType, newResults[yieldIdx]); |
| 842 | |
| 843 | return newResults; |
| 844 | } |
| 845 | |
| 846 | namespace { |
| 847 | // Fold away ForOp iter arguments when: |
| 848 | // 1) The op yields the iter arguments. |
| 849 | // 2) The argument's corresponding outer region iterators (inputs) are yielded. |
| 850 | // 3) The iter arguments have no use and the corresponding (operation) results |
| 851 | // have no use. |
| 852 | // |
| 853 | // These arguments must be defined outside of the ForOp region and can just be |
| 854 | // forwarded after simplifying the op inits, yields and returns. |
| 855 | // |
| 856 | // The implementation uses `inlineBlockBefore` to steal the content of the |
| 857 | // original ForOp and avoid cloning. |
| 858 | struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> { |
| 859 | using OpRewritePattern<scf::ForOp>::OpRewritePattern; |
| 860 | |
| 861 | LogicalResult matchAndRewrite(scf::ForOp forOp, |
| 862 | PatternRewriter &rewriter) const final { |
| 863 | bool canonicalize = false; |
| 864 | |
| 865 | // An internal flat vector of block transfer |
| 866 | // arguments `newBlockTransferArgs` keeps the 1-1 mapping of original to |
| 867 | // transformed block argument mappings. This plays the role of a |
| 868 | // IRMapping for the particular use case of calling into |
| 869 | // `inlineBlockBefore`. |
| 870 | int64_t numResults = forOp.getNumResults(); |
| 871 | SmallVector<bool, 4> keepMask; |
| 872 | keepMask.reserve(N: numResults); |
| 873 | SmallVector<Value, 4> newBlockTransferArgs, newIterArgs, newYieldValues, |
| 874 | newResultValues; |
| 875 | newBlockTransferArgs.reserve(N: 1 + numResults); |
| 876 | newBlockTransferArgs.push_back(Elt: Value()); // iv placeholder with null value |
| 877 | newIterArgs.reserve(N: forOp.getInitArgs().size()); |
| 878 | newYieldValues.reserve(N: numResults); |
| 879 | newResultValues.reserve(N: numResults); |
| 880 | DenseMap<std::pair<Value, Value>, std::pair<Value, Value>> initYieldToArg; |
| 881 | for (auto [init, arg, result, yielded] : |
| 882 | llvm::zip(forOp.getInitArgs(), // iter from outside |
| 883 | forOp.getRegionIterArgs(), // iter inside region |
| 884 | forOp.getResults(), // op results |
| 885 | forOp.getYieldedValues() // iter yield |
| 886 | )) { |
| 887 | // Forwarded is `true` when: |
| 888 | // 1) The region `iter` argument is yielded. |
| 889 | // 2) The region `iter` argument the corresponding input is yielded. |
| 890 | // 3) The region `iter` argument has no use, and the corresponding op |
| 891 | // result has no use. |
| 892 | bool forwarded = (arg == yielded) || (init == yielded) || |
| 893 | (arg.use_empty() && result.use_empty()); |
| 894 | if (forwarded) { |
| 895 | canonicalize = true; |
| 896 | keepMask.push_back(false); |
| 897 | newBlockTransferArgs.push_back(init); |
| 898 | newResultValues.push_back(init); |
| 899 | continue; |
| 900 | } |
| 901 | |
| 902 | // Check if a previous kept argument always has the same values for init |
| 903 | // and yielded values. |
| 904 | if (auto it = initYieldToArg.find({init, yielded}); |
| 905 | it != initYieldToArg.end()) { |
| 906 | canonicalize = true; |
| 907 | keepMask.push_back(false); |
| 908 | auto [sameArg, sameResult] = it->second; |
| 909 | rewriter.replaceAllUsesWith(arg, sameArg); |
| 910 | rewriter.replaceAllUsesWith(result, sameResult); |
| 911 | // The replacement value doesn't matter because there are no uses. |
| 912 | newBlockTransferArgs.push_back(init); |
| 913 | newResultValues.push_back(init); |
| 914 | continue; |
| 915 | } |
| 916 | |
| 917 | // This value is kept. |
| 918 | initYieldToArg.insert({{init, yielded}, {arg, result}}); |
| 919 | keepMask.push_back(true); |
| 920 | newIterArgs.push_back(init); |
| 921 | newYieldValues.push_back(yielded); |
| 922 | newBlockTransferArgs.push_back(Value()); // placeholder with null value |
| 923 | newResultValues.push_back(Value()); // placeholder with null value |
| 924 | } |
| 925 | |
| 926 | if (!canonicalize) |
| 927 | return failure(); |
| 928 | |
| 929 | scf::ForOp newForOp = rewriter.create<scf::ForOp>( |
| 930 | forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), |
| 931 | forOp.getStep(), newIterArgs); |
| 932 | newForOp->setAttrs(forOp->getAttrs()); |
| 933 | Block &newBlock = newForOp.getRegion().front(); |
| 934 | |
| 935 | // Replace the null placeholders with newly constructed values. |
| 936 | newBlockTransferArgs[0] = newBlock.getArgument(i: 0); // iv |
| 937 | for (unsigned idx = 0, collapsedIdx = 0, e = newResultValues.size(); |
| 938 | idx != e; ++idx) { |
| 939 | Value &blockTransferArg = newBlockTransferArgs[1 + idx]; |
| 940 | Value &newResultVal = newResultValues[idx]; |
| 941 | assert((blockTransferArg && newResultVal) || |
| 942 | (!blockTransferArg && !newResultVal)); |
| 943 | if (!blockTransferArg) { |
| 944 | blockTransferArg = newForOp.getRegionIterArgs()[collapsedIdx]; |
| 945 | newResultVal = newForOp.getResult(collapsedIdx++); |
| 946 | } |
| 947 | } |
| 948 | |
| 949 | Block &oldBlock = forOp.getRegion().front(); |
| 950 | assert(oldBlock.getNumArguments() == newBlockTransferArgs.size() && |
| 951 | "unexpected argument size mismatch" ); |
| 952 | |
| 953 | // No results case: the scf::ForOp builder already created a zero |
| 954 | // result terminator. Merge before this terminator and just get rid of the |
| 955 | // original terminator that has been merged in. |
| 956 | if (newIterArgs.empty()) { |
| 957 | auto newYieldOp = cast<scf::YieldOp>(newBlock.getTerminator()); |
| 958 | rewriter.inlineBlockBefore(&oldBlock, newYieldOp, newBlockTransferArgs); |
| 959 | rewriter.eraseOp(op: newBlock.getTerminator()->getPrevNode()); |
| 960 | rewriter.replaceOp(forOp, newResultValues); |
| 961 | return success(); |
| 962 | } |
| 963 | |
| 964 | // No terminator case: merge and rewrite the merged terminator. |
| 965 | auto cloneFilteredTerminator = [&](scf::YieldOp mergedTerminator) { |
| 966 | OpBuilder::InsertionGuard g(rewriter); |
| 967 | rewriter.setInsertionPoint(mergedTerminator); |
| 968 | SmallVector<Value, 4> filteredOperands; |
| 969 | filteredOperands.reserve(N: newResultValues.size()); |
| 970 | for (unsigned idx = 0, e = keepMask.size(); idx < e; ++idx) |
| 971 | if (keepMask[idx]) |
| 972 | filteredOperands.push_back(Elt: mergedTerminator.getOperand(idx)); |
| 973 | rewriter.create<scf::YieldOp>(mergedTerminator.getLoc(), |
| 974 | filteredOperands); |
| 975 | }; |
| 976 | |
| 977 | rewriter.mergeBlocks(source: &oldBlock, dest: &newBlock, argValues: newBlockTransferArgs); |
| 978 | auto mergedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator()); |
| 979 | cloneFilteredTerminator(mergedYieldOp); |
| 980 | rewriter.eraseOp(op: mergedYieldOp); |
| 981 | rewriter.replaceOp(forOp, newResultValues); |
| 982 | return success(); |
| 983 | } |
| 984 | }; |
| 985 | |
| 986 | /// Util function that tries to compute a constant diff between u and l. |
| 987 | /// Returns std::nullopt when the difference between two AffineValueMap is |
| 988 | /// dynamic. |
| 989 | static std::optional<int64_t> computeConstDiff(Value l, Value u) { |
| 990 | IntegerAttr clb, cub; |
| 991 | if (matchPattern(l, m_Constant(&clb)) && matchPattern(u, m_Constant(&cub))) { |
| 992 | llvm::APInt lbValue = clb.getValue(); |
| 993 | llvm::APInt ubValue = cub.getValue(); |
| 994 | return (ubValue - lbValue).getSExtValue(); |
| 995 | } |
| 996 | |
| 997 | // Else a simple pattern match for x + c or c + x |
| 998 | llvm::APInt diff; |
| 999 | if (matchPattern( |
| 1000 | u, m_Op<arith::AddIOp>(matchers::m_Val(l), m_ConstantInt(&diff))) || |
| 1001 | matchPattern( |
| 1002 | u, m_Op<arith::AddIOp>(m_ConstantInt(&diff), matchers::m_Val(l)))) |
| 1003 | return diff.getSExtValue(); |
| 1004 | return std::nullopt; |
| 1005 | } |
| 1006 | |
| 1007 | /// Rewriting pattern that erases loops that are known not to iterate, replaces |
| 1008 | /// single-iteration loops with their bodies, and removes empty loops that |
| 1009 | /// iterate at least once and only return values defined outside of the loop. |
| 1010 | struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> { |
| 1011 | using OpRewritePattern<ForOp>::OpRewritePattern; |
| 1012 | |
| 1013 | LogicalResult matchAndRewrite(ForOp op, |
| 1014 | PatternRewriter &rewriter) const override { |
| 1015 | // If the upper bound is the same as the lower bound, the loop does not |
| 1016 | // iterate, just remove it. |
| 1017 | if (op.getLowerBound() == op.getUpperBound()) { |
| 1018 | rewriter.replaceOp(op, op.getInitArgs()); |
| 1019 | return success(); |
| 1020 | } |
| 1021 | |
| 1022 | std::optional<int64_t> diff = |
| 1023 | computeConstDiff(op.getLowerBound(), op.getUpperBound()); |
| 1024 | if (!diff) |
| 1025 | return failure(); |
| 1026 | |
| 1027 | // If the loop is known to have 0 iterations, remove it. |
| 1028 | if (*diff <= 0) { |
| 1029 | rewriter.replaceOp(op, op.getInitArgs()); |
| 1030 | return success(); |
| 1031 | } |
| 1032 | |
| 1033 | std::optional<llvm::APInt> maybeStepValue = op.getConstantStep(); |
| 1034 | if (!maybeStepValue) |
| 1035 | return failure(); |
| 1036 | |
| 1037 | // If the loop is known to have 1 iteration, inline its body and remove the |
| 1038 | // loop. |
| 1039 | llvm::APInt stepValue = *maybeStepValue; |
| 1040 | if (stepValue.sge(RHS: *diff)) { |
| 1041 | SmallVector<Value, 4> blockArgs; |
| 1042 | blockArgs.reserve(N: op.getInitArgs().size() + 1); |
| 1043 | blockArgs.push_back(Elt: op.getLowerBound()); |
| 1044 | llvm::append_range(blockArgs, op.getInitArgs()); |
| 1045 | replaceOpWithRegion(rewriter, op, op.getRegion(), blockArgs); |
| 1046 | return success(); |
| 1047 | } |
| 1048 | |
| 1049 | // Now we are left with loops that have more than 1 iterations. |
| 1050 | Block &block = op.getRegion().front(); |
| 1051 | if (!llvm::hasSingleElement(C&: block)) |
| 1052 | return failure(); |
| 1053 | // If the loop is empty, iterates at least once, and only returns values |
| 1054 | // defined outside of the loop, remove it and replace it with yield values. |
| 1055 | if (llvm::any_of(op.getYieldedValues(), |
| 1056 | [&](Value v) { return !op.isDefinedOutsideOfLoop(v); })) |
| 1057 | return failure(); |
| 1058 | rewriter.replaceOp(op, op.getYieldedValues()); |
| 1059 | return success(); |
| 1060 | } |
| 1061 | }; |
| 1062 | |
| 1063 | /// Fold scf.for iter_arg/result pairs that go through incoming/ougoing |
| 1064 | /// a tensor.cast op pair so as to pull the tensor.cast inside the scf.for: |
| 1065 | /// |
| 1066 | /// ``` |
| 1067 | /// %0 = tensor.cast %t0 : tensor<32x1024xf32> to tensor<?x?xf32> |
| 1068 | /// %1 = scf.for %i = %c0 to %c1024 step %c32 iter_args(%iter_t0 = %0) |
| 1069 | /// -> (tensor<?x?xf32>) { |
| 1070 | /// %2 = call @do(%iter_t0) : (tensor<?x?xf32>) -> tensor<?x?xf32> |
| 1071 | /// scf.yield %2 : tensor<?x?xf32> |
| 1072 | /// } |
| 1073 | /// use_of(%1) |
| 1074 | /// ``` |
| 1075 | /// |
| 1076 | /// folds into: |
| 1077 | /// |
| 1078 | /// ``` |
| 1079 | /// %0 = scf.for %arg2 = %c0 to %c1024 step %c32 iter_args(%arg3 = %arg0) |
| 1080 | /// -> (tensor<32x1024xf32>) { |
| 1081 | /// %2 = tensor.cast %arg3 : tensor<32x1024xf32> to tensor<?x?xf32> |
| 1082 | /// %3 = call @do(%2) : (tensor<?x?xf32>) -> tensor<?x?xf32> |
| 1083 | /// %4 = tensor.cast %3 : tensor<?x?xf32> to tensor<32x1024xf32> |
| 1084 | /// scf.yield %4 : tensor<32x1024xf32> |
| 1085 | /// } |
| 1086 | /// %1 = tensor.cast %0 : tensor<32x1024xf32> to tensor<?x?xf32> |
| 1087 | /// use_of(%1) |
| 1088 | /// ``` |
| 1089 | struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> { |
| 1090 | using OpRewritePattern<ForOp>::OpRewritePattern; |
| 1091 | |
| 1092 | LogicalResult matchAndRewrite(ForOp op, |
| 1093 | PatternRewriter &rewriter) const override { |
| 1094 | for (auto it : llvm::zip(op.getInitArgsMutable(), op.getResults())) { |
| 1095 | OpOperand &iterOpOperand = std::get<0>(it); |
| 1096 | auto incomingCast = iterOpOperand.get().getDefiningOp<tensor::CastOp>(); |
| 1097 | if (!incomingCast || |
| 1098 | incomingCast.getSource().getType() == incomingCast.getType()) |
| 1099 | continue; |
| 1100 | // If the dest type of the cast does not preserve static information in |
| 1101 | // the source type. |
| 1102 | if (!tensor::preservesStaticInformation( |
| 1103 | incomingCast.getDest().getType(), |
| 1104 | incomingCast.getSource().getType())) |
| 1105 | continue; |
| 1106 | if (!std::get<1>(it).hasOneUse()) |
| 1107 | continue; |
| 1108 | |
| 1109 | // Create a new ForOp with that iter operand replaced. |
| 1110 | rewriter.replaceOp( |
| 1111 | op, replaceAndCastForOpIterArg( |
| 1112 | rewriter, op, iterOpOperand, incomingCast.getSource(), |
| 1113 | [](OpBuilder &b, Location loc, Type type, Value source) { |
| 1114 | return b.create<tensor::CastOp>(loc, type, source); |
| 1115 | })); |
| 1116 | return success(); |
| 1117 | } |
| 1118 | return failure(); |
| 1119 | } |
| 1120 | }; |
| 1121 | |
| 1122 | } // namespace |
| 1123 | |
| 1124 | void ForOp::getCanonicalizationPatterns(RewritePatternSet &results, |
| 1125 | MLIRContext *context) { |
| 1126 | results.add<ForOpIterArgsFolder, SimplifyTrivialLoops, ForOpTensorCastFolder>( |
| 1127 | context); |
| 1128 | } |
| 1129 | |
| 1130 | std::optional<APInt> ForOp::getConstantStep() { |
| 1131 | IntegerAttr step; |
| 1132 | if (matchPattern(getStep(), m_Constant(&step))) |
| 1133 | return step.getValue(); |
| 1134 | return {}; |
| 1135 | } |
| 1136 | |
| 1137 | std::optional<MutableArrayRef<OpOperand>> ForOp::getYieldedValuesMutable() { |
| 1138 | return cast<scf::YieldOp>(getBody()->getTerminator()).getResultsMutable(); |
| 1139 | } |
| 1140 | |
| 1141 | Speculation::Speculatability ForOp::getSpeculatability() { |
| 1142 | // `scf.for (I = Start; I < End; I += 1)` terminates for all values of Start |
| 1143 | // and End. |
| 1144 | if (auto constantStep = getConstantStep()) |
| 1145 | if (*constantStep == 1) |
| 1146 | return Speculation::RecursivelySpeculatable; |
| 1147 | |
| 1148 | // For Step != 1, the loop may not terminate. We can add more smarts here if |
| 1149 | // needed. |
| 1150 | return Speculation::NotSpeculatable; |
| 1151 | } |
| 1152 | |
| 1153 | //===----------------------------------------------------------------------===// |
| 1154 | // ForallOp |
| 1155 | //===----------------------------------------------------------------------===// |
| 1156 | |
| 1157 | LogicalResult ForallOp::verify() { |
| 1158 | unsigned numLoops = getRank(); |
| 1159 | // Check number of outputs. |
| 1160 | if (getNumResults() != getOutputs().size()) |
| 1161 | return emitOpError("produces " ) |
| 1162 | << getNumResults() << " results, but has only " |
| 1163 | << getOutputs().size() << " outputs" ; |
| 1164 | |
| 1165 | // Check that the body defines block arguments for thread indices and outputs. |
| 1166 | auto *body = getBody(); |
| 1167 | if (body->getNumArguments() != numLoops + getOutputs().size()) |
| 1168 | return emitOpError("region expects " ) << numLoops << " arguments" ; |
| 1169 | for (int64_t i = 0; i < numLoops; ++i) |
| 1170 | if (!body->getArgument(i).getType().isIndex()) |
| 1171 | return emitOpError("expects " ) |
| 1172 | << i << "-th block argument to be an index" ; |
| 1173 | for (unsigned i = 0; i < getOutputs().size(); ++i) |
| 1174 | if (body->getArgument(i + numLoops).getType() != getOutputs()[i].getType()) |
| 1175 | return emitOpError("type mismatch between " ) |
| 1176 | << i << "-th output and corresponding block argument" ; |
| 1177 | if (getMapping().has_value() && !getMapping()->empty()) { |
| 1178 | if (static_cast<int64_t>(getMapping()->size()) != numLoops) |
| 1179 | return emitOpError() << "mapping attribute size must match op rank" ; |
| 1180 | for (auto map : getMapping()->getValue()) { |
| 1181 | if (!isa<DeviceMappingAttrInterface>(map)) |
| 1182 | return emitOpError() |
| 1183 | << getMappingAttrName() << " is not device mapping attribute" ; |
| 1184 | } |
| 1185 | } |
| 1186 | |
| 1187 | // Verify mixed static/dynamic control variables. |
| 1188 | Operation *op = getOperation(); |
| 1189 | if (failed(verifyListOfOperandsOrIntegers(op, "lower bound" , numLoops, |
| 1190 | getStaticLowerBound(), |
| 1191 | getDynamicLowerBound()))) |
| 1192 | return failure(); |
| 1193 | if (failed(verifyListOfOperandsOrIntegers(op, "upper bound" , numLoops, |
| 1194 | getStaticUpperBound(), |
| 1195 | getDynamicUpperBound()))) |
| 1196 | return failure(); |
| 1197 | if (failed(verifyListOfOperandsOrIntegers(op, "step" , numLoops, |
| 1198 | getStaticStep(), getDynamicStep()))) |
| 1199 | return failure(); |
| 1200 | |
| 1201 | return success(); |
| 1202 | } |
| 1203 | |
| 1204 | void ForallOp::print(OpAsmPrinter &p) { |
| 1205 | Operation *op = getOperation(); |
| 1206 | p << " (" << getInductionVars(); |
| 1207 | if (isNormalized()) { |
| 1208 | p << ") in " ; |
| 1209 | printDynamicIndexList(p, op, getDynamicUpperBound(), getStaticUpperBound(), |
| 1210 | /*valueTypes=*/{}, /*scalables=*/{}, |
| 1211 | OpAsmParser::Delimiter::Paren); |
| 1212 | } else { |
| 1213 | p << ") = " ; |
| 1214 | printDynamicIndexList(p, op, getDynamicLowerBound(), getStaticLowerBound(), |
| 1215 | /*valueTypes=*/{}, /*scalables=*/{}, |
| 1216 | OpAsmParser::Delimiter::Paren); |
| 1217 | p << " to " ; |
| 1218 | printDynamicIndexList(p, op, getDynamicUpperBound(), getStaticUpperBound(), |
| 1219 | /*valueTypes=*/{}, /*scalables=*/{}, |
| 1220 | OpAsmParser::Delimiter::Paren); |
| 1221 | p << " step " ; |
| 1222 | printDynamicIndexList(p, op, getDynamicStep(), getStaticStep(), |
| 1223 | /*valueTypes=*/{}, /*scalables=*/{}, |
| 1224 | OpAsmParser::Delimiter::Paren); |
| 1225 | } |
| 1226 | printInitializationList(p, getRegionOutArgs(), getOutputs(), " shared_outs" ); |
| 1227 | p << " " ; |
| 1228 | if (!getRegionOutArgs().empty()) |
| 1229 | p << "-> (" << getResultTypes() << ") " ; |
| 1230 | p.printRegion(getRegion(), |
| 1231 | /*printEntryBlockArgs=*/false, |
| 1232 | /*printBlockTerminators=*/getNumResults() > 0); |
| 1233 | p.printOptionalAttrDict(op->getAttrs(), {getOperandSegmentSizesAttrName(), |
| 1234 | getStaticLowerBoundAttrName(), |
| 1235 | getStaticUpperBoundAttrName(), |
| 1236 | getStaticStepAttrName()}); |
| 1237 | } |
| 1238 | |
| 1239 | ParseResult ForallOp::parse(OpAsmParser &parser, OperationState &result) { |
| 1240 | OpBuilder b(parser.getContext()); |
| 1241 | auto indexType = b.getIndexType(); |
| 1242 | |
| 1243 | // Parse an opening `(` followed by thread index variables followed by `)` |
| 1244 | // TODO: when we can refer to such "induction variable"-like handles from the |
| 1245 | // declarative assembly format, we can implement the parser as a custom hook. |
| 1246 | SmallVector<OpAsmParser::Argument, 4> ivs; |
| 1247 | if (parser.parseArgumentList(ivs, OpAsmParser::Delimiter::Paren)) |
| 1248 | return failure(); |
| 1249 | |
| 1250 | DenseI64ArrayAttr staticLbs, staticUbs, staticSteps; |
| 1251 | SmallVector<OpAsmParser::UnresolvedOperand> dynamicLbs, dynamicUbs, |
| 1252 | dynamicSteps; |
| 1253 | if (succeeded(parser.parseOptionalKeyword("in" ))) { |
| 1254 | // Parse upper bounds. |
| 1255 | if (parseDynamicIndexList(parser, dynamicUbs, staticUbs, |
| 1256 | /*valueTypes=*/nullptr, |
| 1257 | OpAsmParser::Delimiter::Paren) || |
| 1258 | parser.resolveOperands(dynamicUbs, indexType, result.operands)) |
| 1259 | return failure(); |
| 1260 | |
| 1261 | unsigned numLoops = ivs.size(); |
| 1262 | staticLbs = b.getDenseI64ArrayAttr(SmallVector<int64_t>(numLoops, 0)); |
| 1263 | staticSteps = b.getDenseI64ArrayAttr(SmallVector<int64_t>(numLoops, 1)); |
| 1264 | } else { |
| 1265 | // Parse lower bounds. |
| 1266 | if (parser.parseEqual() || |
| 1267 | parseDynamicIndexList(parser, dynamicLbs, staticLbs, |
| 1268 | /*valueTypes=*/nullptr, |
| 1269 | OpAsmParser::Delimiter::Paren) || |
| 1270 | |
| 1271 | parser.resolveOperands(dynamicLbs, indexType, result.operands)) |
| 1272 | return failure(); |
| 1273 | |
| 1274 | // Parse upper bounds. |
| 1275 | if (parser.parseKeyword("to" ) || |
| 1276 | parseDynamicIndexList(parser, dynamicUbs, staticUbs, |
| 1277 | /*valueTypes=*/nullptr, |
| 1278 | OpAsmParser::Delimiter::Paren) || |
| 1279 | parser.resolveOperands(dynamicUbs, indexType, result.operands)) |
| 1280 | return failure(); |
| 1281 | |
| 1282 | // Parse step values. |
| 1283 | if (parser.parseKeyword("step" ) || |
| 1284 | parseDynamicIndexList(parser, dynamicSteps, staticSteps, |
| 1285 | /*valueTypes=*/nullptr, |
| 1286 | OpAsmParser::Delimiter::Paren) || |
| 1287 | parser.resolveOperands(dynamicSteps, indexType, result.operands)) |
| 1288 | return failure(); |
| 1289 | } |
| 1290 | |
| 1291 | // Parse out operands and results. |
| 1292 | SmallVector<OpAsmParser::Argument, 4> regionOutArgs; |
| 1293 | SmallVector<OpAsmParser::UnresolvedOperand, 4> outOperands; |
| 1294 | SMLoc outOperandsLoc = parser.getCurrentLocation(); |
| 1295 | if (succeeded(parser.parseOptionalKeyword("shared_outs" ))) { |
| 1296 | if (outOperands.size() != result.types.size()) |
| 1297 | return parser.emitError(outOperandsLoc, |
| 1298 | "mismatch between out operands and types" ); |
| 1299 | if (parser.parseAssignmentList(regionOutArgs, outOperands) || |
| 1300 | parser.parseOptionalArrowTypeList(result.types) || |
| 1301 | parser.resolveOperands(outOperands, result.types, outOperandsLoc, |
| 1302 | result.operands)) |
| 1303 | return failure(); |
| 1304 | } |
| 1305 | |
| 1306 | // Parse region. |
| 1307 | SmallVector<OpAsmParser::Argument, 4> regionArgs; |
| 1308 | std::unique_ptr<Region> region = std::make_unique<Region>(); |
| 1309 | for (auto &iv : ivs) { |
| 1310 | iv.type = b.getIndexType(); |
| 1311 | regionArgs.push_back(iv); |
| 1312 | } |
| 1313 | for (const auto &it : llvm::enumerate(regionOutArgs)) { |
| 1314 | auto &out = it.value(); |
| 1315 | out.type = result.types[it.index()]; |
| 1316 | regionArgs.push_back(out); |
| 1317 | } |
| 1318 | if (parser.parseRegion(*region, regionArgs)) |
| 1319 | return failure(); |
| 1320 | |
| 1321 | // Ensure terminator and move region. |
| 1322 | ForallOp::ensureTerminator(*region, b, result.location); |
| 1323 | result.addRegion(std::move(region)); |
| 1324 | |
| 1325 | // Parse the optional attribute list. |
| 1326 | if (parser.parseOptionalAttrDict(result.attributes)) |
| 1327 | return failure(); |
| 1328 | |
| 1329 | result.addAttribute("staticLowerBound" , staticLbs); |
| 1330 | result.addAttribute("staticUpperBound" , staticUbs); |
| 1331 | result.addAttribute("staticStep" , staticSteps); |
| 1332 | result.addAttribute("operandSegmentSizes" , |
| 1333 | parser.getBuilder().getDenseI32ArrayAttr( |
| 1334 | {static_cast<int32_t>(dynamicLbs.size()), |
| 1335 | static_cast<int32_t>(dynamicUbs.size()), |
| 1336 | static_cast<int32_t>(dynamicSteps.size()), |
| 1337 | static_cast<int32_t>(outOperands.size())})); |
| 1338 | return success(); |
| 1339 | } |
| 1340 | |
| 1341 | // Builder that takes loop bounds. |
| 1342 | void ForallOp::build( |
| 1343 | mlir::OpBuilder &b, mlir::OperationState &result, |
| 1344 | ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs, |
| 1345 | ArrayRef<OpFoldResult> steps, ValueRange outputs, |
| 1346 | std::optional<ArrayAttr> mapping, |
| 1347 | function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) { |
| 1348 | SmallVector<int64_t> staticLbs, staticUbs, staticSteps; |
| 1349 | SmallVector<Value> dynamicLbs, dynamicUbs, dynamicSteps; |
| 1350 | dispatchIndexOpFoldResults(lbs, dynamicLbs, staticLbs); |
| 1351 | dispatchIndexOpFoldResults(ubs, dynamicUbs, staticUbs); |
| 1352 | dispatchIndexOpFoldResults(steps, dynamicSteps, staticSteps); |
| 1353 | |
| 1354 | result.addOperands(dynamicLbs); |
| 1355 | result.addOperands(dynamicUbs); |
| 1356 | result.addOperands(dynamicSteps); |
| 1357 | result.addOperands(outputs); |
| 1358 | result.addTypes(TypeRange(outputs)); |
| 1359 | |
| 1360 | result.addAttribute(getStaticLowerBoundAttrName(result.name), |
| 1361 | b.getDenseI64ArrayAttr(staticLbs)); |
| 1362 | result.addAttribute(getStaticUpperBoundAttrName(result.name), |
| 1363 | b.getDenseI64ArrayAttr(staticUbs)); |
| 1364 | result.addAttribute(getStaticStepAttrName(result.name), |
| 1365 | b.getDenseI64ArrayAttr(staticSteps)); |
| 1366 | result.addAttribute( |
| 1367 | "operandSegmentSizes" , |
| 1368 | b.getDenseI32ArrayAttr({static_cast<int32_t>(dynamicLbs.size()), |
| 1369 | static_cast<int32_t>(dynamicUbs.size()), |
| 1370 | static_cast<int32_t>(dynamicSteps.size()), |
| 1371 | static_cast<int32_t>(outputs.size())})); |
| 1372 | if (mapping.has_value()) { |
| 1373 | result.addAttribute(ForallOp::getMappingAttrName(result.name), |
| 1374 | mapping.value()); |
| 1375 | } |
| 1376 | |
| 1377 | Region *bodyRegion = result.addRegion(); |
| 1378 | OpBuilder::InsertionGuard g(b); |
| 1379 | b.createBlock(bodyRegion); |
| 1380 | Block &bodyBlock = bodyRegion->front(); |
| 1381 | |
| 1382 | // Add block arguments for indices and outputs. |
| 1383 | bodyBlock.addArguments( |
| 1384 | SmallVector<Type>(lbs.size(), b.getIndexType()), |
| 1385 | SmallVector<Location>(staticLbs.size(), result.location)); |
| 1386 | bodyBlock.addArguments( |
| 1387 | TypeRange(outputs), |
| 1388 | SmallVector<Location>(outputs.size(), result.location)); |
| 1389 | |
| 1390 | b.setInsertionPointToStart(&bodyBlock); |
| 1391 | if (!bodyBuilderFn) { |
| 1392 | ForallOp::ensureTerminator(*bodyRegion, b, result.location); |
| 1393 | return; |
| 1394 | } |
| 1395 | bodyBuilderFn(b, result.location, bodyBlock.getArguments()); |
| 1396 | } |
| 1397 | |
| 1398 | // Builder that takes loop bounds. |
| 1399 | void ForallOp::build( |
| 1400 | mlir::OpBuilder &b, mlir::OperationState &result, |
| 1401 | ArrayRef<OpFoldResult> ubs, ValueRange outputs, |
| 1402 | std::optional<ArrayAttr> mapping, |
| 1403 | function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) { |
| 1404 | unsigned numLoops = ubs.size(); |
| 1405 | SmallVector<OpFoldResult> lbs(numLoops, b.getIndexAttr(0)); |
| 1406 | SmallVector<OpFoldResult> steps(numLoops, b.getIndexAttr(1)); |
| 1407 | build(b, result, lbs, ubs, steps, outputs, mapping, bodyBuilderFn); |
| 1408 | } |
| 1409 | |
| 1410 | // Checks if the lbs are zeros and steps are ones. |
| 1411 | bool ForallOp::isNormalized() { |
| 1412 | auto allEqual = [](ArrayRef<OpFoldResult> results, int64_t val) { |
| 1413 | return llvm::all_of(results, [&](OpFoldResult ofr) { |
| 1414 | auto intValue = getConstantIntValue(ofr); |
| 1415 | return intValue.has_value() && intValue == val; |
| 1416 | }); |
| 1417 | }; |
| 1418 | return allEqual(getMixedLowerBound(), 0) && allEqual(getMixedStep(), 1); |
| 1419 | } |
| 1420 | |
| 1421 | InParallelOp ForallOp::getTerminator() { |
| 1422 | return cast<InParallelOp>(getBody()->getTerminator()); |
| 1423 | } |
| 1424 | |
| 1425 | SmallVector<Operation *> ForallOp::getCombiningOps(BlockArgument bbArg) { |
| 1426 | SmallVector<Operation *> storeOps; |
| 1427 | InParallelOp inParallelOp = getTerminator(); |
| 1428 | for (Operation &yieldOp : inParallelOp.getYieldingOps()) { |
| 1429 | if (auto parallelInsertSliceOp = |
| 1430 | dyn_cast<tensor::ParallelInsertSliceOp>(yieldOp); |
| 1431 | parallelInsertSliceOp && parallelInsertSliceOp.getDest() == bbArg) { |
| 1432 | storeOps.push_back(parallelInsertSliceOp); |
| 1433 | } |
| 1434 | } |
| 1435 | return storeOps; |
| 1436 | } |
| 1437 | |
| 1438 | std::optional<SmallVector<Value>> ForallOp::getLoopInductionVars() { |
| 1439 | return SmallVector<Value>{getBody()->getArguments().take_front(getRank())}; |
| 1440 | } |
| 1441 | |
| 1442 | // Get lower bounds as OpFoldResult. |
| 1443 | std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopLowerBounds() { |
| 1444 | Builder b(getOperation()->getContext()); |
| 1445 | return getMixedValues(getStaticLowerBound(), getDynamicLowerBound(), b); |
| 1446 | } |
| 1447 | |
| 1448 | // Get upper bounds as OpFoldResult. |
| 1449 | std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopUpperBounds() { |
| 1450 | Builder b(getOperation()->getContext()); |
| 1451 | return getMixedValues(getStaticUpperBound(), getDynamicUpperBound(), b); |
| 1452 | } |
| 1453 | |
| 1454 | // Get steps as OpFoldResult. |
| 1455 | std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopSteps() { |
| 1456 | Builder b(getOperation()->getContext()); |
| 1457 | return getMixedValues(getStaticStep(), getDynamicStep(), b); |
| 1458 | } |
| 1459 | |
| 1460 | ForallOp mlir::scf::getForallOpThreadIndexOwner(Value val) { |
| 1461 | auto tidxArg = llvm::dyn_cast<BlockArgument>(Val&: val); |
| 1462 | if (!tidxArg) |
| 1463 | return ForallOp(); |
| 1464 | assert(tidxArg.getOwner() && "unlinked block argument" ); |
| 1465 | auto *containingOp = tidxArg.getOwner()->getParentOp(); |
| 1466 | return dyn_cast<ForallOp>(containingOp); |
| 1467 | } |
| 1468 | |
| 1469 | namespace { |
| 1470 | /// Fold tensor.dim(forall shared_outs(... = %t)) to tensor.dim(%t). |
| 1471 | struct DimOfForallOp : public OpRewritePattern<tensor::DimOp> { |
| 1472 | using OpRewritePattern<tensor::DimOp>::OpRewritePattern; |
| 1473 | |
| 1474 | LogicalResult matchAndRewrite(tensor::DimOp dimOp, |
| 1475 | PatternRewriter &rewriter) const final { |
| 1476 | auto forallOp = dimOp.getSource().getDefiningOp<ForallOp>(); |
| 1477 | if (!forallOp) |
| 1478 | return failure(); |
| 1479 | Value sharedOut = |
| 1480 | forallOp.getTiedOpOperand(llvm::cast<OpResult>(dimOp.getSource())) |
| 1481 | ->get(); |
| 1482 | rewriter.modifyOpInPlace( |
| 1483 | dimOp, [&]() { dimOp.getSourceMutable().assign(sharedOut); }); |
| 1484 | return success(); |
| 1485 | } |
| 1486 | }; |
| 1487 | |
| 1488 | class ForallOpControlOperandsFolder : public OpRewritePattern<ForallOp> { |
| 1489 | public: |
| 1490 | using OpRewritePattern<ForallOp>::OpRewritePattern; |
| 1491 | |
| 1492 | LogicalResult matchAndRewrite(ForallOp op, |
| 1493 | PatternRewriter &rewriter) const override { |
| 1494 | SmallVector<OpFoldResult> mixedLowerBound(op.getMixedLowerBound()); |
| 1495 | SmallVector<OpFoldResult> mixedUpperBound(op.getMixedUpperBound()); |
| 1496 | SmallVector<OpFoldResult> mixedStep(op.getMixedStep()); |
| 1497 | if (failed(Result: foldDynamicIndexList(ofrs&: mixedLowerBound)) && |
| 1498 | failed(Result: foldDynamicIndexList(ofrs&: mixedUpperBound)) && |
| 1499 | failed(Result: foldDynamicIndexList(ofrs&: mixedStep))) |
| 1500 | return failure(); |
| 1501 | |
| 1502 | rewriter.modifyOpInPlace(op, [&]() { |
| 1503 | SmallVector<Value> dynamicLowerBound, dynamicUpperBound, dynamicStep; |
| 1504 | SmallVector<int64_t> staticLowerBound, staticUpperBound, staticStep; |
| 1505 | dispatchIndexOpFoldResults(ofrs: mixedLowerBound, dynamicVec&: dynamicLowerBound, |
| 1506 | staticVec&: staticLowerBound); |
| 1507 | op.getDynamicLowerBoundMutable().assign(dynamicLowerBound); |
| 1508 | op.setStaticLowerBound(staticLowerBound); |
| 1509 | |
| 1510 | dispatchIndexOpFoldResults(ofrs: mixedUpperBound, dynamicVec&: dynamicUpperBound, |
| 1511 | staticVec&: staticUpperBound); |
| 1512 | op.getDynamicUpperBoundMutable().assign(dynamicUpperBound); |
| 1513 | op.setStaticUpperBound(staticUpperBound); |
| 1514 | |
| 1515 | dispatchIndexOpFoldResults(ofrs: mixedStep, dynamicVec&: dynamicStep, staticVec&: staticStep); |
| 1516 | op.getDynamicStepMutable().assign(dynamicStep); |
| 1517 | op.setStaticStep(staticStep); |
| 1518 | |
| 1519 | op->setAttr(ForallOp::getOperandSegmentSizeAttr(), |
| 1520 | rewriter.getDenseI32ArrayAttr( |
| 1521 | {static_cast<int32_t>(dynamicLowerBound.size()), |
| 1522 | static_cast<int32_t>(dynamicUpperBound.size()), |
| 1523 | static_cast<int32_t>(dynamicStep.size()), |
| 1524 | static_cast<int32_t>(op.getNumResults())})); |
| 1525 | }); |
| 1526 | return success(); |
| 1527 | } |
| 1528 | }; |
| 1529 | |
| 1530 | /// The following canonicalization pattern folds the iter arguments of |
| 1531 | /// scf.forall op if :- |
| 1532 | /// 1. The corresponding result has zero uses. |
| 1533 | /// 2. The iter argument is NOT being modified within the loop body. |
| 1534 | /// uses. |
| 1535 | /// |
| 1536 | /// Example of first case :- |
| 1537 | /// INPUT: |
| 1538 | /// %res:3 = scf.forall ... shared_outs(%arg0 = %a, %arg1 = %b, %arg2 = %c) |
| 1539 | /// { |
| 1540 | /// ... |
| 1541 | /// <SOME USE OF %arg0> |
| 1542 | /// <SOME USE OF %arg1> |
| 1543 | /// <SOME USE OF %arg2> |
| 1544 | /// ... |
| 1545 | /// scf.forall.in_parallel { |
| 1546 | /// <STORE OP WITH DESTINATION %arg1> |
| 1547 | /// <STORE OP WITH DESTINATION %arg0> |
| 1548 | /// <STORE OP WITH DESTINATION %arg2> |
| 1549 | /// } |
| 1550 | /// } |
| 1551 | /// return %res#1 |
| 1552 | /// |
| 1553 | /// OUTPUT: |
| 1554 | /// %res:3 = scf.forall ... shared_outs(%new_arg0 = %b) |
| 1555 | /// { |
| 1556 | /// ... |
| 1557 | /// <SOME USE OF %a> |
| 1558 | /// <SOME USE OF %new_arg0> |
| 1559 | /// <SOME USE OF %c> |
| 1560 | /// ... |
| 1561 | /// scf.forall.in_parallel { |
| 1562 | /// <STORE OP WITH DESTINATION %new_arg0> |
| 1563 | /// } |
| 1564 | /// } |
| 1565 | /// return %res |
| 1566 | /// |
| 1567 | /// NOTE: 1. All uses of the folded shared_outs (iter argument) within the |
| 1568 | /// scf.forall is replaced by their corresponding operands. |
| 1569 | /// 2. Even if there are <STORE OP WITH DESTINATION *> ops within the body |
| 1570 | /// of the scf.forall besides within scf.forall.in_parallel terminator, |
| 1571 | /// this canonicalization remains valid. For more details, please refer |
| 1572 | /// to : |
| 1573 | /// https://github.com/llvm/llvm-project/pull/90189#discussion_r1589011124 |
| 1574 | /// 3. TODO(avarma): Generalize it for other store ops. Currently it |
| 1575 | /// handles tensor.parallel_insert_slice ops only. |
| 1576 | /// |
| 1577 | /// Example of second case :- |
| 1578 | /// INPUT: |
| 1579 | /// %res:2 = scf.forall ... shared_outs(%arg0 = %a, %arg1 = %b) |
| 1580 | /// { |
| 1581 | /// ... |
| 1582 | /// <SOME USE OF %arg0> |
| 1583 | /// <SOME USE OF %arg1> |
| 1584 | /// ... |
| 1585 | /// scf.forall.in_parallel { |
| 1586 | /// <STORE OP WITH DESTINATION %arg1> |
| 1587 | /// } |
| 1588 | /// } |
| 1589 | /// return %res#0, %res#1 |
| 1590 | /// |
| 1591 | /// OUTPUT: |
| 1592 | /// %res = scf.forall ... shared_outs(%new_arg0 = %b) |
| 1593 | /// { |
| 1594 | /// ... |
| 1595 | /// <SOME USE OF %a> |
| 1596 | /// <SOME USE OF %new_arg0> |
| 1597 | /// ... |
| 1598 | /// scf.forall.in_parallel { |
| 1599 | /// <STORE OP WITH DESTINATION %new_arg0> |
| 1600 | /// } |
| 1601 | /// } |
| 1602 | /// return %a, %res |
| 1603 | struct ForallOpIterArgsFolder : public OpRewritePattern<ForallOp> { |
| 1604 | using OpRewritePattern<ForallOp>::OpRewritePattern; |
| 1605 | |
| 1606 | LogicalResult matchAndRewrite(ForallOp forallOp, |
| 1607 | PatternRewriter &rewriter) const final { |
| 1608 | // Step 1: For a given i-th result of scf.forall, check the following :- |
| 1609 | // a. If it has any use. |
| 1610 | // b. If the corresponding iter argument is being modified within |
| 1611 | // the loop, i.e. has at least one store op with the iter arg as |
| 1612 | // its destination operand. For this we use |
| 1613 | // ForallOp::getCombiningOps(iter_arg). |
| 1614 | // |
| 1615 | // Based on the check we maintain the following :- |
| 1616 | // a. `resultToDelete` - i-th result of scf.forall that'll be |
| 1617 | // deleted. |
| 1618 | // b. `resultToReplace` - i-th result of the old scf.forall |
| 1619 | // whose uses will be replaced by the new scf.forall. |
| 1620 | // c. `newOuts` - the shared_outs' operand of the new scf.forall |
| 1621 | // corresponding to the i-th result with at least one use. |
| 1622 | SetVector<OpResult> resultToDelete; |
| 1623 | SmallVector<Value> resultToReplace; |
| 1624 | SmallVector<Value> newOuts; |
| 1625 | for (OpResult result : forallOp.getResults()) { |
| 1626 | OpOperand *opOperand = forallOp.getTiedOpOperand(result); |
| 1627 | BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand); |
| 1628 | if (result.use_empty() || forallOp.getCombiningOps(blockArg).empty()) { |
| 1629 | resultToDelete.insert(result); |
| 1630 | } else { |
| 1631 | resultToReplace.push_back(result); |
| 1632 | newOuts.push_back(opOperand->get()); |
| 1633 | } |
| 1634 | } |
| 1635 | |
| 1636 | // Return early if all results of scf.forall have at least one use and being |
| 1637 | // modified within the loop. |
| 1638 | if (resultToDelete.empty()) |
| 1639 | return failure(); |
| 1640 | |
| 1641 | // Step 2: For the the i-th result, do the following :- |
| 1642 | // a. Fetch the corresponding BlockArgument. |
| 1643 | // b. Look for store ops (currently tensor.parallel_insert_slice) |
| 1644 | // with the BlockArgument as its destination operand. |
| 1645 | // c. Remove the operations fetched in b. |
| 1646 | for (OpResult result : resultToDelete) { |
| 1647 | OpOperand *opOperand = forallOp.getTiedOpOperand(result); |
| 1648 | BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand); |
| 1649 | SmallVector<Operation *> combiningOps = |
| 1650 | forallOp.getCombiningOps(blockArg); |
| 1651 | for (Operation *combiningOp : combiningOps) |
| 1652 | rewriter.eraseOp(combiningOp); |
| 1653 | } |
| 1654 | |
| 1655 | // Step 3. Create a new scf.forall op with the new shared_outs' operands |
| 1656 | // fetched earlier |
| 1657 | auto newForallOp = rewriter.create<scf::ForallOp>( |
| 1658 | forallOp.getLoc(), forallOp.getMixedLowerBound(), |
| 1659 | forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOuts, |
| 1660 | forallOp.getMapping(), |
| 1661 | /*bodyBuilderFn =*/[](OpBuilder &, Location, ValueRange) {}); |
| 1662 | |
| 1663 | // Step 4. Merge the block of the old scf.forall into the newly created |
| 1664 | // scf.forall using the new set of arguments. |
| 1665 | Block *loopBody = forallOp.getBody(); |
| 1666 | Block *newLoopBody = newForallOp.getBody(); |
| 1667 | ArrayRef<BlockArgument> newBbArgs = newLoopBody->getArguments(); |
| 1668 | // Form initial new bbArg list with just the control operands of the new |
| 1669 | // scf.forall op. |
| 1670 | SmallVector<Value> newBlockArgs = |
| 1671 | llvm::map_to_vector(newBbArgs.take_front(N: forallOp.getRank()), |
| 1672 | [](BlockArgument b) -> Value { return b; }); |
| 1673 | Block::BlockArgListType newSharedOutsArgs = newForallOp.getRegionOutArgs(); |
| 1674 | unsigned index = 0; |
| 1675 | // Take the new corresponding bbArg if the old bbArg was used as a |
| 1676 | // destination in the in_parallel op. For all other bbArgs, use the |
| 1677 | // corresponding init_arg from the old scf.forall op. |
| 1678 | for (OpResult result : forallOp.getResults()) { |
| 1679 | if (resultToDelete.count(result)) { |
| 1680 | newBlockArgs.push_back(forallOp.getTiedOpOperand(result)->get()); |
| 1681 | } else { |
| 1682 | newBlockArgs.push_back(newSharedOutsArgs[index++]); |
| 1683 | } |
| 1684 | } |
| 1685 | rewriter.mergeBlocks(source: loopBody, dest: newLoopBody, argValues: newBlockArgs); |
| 1686 | |
| 1687 | // Step 5. Replace the uses of result of old scf.forall with that of the new |
| 1688 | // scf.forall. |
| 1689 | for (auto &&[oldResult, newResult] : |
| 1690 | llvm::zip(resultToReplace, newForallOp->getResults())) |
| 1691 | rewriter.replaceAllUsesWith(oldResult, newResult); |
| 1692 | |
| 1693 | // Step 6. Replace the uses of those values that either has no use or are |
| 1694 | // not being modified within the loop with the corresponding |
| 1695 | // OpOperand. |
| 1696 | for (OpResult oldResult : resultToDelete) |
| 1697 | rewriter.replaceAllUsesWith(oldResult, |
| 1698 | forallOp.getTiedOpOperand(oldResult)->get()); |
| 1699 | return success(); |
| 1700 | } |
| 1701 | }; |
| 1702 | |
| 1703 | struct ForallOpSingleOrZeroIterationDimsFolder |
| 1704 | : public OpRewritePattern<ForallOp> { |
| 1705 | using OpRewritePattern<ForallOp>::OpRewritePattern; |
| 1706 | |
| 1707 | LogicalResult matchAndRewrite(ForallOp op, |
| 1708 | PatternRewriter &rewriter) const override { |
| 1709 | // Do not fold dimensions if they are mapped to processing units. |
| 1710 | if (op.getMapping().has_value() && !op.getMapping()->empty()) |
| 1711 | return failure(); |
| 1712 | Location loc = op.getLoc(); |
| 1713 | |
| 1714 | // Compute new loop bounds that omit all single-iteration loop dimensions. |
| 1715 | SmallVector<OpFoldResult> newMixedLowerBounds, newMixedUpperBounds, |
| 1716 | newMixedSteps; |
| 1717 | IRMapping mapping; |
| 1718 | for (auto [lb, ub, step, iv] : |
| 1719 | llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(), |
| 1720 | op.getMixedStep(), op.getInductionVars())) { |
| 1721 | auto numIterations = constantTripCount(lb, ub, step); |
| 1722 | if (numIterations.has_value()) { |
| 1723 | // Remove the loop if it performs zero iterations. |
| 1724 | if (*numIterations == 0) { |
| 1725 | rewriter.replaceOp(op, op.getOutputs()); |
| 1726 | return success(); |
| 1727 | } |
| 1728 | // Replace the loop induction variable by the lower bound if the loop |
| 1729 | // performs a single iteration. Otherwise, copy the loop bounds. |
| 1730 | if (*numIterations == 1) { |
| 1731 | mapping.map(iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb)); |
| 1732 | continue; |
| 1733 | } |
| 1734 | } |
| 1735 | newMixedLowerBounds.push_back(lb); |
| 1736 | newMixedUpperBounds.push_back(ub); |
| 1737 | newMixedSteps.push_back(step); |
| 1738 | } |
| 1739 | |
| 1740 | // All of the loop dimensions perform a single iteration. Inline loop body. |
| 1741 | if (newMixedLowerBounds.empty()) { |
| 1742 | promote(rewriter, op); |
| 1743 | return success(); |
| 1744 | } |
| 1745 | |
| 1746 | // Exit if none of the loop dimensions perform a single iteration. |
| 1747 | if (newMixedLowerBounds.size() == static_cast<unsigned>(op.getRank())) { |
| 1748 | return rewriter.notifyMatchFailure( |
| 1749 | op, "no dimensions have 0 or 1 iterations" ); |
| 1750 | } |
| 1751 | |
| 1752 | // Replace the loop by a lower-dimensional loop. |
| 1753 | ForallOp newOp; |
| 1754 | newOp = rewriter.create<ForallOp>(loc, newMixedLowerBounds, |
| 1755 | newMixedUpperBounds, newMixedSteps, |
| 1756 | op.getOutputs(), std::nullopt, nullptr); |
| 1757 | newOp.getBodyRegion().getBlocks().clear(); |
| 1758 | // The new loop needs to keep all attributes from the old one, except for |
| 1759 | // "operandSegmentSizes" and static loop bound attributes which capture |
| 1760 | // the outdated information of the old iteration domain. |
| 1761 | SmallVector<StringAttr> elidedAttrs{newOp.getOperandSegmentSizesAttrName(), |
| 1762 | newOp.getStaticLowerBoundAttrName(), |
| 1763 | newOp.getStaticUpperBoundAttrName(), |
| 1764 | newOp.getStaticStepAttrName()}; |
| 1765 | for (const auto &namedAttr : op->getAttrs()) { |
| 1766 | if (llvm::is_contained(elidedAttrs, namedAttr.getName())) |
| 1767 | continue; |
| 1768 | rewriter.modifyOpInPlace(newOp, [&]() { |
| 1769 | newOp->setAttr(namedAttr.getName(), namedAttr.getValue()); |
| 1770 | }); |
| 1771 | } |
| 1772 | rewriter.cloneRegionBefore(op.getRegion(), newOp.getRegion(), |
| 1773 | newOp.getRegion().begin(), mapping); |
| 1774 | rewriter.replaceOp(op, newOp.getResults()); |
| 1775 | return success(); |
| 1776 | } |
| 1777 | }; |
| 1778 | |
| 1779 | /// Replace all induction vars with a single trip count with their lower bound. |
| 1780 | struct ForallOpReplaceConstantInductionVar : public OpRewritePattern<ForallOp> { |
| 1781 | using OpRewritePattern<ForallOp>::OpRewritePattern; |
| 1782 | |
| 1783 | LogicalResult matchAndRewrite(ForallOp op, |
| 1784 | PatternRewriter &rewriter) const override { |
| 1785 | Location loc = op.getLoc(); |
| 1786 | bool changed = false; |
| 1787 | for (auto [lb, ub, step, iv] : |
| 1788 | llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(), |
| 1789 | op.getMixedStep(), op.getInductionVars())) { |
| 1790 | if (iv.hasNUses(0)) |
| 1791 | continue; |
| 1792 | auto numIterations = constantTripCount(lb, ub, step); |
| 1793 | if (!numIterations.has_value() || numIterations.value() != 1) { |
| 1794 | continue; |
| 1795 | } |
| 1796 | rewriter.replaceAllUsesWith( |
| 1797 | iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb)); |
| 1798 | changed = true; |
| 1799 | } |
| 1800 | return success(IsSuccess: changed); |
| 1801 | } |
| 1802 | }; |
| 1803 | |
| 1804 | struct FoldTensorCastOfOutputIntoForallOp |
| 1805 | : public OpRewritePattern<scf::ForallOp> { |
| 1806 | using OpRewritePattern<scf::ForallOp>::OpRewritePattern; |
| 1807 | |
| 1808 | struct TypeCast { |
| 1809 | Type srcType; |
| 1810 | Type dstType; |
| 1811 | }; |
| 1812 | |
| 1813 | LogicalResult matchAndRewrite(scf::ForallOp forallOp, |
| 1814 | PatternRewriter &rewriter) const final { |
| 1815 | llvm::SmallMapVector<unsigned, TypeCast, 2> tensorCastProducers; |
| 1816 | llvm::SmallVector<Value> newOutputTensors = forallOp.getOutputs(); |
| 1817 | for (auto en : llvm::enumerate(newOutputTensors)) { |
| 1818 | auto castOp = en.value().getDefiningOp<tensor::CastOp>(); |
| 1819 | if (!castOp) |
| 1820 | continue; |
| 1821 | |
| 1822 | // Only casts that that preserve static information, i.e. will make the |
| 1823 | // loop result type "more" static than before, will be folded. |
| 1824 | if (!tensor::preservesStaticInformation(castOp.getDest().getType(), |
| 1825 | castOp.getSource().getType())) { |
| 1826 | continue; |
| 1827 | } |
| 1828 | |
| 1829 | tensorCastProducers[en.index()] = |
| 1830 | TypeCast{castOp.getSource().getType(), castOp.getType()}; |
| 1831 | newOutputTensors[en.index()] = castOp.getSource(); |
| 1832 | } |
| 1833 | |
| 1834 | if (tensorCastProducers.empty()) |
| 1835 | return failure(); |
| 1836 | |
| 1837 | // Create new loop. |
| 1838 | Location loc = forallOp.getLoc(); |
| 1839 | auto newForallOp = rewriter.create<ForallOp>( |
| 1840 | loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(), |
| 1841 | forallOp.getMixedStep(), newOutputTensors, forallOp.getMapping(), |
| 1842 | [&](OpBuilder nestedBuilder, Location nestedLoc, ValueRange bbArgs) { |
| 1843 | auto castBlockArgs = |
| 1844 | llvm::to_vector(bbArgs.take_back(forallOp->getNumResults())); |
| 1845 | for (auto [index, cast] : tensorCastProducers) { |
| 1846 | Value &oldTypeBBArg = castBlockArgs[index]; |
| 1847 | oldTypeBBArg = nestedBuilder.create<tensor::CastOp>( |
| 1848 | nestedLoc, cast.dstType, oldTypeBBArg); |
| 1849 | } |
| 1850 | |
| 1851 | // Move old body into new parallel loop. |
| 1852 | SmallVector<Value> ivsBlockArgs = |
| 1853 | llvm::to_vector(bbArgs.take_front(forallOp.getRank())); |
| 1854 | ivsBlockArgs.append(castBlockArgs); |
| 1855 | rewriter.mergeBlocks(forallOp.getBody(), |
| 1856 | bbArgs.front().getParentBlock(), ivsBlockArgs); |
| 1857 | }); |
| 1858 | |
| 1859 | // After `mergeBlocks` happened, the destinations in the terminator were |
| 1860 | // mapped to the tensor.cast old-typed results of the output bbArgs. The |
| 1861 | // destination have to be updated to point to the output bbArgs directly. |
| 1862 | auto terminator = newForallOp.getTerminator(); |
| 1863 | for (auto [yieldingOp, outputBlockArg] : llvm::zip( |
| 1864 | terminator.getYieldingOps(), newForallOp.getRegionIterArgs())) { |
| 1865 | auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(yieldingOp); |
| 1866 | insertSliceOp.getDestMutable().assign(outputBlockArg); |
| 1867 | } |
| 1868 | |
| 1869 | // Cast results back to the original types. |
| 1870 | rewriter.setInsertionPointAfter(newForallOp); |
| 1871 | SmallVector<Value> castResults = newForallOp.getResults(); |
| 1872 | for (auto &item : tensorCastProducers) { |
| 1873 | Value &oldTypeResult = castResults[item.first]; |
| 1874 | oldTypeResult = rewriter.create<tensor::CastOp>(loc, item.second.dstType, |
| 1875 | oldTypeResult); |
| 1876 | } |
| 1877 | rewriter.replaceOp(forallOp, castResults); |
| 1878 | return success(); |
| 1879 | } |
| 1880 | }; |
| 1881 | |
| 1882 | } // namespace |
| 1883 | |
| 1884 | void ForallOp::getCanonicalizationPatterns(RewritePatternSet &results, |
| 1885 | MLIRContext *context) { |
| 1886 | results.add<DimOfForallOp, FoldTensorCastOfOutputIntoForallOp, |
| 1887 | ForallOpControlOperandsFolder, ForallOpIterArgsFolder, |
| 1888 | ForallOpSingleOrZeroIterationDimsFolder, |
| 1889 | ForallOpReplaceConstantInductionVar>(context); |
| 1890 | } |
| 1891 | |
| 1892 | /// Given the region at `index`, or the parent operation if `index` is None, |
| 1893 | /// return the successor regions. These are the regions that may be selected |
| 1894 | /// during the flow of control. `operands` is a set of optional attributes that |
| 1895 | /// correspond to a constant value for each operand, or null if that operand is |
| 1896 | /// not a constant. |
| 1897 | void ForallOp::getSuccessorRegions(RegionBranchPoint point, |
| 1898 | SmallVectorImpl<RegionSuccessor> ®ions) { |
| 1899 | // Both the operation itself and the region may be branching into the body or |
| 1900 | // back into the operation itself. It is possible for loop not to enter the |
| 1901 | // body. |
| 1902 | regions.push_back(RegionSuccessor(&getRegion())); |
| 1903 | regions.push_back(RegionSuccessor()); |
| 1904 | } |
| 1905 | |
| 1906 | //===----------------------------------------------------------------------===// |
| 1907 | // InParallelOp |
| 1908 | //===----------------------------------------------------------------------===// |
| 1909 | |
| 1910 | // Build a InParallelOp with mixed static and dynamic entries. |
| 1911 | void InParallelOp::build(OpBuilder &b, OperationState &result) { |
| 1912 | OpBuilder::InsertionGuard g(b); |
| 1913 | Region *bodyRegion = result.addRegion(); |
| 1914 | b.createBlock(bodyRegion); |
| 1915 | } |
| 1916 | |
| 1917 | LogicalResult InParallelOp::verify() { |
| 1918 | scf::ForallOp forallOp = |
| 1919 | dyn_cast<scf::ForallOp>(getOperation()->getParentOp()); |
| 1920 | if (!forallOp) |
| 1921 | return this->emitOpError("expected forall op parent" ); |
| 1922 | |
| 1923 | // TODO: InParallelOpInterface. |
| 1924 | for (Operation &op : getRegion().front().getOperations()) { |
| 1925 | if (!isa<tensor::ParallelInsertSliceOp>(op)) { |
| 1926 | return this->emitOpError("expected only " ) |
| 1927 | << tensor::ParallelInsertSliceOp::getOperationName() << " ops" ; |
| 1928 | } |
| 1929 | |
| 1930 | // Verify that inserts are into out block arguments. |
| 1931 | Value dest = cast<tensor::ParallelInsertSliceOp>(op).getDest(); |
| 1932 | ArrayRef<BlockArgument> regionOutArgs = forallOp.getRegionOutArgs(); |
| 1933 | if (!llvm::is_contained(regionOutArgs, dest)) |
| 1934 | return op.emitOpError("may only insert into an output block argument" ); |
| 1935 | } |
| 1936 | return success(); |
| 1937 | } |
| 1938 | |
| 1939 | void InParallelOp::print(OpAsmPrinter &p) { |
| 1940 | p << " " ; |
| 1941 | p.printRegion(getRegion(), |
| 1942 | /*printEntryBlockArgs=*/false, |
| 1943 | /*printBlockTerminators=*/false); |
| 1944 | p.printOptionalAttrDict(getOperation()->getAttrs()); |
| 1945 | } |
| 1946 | |
| 1947 | ParseResult InParallelOp::parse(OpAsmParser &parser, OperationState &result) { |
| 1948 | auto &builder = parser.getBuilder(); |
| 1949 | |
| 1950 | SmallVector<OpAsmParser::Argument, 8> regionOperands; |
| 1951 | std::unique_ptr<Region> region = std::make_unique<Region>(); |
| 1952 | if (parser.parseRegion(*region, regionOperands)) |
| 1953 | return failure(); |
| 1954 | |
| 1955 | if (region->empty()) |
| 1956 | OpBuilder(builder.getContext()).createBlock(region.get()); |
| 1957 | result.addRegion(std::move(region)); |
| 1958 | |
| 1959 | // Parse the optional attribute list. |
| 1960 | if (parser.parseOptionalAttrDict(result.attributes)) |
| 1961 | return failure(); |
| 1962 | return success(); |
| 1963 | } |
| 1964 | |
| 1965 | OpResult InParallelOp::getParentResult(int64_t idx) { |
| 1966 | return getOperation()->getParentOp()->getResult(idx); |
| 1967 | } |
| 1968 | |
| 1969 | SmallVector<BlockArgument> InParallelOp::getDests() { |
| 1970 | return llvm::to_vector<4>( |
| 1971 | llvm::map_range(getYieldingOps(), [](Operation &op) { |
| 1972 | // Add new ops here as needed. |
| 1973 | auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(&op); |
| 1974 | return llvm::cast<BlockArgument>(insertSliceOp.getDest()); |
| 1975 | })); |
| 1976 | } |
| 1977 | |
| 1978 | llvm::iterator_range<Block::iterator> InParallelOp::getYieldingOps() { |
| 1979 | return getRegion().front().getOperations(); |
| 1980 | } |
| 1981 | |
| 1982 | //===----------------------------------------------------------------------===// |
| 1983 | // IfOp |
| 1984 | //===----------------------------------------------------------------------===// |
| 1985 | |
| 1986 | bool mlir::scf::insideMutuallyExclusiveBranches(Operation *a, Operation *b) { |
| 1987 | assert(a && "expected non-empty operation" ); |
| 1988 | assert(b && "expected non-empty operation" ); |
| 1989 | |
| 1990 | IfOp ifOp = a->getParentOfType<IfOp>(); |
| 1991 | while (ifOp) { |
| 1992 | // Check if b is inside ifOp. (We already know that a is.) |
| 1993 | if (ifOp->isProperAncestor(b)) |
| 1994 | // b is contained in ifOp. a and b are in mutually exclusive branches if |
| 1995 | // they are in different blocks of ifOp. |
| 1996 | return static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*a)) != |
| 1997 | static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*b)); |
| 1998 | // Check next enclosing IfOp. |
| 1999 | ifOp = ifOp->getParentOfType<IfOp>(); |
| 2000 | } |
| 2001 | |
| 2002 | // Could not find a common IfOp among a's and b's ancestors. |
| 2003 | return false; |
| 2004 | } |
| 2005 | |
| 2006 | LogicalResult |
| 2007 | IfOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc, |
| 2008 | IfOp::Adaptor adaptor, |
| 2009 | SmallVectorImpl<Type> &inferredReturnTypes) { |
| 2010 | if (adaptor.getRegions().empty()) |
| 2011 | return failure(); |
| 2012 | Region *r = &adaptor.getThenRegion(); |
| 2013 | if (r->empty()) |
| 2014 | return failure(); |
| 2015 | Block &b = r->front(); |
| 2016 | if (b.empty()) |
| 2017 | return failure(); |
| 2018 | auto yieldOp = llvm::dyn_cast<YieldOp>(b.back()); |
| 2019 | if (!yieldOp) |
| 2020 | return failure(); |
| 2021 | TypeRange types = yieldOp.getOperandTypes(); |
| 2022 | llvm::append_range(inferredReturnTypes, types); |
| 2023 | return success(); |
| 2024 | } |
| 2025 | |
| 2026 | void IfOp::build(OpBuilder &builder, OperationState &result, |
| 2027 | TypeRange resultTypes, Value cond) { |
| 2028 | return build(builder, result, resultTypes, cond, /*addThenBlock=*/false, |
| 2029 | /*addElseBlock=*/false); |
| 2030 | } |
| 2031 | |
| 2032 | void IfOp::build(OpBuilder &builder, OperationState &result, |
| 2033 | TypeRange resultTypes, Value cond, bool addThenBlock, |
| 2034 | bool addElseBlock) { |
| 2035 | assert((!addElseBlock || addThenBlock) && |
| 2036 | "must not create else block w/o then block" ); |
| 2037 | result.addTypes(resultTypes); |
| 2038 | result.addOperands(cond); |
| 2039 | |
| 2040 | // Add regions and blocks. |
| 2041 | OpBuilder::InsertionGuard guard(builder); |
| 2042 | Region *thenRegion = result.addRegion(); |
| 2043 | if (addThenBlock) |
| 2044 | builder.createBlock(thenRegion); |
| 2045 | Region *elseRegion = result.addRegion(); |
| 2046 | if (addElseBlock) |
| 2047 | builder.createBlock(elseRegion); |
| 2048 | } |
| 2049 | |
| 2050 | void IfOp::build(OpBuilder &builder, OperationState &result, Value cond, |
| 2051 | bool withElseRegion) { |
| 2052 | build(builder, result, TypeRange{}, cond, withElseRegion); |
| 2053 | } |
| 2054 | |
| 2055 | void IfOp::build(OpBuilder &builder, OperationState &result, |
| 2056 | TypeRange resultTypes, Value cond, bool withElseRegion) { |
| 2057 | result.addTypes(resultTypes); |
| 2058 | result.addOperands(cond); |
| 2059 | |
| 2060 | // Build then region. |
| 2061 | OpBuilder::InsertionGuard guard(builder); |
| 2062 | Region *thenRegion = result.addRegion(); |
| 2063 | builder.createBlock(thenRegion); |
| 2064 | if (resultTypes.empty()) |
| 2065 | IfOp::ensureTerminator(*thenRegion, builder, result.location); |
| 2066 | |
| 2067 | // Build else region. |
| 2068 | Region *elseRegion = result.addRegion(); |
| 2069 | if (withElseRegion) { |
| 2070 | builder.createBlock(elseRegion); |
| 2071 | if (resultTypes.empty()) |
| 2072 | IfOp::ensureTerminator(*elseRegion, builder, result.location); |
| 2073 | } |
| 2074 | } |
| 2075 | |
| 2076 | void IfOp::build(OpBuilder &builder, OperationState &result, Value cond, |
| 2077 | function_ref<void(OpBuilder &, Location)> thenBuilder, |
| 2078 | function_ref<void(OpBuilder &, Location)> elseBuilder) { |
| 2079 | assert(thenBuilder && "the builder callback for 'then' must be present" ); |
| 2080 | result.addOperands(cond); |
| 2081 | |
| 2082 | // Build then region. |
| 2083 | OpBuilder::InsertionGuard guard(builder); |
| 2084 | Region *thenRegion = result.addRegion(); |
| 2085 | builder.createBlock(thenRegion); |
| 2086 | thenBuilder(builder, result.location); |
| 2087 | |
| 2088 | // Build else region. |
| 2089 | Region *elseRegion = result.addRegion(); |
| 2090 | if (elseBuilder) { |
| 2091 | builder.createBlock(elseRegion); |
| 2092 | elseBuilder(builder, result.location); |
| 2093 | } |
| 2094 | |
| 2095 | // Infer result types. |
| 2096 | SmallVector<Type> inferredReturnTypes; |
| 2097 | MLIRContext *ctx = builder.getContext(); |
| 2098 | auto attrDict = DictionaryAttr::get(ctx, result.attributes); |
| 2099 | if (succeeded(inferReturnTypes(ctx, std::nullopt, result.operands, attrDict, |
| 2100 | /*properties=*/nullptr, result.regions, |
| 2101 | inferredReturnTypes))) { |
| 2102 | result.addTypes(inferredReturnTypes); |
| 2103 | } |
| 2104 | } |
| 2105 | |
| 2106 | LogicalResult IfOp::verify() { |
| 2107 | if (getNumResults() != 0 && getElseRegion().empty()) |
| 2108 | return emitOpError("must have an else block if defining values" ); |
| 2109 | return success(); |
| 2110 | } |
| 2111 | |
| 2112 | ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) { |
| 2113 | // Create the regions for 'then'. |
| 2114 | result.regions.reserve(2); |
| 2115 | Region *thenRegion = result.addRegion(); |
| 2116 | Region *elseRegion = result.addRegion(); |
| 2117 | |
| 2118 | auto &builder = parser.getBuilder(); |
| 2119 | OpAsmParser::UnresolvedOperand cond; |
| 2120 | Type i1Type = builder.getIntegerType(1); |
| 2121 | if (parser.parseOperand(cond) || |
| 2122 | parser.resolveOperand(cond, i1Type, result.operands)) |
| 2123 | return failure(); |
| 2124 | // Parse optional results type list. |
| 2125 | if (parser.parseOptionalArrowTypeList(result.types)) |
| 2126 | return failure(); |
| 2127 | // Parse the 'then' region. |
| 2128 | if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{})) |
| 2129 | return failure(); |
| 2130 | IfOp::ensureTerminator(*thenRegion, parser.getBuilder(), result.location); |
| 2131 | |
| 2132 | // If we find an 'else' keyword then parse the 'else' region. |
| 2133 | if (!parser.parseOptionalKeyword("else" )) { |
| 2134 | if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{})) |
| 2135 | return failure(); |
| 2136 | IfOp::ensureTerminator(*elseRegion, parser.getBuilder(), result.location); |
| 2137 | } |
| 2138 | |
| 2139 | // Parse the optional attribute list. |
| 2140 | if (parser.parseOptionalAttrDict(result.attributes)) |
| 2141 | return failure(); |
| 2142 | return success(); |
| 2143 | } |
| 2144 | |
| 2145 | void IfOp::print(OpAsmPrinter &p) { |
| 2146 | bool printBlockTerminators = false; |
| 2147 | |
| 2148 | p << " " << getCondition(); |
| 2149 | if (!getResults().empty()) { |
| 2150 | p << " -> (" << getResultTypes() << ")" ; |
| 2151 | // Print yield explicitly if the op defines values. |
| 2152 | printBlockTerminators = true; |
| 2153 | } |
| 2154 | p << ' '; |
| 2155 | p.printRegion(getThenRegion(), |
| 2156 | /*printEntryBlockArgs=*/false, |
| 2157 | /*printBlockTerminators=*/printBlockTerminators); |
| 2158 | |
| 2159 | // Print the 'else' regions if it exists and has a block. |
| 2160 | auto &elseRegion = getElseRegion(); |
| 2161 | if (!elseRegion.empty()) { |
| 2162 | p << " else " ; |
| 2163 | p.printRegion(elseRegion, |
| 2164 | /*printEntryBlockArgs=*/false, |
| 2165 | /*printBlockTerminators=*/printBlockTerminators); |
| 2166 | } |
| 2167 | |
| 2168 | p.printOptionalAttrDict((*this)->getAttrs()); |
| 2169 | } |
| 2170 | |
| 2171 | void IfOp::getSuccessorRegions(RegionBranchPoint point, |
| 2172 | SmallVectorImpl<RegionSuccessor> ®ions) { |
| 2173 | // The `then` and the `else` region branch back to the parent operation. |
| 2174 | if (!point.isParent()) { |
| 2175 | regions.push_back(RegionSuccessor(getResults())); |
| 2176 | return; |
| 2177 | } |
| 2178 | |
| 2179 | regions.push_back(RegionSuccessor(&getThenRegion())); |
| 2180 | |
| 2181 | // Don't consider the else region if it is empty. |
| 2182 | Region *elseRegion = &this->getElseRegion(); |
| 2183 | if (elseRegion->empty()) |
| 2184 | regions.push_back(RegionSuccessor()); |
| 2185 | else |
| 2186 | regions.push_back(RegionSuccessor(elseRegion)); |
| 2187 | } |
| 2188 | |
| 2189 | void IfOp::getEntrySuccessorRegions(ArrayRef<Attribute> operands, |
| 2190 | SmallVectorImpl<RegionSuccessor> ®ions) { |
| 2191 | FoldAdaptor adaptor(operands, *this); |
| 2192 | auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition()); |
| 2193 | if (!boolAttr || boolAttr.getValue()) |
| 2194 | regions.emplace_back(&getThenRegion()); |
| 2195 | |
| 2196 | // If the else region is empty, execution continues after the parent op. |
| 2197 | if (!boolAttr || !boolAttr.getValue()) { |
| 2198 | if (!getElseRegion().empty()) |
| 2199 | regions.emplace_back(&getElseRegion()); |
| 2200 | else |
| 2201 | regions.emplace_back(getResults()); |
| 2202 | } |
| 2203 | } |
| 2204 | |
| 2205 | LogicalResult IfOp::fold(FoldAdaptor adaptor, |
| 2206 | SmallVectorImpl<OpFoldResult> &results) { |
| 2207 | // if (!c) then A() else B() -> if c then B() else A() |
| 2208 | if (getElseRegion().empty()) |
| 2209 | return failure(); |
| 2210 | |
| 2211 | arith::XOrIOp xorStmt = getCondition().getDefiningOp<arith::XOrIOp>(); |
| 2212 | if (!xorStmt) |
| 2213 | return failure(); |
| 2214 | |
| 2215 | if (!matchPattern(xorStmt.getRhs(), m_One())) |
| 2216 | return failure(); |
| 2217 | |
| 2218 | getConditionMutable().assign(xorStmt.getLhs()); |
| 2219 | Block *thenBlock = &getThenRegion().front(); |
| 2220 | // It would be nicer to use iplist::swap, but that has no implemented |
| 2221 | // callbacks See: https://llvm.org/doxygen/ilist_8h_source.html#l00224 |
| 2222 | getThenRegion().getBlocks().splice(getThenRegion().getBlocks().begin(), |
| 2223 | getElseRegion().getBlocks()); |
| 2224 | getElseRegion().getBlocks().splice(getElseRegion().getBlocks().begin(), |
| 2225 | getThenRegion().getBlocks(), thenBlock); |
| 2226 | return success(); |
| 2227 | } |
| 2228 | |
| 2229 | void IfOp::getRegionInvocationBounds( |
| 2230 | ArrayRef<Attribute> operands, |
| 2231 | SmallVectorImpl<InvocationBounds> &invocationBounds) { |
| 2232 | if (auto cond = llvm::dyn_cast_or_null<BoolAttr>(operands[0])) { |
| 2233 | // If the condition is known, then one region is known to be executed once |
| 2234 | // and the other zero times. |
| 2235 | invocationBounds.emplace_back(0, cond.getValue() ? 1 : 0); |
| 2236 | invocationBounds.emplace_back(0, cond.getValue() ? 0 : 1); |
| 2237 | } else { |
| 2238 | // Non-constant condition. Each region may be executed 0 or 1 times. |
| 2239 | invocationBounds.assign(2, {0, 1}); |
| 2240 | } |
| 2241 | } |
| 2242 | |
| 2243 | namespace { |
| 2244 | // Pattern to remove unused IfOp results. |
| 2245 | struct RemoveUnusedResults : public OpRewritePattern<IfOp> { |
| 2246 | using OpRewritePattern<IfOp>::OpRewritePattern; |
| 2247 | |
| 2248 | void transferBody(Block *source, Block *dest, ArrayRef<OpResult> usedResults, |
| 2249 | PatternRewriter &rewriter) const { |
| 2250 | // Move all operations to the destination block. |
| 2251 | rewriter.mergeBlocks(source, dest); |
| 2252 | // Replace the yield op by one that returns only the used values. |
| 2253 | auto yieldOp = cast<scf::YieldOp>(dest->getTerminator()); |
| 2254 | SmallVector<Value, 4> usedOperands; |
| 2255 | llvm::transform(Range&: usedResults, d_first: std::back_inserter(x&: usedOperands), |
| 2256 | F: [&](OpResult result) { |
| 2257 | return yieldOp.getOperand(result.getResultNumber()); |
| 2258 | }); |
| 2259 | rewriter.modifyOpInPlace(yieldOp, |
| 2260 | [&]() { yieldOp->setOperands(usedOperands); }); |
| 2261 | } |
| 2262 | |
| 2263 | LogicalResult matchAndRewrite(IfOp op, |
| 2264 | PatternRewriter &rewriter) const override { |
| 2265 | // Compute the list of used results. |
| 2266 | SmallVector<OpResult, 4> usedResults; |
| 2267 | llvm::copy_if(op.getResults(), std::back_inserter(x&: usedResults), |
| 2268 | [](OpResult result) { return !result.use_empty(); }); |
| 2269 | |
| 2270 | // Replace the operation if only a subset of its results have uses. |
| 2271 | if (usedResults.size() == op.getNumResults()) |
| 2272 | return failure(); |
| 2273 | |
| 2274 | // Compute the result types of the replacement operation. |
| 2275 | SmallVector<Type, 4> newTypes; |
| 2276 | llvm::transform(Range&: usedResults, d_first: std::back_inserter(x&: newTypes), |
| 2277 | F: [](OpResult result) { return result.getType(); }); |
| 2278 | |
| 2279 | // Create a replacement operation with empty then and else regions. |
| 2280 | auto newOp = |
| 2281 | rewriter.create<IfOp>(op.getLoc(), newTypes, op.getCondition()); |
| 2282 | rewriter.createBlock(&newOp.getThenRegion()); |
| 2283 | rewriter.createBlock(&newOp.getElseRegion()); |
| 2284 | |
| 2285 | // Move the bodies and replace the terminators (note there is a then and |
| 2286 | // an else region since the operation returns results). |
| 2287 | transferBody(source: op.getBody(0), dest: newOp.getBody(0), usedResults, rewriter); |
| 2288 | transferBody(source: op.getBody(1), dest: newOp.getBody(1), usedResults, rewriter); |
| 2289 | |
| 2290 | // Replace the operation by the new one. |
| 2291 | SmallVector<Value, 4> repResults(op.getNumResults()); |
| 2292 | for (const auto &en : llvm::enumerate(First&: usedResults)) |
| 2293 | repResults[en.value().getResultNumber()] = newOp.getResult(en.index()); |
| 2294 | rewriter.replaceOp(op, repResults); |
| 2295 | return success(); |
| 2296 | } |
| 2297 | }; |
| 2298 | |
| 2299 | struct RemoveStaticCondition : public OpRewritePattern<IfOp> { |
| 2300 | using OpRewritePattern<IfOp>::OpRewritePattern; |
| 2301 | |
| 2302 | LogicalResult matchAndRewrite(IfOp op, |
| 2303 | PatternRewriter &rewriter) const override { |
| 2304 | BoolAttr condition; |
| 2305 | if (!matchPattern(op.getCondition(), m_Constant(bind_value: &condition))) |
| 2306 | return failure(); |
| 2307 | |
| 2308 | if (condition.getValue()) |
| 2309 | replaceOpWithRegion(rewriter, op, op.getThenRegion()); |
| 2310 | else if (!op.getElseRegion().empty()) |
| 2311 | replaceOpWithRegion(rewriter, op, op.getElseRegion()); |
| 2312 | else |
| 2313 | rewriter.eraseOp(op: op); |
| 2314 | |
| 2315 | return success(); |
| 2316 | } |
| 2317 | }; |
| 2318 | |
| 2319 | /// Hoist any yielded results whose operands are defined outside |
| 2320 | /// the if, to a select instruction. |
| 2321 | struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> { |
| 2322 | using OpRewritePattern<IfOp>::OpRewritePattern; |
| 2323 | |
| 2324 | LogicalResult matchAndRewrite(IfOp op, |
| 2325 | PatternRewriter &rewriter) const override { |
| 2326 | if (op->getNumResults() == 0) |
| 2327 | return failure(); |
| 2328 | |
| 2329 | auto cond = op.getCondition(); |
| 2330 | auto thenYieldArgs = op.thenYield().getOperands(); |
| 2331 | auto elseYieldArgs = op.elseYield().getOperands(); |
| 2332 | |
| 2333 | SmallVector<Type> nonHoistable; |
| 2334 | for (auto [trueVal, falseVal] : llvm::zip(thenYieldArgs, elseYieldArgs)) { |
| 2335 | if (&op.getThenRegion() == trueVal.getParentRegion() || |
| 2336 | &op.getElseRegion() == falseVal.getParentRegion()) |
| 2337 | nonHoistable.push_back(trueVal.getType()); |
| 2338 | } |
| 2339 | // Early exit if there aren't any yielded values we can |
| 2340 | // hoist outside the if. |
| 2341 | if (nonHoistable.size() == op->getNumResults()) |
| 2342 | return failure(); |
| 2343 | |
| 2344 | IfOp replacement = rewriter.create<IfOp>(op.getLoc(), nonHoistable, cond, |
| 2345 | /*withElseRegion=*/false); |
| 2346 | if (replacement.thenBlock()) |
| 2347 | rewriter.eraseBlock(block: replacement.thenBlock()); |
| 2348 | replacement.getThenRegion().takeBody(op.getThenRegion()); |
| 2349 | replacement.getElseRegion().takeBody(op.getElseRegion()); |
| 2350 | |
| 2351 | SmallVector<Value> results(op->getNumResults()); |
| 2352 | assert(thenYieldArgs.size() == results.size()); |
| 2353 | assert(elseYieldArgs.size() == results.size()); |
| 2354 | |
| 2355 | SmallVector<Value> trueYields; |
| 2356 | SmallVector<Value> falseYields; |
| 2357 | rewriter.setInsertionPoint(replacement); |
| 2358 | for (const auto &it : |
| 2359 | llvm::enumerate(llvm::zip(thenYieldArgs, elseYieldArgs))) { |
| 2360 | Value trueVal = std::get<0>(it.value()); |
| 2361 | Value falseVal = std::get<1>(it.value()); |
| 2362 | if (&replacement.getThenRegion() == trueVal.getParentRegion() || |
| 2363 | &replacement.getElseRegion() == falseVal.getParentRegion()) { |
| 2364 | results[it.index()] = replacement.getResult(trueYields.size()); |
| 2365 | trueYields.push_back(trueVal); |
| 2366 | falseYields.push_back(falseVal); |
| 2367 | } else if (trueVal == falseVal) |
| 2368 | results[it.index()] = trueVal; |
| 2369 | else |
| 2370 | results[it.index()] = rewriter.create<arith::SelectOp>( |
| 2371 | op.getLoc(), cond, trueVal, falseVal); |
| 2372 | } |
| 2373 | |
| 2374 | rewriter.setInsertionPointToEnd(replacement.thenBlock()); |
| 2375 | rewriter.replaceOpWithNewOp<YieldOp>(replacement.thenYield(), trueYields); |
| 2376 | |
| 2377 | rewriter.setInsertionPointToEnd(replacement.elseBlock()); |
| 2378 | rewriter.replaceOpWithNewOp<YieldOp>(replacement.elseYield(), falseYields); |
| 2379 | |
| 2380 | rewriter.replaceOp(op, results); |
| 2381 | return success(); |
| 2382 | } |
| 2383 | }; |
| 2384 | |
| 2385 | /// Allow the true region of an if to assume the condition is true |
| 2386 | /// and vice versa. For example: |
| 2387 | /// |
| 2388 | /// scf.if %cmp { |
| 2389 | /// print(%cmp) |
| 2390 | /// } |
| 2391 | /// |
| 2392 | /// becomes |
| 2393 | /// |
| 2394 | /// scf.if %cmp { |
| 2395 | /// print(true) |
| 2396 | /// } |
| 2397 | /// |
| 2398 | struct ConditionPropagation : public OpRewritePattern<IfOp> { |
| 2399 | using OpRewritePattern<IfOp>::OpRewritePattern; |
| 2400 | |
| 2401 | LogicalResult matchAndRewrite(IfOp op, |
| 2402 | PatternRewriter &rewriter) const override { |
| 2403 | // Early exit if the condition is constant since replacing a constant |
| 2404 | // in the body with another constant isn't a simplification. |
| 2405 | if (matchPattern(op.getCondition(), m_Constant())) |
| 2406 | return failure(); |
| 2407 | |
| 2408 | bool changed = false; |
| 2409 | mlir::Type i1Ty = rewriter.getI1Type(); |
| 2410 | |
| 2411 | // These variables serve to prevent creating duplicate constants |
| 2412 | // and hold constant true or false values. |
| 2413 | Value constantTrue = nullptr; |
| 2414 | Value constantFalse = nullptr; |
| 2415 | |
| 2416 | for (OpOperand &use : |
| 2417 | llvm::make_early_inc_range(op.getCondition().getUses())) { |
| 2418 | if (op.getThenRegion().isAncestor(use.getOwner()->getParentRegion())) { |
| 2419 | changed = true; |
| 2420 | |
| 2421 | if (!constantTrue) |
| 2422 | constantTrue = rewriter.create<arith::ConstantOp>( |
| 2423 | op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1)); |
| 2424 | |
| 2425 | rewriter.modifyOpInPlace(use.getOwner(), |
| 2426 | [&]() { use.set(constantTrue); }); |
| 2427 | } else if (op.getElseRegion().isAncestor( |
| 2428 | use.getOwner()->getParentRegion())) { |
| 2429 | changed = true; |
| 2430 | |
| 2431 | if (!constantFalse) |
| 2432 | constantFalse = rewriter.create<arith::ConstantOp>( |
| 2433 | op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 0)); |
| 2434 | |
| 2435 | rewriter.modifyOpInPlace(use.getOwner(), |
| 2436 | [&]() { use.set(constantFalse); }); |
| 2437 | } |
| 2438 | } |
| 2439 | |
| 2440 | return success(IsSuccess: changed); |
| 2441 | } |
| 2442 | }; |
| 2443 | |
| 2444 | /// Remove any statements from an if that are equivalent to the condition |
| 2445 | /// or its negation. For example: |
| 2446 | /// |
| 2447 | /// %res:2 = scf.if %cmp { |
| 2448 | /// yield something(), true |
| 2449 | /// } else { |
| 2450 | /// yield something2(), false |
| 2451 | /// } |
| 2452 | /// print(%res#1) |
| 2453 | /// |
| 2454 | /// becomes |
| 2455 | /// %res = scf.if %cmp { |
| 2456 | /// yield something() |
| 2457 | /// } else { |
| 2458 | /// yield something2() |
| 2459 | /// } |
| 2460 | /// print(%cmp) |
| 2461 | /// |
| 2462 | /// Additionally if both branches yield the same value, replace all uses |
| 2463 | /// of the result with the yielded value. |
| 2464 | /// |
| 2465 | /// %res:2 = scf.if %cmp { |
| 2466 | /// yield something(), %arg1 |
| 2467 | /// } else { |
| 2468 | /// yield something2(), %arg1 |
| 2469 | /// } |
| 2470 | /// print(%res#1) |
| 2471 | /// |
| 2472 | /// becomes |
| 2473 | /// %res = scf.if %cmp { |
| 2474 | /// yield something() |
| 2475 | /// } else { |
| 2476 | /// yield something2() |
| 2477 | /// } |
| 2478 | /// print(%arg1) |
| 2479 | /// |
| 2480 | struct ReplaceIfYieldWithConditionOrValue : public OpRewritePattern<IfOp> { |
| 2481 | using OpRewritePattern<IfOp>::OpRewritePattern; |
| 2482 | |
| 2483 | LogicalResult matchAndRewrite(IfOp op, |
| 2484 | PatternRewriter &rewriter) const override { |
| 2485 | // Early exit if there are no results that could be replaced. |
| 2486 | if (op.getNumResults() == 0) |
| 2487 | return failure(); |
| 2488 | |
| 2489 | auto trueYield = |
| 2490 | cast<scf::YieldOp>(op.getThenRegion().back().getTerminator()); |
| 2491 | auto falseYield = |
| 2492 | cast<scf::YieldOp>(op.getElseRegion().back().getTerminator()); |
| 2493 | |
| 2494 | rewriter.setInsertionPoint(op->getBlock(), |
| 2495 | op.getOperation()->getIterator()); |
| 2496 | bool changed = false; |
| 2497 | Type i1Ty = rewriter.getI1Type(); |
| 2498 | for (auto [trueResult, falseResult, opResult] : |
| 2499 | llvm::zip(trueYield.getResults(), falseYield.getResults(), |
| 2500 | op.getResults())) { |
| 2501 | if (trueResult == falseResult) { |
| 2502 | if (!opResult.use_empty()) { |
| 2503 | opResult.replaceAllUsesWith(trueResult); |
| 2504 | changed = true; |
| 2505 | } |
| 2506 | continue; |
| 2507 | } |
| 2508 | |
| 2509 | BoolAttr trueYield, falseYield; |
| 2510 | if (!matchPattern(trueResult, m_Constant(&trueYield)) || |
| 2511 | !matchPattern(falseResult, m_Constant(&falseYield))) |
| 2512 | continue; |
| 2513 | |
| 2514 | bool trueVal = trueYield.getValue(); |
| 2515 | bool falseVal = falseYield.getValue(); |
| 2516 | if (!trueVal && falseVal) { |
| 2517 | if (!opResult.use_empty()) { |
| 2518 | Dialect *constDialect = trueResult.getDefiningOp()->getDialect(); |
| 2519 | Value notCond = rewriter.create<arith::XOrIOp>( |
| 2520 | op.getLoc(), op.getCondition(), |
| 2521 | constDialect |
| 2522 | ->materializeConstant(rewriter, |
| 2523 | rewriter.getIntegerAttr(i1Ty, 1), i1Ty, |
| 2524 | op.getLoc()) |
| 2525 | ->getResult(0)); |
| 2526 | opResult.replaceAllUsesWith(notCond); |
| 2527 | changed = true; |
| 2528 | } |
| 2529 | } |
| 2530 | if (trueVal && !falseVal) { |
| 2531 | if (!opResult.use_empty()) { |
| 2532 | opResult.replaceAllUsesWith(op.getCondition()); |
| 2533 | changed = true; |
| 2534 | } |
| 2535 | } |
| 2536 | } |
| 2537 | return success(IsSuccess: changed); |
| 2538 | } |
| 2539 | }; |
| 2540 | |
| 2541 | /// Merge any consecutive scf.if's with the same condition. |
| 2542 | /// |
| 2543 | /// scf.if %cond { |
| 2544 | /// firstCodeTrue();... |
| 2545 | /// } else { |
| 2546 | /// firstCodeFalse();... |
| 2547 | /// } |
| 2548 | /// %res = scf.if %cond { |
| 2549 | /// secondCodeTrue();... |
| 2550 | /// } else { |
| 2551 | /// secondCodeFalse();... |
| 2552 | /// } |
| 2553 | /// |
| 2554 | /// becomes |
| 2555 | /// %res = scf.if %cmp { |
| 2556 | /// firstCodeTrue();... |
| 2557 | /// secondCodeTrue();... |
| 2558 | /// } else { |
| 2559 | /// firstCodeFalse();... |
| 2560 | /// secondCodeFalse();... |
| 2561 | /// } |
| 2562 | struct CombineIfs : public OpRewritePattern<IfOp> { |
| 2563 | using OpRewritePattern<IfOp>::OpRewritePattern; |
| 2564 | |
| 2565 | LogicalResult matchAndRewrite(IfOp nextIf, |
| 2566 | PatternRewriter &rewriter) const override { |
| 2567 | Block *parent = nextIf->getBlock(); |
| 2568 | if (nextIf == &parent->front()) |
| 2569 | return failure(); |
| 2570 | |
| 2571 | auto prevIf = dyn_cast<IfOp>(nextIf->getPrevNode()); |
| 2572 | if (!prevIf) |
| 2573 | return failure(); |
| 2574 | |
| 2575 | // Determine the logical then/else blocks when prevIf's |
| 2576 | // condition is used. Null means the block does not exist |
| 2577 | // in that case (e.g. empty else). If neither of these |
| 2578 | // are set, the two conditions cannot be compared. |
| 2579 | Block *nextThen = nullptr; |
| 2580 | Block *nextElse = nullptr; |
| 2581 | if (nextIf.getCondition() == prevIf.getCondition()) { |
| 2582 | nextThen = nextIf.thenBlock(); |
| 2583 | if (!nextIf.getElseRegion().empty()) |
| 2584 | nextElse = nextIf.elseBlock(); |
| 2585 | } |
| 2586 | if (arith::XOrIOp notv = |
| 2587 | nextIf.getCondition().getDefiningOp<arith::XOrIOp>()) { |
| 2588 | if (notv.getLhs() == prevIf.getCondition() && |
| 2589 | matchPattern(notv.getRhs(), m_One())) { |
| 2590 | nextElse = nextIf.thenBlock(); |
| 2591 | if (!nextIf.getElseRegion().empty()) |
| 2592 | nextThen = nextIf.elseBlock(); |
| 2593 | } |
| 2594 | } |
| 2595 | if (arith::XOrIOp notv = |
| 2596 | prevIf.getCondition().getDefiningOp<arith::XOrIOp>()) { |
| 2597 | if (notv.getLhs() == nextIf.getCondition() && |
| 2598 | matchPattern(notv.getRhs(), m_One())) { |
| 2599 | nextElse = nextIf.thenBlock(); |
| 2600 | if (!nextIf.getElseRegion().empty()) |
| 2601 | nextThen = nextIf.elseBlock(); |
| 2602 | } |
| 2603 | } |
| 2604 | |
| 2605 | if (!nextThen && !nextElse) |
| 2606 | return failure(); |
| 2607 | |
| 2608 | SmallVector<Value> prevElseYielded; |
| 2609 | if (!prevIf.getElseRegion().empty()) |
| 2610 | prevElseYielded = prevIf.elseYield().getOperands(); |
| 2611 | // Replace all uses of return values of op within nextIf with the |
| 2612 | // corresponding yields |
| 2613 | for (auto it : llvm::zip(prevIf.getResults(), |
| 2614 | prevIf.thenYield().getOperands(), prevElseYielded)) |
| 2615 | for (OpOperand &use : |
| 2616 | llvm::make_early_inc_range(std::get<0>(it).getUses())) { |
| 2617 | if (nextThen && nextThen->getParent()->isAncestor( |
| 2618 | use.getOwner()->getParentRegion())) { |
| 2619 | rewriter.startOpModification(use.getOwner()); |
| 2620 | use.set(std::get<1>(it)); |
| 2621 | rewriter.finalizeOpModification(use.getOwner()); |
| 2622 | } else if (nextElse && nextElse->getParent()->isAncestor( |
| 2623 | use.getOwner()->getParentRegion())) { |
| 2624 | rewriter.startOpModification(use.getOwner()); |
| 2625 | use.set(std::get<2>(it)); |
| 2626 | rewriter.finalizeOpModification(use.getOwner()); |
| 2627 | } |
| 2628 | } |
| 2629 | |
| 2630 | SmallVector<Type> mergedTypes(prevIf.getResultTypes()); |
| 2631 | llvm::append_range(mergedTypes, nextIf.getResultTypes()); |
| 2632 | |
| 2633 | IfOp combinedIf = rewriter.create<IfOp>( |
| 2634 | nextIf.getLoc(), mergedTypes, prevIf.getCondition(), /*hasElse=*/false); |
| 2635 | rewriter.eraseBlock(block: &combinedIf.getThenRegion().back()); |
| 2636 | |
| 2637 | rewriter.inlineRegionBefore(prevIf.getThenRegion(), |
| 2638 | combinedIf.getThenRegion(), |
| 2639 | combinedIf.getThenRegion().begin()); |
| 2640 | |
| 2641 | if (nextThen) { |
| 2642 | YieldOp thenYield = combinedIf.thenYield(); |
| 2643 | YieldOp thenYield2 = cast<YieldOp>(nextThen->getTerminator()); |
| 2644 | rewriter.mergeBlocks(source: nextThen, dest: combinedIf.thenBlock()); |
| 2645 | rewriter.setInsertionPointToEnd(combinedIf.thenBlock()); |
| 2646 | |
| 2647 | SmallVector<Value> mergedYields(thenYield.getOperands()); |
| 2648 | llvm::append_range(mergedYields, thenYield2.getOperands()); |
| 2649 | rewriter.create<YieldOp>(thenYield2.getLoc(), mergedYields); |
| 2650 | rewriter.eraseOp(op: thenYield); |
| 2651 | rewriter.eraseOp(op: thenYield2); |
| 2652 | } |
| 2653 | |
| 2654 | rewriter.inlineRegionBefore(prevIf.getElseRegion(), |
| 2655 | combinedIf.getElseRegion(), |
| 2656 | combinedIf.getElseRegion().begin()); |
| 2657 | |
| 2658 | if (nextElse) { |
| 2659 | if (combinedIf.getElseRegion().empty()) { |
| 2660 | rewriter.inlineRegionBefore(*nextElse->getParent(), |
| 2661 | combinedIf.getElseRegion(), |
| 2662 | combinedIf.getElseRegion().begin()); |
| 2663 | } else { |
| 2664 | YieldOp elseYield = combinedIf.elseYield(); |
| 2665 | YieldOp elseYield2 = cast<YieldOp>(nextElse->getTerminator()); |
| 2666 | rewriter.mergeBlocks(source: nextElse, dest: combinedIf.elseBlock()); |
| 2667 | |
| 2668 | rewriter.setInsertionPointToEnd(combinedIf.elseBlock()); |
| 2669 | |
| 2670 | SmallVector<Value> mergedElseYields(elseYield.getOperands()); |
| 2671 | llvm::append_range(mergedElseYields, elseYield2.getOperands()); |
| 2672 | |
| 2673 | rewriter.create<YieldOp>(elseYield2.getLoc(), mergedElseYields); |
| 2674 | rewriter.eraseOp(op: elseYield); |
| 2675 | rewriter.eraseOp(op: elseYield2); |
| 2676 | } |
| 2677 | } |
| 2678 | |
| 2679 | SmallVector<Value> prevValues; |
| 2680 | SmallVector<Value> nextValues; |
| 2681 | for (const auto &pair : llvm::enumerate(combinedIf.getResults())) { |
| 2682 | if (pair.index() < prevIf.getNumResults()) |
| 2683 | prevValues.push_back(pair.value()); |
| 2684 | else |
| 2685 | nextValues.push_back(pair.value()); |
| 2686 | } |
| 2687 | rewriter.replaceOp(prevIf, prevValues); |
| 2688 | rewriter.replaceOp(nextIf, nextValues); |
| 2689 | return success(); |
| 2690 | } |
| 2691 | }; |
| 2692 | |
| 2693 | /// Pattern to remove an empty else branch. |
| 2694 | struct RemoveEmptyElseBranch : public OpRewritePattern<IfOp> { |
| 2695 | using OpRewritePattern<IfOp>::OpRewritePattern; |
| 2696 | |
| 2697 | LogicalResult matchAndRewrite(IfOp ifOp, |
| 2698 | PatternRewriter &rewriter) const override { |
| 2699 | // Cannot remove else region when there are operation results. |
| 2700 | if (ifOp.getNumResults()) |
| 2701 | return failure(); |
| 2702 | Block *elseBlock = ifOp.elseBlock(); |
| 2703 | if (!elseBlock || !llvm::hasSingleElement(C&: *elseBlock)) |
| 2704 | return failure(); |
| 2705 | auto newIfOp = rewriter.cloneWithoutRegions(ifOp); |
| 2706 | rewriter.inlineRegionBefore(ifOp.getThenRegion(), newIfOp.getThenRegion(), |
| 2707 | newIfOp.getThenRegion().begin()); |
| 2708 | rewriter.eraseOp(op: ifOp); |
| 2709 | return success(); |
| 2710 | } |
| 2711 | }; |
| 2712 | |
| 2713 | /// Convert nested `if`s into `arith.andi` + single `if`. |
| 2714 | /// |
| 2715 | /// scf.if %arg0 { |
| 2716 | /// scf.if %arg1 { |
| 2717 | /// ... |
| 2718 | /// scf.yield |
| 2719 | /// } |
| 2720 | /// scf.yield |
| 2721 | /// } |
| 2722 | /// becomes |
| 2723 | /// |
| 2724 | /// %0 = arith.andi %arg0, %arg1 |
| 2725 | /// scf.if %0 { |
| 2726 | /// ... |
| 2727 | /// scf.yield |
| 2728 | /// } |
| 2729 | struct CombineNestedIfs : public OpRewritePattern<IfOp> { |
| 2730 | using OpRewritePattern<IfOp>::OpRewritePattern; |
| 2731 | |
| 2732 | LogicalResult matchAndRewrite(IfOp op, |
| 2733 | PatternRewriter &rewriter) const override { |
| 2734 | auto nestedOps = op.thenBlock()->without_terminator(); |
| 2735 | // Nested `if` must be the only op in block. |
| 2736 | if (!llvm::hasSingleElement(nestedOps)) |
| 2737 | return failure(); |
| 2738 | |
| 2739 | // If there is an else block, it can only yield |
| 2740 | if (op.elseBlock() && !llvm::hasSingleElement(*op.elseBlock())) |
| 2741 | return failure(); |
| 2742 | |
| 2743 | auto nestedIf = dyn_cast<IfOp>(*nestedOps.begin()); |
| 2744 | if (!nestedIf) |
| 2745 | return failure(); |
| 2746 | |
| 2747 | if (nestedIf.elseBlock() && !llvm::hasSingleElement(*nestedIf.elseBlock())) |
| 2748 | return failure(); |
| 2749 | |
| 2750 | SmallVector<Value> thenYield(op.thenYield().getOperands()); |
| 2751 | SmallVector<Value> elseYield; |
| 2752 | if (op.elseBlock()) |
| 2753 | llvm::append_range(elseYield, op.elseYield().getOperands()); |
| 2754 | |
| 2755 | // A list of indices for which we should upgrade the value yielded |
| 2756 | // in the else to a select. |
| 2757 | SmallVector<unsigned> elseYieldsToUpgradeToSelect; |
| 2758 | |
| 2759 | // If the outer scf.if yields a value produced by the inner scf.if, |
| 2760 | // only permit combining if the value yielded when the condition |
| 2761 | // is false in the outer scf.if is the same value yielded when the |
| 2762 | // inner scf.if condition is false. |
| 2763 | // Note that the array access to elseYield will not go out of bounds |
| 2764 | // since it must have the same length as thenYield, since they both |
| 2765 | // come from the same scf.if. |
| 2766 | for (const auto &tup : llvm::enumerate(thenYield)) { |
| 2767 | if (tup.value().getDefiningOp() == nestedIf) { |
| 2768 | auto nestedIdx = llvm::cast<OpResult>(tup.value()).getResultNumber(); |
| 2769 | if (nestedIf.elseYield().getOperand(nestedIdx) != |
| 2770 | elseYield[tup.index()]) { |
| 2771 | return failure(); |
| 2772 | } |
| 2773 | // If the correctness test passes, we will yield |
| 2774 | // corresponding value from the inner scf.if |
| 2775 | thenYield[tup.index()] = nestedIf.thenYield().getOperand(nestedIdx); |
| 2776 | continue; |
| 2777 | } |
| 2778 | |
| 2779 | // Otherwise, we need to ensure the else block of the combined |
| 2780 | // condition still returns the same value when the outer condition is |
| 2781 | // true and the inner condition is false. This can be accomplished if |
| 2782 | // the then value is defined outside the outer scf.if and we replace the |
| 2783 | // value with a select that considers just the outer condition. Since |
| 2784 | // the else region contains just the yield, its yielded value is |
| 2785 | // defined outside the scf.if, by definition. |
| 2786 | |
| 2787 | // If the then value is defined within the scf.if, bail. |
| 2788 | if (tup.value().getParentRegion() == &op.getThenRegion()) { |
| 2789 | return failure(); |
| 2790 | } |
| 2791 | elseYieldsToUpgradeToSelect.push_back(tup.index()); |
| 2792 | } |
| 2793 | |
| 2794 | Location loc = op.getLoc(); |
| 2795 | Value newCondition = rewriter.create<arith::AndIOp>( |
| 2796 | loc, op.getCondition(), nestedIf.getCondition()); |
| 2797 | auto newIf = rewriter.create<IfOp>(loc, op.getResultTypes(), newCondition); |
| 2798 | Block *newIfBlock = rewriter.createBlock(&newIf.getThenRegion()); |
| 2799 | |
| 2800 | SmallVector<Value> results; |
| 2801 | llvm::append_range(results, newIf.getResults()); |
| 2802 | rewriter.setInsertionPoint(newIf); |
| 2803 | |
| 2804 | for (auto idx : elseYieldsToUpgradeToSelect) |
| 2805 | results[idx] = rewriter.create<arith::SelectOp>( |
| 2806 | op.getLoc(), op.getCondition(), thenYield[idx], elseYield[idx]); |
| 2807 | |
| 2808 | rewriter.mergeBlocks(source: nestedIf.thenBlock(), dest: newIfBlock); |
| 2809 | rewriter.setInsertionPointToEnd(newIf.thenBlock()); |
| 2810 | rewriter.replaceOpWithNewOp<YieldOp>(newIf.thenYield(), thenYield); |
| 2811 | if (!elseYield.empty()) { |
| 2812 | rewriter.createBlock(&newIf.getElseRegion()); |
| 2813 | rewriter.setInsertionPointToEnd(newIf.elseBlock()); |
| 2814 | rewriter.create<YieldOp>(loc, elseYield); |
| 2815 | } |
| 2816 | rewriter.replaceOp(op, results); |
| 2817 | return success(); |
| 2818 | } |
| 2819 | }; |
| 2820 | |
| 2821 | } // namespace |
| 2822 | |
| 2823 | void IfOp::getCanonicalizationPatterns(RewritePatternSet &results, |
| 2824 | MLIRContext *context) { |
| 2825 | results.add<CombineIfs, CombineNestedIfs, ConditionPropagation, |
| 2826 | ConvertTrivialIfToSelect, RemoveEmptyElseBranch, |
| 2827 | RemoveStaticCondition, RemoveUnusedResults, |
| 2828 | ReplaceIfYieldWithConditionOrValue>(context); |
| 2829 | } |
| 2830 | |
| 2831 | Block *IfOp::thenBlock() { return &getThenRegion().back(); } |
| 2832 | YieldOp IfOp::thenYield() { return cast<YieldOp>(&thenBlock()->back()); } |
| 2833 | Block *IfOp::elseBlock() { |
| 2834 | Region &r = getElseRegion(); |
| 2835 | if (r.empty()) |
| 2836 | return nullptr; |
| 2837 | return &r.back(); |
| 2838 | } |
| 2839 | YieldOp IfOp::elseYield() { return cast<YieldOp>(&elseBlock()->back()); } |
| 2840 | |
| 2841 | //===----------------------------------------------------------------------===// |
| 2842 | // ParallelOp |
| 2843 | //===----------------------------------------------------------------------===// |
| 2844 | |
| 2845 | void ParallelOp::build( |
| 2846 | OpBuilder &builder, OperationState &result, ValueRange lowerBounds, |
| 2847 | ValueRange upperBounds, ValueRange steps, ValueRange initVals, |
| 2848 | function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)> |
| 2849 | bodyBuilderFn) { |
| 2850 | result.addOperands(lowerBounds); |
| 2851 | result.addOperands(upperBounds); |
| 2852 | result.addOperands(steps); |
| 2853 | result.addOperands(initVals); |
| 2854 | result.addAttribute( |
| 2855 | ParallelOp::getOperandSegmentSizeAttr(), |
| 2856 | builder.getDenseI32ArrayAttr({static_cast<int32_t>(lowerBounds.size()), |
| 2857 | static_cast<int32_t>(upperBounds.size()), |
| 2858 | static_cast<int32_t>(steps.size()), |
| 2859 | static_cast<int32_t>(initVals.size())})); |
| 2860 | result.addTypes(initVals.getTypes()); |
| 2861 | |
| 2862 | OpBuilder::InsertionGuard guard(builder); |
| 2863 | unsigned numIVs = steps.size(); |
| 2864 | SmallVector<Type, 8> argTypes(numIVs, builder.getIndexType()); |
| 2865 | SmallVector<Location, 8> argLocs(numIVs, result.location); |
| 2866 | Region *bodyRegion = result.addRegion(); |
| 2867 | Block *bodyBlock = builder.createBlock(bodyRegion, {}, argTypes, argLocs); |
| 2868 | |
| 2869 | if (bodyBuilderFn) { |
| 2870 | builder.setInsertionPointToStart(bodyBlock); |
| 2871 | bodyBuilderFn(builder, result.location, |
| 2872 | bodyBlock->getArguments().take_front(numIVs), |
| 2873 | bodyBlock->getArguments().drop_front(numIVs)); |
| 2874 | } |
| 2875 | // Add terminator only if there are no reductions. |
| 2876 | if (initVals.empty()) |
| 2877 | ParallelOp::ensureTerminator(*bodyRegion, builder, result.location); |
| 2878 | } |
| 2879 | |
| 2880 | void ParallelOp::build( |
| 2881 | OpBuilder &builder, OperationState &result, ValueRange lowerBounds, |
| 2882 | ValueRange upperBounds, ValueRange steps, |
| 2883 | function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) { |
| 2884 | // Only pass a non-null wrapper if bodyBuilderFn is non-null itself. Make sure |
| 2885 | // we don't capture a reference to a temporary by constructing the lambda at |
| 2886 | // function level. |
| 2887 | auto wrappedBuilderFn = [&bodyBuilderFn](OpBuilder &nestedBuilder, |
| 2888 | Location nestedLoc, ValueRange ivs, |
| 2889 | ValueRange) { |
| 2890 | bodyBuilderFn(nestedBuilder, nestedLoc, ivs); |
| 2891 | }; |
| 2892 | function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)> wrapper; |
| 2893 | if (bodyBuilderFn) |
| 2894 | wrapper = wrappedBuilderFn; |
| 2895 | |
| 2896 | build(builder, result, lowerBounds, upperBounds, steps, ValueRange(), |
| 2897 | wrapper); |
| 2898 | } |
| 2899 | |
| 2900 | LogicalResult ParallelOp::verify() { |
| 2901 | // Check that there is at least one value in lowerBound, upperBound and step. |
| 2902 | // It is sufficient to test only step, because it is ensured already that the |
| 2903 | // number of elements in lowerBound, upperBound and step are the same. |
| 2904 | Operation::operand_range stepValues = getStep(); |
| 2905 | if (stepValues.empty()) |
| 2906 | return emitOpError( |
| 2907 | "needs at least one tuple element for lowerBound, upperBound and step" ); |
| 2908 | |
| 2909 | // Check whether all constant step values are positive. |
| 2910 | for (Value stepValue : stepValues) |
| 2911 | if (auto cst = getConstantIntValue(stepValue)) |
| 2912 | if (*cst <= 0) |
| 2913 | return emitOpError("constant step operand must be positive" ); |
| 2914 | |
| 2915 | // Check that the body defines the same number of block arguments as the |
| 2916 | // number of tuple elements in step. |
| 2917 | Block *body = getBody(); |
| 2918 | if (body->getNumArguments() != stepValues.size()) |
| 2919 | return emitOpError() << "expects the same number of induction variables: " |
| 2920 | << body->getNumArguments() |
| 2921 | << " as bound and step values: " << stepValues.size(); |
| 2922 | for (auto arg : body->getArguments()) |
| 2923 | if (!arg.getType().isIndex()) |
| 2924 | return emitOpError( |
| 2925 | "expects arguments for the induction variable to be of index type" ); |
| 2926 | |
| 2927 | // Check that the terminator is an scf.reduce op. |
| 2928 | auto reduceOp = verifyAndGetTerminator<scf::ReduceOp>( |
| 2929 | *this, getRegion(), "expects body to terminate with 'scf.reduce'" ); |
| 2930 | if (!reduceOp) |
| 2931 | return failure(); |
| 2932 | |
| 2933 | // Check that the number of results is the same as the number of reductions. |
| 2934 | auto resultsSize = getResults().size(); |
| 2935 | auto reductionsSize = reduceOp.getReductions().size(); |
| 2936 | auto initValsSize = getInitVals().size(); |
| 2937 | if (resultsSize != reductionsSize) |
| 2938 | return emitOpError() << "expects number of results: " << resultsSize |
| 2939 | << " to be the same as number of reductions: " |
| 2940 | << reductionsSize; |
| 2941 | if (resultsSize != initValsSize) |
| 2942 | return emitOpError() << "expects number of results: " << resultsSize |
| 2943 | << " to be the same as number of initial values: " |
| 2944 | << initValsSize; |
| 2945 | |
| 2946 | // Check that the types of the results and reductions are the same. |
| 2947 | for (int64_t i = 0; i < static_cast<int64_t>(reductionsSize); ++i) { |
| 2948 | auto resultType = getOperation()->getResult(i).getType(); |
| 2949 | auto reductionOperandType = reduceOp.getOperands()[i].getType(); |
| 2950 | if (resultType != reductionOperandType) |
| 2951 | return reduceOp.emitOpError() |
| 2952 | << "expects type of " << i |
| 2953 | << "-th reduction operand: " << reductionOperandType |
| 2954 | << " to be the same as the " << i |
| 2955 | << "-th result type: " << resultType; |
| 2956 | } |
| 2957 | return success(); |
| 2958 | } |
| 2959 | |
| 2960 | ParseResult ParallelOp::parse(OpAsmParser &parser, OperationState &result) { |
| 2961 | auto &builder = parser.getBuilder(); |
| 2962 | // Parse an opening `(` followed by induction variables followed by `)` |
| 2963 | SmallVector<OpAsmParser::Argument, 4> ivs; |
| 2964 | if (parser.parseArgumentList(ivs, OpAsmParser::Delimiter::Paren)) |
| 2965 | return failure(); |
| 2966 | |
| 2967 | // Parse loop bounds. |
| 2968 | SmallVector<OpAsmParser::UnresolvedOperand, 4> lower; |
| 2969 | if (parser.parseEqual() || |
| 2970 | parser.parseOperandList(lower, ivs.size(), |
| 2971 | OpAsmParser::Delimiter::Paren) || |
| 2972 | parser.resolveOperands(lower, builder.getIndexType(), result.operands)) |
| 2973 | return failure(); |
| 2974 | |
| 2975 | SmallVector<OpAsmParser::UnresolvedOperand, 4> upper; |
| 2976 | if (parser.parseKeyword("to" ) || |
| 2977 | parser.parseOperandList(upper, ivs.size(), |
| 2978 | OpAsmParser::Delimiter::Paren) || |
| 2979 | parser.resolveOperands(upper, builder.getIndexType(), result.operands)) |
| 2980 | return failure(); |
| 2981 | |
| 2982 | // Parse step values. |
| 2983 | SmallVector<OpAsmParser::UnresolvedOperand, 4> steps; |
| 2984 | if (parser.parseKeyword("step" ) || |
| 2985 | parser.parseOperandList(steps, ivs.size(), |
| 2986 | OpAsmParser::Delimiter::Paren) || |
| 2987 | parser.resolveOperands(steps, builder.getIndexType(), result.operands)) |
| 2988 | return failure(); |
| 2989 | |
| 2990 | // Parse init values. |
| 2991 | SmallVector<OpAsmParser::UnresolvedOperand, 4> initVals; |
| 2992 | if (succeeded(parser.parseOptionalKeyword("init" ))) { |
| 2993 | if (parser.parseOperandList(initVals, OpAsmParser::Delimiter::Paren)) |
| 2994 | return failure(); |
| 2995 | } |
| 2996 | |
| 2997 | // Parse optional results in case there is a reduce. |
| 2998 | if (parser.parseOptionalArrowTypeList(result.types)) |
| 2999 | return failure(); |
| 3000 | |
| 3001 | // Now parse the body. |
| 3002 | Region *body = result.addRegion(); |
| 3003 | for (auto &iv : ivs) |
| 3004 | iv.type = builder.getIndexType(); |
| 3005 | if (parser.parseRegion(*body, ivs)) |
| 3006 | return failure(); |
| 3007 | |
| 3008 | // Set `operandSegmentSizes` attribute. |
| 3009 | result.addAttribute( |
| 3010 | ParallelOp::getOperandSegmentSizeAttr(), |
| 3011 | builder.getDenseI32ArrayAttr({static_cast<int32_t>(lower.size()), |
| 3012 | static_cast<int32_t>(upper.size()), |
| 3013 | static_cast<int32_t>(steps.size()), |
| 3014 | static_cast<int32_t>(initVals.size())})); |
| 3015 | |
| 3016 | // Parse attributes. |
| 3017 | if (parser.parseOptionalAttrDict(result.attributes) || |
| 3018 | parser.resolveOperands(initVals, result.types, parser.getNameLoc(), |
| 3019 | result.operands)) |
| 3020 | return failure(); |
| 3021 | |
| 3022 | // Add a terminator if none was parsed. |
| 3023 | ParallelOp::ensureTerminator(*body, builder, result.location); |
| 3024 | return success(); |
| 3025 | } |
| 3026 | |
| 3027 | void ParallelOp::print(OpAsmPrinter &p) { |
| 3028 | p << " (" << getBody()->getArguments() << ") = (" << getLowerBound() |
| 3029 | << ") to (" << getUpperBound() << ") step (" << getStep() << ")" ; |
| 3030 | if (!getInitVals().empty()) |
| 3031 | p << " init (" << getInitVals() << ")" ; |
| 3032 | p.printOptionalArrowTypeList(getResultTypes()); |
| 3033 | p << ' '; |
| 3034 | p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); |
| 3035 | p.printOptionalAttrDict( |
| 3036 | (*this)->getAttrs(), |
| 3037 | /*elidedAttrs=*/ParallelOp::getOperandSegmentSizeAttr()); |
| 3038 | } |
| 3039 | |
| 3040 | SmallVector<Region *> ParallelOp::getLoopRegions() { return {&getRegion()}; } |
| 3041 | |
| 3042 | std::optional<SmallVector<Value>> ParallelOp::getLoopInductionVars() { |
| 3043 | return SmallVector<Value>{getBody()->getArguments()}; |
| 3044 | } |
| 3045 | |
| 3046 | std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopLowerBounds() { |
| 3047 | return getLowerBound(); |
| 3048 | } |
| 3049 | |
| 3050 | std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopUpperBounds() { |
| 3051 | return getUpperBound(); |
| 3052 | } |
| 3053 | |
| 3054 | std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopSteps() { |
| 3055 | return getStep(); |
| 3056 | } |
| 3057 | |
| 3058 | ParallelOp mlir::scf::getParallelForInductionVarOwner(Value val) { |
| 3059 | auto ivArg = llvm::dyn_cast<BlockArgument>(Val&: val); |
| 3060 | if (!ivArg) |
| 3061 | return ParallelOp(); |
| 3062 | assert(ivArg.getOwner() && "unlinked block argument" ); |
| 3063 | auto *containingOp = ivArg.getOwner()->getParentOp(); |
| 3064 | return dyn_cast<ParallelOp>(containingOp); |
| 3065 | } |
| 3066 | |
| 3067 | namespace { |
| 3068 | // Collapse loop dimensions that perform a single iteration. |
| 3069 | struct ParallelOpSingleOrZeroIterationDimsFolder |
| 3070 | : public OpRewritePattern<ParallelOp> { |
| 3071 | using OpRewritePattern<ParallelOp>::OpRewritePattern; |
| 3072 | |
| 3073 | LogicalResult matchAndRewrite(ParallelOp op, |
| 3074 | PatternRewriter &rewriter) const override { |
| 3075 | Location loc = op.getLoc(); |
| 3076 | |
| 3077 | // Compute new loop bounds that omit all single-iteration loop dimensions. |
| 3078 | SmallVector<Value> newLowerBounds, newUpperBounds, newSteps; |
| 3079 | IRMapping mapping; |
| 3080 | for (auto [lb, ub, step, iv] : |
| 3081 | llvm::zip(op.getLowerBound(), op.getUpperBound(), op.getStep(), |
| 3082 | op.getInductionVars())) { |
| 3083 | auto numIterations = constantTripCount(lb, ub, step); |
| 3084 | if (numIterations.has_value()) { |
| 3085 | // Remove the loop if it performs zero iterations. |
| 3086 | if (*numIterations == 0) { |
| 3087 | rewriter.replaceOp(op, op.getInitVals()); |
| 3088 | return success(); |
| 3089 | } |
| 3090 | // Replace the loop induction variable by the lower bound if the loop |
| 3091 | // performs a single iteration. Otherwise, copy the loop bounds. |
| 3092 | if (*numIterations == 1) { |
| 3093 | mapping.map(iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb)); |
| 3094 | continue; |
| 3095 | } |
| 3096 | } |
| 3097 | newLowerBounds.push_back(lb); |
| 3098 | newUpperBounds.push_back(ub); |
| 3099 | newSteps.push_back(step); |
| 3100 | } |
| 3101 | // Exit if none of the loop dimensions perform a single iteration. |
| 3102 | if (newLowerBounds.size() == op.getLowerBound().size()) |
| 3103 | return failure(); |
| 3104 | |
| 3105 | if (newLowerBounds.empty()) { |
| 3106 | // All of the loop dimensions perform a single iteration. Inline |
| 3107 | // loop body and nested ReduceOp's |
| 3108 | SmallVector<Value> results; |
| 3109 | results.reserve(N: op.getInitVals().size()); |
| 3110 | for (auto &bodyOp : op.getBody()->without_terminator()) |
| 3111 | rewriter.clone(bodyOp, mapping); |
| 3112 | auto reduceOp = cast<ReduceOp>(op.getBody()->getTerminator()); |
| 3113 | for (int64_t i = 0, e = reduceOp.getReductions().size(); i < e; ++i) { |
| 3114 | Block &reduceBlock = reduceOp.getReductions()[i].front(); |
| 3115 | auto initValIndex = results.size(); |
| 3116 | mapping.map(reduceBlock.getArgument(i: 0), op.getInitVals()[initValIndex]); |
| 3117 | mapping.map(reduceBlock.getArgument(i: 1), |
| 3118 | mapping.lookupOrDefault(reduceOp.getOperands()[i])); |
| 3119 | for (auto &reduceBodyOp : reduceBlock.without_terminator()) |
| 3120 | rewriter.clone(reduceBodyOp, mapping); |
| 3121 | |
| 3122 | auto result = mapping.lookupOrDefault( |
| 3123 | cast<ReduceReturnOp>(reduceBlock.getTerminator()).getResult()); |
| 3124 | results.push_back(Elt: result); |
| 3125 | } |
| 3126 | |
| 3127 | rewriter.replaceOp(op, results); |
| 3128 | return success(); |
| 3129 | } |
| 3130 | // Replace the parallel loop by lower-dimensional parallel loop. |
| 3131 | auto newOp = |
| 3132 | rewriter.create<ParallelOp>(op.getLoc(), newLowerBounds, newUpperBounds, |
| 3133 | newSteps, op.getInitVals(), nullptr); |
| 3134 | // Erase the empty block that was inserted by the builder. |
| 3135 | rewriter.eraseBlock(block: newOp.getBody()); |
| 3136 | // Clone the loop body and remap the block arguments of the collapsed loops |
| 3137 | // (inlining does not support a cancellable block argument mapping). |
| 3138 | rewriter.cloneRegionBefore(op.getRegion(), newOp.getRegion(), |
| 3139 | newOp.getRegion().begin(), mapping); |
| 3140 | rewriter.replaceOp(op, newOp.getResults()); |
| 3141 | return success(); |
| 3142 | } |
| 3143 | }; |
| 3144 | |
| 3145 | struct MergeNestedParallelLoops : public OpRewritePattern<ParallelOp> { |
| 3146 | using OpRewritePattern<ParallelOp>::OpRewritePattern; |
| 3147 | |
| 3148 | LogicalResult matchAndRewrite(ParallelOp op, |
| 3149 | PatternRewriter &rewriter) const override { |
| 3150 | Block &outerBody = *op.getBody(); |
| 3151 | if (!llvm::hasSingleElement(C: outerBody.without_terminator())) |
| 3152 | return failure(); |
| 3153 | |
| 3154 | auto innerOp = dyn_cast<ParallelOp>(outerBody.front()); |
| 3155 | if (!innerOp) |
| 3156 | return failure(); |
| 3157 | |
| 3158 | for (auto val : outerBody.getArguments()) |
| 3159 | if (llvm::is_contained(innerOp.getLowerBound(), val) || |
| 3160 | llvm::is_contained(innerOp.getUpperBound(), val) || |
| 3161 | llvm::is_contained(innerOp.getStep(), val)) |
| 3162 | return failure(); |
| 3163 | |
| 3164 | // Reductions are not supported yet. |
| 3165 | if (!op.getInitVals().empty() || !innerOp.getInitVals().empty()) |
| 3166 | return failure(); |
| 3167 | |
| 3168 | auto bodyBuilder = [&](OpBuilder &builder, Location /*loc*/, |
| 3169 | ValueRange iterVals, ValueRange) { |
| 3170 | Block &innerBody = *innerOp.getBody(); |
| 3171 | assert(iterVals.size() == |
| 3172 | (outerBody.getNumArguments() + innerBody.getNumArguments())); |
| 3173 | IRMapping mapping; |
| 3174 | mapping.map(from: outerBody.getArguments(), |
| 3175 | to: iterVals.take_front(n: outerBody.getNumArguments())); |
| 3176 | mapping.map(from: innerBody.getArguments(), |
| 3177 | to: iterVals.take_back(n: innerBody.getNumArguments())); |
| 3178 | for (Operation &op : innerBody.without_terminator()) |
| 3179 | builder.clone(op, mapping); |
| 3180 | }; |
| 3181 | |
| 3182 | auto concatValues = [](const auto &first, const auto &second) { |
| 3183 | SmallVector<Value> ret; |
| 3184 | ret.reserve(N: first.size() + second.size()); |
| 3185 | ret.assign(first.begin(), first.end()); |
| 3186 | ret.append(second.begin(), second.end()); |
| 3187 | return ret; |
| 3188 | }; |
| 3189 | |
| 3190 | auto newLowerBounds = |
| 3191 | concatValues(op.getLowerBound(), innerOp.getLowerBound()); |
| 3192 | auto newUpperBounds = |
| 3193 | concatValues(op.getUpperBound(), innerOp.getUpperBound()); |
| 3194 | auto newSteps = concatValues(op.getStep(), innerOp.getStep()); |
| 3195 | |
| 3196 | rewriter.replaceOpWithNewOp<ParallelOp>(op, newLowerBounds, newUpperBounds, |
| 3197 | newSteps, std::nullopt, |
| 3198 | bodyBuilder); |
| 3199 | return success(); |
| 3200 | } |
| 3201 | }; |
| 3202 | |
| 3203 | } // namespace |
| 3204 | |
| 3205 | void ParallelOp::getCanonicalizationPatterns(RewritePatternSet &results, |
| 3206 | MLIRContext *context) { |
| 3207 | results |
| 3208 | .add<ParallelOpSingleOrZeroIterationDimsFolder, MergeNestedParallelLoops>( |
| 3209 | context); |
| 3210 | } |
| 3211 | |
| 3212 | /// Given the region at `index`, or the parent operation if `index` is None, |
| 3213 | /// return the successor regions. These are the regions that may be selected |
| 3214 | /// during the flow of control. `operands` is a set of optional attributes that |
| 3215 | /// correspond to a constant value for each operand, or null if that operand is |
| 3216 | /// not a constant. |
| 3217 | void ParallelOp::getSuccessorRegions( |
| 3218 | RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { |
| 3219 | // Both the operation itself and the region may be branching into the body or |
| 3220 | // back into the operation itself. It is possible for loop not to enter the |
| 3221 | // body. |
| 3222 | regions.push_back(RegionSuccessor(&getRegion())); |
| 3223 | regions.push_back(RegionSuccessor()); |
| 3224 | } |
| 3225 | |
| 3226 | //===----------------------------------------------------------------------===// |
| 3227 | // ReduceOp |
| 3228 | //===----------------------------------------------------------------------===// |
| 3229 | |
| 3230 | void ReduceOp::build(OpBuilder &builder, OperationState &result) {} |
| 3231 | |
| 3232 | void ReduceOp::build(OpBuilder &builder, OperationState &result, |
| 3233 | ValueRange operands) { |
| 3234 | result.addOperands(operands); |
| 3235 | for (Value v : operands) { |
| 3236 | OpBuilder::InsertionGuard guard(builder); |
| 3237 | Region *bodyRegion = result.addRegion(); |
| 3238 | builder.createBlock(bodyRegion, {}, |
| 3239 | ArrayRef<Type>{v.getType(), v.getType()}, |
| 3240 | {result.location, result.location}); |
| 3241 | } |
| 3242 | } |
| 3243 | |
| 3244 | LogicalResult ReduceOp::verifyRegions() { |
| 3245 | // The region of a ReduceOp has two arguments of the same type as its |
| 3246 | // corresponding operand. |
| 3247 | for (int64_t i = 0, e = getReductions().size(); i < e; ++i) { |
| 3248 | auto type = getOperands()[i].getType(); |
| 3249 | Block &block = getReductions()[i].front(); |
| 3250 | if (block.empty()) |
| 3251 | return emitOpError() << i << "-th reduction has an empty body" ; |
| 3252 | if (block.getNumArguments() != 2 || |
| 3253 | llvm::any_of(block.getArguments(), [&](const BlockArgument &arg) { |
| 3254 | return arg.getType() != type; |
| 3255 | })) |
| 3256 | return emitOpError() << "expected two block arguments with type " << type |
| 3257 | << " in the " << i << "-th reduction region" ; |
| 3258 | |
| 3259 | // Check that the block is terminated by a ReduceReturnOp. |
| 3260 | if (!isa<ReduceReturnOp>(block.getTerminator())) |
| 3261 | return emitOpError("reduction bodies must be terminated with an " |
| 3262 | "'scf.reduce.return' op" ); |
| 3263 | } |
| 3264 | |
| 3265 | return success(); |
| 3266 | } |
| 3267 | |
| 3268 | MutableOperandRange |
| 3269 | ReduceOp::getMutableSuccessorOperands(RegionBranchPoint point) { |
| 3270 | // No operands are forwarded to the next iteration. |
| 3271 | return MutableOperandRange(getOperation(), /*start=*/0, /*length=*/0); |
| 3272 | } |
| 3273 | |
| 3274 | //===----------------------------------------------------------------------===// |
| 3275 | // ReduceReturnOp |
| 3276 | //===----------------------------------------------------------------------===// |
| 3277 | |
| 3278 | LogicalResult ReduceReturnOp::verify() { |
| 3279 | // The type of the return value should be the same type as the types of the |
| 3280 | // block arguments of the reduction body. |
| 3281 | Block *reductionBody = getOperation()->getBlock(); |
| 3282 | // Should already be verified by an op trait. |
| 3283 | assert(isa<ReduceOp>(reductionBody->getParentOp()) && "expected scf.reduce" ); |
| 3284 | Type expectedResultType = reductionBody->getArgument(0).getType(); |
| 3285 | if (expectedResultType != getResult().getType()) |
| 3286 | return emitOpError() << "must have type " << expectedResultType |
| 3287 | << " (the type of the reduction inputs)" ; |
| 3288 | return success(); |
| 3289 | } |
| 3290 | |
| 3291 | //===----------------------------------------------------------------------===// |
| 3292 | // WhileOp |
| 3293 | //===----------------------------------------------------------------------===// |
| 3294 | |
| 3295 | void WhileOp::build(::mlir::OpBuilder &odsBuilder, |
| 3296 | ::mlir::OperationState &odsState, TypeRange resultTypes, |
| 3297 | ValueRange inits, BodyBuilderFn beforeBuilder, |
| 3298 | BodyBuilderFn afterBuilder) { |
| 3299 | odsState.addOperands(inits); |
| 3300 | odsState.addTypes(resultTypes); |
| 3301 | |
| 3302 | OpBuilder::InsertionGuard guard(odsBuilder); |
| 3303 | |
| 3304 | // Build before region. |
| 3305 | SmallVector<Location, 4> beforeArgLocs; |
| 3306 | beforeArgLocs.reserve(inits.size()); |
| 3307 | for (Value operand : inits) { |
| 3308 | beforeArgLocs.push_back(operand.getLoc()); |
| 3309 | } |
| 3310 | |
| 3311 | Region *beforeRegion = odsState.addRegion(); |
| 3312 | Block *beforeBlock = odsBuilder.createBlock(beforeRegion, /*insertPt=*/{}, |
| 3313 | inits.getTypes(), beforeArgLocs); |
| 3314 | if (beforeBuilder) |
| 3315 | beforeBuilder(odsBuilder, odsState.location, beforeBlock->getArguments()); |
| 3316 | |
| 3317 | // Build after region. |
| 3318 | SmallVector<Location, 4> afterArgLocs(resultTypes.size(), odsState.location); |
| 3319 | |
| 3320 | Region *afterRegion = odsState.addRegion(); |
| 3321 | Block *afterBlock = odsBuilder.createBlock(afterRegion, /*insertPt=*/{}, |
| 3322 | resultTypes, afterArgLocs); |
| 3323 | |
| 3324 | if (afterBuilder) |
| 3325 | afterBuilder(odsBuilder, odsState.location, afterBlock->getArguments()); |
| 3326 | } |
| 3327 | |
| 3328 | ConditionOp WhileOp::getConditionOp() { |
| 3329 | return cast<ConditionOp>(getBeforeBody()->getTerminator()); |
| 3330 | } |
| 3331 | |
| 3332 | YieldOp WhileOp::getYieldOp() { |
| 3333 | return cast<YieldOp>(getAfterBody()->getTerminator()); |
| 3334 | } |
| 3335 | |
| 3336 | std::optional<MutableArrayRef<OpOperand>> WhileOp::getYieldedValuesMutable() { |
| 3337 | return getYieldOp().getResultsMutable(); |
| 3338 | } |
| 3339 | |
| 3340 | Block::BlockArgListType WhileOp::getBeforeArguments() { |
| 3341 | return getBeforeBody()->getArguments(); |
| 3342 | } |
| 3343 | |
| 3344 | Block::BlockArgListType WhileOp::getAfterArguments() { |
| 3345 | return getAfterBody()->getArguments(); |
| 3346 | } |
| 3347 | |
| 3348 | Block::BlockArgListType WhileOp::getRegionIterArgs() { |
| 3349 | return getBeforeArguments(); |
| 3350 | } |
| 3351 | |
| 3352 | OperandRange WhileOp::getEntrySuccessorOperands(RegionBranchPoint point) { |
| 3353 | assert(point == getBefore() && |
| 3354 | "WhileOp is expected to branch only to the first region" ); |
| 3355 | return getInits(); |
| 3356 | } |
| 3357 | |
| 3358 | void WhileOp::getSuccessorRegions(RegionBranchPoint point, |
| 3359 | SmallVectorImpl<RegionSuccessor> ®ions) { |
| 3360 | // The parent op always branches to the condition region. |
| 3361 | if (point.isParent()) { |
| 3362 | regions.emplace_back(&getBefore(), getBefore().getArguments()); |
| 3363 | return; |
| 3364 | } |
| 3365 | |
| 3366 | assert(llvm::is_contained({&getAfter(), &getBefore()}, point) && |
| 3367 | "there are only two regions in a WhileOp" ); |
| 3368 | // The body region always branches back to the condition region. |
| 3369 | if (point == getAfter()) { |
| 3370 | regions.emplace_back(&getBefore(), getBefore().getArguments()); |
| 3371 | return; |
| 3372 | } |
| 3373 | |
| 3374 | regions.emplace_back(getResults()); |
| 3375 | regions.emplace_back(&getAfter(), getAfter().getArguments()); |
| 3376 | } |
| 3377 | |
| 3378 | SmallVector<Region *> WhileOp::getLoopRegions() { |
| 3379 | return {&getBefore(), &getAfter()}; |
| 3380 | } |
| 3381 | |
| 3382 | /// Parses a `while` op. |
| 3383 | /// |
| 3384 | /// op ::= `scf.while` assignments `:` function-type region `do` region |
| 3385 | /// `attributes` attribute-dict |
| 3386 | /// initializer ::= /* empty */ | `(` assignment-list `)` |
| 3387 | /// assignment-list ::= assignment | assignment `,` assignment-list |
| 3388 | /// assignment ::= ssa-value `=` ssa-value |
| 3389 | ParseResult scf::WhileOp::parse(OpAsmParser &parser, OperationState &result) { |
| 3390 | SmallVector<OpAsmParser::Argument, 4> regionArgs; |
| 3391 | SmallVector<OpAsmParser::UnresolvedOperand, 4> operands; |
| 3392 | Region *before = result.addRegion(); |
| 3393 | Region *after = result.addRegion(); |
| 3394 | |
| 3395 | OptionalParseResult listResult = |
| 3396 | parser.parseOptionalAssignmentList(regionArgs, operands); |
| 3397 | if (listResult.has_value() && failed(listResult.value())) |
| 3398 | return failure(); |
| 3399 | |
| 3400 | FunctionType functionType; |
| 3401 | SMLoc typeLoc = parser.getCurrentLocation(); |
| 3402 | if (failed(parser.parseColonType(functionType))) |
| 3403 | return failure(); |
| 3404 | |
| 3405 | result.addTypes(functionType.getResults()); |
| 3406 | |
| 3407 | if (functionType.getNumInputs() != operands.size()) { |
| 3408 | return parser.emitError(typeLoc) |
| 3409 | << "expected as many input types as operands " |
| 3410 | << "(expected " << operands.size() << " got " |
| 3411 | << functionType.getNumInputs() << ")" ; |
| 3412 | } |
| 3413 | |
| 3414 | // Resolve input operands. |
| 3415 | if (failed(parser.resolveOperands(operands, functionType.getInputs(), |
| 3416 | parser.getCurrentLocation(), |
| 3417 | result.operands))) |
| 3418 | return failure(); |
| 3419 | |
| 3420 | // Propagate the types into the region arguments. |
| 3421 | for (size_t i = 0, e = regionArgs.size(); i != e; ++i) |
| 3422 | regionArgs[i].type = functionType.getInput(i); |
| 3423 | |
| 3424 | return failure(parser.parseRegion(*before, regionArgs) || |
| 3425 | parser.parseKeyword("do" ) || parser.parseRegion(*after) || |
| 3426 | parser.parseOptionalAttrDictWithKeyword(result.attributes)); |
| 3427 | } |
| 3428 | |
| 3429 | /// Prints a `while` op. |
| 3430 | void scf::WhileOp::print(OpAsmPrinter &p) { |
| 3431 | printInitializationList(p, getBeforeArguments(), getInits(), " " ); |
| 3432 | p << " : " ; |
| 3433 | p.printFunctionalType(getInits().getTypes(), getResults().getTypes()); |
| 3434 | p << ' '; |
| 3435 | p.printRegion(getBefore(), /*printEntryBlockArgs=*/false); |
| 3436 | p << " do " ; |
| 3437 | p.printRegion(getAfter()); |
| 3438 | p.printOptionalAttrDictWithKeyword((*this)->getAttrs()); |
| 3439 | } |
| 3440 | |
| 3441 | /// Verifies that two ranges of types match, i.e. have the same number of |
| 3442 | /// entries and that types are pairwise equals. Reports errors on the given |
| 3443 | /// operation in case of mismatch. |
| 3444 | template <typename OpTy> |
| 3445 | static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left, |
| 3446 | TypeRange right, StringRef message) { |
| 3447 | if (left.size() != right.size()) |
| 3448 | return op.emitOpError("expects the same number of " ) << message; |
| 3449 | |
| 3450 | for (unsigned i = 0, e = left.size(); i < e; ++i) { |
| 3451 | if (left[i] != right[i]) { |
| 3452 | InFlightDiagnostic diag = op.emitOpError("expects the same types for " ) |
| 3453 | << message; |
| 3454 | diag.attachNote() << "for argument " << i << ", found " << left[i] |
| 3455 | << " and " << right[i]; |
| 3456 | return diag; |
| 3457 | } |
| 3458 | } |
| 3459 | |
| 3460 | return success(); |
| 3461 | } |
| 3462 | |
| 3463 | LogicalResult scf::WhileOp::verify() { |
| 3464 | auto beforeTerminator = verifyAndGetTerminator<scf::ConditionOp>( |
| 3465 | *this, getBefore(), |
| 3466 | "expects the 'before' region to terminate with 'scf.condition'" ); |
| 3467 | if (!beforeTerminator) |
| 3468 | return failure(); |
| 3469 | |
| 3470 | auto afterTerminator = verifyAndGetTerminator<scf::YieldOp>( |
| 3471 | *this, getAfter(), |
| 3472 | "expects the 'after' region to terminate with 'scf.yield'" ); |
| 3473 | return success(afterTerminator != nullptr); |
| 3474 | } |
| 3475 | |
| 3476 | namespace { |
| 3477 | /// Replace uses of the condition within the do block with true, since otherwise |
| 3478 | /// the block would not be evaluated. |
| 3479 | /// |
| 3480 | /// scf.while (..) : (i1, ...) -> ... { |
| 3481 | /// %condition = call @evaluate_condition() : () -> i1 |
| 3482 | /// scf.condition(%condition) %condition : i1, ... |
| 3483 | /// } do { |
| 3484 | /// ^bb0(%arg0: i1, ...): |
| 3485 | /// use(%arg0) |
| 3486 | /// ... |
| 3487 | /// |
| 3488 | /// becomes |
| 3489 | /// scf.while (..) : (i1, ...) -> ... { |
| 3490 | /// %condition = call @evaluate_condition() : () -> i1 |
| 3491 | /// scf.condition(%condition) %condition : i1, ... |
| 3492 | /// } do { |
| 3493 | /// ^bb0(%arg0: i1, ...): |
| 3494 | /// use(%true) |
| 3495 | /// ... |
| 3496 | struct WhileConditionTruth : public OpRewritePattern<WhileOp> { |
| 3497 | using OpRewritePattern<WhileOp>::OpRewritePattern; |
| 3498 | |
| 3499 | LogicalResult matchAndRewrite(WhileOp op, |
| 3500 | PatternRewriter &rewriter) const override { |
| 3501 | auto term = op.getConditionOp(); |
| 3502 | |
| 3503 | // These variables serve to prevent creating duplicate constants |
| 3504 | // and hold constant true or false values. |
| 3505 | Value constantTrue = nullptr; |
| 3506 | |
| 3507 | bool replaced = false; |
| 3508 | for (auto yieldedAndBlockArgs : |
| 3509 | llvm::zip(term.getArgs(), op.getAfterArguments())) { |
| 3510 | if (std::get<0>(yieldedAndBlockArgs) == term.getCondition()) { |
| 3511 | if (!std::get<1>(yieldedAndBlockArgs).use_empty()) { |
| 3512 | if (!constantTrue) |
| 3513 | constantTrue = rewriter.create<arith::ConstantOp>( |
| 3514 | op.getLoc(), term.getCondition().getType(), |
| 3515 | rewriter.getBoolAttr(true)); |
| 3516 | |
| 3517 | rewriter.replaceAllUsesWith(std::get<1>(yieldedAndBlockArgs), |
| 3518 | constantTrue); |
| 3519 | replaced = true; |
| 3520 | } |
| 3521 | } |
| 3522 | } |
| 3523 | return success(IsSuccess: replaced); |
| 3524 | } |
| 3525 | }; |
| 3526 | |
| 3527 | /// Remove loop invariant arguments from `before` block of scf.while. |
| 3528 | /// A before block argument is considered loop invariant if :- |
| 3529 | /// 1. i-th yield operand is equal to the i-th while operand. |
| 3530 | /// 2. i-th yield operand is k-th after block argument which is (k+1)-th |
| 3531 | /// condition operand AND this (k+1)-th condition operand is equal to i-th |
| 3532 | /// iter argument/while operand. |
| 3533 | /// For the arguments which are removed, their uses inside scf.while |
| 3534 | /// are replaced with their corresponding initial value. |
| 3535 | /// |
| 3536 | /// Eg: |
| 3537 | /// INPUT :- |
| 3538 | /// %res = scf.while <...> iter_args(%arg0_before = %a, %arg1_before = %b, |
| 3539 | /// ..., %argN_before = %N) |
| 3540 | /// { |
| 3541 | /// ... |
| 3542 | /// scf.condition(%cond) %arg1_before, %arg0_before, |
| 3543 | /// %arg2_before, %arg0_before, ... |
| 3544 | /// } do { |
| 3545 | /// ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2, |
| 3546 | /// ..., %argK_after): |
| 3547 | /// ... |
| 3548 | /// scf.yield %arg0_after_2, %b, %arg1_after, ..., %argN |
| 3549 | /// } |
| 3550 | /// |
| 3551 | /// OUTPUT :- |
| 3552 | /// %res = scf.while <...> iter_args(%arg2_before = %c, ..., %argN_before = |
| 3553 | /// %N) |
| 3554 | /// { |
| 3555 | /// ... |
| 3556 | /// scf.condition(%cond) %b, %a, %arg2_before, %a, ... |
| 3557 | /// } do { |
| 3558 | /// ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2, |
| 3559 | /// ..., %argK_after): |
| 3560 | /// ... |
| 3561 | /// scf.yield %arg1_after, ..., %argN |
| 3562 | /// } |
| 3563 | /// |
| 3564 | /// EXPLANATION: |
| 3565 | /// We iterate over each yield operand. |
| 3566 | /// 1. 0-th yield operand %arg0_after_2 is 4-th condition operand |
| 3567 | /// %arg0_before, which in turn is the 0-th iter argument. So we |
| 3568 | /// remove 0-th before block argument and yield operand, and replace |
| 3569 | /// all uses of the 0-th before block argument with its initial value |
| 3570 | /// %a. |
| 3571 | /// 2. 1-th yield operand %b is equal to the 1-th iter arg's initial |
| 3572 | /// value. So we remove this operand and the corresponding before |
| 3573 | /// block argument and replace all uses of 1-th before block argument |
| 3574 | /// with %b. |
| 3575 | struct RemoveLoopInvariantArgsFromBeforeBlock |
| 3576 | : public OpRewritePattern<WhileOp> { |
| 3577 | using OpRewritePattern<WhileOp>::OpRewritePattern; |
| 3578 | |
| 3579 | LogicalResult matchAndRewrite(WhileOp op, |
| 3580 | PatternRewriter &rewriter) const override { |
| 3581 | Block &afterBlock = *op.getAfterBody(); |
| 3582 | Block::BlockArgListType beforeBlockArgs = op.getBeforeArguments(); |
| 3583 | ConditionOp condOp = op.getConditionOp(); |
| 3584 | OperandRange condOpArgs = condOp.getArgs(); |
| 3585 | Operation *yieldOp = afterBlock.getTerminator(); |
| 3586 | ValueRange yieldOpArgs = yieldOp->getOperands(); |
| 3587 | |
| 3588 | bool canSimplify = false; |
| 3589 | for (const auto &it : |
| 3590 | llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) { |
| 3591 | auto index = static_cast<unsigned>(it.index()); |
| 3592 | auto [initVal, yieldOpArg] = it.value(); |
| 3593 | // If i-th yield operand is equal to the i-th operand of the scf.while, |
| 3594 | // the i-th before block argument is a loop invariant. |
| 3595 | if (yieldOpArg == initVal) { |
| 3596 | canSimplify = true; |
| 3597 | break; |
| 3598 | } |
| 3599 | // If the i-th yield operand is k-th after block argument, then we check |
| 3600 | // if the (k+1)-th condition op operand is equal to either the i-th before |
| 3601 | // block argument or the initial value of i-th before block argument. If |
| 3602 | // the comparison results `true`, i-th before block argument is a loop |
| 3603 | // invariant. |
| 3604 | auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg); |
| 3605 | if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) { |
| 3606 | Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()]; |
| 3607 | if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) { |
| 3608 | canSimplify = true; |
| 3609 | break; |
| 3610 | } |
| 3611 | } |
| 3612 | } |
| 3613 | |
| 3614 | if (!canSimplify) |
| 3615 | return failure(); |
| 3616 | |
| 3617 | SmallVector<Value> newInitArgs, newYieldOpArgs; |
| 3618 | DenseMap<unsigned, Value> beforeBlockInitValMap; |
| 3619 | SmallVector<Location> newBeforeBlockArgLocs; |
| 3620 | for (const auto &it : |
| 3621 | llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) { |
| 3622 | auto index = static_cast<unsigned>(it.index()); |
| 3623 | auto [initVal, yieldOpArg] = it.value(); |
| 3624 | |
| 3625 | // If i-th yield operand is equal to the i-th operand of the scf.while, |
| 3626 | // the i-th before block argument is a loop invariant. |
| 3627 | if (yieldOpArg == initVal) { |
| 3628 | beforeBlockInitValMap.insert({index, initVal}); |
| 3629 | continue; |
| 3630 | } else { |
| 3631 | // If the i-th yield operand is k-th after block argument, then we check |
| 3632 | // if the (k+1)-th condition op operand is equal to either the i-th |
| 3633 | // before block argument or the initial value of i-th before block |
| 3634 | // argument. If the comparison results `true`, i-th before block |
| 3635 | // argument is a loop invariant. |
| 3636 | auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg); |
| 3637 | if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) { |
| 3638 | Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()]; |
| 3639 | if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) { |
| 3640 | beforeBlockInitValMap.insert({index, initVal}); |
| 3641 | continue; |
| 3642 | } |
| 3643 | } |
| 3644 | } |
| 3645 | newInitArgs.emplace_back(initVal); |
| 3646 | newYieldOpArgs.emplace_back(yieldOpArg); |
| 3647 | newBeforeBlockArgLocs.emplace_back(beforeBlockArgs[index].getLoc()); |
| 3648 | } |
| 3649 | |
| 3650 | { |
| 3651 | OpBuilder::InsertionGuard g(rewriter); |
| 3652 | rewriter.setInsertionPoint(yieldOp); |
| 3653 | rewriter.replaceOpWithNewOp<YieldOp>(yieldOp, newYieldOpArgs); |
| 3654 | } |
| 3655 | |
| 3656 | auto newWhile = |
| 3657 | rewriter.create<WhileOp>(op.getLoc(), op.getResultTypes(), newInitArgs); |
| 3658 | |
| 3659 | Block &newBeforeBlock = *rewriter.createBlock( |
| 3660 | &newWhile.getBefore(), /*insertPt*/ {}, |
| 3661 | ValueRange(newYieldOpArgs).getTypes(), newBeforeBlockArgLocs); |
| 3662 | |
| 3663 | Block &beforeBlock = *op.getBeforeBody(); |
| 3664 | SmallVector<Value> newBeforeBlockArgs(beforeBlock.getNumArguments()); |
| 3665 | // For each i-th before block argument we find it's replacement value as :- |
| 3666 | // 1. If i-th before block argument is a loop invariant, we fetch it's |
| 3667 | // initial value from `beforeBlockInitValMap` by querying for key `i`. |
| 3668 | // 2. Else we fetch j-th new before block argument as the replacement |
| 3669 | // value of i-th before block argument. |
| 3670 | for (unsigned i = 0, j = 0, n = beforeBlock.getNumArguments(); i < n; i++) { |
| 3671 | // If the index 'i' argument was a loop invariant we fetch it's initial |
| 3672 | // value from `beforeBlockInitValMap`. |
| 3673 | if (beforeBlockInitValMap.count(Val: i) != 0) |
| 3674 | newBeforeBlockArgs[i] = beforeBlockInitValMap[i]; |
| 3675 | else |
| 3676 | newBeforeBlockArgs[i] = newBeforeBlock.getArgument(i: j++); |
| 3677 | } |
| 3678 | |
| 3679 | rewriter.mergeBlocks(source: &beforeBlock, dest: &newBeforeBlock, argValues: newBeforeBlockArgs); |
| 3680 | rewriter.inlineRegionBefore(op.getAfter(), newWhile.getAfter(), |
| 3681 | newWhile.getAfter().begin()); |
| 3682 | |
| 3683 | rewriter.replaceOp(op, newWhile.getResults()); |
| 3684 | return success(); |
| 3685 | } |
| 3686 | }; |
| 3687 | |
| 3688 | /// Remove loop invariant value from result (condition op) of scf.while. |
| 3689 | /// A value is considered loop invariant if the final value yielded by |
| 3690 | /// scf.condition is defined outside of the `before` block. We remove the |
| 3691 | /// corresponding argument in `after` block and replace the use with the value. |
| 3692 | /// We also replace the use of the corresponding result of scf.while with the |
| 3693 | /// value. |
| 3694 | /// |
| 3695 | /// Eg: |
| 3696 | /// INPUT :- |
| 3697 | /// %res_input:K = scf.while <...> iter_args(%arg0_before = , ..., |
| 3698 | /// %argN_before = %N) { |
| 3699 | /// ... |
| 3700 | /// scf.condition(%cond) %arg0_before, %a, %b, %arg1_before, ... |
| 3701 | /// } do { |
| 3702 | /// ^bb0(%arg0_after, %arg1_after, %arg2_after, ..., %argK_after): |
| 3703 | /// ... |
| 3704 | /// some_func(%arg1_after) |
| 3705 | /// ... |
| 3706 | /// scf.yield %arg0_after, %arg2_after, ..., %argN_after |
| 3707 | /// } |
| 3708 | /// |
| 3709 | /// OUTPUT :- |
| 3710 | /// %res_output:M = scf.while <...> iter_args(%arg0 = , ..., %argN = %N) { |
| 3711 | /// ... |
| 3712 | /// scf.condition(%cond) %arg0, %arg1, ..., %argM |
| 3713 | /// } do { |
| 3714 | /// ^bb0(%arg0, %arg3, ..., %argM): |
| 3715 | /// ... |
| 3716 | /// some_func(%a) |
| 3717 | /// ... |
| 3718 | /// scf.yield %arg0, %b, ..., %argN |
| 3719 | /// } |
| 3720 | /// |
| 3721 | /// EXPLANATION: |
| 3722 | /// 1. The 1-th and 2-th operand of scf.condition are defined outside the |
| 3723 | /// before block of scf.while, so they get removed. |
| 3724 | /// 2. %res_input#1's uses are replaced by %a and %res_input#2's uses are |
| 3725 | /// replaced by %b. |
| 3726 | /// 3. The corresponding after block argument %arg1_after's uses are |
| 3727 | /// replaced by %a and %arg2_after's uses are replaced by %b. |
| 3728 | struct RemoveLoopInvariantValueYielded : public OpRewritePattern<WhileOp> { |
| 3729 | using OpRewritePattern<WhileOp>::OpRewritePattern; |
| 3730 | |
| 3731 | LogicalResult matchAndRewrite(WhileOp op, |
| 3732 | PatternRewriter &rewriter) const override { |
| 3733 | Block &beforeBlock = *op.getBeforeBody(); |
| 3734 | ConditionOp condOp = op.getConditionOp(); |
| 3735 | OperandRange condOpArgs = condOp.getArgs(); |
| 3736 | |
| 3737 | bool canSimplify = false; |
| 3738 | for (Value condOpArg : condOpArgs) { |
| 3739 | // Those values not defined within `before` block will be considered as |
| 3740 | // loop invariant values. We map the corresponding `index` with their |
| 3741 | // value. |
| 3742 | if (condOpArg.getParentBlock() != &beforeBlock) { |
| 3743 | canSimplify = true; |
| 3744 | break; |
| 3745 | } |
| 3746 | } |
| 3747 | |
| 3748 | if (!canSimplify) |
| 3749 | return failure(); |
| 3750 | |
| 3751 | Block::BlockArgListType afterBlockArgs = op.getAfterArguments(); |
| 3752 | |
| 3753 | SmallVector<Value> newCondOpArgs; |
| 3754 | SmallVector<Type> newAfterBlockType; |
| 3755 | DenseMap<unsigned, Value> condOpInitValMap; |
| 3756 | SmallVector<Location> newAfterBlockArgLocs; |
| 3757 | for (const auto &it : llvm::enumerate(condOpArgs)) { |
| 3758 | auto index = static_cast<unsigned>(it.index()); |
| 3759 | Value condOpArg = it.value(); |
| 3760 | // Those values not defined within `before` block will be considered as |
| 3761 | // loop invariant values. We map the corresponding `index` with their |
| 3762 | // value. |
| 3763 | if (condOpArg.getParentBlock() != &beforeBlock) { |
| 3764 | condOpInitValMap.insert({index, condOpArg}); |
| 3765 | } else { |
| 3766 | newCondOpArgs.emplace_back(condOpArg); |
| 3767 | newAfterBlockType.emplace_back(condOpArg.getType()); |
| 3768 | newAfterBlockArgLocs.emplace_back(afterBlockArgs[index].getLoc()); |
| 3769 | } |
| 3770 | } |
| 3771 | |
| 3772 | { |
| 3773 | OpBuilder::InsertionGuard g(rewriter); |
| 3774 | rewriter.setInsertionPoint(condOp); |
| 3775 | rewriter.replaceOpWithNewOp<ConditionOp>(condOp, condOp.getCondition(), |
| 3776 | newCondOpArgs); |
| 3777 | } |
| 3778 | |
| 3779 | auto newWhile = rewriter.create<WhileOp>(op.getLoc(), newAfterBlockType, |
| 3780 | op.getOperands()); |
| 3781 | |
| 3782 | Block &newAfterBlock = |
| 3783 | *rewriter.createBlock(&newWhile.getAfter(), /*insertPt*/ {}, |
| 3784 | newAfterBlockType, newAfterBlockArgLocs); |
| 3785 | |
| 3786 | Block &afterBlock = *op.getAfterBody(); |
| 3787 | // Since a new scf.condition op was created, we need to fetch the new |
| 3788 | // `after` block arguments which will be used while replacing operations of |
| 3789 | // previous scf.while's `after` blocks. We'd also be fetching new result |
| 3790 | // values too. |
| 3791 | SmallVector<Value> newAfterBlockArgs(afterBlock.getNumArguments()); |
| 3792 | SmallVector<Value> newWhileResults(afterBlock.getNumArguments()); |
| 3793 | for (unsigned i = 0, j = 0, n = afterBlock.getNumArguments(); i < n; i++) { |
| 3794 | Value afterBlockArg, result; |
| 3795 | // If index 'i' argument was loop invariant we fetch it's value from the |
| 3796 | // `condOpInitMap` map. |
| 3797 | if (condOpInitValMap.count(Val: i) != 0) { |
| 3798 | afterBlockArg = condOpInitValMap[i]; |
| 3799 | result = afterBlockArg; |
| 3800 | } else { |
| 3801 | afterBlockArg = newAfterBlock.getArgument(i: j); |
| 3802 | result = newWhile.getResult(j); |
| 3803 | j++; |
| 3804 | } |
| 3805 | newAfterBlockArgs[i] = afterBlockArg; |
| 3806 | newWhileResults[i] = result; |
| 3807 | } |
| 3808 | |
| 3809 | rewriter.mergeBlocks(source: &afterBlock, dest: &newAfterBlock, argValues: newAfterBlockArgs); |
| 3810 | rewriter.inlineRegionBefore(op.getBefore(), newWhile.getBefore(), |
| 3811 | newWhile.getBefore().begin()); |
| 3812 | |
| 3813 | rewriter.replaceOp(op, newWhileResults); |
| 3814 | return success(); |
| 3815 | } |
| 3816 | }; |
| 3817 | |
| 3818 | /// Remove WhileOp results that are also unused in 'after' block. |
| 3819 | /// |
| 3820 | /// %0:2 = scf.while () : () -> (i32, i64) { |
| 3821 | /// %condition = "test.condition"() : () -> i1 |
| 3822 | /// %v1 = "test.get_some_value"() : () -> i32 |
| 3823 | /// %v2 = "test.get_some_value"() : () -> i64 |
| 3824 | /// scf.condition(%condition) %v1, %v2 : i32, i64 |
| 3825 | /// } do { |
| 3826 | /// ^bb0(%arg0: i32, %arg1: i64): |
| 3827 | /// "test.use"(%arg0) : (i32) -> () |
| 3828 | /// scf.yield |
| 3829 | /// } |
| 3830 | /// return %0#0 : i32 |
| 3831 | /// |
| 3832 | /// becomes |
| 3833 | /// %0 = scf.while () : () -> (i32) { |
| 3834 | /// %condition = "test.condition"() : () -> i1 |
| 3835 | /// %v1 = "test.get_some_value"() : () -> i32 |
| 3836 | /// %v2 = "test.get_some_value"() : () -> i64 |
| 3837 | /// scf.condition(%condition) %v1 : i32 |
| 3838 | /// } do { |
| 3839 | /// ^bb0(%arg0: i32): |
| 3840 | /// "test.use"(%arg0) : (i32) -> () |
| 3841 | /// scf.yield |
| 3842 | /// } |
| 3843 | /// return %0 : i32 |
| 3844 | struct WhileUnusedResult : public OpRewritePattern<WhileOp> { |
| 3845 | using OpRewritePattern<WhileOp>::OpRewritePattern; |
| 3846 | |
| 3847 | LogicalResult matchAndRewrite(WhileOp op, |
| 3848 | PatternRewriter &rewriter) const override { |
| 3849 | auto term = op.getConditionOp(); |
| 3850 | auto afterArgs = op.getAfterArguments(); |
| 3851 | auto termArgs = term.getArgs(); |
| 3852 | |
| 3853 | // Collect results mapping, new terminator args and new result types. |
| 3854 | SmallVector<unsigned> newResultsIndices; |
| 3855 | SmallVector<Type> newResultTypes; |
| 3856 | SmallVector<Value> newTermArgs; |
| 3857 | SmallVector<Location> newArgLocs; |
| 3858 | bool needUpdate = false; |
| 3859 | for (const auto &it : |
| 3860 | llvm::enumerate(llvm::zip(op.getResults(), afterArgs, termArgs))) { |
| 3861 | auto i = static_cast<unsigned>(it.index()); |
| 3862 | Value result = std::get<0>(it.value()); |
| 3863 | Value afterArg = std::get<1>(it.value()); |
| 3864 | Value termArg = std::get<2>(it.value()); |
| 3865 | if (result.use_empty() && afterArg.use_empty()) { |
| 3866 | needUpdate = true; |
| 3867 | } else { |
| 3868 | newResultsIndices.emplace_back(i); |
| 3869 | newTermArgs.emplace_back(termArg); |
| 3870 | newResultTypes.emplace_back(result.getType()); |
| 3871 | newArgLocs.emplace_back(result.getLoc()); |
| 3872 | } |
| 3873 | } |
| 3874 | |
| 3875 | if (!needUpdate) |
| 3876 | return failure(); |
| 3877 | |
| 3878 | { |
| 3879 | OpBuilder::InsertionGuard g(rewriter); |
| 3880 | rewriter.setInsertionPoint(term); |
| 3881 | rewriter.replaceOpWithNewOp<ConditionOp>(term, term.getCondition(), |
| 3882 | newTermArgs); |
| 3883 | } |
| 3884 | |
| 3885 | auto newWhile = |
| 3886 | rewriter.create<WhileOp>(op.getLoc(), newResultTypes, op.getInits()); |
| 3887 | |
| 3888 | Block &newAfterBlock = *rewriter.createBlock( |
| 3889 | &newWhile.getAfter(), /*insertPt*/ {}, newResultTypes, newArgLocs); |
| 3890 | |
| 3891 | // Build new results list and new after block args (unused entries will be |
| 3892 | // null). |
| 3893 | SmallVector<Value> newResults(op.getNumResults()); |
| 3894 | SmallVector<Value> newAfterBlockArgs(op.getNumResults()); |
| 3895 | for (const auto &it : llvm::enumerate(First&: newResultsIndices)) { |
| 3896 | newResults[it.value()] = newWhile.getResult(it.index()); |
| 3897 | newAfterBlockArgs[it.value()] = newAfterBlock.getArgument(i: it.index()); |
| 3898 | } |
| 3899 | |
| 3900 | rewriter.inlineRegionBefore(op.getBefore(), newWhile.getBefore(), |
| 3901 | newWhile.getBefore().begin()); |
| 3902 | |
| 3903 | Block &afterBlock = *op.getAfterBody(); |
| 3904 | rewriter.mergeBlocks(source: &afterBlock, dest: &newAfterBlock, argValues: newAfterBlockArgs); |
| 3905 | |
| 3906 | rewriter.replaceOp(op, newResults); |
| 3907 | return success(); |
| 3908 | } |
| 3909 | }; |
| 3910 | |
| 3911 | /// Replace operations equivalent to the condition in the do block with true, |
| 3912 | /// since otherwise the block would not be evaluated. |
| 3913 | /// |
| 3914 | /// scf.while (..) : (i32, ...) -> ... { |
| 3915 | /// %z = ... : i32 |
| 3916 | /// %condition = cmpi pred %z, %a |
| 3917 | /// scf.condition(%condition) %z : i32, ... |
| 3918 | /// } do { |
| 3919 | /// ^bb0(%arg0: i32, ...): |
| 3920 | /// %condition2 = cmpi pred %arg0, %a |
| 3921 | /// use(%condition2) |
| 3922 | /// ... |
| 3923 | /// |
| 3924 | /// becomes |
| 3925 | /// scf.while (..) : (i32, ...) -> ... { |
| 3926 | /// %z = ... : i32 |
| 3927 | /// %condition = cmpi pred %z, %a |
| 3928 | /// scf.condition(%condition) %z : i32, ... |
| 3929 | /// } do { |
| 3930 | /// ^bb0(%arg0: i32, ...): |
| 3931 | /// use(%true) |
| 3932 | /// ... |
| 3933 | struct WhileCmpCond : public OpRewritePattern<scf::WhileOp> { |
| 3934 | using OpRewritePattern<scf::WhileOp>::OpRewritePattern; |
| 3935 | |
| 3936 | LogicalResult matchAndRewrite(scf::WhileOp op, |
| 3937 | PatternRewriter &rewriter) const override { |
| 3938 | using namespace scf; |
| 3939 | auto cond = op.getConditionOp(); |
| 3940 | auto cmp = cond.getCondition().getDefiningOp<arith::CmpIOp>(); |
| 3941 | if (!cmp) |
| 3942 | return failure(); |
| 3943 | bool changed = false; |
| 3944 | for (auto tup : llvm::zip(cond.getArgs(), op.getAfterArguments())) { |
| 3945 | for (size_t opIdx = 0; opIdx < 2; opIdx++) { |
| 3946 | if (std::get<0>(tup) != cmp.getOperand(opIdx)) |
| 3947 | continue; |
| 3948 | for (OpOperand &u : |
| 3949 | llvm::make_early_inc_range(std::get<1>(tup).getUses())) { |
| 3950 | auto cmp2 = dyn_cast<arith::CmpIOp>(u.getOwner()); |
| 3951 | if (!cmp2) |
| 3952 | continue; |
| 3953 | // For a binary operator 1-opIdx gets the other side. |
| 3954 | if (cmp2.getOperand(1 - opIdx) != cmp.getOperand(1 - opIdx)) |
| 3955 | continue; |
| 3956 | bool samePredicate; |
| 3957 | if (cmp2.getPredicate() == cmp.getPredicate()) |
| 3958 | samePredicate = true; |
| 3959 | else if (cmp2.getPredicate() == |
| 3960 | arith::invertPredicate(cmp.getPredicate())) |
| 3961 | samePredicate = false; |
| 3962 | else |
| 3963 | continue; |
| 3964 | |
| 3965 | rewriter.replaceOpWithNewOp<arith::ConstantIntOp>(cmp2, samePredicate, |
| 3966 | 1); |
| 3967 | changed = true; |
| 3968 | } |
| 3969 | } |
| 3970 | } |
| 3971 | return success(IsSuccess: changed); |
| 3972 | } |
| 3973 | }; |
| 3974 | |
| 3975 | /// Remove unused init/yield args. |
| 3976 | struct WhileRemoveUnusedArgs : public OpRewritePattern<WhileOp> { |
| 3977 | using OpRewritePattern<WhileOp>::OpRewritePattern; |
| 3978 | |
| 3979 | LogicalResult matchAndRewrite(WhileOp op, |
| 3980 | PatternRewriter &rewriter) const override { |
| 3981 | |
| 3982 | if (!llvm::any_of(op.getBeforeArguments(), |
| 3983 | [](Value arg) { return arg.use_empty(); })) |
| 3984 | return rewriter.notifyMatchFailure(op, "No args to remove" ); |
| 3985 | |
| 3986 | YieldOp yield = op.getYieldOp(); |
| 3987 | |
| 3988 | // Collect results mapping, new terminator args and new result types. |
| 3989 | SmallVector<Value> newYields; |
| 3990 | SmallVector<Value> newInits; |
| 3991 | llvm::BitVector argsToErase; |
| 3992 | |
| 3993 | size_t argsCount = op.getBeforeArguments().size(); |
| 3994 | newYields.reserve(N: argsCount); |
| 3995 | newInits.reserve(N: argsCount); |
| 3996 | argsToErase.reserve(N: argsCount); |
| 3997 | for (auto &&[beforeArg, yieldValue, initValue] : llvm::zip( |
| 3998 | op.getBeforeArguments(), yield.getOperands(), op.getInits())) { |
| 3999 | if (beforeArg.use_empty()) { |
| 4000 | argsToErase.push_back(true); |
| 4001 | } else { |
| 4002 | argsToErase.push_back(false); |
| 4003 | newYields.emplace_back(yieldValue); |
| 4004 | newInits.emplace_back(initValue); |
| 4005 | } |
| 4006 | } |
| 4007 | |
| 4008 | Block &beforeBlock = *op.getBeforeBody(); |
| 4009 | Block &afterBlock = *op.getAfterBody(); |
| 4010 | |
| 4011 | beforeBlock.eraseArguments(eraseIndices: argsToErase); |
| 4012 | |
| 4013 | Location loc = op.getLoc(); |
| 4014 | auto newWhileOp = |
| 4015 | rewriter.create<WhileOp>(loc, op.getResultTypes(), newInits, |
| 4016 | /*beforeBody*/ nullptr, /*afterBody*/ nullptr); |
| 4017 | Block &newBeforeBlock = *newWhileOp.getBeforeBody(); |
| 4018 | Block &newAfterBlock = *newWhileOp.getAfterBody(); |
| 4019 | |
| 4020 | OpBuilder::InsertionGuard g(rewriter); |
| 4021 | rewriter.setInsertionPoint(yield); |
| 4022 | rewriter.replaceOpWithNewOp<YieldOp>(yield, newYields); |
| 4023 | |
| 4024 | rewriter.mergeBlocks(source: &beforeBlock, dest: &newBeforeBlock, |
| 4025 | argValues: newBeforeBlock.getArguments()); |
| 4026 | rewriter.mergeBlocks(source: &afterBlock, dest: &newAfterBlock, |
| 4027 | argValues: newAfterBlock.getArguments()); |
| 4028 | |
| 4029 | rewriter.replaceOp(op, newWhileOp.getResults()); |
| 4030 | return success(); |
| 4031 | } |
| 4032 | }; |
| 4033 | |
| 4034 | /// Remove duplicated ConditionOp args. |
| 4035 | struct WhileRemoveDuplicatedResults : public OpRewritePattern<WhileOp> { |
| 4036 | using OpRewritePattern::OpRewritePattern; |
| 4037 | |
| 4038 | LogicalResult matchAndRewrite(WhileOp op, |
| 4039 | PatternRewriter &rewriter) const override { |
| 4040 | ConditionOp condOp = op.getConditionOp(); |
| 4041 | ValueRange condOpArgs = condOp.getArgs(); |
| 4042 | |
| 4043 | llvm::SmallPtrSet<Value, 8> argsSet(llvm::from_range, condOpArgs); |
| 4044 | |
| 4045 | if (argsSet.size() == condOpArgs.size()) |
| 4046 | return rewriter.notifyMatchFailure(op, "No results to remove" ); |
| 4047 | |
| 4048 | llvm::SmallDenseMap<Value, unsigned> argsMap; |
| 4049 | SmallVector<Value> newArgs; |
| 4050 | argsMap.reserve(NumEntries: condOpArgs.size()); |
| 4051 | newArgs.reserve(N: condOpArgs.size()); |
| 4052 | for (Value arg : condOpArgs) { |
| 4053 | if (!argsMap.count(arg)) { |
| 4054 | auto pos = static_cast<unsigned>(argsMap.size()); |
| 4055 | argsMap.insert({arg, pos}); |
| 4056 | newArgs.emplace_back(arg); |
| 4057 | } |
| 4058 | } |
| 4059 | |
| 4060 | ValueRange argsRange(newArgs); |
| 4061 | |
| 4062 | Location loc = op.getLoc(); |
| 4063 | auto newWhileOp = rewriter.create<scf::WhileOp>( |
| 4064 | loc, argsRange.getTypes(), op.getInits(), /*beforeBody*/ nullptr, |
| 4065 | /*afterBody*/ nullptr); |
| 4066 | Block &newBeforeBlock = *newWhileOp.getBeforeBody(); |
| 4067 | Block &newAfterBlock = *newWhileOp.getAfterBody(); |
| 4068 | |
| 4069 | SmallVector<Value> afterArgsMapping; |
| 4070 | SmallVector<Value> resultsMapping; |
| 4071 | for (auto &&[i, arg] : llvm::enumerate(condOpArgs)) { |
| 4072 | auto it = argsMap.find(arg); |
| 4073 | assert(it != argsMap.end()); |
| 4074 | auto pos = it->second; |
| 4075 | afterArgsMapping.emplace_back(newAfterBlock.getArgument(pos)); |
| 4076 | resultsMapping.emplace_back(newWhileOp->getResult(pos)); |
| 4077 | } |
| 4078 | |
| 4079 | OpBuilder::InsertionGuard g(rewriter); |
| 4080 | rewriter.setInsertionPoint(condOp); |
| 4081 | rewriter.replaceOpWithNewOp<ConditionOp>(condOp, condOp.getCondition(), |
| 4082 | argsRange); |
| 4083 | |
| 4084 | Block &beforeBlock = *op.getBeforeBody(); |
| 4085 | Block &afterBlock = *op.getAfterBody(); |
| 4086 | |
| 4087 | rewriter.mergeBlocks(source: &beforeBlock, dest: &newBeforeBlock, |
| 4088 | argValues: newBeforeBlock.getArguments()); |
| 4089 | rewriter.mergeBlocks(source: &afterBlock, dest: &newAfterBlock, argValues: afterArgsMapping); |
| 4090 | rewriter.replaceOp(op, resultsMapping); |
| 4091 | return success(); |
| 4092 | } |
| 4093 | }; |
| 4094 | |
| 4095 | /// If both ranges contain same values return mappping indices from args2 to |
| 4096 | /// args1. Otherwise return std::nullopt. |
| 4097 | static std::optional<SmallVector<unsigned>> getArgsMapping(ValueRange args1, |
| 4098 | ValueRange args2) { |
| 4099 | if (args1.size() != args2.size()) |
| 4100 | return std::nullopt; |
| 4101 | |
| 4102 | SmallVector<unsigned> ret(args1.size()); |
| 4103 | for (auto &&[i, arg1] : llvm::enumerate(First&: args1)) { |
| 4104 | auto it = llvm::find(Range&: args2, Val: arg1); |
| 4105 | if (it == args2.end()) |
| 4106 | return std::nullopt; |
| 4107 | |
| 4108 | ret[std::distance(first: args2.begin(), last: it)] = static_cast<unsigned>(i); |
| 4109 | } |
| 4110 | |
| 4111 | return ret; |
| 4112 | } |
| 4113 | |
| 4114 | static bool hasDuplicates(ValueRange args) { |
| 4115 | llvm::SmallDenseSet<Value> set; |
| 4116 | for (Value arg : args) { |
| 4117 | if (!set.insert(V: arg).second) |
| 4118 | return true; |
| 4119 | } |
| 4120 | return false; |
| 4121 | } |
| 4122 | |
| 4123 | /// If `before` block args are directly forwarded to `scf.condition`, rearrange |
| 4124 | /// `scf.condition` args into same order as block args. Update `after` block |
| 4125 | /// args and op result values accordingly. |
| 4126 | /// Needed to simplify `scf.while` -> `scf.for` uplifting. |
| 4127 | struct WhileOpAlignBeforeArgs : public OpRewritePattern<WhileOp> { |
| 4128 | using OpRewritePattern::OpRewritePattern; |
| 4129 | |
| 4130 | LogicalResult matchAndRewrite(WhileOp loop, |
| 4131 | PatternRewriter &rewriter) const override { |
| 4132 | auto oldBefore = loop.getBeforeBody(); |
| 4133 | ConditionOp oldTerm = loop.getConditionOp(); |
| 4134 | ValueRange beforeArgs = oldBefore->getArguments(); |
| 4135 | ValueRange termArgs = oldTerm.getArgs(); |
| 4136 | if (beforeArgs == termArgs) |
| 4137 | return failure(); |
| 4138 | |
| 4139 | if (hasDuplicates(args: termArgs)) |
| 4140 | return failure(); |
| 4141 | |
| 4142 | auto mapping = getArgsMapping(args1: beforeArgs, args2: termArgs); |
| 4143 | if (!mapping) |
| 4144 | return failure(); |
| 4145 | |
| 4146 | { |
| 4147 | OpBuilder::InsertionGuard g(rewriter); |
| 4148 | rewriter.setInsertionPoint(oldTerm); |
| 4149 | rewriter.replaceOpWithNewOp<ConditionOp>(oldTerm, oldTerm.getCondition(), |
| 4150 | beforeArgs); |
| 4151 | } |
| 4152 | |
| 4153 | auto oldAfter = loop.getAfterBody(); |
| 4154 | |
| 4155 | SmallVector<Type> newResultTypes(beforeArgs.size()); |
| 4156 | for (auto &&[i, j] : llvm::enumerate(*mapping)) |
| 4157 | newResultTypes[j] = loop.getResult(i).getType(); |
| 4158 | |
| 4159 | auto newLoop = rewriter.create<WhileOp>( |
| 4160 | loop.getLoc(), newResultTypes, loop.getInits(), |
| 4161 | /*beforeBuilder=*/nullptr, /*afterBuilder=*/nullptr); |
| 4162 | auto newBefore = newLoop.getBeforeBody(); |
| 4163 | auto newAfter = newLoop.getAfterBody(); |
| 4164 | |
| 4165 | SmallVector<Value> newResults(beforeArgs.size()); |
| 4166 | SmallVector<Value> newAfterArgs(beforeArgs.size()); |
| 4167 | for (auto &&[i, j] : llvm::enumerate(*mapping)) { |
| 4168 | newResults[i] = newLoop.getResult(j); |
| 4169 | newAfterArgs[i] = newAfter->getArgument(j); |
| 4170 | } |
| 4171 | |
| 4172 | rewriter.inlineBlockBefore(oldBefore, newBefore, newBefore->begin(), |
| 4173 | newBefore->getArguments()); |
| 4174 | rewriter.inlineBlockBefore(oldAfter, newAfter, newAfter->begin(), |
| 4175 | newAfterArgs); |
| 4176 | |
| 4177 | rewriter.replaceOp(loop, newResults); |
| 4178 | return success(); |
| 4179 | } |
| 4180 | }; |
| 4181 | } // namespace |
| 4182 | |
| 4183 | void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results, |
| 4184 | MLIRContext *context) { |
| 4185 | results.add<RemoveLoopInvariantArgsFromBeforeBlock, |
| 4186 | RemoveLoopInvariantValueYielded, WhileConditionTruth, |
| 4187 | WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults, |
| 4188 | WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs>(context); |
| 4189 | } |
| 4190 | |
| 4191 | //===----------------------------------------------------------------------===// |
| 4192 | // IndexSwitchOp |
| 4193 | //===----------------------------------------------------------------------===// |
| 4194 | |
| 4195 | /// Parse the case regions and values. |
| 4196 | static ParseResult |
| 4197 | parseSwitchCases(OpAsmParser &p, DenseI64ArrayAttr &cases, |
| 4198 | SmallVectorImpl<std::unique_ptr<Region>> &caseRegions) { |
| 4199 | SmallVector<int64_t> caseValues; |
| 4200 | while (succeeded(Result: p.parseOptionalKeyword(keyword: "case" ))) { |
| 4201 | int64_t value; |
| 4202 | Region ®ion = *caseRegions.emplace_back(Args: std::make_unique<Region>()); |
| 4203 | if (p.parseInteger(result&: value) || p.parseRegion(region, /*arguments=*/{})) |
| 4204 | return failure(); |
| 4205 | caseValues.push_back(Elt: value); |
| 4206 | } |
| 4207 | cases = p.getBuilder().getDenseI64ArrayAttr(caseValues); |
| 4208 | return success(); |
| 4209 | } |
| 4210 | |
| 4211 | /// Print the case regions and values. |
| 4212 | static void printSwitchCases(OpAsmPrinter &p, Operation *op, |
| 4213 | DenseI64ArrayAttr cases, RegionRange caseRegions) { |
| 4214 | for (auto [value, region] : llvm::zip(cases.asArrayRef(), caseRegions)) { |
| 4215 | p.printNewline(); |
| 4216 | p << "case " << value << ' '; |
| 4217 | p.printRegion(*region, /*printEntryBlockArgs=*/false); |
| 4218 | } |
| 4219 | } |
| 4220 | |
| 4221 | LogicalResult scf::IndexSwitchOp::verify() { |
| 4222 | if (getCases().size() != getCaseRegions().size()) { |
| 4223 | return emitOpError("has " ) |
| 4224 | << getCaseRegions().size() << " case regions but " |
| 4225 | << getCases().size() << " case values" ; |
| 4226 | } |
| 4227 | |
| 4228 | DenseSet<int64_t> valueSet; |
| 4229 | for (int64_t value : getCases()) |
| 4230 | if (!valueSet.insert(value).second) |
| 4231 | return emitOpError("has duplicate case value: " ) << value; |
| 4232 | auto verifyRegion = [&](Region ®ion, const Twine &name) -> LogicalResult { |
| 4233 | auto yield = dyn_cast<YieldOp>(region.front().back()); |
| 4234 | if (!yield) |
| 4235 | return emitOpError("expected region to end with scf.yield, but got " ) |
| 4236 | << region.front().back().getName(); |
| 4237 | |
| 4238 | if (yield.getNumOperands() != getNumResults()) { |
| 4239 | return (emitOpError("expected each region to return " ) |
| 4240 | << getNumResults() << " values, but " << name << " returns " |
| 4241 | << yield.getNumOperands()) |
| 4242 | .attachNote(yield.getLoc()) |
| 4243 | << "see yield operation here" ; |
| 4244 | } |
| 4245 | for (auto [idx, result, operand] : |
| 4246 | llvm::zip(llvm::seq<unsigned>(0, getNumResults()), getResultTypes(), |
| 4247 | yield.getOperandTypes())) { |
| 4248 | if (result == operand) |
| 4249 | continue; |
| 4250 | return (emitOpError("expected result #" ) |
| 4251 | << idx << " of each region to be " << result) |
| 4252 | .attachNote(yield.getLoc()) |
| 4253 | << name << " returns " << operand << " here" ; |
| 4254 | } |
| 4255 | return success(); |
| 4256 | }; |
| 4257 | |
| 4258 | if (failed(verifyRegion(getDefaultRegion(), "default region" ))) |
| 4259 | return failure(); |
| 4260 | for (auto [idx, caseRegion] : llvm::enumerate(getCaseRegions())) |
| 4261 | if (failed(verifyRegion(caseRegion, "case region #" + Twine(idx)))) |
| 4262 | return failure(); |
| 4263 | |
| 4264 | return success(); |
| 4265 | } |
| 4266 | |
| 4267 | unsigned scf::IndexSwitchOp::getNumCases() { return getCases().size(); } |
| 4268 | |
| 4269 | Block &scf::IndexSwitchOp::getDefaultBlock() { |
| 4270 | return getDefaultRegion().front(); |
| 4271 | } |
| 4272 | |
| 4273 | Block &scf::IndexSwitchOp::getCaseBlock(unsigned idx) { |
| 4274 | assert(idx < getNumCases() && "case index out-of-bounds" ); |
| 4275 | return getCaseRegions()[idx].front(); |
| 4276 | } |
| 4277 | |
| 4278 | void IndexSwitchOp::getSuccessorRegions( |
| 4279 | RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &successors) { |
| 4280 | // All regions branch back to the parent op. |
| 4281 | if (!point.isParent()) { |
| 4282 | successors.emplace_back(getResults()); |
| 4283 | return; |
| 4284 | } |
| 4285 | |
| 4286 | llvm::append_range(successors, getRegions()); |
| 4287 | } |
| 4288 | |
| 4289 | void IndexSwitchOp::getEntrySuccessorRegions( |
| 4290 | ArrayRef<Attribute> operands, |
| 4291 | SmallVectorImpl<RegionSuccessor> &successors) { |
| 4292 | FoldAdaptor adaptor(operands, *this); |
| 4293 | |
| 4294 | // If a constant was not provided, all regions are possible successors. |
| 4295 | auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg()); |
| 4296 | if (!arg) { |
| 4297 | llvm::append_range(successors, getRegions()); |
| 4298 | return; |
| 4299 | } |
| 4300 | |
| 4301 | // Otherwise, try to find a case with a matching value. If not, the |
| 4302 | // default region is the only successor. |
| 4303 | for (auto [caseValue, caseRegion] : llvm::zip(getCases(), getCaseRegions())) { |
| 4304 | if (caseValue == arg.getInt()) { |
| 4305 | successors.emplace_back(&caseRegion); |
| 4306 | return; |
| 4307 | } |
| 4308 | } |
| 4309 | successors.emplace_back(&getDefaultRegion()); |
| 4310 | } |
| 4311 | |
| 4312 | void IndexSwitchOp::getRegionInvocationBounds( |
| 4313 | ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) { |
| 4314 | auto operandValue = llvm::dyn_cast_or_null<IntegerAttr>(operands.front()); |
| 4315 | if (!operandValue) { |
| 4316 | // All regions are invoked at most once. |
| 4317 | bounds.append(getNumRegions(), InvocationBounds(/*lb=*/0, /*ub=*/1)); |
| 4318 | return; |
| 4319 | } |
| 4320 | |
| 4321 | unsigned liveIndex = getNumRegions() - 1; |
| 4322 | const auto *it = llvm::find(getCases(), operandValue.getInt()); |
| 4323 | if (it != getCases().end()) |
| 4324 | liveIndex = std::distance(getCases().begin(), it); |
| 4325 | for (unsigned i = 0, e = getNumRegions(); i < e; ++i) |
| 4326 | bounds.emplace_back(/*lb=*/0, /*ub=*/i == liveIndex); |
| 4327 | } |
| 4328 | |
| 4329 | struct FoldConstantCase : OpRewritePattern<scf::IndexSwitchOp> { |
| 4330 | using OpRewritePattern<scf::IndexSwitchOp>::OpRewritePattern; |
| 4331 | |
| 4332 | LogicalResult matchAndRewrite(scf::IndexSwitchOp op, |
| 4333 | PatternRewriter &rewriter) const override { |
| 4334 | // If `op.getArg()` is a constant, select the region that matches with |
| 4335 | // the constant value. Use the default region if no matche is found. |
| 4336 | std::optional<int64_t> maybeCst = getConstantIntValue(op.getArg()); |
| 4337 | if (!maybeCst.has_value()) |
| 4338 | return failure(); |
| 4339 | int64_t cst = *maybeCst; |
| 4340 | int64_t caseIdx, e = op.getNumCases(); |
| 4341 | for (caseIdx = 0; caseIdx < e; ++caseIdx) { |
| 4342 | if (cst == op.getCases()[caseIdx]) |
| 4343 | break; |
| 4344 | } |
| 4345 | |
| 4346 | Region &r = (caseIdx < op.getNumCases()) ? op.getCaseRegions()[caseIdx] |
| 4347 | : op.getDefaultRegion(); |
| 4348 | Block &source = r.front(); |
| 4349 | Operation *terminator = source.getTerminator(); |
| 4350 | SmallVector<Value> results = terminator->getOperands(); |
| 4351 | |
| 4352 | rewriter.inlineBlockBefore(&source, op); |
| 4353 | rewriter.eraseOp(op: terminator); |
| 4354 | // Replace the operation with a potentially empty list of results. |
| 4355 | // Fold mechanism doesn't support the case where the result list is empty. |
| 4356 | rewriter.replaceOp(op, results); |
| 4357 | |
| 4358 | return success(); |
| 4359 | } |
| 4360 | }; |
| 4361 | |
| 4362 | void IndexSwitchOp::getCanonicalizationPatterns(RewritePatternSet &results, |
| 4363 | MLIRContext *context) { |
| 4364 | results.add<FoldConstantCase>(context); |
| 4365 | } |
| 4366 | |
| 4367 | //===----------------------------------------------------------------------===// |
| 4368 | // TableGen'd op method definitions |
| 4369 | //===----------------------------------------------------------------------===// |
| 4370 | |
| 4371 | #define GET_OP_CLASSES |
| 4372 | #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc" |
| 4373 | |