GNU Linux-libre 4.19.286-gnu1
[releases.git] / net / rxrpc / local_object.c
1 /* Local endpoint object management
2  *
3  * Copyright (C) 2016 Red Hat, Inc. All Rights Reserved.
4  * Written by David Howells (dhowells@redhat.com)
5  *
6  * This program is free software; you can redistribute it and/or
7  * modify it under the terms of the GNU General Public Licence
8  * as published by the Free Software Foundation; either version
9  * 2 of the Licence, or (at your option) any later version.
10  */
11
12 #define pr_fmt(fmt) KBUILD_MODNAME ": " fmt
13
14 #include <linux/module.h>
15 #include <linux/net.h>
16 #include <linux/skbuff.h>
17 #include <linux/slab.h>
18 #include <linux/udp.h>
19 #include <linux/ip.h>
20 #include <linux/hashtable.h>
21 #include <net/sock.h>
22 #include <net/udp.h>
23 #include <net/af_rxrpc.h>
24 #include "ar-internal.h"
25
26 static void rxrpc_local_processor(struct work_struct *);
27 static void rxrpc_local_rcu(struct rcu_head *);
28
29 /*
30  * Compare a local to an address.  Return -ve, 0 or +ve to indicate less than,
31  * same or greater than.
32  *
33  * We explicitly don't compare the RxRPC service ID as we want to reject
34  * conflicting uses by differing services.  Further, we don't want to share
35  * addresses with different options (IPv6), so we don't compare those bits
36  * either.
37  */
38 static long rxrpc_local_cmp_key(const struct rxrpc_local *local,
39                                 const struct sockaddr_rxrpc *srx)
40 {
41         long diff;
42
43         diff = ((local->srx.transport_type - srx->transport_type) ?:
44                 (local->srx.transport_len - srx->transport_len) ?:
45                 (local->srx.transport.family - srx->transport.family));
46         if (diff != 0)
47                 return diff;
48
49         switch (srx->transport.family) {
50         case AF_INET:
51                 /* If the choice of UDP port is left up to the transport, then
52                  * the endpoint record doesn't match.
53                  */
54                 return ((u16 __force)local->srx.transport.sin.sin_port -
55                         (u16 __force)srx->transport.sin.sin_port) ?:
56                         memcmp(&local->srx.transport.sin.sin_addr,
57                                &srx->transport.sin.sin_addr,
58                                sizeof(struct in_addr));
59 #ifdef CONFIG_AF_RXRPC_IPV6
60         case AF_INET6:
61                 /* If the choice of UDP6 port is left up to the transport, then
62                  * the endpoint record doesn't match.
63                  */
64                 return ((u16 __force)local->srx.transport.sin6.sin6_port -
65                         (u16 __force)srx->transport.sin6.sin6_port) ?:
66                         memcmp(&local->srx.transport.sin6.sin6_addr,
67                                &srx->transport.sin6.sin6_addr,
68                                sizeof(struct in6_addr));
69 #endif
70         default:
71                 BUG();
72         }
73 }
74
75 /*
76  * Allocate a new local endpoint.
77  */
78 static struct rxrpc_local *rxrpc_alloc_local(struct rxrpc_net *rxnet,
79                                              const struct sockaddr_rxrpc *srx)
80 {
81         struct rxrpc_local *local;
82
83         local = kzalloc(sizeof(struct rxrpc_local), GFP_KERNEL);
84         if (local) {
85                 atomic_set(&local->usage, 1);
86                 atomic_set(&local->active_users, 1);
87                 local->rxnet = rxnet;
88                 INIT_LIST_HEAD(&local->link);
89                 INIT_WORK(&local->processor, rxrpc_local_processor);
90                 init_rwsem(&local->defrag_sem);
91                 skb_queue_head_init(&local->reject_queue);
92                 skb_queue_head_init(&local->event_queue);
93                 local->client_conns = RB_ROOT;
94                 spin_lock_init(&local->client_conns_lock);
95                 spin_lock_init(&local->lock);
96                 rwlock_init(&local->services_lock);
97                 local->debug_id = atomic_inc_return(&rxrpc_debug_id);
98                 memcpy(&local->srx, srx, sizeof(*srx));
99                 local->srx.srx_service = 0;
100                 trace_rxrpc_local(local->debug_id, rxrpc_local_new, 1, NULL);
101         }
102
103         _leave(" = %p", local);
104         return local;
105 }
106
107 /*
108  * create the local socket
109  * - must be called with rxrpc_local_mutex locked
110  */
111 static int rxrpc_open_socket(struct rxrpc_local *local, struct net *net)
112 {
113         struct sock *usk;
114         int ret, opt;
115
116         _enter("%p{%d,%d}",
117                local, local->srx.transport_type, local->srx.transport.family);
118
119         /* create a socket to represent the local endpoint */
120         ret = sock_create_kern(net, local->srx.transport.family,
121                                local->srx.transport_type, 0, &local->socket);
122         if (ret < 0) {
123                 _leave(" = %d [socket]", ret);
124                 return ret;
125         }
126
127         /* set the socket up */
128         usk = local->socket->sk;
129         inet_sk(usk)->mc_loop = 0;
130
131         /* Enable CHECKSUM_UNNECESSARY to CHECKSUM_COMPLETE conversion */
132         inet_inc_convert_csum(usk);
133
134         rcu_assign_sk_user_data(usk, local);
135
136         udp_sk(usk)->encap_type = UDP_ENCAP_RXRPC;
137         udp_sk(usk)->encap_rcv = rxrpc_input_packet;
138         udp_sk(usk)->encap_destroy = NULL;
139         udp_sk(usk)->gro_receive = NULL;
140         udp_sk(usk)->gro_complete = NULL;
141
142         udp_encap_enable();
143 #if IS_ENABLED(CONFIG_AF_RXRPC_IPV6)
144         if (local->srx.transport.family == AF_INET6)
145                 udpv6_encap_enable();
146 #endif
147         usk->sk_error_report = rxrpc_error_report;
148
149         /* if a local address was supplied then bind it */
150         if (local->srx.transport_len > sizeof(sa_family_t)) {
151                 _debug("bind");
152                 ret = kernel_bind(local->socket,
153                                   (struct sockaddr *)&local->srx.transport,
154                                   local->srx.transport_len);
155                 if (ret < 0) {
156                         _debug("bind failed %d", ret);
157                         goto error;
158                 }
159         }
160
161         switch (local->srx.transport.family) {
162         case AF_INET6:
163                 /* we want to receive ICMPv6 errors */
164                 opt = 1;
165                 ret = kernel_setsockopt(local->socket, SOL_IPV6, IPV6_RECVERR,
166                                         (char *) &opt, sizeof(opt));
167                 if (ret < 0) {
168                         _debug("setsockopt failed");
169                         goto error;
170                 }
171
172                 /* Fall through and set IPv4 options too otherwise we don't get
173                  * errors from IPv4 packets sent through the IPv6 socket.
174                  */
175
176         case AF_INET:
177                 /* we want to receive ICMP errors */
178                 opt = 1;
179                 ret = kernel_setsockopt(local->socket, SOL_IP, IP_RECVERR,
180                                         (char *) &opt, sizeof(opt));
181                 if (ret < 0) {
182                         _debug("setsockopt failed");
183                         goto error;
184                 }
185
186                 /* we want to set the don't fragment bit */
187                 opt = IP_PMTUDISC_DO;
188                 ret = kernel_setsockopt(local->socket, SOL_IP, IP_MTU_DISCOVER,
189                                         (char *) &opt, sizeof(opt));
190                 if (ret < 0) {
191                         _debug("setsockopt failed");
192                         goto error;
193                 }
194
195                 /* We want receive timestamps. */
196                 opt = 1;
197                 ret = kernel_setsockopt(local->socket, SOL_SOCKET, SO_TIMESTAMPNS,
198                                         (char *)&opt, sizeof(opt));
199                 if (ret < 0) {
200                         _debug("setsockopt failed");
201                         goto error;
202                 }
203                 break;
204
205         default:
206                 BUG();
207         }
208
209         _leave(" = 0");
210         return 0;
211
212 error:
213         kernel_sock_shutdown(local->socket, SHUT_RDWR);
214         local->socket->sk->sk_user_data = NULL;
215         sock_release(local->socket);
216         local->socket = NULL;
217
218         _leave(" = %d", ret);
219         return ret;
220 }
221
222 /*
223  * Look up or create a new local endpoint using the specified local address.
224  */
225 struct rxrpc_local *rxrpc_lookup_local(struct net *net,
226                                        const struct sockaddr_rxrpc *srx)
227 {
228         struct rxrpc_local *local;
229         struct rxrpc_net *rxnet = rxrpc_net(net);
230         struct list_head *cursor;
231         const char *age;
232         long diff;
233         int ret;
234
235         _enter("{%d,%d,%pISp}",
236                srx->transport_type, srx->transport.family, &srx->transport);
237
238         mutex_lock(&rxnet->local_mutex);
239
240         for (cursor = rxnet->local_endpoints.next;
241              cursor != &rxnet->local_endpoints;
242              cursor = cursor->next) {
243                 local = list_entry(cursor, struct rxrpc_local, link);
244
245                 diff = rxrpc_local_cmp_key(local, srx);
246                 if (diff < 0)
247                         continue;
248                 if (diff > 0)
249                         break;
250
251                 /* Services aren't allowed to share transport sockets, so
252                  * reject that here.  It is possible that the object is dying -
253                  * but it may also still have the local transport address that
254                  * we want bound.
255                  */
256                 if (srx->srx_service) {
257                         local = NULL;
258                         goto addr_in_use;
259                 }
260
261                 /* Found a match.  We replace a dying object.  Attempting to
262                  * bind the transport socket may still fail if we're attempting
263                  * to use a local address that the dying object is still using.
264                  */
265                 if (!rxrpc_use_local(local))
266                         break;
267
268                 age = "old";
269                 goto found;
270         }
271
272         local = rxrpc_alloc_local(rxnet, srx);
273         if (!local)
274                 goto nomem;
275
276         ret = rxrpc_open_socket(local, net);
277         if (ret < 0)
278                 goto sock_error;
279
280         if (cursor != &rxnet->local_endpoints)
281                 list_replace_init(cursor, &local->link);
282         else
283                 list_add_tail(&local->link, cursor);
284         age = "new";
285
286 found:
287         mutex_unlock(&rxnet->local_mutex);
288
289         _net("LOCAL %s %d {%pISp}",
290              age, local->debug_id, &local->srx.transport);
291
292         _leave(" = %p", local);
293         return local;
294
295 nomem:
296         ret = -ENOMEM;
297 sock_error:
298         mutex_unlock(&rxnet->local_mutex);
299         if (local)
300                 call_rcu(&local->rcu, rxrpc_local_rcu);
301         _leave(" = %d", ret);
302         return ERR_PTR(ret);
303
304 addr_in_use:
305         mutex_unlock(&rxnet->local_mutex);
306         _leave(" = -EADDRINUSE");
307         return ERR_PTR(-EADDRINUSE);
308 }
309
310 /*
311  * Get a ref on a local endpoint.
312  */
313 struct rxrpc_local *rxrpc_get_local(struct rxrpc_local *local)
314 {
315         const void *here = __builtin_return_address(0);
316         int n;
317
318         n = atomic_inc_return(&local->usage);
319         trace_rxrpc_local(local->debug_id, rxrpc_local_got, n, here);
320         return local;
321 }
322
323 /*
324  * Get a ref on a local endpoint unless its usage has already reached 0.
325  */
326 struct rxrpc_local *rxrpc_get_local_maybe(struct rxrpc_local *local)
327 {
328         const void *here = __builtin_return_address(0);
329
330         if (local) {
331                 int n = atomic_fetch_add_unless(&local->usage, 1, 0);
332                 if (n > 0)
333                         trace_rxrpc_local(local->debug_id, rxrpc_local_got,
334                                           n + 1, here);
335                 else
336                         local = NULL;
337         }
338         return local;
339 }
340
341 /*
342  * Queue a local endpoint and pass the caller's reference to the work item.
343  */
344 void rxrpc_queue_local(struct rxrpc_local *local)
345 {
346         const void *here = __builtin_return_address(0);
347         unsigned int debug_id = local->debug_id;
348         int n = atomic_read(&local->usage);
349
350         if (rxrpc_queue_work(&local->processor))
351                 trace_rxrpc_local(debug_id, rxrpc_local_queued, n, here);
352         else
353                 rxrpc_put_local(local);
354 }
355
356 /*
357  * Drop a ref on a local endpoint.
358  */
359 void rxrpc_put_local(struct rxrpc_local *local)
360 {
361         const void *here = __builtin_return_address(0);
362         unsigned int debug_id;
363         int n;
364
365         if (local) {
366                 debug_id = local->debug_id;
367
368                 n = atomic_dec_return(&local->usage);
369                 trace_rxrpc_local(debug_id, rxrpc_local_put, n, here);
370
371                 if (n == 0)
372                         call_rcu(&local->rcu, rxrpc_local_rcu);
373         }
374 }
375
376 /*
377  * Start using a local endpoint.
378  */
379 struct rxrpc_local *rxrpc_use_local(struct rxrpc_local *local)
380 {
381         local = rxrpc_get_local_maybe(local);
382         if (!local)
383                 return NULL;
384
385         if (!__rxrpc_use_local(local)) {
386                 rxrpc_put_local(local);
387                 return NULL;
388         }
389
390         return local;
391 }
392
393 /*
394  * Cease using a local endpoint.  Once the number of active users reaches 0, we
395  * start the closure of the transport in the work processor.
396  */
397 void rxrpc_unuse_local(struct rxrpc_local *local)
398 {
399         if (local) {
400                 if (__rxrpc_unuse_local(local)) {
401                         rxrpc_get_local(local);
402                         rxrpc_queue_local(local);
403                 }
404         }
405 }
406
407 /*
408  * Destroy a local endpoint's socket and then hand the record to RCU to dispose
409  * of.
410  *
411  * Closing the socket cannot be done from bottom half context or RCU callback
412  * context because it might sleep.
413  */
414 static void rxrpc_local_destroyer(struct rxrpc_local *local)
415 {
416         struct socket *socket = local->socket;
417         struct rxrpc_net *rxnet = local->rxnet;
418
419         _enter("%d", local->debug_id);
420
421         local->dead = true;
422
423         mutex_lock(&rxnet->local_mutex);
424         list_del_init(&local->link);
425         mutex_unlock(&rxnet->local_mutex);
426
427         rxrpc_clean_up_local_conns(local);
428         rxrpc_service_connection_reaper(&rxnet->service_conn_reaper);
429         ASSERT(!local->service);
430
431         if (socket) {
432                 local->socket = NULL;
433                 kernel_sock_shutdown(socket, SHUT_RDWR);
434                 socket->sk->sk_user_data = NULL;
435                 sock_release(socket);
436         }
437
438         /* At this point, there should be no more packets coming in to the
439          * local endpoint.
440          */
441         rxrpc_purge_queue(&local->reject_queue);
442         rxrpc_purge_queue(&local->event_queue);
443 }
444
445 /*
446  * Process events on an endpoint.  The work item carries a ref which
447  * we must release.
448  */
449 static void rxrpc_local_processor(struct work_struct *work)
450 {
451         struct rxrpc_local *local =
452                 container_of(work, struct rxrpc_local, processor);
453         bool again;
454
455         if (local->dead)
456                 return;
457
458         trace_rxrpc_local(local->debug_id, rxrpc_local_processing,
459                           atomic_read(&local->usage), NULL);
460
461         do {
462                 again = false;
463                 if (!__rxrpc_use_local(local)) {
464                         rxrpc_local_destroyer(local);
465                         break;
466                 }
467
468                 if (!skb_queue_empty(&local->reject_queue)) {
469                         rxrpc_reject_packets(local);
470                         again = true;
471                 }
472
473                 if (!skb_queue_empty(&local->event_queue)) {
474                         rxrpc_process_local_events(local);
475                         again = true;
476                 }
477
478                 __rxrpc_unuse_local(local);
479         } while (again);
480
481         rxrpc_put_local(local);
482 }
483
484 /*
485  * Destroy a local endpoint after the RCU grace period expires.
486  */
487 static void rxrpc_local_rcu(struct rcu_head *rcu)
488 {
489         struct rxrpc_local *local = container_of(rcu, struct rxrpc_local, rcu);
490
491         _enter("%d", local->debug_id);
492
493         ASSERT(!work_pending(&local->processor));
494
495         _net("DESTROY LOCAL %d", local->debug_id);
496         kfree(local);
497         _leave("");
498 }
499
500 /*
501  * Verify the local endpoint list is empty by this point.
502  */
503 void rxrpc_destroy_all_locals(struct rxrpc_net *rxnet)
504 {
505         struct rxrpc_local *local;
506
507         _enter("");
508
509         flush_workqueue(rxrpc_workqueue);
510
511         if (!list_empty(&rxnet->local_endpoints)) {
512                 mutex_lock(&rxnet->local_mutex);
513                 list_for_each_entry(local, &rxnet->local_endpoints, link) {
514                         pr_err("AF_RXRPC: Leaked local %p {%d}\n",
515                                local, atomic_read(&local->usage));
516                 }
517                 mutex_unlock(&rxnet->local_mutex);
518                 BUG();
519         }
520 }