diff options
Diffstat (limited to 'python/google/protobuf/internal/descriptor_pool_test.py')
-rw-r--r-- | python/google/protobuf/internal/descriptor_pool_test.py | 91 |
1 files changed, 73 insertions, 18 deletions
diff --git a/python/google/protobuf/internal/descriptor_pool_test.py b/python/google/protobuf/internal/descriptor_pool_test.py index 1e710dcf..6015e6f8 100644 --- a/python/google/protobuf/internal/descriptor_pool_test.py +++ b/python/google/protobuf/internal/descriptor_pool_test.py @@ -63,6 +63,9 @@ from google.protobuf import symbol_database class DescriptorPoolTest(unittest.TestCase): def setUp(self): + # TODO(jieluo): Should make the pool which is created by + # serialized_pb same with generated pool. + # TODO(jieluo): More test coverage for the generated pool. self.pool = descriptor_pool.DescriptorPool() self.factory_test1_fd = descriptor_pb2.FileDescriptorProto.FromString( factory_test1_pb2.DESCRIPTOR.serialized_pb) @@ -71,6 +74,13 @@ class DescriptorPoolTest(unittest.TestCase): self.pool.Add(self.factory_test1_fd) self.pool.Add(self.factory_test2_fd) + self.pool.Add(descriptor_pb2.FileDescriptorProto.FromString( + unittest_import_public_pb2.DESCRIPTOR.serialized_pb)) + self.pool.Add(descriptor_pb2.FileDescriptorProto.FromString( + unittest_import_pb2.DESCRIPTOR.serialized_pb)) + self.pool.Add(descriptor_pb2.FileDescriptorProto.FromString( + unittest_pb2.DESCRIPTOR.serialized_pb)) + def testFindFileByName(self): name1 = 'google/protobuf/internal/factory_test1.proto' file_desc1 = self.pool.FindFileByName(name1) @@ -107,6 +117,34 @@ class DescriptorPoolTest(unittest.TestCase): self.assertEqual('google.protobuf.python.internal', file_desc2.package) self.assertIn('Factory2Message', file_desc2.message_types_by_name) + # Tests top level extension. + file_desc3 = self.pool.FindFileContainingSymbol( + 'google.protobuf.python.internal.another_field') + self.assertIsInstance(file_desc3, descriptor.FileDescriptor) + self.assertEqual('google/protobuf/internal/factory_test2.proto', + file_desc3.name) + + # Tests nested extension inside a message. + file_desc4 = self.pool.FindFileContainingSymbol( + 'google.protobuf.python.internal.Factory2Message.one_more_field') + self.assertIsInstance(file_desc4, descriptor.FileDescriptor) + self.assertEqual('google/protobuf/internal/factory_test2.proto', + file_desc4.name) + + file_desc5 = self.pool.FindFileContainingSymbol( + 'protobuf_unittest.TestService') + self.assertIsInstance(file_desc5, descriptor.FileDescriptor) + self.assertEqual('google/protobuf/unittest.proto', + file_desc5.name) + + # Tests the generated pool. + assert descriptor_pool.Default().FindFileContainingSymbol( + 'google.protobuf.python.internal.Factory2Message.one_more_field') + assert descriptor_pool.Default().FindFileContainingSymbol( + 'google.protobuf.python.internal.another_field') + assert descriptor_pool.Default().FindFileContainingSymbol( + 'protobuf_unittest.TestService') + def testFindFileContainingSymbolFailure(self): with self.assertRaises(KeyError): self.pool.FindFileContainingSymbol('Does not exist') @@ -311,6 +349,10 @@ class DescriptorPoolTest(unittest.TestCase): self.pool.FindExtensionByName( 'google.protobuf.python.internal.Factory1Message.list_value') + def testFindService(self): + service = self.pool.FindServiceByName('protobuf_unittest.TestService') + self.assertEqual(service.full_name, 'protobuf_unittest.TestService') + def testUserDefinedDB(self): db = descriptor_database.DescriptorDatabase() self.pool = descriptor_pool.DescriptorPool(db) @@ -472,10 +514,10 @@ class MessageType(object): subtype.CheckType(test, desc, name, file_desc) for index, (name, field) in enumerate(self.field_list): - field.CheckField(test, desc, name, index) + field.CheckField(test, desc, name, index, file_desc) for index, (name, field) in enumerate(self.extensions): - field.CheckField(test, desc, name, index) + field.CheckField(test, desc, name, index, file_desc) class EnumField(object): @@ -485,7 +527,7 @@ class EnumField(object): self.type_name = type_name self.default_value = default_value - def CheckField(self, test, msg_desc, name, index): + def CheckField(self, test, msg_desc, name, index, file_desc): field_desc = msg_desc.fields_by_name[name] enum_desc = msg_desc.enum_types_by_name[self.type_name] test.assertEqual(name, field_desc.name) @@ -502,6 +544,7 @@ class EnumField(object): test.assertFalse(enum_desc.values_by_name[self.default_value].has_options) test.assertEqual(msg_desc, field_desc.containing_type) test.assertEqual(enum_desc, field_desc.enum_type) + test.assertEqual(file_desc, enum_desc.file) class MessageField(object): @@ -510,7 +553,7 @@ class MessageField(object): self.number = number self.type_name = type_name - def CheckField(self, test, msg_desc, name, index): + def CheckField(self, test, msg_desc, name, index, file_desc): field_desc = msg_desc.fields_by_name[name] field_type_desc = msg_desc.nested_types_by_name[self.type_name] test.assertEqual(name, field_desc.name) @@ -524,6 +567,7 @@ class MessageField(object): test.assertFalse(field_desc.has_default_value) test.assertEqual(msg_desc, field_desc.containing_type) test.assertEqual(field_type_desc, field_desc.message_type) + test.assertEqual(file_desc, field_desc.file) class StringField(object): @@ -532,7 +576,7 @@ class StringField(object): self.number = number self.default_value = default_value - def CheckField(self, test, msg_desc, name, index): + def CheckField(self, test, msg_desc, name, index, file_desc): field_desc = msg_desc.fields_by_name[name] test.assertEqual(name, field_desc.name) expected_field_full_name = '.'.join([msg_desc.full_name, name]) @@ -544,6 +588,7 @@ class StringField(object): field_desc.cpp_type) test.assertTrue(field_desc.has_default_value) test.assertEqual(self.default_value, field_desc.default_value) + test.assertEqual(file_desc, field_desc.file) class ExtensionField(object): @@ -552,7 +597,7 @@ class ExtensionField(object): self.number = number self.extended_type = extended_type - def CheckField(self, test, msg_desc, name, index): + def CheckField(self, test, msg_desc, name, index, file_desc): field_desc = msg_desc.extensions_by_name[name] test.assertEqual(name, field_desc.name) expected_field_full_name = '.'.join([msg_desc.full_name, name]) @@ -567,6 +612,7 @@ class ExtensionField(object): test.assertEqual(msg_desc, field_desc.extension_scope) test.assertEqual(msg_desc, field_desc.message_type) test.assertEqual(self.extended_type, field_desc.containing_type.name) + test.assertEqual(file_desc, field_desc.file) class AddDescriptorTest(unittest.TestCase): @@ -645,6 +691,17 @@ class AddDescriptorTest(unittest.TestCase): @unittest.skipIf(api_implementation.Type() == 'cpp', 'With the cpp implementation, Add() must be called first') + def testService(self): + pool = descriptor_pool.DescriptorPool() + with self.assertRaises(KeyError): + pool.FindServiceByName('protobuf_unittest.TestService') + pool.AddServiceDescriptor(unittest_pb2._TESTSERVICE) + self.assertEqual( + 'protobuf_unittest.TestService', + pool.FindServiceByName('protobuf_unittest.TestService').full_name) + + @unittest.skipIf(api_implementation.Type() == 'cpp', + 'With the cpp implementation, Add() must be called first') def testFile(self): pool = descriptor_pool.DescriptorPool() pool.AddFileDescriptor(unittest_pb2.DESCRIPTOR) @@ -701,15 +758,10 @@ class AddDescriptorTest(unittest.TestCase): self.assertIs(options, file_descriptor.GetOptions()) -@unittest.skipIf( - api_implementation.Type() != 'cpp', - 'default_pool is only supported by the C++ implementation') class DefaultPoolTest(unittest.TestCase): def testFindMethods(self): - # pylint: disable=g-import-not-at-top - from google.protobuf.pyext import _message - pool = _message.default_pool + pool = descriptor_pool.Default() self.assertIs( pool.FindFileByName('google/protobuf/unittest.proto'), unittest_pb2.DESCRIPTOR) @@ -720,19 +772,22 @@ class DefaultPoolTest(unittest.TestCase): pool.FindFieldByName('protobuf_unittest.TestAllTypes.optional_int32'), unittest_pb2.TestAllTypes.DESCRIPTOR.fields_by_name['optional_int32']) self.assertIs( - pool.FindExtensionByName('protobuf_unittest.optional_int32_extension'), - unittest_pb2.DESCRIPTOR.extensions_by_name['optional_int32_extension']) - self.assertIs( pool.FindEnumTypeByName('protobuf_unittest.ForeignEnum'), unittest_pb2.ForeignEnum.DESCRIPTOR) + if api_implementation.Type() != 'cpp': + self.skipTest('Only the C++ implementation correctly indexes all types') + self.assertIs( + pool.FindExtensionByName('protobuf_unittest.optional_int32_extension'), + unittest_pb2.DESCRIPTOR.extensions_by_name['optional_int32_extension']) self.assertIs( pool.FindOneofByName('protobuf_unittest.TestAllTypes.oneof_field'), unittest_pb2.TestAllTypes.DESCRIPTOR.oneofs_by_name['oneof_field']) + self.assertIs( + pool.FindServiceByName('protobuf_unittest.TestService'), + unittest_pb2.DESCRIPTOR.services_by_name['TestService']) def testAddFileDescriptor(self): - # pylint: disable=g-import-not-at-top - from google.protobuf.pyext import _message - pool = _message.default_pool + pool = descriptor_pool.Default() file_desc = descriptor_pb2.FileDescriptorProto(name='some/file.proto') pool.Add(file_desc) pool.AddSerializedFile(file_desc.SerializeToString()) |