1 | //===-- BoxedProcedure.cpp ------------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "flang/Optimizer/CodeGen/CodeGen.h" |
10 | |
11 | #include "flang/Optimizer/Builder/FIRBuilder.h" |
12 | #include "flang/Optimizer/Builder/LowLevelIntrinsics.h" |
13 | #include "flang/Optimizer/Dialect/FIRDialect.h" |
14 | #include "flang/Optimizer/Dialect/FIROps.h" |
15 | #include "flang/Optimizer/Dialect/FIRType.h" |
16 | #include "flang/Optimizer/Dialect/Support/FIRContext.h" |
17 | #include "flang/Optimizer/Support/FatalError.h" |
18 | #include "flang/Optimizer/Support/InternalNames.h" |
19 | #include "mlir/IR/PatternMatch.h" |
20 | #include "mlir/Pass/Pass.h" |
21 | #include "mlir/Transforms/DialectConversion.h" |
22 | #include "llvm/ADT/DenseMap.h" |
23 | |
24 | namespace fir { |
25 | #define GEN_PASS_DEF_BOXEDPROCEDUREPASS |
26 | #include "flang/Optimizer/CodeGen/CGPasses.h.inc" |
27 | } // namespace fir |
28 | |
29 | #define DEBUG_TYPE "flang-procedure-pointer" |
30 | |
31 | using namespace fir; |
32 | |
33 | namespace { |
34 | /// Options to the procedure pointer pass. |
35 | struct BoxedProcedureOptions { |
36 | // Lower the boxproc abstraction to function pointers and thunks where |
37 | // required. |
38 | bool useThunks = true; |
39 | }; |
40 | |
41 | /// This type converter rewrites all `!fir.boxproc<Func>` types to `Func` types. |
42 | class BoxprocTypeRewriter : public mlir::TypeConverter { |
43 | public: |
44 | using mlir::TypeConverter::convertType; |
45 | |
46 | /// Does the type \p ty need to be converted? |
47 | /// Any type that is a `!fir.boxproc` in whole or in part will need to be |
48 | /// converted to a function type to lower the IR to function pointer form in |
49 | /// the default implementation performed in this pass. Other implementations |
50 | /// are possible, so those may convert `!fir.boxproc` to some other type or |
51 | /// not at all depending on the implementation target's characteristics and |
52 | /// preference. |
53 | bool needsConversion(mlir::Type ty) { |
54 | if (mlir::isa<BoxProcType>(ty)) |
55 | return true; |
56 | if (auto funcTy = mlir::dyn_cast<mlir::FunctionType>(ty)) { |
57 | for (auto t : funcTy.getInputs()) |
58 | if (needsConversion(t)) |
59 | return true; |
60 | for (auto t : funcTy.getResults()) |
61 | if (needsConversion(t)) |
62 | return true; |
63 | return false; |
64 | } |
65 | if (auto tupleTy = mlir::dyn_cast<mlir::TupleType>(ty)) { |
66 | for (auto t : tupleTy.getTypes()) |
67 | if (needsConversion(t)) |
68 | return true; |
69 | return false; |
70 | } |
71 | if (auto recTy = mlir::dyn_cast<RecordType>(ty)) { |
72 | auto [visited, inserted] = visitedTypes.try_emplace(ty, false); |
73 | if (!inserted) |
74 | return visited->second; |
75 | bool wasAlreadyVisitingRecordType = needConversionIsVisitingRecordType; |
76 | needConversionIsVisitingRecordType = true; |
77 | bool result = false; |
78 | for (auto t : recTy.getTypeList()) { |
79 | if (needsConversion(t.second)) { |
80 | result = true; |
81 | break; |
82 | } |
83 | } |
84 | // Only keep the result cached if the fir.type visited was a "top-level |
85 | // type". Nested types with a recursive reference to the "top-level type" |
86 | // may incorrectly have been resolved as not needed conversions because it |
87 | // had not been determined yet if the "top-level type" needed conversion. |
88 | // This is not an issue to determine the "top-level type" need of |
89 | // conversion, but the result should not be kept and later used in other |
90 | // contexts. |
91 | needConversionIsVisitingRecordType = wasAlreadyVisitingRecordType; |
92 | if (needConversionIsVisitingRecordType) |
93 | visitedTypes.erase(ty); |
94 | else |
95 | visitedTypes.find(ty)->second = result; |
96 | return result; |
97 | } |
98 | if (auto boxTy = mlir::dyn_cast<BaseBoxType>(ty)) |
99 | return needsConversion(boxTy.getEleTy()); |
100 | if (isa_ref_type(ty)) |
101 | return needsConversion(unwrapRefType(ty)); |
102 | if (auto t = mlir::dyn_cast<SequenceType>(ty)) |
103 | return needsConversion(unwrapSequenceType(ty)); |
104 | if (auto t = mlir::dyn_cast<TypeDescType>(ty)) |
105 | return needsConversion(t.getOfTy()); |
106 | return false; |
107 | } |
108 | |
109 | BoxprocTypeRewriter(mlir::Location location) : loc{location} { |
110 | addConversion([](mlir::Type ty) { return ty; }); |
111 | addConversion( |
112 | [&](BoxProcType boxproc) { return convertType(boxproc.getEleTy()); }); |
113 | addConversion([&](mlir::TupleType tupTy) { |
114 | llvm::SmallVector<mlir::Type> memTys; |
115 | for (auto ty : tupTy.getTypes()) |
116 | memTys.push_back(convertType(ty)); |
117 | return mlir::TupleType::get(tupTy.getContext(), memTys); |
118 | }); |
119 | addConversion([&](mlir::FunctionType funcTy) { |
120 | llvm::SmallVector<mlir::Type> inTys; |
121 | llvm::SmallVector<mlir::Type> resTys; |
122 | for (auto ty : funcTy.getInputs()) |
123 | inTys.push_back(convertType(ty)); |
124 | for (auto ty : funcTy.getResults()) |
125 | resTys.push_back(convertType(ty)); |
126 | return mlir::FunctionType::get(funcTy.getContext(), inTys, resTys); |
127 | }); |
128 | addConversion([&](ReferenceType ty) { |
129 | return ReferenceType::get(convertType(ty.getEleTy())); |
130 | }); |
131 | addConversion([&](PointerType ty) { |
132 | return PointerType::get(convertType(ty.getEleTy())); |
133 | }); |
134 | addConversion( |
135 | [&](HeapType ty) { return HeapType::get(convertType(ty.getEleTy())); }); |
136 | addConversion([&](fir::LLVMPointerType ty) { |
137 | return fir::LLVMPointerType::get(convertType(ty.getEleTy())); |
138 | }); |
139 | addConversion( |
140 | [&](BoxType ty) { return BoxType::get(convertType(ty.getEleTy())); }); |
141 | addConversion([&](ClassType ty) { |
142 | return ClassType::get(convertType(ty.getEleTy())); |
143 | }); |
144 | addConversion([&](SequenceType ty) { |
145 | // TODO: add ty.getLayoutMap() as needed. |
146 | return SequenceType::get(ty.getShape(), convertType(ty.getEleTy())); |
147 | }); |
148 | addConversion([&](RecordType ty) -> mlir::Type { |
149 | if (!needsConversion(ty)) |
150 | return ty; |
151 | if (auto converted = convertedTypes.lookup(ty)) |
152 | return converted; |
153 | auto rec = RecordType::get(ty.getContext(), |
154 | ty.getName().str() + boxprocSuffix.str()); |
155 | if (rec.isFinalized()) |
156 | return rec; |
157 | [[maybe_unused]] auto it = convertedTypes.try_emplace(ty, rec); |
158 | assert(it.second && "expected ty to not be in the map" ); |
159 | std::vector<RecordType::TypePair> ps = ty.getLenParamList(); |
160 | std::vector<RecordType::TypePair> cs; |
161 | for (auto t : ty.getTypeList()) { |
162 | if (needsConversion(t.second)) |
163 | cs.emplace_back(t.first, convertType(t.second)); |
164 | else |
165 | cs.emplace_back(t.first, t.second); |
166 | } |
167 | rec.finalize(ps, cs); |
168 | rec.pack(ty.isPacked()); |
169 | return rec; |
170 | }); |
171 | addConversion([&](TypeDescType ty) { |
172 | return TypeDescType::get(convertType(ty.getOfTy())); |
173 | }); |
174 | addSourceMaterialization(materializeProcedure); |
175 | addTargetMaterialization(materializeProcedure); |
176 | } |
177 | |
178 | static mlir::Value materializeProcedure(mlir::OpBuilder &builder, |
179 | BoxProcType type, |
180 | mlir::ValueRange inputs, |
181 | mlir::Location loc) { |
182 | assert(inputs.size() == 1); |
183 | return builder.create<ConvertOp>(loc, unwrapRefType(type.getEleTy()), |
184 | inputs[0]); |
185 | } |
186 | |
187 | void setLocation(mlir::Location location) { loc = location; } |
188 | |
189 | private: |
190 | // Maps to deal with recursive derived types (avoid infinite loops). |
191 | // Caching is also beneficial for apps with big types (dozens of |
192 | // components and or parent types), so the lifetime of the cache |
193 | // is the whole pass. |
194 | llvm::DenseMap<mlir::Type, bool> visitedTypes; |
195 | bool needConversionIsVisitingRecordType = false; |
196 | llvm::DenseMap<mlir::Type, mlir::Type> convertedTypes; |
197 | mlir::Location loc; |
198 | }; |
199 | |
200 | /// A `boxproc` is an abstraction for a Fortran procedure reference. Typically, |
201 | /// Fortran procedures can be referenced directly through a function pointer. |
202 | /// However, Fortran has one-level dynamic scoping between a host procedure and |
203 | /// its internal procedures. This allows internal procedures to directly access |
204 | /// and modify the state of the host procedure's variables. |
205 | /// |
206 | /// There are any number of possible implementations possible. |
207 | /// |
208 | /// The implementation used here is to convert `boxproc` values to function |
209 | /// pointers everywhere. If a `boxproc` value includes a frame pointer to the |
210 | /// host procedure's data, then a thunk will be created at runtime to capture |
211 | /// the frame pointer during execution. In LLVM IR, the frame pointer is |
212 | /// designated with the `nest` attribute. The thunk's address will then be used |
213 | /// as the call target instead of the original function's address directly. |
214 | class BoxedProcedurePass |
215 | : public fir::impl::BoxedProcedurePassBase<BoxedProcedurePass> { |
216 | public: |
217 | using BoxedProcedurePassBase<BoxedProcedurePass>::BoxedProcedurePassBase; |
218 | |
219 | inline mlir::ModuleOp getModule() { return getOperation(); } |
220 | |
221 | void runOnOperation() override final { |
222 | if (options.useThunks) { |
223 | auto *context = &getContext(); |
224 | mlir::IRRewriter rewriter(context); |
225 | BoxprocTypeRewriter typeConverter(mlir::UnknownLoc::get(context)); |
226 | getModule().walk([&](mlir::Operation *op) { |
227 | bool opIsValid = true; |
228 | typeConverter.setLocation(op->getLoc()); |
229 | if (auto addr = mlir::dyn_cast<BoxAddrOp>(op)) { |
230 | mlir::Type ty = addr.getVal().getType(); |
231 | mlir::Type resTy = addr.getResult().getType(); |
232 | if (llvm::isa<mlir::FunctionType>(ty) || |
233 | llvm::isa<fir::BoxProcType>(ty)) { |
234 | // Rewrite all `fir.box_addr` ops on values of type `!fir.boxproc` |
235 | // or function type to be `fir.convert` ops. |
236 | rewriter.setInsertionPoint(addr); |
237 | rewriter.replaceOpWithNewOp<ConvertOp>( |
238 | addr, typeConverter.convertType(addr.getType()), addr.getVal()); |
239 | opIsValid = false; |
240 | } else if (typeConverter.needsConversion(resTy)) { |
241 | rewriter.startOpModification(op); |
242 | op->getResult(0).setType(typeConverter.convertType(resTy)); |
243 | rewriter.finalizeOpModification(op); |
244 | } |
245 | } else if (auto func = mlir::dyn_cast<mlir::func::FuncOp>(op)) { |
246 | mlir::FunctionType ty = func.getFunctionType(); |
247 | if (typeConverter.needsConversion(ty)) { |
248 | rewriter.startOpModification(func); |
249 | auto toTy = |
250 | mlir::cast<mlir::FunctionType>(typeConverter.convertType(ty)); |
251 | if (!func.empty()) |
252 | for (auto e : llvm::enumerate(toTy.getInputs())) { |
253 | unsigned i = e.index(); |
254 | auto &block = func.front(); |
255 | block.insertArgument(i, e.value(), func.getLoc()); |
256 | block.getArgument(i + 1).replaceAllUsesWith( |
257 | block.getArgument(i)); |
258 | block.eraseArgument(i + 1); |
259 | } |
260 | func.setType(toTy); |
261 | rewriter.finalizeOpModification(func); |
262 | } |
263 | } else if (auto embox = mlir::dyn_cast<EmboxProcOp>(op)) { |
264 | // Rewrite all `fir.emboxproc` ops to either `fir.convert` or a thunk |
265 | // as required. |
266 | mlir::Type toTy = typeConverter.convertType( |
267 | mlir::cast<BoxProcType>(embox.getType()).getEleTy()); |
268 | rewriter.setInsertionPoint(embox); |
269 | if (embox.getHost()) { |
270 | // Create the thunk. |
271 | auto module = embox->getParentOfType<mlir::ModuleOp>(); |
272 | FirOpBuilder builder(rewriter, module); |
273 | const auto triple{fir::getTargetTriple(module)}; |
274 | auto loc = embox.getLoc(); |
275 | mlir::Type i8Ty = builder.getI8Type(); |
276 | mlir::Type i8Ptr = builder.getRefType(i8Ty); |
277 | // For AArch64, PPC32 and PPC64, the thunk is populated by a call to |
278 | // __trampoline_setup, which is defined in |
279 | // compiler-rt/lib/builtins/trampoline_setup.c and requires the |
280 | // thunk size greater than 32 bytes. For RISCV and x86_64, the |
281 | // thunk setup doesn't go through __trampoline_setup and fits in 32 |
282 | // bytes. |
283 | fir::SequenceType::Extent thunkSize = triple.getTrampolineSize(); |
284 | mlir::Type buffTy = SequenceType::get({thunkSize}, i8Ty); |
285 | auto buffer = builder.create<AllocaOp>(loc, buffTy); |
286 | mlir::Value closure = |
287 | builder.createConvert(loc, i8Ptr, embox.getHost()); |
288 | mlir::Value tramp = builder.createConvert(loc, i8Ptr, buffer); |
289 | mlir::Value func = |
290 | builder.createConvert(loc, i8Ptr, embox.getFunc()); |
291 | builder.create<fir::CallOp>( |
292 | loc, factory::getLlvmInitTrampoline(builder), |
293 | llvm::ArrayRef<mlir::Value>{tramp, func, closure}); |
294 | auto adjustCall = builder.create<fir::CallOp>( |
295 | loc, factory::getLlvmAdjustTrampoline(builder), |
296 | llvm::ArrayRef<mlir::Value>{tramp}); |
297 | rewriter.replaceOpWithNewOp<ConvertOp>(embox, toTy, |
298 | adjustCall.getResult(0)); |
299 | opIsValid = false; |
300 | } else { |
301 | // Just forward the function as a pointer. |
302 | rewriter.replaceOpWithNewOp<ConvertOp>(embox, toTy, |
303 | embox.getFunc()); |
304 | opIsValid = false; |
305 | } |
306 | } else if (auto global = mlir::dyn_cast<GlobalOp>(op)) { |
307 | auto ty = global.getType(); |
308 | if (typeConverter.needsConversion(ty)) { |
309 | rewriter.startOpModification(global); |
310 | auto toTy = typeConverter.convertType(ty); |
311 | global.setType(toTy); |
312 | rewriter.finalizeOpModification(global); |
313 | } |
314 | } else if (auto mem = mlir::dyn_cast<AllocaOp>(op)) { |
315 | auto ty = mem.getType(); |
316 | if (typeConverter.needsConversion(ty)) { |
317 | rewriter.setInsertionPoint(mem); |
318 | auto toTy = typeConverter.convertType(unwrapRefType(ty)); |
319 | bool isPinned = mem.getPinned(); |
320 | llvm::StringRef uniqName = |
321 | mem.getUniqName().value_or(llvm::StringRef()); |
322 | llvm::StringRef bindcName = |
323 | mem.getBindcName().value_or(llvm::StringRef()); |
324 | rewriter.replaceOpWithNewOp<AllocaOp>( |
325 | mem, toTy, uniqName, bindcName, isPinned, mem.getTypeparams(), |
326 | mem.getShape()); |
327 | opIsValid = false; |
328 | } |
329 | } else if (auto mem = mlir::dyn_cast<AllocMemOp>(op)) { |
330 | auto ty = mem.getType(); |
331 | if (typeConverter.needsConversion(ty)) { |
332 | rewriter.setInsertionPoint(mem); |
333 | auto toTy = typeConverter.convertType(unwrapRefType(ty)); |
334 | llvm::StringRef uniqName = |
335 | mem.getUniqName().value_or(llvm::StringRef()); |
336 | llvm::StringRef bindcName = |
337 | mem.getBindcName().value_or(llvm::StringRef()); |
338 | rewriter.replaceOpWithNewOp<AllocMemOp>( |
339 | mem, toTy, uniqName, bindcName, mem.getTypeparams(), |
340 | mem.getShape()); |
341 | opIsValid = false; |
342 | } |
343 | } else if (auto coor = mlir::dyn_cast<CoordinateOp>(op)) { |
344 | auto ty = coor.getType(); |
345 | mlir::Type baseTy = coor.getBaseType(); |
346 | if (typeConverter.needsConversion(ty) || |
347 | typeConverter.needsConversion(baseTy)) { |
348 | rewriter.setInsertionPoint(coor); |
349 | auto toTy = typeConverter.convertType(ty); |
350 | auto toBaseTy = typeConverter.convertType(baseTy); |
351 | rewriter.replaceOpWithNewOp<CoordinateOp>( |
352 | coor, toTy, coor.getRef(), coor.getCoor(), toBaseTy, |
353 | coor.getFieldIndicesAttr()); |
354 | opIsValid = false; |
355 | } |
356 | } else if (auto index = mlir::dyn_cast<FieldIndexOp>(op)) { |
357 | auto ty = index.getType(); |
358 | mlir::Type onTy = index.getOnType(); |
359 | if (typeConverter.needsConversion(ty) || |
360 | typeConverter.needsConversion(onTy)) { |
361 | rewriter.setInsertionPoint(index); |
362 | auto toTy = typeConverter.convertType(ty); |
363 | auto toOnTy = typeConverter.convertType(onTy); |
364 | rewriter.replaceOpWithNewOp<FieldIndexOp>( |
365 | index, toTy, index.getFieldId(), toOnTy, index.getTypeparams()); |
366 | opIsValid = false; |
367 | } |
368 | } else if (auto index = mlir::dyn_cast<LenParamIndexOp>(op)) { |
369 | auto ty = index.getType(); |
370 | mlir::Type onTy = index.getOnType(); |
371 | if (typeConverter.needsConversion(ty) || |
372 | typeConverter.needsConversion(onTy)) { |
373 | rewriter.setInsertionPoint(index); |
374 | auto toTy = typeConverter.convertType(ty); |
375 | auto toOnTy = typeConverter.convertType(onTy); |
376 | rewriter.replaceOpWithNewOp<LenParamIndexOp>( |
377 | index, toTy, index.getFieldId(), toOnTy, index.getTypeparams()); |
378 | opIsValid = false; |
379 | } |
380 | } else { |
381 | rewriter.startOpModification(op); |
382 | // Convert the operands if needed |
383 | for (auto i : llvm::enumerate(op->getResultTypes())) |
384 | if (typeConverter.needsConversion(i.value())) { |
385 | auto toTy = typeConverter.convertType(i.value()); |
386 | op->getResult(i.index()).setType(toTy); |
387 | } |
388 | |
389 | // Convert the type attributes if needed |
390 | for (const mlir::NamedAttribute &attr : op->getAttrDictionary()) |
391 | if (auto tyAttr = llvm::dyn_cast<mlir::TypeAttr>(attr.getValue())) |
392 | if (typeConverter.needsConversion(tyAttr.getValue())) { |
393 | auto toTy = typeConverter.convertType(tyAttr.getValue()); |
394 | op->setAttr(attr.getName(), mlir::TypeAttr::get(toTy)); |
395 | } |
396 | rewriter.finalizeOpModification(op); |
397 | } |
398 | // Ensure block arguments are updated if needed. |
399 | if (opIsValid && op->getNumRegions() != 0) { |
400 | rewriter.startOpModification(op); |
401 | for (mlir::Region ®ion : op->getRegions()) |
402 | for (mlir::Block &block : region.getBlocks()) |
403 | for (mlir::BlockArgument blockArg : block.getArguments()) |
404 | if (typeConverter.needsConversion(blockArg.getType())) { |
405 | mlir::Type toTy = |
406 | typeConverter.convertType(blockArg.getType()); |
407 | blockArg.setType(toTy); |
408 | } |
409 | rewriter.finalizeOpModification(op); |
410 | } |
411 | }); |
412 | } |
413 | } |
414 | |
415 | private: |
416 | BoxedProcedureOptions options; |
417 | }; |
418 | } // namespace |
419 | |