4 changed files with 537 additions and 9 deletions
-
21test/postgres/client.go
-
2test/postgres/docker-compose.yml
-
57weed/query/engine/engine.go
-
466weed/query/engine/query_parsing_test.go
@ -0,0 +1,466 @@ |
|||||
|
package engine |
||||
|
|
||||
|
import ( |
||||
|
"testing" |
||||
|
) |
||||
|
|
||||
|
func TestParseSQL_COUNT_Functions(t *testing.T) { |
||||
|
tests := []struct { |
||||
|
name string |
||||
|
sql string |
||||
|
wantErr bool |
||||
|
validate func(t *testing.T, stmt Statement) |
||||
|
}{ |
||||
|
{ |
||||
|
name: "COUNT(*) basic", |
||||
|
sql: "SELECT COUNT(*) FROM test_table", |
||||
|
wantErr: false, |
||||
|
validate: func(t *testing.T, stmt Statement) { |
||||
|
selectStmt, ok := stmt.(*SelectStatement) |
||||
|
if !ok { |
||||
|
t.Fatalf("Expected *SelectStatement, got %T", stmt) |
||||
|
} |
||||
|
|
||||
|
if len(selectStmt.SelectExprs) != 1 { |
||||
|
t.Fatalf("Expected 1 select expression, got %d", len(selectStmt.SelectExprs)) |
||||
|
} |
||||
|
|
||||
|
aliasedExpr, ok := selectStmt.SelectExprs[0].(*AliasedExpr) |
||||
|
if !ok { |
||||
|
t.Fatalf("Expected *AliasedExpr, got %T", selectStmt.SelectExprs[0]) |
||||
|
} |
||||
|
|
||||
|
funcExpr, ok := aliasedExpr.Expr.(*FuncExpr) |
||||
|
if !ok { |
||||
|
t.Fatalf("Expected *FuncExpr, got %T", aliasedExpr.Expr) |
||||
|
} |
||||
|
|
||||
|
if funcExpr.Name.String() != "COUNT" { |
||||
|
t.Errorf("Expected function name 'COUNT', got '%s'", funcExpr.Name.String()) |
||||
|
} |
||||
|
|
||||
|
if len(funcExpr.Exprs) != 1 { |
||||
|
t.Fatalf("Expected 1 function argument, got %d", len(funcExpr.Exprs)) |
||||
|
} |
||||
|
|
||||
|
starExpr, ok := funcExpr.Exprs[0].(*StarExpr) |
||||
|
if !ok { |
||||
|
t.Errorf("Expected *StarExpr argument, got %T", funcExpr.Exprs[0]) |
||||
|
} |
||||
|
_ = starExpr // Use the variable to avoid unused variable error
|
||||
|
}, |
||||
|
}, |
||||
|
{ |
||||
|
name: "COUNT(column_name)", |
||||
|
sql: "SELECT COUNT(user_id) FROM users", |
||||
|
wantErr: false, |
||||
|
validate: func(t *testing.T, stmt Statement) { |
||||
|
selectStmt, ok := stmt.(*SelectStatement) |
||||
|
if !ok { |
||||
|
t.Fatalf("Expected *SelectStatement, got %T", stmt) |
||||
|
} |
||||
|
|
||||
|
aliasedExpr := selectStmt.SelectExprs[0].(*AliasedExpr) |
||||
|
funcExpr := aliasedExpr.Expr.(*FuncExpr) |
||||
|
|
||||
|
if funcExpr.Name.String() != "COUNT" { |
||||
|
t.Errorf("Expected function name 'COUNT', got '%s'", funcExpr.Name.String()) |
||||
|
} |
||||
|
|
||||
|
if len(funcExpr.Exprs) != 1 { |
||||
|
t.Fatalf("Expected 1 function argument, got %d", len(funcExpr.Exprs)) |
||||
|
} |
||||
|
|
||||
|
argExpr, ok := funcExpr.Exprs[0].(*AliasedExpr) |
||||
|
if !ok { |
||||
|
t.Errorf("Expected *AliasedExpr argument, got %T", funcExpr.Exprs[0]) |
||||
|
} |
||||
|
|
||||
|
colName, ok := argExpr.Expr.(*ColName) |
||||
|
if !ok { |
||||
|
t.Errorf("Expected *ColName, got %T", argExpr.Expr) |
||||
|
} |
||||
|
|
||||
|
if colName.Name.String() != "user_id" { |
||||
|
t.Errorf("Expected column name 'user_id', got '%s'", colName.Name.String()) |
||||
|
} |
||||
|
}, |
||||
|
}, |
||||
|
{ |
||||
|
name: "Multiple aggregate functions", |
||||
|
sql: "SELECT COUNT(*), SUM(amount), AVG(score) FROM transactions", |
||||
|
wantErr: false, |
||||
|
validate: func(t *testing.T, stmt Statement) { |
||||
|
selectStmt, ok := stmt.(*SelectStatement) |
||||
|
if !ok { |
||||
|
t.Fatalf("Expected *SelectStatement, got %T", stmt) |
||||
|
} |
||||
|
|
||||
|
if len(selectStmt.SelectExprs) != 3 { |
||||
|
t.Fatalf("Expected 3 select expressions, got %d", len(selectStmt.SelectExprs)) |
||||
|
} |
||||
|
|
||||
|
// Verify COUNT(*)
|
||||
|
countExpr := selectStmt.SelectExprs[0].(*AliasedExpr) |
||||
|
countFunc := countExpr.Expr.(*FuncExpr) |
||||
|
if countFunc.Name.String() != "COUNT" { |
||||
|
t.Errorf("Expected first function to be COUNT, got %s", countFunc.Name.String()) |
||||
|
} |
||||
|
|
||||
|
// Verify SUM(amount)
|
||||
|
sumExpr := selectStmt.SelectExprs[1].(*AliasedExpr) |
||||
|
sumFunc := sumExpr.Expr.(*FuncExpr) |
||||
|
if sumFunc.Name.String() != "SUM" { |
||||
|
t.Errorf("Expected second function to be SUM, got %s", sumFunc.Name.String()) |
||||
|
} |
||||
|
|
||||
|
// Verify AVG(score)
|
||||
|
avgExpr := selectStmt.SelectExprs[2].(*AliasedExpr) |
||||
|
avgFunc := avgExpr.Expr.(*FuncExpr) |
||||
|
if avgFunc.Name.String() != "AVG" { |
||||
|
t.Errorf("Expected third function to be AVG, got %s", avgFunc.Name.String()) |
||||
|
} |
||||
|
}, |
||||
|
}, |
||||
|
} |
||||
|
|
||||
|
for _, tt := range tests { |
||||
|
t.Run(tt.name, func(t *testing.T) { |
||||
|
stmt, err := ParseSQL(tt.sql) |
||||
|
|
||||
|
if tt.wantErr { |
||||
|
if err == nil { |
||||
|
t.Errorf("Expected error, but got none") |
||||
|
} |
||||
|
return |
||||
|
} |
||||
|
|
||||
|
if err != nil { |
||||
|
t.Errorf("Unexpected error: %v", err) |
||||
|
return |
||||
|
} |
||||
|
|
||||
|
if tt.validate != nil { |
||||
|
tt.validate(t, stmt) |
||||
|
} |
||||
|
}) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func TestParseSQL_SELECT_Expressions(t *testing.T) { |
||||
|
tests := []struct { |
||||
|
name string |
||||
|
sql string |
||||
|
wantErr bool |
||||
|
validate func(t *testing.T, stmt Statement) |
||||
|
}{ |
||||
|
{ |
||||
|
name: "SELECT * FROM table", |
||||
|
sql: "SELECT * FROM users", |
||||
|
wantErr: false, |
||||
|
validate: func(t *testing.T, stmt Statement) { |
||||
|
selectStmt := stmt.(*SelectStatement) |
||||
|
if len(selectStmt.SelectExprs) != 1 { |
||||
|
t.Fatalf("Expected 1 select expression, got %d", len(selectStmt.SelectExprs)) |
||||
|
} |
||||
|
|
||||
|
_, ok := selectStmt.SelectExprs[0].(*StarExpr) |
||||
|
if !ok { |
||||
|
t.Errorf("Expected *StarExpr, got %T", selectStmt.SelectExprs[0]) |
||||
|
} |
||||
|
}, |
||||
|
}, |
||||
|
{ |
||||
|
name: "SELECT column FROM table", |
||||
|
sql: "SELECT user_id FROM users", |
||||
|
wantErr: false, |
||||
|
validate: func(t *testing.T, stmt Statement) { |
||||
|
selectStmt := stmt.(*SelectStatement) |
||||
|
if len(selectStmt.SelectExprs) != 1 { |
||||
|
t.Fatalf("Expected 1 select expression, got %d", len(selectStmt.SelectExprs)) |
||||
|
} |
||||
|
|
||||
|
aliasedExpr, ok := selectStmt.SelectExprs[0].(*AliasedExpr) |
||||
|
if !ok { |
||||
|
t.Fatalf("Expected *AliasedExpr, got %T", selectStmt.SelectExprs[0]) |
||||
|
} |
||||
|
|
||||
|
colName, ok := aliasedExpr.Expr.(*ColName) |
||||
|
if !ok { |
||||
|
t.Fatalf("Expected *ColName, got %T", aliasedExpr.Expr) |
||||
|
} |
||||
|
|
||||
|
if colName.Name.String() != "user_id" { |
||||
|
t.Errorf("Expected column name 'user_id', got '%s'", colName.Name.String()) |
||||
|
} |
||||
|
}, |
||||
|
}, |
||||
|
{ |
||||
|
name: "SELECT multiple columns", |
||||
|
sql: "SELECT user_id, name, email FROM users", |
||||
|
wantErr: false, |
||||
|
validate: func(t *testing.T, stmt Statement) { |
||||
|
selectStmt := stmt.(*SelectStatement) |
||||
|
if len(selectStmt.SelectExprs) != 3 { |
||||
|
t.Fatalf("Expected 3 select expressions, got %d", len(selectStmt.SelectExprs)) |
||||
|
} |
||||
|
|
||||
|
expectedColumns := []string{"user_id", "name", "email"} |
||||
|
for i, expected := range expectedColumns { |
||||
|
aliasedExpr := selectStmt.SelectExprs[i].(*AliasedExpr) |
||||
|
colName := aliasedExpr.Expr.(*ColName) |
||||
|
if colName.Name.String() != expected { |
||||
|
t.Errorf("Expected column %d to be '%s', got '%s'", i, expected, colName.Name.String()) |
||||
|
} |
||||
|
} |
||||
|
}, |
||||
|
}, |
||||
|
} |
||||
|
|
||||
|
for _, tt := range tests { |
||||
|
t.Run(tt.name, func(t *testing.T) { |
||||
|
stmt, err := ParseSQL(tt.sql) |
||||
|
|
||||
|
if tt.wantErr { |
||||
|
if err == nil { |
||||
|
t.Errorf("Expected error, but got none") |
||||
|
} |
||||
|
return |
||||
|
} |
||||
|
|
||||
|
if err != nil { |
||||
|
t.Errorf("Unexpected error: %v", err) |
||||
|
return |
||||
|
} |
||||
|
|
||||
|
if tt.validate != nil { |
||||
|
tt.validate(t, stmt) |
||||
|
} |
||||
|
}) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func TestParseSQL_WHERE_Clauses(t *testing.T) { |
||||
|
tests := []struct { |
||||
|
name string |
||||
|
sql string |
||||
|
wantErr bool |
||||
|
validate func(t *testing.T, stmt Statement) |
||||
|
}{ |
||||
|
{ |
||||
|
name: "WHERE with simple comparison", |
||||
|
sql: "SELECT * FROM users WHERE age > 18", |
||||
|
wantErr: false, |
||||
|
validate: func(t *testing.T, stmt Statement) { |
||||
|
selectStmt := stmt.(*SelectStatement) |
||||
|
if selectStmt.Where == nil { |
||||
|
t.Fatal("Expected WHERE clause, got nil") |
||||
|
} |
||||
|
|
||||
|
// Just verify we have a WHERE clause with an expression
|
||||
|
if selectStmt.Where.Expr == nil { |
||||
|
t.Error("Expected WHERE expression, got nil") |
||||
|
} |
||||
|
}, |
||||
|
}, |
||||
|
{ |
||||
|
name: "WHERE with AND condition", |
||||
|
sql: "SELECT * FROM users WHERE age > 18 AND status = 'active'", |
||||
|
wantErr: false, |
||||
|
validate: func(t *testing.T, stmt Statement) { |
||||
|
selectStmt := stmt.(*SelectStatement) |
||||
|
if selectStmt.Where == nil { |
||||
|
t.Fatal("Expected WHERE clause, got nil") |
||||
|
} |
||||
|
|
||||
|
// Verify we have an AND expression
|
||||
|
andExpr, ok := selectStmt.Where.Expr.(*AndExpr) |
||||
|
if !ok { |
||||
|
t.Errorf("Expected *AndExpr, got %T", selectStmt.Where.Expr) |
||||
|
} |
||||
|
_ = andExpr // Use variable to avoid unused error
|
||||
|
}, |
||||
|
}, |
||||
|
{ |
||||
|
name: "WHERE with OR condition", |
||||
|
sql: "SELECT * FROM users WHERE age < 18 OR age > 65", |
||||
|
wantErr: false, |
||||
|
validate: func(t *testing.T, stmt Statement) { |
||||
|
selectStmt := stmt.(*SelectStatement) |
||||
|
if selectStmt.Where == nil { |
||||
|
t.Fatal("Expected WHERE clause, got nil") |
||||
|
} |
||||
|
|
||||
|
// Verify we have an OR expression
|
||||
|
orExpr, ok := selectStmt.Where.Expr.(*OrExpr) |
||||
|
if !ok { |
||||
|
t.Errorf("Expected *OrExpr, got %T", selectStmt.Where.Expr) |
||||
|
} |
||||
|
_ = orExpr // Use variable to avoid unused error
|
||||
|
}, |
||||
|
}, |
||||
|
} |
||||
|
|
||||
|
for _, tt := range tests { |
||||
|
t.Run(tt.name, func(t *testing.T) { |
||||
|
stmt, err := ParseSQL(tt.sql) |
||||
|
|
||||
|
if tt.wantErr { |
||||
|
if err == nil { |
||||
|
t.Errorf("Expected error, but got none") |
||||
|
} |
||||
|
return |
||||
|
} |
||||
|
|
||||
|
if err != nil { |
||||
|
t.Errorf("Unexpected error: %v", err) |
||||
|
return |
||||
|
} |
||||
|
|
||||
|
if tt.validate != nil { |
||||
|
tt.validate(t, stmt) |
||||
|
} |
||||
|
}) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func TestParseSQL_LIMIT_Clauses(t *testing.T) { |
||||
|
tests := []struct { |
||||
|
name string |
||||
|
sql string |
||||
|
wantErr bool |
||||
|
validate func(t *testing.T, stmt Statement) |
||||
|
}{ |
||||
|
{ |
||||
|
name: "LIMIT with number", |
||||
|
sql: "SELECT * FROM users LIMIT 10", |
||||
|
wantErr: false, |
||||
|
validate: func(t *testing.T, stmt Statement) { |
||||
|
selectStmt := stmt.(*SelectStatement) |
||||
|
if selectStmt.Limit == nil { |
||||
|
t.Fatal("Expected LIMIT clause, got nil") |
||||
|
} |
||||
|
|
||||
|
if selectStmt.Limit.Rowcount == nil { |
||||
|
t.Error("Expected LIMIT rowcount, got nil") |
||||
|
} |
||||
|
|
||||
|
sqlVal, ok := selectStmt.Limit.Rowcount.(*SQLVal) |
||||
|
if !ok { |
||||
|
t.Errorf("Expected *SQLVal, got %T", selectStmt.Limit.Rowcount) |
||||
|
} |
||||
|
|
||||
|
if sqlVal.Type != IntVal { |
||||
|
t.Errorf("Expected IntVal type, got %d", sqlVal.Type) |
||||
|
} |
||||
|
|
||||
|
if string(sqlVal.Val) != "10" { |
||||
|
t.Errorf("Expected limit value '10', got '%s'", string(sqlVal.Val)) |
||||
|
} |
||||
|
}, |
||||
|
}, |
||||
|
} |
||||
|
|
||||
|
for _, tt := range tests { |
||||
|
t.Run(tt.name, func(t *testing.T) { |
||||
|
stmt, err := ParseSQL(tt.sql) |
||||
|
|
||||
|
if tt.wantErr { |
||||
|
if err == nil { |
||||
|
t.Errorf("Expected error, but got none") |
||||
|
} |
||||
|
return |
||||
|
} |
||||
|
|
||||
|
if err != nil { |
||||
|
t.Errorf("Unexpected error: %v", err) |
||||
|
return |
||||
|
} |
||||
|
|
||||
|
if tt.validate != nil { |
||||
|
tt.validate(t, stmt) |
||||
|
} |
||||
|
}) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func TestParseSQL_SHOW_Statements(t *testing.T) { |
||||
|
tests := []struct { |
||||
|
name string |
||||
|
sql string |
||||
|
wantErr bool |
||||
|
validate func(t *testing.T, stmt Statement) |
||||
|
}{ |
||||
|
{ |
||||
|
name: "SHOW DATABASES", |
||||
|
sql: "SHOW DATABASES", |
||||
|
wantErr: false, |
||||
|
validate: func(t *testing.T, stmt Statement) { |
||||
|
showStmt, ok := stmt.(*ShowStatement) |
||||
|
if !ok { |
||||
|
t.Fatalf("Expected *ShowStatement, got %T", stmt) |
||||
|
} |
||||
|
|
||||
|
if showStmt.Type != "databases" { |
||||
|
t.Errorf("Expected type 'databases', got '%s'", showStmt.Type) |
||||
|
} |
||||
|
}, |
||||
|
}, |
||||
|
{ |
||||
|
name: "SHOW TABLES", |
||||
|
sql: "SHOW TABLES", |
||||
|
wantErr: false, |
||||
|
validate: func(t *testing.T, stmt Statement) { |
||||
|
showStmt, ok := stmt.(*ShowStatement) |
||||
|
if !ok { |
||||
|
t.Fatalf("Expected *ShowStatement, got %T", stmt) |
||||
|
} |
||||
|
|
||||
|
if showStmt.Type != "tables" { |
||||
|
t.Errorf("Expected type 'tables', got '%s'", showStmt.Type) |
||||
|
} |
||||
|
}, |
||||
|
}, |
||||
|
{ |
||||
|
name: "SHOW TABLES FROM database", |
||||
|
sql: "SHOW TABLES FROM \"test_db\"", |
||||
|
wantErr: false, |
||||
|
validate: func(t *testing.T, stmt Statement) { |
||||
|
showStmt, ok := stmt.(*ShowStatement) |
||||
|
if !ok { |
||||
|
t.Fatalf("Expected *ShowStatement, got %T", stmt) |
||||
|
} |
||||
|
|
||||
|
if showStmt.Type != "tables" { |
||||
|
t.Errorf("Expected type 'tables', got '%s'", showStmt.Type) |
||||
|
} |
||||
|
|
||||
|
if showStmt.Schema != "test_db" { |
||||
|
t.Errorf("Expected schema 'test_db', got '%s'", showStmt.Schema) |
||||
|
} |
||||
|
}, |
||||
|
}, |
||||
|
} |
||||
|
|
||||
|
for _, tt := range tests { |
||||
|
t.Run(tt.name, func(t *testing.T) { |
||||
|
stmt, err := ParseSQL(tt.sql) |
||||
|
|
||||
|
if tt.wantErr { |
||||
|
if err == nil { |
||||
|
t.Errorf("Expected error, but got none") |
||||
|
} |
||||
|
return |
||||
|
} |
||||
|
|
||||
|
if err != nil { |
||||
|
t.Errorf("Unexpected error: %v", err) |
||||
|
return |
||||
|
} |
||||
|
|
||||
|
if tt.validate != nil { |
||||
|
tt.validate(t, stmt) |
||||
|
} |
||||
|
}) |
||||
|
} |
||||
|
} |
Write
Preview
Loading…
Cancel
Save
Reference in new issue