You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							704 lines
						
					
					
						
							17 KiB
						
					
					
				
			
		
		
		
			
			
			
		
		
	
	
							704 lines
						
					
					
						
							17 KiB
						
					
					
				
								package postgres
							 | 
						|
								
							 | 
						|
								import (
							 | 
						|
									"bufio"
							 | 
						|
									"crypto/md5"
							 | 
						|
									"crypto/rand"
							 | 
						|
									"crypto/tls"
							 | 
						|
									"encoding/binary"
							 | 
						|
									"fmt"
							 | 
						|
									"io"
							 | 
						|
									"net"
							 | 
						|
									"strings"
							 | 
						|
									"sync"
							 | 
						|
									"time"
							 | 
						|
								
							 | 
						|
									"github.com/seaweedfs/seaweedfs/weed/glog"
							 | 
						|
									"github.com/seaweedfs/seaweedfs/weed/query/engine"
							 | 
						|
									"github.com/seaweedfs/seaweedfs/weed/util/version"
							 | 
						|
								)
							 | 
						|
								
							 | 
						|
								// PostgreSQL protocol constants
							 | 
						|
								const (
							 | 
						|
									// Protocol versions
							 | 
						|
									PG_PROTOCOL_VERSION_3 = 196608   // PostgreSQL 3.0 protocol (0x00030000)
							 | 
						|
									PG_SSL_REQUEST        = 80877103 // SSL request (0x04d2162f)
							 | 
						|
									PG_GSSAPI_REQUEST     = 80877104 // GSSAPI request (0x04d21630)
							 | 
						|
								
							 | 
						|
									// Message types from client
							 | 
						|
									PG_MSG_STARTUP   = 0x00
							 | 
						|
									PG_MSG_QUERY     = 'Q'
							 | 
						|
									PG_MSG_PARSE     = 'P'
							 | 
						|
									PG_MSG_BIND      = 'B'
							 | 
						|
									PG_MSG_EXECUTE   = 'E'
							 | 
						|
									PG_MSG_DESCRIBE  = 'D'
							 | 
						|
									PG_MSG_CLOSE     = 'C'
							 | 
						|
									PG_MSG_FLUSH     = 'H'
							 | 
						|
									PG_MSG_SYNC      = 'S'
							 | 
						|
									PG_MSG_TERMINATE = 'X'
							 | 
						|
									PG_MSG_PASSWORD  = 'p'
							 | 
						|
								
							 | 
						|
									// Response types to client
							 | 
						|
									PG_RESP_AUTH_OK        = 'R'
							 | 
						|
									PG_RESP_BACKEND_KEY    = 'K'
							 | 
						|
									PG_RESP_PARAMETER      = 'S'
							 | 
						|
									PG_RESP_READY          = 'Z'
							 | 
						|
									PG_RESP_COMMAND        = 'C'
							 | 
						|
									PG_RESP_DATA_ROW       = 'D'
							 | 
						|
									PG_RESP_ROW_DESC       = 'T'
							 | 
						|
									PG_RESP_PARSE_COMPLETE = '1'
							 | 
						|
									PG_RESP_BIND_COMPLETE  = '2'
							 | 
						|
									PG_RESP_CLOSE_COMPLETE = '3'
							 | 
						|
									PG_RESP_ERROR          = 'E'
							 | 
						|
									PG_RESP_NOTICE         = 'N'
							 | 
						|
								
							 | 
						|
									// Transaction states
							 | 
						|
									PG_TRANS_IDLE    = 'I'
							 | 
						|
									PG_TRANS_INTRANS = 'T'
							 | 
						|
									PG_TRANS_ERROR   = 'E'
							 | 
						|
								
							 | 
						|
									// Authentication methods
							 | 
						|
									AUTH_OK    = 0
							 | 
						|
									AUTH_CLEAR = 3
							 | 
						|
									AUTH_MD5   = 5
							 | 
						|
									AUTH_TRUST = 10
							 | 
						|
								
							 | 
						|
									// PostgreSQL data types
							 | 
						|
									PG_TYPE_BOOL      = 16
							 | 
						|
									PG_TYPE_BYTEA     = 17
							 | 
						|
									PG_TYPE_INT8      = 20
							 | 
						|
									PG_TYPE_INT4      = 23
							 | 
						|
									PG_TYPE_TEXT      = 25
							 | 
						|
									PG_TYPE_FLOAT4    = 700
							 | 
						|
									PG_TYPE_FLOAT8    = 701
							 | 
						|
									PG_TYPE_VARCHAR   = 1043
							 | 
						|
									PG_TYPE_TIMESTAMP = 1114
							 | 
						|
									PG_TYPE_JSON      = 114
							 | 
						|
									PG_TYPE_JSONB     = 3802
							 | 
						|
								
							 | 
						|
									// Default values
							 | 
						|
									DEFAULT_POSTGRES_PORT = 5432
							 | 
						|
								)
							 | 
						|
								
							 | 
						|
								// Authentication method type
							 | 
						|
								type AuthMethod int
							 | 
						|
								
							 | 
						|
								const (
							 | 
						|
									AuthTrust AuthMethod = iota
							 | 
						|
									AuthPassword
							 | 
						|
									AuthMD5
							 | 
						|
								)
							 | 
						|
								
							 | 
						|
								// PostgreSQL server configuration
							 | 
						|
								type PostgreSQLServerConfig struct {
							 | 
						|
									Host           string
							 | 
						|
									Port           int
							 | 
						|
									AuthMethod     AuthMethod
							 | 
						|
									Users          map[string]string
							 | 
						|
									TLSConfig      *tls.Config
							 | 
						|
									MaxConns       int
							 | 
						|
									IdleTimeout    time.Duration
							 | 
						|
									StartupTimeout time.Duration // Timeout for client startup handshake
							 | 
						|
									Database       string
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// PostgreSQL server
							 | 
						|
								type PostgreSQLServer struct {
							 | 
						|
									config     *PostgreSQLServerConfig
							 | 
						|
									listener   net.Listener
							 | 
						|
									sqlEngine  *engine.SQLEngine
							 | 
						|
									sessions   map[uint32]*PostgreSQLSession
							 | 
						|
									sessionMux sync.RWMutex
							 | 
						|
									shutdown   chan struct{}
							 | 
						|
									wg         sync.WaitGroup
							 | 
						|
									nextConnID uint32
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// PostgreSQL session
							 | 
						|
								type PostgreSQLSession struct {
							 | 
						|
									conn             net.Conn
							 | 
						|
									reader           *bufio.Reader
							 | 
						|
									writer           *bufio.Writer
							 | 
						|
									authenticated    bool
							 | 
						|
									username         string
							 | 
						|
									database         string
							 | 
						|
									parameters       map[string]string
							 | 
						|
									preparedStmts    map[string]*PreparedStatement
							 | 
						|
									portals          map[string]*Portal
							 | 
						|
									transactionState byte
							 | 
						|
									processID        uint32
							 | 
						|
									secretKey        uint32
							 | 
						|
									created          time.Time
							 | 
						|
									lastActivity     time.Time
							 | 
						|
									mutex            sync.Mutex
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// Prepared statement
							 | 
						|
								type PreparedStatement struct {
							 | 
						|
									Name       string
							 | 
						|
									Query      string
							 | 
						|
									ParamTypes []uint32
							 | 
						|
									Fields     []FieldDescription
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// Portal (cursor)
							 | 
						|
								type Portal struct {
							 | 
						|
									Name       string
							 | 
						|
									Statement  string
							 | 
						|
									Parameters [][]byte
							 | 
						|
									Suspended  bool
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// Field description
							 | 
						|
								type FieldDescription struct {
							 | 
						|
									Name     string
							 | 
						|
									TableOID uint32
							 | 
						|
									AttrNum  int16
							 | 
						|
									TypeOID  uint32
							 | 
						|
									TypeSize int16
							 | 
						|
									TypeMod  int32
							 | 
						|
									Format   int16
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// NewPostgreSQLServer creates a new PostgreSQL protocol server
							 | 
						|
								func NewPostgreSQLServer(config *PostgreSQLServerConfig, masterAddr string) (*PostgreSQLServer, error) {
							 | 
						|
									if config.Port <= 0 {
							 | 
						|
										config.Port = DEFAULT_POSTGRES_PORT
							 | 
						|
									}
							 | 
						|
									if config.Host == "" {
							 | 
						|
										config.Host = "localhost"
							 | 
						|
									}
							 | 
						|
									if config.Database == "" {
							 | 
						|
										config.Database = "default"
							 | 
						|
									}
							 | 
						|
									if config.MaxConns <= 0 {
							 | 
						|
										config.MaxConns = 100
							 | 
						|
									}
							 | 
						|
									if config.IdleTimeout <= 0 {
							 | 
						|
										config.IdleTimeout = time.Hour
							 | 
						|
									}
							 | 
						|
									if config.StartupTimeout <= 0 {
							 | 
						|
										config.StartupTimeout = 30 * time.Second
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									// Create SQL engine (now uses CockroachDB parser for PostgreSQL compatibility)
							 | 
						|
									sqlEngine := engine.NewSQLEngine(masterAddr)
							 | 
						|
								
							 | 
						|
									server := &PostgreSQLServer{
							 | 
						|
										config:     config,
							 | 
						|
										sqlEngine:  sqlEngine,
							 | 
						|
										sessions:   make(map[uint32]*PostgreSQLSession),
							 | 
						|
										shutdown:   make(chan struct{}),
							 | 
						|
										nextConnID: 1,
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									return server, nil
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// Start begins listening for PostgreSQL connections
							 | 
						|
								func (s *PostgreSQLServer) Start() error {
							 | 
						|
									addr := fmt.Sprintf("%s:%d", s.config.Host, s.config.Port)
							 | 
						|
								
							 | 
						|
									var listener net.Listener
							 | 
						|
									var err error
							 | 
						|
								
							 | 
						|
									if s.config.TLSConfig != nil {
							 | 
						|
										listener, err = tls.Listen("tcp", addr, s.config.TLSConfig)
							 | 
						|
										glog.Infof("PostgreSQL Server with TLS listening on %s", addr)
							 | 
						|
									} else {
							 | 
						|
										listener, err = net.Listen("tcp", addr)
							 | 
						|
										glog.Infof("PostgreSQL Server listening on %s", addr)
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									if err != nil {
							 | 
						|
										return fmt.Errorf("failed to start PostgreSQL server on %s: %v", addr, err)
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									s.listener = listener
							 | 
						|
								
							 | 
						|
									// Start accepting connections
							 | 
						|
									s.wg.Add(1)
							 | 
						|
									go s.acceptConnections()
							 | 
						|
								
							 | 
						|
									// Start cleanup routine
							 | 
						|
									s.wg.Add(1)
							 | 
						|
									go s.cleanupSessions()
							 | 
						|
								
							 | 
						|
									return nil
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// Stop gracefully shuts down the PostgreSQL server
							 | 
						|
								func (s *PostgreSQLServer) Stop() error {
							 | 
						|
									close(s.shutdown)
							 | 
						|
								
							 | 
						|
									if s.listener != nil {
							 | 
						|
										s.listener.Close()
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									// Close all sessions
							 | 
						|
									s.sessionMux.Lock()
							 | 
						|
									for _, session := range s.sessions {
							 | 
						|
										session.close()
							 | 
						|
									}
							 | 
						|
									s.sessions = make(map[uint32]*PostgreSQLSession)
							 | 
						|
									s.sessionMux.Unlock()
							 | 
						|
								
							 | 
						|
									s.wg.Wait()
							 | 
						|
									glog.Infof("PostgreSQL Server stopped")
							 | 
						|
									return nil
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// acceptConnections handles incoming PostgreSQL connections
							 | 
						|
								func (s *PostgreSQLServer) acceptConnections() {
							 | 
						|
									defer s.wg.Done()
							 | 
						|
								
							 | 
						|
									for {
							 | 
						|
										select {
							 | 
						|
										case <-s.shutdown:
							 | 
						|
											return
							 | 
						|
										default:
							 | 
						|
										}
							 | 
						|
								
							 | 
						|
										conn, err := s.listener.Accept()
							 | 
						|
										if err != nil {
							 | 
						|
											select {
							 | 
						|
											case <-s.shutdown:
							 | 
						|
												return
							 | 
						|
											default:
							 | 
						|
												glog.Errorf("Failed to accept PostgreSQL connection: %v", err)
							 | 
						|
												continue
							 | 
						|
											}
							 | 
						|
										}
							 | 
						|
								
							 | 
						|
										// Check connection limit
							 | 
						|
										s.sessionMux.RLock()
							 | 
						|
										sessionCount := len(s.sessions)
							 | 
						|
										s.sessionMux.RUnlock()
							 | 
						|
								
							 | 
						|
										if sessionCount >= s.config.MaxConns {
							 | 
						|
											glog.Warningf("Maximum connections reached (%d), rejecting connection from %s",
							 | 
						|
												s.config.MaxConns, conn.RemoteAddr())
							 | 
						|
											conn.Close()
							 | 
						|
											continue
							 | 
						|
										}
							 | 
						|
								
							 | 
						|
										s.wg.Add(1)
							 | 
						|
										go s.handleConnection(conn)
							 | 
						|
									}
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// handleConnection processes a single PostgreSQL connection
							 | 
						|
								func (s *PostgreSQLServer) handleConnection(conn net.Conn) {
							 | 
						|
									defer s.wg.Done()
							 | 
						|
									defer conn.Close()
							 | 
						|
								
							 | 
						|
									// Generate unique connection ID
							 | 
						|
									connID := s.generateConnectionID()
							 | 
						|
									secretKey := s.generateSecretKey()
							 | 
						|
								
							 | 
						|
									// Create session
							 | 
						|
									session := &PostgreSQLSession{
							 | 
						|
										conn:             conn,
							 | 
						|
										reader:           bufio.NewReader(conn),
							 | 
						|
										writer:           bufio.NewWriter(conn),
							 | 
						|
										authenticated:    false,
							 | 
						|
										database:         s.config.Database,
							 | 
						|
										parameters:       make(map[string]string),
							 | 
						|
										preparedStmts:    make(map[string]*PreparedStatement),
							 | 
						|
										portals:          make(map[string]*Portal),
							 | 
						|
										transactionState: PG_TRANS_IDLE,
							 | 
						|
										processID:        connID,
							 | 
						|
										secretKey:        secretKey,
							 | 
						|
										created:          time.Now(),
							 | 
						|
										lastActivity:     time.Now(),
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									// Register session
							 | 
						|
									s.sessionMux.Lock()
							 | 
						|
									s.sessions[connID] = session
							 | 
						|
									s.sessionMux.Unlock()
							 | 
						|
								
							 | 
						|
									// Clean up on exit
							 | 
						|
									defer func() {
							 | 
						|
										s.sessionMux.Lock()
							 | 
						|
										delete(s.sessions, connID)
							 | 
						|
										s.sessionMux.Unlock()
							 | 
						|
									}()
							 | 
						|
								
							 | 
						|
									glog.V(2).Infof("New PostgreSQL connection from %s (ID: %d)", conn.RemoteAddr(), connID)
							 | 
						|
								
							 | 
						|
									// Handle startup
							 | 
						|
									err := s.handleStartup(session)
							 | 
						|
									if err != nil {
							 | 
						|
										// Handle common disconnection scenarios more gracefully
							 | 
						|
										if strings.Contains(err.Error(), "client disconnected") {
							 | 
						|
											glog.V(1).Infof("Client startup disconnected from %s (ID: %d): %v", conn.RemoteAddr(), connID, err)
							 | 
						|
										} else if strings.Contains(err.Error(), "timeout") {
							 | 
						|
											glog.Warningf("Startup timeout for connection %d from %s: %v", connID, conn.RemoteAddr(), err)
							 | 
						|
										} else {
							 | 
						|
											glog.Errorf("Startup failed for connection %d from %s: %v", connID, conn.RemoteAddr(), err)
							 | 
						|
										}
							 | 
						|
										return
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									// Handle messages
							 | 
						|
									for {
							 | 
						|
										select {
							 | 
						|
										case <-s.shutdown:
							 | 
						|
											return
							 | 
						|
										default:
							 | 
						|
										}
							 | 
						|
								
							 | 
						|
										// Set read timeout
							 | 
						|
										conn.SetReadDeadline(time.Now().Add(30 * time.Second))
							 | 
						|
								
							 | 
						|
										err := s.handleMessage(session)
							 | 
						|
										if err != nil {
							 | 
						|
											if err == io.EOF {
							 | 
						|
												glog.Infof("PostgreSQL client disconnected (ID: %d)", connID)
							 | 
						|
											} else {
							 | 
						|
												glog.Errorf("Error handling PostgreSQL message (ID: %d): %v", connID, err)
							 | 
						|
											}
							 | 
						|
											return
							 | 
						|
										}
							 | 
						|
								
							 | 
						|
										session.lastActivity = time.Now()
							 | 
						|
									}
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// handleStartup processes the PostgreSQL startup sequence
							 | 
						|
								func (s *PostgreSQLServer) handleStartup(session *PostgreSQLSession) error {
							 | 
						|
									// Set a startup timeout to prevent hanging connections
							 | 
						|
									startupTimeout := s.config.StartupTimeout
							 | 
						|
									session.conn.SetReadDeadline(time.Now().Add(startupTimeout))
							 | 
						|
									defer session.conn.SetReadDeadline(time.Time{}) // Clear timeout
							 | 
						|
								
							 | 
						|
									for {
							 | 
						|
										// Read startup message length
							 | 
						|
										length := make([]byte, 4)
							 | 
						|
										_, err := io.ReadFull(session.reader, length)
							 | 
						|
										if err != nil {
							 | 
						|
											if err == io.EOF {
							 | 
						|
												// Client disconnected during startup - this is common for health checks
							 | 
						|
												return fmt.Errorf("client disconnected during startup handshake")
							 | 
						|
											}
							 | 
						|
											if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
							 | 
						|
												return fmt.Errorf("startup handshake timeout after %v", startupTimeout)
							 | 
						|
											}
							 | 
						|
											return fmt.Errorf("failed to read message length during startup: %v", err)
							 | 
						|
										}
							 | 
						|
								
							 | 
						|
										msgLength := binary.BigEndian.Uint32(length) - 4
							 | 
						|
										if msgLength > 10000 { // Reasonable limit for startup messages
							 | 
						|
											return fmt.Errorf("startup message too large: %d bytes", msgLength)
							 | 
						|
										}
							 | 
						|
								
							 | 
						|
										// Read startup message content
							 | 
						|
										msg := make([]byte, msgLength)
							 | 
						|
										_, err = io.ReadFull(session.reader, msg)
							 | 
						|
										if err != nil {
							 | 
						|
											if err == io.EOF {
							 | 
						|
												return fmt.Errorf("client disconnected while reading startup message")
							 | 
						|
											}
							 | 
						|
											if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
							 | 
						|
												return fmt.Errorf("startup message read timeout")
							 | 
						|
											}
							 | 
						|
											return fmt.Errorf("failed to read startup message: %v", err)
							 | 
						|
										}
							 | 
						|
								
							 | 
						|
										// Parse protocol version
							 | 
						|
										protocolVersion := binary.BigEndian.Uint32(msg[0:4])
							 | 
						|
								
							 | 
						|
										switch protocolVersion {
							 | 
						|
										case PG_SSL_REQUEST:
							 | 
						|
											// Reject SSL request - send 'N' to indicate SSL not supported
							 | 
						|
											_, err = session.conn.Write([]byte{'N'})
							 | 
						|
											if err != nil {
							 | 
						|
												return fmt.Errorf("failed to reject SSL request: %v", err)
							 | 
						|
											}
							 | 
						|
											// Continue loop to read the actual startup message
							 | 
						|
											continue
							 | 
						|
								
							 | 
						|
										case PG_GSSAPI_REQUEST:
							 | 
						|
											// Reject GSSAPI request - send 'N' to indicate GSSAPI not supported
							 | 
						|
											_, err = session.conn.Write([]byte{'N'})
							 | 
						|
											if err != nil {
							 | 
						|
												return fmt.Errorf("failed to reject GSSAPI request: %v", err)
							 | 
						|
											}
							 | 
						|
											// Continue loop to read the actual startup message
							 | 
						|
											continue
							 | 
						|
								
							 | 
						|
										case PG_PROTOCOL_VERSION_3:
							 | 
						|
											// This is the actual startup message, break out of loop
							 | 
						|
											break
							 | 
						|
								
							 | 
						|
										default:
							 | 
						|
											return fmt.Errorf("unsupported protocol version: %d", protocolVersion)
							 | 
						|
										}
							 | 
						|
								
							 | 
						|
										// Parse parameters
							 | 
						|
										params := strings.Split(string(msg[4:]), "\x00")
							 | 
						|
										for i := 0; i < len(params)-1; i += 2 {
							 | 
						|
											if params[i] == "user" {
							 | 
						|
												session.username = params[i+1]
							 | 
						|
											} else if params[i] == "database" {
							 | 
						|
												session.database = params[i+1]
							 | 
						|
											}
							 | 
						|
											session.parameters[params[i]] = params[i+1]
							 | 
						|
										}
							 | 
						|
								
							 | 
						|
										// Break out of the main loop - we have the startup message
							 | 
						|
										break
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									// Handle authentication
							 | 
						|
									err := s.handleAuthentication(session)
							 | 
						|
									if err != nil {
							 | 
						|
										return err
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									// Send parameter status messages
							 | 
						|
									err = s.sendParameterStatus(session, "server_version", fmt.Sprintf("%s (SeaweedFS)", version.VERSION_NUMBER))
							 | 
						|
									if err != nil {
							 | 
						|
										return err
							 | 
						|
									}
							 | 
						|
									err = s.sendParameterStatus(session, "server_encoding", "UTF8")
							 | 
						|
									if err != nil {
							 | 
						|
										return err
							 | 
						|
									}
							 | 
						|
									err = s.sendParameterStatus(session, "client_encoding", "UTF8")
							 | 
						|
									if err != nil {
							 | 
						|
										return err
							 | 
						|
									}
							 | 
						|
									err = s.sendParameterStatus(session, "DateStyle", "ISO, MDY")
							 | 
						|
									if err != nil {
							 | 
						|
										return err
							 | 
						|
									}
							 | 
						|
									err = s.sendParameterStatus(session, "integer_datetimes", "on")
							 | 
						|
									if err != nil {
							 | 
						|
										return err
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									// Send backend key data
							 | 
						|
									err = s.sendBackendKeyData(session)
							 | 
						|
									if err != nil {
							 | 
						|
										return err
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									// Send ready for query
							 | 
						|
									err = s.sendReadyForQuery(session)
							 | 
						|
									if err != nil {
							 | 
						|
										return err
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									session.authenticated = true
							 | 
						|
									return nil
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// handleAuthentication processes authentication
							 | 
						|
								func (s *PostgreSQLServer) handleAuthentication(session *PostgreSQLSession) error {
							 | 
						|
									switch s.config.AuthMethod {
							 | 
						|
									case AuthTrust:
							 | 
						|
										return s.sendAuthenticationOk(session)
							 | 
						|
									case AuthPassword:
							 | 
						|
										return s.handlePasswordAuth(session)
							 | 
						|
									case AuthMD5:
							 | 
						|
										return s.handleMD5Auth(session)
							 | 
						|
									default:
							 | 
						|
										return fmt.Errorf("unsupported authentication method")
							 | 
						|
									}
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// sendAuthenticationOk sends authentication OK message
							 | 
						|
								func (s *PostgreSQLServer) sendAuthenticationOk(session *PostgreSQLSession) error {
							 | 
						|
									msg := make([]byte, 9)
							 | 
						|
									msg[0] = PG_RESP_AUTH_OK
							 | 
						|
									binary.BigEndian.PutUint32(msg[1:5], 8)
							 | 
						|
									binary.BigEndian.PutUint32(msg[5:9], AUTH_OK)
							 | 
						|
								
							 | 
						|
									_, err := session.writer.Write(msg)
							 | 
						|
									if err == nil {
							 | 
						|
										err = session.writer.Flush()
							 | 
						|
									}
							 | 
						|
									return err
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// handlePasswordAuth handles clear password authentication
							 | 
						|
								func (s *PostgreSQLServer) handlePasswordAuth(session *PostgreSQLSession) error {
							 | 
						|
									// Send password request
							 | 
						|
									msg := make([]byte, 9)
							 | 
						|
									msg[0] = PG_RESP_AUTH_OK
							 | 
						|
									binary.BigEndian.PutUint32(msg[1:5], 8)
							 | 
						|
									binary.BigEndian.PutUint32(msg[5:9], AUTH_CLEAR)
							 | 
						|
								
							 | 
						|
									_, err := session.writer.Write(msg)
							 | 
						|
									if err != nil {
							 | 
						|
										return err
							 | 
						|
									}
							 | 
						|
									err = session.writer.Flush()
							 | 
						|
									if err != nil {
							 | 
						|
										return err
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									// Read password response
							 | 
						|
									msgType := make([]byte, 1)
							 | 
						|
									_, err = io.ReadFull(session.reader, msgType)
							 | 
						|
									if err != nil {
							 | 
						|
										return err
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									if msgType[0] != PG_MSG_PASSWORD {
							 | 
						|
										return fmt.Errorf("expected password message, got %c", msgType[0])
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									length := make([]byte, 4)
							 | 
						|
									_, err = io.ReadFull(session.reader, length)
							 | 
						|
									if err != nil {
							 | 
						|
										return err
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									msgLength := binary.BigEndian.Uint32(length) - 4
							 | 
						|
									password := make([]byte, msgLength)
							 | 
						|
									_, err = io.ReadFull(session.reader, password)
							 | 
						|
									if err != nil {
							 | 
						|
										return err
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									// Verify password
							 | 
						|
									expectedPassword, exists := s.config.Users[session.username]
							 | 
						|
									if !exists || string(password[:len(password)-1]) != expectedPassword { // Remove null terminator
							 | 
						|
										return s.sendError(session, "28P01", "authentication failed for user \""+session.username+"\"")
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									return s.sendAuthenticationOk(session)
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// handleMD5Auth handles MD5 password authentication
							 | 
						|
								func (s *PostgreSQLServer) handleMD5Auth(session *PostgreSQLSession) error {
							 | 
						|
									// Generate salt
							 | 
						|
									salt := make([]byte, 4)
							 | 
						|
									_, err := rand.Read(salt)
							 | 
						|
									if err != nil {
							 | 
						|
										return err
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									// Send MD5 request
							 | 
						|
									msg := make([]byte, 13)
							 | 
						|
									msg[0] = PG_RESP_AUTH_OK
							 | 
						|
									binary.BigEndian.PutUint32(msg[1:5], 12)
							 | 
						|
									binary.BigEndian.PutUint32(msg[5:9], AUTH_MD5)
							 | 
						|
									copy(msg[9:13], salt)
							 | 
						|
								
							 | 
						|
									_, err = session.writer.Write(msg)
							 | 
						|
									if err != nil {
							 | 
						|
										return err
							 | 
						|
									}
							 | 
						|
									err = session.writer.Flush()
							 | 
						|
									if err != nil {
							 | 
						|
										return err
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									// Read password response
							 | 
						|
									msgType := make([]byte, 1)
							 | 
						|
									_, err = io.ReadFull(session.reader, msgType)
							 | 
						|
									if err != nil {
							 | 
						|
										return err
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									if msgType[0] != PG_MSG_PASSWORD {
							 | 
						|
										return fmt.Errorf("expected password message, got %c", msgType[0])
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									length := make([]byte, 4)
							 | 
						|
									_, err = io.ReadFull(session.reader, length)
							 | 
						|
									if err != nil {
							 | 
						|
										return err
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									msgLength := binary.BigEndian.Uint32(length) - 4
							 | 
						|
									response := make([]byte, msgLength)
							 | 
						|
									_, err = io.ReadFull(session.reader, response)
							 | 
						|
									if err != nil {
							 | 
						|
										return err
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									// Verify MD5 hash
							 | 
						|
									expectedPassword, exists := s.config.Users[session.username]
							 | 
						|
									if !exists {
							 | 
						|
										return s.sendError(session, "28P01", "authentication failed for user \""+session.username+"\"")
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									// Calculate expected hash: md5(md5(password + username) + salt)
							 | 
						|
									inner := md5.Sum([]byte(expectedPassword + session.username))
							 | 
						|
									expected := fmt.Sprintf("md5%x", md5.Sum(append([]byte(fmt.Sprintf("%x", inner)), salt...)))
							 | 
						|
								
							 | 
						|
									if string(response[:len(response)-1]) != expected { // Remove null terminator
							 | 
						|
										return s.sendError(session, "28P01", "authentication failed for user \""+session.username+"\"")
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									return s.sendAuthenticationOk(session)
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// generateConnectionID generates a unique connection ID
							 | 
						|
								func (s *PostgreSQLServer) generateConnectionID() uint32 {
							 | 
						|
									s.sessionMux.Lock()
							 | 
						|
									defer s.sessionMux.Unlock()
							 | 
						|
									id := s.nextConnID
							 | 
						|
									s.nextConnID++
							 | 
						|
									return id
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// generateSecretKey generates a secret key for the connection
							 | 
						|
								func (s *PostgreSQLServer) generateSecretKey() uint32 {
							 | 
						|
									key := make([]byte, 4)
							 | 
						|
									rand.Read(key)
							 | 
						|
									return binary.BigEndian.Uint32(key)
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// close marks the session as closed
							 | 
						|
								func (s *PostgreSQLSession) close() {
							 | 
						|
									s.mutex.Lock()
							 | 
						|
									defer s.mutex.Unlock()
							 | 
						|
									if s.conn != nil {
							 | 
						|
										s.conn.Close()
							 | 
						|
										s.conn = nil
							 | 
						|
									}
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// cleanupSessions periodically cleans up idle sessions
							 | 
						|
								func (s *PostgreSQLServer) cleanupSessions() {
							 | 
						|
									defer s.wg.Done()
							 | 
						|
								
							 | 
						|
									ticker := time.NewTicker(time.Minute)
							 | 
						|
									defer ticker.Stop()
							 | 
						|
								
							 | 
						|
									for {
							 | 
						|
										select {
							 | 
						|
										case <-s.shutdown:
							 | 
						|
											return
							 | 
						|
										case <-ticker.C:
							 | 
						|
											s.cleanupIdleSessions()
							 | 
						|
										}
							 | 
						|
									}
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// cleanupIdleSessions removes sessions that have been idle too long
							 | 
						|
								func (s *PostgreSQLServer) cleanupIdleSessions() {
							 | 
						|
									now := time.Now()
							 | 
						|
								
							 | 
						|
									s.sessionMux.Lock()
							 | 
						|
									defer s.sessionMux.Unlock()
							 | 
						|
								
							 | 
						|
									for id, session := range s.sessions {
							 | 
						|
										if now.Sub(session.lastActivity) > s.config.IdleTimeout {
							 | 
						|
											glog.Infof("Closing idle PostgreSQL session %d", id)
							 | 
						|
											session.close()
							 | 
						|
											delete(s.sessions, id)
							 | 
						|
										}
							 | 
						|
									}
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// GetAddress returns the server address
							 | 
						|
								func (s *PostgreSQLServer) GetAddress() string {
							 | 
						|
									return fmt.Sprintf("%s:%d", s.config.Host, s.config.Port)
							 | 
						|
								}
							 |