aboutsummaryrefslogtreecommitdiff
path: root/csharp/src
diff options
context:
space:
mode:
Diffstat (limited to 'csharp/src')
-rw-r--r--csharp/src/Google.Protobuf.Test/CodedInputStreamTest.cs87
-rw-r--r--csharp/src/Google.Protobuf/CodedInputStream.cs55
-rw-r--r--csharp/src/Google.Protobuf/Collections/MapField.cs5
-rw-r--r--csharp/src/Google.Protobuf/FieldCodec.cs7
-rw-r--r--csharp/src/Google.Protobuf/MessageExtensions.cs6
-rw-r--r--csharp/src/Google.Protobuf/WireFormat.cs10
6 files changed, 138 insertions, 32 deletions
diff --git a/csharp/src/Google.Protobuf.Test/CodedInputStreamTest.cs b/csharp/src/Google.Protobuf.Test/CodedInputStreamTest.cs
index c4c92efd..42c740ac 100644
--- a/csharp/src/Google.Protobuf.Test/CodedInputStreamTest.cs
+++ b/csharp/src/Google.Protobuf.Test/CodedInputStreamTest.cs
@@ -442,5 +442,92 @@ namespace Google.Protobuf
var input = new CodedInputStream(new byte[] { 0 });
Assert.Throws<InvalidProtocolBufferException>(() => input.ReadTag());
}
+
+ [Test]
+ public void SkipGroup()
+ {
+ // Create an output stream with a group in:
+ // Field 1: string "field 1"
+ // Field 2: group containing:
+ // Field 1: fixed int32 value 100
+ // Field 2: string "ignore me"
+ // Field 3: nested group containing
+ // Field 1: fixed int64 value 1000
+ // Field 3: string "field 3"
+ var stream = new MemoryStream();
+ var output = new CodedOutputStream(stream);
+ output.WriteTag(1, WireFormat.WireType.LengthDelimited);
+ output.WriteString("field 1");
+
+ // The outer group...
+ output.WriteTag(2, WireFormat.WireType.StartGroup);
+ output.WriteTag(1, WireFormat.WireType.Fixed32);
+ output.WriteFixed32(100);
+ output.WriteTag(2, WireFormat.WireType.LengthDelimited);
+ output.WriteString("ignore me");
+ // The nested group...
+ output.WriteTag(3, WireFormat.WireType.StartGroup);
+ output.WriteTag(1, WireFormat.WireType.Fixed64);
+ output.WriteFixed64(1000);
+ // Note: Not sure the field number is relevant for end group...
+ output.WriteTag(3, WireFormat.WireType.EndGroup);
+
+ // End the outer group
+ output.WriteTag(2, WireFormat.WireType.EndGroup);
+
+ output.WriteTag(3, WireFormat.WireType.LengthDelimited);
+ output.WriteString("field 3");
+ output.Flush();
+ stream.Position = 0;
+
+ // Now act like a generated client
+ var input = new CodedInputStream(stream);
+ Assert.AreEqual(WireFormat.MakeTag(1, WireFormat.WireType.LengthDelimited), input.ReadTag());
+ Assert.AreEqual("field 1", input.ReadString());
+ Assert.AreEqual(WireFormat.MakeTag(2, WireFormat.WireType.StartGroup), input.ReadTag());
+ input.SkipLastField(); // Should consume the whole group, including the nested one.
+ Assert.AreEqual(WireFormat.MakeTag(3, WireFormat.WireType.LengthDelimited), input.ReadTag());
+ Assert.AreEqual("field 3", input.ReadString());
+ }
+
+ [Test]
+ public void EndOfStreamReachedWhileSkippingGroup()
+ {
+ var stream = new MemoryStream();
+ var output = new CodedOutputStream(stream);
+ output.WriteTag(1, WireFormat.WireType.StartGroup);
+ output.WriteTag(2, WireFormat.WireType.StartGroup);
+ output.WriteTag(2, WireFormat.WireType.EndGroup);
+
+ output.Flush();
+ stream.Position = 0;
+
+ // Now act like a generated client
+ var input = new CodedInputStream(stream);
+ input.ReadTag();
+ Assert.Throws<InvalidProtocolBufferException>(() => input.SkipLastField());
+ }
+
+ [Test]
+ public void RecursionLimitAppliedWhileSkippingGroup()
+ {
+ var stream = new MemoryStream();
+ var output = new CodedOutputStream(stream);
+ for (int i = 0; i < CodedInputStream.DefaultRecursionLimit + 1; i++)
+ {
+ output.WriteTag(1, WireFormat.WireType.StartGroup);
+ }
+ for (int i = 0; i < CodedInputStream.DefaultRecursionLimit + 1; i++)
+ {
+ output.WriteTag(1, WireFormat.WireType.EndGroup);
+ }
+ output.Flush();
+ stream.Position = 0;
+
+ // Now act like a generated client
+ var input = new CodedInputStream(stream);
+ Assert.AreEqual(WireFormat.MakeTag(1, WireFormat.WireType.StartGroup), input.ReadTag());
+ Assert.Throws<InvalidProtocolBufferException>(() => input.SkipLastField());
+ }
}
} \ No newline at end of file
diff --git a/csharp/src/Google.Protobuf/CodedInputStream.cs b/csharp/src/Google.Protobuf/CodedInputStream.cs
index 0e2495f1..a37fefc1 100644
--- a/csharp/src/Google.Protobuf/CodedInputStream.cs
+++ b/csharp/src/Google.Protobuf/CodedInputStream.cs
@@ -236,17 +236,16 @@ namespace Google.Protobuf
#region Validation
/// <summary>
- /// Verifies that the last call to ReadTag() returned the given tag value.
- /// This is used to verify that a nested group ended with the correct
- /// end tag.
+ /// Verifies that the last call to ReadTag() returned tag 0 - in other words,
+ /// we've reached the end of the stream when we expected to.
/// </summary>
- /// <exception cref="InvalidProtocolBufferException">The last
+ /// <exception cref="InvalidProtocolBufferException">The
/// tag read was not the one specified</exception>
- internal void CheckLastTagWas(uint value)
+ internal void CheckReadEndOfStreamTag()
{
- if (lastTag != value)
+ if (lastTag != 0)
{
- throw InvalidProtocolBufferException.InvalidEndTag();
+ throw InvalidProtocolBufferException.MoreDataAvailable();
}
}
#endregion
@@ -275,6 +274,11 @@ namespace Google.Protobuf
/// <summary>
/// Reads a field tag, returning the tag of 0 for "end of stream".
/// </summary>
+ /// <remarks>
+ /// If this method returns 0, it doesn't necessarily mean the end of all
+ /// the data in this CodedInputStream; it may be the end of the logical stream
+ /// for an embedded message, for example.
+ /// </remarks>
/// <returns>The next field tag, or 0 for end of stream. (0 is never a valid tag.)</returns>
public uint ReadTag()
{
@@ -329,22 +333,24 @@ namespace Google.Protobuf
}
/// <summary>
- /// Consumes the data for the field with the tag we've just read.
+ /// Skips the data for the field with the tag we've just read.
/// This should be called directly after <see cref="ReadTag"/>, when
/// the caller wishes to skip an unknown field.
/// </summary>
- public void ConsumeLastField()
+ public void SkipLastField()
{
if (lastTag == 0)
{
- throw new InvalidOperationException("ConsumeLastField cannot be called at the end of a stream");
+ throw new InvalidOperationException("SkipLastField cannot be called at the end of a stream");
}
switch (WireFormat.GetTagWireType(lastTag))
{
case WireFormat.WireType.StartGroup:
+ ConsumeGroup();
+ break;
case WireFormat.WireType.EndGroup:
- // TODO: Work out how to skip them instead? See issue 688.
- throw new InvalidProtocolBufferException("Group tags not supported by proto3 C# implementation");
+ // Just ignore; there's no data following the tag.
+ break;
case WireFormat.WireType.Fixed32:
ReadFixed32();
break;
@@ -361,6 +367,29 @@ namespace Google.Protobuf
}
}
+ private void ConsumeGroup()
+ {
+ // Note: Currently we expect this to be the way that groups are read. We could put the recursion
+ // depth changes into the ReadTag method instead, potentially...
+ recursionDepth++;
+ if (recursionDepth >= recursionLimit)
+ {
+ throw InvalidProtocolBufferException.RecursionLimitExceeded();
+ }
+ uint tag;
+ do
+ {
+ tag = ReadTag();
+ if (tag == 0)
+ {
+ throw InvalidProtocolBufferException.TruncatedMessage();
+ }
+ // This recursion will allow us to handle nested groups.
+ SkipLastField();
+ } while (WireFormat.GetTagWireType(tag) != WireFormat.WireType.EndGroup);
+ recursionDepth--;
+ }
+
/// <summary>
/// Reads a double field from the stream.
/// </summary>
@@ -475,7 +504,7 @@ namespace Google.Protobuf
int oldLimit = PushLimit(length);
++recursionDepth;
builder.MergeFrom(this);
- CheckLastTagWas(0);
+ CheckReadEndOfStreamTag();
// Check that we've read exactly as much data as expected.
if (!ReachedLimit)
{
diff --git a/csharp/src/Google.Protobuf/Collections/MapField.cs b/csharp/src/Google.Protobuf/Collections/MapField.cs
index 5eb2c2fc..dc4b04cb 100644
--- a/csharp/src/Google.Protobuf/Collections/MapField.cs
+++ b/csharp/src/Google.Protobuf/Collections/MapField.cs
@@ -637,10 +637,9 @@ namespace Google.Protobuf.Collections
{
Value = codec.valueCodec.Read(input);
}
- else if (WireFormat.IsEndGroupTag(tag))
+ else
{
- // TODO(jonskeet): Do we need this? (Given that we don't support groups...)
- return;
+ input.SkipLastField();
}
}
}
diff --git a/csharp/src/Google.Protobuf/FieldCodec.cs b/csharp/src/Google.Protobuf/FieldCodec.cs
index 15d52c7d..20a1f438 100644
--- a/csharp/src/Google.Protobuf/FieldCodec.cs
+++ b/csharp/src/Google.Protobuf/FieldCodec.cs
@@ -304,12 +304,13 @@ namespace Google.Protobuf
{
value = codec.Read(input);
}
- if (WireFormat.IsEndGroupTag(tag))
+ else
{
- break;
+ input.SkipLastField();
}
+
}
- input.CheckLastTagWas(0);
+ input.CheckReadEndOfStreamTag();
input.PopLimit(oldLimit);
return value;
diff --git a/csharp/src/Google.Protobuf/MessageExtensions.cs b/csharp/src/Google.Protobuf/MessageExtensions.cs
index ee78dc8d..d2d057c0 100644
--- a/csharp/src/Google.Protobuf/MessageExtensions.cs
+++ b/csharp/src/Google.Protobuf/MessageExtensions.cs
@@ -50,7 +50,7 @@ namespace Google.Protobuf
Preconditions.CheckNotNull(data, "data");
CodedInputStream input = new CodedInputStream(data);
message.MergeFrom(input);
- input.CheckLastTagWas(0);
+ input.CheckReadEndOfStreamTag();
}
/// <summary>
@@ -64,7 +64,7 @@ namespace Google.Protobuf
Preconditions.CheckNotNull(data, "data");
CodedInputStream input = data.CreateCodedInput();
message.MergeFrom(input);
- input.CheckLastTagWas(0);
+ input.CheckReadEndOfStreamTag();
}
/// <summary>
@@ -78,7 +78,7 @@ namespace Google.Protobuf
Preconditions.CheckNotNull(input, "input");
CodedInputStream codedInput = new CodedInputStream(input);
message.MergeFrom(codedInput);
- codedInput.CheckLastTagWas(0);
+ codedInput.CheckReadEndOfStreamTag();
}
/// <summary>
diff --git a/csharp/src/Google.Protobuf/WireFormat.cs b/csharp/src/Google.Protobuf/WireFormat.cs
index bbd7e4f9..b0e4a41f 100644
--- a/csharp/src/Google.Protobuf/WireFormat.cs
+++ b/csharp/src/Google.Protobuf/WireFormat.cs
@@ -99,16 +99,6 @@ namespace Google.Protobuf
}
/// <summary>
- /// Determines whether the given tag is an end group tag.
- /// </summary>
- /// <param name="tag">The tag to check.</param>
- /// <returns><c>true</c> if the given tag is an end group tag; <c>false</c> otherwise.</returns>
- public static bool IsEndGroupTag(uint tag)
- {
- return (WireType) (tag & TagTypeMask) == WireType.EndGroup;
- }
-
- /// <summary>
/// Given a tag value, determines the field number (the upper 29 bits).
/// </summary>
public static int GetTagFieldNumber(uint tag)