diff options
Diffstat (limited to 'python/google/protobuf/pyext/extension_dict.cc')
-rw-r--r-- | python/google/protobuf/pyext/extension_dict.cc | 68 |
1 files changed, 52 insertions, 16 deletions
diff --git a/python/google/protobuf/pyext/extension_dict.cc b/python/google/protobuf/pyext/extension_dict.cc index 9c9b4178..21bbb8c2 100644 --- a/python/google/protobuf/pyext/extension_dict.cc +++ b/python/google/protobuf/pyext/extension_dict.cc @@ -94,13 +94,13 @@ PyObject* subscript(ExtensionDict* self, PyObject* key) { if (descriptor == NULL) { return NULL; } - if (!CheckFieldBelongsToMessage(descriptor, self->parent->message)) { + if (!CheckFieldBelongsToMessage(descriptor, self->message)) { return NULL; } if (descriptor->label() != FieldDescriptor::LABEL_REPEATED && descriptor->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE) { - return cmessage::InternalGetScalar(self->parent->message, descriptor); + return cmessage::InternalGetScalar(self->message, descriptor); } PyObject* value = PyDict_GetItem(self->values, key); @@ -109,6 +109,14 @@ PyObject* subscript(ExtensionDict* self, PyObject* key) { return value; } + if (self->parent == NULL) { + // We are in "detached" state. Don't allow further modifications. + // TODO(amauryfa): Support adding non-scalars to a detached extension dict. + // This probably requires to store the type of the main message. + PyErr_SetObject(PyExc_KeyError, key); + return NULL; + } + if (descriptor->label() != FieldDescriptor::LABEL_REPEATED && descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { PyObject* sub_message = cmessage::InternalGetSubMessage( @@ -122,7 +130,7 @@ PyObject* subscript(ExtensionDict* self, PyObject* key) { if (descriptor->label() == FieldDescriptor::LABEL_REPEATED) { if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { - PyObject *message_class = cdescriptor_pool::GetMessageClass( + CMessageClass* message_class = cdescriptor_pool::GetMessageClass( cmessage::GetDescriptorPoolForMessage(self->parent), descriptor->message_type()); if (message_class == NULL) { @@ -154,7 +162,7 @@ int ass_subscript(ExtensionDict* self, PyObject* key, PyObject* value) { if (descriptor == NULL) { return -1; } - if (!CheckFieldBelongsToMessage(descriptor, self->parent->message)) { + if (!CheckFieldBelongsToMessage(descriptor, self->message)) { return -1; } @@ -164,9 +172,11 @@ int ass_subscript(ExtensionDict* self, PyObject* key, PyObject* value) { "type"); return -1; } - cmessage::AssureWritable(self->parent); - if (cmessage::InternalSetScalar(self->parent, descriptor, value) < 0) { - return -1; + if (self->parent) { + cmessage::AssureWritable(self->parent); + if (cmessage::InternalSetScalar(self->parent, descriptor, value) < 0) { + return -1; + } } // TODO(tibell): We shouldn't write scalars to the cache. PyDict_SetItem(self->values, key, value); @@ -180,15 +190,17 @@ PyObject* ClearExtension(ExtensionDict* self, PyObject* extension) { return NULL; } PyObject* value = PyDict_GetItem(self->values, extension); - if (value != NULL) { - if (ReleaseExtension(self, value, descriptor) < 0) { + if (self->parent) { + if (value != NULL) { + if (ReleaseExtension(self, value, descriptor) < 0) { + return NULL; + } + } + if (ScopedPyObjectPtr(cmessage::ClearFieldByDescriptor( + self->parent, descriptor)) == NULL) { return NULL; } } - if (ScopedPyObjectPtr(cmessage::ClearFieldByDescriptor( - self->parent, descriptor)) == NULL) { - return NULL; - } if (PyDict_DelItem(self->values, extension) < 0) { PyErr_Clear(); } @@ -201,8 +213,15 @@ PyObject* HasExtension(ExtensionDict* self, PyObject* extension) { if (descriptor == NULL) { return NULL; } - PyObject* result = cmessage::HasFieldByDescriptor(self->parent, descriptor); - return result; + if (self->parent) { + return cmessage::HasFieldByDescriptor(self->parent, descriptor); + } else { + int exists = PyDict_Contains(self->values, extension); + if (exists < 0) { + return NULL; + } + return PyBool_FromLong(exists); + } } PyObject* _FindExtensionByName(ExtensionDict* self, PyObject* name) { @@ -211,7 +230,22 @@ PyObject* _FindExtensionByName(ExtensionDict* self, PyObject* name) { if (extensions_by_name == NULL) { return NULL; } - PyObject* result = PyDict_GetItem(extensions_by_name, name); + PyObject* result = PyDict_GetItem(extensions_by_name.get(), name); + if (result == NULL) { + Py_RETURN_NONE; + } else { + Py_INCREF(result); + return result; + } +} + +PyObject* _FindExtensionByNumber(ExtensionDict* self, PyObject* number) { + ScopedPyObjectPtr extensions_by_number(PyObject_GetAttrString( + reinterpret_cast<PyObject*>(self->parent), "_extensions_by_number")); + if (extensions_by_number == NULL) { + return NULL; + } + PyObject* result = PyDict_GetItem(extensions_by_number.get(), number); if (result == NULL) { Py_RETURN_NONE; } else { @@ -252,6 +286,8 @@ static PyMethodDef Methods[] = { EDMETHOD(HasExtension, METH_O, "Checks if the object has an extension."), EDMETHOD(_FindExtensionByName, METH_O, "Finds an extension by name."), + EDMETHOD(_FindExtensionByNumber, METH_O, + "Finds an extension by field number."), { NULL, NULL } }; |