|
|
@ -91,14 +91,15 @@ const ( |
|
|
|
|
|
|
|
// PostgreSQL server configuration
|
|
|
|
type PostgreSQLServerConfig struct { |
|
|
|
Host string |
|
|
|
Port int |
|
|
|
AuthMethod AuthMethod |
|
|
|
Users map[string]string |
|
|
|
TLSConfig *tls.Config |
|
|
|
MaxConns int |
|
|
|
IdleTimeout time.Duration |
|
|
|
Database string |
|
|
|
Host string |
|
|
|
Port int |
|
|
|
AuthMethod AuthMethod |
|
|
|
Users map[string]string |
|
|
|
TLSConfig *tls.Config |
|
|
|
MaxConns int |
|
|
|
IdleTimeout time.Duration |
|
|
|
StartupTimeout time.Duration // Timeout for client startup handshake
|
|
|
|
Database string |
|
|
|
} |
|
|
|
|
|
|
|
// PostgreSQL server
|
|
|
@ -177,6 +178,9 @@ func NewPostgreSQLServer(config *PostgreSQLServerConfig, masterAddr string) (*Po |
|
|
|
if config.IdleTimeout <= 0 { |
|
|
|
config.IdleTimeout = time.Hour |
|
|
|
} |
|
|
|
if config.StartupTimeout <= 0 { |
|
|
|
config.StartupTimeout = 30 * time.Second |
|
|
|
} |
|
|
|
|
|
|
|
// Create SQL engine with PostgreSQL parser for proper dialect compatibility
|
|
|
|
// Use PostgreSQL parser since we're implementing PostgreSQL wire protocol
|
|
|
@ -325,12 +329,19 @@ func (s *PostgreSQLServer) handleConnection(conn net.Conn) { |
|
|
|
s.sessionMux.Unlock() |
|
|
|
}() |
|
|
|
|
|
|
|
glog.Infof("New PostgreSQL connection from %s (ID: %d)", conn.RemoteAddr(), connID) |
|
|
|
glog.V(2).Infof("New PostgreSQL connection from %s (ID: %d)", conn.RemoteAddr(), connID) |
|
|
|
|
|
|
|
// Handle startup
|
|
|
|
err := s.handleStartup(session) |
|
|
|
if err != nil { |
|
|
|
glog.Errorf("Startup failed for connection %d: %v", connID, err) |
|
|
|
// Handle common disconnection scenarios more gracefully
|
|
|
|
if strings.Contains(err.Error(), "client disconnected") { |
|
|
|
glog.V(1).Infof("Client startup disconnected from %s (ID: %d): %v", conn.RemoteAddr(), connID, err) |
|
|
|
} else if strings.Contains(err.Error(), "timeout") { |
|
|
|
glog.Warningf("Startup timeout for connection %d from %s: %v", connID, conn.RemoteAddr(), err) |
|
|
|
} else { |
|
|
|
glog.Errorf("Startup failed for connection %d from %s: %v", connID, conn.RemoteAddr(), err) |
|
|
|
} |
|
|
|
return |
|
|
|
} |
|
|
|
|
|
|
@ -361,19 +372,42 @@ func (s *PostgreSQLServer) handleConnection(conn net.Conn) { |
|
|
|
|
|
|
|
// handleStartup processes the PostgreSQL startup sequence
|
|
|
|
func (s *PostgreSQLServer) handleStartup(session *PostgreSQLSession) error { |
|
|
|
// Set a startup timeout to prevent hanging connections
|
|
|
|
startupTimeout := s.config.StartupTimeout |
|
|
|
session.conn.SetReadDeadline(time.Now().Add(startupTimeout)) |
|
|
|
defer session.conn.SetReadDeadline(time.Time{}) // Clear timeout
|
|
|
|
|
|
|
|
for { |
|
|
|
// Read startup message
|
|
|
|
// Read startup message length
|
|
|
|
length := make([]byte, 4) |
|
|
|
_, err := io.ReadFull(session.reader, length) |
|
|
|
if err != nil { |
|
|
|
return err |
|
|
|
if err == io.EOF { |
|
|
|
// Client disconnected during startup - this is common for health checks
|
|
|
|
return fmt.Errorf("client disconnected during startup handshake") |
|
|
|
} |
|
|
|
if netErr, ok := err.(net.Error); ok && netErr.Timeout() { |
|
|
|
return fmt.Errorf("startup handshake timeout after %v", startupTimeout) |
|
|
|
} |
|
|
|
return fmt.Errorf("failed to read message length during startup: %v", err) |
|
|
|
} |
|
|
|
|
|
|
|
msgLength := binary.BigEndian.Uint32(length) - 4 |
|
|
|
if msgLength > 10000 { // Reasonable limit for startup messages
|
|
|
|
return fmt.Errorf("startup message too large: %d bytes", msgLength) |
|
|
|
} |
|
|
|
|
|
|
|
// Read startup message content
|
|
|
|
msg := make([]byte, msgLength) |
|
|
|
_, err = io.ReadFull(session.reader, msg) |
|
|
|
if err != nil { |
|
|
|
return err |
|
|
|
if err == io.EOF { |
|
|
|
return fmt.Errorf("client disconnected while reading startup message") |
|
|
|
} |
|
|
|
if netErr, ok := err.(net.Error); ok && netErr.Timeout() { |
|
|
|
return fmt.Errorf("startup message read timeout") |
|
|
|
} |
|
|
|
return fmt.Errorf("failed to read startup message: %v", err) |
|
|
|
} |
|
|
|
|
|
|
|
// Parse protocol version
|
|
|
|