diff options
Diffstat (limited to 'src/google/protobuf/extension_set.cc')
-rw-r--r-- | src/google/protobuf/extension_set.cc | 360 |
1 files changed, 325 insertions, 35 deletions
diff --git a/src/google/protobuf/extension_set.cc b/src/google/protobuf/extension_set.cc index cb205c4f..cb40ab71 100644 --- a/src/google/protobuf/extension_set.cc +++ b/src/google/protobuf/extension_set.cc @@ -32,16 +32,25 @@ // Based on original Protocol Buffers design by // Sanjay Ghemawat, Jeff Dean, and others. -#include <google/protobuf/stubs/hash.h> +#include <google/protobuf/extension_set.h> + #include <tuple> +#include <unordered_map> #include <utility> #include <google/protobuf/stubs/common.h> -#include <google/protobuf/extension_set.h> -#include <google/protobuf/message_lite.h> #include <google/protobuf/io/coded_stream.h> -#include <google/protobuf/wire_format_lite_inl.h> +#include <google/protobuf/io/zero_copy_stream_impl_lite.h> +#include <google/protobuf/message_lite.h> +#include <google/protobuf/metadata_lite.h> #include <google/protobuf/repeated_field.h> +#include <google/protobuf/wire_format_lite_inl.h> #include <google/protobuf/stubs/map_util.h> +#include <google/protobuf/stubs/hash.h> + +#include <google/protobuf/port_def.inc> +#if GOOGLE_PROTOBUF_ENABLE_EXPERIMENTAL_PARSER +#include <google/protobuf/parse_context.h> +#endif namespace google { namespace protobuf { @@ -77,8 +86,16 @@ inline bool is_packable(WireFormatLite::WireType type) { } // Registry stuff. -typedef hash_map<std::pair<const MessageLite*, int>, - ExtensionInfo> ExtensionRegistry; +struct ExtensionHasher { + std::size_t operator()(const std::pair<const MessageLite*, int>& p) const { + return std::hash<const MessageLite*>{}(p.first) ^ + std::hash<int>{}(p.second); + } +}; + +typedef std::unordered_map<std::pair<const MessageLite*, int>, ExtensionInfo, + ExtensionHasher> + ExtensionRegistry; static const ExtensionRegistry* global_registry = nullptr; @@ -89,7 +106,7 @@ void Register(const MessageLite* containing_type, static auto local_static_registry = OnShutdownDelete(new ExtensionRegistry); global_registry = local_static_registry; if (!InsertIfNotPresent(local_static_registry, - std::make_pair(containing_type, number), info)) { + std::make_pair(containing_type, number), info)) { GOOGLE_LOG(FATAL) << "Multiple extension registrations for type \"" << containing_type->GetTypeName() << "\", field number " << number << "."; @@ -99,8 +116,9 @@ void Register(const MessageLite* containing_type, const ExtensionInfo* FindRegisteredExtension( const MessageLite* containing_type, int number) { return global_registry == nullptr - ? nullptr - : FindOrNull(*global_registry, std::make_pair(containing_type, number)); + ? nullptr + : FindOrNull(*global_registry, + std::make_pair(containing_type, number)); } } // namespace @@ -168,21 +186,21 @@ void ExtensionSet::RegisterMessageExtension(const MessageLite* containing_type, // =================================================================== // Constructors and basic methods. -ExtensionSet::ExtensionSet(::google::protobuf::Arena* arena) +ExtensionSet::ExtensionSet(Arena* arena) : arena_(arena), flat_capacity_(0), flat_size_(0), - map_{flat_capacity_ == 0 ? NULL - : ::google::protobuf::Arena::CreateArray<KeyValue>( - arena_, flat_capacity_)} {} + map_{flat_capacity_ == 0 + ? NULL + : Arena::CreateArray<KeyValue>(arena_, flat_capacity_)} {} ExtensionSet::ExtensionSet() : arena_(NULL), flat_capacity_(0), flat_size_(0), - map_{flat_capacity_ == 0 ? NULL - : ::google::protobuf::Arena::CreateArray<KeyValue>( - arena_, flat_capacity_)} {} + map_{flat_capacity_ == 0 + ? NULL + : Arena::CreateArray<KeyValue>(arena_, flat_capacity_)} {} ExtensionSet::~ExtensionSet() { // Deletes all allocated extensions. @@ -191,11 +209,27 @@ ExtensionSet::~ExtensionSet() { if (GOOGLE_PREDICT_FALSE(is_large())) { delete map_.large; } else { - delete[] map_.flat; + DeleteFlatMap(map_.flat, flat_capacity_); } } } +void ExtensionSet::DeleteFlatMap( + const ExtensionSet::KeyValue* flat, uint16 flat_capacity) { +#ifdef __cpp_sized_deallocation + // Arena::CreateArray already requires a trivially destructible type, but + // ensure this constraint is not violated in the future. + static_assert(std::is_trivially_destructible<KeyValue>::value, + "CreateArray requires a trivially destructible type"); + // A const-cast is needed, but this is safe as we are about to deallocate the + // array. + ::operator delete[]( + const_cast<ExtensionSet::KeyValue*>(flat), sizeof(*flat) * flat_capacity); +#else // !__cpp_sized_deallocation + delete[] flat; +#endif // !__cpp_sized_deallocation +} + // Defined in extension_set_heavy.cc. // void ExtensionSet::AppendToList(const Descriptor* containing_type, // const DescriptorPool* pool, @@ -594,7 +628,7 @@ void ExtensionSet::SetAllocatedMessage(int number, FieldType type, ClearExtension(number); return; } - ::google::protobuf::Arena* message_arena = message->GetArena(); + Arena* message_arena = message->GetArena(); Extension* extension; if (MaybeNewExtension(number, descriptor, &extension)) { extension->type = type; @@ -746,10 +780,9 @@ MessageLite* ExtensionSet::AddMessage(int number, FieldType type, // RepeatedPtrField<MessageLite> does not know how to Add() since it cannot // allocate an abstract object, so we have to be tricky. - MessageLite* result = - reinterpret_cast<::google::protobuf::internal::RepeatedPtrFieldBase*>( - extension->repeated_message_value) - ->AddFromCleared<GenericTypeHandler<MessageLite> >(); + MessageLite* result = reinterpret_cast<internal::RepeatedPtrFieldBase*>( + extension->repeated_message_value) + ->AddFromCleared<GenericTypeHandler<MessageLite>>(); if (result == NULL) { result = prototype.New(arena_); extension->repeated_message_value->AddAllocated(result); @@ -945,9 +978,9 @@ void ExtensionSet::InternalExtensionMergeFrom( for (int i = 0; i < other_repeated_message->size(); i++) { const MessageLite& other_message = other_repeated_message->Get(i); MessageLite* target = - reinterpret_cast<::google::protobuf::internal::RepeatedPtrFieldBase*>( + reinterpret_cast<internal::RepeatedPtrFieldBase*>( extension->repeated_message_value) - ->AddFromCleared<GenericTypeHandler<MessageLite> >(); + ->AddFromCleared<GenericTypeHandler<MessageLite>>(); if (target == NULL) { target = other_message.New(arena_); extension->repeated_message_value->AddAllocated(target); @@ -1167,6 +1200,214 @@ bool ExtensionSet::ParseField(uint32 tag, io::CodedInputStream* input, } } +#if GOOGLE_PROTOBUF_ENABLE_EXPERIMENTAL_PARSER +std::pair<const char*, bool> ExtensionSet::ParseField( + uint32 tag, ParseClosure parent, const char* begin, const char* end, + const MessageLite* containing_type, + internal::InternalMetadataWithArenaLite* metadata, + internal::ParseContext* ctx) { + GeneratedExtensionFinder finder(containing_type); + int number; + bool was_packed_on_wire; + ExtensionInfo extension; + if (!FindExtensionInfoFromTag(tag, &finder, &number, &extension, + &was_packed_on_wire)) { + return UnknownFieldParse(tag, parent, begin, end, + metadata->mutable_unknown_fields(), ctx); + } + auto ptr = begin; + ParseClosure child; + int depth; + if (was_packed_on_wire) { + switch (extension.type) { +#define HANDLE_TYPE(UPPERCASE, CPP_CAMELCASE) \ + case WireFormatLite::TYPE_##UPPERCASE: \ + child = { \ + internal::Packed##CPP_CAMELCASE##Parser, \ + MutableRawRepeatedField(number, extension.type, extension.is_packed, \ + extension.descriptor)}; \ + goto length_delim + HANDLE_TYPE(INT32, Int32); + HANDLE_TYPE(INT64, Int64); + HANDLE_TYPE(UINT32, UInt32); + HANDLE_TYPE(UINT64, UInt64); + HANDLE_TYPE(SINT32, SInt32); + HANDLE_TYPE(SINT64, SInt64); + HANDLE_TYPE(FIXED32, Fixed32); + HANDLE_TYPE(FIXED64, Fixed64); + HANDLE_TYPE(SFIXED32, SFixed32); + HANDLE_TYPE(SFIXED64, SFixed64); + HANDLE_TYPE(FLOAT, Float); + HANDLE_TYPE(DOUBLE, Double); + HANDLE_TYPE(BOOL, Bool); +#undef HANDLE_TYPE + + case WireFormatLite::TYPE_ENUM: + ctx->extra_parse_data().SetEnumValidatorArg( + extension.enum_validity_check.func, + extension.enum_validity_check.arg, + metadata->mutable_unknown_fields(), tag >> 3); + child = { + internal::PackedValidEnumParserLiteArg, + MutableRawRepeatedField(number, extension.type, extension.is_packed, + extension.descriptor)}; + goto length_delim; + case WireFormatLite::TYPE_STRING: + case WireFormatLite::TYPE_BYTES: + case WireFormatLite::TYPE_GROUP: + case WireFormatLite::TYPE_MESSAGE: + GOOGLE_LOG(FATAL) << "Non-primitive types can't be packed."; + break; + } + } else { + switch (extension.type) { +#define HANDLE_VARINT_TYPE(UPPERCASE, CPP_CAMELCASE) \ + case WireFormatLite::TYPE_##UPPERCASE: { \ + uint64 value; \ + ptr = Varint::Parse64(ptr, &value); \ + if (ptr == nullptr) goto error; \ + if (extension.is_repeated) { \ + Add##CPP_CAMELCASE(number, WireFormatLite::TYPE_##UPPERCASE, \ + extension.is_packed, value, extension.descriptor); \ + } else { \ + Set##CPP_CAMELCASE(number, WireFormatLite::TYPE_##UPPERCASE, value, \ + extension.descriptor); \ + } \ + } break + + HANDLE_VARINT_TYPE(INT32, Int32); + HANDLE_VARINT_TYPE(INT64, Int64); + HANDLE_VARINT_TYPE(UINT32, UInt32); + HANDLE_VARINT_TYPE(UINT64, UInt64); +#undef HANDLE_VARINT_TYPE +#define HANDLE_SVARINT_TYPE(UPPERCASE, CPP_CAMELCASE, SIZE) \ + case WireFormatLite::TYPE_##UPPERCASE: { \ + uint64 val; \ + ptr = Varint::Parse64(ptr, &val); \ + auto value = WireFormatLite::ZigZagDecode##SIZE(val); \ + if (extension.is_repeated) { \ + Add##CPP_CAMELCASE(number, WireFormatLite::TYPE_##UPPERCASE, \ + extension.is_packed, value, extension.descriptor); \ + } else { \ + Set##CPP_CAMELCASE(number, WireFormatLite::TYPE_##UPPERCASE, value, \ + extension.descriptor); \ + } \ + } break + + HANDLE_SVARINT_TYPE(SINT32, Int32, 32); + HANDLE_SVARINT_TYPE(SINT64, Int64, 64); +#undef HANDLE_SVARINT_TYPE +#define HANDLE_FIXED_TYPE(UPPERCASE, CPP_CAMELCASE, CPPTYPE) \ + case WireFormatLite::TYPE_##UPPERCASE: { \ + CPPTYPE value; \ + std::memcpy(&value, ptr, sizeof(CPPTYPE)); \ + ptr += sizeof(CPPTYPE); \ + if (extension.is_repeated) { \ + Add##CPP_CAMELCASE(number, WireFormatLite::TYPE_##UPPERCASE, \ + extension.is_packed, value, extension.descriptor); \ + } else { \ + Set##CPP_CAMELCASE(number, WireFormatLite::TYPE_##UPPERCASE, value, \ + extension.descriptor); \ + } \ + } break + + HANDLE_FIXED_TYPE(FIXED32, UInt32, uint32); + HANDLE_FIXED_TYPE(FIXED64, UInt64, uint64); + HANDLE_FIXED_TYPE(SFIXED32, Int32, int32); + HANDLE_FIXED_TYPE(SFIXED64, Int64, int64); + HANDLE_FIXED_TYPE(FLOAT, Float, float); + HANDLE_FIXED_TYPE(DOUBLE, Double, double); + HANDLE_FIXED_TYPE(BOOL, Bool, bool); +#undef HANDLE_FIXED_TYPE + + case WireFormatLite::TYPE_ENUM: { + uint64 val; + ptr = Varint::Parse64(ptr, &val); + if (ptr == nullptr) goto error; + int value = val; + + if (!extension.enum_validity_check.func( + extension.enum_validity_check.arg, value)) { + WriteVarint(number, val, metadata->mutable_unknown_fields()); + } else if (extension.is_repeated) { + AddEnum(number, WireFormatLite::TYPE_ENUM, extension.is_packed, value, + extension.descriptor); + } else { + SetEnum(number, WireFormatLite::TYPE_ENUM, value, + extension.descriptor); + } + break; + } + + case WireFormatLite::TYPE_BYTES: + case WireFormatLite::TYPE_STRING: { + string* value = extension.is_repeated + ? AddString(number, WireFormatLite::TYPE_STRING, + extension.descriptor) + : MutableString(number, WireFormatLite::TYPE_STRING, + extension.descriptor); + child = {StringParser, value}; + goto length_delim; + } + + case WireFormatLite::TYPE_GROUP: { + MessageLite* value = + extension.is_repeated + ? AddMessage(number, WireFormatLite::TYPE_GROUP, + *extension.message_prototype, extension.descriptor) + : MutableMessage(number, WireFormatLite::TYPE_GROUP, + *extension.message_prototype, + extension.descriptor); + child = {value->_ParseFunc(), value}; + if (!ctx->PrepareGroup(tag, &depth)) goto error; + ptr = child(ptr, end, ctx); + if (!ptr) goto error; + if (ctx->GroupContinues(depth)) goto group_continues; + break; + } + + case WireFormatLite::TYPE_MESSAGE: { + MessageLite* value = + extension.is_repeated + ? AddMessage(number, WireFormatLite::TYPE_MESSAGE, + *extension.message_prototype, extension.descriptor) + : MutableMessage(number, WireFormatLite::TYPE_MESSAGE, + *extension.message_prototype, + extension.descriptor); + child = {value->_ParseFunc(), value}; + goto length_delim; + } + } + } + + return std::make_pair(ptr, false); + +error: + return std::make_pair(nullptr, true); + +length_delim: + uint32 size; + ptr = Varint::Parse32Inline(ptr, &size); + if (!ptr) goto error; + if (size > end - ptr) goto len_delim_till_end; + { + auto newend = ptr + size; + if (!ctx->ParseExactRange(child, ptr, newend)) { + goto error; + } + ptr = newend; + } + return std::make_pair(ptr, false); +len_delim_till_end: + return std::make_pair(ctx->StoreAndTailCall(ptr, end, parent, child, size), + true); + +group_continues: + ctx->StoreGroup(parent, child, depth); + return std::make_pair(ptr, true); +} +#endif + bool ExtensionSet::ParseFieldWithExtensionInfo( int number, bool was_packed_on_wire, const ExtensionInfo& extension, io::CodedInputStream* input, @@ -1342,15 +1583,60 @@ bool ExtensionSet::ParseField(uint32 tag, io::CodedInputStream* input, return ParseField(tag, input, &finder, &skipper); } -// Defined in extension_set_heavy.cc. -// bool ExtensionSet::ParseField(uint32 tag, io::CodedInputStream* input, -// const MessageLite* containing_type, -// UnknownFieldSet* unknown_fields) +bool ExtensionSet::ParseMessageSetLite(io::CodedInputStream* input, + ExtensionFinder* extension_finder, + FieldSkipper* field_skipper) { + while (true) { + const uint32 tag = input->ReadTag(); + switch (tag) { + case 0: + return true; + case WireFormatLite::kMessageSetItemStartTag: + if (!ParseMessageSetItemLite(input, extension_finder, field_skipper)) { + return false; + } + break; + default: + if (!ParseField(tag, input, extension_finder, field_skipper)) { + return false; + } + break; + } + } +} -// Defined in extension_set_heavy.cc. -// bool ExtensionSet::ParseMessageSet(io::CodedInputStream* input, -// const MessageLite* containing_type, -// UnknownFieldSet* unknown_fields); +bool ExtensionSet::ParseMessageSetItemLite(io::CodedInputStream* input, + ExtensionFinder* extension_finder, + FieldSkipper* field_skipper) { + struct MSLite { + bool ParseField(int type_id, io::CodedInputStream* input) { + return me->ParseField( + WireFormatLite::WIRETYPE_LENGTH_DELIMITED + 8 * type_id, input, + extension_finder, field_skipper); + } + + bool SkipField(uint32 tag, io::CodedInputStream* input) { + return field_skipper->SkipField(input, tag); + } + + ExtensionSet* me; + ExtensionFinder* extension_finder; + FieldSkipper* field_skipper; + }; + + return ParseMessageSetItemImpl(input, + MSLite{this, extension_finder, field_skipper}); +} + +bool ExtensionSet::ParseMessageSet(io::CodedInputStream* input, + const MessageLite* containing_type, + string* unknown_fields) { + io::StringOutputStream zcis(unknown_fields); + io::CodedOutputStream output(&zcis); + CodedOutputStreamFieldSkipper skipper(&output); + GeneratedExtensionFinder finder(containing_type); + return ParseMessageSetLite(input, &finder, &skipper); +} void ExtensionSet::SerializeWithCachedSizes( int start_field_number, int end_field_number, @@ -1859,6 +2145,8 @@ void ExtensionSet::GrowCapacity(size_t minimum_new_capacity) { return; } + const auto old_flat_capacity = flat_capacity_; + do { flat_capacity_ = flat_capacity_ == 0 ? 1 : flat_capacity_ * 4; } while (flat_capacity_ < minimum_new_capacity); @@ -1867,17 +2155,19 @@ void ExtensionSet::GrowCapacity(size_t minimum_new_capacity) { const KeyValue* end = flat_end(); if (flat_capacity_ > kMaximumFlatCapacity) { // Switch to LargeMap - map_.large = ::google::protobuf::Arena::Create<LargeMap>(arena_); + map_.large = Arena::Create<LargeMap>(arena_); LargeMap::iterator hint = map_.large->begin(); for (const KeyValue* it = begin; it != end; ++it) { hint = map_.large->insert(hint, {it->first, it->second}); } flat_size_ = 0; } else { - map_.flat = ::google::protobuf::Arena::CreateArray<KeyValue>(arena_, flat_capacity_); + map_.flat = Arena::CreateArray<KeyValue>(arena_, flat_capacity_); std::copy(begin, end, map_.flat); } - if (arena_ == NULL) delete[] begin; + if (arena_ == nullptr) { + DeleteFlatMap(begin, old_flat_capacity); + } } // static |