1//===- FoldTensorSubsetOps.cpp - Fold tensor subset ops -------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Fold tensor subset ops with producer / consumers.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/Affine/IR/AffineOps.h"
14#include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
15#include "mlir/Dialect/SCF/IR/SCF.h"
16#include "mlir/Dialect/Tensor/IR/Tensor.h"
17#include "mlir/Dialect/Tensor/Transforms/Passes.h"
18#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
19#include "mlir/Dialect/Vector/IR/VectorOps.h"
20#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
21#include "mlir/IR/AffineMap.h"
22#include "mlir/IR/BuiltinAttributes.h"
23#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
24#include <type_traits>
25
26namespace mlir {
27namespace tensor {
28#define GEN_PASS_DEF_FOLDTENSORSUBSETOPSPASS
29#include "mlir/Dialect/Tensor/Transforms/Passes.h.inc"
30} // namespace tensor
31} // namespace mlir
32
33using namespace mlir;
34
35static Value getTensorOperand(vector::TransferReadOp op) {
36 return op.getBase();
37}
38
39static Value getTensorOperand(tensor::InsertSliceOp op) {
40 return op.getSource();
41}
42
43//===----------------------------------------------------------------------===//
44// Patterns
45//===----------------------------------------------------------------------===//
46
47namespace {
48/// Merge extract_slice operation with load/transferRead operation.
49class TransferReadOfExtractSliceOpFolder final
50 : public vector::MaskableOpRewritePattern<vector::TransferReadOp> {
51public:
52 using MaskableOpRewritePattern::MaskableOpRewritePattern;
53
54 FailureOr<mlir::Value>
55 matchAndRewriteMaskableOp(vector::TransferReadOp readOp,
56 vector::MaskingOpInterface maskOp,
57 PatternRewriter &rewriter) const override;
58};
59
60/// Merge insert_slice operation with store/transferWriteOp operation.
61class InsertSliceOfTransferWriteOpFolder final
62 : public OpRewritePattern<tensor::InsertSliceOp> {
63public:
64 using OpRewritePattern<tensor::InsertSliceOp>::OpRewritePattern;
65
66 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertSliceOp,
67 PatternRewriter &rewriter) const override;
68
69private:
70 static bool
71 doesTransferWriteCoverInsertSlice(vector::TransferWriteOp writeOp);
72};
73} // namespace
74
75template <typename XferOp, typename ExtractOrInsertOp>
76static LogicalResult preconditionsFoldExtractOrInsertWithTransferOp(
77 RewriterBase &rewriter, XferOp xferOp,
78 ExtractOrInsertOp extractOrInsertSliceOp) {
79 if (xferOp.hasOutOfBoundsDim())
80 return rewriter.notifyMatchFailure(xferOp, "out of bounds transfer dim");
81 if (xferOp.getMask())
82 return rewriter.notifyMatchFailure(xferOp, "masked transfer");
83 if (!extractOrInsertSliceOp.hasUnitStride()) {
84 return rewriter.notifyMatchFailure(
85 xferOp, "non-1 stride insert/extract, requires keeping track of "
86 "strides, this may result in needing to insert "
87 "vector.insert_strided_slice/extract_strided_slice ops");
88 }
89 return success();
90}
91
92FailureOr<mlir::Value>
93TransferReadOfExtractSliceOpFolder::matchAndRewriteMaskableOp(
94 vector::TransferReadOp readOp, vector::MaskingOpInterface maskOp,
95 PatternRewriter &rewriter) const {
96 auto extractSliceOp =
97 getTensorOperand(op: readOp).getDefiningOp<tensor::ExtractSliceOp>();
98 if (!extractSliceOp)
99 return rewriter.notifyMatchFailure(arg&: readOp, msg: "not an extract_slice");
100
101 LogicalResult preconditionResult =
102 preconditionsFoldExtractOrInsertWithTransferOp(rewriter, xferOp: readOp,
103 extractOrInsertSliceOp: extractSliceOp);
104 if (failed(Result: preconditionResult))
105 return rewriter.notifyMatchFailure(arg&: readOp, msg: "Failed preconditions");
106
107 SmallVector<Value> indices(readOp.getIndices().begin(),
108 readOp.getIndices().end());
109 SmallVector<Value> sourceIndices;
110 affine::resolveIndicesIntoOpWithOffsetsAndStrides(
111 rewriter, loc: readOp.getLoc(), mixedSourceOffsets: extractSliceOp.getMixedOffsets(),
112 mixedSourceStrides: extractSliceOp.getMixedStrides(), rankReducedDims: extractSliceOp.getDroppedDims(),
113 consumerIndices: indices, resolvedIndices&: sourceIndices);
114
115 Operation *newOp = rewriter.create<vector::TransferReadOp>(
116 location: readOp.getLoc(), args: readOp.getVectorType(), args: extractSliceOp.getSource(),
117 args&: sourceIndices,
118 args: AffineMapAttr::get(value: expandDimsToRank(
119 map: readOp.getPermutationMap(), rank: extractSliceOp.getSourceType().getRank(),
120 projectedDimensions: extractSliceOp.getDroppedDims())),
121 args: readOp.getPadding(),
122 /*mask=*/args: Value(), args: readOp.getInBoundsAttr());
123 if (maskOp)
124 newOp = mlir::vector::maskOperation(builder&: rewriter, maskableOp: newOp, mask: maskOp.getMask());
125 return newOp->getResults()[0];
126}
127
128LogicalResult InsertSliceOfTransferWriteOpFolder::matchAndRewrite(
129 tensor::InsertSliceOp insertSliceOp, PatternRewriter &rewriter) const {
130 auto writeOp = getTensorOperand(op: insertSliceOp)
131 .template getDefiningOp<vector::TransferWriteOp>();
132 if (!writeOp)
133 return rewriter.notifyMatchFailure(arg&: insertSliceOp, msg: "not a transfer_write");
134
135 LogicalResult preconditionResult =
136 preconditionsFoldExtractOrInsertWithTransferOp(rewriter, xferOp: writeOp,
137 extractOrInsertSliceOp: insertSliceOp);
138 if (failed(Result: preconditionResult))
139 return preconditionResult;
140
141 if (!doesTransferWriteCoverInsertSlice(writeOp))
142 return rewriter.notifyMatchFailure(
143 arg&: insertSliceOp, msg: "transfer_write does not cover insert_slice");
144
145 SmallVector<Value> indices(writeOp.getIndices().begin(),
146 writeOp.getIndices().end());
147 SmallVector<Value> sourceIndices;
148 affine::resolveIndicesIntoOpWithOffsetsAndStrides(
149 rewriter, loc: writeOp.getLoc(), mixedSourceOffsets: insertSliceOp.getMixedOffsets(),
150 mixedSourceStrides: insertSliceOp.getMixedStrides(), rankReducedDims: insertSliceOp.getDroppedDims(), consumerIndices: indices,
151 resolvedIndices&: sourceIndices);
152
153 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
154 op: insertSliceOp, args: writeOp.getValue(), args: insertSliceOp.getDest(), args&: sourceIndices,
155 args: AffineMapAttr::get(value: expandDimsToRank(map: writeOp.getPermutationMap(),
156 rank: insertSliceOp.getDestType().getRank(),
157 projectedDimensions: insertSliceOp.getDroppedDims())),
158 args: writeOp.getInBoundsAttr());
159
160 return success();
161}
162
163bool InsertSliceOfTransferWriteOpFolder::doesTransferWriteCoverInsertSlice(
164 vector::TransferWriteOp writeOp) {
165 if (writeOp.getShapedType().hasStaticShape())
166 return llvm::equal(LRange: writeOp.getVectorType().getShape(),
167 RRange: writeOp.getShapedType().getShape());
168
169 // TODO: Use ValueBoundsConstraintSet for dynamic shapes.
170
171 return false;
172}
173
174template <typename OpTy>
175struct InsertSliceOfInsertSliceFolder : public OpRewritePattern<OpTy> {
176 using OpRewritePattern<OpTy>::OpRewritePattern;
177
178 LogicalResult matchAndRewrite(OpTy insertSliceOp,
179 PatternRewriter &rewriter) const override {
180 auto sourceInsertSliceOp =
181 insertSliceOp.getSource()
182 .template getDefiningOp<tensor::InsertSliceOp>();
183 if (!sourceInsertSliceOp)
184 return failure();
185
186 // TODO: relax unit stride assumption where possible.
187 if (!insertSliceOp.hasUnitStride()) {
188 return rewriter.notifyMatchFailure(insertSliceOp,
189 "requires unit strides");
190 }
191 if (!sourceInsertSliceOp.hasUnitStride()) {
192 return rewriter.notifyMatchFailure(sourceInsertSliceOp,
193 "requires unit strides");
194 }
195
196 int64_t srcDim = 0;
197 llvm::SmallBitVector droppedDims = insertSliceOp.getDroppedDims();
198 for (int64_t d = 0, e = insertSliceOp.getDestType().getRank(); d < e; ++d) {
199 if (droppedDims[d])
200 continue;
201 if (insertSliceOp.getMixedSizes()[d] !=
202 sourceInsertSliceOp.getMixedSizes()[srcDim++]) {
203 return rewriter.notifyMatchFailure(
204 sourceInsertSliceOp,
205 "requires matching sizes to fold, otherwise a copy is needed");
206 }
207 }
208
209 // Resolve sizes according to dropped dims.
210 SmallVector<OpFoldResult> resolvedSizes;
211 // Note: the "insertSlice" case is symmetrical to the extract/subview case:
212 // `insertSliceOp` is passed as the "source" and `sourceInsertSliceOp` is
213 // passed as the destination to the helper function.
214 affine::resolveSizesIntoOpWithSizes(sourceSizes: insertSliceOp.getMixedSizes(),
215 destSizes: sourceInsertSliceOp.getMixedSizes(),
216 rankReducedSourceDims: droppedDims, resolvedSizes);
217
218 // If we are inside an InParallel region, temporarily set the insertion
219 // point outside: only tensor.parallel_insert_slice ops are allowed in
220 // there.
221 if (std::is_same_v<OpTy, tensor::ParallelInsertSliceOp>) {
222 rewriter.setInsertionPoint(
223 insertSliceOp->template getParentOfType<scf::InParallelOp>());
224 }
225
226 // Resolve offsets according to source offsets and strides.
227 SmallVector<Value> resolvedOffsets;
228 // Note: the "insertSlice" case is symmetrical to the extract/subview case:
229 // `insertSliceOp` is passed as the "source" and `sourceInsertSliceOp` is
230 // passed as the destination to the helper function.
231 affine::resolveIndicesIntoOpWithOffsetsAndStrides(
232 rewriter, insertSliceOp.getLoc(), insertSliceOp.getMixedOffsets(),
233 insertSliceOp.getMixedStrides(), droppedDims,
234 sourceInsertSliceOp.getMixedOffsets(), resolvedOffsets);
235
236 // Reset the insertion point.
237 rewriter.setInsertionPoint(insertSliceOp);
238 // Replace original op.
239 rewriter.replaceOpWithNewOp<OpTy>(
240 insertSliceOp, sourceInsertSliceOp.getSource(), insertSliceOp.getDest(),
241 getAsOpFoldResult(values: resolvedOffsets), resolvedSizes,
242 insertSliceOp.getMixedStrides());
243
244 return success();
245 }
246};
247
248void tensor::populateFoldTensorSubsetOpPatterns(RewritePatternSet &patterns) {
249 populateFoldTensorSubsetIntoVectorTransferPatterns(patterns);
250 patterns.add<InsertSliceOfInsertSliceFolder<tensor::InsertSliceOp>,
251 InsertSliceOfInsertSliceFolder<tensor::ParallelInsertSliceOp>>(
252 arg: patterns.getContext());
253}
254
255void tensor::populateFoldTensorSubsetIntoVectorTransferPatterns(
256 RewritePatternSet &patterns) {
257 patterns.add<TransferReadOfExtractSliceOpFolder,
258 InsertSliceOfTransferWriteOpFolder>(arg: patterns.getContext());
259}
260
261//===----------------------------------------------------------------------===//
262// Pass registration
263//===----------------------------------------------------------------------===//
264
265namespace {
266
267struct FoldTensorSubsetOpsPass final
268 : public tensor::impl::FoldTensorSubsetOpsPassBase<
269 FoldTensorSubsetOpsPass> {
270 void runOnOperation() override;
271};
272
273} // namespace
274
275void FoldTensorSubsetOpsPass::runOnOperation() {
276 RewritePatternSet patterns(&getContext());
277 tensor::populateFoldTensorSubsetOpPatterns(patterns);
278 (void)applyPatternsGreedily(op: getOperation(), patterns: std::move(patterns));
279}
280

source code of mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp