| 1 | //===- SCF.cpp - Structured Control Flow Operations -----------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Dialect/SCF/IR/SCF.h" |
| 10 | #include "mlir/Conversion/ConvertToEmitC/ToEmitCInterface.h" |
| 11 | #include "mlir/Dialect/Arith/IR/Arith.h" |
| 12 | #include "mlir/Dialect/Arith/Utils/Utils.h" |
| 13 | #include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h" |
| 14 | #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" |
| 15 | #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" |
| 16 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| 17 | #include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h" |
| 18 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| 19 | #include "mlir/IR/BuiltinAttributes.h" |
| 20 | #include "mlir/IR/IRMapping.h" |
| 21 | #include "mlir/IR/Matchers.h" |
| 22 | #include "mlir/IR/PatternMatch.h" |
| 23 | #include "mlir/Interfaces/FunctionInterfaces.h" |
| 24 | #include "mlir/Interfaces/ValueBoundsOpInterface.h" |
| 25 | #include "mlir/Transforms/InliningUtils.h" |
| 26 | #include "llvm/ADT/MapVector.h" |
| 27 | #include "llvm/ADT/SmallPtrSet.h" |
| 28 | |
| 29 | using namespace mlir; |
| 30 | using namespace mlir::scf; |
| 31 | |
| 32 | #include "mlir/Dialect/SCF/IR/SCFOpsDialect.cpp.inc" |
| 33 | |
| 34 | //===----------------------------------------------------------------------===// |
| 35 | // SCFDialect Dialect Interfaces |
| 36 | //===----------------------------------------------------------------------===// |
| 37 | |
| 38 | namespace { |
| 39 | struct SCFInlinerInterface : public DialectInlinerInterface { |
| 40 | using DialectInlinerInterface::DialectInlinerInterface; |
| 41 | // We don't have any special restrictions on what can be inlined into |
| 42 | // destination regions (e.g. while/conditional bodies). Always allow it. |
| 43 | bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, |
| 44 | IRMapping &valueMapping) const final { |
| 45 | return true; |
| 46 | } |
| 47 | // Operations in scf dialect are always legal to inline since they are |
| 48 | // pure. |
| 49 | bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final { |
| 50 | return true; |
| 51 | } |
| 52 | // Handle the given inlined terminator by replacing it with a new operation |
| 53 | // as necessary. Required when the region has only one block. |
| 54 | void handleTerminator(Operation *op, ValueRange valuesToRepl) const final { |
| 55 | auto retValOp = dyn_cast<scf::YieldOp>(Val: op); |
| 56 | if (!retValOp) |
| 57 | return; |
| 58 | |
| 59 | for (auto retValue : llvm::zip(t&: valuesToRepl, u: retValOp.getOperands())) { |
| 60 | std::get<0>(t&: retValue).replaceAllUsesWith(newValue: std::get<1>(t&: retValue)); |
| 61 | } |
| 62 | } |
| 63 | }; |
| 64 | } // namespace |
| 65 | |
| 66 | //===----------------------------------------------------------------------===// |
| 67 | // SCFDialect |
| 68 | //===----------------------------------------------------------------------===// |
| 69 | |
| 70 | void SCFDialect::initialize() { |
| 71 | addOperations< |
| 72 | #define GET_OP_LIST |
| 73 | #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc" |
| 74 | >(); |
| 75 | addInterfaces<SCFInlinerInterface>(); |
| 76 | declarePromisedInterface<ConvertToEmitCPatternInterface, SCFDialect>(); |
| 77 | declarePromisedInterfaces<bufferization::BufferDeallocationOpInterface, |
| 78 | InParallelOp, ReduceReturnOp>(); |
| 79 | declarePromisedInterfaces<bufferization::BufferizableOpInterface, ConditionOp, |
| 80 | ExecuteRegionOp, ForOp, IfOp, IndexSwitchOp, |
| 81 | ForallOp, InParallelOp, WhileOp, YieldOp>(); |
| 82 | declarePromisedInterface<ValueBoundsOpInterface, ForOp>(); |
| 83 | } |
| 84 | |
| 85 | /// Default callback for IfOp builders. Inserts a yield without arguments. |
| 86 | void mlir::scf::buildTerminatedBody(OpBuilder &builder, Location loc) { |
| 87 | builder.create<scf::YieldOp>(location: loc); |
| 88 | } |
| 89 | |
| 90 | /// Verifies that the first block of the given `region` is terminated by a |
| 91 | /// TerminatorTy. Reports errors on the given operation if it is not the case. |
| 92 | template <typename TerminatorTy> |
| 93 | static TerminatorTy verifyAndGetTerminator(Operation *op, Region ®ion, |
| 94 | StringRef errorMessage) { |
| 95 | Operation *terminatorOperation = nullptr; |
| 96 | if (!region.empty() && !region.front().empty()) { |
| 97 | terminatorOperation = ®ion.front().back(); |
| 98 | if (auto yield = dyn_cast_or_null<TerminatorTy>(terminatorOperation)) |
| 99 | return yield; |
| 100 | } |
| 101 | auto diag = op->emitOpError(message: errorMessage); |
| 102 | if (terminatorOperation) |
| 103 | diag.attachNote(noteLoc: terminatorOperation->getLoc()) << "terminator here" ; |
| 104 | return nullptr; |
| 105 | } |
| 106 | |
| 107 | //===----------------------------------------------------------------------===// |
| 108 | // ExecuteRegionOp |
| 109 | //===----------------------------------------------------------------------===// |
| 110 | |
| 111 | /// Replaces the given op with the contents of the given single-block region, |
| 112 | /// using the operands of the block terminator to replace operation results. |
| 113 | static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, |
| 114 | Region ®ion, ValueRange blockArgs = {}) { |
| 115 | assert(llvm::hasSingleElement(region) && "expected single-region block" ); |
| 116 | Block *block = ®ion.front(); |
| 117 | Operation *terminator = block->getTerminator(); |
| 118 | ValueRange results = terminator->getOperands(); |
| 119 | rewriter.inlineBlockBefore(source: block, op, argValues: blockArgs); |
| 120 | rewriter.replaceOp(op, newValues: results); |
| 121 | rewriter.eraseOp(op: terminator); |
| 122 | } |
| 123 | |
| 124 | /// |
| 125 | /// (ssa-id `=`)? `execute_region` `->` function-result-type `{` |
| 126 | /// block+ |
| 127 | /// `}` |
| 128 | /// |
| 129 | /// Example: |
| 130 | /// scf.execute_region -> i32 { |
| 131 | /// %idx = load %rI[%i] : memref<128xi32> |
| 132 | /// return %idx : i32 |
| 133 | /// } |
| 134 | /// |
| 135 | ParseResult ExecuteRegionOp::parse(OpAsmParser &parser, |
| 136 | OperationState &result) { |
| 137 | if (parser.parseOptionalArrowTypeList(result&: result.types)) |
| 138 | return failure(); |
| 139 | |
| 140 | // Introduce the body region and parse it. |
| 141 | Region *body = result.addRegion(); |
| 142 | if (parser.parseRegion(region&: *body, /*arguments=*/{}, /*argTypes=*/enableNameShadowing: {}) || |
| 143 | parser.parseOptionalAttrDict(result&: result.attributes)) |
| 144 | return failure(); |
| 145 | |
| 146 | return success(); |
| 147 | } |
| 148 | |
| 149 | void ExecuteRegionOp::print(OpAsmPrinter &p) { |
| 150 | p.printOptionalArrowTypeList(types: getResultTypes()); |
| 151 | |
| 152 | p << ' '; |
| 153 | p.printRegion(blocks&: getRegion(), |
| 154 | /*printEntryBlockArgs=*/false, |
| 155 | /*printBlockTerminators=*/true); |
| 156 | |
| 157 | p.printOptionalAttrDict(attrs: (*this)->getAttrs()); |
| 158 | } |
| 159 | |
| 160 | LogicalResult ExecuteRegionOp::verify() { |
| 161 | if (getRegion().empty()) |
| 162 | return emitOpError(message: "region needs to have at least one block" ); |
| 163 | if (getRegion().front().getNumArguments() > 0) |
| 164 | return emitOpError(message: "region cannot have any arguments" ); |
| 165 | return success(); |
| 166 | } |
| 167 | |
| 168 | // Inline an ExecuteRegionOp if it only contains one block. |
| 169 | // "test.foo"() : () -> () |
| 170 | // %v = scf.execute_region -> i64 { |
| 171 | // %x = "test.val"() : () -> i64 |
| 172 | // scf.yield %x : i64 |
| 173 | // } |
| 174 | // "test.bar"(%v) : (i64) -> () |
| 175 | // |
| 176 | // becomes |
| 177 | // |
| 178 | // "test.foo"() : () -> () |
| 179 | // %x = "test.val"() : () -> i64 |
| 180 | // "test.bar"(%x) : (i64) -> () |
| 181 | // |
| 182 | struct SingleBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> { |
| 183 | using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern; |
| 184 | |
| 185 | LogicalResult matchAndRewrite(ExecuteRegionOp op, |
| 186 | PatternRewriter &rewriter) const override { |
| 187 | if (!llvm::hasSingleElement(C&: op.getRegion())) |
| 188 | return failure(); |
| 189 | replaceOpWithRegion(rewriter, op, region&: op.getRegion()); |
| 190 | return success(); |
| 191 | } |
| 192 | }; |
| 193 | |
| 194 | // Inline an ExecuteRegionOp if its parent can contain multiple blocks. |
| 195 | // TODO generalize the conditions for operations which can be inlined into. |
| 196 | // func @func_execute_region_elim() { |
| 197 | // "test.foo"() : () -> () |
| 198 | // %v = scf.execute_region -> i64 { |
| 199 | // %c = "test.cmp"() : () -> i1 |
| 200 | // cf.cond_br %c, ^bb2, ^bb3 |
| 201 | // ^bb2: |
| 202 | // %x = "test.val1"() : () -> i64 |
| 203 | // cf.br ^bb4(%x : i64) |
| 204 | // ^bb3: |
| 205 | // %y = "test.val2"() : () -> i64 |
| 206 | // cf.br ^bb4(%y : i64) |
| 207 | // ^bb4(%z : i64): |
| 208 | // scf.yield %z : i64 |
| 209 | // } |
| 210 | // "test.bar"(%v) : (i64) -> () |
| 211 | // return |
| 212 | // } |
| 213 | // |
| 214 | // becomes |
| 215 | // |
| 216 | // func @func_execute_region_elim() { |
| 217 | // "test.foo"() : () -> () |
| 218 | // %c = "test.cmp"() : () -> i1 |
| 219 | // cf.cond_br %c, ^bb1, ^bb2 |
| 220 | // ^bb1: // pred: ^bb0 |
| 221 | // %x = "test.val1"() : () -> i64 |
| 222 | // cf.br ^bb3(%x : i64) |
| 223 | // ^bb2: // pred: ^bb0 |
| 224 | // %y = "test.val2"() : () -> i64 |
| 225 | // cf.br ^bb3(%y : i64) |
| 226 | // ^bb3(%z: i64): // 2 preds: ^bb1, ^bb2 |
| 227 | // "test.bar"(%z) : (i64) -> () |
| 228 | // return |
| 229 | // } |
| 230 | // |
| 231 | struct MultiBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> { |
| 232 | using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern; |
| 233 | |
| 234 | LogicalResult matchAndRewrite(ExecuteRegionOp op, |
| 235 | PatternRewriter &rewriter) const override { |
| 236 | if (!isa<FunctionOpInterface, ExecuteRegionOp>(Val: op->getParentOp())) |
| 237 | return failure(); |
| 238 | |
| 239 | Block *prevBlock = op->getBlock(); |
| 240 | Block *postBlock = rewriter.splitBlock(block: prevBlock, before: op->getIterator()); |
| 241 | rewriter.setInsertionPointToEnd(prevBlock); |
| 242 | |
| 243 | rewriter.create<cf::BranchOp>(location: op.getLoc(), args: &op.getRegion().front()); |
| 244 | |
| 245 | for (Block &blk : op.getRegion()) { |
| 246 | if (YieldOp yieldOp = dyn_cast<YieldOp>(Val: blk.getTerminator())) { |
| 247 | rewriter.setInsertionPoint(yieldOp); |
| 248 | rewriter.create<cf::BranchOp>(location: yieldOp.getLoc(), args&: postBlock, |
| 249 | args: yieldOp.getResults()); |
| 250 | rewriter.eraseOp(op: yieldOp); |
| 251 | } |
| 252 | } |
| 253 | |
| 254 | rewriter.inlineRegionBefore(region&: op.getRegion(), before: postBlock); |
| 255 | SmallVector<Value> blockArgs; |
| 256 | |
| 257 | for (auto res : op.getResults()) |
| 258 | blockArgs.push_back(Elt: postBlock->addArgument(type: res.getType(), loc: res.getLoc())); |
| 259 | |
| 260 | rewriter.replaceOp(op, newValues: blockArgs); |
| 261 | return success(); |
| 262 | } |
| 263 | }; |
| 264 | |
| 265 | void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results, |
| 266 | MLIRContext *context) { |
| 267 | results.add<SingleBlockExecuteInliner, MultiBlockExecuteInliner>(arg&: context); |
| 268 | } |
| 269 | |
| 270 | void ExecuteRegionOp::getSuccessorRegions( |
| 271 | RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { |
| 272 | // If the predecessor is the ExecuteRegionOp, branch into the body. |
| 273 | if (point.isParent()) { |
| 274 | regions.push_back(Elt: RegionSuccessor(&getRegion())); |
| 275 | return; |
| 276 | } |
| 277 | |
| 278 | // Otherwise, the region branches back to the parent operation. |
| 279 | regions.push_back(Elt: RegionSuccessor(getResults())); |
| 280 | } |
| 281 | |
| 282 | //===----------------------------------------------------------------------===// |
| 283 | // ConditionOp |
| 284 | //===----------------------------------------------------------------------===// |
| 285 | |
| 286 | MutableOperandRange |
| 287 | ConditionOp::getMutableSuccessorOperands(RegionBranchPoint point) { |
| 288 | assert((point.isParent() || point == getParentOp().getAfter()) && |
| 289 | "condition op can only exit the loop or branch to the after" |
| 290 | "region" ); |
| 291 | // Pass all operands except the condition to the successor region. |
| 292 | return getArgsMutable(); |
| 293 | } |
| 294 | |
| 295 | void ConditionOp::getSuccessorRegions( |
| 296 | ArrayRef<Attribute> operands, SmallVectorImpl<RegionSuccessor> ®ions) { |
| 297 | FoldAdaptor adaptor(operands, *this); |
| 298 | |
| 299 | WhileOp whileOp = getParentOp(); |
| 300 | |
| 301 | // Condition can either lead to the after region or back to the parent op |
| 302 | // depending on whether the condition is true or not. |
| 303 | auto boolAttr = dyn_cast_or_null<BoolAttr>(Val: adaptor.getCondition()); |
| 304 | if (!boolAttr || boolAttr.getValue()) |
| 305 | regions.emplace_back(Args: &whileOp.getAfter(), |
| 306 | Args: whileOp.getAfter().getArguments()); |
| 307 | if (!boolAttr || !boolAttr.getValue()) |
| 308 | regions.emplace_back(Args: whileOp.getResults()); |
| 309 | } |
| 310 | |
| 311 | //===----------------------------------------------------------------------===// |
| 312 | // ForOp |
| 313 | //===----------------------------------------------------------------------===// |
| 314 | |
| 315 | void ForOp::build(OpBuilder &builder, OperationState &result, Value lb, |
| 316 | Value ub, Value step, ValueRange initArgs, |
| 317 | BodyBuilderFn bodyBuilder) { |
| 318 | OpBuilder::InsertionGuard guard(builder); |
| 319 | |
| 320 | result.addOperands(newOperands: {lb, ub, step}); |
| 321 | result.addOperands(newOperands: initArgs); |
| 322 | for (Value v : initArgs) |
| 323 | result.addTypes(newTypes: v.getType()); |
| 324 | Type t = lb.getType(); |
| 325 | Region *bodyRegion = result.addRegion(); |
| 326 | Block *bodyBlock = builder.createBlock(parent: bodyRegion); |
| 327 | bodyBlock->addArgument(type: t, loc: result.location); |
| 328 | for (Value v : initArgs) |
| 329 | bodyBlock->addArgument(type: v.getType(), loc: v.getLoc()); |
| 330 | |
| 331 | // Create the default terminator if the builder is not provided and if the |
| 332 | // iteration arguments are not provided. Otherwise, leave this to the caller |
| 333 | // because we don't know which values to return from the loop. |
| 334 | if (initArgs.empty() && !bodyBuilder) { |
| 335 | ForOp::ensureTerminator(region&: *bodyRegion, builder, loc: result.location); |
| 336 | } else if (bodyBuilder) { |
| 337 | OpBuilder::InsertionGuard guard(builder); |
| 338 | builder.setInsertionPointToStart(bodyBlock); |
| 339 | bodyBuilder(builder, result.location, bodyBlock->getArgument(i: 0), |
| 340 | bodyBlock->getArguments().drop_front()); |
| 341 | } |
| 342 | } |
| 343 | |
| 344 | LogicalResult ForOp::verify() { |
| 345 | // Check that the number of init args and op results is the same. |
| 346 | if (getInitArgs().size() != getNumResults()) |
| 347 | return emitOpError( |
| 348 | message: "mismatch in number of loop-carried values and defined values" ); |
| 349 | |
| 350 | return success(); |
| 351 | } |
| 352 | |
| 353 | LogicalResult ForOp::verifyRegions() { |
| 354 | // Check that the body defines as single block argument for the induction |
| 355 | // variable. |
| 356 | if (getInductionVar().getType() != getLowerBound().getType()) |
| 357 | return emitOpError( |
| 358 | message: "expected induction variable to be same type as bounds and step" ); |
| 359 | |
| 360 | if (getNumRegionIterArgs() != getNumResults()) |
| 361 | return emitOpError( |
| 362 | message: "mismatch in number of basic block args and defined values" ); |
| 363 | |
| 364 | auto initArgs = getInitArgs(); |
| 365 | auto iterArgs = getRegionIterArgs(); |
| 366 | auto opResults = getResults(); |
| 367 | unsigned i = 0; |
| 368 | for (auto e : llvm::zip(t&: initArgs, u&: iterArgs, args&: opResults)) { |
| 369 | if (std::get<0>(t&: e).getType() != std::get<2>(t&: e).getType()) |
| 370 | return emitOpError() << "types mismatch between " << i |
| 371 | << "th iter operand and defined value" ; |
| 372 | if (std::get<1>(t&: e).getType() != std::get<2>(t&: e).getType()) |
| 373 | return emitOpError() << "types mismatch between " << i |
| 374 | << "th iter region arg and defined value" ; |
| 375 | |
| 376 | ++i; |
| 377 | } |
| 378 | return success(); |
| 379 | } |
| 380 | |
| 381 | std::optional<SmallVector<Value>> ForOp::getLoopInductionVars() { |
| 382 | return SmallVector<Value>{getInductionVar()}; |
| 383 | } |
| 384 | |
| 385 | std::optional<SmallVector<OpFoldResult>> ForOp::getLoopLowerBounds() { |
| 386 | return SmallVector<OpFoldResult>{OpFoldResult(getLowerBound())}; |
| 387 | } |
| 388 | |
| 389 | std::optional<SmallVector<OpFoldResult>> ForOp::getLoopSteps() { |
| 390 | return SmallVector<OpFoldResult>{OpFoldResult(getStep())}; |
| 391 | } |
| 392 | |
| 393 | std::optional<SmallVector<OpFoldResult>> ForOp::getLoopUpperBounds() { |
| 394 | return SmallVector<OpFoldResult>{OpFoldResult(getUpperBound())}; |
| 395 | } |
| 396 | |
| 397 | std::optional<ResultRange> ForOp::getLoopResults() { return getResults(); } |
| 398 | |
| 399 | /// Promotes the loop body of a forOp to its containing block if the forOp |
| 400 | /// it can be determined that the loop has a single iteration. |
| 401 | LogicalResult ForOp::promoteIfSingleIteration(RewriterBase &rewriter) { |
| 402 | std::optional<int64_t> tripCount = |
| 403 | constantTripCount(lb: getLowerBound(), ub: getUpperBound(), step: getStep()); |
| 404 | if (!tripCount.has_value() || tripCount != 1) |
| 405 | return failure(); |
| 406 | |
| 407 | // Replace all results with the yielded values. |
| 408 | auto yieldOp = cast<scf::YieldOp>(Val: getBody()->getTerminator()); |
| 409 | rewriter.replaceAllUsesWith(from: getResults(), to: getYieldedValues()); |
| 410 | |
| 411 | // Replace block arguments with lower bound (replacement for IV) and |
| 412 | // iter_args. |
| 413 | SmallVector<Value> bbArgReplacements; |
| 414 | bbArgReplacements.push_back(Elt: getLowerBound()); |
| 415 | llvm::append_range(C&: bbArgReplacements, R: getInitArgs()); |
| 416 | |
| 417 | // Move the loop body operations to the loop's containing block. |
| 418 | rewriter.inlineBlockBefore(source: getBody(), dest: getOperation()->getBlock(), |
| 419 | before: getOperation()->getIterator(), argValues: bbArgReplacements); |
| 420 | |
| 421 | // Erase the old terminator and the loop. |
| 422 | rewriter.eraseOp(op: yieldOp); |
| 423 | rewriter.eraseOp(op: *this); |
| 424 | |
| 425 | return success(); |
| 426 | } |
| 427 | |
| 428 | /// Prints the initialization list in the form of |
| 429 | /// <prefix>(%inner = %outer, %inner2 = %outer2, <...>) |
| 430 | /// where 'inner' values are assumed to be region arguments and 'outer' values |
| 431 | /// are regular SSA values. |
| 432 | static void printInitializationList(OpAsmPrinter &p, |
| 433 | Block::BlockArgListType blocksArgs, |
| 434 | ValueRange initializers, |
| 435 | StringRef prefix = "" ) { |
| 436 | assert(blocksArgs.size() == initializers.size() && |
| 437 | "expected same length of arguments and initializers" ); |
| 438 | if (initializers.empty()) |
| 439 | return; |
| 440 | |
| 441 | p << prefix << '('; |
| 442 | llvm::interleaveComma(c: llvm::zip(t&: blocksArgs, u&: initializers), os&: p, each_fn: [&](auto it) { |
| 443 | p << std::get<0>(it) << " = " << std::get<1>(it); |
| 444 | }); |
| 445 | p << ")" ; |
| 446 | } |
| 447 | |
| 448 | void ForOp::print(OpAsmPrinter &p) { |
| 449 | p << " " << getInductionVar() << " = " << getLowerBound() << " to " |
| 450 | << getUpperBound() << " step " << getStep(); |
| 451 | |
| 452 | printInitializationList(p, blocksArgs: getRegionIterArgs(), initializers: getInitArgs(), prefix: " iter_args" ); |
| 453 | if (!getInitArgs().empty()) |
| 454 | p << " -> (" << getInitArgs().getTypes() << ')'; |
| 455 | p << ' '; |
| 456 | if (Type t = getInductionVar().getType(); !t.isIndex()) |
| 457 | p << " : " << t << ' '; |
| 458 | p.printRegion(blocks&: getRegion(), |
| 459 | /*printEntryBlockArgs=*/false, |
| 460 | /*printBlockTerminators=*/!getInitArgs().empty()); |
| 461 | p.printOptionalAttrDict(attrs: (*this)->getAttrs()); |
| 462 | } |
| 463 | |
| 464 | ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) { |
| 465 | auto &builder = parser.getBuilder(); |
| 466 | Type type; |
| 467 | |
| 468 | OpAsmParser::Argument inductionVariable; |
| 469 | OpAsmParser::UnresolvedOperand lb, ub, step; |
| 470 | |
| 471 | // Parse the induction variable followed by '='. |
| 472 | if (parser.parseOperand(result&: inductionVariable.ssaName) || parser.parseEqual() || |
| 473 | // Parse loop bounds. |
| 474 | parser.parseOperand(result&: lb) || parser.parseKeyword(keyword: "to" ) || |
| 475 | parser.parseOperand(result&: ub) || parser.parseKeyword(keyword: "step" ) || |
| 476 | parser.parseOperand(result&: step)) |
| 477 | return failure(); |
| 478 | |
| 479 | // Parse the optional initial iteration arguments. |
| 480 | SmallVector<OpAsmParser::Argument, 4> regionArgs; |
| 481 | SmallVector<OpAsmParser::UnresolvedOperand, 4> operands; |
| 482 | regionArgs.push_back(Elt: inductionVariable); |
| 483 | |
| 484 | bool hasIterArgs = succeeded(Result: parser.parseOptionalKeyword(keyword: "iter_args" )); |
| 485 | if (hasIterArgs) { |
| 486 | // Parse assignment list and results type list. |
| 487 | if (parser.parseAssignmentList(lhs&: regionArgs, rhs&: operands) || |
| 488 | parser.parseArrowTypeList(result&: result.types)) |
| 489 | return failure(); |
| 490 | } |
| 491 | |
| 492 | if (regionArgs.size() != result.types.size() + 1) |
| 493 | return parser.emitError( |
| 494 | loc: parser.getNameLoc(), |
| 495 | message: "mismatch in number of loop-carried values and defined values" ); |
| 496 | |
| 497 | // Parse optional type, else assume Index. |
| 498 | if (parser.parseOptionalColon()) |
| 499 | type = builder.getIndexType(); |
| 500 | else if (parser.parseType(result&: type)) |
| 501 | return failure(); |
| 502 | |
| 503 | // Set block argument types, so that they are known when parsing the region. |
| 504 | regionArgs.front().type = type; |
| 505 | for (auto [iterArg, type] : |
| 506 | llvm::zip_equal(t: llvm::drop_begin(RangeOrContainer&: regionArgs), u&: result.types)) |
| 507 | iterArg.type = type; |
| 508 | |
| 509 | // Parse the body region. |
| 510 | Region *body = result.addRegion(); |
| 511 | if (parser.parseRegion(region&: *body, arguments: regionArgs)) |
| 512 | return failure(); |
| 513 | ForOp::ensureTerminator(region&: *body, builder, loc: result.location); |
| 514 | |
| 515 | // Resolve input operands. This should be done after parsing the region to |
| 516 | // catch invalid IR where operands were defined inside of the region. |
| 517 | if (parser.resolveOperand(operand: lb, type, result&: result.operands) || |
| 518 | parser.resolveOperand(operand: ub, type, result&: result.operands) || |
| 519 | parser.resolveOperand(operand: step, type, result&: result.operands)) |
| 520 | return failure(); |
| 521 | if (hasIterArgs) { |
| 522 | for (auto argOperandType : llvm::zip_equal(t: llvm::drop_begin(RangeOrContainer&: regionArgs), |
| 523 | u&: operands, args&: result.types)) { |
| 524 | Type type = std::get<2>(t&: argOperandType); |
| 525 | std::get<0>(t&: argOperandType).type = type; |
| 526 | if (parser.resolveOperand(operand: std::get<1>(t&: argOperandType), type, |
| 527 | result&: result.operands)) |
| 528 | return failure(); |
| 529 | } |
| 530 | } |
| 531 | |
| 532 | // Parse the optional attribute list. |
| 533 | if (parser.parseOptionalAttrDict(result&: result.attributes)) |
| 534 | return failure(); |
| 535 | |
| 536 | return success(); |
| 537 | } |
| 538 | |
| 539 | SmallVector<Region *> ForOp::getLoopRegions() { return {&getRegion()}; } |
| 540 | |
| 541 | Block::BlockArgListType ForOp::getRegionIterArgs() { |
| 542 | return getBody()->getArguments().drop_front(N: getNumInductionVars()); |
| 543 | } |
| 544 | |
| 545 | MutableArrayRef<OpOperand> ForOp::getInitsMutable() { |
| 546 | return getInitArgsMutable(); |
| 547 | } |
| 548 | |
| 549 | FailureOr<LoopLikeOpInterface> |
| 550 | ForOp::replaceWithAdditionalYields(RewriterBase &rewriter, |
| 551 | ValueRange newInitOperands, |
| 552 | bool replaceInitOperandUsesInLoop, |
| 553 | const NewYieldValuesFn &newYieldValuesFn) { |
| 554 | // Create a new loop before the existing one, with the extra operands. |
| 555 | OpBuilder::InsertionGuard g(rewriter); |
| 556 | rewriter.setInsertionPoint(getOperation()); |
| 557 | auto inits = llvm::to_vector(Range: getInitArgs()); |
| 558 | inits.append(in_start: newInitOperands.begin(), in_end: newInitOperands.end()); |
| 559 | scf::ForOp newLoop = rewriter.create<scf::ForOp>( |
| 560 | location: getLoc(), args: getLowerBound(), args: getUpperBound(), args: getStep(), args&: inits, |
| 561 | args: [](OpBuilder &, Location, Value, ValueRange) {}); |
| 562 | newLoop->setAttrs(getPrunedAttributeList(op: getOperation(), elidedAttrs: {})); |
| 563 | |
| 564 | // Generate the new yield values and append them to the scf.yield operation. |
| 565 | auto yieldOp = cast<scf::YieldOp>(Val: getBody()->getTerminator()); |
| 566 | ArrayRef<BlockArgument> newIterArgs = |
| 567 | newLoop.getBody()->getArguments().take_back(N: newInitOperands.size()); |
| 568 | { |
| 569 | OpBuilder::InsertionGuard g(rewriter); |
| 570 | rewriter.setInsertionPoint(yieldOp); |
| 571 | SmallVector<Value> newYieldedValues = |
| 572 | newYieldValuesFn(rewriter, getLoc(), newIterArgs); |
| 573 | assert(newInitOperands.size() == newYieldedValues.size() && |
| 574 | "expected as many new yield values as new iter operands" ); |
| 575 | rewriter.modifyOpInPlace(root: yieldOp, callable: [&]() { |
| 576 | yieldOp.getResultsMutable().append(values: newYieldedValues); |
| 577 | }); |
| 578 | } |
| 579 | |
| 580 | // Move the loop body to the new op. |
| 581 | rewriter.mergeBlocks(source: getBody(), dest: newLoop.getBody(), |
| 582 | argValues: newLoop.getBody()->getArguments().take_front( |
| 583 | N: getBody()->getNumArguments())); |
| 584 | |
| 585 | if (replaceInitOperandUsesInLoop) { |
| 586 | // Replace all uses of `newInitOperands` with the corresponding basic block |
| 587 | // arguments. |
| 588 | for (auto it : llvm::zip(t&: newInitOperands, u&: newIterArgs)) { |
| 589 | rewriter.replaceUsesWithIf(from: std::get<0>(t&: it), to: std::get<1>(t&: it), |
| 590 | functor: [&](OpOperand &use) { |
| 591 | Operation *user = use.getOwner(); |
| 592 | return newLoop->isProperAncestor(other: user); |
| 593 | }); |
| 594 | } |
| 595 | } |
| 596 | |
| 597 | // Replace the old loop. |
| 598 | rewriter.replaceOp(op: getOperation(), |
| 599 | newValues: newLoop->getResults().take_front(n: getNumResults())); |
| 600 | return cast<LoopLikeOpInterface>(Val: newLoop.getOperation()); |
| 601 | } |
| 602 | |
| 603 | ForOp mlir::scf::getForInductionVarOwner(Value val) { |
| 604 | auto ivArg = llvm::dyn_cast<BlockArgument>(Val&: val); |
| 605 | if (!ivArg) |
| 606 | return ForOp(); |
| 607 | assert(ivArg.getOwner() && "unlinked block argument" ); |
| 608 | auto *containingOp = ivArg.getOwner()->getParentOp(); |
| 609 | return dyn_cast_or_null<ForOp>(Val: containingOp); |
| 610 | } |
| 611 | |
| 612 | OperandRange ForOp::getEntrySuccessorOperands(RegionBranchPoint point) { |
| 613 | return getInitArgs(); |
| 614 | } |
| 615 | |
| 616 | void ForOp::getSuccessorRegions(RegionBranchPoint point, |
| 617 | SmallVectorImpl<RegionSuccessor> ®ions) { |
| 618 | // Both the operation itself and the region may be branching into the body or |
| 619 | // back into the operation itself. It is possible for loop not to enter the |
| 620 | // body. |
| 621 | regions.push_back(Elt: RegionSuccessor(&getRegion(), getRegionIterArgs())); |
| 622 | regions.push_back(Elt: RegionSuccessor(getResults())); |
| 623 | } |
| 624 | |
| 625 | SmallVector<Region *> ForallOp::getLoopRegions() { return {&getRegion()}; } |
| 626 | |
| 627 | /// Promotes the loop body of a forallOp to its containing block if it can be |
| 628 | /// determined that the loop has a single iteration. |
| 629 | LogicalResult scf::ForallOp::promoteIfSingleIteration(RewriterBase &rewriter) { |
| 630 | for (auto [lb, ub, step] : |
| 631 | llvm::zip(t: getMixedLowerBound(), u: getMixedUpperBound(), args: getMixedStep())) { |
| 632 | auto tripCount = constantTripCount(lb, ub, step); |
| 633 | if (!tripCount.has_value() || *tripCount != 1) |
| 634 | return failure(); |
| 635 | } |
| 636 | |
| 637 | promote(rewriter, forallOp: *this); |
| 638 | return success(); |
| 639 | } |
| 640 | |
| 641 | Block::BlockArgListType ForallOp::getRegionIterArgs() { |
| 642 | return getBody()->getArguments().drop_front(N: getRank()); |
| 643 | } |
| 644 | |
| 645 | MutableArrayRef<OpOperand> ForallOp::getInitsMutable() { |
| 646 | return getOutputsMutable(); |
| 647 | } |
| 648 | |
| 649 | /// Promotes the loop body of a scf::ForallOp to its containing block. |
| 650 | void mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp) { |
| 651 | OpBuilder::InsertionGuard g(rewriter); |
| 652 | scf::InParallelOp terminator = forallOp.getTerminator(); |
| 653 | |
| 654 | // Replace block arguments with lower bounds (replacements for IVs) and |
| 655 | // outputs. |
| 656 | SmallVector<Value> bbArgReplacements = forallOp.getLowerBound(b&: rewriter); |
| 657 | bbArgReplacements.append(in_start: forallOp.getOutputs().begin(), |
| 658 | in_end: forallOp.getOutputs().end()); |
| 659 | |
| 660 | // Move the loop body operations to the loop's containing block. |
| 661 | rewriter.inlineBlockBefore(source: forallOp.getBody(), dest: forallOp->getBlock(), |
| 662 | before: forallOp->getIterator(), argValues: bbArgReplacements); |
| 663 | |
| 664 | // Replace the terminator with tensor.insert_slice ops. |
| 665 | rewriter.setInsertionPointAfter(forallOp); |
| 666 | SmallVector<Value> results; |
| 667 | results.reserve(N: forallOp.getResults().size()); |
| 668 | for (auto &yieldingOp : terminator.getYieldingOps()) { |
| 669 | auto parallelInsertSliceOp = |
| 670 | cast<tensor::ParallelInsertSliceOp>(Val&: yieldingOp); |
| 671 | |
| 672 | Value dst = parallelInsertSliceOp.getDest(); |
| 673 | Value src = parallelInsertSliceOp.getSource(); |
| 674 | if (llvm::isa<TensorType>(Val: src.getType())) { |
| 675 | results.push_back(Elt: rewriter.create<tensor::InsertSliceOp>( |
| 676 | location: forallOp.getLoc(), args: dst.getType(), args&: src, args&: dst, |
| 677 | args: parallelInsertSliceOp.getOffsets(), args: parallelInsertSliceOp.getSizes(), |
| 678 | args: parallelInsertSliceOp.getStrides(), |
| 679 | args: parallelInsertSliceOp.getStaticOffsets(), |
| 680 | args: parallelInsertSliceOp.getStaticSizes(), |
| 681 | args: parallelInsertSliceOp.getStaticStrides())); |
| 682 | } else { |
| 683 | llvm_unreachable("unsupported terminator" ); |
| 684 | } |
| 685 | } |
| 686 | rewriter.replaceAllUsesWith(from: forallOp.getResults(), to: results); |
| 687 | |
| 688 | // Erase the old terminator and the loop. |
| 689 | rewriter.eraseOp(op: terminator); |
| 690 | rewriter.eraseOp(op: forallOp); |
| 691 | } |
| 692 | |
| 693 | LoopNest mlir::scf::buildLoopNest( |
| 694 | OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, |
| 695 | ValueRange steps, ValueRange iterArgs, |
| 696 | function_ref<ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> |
| 697 | bodyBuilder) { |
| 698 | assert(lbs.size() == ubs.size() && |
| 699 | "expected the same number of lower and upper bounds" ); |
| 700 | assert(lbs.size() == steps.size() && |
| 701 | "expected the same number of lower bounds and steps" ); |
| 702 | |
| 703 | // If there are no bounds, call the body-building function and return early. |
| 704 | if (lbs.empty()) { |
| 705 | ValueVector results = |
| 706 | bodyBuilder ? bodyBuilder(builder, loc, ValueRange(), iterArgs) |
| 707 | : ValueVector(); |
| 708 | assert(results.size() == iterArgs.size() && |
| 709 | "loop nest body must return as many values as loop has iteration " |
| 710 | "arguments" ); |
| 711 | return LoopNest{.loops: {}, .results: std::move(results)}; |
| 712 | } |
| 713 | |
| 714 | // First, create the loop structure iteratively using the body-builder |
| 715 | // callback of `ForOp::build`. Do not create `YieldOp`s yet. |
| 716 | OpBuilder::InsertionGuard guard(builder); |
| 717 | SmallVector<scf::ForOp, 4> loops; |
| 718 | SmallVector<Value, 4> ivs; |
| 719 | loops.reserve(N: lbs.size()); |
| 720 | ivs.reserve(N: lbs.size()); |
| 721 | ValueRange currentIterArgs = iterArgs; |
| 722 | Location currentLoc = loc; |
| 723 | for (unsigned i = 0, e = lbs.size(); i < e; ++i) { |
| 724 | auto loop = builder.create<scf::ForOp>( |
| 725 | location: currentLoc, args: lbs[i], args: ubs[i], args: steps[i], args&: currentIterArgs, |
| 726 | args: [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv, |
| 727 | ValueRange args) { |
| 728 | ivs.push_back(Elt: iv); |
| 729 | // It is safe to store ValueRange args because it points to block |
| 730 | // arguments of a loop operation that we also own. |
| 731 | currentIterArgs = args; |
| 732 | currentLoc = nestedLoc; |
| 733 | }); |
| 734 | // Set the builder to point to the body of the newly created loop. We don't |
| 735 | // do this in the callback because the builder is reset when the callback |
| 736 | // returns. |
| 737 | builder.setInsertionPointToStart(loop.getBody()); |
| 738 | loops.push_back(Elt: loop); |
| 739 | } |
| 740 | |
| 741 | // For all loops but the innermost, yield the results of the nested loop. |
| 742 | for (unsigned i = 0, e = loops.size() - 1; i < e; ++i) { |
| 743 | builder.setInsertionPointToEnd(loops[i].getBody()); |
| 744 | builder.create<scf::YieldOp>(location: loc, args: loops[i + 1].getResults()); |
| 745 | } |
| 746 | |
| 747 | // In the body of the innermost loop, call the body building function if any |
| 748 | // and yield its results. |
| 749 | builder.setInsertionPointToStart(loops.back().getBody()); |
| 750 | ValueVector results = bodyBuilder |
| 751 | ? bodyBuilder(builder, currentLoc, ivs, |
| 752 | loops.back().getRegionIterArgs()) |
| 753 | : ValueVector(); |
| 754 | assert(results.size() == iterArgs.size() && |
| 755 | "loop nest body must return as many values as loop has iteration " |
| 756 | "arguments" ); |
| 757 | builder.setInsertionPointToEnd(loops.back().getBody()); |
| 758 | builder.create<scf::YieldOp>(location: loc, args&: results); |
| 759 | |
| 760 | // Return the loops. |
| 761 | ValueVector nestResults; |
| 762 | llvm::append_range(C&: nestResults, R: loops.front().getResults()); |
| 763 | return LoopNest{.loops: std::move(loops), .results: std::move(nestResults)}; |
| 764 | } |
| 765 | |
| 766 | LoopNest mlir::scf::buildLoopNest( |
| 767 | OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, |
| 768 | ValueRange steps, |
| 769 | function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) { |
| 770 | // Delegate to the main function by wrapping the body builder. |
| 771 | return buildLoopNest(builder, loc, lbs, ubs, steps, iterArgs: {}, |
| 772 | bodyBuilder: [&bodyBuilder](OpBuilder &nestedBuilder, |
| 773 | Location nestedLoc, ValueRange ivs, |
| 774 | ValueRange) -> ValueVector { |
| 775 | if (bodyBuilder) |
| 776 | bodyBuilder(nestedBuilder, nestedLoc, ivs); |
| 777 | return {}; |
| 778 | }); |
| 779 | } |
| 780 | |
| 781 | SmallVector<Value> |
| 782 | mlir::scf::replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp, |
| 783 | OpOperand &operand, Value replacement, |
| 784 | const ValueTypeCastFnTy &castFn) { |
| 785 | assert(operand.getOwner() == forOp); |
| 786 | Type oldType = operand.get().getType(), newType = replacement.getType(); |
| 787 | |
| 788 | // 1. Create new iter operands, exactly 1 is replaced. |
| 789 | assert(operand.getOperandNumber() >= forOp.getNumControlOperands() && |
| 790 | "expected an iter OpOperand" ); |
| 791 | assert(operand.get().getType() != replacement.getType() && |
| 792 | "Expected a different type" ); |
| 793 | SmallVector<Value> newIterOperands; |
| 794 | for (OpOperand &opOperand : forOp.getInitArgsMutable()) { |
| 795 | if (opOperand.getOperandNumber() == operand.getOperandNumber()) { |
| 796 | newIterOperands.push_back(Elt: replacement); |
| 797 | continue; |
| 798 | } |
| 799 | newIterOperands.push_back(Elt: opOperand.get()); |
| 800 | } |
| 801 | |
| 802 | // 2. Create the new forOp shell. |
| 803 | scf::ForOp newForOp = rewriter.create<scf::ForOp>( |
| 804 | location: forOp.getLoc(), args: forOp.getLowerBound(), args: forOp.getUpperBound(), |
| 805 | args: forOp.getStep(), args&: newIterOperands); |
| 806 | newForOp->setAttrs(forOp->getAttrs()); |
| 807 | Block &newBlock = newForOp.getRegion().front(); |
| 808 | SmallVector<Value, 4> newBlockTransferArgs(newBlock.getArguments().begin(), |
| 809 | newBlock.getArguments().end()); |
| 810 | |
| 811 | // 3. Inject an incoming cast op at the beginning of the block for the bbArg |
| 812 | // corresponding to the `replacement` value. |
| 813 | OpBuilder::InsertionGuard g(rewriter); |
| 814 | rewriter.setInsertionPointToStart(&newBlock); |
| 815 | BlockArgument newRegionIterArg = newForOp.getTiedLoopRegionIterArg( |
| 816 | opOperand: &newForOp->getOpOperand(idx: operand.getOperandNumber())); |
| 817 | Value castIn = castFn(rewriter, newForOp.getLoc(), oldType, newRegionIterArg); |
| 818 | newBlockTransferArgs[newRegionIterArg.getArgNumber()] = castIn; |
| 819 | |
| 820 | // 4. Steal the old block ops, mapping to the newBlockTransferArgs. |
| 821 | Block &oldBlock = forOp.getRegion().front(); |
| 822 | rewriter.mergeBlocks(source: &oldBlock, dest: &newBlock, argValues: newBlockTransferArgs); |
| 823 | |
| 824 | // 5. Inject an outgoing cast op at the end of the block and yield it instead. |
| 825 | auto clonedYieldOp = cast<scf::YieldOp>(Val: newBlock.getTerminator()); |
| 826 | rewriter.setInsertionPoint(clonedYieldOp); |
| 827 | unsigned yieldIdx = |
| 828 | newRegionIterArg.getArgNumber() - forOp.getNumInductionVars(); |
| 829 | Value castOut = castFn(rewriter, newForOp.getLoc(), newType, |
| 830 | clonedYieldOp.getOperand(i: yieldIdx)); |
| 831 | SmallVector<Value> newYieldOperands = clonedYieldOp.getOperands(); |
| 832 | newYieldOperands[yieldIdx] = castOut; |
| 833 | rewriter.create<scf::YieldOp>(location: newForOp.getLoc(), args&: newYieldOperands); |
| 834 | rewriter.eraseOp(op: clonedYieldOp); |
| 835 | |
| 836 | // 6. Inject an outgoing cast op after the forOp. |
| 837 | rewriter.setInsertionPointAfter(newForOp); |
| 838 | SmallVector<Value> newResults = newForOp.getResults(); |
| 839 | newResults[yieldIdx] = |
| 840 | castFn(rewriter, newForOp.getLoc(), oldType, newResults[yieldIdx]); |
| 841 | |
| 842 | return newResults; |
| 843 | } |
| 844 | |
| 845 | namespace { |
| 846 | // Fold away ForOp iter arguments when: |
| 847 | // 1) The op yields the iter arguments. |
| 848 | // 2) The argument's corresponding outer region iterators (inputs) are yielded. |
| 849 | // 3) The iter arguments have no use and the corresponding (operation) results |
| 850 | // have no use. |
| 851 | // |
| 852 | // These arguments must be defined outside of the ForOp region and can just be |
| 853 | // forwarded after simplifying the op inits, yields and returns. |
| 854 | // |
| 855 | // The implementation uses `inlineBlockBefore` to steal the content of the |
| 856 | // original ForOp and avoid cloning. |
| 857 | struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> { |
| 858 | using OpRewritePattern<scf::ForOp>::OpRewritePattern; |
| 859 | |
| 860 | LogicalResult matchAndRewrite(scf::ForOp forOp, |
| 861 | PatternRewriter &rewriter) const final { |
| 862 | bool canonicalize = false; |
| 863 | |
| 864 | // An internal flat vector of block transfer |
| 865 | // arguments `newBlockTransferArgs` keeps the 1-1 mapping of original to |
| 866 | // transformed block argument mappings. This plays the role of a |
| 867 | // IRMapping for the particular use case of calling into |
| 868 | // `inlineBlockBefore`. |
| 869 | int64_t numResults = forOp.getNumResults(); |
| 870 | SmallVector<bool, 4> keepMask; |
| 871 | keepMask.reserve(N: numResults); |
| 872 | SmallVector<Value, 4> newBlockTransferArgs, newIterArgs, newYieldValues, |
| 873 | newResultValues; |
| 874 | newBlockTransferArgs.reserve(N: 1 + numResults); |
| 875 | newBlockTransferArgs.push_back(Elt: Value()); // iv placeholder with null value |
| 876 | newIterArgs.reserve(N: forOp.getInitArgs().size()); |
| 877 | newYieldValues.reserve(N: numResults); |
| 878 | newResultValues.reserve(N: numResults); |
| 879 | DenseMap<std::pair<Value, Value>, std::pair<Value, Value>> initYieldToArg; |
| 880 | for (auto [init, arg, result, yielded] : |
| 881 | llvm::zip(t: forOp.getInitArgs(), // iter from outside |
| 882 | u: forOp.getRegionIterArgs(), // iter inside region |
| 883 | args: forOp.getResults(), // op results |
| 884 | args: forOp.getYieldedValues() // iter yield |
| 885 | )) { |
| 886 | // Forwarded is `true` when: |
| 887 | // 1) The region `iter` argument is yielded. |
| 888 | // 2) The region `iter` argument the corresponding input is yielded. |
| 889 | // 3) The region `iter` argument has no use, and the corresponding op |
| 890 | // result has no use. |
| 891 | bool forwarded = (arg == yielded) || (init == yielded) || |
| 892 | (arg.use_empty() && result.use_empty()); |
| 893 | if (forwarded) { |
| 894 | canonicalize = true; |
| 895 | keepMask.push_back(Elt: false); |
| 896 | newBlockTransferArgs.push_back(Elt: init); |
| 897 | newResultValues.push_back(Elt: init); |
| 898 | continue; |
| 899 | } |
| 900 | |
| 901 | // Check if a previous kept argument always has the same values for init |
| 902 | // and yielded values. |
| 903 | if (auto it = initYieldToArg.find(Val: {init, yielded}); |
| 904 | it != initYieldToArg.end()) { |
| 905 | canonicalize = true; |
| 906 | keepMask.push_back(Elt: false); |
| 907 | auto [sameArg, sameResult] = it->second; |
| 908 | rewriter.replaceAllUsesWith(from: arg, to: sameArg); |
| 909 | rewriter.replaceAllUsesWith(from: result, to: sameResult); |
| 910 | // The replacement value doesn't matter because there are no uses. |
| 911 | newBlockTransferArgs.push_back(Elt: init); |
| 912 | newResultValues.push_back(Elt: init); |
| 913 | continue; |
| 914 | } |
| 915 | |
| 916 | // This value is kept. |
| 917 | initYieldToArg.insert(KV: {{init, yielded}, {arg, result}}); |
| 918 | keepMask.push_back(Elt: true); |
| 919 | newIterArgs.push_back(Elt: init); |
| 920 | newYieldValues.push_back(Elt: yielded); |
| 921 | newBlockTransferArgs.push_back(Elt: Value()); // placeholder with null value |
| 922 | newResultValues.push_back(Elt: Value()); // placeholder with null value |
| 923 | } |
| 924 | |
| 925 | if (!canonicalize) |
| 926 | return failure(); |
| 927 | |
| 928 | scf::ForOp newForOp = rewriter.create<scf::ForOp>( |
| 929 | location: forOp.getLoc(), args: forOp.getLowerBound(), args: forOp.getUpperBound(), |
| 930 | args: forOp.getStep(), args&: newIterArgs); |
| 931 | newForOp->setAttrs(forOp->getAttrs()); |
| 932 | Block &newBlock = newForOp.getRegion().front(); |
| 933 | |
| 934 | // Replace the null placeholders with newly constructed values. |
| 935 | newBlockTransferArgs[0] = newBlock.getArgument(i: 0); // iv |
| 936 | for (unsigned idx = 0, collapsedIdx = 0, e = newResultValues.size(); |
| 937 | idx != e; ++idx) { |
| 938 | Value &blockTransferArg = newBlockTransferArgs[1 + idx]; |
| 939 | Value &newResultVal = newResultValues[idx]; |
| 940 | assert((blockTransferArg && newResultVal) || |
| 941 | (!blockTransferArg && !newResultVal)); |
| 942 | if (!blockTransferArg) { |
| 943 | blockTransferArg = newForOp.getRegionIterArgs()[collapsedIdx]; |
| 944 | newResultVal = newForOp.getResult(i: collapsedIdx++); |
| 945 | } |
| 946 | } |
| 947 | |
| 948 | Block &oldBlock = forOp.getRegion().front(); |
| 949 | assert(oldBlock.getNumArguments() == newBlockTransferArgs.size() && |
| 950 | "unexpected argument size mismatch" ); |
| 951 | |
| 952 | // No results case: the scf::ForOp builder already created a zero |
| 953 | // result terminator. Merge before this terminator and just get rid of the |
| 954 | // original terminator that has been merged in. |
| 955 | if (newIterArgs.empty()) { |
| 956 | auto newYieldOp = cast<scf::YieldOp>(Val: newBlock.getTerminator()); |
| 957 | rewriter.inlineBlockBefore(source: &oldBlock, op: newYieldOp, argValues: newBlockTransferArgs); |
| 958 | rewriter.eraseOp(op: newBlock.getTerminator()->getPrevNode()); |
| 959 | rewriter.replaceOp(op: forOp, newValues: newResultValues); |
| 960 | return success(); |
| 961 | } |
| 962 | |
| 963 | // No terminator case: merge and rewrite the merged terminator. |
| 964 | auto cloneFilteredTerminator = [&](scf::YieldOp mergedTerminator) { |
| 965 | OpBuilder::InsertionGuard g(rewriter); |
| 966 | rewriter.setInsertionPoint(mergedTerminator); |
| 967 | SmallVector<Value, 4> filteredOperands; |
| 968 | filteredOperands.reserve(N: newResultValues.size()); |
| 969 | for (unsigned idx = 0, e = keepMask.size(); idx < e; ++idx) |
| 970 | if (keepMask[idx]) |
| 971 | filteredOperands.push_back(Elt: mergedTerminator.getOperand(i: idx)); |
| 972 | rewriter.create<scf::YieldOp>(location: mergedTerminator.getLoc(), |
| 973 | args&: filteredOperands); |
| 974 | }; |
| 975 | |
| 976 | rewriter.mergeBlocks(source: &oldBlock, dest: &newBlock, argValues: newBlockTransferArgs); |
| 977 | auto mergedYieldOp = cast<scf::YieldOp>(Val: newBlock.getTerminator()); |
| 978 | cloneFilteredTerminator(mergedYieldOp); |
| 979 | rewriter.eraseOp(op: mergedYieldOp); |
| 980 | rewriter.replaceOp(op: forOp, newValues: newResultValues); |
| 981 | return success(); |
| 982 | } |
| 983 | }; |
| 984 | |
| 985 | /// Util function that tries to compute a constant diff between u and l. |
| 986 | /// Returns std::nullopt when the difference between two AffineValueMap is |
| 987 | /// dynamic. |
| 988 | static std::optional<int64_t> computeConstDiff(Value l, Value u) { |
| 989 | IntegerAttr clb, cub; |
| 990 | if (matchPattern(value: l, pattern: m_Constant(bind_value: &clb)) && matchPattern(value: u, pattern: m_Constant(bind_value: &cub))) { |
| 991 | llvm::APInt lbValue = clb.getValue(); |
| 992 | llvm::APInt ubValue = cub.getValue(); |
| 993 | return (ubValue - lbValue).getSExtValue(); |
| 994 | } |
| 995 | |
| 996 | // Else a simple pattern match for x + c or c + x |
| 997 | llvm::APInt diff; |
| 998 | if (matchPattern( |
| 999 | value: u, pattern: m_Op<arith::AddIOp>(matchers: matchers::m_Val(v: l), matchers: m_ConstantInt(bind_value: &diff))) || |
| 1000 | matchPattern( |
| 1001 | value: u, pattern: m_Op<arith::AddIOp>(matchers: m_ConstantInt(bind_value: &diff), matchers: matchers::m_Val(v: l)))) |
| 1002 | return diff.getSExtValue(); |
| 1003 | return std::nullopt; |
| 1004 | } |
| 1005 | |
| 1006 | /// Rewriting pattern that erases loops that are known not to iterate, replaces |
| 1007 | /// single-iteration loops with their bodies, and removes empty loops that |
| 1008 | /// iterate at least once and only return values defined outside of the loop. |
| 1009 | struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> { |
| 1010 | using OpRewritePattern<ForOp>::OpRewritePattern; |
| 1011 | |
| 1012 | LogicalResult matchAndRewrite(ForOp op, |
| 1013 | PatternRewriter &rewriter) const override { |
| 1014 | // If the upper bound is the same as the lower bound, the loop does not |
| 1015 | // iterate, just remove it. |
| 1016 | if (op.getLowerBound() == op.getUpperBound()) { |
| 1017 | rewriter.replaceOp(op, newValues: op.getInitArgs()); |
| 1018 | return success(); |
| 1019 | } |
| 1020 | |
| 1021 | std::optional<int64_t> diff = |
| 1022 | computeConstDiff(l: op.getLowerBound(), u: op.getUpperBound()); |
| 1023 | if (!diff) |
| 1024 | return failure(); |
| 1025 | |
| 1026 | // If the loop is known to have 0 iterations, remove it. |
| 1027 | if (*diff <= 0) { |
| 1028 | rewriter.replaceOp(op, newValues: op.getInitArgs()); |
| 1029 | return success(); |
| 1030 | } |
| 1031 | |
| 1032 | std::optional<llvm::APInt> maybeStepValue = op.getConstantStep(); |
| 1033 | if (!maybeStepValue) |
| 1034 | return failure(); |
| 1035 | |
| 1036 | // If the loop is known to have 1 iteration, inline its body and remove the |
| 1037 | // loop. |
| 1038 | llvm::APInt stepValue = *maybeStepValue; |
| 1039 | if (stepValue.sge(RHS: *diff)) { |
| 1040 | SmallVector<Value, 4> blockArgs; |
| 1041 | blockArgs.reserve(N: op.getInitArgs().size() + 1); |
| 1042 | blockArgs.push_back(Elt: op.getLowerBound()); |
| 1043 | llvm::append_range(C&: blockArgs, R: op.getInitArgs()); |
| 1044 | replaceOpWithRegion(rewriter, op, region&: op.getRegion(), blockArgs); |
| 1045 | return success(); |
| 1046 | } |
| 1047 | |
| 1048 | // Now we are left with loops that have more than 1 iterations. |
| 1049 | Block &block = op.getRegion().front(); |
| 1050 | if (!llvm::hasSingleElement(C&: block)) |
| 1051 | return failure(); |
| 1052 | // If the loop is empty, iterates at least once, and only returns values |
| 1053 | // defined outside of the loop, remove it and replace it with yield values. |
| 1054 | if (llvm::any_of(Range: op.getYieldedValues(), |
| 1055 | P: [&](Value v) { return !op.isDefinedOutsideOfLoop(value: v); })) |
| 1056 | return failure(); |
| 1057 | rewriter.replaceOp(op, newValues: op.getYieldedValues()); |
| 1058 | return success(); |
| 1059 | } |
| 1060 | }; |
| 1061 | |
| 1062 | /// Fold scf.for iter_arg/result pairs that go through incoming/ougoing |
| 1063 | /// a tensor.cast op pair so as to pull the tensor.cast inside the scf.for: |
| 1064 | /// |
| 1065 | /// ``` |
| 1066 | /// %0 = tensor.cast %t0 : tensor<32x1024xf32> to tensor<?x?xf32> |
| 1067 | /// %1 = scf.for %i = %c0 to %c1024 step %c32 iter_args(%iter_t0 = %0) |
| 1068 | /// -> (tensor<?x?xf32>) { |
| 1069 | /// %2 = call @do(%iter_t0) : (tensor<?x?xf32>) -> tensor<?x?xf32> |
| 1070 | /// scf.yield %2 : tensor<?x?xf32> |
| 1071 | /// } |
| 1072 | /// use_of(%1) |
| 1073 | /// ``` |
| 1074 | /// |
| 1075 | /// folds into: |
| 1076 | /// |
| 1077 | /// ``` |
| 1078 | /// %0 = scf.for %arg2 = %c0 to %c1024 step %c32 iter_args(%arg3 = %arg0) |
| 1079 | /// -> (tensor<32x1024xf32>) { |
| 1080 | /// %2 = tensor.cast %arg3 : tensor<32x1024xf32> to tensor<?x?xf32> |
| 1081 | /// %3 = call @do(%2) : (tensor<?x?xf32>) -> tensor<?x?xf32> |
| 1082 | /// %4 = tensor.cast %3 : tensor<?x?xf32> to tensor<32x1024xf32> |
| 1083 | /// scf.yield %4 : tensor<32x1024xf32> |
| 1084 | /// } |
| 1085 | /// %1 = tensor.cast %0 : tensor<32x1024xf32> to tensor<?x?xf32> |
| 1086 | /// use_of(%1) |
| 1087 | /// ``` |
| 1088 | struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> { |
| 1089 | using OpRewritePattern<ForOp>::OpRewritePattern; |
| 1090 | |
| 1091 | LogicalResult matchAndRewrite(ForOp op, |
| 1092 | PatternRewriter &rewriter) const override { |
| 1093 | for (auto it : llvm::zip(t: op.getInitArgsMutable(), u: op.getResults())) { |
| 1094 | OpOperand &iterOpOperand = std::get<0>(t&: it); |
| 1095 | auto incomingCast = iterOpOperand.get().getDefiningOp<tensor::CastOp>(); |
| 1096 | if (!incomingCast || |
| 1097 | incomingCast.getSource().getType() == incomingCast.getType()) |
| 1098 | continue; |
| 1099 | // If the dest type of the cast does not preserve static information in |
| 1100 | // the source type. |
| 1101 | if (!tensor::preservesStaticInformation( |
| 1102 | source: incomingCast.getDest().getType(), |
| 1103 | target: incomingCast.getSource().getType())) |
| 1104 | continue; |
| 1105 | if (!std::get<1>(t&: it).hasOneUse()) |
| 1106 | continue; |
| 1107 | |
| 1108 | // Create a new ForOp with that iter operand replaced. |
| 1109 | rewriter.replaceOp( |
| 1110 | op, newValues: replaceAndCastForOpIterArg( |
| 1111 | rewriter, forOp: op, operand&: iterOpOperand, replacement: incomingCast.getSource(), |
| 1112 | castFn: [](OpBuilder &b, Location loc, Type type, Value source) { |
| 1113 | return b.create<tensor::CastOp>(location: loc, args&: type, args&: source); |
| 1114 | })); |
| 1115 | return success(); |
| 1116 | } |
| 1117 | return failure(); |
| 1118 | } |
| 1119 | }; |
| 1120 | |
| 1121 | } // namespace |
| 1122 | |
| 1123 | void ForOp::getCanonicalizationPatterns(RewritePatternSet &results, |
| 1124 | MLIRContext *context) { |
| 1125 | results.add<ForOpIterArgsFolder, SimplifyTrivialLoops, ForOpTensorCastFolder>( |
| 1126 | arg&: context); |
| 1127 | } |
| 1128 | |
| 1129 | std::optional<APInt> ForOp::getConstantStep() { |
| 1130 | IntegerAttr step; |
| 1131 | if (matchPattern(value: getStep(), pattern: m_Constant(bind_value: &step))) |
| 1132 | return step.getValue(); |
| 1133 | return {}; |
| 1134 | } |
| 1135 | |
| 1136 | std::optional<MutableArrayRef<OpOperand>> ForOp::getYieldedValuesMutable() { |
| 1137 | return cast<scf::YieldOp>(Val: getBody()->getTerminator()).getResultsMutable(); |
| 1138 | } |
| 1139 | |
| 1140 | Speculation::Speculatability ForOp::getSpeculatability() { |
| 1141 | // `scf.for (I = Start; I < End; I += 1)` terminates for all values of Start |
| 1142 | // and End. |
| 1143 | if (auto constantStep = getConstantStep()) |
| 1144 | if (*constantStep == 1) |
| 1145 | return Speculation::RecursivelySpeculatable; |
| 1146 | |
| 1147 | // For Step != 1, the loop may not terminate. We can add more smarts here if |
| 1148 | // needed. |
| 1149 | return Speculation::NotSpeculatable; |
| 1150 | } |
| 1151 | |
| 1152 | //===----------------------------------------------------------------------===// |
| 1153 | // ForallOp |
| 1154 | //===----------------------------------------------------------------------===// |
| 1155 | |
| 1156 | LogicalResult ForallOp::verify() { |
| 1157 | unsigned numLoops = getRank(); |
| 1158 | // Check number of outputs. |
| 1159 | if (getNumResults() != getOutputs().size()) |
| 1160 | return emitOpError(message: "produces " ) |
| 1161 | << getNumResults() << " results, but has only " |
| 1162 | << getOutputs().size() << " outputs" ; |
| 1163 | |
| 1164 | // Check that the body defines block arguments for thread indices and outputs. |
| 1165 | auto *body = getBody(); |
| 1166 | if (body->getNumArguments() != numLoops + getOutputs().size()) |
| 1167 | return emitOpError(message: "region expects " ) << numLoops << " arguments" ; |
| 1168 | for (int64_t i = 0; i < numLoops; ++i) |
| 1169 | if (!body->getArgument(i).getType().isIndex()) |
| 1170 | return emitOpError(message: "expects " ) |
| 1171 | << i << "-th block argument to be an index" ; |
| 1172 | for (unsigned i = 0; i < getOutputs().size(); ++i) |
| 1173 | if (body->getArgument(i: i + numLoops).getType() != getOutputs()[i].getType()) |
| 1174 | return emitOpError(message: "type mismatch between " ) |
| 1175 | << i << "-th output and corresponding block argument" ; |
| 1176 | if (getMapping().has_value() && !getMapping()->empty()) { |
| 1177 | if (getDeviceMappingAttrs().size() != numLoops) |
| 1178 | return emitOpError() << "mapping attribute size must match op rank" ; |
| 1179 | if (failed(Result: getDeviceMaskingAttr())) |
| 1180 | return emitOpError() << getMappingAttrName() |
| 1181 | << " supports at most one device masking attribute" ; |
| 1182 | } |
| 1183 | |
| 1184 | // Verify mixed static/dynamic control variables. |
| 1185 | Operation *op = getOperation(); |
| 1186 | if (failed(Result: verifyListOfOperandsOrIntegers(op, name: "lower bound" , expectedNumElements: numLoops, |
| 1187 | attr: getStaticLowerBound(), |
| 1188 | values: getDynamicLowerBound()))) |
| 1189 | return failure(); |
| 1190 | if (failed(Result: verifyListOfOperandsOrIntegers(op, name: "upper bound" , expectedNumElements: numLoops, |
| 1191 | attr: getStaticUpperBound(), |
| 1192 | values: getDynamicUpperBound()))) |
| 1193 | return failure(); |
| 1194 | if (failed(Result: verifyListOfOperandsOrIntegers(op, name: "step" , expectedNumElements: numLoops, |
| 1195 | attr: getStaticStep(), values: getDynamicStep()))) |
| 1196 | return failure(); |
| 1197 | |
| 1198 | return success(); |
| 1199 | } |
| 1200 | |
| 1201 | void ForallOp::print(OpAsmPrinter &p) { |
| 1202 | Operation *op = getOperation(); |
| 1203 | p << " (" << getInductionVars(); |
| 1204 | if (isNormalized()) { |
| 1205 | p << ") in " ; |
| 1206 | printDynamicIndexList(printer&: p, op, values: getDynamicUpperBound(), integers: getStaticUpperBound(), |
| 1207 | /*valueTypes=*/scalableFlags: {}, /*scalables=*/valueTypes: {}, |
| 1208 | delimiter: OpAsmParser::Delimiter::Paren); |
| 1209 | } else { |
| 1210 | p << ") = " ; |
| 1211 | printDynamicIndexList(printer&: p, op, values: getDynamicLowerBound(), integers: getStaticLowerBound(), |
| 1212 | /*valueTypes=*/scalableFlags: {}, /*scalables=*/valueTypes: {}, |
| 1213 | delimiter: OpAsmParser::Delimiter::Paren); |
| 1214 | p << " to " ; |
| 1215 | printDynamicIndexList(printer&: p, op, values: getDynamicUpperBound(), integers: getStaticUpperBound(), |
| 1216 | /*valueTypes=*/scalableFlags: {}, /*scalables=*/valueTypes: {}, |
| 1217 | delimiter: OpAsmParser::Delimiter::Paren); |
| 1218 | p << " step " ; |
| 1219 | printDynamicIndexList(printer&: p, op, values: getDynamicStep(), integers: getStaticStep(), |
| 1220 | /*valueTypes=*/scalableFlags: {}, /*scalables=*/valueTypes: {}, |
| 1221 | delimiter: OpAsmParser::Delimiter::Paren); |
| 1222 | } |
| 1223 | printInitializationList(p, blocksArgs: getRegionOutArgs(), initializers: getOutputs(), prefix: " shared_outs" ); |
| 1224 | p << " " ; |
| 1225 | if (!getRegionOutArgs().empty()) |
| 1226 | p << "-> (" << getResultTypes() << ") " ; |
| 1227 | p.printRegion(blocks&: getRegion(), |
| 1228 | /*printEntryBlockArgs=*/false, |
| 1229 | /*printBlockTerminators=*/getNumResults() > 0); |
| 1230 | p.printOptionalAttrDict(attrs: op->getAttrs(), elidedAttrs: {getOperandSegmentSizesAttrName(), |
| 1231 | getStaticLowerBoundAttrName(), |
| 1232 | getStaticUpperBoundAttrName(), |
| 1233 | getStaticStepAttrName()}); |
| 1234 | } |
| 1235 | |
| 1236 | ParseResult ForallOp::parse(OpAsmParser &parser, OperationState &result) { |
| 1237 | OpBuilder b(parser.getContext()); |
| 1238 | auto indexType = b.getIndexType(); |
| 1239 | |
| 1240 | // Parse an opening `(` followed by thread index variables followed by `)` |
| 1241 | // TODO: when we can refer to such "induction variable"-like handles from the |
| 1242 | // declarative assembly format, we can implement the parser as a custom hook. |
| 1243 | SmallVector<OpAsmParser::Argument, 4> ivs; |
| 1244 | if (parser.parseArgumentList(result&: ivs, delimiter: OpAsmParser::Delimiter::Paren)) |
| 1245 | return failure(); |
| 1246 | |
| 1247 | DenseI64ArrayAttr staticLbs, staticUbs, staticSteps; |
| 1248 | SmallVector<OpAsmParser::UnresolvedOperand> dynamicLbs, dynamicUbs, |
| 1249 | dynamicSteps; |
| 1250 | if (succeeded(Result: parser.parseOptionalKeyword(keyword: "in" ))) { |
| 1251 | // Parse upper bounds. |
| 1252 | if (parseDynamicIndexList(parser, values&: dynamicUbs, integers&: staticUbs, |
| 1253 | /*valueTypes=*/nullptr, |
| 1254 | delimiter: OpAsmParser::Delimiter::Paren) || |
| 1255 | parser.resolveOperands(operands&: dynamicUbs, type: indexType, result&: result.operands)) |
| 1256 | return failure(); |
| 1257 | |
| 1258 | unsigned numLoops = ivs.size(); |
| 1259 | staticLbs = b.getDenseI64ArrayAttr(values: SmallVector<int64_t>(numLoops, 0)); |
| 1260 | staticSteps = b.getDenseI64ArrayAttr(values: SmallVector<int64_t>(numLoops, 1)); |
| 1261 | } else { |
| 1262 | // Parse lower bounds. |
| 1263 | if (parser.parseEqual() || |
| 1264 | parseDynamicIndexList(parser, values&: dynamicLbs, integers&: staticLbs, |
| 1265 | /*valueTypes=*/nullptr, |
| 1266 | delimiter: OpAsmParser::Delimiter::Paren) || |
| 1267 | |
| 1268 | parser.resolveOperands(operands&: dynamicLbs, type: indexType, result&: result.operands)) |
| 1269 | return failure(); |
| 1270 | |
| 1271 | // Parse upper bounds. |
| 1272 | if (parser.parseKeyword(keyword: "to" ) || |
| 1273 | parseDynamicIndexList(parser, values&: dynamicUbs, integers&: staticUbs, |
| 1274 | /*valueTypes=*/nullptr, |
| 1275 | delimiter: OpAsmParser::Delimiter::Paren) || |
| 1276 | parser.resolveOperands(operands&: dynamicUbs, type: indexType, result&: result.operands)) |
| 1277 | return failure(); |
| 1278 | |
| 1279 | // Parse step values. |
| 1280 | if (parser.parseKeyword(keyword: "step" ) || |
| 1281 | parseDynamicIndexList(parser, values&: dynamicSteps, integers&: staticSteps, |
| 1282 | /*valueTypes=*/nullptr, |
| 1283 | delimiter: OpAsmParser::Delimiter::Paren) || |
| 1284 | parser.resolveOperands(operands&: dynamicSteps, type: indexType, result&: result.operands)) |
| 1285 | return failure(); |
| 1286 | } |
| 1287 | |
| 1288 | // Parse out operands and results. |
| 1289 | SmallVector<OpAsmParser::Argument, 4> regionOutArgs; |
| 1290 | SmallVector<OpAsmParser::UnresolvedOperand, 4> outOperands; |
| 1291 | SMLoc outOperandsLoc = parser.getCurrentLocation(); |
| 1292 | if (succeeded(Result: parser.parseOptionalKeyword(keyword: "shared_outs" ))) { |
| 1293 | if (outOperands.size() != result.types.size()) |
| 1294 | return parser.emitError(loc: outOperandsLoc, |
| 1295 | message: "mismatch between out operands and types" ); |
| 1296 | if (parser.parseAssignmentList(lhs&: regionOutArgs, rhs&: outOperands) || |
| 1297 | parser.parseOptionalArrowTypeList(result&: result.types) || |
| 1298 | parser.resolveOperands(operands&: outOperands, types&: result.types, loc: outOperandsLoc, |
| 1299 | result&: result.operands)) |
| 1300 | return failure(); |
| 1301 | } |
| 1302 | |
| 1303 | // Parse region. |
| 1304 | SmallVector<OpAsmParser::Argument, 4> regionArgs; |
| 1305 | std::unique_ptr<Region> region = std::make_unique<Region>(); |
| 1306 | for (auto &iv : ivs) { |
| 1307 | iv.type = b.getIndexType(); |
| 1308 | regionArgs.push_back(Elt: iv); |
| 1309 | } |
| 1310 | for (const auto &it : llvm::enumerate(First&: regionOutArgs)) { |
| 1311 | auto &out = it.value(); |
| 1312 | out.type = result.types[it.index()]; |
| 1313 | regionArgs.push_back(Elt: out); |
| 1314 | } |
| 1315 | if (parser.parseRegion(region&: *region, arguments: regionArgs)) |
| 1316 | return failure(); |
| 1317 | |
| 1318 | // Ensure terminator and move region. |
| 1319 | ForallOp::ensureTerminator(region&: *region, builder&: b, loc: result.location); |
| 1320 | result.addRegion(region: std::move(region)); |
| 1321 | |
| 1322 | // Parse the optional attribute list. |
| 1323 | if (parser.parseOptionalAttrDict(result&: result.attributes)) |
| 1324 | return failure(); |
| 1325 | |
| 1326 | result.addAttribute(name: "staticLowerBound" , attr: staticLbs); |
| 1327 | result.addAttribute(name: "staticUpperBound" , attr: staticUbs); |
| 1328 | result.addAttribute(name: "staticStep" , attr: staticSteps); |
| 1329 | result.addAttribute(name: "operandSegmentSizes" , |
| 1330 | attr: parser.getBuilder().getDenseI32ArrayAttr( |
| 1331 | values: {static_cast<int32_t>(dynamicLbs.size()), |
| 1332 | static_cast<int32_t>(dynamicUbs.size()), |
| 1333 | static_cast<int32_t>(dynamicSteps.size()), |
| 1334 | static_cast<int32_t>(outOperands.size())})); |
| 1335 | return success(); |
| 1336 | } |
| 1337 | |
| 1338 | // Builder that takes loop bounds. |
| 1339 | void ForallOp::build( |
| 1340 | mlir::OpBuilder &b, mlir::OperationState &result, |
| 1341 | ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs, |
| 1342 | ArrayRef<OpFoldResult> steps, ValueRange outputs, |
| 1343 | std::optional<ArrayAttr> mapping, |
| 1344 | function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) { |
| 1345 | SmallVector<int64_t> staticLbs, staticUbs, staticSteps; |
| 1346 | SmallVector<Value> dynamicLbs, dynamicUbs, dynamicSteps; |
| 1347 | dispatchIndexOpFoldResults(ofrs: lbs, dynamicVec&: dynamicLbs, staticVec&: staticLbs); |
| 1348 | dispatchIndexOpFoldResults(ofrs: ubs, dynamicVec&: dynamicUbs, staticVec&: staticUbs); |
| 1349 | dispatchIndexOpFoldResults(ofrs: steps, dynamicVec&: dynamicSteps, staticVec&: staticSteps); |
| 1350 | |
| 1351 | result.addOperands(newOperands: dynamicLbs); |
| 1352 | result.addOperands(newOperands: dynamicUbs); |
| 1353 | result.addOperands(newOperands: dynamicSteps); |
| 1354 | result.addOperands(newOperands: outputs); |
| 1355 | result.addTypes(newTypes: TypeRange(outputs)); |
| 1356 | |
| 1357 | result.addAttribute(name: getStaticLowerBoundAttrName(name: result.name), |
| 1358 | attr: b.getDenseI64ArrayAttr(values: staticLbs)); |
| 1359 | result.addAttribute(name: getStaticUpperBoundAttrName(name: result.name), |
| 1360 | attr: b.getDenseI64ArrayAttr(values: staticUbs)); |
| 1361 | result.addAttribute(name: getStaticStepAttrName(name: result.name), |
| 1362 | attr: b.getDenseI64ArrayAttr(values: staticSteps)); |
| 1363 | result.addAttribute( |
| 1364 | name: "operandSegmentSizes" , |
| 1365 | attr: b.getDenseI32ArrayAttr(values: {static_cast<int32_t>(dynamicLbs.size()), |
| 1366 | static_cast<int32_t>(dynamicUbs.size()), |
| 1367 | static_cast<int32_t>(dynamicSteps.size()), |
| 1368 | static_cast<int32_t>(outputs.size())})); |
| 1369 | if (mapping.has_value()) { |
| 1370 | result.addAttribute(name: ForallOp::getMappingAttrName(name: result.name), |
| 1371 | attr: mapping.value()); |
| 1372 | } |
| 1373 | |
| 1374 | Region *bodyRegion = result.addRegion(); |
| 1375 | OpBuilder::InsertionGuard g(b); |
| 1376 | b.createBlock(parent: bodyRegion); |
| 1377 | Block &bodyBlock = bodyRegion->front(); |
| 1378 | |
| 1379 | // Add block arguments for indices and outputs. |
| 1380 | bodyBlock.addArguments( |
| 1381 | types: SmallVector<Type>(lbs.size(), b.getIndexType()), |
| 1382 | locs: SmallVector<Location>(staticLbs.size(), result.location)); |
| 1383 | bodyBlock.addArguments( |
| 1384 | types: TypeRange(outputs), |
| 1385 | locs: SmallVector<Location>(outputs.size(), result.location)); |
| 1386 | |
| 1387 | b.setInsertionPointToStart(&bodyBlock); |
| 1388 | if (!bodyBuilderFn) { |
| 1389 | ForallOp::ensureTerminator(region&: *bodyRegion, builder&: b, loc: result.location); |
| 1390 | return; |
| 1391 | } |
| 1392 | bodyBuilderFn(b, result.location, bodyBlock.getArguments()); |
| 1393 | } |
| 1394 | |
| 1395 | // Builder that takes loop bounds. |
| 1396 | void ForallOp::build( |
| 1397 | mlir::OpBuilder &b, mlir::OperationState &result, |
| 1398 | ArrayRef<OpFoldResult> ubs, ValueRange outputs, |
| 1399 | std::optional<ArrayAttr> mapping, |
| 1400 | function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) { |
| 1401 | unsigned numLoops = ubs.size(); |
| 1402 | SmallVector<OpFoldResult> lbs(numLoops, b.getIndexAttr(value: 0)); |
| 1403 | SmallVector<OpFoldResult> steps(numLoops, b.getIndexAttr(value: 1)); |
| 1404 | build(b, result, lbs, ubs, steps, outputs, mapping, bodyBuilderFn); |
| 1405 | } |
| 1406 | |
| 1407 | // Checks if the lbs are zeros and steps are ones. |
| 1408 | bool ForallOp::isNormalized() { |
| 1409 | auto allEqual = [](ArrayRef<OpFoldResult> results, int64_t val) { |
| 1410 | return llvm::all_of(Range&: results, P: [&](OpFoldResult ofr) { |
| 1411 | auto intValue = getConstantIntValue(ofr); |
| 1412 | return intValue.has_value() && intValue == val; |
| 1413 | }); |
| 1414 | }; |
| 1415 | return allEqual(getMixedLowerBound(), 0) && allEqual(getMixedStep(), 1); |
| 1416 | } |
| 1417 | |
| 1418 | InParallelOp ForallOp::getTerminator() { |
| 1419 | return cast<InParallelOp>(Val: getBody()->getTerminator()); |
| 1420 | } |
| 1421 | |
| 1422 | SmallVector<Operation *> ForallOp::getCombiningOps(BlockArgument bbArg) { |
| 1423 | SmallVector<Operation *> storeOps; |
| 1424 | InParallelOp inParallelOp = getTerminator(); |
| 1425 | for (Operation &yieldOp : inParallelOp.getYieldingOps()) { |
| 1426 | if (auto parallelInsertSliceOp = |
| 1427 | dyn_cast<tensor::ParallelInsertSliceOp>(Val&: yieldOp); |
| 1428 | parallelInsertSliceOp && parallelInsertSliceOp.getDest() == bbArg) { |
| 1429 | storeOps.push_back(Elt: parallelInsertSliceOp); |
| 1430 | } |
| 1431 | } |
| 1432 | return storeOps; |
| 1433 | } |
| 1434 | |
| 1435 | SmallVector<DeviceMappingAttrInterface> ForallOp::getDeviceMappingAttrs() { |
| 1436 | SmallVector<DeviceMappingAttrInterface> res; |
| 1437 | if (!getMapping()) |
| 1438 | return res; |
| 1439 | for (auto attr : getMapping()->getValue()) { |
| 1440 | auto m = dyn_cast<DeviceMappingAttrInterface>(Val&: attr); |
| 1441 | if (m) |
| 1442 | res.push_back(Elt: m); |
| 1443 | } |
| 1444 | return res; |
| 1445 | } |
| 1446 | |
| 1447 | FailureOr<DeviceMaskingAttrInterface> ForallOp::getDeviceMaskingAttr() { |
| 1448 | DeviceMaskingAttrInterface res; |
| 1449 | if (!getMapping()) |
| 1450 | return res; |
| 1451 | for (auto attr : getMapping()->getValue()) { |
| 1452 | auto m = dyn_cast<DeviceMaskingAttrInterface>(Val&: attr); |
| 1453 | if (m && res) |
| 1454 | return failure(); |
| 1455 | if (m) |
| 1456 | res = m; |
| 1457 | } |
| 1458 | return res; |
| 1459 | } |
| 1460 | |
| 1461 | bool ForallOp::usesLinearMapping() { |
| 1462 | SmallVector<DeviceMappingAttrInterface> ifaces = getDeviceMappingAttrs(); |
| 1463 | if (ifaces.empty()) |
| 1464 | return false; |
| 1465 | return ifaces.front().isLinearMapping(); |
| 1466 | } |
| 1467 | |
| 1468 | std::optional<SmallVector<Value>> ForallOp::getLoopInductionVars() { |
| 1469 | return SmallVector<Value>{getBody()->getArguments().take_front(N: getRank())}; |
| 1470 | } |
| 1471 | |
| 1472 | // Get lower bounds as OpFoldResult. |
| 1473 | std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopLowerBounds() { |
| 1474 | Builder b(getOperation()->getContext()); |
| 1475 | return getMixedValues(staticValues: getStaticLowerBound(), dynamicValues: getDynamicLowerBound(), b); |
| 1476 | } |
| 1477 | |
| 1478 | // Get upper bounds as OpFoldResult. |
| 1479 | std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopUpperBounds() { |
| 1480 | Builder b(getOperation()->getContext()); |
| 1481 | return getMixedValues(staticValues: getStaticUpperBound(), dynamicValues: getDynamicUpperBound(), b); |
| 1482 | } |
| 1483 | |
| 1484 | // Get steps as OpFoldResult. |
| 1485 | std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopSteps() { |
| 1486 | Builder b(getOperation()->getContext()); |
| 1487 | return getMixedValues(staticValues: getStaticStep(), dynamicValues: getDynamicStep(), b); |
| 1488 | } |
| 1489 | |
| 1490 | ForallOp mlir::scf::getForallOpThreadIndexOwner(Value val) { |
| 1491 | auto tidxArg = llvm::dyn_cast<BlockArgument>(Val&: val); |
| 1492 | if (!tidxArg) |
| 1493 | return ForallOp(); |
| 1494 | assert(tidxArg.getOwner() && "unlinked block argument" ); |
| 1495 | auto *containingOp = tidxArg.getOwner()->getParentOp(); |
| 1496 | return dyn_cast<ForallOp>(Val: containingOp); |
| 1497 | } |
| 1498 | |
| 1499 | namespace { |
| 1500 | /// Fold tensor.dim(forall shared_outs(... = %t)) to tensor.dim(%t). |
| 1501 | struct DimOfForallOp : public OpRewritePattern<tensor::DimOp> { |
| 1502 | using OpRewritePattern<tensor::DimOp>::OpRewritePattern; |
| 1503 | |
| 1504 | LogicalResult matchAndRewrite(tensor::DimOp dimOp, |
| 1505 | PatternRewriter &rewriter) const final { |
| 1506 | auto forallOp = dimOp.getSource().getDefiningOp<ForallOp>(); |
| 1507 | if (!forallOp) |
| 1508 | return failure(); |
| 1509 | Value sharedOut = |
| 1510 | forallOp.getTiedOpOperand(opResult: llvm::cast<OpResult>(Val: dimOp.getSource())) |
| 1511 | ->get(); |
| 1512 | rewriter.modifyOpInPlace( |
| 1513 | root: dimOp, callable: [&]() { dimOp.getSourceMutable().assign(value: sharedOut); }); |
| 1514 | return success(); |
| 1515 | } |
| 1516 | }; |
| 1517 | |
| 1518 | class ForallOpControlOperandsFolder : public OpRewritePattern<ForallOp> { |
| 1519 | public: |
| 1520 | using OpRewritePattern<ForallOp>::OpRewritePattern; |
| 1521 | |
| 1522 | LogicalResult matchAndRewrite(ForallOp op, |
| 1523 | PatternRewriter &rewriter) const override { |
| 1524 | SmallVector<OpFoldResult> mixedLowerBound(op.getMixedLowerBound()); |
| 1525 | SmallVector<OpFoldResult> mixedUpperBound(op.getMixedUpperBound()); |
| 1526 | SmallVector<OpFoldResult> mixedStep(op.getMixedStep()); |
| 1527 | if (failed(Result: foldDynamicIndexList(ofrs&: mixedLowerBound)) && |
| 1528 | failed(Result: foldDynamicIndexList(ofrs&: mixedUpperBound)) && |
| 1529 | failed(Result: foldDynamicIndexList(ofrs&: mixedStep))) |
| 1530 | return failure(); |
| 1531 | |
| 1532 | rewriter.modifyOpInPlace(root: op, callable: [&]() { |
| 1533 | SmallVector<Value> dynamicLowerBound, dynamicUpperBound, dynamicStep; |
| 1534 | SmallVector<int64_t> staticLowerBound, staticUpperBound, staticStep; |
| 1535 | dispatchIndexOpFoldResults(ofrs: mixedLowerBound, dynamicVec&: dynamicLowerBound, |
| 1536 | staticVec&: staticLowerBound); |
| 1537 | op.getDynamicLowerBoundMutable().assign(values: dynamicLowerBound); |
| 1538 | op.setStaticLowerBound(staticLowerBound); |
| 1539 | |
| 1540 | dispatchIndexOpFoldResults(ofrs: mixedUpperBound, dynamicVec&: dynamicUpperBound, |
| 1541 | staticVec&: staticUpperBound); |
| 1542 | op.getDynamicUpperBoundMutable().assign(values: dynamicUpperBound); |
| 1543 | op.setStaticUpperBound(staticUpperBound); |
| 1544 | |
| 1545 | dispatchIndexOpFoldResults(ofrs: mixedStep, dynamicVec&: dynamicStep, staticVec&: staticStep); |
| 1546 | op.getDynamicStepMutable().assign(values: dynamicStep); |
| 1547 | op.setStaticStep(staticStep); |
| 1548 | |
| 1549 | op->setAttr(name: ForallOp::getOperandSegmentSizeAttr(), |
| 1550 | value: rewriter.getDenseI32ArrayAttr( |
| 1551 | values: {static_cast<int32_t>(dynamicLowerBound.size()), |
| 1552 | static_cast<int32_t>(dynamicUpperBound.size()), |
| 1553 | static_cast<int32_t>(dynamicStep.size()), |
| 1554 | static_cast<int32_t>(op.getNumResults())})); |
| 1555 | }); |
| 1556 | return success(); |
| 1557 | } |
| 1558 | }; |
| 1559 | |
| 1560 | /// The following canonicalization pattern folds the iter arguments of |
| 1561 | /// scf.forall op if :- |
| 1562 | /// 1. The corresponding result has zero uses. |
| 1563 | /// 2. The iter argument is NOT being modified within the loop body. |
| 1564 | /// uses. |
| 1565 | /// |
| 1566 | /// Example of first case :- |
| 1567 | /// INPUT: |
| 1568 | /// %res:3 = scf.forall ... shared_outs(%arg0 = %a, %arg1 = %b, %arg2 = %c) |
| 1569 | /// { |
| 1570 | /// ... |
| 1571 | /// <SOME USE OF %arg0> |
| 1572 | /// <SOME USE OF %arg1> |
| 1573 | /// <SOME USE OF %arg2> |
| 1574 | /// ... |
| 1575 | /// scf.forall.in_parallel { |
| 1576 | /// <STORE OP WITH DESTINATION %arg1> |
| 1577 | /// <STORE OP WITH DESTINATION %arg0> |
| 1578 | /// <STORE OP WITH DESTINATION %arg2> |
| 1579 | /// } |
| 1580 | /// } |
| 1581 | /// return %res#1 |
| 1582 | /// |
| 1583 | /// OUTPUT: |
| 1584 | /// %res:3 = scf.forall ... shared_outs(%new_arg0 = %b) |
| 1585 | /// { |
| 1586 | /// ... |
| 1587 | /// <SOME USE OF %a> |
| 1588 | /// <SOME USE OF %new_arg0> |
| 1589 | /// <SOME USE OF %c> |
| 1590 | /// ... |
| 1591 | /// scf.forall.in_parallel { |
| 1592 | /// <STORE OP WITH DESTINATION %new_arg0> |
| 1593 | /// } |
| 1594 | /// } |
| 1595 | /// return %res |
| 1596 | /// |
| 1597 | /// NOTE: 1. All uses of the folded shared_outs (iter argument) within the |
| 1598 | /// scf.forall is replaced by their corresponding operands. |
| 1599 | /// 2. Even if there are <STORE OP WITH DESTINATION *> ops within the body |
| 1600 | /// of the scf.forall besides within scf.forall.in_parallel terminator, |
| 1601 | /// this canonicalization remains valid. For more details, please refer |
| 1602 | /// to : |
| 1603 | /// https://github.com/llvm/llvm-project/pull/90189#discussion_r1589011124 |
| 1604 | /// 3. TODO(avarma): Generalize it for other store ops. Currently it |
| 1605 | /// handles tensor.parallel_insert_slice ops only. |
| 1606 | /// |
| 1607 | /// Example of second case :- |
| 1608 | /// INPUT: |
| 1609 | /// %res:2 = scf.forall ... shared_outs(%arg0 = %a, %arg1 = %b) |
| 1610 | /// { |
| 1611 | /// ... |
| 1612 | /// <SOME USE OF %arg0> |
| 1613 | /// <SOME USE OF %arg1> |
| 1614 | /// ... |
| 1615 | /// scf.forall.in_parallel { |
| 1616 | /// <STORE OP WITH DESTINATION %arg1> |
| 1617 | /// } |
| 1618 | /// } |
| 1619 | /// return %res#0, %res#1 |
| 1620 | /// |
| 1621 | /// OUTPUT: |
| 1622 | /// %res = scf.forall ... shared_outs(%new_arg0 = %b) |
| 1623 | /// { |
| 1624 | /// ... |
| 1625 | /// <SOME USE OF %a> |
| 1626 | /// <SOME USE OF %new_arg0> |
| 1627 | /// ... |
| 1628 | /// scf.forall.in_parallel { |
| 1629 | /// <STORE OP WITH DESTINATION %new_arg0> |
| 1630 | /// } |
| 1631 | /// } |
| 1632 | /// return %a, %res |
| 1633 | struct ForallOpIterArgsFolder : public OpRewritePattern<ForallOp> { |
| 1634 | using OpRewritePattern<ForallOp>::OpRewritePattern; |
| 1635 | |
| 1636 | LogicalResult matchAndRewrite(ForallOp forallOp, |
| 1637 | PatternRewriter &rewriter) const final { |
| 1638 | // Step 1: For a given i-th result of scf.forall, check the following :- |
| 1639 | // a. If it has any use. |
| 1640 | // b. If the corresponding iter argument is being modified within |
| 1641 | // the loop, i.e. has at least one store op with the iter arg as |
| 1642 | // its destination operand. For this we use |
| 1643 | // ForallOp::getCombiningOps(iter_arg). |
| 1644 | // |
| 1645 | // Based on the check we maintain the following :- |
| 1646 | // a. `resultToDelete` - i-th result of scf.forall that'll be |
| 1647 | // deleted. |
| 1648 | // b. `resultToReplace` - i-th result of the old scf.forall |
| 1649 | // whose uses will be replaced by the new scf.forall. |
| 1650 | // c. `newOuts` - the shared_outs' operand of the new scf.forall |
| 1651 | // corresponding to the i-th result with at least one use. |
| 1652 | SetVector<OpResult> resultToDelete; |
| 1653 | SmallVector<Value> resultToReplace; |
| 1654 | SmallVector<Value> newOuts; |
| 1655 | for (OpResult result : forallOp.getResults()) { |
| 1656 | OpOperand *opOperand = forallOp.getTiedOpOperand(opResult: result); |
| 1657 | BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand); |
| 1658 | if (result.use_empty() || forallOp.getCombiningOps(bbArg: blockArg).empty()) { |
| 1659 | resultToDelete.insert(X: result); |
| 1660 | } else { |
| 1661 | resultToReplace.push_back(Elt: result); |
| 1662 | newOuts.push_back(Elt: opOperand->get()); |
| 1663 | } |
| 1664 | } |
| 1665 | |
| 1666 | // Return early if all results of scf.forall have at least one use and being |
| 1667 | // modified within the loop. |
| 1668 | if (resultToDelete.empty()) |
| 1669 | return failure(); |
| 1670 | |
| 1671 | // Step 2: For the the i-th result, do the following :- |
| 1672 | // a. Fetch the corresponding BlockArgument. |
| 1673 | // b. Look for store ops (currently tensor.parallel_insert_slice) |
| 1674 | // with the BlockArgument as its destination operand. |
| 1675 | // c. Remove the operations fetched in b. |
| 1676 | for (OpResult result : resultToDelete) { |
| 1677 | OpOperand *opOperand = forallOp.getTiedOpOperand(opResult: result); |
| 1678 | BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand); |
| 1679 | SmallVector<Operation *> combiningOps = |
| 1680 | forallOp.getCombiningOps(bbArg: blockArg); |
| 1681 | for (Operation *combiningOp : combiningOps) |
| 1682 | rewriter.eraseOp(op: combiningOp); |
| 1683 | } |
| 1684 | |
| 1685 | // Step 3. Create a new scf.forall op with the new shared_outs' operands |
| 1686 | // fetched earlier |
| 1687 | auto newForallOp = rewriter.create<scf::ForallOp>( |
| 1688 | location: forallOp.getLoc(), args: forallOp.getMixedLowerBound(), |
| 1689 | args: forallOp.getMixedUpperBound(), args: forallOp.getMixedStep(), args&: newOuts, |
| 1690 | args: forallOp.getMapping(), |
| 1691 | /*bodyBuilderFn =*/args: [](OpBuilder &, Location, ValueRange) {}); |
| 1692 | |
| 1693 | // Step 4. Merge the block of the old scf.forall into the newly created |
| 1694 | // scf.forall using the new set of arguments. |
| 1695 | Block *loopBody = forallOp.getBody(); |
| 1696 | Block *newLoopBody = newForallOp.getBody(); |
| 1697 | ArrayRef<BlockArgument> newBbArgs = newLoopBody->getArguments(); |
| 1698 | // Form initial new bbArg list with just the control operands of the new |
| 1699 | // scf.forall op. |
| 1700 | SmallVector<Value> newBlockArgs = |
| 1701 | llvm::map_to_vector(C: newBbArgs.take_front(N: forallOp.getRank()), |
| 1702 | F: [](BlockArgument b) -> Value { return b; }); |
| 1703 | Block::BlockArgListType newSharedOutsArgs = newForallOp.getRegionOutArgs(); |
| 1704 | unsigned index = 0; |
| 1705 | // Take the new corresponding bbArg if the old bbArg was used as a |
| 1706 | // destination in the in_parallel op. For all other bbArgs, use the |
| 1707 | // corresponding init_arg from the old scf.forall op. |
| 1708 | for (OpResult result : forallOp.getResults()) { |
| 1709 | if (resultToDelete.count(key: result)) { |
| 1710 | newBlockArgs.push_back(Elt: forallOp.getTiedOpOperand(opResult: result)->get()); |
| 1711 | } else { |
| 1712 | newBlockArgs.push_back(Elt: newSharedOutsArgs[index++]); |
| 1713 | } |
| 1714 | } |
| 1715 | rewriter.mergeBlocks(source: loopBody, dest: newLoopBody, argValues: newBlockArgs); |
| 1716 | |
| 1717 | // Step 5. Replace the uses of result of old scf.forall with that of the new |
| 1718 | // scf.forall. |
| 1719 | for (auto &&[oldResult, newResult] : |
| 1720 | llvm::zip(t&: resultToReplace, u: newForallOp->getResults())) |
| 1721 | rewriter.replaceAllUsesWith(from: oldResult, to: newResult); |
| 1722 | |
| 1723 | // Step 6. Replace the uses of those values that either has no use or are |
| 1724 | // not being modified within the loop with the corresponding |
| 1725 | // OpOperand. |
| 1726 | for (OpResult oldResult : resultToDelete) |
| 1727 | rewriter.replaceAllUsesWith(from: oldResult, |
| 1728 | to: forallOp.getTiedOpOperand(opResult: oldResult)->get()); |
| 1729 | return success(); |
| 1730 | } |
| 1731 | }; |
| 1732 | |
| 1733 | struct ForallOpSingleOrZeroIterationDimsFolder |
| 1734 | : public OpRewritePattern<ForallOp> { |
| 1735 | using OpRewritePattern<ForallOp>::OpRewritePattern; |
| 1736 | |
| 1737 | LogicalResult matchAndRewrite(ForallOp op, |
| 1738 | PatternRewriter &rewriter) const override { |
| 1739 | // Do not fold dimensions if they are mapped to processing units. |
| 1740 | if (op.getMapping().has_value() && !op.getMapping()->empty()) |
| 1741 | return failure(); |
| 1742 | Location loc = op.getLoc(); |
| 1743 | |
| 1744 | // Compute new loop bounds that omit all single-iteration loop dimensions. |
| 1745 | SmallVector<OpFoldResult> newMixedLowerBounds, newMixedUpperBounds, |
| 1746 | newMixedSteps; |
| 1747 | IRMapping mapping; |
| 1748 | for (auto [lb, ub, step, iv] : |
| 1749 | llvm::zip(t: op.getMixedLowerBound(), u: op.getMixedUpperBound(), |
| 1750 | args: op.getMixedStep(), args: op.getInductionVars())) { |
| 1751 | auto numIterations = constantTripCount(lb, ub, step); |
| 1752 | if (numIterations.has_value()) { |
| 1753 | // Remove the loop if it performs zero iterations. |
| 1754 | if (*numIterations == 0) { |
| 1755 | rewriter.replaceOp(op, newValues: op.getOutputs()); |
| 1756 | return success(); |
| 1757 | } |
| 1758 | // Replace the loop induction variable by the lower bound if the loop |
| 1759 | // performs a single iteration. Otherwise, copy the loop bounds. |
| 1760 | if (*numIterations == 1) { |
| 1761 | mapping.map(from: iv, to: getValueOrCreateConstantIndexOp(b&: rewriter, loc, ofr: lb)); |
| 1762 | continue; |
| 1763 | } |
| 1764 | } |
| 1765 | newMixedLowerBounds.push_back(Elt: lb); |
| 1766 | newMixedUpperBounds.push_back(Elt: ub); |
| 1767 | newMixedSteps.push_back(Elt: step); |
| 1768 | } |
| 1769 | |
| 1770 | // All of the loop dimensions perform a single iteration. Inline loop body. |
| 1771 | if (newMixedLowerBounds.empty()) { |
| 1772 | promote(rewriter, forallOp: op); |
| 1773 | return success(); |
| 1774 | } |
| 1775 | |
| 1776 | // Exit if none of the loop dimensions perform a single iteration. |
| 1777 | if (newMixedLowerBounds.size() == static_cast<unsigned>(op.getRank())) { |
| 1778 | return rewriter.notifyMatchFailure( |
| 1779 | arg&: op, msg: "no dimensions have 0 or 1 iterations" ); |
| 1780 | } |
| 1781 | |
| 1782 | // Replace the loop by a lower-dimensional loop. |
| 1783 | ForallOp newOp; |
| 1784 | newOp = rewriter.create<ForallOp>(location: loc, args&: newMixedLowerBounds, |
| 1785 | args&: newMixedUpperBounds, args&: newMixedSteps, |
| 1786 | args: op.getOutputs(), args: std::nullopt, args: nullptr); |
| 1787 | newOp.getBodyRegion().getBlocks().clear(); |
| 1788 | // The new loop needs to keep all attributes from the old one, except for |
| 1789 | // "operandSegmentSizes" and static loop bound attributes which capture |
| 1790 | // the outdated information of the old iteration domain. |
| 1791 | SmallVector<StringAttr> elidedAttrs{newOp.getOperandSegmentSizesAttrName(), |
| 1792 | newOp.getStaticLowerBoundAttrName(), |
| 1793 | newOp.getStaticUpperBoundAttrName(), |
| 1794 | newOp.getStaticStepAttrName()}; |
| 1795 | for (const auto &namedAttr : op->getAttrs()) { |
| 1796 | if (llvm::is_contained(Range&: elidedAttrs, Element: namedAttr.getName())) |
| 1797 | continue; |
| 1798 | rewriter.modifyOpInPlace(root: newOp, callable: [&]() { |
| 1799 | newOp->setAttr(name: namedAttr.getName(), value: namedAttr.getValue()); |
| 1800 | }); |
| 1801 | } |
| 1802 | rewriter.cloneRegionBefore(region&: op.getRegion(), parent&: newOp.getRegion(), |
| 1803 | before: newOp.getRegion().begin(), mapping); |
| 1804 | rewriter.replaceOp(op, newValues: newOp.getResults()); |
| 1805 | return success(); |
| 1806 | } |
| 1807 | }; |
| 1808 | |
| 1809 | /// Replace all induction vars with a single trip count with their lower bound. |
| 1810 | struct ForallOpReplaceConstantInductionVar : public OpRewritePattern<ForallOp> { |
| 1811 | using OpRewritePattern<ForallOp>::OpRewritePattern; |
| 1812 | |
| 1813 | LogicalResult matchAndRewrite(ForallOp op, |
| 1814 | PatternRewriter &rewriter) const override { |
| 1815 | Location loc = op.getLoc(); |
| 1816 | bool changed = false; |
| 1817 | for (auto [lb, ub, step, iv] : |
| 1818 | llvm::zip(t: op.getMixedLowerBound(), u: op.getMixedUpperBound(), |
| 1819 | args: op.getMixedStep(), args: op.getInductionVars())) { |
| 1820 | if (iv.hasNUses(n: 0)) |
| 1821 | continue; |
| 1822 | auto numIterations = constantTripCount(lb, ub, step); |
| 1823 | if (!numIterations.has_value() || numIterations.value() != 1) { |
| 1824 | continue; |
| 1825 | } |
| 1826 | rewriter.replaceAllUsesWith( |
| 1827 | from: iv, to: getValueOrCreateConstantIndexOp(b&: rewriter, loc, ofr: lb)); |
| 1828 | changed = true; |
| 1829 | } |
| 1830 | return success(IsSuccess: changed); |
| 1831 | } |
| 1832 | }; |
| 1833 | |
| 1834 | struct FoldTensorCastOfOutputIntoForallOp |
| 1835 | : public OpRewritePattern<scf::ForallOp> { |
| 1836 | using OpRewritePattern<scf::ForallOp>::OpRewritePattern; |
| 1837 | |
| 1838 | struct TypeCast { |
| 1839 | Type srcType; |
| 1840 | Type dstType; |
| 1841 | }; |
| 1842 | |
| 1843 | LogicalResult matchAndRewrite(scf::ForallOp forallOp, |
| 1844 | PatternRewriter &rewriter) const final { |
| 1845 | llvm::SmallMapVector<unsigned, TypeCast, 2> tensorCastProducers; |
| 1846 | llvm::SmallVector<Value> newOutputTensors = forallOp.getOutputs(); |
| 1847 | for (auto en : llvm::enumerate(First&: newOutputTensors)) { |
| 1848 | auto castOp = en.value().getDefiningOp<tensor::CastOp>(); |
| 1849 | if (!castOp) |
| 1850 | continue; |
| 1851 | |
| 1852 | // Only casts that that preserve static information, i.e. will make the |
| 1853 | // loop result type "more" static than before, will be folded. |
| 1854 | if (!tensor::preservesStaticInformation(source: castOp.getDest().getType(), |
| 1855 | target: castOp.getSource().getType())) { |
| 1856 | continue; |
| 1857 | } |
| 1858 | |
| 1859 | tensorCastProducers[en.index()] = |
| 1860 | TypeCast{.srcType: castOp.getSource().getType(), .dstType: castOp.getType()}; |
| 1861 | newOutputTensors[en.index()] = castOp.getSource(); |
| 1862 | } |
| 1863 | |
| 1864 | if (tensorCastProducers.empty()) |
| 1865 | return failure(); |
| 1866 | |
| 1867 | // Create new loop. |
| 1868 | Location loc = forallOp.getLoc(); |
| 1869 | auto newForallOp = rewriter.create<ForallOp>( |
| 1870 | location: loc, args: forallOp.getMixedLowerBound(), args: forallOp.getMixedUpperBound(), |
| 1871 | args: forallOp.getMixedStep(), args&: newOutputTensors, args: forallOp.getMapping(), |
| 1872 | args: [&](OpBuilder nestedBuilder, Location nestedLoc, ValueRange bbArgs) { |
| 1873 | auto castBlockArgs = |
| 1874 | llvm::to_vector(Range: bbArgs.take_back(n: forallOp->getNumResults())); |
| 1875 | for (auto [index, cast] : tensorCastProducers) { |
| 1876 | Value &oldTypeBBArg = castBlockArgs[index]; |
| 1877 | oldTypeBBArg = nestedBuilder.create<tensor::CastOp>( |
| 1878 | location: nestedLoc, args&: cast.dstType, args&: oldTypeBBArg); |
| 1879 | } |
| 1880 | |
| 1881 | // Move old body into new parallel loop. |
| 1882 | SmallVector<Value> ivsBlockArgs = |
| 1883 | llvm::to_vector(Range: bbArgs.take_front(n: forallOp.getRank())); |
| 1884 | ivsBlockArgs.append(RHS: castBlockArgs); |
| 1885 | rewriter.mergeBlocks(source: forallOp.getBody(), |
| 1886 | dest: bbArgs.front().getParentBlock(), argValues: ivsBlockArgs); |
| 1887 | }); |
| 1888 | |
| 1889 | // After `mergeBlocks` happened, the destinations in the terminator were |
| 1890 | // mapped to the tensor.cast old-typed results of the output bbArgs. The |
| 1891 | // destination have to be updated to point to the output bbArgs directly. |
| 1892 | auto terminator = newForallOp.getTerminator(); |
| 1893 | for (auto [yieldingOp, outputBlockArg] : llvm::zip( |
| 1894 | t: terminator.getYieldingOps(), u: newForallOp.getRegionIterArgs())) { |
| 1895 | auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(Val&: yieldingOp); |
| 1896 | insertSliceOp.getDestMutable().assign(value: outputBlockArg); |
| 1897 | } |
| 1898 | |
| 1899 | // Cast results back to the original types. |
| 1900 | rewriter.setInsertionPointAfter(newForallOp); |
| 1901 | SmallVector<Value> castResults = newForallOp.getResults(); |
| 1902 | for (auto &item : tensorCastProducers) { |
| 1903 | Value &oldTypeResult = castResults[item.first]; |
| 1904 | oldTypeResult = rewriter.create<tensor::CastOp>(location: loc, args&: item.second.dstType, |
| 1905 | args&: oldTypeResult); |
| 1906 | } |
| 1907 | rewriter.replaceOp(op: forallOp, newValues: castResults); |
| 1908 | return success(); |
| 1909 | } |
| 1910 | }; |
| 1911 | |
| 1912 | } // namespace |
| 1913 | |
| 1914 | void ForallOp::getCanonicalizationPatterns(RewritePatternSet &results, |
| 1915 | MLIRContext *context) { |
| 1916 | results.add<DimOfForallOp, FoldTensorCastOfOutputIntoForallOp, |
| 1917 | ForallOpControlOperandsFolder, ForallOpIterArgsFolder, |
| 1918 | ForallOpSingleOrZeroIterationDimsFolder, |
| 1919 | ForallOpReplaceConstantInductionVar>(arg&: context); |
| 1920 | } |
| 1921 | |
| 1922 | /// Given the region at `index`, or the parent operation if `index` is None, |
| 1923 | /// return the successor regions. These are the regions that may be selected |
| 1924 | /// during the flow of control. `operands` is a set of optional attributes that |
| 1925 | /// correspond to a constant value for each operand, or null if that operand is |
| 1926 | /// not a constant. |
| 1927 | void ForallOp::getSuccessorRegions(RegionBranchPoint point, |
| 1928 | SmallVectorImpl<RegionSuccessor> ®ions) { |
| 1929 | // Both the operation itself and the region may be branching into the body or |
| 1930 | // back into the operation itself. It is possible for loop not to enter the |
| 1931 | // body. |
| 1932 | regions.push_back(Elt: RegionSuccessor(&getRegion())); |
| 1933 | regions.push_back(Elt: RegionSuccessor()); |
| 1934 | } |
| 1935 | |
| 1936 | //===----------------------------------------------------------------------===// |
| 1937 | // InParallelOp |
| 1938 | //===----------------------------------------------------------------------===// |
| 1939 | |
| 1940 | // Build a InParallelOp with mixed static and dynamic entries. |
| 1941 | void InParallelOp::build(OpBuilder &b, OperationState &result) { |
| 1942 | OpBuilder::InsertionGuard g(b); |
| 1943 | Region *bodyRegion = result.addRegion(); |
| 1944 | b.createBlock(parent: bodyRegion); |
| 1945 | } |
| 1946 | |
| 1947 | LogicalResult InParallelOp::verify() { |
| 1948 | scf::ForallOp forallOp = |
| 1949 | dyn_cast<scf::ForallOp>(Val: getOperation()->getParentOp()); |
| 1950 | if (!forallOp) |
| 1951 | return this->emitOpError(message: "expected forall op parent" ); |
| 1952 | |
| 1953 | // TODO: InParallelOpInterface. |
| 1954 | for (Operation &op : getRegion().front().getOperations()) { |
| 1955 | if (!isa<tensor::ParallelInsertSliceOp>(Val: op)) { |
| 1956 | return this->emitOpError(message: "expected only " ) |
| 1957 | << tensor::ParallelInsertSliceOp::getOperationName() << " ops" ; |
| 1958 | } |
| 1959 | |
| 1960 | // Verify that inserts are into out block arguments. |
| 1961 | Value dest = cast<tensor::ParallelInsertSliceOp>(Val&: op).getDest(); |
| 1962 | ArrayRef<BlockArgument> regionOutArgs = forallOp.getRegionOutArgs(); |
| 1963 | if (!llvm::is_contained(Range&: regionOutArgs, Element: dest)) |
| 1964 | return op.emitOpError(message: "may only insert into an output block argument" ); |
| 1965 | } |
| 1966 | return success(); |
| 1967 | } |
| 1968 | |
| 1969 | void InParallelOp::print(OpAsmPrinter &p) { |
| 1970 | p << " " ; |
| 1971 | p.printRegion(blocks&: getRegion(), |
| 1972 | /*printEntryBlockArgs=*/false, |
| 1973 | /*printBlockTerminators=*/false); |
| 1974 | p.printOptionalAttrDict(attrs: getOperation()->getAttrs()); |
| 1975 | } |
| 1976 | |
| 1977 | ParseResult InParallelOp::parse(OpAsmParser &parser, OperationState &result) { |
| 1978 | auto &builder = parser.getBuilder(); |
| 1979 | |
| 1980 | SmallVector<OpAsmParser::Argument, 8> regionOperands; |
| 1981 | std::unique_ptr<Region> region = std::make_unique<Region>(); |
| 1982 | if (parser.parseRegion(region&: *region, arguments: regionOperands)) |
| 1983 | return failure(); |
| 1984 | |
| 1985 | if (region->empty()) |
| 1986 | OpBuilder(builder.getContext()).createBlock(parent: region.get()); |
| 1987 | result.addRegion(region: std::move(region)); |
| 1988 | |
| 1989 | // Parse the optional attribute list. |
| 1990 | if (parser.parseOptionalAttrDict(result&: result.attributes)) |
| 1991 | return failure(); |
| 1992 | return success(); |
| 1993 | } |
| 1994 | |
| 1995 | OpResult InParallelOp::getParentResult(int64_t idx) { |
| 1996 | return getOperation()->getParentOp()->getResult(idx); |
| 1997 | } |
| 1998 | |
| 1999 | SmallVector<BlockArgument> InParallelOp::getDests() { |
| 2000 | return llvm::to_vector<4>( |
| 2001 | Range: llvm::map_range(C: getYieldingOps(), F: [](Operation &op) { |
| 2002 | // Add new ops here as needed. |
| 2003 | auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(Val: &op); |
| 2004 | return llvm::cast<BlockArgument>(Val: insertSliceOp.getDest()); |
| 2005 | })); |
| 2006 | } |
| 2007 | |
| 2008 | llvm::iterator_range<Block::iterator> InParallelOp::getYieldingOps() { |
| 2009 | return getRegion().front().getOperations(); |
| 2010 | } |
| 2011 | |
| 2012 | //===----------------------------------------------------------------------===// |
| 2013 | // IfOp |
| 2014 | //===----------------------------------------------------------------------===// |
| 2015 | |
| 2016 | bool mlir::scf::insideMutuallyExclusiveBranches(Operation *a, Operation *b) { |
| 2017 | assert(a && "expected non-empty operation" ); |
| 2018 | assert(b && "expected non-empty operation" ); |
| 2019 | |
| 2020 | IfOp ifOp = a->getParentOfType<IfOp>(); |
| 2021 | while (ifOp) { |
| 2022 | // Check if b is inside ifOp. (We already know that a is.) |
| 2023 | if (ifOp->isProperAncestor(other: b)) |
| 2024 | // b is contained in ifOp. a and b are in mutually exclusive branches if |
| 2025 | // they are in different blocks of ifOp. |
| 2026 | return static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(op&: *a)) != |
| 2027 | static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(op&: *b)); |
| 2028 | // Check next enclosing IfOp. |
| 2029 | ifOp = ifOp->getParentOfType<IfOp>(); |
| 2030 | } |
| 2031 | |
| 2032 | // Could not find a common IfOp among a's and b's ancestors. |
| 2033 | return false; |
| 2034 | } |
| 2035 | |
| 2036 | LogicalResult |
| 2037 | IfOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc, |
| 2038 | IfOp::Adaptor adaptor, |
| 2039 | SmallVectorImpl<Type> &inferredReturnTypes) { |
| 2040 | if (adaptor.getRegions().empty()) |
| 2041 | return failure(); |
| 2042 | Region *r = &adaptor.getThenRegion(); |
| 2043 | if (r->empty()) |
| 2044 | return failure(); |
| 2045 | Block &b = r->front(); |
| 2046 | if (b.empty()) |
| 2047 | return failure(); |
| 2048 | auto yieldOp = llvm::dyn_cast<YieldOp>(Val&: b.back()); |
| 2049 | if (!yieldOp) |
| 2050 | return failure(); |
| 2051 | TypeRange types = yieldOp.getOperandTypes(); |
| 2052 | llvm::append_range(C&: inferredReturnTypes, R&: types); |
| 2053 | return success(); |
| 2054 | } |
| 2055 | |
| 2056 | void IfOp::build(OpBuilder &builder, OperationState &result, |
| 2057 | TypeRange resultTypes, Value cond) { |
| 2058 | return build(odsBuilder&: builder, odsState&: result, resultTypes, cond, /*addThenBlock=*/false, |
| 2059 | /*addElseBlock=*/false); |
| 2060 | } |
| 2061 | |
| 2062 | void IfOp::build(OpBuilder &builder, OperationState &result, |
| 2063 | TypeRange resultTypes, Value cond, bool addThenBlock, |
| 2064 | bool addElseBlock) { |
| 2065 | assert((!addElseBlock || addThenBlock) && |
| 2066 | "must not create else block w/o then block" ); |
| 2067 | result.addTypes(newTypes&: resultTypes); |
| 2068 | result.addOperands(newOperands: cond); |
| 2069 | |
| 2070 | // Add regions and blocks. |
| 2071 | OpBuilder::InsertionGuard guard(builder); |
| 2072 | Region *thenRegion = result.addRegion(); |
| 2073 | if (addThenBlock) |
| 2074 | builder.createBlock(parent: thenRegion); |
| 2075 | Region *elseRegion = result.addRegion(); |
| 2076 | if (addElseBlock) |
| 2077 | builder.createBlock(parent: elseRegion); |
| 2078 | } |
| 2079 | |
| 2080 | void IfOp::build(OpBuilder &builder, OperationState &result, Value cond, |
| 2081 | bool withElseRegion) { |
| 2082 | build(odsBuilder&: builder, odsState&: result, resultTypes: TypeRange{}, cond, withElseRegion); |
| 2083 | } |
| 2084 | |
| 2085 | void IfOp::build(OpBuilder &builder, OperationState &result, |
| 2086 | TypeRange resultTypes, Value cond, bool withElseRegion) { |
| 2087 | result.addTypes(newTypes&: resultTypes); |
| 2088 | result.addOperands(newOperands: cond); |
| 2089 | |
| 2090 | // Build then region. |
| 2091 | OpBuilder::InsertionGuard guard(builder); |
| 2092 | Region *thenRegion = result.addRegion(); |
| 2093 | builder.createBlock(parent: thenRegion); |
| 2094 | if (resultTypes.empty()) |
| 2095 | IfOp::ensureTerminator(region&: *thenRegion, builder, loc: result.location); |
| 2096 | |
| 2097 | // Build else region. |
| 2098 | Region *elseRegion = result.addRegion(); |
| 2099 | if (withElseRegion) { |
| 2100 | builder.createBlock(parent: elseRegion); |
| 2101 | if (resultTypes.empty()) |
| 2102 | IfOp::ensureTerminator(region&: *elseRegion, builder, loc: result.location); |
| 2103 | } |
| 2104 | } |
| 2105 | |
| 2106 | void IfOp::build(OpBuilder &builder, OperationState &result, Value cond, |
| 2107 | function_ref<void(OpBuilder &, Location)> thenBuilder, |
| 2108 | function_ref<void(OpBuilder &, Location)> elseBuilder) { |
| 2109 | assert(thenBuilder && "the builder callback for 'then' must be present" ); |
| 2110 | result.addOperands(newOperands: cond); |
| 2111 | |
| 2112 | // Build then region. |
| 2113 | OpBuilder::InsertionGuard guard(builder); |
| 2114 | Region *thenRegion = result.addRegion(); |
| 2115 | builder.createBlock(parent: thenRegion); |
| 2116 | thenBuilder(builder, result.location); |
| 2117 | |
| 2118 | // Build else region. |
| 2119 | Region *elseRegion = result.addRegion(); |
| 2120 | if (elseBuilder) { |
| 2121 | builder.createBlock(parent: elseRegion); |
| 2122 | elseBuilder(builder, result.location); |
| 2123 | } |
| 2124 | |
| 2125 | // Infer result types. |
| 2126 | SmallVector<Type> inferredReturnTypes; |
| 2127 | MLIRContext *ctx = builder.getContext(); |
| 2128 | auto attrDict = DictionaryAttr::get(context: ctx, value: result.attributes); |
| 2129 | if (succeeded(Result: inferReturnTypes(context: ctx, location: std::nullopt, operands: result.operands, attributes: attrDict, |
| 2130 | /*properties=*/nullptr, regions: result.regions, |
| 2131 | inferredReturnTypes))) { |
| 2132 | result.addTypes(newTypes: inferredReturnTypes); |
| 2133 | } |
| 2134 | } |
| 2135 | |
| 2136 | LogicalResult IfOp::verify() { |
| 2137 | if (getNumResults() != 0 && getElseRegion().empty()) |
| 2138 | return emitOpError(message: "must have an else block if defining values" ); |
| 2139 | return success(); |
| 2140 | } |
| 2141 | |
| 2142 | ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) { |
| 2143 | // Create the regions for 'then'. |
| 2144 | result.regions.reserve(N: 2); |
| 2145 | Region *thenRegion = result.addRegion(); |
| 2146 | Region *elseRegion = result.addRegion(); |
| 2147 | |
| 2148 | auto &builder = parser.getBuilder(); |
| 2149 | OpAsmParser::UnresolvedOperand cond; |
| 2150 | Type i1Type = builder.getIntegerType(width: 1); |
| 2151 | if (parser.parseOperand(result&: cond) || |
| 2152 | parser.resolveOperand(operand: cond, type: i1Type, result&: result.operands)) |
| 2153 | return failure(); |
| 2154 | // Parse optional results type list. |
| 2155 | if (parser.parseOptionalArrowTypeList(result&: result.types)) |
| 2156 | return failure(); |
| 2157 | // Parse the 'then' region. |
| 2158 | if (parser.parseRegion(region&: *thenRegion, /*arguments=*/{}, /*argTypes=*/enableNameShadowing: {})) |
| 2159 | return failure(); |
| 2160 | IfOp::ensureTerminator(region&: *thenRegion, builder&: parser.getBuilder(), loc: result.location); |
| 2161 | |
| 2162 | // If we find an 'else' keyword then parse the 'else' region. |
| 2163 | if (!parser.parseOptionalKeyword(keyword: "else" )) { |
| 2164 | if (parser.parseRegion(region&: *elseRegion, /*arguments=*/{}, /*argTypes=*/enableNameShadowing: {})) |
| 2165 | return failure(); |
| 2166 | IfOp::ensureTerminator(region&: *elseRegion, builder&: parser.getBuilder(), loc: result.location); |
| 2167 | } |
| 2168 | |
| 2169 | // Parse the optional attribute list. |
| 2170 | if (parser.parseOptionalAttrDict(result&: result.attributes)) |
| 2171 | return failure(); |
| 2172 | return success(); |
| 2173 | } |
| 2174 | |
| 2175 | void IfOp::print(OpAsmPrinter &p) { |
| 2176 | bool printBlockTerminators = false; |
| 2177 | |
| 2178 | p << " " << getCondition(); |
| 2179 | if (!getResults().empty()) { |
| 2180 | p << " -> (" << getResultTypes() << ")" ; |
| 2181 | // Print yield explicitly if the op defines values. |
| 2182 | printBlockTerminators = true; |
| 2183 | } |
| 2184 | p << ' '; |
| 2185 | p.printRegion(blocks&: getThenRegion(), |
| 2186 | /*printEntryBlockArgs=*/false, |
| 2187 | /*printBlockTerminators=*/printBlockTerminators); |
| 2188 | |
| 2189 | // Print the 'else' regions if it exists and has a block. |
| 2190 | auto &elseRegion = getElseRegion(); |
| 2191 | if (!elseRegion.empty()) { |
| 2192 | p << " else " ; |
| 2193 | p.printRegion(blocks&: elseRegion, |
| 2194 | /*printEntryBlockArgs=*/false, |
| 2195 | /*printBlockTerminators=*/printBlockTerminators); |
| 2196 | } |
| 2197 | |
| 2198 | p.printOptionalAttrDict(attrs: (*this)->getAttrs()); |
| 2199 | } |
| 2200 | |
| 2201 | void IfOp::getSuccessorRegions(RegionBranchPoint point, |
| 2202 | SmallVectorImpl<RegionSuccessor> ®ions) { |
| 2203 | // The `then` and the `else` region branch back to the parent operation. |
| 2204 | if (!point.isParent()) { |
| 2205 | regions.push_back(Elt: RegionSuccessor(getResults())); |
| 2206 | return; |
| 2207 | } |
| 2208 | |
| 2209 | regions.push_back(Elt: RegionSuccessor(&getThenRegion())); |
| 2210 | |
| 2211 | // Don't consider the else region if it is empty. |
| 2212 | Region *elseRegion = &this->getElseRegion(); |
| 2213 | if (elseRegion->empty()) |
| 2214 | regions.push_back(Elt: RegionSuccessor()); |
| 2215 | else |
| 2216 | regions.push_back(Elt: RegionSuccessor(elseRegion)); |
| 2217 | } |
| 2218 | |
| 2219 | void IfOp::getEntrySuccessorRegions(ArrayRef<Attribute> operands, |
| 2220 | SmallVectorImpl<RegionSuccessor> ®ions) { |
| 2221 | FoldAdaptor adaptor(operands, *this); |
| 2222 | auto boolAttr = dyn_cast_or_null<BoolAttr>(Val: adaptor.getCondition()); |
| 2223 | if (!boolAttr || boolAttr.getValue()) |
| 2224 | regions.emplace_back(Args: &getThenRegion()); |
| 2225 | |
| 2226 | // If the else region is empty, execution continues after the parent op. |
| 2227 | if (!boolAttr || !boolAttr.getValue()) { |
| 2228 | if (!getElseRegion().empty()) |
| 2229 | regions.emplace_back(Args: &getElseRegion()); |
| 2230 | else |
| 2231 | regions.emplace_back(Args: getResults()); |
| 2232 | } |
| 2233 | } |
| 2234 | |
| 2235 | LogicalResult IfOp::fold(FoldAdaptor adaptor, |
| 2236 | SmallVectorImpl<OpFoldResult> &results) { |
| 2237 | // if (!c) then A() else B() -> if c then B() else A() |
| 2238 | if (getElseRegion().empty()) |
| 2239 | return failure(); |
| 2240 | |
| 2241 | arith::XOrIOp xorStmt = getCondition().getDefiningOp<arith::XOrIOp>(); |
| 2242 | if (!xorStmt) |
| 2243 | return failure(); |
| 2244 | |
| 2245 | if (!matchPattern(value: xorStmt.getRhs(), pattern: m_One())) |
| 2246 | return failure(); |
| 2247 | |
| 2248 | getConditionMutable().assign(value: xorStmt.getLhs()); |
| 2249 | Block *thenBlock = &getThenRegion().front(); |
| 2250 | // It would be nicer to use iplist::swap, but that has no implemented |
| 2251 | // callbacks See: https://llvm.org/doxygen/ilist_8h_source.html#l00224 |
| 2252 | getThenRegion().getBlocks().splice(where: getThenRegion().getBlocks().begin(), |
| 2253 | L2&: getElseRegion().getBlocks()); |
| 2254 | getElseRegion().getBlocks().splice(where: getElseRegion().getBlocks().begin(), |
| 2255 | L2&: getThenRegion().getBlocks(), N: thenBlock); |
| 2256 | return success(); |
| 2257 | } |
| 2258 | |
| 2259 | void IfOp::getRegionInvocationBounds( |
| 2260 | ArrayRef<Attribute> operands, |
| 2261 | SmallVectorImpl<InvocationBounds> &invocationBounds) { |
| 2262 | if (auto cond = llvm::dyn_cast_or_null<BoolAttr>(Val: operands[0])) { |
| 2263 | // If the condition is known, then one region is known to be executed once |
| 2264 | // and the other zero times. |
| 2265 | invocationBounds.emplace_back(Args: 0, Args: cond.getValue() ? 1 : 0); |
| 2266 | invocationBounds.emplace_back(Args: 0, Args: cond.getValue() ? 0 : 1); |
| 2267 | } else { |
| 2268 | // Non-constant condition. Each region may be executed 0 or 1 times. |
| 2269 | invocationBounds.assign(NumElts: 2, Elt: {0, 1}); |
| 2270 | } |
| 2271 | } |
| 2272 | |
| 2273 | namespace { |
| 2274 | // Pattern to remove unused IfOp results. |
| 2275 | struct RemoveUnusedResults : public OpRewritePattern<IfOp> { |
| 2276 | using OpRewritePattern<IfOp>::OpRewritePattern; |
| 2277 | |
| 2278 | void transferBody(Block *source, Block *dest, ArrayRef<OpResult> usedResults, |
| 2279 | PatternRewriter &rewriter) const { |
| 2280 | // Move all operations to the destination block. |
| 2281 | rewriter.mergeBlocks(source, dest); |
| 2282 | // Replace the yield op by one that returns only the used values. |
| 2283 | auto yieldOp = cast<scf::YieldOp>(Val: dest->getTerminator()); |
| 2284 | SmallVector<Value, 4> usedOperands; |
| 2285 | llvm::transform(Range&: usedResults, d_first: std::back_inserter(x&: usedOperands), |
| 2286 | F: [&](OpResult result) { |
| 2287 | return yieldOp.getOperand(i: result.getResultNumber()); |
| 2288 | }); |
| 2289 | rewriter.modifyOpInPlace(root: yieldOp, |
| 2290 | callable: [&]() { yieldOp->setOperands(usedOperands); }); |
| 2291 | } |
| 2292 | |
| 2293 | LogicalResult matchAndRewrite(IfOp op, |
| 2294 | PatternRewriter &rewriter) const override { |
| 2295 | // Compute the list of used results. |
| 2296 | SmallVector<OpResult, 4> usedResults; |
| 2297 | llvm::copy_if(Range: op.getResults(), Out: std::back_inserter(x&: usedResults), |
| 2298 | P: [](OpResult result) { return !result.use_empty(); }); |
| 2299 | |
| 2300 | // Replace the operation if only a subset of its results have uses. |
| 2301 | if (usedResults.size() == op.getNumResults()) |
| 2302 | return failure(); |
| 2303 | |
| 2304 | // Compute the result types of the replacement operation. |
| 2305 | SmallVector<Type, 4> newTypes; |
| 2306 | llvm::transform(Range&: usedResults, d_first: std::back_inserter(x&: newTypes), |
| 2307 | F: [](OpResult result) { return result.getType(); }); |
| 2308 | |
| 2309 | // Create a replacement operation with empty then and else regions. |
| 2310 | auto newOp = |
| 2311 | rewriter.create<IfOp>(location: op.getLoc(), args&: newTypes, args: op.getCondition()); |
| 2312 | rewriter.createBlock(parent: &newOp.getThenRegion()); |
| 2313 | rewriter.createBlock(parent: &newOp.getElseRegion()); |
| 2314 | |
| 2315 | // Move the bodies and replace the terminators (note there is a then and |
| 2316 | // an else region since the operation returns results). |
| 2317 | transferBody(source: op.getBody(idx: 0), dest: newOp.getBody(idx: 0), usedResults, rewriter); |
| 2318 | transferBody(source: op.getBody(idx: 1), dest: newOp.getBody(idx: 1), usedResults, rewriter); |
| 2319 | |
| 2320 | // Replace the operation by the new one. |
| 2321 | SmallVector<Value, 4> repResults(op.getNumResults()); |
| 2322 | for (const auto &en : llvm::enumerate(First&: usedResults)) |
| 2323 | repResults[en.value().getResultNumber()] = newOp.getResult(i: en.index()); |
| 2324 | rewriter.replaceOp(op, newValues: repResults); |
| 2325 | return success(); |
| 2326 | } |
| 2327 | }; |
| 2328 | |
| 2329 | struct RemoveStaticCondition : public OpRewritePattern<IfOp> { |
| 2330 | using OpRewritePattern<IfOp>::OpRewritePattern; |
| 2331 | |
| 2332 | LogicalResult matchAndRewrite(IfOp op, |
| 2333 | PatternRewriter &rewriter) const override { |
| 2334 | BoolAttr condition; |
| 2335 | if (!matchPattern(value: op.getCondition(), pattern: m_Constant(bind_value: &condition))) |
| 2336 | return failure(); |
| 2337 | |
| 2338 | if (condition.getValue()) |
| 2339 | replaceOpWithRegion(rewriter, op, region&: op.getThenRegion()); |
| 2340 | else if (!op.getElseRegion().empty()) |
| 2341 | replaceOpWithRegion(rewriter, op, region&: op.getElseRegion()); |
| 2342 | else |
| 2343 | rewriter.eraseOp(op); |
| 2344 | |
| 2345 | return success(); |
| 2346 | } |
| 2347 | }; |
| 2348 | |
| 2349 | /// Hoist any yielded results whose operands are defined outside |
| 2350 | /// the if, to a select instruction. |
| 2351 | struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> { |
| 2352 | using OpRewritePattern<IfOp>::OpRewritePattern; |
| 2353 | |
| 2354 | LogicalResult matchAndRewrite(IfOp op, |
| 2355 | PatternRewriter &rewriter) const override { |
| 2356 | if (op->getNumResults() == 0) |
| 2357 | return failure(); |
| 2358 | |
| 2359 | auto cond = op.getCondition(); |
| 2360 | auto thenYieldArgs = op.thenYield().getOperands(); |
| 2361 | auto elseYieldArgs = op.elseYield().getOperands(); |
| 2362 | |
| 2363 | SmallVector<Type> nonHoistable; |
| 2364 | for (auto [trueVal, falseVal] : llvm::zip(t&: thenYieldArgs, u&: elseYieldArgs)) { |
| 2365 | if (&op.getThenRegion() == trueVal.getParentRegion() || |
| 2366 | &op.getElseRegion() == falseVal.getParentRegion()) |
| 2367 | nonHoistable.push_back(Elt: trueVal.getType()); |
| 2368 | } |
| 2369 | // Early exit if there aren't any yielded values we can |
| 2370 | // hoist outside the if. |
| 2371 | if (nonHoistable.size() == op->getNumResults()) |
| 2372 | return failure(); |
| 2373 | |
| 2374 | IfOp replacement = rewriter.create<IfOp>(location: op.getLoc(), args&: nonHoistable, args&: cond, |
| 2375 | /*withElseRegion=*/args: false); |
| 2376 | if (replacement.thenBlock()) |
| 2377 | rewriter.eraseBlock(block: replacement.thenBlock()); |
| 2378 | replacement.getThenRegion().takeBody(other&: op.getThenRegion()); |
| 2379 | replacement.getElseRegion().takeBody(other&: op.getElseRegion()); |
| 2380 | |
| 2381 | SmallVector<Value> results(op->getNumResults()); |
| 2382 | assert(thenYieldArgs.size() == results.size()); |
| 2383 | assert(elseYieldArgs.size() == results.size()); |
| 2384 | |
| 2385 | SmallVector<Value> trueYields; |
| 2386 | SmallVector<Value> falseYields; |
| 2387 | rewriter.setInsertionPoint(replacement); |
| 2388 | for (const auto &it : |
| 2389 | llvm::enumerate(First: llvm::zip(t&: thenYieldArgs, u&: elseYieldArgs))) { |
| 2390 | Value trueVal = std::get<0>(t&: it.value()); |
| 2391 | Value falseVal = std::get<1>(t&: it.value()); |
| 2392 | if (&replacement.getThenRegion() == trueVal.getParentRegion() || |
| 2393 | &replacement.getElseRegion() == falseVal.getParentRegion()) { |
| 2394 | results[it.index()] = replacement.getResult(i: trueYields.size()); |
| 2395 | trueYields.push_back(Elt: trueVal); |
| 2396 | falseYields.push_back(Elt: falseVal); |
| 2397 | } else if (trueVal == falseVal) |
| 2398 | results[it.index()] = trueVal; |
| 2399 | else |
| 2400 | results[it.index()] = rewriter.create<arith::SelectOp>( |
| 2401 | location: op.getLoc(), args&: cond, args&: trueVal, args&: falseVal); |
| 2402 | } |
| 2403 | |
| 2404 | rewriter.setInsertionPointToEnd(replacement.thenBlock()); |
| 2405 | rewriter.replaceOpWithNewOp<YieldOp>(op: replacement.thenYield(), args&: trueYields); |
| 2406 | |
| 2407 | rewriter.setInsertionPointToEnd(replacement.elseBlock()); |
| 2408 | rewriter.replaceOpWithNewOp<YieldOp>(op: replacement.elseYield(), args&: falseYields); |
| 2409 | |
| 2410 | rewriter.replaceOp(op, newValues: results); |
| 2411 | return success(); |
| 2412 | } |
| 2413 | }; |
| 2414 | |
| 2415 | /// Allow the true region of an if to assume the condition is true |
| 2416 | /// and vice versa. For example: |
| 2417 | /// |
| 2418 | /// scf.if %cmp { |
| 2419 | /// print(%cmp) |
| 2420 | /// } |
| 2421 | /// |
| 2422 | /// becomes |
| 2423 | /// |
| 2424 | /// scf.if %cmp { |
| 2425 | /// print(true) |
| 2426 | /// } |
| 2427 | /// |
| 2428 | struct ConditionPropagation : public OpRewritePattern<IfOp> { |
| 2429 | using OpRewritePattern<IfOp>::OpRewritePattern; |
| 2430 | |
| 2431 | LogicalResult matchAndRewrite(IfOp op, |
| 2432 | PatternRewriter &rewriter) const override { |
| 2433 | // Early exit if the condition is constant since replacing a constant |
| 2434 | // in the body with another constant isn't a simplification. |
| 2435 | if (matchPattern(value: op.getCondition(), pattern: m_Constant())) |
| 2436 | return failure(); |
| 2437 | |
| 2438 | bool changed = false; |
| 2439 | mlir::Type i1Ty = rewriter.getI1Type(); |
| 2440 | |
| 2441 | // These variables serve to prevent creating duplicate constants |
| 2442 | // and hold constant true or false values. |
| 2443 | Value constantTrue = nullptr; |
| 2444 | Value constantFalse = nullptr; |
| 2445 | |
| 2446 | for (OpOperand &use : |
| 2447 | llvm::make_early_inc_range(Range: op.getCondition().getUses())) { |
| 2448 | if (op.getThenRegion().isAncestor(other: use.getOwner()->getParentRegion())) { |
| 2449 | changed = true; |
| 2450 | |
| 2451 | if (!constantTrue) |
| 2452 | constantTrue = rewriter.create<arith::ConstantOp>( |
| 2453 | location: op.getLoc(), args&: i1Ty, args: rewriter.getIntegerAttr(type: i1Ty, value: 1)); |
| 2454 | |
| 2455 | rewriter.modifyOpInPlace(root: use.getOwner(), |
| 2456 | callable: [&]() { use.set(constantTrue); }); |
| 2457 | } else if (op.getElseRegion().isAncestor( |
| 2458 | other: use.getOwner()->getParentRegion())) { |
| 2459 | changed = true; |
| 2460 | |
| 2461 | if (!constantFalse) |
| 2462 | constantFalse = rewriter.create<arith::ConstantOp>( |
| 2463 | location: op.getLoc(), args&: i1Ty, args: rewriter.getIntegerAttr(type: i1Ty, value: 0)); |
| 2464 | |
| 2465 | rewriter.modifyOpInPlace(root: use.getOwner(), |
| 2466 | callable: [&]() { use.set(constantFalse); }); |
| 2467 | } |
| 2468 | } |
| 2469 | |
| 2470 | return success(IsSuccess: changed); |
| 2471 | } |
| 2472 | }; |
| 2473 | |
| 2474 | /// Remove any statements from an if that are equivalent to the condition |
| 2475 | /// or its negation. For example: |
| 2476 | /// |
| 2477 | /// %res:2 = scf.if %cmp { |
| 2478 | /// yield something(), true |
| 2479 | /// } else { |
| 2480 | /// yield something2(), false |
| 2481 | /// } |
| 2482 | /// print(%res#1) |
| 2483 | /// |
| 2484 | /// becomes |
| 2485 | /// %res = scf.if %cmp { |
| 2486 | /// yield something() |
| 2487 | /// } else { |
| 2488 | /// yield something2() |
| 2489 | /// } |
| 2490 | /// print(%cmp) |
| 2491 | /// |
| 2492 | /// Additionally if both branches yield the same value, replace all uses |
| 2493 | /// of the result with the yielded value. |
| 2494 | /// |
| 2495 | /// %res:2 = scf.if %cmp { |
| 2496 | /// yield something(), %arg1 |
| 2497 | /// } else { |
| 2498 | /// yield something2(), %arg1 |
| 2499 | /// } |
| 2500 | /// print(%res#1) |
| 2501 | /// |
| 2502 | /// becomes |
| 2503 | /// %res = scf.if %cmp { |
| 2504 | /// yield something() |
| 2505 | /// } else { |
| 2506 | /// yield something2() |
| 2507 | /// } |
| 2508 | /// print(%arg1) |
| 2509 | /// |
| 2510 | struct ReplaceIfYieldWithConditionOrValue : public OpRewritePattern<IfOp> { |
| 2511 | using OpRewritePattern<IfOp>::OpRewritePattern; |
| 2512 | |
| 2513 | LogicalResult matchAndRewrite(IfOp op, |
| 2514 | PatternRewriter &rewriter) const override { |
| 2515 | // Early exit if there are no results that could be replaced. |
| 2516 | if (op.getNumResults() == 0) |
| 2517 | return failure(); |
| 2518 | |
| 2519 | auto trueYield = |
| 2520 | cast<scf::YieldOp>(Val: op.getThenRegion().back().getTerminator()); |
| 2521 | auto falseYield = |
| 2522 | cast<scf::YieldOp>(Val: op.getElseRegion().back().getTerminator()); |
| 2523 | |
| 2524 | rewriter.setInsertionPoint(block: op->getBlock(), |
| 2525 | insertPoint: op.getOperation()->getIterator()); |
| 2526 | bool changed = false; |
| 2527 | Type i1Ty = rewriter.getI1Type(); |
| 2528 | for (auto [trueResult, falseResult, opResult] : |
| 2529 | llvm::zip(t: trueYield.getResults(), u: falseYield.getResults(), |
| 2530 | args: op.getResults())) { |
| 2531 | if (trueResult == falseResult) { |
| 2532 | if (!opResult.use_empty()) { |
| 2533 | opResult.replaceAllUsesWith(newValue: trueResult); |
| 2534 | changed = true; |
| 2535 | } |
| 2536 | continue; |
| 2537 | } |
| 2538 | |
| 2539 | BoolAttr trueYield, falseYield; |
| 2540 | if (!matchPattern(value: trueResult, pattern: m_Constant(bind_value: &trueYield)) || |
| 2541 | !matchPattern(value: falseResult, pattern: m_Constant(bind_value: &falseYield))) |
| 2542 | continue; |
| 2543 | |
| 2544 | bool trueVal = trueYield.getValue(); |
| 2545 | bool falseVal = falseYield.getValue(); |
| 2546 | if (!trueVal && falseVal) { |
| 2547 | if (!opResult.use_empty()) { |
| 2548 | Dialect *constDialect = trueResult.getDefiningOp()->getDialect(); |
| 2549 | Value notCond = rewriter.create<arith::XOrIOp>( |
| 2550 | location: op.getLoc(), args: op.getCondition(), |
| 2551 | args: constDialect |
| 2552 | ->materializeConstant(builder&: rewriter, |
| 2553 | value: rewriter.getIntegerAttr(type: i1Ty, value: 1), type: i1Ty, |
| 2554 | loc: op.getLoc()) |
| 2555 | ->getResult(idx: 0)); |
| 2556 | opResult.replaceAllUsesWith(newValue: notCond); |
| 2557 | changed = true; |
| 2558 | } |
| 2559 | } |
| 2560 | if (trueVal && !falseVal) { |
| 2561 | if (!opResult.use_empty()) { |
| 2562 | opResult.replaceAllUsesWith(newValue: op.getCondition()); |
| 2563 | changed = true; |
| 2564 | } |
| 2565 | } |
| 2566 | } |
| 2567 | return success(IsSuccess: changed); |
| 2568 | } |
| 2569 | }; |
| 2570 | |
| 2571 | /// Merge any consecutive scf.if's with the same condition. |
| 2572 | /// |
| 2573 | /// scf.if %cond { |
| 2574 | /// firstCodeTrue();... |
| 2575 | /// } else { |
| 2576 | /// firstCodeFalse();... |
| 2577 | /// } |
| 2578 | /// %res = scf.if %cond { |
| 2579 | /// secondCodeTrue();... |
| 2580 | /// } else { |
| 2581 | /// secondCodeFalse();... |
| 2582 | /// } |
| 2583 | /// |
| 2584 | /// becomes |
| 2585 | /// %res = scf.if %cmp { |
| 2586 | /// firstCodeTrue();... |
| 2587 | /// secondCodeTrue();... |
| 2588 | /// } else { |
| 2589 | /// firstCodeFalse();... |
| 2590 | /// secondCodeFalse();... |
| 2591 | /// } |
| 2592 | struct CombineIfs : public OpRewritePattern<IfOp> { |
| 2593 | using OpRewritePattern<IfOp>::OpRewritePattern; |
| 2594 | |
| 2595 | LogicalResult matchAndRewrite(IfOp nextIf, |
| 2596 | PatternRewriter &rewriter) const override { |
| 2597 | Block *parent = nextIf->getBlock(); |
| 2598 | if (nextIf == &parent->front()) |
| 2599 | return failure(); |
| 2600 | |
| 2601 | auto prevIf = dyn_cast<IfOp>(Val: nextIf->getPrevNode()); |
| 2602 | if (!prevIf) |
| 2603 | return failure(); |
| 2604 | |
| 2605 | // Determine the logical then/else blocks when prevIf's |
| 2606 | // condition is used. Null means the block does not exist |
| 2607 | // in that case (e.g. empty else). If neither of these |
| 2608 | // are set, the two conditions cannot be compared. |
| 2609 | Block *nextThen = nullptr; |
| 2610 | Block *nextElse = nullptr; |
| 2611 | if (nextIf.getCondition() == prevIf.getCondition()) { |
| 2612 | nextThen = nextIf.thenBlock(); |
| 2613 | if (!nextIf.getElseRegion().empty()) |
| 2614 | nextElse = nextIf.elseBlock(); |
| 2615 | } |
| 2616 | if (arith::XOrIOp notv = |
| 2617 | nextIf.getCondition().getDefiningOp<arith::XOrIOp>()) { |
| 2618 | if (notv.getLhs() == prevIf.getCondition() && |
| 2619 | matchPattern(value: notv.getRhs(), pattern: m_One())) { |
| 2620 | nextElse = nextIf.thenBlock(); |
| 2621 | if (!nextIf.getElseRegion().empty()) |
| 2622 | nextThen = nextIf.elseBlock(); |
| 2623 | } |
| 2624 | } |
| 2625 | if (arith::XOrIOp notv = |
| 2626 | prevIf.getCondition().getDefiningOp<arith::XOrIOp>()) { |
| 2627 | if (notv.getLhs() == nextIf.getCondition() && |
| 2628 | matchPattern(value: notv.getRhs(), pattern: m_One())) { |
| 2629 | nextElse = nextIf.thenBlock(); |
| 2630 | if (!nextIf.getElseRegion().empty()) |
| 2631 | nextThen = nextIf.elseBlock(); |
| 2632 | } |
| 2633 | } |
| 2634 | |
| 2635 | if (!nextThen && !nextElse) |
| 2636 | return failure(); |
| 2637 | |
| 2638 | SmallVector<Value> prevElseYielded; |
| 2639 | if (!prevIf.getElseRegion().empty()) |
| 2640 | prevElseYielded = prevIf.elseYield().getOperands(); |
| 2641 | // Replace all uses of return values of op within nextIf with the |
| 2642 | // corresponding yields |
| 2643 | for (auto it : llvm::zip(t: prevIf.getResults(), |
| 2644 | u: prevIf.thenYield().getOperands(), args&: prevElseYielded)) |
| 2645 | for (OpOperand &use : |
| 2646 | llvm::make_early_inc_range(Range: std::get<0>(t&: it).getUses())) { |
| 2647 | if (nextThen && nextThen->getParent()->isAncestor( |
| 2648 | other: use.getOwner()->getParentRegion())) { |
| 2649 | rewriter.startOpModification(op: use.getOwner()); |
| 2650 | use.set(std::get<1>(t&: it)); |
| 2651 | rewriter.finalizeOpModification(op: use.getOwner()); |
| 2652 | } else if (nextElse && nextElse->getParent()->isAncestor( |
| 2653 | other: use.getOwner()->getParentRegion())) { |
| 2654 | rewriter.startOpModification(op: use.getOwner()); |
| 2655 | use.set(std::get<2>(t&: it)); |
| 2656 | rewriter.finalizeOpModification(op: use.getOwner()); |
| 2657 | } |
| 2658 | } |
| 2659 | |
| 2660 | SmallVector<Type> mergedTypes(prevIf.getResultTypes()); |
| 2661 | llvm::append_range(C&: mergedTypes, R: nextIf.getResultTypes()); |
| 2662 | |
| 2663 | IfOp combinedIf = rewriter.create<IfOp>( |
| 2664 | location: nextIf.getLoc(), args&: mergedTypes, args: prevIf.getCondition(), /*hasElse=*/args: false); |
| 2665 | rewriter.eraseBlock(block: &combinedIf.getThenRegion().back()); |
| 2666 | |
| 2667 | rewriter.inlineRegionBefore(region&: prevIf.getThenRegion(), |
| 2668 | parent&: combinedIf.getThenRegion(), |
| 2669 | before: combinedIf.getThenRegion().begin()); |
| 2670 | |
| 2671 | if (nextThen) { |
| 2672 | YieldOp thenYield = combinedIf.thenYield(); |
| 2673 | YieldOp thenYield2 = cast<YieldOp>(Val: nextThen->getTerminator()); |
| 2674 | rewriter.mergeBlocks(source: nextThen, dest: combinedIf.thenBlock()); |
| 2675 | rewriter.setInsertionPointToEnd(combinedIf.thenBlock()); |
| 2676 | |
| 2677 | SmallVector<Value> mergedYields(thenYield.getOperands()); |
| 2678 | llvm::append_range(C&: mergedYields, R: thenYield2.getOperands()); |
| 2679 | rewriter.create<YieldOp>(location: thenYield2.getLoc(), args&: mergedYields); |
| 2680 | rewriter.eraseOp(op: thenYield); |
| 2681 | rewriter.eraseOp(op: thenYield2); |
| 2682 | } |
| 2683 | |
| 2684 | rewriter.inlineRegionBefore(region&: prevIf.getElseRegion(), |
| 2685 | parent&: combinedIf.getElseRegion(), |
| 2686 | before: combinedIf.getElseRegion().begin()); |
| 2687 | |
| 2688 | if (nextElse) { |
| 2689 | if (combinedIf.getElseRegion().empty()) { |
| 2690 | rewriter.inlineRegionBefore(region&: *nextElse->getParent(), |
| 2691 | parent&: combinedIf.getElseRegion(), |
| 2692 | before: combinedIf.getElseRegion().begin()); |
| 2693 | } else { |
| 2694 | YieldOp elseYield = combinedIf.elseYield(); |
| 2695 | YieldOp elseYield2 = cast<YieldOp>(Val: nextElse->getTerminator()); |
| 2696 | rewriter.mergeBlocks(source: nextElse, dest: combinedIf.elseBlock()); |
| 2697 | |
| 2698 | rewriter.setInsertionPointToEnd(combinedIf.elseBlock()); |
| 2699 | |
| 2700 | SmallVector<Value> mergedElseYields(elseYield.getOperands()); |
| 2701 | llvm::append_range(C&: mergedElseYields, R: elseYield2.getOperands()); |
| 2702 | |
| 2703 | rewriter.create<YieldOp>(location: elseYield2.getLoc(), args&: mergedElseYields); |
| 2704 | rewriter.eraseOp(op: elseYield); |
| 2705 | rewriter.eraseOp(op: elseYield2); |
| 2706 | } |
| 2707 | } |
| 2708 | |
| 2709 | SmallVector<Value> prevValues; |
| 2710 | SmallVector<Value> nextValues; |
| 2711 | for (const auto &pair : llvm::enumerate(First: combinedIf.getResults())) { |
| 2712 | if (pair.index() < prevIf.getNumResults()) |
| 2713 | prevValues.push_back(Elt: pair.value()); |
| 2714 | else |
| 2715 | nextValues.push_back(Elt: pair.value()); |
| 2716 | } |
| 2717 | rewriter.replaceOp(op: prevIf, newValues: prevValues); |
| 2718 | rewriter.replaceOp(op: nextIf, newValues: nextValues); |
| 2719 | return success(); |
| 2720 | } |
| 2721 | }; |
| 2722 | |
| 2723 | /// Pattern to remove an empty else branch. |
| 2724 | struct RemoveEmptyElseBranch : public OpRewritePattern<IfOp> { |
| 2725 | using OpRewritePattern<IfOp>::OpRewritePattern; |
| 2726 | |
| 2727 | LogicalResult matchAndRewrite(IfOp ifOp, |
| 2728 | PatternRewriter &rewriter) const override { |
| 2729 | // Cannot remove else region when there are operation results. |
| 2730 | if (ifOp.getNumResults()) |
| 2731 | return failure(); |
| 2732 | Block *elseBlock = ifOp.elseBlock(); |
| 2733 | if (!elseBlock || !llvm::hasSingleElement(C&: *elseBlock)) |
| 2734 | return failure(); |
| 2735 | auto newIfOp = rewriter.cloneWithoutRegions(op: ifOp); |
| 2736 | rewriter.inlineRegionBefore(region&: ifOp.getThenRegion(), parent&: newIfOp.getThenRegion(), |
| 2737 | before: newIfOp.getThenRegion().begin()); |
| 2738 | rewriter.eraseOp(op: ifOp); |
| 2739 | return success(); |
| 2740 | } |
| 2741 | }; |
| 2742 | |
| 2743 | /// Convert nested `if`s into `arith.andi` + single `if`. |
| 2744 | /// |
| 2745 | /// scf.if %arg0 { |
| 2746 | /// scf.if %arg1 { |
| 2747 | /// ... |
| 2748 | /// scf.yield |
| 2749 | /// } |
| 2750 | /// scf.yield |
| 2751 | /// } |
| 2752 | /// becomes |
| 2753 | /// |
| 2754 | /// %0 = arith.andi %arg0, %arg1 |
| 2755 | /// scf.if %0 { |
| 2756 | /// ... |
| 2757 | /// scf.yield |
| 2758 | /// } |
| 2759 | struct CombineNestedIfs : public OpRewritePattern<IfOp> { |
| 2760 | using OpRewritePattern<IfOp>::OpRewritePattern; |
| 2761 | |
| 2762 | LogicalResult matchAndRewrite(IfOp op, |
| 2763 | PatternRewriter &rewriter) const override { |
| 2764 | auto nestedOps = op.thenBlock()->without_terminator(); |
| 2765 | // Nested `if` must be the only op in block. |
| 2766 | if (!llvm::hasSingleElement(C&: nestedOps)) |
| 2767 | return failure(); |
| 2768 | |
| 2769 | // If there is an else block, it can only yield |
| 2770 | if (op.elseBlock() && !llvm::hasSingleElement(C&: *op.elseBlock())) |
| 2771 | return failure(); |
| 2772 | |
| 2773 | auto nestedIf = dyn_cast<IfOp>(Val&: *nestedOps.begin()); |
| 2774 | if (!nestedIf) |
| 2775 | return failure(); |
| 2776 | |
| 2777 | if (nestedIf.elseBlock() && !llvm::hasSingleElement(C&: *nestedIf.elseBlock())) |
| 2778 | return failure(); |
| 2779 | |
| 2780 | SmallVector<Value> thenYield(op.thenYield().getOperands()); |
| 2781 | SmallVector<Value> elseYield; |
| 2782 | if (op.elseBlock()) |
| 2783 | llvm::append_range(C&: elseYield, R: op.elseYield().getOperands()); |
| 2784 | |
| 2785 | // A list of indices for which we should upgrade the value yielded |
| 2786 | // in the else to a select. |
| 2787 | SmallVector<unsigned> elseYieldsToUpgradeToSelect; |
| 2788 | |
| 2789 | // If the outer scf.if yields a value produced by the inner scf.if, |
| 2790 | // only permit combining if the value yielded when the condition |
| 2791 | // is false in the outer scf.if is the same value yielded when the |
| 2792 | // inner scf.if condition is false. |
| 2793 | // Note that the array access to elseYield will not go out of bounds |
| 2794 | // since it must have the same length as thenYield, since they both |
| 2795 | // come from the same scf.if. |
| 2796 | for (const auto &tup : llvm::enumerate(First&: thenYield)) { |
| 2797 | if (tup.value().getDefiningOp() == nestedIf) { |
| 2798 | auto nestedIdx = llvm::cast<OpResult>(Val&: tup.value()).getResultNumber(); |
| 2799 | if (nestedIf.elseYield().getOperand(i: nestedIdx) != |
| 2800 | elseYield[tup.index()]) { |
| 2801 | return failure(); |
| 2802 | } |
| 2803 | // If the correctness test passes, we will yield |
| 2804 | // corresponding value from the inner scf.if |
| 2805 | thenYield[tup.index()] = nestedIf.thenYield().getOperand(i: nestedIdx); |
| 2806 | continue; |
| 2807 | } |
| 2808 | |
| 2809 | // Otherwise, we need to ensure the else block of the combined |
| 2810 | // condition still returns the same value when the outer condition is |
| 2811 | // true and the inner condition is false. This can be accomplished if |
| 2812 | // the then value is defined outside the outer scf.if and we replace the |
| 2813 | // value with a select that considers just the outer condition. Since |
| 2814 | // the else region contains just the yield, its yielded value is |
| 2815 | // defined outside the scf.if, by definition. |
| 2816 | |
| 2817 | // If the then value is defined within the scf.if, bail. |
| 2818 | if (tup.value().getParentRegion() == &op.getThenRegion()) { |
| 2819 | return failure(); |
| 2820 | } |
| 2821 | elseYieldsToUpgradeToSelect.push_back(Elt: tup.index()); |
| 2822 | } |
| 2823 | |
| 2824 | Location loc = op.getLoc(); |
| 2825 | Value newCondition = rewriter.create<arith::AndIOp>( |
| 2826 | location: loc, args: op.getCondition(), args: nestedIf.getCondition()); |
| 2827 | auto newIf = rewriter.create<IfOp>(location: loc, args: op.getResultTypes(), args&: newCondition); |
| 2828 | Block *newIfBlock = rewriter.createBlock(parent: &newIf.getThenRegion()); |
| 2829 | |
| 2830 | SmallVector<Value> results; |
| 2831 | llvm::append_range(C&: results, R: newIf.getResults()); |
| 2832 | rewriter.setInsertionPoint(newIf); |
| 2833 | |
| 2834 | for (auto idx : elseYieldsToUpgradeToSelect) |
| 2835 | results[idx] = rewriter.create<arith::SelectOp>( |
| 2836 | location: op.getLoc(), args: op.getCondition(), args&: thenYield[idx], args&: elseYield[idx]); |
| 2837 | |
| 2838 | rewriter.mergeBlocks(source: nestedIf.thenBlock(), dest: newIfBlock); |
| 2839 | rewriter.setInsertionPointToEnd(newIf.thenBlock()); |
| 2840 | rewriter.replaceOpWithNewOp<YieldOp>(op: newIf.thenYield(), args&: thenYield); |
| 2841 | if (!elseYield.empty()) { |
| 2842 | rewriter.createBlock(parent: &newIf.getElseRegion()); |
| 2843 | rewriter.setInsertionPointToEnd(newIf.elseBlock()); |
| 2844 | rewriter.create<YieldOp>(location: loc, args&: elseYield); |
| 2845 | } |
| 2846 | rewriter.replaceOp(op, newValues: results); |
| 2847 | return success(); |
| 2848 | } |
| 2849 | }; |
| 2850 | |
| 2851 | } // namespace |
| 2852 | |
| 2853 | void IfOp::getCanonicalizationPatterns(RewritePatternSet &results, |
| 2854 | MLIRContext *context) { |
| 2855 | results.add<CombineIfs, CombineNestedIfs, ConditionPropagation, |
| 2856 | ConvertTrivialIfToSelect, RemoveEmptyElseBranch, |
| 2857 | RemoveStaticCondition, RemoveUnusedResults, |
| 2858 | ReplaceIfYieldWithConditionOrValue>(arg&: context); |
| 2859 | } |
| 2860 | |
| 2861 | Block *IfOp::thenBlock() { return &getThenRegion().back(); } |
| 2862 | YieldOp IfOp::thenYield() { return cast<YieldOp>(Val: &thenBlock()->back()); } |
| 2863 | Block *IfOp::elseBlock() { |
| 2864 | Region &r = getElseRegion(); |
| 2865 | if (r.empty()) |
| 2866 | return nullptr; |
| 2867 | return &r.back(); |
| 2868 | } |
| 2869 | YieldOp IfOp::elseYield() { return cast<YieldOp>(Val: &elseBlock()->back()); } |
| 2870 | |
| 2871 | //===----------------------------------------------------------------------===// |
| 2872 | // ParallelOp |
| 2873 | //===----------------------------------------------------------------------===// |
| 2874 | |
| 2875 | void ParallelOp::build( |
| 2876 | OpBuilder &builder, OperationState &result, ValueRange lowerBounds, |
| 2877 | ValueRange upperBounds, ValueRange steps, ValueRange initVals, |
| 2878 | function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)> |
| 2879 | bodyBuilderFn) { |
| 2880 | result.addOperands(newOperands: lowerBounds); |
| 2881 | result.addOperands(newOperands: upperBounds); |
| 2882 | result.addOperands(newOperands: steps); |
| 2883 | result.addOperands(newOperands: initVals); |
| 2884 | result.addAttribute( |
| 2885 | name: ParallelOp::getOperandSegmentSizeAttr(), |
| 2886 | attr: builder.getDenseI32ArrayAttr(values: {static_cast<int32_t>(lowerBounds.size()), |
| 2887 | static_cast<int32_t>(upperBounds.size()), |
| 2888 | static_cast<int32_t>(steps.size()), |
| 2889 | static_cast<int32_t>(initVals.size())})); |
| 2890 | result.addTypes(newTypes: initVals.getTypes()); |
| 2891 | |
| 2892 | OpBuilder::InsertionGuard guard(builder); |
| 2893 | unsigned numIVs = steps.size(); |
| 2894 | SmallVector<Type, 8> argTypes(numIVs, builder.getIndexType()); |
| 2895 | SmallVector<Location, 8> argLocs(numIVs, result.location); |
| 2896 | Region *bodyRegion = result.addRegion(); |
| 2897 | Block *bodyBlock = builder.createBlock(parent: bodyRegion, insertPt: {}, argTypes, locs: argLocs); |
| 2898 | |
| 2899 | if (bodyBuilderFn) { |
| 2900 | builder.setInsertionPointToStart(bodyBlock); |
| 2901 | bodyBuilderFn(builder, result.location, |
| 2902 | bodyBlock->getArguments().take_front(N: numIVs), |
| 2903 | bodyBlock->getArguments().drop_front(N: numIVs)); |
| 2904 | } |
| 2905 | // Add terminator only if there are no reductions. |
| 2906 | if (initVals.empty()) |
| 2907 | ParallelOp::ensureTerminator(region&: *bodyRegion, builder, loc: result.location); |
| 2908 | } |
| 2909 | |
| 2910 | void ParallelOp::build( |
| 2911 | OpBuilder &builder, OperationState &result, ValueRange lowerBounds, |
| 2912 | ValueRange upperBounds, ValueRange steps, |
| 2913 | function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) { |
| 2914 | // Only pass a non-null wrapper if bodyBuilderFn is non-null itself. Make sure |
| 2915 | // we don't capture a reference to a temporary by constructing the lambda at |
| 2916 | // function level. |
| 2917 | auto wrappedBuilderFn = [&bodyBuilderFn](OpBuilder &nestedBuilder, |
| 2918 | Location nestedLoc, ValueRange ivs, |
| 2919 | ValueRange) { |
| 2920 | bodyBuilderFn(nestedBuilder, nestedLoc, ivs); |
| 2921 | }; |
| 2922 | function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)> wrapper; |
| 2923 | if (bodyBuilderFn) |
| 2924 | wrapper = wrappedBuilderFn; |
| 2925 | |
| 2926 | build(builder, result, lowerBounds, upperBounds, steps, initVals: ValueRange(), |
| 2927 | bodyBuilderFn: wrapper); |
| 2928 | } |
| 2929 | |
| 2930 | LogicalResult ParallelOp::verify() { |
| 2931 | // Check that there is at least one value in lowerBound, upperBound and step. |
| 2932 | // It is sufficient to test only step, because it is ensured already that the |
| 2933 | // number of elements in lowerBound, upperBound and step are the same. |
| 2934 | Operation::operand_range stepValues = getStep(); |
| 2935 | if (stepValues.empty()) |
| 2936 | return emitOpError( |
| 2937 | message: "needs at least one tuple element for lowerBound, upperBound and step" ); |
| 2938 | |
| 2939 | // Check whether all constant step values are positive. |
| 2940 | for (Value stepValue : stepValues) |
| 2941 | if (auto cst = getConstantIntValue(ofr: stepValue)) |
| 2942 | if (*cst <= 0) |
| 2943 | return emitOpError(message: "constant step operand must be positive" ); |
| 2944 | |
| 2945 | // Check that the body defines the same number of block arguments as the |
| 2946 | // number of tuple elements in step. |
| 2947 | Block *body = getBody(); |
| 2948 | if (body->getNumArguments() != stepValues.size()) |
| 2949 | return emitOpError() << "expects the same number of induction variables: " |
| 2950 | << body->getNumArguments() |
| 2951 | << " as bound and step values: " << stepValues.size(); |
| 2952 | for (auto arg : body->getArguments()) |
| 2953 | if (!arg.getType().isIndex()) |
| 2954 | return emitOpError( |
| 2955 | message: "expects arguments for the induction variable to be of index type" ); |
| 2956 | |
| 2957 | // Check that the terminator is an scf.reduce op. |
| 2958 | auto reduceOp = verifyAndGetTerminator<scf::ReduceOp>( |
| 2959 | op: *this, region&: getRegion(), errorMessage: "expects body to terminate with 'scf.reduce'" ); |
| 2960 | if (!reduceOp) |
| 2961 | return failure(); |
| 2962 | |
| 2963 | // Check that the number of results is the same as the number of reductions. |
| 2964 | auto resultsSize = getResults().size(); |
| 2965 | auto reductionsSize = reduceOp.getReductions().size(); |
| 2966 | auto initValsSize = getInitVals().size(); |
| 2967 | if (resultsSize != reductionsSize) |
| 2968 | return emitOpError() << "expects number of results: " << resultsSize |
| 2969 | << " to be the same as number of reductions: " |
| 2970 | << reductionsSize; |
| 2971 | if (resultsSize != initValsSize) |
| 2972 | return emitOpError() << "expects number of results: " << resultsSize |
| 2973 | << " to be the same as number of initial values: " |
| 2974 | << initValsSize; |
| 2975 | |
| 2976 | // Check that the types of the results and reductions are the same. |
| 2977 | for (int64_t i = 0; i < static_cast<int64_t>(reductionsSize); ++i) { |
| 2978 | auto resultType = getOperation()->getResult(idx: i).getType(); |
| 2979 | auto reductionOperandType = reduceOp.getOperands()[i].getType(); |
| 2980 | if (resultType != reductionOperandType) |
| 2981 | return reduceOp.emitOpError() |
| 2982 | << "expects type of " << i |
| 2983 | << "-th reduction operand: " << reductionOperandType |
| 2984 | << " to be the same as the " << i |
| 2985 | << "-th result type: " << resultType; |
| 2986 | } |
| 2987 | return success(); |
| 2988 | } |
| 2989 | |
| 2990 | ParseResult ParallelOp::parse(OpAsmParser &parser, OperationState &result) { |
| 2991 | auto &builder = parser.getBuilder(); |
| 2992 | // Parse an opening `(` followed by induction variables followed by `)` |
| 2993 | SmallVector<OpAsmParser::Argument, 4> ivs; |
| 2994 | if (parser.parseArgumentList(result&: ivs, delimiter: OpAsmParser::Delimiter::Paren)) |
| 2995 | return failure(); |
| 2996 | |
| 2997 | // Parse loop bounds. |
| 2998 | SmallVector<OpAsmParser::UnresolvedOperand, 4> lower; |
| 2999 | if (parser.parseEqual() || |
| 3000 | parser.parseOperandList(result&: lower, requiredOperandCount: ivs.size(), |
| 3001 | delimiter: OpAsmParser::Delimiter::Paren) || |
| 3002 | parser.resolveOperands(operands&: lower, type: builder.getIndexType(), result&: result.operands)) |
| 3003 | return failure(); |
| 3004 | |
| 3005 | SmallVector<OpAsmParser::UnresolvedOperand, 4> upper; |
| 3006 | if (parser.parseKeyword(keyword: "to" ) || |
| 3007 | parser.parseOperandList(result&: upper, requiredOperandCount: ivs.size(), |
| 3008 | delimiter: OpAsmParser::Delimiter::Paren) || |
| 3009 | parser.resolveOperands(operands&: upper, type: builder.getIndexType(), result&: result.operands)) |
| 3010 | return failure(); |
| 3011 | |
| 3012 | // Parse step values. |
| 3013 | SmallVector<OpAsmParser::UnresolvedOperand, 4> steps; |
| 3014 | if (parser.parseKeyword(keyword: "step" ) || |
| 3015 | parser.parseOperandList(result&: steps, requiredOperandCount: ivs.size(), |
| 3016 | delimiter: OpAsmParser::Delimiter::Paren) || |
| 3017 | parser.resolveOperands(operands&: steps, type: builder.getIndexType(), result&: result.operands)) |
| 3018 | return failure(); |
| 3019 | |
| 3020 | // Parse init values. |
| 3021 | SmallVector<OpAsmParser::UnresolvedOperand, 4> initVals; |
| 3022 | if (succeeded(Result: parser.parseOptionalKeyword(keyword: "init" ))) { |
| 3023 | if (parser.parseOperandList(result&: initVals, delimiter: OpAsmParser::Delimiter::Paren)) |
| 3024 | return failure(); |
| 3025 | } |
| 3026 | |
| 3027 | // Parse optional results in case there is a reduce. |
| 3028 | if (parser.parseOptionalArrowTypeList(result&: result.types)) |
| 3029 | return failure(); |
| 3030 | |
| 3031 | // Now parse the body. |
| 3032 | Region *body = result.addRegion(); |
| 3033 | for (auto &iv : ivs) |
| 3034 | iv.type = builder.getIndexType(); |
| 3035 | if (parser.parseRegion(region&: *body, arguments: ivs)) |
| 3036 | return failure(); |
| 3037 | |
| 3038 | // Set `operandSegmentSizes` attribute. |
| 3039 | result.addAttribute( |
| 3040 | name: ParallelOp::getOperandSegmentSizeAttr(), |
| 3041 | attr: builder.getDenseI32ArrayAttr(values: {static_cast<int32_t>(lower.size()), |
| 3042 | static_cast<int32_t>(upper.size()), |
| 3043 | static_cast<int32_t>(steps.size()), |
| 3044 | static_cast<int32_t>(initVals.size())})); |
| 3045 | |
| 3046 | // Parse attributes. |
| 3047 | if (parser.parseOptionalAttrDict(result&: result.attributes) || |
| 3048 | parser.resolveOperands(operands&: initVals, types&: result.types, loc: parser.getNameLoc(), |
| 3049 | result&: result.operands)) |
| 3050 | return failure(); |
| 3051 | |
| 3052 | // Add a terminator if none was parsed. |
| 3053 | ParallelOp::ensureTerminator(region&: *body, builder, loc: result.location); |
| 3054 | return success(); |
| 3055 | } |
| 3056 | |
| 3057 | void ParallelOp::print(OpAsmPrinter &p) { |
| 3058 | p << " (" << getBody()->getArguments() << ") = (" << getLowerBound() |
| 3059 | << ") to (" << getUpperBound() << ") step (" << getStep() << ")" ; |
| 3060 | if (!getInitVals().empty()) |
| 3061 | p << " init (" << getInitVals() << ")" ; |
| 3062 | p.printOptionalArrowTypeList(types: getResultTypes()); |
| 3063 | p << ' '; |
| 3064 | p.printRegion(blocks&: getRegion(), /*printEntryBlockArgs=*/false); |
| 3065 | p.printOptionalAttrDict( |
| 3066 | attrs: (*this)->getAttrs(), |
| 3067 | /*elidedAttrs=*/ParallelOp::getOperandSegmentSizeAttr()); |
| 3068 | } |
| 3069 | |
| 3070 | SmallVector<Region *> ParallelOp::getLoopRegions() { return {&getRegion()}; } |
| 3071 | |
| 3072 | std::optional<SmallVector<Value>> ParallelOp::getLoopInductionVars() { |
| 3073 | return SmallVector<Value>{getBody()->getArguments()}; |
| 3074 | } |
| 3075 | |
| 3076 | std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopLowerBounds() { |
| 3077 | return getLowerBound(); |
| 3078 | } |
| 3079 | |
| 3080 | std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopUpperBounds() { |
| 3081 | return getUpperBound(); |
| 3082 | } |
| 3083 | |
| 3084 | std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopSteps() { |
| 3085 | return getStep(); |
| 3086 | } |
| 3087 | |
| 3088 | ParallelOp mlir::scf::getParallelForInductionVarOwner(Value val) { |
| 3089 | auto ivArg = llvm::dyn_cast<BlockArgument>(Val&: val); |
| 3090 | if (!ivArg) |
| 3091 | return ParallelOp(); |
| 3092 | assert(ivArg.getOwner() && "unlinked block argument" ); |
| 3093 | auto *containingOp = ivArg.getOwner()->getParentOp(); |
| 3094 | return dyn_cast<ParallelOp>(Val: containingOp); |
| 3095 | } |
| 3096 | |
| 3097 | namespace { |
| 3098 | // Collapse loop dimensions that perform a single iteration. |
| 3099 | struct ParallelOpSingleOrZeroIterationDimsFolder |
| 3100 | : public OpRewritePattern<ParallelOp> { |
| 3101 | using OpRewritePattern<ParallelOp>::OpRewritePattern; |
| 3102 | |
| 3103 | LogicalResult matchAndRewrite(ParallelOp op, |
| 3104 | PatternRewriter &rewriter) const override { |
| 3105 | Location loc = op.getLoc(); |
| 3106 | |
| 3107 | // Compute new loop bounds that omit all single-iteration loop dimensions. |
| 3108 | SmallVector<Value> newLowerBounds, newUpperBounds, newSteps; |
| 3109 | IRMapping mapping; |
| 3110 | for (auto [lb, ub, step, iv] : |
| 3111 | llvm::zip(t: op.getLowerBound(), u: op.getUpperBound(), args: op.getStep(), |
| 3112 | args: op.getInductionVars())) { |
| 3113 | auto numIterations = constantTripCount(lb, ub, step); |
| 3114 | if (numIterations.has_value()) { |
| 3115 | // Remove the loop if it performs zero iterations. |
| 3116 | if (*numIterations == 0) { |
| 3117 | rewriter.replaceOp(op, newValues: op.getInitVals()); |
| 3118 | return success(); |
| 3119 | } |
| 3120 | // Replace the loop induction variable by the lower bound if the loop |
| 3121 | // performs a single iteration. Otherwise, copy the loop bounds. |
| 3122 | if (*numIterations == 1) { |
| 3123 | mapping.map(from: iv, to: getValueOrCreateConstantIndexOp(b&: rewriter, loc, ofr: lb)); |
| 3124 | continue; |
| 3125 | } |
| 3126 | } |
| 3127 | newLowerBounds.push_back(Elt: lb); |
| 3128 | newUpperBounds.push_back(Elt: ub); |
| 3129 | newSteps.push_back(Elt: step); |
| 3130 | } |
| 3131 | // Exit if none of the loop dimensions perform a single iteration. |
| 3132 | if (newLowerBounds.size() == op.getLowerBound().size()) |
| 3133 | return failure(); |
| 3134 | |
| 3135 | if (newLowerBounds.empty()) { |
| 3136 | // All of the loop dimensions perform a single iteration. Inline |
| 3137 | // loop body and nested ReduceOp's |
| 3138 | SmallVector<Value> results; |
| 3139 | results.reserve(N: op.getInitVals().size()); |
| 3140 | for (auto &bodyOp : op.getBody()->without_terminator()) |
| 3141 | rewriter.clone(op&: bodyOp, mapper&: mapping); |
| 3142 | auto reduceOp = cast<ReduceOp>(Val: op.getBody()->getTerminator()); |
| 3143 | for (int64_t i = 0, e = reduceOp.getReductions().size(); i < e; ++i) { |
| 3144 | Block &reduceBlock = reduceOp.getReductions()[i].front(); |
| 3145 | auto initValIndex = results.size(); |
| 3146 | mapping.map(from: reduceBlock.getArgument(i: 0), to: op.getInitVals()[initValIndex]); |
| 3147 | mapping.map(from: reduceBlock.getArgument(i: 1), |
| 3148 | to: mapping.lookupOrDefault(from: reduceOp.getOperands()[i])); |
| 3149 | for (auto &reduceBodyOp : reduceBlock.without_terminator()) |
| 3150 | rewriter.clone(op&: reduceBodyOp, mapper&: mapping); |
| 3151 | |
| 3152 | auto result = mapping.lookupOrDefault( |
| 3153 | from: cast<ReduceReturnOp>(Val: reduceBlock.getTerminator()).getResult()); |
| 3154 | results.push_back(Elt: result); |
| 3155 | } |
| 3156 | |
| 3157 | rewriter.replaceOp(op, newValues: results); |
| 3158 | return success(); |
| 3159 | } |
| 3160 | // Replace the parallel loop by lower-dimensional parallel loop. |
| 3161 | auto newOp = |
| 3162 | rewriter.create<ParallelOp>(location: op.getLoc(), args&: newLowerBounds, args&: newUpperBounds, |
| 3163 | args&: newSteps, args: op.getInitVals(), args: nullptr); |
| 3164 | // Erase the empty block that was inserted by the builder. |
| 3165 | rewriter.eraseBlock(block: newOp.getBody()); |
| 3166 | // Clone the loop body and remap the block arguments of the collapsed loops |
| 3167 | // (inlining does not support a cancellable block argument mapping). |
| 3168 | rewriter.cloneRegionBefore(region&: op.getRegion(), parent&: newOp.getRegion(), |
| 3169 | before: newOp.getRegion().begin(), mapping); |
| 3170 | rewriter.replaceOp(op, newValues: newOp.getResults()); |
| 3171 | return success(); |
| 3172 | } |
| 3173 | }; |
| 3174 | |
| 3175 | struct MergeNestedParallelLoops : public OpRewritePattern<ParallelOp> { |
| 3176 | using OpRewritePattern<ParallelOp>::OpRewritePattern; |
| 3177 | |
| 3178 | LogicalResult matchAndRewrite(ParallelOp op, |
| 3179 | PatternRewriter &rewriter) const override { |
| 3180 | Block &outerBody = *op.getBody(); |
| 3181 | if (!llvm::hasSingleElement(C: outerBody.without_terminator())) |
| 3182 | return failure(); |
| 3183 | |
| 3184 | auto innerOp = dyn_cast<ParallelOp>(Val&: outerBody.front()); |
| 3185 | if (!innerOp) |
| 3186 | return failure(); |
| 3187 | |
| 3188 | for (auto val : outerBody.getArguments()) |
| 3189 | if (llvm::is_contained(Range: innerOp.getLowerBound(), Element: val) || |
| 3190 | llvm::is_contained(Range: innerOp.getUpperBound(), Element: val) || |
| 3191 | llvm::is_contained(Range: innerOp.getStep(), Element: val)) |
| 3192 | return failure(); |
| 3193 | |
| 3194 | // Reductions are not supported yet. |
| 3195 | if (!op.getInitVals().empty() || !innerOp.getInitVals().empty()) |
| 3196 | return failure(); |
| 3197 | |
| 3198 | auto bodyBuilder = [&](OpBuilder &builder, Location /*loc*/, |
| 3199 | ValueRange iterVals, ValueRange) { |
| 3200 | Block &innerBody = *innerOp.getBody(); |
| 3201 | assert(iterVals.size() == |
| 3202 | (outerBody.getNumArguments() + innerBody.getNumArguments())); |
| 3203 | IRMapping mapping; |
| 3204 | mapping.map(from: outerBody.getArguments(), |
| 3205 | to: iterVals.take_front(n: outerBody.getNumArguments())); |
| 3206 | mapping.map(from: innerBody.getArguments(), |
| 3207 | to: iterVals.take_back(n: innerBody.getNumArguments())); |
| 3208 | for (Operation &op : innerBody.without_terminator()) |
| 3209 | builder.clone(op, mapper&: mapping); |
| 3210 | }; |
| 3211 | |
| 3212 | auto concatValues = [](const auto &first, const auto &second) { |
| 3213 | SmallVector<Value> ret; |
| 3214 | ret.reserve(N: first.size() + second.size()); |
| 3215 | ret.assign(first.begin(), first.end()); |
| 3216 | ret.append(second.begin(), second.end()); |
| 3217 | return ret; |
| 3218 | }; |
| 3219 | |
| 3220 | auto newLowerBounds = |
| 3221 | concatValues(op.getLowerBound(), innerOp.getLowerBound()); |
| 3222 | auto newUpperBounds = |
| 3223 | concatValues(op.getUpperBound(), innerOp.getUpperBound()); |
| 3224 | auto newSteps = concatValues(op.getStep(), innerOp.getStep()); |
| 3225 | |
| 3226 | rewriter.replaceOpWithNewOp<ParallelOp>(op, args&: newLowerBounds, args&: newUpperBounds, |
| 3227 | args&: newSteps, args: ValueRange(), |
| 3228 | args&: bodyBuilder); |
| 3229 | return success(); |
| 3230 | } |
| 3231 | }; |
| 3232 | |
| 3233 | } // namespace |
| 3234 | |
| 3235 | void ParallelOp::getCanonicalizationPatterns(RewritePatternSet &results, |
| 3236 | MLIRContext *context) { |
| 3237 | results |
| 3238 | .add<ParallelOpSingleOrZeroIterationDimsFolder, MergeNestedParallelLoops>( |
| 3239 | arg&: context); |
| 3240 | } |
| 3241 | |
| 3242 | /// Given the region at `index`, or the parent operation if `index` is None, |
| 3243 | /// return the successor regions. These are the regions that may be selected |
| 3244 | /// during the flow of control. `operands` is a set of optional attributes that |
| 3245 | /// correspond to a constant value for each operand, or null if that operand is |
| 3246 | /// not a constant. |
| 3247 | void ParallelOp::getSuccessorRegions( |
| 3248 | RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { |
| 3249 | // Both the operation itself and the region may be branching into the body or |
| 3250 | // back into the operation itself. It is possible for loop not to enter the |
| 3251 | // body. |
| 3252 | regions.push_back(Elt: RegionSuccessor(&getRegion())); |
| 3253 | regions.push_back(Elt: RegionSuccessor()); |
| 3254 | } |
| 3255 | |
| 3256 | //===----------------------------------------------------------------------===// |
| 3257 | // ReduceOp |
| 3258 | //===----------------------------------------------------------------------===// |
| 3259 | |
| 3260 | void ReduceOp::build(OpBuilder &builder, OperationState &result) {} |
| 3261 | |
| 3262 | void ReduceOp::build(OpBuilder &builder, OperationState &result, |
| 3263 | ValueRange operands) { |
| 3264 | result.addOperands(newOperands: operands); |
| 3265 | for (Value v : operands) { |
| 3266 | OpBuilder::InsertionGuard guard(builder); |
| 3267 | Region *bodyRegion = result.addRegion(); |
| 3268 | builder.createBlock(parent: bodyRegion, insertPt: {}, |
| 3269 | argTypes: ArrayRef<Type>{v.getType(), v.getType()}, |
| 3270 | locs: {result.location, result.location}); |
| 3271 | } |
| 3272 | } |
| 3273 | |
| 3274 | LogicalResult ReduceOp::verifyRegions() { |
| 3275 | // The region of a ReduceOp has two arguments of the same type as its |
| 3276 | // corresponding operand. |
| 3277 | for (int64_t i = 0, e = getReductions().size(); i < e; ++i) { |
| 3278 | auto type = getOperands()[i].getType(); |
| 3279 | Block &block = getReductions()[i].front(); |
| 3280 | if (block.empty()) |
| 3281 | return emitOpError() << i << "-th reduction has an empty body" ; |
| 3282 | if (block.getNumArguments() != 2 || |
| 3283 | llvm::any_of(Range: block.getArguments(), P: [&](const BlockArgument &arg) { |
| 3284 | return arg.getType() != type; |
| 3285 | })) |
| 3286 | return emitOpError() << "expected two block arguments with type " << type |
| 3287 | << " in the " << i << "-th reduction region" ; |
| 3288 | |
| 3289 | // Check that the block is terminated by a ReduceReturnOp. |
| 3290 | if (!isa<ReduceReturnOp>(Val: block.getTerminator())) |
| 3291 | return emitOpError(message: "reduction bodies must be terminated with an " |
| 3292 | "'scf.reduce.return' op" ); |
| 3293 | } |
| 3294 | |
| 3295 | return success(); |
| 3296 | } |
| 3297 | |
| 3298 | MutableOperandRange |
| 3299 | ReduceOp::getMutableSuccessorOperands(RegionBranchPoint point) { |
| 3300 | // No operands are forwarded to the next iteration. |
| 3301 | return MutableOperandRange(getOperation(), /*start=*/0, /*length=*/0); |
| 3302 | } |
| 3303 | |
| 3304 | //===----------------------------------------------------------------------===// |
| 3305 | // ReduceReturnOp |
| 3306 | //===----------------------------------------------------------------------===// |
| 3307 | |
| 3308 | LogicalResult ReduceReturnOp::verify() { |
| 3309 | // The type of the return value should be the same type as the types of the |
| 3310 | // block arguments of the reduction body. |
| 3311 | Block *reductionBody = getOperation()->getBlock(); |
| 3312 | // Should already be verified by an op trait. |
| 3313 | assert(isa<ReduceOp>(reductionBody->getParentOp()) && "expected scf.reduce" ); |
| 3314 | Type expectedResultType = reductionBody->getArgument(i: 0).getType(); |
| 3315 | if (expectedResultType != getResult().getType()) |
| 3316 | return emitOpError() << "must have type " << expectedResultType |
| 3317 | << " (the type of the reduction inputs)" ; |
| 3318 | return success(); |
| 3319 | } |
| 3320 | |
| 3321 | //===----------------------------------------------------------------------===// |
| 3322 | // WhileOp |
| 3323 | //===----------------------------------------------------------------------===// |
| 3324 | |
| 3325 | void WhileOp::build(::mlir::OpBuilder &odsBuilder, |
| 3326 | ::mlir::OperationState &odsState, TypeRange resultTypes, |
| 3327 | ValueRange inits, BodyBuilderFn beforeBuilder, |
| 3328 | BodyBuilderFn afterBuilder) { |
| 3329 | odsState.addOperands(newOperands: inits); |
| 3330 | odsState.addTypes(newTypes&: resultTypes); |
| 3331 | |
| 3332 | OpBuilder::InsertionGuard guard(odsBuilder); |
| 3333 | |
| 3334 | // Build before region. |
| 3335 | SmallVector<Location, 4> beforeArgLocs; |
| 3336 | beforeArgLocs.reserve(N: inits.size()); |
| 3337 | for (Value operand : inits) { |
| 3338 | beforeArgLocs.push_back(Elt: operand.getLoc()); |
| 3339 | } |
| 3340 | |
| 3341 | Region *beforeRegion = odsState.addRegion(); |
| 3342 | Block *beforeBlock = odsBuilder.createBlock(parent: beforeRegion, /*insertPt=*/{}, |
| 3343 | argTypes: inits.getTypes(), locs: beforeArgLocs); |
| 3344 | if (beforeBuilder) |
| 3345 | beforeBuilder(odsBuilder, odsState.location, beforeBlock->getArguments()); |
| 3346 | |
| 3347 | // Build after region. |
| 3348 | SmallVector<Location, 4> afterArgLocs(resultTypes.size(), odsState.location); |
| 3349 | |
| 3350 | Region *afterRegion = odsState.addRegion(); |
| 3351 | Block *afterBlock = odsBuilder.createBlock(parent: afterRegion, /*insertPt=*/{}, |
| 3352 | argTypes: resultTypes, locs: afterArgLocs); |
| 3353 | |
| 3354 | if (afterBuilder) |
| 3355 | afterBuilder(odsBuilder, odsState.location, afterBlock->getArguments()); |
| 3356 | } |
| 3357 | |
| 3358 | ConditionOp WhileOp::getConditionOp() { |
| 3359 | return cast<ConditionOp>(Val: getBeforeBody()->getTerminator()); |
| 3360 | } |
| 3361 | |
| 3362 | YieldOp WhileOp::getYieldOp() { |
| 3363 | return cast<YieldOp>(Val: getAfterBody()->getTerminator()); |
| 3364 | } |
| 3365 | |
| 3366 | std::optional<MutableArrayRef<OpOperand>> WhileOp::getYieldedValuesMutable() { |
| 3367 | return getYieldOp().getResultsMutable(); |
| 3368 | } |
| 3369 | |
| 3370 | Block::BlockArgListType WhileOp::getBeforeArguments() { |
| 3371 | return getBeforeBody()->getArguments(); |
| 3372 | } |
| 3373 | |
| 3374 | Block::BlockArgListType WhileOp::getAfterArguments() { |
| 3375 | return getAfterBody()->getArguments(); |
| 3376 | } |
| 3377 | |
| 3378 | Block::BlockArgListType WhileOp::getRegionIterArgs() { |
| 3379 | return getBeforeArguments(); |
| 3380 | } |
| 3381 | |
| 3382 | OperandRange WhileOp::getEntrySuccessorOperands(RegionBranchPoint point) { |
| 3383 | assert(point == getBefore() && |
| 3384 | "WhileOp is expected to branch only to the first region" ); |
| 3385 | return getInits(); |
| 3386 | } |
| 3387 | |
| 3388 | void WhileOp::getSuccessorRegions(RegionBranchPoint point, |
| 3389 | SmallVectorImpl<RegionSuccessor> ®ions) { |
| 3390 | // The parent op always branches to the condition region. |
| 3391 | if (point.isParent()) { |
| 3392 | regions.emplace_back(Args: &getBefore(), Args: getBefore().getArguments()); |
| 3393 | return; |
| 3394 | } |
| 3395 | |
| 3396 | assert(llvm::is_contained({&getAfter(), &getBefore()}, point) && |
| 3397 | "there are only two regions in a WhileOp" ); |
| 3398 | // The body region always branches back to the condition region. |
| 3399 | if (point == getAfter()) { |
| 3400 | regions.emplace_back(Args: &getBefore(), Args: getBefore().getArguments()); |
| 3401 | return; |
| 3402 | } |
| 3403 | |
| 3404 | regions.emplace_back(Args: getResults()); |
| 3405 | regions.emplace_back(Args: &getAfter(), Args: getAfter().getArguments()); |
| 3406 | } |
| 3407 | |
| 3408 | SmallVector<Region *> WhileOp::getLoopRegions() { |
| 3409 | return {&getBefore(), &getAfter()}; |
| 3410 | } |
| 3411 | |
| 3412 | /// Parses a `while` op. |
| 3413 | /// |
| 3414 | /// op ::= `scf.while` assignments `:` function-type region `do` region |
| 3415 | /// `attributes` attribute-dict |
| 3416 | /// initializer ::= /* empty */ | `(` assignment-list `)` |
| 3417 | /// assignment-list ::= assignment | assignment `,` assignment-list |
| 3418 | /// assignment ::= ssa-value `=` ssa-value |
| 3419 | ParseResult scf::WhileOp::parse(OpAsmParser &parser, OperationState &result) { |
| 3420 | SmallVector<OpAsmParser::Argument, 4> regionArgs; |
| 3421 | SmallVector<OpAsmParser::UnresolvedOperand, 4> operands; |
| 3422 | Region *before = result.addRegion(); |
| 3423 | Region *after = result.addRegion(); |
| 3424 | |
| 3425 | OptionalParseResult listResult = |
| 3426 | parser.parseOptionalAssignmentList(lhs&: regionArgs, rhs&: operands); |
| 3427 | if (listResult.has_value() && failed(Result: listResult.value())) |
| 3428 | return failure(); |
| 3429 | |
| 3430 | FunctionType functionType; |
| 3431 | SMLoc typeLoc = parser.getCurrentLocation(); |
| 3432 | if (failed(Result: parser.parseColonType(result&: functionType))) |
| 3433 | return failure(); |
| 3434 | |
| 3435 | result.addTypes(newTypes: functionType.getResults()); |
| 3436 | |
| 3437 | if (functionType.getNumInputs() != operands.size()) { |
| 3438 | return parser.emitError(loc: typeLoc) |
| 3439 | << "expected as many input types as operands " |
| 3440 | << "(expected " << operands.size() << " got " |
| 3441 | << functionType.getNumInputs() << ")" ; |
| 3442 | } |
| 3443 | |
| 3444 | // Resolve input operands. |
| 3445 | if (failed(Result: parser.resolveOperands(operands, types: functionType.getInputs(), |
| 3446 | loc: parser.getCurrentLocation(), |
| 3447 | result&: result.operands))) |
| 3448 | return failure(); |
| 3449 | |
| 3450 | // Propagate the types into the region arguments. |
| 3451 | for (size_t i = 0, e = regionArgs.size(); i != e; ++i) |
| 3452 | regionArgs[i].type = functionType.getInput(i); |
| 3453 | |
| 3454 | return failure(IsFailure: parser.parseRegion(region&: *before, arguments: regionArgs) || |
| 3455 | parser.parseKeyword(keyword: "do" ) || parser.parseRegion(region&: *after) || |
| 3456 | parser.parseOptionalAttrDictWithKeyword(result&: result.attributes)); |
| 3457 | } |
| 3458 | |
| 3459 | /// Prints a `while` op. |
| 3460 | void scf::WhileOp::print(OpAsmPrinter &p) { |
| 3461 | printInitializationList(p, blocksArgs: getBeforeArguments(), initializers: getInits(), prefix: " " ); |
| 3462 | p << " : " ; |
| 3463 | p.printFunctionalType(inputs: getInits().getTypes(), results: getResults().getTypes()); |
| 3464 | p << ' '; |
| 3465 | p.printRegion(blocks&: getBefore(), /*printEntryBlockArgs=*/false); |
| 3466 | p << " do " ; |
| 3467 | p.printRegion(blocks&: getAfter()); |
| 3468 | p.printOptionalAttrDictWithKeyword(attrs: (*this)->getAttrs()); |
| 3469 | } |
| 3470 | |
| 3471 | /// Verifies that two ranges of types match, i.e. have the same number of |
| 3472 | /// entries and that types are pairwise equals. Reports errors on the given |
| 3473 | /// operation in case of mismatch. |
| 3474 | template <typename OpTy> |
| 3475 | static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left, |
| 3476 | TypeRange right, StringRef message) { |
| 3477 | if (left.size() != right.size()) |
| 3478 | return op.emitOpError("expects the same number of " ) << message; |
| 3479 | |
| 3480 | for (unsigned i = 0, e = left.size(); i < e; ++i) { |
| 3481 | if (left[i] != right[i]) { |
| 3482 | InFlightDiagnostic diag = op.emitOpError("expects the same types for " ) |
| 3483 | << message; |
| 3484 | diag.attachNote() << "for argument " << i << ", found " << left[i] |
| 3485 | << " and " << right[i]; |
| 3486 | return diag; |
| 3487 | } |
| 3488 | } |
| 3489 | |
| 3490 | return success(); |
| 3491 | } |
| 3492 | |
| 3493 | LogicalResult scf::WhileOp::verify() { |
| 3494 | auto beforeTerminator = verifyAndGetTerminator<scf::ConditionOp>( |
| 3495 | op: *this, region&: getBefore(), |
| 3496 | errorMessage: "expects the 'before' region to terminate with 'scf.condition'" ); |
| 3497 | if (!beforeTerminator) |
| 3498 | return failure(); |
| 3499 | |
| 3500 | auto afterTerminator = verifyAndGetTerminator<scf::YieldOp>( |
| 3501 | op: *this, region&: getAfter(), |
| 3502 | errorMessage: "expects the 'after' region to terminate with 'scf.yield'" ); |
| 3503 | return success(IsSuccess: afterTerminator != nullptr); |
| 3504 | } |
| 3505 | |
| 3506 | namespace { |
| 3507 | /// Replace uses of the condition within the do block with true, since otherwise |
| 3508 | /// the block would not be evaluated. |
| 3509 | /// |
| 3510 | /// scf.while (..) : (i1, ...) -> ... { |
| 3511 | /// %condition = call @evaluate_condition() : () -> i1 |
| 3512 | /// scf.condition(%condition) %condition : i1, ... |
| 3513 | /// } do { |
| 3514 | /// ^bb0(%arg0: i1, ...): |
| 3515 | /// use(%arg0) |
| 3516 | /// ... |
| 3517 | /// |
| 3518 | /// becomes |
| 3519 | /// scf.while (..) : (i1, ...) -> ... { |
| 3520 | /// %condition = call @evaluate_condition() : () -> i1 |
| 3521 | /// scf.condition(%condition) %condition : i1, ... |
| 3522 | /// } do { |
| 3523 | /// ^bb0(%arg0: i1, ...): |
| 3524 | /// use(%true) |
| 3525 | /// ... |
| 3526 | struct WhileConditionTruth : public OpRewritePattern<WhileOp> { |
| 3527 | using OpRewritePattern<WhileOp>::OpRewritePattern; |
| 3528 | |
| 3529 | LogicalResult matchAndRewrite(WhileOp op, |
| 3530 | PatternRewriter &rewriter) const override { |
| 3531 | auto term = op.getConditionOp(); |
| 3532 | |
| 3533 | // These variables serve to prevent creating duplicate constants |
| 3534 | // and hold constant true or false values. |
| 3535 | Value constantTrue = nullptr; |
| 3536 | |
| 3537 | bool replaced = false; |
| 3538 | for (auto yieldedAndBlockArgs : |
| 3539 | llvm::zip(t: term.getArgs(), u: op.getAfterArguments())) { |
| 3540 | if (std::get<0>(t&: yieldedAndBlockArgs) == term.getCondition()) { |
| 3541 | if (!std::get<1>(t&: yieldedAndBlockArgs).use_empty()) { |
| 3542 | if (!constantTrue) |
| 3543 | constantTrue = rewriter.create<arith::ConstantOp>( |
| 3544 | location: op.getLoc(), args: term.getCondition().getType(), |
| 3545 | args: rewriter.getBoolAttr(value: true)); |
| 3546 | |
| 3547 | rewriter.replaceAllUsesWith(from: std::get<1>(t&: yieldedAndBlockArgs), |
| 3548 | to: constantTrue); |
| 3549 | replaced = true; |
| 3550 | } |
| 3551 | } |
| 3552 | } |
| 3553 | return success(IsSuccess: replaced); |
| 3554 | } |
| 3555 | }; |
| 3556 | |
| 3557 | /// Remove loop invariant arguments from `before` block of scf.while. |
| 3558 | /// A before block argument is considered loop invariant if :- |
| 3559 | /// 1. i-th yield operand is equal to the i-th while operand. |
| 3560 | /// 2. i-th yield operand is k-th after block argument which is (k+1)-th |
| 3561 | /// condition operand AND this (k+1)-th condition operand is equal to i-th |
| 3562 | /// iter argument/while operand. |
| 3563 | /// For the arguments which are removed, their uses inside scf.while |
| 3564 | /// are replaced with their corresponding initial value. |
| 3565 | /// |
| 3566 | /// Eg: |
| 3567 | /// INPUT :- |
| 3568 | /// %res = scf.while <...> iter_args(%arg0_before = %a, %arg1_before = %b, |
| 3569 | /// ..., %argN_before = %N) |
| 3570 | /// { |
| 3571 | /// ... |
| 3572 | /// scf.condition(%cond) %arg1_before, %arg0_before, |
| 3573 | /// %arg2_before, %arg0_before, ... |
| 3574 | /// } do { |
| 3575 | /// ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2, |
| 3576 | /// ..., %argK_after): |
| 3577 | /// ... |
| 3578 | /// scf.yield %arg0_after_2, %b, %arg1_after, ..., %argN |
| 3579 | /// } |
| 3580 | /// |
| 3581 | /// OUTPUT :- |
| 3582 | /// %res = scf.while <...> iter_args(%arg2_before = %c, ..., %argN_before = |
| 3583 | /// %N) |
| 3584 | /// { |
| 3585 | /// ... |
| 3586 | /// scf.condition(%cond) %b, %a, %arg2_before, %a, ... |
| 3587 | /// } do { |
| 3588 | /// ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2, |
| 3589 | /// ..., %argK_after): |
| 3590 | /// ... |
| 3591 | /// scf.yield %arg1_after, ..., %argN |
| 3592 | /// } |
| 3593 | /// |
| 3594 | /// EXPLANATION: |
| 3595 | /// We iterate over each yield operand. |
| 3596 | /// 1. 0-th yield operand %arg0_after_2 is 4-th condition operand |
| 3597 | /// %arg0_before, which in turn is the 0-th iter argument. So we |
| 3598 | /// remove 0-th before block argument and yield operand, and replace |
| 3599 | /// all uses of the 0-th before block argument with its initial value |
| 3600 | /// %a. |
| 3601 | /// 2. 1-th yield operand %b is equal to the 1-th iter arg's initial |
| 3602 | /// value. So we remove this operand and the corresponding before |
| 3603 | /// block argument and replace all uses of 1-th before block argument |
| 3604 | /// with %b. |
| 3605 | struct RemoveLoopInvariantArgsFromBeforeBlock |
| 3606 | : public OpRewritePattern<WhileOp> { |
| 3607 | using OpRewritePattern<WhileOp>::OpRewritePattern; |
| 3608 | |
| 3609 | LogicalResult matchAndRewrite(WhileOp op, |
| 3610 | PatternRewriter &rewriter) const override { |
| 3611 | Block &afterBlock = *op.getAfterBody(); |
| 3612 | Block::BlockArgListType beforeBlockArgs = op.getBeforeArguments(); |
| 3613 | ConditionOp condOp = op.getConditionOp(); |
| 3614 | OperandRange condOpArgs = condOp.getArgs(); |
| 3615 | Operation *yieldOp = afterBlock.getTerminator(); |
| 3616 | ValueRange yieldOpArgs = yieldOp->getOperands(); |
| 3617 | |
| 3618 | bool canSimplify = false; |
| 3619 | for (const auto &it : |
| 3620 | llvm::enumerate(First: llvm::zip(t: op.getOperands(), u&: yieldOpArgs))) { |
| 3621 | auto index = static_cast<unsigned>(it.index()); |
| 3622 | auto [initVal, yieldOpArg] = it.value(); |
| 3623 | // If i-th yield operand is equal to the i-th operand of the scf.while, |
| 3624 | // the i-th before block argument is a loop invariant. |
| 3625 | if (yieldOpArg == initVal) { |
| 3626 | canSimplify = true; |
| 3627 | break; |
| 3628 | } |
| 3629 | // If the i-th yield operand is k-th after block argument, then we check |
| 3630 | // if the (k+1)-th condition op operand is equal to either the i-th before |
| 3631 | // block argument or the initial value of i-th before block argument. If |
| 3632 | // the comparison results `true`, i-th before block argument is a loop |
| 3633 | // invariant. |
| 3634 | auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(Val&: yieldOpArg); |
| 3635 | if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) { |
| 3636 | Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()]; |
| 3637 | if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) { |
| 3638 | canSimplify = true; |
| 3639 | break; |
| 3640 | } |
| 3641 | } |
| 3642 | } |
| 3643 | |
| 3644 | if (!canSimplify) |
| 3645 | return failure(); |
| 3646 | |
| 3647 | SmallVector<Value> newInitArgs, newYieldOpArgs; |
| 3648 | DenseMap<unsigned, Value> beforeBlockInitValMap; |
| 3649 | SmallVector<Location> newBeforeBlockArgLocs; |
| 3650 | for (const auto &it : |
| 3651 | llvm::enumerate(First: llvm::zip(t: op.getOperands(), u&: yieldOpArgs))) { |
| 3652 | auto index = static_cast<unsigned>(it.index()); |
| 3653 | auto [initVal, yieldOpArg] = it.value(); |
| 3654 | |
| 3655 | // If i-th yield operand is equal to the i-th operand of the scf.while, |
| 3656 | // the i-th before block argument is a loop invariant. |
| 3657 | if (yieldOpArg == initVal) { |
| 3658 | beforeBlockInitValMap.insert(KV: {index, initVal}); |
| 3659 | continue; |
| 3660 | } else { |
| 3661 | // If the i-th yield operand is k-th after block argument, then we check |
| 3662 | // if the (k+1)-th condition op operand is equal to either the i-th |
| 3663 | // before block argument or the initial value of i-th before block |
| 3664 | // argument. If the comparison results `true`, i-th before block |
| 3665 | // argument is a loop invariant. |
| 3666 | auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(Val&: yieldOpArg); |
| 3667 | if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) { |
| 3668 | Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()]; |
| 3669 | if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) { |
| 3670 | beforeBlockInitValMap.insert(KV: {index, initVal}); |
| 3671 | continue; |
| 3672 | } |
| 3673 | } |
| 3674 | } |
| 3675 | newInitArgs.emplace_back(Args&: initVal); |
| 3676 | newYieldOpArgs.emplace_back(Args&: yieldOpArg); |
| 3677 | newBeforeBlockArgLocs.emplace_back(Args: beforeBlockArgs[index].getLoc()); |
| 3678 | } |
| 3679 | |
| 3680 | { |
| 3681 | OpBuilder::InsertionGuard g(rewriter); |
| 3682 | rewriter.setInsertionPoint(yieldOp); |
| 3683 | rewriter.replaceOpWithNewOp<YieldOp>(op: yieldOp, args&: newYieldOpArgs); |
| 3684 | } |
| 3685 | |
| 3686 | auto newWhile = |
| 3687 | rewriter.create<WhileOp>(location: op.getLoc(), args: op.getResultTypes(), args&: newInitArgs); |
| 3688 | |
| 3689 | Block &newBeforeBlock = *rewriter.createBlock( |
| 3690 | parent: &newWhile.getBefore(), /*insertPt*/ {}, |
| 3691 | argTypes: ValueRange(newYieldOpArgs).getTypes(), locs: newBeforeBlockArgLocs); |
| 3692 | |
| 3693 | Block &beforeBlock = *op.getBeforeBody(); |
| 3694 | SmallVector<Value> newBeforeBlockArgs(beforeBlock.getNumArguments()); |
| 3695 | // For each i-th before block argument we find it's replacement value as :- |
| 3696 | // 1. If i-th before block argument is a loop invariant, we fetch it's |
| 3697 | // initial value from `beforeBlockInitValMap` by querying for key `i`. |
| 3698 | // 2. Else we fetch j-th new before block argument as the replacement |
| 3699 | // value of i-th before block argument. |
| 3700 | for (unsigned i = 0, j = 0, n = beforeBlock.getNumArguments(); i < n; i++) { |
| 3701 | // If the index 'i' argument was a loop invariant we fetch it's initial |
| 3702 | // value from `beforeBlockInitValMap`. |
| 3703 | if (beforeBlockInitValMap.count(Val: i) != 0) |
| 3704 | newBeforeBlockArgs[i] = beforeBlockInitValMap[i]; |
| 3705 | else |
| 3706 | newBeforeBlockArgs[i] = newBeforeBlock.getArgument(i: j++); |
| 3707 | } |
| 3708 | |
| 3709 | rewriter.mergeBlocks(source: &beforeBlock, dest: &newBeforeBlock, argValues: newBeforeBlockArgs); |
| 3710 | rewriter.inlineRegionBefore(region&: op.getAfter(), parent&: newWhile.getAfter(), |
| 3711 | before: newWhile.getAfter().begin()); |
| 3712 | |
| 3713 | rewriter.replaceOp(op, newValues: newWhile.getResults()); |
| 3714 | return success(); |
| 3715 | } |
| 3716 | }; |
| 3717 | |
| 3718 | /// Remove loop invariant value from result (condition op) of scf.while. |
| 3719 | /// A value is considered loop invariant if the final value yielded by |
| 3720 | /// scf.condition is defined outside of the `before` block. We remove the |
| 3721 | /// corresponding argument in `after` block and replace the use with the value. |
| 3722 | /// We also replace the use of the corresponding result of scf.while with the |
| 3723 | /// value. |
| 3724 | /// |
| 3725 | /// Eg: |
| 3726 | /// INPUT :- |
| 3727 | /// %res_input:K = scf.while <...> iter_args(%arg0_before = , ..., |
| 3728 | /// %argN_before = %N) { |
| 3729 | /// ... |
| 3730 | /// scf.condition(%cond) %arg0_before, %a, %b, %arg1_before, ... |
| 3731 | /// } do { |
| 3732 | /// ^bb0(%arg0_after, %arg1_after, %arg2_after, ..., %argK_after): |
| 3733 | /// ... |
| 3734 | /// some_func(%arg1_after) |
| 3735 | /// ... |
| 3736 | /// scf.yield %arg0_after, %arg2_after, ..., %argN_after |
| 3737 | /// } |
| 3738 | /// |
| 3739 | /// OUTPUT :- |
| 3740 | /// %res_output:M = scf.while <...> iter_args(%arg0 = , ..., %argN = %N) { |
| 3741 | /// ... |
| 3742 | /// scf.condition(%cond) %arg0, %arg1, ..., %argM |
| 3743 | /// } do { |
| 3744 | /// ^bb0(%arg0, %arg3, ..., %argM): |
| 3745 | /// ... |
| 3746 | /// some_func(%a) |
| 3747 | /// ... |
| 3748 | /// scf.yield %arg0, %b, ..., %argN |
| 3749 | /// } |
| 3750 | /// |
| 3751 | /// EXPLANATION: |
| 3752 | /// 1. The 1-th and 2-th operand of scf.condition are defined outside the |
| 3753 | /// before block of scf.while, so they get removed. |
| 3754 | /// 2. %res_input#1's uses are replaced by %a and %res_input#2's uses are |
| 3755 | /// replaced by %b. |
| 3756 | /// 3. The corresponding after block argument %arg1_after's uses are |
| 3757 | /// replaced by %a and %arg2_after's uses are replaced by %b. |
| 3758 | struct RemoveLoopInvariantValueYielded : public OpRewritePattern<WhileOp> { |
| 3759 | using OpRewritePattern<WhileOp>::OpRewritePattern; |
| 3760 | |
| 3761 | LogicalResult matchAndRewrite(WhileOp op, |
| 3762 | PatternRewriter &rewriter) const override { |
| 3763 | Block &beforeBlock = *op.getBeforeBody(); |
| 3764 | ConditionOp condOp = op.getConditionOp(); |
| 3765 | OperandRange condOpArgs = condOp.getArgs(); |
| 3766 | |
| 3767 | bool canSimplify = false; |
| 3768 | for (Value condOpArg : condOpArgs) { |
| 3769 | // Those values not defined within `before` block will be considered as |
| 3770 | // loop invariant values. We map the corresponding `index` with their |
| 3771 | // value. |
| 3772 | if (condOpArg.getParentBlock() != &beforeBlock) { |
| 3773 | canSimplify = true; |
| 3774 | break; |
| 3775 | } |
| 3776 | } |
| 3777 | |
| 3778 | if (!canSimplify) |
| 3779 | return failure(); |
| 3780 | |
| 3781 | Block::BlockArgListType afterBlockArgs = op.getAfterArguments(); |
| 3782 | |
| 3783 | SmallVector<Value> newCondOpArgs; |
| 3784 | SmallVector<Type> newAfterBlockType; |
| 3785 | DenseMap<unsigned, Value> condOpInitValMap; |
| 3786 | SmallVector<Location> newAfterBlockArgLocs; |
| 3787 | for (const auto &it : llvm::enumerate(First&: condOpArgs)) { |
| 3788 | auto index = static_cast<unsigned>(it.index()); |
| 3789 | Value condOpArg = it.value(); |
| 3790 | // Those values not defined within `before` block will be considered as |
| 3791 | // loop invariant values. We map the corresponding `index` with their |
| 3792 | // value. |
| 3793 | if (condOpArg.getParentBlock() != &beforeBlock) { |
| 3794 | condOpInitValMap.insert(KV: {index, condOpArg}); |
| 3795 | } else { |
| 3796 | newCondOpArgs.emplace_back(Args&: condOpArg); |
| 3797 | newAfterBlockType.emplace_back(Args: condOpArg.getType()); |
| 3798 | newAfterBlockArgLocs.emplace_back(Args: afterBlockArgs[index].getLoc()); |
| 3799 | } |
| 3800 | } |
| 3801 | |
| 3802 | { |
| 3803 | OpBuilder::InsertionGuard g(rewriter); |
| 3804 | rewriter.setInsertionPoint(condOp); |
| 3805 | rewriter.replaceOpWithNewOp<ConditionOp>(op: condOp, args: condOp.getCondition(), |
| 3806 | args&: newCondOpArgs); |
| 3807 | } |
| 3808 | |
| 3809 | auto newWhile = rewriter.create<WhileOp>(location: op.getLoc(), args&: newAfterBlockType, |
| 3810 | args: op.getOperands()); |
| 3811 | |
| 3812 | Block &newAfterBlock = |
| 3813 | *rewriter.createBlock(parent: &newWhile.getAfter(), /*insertPt*/ {}, |
| 3814 | argTypes: newAfterBlockType, locs: newAfterBlockArgLocs); |
| 3815 | |
| 3816 | Block &afterBlock = *op.getAfterBody(); |
| 3817 | // Since a new scf.condition op was created, we need to fetch the new |
| 3818 | // `after` block arguments which will be used while replacing operations of |
| 3819 | // previous scf.while's `after` blocks. We'd also be fetching new result |
| 3820 | // values too. |
| 3821 | SmallVector<Value> newAfterBlockArgs(afterBlock.getNumArguments()); |
| 3822 | SmallVector<Value> newWhileResults(afterBlock.getNumArguments()); |
| 3823 | for (unsigned i = 0, j = 0, n = afterBlock.getNumArguments(); i < n; i++) { |
| 3824 | Value afterBlockArg, result; |
| 3825 | // If index 'i' argument was loop invariant we fetch it's value from the |
| 3826 | // `condOpInitMap` map. |
| 3827 | if (condOpInitValMap.count(Val: i) != 0) { |
| 3828 | afterBlockArg = condOpInitValMap[i]; |
| 3829 | result = afterBlockArg; |
| 3830 | } else { |
| 3831 | afterBlockArg = newAfterBlock.getArgument(i: j); |
| 3832 | result = newWhile.getResult(i: j); |
| 3833 | j++; |
| 3834 | } |
| 3835 | newAfterBlockArgs[i] = afterBlockArg; |
| 3836 | newWhileResults[i] = result; |
| 3837 | } |
| 3838 | |
| 3839 | rewriter.mergeBlocks(source: &afterBlock, dest: &newAfterBlock, argValues: newAfterBlockArgs); |
| 3840 | rewriter.inlineRegionBefore(region&: op.getBefore(), parent&: newWhile.getBefore(), |
| 3841 | before: newWhile.getBefore().begin()); |
| 3842 | |
| 3843 | rewriter.replaceOp(op, newValues: newWhileResults); |
| 3844 | return success(); |
| 3845 | } |
| 3846 | }; |
| 3847 | |
| 3848 | /// Remove WhileOp results that are also unused in 'after' block. |
| 3849 | /// |
| 3850 | /// %0:2 = scf.while () : () -> (i32, i64) { |
| 3851 | /// %condition = "test.condition"() : () -> i1 |
| 3852 | /// %v1 = "test.get_some_value"() : () -> i32 |
| 3853 | /// %v2 = "test.get_some_value"() : () -> i64 |
| 3854 | /// scf.condition(%condition) %v1, %v2 : i32, i64 |
| 3855 | /// } do { |
| 3856 | /// ^bb0(%arg0: i32, %arg1: i64): |
| 3857 | /// "test.use"(%arg0) : (i32) -> () |
| 3858 | /// scf.yield |
| 3859 | /// } |
| 3860 | /// return %0#0 : i32 |
| 3861 | /// |
| 3862 | /// becomes |
| 3863 | /// %0 = scf.while () : () -> (i32) { |
| 3864 | /// %condition = "test.condition"() : () -> i1 |
| 3865 | /// %v1 = "test.get_some_value"() : () -> i32 |
| 3866 | /// %v2 = "test.get_some_value"() : () -> i64 |
| 3867 | /// scf.condition(%condition) %v1 : i32 |
| 3868 | /// } do { |
| 3869 | /// ^bb0(%arg0: i32): |
| 3870 | /// "test.use"(%arg0) : (i32) -> () |
| 3871 | /// scf.yield |
| 3872 | /// } |
| 3873 | /// return %0 : i32 |
| 3874 | struct WhileUnusedResult : public OpRewritePattern<WhileOp> { |
| 3875 | using OpRewritePattern<WhileOp>::OpRewritePattern; |
| 3876 | |
| 3877 | LogicalResult matchAndRewrite(WhileOp op, |
| 3878 | PatternRewriter &rewriter) const override { |
| 3879 | auto term = op.getConditionOp(); |
| 3880 | auto afterArgs = op.getAfterArguments(); |
| 3881 | auto termArgs = term.getArgs(); |
| 3882 | |
| 3883 | // Collect results mapping, new terminator args and new result types. |
| 3884 | SmallVector<unsigned> newResultsIndices; |
| 3885 | SmallVector<Type> newResultTypes; |
| 3886 | SmallVector<Value> newTermArgs; |
| 3887 | SmallVector<Location> newArgLocs; |
| 3888 | bool needUpdate = false; |
| 3889 | for (const auto &it : |
| 3890 | llvm::enumerate(First: llvm::zip(t: op.getResults(), u&: afterArgs, args&: termArgs))) { |
| 3891 | auto i = static_cast<unsigned>(it.index()); |
| 3892 | Value result = std::get<0>(t&: it.value()); |
| 3893 | Value afterArg = std::get<1>(t&: it.value()); |
| 3894 | Value termArg = std::get<2>(t&: it.value()); |
| 3895 | if (result.use_empty() && afterArg.use_empty()) { |
| 3896 | needUpdate = true; |
| 3897 | } else { |
| 3898 | newResultsIndices.emplace_back(Args&: i); |
| 3899 | newTermArgs.emplace_back(Args&: termArg); |
| 3900 | newResultTypes.emplace_back(Args: result.getType()); |
| 3901 | newArgLocs.emplace_back(Args: result.getLoc()); |
| 3902 | } |
| 3903 | } |
| 3904 | |
| 3905 | if (!needUpdate) |
| 3906 | return failure(); |
| 3907 | |
| 3908 | { |
| 3909 | OpBuilder::InsertionGuard g(rewriter); |
| 3910 | rewriter.setInsertionPoint(term); |
| 3911 | rewriter.replaceOpWithNewOp<ConditionOp>(op: term, args: term.getCondition(), |
| 3912 | args&: newTermArgs); |
| 3913 | } |
| 3914 | |
| 3915 | auto newWhile = |
| 3916 | rewriter.create<WhileOp>(location: op.getLoc(), args&: newResultTypes, args: op.getInits()); |
| 3917 | |
| 3918 | Block &newAfterBlock = *rewriter.createBlock( |
| 3919 | parent: &newWhile.getAfter(), /*insertPt*/ {}, argTypes: newResultTypes, locs: newArgLocs); |
| 3920 | |
| 3921 | // Build new results list and new after block args (unused entries will be |
| 3922 | // null). |
| 3923 | SmallVector<Value> newResults(op.getNumResults()); |
| 3924 | SmallVector<Value> newAfterBlockArgs(op.getNumResults()); |
| 3925 | for (const auto &it : llvm::enumerate(First&: newResultsIndices)) { |
| 3926 | newResults[it.value()] = newWhile.getResult(i: it.index()); |
| 3927 | newAfterBlockArgs[it.value()] = newAfterBlock.getArgument(i: it.index()); |
| 3928 | } |
| 3929 | |
| 3930 | rewriter.inlineRegionBefore(region&: op.getBefore(), parent&: newWhile.getBefore(), |
| 3931 | before: newWhile.getBefore().begin()); |
| 3932 | |
| 3933 | Block &afterBlock = *op.getAfterBody(); |
| 3934 | rewriter.mergeBlocks(source: &afterBlock, dest: &newAfterBlock, argValues: newAfterBlockArgs); |
| 3935 | |
| 3936 | rewriter.replaceOp(op, newValues: newResults); |
| 3937 | return success(); |
| 3938 | } |
| 3939 | }; |
| 3940 | |
| 3941 | /// Replace operations equivalent to the condition in the do block with true, |
| 3942 | /// since otherwise the block would not be evaluated. |
| 3943 | /// |
| 3944 | /// scf.while (..) : (i32, ...) -> ... { |
| 3945 | /// %z = ... : i32 |
| 3946 | /// %condition = cmpi pred %z, %a |
| 3947 | /// scf.condition(%condition) %z : i32, ... |
| 3948 | /// } do { |
| 3949 | /// ^bb0(%arg0: i32, ...): |
| 3950 | /// %condition2 = cmpi pred %arg0, %a |
| 3951 | /// use(%condition2) |
| 3952 | /// ... |
| 3953 | /// |
| 3954 | /// becomes |
| 3955 | /// scf.while (..) : (i32, ...) -> ... { |
| 3956 | /// %z = ... : i32 |
| 3957 | /// %condition = cmpi pred %z, %a |
| 3958 | /// scf.condition(%condition) %z : i32, ... |
| 3959 | /// } do { |
| 3960 | /// ^bb0(%arg0: i32, ...): |
| 3961 | /// use(%true) |
| 3962 | /// ... |
| 3963 | struct WhileCmpCond : public OpRewritePattern<scf::WhileOp> { |
| 3964 | using OpRewritePattern<scf::WhileOp>::OpRewritePattern; |
| 3965 | |
| 3966 | LogicalResult matchAndRewrite(scf::WhileOp op, |
| 3967 | PatternRewriter &rewriter) const override { |
| 3968 | using namespace scf; |
| 3969 | auto cond = op.getConditionOp(); |
| 3970 | auto cmp = cond.getCondition().getDefiningOp<arith::CmpIOp>(); |
| 3971 | if (!cmp) |
| 3972 | return failure(); |
| 3973 | bool changed = false; |
| 3974 | for (auto tup : llvm::zip(t: cond.getArgs(), u: op.getAfterArguments())) { |
| 3975 | for (size_t opIdx = 0; opIdx < 2; opIdx++) { |
| 3976 | if (std::get<0>(t&: tup) != cmp.getOperand(i: opIdx)) |
| 3977 | continue; |
| 3978 | for (OpOperand &u : |
| 3979 | llvm::make_early_inc_range(Range: std::get<1>(t&: tup).getUses())) { |
| 3980 | auto cmp2 = dyn_cast<arith::CmpIOp>(Val: u.getOwner()); |
| 3981 | if (!cmp2) |
| 3982 | continue; |
| 3983 | // For a binary operator 1-opIdx gets the other side. |
| 3984 | if (cmp2.getOperand(i: 1 - opIdx) != cmp.getOperand(i: 1 - opIdx)) |
| 3985 | continue; |
| 3986 | bool samePredicate; |
| 3987 | if (cmp2.getPredicate() == cmp.getPredicate()) |
| 3988 | samePredicate = true; |
| 3989 | else if (cmp2.getPredicate() == |
| 3990 | arith::invertPredicate(pred: cmp.getPredicate())) |
| 3991 | samePredicate = false; |
| 3992 | else |
| 3993 | continue; |
| 3994 | |
| 3995 | rewriter.replaceOpWithNewOp<arith::ConstantIntOp>(op: cmp2, args&: samePredicate, |
| 3996 | args: 1); |
| 3997 | changed = true; |
| 3998 | } |
| 3999 | } |
| 4000 | } |
| 4001 | return success(IsSuccess: changed); |
| 4002 | } |
| 4003 | }; |
| 4004 | |
| 4005 | /// Remove unused init/yield args. |
| 4006 | struct WhileRemoveUnusedArgs : public OpRewritePattern<WhileOp> { |
| 4007 | using OpRewritePattern<WhileOp>::OpRewritePattern; |
| 4008 | |
| 4009 | LogicalResult matchAndRewrite(WhileOp op, |
| 4010 | PatternRewriter &rewriter) const override { |
| 4011 | |
| 4012 | if (!llvm::any_of(Range: op.getBeforeArguments(), |
| 4013 | P: [](Value arg) { return arg.use_empty(); })) |
| 4014 | return rewriter.notifyMatchFailure(arg&: op, msg: "No args to remove" ); |
| 4015 | |
| 4016 | YieldOp yield = op.getYieldOp(); |
| 4017 | |
| 4018 | // Collect results mapping, new terminator args and new result types. |
| 4019 | SmallVector<Value> newYields; |
| 4020 | SmallVector<Value> newInits; |
| 4021 | llvm::BitVector argsToErase; |
| 4022 | |
| 4023 | size_t argsCount = op.getBeforeArguments().size(); |
| 4024 | newYields.reserve(N: argsCount); |
| 4025 | newInits.reserve(N: argsCount); |
| 4026 | argsToErase.reserve(N: argsCount); |
| 4027 | for (auto &&[beforeArg, yieldValue, initValue] : llvm::zip( |
| 4028 | t: op.getBeforeArguments(), u: yield.getOperands(), args: op.getInits())) { |
| 4029 | if (beforeArg.use_empty()) { |
| 4030 | argsToErase.push_back(Val: true); |
| 4031 | } else { |
| 4032 | argsToErase.push_back(Val: false); |
| 4033 | newYields.emplace_back(Args&: yieldValue); |
| 4034 | newInits.emplace_back(Args&: initValue); |
| 4035 | } |
| 4036 | } |
| 4037 | |
| 4038 | Block &beforeBlock = *op.getBeforeBody(); |
| 4039 | Block &afterBlock = *op.getAfterBody(); |
| 4040 | |
| 4041 | beforeBlock.eraseArguments(eraseIndices: argsToErase); |
| 4042 | |
| 4043 | Location loc = op.getLoc(); |
| 4044 | auto newWhileOp = |
| 4045 | rewriter.create<WhileOp>(location: loc, args: op.getResultTypes(), args&: newInits, |
| 4046 | /*beforeBody*/ args: nullptr, /*afterBody*/ args: nullptr); |
| 4047 | Block &newBeforeBlock = *newWhileOp.getBeforeBody(); |
| 4048 | Block &newAfterBlock = *newWhileOp.getAfterBody(); |
| 4049 | |
| 4050 | OpBuilder::InsertionGuard g(rewriter); |
| 4051 | rewriter.setInsertionPoint(yield); |
| 4052 | rewriter.replaceOpWithNewOp<YieldOp>(op: yield, args&: newYields); |
| 4053 | |
| 4054 | rewriter.mergeBlocks(source: &beforeBlock, dest: &newBeforeBlock, |
| 4055 | argValues: newBeforeBlock.getArguments()); |
| 4056 | rewriter.mergeBlocks(source: &afterBlock, dest: &newAfterBlock, |
| 4057 | argValues: newAfterBlock.getArguments()); |
| 4058 | |
| 4059 | rewriter.replaceOp(op, newValues: newWhileOp.getResults()); |
| 4060 | return success(); |
| 4061 | } |
| 4062 | }; |
| 4063 | |
| 4064 | /// Remove duplicated ConditionOp args. |
| 4065 | struct WhileRemoveDuplicatedResults : public OpRewritePattern<WhileOp> { |
| 4066 | using OpRewritePattern::OpRewritePattern; |
| 4067 | |
| 4068 | LogicalResult matchAndRewrite(WhileOp op, |
| 4069 | PatternRewriter &rewriter) const override { |
| 4070 | ConditionOp condOp = op.getConditionOp(); |
| 4071 | ValueRange condOpArgs = condOp.getArgs(); |
| 4072 | |
| 4073 | llvm::SmallPtrSet<Value, 8> argsSet(llvm::from_range, condOpArgs); |
| 4074 | |
| 4075 | if (argsSet.size() == condOpArgs.size()) |
| 4076 | return rewriter.notifyMatchFailure(arg&: op, msg: "No results to remove" ); |
| 4077 | |
| 4078 | llvm::SmallDenseMap<Value, unsigned> argsMap; |
| 4079 | SmallVector<Value> newArgs; |
| 4080 | argsMap.reserve(NumEntries: condOpArgs.size()); |
| 4081 | newArgs.reserve(N: condOpArgs.size()); |
| 4082 | for (Value arg : condOpArgs) { |
| 4083 | if (!argsMap.count(Val: arg)) { |
| 4084 | auto pos = static_cast<unsigned>(argsMap.size()); |
| 4085 | argsMap.insert(KV: {arg, pos}); |
| 4086 | newArgs.emplace_back(Args&: arg); |
| 4087 | } |
| 4088 | } |
| 4089 | |
| 4090 | ValueRange argsRange(newArgs); |
| 4091 | |
| 4092 | Location loc = op.getLoc(); |
| 4093 | auto newWhileOp = rewriter.create<scf::WhileOp>( |
| 4094 | location: loc, args: argsRange.getTypes(), args: op.getInits(), /*beforeBody*/ args: nullptr, |
| 4095 | /*afterBody*/ args: nullptr); |
| 4096 | Block &newBeforeBlock = *newWhileOp.getBeforeBody(); |
| 4097 | Block &newAfterBlock = *newWhileOp.getAfterBody(); |
| 4098 | |
| 4099 | SmallVector<Value> afterArgsMapping; |
| 4100 | SmallVector<Value> resultsMapping; |
| 4101 | for (auto &&[i, arg] : llvm::enumerate(First&: condOpArgs)) { |
| 4102 | auto it = argsMap.find(Val: arg); |
| 4103 | assert(it != argsMap.end()); |
| 4104 | auto pos = it->second; |
| 4105 | afterArgsMapping.emplace_back(Args: newAfterBlock.getArgument(i: pos)); |
| 4106 | resultsMapping.emplace_back(Args: newWhileOp->getResult(idx: pos)); |
| 4107 | } |
| 4108 | |
| 4109 | OpBuilder::InsertionGuard g(rewriter); |
| 4110 | rewriter.setInsertionPoint(condOp); |
| 4111 | rewriter.replaceOpWithNewOp<ConditionOp>(op: condOp, args: condOp.getCondition(), |
| 4112 | args&: argsRange); |
| 4113 | |
| 4114 | Block &beforeBlock = *op.getBeforeBody(); |
| 4115 | Block &afterBlock = *op.getAfterBody(); |
| 4116 | |
| 4117 | rewriter.mergeBlocks(source: &beforeBlock, dest: &newBeforeBlock, |
| 4118 | argValues: newBeforeBlock.getArguments()); |
| 4119 | rewriter.mergeBlocks(source: &afterBlock, dest: &newAfterBlock, argValues: afterArgsMapping); |
| 4120 | rewriter.replaceOp(op, newValues: resultsMapping); |
| 4121 | return success(); |
| 4122 | } |
| 4123 | }; |
| 4124 | |
| 4125 | /// If both ranges contain same values return mappping indices from args2 to |
| 4126 | /// args1. Otherwise return std::nullopt. |
| 4127 | static std::optional<SmallVector<unsigned>> getArgsMapping(ValueRange args1, |
| 4128 | ValueRange args2) { |
| 4129 | if (args1.size() != args2.size()) |
| 4130 | return std::nullopt; |
| 4131 | |
| 4132 | SmallVector<unsigned> ret(args1.size()); |
| 4133 | for (auto &&[i, arg1] : llvm::enumerate(First&: args1)) { |
| 4134 | auto it = llvm::find(Range&: args2, Val: arg1); |
| 4135 | if (it == args2.end()) |
| 4136 | return std::nullopt; |
| 4137 | |
| 4138 | ret[std::distance(first: args2.begin(), last: it)] = static_cast<unsigned>(i); |
| 4139 | } |
| 4140 | |
| 4141 | return ret; |
| 4142 | } |
| 4143 | |
| 4144 | static bool hasDuplicates(ValueRange args) { |
| 4145 | llvm::SmallDenseSet<Value> set; |
| 4146 | for (Value arg : args) { |
| 4147 | if (!set.insert(V: arg).second) |
| 4148 | return true; |
| 4149 | } |
| 4150 | return false; |
| 4151 | } |
| 4152 | |
| 4153 | /// If `before` block args are directly forwarded to `scf.condition`, rearrange |
| 4154 | /// `scf.condition` args into same order as block args. Update `after` block |
| 4155 | /// args and op result values accordingly. |
| 4156 | /// Needed to simplify `scf.while` -> `scf.for` uplifting. |
| 4157 | struct WhileOpAlignBeforeArgs : public OpRewritePattern<WhileOp> { |
| 4158 | using OpRewritePattern::OpRewritePattern; |
| 4159 | |
| 4160 | LogicalResult matchAndRewrite(WhileOp loop, |
| 4161 | PatternRewriter &rewriter) const override { |
| 4162 | auto oldBefore = loop.getBeforeBody(); |
| 4163 | ConditionOp oldTerm = loop.getConditionOp(); |
| 4164 | ValueRange beforeArgs = oldBefore->getArguments(); |
| 4165 | ValueRange termArgs = oldTerm.getArgs(); |
| 4166 | if (beforeArgs == termArgs) |
| 4167 | return failure(); |
| 4168 | |
| 4169 | if (hasDuplicates(args: termArgs)) |
| 4170 | return failure(); |
| 4171 | |
| 4172 | auto mapping = getArgsMapping(args1: beforeArgs, args2: termArgs); |
| 4173 | if (!mapping) |
| 4174 | return failure(); |
| 4175 | |
| 4176 | { |
| 4177 | OpBuilder::InsertionGuard g(rewriter); |
| 4178 | rewriter.setInsertionPoint(oldTerm); |
| 4179 | rewriter.replaceOpWithNewOp<ConditionOp>(op: oldTerm, args: oldTerm.getCondition(), |
| 4180 | args&: beforeArgs); |
| 4181 | } |
| 4182 | |
| 4183 | auto oldAfter = loop.getAfterBody(); |
| 4184 | |
| 4185 | SmallVector<Type> newResultTypes(beforeArgs.size()); |
| 4186 | for (auto &&[i, j] : llvm::enumerate(First&: *mapping)) |
| 4187 | newResultTypes[j] = loop.getResult(i).getType(); |
| 4188 | |
| 4189 | auto newLoop = rewriter.create<WhileOp>( |
| 4190 | location: loop.getLoc(), args&: newResultTypes, args: loop.getInits(), |
| 4191 | /*beforeBuilder=*/args: nullptr, /*afterBuilder=*/args: nullptr); |
| 4192 | auto newBefore = newLoop.getBeforeBody(); |
| 4193 | auto newAfter = newLoop.getAfterBody(); |
| 4194 | |
| 4195 | SmallVector<Value> newResults(beforeArgs.size()); |
| 4196 | SmallVector<Value> newAfterArgs(beforeArgs.size()); |
| 4197 | for (auto &&[i, j] : llvm::enumerate(First&: *mapping)) { |
| 4198 | newResults[i] = newLoop.getResult(i: j); |
| 4199 | newAfterArgs[i] = newAfter->getArgument(i: j); |
| 4200 | } |
| 4201 | |
| 4202 | rewriter.inlineBlockBefore(source: oldBefore, dest: newBefore, before: newBefore->begin(), |
| 4203 | argValues: newBefore->getArguments()); |
| 4204 | rewriter.inlineBlockBefore(source: oldAfter, dest: newAfter, before: newAfter->begin(), |
| 4205 | argValues: newAfterArgs); |
| 4206 | |
| 4207 | rewriter.replaceOp(op: loop, newValues: newResults); |
| 4208 | return success(); |
| 4209 | } |
| 4210 | }; |
| 4211 | } // namespace |
| 4212 | |
| 4213 | void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results, |
| 4214 | MLIRContext *context) { |
| 4215 | results.add<RemoveLoopInvariantArgsFromBeforeBlock, |
| 4216 | RemoveLoopInvariantValueYielded, WhileConditionTruth, |
| 4217 | WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults, |
| 4218 | WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs>(arg&: context); |
| 4219 | } |
| 4220 | |
| 4221 | //===----------------------------------------------------------------------===// |
| 4222 | // IndexSwitchOp |
| 4223 | //===----------------------------------------------------------------------===// |
| 4224 | |
| 4225 | /// Parse the case regions and values. |
| 4226 | static ParseResult |
| 4227 | parseSwitchCases(OpAsmParser &p, DenseI64ArrayAttr &cases, |
| 4228 | SmallVectorImpl<std::unique_ptr<Region>> &caseRegions) { |
| 4229 | SmallVector<int64_t> caseValues; |
| 4230 | while (succeeded(Result: p.parseOptionalKeyword(keyword: "case" ))) { |
| 4231 | int64_t value; |
| 4232 | Region ®ion = *caseRegions.emplace_back(Args: std::make_unique<Region>()); |
| 4233 | if (p.parseInteger(result&: value) || p.parseRegion(region, /*arguments=*/{})) |
| 4234 | return failure(); |
| 4235 | caseValues.push_back(Elt: value); |
| 4236 | } |
| 4237 | cases = p.getBuilder().getDenseI64ArrayAttr(values: caseValues); |
| 4238 | return success(); |
| 4239 | } |
| 4240 | |
| 4241 | /// Print the case regions and values. |
| 4242 | static void printSwitchCases(OpAsmPrinter &p, Operation *op, |
| 4243 | DenseI64ArrayAttr cases, RegionRange caseRegions) { |
| 4244 | for (auto [value, region] : llvm::zip(t: cases.asArrayRef(), u&: caseRegions)) { |
| 4245 | p.printNewline(); |
| 4246 | p << "case " << value << ' '; |
| 4247 | p.printRegion(blocks&: *region, /*printEntryBlockArgs=*/false); |
| 4248 | } |
| 4249 | } |
| 4250 | |
| 4251 | LogicalResult scf::IndexSwitchOp::verify() { |
| 4252 | if (getCases().size() != getCaseRegions().size()) { |
| 4253 | return emitOpError(message: "has " ) |
| 4254 | << getCaseRegions().size() << " case regions but " |
| 4255 | << getCases().size() << " case values" ; |
| 4256 | } |
| 4257 | |
| 4258 | DenseSet<int64_t> valueSet; |
| 4259 | for (int64_t value : getCases()) |
| 4260 | if (!valueSet.insert(V: value).second) |
| 4261 | return emitOpError(message: "has duplicate case value: " ) << value; |
| 4262 | auto verifyRegion = [&](Region ®ion, const Twine &name) -> LogicalResult { |
| 4263 | auto yield = dyn_cast<YieldOp>(Val&: region.front().back()); |
| 4264 | if (!yield) |
| 4265 | return emitOpError(message: "expected region to end with scf.yield, but got " ) |
| 4266 | << region.front().back().getName(); |
| 4267 | |
| 4268 | if (yield.getNumOperands() != getNumResults()) { |
| 4269 | return (emitOpError(message: "expected each region to return " ) |
| 4270 | << getNumResults() << " values, but " << name << " returns " |
| 4271 | << yield.getNumOperands()) |
| 4272 | .attachNote(noteLoc: yield.getLoc()) |
| 4273 | << "see yield operation here" ; |
| 4274 | } |
| 4275 | for (auto [idx, result, operand] : |
| 4276 | llvm::enumerate(First: getResultTypes(), Rest: yield.getOperands())) { |
| 4277 | if (!operand) |
| 4278 | return yield.emitOpError() << "operand " << idx << " is null\n" ; |
| 4279 | if (result == operand.getType()) |
| 4280 | continue; |
| 4281 | return (emitOpError(message: "expected result #" ) |
| 4282 | << idx << " of each region to be " << result) |
| 4283 | .attachNote(noteLoc: yield.getLoc()) |
| 4284 | << name << " returns " << operand.getType() << " here" ; |
| 4285 | } |
| 4286 | return success(); |
| 4287 | }; |
| 4288 | |
| 4289 | if (failed(Result: verifyRegion(getDefaultRegion(), "default region" ))) |
| 4290 | return failure(); |
| 4291 | for (auto [idx, caseRegion] : llvm::enumerate(First: getCaseRegions())) |
| 4292 | if (failed(Result: verifyRegion(caseRegion, "case region #" + Twine(idx)))) |
| 4293 | return failure(); |
| 4294 | |
| 4295 | return success(); |
| 4296 | } |
| 4297 | |
| 4298 | unsigned scf::IndexSwitchOp::getNumCases() { return getCases().size(); } |
| 4299 | |
| 4300 | Block &scf::IndexSwitchOp::getDefaultBlock() { |
| 4301 | return getDefaultRegion().front(); |
| 4302 | } |
| 4303 | |
| 4304 | Block &scf::IndexSwitchOp::getCaseBlock(unsigned idx) { |
| 4305 | assert(idx < getNumCases() && "case index out-of-bounds" ); |
| 4306 | return getCaseRegions()[idx].front(); |
| 4307 | } |
| 4308 | |
| 4309 | void IndexSwitchOp::getSuccessorRegions( |
| 4310 | RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &successors) { |
| 4311 | // All regions branch back to the parent op. |
| 4312 | if (!point.isParent()) { |
| 4313 | successors.emplace_back(Args: getResults()); |
| 4314 | return; |
| 4315 | } |
| 4316 | |
| 4317 | llvm::append_range(C&: successors, R: getRegions()); |
| 4318 | } |
| 4319 | |
| 4320 | void IndexSwitchOp::getEntrySuccessorRegions( |
| 4321 | ArrayRef<Attribute> operands, |
| 4322 | SmallVectorImpl<RegionSuccessor> &successors) { |
| 4323 | FoldAdaptor adaptor(operands, *this); |
| 4324 | |
| 4325 | // If a constant was not provided, all regions are possible successors. |
| 4326 | auto arg = dyn_cast_or_null<IntegerAttr>(Val: adaptor.getArg()); |
| 4327 | if (!arg) { |
| 4328 | llvm::append_range(C&: successors, R: getRegions()); |
| 4329 | return; |
| 4330 | } |
| 4331 | |
| 4332 | // Otherwise, try to find a case with a matching value. If not, the |
| 4333 | // default region is the only successor. |
| 4334 | for (auto [caseValue, caseRegion] : llvm::zip(t: getCases(), u: getCaseRegions())) { |
| 4335 | if (caseValue == arg.getInt()) { |
| 4336 | successors.emplace_back(Args: &caseRegion); |
| 4337 | return; |
| 4338 | } |
| 4339 | } |
| 4340 | successors.emplace_back(Args: &getDefaultRegion()); |
| 4341 | } |
| 4342 | |
| 4343 | void IndexSwitchOp::getRegionInvocationBounds( |
| 4344 | ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) { |
| 4345 | auto operandValue = llvm::dyn_cast_or_null<IntegerAttr>(Val: operands.front()); |
| 4346 | if (!operandValue) { |
| 4347 | // All regions are invoked at most once. |
| 4348 | bounds.append(NumInputs: getNumRegions(), Elt: InvocationBounds(/*lb=*/0, /*ub=*/1)); |
| 4349 | return; |
| 4350 | } |
| 4351 | |
| 4352 | unsigned liveIndex = getNumRegions() - 1; |
| 4353 | const auto *it = llvm::find(Range: getCases(), Val: operandValue.getInt()); |
| 4354 | if (it != getCases().end()) |
| 4355 | liveIndex = std::distance(first: getCases().begin(), last: it); |
| 4356 | for (unsigned i = 0, e = getNumRegions(); i < e; ++i) |
| 4357 | bounds.emplace_back(/*lb=*/Args: 0, /*ub=*/Args: i == liveIndex); |
| 4358 | } |
| 4359 | |
| 4360 | struct FoldConstantCase : OpRewritePattern<scf::IndexSwitchOp> { |
| 4361 | using OpRewritePattern<scf::IndexSwitchOp>::OpRewritePattern; |
| 4362 | |
| 4363 | LogicalResult matchAndRewrite(scf::IndexSwitchOp op, |
| 4364 | PatternRewriter &rewriter) const override { |
| 4365 | // If `op.getArg()` is a constant, select the region that matches with |
| 4366 | // the constant value. Use the default region if no matche is found. |
| 4367 | std::optional<int64_t> maybeCst = getConstantIntValue(ofr: op.getArg()); |
| 4368 | if (!maybeCst.has_value()) |
| 4369 | return failure(); |
| 4370 | int64_t cst = *maybeCst; |
| 4371 | int64_t caseIdx, e = op.getNumCases(); |
| 4372 | for (caseIdx = 0; caseIdx < e; ++caseIdx) { |
| 4373 | if (cst == op.getCases()[caseIdx]) |
| 4374 | break; |
| 4375 | } |
| 4376 | |
| 4377 | Region &r = (caseIdx < op.getNumCases()) ? op.getCaseRegions()[caseIdx] |
| 4378 | : op.getDefaultRegion(); |
| 4379 | Block &source = r.front(); |
| 4380 | Operation *terminator = source.getTerminator(); |
| 4381 | SmallVector<Value> results = terminator->getOperands(); |
| 4382 | |
| 4383 | rewriter.inlineBlockBefore(source: &source, op); |
| 4384 | rewriter.eraseOp(op: terminator); |
| 4385 | // Replace the operation with a potentially empty list of results. |
| 4386 | // Fold mechanism doesn't support the case where the result list is empty. |
| 4387 | rewriter.replaceOp(op, newValues: results); |
| 4388 | |
| 4389 | return success(); |
| 4390 | } |
| 4391 | }; |
| 4392 | |
| 4393 | void IndexSwitchOp::getCanonicalizationPatterns(RewritePatternSet &results, |
| 4394 | MLIRContext *context) { |
| 4395 | results.add<FoldConstantCase>(arg&: context); |
| 4396 | } |
| 4397 | |
| 4398 | //===----------------------------------------------------------------------===// |
| 4399 | // TableGen'd op method definitions |
| 4400 | //===----------------------------------------------------------------------===// |
| 4401 | |
| 4402 | #define GET_OP_CLASSES |
| 4403 | #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc" |
| 4404 | |