diff --git a/weed/mq/schema/to_parquet_visitor.go b/weed/mq/schema/to_parquet_visitor.go new file mode 100644 index 000000000..6c73563cd --- /dev/null +++ b/weed/mq/schema/to_parquet_visitor.go @@ -0,0 +1,58 @@ +package schema + +import ( + "fmt" + "github.com/seaweedfs/seaweedfs/weed/pb/schema_pb" +) + +type ParquetLevels struct { + startColumnIndex int + endColumnIndex int + definitionDepth int + levels map[string]*ParquetLevels +} + +func ToParquetLevels(recordType *schema_pb.RecordType) (*ParquetLevels, error) { + return toRecordTypeLevels(recordType, 0, 0) +} + +func toFieldTypeLevels(fieldType *schema_pb.Type, startColumnIndex, definitionDepth int) (*ParquetLevels, error) { + switch fieldType.Kind.(type) { + case *schema_pb.Type_ScalarType: + return toFieldTypeScalarLevels(fieldType.GetScalarType(), startColumnIndex, definitionDepth) + case *schema_pb.Type_RecordType: + return toRecordTypeLevels(fieldType.GetRecordType(), startColumnIndex, definitionDepth) + case *schema_pb.Type_ListType: + return toFieldTypeListLevels(fieldType.GetListType(), startColumnIndex, definitionDepth) + } + return nil, fmt.Errorf("unknown field type: %T", fieldType.Kind) +} + +func toFieldTypeListLevels(listType *schema_pb.ListType, startColumnIndex, definitionDepth int) (*ParquetLevels, error) { + return toFieldTypeLevels(listType.ElementType, startColumnIndex, definitionDepth) +} + +func toFieldTypeScalarLevels(scalarType schema_pb.ScalarType, startColumnIndex, definitionDepth int) (*ParquetLevels, error) { + return &ParquetLevels{ + startColumnIndex: startColumnIndex, + endColumnIndex: startColumnIndex + 1, + definitionDepth: definitionDepth, + }, nil +} +func toRecordTypeLevels(recordType *schema_pb.RecordType, startColumnIndex, definitionDepth int) (*ParquetLevels, error) { + recordTypeLevels := &ParquetLevels{ + startColumnIndex: startColumnIndex, + definitionDepth: definitionDepth, + levels: make(map[string]*ParquetLevels), + } + for _, field := range recordType.Fields { + fieldTypeLevels, err := toFieldTypeLevels(field.Type, startColumnIndex, definitionDepth+1) + if err != nil { + return nil, err + } + recordTypeLevels.levels[field.Name] = fieldTypeLevels + startColumnIndex = fieldTypeLevels.endColumnIndex + } + recordTypeLevels.endColumnIndex = startColumnIndex + return recordTypeLevels, nil +} diff --git a/weed/mq/schema/to_parquet_visitor_test.go b/weed/mq/schema/to_parquet_visitor_test.go new file mode 100644 index 000000000..8e9379e55 --- /dev/null +++ b/weed/mq/schema/to_parquet_visitor_test.go @@ -0,0 +1,97 @@ +package schema + +import ( + "github.com/seaweedfs/seaweedfs/weed/pb/schema_pb" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestToParquetLevels(t *testing.T) { + type args struct { + recordType *schema_pb.RecordType + } + tests := []struct { + name string + args args + want *ParquetLevels + }{ + { + name: "nested type", + args: args{ + NewRecordTypeBuilder(). + AddLongField("ID"). + AddLongField("CreatedAt"). + AddRecordField("Person", NewRecordTypeBuilder(). + AddStringField("zName"). + AddListField("emails", TypeString)). + AddStringField("Company"). + AddRecordField("Address", NewRecordTypeBuilder(). + AddStringField("Street"). + AddStringField("City")).Build(), + }, + want: &ParquetLevels{ + startColumnIndex: 0, + endColumnIndex: 7, + definitionDepth: 0, + levels: map[string]*ParquetLevels{ + "Address": { + startColumnIndex: 0, + endColumnIndex: 2, + definitionDepth: 1, + levels: map[string]*ParquetLevels{ + "City": { + startColumnIndex: 0, + endColumnIndex: 1, + definitionDepth: 2, + }, + "Street": { + startColumnIndex: 1, + endColumnIndex: 2, + definitionDepth: 2, + }, + }, + }, + "Company": { + startColumnIndex: 2, + endColumnIndex: 3, + definitionDepth: 1, + }, + "CreatedAt": { + startColumnIndex: 3, + endColumnIndex: 4, + definitionDepth: 1, + }, + "ID": { + startColumnIndex: 4, + endColumnIndex: 5, + definitionDepth: 1, + }, + "Person": { + startColumnIndex: 5, + endColumnIndex: 7, + definitionDepth: 1, + levels: map[string]*ParquetLevels{ + "emails": { + startColumnIndex: 5, + endColumnIndex: 6, + definitionDepth: 2, + }, + "zName": { + startColumnIndex: 6, + endColumnIndex: 7, + definitionDepth: 2, + }, + }, + }, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ToParquetLevels(tt.args.recordType) + assert.Nil(t, err) + assert.Equalf(t, tt.want, got, "ToParquetLevels(%v)", tt.args.recordType) + }) + } +}