1//===-- BoxedProcedure.cpp ------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "flang/Optimizer/CodeGen/CodeGen.h"
10
11#include "flang/Optimizer/Builder/FIRBuilder.h"
12#include "flang/Optimizer/Builder/LowLevelIntrinsics.h"
13#include "flang/Optimizer/Dialect/FIRDialect.h"
14#include "flang/Optimizer/Dialect/FIROps.h"
15#include "flang/Optimizer/Dialect/FIRType.h"
16#include "flang/Optimizer/Dialect/Support/FIRContext.h"
17#include "flang/Optimizer/Support/FatalError.h"
18#include "flang/Optimizer/Support/InternalNames.h"
19#include "mlir/IR/PatternMatch.h"
20#include "mlir/Pass/Pass.h"
21#include "mlir/Transforms/DialectConversion.h"
22#include "llvm/ADT/DenseMap.h"
23
24namespace fir {
25#define GEN_PASS_DEF_BOXEDPROCEDUREPASS
26#include "flang/Optimizer/CodeGen/CGPasses.h.inc"
27} // namespace fir
28
29#define DEBUG_TYPE "flang-procedure-pointer"
30
31using namespace fir;
32
33namespace {
34/// Options to the procedure pointer pass.
35struct BoxedProcedureOptions {
36 // Lower the boxproc abstraction to function pointers and thunks where
37 // required.
38 bool useThunks = true;
39};
40
41/// This type converter rewrites all `!fir.boxproc<Func>` types to `Func` types.
42class BoxprocTypeRewriter : public mlir::TypeConverter {
43public:
44 using mlir::TypeConverter::convertType;
45
46 /// Does the type \p ty need to be converted?
47 /// Any type that is a `!fir.boxproc` in whole or in part will need to be
48 /// converted to a function type to lower the IR to function pointer form in
49 /// the default implementation performed in this pass. Other implementations
50 /// are possible, so those may convert `!fir.boxproc` to some other type or
51 /// not at all depending on the implementation target's characteristics and
52 /// preference.
53 bool needsConversion(mlir::Type ty) {
54 if (ty.isa<BoxProcType>())
55 return true;
56 if (auto funcTy = ty.dyn_cast<mlir::FunctionType>()) {
57 for (auto t : funcTy.getInputs())
58 if (needsConversion(t))
59 return true;
60 for (auto t : funcTy.getResults())
61 if (needsConversion(t))
62 return true;
63 return false;
64 }
65 if (auto tupleTy = ty.dyn_cast<mlir::TupleType>()) {
66 for (auto t : tupleTy.getTypes())
67 if (needsConversion(t))
68 return true;
69 return false;
70 }
71 if (auto recTy = ty.dyn_cast<RecordType>()) {
72 auto visited = visitedTypes.find(ty);
73 if (visited != visitedTypes.end())
74 return visited->second;
75 [[maybe_unused]] auto newIt = visitedTypes.try_emplace(ty, false);
76 assert(newIt.second && "expected ty to not be in the map");
77 bool wasAlreadyVisitingRecordType = needConversionIsVisitingRecordType;
78 needConversionIsVisitingRecordType = true;
79 bool result = false;
80 for (auto t : recTy.getTypeList()) {
81 if (needsConversion(t.second)) {
82 result = true;
83 break;
84 }
85 }
86 // Only keep the result cached if the fir.type visited was a "top-level
87 // type". Nested types with a recursive reference to the "top-level type"
88 // may incorrectly have been resolved as not needed conversions because it
89 // had not been determined yet if the "top-level type" needed conversion.
90 // This is not an issue to determine the "top-level type" need of
91 // conversion, but the result should not be kept and later used in other
92 // contexts.
93 needConversionIsVisitingRecordType = wasAlreadyVisitingRecordType;
94 if (needConversionIsVisitingRecordType)
95 visitedTypes.erase(ty);
96 else
97 visitedTypes.find(ty)->second = result;
98 return result;
99 }
100 if (auto boxTy = ty.dyn_cast<BaseBoxType>())
101 return needsConversion(boxTy.getEleTy());
102 if (isa_ref_type(ty))
103 return needsConversion(unwrapRefType(ty));
104 if (auto t = ty.dyn_cast<SequenceType>())
105 return needsConversion(unwrapSequenceType(ty));
106 return false;
107 }
108
109 BoxprocTypeRewriter(mlir::Location location) : loc{location} {
110 addConversion([](mlir::Type ty) { return ty; });
111 addConversion(
112 [&](BoxProcType boxproc) { return convertType(boxproc.getEleTy()); });
113 addConversion([&](mlir::TupleType tupTy) {
114 llvm::SmallVector<mlir::Type> memTys;
115 for (auto ty : tupTy.getTypes())
116 memTys.push_back(convertType(ty));
117 return mlir::TupleType::get(tupTy.getContext(), memTys);
118 });
119 addConversion([&](mlir::FunctionType funcTy) {
120 llvm::SmallVector<mlir::Type> inTys;
121 llvm::SmallVector<mlir::Type> resTys;
122 for (auto ty : funcTy.getInputs())
123 inTys.push_back(convertType(ty));
124 for (auto ty : funcTy.getResults())
125 resTys.push_back(convertType(ty));
126 return mlir::FunctionType::get(funcTy.getContext(), inTys, resTys);
127 });
128 addConversion([&](ReferenceType ty) {
129 return ReferenceType::get(convertType(ty.getEleTy()));
130 });
131 addConversion([&](PointerType ty) {
132 return PointerType::get(convertType(ty.getEleTy()));
133 });
134 addConversion(
135 [&](HeapType ty) { return HeapType::get(convertType(ty.getEleTy())); });
136 addConversion([&](fir::LLVMPointerType ty) {
137 return fir::LLVMPointerType::get(convertType(ty.getEleTy()));
138 });
139 addConversion(
140 [&](BoxType ty) { return BoxType::get(convertType(ty.getEleTy())); });
141 addConversion([&](ClassType ty) {
142 return ClassType::get(convertType(ty.getEleTy()));
143 });
144 addConversion([&](SequenceType ty) {
145 // TODO: add ty.getLayoutMap() as needed.
146 return SequenceType::get(ty.getShape(), convertType(ty.getEleTy()));
147 });
148 addConversion([&](RecordType ty) -> mlir::Type {
149 if (!needsConversion(ty))
150 return ty;
151 if (auto converted = convertedTypes.lookup(ty))
152 return converted;
153 auto rec = RecordType::get(ty.getContext(),
154 ty.getName().str() + boxprocSuffix.str());
155 if (rec.isFinalized())
156 return rec;
157 [[maybe_unused]] auto it = convertedTypes.try_emplace(ty, rec);
158 assert(it.second && "expected ty to not be in the map");
159 std::vector<RecordType::TypePair> ps = ty.getLenParamList();
160 std::vector<RecordType::TypePair> cs;
161 for (auto t : ty.getTypeList()) {
162 if (needsConversion(t.second))
163 cs.emplace_back(t.first, convertType(t.second));
164 else
165 cs.emplace_back(t.first, t.second);
166 }
167 rec.finalize(ps, cs);
168 return rec;
169 });
170 addArgumentMaterialization(materializeProcedure);
171 addSourceMaterialization(materializeProcedure);
172 addTargetMaterialization(materializeProcedure);
173 }
174
175 static mlir::Value materializeProcedure(mlir::OpBuilder &builder,
176 BoxProcType type,
177 mlir::ValueRange inputs,
178 mlir::Location loc) {
179 assert(inputs.size() == 1);
180 return builder.create<ConvertOp>(loc, unwrapRefType(type.getEleTy()),
181 inputs[0]);
182 }
183
184 void setLocation(mlir::Location location) { loc = location; }
185
186private:
187 // Maps to deal with recursive derived types (avoid infinite loops).
188 // Caching is also beneficial for apps with big types (dozens of
189 // components and or parent types), so the lifetime of the cache
190 // is the whole pass.
191 llvm::DenseMap<mlir::Type, bool> visitedTypes;
192 bool needConversionIsVisitingRecordType = false;
193 llvm::DenseMap<mlir::Type, mlir::Type> convertedTypes;
194 mlir::Location loc;
195};
196
197/// A `boxproc` is an abstraction for a Fortran procedure reference. Typically,
198/// Fortran procedures can be referenced directly through a function pointer.
199/// However, Fortran has one-level dynamic scoping between a host procedure and
200/// its internal procedures. This allows internal procedures to directly access
201/// and modify the state of the host procedure's variables.
202///
203/// There are any number of possible implementations possible.
204///
205/// The implementation used here is to convert `boxproc` values to function
206/// pointers everywhere. If a `boxproc` value includes a frame pointer to the
207/// host procedure's data, then a thunk will be created at runtime to capture
208/// the frame pointer during execution. In LLVM IR, the frame pointer is
209/// designated with the `nest` attribute. The thunk's address will then be used
210/// as the call target instead of the original function's address directly.
211class BoxedProcedurePass
212 : public fir::impl::BoxedProcedurePassBase<BoxedProcedurePass> {
213public:
214 BoxedProcedurePass() { options = {.useThunks: true}; }
215 BoxedProcedurePass(bool useThunks) { options = {.useThunks: useThunks}; }
216
217 inline mlir::ModuleOp getModule() { return getOperation(); }
218
219 void runOnOperation() override final {
220 if (options.useThunks) {
221 auto *context = &getContext();
222 mlir::IRRewriter rewriter(context);
223 BoxprocTypeRewriter typeConverter(mlir::UnknownLoc::get(context));
224 mlir::Dialect *firDialect = context->getLoadedDialect("fir");
225 getModule().walk([&](mlir::Operation *op) {
226 bool opIsValid = true;
227 typeConverter.setLocation(op->getLoc());
228 if (auto addr = mlir::dyn_cast<BoxAddrOp>(op)) {
229 mlir::Type ty = addr.getVal().getType();
230 mlir::Type resTy = addr.getResult().getType();
231 if (llvm::isa<mlir::FunctionType>(ty) ||
232 llvm::isa<fir::BoxProcType>(ty)) {
233 // Rewrite all `fir.box_addr` ops on values of type `!fir.boxproc`
234 // or function type to be `fir.convert` ops.
235 rewriter.setInsertionPoint(addr);
236 rewriter.replaceOpWithNewOp<ConvertOp>(
237 addr, typeConverter.convertType(addr.getType()), addr.getVal());
238 opIsValid = false;
239 } else if (typeConverter.needsConversion(resTy)) {
240 rewriter.startOpModification(op);
241 op->getResult(0).setType(typeConverter.convertType(resTy));
242 rewriter.finalizeOpModification(op);
243 }
244 } else if (auto func = mlir::dyn_cast<mlir::func::FuncOp>(op)) {
245 mlir::FunctionType ty = func.getFunctionType();
246 if (typeConverter.needsConversion(ty)) {
247 rewriter.startOpModification(func);
248 auto toTy =
249 typeConverter.convertType(ty).cast<mlir::FunctionType>();
250 if (!func.empty())
251 for (auto e : llvm::enumerate(toTy.getInputs())) {
252 unsigned i = e.index();
253 auto &block = func.front();
254 block.insertArgument(i, e.value(), func.getLoc());
255 block.getArgument(i + 1).replaceAllUsesWith(
256 block.getArgument(i));
257 block.eraseArgument(i + 1);
258 }
259 func.setType(toTy);
260 rewriter.finalizeOpModification(func);
261 }
262 } else if (auto embox = mlir::dyn_cast<EmboxProcOp>(op)) {
263 // Rewrite all `fir.emboxproc` ops to either `fir.convert` or a thunk
264 // as required.
265 mlir::Type toTy = typeConverter.convertType(
266 embox.getType().cast<BoxProcType>().getEleTy());
267 rewriter.setInsertionPoint(embox);
268 if (embox.getHost()) {
269 // Create the thunk.
270 auto module = embox->getParentOfType<mlir::ModuleOp>();
271 FirOpBuilder builder(rewriter, module);
272 auto loc = embox.getLoc();
273 mlir::Type i8Ty = builder.getI8Type();
274 mlir::Type i8Ptr = builder.getRefType(i8Ty);
275 mlir::Type buffTy = SequenceType::get({32}, i8Ty);
276 auto buffer = builder.create<AllocaOp>(loc, buffTy);
277 mlir::Value closure =
278 builder.createConvert(loc, i8Ptr, embox.getHost());
279 mlir::Value tramp = builder.createConvert(loc, i8Ptr, buffer);
280 mlir::Value func =
281 builder.createConvert(loc, i8Ptr, embox.getFunc());
282 builder.create<fir::CallOp>(
283 loc, factory::getLlvmInitTrampoline(builder),
284 llvm::ArrayRef<mlir::Value>{tramp, func, closure});
285 auto adjustCall = builder.create<fir::CallOp>(
286 loc, factory::getLlvmAdjustTrampoline(builder),
287 llvm::ArrayRef<mlir::Value>{tramp});
288 rewriter.replaceOpWithNewOp<ConvertOp>(embox, toTy,
289 adjustCall.getResult(0));
290 opIsValid = false;
291 } else {
292 // Just forward the function as a pointer.
293 rewriter.replaceOpWithNewOp<ConvertOp>(embox, toTy,
294 embox.getFunc());
295 opIsValid = false;
296 }
297 } else if (auto global = mlir::dyn_cast<GlobalOp>(op)) {
298 auto ty = global.getType();
299 if (typeConverter.needsConversion(ty)) {
300 rewriter.startOpModification(global);
301 auto toTy = typeConverter.convertType(ty);
302 global.setType(toTy);
303 rewriter.finalizeOpModification(global);
304 }
305 } else if (auto mem = mlir::dyn_cast<AllocaOp>(op)) {
306 auto ty = mem.getType();
307 if (typeConverter.needsConversion(ty)) {
308 rewriter.setInsertionPoint(mem);
309 auto toTy = typeConverter.convertType(unwrapRefType(ty));
310 bool isPinned = mem.getPinned();
311 llvm::StringRef uniqName =
312 mem.getUniqName().value_or(llvm::StringRef());
313 llvm::StringRef bindcName =
314 mem.getBindcName().value_or(llvm::StringRef());
315 rewriter.replaceOpWithNewOp<AllocaOp>(
316 mem, toTy, uniqName, bindcName, isPinned, mem.getTypeparams(),
317 mem.getShape());
318 opIsValid = false;
319 }
320 } else if (auto mem = mlir::dyn_cast<AllocMemOp>(op)) {
321 auto ty = mem.getType();
322 if (typeConverter.needsConversion(ty)) {
323 rewriter.setInsertionPoint(mem);
324 auto toTy = typeConverter.convertType(unwrapRefType(ty));
325 llvm::StringRef uniqName =
326 mem.getUniqName().value_or(llvm::StringRef());
327 llvm::StringRef bindcName =
328 mem.getBindcName().value_or(llvm::StringRef());
329 rewriter.replaceOpWithNewOp<AllocMemOp>(
330 mem, toTy, uniqName, bindcName, mem.getTypeparams(),
331 mem.getShape());
332 opIsValid = false;
333 }
334 } else if (auto coor = mlir::dyn_cast<CoordinateOp>(op)) {
335 auto ty = coor.getType();
336 mlir::Type baseTy = coor.getBaseType();
337 if (typeConverter.needsConversion(ty) ||
338 typeConverter.needsConversion(baseTy)) {
339 rewriter.setInsertionPoint(coor);
340 auto toTy = typeConverter.convertType(ty);
341 auto toBaseTy = typeConverter.convertType(baseTy);
342 rewriter.replaceOpWithNewOp<CoordinateOp>(coor, toTy, coor.getRef(),
343 coor.getCoor(), toBaseTy);
344 opIsValid = false;
345 }
346 } else if (auto index = mlir::dyn_cast<FieldIndexOp>(op)) {
347 auto ty = index.getType();
348 mlir::Type onTy = index.getOnType();
349 if (typeConverter.needsConversion(ty) ||
350 typeConverter.needsConversion(onTy)) {
351 rewriter.setInsertionPoint(index);
352 auto toTy = typeConverter.convertType(ty);
353 auto toOnTy = typeConverter.convertType(onTy);
354 rewriter.replaceOpWithNewOp<FieldIndexOp>(
355 index, toTy, index.getFieldId(), toOnTy, index.getTypeparams());
356 opIsValid = false;
357 }
358 } else if (auto index = mlir::dyn_cast<LenParamIndexOp>(op)) {
359 auto ty = index.getType();
360 mlir::Type onTy = index.getOnType();
361 if (typeConverter.needsConversion(ty) ||
362 typeConverter.needsConversion(onTy)) {
363 rewriter.setInsertionPoint(index);
364 auto toTy = typeConverter.convertType(ty);
365 auto toOnTy = typeConverter.convertType(onTy);
366 rewriter.replaceOpWithNewOp<LenParamIndexOp>(
367 index, toTy, index.getFieldId(), toOnTy, index.getTypeparams());
368 opIsValid = false;
369 }
370 } else if (op->getDialect() == firDialect) {
371 rewriter.startOpModification(op);
372 for (auto i : llvm::enumerate(op->getResultTypes()))
373 if (typeConverter.needsConversion(i.value())) {
374 auto toTy = typeConverter.convertType(i.value());
375 op->getResult(i.index()).setType(toTy);
376 }
377 rewriter.finalizeOpModification(op);
378 }
379 // Ensure block arguments are updated if needed.
380 if (opIsValid && op->getNumRegions() != 0) {
381 rewriter.startOpModification(op);
382 for (mlir::Region &region : op->getRegions())
383 for (mlir::Block &block : region.getBlocks())
384 for (mlir::BlockArgument blockArg : block.getArguments())
385 if (typeConverter.needsConversion(blockArg.getType())) {
386 mlir::Type toTy =
387 typeConverter.convertType(blockArg.getType());
388 blockArg.setType(toTy);
389 }
390 rewriter.finalizeOpModification(op);
391 }
392 });
393 }
394 }
395
396private:
397 BoxedProcedureOptions options;
398};
399} // namespace
400
401std::unique_ptr<mlir::Pass> fir::createBoxedProcedurePass() {
402 return std::make_unique<BoxedProcedurePass>();
403}
404
405std::unique_ptr<mlir::Pass> fir::createBoxedProcedurePass(bool useThunks) {
406 return std::make_unique<BoxedProcedurePass>(useThunks);
407}
408

source code of flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp