diff --git a/seaweed-volume/src/main.rs b/seaweed-volume/src/main.rs index 9cac9d46f..750f7b6b1 100644 --- a/seaweed-volume/src/main.rs +++ b/seaweed-volume/src/main.rs @@ -14,6 +14,7 @@ use seaweed_volume::server::debug::build_debug_router; use seaweed_volume::server::grpc_client::load_outgoing_grpc_tls; use seaweed_volume::server::grpc_server::VolumeGrpcService; use seaweed_volume::server::profiling::CpuProfileSession; +use seaweed_volume::server::request_id::GrpcRequestIdLayer; use seaweed_volume::server::volume_server::{ build_metrics_router, RuntimeMetricsConfig, VolumeServerState, }; @@ -569,6 +570,7 @@ async fn run( .expect("Failed to build gRPC reflection v1alpha service"); info!("gRPC server listening on {} (TLS enabled)", addr); if let Err(e) = build_grpc_server_builder() + .layer(GrpcRequestIdLayer) .add_service(reflection_v1) .add_service(reflection_v1alpha) .add_service(build_volume_grpc_service(grpc_service)) @@ -590,6 +592,7 @@ async fn run( .expect("Failed to build gRPC reflection v1alpha service"); info!("gRPC server listening on {}", addr); if let Err(e) = build_grpc_server_builder() + .layer(GrpcRequestIdLayer) .add_service(reflection_v1) .add_service(reflection_v1alpha) .add_service(build_volume_grpc_service(grpc_service)) diff --git a/seaweed-volume/src/server/grpc_server.rs b/seaweed-volume/src/server/grpc_server.rs index e8aab191d..ce973dc03 100644 --- a/seaweed-volume/src/server/grpc_server.rs +++ b/seaweed-volume/src/server/grpc_server.rs @@ -99,7 +99,10 @@ impl VolumeGrpcService { .connect() .await .map_err(|e| Status::internal(format!("connect to master {}: {}", master_url, e)))?; - let mut client = SeaweedClient::new(channel) + let mut client = SeaweedClient::with_interceptor( + channel, + super::request_id::outgoing_request_id_interceptor, + ) .max_decoding_message_size(GRPC_MAX_MESSAGE_SIZE) .max_encoding_message_size(GRPC_MAX_MESSAGE_SIZE); client @@ -898,9 +901,12 @@ impl VolumeServer for VolumeGrpcService { )) })?; - let mut client = volume_server_pb::volume_server_client::VolumeServerClient::new(channel) - .max_decoding_message_size(GRPC_MAX_MESSAGE_SIZE) - .max_encoding_message_size(GRPC_MAX_MESSAGE_SIZE); + let mut client = volume_server_pb::volume_server_client::VolumeServerClient::with_interceptor( + channel, + super::request_id::outgoing_request_id_interceptor, + ) + .max_decoding_message_size(GRPC_MAX_MESSAGE_SIZE) + .max_encoding_message_size(GRPC_MAX_MESSAGE_SIZE); // Get file status from source let vol_info = client @@ -1630,9 +1636,12 @@ impl VolumeServer for VolumeGrpcService { .await .map_err(|e| Status::internal(format!("connect to {}: {}", grpc_addr, e)))?; - let mut client = volume_server_pb::volume_server_client::VolumeServerClient::new(channel) - .max_decoding_message_size(GRPC_MAX_MESSAGE_SIZE) - .max_encoding_message_size(GRPC_MAX_MESSAGE_SIZE); + let mut client = volume_server_pb::volume_server_client::VolumeServerClient::with_interceptor( + channel, + super::request_id::outgoing_request_id_interceptor, + ) + .max_decoding_message_size(GRPC_MAX_MESSAGE_SIZE) + .max_encoding_message_size(GRPC_MAX_MESSAGE_SIZE); // Call VolumeTailSender on source let mut stream = client @@ -1896,9 +1905,12 @@ impl VolumeServer for VolumeGrpcService { )) })?; - let mut client = volume_server_pb::volume_server_client::VolumeServerClient::new(channel) - .max_decoding_message_size(GRPC_MAX_MESSAGE_SIZE) - .max_encoding_message_size(GRPC_MAX_MESSAGE_SIZE); + let mut client = volume_server_pb::volume_server_client::VolumeServerClient::with_interceptor( + channel, + super::request_id::outgoing_request_id_interceptor, + ) + .max_decoding_message_size(GRPC_MAX_MESSAGE_SIZE) + .max_encoding_message_size(GRPC_MAX_MESSAGE_SIZE); // Copy each shard for &shard_id in &req.shard_ids { @@ -3315,9 +3327,12 @@ async fn ping_volume_server_target( .map_err(|_| "connection timeout".to_string())? .map_err(|e| e.to_string())?; - let mut client = volume_server_pb::volume_server_client::VolumeServerClient::new(channel) - .max_decoding_message_size(GRPC_MAX_MESSAGE_SIZE) - .max_encoding_message_size(GRPC_MAX_MESSAGE_SIZE); + let mut client = volume_server_pb::volume_server_client::VolumeServerClient::with_interceptor( + channel, + super::request_id::outgoing_request_id_interceptor, + ) + .max_decoding_message_size(GRPC_MAX_MESSAGE_SIZE) + .max_encoding_message_size(GRPC_MAX_MESSAGE_SIZE); let resp = client .ping(volume_server_pb::PingRequest { target: String::new(), @@ -3339,9 +3354,12 @@ async fn ping_master_target( .map_err(|_| "connection timeout".to_string())? .map_err(|e| e.to_string())?; - let mut client = master_pb::seaweed_client::SeaweedClient::new(channel) - .max_decoding_message_size(GRPC_MAX_MESSAGE_SIZE) - .max_encoding_message_size(GRPC_MAX_MESSAGE_SIZE); + let mut client = master_pb::seaweed_client::SeaweedClient::with_interceptor( + channel, + super::request_id::outgoing_request_id_interceptor, + ) + .max_decoding_message_size(GRPC_MAX_MESSAGE_SIZE) + .max_encoding_message_size(GRPC_MAX_MESSAGE_SIZE); let resp = client .ping(master_pb::PingRequest { target: String::new(), @@ -3363,9 +3381,12 @@ async fn ping_filer_target( .map_err(|_| "connection timeout".to_string())? .map_err(|e| e.to_string())?; - let mut client = filer_pb::seaweed_filer_client::SeaweedFilerClient::new(channel) - .max_decoding_message_size(GRPC_MAX_MESSAGE_SIZE) - .max_encoding_message_size(GRPC_MAX_MESSAGE_SIZE); + let mut client = filer_pb::seaweed_filer_client::SeaweedFilerClient::with_interceptor( + channel, + super::request_id::outgoing_request_id_interceptor, + ) + .max_decoding_message_size(GRPC_MAX_MESSAGE_SIZE) + .max_encoding_message_size(GRPC_MAX_MESSAGE_SIZE); let resp = client .ping(filer_pb::PingRequest::default()) .await @@ -3413,10 +3434,8 @@ fn set_file_mtime(path: &str, modified_ts_ns: i64) { /// Copy a file from a remote volume server via CopyFile streaming RPC. /// Returns the modified_ts_ns received from the source. -async fn copy_file_from_source( - client: &mut volume_server_pb::volume_server_client::VolumeServerClient< - tonic::transport::Channel, - >, +async fn copy_file_from_source( + client: &mut volume_server_pb::volume_server_client::VolumeServerClient, is_ec_volume: bool, collection: &str, volume_id: u32, @@ -3432,7 +3451,13 @@ async fn copy_file_from_source( next_report_target: &mut i64, report_interval: i64, throttler: &mut WriteThrottler, -) -> Result { +) -> Result +where + T: tonic::client::GrpcService, + T::Error: Into, + T::ResponseBody: http_body::Body + Send + 'static, + ::Error: Into + Send, +{ let copy_req = volume_server_pb::CopyFileRequest { volume_id, ext: ext.to_string(), diff --git a/seaweed-volume/src/server/heartbeat.rs b/seaweed-volume/src/server/heartbeat.rs index 13bda113c..29682e698 100644 --- a/seaweed-volume/src/server/heartbeat.rs +++ b/seaweed-volume/src/server/heartbeat.rs @@ -222,7 +222,10 @@ async fn try_get_master_configuration( .timeout(Duration::from_secs(10)) .connect() .await?; - let mut client = SeaweedClient::new(channel) + let mut client = SeaweedClient::with_interceptor( + channel, + super::request_id::outgoing_request_id_interceptor, + ) .max_decoding_message_size(GRPC_MAX_MESSAGE_SIZE) .max_encoding_message_size(GRPC_MAX_MESSAGE_SIZE); let resp = client @@ -249,7 +252,10 @@ async fn do_heartbeat( .connect() .await?; - let mut client = SeaweedClient::new(channel) + let mut client = SeaweedClient::with_interceptor( + channel, + super::request_id::outgoing_request_id_interceptor, + ) .max_decoding_message_size(GRPC_MAX_MESSAGE_SIZE) .max_encoding_message_size(GRPC_MAX_MESSAGE_SIZE); diff --git a/seaweed-volume/src/server/mod.rs b/seaweed-volume/src/server/mod.rs index 37135b549..e0b21f1f9 100644 --- a/seaweed-volume/src/server/mod.rs +++ b/seaweed-volume/src/server/mod.rs @@ -5,5 +5,6 @@ pub mod handlers; pub mod heartbeat; pub mod memory_status; pub mod profiling; +pub mod request_id; pub mod volume_server; pub mod write_queue; diff --git a/seaweed-volume/src/server/request_id.rs b/seaweed-volume/src/server/request_id.rs new file mode 100644 index 000000000..6af2008e6 --- /dev/null +++ b/seaweed-volume/src/server/request_id.rs @@ -0,0 +1,138 @@ +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use hyper::http::{self, HeaderValue}; +use tonic::metadata::MetadataValue; +use tonic::{Request, Status}; +use tower::{Layer, Service}; + +tokio::task_local! { + static CURRENT_REQUEST_ID: String; +} + +#[derive(Clone, Debug, Default)] +pub struct GrpcRequestIdLayer; + +#[derive(Clone, Debug)] +pub struct GrpcRequestIdService { + inner: S, +} + +impl Layer for GrpcRequestIdLayer { + type Service = GrpcRequestIdService; + + fn layer(&self, inner: S) -> Self::Service { + GrpcRequestIdService { inner } + } +} + +impl Service> for GrpcRequestIdService +where + S: Service, Response = http::Response> + Send + 'static, + S::Future: Send + 'static, + B: Send + 'static, +{ + type Response = http::Response; + type Error = S::Error; + type Future = Pin> + Send>>; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, mut request: http::Request) -> Self::Future { + let request_id = match request.headers().get("x-amz-request-id") { + Some(value) => match value.to_str() { + Ok(value) if !value.is_empty() => value.to_owned(), + _ => generate_grpc_request_id(), + }, + None => generate_grpc_request_id(), + }; + + if let Ok(value) = HeaderValue::from_str(&request_id) { + request.headers_mut().insert("x-amz-request-id", value); + } + + let future = self.inner.call(request); + + Box::pin(async move { + let mut response: http::Response = + scope_request_id(request_id.clone(), future).await?; + if let Ok(value) = HeaderValue::from_str(&request_id) { + response.headers_mut().insert("x-amz-request-id", value); + } + Ok(response) + }) + } +} + +pub async fn scope_request_id(request_id: String, future: F) -> T +where + F: Future, +{ + CURRENT_REQUEST_ID.scope(request_id, future).await +} + +pub fn current_request_id() -> Option { + CURRENT_REQUEST_ID.try_with(Clone::clone).ok() +} + +pub fn outgoing_request_id_interceptor( + mut request: Request<()>, +) -> Result, Status> { + if let Some(request_id) = current_request_id() { + let value = MetadataValue::try_from(request_id.as_str()) + .map_err(|_| Status::internal("invalid scoped request id"))?; + request.metadata_mut().insert("x-amz-request-id", value); + } + Ok(request) +} + +pub fn generate_http_request_id() -> String { + use rand::Rng; + + let nanos = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_nanos() as u64; + let rand_val: u32 = rand::thread_rng().gen(); + format!("{:X}{:08X}", nanos, rand_val) +} + +fn generate_grpc_request_id() -> String { + uuid::Uuid::new_v4().to_string() +} + +#[cfg(test)] +mod tests { + use super::{current_request_id, outgoing_request_id_interceptor, scope_request_id}; + use tonic::Request; + + #[tokio::test] + async fn test_scope_request_id_exposes_current_value() { + let request_id = "req-123".to_string(); + let current = scope_request_id(request_id.clone(), async move { + current_request_id().unwrap() + }) + .await; + assert_eq!(current, request_id); + } + + #[tokio::test] + async fn test_outgoing_request_id_interceptor_propagates_scope() { + let request = scope_request_id("req-456".to_string(), async move { + outgoing_request_id_interceptor(Request::new(())).unwrap() + }) + .await; + assert_eq!( + request + .metadata() + .get("x-amz-request-id") + .unwrap() + .to_str() + .unwrap(), + "req-456" + ); + } +} diff --git a/seaweed-volume/src/server/volume_server.rs b/seaweed-volume/src/server/volume_server.rs index 11cc42328..a6e057014 100644 --- a/seaweed-volume/src/server/volume_server.rs +++ b/seaweed-volume/src/server/volume_server.rs @@ -128,29 +128,20 @@ pub fn normalize_outgoing_http_url(scheme: &str, raw_target: &str) -> Result Response { let origin = request.headers().get("origin").cloned(); - let _request_id = request.headers().get("x-amz-request-id").cloned(); + let request_id = super::request_id::generate_http_request_id(); - let mut response = next.run(request).await; + let mut response = super::request_id::scope_request_id(request_id.clone(), async move { + next.run(request).await + }) + .await; let headers = response.headers_mut(); if let Ok(val) = HeaderValue::from_str(crate::version::server_header()) { headers.insert("Server", val); } - // Always generate a server-side request ID (matching Go's request_id.Ensure - // which ignores client-sent headers to prevent spoofing). - // Format: "%X%08X" — nanosecond timestamp hex + 4 random bytes hex. - { - use rand::Rng; - let nanos = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap_or_default() - .as_nanos() as u64; - let rand_val: u32 = rand::thread_rng().gen(); - let id = format!("{:X}{:08X}", nanos, rand_val); - if let Ok(val) = HeaderValue::from_str(&id) { - headers.insert("x-amz-request-id", val); - } + if let Ok(val) = HeaderValue::from_str(&request_id) { + headers.insert("x-amz-request-id", val); } if origin.is_some() {