| 1 | //===- LowerVectorTranspose.cpp - Lower 'vector.transpose' operation ------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements target-independent rewrites and utilities to lower the |
| 10 | // 'vector.transpose' operation. |
| 11 | // |
| 12 | //===----------------------------------------------------------------------===// |
| 13 | |
| 14 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| 15 | #include "mlir/Dialect/UB/IR/UBOps.h" |
| 16 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
| 17 | #include "mlir/Dialect/Utils/StructuredOpsUtils.h" |
| 18 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
| 19 | #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" |
| 20 | #include "mlir/Dialect/Vector/Utils/VectorUtils.h" |
| 21 | #include "mlir/IR/BuiltinTypes.h" |
| 22 | #include "mlir/IR/ImplicitLocOpBuilder.h" |
| 23 | #include "mlir/IR/Location.h" |
| 24 | #include "mlir/IR/PatternMatch.h" |
| 25 | #include "mlir/IR/TypeUtilities.h" |
| 26 | |
| 27 | #define DEBUG_TYPE "lower-vector-transpose" |
| 28 | |
| 29 | using namespace mlir; |
| 30 | using namespace mlir::vector; |
| 31 | |
| 32 | /// Given a 'transpose' pattern, prune the rightmost dimensions that are not |
| 33 | /// transposed. |
| 34 | static void pruneNonTransposedDims(ArrayRef<int64_t> transpose, |
| 35 | SmallVectorImpl<int64_t> &result) { |
| 36 | size_t numTransposedDims = transpose.size(); |
| 37 | for (size_t transpDim : llvm::reverse(C&: transpose)) { |
| 38 | if (transpDim != numTransposedDims - 1) |
| 39 | break; |
| 40 | numTransposedDims--; |
| 41 | } |
| 42 | |
| 43 | result.append(in_start: transpose.begin(), in_end: transpose.begin() + numTransposedDims); |
| 44 | } |
| 45 | |
| 46 | /// Returns true if the lowering option is a vector shuffle based approach. |
| 47 | static bool isShuffleLike(VectorTransposeLowering lowering) { |
| 48 | return lowering == VectorTransposeLowering::Shuffle1D || |
| 49 | lowering == VectorTransposeLowering::Shuffle16x16; |
| 50 | } |
| 51 | |
| 52 | /// Returns a shuffle mask that builds on `vals`. `vals` is the offset base of |
| 53 | /// shuffle ops, i.e., the unpack pattern. The method iterates with `vals` to |
| 54 | /// create the mask for `numBits` bits vector. The `numBits` have to be a |
| 55 | /// multiple of 128. For example, if `vals` is {0, 1, 16, 17} and `numBits` is |
| 56 | /// 512, there should be 16 elements in the final result. It constructs the |
| 57 | /// below mask to get the unpack elements. |
| 58 | /// [0, 1, 16, 17, |
| 59 | /// 0+4, 1+4, 16+4, 17+4, |
| 60 | /// 0+8, 1+8, 16+8, 17+8, |
| 61 | /// 0+12, 1+12, 16+12, 17+12] |
| 62 | static SmallVector<int64_t> |
| 63 | getUnpackShufflePermFor128Lane(ArrayRef<int64_t> vals, int numBits) { |
| 64 | assert(numBits % 128 == 0 && "expected numBits is a multiple of 128" ); |
| 65 | int numElem = numBits / 32; |
| 66 | SmallVector<int64_t> res; |
| 67 | for (int i = 0; i < numElem; i += 4) |
| 68 | for (int64_t v : vals) |
| 69 | res.push_back(Elt: v + i); |
| 70 | return res; |
| 71 | } |
| 72 | |
| 73 | /// Lower to vector.shuffle on v1 and v2 with UnpackLoPd shuffle mask. For |
| 74 | /// example, if it is targeting 512 bit vector, returns |
| 75 | /// vector.shuffle on v1, v2, [0, 1, 16, 17, |
| 76 | /// 0+4, 1+4, 16+4, 17+4, |
| 77 | /// 0+8, 1+8, 16+8, 17+8, |
| 78 | /// 0+12, 1+12, 16+12, 17+12]. |
| 79 | static Value createUnpackLoPd(ImplicitLocOpBuilder &b, Value v1, Value v2, |
| 80 | int numBits) { |
| 81 | int numElem = numBits / 32; |
| 82 | return b.create<vector::ShuffleOp>( |
| 83 | args&: v1, args&: v2, |
| 84 | args: getUnpackShufflePermFor128Lane(vals: {0, 1, numElem, numElem + 1}, numBits)); |
| 85 | } |
| 86 | |
| 87 | /// Lower to vector.shuffle on v1 and v2 with UnpackHiPd shuffle mask. For |
| 88 | /// example, if it is targeting 512 bit vector, returns |
| 89 | /// vector.shuffle, v1, v2, [2, 3, 18, 19, |
| 90 | /// 2+4, 3+4, 18+4, 19+4, |
| 91 | /// 2+8, 3+8, 18+8, 19+8, |
| 92 | /// 2+12, 3+12, 18+12, 19+12]. |
| 93 | static Value createUnpackHiPd(ImplicitLocOpBuilder &b, Value v1, Value v2, |
| 94 | int numBits) { |
| 95 | int numElem = numBits / 32; |
| 96 | return b.create<vector::ShuffleOp>( |
| 97 | args&: v1, args&: v2, |
| 98 | args: getUnpackShufflePermFor128Lane(vals: {2, 3, numElem + 2, numElem + 3}, |
| 99 | numBits)); |
| 100 | } |
| 101 | |
| 102 | /// Lower to vector.shuffle on v1 and v2 with UnpackLoPs shuffle mask. For |
| 103 | /// example, if it is targeting 512 bit vector, returns |
| 104 | /// vector.shuffle, v1, v2, [0, 16, 1, 17, |
| 105 | /// 0+4, 16+4, 1+4, 17+4, |
| 106 | /// 0+8, 16+8, 1+8, 17+8, |
| 107 | /// 0+12, 16+12, 1+12, 17+12]. |
| 108 | static Value createUnpackLoPs(ImplicitLocOpBuilder &b, Value v1, Value v2, |
| 109 | int numBits) { |
| 110 | int numElem = numBits / 32; |
| 111 | auto shuffle = b.create<vector::ShuffleOp>( |
| 112 | args&: v1, args&: v2, |
| 113 | args: getUnpackShufflePermFor128Lane(vals: {0, numElem, 1, numElem + 1}, numBits)); |
| 114 | return shuffle; |
| 115 | } |
| 116 | |
| 117 | /// Lower to vector.shuffle on v1 and v2 with UnpackHiPs shuffle mask. For |
| 118 | /// example, if it is targeting 512 bit vector, returns |
| 119 | /// vector.shuffle, v1, v2, [2, 18, 3, 19, |
| 120 | /// 2+4, 18+4, 3+4, 19+4, |
| 121 | /// 2+8, 18+8, 3+8, 19+8, |
| 122 | /// 2+12, 18+12, 3+12, 19+12]. |
| 123 | static Value createUnpackHiPs(ImplicitLocOpBuilder &b, Value v1, Value v2, |
| 124 | int numBits) { |
| 125 | int numElem = numBits / 32; |
| 126 | return b.create<vector::ShuffleOp>( |
| 127 | args&: v1, args&: v2, |
| 128 | args: getUnpackShufflePermFor128Lane(vals: {2, numElem + 2, 3, numElem + 3}, |
| 129 | numBits)); |
| 130 | } |
| 131 | |
| 132 | /// Returns a vector.shuffle that shuffles 128-bit lanes (composed of 4 32-bit |
| 133 | /// elements) selected by `mask` from `v1` and `v2`. I.e., |
| 134 | /// |
| 135 | /// DEFINE SELECT4(src, control) { |
| 136 | /// CASE(control[1:0]) OF |
| 137 | /// 0: tmp[127:0] := src[127:0] |
| 138 | /// 1: tmp[127:0] := src[255:128] |
| 139 | /// 2: tmp[127:0] := src[383:256] |
| 140 | /// 3: tmp[127:0] := src[511:384] |
| 141 | /// ESAC |
| 142 | /// RETURN tmp[127:0] |
| 143 | /// } |
| 144 | /// dst[127:0] := SELECT4(v1[511:0], mask[1:0]) |
| 145 | /// dst[255:128] := SELECT4(v1[511:0], mask[3:2]) |
| 146 | /// dst[383:256] := SELECT4(v2[511:0], mask[5:4]) |
| 147 | /// dst[511:384] := SELECT4(v2[511:0], mask[7:6]) |
| 148 | static Value create4x128BitSuffle(ImplicitLocOpBuilder &b, Value v1, Value v2, |
| 149 | uint8_t mask) { |
| 150 | assert(cast<VectorType>(v1.getType()).getShape()[0] == 16 && |
| 151 | "expected a vector with length=16" ); |
| 152 | SmallVector<int64_t> shuffleMask; |
| 153 | auto appendToMask = [&](int64_t base, uint8_t control) { |
| 154 | switch (control) { |
| 155 | case 0: |
| 156 | llvm::append_range(C&: shuffleMask, R: ArrayRef<int64_t>{base + 0, base + 1, |
| 157 | base + 2, base + 3}); |
| 158 | break; |
| 159 | case 1: |
| 160 | llvm::append_range(C&: shuffleMask, R: ArrayRef<int64_t>{base + 4, base + 5, |
| 161 | base + 6, base + 7}); |
| 162 | break; |
| 163 | case 2: |
| 164 | llvm::append_range(C&: shuffleMask, R: ArrayRef<int64_t>{base + 8, base + 9, |
| 165 | base + 10, base + 11}); |
| 166 | break; |
| 167 | case 3: |
| 168 | llvm::append_range(C&: shuffleMask, R: ArrayRef<int64_t>{base + 12, base + 13, |
| 169 | base + 14, base + 15}); |
| 170 | break; |
| 171 | default: |
| 172 | llvm_unreachable("control > 3 : overflow" ); |
| 173 | } |
| 174 | }; |
| 175 | uint8_t b01 = mask & 0x3; |
| 176 | uint8_t b23 = (mask >> 2) & 0x3; |
| 177 | uint8_t b45 = (mask >> 4) & 0x3; |
| 178 | uint8_t b67 = (mask >> 6) & 0x3; |
| 179 | appendToMask(0, b01); |
| 180 | appendToMask(0, b23); |
| 181 | appendToMask(16, b45); |
| 182 | appendToMask(16, b67); |
| 183 | return b.create<vector::ShuffleOp>(args&: v1, args&: v2, args&: shuffleMask); |
| 184 | } |
| 185 | |
| 186 | /// Lowers the value to a vector.shuffle op. The `source` is expected to be a |
| 187 | /// 1-D vector and have `m`x`n` elements. |
| 188 | static Value transposeToShuffle1D(OpBuilder &b, Value source, int m, int n) { |
| 189 | SmallVector<int64_t> mask; |
| 190 | mask.reserve(N: m * n); |
| 191 | for (int64_t j = 0; j < n; ++j) |
| 192 | for (int64_t i = 0; i < m; ++i) |
| 193 | mask.push_back(Elt: i * n + j); |
| 194 | return b.create<vector::ShuffleOp>(location: source.getLoc(), args&: source, args&: source, args&: mask); |
| 195 | } |
| 196 | |
| 197 | /// Lowers the value to a sequence of vector.shuffle ops. The `source` is |
| 198 | /// expected to be a 16x16 vector. |
| 199 | static Value transposeToShuffle16x16(OpBuilder &builder, Value source, int m, |
| 200 | int n) { |
| 201 | ImplicitLocOpBuilder b(source.getLoc(), builder); |
| 202 | SmallVector<Value> vs; |
| 203 | for (int64_t i = 0; i < m; ++i) |
| 204 | vs.push_back(Elt: b.createOrFold<vector::ExtractOp>(args&: source, args&: i)); |
| 205 | |
| 206 | // Interleave 32-bit lanes using |
| 207 | // 8x _mm512_unpacklo_epi32 |
| 208 | // 8x _mm512_unpackhi_epi32 |
| 209 | Value t0 = createUnpackLoPs(b, v1: vs[0x0], v2: vs[0x1], numBits: 512); |
| 210 | Value t1 = createUnpackHiPs(b, v1: vs[0x0], v2: vs[0x1], numBits: 512); |
| 211 | Value t2 = createUnpackLoPs(b, v1: vs[0x2], v2: vs[0x3], numBits: 512); |
| 212 | Value t3 = createUnpackHiPs(b, v1: vs[0x2], v2: vs[0x3], numBits: 512); |
| 213 | Value t4 = createUnpackLoPs(b, v1: vs[0x4], v2: vs[0x5], numBits: 512); |
| 214 | Value t5 = createUnpackHiPs(b, v1: vs[0x4], v2: vs[0x5], numBits: 512); |
| 215 | Value t6 = createUnpackLoPs(b, v1: vs[0x6], v2: vs[0x7], numBits: 512); |
| 216 | Value t7 = createUnpackHiPs(b, v1: vs[0x6], v2: vs[0x7], numBits: 512); |
| 217 | Value t8 = createUnpackLoPs(b, v1: vs[0x8], v2: vs[0x9], numBits: 512); |
| 218 | Value t9 = createUnpackHiPs(b, v1: vs[0x8], v2: vs[0x9], numBits: 512); |
| 219 | Value ta = createUnpackLoPs(b, v1: vs[0xa], v2: vs[0xb], numBits: 512); |
| 220 | Value tb = createUnpackHiPs(b, v1: vs[0xa], v2: vs[0xb], numBits: 512); |
| 221 | Value tc = createUnpackLoPs(b, v1: vs[0xc], v2: vs[0xd], numBits: 512); |
| 222 | Value td = createUnpackHiPs(b, v1: vs[0xc], v2: vs[0xd], numBits: 512); |
| 223 | Value te = createUnpackLoPs(b, v1: vs[0xe], v2: vs[0xf], numBits: 512); |
| 224 | Value tf = createUnpackHiPs(b, v1: vs[0xe], v2: vs[0xf], numBits: 512); |
| 225 | |
| 226 | // Interleave 64-bit lanes using |
| 227 | // 8x _mm512_unpacklo_epi64 |
| 228 | // 8x _mm512_unpackhi_epi64 |
| 229 | Value r0 = createUnpackLoPd(b, v1: t0, v2: t2, numBits: 512); |
| 230 | Value r1 = createUnpackHiPd(b, v1: t0, v2: t2, numBits: 512); |
| 231 | Value r2 = createUnpackLoPd(b, v1: t1, v2: t3, numBits: 512); |
| 232 | Value r3 = createUnpackHiPd(b, v1: t1, v2: t3, numBits: 512); |
| 233 | Value r4 = createUnpackLoPd(b, v1: t4, v2: t6, numBits: 512); |
| 234 | Value r5 = createUnpackHiPd(b, v1: t4, v2: t6, numBits: 512); |
| 235 | Value r6 = createUnpackLoPd(b, v1: t5, v2: t7, numBits: 512); |
| 236 | Value r7 = createUnpackHiPd(b, v1: t5, v2: t7, numBits: 512); |
| 237 | Value r8 = createUnpackLoPd(b, v1: t8, v2: ta, numBits: 512); |
| 238 | Value r9 = createUnpackHiPd(b, v1: t8, v2: ta, numBits: 512); |
| 239 | Value ra = createUnpackLoPd(b, v1: t9, v2: tb, numBits: 512); |
| 240 | Value rb = createUnpackHiPd(b, v1: t9, v2: tb, numBits: 512); |
| 241 | Value rc = createUnpackLoPd(b, v1: tc, v2: te, numBits: 512); |
| 242 | Value rd = createUnpackHiPd(b, v1: tc, v2: te, numBits: 512); |
| 243 | Value re = createUnpackLoPd(b, v1: td, v2: tf, numBits: 512); |
| 244 | Value rf = createUnpackHiPd(b, v1: td, v2: tf, numBits: 512); |
| 245 | |
| 246 | // Permute 128-bit lanes using |
| 247 | // 16x _mm512_shuffle_i32x4 |
| 248 | t0 = create4x128BitSuffle(b, v1: r0, v2: r4, mask: 0x88); |
| 249 | t1 = create4x128BitSuffle(b, v1: r1, v2: r5, mask: 0x88); |
| 250 | t2 = create4x128BitSuffle(b, v1: r2, v2: r6, mask: 0x88); |
| 251 | t3 = create4x128BitSuffle(b, v1: r3, v2: r7, mask: 0x88); |
| 252 | t4 = create4x128BitSuffle(b, v1: r0, v2: r4, mask: 0xdd); |
| 253 | t5 = create4x128BitSuffle(b, v1: r1, v2: r5, mask: 0xdd); |
| 254 | t6 = create4x128BitSuffle(b, v1: r2, v2: r6, mask: 0xdd); |
| 255 | t7 = create4x128BitSuffle(b, v1: r3, v2: r7, mask: 0xdd); |
| 256 | t8 = create4x128BitSuffle(b, v1: r8, v2: rc, mask: 0x88); |
| 257 | t9 = create4x128BitSuffle(b, v1: r9, v2: rd, mask: 0x88); |
| 258 | ta = create4x128BitSuffle(b, v1: ra, v2: re, mask: 0x88); |
| 259 | tb = create4x128BitSuffle(b, v1: rb, v2: rf, mask: 0x88); |
| 260 | tc = create4x128BitSuffle(b, v1: r8, v2: rc, mask: 0xdd); |
| 261 | td = create4x128BitSuffle(b, v1: r9, v2: rd, mask: 0xdd); |
| 262 | te = create4x128BitSuffle(b, v1: ra, v2: re, mask: 0xdd); |
| 263 | tf = create4x128BitSuffle(b, v1: rb, v2: rf, mask: 0xdd); |
| 264 | |
| 265 | // Permute 256-bit lanes using again |
| 266 | // 16x _mm512_shuffle_i32x4 |
| 267 | vs[0x0] = create4x128BitSuffle(b, v1: t0, v2: t8, mask: 0x88); |
| 268 | vs[0x1] = create4x128BitSuffle(b, v1: t1, v2: t9, mask: 0x88); |
| 269 | vs[0x2] = create4x128BitSuffle(b, v1: t2, v2: ta, mask: 0x88); |
| 270 | vs[0x3] = create4x128BitSuffle(b, v1: t3, v2: tb, mask: 0x88); |
| 271 | vs[0x4] = create4x128BitSuffle(b, v1: t4, v2: tc, mask: 0x88); |
| 272 | vs[0x5] = create4x128BitSuffle(b, v1: t5, v2: td, mask: 0x88); |
| 273 | vs[0x6] = create4x128BitSuffle(b, v1: t6, v2: te, mask: 0x88); |
| 274 | vs[0x7] = create4x128BitSuffle(b, v1: t7, v2: tf, mask: 0x88); |
| 275 | vs[0x8] = create4x128BitSuffle(b, v1: t0, v2: t8, mask: 0xdd); |
| 276 | vs[0x9] = create4x128BitSuffle(b, v1: t1, v2: t9, mask: 0xdd); |
| 277 | vs[0xa] = create4x128BitSuffle(b, v1: t2, v2: ta, mask: 0xdd); |
| 278 | vs[0xb] = create4x128BitSuffle(b, v1: t3, v2: tb, mask: 0xdd); |
| 279 | vs[0xc] = create4x128BitSuffle(b, v1: t4, v2: tc, mask: 0xdd); |
| 280 | vs[0xd] = create4x128BitSuffle(b, v1: t5, v2: td, mask: 0xdd); |
| 281 | vs[0xe] = create4x128BitSuffle(b, v1: t6, v2: te, mask: 0xdd); |
| 282 | vs[0xf] = create4x128BitSuffle(b, v1: t7, v2: tf, mask: 0xdd); |
| 283 | |
| 284 | auto reshInputType = VectorType::get( |
| 285 | shape: {m, n}, elementType: cast<VectorType>(Val: source.getType()).getElementType()); |
| 286 | Value res = b.create<ub::PoisonOp>(args&: reshInputType); |
| 287 | for (int64_t i = 0; i < m; ++i) |
| 288 | res = b.create<vector::InsertOp>(args&: vs[i], args&: res, args&: i); |
| 289 | return res; |
| 290 | } |
| 291 | |
| 292 | namespace { |
| 293 | /// Progressive lowering of TransposeOp. |
| 294 | /// One: |
| 295 | /// %x = vector.transpose %y, [1, 0] |
| 296 | /// is replaced by: |
| 297 | /// %z = arith.constant dense<0.000000e+00> |
| 298 | /// %0 = vector.extract %y[0, 0] |
| 299 | /// %1 = vector.insert %0, %z [0, 0] |
| 300 | /// .. |
| 301 | /// %x = vector.insert .., .. [.., ..] |
| 302 | class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> { |
| 303 | public: |
| 304 | using OpRewritePattern::OpRewritePattern; |
| 305 | |
| 306 | TransposeOpLowering(vector::VectorTransposeLowering vectorTransposeLowering, |
| 307 | MLIRContext *context, PatternBenefit benefit = 1) |
| 308 | : OpRewritePattern<vector::TransposeOp>(context, benefit), |
| 309 | vectorTransposeLowering(vectorTransposeLowering) {} |
| 310 | |
| 311 | LogicalResult matchAndRewrite(vector::TransposeOp op, |
| 312 | PatternRewriter &rewriter) const override { |
| 313 | auto loc = op.getLoc(); |
| 314 | |
| 315 | Value input = op.getVector(); |
| 316 | VectorType inputType = op.getSourceVectorType(); |
| 317 | VectorType resType = op.getResultVectorType(); |
| 318 | |
| 319 | if (inputType.isScalable()) |
| 320 | return rewriter.notifyMatchFailure( |
| 321 | arg&: op, msg: "This lowering does not support scalable vectors" ); |
| 322 | |
| 323 | // Set up convenience transposition table. |
| 324 | ArrayRef<int64_t> transp = op.getPermutation(); |
| 325 | |
| 326 | if (isShuffleLike(lowering: vectorTransposeLowering) && |
| 327 | succeeded(Result: isTranspose2DSlice(op))) |
| 328 | return rewriter.notifyMatchFailure( |
| 329 | arg&: op, msg: "Options specifies lowering to shuffle" ); |
| 330 | |
| 331 | // Handle a true 2-D matrix transpose differently when requested. |
| 332 | if (vectorTransposeLowering == vector::VectorTransposeLowering::Flat && |
| 333 | resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) { |
| 334 | Type flattenedType = |
| 335 | VectorType::get(shape: resType.getNumElements(), elementType: resType.getElementType()); |
| 336 | auto matrix = |
| 337 | rewriter.create<vector::ShapeCastOp>(location: loc, args&: flattenedType, args&: input); |
| 338 | auto rows = rewriter.getI32IntegerAttr(value: resType.getShape()[0]); |
| 339 | auto columns = rewriter.getI32IntegerAttr(value: resType.getShape()[1]); |
| 340 | Value trans = rewriter.create<vector::FlatTransposeOp>( |
| 341 | location: loc, args&: flattenedType, args&: matrix, args&: rows, args&: columns); |
| 342 | rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, args&: resType, args&: trans); |
| 343 | return success(); |
| 344 | } |
| 345 | |
| 346 | // Generate unrolled extract/insert ops. We do not unroll the rightmost |
| 347 | // (i.e., highest-order) dimensions that are not transposed and leave them |
| 348 | // in vector form to improve performance. Therefore, we prune those |
| 349 | // dimensions from the shape/transpose data structures used to generate the |
| 350 | // extract/insert ops. |
| 351 | SmallVector<int64_t> prunedTransp; |
| 352 | pruneNonTransposedDims(transpose: transp, result&: prunedTransp); |
| 353 | size_t numPrunedDims = transp.size() - prunedTransp.size(); |
| 354 | auto prunedInShape = inputType.getShape().drop_back(N: numPrunedDims); |
| 355 | auto prunedInStrides = computeStrides(sizes: prunedInShape); |
| 356 | |
| 357 | // Generates the extract/insert operations for every scalar/vector element |
| 358 | // of the leftmost transposed dimensions. We traverse every transpose |
| 359 | // element using a linearized index that we delinearize to generate the |
| 360 | // appropriate indices for the extract/insert operations. |
| 361 | Value result = rewriter.create<ub::PoisonOp>(location: loc, args&: resType); |
| 362 | int64_t numTransposedElements = ShapedType::getNumElements(shape: prunedInShape); |
| 363 | |
| 364 | for (int64_t linearIdx = 0; linearIdx < numTransposedElements; |
| 365 | ++linearIdx) { |
| 366 | auto = delinearize(linearIndex: linearIdx, strides: prunedInStrides); |
| 367 | SmallVector<int64_t> insertIdxs(extractIdxs); |
| 368 | applyPermutationToVector(inVec&: insertIdxs, permutation: prunedTransp); |
| 369 | Value = |
| 370 | rewriter.createOrFold<vector::ExtractOp>(location: loc, args&: input, args&: extractIdxs); |
| 371 | result = rewriter.createOrFold<vector::InsertOp>(location: loc, args&: extractOp, args&: result, |
| 372 | args&: insertIdxs); |
| 373 | } |
| 374 | |
| 375 | rewriter.replaceOp(op, newValues: result); |
| 376 | return success(); |
| 377 | } |
| 378 | |
| 379 | private: |
| 380 | /// Options to control the vector patterns. |
| 381 | vector::VectorTransposeLowering vectorTransposeLowering; |
| 382 | }; |
| 383 | |
| 384 | /// Rewrites vector.transpose as vector.shape_cast. This pattern is only applied |
| 385 | /// to 2D vectors with at least one unit dim. For example: |
| 386 | /// |
| 387 | /// Replace: |
| 388 | /// vector.transpose %0, [1, 0] : vector<4x1xi32>> to |
| 389 | /// vector<1x4xi32> |
| 390 | /// with: |
| 391 | /// vector.shape_cast %0 : vector<4x1xi32> to vector<1x4xi32> |
| 392 | /// |
| 393 | /// Source with leading unit dim (inverse) is also replaced. Unit dim must |
| 394 | /// be fixed. Non-unit dim can be scalable. |
| 395 | /// |
| 396 | /// TODO: This pattern was introduced specifically to help lower scalable |
| 397 | /// vectors. In hindsight, a more specialised canonicalization (for shape_cast's |
| 398 | /// to cancel out) would be preferable: |
| 399 | /// |
| 400 | /// BEFORE: |
| 401 | /// %0 = some_op |
| 402 | /// %1 = vector.shape_cast %0 : vector<[4]xf32> to vector<[4]x1xf32> |
| 403 | /// %2 = vector.transpose %1 [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32> |
| 404 | /// AFTER: |
| 405 | /// %0 = some_op |
| 406 | /// %1 = vector.shape_cast %0 : vector<[4]xf32> to vector<1x[4]xf32> |
| 407 | /// |
| 408 | /// Given the context above, we may want to consider (re-)moving this pattern |
| 409 | /// at some later time. I am leaving it for now in case there are other users |
| 410 | /// that I am not aware of. |
| 411 | class Transpose2DWithUnitDimToShapeCast |
| 412 | : public OpRewritePattern<vector::TransposeOp> { |
| 413 | public: |
| 414 | using OpRewritePattern::OpRewritePattern; |
| 415 | |
| 416 | Transpose2DWithUnitDimToShapeCast(MLIRContext *context, |
| 417 | PatternBenefit benefit = 1) |
| 418 | : OpRewritePattern<vector::TransposeOp>(context, benefit) {} |
| 419 | |
| 420 | LogicalResult matchAndRewrite(vector::TransposeOp op, |
| 421 | PatternRewriter &rewriter) const override { |
| 422 | Value input = op.getVector(); |
| 423 | VectorType resType = op.getResultVectorType(); |
| 424 | |
| 425 | // Set up convenience transposition table. |
| 426 | ArrayRef<int64_t> transp = op.getPermutation(); |
| 427 | |
| 428 | if (resType.getRank() == 2 && |
| 429 | ((resType.getShape().front() == 1 && |
| 430 | !resType.getScalableDims().front()) || |
| 431 | (resType.getShape().back() == 1 && |
| 432 | !resType.getScalableDims().back())) && |
| 433 | transp == ArrayRef<int64_t>({1, 0})) { |
| 434 | rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, args&: resType, args&: input); |
| 435 | return success(); |
| 436 | } |
| 437 | |
| 438 | return failure(); |
| 439 | } |
| 440 | }; |
| 441 | |
| 442 | /// Rewrite a 2-D vector.transpose as a sequence of shuffle ops. |
| 443 | /// If the strategy is Shuffle1D, it will be lowered to: |
| 444 | /// vector.shape_cast 2D -> 1D |
| 445 | /// vector.shuffle |
| 446 | /// vector.shape_cast 1D -> 2D |
| 447 | /// If the strategy is Shuffle16x16, it will be lowered to a sequence of shuffle |
| 448 | /// ops on 16xf32 vectors. |
| 449 | class TransposeOp2DToShuffleLowering |
| 450 | : public OpRewritePattern<vector::TransposeOp> { |
| 451 | public: |
| 452 | using OpRewritePattern::OpRewritePattern; |
| 453 | |
| 454 | TransposeOp2DToShuffleLowering( |
| 455 | vector::VectorTransposeLowering vectorTransposeLowering, |
| 456 | MLIRContext *context, PatternBenefit benefit = 1) |
| 457 | : OpRewritePattern<vector::TransposeOp>(context, benefit), |
| 458 | vectorTransposeLowering(vectorTransposeLowering) {} |
| 459 | |
| 460 | LogicalResult matchAndRewrite(vector::TransposeOp op, |
| 461 | PatternRewriter &rewriter) const override { |
| 462 | if (!isShuffleLike(lowering: vectorTransposeLowering)) |
| 463 | return rewriter.notifyMatchFailure( |
| 464 | arg&: op, msg: "not using vector shuffle based lowering" ); |
| 465 | |
| 466 | if (op.getSourceVectorType().isScalable()) |
| 467 | return rewriter.notifyMatchFailure( |
| 468 | arg&: op, msg: "vector shuffle lowering not supported for scalable vectors" ); |
| 469 | |
| 470 | auto srcGtOneDims = isTranspose2DSlice(op); |
| 471 | if (failed(Result: srcGtOneDims)) |
| 472 | return rewriter.notifyMatchFailure( |
| 473 | arg&: op, msg: "expected transposition on a 2D slice" ); |
| 474 | |
| 475 | VectorType srcType = op.getSourceVectorType(); |
| 476 | int64_t m = srcType.getDimSize(idx: std::get<0>(in&: srcGtOneDims.value())); |
| 477 | int64_t n = srcType.getDimSize(idx: std::get<1>(in&: srcGtOneDims.value())); |
| 478 | |
| 479 | // Reshape the n-D input vector with only two dimensions greater than one |
| 480 | // to a 2-D vector. |
| 481 | Location loc = op.getLoc(); |
| 482 | auto flattenedType = VectorType::get(shape: {n * m}, elementType: srcType.getElementType()); |
| 483 | auto reshInputType = VectorType::get(shape: {m, n}, elementType: srcType.getElementType()); |
| 484 | auto reshInput = rewriter.create<vector::ShapeCastOp>(location: loc, args&: flattenedType, |
| 485 | args: op.getVector()); |
| 486 | |
| 487 | Value res; |
| 488 | if (vectorTransposeLowering == VectorTransposeLowering::Shuffle16x16 && |
| 489 | m == 16 && n == 16) { |
| 490 | reshInput = |
| 491 | rewriter.create<vector::ShapeCastOp>(location: loc, args&: reshInputType, args&: reshInput); |
| 492 | res = transposeToShuffle16x16(builder&: rewriter, source: reshInput, m, n); |
| 493 | } else { |
| 494 | // Fallback to shuffle on 1D approach. |
| 495 | res = transposeToShuffle1D(b&: rewriter, source: reshInput, m, n); |
| 496 | } |
| 497 | |
| 498 | rewriter.replaceOpWithNewOp<vector::ShapeCastOp>( |
| 499 | op, args: op.getResultVectorType(), args&: res); |
| 500 | |
| 501 | return success(); |
| 502 | } |
| 503 | |
| 504 | private: |
| 505 | /// Options to control the vector patterns. |
| 506 | vector::VectorTransposeLowering vectorTransposeLowering; |
| 507 | }; |
| 508 | } // namespace |
| 509 | |
| 510 | void mlir::vector::populateVectorTransposeLoweringPatterns( |
| 511 | RewritePatternSet &patterns, |
| 512 | VectorTransposeLowering vectorTransposeLowering, PatternBenefit benefit) { |
| 513 | patterns.add<Transpose2DWithUnitDimToShapeCast>(arg: patterns.getContext(), |
| 514 | args&: benefit); |
| 515 | patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>( |
| 516 | arg&: vectorTransposeLowering, args: patterns.getContext(), args&: benefit); |
| 517 | } |
| 518 | |