1//===- MemRefToSPIRV.cpp - MemRef to SPIR-V Patterns ----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements patterns to convert MemRef dialect to SPIR-V dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/Arith/IR/Arith.h"
14#include "mlir/Dialect/MemRef/IR/MemRef.h"
15#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
16#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
17#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
18#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
19#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
20#include "mlir/IR/BuiltinAttributes.h"
21#include "mlir/IR/BuiltinTypes.h"
22#include "mlir/IR/MLIRContext.h"
23#include "mlir/IR/Visitors.h"
24#include <cassert>
25#include <optional>
26
27#define DEBUG_TYPE "memref-to-spirv-pattern"
28
29using namespace mlir;
30
31//===----------------------------------------------------------------------===//
32// Utility functions
33//===----------------------------------------------------------------------===//
34
35/// Returns the offset of the value in `targetBits` representation.
36///
37/// `srcIdx` is an index into a 1-D array with each element having `sourceBits`.
38/// It's assumed to be non-negative.
39///
40/// When accessing an element in the array treating as having elements of
41/// `targetBits`, multiple values are loaded in the same time. The method
42/// returns the offset where the `srcIdx` locates in the value. For example, if
43/// `sourceBits` equals to 8 and `targetBits` equals to 32, the x-th element is
44/// located at (x % 4) * 8. Because there are four elements in one i32, and one
45/// element has 8 bits.
46static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits,
47 int targetBits, OpBuilder &builder) {
48 assert(targetBits % sourceBits == 0);
49 Type type = srcIdx.getType();
50 IntegerAttr idxAttr = builder.getIntegerAttr(type, value: targetBits / sourceBits);
51 auto idx = builder.createOrFold<spirv::ConstantOp>(location: loc, args&: type, args&: idxAttr);
52 IntegerAttr srcBitsAttr = builder.getIntegerAttr(type, value: sourceBits);
53 auto srcBitsValue =
54 builder.createOrFold<spirv::ConstantOp>(location: loc, args&: type, args&: srcBitsAttr);
55 auto m = builder.createOrFold<spirv::UModOp>(location: loc, args&: srcIdx, args&: idx);
56 return builder.createOrFold<spirv::IMulOp>(location: loc, args&: type, args&: m, args&: srcBitsValue);
57}
58
59/// Returns an adjusted spirv::AccessChainOp. Based on the
60/// extension/capabilities, certain integer bitwidths `sourceBits` might not be
61/// supported. During conversion if a memref of an unsupported type is used,
62/// load/stores to this memref need to be modified to use a supported higher
63/// bitwidth `targetBits` and extracting the required bits. For an accessing a
64/// 1D array (spirv.array or spirv.rtarray), the last index is modified to load
65/// the bits needed. The extraction of the actual bits needed are handled
66/// separately. Note that this only works for a 1-D tensor.
67static Value
68adjustAccessChainForBitwidth(const SPIRVTypeConverter &typeConverter,
69 spirv::AccessChainOp op, int sourceBits,
70 int targetBits, OpBuilder &builder) {
71 assert(targetBits % sourceBits == 0);
72 const auto loc = op.getLoc();
73 Value lastDim = op->getOperand(idx: op.getNumOperands() - 1);
74 Type type = lastDim.getType();
75 IntegerAttr attr = builder.getIntegerAttr(type, value: targetBits / sourceBits);
76 auto idx = builder.createOrFold<spirv::ConstantOp>(location: loc, args&: type, args&: attr);
77 auto indices = llvm::to_vector<4>(Range: op.getIndices());
78 // There are two elements if this is a 1-D tensor.
79 assert(indices.size() == 2);
80 indices.back() = builder.createOrFold<spirv::SDivOp>(location: loc, args&: lastDim, args&: idx);
81 Type t = typeConverter.convertType(t: op.getComponentPtr().getType());
82 return builder.create<spirv::AccessChainOp>(location: loc, args&: t, args: op.getBasePtr(), args&: indices);
83}
84
85/// Casts the given `srcBool` into an integer of `dstType`.
86static Value castBoolToIntN(Location loc, Value srcBool, Type dstType,
87 OpBuilder &builder) {
88 assert(srcBool.getType().isInteger(1));
89 if (dstType.isInteger(width: 1))
90 return srcBool;
91 Value zero = spirv::ConstantOp::getZero(type: dstType, loc, builder);
92 Value one = spirv::ConstantOp::getOne(type: dstType, loc, builder);
93 return builder.createOrFold<spirv::SelectOp>(location: loc, args&: dstType, args&: srcBool, args&: one,
94 args&: zero);
95}
96
97/// Returns the `targetBits`-bit value shifted by the given `offset`, and cast
98/// to the type destination type, and masked.
99static Value shiftValue(Location loc, Value value, Value offset, Value mask,
100 OpBuilder &builder) {
101 IntegerType dstType = cast<IntegerType>(Val: mask.getType());
102 int targetBits = static_cast<int>(dstType.getWidth());
103 int valueBits = value.getType().getIntOrFloatBitWidth();
104 assert(valueBits <= targetBits);
105
106 if (valueBits == 1) {
107 value = castBoolToIntN(loc, srcBool: value, dstType, builder);
108 } else {
109 if (valueBits < targetBits) {
110 value = builder.create<spirv::UConvertOp>(
111 location: loc, args: builder.getIntegerType(width: targetBits), args&: value);
112 }
113
114 value = builder.createOrFold<spirv::BitwiseAndOp>(location: loc, args&: value, args&: mask);
115 }
116 return builder.createOrFold<spirv::ShiftLeftLogicalOp>(location: loc, args: value.getType(),
117 args&: value, args&: offset);
118}
119
120/// Returns true if the allocations of memref `type` generated from `allocOp`
121/// can be lowered to SPIR-V.
122static bool isAllocationSupported(Operation *allocOp, MemRefType type) {
123 if (isa<memref::AllocOp, memref::DeallocOp>(Val: allocOp)) {
124 auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(Val: type.getMemorySpace());
125 if (!sc || sc.getValue() != spirv::StorageClass::Workgroup)
126 return false;
127 } else if (isa<memref::AllocaOp>(Val: allocOp)) {
128 auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(Val: type.getMemorySpace());
129 if (!sc || sc.getValue() != spirv::StorageClass::Function)
130 return false;
131 } else {
132 return false;
133 }
134
135 // Currently only support static shape and int or float or vector of int or
136 // float element type.
137 if (!type.hasStaticShape())
138 return false;
139
140 Type elementType = type.getElementType();
141 if (auto vecType = dyn_cast<VectorType>(Val&: elementType))
142 elementType = vecType.getElementType();
143 return elementType.isIntOrFloat();
144}
145
146/// Returns the scope to use for atomic operations use for emulating store
147/// operations of unsupported integer bitwidths, based on the memref
148/// type. Returns std::nullopt on failure.
149static std::optional<spirv::Scope> getAtomicOpScope(MemRefType type) {
150 auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(Val: type.getMemorySpace());
151 switch (sc.getValue()) {
152 case spirv::StorageClass::StorageBuffer:
153 return spirv::Scope::Device;
154 case spirv::StorageClass::Workgroup:
155 return spirv::Scope::Workgroup;
156 default:
157 break;
158 }
159 return {};
160}
161
162/// Casts the given `srcInt` into a boolean value.
163static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder) {
164 if (srcInt.getType().isInteger(width: 1))
165 return srcInt;
166
167 auto one = spirv::ConstantOp::getZero(type: srcInt.getType(), loc, builder);
168 return builder.createOrFold<spirv::INotEqualOp>(location: loc, args&: srcInt, args&: one);
169}
170
171//===----------------------------------------------------------------------===//
172// Operation conversion
173//===----------------------------------------------------------------------===//
174
175// Note that DRR cannot be used for the patterns in this file: we may need to
176// convert type along the way, which requires ConversionPattern. DRR generates
177// normal RewritePattern.
178
179namespace {
180
181/// Converts memref.alloca to SPIR-V Function variables.
182class AllocaOpPattern final : public OpConversionPattern<memref::AllocaOp> {
183public:
184 using OpConversionPattern<memref::AllocaOp>::OpConversionPattern;
185
186 LogicalResult
187 matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
188 ConversionPatternRewriter &rewriter) const override;
189};
190
191/// Converts an allocation operation to SPIR-V. Currently only supports lowering
192/// to Workgroup memory when the size is constant. Note that this pattern needs
193/// to be applied in a pass that runs at least at spirv.module scope since it
194/// wil ladd global variables into the spirv.module.
195class AllocOpPattern final : public OpConversionPattern<memref::AllocOp> {
196public:
197 using OpConversionPattern<memref::AllocOp>::OpConversionPattern;
198
199 LogicalResult
200 matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
201 ConversionPatternRewriter &rewriter) const override;
202};
203
204/// Converts memref.automic_rmw operations to SPIR-V atomic operations.
205class AtomicRMWOpPattern final
206 : public OpConversionPattern<memref::AtomicRMWOp> {
207public:
208 using OpConversionPattern<memref::AtomicRMWOp>::OpConversionPattern;
209
210 LogicalResult
211 matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
212 ConversionPatternRewriter &rewriter) const override;
213};
214
215/// Removed a deallocation if it is a supported allocation. Currently only
216/// removes deallocation if the memory space is workgroup memory.
217class DeallocOpPattern final : public OpConversionPattern<memref::DeallocOp> {
218public:
219 using OpConversionPattern<memref::DeallocOp>::OpConversionPattern;
220
221 LogicalResult
222 matchAndRewrite(memref::DeallocOp operation, OpAdaptor adaptor,
223 ConversionPatternRewriter &rewriter) const override;
224};
225
226/// Converts memref.load to spirv.Load + spirv.AccessChain on integers.
227class IntLoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
228public:
229 using OpConversionPattern<memref::LoadOp>::OpConversionPattern;
230
231 LogicalResult
232 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
233 ConversionPatternRewriter &rewriter) const override;
234};
235
236/// Converts memref.load to spirv.Load + spirv.AccessChain.
237class LoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
238public:
239 using OpConversionPattern<memref::LoadOp>::OpConversionPattern;
240
241 LogicalResult
242 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
243 ConversionPatternRewriter &rewriter) const override;
244};
245
246/// Converts memref.store to spirv.Store on integers.
247class IntStoreOpPattern final : public OpConversionPattern<memref::StoreOp> {
248public:
249 using OpConversionPattern<memref::StoreOp>::OpConversionPattern;
250
251 LogicalResult
252 matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
253 ConversionPatternRewriter &rewriter) const override;
254};
255
256/// Converts memref.memory_space_cast to the appropriate spirv cast operations.
257class MemorySpaceCastOpPattern final
258 : public OpConversionPattern<memref::MemorySpaceCastOp> {
259public:
260 using OpConversionPattern<memref::MemorySpaceCastOp>::OpConversionPattern;
261
262 LogicalResult
263 matchAndRewrite(memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor,
264 ConversionPatternRewriter &rewriter) const override;
265};
266
267/// Converts memref.store to spirv.Store.
268class StoreOpPattern final : public OpConversionPattern<memref::StoreOp> {
269public:
270 using OpConversionPattern<memref::StoreOp>::OpConversionPattern;
271
272 LogicalResult
273 matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
274 ConversionPatternRewriter &rewriter) const override;
275};
276
277class ReinterpretCastPattern final
278 : public OpConversionPattern<memref::ReinterpretCastOp> {
279public:
280 using OpConversionPattern::OpConversionPattern;
281
282 LogicalResult
283 matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor,
284 ConversionPatternRewriter &rewriter) const override;
285};
286
287class CastPattern final : public OpConversionPattern<memref::CastOp> {
288public:
289 using OpConversionPattern::OpConversionPattern;
290
291 LogicalResult
292 matchAndRewrite(memref::CastOp op, OpAdaptor adaptor,
293 ConversionPatternRewriter &rewriter) const override {
294 Value src = adaptor.getSource();
295 Type srcType = src.getType();
296
297 const TypeConverter *converter = getTypeConverter();
298 Type dstType = converter->convertType(t: op.getType());
299 if (srcType != dstType)
300 return rewriter.notifyMatchFailure(op, reasonCallback: [&](Diagnostic &diag) {
301 diag << "types doesn't match: " << srcType << " and " << dstType;
302 });
303
304 rewriter.replaceOp(op, newValues: src);
305 return success();
306 }
307};
308
309/// Converts memref.extract_aligned_pointer_as_index to spirv.ConvertPtrToU.
310class ExtractAlignedPointerAsIndexOpPattern final
311 : public OpConversionPattern<memref::ExtractAlignedPointerAsIndexOp> {
312public:
313 using OpConversionPattern::OpConversionPattern;
314
315 LogicalResult
316 matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
317 OpAdaptor adaptor,
318 ConversionPatternRewriter &rewriter) const override;
319};
320} // namespace
321
322//===----------------------------------------------------------------------===//
323// AllocaOp
324//===----------------------------------------------------------------------===//
325
326LogicalResult
327AllocaOpPattern::matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
328 ConversionPatternRewriter &rewriter) const {
329 MemRefType allocType = allocaOp.getType();
330 if (!isAllocationSupported(allocOp: allocaOp, type: allocType))
331 return rewriter.notifyMatchFailure(arg&: allocaOp, msg: "unhandled allocation type");
332
333 // Get the SPIR-V type for the allocation.
334 Type spirvType = getTypeConverter()->convertType(t: allocType);
335 if (!spirvType)
336 return rewriter.notifyMatchFailure(arg&: allocaOp, msg: "type conversion failed");
337
338 rewriter.replaceOpWithNewOp<spirv::VariableOp>(op: allocaOp, args&: spirvType,
339 args: spirv::StorageClass::Function,
340 /*initializer=*/args: nullptr);
341 return success();
342}
343
344//===----------------------------------------------------------------------===//
345// AllocOp
346//===----------------------------------------------------------------------===//
347
348LogicalResult
349AllocOpPattern::matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
350 ConversionPatternRewriter &rewriter) const {
351 MemRefType allocType = operation.getType();
352 if (!isAllocationSupported(allocOp: operation, type: allocType))
353 return rewriter.notifyMatchFailure(arg&: operation, msg: "unhandled allocation type");
354
355 // Get the SPIR-V type for the allocation.
356 Type spirvType = getTypeConverter()->convertType(t: allocType);
357 if (!spirvType)
358 return rewriter.notifyMatchFailure(arg&: operation, msg: "type conversion failed");
359
360 // Insert spirv.GlobalVariable for this allocation.
361 Operation *parent =
362 SymbolTable::getNearestSymbolTable(from: operation->getParentOp());
363 if (!parent)
364 return failure();
365 Location loc = operation.getLoc();
366 spirv::GlobalVariableOp varOp;
367 {
368 OpBuilder::InsertionGuard guard(rewriter);
369 Block &entryBlock = *parent->getRegion(index: 0).begin();
370 rewriter.setInsertionPointToStart(&entryBlock);
371 auto varOps = entryBlock.getOps<spirv::GlobalVariableOp>();
372 std::string varName =
373 std::string("__workgroup_mem__") +
374 std::to_string(val: std::distance(first: varOps.begin(), last: varOps.end()));
375 varOp = rewriter.create<spirv::GlobalVariableOp>(location: loc, args&: spirvType, args&: varName,
376 /*initializer=*/args: nullptr);
377 }
378
379 // Get pointer to global variable at the current scope.
380 rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(op: operation, args&: varOp);
381 return success();
382}
383
384//===----------------------------------------------------------------------===//
385// AllocOp
386//===----------------------------------------------------------------------===//
387
388LogicalResult
389AtomicRMWOpPattern::matchAndRewrite(memref::AtomicRMWOp atomicOp,
390 OpAdaptor adaptor,
391 ConversionPatternRewriter &rewriter) const {
392 if (isa<FloatType>(Val: atomicOp.getType()))
393 return rewriter.notifyMatchFailure(arg&: atomicOp,
394 msg: "unimplemented floating-point case");
395
396 auto memrefType = cast<MemRefType>(Val: atomicOp.getMemref().getType());
397 std::optional<spirv::Scope> scope = getAtomicOpScope(type: memrefType);
398 if (!scope)
399 return rewriter.notifyMatchFailure(arg&: atomicOp,
400 msg: "unsupported memref memory space");
401
402 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
403 Type resultType = typeConverter.convertType(t: atomicOp.getType());
404 if (!resultType)
405 return rewriter.notifyMatchFailure(arg&: atomicOp,
406 msg: "failed to convert result type");
407
408 auto loc = atomicOp.getLoc();
409 Value ptr =
410 spirv::getElementPtr(typeConverter, baseType: memrefType, basePtr: adaptor.getMemref(),
411 indices: adaptor.getIndices(), loc, builder&: rewriter);
412
413 if (!ptr)
414 return failure();
415
416#define ATOMIC_CASE(kind, spirvOp) \
417 case arith::AtomicRMWKind::kind: \
418 rewriter.replaceOpWithNewOp<spirv::spirvOp>( \
419 atomicOp, resultType, ptr, *scope, \
420 spirv::MemorySemantics::AcquireRelease, adaptor.getValue()); \
421 break
422
423 switch (atomicOp.getKind()) {
424 ATOMIC_CASE(addi, AtomicIAddOp);
425 ATOMIC_CASE(maxs, AtomicSMaxOp);
426 ATOMIC_CASE(maxu, AtomicUMaxOp);
427 ATOMIC_CASE(mins, AtomicSMinOp);
428 ATOMIC_CASE(minu, AtomicUMinOp);
429 ATOMIC_CASE(ori, AtomicOrOp);
430 ATOMIC_CASE(andi, AtomicAndOp);
431 default:
432 return rewriter.notifyMatchFailure(arg&: atomicOp, msg: "unimplemented atomic kind");
433 }
434
435#undef ATOMIC_CASE
436
437 return success();
438}
439
440//===----------------------------------------------------------------------===//
441// DeallocOp
442//===----------------------------------------------------------------------===//
443
444LogicalResult
445DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation,
446 OpAdaptor adaptor,
447 ConversionPatternRewriter &rewriter) const {
448 MemRefType deallocType = cast<MemRefType>(Val: operation.getMemref().getType());
449 if (!isAllocationSupported(allocOp: operation, type: deallocType))
450 return rewriter.notifyMatchFailure(arg&: operation, msg: "unhandled allocation type");
451 rewriter.eraseOp(op: operation);
452 return success();
453}
454
455//===----------------------------------------------------------------------===//
456// LoadOp
457//===----------------------------------------------------------------------===//
458
459struct MemoryRequirements {
460 spirv::MemoryAccessAttr memoryAccess;
461 IntegerAttr alignment;
462};
463
464/// Given an accessed SPIR-V pointer, calculates its alignment requirements, if
465/// any.
466static FailureOr<MemoryRequirements>
467calculateMemoryRequirements(Value accessedPtr, bool isNontemporal) {
468 MLIRContext *ctx = accessedPtr.getContext();
469
470 auto memoryAccess = spirv::MemoryAccess::None;
471 if (isNontemporal) {
472 memoryAccess = spirv::MemoryAccess::Nontemporal;
473 }
474
475 auto ptrType = cast<spirv::PointerType>(Val: accessedPtr.getType());
476 if (ptrType.getStorageClass() != spirv::StorageClass::PhysicalStorageBuffer) {
477 if (memoryAccess == spirv::MemoryAccess::None) {
478 return MemoryRequirements{.memoryAccess: spirv::MemoryAccessAttr{}, .alignment: IntegerAttr{}};
479 }
480 return MemoryRequirements{.memoryAccess: spirv::MemoryAccessAttr::get(context: ctx, value: memoryAccess),
481 .alignment: IntegerAttr{}};
482 }
483
484 // PhysicalStorageBuffers require the `Aligned` attribute.
485 auto pointeeType = dyn_cast<spirv::ScalarType>(Val: ptrType.getPointeeType());
486 if (!pointeeType)
487 return failure();
488
489 // For scalar types, the alignment is determined by their size.
490 std::optional<int64_t> sizeInBytes = pointeeType.getSizeInBytes();
491 if (!sizeInBytes.has_value())
492 return failure();
493
494 memoryAccess = memoryAccess | spirv::MemoryAccess::Aligned;
495 auto memAccessAttr = spirv::MemoryAccessAttr::get(context: ctx, value: memoryAccess);
496 auto alignment = IntegerAttr::get(type: IntegerType::get(context: ctx, width: 32), value: *sizeInBytes);
497 return MemoryRequirements{.memoryAccess: memAccessAttr, .alignment: alignment};
498}
499
500/// Given an accessed SPIR-V pointer and the original memref load/store
501/// `memAccess` op, calculates the alignment requirements, if any. Takes into
502/// account the alignment attributes applied to the load/store op.
503template <class LoadOrStoreOp>
504static FailureOr<MemoryRequirements>
505calculateMemoryRequirements(Value accessedPtr, LoadOrStoreOp loadOrStoreOp) {
506 static_assert(
507 llvm::is_one_of<LoadOrStoreOp, memref::LoadOp, memref::StoreOp>::value,
508 "Must be called on either memref::LoadOp or memref::StoreOp");
509
510 Operation *memrefAccessOp = loadOrStoreOp.getOperation();
511 auto memrefMemAccess = memrefAccessOp->getAttrOfType<spirv::MemoryAccessAttr>(
512 name: spirv::attributeName<spirv::MemoryAccess>());
513 auto memrefAlignment =
514 memrefAccessOp->getAttrOfType<IntegerAttr>(name: "alignment");
515 if (memrefMemAccess && memrefAlignment)
516 return MemoryRequirements{.memoryAccess: memrefMemAccess, .alignment: memrefAlignment};
517
518 return calculateMemoryRequirements(accessedPtr,
519 loadOrStoreOp.getNontemporal());
520}
521
522LogicalResult
523IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
524 ConversionPatternRewriter &rewriter) const {
525 auto loc = loadOp.getLoc();
526 auto memrefType = cast<MemRefType>(Val: loadOp.getMemref().getType());
527 if (!memrefType.getElementType().isSignlessInteger())
528 return failure();
529
530 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
531 Value accessChain =
532 spirv::getElementPtr(typeConverter, baseType: memrefType, basePtr: adaptor.getMemref(),
533 indices: adaptor.getIndices(), loc, builder&: rewriter);
534
535 if (!accessChain)
536 return failure();
537
538 int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
539 bool isBool = srcBits == 1;
540 if (isBool)
541 srcBits = typeConverter.getOptions().boolNumBits;
542
543 auto pointerType = typeConverter.convertType<spirv::PointerType>(t: memrefType);
544 if (!pointerType)
545 return rewriter.notifyMatchFailure(arg&: loadOp, msg: "failed to convert memref type");
546
547 Type pointeeType = pointerType.getPointeeType();
548 Type dstType;
549 if (typeConverter.allows(capability: spirv::Capability::Kernel)) {
550 if (auto arrayType = dyn_cast<spirv::ArrayType>(Val&: pointeeType))
551 dstType = arrayType.getElementType();
552 else
553 dstType = pointeeType;
554 } else {
555 // For Vulkan we need to extract element from wrapping struct and array.
556 Type structElemType =
557 cast<spirv::StructType>(Val&: pointeeType).getElementType(0);
558 if (auto arrayType = dyn_cast<spirv::ArrayType>(Val&: structElemType))
559 dstType = arrayType.getElementType();
560 else
561 dstType = cast<spirv::RuntimeArrayType>(Val&: structElemType).getElementType();
562 }
563 int dstBits = dstType.getIntOrFloatBitWidth();
564 assert(dstBits % srcBits == 0);
565
566 // If the rewritten load op has the same bit width, use the loading value
567 // directly.
568 if (srcBits == dstBits) {
569 auto memoryRequirements = calculateMemoryRequirements(accessedPtr: accessChain, loadOrStoreOp: loadOp);
570 if (failed(Result: memoryRequirements))
571 return rewriter.notifyMatchFailure(
572 arg&: loadOp, msg: "failed to determine memory requirements");
573
574 auto [memoryAccess, alignment] = *memoryRequirements;
575 Value loadVal = rewriter.create<spirv::LoadOp>(location: loc, args&: accessChain,
576 args&: memoryAccess, args&: alignment);
577 if (isBool)
578 loadVal = castIntNToBool(loc, srcInt: loadVal, builder&: rewriter);
579 rewriter.replaceOp(op: loadOp, newValues: loadVal);
580 return success();
581 }
582
583 // Bitcasting is currently unsupported for Kernel capability /
584 // spirv.PtrAccessChain.
585 if (typeConverter.allows(capability: spirv::Capability::Kernel))
586 return failure();
587
588 auto accessChainOp = accessChain.getDefiningOp<spirv::AccessChainOp>();
589 if (!accessChainOp)
590 return failure();
591
592 // Assume that getElementPtr() works linearizely. If it's a scalar, the method
593 // still returns a linearized accessing. If the accessing is not linearized,
594 // there will be offset issues.
595 assert(accessChainOp.getIndices().size() == 2);
596 Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, op: accessChainOp,
597 sourceBits: srcBits, targetBits: dstBits, builder&: rewriter);
598 auto memoryRequirements = calculateMemoryRequirements(accessedPtr: adjustedPtr, loadOrStoreOp: loadOp);
599 if (failed(Result: memoryRequirements))
600 return rewriter.notifyMatchFailure(
601 arg&: loadOp, msg: "failed to determine memory requirements");
602
603 auto [memoryAccess, alignment] = *memoryRequirements;
604 Value spvLoadOp = rewriter.create<spirv::LoadOp>(location: loc, args&: dstType, args&: adjustedPtr,
605 args&: memoryAccess, args&: alignment);
606
607 // Shift the bits to the rightmost.
608 // ____XXXX________ -> ____________XXXX
609 Value lastDim = accessChainOp->getOperand(idx: accessChainOp.getNumOperands() - 1);
610 Value offset = getOffsetForBitwidth(loc, srcIdx: lastDim, sourceBits: srcBits, targetBits: dstBits, builder&: rewriter);
611 Value result = rewriter.createOrFold<spirv::ShiftRightArithmeticOp>(
612 location: loc, args: spvLoadOp.getType(), args&: spvLoadOp, args&: offset);
613
614 // Apply the mask to extract corresponding bits.
615 Value mask = rewriter.createOrFold<spirv::ConstantOp>(
616 location: loc, args&: dstType, args: rewriter.getIntegerAttr(type: dstType, value: (1 << srcBits) - 1));
617 result =
618 rewriter.createOrFold<spirv::BitwiseAndOp>(location: loc, args&: dstType, args&: result, args&: mask);
619
620 // Apply sign extension on the loading value unconditionally. The signedness
621 // semantic is carried in the operator itself, we relies other pattern to
622 // handle the casting.
623 IntegerAttr shiftValueAttr =
624 rewriter.getIntegerAttr(type: dstType, value: dstBits - srcBits);
625 Value shiftValue =
626 rewriter.createOrFold<spirv::ConstantOp>(location: loc, args&: dstType, args&: shiftValueAttr);
627 result = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>(location: loc, args&: dstType,
628 args&: result, args&: shiftValue);
629 result = rewriter.createOrFold<spirv::ShiftRightArithmeticOp>(
630 location: loc, args&: dstType, args&: result, args&: shiftValue);
631
632 rewriter.replaceOp(op: loadOp, newValues: result);
633
634 assert(accessChainOp.use_empty());
635 rewriter.eraseOp(op: accessChainOp);
636
637 return success();
638}
639
640LogicalResult
641LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
642 ConversionPatternRewriter &rewriter) const {
643 auto memrefType = cast<MemRefType>(Val: loadOp.getMemref().getType());
644 if (memrefType.getElementType().isSignlessInteger())
645 return failure();
646 Value loadPtr = spirv::getElementPtr(
647 typeConverter: *getTypeConverter<SPIRVTypeConverter>(), baseType: memrefType, basePtr: adaptor.getMemref(),
648 indices: adaptor.getIndices(), loc: loadOp.getLoc(), builder&: rewriter);
649
650 if (!loadPtr)
651 return failure();
652
653 auto memoryRequirements = calculateMemoryRequirements(accessedPtr: loadPtr, loadOrStoreOp: loadOp);
654 if (failed(Result: memoryRequirements))
655 return rewriter.notifyMatchFailure(
656 arg&: loadOp, msg: "failed to determine memory requirements");
657
658 auto [memoryAccess, alignment] = *memoryRequirements;
659 rewriter.replaceOpWithNewOp<spirv::LoadOp>(op: loadOp, args&: loadPtr, args&: memoryAccess,
660 args&: alignment);
661 return success();
662}
663
664LogicalResult
665IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
666 ConversionPatternRewriter &rewriter) const {
667 auto memrefType = cast<MemRefType>(Val: storeOp.getMemref().getType());
668 if (!memrefType.getElementType().isSignlessInteger())
669 return rewriter.notifyMatchFailure(arg&: storeOp,
670 msg: "element type is not a signless int");
671
672 auto loc = storeOp.getLoc();
673 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
674 Value accessChain =
675 spirv::getElementPtr(typeConverter, baseType: memrefType, basePtr: adaptor.getMemref(),
676 indices: adaptor.getIndices(), loc, builder&: rewriter);
677
678 if (!accessChain)
679 return rewriter.notifyMatchFailure(
680 arg&: storeOp, msg: "failed to convert element pointer type");
681
682 int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
683
684 bool isBool = srcBits == 1;
685 if (isBool)
686 srcBits = typeConverter.getOptions().boolNumBits;
687
688 auto pointerType = typeConverter.convertType<spirv::PointerType>(t: memrefType);
689 if (!pointerType)
690 return rewriter.notifyMatchFailure(arg&: storeOp,
691 msg: "failed to convert memref type");
692
693 Type pointeeType = pointerType.getPointeeType();
694 IntegerType dstType;
695 if (typeConverter.allows(capability: spirv::Capability::Kernel)) {
696 if (auto arrayType = dyn_cast<spirv::ArrayType>(Val&: pointeeType))
697 dstType = dyn_cast<IntegerType>(Val: arrayType.getElementType());
698 else
699 dstType = dyn_cast<IntegerType>(Val&: pointeeType);
700 } else {
701 // For Vulkan we need to extract element from wrapping struct and array.
702 Type structElemType =
703 cast<spirv::StructType>(Val&: pointeeType).getElementType(0);
704 if (auto arrayType = dyn_cast<spirv::ArrayType>(Val&: structElemType))
705 dstType = dyn_cast<IntegerType>(Val: arrayType.getElementType());
706 else
707 dstType = dyn_cast<IntegerType>(
708 Val: cast<spirv::RuntimeArrayType>(Val&: structElemType).getElementType());
709 }
710
711 if (!dstType)
712 return rewriter.notifyMatchFailure(
713 arg&: storeOp, msg: "failed to determine destination element type");
714
715 int dstBits = static_cast<int>(dstType.getWidth());
716 assert(dstBits % srcBits == 0);
717
718 if (srcBits == dstBits) {
719 auto memoryRequirements = calculateMemoryRequirements(accessedPtr: accessChain, loadOrStoreOp: storeOp);
720 if (failed(Result: memoryRequirements))
721 return rewriter.notifyMatchFailure(
722 arg&: storeOp, msg: "failed to determine memory requirements");
723
724 auto [memoryAccess, alignment] = *memoryRequirements;
725 Value storeVal = adaptor.getValue();
726 if (isBool)
727 storeVal = castBoolToIntN(loc, srcBool: storeVal, dstType, builder&: rewriter);
728 rewriter.replaceOpWithNewOp<spirv::StoreOp>(op: storeOp, args&: accessChain, args&: storeVal,
729 args&: memoryAccess, args&: alignment);
730 return success();
731 }
732
733 // Bitcasting is currently unsupported for Kernel capability /
734 // spirv.PtrAccessChain.
735 if (typeConverter.allows(capability: spirv::Capability::Kernel))
736 return failure();
737
738 auto accessChainOp = accessChain.getDefiningOp<spirv::AccessChainOp>();
739 if (!accessChainOp)
740 return failure();
741
742 // Since there are multiple threads in the processing, the emulation will be
743 // done with atomic operations. E.g., if the stored value is i8, rewrite the
744 // StoreOp to:
745 // 1) load a 32-bit integer
746 // 2) clear 8 bits in the loaded value
747 // 3) set 8 bits in the loaded value
748 // 4) store 32-bit value back
749 //
750 // Step 2 is done with AtomicAnd, and step 3 is done with AtomicOr (of the
751 // loaded 32-bit value and the shifted 8-bit store value) as another atomic
752 // step.
753 assert(accessChainOp.getIndices().size() == 2);
754 Value lastDim = accessChainOp->getOperand(idx: accessChainOp.getNumOperands() - 1);
755 Value offset = getOffsetForBitwidth(loc, srcIdx: lastDim, sourceBits: srcBits, targetBits: dstBits, builder&: rewriter);
756
757 // Create a mask to clear the destination. E.g., if it is the second i8 in
758 // i32, 0xFFFF00FF is created.
759 Value mask = rewriter.createOrFold<spirv::ConstantOp>(
760 location: loc, args&: dstType, args: rewriter.getIntegerAttr(type: dstType, value: (1 << srcBits) - 1));
761 Value clearBitsMask = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>(
762 location: loc, args&: dstType, args&: mask, args&: offset);
763 clearBitsMask =
764 rewriter.createOrFold<spirv::NotOp>(location: loc, args&: dstType, args&: clearBitsMask);
765
766 Value storeVal = shiftValue(loc, value: adaptor.getValue(), offset, mask, builder&: rewriter);
767 Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, op: accessChainOp,
768 sourceBits: srcBits, targetBits: dstBits, builder&: rewriter);
769 std::optional<spirv::Scope> scope = getAtomicOpScope(type: memrefType);
770 if (!scope)
771 return rewriter.notifyMatchFailure(arg&: storeOp, msg: "atomic scope not available");
772
773 Value result = rewriter.create<spirv::AtomicAndOp>(
774 location: loc, args&: dstType, args&: adjustedPtr, args&: *scope, args: spirv::MemorySemantics::AcquireRelease,
775 args&: clearBitsMask);
776 result = rewriter.create<spirv::AtomicOrOp>(
777 location: loc, args&: dstType, args&: adjustedPtr, args&: *scope, args: spirv::MemorySemantics::AcquireRelease,
778 args&: storeVal);
779
780 // The AtomicOrOp has no side effect. Since it is already inserted, we can
781 // just remove the original StoreOp. Note that rewriter.replaceOp()
782 // doesn't work because it only accepts that the numbers of result are the
783 // same.
784 rewriter.eraseOp(op: storeOp);
785
786 assert(accessChainOp.use_empty());
787 rewriter.eraseOp(op: accessChainOp);
788
789 return success();
790}
791
792//===----------------------------------------------------------------------===//
793// MemorySpaceCastOp
794//===----------------------------------------------------------------------===//
795
796LogicalResult MemorySpaceCastOpPattern::matchAndRewrite(
797 memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor,
798 ConversionPatternRewriter &rewriter) const {
799 Location loc = addrCastOp.getLoc();
800 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
801 if (!typeConverter.allows(capability: spirv::Capability::Kernel))
802 return rewriter.notifyMatchFailure(
803 arg&: loc, msg: "address space casts require kernel capability");
804
805 auto sourceType = dyn_cast<MemRefType>(Val: addrCastOp.getSource().getType());
806 if (!sourceType)
807 return rewriter.notifyMatchFailure(
808 arg&: loc, msg: "SPIR-V lowering requires ranked memref types");
809 auto resultType = cast<MemRefType>(Val: addrCastOp.getResult().getType());
810
811 auto sourceStorageClassAttr =
812 dyn_cast_or_null<spirv::StorageClassAttr>(Val: sourceType.getMemorySpace());
813 if (!sourceStorageClassAttr)
814 return rewriter.notifyMatchFailure(loc, reasonCallback: [sourceType](Diagnostic &diag) {
815 diag << "source address space " << sourceType.getMemorySpace()
816 << " must be a SPIR-V storage class";
817 });
818 auto resultStorageClassAttr =
819 dyn_cast_or_null<spirv::StorageClassAttr>(Val: resultType.getMemorySpace());
820 if (!resultStorageClassAttr)
821 return rewriter.notifyMatchFailure(loc, reasonCallback: [resultType](Diagnostic &diag) {
822 diag << "result address space " << resultType.getMemorySpace()
823 << " must be a SPIR-V storage class";
824 });
825
826 spirv::StorageClass sourceSc = sourceStorageClassAttr.getValue();
827 spirv::StorageClass resultSc = resultStorageClassAttr.getValue();
828
829 Value result = adaptor.getSource();
830 Type resultPtrType = typeConverter.convertType(t: resultType);
831 if (!resultPtrType)
832 return rewriter.notifyMatchFailure(arg&: addrCastOp,
833 msg: "failed to convert memref type");
834
835 Type genericPtrType = resultPtrType;
836 // SPIR-V doesn't have a general address space cast operation. Instead, it has
837 // conversions to and from generic pointers. To implement the general case,
838 // we use specific-to-generic conversions when the source class is not
839 // generic. Then when the result storage class is not generic, we convert the
840 // generic pointer (either the input on ar intermediate result) to that
841 // class. This also means that we'll need the intermediate generic pointer
842 // type if neither the source or destination have it.
843 if (sourceSc != spirv::StorageClass::Generic &&
844 resultSc != spirv::StorageClass::Generic) {
845 Type intermediateType =
846 MemRefType::get(shape: sourceType.getShape(), elementType: sourceType.getElementType(),
847 layout: sourceType.getLayout(),
848 memorySpace: rewriter.getAttr<spirv::StorageClassAttr>(
849 args: spirv::StorageClass::Generic));
850 genericPtrType = typeConverter.convertType(t: intermediateType);
851 }
852 if (sourceSc != spirv::StorageClass::Generic) {
853 result =
854 rewriter.create<spirv::PtrCastToGenericOp>(location: loc, args&: genericPtrType, args&: result);
855 }
856 if (resultSc != spirv::StorageClass::Generic) {
857 result =
858 rewriter.create<spirv::GenericCastToPtrOp>(location: loc, args&: resultPtrType, args&: result);
859 }
860 rewriter.replaceOp(op: addrCastOp, newValues: result);
861 return success();
862}
863
864LogicalResult
865StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
866 ConversionPatternRewriter &rewriter) const {
867 auto memrefType = cast<MemRefType>(Val: storeOp.getMemref().getType());
868 if (memrefType.getElementType().isSignlessInteger())
869 return rewriter.notifyMatchFailure(arg&: storeOp, msg: "signless int");
870 auto storePtr = spirv::getElementPtr(
871 typeConverter: *getTypeConverter<SPIRVTypeConverter>(), baseType: memrefType, basePtr: adaptor.getMemref(),
872 indices: adaptor.getIndices(), loc: storeOp.getLoc(), builder&: rewriter);
873
874 if (!storePtr)
875 return rewriter.notifyMatchFailure(arg&: storeOp, msg: "type conversion failed");
876
877 auto memoryRequirements = calculateMemoryRequirements(accessedPtr: storePtr, loadOrStoreOp: storeOp);
878 if (failed(Result: memoryRequirements))
879 return rewriter.notifyMatchFailure(
880 arg&: storeOp, msg: "failed to determine memory requirements");
881
882 auto [memoryAccess, alignment] = *memoryRequirements;
883 rewriter.replaceOpWithNewOp<spirv::StoreOp>(
884 op: storeOp, args&: storePtr, args: adaptor.getValue(), args&: memoryAccess, args&: alignment);
885 return success();
886}
887
888LogicalResult ReinterpretCastPattern::matchAndRewrite(
889 memref::ReinterpretCastOp op, OpAdaptor adaptor,
890 ConversionPatternRewriter &rewriter) const {
891 Value src = adaptor.getSource();
892 auto srcType = dyn_cast<spirv::PointerType>(Val: src.getType());
893
894 if (!srcType)
895 return rewriter.notifyMatchFailure(op, reasonCallback: [&](Diagnostic &diag) {
896 diag << "invalid src type " << src.getType();
897 });
898
899 const TypeConverter *converter = getTypeConverter();
900
901 auto dstType = converter->convertType<spirv::PointerType>(t: op.getType());
902 if (dstType != srcType)
903 return rewriter.notifyMatchFailure(op, reasonCallback: [&](Diagnostic &diag) {
904 diag << "invalid dst type " << op.getType();
905 });
906
907 OpFoldResult offset =
908 getMixedValues(staticValues: adaptor.getStaticOffsets(), dynamicValues: adaptor.getOffsets(), b&: rewriter)
909 .front();
910 if (isZeroInteger(v: offset)) {
911 rewriter.replaceOp(op, newValues: src);
912 return success();
913 }
914
915 Type intType = converter->convertType(t: rewriter.getIndexType());
916 if (!intType)
917 return rewriter.notifyMatchFailure(arg&: op, msg: "failed to convert index type");
918
919 Location loc = op.getLoc();
920 auto offsetValue = [&]() -> Value {
921 if (auto val = dyn_cast<Value>(Val&: offset))
922 return val;
923
924 int64_t attrVal = cast<IntegerAttr>(Val: cast<Attribute>(Val&: offset)).getInt();
925 Attribute attr = rewriter.getIntegerAttr(type: intType, value: attrVal);
926 return rewriter.createOrFold<spirv::ConstantOp>(location: loc, args&: intType, args&: attr);
927 }();
928
929 rewriter.replaceOpWithNewOp<spirv::InBoundsPtrAccessChainOp>(
930 op, args&: src, args&: offsetValue, args: ValueRange());
931 return success();
932}
933
934//===----------------------------------------------------------------------===//
935// ExtractAlignedPointerAsIndexOp
936//===----------------------------------------------------------------------===//
937
938LogicalResult ExtractAlignedPointerAsIndexOpPattern::matchAndRewrite(
939 memref::ExtractAlignedPointerAsIndexOp extractOp, OpAdaptor adaptor,
940 ConversionPatternRewriter &rewriter) const {
941 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
942 Type indexType = typeConverter.getIndexType();
943 rewriter.replaceOpWithNewOp<spirv::ConvertPtrToUOp>(op: extractOp, args&: indexType,
944 args: adaptor.getSource());
945 return success();
946}
947
948//===----------------------------------------------------------------------===//
949// Pattern population
950//===----------------------------------------------------------------------===//
951
952namespace mlir {
953void populateMemRefToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
954 RewritePatternSet &patterns) {
955 patterns
956 .add<AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern,
957 DeallocOpPattern, IntLoadOpPattern, IntStoreOpPattern, LoadOpPattern,
958 MemorySpaceCastOpPattern, StoreOpPattern, ReinterpretCastPattern,
959 CastPattern, ExtractAlignedPointerAsIndexOpPattern>(
960 arg: typeConverter, args: patterns.getContext());
961}
962} // namespace mlir
963

source code of mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp