1 | //===- LowerVectorTranspose.cpp - Lower 'vector.transpose' operation ------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements target-independent rewrites and utilities to lower the |
10 | // 'vector.transpose' operation. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
15 | #include "mlir/Dialect/Arith/IR/Arith.h" |
16 | #include "mlir/Dialect/Arith/Utils/Utils.h" |
17 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
18 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
19 | #include "mlir/Dialect/SCF/IR/SCF.h" |
20 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
21 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
22 | #include "mlir/Dialect/Utils/StructuredOpsUtils.h" |
23 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
24 | #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" |
25 | #include "mlir/Dialect/Vector/Utils/VectorUtils.h" |
26 | #include "mlir/IR/BuiltinAttributeInterfaces.h" |
27 | #include "mlir/IR/BuiltinTypes.h" |
28 | #include "mlir/IR/ImplicitLocOpBuilder.h" |
29 | #include "mlir/IR/Location.h" |
30 | #include "mlir/IR/Matchers.h" |
31 | #include "mlir/IR/PatternMatch.h" |
32 | #include "mlir/IR/TypeUtilities.h" |
33 | #include "mlir/Interfaces/VectorInterfaces.h" |
34 | #include "mlir/Support/LogicalResult.h" |
35 | |
36 | #define DEBUG_TYPE "lower-vector-transpose" |
37 | |
38 | using namespace mlir; |
39 | using namespace mlir::vector; |
40 | |
41 | /// Given a 'transpose' pattern, prune the rightmost dimensions that are not |
42 | /// transposed. |
43 | static void pruneNonTransposedDims(ArrayRef<int64_t> transpose, |
44 | SmallVectorImpl<int64_t> &result) { |
45 | size_t numTransposedDims = transpose.size(); |
46 | for (size_t transpDim : llvm::reverse(C&: transpose)) { |
47 | if (transpDim != numTransposedDims - 1) |
48 | break; |
49 | numTransposedDims--; |
50 | } |
51 | |
52 | result.append(in_start: transpose.begin(), in_end: transpose.begin() + numTransposedDims); |
53 | } |
54 | |
55 | /// Returns true if the lowering option is a vector shuffle based approach. |
56 | static bool isShuffleLike(VectorTransposeLowering lowering) { |
57 | return lowering == VectorTransposeLowering::Shuffle1D || |
58 | lowering == VectorTransposeLowering::Shuffle16x16; |
59 | } |
60 | |
61 | /// Returns a shuffle mask that builds on `vals`. `vals` is the offset base of |
62 | /// shuffle ops, i.e., the unpack pattern. The method iterates with `vals` to |
63 | /// create the mask for `numBits` bits vector. The `numBits` have to be a |
64 | /// multiple of 128. For example, if `vals` is {0, 1, 16, 17} and `numBits` is |
65 | /// 512, there should be 16 elements in the final result. It constructs the |
66 | /// below mask to get the unpack elements. |
67 | /// [0, 1, 16, 17, |
68 | /// 0+4, 1+4, 16+4, 17+4, |
69 | /// 0+8, 1+8, 16+8, 17+8, |
70 | /// 0+12, 1+12, 16+12, 17+12] |
71 | static SmallVector<int64_t> |
72 | getUnpackShufflePermFor128Lane(ArrayRef<int64_t> vals, int numBits) { |
73 | assert(numBits % 128 == 0 && "expected numBits is a multiple of 128" ); |
74 | int numElem = numBits / 32; |
75 | SmallVector<int64_t> res; |
76 | for (int i = 0; i < numElem; i += 4) |
77 | for (int64_t v : vals) |
78 | res.push_back(Elt: v + i); |
79 | return res; |
80 | } |
81 | |
82 | /// Lower to vector.shuffle on v1 and v2 with UnpackLoPd shuffle mask. For |
83 | /// example, if it is targeting 512 bit vector, returns |
84 | /// vector.shuffle on v1, v2, [0, 1, 16, 17, |
85 | /// 0+4, 1+4, 16+4, 17+4, |
86 | /// 0+8, 1+8, 16+8, 17+8, |
87 | /// 0+12, 1+12, 16+12, 17+12]. |
88 | static Value createUnpackLoPd(ImplicitLocOpBuilder &b, Value v1, Value v2, |
89 | int numBits) { |
90 | int numElem = numBits / 32; |
91 | return b.create<vector::ShuffleOp>( |
92 | v1, v2, |
93 | getUnpackShufflePermFor128Lane({0, 1, numElem, numElem + 1}, numBits)); |
94 | } |
95 | |
96 | /// Lower to vector.shuffle on v1 and v2 with UnpackHiPd shuffle mask. For |
97 | /// example, if it is targeting 512 bit vector, returns |
98 | /// vector.shuffle, v1, v2, [2, 3, 18, 19, |
99 | /// 2+4, 3+4, 18+4, 19+4, |
100 | /// 2+8, 3+8, 18+8, 19+8, |
101 | /// 2+12, 3+12, 18+12, 19+12]. |
102 | static Value createUnpackHiPd(ImplicitLocOpBuilder &b, Value v1, Value v2, |
103 | int numBits) { |
104 | int numElem = numBits / 32; |
105 | return b.create<vector::ShuffleOp>( |
106 | v1, v2, |
107 | getUnpackShufflePermFor128Lane({2, 3, numElem + 2, numElem + 3}, |
108 | numBits)); |
109 | } |
110 | |
111 | /// Lower to vector.shuffle on v1 and v2 with UnpackLoPs shuffle mask. For |
112 | /// example, if it is targeting 512 bit vector, returns |
113 | /// vector.shuffle, v1, v2, [0, 16, 1, 17, |
114 | /// 0+4, 16+4, 1+4, 17+4, |
115 | /// 0+8, 16+8, 1+8, 17+8, |
116 | /// 0+12, 16+12, 1+12, 17+12]. |
117 | static Value createUnpackLoPs(ImplicitLocOpBuilder &b, Value v1, Value v2, |
118 | int numBits) { |
119 | int numElem = numBits / 32; |
120 | auto shuffle = b.create<vector::ShuffleOp>( |
121 | v1, v2, |
122 | getUnpackShufflePermFor128Lane({0, numElem, 1, numElem + 1}, numBits)); |
123 | return shuffle; |
124 | } |
125 | |
126 | /// Lower to vector.shuffle on v1 and v2 with UnpackHiPs shuffle mask. For |
127 | /// example, if it is targeting 512 bit vector, returns |
128 | /// vector.shuffle, v1, v2, [2, 18, 3, 19, |
129 | /// 2+4, 18+4, 3+4, 19+4, |
130 | /// 2+8, 18+8, 3+8, 19+8, |
131 | /// 2+12, 18+12, 3+12, 19+12]. |
132 | static Value createUnpackHiPs(ImplicitLocOpBuilder &b, Value v1, Value v2, |
133 | int numBits) { |
134 | int numElem = numBits / 32; |
135 | return b.create<vector::ShuffleOp>( |
136 | v1, v2, |
137 | getUnpackShufflePermFor128Lane({2, numElem + 2, 3, numElem + 3}, |
138 | numBits)); |
139 | } |
140 | |
141 | /// Returns a vector.shuffle that shuffles 128-bit lanes (composed of 4 32-bit |
142 | /// elements) selected by `mask` from `v1` and `v2`. I.e., |
143 | /// |
144 | /// DEFINE SELECT4(src, control) { |
145 | /// CASE(control[1:0]) OF |
146 | /// 0: tmp[127:0] := src[127:0] |
147 | /// 1: tmp[127:0] := src[255:128] |
148 | /// 2: tmp[127:0] := src[383:256] |
149 | /// 3: tmp[127:0] := src[511:384] |
150 | /// ESAC |
151 | /// RETURN tmp[127:0] |
152 | /// } |
153 | /// dst[127:0] := SELECT4(v1[511:0], mask[1:0]) |
154 | /// dst[255:128] := SELECT4(v1[511:0], mask[3:2]) |
155 | /// dst[383:256] := SELECT4(v2[511:0], mask[5:4]) |
156 | /// dst[511:384] := SELECT4(v2[511:0], mask[7:6]) |
157 | static Value create4x128BitSuffle(ImplicitLocOpBuilder &b, Value v1, Value v2, |
158 | uint8_t mask) { |
159 | assert(cast<VectorType>(v1.getType()).getShape()[0] == 16 && |
160 | "expected a vector with length=16" ); |
161 | SmallVector<int64_t> shuffleMask; |
162 | auto appendToMask = [&](int64_t base, uint8_t control) { |
163 | switch (control) { |
164 | case 0: |
165 | llvm::append_range(C&: shuffleMask, R: ArrayRef<int64_t>{base + 0, base + 1, |
166 | base + 2, base + 3}); |
167 | break; |
168 | case 1: |
169 | llvm::append_range(C&: shuffleMask, R: ArrayRef<int64_t>{base + 4, base + 5, |
170 | base + 6, base + 7}); |
171 | break; |
172 | case 2: |
173 | llvm::append_range(C&: shuffleMask, R: ArrayRef<int64_t>{base + 8, base + 9, |
174 | base + 10, base + 11}); |
175 | break; |
176 | case 3: |
177 | llvm::append_range(C&: shuffleMask, R: ArrayRef<int64_t>{base + 12, base + 13, |
178 | base + 14, base + 15}); |
179 | break; |
180 | default: |
181 | llvm_unreachable("control > 3 : overflow" ); |
182 | } |
183 | }; |
184 | uint8_t b01 = mask & 0x3; |
185 | uint8_t b23 = (mask >> 2) & 0x3; |
186 | uint8_t b45 = (mask >> 4) & 0x3; |
187 | uint8_t b67 = (mask >> 6) & 0x3; |
188 | appendToMask(0, b01); |
189 | appendToMask(0, b23); |
190 | appendToMask(16, b45); |
191 | appendToMask(16, b67); |
192 | return b.create<vector::ShuffleOp>(v1, v2, shuffleMask); |
193 | } |
194 | |
195 | /// Lowers the value to a vector.shuffle op. The `source` is expected to be a |
196 | /// 1-D vector and have `m`x`n` elements. |
197 | static Value transposeToShuffle1D(OpBuilder &b, Value source, int m, int n) { |
198 | SmallVector<int64_t> mask; |
199 | mask.reserve(N: m * n); |
200 | for (int64_t j = 0; j < n; ++j) |
201 | for (int64_t i = 0; i < m; ++i) |
202 | mask.push_back(Elt: i * n + j); |
203 | return b.create<vector::ShuffleOp>(source.getLoc(), source, source, mask); |
204 | } |
205 | |
206 | /// Lowers the value to a sequence of vector.shuffle ops. The `source` is |
207 | /// expected to be a 16x16 vector. |
208 | static Value transposeToShuffle16x16(OpBuilder &builder, Value source, int m, |
209 | int n) { |
210 | ImplicitLocOpBuilder b(source.getLoc(), builder); |
211 | SmallVector<Value> vs; |
212 | for (int64_t i = 0; i < m; ++i) |
213 | vs.push_back(b.create<vector::ExtractOp>(source, i)); |
214 | |
215 | // Interleave 32-bit lanes using |
216 | // 8x _mm512_unpacklo_epi32 |
217 | // 8x _mm512_unpackhi_epi32 |
218 | Value t0 = createUnpackLoPs(b, v1: vs[0x0], v2: vs[0x1], numBits: 512); |
219 | Value t1 = createUnpackHiPs(b, v1: vs[0x0], v2: vs[0x1], numBits: 512); |
220 | Value t2 = createUnpackLoPs(b, v1: vs[0x2], v2: vs[0x3], numBits: 512); |
221 | Value t3 = createUnpackHiPs(b, v1: vs[0x2], v2: vs[0x3], numBits: 512); |
222 | Value t4 = createUnpackLoPs(b, v1: vs[0x4], v2: vs[0x5], numBits: 512); |
223 | Value t5 = createUnpackHiPs(b, v1: vs[0x4], v2: vs[0x5], numBits: 512); |
224 | Value t6 = createUnpackLoPs(b, v1: vs[0x6], v2: vs[0x7], numBits: 512); |
225 | Value t7 = createUnpackHiPs(b, v1: vs[0x6], v2: vs[0x7], numBits: 512); |
226 | Value t8 = createUnpackLoPs(b, v1: vs[0x8], v2: vs[0x9], numBits: 512); |
227 | Value t9 = createUnpackHiPs(b, v1: vs[0x8], v2: vs[0x9], numBits: 512); |
228 | Value ta = createUnpackLoPs(b, v1: vs[0xa], v2: vs[0xb], numBits: 512); |
229 | Value tb = createUnpackHiPs(b, v1: vs[0xa], v2: vs[0xb], numBits: 512); |
230 | Value tc = createUnpackLoPs(b, v1: vs[0xc], v2: vs[0xd], numBits: 512); |
231 | Value td = createUnpackHiPs(b, v1: vs[0xc], v2: vs[0xd], numBits: 512); |
232 | Value te = createUnpackLoPs(b, v1: vs[0xe], v2: vs[0xf], numBits: 512); |
233 | Value tf = createUnpackHiPs(b, v1: vs[0xe], v2: vs[0xf], numBits: 512); |
234 | |
235 | // Interleave 64-bit lanes using |
236 | // 8x _mm512_unpacklo_epi64 |
237 | // 8x _mm512_unpackhi_epi64 |
238 | Value r0 = createUnpackLoPd(b, v1: t0, v2: t2, numBits: 512); |
239 | Value r1 = createUnpackHiPd(b, v1: t0, v2: t2, numBits: 512); |
240 | Value r2 = createUnpackLoPd(b, v1: t1, v2: t3, numBits: 512); |
241 | Value r3 = createUnpackHiPd(b, v1: t1, v2: t3, numBits: 512); |
242 | Value r4 = createUnpackLoPd(b, v1: t4, v2: t6, numBits: 512); |
243 | Value r5 = createUnpackHiPd(b, v1: t4, v2: t6, numBits: 512); |
244 | Value r6 = createUnpackLoPd(b, v1: t5, v2: t7, numBits: 512); |
245 | Value r7 = createUnpackHiPd(b, v1: t5, v2: t7, numBits: 512); |
246 | Value r8 = createUnpackLoPd(b, v1: t8, v2: ta, numBits: 512); |
247 | Value r9 = createUnpackHiPd(b, v1: t8, v2: ta, numBits: 512); |
248 | Value ra = createUnpackLoPd(b, v1: t9, v2: tb, numBits: 512); |
249 | Value rb = createUnpackHiPd(b, v1: t9, v2: tb, numBits: 512); |
250 | Value rc = createUnpackLoPd(b, v1: tc, v2: te, numBits: 512); |
251 | Value rd = createUnpackHiPd(b, v1: tc, v2: te, numBits: 512); |
252 | Value re = createUnpackLoPd(b, v1: td, v2: tf, numBits: 512); |
253 | Value rf = createUnpackHiPd(b, v1: td, v2: tf, numBits: 512); |
254 | |
255 | // Permute 128-bit lanes using |
256 | // 16x _mm512_shuffle_i32x4 |
257 | t0 = create4x128BitSuffle(b, v1: r0, v2: r4, mask: 0x88); |
258 | t1 = create4x128BitSuffle(b, v1: r1, v2: r5, mask: 0x88); |
259 | t2 = create4x128BitSuffle(b, v1: r2, v2: r6, mask: 0x88); |
260 | t3 = create4x128BitSuffle(b, v1: r3, v2: r7, mask: 0x88); |
261 | t4 = create4x128BitSuffle(b, v1: r0, v2: r4, mask: 0xdd); |
262 | t5 = create4x128BitSuffle(b, v1: r1, v2: r5, mask: 0xdd); |
263 | t6 = create4x128BitSuffle(b, v1: r2, v2: r6, mask: 0xdd); |
264 | t7 = create4x128BitSuffle(b, v1: r3, v2: r7, mask: 0xdd); |
265 | t8 = create4x128BitSuffle(b, v1: r8, v2: rc, mask: 0x88); |
266 | t9 = create4x128BitSuffle(b, v1: r9, v2: rd, mask: 0x88); |
267 | ta = create4x128BitSuffle(b, v1: ra, v2: re, mask: 0x88); |
268 | tb = create4x128BitSuffle(b, v1: rb, v2: rf, mask: 0x88); |
269 | tc = create4x128BitSuffle(b, v1: r8, v2: rc, mask: 0xdd); |
270 | td = create4x128BitSuffle(b, v1: r9, v2: rd, mask: 0xdd); |
271 | te = create4x128BitSuffle(b, v1: ra, v2: re, mask: 0xdd); |
272 | tf = create4x128BitSuffle(b, v1: rb, v2: rf, mask: 0xdd); |
273 | |
274 | // Permute 256-bit lanes using again |
275 | // 16x _mm512_shuffle_i32x4 |
276 | vs[0x0] = create4x128BitSuffle(b, v1: t0, v2: t8, mask: 0x88); |
277 | vs[0x1] = create4x128BitSuffle(b, v1: t1, v2: t9, mask: 0x88); |
278 | vs[0x2] = create4x128BitSuffle(b, v1: t2, v2: ta, mask: 0x88); |
279 | vs[0x3] = create4x128BitSuffle(b, v1: t3, v2: tb, mask: 0x88); |
280 | vs[0x4] = create4x128BitSuffle(b, v1: t4, v2: tc, mask: 0x88); |
281 | vs[0x5] = create4x128BitSuffle(b, v1: t5, v2: td, mask: 0x88); |
282 | vs[0x6] = create4x128BitSuffle(b, v1: t6, v2: te, mask: 0x88); |
283 | vs[0x7] = create4x128BitSuffle(b, v1: t7, v2: tf, mask: 0x88); |
284 | vs[0x8] = create4x128BitSuffle(b, v1: t0, v2: t8, mask: 0xdd); |
285 | vs[0x9] = create4x128BitSuffle(b, v1: t1, v2: t9, mask: 0xdd); |
286 | vs[0xa] = create4x128BitSuffle(b, v1: t2, v2: ta, mask: 0xdd); |
287 | vs[0xb] = create4x128BitSuffle(b, v1: t3, v2: tb, mask: 0xdd); |
288 | vs[0xc] = create4x128BitSuffle(b, v1: t4, v2: tc, mask: 0xdd); |
289 | vs[0xd] = create4x128BitSuffle(b, v1: t5, v2: td, mask: 0xdd); |
290 | vs[0xe] = create4x128BitSuffle(b, v1: t6, v2: te, mask: 0xdd); |
291 | vs[0xf] = create4x128BitSuffle(b, v1: t7, v2: tf, mask: 0xdd); |
292 | |
293 | auto reshInputType = VectorType::get( |
294 | {m, n}, cast<VectorType>(source.getType()).getElementType()); |
295 | Value res = |
296 | b.create<arith::ConstantOp>(reshInputType, b.getZeroAttr(reshInputType)); |
297 | for (int64_t i = 0; i < m; ++i) |
298 | res = b.create<vector::InsertOp>(vs[i], res, i); |
299 | return res; |
300 | } |
301 | |
302 | namespace { |
303 | /// Progressive lowering of TransposeOp. |
304 | /// One: |
305 | /// %x = vector.transpose %y, [1, 0] |
306 | /// is replaced by: |
307 | /// %z = arith.constant dense<0.000000e+00> |
308 | /// %0 = vector.extract %y[0, 0] |
309 | /// %1 = vector.insert %0, %z [0, 0] |
310 | /// .. |
311 | /// %x = vector.insert .., .. [.., ..] |
312 | class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> { |
313 | public: |
314 | using OpRewritePattern::OpRewritePattern; |
315 | |
316 | TransposeOpLowering(vector::VectorTransformsOptions vectorTransformOptions, |
317 | MLIRContext *context, PatternBenefit benefit = 1) |
318 | : OpRewritePattern<vector::TransposeOp>(context, benefit), |
319 | vectorTransformOptions(vectorTransformOptions) {} |
320 | |
321 | LogicalResult matchAndRewrite(vector::TransposeOp op, |
322 | PatternRewriter &rewriter) const override { |
323 | auto loc = op.getLoc(); |
324 | |
325 | Value input = op.getVector(); |
326 | VectorType inputType = op.getSourceVectorType(); |
327 | VectorType resType = op.getResultVectorType(); |
328 | |
329 | // Set up convenience transposition table. |
330 | ArrayRef<int64_t> transp = op.getPermutation(); |
331 | |
332 | if (isShuffleLike(vectorTransformOptions.vectorTransposeLowering) && |
333 | succeeded(isTranspose2DSlice(op))) |
334 | return rewriter.notifyMatchFailure( |
335 | op, "Options specifies lowering to shuffle" ); |
336 | |
337 | // Replace: |
338 | // vector.transpose %0, [1, 0] : vector<nx1x<eltty>> to |
339 | // vector<1xnxelty> |
340 | // with: |
341 | // vector.shape_cast %0 : vector<nx1x<eltty>> to vector<1xnxelty> |
342 | // |
343 | // Source with leading unit dim (inverse) is also replaced. Unit dim must |
344 | // be fixed. Non-unit can be scalable. |
345 | if (resType.getRank() == 2 && |
346 | ((resType.getShape().front() == 1 && |
347 | !resType.getScalableDims().front()) || |
348 | (resType.getShape().back() == 1 && |
349 | !resType.getScalableDims().back())) && |
350 | transp == ArrayRef<int64_t>({1, 0})) { |
351 | rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, input); |
352 | return success(); |
353 | } |
354 | |
355 | if (inputType.isScalable()) |
356 | return failure(); |
357 | |
358 | // Handle a true 2-D matrix transpose differently when requested. |
359 | if (vectorTransformOptions.vectorTransposeLowering == |
360 | vector::VectorTransposeLowering::Flat && |
361 | resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) { |
362 | Type flattenedType = |
363 | VectorType::get(resType.getNumElements(), resType.getElementType()); |
364 | auto matrix = |
365 | rewriter.create<vector::ShapeCastOp>(loc, flattenedType, input); |
366 | auto rows = rewriter.getI32IntegerAttr(value: resType.getShape()[0]); |
367 | auto columns = rewriter.getI32IntegerAttr(value: resType.getShape()[1]); |
368 | Value trans = rewriter.create<vector::FlatTransposeOp>( |
369 | loc, flattenedType, matrix, rows, columns); |
370 | rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, trans); |
371 | return success(); |
372 | } |
373 | |
374 | // Generate unrolled extract/insert ops. We do not unroll the rightmost |
375 | // (i.e., highest-order) dimensions that are not transposed and leave them |
376 | // in vector form to improve performance. Therefore, we prune those |
377 | // dimensions from the shape/transpose data structures used to generate the |
378 | // extract/insert ops. |
379 | SmallVector<int64_t> prunedTransp; |
380 | pruneNonTransposedDims(transpose: transp, result&: prunedTransp); |
381 | size_t numPrunedDims = transp.size() - prunedTransp.size(); |
382 | auto prunedInShape = inputType.getShape().drop_back(numPrunedDims); |
383 | auto prunedInStrides = computeStrides(prunedInShape); |
384 | |
385 | // Generates the extract/insert operations for every scalar/vector element |
386 | // of the leftmost transposed dimensions. We traverse every transpose |
387 | // element using a linearized index that we delinearize to generate the |
388 | // appropriate indices for the extract/insert operations. |
389 | Value result = rewriter.create<arith::ConstantOp>( |
390 | loc, resType, rewriter.getZeroAttr(resType)); |
391 | int64_t numTransposedElements = ShapedType::getNumElements(prunedInShape); |
392 | |
393 | for (int64_t linearIdx = 0; linearIdx < numTransposedElements; |
394 | ++linearIdx) { |
395 | auto = delinearize(linearIdx, prunedInStrides); |
396 | SmallVector<int64_t> insertIdxs(extractIdxs); |
397 | applyPermutationToVector(inVec&: insertIdxs, permutation: prunedTransp); |
398 | Value = |
399 | rewriter.create<vector::ExtractOp>(loc, input, extractIdxs); |
400 | result = |
401 | rewriter.create<vector::InsertOp>(loc, extractOp, result, insertIdxs); |
402 | } |
403 | |
404 | rewriter.replaceOp(op, result); |
405 | return success(); |
406 | } |
407 | |
408 | private: |
409 | /// Options to control the vector patterns. |
410 | vector::VectorTransformsOptions vectorTransformOptions; |
411 | }; |
412 | |
413 | /// Rewrite a 2-D vector.transpose as a sequence of shuffle ops. |
414 | /// If the strategy is Shuffle1D, it will be lowered to: |
415 | /// vector.shape_cast 2D -> 1D |
416 | /// vector.shuffle |
417 | /// vector.shape_cast 1D -> 2D |
418 | /// If the strategy is Shuffle16x16, it will be lowered to a sequence of shuffle |
419 | /// ops on 16xf32 vectors. |
420 | class TransposeOp2DToShuffleLowering |
421 | : public OpRewritePattern<vector::TransposeOp> { |
422 | public: |
423 | using OpRewritePattern::OpRewritePattern; |
424 | |
425 | TransposeOp2DToShuffleLowering( |
426 | vector::VectorTransformsOptions vectorTransformOptions, |
427 | MLIRContext *context, PatternBenefit benefit = 1) |
428 | : OpRewritePattern<vector::TransposeOp>(context, benefit), |
429 | vectorTransformOptions(vectorTransformOptions) {} |
430 | |
431 | LogicalResult matchAndRewrite(vector::TransposeOp op, |
432 | PatternRewriter &rewriter) const override { |
433 | if (!isShuffleLike(vectorTransformOptions.vectorTransposeLowering)) |
434 | return rewriter.notifyMatchFailure( |
435 | op, "not using vector shuffle based lowering" ); |
436 | |
437 | if (op.getSourceVectorType().isScalable()) |
438 | return rewriter.notifyMatchFailure( |
439 | op, "vector shuffle lowering not supported for scalable vectors" ); |
440 | |
441 | auto srcGtOneDims = isTranspose2DSlice(op); |
442 | if (failed(srcGtOneDims)) |
443 | return rewriter.notifyMatchFailure( |
444 | op, "expected transposition on a 2D slice" ); |
445 | |
446 | VectorType srcType = op.getSourceVectorType(); |
447 | int64_t m = srcType.getDimSize(std::get<0>(srcGtOneDims.value())); |
448 | int64_t n = srcType.getDimSize(std::get<1>(srcGtOneDims.value())); |
449 | |
450 | // Reshape the n-D input vector with only two dimensions greater than one |
451 | // to a 2-D vector. |
452 | Location loc = op.getLoc(); |
453 | auto flattenedType = VectorType::get({n * m}, srcType.getElementType()); |
454 | auto reshInputType = VectorType::get({m, n}, srcType.getElementType()); |
455 | auto reshInput = rewriter.create<vector::ShapeCastOp>(loc, flattenedType, |
456 | op.getVector()); |
457 | |
458 | Value res; |
459 | if (vectorTransformOptions.vectorTransposeLowering == |
460 | VectorTransposeLowering::Shuffle16x16 && |
461 | m == 16 && n == 16) { |
462 | reshInput = |
463 | rewriter.create<vector::ShapeCastOp>(loc, reshInputType, reshInput); |
464 | res = transposeToShuffle16x16(rewriter, reshInput, m, n); |
465 | } else { |
466 | // Fallback to shuffle on 1D approach. |
467 | res = transposeToShuffle1D(rewriter, reshInput, m, n); |
468 | } |
469 | |
470 | rewriter.replaceOpWithNewOp<vector::ShapeCastOp>( |
471 | op, op.getResultVectorType(), res); |
472 | |
473 | return success(); |
474 | } |
475 | |
476 | private: |
477 | /// Options to control the vector patterns. |
478 | vector::VectorTransformsOptions vectorTransformOptions; |
479 | }; |
480 | } // namespace |
481 | |
482 | void mlir::vector::populateVectorTransposeLoweringPatterns( |
483 | RewritePatternSet &patterns, VectorTransformsOptions options, |
484 | PatternBenefit benefit) { |
485 | patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>( |
486 | arg&: options, args: patterns.getContext(), args&: benefit); |
487 | } |
488 | |