1//===-- AArch64Arm64ECCallLowering.cpp - Lower Arm64EC calls ----*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8///
9/// \file
10/// This file contains the IR transform to lower external or indirect calls for
11/// the ARM64EC calling convention. Such calls must go through the runtime, so
12/// we can translate the calling convention for calls into the emulator.
13///
14/// This subsumes Control Flow Guard handling.
15///
16//===----------------------------------------------------------------------===//
17
18#include "AArch64.h"
19#include "llvm/ADT/SetVector.h"
20#include "llvm/ADT/SmallString.h"
21#include "llvm/ADT/SmallVector.h"
22#include "llvm/ADT/Statistic.h"
23#include "llvm/IR/CallingConv.h"
24#include "llvm/IR/IRBuilder.h"
25#include "llvm/IR/Instruction.h"
26#include "llvm/IR/Mangler.h"
27#include "llvm/InitializePasses.h"
28#include "llvm/Object/COFF.h"
29#include "llvm/Pass.h"
30#include "llvm/Support/CommandLine.h"
31#include "llvm/TargetParser/Triple.h"
32
33using namespace llvm;
34using namespace llvm::COFF;
35
36using OperandBundleDef = OperandBundleDefT<Value *>;
37
38#define DEBUG_TYPE "arm64eccalllowering"
39
40STATISTIC(Arm64ECCallsLowered, "Number of Arm64EC calls lowered");
41
42static cl::opt<bool> LowerDirectToIndirect("arm64ec-lower-direct-to-indirect",
43 cl::Hidden, cl::init(Val: true));
44static cl::opt<bool> GenerateThunks("arm64ec-generate-thunks", cl::Hidden,
45 cl::init(Val: true));
46
47namespace {
48
49class AArch64Arm64ECCallLowering : public ModulePass {
50public:
51 static char ID;
52 AArch64Arm64ECCallLowering() : ModulePass(ID) {
53 initializeAArch64Arm64ECCallLoweringPass(*PassRegistry::getPassRegistry());
54 }
55
56 Function *buildExitThunk(FunctionType *FnTy, AttributeList Attrs);
57 Function *buildEntryThunk(Function *F);
58 void lowerCall(CallBase *CB);
59 Function *buildGuestExitThunk(Function *F);
60 bool processFunction(Function &F, SetVector<Function *> &DirectCalledFns);
61 bool runOnModule(Module &M) override;
62
63private:
64 int cfguard_module_flag = 0;
65 FunctionType *GuardFnType = nullptr;
66 PointerType *GuardFnPtrType = nullptr;
67 Constant *GuardFnCFGlobal = nullptr;
68 Constant *GuardFnGlobal = nullptr;
69 Module *M = nullptr;
70
71 Type *PtrTy;
72 Type *I64Ty;
73 Type *VoidTy;
74
75 void getThunkType(FunctionType *FT, AttributeList AttrList,
76 Arm64ECThunkType TT, raw_ostream &Out,
77 FunctionType *&Arm64Ty, FunctionType *&X64Ty);
78 void getThunkRetType(FunctionType *FT, AttributeList AttrList,
79 raw_ostream &Out, Type *&Arm64RetTy, Type *&X64RetTy,
80 SmallVectorImpl<Type *> &Arm64ArgTypes,
81 SmallVectorImpl<Type *> &X64ArgTypes, bool &HasSretPtr);
82 void getThunkArgTypes(FunctionType *FT, AttributeList AttrList,
83 Arm64ECThunkType TT, raw_ostream &Out,
84 SmallVectorImpl<Type *> &Arm64ArgTypes,
85 SmallVectorImpl<Type *> &X64ArgTypes, bool HasSretPtr);
86 void canonicalizeThunkType(Type *T, Align Alignment, bool Ret,
87 uint64_t ArgSizeBytes, raw_ostream &Out,
88 Type *&Arm64Ty, Type *&X64Ty);
89};
90
91} // end anonymous namespace
92
93void AArch64Arm64ECCallLowering::getThunkType(
94 FunctionType *FT, AttributeList AttrList, Arm64ECThunkType TT,
95 raw_ostream &Out, FunctionType *&Arm64Ty, FunctionType *&X64Ty) {
96 Out << (TT == Arm64ECThunkType::Entry ? "$ientry_thunk$cdecl$"
97 : "$iexit_thunk$cdecl$");
98
99 Type *Arm64RetTy;
100 Type *X64RetTy;
101
102 SmallVector<Type *> Arm64ArgTypes;
103 SmallVector<Type *> X64ArgTypes;
104
105 // The first argument to a thunk is the called function, stored in x9.
106 // For exit thunks, we pass the called function down to the emulator;
107 // for entry/guest exit thunks, we just call the Arm64 function directly.
108 if (TT == Arm64ECThunkType::Exit)
109 Arm64ArgTypes.push_back(Elt: PtrTy);
110 X64ArgTypes.push_back(Elt: PtrTy);
111
112 bool HasSretPtr = false;
113 getThunkRetType(FT, AttrList, Out, Arm64RetTy, X64RetTy, Arm64ArgTypes,
114 X64ArgTypes, HasSretPtr);
115
116 getThunkArgTypes(FT, AttrList, TT, Out, Arm64ArgTypes, X64ArgTypes,
117 HasSretPtr);
118
119 Arm64Ty = FunctionType::get(Result: Arm64RetTy, Params: Arm64ArgTypes, isVarArg: false);
120
121 X64Ty = FunctionType::get(Result: X64RetTy, Params: X64ArgTypes, isVarArg: false);
122}
123
124void AArch64Arm64ECCallLowering::getThunkArgTypes(
125 FunctionType *FT, AttributeList AttrList, Arm64ECThunkType TT,
126 raw_ostream &Out, SmallVectorImpl<Type *> &Arm64ArgTypes,
127 SmallVectorImpl<Type *> &X64ArgTypes, bool HasSretPtr) {
128
129 Out << "$";
130 if (FT->isVarArg()) {
131 // We treat the variadic function's thunk as a normal function
132 // with the following type on the ARM side:
133 // rettype exitthunk(
134 // ptr x9, ptr x0, i64 x1, i64 x2, i64 x3, ptr x4, i64 x5)
135 //
136 // that can coverage all types of variadic function.
137 // x9 is similar to normal exit thunk, store the called function.
138 // x0-x3 is the arguments be stored in registers.
139 // x4 is the address of the arguments on the stack.
140 // x5 is the size of the arguments on the stack.
141 //
142 // On the x64 side, it's the same except that x5 isn't set.
143 //
144 // If both the ARM and X64 sides are sret, there are only three
145 // arguments in registers.
146 //
147 // If the X64 side is sret, but the ARM side isn't, we pass an extra value
148 // to/from the X64 side, and let SelectionDAG transform it into a memory
149 // location.
150 Out << "varargs";
151
152 // x0-x3
153 for (int i = HasSretPtr ? 1 : 0; i < 4; i++) {
154 Arm64ArgTypes.push_back(Elt: I64Ty);
155 X64ArgTypes.push_back(Elt: I64Ty);
156 }
157
158 // x4
159 Arm64ArgTypes.push_back(Elt: PtrTy);
160 X64ArgTypes.push_back(Elt: PtrTy);
161 // x5
162 Arm64ArgTypes.push_back(Elt: I64Ty);
163 if (TT != Arm64ECThunkType::Entry) {
164 // FIXME: x5 isn't actually used by the x64 side; revisit once we
165 // have proper isel for varargs
166 X64ArgTypes.push_back(Elt: I64Ty);
167 }
168 return;
169 }
170
171 unsigned I = 0;
172 if (HasSretPtr)
173 I++;
174
175 if (I == FT->getNumParams()) {
176 Out << "v";
177 return;
178 }
179
180 for (unsigned E = FT->getNumParams(); I != E; ++I) {
181 Align ParamAlign = AttrList.getParamAlignment(ArgNo: I).valueOrOne();
182#if 0
183 // FIXME: Need more information about argument size; see
184 // https://reviews.llvm.org/D132926
185 uint64_t ArgSizeBytes = AttrList.getParamArm64ECArgSizeBytes(I);
186#else
187 uint64_t ArgSizeBytes = 0;
188#endif
189 Type *Arm64Ty, *X64Ty;
190 canonicalizeThunkType(T: FT->getParamType(i: I), Alignment: ParamAlign,
191 /*Ret*/ false, ArgSizeBytes, Out, Arm64Ty, X64Ty);
192 Arm64ArgTypes.push_back(Elt: Arm64Ty);
193 X64ArgTypes.push_back(Elt: X64Ty);
194 }
195}
196
197void AArch64Arm64ECCallLowering::getThunkRetType(
198 FunctionType *FT, AttributeList AttrList, raw_ostream &Out,
199 Type *&Arm64RetTy, Type *&X64RetTy, SmallVectorImpl<Type *> &Arm64ArgTypes,
200 SmallVectorImpl<Type *> &X64ArgTypes, bool &HasSretPtr) {
201 Type *T = FT->getReturnType();
202#if 0
203 // FIXME: Need more information about argument size; see
204 // https://reviews.llvm.org/D132926
205 uint64_t ArgSizeBytes = AttrList.getRetArm64ECArgSizeBytes();
206#else
207 int64_t ArgSizeBytes = 0;
208#endif
209 if (T->isVoidTy()) {
210 if (FT->getNumParams()) {
211 auto SRetAttr = AttrList.getParamAttr(0, Attribute::StructRet);
212 auto InRegAttr = AttrList.getParamAttr(0, Attribute::InReg);
213 if (SRetAttr.isValid() && InRegAttr.isValid()) {
214 // sret+inreg indicates a call that returns a C++ class value. This is
215 // actually equivalent to just passing and returning a void* pointer
216 // as the first argument. Translate it that way, instead of trying
217 // to model "inreg" in the thunk's calling convention, to simplify
218 // the rest of the code.
219 Out << "i8";
220 Arm64RetTy = I64Ty;
221 X64RetTy = I64Ty;
222 return;
223 }
224 if (SRetAttr.isValid()) {
225 // FIXME: Sanity-check the sret type; if it's an integer or pointer,
226 // we'll get screwy mangling/codegen.
227 // FIXME: For large struct types, mangle as an integer argument and
228 // integer return, so we can reuse more thunks, instead of "m" syntax.
229 // (MSVC mangles this case as an integer return with no argument, but
230 // that's a miscompile.)
231 Type *SRetType = SRetAttr.getValueAsType();
232 Align SRetAlign = AttrList.getParamAlignment(ArgNo: 0).valueOrOne();
233 Type *Arm64Ty, *X64Ty;
234 canonicalizeThunkType(T: SRetType, Alignment: SRetAlign, /*Ret*/ true, ArgSizeBytes,
235 Out, Arm64Ty, X64Ty);
236 Arm64RetTy = VoidTy;
237 X64RetTy = VoidTy;
238 Arm64ArgTypes.push_back(Elt: FT->getParamType(i: 0));
239 X64ArgTypes.push_back(Elt: FT->getParamType(i: 0));
240 HasSretPtr = true;
241 return;
242 }
243 }
244
245 Out << "v";
246 Arm64RetTy = VoidTy;
247 X64RetTy = VoidTy;
248 return;
249 }
250
251 canonicalizeThunkType(T, Alignment: Align(), /*Ret*/ true, ArgSizeBytes, Out, Arm64Ty&: Arm64RetTy,
252 X64Ty&: X64RetTy);
253 if (X64RetTy->isPointerTy()) {
254 // If the X64 type is canonicalized to a pointer, that means it's
255 // passed/returned indirectly. For a return value, that means it's an
256 // sret pointer.
257 X64ArgTypes.push_back(Elt: X64RetTy);
258 X64RetTy = VoidTy;
259 }
260}
261
262void AArch64Arm64ECCallLowering::canonicalizeThunkType(
263 Type *T, Align Alignment, bool Ret, uint64_t ArgSizeBytes, raw_ostream &Out,
264 Type *&Arm64Ty, Type *&X64Ty) {
265 if (T->isFloatTy()) {
266 Out << "f";
267 Arm64Ty = T;
268 X64Ty = T;
269 return;
270 }
271
272 if (T->isDoubleTy()) {
273 Out << "d";
274 Arm64Ty = T;
275 X64Ty = T;
276 return;
277 }
278
279 if (T->isFloatingPointTy()) {
280 report_fatal_error(
281 reason: "Only 32 and 64 bit floating points are supported for ARM64EC thunks");
282 }
283
284 auto &DL = M->getDataLayout();
285
286 if (auto *StructTy = dyn_cast<StructType>(Val: T))
287 if (StructTy->getNumElements() == 1)
288 T = StructTy->getElementType(N: 0);
289
290 if (T->isArrayTy()) {
291 Type *ElementTy = T->getArrayElementType();
292 uint64_t ElementCnt = T->getArrayNumElements();
293 uint64_t ElementSizePerBytes = DL.getTypeSizeInBits(Ty: ElementTy) / 8;
294 uint64_t TotalSizeBytes = ElementCnt * ElementSizePerBytes;
295 if (ElementTy->isFloatTy() || ElementTy->isDoubleTy()) {
296 Out << (ElementTy->isFloatTy() ? "F" : "D") << TotalSizeBytes;
297 if (Alignment.value() >= 8 && !T->isPointerTy())
298 Out << "a" << Alignment.value();
299 Arm64Ty = T;
300 if (TotalSizeBytes <= 8) {
301 // Arm64 returns small structs of float/double in float registers;
302 // X64 uses RAX.
303 X64Ty = llvm::Type::getIntNTy(C&: M->getContext(), N: TotalSizeBytes * 8);
304 } else {
305 // Struct is passed directly on Arm64, but indirectly on X64.
306 X64Ty = PtrTy;
307 }
308 return;
309 } else if (T->isFloatingPointTy()) {
310 report_fatal_error(reason: "Only 32 and 64 bit floating points are supported for "
311 "ARM64EC thunks");
312 }
313 }
314
315 if ((T->isIntegerTy() || T->isPointerTy()) && DL.getTypeSizeInBits(Ty: T) <= 64) {
316 Out << "i8";
317 Arm64Ty = I64Ty;
318 X64Ty = I64Ty;
319 return;
320 }
321
322 unsigned TypeSize = ArgSizeBytes;
323 if (TypeSize == 0)
324 TypeSize = DL.getTypeSizeInBits(Ty: T) / 8;
325 Out << "m";
326 if (TypeSize != 4)
327 Out << TypeSize;
328 if (Alignment.value() >= 8 && !T->isPointerTy())
329 Out << "a" << Alignment.value();
330 // FIXME: Try to canonicalize Arm64Ty more thoroughly?
331 Arm64Ty = T;
332 if (TypeSize == 1 || TypeSize == 2 || TypeSize == 4 || TypeSize == 8) {
333 // Pass directly in an integer register
334 X64Ty = llvm::Type::getIntNTy(C&: M->getContext(), N: TypeSize * 8);
335 } else {
336 // Passed directly on Arm64, but indirectly on X64.
337 X64Ty = PtrTy;
338 }
339}
340
341// This function builds the "exit thunk", a function which translates
342// arguments and return values when calling x64 code from AArch64 code.
343Function *AArch64Arm64ECCallLowering::buildExitThunk(FunctionType *FT,
344 AttributeList Attrs) {
345 SmallString<256> ExitThunkName;
346 llvm::raw_svector_ostream ExitThunkStream(ExitThunkName);
347 FunctionType *Arm64Ty, *X64Ty;
348 getThunkType(FT, AttrList: Attrs, TT: Arm64ECThunkType::Exit, Out&: ExitThunkStream, Arm64Ty,
349 X64Ty);
350 if (Function *F = M->getFunction(Name: ExitThunkName))
351 return F;
352
353 Function *F = Function::Create(Ty: Arm64Ty, Linkage: GlobalValue::LinkOnceODRLinkage, AddrSpace: 0,
354 N: ExitThunkName, M);
355 F->setCallingConv(CallingConv::ARM64EC_Thunk_Native);
356 F->setSection(".wowthk$aa");
357 F->setComdat(M->getOrInsertComdat(Name: ExitThunkName));
358 // Copy MSVC, and always set up a frame pointer. (Maybe this isn't necessary.)
359 F->addFnAttr(Kind: "frame-pointer", Val: "all");
360 // Only copy sret from the first argument. For C++ instance methods, clang can
361 // stick an sret marking on a later argument, but it doesn't actually affect
362 // the ABI, so we can omit it. This avoids triggering a verifier assertion.
363 if (FT->getNumParams()) {
364 auto SRet = Attrs.getParamAttr(0, Attribute::StructRet);
365 auto InReg = Attrs.getParamAttr(0, Attribute::InReg);
366 if (SRet.isValid() && !InReg.isValid())
367 F->addParamAttr(1, SRet);
368 }
369 // FIXME: Copy anything other than sret? Shouldn't be necessary for normal
370 // C ABI, but might show up in other cases.
371 BasicBlock *BB = BasicBlock::Create(Context&: M->getContext(), Name: "", Parent: F);
372 IRBuilder<> IRB(BB);
373 Value *CalleePtr =
374 M->getOrInsertGlobal(Name: "__os_arm64x_dispatch_call_no_redirect", Ty: PtrTy);
375 Value *Callee = IRB.CreateLoad(Ty: PtrTy, Ptr: CalleePtr);
376 auto &DL = M->getDataLayout();
377 SmallVector<Value *> Args;
378
379 // Pass the called function in x9.
380 Args.push_back(Elt: F->arg_begin());
381
382 Type *RetTy = Arm64Ty->getReturnType();
383 if (RetTy != X64Ty->getReturnType()) {
384 // If the return type is an array or struct, translate it. Values of size
385 // 8 or less go into RAX; bigger values go into memory, and we pass a
386 // pointer.
387 if (DL.getTypeStoreSize(Ty: RetTy) > 8) {
388 Args.push_back(Elt: IRB.CreateAlloca(Ty: RetTy));
389 }
390 }
391
392 for (auto &Arg : make_range(x: F->arg_begin() + 1, y: F->arg_end())) {
393 // Translate arguments from AArch64 calling convention to x86 calling
394 // convention.
395 //
396 // For simple types, we don't need to do any translation: they're
397 // represented the same way. (Implicit sign extension is not part of
398 // either convention.)
399 //
400 // The big thing we have to worry about is struct types... but
401 // fortunately AArch64 clang is pretty friendly here: the cases that need
402 // translation are always passed as a struct or array. (If we run into
403 // some cases where this doesn't work, we can teach clang to mark it up
404 // with an attribute.)
405 //
406 // The first argument is the called function, stored in x9.
407 if (Arg.getType()->isArrayTy() || Arg.getType()->isStructTy() ||
408 DL.getTypeStoreSize(Ty: Arg.getType()) > 8) {
409 Value *Mem = IRB.CreateAlloca(Ty: Arg.getType());
410 IRB.CreateStore(Val: &Arg, Ptr: Mem);
411 if (DL.getTypeStoreSize(Ty: Arg.getType()) <= 8) {
412 Type *IntTy = IRB.getIntNTy(N: DL.getTypeStoreSizeInBits(Ty: Arg.getType()));
413 Args.push_back(Elt: IRB.CreateLoad(Ty: IntTy, Ptr: IRB.CreateBitCast(V: Mem, DestTy: PtrTy)));
414 } else
415 Args.push_back(Elt: Mem);
416 } else {
417 Args.push_back(Elt: &Arg);
418 }
419 }
420 // FIXME: Transfer necessary attributes? sret? anything else?
421
422 Callee = IRB.CreateBitCast(V: Callee, DestTy: PtrTy);
423 CallInst *Call = IRB.CreateCall(FTy: X64Ty, Callee, Args);
424 Call->setCallingConv(CallingConv::ARM64EC_Thunk_X64);
425
426 Value *RetVal = Call;
427 if (RetTy != X64Ty->getReturnType()) {
428 // If we rewrote the return type earlier, convert the return value to
429 // the proper type.
430 if (DL.getTypeStoreSize(Ty: RetTy) > 8) {
431 RetVal = IRB.CreateLoad(Ty: RetTy, Ptr: Args[1]);
432 } else {
433 Value *CastAlloca = IRB.CreateAlloca(Ty: RetTy);
434 IRB.CreateStore(Val: Call, Ptr: IRB.CreateBitCast(V: CastAlloca, DestTy: PtrTy));
435 RetVal = IRB.CreateLoad(Ty: RetTy, Ptr: CastAlloca);
436 }
437 }
438
439 if (RetTy->isVoidTy())
440 IRB.CreateRetVoid();
441 else
442 IRB.CreateRet(V: RetVal);
443 return F;
444}
445
446// This function builds the "entry thunk", a function which translates
447// arguments and return values when calling AArch64 code from x64 code.
448Function *AArch64Arm64ECCallLowering::buildEntryThunk(Function *F) {
449 SmallString<256> EntryThunkName;
450 llvm::raw_svector_ostream EntryThunkStream(EntryThunkName);
451 FunctionType *Arm64Ty, *X64Ty;
452 getThunkType(FT: F->getFunctionType(), AttrList: F->getAttributes(),
453 TT: Arm64ECThunkType::Entry, Out&: EntryThunkStream, Arm64Ty, X64Ty);
454 if (Function *F = M->getFunction(Name: EntryThunkName))
455 return F;
456
457 Function *Thunk = Function::Create(Ty: X64Ty, Linkage: GlobalValue::LinkOnceODRLinkage, AddrSpace: 0,
458 N: EntryThunkName, M);
459 Thunk->setCallingConv(CallingConv::ARM64EC_Thunk_X64);
460 Thunk->setSection(".wowthk$aa");
461 Thunk->setComdat(M->getOrInsertComdat(Name: EntryThunkName));
462 // Copy MSVC, and always set up a frame pointer. (Maybe this isn't necessary.)
463 Thunk->addFnAttr(Kind: "frame-pointer", Val: "all");
464
465 auto &DL = M->getDataLayout();
466 BasicBlock *BB = BasicBlock::Create(Context&: M->getContext(), Name: "", Parent: Thunk);
467 IRBuilder<> IRB(BB);
468
469 Type *RetTy = Arm64Ty->getReturnType();
470 Type *X64RetType = X64Ty->getReturnType();
471
472 bool TransformDirectToSRet = X64RetType->isVoidTy() && !RetTy->isVoidTy();
473 unsigned ThunkArgOffset = TransformDirectToSRet ? 2 : 1;
474 unsigned PassthroughArgSize = F->isVarArg() ? 5 : Thunk->arg_size();
475
476 // Translate arguments to call.
477 SmallVector<Value *> Args;
478 for (unsigned i = ThunkArgOffset, e = PassthroughArgSize; i != e; ++i) {
479 Value *Arg = Thunk->getArg(i);
480 Type *ArgTy = Arm64Ty->getParamType(i: i - ThunkArgOffset);
481 if (ArgTy->isArrayTy() || ArgTy->isStructTy() ||
482 DL.getTypeStoreSize(Ty: ArgTy) > 8) {
483 // Translate array/struct arguments to the expected type.
484 if (DL.getTypeStoreSize(Ty: ArgTy) <= 8) {
485 Value *CastAlloca = IRB.CreateAlloca(Ty: ArgTy);
486 IRB.CreateStore(Val: Arg, Ptr: IRB.CreateBitCast(V: CastAlloca, DestTy: PtrTy));
487 Arg = IRB.CreateLoad(Ty: ArgTy, Ptr: CastAlloca);
488 } else {
489 Arg = IRB.CreateLoad(Ty: ArgTy, Ptr: IRB.CreateBitCast(V: Arg, DestTy: PtrTy));
490 }
491 }
492 Args.push_back(Elt: Arg);
493 }
494
495 if (F->isVarArg()) {
496 // The 5th argument to variadic entry thunks is used to model the x64 sp
497 // which is passed to the thunk in x4, this can be passed to the callee as
498 // the variadic argument start address after skipping over the 32 byte
499 // shadow store.
500
501 // The EC thunk CC will assign any argument marked as InReg to x4.
502 Thunk->addParamAttr(5, Attribute::InReg);
503 Value *Arg = Thunk->getArg(i: 5);
504 Arg = IRB.CreatePtrAdd(Ptr: Arg, Offset: IRB.getInt64(C: 0x20));
505 Args.push_back(Elt: Arg);
506
507 // Pass in a zero variadic argument size (in x5).
508 Args.push_back(Elt: IRB.getInt64(C: 0));
509 }
510
511 // Call the function passed to the thunk.
512 Value *Callee = Thunk->getArg(i: 0);
513 Callee = IRB.CreateBitCast(V: Callee, DestTy: PtrTy);
514 Value *Call = IRB.CreateCall(FTy: Arm64Ty, Callee, Args);
515
516 Value *RetVal = Call;
517 if (TransformDirectToSRet) {
518 IRB.CreateStore(Val: RetVal, Ptr: IRB.CreateBitCast(V: Thunk->getArg(i: 1), DestTy: PtrTy));
519 } else if (X64RetType != RetTy) {
520 Value *CastAlloca = IRB.CreateAlloca(Ty: X64RetType);
521 IRB.CreateStore(Val: Call, Ptr: IRB.CreateBitCast(V: CastAlloca, DestTy: PtrTy));
522 RetVal = IRB.CreateLoad(Ty: X64RetType, Ptr: CastAlloca);
523 }
524
525 // Return to the caller. Note that the isel has code to translate this
526 // "ret" to a tail call to __os_arm64x_dispatch_ret. (Alternatively, we
527 // could emit a tail call here, but that would require a dedicated calling
528 // convention, which seems more complicated overall.)
529 if (X64RetType->isVoidTy())
530 IRB.CreateRetVoid();
531 else
532 IRB.CreateRet(V: RetVal);
533
534 return Thunk;
535}
536
537// Builds the "guest exit thunk", a helper to call a function which may or may
538// not be an exit thunk. (We optimistically assume non-dllimport function
539// declarations refer to functions defined in AArch64 code; if the linker
540// can't prove that, we use this routine instead.)
541Function *AArch64Arm64ECCallLowering::buildGuestExitThunk(Function *F) {
542 llvm::raw_null_ostream NullThunkName;
543 FunctionType *Arm64Ty, *X64Ty;
544 getThunkType(FT: F->getFunctionType(), AttrList: F->getAttributes(),
545 TT: Arm64ECThunkType::GuestExit, Out&: NullThunkName, Arm64Ty, X64Ty);
546 auto MangledName = getArm64ECMangledFunctionName(Name: F->getName().str());
547 assert(MangledName && "Can't guest exit to function that's already native");
548 std::string ThunkName = *MangledName;
549 if (ThunkName[0] == '?' && ThunkName.find(s: "@") != std::string::npos) {
550 ThunkName.insert(pos: ThunkName.find(s: "@"), s: "$exit_thunk");
551 } else {
552 ThunkName.append(s: "$exit_thunk");
553 }
554 Function *GuestExit =
555 Function::Create(Ty: Arm64Ty, Linkage: GlobalValue::WeakODRLinkage, AddrSpace: 0, N: ThunkName, M);
556 GuestExit->setComdat(M->getOrInsertComdat(Name: ThunkName));
557 GuestExit->setSection(".wowthk$aa");
558 GuestExit->setMetadata(
559 Kind: "arm64ec_unmangled_name",
560 Node: MDNode::get(Context&: M->getContext(),
561 MDs: MDString::get(Context&: M->getContext(), Str: F->getName())));
562 GuestExit->setMetadata(
563 Kind: "arm64ec_ecmangled_name",
564 Node: MDNode::get(Context&: M->getContext(),
565 MDs: MDString::get(Context&: M->getContext(), Str: *MangledName)));
566 F->setMetadata(Kind: "arm64ec_hasguestexit", Node: MDNode::get(Context&: M->getContext(), MDs: {}));
567 BasicBlock *BB = BasicBlock::Create(Context&: M->getContext(), Name: "", Parent: GuestExit);
568 IRBuilder<> B(BB);
569
570 // Load the global symbol as a pointer to the check function.
571 Value *GuardFn;
572 if (cfguard_module_flag == 2 && !F->hasFnAttribute(Kind: "guard_nocf"))
573 GuardFn = GuardFnCFGlobal;
574 else
575 GuardFn = GuardFnGlobal;
576 LoadInst *GuardCheckLoad = B.CreateLoad(Ty: GuardFnPtrType, Ptr: GuardFn);
577
578 // Create new call instruction. The CFGuard check should always be a call,
579 // even if the original CallBase is an Invoke or CallBr instruction.
580 Function *Thunk = buildExitThunk(FT: F->getFunctionType(), Attrs: F->getAttributes());
581 CallInst *GuardCheck = B.CreateCall(
582 FTy: GuardFnType, Callee: GuardCheckLoad,
583 Args: {B.CreateBitCast(V: F, DestTy: B.getPtrTy()), B.CreateBitCast(V: Thunk, DestTy: B.getPtrTy())});
584
585 // Ensure that the first argument is passed in the correct register.
586 GuardCheck->setCallingConv(CallingConv::CFGuard_Check);
587
588 Value *GuardRetVal = B.CreateBitCast(V: GuardCheck, DestTy: PtrTy);
589 SmallVector<Value *> Args;
590 for (Argument &Arg : GuestExit->args())
591 Args.push_back(Elt: &Arg);
592 CallInst *Call = B.CreateCall(FTy: Arm64Ty, Callee: GuardRetVal, Args);
593 Call->setTailCallKind(llvm::CallInst::TCK_MustTail);
594
595 if (Call->getType()->isVoidTy())
596 B.CreateRetVoid();
597 else
598 B.CreateRet(V: Call);
599
600 auto SRetAttr = F->getAttributes().getParamAttr(0, Attribute::StructRet);
601 auto InRegAttr = F->getAttributes().getParamAttr(0, Attribute::InReg);
602 if (SRetAttr.isValid() && !InRegAttr.isValid()) {
603 GuestExit->addParamAttr(0, SRetAttr);
604 Call->addParamAttr(0, SRetAttr);
605 }
606
607 return GuestExit;
608}
609
610// Lower an indirect call with inline code.
611void AArch64Arm64ECCallLowering::lowerCall(CallBase *CB) {
612 assert(Triple(CB->getModule()->getTargetTriple()).isOSWindows() &&
613 "Only applicable for Windows targets");
614
615 IRBuilder<> B(CB);
616 Value *CalledOperand = CB->getCalledOperand();
617
618 // If the indirect call is called within catchpad or cleanuppad,
619 // we need to copy "funclet" bundle of the call.
620 SmallVector<llvm::OperandBundleDef, 1> Bundles;
621 if (auto Bundle = CB->getOperandBundle(ID: LLVMContext::OB_funclet))
622 Bundles.push_back(Elt: OperandBundleDef(*Bundle));
623
624 // Load the global symbol as a pointer to the check function.
625 Value *GuardFn;
626 if (cfguard_module_flag == 2 && !CB->hasFnAttr(Kind: "guard_nocf"))
627 GuardFn = GuardFnCFGlobal;
628 else
629 GuardFn = GuardFnGlobal;
630 LoadInst *GuardCheckLoad = B.CreateLoad(Ty: GuardFnPtrType, Ptr: GuardFn);
631
632 // Create new call instruction. The CFGuard check should always be a call,
633 // even if the original CallBase is an Invoke or CallBr instruction.
634 Function *Thunk = buildExitThunk(FT: CB->getFunctionType(), Attrs: CB->getAttributes());
635 CallInst *GuardCheck =
636 B.CreateCall(FTy: GuardFnType, Callee: GuardCheckLoad,
637 Args: {B.CreateBitCast(V: CalledOperand, DestTy: B.getPtrTy()),
638 B.CreateBitCast(V: Thunk, DestTy: B.getPtrTy())},
639 OpBundles: Bundles);
640
641 // Ensure that the first argument is passed in the correct register.
642 GuardCheck->setCallingConv(CallingConv::CFGuard_Check);
643
644 Value *GuardRetVal = B.CreateBitCast(V: GuardCheck, DestTy: CalledOperand->getType());
645 CB->setCalledOperand(GuardRetVal);
646}
647
648bool AArch64Arm64ECCallLowering::runOnModule(Module &Mod) {
649 if (!GenerateThunks)
650 return false;
651
652 M = &Mod;
653
654 // Check if this module has the cfguard flag and read its value.
655 if (auto *MD =
656 mdconst::extract_or_null<ConstantInt>(MD: M->getModuleFlag(Key: "cfguard")))
657 cfguard_module_flag = MD->getZExtValue();
658
659 PtrTy = PointerType::getUnqual(C&: M->getContext());
660 I64Ty = Type::getInt64Ty(C&: M->getContext());
661 VoidTy = Type::getVoidTy(C&: M->getContext());
662
663 GuardFnType = FunctionType::get(Result: PtrTy, Params: {PtrTy, PtrTy}, isVarArg: false);
664 GuardFnPtrType = PointerType::get(ElementType: GuardFnType, AddressSpace: 0);
665 GuardFnCFGlobal =
666 M->getOrInsertGlobal(Name: "__os_arm64x_check_icall_cfg", Ty: GuardFnPtrType);
667 GuardFnGlobal =
668 M->getOrInsertGlobal(Name: "__os_arm64x_check_icall", Ty: GuardFnPtrType);
669
670 SetVector<Function *> DirectCalledFns;
671 for (Function &F : Mod)
672 if (!F.isDeclaration() &&
673 F.getCallingConv() != CallingConv::ARM64EC_Thunk_Native &&
674 F.getCallingConv() != CallingConv::ARM64EC_Thunk_X64)
675 processFunction(F, DirectCalledFns);
676
677 struct ThunkInfo {
678 Constant *Src;
679 Constant *Dst;
680 Arm64ECThunkType Kind;
681 };
682 SmallVector<ThunkInfo> ThunkMapping;
683 for (Function &F : Mod) {
684 if (!F.isDeclaration() && (!F.hasLocalLinkage() || F.hasAddressTaken()) &&
685 F.getCallingConv() != CallingConv::ARM64EC_Thunk_Native &&
686 F.getCallingConv() != CallingConv::ARM64EC_Thunk_X64) {
687 if (!F.hasComdat())
688 F.setComdat(Mod.getOrInsertComdat(Name: F.getName()));
689 ThunkMapping.push_back(
690 Elt: {.Src: &F, .Dst: buildEntryThunk(F: &F), .Kind: Arm64ECThunkType::Entry});
691 }
692 }
693 for (Function *F : DirectCalledFns) {
694 ThunkMapping.push_back(
695 Elt: {.Src: F, .Dst: buildExitThunk(FT: F->getFunctionType(), Attrs: F->getAttributes()),
696 .Kind: Arm64ECThunkType::Exit});
697 if (!F->hasDLLImportStorageClass())
698 ThunkMapping.push_back(
699 Elt: {.Src: buildGuestExitThunk(F), .Dst: F, .Kind: Arm64ECThunkType::GuestExit});
700 }
701
702 if (!ThunkMapping.empty()) {
703 SmallVector<Constant *> ThunkMappingArrayElems;
704 for (ThunkInfo &Thunk : ThunkMapping) {
705 ThunkMappingArrayElems.push_back(Elt: ConstantStruct::getAnon(
706 V: {ConstantExpr::getBitCast(C: Thunk.Src, Ty: PtrTy),
707 ConstantExpr::getBitCast(C: Thunk.Dst, Ty: PtrTy),
708 ConstantInt::get(Context&: M->getContext(), V: APInt(32, uint8_t(Thunk.Kind)))}));
709 }
710 Constant *ThunkMappingArray = ConstantArray::get(
711 T: llvm::ArrayType::get(ElementType: ThunkMappingArrayElems[0]->getType(),
712 NumElements: ThunkMappingArrayElems.size()),
713 V: ThunkMappingArrayElems);
714 new GlobalVariable(Mod, ThunkMappingArray->getType(), /*isConstant*/ false,
715 GlobalValue::ExternalLinkage, ThunkMappingArray,
716 "llvm.arm64ec.symbolmap");
717 }
718
719 return true;
720}
721
722bool AArch64Arm64ECCallLowering::processFunction(
723 Function &F, SetVector<Function *> &DirectCalledFns) {
724 SmallVector<CallBase *, 8> IndirectCalls;
725
726 // For ARM64EC targets, a function definition's name is mangled differently
727 // from the normal symbol. We currently have no representation of this sort
728 // of symbol in IR, so we change the name to the mangled name, then store
729 // the unmangled name as metadata. Later passes that need the unmangled
730 // name (emitting the definition) can grab it from the metadata.
731 //
732 // FIXME: Handle functions with weak linkage?
733 if (!F.hasLocalLinkage() || F.hasAddressTaken()) {
734 if (std::optional<std::string> MangledName =
735 getArm64ECMangledFunctionName(Name: F.getName().str())) {
736 F.setMetadata(Kind: "arm64ec_unmangled_name",
737 Node: MDNode::get(Context&: M->getContext(),
738 MDs: MDString::get(Context&: M->getContext(), Str: F.getName())));
739 if (F.hasComdat() && F.getComdat()->getName() == F.getName()) {
740 Comdat *MangledComdat = M->getOrInsertComdat(Name: MangledName.value());
741 SmallVector<GlobalObject *> ComdatUsers =
742 to_vector(Range: F.getComdat()->getUsers());
743 for (GlobalObject *User : ComdatUsers)
744 User->setComdat(MangledComdat);
745 }
746 F.setName(MangledName.value());
747 }
748 }
749
750 // Iterate over the instructions to find all indirect call/invoke/callbr
751 // instructions. Make a separate list of pointers to indirect
752 // call/invoke/callbr instructions because the original instructions will be
753 // deleted as the checks are added.
754 for (BasicBlock &BB : F) {
755 for (Instruction &I : BB) {
756 auto *CB = dyn_cast<CallBase>(Val: &I);
757 if (!CB || CB->getCallingConv() == CallingConv::ARM64EC_Thunk_X64 ||
758 CB->isInlineAsm())
759 continue;
760
761 // We need to instrument any call that isn't directly calling an
762 // ARM64 function.
763 //
764 // FIXME: getCalledFunction() fails if there's a bitcast (e.g.
765 // unprototyped functions in C)
766 if (Function *F = CB->getCalledFunction()) {
767 if (!LowerDirectToIndirect || F->hasLocalLinkage() ||
768 F->isIntrinsic() || !F->isDeclaration())
769 continue;
770
771 DirectCalledFns.insert(X: F);
772 continue;
773 }
774
775 IndirectCalls.push_back(Elt: CB);
776 ++Arm64ECCallsLowered;
777 }
778 }
779
780 if (IndirectCalls.empty())
781 return false;
782
783 for (CallBase *CB : IndirectCalls)
784 lowerCall(CB);
785
786 return true;
787}
788
789char AArch64Arm64ECCallLowering::ID = 0;
790INITIALIZE_PASS(AArch64Arm64ECCallLowering, "Arm64ECCallLowering",
791 "AArch64Arm64ECCallLowering", false, false)
792
793ModulePass *llvm::createAArch64Arm64ECCallLoweringPass() {
794 return new AArch64Arm64ECCallLowering;
795}
796

source code of llvm/lib/Target/AArch64/AArch64Arm64ECCallLowering.cpp