You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
275 lines
7.7 KiB
275 lines
7.7 KiB
package engine
|
|
|
|
import (
|
|
"fmt"
|
|
"testing"
|
|
|
|
"github.com/seaweedfs/seaweedfs/weed/pb/schema_pb"
|
|
)
|
|
|
|
func TestArithmeticExpressionParsing(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
expression string
|
|
expectNil bool
|
|
leftCol string
|
|
rightCol string
|
|
operator string
|
|
}{
|
|
{
|
|
name: "simple addition",
|
|
expression: "id+user_id",
|
|
expectNil: false,
|
|
leftCol: "id",
|
|
rightCol: "user_id",
|
|
operator: "+",
|
|
},
|
|
{
|
|
name: "simple subtraction",
|
|
expression: "col1-col2",
|
|
expectNil: false,
|
|
leftCol: "col1",
|
|
rightCol: "col2",
|
|
operator: "-",
|
|
},
|
|
{
|
|
name: "multiplication with spaces",
|
|
expression: "a * b",
|
|
expectNil: false,
|
|
leftCol: "a",
|
|
rightCol: "b",
|
|
operator: "*",
|
|
},
|
|
{
|
|
name: "string concatenation",
|
|
expression: "first_name||last_name",
|
|
expectNil: false,
|
|
leftCol: "first_name",
|
|
rightCol: "last_name",
|
|
operator: "||",
|
|
},
|
|
{
|
|
name: "string concatenation with spaces",
|
|
expression: "prefix || suffix",
|
|
expectNil: false,
|
|
leftCol: "prefix",
|
|
rightCol: "suffix",
|
|
operator: "||",
|
|
},
|
|
{
|
|
name: "not arithmetic",
|
|
expression: "simple_column",
|
|
expectNil: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
// Use CockroachDB parser to parse the expression
|
|
cockroachParser := NewCockroachSQLParser()
|
|
dummySelect := fmt.Sprintf("SELECT %s", tt.expression)
|
|
stmt, err := cockroachParser.ParseSQL(dummySelect)
|
|
|
|
var result *ArithmeticExpr
|
|
if err == nil {
|
|
if selectStmt, ok := stmt.(*SelectStatement); ok && len(selectStmt.SelectExprs) > 0 {
|
|
if aliasedExpr, ok := selectStmt.SelectExprs[0].(*AliasedExpr); ok {
|
|
if arithmeticExpr, ok := aliasedExpr.Expr.(*ArithmeticExpr); ok {
|
|
result = arithmeticExpr
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if tt.expectNil {
|
|
if result != nil {
|
|
t.Errorf("Expected nil for %s, got %v", tt.expression, result)
|
|
}
|
|
return
|
|
}
|
|
|
|
if result == nil {
|
|
t.Errorf("Expected arithmetic expression for %s, got nil", tt.expression)
|
|
return
|
|
}
|
|
|
|
if result.Operator != tt.operator {
|
|
t.Errorf("Expected operator %s, got %s", tt.operator, result.Operator)
|
|
}
|
|
|
|
// Check left operand
|
|
if leftCol, ok := result.Left.(*ColName); ok {
|
|
if leftCol.Name.String() != tt.leftCol {
|
|
t.Errorf("Expected left column %s, got %s", tt.leftCol, leftCol.Name.String())
|
|
}
|
|
} else {
|
|
t.Errorf("Expected left operand to be ColName, got %T", result.Left)
|
|
}
|
|
|
|
// Check right operand
|
|
if rightCol, ok := result.Right.(*ColName); ok {
|
|
if rightCol.Name.String() != tt.rightCol {
|
|
t.Errorf("Expected right column %s, got %s", tt.rightCol, rightCol.Name.String())
|
|
}
|
|
} else {
|
|
t.Errorf("Expected right operand to be ColName, got %T", result.Right)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestArithmeticExpressionEvaluation(t *testing.T) {
|
|
engine := NewSQLEngine("")
|
|
|
|
// Create test data
|
|
result := HybridScanResult{
|
|
Values: map[string]*schema_pb.Value{
|
|
"id": {Kind: &schema_pb.Value_Int64Value{Int64Value: 10}},
|
|
"user_id": {Kind: &schema_pb.Value_Int64Value{Int64Value: 5}},
|
|
"price": {Kind: &schema_pb.Value_DoubleValue{DoubleValue: 25.5}},
|
|
"qty": {Kind: &schema_pb.Value_Int64Value{Int64Value: 3}},
|
|
"first_name": {Kind: &schema_pb.Value_StringValue{StringValue: "John"}},
|
|
"last_name": {Kind: &schema_pb.Value_StringValue{StringValue: "Doe"}},
|
|
"prefix": {Kind: &schema_pb.Value_StringValue{StringValue: "Hello"}},
|
|
"suffix": {Kind: &schema_pb.Value_StringValue{StringValue: "World"}},
|
|
},
|
|
}
|
|
|
|
tests := []struct {
|
|
name string
|
|
expression string
|
|
expected interface{}
|
|
}{
|
|
{
|
|
name: "integer addition",
|
|
expression: "id+user_id",
|
|
expected: int64(15),
|
|
},
|
|
{
|
|
name: "integer subtraction",
|
|
expression: "id-user_id",
|
|
expected: int64(5),
|
|
},
|
|
{
|
|
name: "mixed types multiplication",
|
|
expression: "price*qty",
|
|
expected: float64(76.5),
|
|
},
|
|
{
|
|
name: "string concatenation",
|
|
expression: "first_name||last_name",
|
|
expected: "JohnDoe",
|
|
},
|
|
{
|
|
name: "string concatenation with spaces",
|
|
expression: "prefix || suffix",
|
|
expected: "HelloWorld",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
// Parse the arithmetic expression using CockroachDB parser
|
|
cockroachParser := NewCockroachSQLParser()
|
|
dummySelect := fmt.Sprintf("SELECT %s", tt.expression)
|
|
stmt, err := cockroachParser.ParseSQL(dummySelect)
|
|
if err != nil {
|
|
t.Fatalf("Failed to parse expression %s: %v", tt.expression, err)
|
|
}
|
|
|
|
var arithmeticExpr *ArithmeticExpr
|
|
if selectStmt, ok := stmt.(*SelectStatement); ok && len(selectStmt.SelectExprs) > 0 {
|
|
if aliasedExpr, ok := selectStmt.SelectExprs[0].(*AliasedExpr); ok {
|
|
if arithExpr, ok := aliasedExpr.Expr.(*ArithmeticExpr); ok {
|
|
arithmeticExpr = arithExpr
|
|
}
|
|
}
|
|
}
|
|
|
|
if arithmeticExpr == nil {
|
|
t.Fatalf("Failed to parse arithmetic expression: %s", tt.expression)
|
|
}
|
|
|
|
// Evaluate the expression
|
|
value, err := engine.evaluateArithmeticExpression(arithmeticExpr, result)
|
|
if err != nil {
|
|
t.Fatalf("Failed to evaluate expression: %v", err)
|
|
}
|
|
|
|
if value == nil {
|
|
t.Fatalf("Got nil value for expression: %s", tt.expression)
|
|
}
|
|
|
|
// Check the result
|
|
switch expected := tt.expected.(type) {
|
|
case int64:
|
|
if intVal, ok := value.Kind.(*schema_pb.Value_Int64Value); ok {
|
|
if intVal.Int64Value != expected {
|
|
t.Errorf("Expected %d, got %d", expected, intVal.Int64Value)
|
|
}
|
|
} else {
|
|
t.Errorf("Expected int64 result, got %T", value.Kind)
|
|
}
|
|
case float64:
|
|
if doubleVal, ok := value.Kind.(*schema_pb.Value_DoubleValue); ok {
|
|
if doubleVal.DoubleValue != expected {
|
|
t.Errorf("Expected %f, got %f", expected, doubleVal.DoubleValue)
|
|
}
|
|
} else {
|
|
t.Errorf("Expected double result, got %T", value.Kind)
|
|
}
|
|
case string:
|
|
if stringVal, ok := value.Kind.(*schema_pb.Value_StringValue); ok {
|
|
if stringVal.StringValue != expected {
|
|
t.Errorf("Expected %s, got %s", expected, stringVal.StringValue)
|
|
}
|
|
} else {
|
|
t.Errorf("Expected string result, got %T", value.Kind)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestSelectArithmeticExpression(t *testing.T) {
|
|
// Test parsing a SELECT with arithmetic and string concatenation expressions
|
|
stmt, err := ParseSQL("SELECT id+user_id, user_id*2, first_name||last_name FROM test_table")
|
|
if err != nil {
|
|
t.Fatalf("Failed to parse SQL: %v", err)
|
|
}
|
|
|
|
selectStmt := stmt.(*SelectStatement)
|
|
if len(selectStmt.SelectExprs) != 3 {
|
|
t.Fatalf("Expected 3 select expressions, got %d", len(selectStmt.SelectExprs))
|
|
}
|
|
|
|
// Check first expression (id+user_id)
|
|
aliasedExpr1 := selectStmt.SelectExprs[0].(*AliasedExpr)
|
|
if arithmeticExpr1, ok := aliasedExpr1.Expr.(*ArithmeticExpr); ok {
|
|
if arithmeticExpr1.Operator != "+" {
|
|
t.Errorf("Expected + operator, got %s", arithmeticExpr1.Operator)
|
|
}
|
|
} else {
|
|
t.Errorf("Expected arithmetic expression, got %T", aliasedExpr1.Expr)
|
|
}
|
|
|
|
// Check second expression (user_id*2)
|
|
aliasedExpr2 := selectStmt.SelectExprs[1].(*AliasedExpr)
|
|
if arithmeticExpr2, ok := aliasedExpr2.Expr.(*ArithmeticExpr); ok {
|
|
if arithmeticExpr2.Operator != "*" {
|
|
t.Errorf("Expected * operator, got %s", arithmeticExpr2.Operator)
|
|
}
|
|
} else {
|
|
t.Errorf("Expected arithmetic expression, got %T", aliasedExpr2.Expr)
|
|
}
|
|
|
|
// Check third expression (first_name||last_name)
|
|
aliasedExpr3 := selectStmt.SelectExprs[2].(*AliasedExpr)
|
|
if arithmeticExpr3, ok := aliasedExpr3.Expr.(*ArithmeticExpr); ok {
|
|
if arithmeticExpr3.Operator != "||" {
|
|
t.Errorf("Expected || operator, got %s", arithmeticExpr3.Operator)
|
|
}
|
|
} else {
|
|
t.Errorf("Expected string concatenation expression, got %T", aliasedExpr3.Expr)
|
|
}
|
|
}
|