diff --git a/weed/mq/schema/schema_builder.go b/weed/mq/schema/schema_builder.go index cabd0ee7b..313bbc5d3 100644 --- a/weed/mq/schema/schema_builder.go +++ b/weed/mq/schema/schema_builder.go @@ -2,6 +2,7 @@ package schema import ( "github.com/seaweedfs/seaweedfs/weed/pb/schema_pb" + "sort" ) var ( @@ -23,6 +24,10 @@ func NewRecordTypeBuilder() *RecordTypeBuilder { } func (rtb *RecordTypeBuilder) Build() *schema_pb.RecordType { + // be consistent with parquet.node.go `func (g Group) Fields() []Field` + sort.Slice(rtb.recordType.Fields, func(i, j int) bool { + return rtb.recordType.Fields[i].Name < rtb.recordType.Fields[j].Name + }) return rtb.recordType } diff --git a/weed/mq/schema/to_parquet_schema.go b/weed/mq/schema/to_parquet_schema.go index bae42e0a7..3019bc8a7 100644 --- a/weed/mq/schema/to_parquet_schema.go +++ b/weed/mq/schema/to_parquet_schema.go @@ -12,6 +12,8 @@ func ToParquetSchema(topicName string, recordType *schema_pb.RecordType) (*parqu return nil, fmt.Errorf("failed to convert record type to parquet schema: %v", err) } + // Fields are sorted by name, so the value should be sorted also + // the sorting is inside parquet.`func (g Group) Fields() []Field` return parquet.NewSchema(topicName, rootNode), nil } @@ -53,7 +55,7 @@ func toParquetFieldTypeScalar(scalarType schema_pb.ScalarType) (parquet.Node, er case schema_pb.ScalarType_BYTES: return parquet.Leaf(parquet.ByteArrayType), nil case schema_pb.ScalarType_STRING: - return parquet.String(), nil + return parquet.Leaf(parquet.ByteArrayType), nil default: return nil, fmt.Errorf("unknown scalar type: %v", scalarType) } diff --git a/weed/mq/schema/to_parquet_value.go b/weed/mq/schema/to_parquet_value.go index 8041da3ad..a5b981f4d 100644 --- a/weed/mq/schema/to_parquet_value.go +++ b/weed/mq/schema/to_parquet_value.go @@ -6,19 +6,32 @@ import ( "github.com/seaweedfs/seaweedfs/weed/pb/schema_pb" ) - -func AddRecordValue(rowBuilder *parquet.RowBuilder, fieldType *schema_pb.Type, fieldValue *schema_pb.Value) error { - visitor := func(fieldType *schema_pb.Type, fieldValue *schema_pb.Value, index int) error { - switch fieldType.Kind.(type) { - case *schema_pb.Type_ScalarType: - parquetValue, err := toParquetValue(fieldValue) - if err != nil { +func rowBuilderVisit(rowBuilder *parquet.RowBuilder, fieldType *schema_pb.Type, fieldValue *schema_pb.Value, columnIndex int) error { + switch fieldType.Kind.(type) { + case *schema_pb.Type_ScalarType: + parquetValue, err := toParquetValue(fieldValue) + if err != nil { + return err + } + rowBuilder.Add(columnIndex, parquetValue) + case *schema_pb.Type_ListType: + elementType := fieldType.GetListType().ElementType + for _, value := range fieldValue.GetListValue().Values { + if err := rowBuilderVisit(rowBuilder, elementType, value, columnIndex); err != nil { return err } - rowBuilder.Add(index, parquetValue) } - return nil + rowBuilder.Next(columnIndex) + } + return nil +} + +func AddRecordValue(rowBuilder *parquet.RowBuilder, recordType *schema_pb.RecordType, recordValue *schema_pb.RecordValue) error { + visitor := func(fieldType *schema_pb.Type, fieldValue *schema_pb.Value, index int) error { + return rowBuilderVisit(rowBuilder, fieldType, fieldValue, index) } + fieldType := &schema_pb.Type{Kind: &schema_pb.Type_RecordType{RecordType: recordType}} + fieldValue := &schema_pb.Value{Kind: &schema_pb.Value_RecordValue{RecordValue: recordValue}} return visitValue(fieldType, fieldValue, visitor) } @@ -39,13 +52,7 @@ func doVisitValue(fieldType *schema_pb.Type, fieldValue *schema_pb.Value, column case *schema_pb.Type_ScalarType: return columnIndex+1, visitor(fieldType, fieldValue, columnIndex) case *schema_pb.Type_ListType: - for _, value := range fieldValue.GetListValue().Values { - err = visitor(fieldType, value, columnIndex) - if err != nil { - return - } - } - return columnIndex+1, nil + return columnIndex+1, visitor(fieldType, fieldValue, columnIndex) case *schema_pb.Type_RecordType: for _, field := range fieldType.GetRecordType().Fields { fieldValue, found := fieldValue.GetRecordValue().Fields[field.Name] diff --git a/weed/mq/schema/to_schema_value.go b/weed/mq/schema/to_schema_value.go new file mode 100644 index 000000000..9f8cd5d91 --- /dev/null +++ b/weed/mq/schema/to_schema_value.go @@ -0,0 +1,85 @@ +package schema + +import ( + "fmt" + "github.com/parquet-go/parquet-go" + "github.com/seaweedfs/seaweedfs/weed/pb/schema_pb" +) + +func ToRecordValue(recordType *schema_pb.RecordType, row parquet.Row) (*schema_pb.RecordValue, error) { + values := []parquet.Value(row) + recordValue, _, err := toRecordValue(recordType, values, 0) + if err != nil { + return nil, err + } + return recordValue.GetRecordValue(), nil +} + +func ToValue(t *schema_pb.Type, values []parquet.Value, columnIndex int) (value *schema_pb.Value, endIndex int, err error) { + switch t.Kind.(type) { + case *schema_pb.Type_ScalarType: + return toScalarValue(t.GetScalarType(), values, columnIndex) + case *schema_pb.Type_ListType: + return toListValue(t.GetListType(), values, columnIndex) + case *schema_pb.Type_RecordType: + return toRecordValue(t.GetRecordType(), values, columnIndex) + } + return nil, 0, fmt.Errorf("unsupported type: %v", t) +} + +func toRecordValue(recordType *schema_pb.RecordType, values []parquet.Value, columnIndex int) (*schema_pb.Value, int, error) { + recordValue := schema_pb.RecordValue{Fields: make(map[string]*schema_pb.Value)} + for _, field := range recordType.Fields { + fieldValue, endIndex, err := ToValue(field.Type, values, columnIndex) + if err != nil { + return nil, 0, err + } + if endIndex == columnIndex { + continue + } + columnIndex = endIndex + recordValue.Fields[field.Name] = fieldValue + } + return &schema_pb.Value{Kind: &schema_pb.Value_RecordValue{RecordValue: &recordValue}}, columnIndex, nil +} + +func toListValue(listType *schema_pb.ListType, values []parquet.Value, index int) (listValue *schema_pb.Value, endIndex int, err error) { + listValues := make([]*schema_pb.Value, 0) + var value *schema_pb.Value + for i := index; i < len(values); { + value, endIndex, err = ToValue(listType.ElementType, values, i) + if err != nil { + return nil, 0, err + } + if endIndex == i { + break + } + listValues = append(listValues, value) + i = endIndex + } + return &schema_pb.Value{Kind: &schema_pb.Value_ListValue{ListValue: &schema_pb.ListValue{Values: listValues}}}, endIndex, nil +} + +func toScalarValue(scalarType schema_pb.ScalarType, values []parquet.Value, columnIndex int) (*schema_pb.Value, int, error) { + value := values[columnIndex] + if value.Column() != columnIndex { + return nil, columnIndex, nil + } + switch scalarType { + case schema_pb.ScalarType_BOOLEAN: + return &schema_pb.Value{Kind: &schema_pb.Value_BoolValue{BoolValue: value.Boolean()}}, columnIndex + 1, nil + case schema_pb.ScalarType_INTEGER: + return &schema_pb.Value{Kind: &schema_pb.Value_Int32Value{Int32Value: value.Int32()}}, columnIndex + 1, nil + case schema_pb.ScalarType_LONG: + return &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: value.Int64()}}, columnIndex + 1, nil + case schema_pb.ScalarType_FLOAT: + return &schema_pb.Value{Kind: &schema_pb.Value_FloatValue{FloatValue: value.Float()}}, columnIndex + 1, nil + case schema_pb.ScalarType_DOUBLE: + return &schema_pb.Value{Kind: &schema_pb.Value_DoubleValue{DoubleValue: value.Double()}}, columnIndex + 1, nil + case schema_pb.ScalarType_BYTES: + return &schema_pb.Value{Kind: &schema_pb.Value_BytesValue{BytesValue: value.ByteArray()}}, columnIndex + 1, nil + case schema_pb.ScalarType_STRING: + return &schema_pb.Value{Kind: &schema_pb.Value_StringValue{StringValue: string(value.ByteArray())}}, columnIndex + 1, nil + } + return nil, columnIndex, fmt.Errorf("unsupported scalar type: %v", scalarType) +} diff --git a/weed/mq/schema/value_builder.go b/weed/mq/schema/value_builder.go new file mode 100644 index 000000000..600252833 --- /dev/null +++ b/weed/mq/schema/value_builder.go @@ -0,0 +1,113 @@ +package schema + +import "github.com/seaweedfs/seaweedfs/weed/pb/schema_pb" + +// RecordValueBuilder helps in constructing RecordValue protobuf messages +type RecordValueBuilder struct { + recordValue *schema_pb.RecordValue +} + +// NewRecordValueBuilder creates a new RecordValueBuilder instance +func NewRecordValueBuilder() *RecordValueBuilder { + return &RecordValueBuilder{recordValue: &schema_pb.RecordValue{Fields: make(map[string]*schema_pb.Value)}} +} + +// Build returns the constructed RecordValue message +func (rvb *RecordValueBuilder) Build() *schema_pb.RecordValue { + return rvb.recordValue +} + +func (rvb *RecordValueBuilder) AddBoolValue(key string, value bool) *RecordValueBuilder { + rvb.recordValue.Fields[key] = &schema_pb.Value{Kind: &schema_pb.Value_BoolValue{BoolValue: value}} + return rvb +} +func (rvb *RecordValueBuilder) AddIntValue(key string, value int32) *RecordValueBuilder { + rvb.recordValue.Fields[key] = &schema_pb.Value{Kind: &schema_pb.Value_Int32Value{Int32Value: value}} + return rvb +} +func (rvb *RecordValueBuilder) AddLongValue(key string, value int64) *RecordValueBuilder { + rvb.recordValue.Fields[key] = &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: value}} + return rvb +} +func (rvb *RecordValueBuilder) AddFloatValue(key string, value float32) *RecordValueBuilder { + rvb.recordValue.Fields[key] = &schema_pb.Value{Kind: &schema_pb.Value_FloatValue{FloatValue: value}} + return rvb +} +func (rvb *RecordValueBuilder) AddDoubleValue(key string, value float64) *RecordValueBuilder { + rvb.recordValue.Fields[key] = &schema_pb.Value{Kind: &schema_pb.Value_DoubleValue{DoubleValue: value}} + return rvb +} +func (rvb *RecordValueBuilder) AddBytesValue(key string, value []byte) *RecordValueBuilder { + rvb.recordValue.Fields[key] = &schema_pb.Value{Kind: &schema_pb.Value_BytesValue{BytesValue: value}} + return rvb +} +func (rvb *RecordValueBuilder) AddStringValue(key string, value string) *RecordValueBuilder { + rvb.recordValue.Fields[key] = &schema_pb.Value{Kind: &schema_pb.Value_StringValue{StringValue: value}} + return rvb +} +func (rvb *RecordValueBuilder) AddRecordValue(key string, value *RecordValueBuilder) *RecordValueBuilder { + rvb.recordValue.Fields[key] = &schema_pb.Value{Kind: &schema_pb.Value_RecordValue{RecordValue: value.Build()}} + return rvb +} + +func (rvb *RecordValueBuilder) addListValue(key string, values []*schema_pb.Value) *RecordValueBuilder { + rvb.recordValue.Fields[key] = &schema_pb.Value{Kind: &schema_pb.Value_ListValue{ListValue: &schema_pb.ListValue{Values: values}}} + return rvb +} + +func (rvb *RecordValueBuilder) AddBoolListValue(key string, values ...bool) *RecordValueBuilder { + var listValues []*schema_pb.Value + for _, v := range values { + listValues = append(listValues, &schema_pb.Value{Kind: &schema_pb.Value_BoolValue{BoolValue: v}}) + } + return rvb.addListValue(key, listValues) +} +func (rvb *RecordValueBuilder) AddIntListValue(key string, values ...int32) *RecordValueBuilder { + var listValues []*schema_pb.Value + for _, v := range values { + listValues = append(listValues, &schema_pb.Value{Kind: &schema_pb.Value_Int32Value{Int32Value: v}}) + } + return rvb.addListValue(key, listValues) +} +func (rvb *RecordValueBuilder) AddLongListValue(key string, values ...int64) *RecordValueBuilder { + var listValues []*schema_pb.Value + for _, v := range values { + listValues = append(listValues, &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: v}}) + } + return rvb.addListValue(key, listValues) +} +func (rvb *RecordValueBuilder) AddFloatListValue(key string, values ...float32) *RecordValueBuilder { + var listValues []*schema_pb.Value + for _, v := range values { + listValues = append(listValues, &schema_pb.Value{Kind: &schema_pb.Value_FloatValue{FloatValue: v}}) + } + return rvb.addListValue(key, listValues) +} +func (rvb *RecordValueBuilder) AddDoubleListValue(key string, values ...float64) *RecordValueBuilder { + var listValues []*schema_pb.Value + for _, v := range values { + listValues = append(listValues, &schema_pb.Value{Kind: &schema_pb.Value_DoubleValue{DoubleValue: v}}) + } + return rvb.addListValue(key, listValues) +} +func (rvb *RecordValueBuilder) AddBytesListValue(key string, values ...[]byte) *RecordValueBuilder { + var listValues []*schema_pb.Value + for _, v := range values { + listValues = append(listValues, &schema_pb.Value{Kind: &schema_pb.Value_BytesValue{BytesValue: v}}) + } + return rvb.addListValue(key, listValues) +} +func (rvb *RecordValueBuilder) AddStringListValue(key string, values ...string) *RecordValueBuilder { + var listValues []*schema_pb.Value + for _, v := range values { + listValues = append(listValues, &schema_pb.Value{Kind: &schema_pb.Value_StringValue{StringValue: v}}) + } + return rvb.addListValue(key, listValues) +} +func (rvb *RecordValueBuilder) AddRecordListValue(key string, values ...*RecordValueBuilder) *RecordValueBuilder { + var listValues []*schema_pb.Value + for _, v := range values { + listValues = append(listValues, &schema_pb.Value{Kind: &schema_pb.Value_RecordValue{RecordValue: v.Build()}}) + } + return rvb.addListValue(key, listValues) +} diff --git a/weed/mq/schema/write_parquet_test.go b/weed/mq/schema/write_parquet_test.go index 7920bf3a1..1b4ecdf59 100644 --- a/weed/mq/schema/write_parquet_test.go +++ b/weed/mq/schema/write_parquet_test.go @@ -2,6 +2,11 @@ package schema import ( "fmt" + "github.com/parquet-go/parquet-go" + "github.com/parquet-go/parquet-go/compress/zstd" + "github.com/seaweedfs/seaweedfs/weed/pb/schema_pb" + "io" + "os" "testing" ) @@ -11,8 +16,9 @@ func TestWriteParquet(t *testing.T) { AddLongField("ID"). AddLongField("CreatedAt"). AddRecordField("Person", NewRecordTypeBuilder(). - AddStringField("Name"). - AddListField("emails", TypeString)).Build() + AddStringField("zName"). + AddListField("emails", TypeString)). + AddStringField("Company").Build() fmt.Printf("RecordType: %v\n", recordType) // create a parquet schema @@ -21,5 +27,85 @@ func TestWriteParquet(t *testing.T) { t.Fatalf("ToParquetSchema failed: %v", err) } fmt.Printf("ParquetSchema: %v\n", parquetSchema) + parquet.PrintSchema(os.Stdout, "example", parquetSchema) + filename := "example.parquet" + + testWritingParquetFile(t, filename, parquetSchema, recordType) + + total := testReadingParquetFile(t, filename, parquetSchema, recordType) + + if total != 128*1024 { + t.Fatalf("total != 128*1024: %v", total) + } + +} + +func testWritingParquetFile(t *testing.T, filename string, parquetSchema *parquet.Schema, recordType *schema_pb.RecordType) { + // create a parquet file + file, err := os.OpenFile(filename, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0664) + if err != nil { + t.Fatalf("os.Open failed: %v", err) + } + defer file.Close() + writer := parquet.NewWriter(file, parquetSchema, parquet.Compression(&zstd.Codec{Level: zstd.SpeedDefault})) + rowBuilder := parquet.NewRowBuilder(parquetSchema) + for i := 0; i < 128; i++ { + rowBuilder.Reset() + // generate random data + AddRecordValue(rowBuilder, recordType, NewRecordValueBuilder(). + AddLongValue("ID", int64(1+i)). + AddLongValue("CreatedAt", 2*int64(i)). + AddRecordValue("Person", NewRecordValueBuilder(). + AddStringValue("zName", fmt.Sprintf("john_%d", i)). + AddStringListValue("emails", + fmt.Sprintf("john_%d@y.com", i), + fmt.Sprintf("john_%d@g.com", i), + fmt.Sprintf("john_%d@t.com", i))). + AddStringValue("Company", fmt.Sprintf("company_%d", i)).Build()) + + row := rowBuilder.Row() + if err != nil { + t.Fatalf("rowBuilder.Build failed: %v", err) + } + + if _, err = writer.WriteRows([]parquet.Row{row}); err != nil { + t.Fatalf("writer.Write failed: %v", err) + } + } + if err = writer.Close(); err != nil { + t.Fatalf("writer.WriteStop failed: %v", err) + } +} + +func testReadingParquetFile(t *testing.T, filename string, parquetSchema *parquet.Schema, recordType *schema_pb.RecordType) (total int) { + // read the parquet file + file, err := os.Open(filename) + if err != nil { + t.Fatalf("os.Open failed: %v", err) + } + defer file.Close() + reader := parquet.NewReader(file, parquetSchema) + rows := make([]parquet.Row, 128) + for { + rowCount, err := reader.ReadRows(rows) + if err != nil { + if err == io.EOF { + break + } + t.Fatalf("reader.Read failed: %v", err) + } + for i := 0; i < rowCount; i++ { + row := rows[i] + // convert parquet row to schema_pb.RecordValue + recordValue, err := ToRecordValue(recordType, row) + if err != nil { + t.Fatalf("ToRecordValue failed: %v", err) + } + fmt.Printf("RecordValue: %v\n", recordValue) + } + total += rowCount + } + fmt.Printf("total: %v\n", total) + return }