diff options
author | Jisi Liu <jisi.liu@gmail.com> | 2015-10-05 11:59:43 -0700 |
---|---|---|
committer | Jisi Liu <jisi.liu@gmail.com> | 2015-10-05 11:59:43 -0700 |
commit | 46e8ff63cb67a6520711da5317aaaef04d0414d0 (patch) | |
tree | 64370726fe469f8dfca7b14f8b8cb80b6cc856f6 /python/google/protobuf/pyext/message.cc | |
parent | 0087da9d4775f79c67362cc89c653f3a33a9bae2 (diff) | |
download | protobuf-46e8ff63cb67a6520711da5317aaaef04d0414d0.tar.gz protobuf-46e8ff63cb67a6520711da5317aaaef04d0414d0.tar.bz2 protobuf-46e8ff63cb67a6520711da5317aaaef04d0414d0.zip |
Down-integrate from google internal.
Diffstat (limited to 'python/google/protobuf/pyext/message.cc')
-rw-r--r-- | python/google/protobuf/pyext/message.cc | 153 |
1 files changed, 92 insertions, 61 deletions
diff --git a/python/google/protobuf/pyext/message.cc b/python/google/protobuf/pyext/message.cc index 62c7c478..63d53136 100644 --- a/python/google/protobuf/pyext/message.cc +++ b/python/google/protobuf/pyext/message.cc @@ -55,6 +55,7 @@ #include <google/protobuf/descriptor.h> #include <google/protobuf/message.h> #include <google/protobuf/text_format.h> +#include <google/protobuf/unknown_field_set.h> #include <google/protobuf/pyext/descriptor.h> #include <google/protobuf/pyext/descriptor_pool.h> #include <google/protobuf/pyext/extension_dict.h> @@ -107,8 +108,18 @@ struct PyMessageMeta { // C++ descriptor of this message. const Descriptor* message_descriptor; + // Owned reference, used to keep the pointer above alive. PyObject* py_message_descriptor; + + // The Python DescriptorPool used to create the class. It is needed to resolve + // fields descriptors, including extensions fields; its C++ MessageFactory is + // used to instantiate submessages. + // This can be different from DESCRIPTOR.file.pool, in the case of a custom + // DescriptorPool which defines new extensions. + // We own the reference, because it's important to keep the descriptors and + // factory alive. + PyDescriptorPool* py_descriptor_pool; }; namespace message_meta { @@ -139,18 +150,10 @@ static bool AddFieldNumberToClass( // Finalize the creation of the Message class. -// Called from its metaclass: GeneratedProtocolMessageType.__init__(). -static int AddDescriptors(PyObject* cls, PyObject* descriptor) { - const Descriptor* message_descriptor = - cdescriptor_pool::RegisterMessageClass( - GetDescriptorPool(), cls, descriptor); - if (message_descriptor == NULL) { - return -1; - } - +static int AddDescriptors(PyObject* cls, const Descriptor* descriptor) { // If there are extension_ranges, the message is "extendable", and extension // classes will register themselves in this class. - if (message_descriptor->extension_range_count() > 0) { + if (descriptor->extension_range_count() > 0) { ScopedPyObjectPtr by_name(PyDict_New()); if (PyObject_SetAttr(cls, k_extensions_by_name, by_name) < 0) { return -1; @@ -162,8 +165,8 @@ static int AddDescriptors(PyObject* cls, PyObject* descriptor) { } // For each field set: cls.<field>_FIELD_NUMBER = <number> - for (int i = 0; i < message_descriptor->field_count(); ++i) { - if (!AddFieldNumberToClass(cls, message_descriptor->field(i))) { + for (int i = 0; i < descriptor->field_count(); ++i) { + if (!AddFieldNumberToClass(cls, descriptor->field(i))) { return -1; } } @@ -173,8 +176,8 @@ static int AddDescriptors(PyObject* cls, PyObject* descriptor) { // The enum descriptor we get from // <messagedescriptor>.enum_types_by_name[name] // which was built previously. - for (int i = 0; i < message_descriptor->enum_type_count(); ++i) { - const EnumDescriptor* enum_descriptor = message_descriptor->enum_type(i); + for (int i = 0; i < descriptor->enum_type_count(); ++i) { + const EnumDescriptor* enum_descriptor = descriptor->enum_type(i); ScopedPyObjectPtr enum_type( PyEnumDescriptor_FromDescriptor(enum_descriptor)); if (enum_type == NULL) { @@ -212,8 +215,8 @@ static int AddDescriptors(PyObject* cls, PyObject* descriptor) { // Extension descriptors come from // <message descriptor>.extensions_by_name[name] // which was defined previously. - for (int i = 0; i < message_descriptor->extension_count(); ++i) { - const google::protobuf::FieldDescriptor* field = message_descriptor->extension(i); + for (int i = 0; i < descriptor->extension_count(); ++i) { + const google::protobuf::FieldDescriptor* field = descriptor->extension(i); ScopedPyObjectPtr extension_field(PyFieldDescriptor_FromDescriptor(field)); if (extension_field == NULL) { return -1; @@ -258,14 +261,14 @@ static PyObject* New(PyTypeObject* type, } // Check dict['DESCRIPTOR'] - PyObject* descriptor = PyDict_GetItem(dict, kDESCRIPTOR); - if (descriptor == NULL) { + PyObject* py_descriptor = PyDict_GetItem(dict, kDESCRIPTOR); + if (py_descriptor == NULL) { PyErr_SetString(PyExc_TypeError, "Message class has no DESCRIPTOR"); return NULL; } - if (!PyObject_TypeCheck(descriptor, &PyMessageDescriptor_Type)) { + if (!PyObject_TypeCheck(py_descriptor, &PyMessageDescriptor_Type)) { PyErr_Format(PyExc_TypeError, "Expected a message Descriptor, got %s", - descriptor->ob_type->tp_name); + py_descriptor->ob_type->tp_name); return NULL; } @@ -291,14 +294,28 @@ static PyObject* New(PyTypeObject* type, } // Cache the descriptor, both as Python object and as C++ pointer. - const Descriptor* message_descriptor = - PyMessageDescriptor_AsDescriptor(descriptor); - if (message_descriptor == NULL) { + const Descriptor* descriptor = + PyMessageDescriptor_AsDescriptor(py_descriptor); + if (descriptor == NULL) { + return NULL; + } + Py_INCREF(py_descriptor); + newtype->py_message_descriptor = py_descriptor; + newtype->message_descriptor = descriptor; + // TODO(amauryfa): Don't always use the canonical pool of the descriptor, + // use the MessageFactory optionally passed in the class dict. + newtype->py_descriptor_pool = GetDescriptorPool_FromPool( + descriptor->file()->pool()); + if (newtype->py_descriptor_pool == NULL) { + return NULL; + } + Py_INCREF(newtype->py_descriptor_pool); + + // Add the message to the DescriptorPool. + if (cdescriptor_pool::RegisterMessageClass(newtype->py_descriptor_pool, + descriptor, result) < 0) { return NULL; } - Py_INCREF(descriptor); - newtype->py_message_descriptor = descriptor; - newtype->message_descriptor = message_descriptor; // Continue with type initialization: add other descriptors, enum values... if (AddDescriptors(result, descriptor) < 0) { @@ -309,6 +326,7 @@ static PyObject* New(PyTypeObject* type, static void Dealloc(PyMessageMeta *self) { Py_DECREF(self->py_message_descriptor); + Py_DECREF(self->py_descriptor_pool); Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self)); } @@ -381,12 +399,20 @@ PyTypeObject PyMessageMeta_Type = { message_meta::New, // tp_new }; -static const Descriptor* GetMessageDescriptor(PyTypeObject* cls) { +static PyMessageMeta* CheckMessageClass(PyTypeObject* cls) { if (!PyObject_TypeCheck(cls, &PyMessageMeta_Type)) { PyErr_Format(PyExc_TypeError, "Class %s is not a Message", cls->tp_name); return NULL; } - return reinterpret_cast<PyMessageMeta*>(cls)->message_descriptor; + return reinterpret_cast<PyMessageMeta*>(cls); +} + +static const Descriptor* GetMessageDescriptor(PyTypeObject* cls) { + PyMessageMeta* type = CheckMessageClass(cls); + if (type == NULL) { + return NULL; + } + return type->message_descriptor; } // Forward declarations @@ -723,6 +749,17 @@ bool CheckFieldBelongsToMessage(const FieldDescriptor* field_descriptor, namespace cmessage { +PyDescriptorPool* GetDescriptorPoolForMessage(CMessage* message) { + // No need to check the type: the type of instances of CMessage is always + // an instance of PyMessageMeta. Let's prove it with a debug-only check. + GOOGLE_DCHECK(PyObject_TypeCheck(message, &CMessage_Type)); + return reinterpret_cast<PyMessageMeta*>(Py_TYPE(message))->py_descriptor_pool; +} + +MessageFactory* GetFactoryForMessage(CMessage* message) { + return GetDescriptorPoolForMessage(message)->message_factory; +} + static int MaybeReleaseOverlappingOneofField( CMessage* cmessage, const FieldDescriptor* field) { @@ -773,7 +810,7 @@ static Message* GetMutableMessage( return NULL; } return reflection->MutableMessage( - parent_message, parent_field, GetDescriptorPool()->message_factory); + parent_message, parent_field, GetFactoryForMessage(parent)); } struct FixupMessageReference : public ChildVisitor { @@ -814,10 +851,7 @@ int AssureWritable(CMessage* self) { // If parent is NULL but we are trying to modify a read-only message, this // is a reference to a constant default instance that needs to be replaced // with a mutable top-level message. - const Message* prototype = - GetDescriptorPool()->message_factory->GetPrototype( - self->message->GetDescriptor()); - self->message = prototype->New(); + self->message = self->message->New(); self->owner.reset(self->message); // Cascade the new owner to eventual children: even if this message is // empty, some submessages or repeated containers might exist already. @@ -1190,15 +1224,19 @@ CMessage* NewEmptyMessage(PyObject* type, const Descriptor *descriptor) { // The __new__ method of Message classes. // Creates a new C++ message and takes ownership. -static PyObject* New(PyTypeObject* type, +static PyObject* New(PyTypeObject* cls, PyObject* unused_args, PyObject* unused_kwargs) { + PyMessageMeta* type = CheckMessageClass(cls); + if (type == NULL) { + return NULL; + } // Retrieve the message descriptor and the default instance (=prototype). - const Descriptor* message_descriptor = GetMessageDescriptor(type); + const Descriptor* message_descriptor = type->message_descriptor; if (message_descriptor == NULL) { return NULL; } - const Message* default_message = - GetDescriptorPool()->message_factory->GetPrototype(message_descriptor); + const Message* default_message = type->py_descriptor_pool->message_factory + ->GetPrototype(message_descriptor); if (default_message == NULL) { PyErr_SetString(PyExc_TypeError, message_descriptor->full_name().c_str()); return NULL; @@ -1528,7 +1566,7 @@ int SetOwner(CMessage* self, const shared_ptr<Message>& new_owner) { Message* ReleaseMessage(CMessage* self, const Descriptor* descriptor, const FieldDescriptor* field_descriptor) { - MessageFactory* message_factory = GetDescriptorPool()->message_factory; + MessageFactory* message_factory = GetFactoryForMessage(self); Message* released_message = self->message->GetReflection()->ReleaseMessage( self->message, field_descriptor, message_factory); // ReleaseMessage will return NULL which differs from @@ -1883,8 +1921,8 @@ static PyObject* MergeFromString(CMessage* self, PyObject* arg) { AssureWritable(self); io::CodedInputStream input( reinterpret_cast<const uint8*>(data), data_length); - input.SetExtensionRegistry(GetDescriptorPool()->pool, - GetDescriptorPool()->message_factory); + PyDescriptorPool* pool = GetDescriptorPoolForMessage(self); + input.SetExtensionRegistry(pool->pool, pool->message_factory); bool success = self->message->MergePartialFromCodedStream(&input); if (success) { return PyInt_FromLong(input.CurrentPosition()); @@ -1907,11 +1945,6 @@ static PyObject* ByteSize(CMessage* self, PyObject* args) { static PyObject* RegisterExtension(PyObject* cls, PyObject* extension_handle) { - ScopedPyObjectPtr message_descriptor(PyObject_GetAttr(cls, kDESCRIPTOR)); - if (message_descriptor == NULL) { - return NULL; - } - const FieldDescriptor* descriptor = GetExtensionDescriptor(extension_handle); if (descriptor == NULL) { @@ -1920,13 +1953,6 @@ static PyObject* RegisterExtension(PyObject* cls, const Descriptor* cmessage_descriptor = GetMessageDescriptor( reinterpret_cast<PyTypeObject*>(cls)); - if (cmessage_descriptor != descriptor->containing_type()) { - if (PyObject_SetAttrString(extension_handle, "containing_type", - message_descriptor) < 0) { - return NULL; - } - } - ScopedPyObjectPtr extensions_by_name( PyObject_GetAttr(cls, k_extensions_by_name)); if (extensions_by_name == NULL) { @@ -2050,7 +2076,8 @@ static PyObject* ListFields(CMessage* self) { // TODO(amauryfa): consider building the class on the fly! if (fields[i]->message_type() != NULL && cdescriptor_pool::GetMessageClass( - GetDescriptorPool(), fields[i]->message_type()) == NULL) { + GetDescriptorPoolForMessage(self), + fields[i]->message_type()) == NULL) { PyErr_Clear(); continue; } @@ -2207,7 +2234,9 @@ PyObject* InternalGetScalar(const Message* message, message->GetReflection()->GetUnknownFields(*message); for (int i = 0; i < unknown_field_set.field_count(); ++i) { if (unknown_field_set.field(i).number() == - field_descriptor->number()) { + field_descriptor->number() && + unknown_field_set.field(i).type() == + google::protobuf::UnknownField::TYPE_VARINT) { result = PyInt_FromLong(unknown_field_set.field(i).varint()); break; } @@ -2233,11 +2262,12 @@ PyObject* InternalGetScalar(const Message* message, PyObject* InternalGetSubMessage( CMessage* self, const FieldDescriptor* field_descriptor) { const Reflection* reflection = self->message->GetReflection(); + PyDescriptorPool* pool = GetDescriptorPoolForMessage(self); const Message& sub_message = reflection->GetMessage( - *self->message, field_descriptor, GetDescriptorPool()->message_factory); + *self->message, field_descriptor, pool->message_factory); PyObject *message_class = cdescriptor_pool::GetMessageClass( - GetDescriptorPool(), field_descriptor->message_type()); + pool, field_descriptor->message_type()); if (message_class == NULL) { return NULL; } @@ -2560,7 +2590,7 @@ PyObject* GetAttr(CMessage* self, PyObject* name) { const FieldDescriptor* value_type = entry_type->FindFieldByName("value"); if (value_type->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { PyObject* value_class = cdescriptor_pool::GetMessageClass( - GetDescriptorPool(), value_type->message_type()); + GetDescriptorPoolForMessage(self), value_type->message_type()); if (value_class == NULL) { return NULL; } @@ -2583,7 +2613,7 @@ PyObject* GetAttr(CMessage* self, PyObject* name) { PyObject* py_container = NULL; if (field_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { PyObject *message_class = cdescriptor_pool::GetMessageClass( - GetDescriptorPool(), field_descriptor->message_type()); + GetDescriptorPoolForMessage(self), field_descriptor->message_type()); if (message_class == NULL) { return NULL; } @@ -2908,9 +2938,10 @@ bool InitProto2MessageModule(PyObject *m) { // Expose the DescriptorPool used to hold all descriptors added from generated // pb2.py files. - Py_INCREF(GetDescriptorPool()); // PyModule_AddObject steals a reference. - PyModule_AddObject( - m, "default_pool", reinterpret_cast<PyObject*>(GetDescriptorPool())); + // PyModule_AddObject steals a reference. + Py_INCREF(GetDefaultDescriptorPool()); + PyModule_AddObject(m, "default_pool", + reinterpret_cast<PyObject*>(GetDefaultDescriptorPool())); // This implementation provides full Descriptor types, we advertise it so that // descriptor.py can use them in replacement of the Python classes. |