| 1 | //===- AVXTranspose.cpp - Lower Vector transpose to AVX -------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements vector.transpose rewrites as AVX patterns for particular |
| 10 | // sizes of interest. |
| 11 | // |
| 12 | //===----------------------------------------------------------------------===// |
| 13 | |
| 14 | #include "mlir/Dialect/Arith/IR/Arith.h" |
| 15 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
| 16 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
| 17 | #include "mlir/Dialect/Vector/Utils/VectorUtils.h" |
| 18 | #include "mlir/Dialect/X86Vector/Transforms.h" |
| 19 | #include "mlir/IR/ImplicitLocOpBuilder.h" |
| 20 | #include "mlir/IR/Matchers.h" |
| 21 | #include "mlir/IR/PatternMatch.h" |
| 22 | #include "llvm/Support/Format.h" |
| 23 | #include "llvm/Support/FormatVariadic.h" |
| 24 | |
| 25 | using namespace mlir; |
| 26 | using namespace mlir::vector; |
| 27 | using namespace mlir::x86vector; |
| 28 | using namespace mlir::x86vector::avx2; |
| 29 | using namespace mlir::x86vector::avx2::inline_asm; |
| 30 | using namespace mlir::x86vector::avx2::intrin; |
| 31 | |
| 32 | Value mlir::x86vector::avx2::inline_asm::mm256BlendPsAsm( |
| 33 | ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask) { |
| 34 | auto asmDialectAttr = |
| 35 | LLVM::AsmDialectAttr::get(b.getContext(), LLVM::AsmDialect::AD_Intel); |
| 36 | const auto *asmTp = "vblendps $0, $1, $2, {0}" ; |
| 37 | const auto *asmCstr = |
| 38 | "=x,x,x" ; // Careful: constraint parser is very brittle: no ws! |
| 39 | SmallVector<Value> asmVals{v1, v2}; |
| 40 | auto asmStr = llvm::formatv(Fmt: asmTp, Vals: llvm::format_hex(N: mask, /*width=*/Width: 2)).str(); |
| 41 | auto asmOp = b.create<LLVM::InlineAsmOp>( |
| 42 | v1.getType(), /*operands=*/asmVals, /*asm_string=*/asmStr, |
| 43 | /*constraints=*/asmCstr, /*has_side_effects=*/false, |
| 44 | /*is_align_stack=*/false, LLVM::TailCallKind::None, |
| 45 | /*asm_dialect=*/asmDialectAttr, |
| 46 | /*operand_attrs=*/ArrayAttr()); |
| 47 | return asmOp.getResult(0); |
| 48 | } |
| 49 | |
| 50 | Value mlir::x86vector::avx2::intrin::mm256UnpackLoPs(ImplicitLocOpBuilder &b, |
| 51 | Value v1, Value v2) { |
| 52 | return b.create<vector::ShuffleOp>( |
| 53 | v1, v2, ArrayRef<int64_t>{0, 8, 1, 9, 4, 12, 5, 13}); |
| 54 | } |
| 55 | |
| 56 | Value mlir::x86vector::avx2::intrin::mm256UnpackHiPs(ImplicitLocOpBuilder &b, |
| 57 | Value v1, Value v2) { |
| 58 | return b.create<vector::ShuffleOp>( |
| 59 | v1, v2, ArrayRef<int64_t>{2, 10, 3, 11, 6, 14, 7, 15}); |
| 60 | } |
| 61 | /// a a b b a a b b |
| 62 | /// Takes an 8 bit mask, 2 bit for each position of a[0, 3) **and** b[0, 4): |
| 63 | /// 0:127 | 128:255 |
| 64 | /// b01 b23 C8 D8 | b01+4 b23+4 C8+4 D8+4 |
| 65 | Value mlir::x86vector::avx2::intrin::mm256ShufflePs(ImplicitLocOpBuilder &b, |
| 66 | Value v1, Value v2, |
| 67 | uint8_t mask) { |
| 68 | uint8_t b01, b23, b45, b67; |
| 69 | MaskHelper::extractShuffle(mask, b01, b23, b45, b67); |
| 70 | SmallVector<int64_t> shuffleMask = { |
| 71 | b01, b23, b45 + 8, b67 + 8, b01 + 4, b23 + 4, b45 + 8 + 4, b67 + 8 + 4}; |
| 72 | return b.create<vector::ShuffleOp>(v1, v2, shuffleMask); |
| 73 | } |
| 74 | |
| 75 | // imm[0:1] out of imm[0:3] is: |
| 76 | // 0 1 2 3 |
| 77 | // a[0:127] or a[128:255] or b[0:127] or b[128:255] | |
| 78 | // a[0:127] or a[128:255] or b[0:127] or b[128:255] |
| 79 | // 0 1 2 3 |
| 80 | // imm[0:1] out of imm[4:7]. |
| 81 | Value mlir::x86vector::avx2::intrin::mm256Permute2f128Ps( |
| 82 | ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask) { |
| 83 | SmallVector<int64_t> shuffleMask; |
| 84 | auto appendToMask = [&](uint8_t control) { |
| 85 | if (control == 0) |
| 86 | llvm::append_range(C&: shuffleMask, R: ArrayRef<int64_t>{0, 1, 2, 3}); |
| 87 | else if (control == 1) |
| 88 | llvm::append_range(C&: shuffleMask, R: ArrayRef<int64_t>{4, 5, 6, 7}); |
| 89 | else if (control == 2) |
| 90 | llvm::append_range(C&: shuffleMask, R: ArrayRef<int64_t>{8, 9, 10, 11}); |
| 91 | else if (control == 3) |
| 92 | llvm::append_range(C&: shuffleMask, R: ArrayRef<int64_t>{12, 13, 14, 15}); |
| 93 | else |
| 94 | llvm_unreachable("control > 3 : overflow" ); |
| 95 | }; |
| 96 | uint8_t b03, b47; |
| 97 | MaskHelper::extractPermute(mask, b03, b47); |
| 98 | appendToMask(b03); |
| 99 | appendToMask(b47); |
| 100 | return b.create<vector::ShuffleOp>(v1, v2, shuffleMask); |
| 101 | } |
| 102 | |
| 103 | /// If bit i of `mask` is zero, take f32@i from v1 else take it from v2. |
| 104 | Value mlir::x86vector::avx2::intrin::mm256BlendPs(ImplicitLocOpBuilder &b, |
| 105 | Value v1, Value v2, |
| 106 | uint8_t mask) { |
| 107 | SmallVector<int64_t, 8> shuffleMask; |
| 108 | for (int i = 0; i < 8; ++i) { |
| 109 | bool isSet = mask & (1 << i); |
| 110 | shuffleMask.push_back(Elt: !isSet ? i : i + 8); |
| 111 | } |
| 112 | return b.create<vector::ShuffleOp>(v1, v2, shuffleMask); |
| 113 | } |
| 114 | |
| 115 | /// AVX2 4x8xf32-specific transpose lowering using a "C intrinsics" model. |
| 116 | void mlir::x86vector::avx2::transpose4x8xf32(ImplicitLocOpBuilder &ib, |
| 117 | MutableArrayRef<Value> vs) { |
| 118 | #ifndef NDEBUG |
| 119 | auto vt = VectorType::get({8}, Float32Type::get(ib.getContext())); |
| 120 | assert(vs.size() == 4 && "expects 4 vectors" ); |
| 121 | assert(llvm::all_of(ValueRange{vs}.getTypes(), |
| 122 | [&](Type t) { return t == vt; }) && |
| 123 | "expects all types to be vector<8xf32>" ); |
| 124 | #endif |
| 125 | |
| 126 | Value t0 = mm256UnpackLoPs(b&: ib, v1: vs[0], v2: vs[1]); |
| 127 | Value t1 = mm256UnpackHiPs(b&: ib, v1: vs[0], v2: vs[1]); |
| 128 | Value t2 = mm256UnpackLoPs(b&: ib, v1: vs[2], v2: vs[3]); |
| 129 | Value t3 = mm256UnpackHiPs(b&: ib, v1: vs[2], v2: vs[3]); |
| 130 | Value s0 = mm256ShufflePs(b&: ib, v1: t0, v2: t2, mask: MaskHelper::shuffle<1, 0, 1, 0>()); |
| 131 | Value s1 = mm256ShufflePs(b&: ib, v1: t0, v2: t2, mask: MaskHelper::shuffle<3, 2, 3, 2>()); |
| 132 | Value s2 = mm256ShufflePs(b&: ib, v1: t1, v2: t3, mask: MaskHelper::shuffle<1, 0, 1, 0>()); |
| 133 | Value s3 = mm256ShufflePs(b&: ib, v1: t1, v2: t3, mask: MaskHelper::shuffle<3, 2, 3, 2>()); |
| 134 | vs[0] = mm256Permute2f128Ps(b&: ib, v1: s0, v2: s1, mask: MaskHelper::permute<2, 0>()); |
| 135 | vs[1] = mm256Permute2f128Ps(b&: ib, v1: s2, v2: s3, mask: MaskHelper::permute<2, 0>()); |
| 136 | vs[2] = mm256Permute2f128Ps(b&: ib, v1: s0, v2: s1, mask: MaskHelper::permute<3, 1>()); |
| 137 | vs[3] = mm256Permute2f128Ps(b&: ib, v1: s2, v2: s3, mask: MaskHelper::permute<3, 1>()); |
| 138 | } |
| 139 | |
| 140 | /// AVX2 8x8xf32-specific transpose lowering using a "C intrinsics" model. |
| 141 | void mlir::x86vector::avx2::transpose8x8xf32(ImplicitLocOpBuilder &ib, |
| 142 | MutableArrayRef<Value> vs) { |
| 143 | auto vt = VectorType::get({8}, Float32Type::get(ib.getContext())); |
| 144 | (void)vt; |
| 145 | assert(vs.size() == 8 && "expects 8 vectors" ); |
| 146 | assert(llvm::all_of(ValueRange{vs}.getTypes(), |
| 147 | [&](Type t) { return t == vt; }) && |
| 148 | "expects all types to be vector<8xf32>" ); |
| 149 | |
| 150 | Value t0 = mm256UnpackLoPs(b&: ib, v1: vs[0], v2: vs[1]); |
| 151 | Value t1 = mm256UnpackHiPs(b&: ib, v1: vs[0], v2: vs[1]); |
| 152 | Value t2 = mm256UnpackLoPs(b&: ib, v1: vs[2], v2: vs[3]); |
| 153 | Value t3 = mm256UnpackHiPs(b&: ib, v1: vs[2], v2: vs[3]); |
| 154 | Value t4 = mm256UnpackLoPs(b&: ib, v1: vs[4], v2: vs[5]); |
| 155 | Value t5 = mm256UnpackHiPs(b&: ib, v1: vs[4], v2: vs[5]); |
| 156 | Value t6 = mm256UnpackLoPs(b&: ib, v1: vs[6], v2: vs[7]); |
| 157 | Value t7 = mm256UnpackHiPs(b&: ib, v1: vs[6], v2: vs[7]); |
| 158 | |
| 159 | using inline_asm::mm256BlendPsAsm; |
| 160 | Value sh0 = mm256ShufflePs(b&: ib, v1: t0, v2: t2, mask: MaskHelper::shuffle<1, 0, 3, 2>()); |
| 161 | Value sh2 = mm256ShufflePs(b&: ib, v1: t1, v2: t3, mask: MaskHelper::shuffle<1, 0, 3, 2>()); |
| 162 | Value sh4 = mm256ShufflePs(b&: ib, v1: t4, v2: t6, mask: MaskHelper::shuffle<1, 0, 3, 2>()); |
| 163 | Value sh6 = mm256ShufflePs(b&: ib, v1: t5, v2: t7, mask: MaskHelper::shuffle<1, 0, 3, 2>()); |
| 164 | |
| 165 | Value s0 = |
| 166 | mm256BlendPsAsm(b&: ib, v1: t0, v2: sh0, mask: MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>()); |
| 167 | Value s1 = |
| 168 | mm256BlendPsAsm(b&: ib, v1: t2, v2: sh0, mask: MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>()); |
| 169 | Value s2 = |
| 170 | mm256BlendPsAsm(b&: ib, v1: t1, v2: sh2, mask: MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>()); |
| 171 | Value s3 = |
| 172 | mm256BlendPsAsm(b&: ib, v1: t3, v2: sh2, mask: MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>()); |
| 173 | Value s4 = |
| 174 | mm256BlendPsAsm(b&: ib, v1: t4, v2: sh4, mask: MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>()); |
| 175 | Value s5 = |
| 176 | mm256BlendPsAsm(b&: ib, v1: t6, v2: sh4, mask: MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>()); |
| 177 | Value s6 = |
| 178 | mm256BlendPsAsm(b&: ib, v1: t5, v2: sh6, mask: MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>()); |
| 179 | Value s7 = |
| 180 | mm256BlendPsAsm(b&: ib, v1: t7, v2: sh6, mask: MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>()); |
| 181 | |
| 182 | vs[0] = mm256Permute2f128Ps(b&: ib, v1: s0, v2: s4, mask: MaskHelper::permute<2, 0>()); |
| 183 | vs[1] = mm256Permute2f128Ps(b&: ib, v1: s1, v2: s5, mask: MaskHelper::permute<2, 0>()); |
| 184 | vs[2] = mm256Permute2f128Ps(b&: ib, v1: s2, v2: s6, mask: MaskHelper::permute<2, 0>()); |
| 185 | vs[3] = mm256Permute2f128Ps(b&: ib, v1: s3, v2: s7, mask: MaskHelper::permute<2, 0>()); |
| 186 | vs[4] = mm256Permute2f128Ps(b&: ib, v1: s0, v2: s4, mask: MaskHelper::permute<3, 1>()); |
| 187 | vs[5] = mm256Permute2f128Ps(b&: ib, v1: s1, v2: s5, mask: MaskHelper::permute<3, 1>()); |
| 188 | vs[6] = mm256Permute2f128Ps(b&: ib, v1: s2, v2: s6, mask: MaskHelper::permute<3, 1>()); |
| 189 | vs[7] = mm256Permute2f128Ps(b&: ib, v1: s3, v2: s7, mask: MaskHelper::permute<3, 1>()); |
| 190 | } |
| 191 | |
| 192 | /// Rewrite AVX2-specific vector.transpose, for the supported cases and |
| 193 | /// depending on the `TransposeLoweringOptions`. The lowering supports 2-D |
| 194 | /// transpose cases and n-D cases that have been decomposed into 2-D |
| 195 | /// transposition slices. For example, a 3-D transpose: |
| 196 | /// |
| 197 | /// %0 = vector.transpose %arg0, [2, 0, 1] |
| 198 | /// : vector<1024x2048x4096xf32> to vector<4096x1024x2048xf32> |
| 199 | /// |
| 200 | /// could be sliced into 2-D transposes by tiling two of its dimensions to one |
| 201 | /// of the vector lengths supported by the AVX2 patterns (e.g., 4x8): |
| 202 | /// |
| 203 | /// %0 = vector.transpose %arg0, [2, 0, 1] |
| 204 | /// : vector<1x4x8xf32> to vector<8x1x4xf32> |
| 205 | /// |
| 206 | /// This lowering will analyze the n-D vector.transpose and determine if it's a |
| 207 | /// supported 2-D transposition slice where any of the AVX2 patterns can be |
| 208 | /// applied. |
| 209 | class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> { |
| 210 | public: |
| 211 | using OpRewritePattern<vector::TransposeOp>::OpRewritePattern; |
| 212 | |
| 213 | TransposeOpLowering(LoweringOptions loweringOptions, MLIRContext *context, |
| 214 | int benefit) |
| 215 | : OpRewritePattern<vector::TransposeOp>(context, benefit), |
| 216 | loweringOptions(loweringOptions) {} |
| 217 | |
| 218 | LogicalResult matchAndRewrite(vector::TransposeOp op, |
| 219 | PatternRewriter &rewriter) const override { |
| 220 | auto loc = op.getLoc(); |
| 221 | |
| 222 | // Check if the source vector type is supported. AVX2 patterns can only be |
| 223 | // applied to f32 vector types with two dimensions greater than one. |
| 224 | VectorType srcType = op.getSourceVectorType(); |
| 225 | if (!srcType.getElementType().isF32()) |
| 226 | return rewriter.notifyMatchFailure(op, "Unsupported vector element type" ); |
| 227 | |
| 228 | auto srcGtOneDims = mlir::vector::isTranspose2DSlice(op); |
| 229 | if (failed(srcGtOneDims)) |
| 230 | return rewriter.notifyMatchFailure( |
| 231 | op, "expected transposition on a 2D slice" ); |
| 232 | |
| 233 | // Retrieve the sizes of the two dimensions greater than one to be |
| 234 | // transposed. |
| 235 | int64_t m = srcType.getDimSize(std::get<0>(srcGtOneDims.value())); |
| 236 | int64_t n = srcType.getDimSize(std::get<1>(srcGtOneDims.value())); |
| 237 | |
| 238 | auto applyRewrite = [&]() { |
| 239 | ImplicitLocOpBuilder ib(loc, rewriter); |
| 240 | SmallVector<Value> vs; |
| 241 | |
| 242 | // Reshape the n-D input vector with only two dimensions greater than one |
| 243 | // to a 2-D vector. |
| 244 | auto flattenedType = |
| 245 | VectorType::get({n * m}, op.getSourceVectorType().getElementType()); |
| 246 | auto reshInputType = VectorType::get({m, n}, srcType.getElementType()); |
| 247 | auto reshInput = |
| 248 | ib.create<vector::ShapeCastOp>(flattenedType, op.getVector()); |
| 249 | reshInput = ib.create<vector::ShapeCastOp>(reshInputType, reshInput); |
| 250 | |
| 251 | // Extract 1-D vectors from the higher-order dimension of the input |
| 252 | // vector. |
| 253 | for (int64_t i = 0; i < m; ++i) |
| 254 | vs.push_back(ib.create<vector::ExtractOp>(reshInput, i)); |
| 255 | |
| 256 | // Transpose set of 1-D vectors. |
| 257 | if (m == 4) |
| 258 | transpose4x8xf32(ib, vs); |
| 259 | if (m == 8) |
| 260 | transpose8x8xf32(ib, vs); |
| 261 | |
| 262 | // Insert transposed 1-D vectors into the higher-order dimension of the |
| 263 | // output vector. |
| 264 | Value res = ib.create<arith::ConstantOp>(reshInputType, |
| 265 | ib.getZeroAttr(reshInputType)); |
| 266 | for (int64_t i = 0; i < m; ++i) |
| 267 | res = ib.create<vector::InsertOp>(vs[i], res, i); |
| 268 | |
| 269 | // The output vector still has the shape of the input vector (e.g., 4x8). |
| 270 | // We have to transpose their dimensions and retrieve its original rank |
| 271 | // (e.g., 1x8x1x4x1). |
| 272 | res = ib.create<vector::ShapeCastOp>(flattenedType, res); |
| 273 | res = ib.create<vector::ShapeCastOp>(op.getResultVectorType(), res); |
| 274 | rewriter.replaceOp(op, res); |
| 275 | return success(); |
| 276 | }; |
| 277 | |
| 278 | if (loweringOptions.transposeOptions.lower4x8xf32_ && m == 4 && n == 8) |
| 279 | return applyRewrite(); |
| 280 | if (loweringOptions.transposeOptions.lower8x8xf32_ && m == 8 && n == 8) |
| 281 | return applyRewrite(); |
| 282 | return failure(); |
| 283 | } |
| 284 | |
| 285 | private: |
| 286 | LoweringOptions loweringOptions; |
| 287 | }; |
| 288 | |
| 289 | void mlir::x86vector::avx2::populateSpecializedTransposeLoweringPatterns( |
| 290 | RewritePatternSet &patterns, LoweringOptions options, int benefit) { |
| 291 | patterns.add<TransposeOpLowering>(arg&: options, args: patterns.getContext(), args&: benefit); |
| 292 | } |
| 293 | |