1//===- VectorInsertExtractStridedSliceRewritePatterns.cpp - Rewrites ------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Arith/IR/Arith.h"
10#include "mlir/Dialect/MemRef/IR/MemRef.h"
11#include "mlir/Dialect/Utils/IndexingUtils.h"
12#include "mlir/Dialect/Vector/IR/VectorOps.h"
13#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
14#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
15#include "mlir/IR/BuiltinTypes.h"
16#include "mlir/IR/PatternMatch.h"
17
18using namespace mlir;
19using namespace mlir::vector;
20
21// Helper that picks the proper sequence for inserting.
22static Value insertOne(PatternRewriter &rewriter, Location loc, Value from,
23 Value into, int64_t offset) {
24 auto vectorType = cast<VectorType>(into.getType());
25 if (vectorType.getRank() > 1)
26 return rewriter.create<InsertOp>(loc, from, into, offset);
27 return rewriter.create<vector::InsertElementOp>(
28 loc, vectorType, from, into,
29 rewriter.create<arith::ConstantIndexOp>(loc, offset));
30}
31
32// Helper that picks the proper sequence for extracting.
33static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector,
34 int64_t offset) {
35 auto vectorType = cast<VectorType>(vector.getType());
36 if (vectorType.getRank() > 1)
37 return rewriter.create<ExtractOp>(loc, vector, offset);
38 return rewriter.create<vector::ExtractElementOp>(
39 loc, vectorType.getElementType(), vector,
40 rewriter.create<arith::ConstantIndexOp>(loc, offset));
41}
42
43/// RewritePattern for InsertStridedSliceOp where source and destination vectors
44/// have different ranks.
45///
46/// When ranks are different, InsertStridedSlice needs to extract a properly
47/// ranked vector from the destination vector into which to insert. This pattern
48/// only takes care of this extraction part and forwards the rest to
49/// [ConvertSameRankInsertStridedSliceIntoShuffle].
50///
51/// For a k-D source and n-D destination vector (k < n), we emit:
52/// 1. ExtractOp to extract the (unique) (n-1)-D subvector into which to
53/// insert the k-D source.
54/// 2. k-D -> (n-1)-D InsertStridedSlice op
55/// 3. InsertOp that is the reverse of 1.
56class DecomposeDifferentRankInsertStridedSlice
57 : public OpRewritePattern<InsertStridedSliceOp> {
58public:
59 using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
60
61 LogicalResult matchAndRewrite(InsertStridedSliceOp op,
62 PatternRewriter &rewriter) const override {
63 auto srcType = op.getSourceVectorType();
64 auto dstType = op.getDestVectorType();
65
66 if (op.getOffsets().getValue().empty())
67 return failure();
68
69 auto loc = op.getLoc();
70 int64_t rankDiff = dstType.getRank() - srcType.getRank();
71 assert(rankDiff >= 0);
72 if (rankDiff == 0)
73 return failure();
74
75 int64_t rankRest = dstType.getRank() - rankDiff;
76 // Extract / insert the subvector of matching rank and InsertStridedSlice
77 // on it.
78 Value extracted = rewriter.create<ExtractOp>(
79 loc, op.getDest(),
80 getI64SubArray(op.getOffsets(), /*dropFront=*/0,
81 /*dropBack=*/rankRest));
82
83 // A different pattern will kick in for InsertStridedSlice with matching
84 // ranks.
85 auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>(
86 loc, op.getSource(), extracted,
87 getI64SubArray(op.getOffsets(), /*dropFront=*/rankDiff),
88 getI64SubArray(op.getStrides(), /*dropFront=*/0));
89
90 rewriter.replaceOpWithNewOp<InsertOp>(
91 op, stridedSliceInnerOp.getResult(), op.getDest(),
92 getI64SubArray(op.getOffsets(), /*dropFront=*/0,
93 /*dropBack=*/rankRest));
94 return success();
95 }
96};
97
98/// RewritePattern for InsertStridedSliceOp where source and destination vectors
99/// have the same rank. For each outermost index in the slice:
100/// begin end stride
101/// [offset : offset+size*stride : stride]
102/// 1. ExtractOp one (k-1)-D source subvector and one (n-1)-D dest subvector.
103/// 2. InsertStridedSlice (k-1)-D into (n-1)-D
104/// 3. the destination subvector is inserted back in the proper place
105/// 3. InsertOp that is the reverse of 1.
106class ConvertSameRankInsertStridedSliceIntoShuffle
107 : public OpRewritePattern<InsertStridedSliceOp> {
108public:
109 using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
110
111 void initialize() {
112 // This pattern creates recursive InsertStridedSliceOp, but the recursion is
113 // bounded as the rank is strictly decreasing.
114 setHasBoundedRewriteRecursion();
115 }
116
117 LogicalResult matchAndRewrite(InsertStridedSliceOp op,
118 PatternRewriter &rewriter) const override {
119 auto srcType = op.getSourceVectorType();
120 auto dstType = op.getDestVectorType();
121
122 if (op.getOffsets().getValue().empty())
123 return failure();
124
125 int64_t srcRank = srcType.getRank();
126 int64_t dstRank = dstType.getRank();
127 assert(dstRank >= srcRank);
128 if (dstRank != srcRank)
129 return failure();
130
131 if (srcType == dstType) {
132 rewriter.replaceOp(op, op.getSource());
133 return success();
134 }
135
136 int64_t offset =
137 cast<IntegerAttr>(op.getOffsets().getValue().front()).getInt();
138 int64_t size = srcType.getShape().front();
139 int64_t stride =
140 cast<IntegerAttr>(op.getStrides().getValue().front()).getInt();
141
142 auto loc = op.getLoc();
143 Value res = op.getDest();
144
145 if (srcRank == 1) {
146 int nSrc = srcType.getShape().front();
147 int nDest = dstType.getShape().front();
148 // 1. Scale source to destType so we can shufflevector them together.
149 SmallVector<int64_t> offsets(nDest, 0);
150 for (int64_t i = 0; i < nSrc; ++i)
151 offsets[i] = i;
152 Value scaledSource = rewriter.create<ShuffleOp>(loc, op.getSource(),
153 op.getSource(), offsets);
154
155 // 2. Create a mask where we take the value from scaledSource of dest
156 // depending on the offset.
157 offsets.clear();
158 for (int64_t i = 0, e = offset + size * stride; i < nDest; ++i) {
159 if (i < offset || i >= e || (i - offset) % stride != 0)
160 offsets.push_back(Elt: nDest + i);
161 else
162 offsets.push_back(Elt: (i - offset) / stride);
163 }
164
165 // 3. Replace with a ShuffleOp.
166 rewriter.replaceOpWithNewOp<ShuffleOp>(op, scaledSource, op.getDest(),
167 offsets);
168
169 return success();
170 }
171
172 // For each slice of the source vector along the most major dimension.
173 for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
174 off += stride, ++idx) {
175 // 1. extract the proper subvector (or element) from source
176 Value extractedSource = extractOne(rewriter, loc, op.getSource(), idx);
177 if (isa<VectorType>(Val: extractedSource.getType())) {
178 // 2. If we have a vector, extract the proper subvector from destination
179 // Otherwise we are at the element level and no need to recurse.
180 Value extractedDest = extractOne(rewriter, loc, op.getDest(), off);
181 // 3. Reduce the problem to lowering a new InsertStridedSlice op with
182 // smaller rank.
183 extractedSource = rewriter.create<InsertStridedSliceOp>(
184 loc, extractedSource, extractedDest,
185 getI64SubArray(op.getOffsets(), /* dropFront=*/1),
186 getI64SubArray(op.getStrides(), /* dropFront=*/1));
187 }
188 // 4. Insert the extractedSource into the res vector.
189 res = insertOne(rewriter, loc, extractedSource, res, off);
190 }
191
192 rewriter.replaceOp(op, res);
193 return success();
194 }
195};
196
197/// RewritePattern for ExtractStridedSliceOp where source and destination
198/// vectors are 1-D. For such cases, we can lower it to a ShuffleOp.
199class Convert1DExtractStridedSliceIntoShuffle
200 : public OpRewritePattern<ExtractStridedSliceOp> {
201public:
202 using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
203
204 LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
205 PatternRewriter &rewriter) const override {
206 auto dstType = op.getType();
207
208 assert(!op.getOffsets().getValue().empty() && "Unexpected empty offsets");
209
210 int64_t offset =
211 cast<IntegerAttr>(op.getOffsets().getValue().front()).getInt();
212 int64_t size = cast<IntegerAttr>(op.getSizes().getValue().front()).getInt();
213 int64_t stride =
214 cast<IntegerAttr>(op.getStrides().getValue().front()).getInt();
215
216 assert(dstType.getElementType().isSignlessIntOrIndexOrFloat());
217
218 // Single offset can be more efficiently shuffled.
219 if (op.getOffsets().getValue().size() != 1)
220 return failure();
221
222 SmallVector<int64_t, 4> offsets;
223 offsets.reserve(N: size);
224 for (int64_t off = offset, e = offset + size * stride; off < e;
225 off += stride)
226 offsets.push_back(Elt: off);
227 rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.getVector(),
228 op.getVector(),
229 rewriter.getI64ArrayAttr(offsets));
230 return success();
231 }
232};
233
234/// For a 1-D ExtractStridedSlice, breaks it down into a chain of Extract ops
235/// to extract each element from the source, and then a chain of Insert ops
236/// to insert to the target vector.
237class Convert1DExtractStridedSliceIntoExtractInsertChain final
238 : public OpRewritePattern<ExtractStridedSliceOp> {
239public:
240 Convert1DExtractStridedSliceIntoExtractInsertChain(
241 MLIRContext *context,
242 std::function<bool(ExtractStridedSliceOp)> controlFn,
243 PatternBenefit benefit)
244 : OpRewritePattern(context, benefit), controlFn(std::move(controlFn)) {}
245
246 LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
247 PatternRewriter &rewriter) const override {
248 if (controlFn && !controlFn(op))
249 return failure();
250
251 // Only handle 1-D cases.
252 if (op.getOffsets().getValue().size() != 1)
253 return failure();
254
255 int64_t offset =
256 cast<IntegerAttr>(op.getOffsets().getValue().front()).getInt();
257 int64_t size = cast<IntegerAttr>(op.getSizes().getValue().front()).getInt();
258 int64_t stride =
259 cast<IntegerAttr>(op.getStrides().getValue().front()).getInt();
260
261 Location loc = op.getLoc();
262 SmallVector<Value> elements;
263 elements.reserve(N: size);
264 for (int64_t i = offset, e = offset + size * stride; i < e; i += stride)
265 elements.push_back(rewriter.create<ExtractOp>(loc, op.getVector(), i));
266
267 Value result = rewriter.create<arith::ConstantOp>(
268 loc, rewriter.getZeroAttr(op.getType()));
269 for (int64_t i = 0; i < size; ++i)
270 result = rewriter.create<InsertOp>(loc, elements[i], result, i);
271
272 rewriter.replaceOp(op, result);
273 return success();
274 }
275
276private:
277 std::function<bool(ExtractStridedSliceOp)> controlFn;
278};
279
280/// RewritePattern for ExtractStridedSliceOp where the source vector is n-D.
281/// For such cases, we can rewrite it to ExtractOp/ExtractElementOp + lower
282/// rank ExtractStridedSliceOp + InsertOp/InsertElementOp for the n-D case.
283class DecomposeNDExtractStridedSlice
284 : public OpRewritePattern<ExtractStridedSliceOp> {
285public:
286 using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
287
288 void initialize() {
289 // This pattern creates recursive ExtractStridedSliceOp, but the recursion
290 // is bounded as the rank is strictly decreasing.
291 setHasBoundedRewriteRecursion();
292 }
293
294 LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
295 PatternRewriter &rewriter) const override {
296 auto dstType = op.getType();
297
298 assert(!op.getOffsets().getValue().empty() && "Unexpected empty offsets");
299
300 int64_t offset =
301 cast<IntegerAttr>(op.getOffsets().getValue().front()).getInt();
302 int64_t size = cast<IntegerAttr>(op.getSizes().getValue().front()).getInt();
303 int64_t stride =
304 cast<IntegerAttr>(op.getStrides().getValue().front()).getInt();
305
306 auto loc = op.getLoc();
307 auto elemType = dstType.getElementType();
308 assert(elemType.isSignlessIntOrIndexOrFloat());
309
310 // Single offset can be more efficiently shuffled. It's handled in
311 // Convert1DExtractStridedSliceIntoShuffle.
312 if (op.getOffsets().getValue().size() == 1)
313 return failure();
314
315 // Extract/insert on a lower ranked extract strided slice op.
316 Value zero = rewriter.create<arith::ConstantOp>(
317 loc, elemType, rewriter.getZeroAttr(elemType));
318 Value res = rewriter.create<SplatOp>(loc, dstType, zero);
319 for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
320 off += stride, ++idx) {
321 Value one = extractOne(rewriter, loc, op.getVector(), off);
322 Value extracted = rewriter.create<ExtractStridedSliceOp>(
323 loc, one, getI64SubArray(op.getOffsets(), /* dropFront=*/1),
324 getI64SubArray(op.getSizes(), /* dropFront=*/1),
325 getI64SubArray(op.getStrides(), /* dropFront=*/1));
326 res = insertOne(rewriter, loc, extracted, res, idx);
327 }
328 rewriter.replaceOp(op, res);
329 return success();
330 }
331};
332
333void vector::populateVectorInsertExtractStridedSliceDecompositionPatterns(
334 RewritePatternSet &patterns, PatternBenefit benefit) {
335 patterns.add<DecomposeDifferentRankInsertStridedSlice,
336 DecomposeNDExtractStridedSlice>(arg: patterns.getContext(), args&: benefit);
337}
338
339void vector::populateVectorExtractStridedSliceToExtractInsertChainPatterns(
340 RewritePatternSet &patterns,
341 std::function<bool(ExtractStridedSliceOp)> controlFn,
342 PatternBenefit benefit) {
343 patterns.add<Convert1DExtractStridedSliceIntoExtractInsertChain>(
344 patterns.getContext(), std::move(controlFn), benefit);
345}
346
347/// Populate the given list with patterns that convert from Vector to LLVM.
348void vector::populateVectorInsertExtractStridedSliceTransforms(
349 RewritePatternSet &patterns, PatternBenefit benefit) {
350 populateVectorInsertExtractStridedSliceDecompositionPatterns(patterns,
351 benefit);
352 patterns.add<ConvertSameRankInsertStridedSliceIntoShuffle,
353 Convert1DExtractStridedSliceIntoShuffle>(arg: patterns.getContext(),
354 args&: benefit);
355}
356

source code of mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp