From e48764be75d285fe7edfbd88ca7ff419580c48d1 Mon Sep 17 00:00:00 2001 From: chrislu Date: Wed, 23 Mar 2022 01:05:14 -0700 Subject: [PATCH] s3: multipart upload verifies uploaded parts --- weed/s3api/filer_multipart.go | 30 ++++++- weed/s3api/filer_multipart_test.go | 83 +++++++++++++++++++ weed/s3api/s3api_object_multipart_handlers.go | 36 ++++++-- 3 files changed, 140 insertions(+), 9 deletions(-) diff --git a/weed/s3api/filer_multipart.go b/weed/s3api/filer_multipart.go index 1514e2aa8..5a039382b 100644 --- a/weed/s3api/filer_multipart.go +++ b/weed/s3api/filer_multipart.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/chrislusf/seaweedfs/weed/s3api/s3err" "path/filepath" + "sort" "strconv" "strings" "time" @@ -62,10 +63,15 @@ type CompleteMultipartUploadResult struct { s3.CompleteMultipartUploadOutput } -func (s3a *S3ApiServer) completeMultipartUpload(input *s3.CompleteMultipartUploadInput) (output *CompleteMultipartUploadResult, code s3err.ErrorCode) { +func (s3a *S3ApiServer) completeMultipartUpload(input *s3.CompleteMultipartUploadInput, parts *CompleteMultipartUpload) (output *CompleteMultipartUploadResult, code s3err.ErrorCode) { glog.V(2).Infof("completeMultipartUpload input %v", input) + completedParts := parts.Parts + sort.Slice(completedParts, func(i, j int) bool { + return completedParts[i].PartNumber < completedParts[j].PartNumber + }) + uploadDirectory := s3a.genUploadsFolder(*input.Bucket) + "/" + *input.UploadId entries, _, err := s3a.list(uploadDirectory, "", "", false, maxPartsList) @@ -80,14 +86,16 @@ func (s3a *S3ApiServer) completeMultipartUpload(input *s3.CompleteMultipartUploa return nil, s3err.ErrNoSuchUpload } + mime := pentry.Attributes.Mime + var finalParts []*filer_pb.FileChunk var offset int64 - var mime string for _, entry := range entries { if strings.HasSuffix(entry.Name, ".part") && !entry.IsDirectory { - if entry.Name == "0001.part" && entry.Attributes.Mime != "" { - mime = entry.Attributes.Mime + _, found := findByPartNumber(entry.Name, completedParts) + if !found { + continue } for _, chunk := range entry.Chunks { p := &filer_pb.FileChunk{ @@ -156,6 +164,20 @@ func (s3a *S3ApiServer) completeMultipartUpload(input *s3.CompleteMultipartUploa return } +func findByPartNumber(fileName string, parts []CompletedPart) (etag string, found bool) { + partNumber, formatErr := strconv.Atoi(fileName[:4]) + if formatErr != nil { + return + } + x := sort.Search(len(parts), func(i int) bool { + return parts[i].PartNumber >= partNumber + }) + if parts[x].PartNumber != partNumber { + return + } + return parts[x].ETag, true +} + func (s3a *S3ApiServer) abortMultipartUpload(input *s3.AbortMultipartUploadInput) (output *s3.AbortMultipartUploadOutput, code s3err.ErrorCode) { glog.V(2).Infof("abortMultipartUpload input %v", input) diff --git a/weed/s3api/filer_multipart_test.go b/weed/s3api/filer_multipart_test.go index 9e1d2307b..52425b5b2 100644 --- a/weed/s3api/filer_multipart_test.go +++ b/weed/s3api/filer_multipart_test.go @@ -4,6 +4,7 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/s3" "github.com/chrislusf/seaweedfs/weed/s3api/s3err" + "github.com/stretchr/testify/assert" "testing" "time" ) @@ -48,3 +49,85 @@ func TestListPartsResult(t *testing.T) { } } + +func Test_findByPartNumber(t *testing.T) { + type args struct { + fileName string + parts []CompletedPart + } + + parts := []CompletedPart{ + CompletedPart{ + ETag: "xxx", + PartNumber: 1, + }, + CompletedPart{ + ETag: "yyy", + PartNumber: 3, + }, + CompletedPart{ + ETag: "zzz", + PartNumber: 5, + }, + } + + tests := []struct { + name string + args args + wantEtag string + wantFound bool + }{ + { + "first", + args{ + "0001.part", + parts, + }, + "xxx", + true, + }, + { + "second", + args{ + "0002.part", + parts, + }, + "", + false, + }, + { + "third", + args{ + "0003.part", + parts, + }, + "yyy", + true, + }, + { + "fourth", + args{ + "0004.part", + parts, + }, + "", + false, + }, + { + "fifth", + args{ + "0005.part", + parts, + }, + "zzz", + true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotEtag, gotFound := findByPartNumber(tt.args.fileName, tt.args.parts) + assert.Equalf(t, tt.wantEtag, gotEtag, "findByPartNumber(%v, %v)", tt.args.fileName, tt.args.parts) + assert.Equalf(t, tt.wantFound, gotFound, "findByPartNumber(%v, %v)", tt.args.fileName, tt.args.parts) + }) + } +} diff --git a/weed/s3api/s3api_object_multipart_handlers.go b/weed/s3api/s3api_object_multipart_handlers.go index 99c280e13..35bc174c8 100644 --- a/weed/s3api/s3api_object_multipart_handlers.go +++ b/weed/s3api/s3api_object_multipart_handlers.go @@ -1,11 +1,13 @@ package s3api import ( + "encoding/xml" "fmt" "github.com/chrislusf/seaweedfs/weed/glog" xhttp "github.com/chrislusf/seaweedfs/weed/s3api/http" "github.com/chrislusf/seaweedfs/weed/s3api/s3err" weed_server "github.com/chrislusf/seaweedfs/weed/server" + "io" "net/http" "net/url" "strconv" @@ -56,8 +58,16 @@ func (s3a *S3ApiServer) NewMultipartUploadHandler(w http.ResponseWriter, r *http // CompleteMultipartUploadHandler - Completes multipart upload. func (s3a *S3ApiServer) CompleteMultipartUploadHandler(w http.ResponseWriter, r *http.Request) { + // https://docs.aws.amazon.com/AmazonS3/latest/API/API_CompleteMultipartUpload.html + bucket, object := xhttp.GetBucketAndObject(r) + parts := &CompleteMultipartUpload{} + if err := xmlDecoder(r.Body, parts, r.ContentLength); err != nil { + s3err.WriteErrorResponse(w, r, s3err.ErrMalformedXML) + return + } + // Get upload id. uploadID, _, _, _ := getObjectResources(r.URL.Query()) @@ -65,7 +75,7 @@ func (s3a *S3ApiServer) CompleteMultipartUploadHandler(w http.ResponseWriter, r Bucket: aws.String(bucket), Key: objectKey(aws.String(object)), UploadId: aws.String(uploadID), - }) + }, parts) glog.V(2).Info("CompleteMultipartUploadHandler", string(s3err.EncodeXMLResponse(response)), errCode) @@ -268,8 +278,24 @@ func getObjectResources(values url.Values) (uploadID string, partNumberMarker, m return } -type byCompletedPartNumber []*s3.CompletedPart +func xmlDecoder(body io.Reader, v interface{}, size int64) error { + var lbody io.Reader + if size > 0 { + lbody = io.LimitReader(body, size) + } else { + lbody = body + } + d := xml.NewDecoder(lbody) + d.CharsetReader = func(label string, input io.Reader) (io.Reader, error) { + return input, nil + } + return d.Decode(v) +} -func (a byCompletedPartNumber) Len() int { return len(a) } -func (a byCompletedPartNumber) Swap(i, j int) { a[i], a[j] = a[j], a[i] } -func (a byCompletedPartNumber) Less(i, j int) bool { return *a[i].PartNumber < *a[j].PartNumber } +type CompleteMultipartUpload struct { + Parts []CompletedPart `xml:"Part"` +} +type CompletedPart struct { + ETag string + PartNumber int +}