From 53abaae3b93330e6174cb2a0472584c1d4ccf084 Mon Sep 17 00:00:00 2001 From: Chris Lu Date: Mon, 16 Mar 2026 15:22:29 -0700 Subject: [PATCH] Enforce Go HTTP whitelist guards --- seaweed-volume/src/main.rs | 4 +- seaweed-volume/src/server/volume_server.rs | 59 +++++++++-- seaweed-volume/tests/http_integration.rs | 109 ++++++++++++++++++++- 3 files changed, 160 insertions(+), 12 deletions(-) diff --git a/seaweed-volume/src/main.rs b/seaweed-volume/src/main.rs index b9c766988..9fa3c4c51 100644 --- a/seaweed-volume/src/main.rs +++ b/seaweed-volume/src/main.rs @@ -812,7 +812,7 @@ async fn serve_http( use hyper_util::service::TowerToHyperService; use tower::Service; - let mut make_svc = app.into_make_service(); + let mut make_svc = app.into_make_service_with_connect_info::(); tokio::pin!(shutdown_signal); @@ -858,7 +858,7 @@ async fn serve_https( use hyper_util::service::TowerToHyperService; use tower::Service; - let mut make_svc = app.into_make_service(); + let mut make_svc = app.into_make_service_with_connect_info::(); tokio::pin!(shutdown_signal); diff --git a/seaweed-volume/src/server/volume_server.rs b/seaweed-volume/src/server/volume_server.rs index ca890a76d..4595a3213 100644 --- a/seaweed-volume/src/server/volume_server.rs +++ b/seaweed-volume/src/server/volume_server.rs @@ -9,11 +9,12 @@ //! //! Matches Go's server/volume_server.go. +use std::net::SocketAddr; use std::sync::atomic::{AtomicBool, AtomicI64, AtomicU32, Ordering}; use std::sync::{Arc, RwLock}; use axum::{ - extract::{Request, State}, + extract::{connect_info::ConnectInfo, Request, State}, http::{header, HeaderValue, Method, StatusCode}, middleware::{self, Next}, response::{IntoResponse, Response}, @@ -127,15 +128,47 @@ pub fn normalize_outgoing_http_url(scheme: &str, raw_target: &str) -> Result Option { + request + .extensions() + .get::>() + .map(|info| info.0) +} + +fn request_is_whitelisted(state: &VolumeServerState, request: &Request) -> bool { + request_remote_addr(request) + .map(|remote_addr| { + state + .guard + .read() + .unwrap() + .check_whitelist(&remote_addr.to_string()) + }) + .unwrap_or(true) +} + +async fn whitelist_guard_middleware( + State(state): State>, + request: Request, + next: Next, +) -> Response { + if !request_is_whitelisted(&state, &request) { + return StatusCode::UNAUTHORIZED.into_response(); + } + next.run(request).await +} + /// Middleware: set Server header, echo x-amz-request-id, set CORS if Origin present. async fn common_headers_middleware(request: Request, next: Next) -> Response { let origin = request.headers().get("origin").cloned(); let request_id = super::request_id::generate_http_request_id(); - let mut response = super::request_id::scope_request_id(request_id.clone(), async move { - next.run(request).await - }) - .await; + let mut response = + super::request_id::scope_request_id( + request_id.clone(), + async move { next.run(request).await }, + ) + .await; let headers = response.headers_mut(); if let Ok(val) = HeaderValue::from_str(crate::version::server_header()) { @@ -175,7 +208,10 @@ async fn admin_store_handler(state: State>, request: Requ crate::metrics::INFLIGHT_REQUESTS_GAUGE .with_label_values(&[&method_str]) .inc(); + let whitelist_rejected = matches!(method, Method::POST | Method::PUT | Method::DELETE) + && !request_is_whitelisted(&state, &request); let response = match method.clone() { + _ if whitelist_rejected => StatusCode::UNAUTHORIZED.into_response(), Method::GET | Method::HEAD => { super::server_stats::record_read_request(); handlers::get_or_head_handler_from_request(state, request).await @@ -330,11 +366,18 @@ pub fn build_admin_router_with_ui(state: Arc, ui_enabled: boo .route("/:vid/:fid/:filename", any(admin_store_handler)) .fallback(admin_store_handler); if ui_enabled { - router = router - .route("/ui/index.html", get(handlers::ui_handler)) + let whitelist_state = state.clone(); + let stats_router = Router::new() .route("/stats/disk", get(handlers::stats_disk_handler)) .route("/stats/counter", get(handlers::stats_counter_handler)) - .route("/stats/memory", get(handlers::stats_memory_handler)); + .route("/stats/memory", get(handlers::stats_memory_handler)) + .route_layer(middleware::from_fn_with_state( + whitelist_state, + whitelist_guard_middleware, + )); + router = router + .route("/ui/index.html", get(handlers::ui_handler)) + .merge(stats_router); } router .layer(middleware::from_fn(common_headers_middleware)) diff --git a/seaweed-volume/tests/http_integration.rs b/seaweed-volume/tests/http_integration.rs index 478df3e76..6899a9165 100644 --- a/seaweed-volume/tests/http_integration.rs +++ b/seaweed-volume/tests/http_integration.rs @@ -6,6 +6,7 @@ use std::sync::{Arc, RwLock}; use axum::body::Body; +use axum::extract::connect_info::ConnectInfo; use axum::http::{Request, StatusCode}; use tower::ServiceExt; // for `oneshot` @@ -23,10 +24,21 @@ use tempfile::TempDir; /// Create a test VolumeServerState with a temp directory, a single disk /// location, and one pre-created volume (VolumeId 1). fn test_state() -> (Arc, TempDir) { - test_state_with_signing_key(Vec::new()) + test_state_with_guard(Vec::new(), Vec::new()) } fn test_state_with_signing_key(signing_key: Vec) -> (Arc, TempDir) { + test_state_with_guard(Vec::new(), signing_key) +} + +fn test_state_with_whitelist(whitelist: Vec) -> (Arc, TempDir) { + test_state_with_guard(whitelist, Vec::new()) +} + +fn test_state_with_guard( + whitelist: Vec, + signing_key: Vec, +) -> (Arc, TempDir) { let tmp = TempDir::new().expect("failed to create temp dir"); let dir = tmp.path().to_str().unwrap(); @@ -53,7 +65,13 @@ fn test_state_with_signing_key(signing_key: Vec) -> (Arc, ) .expect("failed to create volume"); - let guard = Guard::new(&[], SigningKey(signing_key), 0, SigningKey(vec![]), 0); + let guard = Guard::new( + &whitelist, + SigningKey(signing_key), + 0, + SigningKey(vec![]), + 0, + ); let state = Arc::new(VolumeServerState { store: RwLock::new(store), guard: RwLock::new(guard), @@ -109,6 +127,15 @@ async fn body_bytes(response: axum::response::Response) -> Vec { .to_vec() } +fn with_remote_addr(request: Request, remote_addr: &str) -> Request { + let mut request = request; + let remote_addr = remote_addr + .parse::() + .expect("invalid socket address"); + request.extensions_mut().insert(ConnectInfo(remote_addr)); + request +} + // ============================================================================ // 1. GET /healthz returns 200 when server is running // ============================================================================ @@ -226,6 +253,84 @@ async fn metrics_router_serves_metrics() { assert_eq!(response.status(), StatusCode::OK); } +#[tokio::test] +async fn admin_router_rejects_non_whitelisted_uploads() { + let (state, _tmp) = test_state_with_whitelist(vec!["127.0.0.1".to_string()]); + let app = build_admin_router(state); + + let response = app + .oneshot(with_remote_addr( + Request::builder() + .method("POST") + .uri("/1,000000000000000001") + .body(Body::from("blocked")) + .unwrap(), + "10.0.0.9:12345", + )) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::UNAUTHORIZED); +} + +#[tokio::test] +async fn admin_router_rejects_non_whitelisted_deletes() { + let (state, _tmp) = test_state_with_whitelist(vec!["127.0.0.1".to_string()]); + let app = build_admin_router(state); + + let response = app + .oneshot(with_remote_addr( + Request::builder() + .method("DELETE") + .uri("/1,000000000000000001") + .body(Body::empty()) + .unwrap(), + "10.0.0.9:12345", + )) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::UNAUTHORIZED); +} + +#[tokio::test] +async fn admin_router_rejects_non_whitelisted_stats_routes() { + let (state, _tmp) = test_state_with_whitelist(vec!["127.0.0.1".to_string()]); + let app = build_admin_router_with_ui(state, true); + + let response = app + .oneshot(with_remote_addr( + Request::builder() + .uri("/stats/counter") + .body(Body::empty()) + .unwrap(), + "10.0.0.9:12345", + )) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::UNAUTHORIZED); +} + +#[tokio::test] +async fn admin_router_allows_whitelisted_stats_routes() { + let (state, _tmp) = test_state_with_whitelist(vec!["127.0.0.1".to_string()]); + let app = build_admin_router_with_ui(state, true); + + let response = app + .oneshot(with_remote_addr( + Request::builder() + .uri("/stats/counter") + .body(Body::empty()) + .unwrap(), + "127.0.0.1:12345", + )) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::OK); +} + // ============================================================================ // 4. POST writes data, then GET reads it back // ============================================================================