aboutsummaryrefslogtreecommitdiff
path: root/python/google/protobuf/descriptor_pool.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/google/protobuf/descriptor_pool.py')
-rw-r--r--python/google/protobuf/descriptor_pool.py196
1 files changed, 170 insertions, 26 deletions
diff --git a/python/google/protobuf/descriptor_pool.py b/python/google/protobuf/descriptor_pool.py
index 8983f76f..42f7bcb5 100644
--- a/python/google/protobuf/descriptor_pool.py
+++ b/python/google/protobuf/descriptor_pool.py
@@ -131,33 +131,46 @@ class DescriptorPool(object):
# TODO(jieluo): Remove _file_desc_by_toplevel_extension after
# maybe year 2020 for compatibility issue (with 3.4.1 only).
self._file_desc_by_toplevel_extension = {}
+ self._top_enum_values = {}
# We store extensions in two two-level mappings: The first key is the
# descriptor of the message being extended, the second key is the extension
# full name or its tag number.
self._extensions_by_name = collections.defaultdict(dict)
self._extensions_by_number = collections.defaultdict(dict)
- def _CheckConflictRegister(self, desc):
+ def _CheckConflictRegister(self, desc, desc_name, file_name):
"""Check if the descriptor name conflicts with another of the same name.
Args:
- desc: Descriptor of a message, enum, service or extension.
+ desc: Descriptor of a message, enum, service, extension or enum value.
+ desc_name: the full name of desc.
+ file_name: The file name of descriptor.
"""
- desc_name = desc.full_name
for register, descriptor_type in [
(self._descriptors, descriptor.Descriptor),
(self._enum_descriptors, descriptor.EnumDescriptor),
(self._service_descriptors, descriptor.ServiceDescriptor),
- (self._toplevel_extensions, descriptor.FieldDescriptor)]:
+ (self._toplevel_extensions, descriptor.FieldDescriptor),
+ (self._top_enum_values, descriptor.EnumValueDescriptor)]:
if desc_name in register:
- file_name = register[desc_name].file.name
+ old_desc = register[desc_name]
+ if isinstance(old_desc, descriptor.EnumValueDescriptor):
+ old_file = old_desc.type.file.name
+ else:
+ old_file = old_desc.file.name
+
if not isinstance(desc, descriptor_type) or (
- file_name != desc.file.name):
- warn_msg = ('Conflict register for file "' + desc.file.name +
+ old_file != file_name):
+ warn_msg = ('Conflict register for file "' + file_name +
'": ' + desc_name +
' is already defined in file "' +
- file_name + '"')
+ old_file + '"')
+ if isinstance(desc, descriptor.EnumValueDescriptor):
+ warn_msg += ('\nNote: enum values appear as '
+ 'siblings of the enum type instead of '
+ 'children of it.')
warnings.warn(warn_msg, RuntimeWarning)
+
return
def Add(self, file_desc_proto):
@@ -196,7 +209,7 @@ class DescriptorPool(object):
if not isinstance(desc, descriptor.Descriptor):
raise TypeError('Expected instance of descriptor.Descriptor.')
- self._CheckConflictRegister(desc)
+ self._CheckConflictRegister(desc, desc.full_name, desc.file.name)
self._descriptors[desc.full_name] = desc
self._AddFileDescriptor(desc.file)
@@ -213,8 +226,26 @@ class DescriptorPool(object):
if not isinstance(enum_desc, descriptor.EnumDescriptor):
raise TypeError('Expected instance of descriptor.EnumDescriptor.')
- self._CheckConflictRegister(enum_desc)
+ file_name = enum_desc.file.name
+ self._CheckConflictRegister(enum_desc, enum_desc.full_name, file_name)
self._enum_descriptors[enum_desc.full_name] = enum_desc
+
+ # Top enum values need to be indexed.
+ # Count the number of dots to see whether the enum is toplevel or nested
+ # in a message. We cannot use enum_desc.containing_type at this stage.
+ if enum_desc.file.package:
+ top_level = (enum_desc.full_name.count('.')
+ - enum_desc.file.package.count('.') == 1)
+ else:
+ top_level = enum_desc.full_name.count('.') == 0
+ if top_level:
+ file_name = enum_desc.file.name
+ package = enum_desc.file.package
+ for enum_value in enum_desc.values:
+ full_name = _NormalizeFullyQualifiedName(
+ '.'.join((package, enum_value.name)))
+ self._CheckConflictRegister(enum_value, full_name, file_name)
+ self._top_enum_values[full_name] = enum_value
self._AddFileDescriptor(enum_desc.file)
def AddServiceDescriptor(self, service_desc):
@@ -227,7 +258,8 @@ class DescriptorPool(object):
if not isinstance(service_desc, descriptor.ServiceDescriptor):
raise TypeError('Expected instance of descriptor.ServiceDescriptor.')
- self._CheckConflictRegister(service_desc)
+ self._CheckConflictRegister(service_desc, service_desc.full_name,
+ service_desc.file.name)
self._service_descriptors[service_desc.full_name] = service_desc
def AddExtensionDescriptor(self, extension):
@@ -247,7 +279,6 @@ class DescriptorPool(object):
raise TypeError('Expected an extension descriptor.')
if extension.extension_scope is None:
- self._CheckConflictRegister(extension)
self._toplevel_extensions[extension.full_name] = extension
try:
@@ -349,6 +380,30 @@ class DescriptorPool(object):
symbol = _NormalizeFullyQualifiedName(symbol)
try:
+ return self._InternalFindFileContainingSymbol(symbol)
+ except KeyError:
+ pass
+
+ try:
+ # Try fallback database. Build and find again if possible.
+ self._FindFileContainingSymbolInDb(symbol)
+ return self._InternalFindFileContainingSymbol(symbol)
+ except KeyError:
+ raise KeyError('Cannot find a file containing %s' % symbol)
+
+ def _InternalFindFileContainingSymbol(self, symbol):
+ """Gets the already built FileDescriptor containing the specified symbol.
+
+ Args:
+ symbol: The name of the symbol to search for.
+
+ Returns:
+ A FileDescriptor that contains the specified symbol.
+
+ Raises:
+ KeyError: if the file cannot be found in the pool.
+ """
+ try:
return self._descriptors[symbol].file
except KeyError:
pass
@@ -364,7 +419,7 @@ class DescriptorPool(object):
pass
try:
- return self._FindFileContainingSymbolInDb(symbol)
+ return self._top_enum_values[symbol].type.file
except KeyError:
pass
@@ -373,13 +428,15 @@ class DescriptorPool(object):
except KeyError:
pass
- # Try nested extensions inside a message.
- message_name, _, extension_name = symbol.rpartition('.')
+ # Try fields, enum values and nested extensions inside a message.
+ top_name, _, sub_name = symbol.rpartition('.')
try:
- message = self.FindMessageTypeByName(message_name)
- assert message.extensions_by_name[extension_name]
+ message = self.FindMessageTypeByName(top_name)
+ assert (sub_name in message.extensions_by_name or
+ sub_name in message.fields_by_name or
+ sub_name in message.enum_values_by_name)
return message.file
- except KeyError:
+ except (KeyError, AssertionError):
raise KeyError('Cannot find a file containing %s' % symbol)
def FindMessageTypeByName(self, full_name):
@@ -499,7 +556,11 @@ class DescriptorPool(object):
KeyError: when no extension with the given number is known for the
specified message.
"""
- return self._extensions_by_number[message_descriptor][number]
+ try:
+ return self._extensions_by_number[message_descriptor][number]
+ except KeyError:
+ self._TryLoadExtensionFromDB(message_descriptor, number)
+ return self._extensions_by_number[message_descriptor][number]
def FindAllExtensions(self, message_descriptor):
"""Gets all the known extension of a given message.
@@ -513,8 +574,57 @@ class DescriptorPool(object):
Returns:
A list of FieldDescriptor describing the extensions.
"""
+ # Fallback to descriptor db if FindAllExtensionNumbers is provided.
+ if self._descriptor_db and hasattr(
+ self._descriptor_db, 'FindAllExtensionNumbers'):
+ full_name = message_descriptor.full_name
+ all_numbers = self._descriptor_db.FindAllExtensionNumbers(full_name)
+ for number in all_numbers:
+ if number in self._extensions_by_number[message_descriptor]:
+ continue
+ self._TryLoadExtensionFromDB(message_descriptor, number)
+
return list(self._extensions_by_number[message_descriptor].values())
+ def _TryLoadExtensionFromDB(self, message_descriptor, number):
+ """Try to Load extensions from decriptor db.
+
+ Args:
+ message_descriptor: descriptor of the extended message.
+ number: the extension number that needs to be loaded.
+ """
+ if not self._descriptor_db:
+ return
+ # Only supported when FindFileContainingExtension is provided.
+ if not hasattr(
+ self._descriptor_db, 'FindFileContainingExtension'):
+ return
+
+ full_name = message_descriptor.full_name
+ file_proto = self._descriptor_db.FindFileContainingExtension(
+ full_name, number)
+
+ if file_proto is None:
+ return
+
+ try:
+ file_desc = self._ConvertFileProtoToFileDescriptor(file_proto)
+ for extension in file_desc.extensions_by_name.values():
+ self._extensions_by_number[extension.containing_type][
+ extension.number] = extension
+ self._extensions_by_name[extension.containing_type][
+ extension.full_name] = extension
+ for message_type in file_desc.message_types_by_name.values():
+ for extension in message_type.extensions:
+ self._extensions_by_number[extension.containing_type][
+ extension.number] = extension
+ self._extensions_by_name[extension.containing_type][
+ extension.full_name] = extension
+ except:
+ warn_msg = ('Unable to load proto file %s for extension number %d.' %
+ (file_proto.name, number))
+ warnings.warn(warn_msg, RuntimeWarning)
+
def FindServiceByName(self, full_name):
"""Loads the named service descriptor from the pool.
@@ -532,6 +642,23 @@ class DescriptorPool(object):
self._FindFileContainingSymbolInDb(full_name)
return self._service_descriptors[full_name]
+ def FindMethodByName(self, full_name):
+ """Loads the named service method descriptor from the pool.
+
+ Args:
+ full_name: The full name of the method descriptor to load.
+
+ Returns:
+ The method descriptor for the service method.
+
+ Raises:
+ KeyError: if the method cannot be found in the pool.
+ """
+ full_name = _NormalizeFullyQualifiedName(full_name)
+ service_name, _, method_name = full_name.rpartition('.')
+ service_descriptor = self.FindServiceByName(service_name)
+ return service_descriptor.methods_by_name[method_name]
+
def _FindFileContainingSymbolInDb(self, symbol):
"""Finds the file in descriptor DB containing the specified symbol.
@@ -567,7 +694,6 @@ class DescriptorPool(object):
Returns:
A FileDescriptor matching the passed in proto.
"""
-
if file_proto.name not in self._file_descriptors:
built_deps = list(self._GetDeps(file_proto.dependency))
direct_deps = [self.FindFileByName(n) for n in file_proto.dependency]
@@ -604,7 +730,7 @@ class DescriptorPool(object):
for enum_type in file_proto.enum_type:
file_descriptor.enum_types_by_name[enum_type.name] = (
self._ConvertEnumDescriptor(enum_type, file_proto.package,
- file_descriptor, None, scope))
+ file_descriptor, None, scope, True))
for index, extension_proto in enumerate(file_proto.extension):
extension_desc = self._MakeFieldDescriptor(
@@ -616,6 +742,8 @@ class DescriptorPool(object):
file_descriptor.package, scope)
file_descriptor.extensions_by_name[extension_desc.name] = (
extension_desc)
+ self._file_desc_by_toplevel_extension[extension_desc.full_name] = (
+ file_descriptor)
for desc_proto in file_proto.message_type:
self._SetAllFieldTypes(file_proto.package, desc_proto, scope)
@@ -673,7 +801,8 @@ class DescriptorPool(object):
nested, desc_name, file_desc, scope, syntax)
for nested in desc_proto.nested_type]
enums = [
- self._ConvertEnumDescriptor(enum, desc_name, file_desc, None, scope)
+ self._ConvertEnumDescriptor(enum, desc_name, file_desc, None,
+ scope, False)
for enum in desc_proto.enum_type]
fields = [self._MakeFieldDescriptor(field, desc_name, index, file_desc)
for index, field in enumerate(desc_proto.field)]
@@ -718,12 +847,12 @@ class DescriptorPool(object):
fields[field_index].containing_oneof = oneofs[oneof_index]
scope[_PrefixWithDot(desc_name)] = desc
- self._CheckConflictRegister(desc)
+ self._CheckConflictRegister(desc, desc.full_name, desc.file.name)
self._descriptors[desc_name] = desc
return desc
def _ConvertEnumDescriptor(self, enum_proto, package=None, file_desc=None,
- containing_type=None, scope=None):
+ containing_type=None, scope=None, top_level=False):
"""Make a protobuf EnumDescriptor given an EnumDescriptorProto protobuf.
Args:
@@ -732,6 +861,8 @@ class DescriptorPool(object):
file_desc: The file containing the enum descriptor.
containing_type: The type containing this enum.
scope: Scope containing available types.
+ top_level: If True, the enum is a top level symbol. If False, the enum
+ is defined inside a message.
Returns:
The added descriptor
@@ -757,8 +888,17 @@ class DescriptorPool(object):
containing_type=containing_type,
options=_OptionsOrNone(enum_proto))
scope['.%s' % enum_name] = desc
- self._CheckConflictRegister(desc)
+ self._CheckConflictRegister(desc, desc.full_name, desc.file.name)
self._enum_descriptors[enum_name] = desc
+
+ # Add top level enum values.
+ if top_level:
+ for value in values:
+ full_name = _NormalizeFullyQualifiedName(
+ '.'.join((package, value.name)))
+ self._CheckConflictRegister(value, full_name, file_name)
+ self._top_enum_values[full_name] = value
+
return desc
def _MakeFieldDescriptor(self, field_proto, message_name, index,
@@ -885,6 +1025,8 @@ class DescriptorPool(object):
elif field_proto.type == descriptor.FieldDescriptor.TYPE_BYTES:
field_desc.default_value = text_encoding.CUnescape(
field_proto.default_value)
+ elif field_proto.type == descriptor.FieldDescriptor.TYPE_MESSAGE:
+ field_desc.default_value = None
else:
# All other types are of the "int" type.
field_desc.default_value = int(field_proto.default_value)
@@ -901,6 +1043,8 @@ class DescriptorPool(object):
field_desc.default_value = field_desc.enum_type.values[0].number
elif field_proto.type == descriptor.FieldDescriptor.TYPE_BYTES:
field_desc.default_value = b''
+ elif field_proto.type == descriptor.FieldDescriptor.TYPE_MESSAGE:
+ field_desc.default_value = None
else:
# All other types are of the "int" type.
field_desc.default_value = 0
@@ -954,7 +1098,7 @@ class DescriptorPool(object):
methods=methods,
options=_OptionsOrNone(service_proto),
file=file_desc)
- self._CheckConflictRegister(desc)
+ self._CheckConflictRegister(desc, desc.full_name, desc.file.name)
self._service_descriptors[service_name] = desc
return desc