aboutsummaryrefslogtreecommitdiff
path: root/python/google/protobuf/message_factory.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/google/protobuf/message_factory.py')
-rw-r--r--python/google/protobuf/message_factory.py27
1 files changed, 18 insertions, 9 deletions
diff --git a/python/google/protobuf/message_factory.py b/python/google/protobuf/message_factory.py
index 8ab1c513..e4fb065e 100644
--- a/python/google/protobuf/message_factory.py
+++ b/python/google/protobuf/message_factory.py
@@ -66,7 +66,7 @@ class MessageFactory(object):
Returns:
A class describing the passed in descriptor.
"""
- if descriptor.full_name not in self._classes:
+ if descriptor not in self._classes:
descriptor_name = descriptor.name
if str is bytes: # PY2
descriptor_name = descriptor.name.encode('ascii', 'ignore')
@@ -75,16 +75,16 @@ class MessageFactory(object):
(message.Message,),
{'DESCRIPTOR': descriptor, '__module__': None})
# If module not set, it wrongly points to the reflection.py module.
- self._classes[descriptor.full_name] = result_class
+ self._classes[descriptor] = result_class
for field in descriptor.fields:
if field.message_type:
self.GetPrototype(field.message_type)
for extension in result_class.DESCRIPTOR.extensions:
- if extension.containing_type.full_name not in self._classes:
+ if extension.containing_type not in self._classes:
self.GetPrototype(extension.containing_type)
- extended_class = self._classes[extension.containing_type.full_name]
+ extended_class = self._classes[extension.containing_type]
extended_class.RegisterExtension(extension)
- return self._classes[descriptor.full_name]
+ return self._classes[descriptor]
def GetMessages(self, files):
"""Gets all the messages from a specified file.
@@ -116,9 +116,9 @@ class MessageFactory(object):
# an error if they were different.
for extension in file_desc.extensions_by_name.values():
- if extension.containing_type.full_name not in self._classes:
+ if extension.containing_type not in self._classes:
self.GetPrototype(extension.containing_type)
- extended_class = self._classes[extension.containing_type.full_name]
+ extended_class = self._classes[extension.containing_type]
extended_class.RegisterExtension(extension)
return result
@@ -130,13 +130,22 @@ def GetMessages(file_protos):
"""Builds a dictionary of all the messages available in a set of files.
Args:
- file_protos: A sequence of file protos to build messages out of.
+ file_protos: Iterable of FileDescriptorProto to build messages out of.
Returns:
A dictionary mapping proto names to the message classes. This will include
any dependent messages as well as any messages defined in the same file as
a specified message.
"""
- for file_proto in file_protos:
+ # The cpp implementation of the protocol buffer library requires to add the
+ # message in topological order of the dependency graph.
+ file_by_name = {file_proto.name: file_proto for file_proto in file_protos}
+ def _AddFile(file_proto):
+ for dependency in file_proto.dependency:
+ if dependency in file_by_name:
+ # Remove from elements to be visited, in order to cut cycles.
+ _AddFile(file_by_name.pop(dependency))
_FACTORY.pool.Add(file_proto)
+ while file_by_name:
+ _AddFile(file_by_name.popitem()[1])
return _FACTORY.GetMessages([file_proto.name for file_proto in file_protos])