1//===- LowerVectorTranspose.cpp - Lower 'vector.transpose' operation ------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements target-independent rewrites and utilities to lower the
10// 'vector.transpose' operation.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/Affine/IR/AffineOps.h"
15#include "mlir/Dialect/Arith/IR/Arith.h"
16#include "mlir/Dialect/Arith/Utils/Utils.h"
17#include "mlir/Dialect/Linalg/IR/Linalg.h"
18#include "mlir/Dialect/MemRef/IR/MemRef.h"
19#include "mlir/Dialect/SCF/IR/SCF.h"
20#include "mlir/Dialect/Tensor/IR/Tensor.h"
21#include "mlir/Dialect/Utils/IndexingUtils.h"
22#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
23#include "mlir/Dialect/Vector/IR/VectorOps.h"
24#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
25#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
26#include "mlir/IR/BuiltinAttributeInterfaces.h"
27#include "mlir/IR/BuiltinTypes.h"
28#include "mlir/IR/ImplicitLocOpBuilder.h"
29#include "mlir/IR/Location.h"
30#include "mlir/IR/Matchers.h"
31#include "mlir/IR/PatternMatch.h"
32#include "mlir/IR/TypeUtilities.h"
33#include "mlir/Interfaces/VectorInterfaces.h"
34#include "mlir/Support/LogicalResult.h"
35
36#define DEBUG_TYPE "lower-vector-transpose"
37
38using namespace mlir;
39using namespace mlir::vector;
40
41/// Given a 'transpose' pattern, prune the rightmost dimensions that are not
42/// transposed.
43static void pruneNonTransposedDims(ArrayRef<int64_t> transpose,
44 SmallVectorImpl<int64_t> &result) {
45 size_t numTransposedDims = transpose.size();
46 for (size_t transpDim : llvm::reverse(C&: transpose)) {
47 if (transpDim != numTransposedDims - 1)
48 break;
49 numTransposedDims--;
50 }
51
52 result.append(in_start: transpose.begin(), in_end: transpose.begin() + numTransposedDims);
53}
54
55/// Returns true if the lowering option is a vector shuffle based approach.
56static bool isShuffleLike(VectorTransposeLowering lowering) {
57 return lowering == VectorTransposeLowering::Shuffle1D ||
58 lowering == VectorTransposeLowering::Shuffle16x16;
59}
60
61/// Returns a shuffle mask that builds on `vals`. `vals` is the offset base of
62/// shuffle ops, i.e., the unpack pattern. The method iterates with `vals` to
63/// create the mask for `numBits` bits vector. The `numBits` have to be a
64/// multiple of 128. For example, if `vals` is {0, 1, 16, 17} and `numBits` is
65/// 512, there should be 16 elements in the final result. It constructs the
66/// below mask to get the unpack elements.
67/// [0, 1, 16, 17,
68/// 0+4, 1+4, 16+4, 17+4,
69/// 0+8, 1+8, 16+8, 17+8,
70/// 0+12, 1+12, 16+12, 17+12]
71static SmallVector<int64_t>
72getUnpackShufflePermFor128Lane(ArrayRef<int64_t> vals, int numBits) {
73 assert(numBits % 128 == 0 && "expected numBits is a multiple of 128");
74 int numElem = numBits / 32;
75 SmallVector<int64_t> res;
76 for (int i = 0; i < numElem; i += 4)
77 for (int64_t v : vals)
78 res.push_back(Elt: v + i);
79 return res;
80}
81
82/// Lower to vector.shuffle on v1 and v2 with UnpackLoPd shuffle mask. For
83/// example, if it is targeting 512 bit vector, returns
84/// vector.shuffle on v1, v2, [0, 1, 16, 17,
85/// 0+4, 1+4, 16+4, 17+4,
86/// 0+8, 1+8, 16+8, 17+8,
87/// 0+12, 1+12, 16+12, 17+12].
88static Value createUnpackLoPd(ImplicitLocOpBuilder &b, Value v1, Value v2,
89 int numBits) {
90 int numElem = numBits / 32;
91 return b.create<vector::ShuffleOp>(
92 v1, v2,
93 getUnpackShufflePermFor128Lane({0, 1, numElem, numElem + 1}, numBits));
94}
95
96/// Lower to vector.shuffle on v1 and v2 with UnpackHiPd shuffle mask. For
97/// example, if it is targeting 512 bit vector, returns
98/// vector.shuffle, v1, v2, [2, 3, 18, 19,
99/// 2+4, 3+4, 18+4, 19+4,
100/// 2+8, 3+8, 18+8, 19+8,
101/// 2+12, 3+12, 18+12, 19+12].
102static Value createUnpackHiPd(ImplicitLocOpBuilder &b, Value v1, Value v2,
103 int numBits) {
104 int numElem = numBits / 32;
105 return b.create<vector::ShuffleOp>(
106 v1, v2,
107 getUnpackShufflePermFor128Lane({2, 3, numElem + 2, numElem + 3},
108 numBits));
109}
110
111/// Lower to vector.shuffle on v1 and v2 with UnpackLoPs shuffle mask. For
112/// example, if it is targeting 512 bit vector, returns
113/// vector.shuffle, v1, v2, [0, 16, 1, 17,
114/// 0+4, 16+4, 1+4, 17+4,
115/// 0+8, 16+8, 1+8, 17+8,
116/// 0+12, 16+12, 1+12, 17+12].
117static Value createUnpackLoPs(ImplicitLocOpBuilder &b, Value v1, Value v2,
118 int numBits) {
119 int numElem = numBits / 32;
120 auto shuffle = b.create<vector::ShuffleOp>(
121 v1, v2,
122 getUnpackShufflePermFor128Lane({0, numElem, 1, numElem + 1}, numBits));
123 return shuffle;
124}
125
126/// Lower to vector.shuffle on v1 and v2 with UnpackHiPs shuffle mask. For
127/// example, if it is targeting 512 bit vector, returns
128/// vector.shuffle, v1, v2, [2, 18, 3, 19,
129/// 2+4, 18+4, 3+4, 19+4,
130/// 2+8, 18+8, 3+8, 19+8,
131/// 2+12, 18+12, 3+12, 19+12].
132static Value createUnpackHiPs(ImplicitLocOpBuilder &b, Value v1, Value v2,
133 int numBits) {
134 int numElem = numBits / 32;
135 return b.create<vector::ShuffleOp>(
136 v1, v2,
137 getUnpackShufflePermFor128Lane({2, numElem + 2, 3, numElem + 3},
138 numBits));
139}
140
141/// Returns a vector.shuffle that shuffles 128-bit lanes (composed of 4 32-bit
142/// elements) selected by `mask` from `v1` and `v2`. I.e.,
143///
144/// DEFINE SELECT4(src, control) {
145/// CASE(control[1:0]) OF
146/// 0: tmp[127:0] := src[127:0]
147/// 1: tmp[127:0] := src[255:128]
148/// 2: tmp[127:0] := src[383:256]
149/// 3: tmp[127:0] := src[511:384]
150/// ESAC
151/// RETURN tmp[127:0]
152/// }
153/// dst[127:0] := SELECT4(v1[511:0], mask[1:0])
154/// dst[255:128] := SELECT4(v1[511:0], mask[3:2])
155/// dst[383:256] := SELECT4(v2[511:0], mask[5:4])
156/// dst[511:384] := SELECT4(v2[511:0], mask[7:6])
157static Value create4x128BitSuffle(ImplicitLocOpBuilder &b, Value v1, Value v2,
158 uint8_t mask) {
159 assert(cast<VectorType>(v1.getType()).getShape()[0] == 16 &&
160 "expected a vector with length=16");
161 SmallVector<int64_t> shuffleMask;
162 auto appendToMask = [&](int64_t base, uint8_t control) {
163 switch (control) {
164 case 0:
165 llvm::append_range(C&: shuffleMask, R: ArrayRef<int64_t>{base + 0, base + 1,
166 base + 2, base + 3});
167 break;
168 case 1:
169 llvm::append_range(C&: shuffleMask, R: ArrayRef<int64_t>{base + 4, base + 5,
170 base + 6, base + 7});
171 break;
172 case 2:
173 llvm::append_range(C&: shuffleMask, R: ArrayRef<int64_t>{base + 8, base + 9,
174 base + 10, base + 11});
175 break;
176 case 3:
177 llvm::append_range(C&: shuffleMask, R: ArrayRef<int64_t>{base + 12, base + 13,
178 base + 14, base + 15});
179 break;
180 default:
181 llvm_unreachable("control > 3 : overflow");
182 }
183 };
184 uint8_t b01 = mask & 0x3;
185 uint8_t b23 = (mask >> 2) & 0x3;
186 uint8_t b45 = (mask >> 4) & 0x3;
187 uint8_t b67 = (mask >> 6) & 0x3;
188 appendToMask(0, b01);
189 appendToMask(0, b23);
190 appendToMask(16, b45);
191 appendToMask(16, b67);
192 return b.create<vector::ShuffleOp>(v1, v2, shuffleMask);
193}
194
195/// Lowers the value to a vector.shuffle op. The `source` is expected to be a
196/// 1-D vector and have `m`x`n` elements.
197static Value transposeToShuffle1D(OpBuilder &b, Value source, int m, int n) {
198 SmallVector<int64_t> mask;
199 mask.reserve(N: m * n);
200 for (int64_t j = 0; j < n; ++j)
201 for (int64_t i = 0; i < m; ++i)
202 mask.push_back(Elt: i * n + j);
203 return b.create<vector::ShuffleOp>(source.getLoc(), source, source, mask);
204}
205
206/// Lowers the value to a sequence of vector.shuffle ops. The `source` is
207/// expected to be a 16x16 vector.
208static Value transposeToShuffle16x16(OpBuilder &builder, Value source, int m,
209 int n) {
210 ImplicitLocOpBuilder b(source.getLoc(), builder);
211 SmallVector<Value> vs;
212 for (int64_t i = 0; i < m; ++i)
213 vs.push_back(b.create<vector::ExtractOp>(source, i));
214
215 // Interleave 32-bit lanes using
216 // 8x _mm512_unpacklo_epi32
217 // 8x _mm512_unpackhi_epi32
218 Value t0 = createUnpackLoPs(b, v1: vs[0x0], v2: vs[0x1], numBits: 512);
219 Value t1 = createUnpackHiPs(b, v1: vs[0x0], v2: vs[0x1], numBits: 512);
220 Value t2 = createUnpackLoPs(b, v1: vs[0x2], v2: vs[0x3], numBits: 512);
221 Value t3 = createUnpackHiPs(b, v1: vs[0x2], v2: vs[0x3], numBits: 512);
222 Value t4 = createUnpackLoPs(b, v1: vs[0x4], v2: vs[0x5], numBits: 512);
223 Value t5 = createUnpackHiPs(b, v1: vs[0x4], v2: vs[0x5], numBits: 512);
224 Value t6 = createUnpackLoPs(b, v1: vs[0x6], v2: vs[0x7], numBits: 512);
225 Value t7 = createUnpackHiPs(b, v1: vs[0x6], v2: vs[0x7], numBits: 512);
226 Value t8 = createUnpackLoPs(b, v1: vs[0x8], v2: vs[0x9], numBits: 512);
227 Value t9 = createUnpackHiPs(b, v1: vs[0x8], v2: vs[0x9], numBits: 512);
228 Value ta = createUnpackLoPs(b, v1: vs[0xa], v2: vs[0xb], numBits: 512);
229 Value tb = createUnpackHiPs(b, v1: vs[0xa], v2: vs[0xb], numBits: 512);
230 Value tc = createUnpackLoPs(b, v1: vs[0xc], v2: vs[0xd], numBits: 512);
231 Value td = createUnpackHiPs(b, v1: vs[0xc], v2: vs[0xd], numBits: 512);
232 Value te = createUnpackLoPs(b, v1: vs[0xe], v2: vs[0xf], numBits: 512);
233 Value tf = createUnpackHiPs(b, v1: vs[0xe], v2: vs[0xf], numBits: 512);
234
235 // Interleave 64-bit lanes using
236 // 8x _mm512_unpacklo_epi64
237 // 8x _mm512_unpackhi_epi64
238 Value r0 = createUnpackLoPd(b, v1: t0, v2: t2, numBits: 512);
239 Value r1 = createUnpackHiPd(b, v1: t0, v2: t2, numBits: 512);
240 Value r2 = createUnpackLoPd(b, v1: t1, v2: t3, numBits: 512);
241 Value r3 = createUnpackHiPd(b, v1: t1, v2: t3, numBits: 512);
242 Value r4 = createUnpackLoPd(b, v1: t4, v2: t6, numBits: 512);
243 Value r5 = createUnpackHiPd(b, v1: t4, v2: t6, numBits: 512);
244 Value r6 = createUnpackLoPd(b, v1: t5, v2: t7, numBits: 512);
245 Value r7 = createUnpackHiPd(b, v1: t5, v2: t7, numBits: 512);
246 Value r8 = createUnpackLoPd(b, v1: t8, v2: ta, numBits: 512);
247 Value r9 = createUnpackHiPd(b, v1: t8, v2: ta, numBits: 512);
248 Value ra = createUnpackLoPd(b, v1: t9, v2: tb, numBits: 512);
249 Value rb = createUnpackHiPd(b, v1: t9, v2: tb, numBits: 512);
250 Value rc = createUnpackLoPd(b, v1: tc, v2: te, numBits: 512);
251 Value rd = createUnpackHiPd(b, v1: tc, v2: te, numBits: 512);
252 Value re = createUnpackLoPd(b, v1: td, v2: tf, numBits: 512);
253 Value rf = createUnpackHiPd(b, v1: td, v2: tf, numBits: 512);
254
255 // Permute 128-bit lanes using
256 // 16x _mm512_shuffle_i32x4
257 t0 = create4x128BitSuffle(b, v1: r0, v2: r4, mask: 0x88);
258 t1 = create4x128BitSuffle(b, v1: r1, v2: r5, mask: 0x88);
259 t2 = create4x128BitSuffle(b, v1: r2, v2: r6, mask: 0x88);
260 t3 = create4x128BitSuffle(b, v1: r3, v2: r7, mask: 0x88);
261 t4 = create4x128BitSuffle(b, v1: r0, v2: r4, mask: 0xdd);
262 t5 = create4x128BitSuffle(b, v1: r1, v2: r5, mask: 0xdd);
263 t6 = create4x128BitSuffle(b, v1: r2, v2: r6, mask: 0xdd);
264 t7 = create4x128BitSuffle(b, v1: r3, v2: r7, mask: 0xdd);
265 t8 = create4x128BitSuffle(b, v1: r8, v2: rc, mask: 0x88);
266 t9 = create4x128BitSuffle(b, v1: r9, v2: rd, mask: 0x88);
267 ta = create4x128BitSuffle(b, v1: ra, v2: re, mask: 0x88);
268 tb = create4x128BitSuffle(b, v1: rb, v2: rf, mask: 0x88);
269 tc = create4x128BitSuffle(b, v1: r8, v2: rc, mask: 0xdd);
270 td = create4x128BitSuffle(b, v1: r9, v2: rd, mask: 0xdd);
271 te = create4x128BitSuffle(b, v1: ra, v2: re, mask: 0xdd);
272 tf = create4x128BitSuffle(b, v1: rb, v2: rf, mask: 0xdd);
273
274 // Permute 256-bit lanes using again
275 // 16x _mm512_shuffle_i32x4
276 vs[0x0] = create4x128BitSuffle(b, v1: t0, v2: t8, mask: 0x88);
277 vs[0x1] = create4x128BitSuffle(b, v1: t1, v2: t9, mask: 0x88);
278 vs[0x2] = create4x128BitSuffle(b, v1: t2, v2: ta, mask: 0x88);
279 vs[0x3] = create4x128BitSuffle(b, v1: t3, v2: tb, mask: 0x88);
280 vs[0x4] = create4x128BitSuffle(b, v1: t4, v2: tc, mask: 0x88);
281 vs[0x5] = create4x128BitSuffle(b, v1: t5, v2: td, mask: 0x88);
282 vs[0x6] = create4x128BitSuffle(b, v1: t6, v2: te, mask: 0x88);
283 vs[0x7] = create4x128BitSuffle(b, v1: t7, v2: tf, mask: 0x88);
284 vs[0x8] = create4x128BitSuffle(b, v1: t0, v2: t8, mask: 0xdd);
285 vs[0x9] = create4x128BitSuffle(b, v1: t1, v2: t9, mask: 0xdd);
286 vs[0xa] = create4x128BitSuffle(b, v1: t2, v2: ta, mask: 0xdd);
287 vs[0xb] = create4x128BitSuffle(b, v1: t3, v2: tb, mask: 0xdd);
288 vs[0xc] = create4x128BitSuffle(b, v1: t4, v2: tc, mask: 0xdd);
289 vs[0xd] = create4x128BitSuffle(b, v1: t5, v2: td, mask: 0xdd);
290 vs[0xe] = create4x128BitSuffle(b, v1: t6, v2: te, mask: 0xdd);
291 vs[0xf] = create4x128BitSuffle(b, v1: t7, v2: tf, mask: 0xdd);
292
293 auto reshInputType = VectorType::get(
294 {m, n}, cast<VectorType>(source.getType()).getElementType());
295 Value res =
296 b.create<arith::ConstantOp>(reshInputType, b.getZeroAttr(reshInputType));
297 for (int64_t i = 0; i < m; ++i)
298 res = b.create<vector::InsertOp>(vs[i], res, i);
299 return res;
300}
301
302namespace {
303/// Progressive lowering of TransposeOp.
304/// One:
305/// %x = vector.transpose %y, [1, 0]
306/// is replaced by:
307/// %z = arith.constant dense<0.000000e+00>
308/// %0 = vector.extract %y[0, 0]
309/// %1 = vector.insert %0, %z [0, 0]
310/// ..
311/// %x = vector.insert .., .. [.., ..]
312class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
313public:
314 using OpRewritePattern::OpRewritePattern;
315
316 TransposeOpLowering(vector::VectorTransformsOptions vectorTransformOptions,
317 MLIRContext *context, PatternBenefit benefit = 1)
318 : OpRewritePattern<vector::TransposeOp>(context, benefit),
319 vectorTransformOptions(vectorTransformOptions) {}
320
321 LogicalResult matchAndRewrite(vector::TransposeOp op,
322 PatternRewriter &rewriter) const override {
323 auto loc = op.getLoc();
324
325 Value input = op.getVector();
326 VectorType inputType = op.getSourceVectorType();
327 VectorType resType = op.getResultVectorType();
328
329 // Set up convenience transposition table.
330 ArrayRef<int64_t> transp = op.getPermutation();
331
332 if (isShuffleLike(vectorTransformOptions.vectorTransposeLowering) &&
333 succeeded(isTranspose2DSlice(op)))
334 return rewriter.notifyMatchFailure(
335 op, "Options specifies lowering to shuffle");
336
337 // Replace:
338 // vector.transpose %0, [1, 0] : vector<nx1x<eltty>> to
339 // vector<1xnxelty>
340 // with:
341 // vector.shape_cast %0 : vector<nx1x<eltty>> to vector<1xnxelty>
342 //
343 // Source with leading unit dim (inverse) is also replaced. Unit dim must
344 // be fixed. Non-unit can be scalable.
345 if (resType.getRank() == 2 &&
346 ((resType.getShape().front() == 1 &&
347 !resType.getScalableDims().front()) ||
348 (resType.getShape().back() == 1 &&
349 !resType.getScalableDims().back())) &&
350 transp == ArrayRef<int64_t>({1, 0})) {
351 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, input);
352 return success();
353 }
354
355 if (inputType.isScalable())
356 return failure();
357
358 // Handle a true 2-D matrix transpose differently when requested.
359 if (vectorTransformOptions.vectorTransposeLowering ==
360 vector::VectorTransposeLowering::Flat &&
361 resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) {
362 Type flattenedType =
363 VectorType::get(resType.getNumElements(), resType.getElementType());
364 auto matrix =
365 rewriter.create<vector::ShapeCastOp>(loc, flattenedType, input);
366 auto rows = rewriter.getI32IntegerAttr(value: resType.getShape()[0]);
367 auto columns = rewriter.getI32IntegerAttr(value: resType.getShape()[1]);
368 Value trans = rewriter.create<vector::FlatTransposeOp>(
369 loc, flattenedType, matrix, rows, columns);
370 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, trans);
371 return success();
372 }
373
374 // Generate unrolled extract/insert ops. We do not unroll the rightmost
375 // (i.e., highest-order) dimensions that are not transposed and leave them
376 // in vector form to improve performance. Therefore, we prune those
377 // dimensions from the shape/transpose data structures used to generate the
378 // extract/insert ops.
379 SmallVector<int64_t> prunedTransp;
380 pruneNonTransposedDims(transpose: transp, result&: prunedTransp);
381 size_t numPrunedDims = transp.size() - prunedTransp.size();
382 auto prunedInShape = inputType.getShape().drop_back(numPrunedDims);
383 auto prunedInStrides = computeStrides(prunedInShape);
384
385 // Generates the extract/insert operations for every scalar/vector element
386 // of the leftmost transposed dimensions. We traverse every transpose
387 // element using a linearized index that we delinearize to generate the
388 // appropriate indices for the extract/insert operations.
389 Value result = rewriter.create<arith::ConstantOp>(
390 loc, resType, rewriter.getZeroAttr(resType));
391 int64_t numTransposedElements = ShapedType::getNumElements(prunedInShape);
392
393 for (int64_t linearIdx = 0; linearIdx < numTransposedElements;
394 ++linearIdx) {
395 auto extractIdxs = delinearize(linearIdx, prunedInStrides);
396 SmallVector<int64_t> insertIdxs(extractIdxs);
397 applyPermutationToVector(inVec&: insertIdxs, permutation: prunedTransp);
398 Value extractOp =
399 rewriter.create<vector::ExtractOp>(loc, input, extractIdxs);
400 result =
401 rewriter.create<vector::InsertOp>(loc, extractOp, result, insertIdxs);
402 }
403
404 rewriter.replaceOp(op, result);
405 return success();
406 }
407
408private:
409 /// Options to control the vector patterns.
410 vector::VectorTransformsOptions vectorTransformOptions;
411};
412
413/// Rewrite a 2-D vector.transpose as a sequence of shuffle ops.
414/// If the strategy is Shuffle1D, it will be lowered to:
415/// vector.shape_cast 2D -> 1D
416/// vector.shuffle
417/// vector.shape_cast 1D -> 2D
418/// If the strategy is Shuffle16x16, it will be lowered to a sequence of shuffle
419/// ops on 16xf32 vectors.
420class TransposeOp2DToShuffleLowering
421 : public OpRewritePattern<vector::TransposeOp> {
422public:
423 using OpRewritePattern::OpRewritePattern;
424
425 TransposeOp2DToShuffleLowering(
426 vector::VectorTransformsOptions vectorTransformOptions,
427 MLIRContext *context, PatternBenefit benefit = 1)
428 : OpRewritePattern<vector::TransposeOp>(context, benefit),
429 vectorTransformOptions(vectorTransformOptions) {}
430
431 LogicalResult matchAndRewrite(vector::TransposeOp op,
432 PatternRewriter &rewriter) const override {
433 if (!isShuffleLike(vectorTransformOptions.vectorTransposeLowering))
434 return rewriter.notifyMatchFailure(
435 op, "not using vector shuffle based lowering");
436
437 if (op.getSourceVectorType().isScalable())
438 return rewriter.notifyMatchFailure(
439 op, "vector shuffle lowering not supported for scalable vectors");
440
441 auto srcGtOneDims = isTranspose2DSlice(op);
442 if (failed(srcGtOneDims))
443 return rewriter.notifyMatchFailure(
444 op, "expected transposition on a 2D slice");
445
446 VectorType srcType = op.getSourceVectorType();
447 int64_t m = srcType.getDimSize(std::get<0>(srcGtOneDims.value()));
448 int64_t n = srcType.getDimSize(std::get<1>(srcGtOneDims.value()));
449
450 // Reshape the n-D input vector with only two dimensions greater than one
451 // to a 2-D vector.
452 Location loc = op.getLoc();
453 auto flattenedType = VectorType::get({n * m}, srcType.getElementType());
454 auto reshInputType = VectorType::get({m, n}, srcType.getElementType());
455 auto reshInput = rewriter.create<vector::ShapeCastOp>(loc, flattenedType,
456 op.getVector());
457
458 Value res;
459 if (vectorTransformOptions.vectorTransposeLowering ==
460 VectorTransposeLowering::Shuffle16x16 &&
461 m == 16 && n == 16) {
462 reshInput =
463 rewriter.create<vector::ShapeCastOp>(loc, reshInputType, reshInput);
464 res = transposeToShuffle16x16(rewriter, reshInput, m, n);
465 } else {
466 // Fallback to shuffle on 1D approach.
467 res = transposeToShuffle1D(rewriter, reshInput, m, n);
468 }
469
470 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
471 op, op.getResultVectorType(), res);
472
473 return success();
474 }
475
476private:
477 /// Options to control the vector patterns.
478 vector::VectorTransformsOptions vectorTransformOptions;
479};
480} // namespace
481
482void mlir::vector::populateVectorTransposeLoweringPatterns(
483 RewritePatternSet &patterns, VectorTransformsOptions options,
484 PatternBenefit benefit) {
485 patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>(
486 arg&: options, args: patterns.getContext(), args&: benefit);
487}
488

source code of mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp