| 1 | //===- MemRefToSPIRV.cpp - MemRef to SPIR-V Patterns ----------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements patterns to convert MemRef dialect to SPIR-V dialect. |
| 10 | // |
| 11 | //===----------------------------------------------------------------------===// |
| 12 | |
| 13 | #include "mlir/Dialect/Arith/IR/Arith.h" |
| 14 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| 15 | #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" |
| 16 | #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" |
| 17 | #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" |
| 18 | #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" |
| 19 | #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" |
| 20 | #include "mlir/IR/BuiltinAttributes.h" |
| 21 | #include "mlir/IR/BuiltinTypes.h" |
| 22 | #include "mlir/IR/MLIRContext.h" |
| 23 | #include "mlir/IR/Visitors.h" |
| 24 | #include "llvm/Support/Debug.h" |
| 25 | #include <cassert> |
| 26 | #include <optional> |
| 27 | |
| 28 | #define DEBUG_TYPE "memref-to-spirv-pattern" |
| 29 | |
| 30 | using namespace mlir; |
| 31 | |
| 32 | //===----------------------------------------------------------------------===// |
| 33 | // Utility functions |
| 34 | //===----------------------------------------------------------------------===// |
| 35 | |
| 36 | /// Returns the offset of the value in `targetBits` representation. |
| 37 | /// |
| 38 | /// `srcIdx` is an index into a 1-D array with each element having `sourceBits`. |
| 39 | /// It's assumed to be non-negative. |
| 40 | /// |
| 41 | /// When accessing an element in the array treating as having elements of |
| 42 | /// `targetBits`, multiple values are loaded in the same time. The method |
| 43 | /// returns the offset where the `srcIdx` locates in the value. For example, if |
| 44 | /// `sourceBits` equals to 8 and `targetBits` equals to 32, the x-th element is |
| 45 | /// located at (x % 4) * 8. Because there are four elements in one i32, and one |
| 46 | /// element has 8 bits. |
| 47 | static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits, |
| 48 | int targetBits, OpBuilder &builder) { |
| 49 | assert(targetBits % sourceBits == 0); |
| 50 | Type type = srcIdx.getType(); |
| 51 | IntegerAttr idxAttr = builder.getIntegerAttr(type, targetBits / sourceBits); |
| 52 | auto idx = builder.createOrFold<spirv::ConstantOp>(loc, type, idxAttr); |
| 53 | IntegerAttr srcBitsAttr = builder.getIntegerAttr(type, sourceBits); |
| 54 | auto srcBitsValue = |
| 55 | builder.createOrFold<spirv::ConstantOp>(loc, type, srcBitsAttr); |
| 56 | auto m = builder.createOrFold<spirv::UModOp>(loc, srcIdx, idx); |
| 57 | return builder.createOrFold<spirv::IMulOp>(loc, type, m, srcBitsValue); |
| 58 | } |
| 59 | |
| 60 | /// Returns an adjusted spirv::AccessChainOp. Based on the |
| 61 | /// extension/capabilities, certain integer bitwidths `sourceBits` might not be |
| 62 | /// supported. During conversion if a memref of an unsupported type is used, |
| 63 | /// load/stores to this memref need to be modified to use a supported higher |
| 64 | /// bitwidth `targetBits` and extracting the required bits. For an accessing a |
| 65 | /// 1D array (spirv.array or spirv.rtarray), the last index is modified to load |
| 66 | /// the bits needed. The extraction of the actual bits needed are handled |
| 67 | /// separately. Note that this only works for a 1-D tensor. |
| 68 | static Value |
| 69 | adjustAccessChainForBitwidth(const SPIRVTypeConverter &typeConverter, |
| 70 | spirv::AccessChainOp op, int sourceBits, |
| 71 | int targetBits, OpBuilder &builder) { |
| 72 | assert(targetBits % sourceBits == 0); |
| 73 | const auto loc = op.getLoc(); |
| 74 | Value lastDim = op->getOperand(op.getNumOperands() - 1); |
| 75 | Type type = lastDim.getType(); |
| 76 | IntegerAttr attr = builder.getIntegerAttr(type, targetBits / sourceBits); |
| 77 | auto idx = builder.createOrFold<spirv::ConstantOp>(loc, type, attr); |
| 78 | auto indices = llvm::to_vector<4>(op.getIndices()); |
| 79 | // There are two elements if this is a 1-D tensor. |
| 80 | assert(indices.size() == 2); |
| 81 | indices.back() = builder.createOrFold<spirv::SDivOp>(loc, lastDim, idx); |
| 82 | Type t = typeConverter.convertType(op.getComponentPtr().getType()); |
| 83 | return builder.create<spirv::AccessChainOp>(loc, t, op.getBasePtr(), indices); |
| 84 | } |
| 85 | |
| 86 | /// Casts the given `srcBool` into an integer of `dstType`. |
| 87 | static Value castBoolToIntN(Location loc, Value srcBool, Type dstType, |
| 88 | OpBuilder &builder) { |
| 89 | assert(srcBool.getType().isInteger(1)); |
| 90 | if (dstType.isInteger(width: 1)) |
| 91 | return srcBool; |
| 92 | Value zero = spirv::ConstantOp::getZero(dstType, loc, builder); |
| 93 | Value one = spirv::ConstantOp::getOne(dstType, loc, builder); |
| 94 | return builder.createOrFold<spirv::SelectOp>(loc, dstType, srcBool, one, |
| 95 | zero); |
| 96 | } |
| 97 | |
| 98 | /// Returns the `targetBits`-bit value shifted by the given `offset`, and cast |
| 99 | /// to the type destination type, and masked. |
| 100 | static Value shiftValue(Location loc, Value value, Value offset, Value mask, |
| 101 | OpBuilder &builder) { |
| 102 | IntegerType dstType = cast<IntegerType>(mask.getType()); |
| 103 | int targetBits = static_cast<int>(dstType.getWidth()); |
| 104 | int valueBits = value.getType().getIntOrFloatBitWidth(); |
| 105 | assert(valueBits <= targetBits); |
| 106 | |
| 107 | if (valueBits == 1) { |
| 108 | value = castBoolToIntN(loc, value, dstType, builder); |
| 109 | } else { |
| 110 | if (valueBits < targetBits) { |
| 111 | value = builder.create<spirv::UConvertOp>( |
| 112 | loc, builder.getIntegerType(targetBits), value); |
| 113 | } |
| 114 | |
| 115 | value = builder.createOrFold<spirv::BitwiseAndOp>(loc, value, mask); |
| 116 | } |
| 117 | return builder.createOrFold<spirv::ShiftLeftLogicalOp>(loc, value.getType(), |
| 118 | value, offset); |
| 119 | } |
| 120 | |
| 121 | /// Returns true if the allocations of memref `type` generated from `allocOp` |
| 122 | /// can be lowered to SPIR-V. |
| 123 | static bool isAllocationSupported(Operation *allocOp, MemRefType type) { |
| 124 | if (isa<memref::AllocOp, memref::DeallocOp>(allocOp)) { |
| 125 | auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace()); |
| 126 | if (!sc || sc.getValue() != spirv::StorageClass::Workgroup) |
| 127 | return false; |
| 128 | } else if (isa<memref::AllocaOp>(allocOp)) { |
| 129 | auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace()); |
| 130 | if (!sc || sc.getValue() != spirv::StorageClass::Function) |
| 131 | return false; |
| 132 | } else { |
| 133 | return false; |
| 134 | } |
| 135 | |
| 136 | // Currently only support static shape and int or float or vector of int or |
| 137 | // float element type. |
| 138 | if (!type.hasStaticShape()) |
| 139 | return false; |
| 140 | |
| 141 | Type elementType = type.getElementType(); |
| 142 | if (auto vecType = dyn_cast<VectorType>(elementType)) |
| 143 | elementType = vecType.getElementType(); |
| 144 | return elementType.isIntOrFloat(); |
| 145 | } |
| 146 | |
| 147 | /// Returns the scope to use for atomic operations use for emulating store |
| 148 | /// operations of unsupported integer bitwidths, based on the memref |
| 149 | /// type. Returns std::nullopt on failure. |
| 150 | static std::optional<spirv::Scope> getAtomicOpScope(MemRefType type) { |
| 151 | auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace()); |
| 152 | switch (sc.getValue()) { |
| 153 | case spirv::StorageClass::StorageBuffer: |
| 154 | return spirv::Scope::Device; |
| 155 | case spirv::StorageClass::Workgroup: |
| 156 | return spirv::Scope::Workgroup; |
| 157 | default: |
| 158 | break; |
| 159 | } |
| 160 | return {}; |
| 161 | } |
| 162 | |
| 163 | /// Casts the given `srcInt` into a boolean value. |
| 164 | static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder) { |
| 165 | if (srcInt.getType().isInteger(width: 1)) |
| 166 | return srcInt; |
| 167 | |
| 168 | auto one = spirv::ConstantOp::getZero(srcInt.getType(), loc, builder); |
| 169 | return builder.createOrFold<spirv::INotEqualOp>(loc, srcInt, one); |
| 170 | } |
| 171 | |
| 172 | //===----------------------------------------------------------------------===// |
| 173 | // Operation conversion |
| 174 | //===----------------------------------------------------------------------===// |
| 175 | |
| 176 | // Note that DRR cannot be used for the patterns in this file: we may need to |
| 177 | // convert type along the way, which requires ConversionPattern. DRR generates |
| 178 | // normal RewritePattern. |
| 179 | |
| 180 | namespace { |
| 181 | |
| 182 | /// Converts memref.alloca to SPIR-V Function variables. |
| 183 | class AllocaOpPattern final : public OpConversionPattern<memref::AllocaOp> { |
| 184 | public: |
| 185 | using OpConversionPattern<memref::AllocaOp>::OpConversionPattern; |
| 186 | |
| 187 | LogicalResult |
| 188 | matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor, |
| 189 | ConversionPatternRewriter &rewriter) const override; |
| 190 | }; |
| 191 | |
| 192 | /// Converts an allocation operation to SPIR-V. Currently only supports lowering |
| 193 | /// to Workgroup memory when the size is constant. Note that this pattern needs |
| 194 | /// to be applied in a pass that runs at least at spirv.module scope since it |
| 195 | /// wil ladd global variables into the spirv.module. |
| 196 | class AllocOpPattern final : public OpConversionPattern<memref::AllocOp> { |
| 197 | public: |
| 198 | using OpConversionPattern<memref::AllocOp>::OpConversionPattern; |
| 199 | |
| 200 | LogicalResult |
| 201 | matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor, |
| 202 | ConversionPatternRewriter &rewriter) const override; |
| 203 | }; |
| 204 | |
| 205 | /// Converts memref.automic_rmw operations to SPIR-V atomic operations. |
| 206 | class AtomicRMWOpPattern final |
| 207 | : public OpConversionPattern<memref::AtomicRMWOp> { |
| 208 | public: |
| 209 | using OpConversionPattern<memref::AtomicRMWOp>::OpConversionPattern; |
| 210 | |
| 211 | LogicalResult |
| 212 | matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor, |
| 213 | ConversionPatternRewriter &rewriter) const override; |
| 214 | }; |
| 215 | |
| 216 | /// Removed a deallocation if it is a supported allocation. Currently only |
| 217 | /// removes deallocation if the memory space is workgroup memory. |
| 218 | class DeallocOpPattern final : public OpConversionPattern<memref::DeallocOp> { |
| 219 | public: |
| 220 | using OpConversionPattern<memref::DeallocOp>::OpConversionPattern; |
| 221 | |
| 222 | LogicalResult |
| 223 | matchAndRewrite(memref::DeallocOp operation, OpAdaptor adaptor, |
| 224 | ConversionPatternRewriter &rewriter) const override; |
| 225 | }; |
| 226 | |
| 227 | /// Converts memref.load to spirv.Load + spirv.AccessChain on integers. |
| 228 | class IntLoadOpPattern final : public OpConversionPattern<memref::LoadOp> { |
| 229 | public: |
| 230 | using OpConversionPattern<memref::LoadOp>::OpConversionPattern; |
| 231 | |
| 232 | LogicalResult |
| 233 | matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, |
| 234 | ConversionPatternRewriter &rewriter) const override; |
| 235 | }; |
| 236 | |
| 237 | /// Converts memref.load to spirv.Load + spirv.AccessChain. |
| 238 | class LoadOpPattern final : public OpConversionPattern<memref::LoadOp> { |
| 239 | public: |
| 240 | using OpConversionPattern<memref::LoadOp>::OpConversionPattern; |
| 241 | |
| 242 | LogicalResult |
| 243 | matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, |
| 244 | ConversionPatternRewriter &rewriter) const override; |
| 245 | }; |
| 246 | |
| 247 | /// Converts memref.store to spirv.Store on integers. |
| 248 | class IntStoreOpPattern final : public OpConversionPattern<memref::StoreOp> { |
| 249 | public: |
| 250 | using OpConversionPattern<memref::StoreOp>::OpConversionPattern; |
| 251 | |
| 252 | LogicalResult |
| 253 | matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor, |
| 254 | ConversionPatternRewriter &rewriter) const override; |
| 255 | }; |
| 256 | |
| 257 | /// Converts memref.memory_space_cast to the appropriate spirv cast operations. |
| 258 | class MemorySpaceCastOpPattern final |
| 259 | : public OpConversionPattern<memref::MemorySpaceCastOp> { |
| 260 | public: |
| 261 | using OpConversionPattern<memref::MemorySpaceCastOp>::OpConversionPattern; |
| 262 | |
| 263 | LogicalResult |
| 264 | matchAndRewrite(memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor, |
| 265 | ConversionPatternRewriter &rewriter) const override; |
| 266 | }; |
| 267 | |
| 268 | /// Converts memref.store to spirv.Store. |
| 269 | class StoreOpPattern final : public OpConversionPattern<memref::StoreOp> { |
| 270 | public: |
| 271 | using OpConversionPattern<memref::StoreOp>::OpConversionPattern; |
| 272 | |
| 273 | LogicalResult |
| 274 | matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor, |
| 275 | ConversionPatternRewriter &rewriter) const override; |
| 276 | }; |
| 277 | |
| 278 | class ReinterpretCastPattern final |
| 279 | : public OpConversionPattern<memref::ReinterpretCastOp> { |
| 280 | public: |
| 281 | using OpConversionPattern::OpConversionPattern; |
| 282 | |
| 283 | LogicalResult |
| 284 | matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor, |
| 285 | ConversionPatternRewriter &rewriter) const override; |
| 286 | }; |
| 287 | |
| 288 | class CastPattern final : public OpConversionPattern<memref::CastOp> { |
| 289 | public: |
| 290 | using OpConversionPattern::OpConversionPattern; |
| 291 | |
| 292 | LogicalResult |
| 293 | matchAndRewrite(memref::CastOp op, OpAdaptor adaptor, |
| 294 | ConversionPatternRewriter &rewriter) const override { |
| 295 | Value src = adaptor.getSource(); |
| 296 | Type srcType = src.getType(); |
| 297 | |
| 298 | const TypeConverter *converter = getTypeConverter(); |
| 299 | Type dstType = converter->convertType(op.getType()); |
| 300 | if (srcType != dstType) |
| 301 | return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { |
| 302 | diag << "types doesn't match: " << srcType << " and " << dstType; |
| 303 | }); |
| 304 | |
| 305 | rewriter.replaceOp(op, src); |
| 306 | return success(); |
| 307 | } |
| 308 | }; |
| 309 | |
| 310 | /// Converts memref.extract_aligned_pointer_as_index to spirv.ConvertPtrToU. |
| 311 | class final |
| 312 | : public OpConversionPattern<memref::ExtractAlignedPointerAsIndexOp> { |
| 313 | public: |
| 314 | using OpConversionPattern::OpConversionPattern; |
| 315 | |
| 316 | LogicalResult |
| 317 | matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp , |
| 318 | OpAdaptor adaptor, |
| 319 | ConversionPatternRewriter &rewriter) const override; |
| 320 | }; |
| 321 | } // namespace |
| 322 | |
| 323 | //===----------------------------------------------------------------------===// |
| 324 | // AllocaOp |
| 325 | //===----------------------------------------------------------------------===// |
| 326 | |
| 327 | LogicalResult |
| 328 | AllocaOpPattern::matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor, |
| 329 | ConversionPatternRewriter &rewriter) const { |
| 330 | MemRefType allocType = allocaOp.getType(); |
| 331 | if (!isAllocationSupported(allocaOp, allocType)) |
| 332 | return rewriter.notifyMatchFailure(allocaOp, "unhandled allocation type" ); |
| 333 | |
| 334 | // Get the SPIR-V type for the allocation. |
| 335 | Type spirvType = getTypeConverter()->convertType(allocType); |
| 336 | if (!spirvType) |
| 337 | return rewriter.notifyMatchFailure(allocaOp, "type conversion failed" ); |
| 338 | |
| 339 | rewriter.replaceOpWithNewOp<spirv::VariableOp>(allocaOp, spirvType, |
| 340 | spirv::StorageClass::Function, |
| 341 | /*initializer=*/nullptr); |
| 342 | return success(); |
| 343 | } |
| 344 | |
| 345 | //===----------------------------------------------------------------------===// |
| 346 | // AllocOp |
| 347 | //===----------------------------------------------------------------------===// |
| 348 | |
| 349 | LogicalResult |
| 350 | AllocOpPattern::matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor, |
| 351 | ConversionPatternRewriter &rewriter) const { |
| 352 | MemRefType allocType = operation.getType(); |
| 353 | if (!isAllocationSupported(operation, allocType)) |
| 354 | return rewriter.notifyMatchFailure(operation, "unhandled allocation type" ); |
| 355 | |
| 356 | // Get the SPIR-V type for the allocation. |
| 357 | Type spirvType = getTypeConverter()->convertType(allocType); |
| 358 | if (!spirvType) |
| 359 | return rewriter.notifyMatchFailure(operation, "type conversion failed" ); |
| 360 | |
| 361 | // Insert spirv.GlobalVariable for this allocation. |
| 362 | Operation *parent = |
| 363 | SymbolTable::getNearestSymbolTable(from: operation->getParentOp()); |
| 364 | if (!parent) |
| 365 | return failure(); |
| 366 | Location loc = operation.getLoc(); |
| 367 | spirv::GlobalVariableOp varOp; |
| 368 | { |
| 369 | OpBuilder::InsertionGuard guard(rewriter); |
| 370 | Block &entryBlock = *parent->getRegion(index: 0).begin(); |
| 371 | rewriter.setInsertionPointToStart(&entryBlock); |
| 372 | auto varOps = entryBlock.getOps<spirv::GlobalVariableOp>(); |
| 373 | std::string varName = |
| 374 | std::string("__workgroup_mem__" ) + |
| 375 | std::to_string(std::distance(varOps.begin(), varOps.end())); |
| 376 | varOp = rewriter.create<spirv::GlobalVariableOp>(loc, spirvType, varName, |
| 377 | /*initializer=*/nullptr); |
| 378 | } |
| 379 | |
| 380 | // Get pointer to global variable at the current scope. |
| 381 | rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(operation, varOp); |
| 382 | return success(); |
| 383 | } |
| 384 | |
| 385 | //===----------------------------------------------------------------------===// |
| 386 | // AllocOp |
| 387 | //===----------------------------------------------------------------------===// |
| 388 | |
| 389 | LogicalResult |
| 390 | AtomicRMWOpPattern::matchAndRewrite(memref::AtomicRMWOp atomicOp, |
| 391 | OpAdaptor adaptor, |
| 392 | ConversionPatternRewriter &rewriter) const { |
| 393 | if (isa<FloatType>(atomicOp.getType())) |
| 394 | return rewriter.notifyMatchFailure(atomicOp, |
| 395 | "unimplemented floating-point case" ); |
| 396 | |
| 397 | auto memrefType = cast<MemRefType>(atomicOp.getMemref().getType()); |
| 398 | std::optional<spirv::Scope> scope = getAtomicOpScope(memrefType); |
| 399 | if (!scope) |
| 400 | return rewriter.notifyMatchFailure(atomicOp, |
| 401 | "unsupported memref memory space" ); |
| 402 | |
| 403 | auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>(); |
| 404 | Type resultType = typeConverter.convertType(atomicOp.getType()); |
| 405 | if (!resultType) |
| 406 | return rewriter.notifyMatchFailure(atomicOp, |
| 407 | "failed to convert result type" ); |
| 408 | |
| 409 | auto loc = atomicOp.getLoc(); |
| 410 | Value ptr = |
| 411 | spirv::getElementPtr(typeConverter: typeConverter, baseType: memrefType, basePtr: adaptor.getMemref(), |
| 412 | indices: adaptor.getIndices(), loc: loc, builder&: rewriter); |
| 413 | |
| 414 | if (!ptr) |
| 415 | return failure(); |
| 416 | |
| 417 | #define ATOMIC_CASE(kind, spirvOp) \ |
| 418 | case arith::AtomicRMWKind::kind: \ |
| 419 | rewriter.replaceOpWithNewOp<spirv::spirvOp>( \ |
| 420 | atomicOp, resultType, ptr, *scope, \ |
| 421 | spirv::MemorySemantics::AcquireRelease, adaptor.getValue()); \ |
| 422 | break |
| 423 | |
| 424 | switch (atomicOp.getKind()) { |
| 425 | ATOMIC_CASE(addi, AtomicIAddOp); |
| 426 | ATOMIC_CASE(maxs, AtomicSMaxOp); |
| 427 | ATOMIC_CASE(maxu, AtomicUMaxOp); |
| 428 | ATOMIC_CASE(mins, AtomicSMinOp); |
| 429 | ATOMIC_CASE(minu, AtomicUMinOp); |
| 430 | ATOMIC_CASE(ori, AtomicOrOp); |
| 431 | ATOMIC_CASE(andi, AtomicAndOp); |
| 432 | default: |
| 433 | return rewriter.notifyMatchFailure(atomicOp, "unimplemented atomic kind" ); |
| 434 | } |
| 435 | |
| 436 | #undef ATOMIC_CASE |
| 437 | |
| 438 | return success(); |
| 439 | } |
| 440 | |
| 441 | //===----------------------------------------------------------------------===// |
| 442 | // DeallocOp |
| 443 | //===----------------------------------------------------------------------===// |
| 444 | |
| 445 | LogicalResult |
| 446 | DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation, |
| 447 | OpAdaptor adaptor, |
| 448 | ConversionPatternRewriter &rewriter) const { |
| 449 | MemRefType deallocType = cast<MemRefType>(operation.getMemref().getType()); |
| 450 | if (!isAllocationSupported(operation, deallocType)) |
| 451 | return rewriter.notifyMatchFailure(operation, "unhandled allocation type" ); |
| 452 | rewriter.eraseOp(op: operation); |
| 453 | return success(); |
| 454 | } |
| 455 | |
| 456 | //===----------------------------------------------------------------------===// |
| 457 | // LoadOp |
| 458 | //===----------------------------------------------------------------------===// |
| 459 | |
| 460 | struct MemoryRequirements { |
| 461 | spirv::MemoryAccessAttr memoryAccess; |
| 462 | IntegerAttr alignment; |
| 463 | }; |
| 464 | |
| 465 | /// Given an accessed SPIR-V pointer, calculates its alignment requirements, if |
| 466 | /// any. |
| 467 | static FailureOr<MemoryRequirements> |
| 468 | calculateMemoryRequirements(Value accessedPtr, bool isNontemporal) { |
| 469 | MLIRContext *ctx = accessedPtr.getContext(); |
| 470 | |
| 471 | auto memoryAccess = spirv::MemoryAccess::None; |
| 472 | if (isNontemporal) { |
| 473 | memoryAccess = spirv::MemoryAccess::Nontemporal; |
| 474 | } |
| 475 | |
| 476 | auto ptrType = cast<spirv::PointerType>(Val: accessedPtr.getType()); |
| 477 | if (ptrType.getStorageClass() != spirv::StorageClass::PhysicalStorageBuffer) { |
| 478 | if (memoryAccess == spirv::MemoryAccess::None) { |
| 479 | return MemoryRequirements{spirv::MemoryAccessAttr{}, IntegerAttr{}}; |
| 480 | } |
| 481 | return MemoryRequirements{spirv::MemoryAccessAttr::get(ctx, memoryAccess), |
| 482 | IntegerAttr{}}; |
| 483 | } |
| 484 | |
| 485 | // PhysicalStorageBuffers require the `Aligned` attribute. |
| 486 | auto pointeeType = dyn_cast<spirv::ScalarType>(Val: ptrType.getPointeeType()); |
| 487 | if (!pointeeType) |
| 488 | return failure(); |
| 489 | |
| 490 | // For scalar types, the alignment is determined by their size. |
| 491 | std::optional<int64_t> sizeInBytes = pointeeType.getSizeInBytes(); |
| 492 | if (!sizeInBytes.has_value()) |
| 493 | return failure(); |
| 494 | |
| 495 | memoryAccess = memoryAccess | spirv::MemoryAccess::Aligned; |
| 496 | auto memAccessAttr = spirv::MemoryAccessAttr::get(ctx, memoryAccess); |
| 497 | auto alignment = IntegerAttr::get(IntegerType::get(ctx, 32), *sizeInBytes); |
| 498 | return MemoryRequirements{memAccessAttr, alignment}; |
| 499 | } |
| 500 | |
| 501 | /// Given an accessed SPIR-V pointer and the original memref load/store |
| 502 | /// `memAccess` op, calculates the alignment requirements, if any. Takes into |
| 503 | /// account the alignment attributes applied to the load/store op. |
| 504 | template <class LoadOrStoreOp> |
| 505 | static FailureOr<MemoryRequirements> |
| 506 | calculateMemoryRequirements(Value accessedPtr, LoadOrStoreOp loadOrStoreOp) { |
| 507 | static_assert( |
| 508 | llvm::is_one_of<LoadOrStoreOp, memref::LoadOp, memref::StoreOp>::value, |
| 509 | "Must be called on either memref::LoadOp or memref::StoreOp" ); |
| 510 | |
| 511 | Operation *memrefAccessOp = loadOrStoreOp.getOperation(); |
| 512 | auto memrefMemAccess = memrefAccessOp->getAttrOfType<spirv::MemoryAccessAttr>( |
| 513 | spirv::attributeName<spirv::MemoryAccess>()); |
| 514 | auto memrefAlignment = |
| 515 | memrefAccessOp->getAttrOfType<IntegerAttr>("alignment" ); |
| 516 | if (memrefMemAccess && memrefAlignment) |
| 517 | return MemoryRequirements{memrefMemAccess, memrefAlignment}; |
| 518 | |
| 519 | return calculateMemoryRequirements(accessedPtr, |
| 520 | loadOrStoreOp.getNontemporal()); |
| 521 | } |
| 522 | |
| 523 | LogicalResult |
| 524 | IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, |
| 525 | ConversionPatternRewriter &rewriter) const { |
| 526 | auto loc = loadOp.getLoc(); |
| 527 | auto memrefType = cast<MemRefType>(loadOp.getMemref().getType()); |
| 528 | if (!memrefType.getElementType().isSignlessInteger()) |
| 529 | return failure(); |
| 530 | |
| 531 | const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>(); |
| 532 | Value accessChain = |
| 533 | spirv::getElementPtr(typeConverter: typeConverter, baseType: memrefType, basePtr: adaptor.getMemref(), |
| 534 | indices: adaptor.getIndices(), loc: loc, builder&: rewriter); |
| 535 | |
| 536 | if (!accessChain) |
| 537 | return failure(); |
| 538 | |
| 539 | int srcBits = memrefType.getElementType().getIntOrFloatBitWidth(); |
| 540 | bool isBool = srcBits == 1; |
| 541 | if (isBool) |
| 542 | srcBits = typeConverter.getOptions().boolNumBits; |
| 543 | |
| 544 | auto pointerType = typeConverter.convertType<spirv::PointerType>(memrefType); |
| 545 | if (!pointerType) |
| 546 | return rewriter.notifyMatchFailure(loadOp, "failed to convert memref type" ); |
| 547 | |
| 548 | Type pointeeType = pointerType.getPointeeType(); |
| 549 | Type dstType; |
| 550 | if (typeConverter.allows(spirv::Capability::Kernel)) { |
| 551 | if (auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType)) |
| 552 | dstType = arrayType.getElementType(); |
| 553 | else |
| 554 | dstType = pointeeType; |
| 555 | } else { |
| 556 | // For Vulkan we need to extract element from wrapping struct and array. |
| 557 | Type structElemType = |
| 558 | cast<spirv::StructType>(Val&: pointeeType).getElementType(0); |
| 559 | if (auto arrayType = dyn_cast<spirv::ArrayType>(structElemType)) |
| 560 | dstType = arrayType.getElementType(); |
| 561 | else |
| 562 | dstType = cast<spirv::RuntimeArrayType>(Val&: structElemType).getElementType(); |
| 563 | } |
| 564 | int dstBits = dstType.getIntOrFloatBitWidth(); |
| 565 | assert(dstBits % srcBits == 0); |
| 566 | |
| 567 | // If the rewritten load op has the same bit width, use the loading value |
| 568 | // directly. |
| 569 | if (srcBits == dstBits) { |
| 570 | auto memoryRequirements = calculateMemoryRequirements(accessChain, loadOp); |
| 571 | if (failed(memoryRequirements)) |
| 572 | return rewriter.notifyMatchFailure( |
| 573 | loadOp, "failed to determine memory requirements" ); |
| 574 | |
| 575 | auto [memoryAccess, alignment] = *memoryRequirements; |
| 576 | Value loadVal = rewriter.create<spirv::LoadOp>(loc, accessChain, |
| 577 | memoryAccess, alignment); |
| 578 | if (isBool) |
| 579 | loadVal = castIntNToBool(loc, loadVal, rewriter); |
| 580 | rewriter.replaceOp(loadOp, loadVal); |
| 581 | return success(); |
| 582 | } |
| 583 | |
| 584 | // Bitcasting is currently unsupported for Kernel capability / |
| 585 | // spirv.PtrAccessChain. |
| 586 | if (typeConverter.allows(spirv::Capability::Kernel)) |
| 587 | return failure(); |
| 588 | |
| 589 | auto accessChainOp = accessChain.getDefiningOp<spirv::AccessChainOp>(); |
| 590 | if (!accessChainOp) |
| 591 | return failure(); |
| 592 | |
| 593 | // Assume that getElementPtr() works linearizely. If it's a scalar, the method |
| 594 | // still returns a linearized accessing. If the accessing is not linearized, |
| 595 | // there will be offset issues. |
| 596 | assert(accessChainOp.getIndices().size() == 2); |
| 597 | Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp, |
| 598 | srcBits, dstBits, rewriter); |
| 599 | auto memoryRequirements = calculateMemoryRequirements(adjustedPtr, loadOp); |
| 600 | if (failed(memoryRequirements)) |
| 601 | return rewriter.notifyMatchFailure( |
| 602 | loadOp, "failed to determine memory requirements" ); |
| 603 | |
| 604 | auto [memoryAccess, alignment] = *memoryRequirements; |
| 605 | Value spvLoadOp = rewriter.create<spirv::LoadOp>(loc, dstType, adjustedPtr, |
| 606 | memoryAccess, alignment); |
| 607 | |
| 608 | // Shift the bits to the rightmost. |
| 609 | // ____XXXX________ -> ____________XXXX |
| 610 | Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1); |
| 611 | Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter); |
| 612 | Value result = rewriter.createOrFold<spirv::ShiftRightArithmeticOp>( |
| 613 | loc, spvLoadOp.getType(), spvLoadOp, offset); |
| 614 | |
| 615 | // Apply the mask to extract corresponding bits. |
| 616 | Value mask = rewriter.createOrFold<spirv::ConstantOp>( |
| 617 | loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1)); |
| 618 | result = |
| 619 | rewriter.createOrFold<spirv::BitwiseAndOp>(loc, dstType, result, mask); |
| 620 | |
| 621 | // Apply sign extension on the loading value unconditionally. The signedness |
| 622 | // semantic is carried in the operator itself, we relies other pattern to |
| 623 | // handle the casting. |
| 624 | IntegerAttr shiftValueAttr = |
| 625 | rewriter.getIntegerAttr(dstType, dstBits - srcBits); |
| 626 | Value shiftValue = |
| 627 | rewriter.createOrFold<spirv::ConstantOp>(loc, dstType, shiftValueAttr); |
| 628 | result = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>(loc, dstType, |
| 629 | result, shiftValue); |
| 630 | result = rewriter.createOrFold<spirv::ShiftRightArithmeticOp>( |
| 631 | loc, dstType, result, shiftValue); |
| 632 | |
| 633 | rewriter.replaceOp(loadOp, result); |
| 634 | |
| 635 | assert(accessChainOp.use_empty()); |
| 636 | rewriter.eraseOp(op: accessChainOp); |
| 637 | |
| 638 | return success(); |
| 639 | } |
| 640 | |
| 641 | LogicalResult |
| 642 | LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, |
| 643 | ConversionPatternRewriter &rewriter) const { |
| 644 | auto memrefType = cast<MemRefType>(loadOp.getMemref().getType()); |
| 645 | if (memrefType.getElementType().isSignlessInteger()) |
| 646 | return failure(); |
| 647 | Value loadPtr = spirv::getElementPtr( |
| 648 | typeConverter: *getTypeConverter<SPIRVTypeConverter>(), baseType: memrefType, basePtr: adaptor.getMemref(), |
| 649 | indices: adaptor.getIndices(), loc: loadOp.getLoc(), builder&: rewriter); |
| 650 | |
| 651 | if (!loadPtr) |
| 652 | return failure(); |
| 653 | |
| 654 | auto memoryRequirements = calculateMemoryRequirements(loadPtr, loadOp); |
| 655 | if (failed(memoryRequirements)) |
| 656 | return rewriter.notifyMatchFailure( |
| 657 | loadOp, "failed to determine memory requirements" ); |
| 658 | |
| 659 | auto [memoryAccess, alignment] = *memoryRequirements; |
| 660 | rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr, memoryAccess, |
| 661 | alignment); |
| 662 | return success(); |
| 663 | } |
| 664 | |
| 665 | LogicalResult |
| 666 | IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor, |
| 667 | ConversionPatternRewriter &rewriter) const { |
| 668 | auto memrefType = cast<MemRefType>(storeOp.getMemref().getType()); |
| 669 | if (!memrefType.getElementType().isSignlessInteger()) |
| 670 | return rewriter.notifyMatchFailure(storeOp, |
| 671 | "element type is not a signless int" ); |
| 672 | |
| 673 | auto loc = storeOp.getLoc(); |
| 674 | auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>(); |
| 675 | Value accessChain = |
| 676 | spirv::getElementPtr(typeConverter: typeConverter, baseType: memrefType, basePtr: adaptor.getMemref(), |
| 677 | indices: adaptor.getIndices(), loc: loc, builder&: rewriter); |
| 678 | |
| 679 | if (!accessChain) |
| 680 | return rewriter.notifyMatchFailure( |
| 681 | storeOp, "failed to convert element pointer type" ); |
| 682 | |
| 683 | int srcBits = memrefType.getElementType().getIntOrFloatBitWidth(); |
| 684 | |
| 685 | bool isBool = srcBits == 1; |
| 686 | if (isBool) |
| 687 | srcBits = typeConverter.getOptions().boolNumBits; |
| 688 | |
| 689 | auto pointerType = typeConverter.convertType<spirv::PointerType>(memrefType); |
| 690 | if (!pointerType) |
| 691 | return rewriter.notifyMatchFailure(storeOp, |
| 692 | "failed to convert memref type" ); |
| 693 | |
| 694 | Type pointeeType = pointerType.getPointeeType(); |
| 695 | IntegerType dstType; |
| 696 | if (typeConverter.allows(spirv::Capability::Kernel)) { |
| 697 | if (auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType)) |
| 698 | dstType = dyn_cast<IntegerType>(arrayType.getElementType()); |
| 699 | else |
| 700 | dstType = dyn_cast<IntegerType>(pointeeType); |
| 701 | } else { |
| 702 | // For Vulkan we need to extract element from wrapping struct and array. |
| 703 | Type structElemType = |
| 704 | cast<spirv::StructType>(Val&: pointeeType).getElementType(0); |
| 705 | if (auto arrayType = dyn_cast<spirv::ArrayType>(structElemType)) |
| 706 | dstType = dyn_cast<IntegerType>(arrayType.getElementType()); |
| 707 | else |
| 708 | dstType = dyn_cast<IntegerType>( |
| 709 | cast<spirv::RuntimeArrayType>(Val&: structElemType).getElementType()); |
| 710 | } |
| 711 | |
| 712 | if (!dstType) |
| 713 | return rewriter.notifyMatchFailure( |
| 714 | storeOp, "failed to determine destination element type" ); |
| 715 | |
| 716 | int dstBits = static_cast<int>(dstType.getWidth()); |
| 717 | assert(dstBits % srcBits == 0); |
| 718 | |
| 719 | if (srcBits == dstBits) { |
| 720 | auto memoryRequirements = calculateMemoryRequirements(accessChain, storeOp); |
| 721 | if (failed(memoryRequirements)) |
| 722 | return rewriter.notifyMatchFailure( |
| 723 | storeOp, "failed to determine memory requirements" ); |
| 724 | |
| 725 | auto [memoryAccess, alignment] = *memoryRequirements; |
| 726 | Value storeVal = adaptor.getValue(); |
| 727 | if (isBool) |
| 728 | storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter); |
| 729 | rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, accessChain, storeVal, |
| 730 | memoryAccess, alignment); |
| 731 | return success(); |
| 732 | } |
| 733 | |
| 734 | // Bitcasting is currently unsupported for Kernel capability / |
| 735 | // spirv.PtrAccessChain. |
| 736 | if (typeConverter.allows(spirv::Capability::Kernel)) |
| 737 | return failure(); |
| 738 | |
| 739 | auto accessChainOp = accessChain.getDefiningOp<spirv::AccessChainOp>(); |
| 740 | if (!accessChainOp) |
| 741 | return failure(); |
| 742 | |
| 743 | // Since there are multiple threads in the processing, the emulation will be |
| 744 | // done with atomic operations. E.g., if the stored value is i8, rewrite the |
| 745 | // StoreOp to: |
| 746 | // 1) load a 32-bit integer |
| 747 | // 2) clear 8 bits in the loaded value |
| 748 | // 3) set 8 bits in the loaded value |
| 749 | // 4) store 32-bit value back |
| 750 | // |
| 751 | // Step 2 is done with AtomicAnd, and step 3 is done with AtomicOr (of the |
| 752 | // loaded 32-bit value and the shifted 8-bit store value) as another atomic |
| 753 | // step. |
| 754 | assert(accessChainOp.getIndices().size() == 2); |
| 755 | Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1); |
| 756 | Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter); |
| 757 | |
| 758 | // Create a mask to clear the destination. E.g., if it is the second i8 in |
| 759 | // i32, 0xFFFF00FF is created. |
| 760 | Value mask = rewriter.createOrFold<spirv::ConstantOp>( |
| 761 | loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1)); |
| 762 | Value clearBitsMask = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>( |
| 763 | loc, dstType, mask, offset); |
| 764 | clearBitsMask = |
| 765 | rewriter.createOrFold<spirv::NotOp>(loc, dstType, clearBitsMask); |
| 766 | |
| 767 | Value storeVal = shiftValue(loc, adaptor.getValue(), offset, mask, rewriter); |
| 768 | Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp, |
| 769 | srcBits, dstBits, rewriter); |
| 770 | std::optional<spirv::Scope> scope = getAtomicOpScope(memrefType); |
| 771 | if (!scope) |
| 772 | return rewriter.notifyMatchFailure(storeOp, "atomic scope not available" ); |
| 773 | |
| 774 | Value result = rewriter.create<spirv::AtomicAndOp>( |
| 775 | loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease, |
| 776 | clearBitsMask); |
| 777 | result = rewriter.create<spirv::AtomicOrOp>( |
| 778 | loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease, |
| 779 | storeVal); |
| 780 | |
| 781 | // The AtomicOrOp has no side effect. Since it is already inserted, we can |
| 782 | // just remove the original StoreOp. Note that rewriter.replaceOp() |
| 783 | // doesn't work because it only accepts that the numbers of result are the |
| 784 | // same. |
| 785 | rewriter.eraseOp(op: storeOp); |
| 786 | |
| 787 | assert(accessChainOp.use_empty()); |
| 788 | rewriter.eraseOp(op: accessChainOp); |
| 789 | |
| 790 | return success(); |
| 791 | } |
| 792 | |
| 793 | //===----------------------------------------------------------------------===// |
| 794 | // MemorySpaceCastOp |
| 795 | //===----------------------------------------------------------------------===// |
| 796 | |
| 797 | LogicalResult MemorySpaceCastOpPattern::matchAndRewrite( |
| 798 | memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor, |
| 799 | ConversionPatternRewriter &rewriter) const { |
| 800 | Location loc = addrCastOp.getLoc(); |
| 801 | auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>(); |
| 802 | if (!typeConverter.allows(spirv::Capability::Kernel)) |
| 803 | return rewriter.notifyMatchFailure( |
| 804 | arg&: loc, msg: "address space casts require kernel capability" ); |
| 805 | |
| 806 | auto sourceType = dyn_cast<MemRefType>(addrCastOp.getSource().getType()); |
| 807 | if (!sourceType) |
| 808 | return rewriter.notifyMatchFailure( |
| 809 | arg&: loc, msg: "SPIR-V lowering requires ranked memref types" ); |
| 810 | auto resultType = cast<MemRefType>(addrCastOp.getResult().getType()); |
| 811 | |
| 812 | auto sourceStorageClassAttr = |
| 813 | dyn_cast_or_null<spirv::StorageClassAttr>(sourceType.getMemorySpace()); |
| 814 | if (!sourceStorageClassAttr) |
| 815 | return rewriter.notifyMatchFailure(loc, reasonCallback: [sourceType](Diagnostic &diag) { |
| 816 | diag << "source address space " << sourceType.getMemorySpace() |
| 817 | << " must be a SPIR-V storage class" ; |
| 818 | }); |
| 819 | auto resultStorageClassAttr = |
| 820 | dyn_cast_or_null<spirv::StorageClassAttr>(resultType.getMemorySpace()); |
| 821 | if (!resultStorageClassAttr) |
| 822 | return rewriter.notifyMatchFailure(loc, reasonCallback: [resultType](Diagnostic &diag) { |
| 823 | diag << "result address space " << resultType.getMemorySpace() |
| 824 | << " must be a SPIR-V storage class" ; |
| 825 | }); |
| 826 | |
| 827 | spirv::StorageClass sourceSc = sourceStorageClassAttr.getValue(); |
| 828 | spirv::StorageClass resultSc = resultStorageClassAttr.getValue(); |
| 829 | |
| 830 | Value result = adaptor.getSource(); |
| 831 | Type resultPtrType = typeConverter.convertType(resultType); |
| 832 | if (!resultPtrType) |
| 833 | return rewriter.notifyMatchFailure(addrCastOp, |
| 834 | "failed to convert memref type" ); |
| 835 | |
| 836 | Type genericPtrType = resultPtrType; |
| 837 | // SPIR-V doesn't have a general address space cast operation. Instead, it has |
| 838 | // conversions to and from generic pointers. To implement the general case, |
| 839 | // we use specific-to-generic conversions when the source class is not |
| 840 | // generic. Then when the result storage class is not generic, we convert the |
| 841 | // generic pointer (either the input on ar intermediate result) to that |
| 842 | // class. This also means that we'll need the intermediate generic pointer |
| 843 | // type if neither the source or destination have it. |
| 844 | if (sourceSc != spirv::StorageClass::Generic && |
| 845 | resultSc != spirv::StorageClass::Generic) { |
| 846 | Type intermediateType = |
| 847 | MemRefType::get(sourceType.getShape(), sourceType.getElementType(), |
| 848 | sourceType.getLayout(), |
| 849 | rewriter.getAttr<spirv::StorageClassAttr>( |
| 850 | spirv::StorageClass::Generic)); |
| 851 | genericPtrType = typeConverter.convertType(intermediateType); |
| 852 | } |
| 853 | if (sourceSc != spirv::StorageClass::Generic) { |
| 854 | result = |
| 855 | rewriter.create<spirv::PtrCastToGenericOp>(loc, genericPtrType, result); |
| 856 | } |
| 857 | if (resultSc != spirv::StorageClass::Generic) { |
| 858 | result = |
| 859 | rewriter.create<spirv::GenericCastToPtrOp>(loc, resultPtrType, result); |
| 860 | } |
| 861 | rewriter.replaceOp(addrCastOp, result); |
| 862 | return success(); |
| 863 | } |
| 864 | |
| 865 | LogicalResult |
| 866 | StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor, |
| 867 | ConversionPatternRewriter &rewriter) const { |
| 868 | auto memrefType = cast<MemRefType>(storeOp.getMemref().getType()); |
| 869 | if (memrefType.getElementType().isSignlessInteger()) |
| 870 | return rewriter.notifyMatchFailure(storeOp, "signless int" ); |
| 871 | auto storePtr = spirv::getElementPtr( |
| 872 | typeConverter: *getTypeConverter<SPIRVTypeConverter>(), baseType: memrefType, basePtr: adaptor.getMemref(), |
| 873 | indices: adaptor.getIndices(), loc: storeOp.getLoc(), builder&: rewriter); |
| 874 | |
| 875 | if (!storePtr) |
| 876 | return rewriter.notifyMatchFailure(storeOp, "type conversion failed" ); |
| 877 | |
| 878 | auto memoryRequirements = calculateMemoryRequirements(storePtr, storeOp); |
| 879 | if (failed(memoryRequirements)) |
| 880 | return rewriter.notifyMatchFailure( |
| 881 | storeOp, "failed to determine memory requirements" ); |
| 882 | |
| 883 | auto [memoryAccess, alignment] = *memoryRequirements; |
| 884 | rewriter.replaceOpWithNewOp<spirv::StoreOp>( |
| 885 | storeOp, storePtr, adaptor.getValue(), memoryAccess, alignment); |
| 886 | return success(); |
| 887 | } |
| 888 | |
| 889 | LogicalResult ReinterpretCastPattern::matchAndRewrite( |
| 890 | memref::ReinterpretCastOp op, OpAdaptor adaptor, |
| 891 | ConversionPatternRewriter &rewriter) const { |
| 892 | Value src = adaptor.getSource(); |
| 893 | auto srcType = dyn_cast<spirv::PointerType>(Val: src.getType()); |
| 894 | |
| 895 | if (!srcType) |
| 896 | return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { |
| 897 | diag << "invalid src type " << src.getType(); |
| 898 | }); |
| 899 | |
| 900 | const TypeConverter *converter = getTypeConverter(); |
| 901 | |
| 902 | auto dstType = converter->convertType<spirv::PointerType>(op.getType()); |
| 903 | if (dstType != srcType) |
| 904 | return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { |
| 905 | diag << "invalid dst type " << op.getType(); |
| 906 | }); |
| 907 | |
| 908 | OpFoldResult offset = |
| 909 | getMixedValues(adaptor.getStaticOffsets(), adaptor.getOffsets(), rewriter) |
| 910 | .front(); |
| 911 | if (isZeroInteger(v: offset)) { |
| 912 | rewriter.replaceOp(op, src); |
| 913 | return success(); |
| 914 | } |
| 915 | |
| 916 | Type intType = converter->convertType(rewriter.getIndexType()); |
| 917 | if (!intType) |
| 918 | return rewriter.notifyMatchFailure(op, "failed to convert index type" ); |
| 919 | |
| 920 | Location loc = op.getLoc(); |
| 921 | auto offsetValue = [&]() -> Value { |
| 922 | if (auto val = dyn_cast<Value>(offset)) |
| 923 | return val; |
| 924 | |
| 925 | int64_t attrVal = cast<IntegerAttr>(cast<Attribute>(Val&: offset)).getInt(); |
| 926 | Attribute attr = rewriter.getIntegerAttr(intType, attrVal); |
| 927 | return rewriter.createOrFold<spirv::ConstantOp>(loc, intType, attr); |
| 928 | }(); |
| 929 | |
| 930 | rewriter.replaceOpWithNewOp<spirv::InBoundsPtrAccessChainOp>( |
| 931 | op, src, offsetValue, std::nullopt); |
| 932 | return success(); |
| 933 | } |
| 934 | |
| 935 | //===----------------------------------------------------------------------===// |
| 936 | // ExtractAlignedPointerAsIndexOp |
| 937 | //===----------------------------------------------------------------------===// |
| 938 | |
| 939 | LogicalResult ExtractAlignedPointerAsIndexOpPattern::matchAndRewrite( |
| 940 | memref::ExtractAlignedPointerAsIndexOp , OpAdaptor adaptor, |
| 941 | ConversionPatternRewriter &rewriter) const { |
| 942 | auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>(); |
| 943 | Type indexType = typeConverter.getIndexType(); |
| 944 | rewriter.replaceOpWithNewOp<spirv::ConvertPtrToUOp>(extractOp, indexType, |
| 945 | adaptor.getSource()); |
| 946 | return success(); |
| 947 | } |
| 948 | |
| 949 | //===----------------------------------------------------------------------===// |
| 950 | // Pattern population |
| 951 | //===----------------------------------------------------------------------===// |
| 952 | |
| 953 | namespace mlir { |
| 954 | void populateMemRefToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, |
| 955 | RewritePatternSet &patterns) { |
| 956 | patterns |
| 957 | .add<AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern, |
| 958 | DeallocOpPattern, IntLoadOpPattern, IntStoreOpPattern, LoadOpPattern, |
| 959 | MemorySpaceCastOpPattern, StoreOpPattern, ReinterpretCastPattern, |
| 960 | CastPattern, ExtractAlignedPointerAsIndexOpPattern>( |
| 961 | arg: typeConverter, args: patterns.getContext()); |
| 962 | } |
| 963 | } // namespace mlir |
| 964 | |