1//===- SCF.cpp - Structured Control Flow Operations -----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/SCF/IR/SCF.h"
10#include "mlir/Conversion/ConvertToEmitC/ToEmitCInterface.h"
11#include "mlir/Dialect/Arith/IR/Arith.h"
12#include "mlir/Dialect/Arith/Utils/Utils.h"
13#include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h"
14#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
15#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
16#include "mlir/Dialect/MemRef/IR/MemRef.h"
17#include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h"
18#include "mlir/Dialect/Tensor/IR/Tensor.h"
19#include "mlir/IR/BuiltinAttributes.h"
20#include "mlir/IR/IRMapping.h"
21#include "mlir/IR/Matchers.h"
22#include "mlir/IR/PatternMatch.h"
23#include "mlir/Interfaces/FunctionInterfaces.h"
24#include "mlir/Interfaces/ValueBoundsOpInterface.h"
25#include "mlir/Transforms/InliningUtils.h"
26#include "llvm/ADT/MapVector.h"
27#include "llvm/ADT/SmallPtrSet.h"
28#include "llvm/ADT/TypeSwitch.h"
29
30using namespace mlir;
31using namespace mlir::scf;
32
33#include "mlir/Dialect/SCF/IR/SCFOpsDialect.cpp.inc"
34
35//===----------------------------------------------------------------------===//
36// SCFDialect Dialect Interfaces
37//===----------------------------------------------------------------------===//
38
39namespace {
40struct SCFInlinerInterface : public DialectInlinerInterface {
41 using DialectInlinerInterface::DialectInlinerInterface;
42 // We don't have any special restrictions on what can be inlined into
43 // destination regions (e.g. while/conditional bodies). Always allow it.
44 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
45 IRMapping &valueMapping) const final {
46 return true;
47 }
48 // Operations in scf dialect are always legal to inline since they are
49 // pure.
50 bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
51 return true;
52 }
53 // Handle the given inlined terminator by replacing it with a new operation
54 // as necessary. Required when the region has only one block.
55 void handleTerminator(Operation *op, ValueRange valuesToRepl) const final {
56 auto retValOp = dyn_cast<scf::YieldOp>(op);
57 if (!retValOp)
58 return;
59
60 for (auto retValue : llvm::zip(valuesToRepl, retValOp.getOperands())) {
61 std::get<0>(retValue).replaceAllUsesWith(std::get<1>(retValue));
62 }
63 }
64};
65} // namespace
66
67//===----------------------------------------------------------------------===//
68// SCFDialect
69//===----------------------------------------------------------------------===//
70
71void SCFDialect::initialize() {
72 addOperations<
73#define GET_OP_LIST
74#include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
75 >();
76 addInterfaces<SCFInlinerInterface>();
77 declarePromisedInterface<ConvertToEmitCPatternInterface, SCFDialect>();
78 declarePromisedInterfaces<bufferization::BufferDeallocationOpInterface,
79 InParallelOp, ReduceReturnOp>();
80 declarePromisedInterfaces<bufferization::BufferizableOpInterface, ConditionOp,
81 ExecuteRegionOp, ForOp, IfOp, IndexSwitchOp,
82 ForallOp, InParallelOp, WhileOp, YieldOp>();
83 declarePromisedInterface<ValueBoundsOpInterface, ForOp>();
84}
85
86/// Default callback for IfOp builders. Inserts a yield without arguments.
87void mlir::scf::buildTerminatedBody(OpBuilder &builder, Location loc) {
88 builder.create<scf::YieldOp>(loc);
89}
90
91/// Verifies that the first block of the given `region` is terminated by a
92/// TerminatorTy. Reports errors on the given operation if it is not the case.
93template <typename TerminatorTy>
94static TerminatorTy verifyAndGetTerminator(Operation *op, Region &region,
95 StringRef errorMessage) {
96 Operation *terminatorOperation = nullptr;
97 if (!region.empty() && !region.front().empty()) {
98 terminatorOperation = &region.front().back();
99 if (auto yield = dyn_cast_or_null<TerminatorTy>(terminatorOperation))
100 return yield;
101 }
102 auto diag = op->emitOpError(message: errorMessage);
103 if (terminatorOperation)
104 diag.attachNote(noteLoc: terminatorOperation->getLoc()) << "terminator here";
105 return nullptr;
106}
107
108//===----------------------------------------------------------------------===//
109// ExecuteRegionOp
110//===----------------------------------------------------------------------===//
111
112/// Replaces the given op with the contents of the given single-block region,
113/// using the operands of the block terminator to replace operation results.
114static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
115 Region &region, ValueRange blockArgs = {}) {
116 assert(llvm::hasSingleElement(region) && "expected single-region block");
117 Block *block = &region.front();
118 Operation *terminator = block->getTerminator();
119 ValueRange results = terminator->getOperands();
120 rewriter.inlineBlockBefore(source: block, op, argValues: blockArgs);
121 rewriter.replaceOp(op, newValues: results);
122 rewriter.eraseOp(op: terminator);
123}
124
125///
126/// (ssa-id `=`)? `execute_region` `->` function-result-type `{`
127/// block+
128/// `}`
129///
130/// Example:
131/// scf.execute_region -> i32 {
132/// %idx = load %rI[%i] : memref<128xi32>
133/// return %idx : i32
134/// }
135///
136ParseResult ExecuteRegionOp::parse(OpAsmParser &parser,
137 OperationState &result) {
138 if (parser.parseOptionalArrowTypeList(result.types))
139 return failure();
140
141 // Introduce the body region and parse it.
142 Region *body = result.addRegion();
143 if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}) ||
144 parser.parseOptionalAttrDict(result.attributes))
145 return failure();
146
147 return success();
148}
149
150void ExecuteRegionOp::print(OpAsmPrinter &p) {
151 p.printOptionalArrowTypeList(getResultTypes());
152
153 p << ' ';
154 p.printRegion(getRegion(),
155 /*printEntryBlockArgs=*/false,
156 /*printBlockTerminators=*/true);
157
158 p.printOptionalAttrDict((*this)->getAttrs());
159}
160
161LogicalResult ExecuteRegionOp::verify() {
162 if (getRegion().empty())
163 return emitOpError("region needs to have at least one block");
164 if (getRegion().front().getNumArguments() > 0)
165 return emitOpError("region cannot have any arguments");
166 return success();
167}
168
169// Inline an ExecuteRegionOp if it only contains one block.
170// "test.foo"() : () -> ()
171// %v = scf.execute_region -> i64 {
172// %x = "test.val"() : () -> i64
173// scf.yield %x : i64
174// }
175// "test.bar"(%v) : (i64) -> ()
176//
177// becomes
178//
179// "test.foo"() : () -> ()
180// %x = "test.val"() : () -> i64
181// "test.bar"(%x) : (i64) -> ()
182//
183struct SingleBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
184 using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern;
185
186 LogicalResult matchAndRewrite(ExecuteRegionOp op,
187 PatternRewriter &rewriter) const override {
188 if (!llvm::hasSingleElement(op.getRegion()))
189 return failure();
190 replaceOpWithRegion(rewriter, op, op.getRegion());
191 return success();
192 }
193};
194
195// Inline an ExecuteRegionOp if its parent can contain multiple blocks.
196// TODO generalize the conditions for operations which can be inlined into.
197// func @func_execute_region_elim() {
198// "test.foo"() : () -> ()
199// %v = scf.execute_region -> i64 {
200// %c = "test.cmp"() : () -> i1
201// cf.cond_br %c, ^bb2, ^bb3
202// ^bb2:
203// %x = "test.val1"() : () -> i64
204// cf.br ^bb4(%x : i64)
205// ^bb3:
206// %y = "test.val2"() : () -> i64
207// cf.br ^bb4(%y : i64)
208// ^bb4(%z : i64):
209// scf.yield %z : i64
210// }
211// "test.bar"(%v) : (i64) -> ()
212// return
213// }
214//
215// becomes
216//
217// func @func_execute_region_elim() {
218// "test.foo"() : () -> ()
219// %c = "test.cmp"() : () -> i1
220// cf.cond_br %c, ^bb1, ^bb2
221// ^bb1: // pred: ^bb0
222// %x = "test.val1"() : () -> i64
223// cf.br ^bb3(%x : i64)
224// ^bb2: // pred: ^bb0
225// %y = "test.val2"() : () -> i64
226// cf.br ^bb3(%y : i64)
227// ^bb3(%z: i64): // 2 preds: ^bb1, ^bb2
228// "test.bar"(%z) : (i64) -> ()
229// return
230// }
231//
232struct MultiBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
233 using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern;
234
235 LogicalResult matchAndRewrite(ExecuteRegionOp op,
236 PatternRewriter &rewriter) const override {
237 if (!isa<FunctionOpInterface, ExecuteRegionOp>(op->getParentOp()))
238 return failure();
239
240 Block *prevBlock = op->getBlock();
241 Block *postBlock = rewriter.splitBlock(block: prevBlock, before: op->getIterator());
242 rewriter.setInsertionPointToEnd(prevBlock);
243
244 rewriter.create<cf::BranchOp>(op.getLoc(), &op.getRegion().front());
245
246 for (Block &blk : op.getRegion()) {
247 if (YieldOp yieldOp = dyn_cast<YieldOp>(blk.getTerminator())) {
248 rewriter.setInsertionPoint(yieldOp);
249 rewriter.create<cf::BranchOp>(yieldOp.getLoc(), postBlock,
250 yieldOp.getResults());
251 rewriter.eraseOp(yieldOp);
252 }
253 }
254
255 rewriter.inlineRegionBefore(op.getRegion(), postBlock);
256 SmallVector<Value> blockArgs;
257
258 for (auto res : op.getResults())
259 blockArgs.push_back(postBlock->addArgument(res.getType(), res.getLoc()));
260
261 rewriter.replaceOp(op, blockArgs);
262 return success();
263 }
264};
265
266void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results,
267 MLIRContext *context) {
268 results.add<SingleBlockExecuteInliner, MultiBlockExecuteInliner>(context);
269}
270
271void ExecuteRegionOp::getSuccessorRegions(
272 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
273 // If the predecessor is the ExecuteRegionOp, branch into the body.
274 if (point.isParent()) {
275 regions.push_back(RegionSuccessor(&getRegion()));
276 return;
277 }
278
279 // Otherwise, the region branches back to the parent operation.
280 regions.push_back(RegionSuccessor(getResults()));
281}
282
283//===----------------------------------------------------------------------===//
284// ConditionOp
285//===----------------------------------------------------------------------===//
286
287MutableOperandRange
288ConditionOp::getMutableSuccessorOperands(RegionBranchPoint point) {
289 assert((point.isParent() || point == getParentOp().getAfter()) &&
290 "condition op can only exit the loop or branch to the after"
291 "region");
292 // Pass all operands except the condition to the successor region.
293 return getArgsMutable();
294}
295
296void ConditionOp::getSuccessorRegions(
297 ArrayRef<Attribute> operands, SmallVectorImpl<RegionSuccessor> &regions) {
298 FoldAdaptor adaptor(operands, *this);
299
300 WhileOp whileOp = getParentOp();
301
302 // Condition can either lead to the after region or back to the parent op
303 // depending on whether the condition is true or not.
304 auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
305 if (!boolAttr || boolAttr.getValue())
306 regions.emplace_back(&whileOp.getAfter(),
307 whileOp.getAfter().getArguments());
308 if (!boolAttr || !boolAttr.getValue())
309 regions.emplace_back(whileOp.getResults());
310}
311
312//===----------------------------------------------------------------------===//
313// ForOp
314//===----------------------------------------------------------------------===//
315
316void ForOp::build(OpBuilder &builder, OperationState &result, Value lb,
317 Value ub, Value step, ValueRange initArgs,
318 BodyBuilderFn bodyBuilder) {
319 OpBuilder::InsertionGuard guard(builder);
320
321 result.addOperands({lb, ub, step});
322 result.addOperands(initArgs);
323 for (Value v : initArgs)
324 result.addTypes(v.getType());
325 Type t = lb.getType();
326 Region *bodyRegion = result.addRegion();
327 Block *bodyBlock = builder.createBlock(bodyRegion);
328 bodyBlock->addArgument(t, result.location);
329 for (Value v : initArgs)
330 bodyBlock->addArgument(v.getType(), v.getLoc());
331
332 // Create the default terminator if the builder is not provided and if the
333 // iteration arguments are not provided. Otherwise, leave this to the caller
334 // because we don't know which values to return from the loop.
335 if (initArgs.empty() && !bodyBuilder) {
336 ForOp::ensureTerminator(*bodyRegion, builder, result.location);
337 } else if (bodyBuilder) {
338 OpBuilder::InsertionGuard guard(builder);
339 builder.setInsertionPointToStart(bodyBlock);
340 bodyBuilder(builder, result.location, bodyBlock->getArgument(0),
341 bodyBlock->getArguments().drop_front());
342 }
343}
344
345LogicalResult ForOp::verify() {
346 // Check that the number of init args and op results is the same.
347 if (getInitArgs().size() != getNumResults())
348 return emitOpError(
349 "mismatch in number of loop-carried values and defined values");
350
351 return success();
352}
353
354LogicalResult ForOp::verifyRegions() {
355 // Check that the body defines as single block argument for the induction
356 // variable.
357 if (getInductionVar().getType() != getLowerBound().getType())
358 return emitOpError(
359 "expected induction variable to be same type as bounds and step");
360
361 if (getNumRegionIterArgs() != getNumResults())
362 return emitOpError(
363 "mismatch in number of basic block args and defined values");
364
365 auto initArgs = getInitArgs();
366 auto iterArgs = getRegionIterArgs();
367 auto opResults = getResults();
368 unsigned i = 0;
369 for (auto e : llvm::zip(initArgs, iterArgs, opResults)) {
370 if (std::get<0>(e).getType() != std::get<2>(e).getType())
371 return emitOpError() << "types mismatch between " << i
372 << "th iter operand and defined value";
373 if (std::get<1>(e).getType() != std::get<2>(e).getType())
374 return emitOpError() << "types mismatch between " << i
375 << "th iter region arg and defined value";
376
377 ++i;
378 }
379 return success();
380}
381
382std::optional<SmallVector<Value>> ForOp::getLoopInductionVars() {
383 return SmallVector<Value>{getInductionVar()};
384}
385
386std::optional<SmallVector<OpFoldResult>> ForOp::getLoopLowerBounds() {
387 return SmallVector<OpFoldResult>{OpFoldResult(getLowerBound())};
388}
389
390std::optional<SmallVector<OpFoldResult>> ForOp::getLoopSteps() {
391 return SmallVector<OpFoldResult>{OpFoldResult(getStep())};
392}
393
394std::optional<SmallVector<OpFoldResult>> ForOp::getLoopUpperBounds() {
395 return SmallVector<OpFoldResult>{OpFoldResult(getUpperBound())};
396}
397
398std::optional<ResultRange> ForOp::getLoopResults() { return getResults(); }
399
400/// Promotes the loop body of a forOp to its containing block if the forOp
401/// it can be determined that the loop has a single iteration.
402LogicalResult ForOp::promoteIfSingleIteration(RewriterBase &rewriter) {
403 std::optional<int64_t> tripCount =
404 constantTripCount(getLowerBound(), getUpperBound(), getStep());
405 if (!tripCount.has_value() || tripCount != 1)
406 return failure();
407
408 // Replace all results with the yielded values.
409 auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
410 rewriter.replaceAllUsesWith(getResults(), getYieldedValues());
411
412 // Replace block arguments with lower bound (replacement for IV) and
413 // iter_args.
414 SmallVector<Value> bbArgReplacements;
415 bbArgReplacements.push_back(getLowerBound());
416 llvm::append_range(bbArgReplacements, getInitArgs());
417
418 // Move the loop body operations to the loop's containing block.
419 rewriter.inlineBlockBefore(getBody(), getOperation()->getBlock(),
420 getOperation()->getIterator(), bbArgReplacements);
421
422 // Erase the old terminator and the loop.
423 rewriter.eraseOp(yieldOp);
424 rewriter.eraseOp(*this);
425
426 return success();
427}
428
429/// Prints the initialization list in the form of
430/// <prefix>(%inner = %outer, %inner2 = %outer2, <...>)
431/// where 'inner' values are assumed to be region arguments and 'outer' values
432/// are regular SSA values.
433static void printInitializationList(OpAsmPrinter &p,
434 Block::BlockArgListType blocksArgs,
435 ValueRange initializers,
436 StringRef prefix = "") {
437 assert(blocksArgs.size() == initializers.size() &&
438 "expected same length of arguments and initializers");
439 if (initializers.empty())
440 return;
441
442 p << prefix << '(';
443 llvm::interleaveComma(c: llvm::zip(t&: blocksArgs, u&: initializers), os&: p, each_fn: [&](auto it) {
444 p << std::get<0>(it) << " = " << std::get<1>(it);
445 });
446 p << ")";
447}
448
449void ForOp::print(OpAsmPrinter &p) {
450 p << " " << getInductionVar() << " = " << getLowerBound() << " to "
451 << getUpperBound() << " step " << getStep();
452
453 printInitializationList(p, getRegionIterArgs(), getInitArgs(), " iter_args");
454 if (!getInitArgs().empty())
455 p << " -> (" << getInitArgs().getTypes() << ')';
456 p << ' ';
457 if (Type t = getInductionVar().getType(); !t.isIndex())
458 p << " : " << t << ' ';
459 p.printRegion(getRegion(),
460 /*printEntryBlockArgs=*/false,
461 /*printBlockTerminators=*/!getInitArgs().empty());
462 p.printOptionalAttrDict((*this)->getAttrs());
463}
464
465ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
466 auto &builder = parser.getBuilder();
467 Type type;
468
469 OpAsmParser::Argument inductionVariable;
470 OpAsmParser::UnresolvedOperand lb, ub, step;
471
472 // Parse the induction variable followed by '='.
473 if (parser.parseOperand(inductionVariable.ssaName) || parser.parseEqual() ||
474 // Parse loop bounds.
475 parser.parseOperand(lb) || parser.parseKeyword("to") ||
476 parser.parseOperand(ub) || parser.parseKeyword("step") ||
477 parser.parseOperand(step))
478 return failure();
479
480 // Parse the optional initial iteration arguments.
481 SmallVector<OpAsmParser::Argument, 4> regionArgs;
482 SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
483 regionArgs.push_back(inductionVariable);
484
485 bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
486 if (hasIterArgs) {
487 // Parse assignment list and results type list.
488 if (parser.parseAssignmentList(regionArgs, operands) ||
489 parser.parseArrowTypeList(result.types))
490 return failure();
491 }
492
493 if (regionArgs.size() != result.types.size() + 1)
494 return parser.emitError(
495 parser.getNameLoc(),
496 "mismatch in number of loop-carried values and defined values");
497
498 // Parse optional type, else assume Index.
499 if (parser.parseOptionalColon())
500 type = builder.getIndexType();
501 else if (parser.parseType(type))
502 return failure();
503
504 // Set block argument types, so that they are known when parsing the region.
505 regionArgs.front().type = type;
506 for (auto [iterArg, type] :
507 llvm::zip_equal(llvm::drop_begin(regionArgs), result.types))
508 iterArg.type = type;
509
510 // Parse the body region.
511 Region *body = result.addRegion();
512 if (parser.parseRegion(*body, regionArgs))
513 return failure();
514 ForOp::ensureTerminator(*body, builder, result.location);
515
516 // Resolve input operands. This should be done after parsing the region to
517 // catch invalid IR where operands were defined inside of the region.
518 if (parser.resolveOperand(lb, type, result.operands) ||
519 parser.resolveOperand(ub, type, result.operands) ||
520 parser.resolveOperand(step, type, result.operands))
521 return failure();
522 if (hasIterArgs) {
523 for (auto argOperandType : llvm::zip_equal(llvm::drop_begin(regionArgs),
524 operands, result.types)) {
525 Type type = std::get<2>(argOperandType);
526 std::get<0>(argOperandType).type = type;
527 if (parser.resolveOperand(std::get<1>(argOperandType), type,
528 result.operands))
529 return failure();
530 }
531 }
532
533 // Parse the optional attribute list.
534 if (parser.parseOptionalAttrDict(result.attributes))
535 return failure();
536
537 return success();
538}
539
540SmallVector<Region *> ForOp::getLoopRegions() { return {&getRegion()}; }
541
542Block::BlockArgListType ForOp::getRegionIterArgs() {
543 return getBody()->getArguments().drop_front(getNumInductionVars());
544}
545
546MutableArrayRef<OpOperand> ForOp::getInitsMutable() {
547 return getInitArgsMutable();
548}
549
550FailureOr<LoopLikeOpInterface>
551ForOp::replaceWithAdditionalYields(RewriterBase &rewriter,
552 ValueRange newInitOperands,
553 bool replaceInitOperandUsesInLoop,
554 const NewYieldValuesFn &newYieldValuesFn) {
555 // Create a new loop before the existing one, with the extra operands.
556 OpBuilder::InsertionGuard g(rewriter);
557 rewriter.setInsertionPoint(getOperation());
558 auto inits = llvm::to_vector(getInitArgs());
559 inits.append(newInitOperands.begin(), newInitOperands.end());
560 scf::ForOp newLoop = rewriter.create<scf::ForOp>(
561 getLoc(), getLowerBound(), getUpperBound(), getStep(), inits,
562 [](OpBuilder &, Location, Value, ValueRange) {});
563 newLoop->setAttrs(getPrunedAttributeList(getOperation(), {}));
564
565 // Generate the new yield values and append them to the scf.yield operation.
566 auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
567 ArrayRef<BlockArgument> newIterArgs =
568 newLoop.getBody()->getArguments().take_back(newInitOperands.size());
569 {
570 OpBuilder::InsertionGuard g(rewriter);
571 rewriter.setInsertionPoint(yieldOp);
572 SmallVector<Value> newYieldedValues =
573 newYieldValuesFn(rewriter, getLoc(), newIterArgs);
574 assert(newInitOperands.size() == newYieldedValues.size() &&
575 "expected as many new yield values as new iter operands");
576 rewriter.modifyOpInPlace(yieldOp, [&]() {
577 yieldOp.getResultsMutable().append(newYieldedValues);
578 });
579 }
580
581 // Move the loop body to the new op.
582 rewriter.mergeBlocks(getBody(), newLoop.getBody(),
583 newLoop.getBody()->getArguments().take_front(
584 getBody()->getNumArguments()));
585
586 if (replaceInitOperandUsesInLoop) {
587 // Replace all uses of `newInitOperands` with the corresponding basic block
588 // arguments.
589 for (auto it : llvm::zip(newInitOperands, newIterArgs)) {
590 rewriter.replaceUsesWithIf(std::get<0>(it), std::get<1>(it),
591 [&](OpOperand &use) {
592 Operation *user = use.getOwner();
593 return newLoop->isProperAncestor(user);
594 });
595 }
596 }
597
598 // Replace the old loop.
599 rewriter.replaceOp(getOperation(),
600 newLoop->getResults().take_front(getNumResults()));
601 return cast<LoopLikeOpInterface>(newLoop.getOperation());
602}
603
604ForOp mlir::scf::getForInductionVarOwner(Value val) {
605 auto ivArg = llvm::dyn_cast<BlockArgument>(Val&: val);
606 if (!ivArg)
607 return ForOp();
608 assert(ivArg.getOwner() && "unlinked block argument");
609 auto *containingOp = ivArg.getOwner()->getParentOp();
610 return dyn_cast_or_null<ForOp>(containingOp);
611}
612
613OperandRange ForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
614 return getInitArgs();
615}
616
617void ForOp::getSuccessorRegions(RegionBranchPoint point,
618 SmallVectorImpl<RegionSuccessor> &regions) {
619 // Both the operation itself and the region may be branching into the body or
620 // back into the operation itself. It is possible for loop not to enter the
621 // body.
622 regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
623 regions.push_back(RegionSuccessor(getResults()));
624}
625
626SmallVector<Region *> ForallOp::getLoopRegions() { return {&getRegion()}; }
627
628/// Promotes the loop body of a forallOp to its containing block if it can be
629/// determined that the loop has a single iteration.
630LogicalResult scf::ForallOp::promoteIfSingleIteration(RewriterBase &rewriter) {
631 for (auto [lb, ub, step] :
632 llvm::zip(getMixedLowerBound(), getMixedUpperBound(), getMixedStep())) {
633 auto tripCount = constantTripCount(lb, ub, step);
634 if (!tripCount.has_value() || *tripCount != 1)
635 return failure();
636 }
637
638 promote(rewriter, *this);
639 return success();
640}
641
642Block::BlockArgListType ForallOp::getRegionIterArgs() {
643 return getBody()->getArguments().drop_front(getRank());
644}
645
646MutableArrayRef<OpOperand> ForallOp::getInitsMutable() {
647 return getOutputsMutable();
648}
649
650/// Promotes the loop body of a scf::ForallOp to its containing block.
651void mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp) {
652 OpBuilder::InsertionGuard g(rewriter);
653 scf::InParallelOp terminator = forallOp.getTerminator();
654
655 // Replace block arguments with lower bounds (replacements for IVs) and
656 // outputs.
657 SmallVector<Value> bbArgReplacements = forallOp.getLowerBound(rewriter);
658 bbArgReplacements.append(forallOp.getOutputs().begin(),
659 forallOp.getOutputs().end());
660
661 // Move the loop body operations to the loop's containing block.
662 rewriter.inlineBlockBefore(forallOp.getBody(), forallOp->getBlock(),
663 forallOp->getIterator(), bbArgReplacements);
664
665 // Replace the terminator with tensor.insert_slice ops.
666 rewriter.setInsertionPointAfter(forallOp);
667 SmallVector<Value> results;
668 results.reserve(N: forallOp.getResults().size());
669 for (auto &yieldingOp : terminator.getYieldingOps()) {
670 auto parallelInsertSliceOp =
671 cast<tensor::ParallelInsertSliceOp>(yieldingOp);
672
673 Value dst = parallelInsertSliceOp.getDest();
674 Value src = parallelInsertSliceOp.getSource();
675 if (llvm::isa<TensorType>(src.getType())) {
676 results.push_back(rewriter.create<tensor::InsertSliceOp>(
677 forallOp.getLoc(), dst.getType(), src, dst,
678 parallelInsertSliceOp.getOffsets(), parallelInsertSliceOp.getSizes(),
679 parallelInsertSliceOp.getStrides(),
680 parallelInsertSliceOp.getStaticOffsets(),
681 parallelInsertSliceOp.getStaticSizes(),
682 parallelInsertSliceOp.getStaticStrides()));
683 } else {
684 llvm_unreachable("unsupported terminator");
685 }
686 }
687 rewriter.replaceAllUsesWith(forallOp.getResults(), results);
688
689 // Erase the old terminator and the loop.
690 rewriter.eraseOp(op: terminator);
691 rewriter.eraseOp(op: forallOp);
692}
693
694LoopNest mlir::scf::buildLoopNest(
695 OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
696 ValueRange steps, ValueRange iterArgs,
697 function_ref<ValueVector(OpBuilder &, Location, ValueRange, ValueRange)>
698 bodyBuilder) {
699 assert(lbs.size() == ubs.size() &&
700 "expected the same number of lower and upper bounds");
701 assert(lbs.size() == steps.size() &&
702 "expected the same number of lower bounds and steps");
703
704 // If there are no bounds, call the body-building function and return early.
705 if (lbs.empty()) {
706 ValueVector results =
707 bodyBuilder ? bodyBuilder(builder, loc, ValueRange(), iterArgs)
708 : ValueVector();
709 assert(results.size() == iterArgs.size() &&
710 "loop nest body must return as many values as loop has iteration "
711 "arguments");
712 return LoopNest{{}, std::move(results)};
713 }
714
715 // First, create the loop structure iteratively using the body-builder
716 // callback of `ForOp::build`. Do not create `YieldOp`s yet.
717 OpBuilder::InsertionGuard guard(builder);
718 SmallVector<scf::ForOp, 4> loops;
719 SmallVector<Value, 4> ivs;
720 loops.reserve(lbs.size());
721 ivs.reserve(N: lbs.size());
722 ValueRange currentIterArgs = iterArgs;
723 Location currentLoc = loc;
724 for (unsigned i = 0, e = lbs.size(); i < e; ++i) {
725 auto loop = builder.create<scf::ForOp>(
726 currentLoc, lbs[i], ubs[i], steps[i], currentIterArgs,
727 [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv,
728 ValueRange args) {
729 ivs.push_back(iv);
730 // It is safe to store ValueRange args because it points to block
731 // arguments of a loop operation that we also own.
732 currentIterArgs = args;
733 currentLoc = nestedLoc;
734 });
735 // Set the builder to point to the body of the newly created loop. We don't
736 // do this in the callback because the builder is reset when the callback
737 // returns.
738 builder.setInsertionPointToStart(loop.getBody());
739 loops.push_back(loop);
740 }
741
742 // For all loops but the innermost, yield the results of the nested loop.
743 for (unsigned i = 0, e = loops.size() - 1; i < e; ++i) {
744 builder.setInsertionPointToEnd(loops[i].getBody());
745 builder.create<scf::YieldOp>(loc, loops[i + 1].getResults());
746 }
747
748 // In the body of the innermost loop, call the body building function if any
749 // and yield its results.
750 builder.setInsertionPointToStart(loops.back().getBody());
751 ValueVector results = bodyBuilder
752 ? bodyBuilder(builder, currentLoc, ivs,
753 loops.back().getRegionIterArgs())
754 : ValueVector();
755 assert(results.size() == iterArgs.size() &&
756 "loop nest body must return as many values as loop has iteration "
757 "arguments");
758 builder.setInsertionPointToEnd(loops.back().getBody());
759 builder.create<scf::YieldOp>(loc, results);
760
761 // Return the loops.
762 ValueVector nestResults;
763 llvm::append_range(nestResults, loops.front().getResults());
764 return LoopNest{std::move(loops), std::move(nestResults)};
765}
766
767LoopNest mlir::scf::buildLoopNest(
768 OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
769 ValueRange steps,
770 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
771 // Delegate to the main function by wrapping the body builder.
772 return buildLoopNest(builder, loc, lbs, ubs, steps, iterArgs: std::nullopt,
773 bodyBuilder: [&bodyBuilder](OpBuilder &nestedBuilder,
774 Location nestedLoc, ValueRange ivs,
775 ValueRange) -> ValueVector {
776 if (bodyBuilder)
777 bodyBuilder(nestedBuilder, nestedLoc, ivs);
778 return {};
779 });
780}
781
782SmallVector<Value>
783mlir::scf::replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp,
784 OpOperand &operand, Value replacement,
785 const ValueTypeCastFnTy &castFn) {
786 assert(operand.getOwner() == forOp);
787 Type oldType = operand.get().getType(), newType = replacement.getType();
788
789 // 1. Create new iter operands, exactly 1 is replaced.
790 assert(operand.getOperandNumber() >= forOp.getNumControlOperands() &&
791 "expected an iter OpOperand");
792 assert(operand.get().getType() != replacement.getType() &&
793 "Expected a different type");
794 SmallVector<Value> newIterOperands;
795 for (OpOperand &opOperand : forOp.getInitArgsMutable()) {
796 if (opOperand.getOperandNumber() == operand.getOperandNumber()) {
797 newIterOperands.push_back(replacement);
798 continue;
799 }
800 newIterOperands.push_back(opOperand.get());
801 }
802
803 // 2. Create the new forOp shell.
804 scf::ForOp newForOp = rewriter.create<scf::ForOp>(
805 forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
806 forOp.getStep(), newIterOperands);
807 newForOp->setAttrs(forOp->getAttrs());
808 Block &newBlock = newForOp.getRegion().front();
809 SmallVector<Value, 4> newBlockTransferArgs(newBlock.getArguments().begin(),
810 newBlock.getArguments().end());
811
812 // 3. Inject an incoming cast op at the beginning of the block for the bbArg
813 // corresponding to the `replacement` value.
814 OpBuilder::InsertionGuard g(rewriter);
815 rewriter.setInsertionPointToStart(&newBlock);
816 BlockArgument newRegionIterArg = newForOp.getTiedLoopRegionIterArg(
817 &newForOp->getOpOperand(operand.getOperandNumber()));
818 Value castIn = castFn(rewriter, newForOp.getLoc(), oldType, newRegionIterArg);
819 newBlockTransferArgs[newRegionIterArg.getArgNumber()] = castIn;
820
821 // 4. Steal the old block ops, mapping to the newBlockTransferArgs.
822 Block &oldBlock = forOp.getRegion().front();
823 rewriter.mergeBlocks(source: &oldBlock, dest: &newBlock, argValues: newBlockTransferArgs);
824
825 // 5. Inject an outgoing cast op at the end of the block and yield it instead.
826 auto clonedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
827 rewriter.setInsertionPoint(clonedYieldOp);
828 unsigned yieldIdx =
829 newRegionIterArg.getArgNumber() - forOp.getNumInductionVars();
830 Value castOut = castFn(rewriter, newForOp.getLoc(), newType,
831 clonedYieldOp.getOperand(yieldIdx));
832 SmallVector<Value> newYieldOperands = clonedYieldOp.getOperands();
833 newYieldOperands[yieldIdx] = castOut;
834 rewriter.create<scf::YieldOp>(newForOp.getLoc(), newYieldOperands);
835 rewriter.eraseOp(op: clonedYieldOp);
836
837 // 6. Inject an outgoing cast op after the forOp.
838 rewriter.setInsertionPointAfter(newForOp);
839 SmallVector<Value> newResults = newForOp.getResults();
840 newResults[yieldIdx] =
841 castFn(rewriter, newForOp.getLoc(), oldType, newResults[yieldIdx]);
842
843 return newResults;
844}
845
846namespace {
847// Fold away ForOp iter arguments when:
848// 1) The op yields the iter arguments.
849// 2) The argument's corresponding outer region iterators (inputs) are yielded.
850// 3) The iter arguments have no use and the corresponding (operation) results
851// have no use.
852//
853// These arguments must be defined outside of the ForOp region and can just be
854// forwarded after simplifying the op inits, yields and returns.
855//
856// The implementation uses `inlineBlockBefore` to steal the content of the
857// original ForOp and avoid cloning.
858struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
859 using OpRewritePattern<scf::ForOp>::OpRewritePattern;
860
861 LogicalResult matchAndRewrite(scf::ForOp forOp,
862 PatternRewriter &rewriter) const final {
863 bool canonicalize = false;
864
865 // An internal flat vector of block transfer
866 // arguments `newBlockTransferArgs` keeps the 1-1 mapping of original to
867 // transformed block argument mappings. This plays the role of a
868 // IRMapping for the particular use case of calling into
869 // `inlineBlockBefore`.
870 int64_t numResults = forOp.getNumResults();
871 SmallVector<bool, 4> keepMask;
872 keepMask.reserve(N: numResults);
873 SmallVector<Value, 4> newBlockTransferArgs, newIterArgs, newYieldValues,
874 newResultValues;
875 newBlockTransferArgs.reserve(N: 1 + numResults);
876 newBlockTransferArgs.push_back(Elt: Value()); // iv placeholder with null value
877 newIterArgs.reserve(N: forOp.getInitArgs().size());
878 newYieldValues.reserve(N: numResults);
879 newResultValues.reserve(N: numResults);
880 DenseMap<std::pair<Value, Value>, std::pair<Value, Value>> initYieldToArg;
881 for (auto [init, arg, result, yielded] :
882 llvm::zip(forOp.getInitArgs(), // iter from outside
883 forOp.getRegionIterArgs(), // iter inside region
884 forOp.getResults(), // op results
885 forOp.getYieldedValues() // iter yield
886 )) {
887 // Forwarded is `true` when:
888 // 1) The region `iter` argument is yielded.
889 // 2) The region `iter` argument the corresponding input is yielded.
890 // 3) The region `iter` argument has no use, and the corresponding op
891 // result has no use.
892 bool forwarded = (arg == yielded) || (init == yielded) ||
893 (arg.use_empty() && result.use_empty());
894 if (forwarded) {
895 canonicalize = true;
896 keepMask.push_back(false);
897 newBlockTransferArgs.push_back(init);
898 newResultValues.push_back(init);
899 continue;
900 }
901
902 // Check if a previous kept argument always has the same values for init
903 // and yielded values.
904 if (auto it = initYieldToArg.find({init, yielded});
905 it != initYieldToArg.end()) {
906 canonicalize = true;
907 keepMask.push_back(false);
908 auto [sameArg, sameResult] = it->second;
909 rewriter.replaceAllUsesWith(arg, sameArg);
910 rewriter.replaceAllUsesWith(result, sameResult);
911 // The replacement value doesn't matter because there are no uses.
912 newBlockTransferArgs.push_back(init);
913 newResultValues.push_back(init);
914 continue;
915 }
916
917 // This value is kept.
918 initYieldToArg.insert({{init, yielded}, {arg, result}});
919 keepMask.push_back(true);
920 newIterArgs.push_back(init);
921 newYieldValues.push_back(yielded);
922 newBlockTransferArgs.push_back(Value()); // placeholder with null value
923 newResultValues.push_back(Value()); // placeholder with null value
924 }
925
926 if (!canonicalize)
927 return failure();
928
929 scf::ForOp newForOp = rewriter.create<scf::ForOp>(
930 forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
931 forOp.getStep(), newIterArgs);
932 newForOp->setAttrs(forOp->getAttrs());
933 Block &newBlock = newForOp.getRegion().front();
934
935 // Replace the null placeholders with newly constructed values.
936 newBlockTransferArgs[0] = newBlock.getArgument(i: 0); // iv
937 for (unsigned idx = 0, collapsedIdx = 0, e = newResultValues.size();
938 idx != e; ++idx) {
939 Value &blockTransferArg = newBlockTransferArgs[1 + idx];
940 Value &newResultVal = newResultValues[idx];
941 assert((blockTransferArg && newResultVal) ||
942 (!blockTransferArg && !newResultVal));
943 if (!blockTransferArg) {
944 blockTransferArg = newForOp.getRegionIterArgs()[collapsedIdx];
945 newResultVal = newForOp.getResult(collapsedIdx++);
946 }
947 }
948
949 Block &oldBlock = forOp.getRegion().front();
950 assert(oldBlock.getNumArguments() == newBlockTransferArgs.size() &&
951 "unexpected argument size mismatch");
952
953 // No results case: the scf::ForOp builder already created a zero
954 // result terminator. Merge before this terminator and just get rid of the
955 // original terminator that has been merged in.
956 if (newIterArgs.empty()) {
957 auto newYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
958 rewriter.inlineBlockBefore(&oldBlock, newYieldOp, newBlockTransferArgs);
959 rewriter.eraseOp(op: newBlock.getTerminator()->getPrevNode());
960 rewriter.replaceOp(forOp, newResultValues);
961 return success();
962 }
963
964 // No terminator case: merge and rewrite the merged terminator.
965 auto cloneFilteredTerminator = [&](scf::YieldOp mergedTerminator) {
966 OpBuilder::InsertionGuard g(rewriter);
967 rewriter.setInsertionPoint(mergedTerminator);
968 SmallVector<Value, 4> filteredOperands;
969 filteredOperands.reserve(N: newResultValues.size());
970 for (unsigned idx = 0, e = keepMask.size(); idx < e; ++idx)
971 if (keepMask[idx])
972 filteredOperands.push_back(Elt: mergedTerminator.getOperand(idx));
973 rewriter.create<scf::YieldOp>(mergedTerminator.getLoc(),
974 filteredOperands);
975 };
976
977 rewriter.mergeBlocks(source: &oldBlock, dest: &newBlock, argValues: newBlockTransferArgs);
978 auto mergedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
979 cloneFilteredTerminator(mergedYieldOp);
980 rewriter.eraseOp(op: mergedYieldOp);
981 rewriter.replaceOp(forOp, newResultValues);
982 return success();
983 }
984};
985
986/// Util function that tries to compute a constant diff between u and l.
987/// Returns std::nullopt when the difference between two AffineValueMap is
988/// dynamic.
989static std::optional<int64_t> computeConstDiff(Value l, Value u) {
990 IntegerAttr clb, cub;
991 if (matchPattern(l, m_Constant(&clb)) && matchPattern(u, m_Constant(&cub))) {
992 llvm::APInt lbValue = clb.getValue();
993 llvm::APInt ubValue = cub.getValue();
994 return (ubValue - lbValue).getSExtValue();
995 }
996
997 // Else a simple pattern match for x + c or c + x
998 llvm::APInt diff;
999 if (matchPattern(
1000 u, m_Op<arith::AddIOp>(matchers::m_Val(l), m_ConstantInt(&diff))) ||
1001 matchPattern(
1002 u, m_Op<arith::AddIOp>(m_ConstantInt(&diff), matchers::m_Val(l))))
1003 return diff.getSExtValue();
1004 return std::nullopt;
1005}
1006
1007/// Rewriting pattern that erases loops that are known not to iterate, replaces
1008/// single-iteration loops with their bodies, and removes empty loops that
1009/// iterate at least once and only return values defined outside of the loop.
1010struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
1011 using OpRewritePattern<ForOp>::OpRewritePattern;
1012
1013 LogicalResult matchAndRewrite(ForOp op,
1014 PatternRewriter &rewriter) const override {
1015 // If the upper bound is the same as the lower bound, the loop does not
1016 // iterate, just remove it.
1017 if (op.getLowerBound() == op.getUpperBound()) {
1018 rewriter.replaceOp(op, op.getInitArgs());
1019 return success();
1020 }
1021
1022 std::optional<int64_t> diff =
1023 computeConstDiff(op.getLowerBound(), op.getUpperBound());
1024 if (!diff)
1025 return failure();
1026
1027 // If the loop is known to have 0 iterations, remove it.
1028 if (*diff <= 0) {
1029 rewriter.replaceOp(op, op.getInitArgs());
1030 return success();
1031 }
1032
1033 std::optional<llvm::APInt> maybeStepValue = op.getConstantStep();
1034 if (!maybeStepValue)
1035 return failure();
1036
1037 // If the loop is known to have 1 iteration, inline its body and remove the
1038 // loop.
1039 llvm::APInt stepValue = *maybeStepValue;
1040 if (stepValue.sge(RHS: *diff)) {
1041 SmallVector<Value, 4> blockArgs;
1042 blockArgs.reserve(N: op.getInitArgs().size() + 1);
1043 blockArgs.push_back(Elt: op.getLowerBound());
1044 llvm::append_range(blockArgs, op.getInitArgs());
1045 replaceOpWithRegion(rewriter, op, op.getRegion(), blockArgs);
1046 return success();
1047 }
1048
1049 // Now we are left with loops that have more than 1 iterations.
1050 Block &block = op.getRegion().front();
1051 if (!llvm::hasSingleElement(C&: block))
1052 return failure();
1053 // If the loop is empty, iterates at least once, and only returns values
1054 // defined outside of the loop, remove it and replace it with yield values.
1055 if (llvm::any_of(op.getYieldedValues(),
1056 [&](Value v) { return !op.isDefinedOutsideOfLoop(v); }))
1057 return failure();
1058 rewriter.replaceOp(op, op.getYieldedValues());
1059 return success();
1060 }
1061};
1062
1063/// Fold scf.for iter_arg/result pairs that go through incoming/ougoing
1064/// a tensor.cast op pair so as to pull the tensor.cast inside the scf.for:
1065///
1066/// ```
1067/// %0 = tensor.cast %t0 : tensor<32x1024xf32> to tensor<?x?xf32>
1068/// %1 = scf.for %i = %c0 to %c1024 step %c32 iter_args(%iter_t0 = %0)
1069/// -> (tensor<?x?xf32>) {
1070/// %2 = call @do(%iter_t0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
1071/// scf.yield %2 : tensor<?x?xf32>
1072/// }
1073/// use_of(%1)
1074/// ```
1075///
1076/// folds into:
1077///
1078/// ```
1079/// %0 = scf.for %arg2 = %c0 to %c1024 step %c32 iter_args(%arg3 = %arg0)
1080/// -> (tensor<32x1024xf32>) {
1081/// %2 = tensor.cast %arg3 : tensor<32x1024xf32> to tensor<?x?xf32>
1082/// %3 = call @do(%2) : (tensor<?x?xf32>) -> tensor<?x?xf32>
1083/// %4 = tensor.cast %3 : tensor<?x?xf32> to tensor<32x1024xf32>
1084/// scf.yield %4 : tensor<32x1024xf32>
1085/// }
1086/// %1 = tensor.cast %0 : tensor<32x1024xf32> to tensor<?x?xf32>
1087/// use_of(%1)
1088/// ```
1089struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
1090 using OpRewritePattern<ForOp>::OpRewritePattern;
1091
1092 LogicalResult matchAndRewrite(ForOp op,
1093 PatternRewriter &rewriter) const override {
1094 for (auto it : llvm::zip(op.getInitArgsMutable(), op.getResults())) {
1095 OpOperand &iterOpOperand = std::get<0>(it);
1096 auto incomingCast = iterOpOperand.get().getDefiningOp<tensor::CastOp>();
1097 if (!incomingCast ||
1098 incomingCast.getSource().getType() == incomingCast.getType())
1099 continue;
1100 // If the dest type of the cast does not preserve static information in
1101 // the source type.
1102 if (!tensor::preservesStaticInformation(
1103 incomingCast.getDest().getType(),
1104 incomingCast.getSource().getType()))
1105 continue;
1106 if (!std::get<1>(it).hasOneUse())
1107 continue;
1108
1109 // Create a new ForOp with that iter operand replaced.
1110 rewriter.replaceOp(
1111 op, replaceAndCastForOpIterArg(
1112 rewriter, op, iterOpOperand, incomingCast.getSource(),
1113 [](OpBuilder &b, Location loc, Type type, Value source) {
1114 return b.create<tensor::CastOp>(loc, type, source);
1115 }));
1116 return success();
1117 }
1118 return failure();
1119 }
1120};
1121
1122} // namespace
1123
1124void ForOp::getCanonicalizationPatterns(RewritePatternSet &results,
1125 MLIRContext *context) {
1126 results.add<ForOpIterArgsFolder, SimplifyTrivialLoops, ForOpTensorCastFolder>(
1127 context);
1128}
1129
1130std::optional<APInt> ForOp::getConstantStep() {
1131 IntegerAttr step;
1132 if (matchPattern(getStep(), m_Constant(&step)))
1133 return step.getValue();
1134 return {};
1135}
1136
1137std::optional<MutableArrayRef<OpOperand>> ForOp::getYieldedValuesMutable() {
1138 return cast<scf::YieldOp>(getBody()->getTerminator()).getResultsMutable();
1139}
1140
1141Speculation::Speculatability ForOp::getSpeculatability() {
1142 // `scf.for (I = Start; I < End; I += 1)` terminates for all values of Start
1143 // and End.
1144 if (auto constantStep = getConstantStep())
1145 if (*constantStep == 1)
1146 return Speculation::RecursivelySpeculatable;
1147
1148 // For Step != 1, the loop may not terminate. We can add more smarts here if
1149 // needed.
1150 return Speculation::NotSpeculatable;
1151}
1152
1153//===----------------------------------------------------------------------===//
1154// ForallOp
1155//===----------------------------------------------------------------------===//
1156
1157LogicalResult ForallOp::verify() {
1158 unsigned numLoops = getRank();
1159 // Check number of outputs.
1160 if (getNumResults() != getOutputs().size())
1161 return emitOpError("produces ")
1162 << getNumResults() << " results, but has only "
1163 << getOutputs().size() << " outputs";
1164
1165 // Check that the body defines block arguments for thread indices and outputs.
1166 auto *body = getBody();
1167 if (body->getNumArguments() != numLoops + getOutputs().size())
1168 return emitOpError("region expects ") << numLoops << " arguments";
1169 for (int64_t i = 0; i < numLoops; ++i)
1170 if (!body->getArgument(i).getType().isIndex())
1171 return emitOpError("expects ")
1172 << i << "-th block argument to be an index";
1173 for (unsigned i = 0; i < getOutputs().size(); ++i)
1174 if (body->getArgument(i + numLoops).getType() != getOutputs()[i].getType())
1175 return emitOpError("type mismatch between ")
1176 << i << "-th output and corresponding block argument";
1177 if (getMapping().has_value() && !getMapping()->empty()) {
1178 if (static_cast<int64_t>(getMapping()->size()) != numLoops)
1179 return emitOpError() << "mapping attribute size must match op rank";
1180 for (auto map : getMapping()->getValue()) {
1181 if (!isa<DeviceMappingAttrInterface>(map))
1182 return emitOpError()
1183 << getMappingAttrName() << " is not device mapping attribute";
1184 }
1185 }
1186
1187 // Verify mixed static/dynamic control variables.
1188 Operation *op = getOperation();
1189 if (failed(verifyListOfOperandsOrIntegers(op, "lower bound", numLoops,
1190 getStaticLowerBound(),
1191 getDynamicLowerBound())))
1192 return failure();
1193 if (failed(verifyListOfOperandsOrIntegers(op, "upper bound", numLoops,
1194 getStaticUpperBound(),
1195 getDynamicUpperBound())))
1196 return failure();
1197 if (failed(verifyListOfOperandsOrIntegers(op, "step", numLoops,
1198 getStaticStep(), getDynamicStep())))
1199 return failure();
1200
1201 return success();
1202}
1203
1204void ForallOp::print(OpAsmPrinter &p) {
1205 Operation *op = getOperation();
1206 p << " (" << getInductionVars();
1207 if (isNormalized()) {
1208 p << ") in ";
1209 printDynamicIndexList(p, op, getDynamicUpperBound(), getStaticUpperBound(),
1210 /*valueTypes=*/{}, /*scalables=*/{},
1211 OpAsmParser::Delimiter::Paren);
1212 } else {
1213 p << ") = ";
1214 printDynamicIndexList(p, op, getDynamicLowerBound(), getStaticLowerBound(),
1215 /*valueTypes=*/{}, /*scalables=*/{},
1216 OpAsmParser::Delimiter::Paren);
1217 p << " to ";
1218 printDynamicIndexList(p, op, getDynamicUpperBound(), getStaticUpperBound(),
1219 /*valueTypes=*/{}, /*scalables=*/{},
1220 OpAsmParser::Delimiter::Paren);
1221 p << " step ";
1222 printDynamicIndexList(p, op, getDynamicStep(), getStaticStep(),
1223 /*valueTypes=*/{}, /*scalables=*/{},
1224 OpAsmParser::Delimiter::Paren);
1225 }
1226 printInitializationList(p, getRegionOutArgs(), getOutputs(), " shared_outs");
1227 p << " ";
1228 if (!getRegionOutArgs().empty())
1229 p << "-> (" << getResultTypes() << ") ";
1230 p.printRegion(getRegion(),
1231 /*printEntryBlockArgs=*/false,
1232 /*printBlockTerminators=*/getNumResults() > 0);
1233 p.printOptionalAttrDict(op->getAttrs(), {getOperandSegmentSizesAttrName(),
1234 getStaticLowerBoundAttrName(),
1235 getStaticUpperBoundAttrName(),
1236 getStaticStepAttrName()});
1237}
1238
1239ParseResult ForallOp::parse(OpAsmParser &parser, OperationState &result) {
1240 OpBuilder b(parser.getContext());
1241 auto indexType = b.getIndexType();
1242
1243 // Parse an opening `(` followed by thread index variables followed by `)`
1244 // TODO: when we can refer to such "induction variable"-like handles from the
1245 // declarative assembly format, we can implement the parser as a custom hook.
1246 SmallVector<OpAsmParser::Argument, 4> ivs;
1247 if (parser.parseArgumentList(ivs, OpAsmParser::Delimiter::Paren))
1248 return failure();
1249
1250 DenseI64ArrayAttr staticLbs, staticUbs, staticSteps;
1251 SmallVector<OpAsmParser::UnresolvedOperand> dynamicLbs, dynamicUbs,
1252 dynamicSteps;
1253 if (succeeded(parser.parseOptionalKeyword("in"))) {
1254 // Parse upper bounds.
1255 if (parseDynamicIndexList(parser, dynamicUbs, staticUbs,
1256 /*valueTypes=*/nullptr,
1257 OpAsmParser::Delimiter::Paren) ||
1258 parser.resolveOperands(dynamicUbs, indexType, result.operands))
1259 return failure();
1260
1261 unsigned numLoops = ivs.size();
1262 staticLbs = b.getDenseI64ArrayAttr(SmallVector<int64_t>(numLoops, 0));
1263 staticSteps = b.getDenseI64ArrayAttr(SmallVector<int64_t>(numLoops, 1));
1264 } else {
1265 // Parse lower bounds.
1266 if (parser.parseEqual() ||
1267 parseDynamicIndexList(parser, dynamicLbs, staticLbs,
1268 /*valueTypes=*/nullptr,
1269 OpAsmParser::Delimiter::Paren) ||
1270
1271 parser.resolveOperands(dynamicLbs, indexType, result.operands))
1272 return failure();
1273
1274 // Parse upper bounds.
1275 if (parser.parseKeyword("to") ||
1276 parseDynamicIndexList(parser, dynamicUbs, staticUbs,
1277 /*valueTypes=*/nullptr,
1278 OpAsmParser::Delimiter::Paren) ||
1279 parser.resolveOperands(dynamicUbs, indexType, result.operands))
1280 return failure();
1281
1282 // Parse step values.
1283 if (parser.parseKeyword("step") ||
1284 parseDynamicIndexList(parser, dynamicSteps, staticSteps,
1285 /*valueTypes=*/nullptr,
1286 OpAsmParser::Delimiter::Paren) ||
1287 parser.resolveOperands(dynamicSteps, indexType, result.operands))
1288 return failure();
1289 }
1290
1291 // Parse out operands and results.
1292 SmallVector<OpAsmParser::Argument, 4> regionOutArgs;
1293 SmallVector<OpAsmParser::UnresolvedOperand, 4> outOperands;
1294 SMLoc outOperandsLoc = parser.getCurrentLocation();
1295 if (succeeded(parser.parseOptionalKeyword("shared_outs"))) {
1296 if (outOperands.size() != result.types.size())
1297 return parser.emitError(outOperandsLoc,
1298 "mismatch between out operands and types");
1299 if (parser.parseAssignmentList(regionOutArgs, outOperands) ||
1300 parser.parseOptionalArrowTypeList(result.types) ||
1301 parser.resolveOperands(outOperands, result.types, outOperandsLoc,
1302 result.operands))
1303 return failure();
1304 }
1305
1306 // Parse region.
1307 SmallVector<OpAsmParser::Argument, 4> regionArgs;
1308 std::unique_ptr<Region> region = std::make_unique<Region>();
1309 for (auto &iv : ivs) {
1310 iv.type = b.getIndexType();
1311 regionArgs.push_back(iv);
1312 }
1313 for (const auto &it : llvm::enumerate(regionOutArgs)) {
1314 auto &out = it.value();
1315 out.type = result.types[it.index()];
1316 regionArgs.push_back(out);
1317 }
1318 if (parser.parseRegion(*region, regionArgs))
1319 return failure();
1320
1321 // Ensure terminator and move region.
1322 ForallOp::ensureTerminator(*region, b, result.location);
1323 result.addRegion(std::move(region));
1324
1325 // Parse the optional attribute list.
1326 if (parser.parseOptionalAttrDict(result.attributes))
1327 return failure();
1328
1329 result.addAttribute("staticLowerBound", staticLbs);
1330 result.addAttribute("staticUpperBound", staticUbs);
1331 result.addAttribute("staticStep", staticSteps);
1332 result.addAttribute("operandSegmentSizes",
1333 parser.getBuilder().getDenseI32ArrayAttr(
1334 {static_cast<int32_t>(dynamicLbs.size()),
1335 static_cast<int32_t>(dynamicUbs.size()),
1336 static_cast<int32_t>(dynamicSteps.size()),
1337 static_cast<int32_t>(outOperands.size())}));
1338 return success();
1339}
1340
1341// Builder that takes loop bounds.
1342void ForallOp::build(
1343 mlir::OpBuilder &b, mlir::OperationState &result,
1344 ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs,
1345 ArrayRef<OpFoldResult> steps, ValueRange outputs,
1346 std::optional<ArrayAttr> mapping,
1347 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
1348 SmallVector<int64_t> staticLbs, staticUbs, staticSteps;
1349 SmallVector<Value> dynamicLbs, dynamicUbs, dynamicSteps;
1350 dispatchIndexOpFoldResults(lbs, dynamicLbs, staticLbs);
1351 dispatchIndexOpFoldResults(ubs, dynamicUbs, staticUbs);
1352 dispatchIndexOpFoldResults(steps, dynamicSteps, staticSteps);
1353
1354 result.addOperands(dynamicLbs);
1355 result.addOperands(dynamicUbs);
1356 result.addOperands(dynamicSteps);
1357 result.addOperands(outputs);
1358 result.addTypes(TypeRange(outputs));
1359
1360 result.addAttribute(getStaticLowerBoundAttrName(result.name),
1361 b.getDenseI64ArrayAttr(staticLbs));
1362 result.addAttribute(getStaticUpperBoundAttrName(result.name),
1363 b.getDenseI64ArrayAttr(staticUbs));
1364 result.addAttribute(getStaticStepAttrName(result.name),
1365 b.getDenseI64ArrayAttr(staticSteps));
1366 result.addAttribute(
1367 "operandSegmentSizes",
1368 b.getDenseI32ArrayAttr({static_cast<int32_t>(dynamicLbs.size()),
1369 static_cast<int32_t>(dynamicUbs.size()),
1370 static_cast<int32_t>(dynamicSteps.size()),
1371 static_cast<int32_t>(outputs.size())}));
1372 if (mapping.has_value()) {
1373 result.addAttribute(ForallOp::getMappingAttrName(result.name),
1374 mapping.value());
1375 }
1376
1377 Region *bodyRegion = result.addRegion();
1378 OpBuilder::InsertionGuard g(b);
1379 b.createBlock(bodyRegion);
1380 Block &bodyBlock = bodyRegion->front();
1381
1382 // Add block arguments for indices and outputs.
1383 bodyBlock.addArguments(
1384 SmallVector<Type>(lbs.size(), b.getIndexType()),
1385 SmallVector<Location>(staticLbs.size(), result.location));
1386 bodyBlock.addArguments(
1387 TypeRange(outputs),
1388 SmallVector<Location>(outputs.size(), result.location));
1389
1390 b.setInsertionPointToStart(&bodyBlock);
1391 if (!bodyBuilderFn) {
1392 ForallOp::ensureTerminator(*bodyRegion, b, result.location);
1393 return;
1394 }
1395 bodyBuilderFn(b, result.location, bodyBlock.getArguments());
1396}
1397
1398// Builder that takes loop bounds.
1399void ForallOp::build(
1400 mlir::OpBuilder &b, mlir::OperationState &result,
1401 ArrayRef<OpFoldResult> ubs, ValueRange outputs,
1402 std::optional<ArrayAttr> mapping,
1403 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
1404 unsigned numLoops = ubs.size();
1405 SmallVector<OpFoldResult> lbs(numLoops, b.getIndexAttr(0));
1406 SmallVector<OpFoldResult> steps(numLoops, b.getIndexAttr(1));
1407 build(b, result, lbs, ubs, steps, outputs, mapping, bodyBuilderFn);
1408}
1409
1410// Checks if the lbs are zeros and steps are ones.
1411bool ForallOp::isNormalized() {
1412 auto allEqual = [](ArrayRef<OpFoldResult> results, int64_t val) {
1413 return llvm::all_of(results, [&](OpFoldResult ofr) {
1414 auto intValue = getConstantIntValue(ofr);
1415 return intValue.has_value() && intValue == val;
1416 });
1417 };
1418 return allEqual(getMixedLowerBound(), 0) && allEqual(getMixedStep(), 1);
1419}
1420
1421InParallelOp ForallOp::getTerminator() {
1422 return cast<InParallelOp>(getBody()->getTerminator());
1423}
1424
1425SmallVector<Operation *> ForallOp::getCombiningOps(BlockArgument bbArg) {
1426 SmallVector<Operation *> storeOps;
1427 InParallelOp inParallelOp = getTerminator();
1428 for (Operation &yieldOp : inParallelOp.getYieldingOps()) {
1429 if (auto parallelInsertSliceOp =
1430 dyn_cast<tensor::ParallelInsertSliceOp>(yieldOp);
1431 parallelInsertSliceOp && parallelInsertSliceOp.getDest() == bbArg) {
1432 storeOps.push_back(parallelInsertSliceOp);
1433 }
1434 }
1435 return storeOps;
1436}
1437
1438std::optional<SmallVector<Value>> ForallOp::getLoopInductionVars() {
1439 return SmallVector<Value>{getBody()->getArguments().take_front(getRank())};
1440}
1441
1442// Get lower bounds as OpFoldResult.
1443std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopLowerBounds() {
1444 Builder b(getOperation()->getContext());
1445 return getMixedValues(getStaticLowerBound(), getDynamicLowerBound(), b);
1446}
1447
1448// Get upper bounds as OpFoldResult.
1449std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopUpperBounds() {
1450 Builder b(getOperation()->getContext());
1451 return getMixedValues(getStaticUpperBound(), getDynamicUpperBound(), b);
1452}
1453
1454// Get steps as OpFoldResult.
1455std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopSteps() {
1456 Builder b(getOperation()->getContext());
1457 return getMixedValues(getStaticStep(), getDynamicStep(), b);
1458}
1459
1460ForallOp mlir::scf::getForallOpThreadIndexOwner(Value val) {
1461 auto tidxArg = llvm::dyn_cast<BlockArgument>(Val&: val);
1462 if (!tidxArg)
1463 return ForallOp();
1464 assert(tidxArg.getOwner() && "unlinked block argument");
1465 auto *containingOp = tidxArg.getOwner()->getParentOp();
1466 return dyn_cast<ForallOp>(containingOp);
1467}
1468
1469namespace {
1470/// Fold tensor.dim(forall shared_outs(... = %t)) to tensor.dim(%t).
1471struct DimOfForallOp : public OpRewritePattern<tensor::DimOp> {
1472 using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
1473
1474 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1475 PatternRewriter &rewriter) const final {
1476 auto forallOp = dimOp.getSource().getDefiningOp<ForallOp>();
1477 if (!forallOp)
1478 return failure();
1479 Value sharedOut =
1480 forallOp.getTiedOpOperand(llvm::cast<OpResult>(dimOp.getSource()))
1481 ->get();
1482 rewriter.modifyOpInPlace(
1483 dimOp, [&]() { dimOp.getSourceMutable().assign(sharedOut); });
1484 return success();
1485 }
1486};
1487
1488class ForallOpControlOperandsFolder : public OpRewritePattern<ForallOp> {
1489public:
1490 using OpRewritePattern<ForallOp>::OpRewritePattern;
1491
1492 LogicalResult matchAndRewrite(ForallOp op,
1493 PatternRewriter &rewriter) const override {
1494 SmallVector<OpFoldResult> mixedLowerBound(op.getMixedLowerBound());
1495 SmallVector<OpFoldResult> mixedUpperBound(op.getMixedUpperBound());
1496 SmallVector<OpFoldResult> mixedStep(op.getMixedStep());
1497 if (failed(Result: foldDynamicIndexList(ofrs&: mixedLowerBound)) &&
1498 failed(Result: foldDynamicIndexList(ofrs&: mixedUpperBound)) &&
1499 failed(Result: foldDynamicIndexList(ofrs&: mixedStep)))
1500 return failure();
1501
1502 rewriter.modifyOpInPlace(op, [&]() {
1503 SmallVector<Value> dynamicLowerBound, dynamicUpperBound, dynamicStep;
1504 SmallVector<int64_t> staticLowerBound, staticUpperBound, staticStep;
1505 dispatchIndexOpFoldResults(ofrs: mixedLowerBound, dynamicVec&: dynamicLowerBound,
1506 staticVec&: staticLowerBound);
1507 op.getDynamicLowerBoundMutable().assign(dynamicLowerBound);
1508 op.setStaticLowerBound(staticLowerBound);
1509
1510 dispatchIndexOpFoldResults(ofrs: mixedUpperBound, dynamicVec&: dynamicUpperBound,
1511 staticVec&: staticUpperBound);
1512 op.getDynamicUpperBoundMutable().assign(dynamicUpperBound);
1513 op.setStaticUpperBound(staticUpperBound);
1514
1515 dispatchIndexOpFoldResults(ofrs: mixedStep, dynamicVec&: dynamicStep, staticVec&: staticStep);
1516 op.getDynamicStepMutable().assign(dynamicStep);
1517 op.setStaticStep(staticStep);
1518
1519 op->setAttr(ForallOp::getOperandSegmentSizeAttr(),
1520 rewriter.getDenseI32ArrayAttr(
1521 {static_cast<int32_t>(dynamicLowerBound.size()),
1522 static_cast<int32_t>(dynamicUpperBound.size()),
1523 static_cast<int32_t>(dynamicStep.size()),
1524 static_cast<int32_t>(op.getNumResults())}));
1525 });
1526 return success();
1527 }
1528};
1529
1530/// The following canonicalization pattern folds the iter arguments of
1531/// scf.forall op if :-
1532/// 1. The corresponding result has zero uses.
1533/// 2. The iter argument is NOT being modified within the loop body.
1534/// uses.
1535///
1536/// Example of first case :-
1537/// INPUT:
1538/// %res:3 = scf.forall ... shared_outs(%arg0 = %a, %arg1 = %b, %arg2 = %c)
1539/// {
1540/// ...
1541/// <SOME USE OF %arg0>
1542/// <SOME USE OF %arg1>
1543/// <SOME USE OF %arg2>
1544/// ...
1545/// scf.forall.in_parallel {
1546/// <STORE OP WITH DESTINATION %arg1>
1547/// <STORE OP WITH DESTINATION %arg0>
1548/// <STORE OP WITH DESTINATION %arg2>
1549/// }
1550/// }
1551/// return %res#1
1552///
1553/// OUTPUT:
1554/// %res:3 = scf.forall ... shared_outs(%new_arg0 = %b)
1555/// {
1556/// ...
1557/// <SOME USE OF %a>
1558/// <SOME USE OF %new_arg0>
1559/// <SOME USE OF %c>
1560/// ...
1561/// scf.forall.in_parallel {
1562/// <STORE OP WITH DESTINATION %new_arg0>
1563/// }
1564/// }
1565/// return %res
1566///
1567/// NOTE: 1. All uses of the folded shared_outs (iter argument) within the
1568/// scf.forall is replaced by their corresponding operands.
1569/// 2. Even if there are <STORE OP WITH DESTINATION *> ops within the body
1570/// of the scf.forall besides within scf.forall.in_parallel terminator,
1571/// this canonicalization remains valid. For more details, please refer
1572/// to :
1573/// https://github.com/llvm/llvm-project/pull/90189#discussion_r1589011124
1574/// 3. TODO(avarma): Generalize it for other store ops. Currently it
1575/// handles tensor.parallel_insert_slice ops only.
1576///
1577/// Example of second case :-
1578/// INPUT:
1579/// %res:2 = scf.forall ... shared_outs(%arg0 = %a, %arg1 = %b)
1580/// {
1581/// ...
1582/// <SOME USE OF %arg0>
1583/// <SOME USE OF %arg1>
1584/// ...
1585/// scf.forall.in_parallel {
1586/// <STORE OP WITH DESTINATION %arg1>
1587/// }
1588/// }
1589/// return %res#0, %res#1
1590///
1591/// OUTPUT:
1592/// %res = scf.forall ... shared_outs(%new_arg0 = %b)
1593/// {
1594/// ...
1595/// <SOME USE OF %a>
1596/// <SOME USE OF %new_arg0>
1597/// ...
1598/// scf.forall.in_parallel {
1599/// <STORE OP WITH DESTINATION %new_arg0>
1600/// }
1601/// }
1602/// return %a, %res
1603struct ForallOpIterArgsFolder : public OpRewritePattern<ForallOp> {
1604 using OpRewritePattern<ForallOp>::OpRewritePattern;
1605
1606 LogicalResult matchAndRewrite(ForallOp forallOp,
1607 PatternRewriter &rewriter) const final {
1608 // Step 1: For a given i-th result of scf.forall, check the following :-
1609 // a. If it has any use.
1610 // b. If the corresponding iter argument is being modified within
1611 // the loop, i.e. has at least one store op with the iter arg as
1612 // its destination operand. For this we use
1613 // ForallOp::getCombiningOps(iter_arg).
1614 //
1615 // Based on the check we maintain the following :-
1616 // a. `resultToDelete` - i-th result of scf.forall that'll be
1617 // deleted.
1618 // b. `resultToReplace` - i-th result of the old scf.forall
1619 // whose uses will be replaced by the new scf.forall.
1620 // c. `newOuts` - the shared_outs' operand of the new scf.forall
1621 // corresponding to the i-th result with at least one use.
1622 SetVector<OpResult> resultToDelete;
1623 SmallVector<Value> resultToReplace;
1624 SmallVector<Value> newOuts;
1625 for (OpResult result : forallOp.getResults()) {
1626 OpOperand *opOperand = forallOp.getTiedOpOperand(result);
1627 BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
1628 if (result.use_empty() || forallOp.getCombiningOps(blockArg).empty()) {
1629 resultToDelete.insert(result);
1630 } else {
1631 resultToReplace.push_back(result);
1632 newOuts.push_back(opOperand->get());
1633 }
1634 }
1635
1636 // Return early if all results of scf.forall have at least one use and being
1637 // modified within the loop.
1638 if (resultToDelete.empty())
1639 return failure();
1640
1641 // Step 2: For the the i-th result, do the following :-
1642 // a. Fetch the corresponding BlockArgument.
1643 // b. Look for store ops (currently tensor.parallel_insert_slice)
1644 // with the BlockArgument as its destination operand.
1645 // c. Remove the operations fetched in b.
1646 for (OpResult result : resultToDelete) {
1647 OpOperand *opOperand = forallOp.getTiedOpOperand(result);
1648 BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
1649 SmallVector<Operation *> combiningOps =
1650 forallOp.getCombiningOps(blockArg);
1651 for (Operation *combiningOp : combiningOps)
1652 rewriter.eraseOp(combiningOp);
1653 }
1654
1655 // Step 3. Create a new scf.forall op with the new shared_outs' operands
1656 // fetched earlier
1657 auto newForallOp = rewriter.create<scf::ForallOp>(
1658 forallOp.getLoc(), forallOp.getMixedLowerBound(),
1659 forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOuts,
1660 forallOp.getMapping(),
1661 /*bodyBuilderFn =*/[](OpBuilder &, Location, ValueRange) {});
1662
1663 // Step 4. Merge the block of the old scf.forall into the newly created
1664 // scf.forall using the new set of arguments.
1665 Block *loopBody = forallOp.getBody();
1666 Block *newLoopBody = newForallOp.getBody();
1667 ArrayRef<BlockArgument> newBbArgs = newLoopBody->getArguments();
1668 // Form initial new bbArg list with just the control operands of the new
1669 // scf.forall op.
1670 SmallVector<Value> newBlockArgs =
1671 llvm::map_to_vector(newBbArgs.take_front(N: forallOp.getRank()),
1672 [](BlockArgument b) -> Value { return b; });
1673 Block::BlockArgListType newSharedOutsArgs = newForallOp.getRegionOutArgs();
1674 unsigned index = 0;
1675 // Take the new corresponding bbArg if the old bbArg was used as a
1676 // destination in the in_parallel op. For all other bbArgs, use the
1677 // corresponding init_arg from the old scf.forall op.
1678 for (OpResult result : forallOp.getResults()) {
1679 if (resultToDelete.count(result)) {
1680 newBlockArgs.push_back(forallOp.getTiedOpOperand(result)->get());
1681 } else {
1682 newBlockArgs.push_back(newSharedOutsArgs[index++]);
1683 }
1684 }
1685 rewriter.mergeBlocks(source: loopBody, dest: newLoopBody, argValues: newBlockArgs);
1686
1687 // Step 5. Replace the uses of result of old scf.forall with that of the new
1688 // scf.forall.
1689 for (auto &&[oldResult, newResult] :
1690 llvm::zip(resultToReplace, newForallOp->getResults()))
1691 rewriter.replaceAllUsesWith(oldResult, newResult);
1692
1693 // Step 6. Replace the uses of those values that either has no use or are
1694 // not being modified within the loop with the corresponding
1695 // OpOperand.
1696 for (OpResult oldResult : resultToDelete)
1697 rewriter.replaceAllUsesWith(oldResult,
1698 forallOp.getTiedOpOperand(oldResult)->get());
1699 return success();
1700 }
1701};
1702
1703struct ForallOpSingleOrZeroIterationDimsFolder
1704 : public OpRewritePattern<ForallOp> {
1705 using OpRewritePattern<ForallOp>::OpRewritePattern;
1706
1707 LogicalResult matchAndRewrite(ForallOp op,
1708 PatternRewriter &rewriter) const override {
1709 // Do not fold dimensions if they are mapped to processing units.
1710 if (op.getMapping().has_value() && !op.getMapping()->empty())
1711 return failure();
1712 Location loc = op.getLoc();
1713
1714 // Compute new loop bounds that omit all single-iteration loop dimensions.
1715 SmallVector<OpFoldResult> newMixedLowerBounds, newMixedUpperBounds,
1716 newMixedSteps;
1717 IRMapping mapping;
1718 for (auto [lb, ub, step, iv] :
1719 llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
1720 op.getMixedStep(), op.getInductionVars())) {
1721 auto numIterations = constantTripCount(lb, ub, step);
1722 if (numIterations.has_value()) {
1723 // Remove the loop if it performs zero iterations.
1724 if (*numIterations == 0) {
1725 rewriter.replaceOp(op, op.getOutputs());
1726 return success();
1727 }
1728 // Replace the loop induction variable by the lower bound if the loop
1729 // performs a single iteration. Otherwise, copy the loop bounds.
1730 if (*numIterations == 1) {
1731 mapping.map(iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb));
1732 continue;
1733 }
1734 }
1735 newMixedLowerBounds.push_back(lb);
1736 newMixedUpperBounds.push_back(ub);
1737 newMixedSteps.push_back(step);
1738 }
1739
1740 // All of the loop dimensions perform a single iteration. Inline loop body.
1741 if (newMixedLowerBounds.empty()) {
1742 promote(rewriter, op);
1743 return success();
1744 }
1745
1746 // Exit if none of the loop dimensions perform a single iteration.
1747 if (newMixedLowerBounds.size() == static_cast<unsigned>(op.getRank())) {
1748 return rewriter.notifyMatchFailure(
1749 op, "no dimensions have 0 or 1 iterations");
1750 }
1751
1752 // Replace the loop by a lower-dimensional loop.
1753 ForallOp newOp;
1754 newOp = rewriter.create<ForallOp>(loc, newMixedLowerBounds,
1755 newMixedUpperBounds, newMixedSteps,
1756 op.getOutputs(), std::nullopt, nullptr);
1757 newOp.getBodyRegion().getBlocks().clear();
1758 // The new loop needs to keep all attributes from the old one, except for
1759 // "operandSegmentSizes" and static loop bound attributes which capture
1760 // the outdated information of the old iteration domain.
1761 SmallVector<StringAttr> elidedAttrs{newOp.getOperandSegmentSizesAttrName(),
1762 newOp.getStaticLowerBoundAttrName(),
1763 newOp.getStaticUpperBoundAttrName(),
1764 newOp.getStaticStepAttrName()};
1765 for (const auto &namedAttr : op->getAttrs()) {
1766 if (llvm::is_contained(elidedAttrs, namedAttr.getName()))
1767 continue;
1768 rewriter.modifyOpInPlace(newOp, [&]() {
1769 newOp->setAttr(namedAttr.getName(), namedAttr.getValue());
1770 });
1771 }
1772 rewriter.cloneRegionBefore(op.getRegion(), newOp.getRegion(),
1773 newOp.getRegion().begin(), mapping);
1774 rewriter.replaceOp(op, newOp.getResults());
1775 return success();
1776 }
1777};
1778
1779/// Replace all induction vars with a single trip count with their lower bound.
1780struct ForallOpReplaceConstantInductionVar : public OpRewritePattern<ForallOp> {
1781 using OpRewritePattern<ForallOp>::OpRewritePattern;
1782
1783 LogicalResult matchAndRewrite(ForallOp op,
1784 PatternRewriter &rewriter) const override {
1785 Location loc = op.getLoc();
1786 bool changed = false;
1787 for (auto [lb, ub, step, iv] :
1788 llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
1789 op.getMixedStep(), op.getInductionVars())) {
1790 if (iv.hasNUses(0))
1791 continue;
1792 auto numIterations = constantTripCount(lb, ub, step);
1793 if (!numIterations.has_value() || numIterations.value() != 1) {
1794 continue;
1795 }
1796 rewriter.replaceAllUsesWith(
1797 iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb));
1798 changed = true;
1799 }
1800 return success(IsSuccess: changed);
1801 }
1802};
1803
1804struct FoldTensorCastOfOutputIntoForallOp
1805 : public OpRewritePattern<scf::ForallOp> {
1806 using OpRewritePattern<scf::ForallOp>::OpRewritePattern;
1807
1808 struct TypeCast {
1809 Type srcType;
1810 Type dstType;
1811 };
1812
1813 LogicalResult matchAndRewrite(scf::ForallOp forallOp,
1814 PatternRewriter &rewriter) const final {
1815 llvm::SmallMapVector<unsigned, TypeCast, 2> tensorCastProducers;
1816 llvm::SmallVector<Value> newOutputTensors = forallOp.getOutputs();
1817 for (auto en : llvm::enumerate(newOutputTensors)) {
1818 auto castOp = en.value().getDefiningOp<tensor::CastOp>();
1819 if (!castOp)
1820 continue;
1821
1822 // Only casts that that preserve static information, i.e. will make the
1823 // loop result type "more" static than before, will be folded.
1824 if (!tensor::preservesStaticInformation(castOp.getDest().getType(),
1825 castOp.getSource().getType())) {
1826 continue;
1827 }
1828
1829 tensorCastProducers[en.index()] =
1830 TypeCast{castOp.getSource().getType(), castOp.getType()};
1831 newOutputTensors[en.index()] = castOp.getSource();
1832 }
1833
1834 if (tensorCastProducers.empty())
1835 return failure();
1836
1837 // Create new loop.
1838 Location loc = forallOp.getLoc();
1839 auto newForallOp = rewriter.create<ForallOp>(
1840 loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
1841 forallOp.getMixedStep(), newOutputTensors, forallOp.getMapping(),
1842 [&](OpBuilder nestedBuilder, Location nestedLoc, ValueRange bbArgs) {
1843 auto castBlockArgs =
1844 llvm::to_vector(bbArgs.take_back(forallOp->getNumResults()));
1845 for (auto [index, cast] : tensorCastProducers) {
1846 Value &oldTypeBBArg = castBlockArgs[index];
1847 oldTypeBBArg = nestedBuilder.create<tensor::CastOp>(
1848 nestedLoc, cast.dstType, oldTypeBBArg);
1849 }
1850
1851 // Move old body into new parallel loop.
1852 SmallVector<Value> ivsBlockArgs =
1853 llvm::to_vector(bbArgs.take_front(forallOp.getRank()));
1854 ivsBlockArgs.append(castBlockArgs);
1855 rewriter.mergeBlocks(forallOp.getBody(),
1856 bbArgs.front().getParentBlock(), ivsBlockArgs);
1857 });
1858
1859 // After `mergeBlocks` happened, the destinations in the terminator were
1860 // mapped to the tensor.cast old-typed results of the output bbArgs. The
1861 // destination have to be updated to point to the output bbArgs directly.
1862 auto terminator = newForallOp.getTerminator();
1863 for (auto [yieldingOp, outputBlockArg] : llvm::zip(
1864 terminator.getYieldingOps(), newForallOp.getRegionIterArgs())) {
1865 auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(yieldingOp);
1866 insertSliceOp.getDestMutable().assign(outputBlockArg);
1867 }
1868
1869 // Cast results back to the original types.
1870 rewriter.setInsertionPointAfter(newForallOp);
1871 SmallVector<Value> castResults = newForallOp.getResults();
1872 for (auto &item : tensorCastProducers) {
1873 Value &oldTypeResult = castResults[item.first];
1874 oldTypeResult = rewriter.create<tensor::CastOp>(loc, item.second.dstType,
1875 oldTypeResult);
1876 }
1877 rewriter.replaceOp(forallOp, castResults);
1878 return success();
1879 }
1880};
1881
1882} // namespace
1883
1884void ForallOp::getCanonicalizationPatterns(RewritePatternSet &results,
1885 MLIRContext *context) {
1886 results.add<DimOfForallOp, FoldTensorCastOfOutputIntoForallOp,
1887 ForallOpControlOperandsFolder, ForallOpIterArgsFolder,
1888 ForallOpSingleOrZeroIterationDimsFolder,
1889 ForallOpReplaceConstantInductionVar>(context);
1890}
1891
1892/// Given the region at `index`, or the parent operation if `index` is None,
1893/// return the successor regions. These are the regions that may be selected
1894/// during the flow of control. `operands` is a set of optional attributes that
1895/// correspond to a constant value for each operand, or null if that operand is
1896/// not a constant.
1897void ForallOp::getSuccessorRegions(RegionBranchPoint point,
1898 SmallVectorImpl<RegionSuccessor> &regions) {
1899 // Both the operation itself and the region may be branching into the body or
1900 // back into the operation itself. It is possible for loop not to enter the
1901 // body.
1902 regions.push_back(RegionSuccessor(&getRegion()));
1903 regions.push_back(RegionSuccessor());
1904}
1905
1906//===----------------------------------------------------------------------===//
1907// InParallelOp
1908//===----------------------------------------------------------------------===//
1909
1910// Build a InParallelOp with mixed static and dynamic entries.
1911void InParallelOp::build(OpBuilder &b, OperationState &result) {
1912 OpBuilder::InsertionGuard g(b);
1913 Region *bodyRegion = result.addRegion();
1914 b.createBlock(bodyRegion);
1915}
1916
1917LogicalResult InParallelOp::verify() {
1918 scf::ForallOp forallOp =
1919 dyn_cast<scf::ForallOp>(getOperation()->getParentOp());
1920 if (!forallOp)
1921 return this->emitOpError("expected forall op parent");
1922
1923 // TODO: InParallelOpInterface.
1924 for (Operation &op : getRegion().front().getOperations()) {
1925 if (!isa<tensor::ParallelInsertSliceOp>(op)) {
1926 return this->emitOpError("expected only ")
1927 << tensor::ParallelInsertSliceOp::getOperationName() << " ops";
1928 }
1929
1930 // Verify that inserts are into out block arguments.
1931 Value dest = cast<tensor::ParallelInsertSliceOp>(op).getDest();
1932 ArrayRef<BlockArgument> regionOutArgs = forallOp.getRegionOutArgs();
1933 if (!llvm::is_contained(regionOutArgs, dest))
1934 return op.emitOpError("may only insert into an output block argument");
1935 }
1936 return success();
1937}
1938
1939void InParallelOp::print(OpAsmPrinter &p) {
1940 p << " ";
1941 p.printRegion(getRegion(),
1942 /*printEntryBlockArgs=*/false,
1943 /*printBlockTerminators=*/false);
1944 p.printOptionalAttrDict(getOperation()->getAttrs());
1945}
1946
1947ParseResult InParallelOp::parse(OpAsmParser &parser, OperationState &result) {
1948 auto &builder = parser.getBuilder();
1949
1950 SmallVector<OpAsmParser::Argument, 8> regionOperands;
1951 std::unique_ptr<Region> region = std::make_unique<Region>();
1952 if (parser.parseRegion(*region, regionOperands))
1953 return failure();
1954
1955 if (region->empty())
1956 OpBuilder(builder.getContext()).createBlock(region.get());
1957 result.addRegion(std::move(region));
1958
1959 // Parse the optional attribute list.
1960 if (parser.parseOptionalAttrDict(result.attributes))
1961 return failure();
1962 return success();
1963}
1964
1965OpResult InParallelOp::getParentResult(int64_t idx) {
1966 return getOperation()->getParentOp()->getResult(idx);
1967}
1968
1969SmallVector<BlockArgument> InParallelOp::getDests() {
1970 return llvm::to_vector<4>(
1971 llvm::map_range(getYieldingOps(), [](Operation &op) {
1972 // Add new ops here as needed.
1973 auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(&op);
1974 return llvm::cast<BlockArgument>(insertSliceOp.getDest());
1975 }));
1976}
1977
1978llvm::iterator_range<Block::iterator> InParallelOp::getYieldingOps() {
1979 return getRegion().front().getOperations();
1980}
1981
1982//===----------------------------------------------------------------------===//
1983// IfOp
1984//===----------------------------------------------------------------------===//
1985
1986bool mlir::scf::insideMutuallyExclusiveBranches(Operation *a, Operation *b) {
1987 assert(a && "expected non-empty operation");
1988 assert(b && "expected non-empty operation");
1989
1990 IfOp ifOp = a->getParentOfType<IfOp>();
1991 while (ifOp) {
1992 // Check if b is inside ifOp. (We already know that a is.)
1993 if (ifOp->isProperAncestor(b))
1994 // b is contained in ifOp. a and b are in mutually exclusive branches if
1995 // they are in different blocks of ifOp.
1996 return static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*a)) !=
1997 static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*b));
1998 // Check next enclosing IfOp.
1999 ifOp = ifOp->getParentOfType<IfOp>();
2000 }
2001
2002 // Could not find a common IfOp among a's and b's ancestors.
2003 return false;
2004}
2005
2006LogicalResult
2007IfOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
2008 IfOp::Adaptor adaptor,
2009 SmallVectorImpl<Type> &inferredReturnTypes) {
2010 if (adaptor.getRegions().empty())
2011 return failure();
2012 Region *r = &adaptor.getThenRegion();
2013 if (r->empty())
2014 return failure();
2015 Block &b = r->front();
2016 if (b.empty())
2017 return failure();
2018 auto yieldOp = llvm::dyn_cast<YieldOp>(b.back());
2019 if (!yieldOp)
2020 return failure();
2021 TypeRange types = yieldOp.getOperandTypes();
2022 llvm::append_range(inferredReturnTypes, types);
2023 return success();
2024}
2025
2026void IfOp::build(OpBuilder &builder, OperationState &result,
2027 TypeRange resultTypes, Value cond) {
2028 return build(builder, result, resultTypes, cond, /*addThenBlock=*/false,
2029 /*addElseBlock=*/false);
2030}
2031
2032void IfOp::build(OpBuilder &builder, OperationState &result,
2033 TypeRange resultTypes, Value cond, bool addThenBlock,
2034 bool addElseBlock) {
2035 assert((!addElseBlock || addThenBlock) &&
2036 "must not create else block w/o then block");
2037 result.addTypes(resultTypes);
2038 result.addOperands(cond);
2039
2040 // Add regions and blocks.
2041 OpBuilder::InsertionGuard guard(builder);
2042 Region *thenRegion = result.addRegion();
2043 if (addThenBlock)
2044 builder.createBlock(thenRegion);
2045 Region *elseRegion = result.addRegion();
2046 if (addElseBlock)
2047 builder.createBlock(elseRegion);
2048}
2049
2050void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
2051 bool withElseRegion) {
2052 build(builder, result, TypeRange{}, cond, withElseRegion);
2053}
2054
2055void IfOp::build(OpBuilder &builder, OperationState &result,
2056 TypeRange resultTypes, Value cond, bool withElseRegion) {
2057 result.addTypes(resultTypes);
2058 result.addOperands(cond);
2059
2060 // Build then region.
2061 OpBuilder::InsertionGuard guard(builder);
2062 Region *thenRegion = result.addRegion();
2063 builder.createBlock(thenRegion);
2064 if (resultTypes.empty())
2065 IfOp::ensureTerminator(*thenRegion, builder, result.location);
2066
2067 // Build else region.
2068 Region *elseRegion = result.addRegion();
2069 if (withElseRegion) {
2070 builder.createBlock(elseRegion);
2071 if (resultTypes.empty())
2072 IfOp::ensureTerminator(*elseRegion, builder, result.location);
2073 }
2074}
2075
2076void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
2077 function_ref<void(OpBuilder &, Location)> thenBuilder,
2078 function_ref<void(OpBuilder &, Location)> elseBuilder) {
2079 assert(thenBuilder && "the builder callback for 'then' must be present");
2080 result.addOperands(cond);
2081
2082 // Build then region.
2083 OpBuilder::InsertionGuard guard(builder);
2084 Region *thenRegion = result.addRegion();
2085 builder.createBlock(thenRegion);
2086 thenBuilder(builder, result.location);
2087
2088 // Build else region.
2089 Region *elseRegion = result.addRegion();
2090 if (elseBuilder) {
2091 builder.createBlock(elseRegion);
2092 elseBuilder(builder, result.location);
2093 }
2094
2095 // Infer result types.
2096 SmallVector<Type> inferredReturnTypes;
2097 MLIRContext *ctx = builder.getContext();
2098 auto attrDict = DictionaryAttr::get(ctx, result.attributes);
2099 if (succeeded(inferReturnTypes(ctx, std::nullopt, result.operands, attrDict,
2100 /*properties=*/nullptr, result.regions,
2101 inferredReturnTypes))) {
2102 result.addTypes(inferredReturnTypes);
2103 }
2104}
2105
2106LogicalResult IfOp::verify() {
2107 if (getNumResults() != 0 && getElseRegion().empty())
2108 return emitOpError("must have an else block if defining values");
2109 return success();
2110}
2111
2112ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
2113 // Create the regions for 'then'.
2114 result.regions.reserve(2);
2115 Region *thenRegion = result.addRegion();
2116 Region *elseRegion = result.addRegion();
2117
2118 auto &builder = parser.getBuilder();
2119 OpAsmParser::UnresolvedOperand cond;
2120 Type i1Type = builder.getIntegerType(1);
2121 if (parser.parseOperand(cond) ||
2122 parser.resolveOperand(cond, i1Type, result.operands))
2123 return failure();
2124 // Parse optional results type list.
2125 if (parser.parseOptionalArrowTypeList(result.types))
2126 return failure();
2127 // Parse the 'then' region.
2128 if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
2129 return failure();
2130 IfOp::ensureTerminator(*thenRegion, parser.getBuilder(), result.location);
2131
2132 // If we find an 'else' keyword then parse the 'else' region.
2133 if (!parser.parseOptionalKeyword("else")) {
2134 if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
2135 return failure();
2136 IfOp::ensureTerminator(*elseRegion, parser.getBuilder(), result.location);
2137 }
2138
2139 // Parse the optional attribute list.
2140 if (parser.parseOptionalAttrDict(result.attributes))
2141 return failure();
2142 return success();
2143}
2144
2145void IfOp::print(OpAsmPrinter &p) {
2146 bool printBlockTerminators = false;
2147
2148 p << " " << getCondition();
2149 if (!getResults().empty()) {
2150 p << " -> (" << getResultTypes() << ")";
2151 // Print yield explicitly if the op defines values.
2152 printBlockTerminators = true;
2153 }
2154 p << ' ';
2155 p.printRegion(getThenRegion(),
2156 /*printEntryBlockArgs=*/false,
2157 /*printBlockTerminators=*/printBlockTerminators);
2158
2159 // Print the 'else' regions if it exists and has a block.
2160 auto &elseRegion = getElseRegion();
2161 if (!elseRegion.empty()) {
2162 p << " else ";
2163 p.printRegion(elseRegion,
2164 /*printEntryBlockArgs=*/false,
2165 /*printBlockTerminators=*/printBlockTerminators);
2166 }
2167
2168 p.printOptionalAttrDict((*this)->getAttrs());
2169}
2170
2171void IfOp::getSuccessorRegions(RegionBranchPoint point,
2172 SmallVectorImpl<RegionSuccessor> &regions) {
2173 // The `then` and the `else` region branch back to the parent operation.
2174 if (!point.isParent()) {
2175 regions.push_back(RegionSuccessor(getResults()));
2176 return;
2177 }
2178
2179 regions.push_back(RegionSuccessor(&getThenRegion()));
2180
2181 // Don't consider the else region if it is empty.
2182 Region *elseRegion = &this->getElseRegion();
2183 if (elseRegion->empty())
2184 regions.push_back(RegionSuccessor());
2185 else
2186 regions.push_back(RegionSuccessor(elseRegion));
2187}
2188
2189void IfOp::getEntrySuccessorRegions(ArrayRef<Attribute> operands,
2190 SmallVectorImpl<RegionSuccessor> &regions) {
2191 FoldAdaptor adaptor(operands, *this);
2192 auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
2193 if (!boolAttr || boolAttr.getValue())
2194 regions.emplace_back(&getThenRegion());
2195
2196 // If the else region is empty, execution continues after the parent op.
2197 if (!boolAttr || !boolAttr.getValue()) {
2198 if (!getElseRegion().empty())
2199 regions.emplace_back(&getElseRegion());
2200 else
2201 regions.emplace_back(getResults());
2202 }
2203}
2204
2205LogicalResult IfOp::fold(FoldAdaptor adaptor,
2206 SmallVectorImpl<OpFoldResult> &results) {
2207 // if (!c) then A() else B() -> if c then B() else A()
2208 if (getElseRegion().empty())
2209 return failure();
2210
2211 arith::XOrIOp xorStmt = getCondition().getDefiningOp<arith::XOrIOp>();
2212 if (!xorStmt)
2213 return failure();
2214
2215 if (!matchPattern(xorStmt.getRhs(), m_One()))
2216 return failure();
2217
2218 getConditionMutable().assign(xorStmt.getLhs());
2219 Block *thenBlock = &getThenRegion().front();
2220 // It would be nicer to use iplist::swap, but that has no implemented
2221 // callbacks See: https://llvm.org/doxygen/ilist_8h_source.html#l00224
2222 getThenRegion().getBlocks().splice(getThenRegion().getBlocks().begin(),
2223 getElseRegion().getBlocks());
2224 getElseRegion().getBlocks().splice(getElseRegion().getBlocks().begin(),
2225 getThenRegion().getBlocks(), thenBlock);
2226 return success();
2227}
2228
2229void IfOp::getRegionInvocationBounds(
2230 ArrayRef<Attribute> operands,
2231 SmallVectorImpl<InvocationBounds> &invocationBounds) {
2232 if (auto cond = llvm::dyn_cast_or_null<BoolAttr>(operands[0])) {
2233 // If the condition is known, then one region is known to be executed once
2234 // and the other zero times.
2235 invocationBounds.emplace_back(0, cond.getValue() ? 1 : 0);
2236 invocationBounds.emplace_back(0, cond.getValue() ? 0 : 1);
2237 } else {
2238 // Non-constant condition. Each region may be executed 0 or 1 times.
2239 invocationBounds.assign(2, {0, 1});
2240 }
2241}
2242
2243namespace {
2244// Pattern to remove unused IfOp results.
2245struct RemoveUnusedResults : public OpRewritePattern<IfOp> {
2246 using OpRewritePattern<IfOp>::OpRewritePattern;
2247
2248 void transferBody(Block *source, Block *dest, ArrayRef<OpResult> usedResults,
2249 PatternRewriter &rewriter) const {
2250 // Move all operations to the destination block.
2251 rewriter.mergeBlocks(source, dest);
2252 // Replace the yield op by one that returns only the used values.
2253 auto yieldOp = cast<scf::YieldOp>(dest->getTerminator());
2254 SmallVector<Value, 4> usedOperands;
2255 llvm::transform(Range&: usedResults, d_first: std::back_inserter(x&: usedOperands),
2256 F: [&](OpResult result) {
2257 return yieldOp.getOperand(result.getResultNumber());
2258 });
2259 rewriter.modifyOpInPlace(yieldOp,
2260 [&]() { yieldOp->setOperands(usedOperands); });
2261 }
2262
2263 LogicalResult matchAndRewrite(IfOp op,
2264 PatternRewriter &rewriter) const override {
2265 // Compute the list of used results.
2266 SmallVector<OpResult, 4> usedResults;
2267 llvm::copy_if(op.getResults(), std::back_inserter(x&: usedResults),
2268 [](OpResult result) { return !result.use_empty(); });
2269
2270 // Replace the operation if only a subset of its results have uses.
2271 if (usedResults.size() == op.getNumResults())
2272 return failure();
2273
2274 // Compute the result types of the replacement operation.
2275 SmallVector<Type, 4> newTypes;
2276 llvm::transform(Range&: usedResults, d_first: std::back_inserter(x&: newTypes),
2277 F: [](OpResult result) { return result.getType(); });
2278
2279 // Create a replacement operation with empty then and else regions.
2280 auto newOp =
2281 rewriter.create<IfOp>(op.getLoc(), newTypes, op.getCondition());
2282 rewriter.createBlock(&newOp.getThenRegion());
2283 rewriter.createBlock(&newOp.getElseRegion());
2284
2285 // Move the bodies and replace the terminators (note there is a then and
2286 // an else region since the operation returns results).
2287 transferBody(source: op.getBody(0), dest: newOp.getBody(0), usedResults, rewriter);
2288 transferBody(source: op.getBody(1), dest: newOp.getBody(1), usedResults, rewriter);
2289
2290 // Replace the operation by the new one.
2291 SmallVector<Value, 4> repResults(op.getNumResults());
2292 for (const auto &en : llvm::enumerate(First&: usedResults))
2293 repResults[en.value().getResultNumber()] = newOp.getResult(en.index());
2294 rewriter.replaceOp(op, repResults);
2295 return success();
2296 }
2297};
2298
2299struct RemoveStaticCondition : public OpRewritePattern<IfOp> {
2300 using OpRewritePattern<IfOp>::OpRewritePattern;
2301
2302 LogicalResult matchAndRewrite(IfOp op,
2303 PatternRewriter &rewriter) const override {
2304 BoolAttr condition;
2305 if (!matchPattern(op.getCondition(), m_Constant(bind_value: &condition)))
2306 return failure();
2307
2308 if (condition.getValue())
2309 replaceOpWithRegion(rewriter, op, op.getThenRegion());
2310 else if (!op.getElseRegion().empty())
2311 replaceOpWithRegion(rewriter, op, op.getElseRegion());
2312 else
2313 rewriter.eraseOp(op: op);
2314
2315 return success();
2316 }
2317};
2318
2319/// Hoist any yielded results whose operands are defined outside
2320/// the if, to a select instruction.
2321struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> {
2322 using OpRewritePattern<IfOp>::OpRewritePattern;
2323
2324 LogicalResult matchAndRewrite(IfOp op,
2325 PatternRewriter &rewriter) const override {
2326 if (op->getNumResults() == 0)
2327 return failure();
2328
2329 auto cond = op.getCondition();
2330 auto thenYieldArgs = op.thenYield().getOperands();
2331 auto elseYieldArgs = op.elseYield().getOperands();
2332
2333 SmallVector<Type> nonHoistable;
2334 for (auto [trueVal, falseVal] : llvm::zip(thenYieldArgs, elseYieldArgs)) {
2335 if (&op.getThenRegion() == trueVal.getParentRegion() ||
2336 &op.getElseRegion() == falseVal.getParentRegion())
2337 nonHoistable.push_back(trueVal.getType());
2338 }
2339 // Early exit if there aren't any yielded values we can
2340 // hoist outside the if.
2341 if (nonHoistable.size() == op->getNumResults())
2342 return failure();
2343
2344 IfOp replacement = rewriter.create<IfOp>(op.getLoc(), nonHoistable, cond,
2345 /*withElseRegion=*/false);
2346 if (replacement.thenBlock())
2347 rewriter.eraseBlock(block: replacement.thenBlock());
2348 replacement.getThenRegion().takeBody(op.getThenRegion());
2349 replacement.getElseRegion().takeBody(op.getElseRegion());
2350
2351 SmallVector<Value> results(op->getNumResults());
2352 assert(thenYieldArgs.size() == results.size());
2353 assert(elseYieldArgs.size() == results.size());
2354
2355 SmallVector<Value> trueYields;
2356 SmallVector<Value> falseYields;
2357 rewriter.setInsertionPoint(replacement);
2358 for (const auto &it :
2359 llvm::enumerate(llvm::zip(thenYieldArgs, elseYieldArgs))) {
2360 Value trueVal = std::get<0>(it.value());
2361 Value falseVal = std::get<1>(it.value());
2362 if (&replacement.getThenRegion() == trueVal.getParentRegion() ||
2363 &replacement.getElseRegion() == falseVal.getParentRegion()) {
2364 results[it.index()] = replacement.getResult(trueYields.size());
2365 trueYields.push_back(trueVal);
2366 falseYields.push_back(falseVal);
2367 } else if (trueVal == falseVal)
2368 results[it.index()] = trueVal;
2369 else
2370 results[it.index()] = rewriter.create<arith::SelectOp>(
2371 op.getLoc(), cond, trueVal, falseVal);
2372 }
2373
2374 rewriter.setInsertionPointToEnd(replacement.thenBlock());
2375 rewriter.replaceOpWithNewOp<YieldOp>(replacement.thenYield(), trueYields);
2376
2377 rewriter.setInsertionPointToEnd(replacement.elseBlock());
2378 rewriter.replaceOpWithNewOp<YieldOp>(replacement.elseYield(), falseYields);
2379
2380 rewriter.replaceOp(op, results);
2381 return success();
2382 }
2383};
2384
2385/// Allow the true region of an if to assume the condition is true
2386/// and vice versa. For example:
2387///
2388/// scf.if %cmp {
2389/// print(%cmp)
2390/// }
2391///
2392/// becomes
2393///
2394/// scf.if %cmp {
2395/// print(true)
2396/// }
2397///
2398struct ConditionPropagation : public OpRewritePattern<IfOp> {
2399 using OpRewritePattern<IfOp>::OpRewritePattern;
2400
2401 LogicalResult matchAndRewrite(IfOp op,
2402 PatternRewriter &rewriter) const override {
2403 // Early exit if the condition is constant since replacing a constant
2404 // in the body with another constant isn't a simplification.
2405 if (matchPattern(op.getCondition(), m_Constant()))
2406 return failure();
2407
2408 bool changed = false;
2409 mlir::Type i1Ty = rewriter.getI1Type();
2410
2411 // These variables serve to prevent creating duplicate constants
2412 // and hold constant true or false values.
2413 Value constantTrue = nullptr;
2414 Value constantFalse = nullptr;
2415
2416 for (OpOperand &use :
2417 llvm::make_early_inc_range(op.getCondition().getUses())) {
2418 if (op.getThenRegion().isAncestor(use.getOwner()->getParentRegion())) {
2419 changed = true;
2420
2421 if (!constantTrue)
2422 constantTrue = rewriter.create<arith::ConstantOp>(
2423 op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1));
2424
2425 rewriter.modifyOpInPlace(use.getOwner(),
2426 [&]() { use.set(constantTrue); });
2427 } else if (op.getElseRegion().isAncestor(
2428 use.getOwner()->getParentRegion())) {
2429 changed = true;
2430
2431 if (!constantFalse)
2432 constantFalse = rewriter.create<arith::ConstantOp>(
2433 op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 0));
2434
2435 rewriter.modifyOpInPlace(use.getOwner(),
2436 [&]() { use.set(constantFalse); });
2437 }
2438 }
2439
2440 return success(IsSuccess: changed);
2441 }
2442};
2443
2444/// Remove any statements from an if that are equivalent to the condition
2445/// or its negation. For example:
2446///
2447/// %res:2 = scf.if %cmp {
2448/// yield something(), true
2449/// } else {
2450/// yield something2(), false
2451/// }
2452/// print(%res#1)
2453///
2454/// becomes
2455/// %res = scf.if %cmp {
2456/// yield something()
2457/// } else {
2458/// yield something2()
2459/// }
2460/// print(%cmp)
2461///
2462/// Additionally if both branches yield the same value, replace all uses
2463/// of the result with the yielded value.
2464///
2465/// %res:2 = scf.if %cmp {
2466/// yield something(), %arg1
2467/// } else {
2468/// yield something2(), %arg1
2469/// }
2470/// print(%res#1)
2471///
2472/// becomes
2473/// %res = scf.if %cmp {
2474/// yield something()
2475/// } else {
2476/// yield something2()
2477/// }
2478/// print(%arg1)
2479///
2480struct ReplaceIfYieldWithConditionOrValue : public OpRewritePattern<IfOp> {
2481 using OpRewritePattern<IfOp>::OpRewritePattern;
2482
2483 LogicalResult matchAndRewrite(IfOp op,
2484 PatternRewriter &rewriter) const override {
2485 // Early exit if there are no results that could be replaced.
2486 if (op.getNumResults() == 0)
2487 return failure();
2488
2489 auto trueYield =
2490 cast<scf::YieldOp>(op.getThenRegion().back().getTerminator());
2491 auto falseYield =
2492 cast<scf::YieldOp>(op.getElseRegion().back().getTerminator());
2493
2494 rewriter.setInsertionPoint(op->getBlock(),
2495 op.getOperation()->getIterator());
2496 bool changed = false;
2497 Type i1Ty = rewriter.getI1Type();
2498 for (auto [trueResult, falseResult, opResult] :
2499 llvm::zip(trueYield.getResults(), falseYield.getResults(),
2500 op.getResults())) {
2501 if (trueResult == falseResult) {
2502 if (!opResult.use_empty()) {
2503 opResult.replaceAllUsesWith(trueResult);
2504 changed = true;
2505 }
2506 continue;
2507 }
2508
2509 BoolAttr trueYield, falseYield;
2510 if (!matchPattern(trueResult, m_Constant(&trueYield)) ||
2511 !matchPattern(falseResult, m_Constant(&falseYield)))
2512 continue;
2513
2514 bool trueVal = trueYield.getValue();
2515 bool falseVal = falseYield.getValue();
2516 if (!trueVal && falseVal) {
2517 if (!opResult.use_empty()) {
2518 Dialect *constDialect = trueResult.getDefiningOp()->getDialect();
2519 Value notCond = rewriter.create<arith::XOrIOp>(
2520 op.getLoc(), op.getCondition(),
2521 constDialect
2522 ->materializeConstant(rewriter,
2523 rewriter.getIntegerAttr(i1Ty, 1), i1Ty,
2524 op.getLoc())
2525 ->getResult(0));
2526 opResult.replaceAllUsesWith(notCond);
2527 changed = true;
2528 }
2529 }
2530 if (trueVal && !falseVal) {
2531 if (!opResult.use_empty()) {
2532 opResult.replaceAllUsesWith(op.getCondition());
2533 changed = true;
2534 }
2535 }
2536 }
2537 return success(IsSuccess: changed);
2538 }
2539};
2540
2541/// Merge any consecutive scf.if's with the same condition.
2542///
2543/// scf.if %cond {
2544/// firstCodeTrue();...
2545/// } else {
2546/// firstCodeFalse();...
2547/// }
2548/// %res = scf.if %cond {
2549/// secondCodeTrue();...
2550/// } else {
2551/// secondCodeFalse();...
2552/// }
2553///
2554/// becomes
2555/// %res = scf.if %cmp {
2556/// firstCodeTrue();...
2557/// secondCodeTrue();...
2558/// } else {
2559/// firstCodeFalse();...
2560/// secondCodeFalse();...
2561/// }
2562struct CombineIfs : public OpRewritePattern<IfOp> {
2563 using OpRewritePattern<IfOp>::OpRewritePattern;
2564
2565 LogicalResult matchAndRewrite(IfOp nextIf,
2566 PatternRewriter &rewriter) const override {
2567 Block *parent = nextIf->getBlock();
2568 if (nextIf == &parent->front())
2569 return failure();
2570
2571 auto prevIf = dyn_cast<IfOp>(nextIf->getPrevNode());
2572 if (!prevIf)
2573 return failure();
2574
2575 // Determine the logical then/else blocks when prevIf's
2576 // condition is used. Null means the block does not exist
2577 // in that case (e.g. empty else). If neither of these
2578 // are set, the two conditions cannot be compared.
2579 Block *nextThen = nullptr;
2580 Block *nextElse = nullptr;
2581 if (nextIf.getCondition() == prevIf.getCondition()) {
2582 nextThen = nextIf.thenBlock();
2583 if (!nextIf.getElseRegion().empty())
2584 nextElse = nextIf.elseBlock();
2585 }
2586 if (arith::XOrIOp notv =
2587 nextIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2588 if (notv.getLhs() == prevIf.getCondition() &&
2589 matchPattern(notv.getRhs(), m_One())) {
2590 nextElse = nextIf.thenBlock();
2591 if (!nextIf.getElseRegion().empty())
2592 nextThen = nextIf.elseBlock();
2593 }
2594 }
2595 if (arith::XOrIOp notv =
2596 prevIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2597 if (notv.getLhs() == nextIf.getCondition() &&
2598 matchPattern(notv.getRhs(), m_One())) {
2599 nextElse = nextIf.thenBlock();
2600 if (!nextIf.getElseRegion().empty())
2601 nextThen = nextIf.elseBlock();
2602 }
2603 }
2604
2605 if (!nextThen && !nextElse)
2606 return failure();
2607
2608 SmallVector<Value> prevElseYielded;
2609 if (!prevIf.getElseRegion().empty())
2610 prevElseYielded = prevIf.elseYield().getOperands();
2611 // Replace all uses of return values of op within nextIf with the
2612 // corresponding yields
2613 for (auto it : llvm::zip(prevIf.getResults(),
2614 prevIf.thenYield().getOperands(), prevElseYielded))
2615 for (OpOperand &use :
2616 llvm::make_early_inc_range(std::get<0>(it).getUses())) {
2617 if (nextThen && nextThen->getParent()->isAncestor(
2618 use.getOwner()->getParentRegion())) {
2619 rewriter.startOpModification(use.getOwner());
2620 use.set(std::get<1>(it));
2621 rewriter.finalizeOpModification(use.getOwner());
2622 } else if (nextElse && nextElse->getParent()->isAncestor(
2623 use.getOwner()->getParentRegion())) {
2624 rewriter.startOpModification(use.getOwner());
2625 use.set(std::get<2>(it));
2626 rewriter.finalizeOpModification(use.getOwner());
2627 }
2628 }
2629
2630 SmallVector<Type> mergedTypes(prevIf.getResultTypes());
2631 llvm::append_range(mergedTypes, nextIf.getResultTypes());
2632
2633 IfOp combinedIf = rewriter.create<IfOp>(
2634 nextIf.getLoc(), mergedTypes, prevIf.getCondition(), /*hasElse=*/false);
2635 rewriter.eraseBlock(block: &combinedIf.getThenRegion().back());
2636
2637 rewriter.inlineRegionBefore(prevIf.getThenRegion(),
2638 combinedIf.getThenRegion(),
2639 combinedIf.getThenRegion().begin());
2640
2641 if (nextThen) {
2642 YieldOp thenYield = combinedIf.thenYield();
2643 YieldOp thenYield2 = cast<YieldOp>(nextThen->getTerminator());
2644 rewriter.mergeBlocks(source: nextThen, dest: combinedIf.thenBlock());
2645 rewriter.setInsertionPointToEnd(combinedIf.thenBlock());
2646
2647 SmallVector<Value> mergedYields(thenYield.getOperands());
2648 llvm::append_range(mergedYields, thenYield2.getOperands());
2649 rewriter.create<YieldOp>(thenYield2.getLoc(), mergedYields);
2650 rewriter.eraseOp(op: thenYield);
2651 rewriter.eraseOp(op: thenYield2);
2652 }
2653
2654 rewriter.inlineRegionBefore(prevIf.getElseRegion(),
2655 combinedIf.getElseRegion(),
2656 combinedIf.getElseRegion().begin());
2657
2658 if (nextElse) {
2659 if (combinedIf.getElseRegion().empty()) {
2660 rewriter.inlineRegionBefore(*nextElse->getParent(),
2661 combinedIf.getElseRegion(),
2662 combinedIf.getElseRegion().begin());
2663 } else {
2664 YieldOp elseYield = combinedIf.elseYield();
2665 YieldOp elseYield2 = cast<YieldOp>(nextElse->getTerminator());
2666 rewriter.mergeBlocks(source: nextElse, dest: combinedIf.elseBlock());
2667
2668 rewriter.setInsertionPointToEnd(combinedIf.elseBlock());
2669
2670 SmallVector<Value> mergedElseYields(elseYield.getOperands());
2671 llvm::append_range(mergedElseYields, elseYield2.getOperands());
2672
2673 rewriter.create<YieldOp>(elseYield2.getLoc(), mergedElseYields);
2674 rewriter.eraseOp(op: elseYield);
2675 rewriter.eraseOp(op: elseYield2);
2676 }
2677 }
2678
2679 SmallVector<Value> prevValues;
2680 SmallVector<Value> nextValues;
2681 for (const auto &pair : llvm::enumerate(combinedIf.getResults())) {
2682 if (pair.index() < prevIf.getNumResults())
2683 prevValues.push_back(pair.value());
2684 else
2685 nextValues.push_back(pair.value());
2686 }
2687 rewriter.replaceOp(prevIf, prevValues);
2688 rewriter.replaceOp(nextIf, nextValues);
2689 return success();
2690 }
2691};
2692
2693/// Pattern to remove an empty else branch.
2694struct RemoveEmptyElseBranch : public OpRewritePattern<IfOp> {
2695 using OpRewritePattern<IfOp>::OpRewritePattern;
2696
2697 LogicalResult matchAndRewrite(IfOp ifOp,
2698 PatternRewriter &rewriter) const override {
2699 // Cannot remove else region when there are operation results.
2700 if (ifOp.getNumResults())
2701 return failure();
2702 Block *elseBlock = ifOp.elseBlock();
2703 if (!elseBlock || !llvm::hasSingleElement(C&: *elseBlock))
2704 return failure();
2705 auto newIfOp = rewriter.cloneWithoutRegions(ifOp);
2706 rewriter.inlineRegionBefore(ifOp.getThenRegion(), newIfOp.getThenRegion(),
2707 newIfOp.getThenRegion().begin());
2708 rewriter.eraseOp(op: ifOp);
2709 return success();
2710 }
2711};
2712
2713/// Convert nested `if`s into `arith.andi` + single `if`.
2714///
2715/// scf.if %arg0 {
2716/// scf.if %arg1 {
2717/// ...
2718/// scf.yield
2719/// }
2720/// scf.yield
2721/// }
2722/// becomes
2723///
2724/// %0 = arith.andi %arg0, %arg1
2725/// scf.if %0 {
2726/// ...
2727/// scf.yield
2728/// }
2729struct CombineNestedIfs : public OpRewritePattern<IfOp> {
2730 using OpRewritePattern<IfOp>::OpRewritePattern;
2731
2732 LogicalResult matchAndRewrite(IfOp op,
2733 PatternRewriter &rewriter) const override {
2734 auto nestedOps = op.thenBlock()->without_terminator();
2735 // Nested `if` must be the only op in block.
2736 if (!llvm::hasSingleElement(nestedOps))
2737 return failure();
2738
2739 // If there is an else block, it can only yield
2740 if (op.elseBlock() && !llvm::hasSingleElement(*op.elseBlock()))
2741 return failure();
2742
2743 auto nestedIf = dyn_cast<IfOp>(*nestedOps.begin());
2744 if (!nestedIf)
2745 return failure();
2746
2747 if (nestedIf.elseBlock() && !llvm::hasSingleElement(*nestedIf.elseBlock()))
2748 return failure();
2749
2750 SmallVector<Value> thenYield(op.thenYield().getOperands());
2751 SmallVector<Value> elseYield;
2752 if (op.elseBlock())
2753 llvm::append_range(elseYield, op.elseYield().getOperands());
2754
2755 // A list of indices for which we should upgrade the value yielded
2756 // in the else to a select.
2757 SmallVector<unsigned> elseYieldsToUpgradeToSelect;
2758
2759 // If the outer scf.if yields a value produced by the inner scf.if,
2760 // only permit combining if the value yielded when the condition
2761 // is false in the outer scf.if is the same value yielded when the
2762 // inner scf.if condition is false.
2763 // Note that the array access to elseYield will not go out of bounds
2764 // since it must have the same length as thenYield, since they both
2765 // come from the same scf.if.
2766 for (const auto &tup : llvm::enumerate(thenYield)) {
2767 if (tup.value().getDefiningOp() == nestedIf) {
2768 auto nestedIdx = llvm::cast<OpResult>(tup.value()).getResultNumber();
2769 if (nestedIf.elseYield().getOperand(nestedIdx) !=
2770 elseYield[tup.index()]) {
2771 return failure();
2772 }
2773 // If the correctness test passes, we will yield
2774 // corresponding value from the inner scf.if
2775 thenYield[tup.index()] = nestedIf.thenYield().getOperand(nestedIdx);
2776 continue;
2777 }
2778
2779 // Otherwise, we need to ensure the else block of the combined
2780 // condition still returns the same value when the outer condition is
2781 // true and the inner condition is false. This can be accomplished if
2782 // the then value is defined outside the outer scf.if and we replace the
2783 // value with a select that considers just the outer condition. Since
2784 // the else region contains just the yield, its yielded value is
2785 // defined outside the scf.if, by definition.
2786
2787 // If the then value is defined within the scf.if, bail.
2788 if (tup.value().getParentRegion() == &op.getThenRegion()) {
2789 return failure();
2790 }
2791 elseYieldsToUpgradeToSelect.push_back(tup.index());
2792 }
2793
2794 Location loc = op.getLoc();
2795 Value newCondition = rewriter.create<arith::AndIOp>(
2796 loc, op.getCondition(), nestedIf.getCondition());
2797 auto newIf = rewriter.create<IfOp>(loc, op.getResultTypes(), newCondition);
2798 Block *newIfBlock = rewriter.createBlock(&newIf.getThenRegion());
2799
2800 SmallVector<Value> results;
2801 llvm::append_range(results, newIf.getResults());
2802 rewriter.setInsertionPoint(newIf);
2803
2804 for (auto idx : elseYieldsToUpgradeToSelect)
2805 results[idx] = rewriter.create<arith::SelectOp>(
2806 op.getLoc(), op.getCondition(), thenYield[idx], elseYield[idx]);
2807
2808 rewriter.mergeBlocks(source: nestedIf.thenBlock(), dest: newIfBlock);
2809 rewriter.setInsertionPointToEnd(newIf.thenBlock());
2810 rewriter.replaceOpWithNewOp<YieldOp>(newIf.thenYield(), thenYield);
2811 if (!elseYield.empty()) {
2812 rewriter.createBlock(&newIf.getElseRegion());
2813 rewriter.setInsertionPointToEnd(newIf.elseBlock());
2814 rewriter.create<YieldOp>(loc, elseYield);
2815 }
2816 rewriter.replaceOp(op, results);
2817 return success();
2818 }
2819};
2820
2821} // namespace
2822
2823void IfOp::getCanonicalizationPatterns(RewritePatternSet &results,
2824 MLIRContext *context) {
2825 results.add<CombineIfs, CombineNestedIfs, ConditionPropagation,
2826 ConvertTrivialIfToSelect, RemoveEmptyElseBranch,
2827 RemoveStaticCondition, RemoveUnusedResults,
2828 ReplaceIfYieldWithConditionOrValue>(context);
2829}
2830
2831Block *IfOp::thenBlock() { return &getThenRegion().back(); }
2832YieldOp IfOp::thenYield() { return cast<YieldOp>(&thenBlock()->back()); }
2833Block *IfOp::elseBlock() {
2834 Region &r = getElseRegion();
2835 if (r.empty())
2836 return nullptr;
2837 return &r.back();
2838}
2839YieldOp IfOp::elseYield() { return cast<YieldOp>(&elseBlock()->back()); }
2840
2841//===----------------------------------------------------------------------===//
2842// ParallelOp
2843//===----------------------------------------------------------------------===//
2844
2845void ParallelOp::build(
2846 OpBuilder &builder, OperationState &result, ValueRange lowerBounds,
2847 ValueRange upperBounds, ValueRange steps, ValueRange initVals,
2848 function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)>
2849 bodyBuilderFn) {
2850 result.addOperands(lowerBounds);
2851 result.addOperands(upperBounds);
2852 result.addOperands(steps);
2853 result.addOperands(initVals);
2854 result.addAttribute(
2855 ParallelOp::getOperandSegmentSizeAttr(),
2856 builder.getDenseI32ArrayAttr({static_cast<int32_t>(lowerBounds.size()),
2857 static_cast<int32_t>(upperBounds.size()),
2858 static_cast<int32_t>(steps.size()),
2859 static_cast<int32_t>(initVals.size())}));
2860 result.addTypes(initVals.getTypes());
2861
2862 OpBuilder::InsertionGuard guard(builder);
2863 unsigned numIVs = steps.size();
2864 SmallVector<Type, 8> argTypes(numIVs, builder.getIndexType());
2865 SmallVector<Location, 8> argLocs(numIVs, result.location);
2866 Region *bodyRegion = result.addRegion();
2867 Block *bodyBlock = builder.createBlock(bodyRegion, {}, argTypes, argLocs);
2868
2869 if (bodyBuilderFn) {
2870 builder.setInsertionPointToStart(bodyBlock);
2871 bodyBuilderFn(builder, result.location,
2872 bodyBlock->getArguments().take_front(numIVs),
2873 bodyBlock->getArguments().drop_front(numIVs));
2874 }
2875 // Add terminator only if there are no reductions.
2876 if (initVals.empty())
2877 ParallelOp::ensureTerminator(*bodyRegion, builder, result.location);
2878}
2879
2880void ParallelOp::build(
2881 OpBuilder &builder, OperationState &result, ValueRange lowerBounds,
2882 ValueRange upperBounds, ValueRange steps,
2883 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
2884 // Only pass a non-null wrapper if bodyBuilderFn is non-null itself. Make sure
2885 // we don't capture a reference to a temporary by constructing the lambda at
2886 // function level.
2887 auto wrappedBuilderFn = [&bodyBuilderFn](OpBuilder &nestedBuilder,
2888 Location nestedLoc, ValueRange ivs,
2889 ValueRange) {
2890 bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
2891 };
2892 function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)> wrapper;
2893 if (bodyBuilderFn)
2894 wrapper = wrappedBuilderFn;
2895
2896 build(builder, result, lowerBounds, upperBounds, steps, ValueRange(),
2897 wrapper);
2898}
2899
2900LogicalResult ParallelOp::verify() {
2901 // Check that there is at least one value in lowerBound, upperBound and step.
2902 // It is sufficient to test only step, because it is ensured already that the
2903 // number of elements in lowerBound, upperBound and step are the same.
2904 Operation::operand_range stepValues = getStep();
2905 if (stepValues.empty())
2906 return emitOpError(
2907 "needs at least one tuple element for lowerBound, upperBound and step");
2908
2909 // Check whether all constant step values are positive.
2910 for (Value stepValue : stepValues)
2911 if (auto cst = getConstantIntValue(stepValue))
2912 if (*cst <= 0)
2913 return emitOpError("constant step operand must be positive");
2914
2915 // Check that the body defines the same number of block arguments as the
2916 // number of tuple elements in step.
2917 Block *body = getBody();
2918 if (body->getNumArguments() != stepValues.size())
2919 return emitOpError() << "expects the same number of induction variables: "
2920 << body->getNumArguments()
2921 << " as bound and step values: " << stepValues.size();
2922 for (auto arg : body->getArguments())
2923 if (!arg.getType().isIndex())
2924 return emitOpError(
2925 "expects arguments for the induction variable to be of index type");
2926
2927 // Check that the terminator is an scf.reduce op.
2928 auto reduceOp = verifyAndGetTerminator<scf::ReduceOp>(
2929 *this, getRegion(), "expects body to terminate with 'scf.reduce'");
2930 if (!reduceOp)
2931 return failure();
2932
2933 // Check that the number of results is the same as the number of reductions.
2934 auto resultsSize = getResults().size();
2935 auto reductionsSize = reduceOp.getReductions().size();
2936 auto initValsSize = getInitVals().size();
2937 if (resultsSize != reductionsSize)
2938 return emitOpError() << "expects number of results: " << resultsSize
2939 << " to be the same as number of reductions: "
2940 << reductionsSize;
2941 if (resultsSize != initValsSize)
2942 return emitOpError() << "expects number of results: " << resultsSize
2943 << " to be the same as number of initial values: "
2944 << initValsSize;
2945
2946 // Check that the types of the results and reductions are the same.
2947 for (int64_t i = 0; i < static_cast<int64_t>(reductionsSize); ++i) {
2948 auto resultType = getOperation()->getResult(i).getType();
2949 auto reductionOperandType = reduceOp.getOperands()[i].getType();
2950 if (resultType != reductionOperandType)
2951 return reduceOp.emitOpError()
2952 << "expects type of " << i
2953 << "-th reduction operand: " << reductionOperandType
2954 << " to be the same as the " << i
2955 << "-th result type: " << resultType;
2956 }
2957 return success();
2958}
2959
2960ParseResult ParallelOp::parse(OpAsmParser &parser, OperationState &result) {
2961 auto &builder = parser.getBuilder();
2962 // Parse an opening `(` followed by induction variables followed by `)`
2963 SmallVector<OpAsmParser::Argument, 4> ivs;
2964 if (parser.parseArgumentList(ivs, OpAsmParser::Delimiter::Paren))
2965 return failure();
2966
2967 // Parse loop bounds.
2968 SmallVector<OpAsmParser::UnresolvedOperand, 4> lower;
2969 if (parser.parseEqual() ||
2970 parser.parseOperandList(lower, ivs.size(),
2971 OpAsmParser::Delimiter::Paren) ||
2972 parser.resolveOperands(lower, builder.getIndexType(), result.operands))
2973 return failure();
2974
2975 SmallVector<OpAsmParser::UnresolvedOperand, 4> upper;
2976 if (parser.parseKeyword("to") ||
2977 parser.parseOperandList(upper, ivs.size(),
2978 OpAsmParser::Delimiter::Paren) ||
2979 parser.resolveOperands(upper, builder.getIndexType(), result.operands))
2980 return failure();
2981
2982 // Parse step values.
2983 SmallVector<OpAsmParser::UnresolvedOperand, 4> steps;
2984 if (parser.parseKeyword("step") ||
2985 parser.parseOperandList(steps, ivs.size(),
2986 OpAsmParser::Delimiter::Paren) ||
2987 parser.resolveOperands(steps, builder.getIndexType(), result.operands))
2988 return failure();
2989
2990 // Parse init values.
2991 SmallVector<OpAsmParser::UnresolvedOperand, 4> initVals;
2992 if (succeeded(parser.parseOptionalKeyword("init"))) {
2993 if (parser.parseOperandList(initVals, OpAsmParser::Delimiter::Paren))
2994 return failure();
2995 }
2996
2997 // Parse optional results in case there is a reduce.
2998 if (parser.parseOptionalArrowTypeList(result.types))
2999 return failure();
3000
3001 // Now parse the body.
3002 Region *body = result.addRegion();
3003 for (auto &iv : ivs)
3004 iv.type = builder.getIndexType();
3005 if (parser.parseRegion(*body, ivs))
3006 return failure();
3007
3008 // Set `operandSegmentSizes` attribute.
3009 result.addAttribute(
3010 ParallelOp::getOperandSegmentSizeAttr(),
3011 builder.getDenseI32ArrayAttr({static_cast<int32_t>(lower.size()),
3012 static_cast<int32_t>(upper.size()),
3013 static_cast<int32_t>(steps.size()),
3014 static_cast<int32_t>(initVals.size())}));
3015
3016 // Parse attributes.
3017 if (parser.parseOptionalAttrDict(result.attributes) ||
3018 parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
3019 result.operands))
3020 return failure();
3021
3022 // Add a terminator if none was parsed.
3023 ParallelOp::ensureTerminator(*body, builder, result.location);
3024 return success();
3025}
3026
3027void ParallelOp::print(OpAsmPrinter &p) {
3028 p << " (" << getBody()->getArguments() << ") = (" << getLowerBound()
3029 << ") to (" << getUpperBound() << ") step (" << getStep() << ")";
3030 if (!getInitVals().empty())
3031 p << " init (" << getInitVals() << ")";
3032 p.printOptionalArrowTypeList(getResultTypes());
3033 p << ' ';
3034 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
3035 p.printOptionalAttrDict(
3036 (*this)->getAttrs(),
3037 /*elidedAttrs=*/ParallelOp::getOperandSegmentSizeAttr());
3038}
3039
3040SmallVector<Region *> ParallelOp::getLoopRegions() { return {&getRegion()}; }
3041
3042std::optional<SmallVector<Value>> ParallelOp::getLoopInductionVars() {
3043 return SmallVector<Value>{getBody()->getArguments()};
3044}
3045
3046std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopLowerBounds() {
3047 return getLowerBound();
3048}
3049
3050std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopUpperBounds() {
3051 return getUpperBound();
3052}
3053
3054std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopSteps() {
3055 return getStep();
3056}
3057
3058ParallelOp mlir::scf::getParallelForInductionVarOwner(Value val) {
3059 auto ivArg = llvm::dyn_cast<BlockArgument>(Val&: val);
3060 if (!ivArg)
3061 return ParallelOp();
3062 assert(ivArg.getOwner() && "unlinked block argument");
3063 auto *containingOp = ivArg.getOwner()->getParentOp();
3064 return dyn_cast<ParallelOp>(containingOp);
3065}
3066
3067namespace {
3068// Collapse loop dimensions that perform a single iteration.
3069struct ParallelOpSingleOrZeroIterationDimsFolder
3070 : public OpRewritePattern<ParallelOp> {
3071 using OpRewritePattern<ParallelOp>::OpRewritePattern;
3072
3073 LogicalResult matchAndRewrite(ParallelOp op,
3074 PatternRewriter &rewriter) const override {
3075 Location loc = op.getLoc();
3076
3077 // Compute new loop bounds that omit all single-iteration loop dimensions.
3078 SmallVector<Value> newLowerBounds, newUpperBounds, newSteps;
3079 IRMapping mapping;
3080 for (auto [lb, ub, step, iv] :
3081 llvm::zip(op.getLowerBound(), op.getUpperBound(), op.getStep(),
3082 op.getInductionVars())) {
3083 auto numIterations = constantTripCount(lb, ub, step);
3084 if (numIterations.has_value()) {
3085 // Remove the loop if it performs zero iterations.
3086 if (*numIterations == 0) {
3087 rewriter.replaceOp(op, op.getInitVals());
3088 return success();
3089 }
3090 // Replace the loop induction variable by the lower bound if the loop
3091 // performs a single iteration. Otherwise, copy the loop bounds.
3092 if (*numIterations == 1) {
3093 mapping.map(iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb));
3094 continue;
3095 }
3096 }
3097 newLowerBounds.push_back(lb);
3098 newUpperBounds.push_back(ub);
3099 newSteps.push_back(step);
3100 }
3101 // Exit if none of the loop dimensions perform a single iteration.
3102 if (newLowerBounds.size() == op.getLowerBound().size())
3103 return failure();
3104
3105 if (newLowerBounds.empty()) {
3106 // All of the loop dimensions perform a single iteration. Inline
3107 // loop body and nested ReduceOp's
3108 SmallVector<Value> results;
3109 results.reserve(N: op.getInitVals().size());
3110 for (auto &bodyOp : op.getBody()->without_terminator())
3111 rewriter.clone(bodyOp, mapping);
3112 auto reduceOp = cast<ReduceOp>(op.getBody()->getTerminator());
3113 for (int64_t i = 0, e = reduceOp.getReductions().size(); i < e; ++i) {
3114 Block &reduceBlock = reduceOp.getReductions()[i].front();
3115 auto initValIndex = results.size();
3116 mapping.map(reduceBlock.getArgument(i: 0), op.getInitVals()[initValIndex]);
3117 mapping.map(reduceBlock.getArgument(i: 1),
3118 mapping.lookupOrDefault(reduceOp.getOperands()[i]));
3119 for (auto &reduceBodyOp : reduceBlock.without_terminator())
3120 rewriter.clone(reduceBodyOp, mapping);
3121
3122 auto result = mapping.lookupOrDefault(
3123 cast<ReduceReturnOp>(reduceBlock.getTerminator()).getResult());
3124 results.push_back(Elt: result);
3125 }
3126
3127 rewriter.replaceOp(op, results);
3128 return success();
3129 }
3130 // Replace the parallel loop by lower-dimensional parallel loop.
3131 auto newOp =
3132 rewriter.create<ParallelOp>(op.getLoc(), newLowerBounds, newUpperBounds,
3133 newSteps, op.getInitVals(), nullptr);
3134 // Erase the empty block that was inserted by the builder.
3135 rewriter.eraseBlock(block: newOp.getBody());
3136 // Clone the loop body and remap the block arguments of the collapsed loops
3137 // (inlining does not support a cancellable block argument mapping).
3138 rewriter.cloneRegionBefore(op.getRegion(), newOp.getRegion(),
3139 newOp.getRegion().begin(), mapping);
3140 rewriter.replaceOp(op, newOp.getResults());
3141 return success();
3142 }
3143};
3144
3145struct MergeNestedParallelLoops : public OpRewritePattern<ParallelOp> {
3146 using OpRewritePattern<ParallelOp>::OpRewritePattern;
3147
3148 LogicalResult matchAndRewrite(ParallelOp op,
3149 PatternRewriter &rewriter) const override {
3150 Block &outerBody = *op.getBody();
3151 if (!llvm::hasSingleElement(C: outerBody.without_terminator()))
3152 return failure();
3153
3154 auto innerOp = dyn_cast<ParallelOp>(outerBody.front());
3155 if (!innerOp)
3156 return failure();
3157
3158 for (auto val : outerBody.getArguments())
3159 if (llvm::is_contained(innerOp.getLowerBound(), val) ||
3160 llvm::is_contained(innerOp.getUpperBound(), val) ||
3161 llvm::is_contained(innerOp.getStep(), val))
3162 return failure();
3163
3164 // Reductions are not supported yet.
3165 if (!op.getInitVals().empty() || !innerOp.getInitVals().empty())
3166 return failure();
3167
3168 auto bodyBuilder = [&](OpBuilder &builder, Location /*loc*/,
3169 ValueRange iterVals, ValueRange) {
3170 Block &innerBody = *innerOp.getBody();
3171 assert(iterVals.size() ==
3172 (outerBody.getNumArguments() + innerBody.getNumArguments()));
3173 IRMapping mapping;
3174 mapping.map(from: outerBody.getArguments(),
3175 to: iterVals.take_front(n: outerBody.getNumArguments()));
3176 mapping.map(from: innerBody.getArguments(),
3177 to: iterVals.take_back(n: innerBody.getNumArguments()));
3178 for (Operation &op : innerBody.without_terminator())
3179 builder.clone(op, mapping);
3180 };
3181
3182 auto concatValues = [](const auto &first, const auto &second) {
3183 SmallVector<Value> ret;
3184 ret.reserve(N: first.size() + second.size());
3185 ret.assign(first.begin(), first.end());
3186 ret.append(second.begin(), second.end());
3187 return ret;
3188 };
3189
3190 auto newLowerBounds =
3191 concatValues(op.getLowerBound(), innerOp.getLowerBound());
3192 auto newUpperBounds =
3193 concatValues(op.getUpperBound(), innerOp.getUpperBound());
3194 auto newSteps = concatValues(op.getStep(), innerOp.getStep());
3195
3196 rewriter.replaceOpWithNewOp<ParallelOp>(op, newLowerBounds, newUpperBounds,
3197 newSteps, std::nullopt,
3198 bodyBuilder);
3199 return success();
3200 }
3201};
3202
3203} // namespace
3204
3205void ParallelOp::getCanonicalizationPatterns(RewritePatternSet &results,
3206 MLIRContext *context) {
3207 results
3208 .add<ParallelOpSingleOrZeroIterationDimsFolder, MergeNestedParallelLoops>(
3209 context);
3210}
3211
3212/// Given the region at `index`, or the parent operation if `index` is None,
3213/// return the successor regions. These are the regions that may be selected
3214/// during the flow of control. `operands` is a set of optional attributes that
3215/// correspond to a constant value for each operand, or null if that operand is
3216/// not a constant.
3217void ParallelOp::getSuccessorRegions(
3218 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
3219 // Both the operation itself and the region may be branching into the body or
3220 // back into the operation itself. It is possible for loop not to enter the
3221 // body.
3222 regions.push_back(RegionSuccessor(&getRegion()));
3223 regions.push_back(RegionSuccessor());
3224}
3225
3226//===----------------------------------------------------------------------===//
3227// ReduceOp
3228//===----------------------------------------------------------------------===//
3229
3230void ReduceOp::build(OpBuilder &builder, OperationState &result) {}
3231
3232void ReduceOp::build(OpBuilder &builder, OperationState &result,
3233 ValueRange operands) {
3234 result.addOperands(operands);
3235 for (Value v : operands) {
3236 OpBuilder::InsertionGuard guard(builder);
3237 Region *bodyRegion = result.addRegion();
3238 builder.createBlock(bodyRegion, {},
3239 ArrayRef<Type>{v.getType(), v.getType()},
3240 {result.location, result.location});
3241 }
3242}
3243
3244LogicalResult ReduceOp::verifyRegions() {
3245 // The region of a ReduceOp has two arguments of the same type as its
3246 // corresponding operand.
3247 for (int64_t i = 0, e = getReductions().size(); i < e; ++i) {
3248 auto type = getOperands()[i].getType();
3249 Block &block = getReductions()[i].front();
3250 if (block.empty())
3251 return emitOpError() << i << "-th reduction has an empty body";
3252 if (block.getNumArguments() != 2 ||
3253 llvm::any_of(block.getArguments(), [&](const BlockArgument &arg) {
3254 return arg.getType() != type;
3255 }))
3256 return emitOpError() << "expected two block arguments with type " << type
3257 << " in the " << i << "-th reduction region";
3258
3259 // Check that the block is terminated by a ReduceReturnOp.
3260 if (!isa<ReduceReturnOp>(block.getTerminator()))
3261 return emitOpError("reduction bodies must be terminated with an "
3262 "'scf.reduce.return' op");
3263 }
3264
3265 return success();
3266}
3267
3268MutableOperandRange
3269ReduceOp::getMutableSuccessorOperands(RegionBranchPoint point) {
3270 // No operands are forwarded to the next iteration.
3271 return MutableOperandRange(getOperation(), /*start=*/0, /*length=*/0);
3272}
3273
3274//===----------------------------------------------------------------------===//
3275// ReduceReturnOp
3276//===----------------------------------------------------------------------===//
3277
3278LogicalResult ReduceReturnOp::verify() {
3279 // The type of the return value should be the same type as the types of the
3280 // block arguments of the reduction body.
3281 Block *reductionBody = getOperation()->getBlock();
3282 // Should already be verified by an op trait.
3283 assert(isa<ReduceOp>(reductionBody->getParentOp()) && "expected scf.reduce");
3284 Type expectedResultType = reductionBody->getArgument(0).getType();
3285 if (expectedResultType != getResult().getType())
3286 return emitOpError() << "must have type " << expectedResultType
3287 << " (the type of the reduction inputs)";
3288 return success();
3289}
3290
3291//===----------------------------------------------------------------------===//
3292// WhileOp
3293//===----------------------------------------------------------------------===//
3294
3295void WhileOp::build(::mlir::OpBuilder &odsBuilder,
3296 ::mlir::OperationState &odsState, TypeRange resultTypes,
3297 ValueRange inits, BodyBuilderFn beforeBuilder,
3298 BodyBuilderFn afterBuilder) {
3299 odsState.addOperands(inits);
3300 odsState.addTypes(resultTypes);
3301
3302 OpBuilder::InsertionGuard guard(odsBuilder);
3303
3304 // Build before region.
3305 SmallVector<Location, 4> beforeArgLocs;
3306 beforeArgLocs.reserve(inits.size());
3307 for (Value operand : inits) {
3308 beforeArgLocs.push_back(operand.getLoc());
3309 }
3310
3311 Region *beforeRegion = odsState.addRegion();
3312 Block *beforeBlock = odsBuilder.createBlock(beforeRegion, /*insertPt=*/{},
3313 inits.getTypes(), beforeArgLocs);
3314 if (beforeBuilder)
3315 beforeBuilder(odsBuilder, odsState.location, beforeBlock->getArguments());
3316
3317 // Build after region.
3318 SmallVector<Location, 4> afterArgLocs(resultTypes.size(), odsState.location);
3319
3320 Region *afterRegion = odsState.addRegion();
3321 Block *afterBlock = odsBuilder.createBlock(afterRegion, /*insertPt=*/{},
3322 resultTypes, afterArgLocs);
3323
3324 if (afterBuilder)
3325 afterBuilder(odsBuilder, odsState.location, afterBlock->getArguments());
3326}
3327
3328ConditionOp WhileOp::getConditionOp() {
3329 return cast<ConditionOp>(getBeforeBody()->getTerminator());
3330}
3331
3332YieldOp WhileOp::getYieldOp() {
3333 return cast<YieldOp>(getAfterBody()->getTerminator());
3334}
3335
3336std::optional<MutableArrayRef<OpOperand>> WhileOp::getYieldedValuesMutable() {
3337 return getYieldOp().getResultsMutable();
3338}
3339
3340Block::BlockArgListType WhileOp::getBeforeArguments() {
3341 return getBeforeBody()->getArguments();
3342}
3343
3344Block::BlockArgListType WhileOp::getAfterArguments() {
3345 return getAfterBody()->getArguments();
3346}
3347
3348Block::BlockArgListType WhileOp::getRegionIterArgs() {
3349 return getBeforeArguments();
3350}
3351
3352OperandRange WhileOp::getEntrySuccessorOperands(RegionBranchPoint point) {
3353 assert(point == getBefore() &&
3354 "WhileOp is expected to branch only to the first region");
3355 return getInits();
3356}
3357
3358void WhileOp::getSuccessorRegions(RegionBranchPoint point,
3359 SmallVectorImpl<RegionSuccessor> &regions) {
3360 // The parent op always branches to the condition region.
3361 if (point.isParent()) {
3362 regions.emplace_back(&getBefore(), getBefore().getArguments());
3363 return;
3364 }
3365
3366 assert(llvm::is_contained({&getAfter(), &getBefore()}, point) &&
3367 "there are only two regions in a WhileOp");
3368 // The body region always branches back to the condition region.
3369 if (point == getAfter()) {
3370 regions.emplace_back(&getBefore(), getBefore().getArguments());
3371 return;
3372 }
3373
3374 regions.emplace_back(getResults());
3375 regions.emplace_back(&getAfter(), getAfter().getArguments());
3376}
3377
3378SmallVector<Region *> WhileOp::getLoopRegions() {
3379 return {&getBefore(), &getAfter()};
3380}
3381
3382/// Parses a `while` op.
3383///
3384/// op ::= `scf.while` assignments `:` function-type region `do` region
3385/// `attributes` attribute-dict
3386/// initializer ::= /* empty */ | `(` assignment-list `)`
3387/// assignment-list ::= assignment | assignment `,` assignment-list
3388/// assignment ::= ssa-value `=` ssa-value
3389ParseResult scf::WhileOp::parse(OpAsmParser &parser, OperationState &result) {
3390 SmallVector<OpAsmParser::Argument, 4> regionArgs;
3391 SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
3392 Region *before = result.addRegion();
3393 Region *after = result.addRegion();
3394
3395 OptionalParseResult listResult =
3396 parser.parseOptionalAssignmentList(regionArgs, operands);
3397 if (listResult.has_value() && failed(listResult.value()))
3398 return failure();
3399
3400 FunctionType functionType;
3401 SMLoc typeLoc = parser.getCurrentLocation();
3402 if (failed(parser.parseColonType(functionType)))
3403 return failure();
3404
3405 result.addTypes(functionType.getResults());
3406
3407 if (functionType.getNumInputs() != operands.size()) {
3408 return parser.emitError(typeLoc)
3409 << "expected as many input types as operands "
3410 << "(expected " << operands.size() << " got "
3411 << functionType.getNumInputs() << ")";
3412 }
3413
3414 // Resolve input operands.
3415 if (failed(parser.resolveOperands(operands, functionType.getInputs(),
3416 parser.getCurrentLocation(),
3417 result.operands)))
3418 return failure();
3419
3420 // Propagate the types into the region arguments.
3421 for (size_t i = 0, e = regionArgs.size(); i != e; ++i)
3422 regionArgs[i].type = functionType.getInput(i);
3423
3424 return failure(parser.parseRegion(*before, regionArgs) ||
3425 parser.parseKeyword("do") || parser.parseRegion(*after) ||
3426 parser.parseOptionalAttrDictWithKeyword(result.attributes));
3427}
3428
3429/// Prints a `while` op.
3430void scf::WhileOp::print(OpAsmPrinter &p) {
3431 printInitializationList(p, getBeforeArguments(), getInits(), " ");
3432 p << " : ";
3433 p.printFunctionalType(getInits().getTypes(), getResults().getTypes());
3434 p << ' ';
3435 p.printRegion(getBefore(), /*printEntryBlockArgs=*/false);
3436 p << " do ";
3437 p.printRegion(getAfter());
3438 p.printOptionalAttrDictWithKeyword((*this)->getAttrs());
3439}
3440
3441/// Verifies that two ranges of types match, i.e. have the same number of
3442/// entries and that types are pairwise equals. Reports errors on the given
3443/// operation in case of mismatch.
3444template <typename OpTy>
3445static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left,
3446 TypeRange right, StringRef message) {
3447 if (left.size() != right.size())
3448 return op.emitOpError("expects the same number of ") << message;
3449
3450 for (unsigned i = 0, e = left.size(); i < e; ++i) {
3451 if (left[i] != right[i]) {
3452 InFlightDiagnostic diag = op.emitOpError("expects the same types for ")
3453 << message;
3454 diag.attachNote() << "for argument " << i << ", found " << left[i]
3455 << " and " << right[i];
3456 return diag;
3457 }
3458 }
3459
3460 return success();
3461}
3462
3463LogicalResult scf::WhileOp::verify() {
3464 auto beforeTerminator = verifyAndGetTerminator<scf::ConditionOp>(
3465 *this, getBefore(),
3466 "expects the 'before' region to terminate with 'scf.condition'");
3467 if (!beforeTerminator)
3468 return failure();
3469
3470 auto afterTerminator = verifyAndGetTerminator<scf::YieldOp>(
3471 *this, getAfter(),
3472 "expects the 'after' region to terminate with 'scf.yield'");
3473 return success(afterTerminator != nullptr);
3474}
3475
3476namespace {
3477/// Replace uses of the condition within the do block with true, since otherwise
3478/// the block would not be evaluated.
3479///
3480/// scf.while (..) : (i1, ...) -> ... {
3481/// %condition = call @evaluate_condition() : () -> i1
3482/// scf.condition(%condition) %condition : i1, ...
3483/// } do {
3484/// ^bb0(%arg0: i1, ...):
3485/// use(%arg0)
3486/// ...
3487///
3488/// becomes
3489/// scf.while (..) : (i1, ...) -> ... {
3490/// %condition = call @evaluate_condition() : () -> i1
3491/// scf.condition(%condition) %condition : i1, ...
3492/// } do {
3493/// ^bb0(%arg0: i1, ...):
3494/// use(%true)
3495/// ...
3496struct WhileConditionTruth : public OpRewritePattern<WhileOp> {
3497 using OpRewritePattern<WhileOp>::OpRewritePattern;
3498
3499 LogicalResult matchAndRewrite(WhileOp op,
3500 PatternRewriter &rewriter) const override {
3501 auto term = op.getConditionOp();
3502
3503 // These variables serve to prevent creating duplicate constants
3504 // and hold constant true or false values.
3505 Value constantTrue = nullptr;
3506
3507 bool replaced = false;
3508 for (auto yieldedAndBlockArgs :
3509 llvm::zip(term.getArgs(), op.getAfterArguments())) {
3510 if (std::get<0>(yieldedAndBlockArgs) == term.getCondition()) {
3511 if (!std::get<1>(yieldedAndBlockArgs).use_empty()) {
3512 if (!constantTrue)
3513 constantTrue = rewriter.create<arith::ConstantOp>(
3514 op.getLoc(), term.getCondition().getType(),
3515 rewriter.getBoolAttr(true));
3516
3517 rewriter.replaceAllUsesWith(std::get<1>(yieldedAndBlockArgs),
3518 constantTrue);
3519 replaced = true;
3520 }
3521 }
3522 }
3523 return success(IsSuccess: replaced);
3524 }
3525};
3526
3527/// Remove loop invariant arguments from `before` block of scf.while.
3528/// A before block argument is considered loop invariant if :-
3529/// 1. i-th yield operand is equal to the i-th while operand.
3530/// 2. i-th yield operand is k-th after block argument which is (k+1)-th
3531/// condition operand AND this (k+1)-th condition operand is equal to i-th
3532/// iter argument/while operand.
3533/// For the arguments which are removed, their uses inside scf.while
3534/// are replaced with their corresponding initial value.
3535///
3536/// Eg:
3537/// INPUT :-
3538/// %res = scf.while <...> iter_args(%arg0_before = %a, %arg1_before = %b,
3539/// ..., %argN_before = %N)
3540/// {
3541/// ...
3542/// scf.condition(%cond) %arg1_before, %arg0_before,
3543/// %arg2_before, %arg0_before, ...
3544/// } do {
3545/// ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2,
3546/// ..., %argK_after):
3547/// ...
3548/// scf.yield %arg0_after_2, %b, %arg1_after, ..., %argN
3549/// }
3550///
3551/// OUTPUT :-
3552/// %res = scf.while <...> iter_args(%arg2_before = %c, ..., %argN_before =
3553/// %N)
3554/// {
3555/// ...
3556/// scf.condition(%cond) %b, %a, %arg2_before, %a, ...
3557/// } do {
3558/// ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2,
3559/// ..., %argK_after):
3560/// ...
3561/// scf.yield %arg1_after, ..., %argN
3562/// }
3563///
3564/// EXPLANATION:
3565/// We iterate over each yield operand.
3566/// 1. 0-th yield operand %arg0_after_2 is 4-th condition operand
3567/// %arg0_before, which in turn is the 0-th iter argument. So we
3568/// remove 0-th before block argument and yield operand, and replace
3569/// all uses of the 0-th before block argument with its initial value
3570/// %a.
3571/// 2. 1-th yield operand %b is equal to the 1-th iter arg's initial
3572/// value. So we remove this operand and the corresponding before
3573/// block argument and replace all uses of 1-th before block argument
3574/// with %b.
3575struct RemoveLoopInvariantArgsFromBeforeBlock
3576 : public OpRewritePattern<WhileOp> {
3577 using OpRewritePattern<WhileOp>::OpRewritePattern;
3578
3579 LogicalResult matchAndRewrite(WhileOp op,
3580 PatternRewriter &rewriter) const override {
3581 Block &afterBlock = *op.getAfterBody();
3582 Block::BlockArgListType beforeBlockArgs = op.getBeforeArguments();
3583 ConditionOp condOp = op.getConditionOp();
3584 OperandRange condOpArgs = condOp.getArgs();
3585 Operation *yieldOp = afterBlock.getTerminator();
3586 ValueRange yieldOpArgs = yieldOp->getOperands();
3587
3588 bool canSimplify = false;
3589 for (const auto &it :
3590 llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) {
3591 auto index = static_cast<unsigned>(it.index());
3592 auto [initVal, yieldOpArg] = it.value();
3593 // If i-th yield operand is equal to the i-th operand of the scf.while,
3594 // the i-th before block argument is a loop invariant.
3595 if (yieldOpArg == initVal) {
3596 canSimplify = true;
3597 break;
3598 }
3599 // If the i-th yield operand is k-th after block argument, then we check
3600 // if the (k+1)-th condition op operand is equal to either the i-th before
3601 // block argument or the initial value of i-th before block argument. If
3602 // the comparison results `true`, i-th before block argument is a loop
3603 // invariant.
3604 auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg);
3605 if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
3606 Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
3607 if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
3608 canSimplify = true;
3609 break;
3610 }
3611 }
3612 }
3613
3614 if (!canSimplify)
3615 return failure();
3616
3617 SmallVector<Value> newInitArgs, newYieldOpArgs;
3618 DenseMap<unsigned, Value> beforeBlockInitValMap;
3619 SmallVector<Location> newBeforeBlockArgLocs;
3620 for (const auto &it :
3621 llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) {
3622 auto index = static_cast<unsigned>(it.index());
3623 auto [initVal, yieldOpArg] = it.value();
3624
3625 // If i-th yield operand is equal to the i-th operand of the scf.while,
3626 // the i-th before block argument is a loop invariant.
3627 if (yieldOpArg == initVal) {
3628 beforeBlockInitValMap.insert({index, initVal});
3629 continue;
3630 } else {
3631 // If the i-th yield operand is k-th after block argument, then we check
3632 // if the (k+1)-th condition op operand is equal to either the i-th
3633 // before block argument or the initial value of i-th before block
3634 // argument. If the comparison results `true`, i-th before block
3635 // argument is a loop invariant.
3636 auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg);
3637 if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
3638 Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
3639 if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
3640 beforeBlockInitValMap.insert({index, initVal});
3641 continue;
3642 }
3643 }
3644 }
3645 newInitArgs.emplace_back(initVal);
3646 newYieldOpArgs.emplace_back(yieldOpArg);
3647 newBeforeBlockArgLocs.emplace_back(beforeBlockArgs[index].getLoc());
3648 }
3649
3650 {
3651 OpBuilder::InsertionGuard g(rewriter);
3652 rewriter.setInsertionPoint(yieldOp);
3653 rewriter.replaceOpWithNewOp<YieldOp>(yieldOp, newYieldOpArgs);
3654 }
3655
3656 auto newWhile =
3657 rewriter.create<WhileOp>(op.getLoc(), op.getResultTypes(), newInitArgs);
3658
3659 Block &newBeforeBlock = *rewriter.createBlock(
3660 &newWhile.getBefore(), /*insertPt*/ {},
3661 ValueRange(newYieldOpArgs).getTypes(), newBeforeBlockArgLocs);
3662
3663 Block &beforeBlock = *op.getBeforeBody();
3664 SmallVector<Value> newBeforeBlockArgs(beforeBlock.getNumArguments());
3665 // For each i-th before block argument we find it's replacement value as :-
3666 // 1. If i-th before block argument is a loop invariant, we fetch it's
3667 // initial value from `beforeBlockInitValMap` by querying for key `i`.
3668 // 2. Else we fetch j-th new before block argument as the replacement
3669 // value of i-th before block argument.
3670 for (unsigned i = 0, j = 0, n = beforeBlock.getNumArguments(); i < n; i++) {
3671 // If the index 'i' argument was a loop invariant we fetch it's initial
3672 // value from `beforeBlockInitValMap`.
3673 if (beforeBlockInitValMap.count(Val: i) != 0)
3674 newBeforeBlockArgs[i] = beforeBlockInitValMap[i];
3675 else
3676 newBeforeBlockArgs[i] = newBeforeBlock.getArgument(i: j++);
3677 }
3678
3679 rewriter.mergeBlocks(source: &beforeBlock, dest: &newBeforeBlock, argValues: newBeforeBlockArgs);
3680 rewriter.inlineRegionBefore(op.getAfter(), newWhile.getAfter(),
3681 newWhile.getAfter().begin());
3682
3683 rewriter.replaceOp(op, newWhile.getResults());
3684 return success();
3685 }
3686};
3687
3688/// Remove loop invariant value from result (condition op) of scf.while.
3689/// A value is considered loop invariant if the final value yielded by
3690/// scf.condition is defined outside of the `before` block. We remove the
3691/// corresponding argument in `after` block and replace the use with the value.
3692/// We also replace the use of the corresponding result of scf.while with the
3693/// value.
3694///
3695/// Eg:
3696/// INPUT :-
3697/// %res_input:K = scf.while <...> iter_args(%arg0_before = , ...,
3698/// %argN_before = %N) {
3699/// ...
3700/// scf.condition(%cond) %arg0_before, %a, %b, %arg1_before, ...
3701/// } do {
3702/// ^bb0(%arg0_after, %arg1_after, %arg2_after, ..., %argK_after):
3703/// ...
3704/// some_func(%arg1_after)
3705/// ...
3706/// scf.yield %arg0_after, %arg2_after, ..., %argN_after
3707/// }
3708///
3709/// OUTPUT :-
3710/// %res_output:M = scf.while <...> iter_args(%arg0 = , ..., %argN = %N) {
3711/// ...
3712/// scf.condition(%cond) %arg0, %arg1, ..., %argM
3713/// } do {
3714/// ^bb0(%arg0, %arg3, ..., %argM):
3715/// ...
3716/// some_func(%a)
3717/// ...
3718/// scf.yield %arg0, %b, ..., %argN
3719/// }
3720///
3721/// EXPLANATION:
3722/// 1. The 1-th and 2-th operand of scf.condition are defined outside the
3723/// before block of scf.while, so they get removed.
3724/// 2. %res_input#1's uses are replaced by %a and %res_input#2's uses are
3725/// replaced by %b.
3726/// 3. The corresponding after block argument %arg1_after's uses are
3727/// replaced by %a and %arg2_after's uses are replaced by %b.
3728struct RemoveLoopInvariantValueYielded : public OpRewritePattern<WhileOp> {
3729 using OpRewritePattern<WhileOp>::OpRewritePattern;
3730
3731 LogicalResult matchAndRewrite(WhileOp op,
3732 PatternRewriter &rewriter) const override {
3733 Block &beforeBlock = *op.getBeforeBody();
3734 ConditionOp condOp = op.getConditionOp();
3735 OperandRange condOpArgs = condOp.getArgs();
3736
3737 bool canSimplify = false;
3738 for (Value condOpArg : condOpArgs) {
3739 // Those values not defined within `before` block will be considered as
3740 // loop invariant values. We map the corresponding `index` with their
3741 // value.
3742 if (condOpArg.getParentBlock() != &beforeBlock) {
3743 canSimplify = true;
3744 break;
3745 }
3746 }
3747
3748 if (!canSimplify)
3749 return failure();
3750
3751 Block::BlockArgListType afterBlockArgs = op.getAfterArguments();
3752
3753 SmallVector<Value> newCondOpArgs;
3754 SmallVector<Type> newAfterBlockType;
3755 DenseMap<unsigned, Value> condOpInitValMap;
3756 SmallVector<Location> newAfterBlockArgLocs;
3757 for (const auto &it : llvm::enumerate(condOpArgs)) {
3758 auto index = static_cast<unsigned>(it.index());
3759 Value condOpArg = it.value();
3760 // Those values not defined within `before` block will be considered as
3761 // loop invariant values. We map the corresponding `index` with their
3762 // value.
3763 if (condOpArg.getParentBlock() != &beforeBlock) {
3764 condOpInitValMap.insert({index, condOpArg});
3765 } else {
3766 newCondOpArgs.emplace_back(condOpArg);
3767 newAfterBlockType.emplace_back(condOpArg.getType());
3768 newAfterBlockArgLocs.emplace_back(afterBlockArgs[index].getLoc());
3769 }
3770 }
3771
3772 {
3773 OpBuilder::InsertionGuard g(rewriter);
3774 rewriter.setInsertionPoint(condOp);
3775 rewriter.replaceOpWithNewOp<ConditionOp>(condOp, condOp.getCondition(),
3776 newCondOpArgs);
3777 }
3778
3779 auto newWhile = rewriter.create<WhileOp>(op.getLoc(), newAfterBlockType,
3780 op.getOperands());
3781
3782 Block &newAfterBlock =
3783 *rewriter.createBlock(&newWhile.getAfter(), /*insertPt*/ {},
3784 newAfterBlockType, newAfterBlockArgLocs);
3785
3786 Block &afterBlock = *op.getAfterBody();
3787 // Since a new scf.condition op was created, we need to fetch the new
3788 // `after` block arguments which will be used while replacing operations of
3789 // previous scf.while's `after` blocks. We'd also be fetching new result
3790 // values too.
3791 SmallVector<Value> newAfterBlockArgs(afterBlock.getNumArguments());
3792 SmallVector<Value> newWhileResults(afterBlock.getNumArguments());
3793 for (unsigned i = 0, j = 0, n = afterBlock.getNumArguments(); i < n; i++) {
3794 Value afterBlockArg, result;
3795 // If index 'i' argument was loop invariant we fetch it's value from the
3796 // `condOpInitMap` map.
3797 if (condOpInitValMap.count(Val: i) != 0) {
3798 afterBlockArg = condOpInitValMap[i];
3799 result = afterBlockArg;
3800 } else {
3801 afterBlockArg = newAfterBlock.getArgument(i: j);
3802 result = newWhile.getResult(j);
3803 j++;
3804 }
3805 newAfterBlockArgs[i] = afterBlockArg;
3806 newWhileResults[i] = result;
3807 }
3808
3809 rewriter.mergeBlocks(source: &afterBlock, dest: &newAfterBlock, argValues: newAfterBlockArgs);
3810 rewriter.inlineRegionBefore(op.getBefore(), newWhile.getBefore(),
3811 newWhile.getBefore().begin());
3812
3813 rewriter.replaceOp(op, newWhileResults);
3814 return success();
3815 }
3816};
3817
3818/// Remove WhileOp results that are also unused in 'after' block.
3819///
3820/// %0:2 = scf.while () : () -> (i32, i64) {
3821/// %condition = "test.condition"() : () -> i1
3822/// %v1 = "test.get_some_value"() : () -> i32
3823/// %v2 = "test.get_some_value"() : () -> i64
3824/// scf.condition(%condition) %v1, %v2 : i32, i64
3825/// } do {
3826/// ^bb0(%arg0: i32, %arg1: i64):
3827/// "test.use"(%arg0) : (i32) -> ()
3828/// scf.yield
3829/// }
3830/// return %0#0 : i32
3831///
3832/// becomes
3833/// %0 = scf.while () : () -> (i32) {
3834/// %condition = "test.condition"() : () -> i1
3835/// %v1 = "test.get_some_value"() : () -> i32
3836/// %v2 = "test.get_some_value"() : () -> i64
3837/// scf.condition(%condition) %v1 : i32
3838/// } do {
3839/// ^bb0(%arg0: i32):
3840/// "test.use"(%arg0) : (i32) -> ()
3841/// scf.yield
3842/// }
3843/// return %0 : i32
3844struct WhileUnusedResult : public OpRewritePattern<WhileOp> {
3845 using OpRewritePattern<WhileOp>::OpRewritePattern;
3846
3847 LogicalResult matchAndRewrite(WhileOp op,
3848 PatternRewriter &rewriter) const override {
3849 auto term = op.getConditionOp();
3850 auto afterArgs = op.getAfterArguments();
3851 auto termArgs = term.getArgs();
3852
3853 // Collect results mapping, new terminator args and new result types.
3854 SmallVector<unsigned> newResultsIndices;
3855 SmallVector<Type> newResultTypes;
3856 SmallVector<Value> newTermArgs;
3857 SmallVector<Location> newArgLocs;
3858 bool needUpdate = false;
3859 for (const auto &it :
3860 llvm::enumerate(llvm::zip(op.getResults(), afterArgs, termArgs))) {
3861 auto i = static_cast<unsigned>(it.index());
3862 Value result = std::get<0>(it.value());
3863 Value afterArg = std::get<1>(it.value());
3864 Value termArg = std::get<2>(it.value());
3865 if (result.use_empty() && afterArg.use_empty()) {
3866 needUpdate = true;
3867 } else {
3868 newResultsIndices.emplace_back(i);
3869 newTermArgs.emplace_back(termArg);
3870 newResultTypes.emplace_back(result.getType());
3871 newArgLocs.emplace_back(result.getLoc());
3872 }
3873 }
3874
3875 if (!needUpdate)
3876 return failure();
3877
3878 {
3879 OpBuilder::InsertionGuard g(rewriter);
3880 rewriter.setInsertionPoint(term);
3881 rewriter.replaceOpWithNewOp<ConditionOp>(term, term.getCondition(),
3882 newTermArgs);
3883 }
3884
3885 auto newWhile =
3886 rewriter.create<WhileOp>(op.getLoc(), newResultTypes, op.getInits());
3887
3888 Block &newAfterBlock = *rewriter.createBlock(
3889 &newWhile.getAfter(), /*insertPt*/ {}, newResultTypes, newArgLocs);
3890
3891 // Build new results list and new after block args (unused entries will be
3892 // null).
3893 SmallVector<Value> newResults(op.getNumResults());
3894 SmallVector<Value> newAfterBlockArgs(op.getNumResults());
3895 for (const auto &it : llvm::enumerate(First&: newResultsIndices)) {
3896 newResults[it.value()] = newWhile.getResult(it.index());
3897 newAfterBlockArgs[it.value()] = newAfterBlock.getArgument(i: it.index());
3898 }
3899
3900 rewriter.inlineRegionBefore(op.getBefore(), newWhile.getBefore(),
3901 newWhile.getBefore().begin());
3902
3903 Block &afterBlock = *op.getAfterBody();
3904 rewriter.mergeBlocks(source: &afterBlock, dest: &newAfterBlock, argValues: newAfterBlockArgs);
3905
3906 rewriter.replaceOp(op, newResults);
3907 return success();
3908 }
3909};
3910
3911/// Replace operations equivalent to the condition in the do block with true,
3912/// since otherwise the block would not be evaluated.
3913///
3914/// scf.while (..) : (i32, ...) -> ... {
3915/// %z = ... : i32
3916/// %condition = cmpi pred %z, %a
3917/// scf.condition(%condition) %z : i32, ...
3918/// } do {
3919/// ^bb0(%arg0: i32, ...):
3920/// %condition2 = cmpi pred %arg0, %a
3921/// use(%condition2)
3922/// ...
3923///
3924/// becomes
3925/// scf.while (..) : (i32, ...) -> ... {
3926/// %z = ... : i32
3927/// %condition = cmpi pred %z, %a
3928/// scf.condition(%condition) %z : i32, ...
3929/// } do {
3930/// ^bb0(%arg0: i32, ...):
3931/// use(%true)
3932/// ...
3933struct WhileCmpCond : public OpRewritePattern<scf::WhileOp> {
3934 using OpRewritePattern<scf::WhileOp>::OpRewritePattern;
3935
3936 LogicalResult matchAndRewrite(scf::WhileOp op,
3937 PatternRewriter &rewriter) const override {
3938 using namespace scf;
3939 auto cond = op.getConditionOp();
3940 auto cmp = cond.getCondition().getDefiningOp<arith::CmpIOp>();
3941 if (!cmp)
3942 return failure();
3943 bool changed = false;
3944 for (auto tup : llvm::zip(cond.getArgs(), op.getAfterArguments())) {
3945 for (size_t opIdx = 0; opIdx < 2; opIdx++) {
3946 if (std::get<0>(tup) != cmp.getOperand(opIdx))
3947 continue;
3948 for (OpOperand &u :
3949 llvm::make_early_inc_range(std::get<1>(tup).getUses())) {
3950 auto cmp2 = dyn_cast<arith::CmpIOp>(u.getOwner());
3951 if (!cmp2)
3952 continue;
3953 // For a binary operator 1-opIdx gets the other side.
3954 if (cmp2.getOperand(1 - opIdx) != cmp.getOperand(1 - opIdx))
3955 continue;
3956 bool samePredicate;
3957 if (cmp2.getPredicate() == cmp.getPredicate())
3958 samePredicate = true;
3959 else if (cmp2.getPredicate() ==
3960 arith::invertPredicate(cmp.getPredicate()))
3961 samePredicate = false;
3962 else
3963 continue;
3964
3965 rewriter.replaceOpWithNewOp<arith::ConstantIntOp>(cmp2, samePredicate,
3966 1);
3967 changed = true;
3968 }
3969 }
3970 }
3971 return success(IsSuccess: changed);
3972 }
3973};
3974
3975/// Remove unused init/yield args.
3976struct WhileRemoveUnusedArgs : public OpRewritePattern<WhileOp> {
3977 using OpRewritePattern<WhileOp>::OpRewritePattern;
3978
3979 LogicalResult matchAndRewrite(WhileOp op,
3980 PatternRewriter &rewriter) const override {
3981
3982 if (!llvm::any_of(op.getBeforeArguments(),
3983 [](Value arg) { return arg.use_empty(); }))
3984 return rewriter.notifyMatchFailure(op, "No args to remove");
3985
3986 YieldOp yield = op.getYieldOp();
3987
3988 // Collect results mapping, new terminator args and new result types.
3989 SmallVector<Value> newYields;
3990 SmallVector<Value> newInits;
3991 llvm::BitVector argsToErase;
3992
3993 size_t argsCount = op.getBeforeArguments().size();
3994 newYields.reserve(N: argsCount);
3995 newInits.reserve(N: argsCount);
3996 argsToErase.reserve(N: argsCount);
3997 for (auto &&[beforeArg, yieldValue, initValue] : llvm::zip(
3998 op.getBeforeArguments(), yield.getOperands(), op.getInits())) {
3999 if (beforeArg.use_empty()) {
4000 argsToErase.push_back(true);
4001 } else {
4002 argsToErase.push_back(false);
4003 newYields.emplace_back(yieldValue);
4004 newInits.emplace_back(initValue);
4005 }
4006 }
4007
4008 Block &beforeBlock = *op.getBeforeBody();
4009 Block &afterBlock = *op.getAfterBody();
4010
4011 beforeBlock.eraseArguments(eraseIndices: argsToErase);
4012
4013 Location loc = op.getLoc();
4014 auto newWhileOp =
4015 rewriter.create<WhileOp>(loc, op.getResultTypes(), newInits,
4016 /*beforeBody*/ nullptr, /*afterBody*/ nullptr);
4017 Block &newBeforeBlock = *newWhileOp.getBeforeBody();
4018 Block &newAfterBlock = *newWhileOp.getAfterBody();
4019
4020 OpBuilder::InsertionGuard g(rewriter);
4021 rewriter.setInsertionPoint(yield);
4022 rewriter.replaceOpWithNewOp<YieldOp>(yield, newYields);
4023
4024 rewriter.mergeBlocks(source: &beforeBlock, dest: &newBeforeBlock,
4025 argValues: newBeforeBlock.getArguments());
4026 rewriter.mergeBlocks(source: &afterBlock, dest: &newAfterBlock,
4027 argValues: newAfterBlock.getArguments());
4028
4029 rewriter.replaceOp(op, newWhileOp.getResults());
4030 return success();
4031 }
4032};
4033
4034/// Remove duplicated ConditionOp args.
4035struct WhileRemoveDuplicatedResults : public OpRewritePattern<WhileOp> {
4036 using OpRewritePattern::OpRewritePattern;
4037
4038 LogicalResult matchAndRewrite(WhileOp op,
4039 PatternRewriter &rewriter) const override {
4040 ConditionOp condOp = op.getConditionOp();
4041 ValueRange condOpArgs = condOp.getArgs();
4042
4043 llvm::SmallPtrSet<Value, 8> argsSet(llvm::from_range, condOpArgs);
4044
4045 if (argsSet.size() == condOpArgs.size())
4046 return rewriter.notifyMatchFailure(op, "No results to remove");
4047
4048 llvm::SmallDenseMap<Value, unsigned> argsMap;
4049 SmallVector<Value> newArgs;
4050 argsMap.reserve(NumEntries: condOpArgs.size());
4051 newArgs.reserve(N: condOpArgs.size());
4052 for (Value arg : condOpArgs) {
4053 if (!argsMap.count(arg)) {
4054 auto pos = static_cast<unsigned>(argsMap.size());
4055 argsMap.insert({arg, pos});
4056 newArgs.emplace_back(arg);
4057 }
4058 }
4059
4060 ValueRange argsRange(newArgs);
4061
4062 Location loc = op.getLoc();
4063 auto newWhileOp = rewriter.create<scf::WhileOp>(
4064 loc, argsRange.getTypes(), op.getInits(), /*beforeBody*/ nullptr,
4065 /*afterBody*/ nullptr);
4066 Block &newBeforeBlock = *newWhileOp.getBeforeBody();
4067 Block &newAfterBlock = *newWhileOp.getAfterBody();
4068
4069 SmallVector<Value> afterArgsMapping;
4070 SmallVector<Value> resultsMapping;
4071 for (auto &&[i, arg] : llvm::enumerate(condOpArgs)) {
4072 auto it = argsMap.find(arg);
4073 assert(it != argsMap.end());
4074 auto pos = it->second;
4075 afterArgsMapping.emplace_back(newAfterBlock.getArgument(pos));
4076 resultsMapping.emplace_back(newWhileOp->getResult(pos));
4077 }
4078
4079 OpBuilder::InsertionGuard g(rewriter);
4080 rewriter.setInsertionPoint(condOp);
4081 rewriter.replaceOpWithNewOp<ConditionOp>(condOp, condOp.getCondition(),
4082 argsRange);
4083
4084 Block &beforeBlock = *op.getBeforeBody();
4085 Block &afterBlock = *op.getAfterBody();
4086
4087 rewriter.mergeBlocks(source: &beforeBlock, dest: &newBeforeBlock,
4088 argValues: newBeforeBlock.getArguments());
4089 rewriter.mergeBlocks(source: &afterBlock, dest: &newAfterBlock, argValues: afterArgsMapping);
4090 rewriter.replaceOp(op, resultsMapping);
4091 return success();
4092 }
4093};
4094
4095/// If both ranges contain same values return mappping indices from args2 to
4096/// args1. Otherwise return std::nullopt.
4097static std::optional<SmallVector<unsigned>> getArgsMapping(ValueRange args1,
4098 ValueRange args2) {
4099 if (args1.size() != args2.size())
4100 return std::nullopt;
4101
4102 SmallVector<unsigned> ret(args1.size());
4103 for (auto &&[i, arg1] : llvm::enumerate(First&: args1)) {
4104 auto it = llvm::find(Range&: args2, Val: arg1);
4105 if (it == args2.end())
4106 return std::nullopt;
4107
4108 ret[std::distance(first: args2.begin(), last: it)] = static_cast<unsigned>(i);
4109 }
4110
4111 return ret;
4112}
4113
4114static bool hasDuplicates(ValueRange args) {
4115 llvm::SmallDenseSet<Value> set;
4116 for (Value arg : args) {
4117 if (!set.insert(V: arg).second)
4118 return true;
4119 }
4120 return false;
4121}
4122
4123/// If `before` block args are directly forwarded to `scf.condition`, rearrange
4124/// `scf.condition` args into same order as block args. Update `after` block
4125/// args and op result values accordingly.
4126/// Needed to simplify `scf.while` -> `scf.for` uplifting.
4127struct WhileOpAlignBeforeArgs : public OpRewritePattern<WhileOp> {
4128 using OpRewritePattern::OpRewritePattern;
4129
4130 LogicalResult matchAndRewrite(WhileOp loop,
4131 PatternRewriter &rewriter) const override {
4132 auto oldBefore = loop.getBeforeBody();
4133 ConditionOp oldTerm = loop.getConditionOp();
4134 ValueRange beforeArgs = oldBefore->getArguments();
4135 ValueRange termArgs = oldTerm.getArgs();
4136 if (beforeArgs == termArgs)
4137 return failure();
4138
4139 if (hasDuplicates(args: termArgs))
4140 return failure();
4141
4142 auto mapping = getArgsMapping(args1: beforeArgs, args2: termArgs);
4143 if (!mapping)
4144 return failure();
4145
4146 {
4147 OpBuilder::InsertionGuard g(rewriter);
4148 rewriter.setInsertionPoint(oldTerm);
4149 rewriter.replaceOpWithNewOp<ConditionOp>(oldTerm, oldTerm.getCondition(),
4150 beforeArgs);
4151 }
4152
4153 auto oldAfter = loop.getAfterBody();
4154
4155 SmallVector<Type> newResultTypes(beforeArgs.size());
4156 for (auto &&[i, j] : llvm::enumerate(*mapping))
4157 newResultTypes[j] = loop.getResult(i).getType();
4158
4159 auto newLoop = rewriter.create<WhileOp>(
4160 loop.getLoc(), newResultTypes, loop.getInits(),
4161 /*beforeBuilder=*/nullptr, /*afterBuilder=*/nullptr);
4162 auto newBefore = newLoop.getBeforeBody();
4163 auto newAfter = newLoop.getAfterBody();
4164
4165 SmallVector<Value> newResults(beforeArgs.size());
4166 SmallVector<Value> newAfterArgs(beforeArgs.size());
4167 for (auto &&[i, j] : llvm::enumerate(*mapping)) {
4168 newResults[i] = newLoop.getResult(j);
4169 newAfterArgs[i] = newAfter->getArgument(j);
4170 }
4171
4172 rewriter.inlineBlockBefore(oldBefore, newBefore, newBefore->begin(),
4173 newBefore->getArguments());
4174 rewriter.inlineBlockBefore(oldAfter, newAfter, newAfter->begin(),
4175 newAfterArgs);
4176
4177 rewriter.replaceOp(loop, newResults);
4178 return success();
4179 }
4180};
4181} // namespace
4182
4183void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
4184 MLIRContext *context) {
4185 results.add<RemoveLoopInvariantArgsFromBeforeBlock,
4186 RemoveLoopInvariantValueYielded, WhileConditionTruth,
4187 WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults,
4188 WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs>(context);
4189}
4190
4191//===----------------------------------------------------------------------===//
4192// IndexSwitchOp
4193//===----------------------------------------------------------------------===//
4194
4195/// Parse the case regions and values.
4196static ParseResult
4197parseSwitchCases(OpAsmParser &p, DenseI64ArrayAttr &cases,
4198 SmallVectorImpl<std::unique_ptr<Region>> &caseRegions) {
4199 SmallVector<int64_t> caseValues;
4200 while (succeeded(Result: p.parseOptionalKeyword(keyword: "case"))) {
4201 int64_t value;
4202 Region &region = *caseRegions.emplace_back(Args: std::make_unique<Region>());
4203 if (p.parseInteger(result&: value) || p.parseRegion(region, /*arguments=*/{}))
4204 return failure();
4205 caseValues.push_back(Elt: value);
4206 }
4207 cases = p.getBuilder().getDenseI64ArrayAttr(caseValues);
4208 return success();
4209}
4210
4211/// Print the case regions and values.
4212static void printSwitchCases(OpAsmPrinter &p, Operation *op,
4213 DenseI64ArrayAttr cases, RegionRange caseRegions) {
4214 for (auto [value, region] : llvm::zip(cases.asArrayRef(), caseRegions)) {
4215 p.printNewline();
4216 p << "case " << value << ' ';
4217 p.printRegion(*region, /*printEntryBlockArgs=*/false);
4218 }
4219}
4220
4221LogicalResult scf::IndexSwitchOp::verify() {
4222 if (getCases().size() != getCaseRegions().size()) {
4223 return emitOpError("has ")
4224 << getCaseRegions().size() << " case regions but "
4225 << getCases().size() << " case values";
4226 }
4227
4228 DenseSet<int64_t> valueSet;
4229 for (int64_t value : getCases())
4230 if (!valueSet.insert(value).second)
4231 return emitOpError("has duplicate case value: ") << value;
4232 auto verifyRegion = [&](Region &region, const Twine &name) -> LogicalResult {
4233 auto yield = dyn_cast<YieldOp>(region.front().back());
4234 if (!yield)
4235 return emitOpError("expected region to end with scf.yield, but got ")
4236 << region.front().back().getName();
4237
4238 if (yield.getNumOperands() != getNumResults()) {
4239 return (emitOpError("expected each region to return ")
4240 << getNumResults() << " values, but " << name << " returns "
4241 << yield.getNumOperands())
4242 .attachNote(yield.getLoc())
4243 << "see yield operation here";
4244 }
4245 for (auto [idx, result, operand] :
4246 llvm::zip(llvm::seq<unsigned>(0, getNumResults()), getResultTypes(),
4247 yield.getOperandTypes())) {
4248 if (result == operand)
4249 continue;
4250 return (emitOpError("expected result #")
4251 << idx << " of each region to be " << result)
4252 .attachNote(yield.getLoc())
4253 << name << " returns " << operand << " here";
4254 }
4255 return success();
4256 };
4257
4258 if (failed(verifyRegion(getDefaultRegion(), "default region")))
4259 return failure();
4260 for (auto [idx, caseRegion] : llvm::enumerate(getCaseRegions()))
4261 if (failed(verifyRegion(caseRegion, "case region #" + Twine(idx))))
4262 return failure();
4263
4264 return success();
4265}
4266
4267unsigned scf::IndexSwitchOp::getNumCases() { return getCases().size(); }
4268
4269Block &scf::IndexSwitchOp::getDefaultBlock() {
4270 return getDefaultRegion().front();
4271}
4272
4273Block &scf::IndexSwitchOp::getCaseBlock(unsigned idx) {
4274 assert(idx < getNumCases() && "case index out-of-bounds");
4275 return getCaseRegions()[idx].front();
4276}
4277
4278void IndexSwitchOp::getSuccessorRegions(
4279 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &successors) {
4280 // All regions branch back to the parent op.
4281 if (!point.isParent()) {
4282 successors.emplace_back(getResults());
4283 return;
4284 }
4285
4286 llvm::append_range(successors, getRegions());
4287}
4288
4289void IndexSwitchOp::getEntrySuccessorRegions(
4290 ArrayRef<Attribute> operands,
4291 SmallVectorImpl<RegionSuccessor> &successors) {
4292 FoldAdaptor adaptor(operands, *this);
4293
4294 // If a constant was not provided, all regions are possible successors.
4295 auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg());
4296 if (!arg) {
4297 llvm::append_range(successors, getRegions());
4298 return;
4299 }
4300
4301 // Otherwise, try to find a case with a matching value. If not, the
4302 // default region is the only successor.
4303 for (auto [caseValue, caseRegion] : llvm::zip(getCases(), getCaseRegions())) {
4304 if (caseValue == arg.getInt()) {
4305 successors.emplace_back(&caseRegion);
4306 return;
4307 }
4308 }
4309 successors.emplace_back(&getDefaultRegion());
4310}
4311
4312void IndexSwitchOp::getRegionInvocationBounds(
4313 ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
4314 auto operandValue = llvm::dyn_cast_or_null<IntegerAttr>(operands.front());
4315 if (!operandValue) {
4316 // All regions are invoked at most once.
4317 bounds.append(getNumRegions(), InvocationBounds(/*lb=*/0, /*ub=*/1));
4318 return;
4319 }
4320
4321 unsigned liveIndex = getNumRegions() - 1;
4322 const auto *it = llvm::find(getCases(), operandValue.getInt());
4323 if (it != getCases().end())
4324 liveIndex = std::distance(getCases().begin(), it);
4325 for (unsigned i = 0, e = getNumRegions(); i < e; ++i)
4326 bounds.emplace_back(/*lb=*/0, /*ub=*/i == liveIndex);
4327}
4328
4329struct FoldConstantCase : OpRewritePattern<scf::IndexSwitchOp> {
4330 using OpRewritePattern<scf::IndexSwitchOp>::OpRewritePattern;
4331
4332 LogicalResult matchAndRewrite(scf::IndexSwitchOp op,
4333 PatternRewriter &rewriter) const override {
4334 // If `op.getArg()` is a constant, select the region that matches with
4335 // the constant value. Use the default region if no matche is found.
4336 std::optional<int64_t> maybeCst = getConstantIntValue(op.getArg());
4337 if (!maybeCst.has_value())
4338 return failure();
4339 int64_t cst = *maybeCst;
4340 int64_t caseIdx, e = op.getNumCases();
4341 for (caseIdx = 0; caseIdx < e; ++caseIdx) {
4342 if (cst == op.getCases()[caseIdx])
4343 break;
4344 }
4345
4346 Region &r = (caseIdx < op.getNumCases()) ? op.getCaseRegions()[caseIdx]
4347 : op.getDefaultRegion();
4348 Block &source = r.front();
4349 Operation *terminator = source.getTerminator();
4350 SmallVector<Value> results = terminator->getOperands();
4351
4352 rewriter.inlineBlockBefore(&source, op);
4353 rewriter.eraseOp(op: terminator);
4354 // Replace the operation with a potentially empty list of results.
4355 // Fold mechanism doesn't support the case where the result list is empty.
4356 rewriter.replaceOp(op, results);
4357
4358 return success();
4359 }
4360};
4361
4362void IndexSwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
4363 MLIRContext *context) {
4364 results.add<FoldConstantCase>(context);
4365}
4366
4367//===----------------------------------------------------------------------===//
4368// TableGen'd op method definitions
4369//===----------------------------------------------------------------------===//
4370
4371#define GET_OP_CLASSES
4372#include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
4373

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/lib/Dialect/SCF/IR/SCF.cpp