| 1 | //===-- BoxedProcedure.cpp ------------------------------------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "flang/Optimizer/CodeGen/CodeGen.h" |
| 10 | |
| 11 | #include "flang/Optimizer/Builder/FIRBuilder.h" |
| 12 | #include "flang/Optimizer/Builder/LowLevelIntrinsics.h" |
| 13 | #include "flang/Optimizer/Dialect/FIRDialect.h" |
| 14 | #include "flang/Optimizer/Dialect/FIROps.h" |
| 15 | #include "flang/Optimizer/Dialect/FIRType.h" |
| 16 | #include "flang/Optimizer/Dialect/Support/FIRContext.h" |
| 17 | #include "flang/Optimizer/Support/FatalError.h" |
| 18 | #include "flang/Optimizer/Support/InternalNames.h" |
| 19 | #include "mlir/IR/PatternMatch.h" |
| 20 | #include "mlir/Pass/Pass.h" |
| 21 | #include "mlir/Transforms/DialectConversion.h" |
| 22 | #include "llvm/ADT/DenseMap.h" |
| 23 | |
| 24 | namespace fir { |
| 25 | #define GEN_PASS_DEF_BOXEDPROCEDUREPASS |
| 26 | #include "flang/Optimizer/CodeGen/CGPasses.h.inc" |
| 27 | } // namespace fir |
| 28 | |
| 29 | #define DEBUG_TYPE "flang-procedure-pointer" |
| 30 | |
| 31 | using namespace fir; |
| 32 | |
| 33 | namespace { |
| 34 | /// Options to the procedure pointer pass. |
| 35 | struct BoxedProcedureOptions { |
| 36 | // Lower the boxproc abstraction to function pointers and thunks where |
| 37 | // required. |
| 38 | bool useThunks = true; |
| 39 | }; |
| 40 | |
| 41 | /// This type converter rewrites all `!fir.boxproc<Func>` types to `Func` types. |
| 42 | class BoxprocTypeRewriter : public mlir::TypeConverter { |
| 43 | public: |
| 44 | using mlir::TypeConverter::convertType; |
| 45 | |
| 46 | /// Does the type \p ty need to be converted? |
| 47 | /// Any type that is a `!fir.boxproc` in whole or in part will need to be |
| 48 | /// converted to a function type to lower the IR to function pointer form in |
| 49 | /// the default implementation performed in this pass. Other implementations |
| 50 | /// are possible, so those may convert `!fir.boxproc` to some other type or |
| 51 | /// not at all depending on the implementation target's characteristics and |
| 52 | /// preference. |
| 53 | bool needsConversion(mlir::Type ty) { |
| 54 | if (mlir::isa<BoxProcType>(ty)) |
| 55 | return true; |
| 56 | if (auto funcTy = mlir::dyn_cast<mlir::FunctionType>(ty)) { |
| 57 | for (auto t : funcTy.getInputs()) |
| 58 | if (needsConversion(t)) |
| 59 | return true; |
| 60 | for (auto t : funcTy.getResults()) |
| 61 | if (needsConversion(t)) |
| 62 | return true; |
| 63 | return false; |
| 64 | } |
| 65 | if (auto tupleTy = mlir::dyn_cast<mlir::TupleType>(ty)) { |
| 66 | for (auto t : tupleTy.getTypes()) |
| 67 | if (needsConversion(t)) |
| 68 | return true; |
| 69 | return false; |
| 70 | } |
| 71 | if (auto recTy = mlir::dyn_cast<RecordType>(ty)) { |
| 72 | auto [visited, inserted] = visitedTypes.try_emplace(ty, false); |
| 73 | if (!inserted) |
| 74 | return visited->second; |
| 75 | bool wasAlreadyVisitingRecordType = needConversionIsVisitingRecordType; |
| 76 | needConversionIsVisitingRecordType = true; |
| 77 | bool result = false; |
| 78 | for (auto t : recTy.getTypeList()) { |
| 79 | if (needsConversion(t.second)) { |
| 80 | result = true; |
| 81 | break; |
| 82 | } |
| 83 | } |
| 84 | // Only keep the result cached if the fir.type visited was a "top-level |
| 85 | // type". Nested types with a recursive reference to the "top-level type" |
| 86 | // may incorrectly have been resolved as not needed conversions because it |
| 87 | // had not been determined yet if the "top-level type" needed conversion. |
| 88 | // This is not an issue to determine the "top-level type" need of |
| 89 | // conversion, but the result should not be kept and later used in other |
| 90 | // contexts. |
| 91 | needConversionIsVisitingRecordType = wasAlreadyVisitingRecordType; |
| 92 | if (needConversionIsVisitingRecordType) |
| 93 | visitedTypes.erase(ty); |
| 94 | else |
| 95 | visitedTypes.find(ty)->second = result; |
| 96 | return result; |
| 97 | } |
| 98 | if (auto boxTy = mlir::dyn_cast<BaseBoxType>(ty)) |
| 99 | return needsConversion(boxTy.getEleTy()); |
| 100 | if (isa_ref_type(ty)) |
| 101 | return needsConversion(unwrapRefType(ty)); |
| 102 | if (auto t = mlir::dyn_cast<SequenceType>(ty)) |
| 103 | return needsConversion(unwrapSequenceType(ty)); |
| 104 | if (auto t = mlir::dyn_cast<TypeDescType>(ty)) |
| 105 | return needsConversion(t.getOfTy()); |
| 106 | return false; |
| 107 | } |
| 108 | |
| 109 | BoxprocTypeRewriter(mlir::Location location) : loc{location} { |
| 110 | addConversion([](mlir::Type ty) { return ty; }); |
| 111 | addConversion( |
| 112 | [&](BoxProcType boxproc) { return convertType(boxproc.getEleTy()); }); |
| 113 | addConversion([&](mlir::TupleType tupTy) { |
| 114 | llvm::SmallVector<mlir::Type> memTys; |
| 115 | for (auto ty : tupTy.getTypes()) |
| 116 | memTys.push_back(convertType(ty)); |
| 117 | return mlir::TupleType::get(tupTy.getContext(), memTys); |
| 118 | }); |
| 119 | addConversion([&](mlir::FunctionType funcTy) { |
| 120 | llvm::SmallVector<mlir::Type> inTys; |
| 121 | llvm::SmallVector<mlir::Type> resTys; |
| 122 | for (auto ty : funcTy.getInputs()) |
| 123 | inTys.push_back(convertType(ty)); |
| 124 | for (auto ty : funcTy.getResults()) |
| 125 | resTys.push_back(convertType(ty)); |
| 126 | return mlir::FunctionType::get(funcTy.getContext(), inTys, resTys); |
| 127 | }); |
| 128 | addConversion([&](ReferenceType ty) { |
| 129 | return ReferenceType::get(convertType(ty.getEleTy())); |
| 130 | }); |
| 131 | addConversion([&](PointerType ty) { |
| 132 | return PointerType::get(convertType(ty.getEleTy())); |
| 133 | }); |
| 134 | addConversion( |
| 135 | [&](HeapType ty) { return HeapType::get(convertType(ty.getEleTy())); }); |
| 136 | addConversion([&](fir::LLVMPointerType ty) { |
| 137 | return fir::LLVMPointerType::get(convertType(ty.getEleTy())); |
| 138 | }); |
| 139 | addConversion( |
| 140 | [&](BoxType ty) { return BoxType::get(convertType(ty.getEleTy())); }); |
| 141 | addConversion([&](ClassType ty) { |
| 142 | return ClassType::get(convertType(ty.getEleTy())); |
| 143 | }); |
| 144 | addConversion([&](SequenceType ty) { |
| 145 | // TODO: add ty.getLayoutMap() as needed. |
| 146 | return SequenceType::get(ty.getShape(), convertType(ty.getEleTy())); |
| 147 | }); |
| 148 | addConversion([&](RecordType ty) -> mlir::Type { |
| 149 | if (!needsConversion(ty)) |
| 150 | return ty; |
| 151 | if (auto converted = convertedTypes.lookup(ty)) |
| 152 | return converted; |
| 153 | auto rec = RecordType::get(ty.getContext(), |
| 154 | ty.getName().str() + boxprocSuffix.str()); |
| 155 | if (rec.isFinalized()) |
| 156 | return rec; |
| 157 | [[maybe_unused]] auto it = convertedTypes.try_emplace(ty, rec); |
| 158 | assert(it.second && "expected ty to not be in the map" ); |
| 159 | std::vector<RecordType::TypePair> ps = ty.getLenParamList(); |
| 160 | std::vector<RecordType::TypePair> cs; |
| 161 | for (auto t : ty.getTypeList()) { |
| 162 | if (needsConversion(t.second)) |
| 163 | cs.emplace_back(t.first, convertType(t.second)); |
| 164 | else |
| 165 | cs.emplace_back(t.first, t.second); |
| 166 | } |
| 167 | rec.finalize(ps, cs); |
| 168 | rec.pack(ty.isPacked()); |
| 169 | return rec; |
| 170 | }); |
| 171 | addConversion([&](TypeDescType ty) { |
| 172 | return TypeDescType::get(convertType(ty.getOfTy())); |
| 173 | }); |
| 174 | addSourceMaterialization(materializeProcedure); |
| 175 | addTargetMaterialization(materializeProcedure); |
| 176 | } |
| 177 | |
| 178 | static mlir::Value materializeProcedure(mlir::OpBuilder &builder, |
| 179 | BoxProcType type, |
| 180 | mlir::ValueRange inputs, |
| 181 | mlir::Location loc) { |
| 182 | assert(inputs.size() == 1); |
| 183 | return builder.create<ConvertOp>(loc, unwrapRefType(type.getEleTy()), |
| 184 | inputs[0]); |
| 185 | } |
| 186 | |
| 187 | void setLocation(mlir::Location location) { loc = location; } |
| 188 | |
| 189 | private: |
| 190 | // Maps to deal with recursive derived types (avoid infinite loops). |
| 191 | // Caching is also beneficial for apps with big types (dozens of |
| 192 | // components and or parent types), so the lifetime of the cache |
| 193 | // is the whole pass. |
| 194 | llvm::DenseMap<mlir::Type, bool> visitedTypes; |
| 195 | bool needConversionIsVisitingRecordType = false; |
| 196 | llvm::DenseMap<mlir::Type, mlir::Type> convertedTypes; |
| 197 | mlir::Location loc; |
| 198 | }; |
| 199 | |
| 200 | /// A `boxproc` is an abstraction for a Fortran procedure reference. Typically, |
| 201 | /// Fortran procedures can be referenced directly through a function pointer. |
| 202 | /// However, Fortran has one-level dynamic scoping between a host procedure and |
| 203 | /// its internal procedures. This allows internal procedures to directly access |
| 204 | /// and modify the state of the host procedure's variables. |
| 205 | /// |
| 206 | /// There are any number of possible implementations possible. |
| 207 | /// |
| 208 | /// The implementation used here is to convert `boxproc` values to function |
| 209 | /// pointers everywhere. If a `boxproc` value includes a frame pointer to the |
| 210 | /// host procedure's data, then a thunk will be created at runtime to capture |
| 211 | /// the frame pointer during execution. In LLVM IR, the frame pointer is |
| 212 | /// designated with the `nest` attribute. The thunk's address will then be used |
| 213 | /// as the call target instead of the original function's address directly. |
| 214 | class BoxedProcedurePass |
| 215 | : public fir::impl::BoxedProcedurePassBase<BoxedProcedurePass> { |
| 216 | public: |
| 217 | using BoxedProcedurePassBase<BoxedProcedurePass>::BoxedProcedurePassBase; |
| 218 | |
| 219 | inline mlir::ModuleOp getModule() { return getOperation(); } |
| 220 | |
| 221 | void runOnOperation() override final { |
| 222 | if (options.useThunks) { |
| 223 | auto *context = &getContext(); |
| 224 | mlir::IRRewriter rewriter(context); |
| 225 | BoxprocTypeRewriter typeConverter(mlir::UnknownLoc::get(context)); |
| 226 | getModule().walk([&](mlir::Operation *op) { |
| 227 | bool opIsValid = true; |
| 228 | typeConverter.setLocation(op->getLoc()); |
| 229 | if (auto addr = mlir::dyn_cast<BoxAddrOp>(op)) { |
| 230 | mlir::Type ty = addr.getVal().getType(); |
| 231 | mlir::Type resTy = addr.getResult().getType(); |
| 232 | if (llvm::isa<mlir::FunctionType>(ty) || |
| 233 | llvm::isa<fir::BoxProcType>(ty)) { |
| 234 | // Rewrite all `fir.box_addr` ops on values of type `!fir.boxproc` |
| 235 | // or function type to be `fir.convert` ops. |
| 236 | rewriter.setInsertionPoint(addr); |
| 237 | rewriter.replaceOpWithNewOp<ConvertOp>( |
| 238 | addr, typeConverter.convertType(addr.getType()), addr.getVal()); |
| 239 | opIsValid = false; |
| 240 | } else if (typeConverter.needsConversion(resTy)) { |
| 241 | rewriter.startOpModification(op); |
| 242 | op->getResult(0).setType(typeConverter.convertType(resTy)); |
| 243 | rewriter.finalizeOpModification(op); |
| 244 | } |
| 245 | } else if (auto func = mlir::dyn_cast<mlir::func::FuncOp>(op)) { |
| 246 | mlir::FunctionType ty = func.getFunctionType(); |
| 247 | if (typeConverter.needsConversion(ty)) { |
| 248 | rewriter.startOpModification(func); |
| 249 | auto toTy = |
| 250 | mlir::cast<mlir::FunctionType>(typeConverter.convertType(ty)); |
| 251 | if (!func.empty()) |
| 252 | for (auto e : llvm::enumerate(toTy.getInputs())) { |
| 253 | unsigned i = e.index(); |
| 254 | auto &block = func.front(); |
| 255 | block.insertArgument(i, e.value(), func.getLoc()); |
| 256 | block.getArgument(i + 1).replaceAllUsesWith( |
| 257 | block.getArgument(i)); |
| 258 | block.eraseArgument(i + 1); |
| 259 | } |
| 260 | func.setType(toTy); |
| 261 | rewriter.finalizeOpModification(func); |
| 262 | } |
| 263 | } else if (auto embox = mlir::dyn_cast<EmboxProcOp>(op)) { |
| 264 | // Rewrite all `fir.emboxproc` ops to either `fir.convert` or a thunk |
| 265 | // as required. |
| 266 | mlir::Type toTy = typeConverter.convertType( |
| 267 | mlir::cast<BoxProcType>(embox.getType()).getEleTy()); |
| 268 | rewriter.setInsertionPoint(embox); |
| 269 | if (embox.getHost()) { |
| 270 | // Create the thunk. |
| 271 | auto module = embox->getParentOfType<mlir::ModuleOp>(); |
| 272 | FirOpBuilder builder(rewriter, module); |
| 273 | const auto triple{fir::getTargetTriple(module)}; |
| 274 | auto loc = embox.getLoc(); |
| 275 | mlir::Type i8Ty = builder.getI8Type(); |
| 276 | mlir::Type i8Ptr = builder.getRefType(i8Ty); |
| 277 | // For AArch64, PPC32 and PPC64, the thunk is populated by a call to |
| 278 | // __trampoline_setup, which is defined in |
| 279 | // compiler-rt/lib/builtins/trampoline_setup.c and requires the |
| 280 | // thunk size greater than 32 bytes. For RISCV and x86_64, the |
| 281 | // thunk setup doesn't go through __trampoline_setup and fits in 32 |
| 282 | // bytes. |
| 283 | fir::SequenceType::Extent thunkSize = triple.getTrampolineSize(); |
| 284 | mlir::Type buffTy = SequenceType::get({thunkSize}, i8Ty); |
| 285 | auto buffer = builder.create<AllocaOp>(loc, buffTy); |
| 286 | mlir::Value closure = |
| 287 | builder.createConvert(loc, i8Ptr, embox.getHost()); |
| 288 | mlir::Value tramp = builder.createConvert(loc, i8Ptr, buffer); |
| 289 | mlir::Value func = |
| 290 | builder.createConvert(loc, i8Ptr, embox.getFunc()); |
| 291 | builder.create<fir::CallOp>( |
| 292 | loc, factory::getLlvmInitTrampoline(builder), |
| 293 | llvm::ArrayRef<mlir::Value>{tramp, func, closure}); |
| 294 | auto adjustCall = builder.create<fir::CallOp>( |
| 295 | loc, factory::getLlvmAdjustTrampoline(builder), |
| 296 | llvm::ArrayRef<mlir::Value>{tramp}); |
| 297 | rewriter.replaceOpWithNewOp<ConvertOp>(embox, toTy, |
| 298 | adjustCall.getResult(0)); |
| 299 | opIsValid = false; |
| 300 | } else { |
| 301 | // Just forward the function as a pointer. |
| 302 | rewriter.replaceOpWithNewOp<ConvertOp>(embox, toTy, |
| 303 | embox.getFunc()); |
| 304 | opIsValid = false; |
| 305 | } |
| 306 | } else if (auto global = mlir::dyn_cast<GlobalOp>(op)) { |
| 307 | auto ty = global.getType(); |
| 308 | if (typeConverter.needsConversion(ty)) { |
| 309 | rewriter.startOpModification(global); |
| 310 | auto toTy = typeConverter.convertType(ty); |
| 311 | global.setType(toTy); |
| 312 | rewriter.finalizeOpModification(global); |
| 313 | } |
| 314 | } else if (auto mem = mlir::dyn_cast<AllocaOp>(op)) { |
| 315 | auto ty = mem.getType(); |
| 316 | if (typeConverter.needsConversion(ty)) { |
| 317 | rewriter.setInsertionPoint(mem); |
| 318 | auto toTy = typeConverter.convertType(unwrapRefType(ty)); |
| 319 | bool isPinned = mem.getPinned(); |
| 320 | llvm::StringRef uniqName = |
| 321 | mem.getUniqName().value_or(llvm::StringRef()); |
| 322 | llvm::StringRef bindcName = |
| 323 | mem.getBindcName().value_or(llvm::StringRef()); |
| 324 | rewriter.replaceOpWithNewOp<AllocaOp>( |
| 325 | mem, toTy, uniqName, bindcName, isPinned, mem.getTypeparams(), |
| 326 | mem.getShape()); |
| 327 | opIsValid = false; |
| 328 | } |
| 329 | } else if (auto mem = mlir::dyn_cast<AllocMemOp>(op)) { |
| 330 | auto ty = mem.getType(); |
| 331 | if (typeConverter.needsConversion(ty)) { |
| 332 | rewriter.setInsertionPoint(mem); |
| 333 | auto toTy = typeConverter.convertType(unwrapRefType(ty)); |
| 334 | llvm::StringRef uniqName = |
| 335 | mem.getUniqName().value_or(llvm::StringRef()); |
| 336 | llvm::StringRef bindcName = |
| 337 | mem.getBindcName().value_or(llvm::StringRef()); |
| 338 | rewriter.replaceOpWithNewOp<AllocMemOp>( |
| 339 | mem, toTy, uniqName, bindcName, mem.getTypeparams(), |
| 340 | mem.getShape()); |
| 341 | opIsValid = false; |
| 342 | } |
| 343 | } else if (auto coor = mlir::dyn_cast<CoordinateOp>(op)) { |
| 344 | auto ty = coor.getType(); |
| 345 | mlir::Type baseTy = coor.getBaseType(); |
| 346 | if (typeConverter.needsConversion(ty) || |
| 347 | typeConverter.needsConversion(baseTy)) { |
| 348 | rewriter.setInsertionPoint(coor); |
| 349 | auto toTy = typeConverter.convertType(ty); |
| 350 | auto toBaseTy = typeConverter.convertType(baseTy); |
| 351 | rewriter.replaceOpWithNewOp<CoordinateOp>( |
| 352 | coor, toTy, coor.getRef(), coor.getCoor(), toBaseTy, |
| 353 | coor.getFieldIndicesAttr()); |
| 354 | opIsValid = false; |
| 355 | } |
| 356 | } else if (auto index = mlir::dyn_cast<FieldIndexOp>(op)) { |
| 357 | auto ty = index.getType(); |
| 358 | mlir::Type onTy = index.getOnType(); |
| 359 | if (typeConverter.needsConversion(ty) || |
| 360 | typeConverter.needsConversion(onTy)) { |
| 361 | rewriter.setInsertionPoint(index); |
| 362 | auto toTy = typeConverter.convertType(ty); |
| 363 | auto toOnTy = typeConverter.convertType(onTy); |
| 364 | rewriter.replaceOpWithNewOp<FieldIndexOp>( |
| 365 | index, toTy, index.getFieldId(), toOnTy, index.getTypeparams()); |
| 366 | opIsValid = false; |
| 367 | } |
| 368 | } else if (auto index = mlir::dyn_cast<LenParamIndexOp>(op)) { |
| 369 | auto ty = index.getType(); |
| 370 | mlir::Type onTy = index.getOnType(); |
| 371 | if (typeConverter.needsConversion(ty) || |
| 372 | typeConverter.needsConversion(onTy)) { |
| 373 | rewriter.setInsertionPoint(index); |
| 374 | auto toTy = typeConverter.convertType(ty); |
| 375 | auto toOnTy = typeConverter.convertType(onTy); |
| 376 | rewriter.replaceOpWithNewOp<LenParamIndexOp>( |
| 377 | index, toTy, index.getFieldId(), toOnTy, index.getTypeparams()); |
| 378 | opIsValid = false; |
| 379 | } |
| 380 | } else { |
| 381 | rewriter.startOpModification(op); |
| 382 | // Convert the operands if needed |
| 383 | for (auto i : llvm::enumerate(op->getResultTypes())) |
| 384 | if (typeConverter.needsConversion(i.value())) { |
| 385 | auto toTy = typeConverter.convertType(i.value()); |
| 386 | op->getResult(i.index()).setType(toTy); |
| 387 | } |
| 388 | |
| 389 | // Convert the type attributes if needed |
| 390 | for (const mlir::NamedAttribute &attr : op->getAttrDictionary()) |
| 391 | if (auto tyAttr = llvm::dyn_cast<mlir::TypeAttr>(attr.getValue())) |
| 392 | if (typeConverter.needsConversion(tyAttr.getValue())) { |
| 393 | auto toTy = typeConverter.convertType(tyAttr.getValue()); |
| 394 | op->setAttr(attr.getName(), mlir::TypeAttr::get(toTy)); |
| 395 | } |
| 396 | rewriter.finalizeOpModification(op); |
| 397 | } |
| 398 | // Ensure block arguments are updated if needed. |
| 399 | if (opIsValid && op->getNumRegions() != 0) { |
| 400 | rewriter.startOpModification(op); |
| 401 | for (mlir::Region ®ion : op->getRegions()) |
| 402 | for (mlir::Block &block : region.getBlocks()) |
| 403 | for (mlir::BlockArgument blockArg : block.getArguments()) |
| 404 | if (typeConverter.needsConversion(blockArg.getType())) { |
| 405 | mlir::Type toTy = |
| 406 | typeConverter.convertType(blockArg.getType()); |
| 407 | blockArg.setType(toTy); |
| 408 | } |
| 409 | rewriter.finalizeOpModification(op); |
| 410 | } |
| 411 | }); |
| 412 | } |
| 413 | } |
| 414 | |
| 415 | private: |
| 416 | BoxedProcedureOptions options; |
| 417 | }; |
| 418 | } // namespace |
| 419 | |