diff --git a/ol-win.c b/ol-win.c index 3b05b08c..b46c2469 100644 --- a/ol-win.c +++ b/ol-win.c @@ -69,21 +69,26 @@ LPTRANSMIT_FILE_BUFFERS lpTransmitBuffers, DWORD dwFlags); #endif - /* - * Private ol_req flags. + * Private ol_handle flags + */ +#define OL_HANDLE_CLOSING 0x01 +#define OL_HANDLE_CLOSED 0x03 + +/* + * Private ol_req flags. */ /* The request is currently queued. */ -#define OL_REQ_PENDING 0x01 +#define OL_REQ_PENDING 0x01 /* When STRAY is set, that means that the handle owning the ol_req */ /* struct was destroyed while the old_req was queued to an iocp */ -#define OL_REQ_STRAY 0x02 +#define OL_REQ_STRAY 0x02 /* When INTERNAL is set that means that the ol_req struct was */ /* allocated by libol, so libol also needs to free it again */ -#define OL_REQ_INTERNAL 0x04 +#define OL_REQ_INTERNAL 0x04 /* * Pointers to winsock extension functions that have to be retrieved dynamically @@ -95,7 +100,7 @@ LPFN_DISCONNECTEX pDisconnectEx; LPFN_TRANSMITFILE pTransmitFile; /* - * Global I/O completion port + * Global I/O completion port */ HANDLE ol_iocp_; @@ -104,6 +109,10 @@ HANDLE ol_iocp_; int ol_errno_; +/* Reference count that keeps the event loop alive */ +int ol_refs_ = 0; + + /* * Display an error message and abort the event loop. */ @@ -120,7 +129,7 @@ void ol_fatal_error(const int errorno, const char *syscall) { } else { errmsg = "Unknown error"; } - + /* FormatMessage messages include a newline character already, */ /* so don't add another. */ if (syscall) { @@ -219,10 +228,10 @@ int ol_set_socket_options(ol_handle *handle) { /* Set the SO_REUSEADDR option on the socket */ /* If it fails, soit. */ - setsockopt(handle->_.socket, - SOL_SOCKET, - SO_REUSEADDR, - (char*)&yes, + setsockopt(handle->_.socket, + SOL_SOCKET, + SO_REUSEADDR, + (char*)&yes, sizeof(int)); /* Make the socket non-inheritable */ @@ -233,9 +242,9 @@ int ol_set_socket_options(ol_handle *handle) { /* Associate it with the I/O completion port. */ /* Use ol_handle pointer as completion key. */ - if (CreateIoCompletionPort(handle->_.handle, - ol_iocp_, - (ULONG_PTR)handle, + if (CreateIoCompletionPort(handle->_.handle, + ol_iocp_, + (ULONG_PTR)handle, 0) == NULL) { ol_errno_ = GetLastError(); return -1; @@ -252,7 +261,9 @@ ol_handle* ol_tcp_handle_new(ol_close_cb close_cb, void* data) { handle->close_cb = close_cb; handle->data = data; handle->type = OL_TCP; - handle->_.accept_req = NULL; + handle->_.flags = 0; + handle->_.reqs_pending = 0; + handle->_.error = 0; handle->_.socket = socket(AF_INET, SOCK_STREAM, 0); if (handle->_.socket == INVALID_SOCKET) { @@ -266,23 +277,33 @@ ol_handle* ol_tcp_handle_new(ol_close_cb close_cb, void* data) { free(handle); return NULL; } - + + ol_refs_++; + return handle; } -int ol_close_error(ol_handle* handle, ol_err error) { +int ol_close_error(ol_handle* handle, ol_err e) { + if (handle->_.flags & OL_HANDLE_CLOSING) + return 0; + + handle->_.error = e; + switch (handle->type) { case OL_TCP: - if (handle->_.accept_req) { - if (handle->_.accept_req->_.flags & OL_REQ_PENDING) { - handle->_.accept_req->_.flags |= OL_REQ_STRAY; - } else { - free(handle->_.accept_req); - } + closesocket(handle->_.socket); + if (handle->_.reqs_pending != 0) { + /* Cannot free the handle right now because there are queued */ + /* operations. Close the socket, wait for for all packets to come */ + /* out, then have ol_poll call close_cb. */ + handle->_.flags |= OL_HANDLE_CLOSING; + } else { + /* There are no pending operations. Call the close callback now. */ + handle->_.flags |= OL_HANDLE_CLOSED; + if (handle->close_cb) + handle->close_cb(handle, e); } - if (closesocket(handle->_.socket) == SOCKET_ERROR) - return -1; return 0; default: @@ -300,6 +321,7 @@ int ol_close(ol_handle* handle) { void ol_free(ol_handle* handle) { free(handle); + ol_refs_--; } @@ -335,10 +357,10 @@ int ol_bind(ol_handle* handle, struct sockaddr* addr) { } -void ol_queue_accept(ol_handle *handle) { +void ol_queue_accept(ol_handle *handle, ol_req *req) { ol_handle* peer; void *buffer; - ol_req *req; + BOOL success; DWORD bytes; peer = ol_tcp_handle_new(NULL, NULL); @@ -353,31 +375,31 @@ void ol_queue_accept(ol_handle *handle) { buffer = malloc(sizeof(struct sockaddr_storage) * 2 + 32); /* Prepare the ol_req and OVERLAPPED structures. */ - req = handle->_.accept_req; assert(!(req->_.flags & OL_REQ_PENDING)); req->_.flags |= OL_REQ_PENDING; req->data = (void*)peer; memset(&req->_.overlapped, 0, sizeof(req->_.overlapped)); - if (!pAcceptEx(handle->_.socket, - peer->_.socket, - buffer, - 0, - sizeof(struct sockaddr_storage), - sizeof(struct sockaddr_storage), - &bytes, - &req->_.overlapped)) { - if (WSAGetLastError() != ERROR_IO_PENDING) { - ol_errno_ = WSAGetLastError(); - /* destroy the preallocated client handle */ - ol_close(peer); - ol_free(peer); - /* destroy ourselves */ - ol_close_error(handle, ol_errno_); - return; - } + success = pAcceptEx(handle->_.socket, + peer->_.socket, + buffer, + 0, + sizeof(struct sockaddr_storage), + sizeof(struct sockaddr_storage), + &bytes, + &req->_.overlapped); + + if (!success && WSAGetLastError() != ERROR_IO_PENDING) { + ol_errno_ = WSAGetLastError(); + /* destroy the preallocated client handle */ + ol_close(peer); + ol_free(peer); + /* destroy ourselves */ + ol_close_error(handle, ol_errno_); + return; } + handle->_.reqs_pending++; req->_.flags |= OL_REQ_PENDING; } @@ -389,20 +411,20 @@ int ol_listen(ol_handle* handle, int backlog, ol_accept_cb cb) { return -1; handle->accept_cb = cb; - req = (ol_req*)malloc(sizeof(handle->_.accept_req)); - handle->_.accept_req = req; - handle->_.accept_req->type = OL_ACCEPT; - handle->_.accept_req->_.flags = OL_REQ_INTERNAL; - - ol_queue_accept(handle); + req = (ol_req*)malloc(sizeof(*req)); + req->type = OL_ACCEPT; + req->handle = handle; + req->_.flags = OL_REQ_INTERNAL; + + ol_queue_accept(handle, req); return 0; } int ol_connect(ol_handle* handle, ol_req *req, struct sockaddr* addr) { - int addrsize; - int result; + int addrsize; + BOOL success; DWORD bytes; assert(!(req->_.flags & OL_REQ_PENDING)); @@ -420,20 +442,21 @@ int ol_connect(ol_handle* handle, ol_req *req, struct sockaddr* addr) { req->handle = handle; req->type = OL_CONNECT; - result = pConnectEx(handle->_.socket, - addr, - addrsize, - NULL, - 0, - &bytes, - &req->_.overlapped); + success = pConnectEx(handle->_.socket, + addr, + addrsize, + NULL, + 0, + &bytes, + &req->_.overlapped); - if (result != 0 && WSAGetLastError() != ERROR_IO_PENDING) { + if (!success && WSAGetLastError() != ERROR_IO_PENDING) { ol_errno_ = WSAGetLastError(); return -1; } req->_.flags |= OL_REQ_PENDING; + handle->_.reqs_pending++; return 0; } @@ -449,12 +472,12 @@ int ol_write(ol_handle* handle, ol_req *req, ol_buf* bufs, int bufcnt) { req->handle = handle; req->type = OL_WRITE; - result = WSASend(handle->_.socket, - (WSABUF*)bufs, - bufcnt, - &bytes, - 0, - &req->_.overlapped, + result = WSASend(handle->_.socket, + (WSABUF*)bufs, + bufcnt, + &bytes, + 0, + &req->_.overlapped, NULL); if (result != 0 && WSAGetLastError() != ERROR_IO_PENDING) { ol_errno_ = WSAGetLastError(); @@ -462,10 +485,12 @@ int ol_write(ol_handle* handle, ol_req *req, ol_buf* bufs, int bufcnt) { } req->_.flags |= OL_REQ_PENDING; + handle->_.reqs_pending++; return 0; } + int ol_read(ol_handle* handle, ol_req *req, ol_buf* bufs, int bufcnt) { int result; DWORD bytes, flags; @@ -475,14 +500,14 @@ int ol_read(ol_handle* handle, ol_req *req, ol_buf* bufs, int bufcnt) { memset(&req->_.overlapped, 0, sizeof(req->_.overlapped)); req->handle = handle; req->type = OL_READ; - + flags = 0; - result = WSARecv(handle->_.socket, - (WSABUF*)bufs, - bufcnt, - &bytes, - &flags, - &req->_.overlapped, + result = WSARecv(handle->_.socket, + (WSABUF*)bufs, + bufcnt, + &bytes, + &flags, + &req->_.overlapped, NULL); if (result != 0 && WSAGetLastError() != ERROR_IO_PENDING) { ol_errno_ = WSAGetLastError(); @@ -490,6 +515,7 @@ int ol_read(ol_handle* handle, ol_req *req, ol_buf* bufs, int bufcnt) { } req->_.flags |= OL_REQ_PENDING; + handle->_.reqs_pending++; return 0; } @@ -505,11 +531,16 @@ int ol_write2(ol_handle* handle, const char* msg) { buf.base = (char*)msg; buf.len = strlen(msg); - + return ol_write(handle, req, &buf, 1); } +ol_err ol_last_error() { + return ol_errno_; +} + + void ol_poll() { BOOL success; DWORD bytes; @@ -518,7 +549,8 @@ void ol_poll() { ol_req* req; ol_handle* handle; ol_handle *peer; - + int free_req; + success = GetQueuedCompletionStatus(ol_iocp_, &bytes, &key, @@ -529,35 +561,47 @@ void ol_poll() { ol_fatal_error(GetLastError(), "GetQueuedCompletionStatus"); req = ol_overlapped_to_req(overlapped); + handle = req->handle; /* Mark the request non-pending */ req->_.flags &= ~OL_REQ_PENDING; + handle->_.reqs_pending--; + + /* Cache this value, because when the req is not internal the callback */ + /* might free the req structure, so we cannot look at the flags field */ + /* after the callback has been called. */ + free_req = req->_.flags & OL_REQ_INTERNAL; /* If the related socket got closed in the meantime, disregard this */ - /* result. If necessary free the request */ - if (req->_.flags & OL_REQ_STRAY) { + /* result. If it is an internal request, free it. If this is the last */ + /* request pending, close the handle's close callback. */ + if (handle->_.flags & OL_HANDLE_CLOSING) { if (req->type == OL_ACCEPT) { peer = (ol_handle*)req->data; ol_close(peer); ol_free(peer); } - if (req->_.flags & OL_REQ_INTERNAL) { - /* Free it */ + if (free_req) { free(req); } + if (handle->_.reqs_pending == 0) { + handle->_.flags |= OL_HANDLE_CLOSED; + if (handle->close_cb) + handle->close_cb(handle, handle->_.error); + ol_refs_--; + } return; } switch (req->type) { case OL_WRITE: - handle = (ol_handle*)key; success = GetOverlappedResult(handle->_.handle, overlapped, &bytes, FALSE); if (!success) { ol_close_error(handle, GetLastError()); } else if (req->cb) { ((ol_write_cb)req->cb)(req); - } - if (req->_.flags & OL_REQ_INTERNAL) { + } + if (free_req) { free(req); } return; @@ -570,7 +614,7 @@ void ol_poll() { } else if (req->cb) { ((ol_read_cb)req->cb)(req, bytes); } - if (req->_.flags & OL_REQ_INTERNAL) { + if (free_req) { free(req); } break; @@ -588,7 +632,7 @@ void ol_poll() { } /* Queue another accept */ - ol_queue_accept(handle); + ol_queue_accept(handle, req); return; case OL_CONNECT: @@ -601,7 +645,7 @@ void ol_poll() { ((ol_connect_cb)req->cb)(req, GetLastError()); } } - if (req->_.flags & OL_REQ_INTERNAL) { + if (free_req) { free(req); } return; @@ -610,8 +654,9 @@ void ol_poll() { int ol_run() { - for (;;) { + while (ol_refs_ > 0) { ol_poll(); } + assert(ol_refs_ == 0); return 0; } diff --git a/ol-win.h b/ol-win.h index a18f220c..edaad484 100644 --- a/ol-win.h +++ b/ol-win.h @@ -25,6 +25,8 @@ typedef struct { SOCKET socket; HANDLE handle; }; - struct ol_req_s* accept_req; + unsigned int flags; + unsigned int reqs_pending; + ol_err error; } ol_handle_private;