1 | //===- GPUOpsLowering.cpp - GPU FuncOp / ReturnOp lowering ----------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "GPUOpsLowering.h" |
10 | |
11 | #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" |
12 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
13 | #include "mlir/IR/Attributes.h" |
14 | #include "mlir/IR/Builders.h" |
15 | #include "mlir/IR/BuiltinTypes.h" |
16 | #include "llvm/ADT/SmallVectorExtras.h" |
17 | #include "llvm/ADT/StringSet.h" |
18 | #include "llvm/Support/FormatVariadic.h" |
19 | |
20 | using namespace mlir; |
21 | |
22 | LogicalResult |
23 | GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor, |
24 | ConversionPatternRewriter &rewriter) const { |
25 | Location loc = gpuFuncOp.getLoc(); |
26 | |
27 | SmallVector<LLVM::GlobalOp, 3> workgroupBuffers; |
28 | workgroupBuffers.reserve(gpuFuncOp.getNumWorkgroupAttributions()); |
29 | for (const auto [idx, attribution] : |
30 | llvm::enumerate(gpuFuncOp.getWorkgroupAttributions())) { |
31 | auto type = dyn_cast<MemRefType>(attribution.getType()); |
32 | assert(type && type.hasStaticShape() && "unexpected type in attribution" ); |
33 | |
34 | uint64_t numElements = type.getNumElements(); |
35 | |
36 | auto elementType = |
37 | cast<Type>(typeConverter->convertType(type.getElementType())); |
38 | auto arrayType = LLVM::LLVMArrayType::get(elementType, numElements); |
39 | std::string name = |
40 | std::string(llvm::formatv("__wg_{0}_{1}" , gpuFuncOp.getName(), idx)); |
41 | uint64_t alignment = 0; |
42 | if (auto alignAttr = |
43 | dyn_cast_or_null<IntegerAttr>(gpuFuncOp.getWorkgroupAttributionAttr( |
44 | idx, LLVM::LLVMDialect::getAlignAttrName()))) |
45 | alignment = alignAttr.getInt(); |
46 | auto globalOp = rewriter.create<LLVM::GlobalOp>( |
47 | gpuFuncOp.getLoc(), arrayType, /*isConstant=*/false, |
48 | LLVM::Linkage::Internal, name, /*value=*/Attribute(), alignment, |
49 | workgroupAddrSpace); |
50 | workgroupBuffers.push_back(globalOp); |
51 | } |
52 | |
53 | // Remap proper input types. |
54 | TypeConverter::SignatureConversion signatureConversion( |
55 | gpuFuncOp.front().getNumArguments()); |
56 | |
57 | Type funcType = getTypeConverter()->convertFunctionSignature( |
58 | gpuFuncOp.getFunctionType(), /*isVariadic=*/false, |
59 | getTypeConverter()->getOptions().useBarePtrCallConv, signatureConversion); |
60 | if (!funcType) { |
61 | return rewriter.notifyMatchFailure(gpuFuncOp, [&](Diagnostic &diag) { |
62 | diag << "failed to convert function signature type for: " |
63 | << gpuFuncOp.getFunctionType(); |
64 | }); |
65 | } |
66 | |
67 | // Create the new function operation. Only copy those attributes that are |
68 | // not specific to function modeling. |
69 | SmallVector<NamedAttribute, 4> attributes; |
70 | ArrayAttr argAttrs; |
71 | for (const auto &attr : gpuFuncOp->getAttrs()) { |
72 | if (attr.getName() == SymbolTable::getSymbolAttrName() || |
73 | attr.getName() == gpuFuncOp.getFunctionTypeAttrName() || |
74 | attr.getName() == |
75 | gpu::GPUFuncOp::getNumWorkgroupAttributionsAttrName() || |
76 | attr.getName() == gpuFuncOp.getWorkgroupAttribAttrsAttrName() || |
77 | attr.getName() == gpuFuncOp.getPrivateAttribAttrsAttrName()) |
78 | continue; |
79 | if (attr.getName() == gpuFuncOp.getArgAttrsAttrName()) { |
80 | argAttrs = gpuFuncOp.getArgAttrsAttr(); |
81 | continue; |
82 | } |
83 | attributes.push_back(attr); |
84 | } |
85 | // Add a dialect specific kernel attribute in addition to GPU kernel |
86 | // attribute. The former is necessary for further translation while the |
87 | // latter is expected by gpu.launch_func. |
88 | if (gpuFuncOp.isKernel()) { |
89 | attributes.emplace_back(kernelAttributeName, rewriter.getUnitAttr()); |
90 | |
91 | // Set the block size attribute if it is present. |
92 | if (kernelBlockSizeAttributeName.has_value()) { |
93 | std::optional<int32_t> dimX = |
94 | gpuFuncOp.getKnownBlockSize(gpu::Dimension::x); |
95 | std::optional<int32_t> dimY = |
96 | gpuFuncOp.getKnownBlockSize(gpu::Dimension::y); |
97 | std::optional<int32_t> dimZ = |
98 | gpuFuncOp.getKnownBlockSize(gpu::Dimension::z); |
99 | if (dimX.has_value() || dimY.has_value() || dimZ.has_value()) { |
100 | // If any of the dimensions are missing, fill them in with 1. |
101 | attributes.emplace_back( |
102 | kernelBlockSizeAttributeName.value(), |
103 | rewriter.getDenseI32ArrayAttr( |
104 | {dimX.value_or(1), dimY.value_or(1), dimZ.value_or(1)})); |
105 | } |
106 | } |
107 | } |
108 | auto llvmFuncOp = rewriter.create<LLVM::LLVMFuncOp>( |
109 | gpuFuncOp.getLoc(), gpuFuncOp.getName(), funcType, |
110 | LLVM::Linkage::External, /*dsoLocal=*/false, /*cconv=*/LLVM::CConv::C, |
111 | /*comdat=*/nullptr, attributes); |
112 | |
113 | { |
114 | // Insert operations that correspond to converted workgroup and private |
115 | // memory attributions to the body of the function. This must operate on |
116 | // the original function, before the body region is inlined in the new |
117 | // function to maintain the relation between block arguments and the |
118 | // parent operation that assigns their semantics. |
119 | OpBuilder::InsertionGuard guard(rewriter); |
120 | |
121 | // Rewrite workgroup memory attributions to addresses of global buffers. |
122 | rewriter.setInsertionPointToStart(&gpuFuncOp.front()); |
123 | unsigned numProperArguments = gpuFuncOp.getNumArguments(); |
124 | |
125 | for (const auto [idx, global] : llvm::enumerate(workgroupBuffers)) { |
126 | auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext(), |
127 | global.getAddrSpace()); |
128 | Value address = rewriter.create<LLVM::AddressOfOp>( |
129 | loc, ptrType, global.getSymNameAttr()); |
130 | Value memory = |
131 | rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getType(), address, |
132 | ArrayRef<LLVM::GEPArg>{0, 0}); |
133 | |
134 | // Build a memref descriptor pointing to the buffer to plug with the |
135 | // existing memref infrastructure. This may use more registers than |
136 | // otherwise necessary given that memref sizes are fixed, but we can try |
137 | // and canonicalize that away later. |
138 | Value attribution = gpuFuncOp.getWorkgroupAttributions()[idx]; |
139 | auto type = cast<MemRefType>(attribution.getType()); |
140 | auto descr = MemRefDescriptor::fromStaticShape( |
141 | rewriter, loc, *getTypeConverter(), type, memory); |
142 | signatureConversion.remapInput(numProperArguments + idx, descr); |
143 | } |
144 | |
145 | // Rewrite private memory attributions to alloca'ed buffers. |
146 | unsigned numWorkgroupAttributions = gpuFuncOp.getNumWorkgroupAttributions(); |
147 | auto int64Ty = IntegerType::get(rewriter.getContext(), 64); |
148 | for (const auto [idx, attribution] : |
149 | llvm::enumerate(gpuFuncOp.getPrivateAttributions())) { |
150 | auto type = cast<MemRefType>(attribution.getType()); |
151 | assert(type && type.hasStaticShape() && "unexpected type in attribution" ); |
152 | |
153 | // Explicitly drop memory space when lowering private memory |
154 | // attributions since NVVM models it as `alloca`s in the default |
155 | // memory space and does not support `alloca`s with addrspace(5). |
156 | Type elementType = typeConverter->convertType(type.getElementType()); |
157 | auto ptrType = |
158 | LLVM::LLVMPointerType::get(rewriter.getContext(), allocaAddrSpace); |
159 | Value numElements = rewriter.create<LLVM::ConstantOp>( |
160 | gpuFuncOp.getLoc(), int64Ty, type.getNumElements()); |
161 | uint64_t alignment = 0; |
162 | if (auto alignAttr = |
163 | dyn_cast_or_null<IntegerAttr>(gpuFuncOp.getPrivateAttributionAttr( |
164 | idx, LLVM::LLVMDialect::getAlignAttrName()))) |
165 | alignment = alignAttr.getInt(); |
166 | Value allocated = rewriter.create<LLVM::AllocaOp>( |
167 | gpuFuncOp.getLoc(), ptrType, elementType, numElements, alignment); |
168 | auto descr = MemRefDescriptor::fromStaticShape( |
169 | rewriter, loc, *getTypeConverter(), type, allocated); |
170 | signatureConversion.remapInput( |
171 | numProperArguments + numWorkgroupAttributions + idx, descr); |
172 | } |
173 | } |
174 | |
175 | // Move the region to the new function, update the entry block signature. |
176 | rewriter.inlineRegionBefore(gpuFuncOp.getBody(), llvmFuncOp.getBody(), |
177 | llvmFuncOp.end()); |
178 | if (failed(rewriter.convertRegionTypes(&llvmFuncOp.getBody(), *typeConverter, |
179 | &signatureConversion))) |
180 | return failure(); |
181 | |
182 | // If bare memref pointers are being used, remap them back to memref |
183 | // descriptors This must be done after signature conversion to get rid of the |
184 | // unrealized casts. |
185 | if (getTypeConverter()->getOptions().useBarePtrCallConv) { |
186 | OpBuilder::InsertionGuard guard(rewriter); |
187 | rewriter.setInsertionPointToStart(&llvmFuncOp.getBody().front()); |
188 | for (const auto [idx, argTy] : |
189 | llvm::enumerate(gpuFuncOp.getArgumentTypes())) { |
190 | auto memrefTy = dyn_cast<MemRefType>(argTy); |
191 | if (!memrefTy) |
192 | continue; |
193 | assert(memrefTy.hasStaticShape() && |
194 | "Bare pointer convertion used with dynamically-shaped memrefs" ); |
195 | // Use a placeholder when replacing uses of the memref argument to prevent |
196 | // circular replacements. |
197 | auto remapping = signatureConversion.getInputMapping(idx); |
198 | assert(remapping && remapping->size == 1 && |
199 | "Type converter should produce 1-to-1 mapping for bare memrefs" ); |
200 | BlockArgument newArg = |
201 | llvmFuncOp.getBody().getArgument(remapping->inputNo); |
202 | auto placeholder = rewriter.create<LLVM::UndefOp>( |
203 | loc, getTypeConverter()->convertType(memrefTy)); |
204 | rewriter.replaceUsesOfBlockArgument(newArg, placeholder); |
205 | Value desc = MemRefDescriptor::fromStaticShape( |
206 | rewriter, loc, *getTypeConverter(), memrefTy, newArg); |
207 | rewriter.replaceOp(placeholder, {desc}); |
208 | } |
209 | } |
210 | |
211 | // Get memref type from function arguments and set the noalias to |
212 | // pointer arguments. |
213 | for (const auto [idx, argTy] : |
214 | llvm::enumerate(gpuFuncOp.getArgumentTypes())) { |
215 | auto remapping = signatureConversion.getInputMapping(idx); |
216 | NamedAttrList argAttr = |
217 | argAttrs ? cast<DictionaryAttr>(argAttrs[idx]) : NamedAttrList(); |
218 | auto copyAttribute = [&](StringRef attrName) { |
219 | Attribute attr = argAttr.erase(attrName); |
220 | if (!attr) |
221 | return; |
222 | for (size_t i = 0, e = remapping->size; i < e; ++i) |
223 | llvmFuncOp.setArgAttr(remapping->inputNo + i, attrName, attr); |
224 | }; |
225 | auto copyPointerAttribute = [&](StringRef attrName) { |
226 | Attribute attr = argAttr.erase(attrName); |
227 | |
228 | if (!attr) |
229 | return; |
230 | if (remapping->size > 1 && |
231 | attrName == LLVM::LLVMDialect::getNoAliasAttrName()) { |
232 | emitWarning(llvmFuncOp.getLoc(), |
233 | "Cannot copy noalias with non-bare pointers.\n" ); |
234 | return; |
235 | } |
236 | for (size_t i = 0, e = remapping->size; i < e; ++i) { |
237 | if (isa<LLVM::LLVMPointerType>( |
238 | llvmFuncOp.getArgument(remapping->inputNo + i).getType())) { |
239 | llvmFuncOp.setArgAttr(remapping->inputNo + i, attrName, attr); |
240 | } |
241 | } |
242 | }; |
243 | |
244 | if (argAttr.empty()) |
245 | continue; |
246 | |
247 | copyAttribute(LLVM::LLVMDialect::getReturnedAttrName()); |
248 | copyAttribute(LLVM::LLVMDialect::getNoUndefAttrName()); |
249 | copyAttribute(LLVM::LLVMDialect::getInRegAttrName()); |
250 | bool lowersToPointer = false; |
251 | for (size_t i = 0, e = remapping->size; i < e; ++i) { |
252 | lowersToPointer |= isa<LLVM::LLVMPointerType>( |
253 | llvmFuncOp.getArgument(remapping->inputNo + i).getType()); |
254 | } |
255 | |
256 | if (lowersToPointer) { |
257 | copyPointerAttribute(LLVM::LLVMDialect::getNoAliasAttrName()); |
258 | copyPointerAttribute(LLVM::LLVMDialect::getNoCaptureAttrName()); |
259 | copyPointerAttribute(LLVM::LLVMDialect::getNoFreeAttrName()); |
260 | copyPointerAttribute(LLVM::LLVMDialect::getAlignAttrName()); |
261 | copyPointerAttribute(LLVM::LLVMDialect::getReadonlyAttrName()); |
262 | copyPointerAttribute(LLVM::LLVMDialect::getWriteOnlyAttrName()); |
263 | copyPointerAttribute(LLVM::LLVMDialect::getReadnoneAttrName()); |
264 | copyPointerAttribute(LLVM::LLVMDialect::getNonNullAttrName()); |
265 | copyPointerAttribute(LLVM::LLVMDialect::getDereferenceableAttrName()); |
266 | copyPointerAttribute( |
267 | LLVM::LLVMDialect::getDereferenceableOrNullAttrName()); |
268 | } |
269 | } |
270 | rewriter.eraseOp(op: gpuFuncOp); |
271 | return success(); |
272 | } |
273 | |
274 | static SmallString<16> getUniqueFormatGlobalName(gpu::GPUModuleOp moduleOp) { |
275 | const char formatStringPrefix[] = "printfFormat_" ; |
276 | // Get a unique global name. |
277 | unsigned stringNumber = 0; |
278 | SmallString<16> stringConstName; |
279 | do { |
280 | stringConstName.clear(); |
281 | (formatStringPrefix + Twine(stringNumber++)).toStringRef(Out&: stringConstName); |
282 | } while (moduleOp.lookupSymbol(stringConstName)); |
283 | return stringConstName; |
284 | } |
285 | |
286 | template <typename T> |
287 | static LLVM::LLVMFuncOp getOrDefineFunction(T &moduleOp, const Location loc, |
288 | ConversionPatternRewriter &rewriter, |
289 | StringRef name, |
290 | LLVM::LLVMFunctionType type) { |
291 | LLVM::LLVMFuncOp ret; |
292 | if (!(ret = moduleOp.template lookupSymbol<LLVM::LLVMFuncOp>(name))) { |
293 | ConversionPatternRewriter::InsertionGuard guard(rewriter); |
294 | rewriter.setInsertionPointToStart(moduleOp.getBody()); |
295 | ret = rewriter.create<LLVM::LLVMFuncOp>(loc, name, type, |
296 | LLVM::Linkage::External); |
297 | } |
298 | return ret; |
299 | } |
300 | |
301 | LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite( |
302 | gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor, |
303 | ConversionPatternRewriter &rewriter) const { |
304 | Location loc = gpuPrintfOp->getLoc(); |
305 | |
306 | mlir::Type llvmI8 = typeConverter->convertType(rewriter.getI8Type()); |
307 | auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext()); |
308 | mlir::Type llvmI32 = typeConverter->convertType(rewriter.getI32Type()); |
309 | mlir::Type llvmI64 = typeConverter->convertType(rewriter.getI64Type()); |
310 | // Note: this is the GPUModule op, not the ModuleOp that surrounds it |
311 | // This ensures that global constants and declarations are placed within |
312 | // the device code, not the host code |
313 | auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>(); |
314 | |
315 | auto ocklBegin = |
316 | getOrDefineFunction(moduleOp, loc, rewriter, "__ockl_printf_begin" , |
317 | LLVM::LLVMFunctionType::get(llvmI64, {llvmI64})); |
318 | LLVM::LLVMFuncOp ocklAppendArgs; |
319 | if (!adaptor.getArgs().empty()) { |
320 | ocklAppendArgs = getOrDefineFunction( |
321 | moduleOp, loc, rewriter, "__ockl_printf_append_args" , |
322 | LLVM::LLVMFunctionType::get( |
323 | llvmI64, {llvmI64, /*numArgs*/ llvmI32, llvmI64, llvmI64, llvmI64, |
324 | llvmI64, llvmI64, llvmI64, llvmI64, /*isLast*/ llvmI32})); |
325 | } |
326 | auto ocklAppendStringN = getOrDefineFunction( |
327 | moduleOp, loc, rewriter, "__ockl_printf_append_string_n" , |
328 | LLVM::LLVMFunctionType::get( |
329 | llvmI64, |
330 | {llvmI64, ptrType, /*length (bytes)*/ llvmI64, /*isLast*/ llvmI32})); |
331 | |
332 | /// Start the printf hostcall |
333 | Value zeroI64 = rewriter.create<LLVM::ConstantOp>(loc, llvmI64, 0); |
334 | auto printfBeginCall = rewriter.create<LLVM::CallOp>(loc, ocklBegin, zeroI64); |
335 | Value printfDesc = printfBeginCall.getResult(); |
336 | |
337 | // Get a unique global name for the format. |
338 | SmallString<16> stringConstName = getUniqueFormatGlobalName(moduleOp); |
339 | |
340 | llvm::SmallString<20> formatString(adaptor.getFormat()); |
341 | formatString.push_back(Elt: '\0'); // Null terminate for C |
342 | size_t formatStringSize = formatString.size_in_bytes(); |
343 | |
344 | auto globalType = LLVM::LLVMArrayType::get(llvmI8, formatStringSize); |
345 | LLVM::GlobalOp global; |
346 | { |
347 | ConversionPatternRewriter::InsertionGuard guard(rewriter); |
348 | rewriter.setInsertionPointToStart(moduleOp.getBody()); |
349 | global = rewriter.create<LLVM::GlobalOp>( |
350 | loc, globalType, |
351 | /*isConstant=*/true, LLVM::Linkage::Internal, stringConstName, |
352 | rewriter.getStringAttr(formatString)); |
353 | } |
354 | |
355 | // Get a pointer to the format string's first element and pass it to printf() |
356 | Value globalPtr = rewriter.create<LLVM::AddressOfOp>( |
357 | loc, |
358 | LLVM::LLVMPointerType::get(rewriter.getContext(), global.getAddrSpace()), |
359 | global.getSymNameAttr()); |
360 | Value stringStart = rewriter.create<LLVM::GEPOp>( |
361 | loc, ptrType, globalType, globalPtr, ArrayRef<LLVM::GEPArg>{0, 0}); |
362 | Value stringLen = |
363 | rewriter.create<LLVM::ConstantOp>(loc, llvmI64, formatStringSize); |
364 | |
365 | Value oneI32 = rewriter.create<LLVM::ConstantOp>(loc, llvmI32, 1); |
366 | Value zeroI32 = rewriter.create<LLVM::ConstantOp>(loc, llvmI32, 0); |
367 | |
368 | auto appendFormatCall = rewriter.create<LLVM::CallOp>( |
369 | loc, ocklAppendStringN, |
370 | ValueRange{printfDesc, stringStart, stringLen, |
371 | adaptor.getArgs().empty() ? oneI32 : zeroI32}); |
372 | printfDesc = appendFormatCall.getResult(); |
373 | |
374 | // __ockl_printf_append_args takes 7 values per append call |
375 | constexpr size_t argsPerAppend = 7; |
376 | size_t nArgs = adaptor.getArgs().size(); |
377 | for (size_t group = 0; group < nArgs; group += argsPerAppend) { |
378 | size_t bound = std::min(a: group + argsPerAppend, b: nArgs); |
379 | size_t numArgsThisCall = bound - group; |
380 | |
381 | SmallVector<mlir::Value, 2 + argsPerAppend + 1> arguments; |
382 | arguments.push_back(Elt: printfDesc); |
383 | arguments.push_back( |
384 | rewriter.create<LLVM::ConstantOp>(loc, llvmI32, numArgsThisCall)); |
385 | for (size_t i = group; i < bound; ++i) { |
386 | Value arg = adaptor.getArgs()[i]; |
387 | if (auto floatType = dyn_cast<FloatType>(arg.getType())) { |
388 | if (!floatType.isF64()) |
389 | arg = rewriter.create<LLVM::FPExtOp>( |
390 | loc, typeConverter->convertType(rewriter.getF64Type()), arg); |
391 | arg = rewriter.create<LLVM::BitcastOp>(loc, llvmI64, arg); |
392 | } |
393 | if (arg.getType().getIntOrFloatBitWidth() != 64) |
394 | arg = rewriter.create<LLVM::ZExtOp>(loc, llvmI64, arg); |
395 | |
396 | arguments.push_back(Elt: arg); |
397 | } |
398 | // Pad out to 7 arguments since the hostcall always needs 7 |
399 | for (size_t = numArgsThisCall; extra < argsPerAppend; ++extra) { |
400 | arguments.push_back(Elt: zeroI64); |
401 | } |
402 | |
403 | auto isLast = (bound == nArgs) ? oneI32 : zeroI32; |
404 | arguments.push_back(Elt: isLast); |
405 | auto call = rewriter.create<LLVM::CallOp>(loc, ocklAppendArgs, arguments); |
406 | printfDesc = call.getResult(); |
407 | } |
408 | rewriter.eraseOp(op: gpuPrintfOp); |
409 | return success(); |
410 | } |
411 | |
412 | LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite( |
413 | gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor, |
414 | ConversionPatternRewriter &rewriter) const { |
415 | Location loc = gpuPrintfOp->getLoc(); |
416 | |
417 | mlir::Type llvmI8 = typeConverter->convertType(rewriter.getIntegerType(8)); |
418 | mlir::Type ptrType = |
419 | LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace); |
420 | |
421 | // Note: this is the GPUModule op, not the ModuleOp that surrounds it |
422 | // This ensures that global constants and declarations are placed within |
423 | // the device code, not the host code |
424 | auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>(); |
425 | |
426 | auto printfType = |
427 | LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {ptrType}, |
428 | /*isVarArg=*/true); |
429 | LLVM::LLVMFuncOp printfDecl = |
430 | getOrDefineFunction(moduleOp, loc, rewriter, "printf" , printfType); |
431 | |
432 | // Get a unique global name for the format. |
433 | SmallString<16> stringConstName = getUniqueFormatGlobalName(moduleOp); |
434 | |
435 | llvm::SmallString<20> formatString(adaptor.getFormat()); |
436 | formatString.push_back(Elt: '\0'); // Null terminate for C |
437 | auto globalType = |
438 | LLVM::LLVMArrayType::get(llvmI8, formatString.size_in_bytes()); |
439 | LLVM::GlobalOp global; |
440 | { |
441 | ConversionPatternRewriter::InsertionGuard guard(rewriter); |
442 | rewriter.setInsertionPointToStart(moduleOp.getBody()); |
443 | global = rewriter.create<LLVM::GlobalOp>( |
444 | loc, globalType, |
445 | /*isConstant=*/true, LLVM::Linkage::Internal, stringConstName, |
446 | rewriter.getStringAttr(formatString), /*allignment=*/0, addressSpace); |
447 | } |
448 | |
449 | // Get a pointer to the format string's first element |
450 | Value globalPtr = rewriter.create<LLVM::AddressOfOp>( |
451 | loc, |
452 | LLVM::LLVMPointerType::get(rewriter.getContext(), global.getAddrSpace()), |
453 | global.getSymNameAttr()); |
454 | Value stringStart = rewriter.create<LLVM::GEPOp>( |
455 | loc, ptrType, globalType, globalPtr, ArrayRef<LLVM::GEPArg>{0, 0}); |
456 | |
457 | // Construct arguments and function call |
458 | auto argsRange = adaptor.getArgs(); |
459 | SmallVector<Value, 4> printfArgs; |
460 | printfArgs.reserve(N: argsRange.size() + 1); |
461 | printfArgs.push_back(Elt: stringStart); |
462 | printfArgs.append(argsRange.begin(), argsRange.end()); |
463 | |
464 | rewriter.create<LLVM::CallOp>(loc, printfDecl, printfArgs); |
465 | rewriter.eraseOp(op: gpuPrintfOp); |
466 | return success(); |
467 | } |
468 | |
469 | LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite( |
470 | gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor, |
471 | ConversionPatternRewriter &rewriter) const { |
472 | Location loc = gpuPrintfOp->getLoc(); |
473 | |
474 | mlir::Type llvmI8 = typeConverter->convertType(rewriter.getIntegerType(8)); |
475 | mlir::Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext()); |
476 | |
477 | // Note: this is the GPUModule op, not the ModuleOp that surrounds it |
478 | // This ensures that global constants and declarations are placed within |
479 | // the device code, not the host code |
480 | auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>(); |
481 | |
482 | auto vprintfType = |
483 | LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {ptrType, ptrType}); |
484 | LLVM::LLVMFuncOp vprintfDecl = |
485 | getOrDefineFunction(moduleOp, loc, rewriter, "vprintf" , vprintfType); |
486 | |
487 | // Get a unique global name for the format. |
488 | SmallString<16> stringConstName = getUniqueFormatGlobalName(moduleOp); |
489 | |
490 | llvm::SmallString<20> formatString(adaptor.getFormat()); |
491 | formatString.push_back(Elt: '\0'); // Null terminate for C |
492 | auto globalType = |
493 | LLVM::LLVMArrayType::get(llvmI8, formatString.size_in_bytes()); |
494 | LLVM::GlobalOp global; |
495 | { |
496 | ConversionPatternRewriter::InsertionGuard guard(rewriter); |
497 | rewriter.setInsertionPointToStart(moduleOp.getBody()); |
498 | global = rewriter.create<LLVM::GlobalOp>( |
499 | loc, globalType, |
500 | /*isConstant=*/true, LLVM::Linkage::Internal, stringConstName, |
501 | rewriter.getStringAttr(formatString), /*allignment=*/0); |
502 | } |
503 | |
504 | // Get a pointer to the format string's first element |
505 | Value globalPtr = rewriter.create<LLVM::AddressOfOp>(loc, global); |
506 | Value stringStart = rewriter.create<LLVM::GEPOp>( |
507 | loc, ptrType, globalType, globalPtr, ArrayRef<LLVM::GEPArg>{0, 0}); |
508 | SmallVector<Type> types; |
509 | SmallVector<Value> args; |
510 | // Promote and pack the arguments into a stack allocation. |
511 | for (Value arg : adaptor.getArgs()) { |
512 | Type type = arg.getType(); |
513 | Value promotedArg = arg; |
514 | assert(type.isIntOrFloat()); |
515 | if (isa<FloatType>(type)) { |
516 | type = rewriter.getF64Type(); |
517 | promotedArg = rewriter.create<LLVM::FPExtOp>(loc, type, arg); |
518 | } |
519 | types.push_back(type); |
520 | args.push_back(promotedArg); |
521 | } |
522 | Type structType = |
523 | LLVM::LLVMStructType::getLiteral(context: gpuPrintfOp.getContext(), types); |
524 | Value one = rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI64Type(), |
525 | rewriter.getIndexAttr(1)); |
526 | Value tempAlloc = |
527 | rewriter.create<LLVM::AllocaOp>(loc, ptrType, structType, one, |
528 | /*alignment=*/0); |
529 | for (auto [index, arg] : llvm::enumerate(First&: args)) { |
530 | Value ptr = rewriter.create<LLVM::GEPOp>( |
531 | loc, ptrType, structType, tempAlloc, |
532 | ArrayRef<LLVM::GEPArg>{0, static_cast<int32_t>(index)}); |
533 | rewriter.create<LLVM::StoreOp>(loc, arg, ptr); |
534 | } |
535 | std::array<Value, 2> printfArgs = {stringStart, tempAlloc}; |
536 | |
537 | rewriter.create<LLVM::CallOp>(loc, vprintfDecl, printfArgs); |
538 | rewriter.eraseOp(op: gpuPrintfOp); |
539 | return success(); |
540 | } |
541 | |
542 | /// Unrolls op if it's operating on vectors. |
543 | LogicalResult impl::scalarizeVectorOp(Operation *op, ValueRange operands, |
544 | ConversionPatternRewriter &rewriter, |
545 | const LLVMTypeConverter &converter) { |
546 | TypeRange operandTypes(operands); |
547 | if (llvm::none_of(Range&: operandTypes, P: llvm::IsaPred<VectorType>)) { |
548 | return rewriter.notifyMatchFailure(arg&: op, msg: "expected vector operand" ); |
549 | } |
550 | if (op->getNumRegions() != 0 || op->getNumSuccessors() != 0) |
551 | return rewriter.notifyMatchFailure(arg&: op, msg: "expected no region/successor" ); |
552 | if (op->getNumResults() != 1) |
553 | return rewriter.notifyMatchFailure(arg&: op, msg: "expected single result" ); |
554 | VectorType vectorType = dyn_cast<VectorType>(op->getResult(0).getType()); |
555 | if (!vectorType) |
556 | return rewriter.notifyMatchFailure(arg&: op, msg: "expected vector result" ); |
557 | |
558 | Location loc = op->getLoc(); |
559 | Value result = rewriter.create<LLVM::UndefOp>(loc, vectorType); |
560 | Type indexType = converter.convertType(rewriter.getIndexType()); |
561 | StringAttr name = op->getName().getIdentifier(); |
562 | Type elementType = vectorType.getElementType(); |
563 | |
564 | for (int64_t i = 0; i < vectorType.getNumElements(); ++i) { |
565 | Value index = rewriter.create<LLVM::ConstantOp>(loc, indexType, i); |
566 | auto = [&](Value operand) -> Value { |
567 | if (!isa<VectorType>(Val: operand.getType())) |
568 | return operand; |
569 | return rewriter.create<LLVM::ExtractElementOp>(loc, operand, index); |
570 | }; |
571 | auto scalarOperands = llvm::map_to_vector(C&: operands, F&: extractElement); |
572 | Operation *scalarOp = |
573 | rewriter.create(loc, name, scalarOperands, elementType, op->getAttrs()); |
574 | result = rewriter.create<LLVM::InsertElementOp>( |
575 | loc, result, scalarOp->getResult(0), index); |
576 | } |
577 | |
578 | rewriter.replaceOp(op, newValues: result); |
579 | return success(); |
580 | } |
581 | |
582 | static IntegerAttr wrapNumericMemorySpace(MLIRContext *ctx, unsigned space) { |
583 | return IntegerAttr::get(IntegerType::get(ctx, 64), space); |
584 | } |
585 | |
586 | /// Generates a symbol with 0-sized array type for dynamic shared memory usage, |
587 | /// or uses existing symbol. |
588 | LLVM::GlobalOp |
589 | getDynamicSharedMemorySymbol(ConversionPatternRewriter &rewriter, |
590 | Operation *moduleOp, gpu::DynamicSharedMemoryOp op, |
591 | const LLVMTypeConverter *typeConverter, |
592 | MemRefType memrefType, unsigned alignmentBit) { |
593 | uint64_t alignmentByte = alignmentBit / memrefType.getElementTypeBitWidth(); |
594 | |
595 | FailureOr<unsigned> addressSpace = |
596 | typeConverter->getMemRefAddressSpace(type: memrefType); |
597 | if (failed(result: addressSpace)) { |
598 | op->emitError() << "conversion of memref memory space " |
599 | << memrefType.getMemorySpace() |
600 | << " to integer address space " |
601 | "failed. Consider adding memory space conversions." ; |
602 | } |
603 | |
604 | // Step 1. Collect symbol names of LLVM::GlobalOp Ops. Also if any of |
605 | // LLVM::GlobalOp is suitable for shared memory, return it. |
606 | llvm::StringSet<> existingGlobalNames; |
607 | for (auto globalOp : |
608 | moduleOp->getRegion(0).front().getOps<LLVM::GlobalOp>()) { |
609 | existingGlobalNames.insert(globalOp.getSymName()); |
610 | if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(globalOp.getType())) { |
611 | if (globalOp.getAddrSpace() == addressSpace.value() && |
612 | arrayType.getNumElements() == 0 && |
613 | globalOp.getAlignment().value_or(0) == alignmentByte) { |
614 | return globalOp; |
615 | } |
616 | } |
617 | } |
618 | |
619 | // Step 2. Find a unique symbol name |
620 | unsigned uniquingCounter = 0; |
621 | SmallString<128> symName = SymbolTable::generateSymbolName<128>( |
622 | name: "__dynamic_shmem_" , |
623 | uniqueChecker: [&](StringRef candidate) { |
624 | return existingGlobalNames.contains(key: candidate); |
625 | }, |
626 | uniquingCounter); |
627 | |
628 | // Step 3. Generate a global op |
629 | OpBuilder::InsertionGuard guard(rewriter); |
630 | rewriter.setInsertionPoint(&moduleOp->getRegion(index: 0).front().front()); |
631 | |
632 | auto zeroSizedArrayType = LLVM::LLVMArrayType::get( |
633 | typeConverter->convertType(memrefType.getElementType()), 0); |
634 | |
635 | return rewriter.create<LLVM::GlobalOp>( |
636 | op->getLoc(), zeroSizedArrayType, /*isConstant=*/false, |
637 | LLVM::Linkage::Internal, symName, /*value=*/Attribute(), alignmentByte, |
638 | addressSpace.value()); |
639 | } |
640 | |
641 | LogicalResult GPUDynamicSharedMemoryOpLowering::matchAndRewrite( |
642 | gpu::DynamicSharedMemoryOp op, OpAdaptor adaptor, |
643 | ConversionPatternRewriter &rewriter) const { |
644 | Location loc = op.getLoc(); |
645 | MemRefType memrefType = op.getResultMemref().getType(); |
646 | Type elementType = typeConverter->convertType(memrefType.getElementType()); |
647 | |
648 | // Step 1: Generate a memref<0xi8> type |
649 | MemRefLayoutAttrInterface layout = {}; |
650 | auto memrefType0sz = |
651 | MemRefType::get({0}, elementType, layout, memrefType.getMemorySpace()); |
652 | |
653 | // Step 2: Generate a global symbol or existing for the dynamic shared |
654 | // memory with memref<0xi8> type |
655 | LLVM::LLVMFuncOp funcOp = op->getParentOfType<LLVM::LLVMFuncOp>(); |
656 | LLVM::GlobalOp shmemOp = {}; |
657 | Operation *moduleOp = funcOp->getParentWithTrait<OpTrait::SymbolTable>(); |
658 | shmemOp = getDynamicSharedMemorySymbol( |
659 | rewriter, moduleOp, op, getTypeConverter(), memrefType0sz, alignmentBit); |
660 | |
661 | // Step 3. Get address of the global symbol |
662 | OpBuilder::InsertionGuard guard(rewriter); |
663 | rewriter.setInsertionPoint(op); |
664 | auto basePtr = rewriter.create<LLVM::AddressOfOp>(loc, shmemOp); |
665 | Type baseType = basePtr->getResultTypes().front(); |
666 | |
667 | // Step 4. Generate GEP using offsets |
668 | SmallVector<LLVM::GEPArg> gepArgs = {0}; |
669 | Value shmemPtr = rewriter.create<LLVM::GEPOp>(loc, baseType, elementType, |
670 | basePtr, gepArgs); |
671 | // Step 5. Create a memref descriptor |
672 | SmallVector<Value> shape, strides; |
673 | Value sizeBytes; |
674 | getMemRefDescriptorSizes(loc, memrefType0sz, {}, rewriter, shape, strides, |
675 | sizeBytes); |
676 | auto memRefDescriptor = this->createMemRefDescriptor( |
677 | loc, memrefType0sz, shmemPtr, shmemPtr, shape, strides, rewriter); |
678 | |
679 | // Step 5. Replace the op with memref descriptor |
680 | rewriter.replaceOp(op, {memRefDescriptor}); |
681 | return success(); |
682 | } |
683 | |
684 | void mlir::populateGpuMemorySpaceAttributeConversions( |
685 | TypeConverter &typeConverter, const MemorySpaceMapping &mapping) { |
686 | typeConverter.addTypeAttributeConversion( |
687 | callback: [mapping](BaseMemRefType type, gpu::AddressSpaceAttr memorySpaceAttr) { |
688 | gpu::AddressSpace memorySpace = memorySpaceAttr.getValue(); |
689 | unsigned addressSpace = mapping(memorySpace); |
690 | return wrapNumericMemorySpace(memorySpaceAttr.getContext(), |
691 | addressSpace); |
692 | }); |
693 | } |
694 | |