diff options
Diffstat (limited to 'python/google/protobuf/pyext/extension_dict.cc')
-rw-r--r-- | python/google/protobuf/pyext/extension_dict.cc | 138 |
1 files changed, 62 insertions, 76 deletions
diff --git a/python/google/protobuf/pyext/extension_dict.cc b/python/google/protobuf/pyext/extension_dict.cc index 555bd293..018b5c2c 100644 --- a/python/google/protobuf/pyext/extension_dict.cc +++ b/python/google/protobuf/pyext/extension_dict.cc @@ -32,19 +32,30 @@ // Author: tibell@google.com (Johan Tibell) #include <google/protobuf/pyext/extension_dict.h> +#include <memory> #include <google/protobuf/stubs/logging.h> #include <google/protobuf/stubs/common.h> #include <google/protobuf/descriptor.h> #include <google/protobuf/dynamic_message.h> #include <google/protobuf/message.h> +#include <google/protobuf/descriptor.pb.h> #include <google/protobuf/pyext/descriptor.h> -#include <google/protobuf/pyext/descriptor_pool.h> #include <google/protobuf/pyext/message.h> +#include <google/protobuf/pyext/message_factory.h> #include <google/protobuf/pyext/repeated_composite_container.h> #include <google/protobuf/pyext/repeated_scalar_container.h> #include <google/protobuf/pyext/scoped_pyobject_ptr.h> -#include <google/protobuf/stubs/shared_ptr.h> + +#if PY_MAJOR_VERSION >= 3 + #if PY_VERSION_HEX < 0x03030000 + #error "Python 3.0 - 3.2 are not supported." + #endif + #define PyString_AsStringAndSize(ob, charpp, sizep) \ + (PyUnicode_Check(ob)? \ + ((*(charpp) = PyUnicode_AsUTF8AndSize(ob, (sizep))) == NULL? -1: 0): \ + PyBytes_AsStringAndSize(ob, (charpp), (sizep))) +#endif namespace google { namespace protobuf { @@ -60,35 +71,6 @@ PyObject* len(ExtensionDict* self) { #endif } -// TODO(tibell): Use VisitCompositeField. -int ReleaseExtension(ExtensionDict* self, - PyObject* extension, - const FieldDescriptor* descriptor) { - if (descriptor->label() == FieldDescriptor::LABEL_REPEATED) { - if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { - if (repeated_composite_container::Release( - reinterpret_cast<RepeatedCompositeContainer*>( - extension)) < 0) { - return -1; - } - } else { - if (repeated_scalar_container::Release( - reinterpret_cast<RepeatedScalarContainer*>( - extension)) < 0) { - return -1; - } - } - } else if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { - if (cmessage::ReleaseSubMessage( - self->parent, descriptor, - reinterpret_cast<CMessage*>(extension)) < 0) { - return -1; - } - } - - return 0; -} - PyObject* subscript(ExtensionDict* self, PyObject* key) { const FieldDescriptor* descriptor = cmessage::GetExtensionDescriptor(key); if (descriptor == NULL) { @@ -119,6 +101,7 @@ PyObject* subscript(ExtensionDict* self, PyObject* key) { if (descriptor->label() != FieldDescriptor::LABEL_REPEATED && descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { + // TODO(plabatut): consider building the class on the fly! PyObject* sub_message = cmessage::InternalGetSubMessage( self->parent, descriptor); if (sub_message == NULL) { @@ -130,9 +113,21 @@ PyObject* subscript(ExtensionDict* self, PyObject* key) { if (descriptor->label() == FieldDescriptor::LABEL_REPEATED) { if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { - PyObject *message_class = cdescriptor_pool::GetMessageClass( - cmessage::GetDescriptorPoolForMessage(self->parent), + // On the fly message class creation is needed to support the following + // situation: + // 1- add FileDescriptor to the pool that contains extensions of a message + // defined by another proto file. Do not create any message classes. + // 2- instantiate an extended message, and access the extension using + // the field descriptor. + // 3- the extension submessage fails to be returned, because no class has + // been created. + // It happens when deserializing text proto format, or when enumerating + // fields of a deserialized message. + CMessageClass* message_class = message_factory::GetOrCreateMessageClass( + cmessage::GetFactoryForMessage(self->parent), descriptor->message_type()); + ScopedPyObjectPtr message_class_handler( + reinterpret_cast<PyObject*>(message_class)); if (message_class == NULL) { return NULL; } @@ -183,60 +178,51 @@ int ass_subscript(ExtensionDict* self, PyObject* key, PyObject* value) { return 0; } -PyObject* ClearExtension(ExtensionDict* self, PyObject* extension) { - const FieldDescriptor* descriptor = - cmessage::GetExtensionDescriptor(extension); - if (descriptor == NULL) { +PyObject* _FindExtensionByName(ExtensionDict* self, PyObject* arg) { + char* name; + Py_ssize_t name_size; + if (PyString_AsStringAndSize(arg, &name, &name_size) < 0) { return NULL; } - PyObject* value = PyDict_GetItem(self->values, extension); - if (self->parent) { - if (value != NULL) { - if (ReleaseExtension(self, value, descriptor) < 0) { - return NULL; + + PyDescriptorPool* pool = cmessage::GetFactoryForMessage(self->parent)->pool; + const FieldDescriptor* message_extension = + pool->pool->FindExtensionByName(string(name, name_size)); + if (message_extension == NULL) { + // Is is the name of a message set extension? + const Descriptor* message_descriptor = pool->pool->FindMessageTypeByName( + string(name, name_size)); + if (message_descriptor && message_descriptor->extension_count() > 0) { + const FieldDescriptor* extension = message_descriptor->extension(0); + if (extension->is_extension() && + extension->containing_type()->options().message_set_wire_format() && + extension->type() == FieldDescriptor::TYPE_MESSAGE && + extension->label() == FieldDescriptor::LABEL_OPTIONAL) { + message_extension = extension; } } - if (ScopedPyObjectPtr(cmessage::ClearFieldByDescriptor( - self->parent, descriptor)) == NULL) { - return NULL; - } } - if (PyDict_DelItem(self->values, extension) < 0) { - PyErr_Clear(); + if (message_extension == NULL) { + Py_RETURN_NONE; } - Py_RETURN_NONE; -} -PyObject* HasExtension(ExtensionDict* self, PyObject* extension) { - const FieldDescriptor* descriptor = - cmessage::GetExtensionDescriptor(extension); - if (descriptor == NULL) { - return NULL; - } - if (self->parent) { - return cmessage::HasFieldByDescriptor(self->parent, descriptor); - } else { - int exists = PyDict_Contains(self->values, extension); - if (exists < 0) { - return NULL; - } - return PyBool_FromLong(exists); - } + return PyFieldDescriptor_FromDescriptor(message_extension); } -PyObject* _FindExtensionByName(ExtensionDict* self, PyObject* name) { - ScopedPyObjectPtr extensions_by_name(PyObject_GetAttrString( - reinterpret_cast<PyObject*>(self->parent), "_extensions_by_name")); - if (extensions_by_name == NULL) { +PyObject* _FindExtensionByNumber(ExtensionDict* self, PyObject* arg) { + int64 number = PyLong_AsLong(arg); + if (number == -1 && PyErr_Occurred()) { return NULL; } - PyObject* result = PyDict_GetItem(extensions_by_name.get(), name); - if (result == NULL) { + + PyDescriptorPool* pool = cmessage::GetFactoryForMessage(self->parent)->pool; + const FieldDescriptor* message_extension = pool->pool->FindExtensionByNumber( + self->parent->message->GetDescriptor(), number); + if (message_extension == NULL) { Py_RETURN_NONE; - } else { - Py_INCREF(result); - return result; } + + return PyFieldDescriptor_FromDescriptor(message_extension); } ExtensionDict* NewExtensionDict(CMessage *parent) { @@ -267,10 +253,10 @@ static PyMappingMethods MpMethods = { #define EDMETHOD(name, args, doc) { #name, (PyCFunction)name, args, doc } static PyMethodDef Methods[] = { - EDMETHOD(ClearExtension, METH_O, "Clears an extension from the object."), - EDMETHOD(HasExtension, METH_O, "Checks if the object has an extension."), EDMETHOD(_FindExtensionByName, METH_O, "Finds an extension by name."), + EDMETHOD(_FindExtensionByNumber, METH_O, + "Finds an extension by field number."), { NULL, NULL } }; |