diff options
Diffstat (limited to 'python/google/protobuf/pyext/message.cc')
-rw-r--r-- | python/google/protobuf/pyext/message.cc | 212 |
1 files changed, 86 insertions, 126 deletions
diff --git a/python/google/protobuf/pyext/message.cc b/python/google/protobuf/pyext/message.cc index 5535338d..1b325469 100644 --- a/python/google/protobuf/pyext/message.cc +++ b/python/google/protobuf/pyext/message.cc @@ -63,6 +63,7 @@ #include <google/protobuf/pyext/repeated_composite_container.h> #include <google/protobuf/pyext/repeated_scalar_container.h> #include <google/protobuf/pyext/map_container.h> +#include <google/protobuf/pyext/message_factory.h> #include <google/protobuf/pyext/scoped_pyobject_ptr.h> #include <google/protobuf/stubs/strutil.h> @@ -244,6 +245,12 @@ static PyObject* New(PyTypeObject* type, return NULL; } + // Messages have no __dict__ + ScopedPyObjectPtr slots(PyTuple_New(0)); + if (PyDict_SetItemString(dict, "__slots__", slots.get()) < 0) { + return NULL; + } + // Build the arguments to the base metaclass. // We change the __bases__ classes. ScopedPyObjectPtr new_args; @@ -300,16 +307,19 @@ static PyObject* New(PyTypeObject* type, newtype->message_descriptor = descriptor; // TODO(amauryfa): Don't always use the canonical pool of the descriptor, // use the MessageFactory optionally passed in the class dict. - newtype->py_descriptor_pool = GetDescriptorPool_FromPool( - descriptor->file()->pool()); - if (newtype->py_descriptor_pool == NULL) { + PyDescriptorPool* py_descriptor_pool = + GetDescriptorPool_FromPool(descriptor->file()->pool()); + if (py_descriptor_pool == NULL) { return NULL; } - Py_INCREF(newtype->py_descriptor_pool); + newtype->py_message_factory = py_descriptor_pool->py_message_factory; + Py_INCREF(newtype->py_message_factory); - // Add the message to the DescriptorPool. - if (cdescriptor_pool::RegisterMessageClass(newtype->py_descriptor_pool, - descriptor, newtype) < 0) { + // Register the message in the MessageFactory. + // TODO(amauryfa): Move this call to MessageFactory.GetPrototype() when the + // MessageFactory is fully implemented in C++. + if (message_factory::RegisterMessageClass(newtype->py_message_factory, + descriptor, newtype) < 0) { return NULL; } @@ -321,8 +331,8 @@ static PyObject* New(PyTypeObject* type, } static void Dealloc(CMessageClass *self) { - Py_DECREF(self->py_message_descriptor); - Py_DECREF(self->py_descriptor_pool); + Py_XDECREF(self->py_message_descriptor); + Py_XDECREF(self->py_message_factory); Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self)); } @@ -752,15 +762,9 @@ bool CheckFieldBelongsToMessage(const FieldDescriptor* field_descriptor, namespace cmessage { -PyDescriptorPool* GetDescriptorPoolForMessage(CMessage* message) { - // No need to check the type: the type of instances of CMessage is always - // an instance of CMessageClass. Let's prove it with a debug-only check. +PyMessageFactory* GetFactoryForMessage(CMessage* message) { GOOGLE_DCHECK(PyObject_TypeCheck(message, &CMessage_Type)); - return reinterpret_cast<CMessageClass*>(Py_TYPE(message))->py_descriptor_pool; -} - -MessageFactory* GetFactoryForMessage(CMessage* message) { - return GetDescriptorPoolForMessage(message)->message_factory; + return reinterpret_cast<CMessageClass*>(Py_TYPE(message))->py_message_factory; } static int MaybeReleaseOverlappingOneofField( @@ -813,7 +817,8 @@ static Message* GetMutableMessage( return NULL; } return reflection->MutableMessage( - parent_message, parent_field, GetFactoryForMessage(parent)); + parent_message, parent_field, + GetFactoryForMessage(parent)->message_factory); } struct FixupMessageReference : public ChildVisitor { @@ -1172,6 +1177,8 @@ int InitAttributes(CMessage* self, PyObject* args, PyObject* kwargs) { } CMessage* cmessage = reinterpret_cast<CMessage*>(message.get()); if (PyDict_Check(value)) { + // Make the message exist even if the dict is empty. + AssureWritable(cmessage); if (InitAttributes(cmessage, NULL, value) < 0) { return -1; } @@ -1231,7 +1238,7 @@ static PyObject* New(PyTypeObject* cls, if (message_descriptor == NULL) { return NULL; } - const Message* default_message = type->py_descriptor_pool->message_factory + const Message* default_message = type->py_message_factory->message_factory ->GetPrototype(message_descriptor); if (default_message == NULL) { PyErr_SetString(PyExc_TypeError, message_descriptor->full_name().c_str()); @@ -1292,6 +1299,9 @@ struct ClearWeakReferences : public ChildVisitor { }; static void Dealloc(CMessage* self) { + if (self->weakreflist) { + PyObject_ClearWeakRefs(reinterpret_cast<PyObject*>(self)); + } // Null out all weak references from children to this message. GOOGLE_CHECK_EQ(0, ForEachCompositeField(self, ClearWeakReferences())); if (self->extensions) { @@ -1459,18 +1469,20 @@ PyObject* HasField(CMessage* self, PyObject* arg) { } PyObject* ClearExtension(CMessage* self, PyObject* extension) { + const FieldDescriptor* descriptor = GetExtensionDescriptor(extension); + if (descriptor == NULL) { + return NULL; + } if (self->extensions != NULL) { - return extension_dict::ClearExtension(self->extensions, extension); - } else { - const FieldDescriptor* descriptor = GetExtensionDescriptor(extension); - if (descriptor == NULL) { - return NULL; - } - if (ScopedPyObjectPtr(ClearFieldByDescriptor(self, descriptor)) == NULL) { - return NULL; + PyObject* value = PyDict_GetItem(self->extensions->values, extension); + if (value != NULL) { + if (InternalReleaseFieldByDescriptor(self, descriptor, value) < 0) { + return NULL; + } + PyDict_DelItem(self->extensions->values, extension); } } - Py_RETURN_NONE; + return ClearFieldByDescriptor(self, descriptor); } PyObject* HasExtension(CMessage* self, PyObject* extension) { @@ -1556,7 +1568,7 @@ int SetOwner(CMessage* self, const shared_ptr<Message>& new_owner) { Message* ReleaseMessage(CMessage* self, const Descriptor* descriptor, const FieldDescriptor* field_descriptor) { - MessageFactory* message_factory = GetFactoryForMessage(self); + MessageFactory* message_factory = GetFactoryForMessage(self)->message_factory; Message* released_message = self->message->GetReflection()->ReleaseMessage( self->message, field_descriptor, message_factory); // ReleaseMessage will return NULL which differs from @@ -1624,12 +1636,19 @@ int InternalReleaseFieldByDescriptor( PyObject* ClearFieldByDescriptor( CMessage* self, - const FieldDescriptor* descriptor) { - if (!CheckFieldBelongsToMessage(descriptor, self->message)) { + const FieldDescriptor* field_descriptor) { + if (!CheckFieldBelongsToMessage(field_descriptor, self->message)) { return NULL; } AssureWritable(self); - self->message->GetReflection()->ClearField(self->message, descriptor); + Message* message = self->message; + message->GetReflection()->ClearField(message, field_descriptor); + if (field_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_ENUM && + !message->GetReflection()->SupportsUnknownEnumValues()) { + UnknownFieldSet* unknown_field_set = + message->GetReflection()->MutableUnknownFields(message); + unknown_field_set->DeleteByNumber(field_descriptor->number()); + } Py_RETURN_NONE; } @@ -1665,27 +1684,17 @@ PyObject* ClearField(CMessage* self, PyObject* arg) { arg = arg_in_oneof.get(); } - PyObject* composite_field = self->composite_fields ? - PyDict_GetItem(self->composite_fields, arg) : NULL; - - // Only release the field if there's a possibility that there are - // references to it. - if (composite_field != NULL) { - if (InternalReleaseFieldByDescriptor(self, field_descriptor, - composite_field) < 0) { - return NULL; + // Release the field if it exists in the dict of composite fields. + if (self->composite_fields) { + PyObject* value = PyDict_GetItem(self->composite_fields, arg); + if (value != NULL) { + if (InternalReleaseFieldByDescriptor(self, field_descriptor, value) < 0) { + return NULL; + } + PyDict_DelItem(self->composite_fields, arg); } - PyDict_DelItem(self->composite_fields, arg); - } - message->GetReflection()->ClearField(message, field_descriptor); - if (field_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_ENUM && - !message->GetReflection()->SupportsUnknownEnumValues()) { - UnknownFieldSet* unknown_field_set = - message->GetReflection()->MutableUnknownFields(message); - unknown_field_set->DeleteByNumber(field_descriptor->number()); } - - Py_RETURN_NONE; + return ClearFieldByDescriptor(self, field_descriptor); } PyObject* Clear(CMessage* self) { @@ -1927,8 +1936,8 @@ static PyObject* MergeFromString(CMessage* self, PyObject* arg) { if (allow_oversize_protos) { input.SetTotalBytesLimit(INT_MAX, INT_MAX); } - PyDescriptorPool* pool = GetDescriptorPoolForMessage(self); - input.SetExtensionRegistry(pool->pool, pool->message_factory); + PyMessageFactory* factory = GetFactoryForMessage(self); + input.SetExtensionRegistry(factory->pool->pool, factory->message_factory); bool success = self->message->MergePartialFromCodedStream(&input); if (success) { return PyInt_FromLong(input.CurrentPosition()); @@ -2108,8 +2117,8 @@ static PyObject* ListFields(CMessage* self) { // is no message class and we cannot retrieve the value. // TODO(amauryfa): consider building the class on the fly! if (fields[i]->message_type() != NULL && - cdescriptor_pool::GetMessageClass( - GetDescriptorPoolForMessage(self), + message_factory::GetMessageClass( + GetFactoryForMessage(self), fields[i]->message_type()) == NULL) { PyErr_Clear(); continue; @@ -2306,12 +2315,12 @@ PyObject* InternalGetScalar(const Message* message, PyObject* InternalGetSubMessage( CMessage* self, const FieldDescriptor* field_descriptor) { const Reflection* reflection = self->message->GetReflection(); - PyDescriptorPool* pool = GetDescriptorPoolForMessage(self); + PyMessageFactory* factory = GetFactoryForMessage(self); const Message& sub_message = reflection->GetMessage( - *self->message, field_descriptor, pool->message_factory); + *self->message, field_descriptor, factory->message_factory); - CMessageClass* message_class = cdescriptor_pool::GetMessageClass( - pool, field_descriptor->message_type()); + CMessageClass* message_class = message_factory::GetMessageClass( + factory, field_descriptor->message_type()); if (message_class == NULL) { return NULL; } @@ -2656,8 +2665,8 @@ PyObject* GetAttr(CMessage* self, PyObject* name) { const Descriptor* entry_type = field_descriptor->message_type(); const FieldDescriptor* value_type = entry_type->FindFieldByName("value"); if (value_type->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { - CMessageClass* value_class = cdescriptor_pool::GetMessageClass( - GetDescriptorPoolForMessage(self), value_type->message_type()); + CMessageClass* value_class = message_factory::GetMessageClass( + GetFactoryForMessage(self), value_type->message_type()); if (value_class == NULL) { return NULL; } @@ -2679,8 +2688,8 @@ PyObject* GetAttr(CMessage* self, PyObject* name) { if (field_descriptor->label() == FieldDescriptor::LABEL_REPEATED) { PyObject* py_container = NULL; if (field_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { - CMessageClass* message_class = cdescriptor_pool::GetMessageClass( - GetDescriptorPoolForMessage(self), field_descriptor->message_type()); + CMessageClass* message_class = message_factory::GetMessageClass( + GetFactoryForMessage(self), field_descriptor->message_type()); if (message_class == NULL) { return NULL; } @@ -2775,7 +2784,7 @@ PyTypeObject CMessage_Type = { 0, // tp_traverse 0, // tp_clear (richcmpfunc)cmessage::RichCompare, // tp_richcompare - 0, // tp_weaklistoffset + offsetof(CMessage, weakreflist), // tp_weaklistoffset 0, // tp_iter 0, // tp_iternext cmessage::Methods, // tp_methods @@ -2863,6 +2872,11 @@ bool InitProto2MessageModule(PyObject *m) { return false; } + // Initialize types and globals in message_factory.cc + if (!InitMessageFactory()) { + return false; + } + // Initialize constants defined in this file. InitGlobals(); @@ -2944,69 +2958,15 @@ bool InitProto2MessageModule(PyObject *m) { } // Initialize Map container types. - { - // ScalarMapContainer_Type derives from our MutableMapping type. - ScopedPyObjectPtr containers(PyImport_ImportModule( - "google.protobuf.internal.containers")); - if (containers == NULL) { - return false; - } - - ScopedPyObjectPtr mutable_mapping( - PyObject_GetAttrString(containers.get(), "MutableMapping")); - if (mutable_mapping == NULL) { - return false; - } - - if (!PyObject_TypeCheck(mutable_mapping.get(), &PyType_Type)) { - return false; - } - - Py_INCREF(mutable_mapping.get()); -#if PY_MAJOR_VERSION >= 3 - PyObject* bases = PyTuple_New(1); - PyTuple_SET_ITEM(bases, 0, mutable_mapping.get()); - - ScalarMapContainer_Type = - PyType_FromSpecWithBases(&ScalarMapContainer_Type_spec, bases); - PyModule_AddObject(m, "ScalarMapContainer", ScalarMapContainer_Type); -#else - ScalarMapContainer_Type.tp_base = - reinterpret_cast<PyTypeObject*>(mutable_mapping.get()); - - if (PyType_Ready(&ScalarMapContainer_Type) < 0) { - return false; - } - - PyModule_AddObject(m, "ScalarMapContainer", - reinterpret_cast<PyObject*>(&ScalarMapContainer_Type)); -#endif - - if (PyType_Ready(&MapIterator_Type) < 0) { - return false; - } - - PyModule_AddObject(m, "MapIterator", - reinterpret_cast<PyObject*>(&MapIterator_Type)); - - -#if PY_MAJOR_VERSION >= 3 - MessageMapContainer_Type = - PyType_FromSpecWithBases(&MessageMapContainer_Type_spec, bases); - PyModule_AddObject(m, "MessageMapContainer", MessageMapContainer_Type); -#else - Py_INCREF(mutable_mapping.get()); - MessageMapContainer_Type.tp_base = - reinterpret_cast<PyTypeObject*>(mutable_mapping.get()); - - if (PyType_Ready(&MessageMapContainer_Type) < 0) { - return false; - } - - PyModule_AddObject(m, "MessageMapContainer", - reinterpret_cast<PyObject*>(&MessageMapContainer_Type)); -#endif + if (!InitMapContainers()) { + return false; } + PyModule_AddObject(m, "ScalarMapContainer", + reinterpret_cast<PyObject*>(ScalarMapContainer_Type)); + PyModule_AddObject(m, "MessageMapContainer", + reinterpret_cast<PyObject*>(MessageMapContainer_Type)); + PyModule_AddObject(m, "MapIterator", + reinterpret_cast<PyObject*>(&MapIterator_Type)); if (PyType_Ready(&ExtensionDict_Type) < 0) { return false; |