net.c: don't tune SO_RCVLOWAT
[netsum.git] / net.c
1 /*
2  * This file is part of netsum.
3  *
4  * Copyright (C) 2014 Simon Guinot <simon.guinot@sequanux.org>
5  *
6  * netsum is free software: you can redistribute it and/or modify
7  * it under the terms of the GNU General Public License as published by
8  * the Free Software Foundation, either version 3 of the License, or
9  * (at your option) any later version.
10  *
11  * netsum is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14  * GNU General Public License for more details.
15  *
16  * You should have received a copy of the GNU General Public License
17  * along with netsum.  If not, see <http://www.gnu.org/licenses/>.
18  */
19
20 #include <stdio.h>
21 #include <stdlib.h>
22 #include <string.h>
23 #include <stdbool.h>
24 #include <unistd.h>
25 #include <errno.h>
26 #include <ifaddrs.h>
27 #include <inttypes.h>
28 #include <arpa/inet.h>
29 #include <sys/time.h>
30
31 #include "netsum.h"
32 #include "buffer.h"
33
34 #define NETSUM_DEFAULT_PORT 3333
35
36 extern bool verbose;
37
38 struct netsum_ctrl_msg {
39         uint32_t count;
40         uint32_t size;
41         uint32_t timeout;
42         uint8_t dir;
43 } __attribute__((packed));
44
45 /*
46  * Connection.
47  */
48
49 static int get_inet_addr(const char *addr, struct sockaddr_in *sin)
50 {
51         char *tmp, *ip_str, *port_str;
52         int ret = -1;
53         uint16_t port = NETSUM_DEFAULT_PORT;
54
55         tmp = strdup(addr);
56         if (!tmp) {
57                 fprintf(stderr, "strdup() failed: %s\n", strerror(errno));
58                 return -1;
59         }
60
61         ip_str = tmp;
62         port_str = strchr(tmp, ':');
63         if (port_str) {
64                 *port_str = '\0';
65                 port_str++;
66                 port = (uint16_t) atoi(port_str);
67         }
68
69         if (!inet_aton(ip_str, &sin->sin_addr))
70                 goto err_free;
71         sin->sin_port = htons(port);
72         sin->sin_family = AF_INET;
73
74         ret = 0;
75 err_free:
76         free(tmp);
77         return ret;
78 }
79
80 static int addr_is_local(struct sockaddr_in *sin)
81 {
82         int ret = 0;
83         struct ifaddrs *ifaddr, *ifa;
84
85         if (getifaddrs(&ifaddr) == -1) {
86                 fprintf(stderr, "getifaddrs() failed: %s\n", strerror(errno));
87                 return -1;
88         }
89
90         for (ifa = ifaddr; ifa != NULL; ifa = ifa->ifa_next) {
91                 struct sockaddr_in *local_sin;
92
93                 if (ifa->ifa_addr->sa_family != AF_INET)
94                         continue;
95
96                 local_sin = (struct sockaddr_in *) ifa->ifa_addr;
97                 if (local_sin->sin_addr.s_addr == sin->sin_addr.s_addr) {
98                         ret = 1;
99                         break;
100                 }
101         }
102         freeifaddrs(ifaddr);
103
104         return ret;
105 }
106
107 static int open_socket(time_t timeout, size_t recvlowat)
108 {
109         int sock;
110         int on = 1;
111         struct timeval tv = {
112                 tv.tv_sec = timeout,
113                 tv.tv_usec = 0,
114         };
115
116         sock = socket(PF_INET, SOCK_STREAM, 0);
117         if (sock == -1) {
118                 fprintf(stderr, "socket() failed: %s\n", strerror(errno));
119                 goto err;
120         }
121         if (setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on)) == -1) {
122                 fprintf(stderr, "setsockopt() SO_REUSEADDR failed: %s\n",
123                         strerror(errno));
124                 goto err_close_sock;
125         }
126         if (setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)) == -1) {
127                 fprintf(stderr, "setsockopt() SO_RCVTIMEOUT failed: %s\n",
128                         strerror(errno));
129                 goto err_close_sock;
130         }
131         if (setsockopt(sock, SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)) == -1) {
132                 fprintf(stderr, "setsockopt() SO_SNDTIMEOUT failed: %s\n",
133                         strerror(errno));
134                 goto err_close_sock;
135         }
136
137         return sock;
138
139 err_close_sock:
140         close(sock);
141 err:
142         return -1;
143 }
144         
145 static int
146 connect_server(struct sockaddr_in *sin, time_t timeout, size_t recvlowat)
147 {
148         int sock;
149
150         sock = open_socket(timeout, recvlowat);
151         if (sock == -1)
152                 return -1;
153
154         if (connect(sock, (struct sockaddr *) sin, sizeof(*sin)) == -1) {
155                 fprintf(stderr, "connect() failed: %s\n", strerror(errno));
156                 close(sock);
157                 return -1;
158         }
159
160         return sock;
161 }
162
163 static int
164 open_server_sock(struct sockaddr_in *sin, time_t timeout, size_t recvlowat)
165 {
166         int ret;
167         int sock;
168         socklen_t addrlen = sizeof(*sin);
169
170         sock = open_socket(timeout, recvlowat);
171         if (sock == -1)
172                 return -1;
173
174         ret = bind(sock, (struct sockaddr *) sin, sizeof(*sin));
175         if (ret == -1) {
176                 fprintf(stderr, "bind() failed: %s\n", strerror(errno));
177                 goto err_sock;
178         }
179         ret = listen(sock, 1);
180         if (ret == -1) {
181                 fprintf(stderr, "listen() failed: %s\n", strerror(errno));
182                 goto err_sock;
183         }
184         ret = getsockname(sock, (struct sockaddr *) sin, &addrlen);
185         if (ret == -1) {
186                 fprintf(stderr, "getsockname() failed: %s\n", strerror(errno));
187                 goto err_sock;
188         }
189         fprintf(stdout, "Listen at %s:%d\n",
190                 inet_ntoa(sin->sin_addr), ntohs(sin->sin_port));
191
192         return sock;
193
194 err_sock:
195         close(sock);
196         return ret;
197 }
198
199 /*
200  * Messages.
201  */
202
203 static int send_message(int sock, unsigned char *msg, int len,
204                         struct sockaddr *daddr, socklen_t addrlen)
205 {
206         int sent, retry, ret;
207
208         for (sent = 0, retry = 0; sent != len && retry < 3;) {
209                 ret = sendto(sock, msg, len - sent,
210                              MSG_NOSIGNAL, daddr, addrlen);
211                 if (ret == -1) {
212                         if (errno == EAGAIN || errno == EINTR) {
213                                 retry++;
214                                 continue;
215                         }
216                         break;
217                 }
218                 sent += ret;
219                 msg += ret;
220         }
221         return sent;
222 }
223
224 static int recv_message(int sock, unsigned char *msg, int len,
225                         struct sockaddr *saddr, socklen_t *addrlen)
226 {
227         int received, retry, ret;
228
229         for (received = 0, retry = 0; received != len && retry < 3;) {
230                 ret = recvfrom(sock, msg, len - received,
231                                MSG_NOSIGNAL, saddr, addrlen);
232                 if (ret == -1) {
233                         if (errno == EAGAIN || errno == EINTR) {
234                                 retry++;
235                                 continue;
236                         }
237                         break;
238                 } else if (ret == 0) {
239                         break;
240                 }
241                 received += ret;
242                 msg += ret;
243         }
244         return received;
245 }
246
247 static int send_ctrl_message(int sock, struct netsum_ctrl_msg *ctrl_msg,
248                              struct sockaddr *daddr, socklen_t addrlen)
249 {
250         int ret;
251
252         ret = send_message(sock, (unsigned char *) ctrl_msg, sizeof(*ctrl_msg),
253                            daddr, addrlen);
254         if (ret != sizeof(*ctrl_msg))
255                 return -1;
256
257         return 0;
258 }
259
260 static int recv_ctrl_message(int sock, struct netsum_ctrl_msg *ctrl_msg,
261                              struct sockaddr *saddr, socklen_t *addrlen)
262 {
263         int ret;
264
265         ret = recv_message(sock, (unsigned char *) ctrl_msg, sizeof(*ctrl_msg),
266                            saddr, addrlen);
267         if (ret != sizeof(*ctrl_msg))
268                 return -1;
269
270         return 0;
271 }
272
273 static void show_stats(struct netsum_ctrl *ctrl)
274 {
275         int i;
276         uint64_t sum;
277         bool print;
278         char *units[] = { "KB", "MB", "GB", "TB" };
279         double time, time_sum;
280
281         fprintf(stdout, "\33[2K\rTx:");
282         sum = ctrl->tx_sum >> 10;
283         print = false;
284         for (i = ARRAY_SIZE(units) - 1; i >= 0; i--) {
285                 uint64_t val = sum >> (10 * i);
286
287                 if (print || val) {
288                         fprintf(stdout, " %04" PRIu64 "%s", val, units[i]);
289                         sum -= (val << (10 * i));
290                         print = true;
291                 }
292         }
293         /* Avoid division by zero. */
294         time = ctrl->tx_time ? : 1; 
295         time_sum = ctrl->tx_time_sum ? : 1; 
296         fprintf(stdout, " (%3.3fMB/sec %3.3fMB/sec)",
297                 ((ctrl->tx >> 20) / time),
298                 ((ctrl->tx_sum >> 20) / time_sum));
299
300         fprintf(stdout, " - Rx:");
301         sum = ctrl->rx_sum >> 10;
302         print = false;
303         for (i = ARRAY_SIZE(units) - 1; i >= 0; i--) {
304                 uint64_t val = sum >> (10 * i);
305
306                 if (print || val) {
307                         fprintf(stdout, " %04" PRIu64 "%s", val, units[i]);
308                         sum -= (val << (10 * i));
309                         print = true;
310                 }
311         }
312         /* Avoid division by zero. */
313         time = ctrl->rx_time ? : 1; 
314         time_sum = ctrl->rx_time_sum ? : 1; 
315         fprintf(stdout, " (%3.3fMB/sec %3.3fMB/sec)",
316                 ((ctrl->rx >> 20) / time),
317                 ((ctrl->rx_sum >> 20) / time_sum));
318
319         fflush(stdout);
320 }
321
322 static int send_data_messages(int sock, struct netsum_ctrl *ctrl,
323                               struct sockaddr *daddr, socklen_t addrlen)
324 {
325         unsigned char *buff;
326         uint32_t i;
327         int ret;
328         struct timeval start_time, stop_time;
329         double elapsed;
330
331         /* Prepare payload */
332         buff = malloc(ctrl->size);
333         if (!buff) {
334                 fprintf(stderr, "malloc() failed: %s\n", strerror(errno));
335                 return -1;
336         }
337         fill_buffer(buff, ctrl->size);
338
339         gettimeofday(&start_time, NULL);
340
341         for (i = 0; i < ctrl->count; i++) {
342                 ret = send_message(sock, buff, ctrl->size, daddr, addrlen);
343                 if (ret != ctrl->size) {
344                         ret = -1;
345                         goto err_free;
346                 }
347         }
348
349         gettimeofday(&stop_time, NULL);
350         elapsed = stop_time.tv_sec - start_time.tv_sec +
351                   ((double) (stop_time.tv_usec - start_time.tv_usec) / 1000000);
352
353         ctrl->tx = (uint64_t) ctrl->count * (uint64_t) ctrl->size;
354         ctrl->tx_sum += ctrl->tx;
355         ctrl->tx_time = elapsed;
356         ctrl->tx_time_sum += elapsed;
357
358         show_stats(ctrl);
359         ret = 0;
360
361 err_free:
362         free(buff);
363         return ret;
364 }
365
366 static int recv_data_messages(int sock, struct netsum_ctrl *ctrl,
367                               struct sockaddr *saddr, socklen_t *addrlen)
368 {
369         unsigned char *buff;
370         unsigned char *template = NULL;
371         bool use_memcmp = !!(ctrl->flags & USE_MEMCMP);
372         int ret = -1;
373         struct timeval start_time, stop_time;
374         uint32_t i;
375         double elapsed;
376
377         buff = malloc(ctrl->size);
378         if (!buff) {
379                 fprintf(stderr, "malloc() failed: %s\n", strerror(errno));
380                 return -1;
381         }
382
383         template = malloc(ctrl->size);
384         if (!template) {
385                 fprintf(stderr, "malloc() failed: %s\n",
386                         strerror(errno));
387                 goto err_free;
388         }
389
390         gettimeofday(&start_time, NULL);
391
392         for (i = 0; i < ctrl->count; i++) {
393                 ret = recv_message(sock, buff, ctrl->size, saddr, addrlen);
394                 if (ret != ctrl->size) {
395                         ret = -1;
396                         goto err_free;
397                 }
398                 if (use_memcmp) {
399                         /*
400                          * For the first buffer received or on error, fall back
401                          * on buffer_is_valid().
402                          */
403                         if (i == 0)
404                                 memcpy(template, buff, ctrl->size);
405                         else if (!memcmp(buff, template, ctrl->size))
406                                 continue;
407                 }
408                 if (!buffer_is_valid(buff, ctrl->size, true)) {
409                         fprintf(stdout,
410                                 "Reference buffer stored in file: ref\n");
411                         fprintf(stdout,
412                                 "Corrupted buffer stored in file: err\n");
413                         write_buffer_to_file("ref", buff, ctrl->size);
414                         write_buffer_to_file("err", template, ctrl->size);
415                         ret = -1;
416                         goto err_free;
417                 }
418         }
419
420         gettimeofday(&stop_time, NULL);
421         elapsed = stop_time.tv_sec - start_time.tv_sec +
422                   ((double) (stop_time.tv_usec - start_time.tv_usec) / 1000000);
423
424         ctrl->rx = (uint64_t) ctrl->count * (uint64_t) ctrl->size;
425         ctrl->rx_sum += ctrl->rx;
426         ctrl->rx_time = elapsed;
427         ctrl->rx_time_sum += elapsed;
428
429         show_stats(ctrl);
430         ret = 0;
431
432 err_free:
433         free(buff);
434         if (template)
435                 free(template);
436
437         return ret;
438 }
439
440 /*
441  * Base.
442  */
443
444 static int handle_client(int server_sock, struct netsum_ctrl *ctrl)
445 {
446         int sock, ret;
447         struct sockaddr_in sin;
448         socklen_t addrlen = sizeof(sin);
449         struct netsum_ctrl_msg ctrl_msg;
450
451         sock = accept(server_sock, (struct sockaddr *) &sin, &addrlen);
452         if (sock == -1) {
453                 fprintf(stderr, "accept() failed: %s\n", strerror(errno));
454                 return -1;
455         }
456         fprintf(stdout, "\nClient %s:%d\n",
457                 inet_ntoa(sin.sin_addr), ntohs(sin.sin_port));
458
459         ctrl->rx = 0;
460         ctrl->rx_sum = 0;
461         ctrl->rx_time = 0;
462         ctrl->rx_time_sum = 0;
463         ctrl->tx = 0;
464         ctrl->tx_sum = 0;
465         ctrl->tx_time = 0;
466         ctrl->tx_time_sum = 0;
467
468         while (1) {
469                 ret = recv_ctrl_message(sock, &ctrl_msg, NULL, NULL);
470                 if (ret)
471                         break;
472
473                 ctrl->count = ntohl(ctrl_msg.count);
474                 ctrl->size = ntohl(ctrl_msg.size);
475                 ctrl->timeout = ntohl(ctrl_msg.timeout);
476
477                 if (ctrl_msg.dir)
478                         ret = send_data_messages(sock, ctrl, NULL, 0);
479                 else
480                         ret = recv_data_messages(sock, ctrl, NULL, NULL);
481                 if (ret)
482                         break;
483         };
484
485         close(sock);
486
487         return ret;
488 }
489
490 static int run_server(struct sockaddr_in *sin, struct netsum_ctrl *ctrl)
491 {
492         int sock;
493
494         sock = open_server_sock(sin, ctrl->timeout, ctrl->size);
495         if (sock == -1)
496                 return -1;
497
498         while(1) { handle_client(sock, ctrl); };
499
500         close(sock);
501         return -1;
502 }
503
504 static int run_client(struct sockaddr_in *sin, struct netsum_ctrl *ctrl)
505 {
506         int sock, ret;
507         struct netsum_ctrl_msg ctrl_msg;
508
509         sock = connect_server(sin, ctrl->timeout, ctrl->size);
510         if (sock == -1)
511                 return -1;
512
513         ctrl->rx = 0;
514         ctrl->rx_sum = 0;
515         ctrl->rx_time = 0;
516         ctrl->rx_time_sum = 0;
517         ctrl->tx = 0;
518         ctrl->tx_sum = 0;
519         ctrl->tx_time = 0;
520         ctrl->tx_time_sum = 0;
521
522         ctrl_msg.count = htonl(ctrl->count);
523         ctrl_msg.size = htonl(ctrl->size);
524         ctrl_msg.timeout = htonl(ctrl->timeout);
525
526         while (1) {
527                 if (ctrl->flags & DIR_TX) {
528                         ctrl_msg.dir = 0;
529                         ret = send_ctrl_message(sock, &ctrl_msg, NULL, 0);
530                         if (ret)
531                                 break;
532                         ret = send_data_messages(sock, ctrl, NULL, 0);
533                         if (ret)
534                                 break;
535                 }
536                 if (ctrl->flags & DIR_RX) {
537                         ctrl_msg.dir = 1;
538                         ret = send_ctrl_message(sock, &ctrl_msg, NULL, 0);
539                         if (ret)
540                                 break;
541                         ret = recv_data_messages(sock, ctrl, NULL, NULL);
542                         if (ret)
543                                 break;
544                 }
545         };
546
547         close(sock);
548
549         return ret;
550 }
551
552 int run_netsum(char *addr, struct netsum_ctrl *ctrl)
553 {
554         struct sockaddr_in sin;
555         int ret;
556
557         memset(&sin, 0, sizeof(sin));
558
559         if (addr) {
560                 if (get_inet_addr(addr, &sin)) {
561                         fprintf(stderr,
562                                 "Failed to convert %s into an IP address\n",
563                                 addr);
564                         return -1;
565                 }
566                 ret = addr_is_local(&sin);
567                 if (ret < 0)
568                         return ret;
569                 if (!ret)
570                         return run_client(&sin, ctrl);
571         } else {
572                 sin.sin_family = AF_INET;
573                 sin.sin_port = htons(NETSUM_DEFAULT_PORT);
574                 sin.sin_addr.s_addr = INADDR_ANY;
575         }
576
577         return run_server(&sin, ctrl);
578 }