1 | // SPDX-License-Identifier: GPL-2.0-only |
2 | /* |
3 | * Bit sliced AES using NEON instructions |
4 | * |
5 | * Copyright (C) 2016 - 2017 Linaro Ltd <ard.biesheuvel@linaro.org> |
6 | */ |
7 | |
8 | #include <asm/neon.h> |
9 | #include <asm/simd.h> |
10 | #include <crypto/aes.h> |
11 | #include <crypto/ctr.h> |
12 | #include <crypto/internal/simd.h> |
13 | #include <crypto/internal/skcipher.h> |
14 | #include <crypto/scatterwalk.h> |
15 | #include <crypto/xts.h> |
16 | #include <linux/module.h> |
17 | |
18 | MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>" ); |
19 | MODULE_DESCRIPTION("Bit sliced AES using NEON instructions" ); |
20 | MODULE_LICENSE("GPL v2" ); |
21 | |
22 | MODULE_ALIAS_CRYPTO("ecb(aes)" ); |
23 | MODULE_ALIAS_CRYPTO("cbc(aes)" ); |
24 | MODULE_ALIAS_CRYPTO("ctr(aes)" ); |
25 | MODULE_ALIAS_CRYPTO("xts(aes)" ); |
26 | |
27 | asmlinkage void aesbs_convert_key(u8 out[], u32 const rk[], int rounds); |
28 | |
29 | asmlinkage void aesbs_ecb_encrypt(u8 out[], u8 const in[], u8 const rk[], |
30 | int rounds, int blocks); |
31 | asmlinkage void aesbs_ecb_decrypt(u8 out[], u8 const in[], u8 const rk[], |
32 | int rounds, int blocks); |
33 | |
34 | asmlinkage void aesbs_cbc_decrypt(u8 out[], u8 const in[], u8 const rk[], |
35 | int rounds, int blocks, u8 iv[]); |
36 | |
37 | asmlinkage void aesbs_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[], |
38 | int rounds, int blocks, u8 iv[]); |
39 | |
40 | asmlinkage void aesbs_xts_encrypt(u8 out[], u8 const in[], u8 const rk[], |
41 | int rounds, int blocks, u8 iv[]); |
42 | asmlinkage void aesbs_xts_decrypt(u8 out[], u8 const in[], u8 const rk[], |
43 | int rounds, int blocks, u8 iv[]); |
44 | |
45 | /* borrowed from aes-neon-blk.ko */ |
46 | asmlinkage void neon_aes_ecb_encrypt(u8 out[], u8 const in[], u32 const rk[], |
47 | int rounds, int blocks); |
48 | asmlinkage void neon_aes_cbc_encrypt(u8 out[], u8 const in[], u32 const rk[], |
49 | int rounds, int blocks, u8 iv[]); |
50 | asmlinkage void neon_aes_ctr_encrypt(u8 out[], u8 const in[], u32 const rk[], |
51 | int rounds, int bytes, u8 ctr[]); |
52 | asmlinkage void neon_aes_xts_encrypt(u8 out[], u8 const in[], |
53 | u32 const rk1[], int rounds, int bytes, |
54 | u32 const rk2[], u8 iv[], int first); |
55 | asmlinkage void neon_aes_xts_decrypt(u8 out[], u8 const in[], |
56 | u32 const rk1[], int rounds, int bytes, |
57 | u32 const rk2[], u8 iv[], int first); |
58 | |
59 | struct aesbs_ctx { |
60 | u8 rk[13 * (8 * AES_BLOCK_SIZE) + 32]; |
61 | int rounds; |
62 | } __aligned(AES_BLOCK_SIZE); |
63 | |
64 | struct aesbs_cbc_ctr_ctx { |
65 | struct aesbs_ctx key; |
66 | u32 enc[AES_MAX_KEYLENGTH_U32]; |
67 | }; |
68 | |
69 | struct aesbs_xts_ctx { |
70 | struct aesbs_ctx key; |
71 | u32 twkey[AES_MAX_KEYLENGTH_U32]; |
72 | struct crypto_aes_ctx cts; |
73 | }; |
74 | |
75 | static int aesbs_setkey(struct crypto_skcipher *tfm, const u8 *in_key, |
76 | unsigned int key_len) |
77 | { |
78 | struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm); |
79 | struct crypto_aes_ctx rk; |
80 | int err; |
81 | |
82 | err = aes_expandkey(ctx: &rk, in_key, key_len); |
83 | if (err) |
84 | return err; |
85 | |
86 | ctx->rounds = 6 + key_len / 4; |
87 | |
88 | kernel_neon_begin(); |
89 | aesbs_convert_key(out: ctx->rk, rk: rk.key_enc, rounds: ctx->rounds); |
90 | kernel_neon_end(); |
91 | |
92 | return 0; |
93 | } |
94 | |
95 | static int __ecb_crypt(struct skcipher_request *req, |
96 | void (*fn)(u8 out[], u8 const in[], u8 const rk[], |
97 | int rounds, int blocks)) |
98 | { |
99 | struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); |
100 | struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm); |
101 | struct skcipher_walk walk; |
102 | int err; |
103 | |
104 | err = skcipher_walk_virt(walk: &walk, req, atomic: false); |
105 | |
106 | while (walk.nbytes >= AES_BLOCK_SIZE) { |
107 | unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE; |
108 | |
109 | if (walk.nbytes < walk.total) |
110 | blocks = round_down(blocks, |
111 | walk.stride / AES_BLOCK_SIZE); |
112 | |
113 | kernel_neon_begin(); |
114 | fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->rk, |
115 | ctx->rounds, blocks); |
116 | kernel_neon_end(); |
117 | err = skcipher_walk_done(walk: &walk, |
118 | res: walk.nbytes - blocks * AES_BLOCK_SIZE); |
119 | } |
120 | |
121 | return err; |
122 | } |
123 | |
124 | static int ecb_encrypt(struct skcipher_request *req) |
125 | { |
126 | return __ecb_crypt(req, fn: aesbs_ecb_encrypt); |
127 | } |
128 | |
129 | static int ecb_decrypt(struct skcipher_request *req) |
130 | { |
131 | return __ecb_crypt(req, fn: aesbs_ecb_decrypt); |
132 | } |
133 | |
134 | static int aesbs_cbc_ctr_setkey(struct crypto_skcipher *tfm, const u8 *in_key, |
135 | unsigned int key_len) |
136 | { |
137 | struct aesbs_cbc_ctr_ctx *ctx = crypto_skcipher_ctx(tfm); |
138 | struct crypto_aes_ctx rk; |
139 | int err; |
140 | |
141 | err = aes_expandkey(ctx: &rk, in_key, key_len); |
142 | if (err) |
143 | return err; |
144 | |
145 | ctx->key.rounds = 6 + key_len / 4; |
146 | |
147 | memcpy(ctx->enc, rk.key_enc, sizeof(ctx->enc)); |
148 | |
149 | kernel_neon_begin(); |
150 | aesbs_convert_key(out: ctx->key.rk, rk: rk.key_enc, rounds: ctx->key.rounds); |
151 | kernel_neon_end(); |
152 | memzero_explicit(s: &rk, count: sizeof(rk)); |
153 | |
154 | return 0; |
155 | } |
156 | |
157 | static int cbc_encrypt(struct skcipher_request *req) |
158 | { |
159 | struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); |
160 | struct aesbs_cbc_ctr_ctx *ctx = crypto_skcipher_ctx(tfm); |
161 | struct skcipher_walk walk; |
162 | int err; |
163 | |
164 | err = skcipher_walk_virt(walk: &walk, req, atomic: false); |
165 | |
166 | while (walk.nbytes >= AES_BLOCK_SIZE) { |
167 | unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE; |
168 | |
169 | /* fall back to the non-bitsliced NEON implementation */ |
170 | kernel_neon_begin(); |
171 | neon_aes_cbc_encrypt(out: walk.dst.virt.addr, in: walk.src.virt.addr, |
172 | rk: ctx->enc, rounds: ctx->key.rounds, blocks, |
173 | iv: walk.iv); |
174 | kernel_neon_end(); |
175 | err = skcipher_walk_done(walk: &walk, res: walk.nbytes % AES_BLOCK_SIZE); |
176 | } |
177 | return err; |
178 | } |
179 | |
180 | static int cbc_decrypt(struct skcipher_request *req) |
181 | { |
182 | struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); |
183 | struct aesbs_cbc_ctr_ctx *ctx = crypto_skcipher_ctx(tfm); |
184 | struct skcipher_walk walk; |
185 | int err; |
186 | |
187 | err = skcipher_walk_virt(walk: &walk, req, atomic: false); |
188 | |
189 | while (walk.nbytes >= AES_BLOCK_SIZE) { |
190 | unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE; |
191 | |
192 | if (walk.nbytes < walk.total) |
193 | blocks = round_down(blocks, |
194 | walk.stride / AES_BLOCK_SIZE); |
195 | |
196 | kernel_neon_begin(); |
197 | aesbs_cbc_decrypt(out: walk.dst.virt.addr, in: walk.src.virt.addr, |
198 | rk: ctx->key.rk, rounds: ctx->key.rounds, blocks, |
199 | iv: walk.iv); |
200 | kernel_neon_end(); |
201 | err = skcipher_walk_done(walk: &walk, |
202 | res: walk.nbytes - blocks * AES_BLOCK_SIZE); |
203 | } |
204 | |
205 | return err; |
206 | } |
207 | |
208 | static int ctr_encrypt(struct skcipher_request *req) |
209 | { |
210 | struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); |
211 | struct aesbs_cbc_ctr_ctx *ctx = crypto_skcipher_ctx(tfm); |
212 | struct skcipher_walk walk; |
213 | int err; |
214 | |
215 | err = skcipher_walk_virt(walk: &walk, req, atomic: false); |
216 | |
217 | while (walk.nbytes > 0) { |
218 | int blocks = (walk.nbytes / AES_BLOCK_SIZE) & ~7; |
219 | int nbytes = walk.nbytes % (8 * AES_BLOCK_SIZE); |
220 | const u8 *src = walk.src.virt.addr; |
221 | u8 *dst = walk.dst.virt.addr; |
222 | |
223 | kernel_neon_begin(); |
224 | if (blocks >= 8) { |
225 | aesbs_ctr_encrypt(out: dst, in: src, rk: ctx->key.rk, rounds: ctx->key.rounds, |
226 | blocks, iv: walk.iv); |
227 | dst += blocks * AES_BLOCK_SIZE; |
228 | src += blocks * AES_BLOCK_SIZE; |
229 | } |
230 | if (nbytes && walk.nbytes == walk.total) { |
231 | u8 buf[AES_BLOCK_SIZE]; |
232 | u8 *d = dst; |
233 | |
234 | if (unlikely(nbytes < AES_BLOCK_SIZE)) |
235 | src = dst = memcpy(buf + sizeof(buf) - nbytes, |
236 | src, nbytes); |
237 | |
238 | neon_aes_ctr_encrypt(out: dst, in: src, rk: ctx->enc, rounds: ctx->key.rounds, |
239 | bytes: nbytes, ctr: walk.iv); |
240 | |
241 | if (unlikely(nbytes < AES_BLOCK_SIZE)) |
242 | memcpy(d, dst, nbytes); |
243 | |
244 | nbytes = 0; |
245 | } |
246 | kernel_neon_end(); |
247 | err = skcipher_walk_done(walk: &walk, res: nbytes); |
248 | } |
249 | return err; |
250 | } |
251 | |
252 | static int aesbs_xts_setkey(struct crypto_skcipher *tfm, const u8 *in_key, |
253 | unsigned int key_len) |
254 | { |
255 | struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm); |
256 | struct crypto_aes_ctx rk; |
257 | int err; |
258 | |
259 | err = xts_verify_key(tfm, key: in_key, keylen: key_len); |
260 | if (err) |
261 | return err; |
262 | |
263 | key_len /= 2; |
264 | err = aes_expandkey(ctx: &ctx->cts, in_key, key_len); |
265 | if (err) |
266 | return err; |
267 | |
268 | err = aes_expandkey(ctx: &rk, in_key: in_key + key_len, key_len); |
269 | if (err) |
270 | return err; |
271 | |
272 | memcpy(ctx->twkey, rk.key_enc, sizeof(ctx->twkey)); |
273 | |
274 | return aesbs_setkey(tfm, in_key, key_len); |
275 | } |
276 | |
277 | static int __xts_crypt(struct skcipher_request *req, bool encrypt, |
278 | void (*fn)(u8 out[], u8 const in[], u8 const rk[], |
279 | int rounds, int blocks, u8 iv[])) |
280 | { |
281 | struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); |
282 | struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm); |
283 | int tail = req->cryptlen % (8 * AES_BLOCK_SIZE); |
284 | struct scatterlist sg_src[2], sg_dst[2]; |
285 | struct skcipher_request subreq; |
286 | struct scatterlist *src, *dst; |
287 | struct skcipher_walk walk; |
288 | int nbytes, err; |
289 | int first = 1; |
290 | const u8 *in; |
291 | u8 *out; |
292 | |
293 | if (req->cryptlen < AES_BLOCK_SIZE) |
294 | return -EINVAL; |
295 | |
296 | /* ensure that the cts tail is covered by a single step */ |
297 | if (unlikely(tail > 0 && tail < AES_BLOCK_SIZE)) { |
298 | int xts_blocks = DIV_ROUND_UP(req->cryptlen, |
299 | AES_BLOCK_SIZE) - 2; |
300 | |
301 | skcipher_request_set_tfm(req: &subreq, tfm); |
302 | skcipher_request_set_callback(req: &subreq, |
303 | flags: skcipher_request_flags(req), |
304 | NULL, NULL); |
305 | skcipher_request_set_crypt(req: &subreq, src: req->src, dst: req->dst, |
306 | cryptlen: xts_blocks * AES_BLOCK_SIZE, |
307 | iv: req->iv); |
308 | req = &subreq; |
309 | } else { |
310 | tail = 0; |
311 | } |
312 | |
313 | err = skcipher_walk_virt(walk: &walk, req, atomic: false); |
314 | if (err) |
315 | return err; |
316 | |
317 | while (walk.nbytes >= AES_BLOCK_SIZE) { |
318 | int blocks = (walk.nbytes / AES_BLOCK_SIZE) & ~7; |
319 | out = walk.dst.virt.addr; |
320 | in = walk.src.virt.addr; |
321 | nbytes = walk.nbytes; |
322 | |
323 | kernel_neon_begin(); |
324 | if (blocks >= 8) { |
325 | if (first == 1) |
326 | neon_aes_ecb_encrypt(out: walk.iv, in: walk.iv, |
327 | rk: ctx->twkey, |
328 | rounds: ctx->key.rounds, blocks: 1); |
329 | first = 2; |
330 | |
331 | fn(out, in, ctx->key.rk, ctx->key.rounds, blocks, |
332 | walk.iv); |
333 | |
334 | out += blocks * AES_BLOCK_SIZE; |
335 | in += blocks * AES_BLOCK_SIZE; |
336 | nbytes -= blocks * AES_BLOCK_SIZE; |
337 | } |
338 | if (walk.nbytes == walk.total && nbytes > 0) { |
339 | if (encrypt) |
340 | neon_aes_xts_encrypt(out, in, rk1: ctx->cts.key_enc, |
341 | rounds: ctx->key.rounds, bytes: nbytes, |
342 | rk2: ctx->twkey, iv: walk.iv, first); |
343 | else |
344 | neon_aes_xts_decrypt(out, in, rk1: ctx->cts.key_dec, |
345 | rounds: ctx->key.rounds, bytes: nbytes, |
346 | rk2: ctx->twkey, iv: walk.iv, first); |
347 | nbytes = first = 0; |
348 | } |
349 | kernel_neon_end(); |
350 | err = skcipher_walk_done(walk: &walk, res: nbytes); |
351 | } |
352 | |
353 | if (err || likely(!tail)) |
354 | return err; |
355 | |
356 | /* handle ciphertext stealing */ |
357 | dst = src = scatterwalk_ffwd(dst: sg_src, src: req->src, len: req->cryptlen); |
358 | if (req->dst != req->src) |
359 | dst = scatterwalk_ffwd(dst: sg_dst, src: req->dst, len: req->cryptlen); |
360 | |
361 | skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail, |
362 | iv: req->iv); |
363 | |
364 | err = skcipher_walk_virt(walk: &walk, req, atomic: false); |
365 | if (err) |
366 | return err; |
367 | |
368 | out = walk.dst.virt.addr; |
369 | in = walk.src.virt.addr; |
370 | nbytes = walk.nbytes; |
371 | |
372 | kernel_neon_begin(); |
373 | if (encrypt) |
374 | neon_aes_xts_encrypt(out, in, rk1: ctx->cts.key_enc, rounds: ctx->key.rounds, |
375 | bytes: nbytes, rk2: ctx->twkey, iv: walk.iv, first); |
376 | else |
377 | neon_aes_xts_decrypt(out, in, rk1: ctx->cts.key_dec, rounds: ctx->key.rounds, |
378 | bytes: nbytes, rk2: ctx->twkey, iv: walk.iv, first); |
379 | kernel_neon_end(); |
380 | |
381 | return skcipher_walk_done(walk: &walk, res: 0); |
382 | } |
383 | |
384 | static int xts_encrypt(struct skcipher_request *req) |
385 | { |
386 | return __xts_crypt(req, encrypt: true, fn: aesbs_xts_encrypt); |
387 | } |
388 | |
389 | static int xts_decrypt(struct skcipher_request *req) |
390 | { |
391 | return __xts_crypt(req, encrypt: false, fn: aesbs_xts_decrypt); |
392 | } |
393 | |
394 | static struct skcipher_alg aes_algs[] = { { |
395 | .base.cra_name = "ecb(aes)" , |
396 | .base.cra_driver_name = "ecb-aes-neonbs" , |
397 | .base.cra_priority = 250, |
398 | .base.cra_blocksize = AES_BLOCK_SIZE, |
399 | .base.cra_ctxsize = sizeof(struct aesbs_ctx), |
400 | .base.cra_module = THIS_MODULE, |
401 | |
402 | .min_keysize = AES_MIN_KEY_SIZE, |
403 | .max_keysize = AES_MAX_KEY_SIZE, |
404 | .walksize = 8 * AES_BLOCK_SIZE, |
405 | .setkey = aesbs_setkey, |
406 | .encrypt = ecb_encrypt, |
407 | .decrypt = ecb_decrypt, |
408 | }, { |
409 | .base.cra_name = "cbc(aes)" , |
410 | .base.cra_driver_name = "cbc-aes-neonbs" , |
411 | .base.cra_priority = 250, |
412 | .base.cra_blocksize = AES_BLOCK_SIZE, |
413 | .base.cra_ctxsize = sizeof(struct aesbs_cbc_ctr_ctx), |
414 | .base.cra_module = THIS_MODULE, |
415 | |
416 | .min_keysize = AES_MIN_KEY_SIZE, |
417 | .max_keysize = AES_MAX_KEY_SIZE, |
418 | .walksize = 8 * AES_BLOCK_SIZE, |
419 | .ivsize = AES_BLOCK_SIZE, |
420 | .setkey = aesbs_cbc_ctr_setkey, |
421 | .encrypt = cbc_encrypt, |
422 | .decrypt = cbc_decrypt, |
423 | }, { |
424 | .base.cra_name = "ctr(aes)" , |
425 | .base.cra_driver_name = "ctr-aes-neonbs" , |
426 | .base.cra_priority = 250, |
427 | .base.cra_blocksize = 1, |
428 | .base.cra_ctxsize = sizeof(struct aesbs_cbc_ctr_ctx), |
429 | .base.cra_module = THIS_MODULE, |
430 | |
431 | .min_keysize = AES_MIN_KEY_SIZE, |
432 | .max_keysize = AES_MAX_KEY_SIZE, |
433 | .chunksize = AES_BLOCK_SIZE, |
434 | .walksize = 8 * AES_BLOCK_SIZE, |
435 | .ivsize = AES_BLOCK_SIZE, |
436 | .setkey = aesbs_cbc_ctr_setkey, |
437 | .encrypt = ctr_encrypt, |
438 | .decrypt = ctr_encrypt, |
439 | }, { |
440 | .base.cra_name = "xts(aes)" , |
441 | .base.cra_driver_name = "xts-aes-neonbs" , |
442 | .base.cra_priority = 250, |
443 | .base.cra_blocksize = AES_BLOCK_SIZE, |
444 | .base.cra_ctxsize = sizeof(struct aesbs_xts_ctx), |
445 | .base.cra_module = THIS_MODULE, |
446 | |
447 | .min_keysize = 2 * AES_MIN_KEY_SIZE, |
448 | .max_keysize = 2 * AES_MAX_KEY_SIZE, |
449 | .walksize = 8 * AES_BLOCK_SIZE, |
450 | .ivsize = AES_BLOCK_SIZE, |
451 | .setkey = aesbs_xts_setkey, |
452 | .encrypt = xts_encrypt, |
453 | .decrypt = xts_decrypt, |
454 | } }; |
455 | |
456 | static void aes_exit(void) |
457 | { |
458 | crypto_unregister_skciphers(algs: aes_algs, ARRAY_SIZE(aes_algs)); |
459 | } |
460 | |
461 | static int __init aes_init(void) |
462 | { |
463 | if (!cpu_have_named_feature(ASIMD)) |
464 | return -ENODEV; |
465 | |
466 | return crypto_register_skciphers(algs: aes_algs, ARRAY_SIZE(aes_algs)); |
467 | } |
468 | |
469 | module_init(aes_init); |
470 | module_exit(aes_exit); |
471 | |