aboutsummaryrefslogtreecommitdiff
path: root/python/google/protobuf/pyext/extension_dict.cc
diff options
context:
space:
mode:
Diffstat (limited to 'python/google/protobuf/pyext/extension_dict.cc')
-rw-r--r--python/google/protobuf/pyext/extension_dict.cc77
1 files changed, 58 insertions, 19 deletions
diff --git a/python/google/protobuf/pyext/extension_dict.cc b/python/google/protobuf/pyext/extension_dict.cc
index dbb7bca0..9423c1d8 100644
--- a/python/google/protobuf/pyext/extension_dict.cc
+++ b/python/google/protobuf/pyext/extension_dict.cc
@@ -38,6 +38,7 @@
#include <google/protobuf/descriptor.h>
#include <google/protobuf/dynamic_message.h>
#include <google/protobuf/message.h>
+#include <google/protobuf/descriptor.pb.h>
#include <google/protobuf/pyext/descriptor.h>
#include <google/protobuf/pyext/message.h>
#include <google/protobuf/pyext/message_factory.h>
@@ -46,6 +47,16 @@
#include <google/protobuf/pyext/scoped_pyobject_ptr.h>
#include <google/protobuf/stubs/shared_ptr.h>
+#if PY_MAJOR_VERSION >= 3
+ #if PY_VERSION_HEX < 0x03030000
+ #error "Python 3.0 - 3.2 are not supported."
+ #endif
+ #define PyString_AsStringAndSize(ob, charpp, sizep) \
+ (PyUnicode_Check(ob)? \
+ ((*(charpp) = PyUnicode_AsUTF8AndSize(ob, (sizep))) == NULL? -1: 0): \
+ PyBytes_AsStringAndSize(ob, (charpp), (sizep)))
+#endif
+
namespace google {
namespace protobuf {
namespace python {
@@ -90,6 +101,7 @@ PyObject* subscript(ExtensionDict* self, PyObject* key) {
if (descriptor->label() != FieldDescriptor::LABEL_REPEATED &&
descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
+ // TODO(plabatut): consider building the class on the fly!
PyObject* sub_message = cmessage::InternalGetSubMessage(
self->parent, descriptor);
if (sub_message == NULL) {
@@ -101,7 +113,17 @@ PyObject* subscript(ExtensionDict* self, PyObject* key) {
if (descriptor->label() == FieldDescriptor::LABEL_REPEATED) {
if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
- CMessageClass* message_class = message_factory::GetMessageClass(
+ // On the fly message class creation is needed to support the following
+ // situation:
+ // 1- add FileDescriptor to the pool that contains extensions of a message
+ // defined by another proto file. Do not create any message classes.
+ // 2- instantiate an extended message, and access the extension using
+ // the field descriptor.
+ // 3- the extension submessage fails to be returned, because no class has
+ // been created.
+ // It happens when deserializing text proto format, or when enumerating
+ // fields of a deserialized message.
+ CMessageClass* message_class = message_factory::GetOrCreateMessageClass(
cmessage::GetFactoryForMessage(self->parent),
descriptor->message_type());
if (message_class == NULL) {
@@ -154,34 +176,51 @@ int ass_subscript(ExtensionDict* self, PyObject* key, PyObject* value) {
return 0;
}
-PyObject* _FindExtensionByName(ExtensionDict* self, PyObject* name) {
- ScopedPyObjectPtr extensions_by_name(PyObject_GetAttrString(
- reinterpret_cast<PyObject*>(self->parent), "_extensions_by_name"));
- if (extensions_by_name == NULL) {
+PyObject* _FindExtensionByName(ExtensionDict* self, PyObject* arg) {
+ char* name;
+ Py_ssize_t name_size;
+ if (PyString_AsStringAndSize(arg, &name, &name_size) < 0) {
return NULL;
}
- PyObject* result = PyDict_GetItem(extensions_by_name.get(), name);
- if (result == NULL) {
+
+ PyDescriptorPool* pool = cmessage::GetFactoryForMessage(self->parent)->pool;
+ const FieldDescriptor* message_extension =
+ pool->pool->FindExtensionByName(string(name, name_size));
+ if (message_extension == NULL) {
+ // Is is the name of a message set extension?
+ const Descriptor* message_descriptor = pool->pool->FindMessageTypeByName(
+ string(name, name_size));
+ if (message_descriptor && message_descriptor->extension_count() > 0) {
+ const FieldDescriptor* extension = message_descriptor->extension(0);
+ if (extension->is_extension() &&
+ extension->containing_type()->options().message_set_wire_format() &&
+ extension->type() == FieldDescriptor::TYPE_MESSAGE &&
+ extension->label() == FieldDescriptor::LABEL_OPTIONAL) {
+ message_extension = extension;
+ }
+ }
+ }
+ if (message_extension == NULL) {
Py_RETURN_NONE;
- } else {
- Py_INCREF(result);
- return result;
}
+
+ return PyFieldDescriptor_FromDescriptor(message_extension);
}
-PyObject* _FindExtensionByNumber(ExtensionDict* self, PyObject* number) {
- ScopedPyObjectPtr extensions_by_number(PyObject_GetAttrString(
- reinterpret_cast<PyObject*>(self->parent), "_extensions_by_number"));
- if (extensions_by_number == NULL) {
+PyObject* _FindExtensionByNumber(ExtensionDict* self, PyObject* arg) {
+ int64 number = PyLong_AsLong(arg);
+ if (number == -1 && PyErr_Occurred()) {
return NULL;
}
- PyObject* result = PyDict_GetItem(extensions_by_number.get(), number);
- if (result == NULL) {
+
+ PyDescriptorPool* pool = cmessage::GetFactoryForMessage(self->parent)->pool;
+ const FieldDescriptor* message_extension = pool->pool->FindExtensionByNumber(
+ self->parent->message->GetDescriptor(), number);
+ if (message_extension == NULL) {
Py_RETURN_NONE;
- } else {
- Py_INCREF(result);
- return result;
}
+
+ return PyFieldDescriptor_FromDescriptor(message_extension);
}
ExtensionDict* NewExtensionDict(CMessage *parent) {