|
|
@ -204,27 +204,31 @@ const ( |
|
|
|
// ParseSQL uses PostgreSQL parser to parse SQL statements
|
|
|
|
func ParseSQL(sql string) (Statement, error) { |
|
|
|
sql = strings.TrimSpace(sql) |
|
|
|
sqlUpper := strings.ToUpper(sql) |
|
|
|
|
|
|
|
// Parse with pg_query_go
|
|
|
|
result, err := pg_query.Parse(sql) |
|
|
|
if err != nil { |
|
|
|
return nil, fmt.Errorf("PostgreSQL parse error: %v", err) |
|
|
|
// Handle SHOW statements first (before PostgreSQL parser since pg doesn't support SHOW)
|
|
|
|
if strings.HasPrefix(sqlUpper, "SHOW DATABASES") || strings.HasPrefix(sqlUpper, "SHOW SCHEMAS") { |
|
|
|
return &ShowStatement{Type: "databases"}, nil |
|
|
|
} |
|
|
|
if strings.HasPrefix(sqlUpper, "SHOW TABLES") { |
|
|
|
stmt := &ShowStatement{Type: "tables"} |
|
|
|
// Handle "SHOW TABLES FROM database" syntax
|
|
|
|
if strings.Contains(sqlUpper, "FROM") { |
|
|
|
partsUpper := strings.Fields(sqlUpper) |
|
|
|
partsOriginal := strings.Fields(sql) // Use original casing
|
|
|
|
for i, part := range partsUpper { |
|
|
|
if part == "FROM" && i+1 < len(partsOriginal) { |
|
|
|
// Remove quotes if present (PostgreSQL uses double quotes, MySQL uses backticks)
|
|
|
|
dbName := strings.Trim(partsOriginal[i+1], "\"'`") |
|
|
|
stmt.OnTable.Name = stringValue(dbName) |
|
|
|
break |
|
|
|
} |
|
|
|
|
|
|
|
if len(result.Stmts) == 0 { |
|
|
|
return nil, fmt.Errorf("no statements parsed") |
|
|
|
} |
|
|
|
|
|
|
|
// Convert first statement
|
|
|
|
stmt := result.Stmts[0] |
|
|
|
|
|
|
|
// Handle SELECT statements
|
|
|
|
if selectStmt := stmt.Stmt.GetSelectStmt(); selectStmt != nil { |
|
|
|
return convertSelectStatement(selectStmt), nil |
|
|
|
} |
|
|
|
return stmt, nil |
|
|
|
} |
|
|
|
|
|
|
|
// Handle DDL statements by parsing SQL text patterns
|
|
|
|
sqlUpper := strings.ToUpper(sql) |
|
|
|
// Handle DDL statements by parsing SQL text patterns (before PostgreSQL parser)
|
|
|
|
if strings.HasPrefix(sqlUpper, "CREATE TABLE") { |
|
|
|
return parseCreateTableFromSQL(sql) |
|
|
|
} |
|
|
@ -235,12 +239,22 @@ func ParseSQL(sql string) (Statement, error) { |
|
|
|
return parseAlterTableFromSQL(sql) |
|
|
|
} |
|
|
|
|
|
|
|
// Handle SHOW statements
|
|
|
|
if strings.HasPrefix(sqlUpper, "SHOW DATABASES") || strings.HasPrefix(sqlUpper, "SHOW SCHEMAS") { |
|
|
|
return &ShowStatement{Type: "databases"}, nil |
|
|
|
// Parse with pg_query_go for SELECT and other standard PostgreSQL statements
|
|
|
|
result, err := pg_query.Parse(sql) |
|
|
|
if err != nil { |
|
|
|
return nil, fmt.Errorf("PostgreSQL parse error: %v", err) |
|
|
|
} |
|
|
|
if strings.HasPrefix(sqlUpper, "SHOW TABLES") { |
|
|
|
return &ShowStatement{Type: "tables"}, nil |
|
|
|
|
|
|
|
if len(result.Stmts) == 0 { |
|
|
|
return nil, fmt.Errorf("no statements parsed") |
|
|
|
} |
|
|
|
|
|
|
|
// Convert first statement
|
|
|
|
stmt := result.Stmts[0] |
|
|
|
|
|
|
|
// Handle SELECT statements
|
|
|
|
if selectStmt := stmt.Stmt.GetSelectStmt(); selectStmt != nil { |
|
|
|
return convertSelectStatement(selectStmt), nil |
|
|
|
} |
|
|
|
|
|
|
|
return nil, fmt.Errorf("unsupported statement type") |
|
|
@ -282,9 +296,61 @@ func convertSelectStatement(stmt *pg_query.SelectStmt) *SelectStatement { |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// Convert WHERE clause
|
|
|
|
if stmt.GetWhereClause() != nil { |
|
|
|
s.Where = &WhereClause{ |
|
|
|
Expr: convertExpressionNode(stmt.GetWhereClause()), |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// Convert LIMIT clause
|
|
|
|
if stmt.GetLimitCount() != nil { |
|
|
|
s.Limit = &LimitClause{ |
|
|
|
Rowcount: convertExpressionNode(stmt.GetLimitCount()), |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
return s |
|
|
|
} |
|
|
|
|
|
|
|
// convertExpressionNode converts PostgreSQL parser expression nodes to our internal ExprNode types
|
|
|
|
func convertExpressionNode(node *pg_query.Node) ExprNode { |
|
|
|
if node == nil { |
|
|
|
return nil |
|
|
|
} |
|
|
|
|
|
|
|
// Handle different expression types
|
|
|
|
if aConst := node.GetAConst(); aConst != nil { |
|
|
|
// Handle constants (numbers, strings)
|
|
|
|
if aConst.GetIval() != nil { |
|
|
|
return &SQLVal{ |
|
|
|
Type: IntVal, |
|
|
|
Val: []byte(fmt.Sprintf("%d", aConst.GetIval().GetIval())), |
|
|
|
} |
|
|
|
} |
|
|
|
if aConst.GetSval() != nil { |
|
|
|
return &SQLVal{ |
|
|
|
Type: StrVal, |
|
|
|
Val: []byte(aConst.GetSval().GetSval()), |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
if columnRef := node.GetColumnRef(); columnRef != nil { |
|
|
|
// Handle column references
|
|
|
|
return &ColName{ |
|
|
|
Name: stringValue("column"), // Simplified - would need more complex parsing
|
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// For now, return a simple placeholder for other expression types
|
|
|
|
// In a full implementation, we'd handle all PostgreSQL expression types
|
|
|
|
return &SQLVal{ |
|
|
|
Type: StrVal, |
|
|
|
Val: []byte(""), |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
func parseCreateTableFromSQL(sql string) (*DDLStatement, error) { |
|
|
|
parts := strings.Fields(sql) |
|
|
|
if len(parts) < 3 { |
|
|
|