diff options
Diffstat (limited to 'python/google/protobuf/pyext/message.cc')
-rw-r--r-- | python/google/protobuf/pyext/message.cc | 594 |
1 files changed, 335 insertions, 259 deletions
diff --git a/python/google/protobuf/pyext/message.cc b/python/google/protobuf/pyext/message.cc index 9318c834..83c151ff 100644 --- a/python/google/protobuf/pyext/message.cc +++ b/python/google/protobuf/pyext/message.cc @@ -33,6 +33,7 @@ #include <google/protobuf/pyext/message.h> +#include <map> #include <memory> #ifndef _SHARED_PTR_H #include <google/protobuf/stubs/shared_ptr.h> @@ -61,8 +62,7 @@ #include <google/protobuf/pyext/extension_dict.h> #include <google/protobuf/pyext/repeated_composite_container.h> #include <google/protobuf/pyext/repeated_scalar_container.h> -#include <google/protobuf/pyext/message_map_container.h> -#include <google/protobuf/pyext/scalar_map_container.h> +#include <google/protobuf/pyext/map_container.h> #include <google/protobuf/pyext/scoped_pyobject_ptr.h> #include <google/protobuf/stubs/strutil.h> @@ -96,31 +96,7 @@ static PyObject* k_extensions_by_number; PyObject* EnumTypeWrapper_class; static PyObject* PythonMessage_class; static PyObject* kEmptyWeakref; - -// Defines the Metaclass of all Message classes. -// It allows us to cache some C++ pointers in the class object itself, they are -// faster to extract than from the type's dictionary. - -struct PyMessageMeta { - // This is how CPython subclasses C structures: the base structure must be - // the first member of the object. - PyHeapTypeObject super; - - // C++ descriptor of this message. - const Descriptor* message_descriptor; - - // Owned reference, used to keep the pointer above alive. - PyObject* py_message_descriptor; - - // The Python DescriptorPool used to create the class. It is needed to resolve - // fields descriptors, including extensions fields; its C++ MessageFactory is - // used to instantiate submessages. - // This can be different from DESCRIPTOR.file.pool, in the case of a custom - // DescriptorPool which defines new extensions. - // We own the reference, because it's important to keep the descriptors and - // factory alive. - PyDescriptorPool* py_descriptor_pool; -}; +static PyObject* WKT_classes = NULL; namespace message_meta { @@ -142,7 +118,7 @@ static bool AddFieldNumberToClass( if (number == NULL) { return false; } - if (PyObject_SetAttr(cls, attr_name, number) == -1) { + if (PyObject_SetAttr(cls, attr_name.get(), number.get()) == -1) { return false; } return true; @@ -155,11 +131,11 @@ static int AddDescriptors(PyObject* cls, const Descriptor* descriptor) { // classes will register themselves in this class. if (descriptor->extension_range_count() > 0) { ScopedPyObjectPtr by_name(PyDict_New()); - if (PyObject_SetAttr(cls, k_extensions_by_name, by_name) < 0) { + if (PyObject_SetAttr(cls, k_extensions_by_name, by_name.get()) < 0) { return -1; } ScopedPyObjectPtr by_number(PyDict_New()); - if (PyObject_SetAttr(cls, k_extensions_by_number, by_number) < 0) { + if (PyObject_SetAttr(cls, k_extensions_by_number, by_number.get()) < 0) { return -1; } } @@ -172,10 +148,6 @@ static int AddDescriptors(PyObject* cls, const Descriptor* descriptor) { } // For each enum set cls.<enum name> = EnumTypeWrapper(<enum descriptor>). - // - // The enum descriptor we get from - // <messagedescriptor>.enum_types_by_name[name] - // which was built previously. for (int i = 0; i < descriptor->enum_type_count(); ++i) { const EnumDescriptor* enum_descriptor = descriptor->enum_type(i); ScopedPyObjectPtr enum_type( @@ -190,7 +162,7 @@ static int AddDescriptors(PyObject* cls, const Descriptor* descriptor) { return -1; } if (PyObject_SetAttrString( - cls, enum_descriptor->name().c_str(), wrapped) == -1) { + cls, enum_descriptor->name().c_str(), wrapped.get()) == -1) { return -1; } @@ -203,8 +175,8 @@ static int AddDescriptors(PyObject* cls, const Descriptor* descriptor) { if (value_number == NULL) { return -1; } - if (PyObject_SetAttrString( - cls, enum_value_descriptor->name().c_str(), value_number) == -1) { + if (PyObject_SetAttrString(cls, enum_value_descriptor->name().c_str(), + value_number.get()) == -1) { return -1; } } @@ -224,7 +196,7 @@ static int AddDescriptors(PyObject* cls, const Descriptor* descriptor) { // Add the extension field to the message class. if (PyObject_SetAttrString( - cls, field->name().c_str(), extension_field) == -1) { + cls, field->name().c_str(), extension_field.get()) == -1) { return -1; } @@ -274,17 +246,41 @@ static PyObject* New(PyTypeObject* type, // Build the arguments to the base metaclass. // We change the __bases__ classes. - ScopedPyObjectPtr new_args(Py_BuildValue( - "s(OO)O", name, &CMessage_Type, PythonMessage_class, dict)); + ScopedPyObjectPtr new_args; + const Descriptor* message_descriptor = + PyMessageDescriptor_AsDescriptor(py_descriptor); + if (message_descriptor == NULL) { + return NULL; + } + + if (WKT_classes == NULL) { + ScopedPyObjectPtr well_known_types(PyImport_ImportModule( + "google.protobuf.internal.well_known_types")); + GOOGLE_DCHECK(well_known_types != NULL); + + WKT_classes = PyObject_GetAttrString(well_known_types.get(), "WKTBASES"); + GOOGLE_DCHECK(WKT_classes != NULL); + } + + PyObject* well_known_class = PyDict_GetItemString( + WKT_classes, message_descriptor->full_name().c_str()); + if (well_known_class == NULL) { + new_args.reset(Py_BuildValue("s(OO)O", name, &CMessage_Type, + PythonMessage_class, dict)); + } else { + new_args.reset(Py_BuildValue("s(OOO)O", name, &CMessage_Type, + PythonMessage_class, well_known_class, dict)); + } + if (new_args == NULL) { return NULL; } // Call the base metaclass. - ScopedPyObjectPtr result(PyType_Type.tp_new(type, new_args, NULL)); + ScopedPyObjectPtr result(PyType_Type.tp_new(type, new_args.get(), NULL)); if (result == NULL) { return NULL; } - PyMessageMeta* newtype = reinterpret_cast<PyMessageMeta*>(result.get()); + CMessageClass* newtype = reinterpret_cast<CMessageClass*>(result.get()); // Insert the empty weakref into the base classes. if (InsertEmptyWeakref( @@ -313,28 +309,23 @@ static PyObject* New(PyTypeObject* type, // Add the message to the DescriptorPool. if (cdescriptor_pool::RegisterMessageClass(newtype->py_descriptor_pool, - descriptor, result) < 0) { + descriptor, newtype) < 0) { return NULL; } // Continue with type initialization: add other descriptors, enum values... - if (AddDescriptors(result, descriptor) < 0) { + if (AddDescriptors(result.get(), descriptor) < 0) { return NULL; } return result.release(); } -static void Dealloc(PyMessageMeta *self) { +static void Dealloc(CMessageClass *self) { Py_DECREF(self->py_message_descriptor); Py_DECREF(self->py_descriptor_pool); Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self)); } -static PyObject* GetDescriptor(PyMessageMeta *self, void *closure) { - Py_INCREF(self->py_message_descriptor); - return self->py_message_descriptor; -} - // This function inserts and empty weakref at the end of the list of // subclasses for the main protocol buffer Message class. @@ -358,10 +349,10 @@ static int InsertEmptyWeakref(PyTypeObject *base_type) { } // namespace message_meta -PyTypeObject PyMessageMeta_Type = { +PyTypeObject CMessageClass_Type = { PyVarObject_HEAD_INIT(&PyType_Type, 0) FULL_MODULE_NAME ".MessageMeta", // tp_name - sizeof(PyMessageMeta), // tp_basicsize + sizeof(CMessageClass), // tp_basicsize 0, // tp_itemsize (destructor)message_meta::Dealloc, // tp_dealloc 0, // tp_print @@ -399,16 +390,16 @@ PyTypeObject PyMessageMeta_Type = { message_meta::New, // tp_new }; -static PyMessageMeta* CheckMessageClass(PyTypeObject* cls) { - if (!PyObject_TypeCheck(cls, &PyMessageMeta_Type)) { +static CMessageClass* CheckMessageClass(PyTypeObject* cls) { + if (!PyObject_TypeCheck(cls, &CMessageClass_Type)) { PyErr_Format(PyExc_TypeError, "Class %s is not a Message", cls->tp_name); return NULL; } - return reinterpret_cast<PyMessageMeta*>(cls); + return reinterpret_cast<CMessageClass*>(cls); } static const Descriptor* GetMessageDescriptor(PyTypeObject* cls) { - PyMessageMeta* type = CheckMessageClass(cls); + CMessageClass* type = CheckMessageClass(cls); if (type == NULL) { return NULL; } @@ -453,21 +444,9 @@ static int VisitCompositeField(const FieldDescriptor* descriptor, if (descriptor->label() == FieldDescriptor::LABEL_REPEATED) { if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { if (descriptor->is_map()) { - const Descriptor* entry_type = descriptor->message_type(); - const FieldDescriptor* value_type = - entry_type->FindFieldByName("value"); - if (value_type->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { - MessageMapContainer* container = - reinterpret_cast<MessageMapContainer*>(child); - if (visitor.VisitMessageMapContainer(container) == -1) { - return -1; - } - } else { - ScalarMapContainer* container = - reinterpret_cast<ScalarMapContainer*>(child); - if (visitor.VisitScalarMapContainer(container) == -1) { - return -1; - } + MapContainer* container = reinterpret_cast<MapContainer*>(child); + if (visitor.VisitMapContainer(container) == -1) { + return -1; } } else { RepeatedCompositeContainer* container = @@ -584,12 +563,14 @@ bool CheckAndGetInteger( if (PyObject_RichCompareBool(min, arg, Py_LE) != 1 || PyObject_RichCompareBool(max, arg, Py_GE) != 1) { #endif - PyObject *s = PyObject_Str(arg); - if (s) { - PyErr_Format(PyExc_ValueError, - "Value out of range: %s", - PyString_AsString(s)); - Py_DECREF(s); + if (!PyErr_Occurred()) { + PyObject *s = PyObject_Str(arg); + if (s) { + PyErr_Format(PyExc_ValueError, + "Value out of range: %s", + PyString_AsString(s)); + Py_DECREF(s); + } } return false; } @@ -647,38 +628,51 @@ bool CheckAndGetBool(PyObject* arg, bool* value) { return true; } -bool CheckAndSetString( - PyObject* arg, Message* message, - const FieldDescriptor* descriptor, - const Reflection* reflection, - bool append, - int index) { +// Checks whether the given object (which must be "bytes" or "unicode") contains +// valid UTF-8. +bool IsValidUTF8(PyObject* obj) { + if (PyBytes_Check(obj)) { + PyObject* unicode = PyUnicode_FromEncodedObject(obj, "utf-8", NULL); + + // Clear the error indicator; we report our own error when desired. + PyErr_Clear(); + + if (unicode) { + Py_DECREF(unicode); + return true; + } else { + return false; + } + } else { + // Unicode object, known to be valid UTF-8. + return true; + } +} + +bool AllowInvalidUTF8(const FieldDescriptor* field) { return false; } + +PyObject* CheckString(PyObject* arg, const FieldDescriptor* descriptor) { GOOGLE_DCHECK(descriptor->type() == FieldDescriptor::TYPE_STRING || descriptor->type() == FieldDescriptor::TYPE_BYTES); if (descriptor->type() == FieldDescriptor::TYPE_STRING) { if (!PyBytes_Check(arg) && !PyUnicode_Check(arg)) { FormatTypeError(arg, "bytes, unicode"); - return false; + return NULL; } - if (PyBytes_Check(arg)) { - PyObject* unicode = PyUnicode_FromEncodedObject(arg, "utf-8", NULL); - if (unicode == NULL) { - PyObject* repr = PyObject_Repr(arg); - PyErr_Format(PyExc_ValueError, - "%s has type str, but isn't valid UTF-8 " - "encoding. Non-UTF-8 strings must be converted to " - "unicode objects before being added.", - PyString_AsString(repr)); - Py_DECREF(repr); - return false; - } else { - Py_DECREF(unicode); - } + if (!IsValidUTF8(arg) && !AllowInvalidUTF8(descriptor)) { + PyObject* repr = PyObject_Repr(arg); + PyErr_Format(PyExc_ValueError, + "%s has type str, but isn't valid UTF-8 " + "encoding. Non-UTF-8 strings must be converted to " + "unicode objects before being added.", + PyString_AsString(repr)); + Py_DECREF(repr); + return NULL; } } else if (!PyBytes_Check(arg)) { FormatTypeError(arg, "bytes"); - return false; + return NULL; } PyObject* encoded_string = NULL; @@ -696,14 +690,24 @@ bool CheckAndSetString( Py_INCREF(encoded_string); } - if (encoded_string == NULL) { + return encoded_string; +} + +bool CheckAndSetString( + PyObject* arg, Message* message, + const FieldDescriptor* descriptor, + const Reflection* reflection, + bool append, + int index) { + ScopedPyObjectPtr encoded_string(CheckString(arg, descriptor)); + + if (encoded_string.get() == NULL) { return false; } char* value; Py_ssize_t value_len; - if (PyBytes_AsStringAndSize(encoded_string, &value, &value_len) < 0) { - Py_DECREF(encoded_string); + if (PyBytes_AsStringAndSize(encoded_string.get(), &value, &value_len) < 0) { return false; } @@ -715,7 +719,6 @@ bool CheckAndSetString( } else { reflection->SetRepeatedString(message, descriptor, index, value_string); } - Py_DECREF(encoded_string); return true; } @@ -751,9 +754,9 @@ namespace cmessage { PyDescriptorPool* GetDescriptorPoolForMessage(CMessage* message) { // No need to check the type: the type of instances of CMessage is always - // an instance of PyMessageMeta. Let's prove it with a debug-only check. + // an instance of CMessageClass. Let's prove it with a debug-only check. GOOGLE_DCHECK(PyObject_TypeCheck(message, &CMessage_Type)); - return reinterpret_cast<PyMessageMeta*>(Py_TYPE(message))->py_descriptor_pool; + return reinterpret_cast<CMessageClass*>(Py_TYPE(message))->py_descriptor_pool; } MessageFactory* GetFactoryForMessage(CMessage* message) { @@ -828,12 +831,7 @@ struct FixupMessageReference : public ChildVisitor { return 0; } - int VisitScalarMapContainer(ScalarMapContainer* container) { - container->message = message_; - return 0; - } - - int VisitMessageMapContainer(MessageMapContainer* container) { + int VisitMapContainer(MapContainer* container) { container->message = message_; return 0; } @@ -862,7 +860,6 @@ int AssureWritable(CMessage* self) { return -1; // Make self->message writable. - Message* parent_message = self->parent->message; Message* mutable_message = GetMutableMessage( self->parent, self->parent_field_descriptor); @@ -876,9 +873,8 @@ int AssureWritable(CMessage* self) { // When a CMessage is made writable its Message pointer is updated // to point to a new mutable Message. When that happens we need to // update any references to the old, read-only CMessage. There are - // five places such references occur: RepeatedScalarContainer, - // RepeatedCompositeContainer, ScalarMapContainer, MessageMapContainer, - // and ExtensionDict. + // four places such references occur: RepeatedScalarContainer, + // RepeatedCompositeContainer, MapContainer, and ExtensionDict. if (self->extensions != NULL) self->extensions->message = self->message; if (ForEachCompositeField(self, FixupMessageReference(self->message)) == -1) @@ -1060,10 +1056,15 @@ int InitAttributes(CMessage* self, PyObject* kwargs) { } const FieldDescriptor* descriptor = GetFieldDescriptor(self, name); if (descriptor == NULL) { - PyErr_Format(PyExc_ValueError, "Protocol message has no \"%s\" field.", + PyErr_Format(PyExc_ValueError, "Protocol message %s has no \"%s\" field.", + self->message->GetDescriptor()->name().c_str(), PyString_AsString(name)); return -1; } + if (value == Py_None) { + // field=None is the same as no field at all. + continue; + } if (descriptor->is_map()) { ScopedPyObjectPtr map(GetAttr(self, name)); const FieldDescriptor* value_descriptor = @@ -1106,8 +1107,8 @@ int InitAttributes(CMessage* self, PyObject* kwargs) { return -1; } ScopedPyObjectPtr next; - while ((next.reset(PyIter_Next(iter))) != NULL) { - PyObject* kwargs = (PyDict_Check(next) ? next.get() : NULL); + while ((next.reset(PyIter_Next(iter.get()))) != NULL) { + PyObject* kwargs = (PyDict_Check(next.get()) ? next.get() : NULL); ScopedPyObjectPtr new_msg( repeated_composite_container::Add(rc_container, NULL, kwargs)); if (new_msg == NULL) { @@ -1115,9 +1116,9 @@ int InitAttributes(CMessage* self, PyObject* kwargs) { } if (kwargs == NULL) { // next was not a dict, it's a message we need to merge - ScopedPyObjectPtr merged( - MergeFrom(reinterpret_cast<CMessage*>(new_msg.get()), next)); - if (merged == NULL) { + ScopedPyObjectPtr merged(MergeFrom( + reinterpret_cast<CMessage*>(new_msg.get()), next.get())); + if (merged.get() == NULL) { return -1; } } @@ -1135,13 +1136,14 @@ int InitAttributes(CMessage* self, PyObject* kwargs) { return -1; } ScopedPyObjectPtr next; - while ((next.reset(PyIter_Next(iter))) != NULL) { - ScopedPyObjectPtr enum_value(GetIntegerEnumValue(*descriptor, next)); + while ((next.reset(PyIter_Next(iter.get()))) != NULL) { + ScopedPyObjectPtr enum_value( + GetIntegerEnumValue(*descriptor, next.get())); if (enum_value == NULL) { return -1; } - ScopedPyObjectPtr new_msg( - repeated_scalar_container::Append(rs_container, enum_value)); + ScopedPyObjectPtr new_msg(repeated_scalar_container::Append( + rs_container, enum_value.get())); if (new_msg == NULL) { return -1; } @@ -1182,7 +1184,8 @@ int InitAttributes(CMessage* self, PyObject* kwargs) { return -1; } } - if (SetAttr(self, name, (new_val == NULL) ? value : new_val) < 0) { + if (SetAttr(self, name, (new_val.get() == NULL) ? value : new_val.get()) < + 0) { return -1; } } @@ -1192,9 +1195,9 @@ int InitAttributes(CMessage* self, PyObject* kwargs) { // Allocates an incomplete Python Message: the caller must fill self->message, // self->owner and eventually self->parent. -CMessage* NewEmptyMessage(PyObject* type, const Descriptor *descriptor) { +CMessage* NewEmptyMessage(CMessageClass* type) { CMessage* self = reinterpret_cast<CMessage*>( - PyType_GenericAlloc(reinterpret_cast<PyTypeObject*>(type), 0)); + PyType_GenericAlloc(&type->super.ht_type, 0)); if (self == NULL) { return NULL; } @@ -1207,18 +1210,6 @@ CMessage* NewEmptyMessage(PyObject* type, const Descriptor *descriptor) { self->composite_fields = NULL; - // If there are extension_ranges, the message is "extendable". Allocate a - // dictionary to store the extension fields. - if (descriptor->extension_range_count() > 0) { - // TODO(amauryfa): Delay the construction of this dict until extensions are - // really used on the object. - ExtensionDict* extension_dict = extension_dict::NewExtensionDict(self); - if (extension_dict == NULL) { - return NULL; - } - self->extensions = extension_dict; - } - return self; } @@ -1226,7 +1217,7 @@ CMessage* NewEmptyMessage(PyObject* type, const Descriptor *descriptor) { // Creates a new C++ message and takes ownership. static PyObject* New(PyTypeObject* cls, PyObject* unused_args, PyObject* unused_kwargs) { - PyMessageMeta* type = CheckMessageClass(cls); + CMessageClass* type = CheckMessageClass(cls); if (type == NULL) { return NULL; } @@ -1242,8 +1233,7 @@ static PyObject* New(PyTypeObject* cls, return NULL; } - CMessage* self = NewEmptyMessage(reinterpret_cast<PyObject*>(type), - message_descriptor); + CMessage* self = NewEmptyMessage(type); if (self == NULL) { return NULL; } @@ -1289,12 +1279,7 @@ struct ClearWeakReferences : public ChildVisitor { return 0; } - int VisitScalarMapContainer(ScalarMapContainer* container) { - container->parent = NULL; - return 0; - } - - int VisitMessageMapContainer(MessageMapContainer* container) { + int VisitMapContainer(MapContainer* container) { container->parent = NULL; return 0; } @@ -1309,6 +1294,9 @@ struct ClearWeakReferences : public ChildVisitor { static void Dealloc(CMessage* self) { // Null out all weak references from children to this message. GOOGLE_CHECK_EQ(0, ForEachCompositeField(self, ClearWeakReferences())); + if (self->extensions) { + self->extensions->parent = NULL; + } Py_CLEAR(self->extensions); Py_CLEAR(self->composite_fields); @@ -1470,20 +1458,27 @@ PyObject* HasField(CMessage* self, PyObject* arg) { Py_RETURN_FALSE; } -PyObject* ClearExtension(CMessage* self, PyObject* arg) { +PyObject* ClearExtension(CMessage* self, PyObject* extension) { if (self->extensions != NULL) { - return extension_dict::ClearExtension(self->extensions, arg); + return extension_dict::ClearExtension(self->extensions, extension); + } else { + const FieldDescriptor* descriptor = GetExtensionDescriptor(extension); + if (descriptor == NULL) { + return NULL; + } + if (ScopedPyObjectPtr(ClearFieldByDescriptor(self, descriptor)) == NULL) { + return NULL; + } } - PyErr_SetString(PyExc_TypeError, "Message is not extendable"); - return NULL; + Py_RETURN_NONE; } -PyObject* HasExtension(CMessage* self, PyObject* arg) { - if (self->extensions != NULL) { - return extension_dict::HasExtension(self->extensions, arg); +PyObject* HasExtension(CMessage* self, PyObject* extension) { + const FieldDescriptor* descriptor = GetExtensionDescriptor(extension); + if (descriptor == NULL) { + return NULL; } - PyErr_SetString(PyExc_TypeError, "Message is not extendable"); - return NULL; + return HasFieldByDescriptor(self, descriptor); } // --------------------------------------------------------------------- @@ -1533,13 +1528,8 @@ struct SetOwnerVisitor : public ChildVisitor { return 0; } - int VisitScalarMapContainer(ScalarMapContainer* container) { - scalar_map_container::SetOwner(container, new_owner_); - return 0; - } - - int VisitMessageMapContainer(MessageMapContainer* container) { - message_map_container::SetOwner(container, new_owner_); + int VisitMapContainer(MapContainer* container) { + container->SetOwner(new_owner_); return 0; } @@ -1612,14 +1602,8 @@ struct ReleaseChild : public ChildVisitor { reinterpret_cast<RepeatedScalarContainer*>(container)); } - int VisitScalarMapContainer(ScalarMapContainer* container) { - return scalar_map_container::Release( - reinterpret_cast<ScalarMapContainer*>(container)); - } - - int VisitMessageMapContainer(MessageMapContainer* container) { - return message_map_container::Release( - reinterpret_cast<MessageMapContainer*>(container)); + int VisitMapContainer(MapContainer* container) { + return reinterpret_cast<MapContainer*>(container)->Release(); } int VisitCMessage(CMessage* cmessage, @@ -1711,17 +1695,7 @@ PyObject* Clear(CMessage* self) { AssureWritable(self); if (ForEachCompositeField(self, ReleaseChild(self)) == -1) return NULL; - - // The old ExtensionDict still aliases this CMessage, but all its - // fields have been released. - if (self->extensions != NULL) { - Py_CLEAR(self->extensions); - ExtensionDict* extension_dict = extension_dict::NewExtensionDict(self); - if (extension_dict == NULL) { - return NULL; - } - self->extensions = extension_dict; - } + Py_CLEAR(self->extensions); if (self->composite_fields) { PyDict_Clear(self->composite_fields); } @@ -1769,13 +1743,13 @@ static PyObject* SerializeToString(CMessage* self, PyObject* args) { } ScopedPyObjectPtr encode_error( - PyObject_GetAttrString(message_module, "EncodeError")); + PyObject_GetAttrString(message_module.get(), "EncodeError")); if (encode_error.get() == NULL) { return NULL; } PyErr_Format(encode_error.get(), "Message %s is missing required fields: %s", - GetMessageName(self).c_str(), PyString_AsString(joined)); + GetMessageName(self).c_str(), PyString_AsString(joined.get())); return NULL; } int size = self->message->ByteSize(); @@ -1851,8 +1825,12 @@ static PyObject* ToStr(CMessage* self) { PyObject* MergeFrom(CMessage* self, PyObject* arg) { CMessage* other_message; - if (!PyObject_TypeCheck(reinterpret_cast<PyObject *>(arg), &CMessage_Type)) { - PyErr_SetString(PyExc_TypeError, "Must be a message"); + if (!PyObject_TypeCheck(arg, &CMessage_Type)) { + PyErr_Format(PyExc_TypeError, + "Parameter to MergeFrom() must be instance of same class: " + "expected %s got %s.", + self->message->GetDescriptor()->full_name().c_str(), + Py_TYPE(arg)->tp_name); return NULL; } @@ -1860,8 +1838,8 @@ PyObject* MergeFrom(CMessage* self, PyObject* arg) { if (other_message->message->GetDescriptor() != self->message->GetDescriptor()) { PyErr_Format(PyExc_TypeError, - "Tried to merge from a message with a different type. " - "to: %s, from: %s", + "Parameter to MergeFrom() must be instance of same class: " + "expected %s got %s.", self->message->GetDescriptor()->full_name().c_str(), other_message->message->GetDescriptor()->full_name().c_str()); return NULL; @@ -1879,8 +1857,12 @@ PyObject* MergeFrom(CMessage* self, PyObject* arg) { static PyObject* CopyFrom(CMessage* self, PyObject* arg) { CMessage* other_message; - if (!PyObject_TypeCheck(reinterpret_cast<PyObject *>(arg), &CMessage_Type)) { - PyErr_SetString(PyExc_TypeError, "Must be a message"); + if (!PyObject_TypeCheck(arg, &CMessage_Type)) { + PyErr_Format(PyExc_TypeError, + "Parameter to CopyFrom() must be instance of same class: " + "expected %s got %s.", + self->message->GetDescriptor()->full_name().c_str(), + Py_TYPE(arg)->tp_name); return NULL; } @@ -1893,8 +1875,8 @@ static PyObject* CopyFrom(CMessage* self, PyObject* arg) { if (other_message->message->GetDescriptor() != self->message->GetDescriptor()) { PyErr_Format(PyExc_TypeError, - "Tried to copy from a message with a different type. " - "to: %s, from: %s", + "Parameter to CopyFrom() must be instance of same class: " + "expected %s got %s.", self->message->GetDescriptor()->full_name().c_str(), other_message->message->GetDescriptor()->full_name().c_str()); return NULL; @@ -1911,6 +1893,30 @@ static PyObject* CopyFrom(CMessage* self, PyObject* arg) { Py_RETURN_NONE; } +// Protobuf has a 64MB limit built in, this variable will override this. Please +// do not enable this unless you fully understand the implications: protobufs +// must all be kept in memory at the same time, so if they grow too big you may +// get OOM errors. The protobuf APIs do not provide any tools for processing +// protobufs in chunks. If you have protos this big you should break them up if +// it is at all convenient to do so. +static bool allow_oversize_protos = false; + +// Provide a method in the module to set allow_oversize_protos to a boolean +// value. This method returns the newly value of allow_oversize_protos. +static PyObject* SetAllowOversizeProtos(PyObject* m, PyObject* arg) { + if (!arg || !PyBool_Check(arg)) { + PyErr_SetString(PyExc_TypeError, + "Argument to SetAllowOversizeProtos must be boolean"); + return NULL; + } + allow_oversize_protos = PyObject_IsTrue(arg); + if (allow_oversize_protos) { + Py_RETURN_TRUE; + } else { + Py_RETURN_FALSE; + } +} + static PyObject* MergeFromString(CMessage* self, PyObject* arg) { const void* data; Py_ssize_t data_length; @@ -1921,6 +1927,9 @@ static PyObject* MergeFromString(CMessage* self, PyObject* arg) { AssureWritable(self); io::CodedInputStream input( reinterpret_cast<const uint8*>(data), data_length); + if (allow_oversize_protos) { + input.SetTotalBytesLimit(INT_MAX, INT_MAX); + } PyDescriptorPool* pool = GetDescriptorPoolForMessage(self); input.SetExtensionRegistry(pool->pool, pool->message_factory); bool success = self->message->MergePartialFromCodedStream(&input); @@ -1950,8 +1959,6 @@ static PyObject* RegisterExtension(PyObject* cls, if (descriptor == NULL) { return NULL; } - const Descriptor* cmessage_descriptor = GetMessageDescriptor( - reinterpret_cast<PyTypeObject*>(cls)); ScopedPyObjectPtr extensions_by_name( PyObject_GetAttr(cls, k_extensions_by_name)); @@ -1965,7 +1972,8 @@ static PyObject* RegisterExtension(PyObject* cls, } // If the extension was already registered, check that it is the same. - PyObject* existing_extension = PyDict_GetItem(extensions_by_name, full_name); + PyObject* existing_extension = + PyDict_GetItem(extensions_by_name.get(), full_name.get()); if (existing_extension != NULL) { const FieldDescriptor* existing_extension_descriptor = GetExtensionDescriptor(existing_extension); @@ -1977,7 +1985,8 @@ static PyObject* RegisterExtension(PyObject* cls, Py_RETURN_NONE; } - if (PyDict_SetItem(extensions_by_name, full_name, extension_handle) < 0) { + if (PyDict_SetItem(extensions_by_name.get(), full_name.get(), + extension_handle) < 0) { return NULL; } @@ -1988,11 +1997,36 @@ static PyObject* RegisterExtension(PyObject* cls, PyErr_SetString(PyExc_TypeError, "no extensions_by_number on class"); return NULL; } + ScopedPyObjectPtr number(PyObject_GetAttrString(extension_handle, "number")); if (number == NULL) { return NULL; } - if (PyDict_SetItem(extensions_by_number, number, extension_handle) < 0) { + + // If the extension was already registered by number, check that it is the + // same. + existing_extension = PyDict_GetItem(extensions_by_number.get(), number.get()); + if (existing_extension != NULL) { + const FieldDescriptor* existing_extension_descriptor = + GetExtensionDescriptor(existing_extension); + if (existing_extension_descriptor != descriptor) { + const Descriptor* msg_desc = GetMessageDescriptor( + reinterpret_cast<PyTypeObject*>(cls)); + PyErr_Format( + PyExc_ValueError, + "Extensions \"%s\" and \"%s\" both try to extend message type " + "\"%s\" with field number %ld.", + existing_extension_descriptor->full_name().c_str(), + descriptor->full_name().c_str(), + msg_desc->full_name().c_str(), + PyInt_AsLong(number.get())); + return NULL; + } + // Nothing else to do. + Py_RETURN_NONE; + } + if (PyDict_SetItem(extensions_by_number.get(), number.get(), + extension_handle) < 0) { return NULL; } @@ -2000,7 +2034,6 @@ static PyObject* RegisterExtension(PyObject* cls, if (descriptor->is_extension() && descriptor->containing_type()->options().message_set_wire_format() && descriptor->type() == FieldDescriptor::TYPE_MESSAGE && - descriptor->message_type() == descriptor->extension_scope() && descriptor->label() == FieldDescriptor::LABEL_OPTIONAL) { ScopedPyObjectPtr message_name(PyString_FromStringAndSize( descriptor->message_type()->full_name().c_str(), @@ -2008,7 +2041,8 @@ static PyObject* RegisterExtension(PyObject* cls, if (message_name == NULL) { return NULL; } - PyDict_SetItem(extensions_by_name, message_name, extension_handle); + PyDict_SetItem(extensions_by_name.get(), message_name.get(), + extension_handle); } Py_RETURN_NONE; @@ -2044,6 +2078,8 @@ static PyObject* WhichOneof(CMessage* self, PyObject* arg) { } } +static PyObject* GetExtensionDict(CMessage* self, void *closure); + static PyObject* ListFields(CMessage* self) { vector<const FieldDescriptor*> fields; self->message->GetReflection()->ListFields(*self->message, &fields); @@ -2058,7 +2094,7 @@ static PyObject* ListFields(CMessage* self) { // the field information. Thus the actual size of the py list will be // smaller than the size of fields. Set the actual size at the end. Py_ssize_t actual_size = 0; - for (Py_ssize_t i = 0; i < fields.size(); ++i) { + for (size_t i = 0; i < fields.size(); ++i) { ScopedPyObjectPtr t(PyTuple_New(2)); if (t == NULL) { return NULL; @@ -2081,12 +2117,13 @@ static PyObject* ListFields(CMessage* self) { PyErr_Clear(); continue; } - PyObject* extensions = reinterpret_cast<PyObject*>(self->extensions); + ScopedPyObjectPtr extensions(GetExtensionDict(self, NULL)); if (extensions == NULL) { return NULL; } // 'extension' reference later stolen by PyTuple_SET_ITEM. - PyObject* extension = PyObject_GetItem(extensions, extension_field); + PyObject* extension = PyObject_GetItem( + extensions.get(), extension_field.get()); if (extension == NULL) { return NULL; } @@ -2108,9 +2145,9 @@ static PyObject* ListFields(CMessage* self) { return NULL; } - PyObject* field_value = GetAttr(self, py_field_name); + PyObject* field_value = GetAttr(self, py_field_name.get()); if (field_value == NULL) { - PyErr_SetObject(PyExc_ValueError, py_field_name); + PyErr_SetObject(PyExc_ValueError, py_field_name.get()); return NULL; } PyTuple_SET_ITEM(t.get(), 0, field_descriptor.release()); @@ -2119,10 +2156,20 @@ static PyObject* ListFields(CMessage* self) { PyList_SET_ITEM(all_fields.get(), actual_size, t.release()); ++actual_size; } - Py_SIZE(all_fields.get()) = actual_size; + if (static_cast<size_t>(actual_size) != fields.size() && + (PyList_SetSlice(all_fields.get(), actual_size, fields.size(), NULL) < + 0)) { + return NULL; + } return all_fields.release(); } +static PyObject* DiscardUnknownFields(CMessage* self) { + AssureWritable(self); + self->message->DiscardUnknownFields(); + Py_RETURN_NONE; +} + PyObject* FindInitializationErrors(CMessage* self) { Message* message = self->message; vector<string> errors; @@ -2132,7 +2179,7 @@ PyObject* FindInitializationErrors(CMessage* self) { if (error_list == NULL) { return NULL; } - for (Py_ssize_t i = 0; i < errors.size(); ++i) { + for (size_t i = 0; i < errors.size(); ++i) { const string& error = errors[i]; PyObject* error_string = PyString_FromStringAndSize( error.c_str(), error.length()); @@ -2266,14 +2313,13 @@ PyObject* InternalGetSubMessage( const Message& sub_message = reflection->GetMessage( *self->message, field_descriptor, pool->message_factory); - PyObject *message_class = cdescriptor_pool::GetMessageClass( + CMessageClass* message_class = cdescriptor_pool::GetMessageClass( pool, field_descriptor->message_type()); if (message_class == NULL) { return NULL; } - CMessage* cmsg = cmessage::NewEmptyMessage(message_class, - sub_message.GetDescriptor()); + CMessage* cmsg = cmessage::NewEmptyMessage(message_class); if (cmsg == NULL) { return NULL; } @@ -2430,16 +2476,16 @@ PyObject* ToUnicode(CMessage* self) { return NULL; } Py_INCREF(Py_True); - ScopedPyObjectPtr encoded(PyObject_CallMethodObjArgs(text_format, method_name, - self, Py_True, NULL)); + ScopedPyObjectPtr encoded(PyObject_CallMethodObjArgs( + text_format.get(), method_name.get(), self, Py_True, NULL)); Py_DECREF(Py_True); if (encoded == NULL) { return NULL; } #if PY_MAJOR_VERSION < 3 - PyObject* decoded = PyString_AsDecodedObject(encoded, "utf-8", NULL); + PyObject* decoded = PyString_AsDecodedObject(encoded.get(), "utf-8", NULL); #else - PyObject* decoded = PyUnicode_FromEncodedObject(encoded, "utf-8", NULL); + PyObject* decoded = PyUnicode_FromEncodedObject(encoded.get(), "utf-8", NULL); #endif if (decoded == NULL) { return NULL; @@ -2462,7 +2508,7 @@ PyObject* Reduce(CMessage* self) { if (serialized == NULL) { return NULL; } - if (PyDict_SetItemString(state, "serialized", serialized) < 0) { + if (PyDict_SetItemString(state.get(), "serialized", serialized.get()) < 0) { return NULL; } return Py_BuildValue("OOO", constructor.get(), args.get(), state.get()); @@ -2495,9 +2541,31 @@ PyObject* _CheckCalledFromGeneratedFile(PyObject* unused, Py_RETURN_NONE; } -static PyMemberDef Members[] = { - {"Extensions", T_OBJECT_EX, offsetof(CMessage, extensions), 0, - "Extension dict"}, +static PyObject* GetExtensionDict(CMessage* self, void *closure) { + if (self->extensions) { + Py_INCREF(self->extensions); + return reinterpret_cast<PyObject*>(self->extensions); + } + + // If there are extension_ranges, the message is "extendable". Allocate a + // dictionary to store the extension fields. + const Descriptor* descriptor = GetMessageDescriptor(Py_TYPE(self)); + if (descriptor->extension_range_count() > 0) { + ExtensionDict* extension_dict = extension_dict::NewExtensionDict(self); + if (extension_dict == NULL) { + return NULL; + } + self->extensions = extension_dict; + Py_INCREF(self->extensions); + return reinterpret_cast<PyObject*>(self->extensions); + } + + PyErr_SetNone(PyExc_AttributeError); + return NULL; +} + +static PyGetSetDef Getters[] = { + {"Extensions", (getter)GetExtensionDict, NULL, "Extension dict"}, {NULL} }; @@ -2520,6 +2588,8 @@ static PyMethodDef Methods[] = { "Clears a message field." }, { "CopyFrom", (PyCFunction)CopyFrom, METH_O, "Copies a protocol message into the current message." }, + { "DiscardUnknownFields", (PyCFunction)DiscardUnknownFields, METH_NOARGS, + "Discards the unknown fields." }, { "FindInitializationErrors", (PyCFunction)FindInitializationErrors, METH_NOARGS, "Finds unset required fields." }, @@ -2589,15 +2659,15 @@ PyObject* GetAttr(CMessage* self, PyObject* name) { const Descriptor* entry_type = field_descriptor->message_type(); const FieldDescriptor* value_type = entry_type->FindFieldByName("value"); if (value_type->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { - PyObject* value_class = cdescriptor_pool::GetMessageClass( + CMessageClass* value_class = cdescriptor_pool::GetMessageClass( GetDescriptorPoolForMessage(self), value_type->message_type()); if (value_class == NULL) { return NULL; } - py_container = message_map_container::NewContainer(self, field_descriptor, - value_class); + py_container = + NewMessageMapContainer(self, field_descriptor, value_class); } else { - py_container = scalar_map_container::NewContainer(self, field_descriptor); + py_container = NewScalarMapContainer(self, field_descriptor); } if (py_container == NULL) { return NULL; @@ -2612,7 +2682,7 @@ PyObject* GetAttr(CMessage* self, PyObject* name) { if (field_descriptor->label() == FieldDescriptor::LABEL_REPEATED) { PyObject* py_container = NULL; if (field_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { - PyObject *message_class = cdescriptor_pool::GetMessageClass( + CMessageClass* message_class = cdescriptor_pool::GetMessageClass( GetDescriptorPoolForMessage(self), field_descriptor->message_type()); if (message_class == NULL) { return NULL; @@ -2674,14 +2744,17 @@ int SetAttr(CMessage* self, PyObject* name, PyObject* value) { } } - PyErr_Format(PyExc_AttributeError, "Assignment not allowed"); + PyErr_Format(PyExc_AttributeError, + "Assignment not allowed " + "(no field \"%s\" in protocol message object).", + PyString_AsString(name)); return -1; } } // namespace cmessage PyTypeObject CMessage_Type = { - PyVarObject_HEAD_INIT(&PyMessageMeta_Type, 0) + PyVarObject_HEAD_INIT(&CMessageClass_Type, 0) FULL_MODULE_NAME ".CMessage", // tp_name sizeof(CMessage), // tp_basicsize 0, // tp_itemsize @@ -2690,7 +2763,7 @@ PyTypeObject CMessage_Type = { 0, // tp_getattr 0, // tp_setattr 0, // tp_compare - 0, // tp_repr + (reprfunc)cmessage::ToStr, // tp_repr 0, // tp_as_number 0, // tp_as_sequence 0, // tp_as_mapping @@ -2709,8 +2782,8 @@ PyTypeObject CMessage_Type = { 0, // tp_iter 0, // tp_iternext cmessage::Methods, // tp_methods - cmessage::Members, // tp_members - 0, // tp_getset + 0, // tp_members + cmessage::Getters, // tp_getset 0, // tp_base 0, // tp_dict 0, // tp_descr_get @@ -2796,12 +2869,12 @@ bool InitProto2MessageModule(PyObject *m) { // Initialize constants defined in this file. InitGlobals(); - PyMessageMeta_Type.tp_base = &PyType_Type; - if (PyType_Ready(&PyMessageMeta_Type) < 0) { + CMessageClass_Type.tp_base = &PyType_Type; + if (PyType_Ready(&CMessageClass_Type) < 0) { return false; } PyModule_AddObject(m, "MessageMeta", - reinterpret_cast<PyObject*>(&PyMessageMeta_Type)); + reinterpret_cast<PyObject*>(&CMessageClass_Type)); if (PyType_Ready(&CMessage_Type) < 0) { return false; @@ -2817,16 +2890,16 @@ bool InitProto2MessageModule(PyObject *m) { if (empty_dict == NULL) { return false; } - ScopedPyObjectPtr immutable_dict(PyDictProxy_New(empty_dict)); + ScopedPyObjectPtr immutable_dict(PyDictProxy_New(empty_dict.get())); if (immutable_dict == NULL) { return false; } if (PyDict_SetItem(CMessage_Type.tp_dict, - k_extensions_by_name, immutable_dict) < 0) { + k_extensions_by_name, immutable_dict.get()) < 0) { return false; } if (PyDict_SetItem(CMessage_Type.tp_dict, - k_extensions_by_number, immutable_dict) < 0) { + k_extensions_by_number, immutable_dict.get()) < 0) { return false; } @@ -2856,19 +2929,19 @@ bool InitProto2MessageModule(PyObject *m) { if (collections == NULL) { return false; } - ScopedPyObjectPtr mutable_sequence(PyObject_GetAttrString( - collections, "MutableSequence")); + ScopedPyObjectPtr mutable_sequence( + PyObject_GetAttrString(collections.get(), "MutableSequence")); if (mutable_sequence == NULL) { return false; } - if (ScopedPyObjectPtr(PyObject_CallMethod(mutable_sequence, "register", "O", - &RepeatedScalarContainer_Type)) - == NULL) { + if (ScopedPyObjectPtr( + PyObject_CallMethod(mutable_sequence.get(), "register", "O", + &RepeatedScalarContainer_Type)) == NULL) { return false; } - if (ScopedPyObjectPtr(PyObject_CallMethod(mutable_sequence, "register", "O", - &RepeatedCompositeContainer_Type)) - == NULL) { + if (ScopedPyObjectPtr( + PyObject_CallMethod(mutable_sequence.get(), "register", "O", + &RepeatedCompositeContainer_Type)) == NULL) { return false; } } @@ -2883,16 +2956,16 @@ bool InitProto2MessageModule(PyObject *m) { } ScopedPyObjectPtr mutable_mapping( - PyObject_GetAttrString(containers, "MutableMapping")); + PyObject_GetAttrString(containers.get(), "MutableMapping")); if (mutable_mapping == NULL) { return false; } - if (!PyObject_TypeCheck(mutable_mapping, &PyType_Type)) { + if (!PyObject_TypeCheck(mutable_mapping.get(), &PyType_Type)) { return false; } - Py_INCREF(mutable_mapping); + Py_INCREF(mutable_mapping.get()); #if PY_MAJOR_VERSION >= 3 PyObject* bases = PyTuple_New(1); PyTuple_SET_ITEM(bases, 0, mutable_mapping.get()); @@ -2912,12 +2985,12 @@ bool InitProto2MessageModule(PyObject *m) { reinterpret_cast<PyObject*>(&ScalarMapContainer_Type)); #endif - if (PyType_Ready(&ScalarMapIterator_Type) < 0) { + if (PyType_Ready(&MapIterator_Type) < 0) { return false; } - PyModule_AddObject(m, "ScalarMapIterator", - reinterpret_cast<PyObject*>(&ScalarMapIterator_Type)); + PyModule_AddObject(m, "MapIterator", + reinterpret_cast<PyObject*>(&MapIterator_Type)); #if PY_MAJOR_VERSION >= 3 @@ -2925,7 +2998,7 @@ bool InitProto2MessageModule(PyObject *m) { PyType_FromSpecWithBases(&MessageMapContainer_Type_spec, bases); PyModule_AddObject(m, "MessageMapContainer", MessageMapContainer_Type); #else - Py_INCREF(mutable_mapping); + Py_INCREF(mutable_mapping.get()); MessageMapContainer_Type.tp_base = reinterpret_cast<PyTypeObject*>(mutable_mapping.get()); @@ -2936,13 +3009,6 @@ bool InitProto2MessageModule(PyObject *m) { PyModule_AddObject(m, "MessageMapContainer", reinterpret_cast<PyObject*>(&MessageMapContainer_Type)); #endif - - if (PyType_Ready(&MessageMapIterator_Type) < 0) { - return false; - } - - PyModule_AddObject(m, "MessageMapIterator", - reinterpret_cast<PyObject*>(&MessageMapIterator_Type)); } if (PyType_Ready(&ExtensionDict_Type) < 0) { @@ -2959,6 +3025,9 @@ bool InitProto2MessageModule(PyObject *m) { PyModule_AddObject(m, "default_pool", reinterpret_cast<PyObject*>(GetDefaultDescriptorPool())); + PyModule_AddObject(m, "DescriptorPool", reinterpret_cast<PyObject*>( + &PyDescriptorPool_Type)); + // This implementation provides full Descriptor types, we advertise it so that // descriptor.py can use them in replacement of the Python classes. PyModule_AddIntConstant(m, "_USE_C_DESCRIPTORS", 1); @@ -3012,6 +3081,12 @@ bool InitProto2MessageModule(PyObject *m) { } // namespace python } // namespace protobuf +static PyMethodDef ModuleMethods[] = { + {"SetAllowOversizeProtos", + (PyCFunction)google::protobuf::python::cmessage::SetAllowOversizeProtos, + METH_O, "Enable/disable oversize proto parsing."}, + { NULL, NULL} +}; #if PY_MAJOR_VERSION >= 3 static struct PyModuleDef _module = { @@ -3019,7 +3094,7 @@ static struct PyModuleDef _module = { "_message", google::protobuf::python::module_docstring, -1, - NULL, + ModuleMethods, /* m_methods */ NULL, NULL, NULL, @@ -3038,7 +3113,8 @@ extern "C" { #if PY_MAJOR_VERSION >= 3 m = PyModule_Create(&_module); #else - m = Py_InitModule3("_message", NULL, google::protobuf::python::module_docstring); + m = Py_InitModule3("_message", ModuleMethods, + google::protobuf::python::module_docstring); #endif if (m == NULL) { return INITFUNC_ERRORVAL; |