1//===- ConstantArgumentGlobalisation.cpp ----------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "flang/Optimizer/Builder/FIRBuilder.h"
10#include "flang/Optimizer/Dialect/FIRDialect.h"
11#include "flang/Optimizer/Dialect/FIROps.h"
12#include "flang/Optimizer/Dialect/FIRType.h"
13#include "flang/Optimizer/Transforms/Passes.h"
14#include "mlir/Dialect/Func/IR/FuncOps.h"
15#include "mlir/IR/Diagnostics.h"
16#include "mlir/IR/Dominance.h"
17#include "mlir/Pass/Pass.h"
18#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
19
20namespace fir {
21#define GEN_PASS_DEF_CONSTANTARGUMENTGLOBALISATIONOPT
22#include "flang/Optimizer/Transforms/Passes.h.inc"
23} // namespace fir
24
25#define DEBUG_TYPE "flang-constant-argument-globalisation-opt"
26
27namespace {
28unsigned uniqueLitId = 1;
29
30class CallOpRewriter : public mlir::OpRewritePattern<fir::CallOp> {
31protected:
32 const mlir::DominanceInfo &di;
33
34public:
35 using OpRewritePattern::OpRewritePattern;
36
37 CallOpRewriter(mlir::MLIRContext *ctx, const mlir::DominanceInfo &_di)
38 : OpRewritePattern(ctx), di(_di) {}
39
40 llvm::LogicalResult
41 matchAndRewrite(fir::CallOp callOp,
42 mlir::PatternRewriter &rewriter) const override {
43 LLVM_DEBUG(llvm::dbgs() << "Processing call op: " << callOp << "\n");
44 auto module = callOp->getParentOfType<mlir::ModuleOp>();
45 bool needUpdate = false;
46 fir::FirOpBuilder builder(rewriter, module);
47 llvm::SmallVector<mlir::Value> newOperands;
48 llvm::SmallVector<std::pair<mlir::Operation *, mlir::Operation *>> allocas;
49 for (const mlir::Value &a : callOp.getArgs()) {
50 auto alloca = mlir::dyn_cast_or_null<fir::AllocaOp>(a.getDefiningOp());
51 // We can convert arguments that are alloca, and that has
52 // the value by reference attribute. All else is just added
53 // to the argument list.
54 if (!alloca || !alloca->hasAttr(fir::getAdaptToByRefAttrName())) {
55 newOperands.push_back(a);
56 continue;
57 }
58
59 mlir::Type varTy = alloca.getInType();
60 assert(!fir::hasDynamicSize(varTy) &&
61 "only expect statically sized scalars to be by value");
62
63 // Find immediate store with const argument
64 mlir::Operation *store = nullptr;
65 for (mlir::Operation *s : alloca->getUsers()) {
66 if (mlir::isa<fir::StoreOp>(s) && di.dominates(s, callOp)) {
67 // We can only deal with ONE store - if already found one,
68 // set to nullptr and exit the loop.
69 if (store) {
70 store = nullptr;
71 break;
72 }
73 store = s;
74 }
75 }
76
77 // If we didn't find any store, or multiple stores, add argument as is
78 // and move on.
79 if (!store) {
80 newOperands.push_back(a);
81 continue;
82 }
83
84 LLVM_DEBUG(llvm::dbgs() << " found store " << *store << "\n");
85
86 mlir::Operation *definingOp = store->getOperand(0).getDefiningOp();
87 // If not a constant, add to operands and move on.
88 if (!mlir::isa<mlir::arith::ConstantOp>(definingOp)) {
89 // Unable to remove alloca arg
90 newOperands.push_back(a);
91 continue;
92 }
93
94 LLVM_DEBUG(llvm::dbgs() << " found define " << *definingOp << "\n");
95
96 std::string globalName =
97 "_global_const_." + std::to_string(uniqueLitId++);
98 assert(!builder.getNamedGlobal(globalName) &&
99 "We should have a unique name here");
100
101 if (llvm::none_of(allocas,
102 [alloca](auto x) { return x.first == alloca; })) {
103 allocas.push_back(std::make_pair(alloca, store));
104 }
105
106 auto loc = callOp.getLoc();
107 fir::GlobalOp global = builder.createGlobalConstant(
108 loc, varTy, globalName,
109 [&](fir::FirOpBuilder &builder) {
110 mlir::Operation *cln = definingOp->clone();
111 builder.insert(cln);
112 mlir::Value val =
113 builder.createConvert(loc, varTy, cln->getResult(0));
114 builder.create<fir::HasValueOp>(loc, val);
115 },
116 builder.createInternalLinkage());
117 mlir::Value addr = builder.create<fir::AddrOfOp>(loc, global.resultType(),
118 global.getSymbol());
119 newOperands.push_back(addr);
120 needUpdate = true;
121 }
122
123 if (needUpdate) {
124 auto loc = callOp.getLoc();
125 llvm::SmallVector<mlir::Type> newResultTypes;
126 newResultTypes.append(callOp.getResultTypes().begin(),
127 callOp.getResultTypes().end());
128 fir::CallOp newOp = builder.create<fir::CallOp>(
129 loc,
130 callOp.getCallee().has_value() ? callOp.getCallee().value()
131 : mlir::SymbolRefAttr{},
132 newResultTypes, newOperands);
133 // Copy all the attributes from the old to new op.
134 newOp->setAttrs(callOp->getAttrs());
135 rewriter.replaceOp(callOp, newOp);
136
137 for (auto a : allocas) {
138 if (a.first->hasOneUse()) {
139 // If the alloca is only used for a store and the call operand, the
140 // store is no longer required.
141 rewriter.eraseOp(a.second);
142 rewriter.eraseOp(a.first);
143 }
144 }
145 LLVM_DEBUG(llvm::dbgs() << "global constant for " << callOp << " as "
146 << newOp << '\n');
147 return mlir::success();
148 }
149
150 // Failure here just means "we couldn't do the conversion", which is
151 // perfectly acceptable to the upper layers of this function.
152 return mlir::failure();
153 }
154};
155
156// this pass attempts to convert immediate scalar literals in function calls
157// to global constants to allow transformations such as Dead Argument
158// Elimination
159class ConstantArgumentGlobalisationOpt
160 : public fir::impl::ConstantArgumentGlobalisationOptBase<
161 ConstantArgumentGlobalisationOpt> {
162public:
163 ConstantArgumentGlobalisationOpt() = default;
164
165 void runOnOperation() override {
166 mlir::ModuleOp mod = getOperation();
167 mlir::DominanceInfo *di = &getAnalysis<mlir::DominanceInfo>();
168 auto *context = &getContext();
169 mlir::RewritePatternSet patterns(context);
170 mlir::GreedyRewriteConfig config;
171 config.setRegionSimplificationLevel(
172 mlir::GreedySimplifyRegionLevel::Disabled);
173 config.setStrictness(mlir::GreedyRewriteStrictness::ExistingOps);
174
175 patterns.insert<CallOpRewriter>(context, *di);
176 if (mlir::failed(
177 mlir::applyPatternsGreedily(mod, std::move(patterns), config))) {
178 mlir::emitError(mod.getLoc(),
179 "error in constant globalisation optimization\n");
180 signalPassFailure();
181 }
182 }
183};
184} // namespace
185

source code of flang/lib/Optimizer/Transforms/ConstantArgumentGlobalisation.cpp