diff --git a/lib/vtls/rustls.c b/lib/vtls/rustls.c index c4fb432dd7..9a7538b62e 100644 --- a/lib/vtls/rustls.c +++ b/lib/vtls/rustls.c @@ -46,7 +46,8 @@ struct rustls_ssl_backend_data { const struct rustls_client_config *config; struct rustls_connection *conn; - bool data_pending; + size_t plain_out_buffered; + BIT(data_in_pending); }; /* For a given rustls_result error code, return the best-matching CURLcode. */ @@ -74,7 +75,7 @@ cr_data_pending(struct Curl_cfilter *cf, const struct Curl_easy *data) (void)data; DEBUGASSERT(ctx && ctx->backend); backend = (struct rustls_ssl_backend_data *)ctx->backend; - return backend->data_pending; + return backend->data_in_pending; } struct io_ctx { @@ -101,6 +102,8 @@ read_cb(void *userdata, uint8_t *buf, uintptr_t len, uintptr_t *out_n) else if(nread == 0) connssl->peer_closed = TRUE; *out_n = (int)nread; + CURL_TRC_CF(io_ctx->data, io_ctx->cf, "cf->next recv(len=%zu) -> %zd, %d", + len, nread, result); return ret; } @@ -120,10 +123,8 @@ write_cb(void *userdata, const uint8_t *buf, uintptr_t len, uintptr_t *out_n) ret = EINVAL; } *out_n = (int)nwritten; - /* - CURL_TRC_CFX(io_ctx->data, io_ctx->cf, "cf->next send(len=%zu) -> %zd, %d", - len, nwritten, result)); - */ + CURL_TRC_CF(io_ctx->data, io_ctx->cf, "cf->next send(len=%zu) -> %zd, %d", + len, nwritten, result); return ret; } @@ -165,7 +166,7 @@ static ssize_t tls_recv_more(struct Curl_cfilter *cf, return -1; } - backend->data_pending = TRUE; + backend->data_in_pending = TRUE; *err = CURLE_OK; return (ssize_t)tls_bytes_read; } @@ -200,7 +201,7 @@ cr_recv(struct Curl_cfilter *cf, struct Curl_easy *data, rconn = backend->conn; while(plain_bytes_copied < plainlen) { - if(!backend->data_pending) { + if(!backend->data_in_pending) { if(tls_recv_more(cf, data, err) < 0) { if(*err != CURLE_AGAIN) { nread = -1; @@ -215,7 +216,7 @@ cr_recv(struct Curl_cfilter *cf, struct Curl_easy *data, plainlen - plain_bytes_copied, &n); if(rresult == RUSTLS_RESULT_PLAINTEXT_EMPTY) { - backend->data_pending = FALSE; + backend->data_in_pending = FALSE; } else if(rresult == RUSTLS_RESULT_UNEXPECTED_EOF) { failf(data, "rustls: peer closed TCP connection " @@ -265,6 +266,42 @@ out: return nread; } +static CURLcode cr_flush_out(struct Curl_cfilter *cf, struct Curl_easy *data, + struct rustls_connection *rconn) +{ + struct io_ctx io_ctx; + rustls_io_result io_error; + size_t tlswritten = 0; + size_t tlswritten_total = 0; + CURLcode result = CURLE_OK; + + io_ctx.cf = cf; + io_ctx.data = data; + + while(rustls_connection_wants_write(rconn)) { + io_error = rustls_connection_write_tls(rconn, write_cb, &io_ctx, + &tlswritten); + if(io_error == EAGAIN || io_error == EWOULDBLOCK) { + CURL_TRC_CF(data, cf, "cf_send: EAGAIN after %zu bytes", + tlswritten_total); + return CURLE_AGAIN; + } + else if(io_error) { + char buffer[STRERROR_LEN]; + failf(data, "writing to socket: %s", + Curl_strerror(io_error, buffer, sizeof(buffer))); + return CURLE_SEND_ERROR; + } + if(tlswritten == 0) { + failf(data, "EOF in swrite"); + return CURLE_SEND_ERROR; + } + CURL_TRC_CF(data, cf, "cf_send: wrote %zu TLS bytes", tlswritten); + tlswritten_total += tlswritten; + } + return result; +} + /* * On each call: * - Copy `plainlen` bytes into rustls' plaintext input buffer (if > 0). @@ -283,26 +320,42 @@ cr_send(struct Curl_cfilter *cf, struct Curl_easy *data, struct rustls_ssl_backend_data *const backend = (struct rustls_ssl_backend_data *)connssl->backend; struct rustls_connection *rconn = NULL; - struct io_ctx io_ctx; size_t plainwritten = 0; - size_t tlswritten = 0; - size_t tlswritten_total = 0; rustls_result rresult; - rustls_io_result io_error; char errorbuf[256]; size_t errorlen; + const unsigned char *buf = plainbuf; + size_t blen = plainlen; + ssize_t nwritten = 0; DEBUGASSERT(backend); rconn = backend->conn; - CURL_TRC_CF(data, cf, "cf_send: %zu plain bytes", plainlen); + CURL_TRC_CF(data, cf, "cf_send(len=%zu)", plainlen); - io_ctx.cf = cf; - io_ctx.data = data; + /* If a previous send blocked, we already added its plain bytes + * to rustsls and must not do that again. Flush the TLS bytes and, + * if successful, deduct the previous plain bytes from the current + * send. */ + if(backend->plain_out_buffered) { + *err = cr_flush_out(cf, data, rconn); + CURL_TRC_CF(data, cf, "cf_send: flushing %zu previously added bytes -> %d", + backend->plain_out_buffered, *err); + if(*err) + return -1; + if(blen > backend->plain_out_buffered) { + blen -= backend->plain_out_buffered; + buf += backend->plain_out_buffered; + } + else + blen = 0; + nwritten += (ssize_t)backend->plain_out_buffered; + backend->plain_out_buffered = 0; + } - if(plainlen > 0) { - rresult = rustls_connection_write(rconn, plainbuf, plainlen, - &plainwritten); + if(blen > 0) { + CURL_TRC_CF(data, cf, "cf_send: adding %zu plain bytes to rustls", blen); + rresult = rustls_connection_write(rconn, buf, blen, &plainwritten); if(rresult != RUSTLS_RESULT_OK) { rustls_error(rresult, errorbuf, sizeof(errorbuf), &errorlen); failf(data, "rustls_connection_write: %.*s", (int)errorlen, errorbuf); @@ -316,32 +369,27 @@ cr_send(struct Curl_cfilter *cf, struct Curl_easy *data, } } - while(rustls_connection_wants_write(rconn)) { - io_error = rustls_connection_write_tls(rconn, write_cb, &io_ctx, - &tlswritten); - if(io_error == EAGAIN || io_error == EWOULDBLOCK) { - CURL_TRC_CF(data, cf, "cf_send: EAGAIN after %zu bytes", - tlswritten_total); - *err = CURLE_AGAIN; - return -1; + *err = cr_flush_out(cf, data, rconn); + if(*err) { + if(CURLE_AGAIN == *err) { + /* The TLS bytes may have been partially written, but we fail the + * complete send() and remember how much we already added to rustls. */ + CURL_TRC_CF(data, cf, "cf_send: EAGAIN, remember we added %zu plain" + " bytes already to rustls", blen); + backend->plain_out_buffered = plainwritten; + if(nwritten) { + *err = CURLE_OK; + return (ssize_t)nwritten; + } } - else if(io_error) { - char buffer[STRERROR_LEN]; - failf(data, "writing to socket: %s", - Curl_strerror(io_error, buffer, sizeof(buffer))); - *err = CURLE_WRITE_ERROR; - return -1; - } - if(tlswritten == 0) { - failf(data, "EOF in swrite"); - *err = CURLE_WRITE_ERROR; - return -1; - } - CURL_TRC_CF(data, cf, "cf_send: wrote %zu TLS bytes", tlswritten); - tlswritten_total += tlswritten; + return -1; } + else + nwritten += (ssize_t)plainwritten; - return plainwritten; + CURL_TRC_CF(data, cf, "cf_send(len=%zu) -> %d, %zd", + plainlen, *err, nwritten); + return nwritten; } /* A server certificate verify callback for rustls that always returns