You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
564 lines
14 KiB
564 lines
14 KiB
package engine
|
|
|
|
import (
|
|
"testing"
|
|
)
|
|
|
|
func TestParseSQL_COUNT_Functions(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
sql string
|
|
wantErr bool
|
|
validate func(t *testing.T, stmt Statement)
|
|
}{
|
|
{
|
|
name: "COUNT(*) basic",
|
|
sql: "SELECT COUNT(*) FROM test_table",
|
|
wantErr: false,
|
|
validate: func(t *testing.T, stmt Statement) {
|
|
selectStmt, ok := stmt.(*SelectStatement)
|
|
if !ok {
|
|
t.Fatalf("Expected *SelectStatement, got %T", stmt)
|
|
}
|
|
|
|
if len(selectStmt.SelectExprs) != 1 {
|
|
t.Fatalf("Expected 1 select expression, got %d", len(selectStmt.SelectExprs))
|
|
}
|
|
|
|
aliasedExpr, ok := selectStmt.SelectExprs[0].(*AliasedExpr)
|
|
if !ok {
|
|
t.Fatalf("Expected *AliasedExpr, got %T", selectStmt.SelectExprs[0])
|
|
}
|
|
|
|
funcExpr, ok := aliasedExpr.Expr.(*FuncExpr)
|
|
if !ok {
|
|
t.Fatalf("Expected *FuncExpr, got %T", aliasedExpr.Expr)
|
|
}
|
|
|
|
if funcExpr.Name.String() != "COUNT" {
|
|
t.Errorf("Expected function name 'COUNT', got '%s'", funcExpr.Name.String())
|
|
}
|
|
|
|
if len(funcExpr.Exprs) != 1 {
|
|
t.Fatalf("Expected 1 function argument, got %d", len(funcExpr.Exprs))
|
|
}
|
|
|
|
starExpr, ok := funcExpr.Exprs[0].(*StarExpr)
|
|
if !ok {
|
|
t.Errorf("Expected *StarExpr argument, got %T", funcExpr.Exprs[0])
|
|
}
|
|
_ = starExpr // Use the variable to avoid unused variable error
|
|
},
|
|
},
|
|
{
|
|
name: "COUNT(column_name)",
|
|
sql: "SELECT COUNT(user_id) FROM users",
|
|
wantErr: false,
|
|
validate: func(t *testing.T, stmt Statement) {
|
|
selectStmt, ok := stmt.(*SelectStatement)
|
|
if !ok {
|
|
t.Fatalf("Expected *SelectStatement, got %T", stmt)
|
|
}
|
|
|
|
aliasedExpr := selectStmt.SelectExprs[0].(*AliasedExpr)
|
|
funcExpr := aliasedExpr.Expr.(*FuncExpr)
|
|
|
|
if funcExpr.Name.String() != "COUNT" {
|
|
t.Errorf("Expected function name 'COUNT', got '%s'", funcExpr.Name.String())
|
|
}
|
|
|
|
if len(funcExpr.Exprs) != 1 {
|
|
t.Fatalf("Expected 1 function argument, got %d", len(funcExpr.Exprs))
|
|
}
|
|
|
|
argExpr, ok := funcExpr.Exprs[0].(*AliasedExpr)
|
|
if !ok {
|
|
t.Errorf("Expected *AliasedExpr argument, got %T", funcExpr.Exprs[0])
|
|
}
|
|
|
|
colName, ok := argExpr.Expr.(*ColName)
|
|
if !ok {
|
|
t.Errorf("Expected *ColName, got %T", argExpr.Expr)
|
|
}
|
|
|
|
if colName.Name.String() != "user_id" {
|
|
t.Errorf("Expected column name 'user_id', got '%s'", colName.Name.String())
|
|
}
|
|
},
|
|
},
|
|
{
|
|
name: "Multiple aggregate functions",
|
|
sql: "SELECT COUNT(*), SUM(amount), AVG(score) FROM transactions",
|
|
wantErr: false,
|
|
validate: func(t *testing.T, stmt Statement) {
|
|
selectStmt, ok := stmt.(*SelectStatement)
|
|
if !ok {
|
|
t.Fatalf("Expected *SelectStatement, got %T", stmt)
|
|
}
|
|
|
|
if len(selectStmt.SelectExprs) != 3 {
|
|
t.Fatalf("Expected 3 select expressions, got %d", len(selectStmt.SelectExprs))
|
|
}
|
|
|
|
// Verify COUNT(*)
|
|
countExpr := selectStmt.SelectExprs[0].(*AliasedExpr)
|
|
countFunc := countExpr.Expr.(*FuncExpr)
|
|
if countFunc.Name.String() != "COUNT" {
|
|
t.Errorf("Expected first function to be COUNT, got %s", countFunc.Name.String())
|
|
}
|
|
|
|
// Verify SUM(amount)
|
|
sumExpr := selectStmt.SelectExprs[1].(*AliasedExpr)
|
|
sumFunc := sumExpr.Expr.(*FuncExpr)
|
|
if sumFunc.Name.String() != "SUM" {
|
|
t.Errorf("Expected second function to be SUM, got %s", sumFunc.Name.String())
|
|
}
|
|
|
|
// Verify AVG(score)
|
|
avgExpr := selectStmt.SelectExprs[2].(*AliasedExpr)
|
|
avgFunc := avgExpr.Expr.(*FuncExpr)
|
|
if avgFunc.Name.String() != "AVG" {
|
|
t.Errorf("Expected third function to be AVG, got %s", avgFunc.Name.String())
|
|
}
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
stmt, err := ParseSQL(tt.sql)
|
|
|
|
if tt.wantErr {
|
|
if err == nil {
|
|
t.Errorf("Expected error, but got none")
|
|
}
|
|
return
|
|
}
|
|
|
|
if err != nil {
|
|
t.Errorf("Unexpected error: %v", err)
|
|
return
|
|
}
|
|
|
|
if tt.validate != nil {
|
|
tt.validate(t, stmt)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestParseSQL_SELECT_Expressions(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
sql string
|
|
wantErr bool
|
|
validate func(t *testing.T, stmt Statement)
|
|
}{
|
|
{
|
|
name: "SELECT * FROM table",
|
|
sql: "SELECT * FROM users",
|
|
wantErr: false,
|
|
validate: func(t *testing.T, stmt Statement) {
|
|
selectStmt := stmt.(*SelectStatement)
|
|
if len(selectStmt.SelectExprs) != 1 {
|
|
t.Fatalf("Expected 1 select expression, got %d", len(selectStmt.SelectExprs))
|
|
}
|
|
|
|
_, ok := selectStmt.SelectExprs[0].(*StarExpr)
|
|
if !ok {
|
|
t.Errorf("Expected *StarExpr, got %T", selectStmt.SelectExprs[0])
|
|
}
|
|
},
|
|
},
|
|
{
|
|
name: "SELECT column FROM table",
|
|
sql: "SELECT user_id FROM users",
|
|
wantErr: false,
|
|
validate: func(t *testing.T, stmt Statement) {
|
|
selectStmt := stmt.(*SelectStatement)
|
|
if len(selectStmt.SelectExprs) != 1 {
|
|
t.Fatalf("Expected 1 select expression, got %d", len(selectStmt.SelectExprs))
|
|
}
|
|
|
|
aliasedExpr, ok := selectStmt.SelectExprs[0].(*AliasedExpr)
|
|
if !ok {
|
|
t.Fatalf("Expected *AliasedExpr, got %T", selectStmt.SelectExprs[0])
|
|
}
|
|
|
|
colName, ok := aliasedExpr.Expr.(*ColName)
|
|
if !ok {
|
|
t.Fatalf("Expected *ColName, got %T", aliasedExpr.Expr)
|
|
}
|
|
|
|
if colName.Name.String() != "user_id" {
|
|
t.Errorf("Expected column name 'user_id', got '%s'", colName.Name.String())
|
|
}
|
|
},
|
|
},
|
|
{
|
|
name: "SELECT multiple columns",
|
|
sql: "SELECT user_id, name, email FROM users",
|
|
wantErr: false,
|
|
validate: func(t *testing.T, stmt Statement) {
|
|
selectStmt := stmt.(*SelectStatement)
|
|
if len(selectStmt.SelectExprs) != 3 {
|
|
t.Fatalf("Expected 3 select expressions, got %d", len(selectStmt.SelectExprs))
|
|
}
|
|
|
|
expectedColumns := []string{"user_id", "name", "email"}
|
|
for i, expected := range expectedColumns {
|
|
aliasedExpr := selectStmt.SelectExprs[i].(*AliasedExpr)
|
|
colName := aliasedExpr.Expr.(*ColName)
|
|
if colName.Name.String() != expected {
|
|
t.Errorf("Expected column %d to be '%s', got '%s'", i, expected, colName.Name.String())
|
|
}
|
|
}
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
stmt, err := ParseSQL(tt.sql)
|
|
|
|
if tt.wantErr {
|
|
if err == nil {
|
|
t.Errorf("Expected error, but got none")
|
|
}
|
|
return
|
|
}
|
|
|
|
if err != nil {
|
|
t.Errorf("Unexpected error: %v", err)
|
|
return
|
|
}
|
|
|
|
if tt.validate != nil {
|
|
tt.validate(t, stmt)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestParseSQL_WHERE_Clauses(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
sql string
|
|
wantErr bool
|
|
validate func(t *testing.T, stmt Statement)
|
|
}{
|
|
{
|
|
name: "WHERE with simple comparison",
|
|
sql: "SELECT * FROM users WHERE age > 18",
|
|
wantErr: false,
|
|
validate: func(t *testing.T, stmt Statement) {
|
|
selectStmt := stmt.(*SelectStatement)
|
|
if selectStmt.Where == nil {
|
|
t.Fatal("Expected WHERE clause, got nil")
|
|
}
|
|
|
|
// Just verify we have a WHERE clause with an expression
|
|
if selectStmt.Where.Expr == nil {
|
|
t.Error("Expected WHERE expression, got nil")
|
|
}
|
|
},
|
|
},
|
|
{
|
|
name: "WHERE with AND condition",
|
|
sql: "SELECT * FROM users WHERE age > 18 AND status = 'active'",
|
|
wantErr: false,
|
|
validate: func(t *testing.T, stmt Statement) {
|
|
selectStmt := stmt.(*SelectStatement)
|
|
if selectStmt.Where == nil {
|
|
t.Fatal("Expected WHERE clause, got nil")
|
|
}
|
|
|
|
// Verify we have an AND expression
|
|
andExpr, ok := selectStmt.Where.Expr.(*AndExpr)
|
|
if !ok {
|
|
t.Errorf("Expected *AndExpr, got %T", selectStmt.Where.Expr)
|
|
}
|
|
_ = andExpr // Use variable to avoid unused error
|
|
},
|
|
},
|
|
{
|
|
name: "WHERE with OR condition",
|
|
sql: "SELECT * FROM users WHERE age < 18 OR age > 65",
|
|
wantErr: false,
|
|
validate: func(t *testing.T, stmt Statement) {
|
|
selectStmt := stmt.(*SelectStatement)
|
|
if selectStmt.Where == nil {
|
|
t.Fatal("Expected WHERE clause, got nil")
|
|
}
|
|
|
|
// Verify we have an OR expression
|
|
orExpr, ok := selectStmt.Where.Expr.(*OrExpr)
|
|
if !ok {
|
|
t.Errorf("Expected *OrExpr, got %T", selectStmt.Where.Expr)
|
|
}
|
|
_ = orExpr // Use variable to avoid unused error
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
stmt, err := ParseSQL(tt.sql)
|
|
|
|
if tt.wantErr {
|
|
if err == nil {
|
|
t.Errorf("Expected error, but got none")
|
|
}
|
|
return
|
|
}
|
|
|
|
if err != nil {
|
|
t.Errorf("Unexpected error: %v", err)
|
|
return
|
|
}
|
|
|
|
if tt.validate != nil {
|
|
tt.validate(t, stmt)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestParseSQL_LIMIT_Clauses(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
sql string
|
|
wantErr bool
|
|
validate func(t *testing.T, stmt Statement)
|
|
}{
|
|
{
|
|
name: "LIMIT with number",
|
|
sql: "SELECT * FROM users LIMIT 10",
|
|
wantErr: false,
|
|
validate: func(t *testing.T, stmt Statement) {
|
|
selectStmt := stmt.(*SelectStatement)
|
|
if selectStmt.Limit == nil {
|
|
t.Fatal("Expected LIMIT clause, got nil")
|
|
}
|
|
|
|
if selectStmt.Limit.Rowcount == nil {
|
|
t.Error("Expected LIMIT rowcount, got nil")
|
|
}
|
|
|
|
// Verify no OFFSET is set
|
|
if selectStmt.Limit.Offset != nil {
|
|
t.Error("Expected OFFSET to be nil for LIMIT-only query")
|
|
}
|
|
|
|
sqlVal, ok := selectStmt.Limit.Rowcount.(*SQLVal)
|
|
if !ok {
|
|
t.Errorf("Expected *SQLVal, got %T", selectStmt.Limit.Rowcount)
|
|
}
|
|
|
|
if sqlVal.Type != IntVal {
|
|
t.Errorf("Expected IntVal type, got %d", sqlVal.Type)
|
|
}
|
|
|
|
if string(sqlVal.Val) != "10" {
|
|
t.Errorf("Expected limit value '10', got '%s'", string(sqlVal.Val))
|
|
}
|
|
},
|
|
},
|
|
{
|
|
name: "LIMIT with OFFSET",
|
|
sql: "SELECT * FROM users LIMIT 10 OFFSET 5",
|
|
wantErr: false,
|
|
validate: func(t *testing.T, stmt Statement) {
|
|
selectStmt := stmt.(*SelectStatement)
|
|
if selectStmt.Limit == nil {
|
|
t.Fatal("Expected LIMIT clause, got nil")
|
|
}
|
|
|
|
// Verify LIMIT value
|
|
if selectStmt.Limit.Rowcount == nil {
|
|
t.Error("Expected LIMIT rowcount, got nil")
|
|
}
|
|
|
|
limitVal, ok := selectStmt.Limit.Rowcount.(*SQLVal)
|
|
if !ok {
|
|
t.Errorf("Expected *SQLVal for LIMIT, got %T", selectStmt.Limit.Rowcount)
|
|
}
|
|
|
|
if limitVal.Type != IntVal {
|
|
t.Errorf("Expected IntVal type for LIMIT, got %d", limitVal.Type)
|
|
}
|
|
|
|
if string(limitVal.Val) != "10" {
|
|
t.Errorf("Expected limit value '10', got '%s'", string(limitVal.Val))
|
|
}
|
|
|
|
// Verify OFFSET value
|
|
if selectStmt.Limit.Offset == nil {
|
|
t.Fatal("Expected OFFSET clause, got nil")
|
|
}
|
|
|
|
offsetVal, ok := selectStmt.Limit.Offset.(*SQLVal)
|
|
if !ok {
|
|
t.Errorf("Expected *SQLVal for OFFSET, got %T", selectStmt.Limit.Offset)
|
|
}
|
|
|
|
if offsetVal.Type != IntVal {
|
|
t.Errorf("Expected IntVal type for OFFSET, got %d", offsetVal.Type)
|
|
}
|
|
|
|
if string(offsetVal.Val) != "5" {
|
|
t.Errorf("Expected offset value '5', got '%s'", string(offsetVal.Val))
|
|
}
|
|
},
|
|
},
|
|
{
|
|
name: "LIMIT with OFFSET zero",
|
|
sql: "SELECT * FROM users LIMIT 5 OFFSET 0",
|
|
wantErr: false,
|
|
validate: func(t *testing.T, stmt Statement) {
|
|
selectStmt := stmt.(*SelectStatement)
|
|
if selectStmt.Limit == nil {
|
|
t.Fatal("Expected LIMIT clause, got nil")
|
|
}
|
|
|
|
// Verify OFFSET is 0
|
|
if selectStmt.Limit.Offset == nil {
|
|
t.Fatal("Expected OFFSET clause, got nil")
|
|
}
|
|
|
|
offsetVal, ok := selectStmt.Limit.Offset.(*SQLVal)
|
|
if !ok {
|
|
t.Errorf("Expected *SQLVal for OFFSET, got %T", selectStmt.Limit.Offset)
|
|
}
|
|
|
|
if string(offsetVal.Val) != "0" {
|
|
t.Errorf("Expected offset value '0', got '%s'", string(offsetVal.Val))
|
|
}
|
|
},
|
|
},
|
|
{
|
|
name: "LIMIT with large OFFSET",
|
|
sql: "SELECT * FROM users LIMIT 100 OFFSET 1000",
|
|
wantErr: false,
|
|
validate: func(t *testing.T, stmt Statement) {
|
|
selectStmt := stmt.(*SelectStatement)
|
|
if selectStmt.Limit == nil {
|
|
t.Fatal("Expected LIMIT clause, got nil")
|
|
}
|
|
|
|
// Verify large OFFSET value
|
|
offsetVal, ok := selectStmt.Limit.Offset.(*SQLVal)
|
|
if !ok {
|
|
t.Errorf("Expected *SQLVal for OFFSET, got %T", selectStmt.Limit.Offset)
|
|
}
|
|
|
|
if string(offsetVal.Val) != "1000" {
|
|
t.Errorf("Expected offset value '1000', got '%s'", string(offsetVal.Val))
|
|
}
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
stmt, err := ParseSQL(tt.sql)
|
|
|
|
if tt.wantErr {
|
|
if err == nil {
|
|
t.Errorf("Expected error, but got none")
|
|
}
|
|
return
|
|
}
|
|
|
|
if err != nil {
|
|
t.Errorf("Unexpected error: %v", err)
|
|
return
|
|
}
|
|
|
|
if tt.validate != nil {
|
|
tt.validate(t, stmt)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestParseSQL_SHOW_Statements(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
sql string
|
|
wantErr bool
|
|
validate func(t *testing.T, stmt Statement)
|
|
}{
|
|
{
|
|
name: "SHOW DATABASES",
|
|
sql: "SHOW DATABASES",
|
|
wantErr: false,
|
|
validate: func(t *testing.T, stmt Statement) {
|
|
showStmt, ok := stmt.(*ShowStatement)
|
|
if !ok {
|
|
t.Fatalf("Expected *ShowStatement, got %T", stmt)
|
|
}
|
|
|
|
if showStmt.Type != "databases" {
|
|
t.Errorf("Expected type 'databases', got '%s'", showStmt.Type)
|
|
}
|
|
},
|
|
},
|
|
{
|
|
name: "SHOW TABLES",
|
|
sql: "SHOW TABLES",
|
|
wantErr: false,
|
|
validate: func(t *testing.T, stmt Statement) {
|
|
showStmt, ok := stmt.(*ShowStatement)
|
|
if !ok {
|
|
t.Fatalf("Expected *ShowStatement, got %T", stmt)
|
|
}
|
|
|
|
if showStmt.Type != "tables" {
|
|
t.Errorf("Expected type 'tables', got '%s'", showStmt.Type)
|
|
}
|
|
},
|
|
},
|
|
{
|
|
name: "SHOW TABLES FROM database",
|
|
sql: "SHOW TABLES FROM \"test_db\"",
|
|
wantErr: false,
|
|
validate: func(t *testing.T, stmt Statement) {
|
|
showStmt, ok := stmt.(*ShowStatement)
|
|
if !ok {
|
|
t.Fatalf("Expected *ShowStatement, got %T", stmt)
|
|
}
|
|
|
|
if showStmt.Type != "tables" {
|
|
t.Errorf("Expected type 'tables', got '%s'", showStmt.Type)
|
|
}
|
|
|
|
if showStmt.Schema != "test_db" {
|
|
t.Errorf("Expected schema 'test_db', got '%s'", showStmt.Schema)
|
|
}
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
stmt, err := ParseSQL(tt.sql)
|
|
|
|
if tt.wantErr {
|
|
if err == nil {
|
|
t.Errorf("Expected error, but got none")
|
|
}
|
|
return
|
|
}
|
|
|
|
if err != nil {
|
|
t.Errorf("Unexpected error: %v", err)
|
|
return
|
|
}
|
|
|
|
if tt.validate != nil {
|
|
tt.validate(t, stmt)
|
|
}
|
|
})
|
|
}
|
|
}
|