1//===- LowerVectorTranspose.cpp - Lower 'vector.transpose' operation ------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements target-independent rewrites and utilities to lower the
10// 'vector.transpose' operation.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/Arith/IR/Arith.h"
15#include "mlir/Dialect/MemRef/IR/MemRef.h"
16#include "mlir/Dialect/UB/IR/UBOps.h"
17#include "mlir/Dialect/Utils/IndexingUtils.h"
18#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
19#include "mlir/Dialect/Vector/IR/VectorOps.h"
20#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
21#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
22#include "mlir/IR/BuiltinTypes.h"
23#include "mlir/IR/ImplicitLocOpBuilder.h"
24#include "mlir/IR/Location.h"
25#include "mlir/IR/PatternMatch.h"
26#include "mlir/IR/TypeUtilities.h"
27
28#define DEBUG_TYPE "lower-vector-transpose"
29
30using namespace mlir;
31using namespace mlir::vector;
32
33/// Given a 'transpose' pattern, prune the rightmost dimensions that are not
34/// transposed.
35static void pruneNonTransposedDims(ArrayRef<int64_t> transpose,
36 SmallVectorImpl<int64_t> &result) {
37 size_t numTransposedDims = transpose.size();
38 for (size_t transpDim : llvm::reverse(C&: transpose)) {
39 if (transpDim != numTransposedDims - 1)
40 break;
41 numTransposedDims--;
42 }
43
44 result.append(in_start: transpose.begin(), in_end: transpose.begin() + numTransposedDims);
45}
46
47/// Returns true if the lowering option is a vector shuffle based approach.
48static bool isShuffleLike(VectorTransposeLowering lowering) {
49 return lowering == VectorTransposeLowering::Shuffle1D ||
50 lowering == VectorTransposeLowering::Shuffle16x16;
51}
52
53/// Returns a shuffle mask that builds on `vals`. `vals` is the offset base of
54/// shuffle ops, i.e., the unpack pattern. The method iterates with `vals` to
55/// create the mask for `numBits` bits vector. The `numBits` have to be a
56/// multiple of 128. For example, if `vals` is {0, 1, 16, 17} and `numBits` is
57/// 512, there should be 16 elements in the final result. It constructs the
58/// below mask to get the unpack elements.
59/// [0, 1, 16, 17,
60/// 0+4, 1+4, 16+4, 17+4,
61/// 0+8, 1+8, 16+8, 17+8,
62/// 0+12, 1+12, 16+12, 17+12]
63static SmallVector<int64_t>
64getUnpackShufflePermFor128Lane(ArrayRef<int64_t> vals, int numBits) {
65 assert(numBits % 128 == 0 && "expected numBits is a multiple of 128");
66 int numElem = numBits / 32;
67 SmallVector<int64_t> res;
68 for (int i = 0; i < numElem; i += 4)
69 for (int64_t v : vals)
70 res.push_back(Elt: v + i);
71 return res;
72}
73
74/// Lower to vector.shuffle on v1 and v2 with UnpackLoPd shuffle mask. For
75/// example, if it is targeting 512 bit vector, returns
76/// vector.shuffle on v1, v2, [0, 1, 16, 17,
77/// 0+4, 1+4, 16+4, 17+4,
78/// 0+8, 1+8, 16+8, 17+8,
79/// 0+12, 1+12, 16+12, 17+12].
80static Value createUnpackLoPd(ImplicitLocOpBuilder &b, Value v1, Value v2,
81 int numBits) {
82 int numElem = numBits / 32;
83 return b.create<vector::ShuffleOp>(
84 v1, v2,
85 getUnpackShufflePermFor128Lane({0, 1, numElem, numElem + 1}, numBits));
86}
87
88/// Lower to vector.shuffle on v1 and v2 with UnpackHiPd shuffle mask. For
89/// example, if it is targeting 512 bit vector, returns
90/// vector.shuffle, v1, v2, [2, 3, 18, 19,
91/// 2+4, 3+4, 18+4, 19+4,
92/// 2+8, 3+8, 18+8, 19+8,
93/// 2+12, 3+12, 18+12, 19+12].
94static Value createUnpackHiPd(ImplicitLocOpBuilder &b, Value v1, Value v2,
95 int numBits) {
96 int numElem = numBits / 32;
97 return b.create<vector::ShuffleOp>(
98 v1, v2,
99 getUnpackShufflePermFor128Lane({2, 3, numElem + 2, numElem + 3},
100 numBits));
101}
102
103/// Lower to vector.shuffle on v1 and v2 with UnpackLoPs shuffle mask. For
104/// example, if it is targeting 512 bit vector, returns
105/// vector.shuffle, v1, v2, [0, 16, 1, 17,
106/// 0+4, 16+4, 1+4, 17+4,
107/// 0+8, 16+8, 1+8, 17+8,
108/// 0+12, 16+12, 1+12, 17+12].
109static Value createUnpackLoPs(ImplicitLocOpBuilder &b, Value v1, Value v2,
110 int numBits) {
111 int numElem = numBits / 32;
112 auto shuffle = b.create<vector::ShuffleOp>(
113 v1, v2,
114 getUnpackShufflePermFor128Lane({0, numElem, 1, numElem + 1}, numBits));
115 return shuffle;
116}
117
118/// Lower to vector.shuffle on v1 and v2 with UnpackHiPs shuffle mask. For
119/// example, if it is targeting 512 bit vector, returns
120/// vector.shuffle, v1, v2, [2, 18, 3, 19,
121/// 2+4, 18+4, 3+4, 19+4,
122/// 2+8, 18+8, 3+8, 19+8,
123/// 2+12, 18+12, 3+12, 19+12].
124static Value createUnpackHiPs(ImplicitLocOpBuilder &b, Value v1, Value v2,
125 int numBits) {
126 int numElem = numBits / 32;
127 return b.create<vector::ShuffleOp>(
128 v1, v2,
129 getUnpackShufflePermFor128Lane({2, numElem + 2, 3, numElem + 3},
130 numBits));
131}
132
133/// Returns a vector.shuffle that shuffles 128-bit lanes (composed of 4 32-bit
134/// elements) selected by `mask` from `v1` and `v2`. I.e.,
135///
136/// DEFINE SELECT4(src, control) {
137/// CASE(control[1:0]) OF
138/// 0: tmp[127:0] := src[127:0]
139/// 1: tmp[127:0] := src[255:128]
140/// 2: tmp[127:0] := src[383:256]
141/// 3: tmp[127:0] := src[511:384]
142/// ESAC
143/// RETURN tmp[127:0]
144/// }
145/// dst[127:0] := SELECT4(v1[511:0], mask[1:0])
146/// dst[255:128] := SELECT4(v1[511:0], mask[3:2])
147/// dst[383:256] := SELECT4(v2[511:0], mask[5:4])
148/// dst[511:384] := SELECT4(v2[511:0], mask[7:6])
149static Value create4x128BitSuffle(ImplicitLocOpBuilder &b, Value v1, Value v2,
150 uint8_t mask) {
151 assert(cast<VectorType>(v1.getType()).getShape()[0] == 16 &&
152 "expected a vector with length=16");
153 SmallVector<int64_t> shuffleMask;
154 auto appendToMask = [&](int64_t base, uint8_t control) {
155 switch (control) {
156 case 0:
157 llvm::append_range(C&: shuffleMask, R: ArrayRef<int64_t>{base + 0, base + 1,
158 base + 2, base + 3});
159 break;
160 case 1:
161 llvm::append_range(C&: shuffleMask, R: ArrayRef<int64_t>{base + 4, base + 5,
162 base + 6, base + 7});
163 break;
164 case 2:
165 llvm::append_range(C&: shuffleMask, R: ArrayRef<int64_t>{base + 8, base + 9,
166 base + 10, base + 11});
167 break;
168 case 3:
169 llvm::append_range(C&: shuffleMask, R: ArrayRef<int64_t>{base + 12, base + 13,
170 base + 14, base + 15});
171 break;
172 default:
173 llvm_unreachable("control > 3 : overflow");
174 }
175 };
176 uint8_t b01 = mask & 0x3;
177 uint8_t b23 = (mask >> 2) & 0x3;
178 uint8_t b45 = (mask >> 4) & 0x3;
179 uint8_t b67 = (mask >> 6) & 0x3;
180 appendToMask(0, b01);
181 appendToMask(0, b23);
182 appendToMask(16, b45);
183 appendToMask(16, b67);
184 return b.create<vector::ShuffleOp>(v1, v2, shuffleMask);
185}
186
187/// Lowers the value to a vector.shuffle op. The `source` is expected to be a
188/// 1-D vector and have `m`x`n` elements.
189static Value transposeToShuffle1D(OpBuilder &b, Value source, int m, int n) {
190 SmallVector<int64_t> mask;
191 mask.reserve(N: m * n);
192 for (int64_t j = 0; j < n; ++j)
193 for (int64_t i = 0; i < m; ++i)
194 mask.push_back(Elt: i * n + j);
195 return b.create<vector::ShuffleOp>(source.getLoc(), source, source, mask);
196}
197
198/// Lowers the value to a sequence of vector.shuffle ops. The `source` is
199/// expected to be a 16x16 vector.
200static Value transposeToShuffle16x16(OpBuilder &builder, Value source, int m,
201 int n) {
202 ImplicitLocOpBuilder b(source.getLoc(), builder);
203 SmallVector<Value> vs;
204 for (int64_t i = 0; i < m; ++i)
205 vs.push_back(b.createOrFold<vector::ExtractOp>(source, i));
206
207 // Interleave 32-bit lanes using
208 // 8x _mm512_unpacklo_epi32
209 // 8x _mm512_unpackhi_epi32
210 Value t0 = createUnpackLoPs(b, v1: vs[0x0], v2: vs[0x1], numBits: 512);
211 Value t1 = createUnpackHiPs(b, v1: vs[0x0], v2: vs[0x1], numBits: 512);
212 Value t2 = createUnpackLoPs(b, v1: vs[0x2], v2: vs[0x3], numBits: 512);
213 Value t3 = createUnpackHiPs(b, v1: vs[0x2], v2: vs[0x3], numBits: 512);
214 Value t4 = createUnpackLoPs(b, v1: vs[0x4], v2: vs[0x5], numBits: 512);
215 Value t5 = createUnpackHiPs(b, v1: vs[0x4], v2: vs[0x5], numBits: 512);
216 Value t6 = createUnpackLoPs(b, v1: vs[0x6], v2: vs[0x7], numBits: 512);
217 Value t7 = createUnpackHiPs(b, v1: vs[0x6], v2: vs[0x7], numBits: 512);
218 Value t8 = createUnpackLoPs(b, v1: vs[0x8], v2: vs[0x9], numBits: 512);
219 Value t9 = createUnpackHiPs(b, v1: vs[0x8], v2: vs[0x9], numBits: 512);
220 Value ta = createUnpackLoPs(b, v1: vs[0xa], v2: vs[0xb], numBits: 512);
221 Value tb = createUnpackHiPs(b, v1: vs[0xa], v2: vs[0xb], numBits: 512);
222 Value tc = createUnpackLoPs(b, v1: vs[0xc], v2: vs[0xd], numBits: 512);
223 Value td = createUnpackHiPs(b, v1: vs[0xc], v2: vs[0xd], numBits: 512);
224 Value te = createUnpackLoPs(b, v1: vs[0xe], v2: vs[0xf], numBits: 512);
225 Value tf = createUnpackHiPs(b, v1: vs[0xe], v2: vs[0xf], numBits: 512);
226
227 // Interleave 64-bit lanes using
228 // 8x _mm512_unpacklo_epi64
229 // 8x _mm512_unpackhi_epi64
230 Value r0 = createUnpackLoPd(b, v1: t0, v2: t2, numBits: 512);
231 Value r1 = createUnpackHiPd(b, v1: t0, v2: t2, numBits: 512);
232 Value r2 = createUnpackLoPd(b, v1: t1, v2: t3, numBits: 512);
233 Value r3 = createUnpackHiPd(b, v1: t1, v2: t3, numBits: 512);
234 Value r4 = createUnpackLoPd(b, v1: t4, v2: t6, numBits: 512);
235 Value r5 = createUnpackHiPd(b, v1: t4, v2: t6, numBits: 512);
236 Value r6 = createUnpackLoPd(b, v1: t5, v2: t7, numBits: 512);
237 Value r7 = createUnpackHiPd(b, v1: t5, v2: t7, numBits: 512);
238 Value r8 = createUnpackLoPd(b, v1: t8, v2: ta, numBits: 512);
239 Value r9 = createUnpackHiPd(b, v1: t8, v2: ta, numBits: 512);
240 Value ra = createUnpackLoPd(b, v1: t9, v2: tb, numBits: 512);
241 Value rb = createUnpackHiPd(b, v1: t9, v2: tb, numBits: 512);
242 Value rc = createUnpackLoPd(b, v1: tc, v2: te, numBits: 512);
243 Value rd = createUnpackHiPd(b, v1: tc, v2: te, numBits: 512);
244 Value re = createUnpackLoPd(b, v1: td, v2: tf, numBits: 512);
245 Value rf = createUnpackHiPd(b, v1: td, v2: tf, numBits: 512);
246
247 // Permute 128-bit lanes using
248 // 16x _mm512_shuffle_i32x4
249 t0 = create4x128BitSuffle(b, v1: r0, v2: r4, mask: 0x88);
250 t1 = create4x128BitSuffle(b, v1: r1, v2: r5, mask: 0x88);
251 t2 = create4x128BitSuffle(b, v1: r2, v2: r6, mask: 0x88);
252 t3 = create4x128BitSuffle(b, v1: r3, v2: r7, mask: 0x88);
253 t4 = create4x128BitSuffle(b, v1: r0, v2: r4, mask: 0xdd);
254 t5 = create4x128BitSuffle(b, v1: r1, v2: r5, mask: 0xdd);
255 t6 = create4x128BitSuffle(b, v1: r2, v2: r6, mask: 0xdd);
256 t7 = create4x128BitSuffle(b, v1: r3, v2: r7, mask: 0xdd);
257 t8 = create4x128BitSuffle(b, v1: r8, v2: rc, mask: 0x88);
258 t9 = create4x128BitSuffle(b, v1: r9, v2: rd, mask: 0x88);
259 ta = create4x128BitSuffle(b, v1: ra, v2: re, mask: 0x88);
260 tb = create4x128BitSuffle(b, v1: rb, v2: rf, mask: 0x88);
261 tc = create4x128BitSuffle(b, v1: r8, v2: rc, mask: 0xdd);
262 td = create4x128BitSuffle(b, v1: r9, v2: rd, mask: 0xdd);
263 te = create4x128BitSuffle(b, v1: ra, v2: re, mask: 0xdd);
264 tf = create4x128BitSuffle(b, v1: rb, v2: rf, mask: 0xdd);
265
266 // Permute 256-bit lanes using again
267 // 16x _mm512_shuffle_i32x4
268 vs[0x0] = create4x128BitSuffle(b, v1: t0, v2: t8, mask: 0x88);
269 vs[0x1] = create4x128BitSuffle(b, v1: t1, v2: t9, mask: 0x88);
270 vs[0x2] = create4x128BitSuffle(b, v1: t2, v2: ta, mask: 0x88);
271 vs[0x3] = create4x128BitSuffle(b, v1: t3, v2: tb, mask: 0x88);
272 vs[0x4] = create4x128BitSuffle(b, v1: t4, v2: tc, mask: 0x88);
273 vs[0x5] = create4x128BitSuffle(b, v1: t5, v2: td, mask: 0x88);
274 vs[0x6] = create4x128BitSuffle(b, v1: t6, v2: te, mask: 0x88);
275 vs[0x7] = create4x128BitSuffle(b, v1: t7, v2: tf, mask: 0x88);
276 vs[0x8] = create4x128BitSuffle(b, v1: t0, v2: t8, mask: 0xdd);
277 vs[0x9] = create4x128BitSuffle(b, v1: t1, v2: t9, mask: 0xdd);
278 vs[0xa] = create4x128BitSuffle(b, v1: t2, v2: ta, mask: 0xdd);
279 vs[0xb] = create4x128BitSuffle(b, v1: t3, v2: tb, mask: 0xdd);
280 vs[0xc] = create4x128BitSuffle(b, v1: t4, v2: tc, mask: 0xdd);
281 vs[0xd] = create4x128BitSuffle(b, v1: t5, v2: td, mask: 0xdd);
282 vs[0xe] = create4x128BitSuffle(b, v1: t6, v2: te, mask: 0xdd);
283 vs[0xf] = create4x128BitSuffle(b, v1: t7, v2: tf, mask: 0xdd);
284
285 auto reshInputType = VectorType::get(
286 {m, n}, cast<VectorType>(source.getType()).getElementType());
287 Value res = b.create<ub::PoisonOp>(reshInputType);
288 for (int64_t i = 0; i < m; ++i)
289 res = b.create<vector::InsertOp>(vs[i], res, i);
290 return res;
291}
292
293namespace {
294/// Progressive lowering of TransposeOp.
295/// One:
296/// %x = vector.transpose %y, [1, 0]
297/// is replaced by:
298/// %z = arith.constant dense<0.000000e+00>
299/// %0 = vector.extract %y[0, 0]
300/// %1 = vector.insert %0, %z [0, 0]
301/// ..
302/// %x = vector.insert .., .. [.., ..]
303class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
304public:
305 using OpRewritePattern::OpRewritePattern;
306
307 TransposeOpLowering(vector::VectorTransposeLowering vectorTransposeLowering,
308 MLIRContext *context, PatternBenefit benefit = 1)
309 : OpRewritePattern<vector::TransposeOp>(context, benefit),
310 vectorTransposeLowering(vectorTransposeLowering) {}
311
312 LogicalResult matchAndRewrite(vector::TransposeOp op,
313 PatternRewriter &rewriter) const override {
314 auto loc = op.getLoc();
315
316 Value input = op.getVector();
317 VectorType inputType = op.getSourceVectorType();
318 VectorType resType = op.getResultVectorType();
319
320 if (inputType.isScalable())
321 return rewriter.notifyMatchFailure(
322 op, "This lowering does not support scalable vectors");
323
324 // Set up convenience transposition table.
325 ArrayRef<int64_t> transp = op.getPermutation();
326
327 if (isShuffleLike(vectorTransposeLowering) &&
328 succeeded(isTranspose2DSlice(op)))
329 return rewriter.notifyMatchFailure(
330 op, "Options specifies lowering to shuffle");
331
332 // Handle a true 2-D matrix transpose differently when requested.
333 if (vectorTransposeLowering == vector::VectorTransposeLowering::Flat &&
334 resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) {
335 Type flattenedType =
336 VectorType::get(resType.getNumElements(), resType.getElementType());
337 auto matrix =
338 rewriter.create<vector::ShapeCastOp>(loc, flattenedType, input);
339 auto rows = rewriter.getI32IntegerAttr(value: resType.getShape()[0]);
340 auto columns = rewriter.getI32IntegerAttr(value: resType.getShape()[1]);
341 Value trans = rewriter.create<vector::FlatTransposeOp>(
342 loc, flattenedType, matrix, rows, columns);
343 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, trans);
344 return success();
345 }
346
347 // Generate unrolled extract/insert ops. We do not unroll the rightmost
348 // (i.e., highest-order) dimensions that are not transposed and leave them
349 // in vector form to improve performance. Therefore, we prune those
350 // dimensions from the shape/transpose data structures used to generate the
351 // extract/insert ops.
352 SmallVector<int64_t> prunedTransp;
353 pruneNonTransposedDims(transpose: transp, result&: prunedTransp);
354 size_t numPrunedDims = transp.size() - prunedTransp.size();
355 auto prunedInShape = inputType.getShape().drop_back(numPrunedDims);
356 auto prunedInStrides = computeStrides(prunedInShape);
357
358 // Generates the extract/insert operations for every scalar/vector element
359 // of the leftmost transposed dimensions. We traverse every transpose
360 // element using a linearized index that we delinearize to generate the
361 // appropriate indices for the extract/insert operations.
362 Value result = rewriter.create<ub::PoisonOp>(loc, resType);
363 int64_t numTransposedElements = ShapedType::getNumElements(prunedInShape);
364
365 for (int64_t linearIdx = 0; linearIdx < numTransposedElements;
366 ++linearIdx) {
367 auto extractIdxs = delinearize(linearIdx, prunedInStrides);
368 SmallVector<int64_t> insertIdxs(extractIdxs);
369 applyPermutationToVector(inVec&: insertIdxs, permutation: prunedTransp);
370 Value extractOp =
371 rewriter.createOrFold<vector::ExtractOp>(loc, input, extractIdxs);
372 result = rewriter.createOrFold<vector::InsertOp>(loc, extractOp, result,
373 insertIdxs);
374 }
375
376 rewriter.replaceOp(op, result);
377 return success();
378 }
379
380private:
381 /// Options to control the vector patterns.
382 vector::VectorTransposeLowering vectorTransposeLowering;
383};
384
385/// Rewrites vector.transpose as vector.shape_cast. This pattern is only applied
386/// to 2D vectors with at least one unit dim. For example:
387///
388/// Replace:
389/// vector.transpose %0, [1, 0] : vector<4x1xi32>> to
390/// vector<1x4xi32>
391/// with:
392/// vector.shape_cast %0 : vector<4x1xi32> to vector<1x4xi32>
393///
394/// Source with leading unit dim (inverse) is also replaced. Unit dim must
395/// be fixed. Non-unit dim can be scalable.
396///
397/// TODO: This pattern was introduced specifically to help lower scalable
398/// vectors. In hindsight, a more specialised canonicalization (for shape_cast's
399/// to cancel out) would be preferable:
400///
401/// BEFORE:
402/// %0 = some_op
403/// %1 = vector.shape_cast %0 : vector<[4]xf32> to vector<[4]x1xf32>
404/// %2 = vector.transpose %1 [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
405/// AFTER:
406/// %0 = some_op
407/// %1 = vector.shape_cast %0 : vector<[4]xf32> to vector<1x[4]xf32>
408///
409/// Given the context above, we may want to consider (re-)moving this pattern
410/// at some later time. I am leaving it for now in case there are other users
411/// that I am not aware of.
412class Transpose2DWithUnitDimToShapeCast
413 : public OpRewritePattern<vector::TransposeOp> {
414public:
415 using OpRewritePattern::OpRewritePattern;
416
417 Transpose2DWithUnitDimToShapeCast(MLIRContext *context,
418 PatternBenefit benefit = 1)
419 : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
420
421 LogicalResult matchAndRewrite(vector::TransposeOp op,
422 PatternRewriter &rewriter) const override {
423 Value input = op.getVector();
424 VectorType resType = op.getResultVectorType();
425
426 // Set up convenience transposition table.
427 ArrayRef<int64_t> transp = op.getPermutation();
428
429 if (resType.getRank() == 2 &&
430 ((resType.getShape().front() == 1 &&
431 !resType.getScalableDims().front()) ||
432 (resType.getShape().back() == 1 &&
433 !resType.getScalableDims().back())) &&
434 transp == ArrayRef<int64_t>({1, 0})) {
435 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, input);
436 return success();
437 }
438
439 return failure();
440 }
441};
442
443/// Rewrite a 2-D vector.transpose as a sequence of shuffle ops.
444/// If the strategy is Shuffle1D, it will be lowered to:
445/// vector.shape_cast 2D -> 1D
446/// vector.shuffle
447/// vector.shape_cast 1D -> 2D
448/// If the strategy is Shuffle16x16, it will be lowered to a sequence of shuffle
449/// ops on 16xf32 vectors.
450class TransposeOp2DToShuffleLowering
451 : public OpRewritePattern<vector::TransposeOp> {
452public:
453 using OpRewritePattern::OpRewritePattern;
454
455 TransposeOp2DToShuffleLowering(
456 vector::VectorTransposeLowering vectorTransposeLowering,
457 MLIRContext *context, PatternBenefit benefit = 1)
458 : OpRewritePattern<vector::TransposeOp>(context, benefit),
459 vectorTransposeLowering(vectorTransposeLowering) {}
460
461 LogicalResult matchAndRewrite(vector::TransposeOp op,
462 PatternRewriter &rewriter) const override {
463 if (!isShuffleLike(vectorTransposeLowering))
464 return rewriter.notifyMatchFailure(
465 op, "not using vector shuffle based lowering");
466
467 if (op.getSourceVectorType().isScalable())
468 return rewriter.notifyMatchFailure(
469 op, "vector shuffle lowering not supported for scalable vectors");
470
471 auto srcGtOneDims = isTranspose2DSlice(op);
472 if (failed(srcGtOneDims))
473 return rewriter.notifyMatchFailure(
474 op, "expected transposition on a 2D slice");
475
476 VectorType srcType = op.getSourceVectorType();
477 int64_t m = srcType.getDimSize(std::get<0>(srcGtOneDims.value()));
478 int64_t n = srcType.getDimSize(std::get<1>(srcGtOneDims.value()));
479
480 // Reshape the n-D input vector with only two dimensions greater than one
481 // to a 2-D vector.
482 Location loc = op.getLoc();
483 auto flattenedType = VectorType::get({n * m}, srcType.getElementType());
484 auto reshInputType = VectorType::get({m, n}, srcType.getElementType());
485 auto reshInput = rewriter.create<vector::ShapeCastOp>(loc, flattenedType,
486 op.getVector());
487
488 Value res;
489 if (vectorTransposeLowering == VectorTransposeLowering::Shuffle16x16 &&
490 m == 16 && n == 16) {
491 reshInput =
492 rewriter.create<vector::ShapeCastOp>(loc, reshInputType, reshInput);
493 res = transposeToShuffle16x16(rewriter, reshInput, m, n);
494 } else {
495 // Fallback to shuffle on 1D approach.
496 res = transposeToShuffle1D(rewriter, reshInput, m, n);
497 }
498
499 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
500 op, op.getResultVectorType(), res);
501
502 return success();
503 }
504
505private:
506 /// Options to control the vector patterns.
507 vector::VectorTransposeLowering vectorTransposeLowering;
508};
509} // namespace
510
511void mlir::vector::populateVectorTransposeLoweringPatterns(
512 RewritePatternSet &patterns,
513 VectorTransposeLowering vectorTransposeLowering, PatternBenefit benefit) {
514 patterns.add<Transpose2DWithUnitDimToShapeCast>(arg: patterns.getContext(),
515 args&: benefit);
516 patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>(
517 vectorTransposeLowering, patterns.getContext(), benefit);
518}
519

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp