diff options
Diffstat (limited to 'python/google/protobuf/message_factory.py')
-rw-r--r-- | python/google/protobuf/message_factory.py | 110 |
1 files changed, 76 insertions, 34 deletions
diff --git a/python/google/protobuf/message_factory.py b/python/google/protobuf/message_factory.py index 36e2fef0..9004ffd9 100644 --- a/python/google/protobuf/message_factory.py +++ b/python/google/protobuf/message_factory.py @@ -28,10 +28,22 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -"""Provides a factory class for generating dynamic messages.""" +#PY25 compatible for GAE. +# +# Copyright 2012 Google Inc. All Rights Reserved. + +"""Provides a factory class for generating dynamic messages. + +The easiest way to use this class is if you have access to the FileDescriptor +protos containing the messages you want to create you can just do the following: + +message_classes = message_factory.GetMessages(iterable_of_file_descriptors) +my_proto_instance = message_classes['some.proto.package.MessageName']() +""" __author__ = 'matthewtoia@google.com (Matt Toia)' +import sys ##PY25 from google.protobuf import descriptor_database from google.protobuf import descriptor_pool from google.protobuf import message @@ -41,8 +53,12 @@ from google.protobuf import reflection class MessageFactory(object): """Factory for creating Proto2 messages from descriptors in a pool.""" - def __init__(self): + def __init__(self, pool=None): """Initializes a new factory.""" + self.pool = (pool or descriptor_pool.DescriptorPool( + descriptor_database.DescriptorDatabase())) + + # local cache of all classes built from protobuf descriptors self._classes = {} def GetPrototype(self, descriptor): @@ -57,21 +73,69 @@ class MessageFactory(object): Returns: A class describing the passed in descriptor. """ - if descriptor.full_name not in self._classes: + descriptor_name = descriptor.name + if sys.version_info[0] < 3: ##PY25 +##!PY25 if str is bytes: # PY2 + descriptor_name = descriptor.name.encode('ascii', 'ignore') result_class = reflection.GeneratedProtocolMessageType( - descriptor.name.encode('ascii', 'ignore'), + descriptor_name, (message.Message,), - {'DESCRIPTOR': descriptor}) + {'DESCRIPTOR': descriptor, '__module__': None}) + # If module not set, it wrongly points to the reflection.py module. self._classes[descriptor.full_name] = result_class for field in descriptor.fields: if field.message_type: self.GetPrototype(field.message_type) + for extension in result_class.DESCRIPTOR.extensions: + if extension.containing_type.full_name not in self._classes: + self.GetPrototype(extension.containing_type) + extended_class = self._classes[extension.containing_type.full_name] + extended_class.RegisterExtension(extension) return self._classes[descriptor.full_name] + def GetMessages(self, files): + """Gets all the messages from a specified file. + + This will find and resolve dependencies, failing if the descriptor + pool cannot satisfy them. + + Args: + files: The file names to extract messages from. + + Returns: + A dictionary mapping proto names to the message classes. This will include + any dependent messages as well as any messages defined in the same file as + a specified message. + """ + result = {} + for file_name in files: + file_desc = self.pool.FindFileByName(file_name) + for name, msg in file_desc.message_types_by_name.iteritems(): + if file_desc.package: + full_name = '.'.join([file_desc.package, name]) + else: + full_name = msg.name + result[full_name] = self.GetPrototype( + self.pool.FindMessageTypeByName(full_name)) + + # While the extension FieldDescriptors are created by the descriptor pool, + # the python classes created in the factory need them to be registered + # explicitly, which is done below. + # + # The call to RegisterExtension will specifically check if the + # extension was already registered on the object and either + # ignore the registration if the original was the same, or raise + # an error if they were different. + + for name, extension in file_desc.extensions_by_name.iteritems(): + if extension.containing_type.full_name not in self._classes: + self.GetPrototype(extension.containing_type) + extended_class = self._classes[extension.containing_type.full_name] + extended_class.RegisterExtension(extension) + return result + -_DB = descriptor_database.DescriptorDatabase() -_POOL = descriptor_pool.DescriptorPool(_DB) _FACTORY = MessageFactory() @@ -82,32 +146,10 @@ def GetMessages(file_protos): file_protos: A sequence of file protos to build messages out of. Returns: - A dictionary containing all the message types in the files mapping the - fully qualified name to a Message subclass for the descriptor. + A dictionary mapping proto names to the message classes. This will include + any dependent messages as well as any messages defined in the same file as + a specified message. """ - - result = {} for file_proto in file_protos: - _DB.Add(file_proto) - for file_proto in file_protos: - for desc in _GetAllDescriptors(file_proto.message_type, file_proto.package): - result[desc.full_name] = _FACTORY.GetPrototype(desc) - return result - - -def _GetAllDescriptors(desc_protos, package): - """Gets all levels of nested message types as a flattened list of descriptors. - - Args: - desc_protos: The descriptor protos to process. - package: The package where the protos are defined. - - Yields: - Each message descriptor for each nested type. - """ - - for desc_proto in desc_protos: - name = '.'.join((package, desc_proto.name)) - yield _POOL.FindMessageTypeByName(name) - for nested_desc in _GetAllDescriptors(desc_proto.nested_type, name): - yield nested_desc + _FACTORY.pool.Add(file_proto) + return _FACTORY.GetMessages([file_proto.name for file_proto in file_protos]) |