1 | //===- WrapInZeroTripCheck.cpp - Loop transforms to add zero-trip-check ---===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/SCF/IR/SCF.h" |
10 | #include "mlir/Dialect/SCF/Transforms/Transforms.h" |
11 | #include "mlir/IR/IRMapping.h" |
12 | #include "mlir/IR/PatternMatch.h" |
13 | |
14 | using namespace mlir; |
15 | |
16 | /// Create zero-trip-check around a `while` op and return the new loop op in the |
17 | /// check. The while loop is rotated to avoid evaluating the condition twice. |
18 | /// |
19 | /// Given an example below: |
20 | /// |
21 | /// scf.while (%arg0 = %init) : (i32) -> i64 { |
22 | /// %val = .., %arg0 : i64 |
23 | /// %cond = arith.cmpi .., %arg0 : i32 |
24 | /// scf.condition(%cond) %val : i64 |
25 | /// } do { |
26 | /// ^bb0(%arg1: i64): |
27 | /// %next = .., %arg1 : i32 |
28 | /// scf.yield %next : i32 |
29 | /// } |
30 | /// |
31 | /// First clone before block to the front of the loop: |
32 | /// |
33 | /// %pre_val = .., %init : i64 |
34 | /// %pre_cond = arith.cmpi .., %init : i32 |
35 | /// scf.while (%arg0 = %init) : (i32) -> i64 { |
36 | /// %val = .., %arg0 : i64 |
37 | /// %cond = arith.cmpi .., %arg0 : i32 |
38 | /// scf.condition(%cond) %val : i64 |
39 | /// } do { |
40 | /// ^bb0(%arg1: i64): |
41 | /// %next = .., %arg1 : i32 |
42 | /// scf.yield %next : i32 |
43 | /// } |
44 | /// |
45 | /// Create `if` op with the condition, rotate and move the loop into the else |
46 | /// branch: |
47 | /// |
48 | /// %pre_val = .., %init : i64 |
49 | /// %pre_cond = arith.cmpi .., %init : i32 |
50 | /// scf.if %pre_cond -> i64 { |
51 | /// %res = scf.while (%arg1 = %va0) : (i64) -> i64 { |
52 | /// // Original after block |
53 | /// %next = .., %arg1 : i32 |
54 | /// // Original before block |
55 | /// %val = .., %next : i64 |
56 | /// %cond = arith.cmpi .., %next : i32 |
57 | /// scf.condition(%cond) %val : i64 |
58 | /// } do { |
59 | /// ^bb0(%arg2: i64): |
60 | /// %scf.yield %arg2 : i32 |
61 | /// } |
62 | /// scf.yield %res : i64 |
63 | /// } else { |
64 | /// scf.yield %pre_val : i64 |
65 | /// } |
66 | FailureOr<scf::WhileOp> mlir::scf::wrapWhileLoopInZeroTripCheck( |
67 | scf::WhileOp whileOp, RewriterBase &rewriter, bool forceCreateCheck) { |
68 | // If the loop is in do-while form (after block only passes through values), |
69 | // there is no need to create a zero-trip-check as before block is always run. |
70 | if (!forceCreateCheck && isa<scf::YieldOp>(whileOp.getAfterBody()->front())) { |
71 | return whileOp; |
72 | } |
73 | |
74 | OpBuilder::InsertionGuard insertion_guard(rewriter); |
75 | |
76 | IRMapping mapper; |
77 | Block *beforeBlock = whileOp.getBeforeBody(); |
78 | // Clone before block before the loop for zero-trip-check. |
79 | for (auto [arg, init] : |
80 | llvm::zip_equal(beforeBlock->getArguments(), whileOp.getInits())) { |
81 | mapper.map(arg, init); |
82 | } |
83 | rewriter.setInsertionPoint(whileOp); |
84 | for (auto &op : *beforeBlock) { |
85 | if (isa<scf::ConditionOp>(op)) { |
86 | break; |
87 | } |
88 | // Safe to clone everything as in a single block all defs have been cloned |
89 | // and added to mapper in order. |
90 | rewriter.insert(op.clone(mapper)); |
91 | } |
92 | |
93 | scf::ConditionOp condOp = whileOp.getConditionOp(); |
94 | Value clonedCondition = mapper.lookupOrDefault(condOp.getCondition()); |
95 | SmallVector<Value> clonedCondArgs = llvm::map_to_vector( |
96 | condOp.getArgs(), [&](Value arg) { return mapper.lookupOrDefault(from: arg); }); |
97 | |
98 | // Create rotated while loop. |
99 | auto newLoopOp = rewriter.create<scf::WhileOp>( |
100 | whileOp.getLoc(), whileOp.getResultTypes(), clonedCondArgs, |
101 | [&](OpBuilder &builder, Location loc, ValueRange args) { |
102 | // Rotate and move the loop body into before block. |
103 | auto newBlock = builder.getBlock(); |
104 | rewriter.mergeBlocks(source: whileOp.getAfterBody(), dest: newBlock, argValues: args); |
105 | auto yieldOp = cast<scf::YieldOp>(newBlock->getTerminator()); |
106 | rewriter.mergeBlocks(source: whileOp.getBeforeBody(), dest: newBlock, |
107 | argValues: yieldOp.getResults()); |
108 | rewriter.eraseOp(op: yieldOp); |
109 | }, |
110 | [&](OpBuilder &builder, Location loc, ValueRange args) { |
111 | // Pass through values. |
112 | builder.create<scf::YieldOp>(loc, args); |
113 | }); |
114 | |
115 | // Create zero-trip-check and move the while loop in. |
116 | auto ifOp = rewriter.create<scf::IfOp>( |
117 | whileOp.getLoc(), clonedCondition, |
118 | [&](OpBuilder &builder, Location loc) { |
119 | // Then runs the while loop. |
120 | rewriter.moveOpBefore(newLoopOp, builder.getInsertionBlock(), |
121 | builder.getInsertionPoint()); |
122 | builder.create<scf::YieldOp>(loc, newLoopOp.getResults()); |
123 | }, |
124 | [&](OpBuilder &builder, Location loc) { |
125 | // Else returns the results from precondition. |
126 | builder.create<scf::YieldOp>(loc, clonedCondArgs); |
127 | }); |
128 | |
129 | rewriter.replaceOp(whileOp, ifOp); |
130 | |
131 | return newLoopOp; |
132 | } |
133 | |