1//! [AVX512BF16 intrinsics].
2//!
3//! [AVX512BF16 intrinsics]: https://software.intel.com/sites/landingpage/IntrinsicsGuide/#expand=1769&avx512techs=AVX512_BF16
4
5use crate::core_arch::{simd::*, x86::*};
6use crate::intrinsics::simd::*;
7
8#[cfg(test)]
9use stdarch_test::assert_instr;
10
11#[allow(improper_ctypes)]
12extern "C" {
13 #[link_name = "llvm.x86.avx512bf16.cvtne2ps2bf16.128"]
14 fn cvtne2ps2bf16(a: f32x4, b: f32x4) -> i16x8;
15 #[link_name = "llvm.x86.avx512bf16.cvtne2ps2bf16.256"]
16 fn cvtne2ps2bf16_256(a: f32x8, b: f32x8) -> i16x16;
17 #[link_name = "llvm.x86.avx512bf16.cvtne2ps2bf16.512"]
18 fn cvtne2ps2bf16_512(a: f32x16, b: f32x16) -> i16x32;
19 #[link_name = "llvm.x86.avx512bf16.cvtneps2bf16.256"]
20 fn cvtneps2bf16_256(a: f32x8) -> i16x8;
21 #[link_name = "llvm.x86.avx512bf16.cvtneps2bf16.512"]
22 fn cvtneps2bf16_512(a: f32x16) -> i16x16;
23 #[link_name = "llvm.x86.avx512bf16.dpbf16ps.128"]
24 fn dpbf16ps(a: f32x4, b: i32x4, c: i32x4) -> f32x4;
25 #[link_name = "llvm.x86.avx512bf16.dpbf16ps.256"]
26 fn dpbf16ps_256(a: f32x8, b: i32x8, c: i32x8) -> f32x8;
27 #[link_name = "llvm.x86.avx512bf16.dpbf16ps.512"]
28 fn dpbf16ps_512(a: f32x16, b: i32x16, c: i32x16) -> f32x16;
29}
30
31/// Convert packed single-precision (32-bit) floating-point elements in two 128-bit vectors
32/// a and b to packed BF16 (16-bit) floating-point elements, and store the results in a
33/// 128-bit wide vector.
34/// [Intel's documentation](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#expand=1769,1651&avx512techs=AVX512_BF16&text=_mm_cvtne2ps_pbh)
35#[inline]
36#[target_feature(enable = "avx512bf16,avx512vl")]
37#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
38#[cfg_attr(test, assert_instr("vcvtne2ps2bf16"))]
39pub unsafe fn _mm_cvtne2ps_pbh(a: __m128, b: __m128) -> __m128bh {
40 transmute(src:cvtne2ps2bf16(a:a.as_f32x4(), b:b.as_f32x4()))
41}
42
43/// Convert packed single-precision (32-bit) floating-point elements in two vectors
44/// a and b to packed BF16 (16-bit) floating-point elements, and store the results
45/// in single vector dst using writemask k (elements are copied from src when the
46/// corresponding mask bit is not set).
47/// [Intel's documentation](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#expand=1769,1651&avx512techs=AVX512_BF16&text=_mm_mask_cvtne2ps_pbh)
48#[inline]
49#[target_feature(enable = "avx512bf16,avx512vl")]
50#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
51#[cfg_attr(test, assert_instr("vcvtne2ps2bf16"))]
52pub unsafe fn _mm_mask_cvtne2ps_pbh(src: __m128bh, k: __mmask8, a: __m128, b: __m128) -> __m128bh {
53 let cvt: u16x8 = _mm_cvtne2ps_pbh(a, b).as_u16x8();
54 transmute(src:simd_select_bitmask(m:k, yes:cvt, no:src.as_u16x8()))
55}
56
57/// Convert packed single-precision (32-bit) floating-point elements in two vectors
58/// a and b to packed BF16 (16-bit) floating-point elements, and store the results
59/// in single vector dst using zeromask k (elements are zeroed out when the corresponding
60/// mask bit is not set).
61/// [Intel's documentation](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#expand=1769,1651&avx512techs=AVX512_BF16&text=_mm_maskz_cvtne2ps_pbh)
62#[inline]
63#[target_feature(enable = "avx512bf16,avx512vl")]
64#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
65#[cfg_attr(test, assert_instr("vcvtne2ps2bf16"))]
66pub unsafe fn _mm_maskz_cvtne2ps_pbh(k: __mmask8, a: __m128, b: __m128) -> __m128bh {
67 let cvt: u16x8 = _mm_cvtne2ps_pbh(a, b).as_u16x8();
68 let zero: u16x8 = _mm_setzero_si128().as_u16x8();
69 transmute(src:simd_select_bitmask(m:k, yes:cvt, no:zero))
70}
71
72/// Convert packed single-precision (32-bit) floating-point elements in two 256-bit vectors
73/// a and b to packed BF16 (16-bit) floating-point elements, and store the results in a
74/// 256-bit wide vector.
75/// [Intel's documentation](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#expand=1769,1651,1654&avx512techs=AVX512_BF16&text=_mm256_cvtne2ps_pbh)
76#[inline]
77#[target_feature(enable = "avx512bf16,avx512vl")]
78#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
79#[cfg_attr(test, assert_instr("vcvtne2ps2bf16"))]
80pub unsafe fn _mm256_cvtne2ps_pbh(a: __m256, b: __m256) -> __m256bh {
81 transmute(src:cvtne2ps2bf16_256(a:a.as_f32x8(), b:b.as_f32x8()))
82}
83
84/// Convert packed single-precision (32-bit) floating-point elements in two vectors a and b
85/// to packed BF16 (16-bit) floating-point elements and store the results in single vector
86/// dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
87/// [Intel's documentation](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#expand=1769,1651,1654&avx512techs=AVX512_BF16&text=_mm256_mask_cvtne2ps_pbh)
88#[inline]
89#[target_feature(enable = "avx512bf16,avx512vl")]
90#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
91#[cfg_attr(test, assert_instr("vcvtne2ps2bf16"))]
92pub unsafe fn _mm256_mask_cvtne2ps_pbh(
93 src: __m256bh,
94 k: __mmask16,
95 a: __m256,
96 b: __m256,
97) -> __m256bh {
98 let cvt: u16x16 = _mm256_cvtne2ps_pbh(a, b).as_u16x16();
99 transmute(src:simd_select_bitmask(m:k, yes:cvt, no:src.as_u16x16()))
100}
101
102/// Convert packed single-precision (32-bit) floating-point elements in two vectors a and b
103/// to packed BF16 (16-bit) floating-point elements, and store the results in single vector
104/// dst using zeromask k (elements are zeroed out when the corresponding mask bit is not set).
105/// [Intel's documentation](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#expand=1769,1651,1654&avx512techs=AVX512_BF16&text=_mm256_maskz_cvtne2ps_pbh)
106#[inline]
107#[target_feature(enable = "avx512bf16,avx512vl")]
108#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
109#[cfg_attr(test, assert_instr("vcvtne2ps2bf16"))]
110pub unsafe fn _mm256_maskz_cvtne2ps_pbh(k: __mmask16, a: __m256, b: __m256) -> __m256bh {
111 let cvt: u16x16 = _mm256_cvtne2ps_pbh(a, b).as_u16x16();
112 let zero: u16x16 = _mm256_setzero_si256().as_u16x16();
113 transmute(src:simd_select_bitmask(m:k, yes:cvt, no:zero))
114}
115
116/// Convert packed single-precision (32-bit) floating-point elements in two 512-bit vectors
117/// a and b to packed BF16 (16-bit) floating-point elements, and store the results in a
118/// 512-bit wide vector.
119/// [Intel's documentation](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#expand=1769,1651,1654,1657&avx512techs=AVX512_BF16&text=_mm512_cvtne2ps_pbh)
120#[inline]
121#[target_feature(enable = "avx512bf16,avx512f")]
122#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
123#[cfg_attr(test, assert_instr("vcvtne2ps2bf16"))]
124pub unsafe fn _mm512_cvtne2ps_pbh(a: __m512, b: __m512) -> __m512bh {
125 transmute(src:cvtne2ps2bf16_512(a:a.as_f32x16(), b:b.as_f32x16()))
126}
127
128/// Convert packed single-precision (32-bit) floating-point elements in two vectors
129/// a and b to packed BF16 (16-bit) floating-point elements, and store the results
130/// in single vector dst using writemask k (elements are copied from src when the
131/// corresponding mask bit is not set).
132/// [Intel's documentation](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#expand=1769,1651,1654,1657&avx512techs=AVX512_BF16&text=_mm512_mask_cvtne2ps_pbh)
133#[inline]
134#[target_feature(enable = "avx512bf16,avx512f")]
135#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
136#[cfg_attr(test, assert_instr("vcvtne2ps2bf16"))]
137pub unsafe fn _mm512_mask_cvtne2ps_pbh(
138 src: __m512bh,
139 k: __mmask32,
140 a: __m512,
141 b: __m512,
142) -> __m512bh {
143 let cvt: u16x32 = _mm512_cvtne2ps_pbh(a, b).as_u16x32();
144 transmute(src:simd_select_bitmask(m:k, yes:cvt, no:src.as_u16x32()))
145}
146
147/// Convert packed single-precision (32-bit) floating-point elements in two vectors
148/// a and b to packed BF16 (16-bit) floating-point elements, and store the results
149/// in single vector dst using zeromask k (elements are zeroed out when the corresponding
150/// mask bit is not set).
151/// [Intel's documentation](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#expand=1769,1651,1654,1657&avx512techs=AVX512_BF16&text=_mm512_maskz_cvtne2ps_pbh)
152#[inline]
153#[target_feature(enable = "avx512bf16,avx512f")]
154#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
155#[cfg_attr(test, assert_instr("vcvtne2ps2bf16"))]
156pub unsafe fn _mm512_maskz_cvtne2ps_pbh(k: __mmask32, a: __m512, b: __m512) -> __m512bh {
157 let cvt: u16x32 = _mm512_cvtne2ps_pbh(a, b).as_u16x32();
158 let zero: u16x32 = _mm512_setzero_si512().as_u16x32();
159 transmute(src:simd_select_bitmask(m:k, yes:cvt, no:zero))
160}
161
162/// Convert packed single-precision (32-bit) floating-point elements in a to packed BF16 (16-bit)
163/// floating-point elements, and store the results in dst.
164/// [Intel's documentation](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#expand=1769,1651,1654,1657,1660&avx512techs=AVX512_BF16&text=_mm256_cvtneps_pbh)
165#[inline]
166#[target_feature(enable = "avx512bf16,avx512vl")]
167#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
168#[cfg_attr(test, assert_instr("vcvtneps2bf16"))]
169pub unsafe fn _mm256_cvtneps_pbh(a: __m256) -> __m128bh {
170 transmute(src:cvtneps2bf16_256(a.as_f32x8()))
171}
172
173/// Convert packed single-precision (32-bit) floating-point elements in a to packed BF16 (16-bit)
174/// floating-point elements, and store the results in dst using writemask k
175/// (elements are copied from src when the corresponding mask bit is not set).
176/// [Intel's documentation](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#expand=1769,1651,1654,1657,1660&avx512techs=AVX512_BF16&text=_mm256_mask_cvtneps_pbh)
177#[inline]
178#[target_feature(enable = "avx512bf16,avx512vl")]
179#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
180#[cfg_attr(test, assert_instr("vcvtneps2bf16"))]
181pub unsafe fn _mm256_mask_cvtneps_pbh(src: __m128bh, k: __mmask8, a: __m256) -> __m128bh {
182 let cvt: u16x8 = _mm256_cvtneps_pbh(a).as_u16x8();
183 transmute(src:simd_select_bitmask(m:k, yes:cvt, no:src.as_u16x8()))
184}
185
186/// Convert packed single-precision (32-bit) floating-point elements in a to packed BF16 (16-bit)
187/// floating-point elements, and store the results in dst using zeromask k
188/// (elements are zeroed out when the corresponding mask bit is not set).
189/// [Intel's documentation](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#expand=1769,1651,1654,1657,1660&avx512techs=AVX512_BF16&text=_mm256_maskz_cvtneps_pbh)
190#[inline]
191#[target_feature(enable = "avx512bf16,avx512vl")]
192#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
193#[cfg_attr(test, assert_instr("vcvtneps2bf16"))]
194pub unsafe fn _mm256_maskz_cvtneps_pbh(k: __mmask8, a: __m256) -> __m128bh {
195 let cvt: u16x8 = _mm256_cvtneps_pbh(a).as_u16x8();
196 let zero: u16x8 = _mm_setzero_si128().as_u16x8();
197 transmute(src:simd_select_bitmask(m:k, yes:cvt, no:zero))
198}
199
200/// Convert packed single-precision (32-bit) floating-point elements in a to packed BF16 (16-bit)
201/// floating-point elements, and store the results in dst.
202/// [Intel's documentation](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#expand=1769,1651,1654,1657,1660&avx512techs=AVX512_BF16&text=_mm512_cvtneps_pbh)
203#[inline]
204#[target_feature(enable = "avx512bf16,avx512f")]
205#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
206#[cfg_attr(test, assert_instr("vcvtneps2bf16"))]
207pub unsafe fn _mm512_cvtneps_pbh(a: __m512) -> __m256bh {
208 transmute(src:cvtneps2bf16_512(a.as_f32x16()))
209}
210
211/// Convert packed single-precision (32-bit) floating-point elements in a to packed BF16 (16-bit)
212/// floating-point elements, and store the results in dst using writemask k
213/// (elements are copied from src when the corresponding mask bit is not set).
214/// [Intel's documentation](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#expand=1769,1651,1654,1657,1660&avx512techs=AVX512_BF16&text=_mm512_mask_cvtneps_pbh)
215#[inline]
216#[target_feature(enable = "avx512bf16,avx512f")]
217#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
218#[cfg_attr(test, assert_instr("vcvtneps2bf16"))]
219pub unsafe fn _mm512_mask_cvtneps_pbh(src: __m256bh, k: __mmask16, a: __m512) -> __m256bh {
220 let cvt: u16x16 = _mm512_cvtneps_pbh(a).as_u16x16();
221 transmute(src:simd_select_bitmask(m:k, yes:cvt, no:src.as_u16x16()))
222}
223
224/// Convert packed single-precision (32-bit) floating-point elements in a to packed BF16 (16-bit)
225/// floating-point elements, and store the results in dst using zeromask k
226/// (elements are zeroed out when the corresponding mask bit is not set).
227/// [Intel's documentation](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#expand=1769,1651,1654,1657,1660&avx512techs=AVX512_BF16&text=_mm512_maskz_cvtneps_pbh)
228#[inline]
229#[target_feature(enable = "avx512bf16,avx512f")]
230#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
231#[cfg_attr(test, assert_instr("vcvtneps2bf16"))]
232pub unsafe fn _mm512_maskz_cvtneps_pbh(k: __mmask16, a: __m512) -> __m256bh {
233 let cvt: u16x16 = _mm512_cvtneps_pbh(a).as_u16x16();
234 let zero: u16x16 = _mm256_setzero_si256().as_u16x16();
235 transmute(src:simd_select_bitmask(m:k, yes:cvt, no:zero))
236}
237
238/// Compute dot-product of BF16 (16-bit) floating-point pairs in a and b,
239/// accumulating the intermediate single-precision (32-bit) floating-point elements
240/// with elements in src, and store the results in dst.
241/// [Intel's documentation](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#expand=1769,1651,1654,1657,1660&avx512techs=AVX512_BF16&text=_mm_dpbf16_ps)
242#[inline]
243#[target_feature(enable = "avx512bf16,avx512vl")]
244#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
245#[cfg_attr(test, assert_instr("vdpbf16ps"))]
246pub unsafe fn _mm_dpbf16_ps(src: __m128, a: __m128bh, b: __m128bh) -> __m128 {
247 transmute(src:dpbf16ps(a:src.as_f32x4(), b:a.as_i32x4(), c:b.as_i32x4()))
248}
249
250/// Compute dot-product of BF16 (16-bit) floating-point pairs in a and b,
251/// accumulating the intermediate single-precision (32-bit) floating-point elements
252/// with elements in src, and store the results in dst using writemask k
253/// (elements are copied from src when the corresponding mask bit is not set).
254/// [Intel's documentation](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#expand=1769,1651,1654,1657,1660&avx512techs=AVX512_BF16&text=_mm_mask_dpbf16_ps)
255#[inline]
256#[target_feature(enable = "avx512bf16,avx512vl")]
257#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
258#[cfg_attr(test, assert_instr("vdpbf16ps"))]
259pub unsafe fn _mm_mask_dpbf16_ps(src: __m128, k: __mmask8, a: __m128bh, b: __m128bh) -> __m128 {
260 let rst: f32x4 = _mm_dpbf16_ps(src, a, b).as_f32x4();
261 transmute(src:simd_select_bitmask(m:k, yes:rst, no:src.as_f32x4()))
262}
263
264/// Compute dot-product of BF16 (16-bit) floating-point pairs in a and b,
265/// accumulating the intermediate single-precision (32-bit) floating-point elements
266/// with elements in src, and store the results in dst using zeromask k
267/// (elements are zeroed out when the corresponding mask bit is not set).
268/// [Intel's documentation](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#expand=1769,1651,1654,1657,1660&avx512techs=AVX512_BF16&text=_mm_maskz_dpbf16_ps)
269#[inline]
270#[target_feature(enable = "avx512bf16,avx512vl")]
271#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
272#[cfg_attr(test, assert_instr("vdpbf16ps"))]
273pub unsafe fn _mm_maskz_dpbf16_ps(k: __mmask8, src: __m128, a: __m128bh, b: __m128bh) -> __m128 {
274 let rst: f32x4 = _mm_dpbf16_ps(src, a, b).as_f32x4();
275 let zero: f32x4 = _mm_set1_ps(0.0_f32).as_f32x4();
276 transmute(src:simd_select_bitmask(m:k, yes:rst, no:zero))
277}
278
279/// Compute dot-product of BF16 (16-bit) floating-point pairs in a and b,
280/// accumulating the intermediate single-precision (32-bit) floating-point elements
281/// with elements in src, and store the results in dst.
282/// [Intel's documentation](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#expand=1769,1651,1654,1657,1660&avx512techs=AVX512_BF16&text=_mm256_dpbf16_ps)
283#[inline]
284#[target_feature(enable = "avx512bf16,avx512vl")]
285#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
286#[cfg_attr(test, assert_instr("vdpbf16ps"))]
287pub unsafe fn _mm256_dpbf16_ps(src: __m256, a: __m256bh, b: __m256bh) -> __m256 {
288 transmute(src:dpbf16ps_256(a:src.as_f32x8(), b:a.as_i32x8(), c:b.as_i32x8()))
289}
290
291/// Compute dot-product of BF16 (16-bit) floating-point pairs in a and b,
292/// accumulating the intermediate single-precision (32-bit) floating-point elements
293/// with elements in src, and store the results in dst using writemask k
294/// (elements are copied from src when the corresponding mask bit is not set).
295/// [Intel's documentation](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#expand=1769,1651,1654,1657,1660&avx512techs=AVX512_BF16&text=_mm256_mask_dpbf16_ps)
296#[inline]
297#[target_feature(enable = "avx512bf16,avx512vl")]
298#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
299#[cfg_attr(test, assert_instr("vdpbf16ps"))]
300pub unsafe fn _mm256_mask_dpbf16_ps(src: __m256, k: __mmask8, a: __m256bh, b: __m256bh) -> __m256 {
301 let rst: f32x8 = _mm256_dpbf16_ps(src, a, b).as_f32x8();
302 transmute(src:simd_select_bitmask(m:k, yes:rst, no:src.as_f32x8()))
303}
304
305/// Compute dot-product of BF16 (16-bit) floating-point pairs in a and b,
306/// accumulating the intermediate single-precision (32-bit) floating-point elements
307/// with elements in src, and store the results in dst using zeromask k
308/// (elements are zeroed out when the corresponding mask bit is not set).
309/// [Intel's documentation](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#expand=1769,1651,1654,1657,1660&avx512techs=AVX512_BF16&text=_mm256_maskz_dpbf16_ps)
310#[inline]
311#[target_feature(enable = "avx512bf16,avx512vl")]
312#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
313#[cfg_attr(test, assert_instr("vdpbf16ps"))]
314pub unsafe fn _mm256_maskz_dpbf16_ps(k: __mmask8, src: __m256, a: __m256bh, b: __m256bh) -> __m256 {
315 let rst: f32x8 = _mm256_dpbf16_ps(src, a, b).as_f32x8();
316 let zero: f32x8 = _mm256_setzero_ps().as_f32x8();
317 transmute(src:simd_select_bitmask(m:k, yes:rst, no:zero))
318}
319
320/// Compute dot-product of BF16 (16-bit) floating-point pairs in a and b,
321/// accumulating the intermediate single-precision (32-bit) floating-point elements
322/// with elements in src, and store the results in dst.Compute dot-product of BF16 (16-bit)
323/// floating-point pairs in a and b, accumulating the intermediate single-precision (32-bit)
324/// floating-point elements with elements in src, and store the results in dst.
325/// [Intel's documentation](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#expand=1769,1651,1654,1657,1660&avx512techs=AVX512_BF16&text=_mm512_dpbf16_ps)
326#[inline]
327#[target_feature(enable = "avx512bf16,avx512f")]
328#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
329#[cfg_attr(test, assert_instr("vdpbf16ps"))]
330pub unsafe fn _mm512_dpbf16_ps(src: __m512, a: __m512bh, b: __m512bh) -> __m512 {
331 transmute(src:dpbf16ps_512(a:src.as_f32x16(), b:a.as_i32x16(), c:b.as_i32x16()))
332}
333
334/// Compute dot-product of BF16 (16-bit) floating-point pairs in a and b,
335/// accumulating the intermediate single-precision (32-bit) floating-point elements
336/// with elements in src, and store the results in dst using writemask k
337/// (elements are copied from src when the corresponding mask bit is not set).
338/// [Intel's documentation](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#expand=1769,1651,1654,1657,1660&avx512techs=AVX512_BF16&text=_mm512_mask_dpbf16_ps)
339#[inline]
340#[target_feature(enable = "avx512bf16,avx512f")]
341#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
342#[cfg_attr(test, assert_instr("vdpbf16ps"))]
343pub unsafe fn _mm512_mask_dpbf16_ps(src: __m512, k: __mmask16, a: __m512bh, b: __m512bh) -> __m512 {
344 let rst: f32x16 = _mm512_dpbf16_ps(src, a, b).as_f32x16();
345 transmute(src:simd_select_bitmask(m:k, yes:rst, no:src.as_f32x16()))
346}
347
348/// Compute dot-product of BF16 (16-bit) floating-point pairs in a and b,
349/// accumulating the intermediate single-precision (32-bit) floating-point elements
350/// with elements in src, and store the results in dst using zeromask k
351/// (elements are zeroed out when the corresponding mask bit is not set).
352/// [Intel's documentation](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#expand=1769,1651,1654,1657,1660&avx512techs=AVX512_BF16&text=_mm512_maskz_dpbf16_ps)
353#[inline]
354#[target_feature(enable = "avx512bf16,avx512f")]
355#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
356#[cfg_attr(test, assert_instr("vdpbf16ps"))]
357pub unsafe fn _mm512_maskz_dpbf16_ps(
358 k: __mmask16,
359 src: __m512,
360 a: __m512bh,
361 b: __m512bh,
362) -> __m512 {
363 let rst: f32x16 = _mm512_dpbf16_ps(src, a, b).as_f32x16();
364 let zero: f32x16 = _mm512_setzero_ps().as_f32x16();
365 transmute(src:simd_select_bitmask(m:k, yes:rst, no:zero))
366}
367
368#[cfg(test)]
369mod tests {
370 use crate::{core_arch::x86::*, mem::transmute};
371 use stdarch_test::simd_test;
372
373 #[simd_test(enable = "avx512bf16,avx512vl")]
374 unsafe fn test_mm_cvtne2ps_pbh() {
375 let a_array = [178.125_f32, 10.5_f32, 3.75_f32, 50.25_f32];
376 let b_array = [-178.125_f32, -10.5_f32, -3.75_f32, -50.25_f32];
377 let a: __m128 = transmute(a_array);
378 let b: __m128 = transmute(b_array);
379 let c: __m128bh = _mm_cvtne2ps_pbh(a, b);
380 let result: [u16; 8] = transmute(c.as_u16x8());
381 #[rustfmt::skip]
382 let expected_result: [u16; 8] = [
383 0b1_10000110_0110010,
384 0b1_10000010_0101000,
385 0b1_10000000_1110000,
386 0b1_10000100_1001001,
387 0b0_10000110_0110010,
388 0b0_10000010_0101000,
389 0b0_10000000_1110000,
390 0b0_10000100_1001001,
391 ];
392 assert_eq!(result, expected_result);
393 }
394
395 #[simd_test(enable = "avx512bf16,avx512vl")]
396 unsafe fn test_mm_mask_cvtne2ps_pbh() {
397 let a_array = [178.125_f32, 10.5_f32, 3.75_f32, 50.25_f32];
398 let b_array = [-178.125_f32, -10.5_f32, -3.75_f32, -50.25_f32];
399 #[rustfmt::skip]
400 let src_array: [u16; 8] = [
401 0b0_10000110_0110010,
402 0b0_10000010_0101000,
403 0b0_10000000_1110000,
404 0b0_10000100_1001001,
405 0b0_10000110_0110010,
406 0b0_10000010_0101000,
407 0b0_10000000_1110000,
408 0b0_10000100_1001001,
409 ];
410 let src: __m128bh = transmute(src_array);
411 let a: __m128 = transmute(a_array);
412 let b: __m128 = transmute(b_array);
413 let k: __mmask8 = 0b1111_1111;
414 let c: __m128bh = _mm_mask_cvtne2ps_pbh(src, k, a, b);
415 let result: [u16; 8] = transmute(c.as_u16x8());
416 #[rustfmt::skip]
417 let expected_result: [u16; 8] = [
418 0b1_10000110_0110010,
419 0b1_10000010_0101000,
420 0b1_10000000_1110000,
421 0b1_10000100_1001001,
422 0b0_10000110_0110010,
423 0b0_10000010_0101000,
424 0b0_10000000_1110000,
425 0b0_10000100_1001001,
426 ];
427 assert_eq!(result, expected_result);
428 let k = 0b0000_0000;
429 let c = _mm_mask_cvtne2ps_pbh(src, k, a, b);
430 let result: [u16; 8] = transmute(c.as_u16x8());
431 let expected_result = src_array;
432 assert_eq!(result, expected_result);
433 }
434
435 #[simd_test(enable = "avx512bf16,avx512vl")]
436 unsafe fn test_mm_maskz_cvtne2ps_pbh() {
437 let a_array = [178.125_f32, 10.5_f32, 3.75_f32, 50.25_f32];
438 let b_array = [-178.125_f32, -10.5_f32, -3.75_f32, -50.25_f32];
439 let a: __m128 = transmute(a_array);
440 let b: __m128 = transmute(b_array);
441 let k: __mmask8 = 0b1111_1111;
442 let c: __m128bh = _mm_maskz_cvtne2ps_pbh(k, a, b);
443 let result: [u16; 8] = transmute(c.as_u16x8());
444 #[rustfmt::skip]
445 let expected_result: [u16; 8] = [
446 0b1_10000110_0110010,
447 0b1_10000010_0101000,
448 0b1_10000000_1110000,
449 0b1_10000100_1001001,
450 0b0_10000110_0110010,
451 0b0_10000010_0101000,
452 0b0_10000000_1110000,
453 0b0_10000100_1001001,
454 ];
455 assert_eq!(result, expected_result);
456 let k = 0b0011_1100;
457 let c = _mm_maskz_cvtne2ps_pbh(k, a, b);
458 let result: [u16; 8] = transmute(c.as_u16x8());
459 #[rustfmt::skip]
460 let expected_result: [u16; 8] = [
461 0,
462 0,
463 0b1_10000000_1110000,
464 0b1_10000100_1001001,
465 0b0_10000110_0110010,
466 0b0_10000010_0101000,
467 0,
468 0,
469 ];
470 assert_eq!(result, expected_result);
471 }
472
473 #[simd_test(enable = "avx512bf16,avx512vl")]
474 unsafe fn test_mm256_cvtne2ps_pbh() {
475 #[rustfmt::skip]
476 let a_array = [
477 178.125_f32,
478 10.5_f32,
479 3.75_f32,
480 50.25_f32,
481 16.5_f32,
482 255.11_f32,
483 1000.158_f32,
484 575.575_f32,
485 ];
486 let b_array = [
487 -178.125_f32,
488 -10.5_f32,
489 -3.75_f32,
490 -50.25_f32,
491 -16.5_f32,
492 -255.11_f32,
493 -1000.158_f32,
494 -575.575_f32,
495 ];
496 let a: __m256 = transmute(a_array);
497 let b: __m256 = transmute(b_array);
498 let c: __m256bh = _mm256_cvtne2ps_pbh(a, b);
499 let result: [u16; 16] = transmute(c.as_u16x16());
500 #[rustfmt::skip]
501 let expected_result: [u16; 16] = [
502 0b1_10000110_0110010,
503 0b1_10000010_0101000,
504 0b1_10000000_1110000,
505 0b1_10000100_1001001,
506 0b1_10000011_0000100,
507 0b1_10000110_1111111,
508 0b1_10001000_1111010,
509 0b1_10001000_0010000,
510 0b0_10000110_0110010,
511 0b0_10000010_0101000,
512 0b0_10000000_1110000,
513 0b0_10000100_1001001,
514 0b0_10000011_0000100,
515 0b0_10000110_1111111,
516 0b0_10001000_1111010,
517 0b0_10001000_0010000,
518 ];
519 assert_eq!(result, expected_result);
520 }
521
522 #[simd_test(enable = "avx512bf16,avx512vl")]
523 unsafe fn test_mm256_mask_cvtne2ps_pbh() {
524 #[rustfmt::skip]
525 let a_array = [
526 178.125_f32,
527 10.5_f32,
528 3.75_f32,
529 50.25_f32,
530 16.5_f32,
531 255.11_f32,
532 1000.158_f32,
533 575.575_f32,
534 ];
535 let b_array = [
536 -178.125_f32,
537 -10.5_f32,
538 -3.75_f32,
539 -50.25_f32,
540 -16.5_f32,
541 -255.11_f32,
542 -1000.158_f32,
543 -575.575_f32,
544 ];
545 let src_array: [u16; 16] = [
546 0b0_10000110_0110010,
547 0b0_10000010_0101000,
548 0b0_10000000_1110000,
549 0b0_10000100_1001001,
550 0b0_10000110_0110010,
551 0b0_10000010_0101000,
552 0b0_10000000_1110000,
553 0b0_10000100_1001001,
554 0b0_10000110_0110010,
555 0b0_10000010_0101000,
556 0b0_10000000_1110000,
557 0b0_10000100_1001001,
558 0b0_10000110_0110010,
559 0b0_10000010_0101000,
560 0b0_10000000_1110000,
561 0b0_10000100_1001001,
562 ];
563 let src: __m256bh = transmute(src_array);
564 let a: __m256 = transmute(a_array);
565 let b: __m256 = transmute(b_array);
566 let k: __mmask16 = 0xffff;
567 let c: __m256bh = _mm256_mask_cvtne2ps_pbh(src, k, a, b);
568 let result: [u16; 16] = transmute(c.as_u16x16());
569 #[rustfmt::skip]
570 let expected_result: [u16; 16] = [
571 0b1_10000110_0110010,
572 0b1_10000010_0101000,
573 0b1_10000000_1110000,
574 0b1_10000100_1001001,
575 0b1_10000011_0000100,
576 0b1_10000110_1111111,
577 0b1_10001000_1111010,
578 0b1_10001000_0010000,
579 0b0_10000110_0110010,
580 0b0_10000010_0101000,
581 0b0_10000000_1110000,
582 0b0_10000100_1001001,
583 0b0_10000011_0000100,
584 0b0_10000110_1111111,
585 0b0_10001000_1111010,
586 0b0_10001000_0010000,
587 ];
588 assert_eq!(result, expected_result);
589 let k: __mmask16 = 0;
590 let c: __m256bh = _mm256_mask_cvtne2ps_pbh(src, k, a, b);
591 let result: [u16; 16] = transmute(c.as_u16x16());
592 let expected_result = src_array;
593 assert_eq!(result, expected_result);
594 }
595
596 #[simd_test(enable = "avx512bf16,avx512vl")]
597 unsafe fn test_mm256_maskz_cvtne2ps_pbh() {
598 #[rustfmt::skip]
599 let a_array = [
600 178.125_f32,
601 10.5_f32,
602 3.75_f32,
603 50.25_f32,
604 16.5_f32,
605 255.11_f32,
606 1000.158_f32,
607 575.575_f32,
608 ];
609 let b_array = [
610 -178.125_f32,
611 -10.5_f32,
612 -3.75_f32,
613 -50.25_f32,
614 -16.5_f32,
615 -255.11_f32,
616 -1000.158_f32,
617 -575.575_f32,
618 ];
619 let a: __m256 = transmute(a_array);
620 let b: __m256 = transmute(b_array);
621 let k: __mmask16 = 0xffff;
622 let c: __m256bh = _mm256_maskz_cvtne2ps_pbh(k, a, b);
623 let result: [u16; 16] = transmute(c.as_u16x16());
624 #[rustfmt::skip]
625 let expected_result: [u16; 16] = [
626 0b1_10000110_0110010,
627 0b1_10000010_0101000,
628 0b1_10000000_1110000,
629 0b1_10000100_1001001,
630 0b1_10000011_0000100,
631 0b1_10000110_1111111,
632 0b1_10001000_1111010,
633 0b1_10001000_0010000,
634 0b0_10000110_0110010,
635 0b0_10000010_0101000,
636 0b0_10000000_1110000,
637 0b0_10000100_1001001,
638 0b0_10000011_0000100,
639 0b0_10000110_1111111,
640 0b0_10001000_1111010,
641 0b0_10001000_0010000,
642 ];
643 assert_eq!(result, expected_result);
644 let k: __mmask16 = 0b0110_1100_0011_0110;
645 let c: __m256bh = _mm256_maskz_cvtne2ps_pbh(k, a, b);
646 let result: [u16; 16] = transmute(c.as_u16x16());
647 #[rustfmt::skip]
648 let expected_result: [u16; 16] = [
649 0,
650 0b1_10000010_0101000,
651 0b1_10000000_1110000,
652 0,
653 0b1_10000011_0000100,
654 0b1_10000110_1111111,
655 0,
656 0,
657 0,
658 0,
659 0b0_10000000_1110000,
660 0b0_10000100_1001001,
661 0,
662 0b0_10000110_1111111,
663 0b0_10001000_1111010,
664 0,
665 ];
666 assert_eq!(result, expected_result);
667 }
668
669 #[simd_test(enable = "avx512bf16,avx512f")]
670 unsafe fn test_mm512_cvtne2ps_pbh() {
671 #[rustfmt::skip]
672 let a_array = [
673 178.125_f32,
674 10.5_f32,
675 3.75_f32,
676 50.25_f32,
677 16.5_f32,
678 255.11_f32,
679 1000.158_f32,
680 575.575_f32,
681 178.125_f32,
682 10.5_f32,
683 3.75_f32,
684 50.25_f32,
685 16.5_f32,
686 255.11_f32,
687 1000.158_f32,
688 575.575_f32,
689 ];
690 let b_array = [
691 -178.125_f32,
692 -10.5_f32,
693 -3.75_f32,
694 -50.25_f32,
695 -16.5_f32,
696 -255.11_f32,
697 -1000.158_f32,
698 -575.575_f32,
699 -178.125_f32,
700 -10.5_f32,
701 -3.75_f32,
702 -50.25_f32,
703 -16.5_f32,
704 -255.11_f32,
705 -1000.158_f32,
706 -575.575_f32,
707 ];
708 let a: __m512 = transmute(a_array);
709 let b: __m512 = transmute(b_array);
710 let c: __m512bh = _mm512_cvtne2ps_pbh(a, b);
711 let result: [u16; 32] = transmute(c.as_u16x32());
712 #[rustfmt::skip]
713 let expected_result: [u16; 32] = [
714 0b1_10000110_0110010,
715 0b1_10000010_0101000,
716 0b1_10000000_1110000,
717 0b1_10000100_1001001,
718 0b1_10000011_0000100,
719 0b1_10000110_1111111,
720 0b1_10001000_1111010,
721 0b1_10001000_0010000,
722 0b1_10000110_0110010,
723 0b1_10000010_0101000,
724 0b1_10000000_1110000,
725 0b1_10000100_1001001,
726 0b1_10000011_0000100,
727 0b1_10000110_1111111,
728 0b1_10001000_1111010,
729 0b1_10001000_0010000,
730 0b0_10000110_0110010,
731 0b0_10000010_0101000,
732 0b0_10000000_1110000,
733 0b0_10000100_1001001,
734 0b0_10000011_0000100,
735 0b0_10000110_1111111,
736 0b0_10001000_1111010,
737 0b0_10001000_0010000,
738 0b0_10000110_0110010,
739 0b0_10000010_0101000,
740 0b0_10000000_1110000,
741 0b0_10000100_1001001,
742 0b0_10000011_0000100,
743 0b0_10000110_1111111,
744 0b0_10001000_1111010,
745 0b0_10001000_0010000,
746 ];
747 assert_eq!(result, expected_result);
748 }
749
750 #[simd_test(enable = "avx512bf16,avx512f")]
751 unsafe fn test_mm512_mask_cvtne2ps_pbh() {
752 #[rustfmt::skip]
753 let a_array = [
754 178.125_f32,
755 10.5_f32,
756 3.75_f32,
757 50.25_f32,
758 16.5_f32,
759 255.11_f32,
760 1000.158_f32,
761 575.575_f32,
762 178.125_f32,
763 10.5_f32,
764 3.75_f32,
765 50.25_f32,
766 16.5_f32,
767 255.11_f32,
768 1000.158_f32,
769 575.575_f32,
770 ];
771 let b_array = [
772 -178.125_f32,
773 -10.5_f32,
774 -3.75_f32,
775 -50.25_f32,
776 -16.5_f32,
777 -255.11_f32,
778 -1000.158_f32,
779 -575.575_f32,
780 -178.125_f32,
781 -10.5_f32,
782 -3.75_f32,
783 -50.25_f32,
784 -16.5_f32,
785 -255.11_f32,
786 -1000.158_f32,
787 -575.575_f32,
788 ];
789 let src_array: [u16; 32] = [
790 0b0_10000110_0110010,
791 0b0_10000010_0101000,
792 0b0_10000000_1110000,
793 0b0_10000100_1001001,
794 0b0_10000110_0110010,
795 0b0_10000010_0101000,
796 0b0_10000000_1110000,
797 0b0_10000100_1001001,
798 0b0_10000110_0110010,
799 0b0_10000010_0101000,
800 0b0_10000000_1110000,
801 0b0_10000100_1001001,
802 0b0_10000110_0110010,
803 0b0_10000010_0101000,
804 0b0_10000000_1110000,
805 0b0_10000100_1001001,
806 0b0_10000110_0110010,
807 0b0_10000010_0101000,
808 0b0_10000000_1110000,
809 0b0_10000100_1001001,
810 0b0_10000110_0110010,
811 0b0_10000010_0101000,
812 0b0_10000000_1110000,
813 0b0_10000100_1001001,
814 0b0_10000110_0110010,
815 0b0_10000010_0101000,
816 0b0_10000000_1110000,
817 0b0_10000100_1001001,
818 0b0_10000110_0110010,
819 0b0_10000010_0101000,
820 0b0_10000000_1110000,
821 0b0_10000100_1001001,
822 ];
823 let src: __m512bh = transmute(src_array);
824 let a: __m512 = transmute(a_array);
825 let b: __m512 = transmute(b_array);
826 let k: __mmask32 = 0xffffffff;
827 let c: __m512bh = _mm512_mask_cvtne2ps_pbh(src, k, a, b);
828 let result: [u16; 32] = transmute(c.as_u16x32());
829 #[rustfmt::skip]
830 let expected_result: [u16; 32] = [
831 0b1_10000110_0110010,
832 0b1_10000010_0101000,
833 0b1_10000000_1110000,
834 0b1_10000100_1001001,
835 0b1_10000011_0000100,
836 0b1_10000110_1111111,
837 0b1_10001000_1111010,
838 0b1_10001000_0010000,
839 0b1_10000110_0110010,
840 0b1_10000010_0101000,
841 0b1_10000000_1110000,
842 0b1_10000100_1001001,
843 0b1_10000011_0000100,
844 0b1_10000110_1111111,
845 0b1_10001000_1111010,
846 0b1_10001000_0010000,
847 0b0_10000110_0110010,
848 0b0_10000010_0101000,
849 0b0_10000000_1110000,
850 0b0_10000100_1001001,
851 0b0_10000011_0000100,
852 0b0_10000110_1111111,
853 0b0_10001000_1111010,
854 0b0_10001000_0010000,
855 0b0_10000110_0110010,
856 0b0_10000010_0101000,
857 0b0_10000000_1110000,
858 0b0_10000100_1001001,
859 0b0_10000011_0000100,
860 0b0_10000110_1111111,
861 0b0_10001000_1111010,
862 0b0_10001000_0010000,
863 ];
864 assert_eq!(result, expected_result);
865 let k: __mmask32 = 0;
866 let c: __m512bh = _mm512_mask_cvtne2ps_pbh(src, k, a, b);
867 let result: [u16; 32] = transmute(c.as_u16x32());
868 let expected_result = src_array;
869 assert_eq!(result, expected_result);
870 }
871
872 #[simd_test(enable = "avx512bf16,avx512f")]
873 unsafe fn test_mm512_maskz_cvtne2ps_pbh() {
874 #[rustfmt::skip]
875 let a_array = [
876 178.125_f32,
877 10.5_f32,
878 3.75_f32,
879 50.25_f32,
880 16.5_f32,
881 255.11_f32,
882 1000.158_f32,
883 575.575_f32,
884 178.125_f32,
885 10.5_f32,
886 3.75_f32,
887 50.25_f32,
888 16.5_f32,
889 255.11_f32,
890 1000.158_f32,
891 575.575_f32,
892 ];
893 let b_array = [
894 -178.125_f32,
895 -10.5_f32,
896 -3.75_f32,
897 -50.25_f32,
898 -16.5_f32,
899 -255.11_f32,
900 -1000.158_f32,
901 -575.575_f32,
902 -178.125_f32,
903 -10.5_f32,
904 -3.75_f32,
905 -50.25_f32,
906 -16.5_f32,
907 -255.11_f32,
908 -1000.158_f32,
909 -575.575_f32,
910 ];
911 let a: __m512 = transmute(a_array);
912 let b: __m512 = transmute(b_array);
913 let k: __mmask32 = 0xffffffff;
914 let c: __m512bh = _mm512_maskz_cvtne2ps_pbh(k, a, b);
915 let result: [u16; 32] = transmute(c.as_u16x32());
916 #[rustfmt::skip]
917 let expected_result: [u16; 32] = [
918 0b1_10000110_0110010,
919 0b1_10000010_0101000,
920 0b1_10000000_1110000,
921 0b1_10000100_1001001,
922 0b1_10000011_0000100,
923 0b1_10000110_1111111,
924 0b1_10001000_1111010,
925 0b1_10001000_0010000,
926 0b1_10000110_0110010,
927 0b1_10000010_0101000,
928 0b1_10000000_1110000,
929 0b1_10000100_1001001,
930 0b1_10000011_0000100,
931 0b1_10000110_1111111,
932 0b1_10001000_1111010,
933 0b1_10001000_0010000,
934 0b0_10000110_0110010,
935 0b0_10000010_0101000,
936 0b0_10000000_1110000,
937 0b0_10000100_1001001,
938 0b0_10000011_0000100,
939 0b0_10000110_1111111,
940 0b0_10001000_1111010,
941 0b0_10001000_0010000,
942 0b0_10000110_0110010,
943 0b0_10000010_0101000,
944 0b0_10000000_1110000,
945 0b0_10000100_1001001,
946 0b0_10000011_0000100,
947 0b0_10000110_1111111,
948 0b0_10001000_1111010,
949 0b0_10001000_0010000,
950 ];
951 assert_eq!(result, expected_result);
952 let k: __mmask32 = 0b1100_1010_1001_0110_1010_0011_0101_0110;
953 let c: __m512bh = _mm512_maskz_cvtne2ps_pbh(k, a, b);
954 let result: [u16; 32] = transmute(c.as_u16x32());
955 #[rustfmt::skip]
956 let expected_result: [u16; 32] = [
957 0,
958 0b1_10000010_0101000,
959 0b1_10000000_1110000,
960 0,
961 0b1_10000011_0000100,
962 0,
963 0b1_10001000_1111010,
964 0,
965 0b1_10000110_0110010,
966 0b1_10000010_0101000,
967 0,
968 0,
969 0,
970 0b1_10000110_1111111,
971 0,
972 0b1_10001000_0010000,
973 0,
974 0b0_10000010_0101000,
975 0b0_10000000_1110000,
976 0,
977 0b0_10000011_0000100,
978 0,
979 0,
980 0b0_10001000_0010000,
981 0,
982 0b0_10000010_0101000,
983 0,
984 0b0_10000100_1001001,
985 0,
986 0,
987 0b0_10001000_1111010,
988 0b0_10001000_0010000,
989 ];
990 assert_eq!(result, expected_result);
991 }
992
993 #[simd_test(enable = "avx512bf16,avx512vl")]
994 unsafe fn test_mm256_cvtneps_pbh() {
995 #[rustfmt::skip]
996 let a_array = [
997 178.125_f32,
998 10.5_f32,
999 3.75_f32,
1000 50.25_f32,
1001 16.5_f32,
1002 255.11_f32,
1003 1000.158_f32,
1004 575.575_f32,
1005 ];
1006 let a: __m256 = transmute(a_array);
1007 let c: __m128bh = _mm256_cvtneps_pbh(a);
1008 let result: [u16; 8] = transmute(c.as_u16x8());
1009 #[rustfmt::skip]
1010 let expected_result: [u16; 8] = [
1011 0b0_10000110_0110010,
1012 0b0_10000010_0101000,
1013 0b0_10000000_1110000,
1014 0b0_10000100_1001001,
1015 0b0_10000011_0000100,
1016 0b0_10000110_1111111,
1017 0b0_10001000_1111010,
1018 0b0_10001000_0010000,
1019 ];
1020 assert_eq!(result, expected_result);
1021 }
1022
1023 #[simd_test(enable = "avx512bf16,avx512vl")]
1024 unsafe fn test_mm256_mask_cvtneps_pbh() {
1025 #[rustfmt::skip]
1026 let a_array = [
1027 178.125_f32,
1028 10.5_f32,
1029 3.75_f32,
1030 50.25_f32,
1031 16.5_f32,
1032 255.11_f32,
1033 1000.158_f32,
1034 575.575_f32,
1035 ];
1036 let src_array: [u16; 8] = [
1037 0b1_10000110_0110010,
1038 0b1_10000010_0101000,
1039 0b1_10000000_1110000,
1040 0b1_10000100_1001001,
1041 0b1_10000011_0000100,
1042 0b1_10000110_1111111,
1043 0b1_10001000_1111010,
1044 0b1_10001000_0010000,
1045 ];
1046 let src: __m128bh = transmute(src_array);
1047 let a: __m256 = transmute(a_array);
1048 let k: __mmask8 = 0xff;
1049 let b = _mm256_mask_cvtneps_pbh(src, k, a);
1050 let result: [u16; 8] = transmute(b.as_u16x8());
1051 #[rustfmt::skip]
1052 let expected_result: [u16; 8] = [
1053 0b0_10000110_0110010,
1054 0b0_10000010_0101000,
1055 0b0_10000000_1110000,
1056 0b0_10000100_1001001,
1057 0b0_10000011_0000100,
1058 0b0_10000110_1111111,
1059 0b0_10001000_1111010,
1060 0b0_10001000_0010000,
1061 ];
1062 assert_eq!(result, expected_result);
1063 let k: __mmask8 = 0x0;
1064 let b: __m128bh = _mm256_mask_cvtneps_pbh(src, k, a);
1065 let result: [u16; 8] = transmute(b.as_u16x8());
1066 let expected_result: [u16; 8] = src_array;
1067 assert_eq!(result, expected_result);
1068 }
1069
1070 #[simd_test(enable = "avx512bf16,avx512vl")]
1071 unsafe fn test_mm256_maskz_cvtneps_pbh() {
1072 #[rustfmt::skip]
1073 let a_array = [
1074 178.125_f32,
1075 10.5_f32,
1076 3.75_f32,
1077 50.25_f32,
1078 16.5_f32,
1079 255.11_f32,
1080 1000.158_f32,
1081 575.575_f32,
1082 ];
1083 let a: __m256 = transmute(a_array);
1084 let k: __mmask8 = 0xff;
1085 let b = _mm256_maskz_cvtneps_pbh(k, a);
1086 let result: [u16; 8] = transmute(b.as_u16x8());
1087 #[rustfmt::skip]
1088 let expected_result: [u16; 8] = [
1089 0b0_10000110_0110010,
1090 0b0_10000010_0101000,
1091 0b0_10000000_1110000,
1092 0b0_10000100_1001001,
1093 0b0_10000011_0000100,
1094 0b0_10000110_1111111,
1095 0b0_10001000_1111010,
1096 0b0_10001000_0010000,
1097 ];
1098 assert_eq!(result, expected_result);
1099 let k: __mmask8 = 0x6;
1100 let b: __m128bh = _mm256_maskz_cvtneps_pbh(k, a);
1101 let result: [u16; 8] = transmute(b.as_u16x8());
1102 let expected_result: [u16; 8] =
1103 [0, 0b0_10000010_0101000, 0b0_10000000_1110000, 0, 0, 0, 0, 0];
1104 assert_eq!(result, expected_result);
1105 }
1106
1107 #[simd_test(enable = "avx512bf16,avx512f")]
1108 unsafe fn test_mm512_cvtneps_pbh() {
1109 #[rustfmt::skip]
1110 let a_array = [
1111 178.125_f32,
1112 10.5_f32,
1113 3.75_f32,
1114 50.25_f32,
1115 16.5_f32,
1116 255.11_f32,
1117 1000.158_f32,
1118 575.575_f32,
1119 178.125_f32,
1120 10.5_f32,
1121 3.75_f32,
1122 50.25_f32,
1123 16.5_f32,
1124 255.11_f32,
1125 1000.158_f32,
1126 575.575_f32,
1127 ];
1128 let a: __m512 = transmute(a_array);
1129 let c: __m256bh = _mm512_cvtneps_pbh(a);
1130 let result: [u16; 16] = transmute(c.as_u16x16());
1131 #[rustfmt::skip]
1132 let expected_result: [u16; 16] = [
1133 0b0_10000110_0110010,
1134 0b0_10000010_0101000,
1135 0b0_10000000_1110000,
1136 0b0_10000100_1001001,
1137 0b0_10000011_0000100,
1138 0b0_10000110_1111111,
1139 0b0_10001000_1111010,
1140 0b0_10001000_0010000,
1141 0b0_10000110_0110010,
1142 0b0_10000010_0101000,
1143 0b0_10000000_1110000,
1144 0b0_10000100_1001001,
1145 0b0_10000011_0000100,
1146 0b0_10000110_1111111,
1147 0b0_10001000_1111010,
1148 0b0_10001000_0010000,
1149 ];
1150 assert_eq!(result, expected_result);
1151 }
1152
1153 #[simd_test(enable = "avx512bf16,avx512f")]
1154 unsafe fn test_mm512_mask_cvtneps_pbh() {
1155 #[rustfmt::skip]
1156 let a_array = [
1157 178.125_f32,
1158 10.5_f32,
1159 3.75_f32,
1160 50.25_f32,
1161 16.5_f32,
1162 255.11_f32,
1163 1000.158_f32,
1164 575.575_f32,
1165 178.125_f32,
1166 10.5_f32,
1167 3.75_f32,
1168 50.25_f32,
1169 16.5_f32,
1170 255.11_f32,
1171 1000.158_f32,
1172 575.575_f32,
1173 ];
1174 let src_array: [u16; 16] = [
1175 0b1_10000110_0110010,
1176 0b1_10000010_0101000,
1177 0b1_10000000_1110000,
1178 0b1_10000100_1001001,
1179 0b1_10000011_0000100,
1180 0b1_10000110_1111111,
1181 0b1_10001000_1111010,
1182 0b1_10001000_0010000,
1183 0b1_10000110_0110010,
1184 0b1_10000010_0101000,
1185 0b1_10000000_1110000,
1186 0b1_10000100_1001001,
1187 0b1_10000011_0000100,
1188 0b1_10000110_1111111,
1189 0b1_10001000_1111010,
1190 0b1_10001000_0010000,
1191 ];
1192 let src: __m256bh = transmute(src_array);
1193 let a: __m512 = transmute(a_array);
1194 let k: __mmask16 = 0xffff;
1195 let c: __m256bh = _mm512_mask_cvtneps_pbh(src, k, a);
1196 let result: [u16; 16] = transmute(c.as_u16x16());
1197 #[rustfmt::skip]
1198 let expected_result: [u16; 16] = [
1199 0b0_10000110_0110010,
1200 0b0_10000010_0101000,
1201 0b0_10000000_1110000,
1202 0b0_10000100_1001001,
1203 0b0_10000011_0000100,
1204 0b0_10000110_1111111,
1205 0b0_10001000_1111010,
1206 0b0_10001000_0010000,
1207 0b0_10000110_0110010,
1208 0b0_10000010_0101000,
1209 0b0_10000000_1110000,
1210 0b0_10000100_1001001,
1211 0b0_10000011_0000100,
1212 0b0_10000110_1111111,
1213 0b0_10001000_1111010,
1214 0b0_10001000_0010000,
1215 ];
1216 assert_eq!(result, expected_result);
1217 let k: __mmask16 = 0;
1218 let c: __m256bh = _mm512_mask_cvtneps_pbh(src, k, a);
1219 let result: [u16; 16] = transmute(c.as_u16x16());
1220 let expected_result = src_array;
1221 assert_eq!(result, expected_result);
1222 }
1223
1224 #[simd_test(enable = "avx512bf16,avx512f")]
1225 unsafe fn test_mm512_maskz_cvtneps_pbh() {
1226 #[rustfmt::skip]
1227 let a_array = [
1228 178.125_f32,
1229 10.5_f32,
1230 3.75_f32,
1231 50.25_f32,
1232 16.5_f32,
1233 255.11_f32,
1234 1000.158_f32,
1235 575.575_f32,
1236 178.125_f32,
1237 10.5_f32,
1238 3.75_f32,
1239 50.25_f32,
1240 16.5_f32,
1241 255.11_f32,
1242 1000.158_f32,
1243 575.575_f32,
1244 ];
1245 let a: __m512 = transmute(a_array);
1246 let k: __mmask16 = 0xffff;
1247 let c: __m256bh = _mm512_maskz_cvtneps_pbh(k, a);
1248 let result: [u16; 16] = transmute(c.as_u16x16());
1249 #[rustfmt::skip]
1250 let expected_result: [u16; 16] = [
1251 0b0_10000110_0110010,
1252 0b0_10000010_0101000,
1253 0b0_10000000_1110000,
1254 0b0_10000100_1001001,
1255 0b0_10000011_0000100,
1256 0b0_10000110_1111111,
1257 0b0_10001000_1111010,
1258 0b0_10001000_0010000,
1259 0b0_10000110_0110010,
1260 0b0_10000010_0101000,
1261 0b0_10000000_1110000,
1262 0b0_10000100_1001001,
1263 0b0_10000011_0000100,
1264 0b0_10000110_1111111,
1265 0b0_10001000_1111010,
1266 0b0_10001000_0010000,
1267 ];
1268 assert_eq!(result, expected_result);
1269 let k: __mmask16 = 0x653a;
1270 let c: __m256bh = _mm512_maskz_cvtneps_pbh(k, a);
1271 let result: [u16; 16] = transmute(c.as_u16x16());
1272 #[rustfmt::skip]
1273 let expected_result: [u16; 16] = [
1274 0,
1275 0b0_10000010_0101000,
1276 0,
1277 0b0_10000100_1001001,
1278 0b0_10000011_0000100,
1279 0b0_10000110_1111111,
1280 0,
1281 0,
1282 0b0_10000110_0110010,
1283 0,
1284 0b0_10000000_1110000,
1285 0,
1286 0,
1287 0b0_10000110_1111111,
1288 0b0_10001000_1111010,
1289 0,
1290 ];
1291 assert_eq!(result, expected_result);
1292 }
1293
1294 #[simd_test(enable = "avx512bf16,avx512vl")]
1295 unsafe fn test_mm_dpbf16_ps() {
1296 let a_array = [8.5_f32, 10.5_f32, 3.75_f32, 50.25_f32];
1297 let b_array = [-1.0_f32, -1.0_f32, -1.0_f32, -1.0_f32];
1298 let a1: __m128 = transmute(a_array);
1299 let b1: __m128 = transmute(b_array);
1300 let src: __m128 = transmute([1.0_f32, 2.0_f32, 3.0_f32, 4.0_f32]);
1301 let a: __m128bh = _mm_cvtne2ps_pbh(a1, a1);
1302 let b: __m128bh = _mm_cvtne2ps_pbh(b1, b1);
1303 let c: __m128 = _mm_dpbf16_ps(src, a, b);
1304 let result: [f32; 4] = transmute(c.as_f32x4());
1305 let expected_result: [f32; 4] = [-18.0_f32, -52.0_f32, -16.0_f32, -50.0_f32];
1306 assert_eq!(result, expected_result);
1307 }
1308
1309 #[simd_test(enable = "avx512bf16,avx512vl")]
1310 unsafe fn test_mm_mask_dpbf16_ps() {
1311 let a_array = [8.5_f32, 10.5_f32, 3.75_f32, 50.25_f32];
1312 let b_array = [-1.0_f32, -1.0_f32, -1.0_f32, -1.0_f32];
1313 let a1: __m128 = transmute(a_array);
1314 let b1: __m128 = transmute(b_array);
1315 let k: __mmask8 = 0xf3;
1316 let src: __m128 = transmute([1.0_f32, 2.0_f32, 3.0_f32, 4.0_f32]);
1317 let a: __m128bh = _mm_cvtne2ps_pbh(a1, a1);
1318 let b: __m128bh = _mm_cvtne2ps_pbh(b1, b1);
1319 let c: __m128 = _mm_mask_dpbf16_ps(src, k, a, b);
1320 let result: [f32; 4] = transmute(c.as_f32x4());
1321 let expected_result: [f32; 4] = [-18.0_f32, -52.0_f32, 3.0_f32, 4.0_f32];
1322 assert_eq!(result, expected_result);
1323 let k: __mmask8 = 0xff;
1324 let c: __m128 = _mm_mask_dpbf16_ps(src, k, a, b);
1325 let result: [f32; 4] = transmute(c.as_f32x4());
1326 let expected_result: [f32; 4] = [-18.0_f32, -52.0_f32, -16.0_f32, -50.0_f32];
1327 assert_eq!(result, expected_result);
1328 let k: __mmask8 = 0;
1329 let c: __m128 = _mm_mask_dpbf16_ps(src, k, a, b);
1330 let result: [f32; 4] = transmute(c.as_f32x4());
1331 let expected_result: [f32; 4] = [1.0_f32, 2.0_f32, 3.0_f32, 4.0_f32];
1332 assert_eq!(result, expected_result);
1333 }
1334
1335 #[simd_test(enable = "avx512bf16,avx512vl")]
1336 unsafe fn test_mm_maskz_dpbf16_ps() {
1337 let a_array = [8.5_f32, 10.5_f32, 3.75_f32, 50.25_f32];
1338 let b_array = [-1.0_f32, -1.0_f32, -1.0_f32, -1.0_f32];
1339 let a1: __m128 = transmute(a_array);
1340 let b1: __m128 = transmute(b_array);
1341 let k: __mmask8 = 0xf3;
1342 let src: __m128 = transmute([1.0_f32, 2.0_f32, 3.0_f32, 4.0_f32]);
1343 let a: __m128bh = _mm_cvtne2ps_pbh(a1, a1);
1344 let b: __m128bh = _mm_cvtne2ps_pbh(b1, b1);
1345 let c: __m128 = _mm_maskz_dpbf16_ps(k, src, a, b);
1346 let result: [f32; 4] = transmute(c.as_f32x4());
1347 let expected_result: [f32; 4] = [-18.0_f32, -52.0_f32, 0.0, 0.0];
1348 assert_eq!(result, expected_result);
1349 let k: __mmask8 = 0xff;
1350 let c: __m128 = _mm_maskz_dpbf16_ps(k, src, a, b);
1351 let result: [f32; 4] = transmute(c.as_f32x4());
1352 let expected_result: [f32; 4] = [-18.0_f32, -52.0_f32, -16.0_f32, -50.0_f32];
1353 assert_eq!(result, expected_result);
1354 let k: __mmask8 = 0;
1355 let c: __m128 = _mm_maskz_dpbf16_ps(k, src, a, b);
1356 let result: [f32; 4] = transmute(c.as_f32x4());
1357 let expected_result: [f32; 4] = [0.0, 0.0, 0.0, 0.0];
1358 assert_eq!(result, expected_result);
1359 }
1360
1361 #[simd_test(enable = "avx512bf16,avx512vl")]
1362 unsafe fn test_mm256_dpbf16_ps() {
1363 #[rustfmt::skip]
1364 let a_array = [
1365 8.5_f32, 10.5_f32, 3.75_f32, 50.25_f32, 8.5_f32, 10.5_f32, 3.75_f32, 50.25_f32,
1366 ];
1367 let b_array = [
1368 -1.0_f32, -1.0_f32, -1.0_f32, -1.0_f32, -1.0_f32, -1.0_f32, -1.0_f32, -1.0_f32,
1369 ];
1370 let a1: __m256 = transmute(a_array);
1371 let b1: __m256 = transmute(b_array);
1372 #[rustfmt::skip]
1373 let src: __m256 = transmute([
1374 1.0_f32, 2.0_f32, 3.0_f32, 4.0_f32, 1.0_f32, 2.0_f32, 3.0_f32, 4.0_f32,
1375 ]);
1376 let a: __m256bh = _mm256_cvtne2ps_pbh(a1, a1);
1377 let b: __m256bh = _mm256_cvtne2ps_pbh(b1, b1);
1378 let c: __m256 = _mm256_dpbf16_ps(src, a, b);
1379 let result: [f32; 8] = transmute(c.as_f32x8());
1380 #[rustfmt::skip]
1381 let expected_result: [f32; 8] = [
1382 -18.0_f32, -52.0_f32, -16.0_f32, -50.0_f32, -18.0_f32, -52.0_f32, -16.0_f32, -50.0_f32,
1383 ];
1384 assert_eq!(result, expected_result);
1385 }
1386
1387 #[simd_test(enable = "avx512bf16,avx512vl")]
1388 unsafe fn test_mm256_mask_dpbf16_ps() {
1389 #[rustfmt::skip]
1390 let a_array = [
1391 8.5_f32, 10.5_f32, 3.75_f32, 50.25_f32, 8.5_f32, 10.5_f32, 3.75_f32, 50.25_f32,
1392 ];
1393 let b_array = [
1394 -1.0_f32, -1.0_f32, -1.0_f32, -1.0_f32, -1.0_f32, -1.0_f32, -1.0_f32, -1.0_f32,
1395 ];
1396 let a1: __m256 = transmute(a_array);
1397 let b1: __m256 = transmute(b_array);
1398 let k: __mmask8 = 0x33;
1399 #[rustfmt::skip]
1400 let src: __m256 = transmute([
1401 1.0_f32, 2.0_f32, 3.0_f32, 4.0_f32, 1.0_f32, 2.0_f32, 3.0_f32, 4.0_f32,
1402 ]);
1403 let a: __m256bh = _mm256_cvtne2ps_pbh(a1, a1);
1404 let b: __m256bh = _mm256_cvtne2ps_pbh(b1, b1);
1405 let c: __m256 = _mm256_mask_dpbf16_ps(src, k, a, b);
1406 let result: [f32; 8] = transmute(c.as_f32x8());
1407 #[rustfmt::skip]
1408 let expected_result: [f32; 8] = [
1409 -18.0_f32, -52.0_f32, 3.0_f32, 4.0_f32, -18.0_f32, -52.0_f32, 3.0_f32, 4.0_f32,
1410 ];
1411 assert_eq!(result, expected_result);
1412 let k: __mmask8 = 0xff;
1413 let c: __m256 = _mm256_mask_dpbf16_ps(src, k, a, b);
1414 let result: [f32; 8] = transmute(c.as_f32x8());
1415 #[rustfmt::skip]
1416 let expected_result: [f32; 8] = [
1417 -18.0_f32, -52.0_f32, -16.0_f32, -50.0_f32, -18.0_f32, -52.0_f32, -16.0_f32, -50.0_f32,
1418 ];
1419 assert_eq!(result, expected_result);
1420 let k: __mmask8 = 0;
1421 let c: __m256 = _mm256_mask_dpbf16_ps(src, k, a, b);
1422 let result: [f32; 8] = transmute(c.as_f32x8());
1423 #[rustfmt::skip]
1424 let expected_result: [f32; 8] = [
1425 1.0_f32, 2.0_f32, 3.0_f32, 4.0_f32, 1.0_f32, 2.0_f32, 3.0_f32, 4.0_f32,
1426 ];
1427 assert_eq!(result, expected_result);
1428 }
1429
1430 #[simd_test(enable = "avx512bf16,avx512vl")]
1431 unsafe fn test_mm256_maskz_dpbf16_ps() {
1432 #[rustfmt::skip]
1433 let a_array = [
1434 8.5_f32, 10.5_f32, 3.75_f32, 50.25_f32, 8.5_f32, 10.5_f32, 3.75_f32, 50.25_f32,
1435 ];
1436 let b_array = [
1437 -1.0_f32, -1.0_f32, -1.0_f32, -1.0_f32, -1.0_f32, -1.0_f32, -1.0_f32, -1.0_f32,
1438 ];
1439 let a1: __m256 = transmute(a_array);
1440 let b1: __m256 = transmute(b_array);
1441 let k: __mmask8 = 0x33;
1442 #[rustfmt::skip]
1443 let src: __m256 = transmute([
1444 1.0_f32, 2.0_f32, 3.0_f32, 4.0_f32, 1.0_f32, 2.0_f32, 3.0_f32, 4.0_f32,
1445 ]);
1446 let a: __m256bh = _mm256_cvtne2ps_pbh(a1, a1);
1447 let b: __m256bh = _mm256_cvtne2ps_pbh(b1, b1);
1448 let c: __m256 = _mm256_maskz_dpbf16_ps(k, src, a, b);
1449 let result: [f32; 8] = transmute(c.as_f32x8());
1450 #[rustfmt::skip]
1451 let expected_result: [f32; 8] = [
1452 -18.0_f32, -52.0_f32, 0.0, 0.0, -18.0_f32, -52.0_f32, 0.0, 0.0,
1453 ];
1454 assert_eq!(result, expected_result);
1455 let k: __mmask8 = 0xff;
1456 let c: __m256 = _mm256_maskz_dpbf16_ps(k, src, a, b);
1457 let result: [f32; 8] = transmute(c.as_f32x8());
1458 #[rustfmt::skip]
1459 let expected_result: [f32; 8] = [
1460 -18.0_f32, -52.0_f32, -16.0_f32, -50.0_f32, -18.0_f32, -52.0_f32, -16.0_f32, -50.0_f32,
1461 ];
1462 assert_eq!(result, expected_result);
1463 let k: __mmask8 = 0;
1464 let c: __m256 = _mm256_maskz_dpbf16_ps(k, src, a, b);
1465 let result: [f32; 8] = transmute(c.as_f32x8());
1466 let expected_result: [f32; 8] = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
1467 assert_eq!(result, expected_result);
1468 }
1469
1470 #[simd_test(enable = "avx512bf16,avx512f")]
1471 unsafe fn test_mm512_dpbf16_ps() {
1472 #[rustfmt::skip]
1473 let a_array = [
1474 8.5_f32, 10.5_f32, 3.75_f32, 50.25_f32, 8.5_f32, 10.5_f32, 3.75_f32, 50.25_f32,
1475 8.5_f32, 10.5_f32, 3.75_f32, 50.25_f32, 8.5_f32, 10.5_f32, 3.75_f32, 50.25_f32,
1476 ];
1477 let b_array = [
1478 -1.0_f32, -1.0_f32, -1.0_f32, -1.0_f32, -1.0_f32, -1.0_f32, -1.0_f32, -1.0_f32,
1479 -1.0_f32, -1.0_f32, -1.0_f32, -1.0_f32, -1.0_f32, -1.0_f32, -1.0_f32, -1.0_f32,
1480 ];
1481 let a1: __m512 = transmute(a_array);
1482 let b1: __m512 = transmute(b_array);
1483 let src: __m512 = transmute([
1484 1.0_f32, 2.0_f32, 3.0_f32, 4.0_f32, 1.0_f32, 2.0_f32, 3.0_f32, 4.0_f32, 1.0_f32,
1485 2.0_f32, 3.0_f32, 4.0_f32, 1.0_f32, 2.0_f32, 3.0_f32, 4.0_f32,
1486 ]);
1487 let a: __m512bh = _mm512_cvtne2ps_pbh(a1, a1);
1488 let b: __m512bh = _mm512_cvtne2ps_pbh(b1, b1);
1489 let c: __m512 = _mm512_dpbf16_ps(src, a, b);
1490 let result: [f32; 16] = transmute(c.as_f32x16());
1491 #[rustfmt::skip]
1492 let expected_result: [f32; 16] = [
1493 -18.0_f32, -52.0_f32, -16.0_f32, -50.0_f32, -18.0_f32, -52.0_f32, -16.0_f32, -50.0_f32,
1494 -18.0_f32, -52.0_f32, -16.0_f32, -50.0_f32, -18.0_f32, -52.0_f32, -16.0_f32, -50.0_f32,
1495 ];
1496 assert_eq!(result, expected_result);
1497 }
1498
1499 #[simd_test(enable = "avx512bf16,avx512f")]
1500 unsafe fn test_mm512_mask_dpbf16_ps() {
1501 #[rustfmt::skip]
1502 let a_array = [
1503 8.5_f32, 10.5_f32, 3.75_f32, 50.25_f32, 8.5_f32, 10.5_f32, 3.75_f32, 50.25_f32,
1504 8.5_f32, 10.5_f32, 3.75_f32, 50.25_f32, 8.5_f32, 10.5_f32, 3.75_f32, 50.25_f32,
1505 ];
1506 let b_array = [
1507 -1.0_f32, -1.0_f32, -1.0_f32, -1.0_f32, -1.0_f32, -1.0_f32, -1.0_f32, -1.0_f32,
1508 -1.0_f32, -1.0_f32, -1.0_f32, -1.0_f32, -1.0_f32, -1.0_f32, -1.0_f32, -1.0_f32,
1509 ];
1510 let a1: __m512 = transmute(a_array);
1511 let b1: __m512 = transmute(b_array);
1512 let k: __mmask16 = 0x3333;
1513 #[rustfmt::skip]
1514 let src: __m512 = transmute([
1515 1.0_f32, 2.0_f32, 3.0_f32, 4.0_f32, 1.0_f32, 2.0_f32, 3.0_f32, 4.0_f32, 1.0_f32,
1516 2.0_f32, 3.0_f32, 4.0_f32, 1.0_f32, 2.0_f32, 3.0_f32, 4.0_f32,
1517 ]);
1518 let a: __m512bh = _mm512_cvtne2ps_pbh(a1, a1);
1519 let b: __m512bh = _mm512_cvtne2ps_pbh(b1, b1);
1520 let c: __m512 = _mm512_mask_dpbf16_ps(src, k, a, b);
1521 let result: [f32; 16] = transmute(c.as_f32x16());
1522 #[rustfmt::skip]
1523 let expected_result: [f32; 16] = [
1524 -18.0_f32, -52.0_f32, 3.0_f32, 4.0_f32, -18.0_f32, -52.0_f32, 3.0_f32, 4.0_f32,
1525 -18.0_f32, -52.0_f32, 3.0_f32, 4.0_f32, -18.0_f32, -52.0_f32, 3.0_f32, 4.0_f32,
1526 ];
1527 assert_eq!(result, expected_result);
1528 let k: __mmask16 = 0xffff;
1529 let c: __m512 = _mm512_mask_dpbf16_ps(src, k, a, b);
1530 let result: [f32; 16] = transmute(c.as_f32x16());
1531 #[rustfmt::skip]
1532 let expected_result: [f32; 16] = [
1533 -18.0_f32, -52.0_f32, -16.0_f32, -50.0_f32, -18.0_f32, -52.0_f32, -16.0_f32, -50.0_f32,
1534 -18.0_f32, -52.0_f32, -16.0_f32, -50.0_f32, -18.0_f32, -52.0_f32, -16.0_f32, -50.0_f32,
1535 ];
1536 assert_eq!(result, expected_result);
1537 let k: __mmask16 = 0;
1538 let c: __m512 = _mm512_mask_dpbf16_ps(src, k, a, b);
1539 let result: [f32; 16] = transmute(c.as_f32x16());
1540 #[rustfmt::skip]
1541 let expected_result: [f32; 16] = [
1542 1.0_f32, 2.0_f32, 3.0_f32, 4.0_f32, 1.0_f32, 2.0_f32, 3.0_f32, 4.0_f32, 1.0_f32,
1543 2.0_f32, 3.0_f32, 4.0_f32, 1.0_f32, 2.0_f32, 3.0_f32, 4.0_f32,
1544 ];
1545 assert_eq!(result, expected_result);
1546 }
1547
1548 #[simd_test(enable = "avx512bf16,avx512f")]
1549 unsafe fn test_mm512_maskz_dpbf16_ps() {
1550 #[rustfmt::skip]
1551 let a_array = [
1552 8.5_f32, 10.5_f32, 3.75_f32, 50.25_f32, 8.5_f32, 10.5_f32, 3.75_f32, 50.25_f32,
1553 8.5_f32, 10.5_f32, 3.75_f32, 50.25_f32, 8.5_f32, 10.5_f32, 3.75_f32, 50.25_f32,
1554 ];
1555 let b_array = [
1556 -1.0_f32, -1.0_f32, -1.0_f32, -1.0_f32, -1.0_f32, -1.0_f32, -1.0_f32, -1.0_f32,
1557 -1.0_f32, -1.0_f32, -1.0_f32, -1.0_f32, -1.0_f32, -1.0_f32, -1.0_f32, -1.0_f32,
1558 ];
1559 let a1: __m512 = transmute(a_array);
1560 let b1: __m512 = transmute(b_array);
1561 let k: __mmask16 = 0x3333;
1562 #[rustfmt::skip]
1563 let src: __m512 = transmute([
1564 1.0_f32, 2.0_f32, 3.0_f32, 4.0_f32, 1.0_f32, 2.0_f32, 3.0_f32, 4.0_f32, 1.0_f32,
1565 2.0_f32, 3.0_f32, 4.0_f32, 1.0_f32, 2.0_f32, 3.0_f32, 4.0_f32,
1566 ]);
1567 let a: __m512bh = _mm512_cvtne2ps_pbh(a1, a1);
1568 let b: __m512bh = _mm512_cvtne2ps_pbh(b1, b1);
1569 let c: __m512 = _mm512_maskz_dpbf16_ps(k, src, a, b);
1570 let result: [f32; 16] = transmute(c.as_f32x16());
1571 #[rustfmt::skip]
1572 let expected_result: [f32; 16] = [
1573 -18.0_f32, -52.0_f32, 0.0, 0.0, -18.0_f32, -52.0_f32, 0.0, 0.0, -18.0_f32, -52.0_f32,
1574 0.0, 0.0, -18.0_f32, -52.0_f32, 0.0, 0.0,
1575 ];
1576 assert_eq!(result, expected_result);
1577 let k: __mmask16 = 0xffff;
1578 let c: __m512 = _mm512_maskz_dpbf16_ps(k, src, a, b);
1579 let result: [f32; 16] = transmute(c.as_f32x16());
1580 #[rustfmt::skip]
1581 let expected_result: [f32; 16] = [
1582 -18.0_f32, -52.0_f32, -16.0_f32, -50.0_f32, -18.0_f32, -52.0_f32, -16.0_f32, -50.0_f32,
1583 -18.0_f32, -52.0_f32, -16.0_f32, -50.0_f32, -18.0_f32, -52.0_f32, -16.0_f32, -50.0_f32,
1584 ];
1585 assert_eq!(result, expected_result);
1586 let k: __mmask16 = 0;
1587 let c: __m512 = _mm512_maskz_dpbf16_ps(k, src, a, b);
1588 let result: [f32; 16] = transmute(c.as_f32x16());
1589 #[rustfmt::skip]
1590 let expected_result: [f32; 16] = [
1591 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
1592 ];
1593 assert_eq!(result, expected_result);
1594 }
1595}
1596