diff options
Diffstat (limited to 'python/google/protobuf/pyext/map_container.cc')
-rw-r--r-- | python/google/protobuf/pyext/map_container.cc | 138 |
1 files changed, 109 insertions, 29 deletions
diff --git a/python/google/protobuf/pyext/map_container.cc b/python/google/protobuf/pyext/map_container.cc index 6d7ee285..77c61706 100644 --- a/python/google/protobuf/pyext/map_container.cc +++ b/python/google/protobuf/pyext/map_container.cc @@ -68,6 +68,8 @@ class MapReflectionFriend { static PyObject* MessageMapGetItem(PyObject* _self, PyObject* key); static int ScalarMapSetItem(PyObject* _self, PyObject* key, PyObject* v); static int MessageMapSetItem(PyObject* _self, PyObject* key, PyObject* v); + static PyObject* ScalarMapToStr(PyObject* _self); + static PyObject* MessageMapToStr(PyObject* _self); }; struct MapIterator { @@ -199,26 +201,26 @@ static PyObject* MapKeyToPython(const FieldDescriptor* field_descriptor, // This is only used for ScalarMap, so we don't need to handle the // CPPTYPE_MESSAGE case. PyObject* MapValueRefToPython(const FieldDescriptor* field_descriptor, - MapValueRef* value) { + const MapValueRef& value) { switch (field_descriptor->cpp_type()) { case FieldDescriptor::CPPTYPE_INT32: - return PyInt_FromLong(value->GetInt32Value()); + return PyInt_FromLong(value.GetInt32Value()); case FieldDescriptor::CPPTYPE_INT64: - return PyLong_FromLongLong(value->GetInt64Value()); + return PyLong_FromLongLong(value.GetInt64Value()); case FieldDescriptor::CPPTYPE_UINT32: - return PyInt_FromSize_t(value->GetUInt32Value()); + return PyInt_FromSize_t(value.GetUInt32Value()); case FieldDescriptor::CPPTYPE_UINT64: - return PyLong_FromUnsignedLongLong(value->GetUInt64Value()); + return PyLong_FromUnsignedLongLong(value.GetUInt64Value()); case FieldDescriptor::CPPTYPE_FLOAT: - return PyFloat_FromDouble(value->GetFloatValue()); + return PyFloat_FromDouble(value.GetFloatValue()); case FieldDescriptor::CPPTYPE_DOUBLE: - return PyFloat_FromDouble(value->GetDoubleValue()); + return PyFloat_FromDouble(value.GetDoubleValue()); case FieldDescriptor::CPPTYPE_BOOL: - return PyBool_FromLong(value->GetBoolValue()); + return PyBool_FromLong(value.GetBoolValue()); case FieldDescriptor::CPPTYPE_STRING: - return ToStringObject(field_descriptor, value->GetStringValue()); + return ToStringObject(field_descriptor, value.GetStringValue()); case FieldDescriptor::CPPTYPE_ENUM: - return PyInt_FromLong(value->GetEnumValue()); + return PyInt_FromLong(value.GetEnumValue()); default: PyErr_Format( PyExc_SystemError, "Couldn't convert type %d to value", @@ -472,7 +474,7 @@ PyObject* MapReflectionFriend::ScalarMapGetItem(PyObject* _self, self->version++; } - return MapValueRefToPython(self->value_field_descriptor, &value); + return MapValueRefToPython(self->value_field_descriptor, value); } int MapReflectionFriend::ScalarMapSetItem(PyObject* _self, PyObject* key, @@ -535,10 +537,47 @@ static PyObject* ScalarMapGet(PyObject* self, PyObject* args) { } } +PyObject* MapReflectionFriend::ScalarMapToStr(PyObject* _self) { + ScopedPyObjectPtr dict(PyDict_New()); + if (dict == NULL) { + return NULL; + } + ScopedPyObjectPtr key; + ScopedPyObjectPtr value; + + MapContainer* self = GetMap(_self); + Message* message = self->GetMutableMessage(); + const Reflection* reflection = message->GetReflection(); + for (google::protobuf::MapIterator it = reflection->MapBegin( + message, self->parent_field_descriptor); + it != reflection->MapEnd(message, self->parent_field_descriptor); + ++it) { + key.reset(MapKeyToPython(self->key_field_descriptor, + it.GetKey())); + if (key == NULL) { + return NULL; + } + value.reset(MapValueRefToPython(self->value_field_descriptor, + it.GetValueRef())); + if (value == NULL) { + return NULL; + } + if (PyDict_SetItem(dict.get(), key.get(), value.get()) < 0) { + return NULL; + } + } + return PyObject_Repr(dict.get()); +} + static void ScalarMapDealloc(PyObject* _self) { MapContainer* self = GetMap(_self); self->owner.reset(); - Py_TYPE(_self)->tp_free(_self); + PyTypeObject *type = Py_TYPE(_self); + type->tp_free(_self); + if (type->tp_flags & Py_TPFLAGS_HEAPTYPE) { + // With Python3, the Map class is not static, and must be managed. + Py_DECREF(type); + } } static PyMethodDef ScalarMapMethods[] = { @@ -570,6 +609,7 @@ PyTypeObject *ScalarMapContainer_Type; {Py_mp_ass_subscript, (void *)MapReflectionFriend::ScalarMapSetItem}, {Py_tp_methods, (void *)ScalarMapMethods}, {Py_tp_iter, (void *)MapReflectionFriend::GetIterator}, + {Py_tp_repr, (void *)MapReflectionFriend::ScalarMapToStr}, {0, 0}, }; @@ -597,7 +637,7 @@ PyTypeObject *ScalarMapContainer_Type; 0, // tp_getattr 0, // tp_setattr 0, // tp_compare - 0, // tp_repr + MapReflectionFriend::ScalarMapToStr, // tp_repr 0, // tp_as_number 0, // tp_as_sequence &ScalarMapMappingMethods, // tp_as_mapping @@ -634,7 +674,8 @@ static MessageMapContainer* GetMessageMap(PyObject* obj) { return reinterpret_cast<MessageMapContainer*>(obj); } -static PyObject* GetCMessage(MessageMapContainer* self, Message* message) { +static PyObject* GetCMessage(MessageMapContainer* self, Message* message, + bool insert_message_dict) { // Get or create the CMessage object corresponding to this message. ScopedPyObjectPtr key(PyLong_FromVoidPtr(message)); PyObject* ret = PyDict_GetItem(self->message_dict, key.get()); @@ -649,10 +690,11 @@ static PyObject* GetCMessage(MessageMapContainer* self, Message* message) { cmsg->owner = self->owner; cmsg->message = message; cmsg->parent = self->parent; - - if (PyDict_SetItem(self->message_dict, key.get(), ret) < 0) { - Py_DECREF(ret); - return NULL; + if (insert_message_dict) { + if (PyDict_SetItem(self->message_dict, key.get(), ret) < 0) { + Py_DECREF(ret); + return NULL; + } } } else { Py_INCREF(ret); @@ -781,7 +823,41 @@ PyObject* MapReflectionFriend::MessageMapGetItem(PyObject* _self, self->version++; } - return GetCMessage(self, value.MutableMessageValue()); + return GetCMessage(self, value.MutableMessageValue(), true); +} + +PyObject* MapReflectionFriend::MessageMapToStr(PyObject* _self) { + ScopedPyObjectPtr dict(PyDict_New()); + if (dict == NULL) { + return NULL; + } + ScopedPyObjectPtr key; + ScopedPyObjectPtr value; + + MessageMapContainer* self = GetMessageMap(_self); + Message* message = self->GetMutableMessage(); + const Reflection* reflection = message->GetReflection(); + for (google::protobuf::MapIterator it = reflection->MapBegin( + message, self->parent_field_descriptor); + it != reflection->MapEnd(message, self->parent_field_descriptor); + ++it) { + key.reset(MapKeyToPython(self->key_field_descriptor, + it.GetKey())); + if (key == NULL) { + return NULL; + } + // Do not insert the cmessage to self->message_dict because + // the returned CMessage will not escape this function. + value.reset(GetCMessage( + self, it.MutableValueRef()->MutableMessageValue(), false)); + if (value == NULL) { + return NULL; + } + if (PyDict_SetItem(dict.get(), key.get(), value.get()) < 0) { + return NULL; + } + } + return PyObject_Repr(dict.get()); } PyObject* MessageMapGet(PyObject* self, PyObject* args) { @@ -813,7 +889,12 @@ static void MessageMapDealloc(PyObject* _self) { self->owner.reset(); Py_DECREF(self->message_dict); Py_DECREF(self->message_class); - Py_TYPE(_self)->tp_free(_self); + PyTypeObject *type = Py_TYPE(_self); + type->tp_free(_self); + if (type->tp_flags & Py_TPFLAGS_HEAPTYPE) { + // With Python3, the Map class is not static, and must be managed. + Py_DECREF(type); + } } static PyMethodDef MessageMapMethods[] = { @@ -847,6 +928,7 @@ PyTypeObject *MessageMapContainer_Type; {Py_mp_ass_subscript, (void *)MapReflectionFriend::MessageMapSetItem}, {Py_tp_methods, (void *)MessageMapMethods}, {Py_tp_iter, (void *)MapReflectionFriend::GetIterator}, + {Py_tp_repr, (void *)MapReflectionFriend::MessageMapToStr}, {0, 0} }; @@ -874,7 +956,7 @@ PyTypeObject *MessageMapContainer_Type; 0, // tp_getattr 0, // tp_setattr 0, // tp_compare - 0, // tp_repr + MapReflectionFriend::MessageMapToStr, // tp_repr 0, // tp_as_number 0, // tp_as_sequence &MessageMapMappingMethods, // tp_as_mapping @@ -1027,17 +1109,15 @@ bool InitMapContainers() { return false; } - if (!PyObject_TypeCheck(mutable_mapping.get(), &PyType_Type)) { - return false; - } - Py_INCREF(mutable_mapping.get()); #if PY_MAJOR_VERSION >= 3 - PyObject* bases = PyTuple_New(1); - PyTuple_SET_ITEM(bases, 0, mutable_mapping.get()); + ScopedPyObjectPtr bases(PyTuple_Pack(1, mutable_mapping.get())); + if (bases == NULL) { + return false; + } ScalarMapContainer_Type = reinterpret_cast<PyTypeObject*>( - PyType_FromSpecWithBases(&ScalarMapContainer_Type_spec, bases)); + PyType_FromSpecWithBases(&ScalarMapContainer_Type_spec, bases.get())); #else _ScalarMapContainer_Type.tp_base = reinterpret_cast<PyTypeObject*>(mutable_mapping.get()); @@ -1055,7 +1135,7 @@ bool InitMapContainers() { #if PY_MAJOR_VERSION >= 3 MessageMapContainer_Type = reinterpret_cast<PyTypeObject*>( - PyType_FromSpecWithBases(&MessageMapContainer_Type_spec, bases)); + PyType_FromSpecWithBases(&MessageMapContainer_Type_spec, bases.get())); #else Py_INCREF(mutable_mapping.get()); _MessageMapContainer_Type.tp_base = |