aboutsummaryrefslogtreecommitdiff
path: root/objectivec/GPBCodedInputStream.m
diff options
context:
space:
mode:
Diffstat (limited to 'objectivec/GPBCodedInputStream.m')
-rw-r--r--objectivec/GPBCodedInputStream.m88
1 files changed, 30 insertions, 58 deletions
diff --git a/objectivec/GPBCodedInputStream.m b/objectivec/GPBCodedInputStream.m
index e8c8989c..dd05ddb4 100644
--- a/objectivec/GPBCodedInputStream.m
+++ b/objectivec/GPBCodedInputStream.m
@@ -45,7 +45,12 @@ NSString *const GPBCodedInputStreamUnderlyingErrorKey =
NSString *const GPBCodedInputStreamErrorDomain =
GPBNSStringifySymbol(GPBCodedInputStreamErrorDomain);
-static const NSUInteger kDefaultRecursionLimit = 64;
+// Matching:
+// https://github.com/google/protobuf/blob/master/java/core/src/main/java/com/google/protobuf/CodedInputStream.java#L62
+// private static final int DEFAULT_RECURSION_LIMIT = 100;
+// https://github.com/google/protobuf/blob/master/src/google/protobuf/io/coded_stream.cc#L86
+// int CodedInputStream::default_recursion_limit_ = 100;
+static const NSUInteger kDefaultRecursionLimit = 100;
static void RaiseException(NSInteger code, NSString *reason) {
NSDictionary *errorInfo = nil;
@@ -58,9 +63,15 @@ static void RaiseException(NSInteger code, NSString *reason) {
NSDictionary *exceptionInfo =
@{ GPBCodedInputStreamUnderlyingErrorKey: error };
- [[[NSException alloc] initWithName:GPBCodedInputStreamException
- reason:reason
- userInfo:exceptionInfo] raise];
+ [[NSException exceptionWithName:GPBCodedInputStreamException
+ reason:reason
+ userInfo:exceptionInfo] raise];
+}
+
+static void CheckRecursionLimit(GPBCodedInputStreamState *state) {
+ if (state->recursionDepth >= kDefaultRecursionLimit) {
+ RaiseException(GPBCodedInputStreamErrorRecursionDepthExceeded, nil);
+ }
}
static void CheckSize(GPBCodedInputStreamState *state, size_t size) {
@@ -94,47 +105,12 @@ static int64_t ReadRawLittleEndian64(GPBCodedInputStreamState *state) {
return value;
}
-static int32_t ReadRawVarint32(GPBCodedInputStreamState *state) {
- int8_t tmp = ReadRawByte(state);
- if (tmp >= 0) {
- return tmp;
- }
- int32_t result = tmp & 0x7f;
- if ((tmp = ReadRawByte(state)) >= 0) {
- result |= tmp << 7;
- } else {
- result |= (tmp & 0x7f) << 7;
- if ((tmp = ReadRawByte(state)) >= 0) {
- result |= tmp << 14;
- } else {
- result |= (tmp & 0x7f) << 14;
- if ((tmp = ReadRawByte(state)) >= 0) {
- result |= tmp << 21;
- } else {
- result |= (tmp & 0x7f) << 21;
- result |= (tmp = ReadRawByte(state)) << 28;
- if (tmp < 0) {
- // Discard upper 32 bits.
- for (int i = 0; i < 5; i++) {
- if (ReadRawByte(state) >= 0) {
- return result;
- }
- }
- RaiseException(GPBCodedInputStreamErrorInvalidVarInt,
- @"Invalid VarInt32");
- }
- }
- }
- }
- return result;
-}
-
static int64_t ReadRawVarint64(GPBCodedInputStreamState *state) {
int32_t shift = 0;
int64_t result = 0;
while (shift < 64) {
int8_t b = ReadRawByte(state);
- result |= (int64_t)(b & 0x7F) << shift;
+ result |= (int64_t)((uint64_t)(b & 0x7F) << shift);
if ((b & 0x80) == 0) {
return result;
}
@@ -144,6 +120,10 @@ static int64_t ReadRawVarint64(GPBCodedInputStreamState *state) {
return 0;
}
+static int32_t ReadRawVarint32(GPBCodedInputStreamState *state) {
+ return (int32_t)ReadRawVarint64(state);
+}
+
static void SkipRawData(GPBCodedInputStreamState *state, size_t size) {
CheckSize(state, size);
state->bufferPos += size;
@@ -225,16 +205,16 @@ int32_t GPBCodedInputStreamReadTag(GPBCodedInputStreamState *state) {
}
state->lastTag = ReadRawVarint32(state);
- if (state->lastTag == 0) {
- // If we actually read zero, that's not a valid tag.
- RaiseException(GPBCodedInputStreamErrorInvalidTag,
- @"A zero tag on the wire is invalid.");
- }
- // Tags have to include a valid wireformat, check that also.
+ // Tags have to include a valid wireformat.
if (!GPBWireFormatIsValidTag(state->lastTag)) {
RaiseException(GPBCodedInputStreamErrorInvalidTag,
@"Invalid wireformat in tag.");
}
+ // Zero is not a valid field number.
+ if (GPBWireFormatGetTagFieldNumber(state->lastTag) == 0) {
+ RaiseException(GPBCodedInputStreamErrorInvalidTag,
+ @"A zero field number on the wire is invalid.");
+ }
return state->lastTag;
}
@@ -447,9 +427,7 @@ void GPBCodedInputStreamCheckLastTagWas(GPBCodedInputStreamState *state,
- (void)readGroup:(int32_t)fieldNumber
message:(GPBMessage *)message
extensionRegistry:(GPBExtensionRegistry *)extensionRegistry {
- if (state_.recursionDepth >= kDefaultRecursionLimit) {
- RaiseException(GPBCodedInputStreamErrorRecursionDepthExceeded, nil);
- }
+ CheckRecursionLimit(&state_);
++state_.recursionDepth;
[message mergeFromCodedInputStream:self extensionRegistry:extensionRegistry];
GPBCodedInputStreamCheckLastTagWas(
@@ -459,9 +437,7 @@ void GPBCodedInputStreamCheckLastTagWas(GPBCodedInputStreamState *state,
- (void)readUnknownGroup:(int32_t)fieldNumber
message:(GPBUnknownFieldSet *)message {
- if (state_.recursionDepth >= kDefaultRecursionLimit) {
- RaiseException(GPBCodedInputStreamErrorRecursionDepthExceeded, nil);
- }
+ CheckRecursionLimit(&state_);
++state_.recursionDepth;
[message mergeFromCodedInputStream:self];
GPBCodedInputStreamCheckLastTagWas(
@@ -471,10 +447,8 @@ void GPBCodedInputStreamCheckLastTagWas(GPBCodedInputStreamState *state,
- (void)readMessage:(GPBMessage *)message
extensionRegistry:(GPBExtensionRegistry *)extensionRegistry {
+ CheckRecursionLimit(&state_);
int32_t length = ReadRawVarint32(&state_);
- if (state_.recursionDepth >= kDefaultRecursionLimit) {
- RaiseException(GPBCodedInputStreamErrorRecursionDepthExceeded, nil);
- }
size_t oldLimit = GPBCodedInputStreamPushLimit(&state_, length);
++state_.recursionDepth;
[message mergeFromCodedInputStream:self extensionRegistry:extensionRegistry];
@@ -487,10 +461,8 @@ void GPBCodedInputStreamCheckLastTagWas(GPBCodedInputStreamState *state,
extensionRegistry:(GPBExtensionRegistry *)extensionRegistry
field:(GPBFieldDescriptor *)field
parentMessage:(GPBMessage *)parentMessage {
+ CheckRecursionLimit(&state_);
int32_t length = ReadRawVarint32(&state_);
- if (state_.recursionDepth >= kDefaultRecursionLimit) {
- RaiseException(GPBCodedInputStreamErrorRecursionDepthExceeded, nil);
- }
size_t oldLimit = GPBCodedInputStreamPushLimit(&state_, length);
++state_.recursionDepth;
GPBDictionaryReadEntry(mapDictionary, self, extensionRegistry, field,