1//===- AbstractResult.cpp - Conversion of Abstract Function Result --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "flang/Optimizer/Builder/FIRBuilder.h"
10#include "flang/Optimizer/Builder/Todo.h"
11#include "flang/Optimizer/Dialect/FIRDialect.h"
12#include "flang/Optimizer/Dialect/FIROps.h"
13#include "flang/Optimizer/Dialect/FIRType.h"
14#include "flang/Optimizer/Dialect/Support/FIRContext.h"
15#include "flang/Optimizer/Transforms/Passes.h"
16#include "mlir/Dialect/Func/IR/FuncOps.h"
17#include "mlir/IR/Diagnostics.h"
18#include "mlir/Pass/Pass.h"
19#include "mlir/Pass/PassManager.h"
20#include "mlir/Transforms/DialectConversion.h"
21#include "llvm/ADT/TypeSwitch.h"
22
23namespace fir {
24#define GEN_PASS_DEF_ABSTRACTRESULTOPT
25#include "flang/Optimizer/Transforms/Passes.h.inc"
26} // namespace fir
27
28#define DEBUG_TYPE "flang-abstract-result-opt"
29
30using namespace mlir;
31
32namespace fir {
33namespace {
34
35static mlir::Type getResultArgumentType(mlir::Type resultType,
36 bool shouldBoxResult) {
37 return llvm::TypeSwitch<mlir::Type, mlir::Type>(resultType)
38 .Case<fir::SequenceType, fir::RecordType>(
39 [&](mlir::Type type) -> mlir::Type {
40 if (shouldBoxResult)
41 return fir::BoxType::get(type);
42 return fir::ReferenceType::get(type);
43 })
44 .Case<fir::BaseBoxType>([](mlir::Type type) -> mlir::Type {
45 return fir::ReferenceType::get(type);
46 })
47 .Default([](mlir::Type) -> mlir::Type {
48 llvm_unreachable("bad abstract result type");
49 });
50}
51
52static mlir::FunctionType getNewFunctionType(mlir::FunctionType funcTy,
53 bool shouldBoxResult) {
54 auto resultType = funcTy.getResult(0);
55 auto argTy = getResultArgumentType(resultType, shouldBoxResult);
56 llvm::SmallVector<mlir::Type> newInputTypes = {argTy};
57 newInputTypes.append(funcTy.getInputs().begin(), funcTy.getInputs().end());
58 return mlir::FunctionType::get(funcTy.getContext(), newInputTypes,
59 /*resultTypes=*/{});
60}
61
62/// This is for function result types that are of type C_PTR from ISO_C_BINDING.
63/// Follow the ABI for interoperability with C.
64static mlir::FunctionType getCPtrFunctionType(mlir::FunctionType funcTy) {
65 auto resultType = funcTy.getResult(0);
66 assert(fir::isa_builtin_cptr_type(resultType));
67 llvm::SmallVector<mlir::Type> outputTypes;
68 auto recTy = resultType.dyn_cast<fir::RecordType>();
69 outputTypes.emplace_back(recTy.getTypeList()[0].second);
70 return mlir::FunctionType::get(funcTy.getContext(), funcTy.getInputs(),
71 outputTypes);
72}
73
74static bool mustEmboxResult(mlir::Type resultType, bool shouldBoxResult) {
75 return resultType.isa<fir::SequenceType, fir::RecordType>() &&
76 shouldBoxResult;
77}
78
79template <typename Op>
80class CallConversion : public mlir::OpRewritePattern<Op> {
81public:
82 using mlir::OpRewritePattern<Op>::OpRewritePattern;
83
84 CallConversion(mlir::MLIRContext *context, bool shouldBoxResult)
85 : OpRewritePattern<Op>(context, 1), shouldBoxResult{shouldBoxResult} {}
86
87 mlir::LogicalResult
88 matchAndRewrite(Op op, mlir::PatternRewriter &rewriter) const override {
89 auto loc = op.getLoc();
90 auto result = op->getResult(0);
91 if (!result.hasOneUse()) {
92 mlir::emitError(loc,
93 "calls with abstract result must have exactly one user");
94 return mlir::failure();
95 }
96 auto saveResult =
97 mlir::dyn_cast<fir::SaveResultOp>(result.use_begin().getUser());
98 if (!saveResult) {
99 mlir::emitError(
100 loc, "calls with abstract result must be used in fir.save_result");
101 return mlir::failure();
102 }
103 auto argType = getResultArgumentType(result.getType(), shouldBoxResult);
104 auto buffer = saveResult.getMemref();
105 mlir::Value arg = buffer;
106 if (mustEmboxResult(result.getType(), shouldBoxResult))
107 arg = rewriter.create<fir::EmboxOp>(
108 loc, argType, buffer, saveResult.getShape(), /*slice*/ mlir::Value{},
109 saveResult.getTypeparams());
110
111 llvm::SmallVector<mlir::Type> newResultTypes;
112 // TODO: This should be generalized for derived types, and it is
113 // architecture and OS dependent.
114 bool isResultBuiltinCPtr = fir::isa_builtin_cptr_type(result.getType());
115 Op newOp;
116 if (isResultBuiltinCPtr) {
117 auto recTy = result.getType().template dyn_cast<fir::RecordType>();
118 newResultTypes.emplace_back(recTy.getTypeList()[0].second);
119 }
120
121 // fir::CallOp specific handling.
122 if constexpr (std::is_same_v<Op, fir::CallOp>) {
123 if (op.getCallee()) {
124 llvm::SmallVector<mlir::Value> newOperands;
125 if (!isResultBuiltinCPtr)
126 newOperands.emplace_back(arg);
127 newOperands.append(op.getOperands().begin(), op.getOperands().end());
128 newOp = rewriter.create<fir::CallOp>(loc, *op.getCallee(),
129 newResultTypes, newOperands);
130 } else {
131 // Indirect calls.
132 llvm::SmallVector<mlir::Type> newInputTypes;
133 if (!isResultBuiltinCPtr)
134 newInputTypes.emplace_back(argType);
135 for (auto operand : op.getOperands().drop_front())
136 newInputTypes.push_back(operand.getType());
137 auto newFuncTy = mlir::FunctionType::get(op.getContext(), newInputTypes,
138 newResultTypes);
139
140 llvm::SmallVector<mlir::Value> newOperands;
141 newOperands.push_back(
142 rewriter.create<fir::ConvertOp>(loc, newFuncTy, op.getOperand(0)));
143 if (!isResultBuiltinCPtr)
144 newOperands.push_back(arg);
145 newOperands.append(op.getOperands().begin() + 1,
146 op.getOperands().end());
147 newOp = rewriter.create<fir::CallOp>(loc, mlir::SymbolRefAttr{},
148 newResultTypes, newOperands);
149 }
150 }
151
152 // fir::DispatchOp specific handling.
153 if constexpr (std::is_same_v<Op, fir::DispatchOp>) {
154 llvm::SmallVector<mlir::Value> newOperands;
155 if (!isResultBuiltinCPtr)
156 newOperands.emplace_back(arg);
157 unsigned passArgShift = newOperands.size();
158 newOperands.append(op.getOperands().begin() + 1, op.getOperands().end());
159
160 fir::DispatchOp newDispatchOp;
161 if (op.getPassArgPos())
162 newOp = rewriter.create<fir::DispatchOp>(
163 loc, newResultTypes, rewriter.getStringAttr(op.getMethod()),
164 op.getOperands()[0], newOperands,
165 rewriter.getI32IntegerAttr(*op.getPassArgPos() + passArgShift));
166 else
167 newOp = rewriter.create<fir::DispatchOp>(
168 loc, newResultTypes, rewriter.getStringAttr(op.getMethod()),
169 op.getOperands()[0], newOperands, nullptr);
170 }
171
172 if (isResultBuiltinCPtr) {
173 mlir::Value save = saveResult.getMemref();
174 auto module = op->template getParentOfType<mlir::ModuleOp>();
175 FirOpBuilder builder(rewriter, module);
176 mlir::Value saveAddr = fir::factory::genCPtrOrCFunptrAddr(
177 builder, loc, save, result.getType());
178 rewriter.create<fir::StoreOp>(loc, newOp->getResult(0), saveAddr);
179 }
180 op->dropAllReferences();
181 rewriter.eraseOp(op);
182 return mlir::success();
183 }
184
185private:
186 bool shouldBoxResult;
187};
188
189class SaveResultOpConversion
190 : public mlir::OpRewritePattern<fir::SaveResultOp> {
191public:
192 using OpRewritePattern::OpRewritePattern;
193 SaveResultOpConversion(mlir::MLIRContext *context)
194 : OpRewritePattern(context) {}
195 mlir::LogicalResult
196 matchAndRewrite(fir::SaveResultOp op,
197 mlir::PatternRewriter &rewriter) const override {
198 rewriter.eraseOp(op);
199 return mlir::success();
200 }
201};
202
203class ReturnOpConversion : public mlir::OpRewritePattern<mlir::func::ReturnOp> {
204public:
205 using OpRewritePattern::OpRewritePattern;
206 ReturnOpConversion(mlir::MLIRContext *context, mlir::Value newArg)
207 : OpRewritePattern(context), newArg{newArg} {}
208 mlir::LogicalResult
209 matchAndRewrite(mlir::func::ReturnOp ret,
210 mlir::PatternRewriter &rewriter) const override {
211 auto loc = ret.getLoc();
212 rewriter.setInsertionPoint(ret);
213 auto returnedValue = ret.getOperand(0);
214 bool replacedStorage = false;
215 if (auto *op = returnedValue.getDefiningOp())
216 if (auto load = mlir::dyn_cast<fir::LoadOp>(op)) {
217 auto resultStorage = load.getMemref();
218 // The result alloca may be behind a fir.declare, if any.
219 if (auto declare = mlir::dyn_cast_or_null<fir::DeclareOp>(
220 resultStorage.getDefiningOp()))
221 resultStorage = declare.getMemref();
222 // TODO: This should be generalized for derived types, and it is
223 // architecture and OS dependent.
224 if (fir::isa_builtin_cptr_type(returnedValue.getType())) {
225 rewriter.eraseOp(load);
226 auto module = ret->getParentOfType<mlir::ModuleOp>();
227 FirOpBuilder builder(rewriter, module);
228 mlir::Value retAddr = fir::factory::genCPtrOrCFunptrAddr(
229 builder, loc, resultStorage, returnedValue.getType());
230 mlir::Value retValue = rewriter.create<fir::LoadOp>(
231 loc, fir::unwrapRefType(retAddr.getType()), retAddr);
232 rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(
233 ret, mlir::ValueRange{retValue});
234 return mlir::success();
235 }
236 resultStorage.replaceAllUsesWith(newArg);
237 replacedStorage = true;
238 if (auto *alloc = resultStorage.getDefiningOp())
239 if (alloc->use_empty())
240 rewriter.eraseOp(alloc);
241 }
242 // The result storage may have been optimized out by a memory to
243 // register pass, this is possible for fir.box results, or fir.record
244 // with no length parameters. Simply store the result in the result storage.
245 // at the return point.
246 if (!replacedStorage)
247 rewriter.create<fir::StoreOp>(loc, returnedValue, newArg);
248 rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(ret);
249 return mlir::success();
250 }
251
252private:
253 mlir::Value newArg;
254};
255
256class AddrOfOpConversion : public mlir::OpRewritePattern<fir::AddrOfOp> {
257public:
258 using OpRewritePattern::OpRewritePattern;
259 AddrOfOpConversion(mlir::MLIRContext *context, bool shouldBoxResult)
260 : OpRewritePattern(context), shouldBoxResult{shouldBoxResult} {}
261 mlir::LogicalResult
262 matchAndRewrite(fir::AddrOfOp addrOf,
263 mlir::PatternRewriter &rewriter) const override {
264 auto oldFuncTy = addrOf.getType().cast<mlir::FunctionType>();
265 mlir::FunctionType newFuncTy;
266 // TODO: This should be generalized for derived types, and it is
267 // architecture and OS dependent.
268 if (oldFuncTy.getNumResults() != 0 &&
269 fir::isa_builtin_cptr_type(oldFuncTy.getResult(0)))
270 newFuncTy = getCPtrFunctionType(oldFuncTy);
271 else
272 newFuncTy = getNewFunctionType(oldFuncTy, shouldBoxResult);
273 auto newAddrOf = rewriter.create<fir::AddrOfOp>(addrOf.getLoc(), newFuncTy,
274 addrOf.getSymbol());
275 // Rather than converting all op a function pointer might transit through
276 // (e.g calls, stores, loads, converts...), cast new type to the abstract
277 // type. A conversion will be added when calling indirect calls of abstract
278 // types.
279 rewriter.replaceOpWithNewOp<fir::ConvertOp>(addrOf, oldFuncTy, newAddrOf);
280 return mlir::success();
281 }
282
283private:
284 bool shouldBoxResult;
285};
286
287class AbstractResultOpt
288 : public fir::impl::AbstractResultOptBase<AbstractResultOpt> {
289public:
290 using fir::impl::AbstractResultOptBase<
291 AbstractResultOpt>::AbstractResultOptBase;
292
293 void runOnSpecificOperation(mlir::func::FuncOp func, bool shouldBoxResult,
294 mlir::RewritePatternSet &patterns,
295 mlir::ConversionTarget &target) {
296 auto loc = func.getLoc();
297 auto *context = &getContext();
298 // Convert function type itself if it has an abstract result.
299 auto funcTy = func.getFunctionType().cast<mlir::FunctionType>();
300 if (hasAbstractResult(funcTy)) {
301 // TODO: This should be generalized for derived types, and it is
302 // architecture and OS dependent.
303 if (fir::isa_builtin_cptr_type(funcTy.getResult(0))) {
304 func.setType(getCPtrFunctionType(funcTy));
305 patterns.insert<ReturnOpConversion>(context, mlir::Value{});
306 target.addDynamicallyLegalOp<mlir::func::ReturnOp>(
307 [](mlir::func::ReturnOp ret) {
308 mlir::Type retTy = ret.getOperand(0).getType();
309 return !fir::isa_builtin_cptr_type(retTy);
310 });
311 return;
312 }
313 if (!func.empty()) {
314 // Insert new argument.
315 mlir::OpBuilder rewriter(context);
316 auto resultType = funcTy.getResult(0);
317 auto argTy = getResultArgumentType(resultType, shouldBoxResult);
318 func.insertArgument(0u, argTy, {}, loc);
319 func.eraseResult(0u);
320 mlir::Value newArg = func.getArgument(0u);
321 if (mustEmboxResult(resultType, shouldBoxResult)) {
322 auto bufferType = fir::ReferenceType::get(resultType);
323 rewriter.setInsertionPointToStart(&func.front());
324 newArg = rewriter.create<fir::BoxAddrOp>(loc, bufferType, newArg);
325 }
326 patterns.insert<ReturnOpConversion>(context, newArg);
327 target.addDynamicallyLegalOp<mlir::func::ReturnOp>(
328 [](mlir::func::ReturnOp ret) { return ret.getOperands().empty(); });
329 assert(func.getFunctionType() ==
330 getNewFunctionType(funcTy, shouldBoxResult));
331 } else {
332 llvm::SmallVector<mlir::DictionaryAttr> allArgs;
333 func.getAllArgAttrs(allArgs);
334 allArgs.insert(allArgs.begin(),
335 mlir::DictionaryAttr::get(func->getContext()));
336 func.setType(getNewFunctionType(funcTy, shouldBoxResult));
337 func.setAllArgAttrs(allArgs);
338 }
339 }
340 }
341
342 inline static bool containsFunctionTypeWithAbstractResult(mlir::Type type) {
343 return mlir::TypeSwitch<mlir::Type, bool>(type)
344 .Case([](fir::BoxProcType boxProc) {
345 return fir::hasAbstractResult(
346 boxProc.getEleTy().cast<mlir::FunctionType>());
347 })
348 .Case([](fir::PointerType pointer) {
349 return fir::hasAbstractResult(
350 pointer.getEleTy().cast<mlir::FunctionType>());
351 })
352 .Default([](auto &&) { return false; });
353 }
354
355 void runOnSpecificOperation(fir::GlobalOp global, bool,
356 mlir::RewritePatternSet &,
357 mlir::ConversionTarget &) {
358 if (containsFunctionTypeWithAbstractResult(global.getType())) {
359 TODO(global->getLoc(), "support for procedure pointers");
360 }
361 }
362
363 /// Run the pass on a ModuleOp. This makes fir-opt --abstract-result work.
364 void runOnModule() {
365 mlir::ModuleOp mod = mlir::cast<mlir::ModuleOp>(getOperation());
366
367 auto pass = std::make_unique<AbstractResultOpt>();
368 pass->copyOptionValuesFrom(this);
369 mlir::OpPassManager pipeline;
370 pipeline.addPass(std::unique_ptr<mlir::Pass>{pass.release()});
371
372 // Run the pass on all operations directly nested inside of the ModuleOp
373 // we can't just call runOnSpecificOperation here because the pass
374 // implementation only works when scoped to a particular func.func or
375 // fir.global
376 for (mlir::Region &region : mod->getRegions()) {
377 for (mlir::Block &block : region.getBlocks()) {
378 for (mlir::Operation &op : block.getOperations()) {
379 if (mlir::failed(runPipeline(pipeline, &op))) {
380 mlir::emitError(op.getLoc(), "Failed to run abstract result pass");
381 signalPassFailure();
382 return;
383 }
384 }
385 }
386 }
387 }
388
389 void runOnOperation() override {
390 auto *context = &this->getContext();
391 mlir::Operation *op = this->getOperation();
392 if (mlir::isa<mlir::ModuleOp>(op)) {
393 runOnModule();
394 return;
395 }
396
397 mlir::RewritePatternSet patterns(context);
398 mlir::ConversionTarget target = *context;
399 const bool shouldBoxResult = this->passResultAsBox.getValue();
400
401 mlir::TypeSwitch<mlir::Operation *, void>(op)
402 .Case<mlir::func::FuncOp, fir::GlobalOp>([&](auto op) {
403 runOnSpecificOperation(op, shouldBoxResult, patterns, target);
404 });
405
406 // Convert the calls and, if needed, the ReturnOp in the function body.
407 target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect,
408 mlir::func::FuncDialect>();
409 target.addIllegalOp<fir::SaveResultOp>();
410 target.addDynamicallyLegalOp<fir::CallOp>([](fir::CallOp call) {
411 return !hasAbstractResult(call.getFunctionType());
412 });
413 target.addDynamicallyLegalOp<fir::AddrOfOp>([](fir::AddrOfOp addrOf) {
414 if (auto funTy = addrOf.getType().dyn_cast<mlir::FunctionType>())
415 return !hasAbstractResult(funTy);
416 return true;
417 });
418 target.addDynamicallyLegalOp<fir::DispatchOp>([](fir::DispatchOp dispatch) {
419 return !hasAbstractResult(dispatch.getFunctionType());
420 });
421
422 patterns.insert<CallConversion<fir::CallOp>>(context, shouldBoxResult);
423 patterns.insert<CallConversion<fir::DispatchOp>>(context, shouldBoxResult);
424 patterns.insert<SaveResultOpConversion>(context);
425 patterns.insert<AddrOfOpConversion>(context, shouldBoxResult);
426 if (mlir::failed(
427 mlir::applyPartialConversion(op, target, std::move(patterns)))) {
428 mlir::emitError(op->getLoc(), "error in converting abstract results\n");
429 this->signalPassFailure();
430 }
431 }
432};
433
434} // end anonymous namespace
435} // namespace fir

source code of flang/lib/Optimizer/Transforms/AbstractResult.cpp