| 1 | //===- MemRefToSPIRV.cpp - MemRef to SPIR-V Patterns ----------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements patterns to convert MemRef dialect to SPIR-V dialect. |
| 10 | // |
| 11 | //===----------------------------------------------------------------------===// |
| 12 | |
| 13 | #include "mlir/Dialect/Arith/IR/Arith.h" |
| 14 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| 15 | #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" |
| 16 | #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" |
| 17 | #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" |
| 18 | #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" |
| 19 | #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" |
| 20 | #include "mlir/IR/BuiltinAttributes.h" |
| 21 | #include "mlir/IR/BuiltinTypes.h" |
| 22 | #include "mlir/IR/MLIRContext.h" |
| 23 | #include "mlir/IR/Visitors.h" |
| 24 | #include <cassert> |
| 25 | #include <optional> |
| 26 | |
| 27 | #define DEBUG_TYPE "memref-to-spirv-pattern" |
| 28 | |
| 29 | using namespace mlir; |
| 30 | |
| 31 | //===----------------------------------------------------------------------===// |
| 32 | // Utility functions |
| 33 | //===----------------------------------------------------------------------===// |
| 34 | |
| 35 | /// Returns the offset of the value in `targetBits` representation. |
| 36 | /// |
| 37 | /// `srcIdx` is an index into a 1-D array with each element having `sourceBits`. |
| 38 | /// It's assumed to be non-negative. |
| 39 | /// |
| 40 | /// When accessing an element in the array treating as having elements of |
| 41 | /// `targetBits`, multiple values are loaded in the same time. The method |
| 42 | /// returns the offset where the `srcIdx` locates in the value. For example, if |
| 43 | /// `sourceBits` equals to 8 and `targetBits` equals to 32, the x-th element is |
| 44 | /// located at (x % 4) * 8. Because there are four elements in one i32, and one |
| 45 | /// element has 8 bits. |
| 46 | static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits, |
| 47 | int targetBits, OpBuilder &builder) { |
| 48 | assert(targetBits % sourceBits == 0); |
| 49 | Type type = srcIdx.getType(); |
| 50 | IntegerAttr idxAttr = builder.getIntegerAttr(type, value: targetBits / sourceBits); |
| 51 | auto idx = builder.createOrFold<spirv::ConstantOp>(location: loc, args&: type, args&: idxAttr); |
| 52 | IntegerAttr srcBitsAttr = builder.getIntegerAttr(type, value: sourceBits); |
| 53 | auto srcBitsValue = |
| 54 | builder.createOrFold<spirv::ConstantOp>(location: loc, args&: type, args&: srcBitsAttr); |
| 55 | auto m = builder.createOrFold<spirv::UModOp>(location: loc, args&: srcIdx, args&: idx); |
| 56 | return builder.createOrFold<spirv::IMulOp>(location: loc, args&: type, args&: m, args&: srcBitsValue); |
| 57 | } |
| 58 | |
| 59 | /// Returns an adjusted spirv::AccessChainOp. Based on the |
| 60 | /// extension/capabilities, certain integer bitwidths `sourceBits` might not be |
| 61 | /// supported. During conversion if a memref of an unsupported type is used, |
| 62 | /// load/stores to this memref need to be modified to use a supported higher |
| 63 | /// bitwidth `targetBits` and extracting the required bits. For an accessing a |
| 64 | /// 1D array (spirv.array or spirv.rtarray), the last index is modified to load |
| 65 | /// the bits needed. The extraction of the actual bits needed are handled |
| 66 | /// separately. Note that this only works for a 1-D tensor. |
| 67 | static Value |
| 68 | adjustAccessChainForBitwidth(const SPIRVTypeConverter &typeConverter, |
| 69 | spirv::AccessChainOp op, int sourceBits, |
| 70 | int targetBits, OpBuilder &builder) { |
| 71 | assert(targetBits % sourceBits == 0); |
| 72 | const auto loc = op.getLoc(); |
| 73 | Value lastDim = op->getOperand(idx: op.getNumOperands() - 1); |
| 74 | Type type = lastDim.getType(); |
| 75 | IntegerAttr attr = builder.getIntegerAttr(type, value: targetBits / sourceBits); |
| 76 | auto idx = builder.createOrFold<spirv::ConstantOp>(location: loc, args&: type, args&: attr); |
| 77 | auto indices = llvm::to_vector<4>(Range: op.getIndices()); |
| 78 | // There are two elements if this is a 1-D tensor. |
| 79 | assert(indices.size() == 2); |
| 80 | indices.back() = builder.createOrFold<spirv::SDivOp>(location: loc, args&: lastDim, args&: idx); |
| 81 | Type t = typeConverter.convertType(t: op.getComponentPtr().getType()); |
| 82 | return builder.create<spirv::AccessChainOp>(location: loc, args&: t, args: op.getBasePtr(), args&: indices); |
| 83 | } |
| 84 | |
| 85 | /// Casts the given `srcBool` into an integer of `dstType`. |
| 86 | static Value castBoolToIntN(Location loc, Value srcBool, Type dstType, |
| 87 | OpBuilder &builder) { |
| 88 | assert(srcBool.getType().isInteger(1)); |
| 89 | if (dstType.isInteger(width: 1)) |
| 90 | return srcBool; |
| 91 | Value zero = spirv::ConstantOp::getZero(type: dstType, loc, builder); |
| 92 | Value one = spirv::ConstantOp::getOne(type: dstType, loc, builder); |
| 93 | return builder.createOrFold<spirv::SelectOp>(location: loc, args&: dstType, args&: srcBool, args&: one, |
| 94 | args&: zero); |
| 95 | } |
| 96 | |
| 97 | /// Returns the `targetBits`-bit value shifted by the given `offset`, and cast |
| 98 | /// to the type destination type, and masked. |
| 99 | static Value shiftValue(Location loc, Value value, Value offset, Value mask, |
| 100 | OpBuilder &builder) { |
| 101 | IntegerType dstType = cast<IntegerType>(Val: mask.getType()); |
| 102 | int targetBits = static_cast<int>(dstType.getWidth()); |
| 103 | int valueBits = value.getType().getIntOrFloatBitWidth(); |
| 104 | assert(valueBits <= targetBits); |
| 105 | |
| 106 | if (valueBits == 1) { |
| 107 | value = castBoolToIntN(loc, srcBool: value, dstType, builder); |
| 108 | } else { |
| 109 | if (valueBits < targetBits) { |
| 110 | value = builder.create<spirv::UConvertOp>( |
| 111 | location: loc, args: builder.getIntegerType(width: targetBits), args&: value); |
| 112 | } |
| 113 | |
| 114 | value = builder.createOrFold<spirv::BitwiseAndOp>(location: loc, args&: value, args&: mask); |
| 115 | } |
| 116 | return builder.createOrFold<spirv::ShiftLeftLogicalOp>(location: loc, args: value.getType(), |
| 117 | args&: value, args&: offset); |
| 118 | } |
| 119 | |
| 120 | /// Returns true if the allocations of memref `type` generated from `allocOp` |
| 121 | /// can be lowered to SPIR-V. |
| 122 | static bool isAllocationSupported(Operation *allocOp, MemRefType type) { |
| 123 | if (isa<memref::AllocOp, memref::DeallocOp>(Val: allocOp)) { |
| 124 | auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(Val: type.getMemorySpace()); |
| 125 | if (!sc || sc.getValue() != spirv::StorageClass::Workgroup) |
| 126 | return false; |
| 127 | } else if (isa<memref::AllocaOp>(Val: allocOp)) { |
| 128 | auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(Val: type.getMemorySpace()); |
| 129 | if (!sc || sc.getValue() != spirv::StorageClass::Function) |
| 130 | return false; |
| 131 | } else { |
| 132 | return false; |
| 133 | } |
| 134 | |
| 135 | // Currently only support static shape and int or float or vector of int or |
| 136 | // float element type. |
| 137 | if (!type.hasStaticShape()) |
| 138 | return false; |
| 139 | |
| 140 | Type elementType = type.getElementType(); |
| 141 | if (auto vecType = dyn_cast<VectorType>(Val&: elementType)) |
| 142 | elementType = vecType.getElementType(); |
| 143 | return elementType.isIntOrFloat(); |
| 144 | } |
| 145 | |
| 146 | /// Returns the scope to use for atomic operations use for emulating store |
| 147 | /// operations of unsupported integer bitwidths, based on the memref |
| 148 | /// type. Returns std::nullopt on failure. |
| 149 | static std::optional<spirv::Scope> getAtomicOpScope(MemRefType type) { |
| 150 | auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(Val: type.getMemorySpace()); |
| 151 | switch (sc.getValue()) { |
| 152 | case spirv::StorageClass::StorageBuffer: |
| 153 | return spirv::Scope::Device; |
| 154 | case spirv::StorageClass::Workgroup: |
| 155 | return spirv::Scope::Workgroup; |
| 156 | default: |
| 157 | break; |
| 158 | } |
| 159 | return {}; |
| 160 | } |
| 161 | |
| 162 | /// Casts the given `srcInt` into a boolean value. |
| 163 | static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder) { |
| 164 | if (srcInt.getType().isInteger(width: 1)) |
| 165 | return srcInt; |
| 166 | |
| 167 | auto one = spirv::ConstantOp::getZero(type: srcInt.getType(), loc, builder); |
| 168 | return builder.createOrFold<spirv::INotEqualOp>(location: loc, args&: srcInt, args&: one); |
| 169 | } |
| 170 | |
| 171 | //===----------------------------------------------------------------------===// |
| 172 | // Operation conversion |
| 173 | //===----------------------------------------------------------------------===// |
| 174 | |
| 175 | // Note that DRR cannot be used for the patterns in this file: we may need to |
| 176 | // convert type along the way, which requires ConversionPattern. DRR generates |
| 177 | // normal RewritePattern. |
| 178 | |
| 179 | namespace { |
| 180 | |
| 181 | /// Converts memref.alloca to SPIR-V Function variables. |
| 182 | class AllocaOpPattern final : public OpConversionPattern<memref::AllocaOp> { |
| 183 | public: |
| 184 | using OpConversionPattern<memref::AllocaOp>::OpConversionPattern; |
| 185 | |
| 186 | LogicalResult |
| 187 | matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor, |
| 188 | ConversionPatternRewriter &rewriter) const override; |
| 189 | }; |
| 190 | |
| 191 | /// Converts an allocation operation to SPIR-V. Currently only supports lowering |
| 192 | /// to Workgroup memory when the size is constant. Note that this pattern needs |
| 193 | /// to be applied in a pass that runs at least at spirv.module scope since it |
| 194 | /// wil ladd global variables into the spirv.module. |
| 195 | class AllocOpPattern final : public OpConversionPattern<memref::AllocOp> { |
| 196 | public: |
| 197 | using OpConversionPattern<memref::AllocOp>::OpConversionPattern; |
| 198 | |
| 199 | LogicalResult |
| 200 | matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor, |
| 201 | ConversionPatternRewriter &rewriter) const override; |
| 202 | }; |
| 203 | |
| 204 | /// Converts memref.automic_rmw operations to SPIR-V atomic operations. |
| 205 | class AtomicRMWOpPattern final |
| 206 | : public OpConversionPattern<memref::AtomicRMWOp> { |
| 207 | public: |
| 208 | using OpConversionPattern<memref::AtomicRMWOp>::OpConversionPattern; |
| 209 | |
| 210 | LogicalResult |
| 211 | matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor, |
| 212 | ConversionPatternRewriter &rewriter) const override; |
| 213 | }; |
| 214 | |
| 215 | /// Removed a deallocation if it is a supported allocation. Currently only |
| 216 | /// removes deallocation if the memory space is workgroup memory. |
| 217 | class DeallocOpPattern final : public OpConversionPattern<memref::DeallocOp> { |
| 218 | public: |
| 219 | using OpConversionPattern<memref::DeallocOp>::OpConversionPattern; |
| 220 | |
| 221 | LogicalResult |
| 222 | matchAndRewrite(memref::DeallocOp operation, OpAdaptor adaptor, |
| 223 | ConversionPatternRewriter &rewriter) const override; |
| 224 | }; |
| 225 | |
| 226 | /// Converts memref.load to spirv.Load + spirv.AccessChain on integers. |
| 227 | class IntLoadOpPattern final : public OpConversionPattern<memref::LoadOp> { |
| 228 | public: |
| 229 | using OpConversionPattern<memref::LoadOp>::OpConversionPattern; |
| 230 | |
| 231 | LogicalResult |
| 232 | matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, |
| 233 | ConversionPatternRewriter &rewriter) const override; |
| 234 | }; |
| 235 | |
| 236 | /// Converts memref.load to spirv.Load + spirv.AccessChain. |
| 237 | class LoadOpPattern final : public OpConversionPattern<memref::LoadOp> { |
| 238 | public: |
| 239 | using OpConversionPattern<memref::LoadOp>::OpConversionPattern; |
| 240 | |
| 241 | LogicalResult |
| 242 | matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, |
| 243 | ConversionPatternRewriter &rewriter) const override; |
| 244 | }; |
| 245 | |
| 246 | /// Converts memref.store to spirv.Store on integers. |
| 247 | class IntStoreOpPattern final : public OpConversionPattern<memref::StoreOp> { |
| 248 | public: |
| 249 | using OpConversionPattern<memref::StoreOp>::OpConversionPattern; |
| 250 | |
| 251 | LogicalResult |
| 252 | matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor, |
| 253 | ConversionPatternRewriter &rewriter) const override; |
| 254 | }; |
| 255 | |
| 256 | /// Converts memref.memory_space_cast to the appropriate spirv cast operations. |
| 257 | class MemorySpaceCastOpPattern final |
| 258 | : public OpConversionPattern<memref::MemorySpaceCastOp> { |
| 259 | public: |
| 260 | using OpConversionPattern<memref::MemorySpaceCastOp>::OpConversionPattern; |
| 261 | |
| 262 | LogicalResult |
| 263 | matchAndRewrite(memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor, |
| 264 | ConversionPatternRewriter &rewriter) const override; |
| 265 | }; |
| 266 | |
| 267 | /// Converts memref.store to spirv.Store. |
| 268 | class StoreOpPattern final : public OpConversionPattern<memref::StoreOp> { |
| 269 | public: |
| 270 | using OpConversionPattern<memref::StoreOp>::OpConversionPattern; |
| 271 | |
| 272 | LogicalResult |
| 273 | matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor, |
| 274 | ConversionPatternRewriter &rewriter) const override; |
| 275 | }; |
| 276 | |
| 277 | class ReinterpretCastPattern final |
| 278 | : public OpConversionPattern<memref::ReinterpretCastOp> { |
| 279 | public: |
| 280 | using OpConversionPattern::OpConversionPattern; |
| 281 | |
| 282 | LogicalResult |
| 283 | matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor, |
| 284 | ConversionPatternRewriter &rewriter) const override; |
| 285 | }; |
| 286 | |
| 287 | class CastPattern final : public OpConversionPattern<memref::CastOp> { |
| 288 | public: |
| 289 | using OpConversionPattern::OpConversionPattern; |
| 290 | |
| 291 | LogicalResult |
| 292 | matchAndRewrite(memref::CastOp op, OpAdaptor adaptor, |
| 293 | ConversionPatternRewriter &rewriter) const override { |
| 294 | Value src = adaptor.getSource(); |
| 295 | Type srcType = src.getType(); |
| 296 | |
| 297 | const TypeConverter *converter = getTypeConverter(); |
| 298 | Type dstType = converter->convertType(t: op.getType()); |
| 299 | if (srcType != dstType) |
| 300 | return rewriter.notifyMatchFailure(op, reasonCallback: [&](Diagnostic &diag) { |
| 301 | diag << "types doesn't match: " << srcType << " and " << dstType; |
| 302 | }); |
| 303 | |
| 304 | rewriter.replaceOp(op, newValues: src); |
| 305 | return success(); |
| 306 | } |
| 307 | }; |
| 308 | |
| 309 | /// Converts memref.extract_aligned_pointer_as_index to spirv.ConvertPtrToU. |
| 310 | class final |
| 311 | : public OpConversionPattern<memref::ExtractAlignedPointerAsIndexOp> { |
| 312 | public: |
| 313 | using OpConversionPattern::OpConversionPattern; |
| 314 | |
| 315 | LogicalResult |
| 316 | matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp , |
| 317 | OpAdaptor adaptor, |
| 318 | ConversionPatternRewriter &rewriter) const override; |
| 319 | }; |
| 320 | } // namespace |
| 321 | |
| 322 | //===----------------------------------------------------------------------===// |
| 323 | // AllocaOp |
| 324 | //===----------------------------------------------------------------------===// |
| 325 | |
| 326 | LogicalResult |
| 327 | AllocaOpPattern::matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor, |
| 328 | ConversionPatternRewriter &rewriter) const { |
| 329 | MemRefType allocType = allocaOp.getType(); |
| 330 | if (!isAllocationSupported(allocOp: allocaOp, type: allocType)) |
| 331 | return rewriter.notifyMatchFailure(arg&: allocaOp, msg: "unhandled allocation type" ); |
| 332 | |
| 333 | // Get the SPIR-V type for the allocation. |
| 334 | Type spirvType = getTypeConverter()->convertType(t: allocType); |
| 335 | if (!spirvType) |
| 336 | return rewriter.notifyMatchFailure(arg&: allocaOp, msg: "type conversion failed" ); |
| 337 | |
| 338 | rewriter.replaceOpWithNewOp<spirv::VariableOp>(op: allocaOp, args&: spirvType, |
| 339 | args: spirv::StorageClass::Function, |
| 340 | /*initializer=*/args: nullptr); |
| 341 | return success(); |
| 342 | } |
| 343 | |
| 344 | //===----------------------------------------------------------------------===// |
| 345 | // AllocOp |
| 346 | //===----------------------------------------------------------------------===// |
| 347 | |
| 348 | LogicalResult |
| 349 | AllocOpPattern::matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor, |
| 350 | ConversionPatternRewriter &rewriter) const { |
| 351 | MemRefType allocType = operation.getType(); |
| 352 | if (!isAllocationSupported(allocOp: operation, type: allocType)) |
| 353 | return rewriter.notifyMatchFailure(arg&: operation, msg: "unhandled allocation type" ); |
| 354 | |
| 355 | // Get the SPIR-V type for the allocation. |
| 356 | Type spirvType = getTypeConverter()->convertType(t: allocType); |
| 357 | if (!spirvType) |
| 358 | return rewriter.notifyMatchFailure(arg&: operation, msg: "type conversion failed" ); |
| 359 | |
| 360 | // Insert spirv.GlobalVariable for this allocation. |
| 361 | Operation *parent = |
| 362 | SymbolTable::getNearestSymbolTable(from: operation->getParentOp()); |
| 363 | if (!parent) |
| 364 | return failure(); |
| 365 | Location loc = operation.getLoc(); |
| 366 | spirv::GlobalVariableOp varOp; |
| 367 | { |
| 368 | OpBuilder::InsertionGuard guard(rewriter); |
| 369 | Block &entryBlock = *parent->getRegion(index: 0).begin(); |
| 370 | rewriter.setInsertionPointToStart(&entryBlock); |
| 371 | auto varOps = entryBlock.getOps<spirv::GlobalVariableOp>(); |
| 372 | std::string varName = |
| 373 | std::string("__workgroup_mem__" ) + |
| 374 | std::to_string(val: std::distance(first: varOps.begin(), last: varOps.end())); |
| 375 | varOp = rewriter.create<spirv::GlobalVariableOp>(location: loc, args&: spirvType, args&: varName, |
| 376 | /*initializer=*/args: nullptr); |
| 377 | } |
| 378 | |
| 379 | // Get pointer to global variable at the current scope. |
| 380 | rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(op: operation, args&: varOp); |
| 381 | return success(); |
| 382 | } |
| 383 | |
| 384 | //===----------------------------------------------------------------------===// |
| 385 | // AllocOp |
| 386 | //===----------------------------------------------------------------------===// |
| 387 | |
| 388 | LogicalResult |
| 389 | AtomicRMWOpPattern::matchAndRewrite(memref::AtomicRMWOp atomicOp, |
| 390 | OpAdaptor adaptor, |
| 391 | ConversionPatternRewriter &rewriter) const { |
| 392 | if (isa<FloatType>(Val: atomicOp.getType())) |
| 393 | return rewriter.notifyMatchFailure(arg&: atomicOp, |
| 394 | msg: "unimplemented floating-point case" ); |
| 395 | |
| 396 | auto memrefType = cast<MemRefType>(Val: atomicOp.getMemref().getType()); |
| 397 | std::optional<spirv::Scope> scope = getAtomicOpScope(type: memrefType); |
| 398 | if (!scope) |
| 399 | return rewriter.notifyMatchFailure(arg&: atomicOp, |
| 400 | msg: "unsupported memref memory space" ); |
| 401 | |
| 402 | auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>(); |
| 403 | Type resultType = typeConverter.convertType(t: atomicOp.getType()); |
| 404 | if (!resultType) |
| 405 | return rewriter.notifyMatchFailure(arg&: atomicOp, |
| 406 | msg: "failed to convert result type" ); |
| 407 | |
| 408 | auto loc = atomicOp.getLoc(); |
| 409 | Value ptr = |
| 410 | spirv::getElementPtr(typeConverter, baseType: memrefType, basePtr: adaptor.getMemref(), |
| 411 | indices: adaptor.getIndices(), loc, builder&: rewriter); |
| 412 | |
| 413 | if (!ptr) |
| 414 | return failure(); |
| 415 | |
| 416 | #define ATOMIC_CASE(kind, spirvOp) \ |
| 417 | case arith::AtomicRMWKind::kind: \ |
| 418 | rewriter.replaceOpWithNewOp<spirv::spirvOp>( \ |
| 419 | atomicOp, resultType, ptr, *scope, \ |
| 420 | spirv::MemorySemantics::AcquireRelease, adaptor.getValue()); \ |
| 421 | break |
| 422 | |
| 423 | switch (atomicOp.getKind()) { |
| 424 | ATOMIC_CASE(addi, AtomicIAddOp); |
| 425 | ATOMIC_CASE(maxs, AtomicSMaxOp); |
| 426 | ATOMIC_CASE(maxu, AtomicUMaxOp); |
| 427 | ATOMIC_CASE(mins, AtomicSMinOp); |
| 428 | ATOMIC_CASE(minu, AtomicUMinOp); |
| 429 | ATOMIC_CASE(ori, AtomicOrOp); |
| 430 | ATOMIC_CASE(andi, AtomicAndOp); |
| 431 | default: |
| 432 | return rewriter.notifyMatchFailure(arg&: atomicOp, msg: "unimplemented atomic kind" ); |
| 433 | } |
| 434 | |
| 435 | #undef ATOMIC_CASE |
| 436 | |
| 437 | return success(); |
| 438 | } |
| 439 | |
| 440 | //===----------------------------------------------------------------------===// |
| 441 | // DeallocOp |
| 442 | //===----------------------------------------------------------------------===// |
| 443 | |
| 444 | LogicalResult |
| 445 | DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation, |
| 446 | OpAdaptor adaptor, |
| 447 | ConversionPatternRewriter &rewriter) const { |
| 448 | MemRefType deallocType = cast<MemRefType>(Val: operation.getMemref().getType()); |
| 449 | if (!isAllocationSupported(allocOp: operation, type: deallocType)) |
| 450 | return rewriter.notifyMatchFailure(arg&: operation, msg: "unhandled allocation type" ); |
| 451 | rewriter.eraseOp(op: operation); |
| 452 | return success(); |
| 453 | } |
| 454 | |
| 455 | //===----------------------------------------------------------------------===// |
| 456 | // LoadOp |
| 457 | //===----------------------------------------------------------------------===// |
| 458 | |
| 459 | struct MemoryRequirements { |
| 460 | spirv::MemoryAccessAttr memoryAccess; |
| 461 | IntegerAttr alignment; |
| 462 | }; |
| 463 | |
| 464 | /// Given an accessed SPIR-V pointer, calculates its alignment requirements, if |
| 465 | /// any. |
| 466 | static FailureOr<MemoryRequirements> |
| 467 | calculateMemoryRequirements(Value accessedPtr, bool isNontemporal) { |
| 468 | MLIRContext *ctx = accessedPtr.getContext(); |
| 469 | |
| 470 | auto memoryAccess = spirv::MemoryAccess::None; |
| 471 | if (isNontemporal) { |
| 472 | memoryAccess = spirv::MemoryAccess::Nontemporal; |
| 473 | } |
| 474 | |
| 475 | auto ptrType = cast<spirv::PointerType>(Val: accessedPtr.getType()); |
| 476 | if (ptrType.getStorageClass() != spirv::StorageClass::PhysicalStorageBuffer) { |
| 477 | if (memoryAccess == spirv::MemoryAccess::None) { |
| 478 | return MemoryRequirements{.memoryAccess: spirv::MemoryAccessAttr{}, .alignment: IntegerAttr{}}; |
| 479 | } |
| 480 | return MemoryRequirements{.memoryAccess: spirv::MemoryAccessAttr::get(context: ctx, value: memoryAccess), |
| 481 | .alignment: IntegerAttr{}}; |
| 482 | } |
| 483 | |
| 484 | // PhysicalStorageBuffers require the `Aligned` attribute. |
| 485 | auto pointeeType = dyn_cast<spirv::ScalarType>(Val: ptrType.getPointeeType()); |
| 486 | if (!pointeeType) |
| 487 | return failure(); |
| 488 | |
| 489 | // For scalar types, the alignment is determined by their size. |
| 490 | std::optional<int64_t> sizeInBytes = pointeeType.getSizeInBytes(); |
| 491 | if (!sizeInBytes.has_value()) |
| 492 | return failure(); |
| 493 | |
| 494 | memoryAccess = memoryAccess | spirv::MemoryAccess::Aligned; |
| 495 | auto memAccessAttr = spirv::MemoryAccessAttr::get(context: ctx, value: memoryAccess); |
| 496 | auto alignment = IntegerAttr::get(type: IntegerType::get(context: ctx, width: 32), value: *sizeInBytes); |
| 497 | return MemoryRequirements{.memoryAccess: memAccessAttr, .alignment: alignment}; |
| 498 | } |
| 499 | |
| 500 | /// Given an accessed SPIR-V pointer and the original memref load/store |
| 501 | /// `memAccess` op, calculates the alignment requirements, if any. Takes into |
| 502 | /// account the alignment attributes applied to the load/store op. |
| 503 | template <class LoadOrStoreOp> |
| 504 | static FailureOr<MemoryRequirements> |
| 505 | calculateMemoryRequirements(Value accessedPtr, LoadOrStoreOp loadOrStoreOp) { |
| 506 | static_assert( |
| 507 | llvm::is_one_of<LoadOrStoreOp, memref::LoadOp, memref::StoreOp>::value, |
| 508 | "Must be called on either memref::LoadOp or memref::StoreOp" ); |
| 509 | |
| 510 | Operation *memrefAccessOp = loadOrStoreOp.getOperation(); |
| 511 | auto memrefMemAccess = memrefAccessOp->getAttrOfType<spirv::MemoryAccessAttr>( |
| 512 | name: spirv::attributeName<spirv::MemoryAccess>()); |
| 513 | auto memrefAlignment = |
| 514 | memrefAccessOp->getAttrOfType<IntegerAttr>(name: "alignment" ); |
| 515 | if (memrefMemAccess && memrefAlignment) |
| 516 | return MemoryRequirements{.memoryAccess: memrefMemAccess, .alignment: memrefAlignment}; |
| 517 | |
| 518 | return calculateMemoryRequirements(accessedPtr, |
| 519 | loadOrStoreOp.getNontemporal()); |
| 520 | } |
| 521 | |
| 522 | LogicalResult |
| 523 | IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, |
| 524 | ConversionPatternRewriter &rewriter) const { |
| 525 | auto loc = loadOp.getLoc(); |
| 526 | auto memrefType = cast<MemRefType>(Val: loadOp.getMemref().getType()); |
| 527 | if (!memrefType.getElementType().isSignlessInteger()) |
| 528 | return failure(); |
| 529 | |
| 530 | const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>(); |
| 531 | Value accessChain = |
| 532 | spirv::getElementPtr(typeConverter, baseType: memrefType, basePtr: adaptor.getMemref(), |
| 533 | indices: adaptor.getIndices(), loc, builder&: rewriter); |
| 534 | |
| 535 | if (!accessChain) |
| 536 | return failure(); |
| 537 | |
| 538 | int srcBits = memrefType.getElementType().getIntOrFloatBitWidth(); |
| 539 | bool isBool = srcBits == 1; |
| 540 | if (isBool) |
| 541 | srcBits = typeConverter.getOptions().boolNumBits; |
| 542 | |
| 543 | auto pointerType = typeConverter.convertType<spirv::PointerType>(t: memrefType); |
| 544 | if (!pointerType) |
| 545 | return rewriter.notifyMatchFailure(arg&: loadOp, msg: "failed to convert memref type" ); |
| 546 | |
| 547 | Type pointeeType = pointerType.getPointeeType(); |
| 548 | Type dstType; |
| 549 | if (typeConverter.allows(capability: spirv::Capability::Kernel)) { |
| 550 | if (auto arrayType = dyn_cast<spirv::ArrayType>(Val&: pointeeType)) |
| 551 | dstType = arrayType.getElementType(); |
| 552 | else |
| 553 | dstType = pointeeType; |
| 554 | } else { |
| 555 | // For Vulkan we need to extract element from wrapping struct and array. |
| 556 | Type structElemType = |
| 557 | cast<spirv::StructType>(Val&: pointeeType).getElementType(0); |
| 558 | if (auto arrayType = dyn_cast<spirv::ArrayType>(Val&: structElemType)) |
| 559 | dstType = arrayType.getElementType(); |
| 560 | else |
| 561 | dstType = cast<spirv::RuntimeArrayType>(Val&: structElemType).getElementType(); |
| 562 | } |
| 563 | int dstBits = dstType.getIntOrFloatBitWidth(); |
| 564 | assert(dstBits % srcBits == 0); |
| 565 | |
| 566 | // If the rewritten load op has the same bit width, use the loading value |
| 567 | // directly. |
| 568 | if (srcBits == dstBits) { |
| 569 | auto memoryRequirements = calculateMemoryRequirements(accessedPtr: accessChain, loadOrStoreOp: loadOp); |
| 570 | if (failed(Result: memoryRequirements)) |
| 571 | return rewriter.notifyMatchFailure( |
| 572 | arg&: loadOp, msg: "failed to determine memory requirements" ); |
| 573 | |
| 574 | auto [memoryAccess, alignment] = *memoryRequirements; |
| 575 | Value loadVal = rewriter.create<spirv::LoadOp>(location: loc, args&: accessChain, |
| 576 | args&: memoryAccess, args&: alignment); |
| 577 | if (isBool) |
| 578 | loadVal = castIntNToBool(loc, srcInt: loadVal, builder&: rewriter); |
| 579 | rewriter.replaceOp(op: loadOp, newValues: loadVal); |
| 580 | return success(); |
| 581 | } |
| 582 | |
| 583 | // Bitcasting is currently unsupported for Kernel capability / |
| 584 | // spirv.PtrAccessChain. |
| 585 | if (typeConverter.allows(capability: spirv::Capability::Kernel)) |
| 586 | return failure(); |
| 587 | |
| 588 | auto accessChainOp = accessChain.getDefiningOp<spirv::AccessChainOp>(); |
| 589 | if (!accessChainOp) |
| 590 | return failure(); |
| 591 | |
| 592 | // Assume that getElementPtr() works linearizely. If it's a scalar, the method |
| 593 | // still returns a linearized accessing. If the accessing is not linearized, |
| 594 | // there will be offset issues. |
| 595 | assert(accessChainOp.getIndices().size() == 2); |
| 596 | Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, op: accessChainOp, |
| 597 | sourceBits: srcBits, targetBits: dstBits, builder&: rewriter); |
| 598 | auto memoryRequirements = calculateMemoryRequirements(accessedPtr: adjustedPtr, loadOrStoreOp: loadOp); |
| 599 | if (failed(Result: memoryRequirements)) |
| 600 | return rewriter.notifyMatchFailure( |
| 601 | arg&: loadOp, msg: "failed to determine memory requirements" ); |
| 602 | |
| 603 | auto [memoryAccess, alignment] = *memoryRequirements; |
| 604 | Value spvLoadOp = rewriter.create<spirv::LoadOp>(location: loc, args&: dstType, args&: adjustedPtr, |
| 605 | args&: memoryAccess, args&: alignment); |
| 606 | |
| 607 | // Shift the bits to the rightmost. |
| 608 | // ____XXXX________ -> ____________XXXX |
| 609 | Value lastDim = accessChainOp->getOperand(idx: accessChainOp.getNumOperands() - 1); |
| 610 | Value offset = getOffsetForBitwidth(loc, srcIdx: lastDim, sourceBits: srcBits, targetBits: dstBits, builder&: rewriter); |
| 611 | Value result = rewriter.createOrFold<spirv::ShiftRightArithmeticOp>( |
| 612 | location: loc, args: spvLoadOp.getType(), args&: spvLoadOp, args&: offset); |
| 613 | |
| 614 | // Apply the mask to extract corresponding bits. |
| 615 | Value mask = rewriter.createOrFold<spirv::ConstantOp>( |
| 616 | location: loc, args&: dstType, args: rewriter.getIntegerAttr(type: dstType, value: (1 << srcBits) - 1)); |
| 617 | result = |
| 618 | rewriter.createOrFold<spirv::BitwiseAndOp>(location: loc, args&: dstType, args&: result, args&: mask); |
| 619 | |
| 620 | // Apply sign extension on the loading value unconditionally. The signedness |
| 621 | // semantic is carried in the operator itself, we relies other pattern to |
| 622 | // handle the casting. |
| 623 | IntegerAttr shiftValueAttr = |
| 624 | rewriter.getIntegerAttr(type: dstType, value: dstBits - srcBits); |
| 625 | Value shiftValue = |
| 626 | rewriter.createOrFold<spirv::ConstantOp>(location: loc, args&: dstType, args&: shiftValueAttr); |
| 627 | result = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>(location: loc, args&: dstType, |
| 628 | args&: result, args&: shiftValue); |
| 629 | result = rewriter.createOrFold<spirv::ShiftRightArithmeticOp>( |
| 630 | location: loc, args&: dstType, args&: result, args&: shiftValue); |
| 631 | |
| 632 | rewriter.replaceOp(op: loadOp, newValues: result); |
| 633 | |
| 634 | assert(accessChainOp.use_empty()); |
| 635 | rewriter.eraseOp(op: accessChainOp); |
| 636 | |
| 637 | return success(); |
| 638 | } |
| 639 | |
| 640 | LogicalResult |
| 641 | LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, |
| 642 | ConversionPatternRewriter &rewriter) const { |
| 643 | auto memrefType = cast<MemRefType>(Val: loadOp.getMemref().getType()); |
| 644 | if (memrefType.getElementType().isSignlessInteger()) |
| 645 | return failure(); |
| 646 | Value loadPtr = spirv::getElementPtr( |
| 647 | typeConverter: *getTypeConverter<SPIRVTypeConverter>(), baseType: memrefType, basePtr: adaptor.getMemref(), |
| 648 | indices: adaptor.getIndices(), loc: loadOp.getLoc(), builder&: rewriter); |
| 649 | |
| 650 | if (!loadPtr) |
| 651 | return failure(); |
| 652 | |
| 653 | auto memoryRequirements = calculateMemoryRequirements(accessedPtr: loadPtr, loadOrStoreOp: loadOp); |
| 654 | if (failed(Result: memoryRequirements)) |
| 655 | return rewriter.notifyMatchFailure( |
| 656 | arg&: loadOp, msg: "failed to determine memory requirements" ); |
| 657 | |
| 658 | auto [memoryAccess, alignment] = *memoryRequirements; |
| 659 | rewriter.replaceOpWithNewOp<spirv::LoadOp>(op: loadOp, args&: loadPtr, args&: memoryAccess, |
| 660 | args&: alignment); |
| 661 | return success(); |
| 662 | } |
| 663 | |
| 664 | LogicalResult |
| 665 | IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor, |
| 666 | ConversionPatternRewriter &rewriter) const { |
| 667 | auto memrefType = cast<MemRefType>(Val: storeOp.getMemref().getType()); |
| 668 | if (!memrefType.getElementType().isSignlessInteger()) |
| 669 | return rewriter.notifyMatchFailure(arg&: storeOp, |
| 670 | msg: "element type is not a signless int" ); |
| 671 | |
| 672 | auto loc = storeOp.getLoc(); |
| 673 | auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>(); |
| 674 | Value accessChain = |
| 675 | spirv::getElementPtr(typeConverter, baseType: memrefType, basePtr: adaptor.getMemref(), |
| 676 | indices: adaptor.getIndices(), loc, builder&: rewriter); |
| 677 | |
| 678 | if (!accessChain) |
| 679 | return rewriter.notifyMatchFailure( |
| 680 | arg&: storeOp, msg: "failed to convert element pointer type" ); |
| 681 | |
| 682 | int srcBits = memrefType.getElementType().getIntOrFloatBitWidth(); |
| 683 | |
| 684 | bool isBool = srcBits == 1; |
| 685 | if (isBool) |
| 686 | srcBits = typeConverter.getOptions().boolNumBits; |
| 687 | |
| 688 | auto pointerType = typeConverter.convertType<spirv::PointerType>(t: memrefType); |
| 689 | if (!pointerType) |
| 690 | return rewriter.notifyMatchFailure(arg&: storeOp, |
| 691 | msg: "failed to convert memref type" ); |
| 692 | |
| 693 | Type pointeeType = pointerType.getPointeeType(); |
| 694 | IntegerType dstType; |
| 695 | if (typeConverter.allows(capability: spirv::Capability::Kernel)) { |
| 696 | if (auto arrayType = dyn_cast<spirv::ArrayType>(Val&: pointeeType)) |
| 697 | dstType = dyn_cast<IntegerType>(Val: arrayType.getElementType()); |
| 698 | else |
| 699 | dstType = dyn_cast<IntegerType>(Val&: pointeeType); |
| 700 | } else { |
| 701 | // For Vulkan we need to extract element from wrapping struct and array. |
| 702 | Type structElemType = |
| 703 | cast<spirv::StructType>(Val&: pointeeType).getElementType(0); |
| 704 | if (auto arrayType = dyn_cast<spirv::ArrayType>(Val&: structElemType)) |
| 705 | dstType = dyn_cast<IntegerType>(Val: arrayType.getElementType()); |
| 706 | else |
| 707 | dstType = dyn_cast<IntegerType>( |
| 708 | Val: cast<spirv::RuntimeArrayType>(Val&: structElemType).getElementType()); |
| 709 | } |
| 710 | |
| 711 | if (!dstType) |
| 712 | return rewriter.notifyMatchFailure( |
| 713 | arg&: storeOp, msg: "failed to determine destination element type" ); |
| 714 | |
| 715 | int dstBits = static_cast<int>(dstType.getWidth()); |
| 716 | assert(dstBits % srcBits == 0); |
| 717 | |
| 718 | if (srcBits == dstBits) { |
| 719 | auto memoryRequirements = calculateMemoryRequirements(accessedPtr: accessChain, loadOrStoreOp: storeOp); |
| 720 | if (failed(Result: memoryRequirements)) |
| 721 | return rewriter.notifyMatchFailure( |
| 722 | arg&: storeOp, msg: "failed to determine memory requirements" ); |
| 723 | |
| 724 | auto [memoryAccess, alignment] = *memoryRequirements; |
| 725 | Value storeVal = adaptor.getValue(); |
| 726 | if (isBool) |
| 727 | storeVal = castBoolToIntN(loc, srcBool: storeVal, dstType, builder&: rewriter); |
| 728 | rewriter.replaceOpWithNewOp<spirv::StoreOp>(op: storeOp, args&: accessChain, args&: storeVal, |
| 729 | args&: memoryAccess, args&: alignment); |
| 730 | return success(); |
| 731 | } |
| 732 | |
| 733 | // Bitcasting is currently unsupported for Kernel capability / |
| 734 | // spirv.PtrAccessChain. |
| 735 | if (typeConverter.allows(capability: spirv::Capability::Kernel)) |
| 736 | return failure(); |
| 737 | |
| 738 | auto accessChainOp = accessChain.getDefiningOp<spirv::AccessChainOp>(); |
| 739 | if (!accessChainOp) |
| 740 | return failure(); |
| 741 | |
| 742 | // Since there are multiple threads in the processing, the emulation will be |
| 743 | // done with atomic operations. E.g., if the stored value is i8, rewrite the |
| 744 | // StoreOp to: |
| 745 | // 1) load a 32-bit integer |
| 746 | // 2) clear 8 bits in the loaded value |
| 747 | // 3) set 8 bits in the loaded value |
| 748 | // 4) store 32-bit value back |
| 749 | // |
| 750 | // Step 2 is done with AtomicAnd, and step 3 is done with AtomicOr (of the |
| 751 | // loaded 32-bit value and the shifted 8-bit store value) as another atomic |
| 752 | // step. |
| 753 | assert(accessChainOp.getIndices().size() == 2); |
| 754 | Value lastDim = accessChainOp->getOperand(idx: accessChainOp.getNumOperands() - 1); |
| 755 | Value offset = getOffsetForBitwidth(loc, srcIdx: lastDim, sourceBits: srcBits, targetBits: dstBits, builder&: rewriter); |
| 756 | |
| 757 | // Create a mask to clear the destination. E.g., if it is the second i8 in |
| 758 | // i32, 0xFFFF00FF is created. |
| 759 | Value mask = rewriter.createOrFold<spirv::ConstantOp>( |
| 760 | location: loc, args&: dstType, args: rewriter.getIntegerAttr(type: dstType, value: (1 << srcBits) - 1)); |
| 761 | Value clearBitsMask = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>( |
| 762 | location: loc, args&: dstType, args&: mask, args&: offset); |
| 763 | clearBitsMask = |
| 764 | rewriter.createOrFold<spirv::NotOp>(location: loc, args&: dstType, args&: clearBitsMask); |
| 765 | |
| 766 | Value storeVal = shiftValue(loc, value: adaptor.getValue(), offset, mask, builder&: rewriter); |
| 767 | Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, op: accessChainOp, |
| 768 | sourceBits: srcBits, targetBits: dstBits, builder&: rewriter); |
| 769 | std::optional<spirv::Scope> scope = getAtomicOpScope(type: memrefType); |
| 770 | if (!scope) |
| 771 | return rewriter.notifyMatchFailure(arg&: storeOp, msg: "atomic scope not available" ); |
| 772 | |
| 773 | Value result = rewriter.create<spirv::AtomicAndOp>( |
| 774 | location: loc, args&: dstType, args&: adjustedPtr, args&: *scope, args: spirv::MemorySemantics::AcquireRelease, |
| 775 | args&: clearBitsMask); |
| 776 | result = rewriter.create<spirv::AtomicOrOp>( |
| 777 | location: loc, args&: dstType, args&: adjustedPtr, args&: *scope, args: spirv::MemorySemantics::AcquireRelease, |
| 778 | args&: storeVal); |
| 779 | |
| 780 | // The AtomicOrOp has no side effect. Since it is already inserted, we can |
| 781 | // just remove the original StoreOp. Note that rewriter.replaceOp() |
| 782 | // doesn't work because it only accepts that the numbers of result are the |
| 783 | // same. |
| 784 | rewriter.eraseOp(op: storeOp); |
| 785 | |
| 786 | assert(accessChainOp.use_empty()); |
| 787 | rewriter.eraseOp(op: accessChainOp); |
| 788 | |
| 789 | return success(); |
| 790 | } |
| 791 | |
| 792 | //===----------------------------------------------------------------------===// |
| 793 | // MemorySpaceCastOp |
| 794 | //===----------------------------------------------------------------------===// |
| 795 | |
| 796 | LogicalResult MemorySpaceCastOpPattern::matchAndRewrite( |
| 797 | memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor, |
| 798 | ConversionPatternRewriter &rewriter) const { |
| 799 | Location loc = addrCastOp.getLoc(); |
| 800 | auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>(); |
| 801 | if (!typeConverter.allows(capability: spirv::Capability::Kernel)) |
| 802 | return rewriter.notifyMatchFailure( |
| 803 | arg&: loc, msg: "address space casts require kernel capability" ); |
| 804 | |
| 805 | auto sourceType = dyn_cast<MemRefType>(Val: addrCastOp.getSource().getType()); |
| 806 | if (!sourceType) |
| 807 | return rewriter.notifyMatchFailure( |
| 808 | arg&: loc, msg: "SPIR-V lowering requires ranked memref types" ); |
| 809 | auto resultType = cast<MemRefType>(Val: addrCastOp.getResult().getType()); |
| 810 | |
| 811 | auto sourceStorageClassAttr = |
| 812 | dyn_cast_or_null<spirv::StorageClassAttr>(Val: sourceType.getMemorySpace()); |
| 813 | if (!sourceStorageClassAttr) |
| 814 | return rewriter.notifyMatchFailure(loc, reasonCallback: [sourceType](Diagnostic &diag) { |
| 815 | diag << "source address space " << sourceType.getMemorySpace() |
| 816 | << " must be a SPIR-V storage class" ; |
| 817 | }); |
| 818 | auto resultStorageClassAttr = |
| 819 | dyn_cast_or_null<spirv::StorageClassAttr>(Val: resultType.getMemorySpace()); |
| 820 | if (!resultStorageClassAttr) |
| 821 | return rewriter.notifyMatchFailure(loc, reasonCallback: [resultType](Diagnostic &diag) { |
| 822 | diag << "result address space " << resultType.getMemorySpace() |
| 823 | << " must be a SPIR-V storage class" ; |
| 824 | }); |
| 825 | |
| 826 | spirv::StorageClass sourceSc = sourceStorageClassAttr.getValue(); |
| 827 | spirv::StorageClass resultSc = resultStorageClassAttr.getValue(); |
| 828 | |
| 829 | Value result = adaptor.getSource(); |
| 830 | Type resultPtrType = typeConverter.convertType(t: resultType); |
| 831 | if (!resultPtrType) |
| 832 | return rewriter.notifyMatchFailure(arg&: addrCastOp, |
| 833 | msg: "failed to convert memref type" ); |
| 834 | |
| 835 | Type genericPtrType = resultPtrType; |
| 836 | // SPIR-V doesn't have a general address space cast operation. Instead, it has |
| 837 | // conversions to and from generic pointers. To implement the general case, |
| 838 | // we use specific-to-generic conversions when the source class is not |
| 839 | // generic. Then when the result storage class is not generic, we convert the |
| 840 | // generic pointer (either the input on ar intermediate result) to that |
| 841 | // class. This also means that we'll need the intermediate generic pointer |
| 842 | // type if neither the source or destination have it. |
| 843 | if (sourceSc != spirv::StorageClass::Generic && |
| 844 | resultSc != spirv::StorageClass::Generic) { |
| 845 | Type intermediateType = |
| 846 | MemRefType::get(shape: sourceType.getShape(), elementType: sourceType.getElementType(), |
| 847 | layout: sourceType.getLayout(), |
| 848 | memorySpace: rewriter.getAttr<spirv::StorageClassAttr>( |
| 849 | args: spirv::StorageClass::Generic)); |
| 850 | genericPtrType = typeConverter.convertType(t: intermediateType); |
| 851 | } |
| 852 | if (sourceSc != spirv::StorageClass::Generic) { |
| 853 | result = |
| 854 | rewriter.create<spirv::PtrCastToGenericOp>(location: loc, args&: genericPtrType, args&: result); |
| 855 | } |
| 856 | if (resultSc != spirv::StorageClass::Generic) { |
| 857 | result = |
| 858 | rewriter.create<spirv::GenericCastToPtrOp>(location: loc, args&: resultPtrType, args&: result); |
| 859 | } |
| 860 | rewriter.replaceOp(op: addrCastOp, newValues: result); |
| 861 | return success(); |
| 862 | } |
| 863 | |
| 864 | LogicalResult |
| 865 | StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor, |
| 866 | ConversionPatternRewriter &rewriter) const { |
| 867 | auto memrefType = cast<MemRefType>(Val: storeOp.getMemref().getType()); |
| 868 | if (memrefType.getElementType().isSignlessInteger()) |
| 869 | return rewriter.notifyMatchFailure(arg&: storeOp, msg: "signless int" ); |
| 870 | auto storePtr = spirv::getElementPtr( |
| 871 | typeConverter: *getTypeConverter<SPIRVTypeConverter>(), baseType: memrefType, basePtr: adaptor.getMemref(), |
| 872 | indices: adaptor.getIndices(), loc: storeOp.getLoc(), builder&: rewriter); |
| 873 | |
| 874 | if (!storePtr) |
| 875 | return rewriter.notifyMatchFailure(arg&: storeOp, msg: "type conversion failed" ); |
| 876 | |
| 877 | auto memoryRequirements = calculateMemoryRequirements(accessedPtr: storePtr, loadOrStoreOp: storeOp); |
| 878 | if (failed(Result: memoryRequirements)) |
| 879 | return rewriter.notifyMatchFailure( |
| 880 | arg&: storeOp, msg: "failed to determine memory requirements" ); |
| 881 | |
| 882 | auto [memoryAccess, alignment] = *memoryRequirements; |
| 883 | rewriter.replaceOpWithNewOp<spirv::StoreOp>( |
| 884 | op: storeOp, args&: storePtr, args: adaptor.getValue(), args&: memoryAccess, args&: alignment); |
| 885 | return success(); |
| 886 | } |
| 887 | |
| 888 | LogicalResult ReinterpretCastPattern::matchAndRewrite( |
| 889 | memref::ReinterpretCastOp op, OpAdaptor adaptor, |
| 890 | ConversionPatternRewriter &rewriter) const { |
| 891 | Value src = adaptor.getSource(); |
| 892 | auto srcType = dyn_cast<spirv::PointerType>(Val: src.getType()); |
| 893 | |
| 894 | if (!srcType) |
| 895 | return rewriter.notifyMatchFailure(op, reasonCallback: [&](Diagnostic &diag) { |
| 896 | diag << "invalid src type " << src.getType(); |
| 897 | }); |
| 898 | |
| 899 | const TypeConverter *converter = getTypeConverter(); |
| 900 | |
| 901 | auto dstType = converter->convertType<spirv::PointerType>(t: op.getType()); |
| 902 | if (dstType != srcType) |
| 903 | return rewriter.notifyMatchFailure(op, reasonCallback: [&](Diagnostic &diag) { |
| 904 | diag << "invalid dst type " << op.getType(); |
| 905 | }); |
| 906 | |
| 907 | OpFoldResult offset = |
| 908 | getMixedValues(staticValues: adaptor.getStaticOffsets(), dynamicValues: adaptor.getOffsets(), b&: rewriter) |
| 909 | .front(); |
| 910 | if (isZeroInteger(v: offset)) { |
| 911 | rewriter.replaceOp(op, newValues: src); |
| 912 | return success(); |
| 913 | } |
| 914 | |
| 915 | Type intType = converter->convertType(t: rewriter.getIndexType()); |
| 916 | if (!intType) |
| 917 | return rewriter.notifyMatchFailure(arg&: op, msg: "failed to convert index type" ); |
| 918 | |
| 919 | Location loc = op.getLoc(); |
| 920 | auto offsetValue = [&]() -> Value { |
| 921 | if (auto val = dyn_cast<Value>(Val&: offset)) |
| 922 | return val; |
| 923 | |
| 924 | int64_t attrVal = cast<IntegerAttr>(Val: cast<Attribute>(Val&: offset)).getInt(); |
| 925 | Attribute attr = rewriter.getIntegerAttr(type: intType, value: attrVal); |
| 926 | return rewriter.createOrFold<spirv::ConstantOp>(location: loc, args&: intType, args&: attr); |
| 927 | }(); |
| 928 | |
| 929 | rewriter.replaceOpWithNewOp<spirv::InBoundsPtrAccessChainOp>( |
| 930 | op, args&: src, args&: offsetValue, args: ValueRange()); |
| 931 | return success(); |
| 932 | } |
| 933 | |
| 934 | //===----------------------------------------------------------------------===// |
| 935 | // ExtractAlignedPointerAsIndexOp |
| 936 | //===----------------------------------------------------------------------===// |
| 937 | |
| 938 | LogicalResult ExtractAlignedPointerAsIndexOpPattern::matchAndRewrite( |
| 939 | memref::ExtractAlignedPointerAsIndexOp , OpAdaptor adaptor, |
| 940 | ConversionPatternRewriter &rewriter) const { |
| 941 | auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>(); |
| 942 | Type indexType = typeConverter.getIndexType(); |
| 943 | rewriter.replaceOpWithNewOp<spirv::ConvertPtrToUOp>(op: extractOp, args&: indexType, |
| 944 | args: adaptor.getSource()); |
| 945 | return success(); |
| 946 | } |
| 947 | |
| 948 | //===----------------------------------------------------------------------===// |
| 949 | // Pattern population |
| 950 | //===----------------------------------------------------------------------===// |
| 951 | |
| 952 | namespace mlir { |
| 953 | void populateMemRefToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, |
| 954 | RewritePatternSet &patterns) { |
| 955 | patterns |
| 956 | .add<AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern, |
| 957 | DeallocOpPattern, IntLoadOpPattern, IntStoreOpPattern, LoadOpPattern, |
| 958 | MemorySpaceCastOpPattern, StoreOpPattern, ReinterpretCastPattern, |
| 959 | CastPattern, ExtractAlignedPointerAsIndexOpPattern>( |
| 960 | arg: typeConverter, args: patterns.getContext()); |
| 961 | } |
| 962 | } // namespace mlir |
| 963 | |