GNU Linux-libre 4.14.290-gnu1
[releases.git] / net / core / lwt_bpf.c
1 /* Copyright (c) 2016 Thomas Graf <tgraf@tgraf.ch>
2  *
3  * This program is free software; you can redistribute it and/or
4  * modify it under the terms of version 2 of the GNU General Public
5  * License as published by the Free Software Foundation.
6  *
7  * This program is distributed in the hope that it will be useful, but
8  * WITHOUT ANY WARRANTY; without even the implied warranty of
9  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
10  * General Public License for more details.
11  */
12
13 #include <linux/kernel.h>
14 #include <linux/module.h>
15 #include <linux/skbuff.h>
16 #include <linux/types.h>
17 #include <linux/bpf.h>
18 #include <net/lwtunnel.h>
19
20 struct bpf_lwt_prog {
21         struct bpf_prog *prog;
22         char *name;
23 };
24
25 struct bpf_lwt {
26         struct bpf_lwt_prog in;
27         struct bpf_lwt_prog out;
28         struct bpf_lwt_prog xmit;
29         int family;
30 };
31
32 #define MAX_PROG_NAME 256
33
34 static inline struct bpf_lwt *bpf_lwt_lwtunnel(struct lwtunnel_state *lwt)
35 {
36         return (struct bpf_lwt *)lwt->data;
37 }
38
39 #define NO_REDIRECT false
40 #define CAN_REDIRECT true
41
42 static int run_lwt_bpf(struct sk_buff *skb, struct bpf_lwt_prog *lwt,
43                        struct dst_entry *dst, bool can_redirect)
44 {
45         int ret;
46
47         /* Preempt disable is needed to protect per-cpu redirect_info between
48          * BPF prog and skb_do_redirect(). The call_rcu in bpf_prog_put() and
49          * access to maps strictly require a rcu_read_lock() for protection,
50          * mixing with BH RCU lock doesn't work.
51          */
52         preempt_disable();
53         rcu_read_lock();
54         bpf_compute_data_end(skb);
55         ret = bpf_prog_run_save_cb(lwt->prog, skb);
56         rcu_read_unlock();
57
58         switch (ret) {
59         case BPF_OK:
60                 break;
61
62         case BPF_REDIRECT:
63                 if (unlikely(!can_redirect)) {
64                         pr_warn_once("Illegal redirect return code in prog %s\n",
65                                      lwt->name ? : "<unknown>");
66                         ret = BPF_OK;
67                 } else {
68                         skb_reset_mac_header(skb);
69                         ret = skb_do_redirect(skb);
70                         if (ret == 0)
71                                 ret = BPF_REDIRECT;
72                 }
73                 break;
74
75         case BPF_DROP:
76                 kfree_skb(skb);
77                 ret = -EPERM;
78                 break;
79
80         default:
81                 pr_warn_once("bpf-lwt: Illegal return value %u, expect packet loss\n", ret);
82                 kfree_skb(skb);
83                 ret = -EINVAL;
84                 break;
85         }
86
87         preempt_enable();
88
89         return ret;
90 }
91
92 static int bpf_input(struct sk_buff *skb)
93 {
94         struct dst_entry *dst = skb_dst(skb);
95         struct bpf_lwt *bpf;
96         int ret;
97
98         bpf = bpf_lwt_lwtunnel(dst->lwtstate);
99         if (bpf->in.prog) {
100                 ret = run_lwt_bpf(skb, &bpf->in, dst, NO_REDIRECT);
101                 if (ret < 0)
102                         return ret;
103         }
104
105         if (unlikely(!dst->lwtstate->orig_input)) {
106                 pr_warn_once("orig_input not set on dst for prog %s\n",
107                              bpf->out.name);
108                 kfree_skb(skb);
109                 return -EINVAL;
110         }
111
112         return dst->lwtstate->orig_input(skb);
113 }
114
115 static int bpf_output(struct net *net, struct sock *sk, struct sk_buff *skb)
116 {
117         struct dst_entry *dst = skb_dst(skb);
118         struct bpf_lwt *bpf;
119         int ret;
120
121         bpf = bpf_lwt_lwtunnel(dst->lwtstate);
122         if (bpf->out.prog) {
123                 ret = run_lwt_bpf(skb, &bpf->out, dst, NO_REDIRECT);
124                 if (ret < 0)
125                         return ret;
126         }
127
128         if (unlikely(!dst->lwtstate->orig_output)) {
129                 pr_warn_once("orig_output not set on dst for prog %s\n",
130                              bpf->out.name);
131                 kfree_skb(skb);
132                 return -EINVAL;
133         }
134
135         return dst->lwtstate->orig_output(net, sk, skb);
136 }
137
138 static int xmit_check_hhlen(struct sk_buff *skb)
139 {
140         int hh_len = skb_dst(skb)->dev->hard_header_len;
141
142         if (skb_headroom(skb) < hh_len) {
143                 int nhead = HH_DATA_ALIGN(hh_len - skb_headroom(skb));
144
145                 if (pskb_expand_head(skb, nhead, 0, GFP_ATOMIC))
146                         return -ENOMEM;
147         }
148
149         return 0;
150 }
151
152 static int bpf_xmit(struct sk_buff *skb)
153 {
154         struct dst_entry *dst = skb_dst(skb);
155         struct bpf_lwt *bpf;
156
157         bpf = bpf_lwt_lwtunnel(dst->lwtstate);
158         if (bpf->xmit.prog) {
159                 int ret;
160
161                 ret = run_lwt_bpf(skb, &bpf->xmit, dst, CAN_REDIRECT);
162                 switch (ret) {
163                 case BPF_OK:
164                         /* If the header was expanded, headroom might be too
165                          * small for L2 header to come, expand as needed.
166                          */
167                         ret = xmit_check_hhlen(skb);
168                         if (unlikely(ret))
169                                 return ret;
170
171                         return LWTUNNEL_XMIT_CONTINUE;
172                 case BPF_REDIRECT:
173                         return LWTUNNEL_XMIT_DONE;
174                 default:
175                         return ret;
176                 }
177         }
178
179         return LWTUNNEL_XMIT_CONTINUE;
180 }
181
182 static void bpf_lwt_prog_destroy(struct bpf_lwt_prog *prog)
183 {
184         if (prog->prog)
185                 bpf_prog_put(prog->prog);
186
187         kfree(prog->name);
188 }
189
190 static void bpf_destroy_state(struct lwtunnel_state *lwt)
191 {
192         struct bpf_lwt *bpf = bpf_lwt_lwtunnel(lwt);
193
194         bpf_lwt_prog_destroy(&bpf->in);
195         bpf_lwt_prog_destroy(&bpf->out);
196         bpf_lwt_prog_destroy(&bpf->xmit);
197 }
198
199 static const struct nla_policy bpf_prog_policy[LWT_BPF_PROG_MAX + 1] = {
200         [LWT_BPF_PROG_FD]   = { .type = NLA_U32, },
201         [LWT_BPF_PROG_NAME] = { .type = NLA_NUL_STRING,
202                                 .len = MAX_PROG_NAME },
203 };
204
205 static int bpf_parse_prog(struct nlattr *attr, struct bpf_lwt_prog *prog,
206                           enum bpf_prog_type type)
207 {
208         struct nlattr *tb[LWT_BPF_PROG_MAX + 1];
209         struct bpf_prog *p;
210         int ret;
211         u32 fd;
212
213         ret = nla_parse_nested(tb, LWT_BPF_PROG_MAX, attr, bpf_prog_policy,
214                                NULL);
215         if (ret < 0)
216                 return ret;
217
218         if (!tb[LWT_BPF_PROG_FD] || !tb[LWT_BPF_PROG_NAME])
219                 return -EINVAL;
220
221         prog->name = nla_memdup(tb[LWT_BPF_PROG_NAME], GFP_ATOMIC);
222         if (!prog->name)
223                 return -ENOMEM;
224
225         fd = nla_get_u32(tb[LWT_BPF_PROG_FD]);
226         p = bpf_prog_get_type(fd, type);
227         if (IS_ERR(p))
228                 return PTR_ERR(p);
229
230         prog->prog = p;
231
232         return 0;
233 }
234
235 static const struct nla_policy bpf_nl_policy[LWT_BPF_MAX + 1] = {
236         [LWT_BPF_IN]            = { .type = NLA_NESTED, },
237         [LWT_BPF_OUT]           = { .type = NLA_NESTED, },
238         [LWT_BPF_XMIT]          = { .type = NLA_NESTED, },
239         [LWT_BPF_XMIT_HEADROOM] = { .type = NLA_U32 },
240 };
241
242 static int bpf_build_state(struct nlattr *nla,
243                            unsigned int family, const void *cfg,
244                            struct lwtunnel_state **ts,
245                            struct netlink_ext_ack *extack)
246 {
247         struct nlattr *tb[LWT_BPF_MAX + 1];
248         struct lwtunnel_state *newts;
249         struct bpf_lwt *bpf;
250         int ret;
251
252         if (family != AF_INET && family != AF_INET6)
253                 return -EAFNOSUPPORT;
254
255         ret = nla_parse_nested(tb, LWT_BPF_MAX, nla, bpf_nl_policy, extack);
256         if (ret < 0)
257                 return ret;
258
259         if (!tb[LWT_BPF_IN] && !tb[LWT_BPF_OUT] && !tb[LWT_BPF_XMIT])
260                 return -EINVAL;
261
262         newts = lwtunnel_state_alloc(sizeof(*bpf));
263         if (!newts)
264                 return -ENOMEM;
265
266         newts->type = LWTUNNEL_ENCAP_BPF;
267         bpf = bpf_lwt_lwtunnel(newts);
268
269         if (tb[LWT_BPF_IN]) {
270                 newts->flags |= LWTUNNEL_STATE_INPUT_REDIRECT;
271                 ret = bpf_parse_prog(tb[LWT_BPF_IN], &bpf->in,
272                                      BPF_PROG_TYPE_LWT_IN);
273                 if (ret  < 0)
274                         goto errout;
275         }
276
277         if (tb[LWT_BPF_OUT]) {
278                 newts->flags |= LWTUNNEL_STATE_OUTPUT_REDIRECT;
279                 ret = bpf_parse_prog(tb[LWT_BPF_OUT], &bpf->out,
280                                      BPF_PROG_TYPE_LWT_OUT);
281                 if (ret < 0)
282                         goto errout;
283         }
284
285         if (tb[LWT_BPF_XMIT]) {
286                 newts->flags |= LWTUNNEL_STATE_XMIT_REDIRECT;
287                 ret = bpf_parse_prog(tb[LWT_BPF_XMIT], &bpf->xmit,
288                                      BPF_PROG_TYPE_LWT_XMIT);
289                 if (ret < 0)
290                         goto errout;
291         }
292
293         if (tb[LWT_BPF_XMIT_HEADROOM]) {
294                 u32 headroom = nla_get_u32(tb[LWT_BPF_XMIT_HEADROOM]);
295
296                 if (headroom > LWT_BPF_MAX_HEADROOM) {
297                         ret = -ERANGE;
298                         goto errout;
299                 }
300
301                 newts->headroom = headroom;
302         }
303
304         bpf->family = family;
305         *ts = newts;
306
307         return 0;
308
309 errout:
310         bpf_destroy_state(newts);
311         kfree(newts);
312         return ret;
313 }
314
315 static int bpf_fill_lwt_prog(struct sk_buff *skb, int attr,
316                              struct bpf_lwt_prog *prog)
317 {
318         struct nlattr *nest;
319
320         if (!prog->prog)
321                 return 0;
322
323         nest = nla_nest_start(skb, attr);
324         if (!nest)
325                 return -EMSGSIZE;
326
327         if (prog->name &&
328             nla_put_string(skb, LWT_BPF_PROG_NAME, prog->name))
329                 return -EMSGSIZE;
330
331         return nla_nest_end(skb, nest);
332 }
333
334 static int bpf_fill_encap_info(struct sk_buff *skb, struct lwtunnel_state *lwt)
335 {
336         struct bpf_lwt *bpf = bpf_lwt_lwtunnel(lwt);
337
338         if (bpf_fill_lwt_prog(skb, LWT_BPF_IN, &bpf->in) < 0 ||
339             bpf_fill_lwt_prog(skb, LWT_BPF_OUT, &bpf->out) < 0 ||
340             bpf_fill_lwt_prog(skb, LWT_BPF_XMIT, &bpf->xmit) < 0)
341                 return -EMSGSIZE;
342
343         return 0;
344 }
345
346 static int bpf_encap_nlsize(struct lwtunnel_state *lwtstate)
347 {
348         int nest_len = nla_total_size(sizeof(struct nlattr)) +
349                        nla_total_size(MAX_PROG_NAME) + /* LWT_BPF_PROG_NAME */
350                        0;
351
352         return nest_len + /* LWT_BPF_IN */
353                nest_len + /* LWT_BPF_OUT */
354                nest_len + /* LWT_BPF_XMIT */
355                0;
356 }
357
358 static int bpf_lwt_prog_cmp(struct bpf_lwt_prog *a, struct bpf_lwt_prog *b)
359 {
360         /* FIXME:
361          * The LWT state is currently rebuilt for delete requests which
362          * results in a new bpf_prog instance. Comparing names for now.
363          */
364         if (!a->name && !b->name)
365                 return 0;
366
367         if (!a->name || !b->name)
368                 return 1;
369
370         return strcmp(a->name, b->name);
371 }
372
373 static int bpf_encap_cmp(struct lwtunnel_state *a, struct lwtunnel_state *b)
374 {
375         struct bpf_lwt *a_bpf = bpf_lwt_lwtunnel(a);
376         struct bpf_lwt *b_bpf = bpf_lwt_lwtunnel(b);
377
378         return bpf_lwt_prog_cmp(&a_bpf->in, &b_bpf->in) ||
379                bpf_lwt_prog_cmp(&a_bpf->out, &b_bpf->out) ||
380                bpf_lwt_prog_cmp(&a_bpf->xmit, &b_bpf->xmit);
381 }
382
383 static const struct lwtunnel_encap_ops bpf_encap_ops = {
384         .build_state    = bpf_build_state,
385         .destroy_state  = bpf_destroy_state,
386         .input          = bpf_input,
387         .output         = bpf_output,
388         .xmit           = bpf_xmit,
389         .fill_encap     = bpf_fill_encap_info,
390         .get_encap_size = bpf_encap_nlsize,
391         .cmp_encap      = bpf_encap_cmp,
392         .owner          = THIS_MODULE,
393 };
394
395 static int __init bpf_lwt_init(void)
396 {
397         return lwtunnel_encap_add_ops(&bpf_encap_ops, LWTUNNEL_ENCAP_BPF);
398 }
399
400 subsys_initcall(bpf_lwt_init)