1//===- FoldTensorSubsetOps.cpp - Fold tensor subset ops -------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Fold tensor subset ops with producer / consumers.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/Affine/IR/AffineOps.h"
14#include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
15#include "mlir/Dialect/SCF/IR/SCF.h"
16#include "mlir/Dialect/Tensor/IR/Tensor.h"
17#include "mlir/Dialect/Tensor/Transforms/Passes.h"
18#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
19#include "mlir/Dialect/Utils/IndexingUtils.h"
20#include "mlir/Dialect/Vector/IR/VectorOps.h"
21#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
22#include "mlir/IR/AffineMap.h"
23#include "mlir/IR/BuiltinAttributes.h"
24#include "mlir/Interfaces/ValueBoundsOpInterface.h"
25#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
26#include "llvm/ADT/TypeSwitch.h"
27#include <type_traits>
28
29namespace mlir {
30namespace tensor {
31#define GEN_PASS_DEF_FOLDTENSORSUBSETOPSPASS
32#include "mlir/Dialect/Tensor/Transforms/Passes.h.inc"
33} // namespace tensor
34} // namespace mlir
35
36using namespace mlir;
37
38static Value getTensorOperand(vector::TransferReadOp op) {
39 return op.getBase();
40}
41
42static Value getTensorOperand(tensor::InsertSliceOp op) {
43 return op.getSource();
44}
45
46//===----------------------------------------------------------------------===//
47// Patterns
48//===----------------------------------------------------------------------===//
49
50namespace {
51/// Merge extract_slice operation with load/transferRead operation.
52class TransferReadOfExtractSliceOpFolder final
53 : public vector::MaskableOpRewritePattern<vector::TransferReadOp> {
54public:
55 using MaskableOpRewritePattern::MaskableOpRewritePattern;
56
57 FailureOr<mlir::Value>
58 matchAndRewriteMaskableOp(vector::TransferReadOp readOp,
59 vector::MaskingOpInterface maskOp,
60 PatternRewriter &rewriter) const override;
61};
62
63/// Merge insert_slice operation with store/transferWriteOp operation.
64class InsertSliceOfTransferWriteOpFolder final
65 : public OpRewritePattern<tensor::InsertSliceOp> {
66public:
67 using OpRewritePattern<tensor::InsertSliceOp>::OpRewritePattern;
68
69 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertSliceOp,
70 PatternRewriter &rewriter) const override;
71
72private:
73 static bool
74 doesTransferWriteCoverInsertSlice(vector::TransferWriteOp writeOp);
75};
76} // namespace
77
78template <typename XferOp, typename ExtractOrInsertOp>
79static LogicalResult preconditionsFoldExtractOrInsertWithTransferOp(
80 RewriterBase &rewriter, XferOp xferOp,
81 ExtractOrInsertOp extractOrInsertSliceOp) {
82 if (xferOp.hasOutOfBoundsDim())
83 return rewriter.notifyMatchFailure(xferOp, "out of bounds transfer dim");
84 if (xferOp.getMask())
85 return rewriter.notifyMatchFailure(xferOp, "masked transfer");
86 if (!extractOrInsertSliceOp.hasUnitStride()) {
87 return rewriter.notifyMatchFailure(
88 xferOp, "non-1 stride insert/extract, requires keeping track of "
89 "strides, this may result in needing to insert "
90 "vector.insert_strided_slice/extract_strided_slice ops");
91 }
92 return success();
93}
94
95FailureOr<mlir::Value>
96TransferReadOfExtractSliceOpFolder::matchAndRewriteMaskableOp(
97 vector::TransferReadOp readOp, vector::MaskingOpInterface maskOp,
98 PatternRewriter &rewriter) const {
99 auto extractSliceOp =
100 getTensorOperand(readOp).getDefiningOp<tensor::ExtractSliceOp>();
101 if (!extractSliceOp)
102 return rewriter.notifyMatchFailure(readOp, "not an extract_slice");
103
104 LogicalResult preconditionResult =
105 preconditionsFoldExtractOrInsertWithTransferOp(rewriter, readOp,
106 extractSliceOp);
107 if (failed(Result: preconditionResult))
108 return rewriter.notifyMatchFailure(readOp, "Failed preconditions");
109
110 SmallVector<Value> indices(readOp.getIndices().begin(),
111 readOp.getIndices().end());
112 SmallVector<Value> sourceIndices;
113 affine::resolveIndicesIntoOpWithOffsetsAndStrides(
114 rewriter, readOp.getLoc(), extractSliceOp.getMixedOffsets(),
115 extractSliceOp.getMixedStrides(), extractSliceOp.getDroppedDims(),
116 indices, sourceIndices);
117
118 Operation *newOp = rewriter.create<vector::TransferReadOp>(
119 readOp.getLoc(), readOp.getVectorType(), extractSliceOp.getSource(),
120 sourceIndices,
121 AffineMapAttr::get(expandDimsToRank(
122 readOp.getPermutationMap(), extractSliceOp.getSourceType().getRank(),
123 extractSliceOp.getDroppedDims())),
124 readOp.getPadding(),
125 /*mask=*/Value(), readOp.getInBoundsAttr());
126 if (maskOp)
127 newOp = mlir::vector::maskOperation(builder&: rewriter, maskableOp: newOp, mask: maskOp.getMask());
128 return newOp->getResults()[0];
129}
130
131LogicalResult InsertSliceOfTransferWriteOpFolder::matchAndRewrite(
132 tensor::InsertSliceOp insertSliceOp, PatternRewriter &rewriter) const {
133 auto writeOp = getTensorOperand(insertSliceOp)
134 .template getDefiningOp<vector::TransferWriteOp>();
135 if (!writeOp)
136 return rewriter.notifyMatchFailure(insertSliceOp, "not a transfer_write");
137
138 LogicalResult preconditionResult =
139 preconditionsFoldExtractOrInsertWithTransferOp(rewriter, writeOp,
140 insertSliceOp);
141 if (failed(Result: preconditionResult))
142 return preconditionResult;
143
144 if (!doesTransferWriteCoverInsertSlice(writeOp: writeOp))
145 return rewriter.notifyMatchFailure(
146 insertSliceOp, "transfer_write does not cover insert_slice");
147
148 SmallVector<Value> indices(writeOp.getIndices().begin(),
149 writeOp.getIndices().end());
150 SmallVector<Value> sourceIndices;
151 affine::resolveIndicesIntoOpWithOffsetsAndStrides(
152 rewriter, writeOp.getLoc(), insertSliceOp.getMixedOffsets(),
153 insertSliceOp.getMixedStrides(), insertSliceOp.getDroppedDims(), indices,
154 sourceIndices);
155
156 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
157 insertSliceOp, writeOp.getValue(), insertSliceOp.getDest(), sourceIndices,
158 AffineMapAttr::get(expandDimsToRank(writeOp.getPermutationMap(),
159 insertSliceOp.getDestType().getRank(),
160 insertSliceOp.getDroppedDims())),
161 writeOp.getInBoundsAttr());
162
163 return success();
164}
165
166bool InsertSliceOfTransferWriteOpFolder::doesTransferWriteCoverInsertSlice(
167 vector::TransferWriteOp writeOp) {
168 if (writeOp.getShapedType().hasStaticShape())
169 return llvm::equal(writeOp.getVectorType().getShape(),
170 writeOp.getShapedType().getShape());
171
172 // TODO: Use ValueBoundsConstraintSet for dynamic shapes.
173
174 return false;
175}
176
177template <typename OpTy>
178struct InsertSliceOfInsertSliceFolder : public OpRewritePattern<OpTy> {
179 using OpRewritePattern<OpTy>::OpRewritePattern;
180
181 LogicalResult matchAndRewrite(OpTy insertSliceOp,
182 PatternRewriter &rewriter) const override {
183 auto sourceInsertSliceOp =
184 insertSliceOp.getSource()
185 .template getDefiningOp<tensor::InsertSliceOp>();
186 if (!sourceInsertSliceOp)
187 return failure();
188
189 // TODO: relax unit stride assumption where possible.
190 if (!insertSliceOp.hasUnitStride()) {
191 return rewriter.notifyMatchFailure(insertSliceOp,
192 "requires unit strides");
193 }
194 if (!sourceInsertSliceOp.hasUnitStride()) {
195 return rewriter.notifyMatchFailure(sourceInsertSliceOp,
196 "requires unit strides");
197 }
198
199 int64_t srcDim = 0;
200 llvm::SmallBitVector droppedDims = insertSliceOp.getDroppedDims();
201 for (int64_t d = 0, e = insertSliceOp.getDestType().getRank(); d < e; ++d) {
202 if (droppedDims[d])
203 continue;
204 if (insertSliceOp.getMixedSizes()[d] !=
205 sourceInsertSliceOp.getMixedSizes()[srcDim++]) {
206 return rewriter.notifyMatchFailure(
207 sourceInsertSliceOp,
208 "requires matching sizes to fold, otherwise a copy is needed");
209 }
210 }
211
212 // Resolve sizes according to dropped dims.
213 SmallVector<OpFoldResult> resolvedSizes;
214 // Note: the "insertSlice" case is symmetrical to the extract/subview case:
215 // `insertSliceOp` is passed as the "source" and `sourceInsertSliceOp` is
216 // passed as the destination to the helper function.
217 affine::resolveSizesIntoOpWithSizes(sourceSizes: insertSliceOp.getMixedSizes(),
218 destSizes: sourceInsertSliceOp.getMixedSizes(),
219 rankReducedSourceDims: droppedDims, resolvedSizes);
220
221 // If we are inside an InParallel region, temporarily set the insertion
222 // point outside: only tensor.parallel_insert_slice ops are allowed in
223 // there.
224 if (std::is_same_v<OpTy, tensor::ParallelInsertSliceOp>) {
225 rewriter.setInsertionPoint(
226 insertSliceOp->template getParentOfType<scf::InParallelOp>());
227 }
228
229 // Resolve offsets according to source offsets and strides.
230 SmallVector<Value> resolvedOffsets;
231 // Note: the "insertSlice" case is symmetrical to the extract/subview case:
232 // `insertSliceOp` is passed as the "source" and `sourceInsertSliceOp` is
233 // passed as the destination to the helper function.
234 affine::resolveIndicesIntoOpWithOffsetsAndStrides(
235 rewriter, insertSliceOp.getLoc(), insertSliceOp.getMixedOffsets(),
236 insertSliceOp.getMixedStrides(), droppedDims,
237 sourceInsertSliceOp.getMixedOffsets(), resolvedOffsets);
238
239 // Reset the insertion point.
240 rewriter.setInsertionPoint(insertSliceOp);
241 // Replace original op.
242 rewriter.replaceOpWithNewOp<OpTy>(
243 insertSliceOp, sourceInsertSliceOp.getSource(), insertSliceOp.getDest(),
244 getAsOpFoldResult(values: resolvedOffsets), resolvedSizes,
245 insertSliceOp.getMixedStrides());
246
247 return success();
248 }
249};
250
251void tensor::populateFoldTensorSubsetOpPatterns(RewritePatternSet &patterns) {
252 populateFoldTensorSubsetIntoVectorTransferPatterns(patterns);
253 patterns.add<InsertSliceOfInsertSliceFolder<tensor::InsertSliceOp>,
254 InsertSliceOfInsertSliceFolder<tensor::ParallelInsertSliceOp>>(
255 patterns.getContext());
256}
257
258void tensor::populateFoldTensorSubsetIntoVectorTransferPatterns(
259 RewritePatternSet &patterns) {
260 patterns.add<TransferReadOfExtractSliceOpFolder,
261 InsertSliceOfTransferWriteOpFolder>(arg: patterns.getContext());
262}
263
264//===----------------------------------------------------------------------===//
265// Pass registration
266//===----------------------------------------------------------------------===//
267
268namespace {
269
270struct FoldTensorSubsetOpsPass final
271 : public tensor::impl::FoldTensorSubsetOpsPassBase<
272 FoldTensorSubsetOpsPass> {
273 void runOnOperation() override;
274};
275
276} // namespace
277
278void FoldTensorSubsetOpsPass::runOnOperation() {
279 RewritePatternSet patterns(&getContext());
280 tensor::populateFoldTensorSubsetOpPatterns(patterns);
281 (void)applyPatternsGreedily(getOperation(), std::move(patterns));
282}
283

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp