1//===- R600OpenCLImageTypeLoweringPass.cpp ------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9/// \file
10/// This pass resolves calls to OpenCL image attribute, image resource ID and
11/// sampler resource ID getter functions.
12///
13/// Image attributes (size and format) are expected to be passed to the kernel
14/// as kernel arguments immediately following the image argument itself,
15/// therefore this pass adds image size and format arguments to the kernel
16/// functions in the module. The kernel functions with image arguments are
17/// re-created using the new signature. The new arguments are added to the
18/// kernel metadata with kernel_arg_type set to "image_size" or "image_format".
19/// Note: this pass may invalidate pointers to functions.
20///
21/// Resource IDs of read-only images, write-only images and samplers are
22/// defined to be their index among the kernel arguments of the same
23/// type and access qualifier.
24//
25//===----------------------------------------------------------------------===//
26
27#include "R600.h"
28#include "llvm/ADT/SmallVector.h"
29#include "llvm/ADT/StringRef.h"
30#include "llvm/IR/Constants.h"
31#include "llvm/IR/Function.h"
32#include "llvm/IR/Instructions.h"
33#include "llvm/IR/Metadata.h"
34#include "llvm/Pass.h"
35#include "llvm/Transforms/Utils/Cloning.h"
36
37using namespace llvm;
38
39static StringRef GetImageSizeFunc = "llvm.OpenCL.image.get.size";
40static StringRef GetImageFormatFunc = "llvm.OpenCL.image.get.format";
41static StringRef GetImageResourceIDFunc = "llvm.OpenCL.image.get.resource.id";
42static StringRef GetSamplerResourceIDFunc =
43 "llvm.OpenCL.sampler.get.resource.id";
44
45static StringRef ImageSizeArgMDType = "__llvm_image_size";
46static StringRef ImageFormatArgMDType = "__llvm_image_format";
47
48static StringRef KernelsMDNodeName = "opencl.kernels";
49static StringRef KernelArgMDNodeNames[] = {
50 "kernel_arg_addr_space",
51 "kernel_arg_access_qual",
52 "kernel_arg_type",
53 "kernel_arg_base_type",
54 "kernel_arg_type_qual"};
55static const unsigned NumKernelArgMDNodes = 5;
56
57namespace {
58
59using MDVector = SmallVector<Metadata *, 8>;
60struct KernelArgMD {
61 MDVector ArgVector[NumKernelArgMDNodes];
62};
63
64} // end anonymous namespace
65
66static inline bool
67IsImageType(StringRef TypeString) {
68 return TypeString == "image2d_t" || TypeString == "image3d_t";
69}
70
71static inline bool
72IsSamplerType(StringRef TypeString) {
73 return TypeString == "sampler_t";
74}
75
76static Function *
77GetFunctionFromMDNode(MDNode *Node) {
78 if (!Node)
79 return nullptr;
80
81 size_t NumOps = Node->getNumOperands();
82 if (NumOps != NumKernelArgMDNodes + 1)
83 return nullptr;
84
85 auto F = mdconst::dyn_extract<Function>(MD: Node->getOperand(I: 0));
86 if (!F)
87 return nullptr;
88
89 // Validation checks.
90 size_t ExpectNumArgNodeOps = F->arg_size() + 1;
91 for (size_t i = 0; i < NumKernelArgMDNodes; ++i) {
92 MDNode *ArgNode = dyn_cast_or_null<MDNode>(Val: Node->getOperand(I: i + 1));
93 if (ArgNode->getNumOperands() != ExpectNumArgNodeOps)
94 return nullptr;
95 if (!ArgNode->getOperand(I: 0))
96 return nullptr;
97
98 // FIXME: It should be possible to do image lowering when some metadata
99 // args missing or not in the expected order.
100 MDString *StringNode = dyn_cast<MDString>(Val: ArgNode->getOperand(I: 0));
101 if (!StringNode || StringNode->getString() != KernelArgMDNodeNames[i])
102 return nullptr;
103 }
104
105 return F;
106}
107
108static StringRef
109AccessQualFromMD(MDNode *KernelMDNode, unsigned ArgIdx) {
110 MDNode *ArgAQNode = cast<MDNode>(Val: KernelMDNode->getOperand(I: 2));
111 return cast<MDString>(Val: ArgAQNode->getOperand(I: ArgIdx + 1))->getString();
112}
113
114static StringRef
115ArgTypeFromMD(MDNode *KernelMDNode, unsigned ArgIdx) {
116 MDNode *ArgTypeNode = cast<MDNode>(Val: KernelMDNode->getOperand(I: 3));
117 return cast<MDString>(Val: ArgTypeNode->getOperand(I: ArgIdx + 1))->getString();
118}
119
120static MDVector
121GetArgMD(MDNode *KernelMDNode, unsigned OpIdx) {
122 MDVector Res;
123 for (unsigned i = 0; i < NumKernelArgMDNodes; ++i) {
124 MDNode *Node = cast<MDNode>(Val: KernelMDNode->getOperand(I: i + 1));
125 Res.push_back(Elt: Node->getOperand(I: OpIdx));
126 }
127 return Res;
128}
129
130static void
131PushArgMD(KernelArgMD &MD, const MDVector &V) {
132 assert(V.size() == NumKernelArgMDNodes);
133 for (unsigned i = 0; i < NumKernelArgMDNodes; ++i) {
134 MD.ArgVector[i].push_back(Elt: V[i]);
135 }
136}
137
138namespace {
139
140class R600OpenCLImageTypeLoweringPass : public ModulePass {
141 static char ID;
142
143 LLVMContext *Context;
144 Type *Int32Type;
145 Type *ImageSizeType;
146 Type *ImageFormatType;
147 SmallVector<Instruction *, 4> InstsToErase;
148
149 bool replaceImageUses(Argument &ImageArg, uint32_t ResourceID,
150 Argument &ImageSizeArg,
151 Argument &ImageFormatArg) {
152 bool Modified = false;
153
154 for (auto &Use : ImageArg.uses()) {
155 auto Inst = dyn_cast<CallInst>(Val: Use.getUser());
156 if (!Inst) {
157 continue;
158 }
159
160 Function *F = Inst->getCalledFunction();
161 if (!F)
162 continue;
163
164 Value *Replacement = nullptr;
165 StringRef Name = F->getName();
166 if (Name.starts_with(Prefix: GetImageResourceIDFunc)) {
167 Replacement = ConstantInt::get(Ty: Int32Type, V: ResourceID);
168 } else if (Name.starts_with(Prefix: GetImageSizeFunc)) {
169 Replacement = &ImageSizeArg;
170 } else if (Name.starts_with(Prefix: GetImageFormatFunc)) {
171 Replacement = &ImageFormatArg;
172 } else {
173 continue;
174 }
175
176 Inst->replaceAllUsesWith(V: Replacement);
177 InstsToErase.push_back(Elt: Inst);
178 Modified = true;
179 }
180
181 return Modified;
182 }
183
184 bool replaceSamplerUses(Argument &SamplerArg, uint32_t ResourceID) {
185 bool Modified = false;
186
187 for (const auto &Use : SamplerArg.uses()) {
188 auto Inst = dyn_cast<CallInst>(Val: Use.getUser());
189 if (!Inst) {
190 continue;
191 }
192
193 Function *F = Inst->getCalledFunction();
194 if (!F)
195 continue;
196
197 Value *Replacement = nullptr;
198 StringRef Name = F->getName();
199 if (Name == GetSamplerResourceIDFunc) {
200 Replacement = ConstantInt::get(Ty: Int32Type, V: ResourceID);
201 } else {
202 continue;
203 }
204
205 Inst->replaceAllUsesWith(V: Replacement);
206 InstsToErase.push_back(Elt: Inst);
207 Modified = true;
208 }
209
210 return Modified;
211 }
212
213 bool replaceImageAndSamplerUses(Function *F, MDNode *KernelMDNode) {
214 uint32_t NumReadOnlyImageArgs = 0;
215 uint32_t NumWriteOnlyImageArgs = 0;
216 uint32_t NumSamplerArgs = 0;
217
218 bool Modified = false;
219 InstsToErase.clear();
220 for (auto ArgI = F->arg_begin(); ArgI != F->arg_end(); ++ArgI) {
221 Argument &Arg = *ArgI;
222 StringRef Type = ArgTypeFromMD(KernelMDNode, ArgIdx: Arg.getArgNo());
223
224 // Handle image types.
225 if (IsImageType(TypeString: Type)) {
226 StringRef AccessQual = AccessQualFromMD(KernelMDNode, ArgIdx: Arg.getArgNo());
227 uint32_t ResourceID;
228 if (AccessQual == "read_only") {
229 ResourceID = NumReadOnlyImageArgs++;
230 } else if (AccessQual == "write_only") {
231 ResourceID = NumWriteOnlyImageArgs++;
232 } else {
233 llvm_unreachable("Wrong image access qualifier.");
234 }
235
236 Argument &SizeArg = *(++ArgI);
237 Argument &FormatArg = *(++ArgI);
238 Modified |= replaceImageUses(ImageArg&: Arg, ResourceID, ImageSizeArg&: SizeArg, ImageFormatArg&: FormatArg);
239
240 // Handle sampler type.
241 } else if (IsSamplerType(TypeString: Type)) {
242 uint32_t ResourceID = NumSamplerArgs++;
243 Modified |= replaceSamplerUses(SamplerArg&: Arg, ResourceID);
244 }
245 }
246 for (unsigned i = 0; i < InstsToErase.size(); ++i) {
247 InstsToErase[i]->eraseFromParent();
248 }
249
250 return Modified;
251 }
252
253 std::tuple<Function *, MDNode *>
254 addImplicitArgs(Function *F, MDNode *KernelMDNode) {
255 bool Modified = false;
256
257 FunctionType *FT = F->getFunctionType();
258 SmallVector<Type *, 8> ArgTypes;
259
260 // Metadata operands for new MDNode.
261 KernelArgMD NewArgMDs;
262 PushArgMD(MD&: NewArgMDs, V: GetArgMD(KernelMDNode, OpIdx: 0));
263
264 // Add implicit arguments to the signature.
265 for (unsigned i = 0; i < FT->getNumParams(); ++i) {
266 ArgTypes.push_back(Elt: FT->getParamType(i));
267 MDVector ArgMD = GetArgMD(KernelMDNode, OpIdx: i + 1);
268 PushArgMD(MD&: NewArgMDs, V: ArgMD);
269
270 if (!IsImageType(TypeString: ArgTypeFromMD(KernelMDNode, ArgIdx: i)))
271 continue;
272
273 // Add size implicit argument.
274 ArgTypes.push_back(Elt: ImageSizeType);
275 ArgMD[2] = ArgMD[3] = MDString::get(Context&: *Context, Str: ImageSizeArgMDType);
276 PushArgMD(MD&: NewArgMDs, V: ArgMD);
277
278 // Add format implicit argument.
279 ArgTypes.push_back(Elt: ImageFormatType);
280 ArgMD[2] = ArgMD[3] = MDString::get(Context&: *Context, Str: ImageFormatArgMDType);
281 PushArgMD(MD&: NewArgMDs, V: ArgMD);
282
283 Modified = true;
284 }
285 if (!Modified) {
286 return std::tuple(nullptr, nullptr);
287 }
288
289 // Create function with new signature and clone the old body into it.
290 auto NewFT = FunctionType::get(Result: FT->getReturnType(), Params: ArgTypes, isVarArg: false);
291 auto NewF = Function::Create(Ty: NewFT, Linkage: F->getLinkage(), N: F->getName());
292 ValueToValueMapTy VMap;
293 auto NewFArgIt = NewF->arg_begin();
294 for (auto &Arg: F->args()) {
295 auto ArgName = Arg.getName();
296 NewFArgIt->setName(ArgName);
297 VMap[&Arg] = &(*NewFArgIt++);
298 if (IsImageType(TypeString: ArgTypeFromMD(KernelMDNode, ArgIdx: Arg.getArgNo()))) {
299 (NewFArgIt++)->setName(Twine("__size_") + ArgName);
300 (NewFArgIt++)->setName(Twine("__format_") + ArgName);
301 }
302 }
303 SmallVector<ReturnInst*, 8> Returns;
304 CloneFunctionInto(NewFunc: NewF, OldFunc: F, VMap, Changes: CloneFunctionChangeType::LocalChangesOnly,
305 Returns);
306
307 // Build new MDNode.
308 SmallVector<Metadata *, 6> KernelMDArgs;
309 KernelMDArgs.push_back(Elt: ConstantAsMetadata::get(C: NewF));
310 for (const MDVector &MDV : NewArgMDs.ArgVector)
311 KernelMDArgs.push_back(Elt: MDNode::get(Context&: *Context, MDs: MDV));
312 MDNode *NewMDNode = MDNode::get(Context&: *Context, MDs: KernelMDArgs);
313
314 return std::tuple(NewF, NewMDNode);
315 }
316
317 bool transformKernels(Module &M) {
318 NamedMDNode *KernelsMDNode = M.getNamedMetadata(Name: KernelsMDNodeName);
319 if (!KernelsMDNode)
320 return false;
321
322 bool Modified = false;
323 for (unsigned i = 0; i < KernelsMDNode->getNumOperands(); ++i) {
324 MDNode *KernelMDNode = KernelsMDNode->getOperand(i);
325 Function *F = GetFunctionFromMDNode(Node: KernelMDNode);
326 if (!F)
327 continue;
328
329 Function *NewF;
330 MDNode *NewMDNode;
331 std::tie(args&: NewF, args&: NewMDNode) = addImplicitArgs(F, KernelMDNode);
332 if (NewF) {
333 // Replace old function and metadata with new ones.
334 F->eraseFromParent();
335 M.getFunctionList().push_back(val: NewF);
336 M.getOrInsertFunction(Name: NewF->getName(), T: NewF->getFunctionType(),
337 AttributeList: NewF->getAttributes());
338 KernelsMDNode->setOperand(I: i, New: NewMDNode);
339
340 F = NewF;
341 KernelMDNode = NewMDNode;
342 Modified = true;
343 }
344
345 Modified |= replaceImageAndSamplerUses(F, KernelMDNode);
346 }
347
348 return Modified;
349 }
350
351public:
352 R600OpenCLImageTypeLoweringPass() : ModulePass(ID) {}
353
354 bool runOnModule(Module &M) override {
355 Context = &M.getContext();
356 Int32Type = Type::getInt32Ty(C&: M.getContext());
357 ImageSizeType = ArrayType::get(ElementType: Int32Type, NumElements: 3);
358 ImageFormatType = ArrayType::get(ElementType: Int32Type, NumElements: 2);
359
360 return transformKernels(M);
361 }
362
363 StringRef getPassName() const override {
364 return "R600 OpenCL Image Type Pass";
365 }
366};
367
368} // end anonymous namespace
369
370char R600OpenCLImageTypeLoweringPass::ID = 0;
371
372ModulePass *llvm::createR600OpenCLImageTypeLoweringPass() {
373 return new R600OpenCLImageTypeLoweringPass();
374}
375

source code of llvm/lib/Target/AMDGPU/R600OpenCLImageTypeLoweringPass.cpp