net.c: add check of the first received buffer when using memcmp
[netsum.git] / net.c
1 /*
2  * This file is part of netsum.
3  *
4  * Copyright (C) 2014 Simon Guinot <simon.guinot@sequanux.org>
5  *
6  * netsum is free software: you can redistribute it and/or modify
7  * it under the terms of the GNU General Public License as published by
8  * the Free Software Foundation, either version 3 of the License, or
9  * (at your option) any later version.
10  *
11  * netsum is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14  * GNU General Public License for more details.
15  *
16  * You should have received a copy of the GNU General Public License
17  * along with netsum.  If not, see <http://www.gnu.org/licenses/>.
18  */
19
20 #include <stdio.h>
21 #include <stdlib.h>
22 #include <string.h>
23 #include <stdbool.h>
24 #include <unistd.h>
25 #include <errno.h>
26 #include <ifaddrs.h>
27 #include <inttypes.h>
28 #include <arpa/inet.h>
29 #include <sys/time.h>
30
31 #include "netsum.h"
32 #include "buffer.h"
33
34 #define NETSUM_DEFAULT_PORT 3333
35
36 extern bool verbose;
37
38 struct netsum_ctrl_msg {
39         uint32_t count;
40         uint32_t size;
41         uint32_t timeout;
42         uint8_t dir;
43 } __attribute__((packed));
44
45 /*
46  * Connection.
47  */
48
49 static int get_inet_addr(const char *addr, struct sockaddr_in *sin)
50 {
51         char *tmp, *ip_str, *port_str;
52         int ret = -1;
53         uint16_t port = NETSUM_DEFAULT_PORT;
54
55         tmp = strdup(addr);
56         if (!tmp) {
57                 fprintf(stderr, "strdup() failed: %s\n", strerror(errno));
58                 return -1;
59         }
60
61         ip_str = tmp;
62         port_str = strchr(tmp, ':');
63         if (port_str) {
64                 *port_str = '\0';
65                 port_str++;
66                 port = (uint16_t) atoi(port_str);
67         }
68
69         if (!inet_aton(ip_str, &sin->sin_addr))
70                 goto err_free;
71         sin->sin_port = htons(port);
72         sin->sin_family = AF_INET;
73
74         ret = 0;
75 err_free:
76         free(tmp);
77         return ret;
78 }
79
80 static int addr_is_local(struct sockaddr_in *sin)
81 {
82         int ret = 0;
83         struct ifaddrs *ifaddr, *ifa;
84
85         if (getifaddrs(&ifaddr) == -1) {
86                 fprintf(stderr, "getifaddrs() failed: %s\n", strerror(errno));
87                 return -1;
88         }
89
90         for (ifa = ifaddr; ifa != NULL; ifa = ifa->ifa_next) {
91                 struct sockaddr_in *local_sin;
92
93                 if (ifa->ifa_addr->sa_family != AF_INET)
94                         continue;
95
96                 local_sin = (struct sockaddr_in *) ifa->ifa_addr;
97                 if (local_sin->sin_addr.s_addr == sin->sin_addr.s_addr) {
98                         ret = 1;
99                         break;
100                 }
101         }
102         freeifaddrs(ifaddr);
103
104         return ret;
105 }
106
107 static int open_socket(time_t timeout, size_t recvlowat)
108 {
109         int sock;
110         int on = 1;
111         struct timeval tv = {
112                 tv.tv_sec = timeout,
113                 tv.tv_usec = 0,
114         };
115
116         sock = socket(PF_INET, SOCK_STREAM, 0);
117         if (sock == -1) {
118                 fprintf(stderr, "socket() failed: %s\n", strerror(errno));
119                 goto err;
120         }
121         if (setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on)) == -1) {
122                 fprintf(stderr, "setsockopt() SO_REUSEADDR failed: %s\n",
123                         strerror(errno));
124                 goto err_close_sock;
125         }
126         if (setsockopt(sock, SOL_SOCKET, SO_RCVLOWAT,
127                         &recvlowat, sizeof(recvlowat)) == -1) {
128                 fprintf(stderr, "setsockopt() SO_RCVLOWAT failed: %s\n",
129                         strerror(errno));
130                 goto err_close_sock;
131         }
132         if (setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)) == -1) {
133                 fprintf(stderr, "setsockopt() SO_RCVTIMEOUT failed: %s\n",
134                         strerror(errno));
135                 goto err_close_sock;
136         }
137         if (setsockopt(sock, SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)) == -1) {
138                 fprintf(stderr, "setsockopt() SO_SNDTIMEOUT failed: %s\n",
139                         strerror(errno));
140                 goto err_close_sock;
141         }
142
143         return sock;
144
145 err_close_sock:
146         close(sock);
147 err:
148         return -1;
149 }
150         
151 static int
152 connect_server(struct sockaddr_in *sin, time_t timeout, size_t recvlowat)
153 {
154         int sock;
155
156         sock = open_socket(timeout, recvlowat);
157         if (sock == -1)
158                 return -1;
159
160         if (connect(sock, (struct sockaddr *) sin, sizeof(*sin)) == -1) {
161                 fprintf(stderr, "connect() failed: %s\n", strerror(errno));
162                 close(sock);
163                 return -1;
164         }
165
166         return sock;
167 }
168
169 static int
170 open_server_sock(struct sockaddr_in *sin, time_t timeout, size_t recvlowat)
171 {
172         int ret;
173         int sock;
174         socklen_t addrlen = sizeof(*sin);
175
176         sock = open_socket(timeout, recvlowat);
177         if (sock == -1)
178                 return -1;
179
180         ret = bind(sock, (struct sockaddr *) sin, sizeof(*sin));
181         if (ret == -1) {
182                 fprintf(stderr, "bind() failed: %s\n", strerror(errno));
183                 goto err_sock;
184         }
185         ret = listen(sock, 1);
186         if (ret == -1) {
187                 fprintf(stderr, "listen() failed: %s\n", strerror(errno));
188                 goto err_sock;
189         }
190         ret = getsockname(sock, (struct sockaddr *) sin, &addrlen);
191         if (ret == -1) {
192                 fprintf(stderr, "getsockname() failed: %s\n", strerror(errno));
193                 goto err_sock;
194         }
195         fprintf(stdout, "Listen at %s:%d\n",
196                 inet_ntoa(sin->sin_addr), ntohs(sin->sin_port));
197
198         return sock;
199
200 err_sock:
201         close(sock);
202         return ret;
203 }
204
205 /*
206  * Messages.
207  */
208
209 static int send_message(int sock, unsigned char *msg, int len,
210                         struct sockaddr *daddr, socklen_t addrlen)
211 {
212         int sent, retry, ret;
213
214         for (sent = 0, retry = 0; sent != len && retry < 3;) {
215                 ret = sendto(sock, msg, len - sent,
216                              MSG_NOSIGNAL, daddr, addrlen);
217                 if (ret == -1) {
218                         if (errno == EAGAIN || errno == EINTR) {
219                                 retry++;
220                                 continue;
221                         }
222                         break;
223                 }
224                 sent += ret;
225                 msg += ret;
226         }
227         return sent;
228 }
229
230 static int recv_message(int sock, unsigned char *msg, int len,
231                         struct sockaddr *saddr, socklen_t *addrlen)
232 {
233         int received, retry, ret;
234
235         for (received = 0, retry = 0; received != len && retry < 3;) {
236                 ret = recvfrom(sock, msg, len - received,
237                                MSG_NOSIGNAL, saddr, addrlen);
238                 if (ret == -1) {
239                         if (errno == EAGAIN || errno == EINTR) {
240                                 retry++;
241                                 continue;
242                         }
243                         break;
244                 } else if (ret == 0) {
245                         break;
246                 }
247                 received += ret;
248                 msg += ret;
249         }
250         return received;
251 }
252
253 static int send_ctrl_message(int sock, struct netsum_ctrl_msg *ctrl_msg,
254                              struct sockaddr *daddr, socklen_t addrlen)
255 {
256         int ret;
257
258         ret = send_message(sock, (unsigned char *) ctrl_msg, sizeof(*ctrl_msg),
259                            daddr, addrlen);
260         if (ret != sizeof(*ctrl_msg))
261                 return -1;
262
263         return 0;
264 }
265
266 static int recv_ctrl_message(int sock, struct netsum_ctrl_msg *ctrl_msg,
267                              struct sockaddr *saddr, socklen_t *addrlen)
268 {
269         int ret;
270
271         ret = recv_message(sock, (unsigned char *) ctrl_msg, sizeof(*ctrl_msg),
272                            saddr, addrlen);
273         if (ret != sizeof(*ctrl_msg))
274                 return -1;
275
276         return 0;
277 }
278
279 static void show_stats(struct netsum_ctrl *ctrl)
280 {
281         int i;
282         uint64_t sum;
283         bool print;
284         char *units[] = { "KB", "MB", "GB", "TB" };
285         double time, time_sum;
286
287         fprintf(stdout, "\33[2K\rTx:");
288         sum = ctrl->tx_sum >> 10;
289         print = false;
290         for (i = ARRAY_SIZE(units) - 1; i >= 0; i--) {
291                 uint64_t val = sum >> (10 * i);
292
293                 if (print || val) {
294                         fprintf(stdout, " %04" PRIu64 "%s", val, units[i]);
295                         sum -= (val << (10 * i));
296                         print = true;
297                 }
298         }
299         /* Avoid division by zero. */
300         time = ctrl->tx_time ? : 1; 
301         time_sum = ctrl->tx_time_sum ? : 1; 
302         fprintf(stdout, " (%3.3fMB/sec %3.3fMB/sec)",
303                 ((ctrl->tx >> 20) / time),
304                 ((ctrl->tx_sum >> 20) / time_sum));
305
306         fprintf(stdout, " - Rx:");
307         sum = ctrl->rx_sum >> 10;
308         print = false;
309         for (i = ARRAY_SIZE(units) - 1; i >= 0; i--) {
310                 uint64_t val = sum >> (10 * i);
311
312                 if (print || val) {
313                         fprintf(stdout, " %04" PRIu64 "%s", val, units[i]);
314                         sum -= (val << (10 * i));
315                         print = true;
316                 }
317         }
318         /* Avoid division by zero. */
319         time = ctrl->rx_time ? : 1; 
320         time_sum = ctrl->rx_time_sum ? : 1; 
321         fprintf(stdout, " (%3.3fMB/sec %3.3fMB/sec)",
322                 ((ctrl->rx >> 20) / time),
323                 ((ctrl->rx_sum >> 20) / time_sum));
324
325         fflush(stdout);
326 }
327
328 static int send_data_messages(int sock, struct netsum_ctrl *ctrl,
329                               struct sockaddr *daddr, socklen_t addrlen)
330 {
331         unsigned char *buff;
332         uint32_t i;
333         int ret;
334         struct timeval start_time, stop_time;
335         double elapsed;
336
337         /* Prepare payload */
338         buff = malloc(ctrl->size);
339         if (!buff) {
340                 fprintf(stderr, "malloc() failed: %s\n", strerror(errno));
341                 return -1;
342         }
343         fill_buffer(buff, ctrl->size);
344
345         gettimeofday(&start_time, NULL);
346
347         for (i = 0; i < ctrl->count; i++) {
348                 ret = send_message(sock, buff, ctrl->size, daddr, addrlen);
349                 if (ret != ctrl->size) {
350                         ret = -1;
351                         goto err_free;
352                 }
353         }
354
355         gettimeofday(&stop_time, NULL);
356         elapsed = stop_time.tv_sec - start_time.tv_sec +
357                   ((double) (stop_time.tv_usec - start_time.tv_usec) / 1000000);
358
359         ctrl->tx = (uint64_t) ctrl->count * (uint64_t) ctrl->size;
360         ctrl->tx_sum += ctrl->tx;
361         ctrl->tx_time = elapsed;
362         ctrl->tx_time_sum += elapsed;
363
364         show_stats(ctrl);
365         ret = 0;
366
367 err_free:
368         free(buff);
369         return ret;
370 }
371
372 static int recv_data_messages(int sock, struct netsum_ctrl *ctrl,
373                               struct sockaddr *saddr, socklen_t *addrlen)
374 {
375         unsigned char *buff;
376         unsigned char *template = NULL;
377         bool use_memcmp = !!(ctrl->flags & USE_MEMCMP);
378         int ret = -1;
379         struct timeval start_time, stop_time;
380         uint32_t i;
381         double elapsed;
382
383         buff = malloc(ctrl->size);
384         if (!buff) {
385                 fprintf(stderr, "malloc() failed: %s\n", strerror(errno));
386                 return -1;
387         }
388
389         template = malloc(ctrl->size);
390         if (!template) {
391                 fprintf(stderr, "malloc() failed: %s\n",
392                         strerror(errno));
393                 goto err_free;
394         }
395
396         gettimeofday(&start_time, NULL);
397
398         for (i = 0; i < ctrl->count; i++) {
399                 ret = recv_message(sock, buff, ctrl->size, saddr, addrlen);
400                 if (ret != ctrl->size) {
401                         ret = -1;
402                         goto err_free;
403                 }
404                 if (use_memcmp) {
405                         /*
406                          * For the first buffer received or on error, fall back
407                          * on buffer_is_valid().
408                          */
409                         if (i == 0)
410                                 memcpy(template, buff, ctrl->size);
411                         else if (!memcmp(buff, template, ctrl->size))
412                                 continue;
413                 }
414                 if (!buffer_is_valid(buff, ctrl->size, true)) {
415                         fprintf(stdout,
416                                 "Reference buffer stored in file: ref\n");
417                         fprintf(stdout,
418                                 "Corrupted buffer stored in file: err\n");
419                         write_buffer_to_file("ref", buff, ctrl->size);
420                         write_buffer_to_file("err", template, ctrl->size);
421                         ret = -1;
422                         goto err_free;
423                 }
424         }
425
426         gettimeofday(&stop_time, NULL);
427         elapsed = stop_time.tv_sec - start_time.tv_sec +
428                   ((double) (stop_time.tv_usec - start_time.tv_usec) / 1000000);
429
430         ctrl->rx = (uint64_t) ctrl->count * (uint64_t) ctrl->size;
431         ctrl->rx_sum += ctrl->rx;
432         ctrl->rx_time = elapsed;
433         ctrl->rx_time_sum += elapsed;
434
435         show_stats(ctrl);
436         ret = 0;
437
438 err_free:
439         free(buff);
440         if (template)
441                 free(template);
442
443         return ret;
444 }
445
446 /*
447  * Base.
448  */
449
450 static int handle_client(int server_sock, struct netsum_ctrl *ctrl)
451 {
452         int sock, ret;
453         struct sockaddr_in sin;
454         socklen_t addrlen = sizeof(sin);
455         struct netsum_ctrl_msg ctrl_msg;
456
457         sock = accept(server_sock, (struct sockaddr *) &sin, &addrlen);
458         if (sock == -1) {
459                 fprintf(stderr, "accept() failed: %s\n", strerror(errno));
460                 return -1;
461         }
462         fprintf(stdout, "\nClient %s:%d\n",
463                 inet_ntoa(sin.sin_addr), ntohs(sin.sin_port));
464
465         ctrl->rx = 0;
466         ctrl->rx_sum = 0;
467         ctrl->rx_time = 0;
468         ctrl->rx_time_sum = 0;
469         ctrl->tx = 0;
470         ctrl->tx_sum = 0;
471         ctrl->tx_time = 0;
472         ctrl->tx_time_sum = 0;
473
474         while (1) {
475                 ret = recv_ctrl_message(sock, &ctrl_msg, NULL, NULL);
476                 if (ret)
477                         break;
478
479                 ctrl->count = ntohl(ctrl_msg.count);
480                 ctrl->size = ntohl(ctrl_msg.size);
481                 ctrl->timeout = ntohl(ctrl_msg.timeout);
482
483                 if (ctrl_msg.dir)
484                         ret = send_data_messages(sock, ctrl, NULL, 0);
485                 else
486                         ret = recv_data_messages(sock, ctrl, NULL, NULL);
487                 if (ret)
488                         break;
489         };
490
491         close(sock);
492
493         return ret;
494 }
495
496 static int run_server(struct sockaddr_in *sin, struct netsum_ctrl *ctrl)
497 {
498         int sock;
499
500         sock = open_server_sock(sin, ctrl->timeout, ctrl->size);
501         if (sock == -1)
502                 return -1;
503
504         while(1) { handle_client(sock, ctrl); };
505
506         close(sock);
507         return -1;
508 }
509
510 static int run_client(struct sockaddr_in *sin, struct netsum_ctrl *ctrl)
511 {
512         int sock, ret;
513         struct netsum_ctrl_msg ctrl_msg;
514
515         sock = connect_server(sin, ctrl->timeout, ctrl->size);
516         if (sock == -1)
517                 return -1;
518
519         ctrl->rx = 0;
520         ctrl->rx_sum = 0;
521         ctrl->rx_time = 0;
522         ctrl->rx_time_sum = 0;
523         ctrl->tx = 0;
524         ctrl->tx_sum = 0;
525         ctrl->tx_time = 0;
526         ctrl->tx_time_sum = 0;
527
528         ctrl_msg.count = htonl(ctrl->count);
529         ctrl_msg.size = htonl(ctrl->size);
530         ctrl_msg.timeout = htonl(ctrl->timeout);
531
532         while (1) {
533                 if (ctrl->flags & DIR_TX) {
534                         ctrl_msg.dir = 0;
535                         ret = send_ctrl_message(sock, &ctrl_msg, NULL, 0);
536                         if (ret)
537                                 break;
538                         ret = send_data_messages(sock, ctrl, NULL, 0);
539                         if (ret)
540                                 break;
541                 }
542                 if (ctrl->flags & DIR_RX) {
543                         ctrl_msg.dir = 1;
544                         ret = send_ctrl_message(sock, &ctrl_msg, NULL, 0);
545                         if (ret)
546                                 break;
547                         ret = recv_data_messages(sock, ctrl, NULL, NULL);
548                         if (ret)
549                                 break;
550                 }
551         };
552
553         close(sock);
554
555         return ret;
556 }
557
558 int run_netsum(char *addr, struct netsum_ctrl *ctrl)
559 {
560         struct sockaddr_in sin;
561         int ret;
562
563         memset(&sin, 0, sizeof(sin));
564
565         if (addr) {
566                 if (get_inet_addr(addr, &sin)) {
567                         fprintf(stderr,  
568                                 "Failed to convert %s into an IP address\n", addr);
569                         return -1;
570                 }
571                 ret = addr_is_local(&sin);
572                 if (ret < 0)
573                         return ret;
574                 if (!ret)
575                         return run_client(&sin, ctrl);
576         } else {
577                 sin.sin_family = AF_INET;
578                 sin.sin_port = htons(NETSUM_DEFAULT_PORT);
579                 sin.sin_addr.s_addr = INADDR_ANY;
580         }
581
582         return run_server(&sin, ctrl);
583 }