aboutsummaryrefslogtreecommitdiff
path: root/python/google/protobuf/pyext/message.cc
diff options
context:
space:
mode:
authorAdam Cozzette <acozzette@google.com>2016-11-17 16:48:38 -0800
committerAdam Cozzette <acozzette@google.com>2016-11-17 16:59:59 -0800
commit5a76e633ea9b5adb215e93fdc11e1c0c08b3fc74 (patch)
tree0276f81f8848a05d84cd7e287b43d665e30f04e3 /python/google/protobuf/pyext/message.cc
parente28286fa05d8327fd6c5aa70cfb3be558f0932b8 (diff)
downloadprotobuf-5a76e633ea9b5adb215e93fdc11e1c0c08b3fc74.tar.gz
protobuf-5a76e633ea9b5adb215e93fdc11e1c0c08b3fc74.tar.bz2
protobuf-5a76e633ea9b5adb215e93fdc11e1c0c08b3fc74.zip
Integrated internal changes from Google
Diffstat (limited to 'python/google/protobuf/pyext/message.cc')
-rw-r--r--python/google/protobuf/pyext/message.cc409
1 files changed, 205 insertions, 204 deletions
diff --git a/python/google/protobuf/pyext/message.cc b/python/google/protobuf/pyext/message.cc
index 7ff99aea..5967a587 100644
--- a/python/google/protobuf/pyext/message.cc
+++ b/python/google/protobuf/pyext/message.cc
@@ -64,11 +64,11 @@
#include <google/protobuf/pyext/repeated_scalar_container.h>
#include <google/protobuf/pyext/map_container.h>
#include <google/protobuf/pyext/message_factory.h>
+#include <google/protobuf/pyext/safe_numerics.h>
#include <google/protobuf/pyext/scoped_pyobject_ptr.h>
#include <google/protobuf/stubs/strutil.h>
#if PY_MAJOR_VERSION >= 3
- #define PyInt_Check PyLong_Check
#define PyInt_AsLong PyLong_AsLong
#define PyInt_FromLong PyLong_FromLong
#define PyInt_FromSize_t PyLong_FromSize_t
@@ -92,8 +92,6 @@ namespace protobuf {
namespace python {
static PyObject* kDESCRIPTOR;
-static PyObject* k_extensions_by_name;
-static PyObject* k_extensions_by_number;
PyObject* EnumTypeWrapper_class;
static PyObject* PythonMessage_class;
static PyObject* kEmptyWeakref;
@@ -128,19 +126,6 @@ static bool AddFieldNumberToClass(
// Finalize the creation of the Message class.
static int AddDescriptors(PyObject* cls, const Descriptor* descriptor) {
- // If there are extension_ranges, the message is "extendable", and extension
- // classes will register themselves in this class.
- if (descriptor->extension_range_count() > 0) {
- ScopedPyObjectPtr by_name(PyDict_New());
- if (PyObject_SetAttr(cls, k_extensions_by_name, by_name.get()) < 0) {
- return -1;
- }
- ScopedPyObjectPtr by_number(PyDict_New());
- if (PyObject_SetAttr(cls, k_extensions_by_number, by_number.get()) < 0) {
- return -1;
- }
- }
-
// For each field set: cls.<field>_FIELD_NUMBER = <number>
for (int i = 0; i < descriptor->field_count(); ++i) {
if (!AddFieldNumberToClass(cls, descriptor->field(i))) {
@@ -357,6 +342,61 @@ static int InsertEmptyWeakref(PyTypeObject *base_type) {
#endif // PY_MAJOR_VERSION >= 3
}
+// The _extensions_by_name dictionary is built on every access.
+// TODO(amauryfa): Migrate all users to pool.FindAllExtensions()
+static PyObject* GetExtensionsByName(CMessageClass *self, void *closure) {
+ const PyDescriptorPool* pool = self->py_message_factory->pool;
+
+ std::vector<const FieldDescriptor*> extensions;
+ pool->pool->FindAllExtensions(self->message_descriptor, &extensions);
+
+ ScopedPyObjectPtr result(PyDict_New());
+ for (int i = 0; i < extensions.size(); i++) {
+ ScopedPyObjectPtr extension(
+ PyFieldDescriptor_FromDescriptor(extensions[i]));
+ if (extension == NULL) {
+ return NULL;
+ }
+ if (PyDict_SetItemString(result.get(), extensions[i]->full_name().c_str(),
+ extension.get()) < 0) {
+ return NULL;
+ }
+ }
+ return result.release();
+}
+
+// The _extensions_by_number dictionary is built on every access.
+// TODO(amauryfa): Migrate all users to pool.FindExtensionByNumber()
+static PyObject* GetExtensionsByNumber(CMessageClass *self, void *closure) {
+ const PyDescriptorPool* pool = self->py_message_factory->pool;
+
+ std::vector<const FieldDescriptor*> extensions;
+ pool->pool->FindAllExtensions(self->message_descriptor, &extensions);
+
+ ScopedPyObjectPtr result(PyDict_New());
+ for (int i = 0; i < extensions.size(); i++) {
+ ScopedPyObjectPtr extension(
+ PyFieldDescriptor_FromDescriptor(extensions[i]));
+ if (extension == NULL) {
+ return NULL;
+ }
+ ScopedPyObjectPtr number(PyInt_FromLong(extensions[i]->number()));
+ if (number == NULL) {
+ return NULL;
+ }
+ if (PyDict_SetItem(result.get(), number.get(), extension.get()) < 0) {
+ return NULL;
+ }
+ }
+ return result.release();
+}
+
+static PyGetSetDef Getters[] = {
+ {"_extensions_by_name", (getter)GetExtensionsByName, NULL},
+ {"_extensions_by_number", (getter)GetExtensionsByNumber, NULL},
+ {NULL}
+};
+
} // namespace message_meta
PyTypeObject CMessageClass_Type = {
@@ -389,7 +429,7 @@ PyTypeObject CMessageClass_Type = {
0, // tp_iternext
0, // tp_methods
0, // tp_members
- 0, // tp_getset
+ message_meta::Getters, // tp_getset
0, // tp_base
0, // tp_dict
0, // tp_descr_get
@@ -525,23 +565,10 @@ int ForEachCompositeField(CMessage* self, Visitor visitor) {
// ---------------------------------------------------------------------
-// Constants used for integer type range checking.
-PyObject* kPythonZero;
-PyObject* kint32min_py;
-PyObject* kint32max_py;
-PyObject* kuint32max_py;
-PyObject* kint64min_py;
-PyObject* kint64max_py;
-PyObject* kuint64max_py;
-
PyObject* EncodeError_class;
PyObject* DecodeError_class;
PyObject* PickleError_class;
-// Constant PyString values used for GetAttr/GetItem.
-static PyObject* k_cdescriptor;
-static PyObject* kfull_name;
-
/* Is 64bit */
void FormatTypeError(PyObject* arg, char* expected_types) {
PyObject* repr = PyObject_Repr(arg);
@@ -555,68 +582,126 @@ void FormatTypeError(PyObject* arg, char* expected_types) {
}
}
-template<class T>
-bool CheckAndGetInteger(
- PyObject* arg, T* value, PyObject* min, PyObject* max) {
- bool is_long = PyLong_Check(arg);
-#if PY_MAJOR_VERSION < 3
- if (!PyInt_Check(arg) && !is_long) {
- FormatTypeError(arg, "int, long");
- return false;
+void OutOfRangeError(PyObject* arg) {
+ PyObject *s = PyObject_Str(arg);
+ if (s) {
+ PyErr_Format(PyExc_ValueError,
+ "Value out of range: %s",
+ PyString_AsString(s));
+ Py_DECREF(s);
}
- if (PyObject_Compare(min, arg) > 0 || PyObject_Compare(max, arg) < 0) {
-#else
- if (!is_long) {
- FormatTypeError(arg, "int");
+}
+
+template<class RangeType, class ValueType>
+bool VerifyIntegerCastAndRange(PyObject* arg, ValueType value) {
+ if GOOGLE_PREDICT_FALSE(value == -1 && PyErr_Occurred()) {
+ if (PyErr_ExceptionMatches(PyExc_OverflowError)) {
+ // Replace it with the same ValueError as pure python protos instead of
+ // the default one.
+ PyErr_Clear();
+ OutOfRangeError(arg);
+ } // Otherwise propagate existing error.
return false;
}
- if (PyObject_RichCompareBool(min, arg, Py_LE) != 1 ||
- PyObject_RichCompareBool(max, arg, Py_GE) != 1) {
-#endif
- if (!PyErr_Occurred()) {
- PyObject *s = PyObject_Str(arg);
- if (s) {
- PyErr_Format(PyExc_ValueError,
- "Value out of range: %s",
- PyString_AsString(s));
- Py_DECREF(s);
- }
- }
+ if GOOGLE_PREDICT_FALSE(!IsValidNumericCast<RangeType>(value)) {
+ OutOfRangeError(arg);
return false;
}
+ return true;
+}
+
+template<class T>
+bool CheckAndGetInteger(PyObject* arg, T* value) {
+ // The fast path.
#if PY_MAJOR_VERSION < 3
- if (!is_long) {
- *value = static_cast<T>(PyInt_AsLong(arg));
- } else // NOLINT
+ // For the typical case, offer a fast path.
+ if GOOGLE_PREDICT_TRUE(PyInt_Check(arg)) {
+ long int_result = PyInt_AsLong(arg);
+ if GOOGLE_PREDICT_TRUE(IsValidNumericCast<T>(int_result)) {
+ *value = static_cast<T>(int_result);
+ return true;
+ } else {
+ OutOfRangeError(arg);
+ return false;
+ }
+ }
#endif
- {
- if (min == kPythonZero) {
- *value = static_cast<T>(PyLong_AsUnsignedLongLong(arg));
+ // This effectively defines an integer as "an object that can be cast as
+ // an integer and can be used as an ordinal number".
+ // This definition includes everything that implements numbers.Integral
+ // and shouldn't cast the net too wide.
+ if GOOGLE_PREDICT_FALSE(!PyIndex_Check(arg)) {
+ FormatTypeError(arg, "int, long");
+ return false;
+ }
+
+ // Now we have an integral number so we can safely use PyLong_ functions.
+ // We need to treat the signed and unsigned cases differently in case arg is
+ // holding a value above the maximum for signed longs.
+ if (std::numeric_limits<T>::min() == 0) {
+ // Unsigned case.
+ unsigned PY_LONG_LONG ulong_result;
+ if (PyLong_Check(arg)) {
+ ulong_result = PyLong_AsUnsignedLongLong(arg);
+ } else {
+ // Unlike PyLong_AsLongLong, PyLong_AsUnsignedLongLong is very
+ // picky about the exact type.
+ PyObject* casted = PyNumber_Long(arg);
+ if GOOGLE_PREDICT_FALSE(casted == NULL) {
+ // Propagate existing error.
+ return false;
+ }
+ ulong_result = PyLong_AsUnsignedLongLong(casted);
+ Py_DECREF(casted);
+ }
+ if (VerifyIntegerCastAndRange<T, unsigned PY_LONG_LONG>(arg,
+ ulong_result)) {
+ *value = static_cast<T>(ulong_result);
+ } else {
+ return false;
+ }
+ } else {
+ // Signed case.
+ PY_LONG_LONG long_result;
+ PyNumberMethods *nb;
+ if ((nb = arg->ob_type->tp_as_number) != NULL && nb->nb_int != NULL) {
+ // PyLong_AsLongLong requires it to be a long or to have an __int__()
+ // method.
+ long_result = PyLong_AsLongLong(arg);
} else {
- *value = static_cast<T>(PyLong_AsLongLong(arg));
+ // Valid subclasses of numbers.Integral should have a __long__() method
+ // so fall back to that.
+ PyObject* casted = PyNumber_Long(arg);
+ if GOOGLE_PREDICT_FALSE(casted == NULL) {
+ // Propagate existing error.
+ return false;
+ }
+ long_result = PyLong_AsLongLong(casted);
+ Py_DECREF(casted);
+ }
+ if (VerifyIntegerCastAndRange<T, PY_LONG_LONG>(arg, long_result)) {
+ *value = static_cast<T>(long_result);
+ } else {
+ return false;
}
}
+
return true;
}
// These are referenced by repeated_scalar_container, and must
// be explicitly instantiated.
-template bool CheckAndGetInteger<int32>(
- PyObject*, int32*, PyObject*, PyObject*);
-template bool CheckAndGetInteger<int64>(
- PyObject*, int64*, PyObject*, PyObject*);
-template bool CheckAndGetInteger<uint32>(
- PyObject*, uint32*, PyObject*, PyObject*);
-template bool CheckAndGetInteger<uint64>(
- PyObject*, uint64*, PyObject*, PyObject*);
+template bool CheckAndGetInteger<int32>(PyObject*, int32*);
+template bool CheckAndGetInteger<int64>(PyObject*, int64*);
+template bool CheckAndGetInteger<uint32>(PyObject*, uint32*);
+template bool CheckAndGetInteger<uint64>(PyObject*, uint64*);
bool CheckAndGetDouble(PyObject* arg, double* value) {
- if (!PyInt_Check(arg) && !PyLong_Check(arg) &&
- !PyFloat_Check(arg)) {
+ *value = PyFloat_AsDouble(arg);
+ if GOOGLE_PREDICT_FALSE(*value == -1 && PyErr_Occurred()) {
FormatTypeError(arg, "int, long, float");
return false;
}
- *value = PyFloat_AsDouble(arg);
return true;
}
@@ -630,11 +715,13 @@ bool CheckAndGetFloat(PyObject* arg, float* value) {
}
bool CheckAndGetBool(PyObject* arg, bool* value) {
- if (!PyInt_Check(arg) && !PyBool_Check(arg) && !PyLong_Check(arg)) {
+ long long_value = PyInt_AsLong(arg);
+ if (long_value == -1 && PyErr_Occurred()) {
FormatTypeError(arg, "int, long, bool");
return false;
}
- *value = static_cast<bool>(PyInt_AsLong(arg));
+ *value = static_cast<bool>(long_value);
+
return true;
}
@@ -966,20 +1053,7 @@ int InternalDeleteRepeatedField(
int min, max;
length = reflection->FieldSize(*message, field_descriptor);
- if (PyInt_Check(slice) || PyLong_Check(slice)) {
- from = to = PyLong_AsLong(slice);
- if (from < 0) {
- from = to = length + from;
- }
- step = 1;
- min = max = from;
-
- // Range check.
- if (from < 0 || from >= length) {
- PyErr_Format(PyExc_IndexError, "list assignment index out of range");
- return -1;
- }
- } else if (PySlice_Check(slice)) {
+ if (PySlice_Check(slice)) {
from = to = step = slice_length = 0;
PySlice_GetIndicesEx(
#if PY_MAJOR_VERSION < 3
@@ -996,8 +1070,23 @@ int InternalDeleteRepeatedField(
max = from;
}
} else {
- PyErr_SetString(PyExc_TypeError, "list indices must be integers");
- return -1;
+ from = to = PyLong_AsLong(slice);
+ if (from == -1 && PyErr_Occurred()) {
+ PyErr_SetString(PyExc_TypeError, "list indices must be integers");
+ return -1;
+ }
+
+ if (from < 0) {
+ from = to = length + from;
+ }
+ step = 1;
+ min = max = from;
+
+ // Range check.
+ if (from < 0 || from >= length) {
+ PyErr_Format(PyExc_IndexError, "list assignment index out of range");
+ return -1;
+ }
}
Py_ssize_t i = from;
@@ -1958,99 +2047,29 @@ static PyObject* ByteSize(CMessage* self, PyObject* args) {
return PyLong_FromLong(self->message->ByteSize());
}
-static PyObject* RegisterExtension(PyObject* cls,
- PyObject* extension_handle) {
+PyObject* RegisterExtension(PyObject* cls, PyObject* extension_handle) {
const FieldDescriptor* descriptor =
GetExtensionDescriptor(extension_handle);
if (descriptor == NULL) {
return NULL;
}
-
- ScopedPyObjectPtr extensions_by_name(
- PyObject_GetAttr(cls, k_extensions_by_name));
- if (extensions_by_name == NULL) {
- PyErr_SetString(PyExc_TypeError, "no extensions_by_name on class");
+ if (!PyObject_TypeCheck(cls, &CMessageClass_Type)) {
+ PyErr_Format(PyExc_TypeError, "Expected a message class, got %s",
+ cls->ob_type->tp_name);
return NULL;
}
- ScopedPyObjectPtr full_name(PyObject_GetAttr(extension_handle, kfull_name));
- if (full_name == NULL) {
+ CMessageClass *message_class = reinterpret_cast<CMessageClass*>(cls);
+ if (message_class == NULL) {
return NULL;
}
-
// If the extension was already registered, check that it is the same.
- PyObject* existing_extension =
- PyDict_GetItem(extensions_by_name.get(), full_name.get());
- if (existing_extension != NULL) {
- const FieldDescriptor* existing_extension_descriptor =
- GetExtensionDescriptor(existing_extension);
- if (existing_extension_descriptor != descriptor) {
- PyErr_SetString(PyExc_ValueError, "Double registration of Extensions");
- return NULL;
- }
- // Nothing else to do.
- Py_RETURN_NONE;
- }
-
- if (PyDict_SetItem(extensions_by_name.get(), full_name.get(),
- extension_handle) < 0) {
- return NULL;
- }
-
- // Also store a mapping from extension number to implementing class.
- ScopedPyObjectPtr extensions_by_number(
- PyObject_GetAttr(cls, k_extensions_by_number));
- if (extensions_by_number == NULL) {
- PyErr_SetString(PyExc_TypeError, "no extensions_by_number on class");
- return NULL;
- }
-
- ScopedPyObjectPtr number(PyObject_GetAttrString(extension_handle, "number"));
- if (number == NULL) {
- return NULL;
- }
-
- // If the extension was already registered by number, check that it is the
- // same.
- existing_extension = PyDict_GetItem(extensions_by_number.get(), number.get());
- if (existing_extension != NULL) {
- const FieldDescriptor* existing_extension_descriptor =
- GetExtensionDescriptor(existing_extension);
- if (existing_extension_descriptor != descriptor) {
- const Descriptor* msg_desc = GetMessageDescriptor(
- reinterpret_cast<PyTypeObject*>(cls));
- PyErr_Format(
- PyExc_ValueError,
- "Extensions \"%s\" and \"%s\" both try to extend message type "
- "\"%s\" with field number %ld.",
- existing_extension_descriptor->full_name().c_str(),
- descriptor->full_name().c_str(),
- msg_desc->full_name().c_str(),
- PyInt_AsLong(number.get()));
- return NULL;
- }
- // Nothing else to do.
- Py_RETURN_NONE;
- }
- if (PyDict_SetItem(extensions_by_number.get(), number.get(),
- extension_handle) < 0) {
+ const FieldDescriptor* existing_extension =
+ message_class->py_message_factory->pool->pool->FindExtensionByNumber(
+ descriptor->containing_type(), descriptor->number());
+ if (existing_extension != NULL && existing_extension != descriptor) {
+ PyErr_SetString(PyExc_ValueError, "Double registration of Extensions");
return NULL;
}
-
- // Check if it's a message set
- if (descriptor->is_extension() &&
- descriptor->containing_type()->options().message_set_wire_format() &&
- descriptor->type() == FieldDescriptor::TYPE_MESSAGE &&
- descriptor->label() == FieldDescriptor::LABEL_OPTIONAL) {
- ScopedPyObjectPtr message_name(PyString_FromStringAndSize(
- descriptor->message_type()->full_name().c_str(),
- descriptor->message_type()->full_name().size()));
- if (message_name == NULL) {
- return NULL;
- }
- PyDict_SetItem(extensions_by_name.get(), message_name.get(),
- extension_handle);
- }
-
Py_RETURN_NONE;
}
@@ -2087,7 +2106,7 @@ static PyObject* WhichOneof(CMessage* self, PyObject* arg) {
static PyObject* GetExtensionDict(CMessage* self, void *closure);
static PyObject* ListFields(CMessage* self) {
- vector<const FieldDescriptor*> fields;
+ std::vector<const FieldDescriptor*> fields;
self->message->GetReflection()->ListFields(*self->message, &fields);
// Normally, the list will be exactly the size of the fields.
@@ -2178,7 +2197,7 @@ static PyObject* DiscardUnknownFields(CMessage* self) {
PyObject* FindInitializationErrors(CMessage* self) {
Message* message = self->message;
- vector<string> errors;
+ std::vector<string> errors;
message->FindInitializationErrors(&errors);
PyObject* error_list = PyList_New(errors.size());
@@ -2570,11 +2589,24 @@ static PyObject* GetExtensionDict(CMessage* self, void *closure) {
return NULL;
}
+static PyObject* GetExtensionsByName(CMessage *self, void *closure) {
+ return message_meta::GetExtensionsByName(
+ reinterpret_cast<CMessageClass*>(Py_TYPE(self)), closure);
+}
+
+static PyObject* GetExtensionsByNumber(CMessage *self, void *closure) {
+ return message_meta::GetExtensionsByNumber(
+ reinterpret_cast<CMessageClass*>(Py_TYPE(self)), closure);
+}
+
static PyGetSetDef Getters[] = {
{"Extensions", (getter)GetExtensionDict, NULL, "Extension dict"},
+ {"_extensions_by_name", (getter)GetExtensionsByName, NULL},
+ {"_extensions_by_number", (getter)GetExtensionsByNumber, NULL},
{NULL}
};
+
static PyMethodDef Methods[] = {
{ "__deepcopy__", (PyCFunction)DeepCopy, METH_VARARGS,
"Makes a deep copy of the class." },
@@ -2835,19 +2867,7 @@ void InitGlobals() {
// TODO(gps): Check all return values in this function for NULL and propagate
// the error (MemoryError) on up to result in an import failure. These should
// also be freed and reset to NULL during finalization.
- kPythonZero = PyInt_FromLong(0);
- kint32min_py = PyInt_FromLong(kint32min);
- kint32max_py = PyInt_FromLong(kint32max);
- kuint32max_py = PyLong_FromLongLong(kuint32max);
- kint64min_py = PyLong_FromLongLong(kint64min);
- kint64max_py = PyLong_FromLongLong(kint64max);
- kuint64max_py = PyLong_FromUnsignedLongLong(kuint64max);
-
kDESCRIPTOR = PyString_FromString("DESCRIPTOR");
- k_cdescriptor = PyString_FromString("_cdescriptor");
- kfull_name = PyString_FromString("full_name");
- k_extensions_by_name = PyString_FromString("_extensions_by_name");
- k_extensions_by_number = PyString_FromString("_extensions_by_number");
PyObject *dummy_obj = PySet_New(NULL);
kEmptyWeakref = PyWeakref_NewRef(dummy_obj, NULL);
@@ -2887,25 +2907,6 @@ bool InitProto2MessageModule(PyObject *m) {
// DESCRIPTOR is set on each protocol buffer message class elsewhere, but set
// it here as well to document that subclasses need to set it.
PyDict_SetItem(CMessage_Type.tp_dict, kDESCRIPTOR, Py_None);
- // Subclasses with message extensions will override _extensions_by_name and
- // _extensions_by_number with fresh mutable dictionaries in AddDescriptors.
- // All other classes can share this same immutable mapping.
- ScopedPyObjectPtr empty_dict(PyDict_New());
- if (empty_dict == NULL) {
- return false;
- }
- ScopedPyObjectPtr immutable_dict(PyDictProxy_New(empty_dict.get()));
- if (immutable_dict == NULL) {
- return false;
- }
- if (PyDict_SetItem(CMessage_Type.tp_dict,
- k_extensions_by_name, immutable_dict.get()) < 0) {
- return false;
- }
- if (PyDict_SetItem(CMessage_Type.tp_dict,
- k_extensions_by_number, immutable_dict.get()) < 0) {
- return false;
- }
PyModule_AddObject(m, "Message", reinterpret_cast<PyObject*>(&CMessage_Type));