1 | //===- MemRefToSPIRV.cpp - MemRef to SPIR-V Patterns ----------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements patterns to convert MemRef dialect to SPIR-V dialect. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Dialect/Arith/IR/Arith.h" |
14 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
15 | #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" |
16 | #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" |
17 | #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" |
18 | #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" |
19 | #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" |
20 | #include "mlir/IR/BuiltinAttributes.h" |
21 | #include "mlir/IR/BuiltinTypes.h" |
22 | #include "mlir/IR/MLIRContext.h" |
23 | #include "mlir/IR/Visitors.h" |
24 | #include "mlir/Support/LogicalResult.h" |
25 | #include "llvm/Support/Debug.h" |
26 | #include <cassert> |
27 | #include <optional> |
28 | |
29 | #define DEBUG_TYPE "memref-to-spirv-pattern" |
30 | |
31 | using namespace mlir; |
32 | |
33 | //===----------------------------------------------------------------------===// |
34 | // Utility functions |
35 | //===----------------------------------------------------------------------===// |
36 | |
37 | /// Returns the offset of the value in `targetBits` representation. |
38 | /// |
39 | /// `srcIdx` is an index into a 1-D array with each element having `sourceBits`. |
40 | /// It's assumed to be non-negative. |
41 | /// |
42 | /// When accessing an element in the array treating as having elements of |
43 | /// `targetBits`, multiple values are loaded in the same time. The method |
44 | /// returns the offset where the `srcIdx` locates in the value. For example, if |
45 | /// `sourceBits` equals to 8 and `targetBits` equals to 32, the x-th element is |
46 | /// located at (x % 4) * 8. Because there are four elements in one i32, and one |
47 | /// element has 8 bits. |
48 | static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits, |
49 | int targetBits, OpBuilder &builder) { |
50 | assert(targetBits % sourceBits == 0); |
51 | Type type = srcIdx.getType(); |
52 | IntegerAttr idxAttr = builder.getIntegerAttr(type, targetBits / sourceBits); |
53 | auto idx = builder.createOrFold<spirv::ConstantOp>(loc, type, idxAttr); |
54 | IntegerAttr srcBitsAttr = builder.getIntegerAttr(type, sourceBits); |
55 | auto srcBitsValue = |
56 | builder.createOrFold<spirv::ConstantOp>(loc, type, srcBitsAttr); |
57 | auto m = builder.createOrFold<spirv::UModOp>(loc, srcIdx, idx); |
58 | return builder.createOrFold<spirv::IMulOp>(loc, type, m, srcBitsValue); |
59 | } |
60 | |
61 | /// Returns an adjusted spirv::AccessChainOp. Based on the |
62 | /// extension/capabilities, certain integer bitwidths `sourceBits` might not be |
63 | /// supported. During conversion if a memref of an unsupported type is used, |
64 | /// load/stores to this memref need to be modified to use a supported higher |
65 | /// bitwidth `targetBits` and extracting the required bits. For an accessing a |
66 | /// 1D array (spirv.array or spirv.rtarray), the last index is modified to load |
67 | /// the bits needed. The extraction of the actual bits needed are handled |
68 | /// separately. Note that this only works for a 1-D tensor. |
69 | static Value |
70 | adjustAccessChainForBitwidth(const SPIRVTypeConverter &typeConverter, |
71 | spirv::AccessChainOp op, int sourceBits, |
72 | int targetBits, OpBuilder &builder) { |
73 | assert(targetBits % sourceBits == 0); |
74 | const auto loc = op.getLoc(); |
75 | Value lastDim = op->getOperand(op.getNumOperands() - 1); |
76 | Type type = lastDim.getType(); |
77 | IntegerAttr attr = builder.getIntegerAttr(type, targetBits / sourceBits); |
78 | auto idx = builder.createOrFold<spirv::ConstantOp>(loc, type, attr); |
79 | auto indices = llvm::to_vector<4>(op.getIndices()); |
80 | // There are two elements if this is a 1-D tensor. |
81 | assert(indices.size() == 2); |
82 | indices.back() = builder.createOrFold<spirv::SDivOp>(loc, lastDim, idx); |
83 | Type t = typeConverter.convertType(op.getComponentPtr().getType()); |
84 | return builder.create<spirv::AccessChainOp>(loc, t, op.getBasePtr(), indices); |
85 | } |
86 | |
87 | /// Casts the given `srcBool` into an integer of `dstType`. |
88 | static Value castBoolToIntN(Location loc, Value srcBool, Type dstType, |
89 | OpBuilder &builder) { |
90 | assert(srcBool.getType().isInteger(1)); |
91 | if (dstType.isInteger(width: 1)) |
92 | return srcBool; |
93 | Value zero = spirv::ConstantOp::getZero(dstType, loc, builder); |
94 | Value one = spirv::ConstantOp::getOne(dstType, loc, builder); |
95 | return builder.createOrFold<spirv::SelectOp>(loc, dstType, srcBool, one, |
96 | zero); |
97 | } |
98 | |
99 | /// Returns the `targetBits`-bit value shifted by the given `offset`, and cast |
100 | /// to the type destination type, and masked. |
101 | static Value shiftValue(Location loc, Value value, Value offset, Value mask, |
102 | OpBuilder &builder) { |
103 | IntegerType dstType = cast<IntegerType>(mask.getType()); |
104 | int targetBits = static_cast<int>(dstType.getWidth()); |
105 | int valueBits = value.getType().getIntOrFloatBitWidth(); |
106 | assert(valueBits <= targetBits); |
107 | |
108 | if (valueBits == 1) { |
109 | value = castBoolToIntN(loc, value, dstType, builder); |
110 | } else { |
111 | if (valueBits < targetBits) { |
112 | value = builder.create<spirv::UConvertOp>( |
113 | loc, builder.getIntegerType(targetBits), value); |
114 | } |
115 | |
116 | value = builder.createOrFold<spirv::BitwiseAndOp>(loc, value, mask); |
117 | } |
118 | return builder.createOrFold<spirv::ShiftLeftLogicalOp>(loc, value.getType(), |
119 | value, offset); |
120 | } |
121 | |
122 | /// Returns true if the allocations of memref `type` generated from `allocOp` |
123 | /// can be lowered to SPIR-V. |
124 | static bool isAllocationSupported(Operation *allocOp, MemRefType type) { |
125 | if (isa<memref::AllocOp, memref::DeallocOp>(allocOp)) { |
126 | auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace()); |
127 | if (!sc || sc.getValue() != spirv::StorageClass::Workgroup) |
128 | return false; |
129 | } else if (isa<memref::AllocaOp>(allocOp)) { |
130 | auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace()); |
131 | if (!sc || sc.getValue() != spirv::StorageClass::Function) |
132 | return false; |
133 | } else { |
134 | return false; |
135 | } |
136 | |
137 | // Currently only support static shape and int or float or vector of int or |
138 | // float element type. |
139 | if (!type.hasStaticShape()) |
140 | return false; |
141 | |
142 | Type elementType = type.getElementType(); |
143 | if (auto vecType = dyn_cast<VectorType>(elementType)) |
144 | elementType = vecType.getElementType(); |
145 | return elementType.isIntOrFloat(); |
146 | } |
147 | |
148 | /// Returns the scope to use for atomic operations use for emulating store |
149 | /// operations of unsupported integer bitwidths, based on the memref |
150 | /// type. Returns std::nullopt on failure. |
151 | static std::optional<spirv::Scope> getAtomicOpScope(MemRefType type) { |
152 | auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace()); |
153 | switch (sc.getValue()) { |
154 | case spirv::StorageClass::StorageBuffer: |
155 | return spirv::Scope::Device; |
156 | case spirv::StorageClass::Workgroup: |
157 | return spirv::Scope::Workgroup; |
158 | default: |
159 | break; |
160 | } |
161 | return {}; |
162 | } |
163 | |
164 | /// Casts the given `srcInt` into a boolean value. |
165 | static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder) { |
166 | if (srcInt.getType().isInteger(width: 1)) |
167 | return srcInt; |
168 | |
169 | auto one = spirv::ConstantOp::getOne(srcInt.getType(), loc, builder); |
170 | return builder.createOrFold<spirv::IEqualOp>(loc, srcInt, one); |
171 | } |
172 | |
173 | //===----------------------------------------------------------------------===// |
174 | // Operation conversion |
175 | //===----------------------------------------------------------------------===// |
176 | |
177 | // Note that DRR cannot be used for the patterns in this file: we may need to |
178 | // convert type along the way, which requires ConversionPattern. DRR generates |
179 | // normal RewritePattern. |
180 | |
181 | namespace { |
182 | |
183 | /// Converts memref.alloca to SPIR-V Function variables. |
184 | class AllocaOpPattern final : public OpConversionPattern<memref::AllocaOp> { |
185 | public: |
186 | using OpConversionPattern<memref::AllocaOp>::OpConversionPattern; |
187 | |
188 | LogicalResult |
189 | matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor, |
190 | ConversionPatternRewriter &rewriter) const override; |
191 | }; |
192 | |
193 | /// Converts an allocation operation to SPIR-V. Currently only supports lowering |
194 | /// to Workgroup memory when the size is constant. Note that this pattern needs |
195 | /// to be applied in a pass that runs at least at spirv.module scope since it |
196 | /// wil ladd global variables into the spirv.module. |
197 | class AllocOpPattern final : public OpConversionPattern<memref::AllocOp> { |
198 | public: |
199 | using OpConversionPattern<memref::AllocOp>::OpConversionPattern; |
200 | |
201 | LogicalResult |
202 | matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor, |
203 | ConversionPatternRewriter &rewriter) const override; |
204 | }; |
205 | |
206 | /// Converts memref.automic_rmw operations to SPIR-V atomic operations. |
207 | class AtomicRMWOpPattern final |
208 | : public OpConversionPattern<memref::AtomicRMWOp> { |
209 | public: |
210 | using OpConversionPattern<memref::AtomicRMWOp>::OpConversionPattern; |
211 | |
212 | LogicalResult |
213 | matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor, |
214 | ConversionPatternRewriter &rewriter) const override; |
215 | }; |
216 | |
217 | /// Removed a deallocation if it is a supported allocation. Currently only |
218 | /// removes deallocation if the memory space is workgroup memory. |
219 | class DeallocOpPattern final : public OpConversionPattern<memref::DeallocOp> { |
220 | public: |
221 | using OpConversionPattern<memref::DeallocOp>::OpConversionPattern; |
222 | |
223 | LogicalResult |
224 | matchAndRewrite(memref::DeallocOp operation, OpAdaptor adaptor, |
225 | ConversionPatternRewriter &rewriter) const override; |
226 | }; |
227 | |
228 | /// Converts memref.load to spirv.Load + spirv.AccessChain on integers. |
229 | class IntLoadOpPattern final : public OpConversionPattern<memref::LoadOp> { |
230 | public: |
231 | using OpConversionPattern<memref::LoadOp>::OpConversionPattern; |
232 | |
233 | LogicalResult |
234 | matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, |
235 | ConversionPatternRewriter &rewriter) const override; |
236 | }; |
237 | |
238 | /// Converts memref.load to spirv.Load + spirv.AccessChain. |
239 | class LoadOpPattern final : public OpConversionPattern<memref::LoadOp> { |
240 | public: |
241 | using OpConversionPattern<memref::LoadOp>::OpConversionPattern; |
242 | |
243 | LogicalResult |
244 | matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, |
245 | ConversionPatternRewriter &rewriter) const override; |
246 | }; |
247 | |
248 | /// Converts memref.store to spirv.Store on integers. |
249 | class IntStoreOpPattern final : public OpConversionPattern<memref::StoreOp> { |
250 | public: |
251 | using OpConversionPattern<memref::StoreOp>::OpConversionPattern; |
252 | |
253 | LogicalResult |
254 | matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor, |
255 | ConversionPatternRewriter &rewriter) const override; |
256 | }; |
257 | |
258 | /// Converts memref.memory_space_cast to the appropriate spirv cast operations. |
259 | class MemorySpaceCastOpPattern final |
260 | : public OpConversionPattern<memref::MemorySpaceCastOp> { |
261 | public: |
262 | using OpConversionPattern<memref::MemorySpaceCastOp>::OpConversionPattern; |
263 | |
264 | LogicalResult |
265 | matchAndRewrite(memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor, |
266 | ConversionPatternRewriter &rewriter) const override; |
267 | }; |
268 | |
269 | /// Converts memref.store to spirv.Store. |
270 | class StoreOpPattern final : public OpConversionPattern<memref::StoreOp> { |
271 | public: |
272 | using OpConversionPattern<memref::StoreOp>::OpConversionPattern; |
273 | |
274 | LogicalResult |
275 | matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor, |
276 | ConversionPatternRewriter &rewriter) const override; |
277 | }; |
278 | |
279 | class ReinterpretCastPattern final |
280 | : public OpConversionPattern<memref::ReinterpretCastOp> { |
281 | public: |
282 | using OpConversionPattern::OpConversionPattern; |
283 | |
284 | LogicalResult |
285 | matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor, |
286 | ConversionPatternRewriter &rewriter) const override; |
287 | }; |
288 | |
289 | class CastPattern final : public OpConversionPattern<memref::CastOp> { |
290 | public: |
291 | using OpConversionPattern::OpConversionPattern; |
292 | |
293 | LogicalResult |
294 | matchAndRewrite(memref::CastOp op, OpAdaptor adaptor, |
295 | ConversionPatternRewriter &rewriter) const override { |
296 | Value src = adaptor.getSource(); |
297 | Type srcType = src.getType(); |
298 | |
299 | const TypeConverter *converter = getTypeConverter(); |
300 | Type dstType = converter->convertType(op.getType()); |
301 | if (srcType != dstType) |
302 | return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { |
303 | diag << "types doesn't match: " << srcType << " and " << dstType; |
304 | }); |
305 | |
306 | rewriter.replaceOp(op, src); |
307 | return success(); |
308 | } |
309 | }; |
310 | |
311 | } // namespace |
312 | |
313 | //===----------------------------------------------------------------------===// |
314 | // AllocaOp |
315 | //===----------------------------------------------------------------------===// |
316 | |
317 | LogicalResult |
318 | AllocaOpPattern::matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor, |
319 | ConversionPatternRewriter &rewriter) const { |
320 | MemRefType allocType = allocaOp.getType(); |
321 | if (!isAllocationSupported(allocaOp, allocType)) |
322 | return rewriter.notifyMatchFailure(allocaOp, "unhandled allocation type" ); |
323 | |
324 | // Get the SPIR-V type for the allocation. |
325 | Type spirvType = getTypeConverter()->convertType(allocType); |
326 | if (!spirvType) |
327 | return rewriter.notifyMatchFailure(allocaOp, "type conversion failed" ); |
328 | |
329 | rewriter.replaceOpWithNewOp<spirv::VariableOp>(allocaOp, spirvType, |
330 | spirv::StorageClass::Function, |
331 | /*initializer=*/nullptr); |
332 | return success(); |
333 | } |
334 | |
335 | //===----------------------------------------------------------------------===// |
336 | // AllocOp |
337 | //===----------------------------------------------------------------------===// |
338 | |
339 | LogicalResult |
340 | AllocOpPattern::matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor, |
341 | ConversionPatternRewriter &rewriter) const { |
342 | MemRefType allocType = operation.getType(); |
343 | if (!isAllocationSupported(operation, allocType)) |
344 | return rewriter.notifyMatchFailure(operation, "unhandled allocation type" ); |
345 | |
346 | // Get the SPIR-V type for the allocation. |
347 | Type spirvType = getTypeConverter()->convertType(allocType); |
348 | if (!spirvType) |
349 | return rewriter.notifyMatchFailure(operation, "type conversion failed" ); |
350 | |
351 | // Insert spirv.GlobalVariable for this allocation. |
352 | Operation *parent = |
353 | SymbolTable::getNearestSymbolTable(from: operation->getParentOp()); |
354 | if (!parent) |
355 | return failure(); |
356 | Location loc = operation.getLoc(); |
357 | spirv::GlobalVariableOp varOp; |
358 | { |
359 | OpBuilder::InsertionGuard guard(rewriter); |
360 | Block &entryBlock = *parent->getRegion(index: 0).begin(); |
361 | rewriter.setInsertionPointToStart(&entryBlock); |
362 | auto varOps = entryBlock.getOps<spirv::GlobalVariableOp>(); |
363 | std::string varName = |
364 | std::string("__workgroup_mem__" ) + |
365 | std::to_string(std::distance(varOps.begin(), varOps.end())); |
366 | varOp = rewriter.create<spirv::GlobalVariableOp>(loc, spirvType, varName, |
367 | /*initializer=*/nullptr); |
368 | } |
369 | |
370 | // Get pointer to global variable at the current scope. |
371 | rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(operation, varOp); |
372 | return success(); |
373 | } |
374 | |
375 | //===----------------------------------------------------------------------===// |
376 | // AllocOp |
377 | //===----------------------------------------------------------------------===// |
378 | |
379 | LogicalResult |
380 | AtomicRMWOpPattern::matchAndRewrite(memref::AtomicRMWOp atomicOp, |
381 | OpAdaptor adaptor, |
382 | ConversionPatternRewriter &rewriter) const { |
383 | if (isa<FloatType>(atomicOp.getType())) |
384 | return rewriter.notifyMatchFailure(atomicOp, |
385 | "unimplemented floating-point case" ); |
386 | |
387 | auto memrefType = cast<MemRefType>(atomicOp.getMemref().getType()); |
388 | std::optional<spirv::Scope> scope = getAtomicOpScope(memrefType); |
389 | if (!scope) |
390 | return rewriter.notifyMatchFailure(atomicOp, |
391 | "unsupported memref memory space" ); |
392 | |
393 | auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>(); |
394 | Type resultType = typeConverter.convertType(atomicOp.getType()); |
395 | if (!resultType) |
396 | return rewriter.notifyMatchFailure(atomicOp, |
397 | "failed to convert result type" ); |
398 | |
399 | auto loc = atomicOp.getLoc(); |
400 | Value ptr = |
401 | spirv::getElementPtr(typeConverter: typeConverter, baseType: memrefType, basePtr: adaptor.getMemref(), |
402 | indices: adaptor.getIndices(), loc: loc, builder&: rewriter); |
403 | |
404 | if (!ptr) |
405 | return failure(); |
406 | |
407 | #define ATOMIC_CASE(kind, spirvOp) \ |
408 | case arith::AtomicRMWKind::kind: \ |
409 | rewriter.replaceOpWithNewOp<spirv::spirvOp>( \ |
410 | atomicOp, resultType, ptr, *scope, \ |
411 | spirv::MemorySemantics::AcquireRelease, adaptor.getValue()); \ |
412 | break |
413 | |
414 | switch (atomicOp.getKind()) { |
415 | ATOMIC_CASE(addi, AtomicIAddOp); |
416 | ATOMIC_CASE(maxs, AtomicSMaxOp); |
417 | ATOMIC_CASE(maxu, AtomicUMaxOp); |
418 | ATOMIC_CASE(mins, AtomicSMinOp); |
419 | ATOMIC_CASE(minu, AtomicUMinOp); |
420 | ATOMIC_CASE(ori, AtomicOrOp); |
421 | ATOMIC_CASE(andi, AtomicAndOp); |
422 | default: |
423 | return rewriter.notifyMatchFailure(atomicOp, "unimplemented atomic kind" ); |
424 | } |
425 | |
426 | #undef ATOMIC_CASE |
427 | |
428 | return success(); |
429 | } |
430 | |
431 | //===----------------------------------------------------------------------===// |
432 | // DeallocOp |
433 | //===----------------------------------------------------------------------===// |
434 | |
435 | LogicalResult |
436 | DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation, |
437 | OpAdaptor adaptor, |
438 | ConversionPatternRewriter &rewriter) const { |
439 | MemRefType deallocType = cast<MemRefType>(operation.getMemref().getType()); |
440 | if (!isAllocationSupported(operation, deallocType)) |
441 | return rewriter.notifyMatchFailure(operation, "unhandled allocation type" ); |
442 | rewriter.eraseOp(op: operation); |
443 | return success(); |
444 | } |
445 | |
446 | //===----------------------------------------------------------------------===// |
447 | // LoadOp |
448 | //===----------------------------------------------------------------------===// |
449 | |
450 | struct MemoryRequirements { |
451 | spirv::MemoryAccessAttr memoryAccess; |
452 | IntegerAttr alignment; |
453 | }; |
454 | |
455 | /// Given an accessed SPIR-V pointer, calculates its alignment requirements, if |
456 | /// any. |
457 | static FailureOr<MemoryRequirements> |
458 | calculateMemoryRequirements(Value accessedPtr, bool isNontemporal) { |
459 | MLIRContext *ctx = accessedPtr.getContext(); |
460 | |
461 | auto memoryAccess = spirv::MemoryAccess::None; |
462 | if (isNontemporal) { |
463 | memoryAccess = spirv::MemoryAccess::Nontemporal; |
464 | } |
465 | |
466 | auto ptrType = cast<spirv::PointerType>(Val: accessedPtr.getType()); |
467 | if (ptrType.getStorageClass() != spirv::StorageClass::PhysicalStorageBuffer) { |
468 | if (memoryAccess == spirv::MemoryAccess::None) { |
469 | return MemoryRequirements{spirv::MemoryAccessAttr{}, IntegerAttr{}}; |
470 | } |
471 | return MemoryRequirements{spirv::MemoryAccessAttr::get(ctx, memoryAccess), |
472 | IntegerAttr{}}; |
473 | } |
474 | |
475 | // PhysicalStorageBuffers require the `Aligned` attribute. |
476 | auto pointeeType = dyn_cast<spirv::ScalarType>(Val: ptrType.getPointeeType()); |
477 | if (!pointeeType) |
478 | return failure(); |
479 | |
480 | // For scalar types, the alignment is determined by their size. |
481 | std::optional<int64_t> sizeInBytes = pointeeType.getSizeInBytes(); |
482 | if (!sizeInBytes.has_value()) |
483 | return failure(); |
484 | |
485 | memoryAccess = memoryAccess | spirv::MemoryAccess::Aligned; |
486 | auto memAccessAttr = spirv::MemoryAccessAttr::get(ctx, memoryAccess); |
487 | auto alignment = IntegerAttr::get(IntegerType::get(ctx, 32), *sizeInBytes); |
488 | return MemoryRequirements{memAccessAttr, alignment}; |
489 | } |
490 | |
491 | /// Given an accessed SPIR-V pointer and the original memref load/store |
492 | /// `memAccess` op, calculates the alignment requirements, if any. Takes into |
493 | /// account the alignment attributes applied to the load/store op. |
494 | template <class LoadOrStoreOp> |
495 | static FailureOr<MemoryRequirements> |
496 | calculateMemoryRequirements(Value accessedPtr, LoadOrStoreOp loadOrStoreOp) { |
497 | static_assert( |
498 | llvm::is_one_of<LoadOrStoreOp, memref::LoadOp, memref::StoreOp>::value, |
499 | "Must be called on either memref::LoadOp or memref::StoreOp" ); |
500 | |
501 | Operation *memrefAccessOp = loadOrStoreOp.getOperation(); |
502 | auto memrefMemAccess = memrefAccessOp->getAttrOfType<spirv::MemoryAccessAttr>( |
503 | spirv::attributeName<spirv::MemoryAccess>()); |
504 | auto memrefAlignment = |
505 | memrefAccessOp->getAttrOfType<IntegerAttr>("alignment" ); |
506 | if (memrefMemAccess && memrefAlignment) |
507 | return MemoryRequirements{memrefMemAccess, memrefAlignment}; |
508 | |
509 | return calculateMemoryRequirements(accessedPtr, |
510 | loadOrStoreOp.getNontemporal()); |
511 | } |
512 | |
513 | LogicalResult |
514 | IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, |
515 | ConversionPatternRewriter &rewriter) const { |
516 | auto loc = loadOp.getLoc(); |
517 | auto memrefType = cast<MemRefType>(loadOp.getMemref().getType()); |
518 | if (!memrefType.getElementType().isSignlessInteger()) |
519 | return failure(); |
520 | |
521 | const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>(); |
522 | Value accessChain = |
523 | spirv::getElementPtr(typeConverter: typeConverter, baseType: memrefType, basePtr: adaptor.getMemref(), |
524 | indices: adaptor.getIndices(), loc: loc, builder&: rewriter); |
525 | |
526 | if (!accessChain) |
527 | return failure(); |
528 | |
529 | int srcBits = memrefType.getElementType().getIntOrFloatBitWidth(); |
530 | bool isBool = srcBits == 1; |
531 | if (isBool) |
532 | srcBits = typeConverter.getOptions().boolNumBits; |
533 | |
534 | auto pointerType = typeConverter.convertType<spirv::PointerType>(memrefType); |
535 | if (!pointerType) |
536 | return rewriter.notifyMatchFailure(loadOp, "failed to convert memref type" ); |
537 | |
538 | Type pointeeType = pointerType.getPointeeType(); |
539 | Type dstType; |
540 | if (typeConverter.allows(spirv::Capability::Kernel)) { |
541 | if (auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType)) |
542 | dstType = arrayType.getElementType(); |
543 | else |
544 | dstType = pointeeType; |
545 | } else { |
546 | // For Vulkan we need to extract element from wrapping struct and array. |
547 | Type structElemType = |
548 | cast<spirv::StructType>(Val&: pointeeType).getElementType(0); |
549 | if (auto arrayType = dyn_cast<spirv::ArrayType>(structElemType)) |
550 | dstType = arrayType.getElementType(); |
551 | else |
552 | dstType = cast<spirv::RuntimeArrayType>(Val&: structElemType).getElementType(); |
553 | } |
554 | int dstBits = dstType.getIntOrFloatBitWidth(); |
555 | assert(dstBits % srcBits == 0); |
556 | |
557 | // If the rewritten load op has the same bit width, use the loading value |
558 | // directly. |
559 | if (srcBits == dstBits) { |
560 | auto memoryRequirements = calculateMemoryRequirements(accessChain, loadOp); |
561 | if (failed(memoryRequirements)) |
562 | return rewriter.notifyMatchFailure( |
563 | loadOp, "failed to determine memory requirements" ); |
564 | |
565 | auto [memoryAccess, alignment] = *memoryRequirements; |
566 | Value loadVal = rewriter.create<spirv::LoadOp>(loc, accessChain, |
567 | memoryAccess, alignment); |
568 | if (isBool) |
569 | loadVal = castIntNToBool(loc, loadVal, rewriter); |
570 | rewriter.replaceOp(loadOp, loadVal); |
571 | return success(); |
572 | } |
573 | |
574 | // Bitcasting is currently unsupported for Kernel capability / |
575 | // spirv.PtrAccessChain. |
576 | if (typeConverter.allows(spirv::Capability::Kernel)) |
577 | return failure(); |
578 | |
579 | auto accessChainOp = accessChain.getDefiningOp<spirv::AccessChainOp>(); |
580 | if (!accessChainOp) |
581 | return failure(); |
582 | |
583 | // Assume that getElementPtr() works linearizely. If it's a scalar, the method |
584 | // still returns a linearized accessing. If the accessing is not linearized, |
585 | // there will be offset issues. |
586 | assert(accessChainOp.getIndices().size() == 2); |
587 | Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp, |
588 | srcBits, dstBits, rewriter); |
589 | auto memoryRequirements = calculateMemoryRequirements(adjustedPtr, loadOp); |
590 | if (failed(memoryRequirements)) |
591 | return rewriter.notifyMatchFailure( |
592 | loadOp, "failed to determine memory requirements" ); |
593 | |
594 | auto [memoryAccess, alignment] = *memoryRequirements; |
595 | Value spvLoadOp = rewriter.create<spirv::LoadOp>(loc, dstType, adjustedPtr, |
596 | memoryAccess, alignment); |
597 | |
598 | // Shift the bits to the rightmost. |
599 | // ____XXXX________ -> ____________XXXX |
600 | Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1); |
601 | Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter); |
602 | Value result = rewriter.createOrFold<spirv::ShiftRightArithmeticOp>( |
603 | loc, spvLoadOp.getType(), spvLoadOp, offset); |
604 | |
605 | // Apply the mask to extract corresponding bits. |
606 | Value mask = rewriter.createOrFold<spirv::ConstantOp>( |
607 | loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1)); |
608 | result = |
609 | rewriter.createOrFold<spirv::BitwiseAndOp>(loc, dstType, result, mask); |
610 | |
611 | // Apply sign extension on the loading value unconditionally. The signedness |
612 | // semantic is carried in the operator itself, we relies other pattern to |
613 | // handle the casting. |
614 | IntegerAttr shiftValueAttr = |
615 | rewriter.getIntegerAttr(dstType, dstBits - srcBits); |
616 | Value shiftValue = |
617 | rewriter.createOrFold<spirv::ConstantOp>(loc, dstType, shiftValueAttr); |
618 | result = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>(loc, dstType, |
619 | result, shiftValue); |
620 | result = rewriter.createOrFold<spirv::ShiftRightArithmeticOp>( |
621 | loc, dstType, result, shiftValue); |
622 | |
623 | rewriter.replaceOp(loadOp, result); |
624 | |
625 | assert(accessChainOp.use_empty()); |
626 | rewriter.eraseOp(op: accessChainOp); |
627 | |
628 | return success(); |
629 | } |
630 | |
631 | LogicalResult |
632 | LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, |
633 | ConversionPatternRewriter &rewriter) const { |
634 | auto memrefType = cast<MemRefType>(loadOp.getMemref().getType()); |
635 | if (memrefType.getElementType().isSignlessInteger()) |
636 | return failure(); |
637 | Value loadPtr = spirv::getElementPtr( |
638 | typeConverter: *getTypeConverter<SPIRVTypeConverter>(), baseType: memrefType, basePtr: adaptor.getMemref(), |
639 | indices: adaptor.getIndices(), loc: loadOp.getLoc(), builder&: rewriter); |
640 | |
641 | if (!loadPtr) |
642 | return failure(); |
643 | |
644 | auto memoryRequirements = calculateMemoryRequirements(loadPtr, loadOp); |
645 | if (failed(memoryRequirements)) |
646 | return rewriter.notifyMatchFailure( |
647 | loadOp, "failed to determine memory requirements" ); |
648 | |
649 | auto [memoryAccess, alignment] = *memoryRequirements; |
650 | rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr, memoryAccess, |
651 | alignment); |
652 | return success(); |
653 | } |
654 | |
655 | LogicalResult |
656 | IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor, |
657 | ConversionPatternRewriter &rewriter) const { |
658 | auto memrefType = cast<MemRefType>(storeOp.getMemref().getType()); |
659 | if (!memrefType.getElementType().isSignlessInteger()) |
660 | return rewriter.notifyMatchFailure(storeOp, |
661 | "element type is not a signless int" ); |
662 | |
663 | auto loc = storeOp.getLoc(); |
664 | auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>(); |
665 | Value accessChain = |
666 | spirv::getElementPtr(typeConverter: typeConverter, baseType: memrefType, basePtr: adaptor.getMemref(), |
667 | indices: adaptor.getIndices(), loc: loc, builder&: rewriter); |
668 | |
669 | if (!accessChain) |
670 | return rewriter.notifyMatchFailure( |
671 | storeOp, "failed to convert element pointer type" ); |
672 | |
673 | int srcBits = memrefType.getElementType().getIntOrFloatBitWidth(); |
674 | |
675 | bool isBool = srcBits == 1; |
676 | if (isBool) |
677 | srcBits = typeConverter.getOptions().boolNumBits; |
678 | |
679 | auto pointerType = typeConverter.convertType<spirv::PointerType>(memrefType); |
680 | if (!pointerType) |
681 | return rewriter.notifyMatchFailure(storeOp, |
682 | "failed to convert memref type" ); |
683 | |
684 | Type pointeeType = pointerType.getPointeeType(); |
685 | IntegerType dstType; |
686 | if (typeConverter.allows(spirv::Capability::Kernel)) { |
687 | if (auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType)) |
688 | dstType = dyn_cast<IntegerType>(arrayType.getElementType()); |
689 | else |
690 | dstType = dyn_cast<IntegerType>(pointeeType); |
691 | } else { |
692 | // For Vulkan we need to extract element from wrapping struct and array. |
693 | Type structElemType = |
694 | cast<spirv::StructType>(Val&: pointeeType).getElementType(0); |
695 | if (auto arrayType = dyn_cast<spirv::ArrayType>(structElemType)) |
696 | dstType = dyn_cast<IntegerType>(arrayType.getElementType()); |
697 | else |
698 | dstType = dyn_cast<IntegerType>( |
699 | cast<spirv::RuntimeArrayType>(Val&: structElemType).getElementType()); |
700 | } |
701 | |
702 | if (!dstType) |
703 | return rewriter.notifyMatchFailure( |
704 | storeOp, "failed to determine destination element type" ); |
705 | |
706 | int dstBits = static_cast<int>(dstType.getWidth()); |
707 | assert(dstBits % srcBits == 0); |
708 | |
709 | if (srcBits == dstBits) { |
710 | auto memoryRequirements = calculateMemoryRequirements(accessChain, storeOp); |
711 | if (failed(memoryRequirements)) |
712 | return rewriter.notifyMatchFailure( |
713 | storeOp, "failed to determine memory requirements" ); |
714 | |
715 | auto [memoryAccess, alignment] = *memoryRequirements; |
716 | Value storeVal = adaptor.getValue(); |
717 | if (isBool) |
718 | storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter); |
719 | rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, accessChain, storeVal, |
720 | memoryAccess, alignment); |
721 | return success(); |
722 | } |
723 | |
724 | // Bitcasting is currently unsupported for Kernel capability / |
725 | // spirv.PtrAccessChain. |
726 | if (typeConverter.allows(spirv::Capability::Kernel)) |
727 | return failure(); |
728 | |
729 | auto accessChainOp = accessChain.getDefiningOp<spirv::AccessChainOp>(); |
730 | if (!accessChainOp) |
731 | return failure(); |
732 | |
733 | // Since there are multiple threads in the processing, the emulation will be |
734 | // done with atomic operations. E.g., if the stored value is i8, rewrite the |
735 | // StoreOp to: |
736 | // 1) load a 32-bit integer |
737 | // 2) clear 8 bits in the loaded value |
738 | // 3) set 8 bits in the loaded value |
739 | // 4) store 32-bit value back |
740 | // |
741 | // Step 2 is done with AtomicAnd, and step 3 is done with AtomicOr (of the |
742 | // loaded 32-bit value and the shifted 8-bit store value) as another atomic |
743 | // step. |
744 | assert(accessChainOp.getIndices().size() == 2); |
745 | Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1); |
746 | Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter); |
747 | |
748 | // Create a mask to clear the destination. E.g., if it is the second i8 in |
749 | // i32, 0xFFFF00FF is created. |
750 | Value mask = rewriter.createOrFold<spirv::ConstantOp>( |
751 | loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1)); |
752 | Value clearBitsMask = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>( |
753 | loc, dstType, mask, offset); |
754 | clearBitsMask = |
755 | rewriter.createOrFold<spirv::NotOp>(loc, dstType, clearBitsMask); |
756 | |
757 | Value storeVal = shiftValue(loc, adaptor.getValue(), offset, mask, rewriter); |
758 | Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp, |
759 | srcBits, dstBits, rewriter); |
760 | std::optional<spirv::Scope> scope = getAtomicOpScope(memrefType); |
761 | if (!scope) |
762 | return rewriter.notifyMatchFailure(storeOp, "atomic scope not available" ); |
763 | |
764 | Value result = rewriter.create<spirv::AtomicAndOp>( |
765 | loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease, |
766 | clearBitsMask); |
767 | result = rewriter.create<spirv::AtomicOrOp>( |
768 | loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease, |
769 | storeVal); |
770 | |
771 | // The AtomicOrOp has no side effect. Since it is already inserted, we can |
772 | // just remove the original StoreOp. Note that rewriter.replaceOp() |
773 | // doesn't work because it only accepts that the numbers of result are the |
774 | // same. |
775 | rewriter.eraseOp(op: storeOp); |
776 | |
777 | assert(accessChainOp.use_empty()); |
778 | rewriter.eraseOp(op: accessChainOp); |
779 | |
780 | return success(); |
781 | } |
782 | |
783 | //===----------------------------------------------------------------------===// |
784 | // MemorySpaceCastOp |
785 | //===----------------------------------------------------------------------===// |
786 | |
787 | LogicalResult MemorySpaceCastOpPattern::matchAndRewrite( |
788 | memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor, |
789 | ConversionPatternRewriter &rewriter) const { |
790 | Location loc = addrCastOp.getLoc(); |
791 | auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>(); |
792 | if (!typeConverter.allows(spirv::Capability::Kernel)) |
793 | return rewriter.notifyMatchFailure( |
794 | arg&: loc, msg: "address space casts require kernel capability" ); |
795 | |
796 | auto sourceType = dyn_cast<MemRefType>(addrCastOp.getSource().getType()); |
797 | if (!sourceType) |
798 | return rewriter.notifyMatchFailure( |
799 | arg&: loc, msg: "SPIR-V lowering requires ranked memref types" ); |
800 | auto resultType = cast<MemRefType>(addrCastOp.getResult().getType()); |
801 | |
802 | auto sourceStorageClassAttr = |
803 | dyn_cast_or_null<spirv::StorageClassAttr>(sourceType.getMemorySpace()); |
804 | if (!sourceStorageClassAttr) |
805 | return rewriter.notifyMatchFailure(loc, reasonCallback: [sourceType](Diagnostic &diag) { |
806 | diag << "source address space " << sourceType.getMemorySpace() |
807 | << " must be a SPIR-V storage class" ; |
808 | }); |
809 | auto resultStorageClassAttr = |
810 | dyn_cast_or_null<spirv::StorageClassAttr>(resultType.getMemorySpace()); |
811 | if (!resultStorageClassAttr) |
812 | return rewriter.notifyMatchFailure(loc, reasonCallback: [resultType](Diagnostic &diag) { |
813 | diag << "result address space " << resultType.getMemorySpace() |
814 | << " must be a SPIR-V storage class" ; |
815 | }); |
816 | |
817 | spirv::StorageClass sourceSc = sourceStorageClassAttr.getValue(); |
818 | spirv::StorageClass resultSc = resultStorageClassAttr.getValue(); |
819 | |
820 | Value result = adaptor.getSource(); |
821 | Type resultPtrType = typeConverter.convertType(resultType); |
822 | if (!resultPtrType) |
823 | return rewriter.notifyMatchFailure(addrCastOp, |
824 | "failed to convert memref type" ); |
825 | |
826 | Type genericPtrType = resultPtrType; |
827 | // SPIR-V doesn't have a general address space cast operation. Instead, it has |
828 | // conversions to and from generic pointers. To implement the general case, |
829 | // we use specific-to-generic conversions when the source class is not |
830 | // generic. Then when the result storage class is not generic, we convert the |
831 | // generic pointer (either the input on ar intermediate result) to that |
832 | // class. This also means that we'll need the intermediate generic pointer |
833 | // type if neither the source or destination have it. |
834 | if (sourceSc != spirv::StorageClass::Generic && |
835 | resultSc != spirv::StorageClass::Generic) { |
836 | Type intermediateType = |
837 | MemRefType::get(sourceType.getShape(), sourceType.getElementType(), |
838 | sourceType.getLayout(), |
839 | rewriter.getAttr<spirv::StorageClassAttr>( |
840 | spirv::StorageClass::Generic)); |
841 | genericPtrType = typeConverter.convertType(intermediateType); |
842 | } |
843 | if (sourceSc != spirv::StorageClass::Generic) { |
844 | result = |
845 | rewriter.create<spirv::PtrCastToGenericOp>(loc, genericPtrType, result); |
846 | } |
847 | if (resultSc != spirv::StorageClass::Generic) { |
848 | result = |
849 | rewriter.create<spirv::GenericCastToPtrOp>(loc, resultPtrType, result); |
850 | } |
851 | rewriter.replaceOp(addrCastOp, result); |
852 | return success(); |
853 | } |
854 | |
855 | LogicalResult |
856 | StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor, |
857 | ConversionPatternRewriter &rewriter) const { |
858 | auto memrefType = cast<MemRefType>(storeOp.getMemref().getType()); |
859 | if (memrefType.getElementType().isSignlessInteger()) |
860 | return rewriter.notifyMatchFailure(storeOp, "signless int" ); |
861 | auto storePtr = spirv::getElementPtr( |
862 | typeConverter: *getTypeConverter<SPIRVTypeConverter>(), baseType: memrefType, basePtr: adaptor.getMemref(), |
863 | indices: adaptor.getIndices(), loc: storeOp.getLoc(), builder&: rewriter); |
864 | |
865 | if (!storePtr) |
866 | return rewriter.notifyMatchFailure(storeOp, "type conversion failed" ); |
867 | |
868 | auto memoryRequirements = calculateMemoryRequirements(storePtr, storeOp); |
869 | if (failed(memoryRequirements)) |
870 | return rewriter.notifyMatchFailure( |
871 | storeOp, "failed to determine memory requirements" ); |
872 | |
873 | auto [memoryAccess, alignment] = *memoryRequirements; |
874 | rewriter.replaceOpWithNewOp<spirv::StoreOp>( |
875 | storeOp, storePtr, adaptor.getValue(), memoryAccess, alignment); |
876 | return success(); |
877 | } |
878 | |
879 | LogicalResult ReinterpretCastPattern::matchAndRewrite( |
880 | memref::ReinterpretCastOp op, OpAdaptor adaptor, |
881 | ConversionPatternRewriter &rewriter) const { |
882 | Value src = adaptor.getSource(); |
883 | auto srcType = dyn_cast<spirv::PointerType>(Val: src.getType()); |
884 | |
885 | if (!srcType) |
886 | return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { |
887 | diag << "invalid src type " << src.getType(); |
888 | }); |
889 | |
890 | const TypeConverter *converter = getTypeConverter(); |
891 | |
892 | auto dstType = converter->convertType<spirv::PointerType>(op.getType()); |
893 | if (dstType != srcType) |
894 | return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { |
895 | diag << "invalid dst type " << op.getType(); |
896 | }); |
897 | |
898 | OpFoldResult offset = |
899 | getMixedValues(adaptor.getStaticOffsets(), adaptor.getOffsets(), rewriter) |
900 | .front(); |
901 | if (isConstantIntValue(ofr: offset, value: 0)) { |
902 | rewriter.replaceOp(op, src); |
903 | return success(); |
904 | } |
905 | |
906 | Type intType = converter->convertType(rewriter.getIndexType()); |
907 | if (!intType) |
908 | return rewriter.notifyMatchFailure(op, "failed to convert index type" ); |
909 | |
910 | Location loc = op.getLoc(); |
911 | auto offsetValue = [&]() -> Value { |
912 | if (auto val = dyn_cast<Value>(offset)) |
913 | return val; |
914 | |
915 | int64_t attrVal = cast<IntegerAttr>(offset.get<Attribute>()).getInt(); |
916 | Attribute attr = rewriter.getIntegerAttr(intType, attrVal); |
917 | return rewriter.createOrFold<spirv::ConstantOp>(loc, intType, attr); |
918 | }(); |
919 | |
920 | rewriter.replaceOpWithNewOp<spirv::InBoundsPtrAccessChainOp>( |
921 | op, src, offsetValue, std::nullopt); |
922 | return success(); |
923 | } |
924 | |
925 | //===----------------------------------------------------------------------===// |
926 | // Pattern population |
927 | //===----------------------------------------------------------------------===// |
928 | |
929 | namespace mlir { |
930 | void populateMemRefToSPIRVPatterns(SPIRVTypeConverter &typeConverter, |
931 | RewritePatternSet &patterns) { |
932 | patterns.add<AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern, |
933 | DeallocOpPattern, IntLoadOpPattern, IntStoreOpPattern, |
934 | LoadOpPattern, MemorySpaceCastOpPattern, StoreOpPattern, |
935 | ReinterpretCastPattern, CastPattern>(arg&: typeConverter, |
936 | args: patterns.getContext()); |
937 | } |
938 | } // namespace mlir |
939 | |