1//===- GPUOpsLowering.cpp - GPU FuncOp / ReturnOp lowering ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "GPUOpsLowering.h"
10
11#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
12#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
13#include "mlir/IR/Attributes.h"
14#include "mlir/IR/Builders.h"
15#include "mlir/IR/BuiltinTypes.h"
16#include "llvm/ADT/SmallVectorExtras.h"
17#include "llvm/ADT/StringSet.h"
18#include "llvm/Support/FormatVariadic.h"
19
20using namespace mlir;
21
22LogicalResult
23GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
24 ConversionPatternRewriter &rewriter) const {
25 Location loc = gpuFuncOp.getLoc();
26
27 SmallVector<LLVM::GlobalOp, 3> workgroupBuffers;
28 workgroupBuffers.reserve(gpuFuncOp.getNumWorkgroupAttributions());
29 for (const auto [idx, attribution] :
30 llvm::enumerate(gpuFuncOp.getWorkgroupAttributions())) {
31 auto type = dyn_cast<MemRefType>(attribution.getType());
32 assert(type && type.hasStaticShape() && "unexpected type in attribution");
33
34 uint64_t numElements = type.getNumElements();
35
36 auto elementType =
37 cast<Type>(typeConverter->convertType(type.getElementType()));
38 auto arrayType = LLVM::LLVMArrayType::get(elementType, numElements);
39 std::string name =
40 std::string(llvm::formatv("__wg_{0}_{1}", gpuFuncOp.getName(), idx));
41 uint64_t alignment = 0;
42 if (auto alignAttr =
43 dyn_cast_or_null<IntegerAttr>(gpuFuncOp.getWorkgroupAttributionAttr(
44 idx, LLVM::LLVMDialect::getAlignAttrName())))
45 alignment = alignAttr.getInt();
46 auto globalOp = rewriter.create<LLVM::GlobalOp>(
47 gpuFuncOp.getLoc(), arrayType, /*isConstant=*/false,
48 LLVM::Linkage::Internal, name, /*value=*/Attribute(), alignment,
49 workgroupAddrSpace);
50 workgroupBuffers.push_back(globalOp);
51 }
52
53 // Remap proper input types.
54 TypeConverter::SignatureConversion signatureConversion(
55 gpuFuncOp.front().getNumArguments());
56
57 Type funcType = getTypeConverter()->convertFunctionSignature(
58 gpuFuncOp.getFunctionType(), /*isVariadic=*/false,
59 getTypeConverter()->getOptions().useBarePtrCallConv, signatureConversion);
60 if (!funcType) {
61 return rewriter.notifyMatchFailure(gpuFuncOp, [&](Diagnostic &diag) {
62 diag << "failed to convert function signature type for: "
63 << gpuFuncOp.getFunctionType();
64 });
65 }
66
67 // Create the new function operation. Only copy those attributes that are
68 // not specific to function modeling.
69 SmallVector<NamedAttribute, 4> attributes;
70 ArrayAttr argAttrs;
71 for (const auto &attr : gpuFuncOp->getAttrs()) {
72 if (attr.getName() == SymbolTable::getSymbolAttrName() ||
73 attr.getName() == gpuFuncOp.getFunctionTypeAttrName() ||
74 attr.getName() ==
75 gpu::GPUFuncOp::getNumWorkgroupAttributionsAttrName() ||
76 attr.getName() == gpuFuncOp.getWorkgroupAttribAttrsAttrName() ||
77 attr.getName() == gpuFuncOp.getPrivateAttribAttrsAttrName())
78 continue;
79 if (attr.getName() == gpuFuncOp.getArgAttrsAttrName()) {
80 argAttrs = gpuFuncOp.getArgAttrsAttr();
81 continue;
82 }
83 attributes.push_back(attr);
84 }
85 // Add a dialect specific kernel attribute in addition to GPU kernel
86 // attribute. The former is necessary for further translation while the
87 // latter is expected by gpu.launch_func.
88 if (gpuFuncOp.isKernel()) {
89 attributes.emplace_back(kernelAttributeName, rewriter.getUnitAttr());
90
91 // Set the block size attribute if it is present.
92 if (kernelBlockSizeAttributeName.has_value()) {
93 std::optional<int32_t> dimX =
94 gpuFuncOp.getKnownBlockSize(gpu::Dimension::x);
95 std::optional<int32_t> dimY =
96 gpuFuncOp.getKnownBlockSize(gpu::Dimension::y);
97 std::optional<int32_t> dimZ =
98 gpuFuncOp.getKnownBlockSize(gpu::Dimension::z);
99 if (dimX.has_value() || dimY.has_value() || dimZ.has_value()) {
100 // If any of the dimensions are missing, fill them in with 1.
101 attributes.emplace_back(
102 kernelBlockSizeAttributeName.value(),
103 rewriter.getDenseI32ArrayAttr(
104 {dimX.value_or(1), dimY.value_or(1), dimZ.value_or(1)}));
105 }
106 }
107 }
108 auto llvmFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
109 gpuFuncOp.getLoc(), gpuFuncOp.getName(), funcType,
110 LLVM::Linkage::External, /*dsoLocal=*/false, /*cconv=*/LLVM::CConv::C,
111 /*comdat=*/nullptr, attributes);
112
113 {
114 // Insert operations that correspond to converted workgroup and private
115 // memory attributions to the body of the function. This must operate on
116 // the original function, before the body region is inlined in the new
117 // function to maintain the relation between block arguments and the
118 // parent operation that assigns their semantics.
119 OpBuilder::InsertionGuard guard(rewriter);
120
121 // Rewrite workgroup memory attributions to addresses of global buffers.
122 rewriter.setInsertionPointToStart(&gpuFuncOp.front());
123 unsigned numProperArguments = gpuFuncOp.getNumArguments();
124
125 for (const auto [idx, global] : llvm::enumerate(workgroupBuffers)) {
126 auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext(),
127 global.getAddrSpace());
128 Value address = rewriter.create<LLVM::AddressOfOp>(
129 loc, ptrType, global.getSymNameAttr());
130 Value memory =
131 rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getType(), address,
132 ArrayRef<LLVM::GEPArg>{0, 0});
133
134 // Build a memref descriptor pointing to the buffer to plug with the
135 // existing memref infrastructure. This may use more registers than
136 // otherwise necessary given that memref sizes are fixed, but we can try
137 // and canonicalize that away later.
138 Value attribution = gpuFuncOp.getWorkgroupAttributions()[idx];
139 auto type = cast<MemRefType>(attribution.getType());
140 auto descr = MemRefDescriptor::fromStaticShape(
141 rewriter, loc, *getTypeConverter(), type, memory);
142 signatureConversion.remapInput(numProperArguments + idx, descr);
143 }
144
145 // Rewrite private memory attributions to alloca'ed buffers.
146 unsigned numWorkgroupAttributions = gpuFuncOp.getNumWorkgroupAttributions();
147 auto int64Ty = IntegerType::get(rewriter.getContext(), 64);
148 for (const auto [idx, attribution] :
149 llvm::enumerate(gpuFuncOp.getPrivateAttributions())) {
150 auto type = cast<MemRefType>(attribution.getType());
151 assert(type && type.hasStaticShape() && "unexpected type in attribution");
152
153 // Explicitly drop memory space when lowering private memory
154 // attributions since NVVM models it as `alloca`s in the default
155 // memory space and does not support `alloca`s with addrspace(5).
156 Type elementType = typeConverter->convertType(type.getElementType());
157 auto ptrType =
158 LLVM::LLVMPointerType::get(rewriter.getContext(), allocaAddrSpace);
159 Value numElements = rewriter.create<LLVM::ConstantOp>(
160 gpuFuncOp.getLoc(), int64Ty, type.getNumElements());
161 uint64_t alignment = 0;
162 if (auto alignAttr =
163 dyn_cast_or_null<IntegerAttr>(gpuFuncOp.getPrivateAttributionAttr(
164 idx, LLVM::LLVMDialect::getAlignAttrName())))
165 alignment = alignAttr.getInt();
166 Value allocated = rewriter.create<LLVM::AllocaOp>(
167 gpuFuncOp.getLoc(), ptrType, elementType, numElements, alignment);
168 auto descr = MemRefDescriptor::fromStaticShape(
169 rewriter, loc, *getTypeConverter(), type, allocated);
170 signatureConversion.remapInput(
171 numProperArguments + numWorkgroupAttributions + idx, descr);
172 }
173 }
174
175 // Move the region to the new function, update the entry block signature.
176 rewriter.inlineRegionBefore(gpuFuncOp.getBody(), llvmFuncOp.getBody(),
177 llvmFuncOp.end());
178 if (failed(rewriter.convertRegionTypes(&llvmFuncOp.getBody(), *typeConverter,
179 &signatureConversion)))
180 return failure();
181
182 // If bare memref pointers are being used, remap them back to memref
183 // descriptors This must be done after signature conversion to get rid of the
184 // unrealized casts.
185 if (getTypeConverter()->getOptions().useBarePtrCallConv) {
186 OpBuilder::InsertionGuard guard(rewriter);
187 rewriter.setInsertionPointToStart(&llvmFuncOp.getBody().front());
188 for (const auto [idx, argTy] :
189 llvm::enumerate(gpuFuncOp.getArgumentTypes())) {
190 auto memrefTy = dyn_cast<MemRefType>(argTy);
191 if (!memrefTy)
192 continue;
193 assert(memrefTy.hasStaticShape() &&
194 "Bare pointer convertion used with dynamically-shaped memrefs");
195 // Use a placeholder when replacing uses of the memref argument to prevent
196 // circular replacements.
197 auto remapping = signatureConversion.getInputMapping(idx);
198 assert(remapping && remapping->size == 1 &&
199 "Type converter should produce 1-to-1 mapping for bare memrefs");
200 BlockArgument newArg =
201 llvmFuncOp.getBody().getArgument(remapping->inputNo);
202 auto placeholder = rewriter.create<LLVM::UndefOp>(
203 loc, getTypeConverter()->convertType(memrefTy));
204 rewriter.replaceUsesOfBlockArgument(newArg, placeholder);
205 Value desc = MemRefDescriptor::fromStaticShape(
206 rewriter, loc, *getTypeConverter(), memrefTy, newArg);
207 rewriter.replaceOp(placeholder, {desc});
208 }
209 }
210
211 // Get memref type from function arguments and set the noalias to
212 // pointer arguments.
213 for (const auto [idx, argTy] :
214 llvm::enumerate(gpuFuncOp.getArgumentTypes())) {
215 auto remapping = signatureConversion.getInputMapping(idx);
216 NamedAttrList argAttr =
217 argAttrs ? cast<DictionaryAttr>(argAttrs[idx]) : NamedAttrList();
218 auto copyAttribute = [&](StringRef attrName) {
219 Attribute attr = argAttr.erase(attrName);
220 if (!attr)
221 return;
222 for (size_t i = 0, e = remapping->size; i < e; ++i)
223 llvmFuncOp.setArgAttr(remapping->inputNo + i, attrName, attr);
224 };
225 auto copyPointerAttribute = [&](StringRef attrName) {
226 Attribute attr = argAttr.erase(attrName);
227
228 if (!attr)
229 return;
230 if (remapping->size > 1 &&
231 attrName == LLVM::LLVMDialect::getNoAliasAttrName()) {
232 emitWarning(llvmFuncOp.getLoc(),
233 "Cannot copy noalias with non-bare pointers.\n");
234 return;
235 }
236 for (size_t i = 0, e = remapping->size; i < e; ++i) {
237 if (isa<LLVM::LLVMPointerType>(
238 llvmFuncOp.getArgument(remapping->inputNo + i).getType())) {
239 llvmFuncOp.setArgAttr(remapping->inputNo + i, attrName, attr);
240 }
241 }
242 };
243
244 if (argAttr.empty())
245 continue;
246
247 copyAttribute(LLVM::LLVMDialect::getReturnedAttrName());
248 copyAttribute(LLVM::LLVMDialect::getNoUndefAttrName());
249 copyAttribute(LLVM::LLVMDialect::getInRegAttrName());
250 bool lowersToPointer = false;
251 for (size_t i = 0, e = remapping->size; i < e; ++i) {
252 lowersToPointer |= isa<LLVM::LLVMPointerType>(
253 llvmFuncOp.getArgument(remapping->inputNo + i).getType());
254 }
255
256 if (lowersToPointer) {
257 copyPointerAttribute(LLVM::LLVMDialect::getNoAliasAttrName());
258 copyPointerAttribute(LLVM::LLVMDialect::getNoCaptureAttrName());
259 copyPointerAttribute(LLVM::LLVMDialect::getNoFreeAttrName());
260 copyPointerAttribute(LLVM::LLVMDialect::getAlignAttrName());
261 copyPointerAttribute(LLVM::LLVMDialect::getReadonlyAttrName());
262 copyPointerAttribute(LLVM::LLVMDialect::getWriteOnlyAttrName());
263 copyPointerAttribute(LLVM::LLVMDialect::getReadnoneAttrName());
264 copyPointerAttribute(LLVM::LLVMDialect::getNonNullAttrName());
265 copyPointerAttribute(LLVM::LLVMDialect::getDereferenceableAttrName());
266 copyPointerAttribute(
267 LLVM::LLVMDialect::getDereferenceableOrNullAttrName());
268 }
269 }
270 rewriter.eraseOp(op: gpuFuncOp);
271 return success();
272}
273
274static SmallString<16> getUniqueFormatGlobalName(gpu::GPUModuleOp moduleOp) {
275 const char formatStringPrefix[] = "printfFormat_";
276 // Get a unique global name.
277 unsigned stringNumber = 0;
278 SmallString<16> stringConstName;
279 do {
280 stringConstName.clear();
281 (formatStringPrefix + Twine(stringNumber++)).toStringRef(Out&: stringConstName);
282 } while (moduleOp.lookupSymbol(stringConstName));
283 return stringConstName;
284}
285
286template <typename T>
287static LLVM::LLVMFuncOp getOrDefineFunction(T &moduleOp, const Location loc,
288 ConversionPatternRewriter &rewriter,
289 StringRef name,
290 LLVM::LLVMFunctionType type) {
291 LLVM::LLVMFuncOp ret;
292 if (!(ret = moduleOp.template lookupSymbol<LLVM::LLVMFuncOp>(name))) {
293 ConversionPatternRewriter::InsertionGuard guard(rewriter);
294 rewriter.setInsertionPointToStart(moduleOp.getBody());
295 ret = rewriter.create<LLVM::LLVMFuncOp>(loc, name, type,
296 LLVM::Linkage::External);
297 }
298 return ret;
299}
300
301LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
302 gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor,
303 ConversionPatternRewriter &rewriter) const {
304 Location loc = gpuPrintfOp->getLoc();
305
306 mlir::Type llvmI8 = typeConverter->convertType(rewriter.getI8Type());
307 auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
308 mlir::Type llvmI32 = typeConverter->convertType(rewriter.getI32Type());
309 mlir::Type llvmI64 = typeConverter->convertType(rewriter.getI64Type());
310 // Note: this is the GPUModule op, not the ModuleOp that surrounds it
311 // This ensures that global constants and declarations are placed within
312 // the device code, not the host code
313 auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>();
314
315 auto ocklBegin =
316 getOrDefineFunction(moduleOp, loc, rewriter, "__ockl_printf_begin",
317 LLVM::LLVMFunctionType::get(llvmI64, {llvmI64}));
318 LLVM::LLVMFuncOp ocklAppendArgs;
319 if (!adaptor.getArgs().empty()) {
320 ocklAppendArgs = getOrDefineFunction(
321 moduleOp, loc, rewriter, "__ockl_printf_append_args",
322 LLVM::LLVMFunctionType::get(
323 llvmI64, {llvmI64, /*numArgs*/ llvmI32, llvmI64, llvmI64, llvmI64,
324 llvmI64, llvmI64, llvmI64, llvmI64, /*isLast*/ llvmI32}));
325 }
326 auto ocklAppendStringN = getOrDefineFunction(
327 moduleOp, loc, rewriter, "__ockl_printf_append_string_n",
328 LLVM::LLVMFunctionType::get(
329 llvmI64,
330 {llvmI64, ptrType, /*length (bytes)*/ llvmI64, /*isLast*/ llvmI32}));
331
332 /// Start the printf hostcall
333 Value zeroI64 = rewriter.create<LLVM::ConstantOp>(loc, llvmI64, 0);
334 auto printfBeginCall = rewriter.create<LLVM::CallOp>(loc, ocklBegin, zeroI64);
335 Value printfDesc = printfBeginCall.getResult();
336
337 // Get a unique global name for the format.
338 SmallString<16> stringConstName = getUniqueFormatGlobalName(moduleOp);
339
340 llvm::SmallString<20> formatString(adaptor.getFormat());
341 formatString.push_back(Elt: '\0'); // Null terminate for C
342 size_t formatStringSize = formatString.size_in_bytes();
343
344 auto globalType = LLVM::LLVMArrayType::get(llvmI8, formatStringSize);
345 LLVM::GlobalOp global;
346 {
347 ConversionPatternRewriter::InsertionGuard guard(rewriter);
348 rewriter.setInsertionPointToStart(moduleOp.getBody());
349 global = rewriter.create<LLVM::GlobalOp>(
350 loc, globalType,
351 /*isConstant=*/true, LLVM::Linkage::Internal, stringConstName,
352 rewriter.getStringAttr(formatString));
353 }
354
355 // Get a pointer to the format string's first element and pass it to printf()
356 Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
357 loc,
358 LLVM::LLVMPointerType::get(rewriter.getContext(), global.getAddrSpace()),
359 global.getSymNameAttr());
360 Value stringStart = rewriter.create<LLVM::GEPOp>(
361 loc, ptrType, globalType, globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
362 Value stringLen =
363 rewriter.create<LLVM::ConstantOp>(loc, llvmI64, formatStringSize);
364
365 Value oneI32 = rewriter.create<LLVM::ConstantOp>(loc, llvmI32, 1);
366 Value zeroI32 = rewriter.create<LLVM::ConstantOp>(loc, llvmI32, 0);
367
368 auto appendFormatCall = rewriter.create<LLVM::CallOp>(
369 loc, ocklAppendStringN,
370 ValueRange{printfDesc, stringStart, stringLen,
371 adaptor.getArgs().empty() ? oneI32 : zeroI32});
372 printfDesc = appendFormatCall.getResult();
373
374 // __ockl_printf_append_args takes 7 values per append call
375 constexpr size_t argsPerAppend = 7;
376 size_t nArgs = adaptor.getArgs().size();
377 for (size_t group = 0; group < nArgs; group += argsPerAppend) {
378 size_t bound = std::min(a: group + argsPerAppend, b: nArgs);
379 size_t numArgsThisCall = bound - group;
380
381 SmallVector<mlir::Value, 2 + argsPerAppend + 1> arguments;
382 arguments.push_back(Elt: printfDesc);
383 arguments.push_back(
384 rewriter.create<LLVM::ConstantOp>(loc, llvmI32, numArgsThisCall));
385 for (size_t i = group; i < bound; ++i) {
386 Value arg = adaptor.getArgs()[i];
387 if (auto floatType = dyn_cast<FloatType>(arg.getType())) {
388 if (!floatType.isF64())
389 arg = rewriter.create<LLVM::FPExtOp>(
390 loc, typeConverter->convertType(rewriter.getF64Type()), arg);
391 arg = rewriter.create<LLVM::BitcastOp>(loc, llvmI64, arg);
392 }
393 if (arg.getType().getIntOrFloatBitWidth() != 64)
394 arg = rewriter.create<LLVM::ZExtOp>(loc, llvmI64, arg);
395
396 arguments.push_back(Elt: arg);
397 }
398 // Pad out to 7 arguments since the hostcall always needs 7
399 for (size_t extra = numArgsThisCall; extra < argsPerAppend; ++extra) {
400 arguments.push_back(Elt: zeroI64);
401 }
402
403 auto isLast = (bound == nArgs) ? oneI32 : zeroI32;
404 arguments.push_back(Elt: isLast);
405 auto call = rewriter.create<LLVM::CallOp>(loc, ocklAppendArgs, arguments);
406 printfDesc = call.getResult();
407 }
408 rewriter.eraseOp(op: gpuPrintfOp);
409 return success();
410}
411
412LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite(
413 gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor,
414 ConversionPatternRewriter &rewriter) const {
415 Location loc = gpuPrintfOp->getLoc();
416
417 mlir::Type llvmI8 = typeConverter->convertType(rewriter.getIntegerType(8));
418 mlir::Type ptrType =
419 LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
420
421 // Note: this is the GPUModule op, not the ModuleOp that surrounds it
422 // This ensures that global constants and declarations are placed within
423 // the device code, not the host code
424 auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>();
425
426 auto printfType =
427 LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {ptrType},
428 /*isVarArg=*/true);
429 LLVM::LLVMFuncOp printfDecl =
430 getOrDefineFunction(moduleOp, loc, rewriter, "printf", printfType);
431
432 // Get a unique global name for the format.
433 SmallString<16> stringConstName = getUniqueFormatGlobalName(moduleOp);
434
435 llvm::SmallString<20> formatString(adaptor.getFormat());
436 formatString.push_back(Elt: '\0'); // Null terminate for C
437 auto globalType =
438 LLVM::LLVMArrayType::get(llvmI8, formatString.size_in_bytes());
439 LLVM::GlobalOp global;
440 {
441 ConversionPatternRewriter::InsertionGuard guard(rewriter);
442 rewriter.setInsertionPointToStart(moduleOp.getBody());
443 global = rewriter.create<LLVM::GlobalOp>(
444 loc, globalType,
445 /*isConstant=*/true, LLVM::Linkage::Internal, stringConstName,
446 rewriter.getStringAttr(formatString), /*allignment=*/0, addressSpace);
447 }
448
449 // Get a pointer to the format string's first element
450 Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
451 loc,
452 LLVM::LLVMPointerType::get(rewriter.getContext(), global.getAddrSpace()),
453 global.getSymNameAttr());
454 Value stringStart = rewriter.create<LLVM::GEPOp>(
455 loc, ptrType, globalType, globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
456
457 // Construct arguments and function call
458 auto argsRange = adaptor.getArgs();
459 SmallVector<Value, 4> printfArgs;
460 printfArgs.reserve(N: argsRange.size() + 1);
461 printfArgs.push_back(Elt: stringStart);
462 printfArgs.append(argsRange.begin(), argsRange.end());
463
464 rewriter.create<LLVM::CallOp>(loc, printfDecl, printfArgs);
465 rewriter.eraseOp(op: gpuPrintfOp);
466 return success();
467}
468
469LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite(
470 gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor,
471 ConversionPatternRewriter &rewriter) const {
472 Location loc = gpuPrintfOp->getLoc();
473
474 mlir::Type llvmI8 = typeConverter->convertType(rewriter.getIntegerType(8));
475 mlir::Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
476
477 // Note: this is the GPUModule op, not the ModuleOp that surrounds it
478 // This ensures that global constants and declarations are placed within
479 // the device code, not the host code
480 auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>();
481
482 auto vprintfType =
483 LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {ptrType, ptrType});
484 LLVM::LLVMFuncOp vprintfDecl =
485 getOrDefineFunction(moduleOp, loc, rewriter, "vprintf", vprintfType);
486
487 // Get a unique global name for the format.
488 SmallString<16> stringConstName = getUniqueFormatGlobalName(moduleOp);
489
490 llvm::SmallString<20> formatString(adaptor.getFormat());
491 formatString.push_back(Elt: '\0'); // Null terminate for C
492 auto globalType =
493 LLVM::LLVMArrayType::get(llvmI8, formatString.size_in_bytes());
494 LLVM::GlobalOp global;
495 {
496 ConversionPatternRewriter::InsertionGuard guard(rewriter);
497 rewriter.setInsertionPointToStart(moduleOp.getBody());
498 global = rewriter.create<LLVM::GlobalOp>(
499 loc, globalType,
500 /*isConstant=*/true, LLVM::Linkage::Internal, stringConstName,
501 rewriter.getStringAttr(formatString), /*allignment=*/0);
502 }
503
504 // Get a pointer to the format string's first element
505 Value globalPtr = rewriter.create<LLVM::AddressOfOp>(loc, global);
506 Value stringStart = rewriter.create<LLVM::GEPOp>(
507 loc, ptrType, globalType, globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
508 SmallVector<Type> types;
509 SmallVector<Value> args;
510 // Promote and pack the arguments into a stack allocation.
511 for (Value arg : adaptor.getArgs()) {
512 Type type = arg.getType();
513 Value promotedArg = arg;
514 assert(type.isIntOrFloat());
515 if (isa<FloatType>(type)) {
516 type = rewriter.getF64Type();
517 promotedArg = rewriter.create<LLVM::FPExtOp>(loc, type, arg);
518 }
519 types.push_back(type);
520 args.push_back(promotedArg);
521 }
522 Type structType =
523 LLVM::LLVMStructType::getLiteral(context: gpuPrintfOp.getContext(), types);
524 Value one = rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI64Type(),
525 rewriter.getIndexAttr(1));
526 Value tempAlloc =
527 rewriter.create<LLVM::AllocaOp>(loc, ptrType, structType, one,
528 /*alignment=*/0);
529 for (auto [index, arg] : llvm::enumerate(First&: args)) {
530 Value ptr = rewriter.create<LLVM::GEPOp>(
531 loc, ptrType, structType, tempAlloc,
532 ArrayRef<LLVM::GEPArg>{0, static_cast<int32_t>(index)});
533 rewriter.create<LLVM::StoreOp>(loc, arg, ptr);
534 }
535 std::array<Value, 2> printfArgs = {stringStart, tempAlloc};
536
537 rewriter.create<LLVM::CallOp>(loc, vprintfDecl, printfArgs);
538 rewriter.eraseOp(op: gpuPrintfOp);
539 return success();
540}
541
542/// Unrolls op if it's operating on vectors.
543LogicalResult impl::scalarizeVectorOp(Operation *op, ValueRange operands,
544 ConversionPatternRewriter &rewriter,
545 const LLVMTypeConverter &converter) {
546 TypeRange operandTypes(operands);
547 if (llvm::none_of(Range&: operandTypes, P: llvm::IsaPred<VectorType>)) {
548 return rewriter.notifyMatchFailure(arg&: op, msg: "expected vector operand");
549 }
550 if (op->getNumRegions() != 0 || op->getNumSuccessors() != 0)
551 return rewriter.notifyMatchFailure(arg&: op, msg: "expected no region/successor");
552 if (op->getNumResults() != 1)
553 return rewriter.notifyMatchFailure(arg&: op, msg: "expected single result");
554 VectorType vectorType = dyn_cast<VectorType>(op->getResult(0).getType());
555 if (!vectorType)
556 return rewriter.notifyMatchFailure(arg&: op, msg: "expected vector result");
557
558 Location loc = op->getLoc();
559 Value result = rewriter.create<LLVM::UndefOp>(loc, vectorType);
560 Type indexType = converter.convertType(rewriter.getIndexType());
561 StringAttr name = op->getName().getIdentifier();
562 Type elementType = vectorType.getElementType();
563
564 for (int64_t i = 0; i < vectorType.getNumElements(); ++i) {
565 Value index = rewriter.create<LLVM::ConstantOp>(loc, indexType, i);
566 auto extractElement = [&](Value operand) -> Value {
567 if (!isa<VectorType>(Val: operand.getType()))
568 return operand;
569 return rewriter.create<LLVM::ExtractElementOp>(loc, operand, index);
570 };
571 auto scalarOperands = llvm::map_to_vector(C&: operands, F&: extractElement);
572 Operation *scalarOp =
573 rewriter.create(loc, name, scalarOperands, elementType, op->getAttrs());
574 result = rewriter.create<LLVM::InsertElementOp>(
575 loc, result, scalarOp->getResult(0), index);
576 }
577
578 rewriter.replaceOp(op, newValues: result);
579 return success();
580}
581
582static IntegerAttr wrapNumericMemorySpace(MLIRContext *ctx, unsigned space) {
583 return IntegerAttr::get(IntegerType::get(ctx, 64), space);
584}
585
586/// Generates a symbol with 0-sized array type for dynamic shared memory usage,
587/// or uses existing symbol.
588LLVM::GlobalOp
589getDynamicSharedMemorySymbol(ConversionPatternRewriter &rewriter,
590 Operation *moduleOp, gpu::DynamicSharedMemoryOp op,
591 const LLVMTypeConverter *typeConverter,
592 MemRefType memrefType, unsigned alignmentBit) {
593 uint64_t alignmentByte = alignmentBit / memrefType.getElementTypeBitWidth();
594
595 FailureOr<unsigned> addressSpace =
596 typeConverter->getMemRefAddressSpace(type: memrefType);
597 if (failed(result: addressSpace)) {
598 op->emitError() << "conversion of memref memory space "
599 << memrefType.getMemorySpace()
600 << " to integer address space "
601 "failed. Consider adding memory space conversions.";
602 }
603
604 // Step 1. Collect symbol names of LLVM::GlobalOp Ops. Also if any of
605 // LLVM::GlobalOp is suitable for shared memory, return it.
606 llvm::StringSet<> existingGlobalNames;
607 for (auto globalOp :
608 moduleOp->getRegion(0).front().getOps<LLVM::GlobalOp>()) {
609 existingGlobalNames.insert(globalOp.getSymName());
610 if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(globalOp.getType())) {
611 if (globalOp.getAddrSpace() == addressSpace.value() &&
612 arrayType.getNumElements() == 0 &&
613 globalOp.getAlignment().value_or(0) == alignmentByte) {
614 return globalOp;
615 }
616 }
617 }
618
619 // Step 2. Find a unique symbol name
620 unsigned uniquingCounter = 0;
621 SmallString<128> symName = SymbolTable::generateSymbolName<128>(
622 name: "__dynamic_shmem_",
623 uniqueChecker: [&](StringRef candidate) {
624 return existingGlobalNames.contains(key: candidate);
625 },
626 uniquingCounter);
627
628 // Step 3. Generate a global op
629 OpBuilder::InsertionGuard guard(rewriter);
630 rewriter.setInsertionPoint(&moduleOp->getRegion(index: 0).front().front());
631
632 auto zeroSizedArrayType = LLVM::LLVMArrayType::get(
633 typeConverter->convertType(memrefType.getElementType()), 0);
634
635 return rewriter.create<LLVM::GlobalOp>(
636 op->getLoc(), zeroSizedArrayType, /*isConstant=*/false,
637 LLVM::Linkage::Internal, symName, /*value=*/Attribute(), alignmentByte,
638 addressSpace.value());
639}
640
641LogicalResult GPUDynamicSharedMemoryOpLowering::matchAndRewrite(
642 gpu::DynamicSharedMemoryOp op, OpAdaptor adaptor,
643 ConversionPatternRewriter &rewriter) const {
644 Location loc = op.getLoc();
645 MemRefType memrefType = op.getResultMemref().getType();
646 Type elementType = typeConverter->convertType(memrefType.getElementType());
647
648 // Step 1: Generate a memref<0xi8> type
649 MemRefLayoutAttrInterface layout = {};
650 auto memrefType0sz =
651 MemRefType::get({0}, elementType, layout, memrefType.getMemorySpace());
652
653 // Step 2: Generate a global symbol or existing for the dynamic shared
654 // memory with memref<0xi8> type
655 LLVM::LLVMFuncOp funcOp = op->getParentOfType<LLVM::LLVMFuncOp>();
656 LLVM::GlobalOp shmemOp = {};
657 Operation *moduleOp = funcOp->getParentWithTrait<OpTrait::SymbolTable>();
658 shmemOp = getDynamicSharedMemorySymbol(
659 rewriter, moduleOp, op, getTypeConverter(), memrefType0sz, alignmentBit);
660
661 // Step 3. Get address of the global symbol
662 OpBuilder::InsertionGuard guard(rewriter);
663 rewriter.setInsertionPoint(op);
664 auto basePtr = rewriter.create<LLVM::AddressOfOp>(loc, shmemOp);
665 Type baseType = basePtr->getResultTypes().front();
666
667 // Step 4. Generate GEP using offsets
668 SmallVector<LLVM::GEPArg> gepArgs = {0};
669 Value shmemPtr = rewriter.create<LLVM::GEPOp>(loc, baseType, elementType,
670 basePtr, gepArgs);
671 // Step 5. Create a memref descriptor
672 SmallVector<Value> shape, strides;
673 Value sizeBytes;
674 getMemRefDescriptorSizes(loc, memrefType0sz, {}, rewriter, shape, strides,
675 sizeBytes);
676 auto memRefDescriptor = this->createMemRefDescriptor(
677 loc, memrefType0sz, shmemPtr, shmemPtr, shape, strides, rewriter);
678
679 // Step 5. Replace the op with memref descriptor
680 rewriter.replaceOp(op, {memRefDescriptor});
681 return success();
682}
683
684void mlir::populateGpuMemorySpaceAttributeConversions(
685 TypeConverter &typeConverter, const MemorySpaceMapping &mapping) {
686 typeConverter.addTypeAttributeConversion(
687 callback: [mapping](BaseMemRefType type, gpu::AddressSpaceAttr memorySpaceAttr) {
688 gpu::AddressSpace memorySpace = memorySpaceAttr.getValue();
689 unsigned addressSpace = mapping(memorySpace);
690 return wrapNumericMemorySpace(memorySpaceAttr.getContext(),
691 addressSpace);
692 });
693}
694

source code of mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp