GNU Linux-libre 4.4.288-gnu1
[releases.git] / net / vmw_vsock / af_vsock.c
1 /*
2  * VMware vSockets Driver
3  *
4  * Copyright (C) 2007-2013 VMware, Inc. All rights reserved.
5  *
6  * This program is free software; you can redistribute it and/or modify it
7  * under the terms of the GNU General Public License as published by the Free
8  * Software Foundation version 2 and no later version.
9  *
10  * This program is distributed in the hope that it will be useful, but WITHOUT
11  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
12  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License for
13  * more details.
14  */
15
16 /* Implementation notes:
17  *
18  * - There are two kinds of sockets: those created by user action (such as
19  * calling socket(2)) and those created by incoming connection request packets.
20  *
21  * - There are two "global" tables, one for bound sockets (sockets that have
22  * specified an address that they are responsible for) and one for connected
23  * sockets (sockets that have established a connection with another socket).
24  * These tables are "global" in that all sockets on the system are placed
25  * within them. - Note, though, that the bound table contains an extra entry
26  * for a list of unbound sockets and SOCK_DGRAM sockets will always remain in
27  * that list. The bound table is used solely for lookup of sockets when packets
28  * are received and that's not necessary for SOCK_DGRAM sockets since we create
29  * a datagram handle for each and need not perform a lookup.  Keeping SOCK_DGRAM
30  * sockets out of the bound hash buckets will reduce the chance of collisions
31  * when looking for SOCK_STREAM sockets and prevents us from having to check the
32  * socket type in the hash table lookups.
33  *
34  * - Sockets created by user action will either be "client" sockets that
35  * initiate a connection or "server" sockets that listen for connections; we do
36  * not support simultaneous connects (two "client" sockets connecting).
37  *
38  * - "Server" sockets are referred to as listener sockets throughout this
39  * implementation because they are in the VSOCK_SS_LISTEN state.  When a
40  * connection request is received (the second kind of socket mentioned above),
41  * we create a new socket and refer to it as a pending socket.  These pending
42  * sockets are placed on the pending connection list of the listener socket.
43  * When future packets are received for the address the listener socket is
44  * bound to, we check if the source of the packet is from one that has an
45  * existing pending connection.  If it does, we process the packet for the
46  * pending socket.  When that socket reaches the connected state, it is removed
47  * from the listener socket's pending list and enqueued in the listener
48  * socket's accept queue.  Callers of accept(2) will accept connected sockets
49  * from the listener socket's accept queue.  If the socket cannot be accepted
50  * for some reason then it is marked rejected.  Once the connection is
51  * accepted, it is owned by the user process and the responsibility for cleanup
52  * falls with that user process.
53  *
54  * - It is possible that these pending sockets will never reach the connected
55  * state; in fact, we may never receive another packet after the connection
56  * request.  Because of this, we must schedule a cleanup function to run in the
57  * future, after some amount of time passes where a connection should have been
58  * established.  This function ensures that the socket is off all lists so it
59  * cannot be retrieved, then drops all references to the socket so it is cleaned
60  * up (sock_put() -> sk_free() -> our sk_destruct implementation).  Note this
61  * function will also cleanup rejected sockets, those that reach the connected
62  * state but leave it before they have been accepted.
63  *
64  * - Sockets created by user action will be cleaned up when the user process
65  * calls close(2), causing our release implementation to be called. Our release
66  * implementation will perform some cleanup then drop the last reference so our
67  * sk_destruct implementation is invoked.  Our sk_destruct implementation will
68  * perform additional cleanup that's common for both types of sockets.
69  *
70  * - A socket's reference count is what ensures that the structure won't be
71  * freed.  Each entry in a list (such as the "global" bound and connected tables
72  * and the listener socket's pending list and connected queue) ensures a
73  * reference.  When we defer work until process context and pass a socket as our
74  * argument, we must ensure the reference count is increased to ensure the
75  * socket isn't freed before the function is run; the deferred function will
76  * then drop the reference.
77  */
78
79 #include <linux/types.h>
80 #include <linux/bitops.h>
81 #include <linux/cred.h>
82 #include <linux/init.h>
83 #include <linux/io.h>
84 #include <linux/kernel.h>
85 #include <linux/kmod.h>
86 #include <linux/list.h>
87 #include <linux/miscdevice.h>
88 #include <linux/module.h>
89 #include <linux/mutex.h>
90 #include <linux/net.h>
91 #include <linux/poll.h>
92 #include <linux/random.h>
93 #include <linux/skbuff.h>
94 #include <linux/smp.h>
95 #include <linux/socket.h>
96 #include <linux/stddef.h>
97 #include <linux/unistd.h>
98 #include <linux/wait.h>
99 #include <linux/workqueue.h>
100 #include <net/sock.h>
101 #include <net/af_vsock.h>
102
103 static int __vsock_bind(struct sock *sk, struct sockaddr_vm *addr);
104 static void vsock_sk_destruct(struct sock *sk);
105 static int vsock_queue_rcv_skb(struct sock *sk, struct sk_buff *skb);
106
107 /* Protocol family. */
108 static struct proto vsock_proto = {
109         .name = "AF_VSOCK",
110         .owner = THIS_MODULE,
111         .obj_size = sizeof(struct vsock_sock),
112 };
113
114 /* The default peer timeout indicates how long we will wait for a peer response
115  * to a control message.
116  */
117 #define VSOCK_DEFAULT_CONNECT_TIMEOUT (2 * HZ)
118
119 static const struct vsock_transport *transport;
120 static DEFINE_MUTEX(vsock_register_mutex);
121
122 /**** EXPORTS ****/
123
124 /* Get the ID of the local context.  This is transport dependent. */
125
126 int vm_sockets_get_local_cid(void)
127 {
128         return transport->get_local_cid();
129 }
130 EXPORT_SYMBOL_GPL(vm_sockets_get_local_cid);
131
132 /**** UTILS ****/
133
134 /* Each bound VSocket is stored in the bind hash table and each connected
135  * VSocket is stored in the connected hash table.
136  *
137  * Unbound sockets are all put on the same list attached to the end of the hash
138  * table (vsock_unbound_sockets).  Bound sockets are added to the hash table in
139  * the bucket that their local address hashes to (vsock_bound_sockets(addr)
140  * represents the list that addr hashes to).
141  *
142  * Specifically, we initialize the vsock_bind_table array to a size of
143  * VSOCK_HASH_SIZE + 1 so that vsock_bind_table[0] through
144  * vsock_bind_table[VSOCK_HASH_SIZE - 1] are for bound sockets and
145  * vsock_bind_table[VSOCK_HASH_SIZE] is for unbound sockets.  The hash function
146  * mods with VSOCK_HASH_SIZE to ensure this.
147  */
148 #define VSOCK_HASH_SIZE         251
149 #define MAX_PORT_RETRIES        24
150
151 #define VSOCK_HASH(addr)        ((addr)->svm_port % VSOCK_HASH_SIZE)
152 #define vsock_bound_sockets(addr) (&vsock_bind_table[VSOCK_HASH(addr)])
153 #define vsock_unbound_sockets     (&vsock_bind_table[VSOCK_HASH_SIZE])
154
155 /* XXX This can probably be implemented in a better way. */
156 #define VSOCK_CONN_HASH(src, dst)                               \
157         (((src)->svm_cid ^ (dst)->svm_port) % VSOCK_HASH_SIZE)
158 #define vsock_connected_sockets(src, dst)               \
159         (&vsock_connected_table[VSOCK_CONN_HASH(src, dst)])
160 #define vsock_connected_sockets_vsk(vsk)                                \
161         vsock_connected_sockets(&(vsk)->remote_addr, &(vsk)->local_addr)
162
163 static struct list_head vsock_bind_table[VSOCK_HASH_SIZE + 1];
164 static struct list_head vsock_connected_table[VSOCK_HASH_SIZE];
165 static DEFINE_SPINLOCK(vsock_table_lock);
166
167 /* Autobind this socket to the local address if necessary. */
168 static int vsock_auto_bind(struct vsock_sock *vsk)
169 {
170         struct sock *sk = sk_vsock(vsk);
171         struct sockaddr_vm local_addr;
172
173         if (vsock_addr_bound(&vsk->local_addr))
174                 return 0;
175         vsock_addr_init(&local_addr, VMADDR_CID_ANY, VMADDR_PORT_ANY);
176         return __vsock_bind(sk, &local_addr);
177 }
178
179 static void vsock_init_tables(void)
180 {
181         int i;
182
183         for (i = 0; i < ARRAY_SIZE(vsock_bind_table); i++)
184                 INIT_LIST_HEAD(&vsock_bind_table[i]);
185
186         for (i = 0; i < ARRAY_SIZE(vsock_connected_table); i++)
187                 INIT_LIST_HEAD(&vsock_connected_table[i]);
188 }
189
190 static void __vsock_insert_bound(struct list_head *list,
191                                  struct vsock_sock *vsk)
192 {
193         sock_hold(&vsk->sk);
194         list_add(&vsk->bound_table, list);
195 }
196
197 static void __vsock_insert_connected(struct list_head *list,
198                                      struct vsock_sock *vsk)
199 {
200         sock_hold(&vsk->sk);
201         list_add(&vsk->connected_table, list);
202 }
203
204 static void __vsock_remove_bound(struct vsock_sock *vsk)
205 {
206         list_del_init(&vsk->bound_table);
207         sock_put(&vsk->sk);
208 }
209
210 static void __vsock_remove_connected(struct vsock_sock *vsk)
211 {
212         list_del_init(&vsk->connected_table);
213         sock_put(&vsk->sk);
214 }
215
216 static struct sock *__vsock_find_bound_socket(struct sockaddr_vm *addr)
217 {
218         struct vsock_sock *vsk;
219
220         list_for_each_entry(vsk, vsock_bound_sockets(addr), bound_table)
221                 if (addr->svm_port == vsk->local_addr.svm_port)
222                         return sk_vsock(vsk);
223
224         return NULL;
225 }
226
227 static struct sock *__vsock_find_connected_socket(struct sockaddr_vm *src,
228                                                   struct sockaddr_vm *dst)
229 {
230         struct vsock_sock *vsk;
231
232         list_for_each_entry(vsk, vsock_connected_sockets(src, dst),
233                             connected_table) {
234                 if (vsock_addr_equals_addr(src, &vsk->remote_addr) &&
235                     dst->svm_port == vsk->local_addr.svm_port) {
236                         return sk_vsock(vsk);
237                 }
238         }
239
240         return NULL;
241 }
242
243 static bool __vsock_in_bound_table(struct vsock_sock *vsk)
244 {
245         return !list_empty(&vsk->bound_table);
246 }
247
248 static bool __vsock_in_connected_table(struct vsock_sock *vsk)
249 {
250         return !list_empty(&vsk->connected_table);
251 }
252
253 static void vsock_insert_unbound(struct vsock_sock *vsk)
254 {
255         spin_lock_bh(&vsock_table_lock);
256         __vsock_insert_bound(vsock_unbound_sockets, vsk);
257         spin_unlock_bh(&vsock_table_lock);
258 }
259
260 void vsock_insert_connected(struct vsock_sock *vsk)
261 {
262         struct list_head *list = vsock_connected_sockets(
263                 &vsk->remote_addr, &vsk->local_addr);
264
265         spin_lock_bh(&vsock_table_lock);
266         __vsock_insert_connected(list, vsk);
267         spin_unlock_bh(&vsock_table_lock);
268 }
269 EXPORT_SYMBOL_GPL(vsock_insert_connected);
270
271 void vsock_remove_bound(struct vsock_sock *vsk)
272 {
273         spin_lock_bh(&vsock_table_lock);
274         __vsock_remove_bound(vsk);
275         spin_unlock_bh(&vsock_table_lock);
276 }
277 EXPORT_SYMBOL_GPL(vsock_remove_bound);
278
279 void vsock_remove_connected(struct vsock_sock *vsk)
280 {
281         spin_lock_bh(&vsock_table_lock);
282         __vsock_remove_connected(vsk);
283         spin_unlock_bh(&vsock_table_lock);
284 }
285 EXPORT_SYMBOL_GPL(vsock_remove_connected);
286
287 struct sock *vsock_find_bound_socket(struct sockaddr_vm *addr)
288 {
289         struct sock *sk;
290
291         spin_lock_bh(&vsock_table_lock);
292         sk = __vsock_find_bound_socket(addr);
293         if (sk)
294                 sock_hold(sk);
295
296         spin_unlock_bh(&vsock_table_lock);
297
298         return sk;
299 }
300 EXPORT_SYMBOL_GPL(vsock_find_bound_socket);
301
302 struct sock *vsock_find_connected_socket(struct sockaddr_vm *src,
303                                          struct sockaddr_vm *dst)
304 {
305         struct sock *sk;
306
307         spin_lock_bh(&vsock_table_lock);
308         sk = __vsock_find_connected_socket(src, dst);
309         if (sk)
310                 sock_hold(sk);
311
312         spin_unlock_bh(&vsock_table_lock);
313
314         return sk;
315 }
316 EXPORT_SYMBOL_GPL(vsock_find_connected_socket);
317
318 static bool vsock_in_bound_table(struct vsock_sock *vsk)
319 {
320         bool ret;
321
322         spin_lock_bh(&vsock_table_lock);
323         ret = __vsock_in_bound_table(vsk);
324         spin_unlock_bh(&vsock_table_lock);
325
326         return ret;
327 }
328
329 static bool vsock_in_connected_table(struct vsock_sock *vsk)
330 {
331         bool ret;
332
333         spin_lock_bh(&vsock_table_lock);
334         ret = __vsock_in_connected_table(vsk);
335         spin_unlock_bh(&vsock_table_lock);
336
337         return ret;
338 }
339
340 void vsock_for_each_connected_socket(void (*fn)(struct sock *sk))
341 {
342         int i;
343
344         spin_lock_bh(&vsock_table_lock);
345
346         for (i = 0; i < ARRAY_SIZE(vsock_connected_table); i++) {
347                 struct vsock_sock *vsk;
348                 list_for_each_entry(vsk, &vsock_connected_table[i],
349                                     connected_table)
350                         fn(sk_vsock(vsk));
351         }
352
353         spin_unlock_bh(&vsock_table_lock);
354 }
355 EXPORT_SYMBOL_GPL(vsock_for_each_connected_socket);
356
357 void vsock_add_pending(struct sock *listener, struct sock *pending)
358 {
359         struct vsock_sock *vlistener;
360         struct vsock_sock *vpending;
361
362         vlistener = vsock_sk(listener);
363         vpending = vsock_sk(pending);
364
365         sock_hold(pending);
366         sock_hold(listener);
367         list_add_tail(&vpending->pending_links, &vlistener->pending_links);
368 }
369 EXPORT_SYMBOL_GPL(vsock_add_pending);
370
371 void vsock_remove_pending(struct sock *listener, struct sock *pending)
372 {
373         struct vsock_sock *vpending = vsock_sk(pending);
374
375         list_del_init(&vpending->pending_links);
376         sock_put(listener);
377         sock_put(pending);
378 }
379 EXPORT_SYMBOL_GPL(vsock_remove_pending);
380
381 void vsock_enqueue_accept(struct sock *listener, struct sock *connected)
382 {
383         struct vsock_sock *vlistener;
384         struct vsock_sock *vconnected;
385
386         vlistener = vsock_sk(listener);
387         vconnected = vsock_sk(connected);
388
389         sock_hold(connected);
390         sock_hold(listener);
391         list_add_tail(&vconnected->accept_queue, &vlistener->accept_queue);
392 }
393 EXPORT_SYMBOL_GPL(vsock_enqueue_accept);
394
395 static struct sock *vsock_dequeue_accept(struct sock *listener)
396 {
397         struct vsock_sock *vlistener;
398         struct vsock_sock *vconnected;
399
400         vlistener = vsock_sk(listener);
401
402         if (list_empty(&vlistener->accept_queue))
403                 return NULL;
404
405         vconnected = list_entry(vlistener->accept_queue.next,
406                                 struct vsock_sock, accept_queue);
407
408         list_del_init(&vconnected->accept_queue);
409         sock_put(listener);
410         /* The caller will need a reference on the connected socket so we let
411          * it call sock_put().
412          */
413
414         return sk_vsock(vconnected);
415 }
416
417 static bool vsock_is_accept_queue_empty(struct sock *sk)
418 {
419         struct vsock_sock *vsk = vsock_sk(sk);
420         return list_empty(&vsk->accept_queue);
421 }
422
423 static bool vsock_is_pending(struct sock *sk)
424 {
425         struct vsock_sock *vsk = vsock_sk(sk);
426         return !list_empty(&vsk->pending_links);
427 }
428
429 static int vsock_send_shutdown(struct sock *sk, int mode)
430 {
431         return transport->shutdown(vsock_sk(sk), mode);
432 }
433
434 static void vsock_pending_work(struct work_struct *work)
435 {
436         struct sock *sk;
437         struct sock *listener;
438         struct vsock_sock *vsk;
439         bool cleanup;
440
441         vsk = container_of(work, struct vsock_sock, pending_work.work);
442         sk = sk_vsock(vsk);
443         listener = vsk->listener;
444         cleanup = true;
445
446         lock_sock(listener);
447         lock_sock(sk);
448
449         if (vsock_is_pending(sk)) {
450                 vsock_remove_pending(listener, sk);
451         } else if (!vsk->rejected) {
452                 /* We are not on the pending list and accept() did not reject
453                  * us, so we must have been accepted by our user process.  We
454                  * just need to drop our references to the sockets and be on
455                  * our way.
456                  */
457                 cleanup = false;
458                 goto out;
459         }
460
461         listener->sk_ack_backlog--;
462
463         /* We need to remove ourself from the global connected sockets list so
464          * incoming packets can't find this socket, and to reduce the reference
465          * count.
466          */
467         if (vsock_in_connected_table(vsk))
468                 vsock_remove_connected(vsk);
469
470         sk->sk_state = SS_FREE;
471
472 out:
473         release_sock(sk);
474         release_sock(listener);
475         if (cleanup)
476                 sock_put(sk);
477
478         sock_put(sk);
479         sock_put(listener);
480 }
481
482 /**** SOCKET OPERATIONS ****/
483
484 static int __vsock_bind_stream(struct vsock_sock *vsk,
485                                struct sockaddr_vm *addr)
486 {
487         static u32 port = 0;
488         struct sockaddr_vm new_addr;
489
490         if (!port)
491                 port = LAST_RESERVED_PORT + 1 +
492                         prandom_u32_max(U32_MAX - LAST_RESERVED_PORT);
493
494         vsock_addr_init(&new_addr, addr->svm_cid, addr->svm_port);
495
496         if (addr->svm_port == VMADDR_PORT_ANY) {
497                 bool found = false;
498                 unsigned int i;
499
500                 for (i = 0; i < MAX_PORT_RETRIES; i++) {
501                         if (port <= LAST_RESERVED_PORT)
502                                 port = LAST_RESERVED_PORT + 1;
503
504                         new_addr.svm_port = port++;
505
506                         if (!__vsock_find_bound_socket(&new_addr)) {
507                                 found = true;
508                                 break;
509                         }
510                 }
511
512                 if (!found)
513                         return -EADDRNOTAVAIL;
514         } else {
515                 /* If port is in reserved range, ensure caller
516                  * has necessary privileges.
517                  */
518                 if (addr->svm_port <= LAST_RESERVED_PORT &&
519                     !capable(CAP_NET_BIND_SERVICE)) {
520                         return -EACCES;
521                 }
522
523                 if (__vsock_find_bound_socket(&new_addr))
524                         return -EADDRINUSE;
525         }
526
527         vsock_addr_init(&vsk->local_addr, new_addr.svm_cid, new_addr.svm_port);
528
529         /* Remove stream sockets from the unbound list and add them to the hash
530          * table for easy lookup by its address.  The unbound list is simply an
531          * extra entry at the end of the hash table, a trick used by AF_UNIX.
532          */
533         __vsock_remove_bound(vsk);
534         __vsock_insert_bound(vsock_bound_sockets(&vsk->local_addr), vsk);
535
536         return 0;
537 }
538
539 static int __vsock_bind_dgram(struct vsock_sock *vsk,
540                               struct sockaddr_vm *addr)
541 {
542         return transport->dgram_bind(vsk, addr);
543 }
544
545 static int __vsock_bind(struct sock *sk, struct sockaddr_vm *addr)
546 {
547         struct vsock_sock *vsk = vsock_sk(sk);
548         u32 cid;
549         int retval;
550
551         /* First ensure this socket isn't already bound. */
552         if (vsock_addr_bound(&vsk->local_addr))
553                 return -EINVAL;
554
555         /* Now bind to the provided address or select appropriate values if
556          * none are provided (VMADDR_CID_ANY and VMADDR_PORT_ANY).  Note that
557          * like AF_INET prevents binding to a non-local IP address (in most
558          * cases), we only allow binding to the local CID.
559          */
560         cid = transport->get_local_cid();
561         if (addr->svm_cid != cid && addr->svm_cid != VMADDR_CID_ANY)
562                 return -EADDRNOTAVAIL;
563
564         switch (sk->sk_socket->type) {
565         case SOCK_STREAM:
566                 spin_lock_bh(&vsock_table_lock);
567                 retval = __vsock_bind_stream(vsk, addr);
568                 spin_unlock_bh(&vsock_table_lock);
569                 break;
570
571         case SOCK_DGRAM:
572                 retval = __vsock_bind_dgram(vsk, addr);
573                 break;
574
575         default:
576                 retval = -EINVAL;
577                 break;
578         }
579
580         return retval;
581 }
582
583 static void vsock_connect_timeout(struct work_struct *work);
584
585 struct sock *__vsock_create(struct net *net,
586                             struct socket *sock,
587                             struct sock *parent,
588                             gfp_t priority,
589                             unsigned short type,
590                             int kern)
591 {
592         struct sock *sk;
593         struct vsock_sock *psk;
594         struct vsock_sock *vsk;
595
596         sk = sk_alloc(net, AF_VSOCK, priority, &vsock_proto, kern);
597         if (!sk)
598                 return NULL;
599
600         sock_init_data(sock, sk);
601
602         /* sk->sk_type is normally set in sock_init_data, but only if sock is
603          * non-NULL. We make sure that our sockets always have a type by
604          * setting it here if needed.
605          */
606         if (!sock)
607                 sk->sk_type = type;
608
609         vsk = vsock_sk(sk);
610         vsock_addr_init(&vsk->local_addr, VMADDR_CID_ANY, VMADDR_PORT_ANY);
611         vsock_addr_init(&vsk->remote_addr, VMADDR_CID_ANY, VMADDR_PORT_ANY);
612
613         sk->sk_destruct = vsock_sk_destruct;
614         sk->sk_backlog_rcv = vsock_queue_rcv_skb;
615         sk->sk_state = 0;
616         sock_reset_flag(sk, SOCK_DONE);
617
618         INIT_LIST_HEAD(&vsk->bound_table);
619         INIT_LIST_HEAD(&vsk->connected_table);
620         vsk->listener = NULL;
621         INIT_LIST_HEAD(&vsk->pending_links);
622         INIT_LIST_HEAD(&vsk->accept_queue);
623         vsk->rejected = false;
624         vsk->sent_request = false;
625         vsk->ignore_connecting_rst = false;
626         vsk->peer_shutdown = 0;
627         INIT_DELAYED_WORK(&vsk->connect_work, vsock_connect_timeout);
628         INIT_DELAYED_WORK(&vsk->pending_work, vsock_pending_work);
629
630         psk = parent ? vsock_sk(parent) : NULL;
631         if (parent) {
632                 vsk->trusted = psk->trusted;
633                 vsk->owner = get_cred(psk->owner);
634                 vsk->connect_timeout = psk->connect_timeout;
635                 security_sk_clone(parent, sk);
636         } else {
637                 vsk->trusted = ns_capable_noaudit(&init_user_ns, CAP_NET_ADMIN);
638                 vsk->owner = get_current_cred();
639                 vsk->connect_timeout = VSOCK_DEFAULT_CONNECT_TIMEOUT;
640         }
641
642         if (transport->init(vsk, psk) < 0) {
643                 sk_free(sk);
644                 return NULL;
645         }
646
647         if (sock)
648                 vsock_insert_unbound(vsk);
649
650         return sk;
651 }
652 EXPORT_SYMBOL_GPL(__vsock_create);
653
654 static void __vsock_release(struct sock *sk)
655 {
656         if (sk) {
657                 struct sk_buff *skb;
658                 struct sock *pending;
659                 struct vsock_sock *vsk;
660
661                 vsk = vsock_sk(sk);
662                 pending = NULL; /* Compiler warning. */
663
664                 if (vsock_in_bound_table(vsk))
665                         vsock_remove_bound(vsk);
666
667                 if (vsock_in_connected_table(vsk))
668                         vsock_remove_connected(vsk);
669
670                 transport->release(vsk);
671
672                 lock_sock(sk);
673                 sock_orphan(sk);
674                 sk->sk_shutdown = SHUTDOWN_MASK;
675
676                 while ((skb = skb_dequeue(&sk->sk_receive_queue)))
677                         kfree_skb(skb);
678
679                 /* Clean up any sockets that never were accepted. */
680                 while ((pending = vsock_dequeue_accept(sk)) != NULL) {
681                         __vsock_release(pending);
682                         sock_put(pending);
683                 }
684
685                 release_sock(sk);
686                 sock_put(sk);
687         }
688 }
689
690 static void vsock_sk_destruct(struct sock *sk)
691 {
692         struct vsock_sock *vsk = vsock_sk(sk);
693
694         transport->destruct(vsk);
695
696         /* When clearing these addresses, there's no need to set the family and
697          * possibly register the address family with the kernel.
698          */
699         vsock_addr_init(&vsk->local_addr, VMADDR_CID_ANY, VMADDR_PORT_ANY);
700         vsock_addr_init(&vsk->remote_addr, VMADDR_CID_ANY, VMADDR_PORT_ANY);
701
702         put_cred(vsk->owner);
703 }
704
705 static int vsock_queue_rcv_skb(struct sock *sk, struct sk_buff *skb)
706 {
707         int err;
708
709         err = sock_queue_rcv_skb(sk, skb);
710         if (err)
711                 kfree_skb(skb);
712
713         return err;
714 }
715
716 s64 vsock_stream_has_data(struct vsock_sock *vsk)
717 {
718         return transport->stream_has_data(vsk);
719 }
720 EXPORT_SYMBOL_GPL(vsock_stream_has_data);
721
722 s64 vsock_stream_has_space(struct vsock_sock *vsk)
723 {
724         return transport->stream_has_space(vsk);
725 }
726 EXPORT_SYMBOL_GPL(vsock_stream_has_space);
727
728 static int vsock_release(struct socket *sock)
729 {
730         __vsock_release(sock->sk);
731         sock->sk = NULL;
732         sock->state = SS_FREE;
733
734         return 0;
735 }
736
737 static int
738 vsock_bind(struct socket *sock, struct sockaddr *addr, int addr_len)
739 {
740         int err;
741         struct sock *sk;
742         struct sockaddr_vm *vm_addr;
743
744         sk = sock->sk;
745
746         if (vsock_addr_cast(addr, addr_len, &vm_addr) != 0)
747                 return -EINVAL;
748
749         lock_sock(sk);
750         err = __vsock_bind(sk, vm_addr);
751         release_sock(sk);
752
753         return err;
754 }
755
756 static int vsock_getname(struct socket *sock,
757                          struct sockaddr *addr, int *addr_len, int peer)
758 {
759         int err;
760         struct sock *sk;
761         struct vsock_sock *vsk;
762         struct sockaddr_vm *vm_addr;
763
764         sk = sock->sk;
765         vsk = vsock_sk(sk);
766         err = 0;
767
768         lock_sock(sk);
769
770         if (peer) {
771                 if (sock->state != SS_CONNECTED) {
772                         err = -ENOTCONN;
773                         goto out;
774                 }
775                 vm_addr = &vsk->remote_addr;
776         } else {
777                 vm_addr = &vsk->local_addr;
778         }
779
780         if (!vm_addr) {
781                 err = -EINVAL;
782                 goto out;
783         }
784
785         /* sys_getsockname() and sys_getpeername() pass us a
786          * MAX_SOCK_ADDR-sized buffer and don't set addr_len.  Unfortunately
787          * that macro is defined in socket.c instead of .h, so we hardcode its
788          * value here.
789          */
790         BUILD_BUG_ON(sizeof(*vm_addr) > 128);
791         memcpy(addr, vm_addr, sizeof(*vm_addr));
792         *addr_len = sizeof(*vm_addr);
793
794 out:
795         release_sock(sk);
796         return err;
797 }
798
799 static int vsock_shutdown(struct socket *sock, int mode)
800 {
801         int err;
802         struct sock *sk;
803
804         /* User level uses SHUT_RD (0) and SHUT_WR (1), but the kernel uses
805          * RCV_SHUTDOWN (1) and SEND_SHUTDOWN (2), so we must increment mode
806          * here like the other address families do.  Note also that the
807          * increment makes SHUT_RDWR (2) into RCV_SHUTDOWN | SEND_SHUTDOWN (3),
808          * which is what we want.
809          */
810         mode++;
811
812         if ((mode & ~SHUTDOWN_MASK) || !mode)
813                 return -EINVAL;
814
815         /* If this is a STREAM socket and it is not connected then bail out
816          * immediately.  If it is a DGRAM socket then we must first kick the
817          * socket so that it wakes up from any sleeping calls, for example
818          * recv(), and then afterwards return the error.
819          */
820
821         sk = sock->sk;
822
823         lock_sock(sk);
824         if (sock->state == SS_UNCONNECTED) {
825                 err = -ENOTCONN;
826                 if (sk->sk_type == SOCK_STREAM)
827                         goto out;
828         } else {
829                 sock->state = SS_DISCONNECTING;
830                 err = 0;
831         }
832
833         /* Receive and send shutdowns are treated alike. */
834         mode = mode & (RCV_SHUTDOWN | SEND_SHUTDOWN);
835         if (mode) {
836                 sk->sk_shutdown |= mode;
837                 sk->sk_state_change(sk);
838
839                 if (sk->sk_type == SOCK_STREAM) {
840                         sock_reset_flag(sk, SOCK_DONE);
841                         vsock_send_shutdown(sk, mode);
842                 }
843         }
844
845 out:
846         release_sock(sk);
847         return err;
848 }
849
850 static unsigned int vsock_poll(struct file *file, struct socket *sock,
851                                poll_table *wait)
852 {
853         struct sock *sk;
854         unsigned int mask;
855         struct vsock_sock *vsk;
856
857         sk = sock->sk;
858         vsk = vsock_sk(sk);
859
860         poll_wait(file, sk_sleep(sk), wait);
861         mask = 0;
862
863         if (sk->sk_err)
864                 /* Signify that there has been an error on this socket. */
865                 mask |= POLLERR;
866
867         /* INET sockets treat local write shutdown and peer write shutdown as a
868          * case of POLLHUP set.
869          */
870         if ((sk->sk_shutdown == SHUTDOWN_MASK) ||
871             ((sk->sk_shutdown & SEND_SHUTDOWN) &&
872              (vsk->peer_shutdown & SEND_SHUTDOWN))) {
873                 mask |= POLLHUP;
874         }
875
876         if (sk->sk_shutdown & RCV_SHUTDOWN ||
877             vsk->peer_shutdown & SEND_SHUTDOWN) {
878                 mask |= POLLRDHUP;
879         }
880
881         if (sock->type == SOCK_DGRAM) {
882                 /* For datagram sockets we can read if there is something in
883                  * the queue and write as long as the socket isn't shutdown for
884                  * sending.
885                  */
886                 if (!skb_queue_empty(&sk->sk_receive_queue) ||
887                     (sk->sk_shutdown & RCV_SHUTDOWN)) {
888                         mask |= POLLIN | POLLRDNORM;
889                 }
890
891                 if (!(sk->sk_shutdown & SEND_SHUTDOWN))
892                         mask |= POLLOUT | POLLWRNORM | POLLWRBAND;
893
894         } else if (sock->type == SOCK_STREAM) {
895                 lock_sock(sk);
896
897                 /* Listening sockets that have connections in their accept
898                  * queue can be read.
899                  */
900                 if (sk->sk_state == VSOCK_SS_LISTEN
901                     && !vsock_is_accept_queue_empty(sk))
902                         mask |= POLLIN | POLLRDNORM;
903
904                 /* If there is something in the queue then we can read. */
905                 if (transport->stream_is_active(vsk) &&
906                     !(sk->sk_shutdown & RCV_SHUTDOWN)) {
907                         bool data_ready_now = false;
908                         int ret = transport->notify_poll_in(
909                                         vsk, 1, &data_ready_now);
910                         if (ret < 0) {
911                                 mask |= POLLERR;
912                         } else {
913                                 if (data_ready_now)
914                                         mask |= POLLIN | POLLRDNORM;
915
916                         }
917                 }
918
919                 /* Sockets whose connections have been closed, reset, or
920                  * terminated should also be considered read, and we check the
921                  * shutdown flag for that.
922                  */
923                 if (sk->sk_shutdown & RCV_SHUTDOWN ||
924                     vsk->peer_shutdown & SEND_SHUTDOWN) {
925                         mask |= POLLIN | POLLRDNORM;
926                 }
927
928                 /* Connected sockets that can produce data can be written. */
929                 if (sk->sk_state == SS_CONNECTED) {
930                         if (!(sk->sk_shutdown & SEND_SHUTDOWN)) {
931                                 bool space_avail_now = false;
932                                 int ret = transport->notify_poll_out(
933                                                 vsk, 1, &space_avail_now);
934                                 if (ret < 0) {
935                                         mask |= POLLERR;
936                                 } else {
937                                         if (space_avail_now)
938                                                 /* Remove POLLWRBAND since INET
939                                                  * sockets are not setting it.
940                                                  */
941                                                 mask |= POLLOUT | POLLWRNORM;
942
943                                 }
944                         }
945                 }
946
947                 /* Simulate INET socket poll behaviors, which sets
948                  * POLLOUT|POLLWRNORM when peer is closed and nothing to read,
949                  * but local send is not shutdown.
950                  */
951                 if (sk->sk_state == SS_UNCONNECTED) {
952                         if (!(sk->sk_shutdown & SEND_SHUTDOWN))
953                                 mask |= POLLOUT | POLLWRNORM;
954
955                 }
956
957                 release_sock(sk);
958         }
959
960         return mask;
961 }
962
963 static int vsock_dgram_sendmsg(struct socket *sock, struct msghdr *msg,
964                                size_t len)
965 {
966         int err;
967         struct sock *sk;
968         struct vsock_sock *vsk;
969         struct sockaddr_vm *remote_addr;
970
971         if (msg->msg_flags & MSG_OOB)
972                 return -EOPNOTSUPP;
973
974         /* For now, MSG_DONTWAIT is always assumed... */
975         err = 0;
976         sk = sock->sk;
977         vsk = vsock_sk(sk);
978
979         lock_sock(sk);
980
981         err = vsock_auto_bind(vsk);
982         if (err)
983                 goto out;
984
985
986         /* If the provided message contains an address, use that.  Otherwise
987          * fall back on the socket's remote handle (if it has been connected).
988          */
989         if (msg->msg_name &&
990             vsock_addr_cast(msg->msg_name, msg->msg_namelen,
991                             &remote_addr) == 0) {
992                 /* Ensure this address is of the right type and is a valid
993                  * destination.
994                  */
995
996                 if (remote_addr->svm_cid == VMADDR_CID_ANY)
997                         remote_addr->svm_cid = transport->get_local_cid();
998
999                 if (!vsock_addr_bound(remote_addr)) {
1000                         err = -EINVAL;
1001                         goto out;
1002                 }
1003         } else if (sock->state == SS_CONNECTED) {
1004                 remote_addr = &vsk->remote_addr;
1005
1006                 if (remote_addr->svm_cid == VMADDR_CID_ANY)
1007                         remote_addr->svm_cid = transport->get_local_cid();
1008
1009                 /* XXX Should connect() or this function ensure remote_addr is
1010                  * bound?
1011                  */
1012                 if (!vsock_addr_bound(&vsk->remote_addr)) {
1013                         err = -EINVAL;
1014                         goto out;
1015                 }
1016         } else {
1017                 err = -EINVAL;
1018                 goto out;
1019         }
1020
1021         if (!transport->dgram_allow(remote_addr->svm_cid,
1022                                     remote_addr->svm_port)) {
1023                 err = -EINVAL;
1024                 goto out;
1025         }
1026
1027         err = transport->dgram_enqueue(vsk, remote_addr, msg, len);
1028
1029 out:
1030         release_sock(sk);
1031         return err;
1032 }
1033
1034 static int vsock_dgram_connect(struct socket *sock,
1035                                struct sockaddr *addr, int addr_len, int flags)
1036 {
1037         int err;
1038         struct sock *sk;
1039         struct vsock_sock *vsk;
1040         struct sockaddr_vm *remote_addr;
1041
1042         sk = sock->sk;
1043         vsk = vsock_sk(sk);
1044
1045         err = vsock_addr_cast(addr, addr_len, &remote_addr);
1046         if (err == -EAFNOSUPPORT && remote_addr->svm_family == AF_UNSPEC) {
1047                 lock_sock(sk);
1048                 vsock_addr_init(&vsk->remote_addr, VMADDR_CID_ANY,
1049                                 VMADDR_PORT_ANY);
1050                 sock->state = SS_UNCONNECTED;
1051                 release_sock(sk);
1052                 return 0;
1053         } else if (err != 0)
1054                 return -EINVAL;
1055
1056         lock_sock(sk);
1057
1058         err = vsock_auto_bind(vsk);
1059         if (err)
1060                 goto out;
1061
1062         if (!transport->dgram_allow(remote_addr->svm_cid,
1063                                     remote_addr->svm_port)) {
1064                 err = -EINVAL;
1065                 goto out;
1066         }
1067
1068         memcpy(&vsk->remote_addr, remote_addr, sizeof(vsk->remote_addr));
1069         sock->state = SS_CONNECTED;
1070
1071 out:
1072         release_sock(sk);
1073         return err;
1074 }
1075
1076 static int vsock_dgram_recvmsg(struct socket *sock, struct msghdr *msg,
1077                                size_t len, int flags)
1078 {
1079         return transport->dgram_dequeue(vsock_sk(sock->sk), msg, len, flags);
1080 }
1081
1082 static const struct proto_ops vsock_dgram_ops = {
1083         .family = PF_VSOCK,
1084         .owner = THIS_MODULE,
1085         .release = vsock_release,
1086         .bind = vsock_bind,
1087         .connect = vsock_dgram_connect,
1088         .socketpair = sock_no_socketpair,
1089         .accept = sock_no_accept,
1090         .getname = vsock_getname,
1091         .poll = vsock_poll,
1092         .ioctl = sock_no_ioctl,
1093         .listen = sock_no_listen,
1094         .shutdown = vsock_shutdown,
1095         .setsockopt = sock_no_setsockopt,
1096         .getsockopt = sock_no_getsockopt,
1097         .sendmsg = vsock_dgram_sendmsg,
1098         .recvmsg = vsock_dgram_recvmsg,
1099         .mmap = sock_no_mmap,
1100         .sendpage = sock_no_sendpage,
1101 };
1102
1103 static void vsock_connect_timeout(struct work_struct *work)
1104 {
1105         struct sock *sk;
1106         struct vsock_sock *vsk;
1107
1108         vsk = container_of(work, struct vsock_sock, connect_work.work);
1109         sk = sk_vsock(vsk);
1110
1111         lock_sock(sk);
1112         if (sk->sk_state == SS_CONNECTING &&
1113             (sk->sk_shutdown != SHUTDOWN_MASK)) {
1114                 sk->sk_state = SS_UNCONNECTED;
1115                 sk->sk_err = ETIMEDOUT;
1116                 sk->sk_error_report(sk);
1117         }
1118         release_sock(sk);
1119
1120         sock_put(sk);
1121 }
1122
1123 static int vsock_stream_connect(struct socket *sock, struct sockaddr *addr,
1124                                 int addr_len, int flags)
1125 {
1126         int err;
1127         struct sock *sk;
1128         struct vsock_sock *vsk;
1129         struct sockaddr_vm *remote_addr;
1130         long timeout;
1131         DEFINE_WAIT(wait);
1132
1133         err = 0;
1134         sk = sock->sk;
1135         vsk = vsock_sk(sk);
1136
1137         lock_sock(sk);
1138
1139         /* XXX AF_UNSPEC should make us disconnect like AF_INET. */
1140         switch (sock->state) {
1141         case SS_CONNECTED:
1142                 err = -EISCONN;
1143                 goto out;
1144         case SS_DISCONNECTING:
1145                 err = -EINVAL;
1146                 goto out;
1147         case SS_CONNECTING:
1148                 /* This continues on so we can move sock into the SS_CONNECTED
1149                  * state once the connection has completed (at which point err
1150                  * will be set to zero also).  Otherwise, we will either wait
1151                  * for the connection or return -EALREADY should this be a
1152                  * non-blocking call.
1153                  */
1154                 err = -EALREADY;
1155                 break;
1156         default:
1157                 if ((sk->sk_state == VSOCK_SS_LISTEN) ||
1158                     vsock_addr_cast(addr, addr_len, &remote_addr) != 0) {
1159                         err = -EINVAL;
1160                         goto out;
1161                 }
1162
1163                 /* The hypervisor and well-known contexts do not have socket
1164                  * endpoints.
1165                  */
1166                 if (!transport->stream_allow(remote_addr->svm_cid,
1167                                              remote_addr->svm_port)) {
1168                         err = -ENETUNREACH;
1169                         goto out;
1170                 }
1171
1172                 /* Set the remote address that we are connecting to. */
1173                 memcpy(&vsk->remote_addr, remote_addr,
1174                        sizeof(vsk->remote_addr));
1175
1176                 err = vsock_auto_bind(vsk);
1177                 if (err)
1178                         goto out;
1179
1180                 sk->sk_state = SS_CONNECTING;
1181
1182                 err = transport->connect(vsk);
1183                 if (err < 0)
1184                         goto out;
1185
1186                 /* Mark sock as connecting and set the error code to in
1187                  * progress in case this is a non-blocking connect.
1188                  */
1189                 sock->state = SS_CONNECTING;
1190                 err = -EINPROGRESS;
1191         }
1192
1193         /* The receive path will handle all communication until we are able to
1194          * enter the connected state.  Here we wait for the connection to be
1195          * completed or a notification of an error.
1196          */
1197         timeout = vsk->connect_timeout;
1198         prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE);
1199
1200         while (sk->sk_state != SS_CONNECTED && sk->sk_err == 0) {
1201                 if (flags & O_NONBLOCK) {
1202                         /* If we're not going to block, we schedule a timeout
1203                          * function to generate a timeout on the connection
1204                          * attempt, in case the peer doesn't respond in a
1205                          * timely manner. We hold on to the socket until the
1206                          * timeout fires.
1207                          */
1208                         sock_hold(sk);
1209                         schedule_delayed_work(&vsk->connect_work, timeout);
1210
1211                         /* Skip ahead to preserve error code set above. */
1212                         goto out_wait;
1213                 }
1214
1215                 release_sock(sk);
1216                 timeout = schedule_timeout(timeout);
1217                 lock_sock(sk);
1218
1219                 if (signal_pending(current)) {
1220                         err = sock_intr_errno(timeout);
1221                         sk->sk_state = SS_UNCONNECTED;
1222                         sock->state = SS_UNCONNECTED;
1223                         goto out_wait;
1224                 } else if (timeout == 0) {
1225                         err = -ETIMEDOUT;
1226                         sk->sk_state = SS_UNCONNECTED;
1227                         sock->state = SS_UNCONNECTED;
1228                         goto out_wait;
1229                 }
1230
1231                 prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE);
1232         }
1233
1234         if (sk->sk_err) {
1235                 err = -sk->sk_err;
1236                 sk->sk_state = SS_UNCONNECTED;
1237                 sock->state = SS_UNCONNECTED;
1238         } else {
1239                 err = 0;
1240         }
1241
1242 out_wait:
1243         finish_wait(sk_sleep(sk), &wait);
1244 out:
1245         release_sock(sk);
1246         return err;
1247 }
1248
1249 static int vsock_accept(struct socket *sock, struct socket *newsock, int flags)
1250 {
1251         struct sock *listener;
1252         int err;
1253         struct sock *connected;
1254         struct vsock_sock *vconnected;
1255         long timeout;
1256         DEFINE_WAIT(wait);
1257
1258         err = 0;
1259         listener = sock->sk;
1260
1261         lock_sock(listener);
1262
1263         if (sock->type != SOCK_STREAM) {
1264                 err = -EOPNOTSUPP;
1265                 goto out;
1266         }
1267
1268         if (listener->sk_state != VSOCK_SS_LISTEN) {
1269                 err = -EINVAL;
1270                 goto out;
1271         }
1272
1273         /* Wait for children sockets to appear; these are the new sockets
1274          * created upon connection establishment.
1275          */
1276         timeout = sock_rcvtimeo(listener, flags & O_NONBLOCK);
1277         prepare_to_wait(sk_sleep(listener), &wait, TASK_INTERRUPTIBLE);
1278
1279         while ((connected = vsock_dequeue_accept(listener)) == NULL &&
1280                listener->sk_err == 0) {
1281                 release_sock(listener);
1282                 timeout = schedule_timeout(timeout);
1283                 finish_wait(sk_sleep(listener), &wait);
1284                 lock_sock(listener);
1285
1286                 if (signal_pending(current)) {
1287                         err = sock_intr_errno(timeout);
1288                         goto out;
1289                 } else if (timeout == 0) {
1290                         err = -EAGAIN;
1291                         goto out;
1292                 }
1293
1294                 prepare_to_wait(sk_sleep(listener), &wait, TASK_INTERRUPTIBLE);
1295         }
1296         finish_wait(sk_sleep(listener), &wait);
1297
1298         if (listener->sk_err)
1299                 err = -listener->sk_err;
1300
1301         if (connected) {
1302                 listener->sk_ack_backlog--;
1303
1304                 lock_sock(connected);
1305                 vconnected = vsock_sk(connected);
1306
1307                 /* If the listener socket has received an error, then we should
1308                  * reject this socket and return.  Note that we simply mark the
1309                  * socket rejected, drop our reference, and let the cleanup
1310                  * function handle the cleanup; the fact that we found it in
1311                  * the listener's accept queue guarantees that the cleanup
1312                  * function hasn't run yet.
1313                  */
1314                 if (err) {
1315                         vconnected->rejected = true;
1316                 } else {
1317                         newsock->state = SS_CONNECTED;
1318                         sock_graft(connected, newsock);
1319                 }
1320
1321                 release_sock(connected);
1322                 sock_put(connected);
1323         }
1324
1325 out:
1326         release_sock(listener);
1327         return err;
1328 }
1329
1330 static int vsock_listen(struct socket *sock, int backlog)
1331 {
1332         int err;
1333         struct sock *sk;
1334         struct vsock_sock *vsk;
1335
1336         sk = sock->sk;
1337
1338         lock_sock(sk);
1339
1340         if (sock->type != SOCK_STREAM) {
1341                 err = -EOPNOTSUPP;
1342                 goto out;
1343         }
1344
1345         if (sock->state != SS_UNCONNECTED) {
1346                 err = -EINVAL;
1347                 goto out;
1348         }
1349
1350         vsk = vsock_sk(sk);
1351
1352         if (!vsock_addr_bound(&vsk->local_addr)) {
1353                 err = -EINVAL;
1354                 goto out;
1355         }
1356
1357         sk->sk_max_ack_backlog = backlog;
1358         sk->sk_state = VSOCK_SS_LISTEN;
1359
1360         err = 0;
1361
1362 out:
1363         release_sock(sk);
1364         return err;
1365 }
1366
1367 static int vsock_stream_setsockopt(struct socket *sock,
1368                                    int level,
1369                                    int optname,
1370                                    char __user *optval,
1371                                    unsigned int optlen)
1372 {
1373         int err;
1374         struct sock *sk;
1375         struct vsock_sock *vsk;
1376         u64 val;
1377
1378         if (level != AF_VSOCK)
1379                 return -ENOPROTOOPT;
1380
1381 #define COPY_IN(_v)                                       \
1382         do {                                              \
1383                 if (optlen < sizeof(_v)) {                \
1384                         err = -EINVAL;                    \
1385                         goto exit;                        \
1386                 }                                         \
1387                 if (copy_from_user(&_v, optval, sizeof(_v)) != 0) {     \
1388                         err = -EFAULT;                                  \
1389                         goto exit;                                      \
1390                 }                                                       \
1391         } while (0)
1392
1393         err = 0;
1394         sk = sock->sk;
1395         vsk = vsock_sk(sk);
1396
1397         lock_sock(sk);
1398
1399         switch (optname) {
1400         case SO_VM_SOCKETS_BUFFER_SIZE:
1401                 COPY_IN(val);
1402                 transport->set_buffer_size(vsk, val);
1403                 break;
1404
1405         case SO_VM_SOCKETS_BUFFER_MAX_SIZE:
1406                 COPY_IN(val);
1407                 transport->set_max_buffer_size(vsk, val);
1408                 break;
1409
1410         case SO_VM_SOCKETS_BUFFER_MIN_SIZE:
1411                 COPY_IN(val);
1412                 transport->set_min_buffer_size(vsk, val);
1413                 break;
1414
1415         case SO_VM_SOCKETS_CONNECT_TIMEOUT: {
1416                 struct timeval tv;
1417                 COPY_IN(tv);
1418                 if (tv.tv_sec >= 0 && tv.tv_usec < USEC_PER_SEC &&
1419                     tv.tv_sec < (MAX_SCHEDULE_TIMEOUT / HZ - 1)) {
1420                         vsk->connect_timeout = tv.tv_sec * HZ +
1421                             DIV_ROUND_UP(tv.tv_usec, (1000000 / HZ));
1422                         if (vsk->connect_timeout == 0)
1423                                 vsk->connect_timeout =
1424                                     VSOCK_DEFAULT_CONNECT_TIMEOUT;
1425
1426                 } else {
1427                         err = -ERANGE;
1428                 }
1429                 break;
1430         }
1431
1432         default:
1433                 err = -ENOPROTOOPT;
1434                 break;
1435         }
1436
1437 #undef COPY_IN
1438
1439 exit:
1440         release_sock(sk);
1441         return err;
1442 }
1443
1444 static int vsock_stream_getsockopt(struct socket *sock,
1445                                    int level, int optname,
1446                                    char __user *optval,
1447                                    int __user *optlen)
1448 {
1449         int err;
1450         int len;
1451         struct sock *sk;
1452         struct vsock_sock *vsk;
1453         u64 val;
1454
1455         if (level != AF_VSOCK)
1456                 return -ENOPROTOOPT;
1457
1458         err = get_user(len, optlen);
1459         if (err != 0)
1460                 return err;
1461
1462 #define COPY_OUT(_v)                            \
1463         do {                                    \
1464                 if (len < sizeof(_v))           \
1465                         return -EINVAL;         \
1466                                                 \
1467                 len = sizeof(_v);               \
1468                 if (copy_to_user(optval, &_v, len) != 0)        \
1469                         return -EFAULT;                         \
1470                                                                 \
1471         } while (0)
1472
1473         err = 0;
1474         sk = sock->sk;
1475         vsk = vsock_sk(sk);
1476
1477         switch (optname) {
1478         case SO_VM_SOCKETS_BUFFER_SIZE:
1479                 val = transport->get_buffer_size(vsk);
1480                 COPY_OUT(val);
1481                 break;
1482
1483         case SO_VM_SOCKETS_BUFFER_MAX_SIZE:
1484                 val = transport->get_max_buffer_size(vsk);
1485                 COPY_OUT(val);
1486                 break;
1487
1488         case SO_VM_SOCKETS_BUFFER_MIN_SIZE:
1489                 val = transport->get_min_buffer_size(vsk);
1490                 COPY_OUT(val);
1491                 break;
1492
1493         case SO_VM_SOCKETS_CONNECT_TIMEOUT: {
1494                 struct timeval tv;
1495                 tv.tv_sec = vsk->connect_timeout / HZ;
1496                 tv.tv_usec =
1497                     (vsk->connect_timeout -
1498                      tv.tv_sec * HZ) * (1000000 / HZ);
1499                 COPY_OUT(tv);
1500                 break;
1501         }
1502         default:
1503                 return -ENOPROTOOPT;
1504         }
1505
1506         err = put_user(len, optlen);
1507         if (err != 0)
1508                 return -EFAULT;
1509
1510 #undef COPY_OUT
1511
1512         return 0;
1513 }
1514
1515 static int vsock_stream_sendmsg(struct socket *sock, struct msghdr *msg,
1516                                 size_t len)
1517 {
1518         struct sock *sk;
1519         struct vsock_sock *vsk;
1520         ssize_t total_written;
1521         long timeout;
1522         int err;
1523         struct vsock_transport_send_notify_data send_data;
1524         DEFINE_WAIT_FUNC(wait, woken_wake_function);
1525
1526         sk = sock->sk;
1527         vsk = vsock_sk(sk);
1528         total_written = 0;
1529         err = 0;
1530
1531         if (msg->msg_flags & MSG_OOB)
1532                 return -EOPNOTSUPP;
1533
1534         lock_sock(sk);
1535
1536         /* Callers should not provide a destination with stream sockets. */
1537         if (msg->msg_namelen) {
1538                 err = sk->sk_state == SS_CONNECTED ? -EISCONN : -EOPNOTSUPP;
1539                 goto out;
1540         }
1541
1542         /* Send data only if both sides are not shutdown in the direction. */
1543         if (sk->sk_shutdown & SEND_SHUTDOWN ||
1544             vsk->peer_shutdown & RCV_SHUTDOWN) {
1545                 err = -EPIPE;
1546                 goto out;
1547         }
1548
1549         if (sk->sk_state != SS_CONNECTED ||
1550             !vsock_addr_bound(&vsk->local_addr)) {
1551                 err = -ENOTCONN;
1552                 goto out;
1553         }
1554
1555         if (!vsock_addr_bound(&vsk->remote_addr)) {
1556                 err = -EDESTADDRREQ;
1557                 goto out;
1558         }
1559
1560         /* Wait for room in the produce queue to enqueue our user's data. */
1561         timeout = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
1562
1563         err = transport->notify_send_init(vsk, &send_data);
1564         if (err < 0)
1565                 goto out;
1566
1567         while (total_written < len) {
1568                 ssize_t written;
1569
1570                 add_wait_queue(sk_sleep(sk), &wait);
1571                 while (vsock_stream_has_space(vsk) == 0 &&
1572                        sk->sk_err == 0 &&
1573                        !(sk->sk_shutdown & SEND_SHUTDOWN) &&
1574                        !(vsk->peer_shutdown & RCV_SHUTDOWN)) {
1575
1576                         /* Don't wait for non-blocking sockets. */
1577                         if (timeout == 0) {
1578                                 err = -EAGAIN;
1579                                 remove_wait_queue(sk_sleep(sk), &wait);
1580                                 goto out_err;
1581                         }
1582
1583                         err = transport->notify_send_pre_block(vsk, &send_data);
1584                         if (err < 0) {
1585                                 remove_wait_queue(sk_sleep(sk), &wait);
1586                                 goto out_err;
1587                         }
1588
1589                         release_sock(sk);
1590                         timeout = wait_woken(&wait, TASK_INTERRUPTIBLE, timeout);
1591                         lock_sock(sk);
1592                         if (signal_pending(current)) {
1593                                 err = sock_intr_errno(timeout);
1594                                 remove_wait_queue(sk_sleep(sk), &wait);
1595                                 goto out_err;
1596                         } else if (timeout == 0) {
1597                                 err = -EAGAIN;
1598                                 remove_wait_queue(sk_sleep(sk), &wait);
1599                                 goto out_err;
1600                         }
1601                 }
1602                 remove_wait_queue(sk_sleep(sk), &wait);
1603
1604                 /* These checks occur both as part of and after the loop
1605                  * conditional since we need to check before and after
1606                  * sleeping.
1607                  */
1608                 if (sk->sk_err) {
1609                         err = -sk->sk_err;
1610                         goto out_err;
1611                 } else if ((sk->sk_shutdown & SEND_SHUTDOWN) ||
1612                            (vsk->peer_shutdown & RCV_SHUTDOWN)) {
1613                         err = -EPIPE;
1614                         goto out_err;
1615                 }
1616
1617                 err = transport->notify_send_pre_enqueue(vsk, &send_data);
1618                 if (err < 0)
1619                         goto out_err;
1620
1621                 /* Note that enqueue will only write as many bytes as are free
1622                  * in the produce queue, so we don't need to ensure len is
1623                  * smaller than the queue size.  It is the caller's
1624                  * responsibility to check how many bytes we were able to send.
1625                  */
1626
1627                 written = transport->stream_enqueue(
1628                                 vsk, msg,
1629                                 len - total_written);
1630                 if (written < 0) {
1631                         err = -ENOMEM;
1632                         goto out_err;
1633                 }
1634
1635                 total_written += written;
1636
1637                 err = transport->notify_send_post_enqueue(
1638                                 vsk, written, &send_data);
1639                 if (err < 0)
1640                         goto out_err;
1641
1642         }
1643
1644 out_err:
1645         if (total_written > 0)
1646                 err = total_written;
1647 out:
1648         release_sock(sk);
1649         return err;
1650 }
1651
1652
1653 static int
1654 vsock_stream_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
1655                      int flags)
1656 {
1657         struct sock *sk;
1658         struct vsock_sock *vsk;
1659         int err;
1660         size_t target;
1661         ssize_t copied;
1662         long timeout;
1663         struct vsock_transport_recv_notify_data recv_data;
1664
1665         DEFINE_WAIT(wait);
1666
1667         sk = sock->sk;
1668         vsk = vsock_sk(sk);
1669         err = 0;
1670
1671         lock_sock(sk);
1672
1673         if (sk->sk_state != SS_CONNECTED) {
1674                 /* Recvmsg is supposed to return 0 if a peer performs an
1675                  * orderly shutdown. Differentiate between that case and when a
1676                  * peer has not connected or a local shutdown occured with the
1677                  * SOCK_DONE flag.
1678                  */
1679                 if (sock_flag(sk, SOCK_DONE))
1680                         err = 0;
1681                 else
1682                         err = -ENOTCONN;
1683
1684                 goto out;
1685         }
1686
1687         if (flags & MSG_OOB) {
1688                 err = -EOPNOTSUPP;
1689                 goto out;
1690         }
1691
1692         /* We don't check peer_shutdown flag here since peer may actually shut
1693          * down, but there can be data in the queue that a local socket can
1694          * receive.
1695          */
1696         if (sk->sk_shutdown & RCV_SHUTDOWN) {
1697                 err = 0;
1698                 goto out;
1699         }
1700
1701         /* It is valid on Linux to pass in a zero-length receive buffer.  This
1702          * is not an error.  We may as well bail out now.
1703          */
1704         if (!len) {
1705                 err = 0;
1706                 goto out;
1707         }
1708
1709         /* We must not copy less than target bytes into the user's buffer
1710          * before returning successfully, so we wait for the consume queue to
1711          * have that much data to consume before dequeueing.  Note that this
1712          * makes it impossible to handle cases where target is greater than the
1713          * queue size.
1714          */
1715         target = sock_rcvlowat(sk, flags & MSG_WAITALL, len);
1716         if (target >= transport->stream_rcvhiwat(vsk)) {
1717                 err = -ENOMEM;
1718                 goto out;
1719         }
1720         timeout = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
1721         copied = 0;
1722
1723         err = transport->notify_recv_init(vsk, target, &recv_data);
1724         if (err < 0)
1725                 goto out;
1726
1727
1728         while (1) {
1729                 s64 ready;
1730
1731                 prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE);
1732                 ready = vsock_stream_has_data(vsk);
1733
1734                 if (ready == 0) {
1735                         if (sk->sk_err != 0 ||
1736                             (sk->sk_shutdown & RCV_SHUTDOWN) ||
1737                             (vsk->peer_shutdown & SEND_SHUTDOWN)) {
1738                                 finish_wait(sk_sleep(sk), &wait);
1739                                 break;
1740                         }
1741                         /* Don't wait for non-blocking sockets. */
1742                         if (timeout == 0) {
1743                                 err = -EAGAIN;
1744                                 finish_wait(sk_sleep(sk), &wait);
1745                                 break;
1746                         }
1747
1748                         err = transport->notify_recv_pre_block(
1749                                         vsk, target, &recv_data);
1750                         if (err < 0) {
1751                                 finish_wait(sk_sleep(sk), &wait);
1752                                 break;
1753                         }
1754                         release_sock(sk);
1755                         timeout = schedule_timeout(timeout);
1756                         lock_sock(sk);
1757
1758                         if (signal_pending(current)) {
1759                                 err = sock_intr_errno(timeout);
1760                                 finish_wait(sk_sleep(sk), &wait);
1761                                 break;
1762                         } else if (timeout == 0) {
1763                                 err = -EAGAIN;
1764                                 finish_wait(sk_sleep(sk), &wait);
1765                                 break;
1766                         }
1767                 } else {
1768                         ssize_t read;
1769
1770                         finish_wait(sk_sleep(sk), &wait);
1771
1772                         if (ready < 0) {
1773                                 /* Invalid queue pair content. XXX This should
1774                                 * be changed to a connection reset in a later
1775                                 * change.
1776                                 */
1777
1778                                 err = -ENOMEM;
1779                                 goto out;
1780                         }
1781
1782                         err = transport->notify_recv_pre_dequeue(
1783                                         vsk, target, &recv_data);
1784                         if (err < 0)
1785                                 break;
1786
1787                         read = transport->stream_dequeue(
1788                                         vsk, msg,
1789                                         len - copied, flags);
1790                         if (read < 0) {
1791                                 err = -ENOMEM;
1792                                 break;
1793                         }
1794
1795                         copied += read;
1796
1797                         err = transport->notify_recv_post_dequeue(
1798                                         vsk, target, read,
1799                                         !(flags & MSG_PEEK), &recv_data);
1800                         if (err < 0)
1801                                 goto out;
1802
1803                         if (read >= target || flags & MSG_PEEK)
1804                                 break;
1805
1806                         target -= read;
1807                 }
1808         }
1809
1810         if (sk->sk_err)
1811                 err = -sk->sk_err;
1812         else if (sk->sk_shutdown & RCV_SHUTDOWN)
1813                 err = 0;
1814
1815         if (copied > 0)
1816                 err = copied;
1817
1818 out:
1819         release_sock(sk);
1820         return err;
1821 }
1822
1823 static const struct proto_ops vsock_stream_ops = {
1824         .family = PF_VSOCK,
1825         .owner = THIS_MODULE,
1826         .release = vsock_release,
1827         .bind = vsock_bind,
1828         .connect = vsock_stream_connect,
1829         .socketpair = sock_no_socketpair,
1830         .accept = vsock_accept,
1831         .getname = vsock_getname,
1832         .poll = vsock_poll,
1833         .ioctl = sock_no_ioctl,
1834         .listen = vsock_listen,
1835         .shutdown = vsock_shutdown,
1836         .setsockopt = vsock_stream_setsockopt,
1837         .getsockopt = vsock_stream_getsockopt,
1838         .sendmsg = vsock_stream_sendmsg,
1839         .recvmsg = vsock_stream_recvmsg,
1840         .mmap = sock_no_mmap,
1841         .sendpage = sock_no_sendpage,
1842 };
1843
1844 static int vsock_create(struct net *net, struct socket *sock,
1845                         int protocol, int kern)
1846 {
1847         if (!sock)
1848                 return -EINVAL;
1849
1850         if (protocol && protocol != PF_VSOCK)
1851                 return -EPROTONOSUPPORT;
1852
1853         switch (sock->type) {
1854         case SOCK_DGRAM:
1855                 sock->ops = &vsock_dgram_ops;
1856                 break;
1857         case SOCK_STREAM:
1858                 sock->ops = &vsock_stream_ops;
1859                 break;
1860         default:
1861                 return -ESOCKTNOSUPPORT;
1862         }
1863
1864         sock->state = SS_UNCONNECTED;
1865
1866         return __vsock_create(net, sock, NULL, GFP_KERNEL, 0, kern) ? 0 : -ENOMEM;
1867 }
1868
1869 static const struct net_proto_family vsock_family_ops = {
1870         .family = AF_VSOCK,
1871         .create = vsock_create,
1872         .owner = THIS_MODULE,
1873 };
1874
1875 static long vsock_dev_do_ioctl(struct file *filp,
1876                                unsigned int cmd, void __user *ptr)
1877 {
1878         u32 __user *p = ptr;
1879         int retval = 0;
1880
1881         switch (cmd) {
1882         case IOCTL_VM_SOCKETS_GET_LOCAL_CID:
1883                 if (put_user(transport->get_local_cid(), p) != 0)
1884                         retval = -EFAULT;
1885                 break;
1886
1887         default:
1888                 pr_err("Unknown ioctl %d\n", cmd);
1889                 retval = -EINVAL;
1890         }
1891
1892         return retval;
1893 }
1894
1895 static long vsock_dev_ioctl(struct file *filp,
1896                             unsigned int cmd, unsigned long arg)
1897 {
1898         return vsock_dev_do_ioctl(filp, cmd, (void __user *)arg);
1899 }
1900
1901 #ifdef CONFIG_COMPAT
1902 static long vsock_dev_compat_ioctl(struct file *filp,
1903                                    unsigned int cmd, unsigned long arg)
1904 {
1905         return vsock_dev_do_ioctl(filp, cmd, compat_ptr(arg));
1906 }
1907 #endif
1908
1909 static const struct file_operations vsock_device_ops = {
1910         .owner          = THIS_MODULE,
1911         .unlocked_ioctl = vsock_dev_ioctl,
1912 #ifdef CONFIG_COMPAT
1913         .compat_ioctl   = vsock_dev_compat_ioctl,
1914 #endif
1915         .open           = nonseekable_open,
1916 };
1917
1918 static struct miscdevice vsock_device = {
1919         .name           = "vsock",
1920         .fops           = &vsock_device_ops,
1921 };
1922
1923 int __vsock_core_init(const struct vsock_transport *t, struct module *owner)
1924 {
1925         int err = mutex_lock_interruptible(&vsock_register_mutex);
1926
1927         if (err)
1928                 return err;
1929
1930         if (transport) {
1931                 err = -EBUSY;
1932                 goto err_busy;
1933         }
1934
1935         /* Transport must be the owner of the protocol so that it can't
1936          * unload while there are open sockets.
1937          */
1938         vsock_proto.owner = owner;
1939         transport = t;
1940
1941         vsock_init_tables();
1942
1943         vsock_device.minor = MISC_DYNAMIC_MINOR;
1944         err = misc_register(&vsock_device);
1945         if (err) {
1946                 pr_err("Failed to register misc device\n");
1947                 goto err_reset_transport;
1948         }
1949
1950         err = proto_register(&vsock_proto, 1);  /* we want our slab */
1951         if (err) {
1952                 pr_err("Cannot register vsock protocol\n");
1953                 goto err_deregister_misc;
1954         }
1955
1956         err = sock_register(&vsock_family_ops);
1957         if (err) {
1958                 pr_err("could not register af_vsock (%d) address family: %d\n",
1959                        AF_VSOCK, err);
1960                 goto err_unregister_proto;
1961         }
1962
1963         mutex_unlock(&vsock_register_mutex);
1964         return 0;
1965
1966 err_unregister_proto:
1967         proto_unregister(&vsock_proto);
1968 err_deregister_misc:
1969         misc_deregister(&vsock_device);
1970 err_reset_transport:
1971         transport = NULL;
1972 err_busy:
1973         mutex_unlock(&vsock_register_mutex);
1974         return err;
1975 }
1976 EXPORT_SYMBOL_GPL(__vsock_core_init);
1977
1978 void vsock_core_exit(void)
1979 {
1980         mutex_lock(&vsock_register_mutex);
1981
1982         misc_deregister(&vsock_device);
1983         sock_unregister(AF_VSOCK);
1984         proto_unregister(&vsock_proto);
1985
1986         /* We do not want the assignment below re-ordered. */
1987         mb();
1988         transport = NULL;
1989
1990         mutex_unlock(&vsock_register_mutex);
1991 }
1992 EXPORT_SYMBOL_GPL(vsock_core_exit);
1993
1994 MODULE_AUTHOR("VMware, Inc.");
1995 MODULE_DESCRIPTION("VMware Virtual Socket Family");
1996 MODULE_VERSION("1.0.1.0-k");
1997 MODULE_LICENSE("GPL v2");