From f17ec59d460ddac9a97f829c1ce613849a1c8abd Mon Sep 17 00:00:00 2001 From: Chris Lu Date: Wed, 28 Jan 2026 13:25:32 -0800 Subject: [PATCH] s3tables: implement optimistic concurrency for table deletion Added VersionToken validation to handleDeleteTable. Refactored table listing to use request context for accurate ARN generation and fixed cross-namespace pagination issues. --- weed/s3api/s3tables/handler_table.go | 234 ++++++++++++++------------- 1 file changed, 118 insertions(+), 116 deletions(-) diff --git a/weed/s3api/s3tables/handler_table.go b/weed/s3api/s3tables/handler_table.go index 5d8c5502a..605a932c8 100644 --- a/weed/s3api/s3tables/handler_table.go +++ b/weed/s3api/s3tables/handler_table.go @@ -1,7 +1,6 @@ package s3tables import ( - "context" "encoding/json" "errors" "fmt" @@ -18,7 +17,8 @@ import ( func (h *S3TablesHandler) handleCreateTable(w http.ResponseWriter, r *http.Request, filerClient FilerClient) error { // Check permission principal := h.getPrincipalFromRequest(r) - if !CanCreateTable(principal, h.accountID) { + accountID := h.getAccountID(r) + if !CanCreateTable(principal, accountID) { h.writeError(w, http.StatusForbidden, ErrCodeAccessDenied, "not authorized to create table") return NewAuthError("CreateTable", principal, "not authorized to create table") } @@ -111,7 +111,7 @@ func (h *S3TablesHandler) handleCreateTable(w http.ResponseWriter, r *http.Reque Format: req.Format, CreatedAt: now, ModifiedAt: now, - OwnerID: h.accountID, + OwnerID: h.getAccountID(r), VersionToken: versionToken, Schema: req.Metadata, } @@ -158,7 +158,7 @@ func (h *S3TablesHandler) handleCreateTable(w http.ResponseWriter, r *http.Reque return err } - tableARN := h.generateTableARN(bucketName, namespaceName+"/"+tableName) + tableARN := h.generateTableARN(r, bucketName, namespaceName+"/"+tableName) resp := &CreateTableResponse{ TableARN: tableARN, @@ -173,7 +173,8 @@ func (h *S3TablesHandler) handleCreateTable(w http.ResponseWriter, r *http.Reque func (h *S3TablesHandler) handleGetTable(w http.ResponseWriter, r *http.Request, filerClient FilerClient) error { // Check permission principal := h.getPrincipalFromRequest(r) - if !CanGetTable(principal, h.accountID) { + accountID := h.getAccountID(r) + if !CanGetTable(principal, accountID) { h.writeError(w, http.StatusForbidden, ErrCodeAccessDenied, "not authorized to get table") return NewAuthError("GetTable", principal, "not authorized to get table") } @@ -235,7 +236,7 @@ func (h *S3TablesHandler) handleGetTable(w http.ResponseWriter, r *http.Request, return err } - tableARN := h.generateTableARN(bucketName, namespace+"/"+tableName) + tableARN := h.generateTableARN(r, bucketName, namespace+"/"+tableName) resp := &GetTableResponse{ Name: metadata.Name, @@ -257,7 +258,8 @@ func (h *S3TablesHandler) handleGetTable(w http.ResponseWriter, r *http.Request, func (h *S3TablesHandler) handleListTables(w http.ResponseWriter, r *http.Request, filerClient FilerClient) error { // Check permission principal := h.getPrincipalFromRequest(r) - if !CanListTables(principal, h.accountID) { + accountID := h.getAccountID(r) + if !CanListTables(principal, accountID) { h.writeError(w, http.StatusForbidden, ErrCodeAccessDenied, "not authorized to list tables") return NewAuthError("ListTables", principal, "not authorized to list tables") } @@ -285,44 +287,27 @@ func (h *S3TablesHandler) handleListTables(w http.ResponseWriter, r *http.Reques } var tables []TableSummary + var paginationToken string - // If namespace is specified, list tables in that namespace only - if len(req.Namespace) > 0 { - namespaceName, err := validateNamespace(req.Namespace) - if err != nil { - h.writeError(w, http.StatusBadRequest, ErrCodeInvalidRequest, err.Error()) - return err + err = filerClient.WithFilerClient(false, func(client filer_pb.SeaweedFilerClient) error { + var err error + if len(req.Namespace) > 0 { + namespaceName, err := validateNamespace(req.Namespace) + if err != nil { + return err + } + tables, paginationToken, err = h.listTablesInNamespaceWithClient(r, client, bucketName, namespaceName, req.Prefix, req.ContinuationToken, maxTables) + } else { + tables, paginationToken, err = h.listTablesInAllNamespaces(r, client, bucketName, req.Prefix, req.ContinuationToken, maxTables) } - err = filerClient.WithFilerClient(false, func(client filer_pb.SeaweedFilerClient) error { - return h.listTablesInNamespaceWithClient(r.Context(), client, bucketName, namespaceName, req.Prefix, req.ContinuationToken, maxTables, &tables) - }) - } else { - // List tables in all namespaces - err = h.listTablesInAllNamespaces(r.Context(), filerClient, bucketName, req.Prefix, req.ContinuationToken, maxTables, &tables) - } + return err + }) if err != nil { h.writeError(w, http.StatusInternalServerError, ErrCodeInternalError, fmt.Sprintf("failed to list tables: %v", err)) return err } - paginationToken := "" - if len(tables) >= maxTables && len(tables) > 0 { - // This is tricky for cross-namespace listing. For now, we'll store the full path if possible, - // but standard S3 tables usually lists within a namespace. - // If we are listing within a namespace, lastFileName is just the table name. - // If we are listing all namespaces, we might need a more complex token. - // For simplicity, let's assume if we hit the limit, we return the last seen entry's path-related info. - if len(req.Namespace) > 0 { - paginationToken = tables[len(tables)-1].Name - } else { - // For all-namespaces listing, we'd need to encode the namespace too. - // Let's use namespace/name as token. - lastTable := tables[len(tables)-1] - paginationToken = lastTable.Namespace[0] + "/" + lastTable.Name - } - } - resp := &ListTablesResponse{ Tables: tables, ContinuationToken: paginationToken, @@ -332,19 +317,26 @@ func (h *S3TablesHandler) handleListTables(w http.ResponseWriter, r *http.Reques return nil } -func (h *S3TablesHandler) listTablesInNamespaceWithClient(ctx context.Context, client filer_pb.SeaweedFilerClient, bucketName, namespace, prefix, continuationToken string, maxTables int, tables *[]TableSummary) error { - namespacePath := getNamespacePath(bucketName, namespace) +// listTablesInNamespaceWithClient lists tables in a specific namespace +func (h *S3TablesHandler) listTablesInNamespaceWithClient(r *http.Request, client filer_pb.SeaweedFilerClient, bucketName, namespaceName, prefix, continuationToken string, maxTables int) ([]TableSummary, string, error) { + namespacePath := getNamespacePath(bucketName, namespaceName) + return h.listTablesWithClient(r, client, namespacePath, bucketName, namespaceName, prefix, continuationToken, maxTables) +} +func (h *S3TablesHandler) listTablesWithClient(r *http.Request, client filer_pb.SeaweedFilerClient, dirPath, bucketName, namespaceName, prefix, continuationToken string, maxTables int) ([]TableSummary, string, error) { + var tables []TableSummary lastFileName := continuationToken - for len(*tables) < maxTables { + ctx := r.Context() + + for len(tables) < maxTables { resp, err := client.ListEntries(ctx, &filer_pb.ListEntriesRequest{ - Directory: namespacePath, + Directory: dirPath, Limit: uint32(maxTables * 2), StartFromFileName: lastFileName, InclusiveStartFrom: lastFileName == "" || lastFileName == continuationToken, }) if err != nil { - return err + return nil, "", err } hasMore := false @@ -354,14 +346,14 @@ func (h *S3TablesHandler) listTablesInNamespaceWithClient(ctx context.Context, c if respErr == io.EOF { break } - return respErr + return nil, "", respErr } if entry.Entry == nil { continue } // Skip the start item if it was included in the previous page - if len(*tables) == 0 && continuationToken != "" && entry.Entry.Name == continuationToken { + if len(tables) == 0 && continuationToken != "" && entry.Entry.Name == continuationToken { continue } @@ -393,18 +385,18 @@ func (h *S3TablesHandler) listTablesInNamespaceWithClient(ctx context.Context, c continue } - tableARN := h.generateTableARN(bucketName, namespace+"/"+entry.Entry.Name) + tableARN := h.generateTableARN(r, bucketName, namespaceName+"/"+entry.Entry.Name) - *tables = append(*tables, TableSummary{ - Name: metadata.Name, + tables = append(tables, TableSummary{ + Name: entry.Entry.Name, TableARN: tableARN, - Namespace: []string{namespace}, + Namespace: []string{namespaceName}, CreatedAt: metadata.CreatedAt, ModifiedAt: metadata.ModifiedAt, }) - if len(*tables) >= maxTables { - return nil + if len(tables) >= maxTables { + return tables, lastFileName, nil } } @@ -413,11 +405,12 @@ func (h *S3TablesHandler) listTablesInNamespaceWithClient(ctx context.Context, c } } - return nil + return tables, lastFileName, nil } -func (h *S3TablesHandler) listTablesInAllNamespaces(ctx context.Context, filerClient FilerClient, bucketName, prefix, continuationToken string, maxTables int, tables *[]TableSummary) error { +func (h *S3TablesHandler) listTablesInAllNamespaces(r *http.Request, client filer_pb.SeaweedFilerClient, bucketName, prefix, continuationToken string, maxTables int) ([]TableSummary, string, error) { bucketPath := getTableBucketPath(bucketName) + ctx := r.Context() var continuationNamespace string var startTableName string @@ -430,88 +423,82 @@ func (h *S3TablesHandler) listTablesInAllNamespaces(ctx context.Context, filerCl } } + var tables []TableSummary lastNamespace := continuationNamespace - return filerClient.WithFilerClient(false, func(client filer_pb.SeaweedFilerClient) error { - for { - // List namespaces in batches - resp, err := client.ListEntries(ctx, &filer_pb.ListEntriesRequest{ - Directory: bucketPath, - Limit: 100, - StartFromFileName: lastNamespace, - InclusiveStartFrom: lastNamespace == continuationNamespace && continuationNamespace != "" || lastNamespace == "", - }) - if err != nil { - return err - } - - hasMore := false - for { - entry, respErr := resp.Recv() - if respErr != nil { - if respErr == io.EOF { - break - } - return respErr - } - if entry.Entry == nil { - continue - } + for { + // List namespaces in batches + resp, err := client.ListEntries(ctx, &filer_pb.ListEntriesRequest{ + Directory: bucketPath, + Limit: 100, + StartFromFileName: lastNamespace, + InclusiveStartFrom: lastNamespace == continuationNamespace && continuationNamespace != "" || lastNamespace == "", + }) + if err != nil { + return nil, "", err + } - // Skip the start item if it was the continuation namespace but we already processed it - // (handled by the startTableName clearing logic below) - if lastNamespace == continuationNamespace && continuationNamespace != "" && entry.Entry.Name == continuationNamespace && startTableName == "" && len(*tables) > 0 { - continue + hasMore := false + for { + entry, respErr := resp.Recv() + if respErr != nil { + if respErr == io.EOF { + break } + return nil, "", respErr + } + if entry.Entry == nil { + continue + } - hasMore = true - lastNamespace = entry.Entry.Name - - if !entry.Entry.IsDirectory { - continue - } + hasMore = true + lastNamespace = entry.Entry.Name - // Skip hidden entries - if strings.HasPrefix(entry.Entry.Name, ".") { - continue - } + if !entry.Entry.IsDirectory || strings.HasPrefix(entry.Entry.Name, ".") { + continue + } - namespace := entry.Entry.Name + namespace := entry.Entry.Name + tableNameFilter := "" + if namespace == continuationNamespace { + tableNameFilter = startTableName + } - // List tables in this namespace - tableNameFilter := "" - if namespace == continuationNamespace { - tableNameFilter = startTableName - } + nsTables, nsToken, err := h.listTablesInNamespaceWithClient(r, client, bucketName, namespace, prefix, tableNameFilter, maxTables-len(tables)) + if err != nil { + glog.Warningf("S3Tables: failed to list tables in namespace %s/%s: %v", bucketName, namespace, err) + continue + } - if err := h.listTablesInNamespaceWithClient(ctx, client, bucketName, namespace, prefix, tableNameFilter, maxTables-len(*tables), tables); err != nil { - glog.Warningf("S3Tables: failed to list tables in namespace %s/%s: %v", bucketName, namespace, err) - continue - } + tables = append(tables, nsTables...) - // Clear startTableName after the first matching namespace is processed - if namespace == continuationNamespace { - startTableName = "" - } + if namespace == continuationNamespace { + startTableName = "" + } - if len(*tables) >= maxTables { - return nil + if len(tables) >= maxTables { + paginationToken := namespace + "/" + nsToken + if nsToken == "" { + // If we hit the limit exactly at the end of a namespace, the next token should be the next namespace + paginationToken = namespace // This will start from the NEXT namespace in the outer loop } + return tables, paginationToken, nil } + } - if !hasMore { - break - } + if !hasMore { + break } + } - return nil - }) + return tables, "", nil } // handleDeleteTable deletes a table from a namespace func (h *S3TablesHandler) handleDeleteTable(w http.ResponseWriter, r *http.Request, filerClient FilerClient) error { // Check permission principal := h.getPrincipalFromRequest(r) - if !CanDeleteTable(principal, h.accountID) { + accountID := h.getAccountID(r) + if !CanDeleteTable(principal, accountID) { h.writeError(w, http.StatusForbidden, ErrCodeAccessDenied, "not authorized to delete table") return NewAuthError("DeleteTable", principal, "not authorized to delete table") } @@ -547,15 +534,30 @@ func (h *S3TablesHandler) handleDeleteTable(w http.ResponseWriter, r *http.Reque tablePath := getTablePath(bucketName, namespaceName, tableName) - // Check if table exists + // Check if table exists and enforce VersionToken if provided err = filerClient.WithFilerClient(false, func(client filer_pb.SeaweedFilerClient) error { - _, err := h.getExtendedAttribute(r.Context(), client, tablePath, ExtendedKeyMetadata) - return err + data, err := h.getExtendedAttribute(r.Context(), client, tablePath, ExtendedKeyMetadata) + if err != nil { + return err + } + + if req.VersionToken != "" { + var metadata tableMetadataInternal + if err := json.Unmarshal(data, &metadata); err != nil { + return fmt.Errorf("failed to unmarshal table metadata: %w", err) + } + if metadata.VersionToken != req.VersionToken { + return ErrVersionTokenMismatch + } + } + return nil }) if err != nil { if errors.Is(err, filer_pb.ErrNotFound) { h.writeError(w, http.StatusNotFound, ErrCodeNoSuchTable, fmt.Sprintf("table %s not found", tableName)) + } else if errors.Is(err, ErrVersionTokenMismatch) { + h.writeError(w, http.StatusConflict, ErrCodeConflict, "version token mismatch") } else { h.writeError(w, http.StatusInternalServerError, ErrCodeInternalError, fmt.Sprintf("failed to check table: %v", err)) }