aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJon Skeet <jonskeet@google.com>2015-08-06 11:40:32 +0100
committerJon Skeet <jonskeet@google.com>2015-08-06 11:40:32 +0100
commite7f88ff1294ada0fca19334ed2c844cdb98ea2f6 (patch)
tree97ab85611ecdc29c56afe217893bafa1d520fc27
parentad8a889d1e1e2b0efd5b7579aa57ea5326cda6da (diff)
downloadprotobuf-e7f88ff1294ada0fca19334ed2c844cdb98ea2f6.tar.gz
protobuf-e7f88ff1294ada0fca19334ed2c844cdb98ea2f6.tar.bz2
protobuf-e7f88ff1294ada0fca19334ed2c844cdb98ea2f6.zip
Skip groups properly.
Now the generated code doesn't need to check for end group tags, as it will skip whole groups at a time. Currently it will ignore extraneous end group tags, which may or may not be a good thing. Renamed ConsumeLastField to SkipLastField as it felt more natural. Removed WireFormat.IsEndGroupTag as it's no longer useful. This mostly fixes issue 688. (Generated code changes coming in next commit.)
-rw-r--r--csharp/src/Google.Protobuf.Test/CodedInputStreamTest.cs87
-rw-r--r--csharp/src/Google.Protobuf/CodedInputStream.cs55
-rw-r--r--csharp/src/Google.Protobuf/Collections/MapField.cs5
-rw-r--r--csharp/src/Google.Protobuf/FieldCodec.cs7
-rw-r--r--csharp/src/Google.Protobuf/MessageExtensions.cs6
-rw-r--r--csharp/src/Google.Protobuf/WireFormat.cs10
-rw-r--r--src/google/protobuf/compiler/csharp/csharp_message.cc5
7 files changed, 139 insertions, 36 deletions
diff --git a/csharp/src/Google.Protobuf.Test/CodedInputStreamTest.cs b/csharp/src/Google.Protobuf.Test/CodedInputStreamTest.cs
index c4c92efd..42c740ac 100644
--- a/csharp/src/Google.Protobuf.Test/CodedInputStreamTest.cs
+++ b/csharp/src/Google.Protobuf.Test/CodedInputStreamTest.cs
@@ -442,5 +442,92 @@ namespace Google.Protobuf
var input = new CodedInputStream(new byte[] { 0 });
Assert.Throws<InvalidProtocolBufferException>(() => input.ReadTag());
}
+
+ [Test]
+ public void SkipGroup()
+ {
+ // Create an output stream with a group in:
+ // Field 1: string "field 1"
+ // Field 2: group containing:
+ // Field 1: fixed int32 value 100
+ // Field 2: string "ignore me"
+ // Field 3: nested group containing
+ // Field 1: fixed int64 value 1000
+ // Field 3: string "field 3"
+ var stream = new MemoryStream();
+ var output = new CodedOutputStream(stream);
+ output.WriteTag(1, WireFormat.WireType.LengthDelimited);
+ output.WriteString("field 1");
+
+ // The outer group...
+ output.WriteTag(2, WireFormat.WireType.StartGroup);
+ output.WriteTag(1, WireFormat.WireType.Fixed32);
+ output.WriteFixed32(100);
+ output.WriteTag(2, WireFormat.WireType.LengthDelimited);
+ output.WriteString("ignore me");
+ // The nested group...
+ output.WriteTag(3, WireFormat.WireType.StartGroup);
+ output.WriteTag(1, WireFormat.WireType.Fixed64);
+ output.WriteFixed64(1000);
+ // Note: Not sure the field number is relevant for end group...
+ output.WriteTag(3, WireFormat.WireType.EndGroup);
+
+ // End the outer group
+ output.WriteTag(2, WireFormat.WireType.EndGroup);
+
+ output.WriteTag(3, WireFormat.WireType.LengthDelimited);
+ output.WriteString("field 3");
+ output.Flush();
+ stream.Position = 0;
+
+ // Now act like a generated client
+ var input = new CodedInputStream(stream);
+ Assert.AreEqual(WireFormat.MakeTag(1, WireFormat.WireType.LengthDelimited), input.ReadTag());
+ Assert.AreEqual("field 1", input.ReadString());
+ Assert.AreEqual(WireFormat.MakeTag(2, WireFormat.WireType.StartGroup), input.ReadTag());
+ input.SkipLastField(); // Should consume the whole group, including the nested one.
+ Assert.AreEqual(WireFormat.MakeTag(3, WireFormat.WireType.LengthDelimited), input.ReadTag());
+ Assert.AreEqual("field 3", input.ReadString());
+ }
+
+ [Test]
+ public void EndOfStreamReachedWhileSkippingGroup()
+ {
+ var stream = new MemoryStream();
+ var output = new CodedOutputStream(stream);
+ output.WriteTag(1, WireFormat.WireType.StartGroup);
+ output.WriteTag(2, WireFormat.WireType.StartGroup);
+ output.WriteTag(2, WireFormat.WireType.EndGroup);
+
+ output.Flush();
+ stream.Position = 0;
+
+ // Now act like a generated client
+ var input = new CodedInputStream(stream);
+ input.ReadTag();
+ Assert.Throws<InvalidProtocolBufferException>(() => input.SkipLastField());
+ }
+
+ [Test]
+ public void RecursionLimitAppliedWhileSkippingGroup()
+ {
+ var stream = new MemoryStream();
+ var output = new CodedOutputStream(stream);
+ for (int i = 0; i < CodedInputStream.DefaultRecursionLimit + 1; i++)
+ {
+ output.WriteTag(1, WireFormat.WireType.StartGroup);
+ }
+ for (int i = 0; i < CodedInputStream.DefaultRecursionLimit + 1; i++)
+ {
+ output.WriteTag(1, WireFormat.WireType.EndGroup);
+ }
+ output.Flush();
+ stream.Position = 0;
+
+ // Now act like a generated client
+ var input = new CodedInputStream(stream);
+ Assert.AreEqual(WireFormat.MakeTag(1, WireFormat.WireType.StartGroup), input.ReadTag());
+ Assert.Throws<InvalidProtocolBufferException>(() => input.SkipLastField());
+ }
}
} \ No newline at end of file
diff --git a/csharp/src/Google.Protobuf/CodedInputStream.cs b/csharp/src/Google.Protobuf/CodedInputStream.cs
index 0e2495f1..a37fefc1 100644
--- a/csharp/src/Google.Protobuf/CodedInputStream.cs
+++ b/csharp/src/Google.Protobuf/CodedInputStream.cs
@@ -236,17 +236,16 @@ namespace Google.Protobuf
#region Validation
/// <summary>
- /// Verifies that the last call to ReadTag() returned the given tag value.
- /// This is used to verify that a nested group ended with the correct
- /// end tag.
+ /// Verifies that the last call to ReadTag() returned tag 0 - in other words,
+ /// we've reached the end of the stream when we expected to.
/// </summary>
- /// <exception cref="InvalidProtocolBufferException">The last
+ /// <exception cref="InvalidProtocolBufferException">The
/// tag read was not the one specified</exception>
- internal void CheckLastTagWas(uint value)
+ internal void CheckReadEndOfStreamTag()
{
- if (lastTag != value)
+ if (lastTag != 0)
{
- throw InvalidProtocolBufferException.InvalidEndTag();
+ throw InvalidProtocolBufferException.MoreDataAvailable();
}
}
#endregion
@@ -275,6 +274,11 @@ namespace Google.Protobuf
/// <summary>
/// Reads a field tag, returning the tag of 0 for "end of stream".
/// </summary>
+ /// <remarks>
+ /// If this method returns 0, it doesn't necessarily mean the end of all
+ /// the data in this CodedInputStream; it may be the end of the logical stream
+ /// for an embedded message, for example.
+ /// </remarks>
/// <returns>The next field tag, or 0 for end of stream. (0 is never a valid tag.)</returns>
public uint ReadTag()
{
@@ -329,22 +333,24 @@ namespace Google.Protobuf
}
/// <summary>
- /// Consumes the data for the field with the tag we've just read.
+ /// Skips the data for the field with the tag we've just read.
/// This should be called directly after <see cref="ReadTag"/>, when
/// the caller wishes to skip an unknown field.
/// </summary>
- public void ConsumeLastField()
+ public void SkipLastField()
{
if (lastTag == 0)
{
- throw new InvalidOperationException("ConsumeLastField cannot be called at the end of a stream");
+ throw new InvalidOperationException("SkipLastField cannot be called at the end of a stream");
}
switch (WireFormat.GetTagWireType(lastTag))
{
case WireFormat.WireType.StartGroup:
+ ConsumeGroup();
+ break;
case WireFormat.WireType.EndGroup:
- // TODO: Work out how to skip them instead? See issue 688.
- throw new InvalidProtocolBufferException("Group tags not supported by proto3 C# implementation");
+ // Just ignore; there's no data following the tag.
+ break;
case WireFormat.WireType.Fixed32:
ReadFixed32();
break;
@@ -361,6 +367,29 @@ namespace Google.Protobuf
}
}
+ private void ConsumeGroup()
+ {
+ // Note: Currently we expect this to be the way that groups are read. We could put the recursion
+ // depth changes into the ReadTag method instead, potentially...
+ recursionDepth++;
+ if (recursionDepth >= recursionLimit)
+ {
+ throw InvalidProtocolBufferException.RecursionLimitExceeded();
+ }
+ uint tag;
+ do
+ {
+ tag = ReadTag();
+ if (tag == 0)
+ {
+ throw InvalidProtocolBufferException.TruncatedMessage();
+ }
+ // This recursion will allow us to handle nested groups.
+ SkipLastField();
+ } while (WireFormat.GetTagWireType(tag) != WireFormat.WireType.EndGroup);
+ recursionDepth--;
+ }
+
/// <summary>
/// Reads a double field from the stream.
/// </summary>
@@ -475,7 +504,7 @@ namespace Google.Protobuf
int oldLimit = PushLimit(length);
++recursionDepth;
builder.MergeFrom(this);
- CheckLastTagWas(0);
+ CheckReadEndOfStreamTag();
// Check that we've read exactly as much data as expected.
if (!ReachedLimit)
{
diff --git a/csharp/src/Google.Protobuf/Collections/MapField.cs b/csharp/src/Google.Protobuf/Collections/MapField.cs
index 5eb2c2fc..dc4b04cb 100644
--- a/csharp/src/Google.Protobuf/Collections/MapField.cs
+++ b/csharp/src/Google.Protobuf/Collections/MapField.cs
@@ -637,10 +637,9 @@ namespace Google.Protobuf.Collections
{
Value = codec.valueCodec.Read(input);
}
- else if (WireFormat.IsEndGroupTag(tag))
+ else
{
- // TODO(jonskeet): Do we need this? (Given that we don't support groups...)
- return;
+ input.SkipLastField();
}
}
}
diff --git a/csharp/src/Google.Protobuf/FieldCodec.cs b/csharp/src/Google.Protobuf/FieldCodec.cs
index 15d52c7d..20a1f438 100644
--- a/csharp/src/Google.Protobuf/FieldCodec.cs
+++ b/csharp/src/Google.Protobuf/FieldCodec.cs
@@ -304,12 +304,13 @@ namespace Google.Protobuf
{
value = codec.Read(input);
}
- if (WireFormat.IsEndGroupTag(tag))
+ else
{
- break;
+ input.SkipLastField();
}
+
}
- input.CheckLastTagWas(0);
+ input.CheckReadEndOfStreamTag();
input.PopLimit(oldLimit);
return value;
diff --git a/csharp/src/Google.Protobuf/MessageExtensions.cs b/csharp/src/Google.Protobuf/MessageExtensions.cs
index ee78dc8d..d2d057c0 100644
--- a/csharp/src/Google.Protobuf/MessageExtensions.cs
+++ b/csharp/src/Google.Protobuf/MessageExtensions.cs
@@ -50,7 +50,7 @@ namespace Google.Protobuf
Preconditions.CheckNotNull(data, "data");
CodedInputStream input = new CodedInputStream(data);
message.MergeFrom(input);
- input.CheckLastTagWas(0);
+ input.CheckReadEndOfStreamTag();
}
/// <summary>
@@ -64,7 +64,7 @@ namespace Google.Protobuf
Preconditions.CheckNotNull(data, "data");
CodedInputStream input = data.CreateCodedInput();
message.MergeFrom(input);
- input.CheckLastTagWas(0);
+ input.CheckReadEndOfStreamTag();
}
/// <summary>
@@ -78,7 +78,7 @@ namespace Google.Protobuf
Preconditions.CheckNotNull(input, "input");
CodedInputStream codedInput = new CodedInputStream(input);
message.MergeFrom(codedInput);
- codedInput.CheckLastTagWas(0);
+ codedInput.CheckReadEndOfStreamTag();
}
/// <summary>
diff --git a/csharp/src/Google.Protobuf/WireFormat.cs b/csharp/src/Google.Protobuf/WireFormat.cs
index bbd7e4f9..b0e4a41f 100644
--- a/csharp/src/Google.Protobuf/WireFormat.cs
+++ b/csharp/src/Google.Protobuf/WireFormat.cs
@@ -99,16 +99,6 @@ namespace Google.Protobuf
}
/// <summary>
- /// Determines whether the given tag is an end group tag.
- /// </summary>
- /// <param name="tag">The tag to check.</param>
- /// <returns><c>true</c> if the given tag is an end group tag; <c>false</c> otherwise.</returns>
- public static bool IsEndGroupTag(uint tag)
- {
- return (WireType) (tag & TagTypeMask) == WireType.EndGroup;
- }
-
- /// <summary>
/// Given a tag value, determines the field number (the upper 29 bits).
/// </summary>
public static int GetTagFieldNumber(uint tag)
diff --git a/src/google/protobuf/compiler/csharp/csharp_message.cc b/src/google/protobuf/compiler/csharp/csharp_message.cc
index 40c13de5..a71a7909 100644
--- a/src/google/protobuf/compiler/csharp/csharp_message.cc
+++ b/src/google/protobuf/compiler/csharp/csharp_message.cc
@@ -423,10 +423,7 @@ void MessageGenerator::GenerateMergingMethods(io::Printer* printer) {
printer->Indent();
printer->Print(
"default:\n"
- " if (pb::WireFormat.IsEndGroupTag(tag)) {\n"
- " return;\n"
- " }\n"
- " input.ConsumeLastField();\n" // We're not storing the data, but we still need to consume it.
+ " input.SkipLastField();\n" // We're not storing the data, but we still need to consume it.
" break;\n");
for (int i = 0; i < fields_by_number().size(); i++) {
const FieldDescriptor* field = fields_by_number()[i];