From e7f88ff1294ada0fca19334ed2c844cdb98ea2f6 Mon Sep 17 00:00:00 2001 From: Jon Skeet Date: Thu, 6 Aug 2015 11:40:32 +0100 Subject: Skip groups properly. Now the generated code doesn't need to check for end group tags, as it will skip whole groups at a time. Currently it will ignore extraneous end group tags, which may or may not be a good thing. Renamed ConsumeLastField to SkipLastField as it felt more natural. Removed WireFormat.IsEndGroupTag as it's no longer useful. This mostly fixes issue 688. (Generated code changes coming in next commit.) --- .../Google.Protobuf.Test/CodedInputStreamTest.cs | 87 ++++++++++++++++++++++ csharp/src/Google.Protobuf/CodedInputStream.cs | 55 ++++++++++---- csharp/src/Google.Protobuf/Collections/MapField.cs | 5 +- csharp/src/Google.Protobuf/FieldCodec.cs | 7 +- csharp/src/Google.Protobuf/MessageExtensions.cs | 6 +- csharp/src/Google.Protobuf/WireFormat.cs | 10 --- 6 files changed, 138 insertions(+), 32 deletions(-) (limited to 'csharp') diff --git a/csharp/src/Google.Protobuf.Test/CodedInputStreamTest.cs b/csharp/src/Google.Protobuf.Test/CodedInputStreamTest.cs index c4c92efd..42c740ac 100644 --- a/csharp/src/Google.Protobuf.Test/CodedInputStreamTest.cs +++ b/csharp/src/Google.Protobuf.Test/CodedInputStreamTest.cs @@ -442,5 +442,92 @@ namespace Google.Protobuf var input = new CodedInputStream(new byte[] { 0 }); Assert.Throws(() => input.ReadTag()); } + + [Test] + public void SkipGroup() + { + // Create an output stream with a group in: + // Field 1: string "field 1" + // Field 2: group containing: + // Field 1: fixed int32 value 100 + // Field 2: string "ignore me" + // Field 3: nested group containing + // Field 1: fixed int64 value 1000 + // Field 3: string "field 3" + var stream = new MemoryStream(); + var output = new CodedOutputStream(stream); + output.WriteTag(1, WireFormat.WireType.LengthDelimited); + output.WriteString("field 1"); + + // The outer group... + output.WriteTag(2, WireFormat.WireType.StartGroup); + output.WriteTag(1, WireFormat.WireType.Fixed32); + output.WriteFixed32(100); + output.WriteTag(2, WireFormat.WireType.LengthDelimited); + output.WriteString("ignore me"); + // The nested group... + output.WriteTag(3, WireFormat.WireType.StartGroup); + output.WriteTag(1, WireFormat.WireType.Fixed64); + output.WriteFixed64(1000); + // Note: Not sure the field number is relevant for end group... + output.WriteTag(3, WireFormat.WireType.EndGroup); + + // End the outer group + output.WriteTag(2, WireFormat.WireType.EndGroup); + + output.WriteTag(3, WireFormat.WireType.LengthDelimited); + output.WriteString("field 3"); + output.Flush(); + stream.Position = 0; + + // Now act like a generated client + var input = new CodedInputStream(stream); + Assert.AreEqual(WireFormat.MakeTag(1, WireFormat.WireType.LengthDelimited), input.ReadTag()); + Assert.AreEqual("field 1", input.ReadString()); + Assert.AreEqual(WireFormat.MakeTag(2, WireFormat.WireType.StartGroup), input.ReadTag()); + input.SkipLastField(); // Should consume the whole group, including the nested one. + Assert.AreEqual(WireFormat.MakeTag(3, WireFormat.WireType.LengthDelimited), input.ReadTag()); + Assert.AreEqual("field 3", input.ReadString()); + } + + [Test] + public void EndOfStreamReachedWhileSkippingGroup() + { + var stream = new MemoryStream(); + var output = new CodedOutputStream(stream); + output.WriteTag(1, WireFormat.WireType.StartGroup); + output.WriteTag(2, WireFormat.WireType.StartGroup); + output.WriteTag(2, WireFormat.WireType.EndGroup); + + output.Flush(); + stream.Position = 0; + + // Now act like a generated client + var input = new CodedInputStream(stream); + input.ReadTag(); + Assert.Throws(() => input.SkipLastField()); + } + + [Test] + public void RecursionLimitAppliedWhileSkippingGroup() + { + var stream = new MemoryStream(); + var output = new CodedOutputStream(stream); + for (int i = 0; i < CodedInputStream.DefaultRecursionLimit + 1; i++) + { + output.WriteTag(1, WireFormat.WireType.StartGroup); + } + for (int i = 0; i < CodedInputStream.DefaultRecursionLimit + 1; i++) + { + output.WriteTag(1, WireFormat.WireType.EndGroup); + } + output.Flush(); + stream.Position = 0; + + // Now act like a generated client + var input = new CodedInputStream(stream); + Assert.AreEqual(WireFormat.MakeTag(1, WireFormat.WireType.StartGroup), input.ReadTag()); + Assert.Throws(() => input.SkipLastField()); + } } } \ No newline at end of file diff --git a/csharp/src/Google.Protobuf/CodedInputStream.cs b/csharp/src/Google.Protobuf/CodedInputStream.cs index 0e2495f1..a37fefc1 100644 --- a/csharp/src/Google.Protobuf/CodedInputStream.cs +++ b/csharp/src/Google.Protobuf/CodedInputStream.cs @@ -236,17 +236,16 @@ namespace Google.Protobuf #region Validation /// - /// Verifies that the last call to ReadTag() returned the given tag value. - /// This is used to verify that a nested group ended with the correct - /// end tag. + /// Verifies that the last call to ReadTag() returned tag 0 - in other words, + /// we've reached the end of the stream when we expected to. /// - /// The last + /// The /// tag read was not the one specified - internal void CheckLastTagWas(uint value) + internal void CheckReadEndOfStreamTag() { - if (lastTag != value) + if (lastTag != 0) { - throw InvalidProtocolBufferException.InvalidEndTag(); + throw InvalidProtocolBufferException.MoreDataAvailable(); } } #endregion @@ -275,6 +274,11 @@ namespace Google.Protobuf /// /// Reads a field tag, returning the tag of 0 for "end of stream". /// + /// + /// If this method returns 0, it doesn't necessarily mean the end of all + /// the data in this CodedInputStream; it may be the end of the logical stream + /// for an embedded message, for example. + /// /// The next field tag, or 0 for end of stream. (0 is never a valid tag.) public uint ReadTag() { @@ -329,22 +333,24 @@ namespace Google.Protobuf } /// - /// Consumes the data for the field with the tag we've just read. + /// Skips the data for the field with the tag we've just read. /// This should be called directly after , when /// the caller wishes to skip an unknown field. /// - public void ConsumeLastField() + public void SkipLastField() { if (lastTag == 0) { - throw new InvalidOperationException("ConsumeLastField cannot be called at the end of a stream"); + throw new InvalidOperationException("SkipLastField cannot be called at the end of a stream"); } switch (WireFormat.GetTagWireType(lastTag)) { case WireFormat.WireType.StartGroup: + ConsumeGroup(); + break; case WireFormat.WireType.EndGroup: - // TODO: Work out how to skip them instead? See issue 688. - throw new InvalidProtocolBufferException("Group tags not supported by proto3 C# implementation"); + // Just ignore; there's no data following the tag. + break; case WireFormat.WireType.Fixed32: ReadFixed32(); break; @@ -361,6 +367,29 @@ namespace Google.Protobuf } } + private void ConsumeGroup() + { + // Note: Currently we expect this to be the way that groups are read. We could put the recursion + // depth changes into the ReadTag method instead, potentially... + recursionDepth++; + if (recursionDepth >= recursionLimit) + { + throw InvalidProtocolBufferException.RecursionLimitExceeded(); + } + uint tag; + do + { + tag = ReadTag(); + if (tag == 0) + { + throw InvalidProtocolBufferException.TruncatedMessage(); + } + // This recursion will allow us to handle nested groups. + SkipLastField(); + } while (WireFormat.GetTagWireType(tag) != WireFormat.WireType.EndGroup); + recursionDepth--; + } + /// /// Reads a double field from the stream. /// @@ -475,7 +504,7 @@ namespace Google.Protobuf int oldLimit = PushLimit(length); ++recursionDepth; builder.MergeFrom(this); - CheckLastTagWas(0); + CheckReadEndOfStreamTag(); // Check that we've read exactly as much data as expected. if (!ReachedLimit) { diff --git a/csharp/src/Google.Protobuf/Collections/MapField.cs b/csharp/src/Google.Protobuf/Collections/MapField.cs index 5eb2c2fc..dc4b04cb 100644 --- a/csharp/src/Google.Protobuf/Collections/MapField.cs +++ b/csharp/src/Google.Protobuf/Collections/MapField.cs @@ -637,10 +637,9 @@ namespace Google.Protobuf.Collections { Value = codec.valueCodec.Read(input); } - else if (WireFormat.IsEndGroupTag(tag)) + else { - // TODO(jonskeet): Do we need this? (Given that we don't support groups...) - return; + input.SkipLastField(); } } } diff --git a/csharp/src/Google.Protobuf/FieldCodec.cs b/csharp/src/Google.Protobuf/FieldCodec.cs index 15d52c7d..20a1f438 100644 --- a/csharp/src/Google.Protobuf/FieldCodec.cs +++ b/csharp/src/Google.Protobuf/FieldCodec.cs @@ -304,12 +304,13 @@ namespace Google.Protobuf { value = codec.Read(input); } - if (WireFormat.IsEndGroupTag(tag)) + else { - break; + input.SkipLastField(); } + } - input.CheckLastTagWas(0); + input.CheckReadEndOfStreamTag(); input.PopLimit(oldLimit); return value; diff --git a/csharp/src/Google.Protobuf/MessageExtensions.cs b/csharp/src/Google.Protobuf/MessageExtensions.cs index ee78dc8d..d2d057c0 100644 --- a/csharp/src/Google.Protobuf/MessageExtensions.cs +++ b/csharp/src/Google.Protobuf/MessageExtensions.cs @@ -50,7 +50,7 @@ namespace Google.Protobuf Preconditions.CheckNotNull(data, "data"); CodedInputStream input = new CodedInputStream(data); message.MergeFrom(input); - input.CheckLastTagWas(0); + input.CheckReadEndOfStreamTag(); } /// @@ -64,7 +64,7 @@ namespace Google.Protobuf Preconditions.CheckNotNull(data, "data"); CodedInputStream input = data.CreateCodedInput(); message.MergeFrom(input); - input.CheckLastTagWas(0); + input.CheckReadEndOfStreamTag(); } /// @@ -78,7 +78,7 @@ namespace Google.Protobuf Preconditions.CheckNotNull(input, "input"); CodedInputStream codedInput = new CodedInputStream(input); message.MergeFrom(codedInput); - codedInput.CheckLastTagWas(0); + codedInput.CheckReadEndOfStreamTag(); } /// diff --git a/csharp/src/Google.Protobuf/WireFormat.cs b/csharp/src/Google.Protobuf/WireFormat.cs index bbd7e4f9..b0e4a41f 100644 --- a/csharp/src/Google.Protobuf/WireFormat.cs +++ b/csharp/src/Google.Protobuf/WireFormat.cs @@ -98,16 +98,6 @@ namespace Google.Protobuf return (WireType) (tag & TagTypeMask); } - /// - /// Determines whether the given tag is an end group tag. - /// - /// The tag to check. - /// true if the given tag is an end group tag; false otherwise. - public static bool IsEndGroupTag(uint tag) - { - return (WireType) (tag & TagTypeMask) == WireType.EndGroup; - } - /// /// Given a tag value, determines the field number (the upper 29 bits). /// -- cgit v1.2.3