1 | //===- LoopRangeFolding.cpp - Code to perform loop range folding-----------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements loop range folding. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Dialect/SCF/Transforms/Passes.h" |
14 | |
15 | #include "mlir/Dialect/Arith/IR/Arith.h" |
16 | #include "mlir/Dialect/SCF/IR/SCF.h" |
17 | #include "mlir/Dialect/SCF/Transforms/Transforms.h" |
18 | #include "mlir/Dialect/SCF/Utils/Utils.h" |
19 | #include "mlir/IR/IRMapping.h" |
20 | |
21 | namespace mlir { |
22 | #define GEN_PASS_DEF_SCFFORLOOPRANGEFOLDING |
23 | #include "mlir/Dialect/SCF/Transforms/Passes.h.inc" |
24 | } // namespace mlir |
25 | |
26 | using namespace mlir; |
27 | using namespace mlir::scf; |
28 | |
29 | namespace { |
30 | struct ForLoopRangeFolding |
31 | : public impl::SCFForLoopRangeFoldingBase<ForLoopRangeFolding> { |
32 | void runOnOperation() override; |
33 | }; |
34 | } // namespace |
35 | |
36 | void ForLoopRangeFolding::runOnOperation() { |
37 | getOperation()->walk([&](ForOp op) { |
38 | Value indVar = op.getInductionVar(); |
39 | |
40 | auto canBeFolded = [&](Value value) { |
41 | return op.isDefinedOutsideOfLoop(value) || value == indVar; |
42 | }; |
43 | |
44 | // Fold until a fixed point is reached |
45 | while (true) { |
46 | |
47 | // If the induction variable is used more than once, we can't fold its |
48 | // arith ops into the loop range |
49 | if (!indVar.hasOneUse()) |
50 | break; |
51 | |
52 | Operation *user = *indVar.getUsers().begin(); |
53 | if (!isa<arith::AddIOp, arith::MulIOp>(user)) |
54 | break; |
55 | |
56 | if (!llvm::all_of(Range: user->getOperands(), P: canBeFolded)) |
57 | break; |
58 | |
59 | OpBuilder b(op); |
60 | IRMapping lbMap; |
61 | lbMap.map(indVar, op.getLowerBound()); |
62 | IRMapping ubMap; |
63 | ubMap.map(indVar, op.getUpperBound()); |
64 | IRMapping stepMap; |
65 | stepMap.map(indVar, op.getStep()); |
66 | |
67 | if (isa<arith::AddIOp>(user)) { |
68 | Operation *lbFold = b.clone(op&: *user, mapper&: lbMap); |
69 | Operation *ubFold = b.clone(op&: *user, mapper&: ubMap); |
70 | |
71 | op.setLowerBound(lbFold->getResult(idx: 0)); |
72 | op.setUpperBound(ubFold->getResult(idx: 0)); |
73 | |
74 | } else if (isa<arith::MulIOp>(user)) { |
75 | Operation *ubFold = b.clone(op&: *user, mapper&: ubMap); |
76 | Operation *stepFold = b.clone(op&: *user, mapper&: stepMap); |
77 | |
78 | op.setUpperBound(ubFold->getResult(idx: 0)); |
79 | op.setStep(stepFold->getResult(idx: 0)); |
80 | } |
81 | |
82 | ValueRange wrapIndvar(indVar); |
83 | user->replaceAllUsesWith(values&: wrapIndvar); |
84 | user->erase(); |
85 | } |
86 | }); |
87 | } |
88 | |
89 | std::unique_ptr<Pass> mlir::createForLoopRangeFoldingPass() { |
90 | return std::make_unique<ForLoopRangeFolding>(); |
91 | } |
92 | |