1 | //===- SubsetOpInterfaceImpl.cpp - Tensor subsets -------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Vector/Transforms/SubsetOpInterfaceImpl.h" |
10 | |
11 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
12 | #include "mlir/Interfaces/SubsetOpInterface.h" |
13 | |
14 | using namespace mlir; |
15 | using namespace mlir::vector; |
16 | |
17 | namespace { |
18 | |
19 | template <typename OpTy> |
20 | struct XferOpSubsetOpInterface |
21 | : public SubsetOpInterface::ExternalModel<XferOpSubsetOpInterface<OpTy>, |
22 | OpTy> { |
23 | FailureOr<HyperrectangularSlice> |
24 | getAccessedHyperrectangularSlice(Operation *op) const { |
25 | auto xferOp = cast<OpTy>(op); |
26 | Builder b(xferOp->getContext()); |
27 | SmallVector<OpFoldResult> offsets = llvm::map_to_vector( |
28 | xferOp.getIndices(), [](Value v) -> OpFoldResult { return v; }); |
29 | SmallVector<OpFoldResult> sizes = llvm::map_to_vector( |
30 | xferOp.getTransferChunkAccessed(), |
31 | [&](int64_t sz) -> OpFoldResult { return b.getIndexAttr(sz); }); |
32 | return HyperrectangularSlice(offsets, sizes); |
33 | } |
34 | }; |
35 | |
36 | struct |
37 | : public SubsetExtractionOpInterface::ExternalModel< |
38 | TransferReadOpSubsetExtractionOpInterface, vector::TransferReadOp> { |
39 | OpOperand &getSourceOperand(Operation *op) const { |
40 | return cast<vector::TransferReadOp>(op).getSourceMutable(); |
41 | } |
42 | }; |
43 | |
44 | struct TransferWriteOpSubsetInsertionOpInterface |
45 | : public SubsetInsertionOpInterface::ExternalModel< |
46 | TransferWriteOpSubsetInsertionOpInterface, vector::TransferWriteOp> { |
47 | OpOperand &getSourceOperand(Operation *op) const { |
48 | return cast<vector::TransferWriteOp>(op).getVectorMutable(); |
49 | } |
50 | |
51 | OpOperand &getDestinationOperand(Operation *op) const { |
52 | return cast<vector::TransferWriteOp>(op).getSourceMutable(); |
53 | } |
54 | |
55 | Value (Operation *op, OpBuilder &builder, |
56 | Location loc) const { |
57 | // TODO: Implement when needed. |
58 | return Value(); |
59 | } |
60 | |
61 | SmallVector<Value> |
62 | (Operation *op) const { |
63 | // TODO: Implement when needed. |
64 | return {}; |
65 | } |
66 | }; |
67 | |
68 | } // namespace |
69 | |
70 | void mlir::vector::registerSubsetOpInterfaceExternalModels( |
71 | DialectRegistry ®istry) { |
72 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, vector::VectorDialect *dialect) { |
73 | TransferReadOp::attachInterface<XferOpSubsetOpInterface<TransferReadOp>>( |
74 | *ctx); |
75 | TransferReadOp::attachInterface<TransferReadOpSubsetExtractionOpInterface>( |
76 | *ctx); |
77 | TransferWriteOp::attachInterface<XferOpSubsetOpInterface<TransferWriteOp>>( |
78 | *ctx); |
79 | TransferWriteOp::attachInterface<TransferWriteOpSubsetInsertionOpInterface>( |
80 | *ctx); |
81 | }); |
82 | } |
83 | |