| 1 | //===- GPUOpsLowering.cpp - GPU FuncOp / ReturnOp lowering ----------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "GPUOpsLowering.h" |
| 10 | |
| 11 | #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" |
| 12 | #include "mlir/Conversion/LLVMCommon/VectorPattern.h" |
| 13 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
| 14 | #include "mlir/IR/Attributes.h" |
| 15 | #include "mlir/IR/Builders.h" |
| 16 | #include "mlir/IR/BuiltinTypes.h" |
| 17 | #include "llvm/ADT/SmallVectorExtras.h" |
| 18 | #include "llvm/ADT/StringSet.h" |
| 19 | #include "llvm/Support/FormatVariadic.h" |
| 20 | |
| 21 | using namespace mlir; |
| 22 | |
| 23 | LLVM::LLVMFuncOp mlir::getOrDefineFunction(gpu::GPUModuleOp moduleOp, |
| 24 | Location loc, OpBuilder &b, |
| 25 | StringRef name, |
| 26 | LLVM::LLVMFunctionType type) { |
| 27 | LLVM::LLVMFuncOp ret; |
| 28 | if (!(ret = moduleOp.template lookupSymbol<LLVM::LLVMFuncOp>(name))) { |
| 29 | OpBuilder::InsertionGuard guard(b); |
| 30 | b.setInsertionPointToStart(moduleOp.getBody()); |
| 31 | ret = b.create<LLVM::LLVMFuncOp>(loc, name, type, LLVM::Linkage::External); |
| 32 | } |
| 33 | return ret; |
| 34 | } |
| 35 | |
| 36 | static SmallString<16> getUniqueSymbolName(gpu::GPUModuleOp moduleOp, |
| 37 | StringRef prefix) { |
| 38 | // Get a unique global name. |
| 39 | unsigned stringNumber = 0; |
| 40 | SmallString<16> stringConstName; |
| 41 | do { |
| 42 | stringConstName.clear(); |
| 43 | (prefix + Twine(stringNumber++)).toStringRef(Out&: stringConstName); |
| 44 | } while (moduleOp.lookupSymbol(stringConstName)); |
| 45 | return stringConstName; |
| 46 | } |
| 47 | |
| 48 | LLVM::GlobalOp |
| 49 | mlir::getOrCreateStringConstant(OpBuilder &b, Location loc, |
| 50 | gpu::GPUModuleOp moduleOp, Type llvmI8, |
| 51 | StringRef namePrefix, StringRef str, |
| 52 | uint64_t alignment, unsigned addrSpace) { |
| 53 | llvm::SmallString<20> nullTermStr(str); |
| 54 | nullTermStr.push_back(Elt: '\0'); // Null terminate for C |
| 55 | auto globalType = |
| 56 | LLVM::LLVMArrayType::get(llvmI8, nullTermStr.size_in_bytes()); |
| 57 | StringAttr attr = b.getStringAttr(nullTermStr); |
| 58 | |
| 59 | // Try to find existing global. |
| 60 | for (auto globalOp : moduleOp.getOps<LLVM::GlobalOp>()) |
| 61 | if (globalOp.getGlobalType() == globalType && globalOp.getConstant() && |
| 62 | globalOp.getValueAttr() == attr && |
| 63 | globalOp.getAlignment().value_or(0) == alignment && |
| 64 | globalOp.getAddrSpace() == addrSpace) |
| 65 | return globalOp; |
| 66 | |
| 67 | // Not found: create new global. |
| 68 | OpBuilder::InsertionGuard guard(b); |
| 69 | b.setInsertionPointToStart(moduleOp.getBody()); |
| 70 | SmallString<16> name = getUniqueSymbolName(moduleOp, namePrefix); |
| 71 | return b.create<LLVM::GlobalOp>(loc, globalType, |
| 72 | /*isConstant=*/true, LLVM::Linkage::Internal, |
| 73 | name, attr, alignment, addrSpace); |
| 74 | } |
| 75 | |
| 76 | LogicalResult |
| 77 | GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor, |
| 78 | ConversionPatternRewriter &rewriter) const { |
| 79 | Location loc = gpuFuncOp.getLoc(); |
| 80 | |
| 81 | SmallVector<LLVM::GlobalOp, 3> workgroupBuffers; |
| 82 | if (encodeWorkgroupAttributionsAsArguments) { |
| 83 | // Append an `llvm.ptr` argument to the function signature to encode |
| 84 | // workgroup attributions. |
| 85 | |
| 86 | ArrayRef<BlockArgument> workgroupAttributions = |
| 87 | gpuFuncOp.getWorkgroupAttributions(); |
| 88 | size_t numAttributions = workgroupAttributions.size(); |
| 89 | |
| 90 | // Insert all arguments at the end. |
| 91 | unsigned index = gpuFuncOp.getNumArguments(); |
| 92 | SmallVector<unsigned> argIndices(numAttributions, index); |
| 93 | |
| 94 | // New arguments will simply be `llvm.ptr` with the correct address space |
| 95 | Type workgroupPtrType = |
| 96 | rewriter.getType<LLVM::LLVMPointerType>(workgroupAddrSpace); |
| 97 | SmallVector<Type> argTypes(numAttributions, workgroupPtrType); |
| 98 | |
| 99 | // Attributes: noalias, llvm.mlir.workgroup_attribution(<size>, <type>) |
| 100 | std::array attrs{ |
| 101 | rewriter.getNamedAttr(LLVM::LLVMDialect::getNoAliasAttrName(), |
| 102 | rewriter.getUnitAttr()), |
| 103 | rewriter.getNamedAttr( |
| 104 | getDialect().getWorkgroupAttributionAttrHelper().getName(), |
| 105 | rewriter.getUnitAttr()), |
| 106 | }; |
| 107 | SmallVector<DictionaryAttr> argAttrs; |
| 108 | for (BlockArgument attribution : workgroupAttributions) { |
| 109 | auto attributionType = cast<MemRefType>(attribution.getType()); |
| 110 | IntegerAttr numElements = |
| 111 | rewriter.getI64IntegerAttr(attributionType.getNumElements()); |
| 112 | Type llvmElementType = |
| 113 | getTypeConverter()->convertType(attributionType.getElementType()); |
| 114 | if (!llvmElementType) |
| 115 | return failure(); |
| 116 | TypeAttr type = TypeAttr::get(llvmElementType); |
| 117 | attrs.back().setValue( |
| 118 | rewriter.getAttr<LLVM::WorkgroupAttributionAttr>(numElements, type)); |
| 119 | argAttrs.push_back(rewriter.getDictionaryAttr(attrs)); |
| 120 | } |
| 121 | |
| 122 | // Location match function location |
| 123 | SmallVector<Location> argLocs(numAttributions, gpuFuncOp.getLoc()); |
| 124 | |
| 125 | // Perform signature modification |
| 126 | rewriter.modifyOpInPlace( |
| 127 | gpuFuncOp, [gpuFuncOp, &argIndices, &argTypes, &argAttrs, &argLocs]() { |
| 128 | LogicalResult inserted = |
| 129 | static_cast<FunctionOpInterface>(gpuFuncOp).insertArguments( |
| 130 | argIndices, argTypes, argAttrs, argLocs); |
| 131 | (void)inserted; |
| 132 | assert(succeeded(inserted) && |
| 133 | "expected GPU funcs to support inserting any argument" ); |
| 134 | }); |
| 135 | } else { |
| 136 | workgroupBuffers.reserve(gpuFuncOp.getNumWorkgroupAttributions()); |
| 137 | for (auto [idx, attribution] : |
| 138 | llvm::enumerate(gpuFuncOp.getWorkgroupAttributions())) { |
| 139 | auto type = dyn_cast<MemRefType>(attribution.getType()); |
| 140 | assert(type && type.hasStaticShape() && "unexpected type in attribution" ); |
| 141 | |
| 142 | uint64_t numElements = type.getNumElements(); |
| 143 | |
| 144 | auto elementType = |
| 145 | cast<Type>(typeConverter->convertType(type.getElementType())); |
| 146 | auto arrayType = LLVM::LLVMArrayType::get(elementType, numElements); |
| 147 | std::string name = |
| 148 | std::string(llvm::formatv("__wg_{0}_{1}" , gpuFuncOp.getName(), idx)); |
| 149 | uint64_t alignment = 0; |
| 150 | if (auto alignAttr = dyn_cast_or_null<IntegerAttr>( |
| 151 | gpuFuncOp.getWorkgroupAttributionAttr( |
| 152 | idx, LLVM::LLVMDialect::getAlignAttrName()))) |
| 153 | alignment = alignAttr.getInt(); |
| 154 | auto globalOp = rewriter.create<LLVM::GlobalOp>( |
| 155 | gpuFuncOp.getLoc(), arrayType, /*isConstant=*/false, |
| 156 | LLVM::Linkage::Internal, name, /*value=*/Attribute(), alignment, |
| 157 | workgroupAddrSpace); |
| 158 | workgroupBuffers.push_back(globalOp); |
| 159 | } |
| 160 | } |
| 161 | |
| 162 | // Remap proper input types. |
| 163 | TypeConverter::SignatureConversion signatureConversion( |
| 164 | gpuFuncOp.front().getNumArguments()); |
| 165 | |
| 166 | Type funcType = getTypeConverter()->convertFunctionSignature( |
| 167 | gpuFuncOp.getFunctionType(), /*isVariadic=*/false, |
| 168 | getTypeConverter()->getOptions().useBarePtrCallConv, signatureConversion); |
| 169 | if (!funcType) { |
| 170 | return rewriter.notifyMatchFailure(gpuFuncOp, [&](Diagnostic &diag) { |
| 171 | diag << "failed to convert function signature type for: " |
| 172 | << gpuFuncOp.getFunctionType(); |
| 173 | }); |
| 174 | } |
| 175 | |
| 176 | // Create the new function operation. Only copy those attributes that are |
| 177 | // not specific to function modeling. |
| 178 | SmallVector<NamedAttribute, 4> attributes; |
| 179 | ArrayAttr argAttrs; |
| 180 | for (const auto &attr : gpuFuncOp->getAttrs()) { |
| 181 | if (attr.getName() == SymbolTable::getSymbolAttrName() || |
| 182 | attr.getName() == gpuFuncOp.getFunctionTypeAttrName() || |
| 183 | attr.getName() == |
| 184 | gpu::GPUFuncOp::getNumWorkgroupAttributionsAttrName() || |
| 185 | attr.getName() == gpuFuncOp.getWorkgroupAttribAttrsAttrName() || |
| 186 | attr.getName() == gpuFuncOp.getPrivateAttribAttrsAttrName() || |
| 187 | attr.getName() == gpuFuncOp.getKnownBlockSizeAttrName() || |
| 188 | attr.getName() == gpuFuncOp.getKnownGridSizeAttrName()) |
| 189 | continue; |
| 190 | if (attr.getName() == gpuFuncOp.getArgAttrsAttrName()) { |
| 191 | argAttrs = gpuFuncOp.getArgAttrsAttr(); |
| 192 | continue; |
| 193 | } |
| 194 | attributes.push_back(attr); |
| 195 | } |
| 196 | |
| 197 | DenseI32ArrayAttr knownBlockSize = gpuFuncOp.getKnownBlockSizeAttr(); |
| 198 | DenseI32ArrayAttr knownGridSize = gpuFuncOp.getKnownGridSizeAttr(); |
| 199 | // Ensure we don't lose information if the function is lowered before its |
| 200 | // surrounding context. |
| 201 | auto *gpuDialect = cast<gpu::GPUDialect>(gpuFuncOp->getDialect()); |
| 202 | if (knownBlockSize) |
| 203 | attributes.emplace_back(gpuDialect->getKnownBlockSizeAttrHelper().getName(), |
| 204 | knownBlockSize); |
| 205 | if (knownGridSize) |
| 206 | attributes.emplace_back(gpuDialect->getKnownGridSizeAttrHelper().getName(), |
| 207 | knownGridSize); |
| 208 | |
| 209 | // Add a dialect specific kernel attribute in addition to GPU kernel |
| 210 | // attribute. The former is necessary for further translation while the |
| 211 | // latter is expected by gpu.launch_func. |
| 212 | if (gpuFuncOp.isKernel()) { |
| 213 | if (kernelAttributeName) |
| 214 | attributes.emplace_back(kernelAttributeName, rewriter.getUnitAttr()); |
| 215 | // Set the dialect-specific block size attribute if there is one. |
| 216 | if (kernelBlockSizeAttributeName && knownBlockSize) { |
| 217 | attributes.emplace_back(kernelBlockSizeAttributeName, knownBlockSize); |
| 218 | } |
| 219 | } |
| 220 | LLVM::CConv callingConvention = gpuFuncOp.isKernel() |
| 221 | ? kernelCallingConvention |
| 222 | : nonKernelCallingConvention; |
| 223 | auto llvmFuncOp = rewriter.create<LLVM::LLVMFuncOp>( |
| 224 | gpuFuncOp.getLoc(), gpuFuncOp.getName(), funcType, |
| 225 | LLVM::Linkage::External, /*dsoLocal=*/false, callingConvention, |
| 226 | /*comdat=*/nullptr, attributes); |
| 227 | |
| 228 | { |
| 229 | // Insert operations that correspond to converted workgroup and private |
| 230 | // memory attributions to the body of the function. This must operate on |
| 231 | // the original function, before the body region is inlined in the new |
| 232 | // function to maintain the relation between block arguments and the |
| 233 | // parent operation that assigns their semantics. |
| 234 | OpBuilder::InsertionGuard guard(rewriter); |
| 235 | |
| 236 | // Rewrite workgroup memory attributions to addresses of global buffers. |
| 237 | rewriter.setInsertionPointToStart(&gpuFuncOp.front()); |
| 238 | unsigned numProperArguments = gpuFuncOp.getNumArguments(); |
| 239 | |
| 240 | if (encodeWorkgroupAttributionsAsArguments) { |
| 241 | // Build a MemRefDescriptor with each of the arguments added above. |
| 242 | |
| 243 | unsigned numAttributions = gpuFuncOp.getNumWorkgroupAttributions(); |
| 244 | assert(numProperArguments >= numAttributions && |
| 245 | "Expecting attributions to be encoded as arguments already" ); |
| 246 | |
| 247 | // Arguments encoding workgroup attributions will be in positions |
| 248 | // [numProperArguments, numProperArguments+numAttributions) |
| 249 | ArrayRef<BlockArgument> attributionArguments = |
| 250 | gpuFuncOp.getArguments().slice(numProperArguments - numAttributions, |
| 251 | numAttributions); |
| 252 | for (auto [idx, vals] : llvm::enumerate(llvm::zip_equal( |
| 253 | gpuFuncOp.getWorkgroupAttributions(), attributionArguments))) { |
| 254 | auto [attribution, arg] = vals; |
| 255 | auto type = cast<MemRefType>(attribution.getType()); |
| 256 | |
| 257 | // Arguments are of llvm.ptr type and attributions are of memref type: |
| 258 | // we need to wrap them in memref descriptors. |
| 259 | Value descr = MemRefDescriptor::fromStaticShape( |
| 260 | rewriter, loc, *getTypeConverter(), type, arg); |
| 261 | |
| 262 | // And remap the arguments |
| 263 | signatureConversion.remapInput(numProperArguments + idx, descr); |
| 264 | } |
| 265 | } else { |
| 266 | for (const auto [idx, global] : llvm::enumerate(workgroupBuffers)) { |
| 267 | auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext(), |
| 268 | global.getAddrSpace()); |
| 269 | Value address = rewriter.create<LLVM::AddressOfOp>( |
| 270 | loc, ptrType, global.getSymNameAttr()); |
| 271 | Value memory = |
| 272 | rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getType(), |
| 273 | address, ArrayRef<LLVM::GEPArg>{0, 0}); |
| 274 | |
| 275 | // Build a memref descriptor pointing to the buffer to plug with the |
| 276 | // existing memref infrastructure. This may use more registers than |
| 277 | // otherwise necessary given that memref sizes are fixed, but we can try |
| 278 | // and canonicalize that away later. |
| 279 | Value attribution = gpuFuncOp.getWorkgroupAttributions()[idx]; |
| 280 | auto type = cast<MemRefType>(attribution.getType()); |
| 281 | Value descr = MemRefDescriptor::fromStaticShape( |
| 282 | rewriter, loc, *getTypeConverter(), type, memory); |
| 283 | signatureConversion.remapInput(numProperArguments + idx, descr); |
| 284 | } |
| 285 | } |
| 286 | |
| 287 | // Rewrite private memory attributions to alloca'ed buffers. |
| 288 | unsigned numWorkgroupAttributions = gpuFuncOp.getNumWorkgroupAttributions(); |
| 289 | auto int64Ty = IntegerType::get(rewriter.getContext(), 64); |
| 290 | for (const auto [idx, attribution] : |
| 291 | llvm::enumerate(gpuFuncOp.getPrivateAttributions())) { |
| 292 | auto type = cast<MemRefType>(attribution.getType()); |
| 293 | assert(type && type.hasStaticShape() && "unexpected type in attribution" ); |
| 294 | |
| 295 | // Explicitly drop memory space when lowering private memory |
| 296 | // attributions since NVVM models it as `alloca`s in the default |
| 297 | // memory space and does not support `alloca`s with addrspace(5). |
| 298 | Type elementType = typeConverter->convertType(type.getElementType()); |
| 299 | auto ptrType = |
| 300 | LLVM::LLVMPointerType::get(rewriter.getContext(), allocaAddrSpace); |
| 301 | Value numElements = rewriter.create<LLVM::ConstantOp>( |
| 302 | gpuFuncOp.getLoc(), int64Ty, type.getNumElements()); |
| 303 | uint64_t alignment = 0; |
| 304 | if (auto alignAttr = |
| 305 | dyn_cast_or_null<IntegerAttr>(gpuFuncOp.getPrivateAttributionAttr( |
| 306 | idx, LLVM::LLVMDialect::getAlignAttrName()))) |
| 307 | alignment = alignAttr.getInt(); |
| 308 | Value allocated = rewriter.create<LLVM::AllocaOp>( |
| 309 | gpuFuncOp.getLoc(), ptrType, elementType, numElements, alignment); |
| 310 | Value descr = MemRefDescriptor::fromStaticShape( |
| 311 | rewriter, loc, *getTypeConverter(), type, allocated); |
| 312 | signatureConversion.remapInput( |
| 313 | numProperArguments + numWorkgroupAttributions + idx, descr); |
| 314 | } |
| 315 | } |
| 316 | |
| 317 | // Move the region to the new function, update the entry block signature. |
| 318 | rewriter.inlineRegionBefore(gpuFuncOp.getBody(), llvmFuncOp.getBody(), |
| 319 | llvmFuncOp.end()); |
| 320 | if (failed(rewriter.convertRegionTypes(&llvmFuncOp.getBody(), *typeConverter, |
| 321 | &signatureConversion))) |
| 322 | return failure(); |
| 323 | |
| 324 | // Get memref type from function arguments and set the noalias to |
| 325 | // pointer arguments. |
| 326 | for (const auto [idx, argTy] : |
| 327 | llvm::enumerate(gpuFuncOp.getArgumentTypes())) { |
| 328 | auto remapping = signatureConversion.getInputMapping(idx); |
| 329 | NamedAttrList argAttr = |
| 330 | argAttrs ? cast<DictionaryAttr>(argAttrs[idx]) : NamedAttrList(); |
| 331 | auto copyAttribute = [&](StringRef attrName) { |
| 332 | Attribute attr = argAttr.erase(attrName); |
| 333 | if (!attr) |
| 334 | return; |
| 335 | for (size_t i = 0, e = remapping->size; i < e; ++i) |
| 336 | llvmFuncOp.setArgAttr(remapping->inputNo + i, attrName, attr); |
| 337 | }; |
| 338 | auto copyPointerAttribute = [&](StringRef attrName) { |
| 339 | Attribute attr = argAttr.erase(attrName); |
| 340 | |
| 341 | if (!attr) |
| 342 | return; |
| 343 | if (remapping->size > 1 && |
| 344 | attrName == LLVM::LLVMDialect::getNoAliasAttrName()) { |
| 345 | emitWarning(llvmFuncOp.getLoc(), |
| 346 | "Cannot copy noalias with non-bare pointers.\n" ); |
| 347 | return; |
| 348 | } |
| 349 | for (size_t i = 0, e = remapping->size; i < e; ++i) { |
| 350 | if (isa<LLVM::LLVMPointerType>( |
| 351 | llvmFuncOp.getArgument(remapping->inputNo + i).getType())) { |
| 352 | llvmFuncOp.setArgAttr(remapping->inputNo + i, attrName, attr); |
| 353 | } |
| 354 | } |
| 355 | }; |
| 356 | |
| 357 | if (argAttr.empty()) |
| 358 | continue; |
| 359 | |
| 360 | copyAttribute(LLVM::LLVMDialect::getReturnedAttrName()); |
| 361 | copyAttribute(LLVM::LLVMDialect::getNoUndefAttrName()); |
| 362 | copyAttribute(LLVM::LLVMDialect::getInRegAttrName()); |
| 363 | bool lowersToPointer = false; |
| 364 | for (size_t i = 0, e = remapping->size; i < e; ++i) { |
| 365 | lowersToPointer |= isa<LLVM::LLVMPointerType>( |
| 366 | llvmFuncOp.getArgument(remapping->inputNo + i).getType()); |
| 367 | } |
| 368 | |
| 369 | if (lowersToPointer) { |
| 370 | copyPointerAttribute(LLVM::LLVMDialect::getNoAliasAttrName()); |
| 371 | copyPointerAttribute(LLVM::LLVMDialect::getNoCaptureAttrName()); |
| 372 | copyPointerAttribute(LLVM::LLVMDialect::getNoFreeAttrName()); |
| 373 | copyPointerAttribute(LLVM::LLVMDialect::getAlignAttrName()); |
| 374 | copyPointerAttribute(LLVM::LLVMDialect::getReadonlyAttrName()); |
| 375 | copyPointerAttribute(LLVM::LLVMDialect::getWriteOnlyAttrName()); |
| 376 | copyPointerAttribute(LLVM::LLVMDialect::getReadnoneAttrName()); |
| 377 | copyPointerAttribute(LLVM::LLVMDialect::getNonNullAttrName()); |
| 378 | copyPointerAttribute(LLVM::LLVMDialect::getDereferenceableAttrName()); |
| 379 | copyPointerAttribute( |
| 380 | LLVM::LLVMDialect::getDereferenceableOrNullAttrName()); |
| 381 | copyPointerAttribute( |
| 382 | LLVM::LLVMDialect::WorkgroupAttributionAttrHelper::getNameStr()); |
| 383 | } |
| 384 | } |
| 385 | rewriter.eraseOp(op: gpuFuncOp); |
| 386 | return success(); |
| 387 | } |
| 388 | |
| 389 | LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite( |
| 390 | gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor, |
| 391 | ConversionPatternRewriter &rewriter) const { |
| 392 | Location loc = gpuPrintfOp->getLoc(); |
| 393 | |
| 394 | mlir::Type llvmI8 = typeConverter->convertType(rewriter.getI8Type()); |
| 395 | auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext()); |
| 396 | mlir::Type llvmI32 = typeConverter->convertType(rewriter.getI32Type()); |
| 397 | mlir::Type llvmI64 = typeConverter->convertType(rewriter.getI64Type()); |
| 398 | // Note: this is the GPUModule op, not the ModuleOp that surrounds it |
| 399 | // This ensures that global constants and declarations are placed within |
| 400 | // the device code, not the host code |
| 401 | auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>(); |
| 402 | |
| 403 | auto ocklBegin = |
| 404 | getOrDefineFunction(moduleOp, loc, rewriter, "__ockl_printf_begin" , |
| 405 | LLVM::LLVMFunctionType::get(llvmI64, {llvmI64})); |
| 406 | LLVM::LLVMFuncOp ocklAppendArgs; |
| 407 | if (!adaptor.getArgs().empty()) { |
| 408 | ocklAppendArgs = getOrDefineFunction( |
| 409 | moduleOp, loc, rewriter, "__ockl_printf_append_args" , |
| 410 | LLVM::LLVMFunctionType::get( |
| 411 | llvmI64, {llvmI64, /*numArgs*/ llvmI32, llvmI64, llvmI64, llvmI64, |
| 412 | llvmI64, llvmI64, llvmI64, llvmI64, /*isLast*/ llvmI32})); |
| 413 | } |
| 414 | auto ocklAppendStringN = getOrDefineFunction( |
| 415 | moduleOp, loc, rewriter, "__ockl_printf_append_string_n" , |
| 416 | LLVM::LLVMFunctionType::get( |
| 417 | llvmI64, |
| 418 | {llvmI64, ptrType, /*length (bytes)*/ llvmI64, /*isLast*/ llvmI32})); |
| 419 | |
| 420 | /// Start the printf hostcall |
| 421 | Value zeroI64 = rewriter.create<LLVM::ConstantOp>(loc, llvmI64, 0); |
| 422 | auto printfBeginCall = rewriter.create<LLVM::CallOp>(loc, ocklBegin, zeroI64); |
| 423 | Value printfDesc = printfBeginCall.getResult(); |
| 424 | |
| 425 | // Create the global op or find an existing one. |
| 426 | LLVM::GlobalOp global = getOrCreateStringConstant( |
| 427 | rewriter, loc, moduleOp, llvmI8, "printfFormat_" , adaptor.getFormat()); |
| 428 | |
| 429 | // Get a pointer to the format string's first element and pass it to printf() |
| 430 | Value globalPtr = rewriter.create<LLVM::AddressOfOp>( |
| 431 | loc, |
| 432 | LLVM::LLVMPointerType::get(rewriter.getContext(), global.getAddrSpace()), |
| 433 | global.getSymNameAttr()); |
| 434 | Value stringStart = |
| 435 | rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getGlobalType(), |
| 436 | globalPtr, ArrayRef<LLVM::GEPArg>{0, 0}); |
| 437 | Value stringLen = rewriter.create<LLVM::ConstantOp>( |
| 438 | loc, llvmI64, cast<StringAttr>(global.getValueAttr()).size()); |
| 439 | |
| 440 | Value oneI32 = rewriter.create<LLVM::ConstantOp>(loc, llvmI32, 1); |
| 441 | Value zeroI32 = rewriter.create<LLVM::ConstantOp>(loc, llvmI32, 0); |
| 442 | |
| 443 | auto appendFormatCall = rewriter.create<LLVM::CallOp>( |
| 444 | loc, ocklAppendStringN, |
| 445 | ValueRange{printfDesc, stringStart, stringLen, |
| 446 | adaptor.getArgs().empty() ? oneI32 : zeroI32}); |
| 447 | printfDesc = appendFormatCall.getResult(); |
| 448 | |
| 449 | // __ockl_printf_append_args takes 7 values per append call |
| 450 | constexpr size_t argsPerAppend = 7; |
| 451 | size_t nArgs = adaptor.getArgs().size(); |
| 452 | for (size_t group = 0; group < nArgs; group += argsPerAppend) { |
| 453 | size_t bound = std::min(a: group + argsPerAppend, b: nArgs); |
| 454 | size_t numArgsThisCall = bound - group; |
| 455 | |
| 456 | SmallVector<mlir::Value, 2 + argsPerAppend + 1> arguments; |
| 457 | arguments.push_back(Elt: printfDesc); |
| 458 | arguments.push_back( |
| 459 | rewriter.create<LLVM::ConstantOp>(loc, llvmI32, numArgsThisCall)); |
| 460 | for (size_t i = group; i < bound; ++i) { |
| 461 | Value arg = adaptor.getArgs()[i]; |
| 462 | if (auto floatType = dyn_cast<FloatType>(arg.getType())) { |
| 463 | if (!floatType.isF64()) |
| 464 | arg = rewriter.create<LLVM::FPExtOp>( |
| 465 | loc, typeConverter->convertType(rewriter.getF64Type()), arg); |
| 466 | arg = rewriter.create<LLVM::BitcastOp>(loc, llvmI64, arg); |
| 467 | } |
| 468 | if (arg.getType().getIntOrFloatBitWidth() != 64) |
| 469 | arg = rewriter.create<LLVM::ZExtOp>(loc, llvmI64, arg); |
| 470 | |
| 471 | arguments.push_back(Elt: arg); |
| 472 | } |
| 473 | // Pad out to 7 arguments since the hostcall always needs 7 |
| 474 | for (size_t = numArgsThisCall; extra < argsPerAppend; ++extra) { |
| 475 | arguments.push_back(Elt: zeroI64); |
| 476 | } |
| 477 | |
| 478 | auto isLast = (bound == nArgs) ? oneI32 : zeroI32; |
| 479 | arguments.push_back(Elt: isLast); |
| 480 | auto call = rewriter.create<LLVM::CallOp>(loc, ocklAppendArgs, arguments); |
| 481 | printfDesc = call.getResult(); |
| 482 | } |
| 483 | rewriter.eraseOp(op: gpuPrintfOp); |
| 484 | return success(); |
| 485 | } |
| 486 | |
| 487 | LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite( |
| 488 | gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor, |
| 489 | ConversionPatternRewriter &rewriter) const { |
| 490 | Location loc = gpuPrintfOp->getLoc(); |
| 491 | |
| 492 | mlir::Type llvmI8 = typeConverter->convertType(rewriter.getIntegerType(8)); |
| 493 | mlir::Type ptrType = |
| 494 | LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace); |
| 495 | |
| 496 | // Note: this is the GPUModule op, not the ModuleOp that surrounds it |
| 497 | // This ensures that global constants and declarations are placed within |
| 498 | // the device code, not the host code |
| 499 | auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>(); |
| 500 | |
| 501 | auto printfType = |
| 502 | LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {ptrType}, |
| 503 | /*isVarArg=*/true); |
| 504 | LLVM::LLVMFuncOp printfDecl = |
| 505 | getOrDefineFunction(moduleOp, loc, rewriter, "printf" , printfType); |
| 506 | |
| 507 | // Create the global op or find an existing one. |
| 508 | LLVM::GlobalOp global = getOrCreateStringConstant( |
| 509 | rewriter, loc, moduleOp, llvmI8, "printfFormat_" , adaptor.getFormat(), |
| 510 | /*alignment=*/0, addressSpace); |
| 511 | |
| 512 | // Get a pointer to the format string's first element |
| 513 | Value globalPtr = rewriter.create<LLVM::AddressOfOp>( |
| 514 | loc, |
| 515 | LLVM::LLVMPointerType::get(rewriter.getContext(), global.getAddrSpace()), |
| 516 | global.getSymNameAttr()); |
| 517 | Value stringStart = |
| 518 | rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getGlobalType(), |
| 519 | globalPtr, ArrayRef<LLVM::GEPArg>{0, 0}); |
| 520 | |
| 521 | // Construct arguments and function call |
| 522 | auto argsRange = adaptor.getArgs(); |
| 523 | SmallVector<Value, 4> printfArgs; |
| 524 | printfArgs.reserve(N: argsRange.size() + 1); |
| 525 | printfArgs.push_back(Elt: stringStart); |
| 526 | printfArgs.append(argsRange.begin(), argsRange.end()); |
| 527 | |
| 528 | rewriter.create<LLVM::CallOp>(loc, printfDecl, printfArgs); |
| 529 | rewriter.eraseOp(op: gpuPrintfOp); |
| 530 | return success(); |
| 531 | } |
| 532 | |
| 533 | LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite( |
| 534 | gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor, |
| 535 | ConversionPatternRewriter &rewriter) const { |
| 536 | Location loc = gpuPrintfOp->getLoc(); |
| 537 | |
| 538 | mlir::Type llvmI8 = typeConverter->convertType(rewriter.getIntegerType(8)); |
| 539 | mlir::Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext()); |
| 540 | |
| 541 | // Note: this is the GPUModule op, not the ModuleOp that surrounds it |
| 542 | // This ensures that global constants and declarations are placed within |
| 543 | // the device code, not the host code |
| 544 | auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>(); |
| 545 | |
| 546 | // Create a valid global location removing any metadata attached to the |
| 547 | // location as debug info metadata inside of a function cannot be used outside |
| 548 | // of that function. |
| 549 | Location globalLoc = loc->findInstanceOfOrUnknown<FileLineColLoc>(); |
| 550 | |
| 551 | auto vprintfType = |
| 552 | LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {ptrType, ptrType}); |
| 553 | LLVM::LLVMFuncOp vprintfDecl = getOrDefineFunction( |
| 554 | moduleOp, globalLoc, rewriter, "vprintf" , vprintfType); |
| 555 | |
| 556 | // Create the global op or find an existing one. |
| 557 | LLVM::GlobalOp global = |
| 558 | getOrCreateStringConstant(rewriter, globalLoc, moduleOp, llvmI8, |
| 559 | "printfFormat_" , adaptor.getFormat()); |
| 560 | |
| 561 | // Get a pointer to the format string's first element |
| 562 | Value globalPtr = rewriter.create<LLVM::AddressOfOp>(loc, global); |
| 563 | Value stringStart = |
| 564 | rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getGlobalType(), |
| 565 | globalPtr, ArrayRef<LLVM::GEPArg>{0, 0}); |
| 566 | SmallVector<Type> types; |
| 567 | SmallVector<Value> args; |
| 568 | // Promote and pack the arguments into a stack allocation. |
| 569 | for (Value arg : adaptor.getArgs()) { |
| 570 | Type type = arg.getType(); |
| 571 | Value promotedArg = arg; |
| 572 | assert(type.isIntOrFloat()); |
| 573 | if (isa<FloatType>(type)) { |
| 574 | type = rewriter.getF64Type(); |
| 575 | promotedArg = rewriter.create<LLVM::FPExtOp>(loc, type, arg); |
| 576 | } |
| 577 | types.push_back(type); |
| 578 | args.push_back(promotedArg); |
| 579 | } |
| 580 | Type structType = |
| 581 | LLVM::LLVMStructType::getLiteral(gpuPrintfOp.getContext(), types); |
| 582 | Value one = rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI64Type(), |
| 583 | rewriter.getIndexAttr(1)); |
| 584 | Value tempAlloc = |
| 585 | rewriter.create<LLVM::AllocaOp>(loc, ptrType, structType, one, |
| 586 | /*alignment=*/0); |
| 587 | for (auto [index, arg] : llvm::enumerate(First&: args)) { |
| 588 | Value ptr = rewriter.create<LLVM::GEPOp>( |
| 589 | loc, ptrType, structType, tempAlloc, |
| 590 | ArrayRef<LLVM::GEPArg>{0, static_cast<int32_t>(index)}); |
| 591 | rewriter.create<LLVM::StoreOp>(loc, arg, ptr); |
| 592 | } |
| 593 | std::array<Value, 2> printfArgs = {stringStart, tempAlloc}; |
| 594 | |
| 595 | rewriter.create<LLVM::CallOp>(loc, vprintfDecl, printfArgs); |
| 596 | rewriter.eraseOp(op: gpuPrintfOp); |
| 597 | return success(); |
| 598 | } |
| 599 | |
| 600 | /// Helper for impl::scalarizeVectorOp. Scalarizes vectors to elements. |
| 601 | /// Used either directly (for ops on 1D vectors) or as the callback passed to |
| 602 | /// detail::handleMultidimensionalVectors (for ops on higher-rank vectors). |
| 603 | static Value scalarizeVectorOpHelper(Operation *op, ValueRange operands, |
| 604 | Type llvm1DVectorTy, |
| 605 | ConversionPatternRewriter &rewriter, |
| 606 | const LLVMTypeConverter &converter) { |
| 607 | TypeRange operandTypes(operands); |
| 608 | VectorType vectorType = cast<VectorType>(llvm1DVectorTy); |
| 609 | Location loc = op->getLoc(); |
| 610 | Value result = rewriter.create<LLVM::PoisonOp>(loc, vectorType); |
| 611 | Type indexType = converter.convertType(rewriter.getIndexType()); |
| 612 | StringAttr name = op->getName().getIdentifier(); |
| 613 | Type elementType = vectorType.getElementType(); |
| 614 | |
| 615 | for (int64_t i = 0; i < vectorType.getNumElements(); ++i) { |
| 616 | Value index = rewriter.create<LLVM::ConstantOp>(loc, indexType, i); |
| 617 | auto = [&](Value operand) -> Value { |
| 618 | if (!isa<VectorType>(Val: operand.getType())) |
| 619 | return operand; |
| 620 | return rewriter.create<LLVM::ExtractElementOp>(loc, operand, index); |
| 621 | }; |
| 622 | auto scalarOperands = llvm::map_to_vector(C&: operands, F&: extractElement); |
| 623 | Operation *scalarOp = |
| 624 | rewriter.create(loc, name, scalarOperands, elementType, op->getAttrs()); |
| 625 | result = rewriter.create<LLVM::InsertElementOp>( |
| 626 | loc, result, scalarOp->getResult(0), index); |
| 627 | } |
| 628 | return result; |
| 629 | } |
| 630 | |
| 631 | /// Unrolls op to array/vector elements. |
| 632 | LogicalResult impl::scalarizeVectorOp(Operation *op, ValueRange operands, |
| 633 | ConversionPatternRewriter &rewriter, |
| 634 | const LLVMTypeConverter &converter) { |
| 635 | TypeRange operandTypes(operands); |
| 636 | if (llvm::any_of(Range&: operandTypes, P: llvm::IsaPred<VectorType>)) { |
| 637 | VectorType vectorType = |
| 638 | cast<VectorType>(converter.convertType(op->getResultTypes()[0])); |
| 639 | rewriter.replaceOp(op, scalarizeVectorOpHelper(op, operands, vectorType, |
| 640 | rewriter, converter)); |
| 641 | return success(); |
| 642 | } |
| 643 | |
| 644 | if (llvm::any_of(operandTypes, llvm::IsaPred<LLVM::LLVMArrayType>)) { |
| 645 | return LLVM::detail::handleMultidimensionalVectors( |
| 646 | op, operands, typeConverter: converter, |
| 647 | createOperand: [&](Type llvm1DVectorTy, ValueRange operands) -> Value { |
| 648 | return scalarizeVectorOpHelper(op, operands, llvm1DVectorTy, rewriter, |
| 649 | converter); |
| 650 | }, |
| 651 | rewriter); |
| 652 | } |
| 653 | |
| 654 | return rewriter.notifyMatchFailure(arg&: op, msg: "no llvm.array or vector to unroll" ); |
| 655 | } |
| 656 | |
| 657 | static IntegerAttr wrapNumericMemorySpace(MLIRContext *ctx, unsigned space) { |
| 658 | return IntegerAttr::get(IntegerType::get(ctx, 64), space); |
| 659 | } |
| 660 | |
| 661 | /// Generates a symbol with 0-sized array type for dynamic shared memory usage, |
| 662 | /// or uses existing symbol. |
| 663 | LLVM::GlobalOp getDynamicSharedMemorySymbol( |
| 664 | ConversionPatternRewriter &rewriter, gpu::GPUModuleOp moduleOp, |
| 665 | gpu::DynamicSharedMemoryOp op, const LLVMTypeConverter *typeConverter, |
| 666 | MemRefType memrefType, unsigned alignmentBit) { |
| 667 | uint64_t alignmentByte = alignmentBit / memrefType.getElementTypeBitWidth(); |
| 668 | |
| 669 | FailureOr<unsigned> addressSpace = |
| 670 | typeConverter->getMemRefAddressSpace(type: memrefType); |
| 671 | if (failed(Result: addressSpace)) { |
| 672 | op->emitError() << "conversion of memref memory space " |
| 673 | << memrefType.getMemorySpace() |
| 674 | << " to integer address space " |
| 675 | "failed. Consider adding memory space conversions." ; |
| 676 | } |
| 677 | |
| 678 | // Step 1. Collect symbol names of LLVM::GlobalOp Ops. Also if any of |
| 679 | // LLVM::GlobalOp is suitable for shared memory, return it. |
| 680 | llvm::StringSet<> existingGlobalNames; |
| 681 | for (auto globalOp : moduleOp.getBody()->getOps<LLVM::GlobalOp>()) { |
| 682 | existingGlobalNames.insert(globalOp.getSymName()); |
| 683 | if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(globalOp.getType())) { |
| 684 | if (globalOp.getAddrSpace() == addressSpace.value() && |
| 685 | arrayType.getNumElements() == 0 && |
| 686 | globalOp.getAlignment().value_or(0) == alignmentByte) { |
| 687 | return globalOp; |
| 688 | } |
| 689 | } |
| 690 | } |
| 691 | |
| 692 | // Step 2. Find a unique symbol name |
| 693 | unsigned uniquingCounter = 0; |
| 694 | SmallString<128> symName = SymbolTable::generateSymbolName<128>( |
| 695 | name: "__dynamic_shmem_" , |
| 696 | uniqueChecker: [&](StringRef candidate) { |
| 697 | return existingGlobalNames.contains(key: candidate); |
| 698 | }, |
| 699 | uniquingCounter); |
| 700 | |
| 701 | // Step 3. Generate a global op |
| 702 | OpBuilder::InsertionGuard guard(rewriter); |
| 703 | rewriter.setInsertionPointToStart(moduleOp.getBody()); |
| 704 | |
| 705 | auto zeroSizedArrayType = LLVM::LLVMArrayType::get( |
| 706 | typeConverter->convertType(memrefType.getElementType()), 0); |
| 707 | |
| 708 | return rewriter.create<LLVM::GlobalOp>( |
| 709 | op->getLoc(), zeroSizedArrayType, /*isConstant=*/false, |
| 710 | LLVM::Linkage::Internal, symName, /*value=*/Attribute(), alignmentByte, |
| 711 | addressSpace.value()); |
| 712 | } |
| 713 | |
| 714 | LogicalResult GPUDynamicSharedMemoryOpLowering::matchAndRewrite( |
| 715 | gpu::DynamicSharedMemoryOp op, OpAdaptor adaptor, |
| 716 | ConversionPatternRewriter &rewriter) const { |
| 717 | Location loc = op.getLoc(); |
| 718 | MemRefType memrefType = op.getResultMemref().getType(); |
| 719 | Type elementType = typeConverter->convertType(memrefType.getElementType()); |
| 720 | |
| 721 | // Step 1: Generate a memref<0xi8> type |
| 722 | MemRefLayoutAttrInterface layout = {}; |
| 723 | auto memrefType0sz = |
| 724 | MemRefType::get({0}, elementType, layout, memrefType.getMemorySpace()); |
| 725 | |
| 726 | // Step 2: Generate a global symbol or existing for the dynamic shared |
| 727 | // memory with memref<0xi8> type |
| 728 | auto moduleOp = op->getParentOfType<gpu::GPUModuleOp>(); |
| 729 | LLVM::GlobalOp shmemOp = getDynamicSharedMemorySymbol( |
| 730 | rewriter, moduleOp, op, getTypeConverter(), memrefType0sz, alignmentBit); |
| 731 | |
| 732 | // Step 3. Get address of the global symbol |
| 733 | OpBuilder::InsertionGuard guard(rewriter); |
| 734 | rewriter.setInsertionPoint(op); |
| 735 | auto basePtr = rewriter.create<LLVM::AddressOfOp>(loc, shmemOp); |
| 736 | Type baseType = basePtr->getResultTypes().front(); |
| 737 | |
| 738 | // Step 4. Generate GEP using offsets |
| 739 | SmallVector<LLVM::GEPArg> gepArgs = {0}; |
| 740 | Value shmemPtr = rewriter.create<LLVM::GEPOp>(loc, baseType, elementType, |
| 741 | basePtr, gepArgs); |
| 742 | // Step 5. Create a memref descriptor |
| 743 | SmallVector<Value> shape, strides; |
| 744 | Value sizeBytes; |
| 745 | getMemRefDescriptorSizes(loc, memrefType0sz, {}, rewriter, shape, strides, |
| 746 | sizeBytes); |
| 747 | auto memRefDescriptor = this->createMemRefDescriptor( |
| 748 | loc, memrefType0sz, shmemPtr, shmemPtr, shape, strides, rewriter); |
| 749 | |
| 750 | // Step 5. Replace the op with memref descriptor |
| 751 | rewriter.replaceOp(op, {memRefDescriptor}); |
| 752 | return success(); |
| 753 | } |
| 754 | |
| 755 | LogicalResult GPUReturnOpLowering::matchAndRewrite( |
| 756 | gpu::ReturnOp op, OpAdaptor adaptor, |
| 757 | ConversionPatternRewriter &rewriter) const { |
| 758 | Location loc = op.getLoc(); |
| 759 | unsigned numArguments = op.getNumOperands(); |
| 760 | SmallVector<Value, 4> updatedOperands; |
| 761 | |
| 762 | bool useBarePtrCallConv = getTypeConverter()->getOptions().useBarePtrCallConv; |
| 763 | if (useBarePtrCallConv) { |
| 764 | // For the bare-ptr calling convention, extract the aligned pointer to |
| 765 | // be returned from the memref descriptor. |
| 766 | for (auto it : llvm::zip(op->getOperands(), adaptor.getOperands())) { |
| 767 | Type oldTy = std::get<0>(it).getType(); |
| 768 | Value newOperand = std::get<1>(it); |
| 769 | if (isa<MemRefType>(oldTy) && getTypeConverter()->canConvertToBarePtr( |
| 770 | cast<BaseMemRefType>(oldTy))) { |
| 771 | MemRefDescriptor memrefDesc(newOperand); |
| 772 | newOperand = memrefDesc.allocatedPtr(rewriter, loc); |
| 773 | } else if (isa<UnrankedMemRefType>(oldTy)) { |
| 774 | // Unranked memref is not supported in the bare pointer calling |
| 775 | // convention. |
| 776 | return failure(); |
| 777 | } |
| 778 | updatedOperands.push_back(newOperand); |
| 779 | } |
| 780 | } else { |
| 781 | updatedOperands = llvm::to_vector<4>(adaptor.getOperands()); |
| 782 | (void)copyUnrankedDescriptors(rewriter, loc, op.getOperands().getTypes(), |
| 783 | updatedOperands, |
| 784 | /*toDynamic=*/true); |
| 785 | } |
| 786 | |
| 787 | // If ReturnOp has 0 or 1 operand, create it and return immediately. |
| 788 | if (numArguments <= 1) { |
| 789 | rewriter.replaceOpWithNewOp<LLVM::ReturnOp>( |
| 790 | op, TypeRange(), updatedOperands, op->getAttrs()); |
| 791 | return success(); |
| 792 | } |
| 793 | |
| 794 | // Otherwise, we need to pack the arguments into an LLVM struct type before |
| 795 | // returning. |
| 796 | auto packedType = getTypeConverter()->packFunctionResults( |
| 797 | op.getOperandTypes(), useBarePtrCallConv); |
| 798 | if (!packedType) { |
| 799 | return rewriter.notifyMatchFailure(op, "could not convert result types" ); |
| 800 | } |
| 801 | |
| 802 | Value packed = rewriter.create<LLVM::PoisonOp>(loc, packedType); |
| 803 | for (auto [idx, operand] : llvm::enumerate(First&: updatedOperands)) { |
| 804 | packed = rewriter.create<LLVM::InsertValueOp>(loc, packed, operand, idx); |
| 805 | } |
| 806 | rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), packed, |
| 807 | op->getAttrs()); |
| 808 | return success(); |
| 809 | } |
| 810 | |
| 811 | void mlir::populateGpuMemorySpaceAttributeConversions( |
| 812 | TypeConverter &typeConverter, const MemorySpaceMapping &mapping) { |
| 813 | typeConverter.addTypeAttributeConversion( |
| 814 | callback: [mapping](BaseMemRefType type, gpu::AddressSpaceAttr memorySpaceAttr) { |
| 815 | gpu::AddressSpace memorySpace = memorySpaceAttr.getValue(); |
| 816 | unsigned addressSpace = mapping(memorySpace); |
| 817 | return wrapNumericMemorySpace(memorySpaceAttr.getContext(), |
| 818 | addressSpace); |
| 819 | }); |
| 820 | } |
| 821 | |