diff --git a/oio-win.c b/oio-win.c index 8e97c68c..c26071f1 100644 --- a/oio-win.c +++ b/oio-win.c @@ -105,30 +105,26 @@ # define SO_UPDATE_CONNECT_CONTEXT 0x7010 #endif -/* - * Described in MSDN but apparently not defined in the SDK. - */ -#ifndef ERROR_SUCCESS -# define ERROR_SUCCESS 0 -#endif - /* * Pointers to winsock extension functions to be retrieved dynamically */ -static LPFN_CONNECTEX pConnectEx; -static LPFN_ACCEPTEX pAcceptEx; -static LPFN_GETACCEPTEXSOCKADDRS pGetAcceptExSockAddrs; -static LPFN_DISCONNECTEX pDisconnectEx; -static LPFN_TRANSMITFILE pTransmitFile; +static LPFN_CONNECTEX pConnectEx; +static LPFN_ACCEPTEX pAcceptEx; +static LPFN_GETACCEPTEXSOCKADDRS pGetAcceptExSockAddrs; +static LPFN_DISCONNECTEX pDisconnectEx; +static LPFN_TRANSMITFILE pTransmitFile; /* * Private oio_handle flags */ -#define OIO_HANDLE_CLOSING 0x01 -#define OIO_HANDLE_CLOSED 0x02 -#define OIO_HANDLE_BOUND 0x04 +#define OIO_HANDLE_CLOSING 0x01 +#define OIO_HANDLE_CLOSED 0x02 +#define OIO_HANDLE_BOUND 0x04 +#define OIO_HANDLE_READING 0x08 +#define OIO_HANDLE_LISTENING 0x10 +#define OIO_HANDLE_BIND_ERROR 0x20 /* * Private oio_req flags. @@ -137,19 +133,6 @@ static LPFN_TRANSMITFILE pTransmitFile; #define OIO_REQ_PENDING 0x01 -/* - * Special oio_req type used by AcceptEx calls - */ -typedef struct oio_accept_req_s { - struct oio_req_s req; - SOCKET socket; - - /* AcceptEx specifies that the buffer must be big enough to at least hold */ - /* two socket addresses plus 32 bytes. */ - char buffer[sizeof(struct sockaddr_storage) * 2 + 32]; -} oio_accept_req; - - /* Binary tree used to keep the list of timers sorted. */ static int oio_timer_compare(oio_req* t1, oio_req* t2); RB_HEAD(oio_timer_s, oio_req_s); @@ -179,6 +162,10 @@ static const oio_err oio_ok_ = { OIO_OK, ERROR_SUCCESS }; static oio_err oio_last_error_ = { OIO_OK, ERROR_SUCCESS }; +/* Global alloc function */ +oio_alloc_cb oio_alloc_ = NULL; + + /* Reference count that keeps the event loop alive */ static int oio_refs_ = 0; @@ -187,6 +174,10 @@ static int oio_refs_ = 0; static struct sockaddr_in oio_addr_ip4_any_; +/* A zero-size buffer for use by oio_read */ +static char oio_zero_[] = ""; + + /* * Display an error message and abort the event loop. */ @@ -285,7 +276,7 @@ static void oio_get_extension_function(SOCKET socket, GUID guid, } -void oio_init() { +void oio_init(oio_alloc_cb alloc_cb) { const GUID wsaid_connectex = WSAID_CONNECTEX; const GUID wsaid_acceptex = WSAID_ACCEPTEX; const GUID wsaid_getacceptexsockaddrs = WSAID_GETACCEPTEXSOCKADDRS; @@ -297,6 +288,8 @@ void oio_init() { LARGE_INTEGER timer_frequency; SOCKET dummy; + oio_alloc_ = alloc_cb; + /* Initialize winsock */ errorno = WSAStartup(MAKEWORD(2, 2), &wsa_data); if (errorno != 0) { @@ -365,12 +358,19 @@ static int oio_set_socket_options(SOCKET socket) { /* Set the SO_REUSEADDR option on the socket */ /* If it fails, soit. */ - setsockopt(socket, + 0&&setsockopt(socket, SOL_SOCKET, SO_REUSEADDR, (char*)&yes, sizeof(int)); + /* Set the socket to nonblocking mode */ + if (ioctlsocket(socket, FIONBIO, &yes) == SOCKET_ERROR) { + oio_set_sys_error(WSAGetLastError()); + return -1; + } + + /* Make the socket non-inheritable */ if (!SetHandleInformation((HANDLE)socket, HANDLE_FLAG_INHERIT, 0)) { oio_set_sys_error(GetLastError()); @@ -399,8 +399,9 @@ int oio_tcp_init(oio_handle* handle, oio_close_cb close_cb, handle->flags = 0; handle->reqs_pending = 0; handle->error = oio_ok_; - handle->accept_reqs = NULL; - handle->accepted_socket = INVALID_SOCKET; + handle->accept_socket = INVALID_SOCKET; + + oio_req_init(&(handle->read_accept_req), handle, NULL); handle->socket = socket(AF_INET, SOCK_STREAM, 0); if (handle->socket == INVALID_SOCKET) { @@ -421,7 +422,7 @@ int oio_tcp_init(oio_handle* handle, oio_close_cb close_cb, int oio_accept(oio_handle* server, oio_handle* client, oio_close_cb close_cb, void* data) { - if (server->accepted_socket == INVALID_SOCKET) { + if (server->accept_socket == INVALID_SOCKET) { oio_set_sys_error(WSAENOTCONN); return -1; } @@ -429,16 +430,18 @@ int oio_accept(oio_handle* server, oio_handle* client, client->close_cb = close_cb; client->data = data; client->type = OIO_TCP; - client->socket = server->accepted_socket; + client->socket = server->accept_socket; client->flags = 0; client->reqs_pending = 0; client->error = oio_ok_; - client->accepted_socket = INVALID_SOCKET; - client->accept_reqs = NULL; + client->accept_socket = INVALID_SOCKET; + + oio_req_init(&(client->read_accept_req), client, NULL); - server->accepted_socket = INVALID_SOCKET; oio_refs_++; + server->accept_socket = INVALID_SOCKET; + return 0; } @@ -500,12 +503,9 @@ static void oio_call_close_cbs() { handle->flags |= OIO_HANDLE_CLOSED; oio_refs_--; - if (handle->accept_reqs) { - free(handle->accept_reqs); - } if (handle->close_cb) { - oio_last_error_ = handle->error; - handle->close_cb(handle, handle->error.code == OIO_OK ? 0 : 1); + oio_last_error_ = handle->error; + handle->close_cb(handle, handle->error.code == OIO_OK ? 0 : 1); } } } @@ -524,6 +524,7 @@ struct sockaddr_in oio_ip4_addr(char* ip, int port) { int oio_bind(oio_handle* handle, struct sockaddr* addr) { int addrsize; + DWORD err; if (addr->sa_family == AF_INET) { addrsize = sizeof(struct sockaddr_in); @@ -535,8 +536,15 @@ int oio_bind(oio_handle* handle, struct sockaddr* addr) { } if (bind(handle->socket, addr, addrsize) == SOCKET_ERROR) { - oio_set_sys_error(WSAGetLastError()); - return -1; + err = WSAGetLastError(); + if (err == WSAEADDRINUSE) { + /* Some errors are not to be reported until connect() or listen() */ + handle->error = oio_new_sys_error(err); + handle->flags |= OIO_HANDLE_BIND_ERROR; + } else { + oio_set_sys_error(err); + return -1; + } } handle->flags |= OIO_HANDLE_BOUND; @@ -545,83 +553,145 @@ int oio_bind(oio_handle* handle, struct sockaddr* addr) { } -static void oio_queue_accept(oio_accept_req* areq, oio_handle* handle) { +static void oio_queue_accept(oio_handle* handle) { + oio_req* req; BOOL success; DWORD bytes; + SOCKET accept_socket; - areq->socket = socket(AF_INET, SOCK_STREAM, 0); - if (areq->socket == INVALID_SOCKET) { + assert(handle->flags & OIO_HANDLE_LISTENING); + assert(handle->accept_socket == INVALID_SOCKET); + + accept_socket = socket(AF_INET, SOCK_STREAM, 0); + if (accept_socket == INVALID_SOCKET) { oio_close_error(handle, oio_new_sys_error(WSAGetLastError())); return; } - if (oio_set_socket_options(areq->socket) != 0) { - closesocket(areq->socket); + if (oio_set_socket_options(accept_socket) != 0) { + closesocket(accept_socket); oio_close_error(handle, oio_last_error_); return; } /* Prepare the oio_req and OVERLAPPED structures. */ - assert(!(areq->req.flags & OIO_REQ_PENDING)); - areq->req.flags |= OIO_REQ_PENDING; - memset(&areq->req.overlapped, 0, sizeof(areq->req.overlapped)); + req = &handle->read_accept_req; + assert(!(req->flags & OIO_REQ_PENDING)); + req->type = OIO_ACCEPT; + req->flags |= OIO_REQ_PENDING; + memset(&(req->overlapped), 0, sizeof(req->overlapped)); success = pAcceptEx(handle->socket, - areq->socket, - (void*)&areq->buffer, + accept_socket, + (void*)&handle->accept_buffer, 0, sizeof(struct sockaddr_storage), sizeof(struct sockaddr_storage), &bytes, - &areq->req.overlapped); + &req->overlapped); if (!success && WSAGetLastError() != ERROR_IO_PENDING) { oio_set_sys_error(WSAGetLastError()); /* destroy the preallocated client handle */ - closesocket(areq->socket); + closesocket(accept_socket); /* destroy ourselves */ oio_close_error(handle, oio_last_error_); return; } + handle->accept_socket = accept_socket; + + handle->reqs_pending++; + req->flags |= OIO_REQ_PENDING; +} + + +void oio_queue_read(oio_handle* handle) { + oio_req *req; + oio_buf buf; + int result; + DWORD bytes, flags; + + assert(handle->flags & OIO_HANDLE_READING); + + req = &handle->read_accept_req; + assert(!(req->flags & OIO_REQ_PENDING)); + memset(&req->overlapped, 0, sizeof(req->overlapped)); + req->type = OIO_READ; + + buf.base = (char*) &oio_zero_; + buf.len = 0; + + flags = 0; + result = WSARecv(handle->socket, + (WSABUF*)&buf, + 1, + &bytes, + &flags, + &req->overlapped, + NULL); + if (result != 0 && WSAGetLastError() != ERROR_IO_PENDING) { + oio_set_sys_error(WSAGetLastError()); + oio_close_error(handle, oio_last_error_); + return; + } + + req->flags |= OIO_REQ_PENDING; handle->reqs_pending++; - areq->req.flags |= OIO_REQ_PENDING; } int oio_listen(oio_handle* handle, int backlog, oio_accept_cb cb) { - oio_accept_req* areq; - oio_accept_req* reqs; - int i; - assert(backlog > 0); - if (handle->accept_reqs != NULL) { + if (handle->flags & OIO_HANDLE_BIND_ERROR) { + oio_last_error_ = handle->error; + return -1; + } + + if (handle->flags & OIO_HANDLE_LISTENING || + handle->flags & OIO_HANDLE_READING) { /* Already listening. */ oio_set_sys_error(WSAEALREADY); return -1; } - reqs = (oio_accept_req*)malloc(sizeof(oio_accept_req) * backlog); - if (!reqs) { - oio_set_sys_error(ERROR_OUTOFMEMORY); - return -1; - } - if (listen(handle->socket, backlog) == SOCKET_ERROR) { oio_set_sys_error(WSAGetLastError()); - free(reqs); return -1; } - for (i = backlog, areq = reqs; i > 0; i--, areq++) { - areq->socket = INVALID_SOCKET; - oio_req_init((oio_req*)areq, handle, (void*)cb); - areq->req.type = OIO_ACCEPT; - oio_queue_accept(areq, handle); + handle->flags |= OIO_HANDLE_LISTENING; + handle->accept_cb = cb; + + oio_queue_accept(handle); + + return 0; +} + + +int oio_read_start(oio_handle* handle, oio_read_cb cb) { + if (handle->flags & OIO_HANDLE_LISTENING || + handle->flags & OIO_HANDLE_READING) { + /* Already listening. */ + oio_set_sys_error(WSAEALREADY); + return -1; } - handle->accept_reqs = (oio_accept_req*)reqs; + handle->flags |= OIO_HANDLE_READING; + handle->read_cb = cb; + + /* If reading was stopped and then started again, there could stell be a */ + /* read request pending. */ + if (!handle->read_accept_req.flags & OIO_REQ_PENDING) + oio_queue_read(handle); + + return 0; +} + + +int oio_read_stop(oio_handle* handle) { + handle->flags &= ~OIO_HANDLE_READING; return 0; } @@ -635,6 +705,11 @@ int oio_connect(oio_req* req, struct sockaddr* addr) { assert(!(req->flags & OIO_REQ_PENDING)); + if (handle->flags & OIO_HANDLE_BIND_ERROR) { + oio_last_error_ = handle->error; + return -1; + } + if (addr->sa_family == AF_INET) { addrsize = sizeof(struct sockaddr_in); if (!(handle->flags & OIO_HANDLE_BOUND) && @@ -701,36 +776,6 @@ int oio_write(oio_req* req, oio_buf* bufs, int bufcnt) { } -int oio_read(oio_req* req, oio_buf* bufs, int bufcnt) { - int result; - DWORD bytes, flags; - oio_handle* handle = req->handle; - - assert(!(req->flags & OIO_REQ_PENDING)); - - memset(&req->overlapped, 0, sizeof(req->overlapped)); - req->type = OIO_READ; - - flags = 0; - result = WSARecv(handle->socket, - (WSABUF*)bufs, - bufcnt, - &bytes, - &flags, - &req->overlapped, - NULL); - if (result != 0 && WSAGetLastError() != ERROR_IO_PENDING) { - oio_set_sys_error(WSAGetLastError()); - return -1; - } - - req->flags |= OIO_REQ_PENDING; - handle->reqs_pending++; - - return 0; -} - - static int oio_timer_compare(oio_req* a, oio_req* b) { if (a->due < b->due) return -1; @@ -785,9 +830,11 @@ static void oio_poll() { ULONG_PTR key; OVERLAPPED* overlapped; oio_req* req; - oio_accept_req* accept_req; oio_handle* handle; + oio_buf buf; DWORD timeout; + DWORD flags; + DWORD err; int64_t delta; /* Call all pending close callbacks. */ @@ -860,42 +907,78 @@ static void oio_poll() { oio_set_sys_error(GetLastError()); oio_close_error(handle, oio_last_error_); } - if (req->cb) { - ((oio_read_cb)req->cb)(req, bytes, success ? 0 : -1); + while (handle->flags & OIO_HANDLE_READING) { + buf = oio_alloc_(handle, 65536); + assert(buf.len > 0); + flags = 0; + if (WSARecv(handle->socket, + (WSABUF*)&buf, + 1, + &bytes, + &flags, + NULL, + NULL) != SOCKET_ERROR) { + if (bytes > 0) { + /* Successful read */ + ((oio_read_cb)handle->read_cb)(handle, bytes, buf); + /* Read again only if bytes == buf.len */ + if (bytes < buf.len) { + break; + } + } else { + /* Connection closed */ + handle->flags &= ~OIO_HANDLE_READING; + oio_last_error_.code = OIO_EOF; + oio_last_error_.sys_errno_ = ERROR_SUCCESS; + ((oio_read_cb)handle->read_cb)(handle, -1, buf); + } + } else { + err = WSAGetLastError(); + if (err == WSAEWOULDBLOCK) { + /* 0-byte read */ + ((oio_read_cb)handle->read_cb)(handle, 0, buf); + } else { + /* Ouch! serious error. */ + oio_set_sys_error(err); + oio_close_error(handle, oio_last_error_); + } + break; + } + } + /* Post another 0-read if still reading and not closing */ + if (!(handle->flags & OIO_HANDLE_CLOSING) && + handle->flags & OIO_HANDLE_READING) { + oio_queue_read(handle); } break; case OIO_ACCEPT: - accept_req = (oio_accept_req*)req; - assert(accept_req->socket != INVALID_SOCKET); - assert(handle->accepted_socket == INVALID_SOCKET); - - handle->accepted_socket = accept_req->socket; + assert(handle->accept_socket != INVALID_SOCKET); success = GetOverlappedResult(handle->handle, overlapped, &bytes, FALSE); if (success) { - if (setsockopt(handle->accepted_socket, + if (setsockopt(handle->accept_socket, SOL_SOCKET, SO_UPDATE_ACCEPT_CONTEXT, (char*)&handle->socket, sizeof(handle->socket)) == 0) { - if (req->cb) { - ((oio_accept_cb)req->cb)(handle); + if (handle->accept_cb) { + ((oio_accept_cb)handle->accept_cb)(handle); } } } - /* accept_cb should call oio_accept_handle which sets data->socket */ + /* accept_cb should call oio_accept which sets handle->accept_socket */ /* to INVALID_SOCKET. */ /* Errorneous accept is ignored if the listen socket is still healthy. */ - if (handle->accepted_socket != INVALID_SOCKET) { - closesocket(handle->accepted_socket); - handle->accepted_socket = INVALID_SOCKET; + if (handle->accept_socket != INVALID_SOCKET) { + closesocket(handle->accept_socket); + handle->accept_socket = INVALID_SOCKET; } /* Queue another accept */ - if (!handle->flags & OIO_HANDLE_CLOSING) - oio_queue_accept(accept_req, handle); + if (!(handle->flags & OIO_HANDLE_CLOSING)) + oio_queue_accept(handle); break; case OIO_CONNECT: diff --git a/oio-win.h b/oio-win.h index e5cf66d1..4cd580d6 100644 --- a/oio-win.h +++ b/oio-win.h @@ -59,8 +59,13 @@ typedef struct oio_buf { SOCKET socket; \ HANDLE handle; \ }; \ - SOCKET accepted_socket; \ - struct oio_accept_req_s* accept_reqs; \ + union { \ + void *read_cb; \ + void *accept_cb; \ + }; \ + struct oio_req_s read_accept_req; \ + char accept_buffer[sizeof(struct sockaddr_storage) * 2 + 32]; \ + SOCKET accept_socket; \ unsigned int flags; \ unsigned int reqs_pending; \ oio_err error;