From 8645f3a26452cf0595c73e1068e776397dd66893 Mon Sep 17 00:00:00 2001 From: chrislu Date: Mon, 1 Sep 2025 11:25:04 -0700 Subject: [PATCH] column name case insensitive, better auto column names --- weed/query/engine/engine.go | 43 ++++++++++++++++++++++++++----------- 1 file changed, 31 insertions(+), 12 deletions(-) diff --git a/weed/query/engine/engine.go b/weed/query/engine/engine.go index 8f422c240..db93f6c0d 100644 --- a/weed/query/engine/engine.go +++ b/weed/query/engine/engine.go @@ -1022,15 +1022,8 @@ type AggregationResult struct { func (e *SQLEngine) parseAggregationFunction(funcExpr *sqlparser.FuncExpr, aliasExpr *sqlparser.AliasedExpr) (*AggregationSpec, error) { funcName := strings.ToUpper(funcExpr.Name.String()) - // Get alias name if specified - alias := funcName // Default alias is the function name - if !aliasExpr.As.IsEmpty() { - alias = aliasExpr.As.String() - } - spec := &AggregationSpec{ Function: funcName, - Alias: alias, } // Parse function arguments @@ -1043,9 +1036,11 @@ func (e *SQLEngine) parseAggregationFunction(funcExpr *sqlparser.FuncExpr, alias switch arg := funcExpr.Exprs[0].(type) { case *sqlparser.StarExpr: spec.Column = "*" + spec.Alias = "COUNT(*)" case *sqlparser.AliasedExpr: if colName, ok := arg.Expr.(*sqlparser.ColName); ok { spec.Column = colName.Name.String() + spec.Alias = fmt.Sprintf("COUNT(%s)", spec.Column) } else { return nil, fmt.Errorf("COUNT argument must be a column name or *") } @@ -1062,6 +1057,7 @@ func (e *SQLEngine) parseAggregationFunction(funcExpr *sqlparser.FuncExpr, alias case *sqlparser.AliasedExpr: if colName, ok := arg.Expr.(*sqlparser.ColName); ok { spec.Column = colName.Name.String() + spec.Alias = fmt.Sprintf("%s(%s)", funcName, spec.Column) } else { return nil, fmt.Errorf("%s argument must be a column name", funcName) } @@ -1073,6 +1069,11 @@ func (e *SQLEngine) parseAggregationFunction(funcExpr *sqlparser.FuncExpr, alias return nil, fmt.Errorf("unsupported aggregation function: %s", funcName) } + // Override with user-specified alias if provided + if !aliasExpr.As.IsEmpty() { + spec.Alias = aliasExpr.As.String() + } + return spec, nil } @@ -1140,7 +1141,7 @@ func (e *SQLEngine) computeAggregations(results []HybridScanResult, aggregations // COUNT(column) counts non-null values count := int64(0) for _, result := range results { - if value, exists := result.Values[spec.Column]; exists && value != nil { + if value := e.findColumnValue(result.Values, spec.Column); value != nil { if !e.isNullValue(value) { count++ } @@ -1152,7 +1153,7 @@ func (e *SQLEngine) computeAggregations(results []HybridScanResult, aggregations case "SUM": sum := float64(0) for _, result := range results { - if value, exists := result.Values[spec.Column]; exists && value != nil { + if value := e.findColumnValue(result.Values, spec.Column); value != nil { if numValue := e.convertToNumber(value); numValue != nil { sum += *numValue } @@ -1164,7 +1165,7 @@ func (e *SQLEngine) computeAggregations(results []HybridScanResult, aggregations sum := float64(0) count := int64(0) for _, result := range results { - if value, exists := result.Values[spec.Column]; exists && value != nil { + if value := e.findColumnValue(result.Values, spec.Column); value != nil { if numValue := e.convertToNumber(value); numValue != nil { sum += *numValue count++ @@ -1179,7 +1180,7 @@ func (e *SQLEngine) computeAggregations(results []HybridScanResult, aggregations case "MIN": var min interface{} for _, result := range results { - if value, exists := result.Values[spec.Column]; exists && value != nil { + if value := e.findColumnValue(result.Values, spec.Column); value != nil { if min == nil || e.compareValues(value, min) < 0 { min = e.extractRawValue(value) } @@ -1190,7 +1191,7 @@ func (e *SQLEngine) computeAggregations(results []HybridScanResult, aggregations case "MAX": var max interface{} for _, result := range results { - if value, exists := result.Values[spec.Column]; exists && value != nil { + if value := e.findColumnValue(result.Values, spec.Column); value != nil { if max == nil || e.compareValues(value, max) > 0 { max = e.extractRawValue(value) } @@ -1335,6 +1336,24 @@ func (e *SQLEngine) convertRawValueToSQL(value interface{}) sqltypes.Value { return sqltypes.NULL } +// findColumnValue performs case-insensitive lookup of column values +func (e *SQLEngine) findColumnValue(values map[string]*schema_pb.Value, columnName string) *schema_pb.Value { + // First try exact match + if value, exists := values[columnName]; exists { + return value + } + + // Then try case-insensitive match + lowerColumnName := strings.ToLower(columnName) + for key, value := range values { + if strings.ToLower(key) == lowerColumnName { + return value + } + } + + return nil +} + // discoverAndRegisterTopic attempts to discover an existing topic and register it in the SQL catalog func (e *SQLEngine) discoverAndRegisterTopic(ctx context.Context, database, tableName string) error { // First, check if topic exists by trying to get its schema from the broker/filer