1 | //===- FoldTensorSubsetOps.cpp - Fold tensor subset ops -------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Fold tensor subset ops with producer / consumers. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
14 | #include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h" |
15 | #include "mlir/Dialect/SCF/IR/SCF.h" |
16 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
17 | #include "mlir/Dialect/Tensor/Transforms/Passes.h" |
18 | #include "mlir/Dialect/Tensor/Transforms/Transforms.h" |
19 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
20 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
21 | #include "mlir/IR/AffineMap.h" |
22 | #include "mlir/IR/BuiltinAttributes.h" |
23 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
24 | #include "llvm/ADT/TypeSwitch.h" |
25 | #include <type_traits> |
26 | |
27 | namespace mlir { |
28 | namespace tensor { |
29 | #define GEN_PASS_DEF_FOLDTENSORSUBSETOPS |
30 | #include "mlir/Dialect/Tensor/Transforms/Passes.h.inc" |
31 | } // namespace tensor |
32 | } // namespace mlir |
33 | |
34 | using namespace mlir; |
35 | |
36 | static Value getTensorOperand(vector::TransferReadOp op) { |
37 | return op.getSource(); |
38 | } |
39 | |
40 | static Value getTensorOperand(tensor::InsertSliceOp op) { |
41 | return op.getSource(); |
42 | } |
43 | |
44 | //===----------------------------------------------------------------------===// |
45 | // Patterns |
46 | //===----------------------------------------------------------------------===// |
47 | |
48 | namespace { |
49 | /// Merge extract_slice operation with load/transferRead operation. |
50 | class final |
51 | : public OpRewritePattern<vector::TransferReadOp> { |
52 | public: |
53 | using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern; |
54 | |
55 | LogicalResult matchAndRewrite(vector::TransferReadOp readOp, |
56 | PatternRewriter &rewriter) const override; |
57 | }; |
58 | |
59 | /// Merge insert_slice operation with store/transferWriteOp operation. |
60 | class InsertSliceOfTransferWriteOpFolder final |
61 | : public OpRewritePattern<tensor::InsertSliceOp> { |
62 | public: |
63 | using OpRewritePattern<tensor::InsertSliceOp>::OpRewritePattern; |
64 | |
65 | LogicalResult matchAndRewrite(tensor::InsertSliceOp insertSliceOp, |
66 | PatternRewriter &rewriter) const override; |
67 | }; |
68 | } // namespace |
69 | |
70 | template <typename XferOp, typename ExtractOrInsertOp> |
71 | static LogicalResult ( |
72 | RewriterBase &rewriter, XferOp xferOp, |
73 | ExtractOrInsertOp ) { |
74 | if (xferOp.hasOutOfBoundsDim()) |
75 | return rewriter.notifyMatchFailure(xferOp, "out of bounds transfer dim" ); |
76 | if (xferOp.getMask()) |
77 | return rewriter.notifyMatchFailure(xferOp, "masked transfer" ); |
78 | if (!extractOrInsertSliceOp.hasUnitStride()) { |
79 | return rewriter.notifyMatchFailure( |
80 | xferOp, "non-1 stride insert/extract, requires keeping track of " |
81 | "strides, this may result in needing to insert " |
82 | "vector.insert_strided_slice/extract_strided_slice ops" ); |
83 | } |
84 | return success(); |
85 | } |
86 | |
87 | LogicalResult TransferReadOfExtractSliceOpFolder::matchAndRewrite( |
88 | vector::TransferReadOp readOp, PatternRewriter &rewriter) const { |
89 | auto = |
90 | getTensorOperand(readOp).getDefiningOp<tensor::ExtractSliceOp>(); |
91 | if (!extractSliceOp) |
92 | return rewriter.notifyMatchFailure(readOp, "not an extract_slice" ); |
93 | |
94 | LogicalResult preconditionResult = |
95 | preconditionsFoldExtractOrInsertWithTransferOp(rewriter, readOp, |
96 | extractSliceOp); |
97 | if (failed(result: preconditionResult)) |
98 | return preconditionResult; |
99 | |
100 | SmallVector<Value> indices(readOp.getIndices().begin(), |
101 | readOp.getIndices().end()); |
102 | SmallVector<Value> sourceIndices; |
103 | affine::resolveIndicesIntoOpWithOffsetsAndStrides( |
104 | rewriter, readOp.getLoc(), extractSliceOp.getMixedOffsets(), |
105 | extractSliceOp.getMixedStrides(), extractSliceOp.getDroppedDims(), |
106 | indices, sourceIndices); |
107 | |
108 | rewriter.replaceOpWithNewOp<vector::TransferReadOp>( |
109 | readOp, readOp.getVectorType(), extractSliceOp.getSource(), sourceIndices, |
110 | AffineMapAttr::get(expandDimsToRank( |
111 | readOp.getPermutationMap(), extractSliceOp.getSourceType().getRank(), |
112 | extractSliceOp.getDroppedDims())), |
113 | readOp.getPadding(), |
114 | /*mask=*/Value(), readOp.getInBoundsAttr()); |
115 | |
116 | return success(); |
117 | } |
118 | |
119 | LogicalResult InsertSliceOfTransferWriteOpFolder::matchAndRewrite( |
120 | tensor::InsertSliceOp insertSliceOp, PatternRewriter &rewriter) const { |
121 | auto writeOp = getTensorOperand(insertSliceOp) |
122 | .template getDefiningOp<vector::TransferWriteOp>(); |
123 | if (!writeOp) |
124 | return rewriter.notifyMatchFailure(insertSliceOp, "not a transfer_write" ); |
125 | |
126 | LogicalResult preconditionResult = |
127 | preconditionsFoldExtractOrInsertWithTransferOp(rewriter, writeOp, |
128 | insertSliceOp); |
129 | if (failed(result: preconditionResult)) |
130 | return preconditionResult; |
131 | |
132 | SmallVector<Value> indices(writeOp.getIndices().begin(), |
133 | writeOp.getIndices().end()); |
134 | SmallVector<Value> sourceIndices; |
135 | affine::resolveIndicesIntoOpWithOffsetsAndStrides( |
136 | rewriter, writeOp.getLoc(), insertSliceOp.getMixedOffsets(), |
137 | insertSliceOp.getMixedStrides(), insertSliceOp.getDroppedDims(), indices, |
138 | sourceIndices); |
139 | |
140 | rewriter.replaceOpWithNewOp<vector::TransferWriteOp>( |
141 | insertSliceOp, writeOp.getValue(), insertSliceOp.getDest(), sourceIndices, |
142 | AffineMapAttr::get(expandDimsToRank(writeOp.getPermutationMap(), |
143 | insertSliceOp.getDestType().getRank(), |
144 | insertSliceOp.getDroppedDims())), |
145 | writeOp.getInBoundsAttr()); |
146 | |
147 | return success(); |
148 | } |
149 | |
150 | template <typename OpTy> |
151 | struct InsertSliceOfInsertSliceFolder : public OpRewritePattern<OpTy> { |
152 | using OpRewritePattern<OpTy>::OpRewritePattern; |
153 | |
154 | LogicalResult matchAndRewrite(OpTy insertSliceOp, |
155 | PatternRewriter &rewriter) const override { |
156 | auto sourceInsertSliceOp = |
157 | insertSliceOp.getSource() |
158 | .template getDefiningOp<tensor::InsertSliceOp>(); |
159 | if (!sourceInsertSliceOp) |
160 | return failure(); |
161 | |
162 | // TODO: relax unit stride assumption where possible. |
163 | if (!insertSliceOp.hasUnitStride()) { |
164 | return rewriter.notifyMatchFailure(insertSliceOp, |
165 | "requires unit strides" ); |
166 | } |
167 | if (!sourceInsertSliceOp.hasUnitStride()) { |
168 | return rewriter.notifyMatchFailure(sourceInsertSliceOp, |
169 | "requires unit strides" ); |
170 | } |
171 | |
172 | int64_t srcDim = 0; |
173 | llvm::SmallBitVector droppedDims = insertSliceOp.getDroppedDims(); |
174 | for (int64_t d = 0, e = insertSliceOp.getDestType().getRank(); d < e; ++d) { |
175 | if (droppedDims[d]) |
176 | continue; |
177 | if (insertSliceOp.getMixedSizes()[d] != |
178 | sourceInsertSliceOp.getMixedSizes()[srcDim++]) { |
179 | return rewriter.notifyMatchFailure( |
180 | sourceInsertSliceOp, |
181 | "requires matching sizes to fold, otherwise a copy is needed" ); |
182 | } |
183 | } |
184 | |
185 | // Resolve sizes according to dropped dims. |
186 | SmallVector<OpFoldResult> resolvedSizes; |
187 | // Note: the "insertSlice" case is symmetrical to the extract/subview case: |
188 | // `insertSliceOp` is passed as the "source" and `sourceInsertSliceOp` is |
189 | // passed as the destination to the helper function. |
190 | affine::resolveSizesIntoOpWithSizes(sourceSizes: insertSliceOp.getMixedSizes(), |
191 | destSizes: sourceInsertSliceOp.getMixedSizes(), |
192 | rankReducedSourceDims: droppedDims, resolvedSizes); |
193 | |
194 | // If we are inside an InParallel region, temporarily set the insertion |
195 | // point outside: only tensor.parallel_insert_slice ops are allowed in |
196 | // there. |
197 | if (std::is_same_v<OpTy, tensor::ParallelInsertSliceOp>) { |
198 | rewriter.setInsertionPoint( |
199 | insertSliceOp->template getParentOfType<scf::InParallelOp>()); |
200 | } |
201 | |
202 | // Resolve offsets according to source offsets and strides. |
203 | SmallVector<Value> resolvedOffsets; |
204 | // Note: the "insertSlice" case is symmetrical to the extract/subview case: |
205 | // `insertSliceOp` is passed as the "source" and `sourceInsertSliceOp` is |
206 | // passed as the destination to the helper function. |
207 | affine::resolveIndicesIntoOpWithOffsetsAndStrides( |
208 | rewriter, insertSliceOp.getLoc(), insertSliceOp.getMixedOffsets(), |
209 | insertSliceOp.getMixedStrides(), droppedDims, |
210 | sourceInsertSliceOp.getMixedOffsets(), resolvedOffsets); |
211 | |
212 | // Reset the insertion point. |
213 | rewriter.setInsertionPoint(insertSliceOp); |
214 | // Replace original op. |
215 | rewriter.replaceOpWithNewOp<OpTy>( |
216 | insertSliceOp, sourceInsertSliceOp.getSource(), insertSliceOp.getDest(), |
217 | getAsOpFoldResult(values: resolvedOffsets), resolvedSizes, |
218 | insertSliceOp.getMixedStrides()); |
219 | |
220 | return success(); |
221 | } |
222 | }; |
223 | |
224 | void tensor::populateFoldTensorSubsetOpPatterns(RewritePatternSet &patterns) { |
225 | populateFoldTensorSubsetIntoVectorTransferPatterns(patterns); |
226 | patterns.add<InsertSliceOfInsertSliceFolder<tensor::InsertSliceOp>, |
227 | InsertSliceOfInsertSliceFolder<tensor::ParallelInsertSliceOp>>( |
228 | patterns.getContext()); |
229 | } |
230 | |
231 | void tensor::populateFoldTensorSubsetIntoVectorTransferPatterns( |
232 | RewritePatternSet &patterns) { |
233 | patterns.add<TransferReadOfExtractSliceOpFolder, |
234 | InsertSliceOfTransferWriteOpFolder>(arg: patterns.getContext()); |
235 | } |
236 | |
237 | //===----------------------------------------------------------------------===// |
238 | // Pass registration |
239 | //===----------------------------------------------------------------------===// |
240 | |
241 | namespace { |
242 | |
243 | struct FoldTensorSubsetOpsPass final |
244 | : public tensor::impl::FoldTensorSubsetOpsBase<FoldTensorSubsetOpsPass> { |
245 | void runOnOperation() override; |
246 | }; |
247 | |
248 | } // namespace |
249 | |
250 | void FoldTensorSubsetOpsPass::runOnOperation() { |
251 | RewritePatternSet patterns(&getContext()); |
252 | tensor::populateFoldTensorSubsetOpPatterns(patterns); |
253 | (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); |
254 | } |
255 | |
256 | std::unique_ptr<Pass> tensor::createFoldTensorSubsetOpsPass() { |
257 | return std::make_unique<FoldTensorSubsetOpsPass>(); |
258 | } |
259 | |