diff options
author | Feng Xiao <xfxyjwf@gmail.com> | 2015-12-11 17:09:20 -0800 |
---|---|---|
committer | Feng Xiao <xfxyjwf@gmail.com> | 2015-12-11 17:10:28 -0800 |
commit | e841bac4fcf47f809e089a70d5f84ac37b3883df (patch) | |
tree | d25dc5fc814db182c04c5f276ff1a609c5965a5a /python/google/protobuf/pyext/message.cc | |
parent | 99a6a95c751a28a3cc33dd2384959179f83f682c (diff) | |
download | protobuf-e841bac4fcf47f809e089a70d5f84ac37b3883df.tar.gz protobuf-e841bac4fcf47f809e089a70d5f84ac37b3883df.tar.bz2 protobuf-e841bac4fcf47f809e089a70d5f84ac37b3883df.zip |
Down-integrate from internal code base.
Diffstat (limited to 'python/google/protobuf/pyext/message.cc')
-rw-r--r-- | python/google/protobuf/pyext/message.cc | 299 |
1 files changed, 163 insertions, 136 deletions
diff --git a/python/google/protobuf/pyext/message.cc b/python/google/protobuf/pyext/message.cc index 72f51ec1..863cde01 100644 --- a/python/google/protobuf/pyext/message.cc +++ b/python/google/protobuf/pyext/message.cc @@ -33,6 +33,7 @@ #include <google/protobuf/pyext/message.h> +#include <map> #include <memory> #ifndef _SHARED_PTR_H #include <google/protobuf/stubs/shared_ptr.h> @@ -61,8 +62,7 @@ #include <google/protobuf/pyext/extension_dict.h> #include <google/protobuf/pyext/repeated_composite_container.h> #include <google/protobuf/pyext/repeated_scalar_container.h> -#include <google/protobuf/pyext/message_map_container.h> -#include <google/protobuf/pyext/scalar_map_container.h> +#include <google/protobuf/pyext/map_container.h> #include <google/protobuf/pyext/scoped_pyobject_ptr.h> #include <google/protobuf/stubs/strutil.h> @@ -96,6 +96,7 @@ static PyObject* k_extensions_by_number; PyObject* EnumTypeWrapper_class; static PyObject* PythonMessage_class; static PyObject* kEmptyWeakref; +static PyObject* WKT_classes = NULL; // Defines the Metaclass of all Message classes. // It allows us to cache some C++ pointers in the class object itself, they are @@ -274,8 +275,32 @@ static PyObject* New(PyTypeObject* type, // Build the arguments to the base metaclass. // We change the __bases__ classes. - ScopedPyObjectPtr new_args(Py_BuildValue( - "s(OO)O", name, &CMessage_Type, PythonMessage_class, dict)); + ScopedPyObjectPtr new_args; + const Descriptor* message_descriptor = + PyMessageDescriptor_AsDescriptor(py_descriptor); + if (message_descriptor == NULL) { + return NULL; + } + + if (WKT_classes == NULL) { + ScopedPyObjectPtr well_known_types(PyImport_ImportModule( + "google.protobuf.internal.well_known_types")); + GOOGLE_DCHECK(well_known_types != NULL); + + WKT_classes = PyObject_GetAttrString(well_known_types.get(), "WKTBASES"); + GOOGLE_DCHECK(WKT_classes != NULL); + } + + PyObject* well_known_class = PyDict_GetItemString( + WKT_classes, message_descriptor->full_name().c_str()); + if (well_known_class == NULL) { + new_args.reset(Py_BuildValue("s(OO)O", name, &CMessage_Type, + PythonMessage_class, dict)); + } else { + new_args.reset(Py_BuildValue("s(OOO)O", name, &CMessage_Type, + PythonMessage_class, well_known_class, dict)); + } + if (new_args == NULL) { return NULL; } @@ -448,21 +473,9 @@ static int VisitCompositeField(const FieldDescriptor* descriptor, if (descriptor->label() == FieldDescriptor::LABEL_REPEATED) { if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { if (descriptor->is_map()) { - const Descriptor* entry_type = descriptor->message_type(); - const FieldDescriptor* value_type = - entry_type->FindFieldByName("value"); - if (value_type->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { - MessageMapContainer* container = - reinterpret_cast<MessageMapContainer*>(child); - if (visitor.VisitMessageMapContainer(container) == -1) { - return -1; - } - } else { - ScalarMapContainer* container = - reinterpret_cast<ScalarMapContainer*>(child); - if (visitor.VisitScalarMapContainer(container) == -1) { - return -1; - } + MapContainer* container = reinterpret_cast<MapContainer*>(child); + if (visitor.VisitMapContainer(container) == -1) { + return -1; } } else { RepeatedCompositeContainer* container = @@ -579,12 +592,14 @@ bool CheckAndGetInteger( if (PyObject_RichCompareBool(min, arg, Py_LE) != 1 || PyObject_RichCompareBool(max, arg, Py_GE) != 1) { #endif - PyObject *s = PyObject_Str(arg); - if (s) { - PyErr_Format(PyExc_ValueError, - "Value out of range: %s", - PyString_AsString(s)); - Py_DECREF(s); + if (!PyErr_Occurred()) { + PyObject *s = PyObject_Str(arg); + if (s) { + PyErr_Format(PyExc_ValueError, + "Value out of range: %s", + PyString_AsString(s)); + Py_DECREF(s); + } } return false; } @@ -642,38 +657,51 @@ bool CheckAndGetBool(PyObject* arg, bool* value) { return true; } -bool CheckAndSetString( - PyObject* arg, Message* message, - const FieldDescriptor* descriptor, - const Reflection* reflection, - bool append, - int index) { +// Checks whether the given object (which must be "bytes" or "unicode") contains +// valid UTF-8. +bool IsValidUTF8(PyObject* obj) { + if (PyBytes_Check(obj)) { + PyObject* unicode = PyUnicode_FromEncodedObject(obj, "utf-8", NULL); + + // Clear the error indicator; we report our own error when desired. + PyErr_Clear(); + + if (unicode) { + Py_DECREF(unicode); + return true; + } else { + return false; + } + } else { + // Unicode object, known to be valid UTF-8. + return true; + } +} + +bool AllowInvalidUTF8(const FieldDescriptor* field) { return false; } + +PyObject* CheckString(PyObject* arg, const FieldDescriptor* descriptor) { GOOGLE_DCHECK(descriptor->type() == FieldDescriptor::TYPE_STRING || descriptor->type() == FieldDescriptor::TYPE_BYTES); if (descriptor->type() == FieldDescriptor::TYPE_STRING) { if (!PyBytes_Check(arg) && !PyUnicode_Check(arg)) { FormatTypeError(arg, "bytes, unicode"); - return false; + return NULL; } - if (PyBytes_Check(arg)) { - PyObject* unicode = PyUnicode_FromEncodedObject(arg, "utf-8", NULL); - if (unicode == NULL) { - PyObject* repr = PyObject_Repr(arg); - PyErr_Format(PyExc_ValueError, - "%s has type str, but isn't valid UTF-8 " - "encoding. Non-UTF-8 strings must be converted to " - "unicode objects before being added.", - PyString_AsString(repr)); - Py_DECREF(repr); - return false; - } else { - Py_DECREF(unicode); - } + if (!IsValidUTF8(arg) && !AllowInvalidUTF8(descriptor)) { + PyObject* repr = PyObject_Repr(arg); + PyErr_Format(PyExc_ValueError, + "%s has type str, but isn't valid UTF-8 " + "encoding. Non-UTF-8 strings must be converted to " + "unicode objects before being added.", + PyString_AsString(repr)); + Py_DECREF(repr); + return NULL; } } else if (!PyBytes_Check(arg)) { FormatTypeError(arg, "bytes"); - return false; + return NULL; } PyObject* encoded_string = NULL; @@ -691,14 +719,24 @@ bool CheckAndSetString( Py_INCREF(encoded_string); } - if (encoded_string == NULL) { + return encoded_string; +} + +bool CheckAndSetString( + PyObject* arg, Message* message, + const FieldDescriptor* descriptor, + const Reflection* reflection, + bool append, + int index) { + ScopedPyObjectPtr encoded_string(CheckString(arg, descriptor)); + + if (encoded_string.get() == NULL) { return false; } char* value; Py_ssize_t value_len; - if (PyBytes_AsStringAndSize(encoded_string, &value, &value_len) < 0) { - Py_DECREF(encoded_string); + if (PyBytes_AsStringAndSize(encoded_string.get(), &value, &value_len) < 0) { return false; } @@ -710,7 +748,6 @@ bool CheckAndSetString( } else { reflection->SetRepeatedString(message, descriptor, index, value_string); } - Py_DECREF(encoded_string); return true; } @@ -823,12 +860,7 @@ struct FixupMessageReference : public ChildVisitor { return 0; } - int VisitScalarMapContainer(ScalarMapContainer* container) { - container->message = message_; - return 0; - } - - int VisitMessageMapContainer(MessageMapContainer* container) { + int VisitMapContainer(MapContainer* container) { container->message = message_; return 0; } @@ -870,9 +902,8 @@ int AssureWritable(CMessage* self) { // When a CMessage is made writable its Message pointer is updated // to point to a new mutable Message. When that happens we need to // update any references to the old, read-only CMessage. There are - // five places such references occur: RepeatedScalarContainer, - // RepeatedCompositeContainer, ScalarMapContainer, MessageMapContainer, - // and ExtensionDict. + // four places such references occur: RepeatedScalarContainer, + // RepeatedCompositeContainer, MapContainer, and ExtensionDict. if (self->extensions != NULL) self->extensions->message = self->message; if (ForEachCompositeField(self, FixupMessageReference(self->message)) == -1) @@ -1054,7 +1085,8 @@ int InitAttributes(CMessage* self, PyObject* kwargs) { } const FieldDescriptor* descriptor = GetFieldDescriptor(self, name); if (descriptor == NULL) { - PyErr_Format(PyExc_ValueError, "Protocol message has no \"%s\" field.", + PyErr_Format(PyExc_ValueError, "Protocol message %s has no \"%s\" field.", + self->message->GetDescriptor()->name().c_str(), PyString_AsString(name)); return -1; } @@ -1203,18 +1235,6 @@ CMessage* NewEmptyMessage(PyObject* type, const Descriptor *descriptor) { self->composite_fields = NULL; - // If there are extension_ranges, the message is "extendable". Allocate a - // dictionary to store the extension fields. - if (descriptor->extension_range_count() > 0) { - // TODO(amauryfa): Delay the construction of this dict until extensions are - // really used on the object. - ExtensionDict* extension_dict = extension_dict::NewExtensionDict(self); - if (extension_dict == NULL) { - return NULL; - } - self->extensions = extension_dict; - } - return self; } @@ -1285,12 +1305,7 @@ struct ClearWeakReferences : public ChildVisitor { return 0; } - int VisitScalarMapContainer(ScalarMapContainer* container) { - container->parent = NULL; - return 0; - } - - int VisitMessageMapContainer(MessageMapContainer* container) { + int VisitMapContainer(MapContainer* container) { container->parent = NULL; return 0; } @@ -1305,6 +1320,9 @@ struct ClearWeakReferences : public ChildVisitor { static void Dealloc(CMessage* self) { // Null out all weak references from children to this message. GOOGLE_CHECK_EQ(0, ForEachCompositeField(self, ClearWeakReferences())); + if (self->extensions) { + self->extensions->parent = NULL; + } Py_CLEAR(self->extensions); Py_CLEAR(self->composite_fields); @@ -1466,20 +1484,27 @@ PyObject* HasField(CMessage* self, PyObject* arg) { Py_RETURN_FALSE; } -PyObject* ClearExtension(CMessage* self, PyObject* arg) { +PyObject* ClearExtension(CMessage* self, PyObject* extension) { if (self->extensions != NULL) { - return extension_dict::ClearExtension(self->extensions, arg); + return extension_dict::ClearExtension(self->extensions, extension); + } else { + const FieldDescriptor* descriptor = GetExtensionDescriptor(extension); + if (descriptor == NULL) { + return NULL; + } + if (ScopedPyObjectPtr(ClearFieldByDescriptor(self, descriptor)) == NULL) { + return NULL; + } } - PyErr_SetString(PyExc_TypeError, "Message is not extendable"); - return NULL; + Py_RETURN_NONE; } -PyObject* HasExtension(CMessage* self, PyObject* arg) { - if (self->extensions != NULL) { - return extension_dict::HasExtension(self->extensions, arg); +PyObject* HasExtension(CMessage* self, PyObject* extension) { + const FieldDescriptor* descriptor = GetExtensionDescriptor(extension); + if (descriptor == NULL) { + return NULL; } - PyErr_SetString(PyExc_TypeError, "Message is not extendable"); - return NULL; + return HasFieldByDescriptor(self, descriptor); } // --------------------------------------------------------------------- @@ -1529,13 +1554,8 @@ struct SetOwnerVisitor : public ChildVisitor { return 0; } - int VisitScalarMapContainer(ScalarMapContainer* container) { - scalar_map_container::SetOwner(container, new_owner_); - return 0; - } - - int VisitMessageMapContainer(MessageMapContainer* container) { - message_map_container::SetOwner(container, new_owner_); + int VisitMapContainer(MapContainer* container) { + container->SetOwner(new_owner_); return 0; } @@ -1608,14 +1628,8 @@ struct ReleaseChild : public ChildVisitor { reinterpret_cast<RepeatedScalarContainer*>(container)); } - int VisitScalarMapContainer(ScalarMapContainer* container) { - return scalar_map_container::Release( - reinterpret_cast<ScalarMapContainer*>(container)); - } - - int VisitMessageMapContainer(MessageMapContainer* container) { - return message_map_container::Release( - reinterpret_cast<MessageMapContainer*>(container)); + int VisitMapContainer(MapContainer* container) { + return reinterpret_cast<MapContainer*>(container)->Release(); } int VisitCMessage(CMessage* cmessage, @@ -1707,17 +1721,7 @@ PyObject* Clear(CMessage* self) { AssureWritable(self); if (ForEachCompositeField(self, ReleaseChild(self)) == -1) return NULL; - - // The old ExtensionDict still aliases this CMessage, but all its - // fields have been released. - if (self->extensions != NULL) { - Py_CLEAR(self->extensions); - ExtensionDict* extension_dict = extension_dict::NewExtensionDict(self); - if (extension_dict == NULL) { - return NULL; - } - self->extensions = extension_dict; - } + Py_CLEAR(self->extensions); if (self->composite_fields) { PyDict_Clear(self->composite_fields); } @@ -1997,7 +2001,6 @@ static PyObject* RegisterExtension(PyObject* cls, if (descriptor->is_extension() && descriptor->containing_type()->options().message_set_wire_format() && descriptor->type() == FieldDescriptor::TYPE_MESSAGE && - descriptor->message_type() == descriptor->extension_scope() && descriptor->label() == FieldDescriptor::LABEL_OPTIONAL) { ScopedPyObjectPtr message_name(PyString_FromStringAndSize( descriptor->message_type()->full_name().c_str(), @@ -2042,6 +2045,8 @@ static PyObject* WhichOneof(CMessage* self, PyObject* arg) { } } +static PyObject* GetExtensionDict(CMessage* self, void *closure); + static PyObject* ListFields(CMessage* self) { vector<const FieldDescriptor*> fields; self->message->GetReflection()->ListFields(*self->message, &fields); @@ -2079,12 +2084,13 @@ static PyObject* ListFields(CMessage* self) { PyErr_Clear(); continue; } - PyObject* extensions = reinterpret_cast<PyObject*>(self->extensions); + ScopedPyObjectPtr extensions(GetExtensionDict(self, NULL)); if (extensions == NULL) { return NULL; } // 'extension' reference later stolen by PyTuple_SET_ITEM. - PyObject* extension = PyObject_GetItem(extensions, extension_field.get()); + PyObject* extension = PyObject_GetItem( + extensions.get(), extension_field.get()); if (extension == NULL) { return NULL; } @@ -2493,9 +2499,31 @@ PyObject* _CheckCalledFromGeneratedFile(PyObject* unused, Py_RETURN_NONE; } -static PyMemberDef Members[] = { - {"Extensions", T_OBJECT_EX, offsetof(CMessage, extensions), 0, - "Extension dict"}, +static PyObject* GetExtensionDict(CMessage* self, void *closure) { + if (self->extensions) { + Py_INCREF(self->extensions); + return reinterpret_cast<PyObject*>(self->extensions); + } + + // If there are extension_ranges, the message is "extendable". Allocate a + // dictionary to store the extension fields. + const Descriptor* descriptor = GetMessageDescriptor(Py_TYPE(self)); + if (descriptor->extension_range_count() > 0) { + ExtensionDict* extension_dict = extension_dict::NewExtensionDict(self); + if (extension_dict == NULL) { + return NULL; + } + self->extensions = extension_dict; + Py_INCREF(self->extensions); + return reinterpret_cast<PyObject*>(self->extensions); + } + + PyErr_SetNone(PyExc_AttributeError); + return NULL; +} + +static PyGetSetDef Getters[] = { + {"Extensions", (getter)GetExtensionDict, NULL, "Extension dict"}, {NULL} }; @@ -2592,10 +2620,10 @@ PyObject* GetAttr(CMessage* self, PyObject* name) { if (value_class == NULL) { return NULL; } - py_container = message_map_container::NewContainer(self, field_descriptor, - value_class); + py_container = + NewMessageMapContainer(self, field_descriptor, value_class); } else { - py_container = scalar_map_container::NewContainer(self, field_descriptor); + py_container = NewScalarMapContainer(self, field_descriptor); } if (py_container == NULL) { return NULL; @@ -2672,7 +2700,10 @@ int SetAttr(CMessage* self, PyObject* name, PyObject* value) { } } - PyErr_Format(PyExc_AttributeError, "Assignment not allowed"); + PyErr_Format(PyExc_AttributeError, + "Assignment not allowed " + "(no field \"%s\"in protocol message object).", + PyString_AsString(name)); return -1; } @@ -2707,8 +2738,8 @@ PyTypeObject CMessage_Type = { 0, // tp_iter 0, // tp_iternext cmessage::Methods, // tp_methods - cmessage::Members, // tp_members - 0, // tp_getset + 0, // tp_members + cmessage::Getters, // tp_getset 0, // tp_base 0, // tp_dict 0, // tp_descr_get @@ -2910,12 +2941,12 @@ bool InitProto2MessageModule(PyObject *m) { reinterpret_cast<PyObject*>(&ScalarMapContainer_Type)); #endif - if (PyType_Ready(&ScalarMapIterator_Type) < 0) { + if (PyType_Ready(&MapIterator_Type) < 0) { return false; } - PyModule_AddObject(m, "ScalarMapIterator", - reinterpret_cast<PyObject*>(&ScalarMapIterator_Type)); + PyModule_AddObject(m, "MapIterator", + reinterpret_cast<PyObject*>(&MapIterator_Type)); #if PY_MAJOR_VERSION >= 3 @@ -2934,13 +2965,6 @@ bool InitProto2MessageModule(PyObject *m) { PyModule_AddObject(m, "MessageMapContainer", reinterpret_cast<PyObject*>(&MessageMapContainer_Type)); #endif - - if (PyType_Ready(&MessageMapIterator_Type) < 0) { - return false; - } - - PyModule_AddObject(m, "MessageMapIterator", - reinterpret_cast<PyObject*>(&MessageMapIterator_Type)); } if (PyType_Ready(&ExtensionDict_Type) < 0) { @@ -2957,6 +2981,9 @@ bool InitProto2MessageModule(PyObject *m) { PyModule_AddObject(m, "default_pool", reinterpret_cast<PyObject*>(GetDefaultDescriptorPool())); + PyModule_AddObject(m, "DescriptorPool", reinterpret_cast<PyObject*>( + &PyDescriptorPool_Type)); + // This implementation provides full Descriptor types, we advertise it so that // descriptor.py can use them in replacement of the Python classes. PyModule_AddIntConstant(m, "_USE_C_DESCRIPTORS", 1); |