1//===- LowerVectorTranspose.cpp - Lower 'vector.transpose' operation ------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements target-independent rewrites and utilities to lower the
10// 'vector.transpose' operation.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/MemRef/IR/MemRef.h"
15#include "mlir/Dialect/UB/IR/UBOps.h"
16#include "mlir/Dialect/Utils/IndexingUtils.h"
17#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
18#include "mlir/Dialect/Vector/IR/VectorOps.h"
19#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
20#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
21#include "mlir/IR/BuiltinTypes.h"
22#include "mlir/IR/ImplicitLocOpBuilder.h"
23#include "mlir/IR/Location.h"
24#include "mlir/IR/PatternMatch.h"
25#include "mlir/IR/TypeUtilities.h"
26
27#define DEBUG_TYPE "lower-vector-transpose"
28
29using namespace mlir;
30using namespace mlir::vector;
31
32/// Given a 'transpose' pattern, prune the rightmost dimensions that are not
33/// transposed.
34static void pruneNonTransposedDims(ArrayRef<int64_t> transpose,
35 SmallVectorImpl<int64_t> &result) {
36 size_t numTransposedDims = transpose.size();
37 for (size_t transpDim : llvm::reverse(C&: transpose)) {
38 if (transpDim != numTransposedDims - 1)
39 break;
40 numTransposedDims--;
41 }
42
43 result.append(in_start: transpose.begin(), in_end: transpose.begin() + numTransposedDims);
44}
45
46/// Returns true if the lowering option is a vector shuffle based approach.
47static bool isShuffleLike(VectorTransposeLowering lowering) {
48 return lowering == VectorTransposeLowering::Shuffle1D ||
49 lowering == VectorTransposeLowering::Shuffle16x16;
50}
51
52/// Returns a shuffle mask that builds on `vals`. `vals` is the offset base of
53/// shuffle ops, i.e., the unpack pattern. The method iterates with `vals` to
54/// create the mask for `numBits` bits vector. The `numBits` have to be a
55/// multiple of 128. For example, if `vals` is {0, 1, 16, 17} and `numBits` is
56/// 512, there should be 16 elements in the final result. It constructs the
57/// below mask to get the unpack elements.
58/// [0, 1, 16, 17,
59/// 0+4, 1+4, 16+4, 17+4,
60/// 0+8, 1+8, 16+8, 17+8,
61/// 0+12, 1+12, 16+12, 17+12]
62static SmallVector<int64_t>
63getUnpackShufflePermFor128Lane(ArrayRef<int64_t> vals, int numBits) {
64 assert(numBits % 128 == 0 && "expected numBits is a multiple of 128");
65 int numElem = numBits / 32;
66 SmallVector<int64_t> res;
67 for (int i = 0; i < numElem; i += 4)
68 for (int64_t v : vals)
69 res.push_back(Elt: v + i);
70 return res;
71}
72
73/// Lower to vector.shuffle on v1 and v2 with UnpackLoPd shuffle mask. For
74/// example, if it is targeting 512 bit vector, returns
75/// vector.shuffle on v1, v2, [0, 1, 16, 17,
76/// 0+4, 1+4, 16+4, 17+4,
77/// 0+8, 1+8, 16+8, 17+8,
78/// 0+12, 1+12, 16+12, 17+12].
79static Value createUnpackLoPd(ImplicitLocOpBuilder &b, Value v1, Value v2,
80 int numBits) {
81 int numElem = numBits / 32;
82 return b.create<vector::ShuffleOp>(
83 args&: v1, args&: v2,
84 args: getUnpackShufflePermFor128Lane(vals: {0, 1, numElem, numElem + 1}, numBits));
85}
86
87/// Lower to vector.shuffle on v1 and v2 with UnpackHiPd shuffle mask. For
88/// example, if it is targeting 512 bit vector, returns
89/// vector.shuffle, v1, v2, [2, 3, 18, 19,
90/// 2+4, 3+4, 18+4, 19+4,
91/// 2+8, 3+8, 18+8, 19+8,
92/// 2+12, 3+12, 18+12, 19+12].
93static Value createUnpackHiPd(ImplicitLocOpBuilder &b, Value v1, Value v2,
94 int numBits) {
95 int numElem = numBits / 32;
96 return b.create<vector::ShuffleOp>(
97 args&: v1, args&: v2,
98 args: getUnpackShufflePermFor128Lane(vals: {2, 3, numElem + 2, numElem + 3},
99 numBits));
100}
101
102/// Lower to vector.shuffle on v1 and v2 with UnpackLoPs shuffle mask. For
103/// example, if it is targeting 512 bit vector, returns
104/// vector.shuffle, v1, v2, [0, 16, 1, 17,
105/// 0+4, 16+4, 1+4, 17+4,
106/// 0+8, 16+8, 1+8, 17+8,
107/// 0+12, 16+12, 1+12, 17+12].
108static Value createUnpackLoPs(ImplicitLocOpBuilder &b, Value v1, Value v2,
109 int numBits) {
110 int numElem = numBits / 32;
111 auto shuffle = b.create<vector::ShuffleOp>(
112 args&: v1, args&: v2,
113 args: getUnpackShufflePermFor128Lane(vals: {0, numElem, 1, numElem + 1}, numBits));
114 return shuffle;
115}
116
117/// Lower to vector.shuffle on v1 and v2 with UnpackHiPs shuffle mask. For
118/// example, if it is targeting 512 bit vector, returns
119/// vector.shuffle, v1, v2, [2, 18, 3, 19,
120/// 2+4, 18+4, 3+4, 19+4,
121/// 2+8, 18+8, 3+8, 19+8,
122/// 2+12, 18+12, 3+12, 19+12].
123static Value createUnpackHiPs(ImplicitLocOpBuilder &b, Value v1, Value v2,
124 int numBits) {
125 int numElem = numBits / 32;
126 return b.create<vector::ShuffleOp>(
127 args&: v1, args&: v2,
128 args: getUnpackShufflePermFor128Lane(vals: {2, numElem + 2, 3, numElem + 3},
129 numBits));
130}
131
132/// Returns a vector.shuffle that shuffles 128-bit lanes (composed of 4 32-bit
133/// elements) selected by `mask` from `v1` and `v2`. I.e.,
134///
135/// DEFINE SELECT4(src, control) {
136/// CASE(control[1:0]) OF
137/// 0: tmp[127:0] := src[127:0]
138/// 1: tmp[127:0] := src[255:128]
139/// 2: tmp[127:0] := src[383:256]
140/// 3: tmp[127:0] := src[511:384]
141/// ESAC
142/// RETURN tmp[127:0]
143/// }
144/// dst[127:0] := SELECT4(v1[511:0], mask[1:0])
145/// dst[255:128] := SELECT4(v1[511:0], mask[3:2])
146/// dst[383:256] := SELECT4(v2[511:0], mask[5:4])
147/// dst[511:384] := SELECT4(v2[511:0], mask[7:6])
148static Value create4x128BitSuffle(ImplicitLocOpBuilder &b, Value v1, Value v2,
149 uint8_t mask) {
150 assert(cast<VectorType>(v1.getType()).getShape()[0] == 16 &&
151 "expected a vector with length=16");
152 SmallVector<int64_t> shuffleMask;
153 auto appendToMask = [&](int64_t base, uint8_t control) {
154 switch (control) {
155 case 0:
156 llvm::append_range(C&: shuffleMask, R: ArrayRef<int64_t>{base + 0, base + 1,
157 base + 2, base + 3});
158 break;
159 case 1:
160 llvm::append_range(C&: shuffleMask, R: ArrayRef<int64_t>{base + 4, base + 5,
161 base + 6, base + 7});
162 break;
163 case 2:
164 llvm::append_range(C&: shuffleMask, R: ArrayRef<int64_t>{base + 8, base + 9,
165 base + 10, base + 11});
166 break;
167 case 3:
168 llvm::append_range(C&: shuffleMask, R: ArrayRef<int64_t>{base + 12, base + 13,
169 base + 14, base + 15});
170 break;
171 default:
172 llvm_unreachable("control > 3 : overflow");
173 }
174 };
175 uint8_t b01 = mask & 0x3;
176 uint8_t b23 = (mask >> 2) & 0x3;
177 uint8_t b45 = (mask >> 4) & 0x3;
178 uint8_t b67 = (mask >> 6) & 0x3;
179 appendToMask(0, b01);
180 appendToMask(0, b23);
181 appendToMask(16, b45);
182 appendToMask(16, b67);
183 return b.create<vector::ShuffleOp>(args&: v1, args&: v2, args&: shuffleMask);
184}
185
186/// Lowers the value to a vector.shuffle op. The `source` is expected to be a
187/// 1-D vector and have `m`x`n` elements.
188static Value transposeToShuffle1D(OpBuilder &b, Value source, int m, int n) {
189 SmallVector<int64_t> mask;
190 mask.reserve(N: m * n);
191 for (int64_t j = 0; j < n; ++j)
192 for (int64_t i = 0; i < m; ++i)
193 mask.push_back(Elt: i * n + j);
194 return b.create<vector::ShuffleOp>(location: source.getLoc(), args&: source, args&: source, args&: mask);
195}
196
197/// Lowers the value to a sequence of vector.shuffle ops. The `source` is
198/// expected to be a 16x16 vector.
199static Value transposeToShuffle16x16(OpBuilder &builder, Value source, int m,
200 int n) {
201 ImplicitLocOpBuilder b(source.getLoc(), builder);
202 SmallVector<Value> vs;
203 for (int64_t i = 0; i < m; ++i)
204 vs.push_back(Elt: b.createOrFold<vector::ExtractOp>(args&: source, args&: i));
205
206 // Interleave 32-bit lanes using
207 // 8x _mm512_unpacklo_epi32
208 // 8x _mm512_unpackhi_epi32
209 Value t0 = createUnpackLoPs(b, v1: vs[0x0], v2: vs[0x1], numBits: 512);
210 Value t1 = createUnpackHiPs(b, v1: vs[0x0], v2: vs[0x1], numBits: 512);
211 Value t2 = createUnpackLoPs(b, v1: vs[0x2], v2: vs[0x3], numBits: 512);
212 Value t3 = createUnpackHiPs(b, v1: vs[0x2], v2: vs[0x3], numBits: 512);
213 Value t4 = createUnpackLoPs(b, v1: vs[0x4], v2: vs[0x5], numBits: 512);
214 Value t5 = createUnpackHiPs(b, v1: vs[0x4], v2: vs[0x5], numBits: 512);
215 Value t6 = createUnpackLoPs(b, v1: vs[0x6], v2: vs[0x7], numBits: 512);
216 Value t7 = createUnpackHiPs(b, v1: vs[0x6], v2: vs[0x7], numBits: 512);
217 Value t8 = createUnpackLoPs(b, v1: vs[0x8], v2: vs[0x9], numBits: 512);
218 Value t9 = createUnpackHiPs(b, v1: vs[0x8], v2: vs[0x9], numBits: 512);
219 Value ta = createUnpackLoPs(b, v1: vs[0xa], v2: vs[0xb], numBits: 512);
220 Value tb = createUnpackHiPs(b, v1: vs[0xa], v2: vs[0xb], numBits: 512);
221 Value tc = createUnpackLoPs(b, v1: vs[0xc], v2: vs[0xd], numBits: 512);
222 Value td = createUnpackHiPs(b, v1: vs[0xc], v2: vs[0xd], numBits: 512);
223 Value te = createUnpackLoPs(b, v1: vs[0xe], v2: vs[0xf], numBits: 512);
224 Value tf = createUnpackHiPs(b, v1: vs[0xe], v2: vs[0xf], numBits: 512);
225
226 // Interleave 64-bit lanes using
227 // 8x _mm512_unpacklo_epi64
228 // 8x _mm512_unpackhi_epi64
229 Value r0 = createUnpackLoPd(b, v1: t0, v2: t2, numBits: 512);
230 Value r1 = createUnpackHiPd(b, v1: t0, v2: t2, numBits: 512);
231 Value r2 = createUnpackLoPd(b, v1: t1, v2: t3, numBits: 512);
232 Value r3 = createUnpackHiPd(b, v1: t1, v2: t3, numBits: 512);
233 Value r4 = createUnpackLoPd(b, v1: t4, v2: t6, numBits: 512);
234 Value r5 = createUnpackHiPd(b, v1: t4, v2: t6, numBits: 512);
235 Value r6 = createUnpackLoPd(b, v1: t5, v2: t7, numBits: 512);
236 Value r7 = createUnpackHiPd(b, v1: t5, v2: t7, numBits: 512);
237 Value r8 = createUnpackLoPd(b, v1: t8, v2: ta, numBits: 512);
238 Value r9 = createUnpackHiPd(b, v1: t8, v2: ta, numBits: 512);
239 Value ra = createUnpackLoPd(b, v1: t9, v2: tb, numBits: 512);
240 Value rb = createUnpackHiPd(b, v1: t9, v2: tb, numBits: 512);
241 Value rc = createUnpackLoPd(b, v1: tc, v2: te, numBits: 512);
242 Value rd = createUnpackHiPd(b, v1: tc, v2: te, numBits: 512);
243 Value re = createUnpackLoPd(b, v1: td, v2: tf, numBits: 512);
244 Value rf = createUnpackHiPd(b, v1: td, v2: tf, numBits: 512);
245
246 // Permute 128-bit lanes using
247 // 16x _mm512_shuffle_i32x4
248 t0 = create4x128BitSuffle(b, v1: r0, v2: r4, mask: 0x88);
249 t1 = create4x128BitSuffle(b, v1: r1, v2: r5, mask: 0x88);
250 t2 = create4x128BitSuffle(b, v1: r2, v2: r6, mask: 0x88);
251 t3 = create4x128BitSuffle(b, v1: r3, v2: r7, mask: 0x88);
252 t4 = create4x128BitSuffle(b, v1: r0, v2: r4, mask: 0xdd);
253 t5 = create4x128BitSuffle(b, v1: r1, v2: r5, mask: 0xdd);
254 t6 = create4x128BitSuffle(b, v1: r2, v2: r6, mask: 0xdd);
255 t7 = create4x128BitSuffle(b, v1: r3, v2: r7, mask: 0xdd);
256 t8 = create4x128BitSuffle(b, v1: r8, v2: rc, mask: 0x88);
257 t9 = create4x128BitSuffle(b, v1: r9, v2: rd, mask: 0x88);
258 ta = create4x128BitSuffle(b, v1: ra, v2: re, mask: 0x88);
259 tb = create4x128BitSuffle(b, v1: rb, v2: rf, mask: 0x88);
260 tc = create4x128BitSuffle(b, v1: r8, v2: rc, mask: 0xdd);
261 td = create4x128BitSuffle(b, v1: r9, v2: rd, mask: 0xdd);
262 te = create4x128BitSuffle(b, v1: ra, v2: re, mask: 0xdd);
263 tf = create4x128BitSuffle(b, v1: rb, v2: rf, mask: 0xdd);
264
265 // Permute 256-bit lanes using again
266 // 16x _mm512_shuffle_i32x4
267 vs[0x0] = create4x128BitSuffle(b, v1: t0, v2: t8, mask: 0x88);
268 vs[0x1] = create4x128BitSuffle(b, v1: t1, v2: t9, mask: 0x88);
269 vs[0x2] = create4x128BitSuffle(b, v1: t2, v2: ta, mask: 0x88);
270 vs[0x3] = create4x128BitSuffle(b, v1: t3, v2: tb, mask: 0x88);
271 vs[0x4] = create4x128BitSuffle(b, v1: t4, v2: tc, mask: 0x88);
272 vs[0x5] = create4x128BitSuffle(b, v1: t5, v2: td, mask: 0x88);
273 vs[0x6] = create4x128BitSuffle(b, v1: t6, v2: te, mask: 0x88);
274 vs[0x7] = create4x128BitSuffle(b, v1: t7, v2: tf, mask: 0x88);
275 vs[0x8] = create4x128BitSuffle(b, v1: t0, v2: t8, mask: 0xdd);
276 vs[0x9] = create4x128BitSuffle(b, v1: t1, v2: t9, mask: 0xdd);
277 vs[0xa] = create4x128BitSuffle(b, v1: t2, v2: ta, mask: 0xdd);
278 vs[0xb] = create4x128BitSuffle(b, v1: t3, v2: tb, mask: 0xdd);
279 vs[0xc] = create4x128BitSuffle(b, v1: t4, v2: tc, mask: 0xdd);
280 vs[0xd] = create4x128BitSuffle(b, v1: t5, v2: td, mask: 0xdd);
281 vs[0xe] = create4x128BitSuffle(b, v1: t6, v2: te, mask: 0xdd);
282 vs[0xf] = create4x128BitSuffle(b, v1: t7, v2: tf, mask: 0xdd);
283
284 auto reshInputType = VectorType::get(
285 shape: {m, n}, elementType: cast<VectorType>(Val: source.getType()).getElementType());
286 Value res = b.create<ub::PoisonOp>(args&: reshInputType);
287 for (int64_t i = 0; i < m; ++i)
288 res = b.create<vector::InsertOp>(args&: vs[i], args&: res, args&: i);
289 return res;
290}
291
292namespace {
293/// Progressive lowering of TransposeOp.
294/// One:
295/// %x = vector.transpose %y, [1, 0]
296/// is replaced by:
297/// %z = arith.constant dense<0.000000e+00>
298/// %0 = vector.extract %y[0, 0]
299/// %1 = vector.insert %0, %z [0, 0]
300/// ..
301/// %x = vector.insert .., .. [.., ..]
302class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
303public:
304 using OpRewritePattern::OpRewritePattern;
305
306 TransposeOpLowering(vector::VectorTransposeLowering vectorTransposeLowering,
307 MLIRContext *context, PatternBenefit benefit = 1)
308 : OpRewritePattern<vector::TransposeOp>(context, benefit),
309 vectorTransposeLowering(vectorTransposeLowering) {}
310
311 LogicalResult matchAndRewrite(vector::TransposeOp op,
312 PatternRewriter &rewriter) const override {
313 auto loc = op.getLoc();
314
315 Value input = op.getVector();
316 VectorType inputType = op.getSourceVectorType();
317 VectorType resType = op.getResultVectorType();
318
319 if (inputType.isScalable())
320 return rewriter.notifyMatchFailure(
321 arg&: op, msg: "This lowering does not support scalable vectors");
322
323 // Set up convenience transposition table.
324 ArrayRef<int64_t> transp = op.getPermutation();
325
326 if (isShuffleLike(lowering: vectorTransposeLowering) &&
327 succeeded(Result: isTranspose2DSlice(op)))
328 return rewriter.notifyMatchFailure(
329 arg&: op, msg: "Options specifies lowering to shuffle");
330
331 // Handle a true 2-D matrix transpose differently when requested.
332 if (vectorTransposeLowering == vector::VectorTransposeLowering::Flat &&
333 resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) {
334 Type flattenedType =
335 VectorType::get(shape: resType.getNumElements(), elementType: resType.getElementType());
336 auto matrix =
337 rewriter.create<vector::ShapeCastOp>(location: loc, args&: flattenedType, args&: input);
338 auto rows = rewriter.getI32IntegerAttr(value: resType.getShape()[0]);
339 auto columns = rewriter.getI32IntegerAttr(value: resType.getShape()[1]);
340 Value trans = rewriter.create<vector::FlatTransposeOp>(
341 location: loc, args&: flattenedType, args&: matrix, args&: rows, args&: columns);
342 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, args&: resType, args&: trans);
343 return success();
344 }
345
346 // Generate unrolled extract/insert ops. We do not unroll the rightmost
347 // (i.e., highest-order) dimensions that are not transposed and leave them
348 // in vector form to improve performance. Therefore, we prune those
349 // dimensions from the shape/transpose data structures used to generate the
350 // extract/insert ops.
351 SmallVector<int64_t> prunedTransp;
352 pruneNonTransposedDims(transpose: transp, result&: prunedTransp);
353 size_t numPrunedDims = transp.size() - prunedTransp.size();
354 auto prunedInShape = inputType.getShape().drop_back(N: numPrunedDims);
355 auto prunedInStrides = computeStrides(sizes: prunedInShape);
356
357 // Generates the extract/insert operations for every scalar/vector element
358 // of the leftmost transposed dimensions. We traverse every transpose
359 // element using a linearized index that we delinearize to generate the
360 // appropriate indices for the extract/insert operations.
361 Value result = rewriter.create<ub::PoisonOp>(location: loc, args&: resType);
362 int64_t numTransposedElements = ShapedType::getNumElements(shape: prunedInShape);
363
364 for (int64_t linearIdx = 0; linearIdx < numTransposedElements;
365 ++linearIdx) {
366 auto extractIdxs = delinearize(linearIndex: linearIdx, strides: prunedInStrides);
367 SmallVector<int64_t> insertIdxs(extractIdxs);
368 applyPermutationToVector(inVec&: insertIdxs, permutation: prunedTransp);
369 Value extractOp =
370 rewriter.createOrFold<vector::ExtractOp>(location: loc, args&: input, args&: extractIdxs);
371 result = rewriter.createOrFold<vector::InsertOp>(location: loc, args&: extractOp, args&: result,
372 args&: insertIdxs);
373 }
374
375 rewriter.replaceOp(op, newValues: result);
376 return success();
377 }
378
379private:
380 /// Options to control the vector patterns.
381 vector::VectorTransposeLowering vectorTransposeLowering;
382};
383
384/// Rewrites vector.transpose as vector.shape_cast. This pattern is only applied
385/// to 2D vectors with at least one unit dim. For example:
386///
387/// Replace:
388/// vector.transpose %0, [1, 0] : vector<4x1xi32>> to
389/// vector<1x4xi32>
390/// with:
391/// vector.shape_cast %0 : vector<4x1xi32> to vector<1x4xi32>
392///
393/// Source with leading unit dim (inverse) is also replaced. Unit dim must
394/// be fixed. Non-unit dim can be scalable.
395///
396/// TODO: This pattern was introduced specifically to help lower scalable
397/// vectors. In hindsight, a more specialised canonicalization (for shape_cast's
398/// to cancel out) would be preferable:
399///
400/// BEFORE:
401/// %0 = some_op
402/// %1 = vector.shape_cast %0 : vector<[4]xf32> to vector<[4]x1xf32>
403/// %2 = vector.transpose %1 [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
404/// AFTER:
405/// %0 = some_op
406/// %1 = vector.shape_cast %0 : vector<[4]xf32> to vector<1x[4]xf32>
407///
408/// Given the context above, we may want to consider (re-)moving this pattern
409/// at some later time. I am leaving it for now in case there are other users
410/// that I am not aware of.
411class Transpose2DWithUnitDimToShapeCast
412 : public OpRewritePattern<vector::TransposeOp> {
413public:
414 using OpRewritePattern::OpRewritePattern;
415
416 Transpose2DWithUnitDimToShapeCast(MLIRContext *context,
417 PatternBenefit benefit = 1)
418 : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
419
420 LogicalResult matchAndRewrite(vector::TransposeOp op,
421 PatternRewriter &rewriter) const override {
422 Value input = op.getVector();
423 VectorType resType = op.getResultVectorType();
424
425 // Set up convenience transposition table.
426 ArrayRef<int64_t> transp = op.getPermutation();
427
428 if (resType.getRank() == 2 &&
429 ((resType.getShape().front() == 1 &&
430 !resType.getScalableDims().front()) ||
431 (resType.getShape().back() == 1 &&
432 !resType.getScalableDims().back())) &&
433 transp == ArrayRef<int64_t>({1, 0})) {
434 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, args&: resType, args&: input);
435 return success();
436 }
437
438 return failure();
439 }
440};
441
442/// Rewrite a 2-D vector.transpose as a sequence of shuffle ops.
443/// If the strategy is Shuffle1D, it will be lowered to:
444/// vector.shape_cast 2D -> 1D
445/// vector.shuffle
446/// vector.shape_cast 1D -> 2D
447/// If the strategy is Shuffle16x16, it will be lowered to a sequence of shuffle
448/// ops on 16xf32 vectors.
449class TransposeOp2DToShuffleLowering
450 : public OpRewritePattern<vector::TransposeOp> {
451public:
452 using OpRewritePattern::OpRewritePattern;
453
454 TransposeOp2DToShuffleLowering(
455 vector::VectorTransposeLowering vectorTransposeLowering,
456 MLIRContext *context, PatternBenefit benefit = 1)
457 : OpRewritePattern<vector::TransposeOp>(context, benefit),
458 vectorTransposeLowering(vectorTransposeLowering) {}
459
460 LogicalResult matchAndRewrite(vector::TransposeOp op,
461 PatternRewriter &rewriter) const override {
462 if (!isShuffleLike(lowering: vectorTransposeLowering))
463 return rewriter.notifyMatchFailure(
464 arg&: op, msg: "not using vector shuffle based lowering");
465
466 if (op.getSourceVectorType().isScalable())
467 return rewriter.notifyMatchFailure(
468 arg&: op, msg: "vector shuffle lowering not supported for scalable vectors");
469
470 auto srcGtOneDims = isTranspose2DSlice(op);
471 if (failed(Result: srcGtOneDims))
472 return rewriter.notifyMatchFailure(
473 arg&: op, msg: "expected transposition on a 2D slice");
474
475 VectorType srcType = op.getSourceVectorType();
476 int64_t m = srcType.getDimSize(idx: std::get<0>(in&: srcGtOneDims.value()));
477 int64_t n = srcType.getDimSize(idx: std::get<1>(in&: srcGtOneDims.value()));
478
479 // Reshape the n-D input vector with only two dimensions greater than one
480 // to a 2-D vector.
481 Location loc = op.getLoc();
482 auto flattenedType = VectorType::get(shape: {n * m}, elementType: srcType.getElementType());
483 auto reshInputType = VectorType::get(shape: {m, n}, elementType: srcType.getElementType());
484 auto reshInput = rewriter.create<vector::ShapeCastOp>(location: loc, args&: flattenedType,
485 args: op.getVector());
486
487 Value res;
488 if (vectorTransposeLowering == VectorTransposeLowering::Shuffle16x16 &&
489 m == 16 && n == 16) {
490 reshInput =
491 rewriter.create<vector::ShapeCastOp>(location: loc, args&: reshInputType, args&: reshInput);
492 res = transposeToShuffle16x16(builder&: rewriter, source: reshInput, m, n);
493 } else {
494 // Fallback to shuffle on 1D approach.
495 res = transposeToShuffle1D(b&: rewriter, source: reshInput, m, n);
496 }
497
498 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
499 op, args: op.getResultVectorType(), args&: res);
500
501 return success();
502 }
503
504private:
505 /// Options to control the vector patterns.
506 vector::VectorTransposeLowering vectorTransposeLowering;
507};
508} // namespace
509
510void mlir::vector::populateVectorTransposeLoweringPatterns(
511 RewritePatternSet &patterns,
512 VectorTransposeLowering vectorTransposeLowering, PatternBenefit benefit) {
513 patterns.add<Transpose2DWithUnitDimToShapeCast>(arg: patterns.getContext(),
514 args&: benefit);
515 patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>(
516 arg&: vectorTransposeLowering, args: patterns.getContext(), args&: benefit);
517}
518

source code of mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp