GNU Linux-libre 4.9.309-gnu1
[releases.git] / arch / arm / crypto / aes-ce-glue.c
1 /*
2  * aes-ce-glue.c - wrapper code for ARMv8 AES
3  *
4  * Copyright (C) 2015 Linaro Ltd <ard.biesheuvel@linaro.org>
5  *
6  * This program is free software; you can redistribute it and/or modify
7  * it under the terms of the GNU General Public License version 2 as
8  * published by the Free Software Foundation.
9  */
10
11 #include <asm/hwcap.h>
12 #include <asm/neon.h>
13 #include <asm/hwcap.h>
14 #include <crypto/aes.h>
15 #include <crypto/ablk_helper.h>
16 #include <crypto/algapi.h>
17 #include <linux/module.h>
18 #include <crypto/xts.h>
19
20 MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS using ARMv8 Crypto Extensions");
21 MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>");
22 MODULE_LICENSE("GPL v2");
23
24 /* defined in aes-ce-core.S */
25 asmlinkage u32 ce_aes_sub(u32 input);
26 asmlinkage void ce_aes_invert(void *dst, void *src);
27
28 asmlinkage void ce_aes_ecb_encrypt(u8 out[], u8 const in[], u8 const rk[],
29                                    int rounds, int blocks);
30 asmlinkage void ce_aes_ecb_decrypt(u8 out[], u8 const in[], u8 const rk[],
31                                    int rounds, int blocks);
32
33 asmlinkage void ce_aes_cbc_encrypt(u8 out[], u8 const in[], u8 const rk[],
34                                    int rounds, int blocks, u8 iv[]);
35 asmlinkage void ce_aes_cbc_decrypt(u8 out[], u8 const in[], u8 const rk[],
36                                    int rounds, int blocks, u8 iv[]);
37
38 asmlinkage void ce_aes_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[],
39                                    int rounds, int blocks, u8 ctr[]);
40
41 asmlinkage void ce_aes_xts_encrypt(u8 out[], u8 const in[], u8 const rk1[],
42                                    int rounds, int blocks, u8 iv[],
43                                    u8 const rk2[], int first);
44 asmlinkage void ce_aes_xts_decrypt(u8 out[], u8 const in[], u8 const rk1[],
45                                    int rounds, int blocks, u8 iv[],
46                                    u8 const rk2[], int first);
47
48 struct aes_block {
49         u8 b[AES_BLOCK_SIZE];
50 };
51
52 static int num_rounds(struct crypto_aes_ctx *ctx)
53 {
54         /*
55          * # of rounds specified by AES:
56          * 128 bit key          10 rounds
57          * 192 bit key          12 rounds
58          * 256 bit key          14 rounds
59          * => n byte key        => 6 + (n/4) rounds
60          */
61         return 6 + ctx->key_length / 4;
62 }
63
64 static int ce_aes_expandkey(struct crypto_aes_ctx *ctx, const u8 *in_key,
65                             unsigned int key_len)
66 {
67         /*
68          * The AES key schedule round constants
69          */
70         static u8 const rcon[] = {
71                 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1b, 0x36,
72         };
73
74         u32 kwords = key_len / sizeof(u32);
75         struct aes_block *key_enc, *key_dec;
76         int i, j;
77
78         if (key_len != AES_KEYSIZE_128 &&
79             key_len != AES_KEYSIZE_192 &&
80             key_len != AES_KEYSIZE_256)
81                 return -EINVAL;
82
83         memcpy(ctx->key_enc, in_key, key_len);
84         ctx->key_length = key_len;
85
86         kernel_neon_begin();
87         for (i = 0; i < sizeof(rcon); i++) {
88                 u32 *rki = ctx->key_enc + (i * kwords);
89                 u32 *rko = rki + kwords;
90
91 #ifndef CONFIG_CPU_BIG_ENDIAN
92                 rko[0] = ror32(ce_aes_sub(rki[kwords - 1]), 8);
93                 rko[0] = rko[0] ^ rki[0] ^ rcon[i];
94 #else
95                 rko[0] = rol32(ce_aes_sub(rki[kwords - 1]), 8);
96                 rko[0] = rko[0] ^ rki[0] ^ (rcon[i] << 24);
97 #endif
98                 rko[1] = rko[0] ^ rki[1];
99                 rko[2] = rko[1] ^ rki[2];
100                 rko[3] = rko[2] ^ rki[3];
101
102                 if (key_len == AES_KEYSIZE_192) {
103                         if (i >= 7)
104                                 break;
105                         rko[4] = rko[3] ^ rki[4];
106                         rko[5] = rko[4] ^ rki[5];
107                 } else if (key_len == AES_KEYSIZE_256) {
108                         if (i >= 6)
109                                 break;
110                         rko[4] = ce_aes_sub(rko[3]) ^ rki[4];
111                         rko[5] = rko[4] ^ rki[5];
112                         rko[6] = rko[5] ^ rki[6];
113                         rko[7] = rko[6] ^ rki[7];
114                 }
115         }
116
117         /*
118          * Generate the decryption keys for the Equivalent Inverse Cipher.
119          * This involves reversing the order of the round keys, and applying
120          * the Inverse Mix Columns transformation on all but the first and
121          * the last one.
122          */
123         key_enc = (struct aes_block *)ctx->key_enc;
124         key_dec = (struct aes_block *)ctx->key_dec;
125         j = num_rounds(ctx);
126
127         key_dec[0] = key_enc[j];
128         for (i = 1, j--; j > 0; i++, j--)
129                 ce_aes_invert(key_dec + i, key_enc + j);
130         key_dec[i] = key_enc[0];
131
132         kernel_neon_end();
133         return 0;
134 }
135
136 static int ce_aes_setkey(struct crypto_tfm *tfm, const u8 *in_key,
137                          unsigned int key_len)
138 {
139         struct crypto_aes_ctx *ctx = crypto_tfm_ctx(tfm);
140         int ret;
141
142         ret = ce_aes_expandkey(ctx, in_key, key_len);
143         if (!ret)
144                 return 0;
145
146         tfm->crt_flags |= CRYPTO_TFM_RES_BAD_KEY_LEN;
147         return -EINVAL;
148 }
149
150 struct crypto_aes_xts_ctx {
151         struct crypto_aes_ctx key1;
152         struct crypto_aes_ctx __aligned(8) key2;
153 };
154
155 static int xts_set_key(struct crypto_tfm *tfm, const u8 *in_key,
156                        unsigned int key_len)
157 {
158         struct crypto_aes_xts_ctx *ctx = crypto_tfm_ctx(tfm);
159         int ret;
160
161         ret = xts_check_key(tfm, in_key, key_len);
162         if (ret)
163                 return ret;
164
165         ret = ce_aes_expandkey(&ctx->key1, in_key, key_len / 2);
166         if (!ret)
167                 ret = ce_aes_expandkey(&ctx->key2, &in_key[key_len / 2],
168                                        key_len / 2);
169         if (!ret)
170                 return 0;
171
172         tfm->crt_flags |= CRYPTO_TFM_RES_BAD_KEY_LEN;
173         return -EINVAL;
174 }
175
176 static int ecb_encrypt(struct blkcipher_desc *desc, struct scatterlist *dst,
177                        struct scatterlist *src, unsigned int nbytes)
178 {
179         struct crypto_aes_ctx *ctx = crypto_blkcipher_ctx(desc->tfm);
180         struct blkcipher_walk walk;
181         unsigned int blocks;
182         int err;
183
184         desc->flags &= ~CRYPTO_TFM_REQ_MAY_SLEEP;
185         blkcipher_walk_init(&walk, dst, src, nbytes);
186         err = blkcipher_walk_virt(desc, &walk);
187
188         kernel_neon_begin();
189         while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
190                 ce_aes_ecb_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
191                                    (u8 *)ctx->key_enc, num_rounds(ctx), blocks);
192                 err = blkcipher_walk_done(desc, &walk,
193                                           walk.nbytes % AES_BLOCK_SIZE);
194         }
195         kernel_neon_end();
196         return err;
197 }
198
199 static int ecb_decrypt(struct blkcipher_desc *desc, struct scatterlist *dst,
200                        struct scatterlist *src, unsigned int nbytes)
201 {
202         struct crypto_aes_ctx *ctx = crypto_blkcipher_ctx(desc->tfm);
203         struct blkcipher_walk walk;
204         unsigned int blocks;
205         int err;
206
207         desc->flags &= ~CRYPTO_TFM_REQ_MAY_SLEEP;
208         blkcipher_walk_init(&walk, dst, src, nbytes);
209         err = blkcipher_walk_virt(desc, &walk);
210
211         kernel_neon_begin();
212         while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
213                 ce_aes_ecb_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
214                                    (u8 *)ctx->key_dec, num_rounds(ctx), blocks);
215                 err = blkcipher_walk_done(desc, &walk,
216                                           walk.nbytes % AES_BLOCK_SIZE);
217         }
218         kernel_neon_end();
219         return err;
220 }
221
222 static int cbc_encrypt(struct blkcipher_desc *desc, struct scatterlist *dst,
223                        struct scatterlist *src, unsigned int nbytes)
224 {
225         struct crypto_aes_ctx *ctx = crypto_blkcipher_ctx(desc->tfm);
226         struct blkcipher_walk walk;
227         unsigned int blocks;
228         int err;
229
230         desc->flags &= ~CRYPTO_TFM_REQ_MAY_SLEEP;
231         blkcipher_walk_init(&walk, dst, src, nbytes);
232         err = blkcipher_walk_virt(desc, &walk);
233
234         kernel_neon_begin();
235         while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
236                 ce_aes_cbc_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
237                                    (u8 *)ctx->key_enc, num_rounds(ctx), blocks,
238                                    walk.iv);
239                 err = blkcipher_walk_done(desc, &walk,
240                                           walk.nbytes % AES_BLOCK_SIZE);
241         }
242         kernel_neon_end();
243         return err;
244 }
245
246 static int cbc_decrypt(struct blkcipher_desc *desc, struct scatterlist *dst,
247                        struct scatterlist *src, unsigned int nbytes)
248 {
249         struct crypto_aes_ctx *ctx = crypto_blkcipher_ctx(desc->tfm);
250         struct blkcipher_walk walk;
251         unsigned int blocks;
252         int err;
253
254         desc->flags &= ~CRYPTO_TFM_REQ_MAY_SLEEP;
255         blkcipher_walk_init(&walk, dst, src, nbytes);
256         err = blkcipher_walk_virt(desc, &walk);
257
258         kernel_neon_begin();
259         while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
260                 ce_aes_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
261                                    (u8 *)ctx->key_dec, num_rounds(ctx), blocks,
262                                    walk.iv);
263                 err = blkcipher_walk_done(desc, &walk,
264                                           walk.nbytes % AES_BLOCK_SIZE);
265         }
266         kernel_neon_end();
267         return err;
268 }
269
270 static int ctr_encrypt(struct blkcipher_desc *desc, struct scatterlist *dst,
271                        struct scatterlist *src, unsigned int nbytes)
272 {
273         struct crypto_aes_ctx *ctx = crypto_blkcipher_ctx(desc->tfm);
274         struct blkcipher_walk walk;
275         int err, blocks;
276
277         desc->flags &= ~CRYPTO_TFM_REQ_MAY_SLEEP;
278         blkcipher_walk_init(&walk, dst, src, nbytes);
279         err = blkcipher_walk_virt_block(desc, &walk, AES_BLOCK_SIZE);
280
281         kernel_neon_begin();
282         while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
283                 ce_aes_ctr_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
284                                    (u8 *)ctx->key_enc, num_rounds(ctx), blocks,
285                                    walk.iv);
286                 nbytes -= blocks * AES_BLOCK_SIZE;
287                 if (nbytes && nbytes == walk.nbytes % AES_BLOCK_SIZE)
288                         break;
289                 err = blkcipher_walk_done(desc, &walk,
290                                           walk.nbytes % AES_BLOCK_SIZE);
291         }
292         if (walk.nbytes % AES_BLOCK_SIZE) {
293                 u8 *tdst = walk.dst.virt.addr + blocks * AES_BLOCK_SIZE;
294                 u8 *tsrc = walk.src.virt.addr + blocks * AES_BLOCK_SIZE;
295                 u8 __aligned(8) tail[AES_BLOCK_SIZE];
296
297                 /*
298                  * Minimum alignment is 8 bytes, so if nbytes is <= 8, we need
299                  * to tell aes_ctr_encrypt() to only read half a block.
300                  */
301                 blocks = (nbytes <= 8) ? -1 : 1;
302
303                 ce_aes_ctr_encrypt(tail, tsrc, (u8 *)ctx->key_enc,
304                                    num_rounds(ctx), blocks, walk.iv);
305                 memcpy(tdst, tail, nbytes);
306                 err = blkcipher_walk_done(desc, &walk, 0);
307         }
308         kernel_neon_end();
309
310         return err;
311 }
312
313 static int xts_encrypt(struct blkcipher_desc *desc, struct scatterlist *dst,
314                        struct scatterlist *src, unsigned int nbytes)
315 {
316         struct crypto_aes_xts_ctx *ctx = crypto_blkcipher_ctx(desc->tfm);
317         int err, first, rounds = num_rounds(&ctx->key1);
318         struct blkcipher_walk walk;
319         unsigned int blocks;
320
321         desc->flags &= ~CRYPTO_TFM_REQ_MAY_SLEEP;
322         blkcipher_walk_init(&walk, dst, src, nbytes);
323         err = blkcipher_walk_virt(desc, &walk);
324
325         kernel_neon_begin();
326         for (first = 1; (blocks = (walk.nbytes / AES_BLOCK_SIZE)); first = 0) {
327                 ce_aes_xts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
328                                    (u8 *)ctx->key1.key_enc, rounds, blocks,
329                                    walk.iv, (u8 *)ctx->key2.key_enc, first);
330                 err = blkcipher_walk_done(desc, &walk,
331                                           walk.nbytes % AES_BLOCK_SIZE);
332         }
333         kernel_neon_end();
334
335         return err;
336 }
337
338 static int xts_decrypt(struct blkcipher_desc *desc, struct scatterlist *dst,
339                        struct scatterlist *src, unsigned int nbytes)
340 {
341         struct crypto_aes_xts_ctx *ctx = crypto_blkcipher_ctx(desc->tfm);
342         int err, first, rounds = num_rounds(&ctx->key1);
343         struct blkcipher_walk walk;
344         unsigned int blocks;
345
346         desc->flags &= ~CRYPTO_TFM_REQ_MAY_SLEEP;
347         blkcipher_walk_init(&walk, dst, src, nbytes);
348         err = blkcipher_walk_virt(desc, &walk);
349
350         kernel_neon_begin();
351         for (first = 1; (blocks = (walk.nbytes / AES_BLOCK_SIZE)); first = 0) {
352                 ce_aes_xts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
353                                    (u8 *)ctx->key1.key_dec, rounds, blocks,
354                                    walk.iv, (u8 *)ctx->key2.key_enc, first);
355                 err = blkcipher_walk_done(desc, &walk,
356                                           walk.nbytes % AES_BLOCK_SIZE);
357         }
358         kernel_neon_end();
359
360         return err;
361 }
362
363 static struct crypto_alg aes_algs[] = { {
364         .cra_name               = "__ecb-aes-ce",
365         .cra_driver_name        = "__driver-ecb-aes-ce",
366         .cra_priority           = 0,
367         .cra_flags              = CRYPTO_ALG_TYPE_BLKCIPHER |
368                                   CRYPTO_ALG_INTERNAL,
369         .cra_blocksize          = AES_BLOCK_SIZE,
370         .cra_ctxsize            = sizeof(struct crypto_aes_ctx),
371         .cra_alignmask          = 7,
372         .cra_type               = &crypto_blkcipher_type,
373         .cra_module             = THIS_MODULE,
374         .cra_blkcipher = {
375                 .min_keysize    = AES_MIN_KEY_SIZE,
376                 .max_keysize    = AES_MAX_KEY_SIZE,
377                 .ivsize         = 0,
378                 .setkey         = ce_aes_setkey,
379                 .encrypt        = ecb_encrypt,
380                 .decrypt        = ecb_decrypt,
381         },
382 }, {
383         .cra_name               = "__cbc-aes-ce",
384         .cra_driver_name        = "__driver-cbc-aes-ce",
385         .cra_priority           = 0,
386         .cra_flags              = CRYPTO_ALG_TYPE_BLKCIPHER |
387                                   CRYPTO_ALG_INTERNAL,
388         .cra_blocksize          = AES_BLOCK_SIZE,
389         .cra_ctxsize            = sizeof(struct crypto_aes_ctx),
390         .cra_alignmask          = 7,
391         .cra_type               = &crypto_blkcipher_type,
392         .cra_module             = THIS_MODULE,
393         .cra_blkcipher = {
394                 .min_keysize    = AES_MIN_KEY_SIZE,
395                 .max_keysize    = AES_MAX_KEY_SIZE,
396                 .ivsize         = AES_BLOCK_SIZE,
397                 .setkey         = ce_aes_setkey,
398                 .encrypt        = cbc_encrypt,
399                 .decrypt        = cbc_decrypt,
400         },
401 }, {
402         .cra_name               = "__ctr-aes-ce",
403         .cra_driver_name        = "__driver-ctr-aes-ce",
404         .cra_priority           = 0,
405         .cra_flags              = CRYPTO_ALG_TYPE_BLKCIPHER |
406                                   CRYPTO_ALG_INTERNAL,
407         .cra_blocksize          = 1,
408         .cra_ctxsize            = sizeof(struct crypto_aes_ctx),
409         .cra_alignmask          = 7,
410         .cra_type               = &crypto_blkcipher_type,
411         .cra_module             = THIS_MODULE,
412         .cra_blkcipher = {
413                 .min_keysize    = AES_MIN_KEY_SIZE,
414                 .max_keysize    = AES_MAX_KEY_SIZE,
415                 .ivsize         = AES_BLOCK_SIZE,
416                 .setkey         = ce_aes_setkey,
417                 .encrypt        = ctr_encrypt,
418                 .decrypt        = ctr_encrypt,
419         },
420 }, {
421         .cra_name               = "__xts-aes-ce",
422         .cra_driver_name        = "__driver-xts-aes-ce",
423         .cra_priority           = 0,
424         .cra_flags              = CRYPTO_ALG_TYPE_BLKCIPHER |
425                                   CRYPTO_ALG_INTERNAL,
426         .cra_blocksize          = AES_BLOCK_SIZE,
427         .cra_ctxsize            = sizeof(struct crypto_aes_xts_ctx),
428         .cra_alignmask          = 7,
429         .cra_type               = &crypto_blkcipher_type,
430         .cra_module             = THIS_MODULE,
431         .cra_blkcipher = {
432                 .min_keysize    = 2 * AES_MIN_KEY_SIZE,
433                 .max_keysize    = 2 * AES_MAX_KEY_SIZE,
434                 .ivsize         = AES_BLOCK_SIZE,
435                 .setkey         = xts_set_key,
436                 .encrypt        = xts_encrypt,
437                 .decrypt        = xts_decrypt,
438         },
439 }, {
440         .cra_name               = "ecb(aes)",
441         .cra_driver_name        = "ecb-aes-ce",
442         .cra_priority           = 300,
443         .cra_flags              = CRYPTO_ALG_TYPE_ABLKCIPHER|CRYPTO_ALG_ASYNC,
444         .cra_blocksize          = AES_BLOCK_SIZE,
445         .cra_ctxsize            = sizeof(struct async_helper_ctx),
446         .cra_alignmask          = 7,
447         .cra_type               = &crypto_ablkcipher_type,
448         .cra_module             = THIS_MODULE,
449         .cra_init               = ablk_init,
450         .cra_exit               = ablk_exit,
451         .cra_ablkcipher = {
452                 .min_keysize    = AES_MIN_KEY_SIZE,
453                 .max_keysize    = AES_MAX_KEY_SIZE,
454                 .ivsize         = 0,
455                 .setkey         = ablk_set_key,
456                 .encrypt        = ablk_encrypt,
457                 .decrypt        = ablk_decrypt,
458         }
459 }, {
460         .cra_name               = "cbc(aes)",
461         .cra_driver_name        = "cbc-aes-ce",
462         .cra_priority           = 300,
463         .cra_flags              = CRYPTO_ALG_TYPE_ABLKCIPHER|CRYPTO_ALG_ASYNC,
464         .cra_blocksize          = AES_BLOCK_SIZE,
465         .cra_ctxsize            = sizeof(struct async_helper_ctx),
466         .cra_alignmask          = 7,
467         .cra_type               = &crypto_ablkcipher_type,
468         .cra_module             = THIS_MODULE,
469         .cra_init               = ablk_init,
470         .cra_exit               = ablk_exit,
471         .cra_ablkcipher = {
472                 .min_keysize    = AES_MIN_KEY_SIZE,
473                 .max_keysize    = AES_MAX_KEY_SIZE,
474                 .ivsize         = AES_BLOCK_SIZE,
475                 .setkey         = ablk_set_key,
476                 .encrypt        = ablk_encrypt,
477                 .decrypt        = ablk_decrypt,
478         }
479 }, {
480         .cra_name               = "ctr(aes)",
481         .cra_driver_name        = "ctr-aes-ce",
482         .cra_priority           = 300,
483         .cra_flags              = CRYPTO_ALG_TYPE_ABLKCIPHER|CRYPTO_ALG_ASYNC,
484         .cra_blocksize          = 1,
485         .cra_ctxsize            = sizeof(struct async_helper_ctx),
486         .cra_alignmask          = 7,
487         .cra_type               = &crypto_ablkcipher_type,
488         .cra_module             = THIS_MODULE,
489         .cra_init               = ablk_init,
490         .cra_exit               = ablk_exit,
491         .cra_ablkcipher = {
492                 .min_keysize    = AES_MIN_KEY_SIZE,
493                 .max_keysize    = AES_MAX_KEY_SIZE,
494                 .ivsize         = AES_BLOCK_SIZE,
495                 .setkey         = ablk_set_key,
496                 .encrypt        = ablk_encrypt,
497                 .decrypt        = ablk_decrypt,
498         }
499 }, {
500         .cra_name               = "xts(aes)",
501         .cra_driver_name        = "xts-aes-ce",
502         .cra_priority           = 300,
503         .cra_flags              = CRYPTO_ALG_TYPE_ABLKCIPHER|CRYPTO_ALG_ASYNC,
504         .cra_blocksize          = AES_BLOCK_SIZE,
505         .cra_ctxsize            = sizeof(struct async_helper_ctx),
506         .cra_alignmask          = 7,
507         .cra_type               = &crypto_ablkcipher_type,
508         .cra_module             = THIS_MODULE,
509         .cra_init               = ablk_init,
510         .cra_exit               = ablk_exit,
511         .cra_ablkcipher = {
512                 .min_keysize    = 2 * AES_MIN_KEY_SIZE,
513                 .max_keysize    = 2 * AES_MAX_KEY_SIZE,
514                 .ivsize         = AES_BLOCK_SIZE,
515                 .setkey         = ablk_set_key,
516                 .encrypt        = ablk_encrypt,
517                 .decrypt        = ablk_decrypt,
518         }
519 } };
520
521 static int __init aes_init(void)
522 {
523         if (!(elf_hwcap2 & HWCAP2_AES))
524                 return -ENODEV;
525         return crypto_register_algs(aes_algs, ARRAY_SIZE(aes_algs));
526 }
527
528 static void __exit aes_exit(void)
529 {
530         crypto_unregister_algs(aes_algs, ARRAY_SIZE(aes_algs));
531 }
532
533 module_init(aes_init);
534 module_exit(aes_exit);