From 47ef8c3cce175400cb22eebd98a4b2074f03ee25 Mon Sep 17 00:00:00 2001 From: Chris Lu Date: Wed, 28 Jan 2026 12:30:32 -0800 Subject: [PATCH] s3tables: add table name validation and cleanup duplicated logic in table handlers --- weed/s3api/s3tables/handler_table.go | 105 ++++++++++++++++++--------- 1 file changed, 72 insertions(+), 33 deletions(-) diff --git a/weed/s3api/s3tables/handler_table.go b/weed/s3api/s3tables/handler_table.go index b8b3540a3..af514c991 100644 --- a/weed/s3api/s3tables/handler_table.go +++ b/weed/s3api/s3tables/handler_table.go @@ -63,20 +63,10 @@ func (h *S3TablesHandler) handleCreateTable(w http.ResponseWriter, r *http.Reque } // Validate table name - if len(req.Name) < 1 || len(req.Name) > 255 { - h.writeError(w, http.StatusBadRequest, ErrCodeInvalidRequest, "table name must be between 1 and 255 characters") - return fmt.Errorf("invalid table name length") - } - if req.Name == "." || req.Name == ".." || strings.Contains(req.Name, "/") { - h.writeError(w, http.StatusBadRequest, ErrCodeInvalidRequest, "invalid table name: cannot be '.', '..' or contain '/'") - return fmt.Errorf("invalid table name") - } - for _, ch := range req.Name { - if (ch >= 'a' && ch <= 'z') || (ch >= '0' && ch <= '9') || ch == '_' { - continue - } - h.writeError(w, http.StatusBadRequest, ErrCodeInvalidRequest, "invalid table name: only 'a-z', '0-9', and '_' are allowed") - return fmt.Errorf("invalid table name") + tableName, err := validateTableName(req.Name) + if err != nil { + h.writeError(w, http.StatusBadRequest, ErrCodeInvalidRequest, err.Error()) + return err } // Check if namespace exists @@ -95,7 +85,7 @@ func (h *S3TablesHandler) handleCreateTable(w http.ResponseWriter, r *http.Reque return err } - tablePath := getTablePath(bucketName, namespaceName, req.Name) + tablePath := getTablePath(bucketName, namespaceName, tableName) // Check if table already exists err = filerClient.WithFilerClient(false, func(client filer_pb.SeaweedFilerClient) error { @@ -104,7 +94,7 @@ func (h *S3TablesHandler) handleCreateTable(w http.ResponseWriter, r *http.Reque }) if err == nil { - h.writeError(w, http.StatusConflict, ErrCodeTableAlreadyExists, fmt.Sprintf("table %s already exists", req.Name)) + h.writeError(w, http.StatusConflict, ErrCodeTableAlreadyExists, fmt.Sprintf("table %s already exists", tableName)) return fmt.Errorf("table already exists") } else if !errors.Is(err, filer_pb.ErrNotFound) { h.writeError(w, http.StatusInternalServerError, ErrCodeInternalError, fmt.Sprintf("failed to check table: %v", err)) @@ -116,7 +106,7 @@ func (h *S3TablesHandler) handleCreateTable(w http.ResponseWriter, r *http.Reque versionToken := generateVersionToken() metadata := &tableMetadataInternal{ - Name: req.Name, + Name: tableName, Namespace: namespaceName, Format: req.Format, CreatedAt: now, @@ -168,7 +158,7 @@ func (h *S3TablesHandler) handleCreateTable(w http.ResponseWriter, r *http.Reque return err } - tableARN := h.generateTableARN(bucketName, namespaceName+"/"+req.Name) + tableARN := h.generateTableARN(bucketName, namespaceName+"/"+tableName) resp := &CreateTableResponse{ TableARN: tableARN, @@ -215,7 +205,11 @@ func (h *S3TablesHandler) handleGetTable(w http.ResponseWriter, r *http.Request, h.writeError(w, http.StatusBadRequest, ErrCodeInvalidRequest, err.Error()) return err } - tableName = req.Name + tableName, err = validateTableName(req.Name) + if err != nil { + h.writeError(w, http.StatusBadRequest, ErrCodeInvalidRequest, err.Error()) + return err + } } else { h.writeError(w, http.StatusBadRequest, ErrCodeInvalidRequest, "either tableARN or (tableBucketARN, namespace, name) is required") return fmt.Errorf("missing required parameters") @@ -300,11 +294,11 @@ func (h *S3TablesHandler) handleListTables(w http.ResponseWriter, r *http.Reques return err } err = filerClient.WithFilerClient(false, func(client filer_pb.SeaweedFilerClient) error { - return h.listTablesInNamespaceWithClient(r.Context(), client, bucketName, namespaceName, req.Prefix, maxTables, &tables) + return h.listTablesInNamespaceWithClient(r.Context(), client, bucketName, namespaceName, req.Prefix, req.ContinuationToken, maxTables, &tables) }) } else { // List tables in all namespaces - err = h.listTablesInAllNamespaces(r.Context(), filerClient, bucketName, req.Prefix, maxTables, &tables) + err = h.listTablesInAllNamespaces(r.Context(), filerClient, bucketName, req.Prefix, req.ContinuationToken, maxTables, &tables) } if err != nil { @@ -312,24 +306,42 @@ func (h *S3TablesHandler) handleListTables(w http.ResponseWriter, r *http.Reques return err } + paginationToken := "" + if len(tables) >= maxTables && len(tables) > 0 { + // This is tricky for cross-namespace listing. For now, we'll store the full path if possible, + // but standard S3 tables usually lists within a namespace. + // If we are listing within a namespace, lastFileName is just the table name. + // If we are listing all namespaces, we might need a more complex token. + // For simplicity, let's assume if we hit the limit, we return the last seen entry's path-related info. + if len(req.Namespace) > 0 { + paginationToken = tables[len(tables)-1].Name + } else { + // For all-namespaces listing, we'd need to encode the namespace too. + // Let's use namespace/name as token. + lastTable := tables[len(tables)-1] + paginationToken = lastTable.Namespace[0] + "/" + lastTable.Name + } + } + resp := &ListTablesResponse{ - Tables: tables, + Tables: tables, + ContinuationToken: paginationToken, } h.writeJSON(w, http.StatusOK, resp) return nil } -func (h *S3TablesHandler) listTablesInNamespaceWithClient(ctx context.Context, client filer_pb.SeaweedFilerClient, bucketName, namespace, prefix string, maxTables int, tables *[]TableSummary) error { +func (h *S3TablesHandler) listTablesInNamespaceWithClient(ctx context.Context, client filer_pb.SeaweedFilerClient, bucketName, namespace, prefix, continuationToken string, maxTables int, tables *[]TableSummary) error { namespacePath := getNamespacePath(bucketName, namespace) - var lastFileName string + lastFileName := continuationToken for len(*tables) < maxTables { resp, err := client.ListEntries(ctx, &filer_pb.ListEntriesRequest{ Directory: namespacePath, Limit: uint32(maxTables * 2), StartFromFileName: lastFileName, - InclusiveStartFrom: lastFileName == "", + InclusiveStartFrom: lastFileName == "" || lastFileName == continuationToken, }) if err != nil { return err @@ -347,6 +359,12 @@ func (h *S3TablesHandler) listTablesInNamespaceWithClient(ctx context.Context, c if entry.Entry == nil { continue } + + // Skip the start item if it was included in the previous page + if len(*tables) == 0 && continuationToken != "" && entry.Entry.Name == continuationToken { + continue + } + hasMore = true lastFileName = entry.Entry.Name @@ -398,18 +416,28 @@ func (h *S3TablesHandler) listTablesInNamespaceWithClient(ctx context.Context, c return nil } -func (h *S3TablesHandler) listTablesInAllNamespaces(ctx context.Context, filerClient FilerClient, bucketName, prefix string, maxTables int, tables *[]TableSummary) error { +func (h *S3TablesHandler) listTablesInAllNamespaces(ctx context.Context, filerClient FilerClient, bucketName, prefix, continuationToken string, maxTables int, tables *[]TableSummary) error { bucketPath := getTableBucketPath(bucketName) - var lastFileName string + var lastNamespace string + var startTableName string + if continuationToken != "" { + if parts := strings.SplitN(continuationToken, "/", 2); len(parts) == 2 { + lastNamespace = parts[0] + startTableName = parts[1] + } else { + lastNamespace = continuationToken + } + } + return filerClient.WithFilerClient(false, func(client filer_pb.SeaweedFilerClient) error { for { // List namespaces in batches resp, err := client.ListEntries(ctx, &filer_pb.ListEntriesRequest{ Directory: bucketPath, Limit: 100, - StartFromFileName: lastFileName, - InclusiveStartFrom: lastFileName == "", + StartFromFileName: lastNamespace, + InclusiveStartFrom: lastNamespace == "", }) if err != nil { return err @@ -428,7 +456,7 @@ func (h *S3TablesHandler) listTablesInAllNamespaces(ctx context.Context, filerCl continue } hasMore = true - lastFileName = entry.Entry.Name + lastNamespace = entry.Entry.Name if !entry.Entry.IsDirectory { continue @@ -442,7 +470,12 @@ func (h *S3TablesHandler) listTablesInAllNamespaces(ctx context.Context, filerCl namespace := entry.Entry.Name // List tables in this namespace - if err := h.listTablesInNamespaceWithClient(ctx, client, bucketName, namespace, prefix, maxTables-len(*tables), tables); err != nil { + tableNameFilter := "" + if namespace == lastNamespace { + tableNameFilter = startTableName + } + + if err := h.listTablesInNamespaceWithClient(ctx, client, bucketName, namespace, prefix, tableNameFilter, maxTables-len(*tables), tables); err != nil { glog.Warningf("S3Tables: failed to list tables in namespace %s/%s: %v", bucketName, namespace, err) continue } @@ -493,7 +526,13 @@ func (h *S3TablesHandler) handleDeleteTable(w http.ResponseWriter, r *http.Reque return err } - tablePath := getTablePath(bucketName, namespaceName, req.Name) + tableName, err := validateTableName(req.Name) + if err != nil { + h.writeError(w, http.StatusBadRequest, ErrCodeInvalidRequest, err.Error()) + return err + } + + tablePath := getTablePath(bucketName, namespaceName, tableName) // Check if table exists err = filerClient.WithFilerClient(false, func(client filer_pb.SeaweedFilerClient) error { @@ -503,7 +542,7 @@ func (h *S3TablesHandler) handleDeleteTable(w http.ResponseWriter, r *http.Reque if err != nil { if errors.Is(err, filer_pb.ErrNotFound) { - h.writeError(w, http.StatusNotFound, ErrCodeNoSuchTable, fmt.Sprintf("table %s not found", req.Name)) + h.writeError(w, http.StatusNotFound, ErrCodeNoSuchTable, fmt.Sprintf("table %s not found", tableName)) } else { h.writeError(w, http.StatusInternalServerError, ErrCodeInternalError, fmt.Sprintf("failed to check table: %v", err)) }