diff options
Diffstat (limited to 'src/google/protobuf/extension_set_heavy.cc')
-rw-r--r-- | src/google/protobuf/extension_set_heavy.cc | 162 |
1 files changed, 40 insertions, 122 deletions
diff --git a/src/google/protobuf/extension_set_heavy.cc b/src/google/protobuf/extension_set_heavy.cc index 7c93c61d..20d36ab7 100644 --- a/src/google/protobuf/extension_set_heavy.cc +++ b/src/google/protobuf/extension_set_heavy.cc @@ -45,6 +45,7 @@ #include <google/protobuf/repeated_field.h> #include <google/protobuf/unknown_field_set.h> #include <google/protobuf/wire_format.h> +#include <google/protobuf/wire_format_lite.h> #include <google/protobuf/wire_format_lite_inl.h> @@ -315,22 +316,23 @@ bool DescriptorPoolExtensionFinder::Find(int number, ExtensionInfo* output) { } #if GOOGLE_PROTOBUF_ENABLE_EXPERIMENTAL_PARSER -bool ExtensionSet::FindExtension(uint32 tag, const Message* containing_type, +bool ExtensionSet::FindExtension(int wire_type, uint32 field, + const Message* containing_type, const internal::ParseContext* ctx, - ExtensionInfo* extension, int* number, + ExtensionInfo* extension, bool* was_packed_on_wire) { if (ctx->extra_parse_data().pool == nullptr) { GeneratedExtensionFinder finder(containing_type); - if (!FindExtensionInfoFromTag(tag, &finder, number, extension, - was_packed_on_wire)) { + if (!FindExtensionInfoFromFieldNumber(wire_type, field, &finder, extension, + was_packed_on_wire)) { return false; } } else { DescriptorPoolExtensionFinder finder(ctx->extra_parse_data().pool, ctx->extra_parse_data().factory, containing_type->GetDescriptor()); - if (!FindExtensionInfoFromTag(tag, &finder, number, extension, - was_packed_on_wire)) { + if (!FindExtensionInfoFromFieldNumber(wire_type, field, &finder, extension, + was_packed_on_wire)) { return false; } } @@ -338,14 +340,14 @@ bool ExtensionSet::FindExtension(uint32 tag, const Message* containing_type, } std::pair<const char*, bool> ExtensionSet::ParseField( - uint32 tag, ParseClosure parent, const char* begin, const char* end, + uint64 tag, ParseClosure parent, const char* begin, const char* end, const Message* containing_type, internal::InternalMetadataWithArena* metadata, internal::ParseContext* ctx) { - int number; + int number = tag >> 3; bool was_packed_on_wire; ExtensionInfo extension; - if (!FindExtension(tag, containing_type, ctx, &extension, &number, + if (!FindExtension(tag & 7, number, containing_type, ctx, &extension, &was_packed_on_wire)) { return UnknownFieldParse(tag, parent, begin, end, metadata->mutable_unknown_fields(), ctx); @@ -400,7 +402,7 @@ std::pair<const char*, bool> ExtensionSet::ParseField( case WireFormatLite::TYPE_##UPPERCASE: { \ uint64 value; \ ptr = Varint::Parse64(ptr, &value); \ - if (ptr == nullptr) goto error; \ + GOOGLE_PROTOBUF_ASSERT_RETURN(ptr, std::make_pair(nullptr, true)); \ if (extension.is_repeated) { \ Add##CPP_CAMELCASE(number, WireFormatLite::TYPE_##UPPERCASE, \ extension.is_packed, value, extension.descriptor); \ @@ -419,6 +421,7 @@ std::pair<const char*, bool> ExtensionSet::ParseField( case WireFormatLite::TYPE_##UPPERCASE: { \ uint64 val; \ ptr = Varint::Parse64(ptr, &val); \ + GOOGLE_PROTOBUF_ASSERT_RETURN(ptr, std::make_pair(nullptr, true)); \ auto value = WireFormatLite::ZigZagDecode##SIZE(val); \ if (extension.is_repeated) { \ Add##CPP_CAMELCASE(number, WireFormatLite::TYPE_##UPPERCASE, \ @@ -458,7 +461,7 @@ std::pair<const char*, bool> ExtensionSet::ParseField( case WireFormatLite::TYPE_ENUM: { uint64 val; ptr = Varint::Parse64(ptr, &val); - if (ptr == nullptr) goto error; + GOOGLE_PROTOBUF_ASSERT_RETURN(ptr, std::make_pair(nullptr, true)); int value = val; if (!extension.enum_validity_check.func( @@ -494,9 +497,10 @@ std::pair<const char*, bool> ExtensionSet::ParseField( *extension.message_prototype, extension.descriptor); child = {value->_ParseFunc(), value}; - if (!ctx->PrepareGroup(tag, &depth)) goto error; + bool ok = ctx->PrepareGroup(tag, &depth); + GOOGLE_PROTOBUF_ASSERT_RETURN(ok, std::make_pair(nullptr, true)); ptr = child(ptr, end, ctx); - if (!ptr) goto error; + GOOGLE_PROTOBUF_ASSERT_RETURN(ptr, std::make_pair(nullptr, true)); if (ctx->GroupContinues(depth)) goto group_continues; break; } @@ -517,19 +521,15 @@ std::pair<const char*, bool> ExtensionSet::ParseField( return std::make_pair(ptr, false); -error: - return std::make_pair(nullptr, true); - length_delim: uint32 size; ptr = Varint::Parse32Inline(ptr, &size); - if (!ptr) goto error; + GOOGLE_PROTOBUF_ASSERT_RETURN(ptr, std::make_pair(nullptr, true)); if (size > end - ptr) goto len_delim_till_end; { auto newend = ptr + size; - if (!ctx->ParseExactRange(child, ptr, newend)) { - goto error; - } + bool ok = ctx->ParseExactRange(child, ptr, newend); + GOOGLE_PROTOBUF_ASSERT_RETURN(ok, std::make_pair(nullptr, true)); ptr = newend; } return std::make_pair(ptr, false); @@ -570,64 +570,60 @@ const char* ExtensionSet::ParseMessageSetItem( if (tag == WireFormatLite::kMessageSetTypeIdTag) { uint32 type_id; ptr = Varint::Parse32(ptr, &type_id); - if (!ptr) goto error; + GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); if (ctx->extra_parse_data().payload.empty()) { tag = *ptr++; - if (tag == WireFormatLite::kMessageSetMessageTag) { - auto res = ParseField(type_id * 8 + 2, parent, ptr, end, - containing_type, metadata, ctx); - ptr = res.first; - if (res.second) break; - } else { - goto error; - } + GOOGLE_PROTOBUF_PARSER_ASSERT(tag == + WireFormatLite::kMessageSetMessageTag); + auto res = ParseField(static_cast<uint64>(type_id) * 8 + 2, parent, ptr, + end, containing_type, metadata, ctx); + ptr = res.first; + if (res.second) break; } else { ExtensionInfo extension; GeneratedExtensionFinder finder(containing_type); - int number; bool was_packed_on_wire; - if (!FindExtension(type_id * 8 + 2, containing_type, ctx, &extension, - &number, &was_packed_on_wire)) { + if (!FindExtension(2, type_id, containing_type, ctx, &extension, + &was_packed_on_wire)) { metadata->mutable_unknown_fields()->AddLengthDelimited( type_id, ctx->extra_parse_data().payload); continue; } MessageLite* value = extension.is_repeated - ? AddMessage(number, WireFormatLite::TYPE_MESSAGE, + ? AddMessage(type_id, WireFormatLite::TYPE_MESSAGE, *extension.message_prototype, extension.descriptor) - : MutableMessage(number, WireFormatLite::TYPE_MESSAGE, + : MutableMessage(type_id, WireFormatLite::TYPE_MESSAGE, *extension.message_prototype, extension.descriptor); ParseClosure parser = {value->_ParseFunc(), value}; StringPiece chunk(ctx->extra_parse_data().payload.data()); - if (!ctx->ParseExactRange(parser, chunk.begin(), chunk.end())) { - return nullptr; - } + bool ok = ctx->ParseExactRange(parser, chunk.begin(), chunk.end()); + GOOGLE_PROTOBUF_PARSER_ASSERT(ok); } } else if (tag == WireFormatLite::kMessageSetItemEndTag) { - if (!ctx->ValidEndGroup(tag)) goto error; + bool ok = ctx->ValidEndGroup(tag); + GOOGLE_PROTOBUF_PARSER_ASSERT(ok); break; } else if (tag == WireFormatLite::kMessageSetMessageTag) { uint32 size; ptr = Varint::Parse32Inline(ptr, &size); - if (!ptr) goto error; + GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); ParseClosure child = {internal::StringParser, &ctx->extra_parse_data().payload}; if (size > end - ptr) { return ctx->StoreAndTailCall(ptr, end, parent, child, size); } else { auto newend = ptr + size; - if (!ctx->ParseExactRange(child, ptr, newend)) { - goto error; - } + bool ok = ctx->ParseExactRange(child, ptr, newend); + GOOGLE_PROTOBUF_PARSER_ASSERT(ok); ptr = newend; } } else { ptr--; ptr = Varint::Parse32(ptr, &tag); - if (ptr == nullptr) goto error; + GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); auto res = ParseField(tag, parent, ptr, end, containing_type, metadata, ctx); ptr = res.first; @@ -635,8 +631,6 @@ const char* ExtensionSet::ParseMessageSetItem( } } return ptr; -error: - return nullptr; } #endif // GOOGLE_PROTOBUF_ENABLE_EXPERIMENTAL_PARSER @@ -744,7 +738,7 @@ uint8* ExtensionSet::SerializeMessageSetWithCachedSizesToArray( uint8* ExtensionSet::InternalSerializeWithCachedSizesToArray( int start_field_number, int end_field_number, bool deterministic, uint8* target) const { - if (GOOGLE_PREDICT_FALSE(is_large())) { + if (PROTOBUF_PREDICT_FALSE(is_large())) { const auto& end = map_.large->end(); for (auto it = map_.large->lower_bound(start_field_number); it != end && it->first < end_field_number; ++it) { @@ -983,82 +977,6 @@ bool ExtensionSet::ParseMessageSetItem(io::CodedInputStream* input, MSFull{this, extension_finder, field_skipper}); } -void ExtensionSet::Extension::SerializeMessageSetItemWithCachedSizes( - int number, - io::CodedOutputStream* output) const { - if (type != WireFormatLite::TYPE_MESSAGE || is_repeated) { - // Not a valid MessageSet extension, but serialize it the normal way. - SerializeFieldWithCachedSizes(number, output); - return; - } - - if (is_cleared) return; - - // Start group. - output->WriteTag(WireFormatLite::kMessageSetItemStartTag); - - // Write type ID. - WireFormatLite::WriteUInt32(WireFormatLite::kMessageSetTypeIdNumber, - number, - output); - // Write message. - if (is_lazy) { - lazymessage_value->WriteMessage( - WireFormatLite::kMessageSetMessageNumber, output); - } else { - WireFormatLite::WriteMessageMaybeToArray( - WireFormatLite::kMessageSetMessageNumber, - *message_value, - output); - } - - // End group. - output->WriteTag(WireFormatLite::kMessageSetItemEndTag); -} - -size_t ExtensionSet::Extension::MessageSetItemByteSize(int number) const { - if (type != WireFormatLite::TYPE_MESSAGE || is_repeated) { - // Not a valid MessageSet extension, but compute the byte size for it the - // normal way. - return ByteSize(number); - } - - if (is_cleared) return 0; - - size_t our_size = WireFormatLite::kMessageSetItemTagsSize; - - // type_id - our_size += io::CodedOutputStream::VarintSize32(number); - - // message - size_t message_size = 0; - if (is_lazy) { - message_size = lazymessage_value->ByteSizeLong(); - } else { - message_size = message_value->ByteSizeLong(); - } - - our_size += io::CodedOutputStream::VarintSize32(message_size); - our_size += message_size; - - return our_size; -} - -void ExtensionSet::SerializeMessageSetWithCachedSizes( - io::CodedOutputStream* output) const { - ForEach([output](int number, const Extension& ext) { - ext.SerializeMessageSetItemWithCachedSizes(number, output); - }); -} - -size_t ExtensionSet::MessageSetByteSize() const { - size_t total_size = 0; - ForEach([&total_size](int number, const Extension& ext) { - total_size += ext.MessageSetItemByteSize(number); - }); - return total_size; -} - } // namespace internal } // namespace protobuf } // namespace google |