diff options
Diffstat (limited to 'src/google/protobuf/extension_set_heavy.cc')
-rw-r--r-- | src/google/protobuf/extension_set_heavy.cc | 500 |
1 files changed, 333 insertions, 167 deletions
diff --git a/src/google/protobuf/extension_set_heavy.cc b/src/google/protobuf/extension_set_heavy.cc index a3c84167..20d36ab7 100644 --- a/src/google/protobuf/extension_set_heavy.cc +++ b/src/google/protobuf/extension_set_heavy.cc @@ -45,10 +45,13 @@ #include <google/protobuf/repeated_field.h> #include <google/protobuf/unknown_field_set.h> #include <google/protobuf/wire_format.h> +#include <google/protobuf/wire_format_lite.h> #include <google/protobuf/wire_format_lite_inl.h> -namespace google { +#include <google/protobuf/port_def.inc> + +namespace google { namespace protobuf { namespace internal { @@ -85,9 +88,9 @@ class DescriptorPoolExtensionFinder : public ExtensionFinder { MessageFactory* factory, const Descriptor* containing_type) : pool_(pool), factory_(factory), containing_type_(containing_type) {} - virtual ~DescriptorPoolExtensionFinder() {} + ~DescriptorPoolExtensionFinder() override {} - virtual bool Find(int number, ExtensionInfo* output); + bool Find(int number, ExtensionInfo* output) override; private: const DescriptorPool* pool_; @@ -244,7 +247,7 @@ ExtensionSet::Extension* ExtensionSet::MaybeNewRepeatedExtension(const FieldDesc GOOGLE_DCHECK_EQ(cpp_type(extension->type), FieldDescriptor::CPPTYPE_MESSAGE); extension->is_repeated = true; extension->repeated_message_value = - ::google::protobuf::Arena::CreateMessage<RepeatedPtrField<MessageLite> >(arena_); + Arena::CreateMessage<RepeatedPtrField<MessageLite> >(arena_); } else { GOOGLE_DCHECK_TYPE(*extension, REPEATED, MESSAGE); } @@ -258,7 +261,7 @@ MessageLite* ExtensionSet::AddMessage(const FieldDescriptor* descriptor, // RepeatedPtrField<Message> does not know how to Add() since it cannot // allocate an abstract object, so we have to be tricky. MessageLite* result = - reinterpret_cast<::google::protobuf::internal::RepeatedPtrFieldBase*>( + reinterpret_cast<internal::RepeatedPtrFieldBase*>( extension->repeated_message_value) ->AddFromCleared<GenericTypeHandler<MessageLite> >(); if (result == NULL) { @@ -312,6 +315,234 @@ bool DescriptorPoolExtensionFinder::Find(int number, ExtensionInfo* output) { } } +#if GOOGLE_PROTOBUF_ENABLE_EXPERIMENTAL_PARSER +bool ExtensionSet::FindExtension(int wire_type, uint32 field, + const Message* containing_type, + const internal::ParseContext* ctx, + ExtensionInfo* extension, + bool* was_packed_on_wire) { + if (ctx->extra_parse_data().pool == nullptr) { + GeneratedExtensionFinder finder(containing_type); + if (!FindExtensionInfoFromFieldNumber(wire_type, field, &finder, extension, + was_packed_on_wire)) { + return false; + } + } else { + DescriptorPoolExtensionFinder finder(ctx->extra_parse_data().pool, + ctx->extra_parse_data().factory, + containing_type->GetDescriptor()); + if (!FindExtensionInfoFromFieldNumber(wire_type, field, &finder, extension, + was_packed_on_wire)) { + return false; + } + } + return true; +} + +std::pair<const char*, bool> ExtensionSet::ParseField( + uint64 tag, ParseClosure parent, const char* begin, const char* end, + const Message* containing_type, + internal::InternalMetadataWithArena* metadata, + internal::ParseContext* ctx) { + int number = tag >> 3; + bool was_packed_on_wire; + ExtensionInfo extension; + if (!FindExtension(tag & 7, number, containing_type, ctx, &extension, + &was_packed_on_wire)) { + return UnknownFieldParse(tag, parent, begin, end, + metadata->mutable_unknown_fields(), ctx); + } + auto ptr = begin; + ParseClosure child; + int depth; + if (was_packed_on_wire) { + switch (extension.type) { +#define HANDLE_TYPE(UPPERCASE, CPP_CAMELCASE) \ + case WireFormatLite::TYPE_##UPPERCASE: \ + child = { \ + internal::Packed##CPP_CAMELCASE##Parser, \ + MutableRawRepeatedField(number, extension.type, extension.is_packed, \ + extension.descriptor)}; \ + goto length_delim + HANDLE_TYPE(INT32, Int32); + HANDLE_TYPE(INT64, Int64); + HANDLE_TYPE(UINT32, UInt32); + HANDLE_TYPE(UINT64, UInt64); + HANDLE_TYPE(SINT32, SInt32); + HANDLE_TYPE(SINT64, SInt64); + HANDLE_TYPE(FIXED32, Fixed32); + HANDLE_TYPE(FIXED64, Fixed64); + HANDLE_TYPE(SFIXED32, SFixed32); + HANDLE_TYPE(SFIXED64, SFixed64); + HANDLE_TYPE(FLOAT, Float); + HANDLE_TYPE(DOUBLE, Double); + HANDLE_TYPE(BOOL, Bool); +#undef HANDLE_TYPE + + case WireFormatLite::TYPE_ENUM: + ctx->extra_parse_data().SetEnumValidatorArg( + extension.enum_validity_check.func, + extension.enum_validity_check.arg, + metadata->mutable_unknown_fields(), tag >> 3); + child = { + internal::PackedValidEnumParserArg, + MutableRawRepeatedField(number, extension.type, extension.is_packed, + extension.descriptor)}; + goto length_delim; + case WireFormatLite::TYPE_STRING: + case WireFormatLite::TYPE_BYTES: + case WireFormatLite::TYPE_GROUP: + case WireFormatLite::TYPE_MESSAGE: + GOOGLE_LOG(FATAL) << "Non-primitive types can't be packed."; + break; + } + } else { + switch (extension.type) { +#define HANDLE_VARINT_TYPE(UPPERCASE, CPP_CAMELCASE) \ + case WireFormatLite::TYPE_##UPPERCASE: { \ + uint64 value; \ + ptr = Varint::Parse64(ptr, &value); \ + GOOGLE_PROTOBUF_ASSERT_RETURN(ptr, std::make_pair(nullptr, true)); \ + if (extension.is_repeated) { \ + Add##CPP_CAMELCASE(number, WireFormatLite::TYPE_##UPPERCASE, \ + extension.is_packed, value, extension.descriptor); \ + } else { \ + Set##CPP_CAMELCASE(number, WireFormatLite::TYPE_##UPPERCASE, value, \ + extension.descriptor); \ + } \ + } break + + HANDLE_VARINT_TYPE(INT32, Int32); + HANDLE_VARINT_TYPE(INT64, Int64); + HANDLE_VARINT_TYPE(UINT32, UInt32); + HANDLE_VARINT_TYPE(UINT64, UInt64); +#undef HANDLE_VARINT_TYPE +#define HANDLE_SVARINT_TYPE(UPPERCASE, CPP_CAMELCASE, SIZE) \ + case WireFormatLite::TYPE_##UPPERCASE: { \ + uint64 val; \ + ptr = Varint::Parse64(ptr, &val); \ + GOOGLE_PROTOBUF_ASSERT_RETURN(ptr, std::make_pair(nullptr, true)); \ + auto value = WireFormatLite::ZigZagDecode##SIZE(val); \ + if (extension.is_repeated) { \ + Add##CPP_CAMELCASE(number, WireFormatLite::TYPE_##UPPERCASE, \ + extension.is_packed, value, extension.descriptor); \ + } else { \ + Set##CPP_CAMELCASE(number, WireFormatLite::TYPE_##UPPERCASE, value, \ + extension.descriptor); \ + } \ + } break + + HANDLE_SVARINT_TYPE(SINT32, Int32, 32); + HANDLE_SVARINT_TYPE(SINT64, Int64, 64); +#undef HANDLE_SVARINT_TYPE +#define HANDLE_FIXED_TYPE(UPPERCASE, CPP_CAMELCASE, CPPTYPE) \ + case WireFormatLite::TYPE_##UPPERCASE: { \ + CPPTYPE value; \ + std::memcpy(&value, ptr, sizeof(CPPTYPE)); \ + ptr += sizeof(CPPTYPE); \ + if (extension.is_repeated) { \ + Add##CPP_CAMELCASE(number, WireFormatLite::TYPE_##UPPERCASE, \ + extension.is_packed, value, extension.descriptor); \ + } else { \ + Set##CPP_CAMELCASE(number, WireFormatLite::TYPE_##UPPERCASE, value, \ + extension.descriptor); \ + } \ + } break + + HANDLE_FIXED_TYPE(FIXED32, UInt32, uint32); + HANDLE_FIXED_TYPE(FIXED64, UInt64, uint64); + HANDLE_FIXED_TYPE(SFIXED32, Int32, int32); + HANDLE_FIXED_TYPE(SFIXED64, Int64, int64); + HANDLE_FIXED_TYPE(FLOAT, Float, float); + HANDLE_FIXED_TYPE(DOUBLE, Double, double); + HANDLE_FIXED_TYPE(BOOL, Bool, bool); +#undef HANDLE_FIXED_TYPE + + case WireFormatLite::TYPE_ENUM: { + uint64 val; + ptr = Varint::Parse64(ptr, &val); + GOOGLE_PROTOBUF_ASSERT_RETURN(ptr, std::make_pair(nullptr, true)); + int value = val; + + if (!extension.enum_validity_check.func( + extension.enum_validity_check.arg, value)) { + WriteVarint(number, val, metadata->mutable_unknown_fields()); + } else if (extension.is_repeated) { + AddEnum(number, WireFormatLite::TYPE_ENUM, extension.is_packed, value, + extension.descriptor); + } else { + SetEnum(number, WireFormatLite::TYPE_ENUM, value, + extension.descriptor); + } + break; + } + + case WireFormatLite::TYPE_BYTES: + case WireFormatLite::TYPE_STRING: { + string* value = extension.is_repeated + ? AddString(number, WireFormatLite::TYPE_STRING, + extension.descriptor) + : MutableString(number, WireFormatLite::TYPE_STRING, + extension.descriptor); + child = {StringParser, value}; + goto length_delim; + } + + case WireFormatLite::TYPE_GROUP: { + MessageLite* value = + extension.is_repeated + ? AddMessage(number, WireFormatLite::TYPE_GROUP, + *extension.message_prototype, extension.descriptor) + : MutableMessage(number, WireFormatLite::TYPE_GROUP, + *extension.message_prototype, + extension.descriptor); + child = {value->_ParseFunc(), value}; + bool ok = ctx->PrepareGroup(tag, &depth); + GOOGLE_PROTOBUF_ASSERT_RETURN(ok, std::make_pair(nullptr, true)); + ptr = child(ptr, end, ctx); + GOOGLE_PROTOBUF_ASSERT_RETURN(ptr, std::make_pair(nullptr, true)); + if (ctx->GroupContinues(depth)) goto group_continues; + break; + } + + case WireFormatLite::TYPE_MESSAGE: { + MessageLite* value = + extension.is_repeated + ? AddMessage(number, WireFormatLite::TYPE_MESSAGE, + *extension.message_prototype, extension.descriptor) + : MutableMessage(number, WireFormatLite::TYPE_MESSAGE, + *extension.message_prototype, + extension.descriptor); + child = {value->_ParseFunc(), value}; + goto length_delim; + } + } + } + + return std::make_pair(ptr, false); + +length_delim: + uint32 size; + ptr = Varint::Parse32Inline(ptr, &size); + GOOGLE_PROTOBUF_ASSERT_RETURN(ptr, std::make_pair(nullptr, true)); + if (size > end - ptr) goto len_delim_till_end; + { + auto newend = ptr + size; + bool ok = ctx->ParseExactRange(child, ptr, newend); + GOOGLE_PROTOBUF_ASSERT_RETURN(ok, std::make_pair(nullptr, true)); + ptr = newend; + } + return std::make_pair(ptr, false); +len_delim_till_end: + return std::make_pair(ctx->StoreAndTailCall(ptr, end, parent, child, size), + true); + +group_continues: + ctx->StoreGroup(parent, child, depth); + return std::make_pair(ptr, true); +} +#endif // GOOGLE_PROTOBUF_ENABLE_EXPERIMENTAL_PARSER + bool ExtensionSet::ParseField(uint32 tag, io::CodedInputStream* input, const Message* containing_type, UnknownFieldSet* unknown_fields) { @@ -327,6 +558,82 @@ bool ExtensionSet::ParseField(uint32 tag, io::CodedInputStream* input, } } +#if GOOGLE_PROTOBUF_ENABLE_EXPERIMENTAL_PARSER +const char* ExtensionSet::ParseMessageSetItem( + ParseClosure parent, const char* begin, const char* end, + const Message* containing_type, + internal::InternalMetadataWithArena* metadata, + internal::ParseContext* ctx) { + auto ptr = begin; + while (ptr < end) { + uint32 tag = *ptr++; + if (tag == WireFormatLite::kMessageSetTypeIdTag) { + uint32 type_id; + ptr = Varint::Parse32(ptr, &type_id); + GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); + + if (ctx->extra_parse_data().payload.empty()) { + tag = *ptr++; + GOOGLE_PROTOBUF_PARSER_ASSERT(tag == + WireFormatLite::kMessageSetMessageTag); + auto res = ParseField(static_cast<uint64>(type_id) * 8 + 2, parent, ptr, + end, containing_type, metadata, ctx); + ptr = res.first; + if (res.second) break; + } else { + ExtensionInfo extension; + GeneratedExtensionFinder finder(containing_type); + bool was_packed_on_wire; + if (!FindExtension(2, type_id, containing_type, ctx, &extension, + &was_packed_on_wire)) { + metadata->mutable_unknown_fields()->AddLengthDelimited( + type_id, ctx->extra_parse_data().payload); + continue; + } + MessageLite* value = + extension.is_repeated + ? AddMessage(type_id, WireFormatLite::TYPE_MESSAGE, + *extension.message_prototype, extension.descriptor) + : MutableMessage(type_id, WireFormatLite::TYPE_MESSAGE, + *extension.message_prototype, + extension.descriptor); + ParseClosure parser = {value->_ParseFunc(), value}; + StringPiece chunk(ctx->extra_parse_data().payload.data()); + bool ok = ctx->ParseExactRange(parser, chunk.begin(), chunk.end()); + GOOGLE_PROTOBUF_PARSER_ASSERT(ok); + } + } else if (tag == WireFormatLite::kMessageSetItemEndTag) { + bool ok = ctx->ValidEndGroup(tag); + GOOGLE_PROTOBUF_PARSER_ASSERT(ok); + break; + } else if (tag == WireFormatLite::kMessageSetMessageTag) { + uint32 size; + ptr = Varint::Parse32Inline(ptr, &size); + GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); + ParseClosure child = {internal::StringParser, + &ctx->extra_parse_data().payload}; + if (size > end - ptr) { + return ctx->StoreAndTailCall(ptr, end, parent, child, size); + } else { + auto newend = ptr + size; + bool ok = ctx->ParseExactRange(child, ptr, newend); + GOOGLE_PROTOBUF_PARSER_ASSERT(ok); + ptr = newend; + } + } else { + ptr--; + ptr = Varint::Parse32(ptr, &tag); + GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); + auto res = + ParseField(tag, parent, ptr, end, containing_type, metadata, ctx); + ptr = res.first; + if (res.second) break; + } + } + return ptr; +} +#endif // GOOGLE_PROTOBUF_ENABLE_EXPERIMENTAL_PARSER + bool ExtensionSet::ParseMessageSet(io::CodedInputStream* input, const Message* containing_type, UnknownFieldSet* unknown_fields) { @@ -385,11 +692,10 @@ size_t ExtensionSet::Extension::SpaceUsedExcludingSelfLong() const { // but MessageLite has no SpaceUsedLong(), so we must directly call // RepeatedPtrFieldBase::SpaceUsedExcludingSelfLong() with a different // type handler. - total_size += - sizeof(*repeated_message_value) + - RepeatedMessage_SpaceUsedExcludingSelfLong( - reinterpret_cast<::google::protobuf::internal::RepeatedPtrFieldBase*>( - repeated_message_value)); + total_size += sizeof(*repeated_message_value) + + RepeatedMessage_SpaceUsedExcludingSelfLong( + reinterpret_cast<internal::RepeatedPtrFieldBase*>( + repeated_message_value)); break; } } else { @@ -420,21 +726,19 @@ uint8* ExtensionSet::SerializeWithCachedSizesToArray(int start_field_number, uint8* target) const { return InternalSerializeWithCachedSizesToArray( start_field_number, end_field_number, - google::protobuf::io::CodedOutputStream::IsDefaultSerializationDeterministic(), - target); + io::CodedOutputStream::IsDefaultSerializationDeterministic(), target); } uint8* ExtensionSet::SerializeMessageSetWithCachedSizesToArray( uint8* target) const { return InternalSerializeMessageSetWithCachedSizesToArray( - google::protobuf::io::CodedOutputStream::IsDefaultSerializationDeterministic(), - target); + io::CodedOutputStream::IsDefaultSerializationDeterministic(), target); } uint8* ExtensionSet::InternalSerializeWithCachedSizesToArray( int start_field_number, int end_field_number, bool deterministic, uint8* target) const { - if (GOOGLE_PREDICT_FALSE(is_large())) { + if (PROTOBUF_PREDICT_FALSE(is_large())) { const auto& end = map_.large->end(); for (auto it = map_.large->lower_bound(start_field_number); it != end && it->first < end_field_number; ++it) { @@ -650,165 +954,27 @@ bool ExtensionSet::ParseMessageSet(io::CodedInputStream* input, } } -bool ExtensionSet::ParseMessageSet(io::CodedInputStream* input, - const MessageLite* containing_type) { - MessageSetFieldSkipper skipper(NULL); - GeneratedExtensionFinder finder(containing_type); - return ParseMessageSet(input, &finder, &skipper); -} - bool ExtensionSet::ParseMessageSetItem(io::CodedInputStream* input, ExtensionFinder* extension_finder, MessageSetFieldSkipper* field_skipper) { - // TODO(kenton): It would be nice to share code between this and - // WireFormatLite::ParseAndMergeMessageSetItem(), but I think the - // differences would be hard to factor out. - - // This method parses a group which should contain two fields: - // required int32 type_id = 2; - // required data message = 3; - - uint32 last_type_id = 0; - - // If we see message data before the type_id, we'll append it to this so - // we can parse it later. - string message_data; - - while (true) { - const uint32 tag = input->ReadTagNoLastTag(); - if (tag == 0) return false; - - switch (tag) { - case WireFormatLite::kMessageSetTypeIdTag: { - uint32 type_id; - if (!input->ReadVarint32(&type_id)) return false; - last_type_id = type_id; - - if (!message_data.empty()) { - // We saw some message data before the type_id. Have to parse it - // now. - io::CodedInputStream sub_input( - reinterpret_cast<const uint8*>(message_data.data()), - message_data.size()); - if (!ParseFieldMaybeLazily(WireFormatLite::WIRETYPE_LENGTH_DELIMITED, - last_type_id, &sub_input, - extension_finder, field_skipper)) { - return false; - } - message_data.clear(); - } - - break; - } - - case WireFormatLite::kMessageSetMessageTag: { - if (last_type_id == 0) { - // We haven't seen a type_id yet. Append this data to message_data. - string temp; - uint32 length; - if (!input->ReadVarint32(&length)) return false; - if (!input->ReadString(&temp, length)) return false; - io::StringOutputStream output_stream(&message_data); - io::CodedOutputStream coded_output(&output_stream); - coded_output.WriteVarint32(length); - coded_output.WriteString(temp); - } else { - // Already saw type_id, so we can parse this directly. - if (!ParseFieldMaybeLazily(WireFormatLite::WIRETYPE_LENGTH_DELIMITED, - last_type_id, input, - extension_finder, field_skipper)) { - return false; - } - } - - break; - } - - case WireFormatLite::kMessageSetItemEndTag: { - return true; - } - - default: { - if (!field_skipper->SkipField(input, tag)) return false; - } + struct MSFull { + bool ParseField(int type_id, io::CodedInputStream* input) { + return me->ParseFieldMaybeLazily( + WireFormatLite::WIRETYPE_LENGTH_DELIMITED, type_id, input, + extension_finder, field_skipper); } - } -} - -void ExtensionSet::Extension::SerializeMessageSetItemWithCachedSizes( - int number, - io::CodedOutputStream* output) const { - if (type != WireFormatLite::TYPE_MESSAGE || is_repeated) { - // Not a valid MessageSet extension, but serialize it the normal way. - SerializeFieldWithCachedSizes(number, output); - return; - } - - if (is_cleared) return; - - // Start group. - output->WriteTag(WireFormatLite::kMessageSetItemStartTag); - - // Write type ID. - WireFormatLite::WriteUInt32(WireFormatLite::kMessageSetTypeIdNumber, - number, - output); - // Write message. - if (is_lazy) { - lazymessage_value->WriteMessage( - WireFormatLite::kMessageSetMessageNumber, output); - } else { - WireFormatLite::WriteMessageMaybeToArray( - WireFormatLite::kMessageSetMessageNumber, - *message_value, - output); - } - - // End group. - output->WriteTag(WireFormatLite::kMessageSetItemEndTag); -} - -size_t ExtensionSet::Extension::MessageSetItemByteSize(int number) const { - if (type != WireFormatLite::TYPE_MESSAGE || is_repeated) { - // Not a valid MessageSet extension, but compute the byte size for it the - // normal way. - return ByteSize(number); - } - - if (is_cleared) return 0; - - size_t our_size = WireFormatLite::kMessageSetItemTagsSize; - // type_id - our_size += io::CodedOutputStream::VarintSize32(number); - - // message - size_t message_size = 0; - if (is_lazy) { - message_size = lazymessage_value->ByteSizeLong(); - } else { - message_size = message_value->ByteSizeLong(); - } - - our_size += io::CodedOutputStream::VarintSize32(message_size); - our_size += message_size; - - return our_size; -} + bool SkipField(uint32 tag, io::CodedInputStream* input) { + return field_skipper->SkipField(input, tag); + } -void ExtensionSet::SerializeMessageSetWithCachedSizes( - io::CodedOutputStream* output) const { - ForEach([output](int number, const Extension& ext) { - ext.SerializeMessageSetItemWithCachedSizes(number, output); - }); -} + ExtensionSet* me; + ExtensionFinder* extension_finder; + MessageSetFieldSkipper* field_skipper; + }; -size_t ExtensionSet::MessageSetByteSize() const { - size_t total_size = 0; - ForEach([&total_size](int number, const Extension& ext) { - total_size += ext.MessageSetItemByteSize(number); - }); - return total_size; + return ParseMessageSetItemImpl(input, + MSFull{this, extension_finder, field_skipper}); } } // namespace internal |