| 1 | //===- GPUOpsLowering.cpp - GPU FuncOp / ReturnOp lowering ----------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "GPUOpsLowering.h" |
| 10 | |
| 11 | #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" |
| 12 | #include "mlir/Conversion/LLVMCommon/VectorPattern.h" |
| 13 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
| 14 | #include "mlir/IR/Attributes.h" |
| 15 | #include "mlir/IR/Builders.h" |
| 16 | #include "mlir/IR/BuiltinTypes.h" |
| 17 | #include "llvm/ADT/SmallVectorExtras.h" |
| 18 | #include "llvm/ADT/StringSet.h" |
| 19 | #include "llvm/Support/FormatVariadic.h" |
| 20 | |
| 21 | using namespace mlir; |
| 22 | |
| 23 | LLVM::LLVMFuncOp mlir::getOrDefineFunction(gpu::GPUModuleOp moduleOp, |
| 24 | Location loc, OpBuilder &b, |
| 25 | StringRef name, |
| 26 | LLVM::LLVMFunctionType type) { |
| 27 | LLVM::LLVMFuncOp ret; |
| 28 | if (!(ret = moduleOp.template lookupSymbol<LLVM::LLVMFuncOp>(name))) { |
| 29 | OpBuilder::InsertionGuard guard(b); |
| 30 | b.setInsertionPointToStart(moduleOp.getBody()); |
| 31 | ret = b.create<LLVM::LLVMFuncOp>(location: loc, args&: name, args&: type, args: LLVM::Linkage::External); |
| 32 | } |
| 33 | return ret; |
| 34 | } |
| 35 | |
| 36 | static SmallString<16> getUniqueSymbolName(gpu::GPUModuleOp moduleOp, |
| 37 | StringRef prefix) { |
| 38 | // Get a unique global name. |
| 39 | unsigned stringNumber = 0; |
| 40 | SmallString<16> stringConstName; |
| 41 | do { |
| 42 | stringConstName.clear(); |
| 43 | (prefix + Twine(stringNumber++)).toStringRef(Out&: stringConstName); |
| 44 | } while (moduleOp.lookupSymbol(name: stringConstName)); |
| 45 | return stringConstName; |
| 46 | } |
| 47 | |
| 48 | LLVM::GlobalOp |
| 49 | mlir::getOrCreateStringConstant(OpBuilder &b, Location loc, |
| 50 | gpu::GPUModuleOp moduleOp, Type llvmI8, |
| 51 | StringRef namePrefix, StringRef str, |
| 52 | uint64_t alignment, unsigned addrSpace) { |
| 53 | llvm::SmallString<20> nullTermStr(str); |
| 54 | nullTermStr.push_back(Elt: '\0'); // Null terminate for C |
| 55 | auto globalType = |
| 56 | LLVM::LLVMArrayType::get(elementType: llvmI8, numElements: nullTermStr.size_in_bytes()); |
| 57 | StringAttr attr = b.getStringAttr(bytes: nullTermStr); |
| 58 | |
| 59 | // Try to find existing global. |
| 60 | for (auto globalOp : moduleOp.getOps<LLVM::GlobalOp>()) |
| 61 | if (globalOp.getGlobalType() == globalType && globalOp.getConstant() && |
| 62 | globalOp.getValueAttr() == attr && |
| 63 | globalOp.getAlignment().value_or(u: 0) == alignment && |
| 64 | globalOp.getAddrSpace() == addrSpace) |
| 65 | return globalOp; |
| 66 | |
| 67 | // Not found: create new global. |
| 68 | OpBuilder::InsertionGuard guard(b); |
| 69 | b.setInsertionPointToStart(moduleOp.getBody()); |
| 70 | SmallString<16> name = getUniqueSymbolName(moduleOp, prefix: namePrefix); |
| 71 | return b.create<LLVM::GlobalOp>(location: loc, args&: globalType, |
| 72 | /*isConstant=*/args: true, args: LLVM::Linkage::Internal, |
| 73 | args&: name, args&: attr, args&: alignment, args&: addrSpace); |
| 74 | } |
| 75 | |
| 76 | LogicalResult |
| 77 | GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor, |
| 78 | ConversionPatternRewriter &rewriter) const { |
| 79 | Location loc = gpuFuncOp.getLoc(); |
| 80 | |
| 81 | SmallVector<LLVM::GlobalOp, 3> workgroupBuffers; |
| 82 | if (encodeWorkgroupAttributionsAsArguments) { |
| 83 | // Append an `llvm.ptr` argument to the function signature to encode |
| 84 | // workgroup attributions. |
| 85 | |
| 86 | ArrayRef<BlockArgument> workgroupAttributions = |
| 87 | gpuFuncOp.getWorkgroupAttributions(); |
| 88 | size_t numAttributions = workgroupAttributions.size(); |
| 89 | |
| 90 | // Insert all arguments at the end. |
| 91 | unsigned index = gpuFuncOp.getNumArguments(); |
| 92 | SmallVector<unsigned> argIndices(numAttributions, index); |
| 93 | |
| 94 | // New arguments will simply be `llvm.ptr` with the correct address space |
| 95 | Type workgroupPtrType = |
| 96 | rewriter.getType<LLVM::LLVMPointerType>(args: workgroupAddrSpace); |
| 97 | SmallVector<Type> argTypes(numAttributions, workgroupPtrType); |
| 98 | |
| 99 | // Attributes: noalias, llvm.mlir.workgroup_attribution(<size>, <type>) |
| 100 | std::array attrs{ |
| 101 | rewriter.getNamedAttr(name: LLVM::LLVMDialect::getNoAliasAttrName(), |
| 102 | val: rewriter.getUnitAttr()), |
| 103 | rewriter.getNamedAttr( |
| 104 | name: getDialect().getWorkgroupAttributionAttrHelper().getName(), |
| 105 | val: rewriter.getUnitAttr()), |
| 106 | }; |
| 107 | SmallVector<DictionaryAttr> argAttrs; |
| 108 | for (BlockArgument attribution : workgroupAttributions) { |
| 109 | auto attributionType = cast<MemRefType>(Val: attribution.getType()); |
| 110 | IntegerAttr numElements = |
| 111 | rewriter.getI64IntegerAttr(value: attributionType.getNumElements()); |
| 112 | Type llvmElementType = |
| 113 | getTypeConverter()->convertType(t: attributionType.getElementType()); |
| 114 | if (!llvmElementType) |
| 115 | return failure(); |
| 116 | TypeAttr type = TypeAttr::get(type: llvmElementType); |
| 117 | attrs.back().setValue( |
| 118 | rewriter.getAttr<LLVM::WorkgroupAttributionAttr>(args&: numElements, args&: type)); |
| 119 | argAttrs.push_back(Elt: rewriter.getDictionaryAttr(value: attrs)); |
| 120 | } |
| 121 | |
| 122 | // Location match function location |
| 123 | SmallVector<Location> argLocs(numAttributions, gpuFuncOp.getLoc()); |
| 124 | |
| 125 | // Perform signature modification |
| 126 | rewriter.modifyOpInPlace( |
| 127 | root: gpuFuncOp, callable: [gpuFuncOp, &argIndices, &argTypes, &argAttrs, &argLocs]() { |
| 128 | LogicalResult inserted = |
| 129 | static_cast<FunctionOpInterface>(gpuFuncOp).insertArguments( |
| 130 | argIndices, argTypes, argAttrs, argLocs); |
| 131 | (void)inserted; |
| 132 | assert(succeeded(inserted) && |
| 133 | "expected GPU funcs to support inserting any argument" ); |
| 134 | }); |
| 135 | } else { |
| 136 | workgroupBuffers.reserve(N: gpuFuncOp.getNumWorkgroupAttributions()); |
| 137 | for (auto [idx, attribution] : |
| 138 | llvm::enumerate(First: gpuFuncOp.getWorkgroupAttributions())) { |
| 139 | auto type = dyn_cast<MemRefType>(Val: attribution.getType()); |
| 140 | assert(type && type.hasStaticShape() && "unexpected type in attribution" ); |
| 141 | |
| 142 | uint64_t numElements = type.getNumElements(); |
| 143 | |
| 144 | auto elementType = |
| 145 | cast<Type>(Val: typeConverter->convertType(t: type.getElementType())); |
| 146 | auto arrayType = LLVM::LLVMArrayType::get(elementType, numElements); |
| 147 | std::string name = |
| 148 | std::string(llvm::formatv(Fmt: "__wg_{0}_{1}" , Vals: gpuFuncOp.getName(), Vals&: idx)); |
| 149 | uint64_t alignment = 0; |
| 150 | if (auto alignAttr = dyn_cast_or_null<IntegerAttr>( |
| 151 | Val: gpuFuncOp.getWorkgroupAttributionAttr( |
| 152 | index: idx, name: LLVM::LLVMDialect::getAlignAttrName()))) |
| 153 | alignment = alignAttr.getInt(); |
| 154 | auto globalOp = rewriter.create<LLVM::GlobalOp>( |
| 155 | location: gpuFuncOp.getLoc(), args&: arrayType, /*isConstant=*/args: false, |
| 156 | args: LLVM::Linkage::Internal, args&: name, /*value=*/args: Attribute(), args&: alignment, |
| 157 | args: workgroupAddrSpace); |
| 158 | workgroupBuffers.push_back(Elt: globalOp); |
| 159 | } |
| 160 | } |
| 161 | |
| 162 | // Remap proper input types. |
| 163 | TypeConverter::SignatureConversion signatureConversion( |
| 164 | gpuFuncOp.front().getNumArguments()); |
| 165 | |
| 166 | Type funcType = getTypeConverter()->convertFunctionSignature( |
| 167 | funcTy: gpuFuncOp.getFunctionType(), /*isVariadic=*/false, |
| 168 | useBarePtrCallConv: getTypeConverter()->getOptions().useBarePtrCallConv, result&: signatureConversion); |
| 169 | if (!funcType) { |
| 170 | return rewriter.notifyMatchFailure(op: gpuFuncOp, reasonCallback: [&](Diagnostic &diag) { |
| 171 | diag << "failed to convert function signature type for: " |
| 172 | << gpuFuncOp.getFunctionType(); |
| 173 | }); |
| 174 | } |
| 175 | |
| 176 | // Create the new function operation. Only copy those attributes that are |
| 177 | // not specific to function modeling. |
| 178 | SmallVector<NamedAttribute, 4> attributes; |
| 179 | ArrayAttr argAttrs; |
| 180 | for (const auto &attr : gpuFuncOp->getAttrs()) { |
| 181 | if (attr.getName() == SymbolTable::getSymbolAttrName() || |
| 182 | attr.getName() == gpuFuncOp.getFunctionTypeAttrName() || |
| 183 | attr.getName() == |
| 184 | gpu::GPUFuncOp::getNumWorkgroupAttributionsAttrName() || |
| 185 | attr.getName() == gpuFuncOp.getWorkgroupAttribAttrsAttrName() || |
| 186 | attr.getName() == gpuFuncOp.getPrivateAttribAttrsAttrName() || |
| 187 | attr.getName() == gpuFuncOp.getKnownBlockSizeAttrName() || |
| 188 | attr.getName() == gpuFuncOp.getKnownGridSizeAttrName()) |
| 189 | continue; |
| 190 | if (attr.getName() == gpuFuncOp.getArgAttrsAttrName()) { |
| 191 | argAttrs = gpuFuncOp.getArgAttrsAttr(); |
| 192 | continue; |
| 193 | } |
| 194 | attributes.push_back(Elt: attr); |
| 195 | } |
| 196 | |
| 197 | DenseI32ArrayAttr knownBlockSize = gpuFuncOp.getKnownBlockSizeAttr(); |
| 198 | DenseI32ArrayAttr knownGridSize = gpuFuncOp.getKnownGridSizeAttr(); |
| 199 | // Ensure we don't lose information if the function is lowered before its |
| 200 | // surrounding context. |
| 201 | auto *gpuDialect = cast<gpu::GPUDialect>(Val: gpuFuncOp->getDialect()); |
| 202 | if (knownBlockSize) |
| 203 | attributes.emplace_back(Args: gpuDialect->getKnownBlockSizeAttrHelper().getName(), |
| 204 | Args&: knownBlockSize); |
| 205 | if (knownGridSize) |
| 206 | attributes.emplace_back(Args: gpuDialect->getKnownGridSizeAttrHelper().getName(), |
| 207 | Args&: knownGridSize); |
| 208 | |
| 209 | // Add a dialect specific kernel attribute in addition to GPU kernel |
| 210 | // attribute. The former is necessary for further translation while the |
| 211 | // latter is expected by gpu.launch_func. |
| 212 | if (gpuFuncOp.isKernel()) { |
| 213 | if (kernelAttributeName) |
| 214 | attributes.emplace_back(Args: kernelAttributeName, Args: rewriter.getUnitAttr()); |
| 215 | // Set the dialect-specific block size attribute if there is one. |
| 216 | if (kernelBlockSizeAttributeName && knownBlockSize) { |
| 217 | attributes.emplace_back(Args: kernelBlockSizeAttributeName, Args&: knownBlockSize); |
| 218 | } |
| 219 | } |
| 220 | LLVM::CConv callingConvention = gpuFuncOp.isKernel() |
| 221 | ? kernelCallingConvention |
| 222 | : nonKernelCallingConvention; |
| 223 | auto llvmFuncOp = rewriter.create<LLVM::LLVMFuncOp>( |
| 224 | location: gpuFuncOp.getLoc(), args: gpuFuncOp.getName(), args&: funcType, |
| 225 | args: LLVM::Linkage::External, /*dsoLocal=*/args: false, args&: callingConvention, |
| 226 | /*comdat=*/args: nullptr, args&: attributes); |
| 227 | |
| 228 | { |
| 229 | // Insert operations that correspond to converted workgroup and private |
| 230 | // memory attributions to the body of the function. This must operate on |
| 231 | // the original function, before the body region is inlined in the new |
| 232 | // function to maintain the relation between block arguments and the |
| 233 | // parent operation that assigns their semantics. |
| 234 | OpBuilder::InsertionGuard guard(rewriter); |
| 235 | |
| 236 | // Rewrite workgroup memory attributions to addresses of global buffers. |
| 237 | rewriter.setInsertionPointToStart(&gpuFuncOp.front()); |
| 238 | unsigned numProperArguments = gpuFuncOp.getNumArguments(); |
| 239 | |
| 240 | if (encodeWorkgroupAttributionsAsArguments) { |
| 241 | // Build a MemRefDescriptor with each of the arguments added above. |
| 242 | |
| 243 | unsigned numAttributions = gpuFuncOp.getNumWorkgroupAttributions(); |
| 244 | assert(numProperArguments >= numAttributions && |
| 245 | "Expecting attributions to be encoded as arguments already" ); |
| 246 | |
| 247 | // Arguments encoding workgroup attributions will be in positions |
| 248 | // [numProperArguments, numProperArguments+numAttributions) |
| 249 | ArrayRef<BlockArgument> attributionArguments = |
| 250 | gpuFuncOp.getArguments().slice(N: numProperArguments - numAttributions, |
| 251 | M: numAttributions); |
| 252 | for (auto [idx, vals] : llvm::enumerate(First: llvm::zip_equal( |
| 253 | t: gpuFuncOp.getWorkgroupAttributions(), u&: attributionArguments))) { |
| 254 | auto [attribution, arg] = vals; |
| 255 | auto type = cast<MemRefType>(Val: attribution.getType()); |
| 256 | |
| 257 | // Arguments are of llvm.ptr type and attributions are of memref type: |
| 258 | // we need to wrap them in memref descriptors. |
| 259 | Value descr = MemRefDescriptor::fromStaticShape( |
| 260 | builder&: rewriter, loc, typeConverter: *getTypeConverter(), type, memory: arg); |
| 261 | |
| 262 | // And remap the arguments |
| 263 | signatureConversion.remapInput(origInputNo: numProperArguments + idx, replacements: descr); |
| 264 | } |
| 265 | } else { |
| 266 | for (const auto [idx, global] : llvm::enumerate(First&: workgroupBuffers)) { |
| 267 | auto ptrType = LLVM::LLVMPointerType::get(context: rewriter.getContext(), |
| 268 | addressSpace: global.getAddrSpace()); |
| 269 | Value address = rewriter.create<LLVM::AddressOfOp>( |
| 270 | location: loc, args&: ptrType, args: global.getSymNameAttr()); |
| 271 | Value memory = |
| 272 | rewriter.create<LLVM::GEPOp>(location: loc, args&: ptrType, args: global.getType(), |
| 273 | args&: address, args: ArrayRef<LLVM::GEPArg>{0, 0}); |
| 274 | |
| 275 | // Build a memref descriptor pointing to the buffer to plug with the |
| 276 | // existing memref infrastructure. This may use more registers than |
| 277 | // otherwise necessary given that memref sizes are fixed, but we can try |
| 278 | // and canonicalize that away later. |
| 279 | Value attribution = gpuFuncOp.getWorkgroupAttributions()[idx]; |
| 280 | auto type = cast<MemRefType>(Val: attribution.getType()); |
| 281 | Value descr = MemRefDescriptor::fromStaticShape( |
| 282 | builder&: rewriter, loc, typeConverter: *getTypeConverter(), type, memory); |
| 283 | signatureConversion.remapInput(origInputNo: numProperArguments + idx, replacements: descr); |
| 284 | } |
| 285 | } |
| 286 | |
| 287 | // Rewrite private memory attributions to alloca'ed buffers. |
| 288 | unsigned numWorkgroupAttributions = gpuFuncOp.getNumWorkgroupAttributions(); |
| 289 | auto int64Ty = IntegerType::get(context: rewriter.getContext(), width: 64); |
| 290 | for (const auto [idx, attribution] : |
| 291 | llvm::enumerate(First: gpuFuncOp.getPrivateAttributions())) { |
| 292 | auto type = cast<MemRefType>(Val: attribution.getType()); |
| 293 | assert(type && type.hasStaticShape() && "unexpected type in attribution" ); |
| 294 | |
| 295 | // Explicitly drop memory space when lowering private memory |
| 296 | // attributions since NVVM models it as `alloca`s in the default |
| 297 | // memory space and does not support `alloca`s with addrspace(5). |
| 298 | Type elementType = typeConverter->convertType(t: type.getElementType()); |
| 299 | auto ptrType = |
| 300 | LLVM::LLVMPointerType::get(context: rewriter.getContext(), addressSpace: allocaAddrSpace); |
| 301 | Value numElements = rewriter.create<LLVM::ConstantOp>( |
| 302 | location: gpuFuncOp.getLoc(), args&: int64Ty, args: type.getNumElements()); |
| 303 | uint64_t alignment = 0; |
| 304 | if (auto alignAttr = |
| 305 | dyn_cast_or_null<IntegerAttr>(Val: gpuFuncOp.getPrivateAttributionAttr( |
| 306 | index: idx, name: LLVM::LLVMDialect::getAlignAttrName()))) |
| 307 | alignment = alignAttr.getInt(); |
| 308 | Value allocated = rewriter.create<LLVM::AllocaOp>( |
| 309 | location: gpuFuncOp.getLoc(), args&: ptrType, args&: elementType, args&: numElements, args&: alignment); |
| 310 | Value descr = MemRefDescriptor::fromStaticShape( |
| 311 | builder&: rewriter, loc, typeConverter: *getTypeConverter(), type, memory: allocated); |
| 312 | signatureConversion.remapInput( |
| 313 | origInputNo: numProperArguments + numWorkgroupAttributions + idx, replacements: descr); |
| 314 | } |
| 315 | } |
| 316 | |
| 317 | // Move the region to the new function, update the entry block signature. |
| 318 | rewriter.inlineRegionBefore(region&: gpuFuncOp.getBody(), parent&: llvmFuncOp.getBody(), |
| 319 | before: llvmFuncOp.end()); |
| 320 | if (failed(Result: rewriter.convertRegionTypes(region: &llvmFuncOp.getBody(), converter: *typeConverter, |
| 321 | entryConversion: &signatureConversion))) |
| 322 | return failure(); |
| 323 | |
| 324 | // Get memref type from function arguments and set the noalias to |
| 325 | // pointer arguments. |
| 326 | for (const auto [idx, argTy] : |
| 327 | llvm::enumerate(First: gpuFuncOp.getArgumentTypes())) { |
| 328 | auto remapping = signatureConversion.getInputMapping(input: idx); |
| 329 | NamedAttrList argAttr = |
| 330 | argAttrs ? cast<DictionaryAttr>(Val: argAttrs[idx]) : NamedAttrList(); |
| 331 | auto copyAttribute = [&](StringRef attrName) { |
| 332 | Attribute attr = argAttr.erase(name: attrName); |
| 333 | if (!attr) |
| 334 | return; |
| 335 | for (size_t i = 0, e = remapping->size; i < e; ++i) |
| 336 | llvmFuncOp.setArgAttr(index: remapping->inputNo + i, name: attrName, value: attr); |
| 337 | }; |
| 338 | auto copyPointerAttribute = [&](StringRef attrName) { |
| 339 | Attribute attr = argAttr.erase(name: attrName); |
| 340 | |
| 341 | if (!attr) |
| 342 | return; |
| 343 | if (remapping->size > 1 && |
| 344 | attrName == LLVM::LLVMDialect::getNoAliasAttrName()) { |
| 345 | emitWarning(loc: llvmFuncOp.getLoc(), |
| 346 | message: "Cannot copy noalias with non-bare pointers.\n" ); |
| 347 | return; |
| 348 | } |
| 349 | for (size_t i = 0, e = remapping->size; i < e; ++i) { |
| 350 | if (isa<LLVM::LLVMPointerType>( |
| 351 | Val: llvmFuncOp.getArgument(idx: remapping->inputNo + i).getType())) { |
| 352 | llvmFuncOp.setArgAttr(index: remapping->inputNo + i, name: attrName, value: attr); |
| 353 | } |
| 354 | } |
| 355 | }; |
| 356 | |
| 357 | if (argAttr.empty()) |
| 358 | continue; |
| 359 | |
| 360 | copyAttribute(LLVM::LLVMDialect::getReturnedAttrName()); |
| 361 | copyAttribute(LLVM::LLVMDialect::getNoUndefAttrName()); |
| 362 | copyAttribute(LLVM::LLVMDialect::getInRegAttrName()); |
| 363 | bool lowersToPointer = false; |
| 364 | for (size_t i = 0, e = remapping->size; i < e; ++i) { |
| 365 | lowersToPointer |= isa<LLVM::LLVMPointerType>( |
| 366 | Val: llvmFuncOp.getArgument(idx: remapping->inputNo + i).getType()); |
| 367 | } |
| 368 | |
| 369 | if (lowersToPointer) { |
| 370 | copyPointerAttribute(LLVM::LLVMDialect::getNoAliasAttrName()); |
| 371 | copyPointerAttribute(LLVM::LLVMDialect::getNoCaptureAttrName()); |
| 372 | copyPointerAttribute(LLVM::LLVMDialect::getNoFreeAttrName()); |
| 373 | copyPointerAttribute(LLVM::LLVMDialect::getAlignAttrName()); |
| 374 | copyPointerAttribute(LLVM::LLVMDialect::getReadonlyAttrName()); |
| 375 | copyPointerAttribute(LLVM::LLVMDialect::getWriteOnlyAttrName()); |
| 376 | copyPointerAttribute(LLVM::LLVMDialect::getReadnoneAttrName()); |
| 377 | copyPointerAttribute(LLVM::LLVMDialect::getNonNullAttrName()); |
| 378 | copyPointerAttribute(LLVM::LLVMDialect::getDereferenceableAttrName()); |
| 379 | copyPointerAttribute( |
| 380 | LLVM::LLVMDialect::getDereferenceableOrNullAttrName()); |
| 381 | copyPointerAttribute( |
| 382 | LLVM::LLVMDialect::WorkgroupAttributionAttrHelper::getNameStr()); |
| 383 | } |
| 384 | } |
| 385 | rewriter.eraseOp(op: gpuFuncOp); |
| 386 | return success(); |
| 387 | } |
| 388 | |
| 389 | LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite( |
| 390 | gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor, |
| 391 | ConversionPatternRewriter &rewriter) const { |
| 392 | Location loc = gpuPrintfOp->getLoc(); |
| 393 | |
| 394 | mlir::Type llvmI8 = typeConverter->convertType(t: rewriter.getI8Type()); |
| 395 | auto ptrType = LLVM::LLVMPointerType::get(context: rewriter.getContext()); |
| 396 | mlir::Type llvmI32 = typeConverter->convertType(t: rewriter.getI32Type()); |
| 397 | mlir::Type llvmI64 = typeConverter->convertType(t: rewriter.getI64Type()); |
| 398 | // Note: this is the GPUModule op, not the ModuleOp that surrounds it |
| 399 | // This ensures that global constants and declarations are placed within |
| 400 | // the device code, not the host code |
| 401 | auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>(); |
| 402 | |
| 403 | auto ocklBegin = |
| 404 | getOrDefineFunction(moduleOp, loc, b&: rewriter, name: "__ockl_printf_begin" , |
| 405 | type: LLVM::LLVMFunctionType::get(result: llvmI64, arguments: {llvmI64})); |
| 406 | LLVM::LLVMFuncOp ocklAppendArgs; |
| 407 | if (!adaptor.getArgs().empty()) { |
| 408 | ocklAppendArgs = getOrDefineFunction( |
| 409 | moduleOp, loc, b&: rewriter, name: "__ockl_printf_append_args" , |
| 410 | type: LLVM::LLVMFunctionType::get( |
| 411 | result: llvmI64, arguments: {llvmI64, /*numArgs*/ llvmI32, llvmI64, llvmI64, llvmI64, |
| 412 | llvmI64, llvmI64, llvmI64, llvmI64, /*isLast*/ llvmI32})); |
| 413 | } |
| 414 | auto ocklAppendStringN = getOrDefineFunction( |
| 415 | moduleOp, loc, b&: rewriter, name: "__ockl_printf_append_string_n" , |
| 416 | type: LLVM::LLVMFunctionType::get( |
| 417 | result: llvmI64, |
| 418 | arguments: {llvmI64, ptrType, /*length (bytes)*/ llvmI64, /*isLast*/ llvmI32})); |
| 419 | |
| 420 | /// Start the printf hostcall |
| 421 | Value zeroI64 = rewriter.create<LLVM::ConstantOp>(location: loc, args&: llvmI64, args: 0); |
| 422 | auto printfBeginCall = rewriter.create<LLVM::CallOp>(location: loc, args&: ocklBegin, args&: zeroI64); |
| 423 | Value printfDesc = printfBeginCall.getResult(); |
| 424 | |
| 425 | // Create the global op or find an existing one. |
| 426 | LLVM::GlobalOp global = getOrCreateStringConstant( |
| 427 | b&: rewriter, loc, moduleOp, llvmI8, namePrefix: "printfFormat_" , str: adaptor.getFormat()); |
| 428 | |
| 429 | // Get a pointer to the format string's first element and pass it to printf() |
| 430 | Value globalPtr = rewriter.create<LLVM::AddressOfOp>( |
| 431 | location: loc, |
| 432 | args: LLVM::LLVMPointerType::get(context: rewriter.getContext(), addressSpace: global.getAddrSpace()), |
| 433 | args: global.getSymNameAttr()); |
| 434 | Value stringStart = |
| 435 | rewriter.create<LLVM::GEPOp>(location: loc, args&: ptrType, args: global.getGlobalType(), |
| 436 | args&: globalPtr, args: ArrayRef<LLVM::GEPArg>{0, 0}); |
| 437 | Value stringLen = rewriter.create<LLVM::ConstantOp>( |
| 438 | location: loc, args&: llvmI64, args: cast<StringAttr>(Val: global.getValueAttr()).size()); |
| 439 | |
| 440 | Value oneI32 = rewriter.create<LLVM::ConstantOp>(location: loc, args&: llvmI32, args: 1); |
| 441 | Value zeroI32 = rewriter.create<LLVM::ConstantOp>(location: loc, args&: llvmI32, args: 0); |
| 442 | |
| 443 | auto appendFormatCall = rewriter.create<LLVM::CallOp>( |
| 444 | location: loc, args&: ocklAppendStringN, |
| 445 | args: ValueRange{printfDesc, stringStart, stringLen, |
| 446 | adaptor.getArgs().empty() ? oneI32 : zeroI32}); |
| 447 | printfDesc = appendFormatCall.getResult(); |
| 448 | |
| 449 | // __ockl_printf_append_args takes 7 values per append call |
| 450 | constexpr size_t argsPerAppend = 7; |
| 451 | size_t nArgs = adaptor.getArgs().size(); |
| 452 | for (size_t group = 0; group < nArgs; group += argsPerAppend) { |
| 453 | size_t bound = std::min(a: group + argsPerAppend, b: nArgs); |
| 454 | size_t numArgsThisCall = bound - group; |
| 455 | |
| 456 | SmallVector<mlir::Value, 2 + argsPerAppend + 1> arguments; |
| 457 | arguments.push_back(Elt: printfDesc); |
| 458 | arguments.push_back( |
| 459 | Elt: rewriter.create<LLVM::ConstantOp>(location: loc, args&: llvmI32, args&: numArgsThisCall)); |
| 460 | for (size_t i = group; i < bound; ++i) { |
| 461 | Value arg = adaptor.getArgs()[i]; |
| 462 | if (auto floatType = dyn_cast<FloatType>(Val: arg.getType())) { |
| 463 | if (!floatType.isF64()) |
| 464 | arg = rewriter.create<LLVM::FPExtOp>( |
| 465 | location: loc, args: typeConverter->convertType(t: rewriter.getF64Type()), args&: arg); |
| 466 | arg = rewriter.create<LLVM::BitcastOp>(location: loc, args&: llvmI64, args&: arg); |
| 467 | } |
| 468 | if (arg.getType().getIntOrFloatBitWidth() != 64) |
| 469 | arg = rewriter.create<LLVM::ZExtOp>(location: loc, args&: llvmI64, args&: arg); |
| 470 | |
| 471 | arguments.push_back(Elt: arg); |
| 472 | } |
| 473 | // Pad out to 7 arguments since the hostcall always needs 7 |
| 474 | for (size_t = numArgsThisCall; extra < argsPerAppend; ++extra) { |
| 475 | arguments.push_back(Elt: zeroI64); |
| 476 | } |
| 477 | |
| 478 | auto isLast = (bound == nArgs) ? oneI32 : zeroI32; |
| 479 | arguments.push_back(Elt: isLast); |
| 480 | auto call = rewriter.create<LLVM::CallOp>(location: loc, args&: ocklAppendArgs, args&: arguments); |
| 481 | printfDesc = call.getResult(); |
| 482 | } |
| 483 | rewriter.eraseOp(op: gpuPrintfOp); |
| 484 | return success(); |
| 485 | } |
| 486 | |
| 487 | LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite( |
| 488 | gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor, |
| 489 | ConversionPatternRewriter &rewriter) const { |
| 490 | Location loc = gpuPrintfOp->getLoc(); |
| 491 | |
| 492 | mlir::Type llvmI8 = typeConverter->convertType(t: rewriter.getIntegerType(width: 8)); |
| 493 | mlir::Type ptrType = |
| 494 | LLVM::LLVMPointerType::get(context: rewriter.getContext(), addressSpace); |
| 495 | |
| 496 | // Note: this is the GPUModule op, not the ModuleOp that surrounds it |
| 497 | // This ensures that global constants and declarations are placed within |
| 498 | // the device code, not the host code |
| 499 | auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>(); |
| 500 | |
| 501 | auto printfType = |
| 502 | LLVM::LLVMFunctionType::get(result: rewriter.getI32Type(), arguments: {ptrType}, |
| 503 | /*isVarArg=*/true); |
| 504 | LLVM::LLVMFuncOp printfDecl = |
| 505 | getOrDefineFunction(moduleOp, loc, b&: rewriter, name: "printf" , type: printfType); |
| 506 | |
| 507 | // Create the global op or find an existing one. |
| 508 | LLVM::GlobalOp global = getOrCreateStringConstant( |
| 509 | b&: rewriter, loc, moduleOp, llvmI8, namePrefix: "printfFormat_" , str: adaptor.getFormat(), |
| 510 | /*alignment=*/0, addrSpace: addressSpace); |
| 511 | |
| 512 | // Get a pointer to the format string's first element |
| 513 | Value globalPtr = rewriter.create<LLVM::AddressOfOp>( |
| 514 | location: loc, |
| 515 | args: LLVM::LLVMPointerType::get(context: rewriter.getContext(), addressSpace: global.getAddrSpace()), |
| 516 | args: global.getSymNameAttr()); |
| 517 | Value stringStart = |
| 518 | rewriter.create<LLVM::GEPOp>(location: loc, args&: ptrType, args: global.getGlobalType(), |
| 519 | args&: globalPtr, args: ArrayRef<LLVM::GEPArg>{0, 0}); |
| 520 | |
| 521 | // Construct arguments and function call |
| 522 | auto argsRange = adaptor.getArgs(); |
| 523 | SmallVector<Value, 4> printfArgs; |
| 524 | printfArgs.reserve(N: argsRange.size() + 1); |
| 525 | printfArgs.push_back(Elt: stringStart); |
| 526 | printfArgs.append(in_start: argsRange.begin(), in_end: argsRange.end()); |
| 527 | |
| 528 | rewriter.create<LLVM::CallOp>(location: loc, args&: printfDecl, args&: printfArgs); |
| 529 | rewriter.eraseOp(op: gpuPrintfOp); |
| 530 | return success(); |
| 531 | } |
| 532 | |
| 533 | LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite( |
| 534 | gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor, |
| 535 | ConversionPatternRewriter &rewriter) const { |
| 536 | Location loc = gpuPrintfOp->getLoc(); |
| 537 | |
| 538 | mlir::Type llvmI8 = typeConverter->convertType(t: rewriter.getIntegerType(width: 8)); |
| 539 | mlir::Type ptrType = LLVM::LLVMPointerType::get(context: rewriter.getContext()); |
| 540 | |
| 541 | // Note: this is the GPUModule op, not the ModuleOp that surrounds it |
| 542 | // This ensures that global constants and declarations are placed within |
| 543 | // the device code, not the host code |
| 544 | auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>(); |
| 545 | |
| 546 | // Create a valid global location removing any metadata attached to the |
| 547 | // location as debug info metadata inside of a function cannot be used outside |
| 548 | // of that function. |
| 549 | Location globalLoc = loc->findInstanceOfOrUnknown<FileLineColLoc>(); |
| 550 | |
| 551 | auto vprintfType = |
| 552 | LLVM::LLVMFunctionType::get(result: rewriter.getI32Type(), arguments: {ptrType, ptrType}); |
| 553 | LLVM::LLVMFuncOp vprintfDecl = getOrDefineFunction( |
| 554 | moduleOp, loc: globalLoc, b&: rewriter, name: "vprintf" , type: vprintfType); |
| 555 | |
| 556 | // Create the global op or find an existing one. |
| 557 | LLVM::GlobalOp global = |
| 558 | getOrCreateStringConstant(b&: rewriter, loc: globalLoc, moduleOp, llvmI8, |
| 559 | namePrefix: "printfFormat_" , str: adaptor.getFormat()); |
| 560 | |
| 561 | // Get a pointer to the format string's first element |
| 562 | Value globalPtr = rewriter.create<LLVM::AddressOfOp>(location: loc, args&: global); |
| 563 | Value stringStart = |
| 564 | rewriter.create<LLVM::GEPOp>(location: loc, args&: ptrType, args: global.getGlobalType(), |
| 565 | args&: globalPtr, args: ArrayRef<LLVM::GEPArg>{0, 0}); |
| 566 | SmallVector<Type> types; |
| 567 | SmallVector<Value> args; |
| 568 | // Promote and pack the arguments into a stack allocation. |
| 569 | for (Value arg : adaptor.getArgs()) { |
| 570 | Type type = arg.getType(); |
| 571 | Value promotedArg = arg; |
| 572 | assert(type.isIntOrFloat()); |
| 573 | if (isa<FloatType>(Val: type)) { |
| 574 | type = rewriter.getF64Type(); |
| 575 | promotedArg = rewriter.create<LLVM::FPExtOp>(location: loc, args&: type, args&: arg); |
| 576 | } |
| 577 | types.push_back(Elt: type); |
| 578 | args.push_back(Elt: promotedArg); |
| 579 | } |
| 580 | Type structType = |
| 581 | LLVM::LLVMStructType::getLiteral(context: gpuPrintfOp.getContext(), types); |
| 582 | Value one = rewriter.create<LLVM::ConstantOp>(location: loc, args: rewriter.getI64Type(), |
| 583 | args: rewriter.getIndexAttr(value: 1)); |
| 584 | Value tempAlloc = |
| 585 | rewriter.create<LLVM::AllocaOp>(location: loc, args&: ptrType, args&: structType, args&: one, |
| 586 | /*alignment=*/args: 0); |
| 587 | for (auto [index, arg] : llvm::enumerate(First&: args)) { |
| 588 | Value ptr = rewriter.create<LLVM::GEPOp>( |
| 589 | location: loc, args&: ptrType, args&: structType, args&: tempAlloc, |
| 590 | args: ArrayRef<LLVM::GEPArg>{0, static_cast<int32_t>(index)}); |
| 591 | rewriter.create<LLVM::StoreOp>(location: loc, args&: arg, args&: ptr); |
| 592 | } |
| 593 | std::array<Value, 2> printfArgs = {stringStart, tempAlloc}; |
| 594 | |
| 595 | rewriter.create<LLVM::CallOp>(location: loc, args&: vprintfDecl, args&: printfArgs); |
| 596 | rewriter.eraseOp(op: gpuPrintfOp); |
| 597 | return success(); |
| 598 | } |
| 599 | |
| 600 | /// Helper for impl::scalarizeVectorOp. Scalarizes vectors to elements. |
| 601 | /// Used either directly (for ops on 1D vectors) or as the callback passed to |
| 602 | /// detail::handleMultidimensionalVectors (for ops on higher-rank vectors). |
| 603 | static Value scalarizeVectorOpHelper(Operation *op, ValueRange operands, |
| 604 | Type llvm1DVectorTy, |
| 605 | ConversionPatternRewriter &rewriter, |
| 606 | const LLVMTypeConverter &converter) { |
| 607 | TypeRange operandTypes(operands); |
| 608 | VectorType vectorType = cast<VectorType>(Val&: llvm1DVectorTy); |
| 609 | Location loc = op->getLoc(); |
| 610 | Value result = rewriter.create<LLVM::PoisonOp>(location: loc, args&: vectorType); |
| 611 | Type indexType = converter.convertType(t: rewriter.getIndexType()); |
| 612 | StringAttr name = op->getName().getIdentifier(); |
| 613 | Type elementType = vectorType.getElementType(); |
| 614 | |
| 615 | for (int64_t i = 0; i < vectorType.getNumElements(); ++i) { |
| 616 | Value index = rewriter.create<LLVM::ConstantOp>(location: loc, args&: indexType, args&: i); |
| 617 | auto = [&](Value operand) -> Value { |
| 618 | if (!isa<VectorType>(Val: operand.getType())) |
| 619 | return operand; |
| 620 | return rewriter.create<LLVM::ExtractElementOp>(location: loc, args&: operand, args&: index); |
| 621 | }; |
| 622 | auto scalarOperands = llvm::map_to_vector(C&: operands, F&: extractElement); |
| 623 | Operation *scalarOp = |
| 624 | rewriter.create(loc, opName: name, operands: scalarOperands, types: elementType, attributes: op->getAttrs()); |
| 625 | result = rewriter.create<LLVM::InsertElementOp>( |
| 626 | location: loc, args&: result, args: scalarOp->getResult(idx: 0), args&: index); |
| 627 | } |
| 628 | return result; |
| 629 | } |
| 630 | |
| 631 | /// Unrolls op to array/vector elements. |
| 632 | LogicalResult impl::scalarizeVectorOp(Operation *op, ValueRange operands, |
| 633 | ConversionPatternRewriter &rewriter, |
| 634 | const LLVMTypeConverter &converter) { |
| 635 | TypeRange operandTypes(operands); |
| 636 | if (llvm::any_of(Range&: operandTypes, P: llvm::IsaPred<VectorType>)) { |
| 637 | VectorType vectorType = |
| 638 | cast<VectorType>(Val: converter.convertType(t: op->getResultTypes()[0])); |
| 639 | rewriter.replaceOp(op, newValues: scalarizeVectorOpHelper(op, operands, llvm1DVectorTy: vectorType, |
| 640 | rewriter, converter)); |
| 641 | return success(); |
| 642 | } |
| 643 | |
| 644 | if (llvm::any_of(Range&: operandTypes, P: llvm::IsaPred<LLVM::LLVMArrayType>)) { |
| 645 | return LLVM::detail::handleMultidimensionalVectors( |
| 646 | op, operands, typeConverter: converter, |
| 647 | createOperand: [&](Type llvm1DVectorTy, ValueRange operands) -> Value { |
| 648 | return scalarizeVectorOpHelper(op, operands, llvm1DVectorTy, rewriter, |
| 649 | converter); |
| 650 | }, |
| 651 | rewriter); |
| 652 | } |
| 653 | |
| 654 | return rewriter.notifyMatchFailure(arg&: op, msg: "no llvm.array or vector to unroll" ); |
| 655 | } |
| 656 | |
| 657 | static IntegerAttr wrapNumericMemorySpace(MLIRContext *ctx, unsigned space) { |
| 658 | return IntegerAttr::get(type: IntegerType::get(context: ctx, width: 64), value: space); |
| 659 | } |
| 660 | |
| 661 | /// Generates a symbol with 0-sized array type for dynamic shared memory usage, |
| 662 | /// or uses existing symbol. |
| 663 | LLVM::GlobalOp getDynamicSharedMemorySymbol( |
| 664 | ConversionPatternRewriter &rewriter, gpu::GPUModuleOp moduleOp, |
| 665 | gpu::DynamicSharedMemoryOp op, const LLVMTypeConverter *typeConverter, |
| 666 | MemRefType memrefType, unsigned alignmentBit) { |
| 667 | uint64_t alignmentByte = alignmentBit / memrefType.getElementTypeBitWidth(); |
| 668 | |
| 669 | FailureOr<unsigned> addressSpace = |
| 670 | typeConverter->getMemRefAddressSpace(type: memrefType); |
| 671 | if (failed(Result: addressSpace)) { |
| 672 | op->emitError() << "conversion of memref memory space " |
| 673 | << memrefType.getMemorySpace() |
| 674 | << " to integer address space " |
| 675 | "failed. Consider adding memory space conversions." ; |
| 676 | } |
| 677 | |
| 678 | // Step 1. Collect symbol names of LLVM::GlobalOp Ops. Also if any of |
| 679 | // LLVM::GlobalOp is suitable for shared memory, return it. |
| 680 | llvm::StringSet<> existingGlobalNames; |
| 681 | for (auto globalOp : moduleOp.getBody()->getOps<LLVM::GlobalOp>()) { |
| 682 | existingGlobalNames.insert(key: globalOp.getSymName()); |
| 683 | if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(Val: globalOp.getType())) { |
| 684 | if (globalOp.getAddrSpace() == addressSpace.value() && |
| 685 | arrayType.getNumElements() == 0 && |
| 686 | globalOp.getAlignment().value_or(u: 0) == alignmentByte) { |
| 687 | return globalOp; |
| 688 | } |
| 689 | } |
| 690 | } |
| 691 | |
| 692 | // Step 2. Find a unique symbol name |
| 693 | unsigned uniquingCounter = 0; |
| 694 | SmallString<128> symName = SymbolTable::generateSymbolName<128>( |
| 695 | name: "__dynamic_shmem_" , |
| 696 | uniqueChecker: [&](StringRef candidate) { |
| 697 | return existingGlobalNames.contains(key: candidate); |
| 698 | }, |
| 699 | uniquingCounter); |
| 700 | |
| 701 | // Step 3. Generate a global op |
| 702 | OpBuilder::InsertionGuard guard(rewriter); |
| 703 | rewriter.setInsertionPointToStart(moduleOp.getBody()); |
| 704 | |
| 705 | auto zeroSizedArrayType = LLVM::LLVMArrayType::get( |
| 706 | elementType: typeConverter->convertType(t: memrefType.getElementType()), numElements: 0); |
| 707 | |
| 708 | return rewriter.create<LLVM::GlobalOp>( |
| 709 | location: op->getLoc(), args&: zeroSizedArrayType, /*isConstant=*/args: false, |
| 710 | args: LLVM::Linkage::Internal, args&: symName, /*value=*/args: Attribute(), args&: alignmentByte, |
| 711 | args&: addressSpace.value()); |
| 712 | } |
| 713 | |
| 714 | LogicalResult GPUDynamicSharedMemoryOpLowering::matchAndRewrite( |
| 715 | gpu::DynamicSharedMemoryOp op, OpAdaptor adaptor, |
| 716 | ConversionPatternRewriter &rewriter) const { |
| 717 | Location loc = op.getLoc(); |
| 718 | MemRefType memrefType = op.getResultMemref().getType(); |
| 719 | Type elementType = typeConverter->convertType(t: memrefType.getElementType()); |
| 720 | |
| 721 | // Step 1: Generate a memref<0xi8> type |
| 722 | MemRefLayoutAttrInterface layout = {}; |
| 723 | auto memrefType0sz = |
| 724 | MemRefType::get(shape: {0}, elementType, layout, memorySpace: memrefType.getMemorySpace()); |
| 725 | |
| 726 | // Step 2: Generate a global symbol or existing for the dynamic shared |
| 727 | // memory with memref<0xi8> type |
| 728 | auto moduleOp = op->getParentOfType<gpu::GPUModuleOp>(); |
| 729 | LLVM::GlobalOp shmemOp = getDynamicSharedMemorySymbol( |
| 730 | rewriter, moduleOp, op, typeConverter: getTypeConverter(), memrefType: memrefType0sz, alignmentBit); |
| 731 | |
| 732 | // Step 3. Get address of the global symbol |
| 733 | OpBuilder::InsertionGuard guard(rewriter); |
| 734 | rewriter.setInsertionPoint(op); |
| 735 | auto basePtr = rewriter.create<LLVM::AddressOfOp>(location: loc, args&: shmemOp); |
| 736 | Type baseType = basePtr->getResultTypes().front(); |
| 737 | |
| 738 | // Step 4. Generate GEP using offsets |
| 739 | SmallVector<LLVM::GEPArg> gepArgs = {0}; |
| 740 | Value shmemPtr = rewriter.create<LLVM::GEPOp>(location: loc, args&: baseType, args&: elementType, |
| 741 | args&: basePtr, args&: gepArgs); |
| 742 | // Step 5. Create a memref descriptor |
| 743 | SmallVector<Value> shape, strides; |
| 744 | Value sizeBytes; |
| 745 | getMemRefDescriptorSizes(loc, memRefType: memrefType0sz, dynamicSizes: {}, rewriter, sizes&: shape, strides, |
| 746 | size&: sizeBytes); |
| 747 | auto memRefDescriptor = this->createMemRefDescriptor( |
| 748 | loc, memRefType: memrefType0sz, allocatedPtr: shmemPtr, alignedPtr: shmemPtr, sizes: shape, strides, rewriter); |
| 749 | |
| 750 | // Step 5. Replace the op with memref descriptor |
| 751 | rewriter.replaceOp(op, newValues: {memRefDescriptor}); |
| 752 | return success(); |
| 753 | } |
| 754 | |
| 755 | LogicalResult GPUReturnOpLowering::matchAndRewrite( |
| 756 | gpu::ReturnOp op, OpAdaptor adaptor, |
| 757 | ConversionPatternRewriter &rewriter) const { |
| 758 | Location loc = op.getLoc(); |
| 759 | unsigned numArguments = op.getNumOperands(); |
| 760 | SmallVector<Value, 4> updatedOperands; |
| 761 | |
| 762 | bool useBarePtrCallConv = getTypeConverter()->getOptions().useBarePtrCallConv; |
| 763 | if (useBarePtrCallConv) { |
| 764 | // For the bare-ptr calling convention, extract the aligned pointer to |
| 765 | // be returned from the memref descriptor. |
| 766 | for (auto it : llvm::zip(t: op->getOperands(), u: adaptor.getOperands())) { |
| 767 | Type oldTy = std::get<0>(t&: it).getType(); |
| 768 | Value newOperand = std::get<1>(t&: it); |
| 769 | if (isa<MemRefType>(Val: oldTy) && getTypeConverter()->canConvertToBarePtr( |
| 770 | type: cast<BaseMemRefType>(Val&: oldTy))) { |
| 771 | MemRefDescriptor memrefDesc(newOperand); |
| 772 | newOperand = memrefDesc.allocatedPtr(builder&: rewriter, loc); |
| 773 | } else if (isa<UnrankedMemRefType>(Val: oldTy)) { |
| 774 | // Unranked memref is not supported in the bare pointer calling |
| 775 | // convention. |
| 776 | return failure(); |
| 777 | } |
| 778 | updatedOperands.push_back(Elt: newOperand); |
| 779 | } |
| 780 | } else { |
| 781 | updatedOperands = llvm::to_vector<4>(Range: adaptor.getOperands()); |
| 782 | (void)copyUnrankedDescriptors(builder&: rewriter, loc, origTypes: op.getOperands().getTypes(), |
| 783 | operands&: updatedOperands, |
| 784 | /*toDynamic=*/true); |
| 785 | } |
| 786 | |
| 787 | // If ReturnOp has 0 or 1 operand, create it and return immediately. |
| 788 | if (numArguments <= 1) { |
| 789 | rewriter.replaceOpWithNewOp<LLVM::ReturnOp>( |
| 790 | op, args: TypeRange(), args&: updatedOperands, args: op->getAttrs()); |
| 791 | return success(); |
| 792 | } |
| 793 | |
| 794 | // Otherwise, we need to pack the arguments into an LLVM struct type before |
| 795 | // returning. |
| 796 | auto packedType = getTypeConverter()->packFunctionResults( |
| 797 | types: op.getOperandTypes(), useBarePointerCallConv: useBarePtrCallConv); |
| 798 | if (!packedType) { |
| 799 | return rewriter.notifyMatchFailure(arg&: op, msg: "could not convert result types" ); |
| 800 | } |
| 801 | |
| 802 | Value packed = rewriter.create<LLVM::PoisonOp>(location: loc, args&: packedType); |
| 803 | for (auto [idx, operand] : llvm::enumerate(First&: updatedOperands)) { |
| 804 | packed = rewriter.create<LLVM::InsertValueOp>(location: loc, args&: packed, args&: operand, args&: idx); |
| 805 | } |
| 806 | rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, args: TypeRange(), args&: packed, |
| 807 | args: op->getAttrs()); |
| 808 | return success(); |
| 809 | } |
| 810 | |
| 811 | void mlir::populateGpuMemorySpaceAttributeConversions( |
| 812 | TypeConverter &typeConverter, const MemorySpaceMapping &mapping) { |
| 813 | typeConverter.addTypeAttributeConversion( |
| 814 | callback: [mapping](BaseMemRefType type, gpu::AddressSpaceAttr memorySpaceAttr) { |
| 815 | gpu::AddressSpace memorySpace = memorySpaceAttr.getValue(); |
| 816 | unsigned addressSpace = mapping(memorySpace); |
| 817 | return wrapNumericMemorySpace(ctx: memorySpaceAttr.getContext(), |
| 818 | space: addressSpace); |
| 819 | }); |
| 820 | } |
| 821 | |