Browse Source

Propagate request IDs across gRPC calls

rust-volume-server
Chris Lu 4 days ago
parent
commit
c2e938442c
  1. 3
      seaweed-volume/src/main.rs
  2. 73
      seaweed-volume/src/server/grpc_server.rs
  3. 10
      seaweed-volume/src/server/heartbeat.rs
  4. 1
      seaweed-volume/src/server/mod.rs
  5. 138
      seaweed-volume/src/server/request_id.rs
  6. 23
      seaweed-volume/src/server/volume_server.rs

3
seaweed-volume/src/main.rs

@ -14,6 +14,7 @@ use seaweed_volume::server::debug::build_debug_router;
use seaweed_volume::server::grpc_client::load_outgoing_grpc_tls;
use seaweed_volume::server::grpc_server::VolumeGrpcService;
use seaweed_volume::server::profiling::CpuProfileSession;
use seaweed_volume::server::request_id::GrpcRequestIdLayer;
use seaweed_volume::server::volume_server::{
build_metrics_router, RuntimeMetricsConfig, VolumeServerState,
};
@ -569,6 +570,7 @@ async fn run(
.expect("Failed to build gRPC reflection v1alpha service");
info!("gRPC server listening on {} (TLS enabled)", addr);
if let Err(e) = build_grpc_server_builder()
.layer(GrpcRequestIdLayer)
.add_service(reflection_v1)
.add_service(reflection_v1alpha)
.add_service(build_volume_grpc_service(grpc_service))
@ -590,6 +592,7 @@ async fn run(
.expect("Failed to build gRPC reflection v1alpha service");
info!("gRPC server listening on {}", addr);
if let Err(e) = build_grpc_server_builder()
.layer(GrpcRequestIdLayer)
.add_service(reflection_v1)
.add_service(reflection_v1alpha)
.add_service(build_volume_grpc_service(grpc_service))

73
seaweed-volume/src/server/grpc_server.rs

@ -99,7 +99,10 @@ impl VolumeGrpcService {
.connect()
.await
.map_err(|e| Status::internal(format!("connect to master {}: {}", master_url, e)))?;
let mut client = SeaweedClient::new(channel)
let mut client = SeaweedClient::with_interceptor(
channel,
super::request_id::outgoing_request_id_interceptor,
)
.max_decoding_message_size(GRPC_MAX_MESSAGE_SIZE)
.max_encoding_message_size(GRPC_MAX_MESSAGE_SIZE);
client
@ -898,9 +901,12 @@ impl VolumeServer for VolumeGrpcService {
))
})?;
let mut client = volume_server_pb::volume_server_client::VolumeServerClient::new(channel)
.max_decoding_message_size(GRPC_MAX_MESSAGE_SIZE)
.max_encoding_message_size(GRPC_MAX_MESSAGE_SIZE);
let mut client = volume_server_pb::volume_server_client::VolumeServerClient::with_interceptor(
channel,
super::request_id::outgoing_request_id_interceptor,
)
.max_decoding_message_size(GRPC_MAX_MESSAGE_SIZE)
.max_encoding_message_size(GRPC_MAX_MESSAGE_SIZE);
// Get file status from source
let vol_info = client
@ -1630,9 +1636,12 @@ impl VolumeServer for VolumeGrpcService {
.await
.map_err(|e| Status::internal(format!("connect to {}: {}", grpc_addr, e)))?;
let mut client = volume_server_pb::volume_server_client::VolumeServerClient::new(channel)
.max_decoding_message_size(GRPC_MAX_MESSAGE_SIZE)
.max_encoding_message_size(GRPC_MAX_MESSAGE_SIZE);
let mut client = volume_server_pb::volume_server_client::VolumeServerClient::with_interceptor(
channel,
super::request_id::outgoing_request_id_interceptor,
)
.max_decoding_message_size(GRPC_MAX_MESSAGE_SIZE)
.max_encoding_message_size(GRPC_MAX_MESSAGE_SIZE);
// Call VolumeTailSender on source
let mut stream = client
@ -1896,9 +1905,12 @@ impl VolumeServer for VolumeGrpcService {
))
})?;
let mut client = volume_server_pb::volume_server_client::VolumeServerClient::new(channel)
.max_decoding_message_size(GRPC_MAX_MESSAGE_SIZE)
.max_encoding_message_size(GRPC_MAX_MESSAGE_SIZE);
let mut client = volume_server_pb::volume_server_client::VolumeServerClient::with_interceptor(
channel,
super::request_id::outgoing_request_id_interceptor,
)
.max_decoding_message_size(GRPC_MAX_MESSAGE_SIZE)
.max_encoding_message_size(GRPC_MAX_MESSAGE_SIZE);
// Copy each shard
for &shard_id in &req.shard_ids {
@ -3315,9 +3327,12 @@ async fn ping_volume_server_target(
.map_err(|_| "connection timeout".to_string())?
.map_err(|e| e.to_string())?;
let mut client = volume_server_pb::volume_server_client::VolumeServerClient::new(channel)
.max_decoding_message_size(GRPC_MAX_MESSAGE_SIZE)
.max_encoding_message_size(GRPC_MAX_MESSAGE_SIZE);
let mut client = volume_server_pb::volume_server_client::VolumeServerClient::with_interceptor(
channel,
super::request_id::outgoing_request_id_interceptor,
)
.max_decoding_message_size(GRPC_MAX_MESSAGE_SIZE)
.max_encoding_message_size(GRPC_MAX_MESSAGE_SIZE);
let resp = client
.ping(volume_server_pb::PingRequest {
target: String::new(),
@ -3339,9 +3354,12 @@ async fn ping_master_target(
.map_err(|_| "connection timeout".to_string())?
.map_err(|e| e.to_string())?;
let mut client = master_pb::seaweed_client::SeaweedClient::new(channel)
.max_decoding_message_size(GRPC_MAX_MESSAGE_SIZE)
.max_encoding_message_size(GRPC_MAX_MESSAGE_SIZE);
let mut client = master_pb::seaweed_client::SeaweedClient::with_interceptor(
channel,
super::request_id::outgoing_request_id_interceptor,
)
.max_decoding_message_size(GRPC_MAX_MESSAGE_SIZE)
.max_encoding_message_size(GRPC_MAX_MESSAGE_SIZE);
let resp = client
.ping(master_pb::PingRequest {
target: String::new(),
@ -3363,9 +3381,12 @@ async fn ping_filer_target(
.map_err(|_| "connection timeout".to_string())?
.map_err(|e| e.to_string())?;
let mut client = filer_pb::seaweed_filer_client::SeaweedFilerClient::new(channel)
.max_decoding_message_size(GRPC_MAX_MESSAGE_SIZE)
.max_encoding_message_size(GRPC_MAX_MESSAGE_SIZE);
let mut client = filer_pb::seaweed_filer_client::SeaweedFilerClient::with_interceptor(
channel,
super::request_id::outgoing_request_id_interceptor,
)
.max_decoding_message_size(GRPC_MAX_MESSAGE_SIZE)
.max_encoding_message_size(GRPC_MAX_MESSAGE_SIZE);
let resp = client
.ping(filer_pb::PingRequest::default())
.await
@ -3413,10 +3434,8 @@ fn set_file_mtime(path: &str, modified_ts_ns: i64) {
/// Copy a file from a remote volume server via CopyFile streaming RPC.
/// Returns the modified_ts_ns received from the source.
async fn copy_file_from_source(
client: &mut volume_server_pb::volume_server_client::VolumeServerClient<
tonic::transport::Channel,
>,
async fn copy_file_from_source<T>(
client: &mut volume_server_pb::volume_server_client::VolumeServerClient<T>,
is_ec_volume: bool,
collection: &str,
volume_id: u32,
@ -3432,7 +3451,13 @@ async fn copy_file_from_source(
next_report_target: &mut i64,
report_interval: i64,
throttler: &mut WriteThrottler,
) -> Result<i64, String> {
) -> Result<i64, String>
where
T: tonic::client::GrpcService<tonic::body::BoxBody>,
T::Error: Into<tonic::codegen::StdError>,
T::ResponseBody: http_body::Body<Data = bytes::Bytes> + Send + 'static,
<T::ResponseBody as http_body::Body>::Error: Into<tonic::codegen::StdError> + Send,
{
let copy_req = volume_server_pb::CopyFileRequest {
volume_id,
ext: ext.to_string(),

10
seaweed-volume/src/server/heartbeat.rs

@ -222,7 +222,10 @@ async fn try_get_master_configuration(
.timeout(Duration::from_secs(10))
.connect()
.await?;
let mut client = SeaweedClient::new(channel)
let mut client = SeaweedClient::with_interceptor(
channel,
super::request_id::outgoing_request_id_interceptor,
)
.max_decoding_message_size(GRPC_MAX_MESSAGE_SIZE)
.max_encoding_message_size(GRPC_MAX_MESSAGE_SIZE);
let resp = client
@ -249,7 +252,10 @@ async fn do_heartbeat(
.connect()
.await?;
let mut client = SeaweedClient::new(channel)
let mut client = SeaweedClient::with_interceptor(
channel,
super::request_id::outgoing_request_id_interceptor,
)
.max_decoding_message_size(GRPC_MAX_MESSAGE_SIZE)
.max_encoding_message_size(GRPC_MAX_MESSAGE_SIZE);

1
seaweed-volume/src/server/mod.rs

@ -5,5 +5,6 @@ pub mod handlers;
pub mod heartbeat;
pub mod memory_status;
pub mod profiling;
pub mod request_id;
pub mod volume_server;
pub mod write_queue;

138
seaweed-volume/src/server/request_id.rs

@ -0,0 +1,138 @@
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use hyper::http::{self, HeaderValue};
use tonic::metadata::MetadataValue;
use tonic::{Request, Status};
use tower::{Layer, Service};
tokio::task_local! {
static CURRENT_REQUEST_ID: String;
}
#[derive(Clone, Debug, Default)]
pub struct GrpcRequestIdLayer;
#[derive(Clone, Debug)]
pub struct GrpcRequestIdService<S> {
inner: S,
}
impl<S> Layer<S> for GrpcRequestIdLayer {
type Service = GrpcRequestIdService<S>;
fn layer(&self, inner: S) -> Self::Service {
GrpcRequestIdService { inner }
}
}
impl<S, B> Service<http::Request<B>> for GrpcRequestIdService<S>
where
S: Service<http::Request<B>, Response = http::Response<tonic::body::BoxBody>> + Send + 'static,
S::Future: Send + 'static,
B: Send + 'static,
{
type Response = http::Response<tonic::body::BoxBody>;
type Error = S::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, mut request: http::Request<B>) -> Self::Future {
let request_id = match request.headers().get("x-amz-request-id") {
Some(value) => match value.to_str() {
Ok(value) if !value.is_empty() => value.to_owned(),
_ => generate_grpc_request_id(),
},
None => generate_grpc_request_id(),
};
if let Ok(value) = HeaderValue::from_str(&request_id) {
request.headers_mut().insert("x-amz-request-id", value);
}
let future = self.inner.call(request);
Box::pin(async move {
let mut response: http::Response<tonic::body::BoxBody> =
scope_request_id(request_id.clone(), future).await?;
if let Ok(value) = HeaderValue::from_str(&request_id) {
response.headers_mut().insert("x-amz-request-id", value);
}
Ok(response)
})
}
}
pub async fn scope_request_id<F, T>(request_id: String, future: F) -> T
where
F: Future<Output = T>,
{
CURRENT_REQUEST_ID.scope(request_id, future).await
}
pub fn current_request_id() -> Option<String> {
CURRENT_REQUEST_ID.try_with(Clone::clone).ok()
}
pub fn outgoing_request_id_interceptor(
mut request: Request<()>,
) -> Result<Request<()>, Status> {
if let Some(request_id) = current_request_id() {
let value = MetadataValue::try_from(request_id.as_str())
.map_err(|_| Status::internal("invalid scoped request id"))?;
request.metadata_mut().insert("x-amz-request-id", value);
}
Ok(request)
}
pub fn generate_http_request_id() -> String {
use rand::Rng;
let nanos = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_nanos() as u64;
let rand_val: u32 = rand::thread_rng().gen();
format!("{:X}{:08X}", nanos, rand_val)
}
fn generate_grpc_request_id() -> String {
uuid::Uuid::new_v4().to_string()
}
#[cfg(test)]
mod tests {
use super::{current_request_id, outgoing_request_id_interceptor, scope_request_id};
use tonic::Request;
#[tokio::test]
async fn test_scope_request_id_exposes_current_value() {
let request_id = "req-123".to_string();
let current = scope_request_id(request_id.clone(), async move {
current_request_id().unwrap()
})
.await;
assert_eq!(current, request_id);
}
#[tokio::test]
async fn test_outgoing_request_id_interceptor_propagates_scope() {
let request = scope_request_id("req-456".to_string(), async move {
outgoing_request_id_interceptor(Request::new(())).unwrap()
})
.await;
assert_eq!(
request
.metadata()
.get("x-amz-request-id")
.unwrap()
.to_str()
.unwrap(),
"req-456"
);
}
}

23
seaweed-volume/src/server/volume_server.rs

@ -128,29 +128,20 @@ pub fn normalize_outgoing_http_url(scheme: &str, raw_target: &str) -> Result<Str
/// Middleware: set Server header, echo x-amz-request-id, set CORS if Origin present.
async fn common_headers_middleware(request: Request, next: Next) -> Response {
let origin = request.headers().get("origin").cloned();
let _request_id = request.headers().get("x-amz-request-id").cloned();
let request_id = super::request_id::generate_http_request_id();
let mut response = next.run(request).await;
let mut response = super::request_id::scope_request_id(request_id.clone(), async move {
next.run(request).await
})
.await;
let headers = response.headers_mut();
if let Ok(val) = HeaderValue::from_str(crate::version::server_header()) {
headers.insert("Server", val);
}
// Always generate a server-side request ID (matching Go's request_id.Ensure
// which ignores client-sent headers to prevent spoofing).
// Format: "%X%08X" — nanosecond timestamp hex + 4 random bytes hex.
{
use rand::Rng;
let nanos = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_nanos() as u64;
let rand_val: u32 = rand::thread_rng().gen();
let id = format!("{:X}{:08X}", nanos, rand_val);
if let Ok(val) = HeaderValue::from_str(&id) {
headers.insert("x-amz-request-id", val);
}
if let Ok(val) = HeaderValue::from_str(&request_id) {
headers.insert("x-amz-request-id", val);
}
if origin.is_some() {

Loading…
Cancel
Save