GNU Linux-libre 4.19.286-gnu1
[releases.git] / kernel / bpf / sockmap.c
1 /* Copyright (c) 2017 Covalent IO, Inc. http://covalent.io
2  *
3  * This program is free software; you can redistribute it and/or
4  * modify it under the terms of version 2 of the GNU General Public
5  * License as published by the Free Software Foundation.
6  *
7  * This program is distributed in the hope that it will be useful, but
8  * WITHOUT ANY WARRANTY; without even the implied warranty of
9  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
10  * General Public License for more details.
11  */
12
13 /* A BPF sock_map is used to store sock objects. This is primarly used
14  * for doing socket redirect with BPF helper routines.
15  *
16  * A sock map may have BPF programs attached to it, currently a program
17  * used to parse packets and a program to provide a verdict and redirect
18  * decision on the packet are supported. Any programs attached to a sock
19  * map are inherited by sock objects when they are added to the map. If
20  * no BPF programs are attached the sock object may only be used for sock
21  * redirect.
22  *
23  * A sock object may be in multiple maps, but can only inherit a single
24  * parse or verdict program. If adding a sock object to a map would result
25  * in having multiple parsing programs the update will return an EBUSY error.
26  *
27  * For reference this program is similar to devmap used in XDP context
28  * reviewing these together may be useful. For an example please review
29  * ./samples/bpf/sockmap/.
30  */
31 #include <linux/bpf.h>
32 #include <net/sock.h>
33 #include <linux/filter.h>
34 #include <linux/errno.h>
35 #include <linux/file.h>
36 #include <linux/kernel.h>
37 #include <linux/net.h>
38 #include <linux/skbuff.h>
39 #include <linux/workqueue.h>
40 #include <linux/list.h>
41 #include <linux/mm.h>
42 #include <net/strparser.h>
43 #include <net/tcp.h>
44 #include <linux/ptr_ring.h>
45 #include <net/inet_common.h>
46 #include <linux/sched/signal.h>
47
48 #define SOCK_CREATE_FLAG_MASK \
49         (BPF_F_NUMA_NODE | BPF_F_RDONLY | BPF_F_WRONLY)
50
51 struct bpf_sock_progs {
52         struct bpf_prog *bpf_tx_msg;
53         struct bpf_prog *bpf_parse;
54         struct bpf_prog *bpf_verdict;
55 };
56
57 struct bpf_stab {
58         struct bpf_map map;
59         struct sock **sock_map;
60         struct bpf_sock_progs progs;
61         raw_spinlock_t lock;
62 };
63
64 struct bucket {
65         struct hlist_head head;
66         raw_spinlock_t lock;
67 };
68
69 struct bpf_htab {
70         struct bpf_map map;
71         struct bucket *buckets;
72         atomic_t count;
73         u32 n_buckets;
74         u32 elem_size;
75         struct bpf_sock_progs progs;
76         struct rcu_head rcu;
77 };
78
79 struct htab_elem {
80         struct rcu_head rcu;
81         struct hlist_node hash_node;
82         u32 hash;
83         struct sock *sk;
84         char key[0];
85 };
86
87 enum smap_psock_state {
88         SMAP_TX_RUNNING,
89 };
90
91 struct smap_psock_map_entry {
92         struct list_head list;
93         struct bpf_map *map;
94         struct sock **entry;
95         struct htab_elem __rcu *hash_link;
96 };
97
98 struct smap_psock {
99         struct rcu_head rcu;
100         refcount_t refcnt;
101
102         /* datapath variables */
103         struct sk_buff_head rxqueue;
104         bool strp_enabled;
105
106         /* datapath error path cache across tx work invocations */
107         int save_rem;
108         int save_off;
109         struct sk_buff *save_skb;
110
111         /* datapath variables for tx_msg ULP */
112         struct sock *sk_redir;
113         int apply_bytes;
114         int cork_bytes;
115         int sg_size;
116         int eval;
117         struct sk_msg_buff *cork;
118         struct list_head ingress;
119
120         struct strparser strp;
121         struct bpf_prog *bpf_tx_msg;
122         struct bpf_prog *bpf_parse;
123         struct bpf_prog *bpf_verdict;
124         struct list_head maps;
125         spinlock_t maps_lock;
126
127         /* Back reference used when sock callback trigger sockmap operations */
128         struct sock *sock;
129         unsigned long state;
130
131         struct work_struct tx_work;
132         struct work_struct gc_work;
133
134         struct proto *sk_proto;
135         void (*save_unhash)(struct sock *sk);
136         void (*save_close)(struct sock *sk, long timeout);
137         void (*save_data_ready)(struct sock *sk);
138         void (*save_write_space)(struct sock *sk);
139 };
140
141 static void smap_release_sock(struct smap_psock *psock, struct sock *sock);
142 static int bpf_tcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
143                            int nonblock, int flags, int *addr_len);
144 static int bpf_tcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size);
145 static int bpf_tcp_sendpage(struct sock *sk, struct page *page,
146                             int offset, size_t size, int flags);
147 static void bpf_tcp_unhash(struct sock *sk);
148 static void bpf_tcp_close(struct sock *sk, long timeout);
149
150 static inline struct smap_psock *smap_psock_sk(const struct sock *sk)
151 {
152         return rcu_dereference_sk_user_data(sk);
153 }
154
155 static bool bpf_tcp_stream_read(const struct sock *sk)
156 {
157         struct smap_psock *psock;
158         bool empty = true;
159
160         rcu_read_lock();
161         psock = smap_psock_sk(sk);
162         if (unlikely(!psock))
163                 goto out;
164         empty = list_empty(&psock->ingress);
165 out:
166         rcu_read_unlock();
167         return !empty;
168 }
169
170 enum {
171         SOCKMAP_IPV4,
172         SOCKMAP_IPV6,
173         SOCKMAP_NUM_PROTS,
174 };
175
176 enum {
177         SOCKMAP_BASE,
178         SOCKMAP_TX,
179         SOCKMAP_NUM_CONFIGS,
180 };
181
182 static struct proto *saved_tcpv6_prot __read_mostly;
183 static DEFINE_SPINLOCK(tcpv6_prot_lock);
184 static struct proto bpf_tcp_prots[SOCKMAP_NUM_PROTS][SOCKMAP_NUM_CONFIGS];
185 static void build_protos(struct proto prot[SOCKMAP_NUM_CONFIGS],
186                          struct proto *base)
187 {
188         prot[SOCKMAP_BASE]                      = *base;
189         prot[SOCKMAP_BASE].unhash               = bpf_tcp_unhash;
190         prot[SOCKMAP_BASE].close                = bpf_tcp_close;
191         prot[SOCKMAP_BASE].recvmsg              = bpf_tcp_recvmsg;
192         prot[SOCKMAP_BASE].stream_memory_read   = bpf_tcp_stream_read;
193
194         prot[SOCKMAP_TX]                        = prot[SOCKMAP_BASE];
195         prot[SOCKMAP_TX].sendmsg                = bpf_tcp_sendmsg;
196         prot[SOCKMAP_TX].sendpage               = bpf_tcp_sendpage;
197 }
198
199 static void update_sk_prot(struct sock *sk, struct smap_psock *psock)
200 {
201         int family = sk->sk_family == AF_INET6 ? SOCKMAP_IPV6 : SOCKMAP_IPV4;
202         int conf = psock->bpf_tx_msg ? SOCKMAP_TX : SOCKMAP_BASE;
203
204         sk->sk_prot = &bpf_tcp_prots[family][conf];
205 }
206
207 static int bpf_tcp_init(struct sock *sk)
208 {
209         struct smap_psock *psock;
210
211         rcu_read_lock();
212         psock = smap_psock_sk(sk);
213         if (unlikely(!psock)) {
214                 rcu_read_unlock();
215                 return -EINVAL;
216         }
217
218         if (unlikely(psock->sk_proto)) {
219                 rcu_read_unlock();
220                 return -EBUSY;
221         }
222
223         psock->save_unhash = sk->sk_prot->unhash;
224         psock->save_close = sk->sk_prot->close;
225         psock->sk_proto = sk->sk_prot;
226
227         /* Build IPv6 sockmap whenever the address of tcpv6_prot changes */
228         if (sk->sk_family == AF_INET6 &&
229             unlikely(sk->sk_prot != smp_load_acquire(&saved_tcpv6_prot))) {
230                 spin_lock_bh(&tcpv6_prot_lock);
231                 if (likely(sk->sk_prot != saved_tcpv6_prot)) {
232                         build_protos(bpf_tcp_prots[SOCKMAP_IPV6], sk->sk_prot);
233                         smp_store_release(&saved_tcpv6_prot, sk->sk_prot);
234                 }
235                 spin_unlock_bh(&tcpv6_prot_lock);
236         }
237         update_sk_prot(sk, psock);
238         rcu_read_unlock();
239         return 0;
240 }
241
242 static void smap_release_sock(struct smap_psock *psock, struct sock *sock);
243 static int free_start_sg(struct sock *sk, struct sk_msg_buff *md, bool charge);
244
245 static void bpf_tcp_release(struct sock *sk)
246 {
247         struct smap_psock *psock;
248
249         rcu_read_lock();
250         psock = smap_psock_sk(sk);
251         if (unlikely(!psock))
252                 goto out;
253
254         if (psock->cork) {
255                 free_start_sg(psock->sock, psock->cork, true);
256                 kfree(psock->cork);
257                 psock->cork = NULL;
258         }
259
260         if (psock->sk_proto) {
261                 sk->sk_prot = psock->sk_proto;
262                 psock->sk_proto = NULL;
263         }
264 out:
265         rcu_read_unlock();
266 }
267
268 static struct htab_elem *lookup_elem_raw(struct hlist_head *head,
269                                          u32 hash, void *key, u32 key_size)
270 {
271         struct htab_elem *l;
272
273         hlist_for_each_entry_rcu(l, head, hash_node) {
274                 if (l->hash == hash && !memcmp(&l->key, key, key_size))
275                         return l;
276         }
277
278         return NULL;
279 }
280
281 static inline struct bucket *__select_bucket(struct bpf_htab *htab, u32 hash)
282 {
283         return &htab->buckets[hash & (htab->n_buckets - 1)];
284 }
285
286 static inline struct hlist_head *select_bucket(struct bpf_htab *htab, u32 hash)
287 {
288         return &__select_bucket(htab, hash)->head;
289 }
290
291 static void free_htab_elem(struct bpf_htab *htab, struct htab_elem *l)
292 {
293         atomic_dec(&htab->count);
294         kfree_rcu(l, rcu);
295 }
296
297 static struct smap_psock_map_entry *psock_map_pop(struct sock *sk,
298                                                   struct smap_psock *psock)
299 {
300         struct smap_psock_map_entry *e;
301
302         spin_lock_bh(&psock->maps_lock);
303         e = list_first_entry_or_null(&psock->maps,
304                                      struct smap_psock_map_entry,
305                                      list);
306         if (e)
307                 list_del(&e->list);
308         spin_unlock_bh(&psock->maps_lock);
309         return e;
310 }
311
312 static void bpf_tcp_remove(struct sock *sk, struct smap_psock *psock)
313 {
314         struct smap_psock_map_entry *e;
315         struct sk_msg_buff *md, *mtmp;
316         struct sock *osk;
317
318         if (psock->cork) {
319                 free_start_sg(psock->sock, psock->cork, true);
320                 kfree(psock->cork);
321                 psock->cork = NULL;
322         }
323
324         list_for_each_entry_safe(md, mtmp, &psock->ingress, list) {
325                 list_del(&md->list);
326                 free_start_sg(psock->sock, md, true);
327                 kfree(md);
328         }
329
330         e = psock_map_pop(sk, psock);
331         while (e) {
332                 if (e->entry) {
333                         struct bpf_stab *stab = container_of(e->map, struct bpf_stab, map);
334
335                         raw_spin_lock_bh(&stab->lock);
336                         osk = *e->entry;
337                         if (osk == sk) {
338                                 *e->entry = NULL;
339                                 smap_release_sock(psock, sk);
340                         }
341                         raw_spin_unlock_bh(&stab->lock);
342                 } else {
343                         struct htab_elem *link = rcu_dereference(e->hash_link);
344                         struct bpf_htab *htab = container_of(e->map, struct bpf_htab, map);
345                         struct hlist_head *head;
346                         struct htab_elem *l;
347                         struct bucket *b;
348
349                         b = __select_bucket(htab, link->hash);
350                         head = &b->head;
351                         raw_spin_lock_bh(&b->lock);
352                         l = lookup_elem_raw(head,
353                                             link->hash, link->key,
354                                             htab->map.key_size);
355                         /* If another thread deleted this object skip deletion.
356                          * The refcnt on psock may or may not be zero.
357                          */
358                         if (l && l == link) {
359                                 hlist_del_rcu(&link->hash_node);
360                                 smap_release_sock(psock, link->sk);
361                                 free_htab_elem(htab, link);
362                         }
363                         raw_spin_unlock_bh(&b->lock);
364                 }
365                 kfree(e);
366                 e = psock_map_pop(sk, psock);
367         }
368 }
369
370 static void bpf_tcp_unhash(struct sock *sk)
371 {
372         void (*unhash_fun)(struct sock *sk);
373         struct smap_psock *psock;
374
375         rcu_read_lock();
376         psock = smap_psock_sk(sk);
377         if (unlikely(!psock)) {
378                 rcu_read_unlock();
379                 if (sk->sk_prot->unhash)
380                         sk->sk_prot->unhash(sk);
381                 return;
382         }
383         unhash_fun = psock->save_unhash;
384         bpf_tcp_remove(sk, psock);
385         rcu_read_unlock();
386         unhash_fun(sk);
387 }
388
389 static void bpf_tcp_close(struct sock *sk, long timeout)
390 {
391         void (*close_fun)(struct sock *sk, long timeout);
392         struct smap_psock *psock;
393
394         lock_sock(sk);
395         rcu_read_lock();
396         psock = smap_psock_sk(sk);
397         if (unlikely(!psock)) {
398                 rcu_read_unlock();
399                 release_sock(sk);
400                 return sk->sk_prot->close(sk, timeout);
401         }
402         close_fun = psock->save_close;
403         bpf_tcp_remove(sk, psock);
404         rcu_read_unlock();
405         release_sock(sk);
406         close_fun(sk, timeout);
407 }
408
409 enum __sk_action {
410         __SK_DROP = 0,
411         __SK_PASS,
412         __SK_REDIRECT,
413         __SK_NONE,
414 };
415
416 static struct tcp_ulp_ops bpf_tcp_ulp_ops __read_mostly = {
417         .name           = "bpf_tcp",
418         .uid            = TCP_ULP_BPF,
419         .user_visible   = false,
420         .owner          = NULL,
421         .init           = bpf_tcp_init,
422         .release        = bpf_tcp_release,
423 };
424
425 static int memcopy_from_iter(struct sock *sk,
426                              struct sk_msg_buff *md,
427                              struct iov_iter *from, int bytes)
428 {
429         struct scatterlist *sg = md->sg_data;
430         int i = md->sg_curr, rc = -ENOSPC;
431
432         do {
433                 int copy;
434                 char *to;
435
436                 if (md->sg_copybreak >= sg[i].length) {
437                         md->sg_copybreak = 0;
438
439                         if (++i == MAX_SKB_FRAGS)
440                                 i = 0;
441
442                         if (i == md->sg_end)
443                                 break;
444                 }
445
446                 copy = sg[i].length - md->sg_copybreak;
447                 to = sg_virt(&sg[i]) + md->sg_copybreak;
448                 md->sg_copybreak += copy;
449
450                 if (sk->sk_route_caps & NETIF_F_NOCACHE_COPY)
451                         rc = copy_from_iter_nocache(to, copy, from);
452                 else
453                         rc = copy_from_iter(to, copy, from);
454
455                 if (rc != copy) {
456                         rc = -EFAULT;
457                         goto out;
458                 }
459
460                 bytes -= copy;
461                 if (!bytes)
462                         break;
463
464                 md->sg_copybreak = 0;
465                 if (++i == MAX_SKB_FRAGS)
466                         i = 0;
467         } while (i != md->sg_end);
468 out:
469         md->sg_curr = i;
470         return rc;
471 }
472
473 static int bpf_tcp_push(struct sock *sk, int apply_bytes,
474                         struct sk_msg_buff *md,
475                         int flags, bool uncharge)
476 {
477         bool apply = apply_bytes;
478         struct scatterlist *sg;
479         int offset, ret = 0;
480         struct page *p;
481         size_t size;
482
483         while (1) {
484                 sg = md->sg_data + md->sg_start;
485                 size = (apply && apply_bytes < sg->length) ?
486                         apply_bytes : sg->length;
487                 offset = sg->offset;
488
489                 tcp_rate_check_app_limited(sk);
490                 p = sg_page(sg);
491 retry:
492                 ret = do_tcp_sendpages(sk, p, offset, size, flags);
493                 if (ret != size) {
494                         if (ret > 0) {
495                                 if (apply)
496                                         apply_bytes -= ret;
497
498                                 sg->offset += ret;
499                                 sg->length -= ret;
500                                 size -= ret;
501                                 offset += ret;
502                                 if (uncharge)
503                                         sk_mem_uncharge(sk, ret);
504                                 goto retry;
505                         }
506
507                         return ret;
508                 }
509
510                 if (apply)
511                         apply_bytes -= ret;
512                 sg->offset += ret;
513                 sg->length -= ret;
514                 if (uncharge)
515                         sk_mem_uncharge(sk, ret);
516
517                 if (!sg->length) {
518                         put_page(p);
519                         md->sg_start++;
520                         if (md->sg_start == MAX_SKB_FRAGS)
521                                 md->sg_start = 0;
522                         sg_init_table(sg, 1);
523
524                         if (md->sg_start == md->sg_end)
525                                 break;
526                 }
527
528                 if (apply && !apply_bytes)
529                         break;
530         }
531         return 0;
532 }
533
534 static inline void bpf_compute_data_pointers_sg(struct sk_msg_buff *md)
535 {
536         struct scatterlist *sg = md->sg_data + md->sg_start;
537
538         if (md->sg_copy[md->sg_start]) {
539                 md->data = md->data_end = 0;
540         } else {
541                 md->data = sg_virt(sg);
542                 md->data_end = md->data + sg->length;
543         }
544 }
545
546 static void return_mem_sg(struct sock *sk, int bytes, struct sk_msg_buff *md)
547 {
548         struct scatterlist *sg = md->sg_data;
549         int i = md->sg_start;
550
551         do {
552                 int uncharge = (bytes < sg[i].length) ? bytes : sg[i].length;
553
554                 sk_mem_uncharge(sk, uncharge);
555                 bytes -= uncharge;
556                 if (!bytes)
557                         break;
558                 i++;
559                 if (i == MAX_SKB_FRAGS)
560                         i = 0;
561         } while (i != md->sg_end);
562 }
563
564 static void free_bytes_sg(struct sock *sk, int bytes,
565                           struct sk_msg_buff *md, bool charge)
566 {
567         struct scatterlist *sg = md->sg_data;
568         int i = md->sg_start, free;
569
570         while (bytes && sg[i].length) {
571                 free = sg[i].length;
572                 if (bytes < free) {
573                         sg[i].length -= bytes;
574                         sg[i].offset += bytes;
575                         if (charge)
576                                 sk_mem_uncharge(sk, bytes);
577                         break;
578                 }
579
580                 if (charge)
581                         sk_mem_uncharge(sk, sg[i].length);
582                 put_page(sg_page(&sg[i]));
583                 bytes -= sg[i].length;
584                 sg[i].length = 0;
585                 sg[i].page_link = 0;
586                 sg[i].offset = 0;
587                 i++;
588
589                 if (i == MAX_SKB_FRAGS)
590                         i = 0;
591         }
592         md->sg_start = i;
593 }
594
595 static int free_sg(struct sock *sk, int start,
596                    struct sk_msg_buff *md, bool charge)
597 {
598         struct scatterlist *sg = md->sg_data;
599         int i = start, free = 0;
600
601         while (sg[i].length) {
602                 free += sg[i].length;
603                 if (charge)
604                         sk_mem_uncharge(sk, sg[i].length);
605                 if (!md->skb)
606                         put_page(sg_page(&sg[i]));
607                 sg[i].length = 0;
608                 sg[i].page_link = 0;
609                 sg[i].offset = 0;
610                 i++;
611
612                 if (i == MAX_SKB_FRAGS)
613                         i = 0;
614         }
615         if (md->skb)
616                 consume_skb(md->skb);
617
618         return free;
619 }
620
621 static int free_start_sg(struct sock *sk, struct sk_msg_buff *md, bool charge)
622 {
623         int free = free_sg(sk, md->sg_start, md, charge);
624
625         md->sg_start = md->sg_end;
626         return free;
627 }
628
629 static int free_curr_sg(struct sock *sk, struct sk_msg_buff *md)
630 {
631         return free_sg(sk, md->sg_curr, md, true);
632 }
633
634 static int bpf_map_msg_verdict(int _rc, struct sk_msg_buff *md)
635 {
636         return ((_rc == SK_PASS) ?
637                (md->sk_redir ? __SK_REDIRECT : __SK_PASS) :
638                __SK_DROP);
639 }
640
641 static unsigned int smap_do_tx_msg(struct sock *sk,
642                                    struct smap_psock *psock,
643                                    struct sk_msg_buff *md)
644 {
645         struct bpf_prog *prog;
646         unsigned int rc, _rc;
647
648         preempt_disable();
649         rcu_read_lock();
650
651         /* If the policy was removed mid-send then default to 'accept' */
652         prog = READ_ONCE(psock->bpf_tx_msg);
653         if (unlikely(!prog)) {
654                 _rc = SK_PASS;
655                 goto verdict;
656         }
657
658         bpf_compute_data_pointers_sg(md);
659         md->sk = sk;
660         rc = (*prog->bpf_func)(md, prog->insnsi);
661         psock->apply_bytes = md->apply_bytes;
662
663         /* Moving return codes from UAPI namespace into internal namespace */
664         _rc = bpf_map_msg_verdict(rc, md);
665
666         /* The psock has a refcount on the sock but not on the map and because
667          * we need to drop rcu read lock here its possible the map could be
668          * removed between here and when we need it to execute the sock
669          * redirect. So do the map lookup now for future use.
670          */
671         if (_rc == __SK_REDIRECT) {
672                 if (psock->sk_redir)
673                         sock_put(psock->sk_redir);
674                 psock->sk_redir = do_msg_redirect_map(md);
675                 if (!psock->sk_redir) {
676                         _rc = __SK_DROP;
677                         goto verdict;
678                 }
679                 sock_hold(psock->sk_redir);
680         }
681 verdict:
682         rcu_read_unlock();
683         preempt_enable();
684
685         return _rc;
686 }
687
688 static int bpf_tcp_ingress(struct sock *sk, int apply_bytes,
689                            struct smap_psock *psock,
690                            struct sk_msg_buff *md, int flags)
691 {
692         bool apply = apply_bytes;
693         size_t size, copied = 0;
694         struct sk_msg_buff *r;
695         int err = 0, i;
696
697         r = kzalloc(sizeof(struct sk_msg_buff), __GFP_NOWARN | GFP_KERNEL);
698         if (unlikely(!r))
699                 return -ENOMEM;
700
701         lock_sock(sk);
702         r->sg_start = md->sg_start;
703         i = md->sg_start;
704
705         do {
706                 size = (apply && apply_bytes < md->sg_data[i].length) ?
707                         apply_bytes : md->sg_data[i].length;
708
709                 if (!sk_wmem_schedule(sk, size)) {
710                         if (!copied)
711                                 err = -ENOMEM;
712                         break;
713                 }
714
715                 sk_mem_charge(sk, size);
716                 r->sg_data[i] = md->sg_data[i];
717                 r->sg_data[i].length = size;
718                 md->sg_data[i].length -= size;
719                 md->sg_data[i].offset += size;
720                 copied += size;
721
722                 if (md->sg_data[i].length) {
723                         get_page(sg_page(&r->sg_data[i]));
724                         r->sg_end = (i + 1) == MAX_SKB_FRAGS ? 0 : i + 1;
725                 } else {
726                         i++;
727                         if (i == MAX_SKB_FRAGS)
728                                 i = 0;
729                         r->sg_end = i;
730                 }
731
732                 if (apply) {
733                         apply_bytes -= size;
734                         if (!apply_bytes)
735                                 break;
736                 }
737         } while (i != md->sg_end);
738
739         md->sg_start = i;
740
741         if (!err) {
742                 list_add_tail(&r->list, &psock->ingress);
743                 sk->sk_data_ready(sk);
744         } else {
745                 free_start_sg(sk, r, true);
746                 kfree(r);
747         }
748
749         release_sock(sk);
750         return err;
751 }
752
753 static int bpf_tcp_sendmsg_do_redirect(struct sock *sk, int send,
754                                        struct sk_msg_buff *md,
755                                        int flags)
756 {
757         bool ingress = !!(md->flags & BPF_F_INGRESS);
758         struct smap_psock *psock;
759         int err = 0;
760
761         rcu_read_lock();
762         psock = smap_psock_sk(sk);
763         if (unlikely(!psock))
764                 goto out_rcu;
765
766         if (!refcount_inc_not_zero(&psock->refcnt))
767                 goto out_rcu;
768
769         rcu_read_unlock();
770
771         if (ingress) {
772                 err = bpf_tcp_ingress(sk, send, psock, md, flags);
773         } else {
774                 lock_sock(sk);
775                 err = bpf_tcp_push(sk, send, md, flags, false);
776                 release_sock(sk);
777         }
778         smap_release_sock(psock, sk);
779         return err;
780 out_rcu:
781         rcu_read_unlock();
782         return 0;
783 }
784
785 static inline void bpf_md_init(struct smap_psock *psock)
786 {
787         if (!psock->apply_bytes) {
788                 psock->eval =  __SK_NONE;
789                 if (psock->sk_redir) {
790                         sock_put(psock->sk_redir);
791                         psock->sk_redir = NULL;
792                 }
793         }
794 }
795
796 static void apply_bytes_dec(struct smap_psock *psock, int i)
797 {
798         if (psock->apply_bytes) {
799                 if (psock->apply_bytes < i)
800                         psock->apply_bytes = 0;
801                 else
802                         psock->apply_bytes -= i;
803         }
804 }
805
806 static int bpf_exec_tx_verdict(struct smap_psock *psock,
807                                struct sk_msg_buff *m,
808                                struct sock *sk,
809                                int *copied, int flags)
810 {
811         bool cork = false, enospc = (m->sg_start == m->sg_end);
812         struct sock *redir;
813         int err = 0;
814         int send;
815
816 more_data:
817         if (psock->eval == __SK_NONE)
818                 psock->eval = smap_do_tx_msg(sk, psock, m);
819
820         if (m->cork_bytes &&
821             m->cork_bytes > psock->sg_size && !enospc) {
822                 psock->cork_bytes = m->cork_bytes - psock->sg_size;
823                 if (!psock->cork) {
824                         psock->cork = kcalloc(1,
825                                         sizeof(struct sk_msg_buff),
826                                         GFP_ATOMIC | __GFP_NOWARN);
827
828                         if (!psock->cork) {
829                                 err = -ENOMEM;
830                                 goto out_err;
831                         }
832                 }
833                 memcpy(psock->cork, m, sizeof(*m));
834                 goto out_err;
835         }
836
837         send = psock->sg_size;
838         if (psock->apply_bytes && psock->apply_bytes < send)
839                 send = psock->apply_bytes;
840
841         switch (psock->eval) {
842         case __SK_PASS:
843                 err = bpf_tcp_push(sk, send, m, flags, true);
844                 if (unlikely(err)) {
845                         *copied -= free_start_sg(sk, m, true);
846                         break;
847                 }
848
849                 apply_bytes_dec(psock, send);
850                 psock->sg_size -= send;
851                 break;
852         case __SK_REDIRECT:
853                 redir = psock->sk_redir;
854                 apply_bytes_dec(psock, send);
855
856                 if (psock->cork) {
857                         cork = true;
858                         psock->cork = NULL;
859                 }
860
861                 return_mem_sg(sk, send, m);
862                 release_sock(sk);
863
864                 err = bpf_tcp_sendmsg_do_redirect(redir, send, m, flags);
865                 lock_sock(sk);
866
867                 if (unlikely(err < 0)) {
868                         int free = free_start_sg(sk, m, false);
869
870                         psock->sg_size = 0;
871                         if (!cork)
872                                 *copied -= free;
873                 } else {
874                         psock->sg_size -= send;
875                 }
876
877                 if (cork) {
878                         free_start_sg(sk, m, true);
879                         psock->sg_size = 0;
880                         kfree(m);
881                         m = NULL;
882                         err = 0;
883                 }
884                 break;
885         case __SK_DROP:
886         default:
887                 free_bytes_sg(sk, send, m, true);
888                 apply_bytes_dec(psock, send);
889                 *copied -= send;
890                 psock->sg_size -= send;
891                 err = -EACCES;
892                 break;
893         }
894
895         if (likely(!err)) {
896                 bpf_md_init(psock);
897                 if (m &&
898                     m->sg_data[m->sg_start].page_link &&
899                     m->sg_data[m->sg_start].length)
900                         goto more_data;
901         }
902
903 out_err:
904         return err;
905 }
906
907 static int bpf_wait_data(struct sock *sk,
908                          struct smap_psock *psk, int flags,
909                          long timeo, int *err)
910 {
911         int rc;
912
913         DEFINE_WAIT_FUNC(wait, woken_wake_function);
914
915         add_wait_queue(sk_sleep(sk), &wait);
916         sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
917         rc = sk_wait_event(sk, &timeo,
918                            !list_empty(&psk->ingress) ||
919                            !skb_queue_empty(&sk->sk_receive_queue),
920                            &wait);
921         sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
922         remove_wait_queue(sk_sleep(sk), &wait);
923
924         return rc;
925 }
926
927 static int bpf_tcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
928                            int nonblock, int flags, int *addr_len)
929 {
930         struct iov_iter *iter = &msg->msg_iter;
931         struct smap_psock *psock;
932         int copied = 0;
933
934         if (unlikely(flags & MSG_ERRQUEUE))
935                 return inet_recv_error(sk, msg, len, addr_len);
936         if (!skb_queue_empty(&sk->sk_receive_queue))
937                 return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
938
939         rcu_read_lock();
940         psock = smap_psock_sk(sk);
941         if (unlikely(!psock))
942                 goto out;
943
944         if (unlikely(!refcount_inc_not_zero(&psock->refcnt)))
945                 goto out;
946         rcu_read_unlock();
947
948         lock_sock(sk);
949 bytes_ready:
950         while (copied != len) {
951                 struct scatterlist *sg;
952                 struct sk_msg_buff *md;
953                 int i;
954
955                 md = list_first_entry_or_null(&psock->ingress,
956                                               struct sk_msg_buff, list);
957                 if (unlikely(!md))
958                         break;
959                 i = md->sg_start;
960                 do {
961                         struct page *page;
962                         int n, copy;
963
964                         sg = &md->sg_data[i];
965                         copy = sg->length;
966                         page = sg_page(sg);
967
968                         if (copied + copy > len)
969                                 copy = len - copied;
970
971                         n = copy_page_to_iter(page, sg->offset, copy, iter);
972                         if (n != copy) {
973                                 md->sg_start = i;
974                                 release_sock(sk);
975                                 smap_release_sock(psock, sk);
976                                 return -EFAULT;
977                         }
978
979                         copied += copy;
980                         sg->offset += copy;
981                         sg->length -= copy;
982                         sk_mem_uncharge(sk, copy);
983
984                         if (!sg->length) {
985                                 i++;
986                                 if (i == MAX_SKB_FRAGS)
987                                         i = 0;
988                                 if (!md->skb)
989                                         put_page(page);
990                         }
991                         if (copied == len)
992                                 break;
993                 } while (i != md->sg_end);
994                 md->sg_start = i;
995
996                 if (!sg->length && md->sg_start == md->sg_end) {
997                         list_del(&md->list);
998                         if (md->skb)
999                                 consume_skb(md->skb);
1000                         kfree(md);
1001                 }
1002         }
1003
1004         if (!copied) {
1005                 long timeo;
1006                 int data;
1007                 int err = 0;
1008
1009                 timeo = sock_rcvtimeo(sk, nonblock);
1010                 data = bpf_wait_data(sk, psock, flags, timeo, &err);
1011
1012                 if (data) {
1013                         if (!skb_queue_empty(&sk->sk_receive_queue)) {
1014                                 release_sock(sk);
1015                                 smap_release_sock(psock, sk);
1016                                 copied = tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
1017                                 return copied;
1018                         }
1019                         goto bytes_ready;
1020                 }
1021
1022                 if (err)
1023                         copied = err;
1024         }
1025
1026         release_sock(sk);
1027         smap_release_sock(psock, sk);
1028         return copied;
1029 out:
1030         rcu_read_unlock();
1031         return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
1032 }
1033
1034
1035 static int bpf_tcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
1036 {
1037         int flags = msg->msg_flags | MSG_NO_SHARED_FRAGS;
1038         struct sk_msg_buff md = {0};
1039         unsigned int sg_copy = 0;
1040         struct smap_psock *psock;
1041         int copied = 0, err = 0;
1042         struct scatterlist *sg;
1043         long timeo;
1044
1045         /* Its possible a sock event or user removed the psock _but_ the ops
1046          * have not been reprogrammed yet so we get here. In this case fallback
1047          * to tcp_sendmsg. Note this only works because we _only_ ever allow
1048          * a single ULP there is no hierarchy here.
1049          */
1050         rcu_read_lock();
1051         psock = smap_psock_sk(sk);
1052         if (unlikely(!psock)) {
1053                 rcu_read_unlock();
1054                 return tcp_sendmsg(sk, msg, size);
1055         }
1056
1057         /* Increment the psock refcnt to ensure its not released while sending a
1058          * message. Required because sk lookup and bpf programs are used in
1059          * separate rcu critical sections. Its OK if we lose the map entry
1060          * but we can't lose the sock reference.
1061          */
1062         if (!refcount_inc_not_zero(&psock->refcnt)) {
1063                 rcu_read_unlock();
1064                 return tcp_sendmsg(sk, msg, size);
1065         }
1066
1067         sg = md.sg_data;
1068         sg_init_marker(sg, MAX_SKB_FRAGS);
1069         rcu_read_unlock();
1070
1071         lock_sock(sk);
1072         timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
1073
1074         while (msg_data_left(msg)) {
1075                 struct sk_msg_buff *m = NULL;
1076                 bool enospc = false;
1077                 int copy;
1078
1079                 if (sk->sk_err) {
1080                         err = -sk->sk_err;
1081                         goto out_err;
1082                 }
1083
1084                 copy = msg_data_left(msg);
1085                 if (!sk_stream_memory_free(sk))
1086                         goto wait_for_sndbuf;
1087
1088                 m = psock->cork_bytes ? psock->cork : &md;
1089                 m->sg_curr = m->sg_copybreak ? m->sg_curr : m->sg_end;
1090                 err = sk_alloc_sg(sk, copy, m->sg_data,
1091                                   m->sg_start, &m->sg_end, &sg_copy,
1092                                   m->sg_end - 1);
1093                 if (err) {
1094                         if (err != -ENOSPC)
1095                                 goto wait_for_memory;
1096                         enospc = true;
1097                         copy = sg_copy;
1098                 }
1099
1100                 err = memcopy_from_iter(sk, m, &msg->msg_iter, copy);
1101                 if (err < 0) {
1102                         free_curr_sg(sk, m);
1103                         goto out_err;
1104                 }
1105
1106                 psock->sg_size += copy;
1107                 copied += copy;
1108                 sg_copy = 0;
1109
1110                 /* When bytes are being corked skip running BPF program and
1111                  * applying verdict unless there is no more buffer space. In
1112                  * the ENOSPC case simply run BPF prorgram with currently
1113                  * accumulated data. We don't have much choice at this point
1114                  * we could try extending the page frags or chaining complex
1115                  * frags but even in these cases _eventually_ we will hit an
1116                  * OOM scenario. More complex recovery schemes may be
1117                  * implemented in the future, but BPF programs must handle
1118                  * the case where apply_cork requests are not honored. The
1119                  * canonical method to verify this is to check data length.
1120                  */
1121                 if (psock->cork_bytes) {
1122                         if (copy > psock->cork_bytes)
1123                                 psock->cork_bytes = 0;
1124                         else
1125                                 psock->cork_bytes -= copy;
1126
1127                         if (psock->cork_bytes && !enospc)
1128                                 goto out_cork;
1129
1130                         /* All cork bytes accounted for re-run filter */
1131                         psock->eval = __SK_NONE;
1132                         psock->cork_bytes = 0;
1133                 }
1134
1135                 err = bpf_exec_tx_verdict(psock, m, sk, &copied, flags);
1136                 if (unlikely(err < 0))
1137                         goto out_err;
1138                 continue;
1139 wait_for_sndbuf:
1140                 set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
1141 wait_for_memory:
1142                 err = sk_stream_wait_memory(sk, &timeo);
1143                 if (err) {
1144                         if (m && m != psock->cork)
1145                                 free_start_sg(sk, m, true);
1146                         goto out_err;
1147                 }
1148         }
1149 out_err:
1150         if (err < 0)
1151                 err = sk_stream_error(sk, msg->msg_flags, err);
1152 out_cork:
1153         release_sock(sk);
1154         smap_release_sock(psock, sk);
1155         return copied ? copied : err;
1156 }
1157
1158 static int bpf_tcp_sendpage(struct sock *sk, struct page *page,
1159                             int offset, size_t size, int flags)
1160 {
1161         struct sk_msg_buff md = {0}, *m = NULL;
1162         int err = 0, copied = 0;
1163         struct smap_psock *psock;
1164         struct scatterlist *sg;
1165         bool enospc = false;
1166
1167         rcu_read_lock();
1168         psock = smap_psock_sk(sk);
1169         if (unlikely(!psock))
1170                 goto accept;
1171
1172         if (!refcount_inc_not_zero(&psock->refcnt))
1173                 goto accept;
1174         rcu_read_unlock();
1175
1176         lock_sock(sk);
1177
1178         if (psock->cork_bytes) {
1179                 m = psock->cork;
1180                 sg = &m->sg_data[m->sg_end];
1181         } else {
1182                 m = &md;
1183                 sg = m->sg_data;
1184                 sg_init_marker(sg, MAX_SKB_FRAGS);
1185         }
1186
1187         /* Catch case where ring is full and sendpage is stalled. */
1188         if (unlikely(m->sg_end == m->sg_start &&
1189             m->sg_data[m->sg_end].length))
1190                 goto out_err;
1191
1192         psock->sg_size += size;
1193         sg_set_page(sg, page, size, offset);
1194         get_page(page);
1195         m->sg_copy[m->sg_end] = true;
1196         sk_mem_charge(sk, size);
1197         m->sg_end++;
1198         copied = size;
1199
1200         if (m->sg_end == MAX_SKB_FRAGS)
1201                 m->sg_end = 0;
1202
1203         if (m->sg_end == m->sg_start)
1204                 enospc = true;
1205
1206         if (psock->cork_bytes) {
1207                 if (size > psock->cork_bytes)
1208                         psock->cork_bytes = 0;
1209                 else
1210                         psock->cork_bytes -= size;
1211
1212                 if (psock->cork_bytes && !enospc)
1213                         goto out_err;
1214
1215                 /* All cork bytes accounted for re-run filter */
1216                 psock->eval = __SK_NONE;
1217                 psock->cork_bytes = 0;
1218         }
1219
1220         err = bpf_exec_tx_verdict(psock, m, sk, &copied, flags);
1221 out_err:
1222         release_sock(sk);
1223         smap_release_sock(psock, sk);
1224         return copied ? copied : err;
1225 accept:
1226         rcu_read_unlock();
1227         return tcp_sendpage(sk, page, offset, size, flags);
1228 }
1229
1230 static void bpf_tcp_msg_add(struct smap_psock *psock,
1231                             struct sock *sk,
1232                             struct bpf_prog *tx_msg)
1233 {
1234         struct bpf_prog *orig_tx_msg;
1235
1236         orig_tx_msg = xchg(&psock->bpf_tx_msg, tx_msg);
1237         if (orig_tx_msg)
1238                 bpf_prog_put(orig_tx_msg);
1239 }
1240
1241 static int bpf_tcp_ulp_register(void)
1242 {
1243         build_protos(bpf_tcp_prots[SOCKMAP_IPV4], &tcp_prot);
1244         /* Once BPF TX ULP is registered it is never unregistered. It
1245          * will be in the ULP list for the lifetime of the system. Doing
1246          * duplicate registers is not a problem.
1247          */
1248         return tcp_register_ulp(&bpf_tcp_ulp_ops);
1249 }
1250
1251 static int smap_verdict_func(struct smap_psock *psock, struct sk_buff *skb)
1252 {
1253         struct bpf_prog *prog = READ_ONCE(psock->bpf_verdict);
1254         int rc;
1255
1256         if (unlikely(!prog))
1257                 return __SK_DROP;
1258
1259         skb_orphan(skb);
1260         /* We need to ensure that BPF metadata for maps is also cleared
1261          * when we orphan the skb so that we don't have the possibility
1262          * to reference a stale map.
1263          */
1264         TCP_SKB_CB(skb)->bpf.sk_redir = NULL;
1265         skb->sk = psock->sock;
1266         bpf_compute_data_end_sk_skb(skb);
1267         preempt_disable();
1268         rc = (*prog->bpf_func)(skb, prog->insnsi);
1269         preempt_enable();
1270         skb->sk = NULL;
1271
1272         /* Moving return codes from UAPI namespace into internal namespace */
1273         return rc == SK_PASS ?
1274                 (TCP_SKB_CB(skb)->bpf.sk_redir ? __SK_REDIRECT : __SK_PASS) :
1275                 __SK_DROP;
1276 }
1277
1278 static int smap_do_ingress(struct smap_psock *psock, struct sk_buff *skb)
1279 {
1280         struct sock *sk = psock->sock;
1281         int copied = 0, num_sg;
1282         struct sk_msg_buff *r;
1283
1284         r = kzalloc(sizeof(struct sk_msg_buff), __GFP_NOWARN | GFP_ATOMIC);
1285         if (unlikely(!r))
1286                 return -EAGAIN;
1287
1288         if (!sk_rmem_schedule(sk, skb, skb->len)) {
1289                 kfree(r);
1290                 return -EAGAIN;
1291         }
1292
1293         sg_init_table(r->sg_data, MAX_SKB_FRAGS);
1294         num_sg = skb_to_sgvec(skb, r->sg_data, 0, skb->len);
1295         if (unlikely(num_sg < 0)) {
1296                 kfree(r);
1297                 return num_sg;
1298         }
1299         sk_mem_charge(sk, skb->len);
1300         copied = skb->len;
1301         r->sg_start = 0;
1302         r->sg_end = num_sg == MAX_SKB_FRAGS ? 0 : num_sg;
1303         r->skb = skb;
1304         list_add_tail(&r->list, &psock->ingress);
1305         sk->sk_data_ready(sk);
1306         return copied;
1307 }
1308
1309 static void smap_do_verdict(struct smap_psock *psock, struct sk_buff *skb)
1310 {
1311         struct smap_psock *peer;
1312         struct sock *sk;
1313         __u32 in;
1314         int rc;
1315
1316         rc = smap_verdict_func(psock, skb);
1317         switch (rc) {
1318         case __SK_REDIRECT:
1319                 sk = do_sk_redirect_map(skb);
1320                 if (!sk) {
1321                         kfree_skb(skb);
1322                         break;
1323                 }
1324
1325                 peer = smap_psock_sk(sk);
1326                 in = (TCP_SKB_CB(skb)->bpf.flags) & BPF_F_INGRESS;
1327
1328                 if (unlikely(!peer || sock_flag(sk, SOCK_DEAD) ||
1329                              !test_bit(SMAP_TX_RUNNING, &peer->state))) {
1330                         kfree_skb(skb);
1331                         break;
1332                 }
1333
1334                 if (!in && sock_writeable(sk)) {
1335                         skb_set_owner_w(skb, sk);
1336                         skb_queue_tail(&peer->rxqueue, skb);
1337                         schedule_work(&peer->tx_work);
1338                         break;
1339                 } else if (in &&
1340                            atomic_read(&sk->sk_rmem_alloc) <= sk->sk_rcvbuf) {
1341                         skb_queue_tail(&peer->rxqueue, skb);
1342                         schedule_work(&peer->tx_work);
1343                         break;
1344                 }
1345         /* Fall through and free skb otherwise */
1346         case __SK_DROP:
1347         default:
1348                 kfree_skb(skb);
1349         }
1350 }
1351
1352 static void smap_report_sk_error(struct smap_psock *psock, int err)
1353 {
1354         struct sock *sk = psock->sock;
1355
1356         sk->sk_err = err;
1357         sk->sk_error_report(sk);
1358 }
1359
1360 static void smap_read_sock_strparser(struct strparser *strp,
1361                                      struct sk_buff *skb)
1362 {
1363         struct smap_psock *psock;
1364
1365         rcu_read_lock();
1366         psock = container_of(strp, struct smap_psock, strp);
1367         smap_do_verdict(psock, skb);
1368         rcu_read_unlock();
1369 }
1370
1371 /* Called with lock held on socket */
1372 static void smap_data_ready(struct sock *sk)
1373 {
1374         struct smap_psock *psock;
1375
1376         rcu_read_lock();
1377         psock = smap_psock_sk(sk);
1378         if (likely(psock)) {
1379                 write_lock_bh(&sk->sk_callback_lock);
1380                 strp_data_ready(&psock->strp);
1381                 write_unlock_bh(&sk->sk_callback_lock);
1382         }
1383         rcu_read_unlock();
1384 }
1385
1386 static void smap_tx_work(struct work_struct *w)
1387 {
1388         struct smap_psock *psock;
1389         struct sk_buff *skb;
1390         int rem, off, n;
1391
1392         psock = container_of(w, struct smap_psock, tx_work);
1393
1394         /* lock sock to avoid losing sk_socket at some point during loop */
1395         lock_sock(psock->sock);
1396         if (psock->save_skb) {
1397                 skb = psock->save_skb;
1398                 rem = psock->save_rem;
1399                 off = psock->save_off;
1400                 psock->save_skb = NULL;
1401                 goto start;
1402         }
1403
1404         while ((skb = skb_dequeue(&psock->rxqueue))) {
1405                 __u32 flags;
1406
1407                 rem = skb->len;
1408                 off = 0;
1409 start:
1410                 flags = (TCP_SKB_CB(skb)->bpf.flags) & BPF_F_INGRESS;
1411                 do {
1412                         if (likely(psock->sock->sk_socket)) {
1413                                 if (flags)
1414                                         n = smap_do_ingress(psock, skb);
1415                                 else
1416                                         n = skb_send_sock_locked(psock->sock,
1417                                                                  skb, off, rem);
1418                         } else {
1419                                 n = -EINVAL;
1420                         }
1421
1422                         if (n <= 0) {
1423                                 if (n == -EAGAIN) {
1424                                         /* Retry when space is available */
1425                                         psock->save_skb = skb;
1426                                         psock->save_rem = rem;
1427                                         psock->save_off = off;
1428                                         goto out;
1429                                 }
1430                                 /* Hard errors break pipe and stop xmit */
1431                                 smap_report_sk_error(psock, n ? -n : EPIPE);
1432                                 clear_bit(SMAP_TX_RUNNING, &psock->state);
1433                                 kfree_skb(skb);
1434                                 goto out;
1435                         }
1436                         rem -= n;
1437                         off += n;
1438                 } while (rem);
1439
1440                 if (!flags)
1441                         kfree_skb(skb);
1442         }
1443 out:
1444         release_sock(psock->sock);
1445 }
1446
1447 static void smap_write_space(struct sock *sk)
1448 {
1449         struct smap_psock *psock;
1450         void (*write_space)(struct sock *sk);
1451
1452         rcu_read_lock();
1453         psock = smap_psock_sk(sk);
1454         if (likely(psock && test_bit(SMAP_TX_RUNNING, &psock->state)))
1455                 schedule_work(&psock->tx_work);
1456         write_space = psock->save_write_space;
1457         rcu_read_unlock();
1458         write_space(sk);
1459 }
1460
1461 static void smap_stop_sock(struct smap_psock *psock, struct sock *sk)
1462 {
1463         if (!psock->strp_enabled)
1464                 return;
1465         sk->sk_data_ready = psock->save_data_ready;
1466         sk->sk_write_space = psock->save_write_space;
1467         psock->save_data_ready = NULL;
1468         psock->save_write_space = NULL;
1469         strp_stop(&psock->strp);
1470         psock->strp_enabled = false;
1471 }
1472
1473 static void smap_destroy_psock(struct rcu_head *rcu)
1474 {
1475         struct smap_psock *psock = container_of(rcu,
1476                                                   struct smap_psock, rcu);
1477
1478         /* Now that a grace period has passed there is no longer
1479          * any reference to this sock in the sockmap so we can
1480          * destroy the psock, strparser, and bpf programs. But,
1481          * because we use workqueue sync operations we can not
1482          * do it in rcu context
1483          */
1484         schedule_work(&psock->gc_work);
1485 }
1486
1487 static bool psock_is_smap_sk(struct sock *sk)
1488 {
1489         return inet_csk(sk)->icsk_ulp_ops == &bpf_tcp_ulp_ops;
1490 }
1491
1492 static void smap_release_sock(struct smap_psock *psock, struct sock *sock)
1493 {
1494         if (refcount_dec_and_test(&psock->refcnt)) {
1495                 if (psock_is_smap_sk(sock))
1496                         tcp_cleanup_ulp(sock);
1497                 write_lock_bh(&sock->sk_callback_lock);
1498                 smap_stop_sock(psock, sock);
1499                 write_unlock_bh(&sock->sk_callback_lock);
1500                 clear_bit(SMAP_TX_RUNNING, &psock->state);
1501                 rcu_assign_sk_user_data(sock, NULL);
1502                 call_rcu_sched(&psock->rcu, smap_destroy_psock);
1503         }
1504 }
1505
1506 static int smap_parse_func_strparser(struct strparser *strp,
1507                                        struct sk_buff *skb)
1508 {
1509         struct smap_psock *psock;
1510         struct bpf_prog *prog;
1511         int rc;
1512
1513         rcu_read_lock();
1514         psock = container_of(strp, struct smap_psock, strp);
1515         prog = READ_ONCE(psock->bpf_parse);
1516
1517         if (unlikely(!prog)) {
1518                 rcu_read_unlock();
1519                 return skb->len;
1520         }
1521
1522         /* Attach socket for bpf program to use if needed we can do this
1523          * because strparser clones the skb before handing it to a upper
1524          * layer, meaning skb_orphan has been called. We NULL sk on the
1525          * way out to ensure we don't trigger a BUG_ON in skb/sk operations
1526          * later and because we are not charging the memory of this skb to
1527          * any socket yet.
1528          */
1529         skb->sk = psock->sock;
1530         bpf_compute_data_end_sk_skb(skb);
1531         rc = (*prog->bpf_func)(skb, prog->insnsi);
1532         skb->sk = NULL;
1533         rcu_read_unlock();
1534         return rc;
1535 }
1536
1537 static int smap_read_sock_done(struct strparser *strp, int err)
1538 {
1539         return err;
1540 }
1541
1542 static int smap_init_sock(struct smap_psock *psock,
1543                           struct sock *sk)
1544 {
1545         static const struct strp_callbacks cb = {
1546                 .rcv_msg = smap_read_sock_strparser,
1547                 .parse_msg = smap_parse_func_strparser,
1548                 .read_sock_done = smap_read_sock_done,
1549         };
1550
1551         return strp_init(&psock->strp, sk, &cb);
1552 }
1553
1554 static void smap_init_progs(struct smap_psock *psock,
1555                             struct bpf_prog *verdict,
1556                             struct bpf_prog *parse)
1557 {
1558         struct bpf_prog *orig_parse, *orig_verdict;
1559
1560         orig_parse = xchg(&psock->bpf_parse, parse);
1561         orig_verdict = xchg(&psock->bpf_verdict, verdict);
1562
1563         if (orig_verdict)
1564                 bpf_prog_put(orig_verdict);
1565         if (orig_parse)
1566                 bpf_prog_put(orig_parse);
1567 }
1568
1569 static void smap_start_sock(struct smap_psock *psock, struct sock *sk)
1570 {
1571         if (sk->sk_data_ready == smap_data_ready)
1572                 return;
1573         psock->save_data_ready = sk->sk_data_ready;
1574         psock->save_write_space = sk->sk_write_space;
1575         sk->sk_data_ready = smap_data_ready;
1576         sk->sk_write_space = smap_write_space;
1577         psock->strp_enabled = true;
1578 }
1579
1580 static void sock_map_remove_complete(struct bpf_stab *stab)
1581 {
1582         bpf_map_area_free(stab->sock_map);
1583         kfree(stab);
1584 }
1585
1586 static void smap_gc_work(struct work_struct *w)
1587 {
1588         struct smap_psock_map_entry *e, *tmp;
1589         struct sk_msg_buff *md, *mtmp;
1590         struct smap_psock *psock;
1591
1592         psock = container_of(w, struct smap_psock, gc_work);
1593
1594         /* no callback lock needed because we already detached sockmap ops */
1595         if (psock->strp_enabled)
1596                 strp_done(&psock->strp);
1597
1598         cancel_work_sync(&psock->tx_work);
1599         __skb_queue_purge(&psock->rxqueue);
1600
1601         /* At this point all strparser and xmit work must be complete */
1602         if (psock->bpf_parse)
1603                 bpf_prog_put(psock->bpf_parse);
1604         if (psock->bpf_verdict)
1605                 bpf_prog_put(psock->bpf_verdict);
1606         if (psock->bpf_tx_msg)
1607                 bpf_prog_put(psock->bpf_tx_msg);
1608
1609         if (psock->cork) {
1610                 free_start_sg(psock->sock, psock->cork, true);
1611                 kfree(psock->cork);
1612         }
1613
1614         list_for_each_entry_safe(md, mtmp, &psock->ingress, list) {
1615                 list_del(&md->list);
1616                 free_start_sg(psock->sock, md, true);
1617                 kfree(md);
1618         }
1619
1620         list_for_each_entry_safe(e, tmp, &psock->maps, list) {
1621                 list_del(&e->list);
1622                 kfree(e);
1623         }
1624
1625         if (psock->sk_redir)
1626                 sock_put(psock->sk_redir);
1627
1628         sock_put(psock->sock);
1629         kfree(psock);
1630 }
1631
1632 static struct smap_psock *smap_init_psock(struct sock *sock, int node)
1633 {
1634         struct smap_psock *psock;
1635
1636         psock = kzalloc_node(sizeof(struct smap_psock),
1637                              GFP_ATOMIC | __GFP_NOWARN,
1638                              node);
1639         if (!psock)
1640                 return ERR_PTR(-ENOMEM);
1641
1642         psock->eval =  __SK_NONE;
1643         psock->sock = sock;
1644         skb_queue_head_init(&psock->rxqueue);
1645         INIT_WORK(&psock->tx_work, smap_tx_work);
1646         INIT_WORK(&psock->gc_work, smap_gc_work);
1647         INIT_LIST_HEAD(&psock->maps);
1648         INIT_LIST_HEAD(&psock->ingress);
1649         refcount_set(&psock->refcnt, 1);
1650         spin_lock_init(&psock->maps_lock);
1651
1652         rcu_assign_sk_user_data(sock, psock);
1653         sock_hold(sock);
1654         return psock;
1655 }
1656
1657 static struct bpf_map *sock_map_alloc(union bpf_attr *attr)
1658 {
1659         struct bpf_stab *stab;
1660         u64 cost;
1661         int err;
1662
1663         if (!capable(CAP_NET_ADMIN))
1664                 return ERR_PTR(-EPERM);
1665
1666         /* check sanity of attributes */
1667         if (attr->max_entries == 0 || attr->key_size != 4 ||
1668             attr->value_size != 4 || attr->map_flags & ~SOCK_CREATE_FLAG_MASK)
1669                 return ERR_PTR(-EINVAL);
1670
1671         err = bpf_tcp_ulp_register();
1672         if (err && err != -EEXIST)
1673                 return ERR_PTR(err);
1674
1675         stab = kzalloc(sizeof(*stab), GFP_USER);
1676         if (!stab)
1677                 return ERR_PTR(-ENOMEM);
1678
1679         bpf_map_init_from_attr(&stab->map, attr);
1680         raw_spin_lock_init(&stab->lock);
1681
1682         /* make sure page count doesn't overflow */
1683         cost = (u64) stab->map.max_entries * sizeof(struct sock *);
1684         err = -EINVAL;
1685         if (cost >= U32_MAX - PAGE_SIZE)
1686                 goto free_stab;
1687
1688         stab->map.pages = round_up(cost, PAGE_SIZE) >> PAGE_SHIFT;
1689
1690         /* if map size is larger than memlock limit, reject it early */
1691         err = bpf_map_precharge_memlock(stab->map.pages);
1692         if (err)
1693                 goto free_stab;
1694
1695         err = -ENOMEM;
1696         stab->sock_map = bpf_map_area_alloc(stab->map.max_entries *
1697                                             sizeof(struct sock *),
1698                                             stab->map.numa_node);
1699         if (!stab->sock_map)
1700                 goto free_stab;
1701
1702         return &stab->map;
1703 free_stab:
1704         kfree(stab);
1705         return ERR_PTR(err);
1706 }
1707
1708 static void smap_list_map_remove(struct smap_psock *psock,
1709                                  struct sock **entry)
1710 {
1711         struct smap_psock_map_entry *e, *tmp;
1712
1713         spin_lock_bh(&psock->maps_lock);
1714         list_for_each_entry_safe(e, tmp, &psock->maps, list) {
1715                 if (e->entry == entry) {
1716                         list_del(&e->list);
1717                         kfree(e);
1718                 }
1719         }
1720         spin_unlock_bh(&psock->maps_lock);
1721 }
1722
1723 static void smap_list_hash_remove(struct smap_psock *psock,
1724                                   struct htab_elem *hash_link)
1725 {
1726         struct smap_psock_map_entry *e, *tmp;
1727
1728         spin_lock_bh(&psock->maps_lock);
1729         list_for_each_entry_safe(e, tmp, &psock->maps, list) {
1730                 struct htab_elem *c = rcu_dereference(e->hash_link);
1731
1732                 if (c == hash_link) {
1733                         list_del(&e->list);
1734                         kfree(e);
1735                 }
1736         }
1737         spin_unlock_bh(&psock->maps_lock);
1738 }
1739
1740 static void sock_map_free(struct bpf_map *map)
1741 {
1742         struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
1743         int i;
1744
1745         synchronize_rcu();
1746
1747         /* At this point no update, lookup or delete operations can happen.
1748          * However, be aware we can still get a socket state event updates,
1749          * and data ready callabacks that reference the psock from sk_user_data
1750          * Also psock worker threads are still in-flight. So smap_release_sock
1751          * will only free the psock after cancel_sync on the worker threads
1752          * and a grace period expire to ensure psock is really safe to remove.
1753          */
1754         rcu_read_lock();
1755         raw_spin_lock_bh(&stab->lock);
1756         for (i = 0; i < stab->map.max_entries; i++) {
1757                 struct smap_psock *psock;
1758                 struct sock *sock;
1759
1760                 sock = stab->sock_map[i];
1761                 if (!sock)
1762                         continue;
1763                 stab->sock_map[i] = NULL;
1764                 psock = smap_psock_sk(sock);
1765                 /* This check handles a racing sock event that can get the
1766                  * sk_callback_lock before this case but after xchg happens
1767                  * causing the refcnt to hit zero and sock user data (psock)
1768                  * to be null and queued for garbage collection.
1769                  */
1770                 if (likely(psock)) {
1771                         smap_list_map_remove(psock, &stab->sock_map[i]);
1772                         smap_release_sock(psock, sock);
1773                 }
1774         }
1775         raw_spin_unlock_bh(&stab->lock);
1776         rcu_read_unlock();
1777
1778         sock_map_remove_complete(stab);
1779 }
1780
1781 static int sock_map_get_next_key(struct bpf_map *map, void *key, void *next_key)
1782 {
1783         struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
1784         u32 i = key ? *(u32 *)key : U32_MAX;
1785         u32 *next = (u32 *)next_key;
1786
1787         if (i >= stab->map.max_entries) {
1788                 *next = 0;
1789                 return 0;
1790         }
1791
1792         if (i == stab->map.max_entries - 1)
1793                 return -ENOENT;
1794
1795         *next = i + 1;
1796         return 0;
1797 }
1798
1799 struct sock  *__sock_map_lookup_elem(struct bpf_map *map, u32 key)
1800 {
1801         struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
1802
1803         if (key >= map->max_entries)
1804                 return NULL;
1805
1806         return READ_ONCE(stab->sock_map[key]);
1807 }
1808
1809 static int sock_map_delete_elem(struct bpf_map *map, void *key)
1810 {
1811         struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
1812         struct smap_psock *psock;
1813         int k = *(u32 *)key;
1814         struct sock *sock;
1815
1816         if (k >= map->max_entries)
1817                 return -EINVAL;
1818
1819         raw_spin_lock_bh(&stab->lock);
1820         sock = stab->sock_map[k];
1821         stab->sock_map[k] = NULL;
1822         raw_spin_unlock_bh(&stab->lock);
1823         if (!sock)
1824                 return -EINVAL;
1825
1826         psock = smap_psock_sk(sock);
1827         if (!psock)
1828                 return 0;
1829         if (psock->bpf_parse) {
1830                 write_lock_bh(&sock->sk_callback_lock);
1831                 smap_stop_sock(psock, sock);
1832                 write_unlock_bh(&sock->sk_callback_lock);
1833         }
1834         smap_list_map_remove(psock, &stab->sock_map[k]);
1835         smap_release_sock(psock, sock);
1836         return 0;
1837 }
1838
1839 /* Locking notes: Concurrent updates, deletes, and lookups are allowed and are
1840  * done inside rcu critical sections. This ensures on updates that the psock
1841  * will not be released via smap_release_sock() until concurrent updates/deletes
1842  * complete. All operations operate on sock_map using cmpxchg and xchg
1843  * operations to ensure we do not get stale references. Any reads into the
1844  * map must be done with READ_ONCE() because of this.
1845  *
1846  * A psock is destroyed via call_rcu and after any worker threads are cancelled
1847  * and syncd so we are certain all references from the update/lookup/delete
1848  * operations as well as references in the data path are no longer in use.
1849  *
1850  * Psocks may exist in multiple maps, but only a single set of parse/verdict
1851  * programs may be inherited from the maps it belongs to. A reference count
1852  * is kept with the total number of references to the psock from all maps. The
1853  * psock will not be released until this reaches zero. The psock and sock
1854  * user data data use the sk_callback_lock to protect critical data structures
1855  * from concurrent access. This allows us to avoid two updates from modifying
1856  * the user data in sock and the lock is required anyways for modifying
1857  * callbacks, we simply increase its scope slightly.
1858  *
1859  * Rules to follow,
1860  *  - psock must always be read inside RCU critical section
1861  *  - sk_user_data must only be modified inside sk_callback_lock and read
1862  *    inside RCU critical section.
1863  *  - psock->maps list must only be read & modified inside sk_callback_lock
1864  *  - sock_map must use READ_ONCE and (cmp)xchg operations
1865  *  - BPF verdict/parse programs must use READ_ONCE and xchg operations
1866  */
1867
1868 static int __sock_map_ctx_update_elem(struct bpf_map *map,
1869                                       struct bpf_sock_progs *progs,
1870                                       struct sock *sock,
1871                                       void *key)
1872 {
1873         struct bpf_prog *verdict, *parse, *tx_msg;
1874         struct smap_psock *psock;
1875         bool new = false;
1876         int err = 0;
1877
1878         /* 1. If sock map has BPF programs those will be inherited by the
1879          * sock being added. If the sock is already attached to BPF programs
1880          * this results in an error.
1881          */
1882         verdict = READ_ONCE(progs->bpf_verdict);
1883         parse = READ_ONCE(progs->bpf_parse);
1884         tx_msg = READ_ONCE(progs->bpf_tx_msg);
1885
1886         if (parse && verdict) {
1887                 /* bpf prog refcnt may be zero if a concurrent attach operation
1888                  * removes the program after the above READ_ONCE() but before
1889                  * we increment the refcnt. If this is the case abort with an
1890                  * error.
1891                  */
1892                 verdict = bpf_prog_inc_not_zero(verdict);
1893                 if (IS_ERR(verdict))
1894                         return PTR_ERR(verdict);
1895
1896                 parse = bpf_prog_inc_not_zero(parse);
1897                 if (IS_ERR(parse)) {
1898                         bpf_prog_put(verdict);
1899                         return PTR_ERR(parse);
1900                 }
1901         }
1902
1903         if (tx_msg) {
1904                 tx_msg = bpf_prog_inc_not_zero(tx_msg);
1905                 if (IS_ERR(tx_msg)) {
1906                         if (parse && verdict) {
1907                                 bpf_prog_put(parse);
1908                                 bpf_prog_put(verdict);
1909                         }
1910                         return PTR_ERR(tx_msg);
1911                 }
1912         }
1913
1914         psock = smap_psock_sk(sock);
1915
1916         /* 2. Do not allow inheriting programs if psock exists and has
1917          * already inherited programs. This would create confusion on
1918          * which parser/verdict program is running. If no psock exists
1919          * create one. Inside sk_callback_lock to ensure concurrent create
1920          * doesn't update user data.
1921          */
1922         if (psock) {
1923                 if (!psock_is_smap_sk(sock)) {
1924                         err = -EBUSY;
1925                         goto out_progs;
1926                 }
1927                 if (READ_ONCE(psock->bpf_parse) && parse) {
1928                         err = -EBUSY;
1929                         goto out_progs;
1930                 }
1931                 if (READ_ONCE(psock->bpf_tx_msg) && tx_msg) {
1932                         err = -EBUSY;
1933                         goto out_progs;
1934                 }
1935                 if (!refcount_inc_not_zero(&psock->refcnt)) {
1936                         err = -EAGAIN;
1937                         goto out_progs;
1938                 }
1939         } else {
1940                 psock = smap_init_psock(sock, map->numa_node);
1941                 if (IS_ERR(psock)) {
1942                         err = PTR_ERR(psock);
1943                         goto out_progs;
1944                 }
1945
1946                 set_bit(SMAP_TX_RUNNING, &psock->state);
1947                 new = true;
1948         }
1949
1950         /* 3. At this point we have a reference to a valid psock that is
1951          * running. Attach any BPF programs needed.
1952          */
1953         if (tx_msg)
1954                 bpf_tcp_msg_add(psock, sock, tx_msg);
1955         if (new) {
1956                 err = tcp_set_ulp_id(sock, TCP_ULP_BPF);
1957                 if (err)
1958                         goto out_free;
1959         }
1960
1961         if (parse && verdict && !psock->strp_enabled) {
1962                 err = smap_init_sock(psock, sock);
1963                 if (err)
1964                         goto out_free;
1965                 smap_init_progs(psock, verdict, parse);
1966                 write_lock_bh(&sock->sk_callback_lock);
1967                 smap_start_sock(psock, sock);
1968                 write_unlock_bh(&sock->sk_callback_lock);
1969         }
1970
1971         return err;
1972 out_free:
1973         smap_release_sock(psock, sock);
1974 out_progs:
1975         if (parse && verdict) {
1976                 bpf_prog_put(parse);
1977                 bpf_prog_put(verdict);
1978         }
1979         if (tx_msg)
1980                 bpf_prog_put(tx_msg);
1981         return err;
1982 }
1983
1984 static int sock_map_ctx_update_elem(struct bpf_sock_ops_kern *skops,
1985                                     struct bpf_map *map,
1986                                     void *key, u64 flags)
1987 {
1988         struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
1989         struct bpf_sock_progs *progs = &stab->progs;
1990         struct sock *osock, *sock = skops->sk;
1991         struct smap_psock_map_entry *e;
1992         struct smap_psock *psock;
1993         u32 i = *(u32 *)key;
1994         int err;
1995
1996         if (unlikely(flags > BPF_EXIST))
1997                 return -EINVAL;
1998         if (unlikely(i >= stab->map.max_entries))
1999                 return -E2BIG;
2000
2001         e = kzalloc(sizeof(*e), GFP_ATOMIC | __GFP_NOWARN);
2002         if (!e)
2003                 return -ENOMEM;
2004
2005         err = __sock_map_ctx_update_elem(map, progs, sock, key);
2006         if (err)
2007                 goto out;
2008
2009         /* psock guaranteed to be present. */
2010         psock = smap_psock_sk(sock);
2011         raw_spin_lock_bh(&stab->lock);
2012         osock = stab->sock_map[i];
2013         if (osock && flags == BPF_NOEXIST) {
2014                 err = -EEXIST;
2015                 goto out_unlock;
2016         }
2017         if (!osock && flags == BPF_EXIST) {
2018                 err = -ENOENT;
2019                 goto out_unlock;
2020         }
2021
2022         e->entry = &stab->sock_map[i];
2023         e->map = map;
2024         spin_lock_bh(&psock->maps_lock);
2025         list_add_tail(&e->list, &psock->maps);
2026         spin_unlock_bh(&psock->maps_lock);
2027
2028         stab->sock_map[i] = sock;
2029         if (osock) {
2030                 psock = smap_psock_sk(osock);
2031                 smap_list_map_remove(psock, &stab->sock_map[i]);
2032                 smap_release_sock(psock, osock);
2033         }
2034         raw_spin_unlock_bh(&stab->lock);
2035         return 0;
2036 out_unlock:
2037         smap_release_sock(psock, sock);
2038         raw_spin_unlock_bh(&stab->lock);
2039 out:
2040         kfree(e);
2041         return err;
2042 }
2043
2044 int sock_map_prog(struct bpf_map *map, struct bpf_prog *prog, u32 type)
2045 {
2046         struct bpf_sock_progs *progs;
2047         struct bpf_prog *orig;
2048
2049         if (map->map_type == BPF_MAP_TYPE_SOCKMAP) {
2050                 struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
2051
2052                 progs = &stab->progs;
2053         } else if (map->map_type == BPF_MAP_TYPE_SOCKHASH) {
2054                 struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
2055
2056                 progs = &htab->progs;
2057         } else {
2058                 return -EINVAL;
2059         }
2060
2061         switch (type) {
2062         case BPF_SK_MSG_VERDICT:
2063                 orig = xchg(&progs->bpf_tx_msg, prog);
2064                 break;
2065         case BPF_SK_SKB_STREAM_PARSER:
2066                 orig = xchg(&progs->bpf_parse, prog);
2067                 break;
2068         case BPF_SK_SKB_STREAM_VERDICT:
2069                 orig = xchg(&progs->bpf_verdict, prog);
2070                 break;
2071         default:
2072                 return -EOPNOTSUPP;
2073         }
2074
2075         if (orig)
2076                 bpf_prog_put(orig);
2077
2078         return 0;
2079 }
2080
2081 int sockmap_get_from_fd(const union bpf_attr *attr, int type,
2082                         struct bpf_prog *prog)
2083 {
2084         int ufd = attr->target_fd;
2085         struct bpf_map *map;
2086         struct fd f;
2087         int err;
2088
2089         f = fdget(ufd);
2090         map = __bpf_map_get(f);
2091         if (IS_ERR(map))
2092                 return PTR_ERR(map);
2093
2094         err = sock_map_prog(map, prog, attr->attach_type);
2095         fdput(f);
2096         return err;
2097 }
2098
2099 static void *sock_map_lookup(struct bpf_map *map, void *key)
2100 {
2101         return NULL;
2102 }
2103
2104 static int sock_map_update_elem(struct bpf_map *map,
2105                                 void *key, void *value, u64 flags)
2106 {
2107         struct bpf_sock_ops_kern skops;
2108         u32 fd = *(u32 *)value;
2109         struct socket *socket;
2110         int err;
2111
2112         socket = sockfd_lookup(fd, &err);
2113         if (!socket)
2114                 return err;
2115
2116         skops.sk = socket->sk;
2117         if (!skops.sk) {
2118                 fput(socket->file);
2119                 return -EINVAL;
2120         }
2121
2122         /* ULPs are currently supported only for TCP sockets in ESTABLISHED
2123          * state.
2124          */
2125         if (skops.sk->sk_type != SOCK_STREAM ||
2126             skops.sk->sk_protocol != IPPROTO_TCP ||
2127             skops.sk->sk_state != TCP_ESTABLISHED) {
2128                 fput(socket->file);
2129                 return -EOPNOTSUPP;
2130         }
2131
2132         lock_sock(skops.sk);
2133         preempt_disable();
2134         rcu_read_lock();
2135         err = sock_map_ctx_update_elem(&skops, map, key, flags);
2136         rcu_read_unlock();
2137         preempt_enable();
2138         release_sock(skops.sk);
2139         fput(socket->file);
2140         return err;
2141 }
2142
2143 static void sock_map_release(struct bpf_map *map)
2144 {
2145         struct bpf_sock_progs *progs;
2146         struct bpf_prog *orig;
2147
2148         if (map->map_type == BPF_MAP_TYPE_SOCKMAP) {
2149                 struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
2150
2151                 progs = &stab->progs;
2152         } else {
2153                 struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
2154
2155                 progs = &htab->progs;
2156         }
2157
2158         orig = xchg(&progs->bpf_parse, NULL);
2159         if (orig)
2160                 bpf_prog_put(orig);
2161         orig = xchg(&progs->bpf_verdict, NULL);
2162         if (orig)
2163                 bpf_prog_put(orig);
2164
2165         orig = xchg(&progs->bpf_tx_msg, NULL);
2166         if (orig)
2167                 bpf_prog_put(orig);
2168 }
2169
2170 static struct bpf_map *sock_hash_alloc(union bpf_attr *attr)
2171 {
2172         struct bpf_htab *htab;
2173         int i, err;
2174         u64 cost;
2175
2176         if (!capable(CAP_NET_ADMIN))
2177                 return ERR_PTR(-EPERM);
2178
2179         /* check sanity of attributes */
2180         if (attr->max_entries == 0 ||
2181             attr->key_size == 0 ||
2182             attr->value_size != 4 ||
2183             attr->map_flags & ~SOCK_CREATE_FLAG_MASK)
2184                 return ERR_PTR(-EINVAL);
2185
2186         if (attr->key_size > MAX_BPF_STACK)
2187                 /* eBPF programs initialize keys on stack, so they cannot be
2188                  * larger than max stack size
2189                  */
2190                 return ERR_PTR(-E2BIG);
2191
2192         err = bpf_tcp_ulp_register();
2193         if (err && err != -EEXIST)
2194                 return ERR_PTR(err);
2195
2196         htab = kzalloc(sizeof(*htab), GFP_USER);
2197         if (!htab)
2198                 return ERR_PTR(-ENOMEM);
2199
2200         bpf_map_init_from_attr(&htab->map, attr);
2201
2202         htab->n_buckets = roundup_pow_of_two(htab->map.max_entries);
2203         htab->elem_size = sizeof(struct htab_elem) +
2204                           round_up(htab->map.key_size, 8);
2205         err = -EINVAL;
2206         if (htab->n_buckets == 0 ||
2207             htab->n_buckets > U32_MAX / sizeof(struct bucket))
2208                 goto free_htab;
2209
2210         cost = (u64) htab->n_buckets * sizeof(struct bucket) +
2211                (u64) htab->elem_size * htab->map.max_entries;
2212
2213         if (cost >= U32_MAX - PAGE_SIZE)
2214                 goto free_htab;
2215
2216         htab->map.pages = round_up(cost, PAGE_SIZE) >> PAGE_SHIFT;
2217         err = bpf_map_precharge_memlock(htab->map.pages);
2218         if (err)
2219                 goto free_htab;
2220
2221         err = -ENOMEM;
2222         htab->buckets = bpf_map_area_alloc(
2223                                 htab->n_buckets * sizeof(struct bucket),
2224                                 htab->map.numa_node);
2225         if (!htab->buckets)
2226                 goto free_htab;
2227
2228         for (i = 0; i < htab->n_buckets; i++) {
2229                 INIT_HLIST_HEAD(&htab->buckets[i].head);
2230                 raw_spin_lock_init(&htab->buckets[i].lock);
2231         }
2232
2233         return &htab->map;
2234 free_htab:
2235         kfree(htab);
2236         return ERR_PTR(err);
2237 }
2238
2239 static void __bpf_htab_free(struct rcu_head *rcu)
2240 {
2241         struct bpf_htab *htab;
2242
2243         htab = container_of(rcu, struct bpf_htab, rcu);
2244         bpf_map_area_free(htab->buckets);
2245         kfree(htab);
2246 }
2247
2248 static void sock_hash_free(struct bpf_map *map)
2249 {
2250         struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
2251         int i;
2252
2253         synchronize_rcu();
2254
2255         /* At this point no update, lookup or delete operations can happen.
2256          * However, be aware we can still get a socket state event updates,
2257          * and data ready callabacks that reference the psock from sk_user_data
2258          * Also psock worker threads are still in-flight. So smap_release_sock
2259          * will only free the psock after cancel_sync on the worker threads
2260          * and a grace period expire to ensure psock is really safe to remove.
2261          */
2262         rcu_read_lock();
2263         for (i = 0; i < htab->n_buckets; i++) {
2264                 struct bucket *b = __select_bucket(htab, i);
2265                 struct hlist_head *head;
2266                 struct hlist_node *n;
2267                 struct htab_elem *l;
2268
2269                 raw_spin_lock_bh(&b->lock);
2270                 head = &b->head;
2271                 hlist_for_each_entry_safe(l, n, head, hash_node) {
2272                         struct sock *sock = l->sk;
2273                         struct smap_psock *psock;
2274
2275                         hlist_del_rcu(&l->hash_node);
2276                         psock = smap_psock_sk(sock);
2277                         /* This check handles a racing sock event that can get
2278                          * the sk_callback_lock before this case but after xchg
2279                          * causing the refcnt to hit zero and sock user data
2280                          * (psock) to be null and queued for garbage collection.
2281                          */
2282                         if (likely(psock)) {
2283                                 smap_list_hash_remove(psock, l);
2284                                 smap_release_sock(psock, sock);
2285                         }
2286                         free_htab_elem(htab, l);
2287                 }
2288                 raw_spin_unlock_bh(&b->lock);
2289         }
2290         rcu_read_unlock();
2291         call_rcu(&htab->rcu, __bpf_htab_free);
2292 }
2293
2294 static struct htab_elem *alloc_sock_hash_elem(struct bpf_htab *htab,
2295                                               void *key, u32 key_size, u32 hash,
2296                                               struct sock *sk,
2297                                               struct htab_elem *old_elem)
2298 {
2299         struct htab_elem *l_new;
2300
2301         if (atomic_inc_return(&htab->count) > htab->map.max_entries) {
2302                 if (!old_elem) {
2303                         atomic_dec(&htab->count);
2304                         return ERR_PTR(-E2BIG);
2305                 }
2306         }
2307         l_new = kmalloc_node(htab->elem_size, GFP_ATOMIC | __GFP_NOWARN,
2308                              htab->map.numa_node);
2309         if (!l_new) {
2310                 atomic_dec(&htab->count);
2311                 return ERR_PTR(-ENOMEM);
2312         }
2313
2314         memcpy(l_new->key, key, key_size);
2315         l_new->sk = sk;
2316         l_new->hash = hash;
2317         return l_new;
2318 }
2319
2320 static inline u32 htab_map_hash(const void *key, u32 key_len)
2321 {
2322         return jhash(key, key_len, 0);
2323 }
2324
2325 static int sock_hash_get_next_key(struct bpf_map *map,
2326                                   void *key, void *next_key)
2327 {
2328         struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
2329         struct htab_elem *l, *next_l;
2330         struct hlist_head *h;
2331         u32 hash, key_size;
2332         int i = 0;
2333
2334         WARN_ON_ONCE(!rcu_read_lock_held());
2335
2336         key_size = map->key_size;
2337         if (!key)
2338                 goto find_first_elem;
2339         hash = htab_map_hash(key, key_size);
2340         h = select_bucket(htab, hash);
2341
2342         l = lookup_elem_raw(h, hash, key, key_size);
2343         if (!l)
2344                 goto find_first_elem;
2345         next_l = hlist_entry_safe(
2346                      rcu_dereference_raw(hlist_next_rcu(&l->hash_node)),
2347                      struct htab_elem, hash_node);
2348         if (next_l) {
2349                 memcpy(next_key, next_l->key, key_size);
2350                 return 0;
2351         }
2352
2353         /* no more elements in this hash list, go to the next bucket */
2354         i = hash & (htab->n_buckets - 1);
2355         i++;
2356
2357 find_first_elem:
2358         /* iterate over buckets */
2359         for (; i < htab->n_buckets; i++) {
2360                 h = select_bucket(htab, i);
2361
2362                 /* pick first element in the bucket */
2363                 next_l = hlist_entry_safe(
2364                                 rcu_dereference_raw(hlist_first_rcu(h)),
2365                                 struct htab_elem, hash_node);
2366                 if (next_l) {
2367                         /* if it's not empty, just return it */
2368                         memcpy(next_key, next_l->key, key_size);
2369                         return 0;
2370                 }
2371         }
2372
2373         /* iterated over all buckets and all elements */
2374         return -ENOENT;
2375 }
2376
2377 static int sock_hash_ctx_update_elem(struct bpf_sock_ops_kern *skops,
2378                                      struct bpf_map *map,
2379                                      void *key, u64 map_flags)
2380 {
2381         struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
2382         struct bpf_sock_progs *progs = &htab->progs;
2383         struct htab_elem *l_new = NULL, *l_old;
2384         struct smap_psock_map_entry *e = NULL;
2385         struct hlist_head *head;
2386         struct smap_psock *psock;
2387         u32 key_size, hash;
2388         struct sock *sock;
2389         struct bucket *b;
2390         int err;
2391
2392         sock = skops->sk;
2393
2394         if (sock->sk_type != SOCK_STREAM ||
2395             sock->sk_protocol != IPPROTO_TCP)
2396                 return -EOPNOTSUPP;
2397
2398         if (unlikely(map_flags > BPF_EXIST))
2399                 return -EINVAL;
2400
2401         e = kzalloc(sizeof(*e), GFP_ATOMIC | __GFP_NOWARN);
2402         if (!e)
2403                 return -ENOMEM;
2404
2405         WARN_ON_ONCE(!rcu_read_lock_held());
2406         key_size = map->key_size;
2407         hash = htab_map_hash(key, key_size);
2408         b = __select_bucket(htab, hash);
2409         head = &b->head;
2410
2411         err = __sock_map_ctx_update_elem(map, progs, sock, key);
2412         if (err)
2413                 goto err;
2414
2415         /* psock is valid here because otherwise above *ctx_update_elem would
2416          * have thrown an error. It is safe to skip error check.
2417          */
2418         psock = smap_psock_sk(sock);
2419         raw_spin_lock_bh(&b->lock);
2420         l_old = lookup_elem_raw(head, hash, key, key_size);
2421         if (l_old && map_flags == BPF_NOEXIST) {
2422                 err = -EEXIST;
2423                 goto bucket_err;
2424         }
2425         if (!l_old && map_flags == BPF_EXIST) {
2426                 err = -ENOENT;
2427                 goto bucket_err;
2428         }
2429
2430         l_new = alloc_sock_hash_elem(htab, key, key_size, hash, sock, l_old);
2431         if (IS_ERR(l_new)) {
2432                 err = PTR_ERR(l_new);
2433                 goto bucket_err;
2434         }
2435
2436         rcu_assign_pointer(e->hash_link, l_new);
2437         e->map = map;
2438         spin_lock_bh(&psock->maps_lock);
2439         list_add_tail(&e->list, &psock->maps);
2440         spin_unlock_bh(&psock->maps_lock);
2441
2442         /* add new element to the head of the list, so that
2443          * concurrent search will find it before old elem
2444          */
2445         hlist_add_head_rcu(&l_new->hash_node, head);
2446         if (l_old) {
2447                 psock = smap_psock_sk(l_old->sk);
2448
2449                 hlist_del_rcu(&l_old->hash_node);
2450                 smap_list_hash_remove(psock, l_old);
2451                 smap_release_sock(psock, l_old->sk);
2452                 free_htab_elem(htab, l_old);
2453         }
2454         raw_spin_unlock_bh(&b->lock);
2455         return 0;
2456 bucket_err:
2457         smap_release_sock(psock, sock);
2458         raw_spin_unlock_bh(&b->lock);
2459 err:
2460         kfree(e);
2461         return err;
2462 }
2463
2464 static int sock_hash_update_elem(struct bpf_map *map,
2465                                 void *key, void *value, u64 flags)
2466 {
2467         struct bpf_sock_ops_kern skops;
2468         u32 fd = *(u32 *)value;
2469         struct socket *socket;
2470         int err;
2471
2472         socket = sockfd_lookup(fd, &err);
2473         if (!socket)
2474                 return err;
2475
2476         skops.sk = socket->sk;
2477         if (!skops.sk) {
2478                 fput(socket->file);
2479                 return -EINVAL;
2480         }
2481
2482         /* ULPs are currently supported only for TCP sockets in ESTABLISHED
2483          * state.
2484          */
2485         if (skops.sk->sk_type != SOCK_STREAM ||
2486             skops.sk->sk_protocol != IPPROTO_TCP ||
2487             skops.sk->sk_state != TCP_ESTABLISHED) {
2488                 fput(socket->file);
2489                 return -EOPNOTSUPP;
2490         }
2491
2492         lock_sock(skops.sk);
2493         preempt_disable();
2494         rcu_read_lock();
2495         err = sock_hash_ctx_update_elem(&skops, map, key, flags);
2496         rcu_read_unlock();
2497         preempt_enable();
2498         release_sock(skops.sk);
2499         fput(socket->file);
2500         return err;
2501 }
2502
2503 static int sock_hash_delete_elem(struct bpf_map *map, void *key)
2504 {
2505         struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
2506         struct hlist_head *head;
2507         struct bucket *b;
2508         struct htab_elem *l;
2509         u32 hash, key_size;
2510         int ret = -ENOENT;
2511
2512         key_size = map->key_size;
2513         hash = htab_map_hash(key, key_size);
2514         b = __select_bucket(htab, hash);
2515         head = &b->head;
2516
2517         raw_spin_lock_bh(&b->lock);
2518         l = lookup_elem_raw(head, hash, key, key_size);
2519         if (l) {
2520                 struct sock *sock = l->sk;
2521                 struct smap_psock *psock;
2522
2523                 hlist_del_rcu(&l->hash_node);
2524                 psock = smap_psock_sk(sock);
2525                 /* This check handles a racing sock event that can get the
2526                  * sk_callback_lock before this case but after xchg happens
2527                  * causing the refcnt to hit zero and sock user data (psock)
2528                  * to be null and queued for garbage collection.
2529                  */
2530                 if (likely(psock)) {
2531                         smap_list_hash_remove(psock, l);
2532                         smap_release_sock(psock, sock);
2533                 }
2534                 free_htab_elem(htab, l);
2535                 ret = 0;
2536         }
2537         raw_spin_unlock_bh(&b->lock);
2538         return ret;
2539 }
2540
2541 struct sock  *__sock_hash_lookup_elem(struct bpf_map *map, void *key)
2542 {
2543         struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
2544         struct hlist_head *head;
2545         struct htab_elem *l;
2546         u32 key_size, hash;
2547         struct bucket *b;
2548         struct sock *sk;
2549
2550         key_size = map->key_size;
2551         hash = htab_map_hash(key, key_size);
2552         b = __select_bucket(htab, hash);
2553         head = &b->head;
2554
2555         l = lookup_elem_raw(head, hash, key, key_size);
2556         sk = l ? l->sk : NULL;
2557         return sk;
2558 }
2559
2560 const struct bpf_map_ops sock_map_ops = {
2561         .map_alloc = sock_map_alloc,
2562         .map_free = sock_map_free,
2563         .map_lookup_elem = sock_map_lookup,
2564         .map_get_next_key = sock_map_get_next_key,
2565         .map_update_elem = sock_map_update_elem,
2566         .map_delete_elem = sock_map_delete_elem,
2567         .map_release_uref = sock_map_release,
2568         .map_check_btf = map_check_no_btf,
2569 };
2570
2571 const struct bpf_map_ops sock_hash_ops = {
2572         .map_alloc = sock_hash_alloc,
2573         .map_free = sock_hash_free,
2574         .map_lookup_elem = sock_map_lookup,
2575         .map_get_next_key = sock_hash_get_next_key,
2576         .map_update_elem = sock_hash_update_elem,
2577         .map_delete_elem = sock_hash_delete_elem,
2578         .map_release_uref = sock_map_release,
2579         .map_check_btf = map_check_no_btf,
2580 };
2581
2582 static bool bpf_is_valid_sock_op(struct bpf_sock_ops_kern *ops)
2583 {
2584         return ops->op == BPF_SOCK_OPS_PASSIVE_ESTABLISHED_CB ||
2585                ops->op == BPF_SOCK_OPS_ACTIVE_ESTABLISHED_CB;
2586 }
2587 BPF_CALL_4(bpf_sock_map_update, struct bpf_sock_ops_kern *, bpf_sock,
2588            struct bpf_map *, map, void *, key, u64, flags)
2589 {
2590         WARN_ON_ONCE(!rcu_read_lock_held());
2591
2592         /* ULPs are currently supported only for TCP sockets in ESTABLISHED
2593          * state. This checks that the sock ops triggering the update is
2594          * one indicating we are (or will be soon) in an ESTABLISHED state.
2595          */
2596         if (!bpf_is_valid_sock_op(bpf_sock))
2597                 return -EOPNOTSUPP;
2598         return sock_map_ctx_update_elem(bpf_sock, map, key, flags);
2599 }
2600
2601 const struct bpf_func_proto bpf_sock_map_update_proto = {
2602         .func           = bpf_sock_map_update,
2603         .gpl_only       = false,
2604         .pkt_access     = true,
2605         .ret_type       = RET_INTEGER,
2606         .arg1_type      = ARG_PTR_TO_CTX,
2607         .arg2_type      = ARG_CONST_MAP_PTR,
2608         .arg3_type      = ARG_PTR_TO_MAP_KEY,
2609         .arg4_type      = ARG_ANYTHING,
2610 };
2611
2612 BPF_CALL_4(bpf_sock_hash_update, struct bpf_sock_ops_kern *, bpf_sock,
2613            struct bpf_map *, map, void *, key, u64, flags)
2614 {
2615         WARN_ON_ONCE(!rcu_read_lock_held());
2616
2617         if (!bpf_is_valid_sock_op(bpf_sock))
2618                 return -EOPNOTSUPP;
2619         return sock_hash_ctx_update_elem(bpf_sock, map, key, flags);
2620 }
2621
2622 const struct bpf_func_proto bpf_sock_hash_update_proto = {
2623         .func           = bpf_sock_hash_update,
2624         .gpl_only       = false,
2625         .pkt_access     = true,
2626         .ret_type       = RET_INTEGER,
2627         .arg1_type      = ARG_PTR_TO_CTX,
2628         .arg2_type      = ARG_CONST_MAP_PTR,
2629         .arg3_type      = ARG_PTR_TO_MAP_KEY,
2630         .arg4_type      = ARG_ANYTHING,
2631 };