1 | //===- R600OpenCLImageTypeLoweringPass.cpp ------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | /// \file |
10 | /// This pass resolves calls to OpenCL image attribute, image resource ID and |
11 | /// sampler resource ID getter functions. |
12 | /// |
13 | /// Image attributes (size and format) are expected to be passed to the kernel |
14 | /// as kernel arguments immediately following the image argument itself, |
15 | /// therefore this pass adds image size and format arguments to the kernel |
16 | /// functions in the module. The kernel functions with image arguments are |
17 | /// re-created using the new signature. The new arguments are added to the |
18 | /// kernel metadata with kernel_arg_type set to "image_size" or "image_format". |
19 | /// Note: this pass may invalidate pointers to functions. |
20 | /// |
21 | /// Resource IDs of read-only images, write-only images and samplers are |
22 | /// defined to be their index among the kernel arguments of the same |
23 | /// type and access qualifier. |
24 | // |
25 | //===----------------------------------------------------------------------===// |
26 | |
27 | #include "R600.h" |
28 | #include "llvm/ADT/SmallVector.h" |
29 | #include "llvm/ADT/StringRef.h" |
30 | #include "llvm/IR/Constants.h" |
31 | #include "llvm/IR/Function.h" |
32 | #include "llvm/IR/Instructions.h" |
33 | #include "llvm/IR/Metadata.h" |
34 | #include "llvm/Pass.h" |
35 | #include "llvm/Transforms/Utils/Cloning.h" |
36 | |
37 | using namespace llvm; |
38 | |
39 | static StringRef GetImageSizeFunc = "llvm.OpenCL.image.get.size" ; |
40 | static StringRef GetImageFormatFunc = "llvm.OpenCL.image.get.format" ; |
41 | static StringRef GetImageResourceIDFunc = "llvm.OpenCL.image.get.resource.id" ; |
42 | static StringRef GetSamplerResourceIDFunc = |
43 | "llvm.OpenCL.sampler.get.resource.id" ; |
44 | |
45 | static StringRef ImageSizeArgMDType = "__llvm_image_size" ; |
46 | static StringRef ImageFormatArgMDType = "__llvm_image_format" ; |
47 | |
48 | static StringRef KernelsMDNodeName = "opencl.kernels" ; |
49 | static StringRef KernelArgMDNodeNames[] = { |
50 | "kernel_arg_addr_space" , |
51 | "kernel_arg_access_qual" , |
52 | "kernel_arg_type" , |
53 | "kernel_arg_base_type" , |
54 | "kernel_arg_type_qual" }; |
55 | static const unsigned NumKernelArgMDNodes = 5; |
56 | |
57 | namespace { |
58 | |
59 | using MDVector = SmallVector<Metadata *, 8>; |
60 | struct KernelArgMD { |
61 | MDVector ArgVector[NumKernelArgMDNodes]; |
62 | }; |
63 | |
64 | } // end anonymous namespace |
65 | |
66 | static inline bool |
67 | IsImageType(StringRef TypeString) { |
68 | return TypeString == "image2d_t" || TypeString == "image3d_t" ; |
69 | } |
70 | |
71 | static inline bool |
72 | IsSamplerType(StringRef TypeString) { |
73 | return TypeString == "sampler_t" ; |
74 | } |
75 | |
76 | static Function * |
77 | GetFunctionFromMDNode(MDNode *Node) { |
78 | if (!Node) |
79 | return nullptr; |
80 | |
81 | size_t NumOps = Node->getNumOperands(); |
82 | if (NumOps != NumKernelArgMDNodes + 1) |
83 | return nullptr; |
84 | |
85 | auto F = mdconst::dyn_extract<Function>(MD: Node->getOperand(I: 0)); |
86 | if (!F) |
87 | return nullptr; |
88 | |
89 | // Validation checks. |
90 | size_t ExpectNumArgNodeOps = F->arg_size() + 1; |
91 | for (size_t i = 0; i < NumKernelArgMDNodes; ++i) { |
92 | MDNode *ArgNode = dyn_cast_or_null<MDNode>(Val: Node->getOperand(I: i + 1)); |
93 | if (ArgNode->getNumOperands() != ExpectNumArgNodeOps) |
94 | return nullptr; |
95 | if (!ArgNode->getOperand(I: 0)) |
96 | return nullptr; |
97 | |
98 | // FIXME: It should be possible to do image lowering when some metadata |
99 | // args missing or not in the expected order. |
100 | MDString *StringNode = dyn_cast<MDString>(Val: ArgNode->getOperand(I: 0)); |
101 | if (!StringNode || StringNode->getString() != KernelArgMDNodeNames[i]) |
102 | return nullptr; |
103 | } |
104 | |
105 | return F; |
106 | } |
107 | |
108 | static StringRef |
109 | AccessQualFromMD(MDNode *KernelMDNode, unsigned ArgIdx) { |
110 | MDNode *ArgAQNode = cast<MDNode>(Val: KernelMDNode->getOperand(I: 2)); |
111 | return cast<MDString>(Val: ArgAQNode->getOperand(I: ArgIdx + 1))->getString(); |
112 | } |
113 | |
114 | static StringRef |
115 | ArgTypeFromMD(MDNode *KernelMDNode, unsigned ArgIdx) { |
116 | MDNode *ArgTypeNode = cast<MDNode>(Val: KernelMDNode->getOperand(I: 3)); |
117 | return cast<MDString>(Val: ArgTypeNode->getOperand(I: ArgIdx + 1))->getString(); |
118 | } |
119 | |
120 | static MDVector |
121 | GetArgMD(MDNode *KernelMDNode, unsigned OpIdx) { |
122 | MDVector Res; |
123 | for (unsigned i = 0; i < NumKernelArgMDNodes; ++i) { |
124 | MDNode *Node = cast<MDNode>(Val: KernelMDNode->getOperand(I: i + 1)); |
125 | Res.push_back(Elt: Node->getOperand(I: OpIdx)); |
126 | } |
127 | return Res; |
128 | } |
129 | |
130 | static void |
131 | PushArgMD(KernelArgMD &MD, const MDVector &V) { |
132 | assert(V.size() == NumKernelArgMDNodes); |
133 | for (unsigned i = 0; i < NumKernelArgMDNodes; ++i) { |
134 | MD.ArgVector[i].push_back(Elt: V[i]); |
135 | } |
136 | } |
137 | |
138 | namespace { |
139 | |
140 | class R600OpenCLImageTypeLoweringPass : public ModulePass { |
141 | static char ID; |
142 | |
143 | LLVMContext *Context; |
144 | Type *Int32Type; |
145 | Type *ImageSizeType; |
146 | Type *ImageFormatType; |
147 | SmallVector<Instruction *, 4> InstsToErase; |
148 | |
149 | bool replaceImageUses(Argument &ImageArg, uint32_t ResourceID, |
150 | Argument &ImageSizeArg, |
151 | Argument &ImageFormatArg) { |
152 | bool Modified = false; |
153 | |
154 | for (auto &Use : ImageArg.uses()) { |
155 | auto Inst = dyn_cast<CallInst>(Val: Use.getUser()); |
156 | if (!Inst) { |
157 | continue; |
158 | } |
159 | |
160 | Function *F = Inst->getCalledFunction(); |
161 | if (!F) |
162 | continue; |
163 | |
164 | Value *Replacement = nullptr; |
165 | StringRef Name = F->getName(); |
166 | if (Name.starts_with(Prefix: GetImageResourceIDFunc)) { |
167 | Replacement = ConstantInt::get(Ty: Int32Type, V: ResourceID); |
168 | } else if (Name.starts_with(Prefix: GetImageSizeFunc)) { |
169 | Replacement = &ImageSizeArg; |
170 | } else if (Name.starts_with(Prefix: GetImageFormatFunc)) { |
171 | Replacement = &ImageFormatArg; |
172 | } else { |
173 | continue; |
174 | } |
175 | |
176 | Inst->replaceAllUsesWith(V: Replacement); |
177 | InstsToErase.push_back(Elt: Inst); |
178 | Modified = true; |
179 | } |
180 | |
181 | return Modified; |
182 | } |
183 | |
184 | bool replaceSamplerUses(Argument &SamplerArg, uint32_t ResourceID) { |
185 | bool Modified = false; |
186 | |
187 | for (const auto &Use : SamplerArg.uses()) { |
188 | auto Inst = dyn_cast<CallInst>(Val: Use.getUser()); |
189 | if (!Inst) { |
190 | continue; |
191 | } |
192 | |
193 | Function *F = Inst->getCalledFunction(); |
194 | if (!F) |
195 | continue; |
196 | |
197 | Value *Replacement = nullptr; |
198 | StringRef Name = F->getName(); |
199 | if (Name == GetSamplerResourceIDFunc) { |
200 | Replacement = ConstantInt::get(Ty: Int32Type, V: ResourceID); |
201 | } else { |
202 | continue; |
203 | } |
204 | |
205 | Inst->replaceAllUsesWith(V: Replacement); |
206 | InstsToErase.push_back(Elt: Inst); |
207 | Modified = true; |
208 | } |
209 | |
210 | return Modified; |
211 | } |
212 | |
213 | bool replaceImageAndSamplerUses(Function *F, MDNode *KernelMDNode) { |
214 | uint32_t NumReadOnlyImageArgs = 0; |
215 | uint32_t NumWriteOnlyImageArgs = 0; |
216 | uint32_t NumSamplerArgs = 0; |
217 | |
218 | bool Modified = false; |
219 | InstsToErase.clear(); |
220 | for (auto ArgI = F->arg_begin(); ArgI != F->arg_end(); ++ArgI) { |
221 | Argument &Arg = *ArgI; |
222 | StringRef Type = ArgTypeFromMD(KernelMDNode, ArgIdx: Arg.getArgNo()); |
223 | |
224 | // Handle image types. |
225 | if (IsImageType(TypeString: Type)) { |
226 | StringRef AccessQual = AccessQualFromMD(KernelMDNode, ArgIdx: Arg.getArgNo()); |
227 | uint32_t ResourceID; |
228 | if (AccessQual == "read_only" ) { |
229 | ResourceID = NumReadOnlyImageArgs++; |
230 | } else if (AccessQual == "write_only" ) { |
231 | ResourceID = NumWriteOnlyImageArgs++; |
232 | } else { |
233 | llvm_unreachable("Wrong image access qualifier." ); |
234 | } |
235 | |
236 | Argument &SizeArg = *(++ArgI); |
237 | Argument &FormatArg = *(++ArgI); |
238 | Modified |= replaceImageUses(ImageArg&: Arg, ResourceID, ImageSizeArg&: SizeArg, ImageFormatArg&: FormatArg); |
239 | |
240 | // Handle sampler type. |
241 | } else if (IsSamplerType(TypeString: Type)) { |
242 | uint32_t ResourceID = NumSamplerArgs++; |
243 | Modified |= replaceSamplerUses(SamplerArg&: Arg, ResourceID); |
244 | } |
245 | } |
246 | for (unsigned i = 0; i < InstsToErase.size(); ++i) { |
247 | InstsToErase[i]->eraseFromParent(); |
248 | } |
249 | |
250 | return Modified; |
251 | } |
252 | |
253 | std::tuple<Function *, MDNode *> |
254 | addImplicitArgs(Function *F, MDNode *KernelMDNode) { |
255 | bool Modified = false; |
256 | |
257 | FunctionType *FT = F->getFunctionType(); |
258 | SmallVector<Type *, 8> ArgTypes; |
259 | |
260 | // Metadata operands for new MDNode. |
261 | KernelArgMD NewArgMDs; |
262 | PushArgMD(MD&: NewArgMDs, V: GetArgMD(KernelMDNode, OpIdx: 0)); |
263 | |
264 | // Add implicit arguments to the signature. |
265 | for (unsigned i = 0; i < FT->getNumParams(); ++i) { |
266 | ArgTypes.push_back(Elt: FT->getParamType(i)); |
267 | MDVector ArgMD = GetArgMD(KernelMDNode, OpIdx: i + 1); |
268 | PushArgMD(MD&: NewArgMDs, V: ArgMD); |
269 | |
270 | if (!IsImageType(TypeString: ArgTypeFromMD(KernelMDNode, ArgIdx: i))) |
271 | continue; |
272 | |
273 | // Add size implicit argument. |
274 | ArgTypes.push_back(Elt: ImageSizeType); |
275 | ArgMD[2] = ArgMD[3] = MDString::get(Context&: *Context, Str: ImageSizeArgMDType); |
276 | PushArgMD(MD&: NewArgMDs, V: ArgMD); |
277 | |
278 | // Add format implicit argument. |
279 | ArgTypes.push_back(Elt: ImageFormatType); |
280 | ArgMD[2] = ArgMD[3] = MDString::get(Context&: *Context, Str: ImageFormatArgMDType); |
281 | PushArgMD(MD&: NewArgMDs, V: ArgMD); |
282 | |
283 | Modified = true; |
284 | } |
285 | if (!Modified) { |
286 | return std::tuple(nullptr, nullptr); |
287 | } |
288 | |
289 | // Create function with new signature and clone the old body into it. |
290 | auto NewFT = FunctionType::get(Result: FT->getReturnType(), Params: ArgTypes, isVarArg: false); |
291 | auto NewF = Function::Create(Ty: NewFT, Linkage: F->getLinkage(), N: F->getName()); |
292 | ValueToValueMapTy VMap; |
293 | auto NewFArgIt = NewF->arg_begin(); |
294 | for (auto &Arg: F->args()) { |
295 | auto ArgName = Arg.getName(); |
296 | NewFArgIt->setName(ArgName); |
297 | VMap[&Arg] = &(*NewFArgIt++); |
298 | if (IsImageType(TypeString: ArgTypeFromMD(KernelMDNode, ArgIdx: Arg.getArgNo()))) { |
299 | (NewFArgIt++)->setName(Twine("__size_" ) + ArgName); |
300 | (NewFArgIt++)->setName(Twine("__format_" ) + ArgName); |
301 | } |
302 | } |
303 | SmallVector<ReturnInst*, 8> Returns; |
304 | CloneFunctionInto(NewFunc: NewF, OldFunc: F, VMap, Changes: CloneFunctionChangeType::LocalChangesOnly, |
305 | Returns); |
306 | |
307 | // Build new MDNode. |
308 | SmallVector<Metadata *, 6> KernelMDArgs; |
309 | KernelMDArgs.push_back(Elt: ConstantAsMetadata::get(C: NewF)); |
310 | for (const MDVector &MDV : NewArgMDs.ArgVector) |
311 | KernelMDArgs.push_back(Elt: MDNode::get(Context&: *Context, MDs: MDV)); |
312 | MDNode *NewMDNode = MDNode::get(Context&: *Context, MDs: KernelMDArgs); |
313 | |
314 | return std::tuple(NewF, NewMDNode); |
315 | } |
316 | |
317 | bool transformKernels(Module &M) { |
318 | NamedMDNode *KernelsMDNode = M.getNamedMetadata(Name: KernelsMDNodeName); |
319 | if (!KernelsMDNode) |
320 | return false; |
321 | |
322 | bool Modified = false; |
323 | for (unsigned i = 0; i < KernelsMDNode->getNumOperands(); ++i) { |
324 | MDNode *KernelMDNode = KernelsMDNode->getOperand(i); |
325 | Function *F = GetFunctionFromMDNode(Node: KernelMDNode); |
326 | if (!F) |
327 | continue; |
328 | |
329 | Function *NewF; |
330 | MDNode *NewMDNode; |
331 | std::tie(args&: NewF, args&: NewMDNode) = addImplicitArgs(F, KernelMDNode); |
332 | if (NewF) { |
333 | // Replace old function and metadata with new ones. |
334 | F->eraseFromParent(); |
335 | M.getFunctionList().push_back(val: NewF); |
336 | M.getOrInsertFunction(Name: NewF->getName(), T: NewF->getFunctionType(), |
337 | AttributeList: NewF->getAttributes()); |
338 | KernelsMDNode->setOperand(I: i, New: NewMDNode); |
339 | |
340 | F = NewF; |
341 | KernelMDNode = NewMDNode; |
342 | Modified = true; |
343 | } |
344 | |
345 | Modified |= replaceImageAndSamplerUses(F, KernelMDNode); |
346 | } |
347 | |
348 | return Modified; |
349 | } |
350 | |
351 | public: |
352 | R600OpenCLImageTypeLoweringPass() : ModulePass(ID) {} |
353 | |
354 | bool runOnModule(Module &M) override { |
355 | Context = &M.getContext(); |
356 | Int32Type = Type::getInt32Ty(C&: M.getContext()); |
357 | ImageSizeType = ArrayType::get(ElementType: Int32Type, NumElements: 3); |
358 | ImageFormatType = ArrayType::get(ElementType: Int32Type, NumElements: 2); |
359 | |
360 | return transformKernels(M); |
361 | } |
362 | |
363 | StringRef getPassName() const override { |
364 | return "R600 OpenCL Image Type Pass" ; |
365 | } |
366 | }; |
367 | |
368 | } // end anonymous namespace |
369 | |
370 | char R600OpenCLImageTypeLoweringPass::ID = 0; |
371 | |
372 | ModulePass *llvm::createR600OpenCLImageTypeLoweringPass() { |
373 | return new R600OpenCLImageTypeLoweringPass(); |
374 | } |
375 | |