1//===- WrapInZeroTripCheck.cpp - Loop transforms to add zero-trip-check ---===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/SCF/IR/SCF.h"
10#include "mlir/Dialect/SCF/Transforms/Transforms.h"
11#include "mlir/IR/IRMapping.h"
12#include "mlir/IR/PatternMatch.h"
13
14using namespace mlir;
15
16/// Create zero-trip-check around a `while` op and return the new loop op in the
17/// check. The while loop is rotated to avoid evaluating the condition twice.
18///
19/// Given an example below:
20///
21/// scf.while (%arg0 = %init) : (i32) -> i64 {
22/// %val = .., %arg0 : i64
23/// %cond = arith.cmpi .., %arg0 : i32
24/// scf.condition(%cond) %val : i64
25/// } do {
26/// ^bb0(%arg1: i64):
27/// %next = .., %arg1 : i32
28/// scf.yield %next : i32
29/// }
30///
31/// First clone before block to the front of the loop:
32///
33/// %pre_val = .., %init : i64
34/// %pre_cond = arith.cmpi .., %init : i32
35/// scf.while (%arg0 = %init) : (i32) -> i64 {
36/// %val = .., %arg0 : i64
37/// %cond = arith.cmpi .., %arg0 : i32
38/// scf.condition(%cond) %val : i64
39/// } do {
40/// ^bb0(%arg1: i64):
41/// %next = .., %arg1 : i32
42/// scf.yield %next : i32
43/// }
44///
45/// Create `if` op with the condition, rotate and move the loop into the else
46/// branch:
47///
48/// %pre_val = .., %init : i64
49/// %pre_cond = arith.cmpi .., %init : i32
50/// scf.if %pre_cond -> i64 {
51/// %res = scf.while (%arg1 = %va0) : (i64) -> i64 {
52/// // Original after block
53/// %next = .., %arg1 : i32
54/// // Original before block
55/// %val = .., %next : i64
56/// %cond = arith.cmpi .., %next : i32
57/// scf.condition(%cond) %val : i64
58/// } do {
59/// ^bb0(%arg2: i64):
60/// %scf.yield %arg2 : i32
61/// }
62/// scf.yield %res : i64
63/// } else {
64/// scf.yield %pre_val : i64
65/// }
66FailureOr<scf::WhileOp> mlir::scf::wrapWhileLoopInZeroTripCheck(
67 scf::WhileOp whileOp, RewriterBase &rewriter, bool forceCreateCheck) {
68 // If the loop is in do-while form (after block only passes through values),
69 // there is no need to create a zero-trip-check as before block is always run.
70 if (!forceCreateCheck && isa<scf::YieldOp>(whileOp.getAfterBody()->front())) {
71 return whileOp;
72 }
73
74 OpBuilder::InsertionGuard insertion_guard(rewriter);
75
76 IRMapping mapper;
77 Block *beforeBlock = whileOp.getBeforeBody();
78 // Clone before block before the loop for zero-trip-check.
79 for (auto [arg, init] :
80 llvm::zip_equal(beforeBlock->getArguments(), whileOp.getInits())) {
81 mapper.map(arg, init);
82 }
83 rewriter.setInsertionPoint(whileOp);
84 for (auto &op : *beforeBlock) {
85 if (isa<scf::ConditionOp>(op)) {
86 break;
87 }
88 // Safe to clone everything as in a single block all defs have been cloned
89 // and added to mapper in order.
90 rewriter.insert(op.clone(mapper));
91 }
92
93 scf::ConditionOp condOp = whileOp.getConditionOp();
94 Value clonedCondition = mapper.lookupOrDefault(condOp.getCondition());
95 SmallVector<Value> clonedCondArgs = llvm::map_to_vector(
96 condOp.getArgs(), [&](Value arg) { return mapper.lookupOrDefault(from: arg); });
97
98 // Create rotated while loop.
99 auto newLoopOp = rewriter.create<scf::WhileOp>(
100 whileOp.getLoc(), whileOp.getResultTypes(), clonedCondArgs,
101 [&](OpBuilder &builder, Location loc, ValueRange args) {
102 // Rotate and move the loop body into before block.
103 auto newBlock = builder.getBlock();
104 rewriter.mergeBlocks(source: whileOp.getAfterBody(), dest: newBlock, argValues: args);
105 auto yieldOp = cast<scf::YieldOp>(newBlock->getTerminator());
106 rewriter.mergeBlocks(source: whileOp.getBeforeBody(), dest: newBlock,
107 argValues: yieldOp.getResults());
108 rewriter.eraseOp(op: yieldOp);
109 },
110 [&](OpBuilder &builder, Location loc, ValueRange args) {
111 // Pass through values.
112 builder.create<scf::YieldOp>(loc, args);
113 });
114
115 // Create zero-trip-check and move the while loop in.
116 auto ifOp = rewriter.create<scf::IfOp>(
117 whileOp.getLoc(), clonedCondition,
118 [&](OpBuilder &builder, Location loc) {
119 // Then runs the while loop.
120 rewriter.moveOpBefore(newLoopOp, builder.getInsertionBlock(),
121 builder.getInsertionPoint());
122 builder.create<scf::YieldOp>(loc, newLoopOp.getResults());
123 },
124 [&](OpBuilder &builder, Location loc) {
125 // Else returns the results from precondition.
126 builder.create<scf::YieldOp>(loc, clonedCondArgs);
127 });
128
129 rewriter.replaceOp(whileOp, ifOp);
130
131 return newLoopOp;
132}
133

source code of mlir/lib/Dialect/SCF/Transforms/WrapInZeroTripCheck.cpp