diff options
Diffstat (limited to 'src/google/protobuf/extension_set_heavy.cc')
-rw-r--r-- | src/google/protobuf/extension_set_heavy.cc | 274 |
1 files changed, 264 insertions, 10 deletions
diff --git a/src/google/protobuf/extension_set_heavy.cc b/src/google/protobuf/extension_set_heavy.cc index 2721f15d..483d7055 100644 --- a/src/google/protobuf/extension_set_heavy.cc +++ b/src/google/protobuf/extension_set_heavy.cc @@ -35,17 +35,20 @@ // Contains methods defined in extension_set.h which cannot be part of the // lite library because they use descriptors or reflection. -#include <google/protobuf/extension_set.h> +#include <google/protobuf/io/zero_copy_stream_impl_lite.h> #include <google/protobuf/descriptor.h> +#include <google/protobuf/extension_set.h> #include <google/protobuf/message.h> #include <google/protobuf/repeated_field.h> #include <google/protobuf/wire_format.h> #include <google/protobuf/wire_format_lite_inl.h> namespace google { + namespace protobuf { namespace internal { + // Implementation of ExtensionFinder which finds extensions in a given // DescriptorPool, using the given MessageFactory to construct sub-objects. // This class is implemented in extension_set_heavy.cc. @@ -103,6 +106,11 @@ inline FieldDescriptor::CppType cpp_type(FieldType type) { static_cast<FieldDescriptor::Type>(type)); } +inline WireFormatLite::FieldType field_type(FieldType type) { + GOOGLE_DCHECK(type > 0 && type <= WireFormatLite::MAX_FIELD_TYPE); + return static_cast<WireFormatLite::FieldType>(type); +} + #define GOOGLE_DCHECK_TYPE(EXTENSION, LABEL, CPPTYPE) \ GOOGLE_DCHECK_EQ((EXTENSION).is_repeated ? FieldDescriptor::LABEL_REPEATED \ : FieldDescriptor::LABEL_OPTIONAL, \ @@ -118,7 +126,12 @@ const MessageLite& ExtensionSet::GetMessage(int number, return *factory->GetPrototype(message_type); } else { GOOGLE_DCHECK_TYPE(iter->second, OPTIONAL, MESSAGE); - return *iter->second.message_value; + if (iter->second.is_lazy) { + return iter->second.lazymessage_value->GetMessage( + *factory->GetPrototype(message_type)); + } else { + return *iter->second.message_value; + } } } @@ -132,13 +145,41 @@ MessageLite* ExtensionSet::MutableMessage(const FieldDescriptor* descriptor, extension->is_packed = false; const MessageLite* prototype = factory->GetPrototype(descriptor->message_type()); - GOOGLE_CHECK(prototype != NULL); + extension->is_lazy = false; extension->message_value = prototype->New(); + extension->is_cleared = false; + return extension->message_value; } else { GOOGLE_DCHECK_TYPE(*extension, OPTIONAL, MESSAGE); + extension->is_cleared = false; + if (extension->is_lazy) { + return extension->lazymessage_value->MutableMessage( + *factory->GetPrototype(descriptor->message_type())); + } else { + return extension->message_value; + } + } +} + +MessageLite* ExtensionSet::ReleaseMessage(const FieldDescriptor* descriptor, + MessageFactory* factory) { + map<int, Extension>::iterator iter = extensions_.find(descriptor->number()); + if (iter == extensions_.end()) { + // Not present. Return NULL. + return NULL; + } else { + GOOGLE_DCHECK_TYPE(iter->second, OPTIONAL, MESSAGE); + MessageLite* ret = NULL; + if (iter->second.is_lazy) { + ret = iter->second.lazymessage_value->ReleaseMessage( + *factory->GetPrototype(descriptor->message_type())); + delete iter->second.lazymessage_value; + } else { + ret = iter->second.message_value; + } + extensions_.erase(descriptor->number()); + return ret; } - extension->is_cleared = false; - return extension->message_value; } MessageLite* ExtensionSet::AddMessage(const FieldDescriptor* descriptor, @@ -157,7 +198,7 @@ MessageLite* ExtensionSet::AddMessage(const FieldDescriptor* descriptor, // RepeatedPtrField<Message> does not know how to Add() since it cannot // allocate an abstract object, so we have to be tricky. MessageLite* result = extension->repeated_message_value - ->AddFromCleared<internal::GenericTypeHandler<MessageLite> >(); + ->AddFromCleared<GenericTypeHandler<MessageLite> >(); if (result == NULL) { const MessageLite* prototype; if (extension->repeated_message_value->size() == 0) { @@ -286,7 +327,11 @@ int ExtensionSet::Extension::SpaceUsedExcludingSelf() const { StringSpaceUsedExcludingSelf(*string_value); break; case FieldDescriptor::CPPTYPE_MESSAGE: - total_size += down_cast<Message*>(message_value)->SpaceUsed(); + if (is_lazy) { + total_size += lazymessage_value->SpaceUsed(); + } else { + total_size += down_cast<Message*>(message_value)->SpaceUsed(); + } break; default: // No extra storage costs for primitive types. @@ -419,8 +464,15 @@ uint8* ExtensionSet::Extension::SerializeFieldWithCachedSizesToArray( HANDLE_TYPE( BYTES, Bytes, *string_value); HANDLE_TYPE( ENUM, Enum, enum_value); HANDLE_TYPE( GROUP, Group, *message_value); - HANDLE_TYPE( MESSAGE, Message, *message_value); #undef HANDLE_TYPE + case FieldDescriptor::TYPE_MESSAGE: + if (is_lazy) { + target = lazymessage_value->WriteMessageToArray(number, target); + } else { + target = WireFormatLite::WriteMessageToArray( + number, *message_value, target); + } + break; } } return target; @@ -444,14 +496,216 @@ uint8* ExtensionSet::Extension::SerializeMessageSetItemWithCachedSizesToArray( target = WireFormatLite::WriteUInt32ToArray( WireFormatLite::kMessageSetTypeIdNumber, number, target); // Write message. - target = WireFormatLite::WriteMessageToArray( - WireFormatLite::kMessageSetMessageNumber, *message_value, target); + if (is_lazy) { + target = lazymessage_value->WriteMessageToArray( + WireFormatLite::kMessageSetMessageNumber, target); + } else { + target = WireFormatLite::WriteMessageToArray( + WireFormatLite::kMessageSetMessageNumber, *message_value, target); + } // End group. target = io::CodedOutputStream::WriteTagToArray( WireFormatLite::kMessageSetItemEndTag, target); return target; } + +bool ExtensionSet::ParseFieldMaybeLazily( + uint32 tag, io::CodedInputStream* input, + ExtensionFinder* extension_finder, + FieldSkipper* field_skipper) { + return ParseField(tag, input, extension_finder, field_skipper); +} + +bool ExtensionSet::ParseMessageSet(io::CodedInputStream* input, + ExtensionFinder* extension_finder, + FieldSkipper* field_skipper) { + while (true) { + uint32 tag = input->ReadTag(); + switch (tag) { + case 0: + return true; + case WireFormatLite::kMessageSetItemStartTag: + if (!ParseMessageSetItem(input, extension_finder, field_skipper)) { + return false; + } + break; + default: + if (!ParseField(tag, input, extension_finder, field_skipper)) { + return false; + } + break; + } + } +} + +bool ExtensionSet::ParseMessageSet(io::CodedInputStream* input, + const MessageLite* containing_type) { + FieldSkipper skipper; + GeneratedExtensionFinder finder(containing_type); + return ParseMessageSet(input, &finder, &skipper); +} + +bool ExtensionSet::ParseMessageSetItem(io::CodedInputStream* input, + ExtensionFinder* extension_finder, + FieldSkipper* field_skipper) { + // TODO(kenton): It would be nice to share code between this and + // WireFormatLite::ParseAndMergeMessageSetItem(), but I think the + // differences would be hard to factor out. + + // This method parses a group which should contain two fields: + // required int32 type_id = 2; + // required data message = 3; + + // Once we see a type_id, we'll construct a fake tag for this extension + // which is the tag it would have had under the proto2 extensions wire + // format. + uint32 fake_tag = 0; + + // If we see message data before the type_id, we'll append it to this so + // we can parse it later. + string message_data; + + while (true) { + uint32 tag = input->ReadTag(); + if (tag == 0) return false; + + switch (tag) { + case WireFormatLite::kMessageSetTypeIdTag: { + uint32 type_id; + if (!input->ReadVarint32(&type_id)) return false; + fake_tag = WireFormatLite::MakeTag(type_id, + WireFormatLite::WIRETYPE_LENGTH_DELIMITED); + + if (!message_data.empty()) { + // We saw some message data before the type_id. Have to parse it + // now. + io::CodedInputStream sub_input( + reinterpret_cast<const uint8*>(message_data.data()), + message_data.size()); + if (!ParseFieldMaybeLazily(fake_tag, &sub_input, + extension_finder, field_skipper)) { + return false; + } + message_data.clear(); + } + + break; + } + + case WireFormatLite::kMessageSetMessageTag: { + if (fake_tag == 0) { + // We haven't seen a type_id yet. Append this data to message_data. + string temp; + uint32 length; + if (!input->ReadVarint32(&length)) return false; + if (!input->ReadString(&temp, length)) return false; + io::StringOutputStream output_stream(&message_data); + io::CodedOutputStream coded_output(&output_stream); + coded_output.WriteVarint32(length); + coded_output.WriteString(temp); + } else { + // Already saw type_id, so we can parse this directly. + if (!ParseFieldMaybeLazily(fake_tag, input, + extension_finder, field_skipper)) { + return false; + } + } + + break; + } + + case WireFormatLite::kMessageSetItemEndTag: { + return true; + } + + default: { + if (!field_skipper->SkipField(input, tag)) return false; + } + } + } +} + +void ExtensionSet::Extension::SerializeMessageSetItemWithCachedSizes( + int number, + io::CodedOutputStream* output) const { + if (type != WireFormatLite::TYPE_MESSAGE || is_repeated) { + // Not a valid MessageSet extension, but serialize it the normal way. + SerializeFieldWithCachedSizes(number, output); + return; + } + + if (is_cleared) return; + + // Start group. + output->WriteTag(WireFormatLite::kMessageSetItemStartTag); + + // Write type ID. + WireFormatLite::WriteUInt32(WireFormatLite::kMessageSetTypeIdNumber, + number, + output); + // Write message. + if (is_lazy) { + lazymessage_value->WriteMessage( + WireFormatLite::kMessageSetMessageNumber, output); + } else { + WireFormatLite::WriteMessageMaybeToArray( + WireFormatLite::kMessageSetMessageNumber, + *message_value, + output); + } + + // End group. + output->WriteTag(WireFormatLite::kMessageSetItemEndTag); +} + +int ExtensionSet::Extension::MessageSetItemByteSize(int number) const { + if (type != WireFormatLite::TYPE_MESSAGE || is_repeated) { + // Not a valid MessageSet extension, but compute the byte size for it the + // normal way. + return ByteSize(number); + } + + if (is_cleared) return 0; + + int our_size = WireFormatLite::kMessageSetItemTagsSize; + + // type_id + our_size += io::CodedOutputStream::VarintSize32(number); + + // message + int message_size = 0; + if (is_lazy) { + message_size = lazymessage_value->ByteSize(); + } else { + message_size = message_value->ByteSize(); + } + + our_size += io::CodedOutputStream::VarintSize32(message_size); + our_size += message_size; + + return our_size; +} + +void ExtensionSet::SerializeMessageSetWithCachedSizes( + io::CodedOutputStream* output) const { + map<int, Extension>::const_iterator iter; + for (iter = extensions_.begin(); iter != extensions_.end(); ++iter) { + iter->second.SerializeMessageSetItemWithCachedSizes(iter->first, output); + } +} + +int ExtensionSet::MessageSetByteSize() const { + int total_size = 0; + + for (map<int, Extension>::const_iterator iter = extensions_.begin(); + iter != extensions_.end(); ++iter) { + total_size += iter->second.MessageSetItemByteSize(iter->first); + } + + return total_size; +} + } // namespace internal } // namespace protobuf } // namespace google |