diff options
Diffstat (limited to 'objectivec/GPBCodedInputStream.m')
-rw-r--r-- | objectivec/GPBCodedInputStream.m | 80 |
1 files changed, 26 insertions, 54 deletions
diff --git a/objectivec/GPBCodedInputStream.m b/objectivec/GPBCodedInputStream.m index e8c8989c..0759640d 100644 --- a/objectivec/GPBCodedInputStream.m +++ b/objectivec/GPBCodedInputStream.m @@ -45,7 +45,12 @@ NSString *const GPBCodedInputStreamUnderlyingErrorKey = NSString *const GPBCodedInputStreamErrorDomain = GPBNSStringifySymbol(GPBCodedInputStreamErrorDomain); -static const NSUInteger kDefaultRecursionLimit = 64; +// Matching: +// https://github.com/google/protobuf/blob/master/java/core/src/main/java/com/google/protobuf/CodedInputStream.java#L62 +// private static final int DEFAULT_RECURSION_LIMIT = 100; +// https://github.com/google/protobuf/blob/master/src/google/protobuf/io/coded_stream.cc#L86 +// int CodedInputStream::default_recursion_limit_ = 100; +static const NSUInteger kDefaultRecursionLimit = 100; static void RaiseException(NSInteger code, NSString *reason) { NSDictionary *errorInfo = nil; @@ -63,6 +68,12 @@ static void RaiseException(NSInteger code, NSString *reason) { userInfo:exceptionInfo] raise]; } +static void CheckRecursionLimit(GPBCodedInputStreamState *state) { + if (state->recursionDepth >= kDefaultRecursionLimit) { + RaiseException(GPBCodedInputStreamErrorRecursionDepthExceeded, nil); + } +} + static void CheckSize(GPBCodedInputStreamState *state, size_t size) { size_t newSize = state->bufferPos + size; if (newSize > state->bufferSize) { @@ -94,41 +105,6 @@ static int64_t ReadRawLittleEndian64(GPBCodedInputStreamState *state) { return value; } -static int32_t ReadRawVarint32(GPBCodedInputStreamState *state) { - int8_t tmp = ReadRawByte(state); - if (tmp >= 0) { - return tmp; - } - int32_t result = tmp & 0x7f; - if ((tmp = ReadRawByte(state)) >= 0) { - result |= tmp << 7; - } else { - result |= (tmp & 0x7f) << 7; - if ((tmp = ReadRawByte(state)) >= 0) { - result |= tmp << 14; - } else { - result |= (tmp & 0x7f) << 14; - if ((tmp = ReadRawByte(state)) >= 0) { - result |= tmp << 21; - } else { - result |= (tmp & 0x7f) << 21; - result |= (tmp = ReadRawByte(state)) << 28; - if (tmp < 0) { - // Discard upper 32 bits. - for (int i = 0; i < 5; i++) { - if (ReadRawByte(state) >= 0) { - return result; - } - } - RaiseException(GPBCodedInputStreamErrorInvalidVarInt, - @"Invalid VarInt32"); - } - } - } - } - return result; -} - static int64_t ReadRawVarint64(GPBCodedInputStreamState *state) { int32_t shift = 0; int64_t result = 0; @@ -144,6 +120,10 @@ static int64_t ReadRawVarint64(GPBCodedInputStreamState *state) { return 0; } +static int32_t ReadRawVarint32(GPBCodedInputStreamState *state) { + return (int32_t)ReadRawVarint64(state); +} + static void SkipRawData(GPBCodedInputStreamState *state, size_t size) { CheckSize(state, size); state->bufferPos += size; @@ -225,16 +205,16 @@ int32_t GPBCodedInputStreamReadTag(GPBCodedInputStreamState *state) { } state->lastTag = ReadRawVarint32(state); - if (state->lastTag == 0) { - // If we actually read zero, that's not a valid tag. - RaiseException(GPBCodedInputStreamErrorInvalidTag, - @"A zero tag on the wire is invalid."); - } - // Tags have to include a valid wireformat, check that also. + // Tags have to include a valid wireformat. if (!GPBWireFormatIsValidTag(state->lastTag)) { RaiseException(GPBCodedInputStreamErrorInvalidTag, @"Invalid wireformat in tag."); } + // Zero is not a valid field number. + if (GPBWireFormatGetTagFieldNumber(state->lastTag) == 0) { + RaiseException(GPBCodedInputStreamErrorInvalidTag, + @"A zero field number on the wire is invalid."); + } return state->lastTag; } @@ -447,9 +427,7 @@ void GPBCodedInputStreamCheckLastTagWas(GPBCodedInputStreamState *state, - (void)readGroup:(int32_t)fieldNumber message:(GPBMessage *)message extensionRegistry:(GPBExtensionRegistry *)extensionRegistry { - if (state_.recursionDepth >= kDefaultRecursionLimit) { - RaiseException(GPBCodedInputStreamErrorRecursionDepthExceeded, nil); - } + CheckRecursionLimit(&state_); ++state_.recursionDepth; [message mergeFromCodedInputStream:self extensionRegistry:extensionRegistry]; GPBCodedInputStreamCheckLastTagWas( @@ -459,9 +437,7 @@ void GPBCodedInputStreamCheckLastTagWas(GPBCodedInputStreamState *state, - (void)readUnknownGroup:(int32_t)fieldNumber message:(GPBUnknownFieldSet *)message { - if (state_.recursionDepth >= kDefaultRecursionLimit) { - RaiseException(GPBCodedInputStreamErrorRecursionDepthExceeded, nil); - } + CheckRecursionLimit(&state_); ++state_.recursionDepth; [message mergeFromCodedInputStream:self]; GPBCodedInputStreamCheckLastTagWas( @@ -471,10 +447,8 @@ void GPBCodedInputStreamCheckLastTagWas(GPBCodedInputStreamState *state, - (void)readMessage:(GPBMessage *)message extensionRegistry:(GPBExtensionRegistry *)extensionRegistry { + CheckRecursionLimit(&state_); int32_t length = ReadRawVarint32(&state_); - if (state_.recursionDepth >= kDefaultRecursionLimit) { - RaiseException(GPBCodedInputStreamErrorRecursionDepthExceeded, nil); - } size_t oldLimit = GPBCodedInputStreamPushLimit(&state_, length); ++state_.recursionDepth; [message mergeFromCodedInputStream:self extensionRegistry:extensionRegistry]; @@ -487,10 +461,8 @@ void GPBCodedInputStreamCheckLastTagWas(GPBCodedInputStreamState *state, extensionRegistry:(GPBExtensionRegistry *)extensionRegistry field:(GPBFieldDescriptor *)field parentMessage:(GPBMessage *)parentMessage { + CheckRecursionLimit(&state_); int32_t length = ReadRawVarint32(&state_); - if (state_.recursionDepth >= kDefaultRecursionLimit) { - RaiseException(GPBCodedInputStreamErrorRecursionDepthExceeded, nil); - } size_t oldLimit = GPBCodedInputStreamPushLimit(&state_, length); ++state_.recursionDepth; GPBDictionaryReadEntry(mapDictionary, self, extensionRegistry, field, |