|
@ -168,9 +168,11 @@ func (e *SQLEngine) executeSelectStatement(ctx context.Context, stmt *sqlparser. |
|
|
return e.executeSelectWithSampleData(ctx, stmt, database, tableName) |
|
|
return e.executeSelectWithSampleData(ctx, stmt, database, tableName) |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
// Parse SELECT columns
|
|
|
|
|
|
|
|
|
// Parse SELECT columns and detect aggregation functions
|
|
|
var columns []string |
|
|
var columns []string |
|
|
|
|
|
var aggregations []AggregationSpec |
|
|
selectAll := false |
|
|
selectAll := false |
|
|
|
|
|
hasAggregations := false |
|
|
|
|
|
|
|
|
for _, selectExpr := range stmt.SelectExprs { |
|
|
for _, selectExpr := range stmt.SelectExprs { |
|
|
switch expr := selectExpr.(type) { |
|
|
switch expr := selectExpr.(type) { |
|
@ -180,6 +182,14 @@ func (e *SQLEngine) executeSelectStatement(ctx context.Context, stmt *sqlparser. |
|
|
switch col := expr.Expr.(type) { |
|
|
switch col := expr.Expr.(type) { |
|
|
case *sqlparser.ColName: |
|
|
case *sqlparser.ColName: |
|
|
columns = append(columns, col.Name.String()) |
|
|
columns = append(columns, col.Name.String()) |
|
|
|
|
|
case *sqlparser.FuncExpr: |
|
|
|
|
|
// Handle aggregation functions
|
|
|
|
|
|
aggSpec, err := e.parseAggregationFunction(col, expr) |
|
|
|
|
|
if err != nil { |
|
|
|
|
|
return &QueryResult{Error: err}, err |
|
|
|
|
|
} |
|
|
|
|
|
aggregations = append(aggregations, *aggSpec) |
|
|
|
|
|
hasAggregations = true |
|
|
default: |
|
|
default: |
|
|
err := fmt.Errorf("unsupported SELECT expression: %T", col) |
|
|
err := fmt.Errorf("unsupported SELECT expression: %T", col) |
|
|
return &QueryResult{Error: err}, err |
|
|
return &QueryResult{Error: err}, err |
|
@ -190,6 +200,11 @@ func (e *SQLEngine) executeSelectStatement(ctx context.Context, stmt *sqlparser. |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
// If we have aggregations, use aggregation query path
|
|
|
|
|
|
if hasAggregations { |
|
|
|
|
|
return e.executeAggregationQuery(ctx, hybridScanner, aggregations, stmt) |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
// Parse WHERE clause for predicate pushdown
|
|
|
// Parse WHERE clause for predicate pushdown
|
|
|
var predicate func(*schema_pb.RecordValue) bool |
|
|
var predicate func(*schema_pb.RecordValue) bool |
|
|
if stmt.Where != nil { |
|
|
if stmt.Where != nil { |
|
@ -988,6 +1003,338 @@ func (e *SQLEngine) dropTable(ctx context.Context, stmt *sqlparser.DDL) (*QueryR |
|
|
return result, nil |
|
|
return result, nil |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
// AggregationSpec defines an aggregation function to be computed
|
|
|
|
|
|
type AggregationSpec struct { |
|
|
|
|
|
Function string // COUNT, SUM, AVG, MIN, MAX
|
|
|
|
|
|
Column string // Column name, or "*" for COUNT(*)
|
|
|
|
|
|
Alias string // Optional alias for the result column
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
// AggregationResult holds the computed result of an aggregation
|
|
|
|
|
|
type AggregationResult struct { |
|
|
|
|
|
Count int64 |
|
|
|
|
|
Sum float64 |
|
|
|
|
|
Min interface{} |
|
|
|
|
|
Max interface{} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
// parseAggregationFunction parses an aggregation function expression
|
|
|
|
|
|
func (e *SQLEngine) parseAggregationFunction(funcExpr *sqlparser.FuncExpr, aliasExpr *sqlparser.AliasedExpr) (*AggregationSpec, error) { |
|
|
|
|
|
funcName := strings.ToUpper(funcExpr.Name.String()) |
|
|
|
|
|
|
|
|
|
|
|
// Get alias name if specified
|
|
|
|
|
|
alias := funcName // Default alias is the function name
|
|
|
|
|
|
if !aliasExpr.As.IsEmpty() { |
|
|
|
|
|
alias = aliasExpr.As.String() |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
spec := &AggregationSpec{ |
|
|
|
|
|
Function: funcName, |
|
|
|
|
|
Alias: alias, |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
// Parse function arguments
|
|
|
|
|
|
switch funcName { |
|
|
|
|
|
case "COUNT": |
|
|
|
|
|
if len(funcExpr.Exprs) != 1 { |
|
|
|
|
|
return nil, fmt.Errorf("COUNT function expects exactly 1 argument") |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
switch arg := funcExpr.Exprs[0].(type) { |
|
|
|
|
|
case *sqlparser.StarExpr: |
|
|
|
|
|
spec.Column = "*" |
|
|
|
|
|
case *sqlparser.AliasedExpr: |
|
|
|
|
|
if colName, ok := arg.Expr.(*sqlparser.ColName); ok { |
|
|
|
|
|
spec.Column = colName.Name.String() |
|
|
|
|
|
} else { |
|
|
|
|
|
return nil, fmt.Errorf("COUNT argument must be a column name or *") |
|
|
|
|
|
} |
|
|
|
|
|
default: |
|
|
|
|
|
return nil, fmt.Errorf("unsupported COUNT argument: %T", arg) |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
case "SUM", "AVG", "MIN", "MAX": |
|
|
|
|
|
if len(funcExpr.Exprs) != 1 { |
|
|
|
|
|
return nil, fmt.Errorf("%s function expects exactly 1 argument", funcName) |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
switch arg := funcExpr.Exprs[0].(type) { |
|
|
|
|
|
case *sqlparser.AliasedExpr: |
|
|
|
|
|
if colName, ok := arg.Expr.(*sqlparser.ColName); ok { |
|
|
|
|
|
spec.Column = colName.Name.String() |
|
|
|
|
|
} else { |
|
|
|
|
|
return nil, fmt.Errorf("%s argument must be a column name", funcName) |
|
|
|
|
|
} |
|
|
|
|
|
default: |
|
|
|
|
|
return nil, fmt.Errorf("unsupported %s argument: %T", funcName, arg) |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
default: |
|
|
|
|
|
return nil, fmt.Errorf("unsupported aggregation function: %s", funcName) |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
return spec, nil |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
// executeAggregationQuery handles SELECT queries with aggregation functions
|
|
|
|
|
|
func (e *SQLEngine) executeAggregationQuery(ctx context.Context, hybridScanner *HybridMessageScanner, aggregations []AggregationSpec, stmt *sqlparser.Select) (*QueryResult, error) { |
|
|
|
|
|
// Parse WHERE clause for filtering
|
|
|
|
|
|
var predicate func(*schema_pb.RecordValue) bool |
|
|
|
|
|
var err error |
|
|
|
|
|
if stmt.Where != nil { |
|
|
|
|
|
predicate, err = e.buildPredicate(stmt.Where.Expr) |
|
|
|
|
|
if err != nil { |
|
|
|
|
|
return &QueryResult{Error: err}, err |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
// Extract time filters for optimization
|
|
|
|
|
|
startTimeNs, stopTimeNs := int64(0), int64(0) |
|
|
|
|
|
if stmt.Where != nil { |
|
|
|
|
|
startTimeNs, stopTimeNs = e.extractTimeFilters(stmt.Where.Expr) |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
// Build scan options for full table scan (aggregations need all data)
|
|
|
|
|
|
hybridScanOptions := HybridScanOptions{ |
|
|
|
|
|
StartTimeNs: startTimeNs, |
|
|
|
|
|
StopTimeNs: stopTimeNs, |
|
|
|
|
|
Limit: 0, // No limit for aggregations - need all data
|
|
|
|
|
|
Predicate: predicate, |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
// Execute the hybrid scan to get all matching records
|
|
|
|
|
|
results, err := hybridScanner.Scan(ctx, hybridScanOptions) |
|
|
|
|
|
if err != nil { |
|
|
|
|
|
return &QueryResult{Error: err}, err |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
// Compute aggregations
|
|
|
|
|
|
aggResults := e.computeAggregations(results, aggregations) |
|
|
|
|
|
|
|
|
|
|
|
// Build result set
|
|
|
|
|
|
columns := make([]string, len(aggregations)) |
|
|
|
|
|
row := make([]sqltypes.Value, len(aggregations)) |
|
|
|
|
|
|
|
|
|
|
|
for i, spec := range aggregations { |
|
|
|
|
|
columns[i] = spec.Alias |
|
|
|
|
|
row[i] = e.formatAggregationResult(spec, aggResults[i]) |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
return &QueryResult{ |
|
|
|
|
|
Columns: columns, |
|
|
|
|
|
Rows: [][]sqltypes.Value{row}, |
|
|
|
|
|
}, nil |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
// computeAggregations computes aggregation functions over the scan results
|
|
|
|
|
|
func (e *SQLEngine) computeAggregations(results []HybridScanResult, aggregations []AggregationSpec) []AggregationResult { |
|
|
|
|
|
aggResults := make([]AggregationResult, len(aggregations)) |
|
|
|
|
|
|
|
|
|
|
|
for i, spec := range aggregations { |
|
|
|
|
|
switch spec.Function { |
|
|
|
|
|
case "COUNT": |
|
|
|
|
|
if spec.Column == "*" { |
|
|
|
|
|
// COUNT(*) counts all rows
|
|
|
|
|
|
aggResults[i].Count = int64(len(results)) |
|
|
|
|
|
} else { |
|
|
|
|
|
// COUNT(column) counts non-null values
|
|
|
|
|
|
count := int64(0) |
|
|
|
|
|
for _, result := range results { |
|
|
|
|
|
if value, exists := result.Values[spec.Column]; exists && value != nil { |
|
|
|
|
|
if !e.isNullValue(value) { |
|
|
|
|
|
count++ |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
aggResults[i].Count = count |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
case "SUM": |
|
|
|
|
|
sum := float64(0) |
|
|
|
|
|
for _, result := range results { |
|
|
|
|
|
if value, exists := result.Values[spec.Column]; exists && value != nil { |
|
|
|
|
|
if numValue := e.convertToNumber(value); numValue != nil { |
|
|
|
|
|
sum += *numValue |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
aggResults[i].Sum = sum |
|
|
|
|
|
|
|
|
|
|
|
case "AVG": |
|
|
|
|
|
sum := float64(0) |
|
|
|
|
|
count := int64(0) |
|
|
|
|
|
for _, result := range results { |
|
|
|
|
|
if value, exists := result.Values[spec.Column]; exists && value != nil { |
|
|
|
|
|
if numValue := e.convertToNumber(value); numValue != nil { |
|
|
|
|
|
sum += *numValue |
|
|
|
|
|
count++ |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
if count > 0 { |
|
|
|
|
|
aggResults[i].Sum = sum / float64(count) // Store average in Sum field
|
|
|
|
|
|
aggResults[i].Count = count |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
case "MIN": |
|
|
|
|
|
var min interface{} |
|
|
|
|
|
for _, result := range results { |
|
|
|
|
|
if value, exists := result.Values[spec.Column]; exists && value != nil { |
|
|
|
|
|
if min == nil || e.compareValues(value, min) < 0 { |
|
|
|
|
|
min = e.extractRawValue(value) |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
aggResults[i].Min = min |
|
|
|
|
|
|
|
|
|
|
|
case "MAX": |
|
|
|
|
|
var max interface{} |
|
|
|
|
|
for _, result := range results { |
|
|
|
|
|
if value, exists := result.Values[spec.Column]; exists && value != nil { |
|
|
|
|
|
if max == nil || e.compareValues(value, max) > 0 { |
|
|
|
|
|
max = e.extractRawValue(value) |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
aggResults[i].Max = max |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
return aggResults |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
// Helper functions for aggregation processing
|
|
|
|
|
|
|
|
|
|
|
|
func (e *SQLEngine) isNullValue(value *schema_pb.Value) bool { |
|
|
|
|
|
return value == nil || value.Kind == nil |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
func (e *SQLEngine) convertToNumber(value *schema_pb.Value) *float64 { |
|
|
|
|
|
switch v := value.Kind.(type) { |
|
|
|
|
|
case *schema_pb.Value_Int32Value: |
|
|
|
|
|
result := float64(v.Int32Value) |
|
|
|
|
|
return &result |
|
|
|
|
|
case *schema_pb.Value_Int64Value: |
|
|
|
|
|
result := float64(v.Int64Value) |
|
|
|
|
|
return &result |
|
|
|
|
|
case *schema_pb.Value_FloatValue: |
|
|
|
|
|
result := float64(v.FloatValue) |
|
|
|
|
|
return &result |
|
|
|
|
|
case *schema_pb.Value_DoubleValue: |
|
|
|
|
|
return &v.DoubleValue |
|
|
|
|
|
} |
|
|
|
|
|
return nil |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
func (e *SQLEngine) extractRawValue(value *schema_pb.Value) interface{} { |
|
|
|
|
|
switch v := value.Kind.(type) { |
|
|
|
|
|
case *schema_pb.Value_Int32Value: |
|
|
|
|
|
return v.Int32Value |
|
|
|
|
|
case *schema_pb.Value_Int64Value: |
|
|
|
|
|
return v.Int64Value |
|
|
|
|
|
case *schema_pb.Value_FloatValue: |
|
|
|
|
|
return v.FloatValue |
|
|
|
|
|
case *schema_pb.Value_DoubleValue: |
|
|
|
|
|
return v.DoubleValue |
|
|
|
|
|
case *schema_pb.Value_StringValue: |
|
|
|
|
|
return v.StringValue |
|
|
|
|
|
case *schema_pb.Value_BoolValue: |
|
|
|
|
|
return v.BoolValue |
|
|
|
|
|
} |
|
|
|
|
|
return nil |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
func (e *SQLEngine) compareValues(value1 *schema_pb.Value, value2 interface{}) int { |
|
|
|
|
|
raw1 := e.extractRawValue(value1) |
|
|
|
|
|
if raw1 == nil { |
|
|
|
|
|
return -1 |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
// Simple comparison - in a full implementation this would handle type coercion
|
|
|
|
|
|
switch v1 := raw1.(type) { |
|
|
|
|
|
case int32: |
|
|
|
|
|
if v2, ok := value2.(int32); ok { |
|
|
|
|
|
if v1 < v2 { |
|
|
|
|
|
return -1 |
|
|
|
|
|
} else if v1 > v2 { |
|
|
|
|
|
return 1 |
|
|
|
|
|
} |
|
|
|
|
|
return 0 |
|
|
|
|
|
} |
|
|
|
|
|
case int64: |
|
|
|
|
|
if v2, ok := value2.(int64); ok { |
|
|
|
|
|
if v1 < v2 { |
|
|
|
|
|
return -1 |
|
|
|
|
|
} else if v1 > v2 { |
|
|
|
|
|
return 1 |
|
|
|
|
|
} |
|
|
|
|
|
return 0 |
|
|
|
|
|
} |
|
|
|
|
|
case float64: |
|
|
|
|
|
if v2, ok := value2.(float64); ok { |
|
|
|
|
|
if v1 < v2 { |
|
|
|
|
|
return -1 |
|
|
|
|
|
} else if v1 > v2 { |
|
|
|
|
|
return 1 |
|
|
|
|
|
} |
|
|
|
|
|
return 0 |
|
|
|
|
|
} |
|
|
|
|
|
case string: |
|
|
|
|
|
if v2, ok := value2.(string); ok { |
|
|
|
|
|
if v1 < v2 { |
|
|
|
|
|
return -1 |
|
|
|
|
|
} else if v1 > v2 { |
|
|
|
|
|
return 1 |
|
|
|
|
|
} |
|
|
|
|
|
return 0 |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
return 0 |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
func (e *SQLEngine) formatAggregationResult(spec AggregationSpec, result AggregationResult) sqltypes.Value { |
|
|
|
|
|
switch spec.Function { |
|
|
|
|
|
case "COUNT": |
|
|
|
|
|
return sqltypes.NewInt64(result.Count) |
|
|
|
|
|
case "SUM": |
|
|
|
|
|
return sqltypes.NewFloat64(result.Sum) |
|
|
|
|
|
case "AVG": |
|
|
|
|
|
return sqltypes.NewFloat64(result.Sum) // Sum contains the average for AVG
|
|
|
|
|
|
case "MIN": |
|
|
|
|
|
if result.Min != nil { |
|
|
|
|
|
return e.convertRawValueToSQL(result.Min) |
|
|
|
|
|
} |
|
|
|
|
|
return sqltypes.NULL |
|
|
|
|
|
case "MAX": |
|
|
|
|
|
if result.Max != nil { |
|
|
|
|
|
return e.convertRawValueToSQL(result.Max) |
|
|
|
|
|
} |
|
|
|
|
|
return sqltypes.NULL |
|
|
|
|
|
} |
|
|
|
|
|
return sqltypes.NULL |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
func (e *SQLEngine) convertRawValueToSQL(value interface{}) sqltypes.Value { |
|
|
|
|
|
switch v := value.(type) { |
|
|
|
|
|
case int32: |
|
|
|
|
|
return sqltypes.NewInt32(v) |
|
|
|
|
|
case int64: |
|
|
|
|
|
return sqltypes.NewInt64(v) |
|
|
|
|
|
case float32: |
|
|
|
|
|
return sqltypes.NewFloat32(v) |
|
|
|
|
|
case float64: |
|
|
|
|
|
return sqltypes.NewFloat64(v) |
|
|
|
|
|
case string: |
|
|
|
|
|
return sqltypes.NewVarChar(v) |
|
|
|
|
|
case bool: |
|
|
|
|
|
if v { |
|
|
|
|
|
return sqltypes.NewVarChar("1") |
|
|
|
|
|
} |
|
|
|
|
|
return sqltypes.NewVarChar("0") |
|
|
|
|
|
} |
|
|
|
|
|
return sqltypes.NULL |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
// discoverAndRegisterTopic attempts to discover an existing topic and register it in the SQL catalog
|
|
|
// discoverAndRegisterTopic attempts to discover an existing topic and register it in the SQL catalog
|
|
|
func (e *SQLEngine) discoverAndRegisterTopic(ctx context.Context, database, tableName string) error { |
|
|
func (e *SQLEngine) discoverAndRegisterTopic(ctx context.Context, database, tableName string) error { |
|
|
// First, check if topic exists by trying to get its schema from the broker/filer
|
|
|
// First, check if topic exists by trying to get its schema from the broker/filer
|
|
|