aboutsummaryrefslogtreecommitdiff
path: root/python/google/protobuf/symbol_database.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/google/protobuf/symbol_database.py')
-rw-r--r--python/google/protobuf/symbol_database.py37
1 files changed, 28 insertions, 9 deletions
diff --git a/python/google/protobuf/symbol_database.py b/python/google/protobuf/symbol_database.py
index ecbef211..5ad869f4 100644
--- a/python/google/protobuf/symbol_database.py
+++ b/python/google/protobuf/symbol_database.py
@@ -78,10 +78,18 @@ class SymbolDatabase(message_factory.MessageFactory):
"""
desc = message.DESCRIPTOR
- self._classes[desc.full_name] = message
- self.pool.AddDescriptor(desc)
+ self._classes[desc] = message
+ self.RegisterMessageDescriptor(desc)
return message
+ def RegisterMessageDescriptor(self, message_descriptor):
+ """Registers the given message descriptor in the local database.
+
+ Args:
+ message_descriptor: a descriptor.MessageDescriptor.
+ """
+ self.pool.AddDescriptor(message_descriptor)
+
def RegisterEnumDescriptor(self, enum_descriptor):
"""Registers the given enum descriptor in the local database.
@@ -94,6 +102,17 @@ class SymbolDatabase(message_factory.MessageFactory):
self.pool.AddEnumDescriptor(enum_descriptor)
return enum_descriptor
+ def RegisterServiceDescriptor(self, service_descriptor):
+ """Registers the given service descriptor in the local database.
+
+ Args:
+ service_descriptor: a descriptor.ServiceDescriptor.
+
+ Returns:
+ The provided descriptor.
+ """
+ self.pool.AddServiceDescriptor(service_descriptor)
+
def RegisterFileDescriptor(self, file_descriptor):
"""Registers the given file descriptor in the local database.
@@ -121,7 +140,7 @@ class SymbolDatabase(message_factory.MessageFactory):
KeyError: if the symbol could not be found.
"""
- return self._classes[symbol]
+ return self._classes[self.pool.FindMessageTypeByName(symbol)]
def GetMessages(self, files):
# TODO(amauryfa): Fix the differences with MessageFactory.
@@ -142,20 +161,20 @@ class SymbolDatabase(message_factory.MessageFactory):
KeyError: if a file could not be found.
"""
- def _GetAllMessageNames(desc):
+ def _GetAllMessages(desc):
"""Walk a message Descriptor and recursively yields all message names."""
- yield desc.full_name
+ yield desc
for msg_desc in desc.nested_types:
- for full_name in _GetAllMessageNames(msg_desc):
- yield full_name
+ for nested_desc in _GetAllMessages(msg_desc):
+ yield nested_desc
result = {}
for file_name in files:
file_desc = self.pool.FindFileByName(file_name)
for msg_desc in file_desc.message_types_by_name.values():
- for full_name in _GetAllMessageNames(msg_desc):
+ for desc in _GetAllMessages(msg_desc):
try:
- result[full_name] = self._classes[full_name]
+ result[desc.full_name] = self._classes[desc]
except KeyError:
# This descriptor has no registered class, skip it.
pass