diff --git a/weed/query/engine/engine.go b/weed/query/engine/engine.go index 9d776fc52..8f422c240 100644 --- a/weed/query/engine/engine.go +++ b/weed/query/engine/engine.go @@ -168,9 +168,11 @@ func (e *SQLEngine) executeSelectStatement(ctx context.Context, stmt *sqlparser. return e.executeSelectWithSampleData(ctx, stmt, database, tableName) } - // Parse SELECT columns + // Parse SELECT columns and detect aggregation functions var columns []string + var aggregations []AggregationSpec selectAll := false + hasAggregations := false for _, selectExpr := range stmt.SelectExprs { switch expr := selectExpr.(type) { @@ -180,6 +182,14 @@ func (e *SQLEngine) executeSelectStatement(ctx context.Context, stmt *sqlparser. switch col := expr.Expr.(type) { case *sqlparser.ColName: columns = append(columns, col.Name.String()) + case *sqlparser.FuncExpr: + // Handle aggregation functions + aggSpec, err := e.parseAggregationFunction(col, expr) + if err != nil { + return &QueryResult{Error: err}, err + } + aggregations = append(aggregations, *aggSpec) + hasAggregations = true default: err := fmt.Errorf("unsupported SELECT expression: %T", col) return &QueryResult{Error: err}, err @@ -190,6 +200,11 @@ func (e *SQLEngine) executeSelectStatement(ctx context.Context, stmt *sqlparser. } } + // If we have aggregations, use aggregation query path + if hasAggregations { + return e.executeAggregationQuery(ctx, hybridScanner, aggregations, stmt) + } + // Parse WHERE clause for predicate pushdown var predicate func(*schema_pb.RecordValue) bool if stmt.Where != nil { @@ -988,6 +1003,338 @@ func (e *SQLEngine) dropTable(ctx context.Context, stmt *sqlparser.DDL) (*QueryR return result, nil } +// AggregationSpec defines an aggregation function to be computed +type AggregationSpec struct { + Function string // COUNT, SUM, AVG, MIN, MAX + Column string // Column name, or "*" for COUNT(*) + Alias string // Optional alias for the result column +} + +// AggregationResult holds the computed result of an aggregation +type AggregationResult struct { + Count int64 + Sum float64 + Min interface{} + Max interface{} +} + +// parseAggregationFunction parses an aggregation function expression +func (e *SQLEngine) parseAggregationFunction(funcExpr *sqlparser.FuncExpr, aliasExpr *sqlparser.AliasedExpr) (*AggregationSpec, error) { + funcName := strings.ToUpper(funcExpr.Name.String()) + + // Get alias name if specified + alias := funcName // Default alias is the function name + if !aliasExpr.As.IsEmpty() { + alias = aliasExpr.As.String() + } + + spec := &AggregationSpec{ + Function: funcName, + Alias: alias, + } + + // Parse function arguments + switch funcName { + case "COUNT": + if len(funcExpr.Exprs) != 1 { + return nil, fmt.Errorf("COUNT function expects exactly 1 argument") + } + + switch arg := funcExpr.Exprs[0].(type) { + case *sqlparser.StarExpr: + spec.Column = "*" + case *sqlparser.AliasedExpr: + if colName, ok := arg.Expr.(*sqlparser.ColName); ok { + spec.Column = colName.Name.String() + } else { + return nil, fmt.Errorf("COUNT argument must be a column name or *") + } + default: + return nil, fmt.Errorf("unsupported COUNT argument: %T", arg) + } + + case "SUM", "AVG", "MIN", "MAX": + if len(funcExpr.Exprs) != 1 { + return nil, fmt.Errorf("%s function expects exactly 1 argument", funcName) + } + + switch arg := funcExpr.Exprs[0].(type) { + case *sqlparser.AliasedExpr: + if colName, ok := arg.Expr.(*sqlparser.ColName); ok { + spec.Column = colName.Name.String() + } else { + return nil, fmt.Errorf("%s argument must be a column name", funcName) + } + default: + return nil, fmt.Errorf("unsupported %s argument: %T", funcName, arg) + } + + default: + return nil, fmt.Errorf("unsupported aggregation function: %s", funcName) + } + + return spec, nil +} + +// executeAggregationQuery handles SELECT queries with aggregation functions +func (e *SQLEngine) executeAggregationQuery(ctx context.Context, hybridScanner *HybridMessageScanner, aggregations []AggregationSpec, stmt *sqlparser.Select) (*QueryResult, error) { + // Parse WHERE clause for filtering + var predicate func(*schema_pb.RecordValue) bool + var err error + if stmt.Where != nil { + predicate, err = e.buildPredicate(stmt.Where.Expr) + if err != nil { + return &QueryResult{Error: err}, err + } + } + + // Extract time filters for optimization + startTimeNs, stopTimeNs := int64(0), int64(0) + if stmt.Where != nil { + startTimeNs, stopTimeNs = e.extractTimeFilters(stmt.Where.Expr) + } + + // Build scan options for full table scan (aggregations need all data) + hybridScanOptions := HybridScanOptions{ + StartTimeNs: startTimeNs, + StopTimeNs: stopTimeNs, + Limit: 0, // No limit for aggregations - need all data + Predicate: predicate, + } + + // Execute the hybrid scan to get all matching records + results, err := hybridScanner.Scan(ctx, hybridScanOptions) + if err != nil { + return &QueryResult{Error: err}, err + } + + // Compute aggregations + aggResults := e.computeAggregations(results, aggregations) + + // Build result set + columns := make([]string, len(aggregations)) + row := make([]sqltypes.Value, len(aggregations)) + + for i, spec := range aggregations { + columns[i] = spec.Alias + row[i] = e.formatAggregationResult(spec, aggResults[i]) + } + + return &QueryResult{ + Columns: columns, + Rows: [][]sqltypes.Value{row}, + }, nil +} + +// computeAggregations computes aggregation functions over the scan results +func (e *SQLEngine) computeAggregations(results []HybridScanResult, aggregations []AggregationSpec) []AggregationResult { + aggResults := make([]AggregationResult, len(aggregations)) + + for i, spec := range aggregations { + switch spec.Function { + case "COUNT": + if spec.Column == "*" { + // COUNT(*) counts all rows + aggResults[i].Count = int64(len(results)) + } else { + // COUNT(column) counts non-null values + count := int64(0) + for _, result := range results { + if value, exists := result.Values[spec.Column]; exists && value != nil { + if !e.isNullValue(value) { + count++ + } + } + } + aggResults[i].Count = count + } + + case "SUM": + sum := float64(0) + for _, result := range results { + if value, exists := result.Values[spec.Column]; exists && value != nil { + if numValue := e.convertToNumber(value); numValue != nil { + sum += *numValue + } + } + } + aggResults[i].Sum = sum + + case "AVG": + sum := float64(0) + count := int64(0) + for _, result := range results { + if value, exists := result.Values[spec.Column]; exists && value != nil { + if numValue := e.convertToNumber(value); numValue != nil { + sum += *numValue + count++ + } + } + } + if count > 0 { + aggResults[i].Sum = sum / float64(count) // Store average in Sum field + aggResults[i].Count = count + } + + case "MIN": + var min interface{} + for _, result := range results { + if value, exists := result.Values[spec.Column]; exists && value != nil { + if min == nil || e.compareValues(value, min) < 0 { + min = e.extractRawValue(value) + } + } + } + aggResults[i].Min = min + + case "MAX": + var max interface{} + for _, result := range results { + if value, exists := result.Values[spec.Column]; exists && value != nil { + if max == nil || e.compareValues(value, max) > 0 { + max = e.extractRawValue(value) + } + } + } + aggResults[i].Max = max + } + } + + return aggResults +} + +// Helper functions for aggregation processing + +func (e *SQLEngine) isNullValue(value *schema_pb.Value) bool { + return value == nil || value.Kind == nil +} + +func (e *SQLEngine) convertToNumber(value *schema_pb.Value) *float64 { + switch v := value.Kind.(type) { + case *schema_pb.Value_Int32Value: + result := float64(v.Int32Value) + return &result + case *schema_pb.Value_Int64Value: + result := float64(v.Int64Value) + return &result + case *schema_pb.Value_FloatValue: + result := float64(v.FloatValue) + return &result + case *schema_pb.Value_DoubleValue: + return &v.DoubleValue + } + return nil +} + +func (e *SQLEngine) extractRawValue(value *schema_pb.Value) interface{} { + switch v := value.Kind.(type) { + case *schema_pb.Value_Int32Value: + return v.Int32Value + case *schema_pb.Value_Int64Value: + return v.Int64Value + case *schema_pb.Value_FloatValue: + return v.FloatValue + case *schema_pb.Value_DoubleValue: + return v.DoubleValue + case *schema_pb.Value_StringValue: + return v.StringValue + case *schema_pb.Value_BoolValue: + return v.BoolValue + } + return nil +} + +func (e *SQLEngine) compareValues(value1 *schema_pb.Value, value2 interface{}) int { + raw1 := e.extractRawValue(value1) + if raw1 == nil { + return -1 + } + + // Simple comparison - in a full implementation this would handle type coercion + switch v1 := raw1.(type) { + case int32: + if v2, ok := value2.(int32); ok { + if v1 < v2 { + return -1 + } else if v1 > v2 { + return 1 + } + return 0 + } + case int64: + if v2, ok := value2.(int64); ok { + if v1 < v2 { + return -1 + } else if v1 > v2 { + return 1 + } + return 0 + } + case float64: + if v2, ok := value2.(float64); ok { + if v1 < v2 { + return -1 + } else if v1 > v2 { + return 1 + } + return 0 + } + case string: + if v2, ok := value2.(string); ok { + if v1 < v2 { + return -1 + } else if v1 > v2 { + return 1 + } + return 0 + } + } + return 0 +} + +func (e *SQLEngine) formatAggregationResult(spec AggregationSpec, result AggregationResult) sqltypes.Value { + switch spec.Function { + case "COUNT": + return sqltypes.NewInt64(result.Count) + case "SUM": + return sqltypes.NewFloat64(result.Sum) + case "AVG": + return sqltypes.NewFloat64(result.Sum) // Sum contains the average for AVG + case "MIN": + if result.Min != nil { + return e.convertRawValueToSQL(result.Min) + } + return sqltypes.NULL + case "MAX": + if result.Max != nil { + return e.convertRawValueToSQL(result.Max) + } + return sqltypes.NULL + } + return sqltypes.NULL +} + +func (e *SQLEngine) convertRawValueToSQL(value interface{}) sqltypes.Value { + switch v := value.(type) { + case int32: + return sqltypes.NewInt32(v) + case int64: + return sqltypes.NewInt64(v) + case float32: + return sqltypes.NewFloat32(v) + case float64: + return sqltypes.NewFloat64(v) + case string: + return sqltypes.NewVarChar(v) + case bool: + if v { + return sqltypes.NewVarChar("1") + } + return sqltypes.NewVarChar("0") + } + return sqltypes.NULL +} + // discoverAndRegisterTopic attempts to discover an existing topic and register it in the SQL catalog func (e *SQLEngine) discoverAndRegisterTopic(ctx context.Context, database, tableName string) error { // First, check if topic exists by trying to get its schema from the broker/filer