You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							485 lines
						
					
					
						
							15 KiB
						
					
					
				
			
		
		
		
			
			
			
		
		
	
	
							485 lines
						
					
					
						
							15 KiB
						
					
					
				| package schema | |
| 
 | |
| import ( | |
| 	"fmt" | |
| 	"sync" | |
| 
 | |
| 	"google.golang.org/protobuf/proto" | |
| 	"google.golang.org/protobuf/reflect/protodesc" | |
| 	"google.golang.org/protobuf/reflect/protoreflect" | |
| 	"google.golang.org/protobuf/reflect/protoregistry" | |
| 	"google.golang.org/protobuf/types/descriptorpb" | |
| 	"google.golang.org/protobuf/types/dynamicpb" | |
| ) | |
| 
 | |
| // ProtobufSchema represents a parsed Protobuf schema with message type information | |
| type ProtobufSchema struct { | |
| 	FileDescriptorSet *descriptorpb.FileDescriptorSet | |
| 	MessageDescriptor protoreflect.MessageDescriptor | |
| 	MessageName       string | |
| 	PackageName       string | |
| 	Dependencies      []string | |
| } | |
| 
 | |
| // ProtobufDescriptorParser handles parsing of Confluent Schema Registry Protobuf descriptors | |
| type ProtobufDescriptorParser struct { | |
| 	mu sync.RWMutex | |
| 	// Cache for parsed descriptors to avoid re-parsing | |
| 	descriptorCache map[string]*ProtobufSchema | |
| } | |
| 
 | |
| // NewProtobufDescriptorParser creates a new parser instance | |
| func NewProtobufDescriptorParser() *ProtobufDescriptorParser { | |
| 	return &ProtobufDescriptorParser{ | |
| 		descriptorCache: make(map[string]*ProtobufSchema), | |
| 	} | |
| } | |
| 
 | |
| // ParseBinaryDescriptor parses a Confluent Schema Registry Protobuf binary descriptor | |
| // The input is typically a serialized FileDescriptorSet from the schema registry | |
| func (p *ProtobufDescriptorParser) ParseBinaryDescriptor(binaryData []byte, messageName string) (*ProtobufSchema, error) { | |
| 	// Check cache first | |
| 	cacheKey := fmt.Sprintf("%x:%s", binaryData[:min(32, len(binaryData))], messageName) | |
| 	p.mu.RLock() | |
| 	if cached, exists := p.descriptorCache[cacheKey]; exists { | |
| 		p.mu.RUnlock() | |
| 		// If we have a cached schema but no message descriptor, return the same error | |
| 		if cached.MessageDescriptor == nil { | |
| 			return cached, fmt.Errorf("failed to find message descriptor for %s: message descriptor resolution not fully implemented in Phase E1 - found message %s in package %s", messageName, messageName, cached.PackageName) | |
| 		} | |
| 		return cached, nil | |
| 	} | |
| 	p.mu.RUnlock() | |
| 
 | |
| 	// Parse the FileDescriptorSet from binary data | |
| 	var fileDescriptorSet descriptorpb.FileDescriptorSet | |
| 	if err := proto.Unmarshal(binaryData, &fileDescriptorSet); err != nil { | |
| 		return nil, fmt.Errorf("failed to unmarshal FileDescriptorSet: %w", err) | |
| 	} | |
| 
 | |
| 	// Validate the descriptor set | |
| 	if err := p.validateDescriptorSet(&fileDescriptorSet); err != nil { | |
| 		return nil, fmt.Errorf("invalid descriptor set: %w", err) | |
| 	} | |
| 
 | |
| 	// If no message name provided, try to find the first available message | |
| 	if messageName == "" { | |
| 		messageName = p.findFirstMessageName(&fileDescriptorSet) | |
| 		if messageName == "" { | |
| 			return nil, fmt.Errorf("no messages found in FileDescriptorSet") | |
| 		} | |
| 	} | |
| 
 | |
| 	// Find the target message descriptor | |
| 	messageDesc, packageName, err := p.findMessageDescriptor(&fileDescriptorSet, messageName) | |
| 	if err != nil { | |
| 		// For Phase E1, we still cache the FileDescriptorSet even if message resolution fails | |
| 		// This allows us to test caching behavior and avoid re-parsing the same binary data | |
| 		schema := &ProtobufSchema{ | |
| 			FileDescriptorSet: &fileDescriptorSet, | |
| 			MessageDescriptor: nil, // Not resolved in Phase E1 | |
| 			MessageName:       messageName, | |
| 			PackageName:       packageName, | |
| 			Dependencies:      p.extractDependencies(&fileDescriptorSet), | |
| 		} | |
| 		p.mu.Lock() | |
| 		p.descriptorCache[cacheKey] = schema | |
| 		p.mu.Unlock() | |
| 		return schema, fmt.Errorf("failed to find message descriptor for %s: %w", messageName, err) | |
| 	} | |
| 
 | |
| 	// Extract dependencies | |
| 	dependencies := p.extractDependencies(&fileDescriptorSet) | |
| 
 | |
| 	// Create the schema object | |
| 	schema := &ProtobufSchema{ | |
| 		FileDescriptorSet: &fileDescriptorSet, | |
| 		MessageDescriptor: messageDesc, | |
| 		MessageName:       messageName, | |
| 		PackageName:       packageName, | |
| 		Dependencies:      dependencies, | |
| 	} | |
| 
 | |
| 	// Cache the result | |
| 	p.mu.Lock() | |
| 	p.descriptorCache[cacheKey] = schema | |
| 	p.mu.Unlock() | |
| 
 | |
| 	return schema, nil | |
| } | |
| 
 | |
| // validateDescriptorSet performs basic validation on the FileDescriptorSet | |
| func (p *ProtobufDescriptorParser) validateDescriptorSet(fds *descriptorpb.FileDescriptorSet) error { | |
| 	if len(fds.File) == 0 { | |
| 		return fmt.Errorf("FileDescriptorSet contains no files") | |
| 	} | |
| 
 | |
| 	for i, file := range fds.File { | |
| 		if file.Name == nil { | |
| 			return fmt.Errorf("file descriptor %d has no name", i) | |
| 		} | |
| 		if file.Package == nil { | |
| 			return fmt.Errorf("file descriptor %s has no package", *file.Name) | |
| 		} | |
| 	} | |
| 
 | |
| 	return nil | |
| } | |
| 
 | |
| // findFirstMessageName finds the first message name in the FileDescriptorSet | |
| func (p *ProtobufDescriptorParser) findFirstMessageName(fds *descriptorpb.FileDescriptorSet) string { | |
| 	for _, file := range fds.File { | |
| 		if len(file.MessageType) > 0 { | |
| 			return file.MessageType[0].GetName() | |
| 		} | |
| 	} | |
| 	return "" | |
| } | |
| 
 | |
| // findMessageDescriptor locates a specific message descriptor within the FileDescriptorSet | |
| func (p *ProtobufDescriptorParser) findMessageDescriptor(fds *descriptorpb.FileDescriptorSet, messageName string) (protoreflect.MessageDescriptor, string, error) { | |
| 	// This is a simplified implementation for Phase E1 | |
| 	// In a complete implementation, we would: | |
| 	// 1. Build a complete descriptor registry from the FileDescriptorSet | |
| 	// 2. Resolve all imports and dependencies | |
| 	// 3. Handle nested message types and packages correctly | |
| 	// 4. Support fully qualified message names | |
|  | |
| 	for _, file := range fds.File { | |
| 		packageName := "" | |
| 		if file.Package != nil { | |
| 			packageName = *file.Package | |
| 		} | |
| 
 | |
| 		// Search for the message in this file | |
| 		for _, messageType := range file.MessageType { | |
| 			if messageType.Name != nil && *messageType.Name == messageName { | |
| 				// Try to build a proper descriptor from the FileDescriptorProto | |
| 				fileDesc, err := p.buildFileDescriptor(file) | |
| 				if err != nil { | |
| 					return nil, packageName, fmt.Errorf("failed to build file descriptor: %w", err) | |
| 				} | |
| 
 | |
| 				// Find the message descriptor in the built file | |
| 				msgDesc := p.findMessageInFileDescriptor(fileDesc, messageName) | |
| 				if msgDesc != nil { | |
| 					return msgDesc, packageName, nil | |
| 				} | |
| 
 | |
| 				return nil, packageName, fmt.Errorf("message descriptor built but not found: %s", messageName) | |
| 			} | |
| 
 | |
| 			// Search nested messages (simplified) | |
| 			if nestedDesc := p.searchNestedMessages(messageType, messageName); nestedDesc != nil { | |
| 				// Try to build descriptor for nested message | |
| 				fileDesc, err := p.buildFileDescriptor(file) | |
| 				if err != nil { | |
| 					return nil, packageName, fmt.Errorf("failed to build file descriptor for nested message: %w", err) | |
| 				} | |
| 
 | |
| 				msgDesc := p.findMessageInFileDescriptor(fileDesc, messageName) | |
| 				if msgDesc != nil { | |
| 					return msgDesc, packageName, nil | |
| 				} | |
| 
 | |
| 				return nil, packageName, fmt.Errorf("nested message descriptor built but not found: %s", messageName) | |
| 			} | |
| 		} | |
| 	} | |
| 
 | |
| 	return nil, "", fmt.Errorf("message %s not found in descriptor set", messageName) | |
| } | |
| 
 | |
| // buildFileDescriptor builds a protoreflect.FileDescriptor from a FileDescriptorProto | |
| func (p *ProtobufDescriptorParser) buildFileDescriptor(fileProto *descriptorpb.FileDescriptorProto) (protoreflect.FileDescriptor, error) { | |
| 	// Create a local registry to avoid conflicts | |
| 	localFiles := &protoregistry.Files{} | |
| 
 | |
| 	// Build the file descriptor using protodesc | |
| 	fileDesc, err := protodesc.NewFile(fileProto, localFiles) | |
| 	if err != nil { | |
| 		return nil, fmt.Errorf("failed to create file descriptor: %w", err) | |
| 	} | |
| 
 | |
| 	return fileDesc, nil | |
| } | |
| 
 | |
| // findMessageInFileDescriptor searches for a message descriptor within a file descriptor | |
| func (p *ProtobufDescriptorParser) findMessageInFileDescriptor(fileDesc protoreflect.FileDescriptor, messageName string) protoreflect.MessageDescriptor { | |
| 	// Search top-level messages | |
| 	messages := fileDesc.Messages() | |
| 	for i := 0; i < messages.Len(); i++ { | |
| 		msgDesc := messages.Get(i) | |
| 		if string(msgDesc.Name()) == messageName { | |
| 			return msgDesc | |
| 		} | |
| 
 | |
| 		// Search nested messages | |
| 		if nestedDesc := p.findNestedMessageDescriptor(msgDesc, messageName); nestedDesc != nil { | |
| 			return nestedDesc | |
| 		} | |
| 	} | |
| 
 | |
| 	return nil | |
| } | |
| 
 | |
| // findNestedMessageDescriptor recursively searches for nested messages | |
| func (p *ProtobufDescriptorParser) findNestedMessageDescriptor(msgDesc protoreflect.MessageDescriptor, messageName string) protoreflect.MessageDescriptor { | |
| 	nestedMessages := msgDesc.Messages() | |
| 	for i := 0; i < nestedMessages.Len(); i++ { | |
| 		nestedDesc := nestedMessages.Get(i) | |
| 		if string(nestedDesc.Name()) == messageName { | |
| 			return nestedDesc | |
| 		} | |
| 
 | |
| 		// Recursively search deeper nested messages | |
| 		if deeperNested := p.findNestedMessageDescriptor(nestedDesc, messageName); deeperNested != nil { | |
| 			return deeperNested | |
| 		} | |
| 	} | |
| 
 | |
| 	return nil | |
| } | |
| 
 | |
| // searchNestedMessages recursively searches for nested message types | |
| func (p *ProtobufDescriptorParser) searchNestedMessages(messageType *descriptorpb.DescriptorProto, targetName string) *descriptorpb.DescriptorProto { | |
| 	for _, nested := range messageType.NestedType { | |
| 		if nested.Name != nil && *nested.Name == targetName { | |
| 			return nested | |
| 		} | |
| 		// Recursively search deeper nesting | |
| 		if found := p.searchNestedMessages(nested, targetName); found != nil { | |
| 			return found | |
| 		} | |
| 	} | |
| 	return nil | |
| } | |
| 
 | |
| // extractDependencies extracts the list of dependencies from the FileDescriptorSet | |
| func (p *ProtobufDescriptorParser) extractDependencies(fds *descriptorpb.FileDescriptorSet) []string { | |
| 	dependencySet := make(map[string]bool) | |
| 
 | |
| 	for _, file := range fds.File { | |
| 		for _, dep := range file.Dependency { | |
| 			dependencySet[dep] = true | |
| 		} | |
| 	} | |
| 
 | |
| 	dependencies := make([]string, 0, len(dependencySet)) | |
| 	for dep := range dependencySet { | |
| 		dependencies = append(dependencies, dep) | |
| 	} | |
| 
 | |
| 	return dependencies | |
| } | |
| 
 | |
| // GetMessageFields returns information about the fields in the message | |
| func (s *ProtobufSchema) GetMessageFields() ([]FieldInfo, error) { | |
| 	if s.FileDescriptorSet == nil { | |
| 		return nil, fmt.Errorf("no FileDescriptorSet available") | |
| 	} | |
| 
 | |
| 	// Find the message descriptor for this schema | |
| 	messageDesc := s.findMessageDescriptor(s.MessageName) | |
| 	if messageDesc == nil { | |
| 		return nil, fmt.Errorf("message %s not found in descriptor set", s.MessageName) | |
| 	} | |
| 
 | |
| 	// Extract field information | |
| 	fields := make([]FieldInfo, 0, len(messageDesc.Field)) | |
| 	for _, field := range messageDesc.Field { | |
| 		fieldInfo := FieldInfo{ | |
| 			Name:   field.GetName(), | |
| 			Number: field.GetNumber(), | |
| 			Type:   s.fieldTypeToString(field.GetType()), | |
| 			Label:  s.fieldLabelToString(field.GetLabel()), | |
| 		} | |
| 
 | |
| 		// Set TypeName for message/enum types | |
| 		if field.GetTypeName() != "" { | |
| 			fieldInfo.TypeName = field.GetTypeName() | |
| 		} | |
| 
 | |
| 		fields = append(fields, fieldInfo) | |
| 	} | |
| 
 | |
| 	return fields, nil | |
| } | |
| 
 | |
| // FieldInfo represents information about a Protobuf field | |
| type FieldInfo struct { | |
| 	Name     string | |
| 	Number   int32 | |
| 	Type     string | |
| 	Label    string // optional, required, repeated | |
| 	TypeName string // for message/enum types | |
| } | |
| 
 | |
| // GetFieldByName returns information about a specific field | |
| func (s *ProtobufSchema) GetFieldByName(fieldName string) (*FieldInfo, error) { | |
| 	fields, err := s.GetMessageFields() | |
| 	if err != nil { | |
| 		return nil, err | |
| 	} | |
| 
 | |
| 	for _, field := range fields { | |
| 		if field.Name == fieldName { | |
| 			return &field, nil | |
| 		} | |
| 	} | |
| 
 | |
| 	return nil, fmt.Errorf("field %s not found", fieldName) | |
| } | |
| 
 | |
| // GetFieldByNumber returns information about a field by its number | |
| func (s *ProtobufSchema) GetFieldByNumber(fieldNumber int32) (*FieldInfo, error) { | |
| 	fields, err := s.GetMessageFields() | |
| 	if err != nil { | |
| 		return nil, err | |
| 	} | |
| 
 | |
| 	for _, field := range fields { | |
| 		if field.Number == fieldNumber { | |
| 			return &field, nil | |
| 		} | |
| 	} | |
| 
 | |
| 	return nil, fmt.Errorf("field number %d not found", fieldNumber) | |
| } | |
| 
 | |
| // findMessageDescriptor finds a message descriptor by name in the FileDescriptorSet | |
| func (s *ProtobufSchema) findMessageDescriptor(messageName string) *descriptorpb.DescriptorProto { | |
| 	if s.FileDescriptorSet == nil { | |
| 		return nil | |
| 	} | |
| 
 | |
| 	for _, file := range s.FileDescriptorSet.File { | |
| 		// Check top-level messages | |
| 		for _, message := range file.MessageType { | |
| 			if message.GetName() == messageName { | |
| 				return message | |
| 			} | |
| 			// Check nested messages | |
| 			if nested := searchNestedMessages(message, messageName); nested != nil { | |
| 				return nested | |
| 			} | |
| 		} | |
| 	} | |
| 
 | |
| 	return nil | |
| } | |
| 
 | |
| // searchNestedMessages recursively searches for nested message types | |
| func searchNestedMessages(messageType *descriptorpb.DescriptorProto, targetName string) *descriptorpb.DescriptorProto { | |
| 	for _, nested := range messageType.NestedType { | |
| 		if nested.Name != nil && *nested.Name == targetName { | |
| 			return nested | |
| 		} | |
| 		// Recursively search deeper nesting | |
| 		if found := searchNestedMessages(nested, targetName); found != nil { | |
| 			return found | |
| 		} | |
| 	} | |
| 	return nil | |
| } | |
| 
 | |
| // fieldTypeToString converts a FieldDescriptorProto_Type to string | |
| func (s *ProtobufSchema) fieldTypeToString(fieldType descriptorpb.FieldDescriptorProto_Type) string { | |
| 	switch fieldType { | |
| 	case descriptorpb.FieldDescriptorProto_TYPE_DOUBLE: | |
| 		return "double" | |
| 	case descriptorpb.FieldDescriptorProto_TYPE_FLOAT: | |
| 		return "float" | |
| 	case descriptorpb.FieldDescriptorProto_TYPE_INT64: | |
| 		return "int64" | |
| 	case descriptorpb.FieldDescriptorProto_TYPE_UINT64: | |
| 		return "uint64" | |
| 	case descriptorpb.FieldDescriptorProto_TYPE_INT32: | |
| 		return "int32" | |
| 	case descriptorpb.FieldDescriptorProto_TYPE_FIXED64: | |
| 		return "fixed64" | |
| 	case descriptorpb.FieldDescriptorProto_TYPE_FIXED32: | |
| 		return "fixed32" | |
| 	case descriptorpb.FieldDescriptorProto_TYPE_BOOL: | |
| 		return "bool" | |
| 	case descriptorpb.FieldDescriptorProto_TYPE_STRING: | |
| 		return "string" | |
| 	case descriptorpb.FieldDescriptorProto_TYPE_GROUP: | |
| 		return "group" | |
| 	case descriptorpb.FieldDescriptorProto_TYPE_MESSAGE: | |
| 		return "message" | |
| 	case descriptorpb.FieldDescriptorProto_TYPE_BYTES: | |
| 		return "bytes" | |
| 	case descriptorpb.FieldDescriptorProto_TYPE_UINT32: | |
| 		return "uint32" | |
| 	case descriptorpb.FieldDescriptorProto_TYPE_ENUM: | |
| 		return "enum" | |
| 	case descriptorpb.FieldDescriptorProto_TYPE_SFIXED32: | |
| 		return "sfixed32" | |
| 	case descriptorpb.FieldDescriptorProto_TYPE_SFIXED64: | |
| 		return "sfixed64" | |
| 	case descriptorpb.FieldDescriptorProto_TYPE_SINT32: | |
| 		return "sint32" | |
| 	case descriptorpb.FieldDescriptorProto_TYPE_SINT64: | |
| 		return "sint64" | |
| 	default: | |
| 		return "unknown" | |
| 	} | |
| } | |
| 
 | |
| // fieldLabelToString converts a FieldDescriptorProto_Label to string | |
| func (s *ProtobufSchema) fieldLabelToString(label descriptorpb.FieldDescriptorProto_Label) string { | |
| 	switch label { | |
| 	case descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL: | |
| 		return "optional" | |
| 	case descriptorpb.FieldDescriptorProto_LABEL_REQUIRED: | |
| 		return "required" | |
| 	case descriptorpb.FieldDescriptorProto_LABEL_REPEATED: | |
| 		return "repeated" | |
| 	default: | |
| 		return "unknown" | |
| 	} | |
| } | |
| 
 | |
| // ValidateMessage validates that a message conforms to the schema | |
| func (s *ProtobufSchema) ValidateMessage(messageData []byte) error { | |
| 	if s.MessageDescriptor == nil { | |
| 		return fmt.Errorf("no message descriptor available for validation") | |
| 	} | |
| 
 | |
| 	// Create a dynamic message from the descriptor | |
| 	msgType := dynamicpb.NewMessageType(s.MessageDescriptor) | |
| 	msg := msgType.New() | |
| 
 | |
| 	// Try to unmarshal the message data | |
| 	if err := proto.Unmarshal(messageData, msg.Interface()); err != nil { | |
| 		return fmt.Errorf("message validation failed: %w", err) | |
| 	} | |
| 
 | |
| 	// Basic validation passed - the message can be unmarshaled with the schema | |
| 	return nil | |
| } | |
| 
 | |
| // ClearCache clears the descriptor cache | |
| func (p *ProtobufDescriptorParser) ClearCache() { | |
| 	p.mu.Lock() | |
| 	defer p.mu.Unlock() | |
| 	p.descriptorCache = make(map[string]*ProtobufSchema) | |
| } | |
| 
 | |
| // GetCacheStats returns statistics about the descriptor cache | |
| func (p *ProtobufDescriptorParser) GetCacheStats() map[string]interface{} { | |
| 	p.mu.RLock() | |
| 	defer p.mu.RUnlock() | |
| 	return map[string]interface{}{ | |
| 		"cached_descriptors": len(p.descriptorCache), | |
| 	} | |
| } | |
| 
 | |
| // Helper function for min | |
| func min(a, b int) int { | |
| 	if a < b { | |
| 		return a | |
| 	} | |
| 	return b | |
| }
 |