1//===- LoopRangeFolding.cpp - Code to perform loop range folding-----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements loop range folding.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/SCF/Transforms/Passes.h"
14
15#include "mlir/Dialect/Arith/IR/Arith.h"
16#include "mlir/Dialect/SCF/IR/SCF.h"
17#include "mlir/Dialect/SCF/Transforms/Transforms.h"
18#include "mlir/Dialect/SCF/Utils/Utils.h"
19#include "mlir/IR/IRMapping.h"
20
21namespace mlir {
22#define GEN_PASS_DEF_SCFFORLOOPRANGEFOLDING
23#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
24} // namespace mlir
25
26using namespace mlir;
27using namespace mlir::scf;
28
29namespace {
30struct ForLoopRangeFolding
31 : public impl::SCFForLoopRangeFoldingBase<ForLoopRangeFolding> {
32 void runOnOperation() override;
33};
34} // namespace
35
36void ForLoopRangeFolding::runOnOperation() {
37 getOperation()->walk([&](ForOp op) {
38 Value indVar = op.getInductionVar();
39
40 auto canBeFolded = [&](Value value) {
41 return op.isDefinedOutsideOfLoop(value) || value == indVar;
42 };
43
44 // Fold until a fixed point is reached
45 while (true) {
46
47 // If the induction variable is used more than once, we can't fold its
48 // arith ops into the loop range
49 if (!indVar.hasOneUse())
50 break;
51
52 Operation *user = *indVar.getUsers().begin();
53 if (!isa<arith::AddIOp, arith::MulIOp>(user))
54 break;
55
56 if (!llvm::all_of(Range: user->getOperands(), P: canBeFolded))
57 break;
58
59 OpBuilder b(op);
60 IRMapping lbMap;
61 lbMap.map(indVar, op.getLowerBound());
62 IRMapping ubMap;
63 ubMap.map(indVar, op.getUpperBound());
64 IRMapping stepMap;
65 stepMap.map(indVar, op.getStep());
66
67 if (isa<arith::AddIOp>(user)) {
68 Operation *lbFold = b.clone(op&: *user, mapper&: lbMap);
69 Operation *ubFold = b.clone(op&: *user, mapper&: ubMap);
70
71 op.setLowerBound(lbFold->getResult(idx: 0));
72 op.setUpperBound(ubFold->getResult(idx: 0));
73
74 } else if (isa<arith::MulIOp>(user)) {
75 Operation *ubFold = b.clone(op&: *user, mapper&: ubMap);
76 Operation *stepFold = b.clone(op&: *user, mapper&: stepMap);
77
78 op.setUpperBound(ubFold->getResult(idx: 0));
79 op.setStep(stepFold->getResult(idx: 0));
80 }
81
82 ValueRange wrapIndvar(indVar);
83 user->replaceAllUsesWith(values&: wrapIndvar);
84 user->erase();
85 }
86 });
87}
88
89std::unique_ptr<Pass> mlir::createForLoopRangeFoldingPass() {
90 return std::make_unique<ForLoopRangeFolding>();
91}
92

source code of mlir/lib/Dialect/SCF/Transforms/LoopRangeFolding.cpp