diff --git a/src/transport/tcp.c b/src/transport/tcp.c index e408865..7f9922a 100644 --- a/src/transport/tcp.c +++ b/src/transport/tcp.c @@ -60,6 +60,8 @@ typedef struct _TCPSocket AppTransportClient * client; int fd; + struct sockaddr * sa; + socklen_t sa_size; /* input queue */ char * bufin; @@ -75,6 +77,7 @@ struct _AppTransportPlugin AppTransportMode mode; struct addrinfo * ai; + socklen_t ai_addrlen; union { @@ -113,8 +116,10 @@ static int _tcp_server_add_client(TCP * tcp, TCPSocket * client); /* sockets */ static int _tcp_socket_init(TCPSocket * tcpsocket, int domain, TCP * tcp); -static void _tcp_socket_init_fd(TCPSocket * tcpsocket, TCP * tcp, int fd); -static TCPSocket * _tcp_socket_new_fd(TCP * tcp, int fd); +static void _tcp_socket_init_fd(TCPSocket * tcpsocket, TCP * tcp, int fd, + struct sockaddr * sa, socklen_t sa_size); +static TCPSocket * _tcp_socket_new_fd(TCP * tcp, int fd, struct sockaddr * sa, + socklen_t sa_size); static void _tcp_socket_delete(TCPSocket * tcpsocket); static void _tcp_socket_destroy(TCPSocket * tcpsocket); @@ -281,6 +286,7 @@ static int _init_client(TCP * tcp, char const * name) tcp->u.client.fd, (EventIOFunc)_tcp_socket_callback_read, &tcp->u.client); + tcp->ai_addrlen = aip->ai_addrlen; break; } freeaddrinfo(tcp->ai); @@ -337,6 +343,7 @@ static int _init_server(TCP * tcp, char const * name) tcp->u.server.fd = -1; continue; } + tcp->ai_addrlen = aip->ai_addrlen; event_register_io_read(tcp->helper->event, tcp->u.server.fd, (EventIOFunc)_tcp_callback_accept, tcp); break; @@ -428,7 +435,7 @@ static int _tcp_socket_init(TCPSocket * tcpsocket, int domain, TCP * tcp) if((tcpsocket->fd = socket(domain, SOCK_STREAM, 0)) < 0) return -_tcp_error("socket", 1); - _tcp_socket_init_fd(tcpsocket, tcp, tcpsocket->fd); + _tcp_socket_init_fd(tcpsocket, tcp, tcpsocket->fd, NULL, 0); /* set the socket as non-blocking */ if((flags = fcntl(tcpsocket->fd, F_GETFL)) == -1) return -_tcp_error("fcntl", 1); @@ -440,11 +447,14 @@ static int _tcp_socket_init(TCPSocket * tcpsocket, int domain, TCP * tcp) /* tcp_socket_init_fd */ -static void _tcp_socket_init_fd(TCPSocket * tcpsocket, TCP * tcp, int fd) +static void _tcp_socket_init_fd(TCPSocket * tcpsocket, TCP * tcp, int fd, + struct sockaddr * sa, socklen_t sa_size) { tcpsocket->tcp = tcp; tcpsocket->client = NULL; tcpsocket->fd = fd; + tcpsocket->sa = sa; + tcpsocket->sa_size = sa_size; tcpsocket->bufin = NULL; tcpsocket->bufin_cnt = 0; tcpsocket->bufout = NULL; @@ -453,13 +463,14 @@ static void _tcp_socket_init_fd(TCPSocket * tcpsocket, TCP * tcp, int fd) /* tcp_socket_new_fd */ -static TCPSocket * _tcp_socket_new_fd(TCP * tcp, int fd) +static TCPSocket * _tcp_socket_new_fd(TCP * tcp, int fd, struct sockaddr * sa, + socklen_t sa_size) { TCPSocket * tcpsocket; if((tcpsocket = object_new(sizeof(*tcpsocket))) == NULL) return NULL; - _tcp_socket_init_fd(tcpsocket, tcp, fd); + _tcp_socket_init_fd(tcpsocket, tcp, fd, sa, sa_size); return tcpsocket; } @@ -514,13 +525,13 @@ static int _tcp_socket_queue(TCPSocket * tcpsocket, Buffer * buffer) /* callbacks */ /* tcp_callback_accept */ -static int _accept_client(TCP * tcp, int fd); +static int _accept_client(TCP * tcp, int fd, struct sockaddr * sa, + socklen_t sa_size); static int _tcp_callback_accept(int fd, TCP * tcp) { - /* FIXME may not be the right type of struct sockaddr */ - struct sockaddr_in sa; - socklen_t sa_size = sizeof(sa); + struct sockaddr * sa; + socklen_t sa_size = tcp->ai_addrlen; #ifdef DEBUG fprintf(stderr, "DEBUG: %s(%d)\n", __func__, fd); @@ -528,12 +539,21 @@ static int _tcp_callback_accept(int fd, TCP * tcp) /* check parameters */ if(tcp->u.server.fd != fd) return -1; - if((fd = accept(fd, (struct sockaddr *)&sa, &sa_size)) < 0) + if((sa = malloc(sa_size)) == NULL) + /* XXX this may not be enough to recover */ + sa_size = 0; + if((fd = accept(fd, sa, &sa_size)) < 0) + { + free(sa); return _tcp_error("accept", 1); - if(_accept_client(tcp, fd) != 0) + } + if(_accept_client(tcp, fd, sa, sa_size) != 0) + { /* just close the connection and keep serving */ /* FIXME report error */ close(fd); + free(sa); + } #ifdef DEBUG else fprintf(stderr, "DEBUG: %s() %d\n", __func__, fd); @@ -541,7 +561,8 @@ static int _tcp_callback_accept(int fd, TCP * tcp) return 0; } -static int _accept_client(TCP * tcp, int fd) +static int _accept_client(TCP * tcp, int fd, struct sockaddr * sa, + socklen_t sa_size) { TCPSocket * tcpsocket; Buffer * buf; @@ -550,7 +571,7 @@ static int _accept_client(TCP * tcp, int fd) #ifdef DEBUG fprintf(stderr, "DEBUG: %s(%d)\n", __func__, fd); #endif - if((tcpsocket = _tcp_socket_new_fd(tcp, fd)) == NULL) + if((tcpsocket = _tcp_socket_new_fd(tcp, fd, sa, sa_size)) == NULL) return -1; /* send the banner */ if((buf = buffer_new(sizeof(banner), banner)) == NULL @@ -587,6 +608,7 @@ static void _tcp_socket_delete(TCPSocket * tcpsocket) /* tcp_socket_destroy */ static void _tcp_socket_destroy(TCPSocket * tcpsocket) { + free(tcpsocket->sa); if(tcpsocket->fd >= 0) { event_unregister_io_read(tcpsocket->tcp->helper->event,