1 | // SPDX-License-Identifier: GPL-2.0-only |
2 | /* |
3 | * Bit sliced AES using NEON instructions |
4 | * |
5 | * Copyright (C) 2016 - 2017 Linaro Ltd <ard.biesheuvel@linaro.org> |
6 | */ |
7 | |
8 | #include <asm/neon.h> |
9 | #include <asm/simd.h> |
10 | #include <crypto/aes.h> |
11 | #include <crypto/ctr.h> |
12 | #include <crypto/internal/simd.h> |
13 | #include <crypto/internal/skcipher.h> |
14 | #include <crypto/scatterwalk.h> |
15 | #include <crypto/xts.h> |
16 | #include <linux/module.h> |
17 | |
18 | MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>" ); |
19 | MODULE_LICENSE("GPL v2" ); |
20 | |
21 | MODULE_ALIAS_CRYPTO("ecb(aes)" ); |
22 | MODULE_ALIAS_CRYPTO("cbc(aes)" ); |
23 | MODULE_ALIAS_CRYPTO("ctr(aes)" ); |
24 | MODULE_ALIAS_CRYPTO("xts(aes)" ); |
25 | |
26 | asmlinkage void aesbs_convert_key(u8 out[], u32 const rk[], int rounds); |
27 | |
28 | asmlinkage void aesbs_ecb_encrypt(u8 out[], u8 const in[], u8 const rk[], |
29 | int rounds, int blocks); |
30 | asmlinkage void aesbs_ecb_decrypt(u8 out[], u8 const in[], u8 const rk[], |
31 | int rounds, int blocks); |
32 | |
33 | asmlinkage void aesbs_cbc_decrypt(u8 out[], u8 const in[], u8 const rk[], |
34 | int rounds, int blocks, u8 iv[]); |
35 | |
36 | asmlinkage void aesbs_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[], |
37 | int rounds, int blocks, u8 iv[]); |
38 | |
39 | asmlinkage void aesbs_xts_encrypt(u8 out[], u8 const in[], u8 const rk[], |
40 | int rounds, int blocks, u8 iv[]); |
41 | asmlinkage void aesbs_xts_decrypt(u8 out[], u8 const in[], u8 const rk[], |
42 | int rounds, int blocks, u8 iv[]); |
43 | |
44 | /* borrowed from aes-neon-blk.ko */ |
45 | asmlinkage void neon_aes_ecb_encrypt(u8 out[], u8 const in[], u32 const rk[], |
46 | int rounds, int blocks); |
47 | asmlinkage void neon_aes_cbc_encrypt(u8 out[], u8 const in[], u32 const rk[], |
48 | int rounds, int blocks, u8 iv[]); |
49 | asmlinkage void neon_aes_ctr_encrypt(u8 out[], u8 const in[], u32 const rk[], |
50 | int rounds, int bytes, u8 ctr[]); |
51 | asmlinkage void neon_aes_xts_encrypt(u8 out[], u8 const in[], |
52 | u32 const rk1[], int rounds, int bytes, |
53 | u32 const rk2[], u8 iv[], int first); |
54 | asmlinkage void neon_aes_xts_decrypt(u8 out[], u8 const in[], |
55 | u32 const rk1[], int rounds, int bytes, |
56 | u32 const rk2[], u8 iv[], int first); |
57 | |
58 | struct aesbs_ctx { |
59 | u8 rk[13 * (8 * AES_BLOCK_SIZE) + 32]; |
60 | int rounds; |
61 | } __aligned(AES_BLOCK_SIZE); |
62 | |
63 | struct aesbs_cbc_ctr_ctx { |
64 | struct aesbs_ctx key; |
65 | u32 enc[AES_MAX_KEYLENGTH_U32]; |
66 | }; |
67 | |
68 | struct aesbs_xts_ctx { |
69 | struct aesbs_ctx key; |
70 | u32 twkey[AES_MAX_KEYLENGTH_U32]; |
71 | struct crypto_aes_ctx cts; |
72 | }; |
73 | |
74 | static int aesbs_setkey(struct crypto_skcipher *tfm, const u8 *in_key, |
75 | unsigned int key_len) |
76 | { |
77 | struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm); |
78 | struct crypto_aes_ctx rk; |
79 | int err; |
80 | |
81 | err = aes_expandkey(ctx: &rk, in_key, key_len); |
82 | if (err) |
83 | return err; |
84 | |
85 | ctx->rounds = 6 + key_len / 4; |
86 | |
87 | kernel_neon_begin(); |
88 | aesbs_convert_key(out: ctx->rk, rk: rk.key_enc, rounds: ctx->rounds); |
89 | kernel_neon_end(); |
90 | |
91 | return 0; |
92 | } |
93 | |
94 | static int __ecb_crypt(struct skcipher_request *req, |
95 | void (*fn)(u8 out[], u8 const in[], u8 const rk[], |
96 | int rounds, int blocks)) |
97 | { |
98 | struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); |
99 | struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm); |
100 | struct skcipher_walk walk; |
101 | int err; |
102 | |
103 | err = skcipher_walk_virt(walk: &walk, req, atomic: false); |
104 | |
105 | while (walk.nbytes >= AES_BLOCK_SIZE) { |
106 | unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE; |
107 | |
108 | if (walk.nbytes < walk.total) |
109 | blocks = round_down(blocks, |
110 | walk.stride / AES_BLOCK_SIZE); |
111 | |
112 | kernel_neon_begin(); |
113 | fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->rk, |
114 | ctx->rounds, blocks); |
115 | kernel_neon_end(); |
116 | err = skcipher_walk_done(walk: &walk, |
117 | err: walk.nbytes - blocks * AES_BLOCK_SIZE); |
118 | } |
119 | |
120 | return err; |
121 | } |
122 | |
123 | static int ecb_encrypt(struct skcipher_request *req) |
124 | { |
125 | return __ecb_crypt(req, fn: aesbs_ecb_encrypt); |
126 | } |
127 | |
128 | static int ecb_decrypt(struct skcipher_request *req) |
129 | { |
130 | return __ecb_crypt(req, fn: aesbs_ecb_decrypt); |
131 | } |
132 | |
133 | static int aesbs_cbc_ctr_setkey(struct crypto_skcipher *tfm, const u8 *in_key, |
134 | unsigned int key_len) |
135 | { |
136 | struct aesbs_cbc_ctr_ctx *ctx = crypto_skcipher_ctx(tfm); |
137 | struct crypto_aes_ctx rk; |
138 | int err; |
139 | |
140 | err = aes_expandkey(ctx: &rk, in_key, key_len); |
141 | if (err) |
142 | return err; |
143 | |
144 | ctx->key.rounds = 6 + key_len / 4; |
145 | |
146 | memcpy(ctx->enc, rk.key_enc, sizeof(ctx->enc)); |
147 | |
148 | kernel_neon_begin(); |
149 | aesbs_convert_key(out: ctx->key.rk, rk: rk.key_enc, rounds: ctx->key.rounds); |
150 | kernel_neon_end(); |
151 | memzero_explicit(s: &rk, count: sizeof(rk)); |
152 | |
153 | return 0; |
154 | } |
155 | |
156 | static int cbc_encrypt(struct skcipher_request *req) |
157 | { |
158 | struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); |
159 | struct aesbs_cbc_ctr_ctx *ctx = crypto_skcipher_ctx(tfm); |
160 | struct skcipher_walk walk; |
161 | int err; |
162 | |
163 | err = skcipher_walk_virt(walk: &walk, req, atomic: false); |
164 | |
165 | while (walk.nbytes >= AES_BLOCK_SIZE) { |
166 | unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE; |
167 | |
168 | /* fall back to the non-bitsliced NEON implementation */ |
169 | kernel_neon_begin(); |
170 | neon_aes_cbc_encrypt(out: walk.dst.virt.addr, in: walk.src.virt.addr, |
171 | rk: ctx->enc, rounds: ctx->key.rounds, blocks, |
172 | iv: walk.iv); |
173 | kernel_neon_end(); |
174 | err = skcipher_walk_done(walk: &walk, err: walk.nbytes % AES_BLOCK_SIZE); |
175 | } |
176 | return err; |
177 | } |
178 | |
179 | static int cbc_decrypt(struct skcipher_request *req) |
180 | { |
181 | struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); |
182 | struct aesbs_cbc_ctr_ctx *ctx = crypto_skcipher_ctx(tfm); |
183 | struct skcipher_walk walk; |
184 | int err; |
185 | |
186 | err = skcipher_walk_virt(walk: &walk, req, atomic: false); |
187 | |
188 | while (walk.nbytes >= AES_BLOCK_SIZE) { |
189 | unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE; |
190 | |
191 | if (walk.nbytes < walk.total) |
192 | blocks = round_down(blocks, |
193 | walk.stride / AES_BLOCK_SIZE); |
194 | |
195 | kernel_neon_begin(); |
196 | aesbs_cbc_decrypt(out: walk.dst.virt.addr, in: walk.src.virt.addr, |
197 | rk: ctx->key.rk, rounds: ctx->key.rounds, blocks, |
198 | iv: walk.iv); |
199 | kernel_neon_end(); |
200 | err = skcipher_walk_done(walk: &walk, |
201 | err: walk.nbytes - blocks * AES_BLOCK_SIZE); |
202 | } |
203 | |
204 | return err; |
205 | } |
206 | |
207 | static int ctr_encrypt(struct skcipher_request *req) |
208 | { |
209 | struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); |
210 | struct aesbs_cbc_ctr_ctx *ctx = crypto_skcipher_ctx(tfm); |
211 | struct skcipher_walk walk; |
212 | int err; |
213 | |
214 | err = skcipher_walk_virt(walk: &walk, req, atomic: false); |
215 | |
216 | while (walk.nbytes > 0) { |
217 | int blocks = (walk.nbytes / AES_BLOCK_SIZE) & ~7; |
218 | int nbytes = walk.nbytes % (8 * AES_BLOCK_SIZE); |
219 | const u8 *src = walk.src.virt.addr; |
220 | u8 *dst = walk.dst.virt.addr; |
221 | |
222 | kernel_neon_begin(); |
223 | if (blocks >= 8) { |
224 | aesbs_ctr_encrypt(out: dst, in: src, rk: ctx->key.rk, rounds: ctx->key.rounds, |
225 | blocks, iv: walk.iv); |
226 | dst += blocks * AES_BLOCK_SIZE; |
227 | src += blocks * AES_BLOCK_SIZE; |
228 | } |
229 | if (nbytes && walk.nbytes == walk.total) { |
230 | u8 buf[AES_BLOCK_SIZE]; |
231 | u8 *d = dst; |
232 | |
233 | if (unlikely(nbytes < AES_BLOCK_SIZE)) |
234 | src = dst = memcpy(buf + sizeof(buf) - nbytes, |
235 | src, nbytes); |
236 | |
237 | neon_aes_ctr_encrypt(out: dst, in: src, rk: ctx->enc, rounds: ctx->key.rounds, |
238 | bytes: nbytes, ctr: walk.iv); |
239 | |
240 | if (unlikely(nbytes < AES_BLOCK_SIZE)) |
241 | memcpy(d, dst, nbytes); |
242 | |
243 | nbytes = 0; |
244 | } |
245 | kernel_neon_end(); |
246 | err = skcipher_walk_done(walk: &walk, err: nbytes); |
247 | } |
248 | return err; |
249 | } |
250 | |
251 | static int aesbs_xts_setkey(struct crypto_skcipher *tfm, const u8 *in_key, |
252 | unsigned int key_len) |
253 | { |
254 | struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm); |
255 | struct crypto_aes_ctx rk; |
256 | int err; |
257 | |
258 | err = xts_verify_key(tfm, key: in_key, keylen: key_len); |
259 | if (err) |
260 | return err; |
261 | |
262 | key_len /= 2; |
263 | err = aes_expandkey(ctx: &ctx->cts, in_key, key_len); |
264 | if (err) |
265 | return err; |
266 | |
267 | err = aes_expandkey(ctx: &rk, in_key: in_key + key_len, key_len); |
268 | if (err) |
269 | return err; |
270 | |
271 | memcpy(ctx->twkey, rk.key_enc, sizeof(ctx->twkey)); |
272 | |
273 | return aesbs_setkey(tfm, in_key, key_len); |
274 | } |
275 | |
276 | static int __xts_crypt(struct skcipher_request *req, bool encrypt, |
277 | void (*fn)(u8 out[], u8 const in[], u8 const rk[], |
278 | int rounds, int blocks, u8 iv[])) |
279 | { |
280 | struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); |
281 | struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm); |
282 | int tail = req->cryptlen % (8 * AES_BLOCK_SIZE); |
283 | struct scatterlist sg_src[2], sg_dst[2]; |
284 | struct skcipher_request subreq; |
285 | struct scatterlist *src, *dst; |
286 | struct skcipher_walk walk; |
287 | int nbytes, err; |
288 | int first = 1; |
289 | u8 *out, *in; |
290 | |
291 | if (req->cryptlen < AES_BLOCK_SIZE) |
292 | return -EINVAL; |
293 | |
294 | /* ensure that the cts tail is covered by a single step */ |
295 | if (unlikely(tail > 0 && tail < AES_BLOCK_SIZE)) { |
296 | int xts_blocks = DIV_ROUND_UP(req->cryptlen, |
297 | AES_BLOCK_SIZE) - 2; |
298 | |
299 | skcipher_request_set_tfm(req: &subreq, tfm); |
300 | skcipher_request_set_callback(req: &subreq, |
301 | flags: skcipher_request_flags(req), |
302 | NULL, NULL); |
303 | skcipher_request_set_crypt(req: &subreq, src: req->src, dst: req->dst, |
304 | cryptlen: xts_blocks * AES_BLOCK_SIZE, |
305 | iv: req->iv); |
306 | req = &subreq; |
307 | } else { |
308 | tail = 0; |
309 | } |
310 | |
311 | err = skcipher_walk_virt(walk: &walk, req, atomic: false); |
312 | if (err) |
313 | return err; |
314 | |
315 | while (walk.nbytes >= AES_BLOCK_SIZE) { |
316 | int blocks = (walk.nbytes / AES_BLOCK_SIZE) & ~7; |
317 | out = walk.dst.virt.addr; |
318 | in = walk.src.virt.addr; |
319 | nbytes = walk.nbytes; |
320 | |
321 | kernel_neon_begin(); |
322 | if (blocks >= 8) { |
323 | if (first == 1) |
324 | neon_aes_ecb_encrypt(out: walk.iv, in: walk.iv, |
325 | rk: ctx->twkey, |
326 | rounds: ctx->key.rounds, blocks: 1); |
327 | first = 2; |
328 | |
329 | fn(out, in, ctx->key.rk, ctx->key.rounds, blocks, |
330 | walk.iv); |
331 | |
332 | out += blocks * AES_BLOCK_SIZE; |
333 | in += blocks * AES_BLOCK_SIZE; |
334 | nbytes -= blocks * AES_BLOCK_SIZE; |
335 | } |
336 | if (walk.nbytes == walk.total && nbytes > 0) { |
337 | if (encrypt) |
338 | neon_aes_xts_encrypt(out, in, rk1: ctx->cts.key_enc, |
339 | rounds: ctx->key.rounds, bytes: nbytes, |
340 | rk2: ctx->twkey, iv: walk.iv, first); |
341 | else |
342 | neon_aes_xts_decrypt(out, in, rk1: ctx->cts.key_dec, |
343 | rounds: ctx->key.rounds, bytes: nbytes, |
344 | rk2: ctx->twkey, iv: walk.iv, first); |
345 | nbytes = first = 0; |
346 | } |
347 | kernel_neon_end(); |
348 | err = skcipher_walk_done(walk: &walk, err: nbytes); |
349 | } |
350 | |
351 | if (err || likely(!tail)) |
352 | return err; |
353 | |
354 | /* handle ciphertext stealing */ |
355 | dst = src = scatterwalk_ffwd(dst: sg_src, src: req->src, len: req->cryptlen); |
356 | if (req->dst != req->src) |
357 | dst = scatterwalk_ffwd(dst: sg_dst, src: req->dst, len: req->cryptlen); |
358 | |
359 | skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail, |
360 | iv: req->iv); |
361 | |
362 | err = skcipher_walk_virt(walk: &walk, req, atomic: false); |
363 | if (err) |
364 | return err; |
365 | |
366 | out = walk.dst.virt.addr; |
367 | in = walk.src.virt.addr; |
368 | nbytes = walk.nbytes; |
369 | |
370 | kernel_neon_begin(); |
371 | if (encrypt) |
372 | neon_aes_xts_encrypt(out, in, rk1: ctx->cts.key_enc, rounds: ctx->key.rounds, |
373 | bytes: nbytes, rk2: ctx->twkey, iv: walk.iv, first); |
374 | else |
375 | neon_aes_xts_decrypt(out, in, rk1: ctx->cts.key_dec, rounds: ctx->key.rounds, |
376 | bytes: nbytes, rk2: ctx->twkey, iv: walk.iv, first); |
377 | kernel_neon_end(); |
378 | |
379 | return skcipher_walk_done(walk: &walk, err: 0); |
380 | } |
381 | |
382 | static int xts_encrypt(struct skcipher_request *req) |
383 | { |
384 | return __xts_crypt(req, encrypt: true, fn: aesbs_xts_encrypt); |
385 | } |
386 | |
387 | static int xts_decrypt(struct skcipher_request *req) |
388 | { |
389 | return __xts_crypt(req, encrypt: false, fn: aesbs_xts_decrypt); |
390 | } |
391 | |
392 | static struct skcipher_alg aes_algs[] = { { |
393 | .base.cra_name = "ecb(aes)" , |
394 | .base.cra_driver_name = "ecb-aes-neonbs" , |
395 | .base.cra_priority = 250, |
396 | .base.cra_blocksize = AES_BLOCK_SIZE, |
397 | .base.cra_ctxsize = sizeof(struct aesbs_ctx), |
398 | .base.cra_module = THIS_MODULE, |
399 | |
400 | .min_keysize = AES_MIN_KEY_SIZE, |
401 | .max_keysize = AES_MAX_KEY_SIZE, |
402 | .walksize = 8 * AES_BLOCK_SIZE, |
403 | .setkey = aesbs_setkey, |
404 | .encrypt = ecb_encrypt, |
405 | .decrypt = ecb_decrypt, |
406 | }, { |
407 | .base.cra_name = "cbc(aes)" , |
408 | .base.cra_driver_name = "cbc-aes-neonbs" , |
409 | .base.cra_priority = 250, |
410 | .base.cra_blocksize = AES_BLOCK_SIZE, |
411 | .base.cra_ctxsize = sizeof(struct aesbs_cbc_ctr_ctx), |
412 | .base.cra_module = THIS_MODULE, |
413 | |
414 | .min_keysize = AES_MIN_KEY_SIZE, |
415 | .max_keysize = AES_MAX_KEY_SIZE, |
416 | .walksize = 8 * AES_BLOCK_SIZE, |
417 | .ivsize = AES_BLOCK_SIZE, |
418 | .setkey = aesbs_cbc_ctr_setkey, |
419 | .encrypt = cbc_encrypt, |
420 | .decrypt = cbc_decrypt, |
421 | }, { |
422 | .base.cra_name = "ctr(aes)" , |
423 | .base.cra_driver_name = "ctr-aes-neonbs" , |
424 | .base.cra_priority = 250, |
425 | .base.cra_blocksize = 1, |
426 | .base.cra_ctxsize = sizeof(struct aesbs_cbc_ctr_ctx), |
427 | .base.cra_module = THIS_MODULE, |
428 | |
429 | .min_keysize = AES_MIN_KEY_SIZE, |
430 | .max_keysize = AES_MAX_KEY_SIZE, |
431 | .chunksize = AES_BLOCK_SIZE, |
432 | .walksize = 8 * AES_BLOCK_SIZE, |
433 | .ivsize = AES_BLOCK_SIZE, |
434 | .setkey = aesbs_cbc_ctr_setkey, |
435 | .encrypt = ctr_encrypt, |
436 | .decrypt = ctr_encrypt, |
437 | }, { |
438 | .base.cra_name = "xts(aes)" , |
439 | .base.cra_driver_name = "xts-aes-neonbs" , |
440 | .base.cra_priority = 250, |
441 | .base.cra_blocksize = AES_BLOCK_SIZE, |
442 | .base.cra_ctxsize = sizeof(struct aesbs_xts_ctx), |
443 | .base.cra_module = THIS_MODULE, |
444 | |
445 | .min_keysize = 2 * AES_MIN_KEY_SIZE, |
446 | .max_keysize = 2 * AES_MAX_KEY_SIZE, |
447 | .walksize = 8 * AES_BLOCK_SIZE, |
448 | .ivsize = AES_BLOCK_SIZE, |
449 | .setkey = aesbs_xts_setkey, |
450 | .encrypt = xts_encrypt, |
451 | .decrypt = xts_decrypt, |
452 | } }; |
453 | |
454 | static void aes_exit(void) |
455 | { |
456 | crypto_unregister_skciphers(algs: aes_algs, ARRAY_SIZE(aes_algs)); |
457 | } |
458 | |
459 | static int __init aes_init(void) |
460 | { |
461 | if (!cpu_have_named_feature(ASIMD)) |
462 | return -ENODEV; |
463 | |
464 | return crypto_register_skciphers(algs: aes_algs, ARRAY_SIZE(aes_algs)); |
465 | } |
466 | |
467 | module_init(aes_init); |
468 | module_exit(aes_exit); |
469 | |