Browse Source

fix tests

pull/7185/head
chrislu 1 month ago
parent
commit
991247facf
  1. 108
      weed/query/engine/engine.go
  2. 4
      weed/query/engine/real_namespace_test.go

108
weed/query/engine/engine.go

@ -204,27 +204,31 @@ const (
// ParseSQL uses PostgreSQL parser to parse SQL statements
func ParseSQL(sql string) (Statement, error) {
sql = strings.TrimSpace(sql)
sqlUpper := strings.ToUpper(sql)
// Parse with pg_query_go
result, err := pg_query.Parse(sql)
if err != nil {
return nil, fmt.Errorf("PostgreSQL parse error: %v", err)
// Handle SHOW statements first (before PostgreSQL parser since pg doesn't support SHOW)
if strings.HasPrefix(sqlUpper, "SHOW DATABASES") || strings.HasPrefix(sqlUpper, "SHOW SCHEMAS") {
return &ShowStatement{Type: "databases"}, nil
}
if strings.HasPrefix(sqlUpper, "SHOW TABLES") {
stmt := &ShowStatement{Type: "tables"}
// Handle "SHOW TABLES FROM database" syntax
if strings.Contains(sqlUpper, "FROM") {
partsUpper := strings.Fields(sqlUpper)
partsOriginal := strings.Fields(sql) // Use original casing
for i, part := range partsUpper {
if part == "FROM" && i+1 < len(partsOriginal) {
// Remove quotes if present (PostgreSQL uses double quotes, MySQL uses backticks)
dbName := strings.Trim(partsOriginal[i+1], "\"'`")
stmt.OnTable.Name = stringValue(dbName)
break
}
if len(result.Stmts) == 0 {
return nil, fmt.Errorf("no statements parsed")
}
// Convert first statement
stmt := result.Stmts[0]
// Handle SELECT statements
if selectStmt := stmt.Stmt.GetSelectStmt(); selectStmt != nil {
return convertSelectStatement(selectStmt), nil
}
return stmt, nil
}
// Handle DDL statements by parsing SQL text patterns
sqlUpper := strings.ToUpper(sql)
// Handle DDL statements by parsing SQL text patterns (before PostgreSQL parser)
if strings.HasPrefix(sqlUpper, "CREATE TABLE") {
return parseCreateTableFromSQL(sql)
}
@ -235,12 +239,22 @@ func ParseSQL(sql string) (Statement, error) {
return parseAlterTableFromSQL(sql)
}
// Handle SHOW statements
if strings.HasPrefix(sqlUpper, "SHOW DATABASES") || strings.HasPrefix(sqlUpper, "SHOW SCHEMAS") {
return &ShowStatement{Type: "databases"}, nil
// Parse with pg_query_go for SELECT and other standard PostgreSQL statements
result, err := pg_query.Parse(sql)
if err != nil {
return nil, fmt.Errorf("PostgreSQL parse error: %v", err)
}
if strings.HasPrefix(sqlUpper, "SHOW TABLES") {
return &ShowStatement{Type: "tables"}, nil
if len(result.Stmts) == 0 {
return nil, fmt.Errorf("no statements parsed")
}
// Convert first statement
stmt := result.Stmts[0]
// Handle SELECT statements
if selectStmt := stmt.Stmt.GetSelectStmt(); selectStmt != nil {
return convertSelectStatement(selectStmt), nil
}
return nil, fmt.Errorf("unsupported statement type")
@ -282,9 +296,61 @@ func convertSelectStatement(stmt *pg_query.SelectStmt) *SelectStatement {
}
}
// Convert WHERE clause
if stmt.GetWhereClause() != nil {
s.Where = &WhereClause{
Expr: convertExpressionNode(stmt.GetWhereClause()),
}
}
// Convert LIMIT clause
if stmt.GetLimitCount() != nil {
s.Limit = &LimitClause{
Rowcount: convertExpressionNode(stmt.GetLimitCount()),
}
}
return s
}
// convertExpressionNode converts PostgreSQL parser expression nodes to our internal ExprNode types
func convertExpressionNode(node *pg_query.Node) ExprNode {
if node == nil {
return nil
}
// Handle different expression types
if aConst := node.GetAConst(); aConst != nil {
// Handle constants (numbers, strings)
if aConst.GetIval() != nil {
return &SQLVal{
Type: IntVal,
Val: []byte(fmt.Sprintf("%d", aConst.GetIval().GetIval())),
}
}
if aConst.GetSval() != nil {
return &SQLVal{
Type: StrVal,
Val: []byte(aConst.GetSval().GetSval()),
}
}
}
if columnRef := node.GetColumnRef(); columnRef != nil {
// Handle column references
return &ColName{
Name: stringValue("column"), // Simplified - would need more complex parsing
}
}
// For now, return a simple placeholder for other expression types
// In a full implementation, we'd handle all PostgreSQL expression types
return &SQLVal{
Type: StrVal,
Val: []byte(""),
}
}
func parseCreateTableFromSQL(sql string) (*DDLStatement, error) {
parts := strings.Fields(sql)
if len(parts) < 3 {

4
weed/query/engine/real_namespace_test.go

@ -37,8 +37,8 @@ func TestRealNamespaceDiscovery(t *testing.T) {
func TestRealTopicDiscovery(t *testing.T) {
engine := NewSQLEngine("localhost:8888")
// Test SHOW TABLES with real topic discovery (use backticks for reserved keyword)
result, err := engine.ExecuteSQL(context.Background(), "SHOW TABLES FROM `default`")
// Test SHOW TABLES with real topic discovery (use double quotes for PostgreSQL)
result, err := engine.ExecuteSQL(context.Background(), "SHOW TABLES FROM \"default\"")
if err != nil {
t.Fatalf("SHOW TABLES failed: %v", err)
}

Loading…
Cancel
Save