1//===- VectorInsertExtractStridedSliceRewritePatterns.cpp - Rewrites ------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Arith/IR/Arith.h"
10#include "mlir/Dialect/MemRef/IR/MemRef.h"
11#include "mlir/Dialect/Utils/IndexingUtils.h"
12#include "mlir/Dialect/Vector/IR/VectorOps.h"
13#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
14#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
15#include "mlir/IR/BuiltinTypes.h"
16#include "mlir/IR/PatternMatch.h"
17
18using namespace mlir;
19using namespace mlir::vector;
20
21/// RewritePattern for InsertStridedSliceOp where source and destination vectors
22/// have different ranks.
23///
24/// When ranks are different, InsertStridedSlice needs to extract a properly
25/// ranked vector from the destination vector into which to insert. This pattern
26/// only takes care of this extraction part and forwards the rest to
27/// [ConvertSameRankInsertStridedSliceIntoShuffle].
28///
29/// For a k-D source and n-D destination vector (k < n), we emit:
30/// 1. ExtractOp to extract the (unique) (n-1)-D subvector into which to
31/// insert the k-D source.
32/// 2. k-D -> (n-1)-D InsertStridedSlice op
33/// 3. InsertOp that is the reverse of 1.
34class DecomposeDifferentRankInsertStridedSlice
35 : public OpRewritePattern<InsertStridedSliceOp> {
36public:
37 using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
38
39 LogicalResult matchAndRewrite(InsertStridedSliceOp op,
40 PatternRewriter &rewriter) const override {
41 auto srcType = op.getSourceVectorType();
42 auto dstType = op.getDestVectorType();
43
44 if (op.getOffsets().getValue().empty())
45 return failure();
46
47 auto loc = op.getLoc();
48 int64_t rankDiff = dstType.getRank() - srcType.getRank();
49 assert(rankDiff >= 0);
50 if (rankDiff == 0)
51 return failure();
52
53 int64_t rankRest = dstType.getRank() - rankDiff;
54 // Extract / insert the subvector of matching rank and InsertStridedSlice
55 // on it.
56 Value extracted = rewriter.create<ExtractOp>(
57 loc, op.getDest(),
58 getI64SubArray(op.getOffsets(), /*dropFront=*/0,
59 /*dropBack=*/rankRest));
60
61 // A different pattern will kick in for InsertStridedSlice with matching
62 // ranks.
63 auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>(
64 loc, op.getValueToStore(), extracted,
65 getI64SubArray(op.getOffsets(), /*dropFront=*/rankDiff),
66 getI64SubArray(op.getStrides(), /*dropFront=*/0));
67
68 rewriter.replaceOpWithNewOp<InsertOp>(
69 op, stridedSliceInnerOp.getResult(), op.getDest(),
70 getI64SubArray(op.getOffsets(), /*dropFront=*/0,
71 /*dropBack=*/rankRest));
72 return success();
73 }
74};
75
76/// RewritePattern for InsertStridedSliceOp where source and destination vectors
77/// have the same rank. For each outermost index in the slice:
78/// begin end stride
79/// [offset : offset+size*stride : stride]
80/// 1. ExtractOp one (k-1)-D source subvector and one (n-1)-D dest subvector.
81/// 2. InsertStridedSlice (k-1)-D into (n-1)-D
82/// 3. the destination subvector is inserted back in the proper place
83/// 3. InsertOp that is the reverse of 1.
84class ConvertSameRankInsertStridedSliceIntoShuffle
85 : public OpRewritePattern<InsertStridedSliceOp> {
86public:
87 using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
88
89 void initialize() {
90 // This pattern creates recursive InsertStridedSliceOp, but the recursion is
91 // bounded as the rank is strictly decreasing.
92 setHasBoundedRewriteRecursion();
93 }
94
95 LogicalResult matchAndRewrite(InsertStridedSliceOp op,
96 PatternRewriter &rewriter) const override {
97 auto srcType = op.getSourceVectorType();
98 auto dstType = op.getDestVectorType();
99 int64_t srcRank = srcType.getRank();
100
101 // Scalable vectors are not supported by vector shuffle.
102 if ((srcType.isScalable() || dstType.isScalable()) && srcRank == 1)
103 return failure();
104
105 if (op.getOffsets().getValue().empty())
106 return failure();
107
108 int64_t dstRank = dstType.getRank();
109 assert(dstRank >= srcRank);
110 if (dstRank != srcRank)
111 return failure();
112
113 if (srcType == dstType) {
114 rewriter.replaceOp(op, op.getValueToStore());
115 return success();
116 }
117
118 int64_t offset =
119 cast<IntegerAttr>(op.getOffsets().getValue().front()).getInt();
120 int64_t size = srcType.getShape().front();
121 int64_t stride =
122 cast<IntegerAttr>(op.getStrides().getValue().front()).getInt();
123
124 auto loc = op.getLoc();
125 Value res = op.getDest();
126
127 if (srcRank == 1) {
128 int nSrc = srcType.getShape().front();
129 int nDest = dstType.getShape().front();
130 // 1. Scale source to destType so we can shufflevector them together.
131 SmallVector<int64_t> offsets(nDest, 0);
132 for (int64_t i = 0; i < nSrc; ++i)
133 offsets[i] = i;
134 Value scaledSource = rewriter.create<ShuffleOp>(
135 loc, op.getValueToStore(), op.getValueToStore(), offsets);
136
137 // 2. Create a mask where we take the value from scaledSource of dest
138 // depending on the offset.
139 offsets.clear();
140 for (int64_t i = 0, e = offset + size * stride; i < nDest; ++i) {
141 if (i < offset || i >= e || (i - offset) % stride != 0)
142 offsets.push_back(Elt: nDest + i);
143 else
144 offsets.push_back(Elt: (i - offset) / stride);
145 }
146
147 // 3. Replace with a ShuffleOp.
148 rewriter.replaceOpWithNewOp<ShuffleOp>(op, scaledSource, op.getDest(),
149 offsets);
150
151 return success();
152 }
153
154 // For each slice of the source vector along the most major dimension.
155 for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
156 off += stride, ++idx) {
157 // 1. extract the proper subvector (or element) from source
158 Value extractedSource =
159 rewriter.create<ExtractOp>(loc, op.getValueToStore(), idx);
160 if (isa<VectorType>(Val: extractedSource.getType())) {
161 // 2. If we have a vector, extract the proper subvector from destination
162 // Otherwise we are at the element level and no need to recurse.
163 Value extractedDest =
164 rewriter.create<ExtractOp>(loc, op.getDest(), off);
165 // 3. Reduce the problem to lowering a new InsertStridedSlice op with
166 // smaller rank.
167 extractedSource = rewriter.create<InsertStridedSliceOp>(
168 loc, extractedSource, extractedDest,
169 getI64SubArray(op.getOffsets(), /* dropFront=*/1),
170 getI64SubArray(op.getStrides(), /* dropFront=*/1));
171 }
172 // 4. Insert the extractedSource into the res vector.
173 res = rewriter.create<InsertOp>(loc, extractedSource, res, off);
174 }
175
176 rewriter.replaceOp(op, res);
177 return success();
178 }
179};
180
181/// RewritePattern for ExtractStridedSliceOp where source and destination
182/// vectors are 1-D. For such cases, we can lower it to a ShuffleOp.
183class Convert1DExtractStridedSliceIntoShuffle
184 : public OpRewritePattern<ExtractStridedSliceOp> {
185public:
186 using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
187
188 LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
189 PatternRewriter &rewriter) const override {
190 auto dstType = op.getType();
191 auto srcType = op.getSourceVectorType();
192
193 // Scalable vectors are not supported by vector shuffle.
194 if (dstType.isScalable() || srcType.isScalable())
195 return failure();
196
197 assert(!op.getOffsets().getValue().empty() && "Unexpected empty offsets");
198
199 int64_t offset =
200 cast<IntegerAttr>(op.getOffsets().getValue().front()).getInt();
201 int64_t size = cast<IntegerAttr>(op.getSizes().getValue().front()).getInt();
202 int64_t stride =
203 cast<IntegerAttr>(op.getStrides().getValue().front()).getInt();
204
205 assert(dstType.getElementType().isSignlessIntOrIndexOrFloat());
206
207 // Single offset can be more efficiently shuffled.
208 if (op.getOffsets().getValue().size() != 1)
209 return failure();
210
211 SmallVector<int64_t, 4> offsets;
212 offsets.reserve(N: size);
213 for (int64_t off = offset, e = offset + size * stride; off < e;
214 off += stride)
215 offsets.push_back(Elt: off);
216 rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.getVector(),
217 op.getVector(), offsets);
218 return success();
219 }
220};
221
222/// For a 1-D ExtractStridedSlice, breaks it down into a chain of Extract ops
223/// to extract each element from the source, and then a chain of Insert ops
224/// to insert to the target vector.
225class Convert1DExtractStridedSliceIntoExtractInsertChain final
226 : public OpRewritePattern<ExtractStridedSliceOp> {
227public:
228 Convert1DExtractStridedSliceIntoExtractInsertChain(
229 MLIRContext *context,
230 std::function<bool(ExtractStridedSliceOp)> controlFn,
231 PatternBenefit benefit)
232 : OpRewritePattern(context, benefit), controlFn(std::move(controlFn)) {}
233
234 LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
235 PatternRewriter &rewriter) const override {
236 if (controlFn && !controlFn(op))
237 return failure();
238
239 // Only handle 1-D cases.
240 if (op.getOffsets().getValue().size() != 1)
241 return failure();
242
243 int64_t offset =
244 cast<IntegerAttr>(op.getOffsets().getValue().front()).getInt();
245 int64_t size = cast<IntegerAttr>(op.getSizes().getValue().front()).getInt();
246 int64_t stride =
247 cast<IntegerAttr>(op.getStrides().getValue().front()).getInt();
248
249 Location loc = op.getLoc();
250 SmallVector<Value> elements;
251 elements.reserve(N: size);
252 for (int64_t i = offset, e = offset + size * stride; i < e; i += stride)
253 elements.push_back(rewriter.create<ExtractOp>(loc, op.getVector(), i));
254
255 Value result = rewriter.create<arith::ConstantOp>(
256 loc, rewriter.getZeroAttr(op.getType()));
257 for (int64_t i = 0; i < size; ++i)
258 result = rewriter.create<InsertOp>(loc, elements[i], result, i);
259
260 rewriter.replaceOp(op, result);
261 return success();
262 }
263
264private:
265 std::function<bool(ExtractStridedSliceOp)> controlFn;
266};
267
268/// RewritePattern for ExtractStridedSliceOp where the source vector is n-D.
269/// For such cases, we can rewrite it to ExtractOp + lower rank
270/// ExtractStridedSliceOp + InsertOp for the n-D case.
271class DecomposeNDExtractStridedSlice
272 : public OpRewritePattern<ExtractStridedSliceOp> {
273public:
274 using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
275
276 void initialize() {
277 // This pattern creates recursive ExtractStridedSliceOp, but the recursion
278 // is bounded as the rank is strictly decreasing.
279 setHasBoundedRewriteRecursion();
280 }
281
282 LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
283 PatternRewriter &rewriter) const override {
284 auto dstType = op.getType();
285
286 assert(!op.getOffsets().getValue().empty() && "Unexpected empty offsets");
287
288 int64_t offset =
289 cast<IntegerAttr>(op.getOffsets().getValue().front()).getInt();
290 int64_t size = cast<IntegerAttr>(op.getSizes().getValue().front()).getInt();
291 int64_t stride =
292 cast<IntegerAttr>(op.getStrides().getValue().front()).getInt();
293
294 auto loc = op.getLoc();
295 auto elemType = dstType.getElementType();
296 assert(elemType.isSignlessIntOrIndexOrFloat());
297
298 // Single offset can be more efficiently shuffled. It's handled in
299 // Convert1DExtractStridedSliceIntoShuffle.
300 if (op.getOffsets().getValue().size() == 1)
301 return failure();
302
303 // Extract/insert on a lower ranked extract strided slice op.
304 Value zero = rewriter.create<arith::ConstantOp>(
305 loc, elemType, rewriter.getZeroAttr(elemType));
306 Value res = rewriter.create<SplatOp>(loc, dstType, zero);
307 for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
308 off += stride, ++idx) {
309 Value one = rewriter.create<ExtractOp>(loc, op.getVector(), off);
310 Value extracted = rewriter.create<ExtractStridedSliceOp>(
311 loc, one, getI64SubArray(op.getOffsets(), /* dropFront=*/1),
312 getI64SubArray(op.getSizes(), /* dropFront=*/1),
313 getI64SubArray(op.getStrides(), /* dropFront=*/1));
314 res = rewriter.create<InsertOp>(loc, extracted, res, idx);
315 }
316 rewriter.replaceOp(op, res);
317 return success();
318 }
319};
320
321// TODO: Make sure these `populate*` patterns are tested in isolation.
322
323void vector::populateVectorInsertExtractStridedSliceDecompositionPatterns(
324 RewritePatternSet &patterns, PatternBenefit benefit) {
325 patterns.add<DecomposeDifferentRankInsertStridedSlice,
326 DecomposeNDExtractStridedSlice>(arg: patterns.getContext(), args&: benefit);
327}
328
329void vector::populateVectorExtractStridedSliceToExtractInsertChainPatterns(
330 RewritePatternSet &patterns,
331 std::function<bool(ExtractStridedSliceOp)> controlFn,
332 PatternBenefit benefit) {
333 patterns.add<Convert1DExtractStridedSliceIntoExtractInsertChain>(
334 patterns.getContext(), std::move(controlFn), benefit);
335}
336
337/// Populate the given list with patterns that convert from Vector to LLVM.
338void vector::populateVectorInsertExtractStridedSliceTransforms(
339 RewritePatternSet &patterns, PatternBenefit benefit) {
340 populateVectorInsertExtractStridedSliceDecompositionPatterns(patterns,
341 benefit);
342 patterns.add<ConvertSameRankInsertStridedSliceIntoShuffle,
343 Convert1DExtractStridedSliceIntoShuffle>(arg: patterns.getContext(),
344 args&: benefit);
345 // Generate chains of extract/insert ops for scalable vectors only as they
346 // can't be lowered to vector shuffles.
347 populateVectorExtractStridedSliceToExtractInsertChainPatterns(
348 patterns,
349 /*controlFn=*/
350 [](ExtractStridedSliceOp op) {
351 return op.getType().isScalable() ||
352 op.getSourceVectorType().isScalable();
353 },
354 benefit);
355}
356

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp