You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							275 lines
						
					
					
						
							7.7 KiB
						
					
					
				
			
		
		
		
			
			
			
		
		
	
	
							275 lines
						
					
					
						
							7.7 KiB
						
					
					
				| package engine | |
| 
 | |
| import ( | |
| 	"fmt" | |
| 	"testing" | |
| 
 | |
| 	"github.com/seaweedfs/seaweedfs/weed/pb/schema_pb" | |
| ) | |
| 
 | |
| func TestArithmeticExpressionParsing(t *testing.T) { | |
| 	tests := []struct { | |
| 		name       string | |
| 		expression string | |
| 		expectNil  bool | |
| 		leftCol    string | |
| 		rightCol   string | |
| 		operator   string | |
| 	}{ | |
| 		{ | |
| 			name:       "simple addition", | |
| 			expression: "id+user_id", | |
| 			expectNil:  false, | |
| 			leftCol:    "id", | |
| 			rightCol:   "user_id", | |
| 			operator:   "+", | |
| 		}, | |
| 		{ | |
| 			name:       "simple subtraction", | |
| 			expression: "col1-col2", | |
| 			expectNil:  false, | |
| 			leftCol:    "col1", | |
| 			rightCol:   "col2", | |
| 			operator:   "-", | |
| 		}, | |
| 		{ | |
| 			name:       "multiplication with spaces", | |
| 			expression: "a * b", | |
| 			expectNil:  false, | |
| 			leftCol:    "a", | |
| 			rightCol:   "b", | |
| 			operator:   "*", | |
| 		}, | |
| 		{ | |
| 			name:       "string concatenation", | |
| 			expression: "first_name||last_name", | |
| 			expectNil:  false, | |
| 			leftCol:    "first_name", | |
| 			rightCol:   "last_name", | |
| 			operator:   "||", | |
| 		}, | |
| 		{ | |
| 			name:       "string concatenation with spaces", | |
| 			expression: "prefix || suffix", | |
| 			expectNil:  false, | |
| 			leftCol:    "prefix", | |
| 			rightCol:   "suffix", | |
| 			operator:   "||", | |
| 		}, | |
| 		{ | |
| 			name:       "not arithmetic", | |
| 			expression: "simple_column", | |
| 			expectNil:  true, | |
| 		}, | |
| 	} | |
| 
 | |
| 	for _, tt := range tests { | |
| 		t.Run(tt.name, func(t *testing.T) { | |
| 			// Use CockroachDB parser to parse the expression | |
| 			cockroachParser := NewCockroachSQLParser() | |
| 			dummySelect := fmt.Sprintf("SELECT %s", tt.expression) | |
| 			stmt, err := cockroachParser.ParseSQL(dummySelect) | |
| 
 | |
| 			var result *ArithmeticExpr | |
| 			if err == nil { | |
| 				if selectStmt, ok := stmt.(*SelectStatement); ok && len(selectStmt.SelectExprs) > 0 { | |
| 					if aliasedExpr, ok := selectStmt.SelectExprs[0].(*AliasedExpr); ok { | |
| 						if arithmeticExpr, ok := aliasedExpr.Expr.(*ArithmeticExpr); ok { | |
| 							result = arithmeticExpr | |
| 						} | |
| 					} | |
| 				} | |
| 			} | |
| 
 | |
| 			if tt.expectNil { | |
| 				if result != nil { | |
| 					t.Errorf("Expected nil for %s, got %v", tt.expression, result) | |
| 				} | |
| 				return | |
| 			} | |
| 
 | |
| 			if result == nil { | |
| 				t.Errorf("Expected arithmetic expression for %s, got nil", tt.expression) | |
| 				return | |
| 			} | |
| 
 | |
| 			if result.Operator != tt.operator { | |
| 				t.Errorf("Expected operator %s, got %s", tt.operator, result.Operator) | |
| 			} | |
| 
 | |
| 			// Check left operand | |
| 			if leftCol, ok := result.Left.(*ColName); ok { | |
| 				if leftCol.Name.String() != tt.leftCol { | |
| 					t.Errorf("Expected left column %s, got %s", tt.leftCol, leftCol.Name.String()) | |
| 				} | |
| 			} else { | |
| 				t.Errorf("Expected left operand to be ColName, got %T", result.Left) | |
| 			} | |
| 
 | |
| 			// Check right operand | |
| 			if rightCol, ok := result.Right.(*ColName); ok { | |
| 				if rightCol.Name.String() != tt.rightCol { | |
| 					t.Errorf("Expected right column %s, got %s", tt.rightCol, rightCol.Name.String()) | |
| 				} | |
| 			} else { | |
| 				t.Errorf("Expected right operand to be ColName, got %T", result.Right) | |
| 			} | |
| 		}) | |
| 	} | |
| } | |
| 
 | |
| func TestArithmeticExpressionEvaluation(t *testing.T) { | |
| 	engine := NewSQLEngine("") | |
| 
 | |
| 	// Create test data | |
| 	result := HybridScanResult{ | |
| 		Values: map[string]*schema_pb.Value{ | |
| 			"id":         {Kind: &schema_pb.Value_Int64Value{Int64Value: 10}}, | |
| 			"user_id":    {Kind: &schema_pb.Value_Int64Value{Int64Value: 5}}, | |
| 			"price":      {Kind: &schema_pb.Value_DoubleValue{DoubleValue: 25.5}}, | |
| 			"qty":        {Kind: &schema_pb.Value_Int64Value{Int64Value: 3}}, | |
| 			"first_name": {Kind: &schema_pb.Value_StringValue{StringValue: "John"}}, | |
| 			"last_name":  {Kind: &schema_pb.Value_StringValue{StringValue: "Doe"}}, | |
| 			"prefix":     {Kind: &schema_pb.Value_StringValue{StringValue: "Hello"}}, | |
| 			"suffix":     {Kind: &schema_pb.Value_StringValue{StringValue: "World"}}, | |
| 		}, | |
| 	} | |
| 
 | |
| 	tests := []struct { | |
| 		name       string | |
| 		expression string | |
| 		expected   interface{} | |
| 	}{ | |
| 		{ | |
| 			name:       "integer addition", | |
| 			expression: "id+user_id", | |
| 			expected:   int64(15), | |
| 		}, | |
| 		{ | |
| 			name:       "integer subtraction", | |
| 			expression: "id-user_id", | |
| 			expected:   int64(5), | |
| 		}, | |
| 		{ | |
| 			name:       "mixed types multiplication", | |
| 			expression: "price*qty", | |
| 			expected:   float64(76.5), | |
| 		}, | |
| 		{ | |
| 			name:       "string concatenation", | |
| 			expression: "first_name||last_name", | |
| 			expected:   "JohnDoe", | |
| 		}, | |
| 		{ | |
| 			name:       "string concatenation with spaces", | |
| 			expression: "prefix || suffix", | |
| 			expected:   "HelloWorld", | |
| 		}, | |
| 	} | |
| 
 | |
| 	for _, tt := range tests { | |
| 		t.Run(tt.name, func(t *testing.T) { | |
| 			// Parse the arithmetic expression using CockroachDB parser | |
| 			cockroachParser := NewCockroachSQLParser() | |
| 			dummySelect := fmt.Sprintf("SELECT %s", tt.expression) | |
| 			stmt, err := cockroachParser.ParseSQL(dummySelect) | |
| 			if err != nil { | |
| 				t.Fatalf("Failed to parse expression %s: %v", tt.expression, err) | |
| 			} | |
| 
 | |
| 			var arithmeticExpr *ArithmeticExpr | |
| 			if selectStmt, ok := stmt.(*SelectStatement); ok && len(selectStmt.SelectExprs) > 0 { | |
| 				if aliasedExpr, ok := selectStmt.SelectExprs[0].(*AliasedExpr); ok { | |
| 					if arithExpr, ok := aliasedExpr.Expr.(*ArithmeticExpr); ok { | |
| 						arithmeticExpr = arithExpr | |
| 					} | |
| 				} | |
| 			} | |
| 
 | |
| 			if arithmeticExpr == nil { | |
| 				t.Fatalf("Failed to parse arithmetic expression: %s", tt.expression) | |
| 			} | |
| 
 | |
| 			// Evaluate the expression | |
| 			value, err := engine.evaluateArithmeticExpression(arithmeticExpr, result) | |
| 			if err != nil { | |
| 				t.Fatalf("Failed to evaluate expression: %v", err) | |
| 			} | |
| 
 | |
| 			if value == nil { | |
| 				t.Fatalf("Got nil value for expression: %s", tt.expression) | |
| 			} | |
| 
 | |
| 			// Check the result | |
| 			switch expected := tt.expected.(type) { | |
| 			case int64: | |
| 				if intVal, ok := value.Kind.(*schema_pb.Value_Int64Value); ok { | |
| 					if intVal.Int64Value != expected { | |
| 						t.Errorf("Expected %d, got %d", expected, intVal.Int64Value) | |
| 					} | |
| 				} else { | |
| 					t.Errorf("Expected int64 result, got %T", value.Kind) | |
| 				} | |
| 			case float64: | |
| 				if doubleVal, ok := value.Kind.(*schema_pb.Value_DoubleValue); ok { | |
| 					if doubleVal.DoubleValue != expected { | |
| 						t.Errorf("Expected %f, got %f", expected, doubleVal.DoubleValue) | |
| 					} | |
| 				} else { | |
| 					t.Errorf("Expected double result, got %T", value.Kind) | |
| 				} | |
| 			case string: | |
| 				if stringVal, ok := value.Kind.(*schema_pb.Value_StringValue); ok { | |
| 					if stringVal.StringValue != expected { | |
| 						t.Errorf("Expected %s, got %s", expected, stringVal.StringValue) | |
| 					} | |
| 				} else { | |
| 					t.Errorf("Expected string result, got %T", value.Kind) | |
| 				} | |
| 			} | |
| 		}) | |
| 	} | |
| } | |
| 
 | |
| func TestSelectArithmeticExpression(t *testing.T) { | |
| 	// Test parsing a SELECT with arithmetic and string concatenation expressions | |
| 	stmt, err := ParseSQL("SELECT id+user_id, user_id*2, first_name||last_name FROM test_table") | |
| 	if err != nil { | |
| 		t.Fatalf("Failed to parse SQL: %v", err) | |
| 	} | |
| 
 | |
| 	selectStmt := stmt.(*SelectStatement) | |
| 	if len(selectStmt.SelectExprs) != 3 { | |
| 		t.Fatalf("Expected 3 select expressions, got %d", len(selectStmt.SelectExprs)) | |
| 	} | |
| 
 | |
| 	// Check first expression (id+user_id) | |
| 	aliasedExpr1 := selectStmt.SelectExprs[0].(*AliasedExpr) | |
| 	if arithmeticExpr1, ok := aliasedExpr1.Expr.(*ArithmeticExpr); ok { | |
| 		if arithmeticExpr1.Operator != "+" { | |
| 			t.Errorf("Expected + operator, got %s", arithmeticExpr1.Operator) | |
| 		} | |
| 	} else { | |
| 		t.Errorf("Expected arithmetic expression, got %T", aliasedExpr1.Expr) | |
| 	} | |
| 
 | |
| 	// Check second expression (user_id*2) | |
| 	aliasedExpr2 := selectStmt.SelectExprs[1].(*AliasedExpr) | |
| 	if arithmeticExpr2, ok := aliasedExpr2.Expr.(*ArithmeticExpr); ok { | |
| 		if arithmeticExpr2.Operator != "*" { | |
| 			t.Errorf("Expected * operator, got %s", arithmeticExpr2.Operator) | |
| 		} | |
| 	} else { | |
| 		t.Errorf("Expected arithmetic expression, got %T", aliasedExpr2.Expr) | |
| 	} | |
| 
 | |
| 	// Check third expression (first_name||last_name) | |
| 	aliasedExpr3 := selectStmt.SelectExprs[2].(*AliasedExpr) | |
| 	if arithmeticExpr3, ok := aliasedExpr3.Expr.(*ArithmeticExpr); ok { | |
| 		if arithmeticExpr3.Operator != "||" { | |
| 			t.Errorf("Expected || operator, got %s", arithmeticExpr3.Operator) | |
| 		} | |
| 	} else { | |
| 		t.Errorf("Expected string concatenation expression, got %T", aliasedExpr3.Expr) | |
| 	} | |
| }
 |