1 | //===- AbstractResult.cpp - Conversion of Abstract Function Result --------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "flang/Optimizer/Builder/FIRBuilder.h" |
10 | #include "flang/Optimizer/Builder/Todo.h" |
11 | #include "flang/Optimizer/Dialect/FIRDialect.h" |
12 | #include "flang/Optimizer/Dialect/FIROps.h" |
13 | #include "flang/Optimizer/Dialect/FIRType.h" |
14 | #include "flang/Optimizer/Dialect/Support/FIRContext.h" |
15 | #include "flang/Optimizer/Transforms/Passes.h" |
16 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
17 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
18 | #include "mlir/IR/Diagnostics.h" |
19 | #include "mlir/Pass/Pass.h" |
20 | #include "mlir/Pass/PassManager.h" |
21 | #include "mlir/Transforms/DialectConversion.h" |
22 | #include "llvm/ADT/TypeSwitch.h" |
23 | |
24 | namespace fir { |
25 | #define GEN_PASS_DEF_ABSTRACTRESULTOPT |
26 | #include "flang/Optimizer/Transforms/Passes.h.inc" |
27 | } // namespace fir |
28 | |
29 | #define DEBUG_TYPE "flang-abstract-result-opt" |
30 | |
31 | using namespace mlir; |
32 | |
33 | namespace fir { |
34 | namespace { |
35 | |
36 | // Helper to only build the symbol table if needed because its build time is |
37 | // linear on the number of symbols in the module. |
38 | struct LazySymbolTable { |
39 | LazySymbolTable(mlir::Operation *op) |
40 | : module{op->getParentOfType<mlir::ModuleOp>()} {} |
41 | void build() { |
42 | if (table) |
43 | return; |
44 | table = std::make_unique<mlir::SymbolTable>(module); |
45 | } |
46 | |
47 | template <typename T> |
48 | T lookup(llvm::StringRef name) { |
49 | build(); |
50 | return table->lookup<T>(name); |
51 | } |
52 | |
53 | private: |
54 | std::unique_ptr<mlir::SymbolTable> table; |
55 | mlir::ModuleOp module; |
56 | }; |
57 | |
58 | bool hasScalarDerivedResult(mlir::FunctionType funTy) { |
59 | // C_PTR/C_FUNPTR are results to void* in this pass, do not consider |
60 | // them as normal derived types. |
61 | return funTy.getNumResults() == 1 && |
62 | mlir::isa<fir::RecordType>(funTy.getResult(0)) && |
63 | !fir::isa_builtin_cptr_type(funTy.getResult(0)); |
64 | } |
65 | |
66 | static mlir::Type getResultArgumentType(mlir::Type resultType, |
67 | bool shouldBoxResult) { |
68 | return llvm::TypeSwitch<mlir::Type, mlir::Type>(resultType) |
69 | .Case<fir::SequenceType, fir::RecordType>( |
70 | [&](mlir::Type type) -> mlir::Type { |
71 | if (shouldBoxResult) |
72 | return fir::BoxType::get(type); |
73 | return fir::ReferenceType::get(type); |
74 | }) |
75 | .Case<fir::BaseBoxType>([](mlir::Type type) -> mlir::Type { |
76 | return fir::ReferenceType::get(type); |
77 | }) |
78 | .Default([](mlir::Type) -> mlir::Type { |
79 | llvm_unreachable("bad abstract result type" ); |
80 | }); |
81 | } |
82 | |
83 | static mlir::FunctionType getNewFunctionType(mlir::FunctionType funcTy, |
84 | bool shouldBoxResult) { |
85 | auto resultType = funcTy.getResult(0); |
86 | auto argTy = getResultArgumentType(resultType, shouldBoxResult); |
87 | llvm::SmallVector<mlir::Type> newInputTypes = {argTy}; |
88 | newInputTypes.append(funcTy.getInputs().begin(), funcTy.getInputs().end()); |
89 | return mlir::FunctionType::get(funcTy.getContext(), newInputTypes, |
90 | /*resultTypes=*/{}); |
91 | } |
92 | |
93 | static mlir::Type getVoidPtrType(mlir::MLIRContext *context) { |
94 | return fir::ReferenceType::get(mlir::NoneType::get(context)); |
95 | } |
96 | |
97 | /// This is for function result types that are of type C_PTR from ISO_C_BINDING. |
98 | /// Follow the ABI for interoperability with C. |
99 | static mlir::FunctionType getCPtrFunctionType(mlir::FunctionType funcTy) { |
100 | assert(fir::isa_builtin_cptr_type(funcTy.getResult(0))); |
101 | llvm::SmallVector<mlir::Type> outputTypes{ |
102 | getVoidPtrType(funcTy.getContext())}; |
103 | return mlir::FunctionType::get(funcTy.getContext(), funcTy.getInputs(), |
104 | outputTypes); |
105 | } |
106 | |
107 | static bool mustEmboxResult(mlir::Type resultType, bool shouldBoxResult) { |
108 | return mlir::isa<fir::SequenceType, fir::RecordType>(resultType) && |
109 | shouldBoxResult; |
110 | } |
111 | |
112 | template <typename Op> |
113 | class CallConversion : public mlir::OpRewritePattern<Op> { |
114 | public: |
115 | using mlir::OpRewritePattern<Op>::OpRewritePattern; |
116 | |
117 | CallConversion(mlir::MLIRContext *context, bool shouldBoxResult) |
118 | : OpRewritePattern<Op>(context, 1), shouldBoxResult{shouldBoxResult} {} |
119 | |
120 | llvm::LogicalResult |
121 | matchAndRewrite(Op op, mlir::PatternRewriter &rewriter) const override { |
122 | auto loc = op.getLoc(); |
123 | auto result = op->getResult(0); |
124 | if (!result.hasOneUse()) { |
125 | mlir::emitError(loc, |
126 | "calls with abstract result must have exactly one user" ); |
127 | return mlir::failure(); |
128 | } |
129 | auto saveResult = |
130 | mlir::dyn_cast<fir::SaveResultOp>(result.use_begin().getUser()); |
131 | if (!saveResult) { |
132 | mlir::emitError( |
133 | loc, "calls with abstract result must be used in fir.save_result" ); |
134 | return mlir::failure(); |
135 | } |
136 | auto argType = getResultArgumentType(result.getType(), shouldBoxResult); |
137 | auto buffer = saveResult.getMemref(); |
138 | mlir::Value arg = buffer; |
139 | if (mustEmboxResult(result.getType(), shouldBoxResult)) |
140 | arg = rewriter.create<fir::EmboxOp>( |
141 | loc, argType, buffer, saveResult.getShape(), /*slice*/ mlir::Value{}, |
142 | saveResult.getTypeparams()); |
143 | |
144 | llvm::SmallVector<mlir::Type> newResultTypes; |
145 | bool isResultBuiltinCPtr = fir::isa_builtin_cptr_type(result.getType()); |
146 | if (isResultBuiltinCPtr) |
147 | newResultTypes.emplace_back(getVoidPtrType(result.getContext())); |
148 | |
149 | Op newOp; |
150 | // TODO: propagate argument and result attributes (need to be shifted). |
151 | // fir::CallOp specific handling. |
152 | if constexpr (std::is_same_v<Op, fir::CallOp>) { |
153 | if (op.getCallee()) { |
154 | llvm::SmallVector<mlir::Value> newOperands; |
155 | if (!isResultBuiltinCPtr) |
156 | newOperands.emplace_back(arg); |
157 | newOperands.append(op.getOperands().begin(), op.getOperands().end()); |
158 | newOp = rewriter.create<fir::CallOp>(loc, *op.getCallee(), |
159 | newResultTypes, newOperands); |
160 | } else { |
161 | // Indirect calls. |
162 | llvm::SmallVector<mlir::Type> newInputTypes; |
163 | if (!isResultBuiltinCPtr) |
164 | newInputTypes.emplace_back(argType); |
165 | for (auto operand : op.getOperands().drop_front()) |
166 | newInputTypes.push_back(operand.getType()); |
167 | auto newFuncTy = mlir::FunctionType::get(op.getContext(), newInputTypes, |
168 | newResultTypes); |
169 | |
170 | llvm::SmallVector<mlir::Value> newOperands; |
171 | newOperands.push_back( |
172 | rewriter.create<fir::ConvertOp>(loc, newFuncTy, op.getOperand(0))); |
173 | if (!isResultBuiltinCPtr) |
174 | newOperands.push_back(arg); |
175 | newOperands.append(op.getOperands().begin() + 1, |
176 | op.getOperands().end()); |
177 | newOp = rewriter.create<fir::CallOp>(loc, mlir::SymbolRefAttr{}, |
178 | newResultTypes, newOperands); |
179 | } |
180 | } |
181 | |
182 | // fir::DispatchOp specific handling. |
183 | if constexpr (std::is_same_v<Op, fir::DispatchOp>) { |
184 | llvm::SmallVector<mlir::Value> newOperands; |
185 | if (!isResultBuiltinCPtr) |
186 | newOperands.emplace_back(arg); |
187 | unsigned passArgShift = newOperands.size(); |
188 | newOperands.append(op.getOperands().begin() + 1, op.getOperands().end()); |
189 | mlir::IntegerAttr passArgPos; |
190 | if (op.getPassArgPos()) |
191 | passArgPos = |
192 | rewriter.getI32IntegerAttr(*op.getPassArgPos() + passArgShift); |
193 | // TODO: propagate argument and result attributes (need to be shifted). |
194 | newOp = rewriter.create<fir::DispatchOp>( |
195 | loc, newResultTypes, rewriter.getStringAttr(op.getMethod()), |
196 | op.getOperands()[0], newOperands, passArgPos, |
197 | /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr, |
198 | op.getProcedureAttrsAttr()); |
199 | } |
200 | |
201 | if (isResultBuiltinCPtr) { |
202 | mlir::Value save = saveResult.getMemref(); |
203 | auto module = op->template getParentOfType<mlir::ModuleOp>(); |
204 | FirOpBuilder builder(rewriter, module); |
205 | mlir::Value saveAddr = fir::factory::genCPtrOrCFunptrAddr( |
206 | builder, loc, save, result.getType()); |
207 | builder.createStoreWithConvert(loc, newOp->getResult(0), saveAddr); |
208 | } |
209 | op->dropAllReferences(); |
210 | rewriter.eraseOp(op); |
211 | return mlir::success(); |
212 | } |
213 | |
214 | private: |
215 | bool shouldBoxResult; |
216 | }; |
217 | |
218 | class SaveResultOpConversion |
219 | : public mlir::OpRewritePattern<fir::SaveResultOp> { |
220 | public: |
221 | using OpRewritePattern::OpRewritePattern; |
222 | SaveResultOpConversion(mlir::MLIRContext *context) |
223 | : OpRewritePattern(context) {} |
224 | llvm::LogicalResult |
225 | matchAndRewrite(fir::SaveResultOp op, |
226 | mlir::PatternRewriter &rewriter) const override { |
227 | mlir::Operation *call = op.getValue().getDefiningOp(); |
228 | mlir::Type type = op.getValue().getType(); |
229 | if (mlir::isa<fir::RecordType>(type) && call && fir::hasBindcAttr(call) && |
230 | !fir::isa_builtin_cptr_type(type)) { |
231 | rewriter.replaceOpWithNewOp<fir::StoreOp>(op, op.getValue(), |
232 | op.getMemref()); |
233 | } else { |
234 | rewriter.eraseOp(op); |
235 | } |
236 | return mlir::success(); |
237 | } |
238 | }; |
239 | |
240 | template <typename OpTy> |
241 | static mlir::LogicalResult |
242 | processReturnLikeOp(OpTy ret, mlir::Value newArg, |
243 | mlir::PatternRewriter &rewriter) { |
244 | auto loc = ret.getLoc(); |
245 | rewriter.setInsertionPoint(ret); |
246 | mlir::Value resultValue = ret.getOperand(0); |
247 | fir::LoadOp resultLoad; |
248 | mlir::Value resultStorage; |
249 | // Identify result local storage. |
250 | if (auto load = resultValue.getDefiningOp<fir::LoadOp>()) { |
251 | resultLoad = load; |
252 | resultStorage = load.getMemref(); |
253 | // The result alloca may be behind a fir.declare, if any. |
254 | if (auto declare = resultStorage.getDefiningOp<fir::DeclareOp>()) |
255 | resultStorage = declare.getMemref(); |
256 | } |
257 | // Replace old local storage with new storage argument, unless |
258 | // the derived type is C_PTR/C_FUN_PTR, in which case the return |
259 | // type is updated to return void* (no new argument is passed). |
260 | if (fir::isa_builtin_cptr_type(resultValue.getType())) { |
261 | auto module = ret->template getParentOfType<mlir::ModuleOp>(); |
262 | FirOpBuilder builder(rewriter, module); |
263 | mlir::Value cptr = resultValue; |
264 | if (resultLoad) { |
265 | // Replace whole derived type load by component load. |
266 | cptr = resultLoad.getMemref(); |
267 | rewriter.setInsertionPoint(resultLoad); |
268 | } |
269 | mlir::Value newResultValue = |
270 | fir::factory::genCPtrOrCFunptrValue(builder, loc, cptr); |
271 | newResultValue = builder.createConvert( |
272 | loc, getVoidPtrType(ret.getContext()), newResultValue); |
273 | rewriter.setInsertionPoint(ret); |
274 | rewriter.replaceOpWithNewOp<OpTy>(ret, mlir::ValueRange{newResultValue}); |
275 | } else if (resultStorage) { |
276 | resultStorage.replaceAllUsesWith(newArg); |
277 | rewriter.replaceOpWithNewOp<OpTy>(ret); |
278 | } else { |
279 | // The result storage may have been optimized out by a memory to |
280 | // register pass, this is possible for fir.box results, or fir.record |
281 | // with no length parameters. Simply store the result in the result |
282 | // storage. at the return point. |
283 | rewriter.create<fir::StoreOp>(loc, resultValue, newArg); |
284 | rewriter.replaceOpWithNewOp<OpTy>(ret); |
285 | } |
286 | // Delete result old local storage if unused. |
287 | if (resultStorage) |
288 | if (auto alloc = resultStorage.getDefiningOp<fir::AllocaOp>()) |
289 | if (alloc->use_empty()) |
290 | rewriter.eraseOp(alloc); |
291 | return mlir::success(); |
292 | } |
293 | |
294 | class ReturnOpConversion : public mlir::OpRewritePattern<mlir::func::ReturnOp> { |
295 | public: |
296 | using OpRewritePattern::OpRewritePattern; |
297 | ReturnOpConversion(mlir::MLIRContext *context, mlir::Value newArg) |
298 | : OpRewritePattern(context), newArg{newArg} {} |
299 | llvm::LogicalResult |
300 | matchAndRewrite(mlir::func::ReturnOp ret, |
301 | mlir::PatternRewriter &rewriter) const override { |
302 | return processReturnLikeOp(ret, newArg, rewriter); |
303 | } |
304 | |
305 | private: |
306 | mlir::Value newArg; |
307 | }; |
308 | |
309 | class GPUReturnOpConversion |
310 | : public mlir::OpRewritePattern<mlir::gpu::ReturnOp> { |
311 | public: |
312 | using OpRewritePattern::OpRewritePattern; |
313 | GPUReturnOpConversion(mlir::MLIRContext *context, mlir::Value newArg) |
314 | : OpRewritePattern(context), newArg{newArg} {} |
315 | llvm::LogicalResult |
316 | matchAndRewrite(mlir::gpu::ReturnOp ret, |
317 | mlir::PatternRewriter &rewriter) const override { |
318 | return processReturnLikeOp(ret, newArg, rewriter); |
319 | } |
320 | |
321 | private: |
322 | mlir::Value newArg; |
323 | }; |
324 | |
325 | class AddrOfOpConversion : public mlir::OpRewritePattern<fir::AddrOfOp> { |
326 | public: |
327 | using OpRewritePattern::OpRewritePattern; |
328 | AddrOfOpConversion(mlir::MLIRContext *context, bool shouldBoxResult) |
329 | : OpRewritePattern(context), shouldBoxResult{shouldBoxResult} {} |
330 | llvm::LogicalResult |
331 | matchAndRewrite(fir::AddrOfOp addrOf, |
332 | mlir::PatternRewriter &rewriter) const override { |
333 | auto oldFuncTy = mlir::cast<mlir::FunctionType>(addrOf.getType()); |
334 | mlir::FunctionType newFuncTy; |
335 | if (oldFuncTy.getNumResults() != 0 && |
336 | fir::isa_builtin_cptr_type(oldFuncTy.getResult(0))) |
337 | newFuncTy = getCPtrFunctionType(oldFuncTy); |
338 | else |
339 | newFuncTy = getNewFunctionType(oldFuncTy, shouldBoxResult); |
340 | auto newAddrOf = rewriter.create<fir::AddrOfOp>(addrOf.getLoc(), newFuncTy, |
341 | addrOf.getSymbol()); |
342 | // Rather than converting all op a function pointer might transit through |
343 | // (e.g calls, stores, loads, converts...), cast new type to the abstract |
344 | // type. A conversion will be added when calling indirect calls of abstract |
345 | // types. |
346 | rewriter.replaceOpWithNewOp<fir::ConvertOp>(addrOf, oldFuncTy, newAddrOf); |
347 | return mlir::success(); |
348 | } |
349 | |
350 | private: |
351 | bool shouldBoxResult; |
352 | }; |
353 | |
354 | class AbstractResultOpt |
355 | : public fir::impl::AbstractResultOptBase<AbstractResultOpt> { |
356 | public: |
357 | using fir::impl::AbstractResultOptBase< |
358 | AbstractResultOpt>::AbstractResultOptBase; |
359 | |
360 | template <typename OpTy> |
361 | void runOnFunctionLikeOperation(OpTy func, bool shouldBoxResult, |
362 | mlir::RewritePatternSet &patterns, |
363 | mlir::ConversionTarget &target) { |
364 | auto loc = func.getLoc(); |
365 | auto *context = &getContext(); |
366 | // Convert function type itself if it has an abstract result. |
367 | auto funcTy = mlir::cast<mlir::FunctionType>(func.getFunctionType()); |
368 | // Scalar derived result of BIND(C) function must be returned according |
369 | // to the C struct return ABI which is target dependent and implemented in |
370 | // the target-rewrite pass. |
371 | if (hasScalarDerivedResult(funcTy) && |
372 | fir::hasBindcAttr(func.getOperation())) |
373 | return; |
374 | if (hasAbstractResult(funcTy)) { |
375 | if (fir::isa_builtin_cptr_type(funcTy.getResult(0))) { |
376 | func.setType(getCPtrFunctionType(funcTy)); |
377 | patterns.insert<ReturnOpConversion>(context, mlir::Value{}); |
378 | target.addDynamicallyLegalOp<mlir::func::ReturnOp>( |
379 | [](mlir::func::ReturnOp ret) { |
380 | mlir::Type retTy = ret.getOperand(0).getType(); |
381 | return !fir::isa_builtin_cptr_type(retTy); |
382 | }); |
383 | return; |
384 | } |
385 | if (!func.empty()) { |
386 | // Insert new argument. |
387 | mlir::OpBuilder rewriter(context); |
388 | auto resultType = funcTy.getResult(0); |
389 | auto argTy = getResultArgumentType(resultType, shouldBoxResult); |
390 | llvm::LogicalResult res = func.insertArgument(0u, argTy, {}, loc); |
391 | (void)res; |
392 | assert(llvm::succeeded(res) && "failed to insert function argument" ); |
393 | res = func.eraseResult(0u); |
394 | (void)res; |
395 | assert(llvm::succeeded(res) && "failed to erase function result" ); |
396 | mlir::Value newArg = func.getArgument(0u); |
397 | if (mustEmboxResult(resultType, shouldBoxResult)) { |
398 | auto bufferType = fir::ReferenceType::get(resultType); |
399 | rewriter.setInsertionPointToStart(&func.front()); |
400 | newArg = rewriter.create<fir::BoxAddrOp>(loc, bufferType, newArg); |
401 | } |
402 | patterns.insert<ReturnOpConversion>(context, newArg); |
403 | target.addDynamicallyLegalOp<mlir::func::ReturnOp>( |
404 | [](mlir::func::ReturnOp ret) { return ret.getOperands().empty(); }); |
405 | patterns.insert<GPUReturnOpConversion>(context, newArg); |
406 | target.addDynamicallyLegalOp<mlir::gpu::ReturnOp>( |
407 | [](mlir::gpu::ReturnOp ret) { return ret.getOperands().empty(); }); |
408 | assert(func.getFunctionType() == |
409 | getNewFunctionType(funcTy, shouldBoxResult)); |
410 | } else { |
411 | llvm::SmallVector<mlir::DictionaryAttr> allArgs; |
412 | func.getAllArgAttrs(allArgs); |
413 | allArgs.insert(allArgs.begin(), |
414 | mlir::DictionaryAttr::get(func->getContext())); |
415 | func.setType(getNewFunctionType(funcTy, shouldBoxResult)); |
416 | func.setAllArgAttrs(allArgs); |
417 | } |
418 | } |
419 | } |
420 | |
421 | void runOnSpecificOperation(mlir::func::FuncOp func, bool shouldBoxResult, |
422 | mlir::RewritePatternSet &patterns, |
423 | mlir::ConversionTarget &target) { |
424 | runOnFunctionLikeOperation(func, shouldBoxResult, patterns, target); |
425 | } |
426 | |
427 | void runOnSpecificOperation(mlir::gpu::GPUFuncOp func, bool shouldBoxResult, |
428 | mlir::RewritePatternSet &patterns, |
429 | mlir::ConversionTarget &target) { |
430 | runOnFunctionLikeOperation(func, shouldBoxResult, patterns, target); |
431 | } |
432 | |
433 | inline static bool containsFunctionTypeWithAbstractResult(mlir::Type type) { |
434 | return mlir::TypeSwitch<mlir::Type, bool>(type) |
435 | .Case([](fir::BoxProcType boxProc) { |
436 | return fir::hasAbstractResult( |
437 | mlir::cast<mlir::FunctionType>(boxProc.getEleTy())); |
438 | }) |
439 | .Case([](fir::PointerType pointer) { |
440 | return fir::hasAbstractResult( |
441 | mlir::cast<mlir::FunctionType>(pointer.getEleTy())); |
442 | }) |
443 | .Default([](auto &&) { return false; }); |
444 | } |
445 | |
446 | void runOnSpecificOperation(fir::GlobalOp global, bool, |
447 | mlir::RewritePatternSet &, |
448 | mlir::ConversionTarget &) { |
449 | if (containsFunctionTypeWithAbstractResult(global.getType())) { |
450 | TODO(global->getLoc(), "support for procedure pointers" ); |
451 | } |
452 | } |
453 | |
454 | /// Run the pass on a ModuleOp. This makes fir-opt --abstract-result work. |
455 | void runOnModule() { |
456 | mlir::ModuleOp mod = mlir::cast<mlir::ModuleOp>(getOperation()); |
457 | |
458 | auto pass = std::make_unique<AbstractResultOpt>(); |
459 | pass->copyOptionValuesFrom(this); |
460 | mlir::OpPassManager pipeline; |
461 | pipeline.addPass(std::unique_ptr<mlir::Pass>{pass.release()}); |
462 | |
463 | // Run the pass on all operations directly nested inside of the ModuleOp |
464 | // we can't just call runOnSpecificOperation here because the pass |
465 | // implementation only works when scoped to a particular func.func or |
466 | // fir.global |
467 | for (mlir::Region ®ion : mod->getRegions()) { |
468 | for (mlir::Block &block : region.getBlocks()) { |
469 | for (mlir::Operation &op : block.getOperations()) { |
470 | if (mlir::failed(runPipeline(pipeline, &op))) { |
471 | mlir::emitError(op.getLoc(), "Failed to run abstract result pass" ); |
472 | signalPassFailure(); |
473 | return; |
474 | } |
475 | } |
476 | } |
477 | } |
478 | } |
479 | |
480 | void runOnOperation() override { |
481 | auto *context = &this->getContext(); |
482 | mlir::Operation *op = this->getOperation(); |
483 | if (mlir::isa<mlir::ModuleOp>(op)) { |
484 | runOnModule(); |
485 | return; |
486 | } |
487 | |
488 | LazySymbolTable symbolTable(op); |
489 | |
490 | mlir::RewritePatternSet patterns(context); |
491 | mlir::ConversionTarget target = *context; |
492 | const bool shouldBoxResult = this->passResultAsBox.getValue(); |
493 | |
494 | mlir::TypeSwitch<mlir::Operation *, void>(op) |
495 | .Case<mlir::func::FuncOp, fir::GlobalOp, mlir::gpu::GPUFuncOp>( |
496 | [&](auto op) { |
497 | runOnSpecificOperation(op, shouldBoxResult, patterns, target); |
498 | }); |
499 | |
500 | // Convert the calls and, if needed, the ReturnOp in the function body. |
501 | target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect, |
502 | mlir::func::FuncDialect>(); |
503 | target.addIllegalOp<fir::SaveResultOp>(); |
504 | target.addDynamicallyLegalOp<fir::CallOp>([](fir::CallOp call) { |
505 | mlir::FunctionType funTy = call.getFunctionType(); |
506 | if (hasScalarDerivedResult(funTy) && |
507 | fir::hasBindcAttr(call.getOperation())) |
508 | return true; |
509 | return !hasAbstractResult(funTy); |
510 | }); |
511 | target.addDynamicallyLegalOp<fir::AddrOfOp>([&symbolTable]( |
512 | fir::AddrOfOp addrOf) { |
513 | if (auto funTy = mlir::dyn_cast<mlir::FunctionType>(addrOf.getType())) { |
514 | if (hasScalarDerivedResult(funTy)) { |
515 | auto func = symbolTable.lookup<mlir::func::FuncOp>( |
516 | addrOf.getSymbol().getRootReference().getValue()); |
517 | return func && fir::hasBindcAttr(func.getOperation()); |
518 | } |
519 | return !hasAbstractResult(funTy); |
520 | } |
521 | return true; |
522 | }); |
523 | target.addDynamicallyLegalOp<fir::DispatchOp>([](fir::DispatchOp dispatch) { |
524 | mlir::FunctionType funTy = dispatch.getFunctionType(); |
525 | if (hasScalarDerivedResult(funTy) && |
526 | fir::hasBindcAttr(dispatch.getOperation())) |
527 | return true; |
528 | return !hasAbstractResult(dispatch.getFunctionType()); |
529 | }); |
530 | |
531 | patterns.insert<CallConversion<fir::CallOp>>(context, shouldBoxResult); |
532 | patterns.insert<CallConversion<fir::DispatchOp>>(context, shouldBoxResult); |
533 | patterns.insert<SaveResultOpConversion>(context); |
534 | patterns.insert<AddrOfOpConversion>(context, shouldBoxResult); |
535 | if (mlir::failed( |
536 | mlir::applyPartialConversion(op, target, std::move(patterns)))) { |
537 | mlir::emitError(op->getLoc(), "error in converting abstract results\n" ); |
538 | this->signalPassFailure(); |
539 | } |
540 | } |
541 | }; |
542 | |
543 | } // end anonymous namespace |
544 | } // namespace fir |
545 | |