From e901abffd335dfb1783386d0975f442a11718bd4 Mon Sep 17 00:00:00 2001 From: chrislu Date: Tue, 2 Sep 2025 15:40:38 -0700 Subject: [PATCH] address comments --- weed/command/db.go | 9 +- weed/command/sql.go | 9 +- weed/query/engine/hybrid_message_scanner.go | 5 +- weed/server/postgres/protocol.go | 99 ++++++++++++--------- 4 files changed, 74 insertions(+), 48 deletions(-) diff --git a/weed/command/db.go b/weed/command/db.go index 58d4984d0..71c46af4a 100644 --- a/weed/command/db.go +++ b/weed/command/db.go @@ -11,6 +11,7 @@ import ( "syscall" "time" + "github.com/seaweedfs/seaweedfs/weed/glog" "github.com/seaweedfs/seaweedfs/weed/server/postgres" "github.com/seaweedfs/seaweedfs/weed/util" ) @@ -183,6 +184,12 @@ func runDB(cmd *Command, args []string) bool { return false } + // Validate port number + if err := validatePortNumber(*dbOptions.port); err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + return false + } + // Setup TLS if requested var tlsConfig *tls.Config if *dbOptions.tlsCert != "" && *dbOptions.tlsKey != "" { @@ -357,7 +364,7 @@ func validatePortNumber(port int) error { return fmt.Errorf("port number must be between 1 and 65535, got %d", port) } if port < 1024 { - return fmt.Errorf("port number %d may require root privileges", port) + glog.Warningf("port number %d may require root privileges", port) } return nil } diff --git a/weed/command/sql.go b/weed/command/sql.go index 99722adf7..c1ee24452 100644 --- a/weed/command/sql.go +++ b/weed/command/sql.go @@ -15,6 +15,7 @@ import ( "github.com/peterh/liner" "github.com/seaweedfs/seaweedfs/weed/query/engine" "github.com/seaweedfs/seaweedfs/weed/util/grace" + "github.com/xwb1989/sqlparser" ) func init() { @@ -155,8 +156,12 @@ func executeFileQueries(ctx *SQLContext, filename string) bool { fmt.Printf("Executing queries from %s against %s...\n", filename, *sqlMaster) } - // Split file content into individual queries (simple approach) - queries := strings.Split(string(content), ";") + // Split file content into individual queries (robust approach) + queries, err := sqlparser.SplitStatementToPieces(string(content)) + if err != nil { + fmt.Printf("Error splitting SQL statements from file %s: %v\n", filename, err) + return false + } for i, query := range queries { query = strings.TrimSpace(query) diff --git a/weed/query/engine/hybrid_message_scanner.go b/weed/query/engine/hybrid_message_scanner.go index 9d33e399b..af313a5cb 100644 --- a/weed/query/engine/hybrid_message_scanner.go +++ b/weed/query/engine/hybrid_message_scanner.go @@ -380,10 +380,9 @@ func (hms *HybridMessageScanner) discoverTopicPartitions(ctx context.Context) ([ return nil, fmt.Errorf("failed to scan topic directory for partitions: %v", err) } - // If no partitions found, use fallback + // If no partitions found, return error instead of masking the issue if len(allPartitions) == 0 { - fmt.Printf("No partitions found in filesystem for topic %s, using default partition\n", hms.topic.String()) - return []topic.Partition{{RangeStart: 0, RangeStop: 1000}}, nil + return nil, fmt.Errorf("no partitions found for topic %s", hms.topic.String()) } fmt.Printf("Discovered %d partitions for topic %s\n", len(allPartitions), hms.topic.String()) diff --git a/weed/server/postgres/protocol.go b/weed/server/postgres/protocol.go index a67027384..1eb74fead 100644 --- a/weed/server/postgres/protocol.go +++ b/weed/server/postgres/protocol.go @@ -11,6 +11,7 @@ import ( "github.com/seaweedfs/seaweedfs/weed/glog" "github.com/seaweedfs/seaweedfs/weed/query/sqltypes" "github.com/seaweedfs/seaweedfs/weed/util/version" + "github.com/xwb1989/sqlparser" ) // handleMessage processes a single PostgreSQL protocol message @@ -85,69 +86,83 @@ func (s *PostgreSQLServer) handleSimpleQuery(session *PostgreSQLSession, query s } } - // Clean query by removing trailing semicolons and whitespace early - cleanQuery := strings.TrimSpace(query) - cleanQuery = strings.TrimSuffix(cleanQuery, ";") - cleanQuery = strings.TrimSpace(cleanQuery) - // Set database context in SQL engine if session database is different from current if session.database != "" && session.database != s.sqlEngine.GetCatalog().GetCurrentDatabase() { s.sqlEngine.GetCatalog().SetCurrentDatabase(session.database) } - // Handle PostgreSQL-specific system queries directly - if systemResult := s.handleSystemQuery(session, cleanQuery); systemResult != nil { - return s.sendSystemQueryResult(session, systemResult, cleanQuery) + // Split query string into individual statements to handle multi-statement queries + queries, err := sqlparser.SplitStatementToPieces(query) + if err != nil { + // If split fails, fall back to single query processing + queries = []string{strings.TrimSpace(strings.TrimSuffix(strings.TrimSpace(query), ";"))} } - // Execute using SQL engine directly - ctx := context.Background() - result, err := s.sqlEngine.ExecuteSQL(ctx, cleanQuery) - if err != nil { - // Send error message but keep connection alive - sendErr := s.sendError(session, "42000", err.Error()) - if sendErr != nil { - return sendErr + // Execute each statement sequentially + for _, singleQuery := range queries { + cleanQuery := strings.TrimSpace(singleQuery) + if cleanQuery == "" { + continue // Skip empty statements } - // Send ReadyForQuery to keep connection alive - return s.sendReadyForQuery(session) - } - if result.Error != nil { - // Send error message but keep connection alive - sendErr := s.sendError(session, "42000", result.Error.Error()) - if sendErr != nil { - return sendErr + // Handle PostgreSQL-specific system queries directly + if systemResult := s.handleSystemQuery(session, cleanQuery); systemResult != nil { + err := s.sendSystemQueryResult(session, systemResult, cleanQuery) + if err != nil { + return err + } + continue // Continue with next statement } - // Send ReadyForQuery to keep connection alive - return s.sendReadyForQuery(session) - } - // Send results - if len(result.Columns) > 0 { - // Send row description - err = s.sendRowDescription(session, result.Columns, result.Rows) + // Execute using SQL engine directly + ctx := context.Background() + result, err := s.sqlEngine.ExecuteSQL(ctx, cleanQuery) if err != nil { - return err + // Send error message but keep connection alive + sendErr := s.sendError(session, "42000", err.Error()) + if sendErr != nil { + return sendErr + } + // Send ReadyForQuery to keep connection alive + return s.sendReadyForQuery(session) + } + + if result.Error != nil { + // Send error message but keep connection alive + sendErr := s.sendError(session, "42000", result.Error.Error()) + if sendErr != nil { + return sendErr + } + // Send ReadyForQuery to keep connection alive + return s.sendReadyForQuery(session) } - // Send data rows - for _, row := range result.Rows { - err = s.sendDataRow(session, row) + // Send results for this statement + if len(result.Columns) > 0 { + // Send row description + err = s.sendRowDescription(session, result.Columns, result.Rows) if err != nil { return err } + + // Send data rows + for _, row := range result.Rows { + err = s.sendDataRow(session, row) + if err != nil { + return err + } + } } - } - // Send command complete - tag := s.getCommandTag(query, len(result.Rows)) - err = s.sendCommandComplete(session, tag) - if err != nil { - return err + // Send command complete for this statement + tag := s.getCommandTag(cleanQuery, len(result.Rows)) + err = s.sendCommandComplete(session, tag) + if err != nil { + return err + } } - // Send ready for query + // Send ready for query after all statements are processed return s.sendReadyForQuery(session) }