diff options
Diffstat (limited to 'python/google/protobuf/pyext/extension_dict.cc')
-rw-r--r-- | python/google/protobuf/pyext/extension_dict.cc | 77 |
1 files changed, 58 insertions, 19 deletions
diff --git a/python/google/protobuf/pyext/extension_dict.cc b/python/google/protobuf/pyext/extension_dict.cc index dbb7bca0..9423c1d8 100644 --- a/python/google/protobuf/pyext/extension_dict.cc +++ b/python/google/protobuf/pyext/extension_dict.cc @@ -38,6 +38,7 @@ #include <google/protobuf/descriptor.h> #include <google/protobuf/dynamic_message.h> #include <google/protobuf/message.h> +#include <google/protobuf/descriptor.pb.h> #include <google/protobuf/pyext/descriptor.h> #include <google/protobuf/pyext/message.h> #include <google/protobuf/pyext/message_factory.h> @@ -46,6 +47,16 @@ #include <google/protobuf/pyext/scoped_pyobject_ptr.h> #include <google/protobuf/stubs/shared_ptr.h> +#if PY_MAJOR_VERSION >= 3 + #if PY_VERSION_HEX < 0x03030000 + #error "Python 3.0 - 3.2 are not supported." + #endif + #define PyString_AsStringAndSize(ob, charpp, sizep) \ + (PyUnicode_Check(ob)? \ + ((*(charpp) = PyUnicode_AsUTF8AndSize(ob, (sizep))) == NULL? -1: 0): \ + PyBytes_AsStringAndSize(ob, (charpp), (sizep))) +#endif + namespace google { namespace protobuf { namespace python { @@ -90,6 +101,7 @@ PyObject* subscript(ExtensionDict* self, PyObject* key) { if (descriptor->label() != FieldDescriptor::LABEL_REPEATED && descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { + // TODO(plabatut): consider building the class on the fly! PyObject* sub_message = cmessage::InternalGetSubMessage( self->parent, descriptor); if (sub_message == NULL) { @@ -101,7 +113,17 @@ PyObject* subscript(ExtensionDict* self, PyObject* key) { if (descriptor->label() == FieldDescriptor::LABEL_REPEATED) { if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { - CMessageClass* message_class = message_factory::GetMessageClass( + // On the fly message class creation is needed to support the following + // situation: + // 1- add FileDescriptor to the pool that contains extensions of a message + // defined by another proto file. Do not create any message classes. + // 2- instantiate an extended message, and access the extension using + // the field descriptor. + // 3- the extension submessage fails to be returned, because no class has + // been created. + // It happens when deserializing text proto format, or when enumerating + // fields of a deserialized message. + CMessageClass* message_class = message_factory::GetOrCreateMessageClass( cmessage::GetFactoryForMessage(self->parent), descriptor->message_type()); if (message_class == NULL) { @@ -154,34 +176,51 @@ int ass_subscript(ExtensionDict* self, PyObject* key, PyObject* value) { return 0; } -PyObject* _FindExtensionByName(ExtensionDict* self, PyObject* name) { - ScopedPyObjectPtr extensions_by_name(PyObject_GetAttrString( - reinterpret_cast<PyObject*>(self->parent), "_extensions_by_name")); - if (extensions_by_name == NULL) { +PyObject* _FindExtensionByName(ExtensionDict* self, PyObject* arg) { + char* name; + Py_ssize_t name_size; + if (PyString_AsStringAndSize(arg, &name, &name_size) < 0) { return NULL; } - PyObject* result = PyDict_GetItem(extensions_by_name.get(), name); - if (result == NULL) { + + PyDescriptorPool* pool = cmessage::GetFactoryForMessage(self->parent)->pool; + const FieldDescriptor* message_extension = + pool->pool->FindExtensionByName(string(name, name_size)); + if (message_extension == NULL) { + // Is is the name of a message set extension? + const Descriptor* message_descriptor = pool->pool->FindMessageTypeByName( + string(name, name_size)); + if (message_descriptor && message_descriptor->extension_count() > 0) { + const FieldDescriptor* extension = message_descriptor->extension(0); + if (extension->is_extension() && + extension->containing_type()->options().message_set_wire_format() && + extension->type() == FieldDescriptor::TYPE_MESSAGE && + extension->label() == FieldDescriptor::LABEL_OPTIONAL) { + message_extension = extension; + } + } + } + if (message_extension == NULL) { Py_RETURN_NONE; - } else { - Py_INCREF(result); - return result; } + + return PyFieldDescriptor_FromDescriptor(message_extension); } -PyObject* _FindExtensionByNumber(ExtensionDict* self, PyObject* number) { - ScopedPyObjectPtr extensions_by_number(PyObject_GetAttrString( - reinterpret_cast<PyObject*>(self->parent), "_extensions_by_number")); - if (extensions_by_number == NULL) { +PyObject* _FindExtensionByNumber(ExtensionDict* self, PyObject* arg) { + int64 number = PyLong_AsLong(arg); + if (number == -1 && PyErr_Occurred()) { return NULL; } - PyObject* result = PyDict_GetItem(extensions_by_number.get(), number); - if (result == NULL) { + + PyDescriptorPool* pool = cmessage::GetFactoryForMessage(self->parent)->pool; + const FieldDescriptor* message_extension = pool->pool->FindExtensionByNumber( + self->parent->message->GetDescriptor(), number); + if (message_extension == NULL) { Py_RETURN_NONE; - } else { - Py_INCREF(result); - return result; } + + return PyFieldDescriptor_FromDescriptor(message_extension); } ExtensionDict* NewExtensionDict(CMessage *parent) { |