diff options
Diffstat (limited to 'python/google/protobuf/descriptor_pool.py')
-rw-r--r-- | python/google/protobuf/descriptor_pool.py | 196 |
1 files changed, 174 insertions, 22 deletions
diff --git a/python/google/protobuf/descriptor_pool.py b/python/google/protobuf/descriptor_pool.py index fc3a7f44..8983f76f 100644 --- a/python/google/protobuf/descriptor_pool.py +++ b/python/google/protobuf/descriptor_pool.py @@ -58,6 +58,7 @@ directly instead of this class. __author__ = 'matthewtoia@google.com (Matt Toia)' import collections +import warnings from google.protobuf import descriptor from google.protobuf import descriptor_database @@ -124,14 +125,41 @@ class DescriptorPool(object): self._descriptor_db = descriptor_db self._descriptors = {} self._enum_descriptors = {} + self._service_descriptors = {} self._file_descriptors = {} self._toplevel_extensions = {} + # TODO(jieluo): Remove _file_desc_by_toplevel_extension after + # maybe year 2020 for compatibility issue (with 3.4.1 only). + self._file_desc_by_toplevel_extension = {} # We store extensions in two two-level mappings: The first key is the # descriptor of the message being extended, the second key is the extension # full name or its tag number. self._extensions_by_name = collections.defaultdict(dict) self._extensions_by_number = collections.defaultdict(dict) + def _CheckConflictRegister(self, desc): + """Check if the descriptor name conflicts with another of the same name. + + Args: + desc: Descriptor of a message, enum, service or extension. + """ + desc_name = desc.full_name + for register, descriptor_type in [ + (self._descriptors, descriptor.Descriptor), + (self._enum_descriptors, descriptor.EnumDescriptor), + (self._service_descriptors, descriptor.ServiceDescriptor), + (self._toplevel_extensions, descriptor.FieldDescriptor)]: + if desc_name in register: + file_name = register[desc_name].file.name + if not isinstance(desc, descriptor_type) or ( + file_name != desc.file.name): + warn_msg = ('Conflict register for file "' + desc.file.name + + '": ' + desc_name + + ' is already defined in file "' + + file_name + '"') + warnings.warn(warn_msg, RuntimeWarning) + return + def Add(self, file_desc_proto): """Adds the FileDescriptorProto and its types to this pool. @@ -168,13 +196,15 @@ class DescriptorPool(object): if not isinstance(desc, descriptor.Descriptor): raise TypeError('Expected instance of descriptor.Descriptor.') + self._CheckConflictRegister(desc) + self._descriptors[desc.full_name] = desc - self.AddFileDescriptor(desc.file) + self._AddFileDescriptor(desc.file) def AddEnumDescriptor(self, enum_desc): """Adds an EnumDescriptor to the pool. - This method also registers the FileDescriptor associated with the message. + This method also registers the FileDescriptor associated with the enum. Args: enum_desc: An EnumDescriptor. @@ -183,8 +213,22 @@ class DescriptorPool(object): if not isinstance(enum_desc, descriptor.EnumDescriptor): raise TypeError('Expected instance of descriptor.EnumDescriptor.') + self._CheckConflictRegister(enum_desc) self._enum_descriptors[enum_desc.full_name] = enum_desc - self.AddFileDescriptor(enum_desc.file) + self._AddFileDescriptor(enum_desc.file) + + def AddServiceDescriptor(self, service_desc): + """Adds a ServiceDescriptor to the pool. + + Args: + service_desc: A ServiceDescriptor. + """ + + if not isinstance(service_desc, descriptor.ServiceDescriptor): + raise TypeError('Expected instance of descriptor.ServiceDescriptor.') + + self._CheckConflictRegister(service_desc) + self._service_descriptors[service_desc.full_name] = service_desc def AddExtensionDescriptor(self, extension): """Adds a FieldDescriptor describing an extension to the pool. @@ -203,6 +247,7 @@ class DescriptorPool(object): raise TypeError('Expected an extension descriptor.') if extension.extension_scope is None: + self._CheckConflictRegister(extension) self._toplevel_extensions[extension.full_name] = extension try: @@ -238,6 +283,24 @@ class DescriptorPool(object): file_desc: A FileDescriptor. """ + self._AddFileDescriptor(file_desc) + # TODO(jieluo): This is a temporary solution for FieldDescriptor.file. + # FieldDescriptor.file is added in code gen. Remove this solution after + # maybe 2020 for compatibility reason (with 3.4.1 only). + for extension in file_desc.extensions_by_name.values(): + self._file_desc_by_toplevel_extension[ + extension.full_name] = file_desc + + def _AddFileDescriptor(self, file_desc): + """Adds a FileDescriptor to the pool, non-recursively. + + If the FileDescriptor contains messages or enums, the caller must explicitly + register them. + + Args: + file_desc: A FileDescriptor. + """ + if not isinstance(file_desc, descriptor.FileDescriptor): raise TypeError('Expected instance of descriptor.FileDescriptor.') self._file_descriptors[file_desc.name] = file_desc @@ -252,7 +315,7 @@ class DescriptorPool(object): A FileDescriptor for the named file. Raises: - KeyError: if the file can not be found in the pool. + KeyError: if the file cannot be found in the pool. """ try: @@ -281,7 +344,7 @@ class DescriptorPool(object): A FileDescriptor that contains the specified symbol. Raises: - KeyError: if the file can not be found in the pool. + KeyError: if the file cannot be found in the pool. """ symbol = _NormalizeFullyQualifiedName(symbol) @@ -296,15 +359,28 @@ class DescriptorPool(object): pass try: - file_proto = self._internal_db.FindFileContainingSymbol(symbol) - except KeyError as error: - if self._descriptor_db: - file_proto = self._descriptor_db.FindFileContainingSymbol(symbol) - else: - raise error - if not file_proto: + return self._service_descriptors[symbol].file + except KeyError: + pass + + try: + return self._FindFileContainingSymbolInDb(symbol) + except KeyError: + pass + + try: + return self._file_desc_by_toplevel_extension[symbol] + except KeyError: + pass + + # Try nested extensions inside a message. + message_name, _, extension_name = symbol.rpartition('.') + try: + message = self.FindMessageTypeByName(message_name) + assert message.extensions_by_name[extension_name] + return message.file + except KeyError: raise KeyError('Cannot find a file containing %s' % symbol) - return self._ConvertFileProtoToFileDescriptor(file_proto) def FindMessageTypeByName(self, full_name): """Loads the named descriptor from the pool. @@ -314,11 +390,14 @@ class DescriptorPool(object): Returns: The descriptor for the named type. + + Raises: + KeyError: if the message cannot be found in the pool. """ full_name = _NormalizeFullyQualifiedName(full_name) if full_name not in self._descriptors: - self.FindFileContainingSymbol(full_name) + self._FindFileContainingSymbolInDb(full_name) return self._descriptors[full_name] def FindEnumTypeByName(self, full_name): @@ -329,11 +408,14 @@ class DescriptorPool(object): Returns: The enum descriptor for the named type. + + Raises: + KeyError: if the enum cannot be found in the pool. """ full_name = _NormalizeFullyQualifiedName(full_name) if full_name not in self._enum_descriptors: - self.FindFileContainingSymbol(full_name) + self._FindFileContainingSymbolInDb(full_name) return self._enum_descriptors[full_name] def FindFieldByName(self, full_name): @@ -344,12 +426,32 @@ class DescriptorPool(object): Returns: The field descriptor for the named field. + + Raises: + KeyError: if the field cannot be found in the pool. """ full_name = _NormalizeFullyQualifiedName(full_name) message_name, _, field_name = full_name.rpartition('.') message_descriptor = self.FindMessageTypeByName(message_name) return message_descriptor.fields_by_name[field_name] + def FindOneofByName(self, full_name): + """Loads the named oneof descriptor from the pool. + + Args: + full_name: The full name of the oneof descriptor to load. + + Returns: + The oneof descriptor for the named oneof. + + Raises: + KeyError: if the oneof cannot be found in the pool. + """ + full_name = _NormalizeFullyQualifiedName(full_name) + message_name, _, oneof_name = full_name.rpartition('.') + message_descriptor = self.FindMessageTypeByName(message_name) + return message_descriptor.oneofs_by_name[oneof_name] + def FindExtensionByName(self, full_name): """Loads the named extension descriptor from the pool. @@ -358,6 +460,9 @@ class DescriptorPool(object): Returns: A FieldDescriptor, describing the named extension. + + Raises: + KeyError: if the extension cannot be found in the pool. """ full_name = _NormalizeFullyQualifiedName(full_name) try: @@ -374,7 +479,7 @@ class DescriptorPool(object): scope = self.FindMessageTypeByName(message_name) except KeyError: # Some extensions are defined at file scope. - scope = self.FindFileContainingSymbol(full_name) + scope = self._FindFileContainingSymbolInDb(full_name) return scope.extensions_by_name[extension_name] def FindExtensionByNumber(self, message_descriptor, number): @@ -390,7 +495,7 @@ class DescriptorPool(object): Returns: A FieldDescriptor describing the extension. - Raise: + Raises: KeyError: when no extension with the given number is known for the specified message. """ @@ -410,6 +515,46 @@ class DescriptorPool(object): """ return list(self._extensions_by_number[message_descriptor].values()) + def FindServiceByName(self, full_name): + """Loads the named service descriptor from the pool. + + Args: + full_name: The full name of the service descriptor to load. + + Returns: + The service descriptor for the named service. + + Raises: + KeyError: if the service cannot be found in the pool. + """ + full_name = _NormalizeFullyQualifiedName(full_name) + if full_name not in self._service_descriptors: + self._FindFileContainingSymbolInDb(full_name) + return self._service_descriptors[full_name] + + def _FindFileContainingSymbolInDb(self, symbol): + """Finds the file in descriptor DB containing the specified symbol. + + Args: + symbol: The name of the symbol to search for. + + Returns: + A FileDescriptor that contains the specified symbol. + + Raises: + KeyError: if the file cannot be found in the descriptor database. + """ + try: + file_proto = self._internal_db.FindFileContainingSymbol(symbol) + except KeyError as error: + if self._descriptor_db: + file_proto = self._descriptor_db.FindFileContainingSymbol(symbol) + else: + raise error + if not file_proto: + raise KeyError('Cannot find a file containing %s' % symbol) + return self._ConvertFileProtoToFileDescriptor(file_proto) + def _ConvertFileProtoToFileDescriptor(self, file_proto): """Creates a FileDescriptor from a proto or returns a cached copy. @@ -463,7 +608,8 @@ class DescriptorPool(object): for index, extension_proto in enumerate(file_proto.extension): extension_desc = self._MakeFieldDescriptor( - extension_proto, file_proto.package, index, is_extension=True) + extension_proto, file_proto.package, index, file_descriptor, + is_extension=True) extension_desc.containing_type = self._GetTypeFromScope( file_descriptor.package, extension_proto.extendee, scope) self._SetFieldType(extension_proto, extension_desc, @@ -529,10 +675,10 @@ class DescriptorPool(object): enums = [ self._ConvertEnumDescriptor(enum, desc_name, file_desc, None, scope) for enum in desc_proto.enum_type] - fields = [self._MakeFieldDescriptor(field, desc_name, index) + fields = [self._MakeFieldDescriptor(field, desc_name, index, file_desc) for index, field in enumerate(desc_proto.field)] extensions = [ - self._MakeFieldDescriptor(extension, desc_name, index, + self._MakeFieldDescriptor(extension, desc_name, index, file_desc, is_extension=True) for index, extension in enumerate(desc_proto.extension)] oneofs = [ @@ -572,6 +718,7 @@ class DescriptorPool(object): fields[field_index].containing_oneof = oneofs[oneof_index] scope[_PrefixWithDot(desc_name)] = desc + self._CheckConflictRegister(desc) self._descriptors[desc_name] = desc return desc @@ -610,11 +757,12 @@ class DescriptorPool(object): containing_type=containing_type, options=_OptionsOrNone(enum_proto)) scope['.%s' % enum_name] = desc + self._CheckConflictRegister(desc) self._enum_descriptors[enum_name] = desc return desc def _MakeFieldDescriptor(self, field_proto, message_name, index, - is_extension=False): + file_desc, is_extension=False): """Creates a field descriptor from a FieldDescriptorProto. For message and enum type fields, this method will do a look up @@ -627,6 +775,7 @@ class DescriptorPool(object): field_proto: The proto describing the field. message_name: The name of the containing message. index: Index of the field + file_desc: The file containing the field descriptor. is_extension: Indication that this field is for an extension. Returns: @@ -653,7 +802,8 @@ class DescriptorPool(object): default_value=None, is_extension=is_extension, extension_scope=None, - options=_OptionsOrNone(field_proto)) + options=_OptionsOrNone(field_proto), + file=file_desc) def _SetAllFieldTypes(self, package, desc_proto, scope): """Sets all the descriptor's fields's types. @@ -804,6 +954,8 @@ class DescriptorPool(object): methods=methods, options=_OptionsOrNone(service_proto), file=file_desc) + self._CheckConflictRegister(desc) + self._service_descriptors[service_name] = desc return desc def _MakeMethodDescriptor(self, method_proto, service_name, package, scope, |