1 | //===-- BoxedProcedure.cpp ------------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "flang/Optimizer/CodeGen/CodeGen.h" |
10 | |
11 | #include "flang/Optimizer/Builder/FIRBuilder.h" |
12 | #include "flang/Optimizer/Builder/LowLevelIntrinsics.h" |
13 | #include "flang/Optimizer/Dialect/FIRDialect.h" |
14 | #include "flang/Optimizer/Dialect/FIROps.h" |
15 | #include "flang/Optimizer/Dialect/FIRType.h" |
16 | #include "flang/Optimizer/Dialect/Support/FIRContext.h" |
17 | #include "flang/Optimizer/Support/FatalError.h" |
18 | #include "flang/Optimizer/Support/InternalNames.h" |
19 | #include "mlir/IR/PatternMatch.h" |
20 | #include "mlir/Pass/Pass.h" |
21 | #include "mlir/Transforms/DialectConversion.h" |
22 | #include "llvm/ADT/DenseMap.h" |
23 | |
24 | namespace fir { |
25 | #define GEN_PASS_DEF_BOXEDPROCEDUREPASS |
26 | #include "flang/Optimizer/CodeGen/CGPasses.h.inc" |
27 | } // namespace fir |
28 | |
29 | #define DEBUG_TYPE "flang-procedure-pointer" |
30 | |
31 | using namespace fir; |
32 | |
33 | namespace { |
34 | /// Options to the procedure pointer pass. |
35 | struct BoxedProcedureOptions { |
36 | // Lower the boxproc abstraction to function pointers and thunks where |
37 | // required. |
38 | bool useThunks = true; |
39 | }; |
40 | |
41 | /// This type converter rewrites all `!fir.boxproc<Func>` types to `Func` types. |
42 | class BoxprocTypeRewriter : public mlir::TypeConverter { |
43 | public: |
44 | using mlir::TypeConverter::convertType; |
45 | |
46 | /// Does the type \p ty need to be converted? |
47 | /// Any type that is a `!fir.boxproc` in whole or in part will need to be |
48 | /// converted to a function type to lower the IR to function pointer form in |
49 | /// the default implementation performed in this pass. Other implementations |
50 | /// are possible, so those may convert `!fir.boxproc` to some other type or |
51 | /// not at all depending on the implementation target's characteristics and |
52 | /// preference. |
53 | bool needsConversion(mlir::Type ty) { |
54 | if (ty.isa<BoxProcType>()) |
55 | return true; |
56 | if (auto funcTy = ty.dyn_cast<mlir::FunctionType>()) { |
57 | for (auto t : funcTy.getInputs()) |
58 | if (needsConversion(t)) |
59 | return true; |
60 | for (auto t : funcTy.getResults()) |
61 | if (needsConversion(t)) |
62 | return true; |
63 | return false; |
64 | } |
65 | if (auto tupleTy = ty.dyn_cast<mlir::TupleType>()) { |
66 | for (auto t : tupleTy.getTypes()) |
67 | if (needsConversion(t)) |
68 | return true; |
69 | return false; |
70 | } |
71 | if (auto recTy = ty.dyn_cast<RecordType>()) { |
72 | auto visited = visitedTypes.find(ty); |
73 | if (visited != visitedTypes.end()) |
74 | return visited->second; |
75 | [[maybe_unused]] auto newIt = visitedTypes.try_emplace(ty, false); |
76 | assert(newIt.second && "expected ty to not be in the map" ); |
77 | bool wasAlreadyVisitingRecordType = needConversionIsVisitingRecordType; |
78 | needConversionIsVisitingRecordType = true; |
79 | bool result = false; |
80 | for (auto t : recTy.getTypeList()) { |
81 | if (needsConversion(t.second)) { |
82 | result = true; |
83 | break; |
84 | } |
85 | } |
86 | // Only keep the result cached if the fir.type visited was a "top-level |
87 | // type". Nested types with a recursive reference to the "top-level type" |
88 | // may incorrectly have been resolved as not needed conversions because it |
89 | // had not been determined yet if the "top-level type" needed conversion. |
90 | // This is not an issue to determine the "top-level type" need of |
91 | // conversion, but the result should not be kept and later used in other |
92 | // contexts. |
93 | needConversionIsVisitingRecordType = wasAlreadyVisitingRecordType; |
94 | if (needConversionIsVisitingRecordType) |
95 | visitedTypes.erase(ty); |
96 | else |
97 | visitedTypes.find(ty)->second = result; |
98 | return result; |
99 | } |
100 | if (auto boxTy = ty.dyn_cast<BaseBoxType>()) |
101 | return needsConversion(boxTy.getEleTy()); |
102 | if (isa_ref_type(ty)) |
103 | return needsConversion(unwrapRefType(ty)); |
104 | if (auto t = ty.dyn_cast<SequenceType>()) |
105 | return needsConversion(unwrapSequenceType(ty)); |
106 | return false; |
107 | } |
108 | |
109 | BoxprocTypeRewriter(mlir::Location location) : loc{location} { |
110 | addConversion([](mlir::Type ty) { return ty; }); |
111 | addConversion( |
112 | [&](BoxProcType boxproc) { return convertType(boxproc.getEleTy()); }); |
113 | addConversion([&](mlir::TupleType tupTy) { |
114 | llvm::SmallVector<mlir::Type> memTys; |
115 | for (auto ty : tupTy.getTypes()) |
116 | memTys.push_back(convertType(ty)); |
117 | return mlir::TupleType::get(tupTy.getContext(), memTys); |
118 | }); |
119 | addConversion([&](mlir::FunctionType funcTy) { |
120 | llvm::SmallVector<mlir::Type> inTys; |
121 | llvm::SmallVector<mlir::Type> resTys; |
122 | for (auto ty : funcTy.getInputs()) |
123 | inTys.push_back(convertType(ty)); |
124 | for (auto ty : funcTy.getResults()) |
125 | resTys.push_back(convertType(ty)); |
126 | return mlir::FunctionType::get(funcTy.getContext(), inTys, resTys); |
127 | }); |
128 | addConversion([&](ReferenceType ty) { |
129 | return ReferenceType::get(convertType(ty.getEleTy())); |
130 | }); |
131 | addConversion([&](PointerType ty) { |
132 | return PointerType::get(convertType(ty.getEleTy())); |
133 | }); |
134 | addConversion( |
135 | [&](HeapType ty) { return HeapType::get(convertType(ty.getEleTy())); }); |
136 | addConversion([&](fir::LLVMPointerType ty) { |
137 | return fir::LLVMPointerType::get(convertType(ty.getEleTy())); |
138 | }); |
139 | addConversion( |
140 | [&](BoxType ty) { return BoxType::get(convertType(ty.getEleTy())); }); |
141 | addConversion([&](ClassType ty) { |
142 | return ClassType::get(convertType(ty.getEleTy())); |
143 | }); |
144 | addConversion([&](SequenceType ty) { |
145 | // TODO: add ty.getLayoutMap() as needed. |
146 | return SequenceType::get(ty.getShape(), convertType(ty.getEleTy())); |
147 | }); |
148 | addConversion([&](RecordType ty) -> mlir::Type { |
149 | if (!needsConversion(ty)) |
150 | return ty; |
151 | if (auto converted = convertedTypes.lookup(ty)) |
152 | return converted; |
153 | auto rec = RecordType::get(ty.getContext(), |
154 | ty.getName().str() + boxprocSuffix.str()); |
155 | if (rec.isFinalized()) |
156 | return rec; |
157 | [[maybe_unused]] auto it = convertedTypes.try_emplace(ty, rec); |
158 | assert(it.second && "expected ty to not be in the map" ); |
159 | std::vector<RecordType::TypePair> ps = ty.getLenParamList(); |
160 | std::vector<RecordType::TypePair> cs; |
161 | for (auto t : ty.getTypeList()) { |
162 | if (needsConversion(t.second)) |
163 | cs.emplace_back(t.first, convertType(t.second)); |
164 | else |
165 | cs.emplace_back(t.first, t.second); |
166 | } |
167 | rec.finalize(ps, cs); |
168 | return rec; |
169 | }); |
170 | addArgumentMaterialization(materializeProcedure); |
171 | addSourceMaterialization(materializeProcedure); |
172 | addTargetMaterialization(materializeProcedure); |
173 | } |
174 | |
175 | static mlir::Value materializeProcedure(mlir::OpBuilder &builder, |
176 | BoxProcType type, |
177 | mlir::ValueRange inputs, |
178 | mlir::Location loc) { |
179 | assert(inputs.size() == 1); |
180 | return builder.create<ConvertOp>(loc, unwrapRefType(type.getEleTy()), |
181 | inputs[0]); |
182 | } |
183 | |
184 | void setLocation(mlir::Location location) { loc = location; } |
185 | |
186 | private: |
187 | // Maps to deal with recursive derived types (avoid infinite loops). |
188 | // Caching is also beneficial for apps with big types (dozens of |
189 | // components and or parent types), so the lifetime of the cache |
190 | // is the whole pass. |
191 | llvm::DenseMap<mlir::Type, bool> visitedTypes; |
192 | bool needConversionIsVisitingRecordType = false; |
193 | llvm::DenseMap<mlir::Type, mlir::Type> convertedTypes; |
194 | mlir::Location loc; |
195 | }; |
196 | |
197 | /// A `boxproc` is an abstraction for a Fortran procedure reference. Typically, |
198 | /// Fortran procedures can be referenced directly through a function pointer. |
199 | /// However, Fortran has one-level dynamic scoping between a host procedure and |
200 | /// its internal procedures. This allows internal procedures to directly access |
201 | /// and modify the state of the host procedure's variables. |
202 | /// |
203 | /// There are any number of possible implementations possible. |
204 | /// |
205 | /// The implementation used here is to convert `boxproc` values to function |
206 | /// pointers everywhere. If a `boxproc` value includes a frame pointer to the |
207 | /// host procedure's data, then a thunk will be created at runtime to capture |
208 | /// the frame pointer during execution. In LLVM IR, the frame pointer is |
209 | /// designated with the `nest` attribute. The thunk's address will then be used |
210 | /// as the call target instead of the original function's address directly. |
211 | class BoxedProcedurePass |
212 | : public fir::impl::BoxedProcedurePassBase<BoxedProcedurePass> { |
213 | public: |
214 | BoxedProcedurePass() { options = {.useThunks: true}; } |
215 | BoxedProcedurePass(bool useThunks) { options = {.useThunks: useThunks}; } |
216 | |
217 | inline mlir::ModuleOp getModule() { return getOperation(); } |
218 | |
219 | void runOnOperation() override final { |
220 | if (options.useThunks) { |
221 | auto *context = &getContext(); |
222 | mlir::IRRewriter rewriter(context); |
223 | BoxprocTypeRewriter typeConverter(mlir::UnknownLoc::get(context)); |
224 | mlir::Dialect *firDialect = context->getLoadedDialect("fir" ); |
225 | getModule().walk([&](mlir::Operation *op) { |
226 | bool opIsValid = true; |
227 | typeConverter.setLocation(op->getLoc()); |
228 | if (auto addr = mlir::dyn_cast<BoxAddrOp>(op)) { |
229 | mlir::Type ty = addr.getVal().getType(); |
230 | mlir::Type resTy = addr.getResult().getType(); |
231 | if (llvm::isa<mlir::FunctionType>(ty) || |
232 | llvm::isa<fir::BoxProcType>(ty)) { |
233 | // Rewrite all `fir.box_addr` ops on values of type `!fir.boxproc` |
234 | // or function type to be `fir.convert` ops. |
235 | rewriter.setInsertionPoint(addr); |
236 | rewriter.replaceOpWithNewOp<ConvertOp>( |
237 | addr, typeConverter.convertType(addr.getType()), addr.getVal()); |
238 | opIsValid = false; |
239 | } else if (typeConverter.needsConversion(resTy)) { |
240 | rewriter.startOpModification(op); |
241 | op->getResult(0).setType(typeConverter.convertType(resTy)); |
242 | rewriter.finalizeOpModification(op); |
243 | } |
244 | } else if (auto func = mlir::dyn_cast<mlir::func::FuncOp>(op)) { |
245 | mlir::FunctionType ty = func.getFunctionType(); |
246 | if (typeConverter.needsConversion(ty)) { |
247 | rewriter.startOpModification(func); |
248 | auto toTy = |
249 | typeConverter.convertType(ty).cast<mlir::FunctionType>(); |
250 | if (!func.empty()) |
251 | for (auto e : llvm::enumerate(toTy.getInputs())) { |
252 | unsigned i = e.index(); |
253 | auto &block = func.front(); |
254 | block.insertArgument(i, e.value(), func.getLoc()); |
255 | block.getArgument(i + 1).replaceAllUsesWith( |
256 | block.getArgument(i)); |
257 | block.eraseArgument(i + 1); |
258 | } |
259 | func.setType(toTy); |
260 | rewriter.finalizeOpModification(func); |
261 | } |
262 | } else if (auto embox = mlir::dyn_cast<EmboxProcOp>(op)) { |
263 | // Rewrite all `fir.emboxproc` ops to either `fir.convert` or a thunk |
264 | // as required. |
265 | mlir::Type toTy = typeConverter.convertType( |
266 | embox.getType().cast<BoxProcType>().getEleTy()); |
267 | rewriter.setInsertionPoint(embox); |
268 | if (embox.getHost()) { |
269 | // Create the thunk. |
270 | auto module = embox->getParentOfType<mlir::ModuleOp>(); |
271 | FirOpBuilder builder(rewriter, module); |
272 | auto loc = embox.getLoc(); |
273 | mlir::Type i8Ty = builder.getI8Type(); |
274 | mlir::Type i8Ptr = builder.getRefType(i8Ty); |
275 | mlir::Type buffTy = SequenceType::get({32}, i8Ty); |
276 | auto buffer = builder.create<AllocaOp>(loc, buffTy); |
277 | mlir::Value closure = |
278 | builder.createConvert(loc, i8Ptr, embox.getHost()); |
279 | mlir::Value tramp = builder.createConvert(loc, i8Ptr, buffer); |
280 | mlir::Value func = |
281 | builder.createConvert(loc, i8Ptr, embox.getFunc()); |
282 | builder.create<fir::CallOp>( |
283 | loc, factory::getLlvmInitTrampoline(builder), |
284 | llvm::ArrayRef<mlir::Value>{tramp, func, closure}); |
285 | auto adjustCall = builder.create<fir::CallOp>( |
286 | loc, factory::getLlvmAdjustTrampoline(builder), |
287 | llvm::ArrayRef<mlir::Value>{tramp}); |
288 | rewriter.replaceOpWithNewOp<ConvertOp>(embox, toTy, |
289 | adjustCall.getResult(0)); |
290 | opIsValid = false; |
291 | } else { |
292 | // Just forward the function as a pointer. |
293 | rewriter.replaceOpWithNewOp<ConvertOp>(embox, toTy, |
294 | embox.getFunc()); |
295 | opIsValid = false; |
296 | } |
297 | } else if (auto global = mlir::dyn_cast<GlobalOp>(op)) { |
298 | auto ty = global.getType(); |
299 | if (typeConverter.needsConversion(ty)) { |
300 | rewriter.startOpModification(global); |
301 | auto toTy = typeConverter.convertType(ty); |
302 | global.setType(toTy); |
303 | rewriter.finalizeOpModification(global); |
304 | } |
305 | } else if (auto mem = mlir::dyn_cast<AllocaOp>(op)) { |
306 | auto ty = mem.getType(); |
307 | if (typeConverter.needsConversion(ty)) { |
308 | rewriter.setInsertionPoint(mem); |
309 | auto toTy = typeConverter.convertType(unwrapRefType(ty)); |
310 | bool isPinned = mem.getPinned(); |
311 | llvm::StringRef uniqName = |
312 | mem.getUniqName().value_or(llvm::StringRef()); |
313 | llvm::StringRef bindcName = |
314 | mem.getBindcName().value_or(llvm::StringRef()); |
315 | rewriter.replaceOpWithNewOp<AllocaOp>( |
316 | mem, toTy, uniqName, bindcName, isPinned, mem.getTypeparams(), |
317 | mem.getShape()); |
318 | opIsValid = false; |
319 | } |
320 | } else if (auto mem = mlir::dyn_cast<AllocMemOp>(op)) { |
321 | auto ty = mem.getType(); |
322 | if (typeConverter.needsConversion(ty)) { |
323 | rewriter.setInsertionPoint(mem); |
324 | auto toTy = typeConverter.convertType(unwrapRefType(ty)); |
325 | llvm::StringRef uniqName = |
326 | mem.getUniqName().value_or(llvm::StringRef()); |
327 | llvm::StringRef bindcName = |
328 | mem.getBindcName().value_or(llvm::StringRef()); |
329 | rewriter.replaceOpWithNewOp<AllocMemOp>( |
330 | mem, toTy, uniqName, bindcName, mem.getTypeparams(), |
331 | mem.getShape()); |
332 | opIsValid = false; |
333 | } |
334 | } else if (auto coor = mlir::dyn_cast<CoordinateOp>(op)) { |
335 | auto ty = coor.getType(); |
336 | mlir::Type baseTy = coor.getBaseType(); |
337 | if (typeConverter.needsConversion(ty) || |
338 | typeConverter.needsConversion(baseTy)) { |
339 | rewriter.setInsertionPoint(coor); |
340 | auto toTy = typeConverter.convertType(ty); |
341 | auto toBaseTy = typeConverter.convertType(baseTy); |
342 | rewriter.replaceOpWithNewOp<CoordinateOp>(coor, toTy, coor.getRef(), |
343 | coor.getCoor(), toBaseTy); |
344 | opIsValid = false; |
345 | } |
346 | } else if (auto index = mlir::dyn_cast<FieldIndexOp>(op)) { |
347 | auto ty = index.getType(); |
348 | mlir::Type onTy = index.getOnType(); |
349 | if (typeConverter.needsConversion(ty) || |
350 | typeConverter.needsConversion(onTy)) { |
351 | rewriter.setInsertionPoint(index); |
352 | auto toTy = typeConverter.convertType(ty); |
353 | auto toOnTy = typeConverter.convertType(onTy); |
354 | rewriter.replaceOpWithNewOp<FieldIndexOp>( |
355 | index, toTy, index.getFieldId(), toOnTy, index.getTypeparams()); |
356 | opIsValid = false; |
357 | } |
358 | } else if (auto index = mlir::dyn_cast<LenParamIndexOp>(op)) { |
359 | auto ty = index.getType(); |
360 | mlir::Type onTy = index.getOnType(); |
361 | if (typeConverter.needsConversion(ty) || |
362 | typeConverter.needsConversion(onTy)) { |
363 | rewriter.setInsertionPoint(index); |
364 | auto toTy = typeConverter.convertType(ty); |
365 | auto toOnTy = typeConverter.convertType(onTy); |
366 | rewriter.replaceOpWithNewOp<LenParamIndexOp>( |
367 | index, toTy, index.getFieldId(), toOnTy, index.getTypeparams()); |
368 | opIsValid = false; |
369 | } |
370 | } else if (op->getDialect() == firDialect) { |
371 | rewriter.startOpModification(op); |
372 | for (auto i : llvm::enumerate(op->getResultTypes())) |
373 | if (typeConverter.needsConversion(i.value())) { |
374 | auto toTy = typeConverter.convertType(i.value()); |
375 | op->getResult(i.index()).setType(toTy); |
376 | } |
377 | rewriter.finalizeOpModification(op); |
378 | } |
379 | // Ensure block arguments are updated if needed. |
380 | if (opIsValid && op->getNumRegions() != 0) { |
381 | rewriter.startOpModification(op); |
382 | for (mlir::Region ®ion : op->getRegions()) |
383 | for (mlir::Block &block : region.getBlocks()) |
384 | for (mlir::BlockArgument blockArg : block.getArguments()) |
385 | if (typeConverter.needsConversion(blockArg.getType())) { |
386 | mlir::Type toTy = |
387 | typeConverter.convertType(blockArg.getType()); |
388 | blockArg.setType(toTy); |
389 | } |
390 | rewriter.finalizeOpModification(op); |
391 | } |
392 | }); |
393 | } |
394 | } |
395 | |
396 | private: |
397 | BoxedProcedureOptions options; |
398 | }; |
399 | } // namespace |
400 | |
401 | std::unique_ptr<mlir::Pass> fir::createBoxedProcedurePass() { |
402 | return std::make_unique<BoxedProcedurePass>(); |
403 | } |
404 | |
405 | std::unique_ptr<mlir::Pass> fir::createBoxedProcedurePass(bool useThunks) { |
406 | return std::make_unique<BoxedProcedurePass>(useThunks); |
407 | } |
408 | |