| 1 | //===- LowerVectorTranspose.cpp - Lower 'vector.transpose' operation ------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements target-independent rewrites and utilities to lower the |
| 10 | // 'vector.transpose' operation. |
| 11 | // |
| 12 | //===----------------------------------------------------------------------===// |
| 13 | |
| 14 | #include "mlir/Dialect/Arith/IR/Arith.h" |
| 15 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| 16 | #include "mlir/Dialect/UB/IR/UBOps.h" |
| 17 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
| 18 | #include "mlir/Dialect/Utils/StructuredOpsUtils.h" |
| 19 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
| 20 | #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" |
| 21 | #include "mlir/Dialect/Vector/Utils/VectorUtils.h" |
| 22 | #include "mlir/IR/BuiltinTypes.h" |
| 23 | #include "mlir/IR/ImplicitLocOpBuilder.h" |
| 24 | #include "mlir/IR/Location.h" |
| 25 | #include "mlir/IR/PatternMatch.h" |
| 26 | #include "mlir/IR/TypeUtilities.h" |
| 27 | |
| 28 | #define DEBUG_TYPE "lower-vector-transpose" |
| 29 | |
| 30 | using namespace mlir; |
| 31 | using namespace mlir::vector; |
| 32 | |
| 33 | /// Given a 'transpose' pattern, prune the rightmost dimensions that are not |
| 34 | /// transposed. |
| 35 | static void pruneNonTransposedDims(ArrayRef<int64_t> transpose, |
| 36 | SmallVectorImpl<int64_t> &result) { |
| 37 | size_t numTransposedDims = transpose.size(); |
| 38 | for (size_t transpDim : llvm::reverse(C&: transpose)) { |
| 39 | if (transpDim != numTransposedDims - 1) |
| 40 | break; |
| 41 | numTransposedDims--; |
| 42 | } |
| 43 | |
| 44 | result.append(in_start: transpose.begin(), in_end: transpose.begin() + numTransposedDims); |
| 45 | } |
| 46 | |
| 47 | /// Returns true if the lowering option is a vector shuffle based approach. |
| 48 | static bool isShuffleLike(VectorTransposeLowering lowering) { |
| 49 | return lowering == VectorTransposeLowering::Shuffle1D || |
| 50 | lowering == VectorTransposeLowering::Shuffle16x16; |
| 51 | } |
| 52 | |
| 53 | /// Returns a shuffle mask that builds on `vals`. `vals` is the offset base of |
| 54 | /// shuffle ops, i.e., the unpack pattern. The method iterates with `vals` to |
| 55 | /// create the mask for `numBits` bits vector. The `numBits` have to be a |
| 56 | /// multiple of 128. For example, if `vals` is {0, 1, 16, 17} and `numBits` is |
| 57 | /// 512, there should be 16 elements in the final result. It constructs the |
| 58 | /// below mask to get the unpack elements. |
| 59 | /// [0, 1, 16, 17, |
| 60 | /// 0+4, 1+4, 16+4, 17+4, |
| 61 | /// 0+8, 1+8, 16+8, 17+8, |
| 62 | /// 0+12, 1+12, 16+12, 17+12] |
| 63 | static SmallVector<int64_t> |
| 64 | getUnpackShufflePermFor128Lane(ArrayRef<int64_t> vals, int numBits) { |
| 65 | assert(numBits % 128 == 0 && "expected numBits is a multiple of 128" ); |
| 66 | int numElem = numBits / 32; |
| 67 | SmallVector<int64_t> res; |
| 68 | for (int i = 0; i < numElem; i += 4) |
| 69 | for (int64_t v : vals) |
| 70 | res.push_back(Elt: v + i); |
| 71 | return res; |
| 72 | } |
| 73 | |
| 74 | /// Lower to vector.shuffle on v1 and v2 with UnpackLoPd shuffle mask. For |
| 75 | /// example, if it is targeting 512 bit vector, returns |
| 76 | /// vector.shuffle on v1, v2, [0, 1, 16, 17, |
| 77 | /// 0+4, 1+4, 16+4, 17+4, |
| 78 | /// 0+8, 1+8, 16+8, 17+8, |
| 79 | /// 0+12, 1+12, 16+12, 17+12]. |
| 80 | static Value createUnpackLoPd(ImplicitLocOpBuilder &b, Value v1, Value v2, |
| 81 | int numBits) { |
| 82 | int numElem = numBits / 32; |
| 83 | return b.create<vector::ShuffleOp>( |
| 84 | v1, v2, |
| 85 | getUnpackShufflePermFor128Lane({0, 1, numElem, numElem + 1}, numBits)); |
| 86 | } |
| 87 | |
| 88 | /// Lower to vector.shuffle on v1 and v2 with UnpackHiPd shuffle mask. For |
| 89 | /// example, if it is targeting 512 bit vector, returns |
| 90 | /// vector.shuffle, v1, v2, [2, 3, 18, 19, |
| 91 | /// 2+4, 3+4, 18+4, 19+4, |
| 92 | /// 2+8, 3+8, 18+8, 19+8, |
| 93 | /// 2+12, 3+12, 18+12, 19+12]. |
| 94 | static Value createUnpackHiPd(ImplicitLocOpBuilder &b, Value v1, Value v2, |
| 95 | int numBits) { |
| 96 | int numElem = numBits / 32; |
| 97 | return b.create<vector::ShuffleOp>( |
| 98 | v1, v2, |
| 99 | getUnpackShufflePermFor128Lane({2, 3, numElem + 2, numElem + 3}, |
| 100 | numBits)); |
| 101 | } |
| 102 | |
| 103 | /// Lower to vector.shuffle on v1 and v2 with UnpackLoPs shuffle mask. For |
| 104 | /// example, if it is targeting 512 bit vector, returns |
| 105 | /// vector.shuffle, v1, v2, [0, 16, 1, 17, |
| 106 | /// 0+4, 16+4, 1+4, 17+4, |
| 107 | /// 0+8, 16+8, 1+8, 17+8, |
| 108 | /// 0+12, 16+12, 1+12, 17+12]. |
| 109 | static Value createUnpackLoPs(ImplicitLocOpBuilder &b, Value v1, Value v2, |
| 110 | int numBits) { |
| 111 | int numElem = numBits / 32; |
| 112 | auto shuffle = b.create<vector::ShuffleOp>( |
| 113 | v1, v2, |
| 114 | getUnpackShufflePermFor128Lane({0, numElem, 1, numElem + 1}, numBits)); |
| 115 | return shuffle; |
| 116 | } |
| 117 | |
| 118 | /// Lower to vector.shuffle on v1 and v2 with UnpackHiPs shuffle mask. For |
| 119 | /// example, if it is targeting 512 bit vector, returns |
| 120 | /// vector.shuffle, v1, v2, [2, 18, 3, 19, |
| 121 | /// 2+4, 18+4, 3+4, 19+4, |
| 122 | /// 2+8, 18+8, 3+8, 19+8, |
| 123 | /// 2+12, 18+12, 3+12, 19+12]. |
| 124 | static Value createUnpackHiPs(ImplicitLocOpBuilder &b, Value v1, Value v2, |
| 125 | int numBits) { |
| 126 | int numElem = numBits / 32; |
| 127 | return b.create<vector::ShuffleOp>( |
| 128 | v1, v2, |
| 129 | getUnpackShufflePermFor128Lane({2, numElem + 2, 3, numElem + 3}, |
| 130 | numBits)); |
| 131 | } |
| 132 | |
| 133 | /// Returns a vector.shuffle that shuffles 128-bit lanes (composed of 4 32-bit |
| 134 | /// elements) selected by `mask` from `v1` and `v2`. I.e., |
| 135 | /// |
| 136 | /// DEFINE SELECT4(src, control) { |
| 137 | /// CASE(control[1:0]) OF |
| 138 | /// 0: tmp[127:0] := src[127:0] |
| 139 | /// 1: tmp[127:0] := src[255:128] |
| 140 | /// 2: tmp[127:0] := src[383:256] |
| 141 | /// 3: tmp[127:0] := src[511:384] |
| 142 | /// ESAC |
| 143 | /// RETURN tmp[127:0] |
| 144 | /// } |
| 145 | /// dst[127:0] := SELECT4(v1[511:0], mask[1:0]) |
| 146 | /// dst[255:128] := SELECT4(v1[511:0], mask[3:2]) |
| 147 | /// dst[383:256] := SELECT4(v2[511:0], mask[5:4]) |
| 148 | /// dst[511:384] := SELECT4(v2[511:0], mask[7:6]) |
| 149 | static Value create4x128BitSuffle(ImplicitLocOpBuilder &b, Value v1, Value v2, |
| 150 | uint8_t mask) { |
| 151 | assert(cast<VectorType>(v1.getType()).getShape()[0] == 16 && |
| 152 | "expected a vector with length=16" ); |
| 153 | SmallVector<int64_t> shuffleMask; |
| 154 | auto appendToMask = [&](int64_t base, uint8_t control) { |
| 155 | switch (control) { |
| 156 | case 0: |
| 157 | llvm::append_range(C&: shuffleMask, R: ArrayRef<int64_t>{base + 0, base + 1, |
| 158 | base + 2, base + 3}); |
| 159 | break; |
| 160 | case 1: |
| 161 | llvm::append_range(C&: shuffleMask, R: ArrayRef<int64_t>{base + 4, base + 5, |
| 162 | base + 6, base + 7}); |
| 163 | break; |
| 164 | case 2: |
| 165 | llvm::append_range(C&: shuffleMask, R: ArrayRef<int64_t>{base + 8, base + 9, |
| 166 | base + 10, base + 11}); |
| 167 | break; |
| 168 | case 3: |
| 169 | llvm::append_range(C&: shuffleMask, R: ArrayRef<int64_t>{base + 12, base + 13, |
| 170 | base + 14, base + 15}); |
| 171 | break; |
| 172 | default: |
| 173 | llvm_unreachable("control > 3 : overflow" ); |
| 174 | } |
| 175 | }; |
| 176 | uint8_t b01 = mask & 0x3; |
| 177 | uint8_t b23 = (mask >> 2) & 0x3; |
| 178 | uint8_t b45 = (mask >> 4) & 0x3; |
| 179 | uint8_t b67 = (mask >> 6) & 0x3; |
| 180 | appendToMask(0, b01); |
| 181 | appendToMask(0, b23); |
| 182 | appendToMask(16, b45); |
| 183 | appendToMask(16, b67); |
| 184 | return b.create<vector::ShuffleOp>(v1, v2, shuffleMask); |
| 185 | } |
| 186 | |
| 187 | /// Lowers the value to a vector.shuffle op. The `source` is expected to be a |
| 188 | /// 1-D vector and have `m`x`n` elements. |
| 189 | static Value transposeToShuffle1D(OpBuilder &b, Value source, int m, int n) { |
| 190 | SmallVector<int64_t> mask; |
| 191 | mask.reserve(N: m * n); |
| 192 | for (int64_t j = 0; j < n; ++j) |
| 193 | for (int64_t i = 0; i < m; ++i) |
| 194 | mask.push_back(Elt: i * n + j); |
| 195 | return b.create<vector::ShuffleOp>(source.getLoc(), source, source, mask); |
| 196 | } |
| 197 | |
| 198 | /// Lowers the value to a sequence of vector.shuffle ops. The `source` is |
| 199 | /// expected to be a 16x16 vector. |
| 200 | static Value transposeToShuffle16x16(OpBuilder &builder, Value source, int m, |
| 201 | int n) { |
| 202 | ImplicitLocOpBuilder b(source.getLoc(), builder); |
| 203 | SmallVector<Value> vs; |
| 204 | for (int64_t i = 0; i < m; ++i) |
| 205 | vs.push_back(b.createOrFold<vector::ExtractOp>(source, i)); |
| 206 | |
| 207 | // Interleave 32-bit lanes using |
| 208 | // 8x _mm512_unpacklo_epi32 |
| 209 | // 8x _mm512_unpackhi_epi32 |
| 210 | Value t0 = createUnpackLoPs(b, v1: vs[0x0], v2: vs[0x1], numBits: 512); |
| 211 | Value t1 = createUnpackHiPs(b, v1: vs[0x0], v2: vs[0x1], numBits: 512); |
| 212 | Value t2 = createUnpackLoPs(b, v1: vs[0x2], v2: vs[0x3], numBits: 512); |
| 213 | Value t3 = createUnpackHiPs(b, v1: vs[0x2], v2: vs[0x3], numBits: 512); |
| 214 | Value t4 = createUnpackLoPs(b, v1: vs[0x4], v2: vs[0x5], numBits: 512); |
| 215 | Value t5 = createUnpackHiPs(b, v1: vs[0x4], v2: vs[0x5], numBits: 512); |
| 216 | Value t6 = createUnpackLoPs(b, v1: vs[0x6], v2: vs[0x7], numBits: 512); |
| 217 | Value t7 = createUnpackHiPs(b, v1: vs[0x6], v2: vs[0x7], numBits: 512); |
| 218 | Value t8 = createUnpackLoPs(b, v1: vs[0x8], v2: vs[0x9], numBits: 512); |
| 219 | Value t9 = createUnpackHiPs(b, v1: vs[0x8], v2: vs[0x9], numBits: 512); |
| 220 | Value ta = createUnpackLoPs(b, v1: vs[0xa], v2: vs[0xb], numBits: 512); |
| 221 | Value tb = createUnpackHiPs(b, v1: vs[0xa], v2: vs[0xb], numBits: 512); |
| 222 | Value tc = createUnpackLoPs(b, v1: vs[0xc], v2: vs[0xd], numBits: 512); |
| 223 | Value td = createUnpackHiPs(b, v1: vs[0xc], v2: vs[0xd], numBits: 512); |
| 224 | Value te = createUnpackLoPs(b, v1: vs[0xe], v2: vs[0xf], numBits: 512); |
| 225 | Value tf = createUnpackHiPs(b, v1: vs[0xe], v2: vs[0xf], numBits: 512); |
| 226 | |
| 227 | // Interleave 64-bit lanes using |
| 228 | // 8x _mm512_unpacklo_epi64 |
| 229 | // 8x _mm512_unpackhi_epi64 |
| 230 | Value r0 = createUnpackLoPd(b, v1: t0, v2: t2, numBits: 512); |
| 231 | Value r1 = createUnpackHiPd(b, v1: t0, v2: t2, numBits: 512); |
| 232 | Value r2 = createUnpackLoPd(b, v1: t1, v2: t3, numBits: 512); |
| 233 | Value r3 = createUnpackHiPd(b, v1: t1, v2: t3, numBits: 512); |
| 234 | Value r4 = createUnpackLoPd(b, v1: t4, v2: t6, numBits: 512); |
| 235 | Value r5 = createUnpackHiPd(b, v1: t4, v2: t6, numBits: 512); |
| 236 | Value r6 = createUnpackLoPd(b, v1: t5, v2: t7, numBits: 512); |
| 237 | Value r7 = createUnpackHiPd(b, v1: t5, v2: t7, numBits: 512); |
| 238 | Value r8 = createUnpackLoPd(b, v1: t8, v2: ta, numBits: 512); |
| 239 | Value r9 = createUnpackHiPd(b, v1: t8, v2: ta, numBits: 512); |
| 240 | Value ra = createUnpackLoPd(b, v1: t9, v2: tb, numBits: 512); |
| 241 | Value rb = createUnpackHiPd(b, v1: t9, v2: tb, numBits: 512); |
| 242 | Value rc = createUnpackLoPd(b, v1: tc, v2: te, numBits: 512); |
| 243 | Value rd = createUnpackHiPd(b, v1: tc, v2: te, numBits: 512); |
| 244 | Value re = createUnpackLoPd(b, v1: td, v2: tf, numBits: 512); |
| 245 | Value rf = createUnpackHiPd(b, v1: td, v2: tf, numBits: 512); |
| 246 | |
| 247 | // Permute 128-bit lanes using |
| 248 | // 16x _mm512_shuffle_i32x4 |
| 249 | t0 = create4x128BitSuffle(b, v1: r0, v2: r4, mask: 0x88); |
| 250 | t1 = create4x128BitSuffle(b, v1: r1, v2: r5, mask: 0x88); |
| 251 | t2 = create4x128BitSuffle(b, v1: r2, v2: r6, mask: 0x88); |
| 252 | t3 = create4x128BitSuffle(b, v1: r3, v2: r7, mask: 0x88); |
| 253 | t4 = create4x128BitSuffle(b, v1: r0, v2: r4, mask: 0xdd); |
| 254 | t5 = create4x128BitSuffle(b, v1: r1, v2: r5, mask: 0xdd); |
| 255 | t6 = create4x128BitSuffle(b, v1: r2, v2: r6, mask: 0xdd); |
| 256 | t7 = create4x128BitSuffle(b, v1: r3, v2: r7, mask: 0xdd); |
| 257 | t8 = create4x128BitSuffle(b, v1: r8, v2: rc, mask: 0x88); |
| 258 | t9 = create4x128BitSuffle(b, v1: r9, v2: rd, mask: 0x88); |
| 259 | ta = create4x128BitSuffle(b, v1: ra, v2: re, mask: 0x88); |
| 260 | tb = create4x128BitSuffle(b, v1: rb, v2: rf, mask: 0x88); |
| 261 | tc = create4x128BitSuffle(b, v1: r8, v2: rc, mask: 0xdd); |
| 262 | td = create4x128BitSuffle(b, v1: r9, v2: rd, mask: 0xdd); |
| 263 | te = create4x128BitSuffle(b, v1: ra, v2: re, mask: 0xdd); |
| 264 | tf = create4x128BitSuffle(b, v1: rb, v2: rf, mask: 0xdd); |
| 265 | |
| 266 | // Permute 256-bit lanes using again |
| 267 | // 16x _mm512_shuffle_i32x4 |
| 268 | vs[0x0] = create4x128BitSuffle(b, v1: t0, v2: t8, mask: 0x88); |
| 269 | vs[0x1] = create4x128BitSuffle(b, v1: t1, v2: t9, mask: 0x88); |
| 270 | vs[0x2] = create4x128BitSuffle(b, v1: t2, v2: ta, mask: 0x88); |
| 271 | vs[0x3] = create4x128BitSuffle(b, v1: t3, v2: tb, mask: 0x88); |
| 272 | vs[0x4] = create4x128BitSuffle(b, v1: t4, v2: tc, mask: 0x88); |
| 273 | vs[0x5] = create4x128BitSuffle(b, v1: t5, v2: td, mask: 0x88); |
| 274 | vs[0x6] = create4x128BitSuffle(b, v1: t6, v2: te, mask: 0x88); |
| 275 | vs[0x7] = create4x128BitSuffle(b, v1: t7, v2: tf, mask: 0x88); |
| 276 | vs[0x8] = create4x128BitSuffle(b, v1: t0, v2: t8, mask: 0xdd); |
| 277 | vs[0x9] = create4x128BitSuffle(b, v1: t1, v2: t9, mask: 0xdd); |
| 278 | vs[0xa] = create4x128BitSuffle(b, v1: t2, v2: ta, mask: 0xdd); |
| 279 | vs[0xb] = create4x128BitSuffle(b, v1: t3, v2: tb, mask: 0xdd); |
| 280 | vs[0xc] = create4x128BitSuffle(b, v1: t4, v2: tc, mask: 0xdd); |
| 281 | vs[0xd] = create4x128BitSuffle(b, v1: t5, v2: td, mask: 0xdd); |
| 282 | vs[0xe] = create4x128BitSuffle(b, v1: t6, v2: te, mask: 0xdd); |
| 283 | vs[0xf] = create4x128BitSuffle(b, v1: t7, v2: tf, mask: 0xdd); |
| 284 | |
| 285 | auto reshInputType = VectorType::get( |
| 286 | {m, n}, cast<VectorType>(source.getType()).getElementType()); |
| 287 | Value res = b.create<ub::PoisonOp>(reshInputType); |
| 288 | for (int64_t i = 0; i < m; ++i) |
| 289 | res = b.create<vector::InsertOp>(vs[i], res, i); |
| 290 | return res; |
| 291 | } |
| 292 | |
| 293 | namespace { |
| 294 | /// Progressive lowering of TransposeOp. |
| 295 | /// One: |
| 296 | /// %x = vector.transpose %y, [1, 0] |
| 297 | /// is replaced by: |
| 298 | /// %z = arith.constant dense<0.000000e+00> |
| 299 | /// %0 = vector.extract %y[0, 0] |
| 300 | /// %1 = vector.insert %0, %z [0, 0] |
| 301 | /// .. |
| 302 | /// %x = vector.insert .., .. [.., ..] |
| 303 | class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> { |
| 304 | public: |
| 305 | using OpRewritePattern::OpRewritePattern; |
| 306 | |
| 307 | TransposeOpLowering(vector::VectorTransposeLowering vectorTransposeLowering, |
| 308 | MLIRContext *context, PatternBenefit benefit = 1) |
| 309 | : OpRewritePattern<vector::TransposeOp>(context, benefit), |
| 310 | vectorTransposeLowering(vectorTransposeLowering) {} |
| 311 | |
| 312 | LogicalResult matchAndRewrite(vector::TransposeOp op, |
| 313 | PatternRewriter &rewriter) const override { |
| 314 | auto loc = op.getLoc(); |
| 315 | |
| 316 | Value input = op.getVector(); |
| 317 | VectorType inputType = op.getSourceVectorType(); |
| 318 | VectorType resType = op.getResultVectorType(); |
| 319 | |
| 320 | if (inputType.isScalable()) |
| 321 | return rewriter.notifyMatchFailure( |
| 322 | op, "This lowering does not support scalable vectors" ); |
| 323 | |
| 324 | // Set up convenience transposition table. |
| 325 | ArrayRef<int64_t> transp = op.getPermutation(); |
| 326 | |
| 327 | if (isShuffleLike(vectorTransposeLowering) && |
| 328 | succeeded(isTranspose2DSlice(op))) |
| 329 | return rewriter.notifyMatchFailure( |
| 330 | op, "Options specifies lowering to shuffle" ); |
| 331 | |
| 332 | // Handle a true 2-D matrix transpose differently when requested. |
| 333 | if (vectorTransposeLowering == vector::VectorTransposeLowering::Flat && |
| 334 | resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) { |
| 335 | Type flattenedType = |
| 336 | VectorType::get(resType.getNumElements(), resType.getElementType()); |
| 337 | auto matrix = |
| 338 | rewriter.create<vector::ShapeCastOp>(loc, flattenedType, input); |
| 339 | auto rows = rewriter.getI32IntegerAttr(value: resType.getShape()[0]); |
| 340 | auto columns = rewriter.getI32IntegerAttr(value: resType.getShape()[1]); |
| 341 | Value trans = rewriter.create<vector::FlatTransposeOp>( |
| 342 | loc, flattenedType, matrix, rows, columns); |
| 343 | rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, trans); |
| 344 | return success(); |
| 345 | } |
| 346 | |
| 347 | // Generate unrolled extract/insert ops. We do not unroll the rightmost |
| 348 | // (i.e., highest-order) dimensions that are not transposed and leave them |
| 349 | // in vector form to improve performance. Therefore, we prune those |
| 350 | // dimensions from the shape/transpose data structures used to generate the |
| 351 | // extract/insert ops. |
| 352 | SmallVector<int64_t> prunedTransp; |
| 353 | pruneNonTransposedDims(transpose: transp, result&: prunedTransp); |
| 354 | size_t numPrunedDims = transp.size() - prunedTransp.size(); |
| 355 | auto prunedInShape = inputType.getShape().drop_back(numPrunedDims); |
| 356 | auto prunedInStrides = computeStrides(prunedInShape); |
| 357 | |
| 358 | // Generates the extract/insert operations for every scalar/vector element |
| 359 | // of the leftmost transposed dimensions. We traverse every transpose |
| 360 | // element using a linearized index that we delinearize to generate the |
| 361 | // appropriate indices for the extract/insert operations. |
| 362 | Value result = rewriter.create<ub::PoisonOp>(loc, resType); |
| 363 | int64_t numTransposedElements = ShapedType::getNumElements(prunedInShape); |
| 364 | |
| 365 | for (int64_t linearIdx = 0; linearIdx < numTransposedElements; |
| 366 | ++linearIdx) { |
| 367 | auto = delinearize(linearIdx, prunedInStrides); |
| 368 | SmallVector<int64_t> insertIdxs(extractIdxs); |
| 369 | applyPermutationToVector(inVec&: insertIdxs, permutation: prunedTransp); |
| 370 | Value = |
| 371 | rewriter.createOrFold<vector::ExtractOp>(loc, input, extractIdxs); |
| 372 | result = rewriter.createOrFold<vector::InsertOp>(loc, extractOp, result, |
| 373 | insertIdxs); |
| 374 | } |
| 375 | |
| 376 | rewriter.replaceOp(op, result); |
| 377 | return success(); |
| 378 | } |
| 379 | |
| 380 | private: |
| 381 | /// Options to control the vector patterns. |
| 382 | vector::VectorTransposeLowering vectorTransposeLowering; |
| 383 | }; |
| 384 | |
| 385 | /// Rewrites vector.transpose as vector.shape_cast. This pattern is only applied |
| 386 | /// to 2D vectors with at least one unit dim. For example: |
| 387 | /// |
| 388 | /// Replace: |
| 389 | /// vector.transpose %0, [1, 0] : vector<4x1xi32>> to |
| 390 | /// vector<1x4xi32> |
| 391 | /// with: |
| 392 | /// vector.shape_cast %0 : vector<4x1xi32> to vector<1x4xi32> |
| 393 | /// |
| 394 | /// Source with leading unit dim (inverse) is also replaced. Unit dim must |
| 395 | /// be fixed. Non-unit dim can be scalable. |
| 396 | /// |
| 397 | /// TODO: This pattern was introduced specifically to help lower scalable |
| 398 | /// vectors. In hindsight, a more specialised canonicalization (for shape_cast's |
| 399 | /// to cancel out) would be preferable: |
| 400 | /// |
| 401 | /// BEFORE: |
| 402 | /// %0 = some_op |
| 403 | /// %1 = vector.shape_cast %0 : vector<[4]xf32> to vector<[4]x1xf32> |
| 404 | /// %2 = vector.transpose %1 [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32> |
| 405 | /// AFTER: |
| 406 | /// %0 = some_op |
| 407 | /// %1 = vector.shape_cast %0 : vector<[4]xf32> to vector<1x[4]xf32> |
| 408 | /// |
| 409 | /// Given the context above, we may want to consider (re-)moving this pattern |
| 410 | /// at some later time. I am leaving it for now in case there are other users |
| 411 | /// that I am not aware of. |
| 412 | class Transpose2DWithUnitDimToShapeCast |
| 413 | : public OpRewritePattern<vector::TransposeOp> { |
| 414 | public: |
| 415 | using OpRewritePattern::OpRewritePattern; |
| 416 | |
| 417 | Transpose2DWithUnitDimToShapeCast(MLIRContext *context, |
| 418 | PatternBenefit benefit = 1) |
| 419 | : OpRewritePattern<vector::TransposeOp>(context, benefit) {} |
| 420 | |
| 421 | LogicalResult matchAndRewrite(vector::TransposeOp op, |
| 422 | PatternRewriter &rewriter) const override { |
| 423 | Value input = op.getVector(); |
| 424 | VectorType resType = op.getResultVectorType(); |
| 425 | |
| 426 | // Set up convenience transposition table. |
| 427 | ArrayRef<int64_t> transp = op.getPermutation(); |
| 428 | |
| 429 | if (resType.getRank() == 2 && |
| 430 | ((resType.getShape().front() == 1 && |
| 431 | !resType.getScalableDims().front()) || |
| 432 | (resType.getShape().back() == 1 && |
| 433 | !resType.getScalableDims().back())) && |
| 434 | transp == ArrayRef<int64_t>({1, 0})) { |
| 435 | rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, input); |
| 436 | return success(); |
| 437 | } |
| 438 | |
| 439 | return failure(); |
| 440 | } |
| 441 | }; |
| 442 | |
| 443 | /// Rewrite a 2-D vector.transpose as a sequence of shuffle ops. |
| 444 | /// If the strategy is Shuffle1D, it will be lowered to: |
| 445 | /// vector.shape_cast 2D -> 1D |
| 446 | /// vector.shuffle |
| 447 | /// vector.shape_cast 1D -> 2D |
| 448 | /// If the strategy is Shuffle16x16, it will be lowered to a sequence of shuffle |
| 449 | /// ops on 16xf32 vectors. |
| 450 | class TransposeOp2DToShuffleLowering |
| 451 | : public OpRewritePattern<vector::TransposeOp> { |
| 452 | public: |
| 453 | using OpRewritePattern::OpRewritePattern; |
| 454 | |
| 455 | TransposeOp2DToShuffleLowering( |
| 456 | vector::VectorTransposeLowering vectorTransposeLowering, |
| 457 | MLIRContext *context, PatternBenefit benefit = 1) |
| 458 | : OpRewritePattern<vector::TransposeOp>(context, benefit), |
| 459 | vectorTransposeLowering(vectorTransposeLowering) {} |
| 460 | |
| 461 | LogicalResult matchAndRewrite(vector::TransposeOp op, |
| 462 | PatternRewriter &rewriter) const override { |
| 463 | if (!isShuffleLike(vectorTransposeLowering)) |
| 464 | return rewriter.notifyMatchFailure( |
| 465 | op, "not using vector shuffle based lowering" ); |
| 466 | |
| 467 | if (op.getSourceVectorType().isScalable()) |
| 468 | return rewriter.notifyMatchFailure( |
| 469 | op, "vector shuffle lowering not supported for scalable vectors" ); |
| 470 | |
| 471 | auto srcGtOneDims = isTranspose2DSlice(op); |
| 472 | if (failed(srcGtOneDims)) |
| 473 | return rewriter.notifyMatchFailure( |
| 474 | op, "expected transposition on a 2D slice" ); |
| 475 | |
| 476 | VectorType srcType = op.getSourceVectorType(); |
| 477 | int64_t m = srcType.getDimSize(std::get<0>(srcGtOneDims.value())); |
| 478 | int64_t n = srcType.getDimSize(std::get<1>(srcGtOneDims.value())); |
| 479 | |
| 480 | // Reshape the n-D input vector with only two dimensions greater than one |
| 481 | // to a 2-D vector. |
| 482 | Location loc = op.getLoc(); |
| 483 | auto flattenedType = VectorType::get({n * m}, srcType.getElementType()); |
| 484 | auto reshInputType = VectorType::get({m, n}, srcType.getElementType()); |
| 485 | auto reshInput = rewriter.create<vector::ShapeCastOp>(loc, flattenedType, |
| 486 | op.getVector()); |
| 487 | |
| 488 | Value res; |
| 489 | if (vectorTransposeLowering == VectorTransposeLowering::Shuffle16x16 && |
| 490 | m == 16 && n == 16) { |
| 491 | reshInput = |
| 492 | rewriter.create<vector::ShapeCastOp>(loc, reshInputType, reshInput); |
| 493 | res = transposeToShuffle16x16(rewriter, reshInput, m, n); |
| 494 | } else { |
| 495 | // Fallback to shuffle on 1D approach. |
| 496 | res = transposeToShuffle1D(rewriter, reshInput, m, n); |
| 497 | } |
| 498 | |
| 499 | rewriter.replaceOpWithNewOp<vector::ShapeCastOp>( |
| 500 | op, op.getResultVectorType(), res); |
| 501 | |
| 502 | return success(); |
| 503 | } |
| 504 | |
| 505 | private: |
| 506 | /// Options to control the vector patterns. |
| 507 | vector::VectorTransposeLowering vectorTransposeLowering; |
| 508 | }; |
| 509 | } // namespace |
| 510 | |
| 511 | void mlir::vector::populateVectorTransposeLoweringPatterns( |
| 512 | RewritePatternSet &patterns, |
| 513 | VectorTransposeLowering vectorTransposeLowering, PatternBenefit benefit) { |
| 514 | patterns.add<Transpose2DWithUnitDimToShapeCast>(arg: patterns.getContext(), |
| 515 | args&: benefit); |
| 516 | patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>( |
| 517 | vectorTransposeLowering, patterns.getContext(), benefit); |
| 518 | } |
| 519 | |