GNU Linux-libre 4.19.286-gnu1
[releases.git] / tools / testing / vsock / vsock_diag_test.c
1 /*
2  * vsock_diag_test - vsock_diag.ko test suite
3  *
4  * Copyright (C) 2017 Red Hat, Inc.
5  *
6  * Author: Stefan Hajnoczi <stefanha@redhat.com>
7  *
8  * This program is free software; you can redistribute it and/or
9  * modify it under the terms of the GNU General Public License
10  * as published by the Free Software Foundation; version 2
11  * of the License.
12  */
13
14 #include <getopt.h>
15 #include <stdio.h>
16 #include <stdbool.h>
17 #include <stdlib.h>
18 #include <string.h>
19 #include <errno.h>
20 #include <unistd.h>
21 #include <signal.h>
22 #include <sys/socket.h>
23 #include <sys/stat.h>
24 #include <sys/types.h>
25 #include <linux/list.h>
26 #include <linux/net.h>
27 #include <linux/netlink.h>
28 #include <linux/sock_diag.h>
29 #include <netinet/tcp.h>
30
31 #include "../../../include/uapi/linux/vm_sockets.h"
32 #include "../../../include/uapi/linux/vm_sockets_diag.h"
33
34 #include "timeout.h"
35 #include "control.h"
36
37 enum test_mode {
38         TEST_MODE_UNSET,
39         TEST_MODE_CLIENT,
40         TEST_MODE_SERVER
41 };
42
43 /* Per-socket status */
44 struct vsock_stat {
45         struct list_head list;
46         struct vsock_diag_msg msg;
47 };
48
49 static const char *sock_type_str(int type)
50 {
51         switch (type) {
52         case SOCK_DGRAM:
53                 return "DGRAM";
54         case SOCK_STREAM:
55                 return "STREAM";
56         default:
57                 return "INVALID TYPE";
58         }
59 }
60
61 static const char *sock_state_str(int state)
62 {
63         switch (state) {
64         case TCP_CLOSE:
65                 return "UNCONNECTED";
66         case TCP_SYN_SENT:
67                 return "CONNECTING";
68         case TCP_ESTABLISHED:
69                 return "CONNECTED";
70         case TCP_CLOSING:
71                 return "DISCONNECTING";
72         case TCP_LISTEN:
73                 return "LISTEN";
74         default:
75                 return "INVALID STATE";
76         }
77 }
78
79 static const char *sock_shutdown_str(int shutdown)
80 {
81         switch (shutdown) {
82         case 1:
83                 return "RCV_SHUTDOWN";
84         case 2:
85                 return "SEND_SHUTDOWN";
86         case 3:
87                 return "RCV_SHUTDOWN | SEND_SHUTDOWN";
88         default:
89                 return "0";
90         }
91 }
92
93 static void print_vsock_addr(FILE *fp, unsigned int cid, unsigned int port)
94 {
95         if (cid == VMADDR_CID_ANY)
96                 fprintf(fp, "*:");
97         else
98                 fprintf(fp, "%u:", cid);
99
100         if (port == VMADDR_PORT_ANY)
101                 fprintf(fp, "*");
102         else
103                 fprintf(fp, "%u", port);
104 }
105
106 static void print_vsock_stat(FILE *fp, struct vsock_stat *st)
107 {
108         print_vsock_addr(fp, st->msg.vdiag_src_cid, st->msg.vdiag_src_port);
109         fprintf(fp, " ");
110         print_vsock_addr(fp, st->msg.vdiag_dst_cid, st->msg.vdiag_dst_port);
111         fprintf(fp, " %s %s %s %u\n",
112                 sock_type_str(st->msg.vdiag_type),
113                 sock_state_str(st->msg.vdiag_state),
114                 sock_shutdown_str(st->msg.vdiag_shutdown),
115                 st->msg.vdiag_ino);
116 }
117
118 static void print_vsock_stats(FILE *fp, struct list_head *head)
119 {
120         struct vsock_stat *st;
121
122         list_for_each_entry(st, head, list)
123                 print_vsock_stat(fp, st);
124 }
125
126 static struct vsock_stat *find_vsock_stat(struct list_head *head, int fd)
127 {
128         struct vsock_stat *st;
129         struct stat stat;
130
131         if (fstat(fd, &stat) < 0) {
132                 perror("fstat");
133                 exit(EXIT_FAILURE);
134         }
135
136         list_for_each_entry(st, head, list)
137                 if (st->msg.vdiag_ino == stat.st_ino)
138                         return st;
139
140         fprintf(stderr, "cannot find fd %d\n", fd);
141         exit(EXIT_FAILURE);
142 }
143
144 static void check_no_sockets(struct list_head *head)
145 {
146         if (!list_empty(head)) {
147                 fprintf(stderr, "expected no sockets\n");
148                 print_vsock_stats(stderr, head);
149                 exit(1);
150         }
151 }
152
153 static void check_num_sockets(struct list_head *head, int expected)
154 {
155         struct list_head *node;
156         int n = 0;
157
158         list_for_each(node, head)
159                 n++;
160
161         if (n != expected) {
162                 fprintf(stderr, "expected %d sockets, found %d\n",
163                         expected, n);
164                 print_vsock_stats(stderr, head);
165                 exit(EXIT_FAILURE);
166         }
167 }
168
169 static void check_socket_state(struct vsock_stat *st, __u8 state)
170 {
171         if (st->msg.vdiag_state != state) {
172                 fprintf(stderr, "expected socket state %#x, got %#x\n",
173                         state, st->msg.vdiag_state);
174                 exit(EXIT_FAILURE);
175         }
176 }
177
178 static void send_req(int fd)
179 {
180         struct sockaddr_nl nladdr = {
181                 .nl_family = AF_NETLINK,
182         };
183         struct {
184                 struct nlmsghdr nlh;
185                 struct vsock_diag_req vreq;
186         } req = {
187                 .nlh = {
188                         .nlmsg_len = sizeof(req),
189                         .nlmsg_type = SOCK_DIAG_BY_FAMILY,
190                         .nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP,
191                 },
192                 .vreq = {
193                         .sdiag_family = AF_VSOCK,
194                         .vdiag_states = ~(__u32)0,
195                 },
196         };
197         struct iovec iov = {
198                 .iov_base = &req,
199                 .iov_len = sizeof(req),
200         };
201         struct msghdr msg = {
202                 .msg_name = &nladdr,
203                 .msg_namelen = sizeof(nladdr),
204                 .msg_iov = &iov,
205                 .msg_iovlen = 1,
206         };
207
208         for (;;) {
209                 if (sendmsg(fd, &msg, 0) < 0) {
210                         if (errno == EINTR)
211                                 continue;
212
213                         perror("sendmsg");
214                         exit(EXIT_FAILURE);
215                 }
216
217                 return;
218         }
219 }
220
221 static ssize_t recv_resp(int fd, void *buf, size_t len)
222 {
223         struct sockaddr_nl nladdr = {
224                 .nl_family = AF_NETLINK,
225         };
226         struct iovec iov = {
227                 .iov_base = buf,
228                 .iov_len = len,
229         };
230         struct msghdr msg = {
231                 .msg_name = &nladdr,
232                 .msg_namelen = sizeof(nladdr),
233                 .msg_iov = &iov,
234                 .msg_iovlen = 1,
235         };
236         ssize_t ret;
237
238         do {
239                 ret = recvmsg(fd, &msg, 0);
240         } while (ret < 0 && errno == EINTR);
241
242         if (ret < 0) {
243                 perror("recvmsg");
244                 exit(EXIT_FAILURE);
245         }
246
247         return ret;
248 }
249
250 static void add_vsock_stat(struct list_head *sockets,
251                            const struct vsock_diag_msg *resp)
252 {
253         struct vsock_stat *st;
254
255         st = malloc(sizeof(*st));
256         if (!st) {
257                 perror("malloc");
258                 exit(EXIT_FAILURE);
259         }
260
261         st->msg = *resp;
262         list_add_tail(&st->list, sockets);
263 }
264
265 /*
266  * Read vsock stats into a list.
267  */
268 static void read_vsock_stat(struct list_head *sockets)
269 {
270         long buf[8192 / sizeof(long)];
271         int fd;
272
273         fd = socket(AF_NETLINK, SOCK_RAW, NETLINK_SOCK_DIAG);
274         if (fd < 0) {
275                 perror("socket");
276                 exit(EXIT_FAILURE);
277         }
278
279         send_req(fd);
280
281         for (;;) {
282                 const struct nlmsghdr *h;
283                 ssize_t ret;
284
285                 ret = recv_resp(fd, buf, sizeof(buf));
286                 if (ret == 0)
287                         goto done;
288                 if (ret < sizeof(*h)) {
289                         fprintf(stderr, "short read of %zd bytes\n", ret);
290                         exit(EXIT_FAILURE);
291                 }
292
293                 h = (struct nlmsghdr *)buf;
294
295                 while (NLMSG_OK(h, ret)) {
296                         if (h->nlmsg_type == NLMSG_DONE)
297                                 goto done;
298
299                         if (h->nlmsg_type == NLMSG_ERROR) {
300                                 const struct nlmsgerr *err = NLMSG_DATA(h);
301
302                                 if (h->nlmsg_len < NLMSG_LENGTH(sizeof(*err)))
303                                         fprintf(stderr, "NLMSG_ERROR\n");
304                                 else {
305                                         errno = -err->error;
306                                         perror("NLMSG_ERROR");
307                                 }
308
309                                 exit(EXIT_FAILURE);
310                         }
311
312                         if (h->nlmsg_type != SOCK_DIAG_BY_FAMILY) {
313                                 fprintf(stderr, "unexpected nlmsg_type %#x\n",
314                                         h->nlmsg_type);
315                                 exit(EXIT_FAILURE);
316                         }
317                         if (h->nlmsg_len <
318                             NLMSG_LENGTH(sizeof(struct vsock_diag_msg))) {
319                                 fprintf(stderr, "short vsock_diag_msg\n");
320                                 exit(EXIT_FAILURE);
321                         }
322
323                         add_vsock_stat(sockets, NLMSG_DATA(h));
324
325                         h = NLMSG_NEXT(h, ret);
326                 }
327         }
328
329 done:
330         close(fd);
331 }
332
333 static void free_sock_stat(struct list_head *sockets)
334 {
335         struct vsock_stat *st;
336         struct vsock_stat *next;
337
338         list_for_each_entry_safe(st, next, sockets, list)
339                 free(st);
340 }
341
342 static void test_no_sockets(unsigned int peer_cid)
343 {
344         LIST_HEAD(sockets);
345
346         read_vsock_stat(&sockets);
347
348         check_no_sockets(&sockets);
349
350         free_sock_stat(&sockets);
351 }
352
353 static void test_listen_socket_server(unsigned int peer_cid)
354 {
355         union {
356                 struct sockaddr sa;
357                 struct sockaddr_vm svm;
358         } addr = {
359                 .svm = {
360                         .svm_family = AF_VSOCK,
361                         .svm_port = 1234,
362                         .svm_cid = VMADDR_CID_ANY,
363                 },
364         };
365         LIST_HEAD(sockets);
366         struct vsock_stat *st;
367         int fd;
368
369         fd = socket(AF_VSOCK, SOCK_STREAM, 0);
370
371         if (bind(fd, &addr.sa, sizeof(addr.svm)) < 0) {
372                 perror("bind");
373                 exit(EXIT_FAILURE);
374         }
375
376         if (listen(fd, 1) < 0) {
377                 perror("listen");
378                 exit(EXIT_FAILURE);
379         }
380
381         read_vsock_stat(&sockets);
382
383         check_num_sockets(&sockets, 1);
384         st = find_vsock_stat(&sockets, fd);
385         check_socket_state(st, TCP_LISTEN);
386
387         close(fd);
388         free_sock_stat(&sockets);
389 }
390
391 static void test_connect_client(unsigned int peer_cid)
392 {
393         union {
394                 struct sockaddr sa;
395                 struct sockaddr_vm svm;
396         } addr = {
397                 .svm = {
398                         .svm_family = AF_VSOCK,
399                         .svm_port = 1234,
400                         .svm_cid = peer_cid,
401                 },
402         };
403         int fd;
404         int ret;
405         LIST_HEAD(sockets);
406         struct vsock_stat *st;
407
408         control_expectln("LISTENING");
409
410         fd = socket(AF_VSOCK, SOCK_STREAM, 0);
411
412         timeout_begin(TIMEOUT);
413         do {
414                 ret = connect(fd, &addr.sa, sizeof(addr.svm));
415                 timeout_check("connect");
416         } while (ret < 0 && errno == EINTR);
417         timeout_end();
418
419         if (ret < 0) {
420                 perror("connect");
421                 exit(EXIT_FAILURE);
422         }
423
424         read_vsock_stat(&sockets);
425
426         check_num_sockets(&sockets, 1);
427         st = find_vsock_stat(&sockets, fd);
428         check_socket_state(st, TCP_ESTABLISHED);
429
430         control_expectln("DONE");
431         control_writeln("DONE");
432
433         close(fd);
434         free_sock_stat(&sockets);
435 }
436
437 static void test_connect_server(unsigned int peer_cid)
438 {
439         union {
440                 struct sockaddr sa;
441                 struct sockaddr_vm svm;
442         } addr = {
443                 .svm = {
444                         .svm_family = AF_VSOCK,
445                         .svm_port = 1234,
446                         .svm_cid = VMADDR_CID_ANY,
447                 },
448         };
449         union {
450                 struct sockaddr sa;
451                 struct sockaddr_vm svm;
452         } clientaddr;
453         socklen_t clientaddr_len = sizeof(clientaddr.svm);
454         LIST_HEAD(sockets);
455         struct vsock_stat *st;
456         int fd;
457         int client_fd;
458
459         fd = socket(AF_VSOCK, SOCK_STREAM, 0);
460
461         if (bind(fd, &addr.sa, sizeof(addr.svm)) < 0) {
462                 perror("bind");
463                 exit(EXIT_FAILURE);
464         }
465
466         if (listen(fd, 1) < 0) {
467                 perror("listen");
468                 exit(EXIT_FAILURE);
469         }
470
471         control_writeln("LISTENING");
472
473         timeout_begin(TIMEOUT);
474         do {
475                 client_fd = accept(fd, &clientaddr.sa, &clientaddr_len);
476                 timeout_check("accept");
477         } while (client_fd < 0 && errno == EINTR);
478         timeout_end();
479
480         if (client_fd < 0) {
481                 perror("accept");
482                 exit(EXIT_FAILURE);
483         }
484         if (clientaddr.sa.sa_family != AF_VSOCK) {
485                 fprintf(stderr, "expected AF_VSOCK from accept(2), got %d\n",
486                         clientaddr.sa.sa_family);
487                 exit(EXIT_FAILURE);
488         }
489         if (clientaddr.svm.svm_cid != peer_cid) {
490                 fprintf(stderr, "expected peer CID %u from accept(2), got %u\n",
491                         peer_cid, clientaddr.svm.svm_cid);
492                 exit(EXIT_FAILURE);
493         }
494
495         read_vsock_stat(&sockets);
496
497         check_num_sockets(&sockets, 2);
498         find_vsock_stat(&sockets, fd);
499         st = find_vsock_stat(&sockets, client_fd);
500         check_socket_state(st, TCP_ESTABLISHED);
501
502         control_writeln("DONE");
503         control_expectln("DONE");
504
505         close(client_fd);
506         close(fd);
507         free_sock_stat(&sockets);
508 }
509
510 static struct {
511         const char *name;
512         void (*run_client)(unsigned int peer_cid);
513         void (*run_server)(unsigned int peer_cid);
514 } test_cases[] = {
515         {
516                 .name = "No sockets",
517                 .run_server = test_no_sockets,
518         },
519         {
520                 .name = "Listen socket",
521                 .run_server = test_listen_socket_server,
522         },
523         {
524                 .name = "Connect",
525                 .run_client = test_connect_client,
526                 .run_server = test_connect_server,
527         },
528         {},
529 };
530
531 static void init_signals(void)
532 {
533         struct sigaction act = {
534                 .sa_handler = sigalrm,
535         };
536
537         sigaction(SIGALRM, &act, NULL);
538         signal(SIGPIPE, SIG_IGN);
539 }
540
541 static unsigned int parse_cid(const char *str)
542 {
543         char *endptr = NULL;
544         unsigned long int n;
545
546         errno = 0;
547         n = strtoul(str, &endptr, 10);
548         if (errno || *endptr != '\0') {
549                 fprintf(stderr, "malformed CID \"%s\"\n", str);
550                 exit(EXIT_FAILURE);
551         }
552         return n;
553 }
554
555 static const char optstring[] = "";
556 static const struct option longopts[] = {
557         {
558                 .name = "control-host",
559                 .has_arg = required_argument,
560                 .val = 'H',
561         },
562         {
563                 .name = "control-port",
564                 .has_arg = required_argument,
565                 .val = 'P',
566         },
567         {
568                 .name = "mode",
569                 .has_arg = required_argument,
570                 .val = 'm',
571         },
572         {
573                 .name = "peer-cid",
574                 .has_arg = required_argument,
575                 .val = 'p',
576         },
577         {
578                 .name = "help",
579                 .has_arg = no_argument,
580                 .val = '?',
581         },
582         {},
583 };
584
585 static void usage(void)
586 {
587         fprintf(stderr, "Usage: vsock_diag_test [--help] [--control-host=<host>] --control-port=<port> --mode=client|server --peer-cid=<cid>\n"
588                 "\n"
589                 "  Server: vsock_diag_test --control-port=1234 --mode=server --peer-cid=3\n"
590                 "  Client: vsock_diag_test --control-host=192.168.0.1 --control-port=1234 --mode=client --peer-cid=2\n"
591                 "\n"
592                 "Run vsock_diag.ko tests.  Must be launched in both\n"
593                 "guest and host.  One side must use --mode=client and\n"
594                 "the other side must use --mode=server.\n"
595                 "\n"
596                 "A TCP control socket connection is used to coordinate tests\n"
597                 "between the client and the server.  The server requires a\n"
598                 "listen address and the client requires an address to\n"
599                 "connect to.\n"
600                 "\n"
601                 "The CID of the other side must be given with --peer-cid=<cid>.\n");
602         exit(EXIT_FAILURE);
603 }
604
605 int main(int argc, char **argv)
606 {
607         const char *control_host = NULL;
608         const char *control_port = NULL;
609         int mode = TEST_MODE_UNSET;
610         unsigned int peer_cid = VMADDR_CID_ANY;
611         int i;
612
613         init_signals();
614
615         for (;;) {
616                 int opt = getopt_long(argc, argv, optstring, longopts, NULL);
617
618                 if (opt == -1)
619                         break;
620
621                 switch (opt) {
622                 case 'H':
623                         control_host = optarg;
624                         break;
625                 case 'm':
626                         if (strcmp(optarg, "client") == 0)
627                                 mode = TEST_MODE_CLIENT;
628                         else if (strcmp(optarg, "server") == 0)
629                                 mode = TEST_MODE_SERVER;
630                         else {
631                                 fprintf(stderr, "--mode must be \"client\" or \"server\"\n");
632                                 return EXIT_FAILURE;
633                         }
634                         break;
635                 case 'p':
636                         peer_cid = parse_cid(optarg);
637                         break;
638                 case 'P':
639                         control_port = optarg;
640                         break;
641                 case '?':
642                 default:
643                         usage();
644                 }
645         }
646
647         if (!control_port)
648                 usage();
649         if (mode == TEST_MODE_UNSET)
650                 usage();
651         if (peer_cid == VMADDR_CID_ANY)
652                 usage();
653
654         if (!control_host) {
655                 if (mode != TEST_MODE_SERVER)
656                         usage();
657                 control_host = "0.0.0.0";
658         }
659
660         control_init(control_host, control_port, mode == TEST_MODE_SERVER);
661
662         for (i = 0; test_cases[i].name; i++) {
663                 void (*run)(unsigned int peer_cid);
664
665                 printf("%s...", test_cases[i].name);
666                 fflush(stdout);
667
668                 if (mode == TEST_MODE_CLIENT)
669                         run = test_cases[i].run_client;
670                 else
671                         run = test_cases[i].run_server;
672
673                 if (run)
674                         run(peer_cid);
675
676                 printf("ok\n");
677         }
678
679         control_cleanup();
680         return EXIT_SUCCESS;
681 }