1 | //===- MemRefToSPIRV.cpp - MemRef to SPIR-V Patterns ----------------------===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements patterns to convert MemRef dialect to SPIR-V dialect. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Dialect/Arith/IR/Arith.h" |
14 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
15 | #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" |
16 | #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" |
17 | #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" |
18 | #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" |
19 | #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" |
20 | #include "mlir/IR/BuiltinAttributes.h" |
21 | #include "mlir/IR/BuiltinTypes.h" |
22 | #include "mlir/IR/MLIRContext.h" |
23 | #include "mlir/IR/Visitors.h" |
24 | #include "llvm/Support/Debug.h" |
25 | #include <cassert> |
26 | #include <optional> |
27 | |
28 | #define DEBUG_TYPE "memref-to-spirv-pattern" |
29 | |
30 | using namespace mlir; |
31 | |
32 | //===----------------------------------------------------------------------===// |
33 | // Utility functions |
34 | //===----------------------------------------------------------------------===// |
35 | |
36 | /// Returns the offset of the value in `targetBits` representation. |
37 | /// |
38 | /// `srcIdx` is an index into a 1-D array with each element having `sourceBits`. |
39 | /// It's assumed to be non-negative. |
40 | /// |
41 | /// When accessing an element in the array treating as having elements of |
42 | /// `targetBits`, multiple values are loaded in the same time. The method |
43 | /// returns the offset where the `srcIdx` locates in the value. For example, if |
44 | /// `sourceBits` equals to 8 and `targetBits` equals to 32, the x-th element is |
45 | /// located at (x % 4) * 8. Because there are four elements in one i32, and one |
46 | /// element has 8 bits. |
47 | static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits, |
48 | int targetBits, OpBuilder &builder) { |
49 | assert(targetBits % sourceBits == 0); |
50 | Type type = srcIdx.getType(); |
51 | IntegerAttr idxAttr = builder.getIntegerAttr(type, targetBits / sourceBits); |
52 | auto idx = builder.createOrFold<spirv::ConstantOp>(loc, type, idxAttr); |
53 | IntegerAttr srcBitsAttr = builder.getIntegerAttr(type, sourceBits); |
54 | auto srcBitsValue = |
55 | builder.createOrFold<spirv::ConstantOp>(loc, type, srcBitsAttr); |
56 | auto m = builder.createOrFold<spirv::UModOp>(loc, srcIdx, idx); |
57 | return builder.createOrFold<spirv::IMulOp>(loc, type, m, srcBitsValue); |
58 | } |
59 | |
60 | /// Returns an adjusted spirv::AccessChainOp. Based on the |
61 | /// extension/capabilities, certain integer bitwidths `sourceBits` might not be |
62 | /// supported. During conversion if a memref of an unsupported type is used, |
63 | /// load/stores to this memref need to be modified to use a supported higher |
64 | /// bitwidth `targetBits` and extracting the required bits. For an accessing a |
65 | /// 1D array (spirv.array or spirv.rtarray), the last index is modified to load |
66 | /// the bits needed. The extraction of the actual bits needed are handled |
67 | /// separately. Note that this only works for a 1-D tensor. |
68 | static Value |
69 | adjustAccessChainForBitwidth(const SPIRVTypeConverter &typeConverter, |
70 | spirv::AccessChainOp op, int sourceBits, |
71 | int targetBits, OpBuilder &builder) { |
72 | assert(targetBits % sourceBits == 0); |
73 | const auto loc = op.getLoc(); |
74 | Value lastDim = op->getOperand(op.getNumOperands() - 1); |
75 | Type type = lastDim.getType(); |
76 | IntegerAttr attr = builder.getIntegerAttr(type, targetBits / sourceBits); |
77 | auto idx = builder.createOrFold<spirv::ConstantOp>(loc, type, attr); |
78 | auto indices = llvm::to_vector<4>(op.getIndices()); |
79 | // There are two elements if this is a 1-D tensor. |
80 | assert(indices.size() == 2); |
81 | indices.back() = builder.createOrFold<spirv::SDivOp>(loc, lastDim, idx); |
82 | Type t = typeConverter.convertType(op.getComponentPtr().getType()); |
83 | return builder.create<spirv::AccessChainOp>(loc, t, op.getBasePtr(), indices); |
84 | } |
85 | |
86 | /// Casts the given `srcBool` into an integer of `dstType`. |
87 | static Value castBoolToIntN(Location loc, Value srcBool, Type dstType, |
88 | OpBuilder &builder) { |
89 | assert(srcBool.getType().isInteger(1)); |
90 | if (dstType.isInteger(width: 1)) |
91 | return srcBool; |
92 | Value zero = spirv::ConstantOp::getZero(dstType, loc, builder); |
93 | Value one = spirv::ConstantOp::getOne(dstType, loc, builder); |
94 | return builder.createOrFold<spirv::SelectOp>(loc, dstType, srcBool, one, |
95 | zero); |
96 | } |
97 | |
98 | /// Returns the `targetBits`-bit value shifted by the given `offset`, and cast |
99 | /// to the type destination type, and masked. |
100 | static Value shiftValue(Location loc, Value value, Value offset, Value mask, |
101 | OpBuilder &builder) { |
102 | IntegerType dstType = cast<IntegerType>(mask.getType()); |
103 | int targetBits = static_cast<int>(dstType.getWidth()); |
104 | int valueBits = value.getType().getIntOrFloatBitWidth(); |
105 | assert(valueBits <= targetBits); |
106 | |
107 | if (valueBits == 1) { |
108 | value = castBoolToIntN(loc, value, dstType, builder); |
109 | } else { |
110 | if (valueBits < targetBits) { |
111 | value = builder.create<spirv::UConvertOp>( |
112 | loc, builder.getIntegerType(targetBits), value); |
113 | } |
114 | |
115 | value = builder.createOrFold<spirv::BitwiseAndOp>(loc, value, mask); |
116 | } |
117 | return builder.createOrFold<spirv::ShiftLeftLogicalOp>(loc, value.getType(), |
118 | value, offset); |
119 | } |
120 | |
121 | /// Returns true if the allocations of memref `type` generated from `allocOp` |
122 | /// can be lowered to SPIR-V. |
123 | static bool isAllocationSupported(Operation *allocOp, MemRefType type) { |
124 | if (isa<memref::AllocOp, memref::DeallocOp>(allocOp)) { |
125 | auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace()); |
126 | if (!sc || sc.getValue() != spirv::StorageClass::Workgroup) |
127 | return false; |
128 | } else if (isa<memref::AllocaOp>(allocOp)) { |
129 | auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace()); |
130 | if (!sc || sc.getValue() != spirv::StorageClass::Function) |
131 | return false; |
132 | } else { |
133 | return false; |
134 | } |
135 | |
136 | // Currently only support static shape and int or float or vector of int or |
137 | // float element type. |
138 | if (!type.hasStaticShape()) |
139 | return false; |
140 | |
141 | Type elementType = type.getElementType(); |
142 | if (auto vecType = dyn_cast<VectorType>(elementType)) |
143 | elementType = vecType.getElementType(); |
144 | return elementType.isIntOrFloat(); |
145 | } |
146 | |
147 | /// Returns the scope to use for atomic operations use for emulating store |
148 | /// operations of unsupported integer bitwidths, based on the memref |
149 | /// type. Returns std::nullopt on failure. |
150 | static std::optional<spirv::Scope> getAtomicOpScope(MemRefType type) { |
151 | auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace()); |
152 | switch (sc.getValue()) { |
153 | case spirv::StorageClass::StorageBuffer: |
154 | return spirv::Scope::Device; |
155 | case spirv::StorageClass::Workgroup: |
156 | return spirv::Scope::Workgroup; |
157 | default: |
158 | break; |
159 | } |
160 | return {}; |
161 | } |
162 | |
163 | /// Casts the given `srcInt` into a boolean value. |
164 | static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder) { |
165 | if (srcInt.getType().isInteger(width: 1)) |
166 | return srcInt; |
167 | |
168 | auto one = spirv::ConstantOp::getZero(srcInt.getType(), loc, builder); |
169 | return builder.createOrFold<spirv::INotEqualOp>(loc, srcInt, one); |
170 | } |
171 | |
172 | //===----------------------------------------------------------------------===// |
173 | // Operation conversion |
174 | //===----------------------------------------------------------------------===// |
175 | |
176 | // Note that DRR cannot be used for the patterns in this file: we may need to |
177 | // convert type along the way, which requires ConversionPattern. DRR generates |
178 | // normal RewritePattern. |
179 | |
180 | namespace { |
181 | |
182 | /// Converts memref.alloca to SPIR-V Function variables. |
183 | class AllocaOpPattern final : public OpConversionPattern<memref::AllocaOp> { |
184 | public: |
185 | using OpConversionPattern<memref::AllocaOp>::OpConversionPattern; |
186 | |
187 | LogicalResult |
188 | matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor, |
189 | ConversionPatternRewriter &rewriter) const override; |
190 | }; |
191 | |
192 | /// Converts an allocation operation to SPIR-V. Currently only supports lowering |
193 | /// to Workgroup memory when the size is constant. Note that this pattern needs |
194 | /// to be applied in a pass that runs at least at spirv.module scope since it |
195 | /// wil ladd global variables into the spirv.module. |
196 | class AllocOpPattern final : public OpConversionPattern<memref::AllocOp> { |
197 | public: |
198 | using OpConversionPattern<memref::AllocOp>::OpConversionPattern; |
199 | |
200 | LogicalResult |
201 | matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor, |
202 | ConversionPatternRewriter &rewriter) const override; |
203 | }; |
204 | |
205 | /// Converts memref.automic_rmw operations to SPIR-V atomic operations. |
206 | class AtomicRMWOpPattern final |
207 | : public OpConversionPattern<memref::AtomicRMWOp> { |
208 | public: |
209 | using OpConversionPattern<memref::AtomicRMWOp>::OpConversionPattern; |
210 | |
211 | LogicalResult |
212 | matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor, |
213 | ConversionPatternRewriter &rewriter) const override; |
214 | }; |
215 | |
216 | /// Removed a deallocation if it is a supported allocation. Currently only |
217 | /// removes deallocation if the memory space is workgroup memory. |
218 | class DeallocOpPattern final : public OpConversionPattern<memref::DeallocOp> { |
219 | public: |
220 | using OpConversionPattern<memref::DeallocOp>::OpConversionPattern; |
221 | |
222 | LogicalResult |
223 | matchAndRewrite(memref::DeallocOp operation, OpAdaptor adaptor, |
224 | ConversionPatternRewriter &rewriter) const override; |
225 | }; |
226 | |
227 | /// Converts memref.load to spirv.Load + spirv.AccessChain on integers. |
228 | class IntLoadOpPattern final : public OpConversionPattern<memref::LoadOp> { |
229 | public: |
230 | using OpConversionPattern<memref::LoadOp>::OpConversionPattern; |
231 | |
232 | LogicalResult |
233 | matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, |
234 | ConversionPatternRewriter &rewriter) const override; |
235 | }; |
236 | |
237 | /// Converts memref.load to spirv.Load + spirv.AccessChain. |
238 | class LoadOpPattern final : public OpConversionPattern<memref::LoadOp> { |
239 | public: |
240 | using OpConversionPattern<memref::LoadOp>::OpConversionPattern; |
241 | |
242 | LogicalResult |
243 | matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, |
244 | ConversionPatternRewriter &rewriter) const override; |
245 | }; |
246 | |
247 | /// Converts memref.store to spirv.Store on integers. |
248 | class IntStoreOpPattern final : public OpConversionPattern<memref::StoreOp> { |
249 | public: |
250 | using OpConversionPattern<memref::StoreOp>::OpConversionPattern; |
251 | |
252 | LogicalResult |
253 | matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor, |
254 | ConversionPatternRewriter &rewriter) const override; |
255 | }; |
256 | |
257 | /// Converts memref.memory_space_cast to the appropriate spirv cast operations. |
258 | class MemorySpaceCastOpPattern final |
259 | : public OpConversionPattern<memref::MemorySpaceCastOp> { |
260 | public: |
261 | using OpConversionPattern<memref::MemorySpaceCastOp>::OpConversionPattern; |
262 | |
263 | LogicalResult |
264 | matchAndRewrite(memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor, |
265 | ConversionPatternRewriter &rewriter) const override; |
266 | }; |
267 | |
268 | /// Converts memref.store to spirv.Store. |
269 | class StoreOpPattern final : public OpConversionPattern<memref::StoreOp> { |
270 | public: |
271 | using OpConversionPattern<memref::StoreOp>::OpConversionPattern; |
272 | |
273 | LogicalResult |
274 | matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor, |
275 | ConversionPatternRewriter &rewriter) const override; |
276 | }; |
277 | |
278 | class ReinterpretCastPattern final |
279 | : public OpConversionPattern<memref::ReinterpretCastOp> { |
280 | public: |
281 | using OpConversionPattern::OpConversionPattern; |
282 | |
283 | LogicalResult |
284 | matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor, |
285 | ConversionPatternRewriter &rewriter) const override; |
286 | }; |
287 | |
288 | class CastPattern final : public OpConversionPattern<memref::CastOp> { |
289 | public: |
290 | using OpConversionPattern::OpConversionPattern; |
291 | |
292 | LogicalResult |
293 | matchAndRewrite(memref::CastOp op, OpAdaptor adaptor, |
294 | ConversionPatternRewriter &rewriter) const override { |
295 | Value src = adaptor.getSource(); |
296 | Type srcType = src.getType(); |
297 | |
298 | const TypeConverter *converter = getTypeConverter(); |
299 | Type dstType = converter->convertType(op.getType()); |
300 | if (srcType != dstType) |
301 | return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { |
302 | diag << "types doesn't match: "<< srcType << " and "<< dstType; |
303 | }); |
304 | |
305 | rewriter.replaceOp(op, src); |
306 | return success(); |
307 | } |
308 | }; |
309 | |
310 | /// Converts memref.extract_aligned_pointer_as_index to spirv.ConvertPtrToU. |
311 | class ExtractAlignedPointerAsIndexOpPattern final |
312 | : public OpConversionPattern<memref::ExtractAlignedPointerAsIndexOp> { |
313 | public: |
314 | using OpConversionPattern::OpConversionPattern; |
315 | |
316 | LogicalResult |
317 | matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp, |
318 | OpAdaptor adaptor, |
319 | ConversionPatternRewriter &rewriter) const override; |
320 | }; |
321 | } // namespace |
322 | |
323 | //===----------------------------------------------------------------------===// |
324 | // AllocaOp |
325 | //===----------------------------------------------------------------------===// |
326 | |
327 | LogicalResult |
328 | AllocaOpPattern::matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor, |
329 | ConversionPatternRewriter &rewriter) const { |
330 | MemRefType allocType = allocaOp.getType(); |
331 | if (!isAllocationSupported(allocaOp, allocType)) |
332 | return rewriter.notifyMatchFailure(allocaOp, "unhandled allocation type"); |
333 | |
334 | // Get the SPIR-V type for the allocation. |
335 | Type spirvType = getTypeConverter()->convertType(allocType); |
336 | if (!spirvType) |
337 | return rewriter.notifyMatchFailure(allocaOp, "type conversion failed"); |
338 | |
339 | rewriter.replaceOpWithNewOp<spirv::VariableOp>(allocaOp, spirvType, |
340 | spirv::StorageClass::Function, |
341 | /*initializer=*/nullptr); |
342 | return success(); |
343 | } |
344 | |
345 | //===----------------------------------------------------------------------===// |
346 | // AllocOp |
347 | //===----------------------------------------------------------------------===// |
348 | |
349 | LogicalResult |
350 | AllocOpPattern::matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor, |
351 | ConversionPatternRewriter &rewriter) const { |
352 | MemRefType allocType = operation.getType(); |
353 | if (!isAllocationSupported(operation, allocType)) |
354 | return rewriter.notifyMatchFailure(operation, "unhandled allocation type"); |
355 | |
356 | // Get the SPIR-V type for the allocation. |
357 | Type spirvType = getTypeConverter()->convertType(allocType); |
358 | if (!spirvType) |
359 | return rewriter.notifyMatchFailure(operation, "type conversion failed"); |
360 | |
361 | // Insert spirv.GlobalVariable for this allocation. |
362 | Operation *parent = |
363 | SymbolTable::getNearestSymbolTable(from: operation->getParentOp()); |
364 | if (!parent) |
365 | return failure(); |
366 | Location loc = operation.getLoc(); |
367 | spirv::GlobalVariableOp varOp; |
368 | { |
369 | OpBuilder::InsertionGuard guard(rewriter); |
370 | Block &entryBlock = *parent->getRegion(index: 0).begin(); |
371 | rewriter.setInsertionPointToStart(&entryBlock); |
372 | auto varOps = entryBlock.getOps<spirv::GlobalVariableOp>(); |
373 | std::string varName = |
374 | std::string("__workgroup_mem__") + |
375 | std::to_string(std::distance(varOps.begin(), varOps.end())); |
376 | varOp = rewriter.create<spirv::GlobalVariableOp>(loc, spirvType, varName, |
377 | /*initializer=*/nullptr); |
378 | } |
379 | |
380 | // Get pointer to global variable at the current scope. |
381 | rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(operation, varOp); |
382 | return success(); |
383 | } |
384 | |
385 | //===----------------------------------------------------------------------===// |
386 | // AllocOp |
387 | //===----------------------------------------------------------------------===// |
388 | |
389 | LogicalResult |
390 | AtomicRMWOpPattern::matchAndRewrite(memref::AtomicRMWOp atomicOp, |
391 | OpAdaptor adaptor, |
392 | ConversionPatternRewriter &rewriter) const { |
393 | if (isa<FloatType>(atomicOp.getType())) |
394 | return rewriter.notifyMatchFailure(atomicOp, |
395 | "unimplemented floating-point case"); |
396 | |
397 | auto memrefType = cast<MemRefType>(atomicOp.getMemref().getType()); |
398 | std::optional<spirv::Scope> scope = getAtomicOpScope(memrefType); |
399 | if (!scope) |
400 | return rewriter.notifyMatchFailure(atomicOp, |
401 | "unsupported memref memory space"); |
402 | |
403 | auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>(); |
404 | Type resultType = typeConverter.convertType(atomicOp.getType()); |
405 | if (!resultType) |
406 | return rewriter.notifyMatchFailure(atomicOp, |
407 | "failed to convert result type"); |
408 | |
409 | auto loc = atomicOp.getLoc(); |
410 | Value ptr = |
411 | spirv::getElementPtr(typeConverter: typeConverter, baseType: memrefType, basePtr: adaptor.getMemref(), |
412 | indices: adaptor.getIndices(), loc: loc, builder&: rewriter); |
413 | |
414 | if (!ptr) |
415 | return failure(); |
416 | |
417 | #define ATOMIC_CASE(kind, spirvOp) \ |
418 | case arith::AtomicRMWKind::kind: \ |
419 | rewriter.replaceOpWithNewOp<spirv::spirvOp>( \ |
420 | atomicOp, resultType, ptr, *scope, \ |
421 | spirv::MemorySemantics::AcquireRelease, adaptor.getValue()); \ |
422 | break |
423 | |
424 | switch (atomicOp.getKind()) { |
425 | ATOMIC_CASE(addi, AtomicIAddOp); |
426 | ATOMIC_CASE(maxs, AtomicSMaxOp); |
427 | ATOMIC_CASE(maxu, AtomicUMaxOp); |
428 | ATOMIC_CASE(mins, AtomicSMinOp); |
429 | ATOMIC_CASE(minu, AtomicUMinOp); |
430 | ATOMIC_CASE(ori, AtomicOrOp); |
431 | ATOMIC_CASE(andi, AtomicAndOp); |
432 | default: |
433 | return rewriter.notifyMatchFailure(atomicOp, "unimplemented atomic kind"); |
434 | } |
435 | |
436 | #undef ATOMIC_CASE |
437 | |
438 | return success(); |
439 | } |
440 | |
441 | //===----------------------------------------------------------------------===// |
442 | // DeallocOp |
443 | //===----------------------------------------------------------------------===// |
444 | |
445 | LogicalResult |
446 | DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation, |
447 | OpAdaptor adaptor, |
448 | ConversionPatternRewriter &rewriter) const { |
449 | MemRefType deallocType = cast<MemRefType>(operation.getMemref().getType()); |
450 | if (!isAllocationSupported(operation, deallocType)) |
451 | return rewriter.notifyMatchFailure(operation, "unhandled allocation type"); |
452 | rewriter.eraseOp(op: operation); |
453 | return success(); |
454 | } |
455 | |
456 | //===----------------------------------------------------------------------===// |
457 | // LoadOp |
458 | //===----------------------------------------------------------------------===// |
459 | |
460 | struct MemoryRequirements { |
461 | spirv::MemoryAccessAttr memoryAccess; |
462 | IntegerAttr alignment; |
463 | }; |
464 | |
465 | /// Given an accessed SPIR-V pointer, calculates its alignment requirements, if |
466 | /// any. |
467 | static FailureOr<MemoryRequirements> |
468 | calculateMemoryRequirements(Value accessedPtr, bool isNontemporal) { |
469 | MLIRContext *ctx = accessedPtr.getContext(); |
470 | |
471 | auto memoryAccess = spirv::MemoryAccess::None; |
472 | if (isNontemporal) { |
473 | memoryAccess = spirv::MemoryAccess::Nontemporal; |
474 | } |
475 | |
476 | auto ptrType = cast<spirv::PointerType>(Val: accessedPtr.getType()); |
477 | if (ptrType.getStorageClass() != spirv::StorageClass::PhysicalStorageBuffer) { |
478 | if (memoryAccess == spirv::MemoryAccess::None) { |
479 | return MemoryRequirements{spirv::MemoryAccessAttr{}, IntegerAttr{}}; |
480 | } |
481 | return MemoryRequirements{spirv::MemoryAccessAttr::get(ctx, memoryAccess), |
482 | IntegerAttr{}}; |
483 | } |
484 | |
485 | // PhysicalStorageBuffers require the `Aligned` attribute. |
486 | auto pointeeType = dyn_cast<spirv::ScalarType>(Val: ptrType.getPointeeType()); |
487 | if (!pointeeType) |
488 | return failure(); |
489 | |
490 | // For scalar types, the alignment is determined by their size. |
491 | std::optional<int64_t> sizeInBytes = pointeeType.getSizeInBytes(); |
492 | if (!sizeInBytes.has_value()) |
493 | return failure(); |
494 | |
495 | memoryAccess = memoryAccess | spirv::MemoryAccess::Aligned; |
496 | auto memAccessAttr = spirv::MemoryAccessAttr::get(ctx, memoryAccess); |
497 | auto alignment = IntegerAttr::get(IntegerType::get(ctx, 32), *sizeInBytes); |
498 | return MemoryRequirements{memAccessAttr, alignment}; |
499 | } |
500 | |
501 | /// Given an accessed SPIR-V pointer and the original memref load/store |
502 | /// `memAccess` op, calculates the alignment requirements, if any. Takes into |
503 | /// account the alignment attributes applied to the load/store op. |
504 | template <class LoadOrStoreOp> |
505 | static FailureOr<MemoryRequirements> |
506 | calculateMemoryRequirements(Value accessedPtr, LoadOrStoreOp loadOrStoreOp) { |
507 | static_assert( |
508 | llvm::is_one_of<LoadOrStoreOp, memref::LoadOp, memref::StoreOp>::value, |
509 | "Must be called on either memref::LoadOp or memref::StoreOp"); |
510 | |
511 | Operation *memrefAccessOp = loadOrStoreOp.getOperation(); |
512 | auto memrefMemAccess = memrefAccessOp->getAttrOfType<spirv::MemoryAccessAttr>( |
513 | spirv::attributeName<spirv::MemoryAccess>()); |
514 | auto memrefAlignment = |
515 | memrefAccessOp->getAttrOfType<IntegerAttr>("alignment"); |
516 | if (memrefMemAccess && memrefAlignment) |
517 | return MemoryRequirements{memrefMemAccess, memrefAlignment}; |
518 | |
519 | return calculateMemoryRequirements(accessedPtr, |
520 | loadOrStoreOp.getNontemporal()); |
521 | } |
522 | |
523 | LogicalResult |
524 | IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, |
525 | ConversionPatternRewriter &rewriter) const { |
526 | auto loc = loadOp.getLoc(); |
527 | auto memrefType = cast<MemRefType>(loadOp.getMemref().getType()); |
528 | if (!memrefType.getElementType().isSignlessInteger()) |
529 | return failure(); |
530 | |
531 | const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>(); |
532 | Value accessChain = |
533 | spirv::getElementPtr(typeConverter: typeConverter, baseType: memrefType, basePtr: adaptor.getMemref(), |
534 | indices: adaptor.getIndices(), loc: loc, builder&: rewriter); |
535 | |
536 | if (!accessChain) |
537 | return failure(); |
538 | |
539 | int srcBits = memrefType.getElementType().getIntOrFloatBitWidth(); |
540 | bool isBool = srcBits == 1; |
541 | if (isBool) |
542 | srcBits = typeConverter.getOptions().boolNumBits; |
543 | |
544 | auto pointerType = typeConverter.convertType<spirv::PointerType>(memrefType); |
545 | if (!pointerType) |
546 | return rewriter.notifyMatchFailure(loadOp, "failed to convert memref type"); |
547 | |
548 | Type pointeeType = pointerType.getPointeeType(); |
549 | Type dstType; |
550 | if (typeConverter.allows(spirv::Capability::Kernel)) { |
551 | if (auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType)) |
552 | dstType = arrayType.getElementType(); |
553 | else |
554 | dstType = pointeeType; |
555 | } else { |
556 | // For Vulkan we need to extract element from wrapping struct and array. |
557 | Type structElemType = |
558 | cast<spirv::StructType>(Val&: pointeeType).getElementType(0); |
559 | if (auto arrayType = dyn_cast<spirv::ArrayType>(structElemType)) |
560 | dstType = arrayType.getElementType(); |
561 | else |
562 | dstType = cast<spirv::RuntimeArrayType>(Val&: structElemType).getElementType(); |
563 | } |
564 | int dstBits = dstType.getIntOrFloatBitWidth(); |
565 | assert(dstBits % srcBits == 0); |
566 | |
567 | // If the rewritten load op has the same bit width, use the loading value |
568 | // directly. |
569 | if (srcBits == dstBits) { |
570 | auto memoryRequirements = calculateMemoryRequirements(accessChain, loadOp); |
571 | if (failed(memoryRequirements)) |
572 | return rewriter.notifyMatchFailure( |
573 | loadOp, "failed to determine memory requirements"); |
574 | |
575 | auto [memoryAccess, alignment] = *memoryRequirements; |
576 | Value loadVal = rewriter.create<spirv::LoadOp>(loc, accessChain, |
577 | memoryAccess, alignment); |
578 | if (isBool) |
579 | loadVal = castIntNToBool(loc, loadVal, rewriter); |
580 | rewriter.replaceOp(loadOp, loadVal); |
581 | return success(); |
582 | } |
583 | |
584 | // Bitcasting is currently unsupported for Kernel capability / |
585 | // spirv.PtrAccessChain. |
586 | if (typeConverter.allows(spirv::Capability::Kernel)) |
587 | return failure(); |
588 | |
589 | auto accessChainOp = accessChain.getDefiningOp<spirv::AccessChainOp>(); |
590 | if (!accessChainOp) |
591 | return failure(); |
592 | |
593 | // Assume that getElementPtr() works linearizely. If it's a scalar, the method |
594 | // still returns a linearized accessing. If the accessing is not linearized, |
595 | // there will be offset issues. |
596 | assert(accessChainOp.getIndices().size() == 2); |
597 | Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp, |
598 | srcBits, dstBits, rewriter); |
599 | auto memoryRequirements = calculateMemoryRequirements(adjustedPtr, loadOp); |
600 | if (failed(memoryRequirements)) |
601 | return rewriter.notifyMatchFailure( |
602 | loadOp, "failed to determine memory requirements"); |
603 | |
604 | auto [memoryAccess, alignment] = *memoryRequirements; |
605 | Value spvLoadOp = rewriter.create<spirv::LoadOp>(loc, dstType, adjustedPtr, |
606 | memoryAccess, alignment); |
607 | |
608 | // Shift the bits to the rightmost. |
609 | // ____XXXX________ -> ____________XXXX |
610 | Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1); |
611 | Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter); |
612 | Value result = rewriter.createOrFold<spirv::ShiftRightArithmeticOp>( |
613 | loc, spvLoadOp.getType(), spvLoadOp, offset); |
614 | |
615 | // Apply the mask to extract corresponding bits. |
616 | Value mask = rewriter.createOrFold<spirv::ConstantOp>( |
617 | loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1)); |
618 | result = |
619 | rewriter.createOrFold<spirv::BitwiseAndOp>(loc, dstType, result, mask); |
620 | |
621 | // Apply sign extension on the loading value unconditionally. The signedness |
622 | // semantic is carried in the operator itself, we relies other pattern to |
623 | // handle the casting. |
624 | IntegerAttr shiftValueAttr = |
625 | rewriter.getIntegerAttr(dstType, dstBits - srcBits); |
626 | Value shiftValue = |
627 | rewriter.createOrFold<spirv::ConstantOp>(loc, dstType, shiftValueAttr); |
628 | result = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>(loc, dstType, |
629 | result, shiftValue); |
630 | result = rewriter.createOrFold<spirv::ShiftRightArithmeticOp>( |
631 | loc, dstType, result, shiftValue); |
632 | |
633 | rewriter.replaceOp(loadOp, result); |
634 | |
635 | assert(accessChainOp.use_empty()); |
636 | rewriter.eraseOp(op: accessChainOp); |
637 | |
638 | return success(); |
639 | } |
640 | |
641 | LogicalResult |
642 | LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, |
643 | ConversionPatternRewriter &rewriter) const { |
644 | auto memrefType = cast<MemRefType>(loadOp.getMemref().getType()); |
645 | if (memrefType.getElementType().isSignlessInteger()) |
646 | return failure(); |
647 | Value loadPtr = spirv::getElementPtr( |
648 | typeConverter: *getTypeConverter<SPIRVTypeConverter>(), baseType: memrefType, basePtr: adaptor.getMemref(), |
649 | indices: adaptor.getIndices(), loc: loadOp.getLoc(), builder&: rewriter); |
650 | |
651 | if (!loadPtr) |
652 | return failure(); |
653 | |
654 | auto memoryRequirements = calculateMemoryRequirements(loadPtr, loadOp); |
655 | if (failed(memoryRequirements)) |
656 | return rewriter.notifyMatchFailure( |
657 | loadOp, "failed to determine memory requirements"); |
658 | |
659 | auto [memoryAccess, alignment] = *memoryRequirements; |
660 | rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr, memoryAccess, |
661 | alignment); |
662 | return success(); |
663 | } |
664 | |
665 | LogicalResult |
666 | IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor, |
667 | ConversionPatternRewriter &rewriter) const { |
668 | auto memrefType = cast<MemRefType>(storeOp.getMemref().getType()); |
669 | if (!memrefType.getElementType().isSignlessInteger()) |
670 | return rewriter.notifyMatchFailure(storeOp, |
671 | "element type is not a signless int"); |
672 | |
673 | auto loc = storeOp.getLoc(); |
674 | auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>(); |
675 | Value accessChain = |
676 | spirv::getElementPtr(typeConverter: typeConverter, baseType: memrefType, basePtr: adaptor.getMemref(), |
677 | indices: adaptor.getIndices(), loc: loc, builder&: rewriter); |
678 | |
679 | if (!accessChain) |
680 | return rewriter.notifyMatchFailure( |
681 | storeOp, "failed to convert element pointer type"); |
682 | |
683 | int srcBits = memrefType.getElementType().getIntOrFloatBitWidth(); |
684 | |
685 | bool isBool = srcBits == 1; |
686 | if (isBool) |
687 | srcBits = typeConverter.getOptions().boolNumBits; |
688 | |
689 | auto pointerType = typeConverter.convertType<spirv::PointerType>(memrefType); |
690 | if (!pointerType) |
691 | return rewriter.notifyMatchFailure(storeOp, |
692 | "failed to convert memref type"); |
693 | |
694 | Type pointeeType = pointerType.getPointeeType(); |
695 | IntegerType dstType; |
696 | if (typeConverter.allows(spirv::Capability::Kernel)) { |
697 | if (auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType)) |
698 | dstType = dyn_cast<IntegerType>(arrayType.getElementType()); |
699 | else |
700 | dstType = dyn_cast<IntegerType>(pointeeType); |
701 | } else { |
702 | // For Vulkan we need to extract element from wrapping struct and array. |
703 | Type structElemType = |
704 | cast<spirv::StructType>(Val&: pointeeType).getElementType(0); |
705 | if (auto arrayType = dyn_cast<spirv::ArrayType>(structElemType)) |
706 | dstType = dyn_cast<IntegerType>(arrayType.getElementType()); |
707 | else |
708 | dstType = dyn_cast<IntegerType>( |
709 | cast<spirv::RuntimeArrayType>(Val&: structElemType).getElementType()); |
710 | } |
711 | |
712 | if (!dstType) |
713 | return rewriter.notifyMatchFailure( |
714 | storeOp, "failed to determine destination element type"); |
715 | |
716 | int dstBits = static_cast<int>(dstType.getWidth()); |
717 | assert(dstBits % srcBits == 0); |
718 | |
719 | if (srcBits == dstBits) { |
720 | auto memoryRequirements = calculateMemoryRequirements(accessChain, storeOp); |
721 | if (failed(memoryRequirements)) |
722 | return rewriter.notifyMatchFailure( |
723 | storeOp, "failed to determine memory requirements"); |
724 | |
725 | auto [memoryAccess, alignment] = *memoryRequirements; |
726 | Value storeVal = adaptor.getValue(); |
727 | if (isBool) |
728 | storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter); |
729 | rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, accessChain, storeVal, |
730 | memoryAccess, alignment); |
731 | return success(); |
732 | } |
733 | |
734 | // Bitcasting is currently unsupported for Kernel capability / |
735 | // spirv.PtrAccessChain. |
736 | if (typeConverter.allows(spirv::Capability::Kernel)) |
737 | return failure(); |
738 | |
739 | auto accessChainOp = accessChain.getDefiningOp<spirv::AccessChainOp>(); |
740 | if (!accessChainOp) |
741 | return failure(); |
742 | |
743 | // Since there are multiple threads in the processing, the emulation will be |
744 | // done with atomic operations. E.g., if the stored value is i8, rewrite the |
745 | // StoreOp to: |
746 | // 1) load a 32-bit integer |
747 | // 2) clear 8 bits in the loaded value |
748 | // 3) set 8 bits in the loaded value |
749 | // 4) store 32-bit value back |
750 | // |
751 | // Step 2 is done with AtomicAnd, and step 3 is done with AtomicOr (of the |
752 | // loaded 32-bit value and the shifted 8-bit store value) as another atomic |
753 | // step. |
754 | assert(accessChainOp.getIndices().size() == 2); |
755 | Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1); |
756 | Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter); |
757 | |
758 | // Create a mask to clear the destination. E.g., if it is the second i8 in |
759 | // i32, 0xFFFF00FF is created. |
760 | Value mask = rewriter.createOrFold<spirv::ConstantOp>( |
761 | loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1)); |
762 | Value clearBitsMask = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>( |
763 | loc, dstType, mask, offset); |
764 | clearBitsMask = |
765 | rewriter.createOrFold<spirv::NotOp>(loc, dstType, clearBitsMask); |
766 | |
767 | Value storeVal = shiftValue(loc, adaptor.getValue(), offset, mask, rewriter); |
768 | Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp, |
769 | srcBits, dstBits, rewriter); |
770 | std::optional<spirv::Scope> scope = getAtomicOpScope(memrefType); |
771 | if (!scope) |
772 | return rewriter.notifyMatchFailure(storeOp, "atomic scope not available"); |
773 | |
774 | Value result = rewriter.create<spirv::AtomicAndOp>( |
775 | loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease, |
776 | clearBitsMask); |
777 | result = rewriter.create<spirv::AtomicOrOp>( |
778 | loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease, |
779 | storeVal); |
780 | |
781 | // The AtomicOrOp has no side effect. Since it is already inserted, we can |
782 | // just remove the original StoreOp. Note that rewriter.replaceOp() |
783 | // doesn't work because it only accepts that the numbers of result are the |
784 | // same. |
785 | rewriter.eraseOp(op: storeOp); |
786 | |
787 | assert(accessChainOp.use_empty()); |
788 | rewriter.eraseOp(op: accessChainOp); |
789 | |
790 | return success(); |
791 | } |
792 | |
793 | //===----------------------------------------------------------------------===// |
794 | // MemorySpaceCastOp |
795 | //===----------------------------------------------------------------------===// |
796 | |
797 | LogicalResult MemorySpaceCastOpPattern::matchAndRewrite( |
798 | memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor, |
799 | ConversionPatternRewriter &rewriter) const { |
800 | Location loc = addrCastOp.getLoc(); |
801 | auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>(); |
802 | if (!typeConverter.allows(spirv::Capability::Kernel)) |
803 | return rewriter.notifyMatchFailure( |
804 | arg&: loc, msg: "address space casts require kernel capability"); |
805 | |
806 | auto sourceType = dyn_cast<MemRefType>(addrCastOp.getSource().getType()); |
807 | if (!sourceType) |
808 | return rewriter.notifyMatchFailure( |
809 | arg&: loc, msg: "SPIR-V lowering requires ranked memref types"); |
810 | auto resultType = cast<MemRefType>(addrCastOp.getResult().getType()); |
811 | |
812 | auto sourceStorageClassAttr = |
813 | dyn_cast_or_null<spirv::StorageClassAttr>(sourceType.getMemorySpace()); |
814 | if (!sourceStorageClassAttr) |
815 | return rewriter.notifyMatchFailure(loc, reasonCallback: [sourceType](Diagnostic &diag) { |
816 | diag << "source address space "<< sourceType.getMemorySpace() |
817 | << " must be a SPIR-V storage class"; |
818 | }); |
819 | auto resultStorageClassAttr = |
820 | dyn_cast_or_null<spirv::StorageClassAttr>(resultType.getMemorySpace()); |
821 | if (!resultStorageClassAttr) |
822 | return rewriter.notifyMatchFailure(loc, reasonCallback: [resultType](Diagnostic &diag) { |
823 | diag << "result address space "<< resultType.getMemorySpace() |
824 | << " must be a SPIR-V storage class"; |
825 | }); |
826 | |
827 | spirv::StorageClass sourceSc = sourceStorageClassAttr.getValue(); |
828 | spirv::StorageClass resultSc = resultStorageClassAttr.getValue(); |
829 | |
830 | Value result = adaptor.getSource(); |
831 | Type resultPtrType = typeConverter.convertType(resultType); |
832 | if (!resultPtrType) |
833 | return rewriter.notifyMatchFailure(addrCastOp, |
834 | "failed to convert memref type"); |
835 | |
836 | Type genericPtrType = resultPtrType; |
837 | // SPIR-V doesn't have a general address space cast operation. Instead, it has |
838 | // conversions to and from generic pointers. To implement the general case, |
839 | // we use specific-to-generic conversions when the source class is not |
840 | // generic. Then when the result storage class is not generic, we convert the |
841 | // generic pointer (either the input on ar intermediate result) to that |
842 | // class. This also means that we'll need the intermediate generic pointer |
843 | // type if neither the source or destination have it. |
844 | if (sourceSc != spirv::StorageClass::Generic && |
845 | resultSc != spirv::StorageClass::Generic) { |
846 | Type intermediateType = |
847 | MemRefType::get(sourceType.getShape(), sourceType.getElementType(), |
848 | sourceType.getLayout(), |
849 | rewriter.getAttr<spirv::StorageClassAttr>( |
850 | spirv::StorageClass::Generic)); |
851 | genericPtrType = typeConverter.convertType(intermediateType); |
852 | } |
853 | if (sourceSc != spirv::StorageClass::Generic) { |
854 | result = |
855 | rewriter.create<spirv::PtrCastToGenericOp>(loc, genericPtrType, result); |
856 | } |
857 | if (resultSc != spirv::StorageClass::Generic) { |
858 | result = |
859 | rewriter.create<spirv::GenericCastToPtrOp>(loc, resultPtrType, result); |
860 | } |
861 | rewriter.replaceOp(addrCastOp, result); |
862 | return success(); |
863 | } |
864 | |
865 | LogicalResult |
866 | StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor, |
867 | ConversionPatternRewriter &rewriter) const { |
868 | auto memrefType = cast<MemRefType>(storeOp.getMemref().getType()); |
869 | if (memrefType.getElementType().isSignlessInteger()) |
870 | return rewriter.notifyMatchFailure(storeOp, "signless int"); |
871 | auto storePtr = spirv::getElementPtr( |
872 | typeConverter: *getTypeConverter<SPIRVTypeConverter>(), baseType: memrefType, basePtr: adaptor.getMemref(), |
873 | indices: adaptor.getIndices(), loc: storeOp.getLoc(), builder&: rewriter); |
874 | |
875 | if (!storePtr) |
876 | return rewriter.notifyMatchFailure(storeOp, "type conversion failed"); |
877 | |
878 | auto memoryRequirements = calculateMemoryRequirements(storePtr, storeOp); |
879 | if (failed(memoryRequirements)) |
880 | return rewriter.notifyMatchFailure( |
881 | storeOp, "failed to determine memory requirements"); |
882 | |
883 | auto [memoryAccess, alignment] = *memoryRequirements; |
884 | rewriter.replaceOpWithNewOp<spirv::StoreOp>( |
885 | storeOp, storePtr, adaptor.getValue(), memoryAccess, alignment); |
886 | return success(); |
887 | } |
888 | |
889 | LogicalResult ReinterpretCastPattern::matchAndRewrite( |
890 | memref::ReinterpretCastOp op, OpAdaptor adaptor, |
891 | ConversionPatternRewriter &rewriter) const { |
892 | Value src = adaptor.getSource(); |
893 | auto srcType = dyn_cast<spirv::PointerType>(Val: src.getType()); |
894 | |
895 | if (!srcType) |
896 | return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { |
897 | diag << "invalid src type "<< src.getType(); |
898 | }); |
899 | |
900 | const TypeConverter *converter = getTypeConverter(); |
901 | |
902 | auto dstType = converter->convertType<spirv::PointerType>(op.getType()); |
903 | if (dstType != srcType) |
904 | return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { |
905 | diag << "invalid dst type "<< op.getType(); |
906 | }); |
907 | |
908 | OpFoldResult offset = |
909 | getMixedValues(adaptor.getStaticOffsets(), adaptor.getOffsets(), rewriter) |
910 | .front(); |
911 | if (isZeroInteger(v: offset)) { |
912 | rewriter.replaceOp(op, src); |
913 | return success(); |
914 | } |
915 | |
916 | Type intType = converter->convertType(rewriter.getIndexType()); |
917 | if (!intType) |
918 | return rewriter.notifyMatchFailure(op, "failed to convert index type"); |
919 | |
920 | Location loc = op.getLoc(); |
921 | auto offsetValue = [&]() -> Value { |
922 | if (auto val = dyn_cast<Value>(offset)) |
923 | return val; |
924 | |
925 | int64_t attrVal = cast<IntegerAttr>(cast<Attribute>(Val&: offset)).getInt(); |
926 | Attribute attr = rewriter.getIntegerAttr(intType, attrVal); |
927 | return rewriter.createOrFold<spirv::ConstantOp>(loc, intType, attr); |
928 | }(); |
929 | |
930 | rewriter.replaceOpWithNewOp<spirv::InBoundsPtrAccessChainOp>( |
931 | op, src, offsetValue, std::nullopt); |
932 | return success(); |
933 | } |
934 | |
935 | //===----------------------------------------------------------------------===// |
936 | // ExtractAlignedPointerAsIndexOp |
937 | //===----------------------------------------------------------------------===// |
938 | |
939 | LogicalResult ExtractAlignedPointerAsIndexOpPattern::matchAndRewrite( |
940 | memref::ExtractAlignedPointerAsIndexOp extractOp, OpAdaptor adaptor, |
941 | ConversionPatternRewriter &rewriter) const { |
942 | auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>(); |
943 | Type indexType = typeConverter.getIndexType(); |
944 | rewriter.replaceOpWithNewOp<spirv::ConvertPtrToUOp>(extractOp, indexType, |
945 | adaptor.getSource()); |
946 | return success(); |
947 | } |
948 | |
949 | //===----------------------------------------------------------------------===// |
950 | // Pattern population |
951 | //===----------------------------------------------------------------------===// |
952 | |
953 | namespace mlir { |
954 | void populateMemRefToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, |
955 | RewritePatternSet &patterns) { |
956 | patterns |
957 | .add<AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern, |
958 | DeallocOpPattern, IntLoadOpPattern, IntStoreOpPattern, LoadOpPattern, |
959 | MemorySpaceCastOpPattern, StoreOpPattern, ReinterpretCastPattern, |
960 | CastPattern, ExtractAlignedPointerAsIndexOpPattern>( |
961 | arg: typeConverter, args: patterns.getContext()); |
962 | } |
963 | } // namespace mlir |
964 |
Definitions
- getOffsetForBitwidth
- adjustAccessChainForBitwidth
- castBoolToIntN
- shiftValue
- isAllocationSupported
- getAtomicOpScope
- castIntNToBool
- AllocaOpPattern
- AllocOpPattern
- AtomicRMWOpPattern
- DeallocOpPattern
- IntLoadOpPattern
- LoadOpPattern
- IntStoreOpPattern
- MemorySpaceCastOpPattern
- StoreOpPattern
- ReinterpretCastPattern
- CastPattern
- matchAndRewrite
- ExtractAlignedPointerAsIndexOpPattern
- matchAndRewrite
- matchAndRewrite
- matchAndRewrite
- matchAndRewrite
- MemoryRequirements
- calculateMemoryRequirements
- calculateMemoryRequirements
- matchAndRewrite
- matchAndRewrite
- matchAndRewrite
- matchAndRewrite
- matchAndRewrite
- matchAndRewrite
- matchAndRewrite
Improve your Profiling and Debugging skills
Find out more