diff options
Diffstat (limited to 'python/google/protobuf/descriptor_pool.py')
-rw-r--r-- | python/google/protobuf/descriptor_pool.py | 99 |
1 files changed, 99 insertions, 0 deletions
diff --git a/python/google/protobuf/descriptor_pool.py b/python/google/protobuf/descriptor_pool.py index 5f43ee5f..28b7e843 100644 --- a/python/google/protobuf/descriptor_pool.py +++ b/python/google/protobuf/descriptor_pool.py @@ -57,6 +57,8 @@ directly instead of this class. __author__ = 'matthewtoia@google.com (Matt Toia)' +import collections + from google.protobuf import descriptor from google.protobuf import descriptor_database from google.protobuf import text_encoding @@ -88,6 +90,14 @@ def _OptionsOrNone(descriptor_proto): return None +def _IsMessageSetExtension(field): + return (field.is_extension and + field.containing_type.has_options and + field.containing_type.GetOptions().message_set_wire_format and + field.type == descriptor.FieldDescriptor.TYPE_MESSAGE and + field.label == descriptor.FieldDescriptor.LABEL_OPTIONAL) + + class DescriptorPool(object): """A collection of protobufs dynamically constructed by descriptor protos.""" @@ -115,6 +125,12 @@ class DescriptorPool(object): self._descriptors = {} self._enum_descriptors = {} self._file_descriptors = {} + self._toplevel_extensions = {} + # We store extensions in two two-level mappings: The first key is the + # descriptor of the message being extended, the second key is the extension + # full name or its tag number. + self._extensions_by_name = collections.defaultdict(dict) + self._extensions_by_number = collections.defaultdict(dict) def Add(self, file_desc_proto): """Adds the FileDescriptorProto and its types to this pool. @@ -170,6 +186,48 @@ class DescriptorPool(object): self._enum_descriptors[enum_desc.full_name] = enum_desc self.AddFileDescriptor(enum_desc.file) + def AddExtensionDescriptor(self, extension): + """Adds a FieldDescriptor describing an extension to the pool. + + Args: + extension: A FieldDescriptor. + + Raises: + AssertionError: when another extension with the same number extends the + same message. + TypeError: when the specified extension is not a + descriptor.FieldDescriptor. + """ + if not (isinstance(extension, descriptor.FieldDescriptor) and + extension.is_extension): + raise TypeError('Expected an extension descriptor.') + + if extension.extension_scope is None: + self._toplevel_extensions[extension.full_name] = extension + + try: + existing_desc = self._extensions_by_number[ + extension.containing_type][extension.number] + except KeyError: + pass + else: + if extension is not existing_desc: + raise AssertionError( + 'Extensions "%s" and "%s" both try to extend message type "%s" ' + 'with field number %d.' % + (extension.full_name, existing_desc.full_name, + extension.containing_type.full_name, extension.number)) + + self._extensions_by_number[extension.containing_type][ + extension.number] = extension + self._extensions_by_name[extension.containing_type][ + extension.full_name] = extension + + # Also register MessageSet extensions with the type name. + if _IsMessageSetExtension(extension): + self._extensions_by_name[extension.containing_type][ + extension.message_type.full_name] = extension + def AddFileDescriptor(self, file_desc): """Adds a FileDescriptor to the pool, non-recursively. @@ -302,6 +360,14 @@ class DescriptorPool(object): A FieldDescriptor, describing the named extension. """ full_name = _NormalizeFullyQualifiedName(full_name) + try: + # The proto compiler does not give any link between the FileDescriptor + # and top-level extensions unless the FileDescriptorProto is added to + # the DescriptorDatabase, but this can impact memory usage. + # So we registered these extensions by name explicitly. + return self._toplevel_extensions[full_name] + except KeyError: + pass message_name, _, extension_name = full_name.rpartition('.') try: # Most extensions are nested inside a message. @@ -311,6 +377,39 @@ class DescriptorPool(object): scope = self.FindFileContainingSymbol(full_name) return scope.extensions_by_name[extension_name] + def FindExtensionByNumber(self, message_descriptor, number): + """Gets the extension of the specified message with the specified number. + + Extensions have to be registered to this pool by calling + AddExtensionDescriptor. + + Args: + message_descriptor: descriptor of the extended message. + number: integer, number of the extension field. + + Returns: + A FieldDescriptor describing the extension. + + Raise: + KeyError: when no extension with the given number is known for the + specified message. + """ + return self._extensions_by_number[message_descriptor][number] + + def FindAllExtensions(self, message_descriptor): + """Gets all the known extension of a given message. + + Extensions have to be registered to this pool by calling + AddExtensionDescriptor. + + Args: + message_descriptor: descriptor of the extended message. + + Returns: + A list of FieldDescriptor describing the extensions. + """ + return self._extensions_by_number[message_descriptor].values() + def _ConvertFileProtoToFileDescriptor(self, file_proto): """Creates a FileDescriptor from a proto or returns a cached copy. |