1 | //===-- SPIRVPrepareFunctions.cpp - modify function signatures --*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This pass modifies function signatures containing aggregate arguments |
10 | // and/or return value before IRTranslator. Information about the original |
11 | // signatures is stored in metadata. It is used during call lowering to |
12 | // restore correct SPIR-V types of function arguments and return values. |
13 | // This pass also substitutes some llvm intrinsic calls with calls to newly |
14 | // generated functions (as the Khronos LLVM/SPIR-V Translator does). |
15 | // |
16 | // NOTE: this pass is a module-level one due to the necessity to modify |
17 | // GVs/functions. |
18 | // |
19 | //===----------------------------------------------------------------------===// |
20 | |
21 | #include "SPIRV.h" |
22 | #include "SPIRVSubtarget.h" |
23 | #include "SPIRVTargetMachine.h" |
24 | #include "SPIRVUtils.h" |
25 | #include "llvm/CodeGen/IntrinsicLowering.h" |
26 | #include "llvm/IR/IRBuilder.h" |
27 | #include "llvm/IR/IntrinsicInst.h" |
28 | #include "llvm/IR/Intrinsics.h" |
29 | #include "llvm/IR/IntrinsicsSPIRV.h" |
30 | #include "llvm/Transforms/Utils/Cloning.h" |
31 | #include "llvm/Transforms/Utils/LowerMemIntrinsics.h" |
32 | |
33 | using namespace llvm; |
34 | |
35 | namespace llvm { |
36 | void initializeSPIRVPrepareFunctionsPass(PassRegistry &); |
37 | } |
38 | |
39 | namespace { |
40 | |
41 | class SPIRVPrepareFunctions : public ModulePass { |
42 | const SPIRVTargetMachine &TM; |
43 | bool substituteIntrinsicCalls(Function *F); |
44 | Function *removeAggregateTypesFromSignature(Function *F); |
45 | |
46 | public: |
47 | static char ID; |
48 | SPIRVPrepareFunctions(const SPIRVTargetMachine &TM) : ModulePass(ID), TM(TM) { |
49 | initializeSPIRVPrepareFunctionsPass(*PassRegistry::getPassRegistry()); |
50 | } |
51 | |
52 | bool runOnModule(Module &M) override; |
53 | |
54 | StringRef getPassName() const override { return "SPIRV prepare functions" ; } |
55 | |
56 | void getAnalysisUsage(AnalysisUsage &AU) const override { |
57 | ModulePass::getAnalysisUsage(AU); |
58 | } |
59 | }; |
60 | |
61 | } // namespace |
62 | |
63 | char SPIRVPrepareFunctions::ID = 0; |
64 | |
65 | INITIALIZE_PASS(SPIRVPrepareFunctions, "prepare-functions" , |
66 | "SPIRV prepare functions" , false, false) |
67 | |
68 | std::string lowerLLVMIntrinsicName(IntrinsicInst *II) { |
69 | Function *IntrinsicFunc = II->getCalledFunction(); |
70 | assert(IntrinsicFunc && "Missing function" ); |
71 | std::string FuncName = IntrinsicFunc->getName().str(); |
72 | std::replace(first: FuncName.begin(), last: FuncName.end(), old_value: '.', new_value: '_'); |
73 | FuncName = "spirv." + FuncName; |
74 | return FuncName; |
75 | } |
76 | |
77 | static Function *getOrCreateFunction(Module *M, Type *RetTy, |
78 | ArrayRef<Type *> ArgTypes, |
79 | StringRef Name) { |
80 | FunctionType *FT = FunctionType::get(Result: RetTy, Params: ArgTypes, isVarArg: false); |
81 | Function *F = M->getFunction(Name); |
82 | if (F && F->getFunctionType() == FT) |
83 | return F; |
84 | Function *NewF = Function::Create(Ty: FT, Linkage: GlobalValue::ExternalLinkage, N: Name, M); |
85 | if (F) |
86 | NewF->setDSOLocal(F->isDSOLocal()); |
87 | NewF->setCallingConv(CallingConv::SPIR_FUNC); |
88 | return NewF; |
89 | } |
90 | |
91 | static bool lowerIntrinsicToFunction(IntrinsicInst *Intrinsic) { |
92 | // For @llvm.memset.* intrinsic cases with constant value and length arguments |
93 | // are emulated via "storing" a constant array to the destination. For other |
94 | // cases we wrap the intrinsic in @spirv.llvm_memset_* function and expand the |
95 | // intrinsic to a loop via expandMemSetAsLoop(). |
96 | if (auto *MSI = dyn_cast<MemSetInst>(Val: Intrinsic)) |
97 | if (isa<Constant>(Val: MSI->getValue()) && isa<ConstantInt>(Val: MSI->getLength())) |
98 | return false; // It is handled later using OpCopyMemorySized. |
99 | |
100 | Module *M = Intrinsic->getModule(); |
101 | std::string FuncName = lowerLLVMIntrinsicName(II: Intrinsic); |
102 | if (Intrinsic->isVolatile()) |
103 | FuncName += ".volatile" ; |
104 | // Redirect @llvm.intrinsic.* call to @spirv.llvm_intrinsic_* |
105 | Function *F = M->getFunction(Name: FuncName); |
106 | if (F) { |
107 | Intrinsic->setCalledFunction(F); |
108 | return true; |
109 | } |
110 | // TODO copy arguments attributes: nocapture writeonly. |
111 | FunctionCallee FC = |
112 | M->getOrInsertFunction(Name: FuncName, T: Intrinsic->getFunctionType()); |
113 | auto IntrinsicID = Intrinsic->getIntrinsicID(); |
114 | Intrinsic->setCalledFunction(FC); |
115 | |
116 | F = dyn_cast<Function>(Val: FC.getCallee()); |
117 | assert(F && "Callee must be a function" ); |
118 | |
119 | switch (IntrinsicID) { |
120 | case Intrinsic::memset: { |
121 | auto *MSI = static_cast<MemSetInst *>(Intrinsic); |
122 | Argument *Dest = F->getArg(i: 0); |
123 | Argument *Val = F->getArg(i: 1); |
124 | Argument *Len = F->getArg(i: 2); |
125 | Argument *IsVolatile = F->getArg(i: 3); |
126 | Dest->setName("dest" ); |
127 | Val->setName("val" ); |
128 | Len->setName("len" ); |
129 | IsVolatile->setName("isvolatile" ); |
130 | BasicBlock *EntryBB = BasicBlock::Create(Context&: M->getContext(), Name: "entry" , Parent: F); |
131 | IRBuilder<> IRB(EntryBB); |
132 | auto *MemSet = IRB.CreateMemSet(Ptr: Dest, Val, Size: Len, Align: MSI->getDestAlign(), |
133 | isVolatile: MSI->isVolatile()); |
134 | IRB.CreateRetVoid(); |
135 | expandMemSetAsLoop(MemSet: cast<MemSetInst>(Val: MemSet)); |
136 | MemSet->eraseFromParent(); |
137 | break; |
138 | } |
139 | case Intrinsic::bswap: { |
140 | BasicBlock *EntryBB = BasicBlock::Create(Context&: M->getContext(), Name: "entry" , Parent: F); |
141 | IRBuilder<> IRB(EntryBB); |
142 | auto *BSwap = IRB.CreateIntrinsic(Intrinsic::bswap, Intrinsic->getType(), |
143 | F->getArg(0)); |
144 | IRB.CreateRet(V: BSwap); |
145 | IntrinsicLowering IL(M->getDataLayout()); |
146 | IL.LowerIntrinsicCall(CI: BSwap); |
147 | break; |
148 | } |
149 | default: |
150 | break; |
151 | } |
152 | return true; |
153 | } |
154 | |
155 | static void lowerFunnelShifts(IntrinsicInst *FSHIntrinsic) { |
156 | // Get a separate function - otherwise, we'd have to rework the CFG of the |
157 | // current one. Then simply replace the intrinsic uses with a call to the new |
158 | // function. |
159 | // Generate LLVM IR for i* @spirv.llvm_fsh?_i* (i* %a, i* %b, i* %c) |
160 | Module *M = FSHIntrinsic->getModule(); |
161 | FunctionType *FSHFuncTy = FSHIntrinsic->getFunctionType(); |
162 | Type *FSHRetTy = FSHFuncTy->getReturnType(); |
163 | const std::string FuncName = lowerLLVMIntrinsicName(II: FSHIntrinsic); |
164 | Function *FSHFunc = |
165 | getOrCreateFunction(M, RetTy: FSHRetTy, ArgTypes: FSHFuncTy->params(), Name: FuncName); |
166 | |
167 | if (!FSHFunc->empty()) { |
168 | FSHIntrinsic->setCalledFunction(FSHFunc); |
169 | return; |
170 | } |
171 | BasicBlock *RotateBB = BasicBlock::Create(Context&: M->getContext(), Name: "rotate" , Parent: FSHFunc); |
172 | IRBuilder<> IRB(RotateBB); |
173 | Type *Ty = FSHFunc->getReturnType(); |
174 | // Build the actual funnel shift rotate logic. |
175 | // In the comments, "int" is used interchangeably with "vector of int |
176 | // elements". |
177 | FixedVectorType *VectorTy = dyn_cast<FixedVectorType>(Val: Ty); |
178 | Type *IntTy = VectorTy ? VectorTy->getElementType() : Ty; |
179 | unsigned BitWidth = IntTy->getIntegerBitWidth(); |
180 | ConstantInt *BitWidthConstant = IRB.getInt(AI: {BitWidth, BitWidth}); |
181 | Value *BitWidthForInsts = |
182 | VectorTy |
183 | ? IRB.CreateVectorSplat(NumElts: VectorTy->getNumElements(), V: BitWidthConstant) |
184 | : BitWidthConstant; |
185 | Value *RotateModVal = |
186 | IRB.CreateURem(/*Rotate*/ LHS: FSHFunc->getArg(i: 2), RHS: BitWidthForInsts); |
187 | Value *FirstShift = nullptr, *SecShift = nullptr; |
188 | if (FSHIntrinsic->getIntrinsicID() == Intrinsic::fshr) { |
189 | // Shift the less significant number right, the "rotate" number of bits |
190 | // will be 0-filled on the left as a result of this regular shift. |
191 | FirstShift = IRB.CreateLShr(LHS: FSHFunc->getArg(i: 1), RHS: RotateModVal); |
192 | } else { |
193 | // Shift the more significant number left, the "rotate" number of bits |
194 | // will be 0-filled on the right as a result of this regular shift. |
195 | FirstShift = IRB.CreateShl(LHS: FSHFunc->getArg(i: 0), RHS: RotateModVal); |
196 | } |
197 | // We want the "rotate" number of the more significant int's LSBs (MSBs) to |
198 | // occupy the leftmost (rightmost) "0 space" left by the previous operation. |
199 | // Therefore, subtract the "rotate" number from the integer bitsize... |
200 | Value *SubRotateVal = IRB.CreateSub(LHS: BitWidthForInsts, RHS: RotateModVal); |
201 | if (FSHIntrinsic->getIntrinsicID() == Intrinsic::fshr) { |
202 | // ...and left-shift the more significant int by this number, zero-filling |
203 | // the LSBs. |
204 | SecShift = IRB.CreateShl(LHS: FSHFunc->getArg(i: 0), RHS: SubRotateVal); |
205 | } else { |
206 | // ...and right-shift the less significant int by this number, zero-filling |
207 | // the MSBs. |
208 | SecShift = IRB.CreateLShr(LHS: FSHFunc->getArg(i: 1), RHS: SubRotateVal); |
209 | } |
210 | // A simple binary addition of the shifted ints yields the final result. |
211 | IRB.CreateRet(V: IRB.CreateOr(LHS: FirstShift, RHS: SecShift)); |
212 | |
213 | FSHIntrinsic->setCalledFunction(FSHFunc); |
214 | } |
215 | |
216 | static void buildUMulWithOverflowFunc(Function *UMulFunc) { |
217 | // The function body is already created. |
218 | if (!UMulFunc->empty()) |
219 | return; |
220 | |
221 | BasicBlock *EntryBB = BasicBlock::Create(Context&: UMulFunc->getParent()->getContext(), |
222 | Name: "entry" , Parent: UMulFunc); |
223 | IRBuilder<> IRB(EntryBB); |
224 | // Build the actual unsigned multiplication logic with the overflow |
225 | // indication. Do unsigned multiplication Mul = A * B. Then check |
226 | // if unsigned division Div = Mul / A is not equal to B. If so, |
227 | // then overflow has happened. |
228 | Value *Mul = IRB.CreateNUWMul(LHS: UMulFunc->getArg(i: 0), RHS: UMulFunc->getArg(i: 1)); |
229 | Value *Div = IRB.CreateUDiv(LHS: Mul, RHS: UMulFunc->getArg(i: 0)); |
230 | Value *Overflow = IRB.CreateICmpNE(LHS: UMulFunc->getArg(i: 0), RHS: Div); |
231 | |
232 | // umul.with.overflow intrinsic return a structure, where the first element |
233 | // is the multiplication result, and the second is an overflow bit. |
234 | Type *StructTy = UMulFunc->getReturnType(); |
235 | Value *Agg = IRB.CreateInsertValue(Agg: PoisonValue::get(T: StructTy), Val: Mul, Idxs: {0}); |
236 | Value *Res = IRB.CreateInsertValue(Agg, Val: Overflow, Idxs: {1}); |
237 | IRB.CreateRet(V: Res); |
238 | } |
239 | |
240 | static void lowerExpectAssume(IntrinsicInst *II) { |
241 | // If we cannot use the SPV_KHR_expect_assume extension, then we need to |
242 | // ignore the intrinsic and move on. It should be removed later on by LLVM. |
243 | // Otherwise we should lower the intrinsic to the corresponding SPIR-V |
244 | // instruction. |
245 | // For @llvm.assume we have OpAssumeTrueKHR. |
246 | // For @llvm.expect we have OpExpectKHR. |
247 | // |
248 | // We need to lower this into a builtin and then the builtin into a SPIR-V |
249 | // instruction. |
250 | if (II->getIntrinsicID() == Intrinsic::assume) { |
251 | Function *F = Intrinsic::getDeclaration( |
252 | II->getModule(), Intrinsic::SPVIntrinsics::spv_assume); |
253 | II->setCalledFunction(F); |
254 | } else if (II->getIntrinsicID() == Intrinsic::expect) { |
255 | Function *F = Intrinsic::getDeclaration( |
256 | II->getModule(), Intrinsic::SPVIntrinsics::spv_expect, |
257 | {II->getOperand(0)->getType()}); |
258 | II->setCalledFunction(F); |
259 | } else { |
260 | llvm_unreachable("Unknown intrinsic" ); |
261 | } |
262 | |
263 | return; |
264 | } |
265 | |
266 | static bool toSpvOverloadedIntrinsic(IntrinsicInst *II, Intrinsic::ID NewID, |
267 | ArrayRef<unsigned> OpNos) { |
268 | Function *F = nullptr; |
269 | if (OpNos.empty()) { |
270 | F = Intrinsic::getDeclaration(M: II->getModule(), id: NewID); |
271 | } else { |
272 | SmallVector<Type *, 4> Tys; |
273 | for (unsigned OpNo : OpNos) |
274 | Tys.push_back(Elt: II->getOperand(i_nocapture: OpNo)->getType()); |
275 | F = Intrinsic::getDeclaration(M: II->getModule(), id: NewID, Tys); |
276 | } |
277 | II->setCalledFunction(F); |
278 | return true; |
279 | } |
280 | |
281 | static void lowerUMulWithOverflow(IntrinsicInst *UMulIntrinsic) { |
282 | // Get a separate function - otherwise, we'd have to rework the CFG of the |
283 | // current one. Then simply replace the intrinsic uses with a call to the new |
284 | // function. |
285 | Module *M = UMulIntrinsic->getModule(); |
286 | FunctionType *UMulFuncTy = UMulIntrinsic->getFunctionType(); |
287 | Type *FSHLRetTy = UMulFuncTy->getReturnType(); |
288 | const std::string FuncName = lowerLLVMIntrinsicName(II: UMulIntrinsic); |
289 | Function *UMulFunc = |
290 | getOrCreateFunction(M, RetTy: FSHLRetTy, ArgTypes: UMulFuncTy->params(), Name: FuncName); |
291 | buildUMulWithOverflowFunc(UMulFunc); |
292 | UMulIntrinsic->setCalledFunction(UMulFunc); |
293 | } |
294 | |
295 | // Substitutes calls to LLVM intrinsics with either calls to SPIR-V intrinsics |
296 | // or calls to proper generated functions. Returns True if F was modified. |
297 | bool SPIRVPrepareFunctions::substituteIntrinsicCalls(Function *F) { |
298 | bool Changed = false; |
299 | for (BasicBlock &BB : *F) { |
300 | for (Instruction &I : BB) { |
301 | auto Call = dyn_cast<CallInst>(Val: &I); |
302 | if (!Call) |
303 | continue; |
304 | Function *CF = Call->getCalledFunction(); |
305 | if (!CF || !CF->isIntrinsic()) |
306 | continue; |
307 | auto *II = cast<IntrinsicInst>(Val: Call); |
308 | switch (II->getIntrinsicID()) { |
309 | case Intrinsic::memset: |
310 | case Intrinsic::bswap: |
311 | Changed |= lowerIntrinsicToFunction(Intrinsic: II); |
312 | break; |
313 | case Intrinsic::fshl: |
314 | case Intrinsic::fshr: |
315 | lowerFunnelShifts(FSHIntrinsic: II); |
316 | Changed = true; |
317 | break; |
318 | case Intrinsic::umul_with_overflow: |
319 | lowerUMulWithOverflow(UMulIntrinsic: II); |
320 | Changed = true; |
321 | break; |
322 | case Intrinsic::assume: |
323 | case Intrinsic::expect: { |
324 | const SPIRVSubtarget &STI = TM.getSubtarget<SPIRVSubtarget>(*F); |
325 | if (STI.canUseExtension(SPIRV::Extension::SPV_KHR_expect_assume)) |
326 | lowerExpectAssume(II); |
327 | Changed = true; |
328 | } break; |
329 | case Intrinsic::lifetime_start: |
330 | Changed |= toSpvOverloadedIntrinsic( |
331 | II, Intrinsic::SPVIntrinsics::spv_lifetime_start, {1}); |
332 | break; |
333 | case Intrinsic::lifetime_end: |
334 | Changed |= toSpvOverloadedIntrinsic( |
335 | II, Intrinsic::SPVIntrinsics::spv_lifetime_end, {1}); |
336 | break; |
337 | } |
338 | } |
339 | } |
340 | return Changed; |
341 | } |
342 | |
343 | // Returns F if aggregate argument/return types are not present or cloned F |
344 | // function with the types replaced by i32 types. The change in types is |
345 | // noted in 'spv.cloned_funcs' metadata for later restoration. |
346 | Function * |
347 | SPIRVPrepareFunctions::removeAggregateTypesFromSignature(Function *F) { |
348 | IRBuilder<> B(F->getContext()); |
349 | |
350 | bool IsRetAggr = F->getReturnType()->isAggregateType(); |
351 | bool HasAggrArg = |
352 | std::any_of(first: F->arg_begin(), last: F->arg_end(), pred: [](Argument &Arg) { |
353 | return Arg.getType()->isAggregateType(); |
354 | }); |
355 | bool DoClone = IsRetAggr || HasAggrArg; |
356 | if (!DoClone) |
357 | return F; |
358 | SmallVector<std::pair<int, Type *>, 4> ChangedTypes; |
359 | Type *RetType = IsRetAggr ? B.getInt32Ty() : F->getReturnType(); |
360 | if (IsRetAggr) |
361 | ChangedTypes.push_back(Elt: std::pair<int, Type *>(-1, F->getReturnType())); |
362 | SmallVector<Type *, 4> ArgTypes; |
363 | for (const auto &Arg : F->args()) { |
364 | if (Arg.getType()->isAggregateType()) { |
365 | ArgTypes.push_back(Elt: B.getInt32Ty()); |
366 | ChangedTypes.push_back( |
367 | Elt: std::pair<int, Type *>(Arg.getArgNo(), Arg.getType())); |
368 | } else |
369 | ArgTypes.push_back(Elt: Arg.getType()); |
370 | } |
371 | FunctionType *NewFTy = |
372 | FunctionType::get(Result: RetType, Params: ArgTypes, isVarArg: F->getFunctionType()->isVarArg()); |
373 | Function *NewF = |
374 | Function::Create(Ty: NewFTy, Linkage: F->getLinkage(), N: F->getName(), M&: *F->getParent()); |
375 | |
376 | ValueToValueMapTy VMap; |
377 | auto NewFArgIt = NewF->arg_begin(); |
378 | for (auto &Arg : F->args()) { |
379 | StringRef ArgName = Arg.getName(); |
380 | NewFArgIt->setName(ArgName); |
381 | VMap[&Arg] = &(*NewFArgIt++); |
382 | } |
383 | SmallVector<ReturnInst *, 8> Returns; |
384 | |
385 | CloneFunctionInto(NewFunc: NewF, OldFunc: F, VMap, Changes: CloneFunctionChangeType::LocalChangesOnly, |
386 | Returns); |
387 | NewF->takeName(V: F); |
388 | |
389 | NamedMDNode *FuncMD = |
390 | F->getParent()->getOrInsertNamedMetadata(Name: "spv.cloned_funcs" ); |
391 | SmallVector<Metadata *, 2> MDArgs; |
392 | MDArgs.push_back(Elt: MDString::get(Context&: B.getContext(), Str: NewF->getName())); |
393 | for (auto &ChangedTyP : ChangedTypes) |
394 | MDArgs.push_back(Elt: MDNode::get( |
395 | Context&: B.getContext(), |
396 | MDs: {ConstantAsMetadata::get(C: B.getInt32(C: ChangedTyP.first)), |
397 | ValueAsMetadata::get(V: Constant::getNullValue(Ty: ChangedTyP.second))})); |
398 | MDNode *ThisFuncMD = MDNode::get(Context&: B.getContext(), MDs: MDArgs); |
399 | FuncMD->addOperand(M: ThisFuncMD); |
400 | |
401 | for (auto *U : make_early_inc_range(Range: F->users())) { |
402 | if (auto *CI = dyn_cast<CallInst>(Val: U)) |
403 | CI->mutateFunctionType(FTy: NewF->getFunctionType()); |
404 | U->replaceUsesOfWith(From: F, To: NewF); |
405 | } |
406 | return NewF; |
407 | } |
408 | |
409 | bool SPIRVPrepareFunctions::runOnModule(Module &M) { |
410 | bool Changed = false; |
411 | for (Function &F : M) |
412 | Changed |= substituteIntrinsicCalls(F: &F); |
413 | |
414 | std::vector<Function *> FuncsWorklist; |
415 | for (auto &F : M) |
416 | FuncsWorklist.push_back(x: &F); |
417 | |
418 | for (auto *F : FuncsWorklist) { |
419 | Function *NewF = removeAggregateTypesFromSignature(F); |
420 | |
421 | if (NewF != F) { |
422 | F->eraseFromParent(); |
423 | Changed = true; |
424 | } |
425 | } |
426 | return Changed; |
427 | } |
428 | |
429 | ModulePass * |
430 | llvm::createSPIRVPrepareFunctionsPass(const SPIRVTargetMachine &TM) { |
431 | return new SPIRVPrepareFunctions(TM); |
432 | } |
433 | |