You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							466 lines
						
					
					
						
							11 KiB
						
					
					
				
			
		
		
		
			
			
			
		
		
	
	
							466 lines
						
					
					
						
							11 KiB
						
					
					
				
								package engine
							 | 
						|
								
							 | 
						|
								import (
							 | 
						|
									"testing"
							 | 
						|
								)
							 | 
						|
								
							 | 
						|
								func TestParseSQL_COUNT_Functions(t *testing.T) {
							 | 
						|
									tests := []struct {
							 | 
						|
										name     string
							 | 
						|
										sql      string
							 | 
						|
										wantErr  bool
							 | 
						|
										validate func(t *testing.T, stmt Statement)
							 | 
						|
									}{
							 | 
						|
										{
							 | 
						|
											name:    "COUNT(*) basic",
							 | 
						|
											sql:     "SELECT COUNT(*) FROM test_table",
							 | 
						|
											wantErr: false,
							 | 
						|
											validate: func(t *testing.T, stmt Statement) {
							 | 
						|
												selectStmt, ok := stmt.(*SelectStatement)
							 | 
						|
												if !ok {
							 | 
						|
													t.Fatalf("Expected *SelectStatement, got %T", stmt)
							 | 
						|
												}
							 | 
						|
								
							 | 
						|
												if len(selectStmt.SelectExprs) != 1 {
							 | 
						|
													t.Fatalf("Expected 1 select expression, got %d", len(selectStmt.SelectExprs))
							 | 
						|
												}
							 | 
						|
								
							 | 
						|
												aliasedExpr, ok := selectStmt.SelectExprs[0].(*AliasedExpr)
							 | 
						|
												if !ok {
							 | 
						|
													t.Fatalf("Expected *AliasedExpr, got %T", selectStmt.SelectExprs[0])
							 | 
						|
												}
							 | 
						|
								
							 | 
						|
												funcExpr, ok := aliasedExpr.Expr.(*FuncExpr)
							 | 
						|
												if !ok {
							 | 
						|
													t.Fatalf("Expected *FuncExpr, got %T", aliasedExpr.Expr)
							 | 
						|
												}
							 | 
						|
								
							 | 
						|
												if funcExpr.Name.String() != "COUNT" {
							 | 
						|
													t.Errorf("Expected function name 'COUNT', got '%s'", funcExpr.Name.String())
							 | 
						|
												}
							 | 
						|
								
							 | 
						|
												if len(funcExpr.Exprs) != 1 {
							 | 
						|
													t.Fatalf("Expected 1 function argument, got %d", len(funcExpr.Exprs))
							 | 
						|
												}
							 | 
						|
								
							 | 
						|
												starExpr, ok := funcExpr.Exprs[0].(*StarExpr)
							 | 
						|
												if !ok {
							 | 
						|
													t.Errorf("Expected *StarExpr argument, got %T", funcExpr.Exprs[0])
							 | 
						|
												}
							 | 
						|
												_ = starExpr // Use the variable to avoid unused variable error
							 | 
						|
											},
							 | 
						|
										},
							 | 
						|
										{
							 | 
						|
											name:    "COUNT(column_name)",
							 | 
						|
											sql:     "SELECT COUNT(user_id) FROM users",
							 | 
						|
											wantErr: false,
							 | 
						|
											validate: func(t *testing.T, stmt Statement) {
							 | 
						|
												selectStmt, ok := stmt.(*SelectStatement)
							 | 
						|
												if !ok {
							 | 
						|
													t.Fatalf("Expected *SelectStatement, got %T", stmt)
							 | 
						|
												}
							 | 
						|
								
							 | 
						|
												aliasedExpr := selectStmt.SelectExprs[0].(*AliasedExpr)
							 | 
						|
												funcExpr := aliasedExpr.Expr.(*FuncExpr)
							 | 
						|
								
							 | 
						|
												if funcExpr.Name.String() != "COUNT" {
							 | 
						|
													t.Errorf("Expected function name 'COUNT', got '%s'", funcExpr.Name.String())
							 | 
						|
												}
							 | 
						|
								
							 | 
						|
												if len(funcExpr.Exprs) != 1 {
							 | 
						|
													t.Fatalf("Expected 1 function argument, got %d", len(funcExpr.Exprs))
							 | 
						|
												}
							 | 
						|
								
							 | 
						|
												argExpr, ok := funcExpr.Exprs[0].(*AliasedExpr)
							 | 
						|
												if !ok {
							 | 
						|
													t.Errorf("Expected *AliasedExpr argument, got %T", funcExpr.Exprs[0])
							 | 
						|
												}
							 | 
						|
								
							 | 
						|
												colName, ok := argExpr.Expr.(*ColName)
							 | 
						|
												if !ok {
							 | 
						|
													t.Errorf("Expected *ColName, got %T", argExpr.Expr)
							 | 
						|
												}
							 | 
						|
								
							 | 
						|
												if colName.Name.String() != "user_id" {
							 | 
						|
													t.Errorf("Expected column name 'user_id', got '%s'", colName.Name.String())
							 | 
						|
												}
							 | 
						|
											},
							 | 
						|
										},
							 | 
						|
										{
							 | 
						|
											name:    "Multiple aggregate functions",
							 | 
						|
											sql:     "SELECT COUNT(*), SUM(amount), AVG(score) FROM transactions",
							 | 
						|
											wantErr: false,
							 | 
						|
											validate: func(t *testing.T, stmt Statement) {
							 | 
						|
												selectStmt, ok := stmt.(*SelectStatement)
							 | 
						|
												if !ok {
							 | 
						|
													t.Fatalf("Expected *SelectStatement, got %T", stmt)
							 | 
						|
												}
							 | 
						|
								
							 | 
						|
												if len(selectStmt.SelectExprs) != 3 {
							 | 
						|
													t.Fatalf("Expected 3 select expressions, got %d", len(selectStmt.SelectExprs))
							 | 
						|
												}
							 | 
						|
								
							 | 
						|
												// Verify COUNT(*)
							 | 
						|
												countExpr := selectStmt.SelectExprs[0].(*AliasedExpr)
							 | 
						|
												countFunc := countExpr.Expr.(*FuncExpr)
							 | 
						|
												if countFunc.Name.String() != "COUNT" {
							 | 
						|
													t.Errorf("Expected first function to be COUNT, got %s", countFunc.Name.String())
							 | 
						|
												}
							 | 
						|
								
							 | 
						|
												// Verify SUM(amount)
							 | 
						|
												sumExpr := selectStmt.SelectExprs[1].(*AliasedExpr)
							 | 
						|
												sumFunc := sumExpr.Expr.(*FuncExpr)
							 | 
						|
												if sumFunc.Name.String() != "SUM" {
							 | 
						|
													t.Errorf("Expected second function to be SUM, got %s", sumFunc.Name.String())
							 | 
						|
												}
							 | 
						|
								
							 | 
						|
												// Verify AVG(score)
							 | 
						|
												avgExpr := selectStmt.SelectExprs[2].(*AliasedExpr)
							 | 
						|
												avgFunc := avgExpr.Expr.(*FuncExpr)
							 | 
						|
												if avgFunc.Name.String() != "AVG" {
							 | 
						|
													t.Errorf("Expected third function to be AVG, got %s", avgFunc.Name.String())
							 | 
						|
												}
							 | 
						|
											},
							 | 
						|
										},
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									for _, tt := range tests {
							 | 
						|
										t.Run(tt.name, func(t *testing.T) {
							 | 
						|
											stmt, err := ParseSQL(tt.sql)
							 | 
						|
								
							 | 
						|
											if tt.wantErr {
							 | 
						|
												if err == nil {
							 | 
						|
													t.Errorf("Expected error, but got none")
							 | 
						|
												}
							 | 
						|
												return
							 | 
						|
											}
							 | 
						|
								
							 | 
						|
											if err != nil {
							 | 
						|
												t.Errorf("Unexpected error: %v", err)
							 | 
						|
												return
							 | 
						|
											}
							 | 
						|
								
							 | 
						|
											if tt.validate != nil {
							 | 
						|
												tt.validate(t, stmt)
							 | 
						|
											}
							 | 
						|
										})
							 | 
						|
									}
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								func TestParseSQL_SELECT_Expressions(t *testing.T) {
							 | 
						|
									tests := []struct {
							 | 
						|
										name     string
							 | 
						|
										sql      string
							 | 
						|
										wantErr  bool
							 | 
						|
										validate func(t *testing.T, stmt Statement)
							 | 
						|
									}{
							 | 
						|
										{
							 | 
						|
											name:    "SELECT * FROM table",
							 | 
						|
											sql:     "SELECT * FROM users",
							 | 
						|
											wantErr: false,
							 | 
						|
											validate: func(t *testing.T, stmt Statement) {
							 | 
						|
												selectStmt := stmt.(*SelectStatement)
							 | 
						|
												if len(selectStmt.SelectExprs) != 1 {
							 | 
						|
													t.Fatalf("Expected 1 select expression, got %d", len(selectStmt.SelectExprs))
							 | 
						|
												}
							 | 
						|
								
							 | 
						|
												_, ok := selectStmt.SelectExprs[0].(*StarExpr)
							 | 
						|
												if !ok {
							 | 
						|
													t.Errorf("Expected *StarExpr, got %T", selectStmt.SelectExprs[0])
							 | 
						|
												}
							 | 
						|
											},
							 | 
						|
										},
							 | 
						|
										{
							 | 
						|
											name:    "SELECT column FROM table",
							 | 
						|
											sql:     "SELECT user_id FROM users",
							 | 
						|
											wantErr: false,
							 | 
						|
											validate: func(t *testing.T, stmt Statement) {
							 | 
						|
												selectStmt := stmt.(*SelectStatement)
							 | 
						|
												if len(selectStmt.SelectExprs) != 1 {
							 | 
						|
													t.Fatalf("Expected 1 select expression, got %d", len(selectStmt.SelectExprs))
							 | 
						|
												}
							 | 
						|
								
							 | 
						|
												aliasedExpr, ok := selectStmt.SelectExprs[0].(*AliasedExpr)
							 | 
						|
												if !ok {
							 | 
						|
													t.Fatalf("Expected *AliasedExpr, got %T", selectStmt.SelectExprs[0])
							 | 
						|
												}
							 | 
						|
								
							 | 
						|
												colName, ok := aliasedExpr.Expr.(*ColName)
							 | 
						|
												if !ok {
							 | 
						|
													t.Fatalf("Expected *ColName, got %T", aliasedExpr.Expr)
							 | 
						|
												}
							 | 
						|
								
							 | 
						|
												if colName.Name.String() != "user_id" {
							 | 
						|
													t.Errorf("Expected column name 'user_id', got '%s'", colName.Name.String())
							 | 
						|
												}
							 | 
						|
											},
							 | 
						|
										},
							 | 
						|
										{
							 | 
						|
											name:    "SELECT multiple columns",
							 | 
						|
											sql:     "SELECT user_id, name, email FROM users",
							 | 
						|
											wantErr: false,
							 | 
						|
											validate: func(t *testing.T, stmt Statement) {
							 | 
						|
												selectStmt := stmt.(*SelectStatement)
							 | 
						|
												if len(selectStmt.SelectExprs) != 3 {
							 | 
						|
													t.Fatalf("Expected 3 select expressions, got %d", len(selectStmt.SelectExprs))
							 | 
						|
												}
							 | 
						|
								
							 | 
						|
												expectedColumns := []string{"user_id", "name", "email"}
							 | 
						|
												for i, expected := range expectedColumns {
							 | 
						|
													aliasedExpr := selectStmt.SelectExprs[i].(*AliasedExpr)
							 | 
						|
													colName := aliasedExpr.Expr.(*ColName)
							 | 
						|
													if colName.Name.String() != expected {
							 | 
						|
														t.Errorf("Expected column %d to be '%s', got '%s'", i, expected, colName.Name.String())
							 | 
						|
													}
							 | 
						|
												}
							 | 
						|
											},
							 | 
						|
										},
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									for _, tt := range tests {
							 | 
						|
										t.Run(tt.name, func(t *testing.T) {
							 | 
						|
											stmt, err := ParseSQL(tt.sql)
							 | 
						|
								
							 | 
						|
											if tt.wantErr {
							 | 
						|
												if err == nil {
							 | 
						|
													t.Errorf("Expected error, but got none")
							 | 
						|
												}
							 | 
						|
												return
							 | 
						|
											}
							 | 
						|
								
							 | 
						|
											if err != nil {
							 | 
						|
												t.Errorf("Unexpected error: %v", err)
							 | 
						|
												return
							 | 
						|
											}
							 | 
						|
								
							 | 
						|
											if tt.validate != nil {
							 | 
						|
												tt.validate(t, stmt)
							 | 
						|
											}
							 | 
						|
										})
							 | 
						|
									}
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								func TestParseSQL_WHERE_Clauses(t *testing.T) {
							 | 
						|
									tests := []struct {
							 | 
						|
										name     string
							 | 
						|
										sql      string
							 | 
						|
										wantErr  bool
							 | 
						|
										validate func(t *testing.T, stmt Statement)
							 | 
						|
									}{
							 | 
						|
										{
							 | 
						|
											name:    "WHERE with simple comparison",
							 | 
						|
											sql:     "SELECT * FROM users WHERE age > 18",
							 | 
						|
											wantErr: false,
							 | 
						|
											validate: func(t *testing.T, stmt Statement) {
							 | 
						|
												selectStmt := stmt.(*SelectStatement)
							 | 
						|
												if selectStmt.Where == nil {
							 | 
						|
													t.Fatal("Expected WHERE clause, got nil")
							 | 
						|
												}
							 | 
						|
								
							 | 
						|
												// Just verify we have a WHERE clause with an expression
							 | 
						|
												if selectStmt.Where.Expr == nil {
							 | 
						|
													t.Error("Expected WHERE expression, got nil")
							 | 
						|
												}
							 | 
						|
											},
							 | 
						|
										},
							 | 
						|
										{
							 | 
						|
											name:    "WHERE with AND condition",
							 | 
						|
											sql:     "SELECT * FROM users WHERE age > 18 AND status = 'active'",
							 | 
						|
											wantErr: false,
							 | 
						|
											validate: func(t *testing.T, stmt Statement) {
							 | 
						|
												selectStmt := stmt.(*SelectStatement)
							 | 
						|
												if selectStmt.Where == nil {
							 | 
						|
													t.Fatal("Expected WHERE clause, got nil")
							 | 
						|
												}
							 | 
						|
								
							 | 
						|
												// Verify we have an AND expression
							 | 
						|
												andExpr, ok := selectStmt.Where.Expr.(*AndExpr)
							 | 
						|
												if !ok {
							 | 
						|
													t.Errorf("Expected *AndExpr, got %T", selectStmt.Where.Expr)
							 | 
						|
												}
							 | 
						|
												_ = andExpr // Use variable to avoid unused error
							 | 
						|
											},
							 | 
						|
										},
							 | 
						|
										{
							 | 
						|
											name:    "WHERE with OR condition",
							 | 
						|
											sql:     "SELECT * FROM users WHERE age < 18 OR age > 65",
							 | 
						|
											wantErr: false,
							 | 
						|
											validate: func(t *testing.T, stmt Statement) {
							 | 
						|
												selectStmt := stmt.(*SelectStatement)
							 | 
						|
												if selectStmt.Where == nil {
							 | 
						|
													t.Fatal("Expected WHERE clause, got nil")
							 | 
						|
												}
							 | 
						|
								
							 | 
						|
												// Verify we have an OR expression
							 | 
						|
												orExpr, ok := selectStmt.Where.Expr.(*OrExpr)
							 | 
						|
												if !ok {
							 | 
						|
													t.Errorf("Expected *OrExpr, got %T", selectStmt.Where.Expr)
							 | 
						|
												}
							 | 
						|
												_ = orExpr // Use variable to avoid unused error
							 | 
						|
											},
							 | 
						|
										},
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									for _, tt := range tests {
							 | 
						|
										t.Run(tt.name, func(t *testing.T) {
							 | 
						|
											stmt, err := ParseSQL(tt.sql)
							 | 
						|
								
							 | 
						|
											if tt.wantErr {
							 | 
						|
												if err == nil {
							 | 
						|
													t.Errorf("Expected error, but got none")
							 | 
						|
												}
							 | 
						|
												return
							 | 
						|
											}
							 | 
						|
								
							 | 
						|
											if err != nil {
							 | 
						|
												t.Errorf("Unexpected error: %v", err)
							 | 
						|
												return
							 | 
						|
											}
							 | 
						|
								
							 | 
						|
											if tt.validate != nil {
							 | 
						|
												tt.validate(t, stmt)
							 | 
						|
											}
							 | 
						|
										})
							 | 
						|
									}
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								func TestParseSQL_LIMIT_Clauses(t *testing.T) {
							 | 
						|
									tests := []struct {
							 | 
						|
										name     string
							 | 
						|
										sql      string
							 | 
						|
										wantErr  bool
							 | 
						|
										validate func(t *testing.T, stmt Statement)
							 | 
						|
									}{
							 | 
						|
										{
							 | 
						|
											name:    "LIMIT with number",
							 | 
						|
											sql:     "SELECT * FROM users LIMIT 10",
							 | 
						|
											wantErr: false,
							 | 
						|
											validate: func(t *testing.T, stmt Statement) {
							 | 
						|
												selectStmt := stmt.(*SelectStatement)
							 | 
						|
												if selectStmt.Limit == nil {
							 | 
						|
													t.Fatal("Expected LIMIT clause, got nil")
							 | 
						|
												}
							 | 
						|
								
							 | 
						|
												if selectStmt.Limit.Rowcount == nil {
							 | 
						|
													t.Error("Expected LIMIT rowcount, got nil")
							 | 
						|
												}
							 | 
						|
								
							 | 
						|
												sqlVal, ok := selectStmt.Limit.Rowcount.(*SQLVal)
							 | 
						|
												if !ok {
							 | 
						|
													t.Errorf("Expected *SQLVal, got %T", selectStmt.Limit.Rowcount)
							 | 
						|
												}
							 | 
						|
								
							 | 
						|
												if sqlVal.Type != IntVal {
							 | 
						|
													t.Errorf("Expected IntVal type, got %d", sqlVal.Type)
							 | 
						|
												}
							 | 
						|
								
							 | 
						|
												if string(sqlVal.Val) != "10" {
							 | 
						|
													t.Errorf("Expected limit value '10', got '%s'", string(sqlVal.Val))
							 | 
						|
												}
							 | 
						|
											},
							 | 
						|
										},
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									for _, tt := range tests {
							 | 
						|
										t.Run(tt.name, func(t *testing.T) {
							 | 
						|
											stmt, err := ParseSQL(tt.sql)
							 | 
						|
								
							 | 
						|
											if tt.wantErr {
							 | 
						|
												if err == nil {
							 | 
						|
													t.Errorf("Expected error, but got none")
							 | 
						|
												}
							 | 
						|
												return
							 | 
						|
											}
							 | 
						|
								
							 | 
						|
											if err != nil {
							 | 
						|
												t.Errorf("Unexpected error: %v", err)
							 | 
						|
												return
							 | 
						|
											}
							 | 
						|
								
							 | 
						|
											if tt.validate != nil {
							 | 
						|
												tt.validate(t, stmt)
							 | 
						|
											}
							 | 
						|
										})
							 | 
						|
									}
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								func TestParseSQL_SHOW_Statements(t *testing.T) {
							 | 
						|
									tests := []struct {
							 | 
						|
										name     string
							 | 
						|
										sql      string
							 | 
						|
										wantErr  bool
							 | 
						|
										validate func(t *testing.T, stmt Statement)
							 | 
						|
									}{
							 | 
						|
										{
							 | 
						|
											name:    "SHOW DATABASES",
							 | 
						|
											sql:     "SHOW DATABASES",
							 | 
						|
											wantErr: false,
							 | 
						|
											validate: func(t *testing.T, stmt Statement) {
							 | 
						|
												showStmt, ok := stmt.(*ShowStatement)
							 | 
						|
												if !ok {
							 | 
						|
													t.Fatalf("Expected *ShowStatement, got %T", stmt)
							 | 
						|
												}
							 | 
						|
								
							 | 
						|
												if showStmt.Type != "databases" {
							 | 
						|
													t.Errorf("Expected type 'databases', got '%s'", showStmt.Type)
							 | 
						|
												}
							 | 
						|
											},
							 | 
						|
										},
							 | 
						|
										{
							 | 
						|
											name:    "SHOW TABLES",
							 | 
						|
											sql:     "SHOW TABLES",
							 | 
						|
											wantErr: false,
							 | 
						|
											validate: func(t *testing.T, stmt Statement) {
							 | 
						|
												showStmt, ok := stmt.(*ShowStatement)
							 | 
						|
												if !ok {
							 | 
						|
													t.Fatalf("Expected *ShowStatement, got %T", stmt)
							 | 
						|
												}
							 | 
						|
								
							 | 
						|
												if showStmt.Type != "tables" {
							 | 
						|
													t.Errorf("Expected type 'tables', got '%s'", showStmt.Type)
							 | 
						|
												}
							 | 
						|
											},
							 | 
						|
										},
							 | 
						|
										{
							 | 
						|
											name:    "SHOW TABLES FROM database",
							 | 
						|
											sql:     "SHOW TABLES FROM \"test_db\"",
							 | 
						|
											wantErr: false,
							 | 
						|
											validate: func(t *testing.T, stmt Statement) {
							 | 
						|
												showStmt, ok := stmt.(*ShowStatement)
							 | 
						|
												if !ok {
							 | 
						|
													t.Fatalf("Expected *ShowStatement, got %T", stmt)
							 | 
						|
												}
							 | 
						|
								
							 | 
						|
												if showStmt.Type != "tables" {
							 | 
						|
													t.Errorf("Expected type 'tables', got '%s'", showStmt.Type)
							 | 
						|
												}
							 | 
						|
								
							 | 
						|
												if showStmt.Schema != "test_db" {
							 | 
						|
													t.Errorf("Expected schema 'test_db', got '%s'", showStmt.Schema)
							 | 
						|
												}
							 | 
						|
											},
							 | 
						|
										},
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									for _, tt := range tests {
							 | 
						|
										t.Run(tt.name, func(t *testing.T) {
							 | 
						|
											stmt, err := ParseSQL(tt.sql)
							 | 
						|
								
							 | 
						|
											if tt.wantErr {
							 | 
						|
												if err == nil {
							 | 
						|
													t.Errorf("Expected error, but got none")
							 | 
						|
												}
							 | 
						|
												return
							 | 
						|
											}
							 | 
						|
								
							 | 
						|
											if err != nil {
							 | 
						|
												t.Errorf("Unexpected error: %v", err)
							 | 
						|
												return
							 | 
						|
											}
							 | 
						|
								
							 | 
						|
											if tt.validate != nil {
							 | 
						|
												tt.validate(t, stmt)
							 | 
						|
											}
							 | 
						|
										})
							 | 
						|
									}
							 | 
						|
								}
							 |