diff options
Diffstat (limited to 'python/google/protobuf/symbol_database.py')
-rw-r--r-- | python/google/protobuf/symbol_database.py | 82 |
1 files changed, 33 insertions, 49 deletions
diff --git a/python/google/protobuf/symbol_database.py b/python/google/protobuf/symbol_database.py index 87760f26..aa466abd 100644 --- a/python/google/protobuf/symbol_database.py +++ b/python/google/protobuf/symbol_database.py @@ -30,11 +30,9 @@ """A database of Python protocol buffer generated symbols. -SymbolDatabase makes it easy to create new instances of a registered type, given -only the type's protocol buffer symbol name. Once all symbols are registered, -they can be accessed using either the MessageFactory interface which -SymbolDatabase exposes, or the DescriptorPool interface of the underlying -pool. +SymbolDatabase is the MessageFactory for messages generated at compile time, +and makes it easy to create new instances of a registered type, given only the +type's protocol buffer symbol name. Example usage: @@ -61,27 +59,17 @@ Example usage: from google.protobuf import descriptor_pool +from google.protobuf import message_factory -class SymbolDatabase(object): - """A database of Python generated symbols. - - SymbolDatabase also models message_factory.MessageFactory. - - The symbol database can be used to keep a global registry of all protocol - buffer types used within a program. - """ - - def __init__(self, pool=None): - """Constructor.""" - - self._symbols = {} - self._symbols_by_file = {} - self.pool = pool or descriptor_pool.Default() +class SymbolDatabase(message_factory.MessageFactory): + """A database of Python generated symbols.""" def RegisterMessage(self, message): """Registers the given message type in the local database. + Calls to GetSymbol() and GetMessages() will return messages registered here. + Args: message: a message.Message, to be registered. @@ -90,10 +78,7 @@ class SymbolDatabase(object): """ desc = message.DESCRIPTOR - self._symbols[desc.full_name] = message - if desc.file.name not in self._symbols_by_file: - self._symbols_by_file[desc.file.name] = {} - self._symbols_by_file[desc.file.name][desc.full_name] = message + self._classes[desc.full_name] = message self.pool.AddDescriptor(desc) return message @@ -136,47 +121,46 @@ class SymbolDatabase(object): KeyError: if the symbol could not be found. """ - return self._symbols[symbol] - - def GetPrototype(self, descriptor): - """Builds a proto2 message class based on the passed in descriptor. - - Passing a descriptor with a fully qualified name matching a previous - invocation will cause the same class to be returned. - - Args: - descriptor: The descriptor to build from. - - Returns: - A class describing the passed in descriptor. - """ - - return self.GetSymbol(descriptor.full_name) + return self._classes[symbol] def GetMessages(self, files): - """Gets all the messages from a specified file. - - This will find and resolve dependencies, failing if they are not registered - in the symbol database. + # TODO(amauryfa): Fix the differences with MessageFactory. + """Gets all registered messages from a specified file. + Only messages already created and registered will be returned; (this is the + case for imported _pb2 modules) + But unlike MessageFactory, this version also returns nested messages. Args: files: The file names to extract messages from. Returns: - A dictionary mapping proto names to the message classes. This will include - any dependent messages as well as any messages defined in the same file as - a specified message. + A dictionary mapping proto names to the message classes. Raises: KeyError: if a file could not be found. """ + def _GetAllMessageNames(desc): + """Walk a message Descriptor and recursively yields all message names.""" + yield desc.full_name + for msg_desc in desc.nested_types: + for full_name in _GetAllMessageNames(msg_desc): + yield full_name + result = {} - for f in files: - result.update(self._symbols_by_file[f]) + for file_name in files: + file_desc = self.pool.FindFileByName(file_name) + for msg_desc in file_desc.message_types_by_name.values(): + for full_name in _GetAllMessageNames(msg_desc): + try: + result[full_name] = self._classes[full_name] + except KeyError: + # This descriptor has no registered class, skip it. + pass return result + _DEFAULT = SymbolDatabase(pool=descriptor_pool.Default()) |