1//===- FoldTensorSubsetOps.cpp - Fold tensor subset ops -------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Fold tensor subset ops with producer / consumers.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/Affine/IR/AffineOps.h"
14#include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
15#include "mlir/Dialect/SCF/IR/SCF.h"
16#include "mlir/Dialect/Tensor/IR/Tensor.h"
17#include "mlir/Dialect/Tensor/Transforms/Passes.h"
18#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
19#include "mlir/Dialect/Utils/IndexingUtils.h"
20#include "mlir/Dialect/Vector/IR/VectorOps.h"
21#include "mlir/IR/AffineMap.h"
22#include "mlir/IR/BuiltinAttributes.h"
23#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
24#include "llvm/ADT/TypeSwitch.h"
25#include <type_traits>
26
27namespace mlir {
28namespace tensor {
29#define GEN_PASS_DEF_FOLDTENSORSUBSETOPS
30#include "mlir/Dialect/Tensor/Transforms/Passes.h.inc"
31} // namespace tensor
32} // namespace mlir
33
34using namespace mlir;
35
36static Value getTensorOperand(vector::TransferReadOp op) {
37 return op.getSource();
38}
39
40static Value getTensorOperand(tensor::InsertSliceOp op) {
41 return op.getSource();
42}
43
44//===----------------------------------------------------------------------===//
45// Patterns
46//===----------------------------------------------------------------------===//
47
48namespace {
49/// Merge extract_slice operation with load/transferRead operation.
50class TransferReadOfExtractSliceOpFolder final
51 : public OpRewritePattern<vector::TransferReadOp> {
52public:
53 using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
54
55 LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
56 PatternRewriter &rewriter) const override;
57};
58
59/// Merge insert_slice operation with store/transferWriteOp operation.
60class InsertSliceOfTransferWriteOpFolder final
61 : public OpRewritePattern<tensor::InsertSliceOp> {
62public:
63 using OpRewritePattern<tensor::InsertSliceOp>::OpRewritePattern;
64
65 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertSliceOp,
66 PatternRewriter &rewriter) const override;
67};
68} // namespace
69
70template <typename XferOp, typename ExtractOrInsertOp>
71static LogicalResult preconditionsFoldExtractOrInsertWithTransferOp(
72 RewriterBase &rewriter, XferOp xferOp,
73 ExtractOrInsertOp extractOrInsertSliceOp) {
74 if (xferOp.hasOutOfBoundsDim())
75 return rewriter.notifyMatchFailure(xferOp, "out of bounds transfer dim");
76 if (xferOp.getMask())
77 return rewriter.notifyMatchFailure(xferOp, "masked transfer");
78 if (!extractOrInsertSliceOp.hasUnitStride()) {
79 return rewriter.notifyMatchFailure(
80 xferOp, "non-1 stride insert/extract, requires keeping track of "
81 "strides, this may result in needing to insert "
82 "vector.insert_strided_slice/extract_strided_slice ops");
83 }
84 return success();
85}
86
87LogicalResult TransferReadOfExtractSliceOpFolder::matchAndRewrite(
88 vector::TransferReadOp readOp, PatternRewriter &rewriter) const {
89 auto extractSliceOp =
90 getTensorOperand(readOp).getDefiningOp<tensor::ExtractSliceOp>();
91 if (!extractSliceOp)
92 return rewriter.notifyMatchFailure(readOp, "not an extract_slice");
93
94 LogicalResult preconditionResult =
95 preconditionsFoldExtractOrInsertWithTransferOp(rewriter, readOp,
96 extractSliceOp);
97 if (failed(result: preconditionResult))
98 return preconditionResult;
99
100 SmallVector<Value> indices(readOp.getIndices().begin(),
101 readOp.getIndices().end());
102 SmallVector<Value> sourceIndices;
103 affine::resolveIndicesIntoOpWithOffsetsAndStrides(
104 rewriter, readOp.getLoc(), extractSliceOp.getMixedOffsets(),
105 extractSliceOp.getMixedStrides(), extractSliceOp.getDroppedDims(),
106 indices, sourceIndices);
107
108 rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
109 readOp, readOp.getVectorType(), extractSliceOp.getSource(), sourceIndices,
110 AffineMapAttr::get(expandDimsToRank(
111 readOp.getPermutationMap(), extractSliceOp.getSourceType().getRank(),
112 extractSliceOp.getDroppedDims())),
113 readOp.getPadding(),
114 /*mask=*/Value(), readOp.getInBoundsAttr());
115
116 return success();
117}
118
119LogicalResult InsertSliceOfTransferWriteOpFolder::matchAndRewrite(
120 tensor::InsertSliceOp insertSliceOp, PatternRewriter &rewriter) const {
121 auto writeOp = getTensorOperand(insertSliceOp)
122 .template getDefiningOp<vector::TransferWriteOp>();
123 if (!writeOp)
124 return rewriter.notifyMatchFailure(insertSliceOp, "not a transfer_write");
125
126 LogicalResult preconditionResult =
127 preconditionsFoldExtractOrInsertWithTransferOp(rewriter, writeOp,
128 insertSliceOp);
129 if (failed(result: preconditionResult))
130 return preconditionResult;
131
132 SmallVector<Value> indices(writeOp.getIndices().begin(),
133 writeOp.getIndices().end());
134 SmallVector<Value> sourceIndices;
135 affine::resolveIndicesIntoOpWithOffsetsAndStrides(
136 rewriter, writeOp.getLoc(), insertSliceOp.getMixedOffsets(),
137 insertSliceOp.getMixedStrides(), insertSliceOp.getDroppedDims(), indices,
138 sourceIndices);
139
140 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
141 insertSliceOp, writeOp.getValue(), insertSliceOp.getDest(), sourceIndices,
142 AffineMapAttr::get(expandDimsToRank(writeOp.getPermutationMap(),
143 insertSliceOp.getDestType().getRank(),
144 insertSliceOp.getDroppedDims())),
145 writeOp.getInBoundsAttr());
146
147 return success();
148}
149
150template <typename OpTy>
151struct InsertSliceOfInsertSliceFolder : public OpRewritePattern<OpTy> {
152 using OpRewritePattern<OpTy>::OpRewritePattern;
153
154 LogicalResult matchAndRewrite(OpTy insertSliceOp,
155 PatternRewriter &rewriter) const override {
156 auto sourceInsertSliceOp =
157 insertSliceOp.getSource()
158 .template getDefiningOp<tensor::InsertSliceOp>();
159 if (!sourceInsertSliceOp)
160 return failure();
161
162 // TODO: relax unit stride assumption where possible.
163 if (!insertSliceOp.hasUnitStride()) {
164 return rewriter.notifyMatchFailure(insertSliceOp,
165 "requires unit strides");
166 }
167 if (!sourceInsertSliceOp.hasUnitStride()) {
168 return rewriter.notifyMatchFailure(sourceInsertSliceOp,
169 "requires unit strides");
170 }
171
172 int64_t srcDim = 0;
173 llvm::SmallBitVector droppedDims = insertSliceOp.getDroppedDims();
174 for (int64_t d = 0, e = insertSliceOp.getDestType().getRank(); d < e; ++d) {
175 if (droppedDims[d])
176 continue;
177 if (insertSliceOp.getMixedSizes()[d] !=
178 sourceInsertSliceOp.getMixedSizes()[srcDim++]) {
179 return rewriter.notifyMatchFailure(
180 sourceInsertSliceOp,
181 "requires matching sizes to fold, otherwise a copy is needed");
182 }
183 }
184
185 // Resolve sizes according to dropped dims.
186 SmallVector<OpFoldResult> resolvedSizes;
187 // Note: the "insertSlice" case is symmetrical to the extract/subview case:
188 // `insertSliceOp` is passed as the "source" and `sourceInsertSliceOp` is
189 // passed as the destination to the helper function.
190 affine::resolveSizesIntoOpWithSizes(sourceSizes: insertSliceOp.getMixedSizes(),
191 destSizes: sourceInsertSliceOp.getMixedSizes(),
192 rankReducedSourceDims: droppedDims, resolvedSizes);
193
194 // If we are inside an InParallel region, temporarily set the insertion
195 // point outside: only tensor.parallel_insert_slice ops are allowed in
196 // there.
197 if (std::is_same_v<OpTy, tensor::ParallelInsertSliceOp>) {
198 rewriter.setInsertionPoint(
199 insertSliceOp->template getParentOfType<scf::InParallelOp>());
200 }
201
202 // Resolve offsets according to source offsets and strides.
203 SmallVector<Value> resolvedOffsets;
204 // Note: the "insertSlice" case is symmetrical to the extract/subview case:
205 // `insertSliceOp` is passed as the "source" and `sourceInsertSliceOp` is
206 // passed as the destination to the helper function.
207 affine::resolveIndicesIntoOpWithOffsetsAndStrides(
208 rewriter, insertSliceOp.getLoc(), insertSliceOp.getMixedOffsets(),
209 insertSliceOp.getMixedStrides(), droppedDims,
210 sourceInsertSliceOp.getMixedOffsets(), resolvedOffsets);
211
212 // Reset the insertion point.
213 rewriter.setInsertionPoint(insertSliceOp);
214 // Replace original op.
215 rewriter.replaceOpWithNewOp<OpTy>(
216 insertSliceOp, sourceInsertSliceOp.getSource(), insertSliceOp.getDest(),
217 getAsOpFoldResult(values: resolvedOffsets), resolvedSizes,
218 insertSliceOp.getMixedStrides());
219
220 return success();
221 }
222};
223
224void tensor::populateFoldTensorSubsetOpPatterns(RewritePatternSet &patterns) {
225 populateFoldTensorSubsetIntoVectorTransferPatterns(patterns);
226 patterns.add<InsertSliceOfInsertSliceFolder<tensor::InsertSliceOp>,
227 InsertSliceOfInsertSliceFolder<tensor::ParallelInsertSliceOp>>(
228 patterns.getContext());
229}
230
231void tensor::populateFoldTensorSubsetIntoVectorTransferPatterns(
232 RewritePatternSet &patterns) {
233 patterns.add<TransferReadOfExtractSliceOpFolder,
234 InsertSliceOfTransferWriteOpFolder>(arg: patterns.getContext());
235}
236
237//===----------------------------------------------------------------------===//
238// Pass registration
239//===----------------------------------------------------------------------===//
240
241namespace {
242
243struct FoldTensorSubsetOpsPass final
244 : public tensor::impl::FoldTensorSubsetOpsBase<FoldTensorSubsetOpsPass> {
245 void runOnOperation() override;
246};
247
248} // namespace
249
250void FoldTensorSubsetOpsPass::runOnOperation() {
251 RewritePatternSet patterns(&getContext());
252 tensor::populateFoldTensorSubsetOpPatterns(patterns);
253 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
254}
255
256std::unique_ptr<Pass> tensor::createFoldTensorSubsetOpsPass() {
257 return std::make_unique<FoldTensorSubsetOpsPass>();
258}
259

source code of mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp