1//===- MemRefToSPIRV.cpp - MemRef to SPIR-V Patterns ----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements patterns to convert MemRef dialect to SPIR-V dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/Arith/IR/Arith.h"
14#include "mlir/Dialect/MemRef/IR/MemRef.h"
15#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
16#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
17#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
18#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
19#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
20#include "mlir/IR/BuiltinAttributes.h"
21#include "mlir/IR/BuiltinTypes.h"
22#include "mlir/IR/MLIRContext.h"
23#include "mlir/IR/Visitors.h"
24#include "mlir/Support/LogicalResult.h"
25#include "llvm/Support/Debug.h"
26#include <cassert>
27#include <optional>
28
29#define DEBUG_TYPE "memref-to-spirv-pattern"
30
31using namespace mlir;
32
33//===----------------------------------------------------------------------===//
34// Utility functions
35//===----------------------------------------------------------------------===//
36
37/// Returns the offset of the value in `targetBits` representation.
38///
39/// `srcIdx` is an index into a 1-D array with each element having `sourceBits`.
40/// It's assumed to be non-negative.
41///
42/// When accessing an element in the array treating as having elements of
43/// `targetBits`, multiple values are loaded in the same time. The method
44/// returns the offset where the `srcIdx` locates in the value. For example, if
45/// `sourceBits` equals to 8 and `targetBits` equals to 32, the x-th element is
46/// located at (x % 4) * 8. Because there are four elements in one i32, and one
47/// element has 8 bits.
48static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits,
49 int targetBits, OpBuilder &builder) {
50 assert(targetBits % sourceBits == 0);
51 Type type = srcIdx.getType();
52 IntegerAttr idxAttr = builder.getIntegerAttr(type, targetBits / sourceBits);
53 auto idx = builder.createOrFold<spirv::ConstantOp>(loc, type, idxAttr);
54 IntegerAttr srcBitsAttr = builder.getIntegerAttr(type, sourceBits);
55 auto srcBitsValue =
56 builder.createOrFold<spirv::ConstantOp>(loc, type, srcBitsAttr);
57 auto m = builder.createOrFold<spirv::UModOp>(loc, srcIdx, idx);
58 return builder.createOrFold<spirv::IMulOp>(loc, type, m, srcBitsValue);
59}
60
61/// Returns an adjusted spirv::AccessChainOp. Based on the
62/// extension/capabilities, certain integer bitwidths `sourceBits` might not be
63/// supported. During conversion if a memref of an unsupported type is used,
64/// load/stores to this memref need to be modified to use a supported higher
65/// bitwidth `targetBits` and extracting the required bits. For an accessing a
66/// 1D array (spirv.array or spirv.rtarray), the last index is modified to load
67/// the bits needed. The extraction of the actual bits needed are handled
68/// separately. Note that this only works for a 1-D tensor.
69static Value
70adjustAccessChainForBitwidth(const SPIRVTypeConverter &typeConverter,
71 spirv::AccessChainOp op, int sourceBits,
72 int targetBits, OpBuilder &builder) {
73 assert(targetBits % sourceBits == 0);
74 const auto loc = op.getLoc();
75 Value lastDim = op->getOperand(op.getNumOperands() - 1);
76 Type type = lastDim.getType();
77 IntegerAttr attr = builder.getIntegerAttr(type, targetBits / sourceBits);
78 auto idx = builder.createOrFold<spirv::ConstantOp>(loc, type, attr);
79 auto indices = llvm::to_vector<4>(op.getIndices());
80 // There are two elements if this is a 1-D tensor.
81 assert(indices.size() == 2);
82 indices.back() = builder.createOrFold<spirv::SDivOp>(loc, lastDim, idx);
83 Type t = typeConverter.convertType(op.getComponentPtr().getType());
84 return builder.create<spirv::AccessChainOp>(loc, t, op.getBasePtr(), indices);
85}
86
87/// Casts the given `srcBool` into an integer of `dstType`.
88static Value castBoolToIntN(Location loc, Value srcBool, Type dstType,
89 OpBuilder &builder) {
90 assert(srcBool.getType().isInteger(1));
91 if (dstType.isInteger(width: 1))
92 return srcBool;
93 Value zero = spirv::ConstantOp::getZero(dstType, loc, builder);
94 Value one = spirv::ConstantOp::getOne(dstType, loc, builder);
95 return builder.createOrFold<spirv::SelectOp>(loc, dstType, srcBool, one,
96 zero);
97}
98
99/// Returns the `targetBits`-bit value shifted by the given `offset`, and cast
100/// to the type destination type, and masked.
101static Value shiftValue(Location loc, Value value, Value offset, Value mask,
102 OpBuilder &builder) {
103 IntegerType dstType = cast<IntegerType>(mask.getType());
104 int targetBits = static_cast<int>(dstType.getWidth());
105 int valueBits = value.getType().getIntOrFloatBitWidth();
106 assert(valueBits <= targetBits);
107
108 if (valueBits == 1) {
109 value = castBoolToIntN(loc, value, dstType, builder);
110 } else {
111 if (valueBits < targetBits) {
112 value = builder.create<spirv::UConvertOp>(
113 loc, builder.getIntegerType(targetBits), value);
114 }
115
116 value = builder.createOrFold<spirv::BitwiseAndOp>(loc, value, mask);
117 }
118 return builder.createOrFold<spirv::ShiftLeftLogicalOp>(loc, value.getType(),
119 value, offset);
120}
121
122/// Returns true if the allocations of memref `type` generated from `allocOp`
123/// can be lowered to SPIR-V.
124static bool isAllocationSupported(Operation *allocOp, MemRefType type) {
125 if (isa<memref::AllocOp, memref::DeallocOp>(allocOp)) {
126 auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
127 if (!sc || sc.getValue() != spirv::StorageClass::Workgroup)
128 return false;
129 } else if (isa<memref::AllocaOp>(allocOp)) {
130 auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
131 if (!sc || sc.getValue() != spirv::StorageClass::Function)
132 return false;
133 } else {
134 return false;
135 }
136
137 // Currently only support static shape and int or float or vector of int or
138 // float element type.
139 if (!type.hasStaticShape())
140 return false;
141
142 Type elementType = type.getElementType();
143 if (auto vecType = dyn_cast<VectorType>(elementType))
144 elementType = vecType.getElementType();
145 return elementType.isIntOrFloat();
146}
147
148/// Returns the scope to use for atomic operations use for emulating store
149/// operations of unsupported integer bitwidths, based on the memref
150/// type. Returns std::nullopt on failure.
151static std::optional<spirv::Scope> getAtomicOpScope(MemRefType type) {
152 auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
153 switch (sc.getValue()) {
154 case spirv::StorageClass::StorageBuffer:
155 return spirv::Scope::Device;
156 case spirv::StorageClass::Workgroup:
157 return spirv::Scope::Workgroup;
158 default:
159 break;
160 }
161 return {};
162}
163
164/// Casts the given `srcInt` into a boolean value.
165static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder) {
166 if (srcInt.getType().isInteger(width: 1))
167 return srcInt;
168
169 auto one = spirv::ConstantOp::getOne(srcInt.getType(), loc, builder);
170 return builder.createOrFold<spirv::IEqualOp>(loc, srcInt, one);
171}
172
173//===----------------------------------------------------------------------===//
174// Operation conversion
175//===----------------------------------------------------------------------===//
176
177// Note that DRR cannot be used for the patterns in this file: we may need to
178// convert type along the way, which requires ConversionPattern. DRR generates
179// normal RewritePattern.
180
181namespace {
182
183/// Converts memref.alloca to SPIR-V Function variables.
184class AllocaOpPattern final : public OpConversionPattern<memref::AllocaOp> {
185public:
186 using OpConversionPattern<memref::AllocaOp>::OpConversionPattern;
187
188 LogicalResult
189 matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
190 ConversionPatternRewriter &rewriter) const override;
191};
192
193/// Converts an allocation operation to SPIR-V. Currently only supports lowering
194/// to Workgroup memory when the size is constant. Note that this pattern needs
195/// to be applied in a pass that runs at least at spirv.module scope since it
196/// wil ladd global variables into the spirv.module.
197class AllocOpPattern final : public OpConversionPattern<memref::AllocOp> {
198public:
199 using OpConversionPattern<memref::AllocOp>::OpConversionPattern;
200
201 LogicalResult
202 matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
203 ConversionPatternRewriter &rewriter) const override;
204};
205
206/// Converts memref.automic_rmw operations to SPIR-V atomic operations.
207class AtomicRMWOpPattern final
208 : public OpConversionPattern<memref::AtomicRMWOp> {
209public:
210 using OpConversionPattern<memref::AtomicRMWOp>::OpConversionPattern;
211
212 LogicalResult
213 matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
214 ConversionPatternRewriter &rewriter) const override;
215};
216
217/// Removed a deallocation if it is a supported allocation. Currently only
218/// removes deallocation if the memory space is workgroup memory.
219class DeallocOpPattern final : public OpConversionPattern<memref::DeallocOp> {
220public:
221 using OpConversionPattern<memref::DeallocOp>::OpConversionPattern;
222
223 LogicalResult
224 matchAndRewrite(memref::DeallocOp operation, OpAdaptor adaptor,
225 ConversionPatternRewriter &rewriter) const override;
226};
227
228/// Converts memref.load to spirv.Load + spirv.AccessChain on integers.
229class IntLoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
230public:
231 using OpConversionPattern<memref::LoadOp>::OpConversionPattern;
232
233 LogicalResult
234 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
235 ConversionPatternRewriter &rewriter) const override;
236};
237
238/// Converts memref.load to spirv.Load + spirv.AccessChain.
239class LoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
240public:
241 using OpConversionPattern<memref::LoadOp>::OpConversionPattern;
242
243 LogicalResult
244 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
245 ConversionPatternRewriter &rewriter) const override;
246};
247
248/// Converts memref.store to spirv.Store on integers.
249class IntStoreOpPattern final : public OpConversionPattern<memref::StoreOp> {
250public:
251 using OpConversionPattern<memref::StoreOp>::OpConversionPattern;
252
253 LogicalResult
254 matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
255 ConversionPatternRewriter &rewriter) const override;
256};
257
258/// Converts memref.memory_space_cast to the appropriate spirv cast operations.
259class MemorySpaceCastOpPattern final
260 : public OpConversionPattern<memref::MemorySpaceCastOp> {
261public:
262 using OpConversionPattern<memref::MemorySpaceCastOp>::OpConversionPattern;
263
264 LogicalResult
265 matchAndRewrite(memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor,
266 ConversionPatternRewriter &rewriter) const override;
267};
268
269/// Converts memref.store to spirv.Store.
270class StoreOpPattern final : public OpConversionPattern<memref::StoreOp> {
271public:
272 using OpConversionPattern<memref::StoreOp>::OpConversionPattern;
273
274 LogicalResult
275 matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
276 ConversionPatternRewriter &rewriter) const override;
277};
278
279class ReinterpretCastPattern final
280 : public OpConversionPattern<memref::ReinterpretCastOp> {
281public:
282 using OpConversionPattern::OpConversionPattern;
283
284 LogicalResult
285 matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor,
286 ConversionPatternRewriter &rewriter) const override;
287};
288
289class CastPattern final : public OpConversionPattern<memref::CastOp> {
290public:
291 using OpConversionPattern::OpConversionPattern;
292
293 LogicalResult
294 matchAndRewrite(memref::CastOp op, OpAdaptor adaptor,
295 ConversionPatternRewriter &rewriter) const override {
296 Value src = adaptor.getSource();
297 Type srcType = src.getType();
298
299 const TypeConverter *converter = getTypeConverter();
300 Type dstType = converter->convertType(op.getType());
301 if (srcType != dstType)
302 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
303 diag << "types doesn't match: " << srcType << " and " << dstType;
304 });
305
306 rewriter.replaceOp(op, src);
307 return success();
308 }
309};
310
311} // namespace
312
313//===----------------------------------------------------------------------===//
314// AllocaOp
315//===----------------------------------------------------------------------===//
316
317LogicalResult
318AllocaOpPattern::matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
319 ConversionPatternRewriter &rewriter) const {
320 MemRefType allocType = allocaOp.getType();
321 if (!isAllocationSupported(allocaOp, allocType))
322 return rewriter.notifyMatchFailure(allocaOp, "unhandled allocation type");
323
324 // Get the SPIR-V type for the allocation.
325 Type spirvType = getTypeConverter()->convertType(allocType);
326 if (!spirvType)
327 return rewriter.notifyMatchFailure(allocaOp, "type conversion failed");
328
329 rewriter.replaceOpWithNewOp<spirv::VariableOp>(allocaOp, spirvType,
330 spirv::StorageClass::Function,
331 /*initializer=*/nullptr);
332 return success();
333}
334
335//===----------------------------------------------------------------------===//
336// AllocOp
337//===----------------------------------------------------------------------===//
338
339LogicalResult
340AllocOpPattern::matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
341 ConversionPatternRewriter &rewriter) const {
342 MemRefType allocType = operation.getType();
343 if (!isAllocationSupported(operation, allocType))
344 return rewriter.notifyMatchFailure(operation, "unhandled allocation type");
345
346 // Get the SPIR-V type for the allocation.
347 Type spirvType = getTypeConverter()->convertType(allocType);
348 if (!spirvType)
349 return rewriter.notifyMatchFailure(operation, "type conversion failed");
350
351 // Insert spirv.GlobalVariable for this allocation.
352 Operation *parent =
353 SymbolTable::getNearestSymbolTable(from: operation->getParentOp());
354 if (!parent)
355 return failure();
356 Location loc = operation.getLoc();
357 spirv::GlobalVariableOp varOp;
358 {
359 OpBuilder::InsertionGuard guard(rewriter);
360 Block &entryBlock = *parent->getRegion(index: 0).begin();
361 rewriter.setInsertionPointToStart(&entryBlock);
362 auto varOps = entryBlock.getOps<spirv::GlobalVariableOp>();
363 std::string varName =
364 std::string("__workgroup_mem__") +
365 std::to_string(std::distance(varOps.begin(), varOps.end()));
366 varOp = rewriter.create<spirv::GlobalVariableOp>(loc, spirvType, varName,
367 /*initializer=*/nullptr);
368 }
369
370 // Get pointer to global variable at the current scope.
371 rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(operation, varOp);
372 return success();
373}
374
375//===----------------------------------------------------------------------===//
376// AllocOp
377//===----------------------------------------------------------------------===//
378
379LogicalResult
380AtomicRMWOpPattern::matchAndRewrite(memref::AtomicRMWOp atomicOp,
381 OpAdaptor adaptor,
382 ConversionPatternRewriter &rewriter) const {
383 if (isa<FloatType>(atomicOp.getType()))
384 return rewriter.notifyMatchFailure(atomicOp,
385 "unimplemented floating-point case");
386
387 auto memrefType = cast<MemRefType>(atomicOp.getMemref().getType());
388 std::optional<spirv::Scope> scope = getAtomicOpScope(memrefType);
389 if (!scope)
390 return rewriter.notifyMatchFailure(atomicOp,
391 "unsupported memref memory space");
392
393 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
394 Type resultType = typeConverter.convertType(atomicOp.getType());
395 if (!resultType)
396 return rewriter.notifyMatchFailure(atomicOp,
397 "failed to convert result type");
398
399 auto loc = atomicOp.getLoc();
400 Value ptr =
401 spirv::getElementPtr(typeConverter: typeConverter, baseType: memrefType, basePtr: adaptor.getMemref(),
402 indices: adaptor.getIndices(), loc: loc, builder&: rewriter);
403
404 if (!ptr)
405 return failure();
406
407#define ATOMIC_CASE(kind, spirvOp) \
408 case arith::AtomicRMWKind::kind: \
409 rewriter.replaceOpWithNewOp<spirv::spirvOp>( \
410 atomicOp, resultType, ptr, *scope, \
411 spirv::MemorySemantics::AcquireRelease, adaptor.getValue()); \
412 break
413
414 switch (atomicOp.getKind()) {
415 ATOMIC_CASE(addi, AtomicIAddOp);
416 ATOMIC_CASE(maxs, AtomicSMaxOp);
417 ATOMIC_CASE(maxu, AtomicUMaxOp);
418 ATOMIC_CASE(mins, AtomicSMinOp);
419 ATOMIC_CASE(minu, AtomicUMinOp);
420 ATOMIC_CASE(ori, AtomicOrOp);
421 ATOMIC_CASE(andi, AtomicAndOp);
422 default:
423 return rewriter.notifyMatchFailure(atomicOp, "unimplemented atomic kind");
424 }
425
426#undef ATOMIC_CASE
427
428 return success();
429}
430
431//===----------------------------------------------------------------------===//
432// DeallocOp
433//===----------------------------------------------------------------------===//
434
435LogicalResult
436DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation,
437 OpAdaptor adaptor,
438 ConversionPatternRewriter &rewriter) const {
439 MemRefType deallocType = cast<MemRefType>(operation.getMemref().getType());
440 if (!isAllocationSupported(operation, deallocType))
441 return rewriter.notifyMatchFailure(operation, "unhandled allocation type");
442 rewriter.eraseOp(op: operation);
443 return success();
444}
445
446//===----------------------------------------------------------------------===//
447// LoadOp
448//===----------------------------------------------------------------------===//
449
450struct MemoryRequirements {
451 spirv::MemoryAccessAttr memoryAccess;
452 IntegerAttr alignment;
453};
454
455/// Given an accessed SPIR-V pointer, calculates its alignment requirements, if
456/// any.
457static FailureOr<MemoryRequirements>
458calculateMemoryRequirements(Value accessedPtr, bool isNontemporal) {
459 MLIRContext *ctx = accessedPtr.getContext();
460
461 auto memoryAccess = spirv::MemoryAccess::None;
462 if (isNontemporal) {
463 memoryAccess = spirv::MemoryAccess::Nontemporal;
464 }
465
466 auto ptrType = cast<spirv::PointerType>(Val: accessedPtr.getType());
467 if (ptrType.getStorageClass() != spirv::StorageClass::PhysicalStorageBuffer) {
468 if (memoryAccess == spirv::MemoryAccess::None) {
469 return MemoryRequirements{spirv::MemoryAccessAttr{}, IntegerAttr{}};
470 }
471 return MemoryRequirements{spirv::MemoryAccessAttr::get(ctx, memoryAccess),
472 IntegerAttr{}};
473 }
474
475 // PhysicalStorageBuffers require the `Aligned` attribute.
476 auto pointeeType = dyn_cast<spirv::ScalarType>(Val: ptrType.getPointeeType());
477 if (!pointeeType)
478 return failure();
479
480 // For scalar types, the alignment is determined by their size.
481 std::optional<int64_t> sizeInBytes = pointeeType.getSizeInBytes();
482 if (!sizeInBytes.has_value())
483 return failure();
484
485 memoryAccess = memoryAccess | spirv::MemoryAccess::Aligned;
486 auto memAccessAttr = spirv::MemoryAccessAttr::get(ctx, memoryAccess);
487 auto alignment = IntegerAttr::get(IntegerType::get(ctx, 32), *sizeInBytes);
488 return MemoryRequirements{memAccessAttr, alignment};
489}
490
491/// Given an accessed SPIR-V pointer and the original memref load/store
492/// `memAccess` op, calculates the alignment requirements, if any. Takes into
493/// account the alignment attributes applied to the load/store op.
494template <class LoadOrStoreOp>
495static FailureOr<MemoryRequirements>
496calculateMemoryRequirements(Value accessedPtr, LoadOrStoreOp loadOrStoreOp) {
497 static_assert(
498 llvm::is_one_of<LoadOrStoreOp, memref::LoadOp, memref::StoreOp>::value,
499 "Must be called on either memref::LoadOp or memref::StoreOp");
500
501 Operation *memrefAccessOp = loadOrStoreOp.getOperation();
502 auto memrefMemAccess = memrefAccessOp->getAttrOfType<spirv::MemoryAccessAttr>(
503 spirv::attributeName<spirv::MemoryAccess>());
504 auto memrefAlignment =
505 memrefAccessOp->getAttrOfType<IntegerAttr>("alignment");
506 if (memrefMemAccess && memrefAlignment)
507 return MemoryRequirements{memrefMemAccess, memrefAlignment};
508
509 return calculateMemoryRequirements(accessedPtr,
510 loadOrStoreOp.getNontemporal());
511}
512
513LogicalResult
514IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
515 ConversionPatternRewriter &rewriter) const {
516 auto loc = loadOp.getLoc();
517 auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
518 if (!memrefType.getElementType().isSignlessInteger())
519 return failure();
520
521 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
522 Value accessChain =
523 spirv::getElementPtr(typeConverter: typeConverter, baseType: memrefType, basePtr: adaptor.getMemref(),
524 indices: adaptor.getIndices(), loc: loc, builder&: rewriter);
525
526 if (!accessChain)
527 return failure();
528
529 int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
530 bool isBool = srcBits == 1;
531 if (isBool)
532 srcBits = typeConverter.getOptions().boolNumBits;
533
534 auto pointerType = typeConverter.convertType<spirv::PointerType>(memrefType);
535 if (!pointerType)
536 return rewriter.notifyMatchFailure(loadOp, "failed to convert memref type");
537
538 Type pointeeType = pointerType.getPointeeType();
539 Type dstType;
540 if (typeConverter.allows(spirv::Capability::Kernel)) {
541 if (auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType))
542 dstType = arrayType.getElementType();
543 else
544 dstType = pointeeType;
545 } else {
546 // For Vulkan we need to extract element from wrapping struct and array.
547 Type structElemType =
548 cast<spirv::StructType>(Val&: pointeeType).getElementType(0);
549 if (auto arrayType = dyn_cast<spirv::ArrayType>(structElemType))
550 dstType = arrayType.getElementType();
551 else
552 dstType = cast<spirv::RuntimeArrayType>(Val&: structElemType).getElementType();
553 }
554 int dstBits = dstType.getIntOrFloatBitWidth();
555 assert(dstBits % srcBits == 0);
556
557 // If the rewritten load op has the same bit width, use the loading value
558 // directly.
559 if (srcBits == dstBits) {
560 auto memoryRequirements = calculateMemoryRequirements(accessChain, loadOp);
561 if (failed(memoryRequirements))
562 return rewriter.notifyMatchFailure(
563 loadOp, "failed to determine memory requirements");
564
565 auto [memoryAccess, alignment] = *memoryRequirements;
566 Value loadVal = rewriter.create<spirv::LoadOp>(loc, accessChain,
567 memoryAccess, alignment);
568 if (isBool)
569 loadVal = castIntNToBool(loc, loadVal, rewriter);
570 rewriter.replaceOp(loadOp, loadVal);
571 return success();
572 }
573
574 // Bitcasting is currently unsupported for Kernel capability /
575 // spirv.PtrAccessChain.
576 if (typeConverter.allows(spirv::Capability::Kernel))
577 return failure();
578
579 auto accessChainOp = accessChain.getDefiningOp<spirv::AccessChainOp>();
580 if (!accessChainOp)
581 return failure();
582
583 // Assume that getElementPtr() works linearizely. If it's a scalar, the method
584 // still returns a linearized accessing. If the accessing is not linearized,
585 // there will be offset issues.
586 assert(accessChainOp.getIndices().size() == 2);
587 Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
588 srcBits, dstBits, rewriter);
589 auto memoryRequirements = calculateMemoryRequirements(adjustedPtr, loadOp);
590 if (failed(memoryRequirements))
591 return rewriter.notifyMatchFailure(
592 loadOp, "failed to determine memory requirements");
593
594 auto [memoryAccess, alignment] = *memoryRequirements;
595 Value spvLoadOp = rewriter.create<spirv::LoadOp>(loc, dstType, adjustedPtr,
596 memoryAccess, alignment);
597
598 // Shift the bits to the rightmost.
599 // ____XXXX________ -> ____________XXXX
600 Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
601 Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
602 Value result = rewriter.createOrFold<spirv::ShiftRightArithmeticOp>(
603 loc, spvLoadOp.getType(), spvLoadOp, offset);
604
605 // Apply the mask to extract corresponding bits.
606 Value mask = rewriter.createOrFold<spirv::ConstantOp>(
607 loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
608 result =
609 rewriter.createOrFold<spirv::BitwiseAndOp>(loc, dstType, result, mask);
610
611 // Apply sign extension on the loading value unconditionally. The signedness
612 // semantic is carried in the operator itself, we relies other pattern to
613 // handle the casting.
614 IntegerAttr shiftValueAttr =
615 rewriter.getIntegerAttr(dstType, dstBits - srcBits);
616 Value shiftValue =
617 rewriter.createOrFold<spirv::ConstantOp>(loc, dstType, shiftValueAttr);
618 result = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>(loc, dstType,
619 result, shiftValue);
620 result = rewriter.createOrFold<spirv::ShiftRightArithmeticOp>(
621 loc, dstType, result, shiftValue);
622
623 rewriter.replaceOp(loadOp, result);
624
625 assert(accessChainOp.use_empty());
626 rewriter.eraseOp(op: accessChainOp);
627
628 return success();
629}
630
631LogicalResult
632LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
633 ConversionPatternRewriter &rewriter) const {
634 auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
635 if (memrefType.getElementType().isSignlessInteger())
636 return failure();
637 Value loadPtr = spirv::getElementPtr(
638 typeConverter: *getTypeConverter<SPIRVTypeConverter>(), baseType: memrefType, basePtr: adaptor.getMemref(),
639 indices: adaptor.getIndices(), loc: loadOp.getLoc(), builder&: rewriter);
640
641 if (!loadPtr)
642 return failure();
643
644 auto memoryRequirements = calculateMemoryRequirements(loadPtr, loadOp);
645 if (failed(memoryRequirements))
646 return rewriter.notifyMatchFailure(
647 loadOp, "failed to determine memory requirements");
648
649 auto [memoryAccess, alignment] = *memoryRequirements;
650 rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr, memoryAccess,
651 alignment);
652 return success();
653}
654
655LogicalResult
656IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
657 ConversionPatternRewriter &rewriter) const {
658 auto memrefType = cast<MemRefType>(storeOp.getMemref().getType());
659 if (!memrefType.getElementType().isSignlessInteger())
660 return rewriter.notifyMatchFailure(storeOp,
661 "element type is not a signless int");
662
663 auto loc = storeOp.getLoc();
664 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
665 Value accessChain =
666 spirv::getElementPtr(typeConverter: typeConverter, baseType: memrefType, basePtr: adaptor.getMemref(),
667 indices: adaptor.getIndices(), loc: loc, builder&: rewriter);
668
669 if (!accessChain)
670 return rewriter.notifyMatchFailure(
671 storeOp, "failed to convert element pointer type");
672
673 int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
674
675 bool isBool = srcBits == 1;
676 if (isBool)
677 srcBits = typeConverter.getOptions().boolNumBits;
678
679 auto pointerType = typeConverter.convertType<spirv::PointerType>(memrefType);
680 if (!pointerType)
681 return rewriter.notifyMatchFailure(storeOp,
682 "failed to convert memref type");
683
684 Type pointeeType = pointerType.getPointeeType();
685 IntegerType dstType;
686 if (typeConverter.allows(spirv::Capability::Kernel)) {
687 if (auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType))
688 dstType = dyn_cast<IntegerType>(arrayType.getElementType());
689 else
690 dstType = dyn_cast<IntegerType>(pointeeType);
691 } else {
692 // For Vulkan we need to extract element from wrapping struct and array.
693 Type structElemType =
694 cast<spirv::StructType>(Val&: pointeeType).getElementType(0);
695 if (auto arrayType = dyn_cast<spirv::ArrayType>(structElemType))
696 dstType = dyn_cast<IntegerType>(arrayType.getElementType());
697 else
698 dstType = dyn_cast<IntegerType>(
699 cast<spirv::RuntimeArrayType>(Val&: structElemType).getElementType());
700 }
701
702 if (!dstType)
703 return rewriter.notifyMatchFailure(
704 storeOp, "failed to determine destination element type");
705
706 int dstBits = static_cast<int>(dstType.getWidth());
707 assert(dstBits % srcBits == 0);
708
709 if (srcBits == dstBits) {
710 auto memoryRequirements = calculateMemoryRequirements(accessChain, storeOp);
711 if (failed(memoryRequirements))
712 return rewriter.notifyMatchFailure(
713 storeOp, "failed to determine memory requirements");
714
715 auto [memoryAccess, alignment] = *memoryRequirements;
716 Value storeVal = adaptor.getValue();
717 if (isBool)
718 storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter);
719 rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, accessChain, storeVal,
720 memoryAccess, alignment);
721 return success();
722 }
723
724 // Bitcasting is currently unsupported for Kernel capability /
725 // spirv.PtrAccessChain.
726 if (typeConverter.allows(spirv::Capability::Kernel))
727 return failure();
728
729 auto accessChainOp = accessChain.getDefiningOp<spirv::AccessChainOp>();
730 if (!accessChainOp)
731 return failure();
732
733 // Since there are multiple threads in the processing, the emulation will be
734 // done with atomic operations. E.g., if the stored value is i8, rewrite the
735 // StoreOp to:
736 // 1) load a 32-bit integer
737 // 2) clear 8 bits in the loaded value
738 // 3) set 8 bits in the loaded value
739 // 4) store 32-bit value back
740 //
741 // Step 2 is done with AtomicAnd, and step 3 is done with AtomicOr (of the
742 // loaded 32-bit value and the shifted 8-bit store value) as another atomic
743 // step.
744 assert(accessChainOp.getIndices().size() == 2);
745 Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
746 Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
747
748 // Create a mask to clear the destination. E.g., if it is the second i8 in
749 // i32, 0xFFFF00FF is created.
750 Value mask = rewriter.createOrFold<spirv::ConstantOp>(
751 loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
752 Value clearBitsMask = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>(
753 loc, dstType, mask, offset);
754 clearBitsMask =
755 rewriter.createOrFold<spirv::NotOp>(loc, dstType, clearBitsMask);
756
757 Value storeVal = shiftValue(loc, adaptor.getValue(), offset, mask, rewriter);
758 Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
759 srcBits, dstBits, rewriter);
760 std::optional<spirv::Scope> scope = getAtomicOpScope(memrefType);
761 if (!scope)
762 return rewriter.notifyMatchFailure(storeOp, "atomic scope not available");
763
764 Value result = rewriter.create<spirv::AtomicAndOp>(
765 loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease,
766 clearBitsMask);
767 result = rewriter.create<spirv::AtomicOrOp>(
768 loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease,
769 storeVal);
770
771 // The AtomicOrOp has no side effect. Since it is already inserted, we can
772 // just remove the original StoreOp. Note that rewriter.replaceOp()
773 // doesn't work because it only accepts that the numbers of result are the
774 // same.
775 rewriter.eraseOp(op: storeOp);
776
777 assert(accessChainOp.use_empty());
778 rewriter.eraseOp(op: accessChainOp);
779
780 return success();
781}
782
783//===----------------------------------------------------------------------===//
784// MemorySpaceCastOp
785//===----------------------------------------------------------------------===//
786
787LogicalResult MemorySpaceCastOpPattern::matchAndRewrite(
788 memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor,
789 ConversionPatternRewriter &rewriter) const {
790 Location loc = addrCastOp.getLoc();
791 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
792 if (!typeConverter.allows(spirv::Capability::Kernel))
793 return rewriter.notifyMatchFailure(
794 arg&: loc, msg: "address space casts require kernel capability");
795
796 auto sourceType = dyn_cast<MemRefType>(addrCastOp.getSource().getType());
797 if (!sourceType)
798 return rewriter.notifyMatchFailure(
799 arg&: loc, msg: "SPIR-V lowering requires ranked memref types");
800 auto resultType = cast<MemRefType>(addrCastOp.getResult().getType());
801
802 auto sourceStorageClassAttr =
803 dyn_cast_or_null<spirv::StorageClassAttr>(sourceType.getMemorySpace());
804 if (!sourceStorageClassAttr)
805 return rewriter.notifyMatchFailure(loc, reasonCallback: [sourceType](Diagnostic &diag) {
806 diag << "source address space " << sourceType.getMemorySpace()
807 << " must be a SPIR-V storage class";
808 });
809 auto resultStorageClassAttr =
810 dyn_cast_or_null<spirv::StorageClassAttr>(resultType.getMemorySpace());
811 if (!resultStorageClassAttr)
812 return rewriter.notifyMatchFailure(loc, reasonCallback: [resultType](Diagnostic &diag) {
813 diag << "result address space " << resultType.getMemorySpace()
814 << " must be a SPIR-V storage class";
815 });
816
817 spirv::StorageClass sourceSc = sourceStorageClassAttr.getValue();
818 spirv::StorageClass resultSc = resultStorageClassAttr.getValue();
819
820 Value result = adaptor.getSource();
821 Type resultPtrType = typeConverter.convertType(resultType);
822 if (!resultPtrType)
823 return rewriter.notifyMatchFailure(addrCastOp,
824 "failed to convert memref type");
825
826 Type genericPtrType = resultPtrType;
827 // SPIR-V doesn't have a general address space cast operation. Instead, it has
828 // conversions to and from generic pointers. To implement the general case,
829 // we use specific-to-generic conversions when the source class is not
830 // generic. Then when the result storage class is not generic, we convert the
831 // generic pointer (either the input on ar intermediate result) to that
832 // class. This also means that we'll need the intermediate generic pointer
833 // type if neither the source or destination have it.
834 if (sourceSc != spirv::StorageClass::Generic &&
835 resultSc != spirv::StorageClass::Generic) {
836 Type intermediateType =
837 MemRefType::get(sourceType.getShape(), sourceType.getElementType(),
838 sourceType.getLayout(),
839 rewriter.getAttr<spirv::StorageClassAttr>(
840 spirv::StorageClass::Generic));
841 genericPtrType = typeConverter.convertType(intermediateType);
842 }
843 if (sourceSc != spirv::StorageClass::Generic) {
844 result =
845 rewriter.create<spirv::PtrCastToGenericOp>(loc, genericPtrType, result);
846 }
847 if (resultSc != spirv::StorageClass::Generic) {
848 result =
849 rewriter.create<spirv::GenericCastToPtrOp>(loc, resultPtrType, result);
850 }
851 rewriter.replaceOp(addrCastOp, result);
852 return success();
853}
854
855LogicalResult
856StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
857 ConversionPatternRewriter &rewriter) const {
858 auto memrefType = cast<MemRefType>(storeOp.getMemref().getType());
859 if (memrefType.getElementType().isSignlessInteger())
860 return rewriter.notifyMatchFailure(storeOp, "signless int");
861 auto storePtr = spirv::getElementPtr(
862 typeConverter: *getTypeConverter<SPIRVTypeConverter>(), baseType: memrefType, basePtr: adaptor.getMemref(),
863 indices: adaptor.getIndices(), loc: storeOp.getLoc(), builder&: rewriter);
864
865 if (!storePtr)
866 return rewriter.notifyMatchFailure(storeOp, "type conversion failed");
867
868 auto memoryRequirements = calculateMemoryRequirements(storePtr, storeOp);
869 if (failed(memoryRequirements))
870 return rewriter.notifyMatchFailure(
871 storeOp, "failed to determine memory requirements");
872
873 auto [memoryAccess, alignment] = *memoryRequirements;
874 rewriter.replaceOpWithNewOp<spirv::StoreOp>(
875 storeOp, storePtr, adaptor.getValue(), memoryAccess, alignment);
876 return success();
877}
878
879LogicalResult ReinterpretCastPattern::matchAndRewrite(
880 memref::ReinterpretCastOp op, OpAdaptor adaptor,
881 ConversionPatternRewriter &rewriter) const {
882 Value src = adaptor.getSource();
883 auto srcType = dyn_cast<spirv::PointerType>(Val: src.getType());
884
885 if (!srcType)
886 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
887 diag << "invalid src type " << src.getType();
888 });
889
890 const TypeConverter *converter = getTypeConverter();
891
892 auto dstType = converter->convertType<spirv::PointerType>(op.getType());
893 if (dstType != srcType)
894 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
895 diag << "invalid dst type " << op.getType();
896 });
897
898 OpFoldResult offset =
899 getMixedValues(adaptor.getStaticOffsets(), adaptor.getOffsets(), rewriter)
900 .front();
901 if (isConstantIntValue(ofr: offset, value: 0)) {
902 rewriter.replaceOp(op, src);
903 return success();
904 }
905
906 Type intType = converter->convertType(rewriter.getIndexType());
907 if (!intType)
908 return rewriter.notifyMatchFailure(op, "failed to convert index type");
909
910 Location loc = op.getLoc();
911 auto offsetValue = [&]() -> Value {
912 if (auto val = dyn_cast<Value>(offset))
913 return val;
914
915 int64_t attrVal = cast<IntegerAttr>(offset.get<Attribute>()).getInt();
916 Attribute attr = rewriter.getIntegerAttr(intType, attrVal);
917 return rewriter.createOrFold<spirv::ConstantOp>(loc, intType, attr);
918 }();
919
920 rewriter.replaceOpWithNewOp<spirv::InBoundsPtrAccessChainOp>(
921 op, src, offsetValue, std::nullopt);
922 return success();
923}
924
925//===----------------------------------------------------------------------===//
926// Pattern population
927//===----------------------------------------------------------------------===//
928
929namespace mlir {
930void populateMemRefToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
931 RewritePatternSet &patterns) {
932 patterns.add<AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern,
933 DeallocOpPattern, IntLoadOpPattern, IntStoreOpPattern,
934 LoadOpPattern, MemorySpaceCastOpPattern, StoreOpPattern,
935 ReinterpretCastPattern, CastPattern>(arg&: typeConverter,
936 args: patterns.getContext());
937}
938} // namespace mlir
939

source code of mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp