diff options
Diffstat (limited to 'python/google/protobuf/descriptor_pool.py')
-rw-r--r-- | python/google/protobuf/descriptor_pool.py | 99 |
1 files changed, 84 insertions, 15 deletions
diff --git a/python/google/protobuf/descriptor_pool.py b/python/google/protobuf/descriptor_pool.py index fc3a7f44..7844575f 100644 --- a/python/google/protobuf/descriptor_pool.py +++ b/python/google/protobuf/descriptor_pool.py @@ -124,6 +124,7 @@ class DescriptorPool(object): self._descriptor_db = descriptor_db self._descriptors = {} self._enum_descriptors = {} + self._service_descriptors = {} self._file_descriptors = {} self._toplevel_extensions = {} # We store extensions in two two-level mappings: The first key is the @@ -174,7 +175,7 @@ class DescriptorPool(object): def AddEnumDescriptor(self, enum_desc): """Adds an EnumDescriptor to the pool. - This method also registers the FileDescriptor associated with the message. + This method also registers the FileDescriptor associated with the enum. Args: enum_desc: An EnumDescriptor. @@ -186,6 +187,18 @@ class DescriptorPool(object): self._enum_descriptors[enum_desc.full_name] = enum_desc self.AddFileDescriptor(enum_desc.file) + def AddServiceDescriptor(self, service_desc): + """Adds a ServiceDescriptor to the pool. + + Args: + service_desc: A ServiceDescriptor. + """ + + if not isinstance(service_desc, descriptor.ServiceDescriptor): + raise TypeError('Expected instance of descriptor.ServiceDescriptor.') + + self._service_descriptors[service_desc.full_name] = service_desc + def AddExtensionDescriptor(self, extension): """Adds a FieldDescriptor describing an extension to the pool. @@ -252,7 +265,7 @@ class DescriptorPool(object): A FileDescriptor for the named file. Raises: - KeyError: if the file can not be found in the pool. + KeyError: if the file cannot be found in the pool. """ try: @@ -281,7 +294,7 @@ class DescriptorPool(object): A FileDescriptor that contains the specified symbol. Raises: - KeyError: if the file can not be found in the pool. + KeyError: if the file cannot be found in the pool. """ symbol = _NormalizeFullyQualifiedName(symbol) @@ -296,15 +309,18 @@ class DescriptorPool(object): pass try: - file_proto = self._internal_db.FindFileContainingSymbol(symbol) - except KeyError as error: - if self._descriptor_db: - file_proto = self._descriptor_db.FindFileContainingSymbol(symbol) - else: - raise error - if not file_proto: + return self._FindFileContainingSymbolInDb(symbol) + except KeyError: + pass + + # Try nested extensions inside a message. + message_name, _, extension_name = symbol.rpartition('.') + try: + scope = self.FindMessageTypeByName(message_name) + assert scope.extensions_by_name[extension_name] + return scope.file + except KeyError: raise KeyError('Cannot find a file containing %s' % symbol) - return self._ConvertFileProtoToFileDescriptor(file_proto) def FindMessageTypeByName(self, full_name): """Loads the named descriptor from the pool. @@ -314,11 +330,14 @@ class DescriptorPool(object): Returns: The descriptor for the named type. + + Raises: + KeyError: if the message cannot be found in the pool. """ full_name = _NormalizeFullyQualifiedName(full_name) if full_name not in self._descriptors: - self.FindFileContainingSymbol(full_name) + self._FindFileContainingSymbolInDb(full_name) return self._descriptors[full_name] def FindEnumTypeByName(self, full_name): @@ -329,11 +348,14 @@ class DescriptorPool(object): Returns: The enum descriptor for the named type. + + Raises: + KeyError: if the enum cannot be found in the pool. """ full_name = _NormalizeFullyQualifiedName(full_name) if full_name not in self._enum_descriptors: - self.FindFileContainingSymbol(full_name) + self._FindFileContainingSymbolInDb(full_name) return self._enum_descriptors[full_name] def FindFieldByName(self, full_name): @@ -344,6 +366,9 @@ class DescriptorPool(object): Returns: The field descriptor for the named field. + + Raises: + KeyError: if the field cannot be found in the pool. """ full_name = _NormalizeFullyQualifiedName(full_name) message_name, _, field_name = full_name.rpartition('.') @@ -358,6 +383,9 @@ class DescriptorPool(object): Returns: A FieldDescriptor, describing the named extension. + + Raises: + KeyError: if the extension cannot be found in the pool. """ full_name = _NormalizeFullyQualifiedName(full_name) try: @@ -374,7 +402,7 @@ class DescriptorPool(object): scope = self.FindMessageTypeByName(message_name) except KeyError: # Some extensions are defined at file scope. - scope = self.FindFileContainingSymbol(full_name) + scope = self._FindFileContainingSymbolInDb(full_name) return scope.extensions_by_name[extension_name] def FindExtensionByNumber(self, message_descriptor, number): @@ -390,7 +418,7 @@ class DescriptorPool(object): Returns: A FieldDescriptor describing the extension. - Raise: + Raises: KeyError: when no extension with the given number is known for the specified message. """ @@ -410,6 +438,46 @@ class DescriptorPool(object): """ return list(self._extensions_by_number[message_descriptor].values()) + def FindServiceByName(self, full_name): + """Loads the named service descriptor from the pool. + + Args: + full_name: The full name of the service descriptor to load. + + Returns: + The service descriptor for the named service. + + Raises: + KeyError: if the service cannot be found in the pool. + """ + full_name = _NormalizeFullyQualifiedName(full_name) + if full_name not in self._service_descriptors: + self._FindFileContainingSymbolInDb(full_name) + return self._service_descriptors[full_name] + + def _FindFileContainingSymbolInDb(self, symbol): + """Finds the file in descriptor DB containing the specified symbol. + + Args: + symbol: The name of the symbol to search for. + + Returns: + A FileDescriptor that contains the specified symbol. + + Raises: + KeyError: if the file cannot be found in the descriptor database. + """ + try: + file_proto = self._internal_db.FindFileContainingSymbol(symbol) + except KeyError as error: + if self._descriptor_db: + file_proto = self._descriptor_db.FindFileContainingSymbol(symbol) + else: + raise error + if not file_proto: + raise KeyError('Cannot find a file containing %s' % symbol) + return self._ConvertFileProtoToFileDescriptor(file_proto) + def _ConvertFileProtoToFileDescriptor(self, file_proto): """Creates a FileDescriptor from a proto or returns a cached copy. @@ -804,6 +872,7 @@ class DescriptorPool(object): methods=methods, options=_OptionsOrNone(service_proto), file=file_desc) + self._service_descriptors[service_name] = desc return desc def _MakeMethodDescriptor(self, method_proto, service_name, package, scope, |