diff --git a/seaweed-volume/Cargo.lock b/seaweed-volume/Cargo.lock index ce365ba8d..03da9aae7 100644 --- a/seaweed-volume/Cargo.lock +++ b/seaweed-volume/Cargo.lock @@ -4120,6 +4120,16 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "tokio-io-timeout" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0bd86198d9ee903fedd2f9a2e72014287c0d9167e4ae43b5853007205dda1b76" +dependencies = [ + "pin-project-lite", + "tokio", +] + [[package]] name = "tokio-macros" version = "2.6.1" @@ -4752,6 +4762,7 @@ dependencies = [ "tempfile", "thiserror 1.0.69", "tokio", + "tokio-io-timeout", "tokio-rustls 0.26.4", "tokio-stream", "toml", diff --git a/seaweed-volume/Cargo.toml b/seaweed-volume/Cargo.toml index 699a9f9d1..d6428eb2a 100644 --- a/seaweed-volume/Cargo.toml +++ b/seaweed-volume/Cargo.toml @@ -22,6 +22,7 @@ default = [] # Async runtime tokio = { version = "1", features = ["full"] } tokio-stream = "0.1" +tokio-io-timeout = "1" # gRPC + protobuf tonic = { version = "0.12", features = ["tls"] } diff --git a/seaweed-volume/src/main.rs b/seaweed-volume/src/main.rs index df9cbcc12..db3c62c0a 100644 --- a/seaweed-volume/src/main.rs +++ b/seaweed-volume/src/main.rs @@ -230,6 +230,21 @@ fn build_volume_grpc_service( .max_encoding_message_size(GRPC_MAX_MESSAGE_SIZE) } +fn apply_idle_timeout( + stream: S, + idle_timeout: std::time::Duration, +) -> std::pin::Pin>> +where + S: tokio::io::AsyncRead + tokio::io::AsyncWrite, +{ + let mut stream = tokio_io_timeout::TimeoutStream::new(stream); + if !idle_timeout.is_zero() { + stream.set_read_timeout(Some(idle_timeout)); + stream.set_write_timeout(Some(idle_timeout)); + } + Box::pin(stream) +} + async fn run(config: VolumeServerConfig) -> Result<(), Box> { // Initialize the store let mut store = Store::new(config.index_type); @@ -374,6 +389,7 @@ async fn run(config: VolumeServerConfig) -> Result<(), Box Result<(), Box Result<(), Box( + tcp_listener: tokio::net::TcpListener, + app: axum::Router, + idle_timeout: std::time::Duration, + shutdown_signal: F, +) where + F: std::future::Future + Send + 'static, +{ + use hyper_util::rt::{TokioExecutor, TokioIo}; + use hyper_util::server::conn::auto::Builder as HttpBuilder; + use hyper_util::service::TowerToHyperService; + use tower::Service; + + let mut make_svc = app.into_make_service(); + + tokio::pin!(shutdown_signal); + + loop { + tokio::select! { + _ = &mut shutdown_signal => { + info!("HTTP server shutting down"); + break; + } + result = tcp_listener.accept() => { + match result { + Ok((tcp_stream, remote_addr)) => { + let tower_svc = make_svc.call(remote_addr).await.expect("infallible"); + let hyper_svc = TowerToHyperService::new(tower_svc); + tokio::spawn(async move { + let io = TokioIo::new(apply_idle_timeout(tcp_stream, idle_timeout)); + let builder = HttpBuilder::new(TokioExecutor::new()); + if let Err(e) = builder.serve_connection(io, hyper_svc).await { + tracing::debug!("HTTP connection error: {}", e); + } + }); + } + Err(e) => { + error!("Failed to accept TCP connection: {}", e); + } + } + } + } + } +} + async fn serve_https( tcp_listener: tokio::net::TcpListener, app: axum::Router, tls_acceptor: TlsAcceptor, + idle_timeout: std::time::Duration, shutdown_signal: F, ) where F: std::future::Future + Send + 'static, @@ -790,7 +855,7 @@ async fn serve_https( tokio::spawn(async move { match tls_acceptor.accept(tcp_stream).await { Ok(tls_stream) => { - let io = TokioIo::new(tls_stream); + let io = TokioIo::new(apply_idle_timeout(tls_stream, idle_timeout)); let builder = HttpBuilder::new(TokioExecutor::new()); if let Err(e) = builder.serve_connection(io, hyper_svc).await { tracing::debug!("HTTPS connection error: {}", e);