diff options
Diffstat (limited to 'python/google/protobuf/pyext/message.cc')
-rw-r--r-- | python/google/protobuf/pyext/message.cc | 646 |
1 files changed, 387 insertions, 259 deletions
diff --git a/python/google/protobuf/pyext/message.cc b/python/google/protobuf/pyext/message.cc index 53736b9c..fecb9364 100644 --- a/python/google/protobuf/pyext/message.cc +++ b/python/google/protobuf/pyext/message.cc @@ -45,12 +45,11 @@ #ifndef Py_TYPE #define Py_TYPE(ob) (((PyObject*)(ob))->ob_type) #endif -#include <google/protobuf/descriptor.pb.h> #include <google/protobuf/stubs/common.h> #include <google/protobuf/stubs/logging.h> #include <google/protobuf/io/coded_stream.h> #include <google/protobuf/io/zero_copy_stream_impl_lite.h> -#include <google/protobuf/util/message_differencer.h> +#include <google/protobuf/descriptor.pb.h> #include <google/protobuf/descriptor.h> #include <google/protobuf/message.h> #include <google/protobuf/text_format.h> @@ -58,12 +57,18 @@ #include <google/protobuf/pyext/descriptor.h> #include <google/protobuf/pyext/descriptor_pool.h> #include <google/protobuf/pyext/extension_dict.h> -#include <google/protobuf/pyext/repeated_composite_container.h> -#include <google/protobuf/pyext/repeated_scalar_container.h> +#include <google/protobuf/pyext/field.h> #include <google/protobuf/pyext/map_container.h> #include <google/protobuf/pyext/message_factory.h> +#include <google/protobuf/pyext/repeated_composite_container.h> +#include <google/protobuf/pyext/repeated_scalar_container.h> +#include <google/protobuf/pyext/unknown_fields.h> #include <google/protobuf/pyext/safe_numerics.h> #include <google/protobuf/pyext/scoped_pyobject_ptr.h> +#include <google/protobuf/util/message_differencer.h> +#include <google/protobuf/stubs/strutil.h> + +#include <google/protobuf/port_def.inc> #if PY_MAJOR_VERSION >= 3 #define PyInt_AsLong PyLong_AsLong @@ -72,16 +77,19 @@ #define PyString_Check PyUnicode_Check #define PyString_FromString PyUnicode_FromString #define PyString_FromStringAndSize PyUnicode_FromStringAndSize + #define PyString_FromFormat PyUnicode_FromFormat #if PY_VERSION_HEX < 0x03030000 #error "Python 3.0 - 3.2 are not supported." #else #define PyString_AsString(ob) \ (PyUnicode_Check(ob)? PyUnicode_AsUTF8(ob): PyBytes_AsString(ob)) - #define PyString_AsStringAndSize(ob, charpp, sizep) \ - (PyUnicode_Check(ob)? \ - ((*(charpp) = PyUnicode_AsUTF8AndSize(ob, (sizep))) == NULL? -1: 0): \ - PyBytes_AsStringAndSize(ob, (charpp), (sizep))) - #endif +#define PyString_AsStringAndSize(ob, charpp, sizep) \ + (PyUnicode_Check(ob) ? ((*(charpp) = const_cast<char*>( \ + PyUnicode_AsUTF8AndSize(ob, (sizep)))) == NULL \ + ? -1 \ + : 0) \ + : PyBytes_AsStringAndSize(ob, (charpp), (sizep))) +#endif #endif namespace google { @@ -99,44 +107,27 @@ namespace message_meta { static int InsertEmptyWeakref(PyTypeObject* base); namespace { -// Copied oveer from internal 'google/protobuf/stubs/strutil.h'. -inline void UpperString(string * s) { +// Copied over from internal 'google/protobuf/stubs/strutil.h'. +inline void LowerString(string * s) { string::iterator end = s->end(); for (string::iterator i = s->begin(); i != end; ++i) { - // toupper() changes based on locale. We don't want this! - if ('a' <= *i && *i <= 'z') *i += 'A' - 'a'; + // tolower() changes based on locale. We don't want this! + if ('A' <= *i && *i <= 'Z') *i += 'a' - 'A'; } } } -// Add the number of a field descriptor to the containing message class. -// Equivalent to: -// _cls.<field>_FIELD_NUMBER = <number> -static bool AddFieldNumberToClass( - PyObject* cls, const FieldDescriptor* field_descriptor) { - string constant_name = field_descriptor->name() + "_FIELD_NUMBER"; - UpperString(&constant_name); - ScopedPyObjectPtr attr_name(PyString_FromStringAndSize( - constant_name.c_str(), constant_name.size())); - if (attr_name == NULL) { - return false; - } - ScopedPyObjectPtr number(PyInt_FromLong(field_descriptor->number())); - if (number == NULL) { - return false; - } - if (PyObject_SetAttr(cls, attr_name.get(), number.get()) == -1) { - return false; - } - return true; -} - - // Finalize the creation of the Message class. static int AddDescriptors(PyObject* cls, const Descriptor* descriptor) { // For each field set: cls.<field>_FIELD_NUMBER = <number> for (int i = 0; i < descriptor->field_count(); ++i) { - if (!AddFieldNumberToClass(cls, descriptor->field(i))) { + const FieldDescriptor* field_descriptor = descriptor->field(i); + ScopedPyObjectPtr property(NewFieldProperty(field_descriptor)); + if (property == NULL) { + return -1; + } + if (PyObject_SetAttrString(cls, field_descriptor->name().c_str(), + property.get()) < 0) { return -1; } } @@ -193,11 +184,6 @@ static int AddDescriptors(PyObject* cls, const Descriptor* descriptor) { cls, field->name().c_str(), extension_field.get()) == -1) { return -1; } - - // For each extension set cls.<extension name>_FIELD_NUMBER = <number>. - if (!AddFieldNumberToClass(cls, field)) { - return -1; - } } return 0; @@ -265,10 +251,10 @@ static PyObject* New(PyTypeObject* type, PyObject* well_known_class = PyDict_GetItemString( WKT_classes, message_descriptor->full_name().c_str()); if (well_known_class == NULL) { - new_args.reset(Py_BuildValue("s(OO)O", name, &CMessage_Type, + new_args.reset(Py_BuildValue("s(OO)O", name, CMessage_Type, PythonMessage_class, dict)); } else { - new_args.reset(Py_BuildValue("s(OOO)O", name, &CMessage_Type, + new_args.reset(Py_BuildValue("s(OOO)O", name, CMessage_Type, PythonMessage_class, well_known_class, dict)); } @@ -285,7 +271,7 @@ static PyObject* New(PyTypeObject* type, // Insert the empty weakref into the base classes. if (InsertEmptyWeakref( reinterpret_cast<PyTypeObject*>(PythonMessage_class)) < 0 || - InsertEmptyWeakref(&CMessage_Type) < 0) { + InsertEmptyWeakref(CMessage_Type) < 0) { return NULL; } @@ -353,6 +339,13 @@ static int InsertEmptyWeakref(PyTypeObject *base_type) { // The _extensions_by_name dictionary is built on every access. // TODO(amauryfa): Migrate all users to pool.FindAllExtensions() static PyObject* GetExtensionsByName(CMessageClass *self, void *closure) { + if (self->message_descriptor == NULL) { + // This is the base Message object, simply raise AttributeError. + PyErr_SetString(PyExc_AttributeError, + "Base Message class has no DESCRIPTOR"); + return NULL; + } + const PyDescriptorPool* pool = self->py_message_factory->pool; std::vector<const FieldDescriptor*> extensions; @@ -376,6 +369,13 @@ static PyObject* GetExtensionsByName(CMessageClass *self, void *closure) { // The _extensions_by_number dictionary is built on every access. // TODO(amauryfa): Migrate all users to pool.FindExtensionByNumber() static PyObject* GetExtensionsByNumber(CMessageClass *self, void *closure) { + if (self->message_descriptor == NULL) { + // This is the base Message object, simply raise AttributeError. + PyErr_SetString(PyExc_AttributeError, + "Base Message class has no DESCRIPTOR"); + return NULL; + } + const PyDescriptorPool* pool = self->py_message_factory->pool; std::vector<const FieldDescriptor*> extensions; @@ -405,9 +405,51 @@ static PyGetSetDef Getters[] = { {NULL} }; +// Compute some class attributes on the fly: +// - All the _FIELD_NUMBER attributes, for all fields and nested extensions. +// Returns a new reference, or NULL with an exception set. +static PyObject* GetClassAttribute(CMessageClass *self, PyObject* name) { + char* attr; + Py_ssize_t attr_size; + static const char kSuffix[] = "_FIELD_NUMBER"; + if (PyString_AsStringAndSize(name, &attr, &attr_size) >= 0 && + strings::EndsWith(StringPiece(attr, attr_size), kSuffix)) { + string field_name(attr, attr_size - sizeof(kSuffix) + 1); + LowerString(&field_name); + + // Try to find a field with the given name, without the suffix. + const FieldDescriptor* field = + self->message_descriptor->FindFieldByLowercaseName(field_name); + if (!field) { + // Search nested extensions as well. + field = + self->message_descriptor->FindExtensionByLowercaseName(field_name); + } + if (field) { + return PyInt_FromLong(field->number()); + } + } + PyErr_SetObject(PyExc_AttributeError, name); + return NULL; +} + +static PyObject* GetAttr(CMessageClass* self, PyObject* name) { + PyObject* result = CMessageClass_Type->tp_base->tp_getattro( + reinterpret_cast<PyObject*>(self), name); + if (result != NULL) { + return result; + } + if (!PyErr_ExceptionMatches(PyExc_AttributeError)) { + return NULL; + } + + PyErr_Clear(); + return GetClassAttribute(self, name); +} + } // namespace message_meta -PyTypeObject CMessageClass_Type = { +static PyTypeObject _CMessageClass_Type = { PyVarObject_HEAD_INIT(&PyType_Type, 0) FULL_MODULE_NAME ".MessageMeta", // tp_name sizeof(CMessageClass), // tp_basicsize @@ -424,7 +466,7 @@ PyTypeObject CMessageClass_Type = { 0, // tp_hash 0, // tp_call 0, // tp_str - 0, // tp_getattro + (getattrofunc)message_meta::GetAttr, // tp_getattro 0, // tp_setattro 0, // tp_as_buffer Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, // tp_flags @@ -447,9 +489,10 @@ PyTypeObject CMessageClass_Type = { 0, // tp_alloc message_meta::New, // tp_new }; +PyTypeObject* CMessageClass_Type = &_CMessageClass_Type; static CMessageClass* CheckMessageClass(PyTypeObject* cls) { - if (!PyObject_TypeCheck(cls, &CMessageClass_Type)) { + if (!PyObject_TypeCheck(cls, CMessageClass_Type)) { PyErr_Format(PyExc_TypeError, "Class %s is not a Message", cls->tp_name); return NULL; } @@ -487,10 +530,20 @@ struct ChildVisitor { } // Returns 0 on success, -1 on failure. + int VisitMapContainer(MapContainer* container) { + return 0; + } + + // Returns 0 on success, -1 on failure. int VisitCMessage(CMessage* cmessage, const FieldDescriptor* field_descriptor) { return 0; } + + // Returns 0 on success, -1 on failure. + int VisitUnknownFieldSet(PyUnknownFields* unknown_field_set) { + return 0; + } }; // Apply a function to a composite field. Does nothing if child is of @@ -538,34 +591,19 @@ int ForEachCompositeField(CMessage* self, Visitor visitor) { // Visit normal fields. if (self->composite_fields) { - // Never use self->message in this function, it may be already freed. - const Descriptor* message_descriptor = - GetMessageDescriptor(Py_TYPE(self)); - while (PyDict_Next(self->composite_fields, &pos, &key, &field)) { - Py_ssize_t key_str_size; - char *key_str_data; - if (PyString_AsStringAndSize(key, &key_str_data, &key_str_size) != 0) - return -1; - const string key_str(key_str_data, key_str_size); - const FieldDescriptor* descriptor = - message_descriptor->FindFieldByName(key_str); - if (descriptor != NULL) { - if (VisitCompositeField(descriptor, field, visitor) == -1) - return -1; - } + for (CMessage::CompositeFieldsMap::iterator it = + self->composite_fields->begin(); + it != self->composite_fields->end(); it++) { + const FieldDescriptor* descriptor = it->first; + PyObject* field = it->second; + if (VisitCompositeField(descriptor, field, visitor) == -1) return -1; } } - // Visit extension fields. - if (self->extensions != NULL) { - pos = 0; - while (PyDict_Next(self->extensions->values, &pos, &key, &field)) { - const FieldDescriptor* descriptor = cmessage::GetExtensionDescriptor(key); - if (descriptor == NULL) - return -1; - if (VisitCompositeField(descriptor, field, visitor) == -1) - return -1; - } + if (self->unknown_field_set) { + PyUnknownFields* unknown_field_set = + reinterpret_cast<PyUnknownFields*>(self->unknown_field_set); + visitor.VisitUnknownFieldSet(unknown_field_set); } return 0; @@ -577,8 +615,12 @@ PyObject* EncodeError_class; PyObject* DecodeError_class; PyObject* PickleError_class; -/* Is 64bit */ +// Format an error message for unexpected types. +// Always return with an exception set. void FormatTypeError(PyObject* arg, char* expected_types) { + // This function is often called with an exception set. + // Clear it to call PyObject_Repr() in good conditions. + PyErr_Clear(); PyObject* repr = PyObject_Repr(arg); if (repr) { PyErr_Format(PyExc_TypeError, @@ -602,7 +644,7 @@ void OutOfRangeError(PyObject* arg) { template<class RangeType, class ValueType> bool VerifyIntegerCastAndRange(PyObject* arg, ValueType value) { - if (GOOGLE_PREDICT_FALSE(value == -1 && PyErr_Occurred())) { + if (PROTOBUF_PREDICT_FALSE(value == -1 && PyErr_Occurred())) { if (PyErr_ExceptionMatches(PyExc_OverflowError)) { // Replace it with the same ValueError as pure python protos instead of // the default one. @@ -611,7 +653,7 @@ bool VerifyIntegerCastAndRange(PyObject* arg, ValueType value) { } // Otherwise propagate existing error. return false; } - if (GOOGLE_PREDICT_FALSE(!IsValidNumericCast<RangeType>(value))) { + if (PROTOBUF_PREDICT_FALSE(!IsValidNumericCast<RangeType>(value))) { OutOfRangeError(arg); return false; } @@ -623,22 +665,22 @@ bool CheckAndGetInteger(PyObject* arg, T* value) { // The fast path. #if PY_MAJOR_VERSION < 3 // For the typical case, offer a fast path. - if (GOOGLE_PREDICT_TRUE(PyInt_Check(arg))) { + if (PROTOBUF_PREDICT_TRUE(PyInt_Check(arg))) { long int_result = PyInt_AsLong(arg); - if (GOOGLE_PREDICT_TRUE(IsValidNumericCast<T>(int_result))) { + if (PROTOBUF_PREDICT_TRUE(IsValidNumericCast<T>(int_result))) { *value = static_cast<T>(int_result); return true; } else { OutOfRangeError(arg); return false; } - } + } #endif // This effectively defines an integer as "an object that can be cast as // an integer and can be used as an ordinal number". // This definition includes everything that implements numbers.Integral // and shouldn't cast the net too wide. - if (GOOGLE_PREDICT_FALSE(!PyIndex_Check(arg))) { + if (PROTOBUF_PREDICT_FALSE(!PyIndex_Check(arg))) { FormatTypeError(arg, "int, long"); return false; } @@ -655,7 +697,7 @@ bool CheckAndGetInteger(PyObject* arg, T* value) { // Unlike PyLong_AsLongLong, PyLong_AsUnsignedLongLong is very // picky about the exact type. PyObject* casted = PyNumber_Long(arg); - if (GOOGLE_PREDICT_FALSE(casted == nullptr)) { + if (PROTOBUF_PREDICT_FALSE(casted == nullptr)) { // Propagate existing error. return false; } @@ -680,7 +722,7 @@ bool CheckAndGetInteger(PyObject* arg, T* value) { // Valid subclasses of numbers.Integral should have a __long__() method // so fall back to that. PyObject* casted = PyNumber_Long(arg); - if (GOOGLE_PREDICT_FALSE(casted == nullptr)) { + if (PROTOBUF_PREDICT_FALSE(casted == nullptr)) { // Propagate existing error. return false; } @@ -706,7 +748,7 @@ template bool CheckAndGetInteger<uint64>(PyObject*, uint64*); bool CheckAndGetDouble(PyObject* arg, double* value) { *value = PyFloat_AsDouble(arg); - if (GOOGLE_PREDICT_FALSE(*value == -1 && PyErr_Occurred())) { + if (PROTOBUF_PREDICT_FALSE(*value == -1 && PyErr_Occurred())) { FormatTypeError(arg, "int, long, float"); return false; } @@ -859,7 +901,7 @@ bool CheckFieldBelongsToMessage(const FieldDescriptor* field_descriptor, namespace cmessage { PyMessageFactory* GetFactoryForMessage(CMessage* message) { - GOOGLE_DCHECK(PyObject_TypeCheck(message, &CMessage_Type)); + GOOGLE_DCHECK(PyObject_TypeCheck(message, CMessage_Type)); return reinterpret_cast<CMessageClass*>(Py_TYPE(message))->py_message_factory; } @@ -883,22 +925,20 @@ static int MaybeReleaseOverlappingOneofField( // Non-message fields don't need to be released. return 0; } - const char* field_name = existing_field->name().c_str(); - PyObject* child_message = cmessage->composite_fields ? - PyDict_GetItemString(cmessage->composite_fields, field_name) : NULL; - if (child_message == NULL) { - // No python reference to this field so no need to release. - return 0; - } - - if (InternalReleaseFieldByDescriptor( - cmessage, existing_field, child_message) < 0) { - return -1; + if (cmessage->composite_fields) { + CMessage::CompositeFieldsMap::iterator iterator = + cmessage->composite_fields->find(existing_field); + if (iterator != cmessage->composite_fields->end()) { + if (InternalReleaseFieldByDescriptor(cmessage, existing_field, + iterator->second) < 0) { + return -1; + } + Py_DECREF(iterator->second); + cmessage->composite_fields->erase(iterator); + } } - return PyDict_DelItemString(cmessage->composite_fields, field_name); -#else - return 0; #endif + return 0; } // --------------------------------------------------------------------- @@ -937,10 +977,49 @@ struct FixupMessageReference : public ChildVisitor { return 0; } + int VisitUnknownFieldSet(PyUnknownFields* unknown_field_set) { + const Reflection* reflection = message_->GetReflection(); + unknown_field_set->fields = &reflection->GetUnknownFields(*message_); + return 0; + } + private: Message* message_; }; +// After a Merge, visit every sub-message that was read-only, and +// eventually update their pointer if the Merge operation modified them. +struct FixupMessageAfterMerge : public FixupMessageReference { + explicit FixupMessageAfterMerge(CMessage* parent) : + FixupMessageReference(parent->message), + parent_cmessage(parent), message(parent->message) {} + + int VisitCMessage(CMessage* cmessage, + const FieldDescriptor* field_descriptor) { + if (cmessage->read_only == false) { + return 0; + } + if (message->GetReflection()->HasField(*message, field_descriptor)) { + Message* mutable_message = GetMutableMessage( + parent_cmessage, field_descriptor); + if (mutable_message == NULL) { + return -1; + } + cmessage->message = mutable_message; + cmessage->read_only = false; + if (ForEachCompositeField( + cmessage, FixupMessageAfterMerge(cmessage)) == -1) { + return -1; + } + } + return 0; + } + + private: + CMessage* parent_cmessage; + Message* message; +}; + int AssureWritable(CMessage* self) { if (self == NULL || !self->read_only) { return 0; @@ -974,10 +1053,8 @@ int AssureWritable(CMessage* self) { // When a CMessage is made writable its Message pointer is updated // to point to a new mutable Message. When that happens we need to // update any references to the old, read-only CMessage. There are - // four places such references occur: RepeatedScalarContainer, - // RepeatedCompositeContainer, MapContainer, and ExtensionDict. - if (self->extensions != NULL) - self->extensions->message = self->message; + // three places such references occur: RepeatedScalarContainer, + // RepeatedCompositeContainer, and MapContainer. if (ForEachCompositeField(self, FixupMessageReference(self->message)) == -1) return -1; @@ -986,27 +1063,6 @@ int AssureWritable(CMessage* self) { // --- Globals: -// Retrieve a C++ FieldDescriptor for a message attribute. -// The C++ message must be valid. -// TODO(amauryfa): This function should stay internal, because exception -// handling is not consistent. -static const FieldDescriptor* GetFieldDescriptor( - CMessage* self, PyObject* name) { - const Descriptor *message_descriptor = self->message->GetDescriptor(); - char* field_name; - Py_ssize_t size; - if (PyString_AsStringAndSize(name, &field_name, &size) < 0) { - return NULL; - } - const FieldDescriptor *field_descriptor = - message_descriptor->FindFieldByName(string(field_name, size)); - if (field_descriptor == NULL) { - // Note: No exception is set! - return NULL; - } - return field_descriptor; -} - // Retrieve a C++ FieldDescriptor for an extension handle. const FieldDescriptor* GetExtensionDescriptor(PyObject* extension) { ScopedPyObjectPtr cdescriptor; @@ -1038,7 +1094,7 @@ static PyObject* GetIntegerEnumValue(const FieldDescriptor& descriptor, const EnumValueDescriptor* enum_value_descriptor = enum_descriptor->FindValueByName(string(enum_label, size)); if (enum_value_descriptor == NULL) { - PyErr_SetString(PyExc_ValueError, "unknown enum label"); + PyErr_Format(PyExc_ValueError, "unknown enum label \"%s\"", enum_label); return NULL; } return PyInt_FromLong(enum_value_descriptor->number()); @@ -1052,11 +1108,10 @@ static PyObject* GetIntegerEnumValue(const FieldDescriptor& descriptor, // needs to do this to make sure CMessages stay alive if they're still // referenced after deletion. Repeated scalar container doesn't need to worry. int InternalDeleteRepeatedField( - CMessage* self, + Message* message, const FieldDescriptor* field_descriptor, PyObject* slice, PyObject* cmessage_list) { - Message* message = self->message; Py_ssize_t length, from, to, step, slice_length; const Reflection* reflection = message->GetReflection(); int min, max; @@ -1134,7 +1189,7 @@ int InternalDeleteRepeatedField( CMessage* last_cmessage = reinterpret_cast<CMessage*>( PyList_GET_ITEM(cmessage_list, PyList_GET_SIZE(cmessage_list) - 1)); repeated_composite_container::ReleaseLastTo( - self, field_descriptor, last_cmessage); + message, field_descriptor, last_cmessage); if (PySequence_DelItem(cmessage_list, -1) < 0) { return -1; } @@ -1160,23 +1215,28 @@ int InitAttributes(CMessage* self, PyObject* args, PyObject* kwargs) { PyObject* name; PyObject* value; while (PyDict_Next(kwargs, &pos, &name, &value)) { - if (!PyString_Check(name)) { + if (!(PyString_Check(name) || PyUnicode_Check(name))) { PyErr_SetString(PyExc_ValueError, "Field name must be a string"); return -1; } - const FieldDescriptor* descriptor = GetFieldDescriptor(self, name); - if (descriptor == NULL) { + ScopedPyObjectPtr property( + PyObject_GetAttr(reinterpret_cast<PyObject*>(Py_TYPE(self)), name)); + if (property == NULL || + !PyObject_TypeCheck(property.get(), CFieldProperty_Type)) { PyErr_Format(PyExc_ValueError, "Protocol message %s has no \"%s\" field.", self->message->GetDescriptor()->name().c_str(), PyString_AsString(name)); return -1; } + const FieldDescriptor* descriptor = + reinterpret_cast<PyMessageFieldProperty*>(property.get()) + ->field_descriptor; if (value == Py_None) { // field=None is the same as no field at all. continue; } if (descriptor->is_map()) { - ScopedPyObjectPtr map(GetAttr(reinterpret_cast<PyObject*>(self), name)); + ScopedPyObjectPtr map(GetFieldValue(self, descriptor)); const FieldDescriptor* value_descriptor = descriptor->message_type()->FindFieldByName("value"); if (value_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { @@ -1204,8 +1264,7 @@ int InitAttributes(CMessage* self, PyObject* args, PyObject* kwargs) { } } } else if (descriptor->label() == FieldDescriptor::LABEL_REPEATED) { - ScopedPyObjectPtr container( - GetAttr(reinterpret_cast<PyObject*>(self), name)); + ScopedPyObjectPtr container(GetFieldValue(self, descriptor)); if (container == NULL) { return -1; } @@ -1272,8 +1331,7 @@ int InitAttributes(CMessage* self, PyObject* args, PyObject* kwargs) { } } } else if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { - ScopedPyObjectPtr message( - GetAttr(reinterpret_cast<PyObject*>(self), name)); + ScopedPyObjectPtr message(GetFieldValue(self, descriptor)); if (message == NULL) { return -1; } @@ -1297,9 +1355,9 @@ int InitAttributes(CMessage* self, PyObject* args, PyObject* kwargs) { if (new_val == NULL) { return -1; } + value = new_val.get(); } - if (SetAttr(reinterpret_cast<PyObject*>(self), name, - (new_val.get() == NULL) ? value : new_val.get()) < 0) { + if (SetFieldValue(self, descriptor, value) < 0) { return -1; } } @@ -1322,10 +1380,11 @@ CMessage* NewEmptyMessage(CMessageClass* type) { self->parent = NULL; self->parent_field_descriptor = NULL; self->read_only = false; - self->extensions = NULL; self->composite_fields = NULL; + self->unknown_field_set = NULL; + return self; } @@ -1408,12 +1467,20 @@ static void Dealloc(CMessage* self) { } // Null out all weak references from children to this message. GOOGLE_CHECK_EQ(0, ForEachCompositeField(self, ClearWeakReferences())); - if (self->extensions) { - self->extensions->parent = NULL; - } - Py_CLEAR(self->extensions); - Py_CLEAR(self->composite_fields); + if (self->composite_fields) { + for (CMessage::CompositeFieldsMap::iterator it = + self->composite_fields->begin(); + it != self->composite_fields->end(); it++) { + Py_DECREF(it->second); + } + delete self->composite_fields; + } + if (self->unknown_field_set) { + unknown_fields::Clear( + reinterpret_cast<PyUnknownFields*>(self->unknown_field_set)); + Py_CLEAR(self->unknown_field_set); + } self->owner.~ThreadUnsafeSharedPtr<Message>(); Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self)); } @@ -1529,7 +1596,7 @@ PyObject* HasField(CMessage* self, PyObject* arg) { return NULL; } #else - field_name = PyUnicode_AsUTF8AndSize(arg, &size); + field_name = const_cast<char*>(PyUnicode_AsUTF8AndSize(arg, &size)); if (!field_name) { return NULL; } @@ -1564,13 +1631,16 @@ PyObject* ClearExtension(CMessage* self, PyObject* extension) { if (descriptor == NULL) { return NULL; } - if (self->extensions != NULL) { - PyObject* value = PyDict_GetItem(self->extensions->values, extension); - if (value != NULL) { - if (InternalReleaseFieldByDescriptor(self, descriptor, value) < 0) { + if (self->composite_fields != NULL) { + CMessage::CompositeFieldsMap::iterator iterator = + self->composite_fields->find(descriptor); + if (iterator != self->composite_fields->end()) { + if (InternalReleaseFieldByDescriptor(self, descriptor, + iterator->second) < 0) { return NULL; } - PyDict_DelItem(self->extensions->values, extension); + Py_DECREF(iterator->second); + self->composite_fields->erase(iterator); } } return ClearFieldByDescriptor(self, descriptor); @@ -1739,13 +1809,16 @@ PyObject* ClearFieldByDescriptor( } PyObject* ClearField(CMessage* self, PyObject* arg) { - if (!PyString_Check(arg)) { + if (!(PyString_Check(arg) || PyUnicode_Check(arg))) { PyErr_SetString(PyExc_TypeError, "field name must be a string"); return NULL; } #if PY_MAJOR_VERSION < 3 - const char* field_name = PyString_AS_STRING(arg); - Py_ssize_t size = PyString_GET_SIZE(arg); + char* field_name; + Py_ssize_t size; + if (PyString_AsStringAndSize(arg, &field_name, &size) < 0) { + return NULL; + } #else Py_ssize_t size; const char* field_name = PyUnicode_AsUTF8AndSize(arg, &size); @@ -1770,14 +1843,16 @@ PyObject* ClearField(CMessage* self, PyObject* arg) { arg = arg_in_oneof.get(); } - // Release the field if it exists in the dict of composite fields. if (self->composite_fields) { - PyObject* value = PyDict_GetItem(self->composite_fields, arg); - if (value != NULL) { - if (InternalReleaseFieldByDescriptor(self, field_descriptor, value) < 0) { + CMessage::CompositeFieldsMap::iterator iterator = + self->composite_fields->find(field_descriptor); + if (iterator != self->composite_fields->end()) { + if (InternalReleaseFieldByDescriptor(self, field_descriptor, + iterator->second) < 0) { return NULL; } - PyDict_DelItem(self->composite_fields, arg); + Py_DECREF(iterator->second); + self->composite_fields->erase(iterator); } } return ClearFieldByDescriptor(self, field_descriptor); @@ -1787,9 +1862,18 @@ PyObject* Clear(CMessage* self) { AssureWritable(self); if (ForEachCompositeField(self, ReleaseChild(self)) == -1) return NULL; - Py_CLEAR(self->extensions); if (self->composite_fields) { - PyDict_Clear(self->composite_fields); + for (CMessage::CompositeFieldsMap::iterator it = + self->composite_fields->begin(); + it != self->composite_fields->end(); it++) { + Py_DECREF(it->second); + } + self->composite_fields->clear(); + } + if (self->unknown_field_set) { + unknown_fields::Clear( + reinterpret_cast<PyUnknownFields*>(self->unknown_field_set)); + Py_CLEAR(self->unknown_field_set); } self->message->Clear(); Py_RETURN_NONE; @@ -1946,7 +2030,7 @@ static PyObject* ToStr(CMessage* self) { PyObject* MergeFrom(CMessage* self, PyObject* arg) { CMessage* other_message; - if (!PyObject_TypeCheck(arg, &CMessage_Type)) { + if (!PyObject_TypeCheck(arg, CMessage_Type)) { PyErr_Format(PyExc_TypeError, "Parameter to MergeFrom() must be instance of same class: " "expected %s got %s.", @@ -1967,18 +2051,19 @@ PyObject* MergeFrom(CMessage* self, PyObject* arg) { } AssureWritable(self); - // TODO(tibell): Message::MergeFrom might turn some child Messages - // into mutable messages, invalidating the message field in the - // corresponding CMessages. We should run a FixupMessageReferences - // pass here. - self->message->MergeFrom(*other_message->message); + // Child message might be lazily created before MergeFrom. Make sure they + // are mutable at this point if child messages are really created. + if (ForEachCompositeField(self, FixupMessageAfterMerge(self)) == -1) { + return NULL; + } + Py_RETURN_NONE; } static PyObject* CopyFrom(CMessage* self, PyObject* arg) { CMessage* other_message; - if (!PyObject_TypeCheck(arg, &CMessage_Type)) { + if (!PyObject_TypeCheck(arg, CMessage_Type)) { PyErr_Format(PyExc_TypeError, "Parameter to CopyFrom() must be instance of same class: " "expected %s got %s.", @@ -2050,6 +2135,7 @@ static PyObject* MergeFromString(CMessage* self, PyObject* arg) { } AssureWritable(self); + io::CodedInputStream input( reinterpret_cast<const uint8*>(data), data_length); if (allow_oversize_protos) { @@ -2058,6 +2144,12 @@ static PyObject* MergeFromString(CMessage* self, PyObject* arg) { PyMessageFactory* factory = GetFactoryForMessage(self); input.SetExtensionRegistry(factory->pool->pool, factory->message_factory); bool success = self->message->MergePartialFromCodedStream(&input); + // Child message might be lazily created before MergeFrom. Make sure they + // are mutable at this point if child messages are really created. + if (ForEachCompositeField(self, FixupMessageAfterMerge(self)) == -1) { + return NULL; + } + if (success) { if (!input.ConsumedEntireMessage()) { // TODO(jieluo): Raise error and return NULL instead. @@ -2088,7 +2180,7 @@ PyObject* RegisterExtension(PyObject* cls, PyObject* extension_handle) { if (descriptor == NULL) { return NULL; } - if (!PyObject_TypeCheck(cls, &CMessageClass_Type)) { + if (!PyObject_TypeCheck(cls, CMessageClass_Type)) { PyErr_Format(PyExc_TypeError, "Expected a message class, got %s", cls->ob_type->tp_name); return NULL; @@ -2192,23 +2284,15 @@ static PyObject* ListFields(CMessage* self) { PyTuple_SET_ITEM(t.get(), 1, extension); } else { // Normal field - const string& field_name = fields[i]->name(); - ScopedPyObjectPtr py_field_name(PyString_FromStringAndSize( - field_name.c_str(), field_name.length())); - if (py_field_name == NULL) { - PyErr_SetString(PyExc_ValueError, "bad string"); - return NULL; - } ScopedPyObjectPtr field_descriptor( PyFieldDescriptor_FromDescriptor(fields[i])); if (field_descriptor == NULL) { return NULL; } - PyObject* field_value = - GetAttr(reinterpret_cast<PyObject*>(self), py_field_name.get()); + PyObject* field_value = GetFieldValue(self, fields[i]); if (field_value == NULL) { - PyErr_SetObject(PyExc_ValueError, py_field_name.get()); + PyErr_SetString(PyExc_ValueError, fields[i]->name().c_str()); return NULL; } PyTuple_SET_ITEM(t.get(), 0, field_descriptor.release()); @@ -2261,7 +2345,7 @@ static PyObject* RichCompare(CMessage* self, PyObject* other, int opid) { } bool equals = true; // If other is not a message, it cannot be equal. - if (!PyObject_TypeCheck(other, &CMessage_Type)) { + if (!PyObject_TypeCheck(other, CMessage_Type)) { equals = false; } const google::protobuf::Message* other_message = @@ -2277,6 +2361,7 @@ static PyObject* RichCompare(CMessage* self, PyObject* other, int opid) { *reinterpret_cast<CMessage*>(other)->message)) { equals = false; } + if (equals ^ (opid == Py_EQ)) { Py_RETURN_FALSE; } else { @@ -2498,7 +2583,7 @@ PyObject* DeepCopy(CMessage* self, PyObject* arg) { if (clone == NULL) { return NULL; } - if (!PyObject_TypeCheck(clone, &CMessage_Type)) { + if (!PyObject_TypeCheck(clone, CMessage_Type)) { Py_DECREF(clone); return NULL; } @@ -2592,26 +2677,29 @@ PyObject* _CheckCalledFromGeneratedFile(PyObject* unused, } static PyObject* GetExtensionDict(CMessage* self, void *closure) { - if (self->extensions) { - Py_INCREF(self->extensions); - return reinterpret_cast<PyObject*>(self->extensions); - } - // If there are extension_ranges, the message is "extendable". Allocate a // dictionary to store the extension fields. const Descriptor* descriptor = GetMessageDescriptor(Py_TYPE(self)); - if (descriptor->extension_range_count() > 0) { - ExtensionDict* extension_dict = extension_dict::NewExtensionDict(self); - if (extension_dict == NULL) { - return NULL; - } - self->extensions = extension_dict; - Py_INCREF(self->extensions); - return reinterpret_cast<PyObject*>(self->extensions); + if (!descriptor->extension_range_count()) { + PyErr_SetNone(PyExc_AttributeError); + return NULL; + } + if (!self->composite_fields) { + self->composite_fields = new CMessage::CompositeFieldsMap(); } + if (!self->composite_fields) { + return NULL; + } + ExtensionDict* extension_dict = extension_dict::NewExtensionDict(self); + return reinterpret_cast<PyObject*>(extension_dict); +} - PyErr_SetNone(PyExc_AttributeError); - return NULL; +static PyObject* UnknownFieldSet(CMessage* self) { + if (self->unknown_field_set == NULL) { + self->unknown_field_set = unknown_fields::NewPyUnknownFields(self); + } + Py_INCREF(self->unknown_field_set); + return self->unknown_field_set; } static PyObject* GetExtensionsByName(CMessage *self, void *closure) { @@ -2682,6 +2770,8 @@ static PyMethodDef Methods[] = { "Serializes the message to a string, only for initialized messages." }, { "SetInParent", (PyCFunction)SetInParent, METH_NOARGS, "Sets the has bit of the given field in its parent message." }, + { "UnknownFields", (PyCFunction)UnknownFieldSet, METH_NOARGS, + "Parse unknown field set"}, { "WhichOneof", (PyCFunction)WhichOneof, METH_O, "Returns the name of the field set inside a oneof, " "or None if no field is set." }, @@ -2693,30 +2783,53 @@ static PyMethodDef Methods[] = { { NULL, NULL} }; -static bool SetCompositeField( - CMessage* self, PyObject* name, PyObject* value) { +static bool SetCompositeField(CMessage* self, const FieldDescriptor* field, + PyObject* value) { if (self->composite_fields == NULL) { - self->composite_fields = PyDict_New(); - if (self->composite_fields == NULL) { - return false; - } + self->composite_fields = new CMessage::CompositeFieldsMap(); } - return PyDict_SetItem(self->composite_fields, name, value) == 0; + Py_INCREF(value); + Py_XDECREF((*self->composite_fields)[field]); + (*self->composite_fields)[field] = value; + return true; } PyObject* GetAttr(PyObject* pself, PyObject* name) { CMessage* self = reinterpret_cast<CMessage*>(pself); - PyObject* value = self->composite_fields ? - PyDict_GetItem(self->composite_fields, name) : NULL; - if (value != NULL) { - Py_INCREF(value); - return value; + PyObject* result = PyObject_GenericGetAttr( + reinterpret_cast<PyObject*>(self), name); + if (result != NULL) { + return result; + } + if (!PyErr_ExceptionMatches(PyExc_AttributeError)) { + return NULL; } - const FieldDescriptor* field_descriptor = GetFieldDescriptor(self, name); - if (field_descriptor == NULL) { - return CMessage_Type.tp_base->tp_getattro( - reinterpret_cast<PyObject*>(self), name); + PyErr_Clear(); + return message_meta::GetClassAttribute( + CheckMessageClass(Py_TYPE(self)), name); +} + +PyObject* GetFieldValue(CMessage* self, + const FieldDescriptor* field_descriptor) { + if (self->composite_fields) { + CMessage::CompositeFieldsMap::iterator it = + self->composite_fields->find(field_descriptor); + if (it != self->composite_fields->end()) { + PyObject* value = it->second; + Py_INCREF(value); + return value; + } + } + + const Descriptor* message_descriptor = + (reinterpret_cast<CMessageClass*>(Py_TYPE(self)))->message_descriptor; + if (self->message->GetDescriptor() != field_descriptor->containing_type()) { + PyErr_Format(PyExc_TypeError, + "descriptor to field '%s' doesn't apply to '%s' object", + field_descriptor->full_name().c_str(), + Py_TYPE(self)->tp_name); + return NULL; } if (field_descriptor->is_map()) { @@ -2737,7 +2850,7 @@ PyObject* GetAttr(PyObject* pself, PyObject* name) { if (py_container == NULL) { return NULL; } - if (!SetCompositeField(self, name, py_container)) { + if (!SetCompositeField(self, field_descriptor, py_container)) { Py_DECREF(py_container); return NULL; } @@ -2761,7 +2874,7 @@ PyObject* GetAttr(PyObject* pself, PyObject* name) { if (py_container == NULL) { return NULL; } - if (!SetCompositeField(self, name, py_container)) { + if (!SetCompositeField(self, field_descriptor, py_container)) { Py_DECREF(py_container); return NULL; } @@ -2773,7 +2886,7 @@ PyObject* GetAttr(PyObject* pself, PyObject* name) { if (sub_message == NULL) { return NULL; } - if (!SetCompositeField(self, name, sub_message)) { + if (!SetCompositeField(self, field_descriptor, sub_message)) { Py_DECREF(sub_message); return NULL; } @@ -2783,44 +2896,35 @@ PyObject* GetAttr(PyObject* pself, PyObject* name) { return InternalGetScalar(self->message, field_descriptor); } -int SetAttr(PyObject* pself, PyObject* name, PyObject* value) { - CMessage* self = reinterpret_cast<CMessage*>(pself); - if (self->composite_fields && PyDict_Contains(self->composite_fields, name)) { - PyErr_SetString(PyExc_TypeError, "Can't set composite field"); +int SetFieldValue(CMessage* self, const FieldDescriptor* field_descriptor, + PyObject* value) { + if (self->message->GetDescriptor() != field_descriptor->containing_type()) { + PyErr_Format(PyExc_TypeError, + "descriptor to field '%s' doesn't apply to '%s' object", + field_descriptor->full_name().c_str(), + Py_TYPE(self)->tp_name); return -1; - } - - const FieldDescriptor* field_descriptor = GetFieldDescriptor(self, name); - if (field_descriptor != NULL) { + } else if (field_descriptor->label() == FieldDescriptor::LABEL_REPEATED) { + PyErr_Format(PyExc_AttributeError, + "Assignment not allowed to repeated " + "field \"%s\" in protocol message object.", + field_descriptor->name().c_str()); + return -1; + } else if (field_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { + PyErr_Format(PyExc_AttributeError, + "Assignment not allowed to " + "field \"%s\" in protocol message object.", + field_descriptor->name().c_str()); + return -1; + } else { AssureWritable(self); - if (field_descriptor->label() == FieldDescriptor::LABEL_REPEATED) { - PyErr_Format(PyExc_AttributeError, "Assignment not allowed to repeated " - "field \"%s\" in protocol message object.", - field_descriptor->name().c_str()); - return -1; - } else { - if (field_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { - PyErr_Format(PyExc_AttributeError, "Assignment not allowed to " - "field \"%s\" in protocol message object.", - field_descriptor->name().c_str()); - return -1; - } else { - return InternalSetScalar(self, field_descriptor, value); - } - } + return InternalSetScalar(self, field_descriptor, value); } - - PyErr_Format(PyExc_AttributeError, - "Assignment not allowed " - "(no field \"%s\" in protocol message object).", - PyString_AsString(name)); - return -1; } - } // namespace cmessage -PyTypeObject CMessage_Type = { - PyVarObject_HEAD_INIT(&CMessageClass_Type, 0) +static CMessageClass _CMessage_Type = { { { + PyVarObject_HEAD_INIT(&_CMessageClass_Type, 0) FULL_MODULE_NAME ".CMessage", // tp_name sizeof(CMessage), // tp_basicsize 0, // tp_itemsize @@ -2837,9 +2941,10 @@ PyTypeObject CMessage_Type = { 0, // tp_call (reprfunc)cmessage::ToStr, // tp_str cmessage::GetAttr, // tp_getattro - cmessage::SetAttr, // tp_setattro + 0, // tp_setattro 0, // tp_as_buffer - Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, // tp_flags + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE + | Py_TPFLAGS_HAVE_VERSION_TAG, // tp_flags "A ProtocolMessage", // tp_doc 0, // tp_traverse 0, // tp_clear @@ -2858,7 +2963,8 @@ PyTypeObject CMessage_Type = { (initproc)cmessage::Init, // tp_init 0, // tp_alloc cmessage::New, // tp_new -}; +} } }; +PyTypeObject* CMessage_Type = &_CMessage_Type.super.ht_type; // --- Exposing the C proto living inside Python proto to C code: @@ -2884,7 +2990,7 @@ static Message* MutableCProtoInsidePyProtoImpl(PyObject* msg) { } const Message* PyMessage_GetMessagePointer(PyObject* msg) { - if (!PyObject_TypeCheck(msg, &CMessage_Type)) { + if (!PyObject_TypeCheck(msg, CMessage_Type)) { PyErr_SetString(PyExc_TypeError, "Not a Message instance"); return NULL; } @@ -2893,15 +2999,14 @@ const Message* PyMessage_GetMessagePointer(PyObject* msg) { } Message* PyMessage_GetMutableMessagePointer(PyObject* msg) { - if (!PyObject_TypeCheck(msg, &CMessage_Type)) { + if (!PyObject_TypeCheck(msg, CMessage_Type)) { PyErr_SetString(PyExc_TypeError, "Not a Message instance"); return NULL; } + CMessage* cmsg = reinterpret_cast<CMessage*>(msg); - if ((cmsg->composite_fields && PyDict_Size(cmsg->composite_fields) != 0) || - (cmsg->extensions != NULL && - PyDict_Size(cmsg->extensions->values) != 0)) { + if (cmsg->composite_fields && !cmsg->composite_fields->empty()) { // There is currently no way of accurately syncing arbitrary changes to // the underlying C++ message back to the CMessage (e.g. removed repeated // composite containers). We only allow direct mutation of the underlying @@ -2945,22 +3050,29 @@ bool InitProto2MessageModule(PyObject *m) { // Initialize constants defined in this file. InitGlobals(); - CMessageClass_Type.tp_base = &PyType_Type; - if (PyType_Ready(&CMessageClass_Type) < 0) { + CMessageClass_Type->tp_base = &PyType_Type; + if (PyType_Ready(CMessageClass_Type) < 0) { return false; } PyModule_AddObject(m, "MessageMeta", - reinterpret_cast<PyObject*>(&CMessageClass_Type)); + reinterpret_cast<PyObject*>(CMessageClass_Type)); - if (PyType_Ready(&CMessage_Type) < 0) { + if (PyType_Ready(CMessage_Type) < 0) { + return false; + } + if (PyType_Ready(CFieldProperty_Type) < 0) { return false; } // DESCRIPTOR is set on each protocol buffer message class elsewhere, but set // it here as well to document that subclasses need to set it. - PyDict_SetItem(CMessage_Type.tp_dict, kDESCRIPTOR, Py_None); + PyDict_SetItem(CMessage_Type->tp_dict, kDESCRIPTOR, Py_None); + // Invalidate any cached data for the CMessage type. + // This call is necessary to correctly support Py_TPFLAGS_HAVE_VERSION_TAG, + // after we have modified CMessage_Type.tp_dict. + PyType_Modified(CMessage_Type); - PyModule_AddObject(m, "Message", reinterpret_cast<PyObject*>(&CMessage_Type)); + PyModule_AddObject(m, "Message", reinterpret_cast<PyObject*>(CMessage_Type)); // Initialize Repeated container types. { @@ -3003,6 +3115,22 @@ bool InitProto2MessageModule(PyObject *m) { } } + if (PyType_Ready(&PyUnknownFields_Type) < 0) { + return false; + } + + PyModule_AddObject(m, "UnknownFieldSet", + reinterpret_cast<PyObject*>( + &PyUnknownFields_Type)); + + if (PyType_Ready(&PyUnknownFieldRef_Type) < 0) { + return false; + } + + PyModule_AddObject(m, "UnknownField", + reinterpret_cast<PyObject*>( + &PyUnknownFieldRef_Type)); + // Initialize Map container types. if (!InitMapContainers()) { return false; |