1//===- AVXTranspose.cpp - Lower Vector transpose to AVX -------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements vector.transpose rewrites as AVX patterns for particular
10// sizes of interest.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/Arith/IR/Arith.h"
15#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
16#include "mlir/Dialect/Vector/IR/VectorOps.h"
17#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
18#include "mlir/Dialect/X86Vector/Transforms.h"
19#include "mlir/IR/ImplicitLocOpBuilder.h"
20#include "mlir/IR/Matchers.h"
21#include "mlir/IR/PatternMatch.h"
22#include "llvm/Support/Format.h"
23#include "llvm/Support/FormatVariadic.h"
24
25using namespace mlir;
26using namespace mlir::vector;
27using namespace mlir::x86vector;
28using namespace mlir::x86vector::avx2;
29using namespace mlir::x86vector::avx2::inline_asm;
30using namespace mlir::x86vector::avx2::intrin;
31
32Value mlir::x86vector::avx2::inline_asm::mm256BlendPsAsm(
33 ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask) {
34 auto asmDialectAttr =
35 LLVM::AsmDialectAttr::get(b.getContext(), LLVM::AsmDialect::AD_Intel);
36 const auto *asmTp = "vblendps $0, $1, $2, {0}";
37 const auto *asmCstr =
38 "=x,x,x"; // Careful: constraint parser is very brittle: no ws!
39 SmallVector<Value> asmVals{v1, v2};
40 auto asmStr = llvm::formatv(Fmt: asmTp, Vals: llvm::format_hex(N: mask, /*width=*/Width: 2)).str();
41 auto asmOp = b.create<LLVM::InlineAsmOp>(
42 v1.getType(), /*operands=*/asmVals, /*asm_string=*/asmStr,
43 /*constraints=*/asmCstr, /*has_side_effects=*/false,
44 /*is_align_stack=*/false, LLVM::TailCallKind::None,
45 /*asm_dialect=*/asmDialectAttr,
46 /*operand_attrs=*/ArrayAttr());
47 return asmOp.getResult(0);
48}
49
50Value mlir::x86vector::avx2::intrin::mm256UnpackLoPs(ImplicitLocOpBuilder &b,
51 Value v1, Value v2) {
52 return b.create<vector::ShuffleOp>(
53 v1, v2, ArrayRef<int64_t>{0, 8, 1, 9, 4, 12, 5, 13});
54}
55
56Value mlir::x86vector::avx2::intrin::mm256UnpackHiPs(ImplicitLocOpBuilder &b,
57 Value v1, Value v2) {
58 return b.create<vector::ShuffleOp>(
59 v1, v2, ArrayRef<int64_t>{2, 10, 3, 11, 6, 14, 7, 15});
60}
61/// a a b b a a b b
62/// Takes an 8 bit mask, 2 bit for each position of a[0, 3) **and** b[0, 4):
63/// 0:127 | 128:255
64/// b01 b23 C8 D8 | b01+4 b23+4 C8+4 D8+4
65Value mlir::x86vector::avx2::intrin::mm256ShufflePs(ImplicitLocOpBuilder &b,
66 Value v1, Value v2,
67 uint8_t mask) {
68 uint8_t b01, b23, b45, b67;
69 MaskHelper::extractShuffle(mask, b01, b23, b45, b67);
70 SmallVector<int64_t> shuffleMask = {
71 b01, b23, b45 + 8, b67 + 8, b01 + 4, b23 + 4, b45 + 8 + 4, b67 + 8 + 4};
72 return b.create<vector::ShuffleOp>(v1, v2, shuffleMask);
73}
74
75// imm[0:1] out of imm[0:3] is:
76// 0 1 2 3
77// a[0:127] or a[128:255] or b[0:127] or b[128:255] |
78// a[0:127] or a[128:255] or b[0:127] or b[128:255]
79// 0 1 2 3
80// imm[0:1] out of imm[4:7].
81Value mlir::x86vector::avx2::intrin::mm256Permute2f128Ps(
82 ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask) {
83 SmallVector<int64_t> shuffleMask;
84 auto appendToMask = [&](uint8_t control) {
85 if (control == 0)
86 llvm::append_range(C&: shuffleMask, R: ArrayRef<int64_t>{0, 1, 2, 3});
87 else if (control == 1)
88 llvm::append_range(C&: shuffleMask, R: ArrayRef<int64_t>{4, 5, 6, 7});
89 else if (control == 2)
90 llvm::append_range(C&: shuffleMask, R: ArrayRef<int64_t>{8, 9, 10, 11});
91 else if (control == 3)
92 llvm::append_range(C&: shuffleMask, R: ArrayRef<int64_t>{12, 13, 14, 15});
93 else
94 llvm_unreachable("control > 3 : overflow");
95 };
96 uint8_t b03, b47;
97 MaskHelper::extractPermute(mask, b03, b47);
98 appendToMask(b03);
99 appendToMask(b47);
100 return b.create<vector::ShuffleOp>(v1, v2, shuffleMask);
101}
102
103/// If bit i of `mask` is zero, take f32@i from v1 else take it from v2.
104Value mlir::x86vector::avx2::intrin::mm256BlendPs(ImplicitLocOpBuilder &b,
105 Value v1, Value v2,
106 uint8_t mask) {
107 SmallVector<int64_t, 8> shuffleMask;
108 for (int i = 0; i < 8; ++i) {
109 bool isSet = mask & (1 << i);
110 shuffleMask.push_back(Elt: !isSet ? i : i + 8);
111 }
112 return b.create<vector::ShuffleOp>(v1, v2, shuffleMask);
113}
114
115/// AVX2 4x8xf32-specific transpose lowering using a "C intrinsics" model.
116void mlir::x86vector::avx2::transpose4x8xf32(ImplicitLocOpBuilder &ib,
117 MutableArrayRef<Value> vs) {
118#ifndef NDEBUG
119 auto vt = VectorType::get({8}, Float32Type::get(ib.getContext()));
120 assert(vs.size() == 4 && "expects 4 vectors");
121 assert(llvm::all_of(ValueRange{vs}.getTypes(),
122 [&](Type t) { return t == vt; }) &&
123 "expects all types to be vector<8xf32>");
124#endif
125
126 Value t0 = mm256UnpackLoPs(b&: ib, v1: vs[0], v2: vs[1]);
127 Value t1 = mm256UnpackHiPs(b&: ib, v1: vs[0], v2: vs[1]);
128 Value t2 = mm256UnpackLoPs(b&: ib, v1: vs[2], v2: vs[3]);
129 Value t3 = mm256UnpackHiPs(b&: ib, v1: vs[2], v2: vs[3]);
130 Value s0 = mm256ShufflePs(b&: ib, v1: t0, v2: t2, mask: MaskHelper::shuffle<1, 0, 1, 0>());
131 Value s1 = mm256ShufflePs(b&: ib, v1: t0, v2: t2, mask: MaskHelper::shuffle<3, 2, 3, 2>());
132 Value s2 = mm256ShufflePs(b&: ib, v1: t1, v2: t3, mask: MaskHelper::shuffle<1, 0, 1, 0>());
133 Value s3 = mm256ShufflePs(b&: ib, v1: t1, v2: t3, mask: MaskHelper::shuffle<3, 2, 3, 2>());
134 vs[0] = mm256Permute2f128Ps(b&: ib, v1: s0, v2: s1, mask: MaskHelper::permute<2, 0>());
135 vs[1] = mm256Permute2f128Ps(b&: ib, v1: s2, v2: s3, mask: MaskHelper::permute<2, 0>());
136 vs[2] = mm256Permute2f128Ps(b&: ib, v1: s0, v2: s1, mask: MaskHelper::permute<3, 1>());
137 vs[3] = mm256Permute2f128Ps(b&: ib, v1: s2, v2: s3, mask: MaskHelper::permute<3, 1>());
138}
139
140/// AVX2 8x8xf32-specific transpose lowering using a "C intrinsics" model.
141void mlir::x86vector::avx2::transpose8x8xf32(ImplicitLocOpBuilder &ib,
142 MutableArrayRef<Value> vs) {
143 auto vt = VectorType::get({8}, Float32Type::get(ib.getContext()));
144 (void)vt;
145 assert(vs.size() == 8 && "expects 8 vectors");
146 assert(llvm::all_of(ValueRange{vs}.getTypes(),
147 [&](Type t) { return t == vt; }) &&
148 "expects all types to be vector<8xf32>");
149
150 Value t0 = mm256UnpackLoPs(b&: ib, v1: vs[0], v2: vs[1]);
151 Value t1 = mm256UnpackHiPs(b&: ib, v1: vs[0], v2: vs[1]);
152 Value t2 = mm256UnpackLoPs(b&: ib, v1: vs[2], v2: vs[3]);
153 Value t3 = mm256UnpackHiPs(b&: ib, v1: vs[2], v2: vs[3]);
154 Value t4 = mm256UnpackLoPs(b&: ib, v1: vs[4], v2: vs[5]);
155 Value t5 = mm256UnpackHiPs(b&: ib, v1: vs[4], v2: vs[5]);
156 Value t6 = mm256UnpackLoPs(b&: ib, v1: vs[6], v2: vs[7]);
157 Value t7 = mm256UnpackHiPs(b&: ib, v1: vs[6], v2: vs[7]);
158
159 using inline_asm::mm256BlendPsAsm;
160 Value sh0 = mm256ShufflePs(b&: ib, v1: t0, v2: t2, mask: MaskHelper::shuffle<1, 0, 3, 2>());
161 Value sh2 = mm256ShufflePs(b&: ib, v1: t1, v2: t3, mask: MaskHelper::shuffle<1, 0, 3, 2>());
162 Value sh4 = mm256ShufflePs(b&: ib, v1: t4, v2: t6, mask: MaskHelper::shuffle<1, 0, 3, 2>());
163 Value sh6 = mm256ShufflePs(b&: ib, v1: t5, v2: t7, mask: MaskHelper::shuffle<1, 0, 3, 2>());
164
165 Value s0 =
166 mm256BlendPsAsm(b&: ib, v1: t0, v2: sh0, mask: MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>());
167 Value s1 =
168 mm256BlendPsAsm(b&: ib, v1: t2, v2: sh0, mask: MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>());
169 Value s2 =
170 mm256BlendPsAsm(b&: ib, v1: t1, v2: sh2, mask: MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>());
171 Value s3 =
172 mm256BlendPsAsm(b&: ib, v1: t3, v2: sh2, mask: MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>());
173 Value s4 =
174 mm256BlendPsAsm(b&: ib, v1: t4, v2: sh4, mask: MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>());
175 Value s5 =
176 mm256BlendPsAsm(b&: ib, v1: t6, v2: sh4, mask: MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>());
177 Value s6 =
178 mm256BlendPsAsm(b&: ib, v1: t5, v2: sh6, mask: MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>());
179 Value s7 =
180 mm256BlendPsAsm(b&: ib, v1: t7, v2: sh6, mask: MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>());
181
182 vs[0] = mm256Permute2f128Ps(b&: ib, v1: s0, v2: s4, mask: MaskHelper::permute<2, 0>());
183 vs[1] = mm256Permute2f128Ps(b&: ib, v1: s1, v2: s5, mask: MaskHelper::permute<2, 0>());
184 vs[2] = mm256Permute2f128Ps(b&: ib, v1: s2, v2: s6, mask: MaskHelper::permute<2, 0>());
185 vs[3] = mm256Permute2f128Ps(b&: ib, v1: s3, v2: s7, mask: MaskHelper::permute<2, 0>());
186 vs[4] = mm256Permute2f128Ps(b&: ib, v1: s0, v2: s4, mask: MaskHelper::permute<3, 1>());
187 vs[5] = mm256Permute2f128Ps(b&: ib, v1: s1, v2: s5, mask: MaskHelper::permute<3, 1>());
188 vs[6] = mm256Permute2f128Ps(b&: ib, v1: s2, v2: s6, mask: MaskHelper::permute<3, 1>());
189 vs[7] = mm256Permute2f128Ps(b&: ib, v1: s3, v2: s7, mask: MaskHelper::permute<3, 1>());
190}
191
192/// Rewrite AVX2-specific vector.transpose, for the supported cases and
193/// depending on the `TransposeLoweringOptions`. The lowering supports 2-D
194/// transpose cases and n-D cases that have been decomposed into 2-D
195/// transposition slices. For example, a 3-D transpose:
196///
197/// %0 = vector.transpose %arg0, [2, 0, 1]
198/// : vector<1024x2048x4096xf32> to vector<4096x1024x2048xf32>
199///
200/// could be sliced into 2-D transposes by tiling two of its dimensions to one
201/// of the vector lengths supported by the AVX2 patterns (e.g., 4x8):
202///
203/// %0 = vector.transpose %arg0, [2, 0, 1]
204/// : vector<1x4x8xf32> to vector<8x1x4xf32>
205///
206/// This lowering will analyze the n-D vector.transpose and determine if it's a
207/// supported 2-D transposition slice where any of the AVX2 patterns can be
208/// applied.
209class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
210public:
211 using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
212
213 TransposeOpLowering(LoweringOptions loweringOptions, MLIRContext *context,
214 int benefit)
215 : OpRewritePattern<vector::TransposeOp>(context, benefit),
216 loweringOptions(loweringOptions) {}
217
218 LogicalResult matchAndRewrite(vector::TransposeOp op,
219 PatternRewriter &rewriter) const override {
220 auto loc = op.getLoc();
221
222 // Check if the source vector type is supported. AVX2 patterns can only be
223 // applied to f32 vector types with two dimensions greater than one.
224 VectorType srcType = op.getSourceVectorType();
225 if (!srcType.getElementType().isF32())
226 return rewriter.notifyMatchFailure(op, "Unsupported vector element type");
227
228 auto srcGtOneDims = mlir::vector::isTranspose2DSlice(op);
229 if (failed(srcGtOneDims))
230 return rewriter.notifyMatchFailure(
231 op, "expected transposition on a 2D slice");
232
233 // Retrieve the sizes of the two dimensions greater than one to be
234 // transposed.
235 int64_t m = srcType.getDimSize(std::get<0>(srcGtOneDims.value()));
236 int64_t n = srcType.getDimSize(std::get<1>(srcGtOneDims.value()));
237
238 auto applyRewrite = [&]() {
239 ImplicitLocOpBuilder ib(loc, rewriter);
240 SmallVector<Value> vs;
241
242 // Reshape the n-D input vector with only two dimensions greater than one
243 // to a 2-D vector.
244 auto flattenedType =
245 VectorType::get({n * m}, op.getSourceVectorType().getElementType());
246 auto reshInputType = VectorType::get({m, n}, srcType.getElementType());
247 auto reshInput =
248 ib.create<vector::ShapeCastOp>(flattenedType, op.getVector());
249 reshInput = ib.create<vector::ShapeCastOp>(reshInputType, reshInput);
250
251 // Extract 1-D vectors from the higher-order dimension of the input
252 // vector.
253 for (int64_t i = 0; i < m; ++i)
254 vs.push_back(ib.create<vector::ExtractOp>(reshInput, i));
255
256 // Transpose set of 1-D vectors.
257 if (m == 4)
258 transpose4x8xf32(ib, vs);
259 if (m == 8)
260 transpose8x8xf32(ib, vs);
261
262 // Insert transposed 1-D vectors into the higher-order dimension of the
263 // output vector.
264 Value res = ib.create<arith::ConstantOp>(reshInputType,
265 ib.getZeroAttr(reshInputType));
266 for (int64_t i = 0; i < m; ++i)
267 res = ib.create<vector::InsertOp>(vs[i], res, i);
268
269 // The output vector still has the shape of the input vector (e.g., 4x8).
270 // We have to transpose their dimensions and retrieve its original rank
271 // (e.g., 1x8x1x4x1).
272 res = ib.create<vector::ShapeCastOp>(flattenedType, res);
273 res = ib.create<vector::ShapeCastOp>(op.getResultVectorType(), res);
274 rewriter.replaceOp(op, res);
275 return success();
276 };
277
278 if (loweringOptions.transposeOptions.lower4x8xf32_ && m == 4 && n == 8)
279 return applyRewrite();
280 if (loweringOptions.transposeOptions.lower8x8xf32_ && m == 8 && n == 8)
281 return applyRewrite();
282 return failure();
283 }
284
285private:
286 LoweringOptions loweringOptions;
287};
288
289void mlir::x86vector::avx2::populateSpecializedTransposeLoweringPatterns(
290 RewritePatternSet &patterns, LoweringOptions options, int benefit) {
291 patterns.add<TransposeOpLowering>(arg&: options, args: patterns.getContext(), args&: benefit);
292}
293

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp