1//===-- SPIRVPrepareFunctions.cpp - modify function signatures --*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass modifies function signatures containing aggregate arguments
10// and/or return value before IRTranslator. Information about the original
11// signatures is stored in metadata. It is used during call lowering to
12// restore correct SPIR-V types of function arguments and return values.
13// This pass also substitutes some llvm intrinsic calls with calls to newly
14// generated functions (as the Khronos LLVM/SPIR-V Translator does).
15//
16// NOTE: this pass is a module-level one due to the necessity to modify
17// GVs/functions.
18//
19//===----------------------------------------------------------------------===//
20
21#include "SPIRV.h"
22#include "SPIRVSubtarget.h"
23#include "SPIRVTargetMachine.h"
24#include "SPIRVUtils.h"
25#include "llvm/CodeGen/IntrinsicLowering.h"
26#include "llvm/IR/IRBuilder.h"
27#include "llvm/IR/IntrinsicInst.h"
28#include "llvm/IR/Intrinsics.h"
29#include "llvm/IR/IntrinsicsSPIRV.h"
30#include "llvm/Transforms/Utils/Cloning.h"
31#include "llvm/Transforms/Utils/LowerMemIntrinsics.h"
32
33using namespace llvm;
34
35namespace llvm {
36void initializeSPIRVPrepareFunctionsPass(PassRegistry &);
37}
38
39namespace {
40
41class SPIRVPrepareFunctions : public ModulePass {
42 const SPIRVTargetMachine &TM;
43 bool substituteIntrinsicCalls(Function *F);
44 Function *removeAggregateTypesFromSignature(Function *F);
45
46public:
47 static char ID;
48 SPIRVPrepareFunctions(const SPIRVTargetMachine &TM) : ModulePass(ID), TM(TM) {
49 initializeSPIRVPrepareFunctionsPass(*PassRegistry::getPassRegistry());
50 }
51
52 bool runOnModule(Module &M) override;
53
54 StringRef getPassName() const override { return "SPIRV prepare functions"; }
55
56 void getAnalysisUsage(AnalysisUsage &AU) const override {
57 ModulePass::getAnalysisUsage(AU);
58 }
59};
60
61} // namespace
62
63char SPIRVPrepareFunctions::ID = 0;
64
65INITIALIZE_PASS(SPIRVPrepareFunctions, "prepare-functions",
66 "SPIRV prepare functions", false, false)
67
68std::string lowerLLVMIntrinsicName(IntrinsicInst *II) {
69 Function *IntrinsicFunc = II->getCalledFunction();
70 assert(IntrinsicFunc && "Missing function");
71 std::string FuncName = IntrinsicFunc->getName().str();
72 std::replace(first: FuncName.begin(), last: FuncName.end(), old_value: '.', new_value: '_');
73 FuncName = "spirv." + FuncName;
74 return FuncName;
75}
76
77static Function *getOrCreateFunction(Module *M, Type *RetTy,
78 ArrayRef<Type *> ArgTypes,
79 StringRef Name) {
80 FunctionType *FT = FunctionType::get(Result: RetTy, Params: ArgTypes, isVarArg: false);
81 Function *F = M->getFunction(Name);
82 if (F && F->getFunctionType() == FT)
83 return F;
84 Function *NewF = Function::Create(Ty: FT, Linkage: GlobalValue::ExternalLinkage, N: Name, M);
85 if (F)
86 NewF->setDSOLocal(F->isDSOLocal());
87 NewF->setCallingConv(CallingConv::SPIR_FUNC);
88 return NewF;
89}
90
91static bool lowerIntrinsicToFunction(IntrinsicInst *Intrinsic) {
92 // For @llvm.memset.* intrinsic cases with constant value and length arguments
93 // are emulated via "storing" a constant array to the destination. For other
94 // cases we wrap the intrinsic in @spirv.llvm_memset_* function and expand the
95 // intrinsic to a loop via expandMemSetAsLoop().
96 if (auto *MSI = dyn_cast<MemSetInst>(Val: Intrinsic))
97 if (isa<Constant>(Val: MSI->getValue()) && isa<ConstantInt>(Val: MSI->getLength()))
98 return false; // It is handled later using OpCopyMemorySized.
99
100 Module *M = Intrinsic->getModule();
101 std::string FuncName = lowerLLVMIntrinsicName(II: Intrinsic);
102 if (Intrinsic->isVolatile())
103 FuncName += ".volatile";
104 // Redirect @llvm.intrinsic.* call to @spirv.llvm_intrinsic_*
105 Function *F = M->getFunction(Name: FuncName);
106 if (F) {
107 Intrinsic->setCalledFunction(F);
108 return true;
109 }
110 // TODO copy arguments attributes: nocapture writeonly.
111 FunctionCallee FC =
112 M->getOrInsertFunction(Name: FuncName, T: Intrinsic->getFunctionType());
113 auto IntrinsicID = Intrinsic->getIntrinsicID();
114 Intrinsic->setCalledFunction(FC);
115
116 F = dyn_cast<Function>(Val: FC.getCallee());
117 assert(F && "Callee must be a function");
118
119 switch (IntrinsicID) {
120 case Intrinsic::memset: {
121 auto *MSI = static_cast<MemSetInst *>(Intrinsic);
122 Argument *Dest = F->getArg(i: 0);
123 Argument *Val = F->getArg(i: 1);
124 Argument *Len = F->getArg(i: 2);
125 Argument *IsVolatile = F->getArg(i: 3);
126 Dest->setName("dest");
127 Val->setName("val");
128 Len->setName("len");
129 IsVolatile->setName("isvolatile");
130 BasicBlock *EntryBB = BasicBlock::Create(Context&: M->getContext(), Name: "entry", Parent: F);
131 IRBuilder<> IRB(EntryBB);
132 auto *MemSet = IRB.CreateMemSet(Ptr: Dest, Val, Size: Len, Align: MSI->getDestAlign(),
133 isVolatile: MSI->isVolatile());
134 IRB.CreateRetVoid();
135 expandMemSetAsLoop(MemSet: cast<MemSetInst>(Val: MemSet));
136 MemSet->eraseFromParent();
137 break;
138 }
139 case Intrinsic::bswap: {
140 BasicBlock *EntryBB = BasicBlock::Create(Context&: M->getContext(), Name: "entry", Parent: F);
141 IRBuilder<> IRB(EntryBB);
142 auto *BSwap = IRB.CreateIntrinsic(Intrinsic::bswap, Intrinsic->getType(),
143 F->getArg(0));
144 IRB.CreateRet(V: BSwap);
145 IntrinsicLowering IL(M->getDataLayout());
146 IL.LowerIntrinsicCall(CI: BSwap);
147 break;
148 }
149 default:
150 break;
151 }
152 return true;
153}
154
155static void lowerFunnelShifts(IntrinsicInst *FSHIntrinsic) {
156 // Get a separate function - otherwise, we'd have to rework the CFG of the
157 // current one. Then simply replace the intrinsic uses with a call to the new
158 // function.
159 // Generate LLVM IR for i* @spirv.llvm_fsh?_i* (i* %a, i* %b, i* %c)
160 Module *M = FSHIntrinsic->getModule();
161 FunctionType *FSHFuncTy = FSHIntrinsic->getFunctionType();
162 Type *FSHRetTy = FSHFuncTy->getReturnType();
163 const std::string FuncName = lowerLLVMIntrinsicName(II: FSHIntrinsic);
164 Function *FSHFunc =
165 getOrCreateFunction(M, RetTy: FSHRetTy, ArgTypes: FSHFuncTy->params(), Name: FuncName);
166
167 if (!FSHFunc->empty()) {
168 FSHIntrinsic->setCalledFunction(FSHFunc);
169 return;
170 }
171 BasicBlock *RotateBB = BasicBlock::Create(Context&: M->getContext(), Name: "rotate", Parent: FSHFunc);
172 IRBuilder<> IRB(RotateBB);
173 Type *Ty = FSHFunc->getReturnType();
174 // Build the actual funnel shift rotate logic.
175 // In the comments, "int" is used interchangeably with "vector of int
176 // elements".
177 FixedVectorType *VectorTy = dyn_cast<FixedVectorType>(Val: Ty);
178 Type *IntTy = VectorTy ? VectorTy->getElementType() : Ty;
179 unsigned BitWidth = IntTy->getIntegerBitWidth();
180 ConstantInt *BitWidthConstant = IRB.getInt(AI: {BitWidth, BitWidth});
181 Value *BitWidthForInsts =
182 VectorTy
183 ? IRB.CreateVectorSplat(NumElts: VectorTy->getNumElements(), V: BitWidthConstant)
184 : BitWidthConstant;
185 Value *RotateModVal =
186 IRB.CreateURem(/*Rotate*/ LHS: FSHFunc->getArg(i: 2), RHS: BitWidthForInsts);
187 Value *FirstShift = nullptr, *SecShift = nullptr;
188 if (FSHIntrinsic->getIntrinsicID() == Intrinsic::fshr) {
189 // Shift the less significant number right, the "rotate" number of bits
190 // will be 0-filled on the left as a result of this regular shift.
191 FirstShift = IRB.CreateLShr(LHS: FSHFunc->getArg(i: 1), RHS: RotateModVal);
192 } else {
193 // Shift the more significant number left, the "rotate" number of bits
194 // will be 0-filled on the right as a result of this regular shift.
195 FirstShift = IRB.CreateShl(LHS: FSHFunc->getArg(i: 0), RHS: RotateModVal);
196 }
197 // We want the "rotate" number of the more significant int's LSBs (MSBs) to
198 // occupy the leftmost (rightmost) "0 space" left by the previous operation.
199 // Therefore, subtract the "rotate" number from the integer bitsize...
200 Value *SubRotateVal = IRB.CreateSub(LHS: BitWidthForInsts, RHS: RotateModVal);
201 if (FSHIntrinsic->getIntrinsicID() == Intrinsic::fshr) {
202 // ...and left-shift the more significant int by this number, zero-filling
203 // the LSBs.
204 SecShift = IRB.CreateShl(LHS: FSHFunc->getArg(i: 0), RHS: SubRotateVal);
205 } else {
206 // ...and right-shift the less significant int by this number, zero-filling
207 // the MSBs.
208 SecShift = IRB.CreateLShr(LHS: FSHFunc->getArg(i: 1), RHS: SubRotateVal);
209 }
210 // A simple binary addition of the shifted ints yields the final result.
211 IRB.CreateRet(V: IRB.CreateOr(LHS: FirstShift, RHS: SecShift));
212
213 FSHIntrinsic->setCalledFunction(FSHFunc);
214}
215
216static void buildUMulWithOverflowFunc(Function *UMulFunc) {
217 // The function body is already created.
218 if (!UMulFunc->empty())
219 return;
220
221 BasicBlock *EntryBB = BasicBlock::Create(Context&: UMulFunc->getParent()->getContext(),
222 Name: "entry", Parent: UMulFunc);
223 IRBuilder<> IRB(EntryBB);
224 // Build the actual unsigned multiplication logic with the overflow
225 // indication. Do unsigned multiplication Mul = A * B. Then check
226 // if unsigned division Div = Mul / A is not equal to B. If so,
227 // then overflow has happened.
228 Value *Mul = IRB.CreateNUWMul(LHS: UMulFunc->getArg(i: 0), RHS: UMulFunc->getArg(i: 1));
229 Value *Div = IRB.CreateUDiv(LHS: Mul, RHS: UMulFunc->getArg(i: 0));
230 Value *Overflow = IRB.CreateICmpNE(LHS: UMulFunc->getArg(i: 0), RHS: Div);
231
232 // umul.with.overflow intrinsic return a structure, where the first element
233 // is the multiplication result, and the second is an overflow bit.
234 Type *StructTy = UMulFunc->getReturnType();
235 Value *Agg = IRB.CreateInsertValue(Agg: PoisonValue::get(T: StructTy), Val: Mul, Idxs: {0});
236 Value *Res = IRB.CreateInsertValue(Agg, Val: Overflow, Idxs: {1});
237 IRB.CreateRet(V: Res);
238}
239
240static void lowerExpectAssume(IntrinsicInst *II) {
241 // If we cannot use the SPV_KHR_expect_assume extension, then we need to
242 // ignore the intrinsic and move on. It should be removed later on by LLVM.
243 // Otherwise we should lower the intrinsic to the corresponding SPIR-V
244 // instruction.
245 // For @llvm.assume we have OpAssumeTrueKHR.
246 // For @llvm.expect we have OpExpectKHR.
247 //
248 // We need to lower this into a builtin and then the builtin into a SPIR-V
249 // instruction.
250 if (II->getIntrinsicID() == Intrinsic::assume) {
251 Function *F = Intrinsic::getDeclaration(
252 II->getModule(), Intrinsic::SPVIntrinsics::spv_assume);
253 II->setCalledFunction(F);
254 } else if (II->getIntrinsicID() == Intrinsic::expect) {
255 Function *F = Intrinsic::getDeclaration(
256 II->getModule(), Intrinsic::SPVIntrinsics::spv_expect,
257 {II->getOperand(0)->getType()});
258 II->setCalledFunction(F);
259 } else {
260 llvm_unreachable("Unknown intrinsic");
261 }
262
263 return;
264}
265
266static bool toSpvOverloadedIntrinsic(IntrinsicInst *II, Intrinsic::ID NewID,
267 ArrayRef<unsigned> OpNos) {
268 Function *F = nullptr;
269 if (OpNos.empty()) {
270 F = Intrinsic::getDeclaration(M: II->getModule(), id: NewID);
271 } else {
272 SmallVector<Type *, 4> Tys;
273 for (unsigned OpNo : OpNos)
274 Tys.push_back(Elt: II->getOperand(i_nocapture: OpNo)->getType());
275 F = Intrinsic::getDeclaration(M: II->getModule(), id: NewID, Tys);
276 }
277 II->setCalledFunction(F);
278 return true;
279}
280
281static void lowerUMulWithOverflow(IntrinsicInst *UMulIntrinsic) {
282 // Get a separate function - otherwise, we'd have to rework the CFG of the
283 // current one. Then simply replace the intrinsic uses with a call to the new
284 // function.
285 Module *M = UMulIntrinsic->getModule();
286 FunctionType *UMulFuncTy = UMulIntrinsic->getFunctionType();
287 Type *FSHLRetTy = UMulFuncTy->getReturnType();
288 const std::string FuncName = lowerLLVMIntrinsicName(II: UMulIntrinsic);
289 Function *UMulFunc =
290 getOrCreateFunction(M, RetTy: FSHLRetTy, ArgTypes: UMulFuncTy->params(), Name: FuncName);
291 buildUMulWithOverflowFunc(UMulFunc);
292 UMulIntrinsic->setCalledFunction(UMulFunc);
293}
294
295// Substitutes calls to LLVM intrinsics with either calls to SPIR-V intrinsics
296// or calls to proper generated functions. Returns True if F was modified.
297bool SPIRVPrepareFunctions::substituteIntrinsicCalls(Function *F) {
298 bool Changed = false;
299 for (BasicBlock &BB : *F) {
300 for (Instruction &I : BB) {
301 auto Call = dyn_cast<CallInst>(Val: &I);
302 if (!Call)
303 continue;
304 Function *CF = Call->getCalledFunction();
305 if (!CF || !CF->isIntrinsic())
306 continue;
307 auto *II = cast<IntrinsicInst>(Val: Call);
308 switch (II->getIntrinsicID()) {
309 case Intrinsic::memset:
310 case Intrinsic::bswap:
311 Changed |= lowerIntrinsicToFunction(Intrinsic: II);
312 break;
313 case Intrinsic::fshl:
314 case Intrinsic::fshr:
315 lowerFunnelShifts(FSHIntrinsic: II);
316 Changed = true;
317 break;
318 case Intrinsic::umul_with_overflow:
319 lowerUMulWithOverflow(UMulIntrinsic: II);
320 Changed = true;
321 break;
322 case Intrinsic::assume:
323 case Intrinsic::expect: {
324 const SPIRVSubtarget &STI = TM.getSubtarget<SPIRVSubtarget>(*F);
325 if (STI.canUseExtension(SPIRV::Extension::SPV_KHR_expect_assume))
326 lowerExpectAssume(II);
327 Changed = true;
328 } break;
329 case Intrinsic::lifetime_start:
330 Changed |= toSpvOverloadedIntrinsic(
331 II, Intrinsic::SPVIntrinsics::spv_lifetime_start, {1});
332 break;
333 case Intrinsic::lifetime_end:
334 Changed |= toSpvOverloadedIntrinsic(
335 II, Intrinsic::SPVIntrinsics::spv_lifetime_end, {1});
336 break;
337 }
338 }
339 }
340 return Changed;
341}
342
343// Returns F if aggregate argument/return types are not present or cloned F
344// function with the types replaced by i32 types. The change in types is
345// noted in 'spv.cloned_funcs' metadata for later restoration.
346Function *
347SPIRVPrepareFunctions::removeAggregateTypesFromSignature(Function *F) {
348 IRBuilder<> B(F->getContext());
349
350 bool IsRetAggr = F->getReturnType()->isAggregateType();
351 bool HasAggrArg =
352 std::any_of(first: F->arg_begin(), last: F->arg_end(), pred: [](Argument &Arg) {
353 return Arg.getType()->isAggregateType();
354 });
355 bool DoClone = IsRetAggr || HasAggrArg;
356 if (!DoClone)
357 return F;
358 SmallVector<std::pair<int, Type *>, 4> ChangedTypes;
359 Type *RetType = IsRetAggr ? B.getInt32Ty() : F->getReturnType();
360 if (IsRetAggr)
361 ChangedTypes.push_back(Elt: std::pair<int, Type *>(-1, F->getReturnType()));
362 SmallVector<Type *, 4> ArgTypes;
363 for (const auto &Arg : F->args()) {
364 if (Arg.getType()->isAggregateType()) {
365 ArgTypes.push_back(Elt: B.getInt32Ty());
366 ChangedTypes.push_back(
367 Elt: std::pair<int, Type *>(Arg.getArgNo(), Arg.getType()));
368 } else
369 ArgTypes.push_back(Elt: Arg.getType());
370 }
371 FunctionType *NewFTy =
372 FunctionType::get(Result: RetType, Params: ArgTypes, isVarArg: F->getFunctionType()->isVarArg());
373 Function *NewF =
374 Function::Create(Ty: NewFTy, Linkage: F->getLinkage(), N: F->getName(), M&: *F->getParent());
375
376 ValueToValueMapTy VMap;
377 auto NewFArgIt = NewF->arg_begin();
378 for (auto &Arg : F->args()) {
379 StringRef ArgName = Arg.getName();
380 NewFArgIt->setName(ArgName);
381 VMap[&Arg] = &(*NewFArgIt++);
382 }
383 SmallVector<ReturnInst *, 8> Returns;
384
385 CloneFunctionInto(NewFunc: NewF, OldFunc: F, VMap, Changes: CloneFunctionChangeType::LocalChangesOnly,
386 Returns);
387 NewF->takeName(V: F);
388
389 NamedMDNode *FuncMD =
390 F->getParent()->getOrInsertNamedMetadata(Name: "spv.cloned_funcs");
391 SmallVector<Metadata *, 2> MDArgs;
392 MDArgs.push_back(Elt: MDString::get(Context&: B.getContext(), Str: NewF->getName()));
393 for (auto &ChangedTyP : ChangedTypes)
394 MDArgs.push_back(Elt: MDNode::get(
395 Context&: B.getContext(),
396 MDs: {ConstantAsMetadata::get(C: B.getInt32(C: ChangedTyP.first)),
397 ValueAsMetadata::get(V: Constant::getNullValue(Ty: ChangedTyP.second))}));
398 MDNode *ThisFuncMD = MDNode::get(Context&: B.getContext(), MDs: MDArgs);
399 FuncMD->addOperand(M: ThisFuncMD);
400
401 for (auto *U : make_early_inc_range(Range: F->users())) {
402 if (auto *CI = dyn_cast<CallInst>(Val: U))
403 CI->mutateFunctionType(FTy: NewF->getFunctionType());
404 U->replaceUsesOfWith(From: F, To: NewF);
405 }
406 return NewF;
407}
408
409bool SPIRVPrepareFunctions::runOnModule(Module &M) {
410 bool Changed = false;
411 for (Function &F : M)
412 Changed |= substituteIntrinsicCalls(F: &F);
413
414 std::vector<Function *> FuncsWorklist;
415 for (auto &F : M)
416 FuncsWorklist.push_back(x: &F);
417
418 for (auto *F : FuncsWorklist) {
419 Function *NewF = removeAggregateTypesFromSignature(F);
420
421 if (NewF != F) {
422 F->eraseFromParent();
423 Changed = true;
424 }
425 }
426 return Changed;
427}
428
429ModulePass *
430llvm::createSPIRVPrepareFunctionsPass(const SPIRVTargetMachine &TM) {
431 return new SPIRVPrepareFunctions(TM);
432}
433

source code of llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp