1/* SPDX-License-Identifier: GPL-2.0-or-later */
2/*
3 * SM4 Cipher Algorithm, AES-NI/AVX optimized.
4 * as specified in
5 * https://tools.ietf.org/id/draft-ribose-cfrg-sm4-10.html
6 *
7 * Copyright (c) 2021, Alibaba Group.
8 * Copyright (c) 2021 Tianjia Zhang <tianjia.zhang@linux.alibaba.com>
9 */
10
11#include <linux/module.h>
12#include <linux/crypto.h>
13#include <linux/kernel.h>
14#include <asm/simd.h>
15#include <crypto/internal/simd.h>
16#include <crypto/internal/skcipher.h>
17#include <crypto/sm4.h>
18#include "sm4-avx.h"
19
20#define SM4_CRYPT8_BLOCK_SIZE (SM4_BLOCK_SIZE * 8)
21
22asmlinkage void sm4_aesni_avx_crypt4(const u32 *rk, u8 *dst,
23 const u8 *src, int nblocks);
24asmlinkage void sm4_aesni_avx_crypt8(const u32 *rk, u8 *dst,
25 const u8 *src, int nblocks);
26asmlinkage void sm4_aesni_avx_ctr_enc_blk8(const u32 *rk, u8 *dst,
27 const u8 *src, u8 *iv);
28asmlinkage void sm4_aesni_avx_cbc_dec_blk8(const u32 *rk, u8 *dst,
29 const u8 *src, u8 *iv);
30asmlinkage void sm4_aesni_avx_cfb_dec_blk8(const u32 *rk, u8 *dst,
31 const u8 *src, u8 *iv);
32
33static int sm4_skcipher_setkey(struct crypto_skcipher *tfm, const u8 *key,
34 unsigned int key_len)
35{
36 struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
37
38 return sm4_expandkey(ctx, in_key: key, key_len);
39}
40
41static int ecb_do_crypt(struct skcipher_request *req, const u32 *rkey)
42{
43 struct skcipher_walk walk;
44 unsigned int nbytes;
45 int err;
46
47 err = skcipher_walk_virt(walk: &walk, req, atomic: false);
48
49 while ((nbytes = walk.nbytes) > 0) {
50 const u8 *src = walk.src.virt.addr;
51 u8 *dst = walk.dst.virt.addr;
52
53 kernel_fpu_begin();
54 while (nbytes >= SM4_CRYPT8_BLOCK_SIZE) {
55 sm4_aesni_avx_crypt8(rk: rkey, dst, src, nblocks: 8);
56 dst += SM4_CRYPT8_BLOCK_SIZE;
57 src += SM4_CRYPT8_BLOCK_SIZE;
58 nbytes -= SM4_CRYPT8_BLOCK_SIZE;
59 }
60 while (nbytes >= SM4_BLOCK_SIZE) {
61 unsigned int nblocks = min(nbytes >> 4, 4u);
62 sm4_aesni_avx_crypt4(rk: rkey, dst, src, nblocks);
63 dst += nblocks * SM4_BLOCK_SIZE;
64 src += nblocks * SM4_BLOCK_SIZE;
65 nbytes -= nblocks * SM4_BLOCK_SIZE;
66 }
67 kernel_fpu_end();
68
69 err = skcipher_walk_done(walk: &walk, err: nbytes);
70 }
71
72 return err;
73}
74
75int sm4_avx_ecb_encrypt(struct skcipher_request *req)
76{
77 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
78 struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
79
80 return ecb_do_crypt(req, rkey: ctx->rkey_enc);
81}
82EXPORT_SYMBOL_GPL(sm4_avx_ecb_encrypt);
83
84int sm4_avx_ecb_decrypt(struct skcipher_request *req)
85{
86 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
87 struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
88
89 return ecb_do_crypt(req, rkey: ctx->rkey_dec);
90}
91EXPORT_SYMBOL_GPL(sm4_avx_ecb_decrypt);
92
93int sm4_cbc_encrypt(struct skcipher_request *req)
94{
95 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
96 struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
97 struct skcipher_walk walk;
98 unsigned int nbytes;
99 int err;
100
101 err = skcipher_walk_virt(walk: &walk, req, atomic: false);
102
103 while ((nbytes = walk.nbytes) > 0) {
104 const u8 *iv = walk.iv;
105 const u8 *src = walk.src.virt.addr;
106 u8 *dst = walk.dst.virt.addr;
107
108 while (nbytes >= SM4_BLOCK_SIZE) {
109 crypto_xor_cpy(dst, src1: src, src2: iv, SM4_BLOCK_SIZE);
110 sm4_crypt_block(rk: ctx->rkey_enc, out: dst, in: dst);
111 iv = dst;
112 src += SM4_BLOCK_SIZE;
113 dst += SM4_BLOCK_SIZE;
114 nbytes -= SM4_BLOCK_SIZE;
115 }
116 if (iv != walk.iv)
117 memcpy(walk.iv, iv, SM4_BLOCK_SIZE);
118
119 err = skcipher_walk_done(walk: &walk, err: nbytes);
120 }
121
122 return err;
123}
124EXPORT_SYMBOL_GPL(sm4_cbc_encrypt);
125
126int sm4_avx_cbc_decrypt(struct skcipher_request *req,
127 unsigned int bsize, sm4_crypt_func func)
128{
129 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
130 struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
131 struct skcipher_walk walk;
132 unsigned int nbytes;
133 int err;
134
135 err = skcipher_walk_virt(walk: &walk, req, atomic: false);
136
137 while ((nbytes = walk.nbytes) > 0) {
138 const u8 *src = walk.src.virt.addr;
139 u8 *dst = walk.dst.virt.addr;
140
141 kernel_fpu_begin();
142
143 while (nbytes >= bsize) {
144 func(ctx->rkey_dec, dst, src, walk.iv);
145 dst += bsize;
146 src += bsize;
147 nbytes -= bsize;
148 }
149
150 while (nbytes >= SM4_BLOCK_SIZE) {
151 u8 keystream[SM4_BLOCK_SIZE * 8];
152 u8 iv[SM4_BLOCK_SIZE];
153 unsigned int nblocks = min(nbytes >> 4, 8u);
154 int i;
155
156 sm4_aesni_avx_crypt8(rk: ctx->rkey_dec, dst: keystream,
157 src, nblocks);
158
159 src += ((int)nblocks - 2) * SM4_BLOCK_SIZE;
160 dst += (nblocks - 1) * SM4_BLOCK_SIZE;
161 memcpy(iv, src + SM4_BLOCK_SIZE, SM4_BLOCK_SIZE);
162
163 for (i = nblocks - 1; i > 0; i--) {
164 crypto_xor_cpy(dst, src1: src,
165 src2: &keystream[i * SM4_BLOCK_SIZE],
166 SM4_BLOCK_SIZE);
167 src -= SM4_BLOCK_SIZE;
168 dst -= SM4_BLOCK_SIZE;
169 }
170 crypto_xor_cpy(dst, src1: walk.iv, src2: keystream, SM4_BLOCK_SIZE);
171 memcpy(walk.iv, iv, SM4_BLOCK_SIZE);
172 dst += nblocks * SM4_BLOCK_SIZE;
173 src += (nblocks + 1) * SM4_BLOCK_SIZE;
174 nbytes -= nblocks * SM4_BLOCK_SIZE;
175 }
176
177 kernel_fpu_end();
178 err = skcipher_walk_done(walk: &walk, err: nbytes);
179 }
180
181 return err;
182}
183EXPORT_SYMBOL_GPL(sm4_avx_cbc_decrypt);
184
185static int cbc_decrypt(struct skcipher_request *req)
186{
187 return sm4_avx_cbc_decrypt(req, SM4_CRYPT8_BLOCK_SIZE,
188 sm4_aesni_avx_cbc_dec_blk8);
189}
190
191int sm4_cfb_encrypt(struct skcipher_request *req)
192{
193 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
194 struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
195 struct skcipher_walk walk;
196 unsigned int nbytes;
197 int err;
198
199 err = skcipher_walk_virt(walk: &walk, req, atomic: false);
200
201 while ((nbytes = walk.nbytes) > 0) {
202 u8 keystream[SM4_BLOCK_SIZE];
203 const u8 *iv = walk.iv;
204 const u8 *src = walk.src.virt.addr;
205 u8 *dst = walk.dst.virt.addr;
206
207 while (nbytes >= SM4_BLOCK_SIZE) {
208 sm4_crypt_block(rk: ctx->rkey_enc, out: keystream, in: iv);
209 crypto_xor_cpy(dst, src1: src, src2: keystream, SM4_BLOCK_SIZE);
210 iv = dst;
211 src += SM4_BLOCK_SIZE;
212 dst += SM4_BLOCK_SIZE;
213 nbytes -= SM4_BLOCK_SIZE;
214 }
215 if (iv != walk.iv)
216 memcpy(walk.iv, iv, SM4_BLOCK_SIZE);
217
218 /* tail */
219 if (walk.nbytes == walk.total && nbytes > 0) {
220 sm4_crypt_block(rk: ctx->rkey_enc, out: keystream, in: walk.iv);
221 crypto_xor_cpy(dst, src1: src, src2: keystream, size: nbytes);
222 nbytes = 0;
223 }
224
225 err = skcipher_walk_done(walk: &walk, err: nbytes);
226 }
227
228 return err;
229}
230EXPORT_SYMBOL_GPL(sm4_cfb_encrypt);
231
232int sm4_avx_cfb_decrypt(struct skcipher_request *req,
233 unsigned int bsize, sm4_crypt_func func)
234{
235 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
236 struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
237 struct skcipher_walk walk;
238 unsigned int nbytes;
239 int err;
240
241 err = skcipher_walk_virt(walk: &walk, req, atomic: false);
242
243 while ((nbytes = walk.nbytes) > 0) {
244 const u8 *src = walk.src.virt.addr;
245 u8 *dst = walk.dst.virt.addr;
246
247 kernel_fpu_begin();
248
249 while (nbytes >= bsize) {
250 func(ctx->rkey_enc, dst, src, walk.iv);
251 dst += bsize;
252 src += bsize;
253 nbytes -= bsize;
254 }
255
256 while (nbytes >= SM4_BLOCK_SIZE) {
257 u8 keystream[SM4_BLOCK_SIZE * 8];
258 unsigned int nblocks = min(nbytes >> 4, 8u);
259
260 memcpy(keystream, walk.iv, SM4_BLOCK_SIZE);
261 if (nblocks > 1)
262 memcpy(&keystream[SM4_BLOCK_SIZE], src,
263 (nblocks - 1) * SM4_BLOCK_SIZE);
264 memcpy(walk.iv, src + (nblocks - 1) * SM4_BLOCK_SIZE,
265 SM4_BLOCK_SIZE);
266
267 sm4_aesni_avx_crypt8(rk: ctx->rkey_enc, dst: keystream,
268 src: keystream, nblocks);
269
270 crypto_xor_cpy(dst, src1: src, src2: keystream,
271 size: nblocks * SM4_BLOCK_SIZE);
272 dst += nblocks * SM4_BLOCK_SIZE;
273 src += nblocks * SM4_BLOCK_SIZE;
274 nbytes -= nblocks * SM4_BLOCK_SIZE;
275 }
276
277 kernel_fpu_end();
278
279 /* tail */
280 if (walk.nbytes == walk.total && nbytes > 0) {
281 u8 keystream[SM4_BLOCK_SIZE];
282
283 sm4_crypt_block(rk: ctx->rkey_enc, out: keystream, in: walk.iv);
284 crypto_xor_cpy(dst, src1: src, src2: keystream, size: nbytes);
285 nbytes = 0;
286 }
287
288 err = skcipher_walk_done(walk: &walk, err: nbytes);
289 }
290
291 return err;
292}
293EXPORT_SYMBOL_GPL(sm4_avx_cfb_decrypt);
294
295static int cfb_decrypt(struct skcipher_request *req)
296{
297 return sm4_avx_cfb_decrypt(req, SM4_CRYPT8_BLOCK_SIZE,
298 sm4_aesni_avx_cfb_dec_blk8);
299}
300
301int sm4_avx_ctr_crypt(struct skcipher_request *req,
302 unsigned int bsize, sm4_crypt_func func)
303{
304 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
305 struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
306 struct skcipher_walk walk;
307 unsigned int nbytes;
308 int err;
309
310 err = skcipher_walk_virt(walk: &walk, req, atomic: false);
311
312 while ((nbytes = walk.nbytes) > 0) {
313 const u8 *src = walk.src.virt.addr;
314 u8 *dst = walk.dst.virt.addr;
315
316 kernel_fpu_begin();
317
318 while (nbytes >= bsize) {
319 func(ctx->rkey_enc, dst, src, walk.iv);
320 dst += bsize;
321 src += bsize;
322 nbytes -= bsize;
323 }
324
325 while (nbytes >= SM4_BLOCK_SIZE) {
326 u8 keystream[SM4_BLOCK_SIZE * 8];
327 unsigned int nblocks = min(nbytes >> 4, 8u);
328 int i;
329
330 for (i = 0; i < nblocks; i++) {
331 memcpy(&keystream[i * SM4_BLOCK_SIZE],
332 walk.iv, SM4_BLOCK_SIZE);
333 crypto_inc(a: walk.iv, SM4_BLOCK_SIZE);
334 }
335 sm4_aesni_avx_crypt8(rk: ctx->rkey_enc, dst: keystream,
336 src: keystream, nblocks);
337
338 crypto_xor_cpy(dst, src1: src, src2: keystream,
339 size: nblocks * SM4_BLOCK_SIZE);
340 dst += nblocks * SM4_BLOCK_SIZE;
341 src += nblocks * SM4_BLOCK_SIZE;
342 nbytes -= nblocks * SM4_BLOCK_SIZE;
343 }
344
345 kernel_fpu_end();
346
347 /* tail */
348 if (walk.nbytes == walk.total && nbytes > 0) {
349 u8 keystream[SM4_BLOCK_SIZE];
350
351 memcpy(keystream, walk.iv, SM4_BLOCK_SIZE);
352 crypto_inc(a: walk.iv, SM4_BLOCK_SIZE);
353
354 sm4_crypt_block(rk: ctx->rkey_enc, out: keystream, in: keystream);
355
356 crypto_xor_cpy(dst, src1: src, src2: keystream, size: nbytes);
357 dst += nbytes;
358 src += nbytes;
359 nbytes = 0;
360 }
361
362 err = skcipher_walk_done(walk: &walk, err: nbytes);
363 }
364
365 return err;
366}
367EXPORT_SYMBOL_GPL(sm4_avx_ctr_crypt);
368
369static int ctr_crypt(struct skcipher_request *req)
370{
371 return sm4_avx_ctr_crypt(req, SM4_CRYPT8_BLOCK_SIZE,
372 sm4_aesni_avx_ctr_enc_blk8);
373}
374
375static struct skcipher_alg sm4_aesni_avx_skciphers[] = {
376 {
377 .base = {
378 .cra_name = "__ecb(sm4)",
379 .cra_driver_name = "__ecb-sm4-aesni-avx",
380 .cra_priority = 400,
381 .cra_flags = CRYPTO_ALG_INTERNAL,
382 .cra_blocksize = SM4_BLOCK_SIZE,
383 .cra_ctxsize = sizeof(struct sm4_ctx),
384 .cra_module = THIS_MODULE,
385 },
386 .min_keysize = SM4_KEY_SIZE,
387 .max_keysize = SM4_KEY_SIZE,
388 .walksize = 8 * SM4_BLOCK_SIZE,
389 .setkey = sm4_skcipher_setkey,
390 .encrypt = sm4_avx_ecb_encrypt,
391 .decrypt = sm4_avx_ecb_decrypt,
392 }, {
393 .base = {
394 .cra_name = "__cbc(sm4)",
395 .cra_driver_name = "__cbc-sm4-aesni-avx",
396 .cra_priority = 400,
397 .cra_flags = CRYPTO_ALG_INTERNAL,
398 .cra_blocksize = SM4_BLOCK_SIZE,
399 .cra_ctxsize = sizeof(struct sm4_ctx),
400 .cra_module = THIS_MODULE,
401 },
402 .min_keysize = SM4_KEY_SIZE,
403 .max_keysize = SM4_KEY_SIZE,
404 .ivsize = SM4_BLOCK_SIZE,
405 .walksize = 8 * SM4_BLOCK_SIZE,
406 .setkey = sm4_skcipher_setkey,
407 .encrypt = sm4_cbc_encrypt,
408 .decrypt = cbc_decrypt,
409 }, {
410 .base = {
411 .cra_name = "__cfb(sm4)",
412 .cra_driver_name = "__cfb-sm4-aesni-avx",
413 .cra_priority = 400,
414 .cra_flags = CRYPTO_ALG_INTERNAL,
415 .cra_blocksize = 1,
416 .cra_ctxsize = sizeof(struct sm4_ctx),
417 .cra_module = THIS_MODULE,
418 },
419 .min_keysize = SM4_KEY_SIZE,
420 .max_keysize = SM4_KEY_SIZE,
421 .ivsize = SM4_BLOCK_SIZE,
422 .chunksize = SM4_BLOCK_SIZE,
423 .walksize = 8 * SM4_BLOCK_SIZE,
424 .setkey = sm4_skcipher_setkey,
425 .encrypt = sm4_cfb_encrypt,
426 .decrypt = cfb_decrypt,
427 }, {
428 .base = {
429 .cra_name = "__ctr(sm4)",
430 .cra_driver_name = "__ctr-sm4-aesni-avx",
431 .cra_priority = 400,
432 .cra_flags = CRYPTO_ALG_INTERNAL,
433 .cra_blocksize = 1,
434 .cra_ctxsize = sizeof(struct sm4_ctx),
435 .cra_module = THIS_MODULE,
436 },
437 .min_keysize = SM4_KEY_SIZE,
438 .max_keysize = SM4_KEY_SIZE,
439 .ivsize = SM4_BLOCK_SIZE,
440 .chunksize = SM4_BLOCK_SIZE,
441 .walksize = 8 * SM4_BLOCK_SIZE,
442 .setkey = sm4_skcipher_setkey,
443 .encrypt = ctr_crypt,
444 .decrypt = ctr_crypt,
445 }
446};
447
448static struct simd_skcipher_alg *
449simd_sm4_aesni_avx_skciphers[ARRAY_SIZE(sm4_aesni_avx_skciphers)];
450
451static int __init sm4_init(void)
452{
453 const char *feature_name;
454
455 if (!boot_cpu_has(X86_FEATURE_AVX) ||
456 !boot_cpu_has(X86_FEATURE_AES) ||
457 !boot_cpu_has(X86_FEATURE_OSXSAVE)) {
458 pr_info("AVX or AES-NI instructions are not detected.\n");
459 return -ENODEV;
460 }
461
462 if (!cpu_has_xfeatures(XFEATURE_MASK_SSE | XFEATURE_MASK_YMM,
463 feature_name: &feature_name)) {
464 pr_info("CPU feature '%s' is not supported.\n", feature_name);
465 return -ENODEV;
466 }
467
468 return simd_register_skciphers_compat(algs: sm4_aesni_avx_skciphers,
469 ARRAY_SIZE(sm4_aesni_avx_skciphers),
470 simd_algs: simd_sm4_aesni_avx_skciphers);
471}
472
473static void __exit sm4_exit(void)
474{
475 simd_unregister_skciphers(algs: sm4_aesni_avx_skciphers,
476 ARRAY_SIZE(sm4_aesni_avx_skciphers),
477 simd_algs: simd_sm4_aesni_avx_skciphers);
478}
479
480module_init(sm4_init);
481module_exit(sm4_exit);
482
483MODULE_LICENSE("GPL v2");
484MODULE_AUTHOR("Tianjia Zhang <tianjia.zhang@linux.alibaba.com>");
485MODULE_DESCRIPTION("SM4 Cipher Algorithm, AES-NI/AVX optimized");
486MODULE_ALIAS_CRYPTO("sm4");
487MODULE_ALIAS_CRYPTO("sm4-aesni-avx");
488

source code of linux/arch/x86/crypto/sm4_aesni_avx_glue.c