diff options
author | Adam Cozzette <acozzette@google.com> | 2016-11-17 16:48:38 -0800 |
---|---|---|
committer | Adam Cozzette <acozzette@google.com> | 2016-11-17 16:59:59 -0800 |
commit | 5a76e633ea9b5adb215e93fdc11e1c0c08b3fc74 (patch) | |
tree | 0276f81f8848a05d84cd7e287b43d665e30f04e3 /python/google/protobuf/pyext/message.cc | |
parent | e28286fa05d8327fd6c5aa70cfb3be558f0932b8 (diff) | |
download | protobuf-5a76e633ea9b5adb215e93fdc11e1c0c08b3fc74.tar.gz protobuf-5a76e633ea9b5adb215e93fdc11e1c0c08b3fc74.tar.bz2 protobuf-5a76e633ea9b5adb215e93fdc11e1c0c08b3fc74.zip |
Integrated internal changes from Google
Diffstat (limited to 'python/google/protobuf/pyext/message.cc')
-rw-r--r-- | python/google/protobuf/pyext/message.cc | 409 |
1 files changed, 205 insertions, 204 deletions
diff --git a/python/google/protobuf/pyext/message.cc b/python/google/protobuf/pyext/message.cc index 7ff99aea..5967a587 100644 --- a/python/google/protobuf/pyext/message.cc +++ b/python/google/protobuf/pyext/message.cc @@ -64,11 +64,11 @@ #include <google/protobuf/pyext/repeated_scalar_container.h> #include <google/protobuf/pyext/map_container.h> #include <google/protobuf/pyext/message_factory.h> +#include <google/protobuf/pyext/safe_numerics.h> #include <google/protobuf/pyext/scoped_pyobject_ptr.h> #include <google/protobuf/stubs/strutil.h> #if PY_MAJOR_VERSION >= 3 - #define PyInt_Check PyLong_Check #define PyInt_AsLong PyLong_AsLong #define PyInt_FromLong PyLong_FromLong #define PyInt_FromSize_t PyLong_FromSize_t @@ -92,8 +92,6 @@ namespace protobuf { namespace python { static PyObject* kDESCRIPTOR; -static PyObject* k_extensions_by_name; -static PyObject* k_extensions_by_number; PyObject* EnumTypeWrapper_class; static PyObject* PythonMessage_class; static PyObject* kEmptyWeakref; @@ -128,19 +126,6 @@ static bool AddFieldNumberToClass( // Finalize the creation of the Message class. static int AddDescriptors(PyObject* cls, const Descriptor* descriptor) { - // If there are extension_ranges, the message is "extendable", and extension - // classes will register themselves in this class. - if (descriptor->extension_range_count() > 0) { - ScopedPyObjectPtr by_name(PyDict_New()); - if (PyObject_SetAttr(cls, k_extensions_by_name, by_name.get()) < 0) { - return -1; - } - ScopedPyObjectPtr by_number(PyDict_New()); - if (PyObject_SetAttr(cls, k_extensions_by_number, by_number.get()) < 0) { - return -1; - } - } - // For each field set: cls.<field>_FIELD_NUMBER = <number> for (int i = 0; i < descriptor->field_count(); ++i) { if (!AddFieldNumberToClass(cls, descriptor->field(i))) { @@ -357,6 +342,61 @@ static int InsertEmptyWeakref(PyTypeObject *base_type) { #endif // PY_MAJOR_VERSION >= 3 } +// The _extensions_by_name dictionary is built on every access. +// TODO(amauryfa): Migrate all users to pool.FindAllExtensions() +static PyObject* GetExtensionsByName(CMessageClass *self, void *closure) { + const PyDescriptorPool* pool = self->py_message_factory->pool; + + std::vector<const FieldDescriptor*> extensions; + pool->pool->FindAllExtensions(self->message_descriptor, &extensions); + + ScopedPyObjectPtr result(PyDict_New()); + for (int i = 0; i < extensions.size(); i++) { + ScopedPyObjectPtr extension( + PyFieldDescriptor_FromDescriptor(extensions[i])); + if (extension == NULL) { + return NULL; + } + if (PyDict_SetItemString(result.get(), extensions[i]->full_name().c_str(), + extension.get()) < 0) { + return NULL; + } + } + return result.release(); +} + +// The _extensions_by_number dictionary is built on every access. +// TODO(amauryfa): Migrate all users to pool.FindExtensionByNumber() +static PyObject* GetExtensionsByNumber(CMessageClass *self, void *closure) { + const PyDescriptorPool* pool = self->py_message_factory->pool; + + std::vector<const FieldDescriptor*> extensions; + pool->pool->FindAllExtensions(self->message_descriptor, &extensions); + + ScopedPyObjectPtr result(PyDict_New()); + for (int i = 0; i < extensions.size(); i++) { + ScopedPyObjectPtr extension( + PyFieldDescriptor_FromDescriptor(extensions[i])); + if (extension == NULL) { + return NULL; + } + ScopedPyObjectPtr number(PyInt_FromLong(extensions[i]->number())); + if (number == NULL) { + return NULL; + } + if (PyDict_SetItem(result.get(), number.get(), extension.get()) < 0) { + return NULL; + } + } + return result.release(); +} + +static PyGetSetDef Getters[] = { + {"_extensions_by_name", (getter)GetExtensionsByName, NULL}, + {"_extensions_by_number", (getter)GetExtensionsByNumber, NULL}, + {NULL} +}; + } // namespace message_meta PyTypeObject CMessageClass_Type = { @@ -389,7 +429,7 @@ PyTypeObject CMessageClass_Type = { 0, // tp_iternext 0, // tp_methods 0, // tp_members - 0, // tp_getset + message_meta::Getters, // tp_getset 0, // tp_base 0, // tp_dict 0, // tp_descr_get @@ -525,23 +565,10 @@ int ForEachCompositeField(CMessage* self, Visitor visitor) { // --------------------------------------------------------------------- -// Constants used for integer type range checking. -PyObject* kPythonZero; -PyObject* kint32min_py; -PyObject* kint32max_py; -PyObject* kuint32max_py; -PyObject* kint64min_py; -PyObject* kint64max_py; -PyObject* kuint64max_py; - PyObject* EncodeError_class; PyObject* DecodeError_class; PyObject* PickleError_class; -// Constant PyString values used for GetAttr/GetItem. -static PyObject* k_cdescriptor; -static PyObject* kfull_name; - /* Is 64bit */ void FormatTypeError(PyObject* arg, char* expected_types) { PyObject* repr = PyObject_Repr(arg); @@ -555,68 +582,126 @@ void FormatTypeError(PyObject* arg, char* expected_types) { } } -template<class T> -bool CheckAndGetInteger( - PyObject* arg, T* value, PyObject* min, PyObject* max) { - bool is_long = PyLong_Check(arg); -#if PY_MAJOR_VERSION < 3 - if (!PyInt_Check(arg) && !is_long) { - FormatTypeError(arg, "int, long"); - return false; +void OutOfRangeError(PyObject* arg) { + PyObject *s = PyObject_Str(arg); + if (s) { + PyErr_Format(PyExc_ValueError, + "Value out of range: %s", + PyString_AsString(s)); + Py_DECREF(s); } - if (PyObject_Compare(min, arg) > 0 || PyObject_Compare(max, arg) < 0) { -#else - if (!is_long) { - FormatTypeError(arg, "int"); +} + +template<class RangeType, class ValueType> +bool VerifyIntegerCastAndRange(PyObject* arg, ValueType value) { + if GOOGLE_PREDICT_FALSE(value == -1 && PyErr_Occurred()) { + if (PyErr_ExceptionMatches(PyExc_OverflowError)) { + // Replace it with the same ValueError as pure python protos instead of + // the default one. + PyErr_Clear(); + OutOfRangeError(arg); + } // Otherwise propagate existing error. return false; } - if (PyObject_RichCompareBool(min, arg, Py_LE) != 1 || - PyObject_RichCompareBool(max, arg, Py_GE) != 1) { -#endif - if (!PyErr_Occurred()) { - PyObject *s = PyObject_Str(arg); - if (s) { - PyErr_Format(PyExc_ValueError, - "Value out of range: %s", - PyString_AsString(s)); - Py_DECREF(s); - } - } + if GOOGLE_PREDICT_FALSE(!IsValidNumericCast<RangeType>(value)) { + OutOfRangeError(arg); return false; } + return true; +} + +template<class T> +bool CheckAndGetInteger(PyObject* arg, T* value) { + // The fast path. #if PY_MAJOR_VERSION < 3 - if (!is_long) { - *value = static_cast<T>(PyInt_AsLong(arg)); - } else // NOLINT + // For the typical case, offer a fast path. + if GOOGLE_PREDICT_TRUE(PyInt_Check(arg)) { + long int_result = PyInt_AsLong(arg); + if GOOGLE_PREDICT_TRUE(IsValidNumericCast<T>(int_result)) { + *value = static_cast<T>(int_result); + return true; + } else { + OutOfRangeError(arg); + return false; + } + } #endif - { - if (min == kPythonZero) { - *value = static_cast<T>(PyLong_AsUnsignedLongLong(arg)); + // This effectively defines an integer as "an object that can be cast as + // an integer and can be used as an ordinal number". + // This definition includes everything that implements numbers.Integral + // and shouldn't cast the net too wide. + if GOOGLE_PREDICT_FALSE(!PyIndex_Check(arg)) { + FormatTypeError(arg, "int, long"); + return false; + } + + // Now we have an integral number so we can safely use PyLong_ functions. + // We need to treat the signed and unsigned cases differently in case arg is + // holding a value above the maximum for signed longs. + if (std::numeric_limits<T>::min() == 0) { + // Unsigned case. + unsigned PY_LONG_LONG ulong_result; + if (PyLong_Check(arg)) { + ulong_result = PyLong_AsUnsignedLongLong(arg); + } else { + // Unlike PyLong_AsLongLong, PyLong_AsUnsignedLongLong is very + // picky about the exact type. + PyObject* casted = PyNumber_Long(arg); + if GOOGLE_PREDICT_FALSE(casted == NULL) { + // Propagate existing error. + return false; + } + ulong_result = PyLong_AsUnsignedLongLong(casted); + Py_DECREF(casted); + } + if (VerifyIntegerCastAndRange<T, unsigned PY_LONG_LONG>(arg, + ulong_result)) { + *value = static_cast<T>(ulong_result); + } else { + return false; + } + } else { + // Signed case. + PY_LONG_LONG long_result; + PyNumberMethods *nb; + if ((nb = arg->ob_type->tp_as_number) != NULL && nb->nb_int != NULL) { + // PyLong_AsLongLong requires it to be a long or to have an __int__() + // method. + long_result = PyLong_AsLongLong(arg); } else { - *value = static_cast<T>(PyLong_AsLongLong(arg)); + // Valid subclasses of numbers.Integral should have a __long__() method + // so fall back to that. + PyObject* casted = PyNumber_Long(arg); + if GOOGLE_PREDICT_FALSE(casted == NULL) { + // Propagate existing error. + return false; + } + long_result = PyLong_AsLongLong(casted); + Py_DECREF(casted); + } + if (VerifyIntegerCastAndRange<T, PY_LONG_LONG>(arg, long_result)) { + *value = static_cast<T>(long_result); + } else { + return false; } } + return true; } // These are referenced by repeated_scalar_container, and must // be explicitly instantiated. -template bool CheckAndGetInteger<int32>( - PyObject*, int32*, PyObject*, PyObject*); -template bool CheckAndGetInteger<int64>( - PyObject*, int64*, PyObject*, PyObject*); -template bool CheckAndGetInteger<uint32>( - PyObject*, uint32*, PyObject*, PyObject*); -template bool CheckAndGetInteger<uint64>( - PyObject*, uint64*, PyObject*, PyObject*); +template bool CheckAndGetInteger<int32>(PyObject*, int32*); +template bool CheckAndGetInteger<int64>(PyObject*, int64*); +template bool CheckAndGetInteger<uint32>(PyObject*, uint32*); +template bool CheckAndGetInteger<uint64>(PyObject*, uint64*); bool CheckAndGetDouble(PyObject* arg, double* value) { - if (!PyInt_Check(arg) && !PyLong_Check(arg) && - !PyFloat_Check(arg)) { + *value = PyFloat_AsDouble(arg); + if GOOGLE_PREDICT_FALSE(*value == -1 && PyErr_Occurred()) { FormatTypeError(arg, "int, long, float"); return false; } - *value = PyFloat_AsDouble(arg); return true; } @@ -630,11 +715,13 @@ bool CheckAndGetFloat(PyObject* arg, float* value) { } bool CheckAndGetBool(PyObject* arg, bool* value) { - if (!PyInt_Check(arg) && !PyBool_Check(arg) && !PyLong_Check(arg)) { + long long_value = PyInt_AsLong(arg); + if (long_value == -1 && PyErr_Occurred()) { FormatTypeError(arg, "int, long, bool"); return false; } - *value = static_cast<bool>(PyInt_AsLong(arg)); + *value = static_cast<bool>(long_value); + return true; } @@ -966,20 +1053,7 @@ int InternalDeleteRepeatedField( int min, max; length = reflection->FieldSize(*message, field_descriptor); - if (PyInt_Check(slice) || PyLong_Check(slice)) { - from = to = PyLong_AsLong(slice); - if (from < 0) { - from = to = length + from; - } - step = 1; - min = max = from; - - // Range check. - if (from < 0 || from >= length) { - PyErr_Format(PyExc_IndexError, "list assignment index out of range"); - return -1; - } - } else if (PySlice_Check(slice)) { + if (PySlice_Check(slice)) { from = to = step = slice_length = 0; PySlice_GetIndicesEx( #if PY_MAJOR_VERSION < 3 @@ -996,8 +1070,23 @@ int InternalDeleteRepeatedField( max = from; } } else { - PyErr_SetString(PyExc_TypeError, "list indices must be integers"); - return -1; + from = to = PyLong_AsLong(slice); + if (from == -1 && PyErr_Occurred()) { + PyErr_SetString(PyExc_TypeError, "list indices must be integers"); + return -1; + } + + if (from < 0) { + from = to = length + from; + } + step = 1; + min = max = from; + + // Range check. + if (from < 0 || from >= length) { + PyErr_Format(PyExc_IndexError, "list assignment index out of range"); + return -1; + } } Py_ssize_t i = from; @@ -1958,99 +2047,29 @@ static PyObject* ByteSize(CMessage* self, PyObject* args) { return PyLong_FromLong(self->message->ByteSize()); } -static PyObject* RegisterExtension(PyObject* cls, - PyObject* extension_handle) { +PyObject* RegisterExtension(PyObject* cls, PyObject* extension_handle) { const FieldDescriptor* descriptor = GetExtensionDescriptor(extension_handle); if (descriptor == NULL) { return NULL; } - - ScopedPyObjectPtr extensions_by_name( - PyObject_GetAttr(cls, k_extensions_by_name)); - if (extensions_by_name == NULL) { - PyErr_SetString(PyExc_TypeError, "no extensions_by_name on class"); + if (!PyObject_TypeCheck(cls, &CMessageClass_Type)) { + PyErr_Format(PyExc_TypeError, "Expected a message class, got %s", + cls->ob_type->tp_name); return NULL; } - ScopedPyObjectPtr full_name(PyObject_GetAttr(extension_handle, kfull_name)); - if (full_name == NULL) { + CMessageClass *message_class = reinterpret_cast<CMessageClass*>(cls); + if (message_class == NULL) { return NULL; } - // If the extension was already registered, check that it is the same. - PyObject* existing_extension = - PyDict_GetItem(extensions_by_name.get(), full_name.get()); - if (existing_extension != NULL) { - const FieldDescriptor* existing_extension_descriptor = - GetExtensionDescriptor(existing_extension); - if (existing_extension_descriptor != descriptor) { - PyErr_SetString(PyExc_ValueError, "Double registration of Extensions"); - return NULL; - } - // Nothing else to do. - Py_RETURN_NONE; - } - - if (PyDict_SetItem(extensions_by_name.get(), full_name.get(), - extension_handle) < 0) { - return NULL; - } - - // Also store a mapping from extension number to implementing class. - ScopedPyObjectPtr extensions_by_number( - PyObject_GetAttr(cls, k_extensions_by_number)); - if (extensions_by_number == NULL) { - PyErr_SetString(PyExc_TypeError, "no extensions_by_number on class"); - return NULL; - } - - ScopedPyObjectPtr number(PyObject_GetAttrString(extension_handle, "number")); - if (number == NULL) { - return NULL; - } - - // If the extension was already registered by number, check that it is the - // same. - existing_extension = PyDict_GetItem(extensions_by_number.get(), number.get()); - if (existing_extension != NULL) { - const FieldDescriptor* existing_extension_descriptor = - GetExtensionDescriptor(existing_extension); - if (existing_extension_descriptor != descriptor) { - const Descriptor* msg_desc = GetMessageDescriptor( - reinterpret_cast<PyTypeObject*>(cls)); - PyErr_Format( - PyExc_ValueError, - "Extensions \"%s\" and \"%s\" both try to extend message type " - "\"%s\" with field number %ld.", - existing_extension_descriptor->full_name().c_str(), - descriptor->full_name().c_str(), - msg_desc->full_name().c_str(), - PyInt_AsLong(number.get())); - return NULL; - } - // Nothing else to do. - Py_RETURN_NONE; - } - if (PyDict_SetItem(extensions_by_number.get(), number.get(), - extension_handle) < 0) { + const FieldDescriptor* existing_extension = + message_class->py_message_factory->pool->pool->FindExtensionByNumber( + descriptor->containing_type(), descriptor->number()); + if (existing_extension != NULL && existing_extension != descriptor) { + PyErr_SetString(PyExc_ValueError, "Double registration of Extensions"); return NULL; } - - // Check if it's a message set - if (descriptor->is_extension() && - descriptor->containing_type()->options().message_set_wire_format() && - descriptor->type() == FieldDescriptor::TYPE_MESSAGE && - descriptor->label() == FieldDescriptor::LABEL_OPTIONAL) { - ScopedPyObjectPtr message_name(PyString_FromStringAndSize( - descriptor->message_type()->full_name().c_str(), - descriptor->message_type()->full_name().size())); - if (message_name == NULL) { - return NULL; - } - PyDict_SetItem(extensions_by_name.get(), message_name.get(), - extension_handle); - } - Py_RETURN_NONE; } @@ -2087,7 +2106,7 @@ static PyObject* WhichOneof(CMessage* self, PyObject* arg) { static PyObject* GetExtensionDict(CMessage* self, void *closure); static PyObject* ListFields(CMessage* self) { - vector<const FieldDescriptor*> fields; + std::vector<const FieldDescriptor*> fields; self->message->GetReflection()->ListFields(*self->message, &fields); // Normally, the list will be exactly the size of the fields. @@ -2178,7 +2197,7 @@ static PyObject* DiscardUnknownFields(CMessage* self) { PyObject* FindInitializationErrors(CMessage* self) { Message* message = self->message; - vector<string> errors; + std::vector<string> errors; message->FindInitializationErrors(&errors); PyObject* error_list = PyList_New(errors.size()); @@ -2570,11 +2589,24 @@ static PyObject* GetExtensionDict(CMessage* self, void *closure) { return NULL; } +static PyObject* GetExtensionsByName(CMessage *self, void *closure) { + return message_meta::GetExtensionsByName( + reinterpret_cast<CMessageClass*>(Py_TYPE(self)), closure); +} + +static PyObject* GetExtensionsByNumber(CMessage *self, void *closure) { + return message_meta::GetExtensionsByNumber( + reinterpret_cast<CMessageClass*>(Py_TYPE(self)), closure); +} + static PyGetSetDef Getters[] = { {"Extensions", (getter)GetExtensionDict, NULL, "Extension dict"}, + {"_extensions_by_name", (getter)GetExtensionsByName, NULL}, + {"_extensions_by_number", (getter)GetExtensionsByNumber, NULL}, {NULL} }; + static PyMethodDef Methods[] = { { "__deepcopy__", (PyCFunction)DeepCopy, METH_VARARGS, "Makes a deep copy of the class." }, @@ -2835,19 +2867,7 @@ void InitGlobals() { // TODO(gps): Check all return values in this function for NULL and propagate // the error (MemoryError) on up to result in an import failure. These should // also be freed and reset to NULL during finalization. - kPythonZero = PyInt_FromLong(0); - kint32min_py = PyInt_FromLong(kint32min); - kint32max_py = PyInt_FromLong(kint32max); - kuint32max_py = PyLong_FromLongLong(kuint32max); - kint64min_py = PyLong_FromLongLong(kint64min); - kint64max_py = PyLong_FromLongLong(kint64max); - kuint64max_py = PyLong_FromUnsignedLongLong(kuint64max); - kDESCRIPTOR = PyString_FromString("DESCRIPTOR"); - k_cdescriptor = PyString_FromString("_cdescriptor"); - kfull_name = PyString_FromString("full_name"); - k_extensions_by_name = PyString_FromString("_extensions_by_name"); - k_extensions_by_number = PyString_FromString("_extensions_by_number"); PyObject *dummy_obj = PySet_New(NULL); kEmptyWeakref = PyWeakref_NewRef(dummy_obj, NULL); @@ -2887,25 +2907,6 @@ bool InitProto2MessageModule(PyObject *m) { // DESCRIPTOR is set on each protocol buffer message class elsewhere, but set // it here as well to document that subclasses need to set it. PyDict_SetItem(CMessage_Type.tp_dict, kDESCRIPTOR, Py_None); - // Subclasses with message extensions will override _extensions_by_name and - // _extensions_by_number with fresh mutable dictionaries in AddDescriptors. - // All other classes can share this same immutable mapping. - ScopedPyObjectPtr empty_dict(PyDict_New()); - if (empty_dict == NULL) { - return false; - } - ScopedPyObjectPtr immutable_dict(PyDictProxy_New(empty_dict.get())); - if (immutable_dict == NULL) { - return false; - } - if (PyDict_SetItem(CMessage_Type.tp_dict, - k_extensions_by_name, immutable_dict.get()) < 0) { - return false; - } - if (PyDict_SetItem(CMessage_Type.tp_dict, - k_extensions_by_number, immutable_dict.get()) < 0) { - return false; - } PyModule_AddObject(m, "Message", reinterpret_cast<PyObject*>(&CMessage_Type)); |