1 | //===- Utils.cpp - Transform utilities ------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/NVGPU/Transforms/Utils.h" |
10 | |
11 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
12 | #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" |
13 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
14 | |
15 | using namespace mlir; |
16 | using namespace mlir::nvgpu; |
17 | |
18 | Operation::operand_range nvgpu::getIndices(Operation *op) { |
19 | if (auto ldmatrixOp = dyn_cast<LdMatrixOp>(op)) |
20 | return ldmatrixOp.getIndices(); |
21 | if (auto copyOp = dyn_cast<DeviceAsyncCopyOp>(op)) |
22 | return copyOp.getDstIndices(); |
23 | if (auto loadOp = dyn_cast<memref::LoadOp>(op)) |
24 | return loadOp.getIndices(); |
25 | if (auto storeOp = dyn_cast<memref::StoreOp>(op)) |
26 | return storeOp.getIndices(); |
27 | if (auto vectorReadOp = dyn_cast<vector::LoadOp>(op)) |
28 | return vectorReadOp.getIndices(); |
29 | if (auto vectorStoreOp = dyn_cast<vector::StoreOp>(op)) |
30 | return vectorStoreOp.getIndices(); |
31 | if (auto transferReadOp = dyn_cast<vector::TransferReadOp>(op)) |
32 | return transferReadOp.getIndices(); |
33 | if (auto transferWriteOp = dyn_cast<vector::TransferWriteOp>(op)) |
34 | return transferWriteOp.getIndices(); |
35 | llvm_unreachable("unsupported op type" ); |
36 | } |
37 | |
38 | void nvgpu::setIndices(Operation *op, ArrayRef<Value> indices) { |
39 | if (auto ldmatrixOp = dyn_cast<LdMatrixOp>(op)) |
40 | return ldmatrixOp.getIndicesMutable().assign(indices); |
41 | if (auto copyOp = dyn_cast<DeviceAsyncCopyOp>(op)) |
42 | return copyOp.getDstIndicesMutable().assign(indices); |
43 | if (auto loadOp = dyn_cast<memref::LoadOp>(op)) |
44 | return loadOp.getIndicesMutable().assign(indices); |
45 | if (auto storeOp = dyn_cast<memref::StoreOp>(op)) |
46 | return storeOp.getIndicesMutable().assign(indices); |
47 | if (auto vectorReadOp = dyn_cast<vector::LoadOp>(op)) |
48 | return vectorReadOp.getIndicesMutable().assign(indices); |
49 | if (auto vectorStoreOp = dyn_cast<vector::StoreOp>(op)) |
50 | return vectorStoreOp.getIndicesMutable().assign(indices); |
51 | if (auto transferReadOp = dyn_cast<vector::TransferReadOp>(op)) |
52 | return transferReadOp.getIndicesMutable().assign(indices); |
53 | if (auto transferWriteOp = dyn_cast<vector::TransferWriteOp>(op)) |
54 | return transferWriteOp.getIndicesMutable().assign(indices); |
55 | llvm_unreachable("unsupported op type" ); |
56 | } |
57 | |
58 | Value nvgpu::getValueStored(Operation *op) { |
59 | if (auto storeOp = dyn_cast<memref::StoreOp>(op)) |
60 | return storeOp.getValueToStore(); |
61 | if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op)) |
62 | return transferWrite.getValue(); |
63 | if (auto storeOp = dyn_cast<vector::StoreOp>(op)) |
64 | return storeOp.getValueToStore(); |
65 | llvm_unreachable("unsupported op type" ); |
66 | } |
67 | |
68 | Value nvgpu::getMemrefOperand(Operation *op) { |
69 | if (auto loadOp = dyn_cast<memref::LoadOp>(op)) |
70 | return loadOp.getMemref(); |
71 | if (auto storeOp = dyn_cast<memref::StoreOp>(op)) |
72 | return storeOp.getMemref(); |
73 | if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op)) |
74 | return transferWrite.getSource(); |
75 | if (auto transferRead = dyn_cast<vector::TransferReadOp>(op)) |
76 | return transferRead.getSource(); |
77 | if (auto storeOp = dyn_cast<vector::StoreOp>(op)) |
78 | return storeOp.getBase(); |
79 | if (auto loadOp = dyn_cast<vector::LoadOp>(op)) |
80 | return loadOp.getBase(); |
81 | llvm_unreachable("unsupported op type" ); |
82 | } |
83 | |