You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

485 lines
15 KiB

package schema
import (
"fmt"
"sync"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protodesc"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/reflect/protoregistry"
"google.golang.org/protobuf/types/descriptorpb"
"google.golang.org/protobuf/types/dynamicpb"
)
// ProtobufSchema represents a parsed Protobuf schema with message type information
type ProtobufSchema struct {
FileDescriptorSet *descriptorpb.FileDescriptorSet
MessageDescriptor protoreflect.MessageDescriptor
MessageName string
PackageName string
Dependencies []string
}
// ProtobufDescriptorParser handles parsing of Confluent Schema Registry Protobuf descriptors
type ProtobufDescriptorParser struct {
mu sync.RWMutex
// Cache for parsed descriptors to avoid re-parsing
descriptorCache map[string]*ProtobufSchema
}
// NewProtobufDescriptorParser creates a new parser instance
func NewProtobufDescriptorParser() *ProtobufDescriptorParser {
return &ProtobufDescriptorParser{
descriptorCache: make(map[string]*ProtobufSchema),
}
}
// ParseBinaryDescriptor parses a Confluent Schema Registry Protobuf binary descriptor
// The input is typically a serialized FileDescriptorSet from the schema registry
func (p *ProtobufDescriptorParser) ParseBinaryDescriptor(binaryData []byte, messageName string) (*ProtobufSchema, error) {
// Check cache first
cacheKey := fmt.Sprintf("%x:%s", binaryData[:min(32, len(binaryData))], messageName)
p.mu.RLock()
if cached, exists := p.descriptorCache[cacheKey]; exists {
p.mu.RUnlock()
// If we have a cached schema but no message descriptor, return the same error
if cached.MessageDescriptor == nil {
return cached, fmt.Errorf("failed to find message descriptor for %s: message descriptor resolution not fully implemented in Phase E1 - found message %s in package %s", messageName, messageName, cached.PackageName)
}
return cached, nil
}
p.mu.RUnlock()
// Parse the FileDescriptorSet from binary data
var fileDescriptorSet descriptorpb.FileDescriptorSet
if err := proto.Unmarshal(binaryData, &fileDescriptorSet); err != nil {
return nil, fmt.Errorf("failed to unmarshal FileDescriptorSet: %w", err)
}
// Validate the descriptor set
if err := p.validateDescriptorSet(&fileDescriptorSet); err != nil {
return nil, fmt.Errorf("invalid descriptor set: %w", err)
}
// If no message name provided, try to find the first available message
if messageName == "" {
messageName = p.findFirstMessageName(&fileDescriptorSet)
if messageName == "" {
return nil, fmt.Errorf("no messages found in FileDescriptorSet")
}
}
// Find the target message descriptor
messageDesc, packageName, err := p.findMessageDescriptor(&fileDescriptorSet, messageName)
if err != nil {
// For Phase E1, we still cache the FileDescriptorSet even if message resolution fails
// This allows us to test caching behavior and avoid re-parsing the same binary data
schema := &ProtobufSchema{
FileDescriptorSet: &fileDescriptorSet,
MessageDescriptor: nil, // Not resolved in Phase E1
MessageName: messageName,
PackageName: packageName,
Dependencies: p.extractDependencies(&fileDescriptorSet),
}
p.mu.Lock()
p.descriptorCache[cacheKey] = schema
p.mu.Unlock()
return schema, fmt.Errorf("failed to find message descriptor for %s: %w", messageName, err)
}
// Extract dependencies
dependencies := p.extractDependencies(&fileDescriptorSet)
// Create the schema object
schema := &ProtobufSchema{
FileDescriptorSet: &fileDescriptorSet,
MessageDescriptor: messageDesc,
MessageName: messageName,
PackageName: packageName,
Dependencies: dependencies,
}
// Cache the result
p.mu.Lock()
p.descriptorCache[cacheKey] = schema
p.mu.Unlock()
return schema, nil
}
// validateDescriptorSet performs basic validation on the FileDescriptorSet
func (p *ProtobufDescriptorParser) validateDescriptorSet(fds *descriptorpb.FileDescriptorSet) error {
if len(fds.File) == 0 {
return fmt.Errorf("FileDescriptorSet contains no files")
}
for i, file := range fds.File {
if file.Name == nil {
return fmt.Errorf("file descriptor %d has no name", i)
}
if file.Package == nil {
return fmt.Errorf("file descriptor %s has no package", *file.Name)
}
}
return nil
}
// findFirstMessageName finds the first message name in the FileDescriptorSet
func (p *ProtobufDescriptorParser) findFirstMessageName(fds *descriptorpb.FileDescriptorSet) string {
for _, file := range fds.File {
if len(file.MessageType) > 0 {
return file.MessageType[0].GetName()
}
}
return ""
}
// findMessageDescriptor locates a specific message descriptor within the FileDescriptorSet
func (p *ProtobufDescriptorParser) findMessageDescriptor(fds *descriptorpb.FileDescriptorSet, messageName string) (protoreflect.MessageDescriptor, string, error) {
// This is a simplified implementation for Phase E1
// In a complete implementation, we would:
// 1. Build a complete descriptor registry from the FileDescriptorSet
// 2. Resolve all imports and dependencies
// 3. Handle nested message types and packages correctly
// 4. Support fully qualified message names
for _, file := range fds.File {
packageName := ""
if file.Package != nil {
packageName = *file.Package
}
// Search for the message in this file
for _, messageType := range file.MessageType {
if messageType.Name != nil && *messageType.Name == messageName {
// Try to build a proper descriptor from the FileDescriptorProto
fileDesc, err := p.buildFileDescriptor(file)
if err != nil {
return nil, packageName, fmt.Errorf("failed to build file descriptor: %w", err)
}
// Find the message descriptor in the built file
msgDesc := p.findMessageInFileDescriptor(fileDesc, messageName)
if msgDesc != nil {
return msgDesc, packageName, nil
}
return nil, packageName, fmt.Errorf("message descriptor built but not found: %s", messageName)
}
// Search nested messages (simplified)
if nestedDesc := p.searchNestedMessages(messageType, messageName); nestedDesc != nil {
// Try to build descriptor for nested message
fileDesc, err := p.buildFileDescriptor(file)
if err != nil {
return nil, packageName, fmt.Errorf("failed to build file descriptor for nested message: %w", err)
}
msgDesc := p.findMessageInFileDescriptor(fileDesc, messageName)
if msgDesc != nil {
return msgDesc, packageName, nil
}
return nil, packageName, fmt.Errorf("nested message descriptor built but not found: %s", messageName)
}
}
}
return nil, "", fmt.Errorf("message %s not found in descriptor set", messageName)
}
// buildFileDescriptor builds a protoreflect.FileDescriptor from a FileDescriptorProto
func (p *ProtobufDescriptorParser) buildFileDescriptor(fileProto *descriptorpb.FileDescriptorProto) (protoreflect.FileDescriptor, error) {
// Create a local registry to avoid conflicts
localFiles := &protoregistry.Files{}
// Build the file descriptor using protodesc
fileDesc, err := protodesc.NewFile(fileProto, localFiles)
if err != nil {
return nil, fmt.Errorf("failed to create file descriptor: %w", err)
}
return fileDesc, nil
}
// findMessageInFileDescriptor searches for a message descriptor within a file descriptor
func (p *ProtobufDescriptorParser) findMessageInFileDescriptor(fileDesc protoreflect.FileDescriptor, messageName string) protoreflect.MessageDescriptor {
// Search top-level messages
messages := fileDesc.Messages()
for i := 0; i < messages.Len(); i++ {
msgDesc := messages.Get(i)
if string(msgDesc.Name()) == messageName {
return msgDesc
}
// Search nested messages
if nestedDesc := p.findNestedMessageDescriptor(msgDesc, messageName); nestedDesc != nil {
return nestedDesc
}
}
return nil
}
// findNestedMessageDescriptor recursively searches for nested messages
func (p *ProtobufDescriptorParser) findNestedMessageDescriptor(msgDesc protoreflect.MessageDescriptor, messageName string) protoreflect.MessageDescriptor {
nestedMessages := msgDesc.Messages()
for i := 0; i < nestedMessages.Len(); i++ {
nestedDesc := nestedMessages.Get(i)
if string(nestedDesc.Name()) == messageName {
return nestedDesc
}
// Recursively search deeper nested messages
if deeperNested := p.findNestedMessageDescriptor(nestedDesc, messageName); deeperNested != nil {
return deeperNested
}
}
return nil
}
// searchNestedMessages recursively searches for nested message types
func (p *ProtobufDescriptorParser) searchNestedMessages(messageType *descriptorpb.DescriptorProto, targetName string) *descriptorpb.DescriptorProto {
for _, nested := range messageType.NestedType {
if nested.Name != nil && *nested.Name == targetName {
return nested
}
// Recursively search deeper nesting
if found := p.searchNestedMessages(nested, targetName); found != nil {
return found
}
}
return nil
}
// extractDependencies extracts the list of dependencies from the FileDescriptorSet
func (p *ProtobufDescriptorParser) extractDependencies(fds *descriptorpb.FileDescriptorSet) []string {
dependencySet := make(map[string]bool)
for _, file := range fds.File {
for _, dep := range file.Dependency {
dependencySet[dep] = true
}
}
dependencies := make([]string, 0, len(dependencySet))
for dep := range dependencySet {
dependencies = append(dependencies, dep)
}
return dependencies
}
// GetMessageFields returns information about the fields in the message
func (s *ProtobufSchema) GetMessageFields() ([]FieldInfo, error) {
if s.FileDescriptorSet == nil {
return nil, fmt.Errorf("no FileDescriptorSet available")
}
// Find the message descriptor for this schema
messageDesc := s.findMessageDescriptor(s.MessageName)
if messageDesc == nil {
return nil, fmt.Errorf("message %s not found in descriptor set", s.MessageName)
}
// Extract field information
fields := make([]FieldInfo, 0, len(messageDesc.Field))
for _, field := range messageDesc.Field {
fieldInfo := FieldInfo{
Name: field.GetName(),
Number: field.GetNumber(),
Type: s.fieldTypeToString(field.GetType()),
Label: s.fieldLabelToString(field.GetLabel()),
}
// Set TypeName for message/enum types
if field.GetTypeName() != "" {
fieldInfo.TypeName = field.GetTypeName()
}
fields = append(fields, fieldInfo)
}
return fields, nil
}
// FieldInfo represents information about a Protobuf field
type FieldInfo struct {
Name string
Number int32
Type string
Label string // optional, required, repeated
TypeName string // for message/enum types
}
// GetFieldByName returns information about a specific field
func (s *ProtobufSchema) GetFieldByName(fieldName string) (*FieldInfo, error) {
fields, err := s.GetMessageFields()
if err != nil {
return nil, err
}
for _, field := range fields {
if field.Name == fieldName {
return &field, nil
}
}
return nil, fmt.Errorf("field %s not found", fieldName)
}
// GetFieldByNumber returns information about a field by its number
func (s *ProtobufSchema) GetFieldByNumber(fieldNumber int32) (*FieldInfo, error) {
fields, err := s.GetMessageFields()
if err != nil {
return nil, err
}
for _, field := range fields {
if field.Number == fieldNumber {
return &field, nil
}
}
return nil, fmt.Errorf("field number %d not found", fieldNumber)
}
// findMessageDescriptor finds a message descriptor by name in the FileDescriptorSet
func (s *ProtobufSchema) findMessageDescriptor(messageName string) *descriptorpb.DescriptorProto {
if s.FileDescriptorSet == nil {
return nil
}
for _, file := range s.FileDescriptorSet.File {
// Check top-level messages
for _, message := range file.MessageType {
if message.GetName() == messageName {
return message
}
// Check nested messages
if nested := searchNestedMessages(message, messageName); nested != nil {
return nested
}
}
}
return nil
}
// searchNestedMessages recursively searches for nested message types
func searchNestedMessages(messageType *descriptorpb.DescriptorProto, targetName string) *descriptorpb.DescriptorProto {
for _, nested := range messageType.NestedType {
if nested.Name != nil && *nested.Name == targetName {
return nested
}
// Recursively search deeper nesting
if found := searchNestedMessages(nested, targetName); found != nil {
return found
}
}
return nil
}
// fieldTypeToString converts a FieldDescriptorProto_Type to string
func (s *ProtobufSchema) fieldTypeToString(fieldType descriptorpb.FieldDescriptorProto_Type) string {
switch fieldType {
case descriptorpb.FieldDescriptorProto_TYPE_DOUBLE:
return "double"
case descriptorpb.FieldDescriptorProto_TYPE_FLOAT:
return "float"
case descriptorpb.FieldDescriptorProto_TYPE_INT64:
return "int64"
case descriptorpb.FieldDescriptorProto_TYPE_UINT64:
return "uint64"
case descriptorpb.FieldDescriptorProto_TYPE_INT32:
return "int32"
case descriptorpb.FieldDescriptorProto_TYPE_FIXED64:
return "fixed64"
case descriptorpb.FieldDescriptorProto_TYPE_FIXED32:
return "fixed32"
case descriptorpb.FieldDescriptorProto_TYPE_BOOL:
return "bool"
case descriptorpb.FieldDescriptorProto_TYPE_STRING:
return "string"
case descriptorpb.FieldDescriptorProto_TYPE_GROUP:
return "group"
case descriptorpb.FieldDescriptorProto_TYPE_MESSAGE:
return "message"
case descriptorpb.FieldDescriptorProto_TYPE_BYTES:
return "bytes"
case descriptorpb.FieldDescriptorProto_TYPE_UINT32:
return "uint32"
case descriptorpb.FieldDescriptorProto_TYPE_ENUM:
return "enum"
case descriptorpb.FieldDescriptorProto_TYPE_SFIXED32:
return "sfixed32"
case descriptorpb.FieldDescriptorProto_TYPE_SFIXED64:
return "sfixed64"
case descriptorpb.FieldDescriptorProto_TYPE_SINT32:
return "sint32"
case descriptorpb.FieldDescriptorProto_TYPE_SINT64:
return "sint64"
default:
return "unknown"
}
}
// fieldLabelToString converts a FieldDescriptorProto_Label to string
func (s *ProtobufSchema) fieldLabelToString(label descriptorpb.FieldDescriptorProto_Label) string {
switch label {
case descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL:
return "optional"
case descriptorpb.FieldDescriptorProto_LABEL_REQUIRED:
return "required"
case descriptorpb.FieldDescriptorProto_LABEL_REPEATED:
return "repeated"
default:
return "unknown"
}
}
// ValidateMessage validates that a message conforms to the schema
func (s *ProtobufSchema) ValidateMessage(messageData []byte) error {
if s.MessageDescriptor == nil {
return fmt.Errorf("no message descriptor available for validation")
}
// Create a dynamic message from the descriptor
msgType := dynamicpb.NewMessageType(s.MessageDescriptor)
msg := msgType.New()
// Try to unmarshal the message data
if err := proto.Unmarshal(messageData, msg.Interface()); err != nil {
return fmt.Errorf("message validation failed: %w", err)
}
// Basic validation passed - the message can be unmarshaled with the schema
return nil
}
// ClearCache clears the descriptor cache
func (p *ProtobufDescriptorParser) ClearCache() {
p.mu.Lock()
defer p.mu.Unlock()
p.descriptorCache = make(map[string]*ProtobufSchema)
}
// GetCacheStats returns statistics about the descriptor cache
func (p *ProtobufDescriptorParser) GetCacheStats() map[string]interface{} {
p.mu.RLock()
defer p.mu.RUnlock()
return map[string]interface{}{
"cached_descriptors": len(p.descriptorCache),
}
}
// Helper function for min
func min(a, b int) int {
if a < b {
return a
}
return b
}