From e14a316aeb113213fcc963780c02d4447ea3918e Mon Sep 17 00:00:00 2001 From: chrislu Date: Tue, 2 Sep 2025 20:59:13 -0700 Subject: [PATCH] use schema instead of inferred result types --- weed/command/sql.go | 7 +- weed/query/engine/hybrid_message_scanner.go | 12 ++- weed/query/engine/types.go | 3 + weed/server/postgres/protocol.go | 100 +++++++++++++++++--- weed/server/postgres/server.go | 3 + 5 files changed, 105 insertions(+), 20 deletions(-) diff --git a/weed/command/sql.go b/weed/command/sql.go index 264e0515c..ddcf23569 100644 --- a/weed/command/sql.go +++ b/weed/command/sql.go @@ -295,10 +295,9 @@ func runInteractiveShell(ctx *SQLContext) bool { ctx.currentDatabase = dbName // Also update the SQL engine's catalog current database ctx.engine.GetCatalog().SetCurrentDatabase(dbName) - fmt.Printf("Database changed to: %s\n\n", dbName) - queryBuffer.Reset() - continue - } + fmt.Printf("Database changed to: %s\n\n", dbName) + queryBuffer.Reset() + continue } // Handle output format switching diff --git a/weed/query/engine/hybrid_message_scanner.go b/weed/query/engine/hybrid_message_scanner.go index af313a5cb..dc2b27491 100644 --- a/weed/query/engine/hybrid_message_scanner.go +++ b/weed/query/engine/hybrid_message_scanner.go @@ -780,8 +780,10 @@ func (hms *HybridMessageScanner) convertJSONValueToSchemaValue(jsonValue interfa func (hms *HybridMessageScanner) ConvertToSQLResult(results []HybridScanResult, columns []string) *QueryResult { if len(results) == 0 { return &QueryResult{ - Columns: columns, - Rows: [][]sqltypes.Value{}, + Columns: columns, + Rows: [][]sqltypes.Value{}, + Database: hms.topic.Namespace, + Table: hms.topic.Name, } } @@ -824,8 +826,10 @@ func (hms *HybridMessageScanner) ConvertToSQLResult(results []HybridScanResult, } return &QueryResult{ - Columns: columns, - Rows: rows, + Columns: columns, + Rows: rows, + Database: hms.topic.Namespace, + Table: hms.topic.Name, } } diff --git a/weed/query/engine/types.go b/weed/query/engine/types.go index 3b72ca7da..25877c01e 100644 --- a/weed/query/engine/types.go +++ b/weed/query/engine/types.go @@ -31,4 +31,7 @@ type QueryResult struct { Rows [][]sqltypes.Value `json:"rows"` Error error `json:"error,omitempty"` ExecutionPlan *QueryExecutionPlan `json:"execution_plan,omitempty"` + // Schema information for type inference (optional) + Database string `json:"database,omitempty"` + Table string `json:"table,omitempty"` } diff --git a/weed/server/postgres/protocol.go b/weed/server/postgres/protocol.go index ccb873a35..82fa777a0 100644 --- a/weed/server/postgres/protocol.go +++ b/weed/server/postgres/protocol.go @@ -9,6 +9,7 @@ import ( "strings" "github.com/seaweedfs/seaweedfs/weed/glog" + "github.com/seaweedfs/seaweedfs/weed/pb/schema_pb" "github.com/seaweedfs/seaweedfs/weed/query/engine" "github.com/seaweedfs/seaweedfs/weed/query/sqltypes" "github.com/seaweedfs/seaweedfs/weed/util/version" @@ -202,7 +203,7 @@ func (s *PostgreSQLServer) handleSimpleQuery(session *PostgreSQLSession, query s // Send results for this statement if len(result.Columns) > 0 { // Send row description - err = s.sendRowDescription(session, result.Columns, result.Rows) + err = s.sendRowDescription(session, result) if err != nil { return err } @@ -324,8 +325,12 @@ func (s *PostgreSQLServer) sendSystemQueryResult(session *PostgreSQLSession, res sqlRows = append(sqlRows, sqlRow) } - // Send row description - err := s.sendRowDescription(session, columns, sqlRows) + // Send row description (create a temporary QueryResult for consistency) + tempResult := &engine.QueryResult{ + Columns: columns, + Rows: sqlRows, + } + err := s.sendRowDescription(session, tempResult) if err != nil { return err } @@ -418,7 +423,11 @@ func (s *PostgreSQLServer) handleDescribe(session *PostgreSQLSession, msgBody [] glog.V(2).Infof("PostgreSQL Describe %c (ID: %d): %s", objectType, session.processID, objectName) // For now, send empty row description - return s.sendRowDescription(session, []string{}, [][]sqltypes.Value{}) + tempResult := &engine.QueryResult{ + Columns: []string{}, + Rows: [][]sqltypes.Value{}, + } + return s.sendRowDescription(session, tempResult) } // handleClose processes a Close message @@ -509,13 +518,13 @@ func (s *PostgreSQLServer) sendReadyForQuery(session *PostgreSQLSession) error { } // sendRowDescription sends row description message -func (s *PostgreSQLServer) sendRowDescription(session *PostgreSQLSession, columns []string, rows [][]sqltypes.Value) error { +func (s *PostgreSQLServer) sendRowDescription(session *PostgreSQLSession, result *engine.QueryResult) error { msg := make([]byte, 0) msg = append(msg, PG_RESP_ROW_DESC) // Calculate message length length := 4 + 2 // length + field count - for _, col := range columns { + for _, col := range result.Columns { length += len(col) + 1 + 4 + 2 + 4 + 2 + 4 + 2 // name + null + tableOID + attrNum + typeOID + typeSize + typeMod + format } @@ -525,11 +534,11 @@ func (s *PostgreSQLServer) sendRowDescription(session *PostgreSQLSession, column // Field count fieldCountBytes := make([]byte, 2) - binary.BigEndian.PutUint16(fieldCountBytes, uint16(len(columns))) + binary.BigEndian.PutUint16(fieldCountBytes, uint16(len(result.Columns))) msg = append(msg, fieldCountBytes...) // Field descriptions - for i, col := range columns { + for i, col := range result.Columns { // Field name msg = append(msg, []byte(col)...) msg = append(msg, 0) // null terminator @@ -544,8 +553,8 @@ func (s *PostgreSQLServer) sendRowDescription(session *PostgreSQLSession, column binary.BigEndian.PutUint16(attrNum, uint16(i+1)) msg = append(msg, attrNum...) - // Type OID (determine from data) - typeOID := s.getPostgreSQLType(columns, rows, i) + // Type OID (determine from schema if available, fallback to data inference) + typeOID := s.getPostgreSQLTypeFromSchema(result, col, i) typeOIDBytes := make([]byte, 4) binary.BigEndian.PutUint32(typeOIDBytes, typeOID) msg = append(msg, typeOIDBytes...) @@ -722,8 +731,75 @@ func (s *PostgreSQLServer) getCommandTag(query string, rowCount int) string { return "SELECT 0" } -// getPostgreSQLType determines PostgreSQL type OID from data -func (s *PostgreSQLServer) getPostgreSQLType(columns []string, rows [][]sqltypes.Value, colIndex int) uint32 { +// getPostgreSQLTypeFromSchema determines PostgreSQL type OID from schema information first, fallback to data +func (s *PostgreSQLServer) getPostgreSQLTypeFromSchema(result *engine.QueryResult, columnName string, colIndex int) uint32 { + // Try to get type from schema if database and table are available + if result.Database != "" && result.Table != "" { + if tableInfo, err := s.sqlEngine.GetCatalog().GetTableInfo(result.Database, result.Table); err == nil { + if tableInfo.Schema != nil && tableInfo.Schema.RecordType != nil { + // Look for the field in the schema + for _, field := range tableInfo.Schema.RecordType.Fields { + if field.Name == columnName { + return s.mapSchemaTypeToPostgreSQL(field.Type) + } + } + } + } + } + + // Handle system columns + switch columnName { + case "_timestamp_ns": + return PG_TYPE_INT8 // PostgreSQL BIGINT for nanosecond timestamps + case "_key": + return PG_TYPE_BYTEA // PostgreSQL BYTEA for binary keys + case "_source": + return PG_TYPE_TEXT // PostgreSQL TEXT for source information + } + + // Fallback to data-based inference if schema is not available + return s.getPostgreSQLTypeFromData(result.Columns, result.Rows, colIndex) +} + +// mapSchemaTypeToPostgreSQL maps SeaweedFS schema types to PostgreSQL type OIDs +func (s *PostgreSQLServer) mapSchemaTypeToPostgreSQL(fieldType *schema_pb.Type) uint32 { + if fieldType == nil { + return PG_TYPE_TEXT + } + + switch kind := fieldType.Kind.(type) { + case *schema_pb.Type_ScalarType: + switch kind.ScalarType { + case schema_pb.ScalarType_BOOL: + return PG_TYPE_BOOL + case schema_pb.ScalarType_INT32: + return PG_TYPE_INT4 + case schema_pb.ScalarType_INT64: + return PG_TYPE_INT8 + case schema_pb.ScalarType_FLOAT: + return PG_TYPE_FLOAT4 + case schema_pb.ScalarType_DOUBLE: + return PG_TYPE_FLOAT8 + case schema_pb.ScalarType_BYTES: + return PG_TYPE_BYTEA + case schema_pb.ScalarType_STRING: + return PG_TYPE_TEXT + default: + return PG_TYPE_TEXT + } + case *schema_pb.Type_ListType: + // For list types, we'll represent them as JSON text + return PG_TYPE_JSONB + case *schema_pb.Type_RecordType: + // For nested record types, we'll represent them as JSON text + return PG_TYPE_JSONB + default: + return PG_TYPE_TEXT + } +} + +// getPostgreSQLTypeFromData determines PostgreSQL type OID from data (legacy fallback method) +func (s *PostgreSQLServer) getPostgreSQLTypeFromData(columns []string, rows [][]sqltypes.Value, colIndex int) uint32 { if len(rows) == 0 || colIndex >= len(rows[0]) { return PG_TYPE_TEXT // Default to text } diff --git a/weed/server/postgres/server.go b/weed/server/postgres/server.go index 89a8c54f5..700aa9895 100644 --- a/weed/server/postgres/server.go +++ b/weed/server/postgres/server.go @@ -65,13 +65,16 @@ const ( // PostgreSQL data types PG_TYPE_BOOL = 16 + PG_TYPE_BYTEA = 17 PG_TYPE_INT8 = 20 PG_TYPE_INT4 = 23 PG_TYPE_TEXT = 25 + PG_TYPE_FLOAT4 = 700 PG_TYPE_FLOAT8 = 701 PG_TYPE_VARCHAR = 1043 PG_TYPE_TIMESTAMP = 1114 PG_TYPE_JSON = 114 + PG_TYPE_JSONB = 3802 // Default values DEFAULT_POSTGRES_PORT = 5432