diff --git a/weed/s3api/s3api_object_retention.go b/weed/s3api/s3api_object_retention.go index fa7eb6856..fa06732ee 100644 --- a/weed/s3api/s3api_object_retention.go +++ b/weed/s3api/s3api_object_retention.go @@ -91,60 +91,48 @@ func (or *ObjectRetention) UnmarshalXML(d *xml.Decoder, start xml.StartElement) return nil } -// parseObjectRetention parses XML retention configuration from request body -func parseObjectRetention(r *http.Request) (*ObjectRetention, error) { +// parseXML is a generic helper function to parse XML from request body +func parseXML[T any](r *http.Request, result *T) error { if r.Body == nil { - return nil, fmt.Errorf("empty request body") + return fmt.Errorf("empty request body") } body, err := io.ReadAll(r.Body) if err != nil { - return nil, fmt.Errorf("error reading request body: %v", err) + return fmt.Errorf("error reading request body: %v", err) } - var retention ObjectRetention - if err := xml.Unmarshal(body, &retention); err != nil { - return nil, fmt.Errorf("error parsing XML: %v", err) + if err := xml.Unmarshal(body, result); err != nil { + return fmt.Errorf("error parsing XML: %v", err) } + return nil +} + +// parseObjectRetention parses XML retention configuration from request body +func parseObjectRetention(r *http.Request) (*ObjectRetention, error) { + var retention ObjectRetention + if err := parseXML(r, &retention); err != nil { + return nil, err + } return &retention, nil } // parseObjectLegalHold parses XML legal hold configuration from request body func parseObjectLegalHold(r *http.Request) (*ObjectLegalHold, error) { - if r.Body == nil { - return nil, fmt.Errorf("empty request body") - } - - body, err := io.ReadAll(r.Body) - if err != nil { - return nil, fmt.Errorf("error reading request body: %v", err) - } - var legalHold ObjectLegalHold - if err := xml.Unmarshal(body, &legalHold); err != nil { - return nil, fmt.Errorf("error parsing XML: %v", err) + if err := parseXML(r, &legalHold); err != nil { + return nil, err } - return &legalHold, nil } // parseObjectLockConfiguration parses XML object lock configuration from request body func parseObjectLockConfiguration(r *http.Request) (*ObjectLockConfiguration, error) { - if r.Body == nil { - return nil, fmt.Errorf("empty request body") - } - - body, err := io.ReadAll(r.Body) - if err != nil { - return nil, fmt.Errorf("error reading request body: %v", err) - } - var config ObjectLockConfiguration - if err := xml.Unmarshal(body, &config); err != nil { - return nil, fmt.Errorf("error parsing XML: %v", err) + if err := parseXML(r, &config); err != nil { + return nil, err } - return &config, nil } @@ -527,31 +515,6 @@ func (s3a *S3ApiServer) checkLegacyWormEnforcement(bucket, object, versionId str return nil } -// integrateWithWormSystem ensures compatibility between S3 retention and legacy WORM -func (s3a *S3ApiServer) integrateWithWormSystem(entry *filer_pb.Entry, retention *ObjectRetention) { - if retention == nil || retention.RetainUntilDate == nil { - return - } - - // Set the legacy WORM timestamp for backward compatibility - if entry.WormEnforcedAtTsNs == 0 { - entry.WormEnforcedAtTsNs = time.Now().UnixNano() - } - - // Store additional S3 retention metadata in extended attributes - if entry.Extended == nil { - entry.Extended = make(map[string][]byte) - } - - if retention.Mode != "" { - entry.Extended[s3_constants.ExtRetentionModeKey] = []byte(retention.Mode) - } - - if retention.RetainUntilDate != nil { - entry.Extended[s3_constants.ExtRetentionUntilDateKey] = []byte(strconv.FormatInt(retention.RetainUntilDate.Unix(), 10)) - } -} - // isObjectWormProtected checks both S3 retention and legacy WORM for complete protection status func (s3a *S3ApiServer) isObjectWormProtected(bucket, object, versionId string) (bool, error) { // Check S3 object retention diff --git a/weed/s3api/s3api_object_retention_test.go b/weed/s3api/s3api_object_retention_test.go index ef194317f..ebe55fbfd 100644 --- a/weed/s3api/s3api_object_retention_test.go +++ b/weed/s3api/s3api_object_retention_test.go @@ -1,6 +1,9 @@ package s3api import ( + "io" + "net/http" + "strings" "testing" "time" @@ -167,3 +170,205 @@ func TestValidateLegalHold(t *testing.T) { }) } } + +func TestParseObjectLockConfiguration(t *testing.T) { + tests := []struct { + name string + xmlBody string + expectError bool + errorMsg string + expectedConfig *ObjectLockConfiguration + }{ + { + name: "Valid configuration with days", + xmlBody: ` + + Enabled + + + GOVERNANCE + 30 + + +`, + expectError: false, + expectedConfig: &ObjectLockConfiguration{ + ObjectLockEnabled: s3_constants.ObjectLockEnabled, + Rule: &ObjectLockRule{ + DefaultRetention: &DefaultRetention{ + Mode: s3_constants.RetentionModeGovernance, + Days: 30, + }, + }, + }, + }, + { + name: "Valid configuration with years", + xmlBody: ` + + Enabled + + + COMPLIANCE + 1 + + +`, + expectError: false, + expectedConfig: &ObjectLockConfiguration{ + ObjectLockEnabled: s3_constants.ObjectLockEnabled, + Rule: &ObjectLockRule{ + DefaultRetention: &DefaultRetention{ + Mode: s3_constants.RetentionModeCompliance, + Years: 1, + }, + }, + }, + }, + { + name: "Configuration with ObjectLockEnabled only", + xmlBody: ` + + Enabled +`, + expectError: false, + expectedConfig: &ObjectLockConfiguration{ + ObjectLockEnabled: s3_constants.ObjectLockEnabled, + }, + }, + { + name: "Empty body", + xmlBody: "", + expectError: true, + errorMsg: "empty request body", + }, + { + name: "Invalid XML", + xmlBody: "", + expectError: true, + errorMsg: "error parsing XML", + }, + { + name: "Malformed XML structure", + xmlBody: ` + + Invalid +`, + expectError: true, + errorMsg: "error parsing XML", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var req *http.Request + if tt.xmlBody == "" { + req = &http.Request{Body: nil} + } else { + req = &http.Request{ + Body: io.NopCloser(strings.NewReader(tt.xmlBody)), + } + } + + config, err := parseObjectLockConfiguration(req) + + if tt.expectError { + if err == nil { + t.Errorf("Expected error but got none") + } else if !strings.Contains(err.Error(), tt.errorMsg) { + t.Errorf("Expected error message to contain '%s', got '%s'", tt.errorMsg, err.Error()) + } + } else { + if err != nil { + t.Errorf("Expected no error but got: %v", err) + } + if config == nil { + t.Errorf("Expected config but got nil") + } + if tt.expectedConfig != nil && config != nil { + if config.ObjectLockEnabled != tt.expectedConfig.ObjectLockEnabled { + t.Errorf("Expected ObjectLockEnabled '%s', got '%s'", tt.expectedConfig.ObjectLockEnabled, config.ObjectLockEnabled) + } + if (config.Rule == nil) != (tt.expectedConfig.Rule == nil) { + t.Errorf("Rule presence mismatch") + } + if config.Rule != nil && tt.expectedConfig.Rule != nil { + if (config.Rule.DefaultRetention == nil) != (tt.expectedConfig.Rule.DefaultRetention == nil) { + t.Errorf("DefaultRetention presence mismatch") + } + if config.Rule.DefaultRetention != nil && tt.expectedConfig.Rule.DefaultRetention != nil { + if config.Rule.DefaultRetention.Mode != tt.expectedConfig.Rule.DefaultRetention.Mode { + t.Errorf("Expected Mode '%s', got '%s'", tt.expectedConfig.Rule.DefaultRetention.Mode, config.Rule.DefaultRetention.Mode) + } + if config.Rule.DefaultRetention.Days != tt.expectedConfig.Rule.DefaultRetention.Days { + t.Errorf("Expected Days %d, got %d", tt.expectedConfig.Rule.DefaultRetention.Days, config.Rule.DefaultRetention.Days) + } + if config.Rule.DefaultRetention.Years != tt.expectedConfig.Rule.DefaultRetention.Years { + t.Errorf("Expected Years %d, got %d", tt.expectedConfig.Rule.DefaultRetention.Years, config.Rule.DefaultRetention.Years) + } + } + } + } + } + }) + } +} + +func TestParseXMLGeneric(t *testing.T) { + tests := []struct { + name string + xmlBody string + expectError bool + errorMsg string + }{ + { + name: "Valid retention XML", + xmlBody: ` + + GOVERNANCE + 2024-12-31T23:59:59Z +`, + expectError: false, + }, + { + name: "Empty body", + xmlBody: "", + expectError: true, + errorMsg: "empty request body", + }, + { + name: "Invalid XML", + xmlBody: "", + expectError: true, + errorMsg: "error parsing XML", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var req *http.Request + if tt.xmlBody == "" { + req = &http.Request{Body: nil} + } else { + req = &http.Request{ + Body: io.NopCloser(strings.NewReader(tt.xmlBody)), + } + } + + var retention ObjectRetention + err := parseXML(req, &retention) + + if tt.expectError { + if err == nil { + t.Errorf("Expected error but got none") + } else if !strings.Contains(err.Error(), tt.errorMsg) { + t.Errorf("Expected error message to contain '%s', got '%s'", tt.errorMsg, err.Error()) + } + } else { + if err != nil { + t.Errorf("Expected no error but got: %v", err) + } + } + }) + } +}