1 | //===- ConstantArgumentGlobalisation.cpp ----------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "flang/Optimizer/Builder/FIRBuilder.h" |
10 | #include "flang/Optimizer/Dialect/FIRDialect.h" |
11 | #include "flang/Optimizer/Dialect/FIROps.h" |
12 | #include "flang/Optimizer/Dialect/FIRType.h" |
13 | #include "flang/Optimizer/Transforms/Passes.h" |
14 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
15 | #include "mlir/IR/Diagnostics.h" |
16 | #include "mlir/IR/Dominance.h" |
17 | #include "mlir/Pass/Pass.h" |
18 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
19 | |
20 | namespace fir { |
21 | #define GEN_PASS_DEF_CONSTANTARGUMENTGLOBALISATIONOPT |
22 | #include "flang/Optimizer/Transforms/Passes.h.inc" |
23 | } // namespace fir |
24 | |
25 | #define DEBUG_TYPE "flang-constant-argument-globalisation-opt" |
26 | |
27 | namespace { |
28 | unsigned uniqueLitId = 1; |
29 | |
30 | class CallOpRewriter : public mlir::OpRewritePattern<fir::CallOp> { |
31 | protected: |
32 | const mlir::DominanceInfo &di; |
33 | |
34 | public: |
35 | using OpRewritePattern::OpRewritePattern; |
36 | |
37 | CallOpRewriter(mlir::MLIRContext *ctx, const mlir::DominanceInfo &_di) |
38 | : OpRewritePattern(ctx), di(_di) {} |
39 | |
40 | llvm::LogicalResult |
41 | matchAndRewrite(fir::CallOp callOp, |
42 | mlir::PatternRewriter &rewriter) const override { |
43 | LLVM_DEBUG(llvm::dbgs() << "Processing call op: " << callOp << "\n" ); |
44 | auto module = callOp->getParentOfType<mlir::ModuleOp>(); |
45 | bool needUpdate = false; |
46 | fir::FirOpBuilder builder(rewriter, module); |
47 | llvm::SmallVector<mlir::Value> newOperands; |
48 | llvm::SmallVector<std::pair<mlir::Operation *, mlir::Operation *>> allocas; |
49 | for (const mlir::Value &a : callOp.getArgs()) { |
50 | auto alloca = mlir::dyn_cast_or_null<fir::AllocaOp>(a.getDefiningOp()); |
51 | // We can convert arguments that are alloca, and that has |
52 | // the value by reference attribute. All else is just added |
53 | // to the argument list. |
54 | if (!alloca || !alloca->hasAttr(fir::getAdaptToByRefAttrName())) { |
55 | newOperands.push_back(a); |
56 | continue; |
57 | } |
58 | |
59 | mlir::Type varTy = alloca.getInType(); |
60 | assert(!fir::hasDynamicSize(varTy) && |
61 | "only expect statically sized scalars to be by value" ); |
62 | |
63 | // Find immediate store with const argument |
64 | mlir::Operation *store = nullptr; |
65 | for (mlir::Operation *s : alloca->getUsers()) { |
66 | if (mlir::isa<fir::StoreOp>(s) && di.dominates(s, callOp)) { |
67 | // We can only deal with ONE store - if already found one, |
68 | // set to nullptr and exit the loop. |
69 | if (store) { |
70 | store = nullptr; |
71 | break; |
72 | } |
73 | store = s; |
74 | } |
75 | } |
76 | |
77 | // If we didn't find any store, or multiple stores, add argument as is |
78 | // and move on. |
79 | if (!store) { |
80 | newOperands.push_back(a); |
81 | continue; |
82 | } |
83 | |
84 | LLVM_DEBUG(llvm::dbgs() << " found store " << *store << "\n" ); |
85 | |
86 | mlir::Operation *definingOp = store->getOperand(0).getDefiningOp(); |
87 | // If not a constant, add to operands and move on. |
88 | if (!mlir::isa<mlir::arith::ConstantOp>(definingOp)) { |
89 | // Unable to remove alloca arg |
90 | newOperands.push_back(a); |
91 | continue; |
92 | } |
93 | |
94 | LLVM_DEBUG(llvm::dbgs() << " found define " << *definingOp << "\n" ); |
95 | |
96 | std::string globalName = |
97 | "_global_const_." + std::to_string(uniqueLitId++); |
98 | assert(!builder.getNamedGlobal(globalName) && |
99 | "We should have a unique name here" ); |
100 | |
101 | if (llvm::none_of(allocas, |
102 | [alloca](auto x) { return x.first == alloca; })) { |
103 | allocas.push_back(std::make_pair(alloca, store)); |
104 | } |
105 | |
106 | auto loc = callOp.getLoc(); |
107 | fir::GlobalOp global = builder.createGlobalConstant( |
108 | loc, varTy, globalName, |
109 | [&](fir::FirOpBuilder &builder) { |
110 | mlir::Operation *cln = definingOp->clone(); |
111 | builder.insert(cln); |
112 | mlir::Value val = |
113 | builder.createConvert(loc, varTy, cln->getResult(0)); |
114 | builder.create<fir::HasValueOp>(loc, val); |
115 | }, |
116 | builder.createInternalLinkage()); |
117 | mlir::Value addr = builder.create<fir::AddrOfOp>(loc, global.resultType(), |
118 | global.getSymbol()); |
119 | newOperands.push_back(addr); |
120 | needUpdate = true; |
121 | } |
122 | |
123 | if (needUpdate) { |
124 | auto loc = callOp.getLoc(); |
125 | llvm::SmallVector<mlir::Type> newResultTypes; |
126 | newResultTypes.append(callOp.getResultTypes().begin(), |
127 | callOp.getResultTypes().end()); |
128 | fir::CallOp newOp = builder.create<fir::CallOp>( |
129 | loc, |
130 | callOp.getCallee().has_value() ? callOp.getCallee().value() |
131 | : mlir::SymbolRefAttr{}, |
132 | newResultTypes, newOperands); |
133 | // Copy all the attributes from the old to new op. |
134 | newOp->setAttrs(callOp->getAttrs()); |
135 | rewriter.replaceOp(callOp, newOp); |
136 | |
137 | for (auto a : allocas) { |
138 | if (a.first->hasOneUse()) { |
139 | // If the alloca is only used for a store and the call operand, the |
140 | // store is no longer required. |
141 | rewriter.eraseOp(a.second); |
142 | rewriter.eraseOp(a.first); |
143 | } |
144 | } |
145 | LLVM_DEBUG(llvm::dbgs() << "global constant for " << callOp << " as " |
146 | << newOp << '\n'); |
147 | return mlir::success(); |
148 | } |
149 | |
150 | // Failure here just means "we couldn't do the conversion", which is |
151 | // perfectly acceptable to the upper layers of this function. |
152 | return mlir::failure(); |
153 | } |
154 | }; |
155 | |
156 | // this pass attempts to convert immediate scalar literals in function calls |
157 | // to global constants to allow transformations such as Dead Argument |
158 | // Elimination |
159 | class ConstantArgumentGlobalisationOpt |
160 | : public fir::impl::ConstantArgumentGlobalisationOptBase< |
161 | ConstantArgumentGlobalisationOpt> { |
162 | public: |
163 | ConstantArgumentGlobalisationOpt() = default; |
164 | |
165 | void runOnOperation() override { |
166 | mlir::ModuleOp mod = getOperation(); |
167 | mlir::DominanceInfo *di = &getAnalysis<mlir::DominanceInfo>(); |
168 | auto *context = &getContext(); |
169 | mlir::RewritePatternSet patterns(context); |
170 | mlir::GreedyRewriteConfig config; |
171 | config.setRegionSimplificationLevel( |
172 | mlir::GreedySimplifyRegionLevel::Disabled); |
173 | config.setStrictness(mlir::GreedyRewriteStrictness::ExistingOps); |
174 | |
175 | patterns.insert<CallOpRewriter>(context, *di); |
176 | if (mlir::failed( |
177 | mlir::applyPatternsGreedily(mod, std::move(patterns), config))) { |
178 | mlir::emitError(mod.getLoc(), |
179 | "error in constant globalisation optimization\n" ); |
180 | signalPassFailure(); |
181 | } |
182 | } |
183 | }; |
184 | } // namespace |
185 | |