diff options
Diffstat (limited to 'python/google/protobuf/message_factory.py')
-rw-r--r-- | python/google/protobuf/message_factory.py | 38 |
1 files changed, 21 insertions, 17 deletions
diff --git a/python/google/protobuf/message_factory.py b/python/google/protobuf/message_factory.py index 1b059d13..e4fb065e 100644 --- a/python/google/protobuf/message_factory.py +++ b/python/google/protobuf/message_factory.py @@ -66,7 +66,7 @@ class MessageFactory(object): Returns: A class describing the passed in descriptor. """ - if descriptor.full_name not in self._classes: + if descriptor not in self._classes: descriptor_name = descriptor.name if str is bytes: # PY2 descriptor_name = descriptor.name.encode('ascii', 'ignore') @@ -75,16 +75,16 @@ class MessageFactory(object): (message.Message,), {'DESCRIPTOR': descriptor, '__module__': None}) # If module not set, it wrongly points to the reflection.py module. - self._classes[descriptor.full_name] = result_class + self._classes[descriptor] = result_class for field in descriptor.fields: if field.message_type: self.GetPrototype(field.message_type) for extension in result_class.DESCRIPTOR.extensions: - if extension.containing_type.full_name not in self._classes: + if extension.containing_type not in self._classes: self.GetPrototype(extension.containing_type) - extended_class = self._classes[extension.containing_type.full_name] + extended_class = self._classes[extension.containing_type] extended_class.RegisterExtension(extension) - return self._classes[descriptor.full_name] + return self._classes[descriptor] def GetMessages(self, files): """Gets all the messages from a specified file. @@ -103,13 +103,8 @@ class MessageFactory(object): result = {} for file_name in files: file_desc = self.pool.FindFileByName(file_name) - for name, msg in file_desc.message_types_by_name.items(): - if file_desc.package: - full_name = '.'.join([file_desc.package, name]) - else: - full_name = msg.name - result[full_name] = self.GetPrototype( - self.pool.FindMessageTypeByName(full_name)) + for desc in file_desc.message_types_by_name.values(): + result[desc.full_name] = self.GetPrototype(desc) # While the extension FieldDescriptors are created by the descriptor pool, # the python classes created in the factory need them to be registered @@ -120,10 +115,10 @@ class MessageFactory(object): # ignore the registration if the original was the same, or raise # an error if they were different. - for name, extension in file_desc.extensions_by_name.items(): - if extension.containing_type.full_name not in self._classes: + for extension in file_desc.extensions_by_name.values(): + if extension.containing_type not in self._classes: self.GetPrototype(extension.containing_type) - extended_class = self._classes[extension.containing_type.full_name] + extended_class = self._classes[extension.containing_type] extended_class.RegisterExtension(extension) return result @@ -135,13 +130,22 @@ def GetMessages(file_protos): """Builds a dictionary of all the messages available in a set of files. Args: - file_protos: A sequence of file protos to build messages out of. + file_protos: Iterable of FileDescriptorProto to build messages out of. Returns: A dictionary mapping proto names to the message classes. This will include any dependent messages as well as any messages defined in the same file as a specified message. """ - for file_proto in file_protos: + # The cpp implementation of the protocol buffer library requires to add the + # message in topological order of the dependency graph. + file_by_name = {file_proto.name: file_proto for file_proto in file_protos} + def _AddFile(file_proto): + for dependency in file_proto.dependency: + if dependency in file_by_name: + # Remove from elements to be visited, in order to cut cycles. + _AddFile(file_by_name.pop(dependency)) _FACTORY.pool.Add(file_proto) + while file_by_name: + _AddFile(file_by_name.popitem()[1]) return _FACTORY.GetMessages([file_proto.name for file_proto in file_protos]) |