diff --git a/weed/s3api/s3tables/handler.go b/weed/s3api/s3tables/handler.go index b69b6b662..1ead82309 100644 --- a/weed/s3api/s3tables/handler.go +++ b/weed/s3api/s3tables/handler.go @@ -260,20 +260,10 @@ func normalizePrincipalID(id string) string { // getIdentityActions extracts the action list from the identity object in the request context. // Uses reflection to avoid import cycles with s3api package. func getIdentityActions(r *http.Request) []string { - identityRaw := s3_constants.GetIdentityFromContext(r) - if identityRaw == nil { - return nil - } - - // Use reflection to access the Actions field to avoid import cycle - val := reflect.ValueOf(identityRaw) - if val.Kind() == reflect.Ptr { - val = val.Elem() - } - if val.Kind() != reflect.Struct { + val, ok := getIdentityStructValue(r) + if !ok { return nil } - actionsField := val.FieldByName("Actions") if !actionsField.IsValid() || actionsField.Kind() != reflect.Slice { return nil diff --git a/weed/s3api/s3tables/handler_bucket_create.go b/weed/s3api/s3tables/handler_bucket_create.go index b2791448a..75d44062e 100644 --- a/weed/s3api/s3tables/handler_bucket_create.go +++ b/weed/s3api/s3tables/handler_bucket_create.go @@ -29,8 +29,8 @@ func (h *S3TablesHandler) handleCreateTableBucket(w http.ResponseWriter, r *http principal := h.getAccountID(r) identityActions := getIdentityActions(r) identityPolicyNames := getIdentityPolicyNames(r) - if h.shouldUseIAM(r, identityActions, identityPolicyNames) && !h.defaultAllow { - ownerAccountID := h.getAccountID(r) + if h.shouldUseIAM(r, identityActions, identityPolicyNames) { + ownerAccountID := principal tableBucketARN := h.generateTableBucketARN(ownerAccountID, req.Name) s3BucketARN := fmt.Sprintf("arn:aws:s3:::%s", req.Name) allowed, err := h.authorizeIAMAction(r, identityPolicyNames, "s3tables:CreateTableBucket", tableBucketARN, s3BucketARN) diff --git a/weed/s3api/s3tables/iam.go b/weed/s3api/s3tables/iam.go index 9a0be69dd..214e51efe 100644 --- a/weed/s3api/s3tables/iam.go +++ b/weed/s3api/s3tables/iam.go @@ -36,13 +36,17 @@ func (h *S3TablesHandler) shouldUseIAM(r *http.Request, identityActions, identit } func hasSessionToken(r *http.Request) bool { - if r.Header.Get("X-SeaweedFS-Session-Token") != "" { - return true + return extractSessionToken(r) != "" +} + +func extractSessionToken(r *http.Request) string { + if token := r.Header.Get("X-SeaweedFS-Session-Token"); token != "" { + return token } - if r.Header.Get("X-Amz-Security-Token") != "" { - return true + if token := r.Header.Get("X-Amz-Security-Token"); token != "" { + return token } - return r.URL.Query().Get("X-Amz-Security-Token") != "" + return r.URL.Query().Get("X-Amz-Security-Token") } func (h *S3TablesHandler) authorizeIAMAction(r *http.Request, identityPolicyNames []string, action string, resources ...string) (bool, error) { @@ -61,13 +65,7 @@ func (h *S3TablesHandler) authorizeIAMAction(r *http.Request, identityPolicyName action = "s3tables:" + action } - sessionToken := r.Header.Get("X-SeaweedFS-Session-Token") - if sessionToken == "" { - sessionToken = r.Header.Get("X-Amz-Security-Token") - if sessionToken == "" { - sessionToken = r.URL.Query().Get("X-Amz-Security-Token") - } - } + sessionToken := extractSessionToken(r) requestContext := buildIAMRequestContext(r, getIdentityClaims(r)) policyNames := identityPolicyNames @@ -101,15 +99,8 @@ func (h *S3TablesHandler) authorizeIAMAction(r *http.Request, identityPolicyName } func getIdentityPrincipalArn(r *http.Request) string { - identityRaw := s3_constants.GetIdentityFromContext(r) - if identityRaw == nil { - return "" - } - val := reflect.ValueOf(identityRaw) - if val.Kind() == reflect.Ptr { - val = val.Elem() - } - if val.Kind() != reflect.Struct { + val, ok := getIdentityStructValue(r) + if !ok { return "" } field := val.FieldByName("PrincipalArn") @@ -120,15 +111,8 @@ func getIdentityPrincipalArn(r *http.Request) string { } func getIdentityPolicyNames(r *http.Request) []string { - identityRaw := s3_constants.GetIdentityFromContext(r) - if identityRaw == nil { - return nil - } - val := reflect.ValueOf(identityRaw) - if val.Kind() == reflect.Ptr { - val = val.Elem() - } - if val.Kind() != reflect.Struct { + val, ok := getIdentityStructValue(r) + if !ok { return nil } field := val.FieldByName("PolicyNames") @@ -151,15 +135,8 @@ func getIdentityPolicyNames(r *http.Request) []string { } func getIdentityClaims(r *http.Request) map[string]interface{} { - identityRaw := s3_constants.GetIdentityFromContext(r) - if identityRaw == nil { - return nil - } - val := reflect.ValueOf(identityRaw) - if val.Kind() == reflect.Ptr { - val = val.Elem() - } - if val.Kind() != reflect.Struct { + val, ok := getIdentityStructValue(r) + if !ok { return nil } field := val.FieldByName("Claims") @@ -194,9 +171,6 @@ func buildIAMRequestContext(r *http.Request, claims map[string]interface{}) map[ if referer := r.Header.Get("Referer"); referer != "" { ctx["referer"] = referer } - if requestTime := r.Context().Value("requestTime"); requestTime != nil { - ctx["requestTime"] = requestTime - } for k, v := range claims { if _, exists := ctx[k]; !exists { ctx[k] = v @@ -213,3 +187,18 @@ func buildIAMRequestContext(r *http.Request, claims map[string]interface{}) map[ } return ctx } + +func getIdentityStructValue(r *http.Request) (reflect.Value, bool) { + identityRaw := s3_constants.GetIdentityFromContext(r) + if identityRaw == nil { + return reflect.Value{}, false + } + val := reflect.ValueOf(identityRaw) + if val.Kind() == reflect.Ptr { + val = val.Elem() + } + if val.Kind() != reflect.Struct { + return reflect.Value{}, false + } + return val, true +}