1 | //===- AbstractResult.cpp - Conversion of Abstract Function Result --------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "flang/Optimizer/Builder/FIRBuilder.h" |
10 | #include "flang/Optimizer/Builder/Todo.h" |
11 | #include "flang/Optimizer/Dialect/FIRDialect.h" |
12 | #include "flang/Optimizer/Dialect/FIROps.h" |
13 | #include "flang/Optimizer/Dialect/FIRType.h" |
14 | #include "flang/Optimizer/Dialect/Support/FIRContext.h" |
15 | #include "flang/Optimizer/Transforms/Passes.h" |
16 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
17 | #include "mlir/IR/Diagnostics.h" |
18 | #include "mlir/Pass/Pass.h" |
19 | #include "mlir/Pass/PassManager.h" |
20 | #include "mlir/Transforms/DialectConversion.h" |
21 | #include "llvm/ADT/TypeSwitch.h" |
22 | |
23 | namespace fir { |
24 | #define GEN_PASS_DEF_ABSTRACTRESULTOPT |
25 | #include "flang/Optimizer/Transforms/Passes.h.inc" |
26 | } // namespace fir |
27 | |
28 | #define DEBUG_TYPE "flang-abstract-result-opt" |
29 | |
30 | using namespace mlir; |
31 | |
32 | namespace fir { |
33 | namespace { |
34 | |
35 | static mlir::Type getResultArgumentType(mlir::Type resultType, |
36 | bool shouldBoxResult) { |
37 | return llvm::TypeSwitch<mlir::Type, mlir::Type>(resultType) |
38 | .Case<fir::SequenceType, fir::RecordType>( |
39 | [&](mlir::Type type) -> mlir::Type { |
40 | if (shouldBoxResult) |
41 | return fir::BoxType::get(type); |
42 | return fir::ReferenceType::get(type); |
43 | }) |
44 | .Case<fir::BaseBoxType>([](mlir::Type type) -> mlir::Type { |
45 | return fir::ReferenceType::get(type); |
46 | }) |
47 | .Default([](mlir::Type) -> mlir::Type { |
48 | llvm_unreachable("bad abstract result type" ); |
49 | }); |
50 | } |
51 | |
52 | static mlir::FunctionType getNewFunctionType(mlir::FunctionType funcTy, |
53 | bool shouldBoxResult) { |
54 | auto resultType = funcTy.getResult(0); |
55 | auto argTy = getResultArgumentType(resultType, shouldBoxResult); |
56 | llvm::SmallVector<mlir::Type> newInputTypes = {argTy}; |
57 | newInputTypes.append(funcTy.getInputs().begin(), funcTy.getInputs().end()); |
58 | return mlir::FunctionType::get(funcTy.getContext(), newInputTypes, |
59 | /*resultTypes=*/{}); |
60 | } |
61 | |
62 | /// This is for function result types that are of type C_PTR from ISO_C_BINDING. |
63 | /// Follow the ABI for interoperability with C. |
64 | static mlir::FunctionType getCPtrFunctionType(mlir::FunctionType funcTy) { |
65 | auto resultType = funcTy.getResult(0); |
66 | assert(fir::isa_builtin_cptr_type(resultType)); |
67 | llvm::SmallVector<mlir::Type> outputTypes; |
68 | auto recTy = resultType.dyn_cast<fir::RecordType>(); |
69 | outputTypes.emplace_back(recTy.getTypeList()[0].second); |
70 | return mlir::FunctionType::get(funcTy.getContext(), funcTy.getInputs(), |
71 | outputTypes); |
72 | } |
73 | |
74 | static bool mustEmboxResult(mlir::Type resultType, bool shouldBoxResult) { |
75 | return resultType.isa<fir::SequenceType, fir::RecordType>() && |
76 | shouldBoxResult; |
77 | } |
78 | |
79 | template <typename Op> |
80 | class CallConversion : public mlir::OpRewritePattern<Op> { |
81 | public: |
82 | using mlir::OpRewritePattern<Op>::OpRewritePattern; |
83 | |
84 | CallConversion(mlir::MLIRContext *context, bool shouldBoxResult) |
85 | : OpRewritePattern<Op>(context, 1), shouldBoxResult{shouldBoxResult} {} |
86 | |
87 | mlir::LogicalResult |
88 | matchAndRewrite(Op op, mlir::PatternRewriter &rewriter) const override { |
89 | auto loc = op.getLoc(); |
90 | auto result = op->getResult(0); |
91 | if (!result.hasOneUse()) { |
92 | mlir::emitError(loc, |
93 | "calls with abstract result must have exactly one user" ); |
94 | return mlir::failure(); |
95 | } |
96 | auto saveResult = |
97 | mlir::dyn_cast<fir::SaveResultOp>(result.use_begin().getUser()); |
98 | if (!saveResult) { |
99 | mlir::emitError( |
100 | loc, "calls with abstract result must be used in fir.save_result" ); |
101 | return mlir::failure(); |
102 | } |
103 | auto argType = getResultArgumentType(result.getType(), shouldBoxResult); |
104 | auto buffer = saveResult.getMemref(); |
105 | mlir::Value arg = buffer; |
106 | if (mustEmboxResult(result.getType(), shouldBoxResult)) |
107 | arg = rewriter.create<fir::EmboxOp>( |
108 | loc, argType, buffer, saveResult.getShape(), /*slice*/ mlir::Value{}, |
109 | saveResult.getTypeparams()); |
110 | |
111 | llvm::SmallVector<mlir::Type> newResultTypes; |
112 | // TODO: This should be generalized for derived types, and it is |
113 | // architecture and OS dependent. |
114 | bool isResultBuiltinCPtr = fir::isa_builtin_cptr_type(result.getType()); |
115 | Op newOp; |
116 | if (isResultBuiltinCPtr) { |
117 | auto recTy = result.getType().template dyn_cast<fir::RecordType>(); |
118 | newResultTypes.emplace_back(recTy.getTypeList()[0].second); |
119 | } |
120 | |
121 | // fir::CallOp specific handling. |
122 | if constexpr (std::is_same_v<Op, fir::CallOp>) { |
123 | if (op.getCallee()) { |
124 | llvm::SmallVector<mlir::Value> newOperands; |
125 | if (!isResultBuiltinCPtr) |
126 | newOperands.emplace_back(arg); |
127 | newOperands.append(op.getOperands().begin(), op.getOperands().end()); |
128 | newOp = rewriter.create<fir::CallOp>(loc, *op.getCallee(), |
129 | newResultTypes, newOperands); |
130 | } else { |
131 | // Indirect calls. |
132 | llvm::SmallVector<mlir::Type> newInputTypes; |
133 | if (!isResultBuiltinCPtr) |
134 | newInputTypes.emplace_back(argType); |
135 | for (auto operand : op.getOperands().drop_front()) |
136 | newInputTypes.push_back(operand.getType()); |
137 | auto newFuncTy = mlir::FunctionType::get(op.getContext(), newInputTypes, |
138 | newResultTypes); |
139 | |
140 | llvm::SmallVector<mlir::Value> newOperands; |
141 | newOperands.push_back( |
142 | rewriter.create<fir::ConvertOp>(loc, newFuncTy, op.getOperand(0))); |
143 | if (!isResultBuiltinCPtr) |
144 | newOperands.push_back(arg); |
145 | newOperands.append(op.getOperands().begin() + 1, |
146 | op.getOperands().end()); |
147 | newOp = rewriter.create<fir::CallOp>(loc, mlir::SymbolRefAttr{}, |
148 | newResultTypes, newOperands); |
149 | } |
150 | } |
151 | |
152 | // fir::DispatchOp specific handling. |
153 | if constexpr (std::is_same_v<Op, fir::DispatchOp>) { |
154 | llvm::SmallVector<mlir::Value> newOperands; |
155 | if (!isResultBuiltinCPtr) |
156 | newOperands.emplace_back(arg); |
157 | unsigned passArgShift = newOperands.size(); |
158 | newOperands.append(op.getOperands().begin() + 1, op.getOperands().end()); |
159 | |
160 | fir::DispatchOp newDispatchOp; |
161 | if (op.getPassArgPos()) |
162 | newOp = rewriter.create<fir::DispatchOp>( |
163 | loc, newResultTypes, rewriter.getStringAttr(op.getMethod()), |
164 | op.getOperands()[0], newOperands, |
165 | rewriter.getI32IntegerAttr(*op.getPassArgPos() + passArgShift)); |
166 | else |
167 | newOp = rewriter.create<fir::DispatchOp>( |
168 | loc, newResultTypes, rewriter.getStringAttr(op.getMethod()), |
169 | op.getOperands()[0], newOperands, nullptr); |
170 | } |
171 | |
172 | if (isResultBuiltinCPtr) { |
173 | mlir::Value save = saveResult.getMemref(); |
174 | auto module = op->template getParentOfType<mlir::ModuleOp>(); |
175 | FirOpBuilder builder(rewriter, module); |
176 | mlir::Value saveAddr = fir::factory::genCPtrOrCFunptrAddr( |
177 | builder, loc, save, result.getType()); |
178 | rewriter.create<fir::StoreOp>(loc, newOp->getResult(0), saveAddr); |
179 | } |
180 | op->dropAllReferences(); |
181 | rewriter.eraseOp(op); |
182 | return mlir::success(); |
183 | } |
184 | |
185 | private: |
186 | bool shouldBoxResult; |
187 | }; |
188 | |
189 | class SaveResultOpConversion |
190 | : public mlir::OpRewritePattern<fir::SaveResultOp> { |
191 | public: |
192 | using OpRewritePattern::OpRewritePattern; |
193 | SaveResultOpConversion(mlir::MLIRContext *context) |
194 | : OpRewritePattern(context) {} |
195 | mlir::LogicalResult |
196 | matchAndRewrite(fir::SaveResultOp op, |
197 | mlir::PatternRewriter &rewriter) const override { |
198 | rewriter.eraseOp(op); |
199 | return mlir::success(); |
200 | } |
201 | }; |
202 | |
203 | class ReturnOpConversion : public mlir::OpRewritePattern<mlir::func::ReturnOp> { |
204 | public: |
205 | using OpRewritePattern::OpRewritePattern; |
206 | ReturnOpConversion(mlir::MLIRContext *context, mlir::Value newArg) |
207 | : OpRewritePattern(context), newArg{newArg} {} |
208 | mlir::LogicalResult |
209 | matchAndRewrite(mlir::func::ReturnOp ret, |
210 | mlir::PatternRewriter &rewriter) const override { |
211 | auto loc = ret.getLoc(); |
212 | rewriter.setInsertionPoint(ret); |
213 | auto returnedValue = ret.getOperand(0); |
214 | bool replacedStorage = false; |
215 | if (auto *op = returnedValue.getDefiningOp()) |
216 | if (auto load = mlir::dyn_cast<fir::LoadOp>(op)) { |
217 | auto resultStorage = load.getMemref(); |
218 | // The result alloca may be behind a fir.declare, if any. |
219 | if (auto declare = mlir::dyn_cast_or_null<fir::DeclareOp>( |
220 | resultStorage.getDefiningOp())) |
221 | resultStorage = declare.getMemref(); |
222 | // TODO: This should be generalized for derived types, and it is |
223 | // architecture and OS dependent. |
224 | if (fir::isa_builtin_cptr_type(returnedValue.getType())) { |
225 | rewriter.eraseOp(load); |
226 | auto module = ret->getParentOfType<mlir::ModuleOp>(); |
227 | FirOpBuilder builder(rewriter, module); |
228 | mlir::Value retAddr = fir::factory::genCPtrOrCFunptrAddr( |
229 | builder, loc, resultStorage, returnedValue.getType()); |
230 | mlir::Value retValue = rewriter.create<fir::LoadOp>( |
231 | loc, fir::unwrapRefType(retAddr.getType()), retAddr); |
232 | rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>( |
233 | ret, mlir::ValueRange{retValue}); |
234 | return mlir::success(); |
235 | } |
236 | resultStorage.replaceAllUsesWith(newArg); |
237 | replacedStorage = true; |
238 | if (auto *alloc = resultStorage.getDefiningOp()) |
239 | if (alloc->use_empty()) |
240 | rewriter.eraseOp(alloc); |
241 | } |
242 | // The result storage may have been optimized out by a memory to |
243 | // register pass, this is possible for fir.box results, or fir.record |
244 | // with no length parameters. Simply store the result in the result storage. |
245 | // at the return point. |
246 | if (!replacedStorage) |
247 | rewriter.create<fir::StoreOp>(loc, returnedValue, newArg); |
248 | rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(ret); |
249 | return mlir::success(); |
250 | } |
251 | |
252 | private: |
253 | mlir::Value newArg; |
254 | }; |
255 | |
256 | class AddrOfOpConversion : public mlir::OpRewritePattern<fir::AddrOfOp> { |
257 | public: |
258 | using OpRewritePattern::OpRewritePattern; |
259 | AddrOfOpConversion(mlir::MLIRContext *context, bool shouldBoxResult) |
260 | : OpRewritePattern(context), shouldBoxResult{shouldBoxResult} {} |
261 | mlir::LogicalResult |
262 | matchAndRewrite(fir::AddrOfOp addrOf, |
263 | mlir::PatternRewriter &rewriter) const override { |
264 | auto oldFuncTy = addrOf.getType().cast<mlir::FunctionType>(); |
265 | mlir::FunctionType newFuncTy; |
266 | // TODO: This should be generalized for derived types, and it is |
267 | // architecture and OS dependent. |
268 | if (oldFuncTy.getNumResults() != 0 && |
269 | fir::isa_builtin_cptr_type(oldFuncTy.getResult(0))) |
270 | newFuncTy = getCPtrFunctionType(oldFuncTy); |
271 | else |
272 | newFuncTy = getNewFunctionType(oldFuncTy, shouldBoxResult); |
273 | auto newAddrOf = rewriter.create<fir::AddrOfOp>(addrOf.getLoc(), newFuncTy, |
274 | addrOf.getSymbol()); |
275 | // Rather than converting all op a function pointer might transit through |
276 | // (e.g calls, stores, loads, converts...), cast new type to the abstract |
277 | // type. A conversion will be added when calling indirect calls of abstract |
278 | // types. |
279 | rewriter.replaceOpWithNewOp<fir::ConvertOp>(addrOf, oldFuncTy, newAddrOf); |
280 | return mlir::success(); |
281 | } |
282 | |
283 | private: |
284 | bool shouldBoxResult; |
285 | }; |
286 | |
287 | class AbstractResultOpt |
288 | : public fir::impl::AbstractResultOptBase<AbstractResultOpt> { |
289 | public: |
290 | using fir::impl::AbstractResultOptBase< |
291 | AbstractResultOpt>::AbstractResultOptBase; |
292 | |
293 | void runOnSpecificOperation(mlir::func::FuncOp func, bool shouldBoxResult, |
294 | mlir::RewritePatternSet &patterns, |
295 | mlir::ConversionTarget &target) { |
296 | auto loc = func.getLoc(); |
297 | auto *context = &getContext(); |
298 | // Convert function type itself if it has an abstract result. |
299 | auto funcTy = func.getFunctionType().cast<mlir::FunctionType>(); |
300 | if (hasAbstractResult(funcTy)) { |
301 | // TODO: This should be generalized for derived types, and it is |
302 | // architecture and OS dependent. |
303 | if (fir::isa_builtin_cptr_type(funcTy.getResult(0))) { |
304 | func.setType(getCPtrFunctionType(funcTy)); |
305 | patterns.insert<ReturnOpConversion>(context, mlir::Value{}); |
306 | target.addDynamicallyLegalOp<mlir::func::ReturnOp>( |
307 | [](mlir::func::ReturnOp ret) { |
308 | mlir::Type retTy = ret.getOperand(0).getType(); |
309 | return !fir::isa_builtin_cptr_type(retTy); |
310 | }); |
311 | return; |
312 | } |
313 | if (!func.empty()) { |
314 | // Insert new argument. |
315 | mlir::OpBuilder rewriter(context); |
316 | auto resultType = funcTy.getResult(0); |
317 | auto argTy = getResultArgumentType(resultType, shouldBoxResult); |
318 | func.insertArgument(0u, argTy, {}, loc); |
319 | func.eraseResult(0u); |
320 | mlir::Value newArg = func.getArgument(0u); |
321 | if (mustEmboxResult(resultType, shouldBoxResult)) { |
322 | auto bufferType = fir::ReferenceType::get(resultType); |
323 | rewriter.setInsertionPointToStart(&func.front()); |
324 | newArg = rewriter.create<fir::BoxAddrOp>(loc, bufferType, newArg); |
325 | } |
326 | patterns.insert<ReturnOpConversion>(context, newArg); |
327 | target.addDynamicallyLegalOp<mlir::func::ReturnOp>( |
328 | [](mlir::func::ReturnOp ret) { return ret.getOperands().empty(); }); |
329 | assert(func.getFunctionType() == |
330 | getNewFunctionType(funcTy, shouldBoxResult)); |
331 | } else { |
332 | llvm::SmallVector<mlir::DictionaryAttr> allArgs; |
333 | func.getAllArgAttrs(allArgs); |
334 | allArgs.insert(allArgs.begin(), |
335 | mlir::DictionaryAttr::get(func->getContext())); |
336 | func.setType(getNewFunctionType(funcTy, shouldBoxResult)); |
337 | func.setAllArgAttrs(allArgs); |
338 | } |
339 | } |
340 | } |
341 | |
342 | inline static bool containsFunctionTypeWithAbstractResult(mlir::Type type) { |
343 | return mlir::TypeSwitch<mlir::Type, bool>(type) |
344 | .Case([](fir::BoxProcType boxProc) { |
345 | return fir::hasAbstractResult( |
346 | boxProc.getEleTy().cast<mlir::FunctionType>()); |
347 | }) |
348 | .Case([](fir::PointerType pointer) { |
349 | return fir::hasAbstractResult( |
350 | pointer.getEleTy().cast<mlir::FunctionType>()); |
351 | }) |
352 | .Default([](auto &&) { return false; }); |
353 | } |
354 | |
355 | void runOnSpecificOperation(fir::GlobalOp global, bool, |
356 | mlir::RewritePatternSet &, |
357 | mlir::ConversionTarget &) { |
358 | if (containsFunctionTypeWithAbstractResult(global.getType())) { |
359 | TODO(global->getLoc(), "support for procedure pointers" ); |
360 | } |
361 | } |
362 | |
363 | /// Run the pass on a ModuleOp. This makes fir-opt --abstract-result work. |
364 | void runOnModule() { |
365 | mlir::ModuleOp mod = mlir::cast<mlir::ModuleOp>(getOperation()); |
366 | |
367 | auto pass = std::make_unique<AbstractResultOpt>(); |
368 | pass->copyOptionValuesFrom(this); |
369 | mlir::OpPassManager pipeline; |
370 | pipeline.addPass(std::unique_ptr<mlir::Pass>{pass.release()}); |
371 | |
372 | // Run the pass on all operations directly nested inside of the ModuleOp |
373 | // we can't just call runOnSpecificOperation here because the pass |
374 | // implementation only works when scoped to a particular func.func or |
375 | // fir.global |
376 | for (mlir::Region ®ion : mod->getRegions()) { |
377 | for (mlir::Block &block : region.getBlocks()) { |
378 | for (mlir::Operation &op : block.getOperations()) { |
379 | if (mlir::failed(runPipeline(pipeline, &op))) { |
380 | mlir::emitError(op.getLoc(), "Failed to run abstract result pass" ); |
381 | signalPassFailure(); |
382 | return; |
383 | } |
384 | } |
385 | } |
386 | } |
387 | } |
388 | |
389 | void runOnOperation() override { |
390 | auto *context = &this->getContext(); |
391 | mlir::Operation *op = this->getOperation(); |
392 | if (mlir::isa<mlir::ModuleOp>(op)) { |
393 | runOnModule(); |
394 | return; |
395 | } |
396 | |
397 | mlir::RewritePatternSet patterns(context); |
398 | mlir::ConversionTarget target = *context; |
399 | const bool shouldBoxResult = this->passResultAsBox.getValue(); |
400 | |
401 | mlir::TypeSwitch<mlir::Operation *, void>(op) |
402 | .Case<mlir::func::FuncOp, fir::GlobalOp>([&](auto op) { |
403 | runOnSpecificOperation(op, shouldBoxResult, patterns, target); |
404 | }); |
405 | |
406 | // Convert the calls and, if needed, the ReturnOp in the function body. |
407 | target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect, |
408 | mlir::func::FuncDialect>(); |
409 | target.addIllegalOp<fir::SaveResultOp>(); |
410 | target.addDynamicallyLegalOp<fir::CallOp>([](fir::CallOp call) { |
411 | return !hasAbstractResult(call.getFunctionType()); |
412 | }); |
413 | target.addDynamicallyLegalOp<fir::AddrOfOp>([](fir::AddrOfOp addrOf) { |
414 | if (auto funTy = addrOf.getType().dyn_cast<mlir::FunctionType>()) |
415 | return !hasAbstractResult(funTy); |
416 | return true; |
417 | }); |
418 | target.addDynamicallyLegalOp<fir::DispatchOp>([](fir::DispatchOp dispatch) { |
419 | return !hasAbstractResult(dispatch.getFunctionType()); |
420 | }); |
421 | |
422 | patterns.insert<CallConversion<fir::CallOp>>(context, shouldBoxResult); |
423 | patterns.insert<CallConversion<fir::DispatchOp>>(context, shouldBoxResult); |
424 | patterns.insert<SaveResultOpConversion>(context); |
425 | patterns.insert<AddrOfOpConversion>(context, shouldBoxResult); |
426 | if (mlir::failed( |
427 | mlir::applyPartialConversion(op, target, std::move(patterns)))) { |
428 | mlir::emitError(op->getLoc(), "error in converting abstract results\n" ); |
429 | this->signalPassFailure(); |
430 | } |
431 | } |
432 | }; |
433 | |
434 | } // end anonymous namespace |
435 | } // namespace fir |