You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

704 lines
17 KiB

package postgres
import (
"bufio"
"crypto/md5"
"crypto/rand"
"crypto/tls"
"encoding/binary"
"fmt"
"io"
"net"
"strings"
"sync"
"time"
"github.com/seaweedfs/seaweedfs/weed/glog"
"github.com/seaweedfs/seaweedfs/weed/query/engine"
"github.com/seaweedfs/seaweedfs/weed/util/version"
)
// PostgreSQL protocol constants
const (
// Protocol versions
PG_PROTOCOL_VERSION_3 = 196608 // PostgreSQL 3.0 protocol (0x00030000)
PG_SSL_REQUEST = 80877103 // SSL request (0x04d2162f)
PG_GSSAPI_REQUEST = 80877104 // GSSAPI request (0x04d21630)
// Message types from client
PG_MSG_STARTUP = 0x00
PG_MSG_QUERY = 'Q'
PG_MSG_PARSE = 'P'
PG_MSG_BIND = 'B'
PG_MSG_EXECUTE = 'E'
PG_MSG_DESCRIBE = 'D'
PG_MSG_CLOSE = 'C'
PG_MSG_FLUSH = 'H'
PG_MSG_SYNC = 'S'
PG_MSG_TERMINATE = 'X'
PG_MSG_PASSWORD = 'p'
// Response types to client
PG_RESP_AUTH_OK = 'R'
PG_RESP_BACKEND_KEY = 'K'
PG_RESP_PARAMETER = 'S'
PG_RESP_READY = 'Z'
PG_RESP_COMMAND = 'C'
PG_RESP_DATA_ROW = 'D'
PG_RESP_ROW_DESC = 'T'
PG_RESP_PARSE_COMPLETE = '1'
PG_RESP_BIND_COMPLETE = '2'
PG_RESP_CLOSE_COMPLETE = '3'
PG_RESP_ERROR = 'E'
PG_RESP_NOTICE = 'N'
// Transaction states
PG_TRANS_IDLE = 'I'
PG_TRANS_INTRANS = 'T'
PG_TRANS_ERROR = 'E'
// Authentication methods
AUTH_OK = 0
AUTH_CLEAR = 3
AUTH_MD5 = 5
AUTH_TRUST = 10
// PostgreSQL data types
PG_TYPE_BOOL = 16
PG_TYPE_BYTEA = 17
PG_TYPE_INT8 = 20
PG_TYPE_INT4 = 23
PG_TYPE_TEXT = 25
PG_TYPE_FLOAT4 = 700
PG_TYPE_FLOAT8 = 701
PG_TYPE_VARCHAR = 1043
PG_TYPE_TIMESTAMP = 1114
PG_TYPE_JSON = 114
PG_TYPE_JSONB = 3802
// Default values
DEFAULT_POSTGRES_PORT = 5432
)
// Authentication method type
type AuthMethod int
const (
AuthTrust AuthMethod = iota
AuthPassword
AuthMD5
)
// PostgreSQL server configuration
type PostgreSQLServerConfig struct {
Host string
Port int
AuthMethod AuthMethod
Users map[string]string
TLSConfig *tls.Config
MaxConns int
IdleTimeout time.Duration
StartupTimeout time.Duration // Timeout for client startup handshake
Database string
}
// PostgreSQL server
type PostgreSQLServer struct {
config *PostgreSQLServerConfig
listener net.Listener
sqlEngine *engine.SQLEngine
sessions map[uint32]*PostgreSQLSession
sessionMux sync.RWMutex
shutdown chan struct{}
wg sync.WaitGroup
nextConnID uint32
}
// PostgreSQL session
type PostgreSQLSession struct {
conn net.Conn
reader *bufio.Reader
writer *bufio.Writer
authenticated bool
username string
database string
parameters map[string]string
preparedStmts map[string]*PreparedStatement
portals map[string]*Portal
transactionState byte
processID uint32
secretKey uint32
created time.Time
lastActivity time.Time
mutex sync.Mutex
}
// Prepared statement
type PreparedStatement struct {
Name string
Query string
ParamTypes []uint32
Fields []FieldDescription
}
// Portal (cursor)
type Portal struct {
Name string
Statement string
Parameters [][]byte
Suspended bool
}
// Field description
type FieldDescription struct {
Name string
TableOID uint32
AttrNum int16
TypeOID uint32
TypeSize int16
TypeMod int32
Format int16
}
// NewPostgreSQLServer creates a new PostgreSQL protocol server
func NewPostgreSQLServer(config *PostgreSQLServerConfig, masterAddr string) (*PostgreSQLServer, error) {
if config.Port <= 0 {
config.Port = DEFAULT_POSTGRES_PORT
}
if config.Host == "" {
config.Host = "localhost"
}
if config.Database == "" {
config.Database = "default"
}
if config.MaxConns <= 0 {
config.MaxConns = 100
}
if config.IdleTimeout <= 0 {
config.IdleTimeout = time.Hour
}
if config.StartupTimeout <= 0 {
config.StartupTimeout = 30 * time.Second
}
// Create SQL engine (now uses CockroachDB parser for PostgreSQL compatibility)
sqlEngine := engine.NewSQLEngine(masterAddr)
server := &PostgreSQLServer{
config: config,
sqlEngine: sqlEngine,
sessions: make(map[uint32]*PostgreSQLSession),
shutdown: make(chan struct{}),
nextConnID: 1,
}
return server, nil
}
// Start begins listening for PostgreSQL connections
func (s *PostgreSQLServer) Start() error {
addr := fmt.Sprintf("%s:%d", s.config.Host, s.config.Port)
var listener net.Listener
var err error
if s.config.TLSConfig != nil {
listener, err = tls.Listen("tcp", addr, s.config.TLSConfig)
glog.Infof("PostgreSQL Server with TLS listening on %s", addr)
} else {
listener, err = net.Listen("tcp", addr)
glog.Infof("PostgreSQL Server listening on %s", addr)
}
if err != nil {
return fmt.Errorf("failed to start PostgreSQL server on %s: %v", addr, err)
}
s.listener = listener
// Start accepting connections
s.wg.Add(1)
go s.acceptConnections()
// Start cleanup routine
s.wg.Add(1)
go s.cleanupSessions()
return nil
}
// Stop gracefully shuts down the PostgreSQL server
func (s *PostgreSQLServer) Stop() error {
close(s.shutdown)
if s.listener != nil {
s.listener.Close()
}
// Close all sessions
s.sessionMux.Lock()
for _, session := range s.sessions {
session.close()
}
s.sessions = make(map[uint32]*PostgreSQLSession)
s.sessionMux.Unlock()
s.wg.Wait()
glog.Infof("PostgreSQL Server stopped")
return nil
}
// acceptConnections handles incoming PostgreSQL connections
func (s *PostgreSQLServer) acceptConnections() {
defer s.wg.Done()
for {
select {
case <-s.shutdown:
return
default:
}
conn, err := s.listener.Accept()
if err != nil {
select {
case <-s.shutdown:
return
default:
glog.Errorf("Failed to accept PostgreSQL connection: %v", err)
continue
}
}
// Check connection limit
s.sessionMux.RLock()
sessionCount := len(s.sessions)
s.sessionMux.RUnlock()
if sessionCount >= s.config.MaxConns {
glog.Warningf("Maximum connections reached (%d), rejecting connection from %s",
s.config.MaxConns, conn.RemoteAddr())
conn.Close()
continue
}
s.wg.Add(1)
go s.handleConnection(conn)
}
}
// handleConnection processes a single PostgreSQL connection
func (s *PostgreSQLServer) handleConnection(conn net.Conn) {
defer s.wg.Done()
defer conn.Close()
// Generate unique connection ID
connID := s.generateConnectionID()
secretKey := s.generateSecretKey()
// Create session
session := &PostgreSQLSession{
conn: conn,
reader: bufio.NewReader(conn),
writer: bufio.NewWriter(conn),
authenticated: false,
database: s.config.Database,
parameters: make(map[string]string),
preparedStmts: make(map[string]*PreparedStatement),
portals: make(map[string]*Portal),
transactionState: PG_TRANS_IDLE,
processID: connID,
secretKey: secretKey,
created: time.Now(),
lastActivity: time.Now(),
}
// Register session
s.sessionMux.Lock()
s.sessions[connID] = session
s.sessionMux.Unlock()
// Clean up on exit
defer func() {
s.sessionMux.Lock()
delete(s.sessions, connID)
s.sessionMux.Unlock()
}()
glog.V(2).Infof("New PostgreSQL connection from %s (ID: %d)", conn.RemoteAddr(), connID)
// Handle startup
err := s.handleStartup(session)
if err != nil {
// Handle common disconnection scenarios more gracefully
if strings.Contains(err.Error(), "client disconnected") {
glog.V(1).Infof("Client startup disconnected from %s (ID: %d): %v", conn.RemoteAddr(), connID, err)
} else if strings.Contains(err.Error(), "timeout") {
glog.Warningf("Startup timeout for connection %d from %s: %v", connID, conn.RemoteAddr(), err)
} else {
glog.Errorf("Startup failed for connection %d from %s: %v", connID, conn.RemoteAddr(), err)
}
return
}
// Handle messages
for {
select {
case <-s.shutdown:
return
default:
}
// Set read timeout
conn.SetReadDeadline(time.Now().Add(30 * time.Second))
err := s.handleMessage(session)
if err != nil {
if err == io.EOF {
glog.Infof("PostgreSQL client disconnected (ID: %d)", connID)
} else {
glog.Errorf("Error handling PostgreSQL message (ID: %d): %v", connID, err)
}
return
}
session.lastActivity = time.Now()
}
}
// handleStartup processes the PostgreSQL startup sequence
func (s *PostgreSQLServer) handleStartup(session *PostgreSQLSession) error {
// Set a startup timeout to prevent hanging connections
startupTimeout := s.config.StartupTimeout
session.conn.SetReadDeadline(time.Now().Add(startupTimeout))
defer session.conn.SetReadDeadline(time.Time{}) // Clear timeout
for {
// Read startup message length
length := make([]byte, 4)
_, err := io.ReadFull(session.reader, length)
if err != nil {
if err == io.EOF {
// Client disconnected during startup - this is common for health checks
return fmt.Errorf("client disconnected during startup handshake")
}
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
return fmt.Errorf("startup handshake timeout after %v", startupTimeout)
}
return fmt.Errorf("failed to read message length during startup: %v", err)
}
msgLength := binary.BigEndian.Uint32(length) - 4
if msgLength > 10000 { // Reasonable limit for startup messages
return fmt.Errorf("startup message too large: %d bytes", msgLength)
}
// Read startup message content
msg := make([]byte, msgLength)
_, err = io.ReadFull(session.reader, msg)
if err != nil {
if err == io.EOF {
return fmt.Errorf("client disconnected while reading startup message")
}
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
return fmt.Errorf("startup message read timeout")
}
return fmt.Errorf("failed to read startup message: %v", err)
}
// Parse protocol version
protocolVersion := binary.BigEndian.Uint32(msg[0:4])
switch protocolVersion {
case PG_SSL_REQUEST:
// Reject SSL request - send 'N' to indicate SSL not supported
_, err = session.conn.Write([]byte{'N'})
if err != nil {
return fmt.Errorf("failed to reject SSL request: %v", err)
}
// Continue loop to read the actual startup message
continue
case PG_GSSAPI_REQUEST:
// Reject GSSAPI request - send 'N' to indicate GSSAPI not supported
_, err = session.conn.Write([]byte{'N'})
if err != nil {
return fmt.Errorf("failed to reject GSSAPI request: %v", err)
}
// Continue loop to read the actual startup message
continue
case PG_PROTOCOL_VERSION_3:
// This is the actual startup message, break out of loop
break
default:
return fmt.Errorf("unsupported protocol version: %d", protocolVersion)
}
// Parse parameters
params := strings.Split(string(msg[4:]), "\x00")
for i := 0; i < len(params)-1; i += 2 {
if params[i] == "user" {
session.username = params[i+1]
} else if params[i] == "database" {
session.database = params[i+1]
}
session.parameters[params[i]] = params[i+1]
}
// Break out of the main loop - we have the startup message
break
}
// Handle authentication
err := s.handleAuthentication(session)
if err != nil {
return err
}
// Send parameter status messages
err = s.sendParameterStatus(session, "server_version", fmt.Sprintf("%s (SeaweedFS)", version.VERSION_NUMBER))
if err != nil {
return err
}
err = s.sendParameterStatus(session, "server_encoding", "UTF8")
if err != nil {
return err
}
err = s.sendParameterStatus(session, "client_encoding", "UTF8")
if err != nil {
return err
}
err = s.sendParameterStatus(session, "DateStyle", "ISO, MDY")
if err != nil {
return err
}
err = s.sendParameterStatus(session, "integer_datetimes", "on")
if err != nil {
return err
}
// Send backend key data
err = s.sendBackendKeyData(session)
if err != nil {
return err
}
// Send ready for query
err = s.sendReadyForQuery(session)
if err != nil {
return err
}
session.authenticated = true
return nil
}
// handleAuthentication processes authentication
func (s *PostgreSQLServer) handleAuthentication(session *PostgreSQLSession) error {
switch s.config.AuthMethod {
case AuthTrust:
return s.sendAuthenticationOk(session)
case AuthPassword:
return s.handlePasswordAuth(session)
case AuthMD5:
return s.handleMD5Auth(session)
default:
return fmt.Errorf("unsupported authentication method")
}
}
// sendAuthenticationOk sends authentication OK message
func (s *PostgreSQLServer) sendAuthenticationOk(session *PostgreSQLSession) error {
msg := make([]byte, 9)
msg[0] = PG_RESP_AUTH_OK
binary.BigEndian.PutUint32(msg[1:5], 8)
binary.BigEndian.PutUint32(msg[5:9], AUTH_OK)
_, err := session.writer.Write(msg)
if err == nil {
err = session.writer.Flush()
}
return err
}
// handlePasswordAuth handles clear password authentication
func (s *PostgreSQLServer) handlePasswordAuth(session *PostgreSQLSession) error {
// Send password request
msg := make([]byte, 9)
msg[0] = PG_RESP_AUTH_OK
binary.BigEndian.PutUint32(msg[1:5], 8)
binary.BigEndian.PutUint32(msg[5:9], AUTH_CLEAR)
_, err := session.writer.Write(msg)
if err != nil {
return err
}
err = session.writer.Flush()
if err != nil {
return err
}
// Read password response
msgType := make([]byte, 1)
_, err = io.ReadFull(session.reader, msgType)
if err != nil {
return err
}
if msgType[0] != PG_MSG_PASSWORD {
return fmt.Errorf("expected password message, got %c", msgType[0])
}
length := make([]byte, 4)
_, err = io.ReadFull(session.reader, length)
if err != nil {
return err
}
msgLength := binary.BigEndian.Uint32(length) - 4
password := make([]byte, msgLength)
_, err = io.ReadFull(session.reader, password)
if err != nil {
return err
}
// Verify password
expectedPassword, exists := s.config.Users[session.username]
if !exists || string(password[:len(password)-1]) != expectedPassword { // Remove null terminator
return s.sendError(session, "28P01", "authentication failed for user \""+session.username+"\"")
}
return s.sendAuthenticationOk(session)
}
// handleMD5Auth handles MD5 password authentication
func (s *PostgreSQLServer) handleMD5Auth(session *PostgreSQLSession) error {
// Generate salt
salt := make([]byte, 4)
_, err := rand.Read(salt)
if err != nil {
return err
}
// Send MD5 request
msg := make([]byte, 13)
msg[0] = PG_RESP_AUTH_OK
binary.BigEndian.PutUint32(msg[1:5], 12)
binary.BigEndian.PutUint32(msg[5:9], AUTH_MD5)
copy(msg[9:13], salt)
_, err = session.writer.Write(msg)
if err != nil {
return err
}
err = session.writer.Flush()
if err != nil {
return err
}
// Read password response
msgType := make([]byte, 1)
_, err = io.ReadFull(session.reader, msgType)
if err != nil {
return err
}
if msgType[0] != PG_MSG_PASSWORD {
return fmt.Errorf("expected password message, got %c", msgType[0])
}
length := make([]byte, 4)
_, err = io.ReadFull(session.reader, length)
if err != nil {
return err
}
msgLength := binary.BigEndian.Uint32(length) - 4
response := make([]byte, msgLength)
_, err = io.ReadFull(session.reader, response)
if err != nil {
return err
}
// Verify MD5 hash
expectedPassword, exists := s.config.Users[session.username]
if !exists {
return s.sendError(session, "28P01", "authentication failed for user \""+session.username+"\"")
}
// Calculate expected hash: md5(md5(password + username) + salt)
inner := md5.Sum([]byte(expectedPassword + session.username))
expected := fmt.Sprintf("md5%x", md5.Sum(append([]byte(fmt.Sprintf("%x", inner)), salt...)))
if string(response[:len(response)-1]) != expected { // Remove null terminator
return s.sendError(session, "28P01", "authentication failed for user \""+session.username+"\"")
}
return s.sendAuthenticationOk(session)
}
// generateConnectionID generates a unique connection ID
func (s *PostgreSQLServer) generateConnectionID() uint32 {
s.sessionMux.Lock()
defer s.sessionMux.Unlock()
id := s.nextConnID
s.nextConnID++
return id
}
// generateSecretKey generates a secret key for the connection
func (s *PostgreSQLServer) generateSecretKey() uint32 {
key := make([]byte, 4)
rand.Read(key)
return binary.BigEndian.Uint32(key)
}
// close marks the session as closed
func (s *PostgreSQLSession) close() {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.conn != nil {
s.conn.Close()
s.conn = nil
}
}
// cleanupSessions periodically cleans up idle sessions
func (s *PostgreSQLServer) cleanupSessions() {
defer s.wg.Done()
ticker := time.NewTicker(time.Minute)
defer ticker.Stop()
for {
select {
case <-s.shutdown:
return
case <-ticker.C:
s.cleanupIdleSessions()
}
}
}
// cleanupIdleSessions removes sessions that have been idle too long
func (s *PostgreSQLServer) cleanupIdleSessions() {
now := time.Now()
s.sessionMux.Lock()
defer s.sessionMux.Unlock()
for id, session := range s.sessions {
if now.Sub(session.lastActivity) > s.config.IdleTimeout {
glog.Infof("Closing idle PostgreSQL session %d", id)
session.close()
delete(s.sessions, id)
}
}
}
// GetAddress returns the server address
func (s *PostgreSQLServer) GetAddress() string {
return fmt.Sprintf("%s:%d", s.config.Host, s.config.Port)
}