1//===- MemRefToSPIRV.cpp - MemRef to SPIR-V Patterns ----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements patterns to convert MemRef dialect to SPIR-V dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/Arith/IR/Arith.h"
14#include "mlir/Dialect/MemRef/IR/MemRef.h"
15#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
16#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
17#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
18#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
19#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
20#include "mlir/IR/BuiltinAttributes.h"
21#include "mlir/IR/BuiltinTypes.h"
22#include "mlir/IR/MLIRContext.h"
23#include "mlir/IR/Visitors.h"
24#include "llvm/Support/Debug.h"
25#include <cassert>
26#include <optional>
27
28#define DEBUG_TYPE "memref-to-spirv-pattern"
29
30using namespace mlir;
31
32//===----------------------------------------------------------------------===//
33// Utility functions
34//===----------------------------------------------------------------------===//
35
36/// Returns the offset of the value in `targetBits` representation.
37///
38/// `srcIdx` is an index into a 1-D array with each element having `sourceBits`.
39/// It's assumed to be non-negative.
40///
41/// When accessing an element in the array treating as having elements of
42/// `targetBits`, multiple values are loaded in the same time. The method
43/// returns the offset where the `srcIdx` locates in the value. For example, if
44/// `sourceBits` equals to 8 and `targetBits` equals to 32, the x-th element is
45/// located at (x % 4) * 8. Because there are four elements in one i32, and one
46/// element has 8 bits.
47static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits,
48 int targetBits, OpBuilder &builder) {
49 assert(targetBits % sourceBits == 0);
50 Type type = srcIdx.getType();
51 IntegerAttr idxAttr = builder.getIntegerAttr(type, targetBits / sourceBits);
52 auto idx = builder.createOrFold<spirv::ConstantOp>(loc, type, idxAttr);
53 IntegerAttr srcBitsAttr = builder.getIntegerAttr(type, sourceBits);
54 auto srcBitsValue =
55 builder.createOrFold<spirv::ConstantOp>(loc, type, srcBitsAttr);
56 auto m = builder.createOrFold<spirv::UModOp>(loc, srcIdx, idx);
57 return builder.createOrFold<spirv::IMulOp>(loc, type, m, srcBitsValue);
58}
59
60/// Returns an adjusted spirv::AccessChainOp. Based on the
61/// extension/capabilities, certain integer bitwidths `sourceBits` might not be
62/// supported. During conversion if a memref of an unsupported type is used,
63/// load/stores to this memref need to be modified to use a supported higher
64/// bitwidth `targetBits` and extracting the required bits. For an accessing a
65/// 1D array (spirv.array or spirv.rtarray), the last index is modified to load
66/// the bits needed. The extraction of the actual bits needed are handled
67/// separately. Note that this only works for a 1-D tensor.
68static Value
69adjustAccessChainForBitwidth(const SPIRVTypeConverter &typeConverter,
70 spirv::AccessChainOp op, int sourceBits,
71 int targetBits, OpBuilder &builder) {
72 assert(targetBits % sourceBits == 0);
73 const auto loc = op.getLoc();
74 Value lastDim = op->getOperand(op.getNumOperands() - 1);
75 Type type = lastDim.getType();
76 IntegerAttr attr = builder.getIntegerAttr(type, targetBits / sourceBits);
77 auto idx = builder.createOrFold<spirv::ConstantOp>(loc, type, attr);
78 auto indices = llvm::to_vector<4>(op.getIndices());
79 // There are two elements if this is a 1-D tensor.
80 assert(indices.size() == 2);
81 indices.back() = builder.createOrFold<spirv::SDivOp>(loc, lastDim, idx);
82 Type t = typeConverter.convertType(op.getComponentPtr().getType());
83 return builder.create<spirv::AccessChainOp>(loc, t, op.getBasePtr(), indices);
84}
85
86/// Casts the given `srcBool` into an integer of `dstType`.
87static Value castBoolToIntN(Location loc, Value srcBool, Type dstType,
88 OpBuilder &builder) {
89 assert(srcBool.getType().isInteger(1));
90 if (dstType.isInteger(width: 1))
91 return srcBool;
92 Value zero = spirv::ConstantOp::getZero(dstType, loc, builder);
93 Value one = spirv::ConstantOp::getOne(dstType, loc, builder);
94 return builder.createOrFold<spirv::SelectOp>(loc, dstType, srcBool, one,
95 zero);
96}
97
98/// Returns the `targetBits`-bit value shifted by the given `offset`, and cast
99/// to the type destination type, and masked.
100static Value shiftValue(Location loc, Value value, Value offset, Value mask,
101 OpBuilder &builder) {
102 IntegerType dstType = cast<IntegerType>(mask.getType());
103 int targetBits = static_cast<int>(dstType.getWidth());
104 int valueBits = value.getType().getIntOrFloatBitWidth();
105 assert(valueBits <= targetBits);
106
107 if (valueBits == 1) {
108 value = castBoolToIntN(loc, value, dstType, builder);
109 } else {
110 if (valueBits < targetBits) {
111 value = builder.create<spirv::UConvertOp>(
112 loc, builder.getIntegerType(targetBits), value);
113 }
114
115 value = builder.createOrFold<spirv::BitwiseAndOp>(loc, value, mask);
116 }
117 return builder.createOrFold<spirv::ShiftLeftLogicalOp>(loc, value.getType(),
118 value, offset);
119}
120
121/// Returns true if the allocations of memref `type` generated from `allocOp`
122/// can be lowered to SPIR-V.
123static bool isAllocationSupported(Operation *allocOp, MemRefType type) {
124 if (isa<memref::AllocOp, memref::DeallocOp>(allocOp)) {
125 auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
126 if (!sc || sc.getValue() != spirv::StorageClass::Workgroup)
127 return false;
128 } else if (isa<memref::AllocaOp>(allocOp)) {
129 auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
130 if (!sc || sc.getValue() != spirv::StorageClass::Function)
131 return false;
132 } else {
133 return false;
134 }
135
136 // Currently only support static shape and int or float or vector of int or
137 // float element type.
138 if (!type.hasStaticShape())
139 return false;
140
141 Type elementType = type.getElementType();
142 if (auto vecType = dyn_cast<VectorType>(elementType))
143 elementType = vecType.getElementType();
144 return elementType.isIntOrFloat();
145}
146
147/// Returns the scope to use for atomic operations use for emulating store
148/// operations of unsupported integer bitwidths, based on the memref
149/// type. Returns std::nullopt on failure.
150static std::optional<spirv::Scope> getAtomicOpScope(MemRefType type) {
151 auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
152 switch (sc.getValue()) {
153 case spirv::StorageClass::StorageBuffer:
154 return spirv::Scope::Device;
155 case spirv::StorageClass::Workgroup:
156 return spirv::Scope::Workgroup;
157 default:
158 break;
159 }
160 return {};
161}
162
163/// Casts the given `srcInt` into a boolean value.
164static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder) {
165 if (srcInt.getType().isInteger(width: 1))
166 return srcInt;
167
168 auto one = spirv::ConstantOp::getZero(srcInt.getType(), loc, builder);
169 return builder.createOrFold<spirv::INotEqualOp>(loc, srcInt, one);
170}
171
172//===----------------------------------------------------------------------===//
173// Operation conversion
174//===----------------------------------------------------------------------===//
175
176// Note that DRR cannot be used for the patterns in this file: we may need to
177// convert type along the way, which requires ConversionPattern. DRR generates
178// normal RewritePattern.
179
180namespace {
181
182/// Converts memref.alloca to SPIR-V Function variables.
183class AllocaOpPattern final : public OpConversionPattern<memref::AllocaOp> {
184public:
185 using OpConversionPattern<memref::AllocaOp>::OpConversionPattern;
186
187 LogicalResult
188 matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
189 ConversionPatternRewriter &rewriter) const override;
190};
191
192/// Converts an allocation operation to SPIR-V. Currently only supports lowering
193/// to Workgroup memory when the size is constant. Note that this pattern needs
194/// to be applied in a pass that runs at least at spirv.module scope since it
195/// wil ladd global variables into the spirv.module.
196class AllocOpPattern final : public OpConversionPattern<memref::AllocOp> {
197public:
198 using OpConversionPattern<memref::AllocOp>::OpConversionPattern;
199
200 LogicalResult
201 matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
202 ConversionPatternRewriter &rewriter) const override;
203};
204
205/// Converts memref.automic_rmw operations to SPIR-V atomic operations.
206class AtomicRMWOpPattern final
207 : public OpConversionPattern<memref::AtomicRMWOp> {
208public:
209 using OpConversionPattern<memref::AtomicRMWOp>::OpConversionPattern;
210
211 LogicalResult
212 matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
213 ConversionPatternRewriter &rewriter) const override;
214};
215
216/// Removed a deallocation if it is a supported allocation. Currently only
217/// removes deallocation if the memory space is workgroup memory.
218class DeallocOpPattern final : public OpConversionPattern<memref::DeallocOp> {
219public:
220 using OpConversionPattern<memref::DeallocOp>::OpConversionPattern;
221
222 LogicalResult
223 matchAndRewrite(memref::DeallocOp operation, OpAdaptor adaptor,
224 ConversionPatternRewriter &rewriter) const override;
225};
226
227/// Converts memref.load to spirv.Load + spirv.AccessChain on integers.
228class IntLoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
229public:
230 using OpConversionPattern<memref::LoadOp>::OpConversionPattern;
231
232 LogicalResult
233 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
234 ConversionPatternRewriter &rewriter) const override;
235};
236
237/// Converts memref.load to spirv.Load + spirv.AccessChain.
238class LoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
239public:
240 using OpConversionPattern<memref::LoadOp>::OpConversionPattern;
241
242 LogicalResult
243 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
244 ConversionPatternRewriter &rewriter) const override;
245};
246
247/// Converts memref.store to spirv.Store on integers.
248class IntStoreOpPattern final : public OpConversionPattern<memref::StoreOp> {
249public:
250 using OpConversionPattern<memref::StoreOp>::OpConversionPattern;
251
252 LogicalResult
253 matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
254 ConversionPatternRewriter &rewriter) const override;
255};
256
257/// Converts memref.memory_space_cast to the appropriate spirv cast operations.
258class MemorySpaceCastOpPattern final
259 : public OpConversionPattern<memref::MemorySpaceCastOp> {
260public:
261 using OpConversionPattern<memref::MemorySpaceCastOp>::OpConversionPattern;
262
263 LogicalResult
264 matchAndRewrite(memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor,
265 ConversionPatternRewriter &rewriter) const override;
266};
267
268/// Converts memref.store to spirv.Store.
269class StoreOpPattern final : public OpConversionPattern<memref::StoreOp> {
270public:
271 using OpConversionPattern<memref::StoreOp>::OpConversionPattern;
272
273 LogicalResult
274 matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
275 ConversionPatternRewriter &rewriter) const override;
276};
277
278class ReinterpretCastPattern final
279 : public OpConversionPattern<memref::ReinterpretCastOp> {
280public:
281 using OpConversionPattern::OpConversionPattern;
282
283 LogicalResult
284 matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor,
285 ConversionPatternRewriter &rewriter) const override;
286};
287
288class CastPattern final : public OpConversionPattern<memref::CastOp> {
289public:
290 using OpConversionPattern::OpConversionPattern;
291
292 LogicalResult
293 matchAndRewrite(memref::CastOp op, OpAdaptor adaptor,
294 ConversionPatternRewriter &rewriter) const override {
295 Value src = adaptor.getSource();
296 Type srcType = src.getType();
297
298 const TypeConverter *converter = getTypeConverter();
299 Type dstType = converter->convertType(op.getType());
300 if (srcType != dstType)
301 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
302 diag << "types doesn't match: " << srcType << " and " << dstType;
303 });
304
305 rewriter.replaceOp(op, src);
306 return success();
307 }
308};
309
310/// Converts memref.extract_aligned_pointer_as_index to spirv.ConvertPtrToU.
311class ExtractAlignedPointerAsIndexOpPattern final
312 : public OpConversionPattern<memref::ExtractAlignedPointerAsIndexOp> {
313public:
314 using OpConversionPattern::OpConversionPattern;
315
316 LogicalResult
317 matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
318 OpAdaptor adaptor,
319 ConversionPatternRewriter &rewriter) const override;
320};
321} // namespace
322
323//===----------------------------------------------------------------------===//
324// AllocaOp
325//===----------------------------------------------------------------------===//
326
327LogicalResult
328AllocaOpPattern::matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
329 ConversionPatternRewriter &rewriter) const {
330 MemRefType allocType = allocaOp.getType();
331 if (!isAllocationSupported(allocaOp, allocType))
332 return rewriter.notifyMatchFailure(allocaOp, "unhandled allocation type");
333
334 // Get the SPIR-V type for the allocation.
335 Type spirvType = getTypeConverter()->convertType(allocType);
336 if (!spirvType)
337 return rewriter.notifyMatchFailure(allocaOp, "type conversion failed");
338
339 rewriter.replaceOpWithNewOp<spirv::VariableOp>(allocaOp, spirvType,
340 spirv::StorageClass::Function,
341 /*initializer=*/nullptr);
342 return success();
343}
344
345//===----------------------------------------------------------------------===//
346// AllocOp
347//===----------------------------------------------------------------------===//
348
349LogicalResult
350AllocOpPattern::matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
351 ConversionPatternRewriter &rewriter) const {
352 MemRefType allocType = operation.getType();
353 if (!isAllocationSupported(operation, allocType))
354 return rewriter.notifyMatchFailure(operation, "unhandled allocation type");
355
356 // Get the SPIR-V type for the allocation.
357 Type spirvType = getTypeConverter()->convertType(allocType);
358 if (!spirvType)
359 return rewriter.notifyMatchFailure(operation, "type conversion failed");
360
361 // Insert spirv.GlobalVariable for this allocation.
362 Operation *parent =
363 SymbolTable::getNearestSymbolTable(from: operation->getParentOp());
364 if (!parent)
365 return failure();
366 Location loc = operation.getLoc();
367 spirv::GlobalVariableOp varOp;
368 {
369 OpBuilder::InsertionGuard guard(rewriter);
370 Block &entryBlock = *parent->getRegion(index: 0).begin();
371 rewriter.setInsertionPointToStart(&entryBlock);
372 auto varOps = entryBlock.getOps<spirv::GlobalVariableOp>();
373 std::string varName =
374 std::string("__workgroup_mem__") +
375 std::to_string(std::distance(varOps.begin(), varOps.end()));
376 varOp = rewriter.create<spirv::GlobalVariableOp>(loc, spirvType, varName,
377 /*initializer=*/nullptr);
378 }
379
380 // Get pointer to global variable at the current scope.
381 rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(operation, varOp);
382 return success();
383}
384
385//===----------------------------------------------------------------------===//
386// AllocOp
387//===----------------------------------------------------------------------===//
388
389LogicalResult
390AtomicRMWOpPattern::matchAndRewrite(memref::AtomicRMWOp atomicOp,
391 OpAdaptor adaptor,
392 ConversionPatternRewriter &rewriter) const {
393 if (isa<FloatType>(atomicOp.getType()))
394 return rewriter.notifyMatchFailure(atomicOp,
395 "unimplemented floating-point case");
396
397 auto memrefType = cast<MemRefType>(atomicOp.getMemref().getType());
398 std::optional<spirv::Scope> scope = getAtomicOpScope(memrefType);
399 if (!scope)
400 return rewriter.notifyMatchFailure(atomicOp,
401 "unsupported memref memory space");
402
403 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
404 Type resultType = typeConverter.convertType(atomicOp.getType());
405 if (!resultType)
406 return rewriter.notifyMatchFailure(atomicOp,
407 "failed to convert result type");
408
409 auto loc = atomicOp.getLoc();
410 Value ptr =
411 spirv::getElementPtr(typeConverter: typeConverter, baseType: memrefType, basePtr: adaptor.getMemref(),
412 indices: adaptor.getIndices(), loc: loc, builder&: rewriter);
413
414 if (!ptr)
415 return failure();
416
417#define ATOMIC_CASE(kind, spirvOp) \
418 case arith::AtomicRMWKind::kind: \
419 rewriter.replaceOpWithNewOp<spirv::spirvOp>( \
420 atomicOp, resultType, ptr, *scope, \
421 spirv::MemorySemantics::AcquireRelease, adaptor.getValue()); \
422 break
423
424 switch (atomicOp.getKind()) {
425 ATOMIC_CASE(addi, AtomicIAddOp);
426 ATOMIC_CASE(maxs, AtomicSMaxOp);
427 ATOMIC_CASE(maxu, AtomicUMaxOp);
428 ATOMIC_CASE(mins, AtomicSMinOp);
429 ATOMIC_CASE(minu, AtomicUMinOp);
430 ATOMIC_CASE(ori, AtomicOrOp);
431 ATOMIC_CASE(andi, AtomicAndOp);
432 default:
433 return rewriter.notifyMatchFailure(atomicOp, "unimplemented atomic kind");
434 }
435
436#undef ATOMIC_CASE
437
438 return success();
439}
440
441//===----------------------------------------------------------------------===//
442// DeallocOp
443//===----------------------------------------------------------------------===//
444
445LogicalResult
446DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation,
447 OpAdaptor adaptor,
448 ConversionPatternRewriter &rewriter) const {
449 MemRefType deallocType = cast<MemRefType>(operation.getMemref().getType());
450 if (!isAllocationSupported(operation, deallocType))
451 return rewriter.notifyMatchFailure(operation, "unhandled allocation type");
452 rewriter.eraseOp(op: operation);
453 return success();
454}
455
456//===----------------------------------------------------------------------===//
457// LoadOp
458//===----------------------------------------------------------------------===//
459
460struct MemoryRequirements {
461 spirv::MemoryAccessAttr memoryAccess;
462 IntegerAttr alignment;
463};
464
465/// Given an accessed SPIR-V pointer, calculates its alignment requirements, if
466/// any.
467static FailureOr<MemoryRequirements>
468calculateMemoryRequirements(Value accessedPtr, bool isNontemporal) {
469 MLIRContext *ctx = accessedPtr.getContext();
470
471 auto memoryAccess = spirv::MemoryAccess::None;
472 if (isNontemporal) {
473 memoryAccess = spirv::MemoryAccess::Nontemporal;
474 }
475
476 auto ptrType = cast<spirv::PointerType>(Val: accessedPtr.getType());
477 if (ptrType.getStorageClass() != spirv::StorageClass::PhysicalStorageBuffer) {
478 if (memoryAccess == spirv::MemoryAccess::None) {
479 return MemoryRequirements{spirv::MemoryAccessAttr{}, IntegerAttr{}};
480 }
481 return MemoryRequirements{spirv::MemoryAccessAttr::get(ctx, memoryAccess),
482 IntegerAttr{}};
483 }
484
485 // PhysicalStorageBuffers require the `Aligned` attribute.
486 auto pointeeType = dyn_cast<spirv::ScalarType>(Val: ptrType.getPointeeType());
487 if (!pointeeType)
488 return failure();
489
490 // For scalar types, the alignment is determined by their size.
491 std::optional<int64_t> sizeInBytes = pointeeType.getSizeInBytes();
492 if (!sizeInBytes.has_value())
493 return failure();
494
495 memoryAccess = memoryAccess | spirv::MemoryAccess::Aligned;
496 auto memAccessAttr = spirv::MemoryAccessAttr::get(ctx, memoryAccess);
497 auto alignment = IntegerAttr::get(IntegerType::get(ctx, 32), *sizeInBytes);
498 return MemoryRequirements{memAccessAttr, alignment};
499}
500
501/// Given an accessed SPIR-V pointer and the original memref load/store
502/// `memAccess` op, calculates the alignment requirements, if any. Takes into
503/// account the alignment attributes applied to the load/store op.
504template <class LoadOrStoreOp>
505static FailureOr<MemoryRequirements>
506calculateMemoryRequirements(Value accessedPtr, LoadOrStoreOp loadOrStoreOp) {
507 static_assert(
508 llvm::is_one_of<LoadOrStoreOp, memref::LoadOp, memref::StoreOp>::value,
509 "Must be called on either memref::LoadOp or memref::StoreOp");
510
511 Operation *memrefAccessOp = loadOrStoreOp.getOperation();
512 auto memrefMemAccess = memrefAccessOp->getAttrOfType<spirv::MemoryAccessAttr>(
513 spirv::attributeName<spirv::MemoryAccess>());
514 auto memrefAlignment =
515 memrefAccessOp->getAttrOfType<IntegerAttr>("alignment");
516 if (memrefMemAccess && memrefAlignment)
517 return MemoryRequirements{memrefMemAccess, memrefAlignment};
518
519 return calculateMemoryRequirements(accessedPtr,
520 loadOrStoreOp.getNontemporal());
521}
522
523LogicalResult
524IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
525 ConversionPatternRewriter &rewriter) const {
526 auto loc = loadOp.getLoc();
527 auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
528 if (!memrefType.getElementType().isSignlessInteger())
529 return failure();
530
531 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
532 Value accessChain =
533 spirv::getElementPtr(typeConverter: typeConverter, baseType: memrefType, basePtr: adaptor.getMemref(),
534 indices: adaptor.getIndices(), loc: loc, builder&: rewriter);
535
536 if (!accessChain)
537 return failure();
538
539 int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
540 bool isBool = srcBits == 1;
541 if (isBool)
542 srcBits = typeConverter.getOptions().boolNumBits;
543
544 auto pointerType = typeConverter.convertType<spirv::PointerType>(memrefType);
545 if (!pointerType)
546 return rewriter.notifyMatchFailure(loadOp, "failed to convert memref type");
547
548 Type pointeeType = pointerType.getPointeeType();
549 Type dstType;
550 if (typeConverter.allows(spirv::Capability::Kernel)) {
551 if (auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType))
552 dstType = arrayType.getElementType();
553 else
554 dstType = pointeeType;
555 } else {
556 // For Vulkan we need to extract element from wrapping struct and array.
557 Type structElemType =
558 cast<spirv::StructType>(Val&: pointeeType).getElementType(0);
559 if (auto arrayType = dyn_cast<spirv::ArrayType>(structElemType))
560 dstType = arrayType.getElementType();
561 else
562 dstType = cast<spirv::RuntimeArrayType>(Val&: structElemType).getElementType();
563 }
564 int dstBits = dstType.getIntOrFloatBitWidth();
565 assert(dstBits % srcBits == 0);
566
567 // If the rewritten load op has the same bit width, use the loading value
568 // directly.
569 if (srcBits == dstBits) {
570 auto memoryRequirements = calculateMemoryRequirements(accessChain, loadOp);
571 if (failed(memoryRequirements))
572 return rewriter.notifyMatchFailure(
573 loadOp, "failed to determine memory requirements");
574
575 auto [memoryAccess, alignment] = *memoryRequirements;
576 Value loadVal = rewriter.create<spirv::LoadOp>(loc, accessChain,
577 memoryAccess, alignment);
578 if (isBool)
579 loadVal = castIntNToBool(loc, loadVal, rewriter);
580 rewriter.replaceOp(loadOp, loadVal);
581 return success();
582 }
583
584 // Bitcasting is currently unsupported for Kernel capability /
585 // spirv.PtrAccessChain.
586 if (typeConverter.allows(spirv::Capability::Kernel))
587 return failure();
588
589 auto accessChainOp = accessChain.getDefiningOp<spirv::AccessChainOp>();
590 if (!accessChainOp)
591 return failure();
592
593 // Assume that getElementPtr() works linearizely. If it's a scalar, the method
594 // still returns a linearized accessing. If the accessing is not linearized,
595 // there will be offset issues.
596 assert(accessChainOp.getIndices().size() == 2);
597 Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
598 srcBits, dstBits, rewriter);
599 auto memoryRequirements = calculateMemoryRequirements(adjustedPtr, loadOp);
600 if (failed(memoryRequirements))
601 return rewriter.notifyMatchFailure(
602 loadOp, "failed to determine memory requirements");
603
604 auto [memoryAccess, alignment] = *memoryRequirements;
605 Value spvLoadOp = rewriter.create<spirv::LoadOp>(loc, dstType, adjustedPtr,
606 memoryAccess, alignment);
607
608 // Shift the bits to the rightmost.
609 // ____XXXX________ -> ____________XXXX
610 Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
611 Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
612 Value result = rewriter.createOrFold<spirv::ShiftRightArithmeticOp>(
613 loc, spvLoadOp.getType(), spvLoadOp, offset);
614
615 // Apply the mask to extract corresponding bits.
616 Value mask = rewriter.createOrFold<spirv::ConstantOp>(
617 loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
618 result =
619 rewriter.createOrFold<spirv::BitwiseAndOp>(loc, dstType, result, mask);
620
621 // Apply sign extension on the loading value unconditionally. The signedness
622 // semantic is carried in the operator itself, we relies other pattern to
623 // handle the casting.
624 IntegerAttr shiftValueAttr =
625 rewriter.getIntegerAttr(dstType, dstBits - srcBits);
626 Value shiftValue =
627 rewriter.createOrFold<spirv::ConstantOp>(loc, dstType, shiftValueAttr);
628 result = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>(loc, dstType,
629 result, shiftValue);
630 result = rewriter.createOrFold<spirv::ShiftRightArithmeticOp>(
631 loc, dstType, result, shiftValue);
632
633 rewriter.replaceOp(loadOp, result);
634
635 assert(accessChainOp.use_empty());
636 rewriter.eraseOp(op: accessChainOp);
637
638 return success();
639}
640
641LogicalResult
642LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
643 ConversionPatternRewriter &rewriter) const {
644 auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
645 if (memrefType.getElementType().isSignlessInteger())
646 return failure();
647 Value loadPtr = spirv::getElementPtr(
648 typeConverter: *getTypeConverter<SPIRVTypeConverter>(), baseType: memrefType, basePtr: adaptor.getMemref(),
649 indices: adaptor.getIndices(), loc: loadOp.getLoc(), builder&: rewriter);
650
651 if (!loadPtr)
652 return failure();
653
654 auto memoryRequirements = calculateMemoryRequirements(loadPtr, loadOp);
655 if (failed(memoryRequirements))
656 return rewriter.notifyMatchFailure(
657 loadOp, "failed to determine memory requirements");
658
659 auto [memoryAccess, alignment] = *memoryRequirements;
660 rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr, memoryAccess,
661 alignment);
662 return success();
663}
664
665LogicalResult
666IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
667 ConversionPatternRewriter &rewriter) const {
668 auto memrefType = cast<MemRefType>(storeOp.getMemref().getType());
669 if (!memrefType.getElementType().isSignlessInteger())
670 return rewriter.notifyMatchFailure(storeOp,
671 "element type is not a signless int");
672
673 auto loc = storeOp.getLoc();
674 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
675 Value accessChain =
676 spirv::getElementPtr(typeConverter: typeConverter, baseType: memrefType, basePtr: adaptor.getMemref(),
677 indices: adaptor.getIndices(), loc: loc, builder&: rewriter);
678
679 if (!accessChain)
680 return rewriter.notifyMatchFailure(
681 storeOp, "failed to convert element pointer type");
682
683 int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
684
685 bool isBool = srcBits == 1;
686 if (isBool)
687 srcBits = typeConverter.getOptions().boolNumBits;
688
689 auto pointerType = typeConverter.convertType<spirv::PointerType>(memrefType);
690 if (!pointerType)
691 return rewriter.notifyMatchFailure(storeOp,
692 "failed to convert memref type");
693
694 Type pointeeType = pointerType.getPointeeType();
695 IntegerType dstType;
696 if (typeConverter.allows(spirv::Capability::Kernel)) {
697 if (auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType))
698 dstType = dyn_cast<IntegerType>(arrayType.getElementType());
699 else
700 dstType = dyn_cast<IntegerType>(pointeeType);
701 } else {
702 // For Vulkan we need to extract element from wrapping struct and array.
703 Type structElemType =
704 cast<spirv::StructType>(Val&: pointeeType).getElementType(0);
705 if (auto arrayType = dyn_cast<spirv::ArrayType>(structElemType))
706 dstType = dyn_cast<IntegerType>(arrayType.getElementType());
707 else
708 dstType = dyn_cast<IntegerType>(
709 cast<spirv::RuntimeArrayType>(Val&: structElemType).getElementType());
710 }
711
712 if (!dstType)
713 return rewriter.notifyMatchFailure(
714 storeOp, "failed to determine destination element type");
715
716 int dstBits = static_cast<int>(dstType.getWidth());
717 assert(dstBits % srcBits == 0);
718
719 if (srcBits == dstBits) {
720 auto memoryRequirements = calculateMemoryRequirements(accessChain, storeOp);
721 if (failed(memoryRequirements))
722 return rewriter.notifyMatchFailure(
723 storeOp, "failed to determine memory requirements");
724
725 auto [memoryAccess, alignment] = *memoryRequirements;
726 Value storeVal = adaptor.getValue();
727 if (isBool)
728 storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter);
729 rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, accessChain, storeVal,
730 memoryAccess, alignment);
731 return success();
732 }
733
734 // Bitcasting is currently unsupported for Kernel capability /
735 // spirv.PtrAccessChain.
736 if (typeConverter.allows(spirv::Capability::Kernel))
737 return failure();
738
739 auto accessChainOp = accessChain.getDefiningOp<spirv::AccessChainOp>();
740 if (!accessChainOp)
741 return failure();
742
743 // Since there are multiple threads in the processing, the emulation will be
744 // done with atomic operations. E.g., if the stored value is i8, rewrite the
745 // StoreOp to:
746 // 1) load a 32-bit integer
747 // 2) clear 8 bits in the loaded value
748 // 3) set 8 bits in the loaded value
749 // 4) store 32-bit value back
750 //
751 // Step 2 is done with AtomicAnd, and step 3 is done with AtomicOr (of the
752 // loaded 32-bit value and the shifted 8-bit store value) as another atomic
753 // step.
754 assert(accessChainOp.getIndices().size() == 2);
755 Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
756 Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
757
758 // Create a mask to clear the destination. E.g., if it is the second i8 in
759 // i32, 0xFFFF00FF is created.
760 Value mask = rewriter.createOrFold<spirv::ConstantOp>(
761 loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
762 Value clearBitsMask = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>(
763 loc, dstType, mask, offset);
764 clearBitsMask =
765 rewriter.createOrFold<spirv::NotOp>(loc, dstType, clearBitsMask);
766
767 Value storeVal = shiftValue(loc, adaptor.getValue(), offset, mask, rewriter);
768 Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
769 srcBits, dstBits, rewriter);
770 std::optional<spirv::Scope> scope = getAtomicOpScope(memrefType);
771 if (!scope)
772 return rewriter.notifyMatchFailure(storeOp, "atomic scope not available");
773
774 Value result = rewriter.create<spirv::AtomicAndOp>(
775 loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease,
776 clearBitsMask);
777 result = rewriter.create<spirv::AtomicOrOp>(
778 loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease,
779 storeVal);
780
781 // The AtomicOrOp has no side effect. Since it is already inserted, we can
782 // just remove the original StoreOp. Note that rewriter.replaceOp()
783 // doesn't work because it only accepts that the numbers of result are the
784 // same.
785 rewriter.eraseOp(op: storeOp);
786
787 assert(accessChainOp.use_empty());
788 rewriter.eraseOp(op: accessChainOp);
789
790 return success();
791}
792
793//===----------------------------------------------------------------------===//
794// MemorySpaceCastOp
795//===----------------------------------------------------------------------===//
796
797LogicalResult MemorySpaceCastOpPattern::matchAndRewrite(
798 memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor,
799 ConversionPatternRewriter &rewriter) const {
800 Location loc = addrCastOp.getLoc();
801 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
802 if (!typeConverter.allows(spirv::Capability::Kernel))
803 return rewriter.notifyMatchFailure(
804 arg&: loc, msg: "address space casts require kernel capability");
805
806 auto sourceType = dyn_cast<MemRefType>(addrCastOp.getSource().getType());
807 if (!sourceType)
808 return rewriter.notifyMatchFailure(
809 arg&: loc, msg: "SPIR-V lowering requires ranked memref types");
810 auto resultType = cast<MemRefType>(addrCastOp.getResult().getType());
811
812 auto sourceStorageClassAttr =
813 dyn_cast_or_null<spirv::StorageClassAttr>(sourceType.getMemorySpace());
814 if (!sourceStorageClassAttr)
815 return rewriter.notifyMatchFailure(loc, reasonCallback: [sourceType](Diagnostic &diag) {
816 diag << "source address space " << sourceType.getMemorySpace()
817 << " must be a SPIR-V storage class";
818 });
819 auto resultStorageClassAttr =
820 dyn_cast_or_null<spirv::StorageClassAttr>(resultType.getMemorySpace());
821 if (!resultStorageClassAttr)
822 return rewriter.notifyMatchFailure(loc, reasonCallback: [resultType](Diagnostic &diag) {
823 diag << "result address space " << resultType.getMemorySpace()
824 << " must be a SPIR-V storage class";
825 });
826
827 spirv::StorageClass sourceSc = sourceStorageClassAttr.getValue();
828 spirv::StorageClass resultSc = resultStorageClassAttr.getValue();
829
830 Value result = adaptor.getSource();
831 Type resultPtrType = typeConverter.convertType(resultType);
832 if (!resultPtrType)
833 return rewriter.notifyMatchFailure(addrCastOp,
834 "failed to convert memref type");
835
836 Type genericPtrType = resultPtrType;
837 // SPIR-V doesn't have a general address space cast operation. Instead, it has
838 // conversions to and from generic pointers. To implement the general case,
839 // we use specific-to-generic conversions when the source class is not
840 // generic. Then when the result storage class is not generic, we convert the
841 // generic pointer (either the input on ar intermediate result) to that
842 // class. This also means that we'll need the intermediate generic pointer
843 // type if neither the source or destination have it.
844 if (sourceSc != spirv::StorageClass::Generic &&
845 resultSc != spirv::StorageClass::Generic) {
846 Type intermediateType =
847 MemRefType::get(sourceType.getShape(), sourceType.getElementType(),
848 sourceType.getLayout(),
849 rewriter.getAttr<spirv::StorageClassAttr>(
850 spirv::StorageClass::Generic));
851 genericPtrType = typeConverter.convertType(intermediateType);
852 }
853 if (sourceSc != spirv::StorageClass::Generic) {
854 result =
855 rewriter.create<spirv::PtrCastToGenericOp>(loc, genericPtrType, result);
856 }
857 if (resultSc != spirv::StorageClass::Generic) {
858 result =
859 rewriter.create<spirv::GenericCastToPtrOp>(loc, resultPtrType, result);
860 }
861 rewriter.replaceOp(addrCastOp, result);
862 return success();
863}
864
865LogicalResult
866StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
867 ConversionPatternRewriter &rewriter) const {
868 auto memrefType = cast<MemRefType>(storeOp.getMemref().getType());
869 if (memrefType.getElementType().isSignlessInteger())
870 return rewriter.notifyMatchFailure(storeOp, "signless int");
871 auto storePtr = spirv::getElementPtr(
872 typeConverter: *getTypeConverter<SPIRVTypeConverter>(), baseType: memrefType, basePtr: adaptor.getMemref(),
873 indices: adaptor.getIndices(), loc: storeOp.getLoc(), builder&: rewriter);
874
875 if (!storePtr)
876 return rewriter.notifyMatchFailure(storeOp, "type conversion failed");
877
878 auto memoryRequirements = calculateMemoryRequirements(storePtr, storeOp);
879 if (failed(memoryRequirements))
880 return rewriter.notifyMatchFailure(
881 storeOp, "failed to determine memory requirements");
882
883 auto [memoryAccess, alignment] = *memoryRequirements;
884 rewriter.replaceOpWithNewOp<spirv::StoreOp>(
885 storeOp, storePtr, adaptor.getValue(), memoryAccess, alignment);
886 return success();
887}
888
889LogicalResult ReinterpretCastPattern::matchAndRewrite(
890 memref::ReinterpretCastOp op, OpAdaptor adaptor,
891 ConversionPatternRewriter &rewriter) const {
892 Value src = adaptor.getSource();
893 auto srcType = dyn_cast<spirv::PointerType>(Val: src.getType());
894
895 if (!srcType)
896 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
897 diag << "invalid src type " << src.getType();
898 });
899
900 const TypeConverter *converter = getTypeConverter();
901
902 auto dstType = converter->convertType<spirv::PointerType>(op.getType());
903 if (dstType != srcType)
904 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
905 diag << "invalid dst type " << op.getType();
906 });
907
908 OpFoldResult offset =
909 getMixedValues(adaptor.getStaticOffsets(), adaptor.getOffsets(), rewriter)
910 .front();
911 if (isZeroInteger(v: offset)) {
912 rewriter.replaceOp(op, src);
913 return success();
914 }
915
916 Type intType = converter->convertType(rewriter.getIndexType());
917 if (!intType)
918 return rewriter.notifyMatchFailure(op, "failed to convert index type");
919
920 Location loc = op.getLoc();
921 auto offsetValue = [&]() -> Value {
922 if (auto val = dyn_cast<Value>(offset))
923 return val;
924
925 int64_t attrVal = cast<IntegerAttr>(cast<Attribute>(Val&: offset)).getInt();
926 Attribute attr = rewriter.getIntegerAttr(intType, attrVal);
927 return rewriter.createOrFold<spirv::ConstantOp>(loc, intType, attr);
928 }();
929
930 rewriter.replaceOpWithNewOp<spirv::InBoundsPtrAccessChainOp>(
931 op, src, offsetValue, std::nullopt);
932 return success();
933}
934
935//===----------------------------------------------------------------------===//
936// ExtractAlignedPointerAsIndexOp
937//===----------------------------------------------------------------------===//
938
939LogicalResult ExtractAlignedPointerAsIndexOpPattern::matchAndRewrite(
940 memref::ExtractAlignedPointerAsIndexOp extractOp, OpAdaptor adaptor,
941 ConversionPatternRewriter &rewriter) const {
942 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
943 Type indexType = typeConverter.getIndexType();
944 rewriter.replaceOpWithNewOp<spirv::ConvertPtrToUOp>(extractOp, indexType,
945 adaptor.getSource());
946 return success();
947}
948
949//===----------------------------------------------------------------------===//
950// Pattern population
951//===----------------------------------------------------------------------===//
952
953namespace mlir {
954void populateMemRefToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
955 RewritePatternSet &patterns) {
956 patterns
957 .add<AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern,
958 DeallocOpPattern, IntLoadOpPattern, IntStoreOpPattern, LoadOpPattern,
959 MemorySpaceCastOpPattern, StoreOpPattern, ReinterpretCastPattern,
960 CastPattern, ExtractAlignedPointerAsIndexOpPattern>(
961 arg: typeConverter, args: patterns.getContext());
962}
963} // namespace mlir
964

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp