1//===- SCF.cpp - Structured Control Flow Operations -----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/SCF/IR/SCF.h"
10#include "mlir/Conversion/ConvertToEmitC/ToEmitCInterface.h"
11#include "mlir/Dialect/Arith/IR/Arith.h"
12#include "mlir/Dialect/Arith/Utils/Utils.h"
13#include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h"
14#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
15#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
16#include "mlir/Dialect/MemRef/IR/MemRef.h"
17#include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h"
18#include "mlir/Dialect/Tensor/IR/Tensor.h"
19#include "mlir/IR/BuiltinAttributes.h"
20#include "mlir/IR/IRMapping.h"
21#include "mlir/IR/Matchers.h"
22#include "mlir/IR/PatternMatch.h"
23#include "mlir/Interfaces/FunctionInterfaces.h"
24#include "mlir/Interfaces/ValueBoundsOpInterface.h"
25#include "mlir/Transforms/InliningUtils.h"
26#include "llvm/ADT/MapVector.h"
27#include "llvm/ADT/SmallPtrSet.h"
28
29using namespace mlir;
30using namespace mlir::scf;
31
32#include "mlir/Dialect/SCF/IR/SCFOpsDialect.cpp.inc"
33
34//===----------------------------------------------------------------------===//
35// SCFDialect Dialect Interfaces
36//===----------------------------------------------------------------------===//
37
38namespace {
39struct SCFInlinerInterface : public DialectInlinerInterface {
40 using DialectInlinerInterface::DialectInlinerInterface;
41 // We don't have any special restrictions on what can be inlined into
42 // destination regions (e.g. while/conditional bodies). Always allow it.
43 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
44 IRMapping &valueMapping) const final {
45 return true;
46 }
47 // Operations in scf dialect are always legal to inline since they are
48 // pure.
49 bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
50 return true;
51 }
52 // Handle the given inlined terminator by replacing it with a new operation
53 // as necessary. Required when the region has only one block.
54 void handleTerminator(Operation *op, ValueRange valuesToRepl) const final {
55 auto retValOp = dyn_cast<scf::YieldOp>(Val: op);
56 if (!retValOp)
57 return;
58
59 for (auto retValue : llvm::zip(t&: valuesToRepl, u: retValOp.getOperands())) {
60 std::get<0>(t&: retValue).replaceAllUsesWith(newValue: std::get<1>(t&: retValue));
61 }
62 }
63};
64} // namespace
65
66//===----------------------------------------------------------------------===//
67// SCFDialect
68//===----------------------------------------------------------------------===//
69
70void SCFDialect::initialize() {
71 addOperations<
72#define GET_OP_LIST
73#include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
74 >();
75 addInterfaces<SCFInlinerInterface>();
76 declarePromisedInterface<ConvertToEmitCPatternInterface, SCFDialect>();
77 declarePromisedInterfaces<bufferization::BufferDeallocationOpInterface,
78 InParallelOp, ReduceReturnOp>();
79 declarePromisedInterfaces<bufferization::BufferizableOpInterface, ConditionOp,
80 ExecuteRegionOp, ForOp, IfOp, IndexSwitchOp,
81 ForallOp, InParallelOp, WhileOp, YieldOp>();
82 declarePromisedInterface<ValueBoundsOpInterface, ForOp>();
83}
84
85/// Default callback for IfOp builders. Inserts a yield without arguments.
86void mlir::scf::buildTerminatedBody(OpBuilder &builder, Location loc) {
87 builder.create<scf::YieldOp>(location: loc);
88}
89
90/// Verifies that the first block of the given `region` is terminated by a
91/// TerminatorTy. Reports errors on the given operation if it is not the case.
92template <typename TerminatorTy>
93static TerminatorTy verifyAndGetTerminator(Operation *op, Region &region,
94 StringRef errorMessage) {
95 Operation *terminatorOperation = nullptr;
96 if (!region.empty() && !region.front().empty()) {
97 terminatorOperation = &region.front().back();
98 if (auto yield = dyn_cast_or_null<TerminatorTy>(terminatorOperation))
99 return yield;
100 }
101 auto diag = op->emitOpError(message: errorMessage);
102 if (terminatorOperation)
103 diag.attachNote(noteLoc: terminatorOperation->getLoc()) << "terminator here";
104 return nullptr;
105}
106
107//===----------------------------------------------------------------------===//
108// ExecuteRegionOp
109//===----------------------------------------------------------------------===//
110
111/// Replaces the given op with the contents of the given single-block region,
112/// using the operands of the block terminator to replace operation results.
113static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
114 Region &region, ValueRange blockArgs = {}) {
115 assert(llvm::hasSingleElement(region) && "expected single-region block");
116 Block *block = &region.front();
117 Operation *terminator = block->getTerminator();
118 ValueRange results = terminator->getOperands();
119 rewriter.inlineBlockBefore(source: block, op, argValues: blockArgs);
120 rewriter.replaceOp(op, newValues: results);
121 rewriter.eraseOp(op: terminator);
122}
123
124///
125/// (ssa-id `=`)? `execute_region` `->` function-result-type `{`
126/// block+
127/// `}`
128///
129/// Example:
130/// scf.execute_region -> i32 {
131/// %idx = load %rI[%i] : memref<128xi32>
132/// return %idx : i32
133/// }
134///
135ParseResult ExecuteRegionOp::parse(OpAsmParser &parser,
136 OperationState &result) {
137 if (parser.parseOptionalArrowTypeList(result&: result.types))
138 return failure();
139
140 // Introduce the body region and parse it.
141 Region *body = result.addRegion();
142 if (parser.parseRegion(region&: *body, /*arguments=*/{}, /*argTypes=*/enableNameShadowing: {}) ||
143 parser.parseOptionalAttrDict(result&: result.attributes))
144 return failure();
145
146 return success();
147}
148
149void ExecuteRegionOp::print(OpAsmPrinter &p) {
150 p.printOptionalArrowTypeList(types: getResultTypes());
151
152 p << ' ';
153 p.printRegion(blocks&: getRegion(),
154 /*printEntryBlockArgs=*/false,
155 /*printBlockTerminators=*/true);
156
157 p.printOptionalAttrDict(attrs: (*this)->getAttrs());
158}
159
160LogicalResult ExecuteRegionOp::verify() {
161 if (getRegion().empty())
162 return emitOpError(message: "region needs to have at least one block");
163 if (getRegion().front().getNumArguments() > 0)
164 return emitOpError(message: "region cannot have any arguments");
165 return success();
166}
167
168// Inline an ExecuteRegionOp if it only contains one block.
169// "test.foo"() : () -> ()
170// %v = scf.execute_region -> i64 {
171// %x = "test.val"() : () -> i64
172// scf.yield %x : i64
173// }
174// "test.bar"(%v) : (i64) -> ()
175//
176// becomes
177//
178// "test.foo"() : () -> ()
179// %x = "test.val"() : () -> i64
180// "test.bar"(%x) : (i64) -> ()
181//
182struct SingleBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
183 using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern;
184
185 LogicalResult matchAndRewrite(ExecuteRegionOp op,
186 PatternRewriter &rewriter) const override {
187 if (!llvm::hasSingleElement(C&: op.getRegion()))
188 return failure();
189 replaceOpWithRegion(rewriter, op, region&: op.getRegion());
190 return success();
191 }
192};
193
194// Inline an ExecuteRegionOp if its parent can contain multiple blocks.
195// TODO generalize the conditions for operations which can be inlined into.
196// func @func_execute_region_elim() {
197// "test.foo"() : () -> ()
198// %v = scf.execute_region -> i64 {
199// %c = "test.cmp"() : () -> i1
200// cf.cond_br %c, ^bb2, ^bb3
201// ^bb2:
202// %x = "test.val1"() : () -> i64
203// cf.br ^bb4(%x : i64)
204// ^bb3:
205// %y = "test.val2"() : () -> i64
206// cf.br ^bb4(%y : i64)
207// ^bb4(%z : i64):
208// scf.yield %z : i64
209// }
210// "test.bar"(%v) : (i64) -> ()
211// return
212// }
213//
214// becomes
215//
216// func @func_execute_region_elim() {
217// "test.foo"() : () -> ()
218// %c = "test.cmp"() : () -> i1
219// cf.cond_br %c, ^bb1, ^bb2
220// ^bb1: // pred: ^bb0
221// %x = "test.val1"() : () -> i64
222// cf.br ^bb3(%x : i64)
223// ^bb2: // pred: ^bb0
224// %y = "test.val2"() : () -> i64
225// cf.br ^bb3(%y : i64)
226// ^bb3(%z: i64): // 2 preds: ^bb1, ^bb2
227// "test.bar"(%z) : (i64) -> ()
228// return
229// }
230//
231struct MultiBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
232 using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern;
233
234 LogicalResult matchAndRewrite(ExecuteRegionOp op,
235 PatternRewriter &rewriter) const override {
236 if (!isa<FunctionOpInterface, ExecuteRegionOp>(Val: op->getParentOp()))
237 return failure();
238
239 Block *prevBlock = op->getBlock();
240 Block *postBlock = rewriter.splitBlock(block: prevBlock, before: op->getIterator());
241 rewriter.setInsertionPointToEnd(prevBlock);
242
243 rewriter.create<cf::BranchOp>(location: op.getLoc(), args: &op.getRegion().front());
244
245 for (Block &blk : op.getRegion()) {
246 if (YieldOp yieldOp = dyn_cast<YieldOp>(Val: blk.getTerminator())) {
247 rewriter.setInsertionPoint(yieldOp);
248 rewriter.create<cf::BranchOp>(location: yieldOp.getLoc(), args&: postBlock,
249 args: yieldOp.getResults());
250 rewriter.eraseOp(op: yieldOp);
251 }
252 }
253
254 rewriter.inlineRegionBefore(region&: op.getRegion(), before: postBlock);
255 SmallVector<Value> blockArgs;
256
257 for (auto res : op.getResults())
258 blockArgs.push_back(Elt: postBlock->addArgument(type: res.getType(), loc: res.getLoc()));
259
260 rewriter.replaceOp(op, newValues: blockArgs);
261 return success();
262 }
263};
264
265void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results,
266 MLIRContext *context) {
267 results.add<SingleBlockExecuteInliner, MultiBlockExecuteInliner>(arg&: context);
268}
269
270void ExecuteRegionOp::getSuccessorRegions(
271 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
272 // If the predecessor is the ExecuteRegionOp, branch into the body.
273 if (point.isParent()) {
274 regions.push_back(Elt: RegionSuccessor(&getRegion()));
275 return;
276 }
277
278 // Otherwise, the region branches back to the parent operation.
279 regions.push_back(Elt: RegionSuccessor(getResults()));
280}
281
282//===----------------------------------------------------------------------===//
283// ConditionOp
284//===----------------------------------------------------------------------===//
285
286MutableOperandRange
287ConditionOp::getMutableSuccessorOperands(RegionBranchPoint point) {
288 assert((point.isParent() || point == getParentOp().getAfter()) &&
289 "condition op can only exit the loop or branch to the after"
290 "region");
291 // Pass all operands except the condition to the successor region.
292 return getArgsMutable();
293}
294
295void ConditionOp::getSuccessorRegions(
296 ArrayRef<Attribute> operands, SmallVectorImpl<RegionSuccessor> &regions) {
297 FoldAdaptor adaptor(operands, *this);
298
299 WhileOp whileOp = getParentOp();
300
301 // Condition can either lead to the after region or back to the parent op
302 // depending on whether the condition is true or not.
303 auto boolAttr = dyn_cast_or_null<BoolAttr>(Val: adaptor.getCondition());
304 if (!boolAttr || boolAttr.getValue())
305 regions.emplace_back(Args: &whileOp.getAfter(),
306 Args: whileOp.getAfter().getArguments());
307 if (!boolAttr || !boolAttr.getValue())
308 regions.emplace_back(Args: whileOp.getResults());
309}
310
311//===----------------------------------------------------------------------===//
312// ForOp
313//===----------------------------------------------------------------------===//
314
315void ForOp::build(OpBuilder &builder, OperationState &result, Value lb,
316 Value ub, Value step, ValueRange initArgs,
317 BodyBuilderFn bodyBuilder) {
318 OpBuilder::InsertionGuard guard(builder);
319
320 result.addOperands(newOperands: {lb, ub, step});
321 result.addOperands(newOperands: initArgs);
322 for (Value v : initArgs)
323 result.addTypes(newTypes: v.getType());
324 Type t = lb.getType();
325 Region *bodyRegion = result.addRegion();
326 Block *bodyBlock = builder.createBlock(parent: bodyRegion);
327 bodyBlock->addArgument(type: t, loc: result.location);
328 for (Value v : initArgs)
329 bodyBlock->addArgument(type: v.getType(), loc: v.getLoc());
330
331 // Create the default terminator if the builder is not provided and if the
332 // iteration arguments are not provided. Otherwise, leave this to the caller
333 // because we don't know which values to return from the loop.
334 if (initArgs.empty() && !bodyBuilder) {
335 ForOp::ensureTerminator(region&: *bodyRegion, builder, loc: result.location);
336 } else if (bodyBuilder) {
337 OpBuilder::InsertionGuard guard(builder);
338 builder.setInsertionPointToStart(bodyBlock);
339 bodyBuilder(builder, result.location, bodyBlock->getArgument(i: 0),
340 bodyBlock->getArguments().drop_front());
341 }
342}
343
344LogicalResult ForOp::verify() {
345 // Check that the number of init args and op results is the same.
346 if (getInitArgs().size() != getNumResults())
347 return emitOpError(
348 message: "mismatch in number of loop-carried values and defined values");
349
350 return success();
351}
352
353LogicalResult ForOp::verifyRegions() {
354 // Check that the body defines as single block argument for the induction
355 // variable.
356 if (getInductionVar().getType() != getLowerBound().getType())
357 return emitOpError(
358 message: "expected induction variable to be same type as bounds and step");
359
360 if (getNumRegionIterArgs() != getNumResults())
361 return emitOpError(
362 message: "mismatch in number of basic block args and defined values");
363
364 auto initArgs = getInitArgs();
365 auto iterArgs = getRegionIterArgs();
366 auto opResults = getResults();
367 unsigned i = 0;
368 for (auto e : llvm::zip(t&: initArgs, u&: iterArgs, args&: opResults)) {
369 if (std::get<0>(t&: e).getType() != std::get<2>(t&: e).getType())
370 return emitOpError() << "types mismatch between " << i
371 << "th iter operand and defined value";
372 if (std::get<1>(t&: e).getType() != std::get<2>(t&: e).getType())
373 return emitOpError() << "types mismatch between " << i
374 << "th iter region arg and defined value";
375
376 ++i;
377 }
378 return success();
379}
380
381std::optional<SmallVector<Value>> ForOp::getLoopInductionVars() {
382 return SmallVector<Value>{getInductionVar()};
383}
384
385std::optional<SmallVector<OpFoldResult>> ForOp::getLoopLowerBounds() {
386 return SmallVector<OpFoldResult>{OpFoldResult(getLowerBound())};
387}
388
389std::optional<SmallVector<OpFoldResult>> ForOp::getLoopSteps() {
390 return SmallVector<OpFoldResult>{OpFoldResult(getStep())};
391}
392
393std::optional<SmallVector<OpFoldResult>> ForOp::getLoopUpperBounds() {
394 return SmallVector<OpFoldResult>{OpFoldResult(getUpperBound())};
395}
396
397std::optional<ResultRange> ForOp::getLoopResults() { return getResults(); }
398
399/// Promotes the loop body of a forOp to its containing block if the forOp
400/// it can be determined that the loop has a single iteration.
401LogicalResult ForOp::promoteIfSingleIteration(RewriterBase &rewriter) {
402 std::optional<int64_t> tripCount =
403 constantTripCount(lb: getLowerBound(), ub: getUpperBound(), step: getStep());
404 if (!tripCount.has_value() || tripCount != 1)
405 return failure();
406
407 // Replace all results with the yielded values.
408 auto yieldOp = cast<scf::YieldOp>(Val: getBody()->getTerminator());
409 rewriter.replaceAllUsesWith(from: getResults(), to: getYieldedValues());
410
411 // Replace block arguments with lower bound (replacement for IV) and
412 // iter_args.
413 SmallVector<Value> bbArgReplacements;
414 bbArgReplacements.push_back(Elt: getLowerBound());
415 llvm::append_range(C&: bbArgReplacements, R: getInitArgs());
416
417 // Move the loop body operations to the loop's containing block.
418 rewriter.inlineBlockBefore(source: getBody(), dest: getOperation()->getBlock(),
419 before: getOperation()->getIterator(), argValues: bbArgReplacements);
420
421 // Erase the old terminator and the loop.
422 rewriter.eraseOp(op: yieldOp);
423 rewriter.eraseOp(op: *this);
424
425 return success();
426}
427
428/// Prints the initialization list in the form of
429/// <prefix>(%inner = %outer, %inner2 = %outer2, <...>)
430/// where 'inner' values are assumed to be region arguments and 'outer' values
431/// are regular SSA values.
432static void printInitializationList(OpAsmPrinter &p,
433 Block::BlockArgListType blocksArgs,
434 ValueRange initializers,
435 StringRef prefix = "") {
436 assert(blocksArgs.size() == initializers.size() &&
437 "expected same length of arguments and initializers");
438 if (initializers.empty())
439 return;
440
441 p << prefix << '(';
442 llvm::interleaveComma(c: llvm::zip(t&: blocksArgs, u&: initializers), os&: p, each_fn: [&](auto it) {
443 p << std::get<0>(it) << " = " << std::get<1>(it);
444 });
445 p << ")";
446}
447
448void ForOp::print(OpAsmPrinter &p) {
449 p << " " << getInductionVar() << " = " << getLowerBound() << " to "
450 << getUpperBound() << " step " << getStep();
451
452 printInitializationList(p, blocksArgs: getRegionIterArgs(), initializers: getInitArgs(), prefix: " iter_args");
453 if (!getInitArgs().empty())
454 p << " -> (" << getInitArgs().getTypes() << ')';
455 p << ' ';
456 if (Type t = getInductionVar().getType(); !t.isIndex())
457 p << " : " << t << ' ';
458 p.printRegion(blocks&: getRegion(),
459 /*printEntryBlockArgs=*/false,
460 /*printBlockTerminators=*/!getInitArgs().empty());
461 p.printOptionalAttrDict(attrs: (*this)->getAttrs());
462}
463
464ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
465 auto &builder = parser.getBuilder();
466 Type type;
467
468 OpAsmParser::Argument inductionVariable;
469 OpAsmParser::UnresolvedOperand lb, ub, step;
470
471 // Parse the induction variable followed by '='.
472 if (parser.parseOperand(result&: inductionVariable.ssaName) || parser.parseEqual() ||
473 // Parse loop bounds.
474 parser.parseOperand(result&: lb) || parser.parseKeyword(keyword: "to") ||
475 parser.parseOperand(result&: ub) || parser.parseKeyword(keyword: "step") ||
476 parser.parseOperand(result&: step))
477 return failure();
478
479 // Parse the optional initial iteration arguments.
480 SmallVector<OpAsmParser::Argument, 4> regionArgs;
481 SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
482 regionArgs.push_back(Elt: inductionVariable);
483
484 bool hasIterArgs = succeeded(Result: parser.parseOptionalKeyword(keyword: "iter_args"));
485 if (hasIterArgs) {
486 // Parse assignment list and results type list.
487 if (parser.parseAssignmentList(lhs&: regionArgs, rhs&: operands) ||
488 parser.parseArrowTypeList(result&: result.types))
489 return failure();
490 }
491
492 if (regionArgs.size() != result.types.size() + 1)
493 return parser.emitError(
494 loc: parser.getNameLoc(),
495 message: "mismatch in number of loop-carried values and defined values");
496
497 // Parse optional type, else assume Index.
498 if (parser.parseOptionalColon())
499 type = builder.getIndexType();
500 else if (parser.parseType(result&: type))
501 return failure();
502
503 // Set block argument types, so that they are known when parsing the region.
504 regionArgs.front().type = type;
505 for (auto [iterArg, type] :
506 llvm::zip_equal(t: llvm::drop_begin(RangeOrContainer&: regionArgs), u&: result.types))
507 iterArg.type = type;
508
509 // Parse the body region.
510 Region *body = result.addRegion();
511 if (parser.parseRegion(region&: *body, arguments: regionArgs))
512 return failure();
513 ForOp::ensureTerminator(region&: *body, builder, loc: result.location);
514
515 // Resolve input operands. This should be done after parsing the region to
516 // catch invalid IR where operands were defined inside of the region.
517 if (parser.resolveOperand(operand: lb, type, result&: result.operands) ||
518 parser.resolveOperand(operand: ub, type, result&: result.operands) ||
519 parser.resolveOperand(operand: step, type, result&: result.operands))
520 return failure();
521 if (hasIterArgs) {
522 for (auto argOperandType : llvm::zip_equal(t: llvm::drop_begin(RangeOrContainer&: regionArgs),
523 u&: operands, args&: result.types)) {
524 Type type = std::get<2>(t&: argOperandType);
525 std::get<0>(t&: argOperandType).type = type;
526 if (parser.resolveOperand(operand: std::get<1>(t&: argOperandType), type,
527 result&: result.operands))
528 return failure();
529 }
530 }
531
532 // Parse the optional attribute list.
533 if (parser.parseOptionalAttrDict(result&: result.attributes))
534 return failure();
535
536 return success();
537}
538
539SmallVector<Region *> ForOp::getLoopRegions() { return {&getRegion()}; }
540
541Block::BlockArgListType ForOp::getRegionIterArgs() {
542 return getBody()->getArguments().drop_front(N: getNumInductionVars());
543}
544
545MutableArrayRef<OpOperand> ForOp::getInitsMutable() {
546 return getInitArgsMutable();
547}
548
549FailureOr<LoopLikeOpInterface>
550ForOp::replaceWithAdditionalYields(RewriterBase &rewriter,
551 ValueRange newInitOperands,
552 bool replaceInitOperandUsesInLoop,
553 const NewYieldValuesFn &newYieldValuesFn) {
554 // Create a new loop before the existing one, with the extra operands.
555 OpBuilder::InsertionGuard g(rewriter);
556 rewriter.setInsertionPoint(getOperation());
557 auto inits = llvm::to_vector(Range: getInitArgs());
558 inits.append(in_start: newInitOperands.begin(), in_end: newInitOperands.end());
559 scf::ForOp newLoop = rewriter.create<scf::ForOp>(
560 location: getLoc(), args: getLowerBound(), args: getUpperBound(), args: getStep(), args&: inits,
561 args: [](OpBuilder &, Location, Value, ValueRange) {});
562 newLoop->setAttrs(getPrunedAttributeList(op: getOperation(), elidedAttrs: {}));
563
564 // Generate the new yield values and append them to the scf.yield operation.
565 auto yieldOp = cast<scf::YieldOp>(Val: getBody()->getTerminator());
566 ArrayRef<BlockArgument> newIterArgs =
567 newLoop.getBody()->getArguments().take_back(N: newInitOperands.size());
568 {
569 OpBuilder::InsertionGuard g(rewriter);
570 rewriter.setInsertionPoint(yieldOp);
571 SmallVector<Value> newYieldedValues =
572 newYieldValuesFn(rewriter, getLoc(), newIterArgs);
573 assert(newInitOperands.size() == newYieldedValues.size() &&
574 "expected as many new yield values as new iter operands");
575 rewriter.modifyOpInPlace(root: yieldOp, callable: [&]() {
576 yieldOp.getResultsMutable().append(values: newYieldedValues);
577 });
578 }
579
580 // Move the loop body to the new op.
581 rewriter.mergeBlocks(source: getBody(), dest: newLoop.getBody(),
582 argValues: newLoop.getBody()->getArguments().take_front(
583 N: getBody()->getNumArguments()));
584
585 if (replaceInitOperandUsesInLoop) {
586 // Replace all uses of `newInitOperands` with the corresponding basic block
587 // arguments.
588 for (auto it : llvm::zip(t&: newInitOperands, u&: newIterArgs)) {
589 rewriter.replaceUsesWithIf(from: std::get<0>(t&: it), to: std::get<1>(t&: it),
590 functor: [&](OpOperand &use) {
591 Operation *user = use.getOwner();
592 return newLoop->isProperAncestor(other: user);
593 });
594 }
595 }
596
597 // Replace the old loop.
598 rewriter.replaceOp(op: getOperation(),
599 newValues: newLoop->getResults().take_front(n: getNumResults()));
600 return cast<LoopLikeOpInterface>(Val: newLoop.getOperation());
601}
602
603ForOp mlir::scf::getForInductionVarOwner(Value val) {
604 auto ivArg = llvm::dyn_cast<BlockArgument>(Val&: val);
605 if (!ivArg)
606 return ForOp();
607 assert(ivArg.getOwner() && "unlinked block argument");
608 auto *containingOp = ivArg.getOwner()->getParentOp();
609 return dyn_cast_or_null<ForOp>(Val: containingOp);
610}
611
612OperandRange ForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
613 return getInitArgs();
614}
615
616void ForOp::getSuccessorRegions(RegionBranchPoint point,
617 SmallVectorImpl<RegionSuccessor> &regions) {
618 // Both the operation itself and the region may be branching into the body or
619 // back into the operation itself. It is possible for loop not to enter the
620 // body.
621 regions.push_back(Elt: RegionSuccessor(&getRegion(), getRegionIterArgs()));
622 regions.push_back(Elt: RegionSuccessor(getResults()));
623}
624
625SmallVector<Region *> ForallOp::getLoopRegions() { return {&getRegion()}; }
626
627/// Promotes the loop body of a forallOp to its containing block if it can be
628/// determined that the loop has a single iteration.
629LogicalResult scf::ForallOp::promoteIfSingleIteration(RewriterBase &rewriter) {
630 for (auto [lb, ub, step] :
631 llvm::zip(t: getMixedLowerBound(), u: getMixedUpperBound(), args: getMixedStep())) {
632 auto tripCount = constantTripCount(lb, ub, step);
633 if (!tripCount.has_value() || *tripCount != 1)
634 return failure();
635 }
636
637 promote(rewriter, forallOp: *this);
638 return success();
639}
640
641Block::BlockArgListType ForallOp::getRegionIterArgs() {
642 return getBody()->getArguments().drop_front(N: getRank());
643}
644
645MutableArrayRef<OpOperand> ForallOp::getInitsMutable() {
646 return getOutputsMutable();
647}
648
649/// Promotes the loop body of a scf::ForallOp to its containing block.
650void mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp) {
651 OpBuilder::InsertionGuard g(rewriter);
652 scf::InParallelOp terminator = forallOp.getTerminator();
653
654 // Replace block arguments with lower bounds (replacements for IVs) and
655 // outputs.
656 SmallVector<Value> bbArgReplacements = forallOp.getLowerBound(b&: rewriter);
657 bbArgReplacements.append(in_start: forallOp.getOutputs().begin(),
658 in_end: forallOp.getOutputs().end());
659
660 // Move the loop body operations to the loop's containing block.
661 rewriter.inlineBlockBefore(source: forallOp.getBody(), dest: forallOp->getBlock(),
662 before: forallOp->getIterator(), argValues: bbArgReplacements);
663
664 // Replace the terminator with tensor.insert_slice ops.
665 rewriter.setInsertionPointAfter(forallOp);
666 SmallVector<Value> results;
667 results.reserve(N: forallOp.getResults().size());
668 for (auto &yieldingOp : terminator.getYieldingOps()) {
669 auto parallelInsertSliceOp =
670 cast<tensor::ParallelInsertSliceOp>(Val&: yieldingOp);
671
672 Value dst = parallelInsertSliceOp.getDest();
673 Value src = parallelInsertSliceOp.getSource();
674 if (llvm::isa<TensorType>(Val: src.getType())) {
675 results.push_back(Elt: rewriter.create<tensor::InsertSliceOp>(
676 location: forallOp.getLoc(), args: dst.getType(), args&: src, args&: dst,
677 args: parallelInsertSliceOp.getOffsets(), args: parallelInsertSliceOp.getSizes(),
678 args: parallelInsertSliceOp.getStrides(),
679 args: parallelInsertSliceOp.getStaticOffsets(),
680 args: parallelInsertSliceOp.getStaticSizes(),
681 args: parallelInsertSliceOp.getStaticStrides()));
682 } else {
683 llvm_unreachable("unsupported terminator");
684 }
685 }
686 rewriter.replaceAllUsesWith(from: forallOp.getResults(), to: results);
687
688 // Erase the old terminator and the loop.
689 rewriter.eraseOp(op: terminator);
690 rewriter.eraseOp(op: forallOp);
691}
692
693LoopNest mlir::scf::buildLoopNest(
694 OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
695 ValueRange steps, ValueRange iterArgs,
696 function_ref<ValueVector(OpBuilder &, Location, ValueRange, ValueRange)>
697 bodyBuilder) {
698 assert(lbs.size() == ubs.size() &&
699 "expected the same number of lower and upper bounds");
700 assert(lbs.size() == steps.size() &&
701 "expected the same number of lower bounds and steps");
702
703 // If there are no bounds, call the body-building function and return early.
704 if (lbs.empty()) {
705 ValueVector results =
706 bodyBuilder ? bodyBuilder(builder, loc, ValueRange(), iterArgs)
707 : ValueVector();
708 assert(results.size() == iterArgs.size() &&
709 "loop nest body must return as many values as loop has iteration "
710 "arguments");
711 return LoopNest{.loops: {}, .results: std::move(results)};
712 }
713
714 // First, create the loop structure iteratively using the body-builder
715 // callback of `ForOp::build`. Do not create `YieldOp`s yet.
716 OpBuilder::InsertionGuard guard(builder);
717 SmallVector<scf::ForOp, 4> loops;
718 SmallVector<Value, 4> ivs;
719 loops.reserve(N: lbs.size());
720 ivs.reserve(N: lbs.size());
721 ValueRange currentIterArgs = iterArgs;
722 Location currentLoc = loc;
723 for (unsigned i = 0, e = lbs.size(); i < e; ++i) {
724 auto loop = builder.create<scf::ForOp>(
725 location: currentLoc, args: lbs[i], args: ubs[i], args: steps[i], args&: currentIterArgs,
726 args: [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv,
727 ValueRange args) {
728 ivs.push_back(Elt: iv);
729 // It is safe to store ValueRange args because it points to block
730 // arguments of a loop operation that we also own.
731 currentIterArgs = args;
732 currentLoc = nestedLoc;
733 });
734 // Set the builder to point to the body of the newly created loop. We don't
735 // do this in the callback because the builder is reset when the callback
736 // returns.
737 builder.setInsertionPointToStart(loop.getBody());
738 loops.push_back(Elt: loop);
739 }
740
741 // For all loops but the innermost, yield the results of the nested loop.
742 for (unsigned i = 0, e = loops.size() - 1; i < e; ++i) {
743 builder.setInsertionPointToEnd(loops[i].getBody());
744 builder.create<scf::YieldOp>(location: loc, args: loops[i + 1].getResults());
745 }
746
747 // In the body of the innermost loop, call the body building function if any
748 // and yield its results.
749 builder.setInsertionPointToStart(loops.back().getBody());
750 ValueVector results = bodyBuilder
751 ? bodyBuilder(builder, currentLoc, ivs,
752 loops.back().getRegionIterArgs())
753 : ValueVector();
754 assert(results.size() == iterArgs.size() &&
755 "loop nest body must return as many values as loop has iteration "
756 "arguments");
757 builder.setInsertionPointToEnd(loops.back().getBody());
758 builder.create<scf::YieldOp>(location: loc, args&: results);
759
760 // Return the loops.
761 ValueVector nestResults;
762 llvm::append_range(C&: nestResults, R: loops.front().getResults());
763 return LoopNest{.loops: std::move(loops), .results: std::move(nestResults)};
764}
765
766LoopNest mlir::scf::buildLoopNest(
767 OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
768 ValueRange steps,
769 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
770 // Delegate to the main function by wrapping the body builder.
771 return buildLoopNest(builder, loc, lbs, ubs, steps, iterArgs: {},
772 bodyBuilder: [&bodyBuilder](OpBuilder &nestedBuilder,
773 Location nestedLoc, ValueRange ivs,
774 ValueRange) -> ValueVector {
775 if (bodyBuilder)
776 bodyBuilder(nestedBuilder, nestedLoc, ivs);
777 return {};
778 });
779}
780
781SmallVector<Value>
782mlir::scf::replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp,
783 OpOperand &operand, Value replacement,
784 const ValueTypeCastFnTy &castFn) {
785 assert(operand.getOwner() == forOp);
786 Type oldType = operand.get().getType(), newType = replacement.getType();
787
788 // 1. Create new iter operands, exactly 1 is replaced.
789 assert(operand.getOperandNumber() >= forOp.getNumControlOperands() &&
790 "expected an iter OpOperand");
791 assert(operand.get().getType() != replacement.getType() &&
792 "Expected a different type");
793 SmallVector<Value> newIterOperands;
794 for (OpOperand &opOperand : forOp.getInitArgsMutable()) {
795 if (opOperand.getOperandNumber() == operand.getOperandNumber()) {
796 newIterOperands.push_back(Elt: replacement);
797 continue;
798 }
799 newIterOperands.push_back(Elt: opOperand.get());
800 }
801
802 // 2. Create the new forOp shell.
803 scf::ForOp newForOp = rewriter.create<scf::ForOp>(
804 location: forOp.getLoc(), args: forOp.getLowerBound(), args: forOp.getUpperBound(),
805 args: forOp.getStep(), args&: newIterOperands);
806 newForOp->setAttrs(forOp->getAttrs());
807 Block &newBlock = newForOp.getRegion().front();
808 SmallVector<Value, 4> newBlockTransferArgs(newBlock.getArguments().begin(),
809 newBlock.getArguments().end());
810
811 // 3. Inject an incoming cast op at the beginning of the block for the bbArg
812 // corresponding to the `replacement` value.
813 OpBuilder::InsertionGuard g(rewriter);
814 rewriter.setInsertionPointToStart(&newBlock);
815 BlockArgument newRegionIterArg = newForOp.getTiedLoopRegionIterArg(
816 opOperand: &newForOp->getOpOperand(idx: operand.getOperandNumber()));
817 Value castIn = castFn(rewriter, newForOp.getLoc(), oldType, newRegionIterArg);
818 newBlockTransferArgs[newRegionIterArg.getArgNumber()] = castIn;
819
820 // 4. Steal the old block ops, mapping to the newBlockTransferArgs.
821 Block &oldBlock = forOp.getRegion().front();
822 rewriter.mergeBlocks(source: &oldBlock, dest: &newBlock, argValues: newBlockTransferArgs);
823
824 // 5. Inject an outgoing cast op at the end of the block and yield it instead.
825 auto clonedYieldOp = cast<scf::YieldOp>(Val: newBlock.getTerminator());
826 rewriter.setInsertionPoint(clonedYieldOp);
827 unsigned yieldIdx =
828 newRegionIterArg.getArgNumber() - forOp.getNumInductionVars();
829 Value castOut = castFn(rewriter, newForOp.getLoc(), newType,
830 clonedYieldOp.getOperand(i: yieldIdx));
831 SmallVector<Value> newYieldOperands = clonedYieldOp.getOperands();
832 newYieldOperands[yieldIdx] = castOut;
833 rewriter.create<scf::YieldOp>(location: newForOp.getLoc(), args&: newYieldOperands);
834 rewriter.eraseOp(op: clonedYieldOp);
835
836 // 6. Inject an outgoing cast op after the forOp.
837 rewriter.setInsertionPointAfter(newForOp);
838 SmallVector<Value> newResults = newForOp.getResults();
839 newResults[yieldIdx] =
840 castFn(rewriter, newForOp.getLoc(), oldType, newResults[yieldIdx]);
841
842 return newResults;
843}
844
845namespace {
846// Fold away ForOp iter arguments when:
847// 1) The op yields the iter arguments.
848// 2) The argument's corresponding outer region iterators (inputs) are yielded.
849// 3) The iter arguments have no use and the corresponding (operation) results
850// have no use.
851//
852// These arguments must be defined outside of the ForOp region and can just be
853// forwarded after simplifying the op inits, yields and returns.
854//
855// The implementation uses `inlineBlockBefore` to steal the content of the
856// original ForOp and avoid cloning.
857struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
858 using OpRewritePattern<scf::ForOp>::OpRewritePattern;
859
860 LogicalResult matchAndRewrite(scf::ForOp forOp,
861 PatternRewriter &rewriter) const final {
862 bool canonicalize = false;
863
864 // An internal flat vector of block transfer
865 // arguments `newBlockTransferArgs` keeps the 1-1 mapping of original to
866 // transformed block argument mappings. This plays the role of a
867 // IRMapping for the particular use case of calling into
868 // `inlineBlockBefore`.
869 int64_t numResults = forOp.getNumResults();
870 SmallVector<bool, 4> keepMask;
871 keepMask.reserve(N: numResults);
872 SmallVector<Value, 4> newBlockTransferArgs, newIterArgs, newYieldValues,
873 newResultValues;
874 newBlockTransferArgs.reserve(N: 1 + numResults);
875 newBlockTransferArgs.push_back(Elt: Value()); // iv placeholder with null value
876 newIterArgs.reserve(N: forOp.getInitArgs().size());
877 newYieldValues.reserve(N: numResults);
878 newResultValues.reserve(N: numResults);
879 DenseMap<std::pair<Value, Value>, std::pair<Value, Value>> initYieldToArg;
880 for (auto [init, arg, result, yielded] :
881 llvm::zip(t: forOp.getInitArgs(), // iter from outside
882 u: forOp.getRegionIterArgs(), // iter inside region
883 args: forOp.getResults(), // op results
884 args: forOp.getYieldedValues() // iter yield
885 )) {
886 // Forwarded is `true` when:
887 // 1) The region `iter` argument is yielded.
888 // 2) The region `iter` argument the corresponding input is yielded.
889 // 3) The region `iter` argument has no use, and the corresponding op
890 // result has no use.
891 bool forwarded = (arg == yielded) || (init == yielded) ||
892 (arg.use_empty() && result.use_empty());
893 if (forwarded) {
894 canonicalize = true;
895 keepMask.push_back(Elt: false);
896 newBlockTransferArgs.push_back(Elt: init);
897 newResultValues.push_back(Elt: init);
898 continue;
899 }
900
901 // Check if a previous kept argument always has the same values for init
902 // and yielded values.
903 if (auto it = initYieldToArg.find(Val: {init, yielded});
904 it != initYieldToArg.end()) {
905 canonicalize = true;
906 keepMask.push_back(Elt: false);
907 auto [sameArg, sameResult] = it->second;
908 rewriter.replaceAllUsesWith(from: arg, to: sameArg);
909 rewriter.replaceAllUsesWith(from: result, to: sameResult);
910 // The replacement value doesn't matter because there are no uses.
911 newBlockTransferArgs.push_back(Elt: init);
912 newResultValues.push_back(Elt: init);
913 continue;
914 }
915
916 // This value is kept.
917 initYieldToArg.insert(KV: {{init, yielded}, {arg, result}});
918 keepMask.push_back(Elt: true);
919 newIterArgs.push_back(Elt: init);
920 newYieldValues.push_back(Elt: yielded);
921 newBlockTransferArgs.push_back(Elt: Value()); // placeholder with null value
922 newResultValues.push_back(Elt: Value()); // placeholder with null value
923 }
924
925 if (!canonicalize)
926 return failure();
927
928 scf::ForOp newForOp = rewriter.create<scf::ForOp>(
929 location: forOp.getLoc(), args: forOp.getLowerBound(), args: forOp.getUpperBound(),
930 args: forOp.getStep(), args&: newIterArgs);
931 newForOp->setAttrs(forOp->getAttrs());
932 Block &newBlock = newForOp.getRegion().front();
933
934 // Replace the null placeholders with newly constructed values.
935 newBlockTransferArgs[0] = newBlock.getArgument(i: 0); // iv
936 for (unsigned idx = 0, collapsedIdx = 0, e = newResultValues.size();
937 idx != e; ++idx) {
938 Value &blockTransferArg = newBlockTransferArgs[1 + idx];
939 Value &newResultVal = newResultValues[idx];
940 assert((blockTransferArg && newResultVal) ||
941 (!blockTransferArg && !newResultVal));
942 if (!blockTransferArg) {
943 blockTransferArg = newForOp.getRegionIterArgs()[collapsedIdx];
944 newResultVal = newForOp.getResult(i: collapsedIdx++);
945 }
946 }
947
948 Block &oldBlock = forOp.getRegion().front();
949 assert(oldBlock.getNumArguments() == newBlockTransferArgs.size() &&
950 "unexpected argument size mismatch");
951
952 // No results case: the scf::ForOp builder already created a zero
953 // result terminator. Merge before this terminator and just get rid of the
954 // original terminator that has been merged in.
955 if (newIterArgs.empty()) {
956 auto newYieldOp = cast<scf::YieldOp>(Val: newBlock.getTerminator());
957 rewriter.inlineBlockBefore(source: &oldBlock, op: newYieldOp, argValues: newBlockTransferArgs);
958 rewriter.eraseOp(op: newBlock.getTerminator()->getPrevNode());
959 rewriter.replaceOp(op: forOp, newValues: newResultValues);
960 return success();
961 }
962
963 // No terminator case: merge and rewrite the merged terminator.
964 auto cloneFilteredTerminator = [&](scf::YieldOp mergedTerminator) {
965 OpBuilder::InsertionGuard g(rewriter);
966 rewriter.setInsertionPoint(mergedTerminator);
967 SmallVector<Value, 4> filteredOperands;
968 filteredOperands.reserve(N: newResultValues.size());
969 for (unsigned idx = 0, e = keepMask.size(); idx < e; ++idx)
970 if (keepMask[idx])
971 filteredOperands.push_back(Elt: mergedTerminator.getOperand(i: idx));
972 rewriter.create<scf::YieldOp>(location: mergedTerminator.getLoc(),
973 args&: filteredOperands);
974 };
975
976 rewriter.mergeBlocks(source: &oldBlock, dest: &newBlock, argValues: newBlockTransferArgs);
977 auto mergedYieldOp = cast<scf::YieldOp>(Val: newBlock.getTerminator());
978 cloneFilteredTerminator(mergedYieldOp);
979 rewriter.eraseOp(op: mergedYieldOp);
980 rewriter.replaceOp(op: forOp, newValues: newResultValues);
981 return success();
982 }
983};
984
985/// Util function that tries to compute a constant diff between u and l.
986/// Returns std::nullopt when the difference between two AffineValueMap is
987/// dynamic.
988static std::optional<int64_t> computeConstDiff(Value l, Value u) {
989 IntegerAttr clb, cub;
990 if (matchPattern(value: l, pattern: m_Constant(bind_value: &clb)) && matchPattern(value: u, pattern: m_Constant(bind_value: &cub))) {
991 llvm::APInt lbValue = clb.getValue();
992 llvm::APInt ubValue = cub.getValue();
993 return (ubValue - lbValue).getSExtValue();
994 }
995
996 // Else a simple pattern match for x + c or c + x
997 llvm::APInt diff;
998 if (matchPattern(
999 value: u, pattern: m_Op<arith::AddIOp>(matchers: matchers::m_Val(v: l), matchers: m_ConstantInt(bind_value: &diff))) ||
1000 matchPattern(
1001 value: u, pattern: m_Op<arith::AddIOp>(matchers: m_ConstantInt(bind_value: &diff), matchers: matchers::m_Val(v: l))))
1002 return diff.getSExtValue();
1003 return std::nullopt;
1004}
1005
1006/// Rewriting pattern that erases loops that are known not to iterate, replaces
1007/// single-iteration loops with their bodies, and removes empty loops that
1008/// iterate at least once and only return values defined outside of the loop.
1009struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
1010 using OpRewritePattern<ForOp>::OpRewritePattern;
1011
1012 LogicalResult matchAndRewrite(ForOp op,
1013 PatternRewriter &rewriter) const override {
1014 // If the upper bound is the same as the lower bound, the loop does not
1015 // iterate, just remove it.
1016 if (op.getLowerBound() == op.getUpperBound()) {
1017 rewriter.replaceOp(op, newValues: op.getInitArgs());
1018 return success();
1019 }
1020
1021 std::optional<int64_t> diff =
1022 computeConstDiff(l: op.getLowerBound(), u: op.getUpperBound());
1023 if (!diff)
1024 return failure();
1025
1026 // If the loop is known to have 0 iterations, remove it.
1027 if (*diff <= 0) {
1028 rewriter.replaceOp(op, newValues: op.getInitArgs());
1029 return success();
1030 }
1031
1032 std::optional<llvm::APInt> maybeStepValue = op.getConstantStep();
1033 if (!maybeStepValue)
1034 return failure();
1035
1036 // If the loop is known to have 1 iteration, inline its body and remove the
1037 // loop.
1038 llvm::APInt stepValue = *maybeStepValue;
1039 if (stepValue.sge(RHS: *diff)) {
1040 SmallVector<Value, 4> blockArgs;
1041 blockArgs.reserve(N: op.getInitArgs().size() + 1);
1042 blockArgs.push_back(Elt: op.getLowerBound());
1043 llvm::append_range(C&: blockArgs, R: op.getInitArgs());
1044 replaceOpWithRegion(rewriter, op, region&: op.getRegion(), blockArgs);
1045 return success();
1046 }
1047
1048 // Now we are left with loops that have more than 1 iterations.
1049 Block &block = op.getRegion().front();
1050 if (!llvm::hasSingleElement(C&: block))
1051 return failure();
1052 // If the loop is empty, iterates at least once, and only returns values
1053 // defined outside of the loop, remove it and replace it with yield values.
1054 if (llvm::any_of(Range: op.getYieldedValues(),
1055 P: [&](Value v) { return !op.isDefinedOutsideOfLoop(value: v); }))
1056 return failure();
1057 rewriter.replaceOp(op, newValues: op.getYieldedValues());
1058 return success();
1059 }
1060};
1061
1062/// Fold scf.for iter_arg/result pairs that go through incoming/ougoing
1063/// a tensor.cast op pair so as to pull the tensor.cast inside the scf.for:
1064///
1065/// ```
1066/// %0 = tensor.cast %t0 : tensor<32x1024xf32> to tensor<?x?xf32>
1067/// %1 = scf.for %i = %c0 to %c1024 step %c32 iter_args(%iter_t0 = %0)
1068/// -> (tensor<?x?xf32>) {
1069/// %2 = call @do(%iter_t0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
1070/// scf.yield %2 : tensor<?x?xf32>
1071/// }
1072/// use_of(%1)
1073/// ```
1074///
1075/// folds into:
1076///
1077/// ```
1078/// %0 = scf.for %arg2 = %c0 to %c1024 step %c32 iter_args(%arg3 = %arg0)
1079/// -> (tensor<32x1024xf32>) {
1080/// %2 = tensor.cast %arg3 : tensor<32x1024xf32> to tensor<?x?xf32>
1081/// %3 = call @do(%2) : (tensor<?x?xf32>) -> tensor<?x?xf32>
1082/// %4 = tensor.cast %3 : tensor<?x?xf32> to tensor<32x1024xf32>
1083/// scf.yield %4 : tensor<32x1024xf32>
1084/// }
1085/// %1 = tensor.cast %0 : tensor<32x1024xf32> to tensor<?x?xf32>
1086/// use_of(%1)
1087/// ```
1088struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
1089 using OpRewritePattern<ForOp>::OpRewritePattern;
1090
1091 LogicalResult matchAndRewrite(ForOp op,
1092 PatternRewriter &rewriter) const override {
1093 for (auto it : llvm::zip(t: op.getInitArgsMutable(), u: op.getResults())) {
1094 OpOperand &iterOpOperand = std::get<0>(t&: it);
1095 auto incomingCast = iterOpOperand.get().getDefiningOp<tensor::CastOp>();
1096 if (!incomingCast ||
1097 incomingCast.getSource().getType() == incomingCast.getType())
1098 continue;
1099 // If the dest type of the cast does not preserve static information in
1100 // the source type.
1101 if (!tensor::preservesStaticInformation(
1102 source: incomingCast.getDest().getType(),
1103 target: incomingCast.getSource().getType()))
1104 continue;
1105 if (!std::get<1>(t&: it).hasOneUse())
1106 continue;
1107
1108 // Create a new ForOp with that iter operand replaced.
1109 rewriter.replaceOp(
1110 op, newValues: replaceAndCastForOpIterArg(
1111 rewriter, forOp: op, operand&: iterOpOperand, replacement: incomingCast.getSource(),
1112 castFn: [](OpBuilder &b, Location loc, Type type, Value source) {
1113 return b.create<tensor::CastOp>(location: loc, args&: type, args&: source);
1114 }));
1115 return success();
1116 }
1117 return failure();
1118 }
1119};
1120
1121} // namespace
1122
1123void ForOp::getCanonicalizationPatterns(RewritePatternSet &results,
1124 MLIRContext *context) {
1125 results.add<ForOpIterArgsFolder, SimplifyTrivialLoops, ForOpTensorCastFolder>(
1126 arg&: context);
1127}
1128
1129std::optional<APInt> ForOp::getConstantStep() {
1130 IntegerAttr step;
1131 if (matchPattern(value: getStep(), pattern: m_Constant(bind_value: &step)))
1132 return step.getValue();
1133 return {};
1134}
1135
1136std::optional<MutableArrayRef<OpOperand>> ForOp::getYieldedValuesMutable() {
1137 return cast<scf::YieldOp>(Val: getBody()->getTerminator()).getResultsMutable();
1138}
1139
1140Speculation::Speculatability ForOp::getSpeculatability() {
1141 // `scf.for (I = Start; I < End; I += 1)` terminates for all values of Start
1142 // and End.
1143 if (auto constantStep = getConstantStep())
1144 if (*constantStep == 1)
1145 return Speculation::RecursivelySpeculatable;
1146
1147 // For Step != 1, the loop may not terminate. We can add more smarts here if
1148 // needed.
1149 return Speculation::NotSpeculatable;
1150}
1151
1152//===----------------------------------------------------------------------===//
1153// ForallOp
1154//===----------------------------------------------------------------------===//
1155
1156LogicalResult ForallOp::verify() {
1157 unsigned numLoops = getRank();
1158 // Check number of outputs.
1159 if (getNumResults() != getOutputs().size())
1160 return emitOpError(message: "produces ")
1161 << getNumResults() << " results, but has only "
1162 << getOutputs().size() << " outputs";
1163
1164 // Check that the body defines block arguments for thread indices and outputs.
1165 auto *body = getBody();
1166 if (body->getNumArguments() != numLoops + getOutputs().size())
1167 return emitOpError(message: "region expects ") << numLoops << " arguments";
1168 for (int64_t i = 0; i < numLoops; ++i)
1169 if (!body->getArgument(i).getType().isIndex())
1170 return emitOpError(message: "expects ")
1171 << i << "-th block argument to be an index";
1172 for (unsigned i = 0; i < getOutputs().size(); ++i)
1173 if (body->getArgument(i: i + numLoops).getType() != getOutputs()[i].getType())
1174 return emitOpError(message: "type mismatch between ")
1175 << i << "-th output and corresponding block argument";
1176 if (getMapping().has_value() && !getMapping()->empty()) {
1177 if (getDeviceMappingAttrs().size() != numLoops)
1178 return emitOpError() << "mapping attribute size must match op rank";
1179 if (failed(Result: getDeviceMaskingAttr()))
1180 return emitOpError() << getMappingAttrName()
1181 << " supports at most one device masking attribute";
1182 }
1183
1184 // Verify mixed static/dynamic control variables.
1185 Operation *op = getOperation();
1186 if (failed(Result: verifyListOfOperandsOrIntegers(op, name: "lower bound", expectedNumElements: numLoops,
1187 attr: getStaticLowerBound(),
1188 values: getDynamicLowerBound())))
1189 return failure();
1190 if (failed(Result: verifyListOfOperandsOrIntegers(op, name: "upper bound", expectedNumElements: numLoops,
1191 attr: getStaticUpperBound(),
1192 values: getDynamicUpperBound())))
1193 return failure();
1194 if (failed(Result: verifyListOfOperandsOrIntegers(op, name: "step", expectedNumElements: numLoops,
1195 attr: getStaticStep(), values: getDynamicStep())))
1196 return failure();
1197
1198 return success();
1199}
1200
1201void ForallOp::print(OpAsmPrinter &p) {
1202 Operation *op = getOperation();
1203 p << " (" << getInductionVars();
1204 if (isNormalized()) {
1205 p << ") in ";
1206 printDynamicIndexList(printer&: p, op, values: getDynamicUpperBound(), integers: getStaticUpperBound(),
1207 /*valueTypes=*/scalableFlags: {}, /*scalables=*/valueTypes: {},
1208 delimiter: OpAsmParser::Delimiter::Paren);
1209 } else {
1210 p << ") = ";
1211 printDynamicIndexList(printer&: p, op, values: getDynamicLowerBound(), integers: getStaticLowerBound(),
1212 /*valueTypes=*/scalableFlags: {}, /*scalables=*/valueTypes: {},
1213 delimiter: OpAsmParser::Delimiter::Paren);
1214 p << " to ";
1215 printDynamicIndexList(printer&: p, op, values: getDynamicUpperBound(), integers: getStaticUpperBound(),
1216 /*valueTypes=*/scalableFlags: {}, /*scalables=*/valueTypes: {},
1217 delimiter: OpAsmParser::Delimiter::Paren);
1218 p << " step ";
1219 printDynamicIndexList(printer&: p, op, values: getDynamicStep(), integers: getStaticStep(),
1220 /*valueTypes=*/scalableFlags: {}, /*scalables=*/valueTypes: {},
1221 delimiter: OpAsmParser::Delimiter::Paren);
1222 }
1223 printInitializationList(p, blocksArgs: getRegionOutArgs(), initializers: getOutputs(), prefix: " shared_outs");
1224 p << " ";
1225 if (!getRegionOutArgs().empty())
1226 p << "-> (" << getResultTypes() << ") ";
1227 p.printRegion(blocks&: getRegion(),
1228 /*printEntryBlockArgs=*/false,
1229 /*printBlockTerminators=*/getNumResults() > 0);
1230 p.printOptionalAttrDict(attrs: op->getAttrs(), elidedAttrs: {getOperandSegmentSizesAttrName(),
1231 getStaticLowerBoundAttrName(),
1232 getStaticUpperBoundAttrName(),
1233 getStaticStepAttrName()});
1234}
1235
1236ParseResult ForallOp::parse(OpAsmParser &parser, OperationState &result) {
1237 OpBuilder b(parser.getContext());
1238 auto indexType = b.getIndexType();
1239
1240 // Parse an opening `(` followed by thread index variables followed by `)`
1241 // TODO: when we can refer to such "induction variable"-like handles from the
1242 // declarative assembly format, we can implement the parser as a custom hook.
1243 SmallVector<OpAsmParser::Argument, 4> ivs;
1244 if (parser.parseArgumentList(result&: ivs, delimiter: OpAsmParser::Delimiter::Paren))
1245 return failure();
1246
1247 DenseI64ArrayAttr staticLbs, staticUbs, staticSteps;
1248 SmallVector<OpAsmParser::UnresolvedOperand> dynamicLbs, dynamicUbs,
1249 dynamicSteps;
1250 if (succeeded(Result: parser.parseOptionalKeyword(keyword: "in"))) {
1251 // Parse upper bounds.
1252 if (parseDynamicIndexList(parser, values&: dynamicUbs, integers&: staticUbs,
1253 /*valueTypes=*/nullptr,
1254 delimiter: OpAsmParser::Delimiter::Paren) ||
1255 parser.resolveOperands(operands&: dynamicUbs, type: indexType, result&: result.operands))
1256 return failure();
1257
1258 unsigned numLoops = ivs.size();
1259 staticLbs = b.getDenseI64ArrayAttr(values: SmallVector<int64_t>(numLoops, 0));
1260 staticSteps = b.getDenseI64ArrayAttr(values: SmallVector<int64_t>(numLoops, 1));
1261 } else {
1262 // Parse lower bounds.
1263 if (parser.parseEqual() ||
1264 parseDynamicIndexList(parser, values&: dynamicLbs, integers&: staticLbs,
1265 /*valueTypes=*/nullptr,
1266 delimiter: OpAsmParser::Delimiter::Paren) ||
1267
1268 parser.resolveOperands(operands&: dynamicLbs, type: indexType, result&: result.operands))
1269 return failure();
1270
1271 // Parse upper bounds.
1272 if (parser.parseKeyword(keyword: "to") ||
1273 parseDynamicIndexList(parser, values&: dynamicUbs, integers&: staticUbs,
1274 /*valueTypes=*/nullptr,
1275 delimiter: OpAsmParser::Delimiter::Paren) ||
1276 parser.resolveOperands(operands&: dynamicUbs, type: indexType, result&: result.operands))
1277 return failure();
1278
1279 // Parse step values.
1280 if (parser.parseKeyword(keyword: "step") ||
1281 parseDynamicIndexList(parser, values&: dynamicSteps, integers&: staticSteps,
1282 /*valueTypes=*/nullptr,
1283 delimiter: OpAsmParser::Delimiter::Paren) ||
1284 parser.resolveOperands(operands&: dynamicSteps, type: indexType, result&: result.operands))
1285 return failure();
1286 }
1287
1288 // Parse out operands and results.
1289 SmallVector<OpAsmParser::Argument, 4> regionOutArgs;
1290 SmallVector<OpAsmParser::UnresolvedOperand, 4> outOperands;
1291 SMLoc outOperandsLoc = parser.getCurrentLocation();
1292 if (succeeded(Result: parser.parseOptionalKeyword(keyword: "shared_outs"))) {
1293 if (outOperands.size() != result.types.size())
1294 return parser.emitError(loc: outOperandsLoc,
1295 message: "mismatch between out operands and types");
1296 if (parser.parseAssignmentList(lhs&: regionOutArgs, rhs&: outOperands) ||
1297 parser.parseOptionalArrowTypeList(result&: result.types) ||
1298 parser.resolveOperands(operands&: outOperands, types&: result.types, loc: outOperandsLoc,
1299 result&: result.operands))
1300 return failure();
1301 }
1302
1303 // Parse region.
1304 SmallVector<OpAsmParser::Argument, 4> regionArgs;
1305 std::unique_ptr<Region> region = std::make_unique<Region>();
1306 for (auto &iv : ivs) {
1307 iv.type = b.getIndexType();
1308 regionArgs.push_back(Elt: iv);
1309 }
1310 for (const auto &it : llvm::enumerate(First&: regionOutArgs)) {
1311 auto &out = it.value();
1312 out.type = result.types[it.index()];
1313 regionArgs.push_back(Elt: out);
1314 }
1315 if (parser.parseRegion(region&: *region, arguments: regionArgs))
1316 return failure();
1317
1318 // Ensure terminator and move region.
1319 ForallOp::ensureTerminator(region&: *region, builder&: b, loc: result.location);
1320 result.addRegion(region: std::move(region));
1321
1322 // Parse the optional attribute list.
1323 if (parser.parseOptionalAttrDict(result&: result.attributes))
1324 return failure();
1325
1326 result.addAttribute(name: "staticLowerBound", attr: staticLbs);
1327 result.addAttribute(name: "staticUpperBound", attr: staticUbs);
1328 result.addAttribute(name: "staticStep", attr: staticSteps);
1329 result.addAttribute(name: "operandSegmentSizes",
1330 attr: parser.getBuilder().getDenseI32ArrayAttr(
1331 values: {static_cast<int32_t>(dynamicLbs.size()),
1332 static_cast<int32_t>(dynamicUbs.size()),
1333 static_cast<int32_t>(dynamicSteps.size()),
1334 static_cast<int32_t>(outOperands.size())}));
1335 return success();
1336}
1337
1338// Builder that takes loop bounds.
1339void ForallOp::build(
1340 mlir::OpBuilder &b, mlir::OperationState &result,
1341 ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs,
1342 ArrayRef<OpFoldResult> steps, ValueRange outputs,
1343 std::optional<ArrayAttr> mapping,
1344 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
1345 SmallVector<int64_t> staticLbs, staticUbs, staticSteps;
1346 SmallVector<Value> dynamicLbs, dynamicUbs, dynamicSteps;
1347 dispatchIndexOpFoldResults(ofrs: lbs, dynamicVec&: dynamicLbs, staticVec&: staticLbs);
1348 dispatchIndexOpFoldResults(ofrs: ubs, dynamicVec&: dynamicUbs, staticVec&: staticUbs);
1349 dispatchIndexOpFoldResults(ofrs: steps, dynamicVec&: dynamicSteps, staticVec&: staticSteps);
1350
1351 result.addOperands(newOperands: dynamicLbs);
1352 result.addOperands(newOperands: dynamicUbs);
1353 result.addOperands(newOperands: dynamicSteps);
1354 result.addOperands(newOperands: outputs);
1355 result.addTypes(newTypes: TypeRange(outputs));
1356
1357 result.addAttribute(name: getStaticLowerBoundAttrName(name: result.name),
1358 attr: b.getDenseI64ArrayAttr(values: staticLbs));
1359 result.addAttribute(name: getStaticUpperBoundAttrName(name: result.name),
1360 attr: b.getDenseI64ArrayAttr(values: staticUbs));
1361 result.addAttribute(name: getStaticStepAttrName(name: result.name),
1362 attr: b.getDenseI64ArrayAttr(values: staticSteps));
1363 result.addAttribute(
1364 name: "operandSegmentSizes",
1365 attr: b.getDenseI32ArrayAttr(values: {static_cast<int32_t>(dynamicLbs.size()),
1366 static_cast<int32_t>(dynamicUbs.size()),
1367 static_cast<int32_t>(dynamicSteps.size()),
1368 static_cast<int32_t>(outputs.size())}));
1369 if (mapping.has_value()) {
1370 result.addAttribute(name: ForallOp::getMappingAttrName(name: result.name),
1371 attr: mapping.value());
1372 }
1373
1374 Region *bodyRegion = result.addRegion();
1375 OpBuilder::InsertionGuard g(b);
1376 b.createBlock(parent: bodyRegion);
1377 Block &bodyBlock = bodyRegion->front();
1378
1379 // Add block arguments for indices and outputs.
1380 bodyBlock.addArguments(
1381 types: SmallVector<Type>(lbs.size(), b.getIndexType()),
1382 locs: SmallVector<Location>(staticLbs.size(), result.location));
1383 bodyBlock.addArguments(
1384 types: TypeRange(outputs),
1385 locs: SmallVector<Location>(outputs.size(), result.location));
1386
1387 b.setInsertionPointToStart(&bodyBlock);
1388 if (!bodyBuilderFn) {
1389 ForallOp::ensureTerminator(region&: *bodyRegion, builder&: b, loc: result.location);
1390 return;
1391 }
1392 bodyBuilderFn(b, result.location, bodyBlock.getArguments());
1393}
1394
1395// Builder that takes loop bounds.
1396void ForallOp::build(
1397 mlir::OpBuilder &b, mlir::OperationState &result,
1398 ArrayRef<OpFoldResult> ubs, ValueRange outputs,
1399 std::optional<ArrayAttr> mapping,
1400 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
1401 unsigned numLoops = ubs.size();
1402 SmallVector<OpFoldResult> lbs(numLoops, b.getIndexAttr(value: 0));
1403 SmallVector<OpFoldResult> steps(numLoops, b.getIndexAttr(value: 1));
1404 build(b, result, lbs, ubs, steps, outputs, mapping, bodyBuilderFn);
1405}
1406
1407// Checks if the lbs are zeros and steps are ones.
1408bool ForallOp::isNormalized() {
1409 auto allEqual = [](ArrayRef<OpFoldResult> results, int64_t val) {
1410 return llvm::all_of(Range&: results, P: [&](OpFoldResult ofr) {
1411 auto intValue = getConstantIntValue(ofr);
1412 return intValue.has_value() && intValue == val;
1413 });
1414 };
1415 return allEqual(getMixedLowerBound(), 0) && allEqual(getMixedStep(), 1);
1416}
1417
1418InParallelOp ForallOp::getTerminator() {
1419 return cast<InParallelOp>(Val: getBody()->getTerminator());
1420}
1421
1422SmallVector<Operation *> ForallOp::getCombiningOps(BlockArgument bbArg) {
1423 SmallVector<Operation *> storeOps;
1424 InParallelOp inParallelOp = getTerminator();
1425 for (Operation &yieldOp : inParallelOp.getYieldingOps()) {
1426 if (auto parallelInsertSliceOp =
1427 dyn_cast<tensor::ParallelInsertSliceOp>(Val&: yieldOp);
1428 parallelInsertSliceOp && parallelInsertSliceOp.getDest() == bbArg) {
1429 storeOps.push_back(Elt: parallelInsertSliceOp);
1430 }
1431 }
1432 return storeOps;
1433}
1434
1435SmallVector<DeviceMappingAttrInterface> ForallOp::getDeviceMappingAttrs() {
1436 SmallVector<DeviceMappingAttrInterface> res;
1437 if (!getMapping())
1438 return res;
1439 for (auto attr : getMapping()->getValue()) {
1440 auto m = dyn_cast<DeviceMappingAttrInterface>(Val&: attr);
1441 if (m)
1442 res.push_back(Elt: m);
1443 }
1444 return res;
1445}
1446
1447FailureOr<DeviceMaskingAttrInterface> ForallOp::getDeviceMaskingAttr() {
1448 DeviceMaskingAttrInterface res;
1449 if (!getMapping())
1450 return res;
1451 for (auto attr : getMapping()->getValue()) {
1452 auto m = dyn_cast<DeviceMaskingAttrInterface>(Val&: attr);
1453 if (m && res)
1454 return failure();
1455 if (m)
1456 res = m;
1457 }
1458 return res;
1459}
1460
1461bool ForallOp::usesLinearMapping() {
1462 SmallVector<DeviceMappingAttrInterface> ifaces = getDeviceMappingAttrs();
1463 if (ifaces.empty())
1464 return false;
1465 return ifaces.front().isLinearMapping();
1466}
1467
1468std::optional<SmallVector<Value>> ForallOp::getLoopInductionVars() {
1469 return SmallVector<Value>{getBody()->getArguments().take_front(N: getRank())};
1470}
1471
1472// Get lower bounds as OpFoldResult.
1473std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopLowerBounds() {
1474 Builder b(getOperation()->getContext());
1475 return getMixedValues(staticValues: getStaticLowerBound(), dynamicValues: getDynamicLowerBound(), b);
1476}
1477
1478// Get upper bounds as OpFoldResult.
1479std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopUpperBounds() {
1480 Builder b(getOperation()->getContext());
1481 return getMixedValues(staticValues: getStaticUpperBound(), dynamicValues: getDynamicUpperBound(), b);
1482}
1483
1484// Get steps as OpFoldResult.
1485std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopSteps() {
1486 Builder b(getOperation()->getContext());
1487 return getMixedValues(staticValues: getStaticStep(), dynamicValues: getDynamicStep(), b);
1488}
1489
1490ForallOp mlir::scf::getForallOpThreadIndexOwner(Value val) {
1491 auto tidxArg = llvm::dyn_cast<BlockArgument>(Val&: val);
1492 if (!tidxArg)
1493 return ForallOp();
1494 assert(tidxArg.getOwner() && "unlinked block argument");
1495 auto *containingOp = tidxArg.getOwner()->getParentOp();
1496 return dyn_cast<ForallOp>(Val: containingOp);
1497}
1498
1499namespace {
1500/// Fold tensor.dim(forall shared_outs(... = %t)) to tensor.dim(%t).
1501struct DimOfForallOp : public OpRewritePattern<tensor::DimOp> {
1502 using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
1503
1504 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1505 PatternRewriter &rewriter) const final {
1506 auto forallOp = dimOp.getSource().getDefiningOp<ForallOp>();
1507 if (!forallOp)
1508 return failure();
1509 Value sharedOut =
1510 forallOp.getTiedOpOperand(opResult: llvm::cast<OpResult>(Val: dimOp.getSource()))
1511 ->get();
1512 rewriter.modifyOpInPlace(
1513 root: dimOp, callable: [&]() { dimOp.getSourceMutable().assign(value: sharedOut); });
1514 return success();
1515 }
1516};
1517
1518class ForallOpControlOperandsFolder : public OpRewritePattern<ForallOp> {
1519public:
1520 using OpRewritePattern<ForallOp>::OpRewritePattern;
1521
1522 LogicalResult matchAndRewrite(ForallOp op,
1523 PatternRewriter &rewriter) const override {
1524 SmallVector<OpFoldResult> mixedLowerBound(op.getMixedLowerBound());
1525 SmallVector<OpFoldResult> mixedUpperBound(op.getMixedUpperBound());
1526 SmallVector<OpFoldResult> mixedStep(op.getMixedStep());
1527 if (failed(Result: foldDynamicIndexList(ofrs&: mixedLowerBound)) &&
1528 failed(Result: foldDynamicIndexList(ofrs&: mixedUpperBound)) &&
1529 failed(Result: foldDynamicIndexList(ofrs&: mixedStep)))
1530 return failure();
1531
1532 rewriter.modifyOpInPlace(root: op, callable: [&]() {
1533 SmallVector<Value> dynamicLowerBound, dynamicUpperBound, dynamicStep;
1534 SmallVector<int64_t> staticLowerBound, staticUpperBound, staticStep;
1535 dispatchIndexOpFoldResults(ofrs: mixedLowerBound, dynamicVec&: dynamicLowerBound,
1536 staticVec&: staticLowerBound);
1537 op.getDynamicLowerBoundMutable().assign(values: dynamicLowerBound);
1538 op.setStaticLowerBound(staticLowerBound);
1539
1540 dispatchIndexOpFoldResults(ofrs: mixedUpperBound, dynamicVec&: dynamicUpperBound,
1541 staticVec&: staticUpperBound);
1542 op.getDynamicUpperBoundMutable().assign(values: dynamicUpperBound);
1543 op.setStaticUpperBound(staticUpperBound);
1544
1545 dispatchIndexOpFoldResults(ofrs: mixedStep, dynamicVec&: dynamicStep, staticVec&: staticStep);
1546 op.getDynamicStepMutable().assign(values: dynamicStep);
1547 op.setStaticStep(staticStep);
1548
1549 op->setAttr(name: ForallOp::getOperandSegmentSizeAttr(),
1550 value: rewriter.getDenseI32ArrayAttr(
1551 values: {static_cast<int32_t>(dynamicLowerBound.size()),
1552 static_cast<int32_t>(dynamicUpperBound.size()),
1553 static_cast<int32_t>(dynamicStep.size()),
1554 static_cast<int32_t>(op.getNumResults())}));
1555 });
1556 return success();
1557 }
1558};
1559
1560/// The following canonicalization pattern folds the iter arguments of
1561/// scf.forall op if :-
1562/// 1. The corresponding result has zero uses.
1563/// 2. The iter argument is NOT being modified within the loop body.
1564/// uses.
1565///
1566/// Example of first case :-
1567/// INPUT:
1568/// %res:3 = scf.forall ... shared_outs(%arg0 = %a, %arg1 = %b, %arg2 = %c)
1569/// {
1570/// ...
1571/// <SOME USE OF %arg0>
1572/// <SOME USE OF %arg1>
1573/// <SOME USE OF %arg2>
1574/// ...
1575/// scf.forall.in_parallel {
1576/// <STORE OP WITH DESTINATION %arg1>
1577/// <STORE OP WITH DESTINATION %arg0>
1578/// <STORE OP WITH DESTINATION %arg2>
1579/// }
1580/// }
1581/// return %res#1
1582///
1583/// OUTPUT:
1584/// %res:3 = scf.forall ... shared_outs(%new_arg0 = %b)
1585/// {
1586/// ...
1587/// <SOME USE OF %a>
1588/// <SOME USE OF %new_arg0>
1589/// <SOME USE OF %c>
1590/// ...
1591/// scf.forall.in_parallel {
1592/// <STORE OP WITH DESTINATION %new_arg0>
1593/// }
1594/// }
1595/// return %res
1596///
1597/// NOTE: 1. All uses of the folded shared_outs (iter argument) within the
1598/// scf.forall is replaced by their corresponding operands.
1599/// 2. Even if there are <STORE OP WITH DESTINATION *> ops within the body
1600/// of the scf.forall besides within scf.forall.in_parallel terminator,
1601/// this canonicalization remains valid. For more details, please refer
1602/// to :
1603/// https://github.com/llvm/llvm-project/pull/90189#discussion_r1589011124
1604/// 3. TODO(avarma): Generalize it for other store ops. Currently it
1605/// handles tensor.parallel_insert_slice ops only.
1606///
1607/// Example of second case :-
1608/// INPUT:
1609/// %res:2 = scf.forall ... shared_outs(%arg0 = %a, %arg1 = %b)
1610/// {
1611/// ...
1612/// <SOME USE OF %arg0>
1613/// <SOME USE OF %arg1>
1614/// ...
1615/// scf.forall.in_parallel {
1616/// <STORE OP WITH DESTINATION %arg1>
1617/// }
1618/// }
1619/// return %res#0, %res#1
1620///
1621/// OUTPUT:
1622/// %res = scf.forall ... shared_outs(%new_arg0 = %b)
1623/// {
1624/// ...
1625/// <SOME USE OF %a>
1626/// <SOME USE OF %new_arg0>
1627/// ...
1628/// scf.forall.in_parallel {
1629/// <STORE OP WITH DESTINATION %new_arg0>
1630/// }
1631/// }
1632/// return %a, %res
1633struct ForallOpIterArgsFolder : public OpRewritePattern<ForallOp> {
1634 using OpRewritePattern<ForallOp>::OpRewritePattern;
1635
1636 LogicalResult matchAndRewrite(ForallOp forallOp,
1637 PatternRewriter &rewriter) const final {
1638 // Step 1: For a given i-th result of scf.forall, check the following :-
1639 // a. If it has any use.
1640 // b. If the corresponding iter argument is being modified within
1641 // the loop, i.e. has at least one store op with the iter arg as
1642 // its destination operand. For this we use
1643 // ForallOp::getCombiningOps(iter_arg).
1644 //
1645 // Based on the check we maintain the following :-
1646 // a. `resultToDelete` - i-th result of scf.forall that'll be
1647 // deleted.
1648 // b. `resultToReplace` - i-th result of the old scf.forall
1649 // whose uses will be replaced by the new scf.forall.
1650 // c. `newOuts` - the shared_outs' operand of the new scf.forall
1651 // corresponding to the i-th result with at least one use.
1652 SetVector<OpResult> resultToDelete;
1653 SmallVector<Value> resultToReplace;
1654 SmallVector<Value> newOuts;
1655 for (OpResult result : forallOp.getResults()) {
1656 OpOperand *opOperand = forallOp.getTiedOpOperand(opResult: result);
1657 BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
1658 if (result.use_empty() || forallOp.getCombiningOps(bbArg: blockArg).empty()) {
1659 resultToDelete.insert(X: result);
1660 } else {
1661 resultToReplace.push_back(Elt: result);
1662 newOuts.push_back(Elt: opOperand->get());
1663 }
1664 }
1665
1666 // Return early if all results of scf.forall have at least one use and being
1667 // modified within the loop.
1668 if (resultToDelete.empty())
1669 return failure();
1670
1671 // Step 2: For the the i-th result, do the following :-
1672 // a. Fetch the corresponding BlockArgument.
1673 // b. Look for store ops (currently tensor.parallel_insert_slice)
1674 // with the BlockArgument as its destination operand.
1675 // c. Remove the operations fetched in b.
1676 for (OpResult result : resultToDelete) {
1677 OpOperand *opOperand = forallOp.getTiedOpOperand(opResult: result);
1678 BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
1679 SmallVector<Operation *> combiningOps =
1680 forallOp.getCombiningOps(bbArg: blockArg);
1681 for (Operation *combiningOp : combiningOps)
1682 rewriter.eraseOp(op: combiningOp);
1683 }
1684
1685 // Step 3. Create a new scf.forall op with the new shared_outs' operands
1686 // fetched earlier
1687 auto newForallOp = rewriter.create<scf::ForallOp>(
1688 location: forallOp.getLoc(), args: forallOp.getMixedLowerBound(),
1689 args: forallOp.getMixedUpperBound(), args: forallOp.getMixedStep(), args&: newOuts,
1690 args: forallOp.getMapping(),
1691 /*bodyBuilderFn =*/args: [](OpBuilder &, Location, ValueRange) {});
1692
1693 // Step 4. Merge the block of the old scf.forall into the newly created
1694 // scf.forall using the new set of arguments.
1695 Block *loopBody = forallOp.getBody();
1696 Block *newLoopBody = newForallOp.getBody();
1697 ArrayRef<BlockArgument> newBbArgs = newLoopBody->getArguments();
1698 // Form initial new bbArg list with just the control operands of the new
1699 // scf.forall op.
1700 SmallVector<Value> newBlockArgs =
1701 llvm::map_to_vector(C: newBbArgs.take_front(N: forallOp.getRank()),
1702 F: [](BlockArgument b) -> Value { return b; });
1703 Block::BlockArgListType newSharedOutsArgs = newForallOp.getRegionOutArgs();
1704 unsigned index = 0;
1705 // Take the new corresponding bbArg if the old bbArg was used as a
1706 // destination in the in_parallel op. For all other bbArgs, use the
1707 // corresponding init_arg from the old scf.forall op.
1708 for (OpResult result : forallOp.getResults()) {
1709 if (resultToDelete.count(key: result)) {
1710 newBlockArgs.push_back(Elt: forallOp.getTiedOpOperand(opResult: result)->get());
1711 } else {
1712 newBlockArgs.push_back(Elt: newSharedOutsArgs[index++]);
1713 }
1714 }
1715 rewriter.mergeBlocks(source: loopBody, dest: newLoopBody, argValues: newBlockArgs);
1716
1717 // Step 5. Replace the uses of result of old scf.forall with that of the new
1718 // scf.forall.
1719 for (auto &&[oldResult, newResult] :
1720 llvm::zip(t&: resultToReplace, u: newForallOp->getResults()))
1721 rewriter.replaceAllUsesWith(from: oldResult, to: newResult);
1722
1723 // Step 6. Replace the uses of those values that either has no use or are
1724 // not being modified within the loop with the corresponding
1725 // OpOperand.
1726 for (OpResult oldResult : resultToDelete)
1727 rewriter.replaceAllUsesWith(from: oldResult,
1728 to: forallOp.getTiedOpOperand(opResult: oldResult)->get());
1729 return success();
1730 }
1731};
1732
1733struct ForallOpSingleOrZeroIterationDimsFolder
1734 : public OpRewritePattern<ForallOp> {
1735 using OpRewritePattern<ForallOp>::OpRewritePattern;
1736
1737 LogicalResult matchAndRewrite(ForallOp op,
1738 PatternRewriter &rewriter) const override {
1739 // Do not fold dimensions if they are mapped to processing units.
1740 if (op.getMapping().has_value() && !op.getMapping()->empty())
1741 return failure();
1742 Location loc = op.getLoc();
1743
1744 // Compute new loop bounds that omit all single-iteration loop dimensions.
1745 SmallVector<OpFoldResult> newMixedLowerBounds, newMixedUpperBounds,
1746 newMixedSteps;
1747 IRMapping mapping;
1748 for (auto [lb, ub, step, iv] :
1749 llvm::zip(t: op.getMixedLowerBound(), u: op.getMixedUpperBound(),
1750 args: op.getMixedStep(), args: op.getInductionVars())) {
1751 auto numIterations = constantTripCount(lb, ub, step);
1752 if (numIterations.has_value()) {
1753 // Remove the loop if it performs zero iterations.
1754 if (*numIterations == 0) {
1755 rewriter.replaceOp(op, newValues: op.getOutputs());
1756 return success();
1757 }
1758 // Replace the loop induction variable by the lower bound if the loop
1759 // performs a single iteration. Otherwise, copy the loop bounds.
1760 if (*numIterations == 1) {
1761 mapping.map(from: iv, to: getValueOrCreateConstantIndexOp(b&: rewriter, loc, ofr: lb));
1762 continue;
1763 }
1764 }
1765 newMixedLowerBounds.push_back(Elt: lb);
1766 newMixedUpperBounds.push_back(Elt: ub);
1767 newMixedSteps.push_back(Elt: step);
1768 }
1769
1770 // All of the loop dimensions perform a single iteration. Inline loop body.
1771 if (newMixedLowerBounds.empty()) {
1772 promote(rewriter, forallOp: op);
1773 return success();
1774 }
1775
1776 // Exit if none of the loop dimensions perform a single iteration.
1777 if (newMixedLowerBounds.size() == static_cast<unsigned>(op.getRank())) {
1778 return rewriter.notifyMatchFailure(
1779 arg&: op, msg: "no dimensions have 0 or 1 iterations");
1780 }
1781
1782 // Replace the loop by a lower-dimensional loop.
1783 ForallOp newOp;
1784 newOp = rewriter.create<ForallOp>(location: loc, args&: newMixedLowerBounds,
1785 args&: newMixedUpperBounds, args&: newMixedSteps,
1786 args: op.getOutputs(), args: std::nullopt, args: nullptr);
1787 newOp.getBodyRegion().getBlocks().clear();
1788 // The new loop needs to keep all attributes from the old one, except for
1789 // "operandSegmentSizes" and static loop bound attributes which capture
1790 // the outdated information of the old iteration domain.
1791 SmallVector<StringAttr> elidedAttrs{newOp.getOperandSegmentSizesAttrName(),
1792 newOp.getStaticLowerBoundAttrName(),
1793 newOp.getStaticUpperBoundAttrName(),
1794 newOp.getStaticStepAttrName()};
1795 for (const auto &namedAttr : op->getAttrs()) {
1796 if (llvm::is_contained(Range&: elidedAttrs, Element: namedAttr.getName()))
1797 continue;
1798 rewriter.modifyOpInPlace(root: newOp, callable: [&]() {
1799 newOp->setAttr(name: namedAttr.getName(), value: namedAttr.getValue());
1800 });
1801 }
1802 rewriter.cloneRegionBefore(region&: op.getRegion(), parent&: newOp.getRegion(),
1803 before: newOp.getRegion().begin(), mapping);
1804 rewriter.replaceOp(op, newValues: newOp.getResults());
1805 return success();
1806 }
1807};
1808
1809/// Replace all induction vars with a single trip count with their lower bound.
1810struct ForallOpReplaceConstantInductionVar : public OpRewritePattern<ForallOp> {
1811 using OpRewritePattern<ForallOp>::OpRewritePattern;
1812
1813 LogicalResult matchAndRewrite(ForallOp op,
1814 PatternRewriter &rewriter) const override {
1815 Location loc = op.getLoc();
1816 bool changed = false;
1817 for (auto [lb, ub, step, iv] :
1818 llvm::zip(t: op.getMixedLowerBound(), u: op.getMixedUpperBound(),
1819 args: op.getMixedStep(), args: op.getInductionVars())) {
1820 if (iv.hasNUses(n: 0))
1821 continue;
1822 auto numIterations = constantTripCount(lb, ub, step);
1823 if (!numIterations.has_value() || numIterations.value() != 1) {
1824 continue;
1825 }
1826 rewriter.replaceAllUsesWith(
1827 from: iv, to: getValueOrCreateConstantIndexOp(b&: rewriter, loc, ofr: lb));
1828 changed = true;
1829 }
1830 return success(IsSuccess: changed);
1831 }
1832};
1833
1834struct FoldTensorCastOfOutputIntoForallOp
1835 : public OpRewritePattern<scf::ForallOp> {
1836 using OpRewritePattern<scf::ForallOp>::OpRewritePattern;
1837
1838 struct TypeCast {
1839 Type srcType;
1840 Type dstType;
1841 };
1842
1843 LogicalResult matchAndRewrite(scf::ForallOp forallOp,
1844 PatternRewriter &rewriter) const final {
1845 llvm::SmallMapVector<unsigned, TypeCast, 2> tensorCastProducers;
1846 llvm::SmallVector<Value> newOutputTensors = forallOp.getOutputs();
1847 for (auto en : llvm::enumerate(First&: newOutputTensors)) {
1848 auto castOp = en.value().getDefiningOp<tensor::CastOp>();
1849 if (!castOp)
1850 continue;
1851
1852 // Only casts that that preserve static information, i.e. will make the
1853 // loop result type "more" static than before, will be folded.
1854 if (!tensor::preservesStaticInformation(source: castOp.getDest().getType(),
1855 target: castOp.getSource().getType())) {
1856 continue;
1857 }
1858
1859 tensorCastProducers[en.index()] =
1860 TypeCast{.srcType: castOp.getSource().getType(), .dstType: castOp.getType()};
1861 newOutputTensors[en.index()] = castOp.getSource();
1862 }
1863
1864 if (tensorCastProducers.empty())
1865 return failure();
1866
1867 // Create new loop.
1868 Location loc = forallOp.getLoc();
1869 auto newForallOp = rewriter.create<ForallOp>(
1870 location: loc, args: forallOp.getMixedLowerBound(), args: forallOp.getMixedUpperBound(),
1871 args: forallOp.getMixedStep(), args&: newOutputTensors, args: forallOp.getMapping(),
1872 args: [&](OpBuilder nestedBuilder, Location nestedLoc, ValueRange bbArgs) {
1873 auto castBlockArgs =
1874 llvm::to_vector(Range: bbArgs.take_back(n: forallOp->getNumResults()));
1875 for (auto [index, cast] : tensorCastProducers) {
1876 Value &oldTypeBBArg = castBlockArgs[index];
1877 oldTypeBBArg = nestedBuilder.create<tensor::CastOp>(
1878 location: nestedLoc, args&: cast.dstType, args&: oldTypeBBArg);
1879 }
1880
1881 // Move old body into new parallel loop.
1882 SmallVector<Value> ivsBlockArgs =
1883 llvm::to_vector(Range: bbArgs.take_front(n: forallOp.getRank()));
1884 ivsBlockArgs.append(RHS: castBlockArgs);
1885 rewriter.mergeBlocks(source: forallOp.getBody(),
1886 dest: bbArgs.front().getParentBlock(), argValues: ivsBlockArgs);
1887 });
1888
1889 // After `mergeBlocks` happened, the destinations in the terminator were
1890 // mapped to the tensor.cast old-typed results of the output bbArgs. The
1891 // destination have to be updated to point to the output bbArgs directly.
1892 auto terminator = newForallOp.getTerminator();
1893 for (auto [yieldingOp, outputBlockArg] : llvm::zip(
1894 t: terminator.getYieldingOps(), u: newForallOp.getRegionIterArgs())) {
1895 auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(Val&: yieldingOp);
1896 insertSliceOp.getDestMutable().assign(value: outputBlockArg);
1897 }
1898
1899 // Cast results back to the original types.
1900 rewriter.setInsertionPointAfter(newForallOp);
1901 SmallVector<Value> castResults = newForallOp.getResults();
1902 for (auto &item : tensorCastProducers) {
1903 Value &oldTypeResult = castResults[item.first];
1904 oldTypeResult = rewriter.create<tensor::CastOp>(location: loc, args&: item.second.dstType,
1905 args&: oldTypeResult);
1906 }
1907 rewriter.replaceOp(op: forallOp, newValues: castResults);
1908 return success();
1909 }
1910};
1911
1912} // namespace
1913
1914void ForallOp::getCanonicalizationPatterns(RewritePatternSet &results,
1915 MLIRContext *context) {
1916 results.add<DimOfForallOp, FoldTensorCastOfOutputIntoForallOp,
1917 ForallOpControlOperandsFolder, ForallOpIterArgsFolder,
1918 ForallOpSingleOrZeroIterationDimsFolder,
1919 ForallOpReplaceConstantInductionVar>(arg&: context);
1920}
1921
1922/// Given the region at `index`, or the parent operation if `index` is None,
1923/// return the successor regions. These are the regions that may be selected
1924/// during the flow of control. `operands` is a set of optional attributes that
1925/// correspond to a constant value for each operand, or null if that operand is
1926/// not a constant.
1927void ForallOp::getSuccessorRegions(RegionBranchPoint point,
1928 SmallVectorImpl<RegionSuccessor> &regions) {
1929 // Both the operation itself and the region may be branching into the body or
1930 // back into the operation itself. It is possible for loop not to enter the
1931 // body.
1932 regions.push_back(Elt: RegionSuccessor(&getRegion()));
1933 regions.push_back(Elt: RegionSuccessor());
1934}
1935
1936//===----------------------------------------------------------------------===//
1937// InParallelOp
1938//===----------------------------------------------------------------------===//
1939
1940// Build a InParallelOp with mixed static and dynamic entries.
1941void InParallelOp::build(OpBuilder &b, OperationState &result) {
1942 OpBuilder::InsertionGuard g(b);
1943 Region *bodyRegion = result.addRegion();
1944 b.createBlock(parent: bodyRegion);
1945}
1946
1947LogicalResult InParallelOp::verify() {
1948 scf::ForallOp forallOp =
1949 dyn_cast<scf::ForallOp>(Val: getOperation()->getParentOp());
1950 if (!forallOp)
1951 return this->emitOpError(message: "expected forall op parent");
1952
1953 // TODO: InParallelOpInterface.
1954 for (Operation &op : getRegion().front().getOperations()) {
1955 if (!isa<tensor::ParallelInsertSliceOp>(Val: op)) {
1956 return this->emitOpError(message: "expected only ")
1957 << tensor::ParallelInsertSliceOp::getOperationName() << " ops";
1958 }
1959
1960 // Verify that inserts are into out block arguments.
1961 Value dest = cast<tensor::ParallelInsertSliceOp>(Val&: op).getDest();
1962 ArrayRef<BlockArgument> regionOutArgs = forallOp.getRegionOutArgs();
1963 if (!llvm::is_contained(Range&: regionOutArgs, Element: dest))
1964 return op.emitOpError(message: "may only insert into an output block argument");
1965 }
1966 return success();
1967}
1968
1969void InParallelOp::print(OpAsmPrinter &p) {
1970 p << " ";
1971 p.printRegion(blocks&: getRegion(),
1972 /*printEntryBlockArgs=*/false,
1973 /*printBlockTerminators=*/false);
1974 p.printOptionalAttrDict(attrs: getOperation()->getAttrs());
1975}
1976
1977ParseResult InParallelOp::parse(OpAsmParser &parser, OperationState &result) {
1978 auto &builder = parser.getBuilder();
1979
1980 SmallVector<OpAsmParser::Argument, 8> regionOperands;
1981 std::unique_ptr<Region> region = std::make_unique<Region>();
1982 if (parser.parseRegion(region&: *region, arguments: regionOperands))
1983 return failure();
1984
1985 if (region->empty())
1986 OpBuilder(builder.getContext()).createBlock(parent: region.get());
1987 result.addRegion(region: std::move(region));
1988
1989 // Parse the optional attribute list.
1990 if (parser.parseOptionalAttrDict(result&: result.attributes))
1991 return failure();
1992 return success();
1993}
1994
1995OpResult InParallelOp::getParentResult(int64_t idx) {
1996 return getOperation()->getParentOp()->getResult(idx);
1997}
1998
1999SmallVector<BlockArgument> InParallelOp::getDests() {
2000 return llvm::to_vector<4>(
2001 Range: llvm::map_range(C: getYieldingOps(), F: [](Operation &op) {
2002 // Add new ops here as needed.
2003 auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(Val: &op);
2004 return llvm::cast<BlockArgument>(Val: insertSliceOp.getDest());
2005 }));
2006}
2007
2008llvm::iterator_range<Block::iterator> InParallelOp::getYieldingOps() {
2009 return getRegion().front().getOperations();
2010}
2011
2012//===----------------------------------------------------------------------===//
2013// IfOp
2014//===----------------------------------------------------------------------===//
2015
2016bool mlir::scf::insideMutuallyExclusiveBranches(Operation *a, Operation *b) {
2017 assert(a && "expected non-empty operation");
2018 assert(b && "expected non-empty operation");
2019
2020 IfOp ifOp = a->getParentOfType<IfOp>();
2021 while (ifOp) {
2022 // Check if b is inside ifOp. (We already know that a is.)
2023 if (ifOp->isProperAncestor(other: b))
2024 // b is contained in ifOp. a and b are in mutually exclusive branches if
2025 // they are in different blocks of ifOp.
2026 return static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(op&: *a)) !=
2027 static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(op&: *b));
2028 // Check next enclosing IfOp.
2029 ifOp = ifOp->getParentOfType<IfOp>();
2030 }
2031
2032 // Could not find a common IfOp among a's and b's ancestors.
2033 return false;
2034}
2035
2036LogicalResult
2037IfOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
2038 IfOp::Adaptor adaptor,
2039 SmallVectorImpl<Type> &inferredReturnTypes) {
2040 if (adaptor.getRegions().empty())
2041 return failure();
2042 Region *r = &adaptor.getThenRegion();
2043 if (r->empty())
2044 return failure();
2045 Block &b = r->front();
2046 if (b.empty())
2047 return failure();
2048 auto yieldOp = llvm::dyn_cast<YieldOp>(Val&: b.back());
2049 if (!yieldOp)
2050 return failure();
2051 TypeRange types = yieldOp.getOperandTypes();
2052 llvm::append_range(C&: inferredReturnTypes, R&: types);
2053 return success();
2054}
2055
2056void IfOp::build(OpBuilder &builder, OperationState &result,
2057 TypeRange resultTypes, Value cond) {
2058 return build(odsBuilder&: builder, odsState&: result, resultTypes, cond, /*addThenBlock=*/false,
2059 /*addElseBlock=*/false);
2060}
2061
2062void IfOp::build(OpBuilder &builder, OperationState &result,
2063 TypeRange resultTypes, Value cond, bool addThenBlock,
2064 bool addElseBlock) {
2065 assert((!addElseBlock || addThenBlock) &&
2066 "must not create else block w/o then block");
2067 result.addTypes(newTypes&: resultTypes);
2068 result.addOperands(newOperands: cond);
2069
2070 // Add regions and blocks.
2071 OpBuilder::InsertionGuard guard(builder);
2072 Region *thenRegion = result.addRegion();
2073 if (addThenBlock)
2074 builder.createBlock(parent: thenRegion);
2075 Region *elseRegion = result.addRegion();
2076 if (addElseBlock)
2077 builder.createBlock(parent: elseRegion);
2078}
2079
2080void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
2081 bool withElseRegion) {
2082 build(odsBuilder&: builder, odsState&: result, resultTypes: TypeRange{}, cond, withElseRegion);
2083}
2084
2085void IfOp::build(OpBuilder &builder, OperationState &result,
2086 TypeRange resultTypes, Value cond, bool withElseRegion) {
2087 result.addTypes(newTypes&: resultTypes);
2088 result.addOperands(newOperands: cond);
2089
2090 // Build then region.
2091 OpBuilder::InsertionGuard guard(builder);
2092 Region *thenRegion = result.addRegion();
2093 builder.createBlock(parent: thenRegion);
2094 if (resultTypes.empty())
2095 IfOp::ensureTerminator(region&: *thenRegion, builder, loc: result.location);
2096
2097 // Build else region.
2098 Region *elseRegion = result.addRegion();
2099 if (withElseRegion) {
2100 builder.createBlock(parent: elseRegion);
2101 if (resultTypes.empty())
2102 IfOp::ensureTerminator(region&: *elseRegion, builder, loc: result.location);
2103 }
2104}
2105
2106void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
2107 function_ref<void(OpBuilder &, Location)> thenBuilder,
2108 function_ref<void(OpBuilder &, Location)> elseBuilder) {
2109 assert(thenBuilder && "the builder callback for 'then' must be present");
2110 result.addOperands(newOperands: cond);
2111
2112 // Build then region.
2113 OpBuilder::InsertionGuard guard(builder);
2114 Region *thenRegion = result.addRegion();
2115 builder.createBlock(parent: thenRegion);
2116 thenBuilder(builder, result.location);
2117
2118 // Build else region.
2119 Region *elseRegion = result.addRegion();
2120 if (elseBuilder) {
2121 builder.createBlock(parent: elseRegion);
2122 elseBuilder(builder, result.location);
2123 }
2124
2125 // Infer result types.
2126 SmallVector<Type> inferredReturnTypes;
2127 MLIRContext *ctx = builder.getContext();
2128 auto attrDict = DictionaryAttr::get(context: ctx, value: result.attributes);
2129 if (succeeded(Result: inferReturnTypes(context: ctx, location: std::nullopt, operands: result.operands, attributes: attrDict,
2130 /*properties=*/nullptr, regions: result.regions,
2131 inferredReturnTypes))) {
2132 result.addTypes(newTypes: inferredReturnTypes);
2133 }
2134}
2135
2136LogicalResult IfOp::verify() {
2137 if (getNumResults() != 0 && getElseRegion().empty())
2138 return emitOpError(message: "must have an else block if defining values");
2139 return success();
2140}
2141
2142ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
2143 // Create the regions for 'then'.
2144 result.regions.reserve(N: 2);
2145 Region *thenRegion = result.addRegion();
2146 Region *elseRegion = result.addRegion();
2147
2148 auto &builder = parser.getBuilder();
2149 OpAsmParser::UnresolvedOperand cond;
2150 Type i1Type = builder.getIntegerType(width: 1);
2151 if (parser.parseOperand(result&: cond) ||
2152 parser.resolveOperand(operand: cond, type: i1Type, result&: result.operands))
2153 return failure();
2154 // Parse optional results type list.
2155 if (parser.parseOptionalArrowTypeList(result&: result.types))
2156 return failure();
2157 // Parse the 'then' region.
2158 if (parser.parseRegion(region&: *thenRegion, /*arguments=*/{}, /*argTypes=*/enableNameShadowing: {}))
2159 return failure();
2160 IfOp::ensureTerminator(region&: *thenRegion, builder&: parser.getBuilder(), loc: result.location);
2161
2162 // If we find an 'else' keyword then parse the 'else' region.
2163 if (!parser.parseOptionalKeyword(keyword: "else")) {
2164 if (parser.parseRegion(region&: *elseRegion, /*arguments=*/{}, /*argTypes=*/enableNameShadowing: {}))
2165 return failure();
2166 IfOp::ensureTerminator(region&: *elseRegion, builder&: parser.getBuilder(), loc: result.location);
2167 }
2168
2169 // Parse the optional attribute list.
2170 if (parser.parseOptionalAttrDict(result&: result.attributes))
2171 return failure();
2172 return success();
2173}
2174
2175void IfOp::print(OpAsmPrinter &p) {
2176 bool printBlockTerminators = false;
2177
2178 p << " " << getCondition();
2179 if (!getResults().empty()) {
2180 p << " -> (" << getResultTypes() << ")";
2181 // Print yield explicitly if the op defines values.
2182 printBlockTerminators = true;
2183 }
2184 p << ' ';
2185 p.printRegion(blocks&: getThenRegion(),
2186 /*printEntryBlockArgs=*/false,
2187 /*printBlockTerminators=*/printBlockTerminators);
2188
2189 // Print the 'else' regions if it exists and has a block.
2190 auto &elseRegion = getElseRegion();
2191 if (!elseRegion.empty()) {
2192 p << " else ";
2193 p.printRegion(blocks&: elseRegion,
2194 /*printEntryBlockArgs=*/false,
2195 /*printBlockTerminators=*/printBlockTerminators);
2196 }
2197
2198 p.printOptionalAttrDict(attrs: (*this)->getAttrs());
2199}
2200
2201void IfOp::getSuccessorRegions(RegionBranchPoint point,
2202 SmallVectorImpl<RegionSuccessor> &regions) {
2203 // The `then` and the `else` region branch back to the parent operation.
2204 if (!point.isParent()) {
2205 regions.push_back(Elt: RegionSuccessor(getResults()));
2206 return;
2207 }
2208
2209 regions.push_back(Elt: RegionSuccessor(&getThenRegion()));
2210
2211 // Don't consider the else region if it is empty.
2212 Region *elseRegion = &this->getElseRegion();
2213 if (elseRegion->empty())
2214 regions.push_back(Elt: RegionSuccessor());
2215 else
2216 regions.push_back(Elt: RegionSuccessor(elseRegion));
2217}
2218
2219void IfOp::getEntrySuccessorRegions(ArrayRef<Attribute> operands,
2220 SmallVectorImpl<RegionSuccessor> &regions) {
2221 FoldAdaptor adaptor(operands, *this);
2222 auto boolAttr = dyn_cast_or_null<BoolAttr>(Val: adaptor.getCondition());
2223 if (!boolAttr || boolAttr.getValue())
2224 regions.emplace_back(Args: &getThenRegion());
2225
2226 // If the else region is empty, execution continues after the parent op.
2227 if (!boolAttr || !boolAttr.getValue()) {
2228 if (!getElseRegion().empty())
2229 regions.emplace_back(Args: &getElseRegion());
2230 else
2231 regions.emplace_back(Args: getResults());
2232 }
2233}
2234
2235LogicalResult IfOp::fold(FoldAdaptor adaptor,
2236 SmallVectorImpl<OpFoldResult> &results) {
2237 // if (!c) then A() else B() -> if c then B() else A()
2238 if (getElseRegion().empty())
2239 return failure();
2240
2241 arith::XOrIOp xorStmt = getCondition().getDefiningOp<arith::XOrIOp>();
2242 if (!xorStmt)
2243 return failure();
2244
2245 if (!matchPattern(value: xorStmt.getRhs(), pattern: m_One()))
2246 return failure();
2247
2248 getConditionMutable().assign(value: xorStmt.getLhs());
2249 Block *thenBlock = &getThenRegion().front();
2250 // It would be nicer to use iplist::swap, but that has no implemented
2251 // callbacks See: https://llvm.org/doxygen/ilist_8h_source.html#l00224
2252 getThenRegion().getBlocks().splice(where: getThenRegion().getBlocks().begin(),
2253 L2&: getElseRegion().getBlocks());
2254 getElseRegion().getBlocks().splice(where: getElseRegion().getBlocks().begin(),
2255 L2&: getThenRegion().getBlocks(), N: thenBlock);
2256 return success();
2257}
2258
2259void IfOp::getRegionInvocationBounds(
2260 ArrayRef<Attribute> operands,
2261 SmallVectorImpl<InvocationBounds> &invocationBounds) {
2262 if (auto cond = llvm::dyn_cast_or_null<BoolAttr>(Val: operands[0])) {
2263 // If the condition is known, then one region is known to be executed once
2264 // and the other zero times.
2265 invocationBounds.emplace_back(Args: 0, Args: cond.getValue() ? 1 : 0);
2266 invocationBounds.emplace_back(Args: 0, Args: cond.getValue() ? 0 : 1);
2267 } else {
2268 // Non-constant condition. Each region may be executed 0 or 1 times.
2269 invocationBounds.assign(NumElts: 2, Elt: {0, 1});
2270 }
2271}
2272
2273namespace {
2274// Pattern to remove unused IfOp results.
2275struct RemoveUnusedResults : public OpRewritePattern<IfOp> {
2276 using OpRewritePattern<IfOp>::OpRewritePattern;
2277
2278 void transferBody(Block *source, Block *dest, ArrayRef<OpResult> usedResults,
2279 PatternRewriter &rewriter) const {
2280 // Move all operations to the destination block.
2281 rewriter.mergeBlocks(source, dest);
2282 // Replace the yield op by one that returns only the used values.
2283 auto yieldOp = cast<scf::YieldOp>(Val: dest->getTerminator());
2284 SmallVector<Value, 4> usedOperands;
2285 llvm::transform(Range&: usedResults, d_first: std::back_inserter(x&: usedOperands),
2286 F: [&](OpResult result) {
2287 return yieldOp.getOperand(i: result.getResultNumber());
2288 });
2289 rewriter.modifyOpInPlace(root: yieldOp,
2290 callable: [&]() { yieldOp->setOperands(usedOperands); });
2291 }
2292
2293 LogicalResult matchAndRewrite(IfOp op,
2294 PatternRewriter &rewriter) const override {
2295 // Compute the list of used results.
2296 SmallVector<OpResult, 4> usedResults;
2297 llvm::copy_if(Range: op.getResults(), Out: std::back_inserter(x&: usedResults),
2298 P: [](OpResult result) { return !result.use_empty(); });
2299
2300 // Replace the operation if only a subset of its results have uses.
2301 if (usedResults.size() == op.getNumResults())
2302 return failure();
2303
2304 // Compute the result types of the replacement operation.
2305 SmallVector<Type, 4> newTypes;
2306 llvm::transform(Range&: usedResults, d_first: std::back_inserter(x&: newTypes),
2307 F: [](OpResult result) { return result.getType(); });
2308
2309 // Create a replacement operation with empty then and else regions.
2310 auto newOp =
2311 rewriter.create<IfOp>(location: op.getLoc(), args&: newTypes, args: op.getCondition());
2312 rewriter.createBlock(parent: &newOp.getThenRegion());
2313 rewriter.createBlock(parent: &newOp.getElseRegion());
2314
2315 // Move the bodies and replace the terminators (note there is a then and
2316 // an else region since the operation returns results).
2317 transferBody(source: op.getBody(idx: 0), dest: newOp.getBody(idx: 0), usedResults, rewriter);
2318 transferBody(source: op.getBody(idx: 1), dest: newOp.getBody(idx: 1), usedResults, rewriter);
2319
2320 // Replace the operation by the new one.
2321 SmallVector<Value, 4> repResults(op.getNumResults());
2322 for (const auto &en : llvm::enumerate(First&: usedResults))
2323 repResults[en.value().getResultNumber()] = newOp.getResult(i: en.index());
2324 rewriter.replaceOp(op, newValues: repResults);
2325 return success();
2326 }
2327};
2328
2329struct RemoveStaticCondition : public OpRewritePattern<IfOp> {
2330 using OpRewritePattern<IfOp>::OpRewritePattern;
2331
2332 LogicalResult matchAndRewrite(IfOp op,
2333 PatternRewriter &rewriter) const override {
2334 BoolAttr condition;
2335 if (!matchPattern(value: op.getCondition(), pattern: m_Constant(bind_value: &condition)))
2336 return failure();
2337
2338 if (condition.getValue())
2339 replaceOpWithRegion(rewriter, op, region&: op.getThenRegion());
2340 else if (!op.getElseRegion().empty())
2341 replaceOpWithRegion(rewriter, op, region&: op.getElseRegion());
2342 else
2343 rewriter.eraseOp(op);
2344
2345 return success();
2346 }
2347};
2348
2349/// Hoist any yielded results whose operands are defined outside
2350/// the if, to a select instruction.
2351struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> {
2352 using OpRewritePattern<IfOp>::OpRewritePattern;
2353
2354 LogicalResult matchAndRewrite(IfOp op,
2355 PatternRewriter &rewriter) const override {
2356 if (op->getNumResults() == 0)
2357 return failure();
2358
2359 auto cond = op.getCondition();
2360 auto thenYieldArgs = op.thenYield().getOperands();
2361 auto elseYieldArgs = op.elseYield().getOperands();
2362
2363 SmallVector<Type> nonHoistable;
2364 for (auto [trueVal, falseVal] : llvm::zip(t&: thenYieldArgs, u&: elseYieldArgs)) {
2365 if (&op.getThenRegion() == trueVal.getParentRegion() ||
2366 &op.getElseRegion() == falseVal.getParentRegion())
2367 nonHoistable.push_back(Elt: trueVal.getType());
2368 }
2369 // Early exit if there aren't any yielded values we can
2370 // hoist outside the if.
2371 if (nonHoistable.size() == op->getNumResults())
2372 return failure();
2373
2374 IfOp replacement = rewriter.create<IfOp>(location: op.getLoc(), args&: nonHoistable, args&: cond,
2375 /*withElseRegion=*/args: false);
2376 if (replacement.thenBlock())
2377 rewriter.eraseBlock(block: replacement.thenBlock());
2378 replacement.getThenRegion().takeBody(other&: op.getThenRegion());
2379 replacement.getElseRegion().takeBody(other&: op.getElseRegion());
2380
2381 SmallVector<Value> results(op->getNumResults());
2382 assert(thenYieldArgs.size() == results.size());
2383 assert(elseYieldArgs.size() == results.size());
2384
2385 SmallVector<Value> trueYields;
2386 SmallVector<Value> falseYields;
2387 rewriter.setInsertionPoint(replacement);
2388 for (const auto &it :
2389 llvm::enumerate(First: llvm::zip(t&: thenYieldArgs, u&: elseYieldArgs))) {
2390 Value trueVal = std::get<0>(t&: it.value());
2391 Value falseVal = std::get<1>(t&: it.value());
2392 if (&replacement.getThenRegion() == trueVal.getParentRegion() ||
2393 &replacement.getElseRegion() == falseVal.getParentRegion()) {
2394 results[it.index()] = replacement.getResult(i: trueYields.size());
2395 trueYields.push_back(Elt: trueVal);
2396 falseYields.push_back(Elt: falseVal);
2397 } else if (trueVal == falseVal)
2398 results[it.index()] = trueVal;
2399 else
2400 results[it.index()] = rewriter.create<arith::SelectOp>(
2401 location: op.getLoc(), args&: cond, args&: trueVal, args&: falseVal);
2402 }
2403
2404 rewriter.setInsertionPointToEnd(replacement.thenBlock());
2405 rewriter.replaceOpWithNewOp<YieldOp>(op: replacement.thenYield(), args&: trueYields);
2406
2407 rewriter.setInsertionPointToEnd(replacement.elseBlock());
2408 rewriter.replaceOpWithNewOp<YieldOp>(op: replacement.elseYield(), args&: falseYields);
2409
2410 rewriter.replaceOp(op, newValues: results);
2411 return success();
2412 }
2413};
2414
2415/// Allow the true region of an if to assume the condition is true
2416/// and vice versa. For example:
2417///
2418/// scf.if %cmp {
2419/// print(%cmp)
2420/// }
2421///
2422/// becomes
2423///
2424/// scf.if %cmp {
2425/// print(true)
2426/// }
2427///
2428struct ConditionPropagation : public OpRewritePattern<IfOp> {
2429 using OpRewritePattern<IfOp>::OpRewritePattern;
2430
2431 LogicalResult matchAndRewrite(IfOp op,
2432 PatternRewriter &rewriter) const override {
2433 // Early exit if the condition is constant since replacing a constant
2434 // in the body with another constant isn't a simplification.
2435 if (matchPattern(value: op.getCondition(), pattern: m_Constant()))
2436 return failure();
2437
2438 bool changed = false;
2439 mlir::Type i1Ty = rewriter.getI1Type();
2440
2441 // These variables serve to prevent creating duplicate constants
2442 // and hold constant true or false values.
2443 Value constantTrue = nullptr;
2444 Value constantFalse = nullptr;
2445
2446 for (OpOperand &use :
2447 llvm::make_early_inc_range(Range: op.getCondition().getUses())) {
2448 if (op.getThenRegion().isAncestor(other: use.getOwner()->getParentRegion())) {
2449 changed = true;
2450
2451 if (!constantTrue)
2452 constantTrue = rewriter.create<arith::ConstantOp>(
2453 location: op.getLoc(), args&: i1Ty, args: rewriter.getIntegerAttr(type: i1Ty, value: 1));
2454
2455 rewriter.modifyOpInPlace(root: use.getOwner(),
2456 callable: [&]() { use.set(constantTrue); });
2457 } else if (op.getElseRegion().isAncestor(
2458 other: use.getOwner()->getParentRegion())) {
2459 changed = true;
2460
2461 if (!constantFalse)
2462 constantFalse = rewriter.create<arith::ConstantOp>(
2463 location: op.getLoc(), args&: i1Ty, args: rewriter.getIntegerAttr(type: i1Ty, value: 0));
2464
2465 rewriter.modifyOpInPlace(root: use.getOwner(),
2466 callable: [&]() { use.set(constantFalse); });
2467 }
2468 }
2469
2470 return success(IsSuccess: changed);
2471 }
2472};
2473
2474/// Remove any statements from an if that are equivalent to the condition
2475/// or its negation. For example:
2476///
2477/// %res:2 = scf.if %cmp {
2478/// yield something(), true
2479/// } else {
2480/// yield something2(), false
2481/// }
2482/// print(%res#1)
2483///
2484/// becomes
2485/// %res = scf.if %cmp {
2486/// yield something()
2487/// } else {
2488/// yield something2()
2489/// }
2490/// print(%cmp)
2491///
2492/// Additionally if both branches yield the same value, replace all uses
2493/// of the result with the yielded value.
2494///
2495/// %res:2 = scf.if %cmp {
2496/// yield something(), %arg1
2497/// } else {
2498/// yield something2(), %arg1
2499/// }
2500/// print(%res#1)
2501///
2502/// becomes
2503/// %res = scf.if %cmp {
2504/// yield something()
2505/// } else {
2506/// yield something2()
2507/// }
2508/// print(%arg1)
2509///
2510struct ReplaceIfYieldWithConditionOrValue : public OpRewritePattern<IfOp> {
2511 using OpRewritePattern<IfOp>::OpRewritePattern;
2512
2513 LogicalResult matchAndRewrite(IfOp op,
2514 PatternRewriter &rewriter) const override {
2515 // Early exit if there are no results that could be replaced.
2516 if (op.getNumResults() == 0)
2517 return failure();
2518
2519 auto trueYield =
2520 cast<scf::YieldOp>(Val: op.getThenRegion().back().getTerminator());
2521 auto falseYield =
2522 cast<scf::YieldOp>(Val: op.getElseRegion().back().getTerminator());
2523
2524 rewriter.setInsertionPoint(block: op->getBlock(),
2525 insertPoint: op.getOperation()->getIterator());
2526 bool changed = false;
2527 Type i1Ty = rewriter.getI1Type();
2528 for (auto [trueResult, falseResult, opResult] :
2529 llvm::zip(t: trueYield.getResults(), u: falseYield.getResults(),
2530 args: op.getResults())) {
2531 if (trueResult == falseResult) {
2532 if (!opResult.use_empty()) {
2533 opResult.replaceAllUsesWith(newValue: trueResult);
2534 changed = true;
2535 }
2536 continue;
2537 }
2538
2539 BoolAttr trueYield, falseYield;
2540 if (!matchPattern(value: trueResult, pattern: m_Constant(bind_value: &trueYield)) ||
2541 !matchPattern(value: falseResult, pattern: m_Constant(bind_value: &falseYield)))
2542 continue;
2543
2544 bool trueVal = trueYield.getValue();
2545 bool falseVal = falseYield.getValue();
2546 if (!trueVal && falseVal) {
2547 if (!opResult.use_empty()) {
2548 Dialect *constDialect = trueResult.getDefiningOp()->getDialect();
2549 Value notCond = rewriter.create<arith::XOrIOp>(
2550 location: op.getLoc(), args: op.getCondition(),
2551 args: constDialect
2552 ->materializeConstant(builder&: rewriter,
2553 value: rewriter.getIntegerAttr(type: i1Ty, value: 1), type: i1Ty,
2554 loc: op.getLoc())
2555 ->getResult(idx: 0));
2556 opResult.replaceAllUsesWith(newValue: notCond);
2557 changed = true;
2558 }
2559 }
2560 if (trueVal && !falseVal) {
2561 if (!opResult.use_empty()) {
2562 opResult.replaceAllUsesWith(newValue: op.getCondition());
2563 changed = true;
2564 }
2565 }
2566 }
2567 return success(IsSuccess: changed);
2568 }
2569};
2570
2571/// Merge any consecutive scf.if's with the same condition.
2572///
2573/// scf.if %cond {
2574/// firstCodeTrue();...
2575/// } else {
2576/// firstCodeFalse();...
2577/// }
2578/// %res = scf.if %cond {
2579/// secondCodeTrue();...
2580/// } else {
2581/// secondCodeFalse();...
2582/// }
2583///
2584/// becomes
2585/// %res = scf.if %cmp {
2586/// firstCodeTrue();...
2587/// secondCodeTrue();...
2588/// } else {
2589/// firstCodeFalse();...
2590/// secondCodeFalse();...
2591/// }
2592struct CombineIfs : public OpRewritePattern<IfOp> {
2593 using OpRewritePattern<IfOp>::OpRewritePattern;
2594
2595 LogicalResult matchAndRewrite(IfOp nextIf,
2596 PatternRewriter &rewriter) const override {
2597 Block *parent = nextIf->getBlock();
2598 if (nextIf == &parent->front())
2599 return failure();
2600
2601 auto prevIf = dyn_cast<IfOp>(Val: nextIf->getPrevNode());
2602 if (!prevIf)
2603 return failure();
2604
2605 // Determine the logical then/else blocks when prevIf's
2606 // condition is used. Null means the block does not exist
2607 // in that case (e.g. empty else). If neither of these
2608 // are set, the two conditions cannot be compared.
2609 Block *nextThen = nullptr;
2610 Block *nextElse = nullptr;
2611 if (nextIf.getCondition() == prevIf.getCondition()) {
2612 nextThen = nextIf.thenBlock();
2613 if (!nextIf.getElseRegion().empty())
2614 nextElse = nextIf.elseBlock();
2615 }
2616 if (arith::XOrIOp notv =
2617 nextIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2618 if (notv.getLhs() == prevIf.getCondition() &&
2619 matchPattern(value: notv.getRhs(), pattern: m_One())) {
2620 nextElse = nextIf.thenBlock();
2621 if (!nextIf.getElseRegion().empty())
2622 nextThen = nextIf.elseBlock();
2623 }
2624 }
2625 if (arith::XOrIOp notv =
2626 prevIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2627 if (notv.getLhs() == nextIf.getCondition() &&
2628 matchPattern(value: notv.getRhs(), pattern: m_One())) {
2629 nextElse = nextIf.thenBlock();
2630 if (!nextIf.getElseRegion().empty())
2631 nextThen = nextIf.elseBlock();
2632 }
2633 }
2634
2635 if (!nextThen && !nextElse)
2636 return failure();
2637
2638 SmallVector<Value> prevElseYielded;
2639 if (!prevIf.getElseRegion().empty())
2640 prevElseYielded = prevIf.elseYield().getOperands();
2641 // Replace all uses of return values of op within nextIf with the
2642 // corresponding yields
2643 for (auto it : llvm::zip(t: prevIf.getResults(),
2644 u: prevIf.thenYield().getOperands(), args&: prevElseYielded))
2645 for (OpOperand &use :
2646 llvm::make_early_inc_range(Range: std::get<0>(t&: it).getUses())) {
2647 if (nextThen && nextThen->getParent()->isAncestor(
2648 other: use.getOwner()->getParentRegion())) {
2649 rewriter.startOpModification(op: use.getOwner());
2650 use.set(std::get<1>(t&: it));
2651 rewriter.finalizeOpModification(op: use.getOwner());
2652 } else if (nextElse && nextElse->getParent()->isAncestor(
2653 other: use.getOwner()->getParentRegion())) {
2654 rewriter.startOpModification(op: use.getOwner());
2655 use.set(std::get<2>(t&: it));
2656 rewriter.finalizeOpModification(op: use.getOwner());
2657 }
2658 }
2659
2660 SmallVector<Type> mergedTypes(prevIf.getResultTypes());
2661 llvm::append_range(C&: mergedTypes, R: nextIf.getResultTypes());
2662
2663 IfOp combinedIf = rewriter.create<IfOp>(
2664 location: nextIf.getLoc(), args&: mergedTypes, args: prevIf.getCondition(), /*hasElse=*/args: false);
2665 rewriter.eraseBlock(block: &combinedIf.getThenRegion().back());
2666
2667 rewriter.inlineRegionBefore(region&: prevIf.getThenRegion(),
2668 parent&: combinedIf.getThenRegion(),
2669 before: combinedIf.getThenRegion().begin());
2670
2671 if (nextThen) {
2672 YieldOp thenYield = combinedIf.thenYield();
2673 YieldOp thenYield2 = cast<YieldOp>(Val: nextThen->getTerminator());
2674 rewriter.mergeBlocks(source: nextThen, dest: combinedIf.thenBlock());
2675 rewriter.setInsertionPointToEnd(combinedIf.thenBlock());
2676
2677 SmallVector<Value> mergedYields(thenYield.getOperands());
2678 llvm::append_range(C&: mergedYields, R: thenYield2.getOperands());
2679 rewriter.create<YieldOp>(location: thenYield2.getLoc(), args&: mergedYields);
2680 rewriter.eraseOp(op: thenYield);
2681 rewriter.eraseOp(op: thenYield2);
2682 }
2683
2684 rewriter.inlineRegionBefore(region&: prevIf.getElseRegion(),
2685 parent&: combinedIf.getElseRegion(),
2686 before: combinedIf.getElseRegion().begin());
2687
2688 if (nextElse) {
2689 if (combinedIf.getElseRegion().empty()) {
2690 rewriter.inlineRegionBefore(region&: *nextElse->getParent(),
2691 parent&: combinedIf.getElseRegion(),
2692 before: combinedIf.getElseRegion().begin());
2693 } else {
2694 YieldOp elseYield = combinedIf.elseYield();
2695 YieldOp elseYield2 = cast<YieldOp>(Val: nextElse->getTerminator());
2696 rewriter.mergeBlocks(source: nextElse, dest: combinedIf.elseBlock());
2697
2698 rewriter.setInsertionPointToEnd(combinedIf.elseBlock());
2699
2700 SmallVector<Value> mergedElseYields(elseYield.getOperands());
2701 llvm::append_range(C&: mergedElseYields, R: elseYield2.getOperands());
2702
2703 rewriter.create<YieldOp>(location: elseYield2.getLoc(), args&: mergedElseYields);
2704 rewriter.eraseOp(op: elseYield);
2705 rewriter.eraseOp(op: elseYield2);
2706 }
2707 }
2708
2709 SmallVector<Value> prevValues;
2710 SmallVector<Value> nextValues;
2711 for (const auto &pair : llvm::enumerate(First: combinedIf.getResults())) {
2712 if (pair.index() < prevIf.getNumResults())
2713 prevValues.push_back(Elt: pair.value());
2714 else
2715 nextValues.push_back(Elt: pair.value());
2716 }
2717 rewriter.replaceOp(op: prevIf, newValues: prevValues);
2718 rewriter.replaceOp(op: nextIf, newValues: nextValues);
2719 return success();
2720 }
2721};
2722
2723/// Pattern to remove an empty else branch.
2724struct RemoveEmptyElseBranch : public OpRewritePattern<IfOp> {
2725 using OpRewritePattern<IfOp>::OpRewritePattern;
2726
2727 LogicalResult matchAndRewrite(IfOp ifOp,
2728 PatternRewriter &rewriter) const override {
2729 // Cannot remove else region when there are operation results.
2730 if (ifOp.getNumResults())
2731 return failure();
2732 Block *elseBlock = ifOp.elseBlock();
2733 if (!elseBlock || !llvm::hasSingleElement(C&: *elseBlock))
2734 return failure();
2735 auto newIfOp = rewriter.cloneWithoutRegions(op: ifOp);
2736 rewriter.inlineRegionBefore(region&: ifOp.getThenRegion(), parent&: newIfOp.getThenRegion(),
2737 before: newIfOp.getThenRegion().begin());
2738 rewriter.eraseOp(op: ifOp);
2739 return success();
2740 }
2741};
2742
2743/// Convert nested `if`s into `arith.andi` + single `if`.
2744///
2745/// scf.if %arg0 {
2746/// scf.if %arg1 {
2747/// ...
2748/// scf.yield
2749/// }
2750/// scf.yield
2751/// }
2752/// becomes
2753///
2754/// %0 = arith.andi %arg0, %arg1
2755/// scf.if %0 {
2756/// ...
2757/// scf.yield
2758/// }
2759struct CombineNestedIfs : public OpRewritePattern<IfOp> {
2760 using OpRewritePattern<IfOp>::OpRewritePattern;
2761
2762 LogicalResult matchAndRewrite(IfOp op,
2763 PatternRewriter &rewriter) const override {
2764 auto nestedOps = op.thenBlock()->without_terminator();
2765 // Nested `if` must be the only op in block.
2766 if (!llvm::hasSingleElement(C&: nestedOps))
2767 return failure();
2768
2769 // If there is an else block, it can only yield
2770 if (op.elseBlock() && !llvm::hasSingleElement(C&: *op.elseBlock()))
2771 return failure();
2772
2773 auto nestedIf = dyn_cast<IfOp>(Val&: *nestedOps.begin());
2774 if (!nestedIf)
2775 return failure();
2776
2777 if (nestedIf.elseBlock() && !llvm::hasSingleElement(C&: *nestedIf.elseBlock()))
2778 return failure();
2779
2780 SmallVector<Value> thenYield(op.thenYield().getOperands());
2781 SmallVector<Value> elseYield;
2782 if (op.elseBlock())
2783 llvm::append_range(C&: elseYield, R: op.elseYield().getOperands());
2784
2785 // A list of indices for which we should upgrade the value yielded
2786 // in the else to a select.
2787 SmallVector<unsigned> elseYieldsToUpgradeToSelect;
2788
2789 // If the outer scf.if yields a value produced by the inner scf.if,
2790 // only permit combining if the value yielded when the condition
2791 // is false in the outer scf.if is the same value yielded when the
2792 // inner scf.if condition is false.
2793 // Note that the array access to elseYield will not go out of bounds
2794 // since it must have the same length as thenYield, since they both
2795 // come from the same scf.if.
2796 for (const auto &tup : llvm::enumerate(First&: thenYield)) {
2797 if (tup.value().getDefiningOp() == nestedIf) {
2798 auto nestedIdx = llvm::cast<OpResult>(Val&: tup.value()).getResultNumber();
2799 if (nestedIf.elseYield().getOperand(i: nestedIdx) !=
2800 elseYield[tup.index()]) {
2801 return failure();
2802 }
2803 // If the correctness test passes, we will yield
2804 // corresponding value from the inner scf.if
2805 thenYield[tup.index()] = nestedIf.thenYield().getOperand(i: nestedIdx);
2806 continue;
2807 }
2808
2809 // Otherwise, we need to ensure the else block of the combined
2810 // condition still returns the same value when the outer condition is
2811 // true and the inner condition is false. This can be accomplished if
2812 // the then value is defined outside the outer scf.if and we replace the
2813 // value with a select that considers just the outer condition. Since
2814 // the else region contains just the yield, its yielded value is
2815 // defined outside the scf.if, by definition.
2816
2817 // If the then value is defined within the scf.if, bail.
2818 if (tup.value().getParentRegion() == &op.getThenRegion()) {
2819 return failure();
2820 }
2821 elseYieldsToUpgradeToSelect.push_back(Elt: tup.index());
2822 }
2823
2824 Location loc = op.getLoc();
2825 Value newCondition = rewriter.create<arith::AndIOp>(
2826 location: loc, args: op.getCondition(), args: nestedIf.getCondition());
2827 auto newIf = rewriter.create<IfOp>(location: loc, args: op.getResultTypes(), args&: newCondition);
2828 Block *newIfBlock = rewriter.createBlock(parent: &newIf.getThenRegion());
2829
2830 SmallVector<Value> results;
2831 llvm::append_range(C&: results, R: newIf.getResults());
2832 rewriter.setInsertionPoint(newIf);
2833
2834 for (auto idx : elseYieldsToUpgradeToSelect)
2835 results[idx] = rewriter.create<arith::SelectOp>(
2836 location: op.getLoc(), args: op.getCondition(), args&: thenYield[idx], args&: elseYield[idx]);
2837
2838 rewriter.mergeBlocks(source: nestedIf.thenBlock(), dest: newIfBlock);
2839 rewriter.setInsertionPointToEnd(newIf.thenBlock());
2840 rewriter.replaceOpWithNewOp<YieldOp>(op: newIf.thenYield(), args&: thenYield);
2841 if (!elseYield.empty()) {
2842 rewriter.createBlock(parent: &newIf.getElseRegion());
2843 rewriter.setInsertionPointToEnd(newIf.elseBlock());
2844 rewriter.create<YieldOp>(location: loc, args&: elseYield);
2845 }
2846 rewriter.replaceOp(op, newValues: results);
2847 return success();
2848 }
2849};
2850
2851} // namespace
2852
2853void IfOp::getCanonicalizationPatterns(RewritePatternSet &results,
2854 MLIRContext *context) {
2855 results.add<CombineIfs, CombineNestedIfs, ConditionPropagation,
2856 ConvertTrivialIfToSelect, RemoveEmptyElseBranch,
2857 RemoveStaticCondition, RemoveUnusedResults,
2858 ReplaceIfYieldWithConditionOrValue>(arg&: context);
2859}
2860
2861Block *IfOp::thenBlock() { return &getThenRegion().back(); }
2862YieldOp IfOp::thenYield() { return cast<YieldOp>(Val: &thenBlock()->back()); }
2863Block *IfOp::elseBlock() {
2864 Region &r = getElseRegion();
2865 if (r.empty())
2866 return nullptr;
2867 return &r.back();
2868}
2869YieldOp IfOp::elseYield() { return cast<YieldOp>(Val: &elseBlock()->back()); }
2870
2871//===----------------------------------------------------------------------===//
2872// ParallelOp
2873//===----------------------------------------------------------------------===//
2874
2875void ParallelOp::build(
2876 OpBuilder &builder, OperationState &result, ValueRange lowerBounds,
2877 ValueRange upperBounds, ValueRange steps, ValueRange initVals,
2878 function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)>
2879 bodyBuilderFn) {
2880 result.addOperands(newOperands: lowerBounds);
2881 result.addOperands(newOperands: upperBounds);
2882 result.addOperands(newOperands: steps);
2883 result.addOperands(newOperands: initVals);
2884 result.addAttribute(
2885 name: ParallelOp::getOperandSegmentSizeAttr(),
2886 attr: builder.getDenseI32ArrayAttr(values: {static_cast<int32_t>(lowerBounds.size()),
2887 static_cast<int32_t>(upperBounds.size()),
2888 static_cast<int32_t>(steps.size()),
2889 static_cast<int32_t>(initVals.size())}));
2890 result.addTypes(newTypes: initVals.getTypes());
2891
2892 OpBuilder::InsertionGuard guard(builder);
2893 unsigned numIVs = steps.size();
2894 SmallVector<Type, 8> argTypes(numIVs, builder.getIndexType());
2895 SmallVector<Location, 8> argLocs(numIVs, result.location);
2896 Region *bodyRegion = result.addRegion();
2897 Block *bodyBlock = builder.createBlock(parent: bodyRegion, insertPt: {}, argTypes, locs: argLocs);
2898
2899 if (bodyBuilderFn) {
2900 builder.setInsertionPointToStart(bodyBlock);
2901 bodyBuilderFn(builder, result.location,
2902 bodyBlock->getArguments().take_front(N: numIVs),
2903 bodyBlock->getArguments().drop_front(N: numIVs));
2904 }
2905 // Add terminator only if there are no reductions.
2906 if (initVals.empty())
2907 ParallelOp::ensureTerminator(region&: *bodyRegion, builder, loc: result.location);
2908}
2909
2910void ParallelOp::build(
2911 OpBuilder &builder, OperationState &result, ValueRange lowerBounds,
2912 ValueRange upperBounds, ValueRange steps,
2913 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
2914 // Only pass a non-null wrapper if bodyBuilderFn is non-null itself. Make sure
2915 // we don't capture a reference to a temporary by constructing the lambda at
2916 // function level.
2917 auto wrappedBuilderFn = [&bodyBuilderFn](OpBuilder &nestedBuilder,
2918 Location nestedLoc, ValueRange ivs,
2919 ValueRange) {
2920 bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
2921 };
2922 function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)> wrapper;
2923 if (bodyBuilderFn)
2924 wrapper = wrappedBuilderFn;
2925
2926 build(builder, result, lowerBounds, upperBounds, steps, initVals: ValueRange(),
2927 bodyBuilderFn: wrapper);
2928}
2929
2930LogicalResult ParallelOp::verify() {
2931 // Check that there is at least one value in lowerBound, upperBound and step.
2932 // It is sufficient to test only step, because it is ensured already that the
2933 // number of elements in lowerBound, upperBound and step are the same.
2934 Operation::operand_range stepValues = getStep();
2935 if (stepValues.empty())
2936 return emitOpError(
2937 message: "needs at least one tuple element for lowerBound, upperBound and step");
2938
2939 // Check whether all constant step values are positive.
2940 for (Value stepValue : stepValues)
2941 if (auto cst = getConstantIntValue(ofr: stepValue))
2942 if (*cst <= 0)
2943 return emitOpError(message: "constant step operand must be positive");
2944
2945 // Check that the body defines the same number of block arguments as the
2946 // number of tuple elements in step.
2947 Block *body = getBody();
2948 if (body->getNumArguments() != stepValues.size())
2949 return emitOpError() << "expects the same number of induction variables: "
2950 << body->getNumArguments()
2951 << " as bound and step values: " << stepValues.size();
2952 for (auto arg : body->getArguments())
2953 if (!arg.getType().isIndex())
2954 return emitOpError(
2955 message: "expects arguments for the induction variable to be of index type");
2956
2957 // Check that the terminator is an scf.reduce op.
2958 auto reduceOp = verifyAndGetTerminator<scf::ReduceOp>(
2959 op: *this, region&: getRegion(), errorMessage: "expects body to terminate with 'scf.reduce'");
2960 if (!reduceOp)
2961 return failure();
2962
2963 // Check that the number of results is the same as the number of reductions.
2964 auto resultsSize = getResults().size();
2965 auto reductionsSize = reduceOp.getReductions().size();
2966 auto initValsSize = getInitVals().size();
2967 if (resultsSize != reductionsSize)
2968 return emitOpError() << "expects number of results: " << resultsSize
2969 << " to be the same as number of reductions: "
2970 << reductionsSize;
2971 if (resultsSize != initValsSize)
2972 return emitOpError() << "expects number of results: " << resultsSize
2973 << " to be the same as number of initial values: "
2974 << initValsSize;
2975
2976 // Check that the types of the results and reductions are the same.
2977 for (int64_t i = 0; i < static_cast<int64_t>(reductionsSize); ++i) {
2978 auto resultType = getOperation()->getResult(idx: i).getType();
2979 auto reductionOperandType = reduceOp.getOperands()[i].getType();
2980 if (resultType != reductionOperandType)
2981 return reduceOp.emitOpError()
2982 << "expects type of " << i
2983 << "-th reduction operand: " << reductionOperandType
2984 << " to be the same as the " << i
2985 << "-th result type: " << resultType;
2986 }
2987 return success();
2988}
2989
2990ParseResult ParallelOp::parse(OpAsmParser &parser, OperationState &result) {
2991 auto &builder = parser.getBuilder();
2992 // Parse an opening `(` followed by induction variables followed by `)`
2993 SmallVector<OpAsmParser::Argument, 4> ivs;
2994 if (parser.parseArgumentList(result&: ivs, delimiter: OpAsmParser::Delimiter::Paren))
2995 return failure();
2996
2997 // Parse loop bounds.
2998 SmallVector<OpAsmParser::UnresolvedOperand, 4> lower;
2999 if (parser.parseEqual() ||
3000 parser.parseOperandList(result&: lower, requiredOperandCount: ivs.size(),
3001 delimiter: OpAsmParser::Delimiter::Paren) ||
3002 parser.resolveOperands(operands&: lower, type: builder.getIndexType(), result&: result.operands))
3003 return failure();
3004
3005 SmallVector<OpAsmParser::UnresolvedOperand, 4> upper;
3006 if (parser.parseKeyword(keyword: "to") ||
3007 parser.parseOperandList(result&: upper, requiredOperandCount: ivs.size(),
3008 delimiter: OpAsmParser::Delimiter::Paren) ||
3009 parser.resolveOperands(operands&: upper, type: builder.getIndexType(), result&: result.operands))
3010 return failure();
3011
3012 // Parse step values.
3013 SmallVector<OpAsmParser::UnresolvedOperand, 4> steps;
3014 if (parser.parseKeyword(keyword: "step") ||
3015 parser.parseOperandList(result&: steps, requiredOperandCount: ivs.size(),
3016 delimiter: OpAsmParser::Delimiter::Paren) ||
3017 parser.resolveOperands(operands&: steps, type: builder.getIndexType(), result&: result.operands))
3018 return failure();
3019
3020 // Parse init values.
3021 SmallVector<OpAsmParser::UnresolvedOperand, 4> initVals;
3022 if (succeeded(Result: parser.parseOptionalKeyword(keyword: "init"))) {
3023 if (parser.parseOperandList(result&: initVals, delimiter: OpAsmParser::Delimiter::Paren))
3024 return failure();
3025 }
3026
3027 // Parse optional results in case there is a reduce.
3028 if (parser.parseOptionalArrowTypeList(result&: result.types))
3029 return failure();
3030
3031 // Now parse the body.
3032 Region *body = result.addRegion();
3033 for (auto &iv : ivs)
3034 iv.type = builder.getIndexType();
3035 if (parser.parseRegion(region&: *body, arguments: ivs))
3036 return failure();
3037
3038 // Set `operandSegmentSizes` attribute.
3039 result.addAttribute(
3040 name: ParallelOp::getOperandSegmentSizeAttr(),
3041 attr: builder.getDenseI32ArrayAttr(values: {static_cast<int32_t>(lower.size()),
3042 static_cast<int32_t>(upper.size()),
3043 static_cast<int32_t>(steps.size()),
3044 static_cast<int32_t>(initVals.size())}));
3045
3046 // Parse attributes.
3047 if (parser.parseOptionalAttrDict(result&: result.attributes) ||
3048 parser.resolveOperands(operands&: initVals, types&: result.types, loc: parser.getNameLoc(),
3049 result&: result.operands))
3050 return failure();
3051
3052 // Add a terminator if none was parsed.
3053 ParallelOp::ensureTerminator(region&: *body, builder, loc: result.location);
3054 return success();
3055}
3056
3057void ParallelOp::print(OpAsmPrinter &p) {
3058 p << " (" << getBody()->getArguments() << ") = (" << getLowerBound()
3059 << ") to (" << getUpperBound() << ") step (" << getStep() << ")";
3060 if (!getInitVals().empty())
3061 p << " init (" << getInitVals() << ")";
3062 p.printOptionalArrowTypeList(types: getResultTypes());
3063 p << ' ';
3064 p.printRegion(blocks&: getRegion(), /*printEntryBlockArgs=*/false);
3065 p.printOptionalAttrDict(
3066 attrs: (*this)->getAttrs(),
3067 /*elidedAttrs=*/ParallelOp::getOperandSegmentSizeAttr());
3068}
3069
3070SmallVector<Region *> ParallelOp::getLoopRegions() { return {&getRegion()}; }
3071
3072std::optional<SmallVector<Value>> ParallelOp::getLoopInductionVars() {
3073 return SmallVector<Value>{getBody()->getArguments()};
3074}
3075
3076std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopLowerBounds() {
3077 return getLowerBound();
3078}
3079
3080std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopUpperBounds() {
3081 return getUpperBound();
3082}
3083
3084std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopSteps() {
3085 return getStep();
3086}
3087
3088ParallelOp mlir::scf::getParallelForInductionVarOwner(Value val) {
3089 auto ivArg = llvm::dyn_cast<BlockArgument>(Val&: val);
3090 if (!ivArg)
3091 return ParallelOp();
3092 assert(ivArg.getOwner() && "unlinked block argument");
3093 auto *containingOp = ivArg.getOwner()->getParentOp();
3094 return dyn_cast<ParallelOp>(Val: containingOp);
3095}
3096
3097namespace {
3098// Collapse loop dimensions that perform a single iteration.
3099struct ParallelOpSingleOrZeroIterationDimsFolder
3100 : public OpRewritePattern<ParallelOp> {
3101 using OpRewritePattern<ParallelOp>::OpRewritePattern;
3102
3103 LogicalResult matchAndRewrite(ParallelOp op,
3104 PatternRewriter &rewriter) const override {
3105 Location loc = op.getLoc();
3106
3107 // Compute new loop bounds that omit all single-iteration loop dimensions.
3108 SmallVector<Value> newLowerBounds, newUpperBounds, newSteps;
3109 IRMapping mapping;
3110 for (auto [lb, ub, step, iv] :
3111 llvm::zip(t: op.getLowerBound(), u: op.getUpperBound(), args: op.getStep(),
3112 args: op.getInductionVars())) {
3113 auto numIterations = constantTripCount(lb, ub, step);
3114 if (numIterations.has_value()) {
3115 // Remove the loop if it performs zero iterations.
3116 if (*numIterations == 0) {
3117 rewriter.replaceOp(op, newValues: op.getInitVals());
3118 return success();
3119 }
3120 // Replace the loop induction variable by the lower bound if the loop
3121 // performs a single iteration. Otherwise, copy the loop bounds.
3122 if (*numIterations == 1) {
3123 mapping.map(from: iv, to: getValueOrCreateConstantIndexOp(b&: rewriter, loc, ofr: lb));
3124 continue;
3125 }
3126 }
3127 newLowerBounds.push_back(Elt: lb);
3128 newUpperBounds.push_back(Elt: ub);
3129 newSteps.push_back(Elt: step);
3130 }
3131 // Exit if none of the loop dimensions perform a single iteration.
3132 if (newLowerBounds.size() == op.getLowerBound().size())
3133 return failure();
3134
3135 if (newLowerBounds.empty()) {
3136 // All of the loop dimensions perform a single iteration. Inline
3137 // loop body and nested ReduceOp's
3138 SmallVector<Value> results;
3139 results.reserve(N: op.getInitVals().size());
3140 for (auto &bodyOp : op.getBody()->without_terminator())
3141 rewriter.clone(op&: bodyOp, mapper&: mapping);
3142 auto reduceOp = cast<ReduceOp>(Val: op.getBody()->getTerminator());
3143 for (int64_t i = 0, e = reduceOp.getReductions().size(); i < e; ++i) {
3144 Block &reduceBlock = reduceOp.getReductions()[i].front();
3145 auto initValIndex = results.size();
3146 mapping.map(from: reduceBlock.getArgument(i: 0), to: op.getInitVals()[initValIndex]);
3147 mapping.map(from: reduceBlock.getArgument(i: 1),
3148 to: mapping.lookupOrDefault(from: reduceOp.getOperands()[i]));
3149 for (auto &reduceBodyOp : reduceBlock.without_terminator())
3150 rewriter.clone(op&: reduceBodyOp, mapper&: mapping);
3151
3152 auto result = mapping.lookupOrDefault(
3153 from: cast<ReduceReturnOp>(Val: reduceBlock.getTerminator()).getResult());
3154 results.push_back(Elt: result);
3155 }
3156
3157 rewriter.replaceOp(op, newValues: results);
3158 return success();
3159 }
3160 // Replace the parallel loop by lower-dimensional parallel loop.
3161 auto newOp =
3162 rewriter.create<ParallelOp>(location: op.getLoc(), args&: newLowerBounds, args&: newUpperBounds,
3163 args&: newSteps, args: op.getInitVals(), args: nullptr);
3164 // Erase the empty block that was inserted by the builder.
3165 rewriter.eraseBlock(block: newOp.getBody());
3166 // Clone the loop body and remap the block arguments of the collapsed loops
3167 // (inlining does not support a cancellable block argument mapping).
3168 rewriter.cloneRegionBefore(region&: op.getRegion(), parent&: newOp.getRegion(),
3169 before: newOp.getRegion().begin(), mapping);
3170 rewriter.replaceOp(op, newValues: newOp.getResults());
3171 return success();
3172 }
3173};
3174
3175struct MergeNestedParallelLoops : public OpRewritePattern<ParallelOp> {
3176 using OpRewritePattern<ParallelOp>::OpRewritePattern;
3177
3178 LogicalResult matchAndRewrite(ParallelOp op,
3179 PatternRewriter &rewriter) const override {
3180 Block &outerBody = *op.getBody();
3181 if (!llvm::hasSingleElement(C: outerBody.without_terminator()))
3182 return failure();
3183
3184 auto innerOp = dyn_cast<ParallelOp>(Val&: outerBody.front());
3185 if (!innerOp)
3186 return failure();
3187
3188 for (auto val : outerBody.getArguments())
3189 if (llvm::is_contained(Range: innerOp.getLowerBound(), Element: val) ||
3190 llvm::is_contained(Range: innerOp.getUpperBound(), Element: val) ||
3191 llvm::is_contained(Range: innerOp.getStep(), Element: val))
3192 return failure();
3193
3194 // Reductions are not supported yet.
3195 if (!op.getInitVals().empty() || !innerOp.getInitVals().empty())
3196 return failure();
3197
3198 auto bodyBuilder = [&](OpBuilder &builder, Location /*loc*/,
3199 ValueRange iterVals, ValueRange) {
3200 Block &innerBody = *innerOp.getBody();
3201 assert(iterVals.size() ==
3202 (outerBody.getNumArguments() + innerBody.getNumArguments()));
3203 IRMapping mapping;
3204 mapping.map(from: outerBody.getArguments(),
3205 to: iterVals.take_front(n: outerBody.getNumArguments()));
3206 mapping.map(from: innerBody.getArguments(),
3207 to: iterVals.take_back(n: innerBody.getNumArguments()));
3208 for (Operation &op : innerBody.without_terminator())
3209 builder.clone(op, mapper&: mapping);
3210 };
3211
3212 auto concatValues = [](const auto &first, const auto &second) {
3213 SmallVector<Value> ret;
3214 ret.reserve(N: first.size() + second.size());
3215 ret.assign(first.begin(), first.end());
3216 ret.append(second.begin(), second.end());
3217 return ret;
3218 };
3219
3220 auto newLowerBounds =
3221 concatValues(op.getLowerBound(), innerOp.getLowerBound());
3222 auto newUpperBounds =
3223 concatValues(op.getUpperBound(), innerOp.getUpperBound());
3224 auto newSteps = concatValues(op.getStep(), innerOp.getStep());
3225
3226 rewriter.replaceOpWithNewOp<ParallelOp>(op, args&: newLowerBounds, args&: newUpperBounds,
3227 args&: newSteps, args: ValueRange(),
3228 args&: bodyBuilder);
3229 return success();
3230 }
3231};
3232
3233} // namespace
3234
3235void ParallelOp::getCanonicalizationPatterns(RewritePatternSet &results,
3236 MLIRContext *context) {
3237 results
3238 .add<ParallelOpSingleOrZeroIterationDimsFolder, MergeNestedParallelLoops>(
3239 arg&: context);
3240}
3241
3242/// Given the region at `index`, or the parent operation if `index` is None,
3243/// return the successor regions. These are the regions that may be selected
3244/// during the flow of control. `operands` is a set of optional attributes that
3245/// correspond to a constant value for each operand, or null if that operand is
3246/// not a constant.
3247void ParallelOp::getSuccessorRegions(
3248 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
3249 // Both the operation itself and the region may be branching into the body or
3250 // back into the operation itself. It is possible for loop not to enter the
3251 // body.
3252 regions.push_back(Elt: RegionSuccessor(&getRegion()));
3253 regions.push_back(Elt: RegionSuccessor());
3254}
3255
3256//===----------------------------------------------------------------------===//
3257// ReduceOp
3258//===----------------------------------------------------------------------===//
3259
3260void ReduceOp::build(OpBuilder &builder, OperationState &result) {}
3261
3262void ReduceOp::build(OpBuilder &builder, OperationState &result,
3263 ValueRange operands) {
3264 result.addOperands(newOperands: operands);
3265 for (Value v : operands) {
3266 OpBuilder::InsertionGuard guard(builder);
3267 Region *bodyRegion = result.addRegion();
3268 builder.createBlock(parent: bodyRegion, insertPt: {},
3269 argTypes: ArrayRef<Type>{v.getType(), v.getType()},
3270 locs: {result.location, result.location});
3271 }
3272}
3273
3274LogicalResult ReduceOp::verifyRegions() {
3275 // The region of a ReduceOp has two arguments of the same type as its
3276 // corresponding operand.
3277 for (int64_t i = 0, e = getReductions().size(); i < e; ++i) {
3278 auto type = getOperands()[i].getType();
3279 Block &block = getReductions()[i].front();
3280 if (block.empty())
3281 return emitOpError() << i << "-th reduction has an empty body";
3282 if (block.getNumArguments() != 2 ||
3283 llvm::any_of(Range: block.getArguments(), P: [&](const BlockArgument &arg) {
3284 return arg.getType() != type;
3285 }))
3286 return emitOpError() << "expected two block arguments with type " << type
3287 << " in the " << i << "-th reduction region";
3288
3289 // Check that the block is terminated by a ReduceReturnOp.
3290 if (!isa<ReduceReturnOp>(Val: block.getTerminator()))
3291 return emitOpError(message: "reduction bodies must be terminated with an "
3292 "'scf.reduce.return' op");
3293 }
3294
3295 return success();
3296}
3297
3298MutableOperandRange
3299ReduceOp::getMutableSuccessorOperands(RegionBranchPoint point) {
3300 // No operands are forwarded to the next iteration.
3301 return MutableOperandRange(getOperation(), /*start=*/0, /*length=*/0);
3302}
3303
3304//===----------------------------------------------------------------------===//
3305// ReduceReturnOp
3306//===----------------------------------------------------------------------===//
3307
3308LogicalResult ReduceReturnOp::verify() {
3309 // The type of the return value should be the same type as the types of the
3310 // block arguments of the reduction body.
3311 Block *reductionBody = getOperation()->getBlock();
3312 // Should already be verified by an op trait.
3313 assert(isa<ReduceOp>(reductionBody->getParentOp()) && "expected scf.reduce");
3314 Type expectedResultType = reductionBody->getArgument(i: 0).getType();
3315 if (expectedResultType != getResult().getType())
3316 return emitOpError() << "must have type " << expectedResultType
3317 << " (the type of the reduction inputs)";
3318 return success();
3319}
3320
3321//===----------------------------------------------------------------------===//
3322// WhileOp
3323//===----------------------------------------------------------------------===//
3324
3325void WhileOp::build(::mlir::OpBuilder &odsBuilder,
3326 ::mlir::OperationState &odsState, TypeRange resultTypes,
3327 ValueRange inits, BodyBuilderFn beforeBuilder,
3328 BodyBuilderFn afterBuilder) {
3329 odsState.addOperands(newOperands: inits);
3330 odsState.addTypes(newTypes&: resultTypes);
3331
3332 OpBuilder::InsertionGuard guard(odsBuilder);
3333
3334 // Build before region.
3335 SmallVector<Location, 4> beforeArgLocs;
3336 beforeArgLocs.reserve(N: inits.size());
3337 for (Value operand : inits) {
3338 beforeArgLocs.push_back(Elt: operand.getLoc());
3339 }
3340
3341 Region *beforeRegion = odsState.addRegion();
3342 Block *beforeBlock = odsBuilder.createBlock(parent: beforeRegion, /*insertPt=*/{},
3343 argTypes: inits.getTypes(), locs: beforeArgLocs);
3344 if (beforeBuilder)
3345 beforeBuilder(odsBuilder, odsState.location, beforeBlock->getArguments());
3346
3347 // Build after region.
3348 SmallVector<Location, 4> afterArgLocs(resultTypes.size(), odsState.location);
3349
3350 Region *afterRegion = odsState.addRegion();
3351 Block *afterBlock = odsBuilder.createBlock(parent: afterRegion, /*insertPt=*/{},
3352 argTypes: resultTypes, locs: afterArgLocs);
3353
3354 if (afterBuilder)
3355 afterBuilder(odsBuilder, odsState.location, afterBlock->getArguments());
3356}
3357
3358ConditionOp WhileOp::getConditionOp() {
3359 return cast<ConditionOp>(Val: getBeforeBody()->getTerminator());
3360}
3361
3362YieldOp WhileOp::getYieldOp() {
3363 return cast<YieldOp>(Val: getAfterBody()->getTerminator());
3364}
3365
3366std::optional<MutableArrayRef<OpOperand>> WhileOp::getYieldedValuesMutable() {
3367 return getYieldOp().getResultsMutable();
3368}
3369
3370Block::BlockArgListType WhileOp::getBeforeArguments() {
3371 return getBeforeBody()->getArguments();
3372}
3373
3374Block::BlockArgListType WhileOp::getAfterArguments() {
3375 return getAfterBody()->getArguments();
3376}
3377
3378Block::BlockArgListType WhileOp::getRegionIterArgs() {
3379 return getBeforeArguments();
3380}
3381
3382OperandRange WhileOp::getEntrySuccessorOperands(RegionBranchPoint point) {
3383 assert(point == getBefore() &&
3384 "WhileOp is expected to branch only to the first region");
3385 return getInits();
3386}
3387
3388void WhileOp::getSuccessorRegions(RegionBranchPoint point,
3389 SmallVectorImpl<RegionSuccessor> &regions) {
3390 // The parent op always branches to the condition region.
3391 if (point.isParent()) {
3392 regions.emplace_back(Args: &getBefore(), Args: getBefore().getArguments());
3393 return;
3394 }
3395
3396 assert(llvm::is_contained({&getAfter(), &getBefore()}, point) &&
3397 "there are only two regions in a WhileOp");
3398 // The body region always branches back to the condition region.
3399 if (point == getAfter()) {
3400 regions.emplace_back(Args: &getBefore(), Args: getBefore().getArguments());
3401 return;
3402 }
3403
3404 regions.emplace_back(Args: getResults());
3405 regions.emplace_back(Args: &getAfter(), Args: getAfter().getArguments());
3406}
3407
3408SmallVector<Region *> WhileOp::getLoopRegions() {
3409 return {&getBefore(), &getAfter()};
3410}
3411
3412/// Parses a `while` op.
3413///
3414/// op ::= `scf.while` assignments `:` function-type region `do` region
3415/// `attributes` attribute-dict
3416/// initializer ::= /* empty */ | `(` assignment-list `)`
3417/// assignment-list ::= assignment | assignment `,` assignment-list
3418/// assignment ::= ssa-value `=` ssa-value
3419ParseResult scf::WhileOp::parse(OpAsmParser &parser, OperationState &result) {
3420 SmallVector<OpAsmParser::Argument, 4> regionArgs;
3421 SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
3422 Region *before = result.addRegion();
3423 Region *after = result.addRegion();
3424
3425 OptionalParseResult listResult =
3426 parser.parseOptionalAssignmentList(lhs&: regionArgs, rhs&: operands);
3427 if (listResult.has_value() && failed(Result: listResult.value()))
3428 return failure();
3429
3430 FunctionType functionType;
3431 SMLoc typeLoc = parser.getCurrentLocation();
3432 if (failed(Result: parser.parseColonType(result&: functionType)))
3433 return failure();
3434
3435 result.addTypes(newTypes: functionType.getResults());
3436
3437 if (functionType.getNumInputs() != operands.size()) {
3438 return parser.emitError(loc: typeLoc)
3439 << "expected as many input types as operands "
3440 << "(expected " << operands.size() << " got "
3441 << functionType.getNumInputs() << ")";
3442 }
3443
3444 // Resolve input operands.
3445 if (failed(Result: parser.resolveOperands(operands, types: functionType.getInputs(),
3446 loc: parser.getCurrentLocation(),
3447 result&: result.operands)))
3448 return failure();
3449
3450 // Propagate the types into the region arguments.
3451 for (size_t i = 0, e = regionArgs.size(); i != e; ++i)
3452 regionArgs[i].type = functionType.getInput(i);
3453
3454 return failure(IsFailure: parser.parseRegion(region&: *before, arguments: regionArgs) ||
3455 parser.parseKeyword(keyword: "do") || parser.parseRegion(region&: *after) ||
3456 parser.parseOptionalAttrDictWithKeyword(result&: result.attributes));
3457}
3458
3459/// Prints a `while` op.
3460void scf::WhileOp::print(OpAsmPrinter &p) {
3461 printInitializationList(p, blocksArgs: getBeforeArguments(), initializers: getInits(), prefix: " ");
3462 p << " : ";
3463 p.printFunctionalType(inputs: getInits().getTypes(), results: getResults().getTypes());
3464 p << ' ';
3465 p.printRegion(blocks&: getBefore(), /*printEntryBlockArgs=*/false);
3466 p << " do ";
3467 p.printRegion(blocks&: getAfter());
3468 p.printOptionalAttrDictWithKeyword(attrs: (*this)->getAttrs());
3469}
3470
3471/// Verifies that two ranges of types match, i.e. have the same number of
3472/// entries and that types are pairwise equals. Reports errors on the given
3473/// operation in case of mismatch.
3474template <typename OpTy>
3475static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left,
3476 TypeRange right, StringRef message) {
3477 if (left.size() != right.size())
3478 return op.emitOpError("expects the same number of ") << message;
3479
3480 for (unsigned i = 0, e = left.size(); i < e; ++i) {
3481 if (left[i] != right[i]) {
3482 InFlightDiagnostic diag = op.emitOpError("expects the same types for ")
3483 << message;
3484 diag.attachNote() << "for argument " << i << ", found " << left[i]
3485 << " and " << right[i];
3486 return diag;
3487 }
3488 }
3489
3490 return success();
3491}
3492
3493LogicalResult scf::WhileOp::verify() {
3494 auto beforeTerminator = verifyAndGetTerminator<scf::ConditionOp>(
3495 op: *this, region&: getBefore(),
3496 errorMessage: "expects the 'before' region to terminate with 'scf.condition'");
3497 if (!beforeTerminator)
3498 return failure();
3499
3500 auto afterTerminator = verifyAndGetTerminator<scf::YieldOp>(
3501 op: *this, region&: getAfter(),
3502 errorMessage: "expects the 'after' region to terminate with 'scf.yield'");
3503 return success(IsSuccess: afterTerminator != nullptr);
3504}
3505
3506namespace {
3507/// Replace uses of the condition within the do block with true, since otherwise
3508/// the block would not be evaluated.
3509///
3510/// scf.while (..) : (i1, ...) -> ... {
3511/// %condition = call @evaluate_condition() : () -> i1
3512/// scf.condition(%condition) %condition : i1, ...
3513/// } do {
3514/// ^bb0(%arg0: i1, ...):
3515/// use(%arg0)
3516/// ...
3517///
3518/// becomes
3519/// scf.while (..) : (i1, ...) -> ... {
3520/// %condition = call @evaluate_condition() : () -> i1
3521/// scf.condition(%condition) %condition : i1, ...
3522/// } do {
3523/// ^bb0(%arg0: i1, ...):
3524/// use(%true)
3525/// ...
3526struct WhileConditionTruth : public OpRewritePattern<WhileOp> {
3527 using OpRewritePattern<WhileOp>::OpRewritePattern;
3528
3529 LogicalResult matchAndRewrite(WhileOp op,
3530 PatternRewriter &rewriter) const override {
3531 auto term = op.getConditionOp();
3532
3533 // These variables serve to prevent creating duplicate constants
3534 // and hold constant true or false values.
3535 Value constantTrue = nullptr;
3536
3537 bool replaced = false;
3538 for (auto yieldedAndBlockArgs :
3539 llvm::zip(t: term.getArgs(), u: op.getAfterArguments())) {
3540 if (std::get<0>(t&: yieldedAndBlockArgs) == term.getCondition()) {
3541 if (!std::get<1>(t&: yieldedAndBlockArgs).use_empty()) {
3542 if (!constantTrue)
3543 constantTrue = rewriter.create<arith::ConstantOp>(
3544 location: op.getLoc(), args: term.getCondition().getType(),
3545 args: rewriter.getBoolAttr(value: true));
3546
3547 rewriter.replaceAllUsesWith(from: std::get<1>(t&: yieldedAndBlockArgs),
3548 to: constantTrue);
3549 replaced = true;
3550 }
3551 }
3552 }
3553 return success(IsSuccess: replaced);
3554 }
3555};
3556
3557/// Remove loop invariant arguments from `before` block of scf.while.
3558/// A before block argument is considered loop invariant if :-
3559/// 1. i-th yield operand is equal to the i-th while operand.
3560/// 2. i-th yield operand is k-th after block argument which is (k+1)-th
3561/// condition operand AND this (k+1)-th condition operand is equal to i-th
3562/// iter argument/while operand.
3563/// For the arguments which are removed, their uses inside scf.while
3564/// are replaced with their corresponding initial value.
3565///
3566/// Eg:
3567/// INPUT :-
3568/// %res = scf.while <...> iter_args(%arg0_before = %a, %arg1_before = %b,
3569/// ..., %argN_before = %N)
3570/// {
3571/// ...
3572/// scf.condition(%cond) %arg1_before, %arg0_before,
3573/// %arg2_before, %arg0_before, ...
3574/// } do {
3575/// ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2,
3576/// ..., %argK_after):
3577/// ...
3578/// scf.yield %arg0_after_2, %b, %arg1_after, ..., %argN
3579/// }
3580///
3581/// OUTPUT :-
3582/// %res = scf.while <...> iter_args(%arg2_before = %c, ..., %argN_before =
3583/// %N)
3584/// {
3585/// ...
3586/// scf.condition(%cond) %b, %a, %arg2_before, %a, ...
3587/// } do {
3588/// ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2,
3589/// ..., %argK_after):
3590/// ...
3591/// scf.yield %arg1_after, ..., %argN
3592/// }
3593///
3594/// EXPLANATION:
3595/// We iterate over each yield operand.
3596/// 1. 0-th yield operand %arg0_after_2 is 4-th condition operand
3597/// %arg0_before, which in turn is the 0-th iter argument. So we
3598/// remove 0-th before block argument and yield operand, and replace
3599/// all uses of the 0-th before block argument with its initial value
3600/// %a.
3601/// 2. 1-th yield operand %b is equal to the 1-th iter arg's initial
3602/// value. So we remove this operand and the corresponding before
3603/// block argument and replace all uses of 1-th before block argument
3604/// with %b.
3605struct RemoveLoopInvariantArgsFromBeforeBlock
3606 : public OpRewritePattern<WhileOp> {
3607 using OpRewritePattern<WhileOp>::OpRewritePattern;
3608
3609 LogicalResult matchAndRewrite(WhileOp op,
3610 PatternRewriter &rewriter) const override {
3611 Block &afterBlock = *op.getAfterBody();
3612 Block::BlockArgListType beforeBlockArgs = op.getBeforeArguments();
3613 ConditionOp condOp = op.getConditionOp();
3614 OperandRange condOpArgs = condOp.getArgs();
3615 Operation *yieldOp = afterBlock.getTerminator();
3616 ValueRange yieldOpArgs = yieldOp->getOperands();
3617
3618 bool canSimplify = false;
3619 for (const auto &it :
3620 llvm::enumerate(First: llvm::zip(t: op.getOperands(), u&: yieldOpArgs))) {
3621 auto index = static_cast<unsigned>(it.index());
3622 auto [initVal, yieldOpArg] = it.value();
3623 // If i-th yield operand is equal to the i-th operand of the scf.while,
3624 // the i-th before block argument is a loop invariant.
3625 if (yieldOpArg == initVal) {
3626 canSimplify = true;
3627 break;
3628 }
3629 // If the i-th yield operand is k-th after block argument, then we check
3630 // if the (k+1)-th condition op operand is equal to either the i-th before
3631 // block argument or the initial value of i-th before block argument. If
3632 // the comparison results `true`, i-th before block argument is a loop
3633 // invariant.
3634 auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(Val&: yieldOpArg);
3635 if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
3636 Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
3637 if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
3638 canSimplify = true;
3639 break;
3640 }
3641 }
3642 }
3643
3644 if (!canSimplify)
3645 return failure();
3646
3647 SmallVector<Value> newInitArgs, newYieldOpArgs;
3648 DenseMap<unsigned, Value> beforeBlockInitValMap;
3649 SmallVector<Location> newBeforeBlockArgLocs;
3650 for (const auto &it :
3651 llvm::enumerate(First: llvm::zip(t: op.getOperands(), u&: yieldOpArgs))) {
3652 auto index = static_cast<unsigned>(it.index());
3653 auto [initVal, yieldOpArg] = it.value();
3654
3655 // If i-th yield operand is equal to the i-th operand of the scf.while,
3656 // the i-th before block argument is a loop invariant.
3657 if (yieldOpArg == initVal) {
3658 beforeBlockInitValMap.insert(KV: {index, initVal});
3659 continue;
3660 } else {
3661 // If the i-th yield operand is k-th after block argument, then we check
3662 // if the (k+1)-th condition op operand is equal to either the i-th
3663 // before block argument or the initial value of i-th before block
3664 // argument. If the comparison results `true`, i-th before block
3665 // argument is a loop invariant.
3666 auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(Val&: yieldOpArg);
3667 if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
3668 Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
3669 if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
3670 beforeBlockInitValMap.insert(KV: {index, initVal});
3671 continue;
3672 }
3673 }
3674 }
3675 newInitArgs.emplace_back(Args&: initVal);
3676 newYieldOpArgs.emplace_back(Args&: yieldOpArg);
3677 newBeforeBlockArgLocs.emplace_back(Args: beforeBlockArgs[index].getLoc());
3678 }
3679
3680 {
3681 OpBuilder::InsertionGuard g(rewriter);
3682 rewriter.setInsertionPoint(yieldOp);
3683 rewriter.replaceOpWithNewOp<YieldOp>(op: yieldOp, args&: newYieldOpArgs);
3684 }
3685
3686 auto newWhile =
3687 rewriter.create<WhileOp>(location: op.getLoc(), args: op.getResultTypes(), args&: newInitArgs);
3688
3689 Block &newBeforeBlock = *rewriter.createBlock(
3690 parent: &newWhile.getBefore(), /*insertPt*/ {},
3691 argTypes: ValueRange(newYieldOpArgs).getTypes(), locs: newBeforeBlockArgLocs);
3692
3693 Block &beforeBlock = *op.getBeforeBody();
3694 SmallVector<Value> newBeforeBlockArgs(beforeBlock.getNumArguments());
3695 // For each i-th before block argument we find it's replacement value as :-
3696 // 1. If i-th before block argument is a loop invariant, we fetch it's
3697 // initial value from `beforeBlockInitValMap` by querying for key `i`.
3698 // 2. Else we fetch j-th new before block argument as the replacement
3699 // value of i-th before block argument.
3700 for (unsigned i = 0, j = 0, n = beforeBlock.getNumArguments(); i < n; i++) {
3701 // If the index 'i' argument was a loop invariant we fetch it's initial
3702 // value from `beforeBlockInitValMap`.
3703 if (beforeBlockInitValMap.count(Val: i) != 0)
3704 newBeforeBlockArgs[i] = beforeBlockInitValMap[i];
3705 else
3706 newBeforeBlockArgs[i] = newBeforeBlock.getArgument(i: j++);
3707 }
3708
3709 rewriter.mergeBlocks(source: &beforeBlock, dest: &newBeforeBlock, argValues: newBeforeBlockArgs);
3710 rewriter.inlineRegionBefore(region&: op.getAfter(), parent&: newWhile.getAfter(),
3711 before: newWhile.getAfter().begin());
3712
3713 rewriter.replaceOp(op, newValues: newWhile.getResults());
3714 return success();
3715 }
3716};
3717
3718/// Remove loop invariant value from result (condition op) of scf.while.
3719/// A value is considered loop invariant if the final value yielded by
3720/// scf.condition is defined outside of the `before` block. We remove the
3721/// corresponding argument in `after` block and replace the use with the value.
3722/// We also replace the use of the corresponding result of scf.while with the
3723/// value.
3724///
3725/// Eg:
3726/// INPUT :-
3727/// %res_input:K = scf.while <...> iter_args(%arg0_before = , ...,
3728/// %argN_before = %N) {
3729/// ...
3730/// scf.condition(%cond) %arg0_before, %a, %b, %arg1_before, ...
3731/// } do {
3732/// ^bb0(%arg0_after, %arg1_after, %arg2_after, ..., %argK_after):
3733/// ...
3734/// some_func(%arg1_after)
3735/// ...
3736/// scf.yield %arg0_after, %arg2_after, ..., %argN_after
3737/// }
3738///
3739/// OUTPUT :-
3740/// %res_output:M = scf.while <...> iter_args(%arg0 = , ..., %argN = %N) {
3741/// ...
3742/// scf.condition(%cond) %arg0, %arg1, ..., %argM
3743/// } do {
3744/// ^bb0(%arg0, %arg3, ..., %argM):
3745/// ...
3746/// some_func(%a)
3747/// ...
3748/// scf.yield %arg0, %b, ..., %argN
3749/// }
3750///
3751/// EXPLANATION:
3752/// 1. The 1-th and 2-th operand of scf.condition are defined outside the
3753/// before block of scf.while, so they get removed.
3754/// 2. %res_input#1's uses are replaced by %a and %res_input#2's uses are
3755/// replaced by %b.
3756/// 3. The corresponding after block argument %arg1_after's uses are
3757/// replaced by %a and %arg2_after's uses are replaced by %b.
3758struct RemoveLoopInvariantValueYielded : public OpRewritePattern<WhileOp> {
3759 using OpRewritePattern<WhileOp>::OpRewritePattern;
3760
3761 LogicalResult matchAndRewrite(WhileOp op,
3762 PatternRewriter &rewriter) const override {
3763 Block &beforeBlock = *op.getBeforeBody();
3764 ConditionOp condOp = op.getConditionOp();
3765 OperandRange condOpArgs = condOp.getArgs();
3766
3767 bool canSimplify = false;
3768 for (Value condOpArg : condOpArgs) {
3769 // Those values not defined within `before` block will be considered as
3770 // loop invariant values. We map the corresponding `index` with their
3771 // value.
3772 if (condOpArg.getParentBlock() != &beforeBlock) {
3773 canSimplify = true;
3774 break;
3775 }
3776 }
3777
3778 if (!canSimplify)
3779 return failure();
3780
3781 Block::BlockArgListType afterBlockArgs = op.getAfterArguments();
3782
3783 SmallVector<Value> newCondOpArgs;
3784 SmallVector<Type> newAfterBlockType;
3785 DenseMap<unsigned, Value> condOpInitValMap;
3786 SmallVector<Location> newAfterBlockArgLocs;
3787 for (const auto &it : llvm::enumerate(First&: condOpArgs)) {
3788 auto index = static_cast<unsigned>(it.index());
3789 Value condOpArg = it.value();
3790 // Those values not defined within `before` block will be considered as
3791 // loop invariant values. We map the corresponding `index` with their
3792 // value.
3793 if (condOpArg.getParentBlock() != &beforeBlock) {
3794 condOpInitValMap.insert(KV: {index, condOpArg});
3795 } else {
3796 newCondOpArgs.emplace_back(Args&: condOpArg);
3797 newAfterBlockType.emplace_back(Args: condOpArg.getType());
3798 newAfterBlockArgLocs.emplace_back(Args: afterBlockArgs[index].getLoc());
3799 }
3800 }
3801
3802 {
3803 OpBuilder::InsertionGuard g(rewriter);
3804 rewriter.setInsertionPoint(condOp);
3805 rewriter.replaceOpWithNewOp<ConditionOp>(op: condOp, args: condOp.getCondition(),
3806 args&: newCondOpArgs);
3807 }
3808
3809 auto newWhile = rewriter.create<WhileOp>(location: op.getLoc(), args&: newAfterBlockType,
3810 args: op.getOperands());
3811
3812 Block &newAfterBlock =
3813 *rewriter.createBlock(parent: &newWhile.getAfter(), /*insertPt*/ {},
3814 argTypes: newAfterBlockType, locs: newAfterBlockArgLocs);
3815
3816 Block &afterBlock = *op.getAfterBody();
3817 // Since a new scf.condition op was created, we need to fetch the new
3818 // `after` block arguments which will be used while replacing operations of
3819 // previous scf.while's `after` blocks. We'd also be fetching new result
3820 // values too.
3821 SmallVector<Value> newAfterBlockArgs(afterBlock.getNumArguments());
3822 SmallVector<Value> newWhileResults(afterBlock.getNumArguments());
3823 for (unsigned i = 0, j = 0, n = afterBlock.getNumArguments(); i < n; i++) {
3824 Value afterBlockArg, result;
3825 // If index 'i' argument was loop invariant we fetch it's value from the
3826 // `condOpInitMap` map.
3827 if (condOpInitValMap.count(Val: i) != 0) {
3828 afterBlockArg = condOpInitValMap[i];
3829 result = afterBlockArg;
3830 } else {
3831 afterBlockArg = newAfterBlock.getArgument(i: j);
3832 result = newWhile.getResult(i: j);
3833 j++;
3834 }
3835 newAfterBlockArgs[i] = afterBlockArg;
3836 newWhileResults[i] = result;
3837 }
3838
3839 rewriter.mergeBlocks(source: &afterBlock, dest: &newAfterBlock, argValues: newAfterBlockArgs);
3840 rewriter.inlineRegionBefore(region&: op.getBefore(), parent&: newWhile.getBefore(),
3841 before: newWhile.getBefore().begin());
3842
3843 rewriter.replaceOp(op, newValues: newWhileResults);
3844 return success();
3845 }
3846};
3847
3848/// Remove WhileOp results that are also unused in 'after' block.
3849///
3850/// %0:2 = scf.while () : () -> (i32, i64) {
3851/// %condition = "test.condition"() : () -> i1
3852/// %v1 = "test.get_some_value"() : () -> i32
3853/// %v2 = "test.get_some_value"() : () -> i64
3854/// scf.condition(%condition) %v1, %v2 : i32, i64
3855/// } do {
3856/// ^bb0(%arg0: i32, %arg1: i64):
3857/// "test.use"(%arg0) : (i32) -> ()
3858/// scf.yield
3859/// }
3860/// return %0#0 : i32
3861///
3862/// becomes
3863/// %0 = scf.while () : () -> (i32) {
3864/// %condition = "test.condition"() : () -> i1
3865/// %v1 = "test.get_some_value"() : () -> i32
3866/// %v2 = "test.get_some_value"() : () -> i64
3867/// scf.condition(%condition) %v1 : i32
3868/// } do {
3869/// ^bb0(%arg0: i32):
3870/// "test.use"(%arg0) : (i32) -> ()
3871/// scf.yield
3872/// }
3873/// return %0 : i32
3874struct WhileUnusedResult : public OpRewritePattern<WhileOp> {
3875 using OpRewritePattern<WhileOp>::OpRewritePattern;
3876
3877 LogicalResult matchAndRewrite(WhileOp op,
3878 PatternRewriter &rewriter) const override {
3879 auto term = op.getConditionOp();
3880 auto afterArgs = op.getAfterArguments();
3881 auto termArgs = term.getArgs();
3882
3883 // Collect results mapping, new terminator args and new result types.
3884 SmallVector<unsigned> newResultsIndices;
3885 SmallVector<Type> newResultTypes;
3886 SmallVector<Value> newTermArgs;
3887 SmallVector<Location> newArgLocs;
3888 bool needUpdate = false;
3889 for (const auto &it :
3890 llvm::enumerate(First: llvm::zip(t: op.getResults(), u&: afterArgs, args&: termArgs))) {
3891 auto i = static_cast<unsigned>(it.index());
3892 Value result = std::get<0>(t&: it.value());
3893 Value afterArg = std::get<1>(t&: it.value());
3894 Value termArg = std::get<2>(t&: it.value());
3895 if (result.use_empty() && afterArg.use_empty()) {
3896 needUpdate = true;
3897 } else {
3898 newResultsIndices.emplace_back(Args&: i);
3899 newTermArgs.emplace_back(Args&: termArg);
3900 newResultTypes.emplace_back(Args: result.getType());
3901 newArgLocs.emplace_back(Args: result.getLoc());
3902 }
3903 }
3904
3905 if (!needUpdate)
3906 return failure();
3907
3908 {
3909 OpBuilder::InsertionGuard g(rewriter);
3910 rewriter.setInsertionPoint(term);
3911 rewriter.replaceOpWithNewOp<ConditionOp>(op: term, args: term.getCondition(),
3912 args&: newTermArgs);
3913 }
3914
3915 auto newWhile =
3916 rewriter.create<WhileOp>(location: op.getLoc(), args&: newResultTypes, args: op.getInits());
3917
3918 Block &newAfterBlock = *rewriter.createBlock(
3919 parent: &newWhile.getAfter(), /*insertPt*/ {}, argTypes: newResultTypes, locs: newArgLocs);
3920
3921 // Build new results list and new after block args (unused entries will be
3922 // null).
3923 SmallVector<Value> newResults(op.getNumResults());
3924 SmallVector<Value> newAfterBlockArgs(op.getNumResults());
3925 for (const auto &it : llvm::enumerate(First&: newResultsIndices)) {
3926 newResults[it.value()] = newWhile.getResult(i: it.index());
3927 newAfterBlockArgs[it.value()] = newAfterBlock.getArgument(i: it.index());
3928 }
3929
3930 rewriter.inlineRegionBefore(region&: op.getBefore(), parent&: newWhile.getBefore(),
3931 before: newWhile.getBefore().begin());
3932
3933 Block &afterBlock = *op.getAfterBody();
3934 rewriter.mergeBlocks(source: &afterBlock, dest: &newAfterBlock, argValues: newAfterBlockArgs);
3935
3936 rewriter.replaceOp(op, newValues: newResults);
3937 return success();
3938 }
3939};
3940
3941/// Replace operations equivalent to the condition in the do block with true,
3942/// since otherwise the block would not be evaluated.
3943///
3944/// scf.while (..) : (i32, ...) -> ... {
3945/// %z = ... : i32
3946/// %condition = cmpi pred %z, %a
3947/// scf.condition(%condition) %z : i32, ...
3948/// } do {
3949/// ^bb0(%arg0: i32, ...):
3950/// %condition2 = cmpi pred %arg0, %a
3951/// use(%condition2)
3952/// ...
3953///
3954/// becomes
3955/// scf.while (..) : (i32, ...) -> ... {
3956/// %z = ... : i32
3957/// %condition = cmpi pred %z, %a
3958/// scf.condition(%condition) %z : i32, ...
3959/// } do {
3960/// ^bb0(%arg0: i32, ...):
3961/// use(%true)
3962/// ...
3963struct WhileCmpCond : public OpRewritePattern<scf::WhileOp> {
3964 using OpRewritePattern<scf::WhileOp>::OpRewritePattern;
3965
3966 LogicalResult matchAndRewrite(scf::WhileOp op,
3967 PatternRewriter &rewriter) const override {
3968 using namespace scf;
3969 auto cond = op.getConditionOp();
3970 auto cmp = cond.getCondition().getDefiningOp<arith::CmpIOp>();
3971 if (!cmp)
3972 return failure();
3973 bool changed = false;
3974 for (auto tup : llvm::zip(t: cond.getArgs(), u: op.getAfterArguments())) {
3975 for (size_t opIdx = 0; opIdx < 2; opIdx++) {
3976 if (std::get<0>(t&: tup) != cmp.getOperand(i: opIdx))
3977 continue;
3978 for (OpOperand &u :
3979 llvm::make_early_inc_range(Range: std::get<1>(t&: tup).getUses())) {
3980 auto cmp2 = dyn_cast<arith::CmpIOp>(Val: u.getOwner());
3981 if (!cmp2)
3982 continue;
3983 // For a binary operator 1-opIdx gets the other side.
3984 if (cmp2.getOperand(i: 1 - opIdx) != cmp.getOperand(i: 1 - opIdx))
3985 continue;
3986 bool samePredicate;
3987 if (cmp2.getPredicate() == cmp.getPredicate())
3988 samePredicate = true;
3989 else if (cmp2.getPredicate() ==
3990 arith::invertPredicate(pred: cmp.getPredicate()))
3991 samePredicate = false;
3992 else
3993 continue;
3994
3995 rewriter.replaceOpWithNewOp<arith::ConstantIntOp>(op: cmp2, args&: samePredicate,
3996 args: 1);
3997 changed = true;
3998 }
3999 }
4000 }
4001 return success(IsSuccess: changed);
4002 }
4003};
4004
4005/// Remove unused init/yield args.
4006struct WhileRemoveUnusedArgs : public OpRewritePattern<WhileOp> {
4007 using OpRewritePattern<WhileOp>::OpRewritePattern;
4008
4009 LogicalResult matchAndRewrite(WhileOp op,
4010 PatternRewriter &rewriter) const override {
4011
4012 if (!llvm::any_of(Range: op.getBeforeArguments(),
4013 P: [](Value arg) { return arg.use_empty(); }))
4014 return rewriter.notifyMatchFailure(arg&: op, msg: "No args to remove");
4015
4016 YieldOp yield = op.getYieldOp();
4017
4018 // Collect results mapping, new terminator args and new result types.
4019 SmallVector<Value> newYields;
4020 SmallVector<Value> newInits;
4021 llvm::BitVector argsToErase;
4022
4023 size_t argsCount = op.getBeforeArguments().size();
4024 newYields.reserve(N: argsCount);
4025 newInits.reserve(N: argsCount);
4026 argsToErase.reserve(N: argsCount);
4027 for (auto &&[beforeArg, yieldValue, initValue] : llvm::zip(
4028 t: op.getBeforeArguments(), u: yield.getOperands(), args: op.getInits())) {
4029 if (beforeArg.use_empty()) {
4030 argsToErase.push_back(Val: true);
4031 } else {
4032 argsToErase.push_back(Val: false);
4033 newYields.emplace_back(Args&: yieldValue);
4034 newInits.emplace_back(Args&: initValue);
4035 }
4036 }
4037
4038 Block &beforeBlock = *op.getBeforeBody();
4039 Block &afterBlock = *op.getAfterBody();
4040
4041 beforeBlock.eraseArguments(eraseIndices: argsToErase);
4042
4043 Location loc = op.getLoc();
4044 auto newWhileOp =
4045 rewriter.create<WhileOp>(location: loc, args: op.getResultTypes(), args&: newInits,
4046 /*beforeBody*/ args: nullptr, /*afterBody*/ args: nullptr);
4047 Block &newBeforeBlock = *newWhileOp.getBeforeBody();
4048 Block &newAfterBlock = *newWhileOp.getAfterBody();
4049
4050 OpBuilder::InsertionGuard g(rewriter);
4051 rewriter.setInsertionPoint(yield);
4052 rewriter.replaceOpWithNewOp<YieldOp>(op: yield, args&: newYields);
4053
4054 rewriter.mergeBlocks(source: &beforeBlock, dest: &newBeforeBlock,
4055 argValues: newBeforeBlock.getArguments());
4056 rewriter.mergeBlocks(source: &afterBlock, dest: &newAfterBlock,
4057 argValues: newAfterBlock.getArguments());
4058
4059 rewriter.replaceOp(op, newValues: newWhileOp.getResults());
4060 return success();
4061 }
4062};
4063
4064/// Remove duplicated ConditionOp args.
4065struct WhileRemoveDuplicatedResults : public OpRewritePattern<WhileOp> {
4066 using OpRewritePattern::OpRewritePattern;
4067
4068 LogicalResult matchAndRewrite(WhileOp op,
4069 PatternRewriter &rewriter) const override {
4070 ConditionOp condOp = op.getConditionOp();
4071 ValueRange condOpArgs = condOp.getArgs();
4072
4073 llvm::SmallPtrSet<Value, 8> argsSet(llvm::from_range, condOpArgs);
4074
4075 if (argsSet.size() == condOpArgs.size())
4076 return rewriter.notifyMatchFailure(arg&: op, msg: "No results to remove");
4077
4078 llvm::SmallDenseMap<Value, unsigned> argsMap;
4079 SmallVector<Value> newArgs;
4080 argsMap.reserve(NumEntries: condOpArgs.size());
4081 newArgs.reserve(N: condOpArgs.size());
4082 for (Value arg : condOpArgs) {
4083 if (!argsMap.count(Val: arg)) {
4084 auto pos = static_cast<unsigned>(argsMap.size());
4085 argsMap.insert(KV: {arg, pos});
4086 newArgs.emplace_back(Args&: arg);
4087 }
4088 }
4089
4090 ValueRange argsRange(newArgs);
4091
4092 Location loc = op.getLoc();
4093 auto newWhileOp = rewriter.create<scf::WhileOp>(
4094 location: loc, args: argsRange.getTypes(), args: op.getInits(), /*beforeBody*/ args: nullptr,
4095 /*afterBody*/ args: nullptr);
4096 Block &newBeforeBlock = *newWhileOp.getBeforeBody();
4097 Block &newAfterBlock = *newWhileOp.getAfterBody();
4098
4099 SmallVector<Value> afterArgsMapping;
4100 SmallVector<Value> resultsMapping;
4101 for (auto &&[i, arg] : llvm::enumerate(First&: condOpArgs)) {
4102 auto it = argsMap.find(Val: arg);
4103 assert(it != argsMap.end());
4104 auto pos = it->second;
4105 afterArgsMapping.emplace_back(Args: newAfterBlock.getArgument(i: pos));
4106 resultsMapping.emplace_back(Args: newWhileOp->getResult(idx: pos));
4107 }
4108
4109 OpBuilder::InsertionGuard g(rewriter);
4110 rewriter.setInsertionPoint(condOp);
4111 rewriter.replaceOpWithNewOp<ConditionOp>(op: condOp, args: condOp.getCondition(),
4112 args&: argsRange);
4113
4114 Block &beforeBlock = *op.getBeforeBody();
4115 Block &afterBlock = *op.getAfterBody();
4116
4117 rewriter.mergeBlocks(source: &beforeBlock, dest: &newBeforeBlock,
4118 argValues: newBeforeBlock.getArguments());
4119 rewriter.mergeBlocks(source: &afterBlock, dest: &newAfterBlock, argValues: afterArgsMapping);
4120 rewriter.replaceOp(op, newValues: resultsMapping);
4121 return success();
4122 }
4123};
4124
4125/// If both ranges contain same values return mappping indices from args2 to
4126/// args1. Otherwise return std::nullopt.
4127static std::optional<SmallVector<unsigned>> getArgsMapping(ValueRange args1,
4128 ValueRange args2) {
4129 if (args1.size() != args2.size())
4130 return std::nullopt;
4131
4132 SmallVector<unsigned> ret(args1.size());
4133 for (auto &&[i, arg1] : llvm::enumerate(First&: args1)) {
4134 auto it = llvm::find(Range&: args2, Val: arg1);
4135 if (it == args2.end())
4136 return std::nullopt;
4137
4138 ret[std::distance(first: args2.begin(), last: it)] = static_cast<unsigned>(i);
4139 }
4140
4141 return ret;
4142}
4143
4144static bool hasDuplicates(ValueRange args) {
4145 llvm::SmallDenseSet<Value> set;
4146 for (Value arg : args) {
4147 if (!set.insert(V: arg).second)
4148 return true;
4149 }
4150 return false;
4151}
4152
4153/// If `before` block args are directly forwarded to `scf.condition`, rearrange
4154/// `scf.condition` args into same order as block args. Update `after` block
4155/// args and op result values accordingly.
4156/// Needed to simplify `scf.while` -> `scf.for` uplifting.
4157struct WhileOpAlignBeforeArgs : public OpRewritePattern<WhileOp> {
4158 using OpRewritePattern::OpRewritePattern;
4159
4160 LogicalResult matchAndRewrite(WhileOp loop,
4161 PatternRewriter &rewriter) const override {
4162 auto oldBefore = loop.getBeforeBody();
4163 ConditionOp oldTerm = loop.getConditionOp();
4164 ValueRange beforeArgs = oldBefore->getArguments();
4165 ValueRange termArgs = oldTerm.getArgs();
4166 if (beforeArgs == termArgs)
4167 return failure();
4168
4169 if (hasDuplicates(args: termArgs))
4170 return failure();
4171
4172 auto mapping = getArgsMapping(args1: beforeArgs, args2: termArgs);
4173 if (!mapping)
4174 return failure();
4175
4176 {
4177 OpBuilder::InsertionGuard g(rewriter);
4178 rewriter.setInsertionPoint(oldTerm);
4179 rewriter.replaceOpWithNewOp<ConditionOp>(op: oldTerm, args: oldTerm.getCondition(),
4180 args&: beforeArgs);
4181 }
4182
4183 auto oldAfter = loop.getAfterBody();
4184
4185 SmallVector<Type> newResultTypes(beforeArgs.size());
4186 for (auto &&[i, j] : llvm::enumerate(First&: *mapping))
4187 newResultTypes[j] = loop.getResult(i).getType();
4188
4189 auto newLoop = rewriter.create<WhileOp>(
4190 location: loop.getLoc(), args&: newResultTypes, args: loop.getInits(),
4191 /*beforeBuilder=*/args: nullptr, /*afterBuilder=*/args: nullptr);
4192 auto newBefore = newLoop.getBeforeBody();
4193 auto newAfter = newLoop.getAfterBody();
4194
4195 SmallVector<Value> newResults(beforeArgs.size());
4196 SmallVector<Value> newAfterArgs(beforeArgs.size());
4197 for (auto &&[i, j] : llvm::enumerate(First&: *mapping)) {
4198 newResults[i] = newLoop.getResult(i: j);
4199 newAfterArgs[i] = newAfter->getArgument(i: j);
4200 }
4201
4202 rewriter.inlineBlockBefore(source: oldBefore, dest: newBefore, before: newBefore->begin(),
4203 argValues: newBefore->getArguments());
4204 rewriter.inlineBlockBefore(source: oldAfter, dest: newAfter, before: newAfter->begin(),
4205 argValues: newAfterArgs);
4206
4207 rewriter.replaceOp(op: loop, newValues: newResults);
4208 return success();
4209 }
4210};
4211} // namespace
4212
4213void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
4214 MLIRContext *context) {
4215 results.add<RemoveLoopInvariantArgsFromBeforeBlock,
4216 RemoveLoopInvariantValueYielded, WhileConditionTruth,
4217 WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults,
4218 WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs>(arg&: context);
4219}
4220
4221//===----------------------------------------------------------------------===//
4222// IndexSwitchOp
4223//===----------------------------------------------------------------------===//
4224
4225/// Parse the case regions and values.
4226static ParseResult
4227parseSwitchCases(OpAsmParser &p, DenseI64ArrayAttr &cases,
4228 SmallVectorImpl<std::unique_ptr<Region>> &caseRegions) {
4229 SmallVector<int64_t> caseValues;
4230 while (succeeded(Result: p.parseOptionalKeyword(keyword: "case"))) {
4231 int64_t value;
4232 Region &region = *caseRegions.emplace_back(Args: std::make_unique<Region>());
4233 if (p.parseInteger(result&: value) || p.parseRegion(region, /*arguments=*/{}))
4234 return failure();
4235 caseValues.push_back(Elt: value);
4236 }
4237 cases = p.getBuilder().getDenseI64ArrayAttr(values: caseValues);
4238 return success();
4239}
4240
4241/// Print the case regions and values.
4242static void printSwitchCases(OpAsmPrinter &p, Operation *op,
4243 DenseI64ArrayAttr cases, RegionRange caseRegions) {
4244 for (auto [value, region] : llvm::zip(t: cases.asArrayRef(), u&: caseRegions)) {
4245 p.printNewline();
4246 p << "case " << value << ' ';
4247 p.printRegion(blocks&: *region, /*printEntryBlockArgs=*/false);
4248 }
4249}
4250
4251LogicalResult scf::IndexSwitchOp::verify() {
4252 if (getCases().size() != getCaseRegions().size()) {
4253 return emitOpError(message: "has ")
4254 << getCaseRegions().size() << " case regions but "
4255 << getCases().size() << " case values";
4256 }
4257
4258 DenseSet<int64_t> valueSet;
4259 for (int64_t value : getCases())
4260 if (!valueSet.insert(V: value).second)
4261 return emitOpError(message: "has duplicate case value: ") << value;
4262 auto verifyRegion = [&](Region &region, const Twine &name) -> LogicalResult {
4263 auto yield = dyn_cast<YieldOp>(Val&: region.front().back());
4264 if (!yield)
4265 return emitOpError(message: "expected region to end with scf.yield, but got ")
4266 << region.front().back().getName();
4267
4268 if (yield.getNumOperands() != getNumResults()) {
4269 return (emitOpError(message: "expected each region to return ")
4270 << getNumResults() << " values, but " << name << " returns "
4271 << yield.getNumOperands())
4272 .attachNote(noteLoc: yield.getLoc())
4273 << "see yield operation here";
4274 }
4275 for (auto [idx, result, operand] :
4276 llvm::enumerate(First: getResultTypes(), Rest: yield.getOperands())) {
4277 if (!operand)
4278 return yield.emitOpError() << "operand " << idx << " is null\n";
4279 if (result == operand.getType())
4280 continue;
4281 return (emitOpError(message: "expected result #")
4282 << idx << " of each region to be " << result)
4283 .attachNote(noteLoc: yield.getLoc())
4284 << name << " returns " << operand.getType() << " here";
4285 }
4286 return success();
4287 };
4288
4289 if (failed(Result: verifyRegion(getDefaultRegion(), "default region")))
4290 return failure();
4291 for (auto [idx, caseRegion] : llvm::enumerate(First: getCaseRegions()))
4292 if (failed(Result: verifyRegion(caseRegion, "case region #" + Twine(idx))))
4293 return failure();
4294
4295 return success();
4296}
4297
4298unsigned scf::IndexSwitchOp::getNumCases() { return getCases().size(); }
4299
4300Block &scf::IndexSwitchOp::getDefaultBlock() {
4301 return getDefaultRegion().front();
4302}
4303
4304Block &scf::IndexSwitchOp::getCaseBlock(unsigned idx) {
4305 assert(idx < getNumCases() && "case index out-of-bounds");
4306 return getCaseRegions()[idx].front();
4307}
4308
4309void IndexSwitchOp::getSuccessorRegions(
4310 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &successors) {
4311 // All regions branch back to the parent op.
4312 if (!point.isParent()) {
4313 successors.emplace_back(Args: getResults());
4314 return;
4315 }
4316
4317 llvm::append_range(C&: successors, R: getRegions());
4318}
4319
4320void IndexSwitchOp::getEntrySuccessorRegions(
4321 ArrayRef<Attribute> operands,
4322 SmallVectorImpl<RegionSuccessor> &successors) {
4323 FoldAdaptor adaptor(operands, *this);
4324
4325 // If a constant was not provided, all regions are possible successors.
4326 auto arg = dyn_cast_or_null<IntegerAttr>(Val: adaptor.getArg());
4327 if (!arg) {
4328 llvm::append_range(C&: successors, R: getRegions());
4329 return;
4330 }
4331
4332 // Otherwise, try to find a case with a matching value. If not, the
4333 // default region is the only successor.
4334 for (auto [caseValue, caseRegion] : llvm::zip(t: getCases(), u: getCaseRegions())) {
4335 if (caseValue == arg.getInt()) {
4336 successors.emplace_back(Args: &caseRegion);
4337 return;
4338 }
4339 }
4340 successors.emplace_back(Args: &getDefaultRegion());
4341}
4342
4343void IndexSwitchOp::getRegionInvocationBounds(
4344 ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
4345 auto operandValue = llvm::dyn_cast_or_null<IntegerAttr>(Val: operands.front());
4346 if (!operandValue) {
4347 // All regions are invoked at most once.
4348 bounds.append(NumInputs: getNumRegions(), Elt: InvocationBounds(/*lb=*/0, /*ub=*/1));
4349 return;
4350 }
4351
4352 unsigned liveIndex = getNumRegions() - 1;
4353 const auto *it = llvm::find(Range: getCases(), Val: operandValue.getInt());
4354 if (it != getCases().end())
4355 liveIndex = std::distance(first: getCases().begin(), last: it);
4356 for (unsigned i = 0, e = getNumRegions(); i < e; ++i)
4357 bounds.emplace_back(/*lb=*/Args: 0, /*ub=*/Args: i == liveIndex);
4358}
4359
4360struct FoldConstantCase : OpRewritePattern<scf::IndexSwitchOp> {
4361 using OpRewritePattern<scf::IndexSwitchOp>::OpRewritePattern;
4362
4363 LogicalResult matchAndRewrite(scf::IndexSwitchOp op,
4364 PatternRewriter &rewriter) const override {
4365 // If `op.getArg()` is a constant, select the region that matches with
4366 // the constant value. Use the default region if no matche is found.
4367 std::optional<int64_t> maybeCst = getConstantIntValue(ofr: op.getArg());
4368 if (!maybeCst.has_value())
4369 return failure();
4370 int64_t cst = *maybeCst;
4371 int64_t caseIdx, e = op.getNumCases();
4372 for (caseIdx = 0; caseIdx < e; ++caseIdx) {
4373 if (cst == op.getCases()[caseIdx])
4374 break;
4375 }
4376
4377 Region &r = (caseIdx < op.getNumCases()) ? op.getCaseRegions()[caseIdx]
4378 : op.getDefaultRegion();
4379 Block &source = r.front();
4380 Operation *terminator = source.getTerminator();
4381 SmallVector<Value> results = terminator->getOperands();
4382
4383 rewriter.inlineBlockBefore(source: &source, op);
4384 rewriter.eraseOp(op: terminator);
4385 // Replace the operation with a potentially empty list of results.
4386 // Fold mechanism doesn't support the case where the result list is empty.
4387 rewriter.replaceOp(op, newValues: results);
4388
4389 return success();
4390 }
4391};
4392
4393void IndexSwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
4394 MLIRContext *context) {
4395 results.add<FoldConstantCase>(arg&: context);
4396}
4397
4398//===----------------------------------------------------------------------===//
4399// TableGen'd op method definitions
4400//===----------------------------------------------------------------------===//
4401
4402#define GET_OP_CLASSES
4403#include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
4404

source code of mlir/lib/Dialect/SCF/IR/SCF.cpp