diff options
author | Jon Skeet <skeet@pobox.com> | 2009-06-25 07:52:07 +0100 |
---|---|---|
committer | Jon Skeet <skeet@pobox.com> | 2009-06-25 07:52:07 +0100 |
commit | 2178b93bbb121e4cbb38aa370758742d723dd8fd (patch) | |
tree | b7f3eb806956551e451d75d68bd703cae4e1dc27 | |
parent | 60fb63e3704091d0d681181dbab2055f6878f2ea (diff) | |
download | protobuf-2178b93bbb121e4cbb38aa370758742d723dd8fd.tar.gz protobuf-2178b93bbb121e4cbb38aa370758742d723dd8fd.tar.bz2 protobuf-2178b93bbb121e4cbb38aa370758742d723dd8fd.zip |
Fix bug when reading many messages - size guard was triggered
-rw-r--r-- | src/ProtocolBuffers.Test/MessageStreamIteratorTest.cs | 27 | ||||
-rw-r--r-- | src/ProtocolBuffers/CodedInputStream.cs | 8 | ||||
-rw-r--r-- | src/ProtocolBuffers/MessageStreamIterator.cs | 23 |
3 files changed, 51 insertions, 7 deletions
diff --git a/src/ProtocolBuffers.Test/MessageStreamIteratorTest.cs b/src/ProtocolBuffers.Test/MessageStreamIteratorTest.cs index 7eebde39..c7f31c0f 100644 --- a/src/ProtocolBuffers.Test/MessageStreamIteratorTest.cs +++ b/src/ProtocolBuffers.Test/MessageStreamIteratorTest.cs @@ -3,6 +3,7 @@ using System.Collections.Generic; using System.IO; using NUnit.Framework; using NestedMessage = Google.ProtocolBuffers.TestProtos.TestAllTypes.Types.NestedMessage; +using Google.ProtocolBuffers.TestProtos; namespace Google.ProtocolBuffers { [TestFixture] @@ -19,5 +20,31 @@ namespace Google.ProtocolBuffers { Assert.AreEqual(1500, messages[1].Bb); Assert.IsFalse(messages[2].HasBb); } + + [Test] + public void ManyMessagesShouldNotTriggerSizeAlert() { + int messageSize = TestUtil.GetAllSet().SerializedSize; + // Enough messages to trigger the alert unless we've reset the size + // Note that currently we need to make this big enough to copy two whole buffers, + // as otherwise when we refill the buffer the second type, the alert triggers instantly. + int correctCount = (CodedInputStream.BufferSize * 2) / messageSize + 1; + using (MemoryStream stream = new MemoryStream()) { + MessageStreamWriter<TestAllTypes> writer = new MessageStreamWriter<TestAllTypes>(stream); + for (int i = 0; i < correctCount; i++) { + writer.Write(TestUtil.GetAllSet()); + } + writer.Flush(); + + stream.Position = 0; + + int count = 0; + foreach (var message in MessageStreamIterator<TestAllTypes>.FromStreamProvider(() => stream) + .WithSizeLimit(CodedInputStream.BufferSize * 2)) { + count++; + TestUtil.AssertAllFieldsSet(message); + } + Assert.AreEqual(correctCount, count); + } + } } } diff --git a/src/ProtocolBuffers/CodedInputStream.cs b/src/ProtocolBuffers/CodedInputStream.cs index 313bddf3..e652af0d 100644 --- a/src/ProtocolBuffers/CodedInputStream.cs +++ b/src/ProtocolBuffers/CodedInputStream.cs @@ -61,9 +61,9 @@ namespace Google.ProtocolBuffers { private readonly Stream input; private uint lastTag = 0; - const int DefaultRecursionLimit = 64; - const int DefaultSizeLimit = 64 << 20; // 64MB - const int BufferSize = 4096; + internal const int DefaultRecursionLimit = 64; + internal const int DefaultSizeLimit = 64 << 20; // 64MB + internal const int BufferSize = 4096; /// <summary> /// The total number of bytes read before the current buffer. The @@ -741,7 +741,7 @@ namespace Google.ProtocolBuffers { /// Read one byte from the input. /// </summary> /// <exception cref="InvalidProtocolBufferException"> - /// he end of the stream or the current limit was reached + /// the end of the stream or the current limit was reached /// </exception> public byte ReadRawByte() { if (bufferPos == bufferSize) { diff --git a/src/ProtocolBuffers/MessageStreamIterator.cs b/src/ProtocolBuffers/MessageStreamIterator.cs index e8cc4306..5ddfc62a 100644 --- a/src/ProtocolBuffers/MessageStreamIterator.cs +++ b/src/ProtocolBuffers/MessageStreamIterator.cs @@ -18,6 +18,7 @@ namespace Google.ProtocolBuffers { private readonly StreamProvider streamProvider; private readonly ExtensionRegistry extensionRegistry; + private readonly int sizeLimit; /// <summary> /// Delegate created via reflection trickery (once per type) to create a builder @@ -103,17 +104,22 @@ namespace Google.ProtocolBuffers { TBuilder builder = builderBuilder(); input.ReadMessage(builder, registry); return builder.Build(); - } + } #pragma warning restore 0414 private static readonly uint ExpectedTag = WireFormat.MakeTag(1, WireFormat.WireType.LengthDelimited); - private MessageStreamIterator(StreamProvider streamProvider, ExtensionRegistry extensionRegistry) { + private MessageStreamIterator(StreamProvider streamProvider, ExtensionRegistry extensionRegistry, int sizeLimit) { if (messageReader == null) { throw typeInitializationException; } this.streamProvider = streamProvider; this.extensionRegistry = extensionRegistry; + this.sizeLimit = sizeLimit; + } + + private MessageStreamIterator(StreamProvider streamProvider, ExtensionRegistry extensionRegistry) + : this (streamProvider, extensionRegistry, CodedInputStream.DefaultSizeLimit) { } /// <summary> @@ -121,7 +127,16 @@ namespace Google.ProtocolBuffers { /// but the specified extension registry. /// </summary> public MessageStreamIterator<TMessage> WithExtensionRegistry(ExtensionRegistry newRegistry) { - return new MessageStreamIterator<TMessage>(streamProvider, newRegistry); + return new MessageStreamIterator<TMessage>(streamProvider, newRegistry, sizeLimit); + } + + /// <summary> + /// Creates a new instance which uses the same stream provider and extension registry as this one, + /// but with the specified size limit. Note that this must be big enough for the largest message + /// and the tag and size preceding it. + /// </summary> + public MessageStreamIterator<TMessage> WithSizeLimit(int newSizeLimit) { + return new MessageStreamIterator<TMessage>(streamProvider, extensionRegistry, newSizeLimit); } public static MessageStreamIterator<TMessage> FromFile(string file) { @@ -135,12 +150,14 @@ namespace Google.ProtocolBuffers { public IEnumerator<TMessage> GetEnumerator() { using (Stream stream = streamProvider()) { CodedInputStream input = CodedInputStream.CreateInstance(stream); + input.SetSizeLimit(sizeLimit); uint tag; while ((tag = input.ReadTag()) != 0) { if (tag != ExpectedTag) { throw InvalidProtocolBufferException.InvalidMessageStreamTag(); } yield return messageReader(input, extensionRegistry); + input.ResetSizeCounter(); } } } |