Browse Source

s3tables: add table name validation and cleanup duplicated logic in table handlers

pull/8147/head
Chris Lu 3 weeks ago
parent
commit
47ef8c3cce
  1. 105
      weed/s3api/s3tables/handler_table.go

105
weed/s3api/s3tables/handler_table.go

@ -63,20 +63,10 @@ func (h *S3TablesHandler) handleCreateTable(w http.ResponseWriter, r *http.Reque
} }
// Validate table name // Validate table name
if len(req.Name) < 1 || len(req.Name) > 255 {
h.writeError(w, http.StatusBadRequest, ErrCodeInvalidRequest, "table name must be between 1 and 255 characters")
return fmt.Errorf("invalid table name length")
}
if req.Name == "." || req.Name == ".." || strings.Contains(req.Name, "/") {
h.writeError(w, http.StatusBadRequest, ErrCodeInvalidRequest, "invalid table name: cannot be '.', '..' or contain '/'")
return fmt.Errorf("invalid table name")
}
for _, ch := range req.Name {
if (ch >= 'a' && ch <= 'z') || (ch >= '0' && ch <= '9') || ch == '_' {
continue
}
h.writeError(w, http.StatusBadRequest, ErrCodeInvalidRequest, "invalid table name: only 'a-z', '0-9', and '_' are allowed")
return fmt.Errorf("invalid table name")
tableName, err := validateTableName(req.Name)
if err != nil {
h.writeError(w, http.StatusBadRequest, ErrCodeInvalidRequest, err.Error())
return err
} }
// Check if namespace exists // Check if namespace exists
@ -95,7 +85,7 @@ func (h *S3TablesHandler) handleCreateTable(w http.ResponseWriter, r *http.Reque
return err return err
} }
tablePath := getTablePath(bucketName, namespaceName, req.Name)
tablePath := getTablePath(bucketName, namespaceName, tableName)
// Check if table already exists // Check if table already exists
err = filerClient.WithFilerClient(false, func(client filer_pb.SeaweedFilerClient) error { err = filerClient.WithFilerClient(false, func(client filer_pb.SeaweedFilerClient) error {
@ -104,7 +94,7 @@ func (h *S3TablesHandler) handleCreateTable(w http.ResponseWriter, r *http.Reque
}) })
if err == nil { if err == nil {
h.writeError(w, http.StatusConflict, ErrCodeTableAlreadyExists, fmt.Sprintf("table %s already exists", req.Name))
h.writeError(w, http.StatusConflict, ErrCodeTableAlreadyExists, fmt.Sprintf("table %s already exists", tableName))
return fmt.Errorf("table already exists") return fmt.Errorf("table already exists")
} else if !errors.Is(err, filer_pb.ErrNotFound) { } else if !errors.Is(err, filer_pb.ErrNotFound) {
h.writeError(w, http.StatusInternalServerError, ErrCodeInternalError, fmt.Sprintf("failed to check table: %v", err)) h.writeError(w, http.StatusInternalServerError, ErrCodeInternalError, fmt.Sprintf("failed to check table: %v", err))
@ -116,7 +106,7 @@ func (h *S3TablesHandler) handleCreateTable(w http.ResponseWriter, r *http.Reque
versionToken := generateVersionToken() versionToken := generateVersionToken()
metadata := &tableMetadataInternal{ metadata := &tableMetadataInternal{
Name: req.Name,
Name: tableName,
Namespace: namespaceName, Namespace: namespaceName,
Format: req.Format, Format: req.Format,
CreatedAt: now, CreatedAt: now,
@ -168,7 +158,7 @@ func (h *S3TablesHandler) handleCreateTable(w http.ResponseWriter, r *http.Reque
return err return err
} }
tableARN := h.generateTableARN(bucketName, namespaceName+"/"+req.Name)
tableARN := h.generateTableARN(bucketName, namespaceName+"/"+tableName)
resp := &CreateTableResponse{ resp := &CreateTableResponse{
TableARN: tableARN, TableARN: tableARN,
@ -215,7 +205,11 @@ func (h *S3TablesHandler) handleGetTable(w http.ResponseWriter, r *http.Request,
h.writeError(w, http.StatusBadRequest, ErrCodeInvalidRequest, err.Error()) h.writeError(w, http.StatusBadRequest, ErrCodeInvalidRequest, err.Error())
return err return err
} }
tableName = req.Name
tableName, err = validateTableName(req.Name)
if err != nil {
h.writeError(w, http.StatusBadRequest, ErrCodeInvalidRequest, err.Error())
return err
}
} else { } else {
h.writeError(w, http.StatusBadRequest, ErrCodeInvalidRequest, "either tableARN or (tableBucketARN, namespace, name) is required") h.writeError(w, http.StatusBadRequest, ErrCodeInvalidRequest, "either tableARN or (tableBucketARN, namespace, name) is required")
return fmt.Errorf("missing required parameters") return fmt.Errorf("missing required parameters")
@ -300,11 +294,11 @@ func (h *S3TablesHandler) handleListTables(w http.ResponseWriter, r *http.Reques
return err return err
} }
err = filerClient.WithFilerClient(false, func(client filer_pb.SeaweedFilerClient) error { err = filerClient.WithFilerClient(false, func(client filer_pb.SeaweedFilerClient) error {
return h.listTablesInNamespaceWithClient(r.Context(), client, bucketName, namespaceName, req.Prefix, maxTables, &tables)
return h.listTablesInNamespaceWithClient(r.Context(), client, bucketName, namespaceName, req.Prefix, req.ContinuationToken, maxTables, &tables)
}) })
} else { } else {
// List tables in all namespaces // List tables in all namespaces
err = h.listTablesInAllNamespaces(r.Context(), filerClient, bucketName, req.Prefix, maxTables, &tables)
err = h.listTablesInAllNamespaces(r.Context(), filerClient, bucketName, req.Prefix, req.ContinuationToken, maxTables, &tables)
} }
if err != nil { if err != nil {
@ -312,24 +306,42 @@ func (h *S3TablesHandler) handleListTables(w http.ResponseWriter, r *http.Reques
return err return err
} }
paginationToken := ""
if len(tables) >= maxTables && len(tables) > 0 {
// This is tricky for cross-namespace listing. For now, we'll store the full path if possible,
// but standard S3 tables usually lists within a namespace.
// If we are listing within a namespace, lastFileName is just the table name.
// If we are listing all namespaces, we might need a more complex token.
// For simplicity, let's assume if we hit the limit, we return the last seen entry's path-related info.
if len(req.Namespace) > 0 {
paginationToken = tables[len(tables)-1].Name
} else {
// For all-namespaces listing, we'd need to encode the namespace too.
// Let's use namespace/name as token.
lastTable := tables[len(tables)-1]
paginationToken = lastTable.Namespace[0] + "/" + lastTable.Name
}
}
resp := &ListTablesResponse{ resp := &ListTablesResponse{
Tables: tables,
Tables: tables,
ContinuationToken: paginationToken,
} }
h.writeJSON(w, http.StatusOK, resp) h.writeJSON(w, http.StatusOK, resp)
return nil return nil
} }
func (h *S3TablesHandler) listTablesInNamespaceWithClient(ctx context.Context, client filer_pb.SeaweedFilerClient, bucketName, namespace, prefix string, maxTables int, tables *[]TableSummary) error {
func (h *S3TablesHandler) listTablesInNamespaceWithClient(ctx context.Context, client filer_pb.SeaweedFilerClient, bucketName, namespace, prefix, continuationToken string, maxTables int, tables *[]TableSummary) error {
namespacePath := getNamespacePath(bucketName, namespace) namespacePath := getNamespacePath(bucketName, namespace)
var lastFileName string
lastFileName := continuationToken
for len(*tables) < maxTables { for len(*tables) < maxTables {
resp, err := client.ListEntries(ctx, &filer_pb.ListEntriesRequest{ resp, err := client.ListEntries(ctx, &filer_pb.ListEntriesRequest{
Directory: namespacePath, Directory: namespacePath,
Limit: uint32(maxTables * 2), Limit: uint32(maxTables * 2),
StartFromFileName: lastFileName, StartFromFileName: lastFileName,
InclusiveStartFrom: lastFileName == "",
InclusiveStartFrom: lastFileName == "" || lastFileName == continuationToken,
}) })
if err != nil { if err != nil {
return err return err
@ -347,6 +359,12 @@ func (h *S3TablesHandler) listTablesInNamespaceWithClient(ctx context.Context, c
if entry.Entry == nil { if entry.Entry == nil {
continue continue
} }
// Skip the start item if it was included in the previous page
if len(*tables) == 0 && continuationToken != "" && entry.Entry.Name == continuationToken {
continue
}
hasMore = true hasMore = true
lastFileName = entry.Entry.Name lastFileName = entry.Entry.Name
@ -398,18 +416,28 @@ func (h *S3TablesHandler) listTablesInNamespaceWithClient(ctx context.Context, c
return nil return nil
} }
func (h *S3TablesHandler) listTablesInAllNamespaces(ctx context.Context, filerClient FilerClient, bucketName, prefix string, maxTables int, tables *[]TableSummary) error {
func (h *S3TablesHandler) listTablesInAllNamespaces(ctx context.Context, filerClient FilerClient, bucketName, prefix, continuationToken string, maxTables int, tables *[]TableSummary) error {
bucketPath := getTableBucketPath(bucketName) bucketPath := getTableBucketPath(bucketName)
var lastFileName string
var lastNamespace string
var startTableName string
if continuationToken != "" {
if parts := strings.SplitN(continuationToken, "/", 2); len(parts) == 2 {
lastNamespace = parts[0]
startTableName = parts[1]
} else {
lastNamespace = continuationToken
}
}
return filerClient.WithFilerClient(false, func(client filer_pb.SeaweedFilerClient) error { return filerClient.WithFilerClient(false, func(client filer_pb.SeaweedFilerClient) error {
for { for {
// List namespaces in batches // List namespaces in batches
resp, err := client.ListEntries(ctx, &filer_pb.ListEntriesRequest{ resp, err := client.ListEntries(ctx, &filer_pb.ListEntriesRequest{
Directory: bucketPath, Directory: bucketPath,
Limit: 100, Limit: 100,
StartFromFileName: lastFileName,
InclusiveStartFrom: lastFileName == "",
StartFromFileName: lastNamespace,
InclusiveStartFrom: lastNamespace == "",
}) })
if err != nil { if err != nil {
return err return err
@ -428,7 +456,7 @@ func (h *S3TablesHandler) listTablesInAllNamespaces(ctx context.Context, filerCl
continue continue
} }
hasMore = true hasMore = true
lastFileName = entry.Entry.Name
lastNamespace = entry.Entry.Name
if !entry.Entry.IsDirectory { if !entry.Entry.IsDirectory {
continue continue
@ -442,7 +470,12 @@ func (h *S3TablesHandler) listTablesInAllNamespaces(ctx context.Context, filerCl
namespace := entry.Entry.Name namespace := entry.Entry.Name
// List tables in this namespace // List tables in this namespace
if err := h.listTablesInNamespaceWithClient(ctx, client, bucketName, namespace, prefix, maxTables-len(*tables), tables); err != nil {
tableNameFilter := ""
if namespace == lastNamespace {
tableNameFilter = startTableName
}
if err := h.listTablesInNamespaceWithClient(ctx, client, bucketName, namespace, prefix, tableNameFilter, maxTables-len(*tables), tables); err != nil {
glog.Warningf("S3Tables: failed to list tables in namespace %s/%s: %v", bucketName, namespace, err) glog.Warningf("S3Tables: failed to list tables in namespace %s/%s: %v", bucketName, namespace, err)
continue continue
} }
@ -493,7 +526,13 @@ func (h *S3TablesHandler) handleDeleteTable(w http.ResponseWriter, r *http.Reque
return err return err
} }
tablePath := getTablePath(bucketName, namespaceName, req.Name)
tableName, err := validateTableName(req.Name)
if err != nil {
h.writeError(w, http.StatusBadRequest, ErrCodeInvalidRequest, err.Error())
return err
}
tablePath := getTablePath(bucketName, namespaceName, tableName)
// Check if table exists // Check if table exists
err = filerClient.WithFilerClient(false, func(client filer_pb.SeaweedFilerClient) error { err = filerClient.WithFilerClient(false, func(client filer_pb.SeaweedFilerClient) error {
@ -503,7 +542,7 @@ func (h *S3TablesHandler) handleDeleteTable(w http.ResponseWriter, r *http.Reque
if err != nil { if err != nil {
if errors.Is(err, filer_pb.ErrNotFound) { if errors.Is(err, filer_pb.ErrNotFound) {
h.writeError(w, http.StatusNotFound, ErrCodeNoSuchTable, fmt.Sprintf("table %s not found", req.Name))
h.writeError(w, http.StatusNotFound, ErrCodeNoSuchTable, fmt.Sprintf("table %s not found", tableName))
} else { } else {
h.writeError(w, http.StatusInternalServerError, ErrCodeInternalError, fmt.Sprintf("failed to check table: %v", err)) h.writeError(w, http.StatusInternalServerError, ErrCodeInternalError, fmt.Sprintf("failed to check table: %v", err))
} }

Loading…
Cancel
Save