1//===- AbstractResult.cpp - Conversion of Abstract Function Result --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "flang/Optimizer/Builder/FIRBuilder.h"
10#include "flang/Optimizer/Builder/Todo.h"
11#include "flang/Optimizer/Dialect/FIRDialect.h"
12#include "flang/Optimizer/Dialect/FIROps.h"
13#include "flang/Optimizer/Dialect/FIRType.h"
14#include "flang/Optimizer/Dialect/Support/FIRContext.h"
15#include "flang/Optimizer/Transforms/Passes.h"
16#include "mlir/Dialect/Func/IR/FuncOps.h"
17#include "mlir/Dialect/GPU/IR/GPUDialect.h"
18#include "mlir/IR/Diagnostics.h"
19#include "mlir/Pass/Pass.h"
20#include "mlir/Pass/PassManager.h"
21#include "mlir/Transforms/DialectConversion.h"
22#include "llvm/ADT/TypeSwitch.h"
23
24namespace fir {
25#define GEN_PASS_DEF_ABSTRACTRESULTOPT
26#include "flang/Optimizer/Transforms/Passes.h.inc"
27} // namespace fir
28
29#define DEBUG_TYPE "flang-abstract-result-opt"
30
31using namespace mlir;
32
33namespace fir {
34namespace {
35
36// Helper to only build the symbol table if needed because its build time is
37// linear on the number of symbols in the module.
38struct LazySymbolTable {
39 LazySymbolTable(mlir::Operation *op)
40 : module{op->getParentOfType<mlir::ModuleOp>()} {}
41 void build() {
42 if (table)
43 return;
44 table = std::make_unique<mlir::SymbolTable>(module);
45 }
46
47 template <typename T>
48 T lookup(llvm::StringRef name) {
49 build();
50 return table->lookup<T>(name);
51 }
52
53private:
54 std::unique_ptr<mlir::SymbolTable> table;
55 mlir::ModuleOp module;
56};
57
58bool hasScalarDerivedResult(mlir::FunctionType funTy) {
59 // C_PTR/C_FUNPTR are results to void* in this pass, do not consider
60 // them as normal derived types.
61 return funTy.getNumResults() == 1 &&
62 mlir::isa<fir::RecordType>(funTy.getResult(0)) &&
63 !fir::isa_builtin_cptr_type(funTy.getResult(0));
64}
65
66static mlir::Type getResultArgumentType(mlir::Type resultType,
67 bool shouldBoxResult) {
68 return llvm::TypeSwitch<mlir::Type, mlir::Type>(resultType)
69 .Case<fir::SequenceType, fir::RecordType>(
70 [&](mlir::Type type) -> mlir::Type {
71 if (shouldBoxResult)
72 return fir::BoxType::get(type);
73 return fir::ReferenceType::get(type);
74 })
75 .Case<fir::BaseBoxType>([](mlir::Type type) -> mlir::Type {
76 return fir::ReferenceType::get(type);
77 })
78 .Default([](mlir::Type) -> mlir::Type {
79 llvm_unreachable("bad abstract result type");
80 });
81}
82
83static mlir::FunctionType getNewFunctionType(mlir::FunctionType funcTy,
84 bool shouldBoxResult) {
85 auto resultType = funcTy.getResult(0);
86 auto argTy = getResultArgumentType(resultType, shouldBoxResult);
87 llvm::SmallVector<mlir::Type> newInputTypes = {argTy};
88 newInputTypes.append(funcTy.getInputs().begin(), funcTy.getInputs().end());
89 return mlir::FunctionType::get(funcTy.getContext(), newInputTypes,
90 /*resultTypes=*/{});
91}
92
93static mlir::Type getVoidPtrType(mlir::MLIRContext *context) {
94 return fir::ReferenceType::get(mlir::NoneType::get(context));
95}
96
97/// This is for function result types that are of type C_PTR from ISO_C_BINDING.
98/// Follow the ABI for interoperability with C.
99static mlir::FunctionType getCPtrFunctionType(mlir::FunctionType funcTy) {
100 assert(fir::isa_builtin_cptr_type(funcTy.getResult(0)));
101 llvm::SmallVector<mlir::Type> outputTypes{
102 getVoidPtrType(funcTy.getContext())};
103 return mlir::FunctionType::get(funcTy.getContext(), funcTy.getInputs(),
104 outputTypes);
105}
106
107static bool mustEmboxResult(mlir::Type resultType, bool shouldBoxResult) {
108 return mlir::isa<fir::SequenceType, fir::RecordType>(resultType) &&
109 shouldBoxResult;
110}
111
112template <typename Op>
113class CallConversion : public mlir::OpRewritePattern<Op> {
114public:
115 using mlir::OpRewritePattern<Op>::OpRewritePattern;
116
117 CallConversion(mlir::MLIRContext *context, bool shouldBoxResult)
118 : OpRewritePattern<Op>(context, 1), shouldBoxResult{shouldBoxResult} {}
119
120 llvm::LogicalResult
121 matchAndRewrite(Op op, mlir::PatternRewriter &rewriter) const override {
122 auto loc = op.getLoc();
123 auto result = op->getResult(0);
124 if (!result.hasOneUse()) {
125 mlir::emitError(loc,
126 "calls with abstract result must have exactly one user");
127 return mlir::failure();
128 }
129 auto saveResult =
130 mlir::dyn_cast<fir::SaveResultOp>(result.use_begin().getUser());
131 if (!saveResult) {
132 mlir::emitError(
133 loc, "calls with abstract result must be used in fir.save_result");
134 return mlir::failure();
135 }
136 auto argType = getResultArgumentType(result.getType(), shouldBoxResult);
137 auto buffer = saveResult.getMemref();
138 mlir::Value arg = buffer;
139 if (mustEmboxResult(result.getType(), shouldBoxResult))
140 arg = rewriter.create<fir::EmboxOp>(
141 loc, argType, buffer, saveResult.getShape(), /*slice*/ mlir::Value{},
142 saveResult.getTypeparams());
143
144 llvm::SmallVector<mlir::Type> newResultTypes;
145 bool isResultBuiltinCPtr = fir::isa_builtin_cptr_type(result.getType());
146 if (isResultBuiltinCPtr)
147 newResultTypes.emplace_back(getVoidPtrType(result.getContext()));
148
149 Op newOp;
150 // TODO: propagate argument and result attributes (need to be shifted).
151 // fir::CallOp specific handling.
152 if constexpr (std::is_same_v<Op, fir::CallOp>) {
153 if (op.getCallee()) {
154 llvm::SmallVector<mlir::Value> newOperands;
155 if (!isResultBuiltinCPtr)
156 newOperands.emplace_back(arg);
157 newOperands.append(op.getOperands().begin(), op.getOperands().end());
158 newOp = rewriter.create<fir::CallOp>(loc, *op.getCallee(),
159 newResultTypes, newOperands);
160 } else {
161 // Indirect calls.
162 llvm::SmallVector<mlir::Type> newInputTypes;
163 if (!isResultBuiltinCPtr)
164 newInputTypes.emplace_back(argType);
165 for (auto operand : op.getOperands().drop_front())
166 newInputTypes.push_back(operand.getType());
167 auto newFuncTy = mlir::FunctionType::get(op.getContext(), newInputTypes,
168 newResultTypes);
169
170 llvm::SmallVector<mlir::Value> newOperands;
171 newOperands.push_back(
172 rewriter.create<fir::ConvertOp>(loc, newFuncTy, op.getOperand(0)));
173 if (!isResultBuiltinCPtr)
174 newOperands.push_back(arg);
175 newOperands.append(op.getOperands().begin() + 1,
176 op.getOperands().end());
177 newOp = rewriter.create<fir::CallOp>(loc, mlir::SymbolRefAttr{},
178 newResultTypes, newOperands);
179 }
180 }
181
182 // fir::DispatchOp specific handling.
183 if constexpr (std::is_same_v<Op, fir::DispatchOp>) {
184 llvm::SmallVector<mlir::Value> newOperands;
185 if (!isResultBuiltinCPtr)
186 newOperands.emplace_back(arg);
187 unsigned passArgShift = newOperands.size();
188 newOperands.append(op.getOperands().begin() + 1, op.getOperands().end());
189 mlir::IntegerAttr passArgPos;
190 if (op.getPassArgPos())
191 passArgPos =
192 rewriter.getI32IntegerAttr(*op.getPassArgPos() + passArgShift);
193 // TODO: propagate argument and result attributes (need to be shifted).
194 newOp = rewriter.create<fir::DispatchOp>(
195 loc, newResultTypes, rewriter.getStringAttr(op.getMethod()),
196 op.getOperands()[0], newOperands, passArgPos,
197 /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr,
198 op.getProcedureAttrsAttr());
199 }
200
201 if (isResultBuiltinCPtr) {
202 mlir::Value save = saveResult.getMemref();
203 auto module = op->template getParentOfType<mlir::ModuleOp>();
204 FirOpBuilder builder(rewriter, module);
205 mlir::Value saveAddr = fir::factory::genCPtrOrCFunptrAddr(
206 builder, loc, save, result.getType());
207 builder.createStoreWithConvert(loc, newOp->getResult(0), saveAddr);
208 }
209 op->dropAllReferences();
210 rewriter.eraseOp(op);
211 return mlir::success();
212 }
213
214private:
215 bool shouldBoxResult;
216};
217
218class SaveResultOpConversion
219 : public mlir::OpRewritePattern<fir::SaveResultOp> {
220public:
221 using OpRewritePattern::OpRewritePattern;
222 SaveResultOpConversion(mlir::MLIRContext *context)
223 : OpRewritePattern(context) {}
224 llvm::LogicalResult
225 matchAndRewrite(fir::SaveResultOp op,
226 mlir::PatternRewriter &rewriter) const override {
227 mlir::Operation *call = op.getValue().getDefiningOp();
228 mlir::Type type = op.getValue().getType();
229 if (mlir::isa<fir::RecordType>(type) && call && fir::hasBindcAttr(call) &&
230 !fir::isa_builtin_cptr_type(type)) {
231 rewriter.replaceOpWithNewOp<fir::StoreOp>(op, op.getValue(),
232 op.getMemref());
233 } else {
234 rewriter.eraseOp(op);
235 }
236 return mlir::success();
237 }
238};
239
240template <typename OpTy>
241static mlir::LogicalResult
242processReturnLikeOp(OpTy ret, mlir::Value newArg,
243 mlir::PatternRewriter &rewriter) {
244 auto loc = ret.getLoc();
245 rewriter.setInsertionPoint(ret);
246 mlir::Value resultValue = ret.getOperand(0);
247 fir::LoadOp resultLoad;
248 mlir::Value resultStorage;
249 // Identify result local storage.
250 if (auto load = resultValue.getDefiningOp<fir::LoadOp>()) {
251 resultLoad = load;
252 resultStorage = load.getMemref();
253 // The result alloca may be behind a fir.declare, if any.
254 if (auto declare = resultStorage.getDefiningOp<fir::DeclareOp>())
255 resultStorage = declare.getMemref();
256 }
257 // Replace old local storage with new storage argument, unless
258 // the derived type is C_PTR/C_FUN_PTR, in which case the return
259 // type is updated to return void* (no new argument is passed).
260 if (fir::isa_builtin_cptr_type(resultValue.getType())) {
261 auto module = ret->template getParentOfType<mlir::ModuleOp>();
262 FirOpBuilder builder(rewriter, module);
263 mlir::Value cptr = resultValue;
264 if (resultLoad) {
265 // Replace whole derived type load by component load.
266 cptr = resultLoad.getMemref();
267 rewriter.setInsertionPoint(resultLoad);
268 }
269 mlir::Value newResultValue =
270 fir::factory::genCPtrOrCFunptrValue(builder, loc, cptr);
271 newResultValue = builder.createConvert(
272 loc, getVoidPtrType(ret.getContext()), newResultValue);
273 rewriter.setInsertionPoint(ret);
274 rewriter.replaceOpWithNewOp<OpTy>(ret, mlir::ValueRange{newResultValue});
275 } else if (resultStorage) {
276 resultStorage.replaceAllUsesWith(newArg);
277 rewriter.replaceOpWithNewOp<OpTy>(ret);
278 } else {
279 // The result storage may have been optimized out by a memory to
280 // register pass, this is possible for fir.box results, or fir.record
281 // with no length parameters. Simply store the result in the result
282 // storage. at the return point.
283 rewriter.create<fir::StoreOp>(loc, resultValue, newArg);
284 rewriter.replaceOpWithNewOp<OpTy>(ret);
285 }
286 // Delete result old local storage if unused.
287 if (resultStorage)
288 if (auto alloc = resultStorage.getDefiningOp<fir::AllocaOp>())
289 if (alloc->use_empty())
290 rewriter.eraseOp(alloc);
291 return mlir::success();
292}
293
294class ReturnOpConversion : public mlir::OpRewritePattern<mlir::func::ReturnOp> {
295public:
296 using OpRewritePattern::OpRewritePattern;
297 ReturnOpConversion(mlir::MLIRContext *context, mlir::Value newArg)
298 : OpRewritePattern(context), newArg{newArg} {}
299 llvm::LogicalResult
300 matchAndRewrite(mlir::func::ReturnOp ret,
301 mlir::PatternRewriter &rewriter) const override {
302 return processReturnLikeOp(ret, newArg, rewriter);
303 }
304
305private:
306 mlir::Value newArg;
307};
308
309class GPUReturnOpConversion
310 : public mlir::OpRewritePattern<mlir::gpu::ReturnOp> {
311public:
312 using OpRewritePattern::OpRewritePattern;
313 GPUReturnOpConversion(mlir::MLIRContext *context, mlir::Value newArg)
314 : OpRewritePattern(context), newArg{newArg} {}
315 llvm::LogicalResult
316 matchAndRewrite(mlir::gpu::ReturnOp ret,
317 mlir::PatternRewriter &rewriter) const override {
318 return processReturnLikeOp(ret, newArg, rewriter);
319 }
320
321private:
322 mlir::Value newArg;
323};
324
325class AddrOfOpConversion : public mlir::OpRewritePattern<fir::AddrOfOp> {
326public:
327 using OpRewritePattern::OpRewritePattern;
328 AddrOfOpConversion(mlir::MLIRContext *context, bool shouldBoxResult)
329 : OpRewritePattern(context), shouldBoxResult{shouldBoxResult} {}
330 llvm::LogicalResult
331 matchAndRewrite(fir::AddrOfOp addrOf,
332 mlir::PatternRewriter &rewriter) const override {
333 auto oldFuncTy = mlir::cast<mlir::FunctionType>(addrOf.getType());
334 mlir::FunctionType newFuncTy;
335 if (oldFuncTy.getNumResults() != 0 &&
336 fir::isa_builtin_cptr_type(oldFuncTy.getResult(0)))
337 newFuncTy = getCPtrFunctionType(oldFuncTy);
338 else
339 newFuncTy = getNewFunctionType(oldFuncTy, shouldBoxResult);
340 auto newAddrOf = rewriter.create<fir::AddrOfOp>(addrOf.getLoc(), newFuncTy,
341 addrOf.getSymbol());
342 // Rather than converting all op a function pointer might transit through
343 // (e.g calls, stores, loads, converts...), cast new type to the abstract
344 // type. A conversion will be added when calling indirect calls of abstract
345 // types.
346 rewriter.replaceOpWithNewOp<fir::ConvertOp>(addrOf, oldFuncTy, newAddrOf);
347 return mlir::success();
348 }
349
350private:
351 bool shouldBoxResult;
352};
353
354class AbstractResultOpt
355 : public fir::impl::AbstractResultOptBase<AbstractResultOpt> {
356public:
357 using fir::impl::AbstractResultOptBase<
358 AbstractResultOpt>::AbstractResultOptBase;
359
360 template <typename OpTy>
361 void runOnFunctionLikeOperation(OpTy func, bool shouldBoxResult,
362 mlir::RewritePatternSet &patterns,
363 mlir::ConversionTarget &target) {
364 auto loc = func.getLoc();
365 auto *context = &getContext();
366 // Convert function type itself if it has an abstract result.
367 auto funcTy = mlir::cast<mlir::FunctionType>(func.getFunctionType());
368 // Scalar derived result of BIND(C) function must be returned according
369 // to the C struct return ABI which is target dependent and implemented in
370 // the target-rewrite pass.
371 if (hasScalarDerivedResult(funcTy) &&
372 fir::hasBindcAttr(func.getOperation()))
373 return;
374 if (hasAbstractResult(funcTy)) {
375 if (fir::isa_builtin_cptr_type(funcTy.getResult(0))) {
376 func.setType(getCPtrFunctionType(funcTy));
377 patterns.insert<ReturnOpConversion>(context, mlir::Value{});
378 target.addDynamicallyLegalOp<mlir::func::ReturnOp>(
379 [](mlir::func::ReturnOp ret) {
380 mlir::Type retTy = ret.getOperand(0).getType();
381 return !fir::isa_builtin_cptr_type(retTy);
382 });
383 return;
384 }
385 if (!func.empty()) {
386 // Insert new argument.
387 mlir::OpBuilder rewriter(context);
388 auto resultType = funcTy.getResult(0);
389 auto argTy = getResultArgumentType(resultType, shouldBoxResult);
390 llvm::LogicalResult res = func.insertArgument(0u, argTy, {}, loc);
391 (void)res;
392 assert(llvm::succeeded(res) && "failed to insert function argument");
393 res = func.eraseResult(0u);
394 (void)res;
395 assert(llvm::succeeded(res) && "failed to erase function result");
396 mlir::Value newArg = func.getArgument(0u);
397 if (mustEmboxResult(resultType, shouldBoxResult)) {
398 auto bufferType = fir::ReferenceType::get(resultType);
399 rewriter.setInsertionPointToStart(&func.front());
400 newArg = rewriter.create<fir::BoxAddrOp>(loc, bufferType, newArg);
401 }
402 patterns.insert<ReturnOpConversion>(context, newArg);
403 target.addDynamicallyLegalOp<mlir::func::ReturnOp>(
404 [](mlir::func::ReturnOp ret) { return ret.getOperands().empty(); });
405 patterns.insert<GPUReturnOpConversion>(context, newArg);
406 target.addDynamicallyLegalOp<mlir::gpu::ReturnOp>(
407 [](mlir::gpu::ReturnOp ret) { return ret.getOperands().empty(); });
408 assert(func.getFunctionType() ==
409 getNewFunctionType(funcTy, shouldBoxResult));
410 } else {
411 llvm::SmallVector<mlir::DictionaryAttr> allArgs;
412 func.getAllArgAttrs(allArgs);
413 allArgs.insert(allArgs.begin(),
414 mlir::DictionaryAttr::get(func->getContext()));
415 func.setType(getNewFunctionType(funcTy, shouldBoxResult));
416 func.setAllArgAttrs(allArgs);
417 }
418 }
419 }
420
421 void runOnSpecificOperation(mlir::func::FuncOp func, bool shouldBoxResult,
422 mlir::RewritePatternSet &patterns,
423 mlir::ConversionTarget &target) {
424 runOnFunctionLikeOperation(func, shouldBoxResult, patterns, target);
425 }
426
427 void runOnSpecificOperation(mlir::gpu::GPUFuncOp func, bool shouldBoxResult,
428 mlir::RewritePatternSet &patterns,
429 mlir::ConversionTarget &target) {
430 runOnFunctionLikeOperation(func, shouldBoxResult, patterns, target);
431 }
432
433 inline static bool containsFunctionTypeWithAbstractResult(mlir::Type type) {
434 return mlir::TypeSwitch<mlir::Type, bool>(type)
435 .Case([](fir::BoxProcType boxProc) {
436 return fir::hasAbstractResult(
437 mlir::cast<mlir::FunctionType>(boxProc.getEleTy()));
438 })
439 .Case([](fir::PointerType pointer) {
440 return fir::hasAbstractResult(
441 mlir::cast<mlir::FunctionType>(pointer.getEleTy()));
442 })
443 .Default([](auto &&) { return false; });
444 }
445
446 void runOnSpecificOperation(fir::GlobalOp global, bool,
447 mlir::RewritePatternSet &,
448 mlir::ConversionTarget &) {
449 if (containsFunctionTypeWithAbstractResult(global.getType())) {
450 TODO(global->getLoc(), "support for procedure pointers");
451 }
452 }
453
454 /// Run the pass on a ModuleOp. This makes fir-opt --abstract-result work.
455 void runOnModule() {
456 mlir::ModuleOp mod = mlir::cast<mlir::ModuleOp>(getOperation());
457
458 auto pass = std::make_unique<AbstractResultOpt>();
459 pass->copyOptionValuesFrom(this);
460 mlir::OpPassManager pipeline;
461 pipeline.addPass(std::unique_ptr<mlir::Pass>{pass.release()});
462
463 // Run the pass on all operations directly nested inside of the ModuleOp
464 // we can't just call runOnSpecificOperation here because the pass
465 // implementation only works when scoped to a particular func.func or
466 // fir.global
467 for (mlir::Region &region : mod->getRegions()) {
468 for (mlir::Block &block : region.getBlocks()) {
469 for (mlir::Operation &op : block.getOperations()) {
470 if (mlir::failed(runPipeline(pipeline, &op))) {
471 mlir::emitError(op.getLoc(), "Failed to run abstract result pass");
472 signalPassFailure();
473 return;
474 }
475 }
476 }
477 }
478 }
479
480 void runOnOperation() override {
481 auto *context = &this->getContext();
482 mlir::Operation *op = this->getOperation();
483 if (mlir::isa<mlir::ModuleOp>(op)) {
484 runOnModule();
485 return;
486 }
487
488 LazySymbolTable symbolTable(op);
489
490 mlir::RewritePatternSet patterns(context);
491 mlir::ConversionTarget target = *context;
492 const bool shouldBoxResult = this->passResultAsBox.getValue();
493
494 mlir::TypeSwitch<mlir::Operation *, void>(op)
495 .Case<mlir::func::FuncOp, fir::GlobalOp, mlir::gpu::GPUFuncOp>(
496 [&](auto op) {
497 runOnSpecificOperation(op, shouldBoxResult, patterns, target);
498 });
499
500 // Convert the calls and, if needed, the ReturnOp in the function body.
501 target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect,
502 mlir::func::FuncDialect>();
503 target.addIllegalOp<fir::SaveResultOp>();
504 target.addDynamicallyLegalOp<fir::CallOp>([](fir::CallOp call) {
505 mlir::FunctionType funTy = call.getFunctionType();
506 if (hasScalarDerivedResult(funTy) &&
507 fir::hasBindcAttr(call.getOperation()))
508 return true;
509 return !hasAbstractResult(funTy);
510 });
511 target.addDynamicallyLegalOp<fir::AddrOfOp>([&symbolTable](
512 fir::AddrOfOp addrOf) {
513 if (auto funTy = mlir::dyn_cast<mlir::FunctionType>(addrOf.getType())) {
514 if (hasScalarDerivedResult(funTy)) {
515 auto func = symbolTable.lookup<mlir::func::FuncOp>(
516 addrOf.getSymbol().getRootReference().getValue());
517 return func && fir::hasBindcAttr(func.getOperation());
518 }
519 return !hasAbstractResult(funTy);
520 }
521 return true;
522 });
523 target.addDynamicallyLegalOp<fir::DispatchOp>([](fir::DispatchOp dispatch) {
524 mlir::FunctionType funTy = dispatch.getFunctionType();
525 if (hasScalarDerivedResult(funTy) &&
526 fir::hasBindcAttr(dispatch.getOperation()))
527 return true;
528 return !hasAbstractResult(dispatch.getFunctionType());
529 });
530
531 patterns.insert<CallConversion<fir::CallOp>>(context, shouldBoxResult);
532 patterns.insert<CallConversion<fir::DispatchOp>>(context, shouldBoxResult);
533 patterns.insert<SaveResultOpConversion>(context);
534 patterns.insert<AddrOfOpConversion>(context, shouldBoxResult);
535 if (mlir::failed(
536 mlir::applyPartialConversion(op, target, std::move(patterns)))) {
537 mlir::emitError(op->getLoc(), "error in converting abstract results\n");
538 this->signalPassFailure();
539 }
540 }
541};
542
543} // end anonymous namespace
544} // namespace fir
545

source code of flang/lib/Optimizer/Transforms/AbstractResult.cpp