Browse Source

s3/iam: reuse one request id per request (#8538)

* request_id: add shared request middleware

* s3err: preserve request ids in responses and logs

* iam: reuse request ids in XML responses

* sts: reuse request ids in XML responses

* request_id: drop legacy header fallback

* request_id: use AWS-style request id format

* iam: fix AWS-compatible XML format for ErrorResponse and field ordering

- ErrorResponse uses bare <RequestId> at root level instead of
  <ResponseMetadata> wrapper, matching the AWS IAM error response spec
- Move CommonResponse to last field in success response structs so
  <ResponseMetadata> serializes after result elements
- Add randomness to request ID generation to avoid collisions
- Add tests for XML ordering and ErrorResponse format

* iam: remove duplicate error_response_test.go

Test is already covered by responses_test.go.

* address PR review comments

- Guard against typed nil pointers in SetResponseRequestID before
  interface assertion (CodeRabbit)
- Use regexp instead of strings.Index in test helpers for extracting
  request IDs (Gemini)

* request_id: prevent spoofing, fix nil-error branch, thread reqID to error writers

- Ensure() now always generates a server-side ID, ignoring client-sent
  x-amz-request-id headers to prevent request ID spoofing. Uses a
  private context key (contextKey{}) instead of the header string.
- writeIamErrorResponse in both iamapi and embedded IAM now accepts
  reqID as a parameter instead of calling Ensure() internally, ensuring
  a single request ID per request lifecycle.
- The nil-iamError branch in writeIamErrorResponse now writes a 500
  Internal Server Error response instead of returning silently.
- Updated tests to set request IDs via context (not headers) and added
  tests for spoofing prevention and context reuse.

* sts: add request-id consistency assertions to ActionInBody tests

* test: update admin test to expect server-generated request IDs

The test previously sent a client x-amz-request-id header and expected
it echoed back. Since Ensure() now ignores client headers to prevent
spoofing, update the test to verify the server returns a non-empty
server-generated request ID instead.

* iam: add generic WithRequestID helper alongside reflection-based fallback

Add WithRequestID[T] that uses generics to take the address of a value
type, satisfying the pointer receiver on SetRequestId without reflection.

The existing SetResponseRequestID is kept for the two call sites that
operate on interface{} (from large action switches where the concrete
type varies at runtime). Generics cannot replace reflection there since
Go cannot infer type parameters from interface{}.

* Remove reflection and generics from request ID setting

Call SetRequestId directly on concrete response types in each switch
branch before boxing into interface{}, eliminating the need for
WithRequestID (generics) and SetResponseRequestID (reflection).

* iam: return pointer responses in action dispatch

* Fix IAM error handling consistency and ensure request IDs on all responses

- UpdateUser/CreatePolicy error branches: use writeIamErrorResponse instead
  of s3err.WriteErrorResponse to preserve IAM formatting and request ID
- ExecuteAction: accept reqID parameter and generate one if empty, ensuring
  every response carries a RequestId regardless of caller

* Clean up inline policies on DeleteUser and UpdateUser rename

DeleteUser: remove InlinePolicies[userName] from policy storage before
removing the identity, so policies are not orphaned.

UpdateUser: move InlinePolicies[userName] to InlinePolicies[newUserName]
when renaming, so GetUserPolicy/DeleteUserPolicy work under the new name.

Both operations persist the updated policies and return an error if
the storage write fails, preventing partial state.
pull/8541/head
Chris Lu 3 days ago
committed by GitHub
parent
commit
540fc97e00
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
  1. 11
      test/volume_server/http/admin_test.go
  2. 34
      weed/iam/error_response_test.go
  3. 24
      weed/iam/responses.go
  4. 30
      weed/iam/responses_test.go
  5. 16
      weed/iamapi/iamapi_handlers.go
  6. 219
      weed/iamapi/iamapi_management_handlers.go
  7. 2
      weed/iamapi/iamapi_server.go
  8. 26
      weed/iamapi/iamapi_test.go
  9. 170
      weed/s3api/s3api_embedded_iam.go
  10. 18
      weed/s3api/s3api_embedded_iam_test.go
  11. 2
      weed/s3api/s3api_server.go
  12. 10
      weed/s3api/s3api_sts.go
  13. 3
      weed/s3api/s3err/audit_fluent.go
  14. 19
      weed/s3api/s3err/audit_fluent_test.go
  15. 13
      weed/s3api/s3err/error_handler.go
  16. 36
      weed/s3api/s3err/error_handler_test.go
  17. 14
      weed/s3api/sts_params_test.go
  18. 22
      weed/server/common.go
  19. 48
      weed/util/request_id/request_id.go
  20. 50
      weed/util/request_id/request_id_test.go

11
test/volume_server/http/admin_test.go

@ -23,8 +23,6 @@ func TestAdminStatusAndHealthz(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("create status request: %v", err) t.Fatalf("create status request: %v", err)
} }
statusReq.Header.Set(request_id.AmzRequestIDHeader, "test-request-id-1")
statusResp := framework.DoRequest(t, client, statusReq) statusResp := framework.DoRequest(t, client, statusReq)
statusBody := framework.ReadAllAndClose(t, statusResp) statusBody := framework.ReadAllAndClose(t, statusResp)
@ -34,8 +32,8 @@ func TestAdminStatusAndHealthz(t *testing.T) {
if got := statusResp.Header.Get("Server"); !strings.Contains(got, "SeaweedFS Volume") { if got := statusResp.Header.Get("Server"); !strings.Contains(got, "SeaweedFS Volume") {
t.Fatalf("expected /status Server header to contain SeaweedFS Volume, got %q", got) t.Fatalf("expected /status Server header to contain SeaweedFS Volume, got %q", got)
} }
if got := statusResp.Header.Get(request_id.AmzRequestIDHeader); got != "test-request-id-1" {
t.Fatalf("expected echoed request id, got %q", got)
if got := statusResp.Header.Get(request_id.AmzRequestIDHeader); got == "" {
t.Fatal("expected server-generated request id in response header")
} }
var payload map[string]interface{} var payload map[string]interface{}
@ -49,7 +47,6 @@ func TestAdminStatusAndHealthz(t *testing.T) {
} }
healthReq := mustNewRequest(t, http.MethodGet, cluster.VolumeAdminURL()+"/healthz") healthReq := mustNewRequest(t, http.MethodGet, cluster.VolumeAdminURL()+"/healthz")
healthReq.Header.Set(request_id.AmzRequestIDHeader, "test-request-id-2")
healthResp := framework.DoRequest(t, client, healthReq) healthResp := framework.DoRequest(t, client, healthReq)
_ = framework.ReadAllAndClose(t, healthResp) _ = framework.ReadAllAndClose(t, healthResp)
if healthResp.StatusCode != http.StatusOK { if healthResp.StatusCode != http.StatusOK {
@ -58,8 +55,8 @@ func TestAdminStatusAndHealthz(t *testing.T) {
if got := healthResp.Header.Get("Server"); !strings.Contains(got, "SeaweedFS Volume") { if got := healthResp.Header.Get("Server"); !strings.Contains(got, "SeaweedFS Volume") {
t.Fatalf("expected /healthz Server header to contain SeaweedFS Volume, got %q", got) t.Fatalf("expected /healthz Server header to contain SeaweedFS Volume, got %q", got)
} }
if got := healthResp.Header.Get(request_id.AmzRequestIDHeader); got != "test-request-id-2" {
t.Fatalf("expected /healthz echoed request id, got %q", got)
if got := healthResp.Header.Get(request_id.AmzRequestIDHeader); got == "" {
t.Fatal("expected /healthz server-generated request id in response header")
} }
uiResp := framework.DoRequest(t, client, mustNewRequest(t, http.MethodGet, cluster.VolumeAdminURL()+"/ui/index.html")) uiResp := framework.DoRequest(t, client, mustNewRequest(t, http.MethodGet, cluster.VolumeAdminURL()+"/ui/index.html"))

34
weed/iam/error_response_test.go

@ -1,34 +0,0 @@
package iam
import (
"encoding/xml"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestErrorResponseXMLUsesTopLevelRequestId(t *testing.T) {
errCode := "NoSuchEntity"
errMsg := "the requested IAM entity does not exist"
resp := ErrorResponse{
RequestId: "request-123",
}
resp.Error.Type = "Sender"
resp.Error.Code = &errCode
resp.Error.Message = &errMsg
output, err := xml.Marshal(resp)
require.NoError(t, err)
xmlString := string(output)
errorIndex := strings.Index(xmlString, "<Error>")
requestIDIndex := strings.Index(xmlString, "<RequestId>request-123</RequestId>")
assert.NotEqual(t, -1, errorIndex, "Error should be present")
assert.NotEqual(t, -1, requestIDIndex, "RequestId should be present")
assert.NotContains(t, xmlString, "<ResponseMetadata>")
assert.Less(t, errorIndex, requestIDIndex, "RequestId should appear after Error at the root level")
}

24
weed/iam/responses.go

@ -2,8 +2,6 @@ package iam
import ( import (
"encoding/xml" "encoding/xml"
"fmt"
"time"
"github.com/aws/aws-sdk-go/service/iam" "github.com/aws/aws-sdk-go/service/iam"
) )
@ -15,9 +13,14 @@ type CommonResponse struct {
} `xml:"ResponseMetadata"` } `xml:"ResponseMetadata"`
} }
// SetRequestId sets a unique request ID based on current timestamp.
func (r *CommonResponse) SetRequestId() {
r.ResponseMetadata.RequestId = newRequestID()
// SetRequestId stores the request ID generated for the current HTTP request.
func (r *CommonResponse) SetRequestId(requestID string) {
r.ResponseMetadata.RequestId = requestID
}
// RequestIDSetter is implemented by IAM responses that can carry a RequestId.
type RequestIDSetter interface {
SetRequestId(string)
} }
// ListUsersResponse is the response for ListUsers action. // ListUsersResponse is the response for ListUsers action.
@ -187,6 +190,7 @@ type GetUserPolicyResponse struct {
} }
// ErrorResponse is the IAM error response format. // ErrorResponse is the IAM error response format.
// AWS IAM uses a bare <RequestId> at root level for errors, not <ResponseMetadata>.
type ErrorResponse struct { type ErrorResponse struct {
XMLName xml.Name `xml:"https://iam.amazonaws.com/doc/2010-05-08/ ErrorResponse"` XMLName xml.Name `xml:"https://iam.amazonaws.com/doc/2010-05-08/ ErrorResponse"`
Error struct { Error struct {
@ -196,13 +200,9 @@ type ErrorResponse struct {
RequestId string `xml:"RequestId"` RequestId string `xml:"RequestId"`
} }
// SetRequestId sets a unique request ID based on current timestamp.
func (r *ErrorResponse) SetRequestId() {
r.RequestId = newRequestID()
}
func newRequestID() string {
return fmt.Sprintf("%d", time.Now().UnixNano())
// SetRequestId stores the request ID generated for the current HTTP request.
func (r *ErrorResponse) SetRequestId(requestID string) {
r.RequestId = requestID
} }
// Error represents an IAM API error with code and underlying error. // Error represents an IAM API error with code and underlying error.

30
weed/iam/responses_test.go

@ -11,7 +11,7 @@ import (
func TestListUsersResponseXMLOrdering(t *testing.T) { func TestListUsersResponseXMLOrdering(t *testing.T) {
resp := ListUsersResponse{} resp := ListUsersResponse{}
resp.SetRequestId()
resp.SetRequestId("test-req-id")
output, err := xml.Marshal(resp) output, err := xml.Marshal(resp)
require.NoError(t, err) require.NoError(t, err)
@ -22,5 +22,31 @@ func TestListUsersResponseXMLOrdering(t *testing.T) {
assert.NotEqual(t, -1, listUsersResultIndex, "ListUsersResult should be present") assert.NotEqual(t, -1, listUsersResultIndex, "ListUsersResult should be present")
assert.NotEqual(t, -1, responseMetadataIndex, "ResponseMetadata should be present") assert.NotEqual(t, -1, responseMetadataIndex, "ResponseMetadata should be present")
assert.Less(t, listUsersResultIndex, responseMetadataIndex, "ListUsersResult should appear before ResponseMetadata")
assert.Less(t, listUsersResultIndex, responseMetadataIndex,
"ListUsersResult should appear before ResponseMetadata")
}
func TestErrorResponseXMLUsesTopLevelRequestId(t *testing.T) {
errCode := "NoSuchEntity"
errMsg := "the requested IAM entity does not exist"
resp := ErrorResponse{}
resp.Error.Type = "Sender"
resp.Error.Code = &errCode
resp.Error.Message = &errMsg
resp.SetRequestId("request-123")
output, err := xml.Marshal(resp)
require.NoError(t, err)
xmlString := string(output)
errorIndex := strings.Index(xmlString, "<Error>")
requestIDIndex := strings.Index(xmlString, "<RequestId>request-123</RequestId>")
assert.NotEqual(t, -1, errorIndex, "Error should be present")
assert.NotEqual(t, -1, requestIDIndex, "RequestId should be present")
assert.NotContains(t, xmlString, "<ResponseMetadata>",
"ErrorResponse should use bare RequestId, not ResponseMetadata wrapper")
assert.Less(t, errorIndex, requestIDIndex,
"RequestId should appear after Error at the root level")
} }

16
weed/iamapi/iamapi_handlers.go

@ -8,20 +8,20 @@ import (
"github.com/seaweedfs/seaweedfs/weed/s3api/s3err" "github.com/seaweedfs/seaweedfs/weed/s3api/s3err"
) )
func newErrorResponse(errCode string, errMsg string) ErrorResponse {
func newErrorResponse(errCode string, errMsg string, requestID string) ErrorResponse {
errorResp := ErrorResponse{} errorResp := ErrorResponse{}
errorResp.Error.Type = "Sender" errorResp.Error.Type = "Sender"
errorResp.Error.Code = &errCode errorResp.Error.Code = &errCode
errorResp.Error.Message = &errMsg errorResp.Error.Message = &errMsg
errorResp.SetRequestId()
errorResp.SetRequestId(requestID)
return errorResp return errorResp
} }
func writeIamErrorResponse(w http.ResponseWriter, r *http.Request, iamError *IamError) {
func writeIamErrorResponse(w http.ResponseWriter, r *http.Request, reqID string, iamError *IamError) {
if iamError == nil { if iamError == nil {
// Do nothing if there is no error
glog.Errorf("No error found")
glog.Errorf("writeIamErrorResponse called with nil error")
internalResp := newErrorResponse(iam.ErrCodeServiceFailureException, "Internal server error", reqID)
s3err.WriteXMLResponse(w, r, http.StatusInternalServerError, internalResp)
return return
} }
@ -29,8 +29,8 @@ func writeIamErrorResponse(w http.ResponseWriter, r *http.Request, iamError *Iam
errMsg := iamError.Error.Error() errMsg := iamError.Error.Error()
glog.Errorf("Response %+v", errMsg) glog.Errorf("Response %+v", errMsg)
errorResp := newErrorResponse(errCode, errMsg)
internalErrorResponse := newErrorResponse(iam.ErrCodeServiceFailureException, "Internal server error")
errorResp := newErrorResponse(errCode, errMsg, reqID)
internalErrorResponse := newErrorResponse(iam.ErrCodeServiceFailureException, "Internal server error", reqID)
switch errCode { switch errCode {
case iam.ErrCodeNoSuchEntityException: case iam.ErrCodeNoSuchEntityException:

219
weed/iamapi/iamapi_management_handlers.go

@ -19,6 +19,7 @@ import (
"github.com/seaweedfs/seaweedfs/weed/pb/iam_pb" "github.com/seaweedfs/seaweedfs/weed/pb/iam_pb"
"github.com/seaweedfs/seaweedfs/weed/s3api/policy_engine" "github.com/seaweedfs/seaweedfs/weed/s3api/policy_engine"
"github.com/seaweedfs/seaweedfs/weed/s3api/s3err" "github.com/seaweedfs/seaweedfs/weed/s3api/s3err"
"github.com/seaweedfs/seaweedfs/weed/util/request_id"
) )
// Constants from shared package // Constants from shared package
@ -162,14 +163,16 @@ func validateAccessKeyStatus(status string) error {
} }
} }
func (iama *IamApiServer) ListUsers(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (resp ListUsersResponse) {
func (iama *IamApiServer) ListUsers(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (resp *ListUsersResponse) {
resp = &ListUsersResponse{}
for _, ident := range s3cfg.Identities { for _, ident := range s3cfg.Identities {
resp.ListUsersResult.Users = append(resp.ListUsersResult.Users, &iam.User{UserName: &ident.Name}) resp.ListUsersResult.Users = append(resp.ListUsersResult.Users, &iam.User{UserName: &ident.Name})
} }
return resp return resp
} }
func (iama *IamApiServer) ListAccessKeys(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (resp ListAccessKeysResponse) {
func (iama *IamApiServer) ListAccessKeys(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (resp *ListAccessKeysResponse) {
resp = &ListAccessKeysResponse{}
userName := values.Get("UserName") userName := values.Get("UserName")
for _, ident := range s3cfg.Identities { for _, ident := range s3cfg.Identities {
if userName != "" && userName != ident.Name { if userName != "" && userName != ident.Name {
@ -190,16 +193,31 @@ func (iama *IamApiServer) ListAccessKeys(s3cfg *iam_pb.S3ApiConfiguration, value
return resp return resp
} }
func (iama *IamApiServer) CreateUser(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (resp CreateUserResponse) {
func (iama *IamApiServer) CreateUser(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (resp *CreateUserResponse) {
resp = &CreateUserResponse{}
userName := values.Get("UserName") userName := values.Get("UserName")
resp.CreateUserResult.User.UserName = &userName resp.CreateUserResult.User.UserName = &userName
s3cfg.Identities = append(s3cfg.Identities, &iam_pb.Identity{Name: userName}) s3cfg.Identities = append(s3cfg.Identities, &iam_pb.Identity{Name: userName})
return resp return resp
} }
func (iama *IamApiServer) DeleteUser(s3cfg *iam_pb.S3ApiConfiguration, userName string) (resp DeleteUserResponse, err *IamError) {
func (iama *IamApiServer) DeleteUser(s3cfg *iam_pb.S3ApiConfiguration, userName string) (resp *DeleteUserResponse, err *IamError) {
resp = &DeleteUserResponse{}
for i, ident := range s3cfg.Identities { for i, ident := range s3cfg.Identities {
if userName == ident.Name { if userName == ident.Name {
// Clean up any inline policies stored for this user
policies := Policies{}
if pErr := iama.s3ApiConfig.GetPolicies(&policies); pErr != nil && !errors.Is(pErr, filer_pb.ErrNotFound) {
return resp, &IamError{Code: iam.ErrCodeServiceFailureException, Error: pErr}
}
if policies.InlinePolicies != nil {
if _, exists := policies.InlinePolicies[userName]; exists {
delete(policies.InlinePolicies, userName)
if pErr := iama.s3ApiConfig.PutPolicies(&policies); pErr != nil {
return resp, &IamError{Code: iam.ErrCodeServiceFailureException, Error: pErr}
}
}
}
s3cfg.Identities = append(s3cfg.Identities[:i], s3cfg.Identities[i+1:]...) s3cfg.Identities = append(s3cfg.Identities[:i], s3cfg.Identities[i+1:]...)
return resp, nil return resp, nil
} }
@ -207,7 +225,8 @@ func (iama *IamApiServer) DeleteUser(s3cfg *iam_pb.S3ApiConfiguration, userName
return resp, &IamError{Code: iam.ErrCodeNoSuchEntityException, Error: fmt.Errorf(USER_DOES_NOT_EXIST, userName)} return resp, &IamError{Code: iam.ErrCodeNoSuchEntityException, Error: fmt.Errorf(USER_DOES_NOT_EXIST, userName)}
} }
func (iama *IamApiServer) GetUser(s3cfg *iam_pb.S3ApiConfiguration, userName string) (resp GetUserResponse, err *IamError) {
func (iama *IamApiServer) GetUser(s3cfg *iam_pb.S3ApiConfiguration, userName string) (resp *GetUserResponse, err *IamError) {
resp = &GetUserResponse{}
for _, ident := range s3cfg.Identities { for _, ident := range s3cfg.Identities {
if userName == ident.Name { if userName == ident.Name {
resp.GetUserResult.User = iam.User{UserName: &ident.Name} resp.GetUserResult.User = iam.User{UserName: &ident.Name}
@ -217,13 +236,28 @@ func (iama *IamApiServer) GetUser(s3cfg *iam_pb.S3ApiConfiguration, userName str
return resp, &IamError{Code: iam.ErrCodeNoSuchEntityException, Error: fmt.Errorf(USER_DOES_NOT_EXIST, userName)} return resp, &IamError{Code: iam.ErrCodeNoSuchEntityException, Error: fmt.Errorf(USER_DOES_NOT_EXIST, userName)}
} }
func (iama *IamApiServer) UpdateUser(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (resp UpdateUserResponse, err *IamError) {
func (iama *IamApiServer) UpdateUser(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (resp *UpdateUserResponse, err *IamError) {
resp = &UpdateUserResponse{}
userName := values.Get("UserName") userName := values.Get("UserName")
newUserName := values.Get("NewUserName") newUserName := values.Get("NewUserName")
if newUserName != "" { if newUserName != "" {
for _, ident := range s3cfg.Identities { for _, ident := range s3cfg.Identities {
if userName == ident.Name { if userName == ident.Name {
ident.Name = newUserName ident.Name = newUserName
// Move any inline policies from old username to new username
policies := Policies{}
if pErr := iama.s3ApiConfig.GetPolicies(&policies); pErr != nil && !errors.Is(pErr, filer_pb.ErrNotFound) {
return resp, &IamError{Code: iam.ErrCodeServiceFailureException, Error: pErr}
}
if policies.InlinePolicies != nil {
if userPolicies, exists := policies.InlinePolicies[userName]; exists {
delete(policies.InlinePolicies, userName)
policies.InlinePolicies[newUserName] = userPolicies
if pErr := iama.s3ApiConfig.PutPolicies(&policies); pErr != nil {
return resp, &IamError{Code: iam.ErrCodeServiceFailureException, Error: pErr}
}
}
}
return resp, nil return resp, nil
} }
} }
@ -241,12 +275,13 @@ func GetPolicyDocument(policy *string) (policy_engine.PolicyDocument, error) {
return policyDocument, nil return policyDocument, nil
} }
func (iama *IamApiServer) CreatePolicy(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (resp CreatePolicyResponse, iamError *IamError) {
func (iama *IamApiServer) CreatePolicy(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (resp *CreatePolicyResponse, iamError *IamError) {
resp = &CreatePolicyResponse{}
policyName := values.Get("PolicyName") policyName := values.Get("PolicyName")
policyDocumentString := values.Get("PolicyDocument") policyDocumentString := values.Get("PolicyDocument")
policyDocument, err := GetPolicyDocument(&policyDocumentString) policyDocument, err := GetPolicyDocument(&policyDocumentString)
if err != nil { if err != nil {
return CreatePolicyResponse{}, &IamError{Code: iam.ErrCodeMalformedPolicyDocumentException, Error: err}
return resp, &IamError{Code: iam.ErrCodeMalformedPolicyDocumentException, Error: err}
} }
policyId := Hash(&policyName) policyId := Hash(&policyName)
arn := fmt.Sprintf("arn:aws:iam:::policy/%s", policyName) arn := fmt.Sprintf("arn:aws:iam:::policy/%s", policyName)
@ -274,16 +309,17 @@ type IamError struct {
} }
// https://docs.aws.amazon.com/IAM/latest/APIReference/API_PutUserPolicy.html // https://docs.aws.amazon.com/IAM/latest/APIReference/API_PutUserPolicy.html
func (iama *IamApiServer) PutUserPolicy(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (resp PutUserPolicyResponse, iamError *IamError) {
func (iama *IamApiServer) PutUserPolicy(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (resp *PutUserPolicyResponse, iamError *IamError) {
resp = &PutUserPolicyResponse{}
userName := values.Get("UserName") userName := values.Get("UserName")
policyName := values.Get("PolicyName") policyName := values.Get("PolicyName")
policyDocumentString := values.Get("PolicyDocument") policyDocumentString := values.Get("PolicyDocument")
policyDocument, err := GetPolicyDocument(&policyDocumentString) policyDocument, err := GetPolicyDocument(&policyDocumentString)
if err != nil { if err != nil {
return PutUserPolicyResponse{}, &IamError{Code: iam.ErrCodeMalformedPolicyDocumentException, Error: err}
return resp, &IamError{Code: iam.ErrCodeMalformedPolicyDocumentException, Error: err}
} }
if _, err := GetActions(&policyDocument); err != nil { if _, err := GetActions(&policyDocument); err != nil {
return PutUserPolicyResponse{}, &IamError{Code: iam.ErrCodeMalformedPolicyDocumentException, Error: err}
return resp, &IamError{Code: iam.ErrCodeMalformedPolicyDocumentException, Error: err}
} }
// Verify the user exists before persisting the policy // Verify the user exists before persisting the policy
@ -295,20 +331,20 @@ func (iama *IamApiServer) PutUserPolicy(s3cfg *iam_pb.S3ApiConfiguration, values
} }
} }
if targetIdent == nil { if targetIdent == nil {
return PutUserPolicyResponse{}, &IamError{Code: iam.ErrCodeNoSuchEntityException, Error: fmt.Errorf("the user with name %s cannot be found", userName)}
return resp, &IamError{Code: iam.ErrCodeNoSuchEntityException, Error: fmt.Errorf("the user with name %s cannot be found", userName)}
} }
// Persist inline policy to storage using per-user indexed structure // Persist inline policy to storage using per-user indexed structure
policies := Policies{} policies := Policies{}
if err = iama.s3ApiConfig.GetPolicies(&policies); err != nil && !errors.Is(err, filer_pb.ErrNotFound) { if err = iama.s3ApiConfig.GetPolicies(&policies); err != nil && !errors.Is(err, filer_pb.ErrNotFound) {
return PutUserPolicyResponse{}, &IamError{Code: iam.ErrCodeServiceFailureException, Error: err}
return resp, &IamError{Code: iam.ErrCodeServiceFailureException, Error: err}
} }
userPolicies := policies.getOrCreateUserPolicies(userName) userPolicies := policies.getOrCreateUserPolicies(userName)
userPolicies[policyName] = policyDocument userPolicies[policyName] = policyDocument
if err = iama.s3ApiConfig.PutPolicies(&policies); err != nil { if err = iama.s3ApiConfig.PutPolicies(&policies); err != nil {
return PutUserPolicyResponse{}, &IamError{Code: iam.ErrCodeServiceFailureException, Error: err}
return resp, &IamError{Code: iam.ErrCodeServiceFailureException, Error: err}
} }
// Recompute aggregated actions (inline + managed) // Recompute aggregated actions (inline + managed)
@ -321,7 +357,8 @@ func (iama *IamApiServer) PutUserPolicy(s3cfg *iam_pb.S3ApiConfiguration, values
return resp, nil return resp, nil
} }
func (iama *IamApiServer) GetUserPolicy(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (resp GetUserPolicyResponse, err *IamError) {
func (iama *IamApiServer) GetUserPolicy(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (resp *GetUserPolicyResponse, err *IamError) {
resp = &GetUserPolicyResponse{}
userName := values.Get("UserName") userName := values.Get("UserName")
policyName := values.Get("PolicyName") policyName := values.Get("PolicyName")
for _, ident := range s3cfg.Identities { for _, ident := range s3cfg.Identities {
@ -404,7 +441,8 @@ func (iama *IamApiServer) GetUserPolicy(s3cfg *iam_pb.S3ApiConfiguration, values
} }
// DeleteUserPolicy removes the inline policy from a user (clears their actions). // DeleteUserPolicy removes the inline policy from a user (clears their actions).
func (iama *IamApiServer) DeleteUserPolicy(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (resp DeleteUserPolicyResponse, err *IamError) {
func (iama *IamApiServer) DeleteUserPolicy(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (resp *DeleteUserPolicyResponse, err *IamError) {
resp = &DeleteUserPolicyResponse{}
userName := values.Get("UserName") userName := values.Get("UserName")
policyName := values.Get("PolicyName") policyName := values.Get("PolicyName")
@ -447,7 +485,8 @@ func (iama *IamApiServer) DeleteUserPolicy(s3cfg *iam_pb.S3ApiConfiguration, val
} }
// GetPolicy retrieves a managed policy by ARN. // GetPolicy retrieves a managed policy by ARN.
func (iama *IamApiServer) GetPolicy(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (resp GetPolicyResponse, iamError *IamError) {
func (iama *IamApiServer) GetPolicy(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (resp *GetPolicyResponse, iamError *IamError) {
resp = &GetPolicyResponse{}
policyArn := values.Get("PolicyArn") policyArn := values.Get("PolicyArn")
policyName, iamError := parsePolicyArn(policyArn) policyName, iamError := parsePolicyArn(policyArn)
if iamError != nil { if iamError != nil {
@ -472,7 +511,8 @@ func (iama *IamApiServer) GetPolicy(s3cfg *iam_pb.S3ApiConfiguration, values url
// DeletePolicy removes a managed policy. Rejects deletion if the policy is still attached to any user // DeletePolicy removes a managed policy. Rejects deletion if the policy is still attached to any user
// (matching AWS IAM behavior: must detach before deleting). // (matching AWS IAM behavior: must detach before deleting).
func (iama *IamApiServer) DeletePolicy(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (resp DeletePolicyResponse, iamError *IamError) {
func (iama *IamApiServer) DeletePolicy(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (resp *DeletePolicyResponse, iamError *IamError) {
resp = &DeletePolicyResponse{}
policyArn := values.Get("PolicyArn") policyArn := values.Get("PolicyArn")
policyName, iamError := parsePolicyArn(policyArn) policyName, iamError := parsePolicyArn(policyArn)
if iamError != nil { if iamError != nil {
@ -509,7 +549,8 @@ func (iama *IamApiServer) DeletePolicy(s3cfg *iam_pb.S3ApiConfiguration, values
} }
// ListPolicies lists all managed policies. // ListPolicies lists all managed policies.
func (iama *IamApiServer) ListPolicies(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (resp ListPoliciesResponse, iamError *IamError) {
func (iama *IamApiServer) ListPolicies(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (resp *ListPoliciesResponse, iamError *IamError) {
resp = &ListPoliciesResponse{}
policies := Policies{} policies := Policies{}
if err := iama.s3ApiConfig.GetPolicies(&policies); err != nil && !errors.Is(err, filer_pb.ErrNotFound) { if err := iama.s3ApiConfig.GetPolicies(&policies); err != nil && !errors.Is(err, filer_pb.ErrNotFound) {
return resp, &IamError{Code: iam.ErrCodeServiceFailureException, Error: err} return resp, &IamError{Code: iam.ErrCodeServiceFailureException, Error: err}
@ -529,7 +570,8 @@ func (iama *IamApiServer) ListPolicies(s3cfg *iam_pb.S3ApiConfiguration, values
} }
// AttachUserPolicy attaches a managed policy to a user. // AttachUserPolicy attaches a managed policy to a user.
func (iama *IamApiServer) AttachUserPolicy(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (resp AttachUserPolicyResponse, iamError *IamError) {
func (iama *IamApiServer) AttachUserPolicy(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (resp *AttachUserPolicyResponse, iamError *IamError) {
resp = &AttachUserPolicyResponse{}
userName := values.Get("UserName") userName := values.Get("UserName")
policyArn := values.Get("PolicyArn") policyArn := values.Get("PolicyArn")
policyName, iamError := parsePolicyArn(policyArn) policyName, iamError := parsePolicyArn(policyArn)
@ -574,7 +616,8 @@ func (iama *IamApiServer) AttachUserPolicy(s3cfg *iam_pb.S3ApiConfiguration, val
} }
// DetachUserPolicy detaches a managed policy from a user. // DetachUserPolicy detaches a managed policy from a user.
func (iama *IamApiServer) DetachUserPolicy(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (resp DetachUserPolicyResponse, iamError *IamError) {
func (iama *IamApiServer) DetachUserPolicy(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (resp *DetachUserPolicyResponse, iamError *IamError) {
resp = &DetachUserPolicyResponse{}
userName := values.Get("UserName") userName := values.Get("UserName")
policyArn := values.Get("PolicyArn") policyArn := values.Get("PolicyArn")
policyName, iamError := parsePolicyArn(policyArn) policyName, iamError := parsePolicyArn(policyArn)
@ -622,7 +665,8 @@ func (iama *IamApiServer) DetachUserPolicy(s3cfg *iam_pb.S3ApiConfiguration, val
} }
// ListAttachedUserPolicies lists the managed policies attached to a user. // ListAttachedUserPolicies lists the managed policies attached to a user.
func (iama *IamApiServer) ListAttachedUserPolicies(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (resp ListAttachedUserPoliciesResponse, iamError *IamError) {
func (iama *IamApiServer) ListAttachedUserPolicies(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (resp *ListAttachedUserPoliciesResponse, iamError *IamError) {
resp = &ListAttachedUserPoliciesResponse{}
userName := values.Get("UserName") userName := values.Get("UserName")
for _, ident := range s3cfg.Identities { for _, ident := range s3cfg.Identities {
if ident.Name != userName { if ident.Name != userName {
@ -714,7 +758,8 @@ func GetActions(policy *policy_engine.PolicyDocument) ([]string, error) {
return actions, nil return actions, nil
} }
func (iama *IamApiServer) CreateAccessKey(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (resp CreateAccessKeyResponse, iamErr *IamError) {
func (iama *IamApiServer) CreateAccessKey(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (resp *CreateAccessKeyResponse, iamErr *IamError) {
resp = &CreateAccessKeyResponse{}
userName := values.Get("UserName") userName := values.Get("UserName")
status := iam.StatusTypeActive status := iam.StatusTypeActive
@ -758,7 +803,8 @@ func (iama *IamApiServer) CreateAccessKey(s3cfg *iam_pb.S3ApiConfiguration, valu
} }
// UpdateAccessKey updates the status of an access key (Active or Inactive). // UpdateAccessKey updates the status of an access key (Active or Inactive).
func (iama *IamApiServer) UpdateAccessKey(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (resp UpdateAccessKeyResponse, err *IamError) {
func (iama *IamApiServer) UpdateAccessKey(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (resp *UpdateAccessKeyResponse, err *IamError) {
resp = &UpdateAccessKeyResponse{}
userName := values.Get("UserName") userName := values.Get("UserName")
accessKeyId := values.Get("AccessKeyId") accessKeyId := values.Get("AccessKeyId")
status := values.Get("Status") status := values.Get("Status")
@ -788,7 +834,8 @@ func (iama *IamApiServer) UpdateAccessKey(s3cfg *iam_pb.S3ApiConfiguration, valu
return resp, &IamError{Code: iam.ErrCodeNoSuchEntityException, Error: fmt.Errorf(USER_DOES_NOT_EXIST, userName)} return resp, &IamError{Code: iam.ErrCodeNoSuchEntityException, Error: fmt.Errorf(USER_DOES_NOT_EXIST, userName)}
} }
func (iama *IamApiServer) DeleteAccessKey(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (resp DeleteAccessKeyResponse) {
func (iama *IamApiServer) DeleteAccessKey(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (resp *DeleteAccessKeyResponse) {
resp = &DeleteAccessKeyResponse{}
userName := values.Get("UserName") userName := values.Get("UserName")
accessKeyId := values.Get("AccessKeyId") accessKeyId := values.Get("AccessKeyId")
for _, ident := range s3cfg.Identities { for _, ident := range s3cfg.Identities {
@ -859,6 +906,8 @@ func (iama *IamApiServer) DoActions(w http.ResponseWriter, r *http.Request) {
policyLock.Lock() policyLock.Lock()
defer policyLock.Unlock() defer policyLock.Unlock()
r, reqID := request_id.Ensure(r)
if err := r.ParseForm(); err != nil { if err := r.ParseForm(); err != nil {
s3err.WriteErrorResponse(w, r, s3err.ErrInvalidRequest) s3err.WriteErrorResponse(w, r, s3err.ErrInvalidRequest)
return return
@ -871,8 +920,7 @@ func (iama *IamApiServer) DoActions(w http.ResponseWriter, r *http.Request) {
} }
glog.V(4).Infof("DoActions: %+v", values) glog.V(4).Infof("DoActions: %+v", values)
var response interface{}
var iamError *IamError
var response iamlib.RequestIDSetter
changed := true changed := true
switch r.Form.Get("Action") { switch r.Form.Get("Action") {
case "ListUsers": case "ListUsers":
@ -886,32 +934,35 @@ func (iama *IamApiServer) DoActions(w http.ResponseWriter, r *http.Request) {
response = iama.CreateUser(s3cfg, values) response = iama.CreateUser(s3cfg, values)
case "GetUser": case "GetUser":
userName := values.Get("UserName") userName := values.Get("UserName")
response, iamError = iama.GetUser(s3cfg, userName)
if iamError != nil {
writeIamErrorResponse(w, r, iamError)
var err *IamError
response, err = iama.GetUser(s3cfg, userName)
if err != nil {
writeIamErrorResponse(w, r, reqID, err)
return return
} }
changed = false changed = false
case "UpdateUser": case "UpdateUser":
response, iamError = iama.UpdateUser(s3cfg, values)
if iamError != nil {
glog.Errorf("UpdateUser: %+v", iamError.Error)
s3err.WriteErrorResponse(w, r, s3err.ErrInvalidRequest)
var err *IamError
response, err = iama.UpdateUser(s3cfg, values)
if err != nil {
writeIamErrorResponse(w, r, reqID, err)
return return
} }
case "DeleteUser": case "DeleteUser":
userName := values.Get("UserName") userName := values.Get("UserName")
response, iamError = iama.DeleteUser(s3cfg, userName)
if iamError != nil {
writeIamErrorResponse(w, r, iamError)
var err *IamError
response, err = iama.DeleteUser(s3cfg, userName)
if err != nil {
writeIamErrorResponse(w, r, reqID, err)
return return
} }
case "CreateAccessKey": case "CreateAccessKey":
iama.handleImplicitUsername(r, values) iama.handleImplicitUsername(r, values)
response, iamError = iama.CreateAccessKey(s3cfg, values)
if iamError != nil {
glog.Errorf("CreateAccessKey: %+v", iamError.Error)
writeIamErrorResponse(w, r, iamError)
var err *IamError
response, err = iama.CreateAccessKey(s3cfg, values)
if err != nil {
glog.Errorf("CreateAccessKey: %+v", err.Error)
writeIamErrorResponse(w, r, reqID, err)
return return
} }
case "DeleteAccessKey": case "DeleteAccessKey":
@ -919,85 +970,94 @@ func (iama *IamApiServer) DoActions(w http.ResponseWriter, r *http.Request) {
response = iama.DeleteAccessKey(s3cfg, values) response = iama.DeleteAccessKey(s3cfg, values)
case "UpdateAccessKey": case "UpdateAccessKey":
iama.handleImplicitUsername(r, values) iama.handleImplicitUsername(r, values)
response, iamError = iama.UpdateAccessKey(s3cfg, values)
if iamError != nil {
writeIamErrorResponse(w, r, iamError)
var err *IamError
response, err = iama.UpdateAccessKey(s3cfg, values)
if err != nil {
writeIamErrorResponse(w, r, reqID, err)
return return
} }
case "CreatePolicy": case "CreatePolicy":
response, iamError = iama.CreatePolicy(s3cfg, values)
if iamError != nil {
glog.Errorf("CreatePolicy: %+v", iamError.Error)
s3err.WriteErrorResponse(w, r, s3err.ErrInvalidRequest)
var err *IamError
response, err = iama.CreatePolicy(s3cfg, values)
if err != nil {
writeIamErrorResponse(w, r, reqID, err)
return return
} }
// CreatePolicy persists the policy document via iama.s3ApiConfig.PutPolicies(). // CreatePolicy persists the policy document via iama.s3ApiConfig.PutPolicies().
// The `changed` flag is false because this does not modify the main s3cfg.Identities configuration. // The `changed` flag is false because this does not modify the main s3cfg.Identities configuration.
changed = false changed = false
case "PutUserPolicy": case "PutUserPolicy":
var iamError *IamError
response, iamError = iama.PutUserPolicy(s3cfg, values)
if iamError != nil {
glog.Errorf("PutUserPolicy: %+v", iamError.Error)
writeIamErrorResponse(w, r, iamError)
var err *IamError
response, err = iama.PutUserPolicy(s3cfg, values)
if err != nil {
glog.Errorf("PutUserPolicy: %+v", err.Error)
writeIamErrorResponse(w, r, reqID, err)
return return
} }
case "GetUserPolicy": case "GetUserPolicy":
response, iamError = iama.GetUserPolicy(s3cfg, values)
if iamError != nil {
writeIamErrorResponse(w, r, iamError)
var err *IamError
response, err = iama.GetUserPolicy(s3cfg, values)
if err != nil {
writeIamErrorResponse(w, r, reqID, err)
return return
} }
changed = false changed = false
case "DeleteUserPolicy": case "DeleteUserPolicy":
if response, iamError = iama.DeleteUserPolicy(s3cfg, values); iamError != nil {
writeIamErrorResponse(w, r, iamError)
var err *IamError
response, err = iama.DeleteUserPolicy(s3cfg, values)
if err != nil {
writeIamErrorResponse(w, r, reqID, err)
return return
} }
case "GetPolicy": case "GetPolicy":
response, iamError = iama.GetPolicy(s3cfg, values)
if iamError != nil {
writeIamErrorResponse(w, r, iamError)
var err *IamError
response, err = iama.GetPolicy(s3cfg, values)
if err != nil {
writeIamErrorResponse(w, r, reqID, err)
return return
} }
changed = false changed = false
case "DeletePolicy": case "DeletePolicy":
response, iamError = iama.DeletePolicy(s3cfg, values)
if iamError != nil {
writeIamErrorResponse(w, r, iamError)
var err *IamError
response, err = iama.DeletePolicy(s3cfg, values)
if err != nil {
writeIamErrorResponse(w, r, reqID, err)
return return
} }
changed = false changed = false
case "ListPolicies": case "ListPolicies":
response, iamError = iama.ListPolicies(s3cfg, values)
if iamError != nil {
writeIamErrorResponse(w, r, iamError)
var err *IamError
response, err = iama.ListPolicies(s3cfg, values)
if err != nil {
writeIamErrorResponse(w, r, reqID, err)
return return
} }
changed = false changed = false
case "AttachUserPolicy": case "AttachUserPolicy":
response, iamError = iama.AttachUserPolicy(s3cfg, values)
if iamError != nil {
writeIamErrorResponse(w, r, iamError)
var err *IamError
response, err = iama.AttachUserPolicy(s3cfg, values)
if err != nil {
writeIamErrorResponse(w, r, reqID, err)
return return
} }
case "DetachUserPolicy": case "DetachUserPolicy":
response, iamError = iama.DetachUserPolicy(s3cfg, values)
if iamError != nil {
writeIamErrorResponse(w, r, iamError)
var err *IamError
response, err = iama.DetachUserPolicy(s3cfg, values)
if err != nil {
writeIamErrorResponse(w, r, reqID, err)
return return
} }
case "ListAttachedUserPolicies": case "ListAttachedUserPolicies":
response, iamError = iama.ListAttachedUserPolicies(s3cfg, values)
if iamError != nil {
writeIamErrorResponse(w, r, iamError)
var err *IamError
response, err = iama.ListAttachedUserPolicies(s3cfg, values)
if err != nil {
writeIamErrorResponse(w, r, reqID, err)
return return
} }
changed = false changed = false
default: default:
errNotImplemented := s3err.GetAPIError(s3err.ErrNotImplemented) errNotImplemented := s3err.GetAPIError(s3err.ErrNotImplemented)
errorResponse := newErrorResponse(errNotImplemented.Code, errNotImplemented.Description)
errorResponse := newErrorResponse(errNotImplemented.Code, errNotImplemented.Description, reqID)
s3err.WriteXMLResponse(w, r, errNotImplemented.HTTPStatusCode, errorResponse) s3err.WriteXMLResponse(w, r, errNotImplemented.HTTPStatusCode, errorResponse)
return return
} }
@ -1005,7 +1065,7 @@ func (iama *IamApiServer) DoActions(w http.ResponseWriter, r *http.Request) {
err := iama.s3ApiConfig.PutS3ApiConfiguration(s3cfg) err := iama.s3ApiConfig.PutS3ApiConfiguration(s3cfg)
if err != nil { if err != nil {
var iamError = IamError{Code: iam.ErrCodeServiceFailureException, Error: err} var iamError = IamError{Code: iam.ErrCodeServiceFailureException, Error: err}
writeIamErrorResponse(w, r, &iamError)
writeIamErrorResponse(w, r, reqID, &iamError)
return return
} }
// Reload in-memory identity maps so subsequent LookupByAccessKey calls // Reload in-memory identity maps so subsequent LookupByAccessKey calls
@ -1017,5 +1077,6 @@ func (iama *IamApiServer) DoActions(w http.ResponseWriter, r *http.Request) {
} }
} }
} }
response.SetRequestId(reqID)
s3err.WriteXMLResponse(w, r, http.StatusOK, response) s3err.WriteXMLResponse(w, r, http.StatusOK, response)
} }

2
weed/iamapi/iamapi_server.go

@ -21,6 +21,7 @@ import (
. "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants" . "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants"
"github.com/seaweedfs/seaweedfs/weed/s3api/s3err" "github.com/seaweedfs/seaweedfs/weed/s3api/s3err"
"github.com/seaweedfs/seaweedfs/weed/util" "github.com/seaweedfs/seaweedfs/weed/util"
"github.com/seaweedfs/seaweedfs/weed/util/request_id"
"github.com/seaweedfs/seaweedfs/weed/wdclient" "github.com/seaweedfs/seaweedfs/weed/wdclient"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
@ -117,6 +118,7 @@ func NewIamApiServerWithStore(router *mux.Router, option *IamServerOption, expli
func (iama *IamApiServer) registerRouter(router *mux.Router) { func (iama *IamApiServer) registerRouter(router *mux.Router) {
// API Router // API Router
apiRouter := router.PathPrefix("/").Subrouter() apiRouter := router.PathPrefix("/").Subrouter()
apiRouter.Use(request_id.Middleware)
// ListBuckets // ListBuckets
// apiRouter.Methods("GET").Path("/").HandlerFunc(track(s3a.iam.Auth(s3a.ListBucketsHandler, ACTION_ADMIN), "LIST")) // apiRouter.Methods("GET").Path("/").HandlerFunc(track(s3a.iam.Auth(s3a.ListBucketsHandler, ACTION_ADMIN), "LIST"))

26
weed/iamapi/iamapi_test.go

@ -17,6 +17,7 @@ import (
"github.com/seaweedfs/seaweedfs/weed/pb/iam_pb" "github.com/seaweedfs/seaweedfs/weed/pb/iam_pb"
"github.com/seaweedfs/seaweedfs/weed/s3api" "github.com/seaweedfs/seaweedfs/weed/s3api"
"github.com/seaweedfs/seaweedfs/weed/s3api/policy_engine" "github.com/seaweedfs/seaweedfs/weed/s3api/policy_engine"
"github.com/seaweedfs/seaweedfs/weed/util/request_id"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@ -73,6 +74,21 @@ func TestListUsers(t *testing.T) {
assert.Equal(t, http.StatusOK, response.Code) assert.Equal(t, http.StatusOK, response.Code)
} }
func TestListUsersRequestIdMatchesResponseHeader(t *testing.T) {
params := &iam.ListUsersInput{}
req, _ := iam.New(session.New()).ListUsersRequest(params)
_ = req.Build()
out := ListUsersResponse{}
response, err := executeRequest(req.HTTPRequest, out)
assert.Equal(t, nil, err)
assert.Equal(t, http.StatusOK, response.Code)
headerRequestID := response.Header().Get(request_id.AmzRequestIDHeader)
assert.NotEmpty(t, headerRequestID)
assert.Equal(t, headerRequestID, extractRequestID(response))
}
func TestListAccessKeys(t *testing.T) { func TestListAccessKeys(t *testing.T) {
svc := iam.New(session.New()) svc := iam.New(session.New())
params := &iam.ListAccessKeysInput{} params := &iam.ListAccessKeysInput{}
@ -246,6 +262,7 @@ func TestPutUserPolicyError(t *testing.T) {
assert.Equal(t, expectedCode, code) assert.Equal(t, expectedCode, code)
assert.Contains(t, response.Body.String(), "<RequestId>") assert.Contains(t, response.Body.String(), "<RequestId>")
assert.NotContains(t, response.Body.String(), "<ResponseMetadata>") assert.NotContains(t, response.Body.String(), "<ResponseMetadata>")
assert.Equal(t, response.Header().Get(request_id.AmzRequestIDHeader), extractRequestID(response))
} }
func extractErrorCodeAndMessage(response *httptest.ResponseRecorder) (string, string) { func extractErrorCodeAndMessage(response *httptest.ResponseRecorder) (string, string) {
@ -257,6 +274,15 @@ func extractErrorCodeAndMessage(response *httptest.ResponseRecorder) (string, st
return code, message return code, message
} }
func extractRequestID(response *httptest.ResponseRecorder) string {
re := regexp.MustCompile(`<RequestId>([^<]+)</RequestId>`)
matches := re.FindStringSubmatch(response.Body.String())
if len(matches) < 2 {
return ""
}
return matches[1]
}
func TestGetUserPolicy(t *testing.T) { func TestGetUserPolicy(t *testing.T) {
userName := aws.String("Test") userName := aws.String("Test")
params := &iam.GetUserPolicyInput{UserName: userName, PolicyName: aws.String("S3-read-only-example-bucket")} params := &iam.GetUserPolicyInput{UserName: userName, PolicyName: aws.String("S3-read-only-example-bucket")}

170
weed/s3api/s3api_embedded_iam.go

@ -26,6 +26,7 @@ import (
"github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants" "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants"
. "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants" . "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants"
"github.com/seaweedfs/seaweedfs/weed/s3api/s3err" "github.com/seaweedfs/seaweedfs/weed/s3api/s3err"
"github.com/seaweedfs/seaweedfs/weed/util/request_id"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
) )
@ -157,18 +158,20 @@ const (
iamAccessKeyStatusInactive = iamlib.AccessKeyStatusInactive iamAccessKeyStatusInactive = iamlib.AccessKeyStatusInactive
) )
func newIamErrorResponse(errCode string, errMsg string) iamErrorResponse {
func newIamErrorResponse(errCode string, errMsg string, requestID string) iamErrorResponse {
errorResp := iamErrorResponse{} errorResp := iamErrorResponse{}
errorResp.Error.Type = "Sender" errorResp.Error.Type = "Sender"
errorResp.Error.Code = &errCode errorResp.Error.Code = &errCode
errorResp.Error.Message = &errMsg errorResp.Error.Message = &errMsg
errorResp.SetRequestId()
errorResp.SetRequestId(requestID)
return errorResp return errorResp
} }
func (e *EmbeddedIamApi) writeIamErrorResponse(w http.ResponseWriter, r *http.Request, iamErr *iamError) {
func (e *EmbeddedIamApi) writeIamErrorResponse(w http.ResponseWriter, r *http.Request, reqID string, iamErr *iamError) {
if iamErr == nil { if iamErr == nil {
glog.Errorf("No error found")
glog.Errorf("writeIamErrorResponse called with nil error")
internalResp := newIamErrorResponse(iam.ErrCodeServiceFailureException, "Internal server error", reqID)
s3err.WriteXMLResponse(w, r, http.StatusInternalServerError, internalResp)
return return
} }
@ -176,8 +179,8 @@ func (e *EmbeddedIamApi) writeIamErrorResponse(w http.ResponseWriter, r *http.Re
errMsg := iamErr.Error.Error() errMsg := iamErr.Error.Error()
glog.Errorf("IAM Response %+v", errMsg) glog.Errorf("IAM Response %+v", errMsg)
errorResp := newIamErrorResponse(errCode, errMsg)
internalErrorResponse := newIamErrorResponse(iam.ErrCodeServiceFailureException, "Internal server error")
errorResp := newIamErrorResponse(errCode, errMsg, reqID)
internalErrorResponse := newIamErrorResponse(iam.ErrCodeServiceFailureException, "Internal server error", reqID)
switch errCode { switch errCode {
case iam.ErrCodeNoSuchEntityException: case iam.ErrCodeNoSuchEntityException:
@ -230,8 +233,8 @@ func (e *EmbeddedIamApi) ReloadConfiguration() error {
} }
// ListUsers lists all IAM users. // ListUsers lists all IAM users.
func (e *EmbeddedIamApi) ListUsers(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) iamListUsersResponse {
var resp iamListUsersResponse
func (e *EmbeddedIamApi) ListUsers(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) *iamListUsersResponse {
resp := &iamListUsersResponse{}
for _, ident := range s3cfg.Identities { for _, ident := range s3cfg.Identities {
resp.ListUsersResult.Users = append(resp.ListUsersResult.Users, &iam.User{UserName: &ident.Name}) resp.ListUsersResult.Users = append(resp.ListUsersResult.Users, &iam.User{UserName: &ident.Name})
} }
@ -239,8 +242,8 @@ func (e *EmbeddedIamApi) ListUsers(s3cfg *iam_pb.S3ApiConfiguration, values url.
} }
// ListAccessKeys lists access keys for a user. // ListAccessKeys lists access keys for a user.
func (e *EmbeddedIamApi) ListAccessKeys(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) iamListAccessKeysResponse {
var resp iamListAccessKeysResponse
func (e *EmbeddedIamApi) ListAccessKeys(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) *iamListAccessKeysResponse {
resp := &iamListAccessKeysResponse{}
userName := values.Get("UserName") userName := values.Get("UserName")
for _, ident := range s3cfg.Identities { for _, ident := range s3cfg.Identities {
if userName != "" && userName != ident.Name { if userName != "" && userName != ident.Name {
@ -265,8 +268,8 @@ func (e *EmbeddedIamApi) ListAccessKeys(s3cfg *iam_pb.S3ApiConfiguration, values
} }
// CreateUser creates a new IAM user. // CreateUser creates a new IAM user.
func (e *EmbeddedIamApi) CreateUser(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (iamCreateUserResponse, *iamError) {
var resp iamCreateUserResponse
func (e *EmbeddedIamApi) CreateUser(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (*iamCreateUserResponse, *iamError) {
resp := &iamCreateUserResponse{}
userName := values.Get("UserName") userName := values.Get("UserName")
// Validate UserName is not empty // Validate UserName is not empty
@ -287,8 +290,8 @@ func (e *EmbeddedIamApi) CreateUser(s3cfg *iam_pb.S3ApiConfiguration, values url
} }
// DeleteUser deletes an IAM user. // DeleteUser deletes an IAM user.
func (e *EmbeddedIamApi) DeleteUser(s3cfg *iam_pb.S3ApiConfiguration, userName string) (iamDeleteUserResponse, *iamError) {
var resp iamDeleteUserResponse
func (e *EmbeddedIamApi) DeleteUser(s3cfg *iam_pb.S3ApiConfiguration, userName string) (*iamDeleteUserResponse, *iamError) {
resp := &iamDeleteUserResponse{}
for i, ident := range s3cfg.Identities { for i, ident := range s3cfg.Identities {
if userName == ident.Name { if userName == ident.Name {
// AWS IAM behavior: prevent deletion if user has service accounts // AWS IAM behavior: prevent deletion if user has service accounts
@ -308,8 +311,8 @@ func (e *EmbeddedIamApi) DeleteUser(s3cfg *iam_pb.S3ApiConfiguration, userName s
} }
// GetUser gets an IAM user. // GetUser gets an IAM user.
func (e *EmbeddedIamApi) GetUser(s3cfg *iam_pb.S3ApiConfiguration, userName string) (iamGetUserResponse, *iamError) {
var resp iamGetUserResponse
func (e *EmbeddedIamApi) GetUser(s3cfg *iam_pb.S3ApiConfiguration, userName string) (*iamGetUserResponse, *iamError) {
resp := &iamGetUserResponse{}
for _, ident := range s3cfg.Identities { for _, ident := range s3cfg.Identities {
if userName == ident.Name { if userName == ident.Name {
resp.GetUserResult.User = iam.User{UserName: &ident.Name} resp.GetUserResult.User = iam.User{UserName: &ident.Name}
@ -320,8 +323,8 @@ func (e *EmbeddedIamApi) GetUser(s3cfg *iam_pb.S3ApiConfiguration, userName stri
} }
// UpdateUser updates an IAM user. // UpdateUser updates an IAM user.
func (e *EmbeddedIamApi) UpdateUser(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (iamUpdateUserResponse, *iamError) {
var resp iamUpdateUserResponse
func (e *EmbeddedIamApi) UpdateUser(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (*iamUpdateUserResponse, *iamError) {
resp := &iamUpdateUserResponse{}
userName := values.Get("UserName") userName := values.Get("UserName")
newUserName := values.Get("NewUserName") newUserName := values.Get("NewUserName")
if newUserName != "" { if newUserName != "" {
@ -338,8 +341,8 @@ func (e *EmbeddedIamApi) UpdateUser(s3cfg *iam_pb.S3ApiConfiguration, values url
} }
// CreateAccessKey creates an access key for a user. // CreateAccessKey creates an access key for a user.
func (e *EmbeddedIamApi) CreateAccessKey(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (iamCreateAccessKeyResponse, *iamError) {
var resp iamCreateAccessKeyResponse
func (e *EmbeddedIamApi) CreateAccessKey(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (*iamCreateAccessKeyResponse, *iamError) {
resp := &iamCreateAccessKeyResponse{}
userName := values.Get("UserName") userName := values.Get("UserName")
status := iam.StatusTypeActive status := iam.StatusTypeActive
@ -372,8 +375,8 @@ func (e *EmbeddedIamApi) CreateAccessKey(s3cfg *iam_pb.S3ApiConfiguration, value
} }
// DeleteAccessKey deletes an access key for a user. // DeleteAccessKey deletes an access key for a user.
func (e *EmbeddedIamApi) DeleteAccessKey(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) iamDeleteAccessKeyResponse {
var resp iamDeleteAccessKeyResponse
func (e *EmbeddedIamApi) DeleteAccessKey(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) *iamDeleteAccessKeyResponse {
resp := &iamDeleteAccessKeyResponse{}
userName := values.Get("UserName") userName := values.Get("UserName")
accessKeyId := values.Get("AccessKeyId") accessKeyId := values.Get("AccessKeyId")
for _, ident := range s3cfg.Identities { for _, ident := range s3cfg.Identities {
@ -400,8 +403,8 @@ func (e *EmbeddedIamApi) GetPolicyDocument(policy *string) (policy_engine.Policy
} }
// CreatePolicy validates and creates a new IAM managed policy. // CreatePolicy validates and creates a new IAM managed policy.
func (e *EmbeddedIamApi) CreatePolicy(ctx context.Context, values url.Values) (iamCreatePolicyResponse, *iamError) {
var resp iamCreatePolicyResponse
func (e *EmbeddedIamApi) CreatePolicy(ctx context.Context, values url.Values) (*iamCreatePolicyResponse, *iamError) {
resp := &iamCreatePolicyResponse{}
policyName := values.Get("PolicyName") policyName := values.Get("PolicyName")
policyDocumentString := values.Get("PolicyDocument") policyDocumentString := values.Get("PolicyDocument")
if policyName == "" { if policyName == "" {
@ -443,8 +446,8 @@ func (e *EmbeddedIamApi) CreatePolicy(ctx context.Context, values url.Values) (i
} }
// DeletePolicy deletes a managed policy by ARN. // DeletePolicy deletes a managed policy by ARN.
func (e *EmbeddedIamApi) DeletePolicy(ctx context.Context, values url.Values) (iamDeletePolicyResponse, *iamError) {
var resp iamDeletePolicyResponse
func (e *EmbeddedIamApi) DeletePolicy(ctx context.Context, values url.Values) (*iamDeletePolicyResponse, *iamError) {
resp := &iamDeletePolicyResponse{}
policyArn := values.Get("PolicyArn") policyArn := values.Get("PolicyArn")
policyName, err := iamPolicyNameFromArn(policyArn) policyName, err := iamPolicyNameFromArn(policyArn)
if err != nil { if err != nil {
@ -485,8 +488,8 @@ func (e *EmbeddedIamApi) DeletePolicy(ctx context.Context, values url.Values) (i
} }
// ListPolicies lists managed policies. // ListPolicies lists managed policies.
func (e *EmbeddedIamApi) ListPolicies(ctx context.Context, values url.Values) (iamListPoliciesResponse, *iamError) {
var resp iamListPoliciesResponse
func (e *EmbeddedIamApi) ListPolicies(ctx context.Context, values url.Values) (*iamListPoliciesResponse, *iamError) {
resp := &iamListPoliciesResponse{}
pathPrefix := values.Get("PathPrefix") pathPrefix := values.Get("PathPrefix")
if pathPrefix == "" { if pathPrefix == "" {
pathPrefix = "/" pathPrefix = "/"
@ -558,8 +561,8 @@ func (e *EmbeddedIamApi) ListPolicies(ctx context.Context, values url.Values) (i
} }
// GetPolicy returns metadata for a managed policy. // GetPolicy returns metadata for a managed policy.
func (e *EmbeddedIamApi) GetPolicy(ctx context.Context, values url.Values) (iamGetPolicyResponse, *iamError) {
var resp iamGetPolicyResponse
func (e *EmbeddedIamApi) GetPolicy(ctx context.Context, values url.Values) (*iamGetPolicyResponse, *iamError) {
resp := &iamGetPolicyResponse{}
policyArn := values.Get("PolicyArn") policyArn := values.Get("PolicyArn")
policyName, err := iamPolicyNameFromArn(policyArn) policyName, err := iamPolicyNameFromArn(policyArn)
if err != nil { if err != nil {
@ -595,8 +598,8 @@ func (e *EmbeddedIamApi) GetPolicy(ctx context.Context, values url.Values) (iamG
// ListPolicyVersions lists versions for a managed policy. // ListPolicyVersions lists versions for a managed policy.
// Current SeaweedFS implementation stores one version per policy (v1). // Current SeaweedFS implementation stores one version per policy (v1).
func (e *EmbeddedIamApi) ListPolicyVersions(ctx context.Context, values url.Values) (iamListPolicyVersionsResponse, *iamError) {
var resp iamListPolicyVersionsResponse
func (e *EmbeddedIamApi) ListPolicyVersions(ctx context.Context, values url.Values) (*iamListPolicyVersionsResponse, *iamError) {
resp := &iamListPolicyVersionsResponse{}
policyArn := values.Get("PolicyArn") policyArn := values.Get("PolicyArn")
policyName, err := iamPolicyNameFromArn(policyArn) policyName, err := iamPolicyNameFromArn(policyArn)
if err != nil { if err != nil {
@ -625,8 +628,8 @@ func (e *EmbeddedIamApi) ListPolicyVersions(ctx context.Context, values url.Valu
// GetPolicyVersion returns the document for a specific policy version. // GetPolicyVersion returns the document for a specific policy version.
// Current SeaweedFS implementation stores one version per policy (v1). // Current SeaweedFS implementation stores one version per policy (v1).
func (e *EmbeddedIamApi) GetPolicyVersion(ctx context.Context, values url.Values) (iamGetPolicyVersionResponse, *iamError) {
var resp iamGetPolicyVersionResponse
func (e *EmbeddedIamApi) GetPolicyVersion(ctx context.Context, values url.Values) (*iamGetPolicyVersionResponse, *iamError) {
resp := &iamGetPolicyVersionResponse{}
policyArn := values.Get("PolicyArn") policyArn := values.Get("PolicyArn")
versionID := values.Get("VersionId") versionID := values.Get("VersionId")
if versionID == "" { if versionID == "" {
@ -754,8 +757,8 @@ func (e *EmbeddedIamApi) getActions(policy *policy_engine.PolicyDocument) ([]str
} }
// PutUserPolicy attaches a policy to a user. // PutUserPolicy attaches a policy to a user.
func (e *EmbeddedIamApi) PutUserPolicy(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (iamPutUserPolicyResponse, *iamError) {
var resp iamPutUserPolicyResponse
func (e *EmbeddedIamApi) PutUserPolicy(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (*iamPutUserPolicyResponse, *iamError) {
resp := &iamPutUserPolicyResponse{}
userName := values.Get("UserName") userName := values.Get("UserName")
policyDocumentString := values.Get("PolicyDocument") policyDocumentString := values.Get("PolicyDocument")
policyDocument, err := e.GetPolicyDocument(&policyDocumentString) policyDocument, err := e.GetPolicyDocument(&policyDocumentString)
@ -778,8 +781,8 @@ func (e *EmbeddedIamApi) PutUserPolicy(s3cfg *iam_pb.S3ApiConfiguration, values
} }
// GetUserPolicy gets the policy attached to a user. // GetUserPolicy gets the policy attached to a user.
func (e *EmbeddedIamApi) GetUserPolicy(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (iamGetUserPolicyResponse, *iamError) {
var resp iamGetUserPolicyResponse
func (e *EmbeddedIamApi) GetUserPolicy(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (*iamGetUserPolicyResponse, *iamError) {
resp := &iamGetUserPolicyResponse{}
userName := values.Get("UserName") userName := values.Get("UserName")
policyName := values.Get("PolicyName") policyName := values.Get("PolicyName")
for _, ident := range s3cfg.Identities { for _, ident := range s3cfg.Identities {
@ -845,8 +848,8 @@ func (e *EmbeddedIamApi) GetUserPolicy(s3cfg *iam_pb.S3ApiConfiguration, values
} }
// DeleteUserPolicy removes the inline policy from a user (clears their actions). // DeleteUserPolicy removes the inline policy from a user (clears their actions).
func (e *EmbeddedIamApi) DeleteUserPolicy(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (iamDeleteUserPolicyResponse, *iamError) {
var resp iamDeleteUserPolicyResponse
func (e *EmbeddedIamApi) DeleteUserPolicy(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (*iamDeleteUserPolicyResponse, *iamError) {
resp := &iamDeleteUserPolicyResponse{}
userName := values.Get("UserName") userName := values.Get("UserName")
for _, ident := range s3cfg.Identities { for _, ident := range s3cfg.Identities {
if ident.Name == userName { if ident.Name == userName {
@ -858,8 +861,8 @@ func (e *EmbeddedIamApi) DeleteUserPolicy(s3cfg *iam_pb.S3ApiConfiguration, valu
} }
// AttachUserPolicy attaches a managed policy to a user. // AttachUserPolicy attaches a managed policy to a user.
func (e *EmbeddedIamApi) AttachUserPolicy(ctx context.Context, values url.Values) (iamAttachUserPolicyResponse, *iamError) {
var resp iamAttachUserPolicyResponse
func (e *EmbeddedIamApi) AttachUserPolicy(ctx context.Context, values url.Values) (*iamAttachUserPolicyResponse, *iamError) {
resp := &iamAttachUserPolicyResponse{}
userName := values.Get("UserName") userName := values.Get("UserName")
if userName == "" { if userName == "" {
@ -926,8 +929,8 @@ func (e *EmbeddedIamApi) AttachUserPolicy(ctx context.Context, values url.Values
} }
// DetachUserPolicy detaches a managed policy from a user. // DetachUserPolicy detaches a managed policy from a user.
func (e *EmbeddedIamApi) DetachUserPolicy(ctx context.Context, values url.Values) (iamDetachUserPolicyResponse, *iamError) {
var resp iamDetachUserPolicyResponse
func (e *EmbeddedIamApi) DetachUserPolicy(ctx context.Context, values url.Values) (*iamDetachUserPolicyResponse, *iamError) {
resp := &iamDetachUserPolicyResponse{}
userName := values.Get("UserName") userName := values.Get("UserName")
if userName == "" { if userName == "" {
@ -971,8 +974,8 @@ func (e *EmbeddedIamApi) DetachUserPolicy(ctx context.Context, values url.Values
} }
// ListAttachedUserPolicies lists managed policies attached to a user. // ListAttachedUserPolicies lists managed policies attached to a user.
func (e *EmbeddedIamApi) ListAttachedUserPolicies(ctx context.Context, values url.Values) (iamListAttachedUserPoliciesResponse, *iamError) {
var resp iamListAttachedUserPoliciesResponse
func (e *EmbeddedIamApi) ListAttachedUserPolicies(ctx context.Context, values url.Values) (*iamListAttachedUserPoliciesResponse, *iamError) {
resp := &iamListAttachedUserPoliciesResponse{}
userName := values.Get("UserName") userName := values.Get("UserName")
if userName == "" { if userName == "" {
@ -1056,8 +1059,8 @@ func (e *EmbeddedIamApi) ListAttachedUserPolicies(ctx context.Context, values ur
// SetUserStatus enables or disables a user without deleting them. // SetUserStatus enables or disables a user without deleting them.
// This is a SeaweedFS extension for temporary user suspension, offboarding, etc. // This is a SeaweedFS extension for temporary user suspension, offboarding, etc.
// When a user is disabled, all API requests using their credentials will return AccessDenied. // When a user is disabled, all API requests using their credentials will return AccessDenied.
func (e *EmbeddedIamApi) SetUserStatus(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (iamSetUserStatusResponse, *iamError) {
var resp iamSetUserStatusResponse
func (e *EmbeddedIamApi) SetUserStatus(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (*iamSetUserStatusResponse, *iamError) {
resp := &iamSetUserStatusResponse{}
userName := values.Get("UserName") userName := values.Get("UserName")
status := values.Get("Status") status := values.Get("Status")
@ -1083,8 +1086,8 @@ func (e *EmbeddedIamApi) SetUserStatus(s3cfg *iam_pb.S3ApiConfiguration, values
// UpdateAccessKey updates the status of an access key (Active or Inactive). // UpdateAccessKey updates the status of an access key (Active or Inactive).
// This allows key rotation workflows where old keys are deactivated before deletion. // This allows key rotation workflows where old keys are deactivated before deletion.
func (e *EmbeddedIamApi) UpdateAccessKey(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (iamUpdateAccessKeyResponse, *iamError) {
var resp iamUpdateAccessKeyResponse
func (e *EmbeddedIamApi) UpdateAccessKey(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (*iamUpdateAccessKeyResponse, *iamError) {
resp := &iamUpdateAccessKeyResponse{}
userName := values.Get("UserName") userName := values.Get("UserName")
accessKeyId := values.Get("AccessKeyId") accessKeyId := values.Get("AccessKeyId")
status := values.Get("Status") status := values.Get("Status")
@ -1129,8 +1132,8 @@ func findIdentityByName(s3cfg *iam_pb.S3ApiConfiguration, name string) *iam_pb.I
} }
// CreateServiceAccount creates a new service account for a user. // CreateServiceAccount creates a new service account for a user.
func (e *EmbeddedIamApi) CreateServiceAccount(s3cfg *iam_pb.S3ApiConfiguration, values url.Values, createdBy string) (iamCreateServiceAccountResponse, *iamError) {
var resp iamCreateServiceAccountResponse
func (e *EmbeddedIamApi) CreateServiceAccount(s3cfg *iam_pb.S3ApiConfiguration, values url.Values, createdBy string) (*iamCreateServiceAccountResponse, *iamError) {
resp := &iamCreateServiceAccountResponse{}
parentUser := values.Get("ParentUser") parentUser := values.Get("ParentUser")
description := values.Get("Description") description := values.Get("Description")
expirationStr := values.Get("Expiration") // Unix timestamp as string expirationStr := values.Get("Expiration") // Unix timestamp as string
@ -1239,8 +1242,8 @@ func (e *EmbeddedIamApi) CreateServiceAccount(s3cfg *iam_pb.S3ApiConfiguration,
} }
// DeleteServiceAccount deletes a service account. // DeleteServiceAccount deletes a service account.
func (e *EmbeddedIamApi) DeleteServiceAccount(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (iamDeleteServiceAccountResponse, *iamError) {
var resp iamDeleteServiceAccountResponse
func (e *EmbeddedIamApi) DeleteServiceAccount(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (*iamDeleteServiceAccountResponse, *iamError) {
resp := &iamDeleteServiceAccountResponse{}
saId := values.Get("ServiceAccountId") saId := values.Get("ServiceAccountId")
if saId == "" { if saId == "" {
@ -1272,8 +1275,8 @@ func (e *EmbeddedIamApi) DeleteServiceAccount(s3cfg *iam_pb.S3ApiConfiguration,
} }
// ListServiceAccounts lists service accounts, optionally filtered by parent user. // ListServiceAccounts lists service accounts, optionally filtered by parent user.
func (e *EmbeddedIamApi) ListServiceAccounts(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) iamListServiceAccountsResponse {
var resp iamListServiceAccountsResponse
func (e *EmbeddedIamApi) ListServiceAccounts(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) *iamListServiceAccountsResponse {
resp := &iamListServiceAccountsResponse{}
parentUser := values.Get("ParentUser") // Optional filter parentUser := values.Get("ParentUser") // Optional filter
for _, sa := range s3cfg.ServiceAccounts { for _, sa := range s3cfg.ServiceAccounts {
@ -1307,8 +1310,8 @@ func (e *EmbeddedIamApi) ListServiceAccounts(s3cfg *iam_pb.S3ApiConfiguration, v
} }
// GetServiceAccount retrieves a service account by ID. // GetServiceAccount retrieves a service account by ID.
func (e *EmbeddedIamApi) GetServiceAccount(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (iamGetServiceAccountResponse, *iamError) {
var resp iamGetServiceAccountResponse
func (e *EmbeddedIamApi) GetServiceAccount(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (*iamGetServiceAccountResponse, *iamError) {
resp := &iamGetServiceAccountResponse{}
saId := values.Get("ServiceAccountId") saId := values.Get("ServiceAccountId")
if saId == "" { if saId == "" {
@ -1344,8 +1347,8 @@ func (e *EmbeddedIamApi) GetServiceAccount(s3cfg *iam_pb.S3ApiConfiguration, val
} }
// UpdateServiceAccount updates a service account's status, description, or expiration. // UpdateServiceAccount updates a service account's status, description, or expiration.
func (e *EmbeddedIamApi) UpdateServiceAccount(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (iamUpdateServiceAccountResponse, *iamError) {
var resp iamUpdateServiceAccountResponse
func (e *EmbeddedIamApi) UpdateServiceAccount(s3cfg *iam_pb.S3ApiConfiguration, values url.Values) (*iamUpdateServiceAccountResponse, *iamError) {
resp := &iamUpdateServiceAccountResponse{}
saId := values.Get("ServiceAccountId") saId := values.Get("ServiceAccountId")
newStatus := values.Get("Status") newStatus := values.Get("Status")
newDescription := values.Get("Description") newDescription := values.Get("Description")
@ -1543,7 +1546,11 @@ func (e *EmbeddedIamApi) AuthIam(f http.HandlerFunc, _ Action) http.HandlerFunc
// ExecuteAction executes an IAM action with the given values. // ExecuteAction executes an IAM action with the given values.
// If skipPersist is true, the changed configuration is not saved to the persistent store. // If skipPersist is true, the changed configuration is not saved to the persistent store.
func (e *EmbeddedIamApi) ExecuteAction(ctx context.Context, values url.Values, skipPersist bool) (interface{}, *iamError) {
// reqID is set on the response; if empty, a new request ID is generated.
func (e *EmbeddedIamApi) ExecuteAction(ctx context.Context, values url.Values, skipPersist bool, reqID string) (iamlib.RequestIDSetter, *iamError) {
if reqID == "" {
reqID = request_id.New()
}
// Lock to prevent concurrent read-modify-write race conditions // Lock to prevent concurrent read-modify-write race conditions
e.policyLock.Lock() e.policyLock.Lock()
defer e.policyLock.Unlock() defer e.policyLock.Unlock()
@ -1564,8 +1571,7 @@ func (e *EmbeddedIamApi) ExecuteAction(ctx context.Context, values url.Values, s
} }
glog.V(4).Infof("IAM ExecuteAction: %+v", values) glog.V(4).Infof("IAM ExecuteAction: %+v", values)
var response interface{}
var iamErr *iamError
var response iamlib.RequestIDSetter
changed := true changed := true
switch values.Get("Action") { switch values.Get("Action") {
case "ListUsers": case "ListUsers":
@ -1577,29 +1583,34 @@ func (e *EmbeddedIamApi) ExecuteAction(ctx context.Context, values url.Values, s
response = e.ListAccessKeys(s3cfg, values) response = e.ListAccessKeys(s3cfg, values)
changed = false changed = false
case "CreateUser": case "CreateUser":
var iamErr *iamError
response, iamErr = e.CreateUser(s3cfg, values) response, iamErr = e.CreateUser(s3cfg, values)
if iamErr != nil { if iamErr != nil {
return nil, iamErr return nil, iamErr
} }
case "GetUser": case "GetUser":
userName := values.Get("UserName") userName := values.Get("UserName")
var iamErr *iamError
response, iamErr = e.GetUser(s3cfg, userName) response, iamErr = e.GetUser(s3cfg, userName)
if iamErr != nil { if iamErr != nil {
return nil, iamErr return nil, iamErr
} }
changed = false changed = false
case "UpdateUser": case "UpdateUser":
var iamErr *iamError
response, iamErr = e.UpdateUser(s3cfg, values) response, iamErr = e.UpdateUser(s3cfg, values)
if iamErr != nil { if iamErr != nil {
return nil, iamErr return nil, iamErr
} }
case "DeleteUser": case "DeleteUser":
userName := values.Get("UserName") userName := values.Get("UserName")
var iamErr *iamError
response, iamErr = e.DeleteUser(s3cfg, userName) response, iamErr = e.DeleteUser(s3cfg, userName)
if iamErr != nil { if iamErr != nil {
return nil, iamErr return nil, iamErr
} }
case "CreateAccessKey": case "CreateAccessKey":
var iamErr *iamError
response, iamErr = e.CreateAccessKey(s3cfg, values) response, iamErr = e.CreateAccessKey(s3cfg, values)
if iamErr != nil { if iamErr != nil {
glog.Errorf("CreateAccessKey: %+v", iamErr.Error) glog.Errorf("CreateAccessKey: %+v", iamErr.Error)
@ -1608,6 +1619,7 @@ func (e *EmbeddedIamApi) ExecuteAction(ctx context.Context, values url.Values, s
case "DeleteAccessKey": case "DeleteAccessKey":
response = e.DeleteAccessKey(s3cfg, values) response = e.DeleteAccessKey(s3cfg, values)
case "CreatePolicy": case "CreatePolicy":
var iamErr *iamError
response, iamErr = e.CreatePolicy(ctx, values) response, iamErr = e.CreatePolicy(ctx, values)
if iamErr != nil { if iamErr != nil {
glog.Errorf("CreatePolicy: %+v", iamErr.Error) glog.Errorf("CreatePolicy: %+v", iamErr.Error)
@ -1615,6 +1627,7 @@ func (e *EmbeddedIamApi) ExecuteAction(ctx context.Context, values url.Values, s
} }
changed = false changed = false
case "DeletePolicy": case "DeletePolicy":
var iamErr *iamError
response, iamErr = e.DeletePolicy(ctx, values) response, iamErr = e.DeletePolicy(ctx, values)
if iamErr != nil { if iamErr != nil {
glog.Errorf("DeletePolicy: %+v", iamErr.Error) glog.Errorf("DeletePolicy: %+v", iamErr.Error)
@ -1622,70 +1635,82 @@ func (e *EmbeddedIamApi) ExecuteAction(ctx context.Context, values url.Values, s
} }
changed = false changed = false
case "PutUserPolicy": case "PutUserPolicy":
var iamErr *iamError
response, iamErr = e.PutUserPolicy(s3cfg, values) response, iamErr = e.PutUserPolicy(s3cfg, values)
if iamErr != nil { if iamErr != nil {
glog.Errorf("PutUserPolicy: %+v", iamErr.Error) glog.Errorf("PutUserPolicy: %+v", iamErr.Error)
return nil, iamErr return nil, iamErr
} }
case "GetUserPolicy": case "GetUserPolicy":
var iamErr *iamError
response, iamErr = e.GetUserPolicy(s3cfg, values) response, iamErr = e.GetUserPolicy(s3cfg, values)
if iamErr != nil { if iamErr != nil {
return nil, iamErr return nil, iamErr
} }
changed = false changed = false
case "DeleteUserPolicy": case "DeleteUserPolicy":
var iamErr *iamError
response, iamErr = e.DeleteUserPolicy(s3cfg, values) response, iamErr = e.DeleteUserPolicy(s3cfg, values)
if iamErr != nil { if iamErr != nil {
return nil, iamErr return nil, iamErr
} }
case "AttachUserPolicy": case "AttachUserPolicy":
var iamErr *iamError
response, iamErr = e.AttachUserPolicy(ctx, values) response, iamErr = e.AttachUserPolicy(ctx, values)
if iamErr != nil { if iamErr != nil {
return nil, iamErr return nil, iamErr
} }
changed = false changed = false
case "DetachUserPolicy": case "DetachUserPolicy":
var iamErr *iamError
response, iamErr = e.DetachUserPolicy(ctx, values) response, iamErr = e.DetachUserPolicy(ctx, values)
if iamErr != nil { if iamErr != nil {
return nil, iamErr return nil, iamErr
} }
changed = false changed = false
case "ListAttachedUserPolicies": case "ListAttachedUserPolicies":
var iamErr *iamError
response, iamErr = e.ListAttachedUserPolicies(ctx, values) response, iamErr = e.ListAttachedUserPolicies(ctx, values)
if iamErr != nil { if iamErr != nil {
return nil, iamErr return nil, iamErr
} }
changed = false changed = false
case "ListPolicies": case "ListPolicies":
var iamErr *iamError
response, iamErr = e.ListPolicies(ctx, values) response, iamErr = e.ListPolicies(ctx, values)
if iamErr != nil { if iamErr != nil {
return nil, iamErr return nil, iamErr
} }
changed = false changed = false
case "GetPolicy": case "GetPolicy":
var iamErr *iamError
response, iamErr = e.GetPolicy(ctx, values) response, iamErr = e.GetPolicy(ctx, values)
if iamErr != nil { if iamErr != nil {
return nil, iamErr return nil, iamErr
} }
changed = false changed = false
case "ListPolicyVersions": case "ListPolicyVersions":
var iamErr *iamError
response, iamErr = e.ListPolicyVersions(ctx, values) response, iamErr = e.ListPolicyVersions(ctx, values)
if iamErr != nil { if iamErr != nil {
return nil, iamErr return nil, iamErr
} }
changed = false changed = false
case "GetPolicyVersion": case "GetPolicyVersion":
var iamErr *iamError
response, iamErr = e.GetPolicyVersion(ctx, values) response, iamErr = e.GetPolicyVersion(ctx, values)
if iamErr != nil { if iamErr != nil {
return nil, iamErr return nil, iamErr
} }
changed = false changed = false
case "SetUserStatus": case "SetUserStatus":
var iamErr *iamError
response, iamErr = e.SetUserStatus(s3cfg, values) response, iamErr = e.SetUserStatus(s3cfg, values)
if iamErr != nil { if iamErr != nil {
return nil, iamErr return nil, iamErr
} }
case "UpdateAccessKey": case "UpdateAccessKey":
var iamErr *iamError
response, iamErr = e.UpdateAccessKey(s3cfg, values) response, iamErr = e.UpdateAccessKey(s3cfg, values)
if iamErr != nil { if iamErr != nil {
return nil, iamErr return nil, iamErr
@ -1693,11 +1718,13 @@ func (e *EmbeddedIamApi) ExecuteAction(ctx context.Context, values url.Values, s
// Service Account actions // Service Account actions
case "CreateServiceAccount": case "CreateServiceAccount":
createdBy := values.Get("CreatedBy") createdBy := values.Get("CreatedBy")
var iamErr *iamError
response, iamErr = e.CreateServiceAccount(s3cfg, values, createdBy) response, iamErr = e.CreateServiceAccount(s3cfg, values, createdBy)
if iamErr != nil { if iamErr != nil {
return nil, iamErr return nil, iamErr
} }
case "DeleteServiceAccount": case "DeleteServiceAccount":
var iamErr *iamError
response, iamErr = e.DeleteServiceAccount(s3cfg, values) response, iamErr = e.DeleteServiceAccount(s3cfg, values)
if iamErr != nil { if iamErr != nil {
return nil, iamErr return nil, iamErr
@ -1706,12 +1733,14 @@ func (e *EmbeddedIamApi) ExecuteAction(ctx context.Context, values url.Values, s
response = e.ListServiceAccounts(s3cfg, values) response = e.ListServiceAccounts(s3cfg, values)
changed = false changed = false
case "GetServiceAccount": case "GetServiceAccount":
var iamErr *iamError
response, iamErr = e.GetServiceAccount(s3cfg, values) response, iamErr = e.GetServiceAccount(s3cfg, values)
if iamErr != nil { if iamErr != nil {
return nil, iamErr return nil, iamErr
} }
changed = false changed = false
case "UpdateServiceAccount": case "UpdateServiceAccount":
var iamErr *iamError
response, iamErr = e.UpdateServiceAccount(s3cfg, values) response, iamErr = e.UpdateServiceAccount(s3cfg, values)
if iamErr != nil { if iamErr != nil {
return nil, iamErr return nil, iamErr
@ -1722,8 +1751,7 @@ func (e *EmbeddedIamApi) ExecuteAction(ctx context.Context, values url.Values, s
if changed { if changed {
if !skipPersist { if !skipPersist {
if err := e.PutS3ApiConfiguration(s3cfg); err != nil { if err := e.PutS3ApiConfiguration(s3cfg); err != nil {
iamErr = &iamError{Code: iam.ErrCodeServiceFailureException, Error: err}
return nil, iamErr
return nil, &iamError{Code: iam.ErrCodeServiceFailureException, Error: err}
} }
} }
// Reload in-memory identity maps so subsequent LookupByAccessKey calls // Reload in-memory identity maps so subsequent LookupByAccessKey calls
@ -1739,11 +1767,13 @@ func (e *EmbeddedIamApi) ExecuteAction(ctx context.Context, values url.Values, s
glog.Errorf("Failed to reload IAM configuration after managed policy mutation: %v", err) glog.Errorf("Failed to reload IAM configuration after managed policy mutation: %v", err)
} }
} }
return response, iamErr
response.SetRequestId(reqID)
return response, nil
} }
// DoActions handles IAM API actions. // DoActions handles IAM API actions.
func (e *EmbeddedIamApi) DoActions(w http.ResponseWriter, r *http.Request) { func (e *EmbeddedIamApi) DoActions(w http.ResponseWriter, r *http.Request) {
r, reqID := request_id.Ensure(r)
if err := r.ParseForm(); err != nil { if err := r.ParseForm(); err != nil {
s3err.WriteErrorResponse(w, r, s3err.ErrInvalidRequest) s3err.WriteErrorResponse(w, r, s3err.ErrInvalidRequest)
return return
@ -1759,15 +1789,11 @@ func (e *EmbeddedIamApi) DoActions(w http.ResponseWriter, r *http.Request) {
values.Set("CreatedBy", createdBy) values.Set("CreatedBy", createdBy)
} }
response, iamErr := e.ExecuteAction(r.Context(), values, false)
response, iamErr := e.ExecuteAction(r.Context(), values, false, reqID)
if iamErr != nil { if iamErr != nil {
e.writeIamErrorResponse(w, r, iamErr)
e.writeIamErrorResponse(w, r, reqID, iamErr)
return return
} }
// Set RequestId for AWS compatibility
if r, ok := response.(interface{ SetRequestId() }); ok {
r.SetRequestId()
}
s3err.WriteXMLResponse(w, r, http.StatusOK, response) s3err.WriteXMLResponse(w, r, http.StatusOK, response)
} }

18
weed/s3api/s3api_embedded_iam_test.go

@ -8,6 +8,7 @@ import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
"regexp"
"strings" "strings"
"sync" "sync"
"testing" "testing"
@ -23,6 +24,7 @@ import (
"github.com/seaweedfs/seaweedfs/weed/s3api/policy_engine" "github.com/seaweedfs/seaweedfs/weed/s3api/policy_engine"
. "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants" . "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants"
"github.com/seaweedfs/seaweedfs/weed/s3api/s3err" "github.com/seaweedfs/seaweedfs/weed/s3api/s3err"
"github.com/seaweedfs/seaweedfs/weed/util/request_id"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
@ -156,6 +158,15 @@ func extractEmbeddedIamErrorCodeAndMessage(response *httptest.ResponseRecorder)
return "", "" return "", ""
} }
func extractEmbeddedIamRequestID(response *httptest.ResponseRecorder) string {
re := regexp.MustCompile(`<RequestId>([^<]+)</RequestId>`)
matches := re.FindStringSubmatch(response.Body.String())
if len(matches) < 2 {
return ""
}
return matches[1]
}
// TestEmbeddedIamCreateUser tests creating a user via the embedded IAM API // TestEmbeddedIamCreateUser tests creating a user via the embedded IAM API
func TestEmbeddedIamCreateUser(t *testing.T) { func TestEmbeddedIamCreateUser(t *testing.T) {
api := NewEmbeddedIamApiForTest() api := NewEmbeddedIamApiForTest()
@ -199,6 +210,8 @@ func TestEmbeddedIamListUsers(t *testing.T) {
// Verify response contains the users // Verify response contains the users
assert.Len(t, out.ListUsersResult.Users, 2) assert.Len(t, out.ListUsersResult.Users, 2)
assert.NotEmpty(t, response.Header().Get(request_id.AmzRequestIDHeader))
assert.Equal(t, response.Header().Get(request_id.AmzRequestIDHeader), out.ResponseMetadata.RequestId)
} }
// TestEmbeddedIamListAccessKeys tests listing access keys via the embedded IAM API // TestEmbeddedIamListAccessKeys tests listing access keys via the embedded IAM API
@ -1216,6 +1229,7 @@ func TestEmbeddedIamNotImplementedAction(t *testing.T) {
assert.Equal(t, http.StatusNotImplemented, rr.Code) assert.Equal(t, http.StatusNotImplemented, rr.Code)
assert.Contains(t, rr.Body.String(), "<RequestId>") assert.Contains(t, rr.Body.String(), "<RequestId>")
assert.NotContains(t, rr.Body.String(), "<ResponseMetadata>") assert.NotContains(t, rr.Body.String(), "<ResponseMetadata>")
assert.Equal(t, rr.Header().Get(request_id.AmzRequestIDHeader), extractEmbeddedIamRequestID(rr))
} }
// TestGetPolicyDocument tests parsing of policy documents // TestGetPolicyDocument tests parsing of policy documents
@ -1900,11 +1914,11 @@ func TestEmbeddedIamExecuteAction(t *testing.T) {
vals.Set("Action", "CreateUser") vals.Set("Action", "CreateUser")
vals.Set("UserName", "ExecuteActionUser") vals.Set("UserName", "ExecuteActionUser")
resp, iamErr := api.ExecuteAction(context.Background(), vals, false)
resp, iamErr := api.ExecuteAction(context.Background(), vals, false, "")
assert.Nil(t, iamErr) assert.Nil(t, iamErr)
// Verify response type // Verify response type
createResp, ok := resp.(iamCreateUserResponse)
createResp, ok := resp.(*iamCreateUserResponse)
assert.True(t, ok) assert.True(t, ok)
assert.Equal(t, "ExecuteActionUser", *createResp.CreateUserResult.User.UserName) assert.Equal(t, "ExecuteActionUser", *createResp.CreateUserResult.User.UserName)

2
weed/s3api/s3api_server.go

@ -34,6 +34,7 @@ import (
"github.com/seaweedfs/seaweedfs/weed/util/grace" "github.com/seaweedfs/seaweedfs/weed/util/grace"
util_http "github.com/seaweedfs/seaweedfs/weed/util/http" util_http "github.com/seaweedfs/seaweedfs/weed/util/http"
util_http_client "github.com/seaweedfs/seaweedfs/weed/util/http/client" util_http_client "github.com/seaweedfs/seaweedfs/weed/util/http/client"
"github.com/seaweedfs/seaweedfs/weed/util/request_id"
"github.com/seaweedfs/seaweedfs/weed/wdclient" "github.com/seaweedfs/seaweedfs/weed/wdclient"
) )
@ -540,6 +541,7 @@ func (s3a *S3ApiServer) UnifiedPostHandler(w http.ResponseWriter, r *http.Reques
func (s3a *S3ApiServer) registerRouter(router *mux.Router) { func (s3a *S3ApiServer) registerRouter(router *mux.Router) {
// API Router // API Router
apiRouter := router.PathPrefix("/").Subrouter() apiRouter := router.PathPrefix("/").Subrouter()
apiRouter.Use(request_id.Middleware)
// S3 Tables API endpoint // S3 Tables API endpoint
// POST / with X-Amz-Target: S3Tables.<OperationName> // POST / with X-Amz-Target: S3Tables.<OperationName>

10
weed/s3api/s3api_sts.go

@ -19,6 +19,7 @@ import (
"github.com/seaweedfs/seaweedfs/weed/iam/sts" "github.com/seaweedfs/seaweedfs/weed/iam/sts"
"github.com/seaweedfs/seaweedfs/weed/iam/utils" "github.com/seaweedfs/seaweedfs/weed/iam/utils"
"github.com/seaweedfs/seaweedfs/weed/s3api/s3err" "github.com/seaweedfs/seaweedfs/weed/s3api/s3err"
"github.com/seaweedfs/seaweedfs/weed/util/request_id"
) )
// STS API constants matching AWS STS specification // STS API constants matching AWS STS specification
@ -97,6 +98,7 @@ func (h *STSHandlers) getAccountID() string {
// HandleSTSRequest is the main entry point for STS requests // HandleSTSRequest is the main entry point for STS requests
// It routes requests based on the Action parameter // It routes requests based on the Action parameter
func (h *STSHandlers) HandleSTSRequest(w http.ResponseWriter, r *http.Request) { func (h *STSHandlers) HandleSTSRequest(w http.ResponseWriter, r *http.Request) {
r, _ = request_id.Ensure(r)
if err := r.ParseForm(); err != nil { if err := r.ParseForm(); err != nil {
h.writeSTSErrorResponse(w, r, STSErrInvalidParameterValue, err) h.writeSTSErrorResponse(w, r, STSErrInvalidParameterValue, err)
return return
@ -224,7 +226,7 @@ func (h *STSHandlers) handleAssumeRoleWithWebIdentity(w http.ResponseWriter, r *
SubjectFromWebIdentityToken: response.AssumedRoleUser.Subject, SubjectFromWebIdentityToken: response.AssumedRoleUser.Subject,
}, },
} }
xmlResponse.ResponseMetadata.RequestId = fmt.Sprintf("%d", time.Now().UnixNano())
xmlResponse.ResponseMetadata.RequestId = request_id.GetFromRequest(r)
s3err.WriteXMLResponse(w, r, http.StatusOK, xmlResponse) s3err.WriteXMLResponse(w, r, http.StatusOK, xmlResponse)
} }
@ -354,7 +356,7 @@ func (h *STSHandlers) handleAssumeRole(w http.ResponseWriter, r *http.Request) {
AssumedRoleUser: assumedUser, AssumedRoleUser: assumedUser,
}, },
} }
xmlResponse.ResponseMetadata.RequestId = fmt.Sprintf("%d", time.Now().UnixNano())
xmlResponse.ResponseMetadata.RequestId = request_id.GetFromRequest(r)
s3err.WriteXMLResponse(w, r, http.StatusOK, xmlResponse) s3err.WriteXMLResponse(w, r, http.StatusOK, xmlResponse)
} }
@ -495,7 +497,7 @@ func (h *STSHandlers) handleAssumeRoleWithLDAPIdentity(w http.ResponseWriter, r
AssumedRoleUser: assumedUser, AssumedRoleUser: assumedUser,
}, },
} }
xmlResponse.ResponseMetadata.RequestId = fmt.Sprintf("%d", time.Now().UnixNano())
xmlResponse.ResponseMetadata.RequestId = request_id.GetFromRequest(r)
s3err.WriteXMLResponse(w, r, http.StatusOK, xmlResponse) s3err.WriteXMLResponse(w, r, http.StatusOK, xmlResponse)
} }
@ -731,7 +733,7 @@ func (h *STSHandlers) writeSTSErrorResponse(w http.ResponseWriter, r *http.Reque
} }
response := STSErrorResponse{ response := STSErrorResponse{
RequestId: fmt.Sprintf("%d", time.Now().UnixNano()),
RequestId: request_id.GetFromRequest(r),
} }
// Server-side errors use "Receiver" type per AWS spec // Server-side errors use "Receiver" type per AWS spec

3
weed/s3api/s3err/audit_fluent.go

@ -10,6 +10,7 @@ import (
"github.com/fluent/fluent-logger-golang/fluent" "github.com/fluent/fluent-logger-golang/fluent"
"github.com/seaweedfs/seaweedfs/weed/glog" "github.com/seaweedfs/seaweedfs/weed/glog"
"github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants" "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants"
"github.com/seaweedfs/seaweedfs/weed/util/request_id"
) )
type AccessLogExtend struct { type AccessLogExtend struct {
@ -150,7 +151,7 @@ func GetAccessLog(r *http.Request, HTTPStatusCode int, s3errCode ErrorCode) *Acc
} }
return &AccessLog{ return &AccessLog{
HostHeader: hostHeader, HostHeader: hostHeader,
RequestID: r.Header.Get("X-Request-ID"),
RequestID: request_id.GetFromRequest(r),
RemoteIP: remoteIP, RemoteIP: remoteIP,
Requester: s3_constants.GetIdentityNameFromContext(r), // Get from context, not header (secure) Requester: s3_constants.GetIdentityNameFromContext(r), // Get from context, not header (secure)
SignatureVersion: r.Header.Get(s3_constants.AmzAuthType), SignatureVersion: r.Header.Get(s3_constants.AmzAuthType),

19
weed/s3api/s3err/audit_fluent_test.go

@ -0,0 +1,19 @@
package s3err
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/seaweedfs/seaweedfs/weed/util/request_id"
"github.com/stretchr/testify/assert"
)
func TestGetAccessLogUsesAmzRequestID(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/bucket/object", nil)
req = req.WithContext(request_id.Set(req.Context(), "req-123"))
log := GetAccessLog(req, http.StatusOK, ErrNone)
assert.Equal(t, "req-123", log.RequestID)
}

13
weed/s3api/s3err/error_handler.go

@ -3,15 +3,14 @@ package s3err
import ( import (
"bytes" "bytes"
"encoding/xml" "encoding/xml"
"fmt"
"net/http" "net/http"
"strconv" "strconv"
"strings" "strings"
"time"
"github.com/aws/aws-sdk-go/private/protocol/xml/xmlutil" "github.com/aws/aws-sdk-go/private/protocol/xml/xmlutil"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/seaweedfs/seaweedfs/weed/glog" "github.com/seaweedfs/seaweedfs/weed/glog"
"github.com/seaweedfs/seaweedfs/weed/util/request_id"
) )
type mimeType string type mimeType string
@ -41,6 +40,7 @@ func WriteEmptyResponse(w http.ResponseWriter, r *http.Request, statusCode int)
} }
func WriteErrorResponse(w http.ResponseWriter, r *http.Request, errorCode ErrorCode) { func WriteErrorResponse(w http.ResponseWriter, r *http.Request, errorCode ErrorCode) {
r, reqID := request_id.Ensure(r)
vars := mux.Vars(r) vars := mux.Vars(r)
bucket := vars["bucket"] bucket := vars["bucket"]
object := vars["object"] object := vars["object"]
@ -49,19 +49,19 @@ func WriteErrorResponse(w http.ResponseWriter, r *http.Request, errorCode ErrorC
} }
apiError := GetAPIError(errorCode) apiError := GetAPIError(errorCode)
errorResponse := getRESTErrorResponse(apiError, r.URL.Path, bucket, object)
errorResponse := getRESTErrorResponse(apiError, r.URL.Path, bucket, object, reqID)
WriteXMLResponse(w, r, apiError.HTTPStatusCode, errorResponse) WriteXMLResponse(w, r, apiError.HTTPStatusCode, errorResponse)
PostLog(r, apiError.HTTPStatusCode, errorCode) PostLog(r, apiError.HTTPStatusCode, errorCode)
} }
func getRESTErrorResponse(err APIError, resource string, bucket, object string) RESTErrorResponse {
func getRESTErrorResponse(err APIError, resource string, bucket, object, requestID string) RESTErrorResponse {
return RESTErrorResponse{ return RESTErrorResponse{
Code: err.Code, Code: err.Code,
BucketName: bucket, BucketName: bucket,
Key: object, Key: object,
Message: err.Description, Message: err.Description,
Resource: resource, Resource: resource,
RequestID: fmt.Sprintf("%d", time.Now().UnixNano()),
RequestID: requestID,
} }
} }
@ -75,7 +75,8 @@ func EncodeXMLResponse(response interface{}) []byte {
} }
func setCommonHeaders(w http.ResponseWriter, r *http.Request) { func setCommonHeaders(w http.ResponseWriter, r *http.Request) {
w.Header().Set("x-amz-request-id", fmt.Sprintf("%d", time.Now().UnixNano()))
_, reqID := request_id.Ensure(r)
w.Header().Set(request_id.AmzRequestIDHeader, reqID)
w.Header().Set("Accept-Ranges", "bytes") w.Header().Set("Accept-Ranges", "bytes")
// Handle CORS headers for requests with Origin header // Handle CORS headers for requests with Origin header

36
weed/s3api/s3err/error_handler_test.go

@ -0,0 +1,36 @@
package s3err
import (
"net/http"
"net/http/httptest"
"regexp"
"testing"
"github.com/gorilla/mux"
"github.com/seaweedfs/seaweedfs/weed/util/request_id"
"github.com/stretchr/testify/assert"
)
func TestWriteErrorResponseReusesRequestID(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/bucket/object", nil)
req = mux.SetURLVars(req, map[string]string{
"bucket": "bucket",
"object": "object",
})
req = req.WithContext(request_id.Set(req.Context(), "req-123"))
rr := httptest.NewRecorder()
WriteErrorResponse(rr, req, ErrNoSuchKey)
assert.Equal(t, "req-123", rr.Header().Get(request_id.AmzRequestIDHeader))
assert.Equal(t, "req-123", extractRequestIDFromBody(rr.Body.String()))
}
func extractRequestIDFromBody(body string) string {
re := regexp.MustCompile(`<RequestId>([^<]+)</RequestId>`)
matches := re.FindStringSubmatch(body)
if len(matches) < 2 {
return ""
}
return matches[1]
}

14
weed/s3api/sts_params_test.go

@ -5,6 +5,7 @@ import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
"regexp"
"strings" "strings"
"sync" "sync"
"testing" "testing"
@ -14,6 +15,7 @@ import (
"github.com/seaweedfs/seaweedfs/weed/pb" "github.com/seaweedfs/seaweedfs/weed/pb"
"github.com/seaweedfs/seaweedfs/weed/pb/iam_pb" "github.com/seaweedfs/seaweedfs/weed/pb/iam_pb"
"github.com/seaweedfs/seaweedfs/weed/s3api/s3err" "github.com/seaweedfs/seaweedfs/weed/s3api/s3err"
"github.com/seaweedfs/seaweedfs/weed/util/request_id"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@ -114,6 +116,7 @@ func TestSTSAssumeRolePostBody(t *testing.T) {
assert.NotEqual(t, http.StatusNotImplemented, rr.Code, "Should not return 501 (IAM handler)") assert.NotEqual(t, http.StatusNotImplemented, rr.Code, "Should not return 501 (IAM handler)")
assert.Equal(t, http.StatusBadRequest, rr.Code, "Should return 400 (STS handler) for missing params") assert.Equal(t, http.StatusBadRequest, rr.Code, "Should return 400 (STS handler) for missing params")
assert.Equal(t, rr.Header().Get(request_id.AmzRequestIDHeader), extractSTSRequestID(rr.Body.String()))
}) })
// Test Case 2: STS Action in Body (Should FAIL current implementation - routed to IAM) // Test Case 2: STS Action in Body (Should FAIL current implementation - routed to IAM)
@ -155,6 +158,7 @@ func TestSTSAssumeRolePostBody(t *testing.T) {
} }
// Confirm it routed to STS // Confirm it routed to STS
assert.Equal(t, http.StatusServiceUnavailable, rr.Code, "Fixed behavior: Should return 503 from STS handler (service not ready)") assert.Equal(t, http.StatusServiceUnavailable, rr.Code, "Fixed behavior: Should return 503 from STS handler (service not ready)")
assert.Equal(t, rr.Header().Get(request_id.AmzRequestIDHeader), extractSTSRequestID(rr.Body.String()))
}) })
// Test Case 3: STS Action in Body with SigV4-style Authorization (Real-world scenario) // Test Case 3: STS Action in Body with SigV4-style Authorization (Real-world scenario)
@ -199,5 +203,15 @@ func TestSTSAssumeRolePostBody(t *testing.T) {
assert.NotEqual(t, http.StatusNotImplemented, rr.Code, "Should not return 501 (IAM handler)") assert.NotEqual(t, http.StatusNotImplemented, rr.Code, "Should not return 501 (IAM handler)")
assert.Contains(t, []int{http.StatusServiceUnavailable, http.StatusForbidden}, rr.Code, assert.Contains(t, []int{http.StatusServiceUnavailable, http.StatusForbidden}, rr.Code,
"Should return 503 (STS unavailable) or 403 (auth failed), confirming STS routing") "Should return 503 (STS unavailable) or 403 (auth failed), confirming STS routing")
assert.Equal(t, rr.Header().Get(request_id.AmzRequestIDHeader), extractSTSRequestID(rr.Body.String()))
}) })
} }
func extractSTSRequestID(body string) string {
re := regexp.MustCompile(`<RequestId>([^<]+)</RequestId>`)
matches := re.FindStringSubmatch(body)
if len(matches) < 2 {
return ""
}
return matches[1]
}

22
weed/server/common.go

@ -3,7 +3,6 @@ package weed_server
import ( import (
"bufio" "bufio"
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@ -18,7 +17,6 @@ import (
"sync" "sync"
"time" "time"
"github.com/google/uuid"
"github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants" "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants"
"github.com/seaweedfs/seaweedfs/weed/util/request_id" "github.com/seaweedfs/seaweedfs/weed/util/request_id"
"github.com/seaweedfs/seaweedfs/weed/util/version" "github.com/seaweedfs/seaweedfs/weed/util/version"
@ -432,18 +430,12 @@ func ProcessRangeRequest(r *http.Request, w http.ResponseWriter, totalSize int64
func requestIDMiddleware(h http.HandlerFunc) http.HandlerFunc { func requestIDMiddleware(h http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
reqID := r.Header.Get(request_id.AmzRequestIDHeader)
if reqID == "" {
reqID = uuid.New().String()
}
ctx := context.WithValue(r.Context(), request_id.AmzRequestIDHeader, reqID)
ctx = metadata.NewOutgoingContext(ctx,
metadata.New(map[string]string{
request_id.AmzRequestIDHeader: reqID,
}))
w.Header().Set(request_id.AmzRequestIDHeader, reqID)
h(w, r.WithContext(ctx))
request_id.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := metadata.NewOutgoingContext(r.Context(),
metadata.New(map[string]string{
request_id.AmzRequestIDHeader: request_id.Get(r.Context()),
}))
h(w, r.WithContext(ctx))
})).ServeHTTP(w, r)
} }
} }

48
weed/util/request_id/request_id.go

@ -2,20 +2,25 @@ package request_id
import ( import (
"context" "context"
"crypto/rand"
"fmt"
"net/http" "net/http"
"time"
) )
const AmzRequestIDHeader = "x-amz-request-id" const AmzRequestIDHeader = "x-amz-request-id"
type contextKey struct{}
func Set(ctx context.Context, id string) context.Context { func Set(ctx context.Context, id string) context.Context {
return context.WithValue(ctx, AmzRequestIDHeader, id)
return context.WithValue(ctx, contextKey{}, id)
} }
func Get(ctx context.Context) string { func Get(ctx context.Context) string {
if ctx == nil { if ctx == nil {
return "" return ""
} }
id, _ := ctx.Value(AmzRequestIDHeader).(string)
id, _ := ctx.Value(contextKey{}).(string)
return id return id
} }
@ -24,3 +29,42 @@ func InjectToRequest(ctx context.Context, req *http.Request) {
req.Header.Set(AmzRequestIDHeader, Get(ctx)) req.Header.Set(AmzRequestIDHeader, Get(ctx))
} }
} }
func New() string {
var buf [4]byte
rand.Read(buf[:])
return fmt.Sprintf("%X%08X", time.Now().UTC().UnixNano(), buf)
}
// GetFromRequest returns the server-generated request ID from the context.
func GetFromRequest(r *http.Request) string {
if r == nil {
return ""
}
return Get(r.Context())
}
// Ensure guarantees a server-generated request ID exists in the context.
// It always generates a new ID if one is not already present in the context,
// ignoring any client-sent x-amz-request-id header to prevent spoofing.
func Ensure(r *http.Request) (*http.Request, string) {
if r == nil {
return nil, ""
}
if id := Get(r.Context()); id != "" {
return r, id
}
id := New()
r = r.WithContext(Set(r.Context(), id))
return r, id
}
func Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
r, reqID := Ensure(r)
if w.Header().Get(AmzRequestIDHeader) == "" {
w.Header().Set(AmzRequestIDHeader, reqID)
}
next.ServeHTTP(w, r)
})
}

50
weed/util/request_id/request_id_test.go

@ -0,0 +1,50 @@
package request_id
import (
"net/http/httptest"
"regexp"
"testing"
)
var requestIDPattern = regexp.MustCompile(`^[0-9A-F]+$`)
func TestNewUsesUppercaseHexFormat(t *testing.T) {
id := New()
if !requestIDPattern.MatchString(id) {
t.Fatalf("expected uppercase hex request id, got %q", id)
}
if len(id) < 24 {
t.Fatalf("expected request id to be at least 24 characters, got %q (len=%d)", id, len(id))
}
}
func TestNewIsUnique(t *testing.T) {
a := New()
b := New()
if a == b {
t.Fatalf("expected unique request ids, got %q twice", a)
}
}
func TestEnsureIgnoresClientHeader(t *testing.T) {
req := httptest.NewRequest("GET", "/", nil)
req.Header.Set(AmzRequestIDHeader, "spoofed-id")
req, id := Ensure(req)
if id == "spoofed-id" {
t.Fatal("Ensure should not trust client-sent x-amz-request-id header")
}
if !requestIDPattern.MatchString(id) {
t.Fatalf("expected server-generated hex id, got %q", id)
}
}
func TestEnsureReusesContextID(t *testing.T) {
req := httptest.NewRequest("GET", "/", nil)
req = req.WithContext(Set(req.Context(), "ctx-id-123"))
req, id := Ensure(req)
if id != "ctx-id-123" {
t.Fatalf("expected context id ctx-id-123, got %q", id)
}
}
Loading…
Cancel
Save