diff --git a/test/volume_server/http/admin_test.go b/test/volume_server/http/admin_test.go index be4445ebc..6dde9c20d 100644 --- a/test/volume_server/http/admin_test.go +++ b/test/volume_server/http/admin_test.go @@ -23,8 +23,6 @@ func TestAdminStatusAndHealthz(t *testing.T) { if err != nil { t.Fatalf("create status request: %v", err) } - statusReq.Header.Set(request_id.AmzRequestIDHeader, "test-request-id-1") - statusResp := framework.DoRequest(t, client, statusReq) statusBody := framework.ReadAllAndClose(t, statusResp) @@ -34,8 +32,8 @@ func TestAdminStatusAndHealthz(t *testing.T) { if got := statusResp.Header.Get("Server"); !strings.Contains(got, "SeaweedFS Volume") { t.Fatalf("expected /status Server header to contain SeaweedFS Volume, got %q", got) } - if got := statusResp.Header.Get(request_id.AmzRequestIDHeader); got != "test-request-id-1" { - t.Fatalf("expected echoed request id, got %q", got) + if got := statusResp.Header.Get(request_id.AmzRequestIDHeader); got == "" { + t.Fatal("expected server-generated request id in response header") } var payload map[string]interface{} @@ -49,7 +47,6 @@ func TestAdminStatusAndHealthz(t *testing.T) { } healthReq := mustNewRequest(t, http.MethodGet, cluster.VolumeAdminURL()+"/healthz") - healthReq.Header.Set(request_id.AmzRequestIDHeader, "test-request-id-2") healthResp := framework.DoRequest(t, client, healthReq) _ = framework.ReadAllAndClose(t, healthResp) if healthResp.StatusCode != http.StatusOK { @@ -58,8 +55,8 @@ func TestAdminStatusAndHealthz(t *testing.T) { if got := healthResp.Header.Get("Server"); !strings.Contains(got, "SeaweedFS Volume") { t.Fatalf("expected /healthz Server header to contain SeaweedFS Volume, got %q", got) } - if got := healthResp.Header.Get(request_id.AmzRequestIDHeader); got != "test-request-id-2" { - t.Fatalf("expected /healthz echoed request id, got %q", got) + if got := healthResp.Header.Get(request_id.AmzRequestIDHeader); got == "" { + t.Fatal("expected /healthz server-generated request id in response header") } uiResp := framework.DoRequest(t, client, mustNewRequest(t, http.MethodGet, cluster.VolumeAdminURL()+"/ui/index.html")) diff --git a/weed/iam/error_response_test.go b/weed/iam/error_response_test.go deleted file mode 100644 index 4e008c157..000000000 --- a/weed/iam/error_response_test.go +++ /dev/null @@ -1,34 +0,0 @@ -package iam - -import ( - "encoding/xml" - "strings" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestErrorResponseXMLUsesTopLevelRequestId(t *testing.T) { - errCode := "NoSuchEntity" - errMsg := "the requested IAM entity does not exist" - - resp := ErrorResponse{ - RequestId: "request-123", - } - resp.Error.Type = "Sender" - resp.Error.Code = &errCode - resp.Error.Message = &errMsg - - output, err := xml.Marshal(resp) - require.NoError(t, err) - - xmlString := string(output) - errorIndex := strings.Index(xmlString, "") - requestIDIndex := strings.Index(xmlString, "request-123") - - assert.NotEqual(t, -1, errorIndex, "Error should be present") - assert.NotEqual(t, -1, requestIDIndex, "RequestId should be present") - assert.NotContains(t, xmlString, "") - assert.Less(t, errorIndex, requestIDIndex, "RequestId should appear after Error at the root level") -} diff --git a/weed/iam/responses.go b/weed/iam/responses.go index ac78c5ba7..c64b3ce23 100644 --- a/weed/iam/responses.go +++ b/weed/iam/responses.go @@ -2,8 +2,6 @@ package iam import ( "encoding/xml" - "fmt" - "time" "github.com/aws/aws-sdk-go/service/iam" ) @@ -15,9 +13,14 @@ type CommonResponse struct { } `xml:"ResponseMetadata"` } -// SetRequestId sets a unique request ID based on current timestamp. -func (r *CommonResponse) SetRequestId() { - r.ResponseMetadata.RequestId = newRequestID() +// SetRequestId stores the request ID generated for the current HTTP request. +func (r *CommonResponse) SetRequestId(requestID string) { + r.ResponseMetadata.RequestId = requestID +} + +// RequestIDSetter is implemented by IAM responses that can carry a RequestId. +type RequestIDSetter interface { + SetRequestId(string) } // ListUsersResponse is the response for ListUsers action. @@ -187,6 +190,7 @@ type GetUserPolicyResponse struct { } // ErrorResponse is the IAM error response format. +// AWS IAM uses a bare at root level for errors, not . type ErrorResponse struct { XMLName xml.Name `xml:"https://iam.amazonaws.com/doc/2010-05-08/ ErrorResponse"` Error struct { @@ -196,13 +200,9 @@ type ErrorResponse struct { RequestId string `xml:"RequestId"` } -// SetRequestId sets a unique request ID based on current timestamp. -func (r *ErrorResponse) SetRequestId() { - r.RequestId = newRequestID() -} - -func newRequestID() string { - return fmt.Sprintf("%d", time.Now().UnixNano()) +// SetRequestId stores the request ID generated for the current HTTP request. +func (r *ErrorResponse) SetRequestId(requestID string) { + r.RequestId = requestID } // Error represents an IAM API error with code and underlying error. diff --git a/weed/iam/responses_test.go b/weed/iam/responses_test.go index daf2afc85..d5ecdfbba 100644 --- a/weed/iam/responses_test.go +++ b/weed/iam/responses_test.go @@ -11,7 +11,7 @@ import ( func TestListUsersResponseXMLOrdering(t *testing.T) { resp := ListUsersResponse{} - resp.SetRequestId() + resp.SetRequestId("test-req-id") output, err := xml.Marshal(resp) require.NoError(t, err) @@ -22,5 +22,31 @@ func TestListUsersResponseXMLOrdering(t *testing.T) { assert.NotEqual(t, -1, listUsersResultIndex, "ListUsersResult should be present") assert.NotEqual(t, -1, responseMetadataIndex, "ResponseMetadata should be present") - assert.Less(t, listUsersResultIndex, responseMetadataIndex, "ListUsersResult should appear before ResponseMetadata") + assert.Less(t, listUsersResultIndex, responseMetadataIndex, + "ListUsersResult should appear before ResponseMetadata") +} + +func TestErrorResponseXMLUsesTopLevelRequestId(t *testing.T) { + errCode := "NoSuchEntity" + errMsg := "the requested IAM entity does not exist" + + resp := ErrorResponse{} + resp.Error.Type = "Sender" + resp.Error.Code = &errCode + resp.Error.Message = &errMsg + resp.SetRequestId("request-123") + + output, err := xml.Marshal(resp) + require.NoError(t, err) + + xmlString := string(output) + errorIndex := strings.Index(xmlString, "") + requestIDIndex := strings.Index(xmlString, "request-123") + + assert.NotEqual(t, -1, errorIndex, "Error should be present") + assert.NotEqual(t, -1, requestIDIndex, "RequestId should be present") + assert.NotContains(t, xmlString, "", + "ErrorResponse should use bare RequestId, not ResponseMetadata wrapper") + assert.Less(t, errorIndex, requestIDIndex, + "RequestId should appear after Error at the root level") } diff --git a/weed/iamapi/iamapi_handlers.go b/weed/iamapi/iamapi_handlers.go index f187c9067..1ab3c1282 100644 --- a/weed/iamapi/iamapi_handlers.go +++ b/weed/iamapi/iamapi_handlers.go @@ -8,20 +8,20 @@ import ( "github.com/seaweedfs/seaweedfs/weed/s3api/s3err" ) -func newErrorResponse(errCode string, errMsg string) ErrorResponse { +func newErrorResponse(errCode string, errMsg string, requestID string) ErrorResponse { errorResp := ErrorResponse{} errorResp.Error.Type = "Sender" errorResp.Error.Code = &errCode errorResp.Error.Message = &errMsg - errorResp.SetRequestId() + errorResp.SetRequestId(requestID) return errorResp } -func writeIamErrorResponse(w http.ResponseWriter, r *http.Request, iamError *IamError) { - +func writeIamErrorResponse(w http.ResponseWriter, r *http.Request, reqID string, iamError *IamError) { if iamError == nil { - // Do nothing if there is no error - glog.Errorf("No error found") + glog.Errorf("writeIamErrorResponse called with nil error") + internalResp := newErrorResponse(iam.ErrCodeServiceFailureException, "Internal server error", reqID) + s3err.WriteXMLResponse(w, r, http.StatusInternalServerError, internalResp) return } @@ -29,8 +29,8 @@ func writeIamErrorResponse(w http.ResponseWriter, r *http.Request, iamError *Iam errMsg := iamError.Error.Error() glog.Errorf("Response %+v", errMsg) - errorResp := newErrorResponse(errCode, errMsg) - internalErrorResponse := newErrorResponse(iam.ErrCodeServiceFailureException, "Internal server error") + errorResp := newErrorResponse(errCode, errMsg, reqID) + internalErrorResponse := newErrorResponse(iam.ErrCodeServiceFailureException, "Internal server error", reqID) switch errCode { case iam.ErrCodeNoSuchEntityException: diff --git a/weed/iamapi/iamapi_management_handlers.go b/weed/iamapi/iamapi_management_handlers.go index e56ab45b2..c2141e6cd 100644 --- a/weed/iamapi/iamapi_management_handlers.go +++ b/weed/iamapi/iamapi_management_handlers.go @@ -19,6 +19,7 @@ import ( "github.com/seaweedfs/seaweedfs/weed/pb/iam_pb" "github.com/seaweedfs/seaweedfs/weed/s3api/policy_engine" "github.com/seaweedfs/seaweedfs/weed/s3api/s3err" + "github.com/seaweedfs/seaweedfs/weed/util/request_id" ) // Constants from shared package @@ -162,14 +163,16 @@ func validateAccessKeyStatus(status string) error { } } -func (iama *IamApiServer) ListUsers(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (resp ListUsersResponse) { +func (iama *IamApiServer) ListUsers(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (resp *ListUsersResponse) { + resp = &ListUsersResponse{} for _, ident := range s3cfg.Identities { resp.ListUsersResult.Users = append(resp.ListUsersResult.Users, &iam.User{UserName: &ident.Name}) } return resp } -func (iama *IamApiServer) ListAccessKeys(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (resp ListAccessKeysResponse) { +func (iama *IamApiServer) ListAccessKeys(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (resp *ListAccessKeysResponse) { + resp = &ListAccessKeysResponse{} userName := values.Get("UserName") for _, ident := range s3cfg.Identities { if userName != "" && userName != ident.Name { @@ -190,16 +193,31 @@ func (iama *IamApiServer) ListAccessKeys(s3cfg *iam_pb.S3ApiConfiguration, value return resp } -func (iama *IamApiServer) CreateUser(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (resp CreateUserResponse) { +func (iama *IamApiServer) CreateUser(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (resp *CreateUserResponse) { + resp = &CreateUserResponse{} userName := values.Get("UserName") resp.CreateUserResult.User.UserName = &userName s3cfg.Identities = append(s3cfg.Identities, &iam_pb.Identity{Name: userName}) return resp } -func (iama *IamApiServer) DeleteUser(s3cfg *iam_pb.S3ApiConfiguration, userName string) (resp DeleteUserResponse, err *IamError) { +func (iama *IamApiServer) DeleteUser(s3cfg *iam_pb.S3ApiConfiguration, userName string) (resp *DeleteUserResponse, err *IamError) { + resp = &DeleteUserResponse{} for i, ident := range s3cfg.Identities { if userName == ident.Name { + // Clean up any inline policies stored for this user + policies := Policies{} + if pErr := iama.s3ApiConfig.GetPolicies(&policies); pErr != nil && !errors.Is(pErr, filer_pb.ErrNotFound) { + return resp, &IamError{Code: iam.ErrCodeServiceFailureException, Error: pErr} + } + if policies.InlinePolicies != nil { + if _, exists := policies.InlinePolicies[userName]; exists { + delete(policies.InlinePolicies, userName) + if pErr := iama.s3ApiConfig.PutPolicies(&policies); pErr != nil { + return resp, &IamError{Code: iam.ErrCodeServiceFailureException, Error: pErr} + } + } + } s3cfg.Identities = append(s3cfg.Identities[:i], s3cfg.Identities[i+1:]...) return resp, nil } @@ -207,7 +225,8 @@ func (iama *IamApiServer) DeleteUser(s3cfg *iam_pb.S3ApiConfiguration, userName return resp, &IamError{Code: iam.ErrCodeNoSuchEntityException, Error: fmt.Errorf(USER_DOES_NOT_EXIST, userName)} } -func (iama *IamApiServer) GetUser(s3cfg *iam_pb.S3ApiConfiguration, userName string) (resp GetUserResponse, err *IamError) { +func (iama *IamApiServer) GetUser(s3cfg *iam_pb.S3ApiConfiguration, userName string) (resp *GetUserResponse, err *IamError) { + resp = &GetUserResponse{} for _, ident := range s3cfg.Identities { if userName == ident.Name { resp.GetUserResult.User = iam.User{UserName: &ident.Name} @@ -217,13 +236,28 @@ func (iama *IamApiServer) GetUser(s3cfg *iam_pb.S3ApiConfiguration, userName str return resp, &IamError{Code: iam.ErrCodeNoSuchEntityException, Error: fmt.Errorf(USER_DOES_NOT_EXIST, userName)} } -func (iama *IamApiServer) UpdateUser(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (resp UpdateUserResponse, err *IamError) { +func (iama *IamApiServer) UpdateUser(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (resp *UpdateUserResponse, err *IamError) { + resp = &UpdateUserResponse{} userName := values.Get("UserName") newUserName := values.Get("NewUserName") if newUserName != "" { for _, ident := range s3cfg.Identities { if userName == ident.Name { ident.Name = newUserName + // Move any inline policies from old username to new username + policies := Policies{} + if pErr := iama.s3ApiConfig.GetPolicies(&policies); pErr != nil && !errors.Is(pErr, filer_pb.ErrNotFound) { + return resp, &IamError{Code: iam.ErrCodeServiceFailureException, Error: pErr} + } + if policies.InlinePolicies != nil { + if userPolicies, exists := policies.InlinePolicies[userName]; exists { + delete(policies.InlinePolicies, userName) + policies.InlinePolicies[newUserName] = userPolicies + if pErr := iama.s3ApiConfig.PutPolicies(&policies); pErr != nil { + return resp, &IamError{Code: iam.ErrCodeServiceFailureException, Error: pErr} + } + } + } return resp, nil } } @@ -241,12 +275,13 @@ func GetPolicyDocument(policy *string) (policy_engine.PolicyDocument, error) { return policyDocument, nil } -func (iama *IamApiServer) CreatePolicy(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (resp CreatePolicyResponse, iamError *IamError) { +func (iama *IamApiServer) CreatePolicy(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (resp *CreatePolicyResponse, iamError *IamError) { + resp = &CreatePolicyResponse{} policyName := values.Get("PolicyName") policyDocumentString := values.Get("PolicyDocument") policyDocument, err := GetPolicyDocument(&policyDocumentString) if err != nil { - return CreatePolicyResponse{}, &IamError{Code: iam.ErrCodeMalformedPolicyDocumentException, Error: err} + return resp, &IamError{Code: iam.ErrCodeMalformedPolicyDocumentException, Error: err} } policyId := Hash(&policyName) arn := fmt.Sprintf("arn:aws:iam:::policy/%s", policyName) @@ -274,16 +309,17 @@ type IamError struct { } // https://docs.aws.amazon.com/IAM/latest/APIReference/API_PutUserPolicy.html -func (iama *IamApiServer) PutUserPolicy(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (resp PutUserPolicyResponse, iamError *IamError) { +func (iama *IamApiServer) PutUserPolicy(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (resp *PutUserPolicyResponse, iamError *IamError) { + resp = &PutUserPolicyResponse{} userName := values.Get("UserName") policyName := values.Get("PolicyName") policyDocumentString := values.Get("PolicyDocument") policyDocument, err := GetPolicyDocument(&policyDocumentString) if err != nil { - return PutUserPolicyResponse{}, &IamError{Code: iam.ErrCodeMalformedPolicyDocumentException, Error: err} + return resp, &IamError{Code: iam.ErrCodeMalformedPolicyDocumentException, Error: err} } if _, err := GetActions(&policyDocument); err != nil { - return PutUserPolicyResponse{}, &IamError{Code: iam.ErrCodeMalformedPolicyDocumentException, Error: err} + return resp, &IamError{Code: iam.ErrCodeMalformedPolicyDocumentException, Error: err} } // Verify the user exists before persisting the policy @@ -295,20 +331,20 @@ func (iama *IamApiServer) PutUserPolicy(s3cfg *iam_pb.S3ApiConfiguration, values } } if targetIdent == nil { - return PutUserPolicyResponse{}, &IamError{Code: iam.ErrCodeNoSuchEntityException, Error: fmt.Errorf("the user with name %s cannot be found", userName)} + return resp, &IamError{Code: iam.ErrCodeNoSuchEntityException, Error: fmt.Errorf("the user with name %s cannot be found", userName)} } // Persist inline policy to storage using per-user indexed structure policies := Policies{} if err = iama.s3ApiConfig.GetPolicies(&policies); err != nil && !errors.Is(err, filer_pb.ErrNotFound) { - return PutUserPolicyResponse{}, &IamError{Code: iam.ErrCodeServiceFailureException, Error: err} + return resp, &IamError{Code: iam.ErrCodeServiceFailureException, Error: err} } userPolicies := policies.getOrCreateUserPolicies(userName) userPolicies[policyName] = policyDocument if err = iama.s3ApiConfig.PutPolicies(&policies); err != nil { - return PutUserPolicyResponse{}, &IamError{Code: iam.ErrCodeServiceFailureException, Error: err} + return resp, &IamError{Code: iam.ErrCodeServiceFailureException, Error: err} } // Recompute aggregated actions (inline + managed) @@ -321,7 +357,8 @@ func (iama *IamApiServer) PutUserPolicy(s3cfg *iam_pb.S3ApiConfiguration, values return resp, nil } -func (iama *IamApiServer) GetUserPolicy(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (resp GetUserPolicyResponse, err *IamError) { +func (iama *IamApiServer) GetUserPolicy(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (resp *GetUserPolicyResponse, err *IamError) { + resp = &GetUserPolicyResponse{} userName := values.Get("UserName") policyName := values.Get("PolicyName") for _, ident := range s3cfg.Identities { @@ -404,7 +441,8 @@ func (iama *IamApiServer) GetUserPolicy(s3cfg *iam_pb.S3ApiConfiguration, values } // DeleteUserPolicy removes the inline policy from a user (clears their actions). -func (iama *IamApiServer) DeleteUserPolicy(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (resp DeleteUserPolicyResponse, err *IamError) { +func (iama *IamApiServer) DeleteUserPolicy(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (resp *DeleteUserPolicyResponse, err *IamError) { + resp = &DeleteUserPolicyResponse{} userName := values.Get("UserName") policyName := values.Get("PolicyName") @@ -447,7 +485,8 @@ func (iama *IamApiServer) DeleteUserPolicy(s3cfg *iam_pb.S3ApiConfiguration, val } // GetPolicy retrieves a managed policy by ARN. -func (iama *IamApiServer) GetPolicy(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (resp GetPolicyResponse, iamError *IamError) { +func (iama *IamApiServer) GetPolicy(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (resp *GetPolicyResponse, iamError *IamError) { + resp = &GetPolicyResponse{} policyArn := values.Get("PolicyArn") policyName, iamError := parsePolicyArn(policyArn) if iamError != nil { @@ -472,7 +511,8 @@ func (iama *IamApiServer) GetPolicy(s3cfg *iam_pb.S3ApiConfiguration, values url // DeletePolicy removes a managed policy. Rejects deletion if the policy is still attached to any user // (matching AWS IAM behavior: must detach before deleting). -func (iama *IamApiServer) DeletePolicy(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (resp DeletePolicyResponse, iamError *IamError) { +func (iama *IamApiServer) DeletePolicy(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (resp *DeletePolicyResponse, iamError *IamError) { + resp = &DeletePolicyResponse{} policyArn := values.Get("PolicyArn") policyName, iamError := parsePolicyArn(policyArn) if iamError != nil { @@ -509,7 +549,8 @@ func (iama *IamApiServer) DeletePolicy(s3cfg *iam_pb.S3ApiConfiguration, values } // ListPolicies lists all managed policies. -func (iama *IamApiServer) ListPolicies(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (resp ListPoliciesResponse, iamError *IamError) { +func (iama *IamApiServer) ListPolicies(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (resp *ListPoliciesResponse, iamError *IamError) { + resp = &ListPoliciesResponse{} policies := Policies{} if err := iama.s3ApiConfig.GetPolicies(&policies); err != nil && !errors.Is(err, filer_pb.ErrNotFound) { return resp, &IamError{Code: iam.ErrCodeServiceFailureException, Error: err} @@ -529,7 +570,8 @@ func (iama *IamApiServer) ListPolicies(s3cfg *iam_pb.S3ApiConfiguration, values } // AttachUserPolicy attaches a managed policy to a user. -func (iama *IamApiServer) AttachUserPolicy(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (resp AttachUserPolicyResponse, iamError *IamError) { +func (iama *IamApiServer) AttachUserPolicy(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (resp *AttachUserPolicyResponse, iamError *IamError) { + resp = &AttachUserPolicyResponse{} userName := values.Get("UserName") policyArn := values.Get("PolicyArn") policyName, iamError := parsePolicyArn(policyArn) @@ -574,7 +616,8 @@ func (iama *IamApiServer) AttachUserPolicy(s3cfg *iam_pb.S3ApiConfiguration, val } // DetachUserPolicy detaches a managed policy from a user. -func (iama *IamApiServer) DetachUserPolicy(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (resp DetachUserPolicyResponse, iamError *IamError) { +func (iama *IamApiServer) DetachUserPolicy(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (resp *DetachUserPolicyResponse, iamError *IamError) { + resp = &DetachUserPolicyResponse{} userName := values.Get("UserName") policyArn := values.Get("PolicyArn") policyName, iamError := parsePolicyArn(policyArn) @@ -622,7 +665,8 @@ func (iama *IamApiServer) DetachUserPolicy(s3cfg *iam_pb.S3ApiConfiguration, val } // ListAttachedUserPolicies lists the managed policies attached to a user. -func (iama *IamApiServer) ListAttachedUserPolicies(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (resp ListAttachedUserPoliciesResponse, iamError *IamError) { +func (iama *IamApiServer) ListAttachedUserPolicies(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (resp *ListAttachedUserPoliciesResponse, iamError *IamError) { + resp = &ListAttachedUserPoliciesResponse{} userName := values.Get("UserName") for _, ident := range s3cfg.Identities { if ident.Name != userName { @@ -714,7 +758,8 @@ func GetActions(policy *policy_engine.PolicyDocument) ([]string, error) { return actions, nil } -func (iama *IamApiServer) CreateAccessKey(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (resp CreateAccessKeyResponse, iamErr *IamError) { +func (iama *IamApiServer) CreateAccessKey(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (resp *CreateAccessKeyResponse, iamErr *IamError) { + resp = &CreateAccessKeyResponse{} userName := values.Get("UserName") status := iam.StatusTypeActive @@ -758,7 +803,8 @@ func (iama *IamApiServer) CreateAccessKey(s3cfg *iam_pb.S3ApiConfiguration, valu } // UpdateAccessKey updates the status of an access key (Active or Inactive). -func (iama *IamApiServer) UpdateAccessKey(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (resp UpdateAccessKeyResponse, err *IamError) { +func (iama *IamApiServer) UpdateAccessKey(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (resp *UpdateAccessKeyResponse, err *IamError) { + resp = &UpdateAccessKeyResponse{} userName := values.Get("UserName") accessKeyId := values.Get("AccessKeyId") status := values.Get("Status") @@ -788,7 +834,8 @@ func (iama *IamApiServer) UpdateAccessKey(s3cfg *iam_pb.S3ApiConfiguration, valu return resp, &IamError{Code: iam.ErrCodeNoSuchEntityException, Error: fmt.Errorf(USER_DOES_NOT_EXIST, userName)} } -func (iama *IamApiServer) DeleteAccessKey(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (resp DeleteAccessKeyResponse) { +func (iama *IamApiServer) DeleteAccessKey(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (resp *DeleteAccessKeyResponse) { + resp = &DeleteAccessKeyResponse{} userName := values.Get("UserName") accessKeyId := values.Get("AccessKeyId") for _, ident := range s3cfg.Identities { @@ -859,6 +906,8 @@ func (iama *IamApiServer) DoActions(w http.ResponseWriter, r *http.Request) { policyLock.Lock() defer policyLock.Unlock() + r, reqID := request_id.Ensure(r) + if err := r.ParseForm(); err != nil { s3err.WriteErrorResponse(w, r, s3err.ErrInvalidRequest) return @@ -871,8 +920,7 @@ func (iama *IamApiServer) DoActions(w http.ResponseWriter, r *http.Request) { } glog.V(4).Infof("DoActions: %+v", values) - var response interface{} - var iamError *IamError + var response iamlib.RequestIDSetter changed := true switch r.Form.Get("Action") { case "ListUsers": @@ -886,32 +934,35 @@ func (iama *IamApiServer) DoActions(w http.ResponseWriter, r *http.Request) { response = iama.CreateUser(s3cfg, values) case "GetUser": userName := values.Get("UserName") - response, iamError = iama.GetUser(s3cfg, userName) - if iamError != nil { - writeIamErrorResponse(w, r, iamError) + var err *IamError + response, err = iama.GetUser(s3cfg, userName) + if err != nil { + writeIamErrorResponse(w, r, reqID, err) return } changed = false case "UpdateUser": - response, iamError = iama.UpdateUser(s3cfg, values) - if iamError != nil { - glog.Errorf("UpdateUser: %+v", iamError.Error) - s3err.WriteErrorResponse(w, r, s3err.ErrInvalidRequest) + var err *IamError + response, err = iama.UpdateUser(s3cfg, values) + if err != nil { + writeIamErrorResponse(w, r, reqID, err) return } case "DeleteUser": userName := values.Get("UserName") - response, iamError = iama.DeleteUser(s3cfg, userName) - if iamError != nil { - writeIamErrorResponse(w, r, iamError) + var err *IamError + response, err = iama.DeleteUser(s3cfg, userName) + if err != nil { + writeIamErrorResponse(w, r, reqID, err) return } case "CreateAccessKey": iama.handleImplicitUsername(r, values) - response, iamError = iama.CreateAccessKey(s3cfg, values) - if iamError != nil { - glog.Errorf("CreateAccessKey: %+v", iamError.Error) - writeIamErrorResponse(w, r, iamError) + var err *IamError + response, err = iama.CreateAccessKey(s3cfg, values) + if err != nil { + glog.Errorf("CreateAccessKey: %+v", err.Error) + writeIamErrorResponse(w, r, reqID, err) return } case "DeleteAccessKey": @@ -919,85 +970,94 @@ func (iama *IamApiServer) DoActions(w http.ResponseWriter, r *http.Request) { response = iama.DeleteAccessKey(s3cfg, values) case "UpdateAccessKey": iama.handleImplicitUsername(r, values) - response, iamError = iama.UpdateAccessKey(s3cfg, values) - if iamError != nil { - writeIamErrorResponse(w, r, iamError) + var err *IamError + response, err = iama.UpdateAccessKey(s3cfg, values) + if err != nil { + writeIamErrorResponse(w, r, reqID, err) return } case "CreatePolicy": - response, iamError = iama.CreatePolicy(s3cfg, values) - if iamError != nil { - glog.Errorf("CreatePolicy: %+v", iamError.Error) - s3err.WriteErrorResponse(w, r, s3err.ErrInvalidRequest) + var err *IamError + response, err = iama.CreatePolicy(s3cfg, values) + if err != nil { + writeIamErrorResponse(w, r, reqID, err) return } // CreatePolicy persists the policy document via iama.s3ApiConfig.PutPolicies(). // The `changed` flag is false because this does not modify the main s3cfg.Identities configuration. changed = false case "PutUserPolicy": - var iamError *IamError - response, iamError = iama.PutUserPolicy(s3cfg, values) - if iamError != nil { - glog.Errorf("PutUserPolicy: %+v", iamError.Error) - - writeIamErrorResponse(w, r, iamError) + var err *IamError + response, err = iama.PutUserPolicy(s3cfg, values) + if err != nil { + glog.Errorf("PutUserPolicy: %+v", err.Error) + writeIamErrorResponse(w, r, reqID, err) return } case "GetUserPolicy": - response, iamError = iama.GetUserPolicy(s3cfg, values) - if iamError != nil { - writeIamErrorResponse(w, r, iamError) + var err *IamError + response, err = iama.GetUserPolicy(s3cfg, values) + if err != nil { + writeIamErrorResponse(w, r, reqID, err) return } changed = false case "DeleteUserPolicy": - if response, iamError = iama.DeleteUserPolicy(s3cfg, values); iamError != nil { - writeIamErrorResponse(w, r, iamError) + var err *IamError + response, err = iama.DeleteUserPolicy(s3cfg, values) + if err != nil { + writeIamErrorResponse(w, r, reqID, err) return } case "GetPolicy": - response, iamError = iama.GetPolicy(s3cfg, values) - if iamError != nil { - writeIamErrorResponse(w, r, iamError) + var err *IamError + response, err = iama.GetPolicy(s3cfg, values) + if err != nil { + writeIamErrorResponse(w, r, reqID, err) return } changed = false case "DeletePolicy": - response, iamError = iama.DeletePolicy(s3cfg, values) - if iamError != nil { - writeIamErrorResponse(w, r, iamError) + var err *IamError + response, err = iama.DeletePolicy(s3cfg, values) + if err != nil { + writeIamErrorResponse(w, r, reqID, err) return } changed = false case "ListPolicies": - response, iamError = iama.ListPolicies(s3cfg, values) - if iamError != nil { - writeIamErrorResponse(w, r, iamError) + var err *IamError + response, err = iama.ListPolicies(s3cfg, values) + if err != nil { + writeIamErrorResponse(w, r, reqID, err) return } changed = false case "AttachUserPolicy": - response, iamError = iama.AttachUserPolicy(s3cfg, values) - if iamError != nil { - writeIamErrorResponse(w, r, iamError) + var err *IamError + response, err = iama.AttachUserPolicy(s3cfg, values) + if err != nil { + writeIamErrorResponse(w, r, reqID, err) return } case "DetachUserPolicy": - response, iamError = iama.DetachUserPolicy(s3cfg, values) - if iamError != nil { - writeIamErrorResponse(w, r, iamError) + var err *IamError + response, err = iama.DetachUserPolicy(s3cfg, values) + if err != nil { + writeIamErrorResponse(w, r, reqID, err) return } case "ListAttachedUserPolicies": - response, iamError = iama.ListAttachedUserPolicies(s3cfg, values) - if iamError != nil { - writeIamErrorResponse(w, r, iamError) + var err *IamError + response, err = iama.ListAttachedUserPolicies(s3cfg, values) + if err != nil { + writeIamErrorResponse(w, r, reqID, err) return } changed = false default: errNotImplemented := s3err.GetAPIError(s3err.ErrNotImplemented) - errorResponse := newErrorResponse(errNotImplemented.Code, errNotImplemented.Description) + errorResponse := newErrorResponse(errNotImplemented.Code, errNotImplemented.Description, reqID) s3err.WriteXMLResponse(w, r, errNotImplemented.HTTPStatusCode, errorResponse) return } @@ -1005,7 +1065,7 @@ func (iama *IamApiServer) DoActions(w http.ResponseWriter, r *http.Request) { err := iama.s3ApiConfig.PutS3ApiConfiguration(s3cfg) if err != nil { var iamError = IamError{Code: iam.ErrCodeServiceFailureException, Error: err} - writeIamErrorResponse(w, r, &iamError) + writeIamErrorResponse(w, r, reqID, &iamError) return } // Reload in-memory identity maps so subsequent LookupByAccessKey calls @@ -1017,5 +1077,6 @@ func (iama *IamApiServer) DoActions(w http.ResponseWriter, r *http.Request) { } } } + response.SetRequestId(reqID) s3err.WriteXMLResponse(w, r, http.StatusOK, response) } diff --git a/weed/iamapi/iamapi_server.go b/weed/iamapi/iamapi_server.go index a7fe6da07..26eaf289f 100644 --- a/weed/iamapi/iamapi_server.go +++ b/weed/iamapi/iamapi_server.go @@ -21,6 +21,7 @@ import ( . "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants" "github.com/seaweedfs/seaweedfs/weed/s3api/s3err" "github.com/seaweedfs/seaweedfs/weed/util" + "github.com/seaweedfs/seaweedfs/weed/util/request_id" "github.com/seaweedfs/seaweedfs/weed/wdclient" "google.golang.org/grpc" "google.golang.org/protobuf/proto" @@ -117,6 +118,7 @@ func NewIamApiServerWithStore(router *mux.Router, option *IamServerOption, expli func (iama *IamApiServer) registerRouter(router *mux.Router) { // API Router apiRouter := router.PathPrefix("/").Subrouter() + apiRouter.Use(request_id.Middleware) // ListBuckets // apiRouter.Methods("GET").Path("/").HandlerFunc(track(s3a.iam.Auth(s3a.ListBucketsHandler, ACTION_ADMIN), "LIST")) diff --git a/weed/iamapi/iamapi_test.go b/weed/iamapi/iamapi_test.go index 9177c7d4a..509fc5745 100644 --- a/weed/iamapi/iamapi_test.go +++ b/weed/iamapi/iamapi_test.go @@ -17,6 +17,7 @@ import ( "github.com/seaweedfs/seaweedfs/weed/pb/iam_pb" "github.com/seaweedfs/seaweedfs/weed/s3api" "github.com/seaweedfs/seaweedfs/weed/s3api/policy_engine" + "github.com/seaweedfs/seaweedfs/weed/util/request_id" "github.com/stretchr/testify/assert" ) @@ -73,6 +74,21 @@ func TestListUsers(t *testing.T) { assert.Equal(t, http.StatusOK, response.Code) } +func TestListUsersRequestIdMatchesResponseHeader(t *testing.T) { + params := &iam.ListUsersInput{} + req, _ := iam.New(session.New()).ListUsersRequest(params) + _ = req.Build() + + out := ListUsersResponse{} + response, err := executeRequest(req.HTTPRequest, out) + assert.Equal(t, nil, err) + assert.Equal(t, http.StatusOK, response.Code) + + headerRequestID := response.Header().Get(request_id.AmzRequestIDHeader) + assert.NotEmpty(t, headerRequestID) + assert.Equal(t, headerRequestID, extractRequestID(response)) +} + func TestListAccessKeys(t *testing.T) { svc := iam.New(session.New()) params := &iam.ListAccessKeysInput{} @@ -246,6 +262,7 @@ func TestPutUserPolicyError(t *testing.T) { assert.Equal(t, expectedCode, code) assert.Contains(t, response.Body.String(), "") assert.NotContains(t, response.Body.String(), "") + assert.Equal(t, response.Header().Get(request_id.AmzRequestIDHeader), extractRequestID(response)) } func extractErrorCodeAndMessage(response *httptest.ResponseRecorder) (string, string) { @@ -257,6 +274,15 @@ func extractErrorCodeAndMessage(response *httptest.ResponseRecorder) (string, st return code, message } +func extractRequestID(response *httptest.ResponseRecorder) string { + re := regexp.MustCompile(`([^<]+)`) + matches := re.FindStringSubmatch(response.Body.String()) + if len(matches) < 2 { + return "" + } + return matches[1] +} + func TestGetUserPolicy(t *testing.T) { userName := aws.String("Test") params := &iam.GetUserPolicyInput{UserName: userName, PolicyName: aws.String("S3-read-only-example-bucket")} diff --git a/weed/s3api/s3api_embedded_iam.go b/weed/s3api/s3api_embedded_iam.go index 97123e413..5b65382be 100644 --- a/weed/s3api/s3api_embedded_iam.go +++ b/weed/s3api/s3api_embedded_iam.go @@ -26,6 +26,7 @@ import ( "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants" . "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants" "github.com/seaweedfs/seaweedfs/weed/s3api/s3err" + "github.com/seaweedfs/seaweedfs/weed/util/request_id" "google.golang.org/protobuf/proto" ) @@ -157,18 +158,20 @@ const ( iamAccessKeyStatusInactive = iamlib.AccessKeyStatusInactive ) -func newIamErrorResponse(errCode string, errMsg string) iamErrorResponse { +func newIamErrorResponse(errCode string, errMsg string, requestID string) iamErrorResponse { errorResp := iamErrorResponse{} errorResp.Error.Type = "Sender" errorResp.Error.Code = &errCode errorResp.Error.Message = &errMsg - errorResp.SetRequestId() + errorResp.SetRequestId(requestID) return errorResp } -func (e *EmbeddedIamApi) writeIamErrorResponse(w http.ResponseWriter, r *http.Request, iamErr *iamError) { +func (e *EmbeddedIamApi) writeIamErrorResponse(w http.ResponseWriter, r *http.Request, reqID string, iamErr *iamError) { if iamErr == nil { - glog.Errorf("No error found") + glog.Errorf("writeIamErrorResponse called with nil error") + internalResp := newIamErrorResponse(iam.ErrCodeServiceFailureException, "Internal server error", reqID) + s3err.WriteXMLResponse(w, r, http.StatusInternalServerError, internalResp) return } @@ -176,8 +179,8 @@ func (e *EmbeddedIamApi) writeIamErrorResponse(w http.ResponseWriter, r *http.Re errMsg := iamErr.Error.Error() glog.Errorf("IAM Response %+v", errMsg) - errorResp := newIamErrorResponse(errCode, errMsg) - internalErrorResponse := newIamErrorResponse(iam.ErrCodeServiceFailureException, "Internal server error") + errorResp := newIamErrorResponse(errCode, errMsg, reqID) + internalErrorResponse := newIamErrorResponse(iam.ErrCodeServiceFailureException, "Internal server error", reqID) switch errCode { case iam.ErrCodeNoSuchEntityException: @@ -230,8 +233,8 @@ func (e *EmbeddedIamApi) ReloadConfiguration() error { } // ListUsers lists all IAM users. -func (e *EmbeddedIamApi) ListUsers(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) iamListUsersResponse { - var resp iamListUsersResponse +func (e *EmbeddedIamApi) ListUsers(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) *iamListUsersResponse { + resp := &iamListUsersResponse{} for _, ident := range s3cfg.Identities { resp.ListUsersResult.Users = append(resp.ListUsersResult.Users, &iam.User{UserName: &ident.Name}) } @@ -239,8 +242,8 @@ func (e *EmbeddedIamApi) ListUsers(s3cfg *iam_pb.S3ApiConfiguration, values url. } // ListAccessKeys lists access keys for a user. -func (e *EmbeddedIamApi) ListAccessKeys(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) iamListAccessKeysResponse { - var resp iamListAccessKeysResponse +func (e *EmbeddedIamApi) ListAccessKeys(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) *iamListAccessKeysResponse { + resp := &iamListAccessKeysResponse{} userName := values.Get("UserName") for _, ident := range s3cfg.Identities { if userName != "" && userName != ident.Name { @@ -265,8 +268,8 @@ func (e *EmbeddedIamApi) ListAccessKeys(s3cfg *iam_pb.S3ApiConfiguration, values } // CreateUser creates a new IAM user. -func (e *EmbeddedIamApi) CreateUser(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (iamCreateUserResponse, *iamError) { - var resp iamCreateUserResponse +func (e *EmbeddedIamApi) CreateUser(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (*iamCreateUserResponse, *iamError) { + resp := &iamCreateUserResponse{} userName := values.Get("UserName") // Validate UserName is not empty @@ -287,8 +290,8 @@ func (e *EmbeddedIamApi) CreateUser(s3cfg *iam_pb.S3ApiConfiguration, values url } // DeleteUser deletes an IAM user. -func (e *EmbeddedIamApi) DeleteUser(s3cfg *iam_pb.S3ApiConfiguration, userName string) (iamDeleteUserResponse, *iamError) { - var resp iamDeleteUserResponse +func (e *EmbeddedIamApi) DeleteUser(s3cfg *iam_pb.S3ApiConfiguration, userName string) (*iamDeleteUserResponse, *iamError) { + resp := &iamDeleteUserResponse{} for i, ident := range s3cfg.Identities { if userName == ident.Name { // AWS IAM behavior: prevent deletion if user has service accounts @@ -308,8 +311,8 @@ func (e *EmbeddedIamApi) DeleteUser(s3cfg *iam_pb.S3ApiConfiguration, userName s } // GetUser gets an IAM user. -func (e *EmbeddedIamApi) GetUser(s3cfg *iam_pb.S3ApiConfiguration, userName string) (iamGetUserResponse, *iamError) { - var resp iamGetUserResponse +func (e *EmbeddedIamApi) GetUser(s3cfg *iam_pb.S3ApiConfiguration, userName string) (*iamGetUserResponse, *iamError) { + resp := &iamGetUserResponse{} for _, ident := range s3cfg.Identities { if userName == ident.Name { resp.GetUserResult.User = iam.User{UserName: &ident.Name} @@ -320,8 +323,8 @@ func (e *EmbeddedIamApi) GetUser(s3cfg *iam_pb.S3ApiConfiguration, userName stri } // UpdateUser updates an IAM user. -func (e *EmbeddedIamApi) UpdateUser(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (iamUpdateUserResponse, *iamError) { - var resp iamUpdateUserResponse +func (e *EmbeddedIamApi) UpdateUser(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (*iamUpdateUserResponse, *iamError) { + resp := &iamUpdateUserResponse{} userName := values.Get("UserName") newUserName := values.Get("NewUserName") if newUserName != "" { @@ -338,8 +341,8 @@ func (e *EmbeddedIamApi) UpdateUser(s3cfg *iam_pb.S3ApiConfiguration, values url } // CreateAccessKey creates an access key for a user. -func (e *EmbeddedIamApi) CreateAccessKey(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (iamCreateAccessKeyResponse, *iamError) { - var resp iamCreateAccessKeyResponse +func (e *EmbeddedIamApi) CreateAccessKey(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (*iamCreateAccessKeyResponse, *iamError) { + resp := &iamCreateAccessKeyResponse{} userName := values.Get("UserName") status := iam.StatusTypeActive @@ -372,8 +375,8 @@ func (e *EmbeddedIamApi) CreateAccessKey(s3cfg *iam_pb.S3ApiConfiguration, value } // DeleteAccessKey deletes an access key for a user. -func (e *EmbeddedIamApi) DeleteAccessKey(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) iamDeleteAccessKeyResponse { - var resp iamDeleteAccessKeyResponse +func (e *EmbeddedIamApi) DeleteAccessKey(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) *iamDeleteAccessKeyResponse { + resp := &iamDeleteAccessKeyResponse{} userName := values.Get("UserName") accessKeyId := values.Get("AccessKeyId") for _, ident := range s3cfg.Identities { @@ -400,8 +403,8 @@ func (e *EmbeddedIamApi) GetPolicyDocument(policy *string) (policy_engine.Policy } // CreatePolicy validates and creates a new IAM managed policy. -func (e *EmbeddedIamApi) CreatePolicy(ctx context.Context, values url.Values) (iamCreatePolicyResponse, *iamError) { - var resp iamCreatePolicyResponse +func (e *EmbeddedIamApi) CreatePolicy(ctx context.Context, values url.Values) (*iamCreatePolicyResponse, *iamError) { + resp := &iamCreatePolicyResponse{} policyName := values.Get("PolicyName") policyDocumentString := values.Get("PolicyDocument") if policyName == "" { @@ -443,8 +446,8 @@ func (e *EmbeddedIamApi) CreatePolicy(ctx context.Context, values url.Values) (i } // DeletePolicy deletes a managed policy by ARN. -func (e *EmbeddedIamApi) DeletePolicy(ctx context.Context, values url.Values) (iamDeletePolicyResponse, *iamError) { - var resp iamDeletePolicyResponse +func (e *EmbeddedIamApi) DeletePolicy(ctx context.Context, values url.Values) (*iamDeletePolicyResponse, *iamError) { + resp := &iamDeletePolicyResponse{} policyArn := values.Get("PolicyArn") policyName, err := iamPolicyNameFromArn(policyArn) if err != nil { @@ -485,8 +488,8 @@ func (e *EmbeddedIamApi) DeletePolicy(ctx context.Context, values url.Values) (i } // ListPolicies lists managed policies. -func (e *EmbeddedIamApi) ListPolicies(ctx context.Context, values url.Values) (iamListPoliciesResponse, *iamError) { - var resp iamListPoliciesResponse +func (e *EmbeddedIamApi) ListPolicies(ctx context.Context, values url.Values) (*iamListPoliciesResponse, *iamError) { + resp := &iamListPoliciesResponse{} pathPrefix := values.Get("PathPrefix") if pathPrefix == "" { pathPrefix = "/" @@ -558,8 +561,8 @@ func (e *EmbeddedIamApi) ListPolicies(ctx context.Context, values url.Values) (i } // GetPolicy returns metadata for a managed policy. -func (e *EmbeddedIamApi) GetPolicy(ctx context.Context, values url.Values) (iamGetPolicyResponse, *iamError) { - var resp iamGetPolicyResponse +func (e *EmbeddedIamApi) GetPolicy(ctx context.Context, values url.Values) (*iamGetPolicyResponse, *iamError) { + resp := &iamGetPolicyResponse{} policyArn := values.Get("PolicyArn") policyName, err := iamPolicyNameFromArn(policyArn) if err != nil { @@ -595,8 +598,8 @@ func (e *EmbeddedIamApi) GetPolicy(ctx context.Context, values url.Values) (iamG // ListPolicyVersions lists versions for a managed policy. // Current SeaweedFS implementation stores one version per policy (v1). -func (e *EmbeddedIamApi) ListPolicyVersions(ctx context.Context, values url.Values) (iamListPolicyVersionsResponse, *iamError) { - var resp iamListPolicyVersionsResponse +func (e *EmbeddedIamApi) ListPolicyVersions(ctx context.Context, values url.Values) (*iamListPolicyVersionsResponse, *iamError) { + resp := &iamListPolicyVersionsResponse{} policyArn := values.Get("PolicyArn") policyName, err := iamPolicyNameFromArn(policyArn) if err != nil { @@ -625,8 +628,8 @@ func (e *EmbeddedIamApi) ListPolicyVersions(ctx context.Context, values url.Valu // GetPolicyVersion returns the document for a specific policy version. // Current SeaweedFS implementation stores one version per policy (v1). -func (e *EmbeddedIamApi) GetPolicyVersion(ctx context.Context, values url.Values) (iamGetPolicyVersionResponse, *iamError) { - var resp iamGetPolicyVersionResponse +func (e *EmbeddedIamApi) GetPolicyVersion(ctx context.Context, values url.Values) (*iamGetPolicyVersionResponse, *iamError) { + resp := &iamGetPolicyVersionResponse{} policyArn := values.Get("PolicyArn") versionID := values.Get("VersionId") if versionID == "" { @@ -754,8 +757,8 @@ func (e *EmbeddedIamApi) getActions(policy *policy_engine.PolicyDocument) ([]str } // PutUserPolicy attaches a policy to a user. -func (e *EmbeddedIamApi) PutUserPolicy(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (iamPutUserPolicyResponse, *iamError) { - var resp iamPutUserPolicyResponse +func (e *EmbeddedIamApi) PutUserPolicy(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (*iamPutUserPolicyResponse, *iamError) { + resp := &iamPutUserPolicyResponse{} userName := values.Get("UserName") policyDocumentString := values.Get("PolicyDocument") policyDocument, err := e.GetPolicyDocument(&policyDocumentString) @@ -778,8 +781,8 @@ func (e *EmbeddedIamApi) PutUserPolicy(s3cfg *iam_pb.S3ApiConfiguration, values } // GetUserPolicy gets the policy attached to a user. -func (e *EmbeddedIamApi) GetUserPolicy(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (iamGetUserPolicyResponse, *iamError) { - var resp iamGetUserPolicyResponse +func (e *EmbeddedIamApi) GetUserPolicy(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (*iamGetUserPolicyResponse, *iamError) { + resp := &iamGetUserPolicyResponse{} userName := values.Get("UserName") policyName := values.Get("PolicyName") for _, ident := range s3cfg.Identities { @@ -845,8 +848,8 @@ func (e *EmbeddedIamApi) GetUserPolicy(s3cfg *iam_pb.S3ApiConfiguration, values } // DeleteUserPolicy removes the inline policy from a user (clears their actions). -func (e *EmbeddedIamApi) DeleteUserPolicy(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (iamDeleteUserPolicyResponse, *iamError) { - var resp iamDeleteUserPolicyResponse +func (e *EmbeddedIamApi) DeleteUserPolicy(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (*iamDeleteUserPolicyResponse, *iamError) { + resp := &iamDeleteUserPolicyResponse{} userName := values.Get("UserName") for _, ident := range s3cfg.Identities { if ident.Name == userName { @@ -858,8 +861,8 @@ func (e *EmbeddedIamApi) DeleteUserPolicy(s3cfg *iam_pb.S3ApiConfiguration, valu } // AttachUserPolicy attaches a managed policy to a user. -func (e *EmbeddedIamApi) AttachUserPolicy(ctx context.Context, values url.Values) (iamAttachUserPolicyResponse, *iamError) { - var resp iamAttachUserPolicyResponse +func (e *EmbeddedIamApi) AttachUserPolicy(ctx context.Context, values url.Values) (*iamAttachUserPolicyResponse, *iamError) { + resp := &iamAttachUserPolicyResponse{} userName := values.Get("UserName") if userName == "" { @@ -926,8 +929,8 @@ func (e *EmbeddedIamApi) AttachUserPolicy(ctx context.Context, values url.Values } // DetachUserPolicy detaches a managed policy from a user. -func (e *EmbeddedIamApi) DetachUserPolicy(ctx context.Context, values url.Values) (iamDetachUserPolicyResponse, *iamError) { - var resp iamDetachUserPolicyResponse +func (e *EmbeddedIamApi) DetachUserPolicy(ctx context.Context, values url.Values) (*iamDetachUserPolicyResponse, *iamError) { + resp := &iamDetachUserPolicyResponse{} userName := values.Get("UserName") if userName == "" { @@ -971,8 +974,8 @@ func (e *EmbeddedIamApi) DetachUserPolicy(ctx context.Context, values url.Values } // ListAttachedUserPolicies lists managed policies attached to a user. -func (e *EmbeddedIamApi) ListAttachedUserPolicies(ctx context.Context, values url.Values) (iamListAttachedUserPoliciesResponse, *iamError) { - var resp iamListAttachedUserPoliciesResponse +func (e *EmbeddedIamApi) ListAttachedUserPolicies(ctx context.Context, values url.Values) (*iamListAttachedUserPoliciesResponse, *iamError) { + resp := &iamListAttachedUserPoliciesResponse{} userName := values.Get("UserName") if userName == "" { @@ -1056,8 +1059,8 @@ func (e *EmbeddedIamApi) ListAttachedUserPolicies(ctx context.Context, values ur // SetUserStatus enables or disables a user without deleting them. // This is a SeaweedFS extension for temporary user suspension, offboarding, etc. // When a user is disabled, all API requests using their credentials will return AccessDenied. -func (e *EmbeddedIamApi) SetUserStatus(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (iamSetUserStatusResponse, *iamError) { - var resp iamSetUserStatusResponse +func (e *EmbeddedIamApi) SetUserStatus(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (*iamSetUserStatusResponse, *iamError) { + resp := &iamSetUserStatusResponse{} userName := values.Get("UserName") status := values.Get("Status") @@ -1083,8 +1086,8 @@ func (e *EmbeddedIamApi) SetUserStatus(s3cfg *iam_pb.S3ApiConfiguration, values // UpdateAccessKey updates the status of an access key (Active or Inactive). // This allows key rotation workflows where old keys are deactivated before deletion. -func (e *EmbeddedIamApi) UpdateAccessKey(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (iamUpdateAccessKeyResponse, *iamError) { - var resp iamUpdateAccessKeyResponse +func (e *EmbeddedIamApi) UpdateAccessKey(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (*iamUpdateAccessKeyResponse, *iamError) { + resp := &iamUpdateAccessKeyResponse{} userName := values.Get("UserName") accessKeyId := values.Get("AccessKeyId") status := values.Get("Status") @@ -1129,8 +1132,8 @@ func findIdentityByName(s3cfg *iam_pb.S3ApiConfiguration, name string) *iam_pb.I } // CreateServiceAccount creates a new service account for a user. -func (e *EmbeddedIamApi) CreateServiceAccount(s3cfg *iam_pb.S3ApiConfiguration, values url.Values, createdBy string) (iamCreateServiceAccountResponse, *iamError) { - var resp iamCreateServiceAccountResponse +func (e *EmbeddedIamApi) CreateServiceAccount(s3cfg *iam_pb.S3ApiConfiguration, values url.Values, createdBy string) (*iamCreateServiceAccountResponse, *iamError) { + resp := &iamCreateServiceAccountResponse{} parentUser := values.Get("ParentUser") description := values.Get("Description") expirationStr := values.Get("Expiration") // Unix timestamp as string @@ -1239,8 +1242,8 @@ func (e *EmbeddedIamApi) CreateServiceAccount(s3cfg *iam_pb.S3ApiConfiguration, } // DeleteServiceAccount deletes a service account. -func (e *EmbeddedIamApi) DeleteServiceAccount(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (iamDeleteServiceAccountResponse, *iamError) { - var resp iamDeleteServiceAccountResponse +func (e *EmbeddedIamApi) DeleteServiceAccount(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (*iamDeleteServiceAccountResponse, *iamError) { + resp := &iamDeleteServiceAccountResponse{} saId := values.Get("ServiceAccountId") if saId == "" { @@ -1272,8 +1275,8 @@ func (e *EmbeddedIamApi) DeleteServiceAccount(s3cfg *iam_pb.S3ApiConfiguration, } // ListServiceAccounts lists service accounts, optionally filtered by parent user. -func (e *EmbeddedIamApi) ListServiceAccounts(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) iamListServiceAccountsResponse { - var resp iamListServiceAccountsResponse +func (e *EmbeddedIamApi) ListServiceAccounts(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) *iamListServiceAccountsResponse { + resp := &iamListServiceAccountsResponse{} parentUser := values.Get("ParentUser") // Optional filter for _, sa := range s3cfg.ServiceAccounts { @@ -1307,8 +1310,8 @@ func (e *EmbeddedIamApi) ListServiceAccounts(s3cfg *iam_pb.S3ApiConfiguration, v } // GetServiceAccount retrieves a service account by ID. -func (e *EmbeddedIamApi) GetServiceAccount(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (iamGetServiceAccountResponse, *iamError) { - var resp iamGetServiceAccountResponse +func (e *EmbeddedIamApi) GetServiceAccount(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (*iamGetServiceAccountResponse, *iamError) { + resp := &iamGetServiceAccountResponse{} saId := values.Get("ServiceAccountId") if saId == "" { @@ -1344,8 +1347,8 @@ func (e *EmbeddedIamApi) GetServiceAccount(s3cfg *iam_pb.S3ApiConfiguration, val } // UpdateServiceAccount updates a service account's status, description, or expiration. -func (e *EmbeddedIamApi) UpdateServiceAccount(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (iamUpdateServiceAccountResponse, *iamError) { - var resp iamUpdateServiceAccountResponse +func (e *EmbeddedIamApi) UpdateServiceAccount(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (*iamUpdateServiceAccountResponse, *iamError) { + resp := &iamUpdateServiceAccountResponse{} saId := values.Get("ServiceAccountId") newStatus := values.Get("Status") newDescription := values.Get("Description") @@ -1543,7 +1546,11 @@ func (e *EmbeddedIamApi) AuthIam(f http.HandlerFunc, _ Action) http.HandlerFunc // ExecuteAction executes an IAM action with the given values. // If skipPersist is true, the changed configuration is not saved to the persistent store. -func (e *EmbeddedIamApi) ExecuteAction(ctx context.Context, values url.Values, skipPersist bool) (interface{}, *iamError) { +// reqID is set on the response; if empty, a new request ID is generated. +func (e *EmbeddedIamApi) ExecuteAction(ctx context.Context, values url.Values, skipPersist bool, reqID string) (iamlib.RequestIDSetter, *iamError) { + if reqID == "" { + reqID = request_id.New() + } // Lock to prevent concurrent read-modify-write race conditions e.policyLock.Lock() defer e.policyLock.Unlock() @@ -1564,8 +1571,7 @@ func (e *EmbeddedIamApi) ExecuteAction(ctx context.Context, values url.Values, s } glog.V(4).Infof("IAM ExecuteAction: %+v", values) - var response interface{} - var iamErr *iamError + var response iamlib.RequestIDSetter changed := true switch values.Get("Action") { case "ListUsers": @@ -1577,29 +1583,34 @@ func (e *EmbeddedIamApi) ExecuteAction(ctx context.Context, values url.Values, s response = e.ListAccessKeys(s3cfg, values) changed = false case "CreateUser": + var iamErr *iamError response, iamErr = e.CreateUser(s3cfg, values) if iamErr != nil { return nil, iamErr } case "GetUser": userName := values.Get("UserName") + var iamErr *iamError response, iamErr = e.GetUser(s3cfg, userName) if iamErr != nil { return nil, iamErr } changed = false case "UpdateUser": + var iamErr *iamError response, iamErr = e.UpdateUser(s3cfg, values) if iamErr != nil { return nil, iamErr } case "DeleteUser": userName := values.Get("UserName") + var iamErr *iamError response, iamErr = e.DeleteUser(s3cfg, userName) if iamErr != nil { return nil, iamErr } case "CreateAccessKey": + var iamErr *iamError response, iamErr = e.CreateAccessKey(s3cfg, values) if iamErr != nil { glog.Errorf("CreateAccessKey: %+v", iamErr.Error) @@ -1608,6 +1619,7 @@ func (e *EmbeddedIamApi) ExecuteAction(ctx context.Context, values url.Values, s case "DeleteAccessKey": response = e.DeleteAccessKey(s3cfg, values) case "CreatePolicy": + var iamErr *iamError response, iamErr = e.CreatePolicy(ctx, values) if iamErr != nil { glog.Errorf("CreatePolicy: %+v", iamErr.Error) @@ -1615,6 +1627,7 @@ func (e *EmbeddedIamApi) ExecuteAction(ctx context.Context, values url.Values, s } changed = false case "DeletePolicy": + var iamErr *iamError response, iamErr = e.DeletePolicy(ctx, values) if iamErr != nil { glog.Errorf("DeletePolicy: %+v", iamErr.Error) @@ -1622,70 +1635,82 @@ func (e *EmbeddedIamApi) ExecuteAction(ctx context.Context, values url.Values, s } changed = false case "PutUserPolicy": + var iamErr *iamError response, iamErr = e.PutUserPolicy(s3cfg, values) if iamErr != nil { glog.Errorf("PutUserPolicy: %+v", iamErr.Error) return nil, iamErr } case "GetUserPolicy": + var iamErr *iamError response, iamErr = e.GetUserPolicy(s3cfg, values) if iamErr != nil { return nil, iamErr } changed = false case "DeleteUserPolicy": + var iamErr *iamError response, iamErr = e.DeleteUserPolicy(s3cfg, values) if iamErr != nil { return nil, iamErr } case "AttachUserPolicy": + var iamErr *iamError response, iamErr = e.AttachUserPolicy(ctx, values) if iamErr != nil { return nil, iamErr } changed = false case "DetachUserPolicy": + var iamErr *iamError response, iamErr = e.DetachUserPolicy(ctx, values) if iamErr != nil { return nil, iamErr } changed = false case "ListAttachedUserPolicies": + var iamErr *iamError response, iamErr = e.ListAttachedUserPolicies(ctx, values) if iamErr != nil { return nil, iamErr } changed = false case "ListPolicies": + var iamErr *iamError response, iamErr = e.ListPolicies(ctx, values) if iamErr != nil { return nil, iamErr } changed = false case "GetPolicy": + var iamErr *iamError response, iamErr = e.GetPolicy(ctx, values) if iamErr != nil { return nil, iamErr } changed = false case "ListPolicyVersions": + var iamErr *iamError response, iamErr = e.ListPolicyVersions(ctx, values) if iamErr != nil { return nil, iamErr } changed = false case "GetPolicyVersion": + var iamErr *iamError response, iamErr = e.GetPolicyVersion(ctx, values) if iamErr != nil { return nil, iamErr } changed = false case "SetUserStatus": + var iamErr *iamError response, iamErr = e.SetUserStatus(s3cfg, values) if iamErr != nil { return nil, iamErr } case "UpdateAccessKey": + var iamErr *iamError response, iamErr = e.UpdateAccessKey(s3cfg, values) if iamErr != nil { return nil, iamErr @@ -1693,11 +1718,13 @@ func (e *EmbeddedIamApi) ExecuteAction(ctx context.Context, values url.Values, s // Service Account actions case "CreateServiceAccount": createdBy := values.Get("CreatedBy") + var iamErr *iamError response, iamErr = e.CreateServiceAccount(s3cfg, values, createdBy) if iamErr != nil { return nil, iamErr } case "DeleteServiceAccount": + var iamErr *iamError response, iamErr = e.DeleteServiceAccount(s3cfg, values) if iamErr != nil { return nil, iamErr @@ -1706,12 +1733,14 @@ func (e *EmbeddedIamApi) ExecuteAction(ctx context.Context, values url.Values, s response = e.ListServiceAccounts(s3cfg, values) changed = false case "GetServiceAccount": + var iamErr *iamError response, iamErr = e.GetServiceAccount(s3cfg, values) if iamErr != nil { return nil, iamErr } changed = false case "UpdateServiceAccount": + var iamErr *iamError response, iamErr = e.UpdateServiceAccount(s3cfg, values) if iamErr != nil { return nil, iamErr @@ -1722,8 +1751,7 @@ func (e *EmbeddedIamApi) ExecuteAction(ctx context.Context, values url.Values, s if changed { if !skipPersist { if err := e.PutS3ApiConfiguration(s3cfg); err != nil { - iamErr = &iamError{Code: iam.ErrCodeServiceFailureException, Error: err} - return nil, iamErr + return nil, &iamError{Code: iam.ErrCodeServiceFailureException, Error: err} } } // Reload in-memory identity maps so subsequent LookupByAccessKey calls @@ -1739,11 +1767,13 @@ func (e *EmbeddedIamApi) ExecuteAction(ctx context.Context, values url.Values, s glog.Errorf("Failed to reload IAM configuration after managed policy mutation: %v", err) } } - return response, iamErr + response.SetRequestId(reqID) + return response, nil } // DoActions handles IAM API actions. func (e *EmbeddedIamApi) DoActions(w http.ResponseWriter, r *http.Request) { + r, reqID := request_id.Ensure(r) if err := r.ParseForm(); err != nil { s3err.WriteErrorResponse(w, r, s3err.ErrInvalidRequest) return @@ -1759,15 +1789,11 @@ func (e *EmbeddedIamApi) DoActions(w http.ResponseWriter, r *http.Request) { values.Set("CreatedBy", createdBy) } - response, iamErr := e.ExecuteAction(r.Context(), values, false) + response, iamErr := e.ExecuteAction(r.Context(), values, false, reqID) if iamErr != nil { - e.writeIamErrorResponse(w, r, iamErr) + e.writeIamErrorResponse(w, r, reqID, iamErr) return } - // Set RequestId for AWS compatibility - if r, ok := response.(interface{ SetRequestId() }); ok { - r.SetRequestId() - } s3err.WriteXMLResponse(w, r, http.StatusOK, response) } diff --git a/weed/s3api/s3api_embedded_iam_test.go b/weed/s3api/s3api_embedded_iam_test.go index 10429a236..5ad18db8a 100644 --- a/weed/s3api/s3api_embedded_iam_test.go +++ b/weed/s3api/s3api_embedded_iam_test.go @@ -8,6 +8,7 @@ import ( "net/http" "net/http/httptest" "net/url" + "regexp" "strings" "sync" "testing" @@ -23,6 +24,7 @@ import ( "github.com/seaweedfs/seaweedfs/weed/s3api/policy_engine" . "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants" "github.com/seaweedfs/seaweedfs/weed/s3api/s3err" + "github.com/seaweedfs/seaweedfs/weed/util/request_id" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "google.golang.org/protobuf/proto" @@ -156,6 +158,15 @@ func extractEmbeddedIamErrorCodeAndMessage(response *httptest.ResponseRecorder) return "", "" } +func extractEmbeddedIamRequestID(response *httptest.ResponseRecorder) string { + re := regexp.MustCompile(`([^<]+)`) + matches := re.FindStringSubmatch(response.Body.String()) + if len(matches) < 2 { + return "" + } + return matches[1] +} + // TestEmbeddedIamCreateUser tests creating a user via the embedded IAM API func TestEmbeddedIamCreateUser(t *testing.T) { api := NewEmbeddedIamApiForTest() @@ -199,6 +210,8 @@ func TestEmbeddedIamListUsers(t *testing.T) { // Verify response contains the users assert.Len(t, out.ListUsersResult.Users, 2) + assert.NotEmpty(t, response.Header().Get(request_id.AmzRequestIDHeader)) + assert.Equal(t, response.Header().Get(request_id.AmzRequestIDHeader), out.ResponseMetadata.RequestId) } // TestEmbeddedIamListAccessKeys tests listing access keys via the embedded IAM API @@ -1216,6 +1229,7 @@ func TestEmbeddedIamNotImplementedAction(t *testing.T) { assert.Equal(t, http.StatusNotImplemented, rr.Code) assert.Contains(t, rr.Body.String(), "") assert.NotContains(t, rr.Body.String(), "") + assert.Equal(t, rr.Header().Get(request_id.AmzRequestIDHeader), extractEmbeddedIamRequestID(rr)) } // TestGetPolicyDocument tests parsing of policy documents @@ -1900,11 +1914,11 @@ func TestEmbeddedIamExecuteAction(t *testing.T) { vals.Set("Action", "CreateUser") vals.Set("UserName", "ExecuteActionUser") - resp, iamErr := api.ExecuteAction(context.Background(), vals, false) + resp, iamErr := api.ExecuteAction(context.Background(), vals, false, "") assert.Nil(t, iamErr) // Verify response type - createResp, ok := resp.(iamCreateUserResponse) + createResp, ok := resp.(*iamCreateUserResponse) assert.True(t, ok) assert.Equal(t, "ExecuteActionUser", *createResp.CreateUserResult.User.UserName) diff --git a/weed/s3api/s3api_server.go b/weed/s3api/s3api_server.go index c88a1ea0b..8d2fe0f41 100644 --- a/weed/s3api/s3api_server.go +++ b/weed/s3api/s3api_server.go @@ -34,6 +34,7 @@ import ( "github.com/seaweedfs/seaweedfs/weed/util/grace" util_http "github.com/seaweedfs/seaweedfs/weed/util/http" util_http_client "github.com/seaweedfs/seaweedfs/weed/util/http/client" + "github.com/seaweedfs/seaweedfs/weed/util/request_id" "github.com/seaweedfs/seaweedfs/weed/wdclient" ) @@ -540,6 +541,7 @@ func (s3a *S3ApiServer) UnifiedPostHandler(w http.ResponseWriter, r *http.Reques func (s3a *S3ApiServer) registerRouter(router *mux.Router) { // API Router apiRouter := router.PathPrefix("/").Subrouter() + apiRouter.Use(request_id.Middleware) // S3 Tables API endpoint // POST / with X-Amz-Target: S3Tables. diff --git a/weed/s3api/s3api_sts.go b/weed/s3api/s3api_sts.go index 780a39141..8015f3595 100644 --- a/weed/s3api/s3api_sts.go +++ b/weed/s3api/s3api_sts.go @@ -19,6 +19,7 @@ import ( "github.com/seaweedfs/seaweedfs/weed/iam/sts" "github.com/seaweedfs/seaweedfs/weed/iam/utils" "github.com/seaweedfs/seaweedfs/weed/s3api/s3err" + "github.com/seaweedfs/seaweedfs/weed/util/request_id" ) // STS API constants matching AWS STS specification @@ -97,6 +98,7 @@ func (h *STSHandlers) getAccountID() string { // HandleSTSRequest is the main entry point for STS requests // It routes requests based on the Action parameter func (h *STSHandlers) HandleSTSRequest(w http.ResponseWriter, r *http.Request) { + r, _ = request_id.Ensure(r) if err := r.ParseForm(); err != nil { h.writeSTSErrorResponse(w, r, STSErrInvalidParameterValue, err) return @@ -224,7 +226,7 @@ func (h *STSHandlers) handleAssumeRoleWithWebIdentity(w http.ResponseWriter, r * SubjectFromWebIdentityToken: response.AssumedRoleUser.Subject, }, } - xmlResponse.ResponseMetadata.RequestId = fmt.Sprintf("%d", time.Now().UnixNano()) + xmlResponse.ResponseMetadata.RequestId = request_id.GetFromRequest(r) s3err.WriteXMLResponse(w, r, http.StatusOK, xmlResponse) } @@ -354,7 +356,7 @@ func (h *STSHandlers) handleAssumeRole(w http.ResponseWriter, r *http.Request) { AssumedRoleUser: assumedUser, }, } - xmlResponse.ResponseMetadata.RequestId = fmt.Sprintf("%d", time.Now().UnixNano()) + xmlResponse.ResponseMetadata.RequestId = request_id.GetFromRequest(r) s3err.WriteXMLResponse(w, r, http.StatusOK, xmlResponse) } @@ -495,7 +497,7 @@ func (h *STSHandlers) handleAssumeRoleWithLDAPIdentity(w http.ResponseWriter, r AssumedRoleUser: assumedUser, }, } - xmlResponse.ResponseMetadata.RequestId = fmt.Sprintf("%d", time.Now().UnixNano()) + xmlResponse.ResponseMetadata.RequestId = request_id.GetFromRequest(r) s3err.WriteXMLResponse(w, r, http.StatusOK, xmlResponse) } @@ -731,7 +733,7 @@ func (h *STSHandlers) writeSTSErrorResponse(w http.ResponseWriter, r *http.Reque } response := STSErrorResponse{ - RequestId: fmt.Sprintf("%d", time.Now().UnixNano()), + RequestId: request_id.GetFromRequest(r), } // Server-side errors use "Receiver" type per AWS spec diff --git a/weed/s3api/s3err/audit_fluent.go b/weed/s3api/s3err/audit_fluent.go index 5d617ce1c..b63533f1c 100644 --- a/weed/s3api/s3err/audit_fluent.go +++ b/weed/s3api/s3err/audit_fluent.go @@ -10,6 +10,7 @@ import ( "github.com/fluent/fluent-logger-golang/fluent" "github.com/seaweedfs/seaweedfs/weed/glog" "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants" + "github.com/seaweedfs/seaweedfs/weed/util/request_id" ) type AccessLogExtend struct { @@ -150,7 +151,7 @@ func GetAccessLog(r *http.Request, HTTPStatusCode int, s3errCode ErrorCode) *Acc } return &AccessLog{ HostHeader: hostHeader, - RequestID: r.Header.Get("X-Request-ID"), + RequestID: request_id.GetFromRequest(r), RemoteIP: remoteIP, Requester: s3_constants.GetIdentityNameFromContext(r), // Get from context, not header (secure) SignatureVersion: r.Header.Get(s3_constants.AmzAuthType), diff --git a/weed/s3api/s3err/audit_fluent_test.go b/weed/s3api/s3err/audit_fluent_test.go new file mode 100644 index 000000000..bfe2e788f --- /dev/null +++ b/weed/s3api/s3err/audit_fluent_test.go @@ -0,0 +1,19 @@ +package s3err + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/seaweedfs/seaweedfs/weed/util/request_id" + "github.com/stretchr/testify/assert" +) + +func TestGetAccessLogUsesAmzRequestID(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/bucket/object", nil) + req = req.WithContext(request_id.Set(req.Context(), "req-123")) + + log := GetAccessLog(req, http.StatusOK, ErrNone) + + assert.Equal(t, "req-123", log.RequestID) +} diff --git a/weed/s3api/s3err/error_handler.go b/weed/s3api/s3err/error_handler.go index 4f96b4ffb..ec819dfc8 100644 --- a/weed/s3api/s3err/error_handler.go +++ b/weed/s3api/s3err/error_handler.go @@ -3,15 +3,14 @@ package s3err import ( "bytes" "encoding/xml" - "fmt" "net/http" "strconv" "strings" - "time" "github.com/aws/aws-sdk-go/private/protocol/xml/xmlutil" "github.com/gorilla/mux" "github.com/seaweedfs/seaweedfs/weed/glog" + "github.com/seaweedfs/seaweedfs/weed/util/request_id" ) type mimeType string @@ -41,6 +40,7 @@ func WriteEmptyResponse(w http.ResponseWriter, r *http.Request, statusCode int) } func WriteErrorResponse(w http.ResponseWriter, r *http.Request, errorCode ErrorCode) { + r, reqID := request_id.Ensure(r) vars := mux.Vars(r) bucket := vars["bucket"] object := vars["object"] @@ -49,19 +49,19 @@ func WriteErrorResponse(w http.ResponseWriter, r *http.Request, errorCode ErrorC } apiError := GetAPIError(errorCode) - errorResponse := getRESTErrorResponse(apiError, r.URL.Path, bucket, object) + errorResponse := getRESTErrorResponse(apiError, r.URL.Path, bucket, object, reqID) WriteXMLResponse(w, r, apiError.HTTPStatusCode, errorResponse) PostLog(r, apiError.HTTPStatusCode, errorCode) } -func getRESTErrorResponse(err APIError, resource string, bucket, object string) RESTErrorResponse { +func getRESTErrorResponse(err APIError, resource string, bucket, object, requestID string) RESTErrorResponse { return RESTErrorResponse{ Code: err.Code, BucketName: bucket, Key: object, Message: err.Description, Resource: resource, - RequestID: fmt.Sprintf("%d", time.Now().UnixNano()), + RequestID: requestID, } } @@ -75,7 +75,8 @@ func EncodeXMLResponse(response interface{}) []byte { } func setCommonHeaders(w http.ResponseWriter, r *http.Request) { - w.Header().Set("x-amz-request-id", fmt.Sprintf("%d", time.Now().UnixNano())) + _, reqID := request_id.Ensure(r) + w.Header().Set(request_id.AmzRequestIDHeader, reqID) w.Header().Set("Accept-Ranges", "bytes") // Handle CORS headers for requests with Origin header diff --git a/weed/s3api/s3err/error_handler_test.go b/weed/s3api/s3err/error_handler_test.go new file mode 100644 index 000000000..b2108e195 --- /dev/null +++ b/weed/s3api/s3err/error_handler_test.go @@ -0,0 +1,36 @@ +package s3err + +import ( + "net/http" + "net/http/httptest" + "regexp" + "testing" + + "github.com/gorilla/mux" + "github.com/seaweedfs/seaweedfs/weed/util/request_id" + "github.com/stretchr/testify/assert" +) + +func TestWriteErrorResponseReusesRequestID(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/bucket/object", nil) + req = mux.SetURLVars(req, map[string]string{ + "bucket": "bucket", + "object": "object", + }) + req = req.WithContext(request_id.Set(req.Context(), "req-123")) + + rr := httptest.NewRecorder() + WriteErrorResponse(rr, req, ErrNoSuchKey) + + assert.Equal(t, "req-123", rr.Header().Get(request_id.AmzRequestIDHeader)) + assert.Equal(t, "req-123", extractRequestIDFromBody(rr.Body.String())) +} + +func extractRequestIDFromBody(body string) string { + re := regexp.MustCompile(`([^<]+)`) + matches := re.FindStringSubmatch(body) + if len(matches) < 2 { + return "" + } + return matches[1] +} diff --git a/weed/s3api/sts_params_test.go b/weed/s3api/sts_params_test.go index 600dbb963..6aca08ca2 100644 --- a/weed/s3api/sts_params_test.go +++ b/weed/s3api/sts_params_test.go @@ -5,6 +5,7 @@ import ( "net/http" "net/http/httptest" "net/url" + "regexp" "strings" "sync" "testing" @@ -14,6 +15,7 @@ import ( "github.com/seaweedfs/seaweedfs/weed/pb" "github.com/seaweedfs/seaweedfs/weed/pb/iam_pb" "github.com/seaweedfs/seaweedfs/weed/s3api/s3err" + "github.com/seaweedfs/seaweedfs/weed/util/request_id" "github.com/stretchr/testify/assert" ) @@ -114,6 +116,7 @@ func TestSTSAssumeRolePostBody(t *testing.T) { assert.NotEqual(t, http.StatusNotImplemented, rr.Code, "Should not return 501 (IAM handler)") assert.Equal(t, http.StatusBadRequest, rr.Code, "Should return 400 (STS handler) for missing params") + assert.Equal(t, rr.Header().Get(request_id.AmzRequestIDHeader), extractSTSRequestID(rr.Body.String())) }) // Test Case 2: STS Action in Body (Should FAIL current implementation - routed to IAM) @@ -155,6 +158,7 @@ func TestSTSAssumeRolePostBody(t *testing.T) { } // Confirm it routed to STS assert.Equal(t, http.StatusServiceUnavailable, rr.Code, "Fixed behavior: Should return 503 from STS handler (service not ready)") + assert.Equal(t, rr.Header().Get(request_id.AmzRequestIDHeader), extractSTSRequestID(rr.Body.String())) }) // Test Case 3: STS Action in Body with SigV4-style Authorization (Real-world scenario) @@ -199,5 +203,15 @@ func TestSTSAssumeRolePostBody(t *testing.T) { assert.NotEqual(t, http.StatusNotImplemented, rr.Code, "Should not return 501 (IAM handler)") assert.Contains(t, []int{http.StatusServiceUnavailable, http.StatusForbidden}, rr.Code, "Should return 503 (STS unavailable) or 403 (auth failed), confirming STS routing") + assert.Equal(t, rr.Header().Get(request_id.AmzRequestIDHeader), extractSTSRequestID(rr.Body.String())) }) } + +func extractSTSRequestID(body string) string { + re := regexp.MustCompile(`([^<]+)`) + matches := re.FindStringSubmatch(body) + if len(matches) < 2 { + return "" + } + return matches[1] +} diff --git a/weed/server/common.go b/weed/server/common.go index dfed891b4..32662ada9 100644 --- a/weed/server/common.go +++ b/weed/server/common.go @@ -3,7 +3,6 @@ package weed_server import ( "bufio" "bytes" - "context" "encoding/json" "errors" "fmt" @@ -18,7 +17,6 @@ import ( "sync" "time" - "github.com/google/uuid" "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants" "github.com/seaweedfs/seaweedfs/weed/util/request_id" "github.com/seaweedfs/seaweedfs/weed/util/version" @@ -432,18 +430,12 @@ func ProcessRangeRequest(r *http.Request, w http.ResponseWriter, totalSize int64 func requestIDMiddleware(h http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - reqID := r.Header.Get(request_id.AmzRequestIDHeader) - if reqID == "" { - reqID = uuid.New().String() - } - - ctx := context.WithValue(r.Context(), request_id.AmzRequestIDHeader, reqID) - ctx = metadata.NewOutgoingContext(ctx, - metadata.New(map[string]string{ - request_id.AmzRequestIDHeader: reqID, - })) - - w.Header().Set(request_id.AmzRequestIDHeader, reqID) - h(w, r.WithContext(ctx)) + request_id.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := metadata.NewOutgoingContext(r.Context(), + metadata.New(map[string]string{ + request_id.AmzRequestIDHeader: request_id.Get(r.Context()), + })) + h(w, r.WithContext(ctx)) + })).ServeHTTP(w, r) } } diff --git a/weed/util/request_id/request_id.go b/weed/util/request_id/request_id.go index 0550cb58b..5379aea74 100644 --- a/weed/util/request_id/request_id.go +++ b/weed/util/request_id/request_id.go @@ -2,20 +2,25 @@ package request_id import ( "context" + "crypto/rand" + "fmt" "net/http" + "time" ) const AmzRequestIDHeader = "x-amz-request-id" +type contextKey struct{} + func Set(ctx context.Context, id string) context.Context { - return context.WithValue(ctx, AmzRequestIDHeader, id) + return context.WithValue(ctx, contextKey{}, id) } func Get(ctx context.Context) string { if ctx == nil { return "" } - id, _ := ctx.Value(AmzRequestIDHeader).(string) + id, _ := ctx.Value(contextKey{}).(string) return id } @@ -24,3 +29,42 @@ func InjectToRequest(ctx context.Context, req *http.Request) { req.Header.Set(AmzRequestIDHeader, Get(ctx)) } } + +func New() string { + var buf [4]byte + rand.Read(buf[:]) + return fmt.Sprintf("%X%08X", time.Now().UTC().UnixNano(), buf) +} + +// GetFromRequest returns the server-generated request ID from the context. +func GetFromRequest(r *http.Request) string { + if r == nil { + return "" + } + return Get(r.Context()) +} + +// Ensure guarantees a server-generated request ID exists in the context. +// It always generates a new ID if one is not already present in the context, +// ignoring any client-sent x-amz-request-id header to prevent spoofing. +func Ensure(r *http.Request) (*http.Request, string) { + if r == nil { + return nil, "" + } + if id := Get(r.Context()); id != "" { + return r, id + } + id := New() + r = r.WithContext(Set(r.Context(), id)) + return r, id +} + +func Middleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + r, reqID := Ensure(r) + if w.Header().Get(AmzRequestIDHeader) == "" { + w.Header().Set(AmzRequestIDHeader, reqID) + } + next.ServeHTTP(w, r) + }) +} diff --git a/weed/util/request_id/request_id_test.go b/weed/util/request_id/request_id_test.go new file mode 100644 index 000000000..ac5f6bfe8 --- /dev/null +++ b/weed/util/request_id/request_id_test.go @@ -0,0 +1,50 @@ +package request_id + +import ( + "net/http/httptest" + "regexp" + "testing" +) + +var requestIDPattern = regexp.MustCompile(`^[0-9A-F]+$`) + +func TestNewUsesUppercaseHexFormat(t *testing.T) { + id := New() + if !requestIDPattern.MatchString(id) { + t.Fatalf("expected uppercase hex request id, got %q", id) + } + if len(id) < 24 { + t.Fatalf("expected request id to be at least 24 characters, got %q (len=%d)", id, len(id)) + } +} + +func TestNewIsUnique(t *testing.T) { + a := New() + b := New() + if a == b { + t.Fatalf("expected unique request ids, got %q twice", a) + } +} + +func TestEnsureIgnoresClientHeader(t *testing.T) { + req := httptest.NewRequest("GET", "/", nil) + req.Header.Set(AmzRequestIDHeader, "spoofed-id") + + req, id := Ensure(req) + if id == "spoofed-id" { + t.Fatal("Ensure should not trust client-sent x-amz-request-id header") + } + if !requestIDPattern.MatchString(id) { + t.Fatalf("expected server-generated hex id, got %q", id) + } +} + +func TestEnsureReusesContextID(t *testing.T) { + req := httptest.NewRequest("GET", "/", nil) + req = req.WithContext(Set(req.Context(), "ctx-id-123")) + + req, id := Ensure(req) + if id != "ctx-id-123" { + t.Fatalf("expected context id ctx-id-123, got %q", id) + } +}