diff options
Diffstat (limited to 'python/google/protobuf/descriptor_pool.py')
-rw-r--r-- | python/google/protobuf/descriptor_pool.py | 196 |
1 files changed, 170 insertions, 26 deletions
diff --git a/python/google/protobuf/descriptor_pool.py b/python/google/protobuf/descriptor_pool.py index 8983f76f..42f7bcb5 100644 --- a/python/google/protobuf/descriptor_pool.py +++ b/python/google/protobuf/descriptor_pool.py @@ -131,33 +131,46 @@ class DescriptorPool(object): # TODO(jieluo): Remove _file_desc_by_toplevel_extension after # maybe year 2020 for compatibility issue (with 3.4.1 only). self._file_desc_by_toplevel_extension = {} + self._top_enum_values = {} # We store extensions in two two-level mappings: The first key is the # descriptor of the message being extended, the second key is the extension # full name or its tag number. self._extensions_by_name = collections.defaultdict(dict) self._extensions_by_number = collections.defaultdict(dict) - def _CheckConflictRegister(self, desc): + def _CheckConflictRegister(self, desc, desc_name, file_name): """Check if the descriptor name conflicts with another of the same name. Args: - desc: Descriptor of a message, enum, service or extension. + desc: Descriptor of a message, enum, service, extension or enum value. + desc_name: the full name of desc. + file_name: The file name of descriptor. """ - desc_name = desc.full_name for register, descriptor_type in [ (self._descriptors, descriptor.Descriptor), (self._enum_descriptors, descriptor.EnumDescriptor), (self._service_descriptors, descriptor.ServiceDescriptor), - (self._toplevel_extensions, descriptor.FieldDescriptor)]: + (self._toplevel_extensions, descriptor.FieldDescriptor), + (self._top_enum_values, descriptor.EnumValueDescriptor)]: if desc_name in register: - file_name = register[desc_name].file.name + old_desc = register[desc_name] + if isinstance(old_desc, descriptor.EnumValueDescriptor): + old_file = old_desc.type.file.name + else: + old_file = old_desc.file.name + if not isinstance(desc, descriptor_type) or ( - file_name != desc.file.name): - warn_msg = ('Conflict register for file "' + desc.file.name + + old_file != file_name): + warn_msg = ('Conflict register for file "' + file_name + '": ' + desc_name + ' is already defined in file "' + - file_name + '"') + old_file + '"') + if isinstance(desc, descriptor.EnumValueDescriptor): + warn_msg += ('\nNote: enum values appear as ' + 'siblings of the enum type instead of ' + 'children of it.') warnings.warn(warn_msg, RuntimeWarning) + return def Add(self, file_desc_proto): @@ -196,7 +209,7 @@ class DescriptorPool(object): if not isinstance(desc, descriptor.Descriptor): raise TypeError('Expected instance of descriptor.Descriptor.') - self._CheckConflictRegister(desc) + self._CheckConflictRegister(desc, desc.full_name, desc.file.name) self._descriptors[desc.full_name] = desc self._AddFileDescriptor(desc.file) @@ -213,8 +226,26 @@ class DescriptorPool(object): if not isinstance(enum_desc, descriptor.EnumDescriptor): raise TypeError('Expected instance of descriptor.EnumDescriptor.') - self._CheckConflictRegister(enum_desc) + file_name = enum_desc.file.name + self._CheckConflictRegister(enum_desc, enum_desc.full_name, file_name) self._enum_descriptors[enum_desc.full_name] = enum_desc + + # Top enum values need to be indexed. + # Count the number of dots to see whether the enum is toplevel or nested + # in a message. We cannot use enum_desc.containing_type at this stage. + if enum_desc.file.package: + top_level = (enum_desc.full_name.count('.') + - enum_desc.file.package.count('.') == 1) + else: + top_level = enum_desc.full_name.count('.') == 0 + if top_level: + file_name = enum_desc.file.name + package = enum_desc.file.package + for enum_value in enum_desc.values: + full_name = _NormalizeFullyQualifiedName( + '.'.join((package, enum_value.name))) + self._CheckConflictRegister(enum_value, full_name, file_name) + self._top_enum_values[full_name] = enum_value self._AddFileDescriptor(enum_desc.file) def AddServiceDescriptor(self, service_desc): @@ -227,7 +258,8 @@ class DescriptorPool(object): if not isinstance(service_desc, descriptor.ServiceDescriptor): raise TypeError('Expected instance of descriptor.ServiceDescriptor.') - self._CheckConflictRegister(service_desc) + self._CheckConflictRegister(service_desc, service_desc.full_name, + service_desc.file.name) self._service_descriptors[service_desc.full_name] = service_desc def AddExtensionDescriptor(self, extension): @@ -247,7 +279,6 @@ class DescriptorPool(object): raise TypeError('Expected an extension descriptor.') if extension.extension_scope is None: - self._CheckConflictRegister(extension) self._toplevel_extensions[extension.full_name] = extension try: @@ -349,6 +380,30 @@ class DescriptorPool(object): symbol = _NormalizeFullyQualifiedName(symbol) try: + return self._InternalFindFileContainingSymbol(symbol) + except KeyError: + pass + + try: + # Try fallback database. Build and find again if possible. + self._FindFileContainingSymbolInDb(symbol) + return self._InternalFindFileContainingSymbol(symbol) + except KeyError: + raise KeyError('Cannot find a file containing %s' % symbol) + + def _InternalFindFileContainingSymbol(self, symbol): + """Gets the already built FileDescriptor containing the specified symbol. + + Args: + symbol: The name of the symbol to search for. + + Returns: + A FileDescriptor that contains the specified symbol. + + Raises: + KeyError: if the file cannot be found in the pool. + """ + try: return self._descriptors[symbol].file except KeyError: pass @@ -364,7 +419,7 @@ class DescriptorPool(object): pass try: - return self._FindFileContainingSymbolInDb(symbol) + return self._top_enum_values[symbol].type.file except KeyError: pass @@ -373,13 +428,15 @@ class DescriptorPool(object): except KeyError: pass - # Try nested extensions inside a message. - message_name, _, extension_name = symbol.rpartition('.') + # Try fields, enum values and nested extensions inside a message. + top_name, _, sub_name = symbol.rpartition('.') try: - message = self.FindMessageTypeByName(message_name) - assert message.extensions_by_name[extension_name] + message = self.FindMessageTypeByName(top_name) + assert (sub_name in message.extensions_by_name or + sub_name in message.fields_by_name or + sub_name in message.enum_values_by_name) return message.file - except KeyError: + except (KeyError, AssertionError): raise KeyError('Cannot find a file containing %s' % symbol) def FindMessageTypeByName(self, full_name): @@ -499,7 +556,11 @@ class DescriptorPool(object): KeyError: when no extension with the given number is known for the specified message. """ - return self._extensions_by_number[message_descriptor][number] + try: + return self._extensions_by_number[message_descriptor][number] + except KeyError: + self._TryLoadExtensionFromDB(message_descriptor, number) + return self._extensions_by_number[message_descriptor][number] def FindAllExtensions(self, message_descriptor): """Gets all the known extension of a given message. @@ -513,8 +574,57 @@ class DescriptorPool(object): Returns: A list of FieldDescriptor describing the extensions. """ + # Fallback to descriptor db if FindAllExtensionNumbers is provided. + if self._descriptor_db and hasattr( + self._descriptor_db, 'FindAllExtensionNumbers'): + full_name = message_descriptor.full_name + all_numbers = self._descriptor_db.FindAllExtensionNumbers(full_name) + for number in all_numbers: + if number in self._extensions_by_number[message_descriptor]: + continue + self._TryLoadExtensionFromDB(message_descriptor, number) + return list(self._extensions_by_number[message_descriptor].values()) + def _TryLoadExtensionFromDB(self, message_descriptor, number): + """Try to Load extensions from decriptor db. + + Args: + message_descriptor: descriptor of the extended message. + number: the extension number that needs to be loaded. + """ + if not self._descriptor_db: + return + # Only supported when FindFileContainingExtension is provided. + if not hasattr( + self._descriptor_db, 'FindFileContainingExtension'): + return + + full_name = message_descriptor.full_name + file_proto = self._descriptor_db.FindFileContainingExtension( + full_name, number) + + if file_proto is None: + return + + try: + file_desc = self._ConvertFileProtoToFileDescriptor(file_proto) + for extension in file_desc.extensions_by_name.values(): + self._extensions_by_number[extension.containing_type][ + extension.number] = extension + self._extensions_by_name[extension.containing_type][ + extension.full_name] = extension + for message_type in file_desc.message_types_by_name.values(): + for extension in message_type.extensions: + self._extensions_by_number[extension.containing_type][ + extension.number] = extension + self._extensions_by_name[extension.containing_type][ + extension.full_name] = extension + except: + warn_msg = ('Unable to load proto file %s for extension number %d.' % + (file_proto.name, number)) + warnings.warn(warn_msg, RuntimeWarning) + def FindServiceByName(self, full_name): """Loads the named service descriptor from the pool. @@ -532,6 +642,23 @@ class DescriptorPool(object): self._FindFileContainingSymbolInDb(full_name) return self._service_descriptors[full_name] + def FindMethodByName(self, full_name): + """Loads the named service method descriptor from the pool. + + Args: + full_name: The full name of the method descriptor to load. + + Returns: + The method descriptor for the service method. + + Raises: + KeyError: if the method cannot be found in the pool. + """ + full_name = _NormalizeFullyQualifiedName(full_name) + service_name, _, method_name = full_name.rpartition('.') + service_descriptor = self.FindServiceByName(service_name) + return service_descriptor.methods_by_name[method_name] + def _FindFileContainingSymbolInDb(self, symbol): """Finds the file in descriptor DB containing the specified symbol. @@ -567,7 +694,6 @@ class DescriptorPool(object): Returns: A FileDescriptor matching the passed in proto. """ - if file_proto.name not in self._file_descriptors: built_deps = list(self._GetDeps(file_proto.dependency)) direct_deps = [self.FindFileByName(n) for n in file_proto.dependency] @@ -604,7 +730,7 @@ class DescriptorPool(object): for enum_type in file_proto.enum_type: file_descriptor.enum_types_by_name[enum_type.name] = ( self._ConvertEnumDescriptor(enum_type, file_proto.package, - file_descriptor, None, scope)) + file_descriptor, None, scope, True)) for index, extension_proto in enumerate(file_proto.extension): extension_desc = self._MakeFieldDescriptor( @@ -616,6 +742,8 @@ class DescriptorPool(object): file_descriptor.package, scope) file_descriptor.extensions_by_name[extension_desc.name] = ( extension_desc) + self._file_desc_by_toplevel_extension[extension_desc.full_name] = ( + file_descriptor) for desc_proto in file_proto.message_type: self._SetAllFieldTypes(file_proto.package, desc_proto, scope) @@ -673,7 +801,8 @@ class DescriptorPool(object): nested, desc_name, file_desc, scope, syntax) for nested in desc_proto.nested_type] enums = [ - self._ConvertEnumDescriptor(enum, desc_name, file_desc, None, scope) + self._ConvertEnumDescriptor(enum, desc_name, file_desc, None, + scope, False) for enum in desc_proto.enum_type] fields = [self._MakeFieldDescriptor(field, desc_name, index, file_desc) for index, field in enumerate(desc_proto.field)] @@ -718,12 +847,12 @@ class DescriptorPool(object): fields[field_index].containing_oneof = oneofs[oneof_index] scope[_PrefixWithDot(desc_name)] = desc - self._CheckConflictRegister(desc) + self._CheckConflictRegister(desc, desc.full_name, desc.file.name) self._descriptors[desc_name] = desc return desc def _ConvertEnumDescriptor(self, enum_proto, package=None, file_desc=None, - containing_type=None, scope=None): + containing_type=None, scope=None, top_level=False): """Make a protobuf EnumDescriptor given an EnumDescriptorProto protobuf. Args: @@ -732,6 +861,8 @@ class DescriptorPool(object): file_desc: The file containing the enum descriptor. containing_type: The type containing this enum. scope: Scope containing available types. + top_level: If True, the enum is a top level symbol. If False, the enum + is defined inside a message. Returns: The added descriptor @@ -757,8 +888,17 @@ class DescriptorPool(object): containing_type=containing_type, options=_OptionsOrNone(enum_proto)) scope['.%s' % enum_name] = desc - self._CheckConflictRegister(desc) + self._CheckConflictRegister(desc, desc.full_name, desc.file.name) self._enum_descriptors[enum_name] = desc + + # Add top level enum values. + if top_level: + for value in values: + full_name = _NormalizeFullyQualifiedName( + '.'.join((package, value.name))) + self._CheckConflictRegister(value, full_name, file_name) + self._top_enum_values[full_name] = value + return desc def _MakeFieldDescriptor(self, field_proto, message_name, index, @@ -885,6 +1025,8 @@ class DescriptorPool(object): elif field_proto.type == descriptor.FieldDescriptor.TYPE_BYTES: field_desc.default_value = text_encoding.CUnescape( field_proto.default_value) + elif field_proto.type == descriptor.FieldDescriptor.TYPE_MESSAGE: + field_desc.default_value = None else: # All other types are of the "int" type. field_desc.default_value = int(field_proto.default_value) @@ -901,6 +1043,8 @@ class DescriptorPool(object): field_desc.default_value = field_desc.enum_type.values[0].number elif field_proto.type == descriptor.FieldDescriptor.TYPE_BYTES: field_desc.default_value = b'' + elif field_proto.type == descriptor.FieldDescriptor.TYPE_MESSAGE: + field_desc.default_value = None else: # All other types are of the "int" type. field_desc.default_value = 0 @@ -954,7 +1098,7 @@ class DescriptorPool(object): methods=methods, options=_OptionsOrNone(service_proto), file=file_desc) - self._CheckConflictRegister(desc) + self._CheckConflictRegister(desc, desc.full_name, desc.file.name) self._service_descriptors[service_name] = desc return desc |