1 | //===- LowerVectorTranspose.cpp - Lower 'vector.transpose' operation ------===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements target-independent rewrites and utilities to lower the |
10 | // 'vector.transpose' operation. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/Dialect/Arith/IR/Arith.h" |
15 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
16 | #include "mlir/Dialect/UB/IR/UBOps.h" |
17 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
18 | #include "mlir/Dialect/Utils/StructuredOpsUtils.h" |
19 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
20 | #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" |
21 | #include "mlir/Dialect/Vector/Utils/VectorUtils.h" |
22 | #include "mlir/IR/BuiltinTypes.h" |
23 | #include "mlir/IR/ImplicitLocOpBuilder.h" |
24 | #include "mlir/IR/Location.h" |
25 | #include "mlir/IR/PatternMatch.h" |
26 | #include "mlir/IR/TypeUtilities.h" |
27 | |
28 | #define DEBUG_TYPE "lower-vector-transpose" |
29 | |
30 | using namespace mlir; |
31 | using namespace mlir::vector; |
32 | |
33 | /// Given a 'transpose' pattern, prune the rightmost dimensions that are not |
34 | /// transposed. |
35 | static void pruneNonTransposedDims(ArrayRef<int64_t> transpose, |
36 | SmallVectorImpl<int64_t> &result) { |
37 | size_t numTransposedDims = transpose.size(); |
38 | for (size_t transpDim : llvm::reverse(C&: transpose)) { |
39 | if (transpDim != numTransposedDims - 1) |
40 | break; |
41 | numTransposedDims--; |
42 | } |
43 | |
44 | result.append(in_start: transpose.begin(), in_end: transpose.begin() + numTransposedDims); |
45 | } |
46 | |
47 | /// Returns true if the lowering option is a vector shuffle based approach. |
48 | static bool isShuffleLike(VectorTransposeLowering lowering) { |
49 | return lowering == VectorTransposeLowering::Shuffle1D || |
50 | lowering == VectorTransposeLowering::Shuffle16x16; |
51 | } |
52 | |
53 | /// Returns a shuffle mask that builds on `vals`. `vals` is the offset base of |
54 | /// shuffle ops, i.e., the unpack pattern. The method iterates with `vals` to |
55 | /// create the mask for `numBits` bits vector. The `numBits` have to be a |
56 | /// multiple of 128. For example, if `vals` is {0, 1, 16, 17} and `numBits` is |
57 | /// 512, there should be 16 elements in the final result. It constructs the |
58 | /// below mask to get the unpack elements. |
59 | /// [0, 1, 16, 17, |
60 | /// 0+4, 1+4, 16+4, 17+4, |
61 | /// 0+8, 1+8, 16+8, 17+8, |
62 | /// 0+12, 1+12, 16+12, 17+12] |
63 | static SmallVector<int64_t> |
64 | getUnpackShufflePermFor128Lane(ArrayRef<int64_t> vals, int numBits) { |
65 | assert(numBits % 128 == 0 && "expected numBits is a multiple of 128"); |
66 | int numElem = numBits / 32; |
67 | SmallVector<int64_t> res; |
68 | for (int i = 0; i < numElem; i += 4) |
69 | for (int64_t v : vals) |
70 | res.push_back(Elt: v + i); |
71 | return res; |
72 | } |
73 | |
74 | /// Lower to vector.shuffle on v1 and v2 with UnpackLoPd shuffle mask. For |
75 | /// example, if it is targeting 512 bit vector, returns |
76 | /// vector.shuffle on v1, v2, [0, 1, 16, 17, |
77 | /// 0+4, 1+4, 16+4, 17+4, |
78 | /// 0+8, 1+8, 16+8, 17+8, |
79 | /// 0+12, 1+12, 16+12, 17+12]. |
80 | static Value createUnpackLoPd(ImplicitLocOpBuilder &b, Value v1, Value v2, |
81 | int numBits) { |
82 | int numElem = numBits / 32; |
83 | return b.create<vector::ShuffleOp>( |
84 | v1, v2, |
85 | getUnpackShufflePermFor128Lane({0, 1, numElem, numElem + 1}, numBits)); |
86 | } |
87 | |
88 | /// Lower to vector.shuffle on v1 and v2 with UnpackHiPd shuffle mask. For |
89 | /// example, if it is targeting 512 bit vector, returns |
90 | /// vector.shuffle, v1, v2, [2, 3, 18, 19, |
91 | /// 2+4, 3+4, 18+4, 19+4, |
92 | /// 2+8, 3+8, 18+8, 19+8, |
93 | /// 2+12, 3+12, 18+12, 19+12]. |
94 | static Value createUnpackHiPd(ImplicitLocOpBuilder &b, Value v1, Value v2, |
95 | int numBits) { |
96 | int numElem = numBits / 32; |
97 | return b.create<vector::ShuffleOp>( |
98 | v1, v2, |
99 | getUnpackShufflePermFor128Lane({2, 3, numElem + 2, numElem + 3}, |
100 | numBits)); |
101 | } |
102 | |
103 | /// Lower to vector.shuffle on v1 and v2 with UnpackLoPs shuffle mask. For |
104 | /// example, if it is targeting 512 bit vector, returns |
105 | /// vector.shuffle, v1, v2, [0, 16, 1, 17, |
106 | /// 0+4, 16+4, 1+4, 17+4, |
107 | /// 0+8, 16+8, 1+8, 17+8, |
108 | /// 0+12, 16+12, 1+12, 17+12]. |
109 | static Value createUnpackLoPs(ImplicitLocOpBuilder &b, Value v1, Value v2, |
110 | int numBits) { |
111 | int numElem = numBits / 32; |
112 | auto shuffle = b.create<vector::ShuffleOp>( |
113 | v1, v2, |
114 | getUnpackShufflePermFor128Lane({0, numElem, 1, numElem + 1}, numBits)); |
115 | return shuffle; |
116 | } |
117 | |
118 | /// Lower to vector.shuffle on v1 and v2 with UnpackHiPs shuffle mask. For |
119 | /// example, if it is targeting 512 bit vector, returns |
120 | /// vector.shuffle, v1, v2, [2, 18, 3, 19, |
121 | /// 2+4, 18+4, 3+4, 19+4, |
122 | /// 2+8, 18+8, 3+8, 19+8, |
123 | /// 2+12, 18+12, 3+12, 19+12]. |
124 | static Value createUnpackHiPs(ImplicitLocOpBuilder &b, Value v1, Value v2, |
125 | int numBits) { |
126 | int numElem = numBits / 32; |
127 | return b.create<vector::ShuffleOp>( |
128 | v1, v2, |
129 | getUnpackShufflePermFor128Lane({2, numElem + 2, 3, numElem + 3}, |
130 | numBits)); |
131 | } |
132 | |
133 | /// Returns a vector.shuffle that shuffles 128-bit lanes (composed of 4 32-bit |
134 | /// elements) selected by `mask` from `v1` and `v2`. I.e., |
135 | /// |
136 | /// DEFINE SELECT4(src, control) { |
137 | /// CASE(control[1:0]) OF |
138 | /// 0: tmp[127:0] := src[127:0] |
139 | /// 1: tmp[127:0] := src[255:128] |
140 | /// 2: tmp[127:0] := src[383:256] |
141 | /// 3: tmp[127:0] := src[511:384] |
142 | /// ESAC |
143 | /// RETURN tmp[127:0] |
144 | /// } |
145 | /// dst[127:0] := SELECT4(v1[511:0], mask[1:0]) |
146 | /// dst[255:128] := SELECT4(v1[511:0], mask[3:2]) |
147 | /// dst[383:256] := SELECT4(v2[511:0], mask[5:4]) |
148 | /// dst[511:384] := SELECT4(v2[511:0], mask[7:6]) |
149 | static Value create4x128BitSuffle(ImplicitLocOpBuilder &b, Value v1, Value v2, |
150 | uint8_t mask) { |
151 | assert(cast<VectorType>(v1.getType()).getShape()[0] == 16 && |
152 | "expected a vector with length=16"); |
153 | SmallVector<int64_t> shuffleMask; |
154 | auto appendToMask = [&](int64_t base, uint8_t control) { |
155 | switch (control) { |
156 | case 0: |
157 | llvm::append_range(C&: shuffleMask, R: ArrayRef<int64_t>{base + 0, base + 1, |
158 | base + 2, base + 3}); |
159 | break; |
160 | case 1: |
161 | llvm::append_range(C&: shuffleMask, R: ArrayRef<int64_t>{base + 4, base + 5, |
162 | base + 6, base + 7}); |
163 | break; |
164 | case 2: |
165 | llvm::append_range(C&: shuffleMask, R: ArrayRef<int64_t>{base + 8, base + 9, |
166 | base + 10, base + 11}); |
167 | break; |
168 | case 3: |
169 | llvm::append_range(C&: shuffleMask, R: ArrayRef<int64_t>{base + 12, base + 13, |
170 | base + 14, base + 15}); |
171 | break; |
172 | default: |
173 | llvm_unreachable("control > 3 : overflow"); |
174 | } |
175 | }; |
176 | uint8_t b01 = mask & 0x3; |
177 | uint8_t b23 = (mask >> 2) & 0x3; |
178 | uint8_t b45 = (mask >> 4) & 0x3; |
179 | uint8_t b67 = (mask >> 6) & 0x3; |
180 | appendToMask(0, b01); |
181 | appendToMask(0, b23); |
182 | appendToMask(16, b45); |
183 | appendToMask(16, b67); |
184 | return b.create<vector::ShuffleOp>(v1, v2, shuffleMask); |
185 | } |
186 | |
187 | /// Lowers the value to a vector.shuffle op. The `source` is expected to be a |
188 | /// 1-D vector and have `m`x`n` elements. |
189 | static Value transposeToShuffle1D(OpBuilder &b, Value source, int m, int n) { |
190 | SmallVector<int64_t> mask; |
191 | mask.reserve(N: m * n); |
192 | for (int64_t j = 0; j < n; ++j) |
193 | for (int64_t i = 0; i < m; ++i) |
194 | mask.push_back(Elt: i * n + j); |
195 | return b.create<vector::ShuffleOp>(source.getLoc(), source, source, mask); |
196 | } |
197 | |
198 | /// Lowers the value to a sequence of vector.shuffle ops. The `source` is |
199 | /// expected to be a 16x16 vector. |
200 | static Value transposeToShuffle16x16(OpBuilder &builder, Value source, int m, |
201 | int n) { |
202 | ImplicitLocOpBuilder b(source.getLoc(), builder); |
203 | SmallVector<Value> vs; |
204 | for (int64_t i = 0; i < m; ++i) |
205 | vs.push_back(b.createOrFold<vector::ExtractOp>(source, i)); |
206 | |
207 | // Interleave 32-bit lanes using |
208 | // 8x _mm512_unpacklo_epi32 |
209 | // 8x _mm512_unpackhi_epi32 |
210 | Value t0 = createUnpackLoPs(b, v1: vs[0x0], v2: vs[0x1], numBits: 512); |
211 | Value t1 = createUnpackHiPs(b, v1: vs[0x0], v2: vs[0x1], numBits: 512); |
212 | Value t2 = createUnpackLoPs(b, v1: vs[0x2], v2: vs[0x3], numBits: 512); |
213 | Value t3 = createUnpackHiPs(b, v1: vs[0x2], v2: vs[0x3], numBits: 512); |
214 | Value t4 = createUnpackLoPs(b, v1: vs[0x4], v2: vs[0x5], numBits: 512); |
215 | Value t5 = createUnpackHiPs(b, v1: vs[0x4], v2: vs[0x5], numBits: 512); |
216 | Value t6 = createUnpackLoPs(b, v1: vs[0x6], v2: vs[0x7], numBits: 512); |
217 | Value t7 = createUnpackHiPs(b, v1: vs[0x6], v2: vs[0x7], numBits: 512); |
218 | Value t8 = createUnpackLoPs(b, v1: vs[0x8], v2: vs[0x9], numBits: 512); |
219 | Value t9 = createUnpackHiPs(b, v1: vs[0x8], v2: vs[0x9], numBits: 512); |
220 | Value ta = createUnpackLoPs(b, v1: vs[0xa], v2: vs[0xb], numBits: 512); |
221 | Value tb = createUnpackHiPs(b, v1: vs[0xa], v2: vs[0xb], numBits: 512); |
222 | Value tc = createUnpackLoPs(b, v1: vs[0xc], v2: vs[0xd], numBits: 512); |
223 | Value td = createUnpackHiPs(b, v1: vs[0xc], v2: vs[0xd], numBits: 512); |
224 | Value te = createUnpackLoPs(b, v1: vs[0xe], v2: vs[0xf], numBits: 512); |
225 | Value tf = createUnpackHiPs(b, v1: vs[0xe], v2: vs[0xf], numBits: 512); |
226 | |
227 | // Interleave 64-bit lanes using |
228 | // 8x _mm512_unpacklo_epi64 |
229 | // 8x _mm512_unpackhi_epi64 |
230 | Value r0 = createUnpackLoPd(b, v1: t0, v2: t2, numBits: 512); |
231 | Value r1 = createUnpackHiPd(b, v1: t0, v2: t2, numBits: 512); |
232 | Value r2 = createUnpackLoPd(b, v1: t1, v2: t3, numBits: 512); |
233 | Value r3 = createUnpackHiPd(b, v1: t1, v2: t3, numBits: 512); |
234 | Value r4 = createUnpackLoPd(b, v1: t4, v2: t6, numBits: 512); |
235 | Value r5 = createUnpackHiPd(b, v1: t4, v2: t6, numBits: 512); |
236 | Value r6 = createUnpackLoPd(b, v1: t5, v2: t7, numBits: 512); |
237 | Value r7 = createUnpackHiPd(b, v1: t5, v2: t7, numBits: 512); |
238 | Value r8 = createUnpackLoPd(b, v1: t8, v2: ta, numBits: 512); |
239 | Value r9 = createUnpackHiPd(b, v1: t8, v2: ta, numBits: 512); |
240 | Value ra = createUnpackLoPd(b, v1: t9, v2: tb, numBits: 512); |
241 | Value rb = createUnpackHiPd(b, v1: t9, v2: tb, numBits: 512); |
242 | Value rc = createUnpackLoPd(b, v1: tc, v2: te, numBits: 512); |
243 | Value rd = createUnpackHiPd(b, v1: tc, v2: te, numBits: 512); |
244 | Value re = createUnpackLoPd(b, v1: td, v2: tf, numBits: 512); |
245 | Value rf = createUnpackHiPd(b, v1: td, v2: tf, numBits: 512); |
246 | |
247 | // Permute 128-bit lanes using |
248 | // 16x _mm512_shuffle_i32x4 |
249 | t0 = create4x128BitSuffle(b, v1: r0, v2: r4, mask: 0x88); |
250 | t1 = create4x128BitSuffle(b, v1: r1, v2: r5, mask: 0x88); |
251 | t2 = create4x128BitSuffle(b, v1: r2, v2: r6, mask: 0x88); |
252 | t3 = create4x128BitSuffle(b, v1: r3, v2: r7, mask: 0x88); |
253 | t4 = create4x128BitSuffle(b, v1: r0, v2: r4, mask: 0xdd); |
254 | t5 = create4x128BitSuffle(b, v1: r1, v2: r5, mask: 0xdd); |
255 | t6 = create4x128BitSuffle(b, v1: r2, v2: r6, mask: 0xdd); |
256 | t7 = create4x128BitSuffle(b, v1: r3, v2: r7, mask: 0xdd); |
257 | t8 = create4x128BitSuffle(b, v1: r8, v2: rc, mask: 0x88); |
258 | t9 = create4x128BitSuffle(b, v1: r9, v2: rd, mask: 0x88); |
259 | ta = create4x128BitSuffle(b, v1: ra, v2: re, mask: 0x88); |
260 | tb = create4x128BitSuffle(b, v1: rb, v2: rf, mask: 0x88); |
261 | tc = create4x128BitSuffle(b, v1: r8, v2: rc, mask: 0xdd); |
262 | td = create4x128BitSuffle(b, v1: r9, v2: rd, mask: 0xdd); |
263 | te = create4x128BitSuffle(b, v1: ra, v2: re, mask: 0xdd); |
264 | tf = create4x128BitSuffle(b, v1: rb, v2: rf, mask: 0xdd); |
265 | |
266 | // Permute 256-bit lanes using again |
267 | // 16x _mm512_shuffle_i32x4 |
268 | vs[0x0] = create4x128BitSuffle(b, v1: t0, v2: t8, mask: 0x88); |
269 | vs[0x1] = create4x128BitSuffle(b, v1: t1, v2: t9, mask: 0x88); |
270 | vs[0x2] = create4x128BitSuffle(b, v1: t2, v2: ta, mask: 0x88); |
271 | vs[0x3] = create4x128BitSuffle(b, v1: t3, v2: tb, mask: 0x88); |
272 | vs[0x4] = create4x128BitSuffle(b, v1: t4, v2: tc, mask: 0x88); |
273 | vs[0x5] = create4x128BitSuffle(b, v1: t5, v2: td, mask: 0x88); |
274 | vs[0x6] = create4x128BitSuffle(b, v1: t6, v2: te, mask: 0x88); |
275 | vs[0x7] = create4x128BitSuffle(b, v1: t7, v2: tf, mask: 0x88); |
276 | vs[0x8] = create4x128BitSuffle(b, v1: t0, v2: t8, mask: 0xdd); |
277 | vs[0x9] = create4x128BitSuffle(b, v1: t1, v2: t9, mask: 0xdd); |
278 | vs[0xa] = create4x128BitSuffle(b, v1: t2, v2: ta, mask: 0xdd); |
279 | vs[0xb] = create4x128BitSuffle(b, v1: t3, v2: tb, mask: 0xdd); |
280 | vs[0xc] = create4x128BitSuffle(b, v1: t4, v2: tc, mask: 0xdd); |
281 | vs[0xd] = create4x128BitSuffle(b, v1: t5, v2: td, mask: 0xdd); |
282 | vs[0xe] = create4x128BitSuffle(b, v1: t6, v2: te, mask: 0xdd); |
283 | vs[0xf] = create4x128BitSuffle(b, v1: t7, v2: tf, mask: 0xdd); |
284 | |
285 | auto reshInputType = VectorType::get( |
286 | {m, n}, cast<VectorType>(source.getType()).getElementType()); |
287 | Value res = b.create<ub::PoisonOp>(reshInputType); |
288 | for (int64_t i = 0; i < m; ++i) |
289 | res = b.create<vector::InsertOp>(vs[i], res, i); |
290 | return res; |
291 | } |
292 | |
293 | namespace { |
294 | /// Progressive lowering of TransposeOp. |
295 | /// One: |
296 | /// %x = vector.transpose %y, [1, 0] |
297 | /// is replaced by: |
298 | /// %z = arith.constant dense<0.000000e+00> |
299 | /// %0 = vector.extract %y[0, 0] |
300 | /// %1 = vector.insert %0, %z [0, 0] |
301 | /// .. |
302 | /// %x = vector.insert .., .. [.., ..] |
303 | class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> { |
304 | public: |
305 | using OpRewritePattern::OpRewritePattern; |
306 | |
307 | TransposeOpLowering(vector::VectorTransposeLowering vectorTransposeLowering, |
308 | MLIRContext *context, PatternBenefit benefit = 1) |
309 | : OpRewritePattern<vector::TransposeOp>(context, benefit), |
310 | vectorTransposeLowering(vectorTransposeLowering) {} |
311 | |
312 | LogicalResult matchAndRewrite(vector::TransposeOp op, |
313 | PatternRewriter &rewriter) const override { |
314 | auto loc = op.getLoc(); |
315 | |
316 | Value input = op.getVector(); |
317 | VectorType inputType = op.getSourceVectorType(); |
318 | VectorType resType = op.getResultVectorType(); |
319 | |
320 | if (inputType.isScalable()) |
321 | return rewriter.notifyMatchFailure( |
322 | op, "This lowering does not support scalable vectors"); |
323 | |
324 | // Set up convenience transposition table. |
325 | ArrayRef<int64_t> transp = op.getPermutation(); |
326 | |
327 | if (isShuffleLike(vectorTransposeLowering) && |
328 | succeeded(isTranspose2DSlice(op))) |
329 | return rewriter.notifyMatchFailure( |
330 | op, "Options specifies lowering to shuffle"); |
331 | |
332 | // Handle a true 2-D matrix transpose differently when requested. |
333 | if (vectorTransposeLowering == vector::VectorTransposeLowering::Flat && |
334 | resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) { |
335 | Type flattenedType = |
336 | VectorType::get(resType.getNumElements(), resType.getElementType()); |
337 | auto matrix = |
338 | rewriter.create<vector::ShapeCastOp>(loc, flattenedType, input); |
339 | auto rows = rewriter.getI32IntegerAttr(value: resType.getShape()[0]); |
340 | auto columns = rewriter.getI32IntegerAttr(value: resType.getShape()[1]); |
341 | Value trans = rewriter.create<vector::FlatTransposeOp>( |
342 | loc, flattenedType, matrix, rows, columns); |
343 | rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, trans); |
344 | return success(); |
345 | } |
346 | |
347 | // Generate unrolled extract/insert ops. We do not unroll the rightmost |
348 | // (i.e., highest-order) dimensions that are not transposed and leave them |
349 | // in vector form to improve performance. Therefore, we prune those |
350 | // dimensions from the shape/transpose data structures used to generate the |
351 | // extract/insert ops. |
352 | SmallVector<int64_t> prunedTransp; |
353 | pruneNonTransposedDims(transpose: transp, result&: prunedTransp); |
354 | size_t numPrunedDims = transp.size() - prunedTransp.size(); |
355 | auto prunedInShape = inputType.getShape().drop_back(numPrunedDims); |
356 | auto prunedInStrides = computeStrides(prunedInShape); |
357 | |
358 | // Generates the extract/insert operations for every scalar/vector element |
359 | // of the leftmost transposed dimensions. We traverse every transpose |
360 | // element using a linearized index that we delinearize to generate the |
361 | // appropriate indices for the extract/insert operations. |
362 | Value result = rewriter.create<ub::PoisonOp>(loc, resType); |
363 | int64_t numTransposedElements = ShapedType::getNumElements(prunedInShape); |
364 | |
365 | for (int64_t linearIdx = 0; linearIdx < numTransposedElements; |
366 | ++linearIdx) { |
367 | auto extractIdxs = delinearize(linearIdx, prunedInStrides); |
368 | SmallVector<int64_t> insertIdxs(extractIdxs); |
369 | applyPermutationToVector(inVec&: insertIdxs, permutation: prunedTransp); |
370 | Value extractOp = |
371 | rewriter.createOrFold<vector::ExtractOp>(loc, input, extractIdxs); |
372 | result = rewriter.createOrFold<vector::InsertOp>(loc, extractOp, result, |
373 | insertIdxs); |
374 | } |
375 | |
376 | rewriter.replaceOp(op, result); |
377 | return success(); |
378 | } |
379 | |
380 | private: |
381 | /// Options to control the vector patterns. |
382 | vector::VectorTransposeLowering vectorTransposeLowering; |
383 | }; |
384 | |
385 | /// Rewrites vector.transpose as vector.shape_cast. This pattern is only applied |
386 | /// to 2D vectors with at least one unit dim. For example: |
387 | /// |
388 | /// Replace: |
389 | /// vector.transpose %0, [1, 0] : vector<4x1xi32>> to |
390 | /// vector<1x4xi32> |
391 | /// with: |
392 | /// vector.shape_cast %0 : vector<4x1xi32> to vector<1x4xi32> |
393 | /// |
394 | /// Source with leading unit dim (inverse) is also replaced. Unit dim must |
395 | /// be fixed. Non-unit dim can be scalable. |
396 | /// |
397 | /// TODO: This pattern was introduced specifically to help lower scalable |
398 | /// vectors. In hindsight, a more specialised canonicalization (for shape_cast's |
399 | /// to cancel out) would be preferable: |
400 | /// |
401 | /// BEFORE: |
402 | /// %0 = some_op |
403 | /// %1 = vector.shape_cast %0 : vector<[4]xf32> to vector<[4]x1xf32> |
404 | /// %2 = vector.transpose %1 [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32> |
405 | /// AFTER: |
406 | /// %0 = some_op |
407 | /// %1 = vector.shape_cast %0 : vector<[4]xf32> to vector<1x[4]xf32> |
408 | /// |
409 | /// Given the context above, we may want to consider (re-)moving this pattern |
410 | /// at some later time. I am leaving it for now in case there are other users |
411 | /// that I am not aware of. |
412 | class Transpose2DWithUnitDimToShapeCast |
413 | : public OpRewritePattern<vector::TransposeOp> { |
414 | public: |
415 | using OpRewritePattern::OpRewritePattern; |
416 | |
417 | Transpose2DWithUnitDimToShapeCast(MLIRContext *context, |
418 | PatternBenefit benefit = 1) |
419 | : OpRewritePattern<vector::TransposeOp>(context, benefit) {} |
420 | |
421 | LogicalResult matchAndRewrite(vector::TransposeOp op, |
422 | PatternRewriter &rewriter) const override { |
423 | Value input = op.getVector(); |
424 | VectorType resType = op.getResultVectorType(); |
425 | |
426 | // Set up convenience transposition table. |
427 | ArrayRef<int64_t> transp = op.getPermutation(); |
428 | |
429 | if (resType.getRank() == 2 && |
430 | ((resType.getShape().front() == 1 && |
431 | !resType.getScalableDims().front()) || |
432 | (resType.getShape().back() == 1 && |
433 | !resType.getScalableDims().back())) && |
434 | transp == ArrayRef<int64_t>({1, 0})) { |
435 | rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, input); |
436 | return success(); |
437 | } |
438 | |
439 | return failure(); |
440 | } |
441 | }; |
442 | |
443 | /// Rewrite a 2-D vector.transpose as a sequence of shuffle ops. |
444 | /// If the strategy is Shuffle1D, it will be lowered to: |
445 | /// vector.shape_cast 2D -> 1D |
446 | /// vector.shuffle |
447 | /// vector.shape_cast 1D -> 2D |
448 | /// If the strategy is Shuffle16x16, it will be lowered to a sequence of shuffle |
449 | /// ops on 16xf32 vectors. |
450 | class TransposeOp2DToShuffleLowering |
451 | : public OpRewritePattern<vector::TransposeOp> { |
452 | public: |
453 | using OpRewritePattern::OpRewritePattern; |
454 | |
455 | TransposeOp2DToShuffleLowering( |
456 | vector::VectorTransposeLowering vectorTransposeLowering, |
457 | MLIRContext *context, PatternBenefit benefit = 1) |
458 | : OpRewritePattern<vector::TransposeOp>(context, benefit), |
459 | vectorTransposeLowering(vectorTransposeLowering) {} |
460 | |
461 | LogicalResult matchAndRewrite(vector::TransposeOp op, |
462 | PatternRewriter &rewriter) const override { |
463 | if (!isShuffleLike(vectorTransposeLowering)) |
464 | return rewriter.notifyMatchFailure( |
465 | op, "not using vector shuffle based lowering"); |
466 | |
467 | if (op.getSourceVectorType().isScalable()) |
468 | return rewriter.notifyMatchFailure( |
469 | op, "vector shuffle lowering not supported for scalable vectors"); |
470 | |
471 | auto srcGtOneDims = isTranspose2DSlice(op); |
472 | if (failed(srcGtOneDims)) |
473 | return rewriter.notifyMatchFailure( |
474 | op, "expected transposition on a 2D slice"); |
475 | |
476 | VectorType srcType = op.getSourceVectorType(); |
477 | int64_t m = srcType.getDimSize(std::get<0>(srcGtOneDims.value())); |
478 | int64_t n = srcType.getDimSize(std::get<1>(srcGtOneDims.value())); |
479 | |
480 | // Reshape the n-D input vector with only two dimensions greater than one |
481 | // to a 2-D vector. |
482 | Location loc = op.getLoc(); |
483 | auto flattenedType = VectorType::get({n * m}, srcType.getElementType()); |
484 | auto reshInputType = VectorType::get({m, n}, srcType.getElementType()); |
485 | auto reshInput = rewriter.create<vector::ShapeCastOp>(loc, flattenedType, |
486 | op.getVector()); |
487 | |
488 | Value res; |
489 | if (vectorTransposeLowering == VectorTransposeLowering::Shuffle16x16 && |
490 | m == 16 && n == 16) { |
491 | reshInput = |
492 | rewriter.create<vector::ShapeCastOp>(loc, reshInputType, reshInput); |
493 | res = transposeToShuffle16x16(rewriter, reshInput, m, n); |
494 | } else { |
495 | // Fallback to shuffle on 1D approach. |
496 | res = transposeToShuffle1D(rewriter, reshInput, m, n); |
497 | } |
498 | |
499 | rewriter.replaceOpWithNewOp<vector::ShapeCastOp>( |
500 | op, op.getResultVectorType(), res); |
501 | |
502 | return success(); |
503 | } |
504 | |
505 | private: |
506 | /// Options to control the vector patterns. |
507 | vector::VectorTransposeLowering vectorTransposeLowering; |
508 | }; |
509 | } // namespace |
510 | |
511 | void mlir::vector::populateVectorTransposeLoweringPatterns( |
512 | RewritePatternSet &patterns, |
513 | VectorTransposeLowering vectorTransposeLowering, PatternBenefit benefit) { |
514 | patterns.add<Transpose2DWithUnitDimToShapeCast>(arg: patterns.getContext(), |
515 | args&: benefit); |
516 | patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>( |
517 | vectorTransposeLowering, patterns.getContext(), benefit); |
518 | } |
519 |
Definitions
- pruneNonTransposedDims
- isShuffleLike
- getUnpackShufflePermFor128Lane
- createUnpackLoPd
- createUnpackHiPd
- createUnpackLoPs
- createUnpackHiPs
- create4x128BitSuffle
- transposeToShuffle1D
- transposeToShuffle16x16
- TransposeOpLowering
- TransposeOpLowering
- matchAndRewrite
- Transpose2DWithUnitDimToShapeCast
- Transpose2DWithUnitDimToShapeCast
- matchAndRewrite
- TransposeOp2DToShuffleLowering
- TransposeOp2DToShuffleLowering
- matchAndRewrite
Learn to use CMake with our Intro Training
Find out more