diff options
Diffstat (limited to 'python/google/protobuf/symbol_database.py')
-rw-r--r-- | python/google/protobuf/symbol_database.py | 26 |
1 files changed, 17 insertions, 9 deletions
diff --git a/python/google/protobuf/symbol_database.py b/python/google/protobuf/symbol_database.py index 07341efa..5ad869f4 100644 --- a/python/google/protobuf/symbol_database.py +++ b/python/google/protobuf/symbol_database.py @@ -78,10 +78,18 @@ class SymbolDatabase(message_factory.MessageFactory): """ desc = message.DESCRIPTOR - self._classes[desc.full_name] = message - self.pool.AddDescriptor(desc) + self._classes[desc] = message + self.RegisterMessageDescriptor(desc) return message + def RegisterMessageDescriptor(self, message_descriptor): + """Registers the given message descriptor in the local database. + + Args: + message_descriptor: a descriptor.MessageDescriptor. + """ + self.pool.AddDescriptor(message_descriptor) + def RegisterEnumDescriptor(self, enum_descriptor): """Registers the given enum descriptor in the local database. @@ -132,7 +140,7 @@ class SymbolDatabase(message_factory.MessageFactory): KeyError: if the symbol could not be found. """ - return self._classes[symbol] + return self._classes[self.pool.FindMessageTypeByName(symbol)] def GetMessages(self, files): # TODO(amauryfa): Fix the differences with MessageFactory. @@ -153,20 +161,20 @@ class SymbolDatabase(message_factory.MessageFactory): KeyError: if a file could not be found. """ - def _GetAllMessageNames(desc): + def _GetAllMessages(desc): """Walk a message Descriptor and recursively yields all message names.""" - yield desc.full_name + yield desc for msg_desc in desc.nested_types: - for full_name in _GetAllMessageNames(msg_desc): - yield full_name + for nested_desc in _GetAllMessages(msg_desc): + yield nested_desc result = {} for file_name in files: file_desc = self.pool.FindFileByName(file_name) for msg_desc in file_desc.message_types_by_name.values(): - for full_name in _GetAllMessageNames(msg_desc): + for desc in _GetAllMessages(msg_desc): try: - result[full_name] = self._classes[full_name] + result[desc.full_name] = self._classes[desc] except KeyError: # This descriptor has no registered class, skip it. pass |