GNU Linux-libre 4.19.286-gnu1
[releases.git] / net / tls / tls_sw.c
1 /*
2  * Copyright (c) 2016-2017, Mellanox Technologies. All rights reserved.
3  * Copyright (c) 2016-2017, Dave Watson <davejwatson@fb.com>. All rights reserved.
4  * Copyright (c) 2016-2017, Lance Chao <lancerchao@fb.com>. All rights reserved.
5  * Copyright (c) 2016, Fridolin Pokorny <fridolin.pokorny@gmail.com>. All rights reserved.
6  * Copyright (c) 2016, Nikos Mavrogiannopoulos <nmav@gnutls.org>. All rights reserved.
7  *
8  * This software is available to you under a choice of one of two
9  * licenses.  You may choose to be licensed under the terms of the GNU
10  * General Public License (GPL) Version 2, available from the file
11  * COPYING in the main directory of this source tree, or the
12  * OpenIB.org BSD license below:
13  *
14  *     Redistribution and use in source and binary forms, with or
15  *     without modification, are permitted provided that the following
16  *     conditions are met:
17  *
18  *      - Redistributions of source code must retain the above
19  *        copyright notice, this list of conditions and the following
20  *        disclaimer.
21  *
22  *      - Redistributions in binary form must reproduce the above
23  *        copyright notice, this list of conditions and the following
24  *        disclaimer in the documentation and/or other materials
25  *        provided with the distribution.
26  *
27  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
28  * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
29  * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
30  * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
31  * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
32  * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
33  * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
34  * SOFTWARE.
35  */
36
37 #include <linux/sched/signal.h>
38 #include <linux/module.h>
39 #include <crypto/aead.h>
40
41 #include <net/strparser.h>
42 #include <net/tls.h>
43
44 #define MAX_IV_SIZE     TLS_CIPHER_AES_GCM_128_IV_SIZE
45
46 static int tls_do_decryption(struct sock *sk,
47                              struct scatterlist *sgin,
48                              struct scatterlist *sgout,
49                              char *iv_recv,
50                              size_t data_len,
51                              struct aead_request *aead_req)
52 {
53         struct tls_context *tls_ctx = tls_get_ctx(sk);
54         struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
55         int ret;
56
57         aead_request_set_tfm(aead_req, ctx->aead_recv);
58         aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE);
59         aead_request_set_crypt(aead_req, sgin, sgout,
60                                data_len + tls_ctx->rx.tag_size,
61                                (u8 *)iv_recv);
62         aead_request_set_callback(aead_req, CRYPTO_TFM_REQ_MAY_BACKLOG,
63                                   crypto_req_done, &ctx->async_wait);
64
65         ret = crypto_wait_req(crypto_aead_decrypt(aead_req), &ctx->async_wait);
66         return ret;
67 }
68
69 static void trim_sg(struct sock *sk, struct scatterlist *sg,
70                     int *sg_num_elem, unsigned int *sg_size, int target_size)
71 {
72         int i = *sg_num_elem - 1;
73         int trim = *sg_size - target_size;
74
75         if (trim <= 0) {
76                 WARN_ON(trim < 0);
77                 return;
78         }
79
80         *sg_size = target_size;
81         while (trim >= sg[i].length) {
82                 trim -= sg[i].length;
83                 sk_mem_uncharge(sk, sg[i].length);
84                 put_page(sg_page(&sg[i]));
85                 i--;
86
87                 if (i < 0)
88                         goto out;
89         }
90
91         sg[i].length -= trim;
92         sk_mem_uncharge(sk, trim);
93
94 out:
95         *sg_num_elem = i + 1;
96 }
97
98 static void trim_both_sgl(struct sock *sk, int target_size)
99 {
100         struct tls_context *tls_ctx = tls_get_ctx(sk);
101         struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
102
103         trim_sg(sk, ctx->sg_plaintext_data,
104                 &ctx->sg_plaintext_num_elem,
105                 &ctx->sg_plaintext_size,
106                 target_size);
107
108         if (target_size > 0)
109                 target_size += tls_ctx->tx.overhead_size;
110
111         trim_sg(sk, ctx->sg_encrypted_data,
112                 &ctx->sg_encrypted_num_elem,
113                 &ctx->sg_encrypted_size,
114                 target_size);
115 }
116
117 static int alloc_encrypted_sg(struct sock *sk, int len)
118 {
119         struct tls_context *tls_ctx = tls_get_ctx(sk);
120         struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
121         int rc = 0;
122
123         rc = sk_alloc_sg(sk, len,
124                          ctx->sg_encrypted_data, 0,
125                          &ctx->sg_encrypted_num_elem,
126                          &ctx->sg_encrypted_size, 0);
127
128         if (rc == -ENOSPC)
129                 ctx->sg_encrypted_num_elem = ARRAY_SIZE(ctx->sg_encrypted_data);
130
131         return rc;
132 }
133
134 static int alloc_plaintext_sg(struct sock *sk, int len)
135 {
136         struct tls_context *tls_ctx = tls_get_ctx(sk);
137         struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
138         int rc = 0;
139
140         rc = sk_alloc_sg(sk, len, ctx->sg_plaintext_data, 0,
141                          &ctx->sg_plaintext_num_elem, &ctx->sg_plaintext_size,
142                          tls_ctx->pending_open_record_frags);
143
144         if (rc == -ENOSPC)
145                 ctx->sg_plaintext_num_elem = ARRAY_SIZE(ctx->sg_plaintext_data);
146
147         return rc;
148 }
149
150 static void free_sg(struct sock *sk, struct scatterlist *sg,
151                     int *sg_num_elem, unsigned int *sg_size)
152 {
153         int i, n = *sg_num_elem;
154
155         for (i = 0; i < n; ++i) {
156                 sk_mem_uncharge(sk, sg[i].length);
157                 put_page(sg_page(&sg[i]));
158         }
159         *sg_num_elem = 0;
160         *sg_size = 0;
161 }
162
163 static void tls_free_both_sg(struct sock *sk)
164 {
165         struct tls_context *tls_ctx = tls_get_ctx(sk);
166         struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
167
168         free_sg(sk, ctx->sg_encrypted_data, &ctx->sg_encrypted_num_elem,
169                 &ctx->sg_encrypted_size);
170
171         free_sg(sk, ctx->sg_plaintext_data, &ctx->sg_plaintext_num_elem,
172                 &ctx->sg_plaintext_size);
173 }
174
175 static int tls_do_encryption(struct tls_context *tls_ctx,
176                              struct tls_sw_context_tx *ctx,
177                              struct aead_request *aead_req,
178                              size_t data_len)
179 {
180         int rc;
181
182         ctx->sg_encrypted_data[0].offset += tls_ctx->tx.prepend_size;
183         ctx->sg_encrypted_data[0].length -= tls_ctx->tx.prepend_size;
184
185         aead_request_set_tfm(aead_req, ctx->aead_send);
186         aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE);
187         aead_request_set_crypt(aead_req, ctx->sg_aead_in, ctx->sg_aead_out,
188                                data_len, tls_ctx->tx.iv);
189
190         aead_request_set_callback(aead_req, CRYPTO_TFM_REQ_MAY_BACKLOG,
191                                   crypto_req_done, &ctx->async_wait);
192
193         rc = crypto_wait_req(crypto_aead_encrypt(aead_req), &ctx->async_wait);
194
195         ctx->sg_encrypted_data[0].offset -= tls_ctx->tx.prepend_size;
196         ctx->sg_encrypted_data[0].length += tls_ctx->tx.prepend_size;
197
198         return rc;
199 }
200
201 static int tls_push_record(struct sock *sk, int flags,
202                            unsigned char record_type)
203 {
204         struct tls_context *tls_ctx = tls_get_ctx(sk);
205         struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
206         struct aead_request *req;
207         int rc;
208
209         req = aead_request_alloc(ctx->aead_send, sk->sk_allocation);
210         if (!req)
211                 return -ENOMEM;
212
213         sg_mark_end(ctx->sg_plaintext_data + ctx->sg_plaintext_num_elem - 1);
214         sg_mark_end(ctx->sg_encrypted_data + ctx->sg_encrypted_num_elem - 1);
215
216         tls_make_aad(ctx->aad_space, ctx->sg_plaintext_size,
217                      tls_ctx->tx.rec_seq, tls_ctx->tx.rec_seq_size,
218                      record_type);
219
220         tls_fill_prepend(tls_ctx,
221                          page_address(sg_page(&ctx->sg_encrypted_data[0])) +
222                          ctx->sg_encrypted_data[0].offset,
223                          ctx->sg_plaintext_size, record_type);
224
225         tls_ctx->pending_open_record_frags = 0;
226         set_bit(TLS_PENDING_CLOSED_RECORD, &tls_ctx->flags);
227
228         rc = tls_do_encryption(tls_ctx, ctx, req, ctx->sg_plaintext_size);
229         if (rc < 0) {
230                 /* If we are called from write_space and
231                  * we fail, we need to set this SOCK_NOSPACE
232                  * to trigger another write_space in the future.
233                  */
234                 set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
235                 goto out_req;
236         }
237
238         free_sg(sk, ctx->sg_plaintext_data, &ctx->sg_plaintext_num_elem,
239                 &ctx->sg_plaintext_size);
240
241         ctx->sg_encrypted_num_elem = 0;
242         ctx->sg_encrypted_size = 0;
243
244         /* Only pass through MSG_DONTWAIT and MSG_NOSIGNAL flags */
245         rc = tls_push_sg(sk, tls_ctx, ctx->sg_encrypted_data, 0, flags);
246         if (rc < 0 && rc != -EAGAIN)
247                 tls_err_abort(sk, EBADMSG);
248
249         tls_advance_record_sn(sk, &tls_ctx->tx);
250 out_req:
251         aead_request_free(req);
252         return rc;
253 }
254
255 static int tls_sw_push_pending_record(struct sock *sk, int flags)
256 {
257         return tls_push_record(sk, flags, TLS_RECORD_TYPE_DATA);
258 }
259
260 static int zerocopy_from_iter(struct sock *sk, struct iov_iter *from,
261                               int length, int *pages_used,
262                               unsigned int *size_used,
263                               struct scatterlist *to, int to_max_pages,
264                               bool charge)
265 {
266         struct page *pages[MAX_SKB_FRAGS];
267
268         size_t offset;
269         ssize_t copied, use;
270         int i = 0;
271         unsigned int size = *size_used;
272         int num_elem = *pages_used;
273         int rc = 0;
274         int maxpages;
275
276         while (length > 0) {
277                 i = 0;
278                 maxpages = to_max_pages - num_elem;
279                 if (maxpages == 0) {
280                         rc = -EFAULT;
281                         goto out;
282                 }
283                 copied = iov_iter_get_pages(from, pages,
284                                             length,
285                                             maxpages, &offset);
286                 if (copied <= 0) {
287                         rc = -EFAULT;
288                         goto out;
289                 }
290
291                 iov_iter_advance(from, copied);
292
293                 length -= copied;
294                 size += copied;
295                 while (copied) {
296                         use = min_t(int, copied, PAGE_SIZE - offset);
297
298                         sg_set_page(&to[num_elem],
299                                     pages[i], use, offset);
300                         sg_unmark_end(&to[num_elem]);
301                         if (charge)
302                                 sk_mem_charge(sk, use);
303
304                         offset = 0;
305                         copied -= use;
306
307                         ++i;
308                         ++num_elem;
309                 }
310         }
311
312         /* Mark the end in the last sg entry if newly added */
313         if (num_elem > *pages_used)
314                 sg_mark_end(&to[num_elem - 1]);
315 out:
316         if (rc)
317                 iov_iter_revert(from, size - *size_used);
318         *size_used = size;
319         *pages_used = num_elem;
320
321         return rc;
322 }
323
324 static int memcopy_from_iter(struct sock *sk, struct iov_iter *from,
325                              int bytes)
326 {
327         struct tls_context *tls_ctx = tls_get_ctx(sk);
328         struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
329         struct scatterlist *sg = ctx->sg_plaintext_data;
330         int copy, i, rc = 0;
331
332         for (i = tls_ctx->pending_open_record_frags;
333              i < ctx->sg_plaintext_num_elem; ++i) {
334                 copy = sg[i].length;
335                 if (copy_from_iter(
336                                 page_address(sg_page(&sg[i])) + sg[i].offset,
337                                 copy, from) != copy) {
338                         rc = -EFAULT;
339                         goto out;
340                 }
341                 bytes -= copy;
342
343                 ++tls_ctx->pending_open_record_frags;
344
345                 if (!bytes)
346                         break;
347         }
348
349 out:
350         return rc;
351 }
352
353 int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
354 {
355         struct tls_context *tls_ctx = tls_get_ctx(sk);
356         struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
357         int ret;
358         int required_size;
359         long timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
360         bool eor = !(msg->msg_flags & MSG_MORE);
361         size_t try_to_copy, copied = 0;
362         unsigned char record_type = TLS_RECORD_TYPE_DATA;
363         int record_room;
364         bool full_record;
365         int orig_size;
366         bool is_kvec = msg->msg_iter.type & ITER_KVEC;
367
368         if (msg->msg_flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL))
369                 return -ENOTSUPP;
370
371         lock_sock(sk);
372
373         ret = tls_complete_pending_work(sk, tls_ctx, msg->msg_flags, &timeo);
374         if (ret)
375                 goto send_end;
376
377         if (unlikely(msg->msg_controllen)) {
378                 ret = tls_proccess_cmsg(sk, msg, &record_type);
379                 if (ret)
380                         goto send_end;
381         }
382
383         while (msg_data_left(msg)) {
384                 if (sk->sk_err) {
385                         ret = -sk->sk_err;
386                         goto send_end;
387                 }
388
389                 orig_size = ctx->sg_plaintext_size;
390                 full_record = false;
391                 try_to_copy = msg_data_left(msg);
392                 record_room = TLS_MAX_PAYLOAD_SIZE - ctx->sg_plaintext_size;
393                 if (try_to_copy >= record_room) {
394                         try_to_copy = record_room;
395                         full_record = true;
396                 }
397
398                 required_size = ctx->sg_plaintext_size + try_to_copy +
399                                 tls_ctx->tx.overhead_size;
400
401                 if (!sk_stream_memory_free(sk))
402                         goto wait_for_sndbuf;
403 alloc_encrypted:
404                 ret = alloc_encrypted_sg(sk, required_size);
405                 if (ret) {
406                         if (ret != -ENOSPC)
407                                 goto wait_for_memory;
408
409                         /* Adjust try_to_copy according to the amount that was
410                          * actually allocated. The difference is due
411                          * to max sg elements limit
412                          */
413                         try_to_copy -= required_size - ctx->sg_encrypted_size;
414                         full_record = true;
415                 }
416                 if (!is_kvec && (full_record || eor)) {
417                         ret = zerocopy_from_iter(sk, &msg->msg_iter,
418                                 try_to_copy, &ctx->sg_plaintext_num_elem,
419                                 &ctx->sg_plaintext_size,
420                                 ctx->sg_plaintext_data,
421                                 ARRAY_SIZE(ctx->sg_plaintext_data),
422                                 true);
423                         if (ret)
424                                 goto fallback_to_reg_send;
425
426                         copied += try_to_copy;
427                         ret = tls_push_record(sk, msg->msg_flags, record_type);
428                         if (ret)
429                                 goto send_end;
430                         continue;
431
432 fallback_to_reg_send:
433                         trim_sg(sk, ctx->sg_plaintext_data,
434                                 &ctx->sg_plaintext_num_elem,
435                                 &ctx->sg_plaintext_size,
436                                 orig_size);
437                 }
438
439                 required_size = ctx->sg_plaintext_size + try_to_copy;
440 alloc_plaintext:
441                 ret = alloc_plaintext_sg(sk, required_size);
442                 if (ret) {
443                         if (ret != -ENOSPC)
444                                 goto wait_for_memory;
445
446                         /* Adjust try_to_copy according to the amount that was
447                          * actually allocated. The difference is due
448                          * to max sg elements limit
449                          */
450                         try_to_copy -= required_size - ctx->sg_plaintext_size;
451                         full_record = true;
452
453                         trim_sg(sk, ctx->sg_encrypted_data,
454                                 &ctx->sg_encrypted_num_elem,
455                                 &ctx->sg_encrypted_size,
456                                 ctx->sg_plaintext_size +
457                                 tls_ctx->tx.overhead_size);
458                 }
459
460                 ret = memcopy_from_iter(sk, &msg->msg_iter, try_to_copy);
461                 if (ret)
462                         goto trim_sgl;
463
464                 copied += try_to_copy;
465                 if (full_record || eor) {
466 push_record:
467                         ret = tls_push_record(sk, msg->msg_flags, record_type);
468                         if (ret) {
469                                 if (ret == -ENOMEM)
470                                         goto wait_for_memory;
471
472                                 goto send_end;
473                         }
474                 }
475
476                 continue;
477
478 wait_for_sndbuf:
479                 set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
480 wait_for_memory:
481                 ret = sk_stream_wait_memory(sk, &timeo);
482                 if (ret) {
483 trim_sgl:
484                         trim_both_sgl(sk, orig_size);
485                         goto send_end;
486                 }
487
488                 if (tls_is_pending_closed_record(tls_ctx))
489                         goto push_record;
490
491                 if (ctx->sg_encrypted_size < required_size)
492                         goto alloc_encrypted;
493
494                 goto alloc_plaintext;
495         }
496
497 send_end:
498         ret = sk_stream_error(sk, msg->msg_flags, ret);
499
500         release_sock(sk);
501         return copied ? copied : ret;
502 }
503
504 int tls_sw_sendpage(struct sock *sk, struct page *page,
505                     int offset, size_t size, int flags)
506 {
507         struct tls_context *tls_ctx = tls_get_ctx(sk);
508         struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
509         int ret;
510         long timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT);
511         bool eor;
512         size_t orig_size = size;
513         unsigned char record_type = TLS_RECORD_TYPE_DATA;
514         struct scatterlist *sg;
515         bool full_record;
516         int record_room;
517
518         if (flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL |
519                       MSG_SENDPAGE_NOTLAST))
520                 return -ENOTSUPP;
521
522         /* No MSG_EOR from splice, only look at MSG_MORE */
523         eor = !(flags & (MSG_MORE | MSG_SENDPAGE_NOTLAST));
524
525         lock_sock(sk);
526
527         sk_clear_bit(SOCKWQ_ASYNC_NOSPACE, sk);
528
529         ret = tls_complete_pending_work(sk, tls_ctx, flags, &timeo);
530         if (ret)
531                 goto sendpage_end;
532
533         /* Call the sk_stream functions to manage the sndbuf mem. */
534         while (size > 0) {
535                 size_t copy, required_size;
536
537                 if (sk->sk_err) {
538                         ret = -sk->sk_err;
539                         goto sendpage_end;
540                 }
541
542                 full_record = false;
543                 record_room = TLS_MAX_PAYLOAD_SIZE - ctx->sg_plaintext_size;
544                 copy = size;
545                 if (copy >= record_room) {
546                         copy = record_room;
547                         full_record = true;
548                 }
549                 required_size = ctx->sg_plaintext_size + copy +
550                               tls_ctx->tx.overhead_size;
551
552                 if (!sk_stream_memory_free(sk))
553                         goto wait_for_sndbuf;
554 alloc_payload:
555                 ret = alloc_encrypted_sg(sk, required_size);
556                 if (ret) {
557                         if (ret != -ENOSPC)
558                                 goto wait_for_memory;
559
560                         /* Adjust copy according to the amount that was
561                          * actually allocated. The difference is due
562                          * to max sg elements limit
563                          */
564                         copy -= required_size - ctx->sg_plaintext_size;
565                         full_record = true;
566                 }
567
568                 get_page(page);
569                 sg = ctx->sg_plaintext_data + ctx->sg_plaintext_num_elem;
570                 sg_set_page(sg, page, copy, offset);
571                 sg_unmark_end(sg);
572
573                 ctx->sg_plaintext_num_elem++;
574
575                 sk_mem_charge(sk, copy);
576                 offset += copy;
577                 size -= copy;
578                 ctx->sg_plaintext_size += copy;
579                 tls_ctx->pending_open_record_frags = ctx->sg_plaintext_num_elem;
580
581                 if (full_record || eor ||
582                     ctx->sg_plaintext_num_elem ==
583                     ARRAY_SIZE(ctx->sg_plaintext_data)) {
584 push_record:
585                         ret = tls_push_record(sk, flags, record_type);
586                         if (ret) {
587                                 if (ret == -ENOMEM)
588                                         goto wait_for_memory;
589
590                                 goto sendpage_end;
591                         }
592                 }
593                 continue;
594 wait_for_sndbuf:
595                 set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
596 wait_for_memory:
597                 ret = sk_stream_wait_memory(sk, &timeo);
598                 if (ret) {
599                         trim_both_sgl(sk, ctx->sg_plaintext_size);
600                         goto sendpage_end;
601                 }
602
603                 if (tls_is_pending_closed_record(tls_ctx))
604                         goto push_record;
605
606                 goto alloc_payload;
607         }
608
609 sendpage_end:
610         if (orig_size > size)
611                 ret = orig_size - size;
612         else
613                 ret = sk_stream_error(sk, flags, ret);
614
615         release_sock(sk);
616         return ret;
617 }
618
619 static struct sk_buff *tls_wait_data(struct sock *sk, int flags,
620                                      long timeo, int *err)
621 {
622         struct tls_context *tls_ctx = tls_get_ctx(sk);
623         struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
624         struct sk_buff *skb;
625         DEFINE_WAIT_FUNC(wait, woken_wake_function);
626
627         while (!(skb = ctx->recv_pkt)) {
628                 if (sk->sk_err) {
629                         *err = sock_error(sk);
630                         return NULL;
631                 }
632
633                 if (!skb_queue_empty(&sk->sk_receive_queue)) {
634                         __strp_unpause(&ctx->strp);
635                         if (ctx->recv_pkt)
636                                 return ctx->recv_pkt;
637                 }
638
639                 if (sk->sk_shutdown & RCV_SHUTDOWN)
640                         return NULL;
641
642                 if (sock_flag(sk, SOCK_DONE))
643                         return NULL;
644
645                 if ((flags & MSG_DONTWAIT) || !timeo) {
646                         *err = -EAGAIN;
647                         return NULL;
648                 }
649
650                 add_wait_queue(sk_sleep(sk), &wait);
651                 sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
652                 sk_wait_event(sk, &timeo, ctx->recv_pkt != skb, &wait);
653                 sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
654                 remove_wait_queue(sk_sleep(sk), &wait);
655
656                 /* Handle signals */
657                 if (signal_pending(current)) {
658                         *err = sock_intr_errno(timeo);
659                         return NULL;
660                 }
661         }
662
663         return skb;
664 }
665
666 /* This function decrypts the input skb into either out_iov or in out_sg
667  * or in skb buffers itself. The input parameter 'zc' indicates if
668  * zero-copy mode needs to be tried or not. With zero-copy mode, either
669  * out_iov or out_sg must be non-NULL. In case both out_iov and out_sg are
670  * NULL, then the decryption happens inside skb buffers itself, i.e.
671  * zero-copy gets disabled and 'zc' is updated.
672  */
673
674 static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
675                             struct iov_iter *out_iov,
676                             struct scatterlist *out_sg,
677                             int *chunk, bool *zc)
678 {
679         struct tls_context *tls_ctx = tls_get_ctx(sk);
680         struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
681         struct strp_msg *rxm = strp_msg(skb);
682         int n_sgin, n_sgout, nsg, mem_size, aead_size, err, pages = 0;
683         struct aead_request *aead_req;
684         struct sk_buff *unused;
685         u8 *aad, *iv, *mem = NULL;
686         struct scatterlist *sgin = NULL;
687         struct scatterlist *sgout = NULL;
688         const int data_len = rxm->full_len - tls_ctx->rx.overhead_size;
689
690         if (*zc && (out_iov || out_sg)) {
691                 if (out_iov)
692                         n_sgout = iov_iter_npages(out_iov, INT_MAX) + 1;
693                 else
694                         n_sgout = sg_nents(out_sg);
695         } else {
696                 n_sgout = 0;
697                 *zc = false;
698         }
699
700         n_sgin = skb_cow_data(skb, 0, &unused);
701         if (n_sgin < 1)
702                 return -EBADMSG;
703
704         /* Increment to accommodate AAD */
705         n_sgin = n_sgin + 1;
706
707         nsg = n_sgin + n_sgout;
708
709         aead_size = sizeof(*aead_req) + crypto_aead_reqsize(ctx->aead_recv);
710         mem_size = aead_size + (nsg * sizeof(struct scatterlist));
711         mem_size = mem_size + TLS_AAD_SPACE_SIZE;
712         mem_size = mem_size + crypto_aead_ivsize(ctx->aead_recv);
713
714         /* Allocate a single block of memory which contains
715          * aead_req || sgin[] || sgout[] || aad || iv.
716          * This order achieves correct alignment for aead_req, sgin, sgout.
717          */
718         mem = kmalloc(mem_size, sk->sk_allocation);
719         if (!mem)
720                 return -ENOMEM;
721
722         /* Segment the allocated memory */
723         aead_req = (struct aead_request *)mem;
724         sgin = (struct scatterlist *)(mem + aead_size);
725         sgout = sgin + n_sgin;
726         aad = (u8 *)(sgout + n_sgout);
727         iv = aad + TLS_AAD_SPACE_SIZE;
728
729         /* Prepare IV */
730         err = skb_copy_bits(skb, rxm->offset + TLS_HEADER_SIZE,
731                             iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
732                             tls_ctx->rx.iv_size);
733         if (err < 0) {
734                 kfree(mem);
735                 return err;
736         }
737         memcpy(iv, tls_ctx->rx.iv, TLS_CIPHER_AES_GCM_128_SALT_SIZE);
738
739         /* Prepare AAD */
740         tls_make_aad(aad, rxm->full_len - tls_ctx->rx.overhead_size,
741                      tls_ctx->rx.rec_seq, tls_ctx->rx.rec_seq_size,
742                      ctx->control);
743
744         /* Prepare sgin */
745         sg_init_table(sgin, n_sgin);
746         sg_set_buf(&sgin[0], aad, TLS_AAD_SPACE_SIZE);
747         err = skb_to_sgvec(skb, &sgin[1],
748                            rxm->offset + tls_ctx->rx.prepend_size,
749                            rxm->full_len - tls_ctx->rx.prepend_size);
750         if (err < 0) {
751                 kfree(mem);
752                 return err;
753         }
754
755         if (n_sgout) {
756                 if (out_iov) {
757                         sg_init_table(sgout, n_sgout);
758                         sg_set_buf(&sgout[0], aad, TLS_AAD_SPACE_SIZE);
759
760                         *chunk = 0;
761                         err = zerocopy_from_iter(sk, out_iov, data_len, &pages,
762                                                  chunk, &sgout[1],
763                                                  (n_sgout - 1), false);
764                         if (err < 0)
765                                 goto fallback_to_reg_recv;
766                 } else if (out_sg) {
767                         memcpy(sgout, out_sg, n_sgout * sizeof(*sgout));
768                 } else {
769                         goto fallback_to_reg_recv;
770                 }
771         } else {
772 fallback_to_reg_recv:
773                 sgout = sgin;
774                 pages = 0;
775                 *chunk = 0;
776                 *zc = false;
777         }
778
779         /* Prepare and submit AEAD request */
780         err = tls_do_decryption(sk, sgin, sgout, iv, data_len, aead_req);
781
782         /* Release the pages in case iov was mapped to pages */
783         for (; pages > 0; pages--)
784                 put_page(sg_page(&sgout[pages]));
785
786         kfree(mem);
787         return err;
788 }
789
790 static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
791                               struct iov_iter *dest, int *chunk, bool *zc)
792 {
793         struct tls_context *tls_ctx = tls_get_ctx(sk);
794         struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
795         struct strp_msg *rxm = strp_msg(skb);
796         int err = 0;
797
798 #ifdef CONFIG_TLS_DEVICE
799         err = tls_device_decrypted(sk, skb);
800         if (err < 0)
801                 return err;
802 #endif
803         if (!ctx->decrypted) {
804                 err = decrypt_internal(sk, skb, dest, NULL, chunk, zc);
805                 if (err < 0)
806                         return err;
807         } else {
808                 *zc = false;
809         }
810
811         rxm->offset += tls_ctx->rx.prepend_size;
812         rxm->full_len -= tls_ctx->rx.overhead_size;
813         tls_advance_record_sn(sk, &tls_ctx->rx);
814         ctx->decrypted = true;
815         ctx->saved_data_ready(sk);
816
817         return err;
818 }
819
820 int decrypt_skb(struct sock *sk, struct sk_buff *skb,
821                 struct scatterlist *sgout)
822 {
823         bool zc = true;
824         int chunk;
825
826         return decrypt_internal(sk, skb, NULL, sgout, &chunk, &zc);
827 }
828
829 static bool tls_sw_advance_skb(struct sock *sk, struct sk_buff *skb,
830                                unsigned int len)
831 {
832         struct tls_context *tls_ctx = tls_get_ctx(sk);
833         struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
834         struct strp_msg *rxm = strp_msg(skb);
835
836         if (len < rxm->full_len) {
837                 rxm->offset += len;
838                 rxm->full_len -= len;
839
840                 return false;
841         }
842
843         /* Finished with message */
844         ctx->recv_pkt = NULL;
845         kfree_skb(skb);
846         __strp_unpause(&ctx->strp);
847
848         return true;
849 }
850
851 int tls_sw_recvmsg(struct sock *sk,
852                    struct msghdr *msg,
853                    size_t len,
854                    int nonblock,
855                    int flags,
856                    int *addr_len)
857 {
858         struct tls_context *tls_ctx = tls_get_ctx(sk);
859         struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
860         unsigned char control;
861         struct strp_msg *rxm;
862         struct sk_buff *skb;
863         ssize_t copied = 0;
864         bool cmsg = false;
865         int target, err = 0;
866         long timeo;
867         bool is_kvec = msg->msg_iter.type & ITER_KVEC;
868
869         flags |= nonblock;
870
871         if (unlikely(flags & MSG_ERRQUEUE))
872                 return sock_recv_errqueue(sk, msg, len, SOL_IP, IP_RECVERR);
873
874         lock_sock(sk);
875
876         target = sock_rcvlowat(sk, flags & MSG_WAITALL, len);
877         timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
878         do {
879                 bool zc = false;
880                 int chunk = 0;
881
882                 skb = tls_wait_data(sk, flags, timeo, &err);
883                 if (!skb)
884                         goto recv_end;
885
886                 rxm = strp_msg(skb);
887                 if (!cmsg) {
888                         int cerr;
889
890                         cerr = put_cmsg(msg, SOL_TLS, TLS_GET_RECORD_TYPE,
891                                         sizeof(ctx->control), &ctx->control);
892                         cmsg = true;
893                         control = ctx->control;
894                         if (ctx->control != TLS_RECORD_TYPE_DATA) {
895                                 if (cerr || msg->msg_flags & MSG_CTRUNC) {
896                                         err = -EIO;
897                                         goto recv_end;
898                                 }
899                         }
900                 } else if (control != ctx->control) {
901                         goto recv_end;
902                 }
903
904                 if (!ctx->decrypted) {
905                         int to_copy = rxm->full_len - tls_ctx->rx.overhead_size;
906
907                         if (!is_kvec && to_copy <= len &&
908                             likely(!(flags & MSG_PEEK)))
909                                 zc = true;
910
911                         err = decrypt_skb_update(sk, skb, &msg->msg_iter,
912                                                  &chunk, &zc);
913                         if (err < 0) {
914                                 tls_err_abort(sk, EBADMSG);
915                                 goto recv_end;
916                         }
917                         ctx->decrypted = true;
918                 }
919
920                 if (!zc) {
921                         chunk = min_t(unsigned int, rxm->full_len, len);
922                         err = skb_copy_datagram_msg(skb, rxm->offset, msg,
923                                                     chunk);
924                         if (err < 0)
925                                 goto recv_end;
926                 }
927
928                 copied += chunk;
929                 len -= chunk;
930                 if (likely(!(flags & MSG_PEEK))) {
931                         u8 control = ctx->control;
932
933                         if (tls_sw_advance_skb(sk, skb, chunk)) {
934                                 /* Return full control message to
935                                  * userspace before trying to parse
936                                  * another message type
937                                  */
938                                 msg->msg_flags |= MSG_EOR;
939                                 if (control != TLS_RECORD_TYPE_DATA)
940                                         goto recv_end;
941                         }
942                 } else {
943                         /* MSG_PEEK right now cannot look beyond current skb
944                          * from strparser, meaning we cannot advance skb here
945                          * and thus unpause strparser since we'd loose original
946                          * one.
947                          */
948                         break;
949                 }
950
951                 /* If we have a new message from strparser, continue now. */
952                 if (copied >= target && !ctx->recv_pkt)
953                         break;
954         } while (len);
955
956 recv_end:
957         release_sock(sk);
958         return copied ? : err;
959 }
960
961 ssize_t tls_sw_splice_read(struct socket *sock,  loff_t *ppos,
962                            struct pipe_inode_info *pipe,
963                            size_t len, unsigned int flags)
964 {
965         struct tls_context *tls_ctx = tls_get_ctx(sock->sk);
966         struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
967         struct strp_msg *rxm = NULL;
968         struct sock *sk = sock->sk;
969         struct sk_buff *skb;
970         ssize_t copied = 0;
971         int err = 0;
972         long timeo;
973         int chunk;
974         bool zc = false;
975
976         lock_sock(sk);
977
978         timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
979
980         skb = tls_wait_data(sk, flags, timeo, &err);
981         if (!skb)
982                 goto splice_read_end;
983
984         /* splice does not support reading control messages */
985         if (ctx->control != TLS_RECORD_TYPE_DATA) {
986                 err = -ENOTSUPP;
987                 goto splice_read_end;
988         }
989
990         if (!ctx->decrypted) {
991                 err = decrypt_skb_update(sk, skb, NULL, &chunk, &zc);
992
993                 if (err < 0) {
994                         tls_err_abort(sk, EBADMSG);
995                         goto splice_read_end;
996                 }
997                 ctx->decrypted = true;
998         }
999         rxm = strp_msg(skb);
1000
1001         chunk = min_t(unsigned int, rxm->full_len, len);
1002         copied = skb_splice_bits(skb, sk, rxm->offset, pipe, chunk, flags);
1003         if (copied < 0)
1004                 goto splice_read_end;
1005
1006         if (likely(!(flags & MSG_PEEK)))
1007                 tls_sw_advance_skb(sk, skb, copied);
1008
1009 splice_read_end:
1010         release_sock(sk);
1011         return copied ? : err;
1012 }
1013
1014 unsigned int tls_sw_poll(struct file *file, struct socket *sock,
1015                          struct poll_table_struct *wait)
1016 {
1017         unsigned int ret;
1018         struct sock *sk = sock->sk;
1019         struct tls_context *tls_ctx = tls_get_ctx(sk);
1020         struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1021
1022         /* Grab POLLOUT and POLLHUP from the underlying socket */
1023         ret = ctx->sk_poll(file, sock, wait);
1024
1025         /* Clear POLLIN bits, and set based on recv_pkt */
1026         ret &= ~(POLLIN | POLLRDNORM);
1027         if (ctx->recv_pkt)
1028                 ret |= POLLIN | POLLRDNORM;
1029
1030         return ret;
1031 }
1032
1033 static int tls_read_size(struct strparser *strp, struct sk_buff *skb)
1034 {
1035         struct tls_context *tls_ctx = tls_get_ctx(strp->sk);
1036         struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1037         char header[TLS_HEADER_SIZE + MAX_IV_SIZE];
1038         struct strp_msg *rxm = strp_msg(skb);
1039         size_t cipher_overhead;
1040         size_t data_len = 0;
1041         int ret;
1042
1043         /* Verify that we have a full TLS header, or wait for more data */
1044         if (rxm->offset + tls_ctx->rx.prepend_size > skb->len)
1045                 return 0;
1046
1047         /* Sanity-check size of on-stack buffer. */
1048         if (WARN_ON(tls_ctx->rx.prepend_size > sizeof(header))) {
1049                 ret = -EINVAL;
1050                 goto read_failure;
1051         }
1052
1053         /* Linearize header to local buffer */
1054         ret = skb_copy_bits(skb, rxm->offset, header, tls_ctx->rx.prepend_size);
1055
1056         if (ret < 0)
1057                 goto read_failure;
1058
1059         ctx->control = header[0];
1060
1061         data_len = ((header[4] & 0xFF) | (header[3] << 8));
1062
1063         cipher_overhead = tls_ctx->rx.tag_size + tls_ctx->rx.iv_size;
1064
1065         if (data_len > TLS_MAX_PAYLOAD_SIZE + cipher_overhead) {
1066                 ret = -EMSGSIZE;
1067                 goto read_failure;
1068         }
1069         if (data_len < cipher_overhead) {
1070                 ret = -EBADMSG;
1071                 goto read_failure;
1072         }
1073
1074         if (header[1] != TLS_VERSION_MINOR(tls_ctx->crypto_recv.info.version) ||
1075             header[2] != TLS_VERSION_MAJOR(tls_ctx->crypto_recv.info.version)) {
1076                 ret = -EINVAL;
1077                 goto read_failure;
1078         }
1079
1080 #ifdef CONFIG_TLS_DEVICE
1081         handle_device_resync(strp->sk, TCP_SKB_CB(skb)->seq + rxm->offset,
1082                              *(u64*)tls_ctx->rx.rec_seq);
1083 #endif
1084         return data_len + TLS_HEADER_SIZE;
1085
1086 read_failure:
1087         tls_err_abort(strp->sk, ret);
1088
1089         return ret;
1090 }
1091
1092 static void tls_queue(struct strparser *strp, struct sk_buff *skb)
1093 {
1094         struct tls_context *tls_ctx = tls_get_ctx(strp->sk);
1095         struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1096
1097         ctx->decrypted = false;
1098
1099         ctx->recv_pkt = skb;
1100         strp_pause(strp);
1101
1102         ctx->saved_data_ready(strp->sk);
1103 }
1104
1105 static void tls_data_ready(struct sock *sk)
1106 {
1107         struct tls_context *tls_ctx = tls_get_ctx(sk);
1108         struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1109
1110         strp_data_ready(&ctx->strp);
1111 }
1112
1113 void tls_sw_free_resources_tx(struct sock *sk)
1114 {
1115         struct tls_context *tls_ctx = tls_get_ctx(sk);
1116         struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
1117
1118         crypto_free_aead(ctx->aead_send);
1119         tls_free_both_sg(sk);
1120
1121         kfree(ctx);
1122 }
1123
1124 void tls_sw_release_resources_rx(struct sock *sk)
1125 {
1126         struct tls_context *tls_ctx = tls_get_ctx(sk);
1127         struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1128
1129         kfree(tls_ctx->rx.rec_seq);
1130         kfree(tls_ctx->rx.iv);
1131
1132         if (ctx->aead_recv) {
1133                 kfree_skb(ctx->recv_pkt);
1134                 ctx->recv_pkt = NULL;
1135                 crypto_free_aead(ctx->aead_recv);
1136                 strp_stop(&ctx->strp);
1137                 write_lock_bh(&sk->sk_callback_lock);
1138                 sk->sk_data_ready = ctx->saved_data_ready;
1139                 write_unlock_bh(&sk->sk_callback_lock);
1140                 release_sock(sk);
1141                 strp_done(&ctx->strp);
1142                 lock_sock(sk);
1143         }
1144 }
1145
1146 void tls_sw_free_resources_rx(struct sock *sk)
1147 {
1148         struct tls_context *tls_ctx = tls_get_ctx(sk);
1149         struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1150
1151         tls_sw_release_resources_rx(sk);
1152
1153         kfree(ctx);
1154 }
1155
1156 int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
1157 {
1158         struct tls_crypto_info *crypto_info;
1159         struct tls12_crypto_info_aes_gcm_128 *gcm_128_info;
1160         struct tls_sw_context_tx *sw_ctx_tx = NULL;
1161         struct tls_sw_context_rx *sw_ctx_rx = NULL;
1162         struct cipher_context *cctx;
1163         struct crypto_aead **aead;
1164         struct strp_callbacks cb;
1165         u16 nonce_size, tag_size, iv_size, rec_seq_size;
1166         char *iv, *rec_seq;
1167         int rc = 0;
1168
1169         if (!ctx) {
1170                 rc = -EINVAL;
1171                 goto out;
1172         }
1173
1174         if (tx) {
1175                 if (!ctx->priv_ctx_tx) {
1176                         sw_ctx_tx = kzalloc(sizeof(*sw_ctx_tx), GFP_KERNEL);
1177                         if (!sw_ctx_tx) {
1178                                 rc = -ENOMEM;
1179                                 goto out;
1180                         }
1181                         ctx->priv_ctx_tx = sw_ctx_tx;
1182                 } else {
1183                         sw_ctx_tx =
1184                                 (struct tls_sw_context_tx *)ctx->priv_ctx_tx;
1185                 }
1186         } else {
1187                 if (!ctx->priv_ctx_rx) {
1188                         sw_ctx_rx = kzalloc(sizeof(*sw_ctx_rx), GFP_KERNEL);
1189                         if (!sw_ctx_rx) {
1190                                 rc = -ENOMEM;
1191                                 goto out;
1192                         }
1193                         ctx->priv_ctx_rx = sw_ctx_rx;
1194                 } else {
1195                         sw_ctx_rx =
1196                                 (struct tls_sw_context_rx *)ctx->priv_ctx_rx;
1197                 }
1198         }
1199
1200         if (tx) {
1201                 crypto_init_wait(&sw_ctx_tx->async_wait);
1202                 crypto_info = &ctx->crypto_send.info;
1203                 cctx = &ctx->tx;
1204                 aead = &sw_ctx_tx->aead_send;
1205         } else {
1206                 crypto_init_wait(&sw_ctx_rx->async_wait);
1207                 crypto_info = &ctx->crypto_recv.info;
1208                 cctx = &ctx->rx;
1209                 aead = &sw_ctx_rx->aead_recv;
1210         }
1211
1212         switch (crypto_info->cipher_type) {
1213         case TLS_CIPHER_AES_GCM_128: {
1214                 nonce_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
1215                 tag_size = TLS_CIPHER_AES_GCM_128_TAG_SIZE;
1216                 iv_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
1217                 iv = ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->iv;
1218                 rec_seq_size = TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE;
1219                 rec_seq =
1220                  ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->rec_seq;
1221                 gcm_128_info =
1222                         (struct tls12_crypto_info_aes_gcm_128 *)crypto_info;
1223                 break;
1224         }
1225         default:
1226                 rc = -EINVAL;
1227                 goto free_priv;
1228         }
1229
1230         /* Sanity-check the IV size for stack allocations. */
1231         if (iv_size > MAX_IV_SIZE || nonce_size > MAX_IV_SIZE) {
1232                 rc = -EINVAL;
1233                 goto free_priv;
1234         }
1235
1236         cctx->prepend_size = TLS_HEADER_SIZE + nonce_size;
1237         cctx->tag_size = tag_size;
1238         cctx->overhead_size = cctx->prepend_size + cctx->tag_size;
1239         cctx->iv_size = iv_size;
1240         cctx->iv = kmalloc(iv_size + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
1241                            GFP_KERNEL);
1242         if (!cctx->iv) {
1243                 rc = -ENOMEM;
1244                 goto free_priv;
1245         }
1246         memcpy(cctx->iv, gcm_128_info->salt, TLS_CIPHER_AES_GCM_128_SALT_SIZE);
1247         memcpy(cctx->iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv, iv_size);
1248         cctx->rec_seq_size = rec_seq_size;
1249         cctx->rec_seq = kmemdup(rec_seq, rec_seq_size, GFP_KERNEL);
1250         if (!cctx->rec_seq) {
1251                 rc = -ENOMEM;
1252                 goto free_iv;
1253         }
1254
1255         if (sw_ctx_tx) {
1256                 sg_init_table(sw_ctx_tx->sg_encrypted_data,
1257                               ARRAY_SIZE(sw_ctx_tx->sg_encrypted_data));
1258                 sg_init_table(sw_ctx_tx->sg_plaintext_data,
1259                               ARRAY_SIZE(sw_ctx_tx->sg_plaintext_data));
1260
1261                 sg_init_table(sw_ctx_tx->sg_aead_in, 2);
1262                 sg_set_buf(&sw_ctx_tx->sg_aead_in[0], sw_ctx_tx->aad_space,
1263                            sizeof(sw_ctx_tx->aad_space));
1264                 sg_unmark_end(&sw_ctx_tx->sg_aead_in[1]);
1265                 sg_chain(sw_ctx_tx->sg_aead_in, 2,
1266                          sw_ctx_tx->sg_plaintext_data);
1267                 sg_init_table(sw_ctx_tx->sg_aead_out, 2);
1268                 sg_set_buf(&sw_ctx_tx->sg_aead_out[0], sw_ctx_tx->aad_space,
1269                            sizeof(sw_ctx_tx->aad_space));
1270                 sg_unmark_end(&sw_ctx_tx->sg_aead_out[1]);
1271                 sg_chain(sw_ctx_tx->sg_aead_out, 2,
1272                          sw_ctx_tx->sg_encrypted_data);
1273         }
1274
1275         if (!*aead) {
1276                 *aead = crypto_alloc_aead("gcm(aes)", 0, 0);
1277                 if (IS_ERR(*aead)) {
1278                         rc = PTR_ERR(*aead);
1279                         *aead = NULL;
1280                         goto free_rec_seq;
1281                 }
1282         }
1283
1284         ctx->push_pending_record = tls_sw_push_pending_record;
1285
1286         rc = crypto_aead_setkey(*aead, gcm_128_info->key,
1287                                 TLS_CIPHER_AES_GCM_128_KEY_SIZE);
1288         if (rc)
1289                 goto free_aead;
1290
1291         rc = crypto_aead_setauthsize(*aead, cctx->tag_size);
1292         if (rc)
1293                 goto free_aead;
1294
1295         if (sw_ctx_rx) {
1296                 /* Set up strparser */
1297                 memset(&cb, 0, sizeof(cb));
1298                 cb.rcv_msg = tls_queue;
1299                 cb.parse_msg = tls_read_size;
1300
1301                 strp_init(&sw_ctx_rx->strp, sk, &cb);
1302
1303                 write_lock_bh(&sk->sk_callback_lock);
1304                 sw_ctx_rx->saved_data_ready = sk->sk_data_ready;
1305                 sk->sk_data_ready = tls_data_ready;
1306                 write_unlock_bh(&sk->sk_callback_lock);
1307
1308                 sw_ctx_rx->sk_poll = sk->sk_socket->ops->poll;
1309
1310                 strp_check_rcv(&sw_ctx_rx->strp);
1311         }
1312
1313         goto out;
1314
1315 free_aead:
1316         crypto_free_aead(*aead);
1317         *aead = NULL;
1318 free_rec_seq:
1319         kfree(cctx->rec_seq);
1320         cctx->rec_seq = NULL;
1321 free_iv:
1322         kfree(cctx->iv);
1323         cctx->iv = NULL;
1324 free_priv:
1325         if (tx) {
1326                 kfree(ctx->priv_ctx_tx);
1327                 ctx->priv_ctx_tx = NULL;
1328         } else {
1329                 kfree(ctx->priv_ctx_rx);
1330                 ctx->priv_ctx_rx = NULL;
1331         }
1332 out:
1333         return rc;
1334 }