Browse Source
feat: Add basic arithmetic operators (+, -, *, /, %) with comprehensive tests
feat: Add basic arithmetic operators (+, -, *, /, %) with comprehensive tests
- Implement EvaluateArithmeticExpression with support for all basic operators
- Handle type conversions between int, float, string, and boolean
- Add proper error handling for division/modulo by zero
- Include 14 comprehensive test cases covering all edge cases
- Support mixed type arithmetic (int + float, string numbers, etc.)
All tests passing ✅
pull/7185/head
2 changed files with 401 additions and 0 deletions
@ -0,0 +1,142 @@ |
|||||
|
package engine |
||||
|
|
||||
|
import ( |
||||
|
"fmt" |
||||
|
"math" |
||||
|
"strconv" |
||||
|
|
||||
|
"github.com/seaweedfs/seaweedfs/weed/pb/schema_pb" |
||||
|
) |
||||
|
|
||||
|
// ArithmeticOperator represents basic arithmetic operations
|
||||
|
type ArithmeticOperator string |
||||
|
|
||||
|
const ( |
||||
|
OpAdd ArithmeticOperator = "+" |
||||
|
OpSub ArithmeticOperator = "-" |
||||
|
OpMul ArithmeticOperator = "*" |
||||
|
OpDiv ArithmeticOperator = "/" |
||||
|
OpMod ArithmeticOperator = "%" |
||||
|
) |
||||
|
|
||||
|
// EvaluateArithmeticExpression evaluates basic arithmetic operations between two values
|
||||
|
func (e *SQLEngine) EvaluateArithmeticExpression(left, right *schema_pb.Value, operator ArithmeticOperator) (*schema_pb.Value, error) { |
||||
|
if left == nil || right == nil { |
||||
|
return nil, fmt.Errorf("arithmetic operation requires non-null operands") |
||||
|
} |
||||
|
|
||||
|
// Convert values to numeric types for calculation
|
||||
|
leftNum, err := e.valueToFloat64(left) |
||||
|
if err != nil { |
||||
|
return nil, fmt.Errorf("left operand conversion error: %v", err) |
||||
|
} |
||||
|
|
||||
|
rightNum, err := e.valueToFloat64(right) |
||||
|
if err != nil { |
||||
|
return nil, fmt.Errorf("right operand conversion error: %v", err) |
||||
|
} |
||||
|
|
||||
|
var result float64 |
||||
|
var resultErr error |
||||
|
|
||||
|
switch operator { |
||||
|
case OpAdd: |
||||
|
result = leftNum + rightNum |
||||
|
case OpSub: |
||||
|
result = leftNum - rightNum |
||||
|
case OpMul: |
||||
|
result = leftNum * rightNum |
||||
|
case OpDiv: |
||||
|
if rightNum == 0 { |
||||
|
return nil, fmt.Errorf("division by zero") |
||||
|
} |
||||
|
result = leftNum / rightNum |
||||
|
case OpMod: |
||||
|
if rightNum == 0 { |
||||
|
return nil, fmt.Errorf("modulo by zero") |
||||
|
} |
||||
|
result = math.Mod(leftNum, rightNum) |
||||
|
default: |
||||
|
return nil, fmt.Errorf("unsupported arithmetic operator: %s", operator) |
||||
|
} |
||||
|
|
||||
|
if resultErr != nil { |
||||
|
return nil, resultErr |
||||
|
} |
||||
|
|
||||
|
// Convert result back to appropriate schema value type
|
||||
|
// If both operands were integers and operation doesn't produce decimal, return integer
|
||||
|
if e.isIntegerValue(left) && e.isIntegerValue(right) && |
||||
|
(operator == OpAdd || operator == OpSub || operator == OpMul || operator == OpMod) { |
||||
|
return &schema_pb.Value{ |
||||
|
Kind: &schema_pb.Value_Int64Value{Int64Value: int64(result)}, |
||||
|
}, nil |
||||
|
} |
||||
|
|
||||
|
// Otherwise return as double/float
|
||||
|
return &schema_pb.Value{ |
||||
|
Kind: &schema_pb.Value_DoubleValue{DoubleValue: result}, |
||||
|
}, nil |
||||
|
} |
||||
|
|
||||
|
// Helper function to convert schema_pb.Value to float64
|
||||
|
func (e *SQLEngine) valueToFloat64(value *schema_pb.Value) (float64, error) { |
||||
|
switch v := value.Kind.(type) { |
||||
|
case *schema_pb.Value_Int32Value: |
||||
|
return float64(v.Int32Value), nil |
||||
|
case *schema_pb.Value_Int64Value: |
||||
|
return float64(v.Int64Value), nil |
||||
|
case *schema_pb.Value_FloatValue: |
||||
|
return float64(v.FloatValue), nil |
||||
|
case *schema_pb.Value_DoubleValue: |
||||
|
return v.DoubleValue, nil |
||||
|
case *schema_pb.Value_StringValue: |
||||
|
// Try to parse string as number
|
||||
|
if f, err := strconv.ParseFloat(v.StringValue, 64); err == nil { |
||||
|
return f, nil |
||||
|
} |
||||
|
return 0, fmt.Errorf("cannot convert string '%s' to number", v.StringValue) |
||||
|
case *schema_pb.Value_BoolValue: |
||||
|
if v.BoolValue { |
||||
|
return 1, nil |
||||
|
} |
||||
|
return 0, nil |
||||
|
default: |
||||
|
return 0, fmt.Errorf("cannot convert value type to number") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// Helper function to check if a value is an integer type
|
||||
|
func (e *SQLEngine) isIntegerValue(value *schema_pb.Value) bool { |
||||
|
switch value.Kind.(type) { |
||||
|
case *schema_pb.Value_Int32Value, *schema_pb.Value_Int64Value: |
||||
|
return true |
||||
|
default: |
||||
|
return false |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// Add evaluates addition (left + right)
|
||||
|
func (e *SQLEngine) Add(left, right *schema_pb.Value) (*schema_pb.Value, error) { |
||||
|
return e.EvaluateArithmeticExpression(left, right, OpAdd) |
||||
|
} |
||||
|
|
||||
|
// Subtract evaluates subtraction (left - right)
|
||||
|
func (e *SQLEngine) Subtract(left, right *schema_pb.Value) (*schema_pb.Value, error) { |
||||
|
return e.EvaluateArithmeticExpression(left, right, OpSub) |
||||
|
} |
||||
|
|
||||
|
// Multiply evaluates multiplication (left * right)
|
||||
|
func (e *SQLEngine) Multiply(left, right *schema_pb.Value) (*schema_pb.Value, error) { |
||||
|
return e.EvaluateArithmeticExpression(left, right, OpMul) |
||||
|
} |
||||
|
|
||||
|
// Divide evaluates division (left / right)
|
||||
|
func (e *SQLEngine) Divide(left, right *schema_pb.Value) (*schema_pb.Value, error) { |
||||
|
return e.EvaluateArithmeticExpression(left, right, OpDiv) |
||||
|
} |
||||
|
|
||||
|
// Modulo evaluates modulo operation (left % right)
|
||||
|
func (e *SQLEngine) Modulo(left, right *schema_pb.Value) (*schema_pb.Value, error) { |
||||
|
return e.EvaluateArithmeticExpression(left, right, OpMod) |
||||
|
} |
@ -0,0 +1,259 @@ |
|||||
|
package engine |
||||
|
|
||||
|
import ( |
||||
|
"testing" |
||||
|
|
||||
|
"github.com/seaweedfs/seaweedfs/weed/pb/schema_pb" |
||||
|
) |
||||
|
|
||||
|
func TestArithmeticOperations(t *testing.T) { |
||||
|
engine := NewTestSQLEngine() |
||||
|
|
||||
|
tests := []struct { |
||||
|
name string |
||||
|
left *schema_pb.Value |
||||
|
right *schema_pb.Value |
||||
|
operator ArithmeticOperator |
||||
|
expected *schema_pb.Value |
||||
|
expectErr bool |
||||
|
}{ |
||||
|
// Addition tests
|
||||
|
{ |
||||
|
name: "Add two integers", |
||||
|
left: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 10}}, |
||||
|
right: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 5}}, |
||||
|
operator: OpAdd, |
||||
|
expected: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 15}}, |
||||
|
expectErr: false, |
||||
|
}, |
||||
|
{ |
||||
|
name: "Add integer and float", |
||||
|
left: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 10}}, |
||||
|
right: &schema_pb.Value{Kind: &schema_pb.Value_DoubleValue{DoubleValue: 5.5}}, |
||||
|
operator: OpAdd, |
||||
|
expected: &schema_pb.Value{Kind: &schema_pb.Value_DoubleValue{DoubleValue: 15.5}}, |
||||
|
expectErr: false, |
||||
|
}, |
||||
|
// Subtraction tests
|
||||
|
{ |
||||
|
name: "Subtract two integers", |
||||
|
left: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 10}}, |
||||
|
right: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 3}}, |
||||
|
operator: OpSub, |
||||
|
expected: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 7}}, |
||||
|
expectErr: false, |
||||
|
}, |
||||
|
// Multiplication tests
|
||||
|
{ |
||||
|
name: "Multiply two integers", |
||||
|
left: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 6}}, |
||||
|
right: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 7}}, |
||||
|
operator: OpMul, |
||||
|
expected: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 42}}, |
||||
|
expectErr: false, |
||||
|
}, |
||||
|
{ |
||||
|
name: "Multiply with float", |
||||
|
left: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 5}}, |
||||
|
right: &schema_pb.Value{Kind: &schema_pb.Value_FloatValue{FloatValue: 2.5}}, |
||||
|
operator: OpMul, |
||||
|
expected: &schema_pb.Value{Kind: &schema_pb.Value_DoubleValue{DoubleValue: 12.5}}, |
||||
|
expectErr: false, |
||||
|
}, |
||||
|
// Division tests
|
||||
|
{ |
||||
|
name: "Divide two integers", |
||||
|
left: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 20}}, |
||||
|
right: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 4}}, |
||||
|
operator: OpDiv, |
||||
|
expected: &schema_pb.Value{Kind: &schema_pb.Value_DoubleValue{DoubleValue: 5.0}}, |
||||
|
expectErr: false, |
||||
|
}, |
||||
|
{ |
||||
|
name: "Division by zero", |
||||
|
left: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 10}}, |
||||
|
right: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 0}}, |
||||
|
operator: OpDiv, |
||||
|
expected: nil, |
||||
|
expectErr: true, |
||||
|
}, |
||||
|
// Modulo tests
|
||||
|
{ |
||||
|
name: "Modulo operation", |
||||
|
left: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 17}}, |
||||
|
right: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 5}}, |
||||
|
operator: OpMod, |
||||
|
expected: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 2}}, |
||||
|
expectErr: false, |
||||
|
}, |
||||
|
{ |
||||
|
name: "Modulo by zero", |
||||
|
left: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 10}}, |
||||
|
right: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 0}}, |
||||
|
operator: OpMod, |
||||
|
expected: nil, |
||||
|
expectErr: true, |
||||
|
}, |
||||
|
// String conversion tests
|
||||
|
{ |
||||
|
name: "Add string number to integer", |
||||
|
left: &schema_pb.Value{Kind: &schema_pb.Value_StringValue{StringValue: "15"}}, |
||||
|
right: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 5}}, |
||||
|
operator: OpAdd, |
||||
|
expected: &schema_pb.Value{Kind: &schema_pb.Value_DoubleValue{DoubleValue: 20.0}}, |
||||
|
expectErr: false, |
||||
|
}, |
||||
|
{ |
||||
|
name: "Invalid string conversion", |
||||
|
left: &schema_pb.Value{Kind: &schema_pb.Value_StringValue{StringValue: "not_a_number"}}, |
||||
|
right: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 5}}, |
||||
|
operator: OpAdd, |
||||
|
expected: nil, |
||||
|
expectErr: true, |
||||
|
}, |
||||
|
// Boolean conversion tests
|
||||
|
{ |
||||
|
name: "Add boolean to integer", |
||||
|
left: &schema_pb.Value{Kind: &schema_pb.Value_BoolValue{BoolValue: true}}, |
||||
|
right: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 5}}, |
||||
|
operator: OpAdd, |
||||
|
expected: &schema_pb.Value{Kind: &schema_pb.Value_DoubleValue{DoubleValue: 6.0}}, |
||||
|
expectErr: false, |
||||
|
}, |
||||
|
// Null value tests
|
||||
|
{ |
||||
|
name: "Add with null left operand", |
||||
|
left: nil, |
||||
|
right: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 5}}, |
||||
|
operator: OpAdd, |
||||
|
expected: nil, |
||||
|
expectErr: true, |
||||
|
}, |
||||
|
{ |
||||
|
name: "Add with null right operand", |
||||
|
left: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 5}}, |
||||
|
right: nil, |
||||
|
operator: OpAdd, |
||||
|
expected: nil, |
||||
|
expectErr: true, |
||||
|
}, |
||||
|
} |
||||
|
|
||||
|
for _, tt := range tests { |
||||
|
t.Run(tt.name, func(t *testing.T) { |
||||
|
result, err := engine.EvaluateArithmeticExpression(tt.left, tt.right, tt.operator) |
||||
|
|
||||
|
if tt.expectErr { |
||||
|
if err == nil { |
||||
|
t.Errorf("Expected error but got none") |
||||
|
} |
||||
|
return |
||||
|
} |
||||
|
|
||||
|
if err != nil { |
||||
|
t.Errorf("Unexpected error: %v", err) |
||||
|
return |
||||
|
} |
||||
|
|
||||
|
if !valuesEqual(result, tt.expected) { |
||||
|
t.Errorf("Expected %v, got %v", tt.expected, result) |
||||
|
} |
||||
|
}) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func TestIndividualArithmeticFunctions(t *testing.T) { |
||||
|
engine := NewTestSQLEngine() |
||||
|
|
||||
|
left := &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 10}} |
||||
|
right := &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 3}} |
||||
|
|
||||
|
// Test Add function
|
||||
|
result, err := engine.Add(left, right) |
||||
|
if err != nil { |
||||
|
t.Errorf("Add function failed: %v", err) |
||||
|
} |
||||
|
expected := &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 13}} |
||||
|
if !valuesEqual(result, expected) { |
||||
|
t.Errorf("Add: Expected %v, got %v", expected, result) |
||||
|
} |
||||
|
|
||||
|
// Test Subtract function
|
||||
|
result, err = engine.Subtract(left, right) |
||||
|
if err != nil { |
||||
|
t.Errorf("Subtract function failed: %v", err) |
||||
|
} |
||||
|
expected = &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 7}} |
||||
|
if !valuesEqual(result, expected) { |
||||
|
t.Errorf("Subtract: Expected %v, got %v", expected, result) |
||||
|
} |
||||
|
|
||||
|
// Test Multiply function
|
||||
|
result, err = engine.Multiply(left, right) |
||||
|
if err != nil { |
||||
|
t.Errorf("Multiply function failed: %v", err) |
||||
|
} |
||||
|
expected = &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 30}} |
||||
|
if !valuesEqual(result, expected) { |
||||
|
t.Errorf("Multiply: Expected %v, got %v", expected, result) |
||||
|
} |
||||
|
|
||||
|
// Test Divide function
|
||||
|
result, err = engine.Divide(left, right) |
||||
|
if err != nil { |
||||
|
t.Errorf("Divide function failed: %v", err) |
||||
|
} |
||||
|
expected = &schema_pb.Value{Kind: &schema_pb.Value_DoubleValue{DoubleValue: 10.0/3.0}} |
||||
|
if !valuesEqual(result, expected) { |
||||
|
t.Errorf("Divide: Expected %v, got %v", expected, result) |
||||
|
} |
||||
|
|
||||
|
// Test Modulo function
|
||||
|
result, err = engine.Modulo(left, right) |
||||
|
if err != nil { |
||||
|
t.Errorf("Modulo function failed: %v", err) |
||||
|
} |
||||
|
expected = &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 1}} |
||||
|
if !valuesEqual(result, expected) { |
||||
|
t.Errorf("Modulo: Expected %v, got %v", expected, result) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// Helper function to compare two schema_pb.Value objects
|
||||
|
func valuesEqual(v1, v2 *schema_pb.Value) bool { |
||||
|
if v1 == nil && v2 == nil { |
||||
|
return true |
||||
|
} |
||||
|
if v1 == nil || v2 == nil { |
||||
|
return false |
||||
|
} |
||||
|
|
||||
|
switch v1Kind := v1.Kind.(type) { |
||||
|
case *schema_pb.Value_Int32Value: |
||||
|
if v2Kind, ok := v2.Kind.(*schema_pb.Value_Int32Value); ok { |
||||
|
return v1Kind.Int32Value == v2Kind.Int32Value |
||||
|
} |
||||
|
case *schema_pb.Value_Int64Value: |
||||
|
if v2Kind, ok := v2.Kind.(*schema_pb.Value_Int64Value); ok { |
||||
|
return v1Kind.Int64Value == v2Kind.Int64Value |
||||
|
} |
||||
|
case *schema_pb.Value_FloatValue: |
||||
|
if v2Kind, ok := v2.Kind.(*schema_pb.Value_FloatValue); ok { |
||||
|
return v1Kind.FloatValue == v2Kind.FloatValue |
||||
|
} |
||||
|
case *schema_pb.Value_DoubleValue: |
||||
|
if v2Kind, ok := v2.Kind.(*schema_pb.Value_DoubleValue); ok { |
||||
|
return v1Kind.DoubleValue == v2Kind.DoubleValue |
||||
|
} |
||||
|
case *schema_pb.Value_StringValue: |
||||
|
if v2Kind, ok := v2.Kind.(*schema_pb.Value_StringValue); ok { |
||||
|
return v1Kind.StringValue == v2Kind.StringValue |
||||
|
} |
||||
|
case *schema_pb.Value_BoolValue: |
||||
|
if v2Kind, ok := v2.Kind.(*schema_pb.Value_BoolValue); ok { |
||||
|
return v1Kind.BoolValue == v2Kind.BoolValue |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
return false |
||||
|
} |
Write
Preview
Loading…
Cancel
Save
Reference in new issue