1// SPDX-License-Identifier: GPL-2.0-only
2/*
3 * Bit sliced AES using NEON instructions
4 *
5 * Copyright (C) 2017 Linaro Ltd <ard.biesheuvel@linaro.org>
6 */
7
8#include <asm/neon.h>
9#include <asm/simd.h>
10#include <crypto/aes.h>
11#include <crypto/internal/skcipher.h>
12#include <crypto/scatterwalk.h>
13#include <crypto/xts.h>
14#include <linux/module.h>
15#include "aes-cipher.h"
16
17MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>");
18MODULE_DESCRIPTION("Bit sliced AES using NEON instructions");
19MODULE_LICENSE("GPL v2");
20
21MODULE_ALIAS_CRYPTO("ecb(aes)");
22MODULE_ALIAS_CRYPTO("cbc(aes)");
23MODULE_ALIAS_CRYPTO("ctr(aes)");
24MODULE_ALIAS_CRYPTO("xts(aes)");
25
26asmlinkage void aesbs_convert_key(u8 out[], u32 const rk[], int rounds);
27
28asmlinkage void aesbs_ecb_encrypt(u8 out[], u8 const in[], u8 const rk[],
29 int rounds, int blocks);
30asmlinkage void aesbs_ecb_decrypt(u8 out[], u8 const in[], u8 const rk[],
31 int rounds, int blocks);
32
33asmlinkage void aesbs_cbc_decrypt(u8 out[], u8 const in[], u8 const rk[],
34 int rounds, int blocks, u8 iv[]);
35
36asmlinkage void aesbs_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[],
37 int rounds, int blocks, u8 ctr[]);
38
39asmlinkage void aesbs_xts_encrypt(u8 out[], u8 const in[], u8 const rk[],
40 int rounds, int blocks, u8 iv[], int);
41asmlinkage void aesbs_xts_decrypt(u8 out[], u8 const in[], u8 const rk[],
42 int rounds, int blocks, u8 iv[], int);
43
44struct aesbs_ctx {
45 int rounds;
46 u8 rk[13 * (8 * AES_BLOCK_SIZE) + 32] __aligned(AES_BLOCK_SIZE);
47};
48
49struct aesbs_cbc_ctx {
50 struct aesbs_ctx key;
51 struct crypto_aes_ctx fallback;
52};
53
54struct aesbs_xts_ctx {
55 struct aesbs_ctx key;
56 struct crypto_aes_ctx fallback;
57 struct crypto_aes_ctx tweak_key;
58};
59
60static int aesbs_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
61 unsigned int key_len)
62{
63 struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
64 struct crypto_aes_ctx rk;
65 int err;
66
67 err = aes_expandkey(ctx: &rk, in_key, key_len);
68 if (err)
69 return err;
70
71 ctx->rounds = 6 + key_len / 4;
72
73 kernel_neon_begin();
74 aesbs_convert_key(out: ctx->rk, rk: rk.key_enc, rounds: ctx->rounds);
75 kernel_neon_end();
76
77 return 0;
78}
79
80static int __ecb_crypt(struct skcipher_request *req,
81 void (*fn)(u8 out[], u8 const in[], u8 const rk[],
82 int rounds, int blocks))
83{
84 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
85 struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
86 struct skcipher_walk walk;
87 int err;
88
89 err = skcipher_walk_virt(walk: &walk, req, atomic: false);
90
91 while (walk.nbytes >= AES_BLOCK_SIZE) {
92 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
93
94 if (walk.nbytes < walk.total)
95 blocks = round_down(blocks,
96 walk.stride / AES_BLOCK_SIZE);
97
98 kernel_neon_begin();
99 fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->rk,
100 ctx->rounds, blocks);
101 kernel_neon_end();
102 err = skcipher_walk_done(walk: &walk,
103 res: walk.nbytes - blocks * AES_BLOCK_SIZE);
104 }
105
106 return err;
107}
108
109static int ecb_encrypt(struct skcipher_request *req)
110{
111 return __ecb_crypt(req, fn: aesbs_ecb_encrypt);
112}
113
114static int ecb_decrypt(struct skcipher_request *req)
115{
116 return __ecb_crypt(req, fn: aesbs_ecb_decrypt);
117}
118
119static int aesbs_cbc_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
120 unsigned int key_len)
121{
122 struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
123 int err;
124
125 err = aes_expandkey(ctx: &ctx->fallback, in_key, key_len);
126 if (err)
127 return err;
128
129 ctx->key.rounds = 6 + key_len / 4;
130
131 kernel_neon_begin();
132 aesbs_convert_key(out: ctx->key.rk, rk: ctx->fallback.key_enc, rounds: ctx->key.rounds);
133 kernel_neon_end();
134
135 return 0;
136}
137
138static int cbc_encrypt(struct skcipher_request *req)
139{
140 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
141 const struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
142 struct skcipher_walk walk;
143 unsigned int nbytes;
144 int err;
145
146 err = skcipher_walk_virt(walk: &walk, req, atomic: false);
147
148 while ((nbytes = walk.nbytes) >= AES_BLOCK_SIZE) {
149 const u8 *src = walk.src.virt.addr;
150 u8 *dst = walk.dst.virt.addr;
151 u8 *prev = walk.iv;
152
153 do {
154 crypto_xor_cpy(dst, src1: src, src2: prev, AES_BLOCK_SIZE);
155 __aes_arm_encrypt(rk: ctx->fallback.key_enc,
156 rounds: ctx->key.rounds, in: dst, out: dst);
157 prev = dst;
158 src += AES_BLOCK_SIZE;
159 dst += AES_BLOCK_SIZE;
160 nbytes -= AES_BLOCK_SIZE;
161 } while (nbytes >= AES_BLOCK_SIZE);
162 memcpy(walk.iv, prev, AES_BLOCK_SIZE);
163 err = skcipher_walk_done(walk: &walk, res: nbytes);
164 }
165 return err;
166}
167
168static int cbc_decrypt(struct skcipher_request *req)
169{
170 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
171 struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
172 struct skcipher_walk walk;
173 int err;
174
175 err = skcipher_walk_virt(walk: &walk, req, atomic: false);
176
177 while (walk.nbytes >= AES_BLOCK_SIZE) {
178 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
179
180 if (walk.nbytes < walk.total)
181 blocks = round_down(blocks,
182 walk.stride / AES_BLOCK_SIZE);
183
184 kernel_neon_begin();
185 aesbs_cbc_decrypt(out: walk.dst.virt.addr, in: walk.src.virt.addr,
186 rk: ctx->key.rk, rounds: ctx->key.rounds, blocks,
187 iv: walk.iv);
188 kernel_neon_end();
189 err = skcipher_walk_done(walk: &walk,
190 res: walk.nbytes - blocks * AES_BLOCK_SIZE);
191 }
192
193 return err;
194}
195
196static int ctr_encrypt(struct skcipher_request *req)
197{
198 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
199 struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
200 struct skcipher_walk walk;
201 u8 buf[AES_BLOCK_SIZE];
202 int err;
203
204 err = skcipher_walk_virt(walk: &walk, req, atomic: false);
205
206 while (walk.nbytes > 0) {
207 const u8 *src = walk.src.virt.addr;
208 u8 *dst = walk.dst.virt.addr;
209 unsigned int bytes = walk.nbytes;
210
211 if (unlikely(bytes < AES_BLOCK_SIZE))
212 src = dst = memcpy(buf + sizeof(buf) - bytes,
213 src, bytes);
214 else if (walk.nbytes < walk.total)
215 bytes &= ~(8 * AES_BLOCK_SIZE - 1);
216
217 kernel_neon_begin();
218 aesbs_ctr_encrypt(out: dst, in: src, rk: ctx->rk, rounds: ctx->rounds, blocks: bytes, ctr: walk.iv);
219 kernel_neon_end();
220
221 if (unlikely(bytes < AES_BLOCK_SIZE))
222 memcpy(walk.dst.virt.addr,
223 buf + sizeof(buf) - bytes, bytes);
224
225 err = skcipher_walk_done(walk: &walk, res: walk.nbytes - bytes);
226 }
227
228 return err;
229}
230
231static int aesbs_xts_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
232 unsigned int key_len)
233{
234 struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
235 int err;
236
237 err = xts_verify_key(tfm, key: in_key, keylen: key_len);
238 if (err)
239 return err;
240
241 key_len /= 2;
242 err = aes_expandkey(ctx: &ctx->fallback, in_key, key_len);
243 if (err)
244 return err;
245 err = aes_expandkey(ctx: &ctx->tweak_key, in_key: in_key + key_len, key_len);
246 if (err)
247 return err;
248
249 return aesbs_setkey(tfm, in_key, key_len);
250}
251
252static int __xts_crypt(struct skcipher_request *req, bool encrypt,
253 void (*fn)(u8 out[], u8 const in[], u8 const rk[],
254 int rounds, int blocks, u8 iv[], int))
255{
256 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
257 struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
258 const int rounds = ctx->key.rounds;
259 int tail = req->cryptlen % AES_BLOCK_SIZE;
260 struct skcipher_request subreq;
261 u8 buf[2 * AES_BLOCK_SIZE];
262 struct skcipher_walk walk;
263 int err;
264
265 if (req->cryptlen < AES_BLOCK_SIZE)
266 return -EINVAL;
267
268 if (unlikely(tail)) {
269 skcipher_request_set_tfm(req: &subreq, tfm);
270 skcipher_request_set_callback(req: &subreq,
271 flags: skcipher_request_flags(req),
272 NULL, NULL);
273 skcipher_request_set_crypt(req: &subreq, src: req->src, dst: req->dst,
274 cryptlen: req->cryptlen - tail, iv: req->iv);
275 req = &subreq;
276 }
277
278 err = skcipher_walk_virt(walk: &walk, req, atomic: true);
279 if (err)
280 return err;
281
282 __aes_arm_encrypt(rk: ctx->tweak_key.key_enc, rounds, in: walk.iv, out: walk.iv);
283
284 while (walk.nbytes >= AES_BLOCK_SIZE) {
285 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
286 int reorder_last_tweak = !encrypt && tail > 0;
287
288 if (walk.nbytes < walk.total) {
289 blocks = round_down(blocks,
290 walk.stride / AES_BLOCK_SIZE);
291 reorder_last_tweak = 0;
292 }
293
294 kernel_neon_begin();
295 fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->key.rk,
296 rounds, blocks, walk.iv, reorder_last_tweak);
297 kernel_neon_end();
298 err = skcipher_walk_done(walk: &walk,
299 res: walk.nbytes - blocks * AES_BLOCK_SIZE);
300 }
301
302 if (err || likely(!tail))
303 return err;
304
305 /* handle ciphertext stealing */
306 scatterwalk_map_and_copy(buf, sg: req->dst, start: req->cryptlen - AES_BLOCK_SIZE,
307 AES_BLOCK_SIZE, out: 0);
308 memcpy(buf + AES_BLOCK_SIZE, buf, tail);
309 scatterwalk_map_and_copy(buf, sg: req->src, start: req->cryptlen, nbytes: tail, out: 0);
310
311 crypto_xor(dst: buf, src: req->iv, AES_BLOCK_SIZE);
312
313 if (encrypt)
314 __aes_arm_encrypt(rk: ctx->fallback.key_enc, rounds, in: buf, out: buf);
315 else
316 __aes_arm_decrypt(rk: ctx->fallback.key_dec, rounds, in: buf, out: buf);
317
318 crypto_xor(dst: buf, src: req->iv, AES_BLOCK_SIZE);
319
320 scatterwalk_map_and_copy(buf, sg: req->dst, start: req->cryptlen - AES_BLOCK_SIZE,
321 AES_BLOCK_SIZE + tail, out: 1);
322 return 0;
323}
324
325static int xts_encrypt(struct skcipher_request *req)
326{
327 return __xts_crypt(req, encrypt: true, fn: aesbs_xts_encrypt);
328}
329
330static int xts_decrypt(struct skcipher_request *req)
331{
332 return __xts_crypt(req, encrypt: false, fn: aesbs_xts_decrypt);
333}
334
335static struct skcipher_alg aes_algs[] = { {
336 .base.cra_name = "ecb(aes)",
337 .base.cra_driver_name = "ecb-aes-neonbs",
338 .base.cra_priority = 250,
339 .base.cra_blocksize = AES_BLOCK_SIZE,
340 .base.cra_ctxsize = sizeof(struct aesbs_ctx),
341 .base.cra_module = THIS_MODULE,
342
343 .min_keysize = AES_MIN_KEY_SIZE,
344 .max_keysize = AES_MAX_KEY_SIZE,
345 .walksize = 8 * AES_BLOCK_SIZE,
346 .setkey = aesbs_setkey,
347 .encrypt = ecb_encrypt,
348 .decrypt = ecb_decrypt,
349}, {
350 .base.cra_name = "cbc(aes)",
351 .base.cra_driver_name = "cbc-aes-neonbs",
352 .base.cra_priority = 250,
353 .base.cra_blocksize = AES_BLOCK_SIZE,
354 .base.cra_ctxsize = sizeof(struct aesbs_cbc_ctx),
355 .base.cra_module = THIS_MODULE,
356
357 .min_keysize = AES_MIN_KEY_SIZE,
358 .max_keysize = AES_MAX_KEY_SIZE,
359 .walksize = 8 * AES_BLOCK_SIZE,
360 .ivsize = AES_BLOCK_SIZE,
361 .setkey = aesbs_cbc_setkey,
362 .encrypt = cbc_encrypt,
363 .decrypt = cbc_decrypt,
364}, {
365 .base.cra_name = "ctr(aes)",
366 .base.cra_driver_name = "ctr-aes-neonbs",
367 .base.cra_priority = 250,
368 .base.cra_blocksize = 1,
369 .base.cra_ctxsize = sizeof(struct aesbs_ctx),
370 .base.cra_module = THIS_MODULE,
371
372 .min_keysize = AES_MIN_KEY_SIZE,
373 .max_keysize = AES_MAX_KEY_SIZE,
374 .chunksize = AES_BLOCK_SIZE,
375 .walksize = 8 * AES_BLOCK_SIZE,
376 .ivsize = AES_BLOCK_SIZE,
377 .setkey = aesbs_setkey,
378 .encrypt = ctr_encrypt,
379 .decrypt = ctr_encrypt,
380}, {
381 .base.cra_name = "xts(aes)",
382 .base.cra_driver_name = "xts-aes-neonbs",
383 .base.cra_priority = 250,
384 .base.cra_blocksize = AES_BLOCK_SIZE,
385 .base.cra_ctxsize = sizeof(struct aesbs_xts_ctx),
386 .base.cra_module = THIS_MODULE,
387
388 .min_keysize = 2 * AES_MIN_KEY_SIZE,
389 .max_keysize = 2 * AES_MAX_KEY_SIZE,
390 .walksize = 8 * AES_BLOCK_SIZE,
391 .ivsize = AES_BLOCK_SIZE,
392 .setkey = aesbs_xts_setkey,
393 .encrypt = xts_encrypt,
394 .decrypt = xts_decrypt,
395} };
396
397static void aes_exit(void)
398{
399 crypto_unregister_skciphers(algs: aes_algs, ARRAY_SIZE(aes_algs));
400}
401
402static int __init aes_init(void)
403{
404 if (!(elf_hwcap & HWCAP_NEON))
405 return -ENODEV;
406
407 return crypto_register_skciphers(algs: aes_algs, ARRAY_SIZE(aes_algs));
408}
409
410module_init(aes_init);
411module_exit(aes_exit);
412

source code of linux/arch/arm/crypto/aes-neonbs-glue.c