Browse Source

Enforce Go HTTP whitelist guards

rust-volume-server
Chris Lu 4 days ago
parent
commit
53abaae3b9
  1. 4
      seaweed-volume/src/main.rs
  2. 59
      seaweed-volume/src/server/volume_server.rs
  3. 109
      seaweed-volume/tests/http_integration.rs

4
seaweed-volume/src/main.rs

@ -812,7 +812,7 @@ async fn serve_http<F>(
use hyper_util::service::TowerToHyperService;
use tower::Service;
let mut make_svc = app.into_make_service();
let mut make_svc = app.into_make_service_with_connect_info::<std::net::SocketAddr>();
tokio::pin!(shutdown_signal);
@ -858,7 +858,7 @@ async fn serve_https<F>(
use hyper_util::service::TowerToHyperService;
use tower::Service;
let mut make_svc = app.into_make_service();
let mut make_svc = app.into_make_service_with_connect_info::<std::net::SocketAddr>();
tokio::pin!(shutdown_signal);

59
seaweed-volume/src/server/volume_server.rs

@ -9,11 +9,12 @@
//!
//! Matches Go's server/volume_server.go.
use std::net::SocketAddr;
use std::sync::atomic::{AtomicBool, AtomicI64, AtomicU32, Ordering};
use std::sync::{Arc, RwLock};
use axum::{
extract::{Request, State},
extract::{connect_info::ConnectInfo, Request, State},
http::{header, HeaderValue, Method, StatusCode},
middleware::{self, Next},
response::{IntoResponse, Response},
@ -127,15 +128,47 @@ pub fn normalize_outgoing_http_url(scheme: &str, raw_target: &str) -> Result<Str
Ok(format!("{}://{}", scheme, raw_target))
}
fn request_remote_addr(request: &Request) -> Option<SocketAddr> {
request
.extensions()
.get::<ConnectInfo<SocketAddr>>()
.map(|info| info.0)
}
fn request_is_whitelisted(state: &VolumeServerState, request: &Request) -> bool {
request_remote_addr(request)
.map(|remote_addr| {
state
.guard
.read()
.unwrap()
.check_whitelist(&remote_addr.to_string())
})
.unwrap_or(true)
}
async fn whitelist_guard_middleware(
State(state): State<Arc<VolumeServerState>>,
request: Request,
next: Next,
) -> Response {
if !request_is_whitelisted(&state, &request) {
return StatusCode::UNAUTHORIZED.into_response();
}
next.run(request).await
}
/// Middleware: set Server header, echo x-amz-request-id, set CORS if Origin present.
async fn common_headers_middleware(request: Request, next: Next) -> Response {
let origin = request.headers().get("origin").cloned();
let request_id = super::request_id::generate_http_request_id();
let mut response = super::request_id::scope_request_id(request_id.clone(), async move {
next.run(request).await
})
.await;
let mut response =
super::request_id::scope_request_id(
request_id.clone(),
async move { next.run(request).await },
)
.await;
let headers = response.headers_mut();
if let Ok(val) = HeaderValue::from_str(crate::version::server_header()) {
@ -175,7 +208,10 @@ async fn admin_store_handler(state: State<Arc<VolumeServerState>>, request: Requ
crate::metrics::INFLIGHT_REQUESTS_GAUGE
.with_label_values(&[&method_str])
.inc();
let whitelist_rejected = matches!(method, Method::POST | Method::PUT | Method::DELETE)
&& !request_is_whitelisted(&state, &request);
let response = match method.clone() {
_ if whitelist_rejected => StatusCode::UNAUTHORIZED.into_response(),
Method::GET | Method::HEAD => {
super::server_stats::record_read_request();
handlers::get_or_head_handler_from_request(state, request).await
@ -330,11 +366,18 @@ pub fn build_admin_router_with_ui(state: Arc<VolumeServerState>, ui_enabled: boo
.route("/:vid/:fid/:filename", any(admin_store_handler))
.fallback(admin_store_handler);
if ui_enabled {
router = router
.route("/ui/index.html", get(handlers::ui_handler))
let whitelist_state = state.clone();
let stats_router = Router::new()
.route("/stats/disk", get(handlers::stats_disk_handler))
.route("/stats/counter", get(handlers::stats_counter_handler))
.route("/stats/memory", get(handlers::stats_memory_handler));
.route("/stats/memory", get(handlers::stats_memory_handler))
.route_layer(middleware::from_fn_with_state(
whitelist_state,
whitelist_guard_middleware,
));
router = router
.route("/ui/index.html", get(handlers::ui_handler))
.merge(stats_router);
}
router
.layer(middleware::from_fn(common_headers_middleware))

109
seaweed-volume/tests/http_integration.rs

@ -6,6 +6,7 @@
use std::sync::{Arc, RwLock};
use axum::body::Body;
use axum::extract::connect_info::ConnectInfo;
use axum::http::{Request, StatusCode};
use tower::ServiceExt; // for `oneshot`
@ -23,10 +24,21 @@ use tempfile::TempDir;
/// Create a test VolumeServerState with a temp directory, a single disk
/// location, and one pre-created volume (VolumeId 1).
fn test_state() -> (Arc<VolumeServerState>, TempDir) {
test_state_with_signing_key(Vec::new())
test_state_with_guard(Vec::new(), Vec::new())
}
fn test_state_with_signing_key(signing_key: Vec<u8>) -> (Arc<VolumeServerState>, TempDir) {
test_state_with_guard(Vec::new(), signing_key)
}
fn test_state_with_whitelist(whitelist: Vec<String>) -> (Arc<VolumeServerState>, TempDir) {
test_state_with_guard(whitelist, Vec::new())
}
fn test_state_with_guard(
whitelist: Vec<String>,
signing_key: Vec<u8>,
) -> (Arc<VolumeServerState>, TempDir) {
let tmp = TempDir::new().expect("failed to create temp dir");
let dir = tmp.path().to_str().unwrap();
@ -53,7 +65,13 @@ fn test_state_with_signing_key(signing_key: Vec<u8>) -> (Arc<VolumeServerState>,
)
.expect("failed to create volume");
let guard = Guard::new(&[], SigningKey(signing_key), 0, SigningKey(vec![]), 0);
let guard = Guard::new(
&whitelist,
SigningKey(signing_key),
0,
SigningKey(vec![]),
0,
);
let state = Arc::new(VolumeServerState {
store: RwLock::new(store),
guard: RwLock::new(guard),
@ -109,6 +127,15 @@ async fn body_bytes(response: axum::response::Response) -> Vec<u8> {
.to_vec()
}
fn with_remote_addr(request: Request<Body>, remote_addr: &str) -> Request<Body> {
let mut request = request;
let remote_addr = remote_addr
.parse::<std::net::SocketAddr>()
.expect("invalid socket address");
request.extensions_mut().insert(ConnectInfo(remote_addr));
request
}
// ============================================================================
// 1. GET /healthz returns 200 when server is running
// ============================================================================
@ -226,6 +253,84 @@ async fn metrics_router_serves_metrics() {
assert_eq!(response.status(), StatusCode::OK);
}
#[tokio::test]
async fn admin_router_rejects_non_whitelisted_uploads() {
let (state, _tmp) = test_state_with_whitelist(vec!["127.0.0.1".to_string()]);
let app = build_admin_router(state);
let response = app
.oneshot(with_remote_addr(
Request::builder()
.method("POST")
.uri("/1,000000000000000001")
.body(Body::from("blocked"))
.unwrap(),
"10.0.0.9:12345",
))
.await
.unwrap();
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn admin_router_rejects_non_whitelisted_deletes() {
let (state, _tmp) = test_state_with_whitelist(vec!["127.0.0.1".to_string()]);
let app = build_admin_router(state);
let response = app
.oneshot(with_remote_addr(
Request::builder()
.method("DELETE")
.uri("/1,000000000000000001")
.body(Body::empty())
.unwrap(),
"10.0.0.9:12345",
))
.await
.unwrap();
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn admin_router_rejects_non_whitelisted_stats_routes() {
let (state, _tmp) = test_state_with_whitelist(vec!["127.0.0.1".to_string()]);
let app = build_admin_router_with_ui(state, true);
let response = app
.oneshot(with_remote_addr(
Request::builder()
.uri("/stats/counter")
.body(Body::empty())
.unwrap(),
"10.0.0.9:12345",
))
.await
.unwrap();
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn admin_router_allows_whitelisted_stats_routes() {
let (state, _tmp) = test_state_with_whitelist(vec!["127.0.0.1".to_string()]);
let app = build_admin_router_with_ui(state, true);
let response = app
.oneshot(with_remote_addr(
Request::builder()
.uri("/stats/counter")
.body(Body::empty())
.unwrap(),
"127.0.0.1:12345",
))
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
}
// ============================================================================
// 4. POST writes data, then GET reads it back
// ============================================================================

Loading…
Cancel
Save