diff options
Diffstat (limited to 'src/google/protobuf/text_format.cc')
-rw-r--r-- | src/google/protobuf/text_format.cc | 216 |
1 files changed, 204 insertions, 12 deletions
diff --git a/src/google/protobuf/text_format.cc b/src/google/protobuf/text_format.cc index 93c24b23..c8de875d 100644 --- a/src/google/protobuf/text_format.cc +++ b/src/google/protobuf/text_format.cc @@ -53,6 +53,7 @@ #include <google/protobuf/io/zero_copy_stream_impl.h> #include <google/protobuf/descriptor.h> #include <google/protobuf/dynamic_message.h> +#include <google/protobuf/map_field.h> #include <google/protobuf/repeated_field.h> #include <google/protobuf/unknown_field_set.h> #include <google/protobuf/wire_format_lite.h> @@ -1356,9 +1357,9 @@ bool CheckParseInputSize(StringPiece input, io::ErrorCollector* error_collector) { if (input.size() > INT_MAX) { error_collector->AddError( - -1, 0, StrCat("Input size too large: ", - static_cast<int64>(input.size()), " bytes", - " > ", INT_MAX, " bytes.")); + -1, 0, + StrCat("Input size too large: ", static_cast<int64>(input.size()), + " bytes", " > ", INT_MAX, " bytes.")); return false; } return true; @@ -1714,9 +1715,9 @@ class FieldValuePrinterWrapper : public TextFormat::FastFieldValuePrinter { void PrintFieldName(const Message& message, int field_index, int field_count, const Reflection* reflection, const FieldDescriptor* field, - TextFormat::BaseTextGenerator* generator) const { - generator->PrintString(delegate_->PrintFieldName( - message, reflection, field)); + TextFormat::BaseTextGenerator* generator) const override { + generator->PrintString( + delegate_->PrintFieldName(message, reflection, field)); } void PrintFieldName(const Message& message, const Reflection* reflection, const FieldDescriptor* field, @@ -1947,7 +1948,13 @@ void TextFormat::Printer::Print(const Message& message, return; } std::vector<const FieldDescriptor*> fields; - reflection->ListFields(message, &fields); + if (descriptor->options().map_entry()) { + fields.push_back(descriptor->field(0)); + fields.push_back(descriptor->field(1)); + } else { + reflection->ListFields(message, &fields); + } + if (print_message_fields_in_index_order_) { std::sort(fields.begin(), fields.end(), FieldIndexSorter()); } @@ -1973,6 +1980,181 @@ void TextFormat::Printer::PrintFieldValueToString( PrintFieldValue(message, message.GetReflection(), field, index, &generator); } +class MapEntryMessageComparator { + public: + explicit MapEntryMessageComparator(const Descriptor* descriptor) + : field_(descriptor->field(0)) {} + + bool operator()(const Message* a, const Message* b) { + const Reflection* reflection = a->GetReflection(); + switch (field_->cpp_type()) { + case FieldDescriptor::CPPTYPE_BOOL: { + bool first = reflection->GetBool(*a, field_); + bool second = reflection->GetBool(*b, field_); + return first < second; + } + case FieldDescriptor::CPPTYPE_INT32: { + int32 first = reflection->GetInt32(*a, field_); + int32 second = reflection->GetInt32(*b, field_); + return first < second; + } + case FieldDescriptor::CPPTYPE_INT64: { + int64 first = reflection->GetInt64(*a, field_); + int64 second = reflection->GetInt64(*b, field_); + return first < second; + } + case FieldDescriptor::CPPTYPE_UINT32: { + uint32 first = reflection->GetUInt32(*a, field_); + uint32 second = reflection->GetUInt32(*b, field_); + return first < second; + } + case FieldDescriptor::CPPTYPE_UINT64: { + uint64 first = reflection->GetUInt64(*a, field_); + uint64 second = reflection->GetUInt64(*b, field_); + return first < second; + } + case FieldDescriptor::CPPTYPE_STRING: { + string first = reflection->GetString(*a, field_); + string second = reflection->GetString(*b, field_); + return first < second; + } + default: + GOOGLE_LOG(DFATAL) << "Invalid key for map field."; + return true; + } + } + + private: + const FieldDescriptor* field_; +}; + +namespace internal { +class MapFieldPrinterHelper { + public: + // DynamicMapSorter::Sort cannot be used because it enfores syncing with + // repeated field. + static bool SortMap(const Message& message, const Reflection* reflection, + const FieldDescriptor* field, MessageFactory* factory, + std::vector<const Message*>* sorted_map_field); + static void CopyKey(const MapKey& key, Message* message, + const FieldDescriptor* field_desc); + static void CopyValue(const MapValueRef& value, Message* message, + const FieldDescriptor* field_desc); +}; + +// Returns true if elements contained in sorted_map_field need to be released. +bool MapFieldPrinterHelper::SortMap( + const Message& message, const Reflection* reflection, + const FieldDescriptor* field, MessageFactory* factory, + std::vector<const Message*>* sorted_map_field) { + bool need_release = false; + const MapFieldBase& base = + *reflection->MapData(const_cast<Message*>(&message), field); + + if (base.IsRepeatedFieldValid()) { + const RepeatedPtrField<Message>& map_field = + reflection->GetRepeatedPtrField<Message>(message, field); + for (int i = 0; i < map_field.size(); ++i) { + sorted_map_field->push_back( + const_cast<RepeatedPtrField<Message>*>(&map_field)->Mutable(i)); + } + } else { + // TODO(teboring): For performance, instead of creating map entry message + // for each element, just store map keys and sort them. + const Descriptor* map_entry_desc = field->message_type(); + const Message* prototype = factory->GetPrototype(map_entry_desc); + for (MapIterator iter = + reflection->MapBegin(const_cast<Message*>(&message), field); + iter != reflection->MapEnd(const_cast<Message*>(&message), field); + ++iter) { + Message* map_entry_message = prototype->New(); + CopyKey(iter.GetKey(), map_entry_message, map_entry_desc->field(0)); + CopyValue(iter.GetValueRef(), map_entry_message, + map_entry_desc->field(1)); + sorted_map_field->push_back(map_entry_message); + } + need_release = true; + } + + MapEntryMessageComparator comparator(field->message_type()); + std::stable_sort(sorted_map_field->begin(), sorted_map_field->end(), + comparator); + return need_release; +} + +void MapFieldPrinterHelper::CopyKey(const MapKey& key, Message* message, + const FieldDescriptor* field_desc) { + const Reflection* reflection = message->GetReflection(); + switch (field_desc->cpp_type()) { + case FieldDescriptor::CPPTYPE_DOUBLE: + case FieldDescriptor::CPPTYPE_FLOAT: + case FieldDescriptor::CPPTYPE_ENUM: + case FieldDescriptor::CPPTYPE_MESSAGE: + GOOGLE_LOG(ERROR) << "Not supported."; + break; + case FieldDescriptor::CPPTYPE_STRING: + reflection->SetString(message, field_desc, key.GetStringValue()); + return; + case FieldDescriptor::CPPTYPE_INT64: + reflection->SetInt64(message, field_desc, key.GetInt64Value()); + return; + case FieldDescriptor::CPPTYPE_INT32: + reflection->SetInt32(message, field_desc, key.GetInt32Value()); + return; + case FieldDescriptor::CPPTYPE_UINT64: + reflection->SetUInt64(message, field_desc, key.GetUInt64Value()); + return; + case FieldDescriptor::CPPTYPE_UINT32: + reflection->SetUInt32(message, field_desc, key.GetUInt32Value()); + return; + case FieldDescriptor::CPPTYPE_BOOL: + reflection->SetBool(message, field_desc, key.GetBoolValue()); + return; + } +} + +void MapFieldPrinterHelper::CopyValue(const MapValueRef& value, + Message* message, + const FieldDescriptor* field_desc) { + const Reflection* reflection = message->GetReflection(); + switch (field_desc->cpp_type()) { + case FieldDescriptor::CPPTYPE_DOUBLE: + reflection->SetDouble(message, field_desc, value.GetDoubleValue()); + return; + case FieldDescriptor::CPPTYPE_FLOAT: + reflection->SetFloat(message, field_desc, value.GetFloatValue()); + return; + case FieldDescriptor::CPPTYPE_ENUM: + reflection->SetEnumValue(message, field_desc, value.GetEnumValue()); + return; + case FieldDescriptor::CPPTYPE_MESSAGE: { + Message* sub_message = value.GetMessageValue().New(); + sub_message->CopyFrom(value.GetMessageValue()); + reflection->SetAllocatedMessage(message, sub_message, field_desc); + return; + } + case FieldDescriptor::CPPTYPE_STRING: + reflection->SetString(message, field_desc, value.GetStringValue()); + return; + case FieldDescriptor::CPPTYPE_INT64: + reflection->SetInt64(message, field_desc, value.GetInt64Value()); + return; + case FieldDescriptor::CPPTYPE_INT32: + reflection->SetInt32(message, field_desc, value.GetInt32Value()); + return; + case FieldDescriptor::CPPTYPE_UINT64: + reflection->SetUInt64(message, field_desc, value.GetUInt64Value()); + return; + case FieldDescriptor::CPPTYPE_UINT32: + reflection->SetUInt32(message, field_desc, value.GetUInt32Value()); + return; + case FieldDescriptor::CPPTYPE_BOOL: + reflection->SetBool(message, field_desc, value.GetBoolValue()); + return; + } +} +} // namespace internal + void TextFormat::Printer::PrintField(const Message& message, const Reflection* reflection, const FieldDescriptor* field, @@ -1989,14 +2171,18 @@ void TextFormat::Printer::PrintField(const Message& message, if (field->is_repeated()) { count = reflection->FieldSize(message, field); - } else if (reflection->HasField(message, field)) { + } else if (reflection->HasField(message, field) || + field->containing_type()->options().map_entry()) { count = 1; } - std::vector<const Message*> map_entries; - const bool is_map = field->is_map(); + DynamicMessageFactory factory; + std::vector<const Message*> sorted_map_field; + bool need_release = false; + bool is_map = field->is_map(); if (is_map) { - map_entries = DynamicMapSorter::Sort(message, count, reflection, field); + need_release = internal::MapFieldPrinterHelper::SortMap( + message, reflection, field, &factory, &sorted_map_field); } for (int j = 0; j < count; ++j) { @@ -2009,7 +2195,7 @@ void TextFormat::Printer::PrintField(const Message& message, custom_printers_, field, default_field_value_printer_.get()); const Message& sub_message = field->is_repeated() - ? (is_map ? *map_entries[j] + ? (is_map ? *sorted_map_field[j] : reflection->GetRepeatedMessage(message, field, j)) : reflection->GetMessage(message, field); printer->PrintMessageStart(sub_message, field_index, count, @@ -2030,6 +2216,12 @@ void TextFormat::Printer::PrintField(const Message& message, } } } + + if (need_release) { + for (int j = 0; j < sorted_map_field.size(); ++j) { + delete sorted_map_field[j]; + } + } } void TextFormat::Printer::PrintShortRepeatedField( |