diff options
Diffstat (limited to 'src/google/protobuf/extension_set_heavy.cc')
-rw-r--r-- | src/google/protobuf/extension_set_heavy.cc | 430 |
1 files changed, 339 insertions, 91 deletions
diff --git a/src/google/protobuf/extension_set_heavy.cc b/src/google/protobuf/extension_set_heavy.cc index 372aea57..7c93c61d 100644 --- a/src/google/protobuf/extension_set_heavy.cc +++ b/src/google/protobuf/extension_set_heavy.cc @@ -47,8 +47,10 @@ #include <google/protobuf/wire_format.h> #include <google/protobuf/wire_format_lite_inl.h> -namespace google { +#include <google/protobuf/port_def.inc> + +namespace google { namespace protobuf { namespace internal { @@ -85,9 +87,9 @@ class DescriptorPoolExtensionFinder : public ExtensionFinder { MessageFactory* factory, const Descriptor* containing_type) : pool_(pool), factory_(factory), containing_type_(containing_type) {} - virtual ~DescriptorPoolExtensionFinder() override {} + ~DescriptorPoolExtensionFinder() override {} - virtual bool Find(int number, ExtensionInfo* output) override; + bool Find(int number, ExtensionInfo* output) override; private: const DescriptorPool* pool_; @@ -244,7 +246,7 @@ ExtensionSet::Extension* ExtensionSet::MaybeNewRepeatedExtension(const FieldDesc GOOGLE_DCHECK_EQ(cpp_type(extension->type), FieldDescriptor::CPPTYPE_MESSAGE); extension->is_repeated = true; extension->repeated_message_value = - ::google::protobuf::Arena::CreateMessage<RepeatedPtrField<MessageLite> >(arena_); + Arena::CreateMessage<RepeatedPtrField<MessageLite> >(arena_); } else { GOOGLE_DCHECK_TYPE(*extension, REPEATED, MESSAGE); } @@ -258,7 +260,7 @@ MessageLite* ExtensionSet::AddMessage(const FieldDescriptor* descriptor, // RepeatedPtrField<Message> does not know how to Add() since it cannot // allocate an abstract object, so we have to be tricky. MessageLite* result = - reinterpret_cast<::google::protobuf::internal::RepeatedPtrFieldBase*>( + reinterpret_cast<internal::RepeatedPtrFieldBase*>( extension->repeated_message_value) ->AddFromCleared<GenericTypeHandler<MessageLite> >(); if (result == NULL) { @@ -312,6 +314,235 @@ bool DescriptorPoolExtensionFinder::Find(int number, ExtensionInfo* output) { } } +#if GOOGLE_PROTOBUF_ENABLE_EXPERIMENTAL_PARSER +bool ExtensionSet::FindExtension(uint32 tag, const Message* containing_type, + const internal::ParseContext* ctx, + ExtensionInfo* extension, int* number, + bool* was_packed_on_wire) { + if (ctx->extra_parse_data().pool == nullptr) { + GeneratedExtensionFinder finder(containing_type); + if (!FindExtensionInfoFromTag(tag, &finder, number, extension, + was_packed_on_wire)) { + return false; + } + } else { + DescriptorPoolExtensionFinder finder(ctx->extra_parse_data().pool, + ctx->extra_parse_data().factory, + containing_type->GetDescriptor()); + if (!FindExtensionInfoFromTag(tag, &finder, number, extension, + was_packed_on_wire)) { + return false; + } + } + return true; +} + +std::pair<const char*, bool> ExtensionSet::ParseField( + uint32 tag, ParseClosure parent, const char* begin, const char* end, + const Message* containing_type, + internal::InternalMetadataWithArena* metadata, + internal::ParseContext* ctx) { + int number; + bool was_packed_on_wire; + ExtensionInfo extension; + if (!FindExtension(tag, containing_type, ctx, &extension, &number, + &was_packed_on_wire)) { + return UnknownFieldParse(tag, parent, begin, end, + metadata->mutable_unknown_fields(), ctx); + } + auto ptr = begin; + ParseClosure child; + int depth; + if (was_packed_on_wire) { + switch (extension.type) { +#define HANDLE_TYPE(UPPERCASE, CPP_CAMELCASE) \ + case WireFormatLite::TYPE_##UPPERCASE: \ + child = { \ + internal::Packed##CPP_CAMELCASE##Parser, \ + MutableRawRepeatedField(number, extension.type, extension.is_packed, \ + extension.descriptor)}; \ + goto length_delim + HANDLE_TYPE(INT32, Int32); + HANDLE_TYPE(INT64, Int64); + HANDLE_TYPE(UINT32, UInt32); + HANDLE_TYPE(UINT64, UInt64); + HANDLE_TYPE(SINT32, SInt32); + HANDLE_TYPE(SINT64, SInt64); + HANDLE_TYPE(FIXED32, Fixed32); + HANDLE_TYPE(FIXED64, Fixed64); + HANDLE_TYPE(SFIXED32, SFixed32); + HANDLE_TYPE(SFIXED64, SFixed64); + HANDLE_TYPE(FLOAT, Float); + HANDLE_TYPE(DOUBLE, Double); + HANDLE_TYPE(BOOL, Bool); +#undef HANDLE_TYPE + + case WireFormatLite::TYPE_ENUM: + ctx->extra_parse_data().SetEnumValidatorArg( + extension.enum_validity_check.func, + extension.enum_validity_check.arg, + metadata->mutable_unknown_fields(), tag >> 3); + child = { + internal::PackedValidEnumParserArg, + MutableRawRepeatedField(number, extension.type, extension.is_packed, + extension.descriptor)}; + goto length_delim; + case WireFormatLite::TYPE_STRING: + case WireFormatLite::TYPE_BYTES: + case WireFormatLite::TYPE_GROUP: + case WireFormatLite::TYPE_MESSAGE: + GOOGLE_LOG(FATAL) << "Non-primitive types can't be packed."; + break; + } + } else { + switch (extension.type) { +#define HANDLE_VARINT_TYPE(UPPERCASE, CPP_CAMELCASE) \ + case WireFormatLite::TYPE_##UPPERCASE: { \ + uint64 value; \ + ptr = Varint::Parse64(ptr, &value); \ + if (ptr == nullptr) goto error; \ + if (extension.is_repeated) { \ + Add##CPP_CAMELCASE(number, WireFormatLite::TYPE_##UPPERCASE, \ + extension.is_packed, value, extension.descriptor); \ + } else { \ + Set##CPP_CAMELCASE(number, WireFormatLite::TYPE_##UPPERCASE, value, \ + extension.descriptor); \ + } \ + } break + + HANDLE_VARINT_TYPE(INT32, Int32); + HANDLE_VARINT_TYPE(INT64, Int64); + HANDLE_VARINT_TYPE(UINT32, UInt32); + HANDLE_VARINT_TYPE(UINT64, UInt64); +#undef HANDLE_VARINT_TYPE +#define HANDLE_SVARINT_TYPE(UPPERCASE, CPP_CAMELCASE, SIZE) \ + case WireFormatLite::TYPE_##UPPERCASE: { \ + uint64 val; \ + ptr = Varint::Parse64(ptr, &val); \ + auto value = WireFormatLite::ZigZagDecode##SIZE(val); \ + if (extension.is_repeated) { \ + Add##CPP_CAMELCASE(number, WireFormatLite::TYPE_##UPPERCASE, \ + extension.is_packed, value, extension.descriptor); \ + } else { \ + Set##CPP_CAMELCASE(number, WireFormatLite::TYPE_##UPPERCASE, value, \ + extension.descriptor); \ + } \ + } break + + HANDLE_SVARINT_TYPE(SINT32, Int32, 32); + HANDLE_SVARINT_TYPE(SINT64, Int64, 64); +#undef HANDLE_SVARINT_TYPE +#define HANDLE_FIXED_TYPE(UPPERCASE, CPP_CAMELCASE, CPPTYPE) \ + case WireFormatLite::TYPE_##UPPERCASE: { \ + CPPTYPE value; \ + std::memcpy(&value, ptr, sizeof(CPPTYPE)); \ + ptr += sizeof(CPPTYPE); \ + if (extension.is_repeated) { \ + Add##CPP_CAMELCASE(number, WireFormatLite::TYPE_##UPPERCASE, \ + extension.is_packed, value, extension.descriptor); \ + } else { \ + Set##CPP_CAMELCASE(number, WireFormatLite::TYPE_##UPPERCASE, value, \ + extension.descriptor); \ + } \ + } break + + HANDLE_FIXED_TYPE(FIXED32, UInt32, uint32); + HANDLE_FIXED_TYPE(FIXED64, UInt64, uint64); + HANDLE_FIXED_TYPE(SFIXED32, Int32, int32); + HANDLE_FIXED_TYPE(SFIXED64, Int64, int64); + HANDLE_FIXED_TYPE(FLOAT, Float, float); + HANDLE_FIXED_TYPE(DOUBLE, Double, double); + HANDLE_FIXED_TYPE(BOOL, Bool, bool); +#undef HANDLE_FIXED_TYPE + + case WireFormatLite::TYPE_ENUM: { + uint64 val; + ptr = Varint::Parse64(ptr, &val); + if (ptr == nullptr) goto error; + int value = val; + + if (!extension.enum_validity_check.func( + extension.enum_validity_check.arg, value)) { + WriteVarint(number, val, metadata->mutable_unknown_fields()); + } else if (extension.is_repeated) { + AddEnum(number, WireFormatLite::TYPE_ENUM, extension.is_packed, value, + extension.descriptor); + } else { + SetEnum(number, WireFormatLite::TYPE_ENUM, value, + extension.descriptor); + } + break; + } + + case WireFormatLite::TYPE_BYTES: + case WireFormatLite::TYPE_STRING: { + string* value = extension.is_repeated + ? AddString(number, WireFormatLite::TYPE_STRING, + extension.descriptor) + : MutableString(number, WireFormatLite::TYPE_STRING, + extension.descriptor); + child = {StringParser, value}; + goto length_delim; + } + + case WireFormatLite::TYPE_GROUP: { + MessageLite* value = + extension.is_repeated + ? AddMessage(number, WireFormatLite::TYPE_GROUP, + *extension.message_prototype, extension.descriptor) + : MutableMessage(number, WireFormatLite::TYPE_GROUP, + *extension.message_prototype, + extension.descriptor); + child = {value->_ParseFunc(), value}; + if (!ctx->PrepareGroup(tag, &depth)) goto error; + ptr = child(ptr, end, ctx); + if (!ptr) goto error; + if (ctx->GroupContinues(depth)) goto group_continues; + break; + } + + case WireFormatLite::TYPE_MESSAGE: { + MessageLite* value = + extension.is_repeated + ? AddMessage(number, WireFormatLite::TYPE_MESSAGE, + *extension.message_prototype, extension.descriptor) + : MutableMessage(number, WireFormatLite::TYPE_MESSAGE, + *extension.message_prototype, + extension.descriptor); + child = {value->_ParseFunc(), value}; + goto length_delim; + } + } + } + + return std::make_pair(ptr, false); + +error: + return std::make_pair(nullptr, true); + +length_delim: + uint32 size; + ptr = Varint::Parse32Inline(ptr, &size); + if (!ptr) goto error; + if (size > end - ptr) goto len_delim_till_end; + { + auto newend = ptr + size; + if (!ctx->ParseExactRange(child, ptr, newend)) { + goto error; + } + ptr = newend; + } + return std::make_pair(ptr, false); +len_delim_till_end: + return std::make_pair(ctx->StoreAndTailCall(ptr, end, parent, child, size), + true); + +group_continues: + ctx->StoreGroup(parent, child, depth); + return std::make_pair(ptr, true); +} +#endif // GOOGLE_PROTOBUF_ENABLE_EXPERIMENTAL_PARSER + bool ExtensionSet::ParseField(uint32 tag, io::CodedInputStream* input, const Message* containing_type, UnknownFieldSet* unknown_fields) { @@ -327,6 +558,88 @@ bool ExtensionSet::ParseField(uint32 tag, io::CodedInputStream* input, } } +#if GOOGLE_PROTOBUF_ENABLE_EXPERIMENTAL_PARSER +const char* ExtensionSet::ParseMessageSetItem( + ParseClosure parent, const char* begin, const char* end, + const Message* containing_type, + internal::InternalMetadataWithArena* metadata, + internal::ParseContext* ctx) { + auto ptr = begin; + while (ptr < end) { + uint32 tag = *ptr++; + if (tag == WireFormatLite::kMessageSetTypeIdTag) { + uint32 type_id; + ptr = Varint::Parse32(ptr, &type_id); + if (!ptr) goto error; + + if (ctx->extra_parse_data().payload.empty()) { + tag = *ptr++; + if (tag == WireFormatLite::kMessageSetMessageTag) { + auto res = ParseField(type_id * 8 + 2, parent, ptr, end, + containing_type, metadata, ctx); + ptr = res.first; + if (res.second) break; + } else { + goto error; + } + } else { + ExtensionInfo extension; + GeneratedExtensionFinder finder(containing_type); + int number; + bool was_packed_on_wire; + if (!FindExtension(type_id * 8 + 2, containing_type, ctx, &extension, + &number, &was_packed_on_wire)) { + metadata->mutable_unknown_fields()->AddLengthDelimited( + type_id, ctx->extra_parse_data().payload); + continue; + } + MessageLite* value = + extension.is_repeated + ? AddMessage(number, WireFormatLite::TYPE_MESSAGE, + *extension.message_prototype, extension.descriptor) + : MutableMessage(number, WireFormatLite::TYPE_MESSAGE, + *extension.message_prototype, + extension.descriptor); + ParseClosure parser = {value->_ParseFunc(), value}; + StringPiece chunk(ctx->extra_parse_data().payload.data()); + if (!ctx->ParseExactRange(parser, chunk.begin(), chunk.end())) { + return nullptr; + } + } + } else if (tag == WireFormatLite::kMessageSetItemEndTag) { + if (!ctx->ValidEndGroup(tag)) goto error; + break; + } else if (tag == WireFormatLite::kMessageSetMessageTag) { + uint32 size; + ptr = Varint::Parse32Inline(ptr, &size); + if (!ptr) goto error; + ParseClosure child = {internal::StringParser, + &ctx->extra_parse_data().payload}; + if (size > end - ptr) { + return ctx->StoreAndTailCall(ptr, end, parent, child, size); + } else { + auto newend = ptr + size; + if (!ctx->ParseExactRange(child, ptr, newend)) { + goto error; + } + ptr = newend; + } + } else { + ptr--; + ptr = Varint::Parse32(ptr, &tag); + if (ptr == nullptr) goto error; + auto res = + ParseField(tag, parent, ptr, end, containing_type, metadata, ctx); + ptr = res.first; + if (res.second) break; + } + } + return ptr; +error: + return nullptr; +} +#endif // GOOGLE_PROTOBUF_ENABLE_EXPERIMENTAL_PARSER + bool ExtensionSet::ParseMessageSet(io::CodedInputStream* input, const Message* containing_type, UnknownFieldSet* unknown_fields) { @@ -385,11 +698,10 @@ size_t ExtensionSet::Extension::SpaceUsedExcludingSelfLong() const { // but MessageLite has no SpaceUsedLong(), so we must directly call // RepeatedPtrFieldBase::SpaceUsedExcludingSelfLong() with a different // type handler. - total_size += - sizeof(*repeated_message_value) + - RepeatedMessage_SpaceUsedExcludingSelfLong( - reinterpret_cast<::google::protobuf::internal::RepeatedPtrFieldBase*>( - repeated_message_value)); + total_size += sizeof(*repeated_message_value) + + RepeatedMessage_SpaceUsedExcludingSelfLong( + reinterpret_cast<internal::RepeatedPtrFieldBase*>( + repeated_message_value)); break; } } else { @@ -420,15 +732,13 @@ uint8* ExtensionSet::SerializeWithCachedSizesToArray(int start_field_number, uint8* target) const { return InternalSerializeWithCachedSizesToArray( start_field_number, end_field_number, - google::protobuf::io::CodedOutputStream::IsDefaultSerializationDeterministic(), - target); + io::CodedOutputStream::IsDefaultSerializationDeterministic(), target); } uint8* ExtensionSet::SerializeMessageSetWithCachedSizesToArray( uint8* target) const { return InternalSerializeMessageSetWithCachedSizesToArray( - google::protobuf::io::CodedOutputStream::IsDefaultSerializationDeterministic(), - target); + io::CodedOutputStream::IsDefaultSerializationDeterministic(), target); } uint8* ExtensionSet::InternalSerializeWithCachedSizesToArray( @@ -650,89 +960,27 @@ bool ExtensionSet::ParseMessageSet(io::CodedInputStream* input, } } -bool ExtensionSet::ParseMessageSet(io::CodedInputStream* input, - const MessageLite* containing_type) { - MessageSetFieldSkipper skipper(NULL); - GeneratedExtensionFinder finder(containing_type); - return ParseMessageSet(input, &finder, &skipper); -} - bool ExtensionSet::ParseMessageSetItem(io::CodedInputStream* input, ExtensionFinder* extension_finder, MessageSetFieldSkipper* field_skipper) { - // TODO(kenton): It would be nice to share code between this and - // WireFormatLite::ParseAndMergeMessageSetItem(), but I think the - // differences would be hard to factor out. - - // This method parses a group which should contain two fields: - // required int32 type_id = 2; - // required data message = 3; - - uint32 last_type_id = 0; - - // If we see message data before the type_id, we'll append it to this so - // we can parse it later. - string message_data; - - while (true) { - const uint32 tag = input->ReadTagNoLastTag(); - if (tag == 0) return false; - - switch (tag) { - case WireFormatLite::kMessageSetTypeIdTag: { - uint32 type_id; - if (!input->ReadVarint32(&type_id)) return false; - last_type_id = type_id; - - if (!message_data.empty()) { - // We saw some message data before the type_id. Have to parse it - // now. - io::CodedInputStream sub_input( - reinterpret_cast<const uint8*>(message_data.data()), - message_data.size()); - if (!ParseFieldMaybeLazily(WireFormatLite::WIRETYPE_LENGTH_DELIMITED, - last_type_id, &sub_input, - extension_finder, field_skipper)) { - return false; - } - message_data.clear(); - } - - break; - } - - case WireFormatLite::kMessageSetMessageTag: { - if (last_type_id == 0) { - // We haven't seen a type_id yet. Append this data to message_data. - string temp; - uint32 length; - if (!input->ReadVarint32(&length)) return false; - if (!input->ReadString(&temp, length)) return false; - io::StringOutputStream output_stream(&message_data); - io::CodedOutputStream coded_output(&output_stream); - coded_output.WriteVarint32(length); - coded_output.WriteString(temp); - } else { - // Already saw type_id, so we can parse this directly. - if (!ParseFieldMaybeLazily(WireFormatLite::WIRETYPE_LENGTH_DELIMITED, - last_type_id, input, - extension_finder, field_skipper)) { - return false; - } - } + struct MSFull { + bool ParseField(int type_id, io::CodedInputStream* input) { + return me->ParseFieldMaybeLazily( + WireFormatLite::WIRETYPE_LENGTH_DELIMITED, type_id, input, + extension_finder, field_skipper); + } - break; - } + bool SkipField(uint32 tag, io::CodedInputStream* input) { + return field_skipper->SkipField(input, tag); + } - case WireFormatLite::kMessageSetItemEndTag: { - return true; - } + ExtensionSet* me; + ExtensionFinder* extension_finder; + MessageSetFieldSkipper* field_skipper; + }; - default: { - if (!field_skipper->SkipField(input, tag)) return false; - } - } - } + return ParseMessageSetItemImpl(input, + MSFull{this, extension_finder, field_skipper}); } void ExtensionSet::Extension::SerializeMessageSetItemWithCachedSizes( |