1//===-- TargetRewrite.cpp -------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Target rewrite: rewriting of ops to make target-specific lowerings manifest.
10// LLVM expects different lowering idioms to be used for distinct target
11// triples. These distinctions are handled by this pass.
12//
13// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
14//
15//===----------------------------------------------------------------------===//
16
17#include "flang/Optimizer/CodeGen/CodeGen.h"
18
19#include "flang/Optimizer/Builder/Character.h"
20#include "flang/Optimizer/Builder/FIRBuilder.h"
21#include "flang/Optimizer/Builder/Todo.h"
22#include "flang/Optimizer/CodeGen/Target.h"
23#include "flang/Optimizer/Dialect/FIRDialect.h"
24#include "flang/Optimizer/Dialect/FIROps.h"
25#include "flang/Optimizer/Dialect/FIROpsSupport.h"
26#include "flang/Optimizer/Dialect/FIRType.h"
27#include "flang/Optimizer/Dialect/Support/FIRContext.h"
28#include "flang/Optimizer/Support/DataLayout.h"
29#include "mlir/Dialect/DLTI/DLTI.h"
30#include "mlir/Transforms/DialectConversion.h"
31#include "llvm/ADT/STLExtras.h"
32#include "llvm/ADT/TypeSwitch.h"
33#include "llvm/Support/Debug.h"
34#include <optional>
35
36namespace fir {
37#define GEN_PASS_DEF_TARGETREWRITEPASS
38#include "flang/Optimizer/CodeGen/CGPasses.h.inc"
39} // namespace fir
40
41#define DEBUG_TYPE "flang-target-rewrite"
42
43namespace {
44
45/// Fixups for updating a FuncOp's arguments and return values.
46struct FixupTy {
47 enum class Codes {
48 ArgumentAsLoad,
49 ArgumentType,
50 CharPair,
51 ReturnAsStore,
52 ReturnType,
53 Split,
54 Trailing,
55 TrailingCharProc
56 };
57
58 FixupTy(Codes code, std::size_t index, std::size_t second = 0)
59 : code{code}, index{index}, second{second} {}
60 FixupTy(Codes code, std::size_t index,
61 std::function<void(mlir::func::FuncOp)> &&finalizer)
62 : code{code}, index{index}, finalizer{finalizer} {}
63 FixupTy(Codes code, std::size_t index, std::size_t second,
64 std::function<void(mlir::func::FuncOp)> &&finalizer)
65 : code{code}, index{index}, second{second}, finalizer{finalizer} {}
66
67 Codes code;
68 std::size_t index;
69 std::size_t second{};
70 std::optional<std::function<void(mlir::func::FuncOp)>> finalizer{};
71}; // namespace
72
73/// Target-specific rewriting of the FIR. This is a prerequisite pass to code
74/// generation that traverses the FIR and modifies types and operations to a
75/// form that is appropriate for the specific target. LLVM IR has specific
76/// idioms that are used for distinct target processor and ABI combinations.
77class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
78public:
79 TargetRewrite(const fir::TargetRewriteOptions &options) {
80 noCharacterConversion = options.noCharacterConversion;
81 noComplexConversion = options.noComplexConversion;
82 noStructConversion = options.noStructConversion;
83 }
84
85 void runOnOperation() override final {
86 auto &context = getContext();
87 mlir::OpBuilder rewriter(&context);
88
89 auto mod = getModule();
90 if (!forcedTargetTriple.empty())
91 fir::setTargetTriple(mod, forcedTargetTriple);
92
93 if (!forcedTargetCPU.empty())
94 fir::setTargetCPU(mod, forcedTargetCPU);
95
96 if (!forcedTargetFeatures.empty())
97 fir::setTargetFeatures(mod, forcedTargetFeatures);
98
99 // TargetRewrite will require querying the type storage sizes, if it was
100 // not set already, create a DataLayoutSpec for the ModuleOp now.
101 std::optional<mlir::DataLayout> dl =
102 fir::support::getOrSetDataLayout(mod, /*allowDefaultLayout=*/true);
103 if (!dl) {
104 mlir::emitError(mod.getLoc(),
105 "module operation must carry a data layout attribute "
106 "to perform target ABI rewrites on FIR");
107 signalPassFailure();
108 return;
109 }
110
111 auto specifics = fir::CodeGenSpecifics::get(
112 mod.getContext(), fir::getTargetTriple(mod), fir::getKindMapping(mod),
113 fir::getTargetCPU(mod), fir::getTargetFeatures(mod), *dl);
114
115 setMembers(specifics.get(), &rewriter, &*dl);
116
117 // We may need to call stacksave/stackrestore later, so
118 // create the FuncOps beforehand.
119 fir::FirOpBuilder builder(rewriter, mod);
120 builder.setInsertionPointToStart(mod.getBody());
121 stackSaveFn = fir::factory::getLlvmStackSave(builder);
122 stackRestoreFn = fir::factory::getLlvmStackRestore(builder);
123
124 // Perform type conversion on signatures and call sites.
125 if (mlir::failed(convertTypes(mod))) {
126 mlir::emitError(mlir::UnknownLoc::get(&context),
127 "error in converting types to target abi");
128 signalPassFailure();
129 }
130
131 // Convert ops in target-specific patterns.
132 mod.walk([&](mlir::Operation *op) {
133 if (auto call = mlir::dyn_cast<fir::CallOp>(op)) {
134 if (!hasPortableSignature(call.getFunctionType(), op))
135 convertCallOp(call);
136 } else if (auto dispatch = mlir::dyn_cast<fir::DispatchOp>(op)) {
137 if (!hasPortableSignature(dispatch.getFunctionType(), op))
138 convertCallOp(dispatch);
139 } else if (auto addr = mlir::dyn_cast<fir::AddrOfOp>(op)) {
140 if (addr.getType().isa<mlir::FunctionType>() &&
141 !hasPortableSignature(addr.getType(), op))
142 convertAddrOp(addr);
143 }
144 });
145
146 clearMembers();
147 }
148
149 mlir::ModuleOp getModule() { return getOperation(); }
150
151 template <typename A, typename B, typename C>
152 std::optional<std::function<mlir::Value(mlir::Operation *)>>
153 rewriteCallComplexResultType(
154 mlir::Location loc, A ty, B &newResTys,
155 fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs, C &newOpers,
156 mlir::Value &savedStackPtr) {
157 if (noComplexConversion) {
158 newResTys.push_back(ty);
159 return std::nullopt;
160 }
161 auto m = specifics->complexReturnType(loc, ty.getElementType());
162 // Currently targets mandate COMPLEX is a single aggregate or packed
163 // scalar, including the sret case.
164 assert(m.size() == 1 && "target of complex return not supported");
165 auto resTy = std::get<mlir::Type>(m[0]);
166 auto attr = std::get<fir::CodeGenSpecifics::Attributes>(m[0]);
167 if (attr.isSRet()) {
168 assert(fir::isa_ref_type(resTy) && "must be a memory reference type");
169 // Save the stack pointer, if it has not been saved for this call yet.
170 // We will need to restore it after the call, because the alloca
171 // needs to be deallocated.
172 if (!savedStackPtr)
173 savedStackPtr = genStackSave(loc);
174 mlir::Value stack =
175 rewriter->create<fir::AllocaOp>(loc, fir::dyn_cast_ptrEleTy(resTy));
176 newInTyAndAttrs.push_back(m[0]);
177 newOpers.push_back(stack);
178 return [=](mlir::Operation *) -> mlir::Value {
179 auto memTy = fir::ReferenceType::get(ty);
180 auto cast = rewriter->create<fir::ConvertOp>(loc, memTy, stack);
181 return rewriter->create<fir::LoadOp>(loc, cast);
182 };
183 }
184 newResTys.push_back(resTy);
185 return [=, &savedStackPtr](mlir::Operation *call) -> mlir::Value {
186 // We are going to generate an alloca, so save the stack pointer.
187 if (!savedStackPtr)
188 savedStackPtr = genStackSave(loc);
189 return this->convertValueInMemory(loc, call->getResult(0), ty,
190 /*inputMayBeBigger=*/true);
191 };
192 }
193
194 void passArgumentOnStackOrWithNewType(
195 mlir::Location loc, fir::CodeGenSpecifics::TypeAndAttr newTypeAndAttr,
196 mlir::Type oldType, mlir::Value oper,
197 llvm::SmallVectorImpl<mlir::Value> &newOpers,
198 mlir::Value &savedStackPtr) {
199 auto resTy = std::get<mlir::Type>(newTypeAndAttr);
200 auto attr = std::get<fir::CodeGenSpecifics::Attributes>(newTypeAndAttr);
201 // We are going to generate an alloca, so save the stack pointer.
202 if (!savedStackPtr)
203 savedStackPtr = genStackSave(loc);
204 if (attr.isByVal()) {
205 mlir::Value mem = rewriter->create<fir::AllocaOp>(loc, oldType);
206 rewriter->create<fir::StoreOp>(loc, oper, mem);
207 if (mem.getType() != resTy)
208 mem = rewriter->create<fir::ConvertOp>(loc, resTy, mem);
209 newOpers.push_back(mem);
210 } else {
211 mlir::Value bitcast =
212 convertValueInMemory(loc, oper, resTy, /*inputMayBeBigger=*/false);
213 newOpers.push_back(bitcast);
214 }
215 }
216
217 // Do a bitcast (convert a value via its memory representation).
218 // The input and output types may have different storage sizes,
219 // "inputMayBeBigger" should be set to indicate which of the input or
220 // output type may be bigger in order for the load/store to be safe.
221 // The mismatch comes from the fact that the LLVM register used for passing
222 // may be bigger than the value being passed (e.g., passing
223 // a `!fir.type<t{fir.array<3xi8>}>` into an i32 LLVM register).
224 mlir::Value convertValueInMemory(mlir::Location loc, mlir::Value value,
225 mlir::Type newType, bool inputMayBeBigger) {
226 if (inputMayBeBigger) {
227 auto newRefTy = fir::ReferenceType::get(newType);
228 auto mem = rewriter->create<fir::AllocaOp>(loc, value.getType());
229 rewriter->create<fir::StoreOp>(loc, value, mem);
230 auto cast = rewriter->create<fir::ConvertOp>(loc, newRefTy, mem);
231 return rewriter->create<fir::LoadOp>(loc, cast);
232 } else {
233 auto oldRefTy = fir::ReferenceType::get(value.getType());
234 auto mem = rewriter->create<fir::AllocaOp>(loc, newType);
235 auto cast = rewriter->create<fir::ConvertOp>(loc, oldRefTy, mem);
236 rewriter->create<fir::StoreOp>(loc, value, cast);
237 return rewriter->create<fir::LoadOp>(loc, mem);
238 }
239 }
240
241 void passSplitArgument(mlir::Location loc,
242 fir::CodeGenSpecifics::Marshalling splitArgs,
243 mlir::Type oldType, mlir::Value oper,
244 llvm::SmallVectorImpl<mlir::Value> &newOpers,
245 mlir::Value &savedStackPtr) {
246 // COMPLEX or struct argument split into separate arguments
247 if (!fir::isa_complex(oldType)) {
248 // Cast original operand to a tuple of the new arguments
249 // via memory.
250 llvm::SmallVector<mlir::Type> partTypes;
251 for (auto argPart : splitArgs)
252 partTypes.push_back(std::get<mlir::Type>(argPart));
253 mlir::Type tupleType =
254 mlir::TupleType::get(oldType.getContext(), partTypes);
255 if (!savedStackPtr)
256 savedStackPtr = genStackSave(loc);
257 oper = convertValueInMemory(loc, oper, tupleType,
258 /*inputMayBeBigger=*/false);
259 }
260 auto iTy = rewriter->getIntegerType(32);
261 for (auto e : llvm::enumerate(splitArgs)) {
262 auto &tup = e.value();
263 auto ty = std::get<mlir::Type>(tup);
264 auto index = e.index();
265 auto idx = rewriter->getIntegerAttr(iTy, index);
266 auto val = rewriter->create<fir::ExtractValueOp>(
267 loc, ty, oper, rewriter->getArrayAttr(idx));
268 newOpers.push_back(val);
269 }
270 }
271
272 void rewriteCallOperands(
273 mlir::Location loc, fir::CodeGenSpecifics::Marshalling passArgAs,
274 mlir::Type originalArgTy, mlir::Value oper,
275 llvm::SmallVectorImpl<mlir::Value> &newOpers, mlir::Value &savedStackPtr,
276 fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs) {
277 if (passArgAs.size() == 1) {
278 // COMPLEX or derived type is passed as a single argument.
279 passArgumentOnStackOrWithNewType(loc, passArgAs[0], originalArgTy, oper,
280 newOpers, savedStackPtr);
281 } else {
282 // COMPLEX or derived type is split into separate arguments
283 passSplitArgument(loc, passArgAs, originalArgTy, oper, newOpers,
284 savedStackPtr);
285 }
286 newInTyAndAttrs.insert(newInTyAndAttrs.end(), passArgAs.begin(),
287 passArgAs.end());
288 }
289
290 template <typename CPLX>
291 void rewriteCallComplexInputType(
292 mlir::Location loc, CPLX ty, mlir::Value oper,
293 fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
294 llvm::SmallVectorImpl<mlir::Value> &newOpers,
295 mlir::Value &savedStackPtr) {
296 if (noComplexConversion) {
297 newInTyAndAttrs.push_back(fir::CodeGenSpecifics::getTypeAndAttr(ty));
298 newOpers.push_back(oper);
299 return;
300 }
301 auto m = specifics->complexArgumentType(loc, ty.getElementType());
302 rewriteCallOperands(loc, m, ty, oper, newOpers, savedStackPtr,
303 newInTyAndAttrs);
304 }
305
306 void rewriteCallStructInputType(
307 mlir::Location loc, fir::RecordType recTy, mlir::Value oper,
308 fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
309 llvm::SmallVectorImpl<mlir::Value> &newOpers,
310 mlir::Value &savedStackPtr) {
311 if (noStructConversion) {
312 newInTyAndAttrs.push_back(fir::CodeGenSpecifics::getTypeAndAttr(recTy));
313 newOpers.push_back(oper);
314 return;
315 }
316 auto structArgs =
317 specifics->structArgumentType(loc, recTy, newInTyAndAttrs);
318 rewriteCallOperands(loc, structArgs, recTy, oper, newOpers, savedStackPtr,
319 newInTyAndAttrs);
320 }
321
322 static bool hasByValOrSRetArgs(
323 const fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs) {
324 return llvm::any_of(newInTyAndAttrs, [](auto arg) {
325 const auto &attr = std::get<fir::CodeGenSpecifics::Attributes>(arg);
326 return attr.isByVal() || attr.isSRet();
327 });
328 }
329
330 // Convert fir.call and fir.dispatch Ops.
331 template <typename A>
332 void convertCallOp(A callOp) {
333 auto fnTy = callOp.getFunctionType();
334 auto loc = callOp.getLoc();
335 rewriter->setInsertionPoint(callOp);
336 llvm::SmallVector<mlir::Type> newResTys;
337 fir::CodeGenSpecifics::Marshalling newInTyAndAttrs;
338 llvm::SmallVector<mlir::Value> newOpers;
339 mlir::Value savedStackPtr = nullptr;
340
341 // If the call is indirect, the first argument must still be the function
342 // to call.
343 int dropFront = 0;
344 if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) {
345 if (!callOp.getCallee()) {
346 newInTyAndAttrs.push_back(
347 fir::CodeGenSpecifics::getTypeAndAttr(fnTy.getInput(0)));
348 newOpers.push_back(callOp.getOperand(0));
349 dropFront = 1;
350 }
351 } else {
352 dropFront = 1; // First operand is the polymorphic object.
353 }
354
355 // Determine the rewrite function, `wrap`, for the result value.
356 std::optional<std::function<mlir::Value(mlir::Operation *)>> wrap;
357 if (fnTy.getResults().size() == 1) {
358 mlir::Type ty = fnTy.getResult(0);
359 llvm::TypeSwitch<mlir::Type>(ty)
360 .template Case<fir::ComplexType>([&](fir::ComplexType cmplx) {
361 wrap = rewriteCallComplexResultType(loc, cmplx, newResTys,
362 newInTyAndAttrs, newOpers,
363 savedStackPtr);
364 })
365 .template Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) {
366 wrap = rewriteCallComplexResultType(loc, cmplx, newResTys,
367 newInTyAndAttrs, newOpers,
368 savedStackPtr);
369 })
370 .Default([&](mlir::Type ty) { newResTys.push_back(ty); });
371 } else if (fnTy.getResults().size() > 1) {
372 TODO(loc, "multiple results not supported yet");
373 }
374
375 llvm::SmallVector<mlir::Type> trailingInTys;
376 llvm::SmallVector<mlir::Value> trailingOpers;
377 unsigned passArgShift = 0;
378 for (auto e : llvm::enumerate(
379 llvm::zip(fnTy.getInputs().drop_front(dropFront),
380 callOp.getOperands().drop_front(dropFront)))) {
381 mlir::Type ty = std::get<0>(e.value());
382 mlir::Value oper = std::get<1>(e.value());
383 unsigned index = e.index();
384 llvm::TypeSwitch<mlir::Type>(ty)
385 .template Case<fir::BoxCharType>([&](fir::BoxCharType boxTy) {
386 bool sret;
387 if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) {
388 if (noCharacterConversion) {
389 newInTyAndAttrs.push_back(
390 fir::CodeGenSpecifics::getTypeAndAttr(boxTy));
391 newOpers.push_back(oper);
392 return;
393 }
394 sret = callOp.getCallee() &&
395 functionArgIsSRet(
396 index, getModule().lookupSymbol<mlir::func::FuncOp>(
397 *callOp.getCallee()));
398 } else {
399 // TODO: dispatch case; how do we put arguments on a call?
400 // We cannot put both an sret and the dispatch object first.
401 sret = false;
402 TODO(loc, "dispatch + sret not supported yet");
403 }
404 auto m = specifics->boxcharArgumentType(boxTy.getEleTy(), sret);
405 auto unbox = rewriter->create<fir::UnboxCharOp>(
406 loc, std::get<mlir::Type>(m[0]), std::get<mlir::Type>(m[1]),
407 oper);
408 // unboxed CHARACTER arguments
409 for (auto e : llvm::enumerate(m)) {
410 unsigned idx = e.index();
411 auto attr =
412 std::get<fir::CodeGenSpecifics::Attributes>(e.value());
413 auto argTy = std::get<mlir::Type>(e.value());
414 if (attr.isAppend()) {
415 trailingInTys.push_back(argTy);
416 trailingOpers.push_back(unbox.getResult(idx));
417 } else {
418 newInTyAndAttrs.push_back(e.value());
419 newOpers.push_back(unbox.getResult(idx));
420 }
421 }
422 })
423 .template Case<fir::ComplexType>([&](fir::ComplexType cmplx) {
424 rewriteCallComplexInputType(loc, cmplx, oper, newInTyAndAttrs,
425 newOpers, savedStackPtr);
426 })
427 .template Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) {
428 rewriteCallComplexInputType(loc, cmplx, oper, newInTyAndAttrs,
429 newOpers, savedStackPtr);
430 })
431 .template Case<fir::RecordType>([&](fir::RecordType recTy) {
432 rewriteCallStructInputType(loc, recTy, oper, newInTyAndAttrs,
433 newOpers, savedStackPtr);
434 })
435 .template Case<mlir::TupleType>([&](mlir::TupleType tuple) {
436 if (fir::isCharacterProcedureTuple(tuple)) {
437 mlir::ModuleOp module = getModule();
438 if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) {
439 if (callOp.getCallee()) {
440 llvm::StringRef charProcAttr =
441 fir::getCharacterProcedureDummyAttrName();
442 // The charProcAttr attribute is only used as a safety to
443 // confirm that this is a dummy procedure and should be split.
444 // It cannot be used to match because attributes are not
445 // available in case of indirect calls.
446 auto funcOp = module.lookupSymbol<mlir::func::FuncOp>(
447 *callOp.getCallee());
448 if (funcOp &&
449 !funcOp.template getArgAttrOfType<mlir::UnitAttr>(
450 index, charProcAttr))
451 mlir::emitError(loc, "tuple argument will be split even "
452 "though it does not have the `" +
453 charProcAttr + "` attribute");
454 }
455 }
456 mlir::Type funcPointerType = tuple.getType(0);
457 mlir::Type lenType = tuple.getType(1);
458 fir::FirOpBuilder builder(*rewriter, module);
459 auto [funcPointer, len] =
460 fir::factory::extractCharacterProcedureTuple(builder, loc,
461 oper);
462 newInTyAndAttrs.push_back(
463 fir::CodeGenSpecifics::getTypeAndAttr(funcPointerType));
464 newOpers.push_back(funcPointer);
465 trailingInTys.push_back(lenType);
466 trailingOpers.push_back(len);
467 } else {
468 newInTyAndAttrs.push_back(
469 fir::CodeGenSpecifics::getTypeAndAttr(tuple));
470 newOpers.push_back(oper);
471 }
472 })
473 .Default([&](mlir::Type ty) {
474 if constexpr (std::is_same_v<std::decay_t<A>, fir::DispatchOp>) {
475 if (callOp.getPassArgPos() && *callOp.getPassArgPos() == index)
476 passArgShift = newOpers.size() - *callOp.getPassArgPos();
477 }
478 newInTyAndAttrs.push_back(
479 fir::CodeGenSpecifics::getTypeAndAttr(ty));
480 newOpers.push_back(oper);
481 });
482 }
483
484 llvm::SmallVector<mlir::Type> newInTypes = toTypeList(newInTyAndAttrs);
485 newInTypes.insert(newInTypes.end(), trailingInTys.begin(),
486 trailingInTys.end());
487 newOpers.insert(newOpers.end(), trailingOpers.begin(), trailingOpers.end());
488
489 llvm::SmallVector<mlir::Value, 1> newCallResults;
490 if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) {
491 fir::CallOp newCall;
492 if (callOp.getCallee()) {
493 newCall =
494 rewriter->create<A>(loc, *callOp.getCallee(), newResTys, newOpers);
495 } else {
496 // TODO: llvm dialect must be updated to propagate argument on
497 // attributes for indirect calls. See:
498 // https://discourse.llvm.org/t/should-llvm-callop-be-able-to-carry-argument-attributes-for-indirect-calls/75431
499 if (hasByValOrSRetArgs(newInTyAndAttrs))
500 TODO(loc,
501 "passing argument or result on the stack in indirect calls");
502 newOpers[0].setType(mlir::FunctionType::get(
503 callOp.getContext(),
504 mlir::TypeRange{newInTypes}.drop_front(dropFront), newResTys));
505 newCall = rewriter->create<A>(loc, newResTys, newOpers);
506 }
507 LLVM_DEBUG(llvm::dbgs() << "replacing call with " << newCall << '\n');
508 if (wrap)
509 newCallResults.push_back((*wrap)(newCall.getOperation()));
510 else
511 newCallResults.append(newCall.result_begin(), newCall.result_end());
512 } else {
513 fir::DispatchOp dispatchOp = rewriter->create<A>(
514 loc, newResTys, rewriter->getStringAttr(callOp.getMethod()),
515 callOp.getOperands()[0], newOpers,
516 rewriter->getI32IntegerAttr(*callOp.getPassArgPos() + passArgShift));
517 if (wrap)
518 newCallResults.push_back((*wrap)(dispatchOp.getOperation()));
519 else
520 newCallResults.append(dispatchOp.result_begin(),
521 dispatchOp.result_end());
522 }
523
524 if (newCallResults.size() <= 1) {
525 if (savedStackPtr) {
526 if (newCallResults.size() == 1) {
527 // We assume that all the allocas are inserted before
528 // the operation that defines the new call result.
529 rewriter->setInsertionPointAfterValue(newCallResults[0]);
530 } else {
531 // If the call does not have results, then insert
532 // stack restore after the original call operation.
533 rewriter->setInsertionPointAfter(callOp);
534 }
535 genStackRestore(loc, savedStackPtr);
536 }
537 replaceOp(callOp, newCallResults);
538 } else {
539 // The TODO is duplicated here to make sure this part
540 // handles the stackrestore insertion properly, if
541 // we add support for multiple call results.
542 TODO(loc, "multiple results not supported yet");
543 }
544 }
545
546 // Result type fixup for fir::ComplexType and mlir::ComplexType
547 template <typename A, typename B>
548 void lowerComplexSignatureRes(
549 mlir::Location loc, A cmplx, B &newResTys,
550 fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs) {
551 if (noComplexConversion) {
552 newResTys.push_back(cmplx);
553 return;
554 }
555 for (auto &tup :
556 specifics->complexReturnType(loc, cmplx.getElementType())) {
557 auto argTy = std::get<mlir::Type>(tup);
558 if (std::get<fir::CodeGenSpecifics::Attributes>(tup).isSRet())
559 newInTyAndAttrs.push_back(tup);
560 else
561 newResTys.push_back(argTy);
562 }
563 }
564
565 // Argument type fixup for fir::ComplexType and mlir::ComplexType
566 template <typename A>
567 void lowerComplexSignatureArg(
568 mlir::Location loc, A cmplx,
569 fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs) {
570 if (noComplexConversion) {
571 newInTyAndAttrs.push_back(fir::CodeGenSpecifics::getTypeAndAttr(cmplx));
572 } else {
573 auto cplxArgs =
574 specifics->complexArgumentType(loc, cmplx.getElementType());
575 newInTyAndAttrs.insert(newInTyAndAttrs.end(), cplxArgs.begin(),
576 cplxArgs.end());
577 }
578 }
579
580 void
581 lowerStructSignatureArg(mlir::Location loc, fir::RecordType recTy,
582 fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs) {
583 if (noStructConversion) {
584 newInTyAndAttrs.push_back(fir::CodeGenSpecifics::getTypeAndAttr(recTy));
585 return;
586 }
587 auto structArgs =
588 specifics->structArgumentType(loc, recTy, newInTyAndAttrs);
589 newInTyAndAttrs.insert(newInTyAndAttrs.end(), structArgs.begin(),
590 structArgs.end());
591 }
592
593 llvm::SmallVector<mlir::Type>
594 toTypeList(const fir::CodeGenSpecifics::Marshalling &marshalled) {
595 llvm::SmallVector<mlir::Type> typeList;
596 for (auto &typeAndAttr : marshalled)
597 typeList.emplace_back(std::get<mlir::Type>(typeAndAttr));
598 return typeList;
599 }
600
601 /// Taking the address of a function. Modify the signature as needed.
602 void convertAddrOp(fir::AddrOfOp addrOp) {
603 rewriter->setInsertionPoint(addrOp);
604 auto addrTy = addrOp.getType().cast<mlir::FunctionType>();
605 fir::CodeGenSpecifics::Marshalling newInTyAndAttrs;
606 llvm::SmallVector<mlir::Type> newResTys;
607 auto loc = addrOp.getLoc();
608 for (mlir::Type ty : addrTy.getResults()) {
609 llvm::TypeSwitch<mlir::Type>(ty)
610 .Case<fir::ComplexType>([&](fir::ComplexType ty) {
611 lowerComplexSignatureRes(loc, ty, newResTys, newInTyAndAttrs);
612 })
613 .Case<mlir::ComplexType>([&](mlir::ComplexType ty) {
614 lowerComplexSignatureRes(loc, ty, newResTys, newInTyAndAttrs);
615 })
616 .Default([&](mlir::Type ty) { newResTys.push_back(ty); });
617 }
618 llvm::SmallVector<mlir::Type> trailingInTys;
619 for (mlir::Type ty : addrTy.getInputs()) {
620 llvm::TypeSwitch<mlir::Type>(ty)
621 .Case<fir::BoxCharType>([&](auto box) {
622 if (noCharacterConversion) {
623 newInTyAndAttrs.push_back(
624 fir::CodeGenSpecifics::getTypeAndAttr(box));
625 } else {
626 for (auto &tup : specifics->boxcharArgumentType(box.getEleTy())) {
627 auto attr = std::get<fir::CodeGenSpecifics::Attributes>(tup);
628 auto argTy = std::get<mlir::Type>(tup);
629 if (attr.isAppend())
630 trailingInTys.push_back(argTy);
631 else
632 newInTyAndAttrs.push_back(tup);
633 }
634 }
635 })
636 .Case<fir::ComplexType>([&](fir::ComplexType ty) {
637 lowerComplexSignatureArg(loc, ty, newInTyAndAttrs);
638 })
639 .Case<mlir::ComplexType>([&](mlir::ComplexType ty) {
640 lowerComplexSignatureArg(loc, ty, newInTyAndAttrs);
641 })
642 .Case<mlir::TupleType>([&](mlir::TupleType tuple) {
643 if (fir::isCharacterProcedureTuple(tuple)) {
644 newInTyAndAttrs.push_back(
645 fir::CodeGenSpecifics::getTypeAndAttr(tuple.getType(0)));
646 trailingInTys.push_back(tuple.getType(1));
647 } else {
648 newInTyAndAttrs.push_back(
649 fir::CodeGenSpecifics::getTypeAndAttr(ty));
650 }
651 })
652 .template Case<fir::RecordType>([&](fir::RecordType recTy) {
653 lowerStructSignatureArg(loc, recTy, newInTyAndAttrs);
654 })
655 .Default([&](mlir::Type ty) {
656 newInTyAndAttrs.push_back(
657 fir::CodeGenSpecifics::getTypeAndAttr(ty));
658 });
659 }
660 llvm::SmallVector<mlir::Type> newInTypes = toTypeList(newInTyAndAttrs);
661 // append trailing input types
662 newInTypes.insert(newInTypes.end(), trailingInTys.begin(),
663 trailingInTys.end());
664 // replace this op with a new one with the updated signature
665 auto newTy = rewriter->getFunctionType(newInTypes, newResTys);
666 auto newOp = rewriter->create<fir::AddrOfOp>(addrOp.getLoc(), newTy,
667 addrOp.getSymbol());
668 replaceOp(addrOp, newOp.getResult());
669 }
670
671 /// Convert the type signatures on all the functions present in the module.
672 /// As the type signature is being changed, this must also update the
673 /// function itself to use any new arguments, etc.
674 mlir::LogicalResult convertTypes(mlir::ModuleOp mod) {
675 mlir::MLIRContext *ctx = mod->getContext();
676 auto targetCPU = specifics->getTargetCPU();
677 mlir::StringAttr targetCPUAttr =
678 targetCPU.empty() ? nullptr : mlir::StringAttr::get(ctx, targetCPU);
679 auto targetFeaturesAttr = specifics->getTargetFeatures();
680
681 for (auto fn : mod.getOps<mlir::func::FuncOp>()) {
682 if (targetCPUAttr)
683 fn->setAttr("target_cpu", targetCPUAttr);
684
685 if (targetFeaturesAttr)
686 fn->setAttr("target_features", targetFeaturesAttr);
687
688 convertSignature(fn);
689 }
690 return mlir::success();
691 }
692
693 // Returns true if the function should be interoperable with C.
694 static bool isFuncWithCCallingConvention(mlir::Operation *op) {
695 auto funcOp = mlir::dyn_cast<mlir::func::FuncOp>(op);
696 if (!funcOp)
697 return false;
698 return op->hasAttrOfType<mlir::UnitAttr>(
699 fir::FIROpsDialect::getFirRuntimeAttrName()) ||
700 op->hasAttrOfType<mlir::StringAttr>(fir::getSymbolAttrName());
701 }
702
703 /// If the signature does not need any special target-specific conversions,
704 /// then it is considered portable for any target, and this function will
705 /// return `true`. Otherwise, the signature is not portable and `false` is
706 /// returned.
707 bool hasPortableSignature(mlir::Type signature, mlir::Operation *op) {
708 assert(signature.isa<mlir::FunctionType>());
709 auto func = signature.dyn_cast<mlir::FunctionType>();
710 bool hasCCallingConv = isFuncWithCCallingConvention(op);
711 for (auto ty : func.getResults())
712 if ((ty.isa<fir::BoxCharType>() && !noCharacterConversion) ||
713 (fir::isa_complex(ty) && !noComplexConversion) ||
714 (ty.isa<mlir::IntegerType>() && hasCCallingConv)) {
715 LLVM_DEBUG(llvm::dbgs() << "rewrite " << signature << " for target\n");
716 return false;
717 }
718 for (auto ty : func.getInputs())
719 if (((ty.isa<fir::BoxCharType>() || fir::isCharacterProcedureTuple(ty)) &&
720 !noCharacterConversion) ||
721 (fir::isa_complex(ty) && !noComplexConversion) ||
722 (ty.isa<mlir::IntegerType>() && hasCCallingConv) ||
723 (ty.isa<fir::RecordType>() && !noStructConversion)) {
724 LLVM_DEBUG(llvm::dbgs() << "rewrite " << signature << " for target\n");
725 return false;
726 }
727 return true;
728 }
729
730 /// Determine if the signature has host associations. The host association
731 /// argument may need special target specific rewriting.
732 static bool hasHostAssociations(mlir::func::FuncOp func) {
733 std::size_t end = func.getFunctionType().getInputs().size();
734 for (std::size_t i = 0; i < end; ++i)
735 if (func.getArgAttrOfType<mlir::UnitAttr>(i, fir::getHostAssocAttrName()))
736 return true;
737 return false;
738 }
739
740 /// Rewrite the signatures and body of the `FuncOp`s in the module for
741 /// the immediately subsequent target code gen.
742 void convertSignature(mlir::func::FuncOp func) {
743 auto funcTy = func.getFunctionType().cast<mlir::FunctionType>();
744 if (hasPortableSignature(funcTy, func) && !hasHostAssociations(func))
745 return;
746 llvm::SmallVector<mlir::Type> newResTys;
747 fir::CodeGenSpecifics::Marshalling newInTyAndAttrs;
748 llvm::SmallVector<std::pair<unsigned, mlir::NamedAttribute>> savedAttrs;
749 llvm::SmallVector<std::pair<unsigned, mlir::NamedAttribute>> extraAttrs;
750 llvm::SmallVector<FixupTy> fixups;
751 llvm::SmallVector<std::pair<unsigned, mlir::NamedAttrList>, 1> resultAttrs;
752
753 // Save argument attributes in case there is a shift so we can replace them
754 // correctly.
755 for (auto e : llvm::enumerate(funcTy.getInputs())) {
756 unsigned index = e.index();
757 llvm::ArrayRef<mlir::NamedAttribute> attrs =
758 mlir::function_interface_impl::getArgAttrs(func, index);
759 for (mlir::NamedAttribute attr : attrs) {
760 savedAttrs.push_back({index, attr});
761 }
762 }
763
764 // Convert return value(s)
765 for (auto ty : funcTy.getResults())
766 llvm::TypeSwitch<mlir::Type>(ty)
767 .Case<fir::ComplexType>([&](fir::ComplexType cmplx) {
768 if (noComplexConversion)
769 newResTys.push_back(cmplx);
770 else
771 doComplexReturn(func, cmplx, newResTys, newInTyAndAttrs, fixups);
772 })
773 .Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) {
774 if (noComplexConversion)
775 newResTys.push_back(cmplx);
776 else
777 doComplexReturn(func, cmplx, newResTys, newInTyAndAttrs, fixups);
778 })
779 .Case<mlir::IntegerType>([&](mlir::IntegerType intTy) {
780 auto m = specifics->integerArgumentType(func.getLoc(), intTy);
781 assert(m.size() == 1);
782 auto attr = std::get<fir::CodeGenSpecifics::Attributes>(m[0]);
783 auto retTy = std::get<mlir::Type>(m[0]);
784 std::size_t resId = newResTys.size();
785 llvm::StringRef extensionAttrName = attr.getIntExtensionAttrName();
786 if (!extensionAttrName.empty() &&
787 isFuncWithCCallingConvention(func))
788 resultAttrs.emplace_back(
789 resId, rewriter->getNamedAttr(extensionAttrName,
790 rewriter->getUnitAttr()));
791 newResTys.push_back(retTy);
792 })
793 .Default([&](mlir::Type ty) { newResTys.push_back(ty); });
794
795 // Saved potential shift in argument. Handling of result can add arguments
796 // at the beginning of the function signature.
797 unsigned argumentShift = newInTyAndAttrs.size();
798
799 // Convert arguments
800 llvm::SmallVector<mlir::Type> trailingTys;
801 for (auto e : llvm::enumerate(funcTy.getInputs())) {
802 auto ty = e.value();
803 unsigned index = e.index();
804 llvm::TypeSwitch<mlir::Type>(ty)
805 .Case<fir::BoxCharType>([&](fir::BoxCharType boxTy) {
806 if (noCharacterConversion) {
807 newInTyAndAttrs.push_back(
808 fir::CodeGenSpecifics::getTypeAndAttr(boxTy));
809 } else {
810 // Convert a CHARACTER argument type. This can involve separating
811 // the pointer and the LEN into two arguments and moving the LEN
812 // argument to the end of the arg list.
813 bool sret = functionArgIsSRet(index, func);
814 for (auto e : llvm::enumerate(specifics->boxcharArgumentType(
815 boxTy.getEleTy(), sret))) {
816 auto &tup = e.value();
817 auto index = e.index();
818 auto attr = std::get<fir::CodeGenSpecifics::Attributes>(tup);
819 auto argTy = std::get<mlir::Type>(tup);
820 if (attr.isAppend()) {
821 trailingTys.push_back(argTy);
822 } else {
823 if (sret) {
824 fixups.emplace_back(FixupTy::Codes::CharPair,
825 newInTyAndAttrs.size(), index);
826 } else {
827 fixups.emplace_back(FixupTy::Codes::Trailing,
828 newInTyAndAttrs.size(),
829 trailingTys.size());
830 }
831 newInTyAndAttrs.push_back(tup);
832 }
833 }
834 }
835 })
836 .Case<fir::ComplexType>([&](fir::ComplexType cmplx) {
837 doComplexArg(func, cmplx, newInTyAndAttrs, fixups);
838 })
839 .Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) {
840 doComplexArg(func, cmplx, newInTyAndAttrs, fixups);
841 })
842 .Case<mlir::TupleType>([&](mlir::TupleType tuple) {
843 if (fir::isCharacterProcedureTuple(tuple)) {
844 fixups.emplace_back(FixupTy::Codes::TrailingCharProc,
845 newInTyAndAttrs.size(), trailingTys.size());
846 newInTyAndAttrs.push_back(
847 fir::CodeGenSpecifics::getTypeAndAttr(tuple.getType(0)));
848 trailingTys.push_back(tuple.getType(1));
849 } else {
850 newInTyAndAttrs.push_back(
851 fir::CodeGenSpecifics::getTypeAndAttr(ty));
852 }
853 })
854 .Case<mlir::IntegerType>([&](mlir::IntegerType intTy) {
855 auto m = specifics->integerArgumentType(func.getLoc(), intTy);
856 assert(m.size() == 1);
857 auto attr = std::get<fir::CodeGenSpecifics::Attributes>(m[0]);
858 auto argNo = newInTyAndAttrs.size();
859 llvm::StringRef extensionAttrName = attr.getIntExtensionAttrName();
860 if (!extensionAttrName.empty() &&
861 isFuncWithCCallingConvention(func))
862 fixups.emplace_back(FixupTy::Codes::ArgumentType, argNo,
863 [=](mlir::func::FuncOp func) {
864 func.setArgAttr(
865 argNo, extensionAttrName,
866 mlir::UnitAttr::get(func.getContext()));
867 });
868
869 newInTyAndAttrs.push_back(m[0]);
870 })
871 .template Case<fir::RecordType>([&](fir::RecordType recTy) {
872 doStructArg(func, recTy, newInTyAndAttrs, fixups);
873 })
874 .Default([&](mlir::Type ty) {
875 newInTyAndAttrs.push_back(
876 fir::CodeGenSpecifics::getTypeAndAttr(ty));
877 });
878
879 if (func.getArgAttrOfType<mlir::UnitAttr>(index,
880 fir::getHostAssocAttrName())) {
881 extraAttrs.push_back(
882 {newInTyAndAttrs.size() - 1,
883 rewriter->getNamedAttr("llvm.nest", rewriter->getUnitAttr())});
884 }
885 }
886
887 if (!func.empty()) {
888 // If the function has a body, then apply the fixups to the arguments and
889 // return ops as required. These fixups are done in place.
890 auto loc = func.getLoc();
891 const auto fixupSize = fixups.size();
892 const auto oldArgTys = func.getFunctionType().getInputs();
893 int offset = 0;
894 for (std::remove_const_t<decltype(fixupSize)> i = 0; i < fixupSize; ++i) {
895 const auto &fixup = fixups[i];
896 mlir::Type fixupType =
897 fixup.index < newInTyAndAttrs.size()
898 ? std::get<mlir::Type>(newInTyAndAttrs[fixup.index])
899 : mlir::Type{};
900 switch (fixup.code) {
901 case FixupTy::Codes::ArgumentAsLoad: {
902 // Argument was pass-by-value, but is now pass-by-reference and
903 // possibly with a different element type.
904 auto newArg =
905 func.front().insertArgument(fixup.index, fixupType, loc);
906 rewriter->setInsertionPointToStart(&func.front());
907 auto oldArgTy =
908 fir::ReferenceType::get(oldArgTys[fixup.index - offset]);
909 auto cast = rewriter->create<fir::ConvertOp>(loc, oldArgTy, newArg);
910 auto load = rewriter->create<fir::LoadOp>(loc, cast);
911 func.getArgument(fixup.index + 1).replaceAllUsesWith(load);
912 func.front().eraseArgument(fixup.index + 1);
913 } break;
914 case FixupTy::Codes::ArgumentType: {
915 // Argument is pass-by-value, but its type has likely been modified to
916 // suit the target ABI convention.
917 auto oldArgTy = oldArgTys[fixup.index - offset];
918 // If type did not change, keep the original argument.
919 if (fixupType == oldArgTy)
920 break;
921
922 auto newArg =
923 func.front().insertArgument(fixup.index, fixupType, loc);
924 rewriter->setInsertionPointToStart(&func.front());
925 mlir::Value bitcast = convertValueInMemory(loc, newArg, oldArgTy,
926 /*inputMayBeBigger=*/true);
927 func.getArgument(fixup.index + 1).replaceAllUsesWith(bitcast);
928 func.front().eraseArgument(fixup.index + 1);
929 LLVM_DEBUG(llvm::dbgs()
930 << "old argument: " << oldArgTy << ", repl: " << bitcast
931 << ", new argument: "
932 << func.getArgument(fixup.index).getType() << '\n');
933 } break;
934 case FixupTy::Codes::CharPair: {
935 // The FIR boxchar argument has been split into a pair of distinct
936 // arguments that are in juxtaposition to each other.
937 auto newArg =
938 func.front().insertArgument(fixup.index, fixupType, loc);
939 if (fixup.second == 1) {
940 rewriter->setInsertionPointToStart(&func.front());
941 auto boxTy = oldArgTys[fixup.index - offset - fixup.second];
942 auto box = rewriter->create<fir::EmboxCharOp>(
943 loc, boxTy, func.front().getArgument(fixup.index - 1), newArg);
944 func.getArgument(fixup.index + 1).replaceAllUsesWith(box);
945 func.front().eraseArgument(fixup.index + 1);
946 offset++;
947 }
948 } break;
949 case FixupTy::Codes::ReturnAsStore: {
950 // The value being returned is now being returned in memory (callee
951 // stack space) through a hidden reference argument.
952 auto newArg =
953 func.front().insertArgument(fixup.index, fixupType, loc);
954 offset++;
955 func.walk([&](mlir::func::ReturnOp ret) {
956 rewriter->setInsertionPoint(ret);
957 auto oldOper = ret.getOperand(0);
958 auto oldOperTy = fir::ReferenceType::get(oldOper.getType());
959 auto cast =
960 rewriter->create<fir::ConvertOp>(loc, oldOperTy, newArg);
961 rewriter->create<fir::StoreOp>(loc, oldOper, cast);
962 rewriter->create<mlir::func::ReturnOp>(loc);
963 ret.erase();
964 });
965 } break;
966 case FixupTy::Codes::ReturnType: {
967 // The function is still returning a value, but its type has likely
968 // changed to suit the target ABI convention.
969 func.walk([&](mlir::func::ReturnOp ret) {
970 rewriter->setInsertionPoint(ret);
971 auto oldOper = ret.getOperand(0);
972 mlir::Value bitcast =
973 convertValueInMemory(loc, oldOper, newResTys[fixup.index],
974 /*inputMayBeBigger=*/false);
975 rewriter->create<mlir::func::ReturnOp>(loc, bitcast);
976 ret.erase();
977 });
978 } break;
979 case FixupTy::Codes::Split: {
980 // The FIR argument has been split into a pair of distinct arguments
981 // that are in juxtaposition to each other. (For COMPLEX value or
982 // derived type passed with VALUE in BIND(C) context).
983 auto newArg =
984 func.front().insertArgument(fixup.index, fixupType, loc);
985 if (fixup.second == 1) {
986 rewriter->setInsertionPointToStart(&func.front());
987 mlir::Value firstArg = func.front().getArgument(fixup.index - 1);
988 mlir::Type originalTy =
989 oldArgTys[fixup.index - offset - fixup.second];
990 mlir::Type pairTy = originalTy;
991 if (!fir::isa_complex(originalTy)) {
992 pairTy = mlir::TupleType::get(
993 originalTy.getContext(),
994 mlir::TypeRange{firstArg.getType(), newArg.getType()});
995 }
996 auto undef = rewriter->create<fir::UndefOp>(loc, pairTy);
997 auto iTy = rewriter->getIntegerType(32);
998 auto zero = rewriter->getIntegerAttr(iTy, 0);
999 auto one = rewriter->getIntegerAttr(iTy, 1);
1000 mlir::Value pair1 = rewriter->create<fir::InsertValueOp>(
1001 loc, pairTy, undef, firstArg, rewriter->getArrayAttr(zero));
1002 mlir::Value pair = rewriter->create<fir::InsertValueOp>(
1003 loc, pairTy, pair1, newArg, rewriter->getArrayAttr(one));
1004 // Cast local argument tuple to original type via memory if needed.
1005 if (pairTy != originalTy)
1006 pair = convertValueInMemory(loc, pair, originalTy,
1007 /*inputMayBeBigger=*/true);
1008 func.getArgument(fixup.index + 1).replaceAllUsesWith(pair);
1009 func.front().eraseArgument(fixup.index + 1);
1010 offset++;
1011 }
1012 } break;
1013 case FixupTy::Codes::Trailing: {
1014 // The FIR argument has been split into a pair of distinct arguments.
1015 // The first part of the pair appears in the original argument
1016 // position. The second part of the pair is appended after all the
1017 // original arguments. (Boxchar arguments.)
1018 auto newBufArg =
1019 func.front().insertArgument(fixup.index, fixupType, loc);
1020 auto newLenArg =
1021 func.front().addArgument(trailingTys[fixup.second], loc);
1022 auto boxTy = oldArgTys[fixup.index - offset];
1023 rewriter->setInsertionPointToStart(&func.front());
1024 auto box = rewriter->create<fir::EmboxCharOp>(loc, boxTy, newBufArg,
1025 newLenArg);
1026 func.getArgument(fixup.index + 1).replaceAllUsesWith(box);
1027 func.front().eraseArgument(fixup.index + 1);
1028 } break;
1029 case FixupTy::Codes::TrailingCharProc: {
1030 // The FIR character procedure argument tuple must be split into a
1031 // pair of distinct arguments. The first part of the pair appears in
1032 // the original argument position. The second part of the pair is
1033 // appended after all the original arguments.
1034 auto newProcPointerArg =
1035 func.front().insertArgument(fixup.index, fixupType, loc);
1036 auto newLenArg =
1037 func.front().addArgument(trailingTys[fixup.second], loc);
1038 auto tupleType = oldArgTys[fixup.index - offset];
1039 rewriter->setInsertionPointToStart(&func.front());
1040 fir::FirOpBuilder builder(*rewriter, getModule());
1041 auto tuple = fir::factory::createCharacterProcedureTuple(
1042 builder, loc, tupleType, newProcPointerArg, newLenArg);
1043 func.getArgument(fixup.index + 1).replaceAllUsesWith(tuple);
1044 func.front().eraseArgument(fixup.index + 1);
1045 } break;
1046 }
1047 }
1048 }
1049
1050 llvm::SmallVector<mlir::Type> newInTypes = toTypeList(newInTyAndAttrs);
1051 // Set the new type and finalize the arguments, etc.
1052 newInTypes.insert(newInTypes.end(), trailingTys.begin(), trailingTys.end());
1053 auto newFuncTy =
1054 mlir::FunctionType::get(func.getContext(), newInTypes, newResTys);
1055 LLVM_DEBUG(llvm::dbgs() << "new func: " << newFuncTy << '\n');
1056 func.setType(newFuncTy);
1057
1058 for (std::pair<unsigned, mlir::NamedAttribute> extraAttr : extraAttrs)
1059 func.setArgAttr(extraAttr.first, extraAttr.second.getName(),
1060 extraAttr.second.getValue());
1061
1062 for (auto [resId, resAttrList] : resultAttrs)
1063 for (mlir::NamedAttribute resAttr : resAttrList)
1064 func.setResultAttr(resId, resAttr.getName(), resAttr.getValue());
1065
1066 // Replace attributes to the correct argument if there was an argument shift
1067 // to the right.
1068 if (argumentShift > 0) {
1069 for (std::pair<unsigned, mlir::NamedAttribute> savedAttr : savedAttrs) {
1070 func.removeArgAttr(savedAttr.first, savedAttr.second.getName());
1071 func.setArgAttr(savedAttr.first + argumentShift,
1072 savedAttr.second.getName(),
1073 savedAttr.second.getValue());
1074 }
1075 }
1076
1077 for (auto &fixup : fixups)
1078 if (fixup.finalizer)
1079 (*fixup.finalizer)(func);
1080 }
1081
1082 inline bool functionArgIsSRet(unsigned index, mlir::func::FuncOp func) {
1083 if (auto attr = func.getArgAttrOfType<mlir::TypeAttr>(index, "llvm.sret"))
1084 return true;
1085 return false;
1086 }
1087
1088 /// Convert a complex return value. This can involve converting the return
1089 /// value to a "hidden" first argument or packing the complex into a wide
1090 /// GPR.
1091 template <typename A, typename B, typename C>
1092 void doComplexReturn(mlir::func::FuncOp func, A cmplx, B &newResTys,
1093 fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
1094 C &fixups) {
1095 if (noComplexConversion) {
1096 newResTys.push_back(cmplx);
1097 return;
1098 }
1099 auto m =
1100 specifics->complexReturnType(func.getLoc(), cmplx.getElementType());
1101 assert(m.size() == 1);
1102 auto &tup = m[0];
1103 auto attr = std::get<fir::CodeGenSpecifics::Attributes>(tup);
1104 auto argTy = std::get<mlir::Type>(tup);
1105 if (attr.isSRet()) {
1106 unsigned argNo = newInTyAndAttrs.size();
1107 if (auto align = attr.getAlignment())
1108 fixups.emplace_back(
1109 FixupTy::Codes::ReturnAsStore, argNo, [=](mlir::func::FuncOp func) {
1110 auto elemType = fir::dyn_cast_ptrOrBoxEleTy(
1111 func.getFunctionType().getInput(argNo));
1112 func.setArgAttr(argNo, "llvm.sret",
1113 mlir::TypeAttr::get(elemType));
1114 func.setArgAttr(argNo, "llvm.align",
1115 rewriter->getIntegerAttr(
1116 rewriter->getIntegerType(32), align));
1117 });
1118 else
1119 fixups.emplace_back(FixupTy::Codes::ReturnAsStore, argNo,
1120 [=](mlir::func::FuncOp func) {
1121 auto elemType = fir::dyn_cast_ptrOrBoxEleTy(
1122 func.getFunctionType().getInput(argNo));
1123 func.setArgAttr(argNo, "llvm.sret",
1124 mlir::TypeAttr::get(elemType));
1125 });
1126 newInTyAndAttrs.push_back(tup);
1127 return;
1128 }
1129 if (auto align = attr.getAlignment())
1130 fixups.emplace_back(
1131 FixupTy::Codes::ReturnType, newResTys.size(),
1132 [=](mlir::func::FuncOp func) {
1133 func.setArgAttr(
1134 newResTys.size(), "llvm.align",
1135 rewriter->getIntegerAttr(rewriter->getIntegerType(32), align));
1136 });
1137 else
1138 fixups.emplace_back(FixupTy::Codes::ReturnType, newResTys.size());
1139 newResTys.push_back(argTy);
1140 }
1141
1142 template <typename FIXUPS>
1143 void
1144 createFuncOpArgFixups(mlir::func::FuncOp func,
1145 fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
1146 fir::CodeGenSpecifics::Marshalling &argsInTys,
1147 FIXUPS &fixups) {
1148 const auto fixupCode = argsInTys.size() > 1 ? FixupTy::Codes::Split
1149 : FixupTy::Codes::ArgumentType;
1150 for (auto e : llvm::enumerate(argsInTys)) {
1151 auto &tup = e.value();
1152 auto index = e.index();
1153 auto attr = std::get<fir::CodeGenSpecifics::Attributes>(tup);
1154 auto argNo = newInTyAndAttrs.size();
1155 if (attr.isByVal()) {
1156 if (auto align = attr.getAlignment())
1157 fixups.emplace_back(FixupTy::Codes::ArgumentAsLoad, argNo,
1158 [=](mlir::func::FuncOp func) {
1159 auto elemType = fir::dyn_cast_ptrOrBoxEleTy(
1160 func.getFunctionType().getInput(argNo));
1161 func.setArgAttr(argNo, "llvm.byval",
1162 mlir::TypeAttr::get(elemType));
1163 func.setArgAttr(
1164 argNo, "llvm.align",
1165 rewriter->getIntegerAttr(
1166 rewriter->getIntegerType(32), align));
1167 });
1168 else
1169 fixups.emplace_back(FixupTy::Codes::ArgumentAsLoad,
1170 newInTyAndAttrs.size(),
1171 [=](mlir::func::FuncOp func) {
1172 auto elemType = fir::dyn_cast_ptrOrBoxEleTy(
1173 func.getFunctionType().getInput(argNo));
1174 func.setArgAttr(argNo, "llvm.byval",
1175 mlir::TypeAttr::get(elemType));
1176 });
1177 } else {
1178 if (auto align = attr.getAlignment())
1179 fixups.emplace_back(
1180 fixupCode, argNo, index, [=](mlir::func::FuncOp func) {
1181 func.setArgAttr(argNo, "llvm.align",
1182 rewriter->getIntegerAttr(
1183 rewriter->getIntegerType(32), align));
1184 });
1185 else
1186 fixups.emplace_back(fixupCode, argNo, index);
1187 }
1188 newInTyAndAttrs.push_back(tup);
1189 }
1190 }
1191
1192 /// Convert a complex argument value. This can involve storing the value to
1193 /// a temporary memory location or factoring the value into two distinct
1194 /// arguments.
1195 template <typename A, typename B>
1196 void doComplexArg(mlir::func::FuncOp func, A cmplx,
1197 fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
1198 B &fixups) {
1199 if (noComplexConversion) {
1200 newInTyAndAttrs.push_back(fir::CodeGenSpecifics::getTypeAndAttr(cmplx));
1201 return;
1202 }
1203 auto cplxArgs =
1204 specifics->complexArgumentType(func.getLoc(), cmplx.getElementType());
1205 createFuncOpArgFixups(func, newInTyAndAttrs, cplxArgs, fixups);
1206 }
1207
1208 template <typename FIXUPS>
1209 void doStructArg(mlir::func::FuncOp func, fir::RecordType recTy,
1210 fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
1211 FIXUPS &fixups) {
1212 if (noStructConversion) {
1213 newInTyAndAttrs.push_back(fir::CodeGenSpecifics::getTypeAndAttr(recTy));
1214 return;
1215 }
1216 auto structArgs =
1217 specifics->structArgumentType(func.getLoc(), recTy, newInTyAndAttrs);
1218 createFuncOpArgFixups(func, newInTyAndAttrs, structArgs, fixups);
1219 }
1220
1221private:
1222 // Replace `op` and remove it.
1223 void replaceOp(mlir::Operation *op, mlir::ValueRange newValues) {
1224 op->replaceAllUsesWith(newValues);
1225 op->dropAllReferences();
1226 op->erase();
1227 }
1228
1229 inline void setMembers(fir::CodeGenSpecifics *s, mlir::OpBuilder *r,
1230 mlir::DataLayout *dl) {
1231 specifics = s;
1232 rewriter = r;
1233 dataLayout = dl;
1234 }
1235
1236 inline void clearMembers() { setMembers(nullptr, nullptr, nullptr); }
1237
1238 // Inserts a call to llvm.stacksave at the current insertion
1239 // point and the given location. Returns the call's result Value.
1240 inline mlir::Value genStackSave(mlir::Location loc) {
1241 return rewriter->create<fir::CallOp>(loc, stackSaveFn).getResult(0);
1242 }
1243
1244 // Inserts a call to llvm.stackrestore at the current insertion
1245 // point and the given location and argument.
1246 inline void genStackRestore(mlir::Location loc, mlir::Value sp) {
1247 rewriter->create<fir::CallOp>(loc, stackRestoreFn, mlir::ValueRange{sp});
1248 }
1249
1250 fir::CodeGenSpecifics *specifics = nullptr;
1251 mlir::OpBuilder *rewriter = nullptr;
1252 mlir::DataLayout *dataLayout = nullptr;
1253 mlir::func::FuncOp stackSaveFn = nullptr;
1254 mlir::func::FuncOp stackRestoreFn = nullptr;
1255};
1256} // namespace
1257
1258std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
1259fir::createFirTargetRewritePass(const fir::TargetRewriteOptions &options) {
1260 return std::make_unique<TargetRewrite>(options);
1261}
1262

source code of flang/lib/Optimizer/CodeGen/TargetRewrite.cpp