From cc8ca5b6a5478b40546d4206392eb1471454460d Mon Sep 17 00:00:00 2001 From: Bo Yang Date: Mon, 19 Sep 2016 13:45:07 -0700 Subject: Integrate internal changes --- .../main/java/com/google/protobuf/ByteString.java | 14 + .../java/com/google/protobuf/CodedInputStream.java | 3593 ++++++++++++++------ .../com/google/protobuf/CodedOutputStream.java | 400 ++- .../main/java/com/google/protobuf/Descriptors.java | 31 +- .../com/google/protobuf/GeneratedMessageV3.java | 481 ++- .../main/java/com/google/protobuf/Internal.java | 10 + .../protobuf/InvalidProtocolBufferException.java | 16 +- .../main/java/com/google/protobuf/MapEntry.java | 9 + .../java/com/google/protobuf/NioByteString.java | 9 +- .../java/com/google/protobuf/RopeByteString.java | 1 + .../main/java/com/google/protobuf/TextFormat.java | 22 +- .../google/protobuf/TextFormatParseInfoTree.java | 3 +- .../com/google/protobuf/UnsafeByteOperations.java | 31 +- .../main/java/com/google/protobuf/UnsafeUtil.java | 119 +- .../main/java/com/google/protobuf/WireFormat.java | 4 + 15 files changed, 3663 insertions(+), 1080 deletions(-) (limited to 'java/core/src/main') diff --git a/java/core/src/main/java/com/google/protobuf/ByteString.java b/java/core/src/main/java/com/google/protobuf/ByteString.java index 3f3f9f3c..5b24976d 100644 --- a/java/core/src/main/java/com/google/protobuf/ByteString.java +++ b/java/core/src/main/java/com/google/protobuf/ByteString.java @@ -310,6 +310,18 @@ public abstract class ByteString implements Iterable, Serializable { return copyFrom(bytes, 0, bytes.length); } + /** + * Wraps the given bytes into a {@code ByteString}. Intended for internal only usage. + */ + static ByteString wrap(ByteBuffer buffer) { + if (buffer.hasArray()) { + final int offset = buffer.arrayOffset(); + return ByteString.wrap(buffer.array(), offset + buffer.position(), buffer.remaining()); + } else { + return new NioByteString(buffer); + } + } + /** * Wraps the given bytes into a {@code ByteString}. Intended for internal only * usage to force a classload of ByteString before LiteralByteString. @@ -679,6 +691,7 @@ public abstract class ByteString implements Iterable, Serializable { */ abstract void writeTo(ByteOutput byteOutput) throws IOException; + /** * Constructs a read-only {@code java.nio.ByteBuffer} whose content * is equal to the contents of this byte string. @@ -820,6 +833,7 @@ public abstract class ByteString implements Iterable, Serializable { return true; } + /** * Check equality of the substring of given length of this object starting at * zero with another {@code ByteString} substring starting at offset. diff --git a/java/core/src/main/java/com/google/protobuf/CodedInputStream.java b/java/core/src/main/java/com/google/protobuf/CodedInputStream.java index e8860651..e461fa28 100644 --- a/java/core/src/main/java/com/google/protobuf/CodedInputStream.java +++ b/java/core/src/main/java/com/google/protobuf/CodedInputStream.java @@ -30,6 +30,14 @@ package com.google.protobuf; +import static com.google.protobuf.Internal.EMPTY_BYTE_ARRAY; +import static com.google.protobuf.Internal.EMPTY_BYTE_BUFFER; +import static com.google.protobuf.Internal.UTF_8; +import static com.google.protobuf.Internal.checkNotNull; +import static com.google.protobuf.WireFormat.FIXED_32_SIZE; +import static com.google.protobuf.WireFormat.FIXED_64_SIZE; +import static com.google.protobuf.WireFormat.MAX_VARINT_SIZE; + import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.InputStream; @@ -41,51 +49,55 @@ import java.util.List; /** * Reads and decodes protocol message fields. * - * This class contains two kinds of methods: methods that read specific - * protocol message constructs and field types (e.g. {@link #readTag()} and - * {@link #readInt32()}) and methods that read low-level values (e.g. - * {@link #readRawVarint32()} and {@link #readRawBytes}). If you are reading - * encoded protocol messages, you should use the former methods, but if you are - * reading some other format of your own design, use the latter. + *

This class contains two kinds of methods: methods that read specific protocol message + * constructs and field types (e.g. {@link #readTag()} and {@link #readInt32()}) and methods that + * read low-level values (e.g. {@link #readRawVarint32()} and {@link #readRawBytes}). If you are + * reading encoded protocol messages, you should use the former methods, but if you are reading some + * other format of your own design, use the latter. * * @author kenton@google.com Kenton Varda */ -public final class CodedInputStream { - /** - * Create a new CodedInputStream wrapping the given InputStream. - */ +public abstract class CodedInputStream { + private static final int DEFAULT_BUFFER_SIZE = 4096; + private static final int DEFAULT_RECURSION_LIMIT = 100; + private static final int DEFAULT_SIZE_LIMIT = 64 << 20; // 64MB + + /** Visible for subclasses. See setRecursionLimit() */ + int recursionDepth; + + int recursionLimit = DEFAULT_RECURSION_LIMIT; + + /** Visible for subclasses. See setSizeLimit() */ + int sizeLimit = DEFAULT_SIZE_LIMIT; + + /** Create a new CodedInputStream wrapping the given InputStream. */ public static CodedInputStream newInstance(final InputStream input) { - return new CodedInputStream(input, BUFFER_SIZE); + return newInstance(input, DEFAULT_BUFFER_SIZE); } - - /** - * Create a new CodedInputStream wrapping the given InputStream. - */ + + /** Create a new CodedInputStream wrapping the given InputStream. */ static CodedInputStream newInstance(final InputStream input, int bufferSize) { - return new CodedInputStream(input, bufferSize); + if (input == null) { + // TODO(nathanmittler): Ideally we should throw here. This is done for backward compatibility. + return newInstance(EMPTY_BYTE_ARRAY); + } + return new StreamDecoder(input, bufferSize); } - /** - * Create a new CodedInputStream wrapping the given byte array. - */ + /** Create a new CodedInputStream wrapping the given byte array. */ public static CodedInputStream newInstance(final byte[] buf) { return newInstance(buf, 0, buf.length); } - /** - * Create a new CodedInputStream wrapping the given byte array slice. - */ - public static CodedInputStream newInstance(final byte[] buf, final int off, - final int len) { + /** Create a new CodedInputStream wrapping the given byte array slice. */ + public static CodedInputStream newInstance(final byte[] buf, final int off, final int len) { return newInstance(buf, off, len, false /* bufferIsImmutable */); } - - /** - * Create a new CodedInputStream wrapping the given byte array slice. - */ + + /** Create a new CodedInputStream wrapping the given byte array slice. */ static CodedInputStream newInstance( final byte[] buf, final int off, final int len, final boolean bufferIsImmutable) { - CodedInputStream result = new CodedInputStream(buf, off, len, bufferIsImmutable); + ArrayDecoder result = new ArrayDecoder(buf, off, len, bufferIsImmutable); try { // Some uses of CodedInputStream can be more efficient if they know // exactly how many bytes are available. By pushing the end point of the @@ -107,821 +119,229 @@ public final class CodedInputStream { } /** - * Create a new CodedInputStream wrapping the given ByteBuffer. The data - * starting from the ByteBuffer's current position to its limit will be read. - * The returned CodedInputStream may or may not share the underlying data - * in the ByteBuffer, therefore the ByteBuffer cannot be changed while the - * CodedInputStream is in use. - * Note that the ByteBuffer's position won't be changed by this function. - * Concurrent calls with the same ByteBuffer object are safe if no other - * thread is trying to alter the ByteBuffer's status. + * Create a new CodedInputStream wrapping the given ByteBuffer. The data starting from the + * ByteBuffer's current position to its limit will be read. The returned CodedInputStream may or + * may not share the underlying data in the ByteBuffer, therefore the ByteBuffer cannot be changed + * while the CodedInputStream is in use. Note that the ByteBuffer's position won't be changed by + * this function. Concurrent calls with the same ByteBuffer object are safe if no other thread is + * trying to alter the ByteBuffer's status. */ public static CodedInputStream newInstance(ByteBuffer buf) { + return newInstance(buf, false /* bufferIsImmutable */); + } + + /** Create a new CodedInputStream wrapping the given buffer. */ + static CodedInputStream newInstance(ByteBuffer buf, boolean bufferIsImmutable) { if (buf.hasArray()) { - return newInstance(buf.array(), buf.arrayOffset() + buf.position(), - buf.remaining()); - } else { - ByteBuffer temp = buf.duplicate(); - byte[] buffer = new byte[temp.remaining()]; - temp.get(buffer); - return newInstance(buffer); + return newInstance( + buf.array(), buf.arrayOffset() + buf.position(), buf.remaining(), bufferIsImmutable); + } + + if (buf.isDirect() && UnsafeDirectNioDecoder.isSupported()) { + return new UnsafeDirectNioDecoder(buf, bufferIsImmutable); } + + // The buffer is non-direct and does not expose the underlying array. Using the ByteBuffer API + // to access individual bytes is very slow, so just copy the buffer to an array. + // TODO(nathanmittler): Re-evaluate with Java 9 + byte[] buffer = new byte[buf.remaining()]; + buf.duplicate().get(buffer); + return newInstance(buffer, 0, buffer.length, true); } + /** Disable construction/inheritance outside of this class. */ + private CodedInputStream() {} + // ----------------------------------------------------------------- /** - * Attempt to read a field tag, returning zero if we have reached EOF. - * Protocol message parsers use this to read tags, since a protocol message - * may legally end wherever a tag occurs, and zero is not a valid tag number. + * Attempt to read a field tag, returning zero if we have reached EOF. Protocol message parsers + * use this to read tags, since a protocol message may legally end wherever a tag occurs, and zero + * is not a valid tag number. */ - public int readTag() throws IOException { - if (isAtEnd()) { - lastTag = 0; - return 0; - } - - lastTag = readRawVarint32(); - if (WireFormat.getTagFieldNumber(lastTag) == 0) { - // If we actually read zero (or any tag number corresponding to field - // number zero), that's not a valid tag. - throw InvalidProtocolBufferException.invalidTag(); - } - return lastTag; - } + public abstract int readTag() throws IOException; /** - * Verifies that the last call to readTag() returned the given tag value. - * This is used to verify that a nested group ended with the correct - * end tag. + * Verifies that the last call to readTag() returned the given tag value. This is used to verify + * that a nested group ended with the correct end tag. * - * @throws InvalidProtocolBufferException {@code value} does not match the - * last tag. + * @throws InvalidProtocolBufferException {@code value} does not match the last tag. */ - public void checkLastTagWas(final int value) - throws InvalidProtocolBufferException { - if (lastTag != value) { - throw InvalidProtocolBufferException.invalidEndTag(); - } - } + public abstract void checkLastTagWas(final int value) throws InvalidProtocolBufferException; - public int getLastTag() { - return lastTag; - } + public abstract int getLastTag(); /** * Reads and discards a single field, given its tag value. * - * @return {@code false} if the tag is an endgroup tag, in which case - * nothing is skipped. Otherwise, returns {@code true}. + * @return {@code false} if the tag is an endgroup tag, in which case nothing is skipped. + * Otherwise, returns {@code true}. */ - public boolean skipField(final int tag) throws IOException { - switch (WireFormat.getTagWireType(tag)) { - case WireFormat.WIRETYPE_VARINT: - skipRawVarint(); - return true; - case WireFormat.WIRETYPE_FIXED64: - skipRawBytes(8); - return true; - case WireFormat.WIRETYPE_LENGTH_DELIMITED: - skipRawBytes(readRawVarint32()); - return true; - case WireFormat.WIRETYPE_START_GROUP: - skipMessage(); - checkLastTagWas( - WireFormat.makeTag(WireFormat.getTagFieldNumber(tag), - WireFormat.WIRETYPE_END_GROUP)); - return true; - case WireFormat.WIRETYPE_END_GROUP: - return false; - case WireFormat.WIRETYPE_FIXED32: - skipRawBytes(4); - return true; - default: - throw InvalidProtocolBufferException.invalidWireType(); - } - } + public abstract boolean skipField(final int tag) throws IOException; /** - * Reads a single field and writes it to output in wire format, - * given its tag value. + * Reads a single field and writes it to output in wire format, given its tag value. * - * @return {@code false} if the tag is an endgroup tag, in which case - * nothing is skipped. Otherwise, returns {@code true}. - */ - public boolean skipField(final int tag, final CodedOutputStream output) - throws IOException { - switch (WireFormat.getTagWireType(tag)) { - case WireFormat.WIRETYPE_VARINT: { - long value = readInt64(); - output.writeRawVarint32(tag); - output.writeUInt64NoTag(value); - return true; - } - case WireFormat.WIRETYPE_FIXED64: { - long value = readRawLittleEndian64(); - output.writeRawVarint32(tag); - output.writeFixed64NoTag(value); - return true; - } - case WireFormat.WIRETYPE_LENGTH_DELIMITED: { - ByteString value = readBytes(); - output.writeRawVarint32(tag); - output.writeBytesNoTag(value); - return true; - } - case WireFormat.WIRETYPE_START_GROUP: { - output.writeRawVarint32(tag); - skipMessage(output); - int endtag = WireFormat.makeTag(WireFormat.getTagFieldNumber(tag), - WireFormat.WIRETYPE_END_GROUP); - checkLastTagWas(endtag); - output.writeRawVarint32(endtag); - return true; - } - case WireFormat.WIRETYPE_END_GROUP: { - return false; - } - case WireFormat.WIRETYPE_FIXED32: { - int value = readRawLittleEndian32(); - output.writeRawVarint32(tag); - output.writeFixed32NoTag(value); - return true; - } - default: - throw InvalidProtocolBufferException.invalidWireType(); - } - } - - /** - * Reads and discards an entire message. This will read either until EOF - * or until an endgroup tag, whichever comes first. + * @return {@code false} if the tag is an endgroup tag, in which case nothing is skipped. + * Otherwise, returns {@code true}. + * @deprecated use {@code UnknownFieldSet} or {@code UnknownFieldSetLite} to skip to an output + * stream. */ - public void skipMessage() throws IOException { - while (true) { - final int tag = readTag(); - if (tag == 0 || !skipField(tag)) { - return; - } - } - } + @Deprecated + public abstract boolean skipField(final int tag, final CodedOutputStream output) + throws IOException; /** - * Reads an entire message and writes it to output in wire format. - * This will read either until EOF or until an endgroup tag, + * Reads and discards an entire message. This will read either until EOF or until an endgroup tag, * whichever comes first. */ - public void skipMessage(CodedOutputStream output) throws IOException { - while (true) { - final int tag = readTag(); - if (tag == 0 || !skipField(tag, output)) { - return; - } - } - } + public abstract void skipMessage() throws IOException; /** - * Collects the bytes skipped and returns the data in a ByteBuffer. + * Reads an entire message and writes it to output in wire format. This will read either until EOF + * or until an endgroup tag, whichever comes first. */ - private class SkippedDataSink implements RefillCallback { - private int lastPos = bufferPos; - private ByteArrayOutputStream byteArrayStream; - - @Override - public void onRefill() { - if (byteArrayStream == null) { - byteArrayStream = new ByteArrayOutputStream(); - } - byteArrayStream.write(buffer, lastPos, bufferPos - lastPos); - lastPos = 0; - } - - /** - * Gets skipped data in a ByteBuffer. This method should only be - * called once. - */ - ByteBuffer getSkippedData() { - if (byteArrayStream == null) { - return ByteBuffer.wrap(buffer, lastPos, bufferPos - lastPos); - } else { - byteArrayStream.write(buffer, lastPos, bufferPos); - return ByteBuffer.wrap(byteArrayStream.toByteArray()); - } - } - } + public abstract void skipMessage(CodedOutputStream output) throws IOException; // ----------------------------------------------------------------- /** Read a {@code double} field value from the stream. */ - public double readDouble() throws IOException { - return Double.longBitsToDouble(readRawLittleEndian64()); - } + public abstract double readDouble() throws IOException; /** Read a {@code float} field value from the stream. */ - public float readFloat() throws IOException { - return Float.intBitsToFloat(readRawLittleEndian32()); - } + public abstract float readFloat() throws IOException; /** Read a {@code uint64} field value from the stream. */ - public long readUInt64() throws IOException { - return readRawVarint64(); - } + public abstract long readUInt64() throws IOException; /** Read an {@code int64} field value from the stream. */ - public long readInt64() throws IOException { - return readRawVarint64(); - } + public abstract long readInt64() throws IOException; /** Read an {@code int32} field value from the stream. */ - public int readInt32() throws IOException { - return readRawVarint32(); - } + public abstract int readInt32() throws IOException; /** Read a {@code fixed64} field value from the stream. */ - public long readFixed64() throws IOException { - return readRawLittleEndian64(); - } + public abstract long readFixed64() throws IOException; /** Read a {@code fixed32} field value from the stream. */ - public int readFixed32() throws IOException { - return readRawLittleEndian32(); - } + public abstract int readFixed32() throws IOException; /** Read a {@code bool} field value from the stream. */ - public boolean readBool() throws IOException { - return readRawVarint64() != 0; - } + public abstract boolean readBool() throws IOException; /** - * Read a {@code string} field value from the stream. - * If the stream contains malformed UTF-8, + * Read a {@code string} field value from the stream. If the stream contains malformed UTF-8, * replace the offending bytes with the standard UTF-8 replacement character. */ - public String readString() throws IOException { - final int size = readRawVarint32(); - if (size <= (bufferSize - bufferPos) && size > 0) { - // Fast path: We already have the bytes in a contiguous buffer, so - // just copy directly from it. - final String result = new String(buffer, bufferPos, size, Internal.UTF_8); - bufferPos += size; - return result; - } else if (size == 0) { - return ""; - } else if (size <= bufferSize) { - refillBuffer(size); - String result = new String(buffer, bufferPos, size, Internal.UTF_8); - bufferPos += size; - return result; - } else { - // Slow path: Build a byte array first then copy it. - return new String(readRawBytesSlowPath(size), Internal.UTF_8); - } - } + public abstract String readString() throws IOException; /** - * Read a {@code string} field value from the stream. - * If the stream contains malformed UTF-8, + * Read a {@code string} field value from the stream. If the stream contains malformed UTF-8, * throw exception {@link InvalidProtocolBufferException}. */ - public String readStringRequireUtf8() throws IOException { - final int size = readRawVarint32(); - final byte[] bytes; - final int oldPos = bufferPos; - final int pos; - if (size <= (bufferSize - oldPos) && size > 0) { - // Fast path: We already have the bytes in a contiguous buffer, so - // just copy directly from it. - bytes = buffer; - bufferPos = oldPos + size; - pos = oldPos; - } else if (size == 0) { - return ""; - } else if (size <= bufferSize) { - refillBuffer(size); - bytes = buffer; - pos = 0; - bufferPos = pos + size; - } else { - // Slow path: Build a byte array first then copy it. - bytes = readRawBytesSlowPath(size); - pos = 0; - } - // TODO(martinrb): We could save a pass by validating while decoding. - if (!Utf8.isValidUtf8(bytes, pos, pos + size)) { - throw InvalidProtocolBufferException.invalidUtf8(); - } - return new String(bytes, pos, size, Internal.UTF_8); - } + public abstract String readStringRequireUtf8() throws IOException; /** Read a {@code group} field value from the stream. */ - public void readGroup(final int fieldNumber, - final MessageLite.Builder builder, - final ExtensionRegistryLite extensionRegistry) - throws IOException { - if (recursionDepth >= recursionLimit) { - throw InvalidProtocolBufferException.recursionLimitExceeded(); - } - ++recursionDepth; - builder.mergeFrom(this, extensionRegistry); - checkLastTagWas( - WireFormat.makeTag(fieldNumber, WireFormat.WIRETYPE_END_GROUP)); - --recursionDepth; - } + public abstract void readGroup( + final int fieldNumber, + final MessageLite.Builder builder, + final ExtensionRegistryLite extensionRegistry) + throws IOException; /** Read a {@code group} field value from the stream. */ - public T readGroup( - final int fieldNumber, - final Parser parser, - final ExtensionRegistryLite extensionRegistry) - throws IOException { - if (recursionDepth >= recursionLimit) { - throw InvalidProtocolBufferException.recursionLimitExceeded(); - } - ++recursionDepth; - T result = parser.parsePartialFrom(this, extensionRegistry); - checkLastTagWas( - WireFormat.makeTag(fieldNumber, WireFormat.WIRETYPE_END_GROUP)); - --recursionDepth; - return result; - } + public abstract T readGroup( + final int fieldNumber, final Parser parser, final ExtensionRegistryLite extensionRegistry) + throws IOException; /** - * Reads a {@code group} field value from the stream and merges it into the - * given {@link UnknownFieldSet}. + * Reads a {@code group} field value from the stream and merges it into the given {@link + * UnknownFieldSet}. * - * @deprecated UnknownFieldSet.Builder now implements MessageLite.Builder, so - * you can just call {@link #readGroup}. + * @deprecated UnknownFieldSet.Builder now implements MessageLite.Builder, so you can just call + * {@link #readGroup}. */ @Deprecated - public void readUnknownGroup(final int fieldNumber, - final MessageLite.Builder builder) - throws IOException { - // We know that UnknownFieldSet will ignore any ExtensionRegistry so it - // is safe to pass null here. (We can't call - // ExtensionRegistry.getEmptyRegistry() because that would make this - // class depend on ExtensionRegistry, which is not part of the lite - // library.) - readGroup(fieldNumber, builder, null); - } + public abstract void readUnknownGroup(final int fieldNumber, final MessageLite.Builder builder) + throws IOException; /** Read an embedded message field value from the stream. */ - public void readMessage(final MessageLite.Builder builder, - final ExtensionRegistryLite extensionRegistry) - throws IOException { - final int length = readRawVarint32(); - if (recursionDepth >= recursionLimit) { - throw InvalidProtocolBufferException.recursionLimitExceeded(); - } - final int oldLimit = pushLimit(length); - ++recursionDepth; - builder.mergeFrom(this, extensionRegistry); - checkLastTagWas(0); - --recursionDepth; - popLimit(oldLimit); - } + public abstract void readMessage( + final MessageLite.Builder builder, final ExtensionRegistryLite extensionRegistry) + throws IOException; /** Read an embedded message field value from the stream. */ - public T readMessage( - final Parser parser, - final ExtensionRegistryLite extensionRegistry) - throws IOException { - int length = readRawVarint32(); - if (recursionDepth >= recursionLimit) { - throw InvalidProtocolBufferException.recursionLimitExceeded(); - } - final int oldLimit = pushLimit(length); - ++recursionDepth; - T result = parser.parsePartialFrom(this, extensionRegistry); - checkLastTagWas(0); - --recursionDepth; - popLimit(oldLimit); - return result; - } + public abstract T readMessage( + final Parser parser, final ExtensionRegistryLite extensionRegistry) throws IOException; /** Read a {@code bytes} field value from the stream. */ - public ByteString readBytes() throws IOException { - final int size = readRawVarint32(); - if (size <= (bufferSize - bufferPos) && size > 0) { - // Fast path: We already have the bytes in a contiguous buffer, so - // just copy directly from it. - final ByteString result = bufferIsImmutable && enableAliasing - ? ByteString.wrap(buffer, bufferPos, size) - : ByteString.copyFrom(buffer, bufferPos, size); - bufferPos += size; - return result; - } else if (size == 0) { - return ByteString.EMPTY; - } else { - // Slow path: Build a byte array first then copy it. - return ByteString.wrap(readRawBytesSlowPath(size)); - } - } + public abstract ByteString readBytes() throws IOException; /** Read a {@code bytes} field value from the stream. */ - public byte[] readByteArray() throws IOException { - final int size = readRawVarint32(); - if (size <= (bufferSize - bufferPos) && size > 0) { - // Fast path: We already have the bytes in a contiguous buffer, so - // just copy directly from it. - final byte[] result = - Arrays.copyOfRange(buffer, bufferPos, bufferPos + size); - bufferPos += size; - return result; - } else { - // Slow path: Build a byte array first then copy it. - return readRawBytesSlowPath(size); - } - } + public abstract byte[] readByteArray() throws IOException; /** Read a {@code bytes} field value from the stream. */ - public ByteBuffer readByteBuffer() throws IOException { - final int size = readRawVarint32(); - if (size <= (bufferSize - bufferPos) && size > 0) { - // Fast path: We already have the bytes in a contiguous buffer. - // When aliasing is enabled, we can return a ByteBuffer pointing directly - // into the underlying byte array without copy if the CodedInputStream is - // constructed from a byte array. If aliasing is disabled or the input is - // from an InputStream or ByteString, we have to make a copy of the bytes. - ByteBuffer result = input == null && !bufferIsImmutable && enableAliasing - ? ByteBuffer.wrap(buffer, bufferPos, size).slice() - : ByteBuffer.wrap(Arrays.copyOfRange( - buffer, bufferPos, bufferPos + size)); - bufferPos += size; - return result; - } else if (size == 0) { - return Internal.EMPTY_BYTE_BUFFER; - } else { - // Slow path: Build a byte array first then copy it. - return ByteBuffer.wrap(readRawBytesSlowPath(size)); - } - } + public abstract ByteBuffer readByteBuffer() throws IOException; /** Read a {@code uint32} field value from the stream. */ - public int readUInt32() throws IOException { - return readRawVarint32(); - } + public abstract int readUInt32() throws IOException; /** - * Read an enum field value from the stream. Caller is responsible - * for converting the numeric value to an actual enum. + * Read an enum field value from the stream. Caller is responsible for converting the numeric + * value to an actual enum. */ - public int readEnum() throws IOException { - return readRawVarint32(); - } + public abstract int readEnum() throws IOException; /** Read an {@code sfixed32} field value from the stream. */ - public int readSFixed32() throws IOException { - return readRawLittleEndian32(); - } + public abstract int readSFixed32() throws IOException; /** Read an {@code sfixed64} field value from the stream. */ - public long readSFixed64() throws IOException { - return readRawLittleEndian64(); - } + public abstract long readSFixed64() throws IOException; /** Read an {@code sint32} field value from the stream. */ - public int readSInt32() throws IOException { - return decodeZigZag32(readRawVarint32()); - } + public abstract int readSInt32() throws IOException; /** Read an {@code sint64} field value from the stream. */ - public long readSInt64() throws IOException { - return decodeZigZag64(readRawVarint64()); - } + public abstract long readSInt64() throws IOException; // ================================================================= - /** - * Read a raw Varint from the stream. If larger than 32 bits, discard the - * upper bits. - */ - public int readRawVarint32() throws IOException { - // See implementation notes for readRawVarint64 - fastpath: { - int pos = bufferPos; - - if (bufferSize == pos) { - break fastpath; - } - - final byte[] buffer = this.buffer; - int x; - if ((x = buffer[pos++]) >= 0) { - bufferPos = pos; - return x; - } else if (bufferSize - pos < 9) { - break fastpath; - } else if ((x ^= (buffer[pos++] << 7)) < 0) { - x ^= (~0 << 7); - } else if ((x ^= (buffer[pos++] << 14)) >= 0) { - x ^= (~0 << 7) ^ (~0 << 14); - } else if ((x ^= (buffer[pos++] << 21)) < 0) { - x ^= (~0 << 7) ^ (~0 << 14) ^ (~0 << 21); - } else { - int y = buffer[pos++]; - x ^= y << 28; - x ^= (~0 << 7) ^ (~0 << 14) ^ (~0 << 21) ^ (~0 << 28); - if (y < 0 && - buffer[pos++] < 0 && - buffer[pos++] < 0 && - buffer[pos++] < 0 && - buffer[pos++] < 0 && - buffer[pos++] < 0) { - break fastpath; // Will throw malformedVarint() - } - } - bufferPos = pos; - return x; - } - return (int) readRawVarint64SlowPath(); - } - - private void skipRawVarint() throws IOException { - if (bufferSize - bufferPos >= 10) { - final byte[] buffer = this.buffer; - int pos = bufferPos; - for (int i = 0; i < 10; i++) { - if (buffer[pos++] >= 0) { - bufferPos = pos; - return; - } - } - } - skipRawVarintSlowPath(); - } - - private void skipRawVarintSlowPath() throws IOException { - for (int i = 0; i < 10; i++) { - if (readRawByte() >= 0) { - return; - } - } - throw InvalidProtocolBufferException.malformedVarint(); - } - - /** - * Reads a varint from the input one byte at a time, so that it does not - * read any bytes after the end of the varint. If you simply wrapped the - * stream in a CodedInputStream and used {@link #readRawVarint32(InputStream)} - * then you would probably end up reading past the end of the varint since - * CodedInputStream buffers its input. - */ - static int readRawVarint32(final InputStream input) throws IOException { - final int firstByte = input.read(); - if (firstByte == -1) { - throw InvalidProtocolBufferException.truncatedMessage(); - } - return readRawVarint32(firstByte, input); - } - - /** - * Like {@link #readRawVarint32(InputStream)}, but expects that the caller - * has already read one byte. This allows the caller to determine if EOF - * has been reached before attempting to read. - */ - public static int readRawVarint32( - final int firstByte, final InputStream input) throws IOException { - if ((firstByte & 0x80) == 0) { - return firstByte; - } - - int result = firstByte & 0x7f; - int offset = 7; - for (; offset < 32; offset += 7) { - final int b = input.read(); - if (b == -1) { - throw InvalidProtocolBufferException.truncatedMessage(); - } - result |= (b & 0x7f) << offset; - if ((b & 0x80) == 0) { - return result; - } - } - // Keep reading up to 64 bits. - for (; offset < 64; offset += 7) { - final int b = input.read(); - if (b == -1) { - throw InvalidProtocolBufferException.truncatedMessage(); - } - if ((b & 0x80) == 0) { - return result; - } - } - throw InvalidProtocolBufferException.malformedVarint(); - } + /** Read a raw Varint from the stream. If larger than 32 bits, discard the upper bits. */ + public abstract int readRawVarint32() throws IOException; /** Read a raw Varint from the stream. */ - public long readRawVarint64() throws IOException { - // Implementation notes: - // - // Optimized for one-byte values, expected to be common. - // The particular code below was selected from various candidates - // empirically, by winning VarintBenchmark. - // - // Sign extension of (signed) Java bytes is usually a nuisance, but - // we exploit it here to more easily obtain the sign of bytes read. - // Instead of cleaning up the sign extension bits by masking eagerly, - // we delay until we find the final (positive) byte, when we clear all - // accumulated bits with one xor. We depend on javac to constant fold. - fastpath: { - int pos = bufferPos; - - if (bufferSize == pos) { - break fastpath; - } - - final byte[] buffer = this.buffer; - long x; - int y; - if ((y = buffer[pos++]) >= 0) { - bufferPos = pos; - return y; - } else if (bufferSize - pos < 9) { - break fastpath; - } else if ((y ^= (buffer[pos++] << 7)) < 0) { - x = y ^ (~0 << 7); - } else if ((y ^= (buffer[pos++] << 14)) >= 0) { - x = y ^ ((~0 << 7) ^ (~0 << 14)); - } else if ((y ^= (buffer[pos++] << 21)) < 0) { - x = y ^ ((~0 << 7) ^ (~0 << 14) ^ (~0 << 21)); - } else if ((x = ((long) y) ^ ((long) buffer[pos++] << 28)) >= 0L) { - x ^= (~0L << 7) ^ (~0L << 14) ^ (~0L << 21) ^ (~0L << 28); - } else if ((x ^= ((long) buffer[pos++] << 35)) < 0L) { - x ^= (~0L << 7) ^ (~0L << 14) ^ (~0L << 21) ^ (~0L << 28) ^ (~0L << 35); - } else if ((x ^= ((long) buffer[pos++] << 42)) >= 0L) { - x ^= (~0L << 7) ^ (~0L << 14) ^ (~0L << 21) ^ (~0L << 28) ^ (~0L << 35) ^ (~0L << 42); - } else if ((x ^= ((long) buffer[pos++] << 49)) < 0L) { - x ^= (~0L << 7) ^ (~0L << 14) ^ (~0L << 21) ^ (~0L << 28) ^ (~0L << 35) ^ (~0L << 42) - ^ (~0L << 49); - } else { - x ^= ((long) buffer[pos++] << 56); - x ^= (~0L << 7) ^ (~0L << 14) ^ (~0L << 21) ^ (~0L << 28) ^ (~0L << 35) ^ (~0L << 42) - ^ (~0L << 49) ^ (~0L << 56); - if (x < 0L) { - if (buffer[pos++] < 0L) { - break fastpath; // Will throw malformedVarint() - } - } - } - bufferPos = pos; - return x; - } - return readRawVarint64SlowPath(); - } + public abstract long readRawVarint64() throws IOException; /** Variant of readRawVarint64 for when uncomfortably close to the limit. */ /* Visible for testing */ - long readRawVarint64SlowPath() throws IOException { - long result = 0; - for (int shift = 0; shift < 64; shift += 7) { - final byte b = readRawByte(); - result |= (long) (b & 0x7F) << shift; - if ((b & 0x80) == 0) { - return result; - } - } - throw InvalidProtocolBufferException.malformedVarint(); - } + abstract long readRawVarint64SlowPath() throws IOException; /** Read a 32-bit little-endian integer from the stream. */ - public int readRawLittleEndian32() throws IOException { - int pos = bufferPos; - - // hand-inlined ensureAvailable(4); - if (bufferSize - pos < 4) { - refillBuffer(4); - pos = bufferPos; - } - - final byte[] buffer = this.buffer; - bufferPos = pos + 4; - return (((buffer[pos] & 0xff)) | - ((buffer[pos + 1] & 0xff) << 8) | - ((buffer[pos + 2] & 0xff) << 16) | - ((buffer[pos + 3] & 0xff) << 24)); - } + public abstract int readRawLittleEndian32() throws IOException; /** Read a 64-bit little-endian integer from the stream. */ - public long readRawLittleEndian64() throws IOException { - int pos = bufferPos; - - // hand-inlined ensureAvailable(8); - if (bufferSize - pos < 8) { - refillBuffer(8); - pos = bufferPos; - } - - final byte[] buffer = this.buffer; - bufferPos = pos + 8; - return ((((long) buffer[pos] & 0xffL)) | - (((long) buffer[pos + 1] & 0xffL) << 8) | - (((long) buffer[pos + 2] & 0xffL) << 16) | - (((long) buffer[pos + 3] & 0xffL) << 24) | - (((long) buffer[pos + 4] & 0xffL) << 32) | - (((long) buffer[pos + 5] & 0xffL) << 40) | - (((long) buffer[pos + 6] & 0xffL) << 48) | - (((long) buffer[pos + 7] & 0xffL) << 56)); - } - - /** - * Decode a ZigZag-encoded 32-bit value. ZigZag encodes signed integers - * into values that can be efficiently encoded with varint. (Otherwise, - * negative values must be sign-extended to 64 bits to be varint encoded, - * thus always taking 10 bytes on the wire.) - * - * @param n An unsigned 32-bit integer, stored in a signed int because - * Java has no explicit unsigned support. - * @return A signed 32-bit integer. - */ - public static int decodeZigZag32(final int n) { - return (n >>> 1) ^ -(n & 1); - } - - /** - * Decode a ZigZag-encoded 64-bit value. ZigZag encodes signed integers - * into values that can be efficiently encoded with varint. (Otherwise, - * negative values must be sign-extended to 64 bits to be varint encoded, - * thus always taking 10 bytes on the wire.) - * - * @param n An unsigned 64-bit integer, stored in a signed int because - * Java has no explicit unsigned support. - * @return A signed 64-bit integer. - */ - public static long decodeZigZag64(final long n) { - return (n >>> 1) ^ -(n & 1); - } + public abstract long readRawLittleEndian64() throws IOException; // ----------------------------------------------------------------- - private final byte[] buffer; - private final boolean bufferIsImmutable; - private int bufferSize; - private int bufferSizeAfterLimit; - private int bufferPos; - private final InputStream input; - private int lastTag; - private boolean enableAliasing = false; - /** - * The total number of bytes read before the current buffer. The total - * bytes read up to the current position can be computed as - * {@code totalBytesRetired + bufferPos}. This value may be negative if - * reading started in the middle of the current buffer (e.g. if the - * constructor that takes a byte array and an offset was used). + * Enables {@link ByteString} aliasing of the underlying buffer, trading off on buffer pinning for + * data copies. Only valid for buffer-backed streams. */ - private int totalBytesRetired; - - /** The absolute position of the end of the current message. */ - private int currentLimit = Integer.MAX_VALUE; - - /** See setRecursionLimit() */ - private int recursionDepth; - private int recursionLimit = DEFAULT_RECURSION_LIMIT; - - /** See setSizeLimit() */ - private int sizeLimit = DEFAULT_SIZE_LIMIT; - - private static final int DEFAULT_RECURSION_LIMIT = 100; - private static final int DEFAULT_SIZE_LIMIT = 64 << 20; // 64MB - private static final int BUFFER_SIZE = 4096; - - private CodedInputStream( - final byte[] buffer, final int off, final int len, boolean bufferIsImmutable) { - this.buffer = buffer; - bufferSize = off + len; - bufferPos = off; - totalBytesRetired = -off; - input = null; - this.bufferIsImmutable = bufferIsImmutable; - } - - private CodedInputStream(final InputStream input, int bufferSize) { - buffer = new byte[bufferSize]; - bufferSize = 0; - bufferPos = 0; - totalBytesRetired = 0; - this.input = input; - bufferIsImmutable = false; - } - - public void enableAliasing(boolean enabled) { - this.enableAliasing = enabled; - } + public abstract void enableAliasing(boolean enabled); /** - * Set the maximum message recursion depth. In order to prevent malicious - * messages from causing stack overflows, {@code CodedInputStream} limits - * how deeply messages may be nested. The default limit is 64. + * Set the maximum message recursion depth. In order to prevent malicious messages from causing + * stack overflows, {@code CodedInputStream} limits how deeply messages may be nested. The default + * limit is 64. * * @return the old limit. */ - public int setRecursionLimit(final int limit) { + public final int setRecursionLimit(final int limit) { if (limit < 0) { - throw new IllegalArgumentException( - "Recursion limit cannot be negative: " + limit); + throw new IllegalArgumentException("Recursion limit cannot be negative: " + limit); } final int oldLimit = recursionLimit; recursionLimit = limit; @@ -929,25 +349,22 @@ public final class CodedInputStream { } /** - * Set the maximum message size. In order to prevent malicious - * messages from exhausting memory or causing integer overflows, - * {@code CodedInputStream} limits how large a message may be. - * The default limit is 64MB. You should set this limit as small - * as you can without harming your app's functionality. Note that - * size limits only apply when reading from an {@code InputStream}, not - * when constructed around a raw byte array (nor with - * {@link ByteString#newCodedInput}). - *

- * If you want to read several messages from a single CodedInputStream, you - * could call {@link #resetSizeCounter()} after each one to avoid hitting the - * size limit. + * Only valid for {@link InputStream}-backed streams. + * + *

Set the maximum message size. In order to prevent malicious messages from exhausting memory + * or causing integer overflows, {@code CodedInputStream} limits how large a message may be. The + * default limit is 64MB. You should set this limit as small as you can without harming your app's + * functionality. Note that size limits only apply when reading from an {@code InputStream}, not + * when constructed around a raw byte array (nor with {@link ByteString#newCodedInput}). + * + *

If you want to read several messages from a single CodedInputStream, you could call {@link + * #resetSizeCounter()} after each one to avoid hitting the size limit. * * @return the old limit. */ - public int setSizeLimit(final int limit) { + public final int setSizeLimit(final int limit) { if (limit < 0) { - throw new IllegalArgumentException( - "Size limit cannot be negative: " + limit); + throw new IllegalArgumentException("Size limit cannot be negative: " + limit); } final int oldLimit = sizeLimit; sizeLimit = limit; @@ -955,348 +372,2524 @@ public final class CodedInputStream { } /** - * Resets the current size counter to zero (see {@link #setSizeLimit(int)}). + * Resets the current size counter to zero (see {@link #setSizeLimit(int)}). Only valid for {@link + * InputStream}-backed streams. */ - public void resetSizeCounter() { - totalBytesRetired = -bufferPos; - } + public abstract void resetSizeCounter(); /** - * Sets {@code currentLimit} to (current position) + {@code byteLimit}. This - * is called when descending into a length-delimited embedded message. + * Sets {@code currentLimit} to (current position) + {@code byteLimit}. This is called when + * descending into a length-delimited embedded message. * - *

Note that {@code pushLimit()} does NOT affect how many bytes the - * {@code CodedInputStream} reads from an underlying {@code InputStream} when - * refreshing its buffer. If you need to prevent reading past a certain - * point in the underlying {@code InputStream} (e.g. because you expect it to - * contain more data after the end of the message which you need to handle - * differently) then you must place a wrapper around your {@code InputStream} - * which limits the amount of data that can be read from it. + *

Note that {@code pushLimit()} does NOT affect how many bytes the {@code CodedInputStream} + * reads from an underlying {@code InputStream} when refreshing its buffer. If you need to prevent + * reading past a certain point in the underlying {@code InputStream} (e.g. because you expect it + * to contain more data after the end of the message which you need to handle differently) then + * you must place a wrapper around your {@code InputStream} which limits the amount of data that + * can be read from it. * * @return the old limit. */ - public int pushLimit(int byteLimit) throws InvalidProtocolBufferException { - if (byteLimit < 0) { - throw InvalidProtocolBufferException.negativeSize(); - } - byteLimit += totalBytesRetired + bufferPos; - final int oldLimit = currentLimit; - if (byteLimit > oldLimit) { - throw InvalidProtocolBufferException.truncatedMessage(); - } - currentLimit = byteLimit; - - recomputeBufferSizeAfterLimit(); - - return oldLimit; - } - - private void recomputeBufferSizeAfterLimit() { - bufferSize += bufferSizeAfterLimit; - final int bufferEnd = totalBytesRetired + bufferSize; - if (bufferEnd > currentLimit) { - // Limit is in current buffer. - bufferSizeAfterLimit = bufferEnd - currentLimit; - bufferSize -= bufferSizeAfterLimit; - } else { - bufferSizeAfterLimit = 0; - } - } + public abstract int pushLimit(int byteLimit) throws InvalidProtocolBufferException; /** * Discards the current limit, returning to the previous limit. * * @param oldLimit The old limit, as returned by {@code pushLimit}. */ - public void popLimit(final int oldLimit) { - currentLimit = oldLimit; - recomputeBufferSizeAfterLimit(); - } + public abstract void popLimit(final int oldLimit); /** - * Returns the number of bytes to be read before the current limit. - * If no limit is set, returns -1. + * Returns the number of bytes to be read before the current limit. If no limit is set, returns + * -1. */ - public int getBytesUntilLimit() { - if (currentLimit == Integer.MAX_VALUE) { - return -1; - } - - final int currentAbsolutePosition = totalBytesRetired + bufferPos; - return currentLimit - currentAbsolutePosition; - } + public abstract int getBytesUntilLimit(); /** - * Returns true if the stream has reached the end of the input. This is the - * case if either the end of the underlying input source has been reached or - * if the stream has reached a limit created using {@link #pushLimit(int)}. + * Returns true if the stream has reached the end of the input. This is the case if either the end + * of the underlying input source has been reached or if the stream has reached a limit created + * using {@link #pushLimit(int)}. */ - public boolean isAtEnd() throws IOException { - return bufferPos == bufferSize && !tryRefillBuffer(1); - } + public abstract boolean isAtEnd() throws IOException; /** - * The total bytes read up to the current position. Calling - * {@link #resetSizeCounter()} resets this value to zero. + * The total bytes read up to the current position. Calling {@link #resetSizeCounter()} resets + * this value to zero. */ - public int getTotalBytesRead() { - return totalBytesRetired + bufferPos; - } + public abstract int getTotalBytesRead(); - private interface RefillCallback { - void onRefill(); - } + /** + * Read one byte from the input. + * + * @throws InvalidProtocolBufferException The end of the stream or the current limit was reached. + */ + public abstract byte readRawByte() throws IOException; - private RefillCallback refillCallback = null; + /** + * Read a fixed size of bytes from the input. + * + * @throws InvalidProtocolBufferException The end of the stream or the current limit was reached. + */ + public abstract byte[] readRawBytes(final int size) throws IOException; /** - * Reads more bytes from the input, making at least {@code n} bytes available - * in the buffer. Caller must ensure that the requested space is not yet - * available, and that the requested space is less than BUFFER_SIZE. + * Reads and discards {@code size} bytes. * - * @throws InvalidProtocolBufferException The end of the stream or the current - * limit was reached. + * @throws InvalidProtocolBufferException The end of the stream or the current limit was reached. */ - private void refillBuffer(int n) throws IOException { - if (!tryRefillBuffer(n)) { - throw InvalidProtocolBufferException.truncatedMessage(); - } + public abstract void skipRawBytes(final int size) throws IOException; + + /** + * Decode a ZigZag-encoded 32-bit value. ZigZag encodes signed integers into values that can be + * efficiently encoded with varint. (Otherwise, negative values must be sign-extended to 64 bits + * to be varint encoded, thus always taking 10 bytes on the wire.) + * + * @param n An unsigned 32-bit integer, stored in a signed int because Java has no explicit + * unsigned support. + * @return A signed 32-bit integer. + */ + public static int decodeZigZag32(final int n) { + return (n >>> 1) ^ -(n & 1); } /** - * Tries to read more bytes from the input, making at least {@code n} bytes - * available in the buffer. Caller must ensure that the requested space is - * not yet available, and that the requested space is less than BUFFER_SIZE. + * Decode a ZigZag-encoded 64-bit value. ZigZag encodes signed integers into values that can be + * efficiently encoded with varint. (Otherwise, negative values must be sign-extended to 64 bits + * to be varint encoded, thus always taking 10 bytes on the wire.) * - * @return {@code true} if the bytes could be made available; {@code false} - * if the end of the stream or the current limit was reached. + * @param n An unsigned 64-bit integer, stored in a signed int because Java has no explicit + * unsigned support. + * @return A signed 64-bit integer. + */ + public static long decodeZigZag64(final long n) { + return (n >>> 1) ^ -(n & 1); + } + + /** + * Like {@link #readRawVarint32(InputStream)}, but expects that the caller has already read one + * byte. This allows the caller to determine if EOF has been reached before attempting to read. */ - private boolean tryRefillBuffer(int n) throws IOException { - if (bufferPos + n <= bufferSize) { - throw new IllegalStateException( - "refillBuffer() called when " + n + - " bytes were already available in buffer"); + public static int readRawVarint32(final int firstByte, final InputStream input) + throws IOException { + if ((firstByte & 0x80) == 0) { + return firstByte; } - if (totalBytesRetired + bufferPos + n > currentLimit) { - // Oops, we hit a limit. - return false; + int result = firstByte & 0x7f; + int offset = 7; + for (; offset < 32; offset += 7) { + final int b = input.read(); + if (b == -1) { + throw InvalidProtocolBufferException.truncatedMessage(); + } + result |= (b & 0x7f) << offset; + if ((b & 0x80) == 0) { + return result; + } + } + // Keep reading up to 64 bits. + for (; offset < 64; offset += 7) { + final int b = input.read(); + if (b == -1) { + throw InvalidProtocolBufferException.truncatedMessage(); + } + if ((b & 0x80) == 0) { + return result; + } + } + throw InvalidProtocolBufferException.malformedVarint(); + } + + /** + * Reads a varint from the input one byte at a time, so that it does not read any bytes after the + * end of the varint. If you simply wrapped the stream in a CodedInputStream and used {@link + * #readRawVarint32(InputStream)} then you would probably end up reading past the end of the + * varint since CodedInputStream buffers its input. + */ + static int readRawVarint32(final InputStream input) throws IOException { + final int firstByte = input.read(); + if (firstByte == -1) { + throw InvalidProtocolBufferException.truncatedMessage(); } + return readRawVarint32(firstByte, input); + } - if (refillCallback != null) { - refillCallback.onRefill(); + /** A {@link CodedInputStream} implementation that uses a backing array as the input. */ + private static final class ArrayDecoder extends CodedInputStream { + private final byte[] buffer; + private final boolean immutable; + private int limit; + private int bufferSizeAfterLimit; + private int pos; + private int startPos; + private int lastTag; + private boolean enableAliasing; + + /** The absolute position of the end of the current message. */ + private int currentLimit = Integer.MAX_VALUE; + + private ArrayDecoder(final byte[] buffer, final int offset, final int len, boolean immutable) { + this.buffer = buffer; + limit = offset + len; + pos = offset; + startPos = pos; + this.immutable = immutable; } - if (input != null) { - int pos = bufferPos; - if (pos > 0) { - if (bufferSize > pos) { - System.arraycopy(buffer, pos, buffer, 0, bufferSize - pos); - } - totalBytesRetired += pos; - bufferSize -= pos; - bufferPos = 0; + @Override + public int readTag() throws IOException { + if (isAtEnd()) { + lastTag = 0; + return 0; } - int bytesRead = input.read(buffer, bufferSize, buffer.length - bufferSize); - if (bytesRead == 0 || bytesRead < -1 || bytesRead > buffer.length) { - throw new IllegalStateException( - "InputStream#read(byte[]) returned invalid result: " + bytesRead + - "\nThe InputStream implementation is buggy."); + lastTag = readRawVarint32(); + if (WireFormat.getTagFieldNumber(lastTag) == 0) { + // If we actually read zero (or any tag number corresponding to field + // number zero), that's not a valid tag. + throw InvalidProtocolBufferException.invalidTag(); } - if (bytesRead > 0) { - bufferSize += bytesRead; - // Integer-overflow-conscious check against sizeLimit - if (totalBytesRetired + n - sizeLimit > 0) { - throw InvalidProtocolBufferException.sizeLimitExceeded(); + return lastTag; + } + + @Override + public void checkLastTagWas(final int value) throws InvalidProtocolBufferException { + if (lastTag != value) { + throw InvalidProtocolBufferException.invalidEndTag(); + } + } + + @Override + public int getLastTag() { + return lastTag; + } + + @Override + public boolean skipField(final int tag) throws IOException { + switch (WireFormat.getTagWireType(tag)) { + case WireFormat.WIRETYPE_VARINT: + skipRawVarint(); + return true; + case WireFormat.WIRETYPE_FIXED64: + skipRawBytes(FIXED_64_SIZE); + return true; + case WireFormat.WIRETYPE_LENGTH_DELIMITED: + skipRawBytes(readRawVarint32()); + return true; + case WireFormat.WIRETYPE_START_GROUP: + skipMessage(); + checkLastTagWas( + WireFormat.makeTag(WireFormat.getTagFieldNumber(tag), WireFormat.WIRETYPE_END_GROUP)); + return true; + case WireFormat.WIRETYPE_END_GROUP: + return false; + case WireFormat.WIRETYPE_FIXED32: + skipRawBytes(FIXED_32_SIZE); + return true; + default: + throw InvalidProtocolBufferException.invalidWireType(); + } + } + + @Override + public boolean skipField(final int tag, final CodedOutputStream output) throws IOException { + switch (WireFormat.getTagWireType(tag)) { + case WireFormat.WIRETYPE_VARINT: + { + long value = readInt64(); + output.writeRawVarint32(tag); + output.writeUInt64NoTag(value); + return true; + } + case WireFormat.WIRETYPE_FIXED64: + { + long value = readRawLittleEndian64(); + output.writeRawVarint32(tag); + output.writeFixed64NoTag(value); + return true; + } + case WireFormat.WIRETYPE_LENGTH_DELIMITED: + { + ByteString value = readBytes(); + output.writeRawVarint32(tag); + output.writeBytesNoTag(value); + return true; + } + case WireFormat.WIRETYPE_START_GROUP: + { + output.writeRawVarint32(tag); + skipMessage(output); + int endtag = + WireFormat.makeTag( + WireFormat.getTagFieldNumber(tag), WireFormat.WIRETYPE_END_GROUP); + checkLastTagWas(endtag); + output.writeRawVarint32(endtag); + return true; + } + case WireFormat.WIRETYPE_END_GROUP: + { + return false; + } + case WireFormat.WIRETYPE_FIXED32: + { + int value = readRawLittleEndian32(); + output.writeRawVarint32(tag); + output.writeFixed32NoTag(value); + return true; + } + default: + throw InvalidProtocolBufferException.invalidWireType(); + } + } + + @Override + public void skipMessage() throws IOException { + while (true) { + final int tag = readTag(); + if (tag == 0 || !skipField(tag)) { + return; } - recomputeBufferSizeAfterLimit(); - return (bufferSize >= n) ? true : tryRefillBuffer(n); } } - return false; - } + @Override + public void skipMessage(CodedOutputStream output) throws IOException { + while (true) { + final int tag = readTag(); + if (tag == 0 || !skipField(tag, output)) { + return; + } + } + } - /** - * Read one byte from the input. - * - * @throws InvalidProtocolBufferException The end of the stream or the current - * limit was reached. - */ - public byte readRawByte() throws IOException { - if (bufferPos == bufferSize) { - refillBuffer(1); + + // ----------------------------------------------------------------- + + @Override + public double readDouble() throws IOException { + return Double.longBitsToDouble(readRawLittleEndian64()); } - return buffer[bufferPos++]; - } - /** - * Read a fixed size of bytes from the input. - * - * @throws InvalidProtocolBufferException The end of the stream or the current - * limit was reached. - */ - public byte[] readRawBytes(final int size) throws IOException { - final int pos = bufferPos; - if (size <= (bufferSize - pos) && size > 0) { - bufferPos = pos + size; - return Arrays.copyOfRange(buffer, pos, pos + size); - } else { - return readRawBytesSlowPath(size); + @Override + public float readFloat() throws IOException { + return Float.intBitsToFloat(readRawLittleEndian32()); } - } - /** - * Exactly like readRawBytes, but caller must have already checked the fast - * path: (size <= (bufferSize - pos) && size > 0) - */ - private byte[] readRawBytesSlowPath(final int size) throws IOException { - if (size <= 0) { + @Override + public long readUInt64() throws IOException { + return readRawVarint64(); + } + + @Override + public long readInt64() throws IOException { + return readRawVarint64(); + } + + @Override + public int readInt32() throws IOException { + return readRawVarint32(); + } + + @Override + public long readFixed64() throws IOException { + return readRawLittleEndian64(); + } + + @Override + public int readFixed32() throws IOException { + return readRawLittleEndian32(); + } + + @Override + public boolean readBool() throws IOException { + return readRawVarint64() != 0; + } + + @Override + public String readString() throws IOException { + final int size = readRawVarint32(); + if (size > 0 && size <= (limit - pos)) { + // Fast path: We already have the bytes in a contiguous buffer, so + // just copy directly from it. + final String result = new String(buffer, pos, size, UTF_8); + pos += size; + return result; + } + if (size == 0) { - return Internal.EMPTY_BYTE_ARRAY; - } else { + return ""; + } + if (size < 0) { throw InvalidProtocolBufferException.negativeSize(); } + throw InvalidProtocolBufferException.truncatedMessage(); } - // Verify that the message size so far has not exceeded sizeLimit. - int currentMessageSize = totalBytesRetired + bufferPos + size; - if (currentMessageSize > sizeLimit) { - throw InvalidProtocolBufferException.sizeLimitExceeded(); - } + @Override + public String readStringRequireUtf8() throws IOException { + final int size = readRawVarint32(); + if (size > 0 && size <= (limit - pos)) { + // TODO(martinrb): We could save a pass by validating while decoding. + if (!Utf8.isValidUtf8(buffer, pos, pos + size)) { + throw InvalidProtocolBufferException.invalidUtf8(); + } + final int tempPos = pos; + pos += size; + return new String(buffer, tempPos, size, UTF_8); + } - // Verify that the message size so far has not exceeded currentLimit. - if (currentMessageSize > currentLimit) { - // Read to the end of the stream anyway. - skipRawBytes(currentLimit - totalBytesRetired - bufferPos); + if (size == 0) { + return ""; + } + if (size <= 0) { + throw InvalidProtocolBufferException.negativeSize(); + } throw InvalidProtocolBufferException.truncatedMessage(); } - // We need the input stream to proceed. - if (input == null) { - throw InvalidProtocolBufferException.truncatedMessage(); + @Override + public void readGroup( + final int fieldNumber, + final MessageLite.Builder builder, + final ExtensionRegistryLite extensionRegistry) + throws IOException { + if (recursionDepth >= recursionLimit) { + throw InvalidProtocolBufferException.recursionLimitExceeded(); + } + ++recursionDepth; + builder.mergeFrom(this, extensionRegistry); + checkLastTagWas(WireFormat.makeTag(fieldNumber, WireFormat.WIRETYPE_END_GROUP)); + --recursionDepth; } - final int originalBufferPos = bufferPos; - final int bufferedBytes = bufferSize - bufferPos; - // Mark the current buffer consumed. - totalBytesRetired += bufferSize; - bufferPos = 0; - bufferSize = 0; + @Override + public T readGroup( + final int fieldNumber, + final Parser parser, + final ExtensionRegistryLite extensionRegistry) + throws IOException { + if (recursionDepth >= recursionLimit) { + throw InvalidProtocolBufferException.recursionLimitExceeded(); + } + ++recursionDepth; + T result = parser.parsePartialFrom(this, extensionRegistry); + checkLastTagWas(WireFormat.makeTag(fieldNumber, WireFormat.WIRETYPE_END_GROUP)); + --recursionDepth; + return result; + } - // Determine the number of bytes we need to read from the input stream. - int sizeLeft = size - bufferedBytes; - // TODO(nathanmittler): Consider using a value larger than BUFFER_SIZE. - if (sizeLeft < BUFFER_SIZE || sizeLeft <= input.available()) { - // Either the bytes we need are known to be available, or the required buffer is - // within an allowed threshold - go ahead and allocate the buffer now. - final byte[] bytes = new byte[size]; + @Deprecated + @Override + public void readUnknownGroup(final int fieldNumber, final MessageLite.Builder builder) + throws IOException { + readGroup(fieldNumber, builder, ExtensionRegistryLite.getEmptyRegistry()); + } - // Copy all of the buffered bytes to the result buffer. - System.arraycopy(buffer, originalBufferPos, bytes, 0, bufferedBytes); + @Override + public void readMessage( + final MessageLite.Builder builder, final ExtensionRegistryLite extensionRegistry) + throws IOException { + final int length = readRawVarint32(); + if (recursionDepth >= recursionLimit) { + throw InvalidProtocolBufferException.recursionLimitExceeded(); + } + final int oldLimit = pushLimit(length); + ++recursionDepth; + builder.mergeFrom(this, extensionRegistry); + checkLastTagWas(0); + --recursionDepth; + popLimit(oldLimit); + } - // Fill the remaining bytes from the input stream. - int pos = bufferedBytes; - while (pos < bytes.length) { - int n = input.read(bytes, pos, size - pos); - if (n == -1) { - throw InvalidProtocolBufferException.truncatedMessage(); - } - totalBytesRetired += n; - pos += n; + + @Override + public T readMessage( + final Parser parser, final ExtensionRegistryLite extensionRegistry) throws IOException { + int length = readRawVarint32(); + if (recursionDepth >= recursionLimit) { + throw InvalidProtocolBufferException.recursionLimitExceeded(); + } + final int oldLimit = pushLimit(length); + ++recursionDepth; + T result = parser.parsePartialFrom(this, extensionRegistry); + checkLastTagWas(0); + --recursionDepth; + popLimit(oldLimit); + return result; + } + + @Override + public ByteString readBytes() throws IOException { + final int size = readRawVarint32(); + if (size > 0 && size <= (limit - pos)) { + // Fast path: We already have the bytes in a contiguous buffer, so + // just copy directly from it. + final ByteString result = + immutable && enableAliasing + ? ByteString.wrap(buffer, pos, size) + : ByteString.copyFrom(buffer, pos, size); + pos += size; + return result; + } + if (size == 0) { + return ByteString.EMPTY; } + // Slow path: Build a byte array first then copy it. + return ByteString.wrap(readRawBytes(size)); + } - return bytes; + @Override + public byte[] readByteArray() throws IOException { + final int size = readRawVarint32(); + return readRawBytes(size); } - // The size is very large. For security reasons, we can't allocate the - // entire byte array yet. The size comes directly from the input, so a - // maliciously-crafted message could provide a bogus very large size in - // order to trick the app into allocating a lot of memory. We avoid this - // by allocating and reading only a small chunk at a time, so that the - // malicious message must actually *be* extremely large to cause - // problems. Meanwhile, we limit the allowed size of a message elsewhere. - final List chunks = new ArrayList(); - - while (sizeLeft > 0) { - // TODO(nathanmittler): Consider using a value larger than BUFFER_SIZE. - final byte[] chunk = new byte[Math.min(sizeLeft, BUFFER_SIZE)]; - int pos = 0; - while (pos < chunk.length) { - final int n = input.read(chunk, pos, chunk.length - pos); - if (n == -1) { - throw InvalidProtocolBufferException.truncatedMessage(); - } - totalBytesRetired += n; - pos += n; + @Override + public ByteBuffer readByteBuffer() throws IOException { + final int size = readRawVarint32(); + if (size > 0 && size <= (limit - pos)) { + // Fast path: We already have the bytes in a contiguous buffer. + // When aliasing is enabled, we can return a ByteBuffer pointing directly + // into the underlying byte array without copy if the CodedInputStream is + // constructed from a byte array. If aliasing is disabled or the input is + // from an InputStream or ByteString, we have to make a copy of the bytes. + ByteBuffer result = + !immutable && enableAliasing + ? ByteBuffer.wrap(buffer, pos, size).slice() + : ByteBuffer.wrap(Arrays.copyOfRange(buffer, pos, pos + size)); + pos += size; + // TODO(nathanmittler): Investigate making the ByteBuffer be made read-only + return result; + } + + if (size == 0) { + return EMPTY_BYTE_BUFFER; + } + if (size < 0) { + throw InvalidProtocolBufferException.negativeSize(); } - sizeLeft -= chunk.length; - chunks.add(chunk); + throw InvalidProtocolBufferException.truncatedMessage(); } - // OK, got everything. Now concatenate it all into one buffer. - final byte[] bytes = new byte[size]; + @Override + public int readUInt32() throws IOException { + return readRawVarint32(); + } - // Start by copying the leftover bytes from this.buffer. - System.arraycopy(buffer, originalBufferPos, bytes, 0, bufferedBytes); + @Override + public int readEnum() throws IOException { + return readRawVarint32(); + } - // And now all the chunks. - int pos = bufferedBytes; - for (final byte[] chunk : chunks) { - System.arraycopy(chunk, 0, bytes, pos, chunk.length); - pos += chunk.length; + @Override + public int readSFixed32() throws IOException { + return readRawLittleEndian32(); } - // Done. - return bytes; - } + @Override + public long readSFixed64() throws IOException { + return readRawLittleEndian64(); + } - /** - * Reads and discards {@code size} bytes. - * - * @throws InvalidProtocolBufferException The end of the stream or the current - * limit was reached. - */ - public void skipRawBytes(final int size) throws IOException { - if (size <= (bufferSize - bufferPos) && size >= 0) { - // We have all the bytes we need already. - bufferPos += size; - } else { - skipRawBytesSlowPath(size); + @Override + public int readSInt32() throws IOException { + return decodeZigZag32(readRawVarint32()); } - } - /** - * Exactly like skipRawBytes, but caller must have already checked the fast - * path: (size <= (bufferSize - pos) && size >= 0) - */ - private void skipRawBytesSlowPath(final int size) throws IOException { - if (size < 0) { - throw InvalidProtocolBufferException.negativeSize(); + @Override + public long readSInt64() throws IOException { + return decodeZigZag64(readRawVarint64()); } - if (totalBytesRetired + bufferPos + size > currentLimit) { - // Read to the end of the stream anyway. - skipRawBytes(currentLimit - totalBytesRetired - bufferPos); - // Then fail. - throw InvalidProtocolBufferException.truncatedMessage(); + // ================================================================= + + @Override + public int readRawVarint32() throws IOException { + // See implementation notes for readRawVarint64 + fastpath: + { + int tempPos = pos; + + if (limit == tempPos) { + break fastpath; + } + + final byte[] buffer = this.buffer; + int x; + if ((x = buffer[tempPos++]) >= 0) { + pos = tempPos; + return x; + } else if (limit - tempPos < 9) { + break fastpath; + } else if ((x ^= (buffer[tempPos++] << 7)) < 0) { + x ^= (~0 << 7); + } else if ((x ^= (buffer[tempPos++] << 14)) >= 0) { + x ^= (~0 << 7) ^ (~0 << 14); + } else if ((x ^= (buffer[tempPos++] << 21)) < 0) { + x ^= (~0 << 7) ^ (~0 << 14) ^ (~0 << 21); + } else { + int y = buffer[tempPos++]; + x ^= y << 28; + x ^= (~0 << 7) ^ (~0 << 14) ^ (~0 << 21) ^ (~0 << 28); + if (y < 0 + && buffer[tempPos++] < 0 + && buffer[tempPos++] < 0 + && buffer[tempPos++] < 0 + && buffer[tempPos++] < 0 + && buffer[tempPos++] < 0) { + break fastpath; // Will throw malformedVarint() + } + } + pos = tempPos; + return x; + } + return (int) readRawVarint64SlowPath(); + } + + private void skipRawVarint() throws IOException { + if (limit - pos >= MAX_VARINT_SIZE) { + skipRawVarintFastPath(); + } else { + skipRawVarintSlowPath(); + } } - // Skipping more bytes than are in the buffer. First skip what we have. - int pos = bufferSize - bufferPos; - bufferPos = bufferSize; + private void skipRawVarintFastPath() throws IOException { + for (int i = 0; i < MAX_VARINT_SIZE; i++) { + if (buffer[pos++] >= 0) { + return; + } + } + throw InvalidProtocolBufferException.malformedVarint(); + } - // Keep refilling the buffer until we get to the point we wanted to skip to. - // This has the side effect of ensuring the limits are updated correctly. - refillBuffer(1); - while (size - pos > bufferSize) { - pos += bufferSize; - bufferPos = bufferSize; - refillBuffer(1); + private void skipRawVarintSlowPath() throws IOException { + for (int i = 0; i < MAX_VARINT_SIZE; i++) { + if (readRawByte() >= 0) { + return; + } + } + throw InvalidProtocolBufferException.malformedVarint(); + } + + @Override + public long readRawVarint64() throws IOException { + // Implementation notes: + // + // Optimized for one-byte values, expected to be common. + // The particular code below was selected from various candidates + // empirically, by winning VarintBenchmark. + // + // Sign extension of (signed) Java bytes is usually a nuisance, but + // we exploit it here to more easily obtain the sign of bytes read. + // Instead of cleaning up the sign extension bits by masking eagerly, + // we delay until we find the final (positive) byte, when we clear all + // accumulated bits with one xor. We depend on javac to constant fold. + fastpath: + { + int tempPos = pos; + + if (limit == tempPos) { + break fastpath; + } + + final byte[] buffer = this.buffer; + long x; + int y; + if ((y = buffer[tempPos++]) >= 0) { + pos = tempPos; + return y; + } else if (limit - tempPos < 9) { + break fastpath; + } else if ((y ^= (buffer[tempPos++] << 7)) < 0) { + x = y ^ (~0 << 7); + } else if ((y ^= (buffer[tempPos++] << 14)) >= 0) { + x = y ^ ((~0 << 7) ^ (~0 << 14)); + } else if ((y ^= (buffer[tempPos++] << 21)) < 0) { + x = y ^ ((~0 << 7) ^ (~0 << 14) ^ (~0 << 21)); + } else if ((x = y ^ ((long) buffer[tempPos++] << 28)) >= 0L) { + x ^= (~0L << 7) ^ (~0L << 14) ^ (~0L << 21) ^ (~0L << 28); + } else if ((x ^= ((long) buffer[tempPos++] << 35)) < 0L) { + x ^= (~0L << 7) ^ (~0L << 14) ^ (~0L << 21) ^ (~0L << 28) ^ (~0L << 35); + } else if ((x ^= ((long) buffer[tempPos++] << 42)) >= 0L) { + x ^= (~0L << 7) ^ (~0L << 14) ^ (~0L << 21) ^ (~0L << 28) ^ (~0L << 35) ^ (~0L << 42); + } else if ((x ^= ((long) buffer[tempPos++] << 49)) < 0L) { + x ^= + (~0L << 7) + ^ (~0L << 14) + ^ (~0L << 21) + ^ (~0L << 28) + ^ (~0L << 35) + ^ (~0L << 42) + ^ (~0L << 49); + } else { + x ^= ((long) buffer[tempPos++] << 56); + x ^= + (~0L << 7) + ^ (~0L << 14) + ^ (~0L << 21) + ^ (~0L << 28) + ^ (~0L << 35) + ^ (~0L << 42) + ^ (~0L << 49) + ^ (~0L << 56); + if (x < 0L) { + if (buffer[tempPos++] < 0L) { + break fastpath; // Will throw malformedVarint() + } + } + } + pos = tempPos; + return x; + } + return readRawVarint64SlowPath(); + } + + @Override + long readRawVarint64SlowPath() throws IOException { + long result = 0; + for (int shift = 0; shift < 64; shift += 7) { + final byte b = readRawByte(); + result |= (long) (b & 0x7F) << shift; + if ((b & 0x80) == 0) { + return result; + } + } + throw InvalidProtocolBufferException.malformedVarint(); } - bufferPos = size - pos; + @Override + public int readRawLittleEndian32() throws IOException { + int tempPos = pos; + + if (limit - tempPos < FIXED_32_SIZE) { + throw InvalidProtocolBufferException.truncatedMessage(); + } + + final byte[] buffer = this.buffer; + pos = tempPos + FIXED_32_SIZE; + return (((buffer[tempPos] & 0xff)) + | ((buffer[tempPos + 1] & 0xff) << 8) + | ((buffer[tempPos + 2] & 0xff) << 16) + | ((buffer[tempPos + 3] & 0xff) << 24)); + } + + @Override + public long readRawLittleEndian64() throws IOException { + int tempPos = pos; + + if (limit - tempPos < FIXED_64_SIZE) { + throw InvalidProtocolBufferException.truncatedMessage(); + } + + final byte[] buffer = this.buffer; + pos = tempPos + FIXED_64_SIZE; + return (((buffer[tempPos] & 0xffL)) + | ((buffer[tempPos + 1] & 0xffL) << 8) + | ((buffer[tempPos + 2] & 0xffL) << 16) + | ((buffer[tempPos + 3] & 0xffL) << 24) + | ((buffer[tempPos + 4] & 0xffL) << 32) + | ((buffer[tempPos + 5] & 0xffL) << 40) + | ((buffer[tempPos + 6] & 0xffL) << 48) + | ((buffer[tempPos + 7] & 0xffL) << 56)); + } + + @Override + public void enableAliasing(boolean enabled) { + this.enableAliasing = enabled; + } + + @Override + public void resetSizeCounter() { + startPos = pos; + } + + @Override + public int pushLimit(int byteLimit) throws InvalidProtocolBufferException { + if (byteLimit < 0) { + throw InvalidProtocolBufferException.negativeSize(); + } + byteLimit += getTotalBytesRead(); + final int oldLimit = currentLimit; + if (byteLimit > oldLimit) { + throw InvalidProtocolBufferException.truncatedMessage(); + } + currentLimit = byteLimit; + + recomputeBufferSizeAfterLimit(); + + return oldLimit; + } + + private void recomputeBufferSizeAfterLimit() { + limit += bufferSizeAfterLimit; + final int bufferEnd = limit - startPos; + if (bufferEnd > currentLimit) { + // Limit is in current buffer. + bufferSizeAfterLimit = bufferEnd - currentLimit; + limit -= bufferSizeAfterLimit; + } else { + bufferSizeAfterLimit = 0; + } + } + + @Override + public void popLimit(final int oldLimit) { + currentLimit = oldLimit; + recomputeBufferSizeAfterLimit(); + } + + @Override + public int getBytesUntilLimit() { + if (currentLimit == Integer.MAX_VALUE) { + return -1; + } + + return currentLimit - getTotalBytesRead(); + } + + @Override + public boolean isAtEnd() throws IOException { + return pos == limit; + } + + @Override + public int getTotalBytesRead() { + return pos - startPos; + } + + @Override + public byte readRawByte() throws IOException { + if (pos == limit) { + throw InvalidProtocolBufferException.truncatedMessage(); + } + return buffer[pos++]; + } + + @Override + public byte[] readRawBytes(final int length) throws IOException { + if (length > 0 && length <= (limit - pos)) { + final int tempPos = pos; + pos += length; + return Arrays.copyOfRange(buffer, tempPos, pos); + } + + if (length <= 0) { + if (length == 0) { + return Internal.EMPTY_BYTE_ARRAY; + } else { + throw InvalidProtocolBufferException.negativeSize(); + } + } + throw InvalidProtocolBufferException.truncatedMessage(); + } + + @Override + public void skipRawBytes(final int length) throws IOException { + if (length >= 0 && length <= (limit - pos)) { + // We have all the bytes we need already. + pos += length; + return; + } + + if (length < 0) { + throw InvalidProtocolBufferException.negativeSize(); + } + throw InvalidProtocolBufferException.truncatedMessage(); + } + } + + /** + * A {@link CodedInputStream} implementation that uses a backing direct ByteBuffer as the input. + * Requires the use of {@code sun.misc.Unsafe} to perform fast reads on the buffer. + */ + private static final class UnsafeDirectNioDecoder extends CodedInputStream { + /** The direct buffer that is backing this stream. */ + private final ByteBuffer buffer; + + /** + * If {@code true}, indicates that the buffer is backing a {@link ByteString} and is therefore + * considered to be an immutable input source. + */ + private final boolean immutable; + + /** The unsafe address of the content of {@link #buffer}. */ + private final long address; + + /** The unsafe address of the current read limit of the buffer. */ + private long limit; + + /** The unsafe address of the current read position of the buffer. */ + private long pos; + + /** The unsafe address of the starting read position. */ + private long startPos; + + /** The amount of available data in the buffer beyond {@link #limit}. */ + private int bufferSizeAfterLimit; + + /** The last tag that was read from this stream. */ + private int lastTag; + + /** + * If {@code true}, indicates that calls to read {@link ByteString} or {@code byte[]} + * may return slices of the underlying buffer, rather than copies. + */ + private boolean enableAliasing; + + /** The absolute position of the end of the current message. */ + private int currentLimit = Integer.MAX_VALUE; + + static boolean isSupported() { + return UnsafeUtil.hasUnsafeByteBufferOperations(); + } + + private UnsafeDirectNioDecoder(ByteBuffer buffer, boolean immutable) { + this.buffer = buffer; + address = UnsafeUtil.addressOffset(buffer); + limit = address + buffer.limit(); + pos = address + buffer.position(); + startPos = pos; + this.immutable = immutable; + } + + @Override + public int readTag() throws IOException { + if (isAtEnd()) { + lastTag = 0; + return 0; + } + + lastTag = readRawVarint32(); + if (WireFormat.getTagFieldNumber(lastTag) == 0) { + // If we actually read zero (or any tag number corresponding to field + // number zero), that's not a valid tag. + throw InvalidProtocolBufferException.invalidTag(); + } + return lastTag; + } + + @Override + public void checkLastTagWas(final int value) throws InvalidProtocolBufferException { + if (lastTag != value) { + throw InvalidProtocolBufferException.invalidEndTag(); + } + } + + @Override + public int getLastTag() { + return lastTag; + } + + @Override + public boolean skipField(final int tag) throws IOException { + switch (WireFormat.getTagWireType(tag)) { + case WireFormat.WIRETYPE_VARINT: + skipRawVarint(); + return true; + case WireFormat.WIRETYPE_FIXED64: + skipRawBytes(FIXED_64_SIZE); + return true; + case WireFormat.WIRETYPE_LENGTH_DELIMITED: + skipRawBytes(readRawVarint32()); + return true; + case WireFormat.WIRETYPE_START_GROUP: + skipMessage(); + checkLastTagWas( + WireFormat.makeTag(WireFormat.getTagFieldNumber(tag), WireFormat.WIRETYPE_END_GROUP)); + return true; + case WireFormat.WIRETYPE_END_GROUP: + return false; + case WireFormat.WIRETYPE_FIXED32: + skipRawBytes(FIXED_32_SIZE); + return true; + default: + throw InvalidProtocolBufferException.invalidWireType(); + } + } + + @Override + public boolean skipField(final int tag, final CodedOutputStream output) throws IOException { + switch (WireFormat.getTagWireType(tag)) { + case WireFormat.WIRETYPE_VARINT: + { + long value = readInt64(); + output.writeRawVarint32(tag); + output.writeUInt64NoTag(value); + return true; + } + case WireFormat.WIRETYPE_FIXED64: + { + long value = readRawLittleEndian64(); + output.writeRawVarint32(tag); + output.writeFixed64NoTag(value); + return true; + } + case WireFormat.WIRETYPE_LENGTH_DELIMITED: + { + ByteString value = readBytes(); + output.writeRawVarint32(tag); + output.writeBytesNoTag(value); + return true; + } + case WireFormat.WIRETYPE_START_GROUP: + { + output.writeRawVarint32(tag); + skipMessage(output); + int endtag = + WireFormat.makeTag( + WireFormat.getTagFieldNumber(tag), WireFormat.WIRETYPE_END_GROUP); + checkLastTagWas(endtag); + output.writeRawVarint32(endtag); + return true; + } + case WireFormat.WIRETYPE_END_GROUP: + { + return false; + } + case WireFormat.WIRETYPE_FIXED32: + { + int value = readRawLittleEndian32(); + output.writeRawVarint32(tag); + output.writeFixed32NoTag(value); + return true; + } + default: + throw InvalidProtocolBufferException.invalidWireType(); + } + } + + @Override + public void skipMessage() throws IOException { + while (true) { + final int tag = readTag(); + if (tag == 0 || !skipField(tag)) { + return; + } + } + } + + @Override + public void skipMessage(CodedOutputStream output) throws IOException { + while (true) { + final int tag = readTag(); + if (tag == 0 || !skipField(tag, output)) { + return; + } + } + } + + + // ----------------------------------------------------------------- + + @Override + public double readDouble() throws IOException { + return Double.longBitsToDouble(readRawLittleEndian64()); + } + + @Override + public float readFloat() throws IOException { + return Float.intBitsToFloat(readRawLittleEndian32()); + } + + @Override + public long readUInt64() throws IOException { + return readRawVarint64(); + } + + @Override + public long readInt64() throws IOException { + return readRawVarint64(); + } + + @Override + public int readInt32() throws IOException { + return readRawVarint32(); + } + + @Override + public long readFixed64() throws IOException { + return readRawLittleEndian64(); + } + + @Override + public int readFixed32() throws IOException { + return readRawLittleEndian32(); + } + + @Override + public boolean readBool() throws IOException { + return readRawVarint64() != 0; + } + + @Override + public String readString() throws IOException { + final int size = readRawVarint32(); + if (size > 0 && size <= remaining()) { + // TODO(nathanmittler): Is there a way to avoid this copy? + byte[] bytes = copyToArray(pos, pos + size); + String result = new String(bytes, UTF_8); + pos += size; + return result; + } + + if (size == 0) { + return ""; + } + if (size < 0) { + throw InvalidProtocolBufferException.negativeSize(); + } + throw InvalidProtocolBufferException.truncatedMessage(); + } + + @Override + public String readStringRequireUtf8() throws IOException { + final int size = readRawVarint32(); + if (size >= 0 && size <= remaining()) { + // TODO(nathanmittler): Is there a way to avoid this copy? + byte[] bytes = copyToArray(pos, pos + size); + // TODO(martinrb): We could save a pass by validating while decoding. + if (!Utf8.isValidUtf8(bytes)) { + throw InvalidProtocolBufferException.invalidUtf8(); + } + + String result = new String(bytes, UTF_8); + pos += size; + return result; + } + + if (size == 0) { + return ""; + } + if (size <= 0) { + throw InvalidProtocolBufferException.negativeSize(); + } + throw InvalidProtocolBufferException.truncatedMessage(); + } + + @Override + public void readGroup( + final int fieldNumber, + final MessageLite.Builder builder, + final ExtensionRegistryLite extensionRegistry) + throws IOException { + if (recursionDepth >= recursionLimit) { + throw InvalidProtocolBufferException.recursionLimitExceeded(); + } + ++recursionDepth; + builder.mergeFrom(this, extensionRegistry); + checkLastTagWas(WireFormat.makeTag(fieldNumber, WireFormat.WIRETYPE_END_GROUP)); + --recursionDepth; + } + + + @Override + public T readGroup( + final int fieldNumber, + final Parser parser, + final ExtensionRegistryLite extensionRegistry) + throws IOException { + if (recursionDepth >= recursionLimit) { + throw InvalidProtocolBufferException.recursionLimitExceeded(); + } + ++recursionDepth; + T result = parser.parsePartialFrom(this, extensionRegistry); + checkLastTagWas(WireFormat.makeTag(fieldNumber, WireFormat.WIRETYPE_END_GROUP)); + --recursionDepth; + return result; + } + + @Deprecated + @Override + public void readUnknownGroup(final int fieldNumber, final MessageLite.Builder builder) + throws IOException { + readGroup(fieldNumber, builder, ExtensionRegistryLite.getEmptyRegistry()); + } + + @Override + public void readMessage( + final MessageLite.Builder builder, final ExtensionRegistryLite extensionRegistry) + throws IOException { + final int length = readRawVarint32(); + if (recursionDepth >= recursionLimit) { + throw InvalidProtocolBufferException.recursionLimitExceeded(); + } + final int oldLimit = pushLimit(length); + ++recursionDepth; + builder.mergeFrom(this, extensionRegistry); + checkLastTagWas(0); + --recursionDepth; + popLimit(oldLimit); + } + + + @Override + public T readMessage( + final Parser parser, final ExtensionRegistryLite extensionRegistry) throws IOException { + int length = readRawVarint32(); + if (recursionDepth >= recursionLimit) { + throw InvalidProtocolBufferException.recursionLimitExceeded(); + } + final int oldLimit = pushLimit(length); + ++recursionDepth; + T result = parser.parsePartialFrom(this, extensionRegistry); + checkLastTagWas(0); + --recursionDepth; + popLimit(oldLimit); + return result; + } + + @Override + public ByteString readBytes() throws IOException { + final int size = readRawVarint32(); + if (size > 0 && size <= remaining()) { + ByteBuffer result; + if (immutable && enableAliasing) { + result = slice(pos, pos + size); + } else { + result = copy(pos, pos + size); + } + pos += size; + return ByteString.wrap(result); + } + + if (size == 0) { + return ByteString.EMPTY; + } + if (size < 0) { + throw InvalidProtocolBufferException.negativeSize(); + } + throw InvalidProtocolBufferException.truncatedMessage(); + } + + @Override + public byte[] readByteArray() throws IOException { + return readRawBytes(readRawVarint32()); + } + + @Override + public ByteBuffer readByteBuffer() throws IOException { + final int size = readRawVarint32(); + if (size > 0 && size <= remaining()) { + ByteBuffer result; + // "Immutable" implies that buffer is backing a ByteString. + // Disallow slicing in this case to prevent the caller from modifying the contents + // of the ByteString. + if (!immutable && enableAliasing) { + result = slice(pos, pos + size); + } else { + result = copy(pos, pos + size); + } + pos += size; + // TODO(nathanmittler): Investigate making the ByteBuffer be made read-only + return result; + } + + if (size == 0) { + return EMPTY_BYTE_BUFFER; + } + if (size < 0) { + throw InvalidProtocolBufferException.negativeSize(); + } + throw InvalidProtocolBufferException.truncatedMessage(); + } + + @Override + public int readUInt32() throws IOException { + return readRawVarint32(); + } + + @Override + public int readEnum() throws IOException { + return readRawVarint32(); + } + + @Override + public int readSFixed32() throws IOException { + return readRawLittleEndian32(); + } + + @Override + public long readSFixed64() throws IOException { + return readRawLittleEndian64(); + } + + @Override + public int readSInt32() throws IOException { + return decodeZigZag32(readRawVarint32()); + } + + @Override + public long readSInt64() throws IOException { + return decodeZigZag64(readRawVarint64()); + } + + // ================================================================= + + @Override + public int readRawVarint32() throws IOException { + // See implementation notes for readRawVarint64 + fastpath: + { + long tempPos = pos; + + if (limit == tempPos) { + break fastpath; + } + + int x; + if ((x = UnsafeUtil.getByte(tempPos++)) >= 0) { + pos = tempPos; + return x; + } else if (limit - tempPos < 9) { + break fastpath; + } else if ((x ^= (UnsafeUtil.getByte(tempPos++) << 7)) < 0) { + x ^= (~0 << 7); + } else if ((x ^= (UnsafeUtil.getByte(tempPos++) << 14)) >= 0) { + x ^= (~0 << 7) ^ (~0 << 14); + } else if ((x ^= (UnsafeUtil.getByte(tempPos++) << 21)) < 0) { + x ^= (~0 << 7) ^ (~0 << 14) ^ (~0 << 21); + } else { + int y = UnsafeUtil.getByte(tempPos++); + x ^= y << 28; + x ^= (~0 << 7) ^ (~0 << 14) ^ (~0 << 21) ^ (~0 << 28); + if (y < 0 + && UnsafeUtil.getByte(tempPos++) < 0 + && UnsafeUtil.getByte(tempPos++) < 0 + && UnsafeUtil.getByte(tempPos++) < 0 + && UnsafeUtil.getByte(tempPos++) < 0 + && UnsafeUtil.getByte(tempPos++) < 0) { + break fastpath; // Will throw malformedVarint() + } + } + pos = tempPos; + return x; + } + return (int) readRawVarint64SlowPath(); + } + + private void skipRawVarint() throws IOException { + if (remaining() >= MAX_VARINT_SIZE) { + skipRawVarintFastPath(); + } else { + skipRawVarintSlowPath(); + } + } + + private void skipRawVarintFastPath() throws IOException { + for (int i = 0; i < MAX_VARINT_SIZE; i++) { + if (UnsafeUtil.getByte(pos++) >= 0) { + return; + } + } + throw InvalidProtocolBufferException.malformedVarint(); + } + + private void skipRawVarintSlowPath() throws IOException { + for (int i = 0; i < MAX_VARINT_SIZE; i++) { + if (readRawByte() >= 0) { + return; + } + } + throw InvalidProtocolBufferException.malformedVarint(); + } + + @Override + public long readRawVarint64() throws IOException { + // Implementation notes: + // + // Optimized for one-byte values, expected to be common. + // The particular code below was selected from various candidates + // empirically, by winning VarintBenchmark. + // + // Sign extension of (signed) Java bytes is usually a nuisance, but + // we exploit it here to more easily obtain the sign of bytes read. + // Instead of cleaning up the sign extension bits by masking eagerly, + // we delay until we find the final (positive) byte, when we clear all + // accumulated bits with one xor. We depend on javac to constant fold. + fastpath: + { + long tempPos = pos; + + if (limit == tempPos) { + break fastpath; + } + + long x; + int y; + if ((y = UnsafeUtil.getByte(tempPos++)) >= 0) { + pos = tempPos; + return y; + } else if (limit - tempPos < 9) { + break fastpath; + } else if ((y ^= (UnsafeUtil.getByte(tempPos++) << 7)) < 0) { + x = y ^ (~0 << 7); + } else if ((y ^= (UnsafeUtil.getByte(tempPos++) << 14)) >= 0) { + x = y ^ ((~0 << 7) ^ (~0 << 14)); + } else if ((y ^= (UnsafeUtil.getByte(tempPos++) << 21)) < 0) { + x = y ^ ((~0 << 7) ^ (~0 << 14) ^ (~0 << 21)); + } else if ((x = y ^ ((long) UnsafeUtil.getByte(tempPos++) << 28)) >= 0L) { + x ^= (~0L << 7) ^ (~0L << 14) ^ (~0L << 21) ^ (~0L << 28); + } else if ((x ^= ((long) UnsafeUtil.getByte(tempPos++) << 35)) < 0L) { + x ^= (~0L << 7) ^ (~0L << 14) ^ (~0L << 21) ^ (~0L << 28) ^ (~0L << 35); + } else if ((x ^= ((long) UnsafeUtil.getByte(tempPos++) << 42)) >= 0L) { + x ^= (~0L << 7) ^ (~0L << 14) ^ (~0L << 21) ^ (~0L << 28) ^ (~0L << 35) ^ (~0L << 42); + } else if ((x ^= ((long) UnsafeUtil.getByte(tempPos++) << 49)) < 0L) { + x ^= + (~0L << 7) + ^ (~0L << 14) + ^ (~0L << 21) + ^ (~0L << 28) + ^ (~0L << 35) + ^ (~0L << 42) + ^ (~0L << 49); + } else { + x ^= ((long) UnsafeUtil.getByte(tempPos++) << 56); + x ^= + (~0L << 7) + ^ (~0L << 14) + ^ (~0L << 21) + ^ (~0L << 28) + ^ (~0L << 35) + ^ (~0L << 42) + ^ (~0L << 49) + ^ (~0L << 56); + if (x < 0L) { + if (UnsafeUtil.getByte(tempPos++) < 0L) { + break fastpath; // Will throw malformedVarint() + } + } + } + pos = tempPos; + return x; + } + return readRawVarint64SlowPath(); + } + + @Override + long readRawVarint64SlowPath() throws IOException { + long result = 0; + for (int shift = 0; shift < 64; shift += 7) { + final byte b = readRawByte(); + result |= (long) (b & 0x7F) << shift; + if ((b & 0x80) == 0) { + return result; + } + } + throw InvalidProtocolBufferException.malformedVarint(); + } + + @Override + public int readRawLittleEndian32() throws IOException { + long tempPos = pos; + + if (limit - tempPos < FIXED_32_SIZE) { + throw InvalidProtocolBufferException.truncatedMessage(); + } + + pos = tempPos + FIXED_32_SIZE; + return (((UnsafeUtil.getByte(tempPos) & 0xff)) + | ((UnsafeUtil.getByte(tempPos + 1) & 0xff) << 8) + | ((UnsafeUtil.getByte(tempPos + 2) & 0xff) << 16) + | ((UnsafeUtil.getByte(tempPos + 3) & 0xff) << 24)); + } + + @Override + public long readRawLittleEndian64() throws IOException { + long tempPos = pos; + + if (limit - tempPos < FIXED_64_SIZE) { + throw InvalidProtocolBufferException.truncatedMessage(); + } + + pos = tempPos + FIXED_64_SIZE; + return (((UnsafeUtil.getByte(tempPos) & 0xffL)) + | ((UnsafeUtil.getByte(tempPos + 1) & 0xffL) << 8) + | ((UnsafeUtil.getByte(tempPos + 2) & 0xffL) << 16) + | ((UnsafeUtil.getByte(tempPos + 3) & 0xffL) << 24) + | ((UnsafeUtil.getByte(tempPos + 4) & 0xffL) << 32) + | ((UnsafeUtil.getByte(tempPos + 5) & 0xffL) << 40) + | ((UnsafeUtil.getByte(tempPos + 6) & 0xffL) << 48) + | ((UnsafeUtil.getByte(tempPos + 7) & 0xffL) << 56)); + } + + @Override + public void enableAliasing(boolean enabled) { + this.enableAliasing = enabled; + } + + @Override + public void resetSizeCounter() { + startPos = pos; + } + + @Override + public int pushLimit(int byteLimit) throws InvalidProtocolBufferException { + if (byteLimit < 0) { + throw InvalidProtocolBufferException.negativeSize(); + } + byteLimit += getTotalBytesRead(); + final int oldLimit = currentLimit; + if (byteLimit > oldLimit) { + throw InvalidProtocolBufferException.truncatedMessage(); + } + currentLimit = byteLimit; + + recomputeBufferSizeAfterLimit(); + + return oldLimit; + } + + @Override + public void popLimit(final int oldLimit) { + currentLimit = oldLimit; + recomputeBufferSizeAfterLimit(); + } + + @Override + public int getBytesUntilLimit() { + if (currentLimit == Integer.MAX_VALUE) { + return -1; + } + + return currentLimit - getTotalBytesRead(); + } + + @Override + public boolean isAtEnd() throws IOException { + return pos == limit; + } + + @Override + public int getTotalBytesRead() { + return (int) (pos - startPos); + } + + @Override + public byte readRawByte() throws IOException { + if (pos == limit) { + throw InvalidProtocolBufferException.truncatedMessage(); + } + return UnsafeUtil.getByte(pos++); + } + + @Override + public byte[] readRawBytes(final int length) throws IOException { + if (length >= 0 && length <= remaining()) { + byte[] bytes = new byte[length]; + slice(pos, pos + length).get(bytes); + pos += length; + return bytes; + } + + if (length <= 0) { + if (length == 0) { + return EMPTY_BYTE_ARRAY; + } else { + throw InvalidProtocolBufferException.negativeSize(); + } + } + + throw InvalidProtocolBufferException.truncatedMessage(); + } + + @Override + public void skipRawBytes(final int length) throws IOException { + if (length >= 0 && length <= remaining()) { + // We have all the bytes we need already. + pos += length; + return; + } + + if (length < 0) { + throw InvalidProtocolBufferException.negativeSize(); + } + throw InvalidProtocolBufferException.truncatedMessage(); + } + + private void recomputeBufferSizeAfterLimit() { + limit += bufferSizeAfterLimit; + final int bufferEnd = (int) (limit - startPos); + if (bufferEnd > currentLimit) { + // Limit is in current buffer. + bufferSizeAfterLimit = bufferEnd - currentLimit; + limit -= bufferSizeAfterLimit; + } else { + bufferSizeAfterLimit = 0; + } + } + + private int remaining() { + return (int) (limit - pos); + } + + private int bufferPos(long pos) { + return (int) (pos - address); + } + + private ByteBuffer slice(long begin, long end) throws IOException { + int prevPos = buffer.position(); + int prevLimit = buffer.limit(); + try { + buffer.position(bufferPos(begin)); + buffer.limit(bufferPos(end)); + return buffer.slice(); + } catch (IllegalArgumentException e) { + throw InvalidProtocolBufferException.truncatedMessage(); + } finally { + buffer.position(prevPos); + buffer.limit(prevLimit); + } + } + + private ByteBuffer copy(long begin, long end) throws IOException { + return ByteBuffer.wrap(copyToArray(begin, end)); + } + + private byte[] copyToArray(long begin, long end) throws IOException { + int prevPos = buffer.position(); + int prevLimit = buffer.limit(); + try { + buffer.position(bufferPos(begin)); + buffer.limit(bufferPos(end)); + byte[] bytes = new byte[(int) (end - begin)]; + buffer.get(bytes); + return bytes; + } catch (IllegalArgumentException e) { + throw InvalidProtocolBufferException.truncatedMessage(); + } finally { + buffer.position(prevPos); + buffer.limit(prevLimit); + } + } + } + + /** + * Implementation of {@link CodedInputStream} that uses an {@link InputStream} as the data source. + */ + private static final class StreamDecoder extends CodedInputStream { + private final InputStream input; + private final byte[] buffer; + /** bufferSize represents how many bytes are currently filled in the buffer */ + private int bufferSize; + + private int bufferSizeAfterLimit; + private int pos; + private int lastTag; + + /** + * The total number of bytes read before the current buffer. The total bytes read up to the + * current position can be computed as {@code totalBytesRetired + pos}. This value may be + * negative if reading started in the middle of the current buffer (e.g. if the constructor that + * takes a byte array and an offset was used). + */ + private int totalBytesRetired; + + /** The absolute position of the end of the current message. */ + private int currentLimit = Integer.MAX_VALUE; + + private StreamDecoder(final InputStream input, int bufferSize) { + checkNotNull(input, "input"); + this.input = input; + this.buffer = new byte[bufferSize]; + this.bufferSize = 0; + pos = 0; + totalBytesRetired = 0; + } + + @Override + public int readTag() throws IOException { + if (isAtEnd()) { + lastTag = 0; + return 0; + } + + lastTag = readRawVarint32(); + if (WireFormat.getTagFieldNumber(lastTag) == 0) { + // If we actually read zero (or any tag number corresponding to field + // number zero), that's not a valid tag. + throw InvalidProtocolBufferException.invalidTag(); + } + return lastTag; + } + + @Override + public void checkLastTagWas(final int value) throws InvalidProtocolBufferException { + if (lastTag != value) { + throw InvalidProtocolBufferException.invalidEndTag(); + } + } + + @Override + public int getLastTag() { + return lastTag; + } + + @Override + public boolean skipField(final int tag) throws IOException { + switch (WireFormat.getTagWireType(tag)) { + case WireFormat.WIRETYPE_VARINT: + skipRawVarint(); + return true; + case WireFormat.WIRETYPE_FIXED64: + skipRawBytes(FIXED_64_SIZE); + return true; + case WireFormat.WIRETYPE_LENGTH_DELIMITED: + skipRawBytes(readRawVarint32()); + return true; + case WireFormat.WIRETYPE_START_GROUP: + skipMessage(); + checkLastTagWas( + WireFormat.makeTag(WireFormat.getTagFieldNumber(tag), WireFormat.WIRETYPE_END_GROUP)); + return true; + case WireFormat.WIRETYPE_END_GROUP: + return false; + case WireFormat.WIRETYPE_FIXED32: + skipRawBytes(FIXED_32_SIZE); + return true; + default: + throw InvalidProtocolBufferException.invalidWireType(); + } + } + + @Override + public boolean skipField(final int tag, final CodedOutputStream output) throws IOException { + switch (WireFormat.getTagWireType(tag)) { + case WireFormat.WIRETYPE_VARINT: + { + long value = readInt64(); + output.writeRawVarint32(tag); + output.writeUInt64NoTag(value); + return true; + } + case WireFormat.WIRETYPE_FIXED64: + { + long value = readRawLittleEndian64(); + output.writeRawVarint32(tag); + output.writeFixed64NoTag(value); + return true; + } + case WireFormat.WIRETYPE_LENGTH_DELIMITED: + { + ByteString value = readBytes(); + output.writeRawVarint32(tag); + output.writeBytesNoTag(value); + return true; + } + case WireFormat.WIRETYPE_START_GROUP: + { + output.writeRawVarint32(tag); + skipMessage(output); + int endtag = + WireFormat.makeTag( + WireFormat.getTagFieldNumber(tag), WireFormat.WIRETYPE_END_GROUP); + checkLastTagWas(endtag); + output.writeRawVarint32(endtag); + return true; + } + case WireFormat.WIRETYPE_END_GROUP: + { + return false; + } + case WireFormat.WIRETYPE_FIXED32: + { + int value = readRawLittleEndian32(); + output.writeRawVarint32(tag); + output.writeFixed32NoTag(value); + return true; + } + default: + throw InvalidProtocolBufferException.invalidWireType(); + } + } + + @Override + public void skipMessage() throws IOException { + while (true) { + final int tag = readTag(); + if (tag == 0 || !skipField(tag)) { + return; + } + } + } + + @Override + public void skipMessage(CodedOutputStream output) throws IOException { + while (true) { + final int tag = readTag(); + if (tag == 0 || !skipField(tag, output)) { + return; + } + } + } + + /** Collects the bytes skipped and returns the data in a ByteBuffer. */ + private class SkippedDataSink implements RefillCallback { + private int lastPos = pos; + private ByteArrayOutputStream byteArrayStream; + + @Override + public void onRefill() { + if (byteArrayStream == null) { + byteArrayStream = new ByteArrayOutputStream(); + } + byteArrayStream.write(buffer, lastPos, pos - lastPos); + lastPos = 0; + } + + /** Gets skipped data in a ByteBuffer. This method should only be called once. */ + ByteBuffer getSkippedData() { + if (byteArrayStream == null) { + return ByteBuffer.wrap(buffer, lastPos, pos - lastPos); + } else { + byteArrayStream.write(buffer, lastPos, pos); + return ByteBuffer.wrap(byteArrayStream.toByteArray()); + } + } + } + + + // ----------------------------------------------------------------- + + @Override + public double readDouble() throws IOException { + return Double.longBitsToDouble(readRawLittleEndian64()); + } + + @Override + public float readFloat() throws IOException { + return Float.intBitsToFloat(readRawLittleEndian32()); + } + + @Override + public long readUInt64() throws IOException { + return readRawVarint64(); + } + + @Override + public long readInt64() throws IOException { + return readRawVarint64(); + } + + @Override + public int readInt32() throws IOException { + return readRawVarint32(); + } + + @Override + public long readFixed64() throws IOException { + return readRawLittleEndian64(); + } + + @Override + public int readFixed32() throws IOException { + return readRawLittleEndian32(); + } + + @Override + public boolean readBool() throws IOException { + return readRawVarint64() != 0; + } + + @Override + public String readString() throws IOException { + final int size = readRawVarint32(); + if (size > 0 && size <= (bufferSize - pos)) { + // Fast path: We already have the bytes in a contiguous buffer, so + // just copy directly from it. + final String result = new String(buffer, pos, size, UTF_8); + pos += size; + return result; + } + if (size == 0) { + return ""; + } + if (size <= bufferSize) { + refillBuffer(size); + String result = new String(buffer, pos, size, UTF_8); + pos += size; + return result; + } + // Slow path: Build a byte array first then copy it. + return new String(readRawBytesSlowPath(size), UTF_8); + } + + @Override + public String readStringRequireUtf8() throws IOException { + final int size = readRawVarint32(); + final byte[] bytes; + final int oldPos = pos; + final int tempPos; + if (size <= (bufferSize - oldPos) && size > 0) { + // Fast path: We already have the bytes in a contiguous buffer, so + // just copy directly from it. + bytes = buffer; + pos = oldPos + size; + tempPos = oldPos; + } else if (size == 0) { + return ""; + } else if (size <= bufferSize) { + refillBuffer(size); + bytes = buffer; + tempPos = 0; + pos = tempPos + size; + } else { + // Slow path: Build a byte array first then copy it. + bytes = readRawBytesSlowPath(size); + tempPos = 0; + } + // TODO(martinrb): We could save a pass by validating while decoding. + if (!Utf8.isValidUtf8(bytes, tempPos, tempPos + size)) { + throw InvalidProtocolBufferException.invalidUtf8(); + } + return new String(bytes, tempPos, size, UTF_8); + } + + @Override + public void readGroup( + final int fieldNumber, + final MessageLite.Builder builder, + final ExtensionRegistryLite extensionRegistry) + throws IOException { + if (recursionDepth >= recursionLimit) { + throw InvalidProtocolBufferException.recursionLimitExceeded(); + } + ++recursionDepth; + builder.mergeFrom(this, extensionRegistry); + checkLastTagWas(WireFormat.makeTag(fieldNumber, WireFormat.WIRETYPE_END_GROUP)); + --recursionDepth; + } + + + @Override + public T readGroup( + final int fieldNumber, + final Parser parser, + final ExtensionRegistryLite extensionRegistry) + throws IOException { + if (recursionDepth >= recursionLimit) { + throw InvalidProtocolBufferException.recursionLimitExceeded(); + } + ++recursionDepth; + T result = parser.parsePartialFrom(this, extensionRegistry); + checkLastTagWas(WireFormat.makeTag(fieldNumber, WireFormat.WIRETYPE_END_GROUP)); + --recursionDepth; + return result; + } + + @Deprecated + @Override + public void readUnknownGroup(final int fieldNumber, final MessageLite.Builder builder) + throws IOException { + readGroup(fieldNumber, builder, ExtensionRegistryLite.getEmptyRegistry()); + } + + @Override + public void readMessage( + final MessageLite.Builder builder, final ExtensionRegistryLite extensionRegistry) + throws IOException { + final int length = readRawVarint32(); + if (recursionDepth >= recursionLimit) { + throw InvalidProtocolBufferException.recursionLimitExceeded(); + } + final int oldLimit = pushLimit(length); + ++recursionDepth; + builder.mergeFrom(this, extensionRegistry); + checkLastTagWas(0); + --recursionDepth; + popLimit(oldLimit); + } + + + @Override + public T readMessage( + final Parser parser, final ExtensionRegistryLite extensionRegistry) throws IOException { + int length = readRawVarint32(); + if (recursionDepth >= recursionLimit) { + throw InvalidProtocolBufferException.recursionLimitExceeded(); + } + final int oldLimit = pushLimit(length); + ++recursionDepth; + T result = parser.parsePartialFrom(this, extensionRegistry); + checkLastTagWas(0); + --recursionDepth; + popLimit(oldLimit); + return result; + } + + @Override + public ByteString readBytes() throws IOException { + final int size = readRawVarint32(); + if (size <= (bufferSize - pos) && size > 0) { + // Fast path: We already have the bytes in a contiguous buffer, so + // just copy directly from it. + final ByteString result = ByteString.copyFrom(buffer, pos, size); + pos += size; + return result; + } + if (size == 0) { + return ByteString.EMPTY; + } + // Slow path: Build a byte array first then copy it. + return ByteString.wrap(readRawBytesSlowPath(size)); + } + + @Override + public byte[] readByteArray() throws IOException { + final int size = readRawVarint32(); + if (size <= (bufferSize - pos) && size > 0) { + // Fast path: We already have the bytes in a contiguous buffer, so + // just copy directly from it. + final byte[] result = Arrays.copyOfRange(buffer, pos, pos + size); + pos += size; + return result; + } else { + // Slow path: Build a byte array first then copy it. + return readRawBytesSlowPath(size); + } + } + + @Override + public ByteBuffer readByteBuffer() throws IOException { + final int size = readRawVarint32(); + if (size <= (bufferSize - pos) && size > 0) { + // Fast path: We already have the bytes in a contiguous buffer. + ByteBuffer result = ByteBuffer.wrap(Arrays.copyOfRange(buffer, pos, pos + size)); + pos += size; + return result; + } + if (size == 0) { + return Internal.EMPTY_BYTE_BUFFER; + } + // Slow path: Build a byte array first then copy it. + return ByteBuffer.wrap(readRawBytesSlowPath(size)); + } + + @Override + public int readUInt32() throws IOException { + return readRawVarint32(); + } + + @Override + public int readEnum() throws IOException { + return readRawVarint32(); + } + + @Override + public int readSFixed32() throws IOException { + return readRawLittleEndian32(); + } + + @Override + public long readSFixed64() throws IOException { + return readRawLittleEndian64(); + } + + @Override + public int readSInt32() throws IOException { + return decodeZigZag32(readRawVarint32()); + } + + @Override + public long readSInt64() throws IOException { + return decodeZigZag64(readRawVarint64()); + } + + // ================================================================= + + @Override + public int readRawVarint32() throws IOException { + // See implementation notes for readRawVarint64 + fastpath: + { + int tempPos = pos; + + if (bufferSize == tempPos) { + break fastpath; + } + + final byte[] buffer = this.buffer; + int x; + if ((x = buffer[tempPos++]) >= 0) { + pos = tempPos; + return x; + } else if (bufferSize - tempPos < 9) { + break fastpath; + } else if ((x ^= (buffer[tempPos++] << 7)) < 0) { + x ^= (~0 << 7); + } else if ((x ^= (buffer[tempPos++] << 14)) >= 0) { + x ^= (~0 << 7) ^ (~0 << 14); + } else if ((x ^= (buffer[tempPos++] << 21)) < 0) { + x ^= (~0 << 7) ^ (~0 << 14) ^ (~0 << 21); + } else { + int y = buffer[tempPos++]; + x ^= y << 28; + x ^= (~0 << 7) ^ (~0 << 14) ^ (~0 << 21) ^ (~0 << 28); + if (y < 0 + && buffer[tempPos++] < 0 + && buffer[tempPos++] < 0 + && buffer[tempPos++] < 0 + && buffer[tempPos++] < 0 + && buffer[tempPos++] < 0) { + break fastpath; // Will throw malformedVarint() + } + } + pos = tempPos; + return x; + } + return (int) readRawVarint64SlowPath(); + } + + private void skipRawVarint() throws IOException { + if (bufferSize - pos >= MAX_VARINT_SIZE) { + skipRawVarintFastPath(); + } else { + skipRawVarintSlowPath(); + } + } + + private void skipRawVarintFastPath() throws IOException { + for (int i = 0; i < MAX_VARINT_SIZE; i++) { + if (buffer[pos++] >= 0) { + return; + } + } + throw InvalidProtocolBufferException.malformedVarint(); + } + + private void skipRawVarintSlowPath() throws IOException { + for (int i = 0; i < MAX_VARINT_SIZE; i++) { + if (readRawByte() >= 0) { + return; + } + } + throw InvalidProtocolBufferException.malformedVarint(); + } + + @Override + public long readRawVarint64() throws IOException { + // Implementation notes: + // + // Optimized for one-byte values, expected to be common. + // The particular code below was selected from various candidates + // empirically, by winning VarintBenchmark. + // + // Sign extension of (signed) Java bytes is usually a nuisance, but + // we exploit it here to more easily obtain the sign of bytes read. + // Instead of cleaning up the sign extension bits by masking eagerly, + // we delay until we find the final (positive) byte, when we clear all + // accumulated bits with one xor. We depend on javac to constant fold. + fastpath: + { + int tempPos = pos; + + if (bufferSize == tempPos) { + break fastpath; + } + + final byte[] buffer = this.buffer; + long x; + int y; + if ((y = buffer[tempPos++]) >= 0) { + pos = tempPos; + return y; + } else if (bufferSize - tempPos < 9) { + break fastpath; + } else if ((y ^= (buffer[tempPos++] << 7)) < 0) { + x = y ^ (~0 << 7); + } else if ((y ^= (buffer[tempPos++] << 14)) >= 0) { + x = y ^ ((~0 << 7) ^ (~0 << 14)); + } else if ((y ^= (buffer[tempPos++] << 21)) < 0) { + x = y ^ ((~0 << 7) ^ (~0 << 14) ^ (~0 << 21)); + } else if ((x = y ^ ((long) buffer[tempPos++] << 28)) >= 0L) { + x ^= (~0L << 7) ^ (~0L << 14) ^ (~0L << 21) ^ (~0L << 28); + } else if ((x ^= ((long) buffer[tempPos++] << 35)) < 0L) { + x ^= (~0L << 7) ^ (~0L << 14) ^ (~0L << 21) ^ (~0L << 28) ^ (~0L << 35); + } else if ((x ^= ((long) buffer[tempPos++] << 42)) >= 0L) { + x ^= (~0L << 7) ^ (~0L << 14) ^ (~0L << 21) ^ (~0L << 28) ^ (~0L << 35) ^ (~0L << 42); + } else if ((x ^= ((long) buffer[tempPos++] << 49)) < 0L) { + x ^= + (~0L << 7) + ^ (~0L << 14) + ^ (~0L << 21) + ^ (~0L << 28) + ^ (~0L << 35) + ^ (~0L << 42) + ^ (~0L << 49); + } else { + x ^= ((long) buffer[tempPos++] << 56); + x ^= + (~0L << 7) + ^ (~0L << 14) + ^ (~0L << 21) + ^ (~0L << 28) + ^ (~0L << 35) + ^ (~0L << 42) + ^ (~0L << 49) + ^ (~0L << 56); + if (x < 0L) { + if (buffer[tempPos++] < 0L) { + break fastpath; // Will throw malformedVarint() + } + } + } + pos = tempPos; + return x; + } + return readRawVarint64SlowPath(); + } + + @Override + long readRawVarint64SlowPath() throws IOException { + long result = 0; + for (int shift = 0; shift < 64; shift += 7) { + final byte b = readRawByte(); + result |= (long) (b & 0x7F) << shift; + if ((b & 0x80) == 0) { + return result; + } + } + throw InvalidProtocolBufferException.malformedVarint(); + } + + @Override + public int readRawLittleEndian32() throws IOException { + int tempPos = pos; + + if (bufferSize - tempPos < FIXED_32_SIZE) { + refillBuffer(FIXED_32_SIZE); + tempPos = pos; + } + + final byte[] buffer = this.buffer; + pos = tempPos + FIXED_32_SIZE; + return (((buffer[tempPos] & 0xff)) + | ((buffer[tempPos + 1] & 0xff) << 8) + | ((buffer[tempPos + 2] & 0xff) << 16) + | ((buffer[tempPos + 3] & 0xff) << 24)); + } + + @Override + public long readRawLittleEndian64() throws IOException { + int tempPos = pos; + + if (bufferSize - tempPos < FIXED_64_SIZE) { + refillBuffer(FIXED_64_SIZE); + tempPos = pos; + } + + final byte[] buffer = this.buffer; + pos = tempPos + FIXED_64_SIZE; + return (((buffer[tempPos] & 0xffL)) + | ((buffer[tempPos + 1] & 0xffL) << 8) + | ((buffer[tempPos + 2] & 0xffL) << 16) + | ((buffer[tempPos + 3] & 0xffL) << 24) + | ((buffer[tempPos + 4] & 0xffL) << 32) + | ((buffer[tempPos + 5] & 0xffL) << 40) + | ((buffer[tempPos + 6] & 0xffL) << 48) + | ((buffer[tempPos + 7] & 0xffL) << 56)); + } + + // ----------------------------------------------------------------- + + @Override + public void enableAliasing(boolean enabled) { + // TODO(nathanmittler): Ideally we should throw here. Do nothing for backward compatibility. + } + + @Override + public void resetSizeCounter() { + totalBytesRetired = -pos; + } + + @Override + public int pushLimit(int byteLimit) throws InvalidProtocolBufferException { + if (byteLimit < 0) { + throw InvalidProtocolBufferException.negativeSize(); + } + byteLimit += totalBytesRetired + pos; + final int oldLimit = currentLimit; + if (byteLimit > oldLimit) { + throw InvalidProtocolBufferException.truncatedMessage(); + } + currentLimit = byteLimit; + + recomputeBufferSizeAfterLimit(); + + return oldLimit; + } + + private void recomputeBufferSizeAfterLimit() { + bufferSize += bufferSizeAfterLimit; + final int bufferEnd = totalBytesRetired + bufferSize; + if (bufferEnd > currentLimit) { + // Limit is in current buffer. + bufferSizeAfterLimit = bufferEnd - currentLimit; + bufferSize -= bufferSizeAfterLimit; + } else { + bufferSizeAfterLimit = 0; + } + } + + @Override + public void popLimit(final int oldLimit) { + currentLimit = oldLimit; + recomputeBufferSizeAfterLimit(); + } + + @Override + public int getBytesUntilLimit() { + if (currentLimit == Integer.MAX_VALUE) { + return -1; + } + + final int currentAbsolutePosition = totalBytesRetired + pos; + return currentLimit - currentAbsolutePosition; + } + + @Override + public boolean isAtEnd() throws IOException { + return pos == bufferSize && !tryRefillBuffer(1); + } + + @Override + public int getTotalBytesRead() { + return totalBytesRetired + pos; + } + + private interface RefillCallback { + void onRefill(); + } + + private RefillCallback refillCallback = null; + + /** + * Reads more bytes from the input, making at least {@code n} bytes available in the buffer. + * Caller must ensure that the requested space is not yet available, and that the requested + * space is less than BUFFER_SIZE. + * + * @throws InvalidProtocolBufferException The end of the stream or the current limit was + * reached. + */ + private void refillBuffer(int n) throws IOException { + if (!tryRefillBuffer(n)) { + throw InvalidProtocolBufferException.truncatedMessage(); + } + } + + /** + * Tries to read more bytes from the input, making at least {@code n} bytes available in the + * buffer. Caller must ensure that the requested space is not yet available, and that the + * requested space is less than BUFFER_SIZE. + * + * @return {@code true} if the bytes could be made available; {@code false} if the end of the + * stream or the current limit was reached. + */ + private boolean tryRefillBuffer(int n) throws IOException { + if (pos + n <= bufferSize) { + throw new IllegalStateException( + "refillBuffer() called when " + n + " bytes were already available in buffer"); + } + + if (totalBytesRetired + pos + n > currentLimit) { + // Oops, we hit a limit. + return false; + } + + if (refillCallback != null) { + refillCallback.onRefill(); + } + + int tempPos = pos; + if (tempPos > 0) { + if (bufferSize > tempPos) { + System.arraycopy(buffer, tempPos, buffer, 0, bufferSize - tempPos); + } + totalBytesRetired += tempPos; + bufferSize -= tempPos; + pos = 0; + } + + int bytesRead = input.read(buffer, bufferSize, buffer.length - bufferSize); + if (bytesRead == 0 || bytesRead < -1 || bytesRead > buffer.length) { + throw new IllegalStateException( + "InputStream#read(byte[]) returned invalid result: " + + bytesRead + + "\nThe InputStream implementation is buggy."); + } + if (bytesRead > 0) { + bufferSize += bytesRead; + // Integer-overflow-conscious check against sizeLimit + if (totalBytesRetired + n - sizeLimit > 0) { + throw InvalidProtocolBufferException.sizeLimitExceeded(); + } + recomputeBufferSizeAfterLimit(); + return (bufferSize >= n) ? true : tryRefillBuffer(n); + } + + return false; + } + + @Override + public byte readRawByte() throws IOException { + if (pos == bufferSize) { + refillBuffer(1); + } + return buffer[pos++]; + } + + @Override + public byte[] readRawBytes(final int size) throws IOException { + final int tempPos = pos; + if (size <= (bufferSize - tempPos) && size > 0) { + pos = tempPos + size; + return Arrays.copyOfRange(buffer, tempPos, tempPos + size); + } else { + return readRawBytesSlowPath(size); + } + } + + /** + * Exactly like readRawBytes, but caller must have already checked the fast path: (size <= + * (bufferSize - pos) && size > 0) + */ + private byte[] readRawBytesSlowPath(final int size) throws IOException { + if (size == 0) { + return Internal.EMPTY_BYTE_ARRAY; + } + if (size < 0) { + throw InvalidProtocolBufferException.negativeSize(); + } + + // Verify that the message size so far has not exceeded sizeLimit. + int currentMessageSize = totalBytesRetired + pos + size; + if (currentMessageSize > sizeLimit) { + throw InvalidProtocolBufferException.sizeLimitExceeded(); + } + + // Verify that the message size so far has not exceeded currentLimit. + if (currentMessageSize > currentLimit) { + // Read to the end of the stream anyway. + skipRawBytes(currentLimit - totalBytesRetired - pos); + throw InvalidProtocolBufferException.truncatedMessage(); + } + + final int originalBufferPos = pos; + final int bufferedBytes = bufferSize - pos; + + // Mark the current buffer consumed. + totalBytesRetired += bufferSize; + pos = 0; + bufferSize = 0; + + // Determine the number of bytes we need to read from the input stream. + int sizeLeft = size - bufferedBytes; + // TODO(nathanmittler): Consider using a value larger than DEFAULT_BUFFER_SIZE. + if (sizeLeft < DEFAULT_BUFFER_SIZE || sizeLeft <= input.available()) { + // Either the bytes we need are known to be available, or the required buffer is + // within an allowed threshold - go ahead and allocate the buffer now. + final byte[] bytes = new byte[size]; + + // Copy all of the buffered bytes to the result buffer. + System.arraycopy(buffer, originalBufferPos, bytes, 0, bufferedBytes); + + // Fill the remaining bytes from the input stream. + int tempPos = bufferedBytes; + while (tempPos < bytes.length) { + int n = input.read(bytes, tempPos, size - tempPos); + if (n == -1) { + throw InvalidProtocolBufferException.truncatedMessage(); + } + totalBytesRetired += n; + tempPos += n; + } + + return bytes; + } + + // The size is very large. For security reasons, we can't allocate the + // entire byte array yet. The size comes directly from the input, so a + // maliciously-crafted message could provide a bogus very large size in + // order to trick the app into allocating a lot of memory. We avoid this + // by allocating and reading only a small chunk at a time, so that the + // malicious message must actually *be* extremely large to cause + // problems. Meanwhile, we limit the allowed size of a message elsewhere. + final List chunks = new ArrayList(); + + while (sizeLeft > 0) { + // TODO(nathanmittler): Consider using a value larger than DEFAULT_BUFFER_SIZE. + final byte[] chunk = new byte[Math.min(sizeLeft, DEFAULT_BUFFER_SIZE)]; + int tempPos = 0; + while (tempPos < chunk.length) { + final int n = input.read(chunk, tempPos, chunk.length - tempPos); + if (n == -1) { + throw InvalidProtocolBufferException.truncatedMessage(); + } + totalBytesRetired += n; + tempPos += n; + } + sizeLeft -= chunk.length; + chunks.add(chunk); + } + + // OK, got everything. Now concatenate it all into one buffer. + final byte[] bytes = new byte[size]; + + // Start by copying the leftover bytes from this.buffer. + System.arraycopy(buffer, originalBufferPos, bytes, 0, bufferedBytes); + + // And now all the chunks. + int tempPos = bufferedBytes; + for (final byte[] chunk : chunks) { + System.arraycopy(chunk, 0, bytes, tempPos, chunk.length); + tempPos += chunk.length; + } + + // Done. + return bytes; + } + + @Override + public void skipRawBytes(final int size) throws IOException { + if (size <= (bufferSize - pos) && size >= 0) { + // We have all the bytes we need already. + pos += size; + } else { + skipRawBytesSlowPath(size); + } + } + + /** + * Exactly like skipRawBytes, but caller must have already checked the fast path: (size <= + * (bufferSize - pos) && size >= 0) + */ + private void skipRawBytesSlowPath(final int size) throws IOException { + if (size < 0) { + throw InvalidProtocolBufferException.negativeSize(); + } + + if (totalBytesRetired + pos + size > currentLimit) { + // Read to the end of the stream anyway. + skipRawBytes(currentLimit - totalBytesRetired - pos); + // Then fail. + throw InvalidProtocolBufferException.truncatedMessage(); + } + + // Skipping more bytes than are in the buffer. First skip what we have. + int tempPos = bufferSize - pos; + pos = bufferSize; + + // Keep refilling the buffer until we get to the point we wanted to skip to. + // This has the side effect of ensuring the limits are updated correctly. + refillBuffer(1); + while (size - tempPos > bufferSize) { + tempPos += bufferSize; + pos = bufferSize; + refillBuffer(1); + } + + pos = size - tempPos; + } } } diff --git a/java/core/src/main/java/com/google/protobuf/CodedOutputStream.java b/java/core/src/main/java/com/google/protobuf/CodedOutputStream.java index e5515285..3e32c2c5 100644 --- a/java/core/src/main/java/com/google/protobuf/CodedOutputStream.java +++ b/java/core/src/main/java/com/google/protobuf/CodedOutputStream.java @@ -30,10 +30,12 @@ package com.google.protobuf; +import static com.google.protobuf.WireFormat.FIXED_32_SIZE; +import static com.google.protobuf.WireFormat.FIXED_64_SIZE; +import static com.google.protobuf.WireFormat.MAX_VARINT_SIZE; import static java.lang.Math.max; import com.google.protobuf.Utf8.UnpairedSurrogateException; - import java.io.IOException; import java.io.OutputStream; import java.nio.BufferOverflowException; @@ -59,10 +61,6 @@ public abstract class CodedOutputStream extends ByteOutput { private static final boolean HAS_UNSAFE_ARRAY_OPERATIONS = UnsafeUtil.hasUnsafeArrayOperations(); private static final long ARRAY_BASE_OFFSET = UnsafeUtil.getArrayBaseOffset(); - private static final int FIXED_32_SIZE = 4; - private static final int FIXED_64_SIZE = 8; - private static final int MAX_VARINT_SIZE = 10; - /** * @deprecated Use {@link #computeFixed32SizeNoTag(int)} instead. */ @@ -134,14 +132,27 @@ public abstract class CodedOutputStream extends ByteOutput { return new ArrayEncoder(flatArray, offset, length); } - /** - * Create a new {@code CodedOutputStream} that writes to the given {@link ByteBuffer}. - */ - public static CodedOutputStream newInstance(ByteBuffer byteBuffer) { - if (byteBuffer.hasArray()) { - return new NioHeapEncoder(byteBuffer); + /** Create a new {@code CodedOutputStream} that writes to the given {@link ByteBuffer}. */ + public static CodedOutputStream newInstance(ByteBuffer buffer) { + if (buffer.hasArray()) { + return new HeapNioEncoder(buffer); } - return new NioEncoder(byteBuffer); + if (buffer.isDirect() && !buffer.isReadOnly()) { + return UnsafeDirectNioEncoder.isSupported() + ? newUnsafeInstance(buffer) + : newSafeInstance(buffer); + } + throw new IllegalArgumentException("ByteBuffer is read-only"); + } + + /** For testing purposes only. */ + static CodedOutputStream newUnsafeInstance(ByteBuffer buffer) { + return new UnsafeDirectNioEncoder(buffer); + } + + /** For testing purposes only. */ + static CodedOutputStream newSafeInstance(ByteBuffer buffer) { + return new SafeDirectNioEncoder(buffer); } /** @@ -979,6 +990,10 @@ public abstract class CodedOutputStream extends ByteOutput { super(MESSAGE); } + OutOfSpaceException(String explanationMessage) { + super(MESSAGE + ": " + explanationMessage); + } + OutOfSpaceException(Throwable cause) { super(MESSAGE, cause); } @@ -1486,11 +1501,11 @@ public abstract class CodedOutputStream extends ByteOutput { * A {@link CodedOutputStream} that writes directly to a heap {@link ByteBuffer}. Writes are * done directly to the underlying array. The buffer position is only updated after a flush. */ - private static final class NioHeapEncoder extends ArrayEncoder { + private static final class HeapNioEncoder extends ArrayEncoder { private final ByteBuffer byteBuffer; private int initialPosition; - NioHeapEncoder(ByteBuffer byteBuffer) { + HeapNioEncoder(ByteBuffer byteBuffer) { super(byteBuffer.array(), byteBuffer.arrayOffset() + byteBuffer.position(), byteBuffer.remaining()); this.byteBuffer = byteBuffer; @@ -1505,14 +1520,15 @@ public abstract class CodedOutputStream extends ByteOutput { } /** - * A {@link CodedOutputStream} that writes directly to a {@link ByteBuffer}. + * A {@link CodedOutputStream} that writes directly to a direct {@link ByteBuffer}, using only + * safe operations.. */ - private static final class NioEncoder extends CodedOutputStream { + private static final class SafeDirectNioEncoder extends CodedOutputStream { private final ByteBuffer originalBuffer; private final ByteBuffer buffer; private final int initialPosition; - NioEncoder(ByteBuffer buffer) { + SafeDirectNioEncoder(ByteBuffer buffer) { this.originalBuffer = buffer; this.buffer = buffer.duplicate().order(ByteOrder.LITTLE_ENDIAN); initialPosition = buffer.position(); @@ -1814,6 +1830,356 @@ public abstract class CodedOutputStream extends ByteOutput { } } + /** + * A {@link CodedOutputStream} that writes directly to a direct {@link ByteBuffer} using {@code + * sun.misc.Unsafe}. + */ + private static final class UnsafeDirectNioEncoder extends CodedOutputStream { + private final ByteBuffer originalBuffer; + private final ByteBuffer buffer; + private final long address; + private final long initialPosition; + private final long limit; + private final long oneVarintLimit; + private long position; + + UnsafeDirectNioEncoder(ByteBuffer buffer) { + this.originalBuffer = buffer; + this.buffer = buffer.duplicate().order(ByteOrder.LITTLE_ENDIAN); + address = UnsafeUtil.addressOffset(buffer); + initialPosition = address + buffer.position(); + limit = address + buffer.limit(); + oneVarintLimit = limit - MAX_VARINT_SIZE; + position = initialPosition; + } + + static boolean isSupported() { + return UnsafeUtil.hasUnsafeByteBufferOperations(); + } + + @Override + public void writeTag(int fieldNumber, int wireType) throws IOException { + writeUInt32NoTag(WireFormat.makeTag(fieldNumber, wireType)); + } + + @Override + public void writeInt32(int fieldNumber, int value) throws IOException { + writeTag(fieldNumber, WireFormat.WIRETYPE_VARINT); + writeInt32NoTag(value); + } + + @Override + public void writeUInt32(int fieldNumber, int value) throws IOException { + writeTag(fieldNumber, WireFormat.WIRETYPE_VARINT); + writeUInt32NoTag(value); + } + + @Override + public void writeFixed32(int fieldNumber, int value) throws IOException { + writeTag(fieldNumber, WireFormat.WIRETYPE_FIXED32); + writeFixed32NoTag(value); + } + + @Override + public void writeUInt64(int fieldNumber, long value) throws IOException { + writeTag(fieldNumber, WireFormat.WIRETYPE_VARINT); + writeUInt64NoTag(value); + } + + @Override + public void writeFixed64(int fieldNumber, long value) throws IOException { + writeTag(fieldNumber, WireFormat.WIRETYPE_FIXED64); + writeFixed64NoTag(value); + } + + @Override + public void writeBool(int fieldNumber, boolean value) throws IOException { + writeTag(fieldNumber, WireFormat.WIRETYPE_VARINT); + write((byte) (value ? 1 : 0)); + } + + @Override + public void writeString(int fieldNumber, String value) throws IOException { + writeTag(fieldNumber, WireFormat.WIRETYPE_LENGTH_DELIMITED); + writeStringNoTag(value); + } + + @Override + public void writeBytes(int fieldNumber, ByteString value) throws IOException { + writeTag(fieldNumber, WireFormat.WIRETYPE_LENGTH_DELIMITED); + writeBytesNoTag(value); + } + + @Override + public void writeByteArray(int fieldNumber, byte[] value) throws IOException { + writeByteArray(fieldNumber, value, 0, value.length); + } + + @Override + public void writeByteArray(int fieldNumber, byte[] value, int offset, int length) + throws IOException { + writeTag(fieldNumber, WireFormat.WIRETYPE_LENGTH_DELIMITED); + writeByteArrayNoTag(value, offset, length); + } + + @Override + public void writeByteBuffer(int fieldNumber, ByteBuffer value) throws IOException { + writeTag(fieldNumber, WireFormat.WIRETYPE_LENGTH_DELIMITED); + writeUInt32NoTag(value.capacity()); + writeRawBytes(value); + } + + @Override + public void writeMessage(int fieldNumber, MessageLite value) throws IOException { + writeTag(fieldNumber, WireFormat.WIRETYPE_LENGTH_DELIMITED); + writeMessageNoTag(value); + } + + @Override + public void writeMessageSetExtension(int fieldNumber, MessageLite value) throws IOException { + writeTag(WireFormat.MESSAGE_SET_ITEM, WireFormat.WIRETYPE_START_GROUP); + writeUInt32(WireFormat.MESSAGE_SET_TYPE_ID, fieldNumber); + writeMessage(WireFormat.MESSAGE_SET_MESSAGE, value); + writeTag(WireFormat.MESSAGE_SET_ITEM, WireFormat.WIRETYPE_END_GROUP); + } + + @Override + public void writeRawMessageSetExtension(int fieldNumber, ByteString value) throws IOException { + writeTag(WireFormat.MESSAGE_SET_ITEM, WireFormat.WIRETYPE_START_GROUP); + writeUInt32(WireFormat.MESSAGE_SET_TYPE_ID, fieldNumber); + writeBytes(WireFormat.MESSAGE_SET_MESSAGE, value); + writeTag(WireFormat.MESSAGE_SET_ITEM, WireFormat.WIRETYPE_END_GROUP); + } + + @Override + public void writeMessageNoTag(MessageLite value) throws IOException { + writeUInt32NoTag(value.getSerializedSize()); + value.writeTo(this); + } + + @Override + public void write(byte value) throws IOException { + if (position >= limit) { + throw new OutOfSpaceException( + String.format("Pos: %d, limit: %d, len: %d", position, limit, 1)); + } + UnsafeUtil.putByte(position++, value); + } + + @Override + public void writeBytesNoTag(ByteString value) throws IOException { + writeUInt32NoTag(value.size()); + value.writeTo(this); + } + + @Override + public void writeByteArrayNoTag(byte[] value, int offset, int length) throws IOException { + writeUInt32NoTag(length); + write(value, offset, length); + } + + @Override + public void writeRawBytes(ByteBuffer value) throws IOException { + if (value.hasArray()) { + write(value.array(), value.arrayOffset(), value.capacity()); + } else { + ByteBuffer duplicated = value.duplicate(); + duplicated.clear(); + write(duplicated); + } + } + + @Override + public void writeInt32NoTag(int value) throws IOException { + if (value >= 0) { + writeUInt32NoTag(value); + } else { + // Must sign-extend. + writeUInt64NoTag(value); + } + } + + @Override + public void writeUInt32NoTag(int value) throws IOException { + if (position <= oneVarintLimit) { + // Optimization to avoid bounds checks on each iteration. + while (true) { + if ((value & ~0x7F) == 0) { + UnsafeUtil.putByte(position++, (byte) value); + return; + } else { + UnsafeUtil.putByte(position++, (byte) ((value & 0x7F) | 0x80)); + value >>>= 7; + } + } + } else { + while (position < limit) { + if ((value & ~0x7F) == 0) { + UnsafeUtil.putByte(position++, (byte) value); + return; + } else { + UnsafeUtil.putByte(position++, (byte) ((value & 0x7F) | 0x80)); + value >>>= 7; + } + } + throw new OutOfSpaceException( + String.format("Pos: %d, limit: %d, len: %d", position, limit, 1)); + } + } + + @Override + public void writeFixed32NoTag(int value) throws IOException { + buffer.putInt(bufferPos(position), value); + position += FIXED_32_SIZE; + } + + @Override + public void writeUInt64NoTag(long value) throws IOException { + if (position <= oneVarintLimit) { + // Optimization to avoid bounds checks on each iteration. + while (true) { + if ((value & ~0x7FL) == 0) { + UnsafeUtil.putByte(position++, (byte) value); + return; + } else { + UnsafeUtil.putByte(position++, (byte) (((int) value & 0x7F) | 0x80)); + value >>>= 7; + } + } + } else { + while (position < limit) { + if ((value & ~0x7FL) == 0) { + UnsafeUtil.putByte(position++, (byte) value); + return; + } else { + UnsafeUtil.putByte(position++, (byte) (((int) value & 0x7F) | 0x80)); + value >>>= 7; + } + } + throw new OutOfSpaceException( + String.format("Pos: %d, limit: %d, len: %d", position, limit, 1)); + } + } + + @Override + public void writeFixed64NoTag(long value) throws IOException { + buffer.putLong(bufferPos(position), value); + position += FIXED_64_SIZE; + } + + @Override + public void write(byte[] value, int offset, int length) throws IOException { + if (value == null + || offset < 0 + || length < 0 + || (value.length - length) < offset + || (limit - length) < position) { + if (value == null) { + throw new NullPointerException("value"); + } + throw new OutOfSpaceException( + String.format("Pos: %d, limit: %d, len: %d", position, limit, length)); + } + + UnsafeUtil.copyMemory( + value, UnsafeUtil.getArrayBaseOffset() + offset, null, position, length); + position += length; + } + + @Override + public void writeLazy(byte[] value, int offset, int length) throws IOException { + write(value, offset, length); + } + + @Override + public void write(ByteBuffer value) throws IOException { + try { + int length = value.remaining(); + repositionBuffer(position); + buffer.put(value); + position += length; + } catch (BufferOverflowException e) { + throw new OutOfSpaceException(e); + } + } + + @Override + public void writeLazy(ByteBuffer value) throws IOException { + write(value); + } + + @Override + public void writeStringNoTag(String value) throws IOException { + long prevPos = position; + try { + // UTF-8 byte length of the string is at least its UTF-16 code unit length (value.length()), + // and at most 3 times of it. We take advantage of this in both branches below. + int maxEncodedSize = value.length() * Utf8.MAX_BYTES_PER_CHAR; + int maxLengthVarIntSize = computeUInt32SizeNoTag(maxEncodedSize); + int minLengthVarIntSize = computeUInt32SizeNoTag(value.length()); + if (minLengthVarIntSize == maxLengthVarIntSize) { + // Save the current position and increment past the length field. We'll come back + // and write the length field after the encoding is complete. + int stringStart = bufferPos(position) + minLengthVarIntSize; + buffer.position(stringStart); + + // Encode the string. + Utf8.encodeUtf8(value, buffer); + + // Write the length and advance the position. + int length = buffer.position() - stringStart; + writeUInt32NoTag(length); + position += length; + } else { + // Calculate and write the encoded length. + int length = Utf8.encodedLength(value); + writeUInt32NoTag(length); + + // Write the string and advance the position. + repositionBuffer(position); + Utf8.encodeUtf8(value, buffer); + position += length; + } + } catch (UnpairedSurrogateException e) { + // Roll back the change and convert to an IOException. + position = prevPos; + repositionBuffer(position); + + // TODO(nathanmittler): We should throw an IOException here instead. + inefficientWriteStringNoTag(value, e); + } catch (IllegalArgumentException e) { + // Thrown by buffer.position() if out of range. + throw new OutOfSpaceException(e); + } catch (IndexOutOfBoundsException e) { + throw new OutOfSpaceException(e); + } + } + + @Override + public void flush() { + // Update the position of the original buffer. + originalBuffer.position(bufferPos(position)); + } + + @Override + public int spaceLeft() { + return (int) (limit - position); + } + + @Override + public int getTotalBytesWritten() { + return (int) (position - initialPosition); + } + + private void repositionBuffer(long pos) { + buffer.position(bufferPos(pos)); + } + + private int bufferPos(long pos) { + return (int) (pos - address); + } + } + /** * Abstract base class for buffered encoders. */ diff --git a/java/core/src/main/java/com/google/protobuf/Descriptors.java b/java/core/src/main/java/com/google/protobuf/Descriptors.java index 1c34c24f..cab7118a 100644 --- a/java/core/src/main/java/com/google/protobuf/Descriptors.java +++ b/java/core/src/main/java/com/google/protobuf/Descriptors.java @@ -1212,33 +1212,20 @@ public final class Descriptors { private final Object defaultDefault; } - // TODO(xiaofeng): Implement it consistently across different languages. See b/24751348. - private static String fieldNameToLowerCamelCase(String name) { + // This method should match exactly with the ToJsonName() function in C++ + // descriptor.cc. + private static String fieldNameToJsonName(String name) { StringBuilder result = new StringBuilder(name.length()); boolean isNextUpperCase = false; for (int i = 0; i < name.length(); i++) { Character ch = name.charAt(i); - if (Character.isLowerCase(ch)) { - if (isNextUpperCase) { - result.append(Character.toUpperCase(ch)); - } else { - result.append(ch); - } - isNextUpperCase = false; - } else if (Character.isUpperCase(ch)) { - if (i == 0) { - // Force first letter to lower-case. - result.append(Character.toLowerCase(ch)); - } else { - // Capital letters after the first are left as-is. - result.append(ch); - } - isNextUpperCase = false; - } else if (Character.isDigit(ch)) { - result.append(ch); + if (ch == '_') { + isNextUpperCase = true; + } else if (isNextUpperCase) { + result.append(Character.toUpperCase(ch)); isNextUpperCase = false; } else { - isNextUpperCase = true; + result.append(ch); } } return result.toString(); @@ -1257,7 +1244,7 @@ public final class Descriptors { if (proto.hasJsonName()) { jsonName = proto.getJsonName(); } else { - jsonName = fieldNameToLowerCamelCase(proto.getName()); + jsonName = fieldNameToJsonName(proto.getName()); } if (proto.hasType()) { diff --git a/java/core/src/main/java/com/google/protobuf/GeneratedMessageV3.java b/java/core/src/main/java/com/google/protobuf/GeneratedMessageV3.java index 5dfe7ff7..5c0b9967 100644 --- a/java/core/src/main/java/com/google/protobuf/GeneratedMessageV3.java +++ b/java/core/src/main/java/com/google/protobuf/GeneratedMessageV3.java @@ -36,7 +36,6 @@ import com.google.protobuf.Descriptors.EnumValueDescriptor; import com.google.protobuf.Descriptors.FieldDescriptor; import com.google.protobuf.Descriptors.FileDescriptor; import com.google.protobuf.Descriptors.OneofDescriptor; -import com.google.protobuf.GeneratedMessage.GeneratedExtension; import java.io.IOException; import java.io.InputStream; @@ -45,6 +44,7 @@ import java.io.Serializable; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.Iterator; import java.util.List; @@ -1609,6 +1609,43 @@ public abstract class GeneratedMessageV3 extends AbstractMessage FieldDescriptor getDescriptor(); } + /** For use by generated code only. */ + public static + GeneratedExtension + newMessageScopedGeneratedExtension(final Message scope, + final int descriptorIndex, + final Class singularType, + final Message defaultInstance) { + // For extensions scoped within a Message, we use the Message to resolve + // the outer class's descriptor, from which the extension descriptor is + // obtained. + return new GeneratedExtension( + new CachedDescriptorRetriever() { + @Override + public FieldDescriptor loadDescriptor() { + return scope.getDescriptorForType().getExtensions().get(descriptorIndex); + } + }, + singularType, + defaultInstance, + Extension.ExtensionType.IMMUTABLE); + } + + /** For use by generated code only. */ + public static + GeneratedExtension + newFileScopedGeneratedExtension(final Class singularType, + final Message defaultInstance) { + // For extensions scoped within a file, we rely on the outer class's + // static initializer to call internalInit() on the extension when the + // descriptor is available. + return new GeneratedExtension( + null, // ExtensionDescriptorRetriever is initialized in internalInit(); + singularType, + defaultInstance, + Extension.ExtensionType.IMMUTABLE); + } + private abstract static class CachedDescriptorRetriever implements ExtensionDescriptorRetriever { private volatile FieldDescriptor descriptor; @@ -1627,6 +1664,301 @@ public abstract class GeneratedMessageV3 extends AbstractMessage } } + /** + * Used in proto1 generated code only. + * + * After enabling bridge, we can define proto2 extensions (the extended type + * is a proto2 mutable message) in a proto1 .proto file. For these extensions + * we should generate proto2 GeneratedExtensions. + */ + public static + GeneratedExtension + newMessageScopedGeneratedExtension( + final Message scope, final String name, + final Class singularType, final Message defaultInstance) { + // For extensions scoped within a Message, we use the Message to resolve + // the outer class's descriptor, from which the extension descriptor is + // obtained. + return new GeneratedExtension( + new CachedDescriptorRetriever() { + @Override + protected FieldDescriptor loadDescriptor() { + return scope.getDescriptorForType().findFieldByName(name); + } + }, + singularType, + defaultInstance, + Extension.ExtensionType.MUTABLE); + } + + /** + * Used in proto1 generated code only. + * + * After enabling bridge, we can define proto2 extensions (the extended type + * is a proto2 mutable message) in a proto1 .proto file. For these extensions + * we should generate proto2 GeneratedExtensions. + */ + public static + GeneratedExtension + newFileScopedGeneratedExtension( + final Class singularType, final Message defaultInstance, + final String descriptorOuterClass, final String extensionName) { + // For extensions scoped within a file, we load the descriptor outer + // class and rely on it to get the FileDescriptor which then can be + // used to obtain the extension's FieldDescriptor. + return new GeneratedExtension( + new CachedDescriptorRetriever() { + @Override + protected FieldDescriptor loadDescriptor() { + try { + Class clazz = singularType.getClassLoader().loadClass(descriptorOuterClass); + FileDescriptor file = (FileDescriptor) clazz.getField("descriptor").get(null); + return file.findExtensionByName(extensionName); + } catch (Exception e) { + throw new RuntimeException( + "Cannot load descriptors: " + + descriptorOuterClass + + " is not a valid descriptor class name", + e); + } + } + }, + singularType, + defaultInstance, + Extension.ExtensionType.MUTABLE); + } + + /** + * Type used to represent generated extensions. The protocol compiler + * generates a static singleton instance of this class for each extension. + * + *

For example, imagine you have the {@code .proto} file: + * + *

+   * option java_class = "MyProto";
+   *
+   * message Foo {
+   *   extensions 1000 to max;
+   * }
+   *
+   * extend Foo {
+   *   optional int32 bar;
+   * }
+   * 
+ * + *

Then, {@code MyProto.Foo.bar} has type + * {@code GeneratedExtension}. + * + *

In general, users should ignore the details of this type, and simply use + * these static singletons as parameters to the extension accessors defined + * in {@link ExtendableMessage} and {@link ExtendableBuilder}. + */ + public static class GeneratedExtension< + ContainingType extends Message, Type> extends + Extension { + // TODO(kenton): Find ways to avoid using Java reflection within this + // class. Also try to avoid suppressing unchecked warnings. + + // We can't always initialize the descriptor of a GeneratedExtension when + // we first construct it due to initialization order difficulties (namely, + // the descriptor may not have been constructed yet, since it is often + // constructed by the initializer of a separate module). + // + // In the case of nested extensions, we initialize the + // ExtensionDescriptorRetriever with an instance that uses the scoping + // Message's default instance to retrieve the extension's descriptor. + // + // In the case of non-nested extensions, we initialize the + // ExtensionDescriptorRetriever to null and rely on the outer class's static + // initializer to call internalInit() after the descriptor has been parsed. + GeneratedExtension(ExtensionDescriptorRetriever descriptorRetriever, + Class singularType, + Message messageDefaultInstance, + ExtensionType extensionType) { + if (Message.class.isAssignableFrom(singularType) && + !singularType.isInstance(messageDefaultInstance)) { + throw new IllegalArgumentException( + "Bad messageDefaultInstance for " + singularType.getName()); + } + this.descriptorRetriever = descriptorRetriever; + this.singularType = singularType; + this.messageDefaultInstance = messageDefaultInstance; + + if (ProtocolMessageEnum.class.isAssignableFrom(singularType)) { + this.enumValueOf = getMethodOrDie(singularType, "valueOf", + EnumValueDescriptor.class); + this.enumGetValueDescriptor = + getMethodOrDie(singularType, "getValueDescriptor"); + } else { + this.enumValueOf = null; + this.enumGetValueDescriptor = null; + } + this.extensionType = extensionType; + } + + /** For use by generated code only. */ + public void internalInit(final FieldDescriptor descriptor) { + if (descriptorRetriever != null) { + throw new IllegalStateException("Already initialized."); + } + descriptorRetriever = + new ExtensionDescriptorRetriever() { + @Override + public FieldDescriptor getDescriptor() { + return descriptor; + } + }; + } + + private ExtensionDescriptorRetriever descriptorRetriever; + private final Class singularType; + private final Message messageDefaultInstance; + private final Method enumValueOf; + private final Method enumGetValueDescriptor; + private final ExtensionType extensionType; + + @Override + public FieldDescriptor getDescriptor() { + if (descriptorRetriever == null) { + throw new IllegalStateException( + "getDescriptor() called before internalInit()"); + } + return descriptorRetriever.getDescriptor(); + } + + /** + * If the extension is an embedded message or group, returns the default + * instance of the message. + */ + @Override + public Message getMessageDefaultInstance() { + return messageDefaultInstance; + } + + @Override + protected ExtensionType getExtensionType() { + return extensionType; + } + + /** + * Convert from the type used by the reflection accessors to the type used + * by native accessors. E.g., for enums, the reflection accessors use + * EnumValueDescriptors but the native accessors use the generated enum + * type. + */ + @Override + @SuppressWarnings("unchecked") + protected Object fromReflectionType(final Object value) { + FieldDescriptor descriptor = getDescriptor(); + if (descriptor.isRepeated()) { + if (descriptor.getJavaType() == FieldDescriptor.JavaType.MESSAGE || + descriptor.getJavaType() == FieldDescriptor.JavaType.ENUM) { + // Must convert the whole list. + final List result = new ArrayList(); + for (final Object element : (List) value) { + result.add(singularFromReflectionType(element)); + } + return result; + } else { + return value; + } + } else { + return singularFromReflectionType(value); + } + } + + /** + * Like {@link #fromReflectionType(Object)}, but if the type is a repeated + * type, this converts a single element. + */ + @Override + protected Object singularFromReflectionType(final Object value) { + FieldDescriptor descriptor = getDescriptor(); + switch (descriptor.getJavaType()) { + case MESSAGE: + if (singularType.isInstance(value)) { + return value; + } else { + return messageDefaultInstance.newBuilderForType() + .mergeFrom((Message) value).build(); + } + case ENUM: + return invokeOrDie(enumValueOf, null, (EnumValueDescriptor) value); + default: + return value; + } + } + + /** + * Convert from the type used by the native accessors to the type used + * by reflection accessors. E.g., for enums, the reflection accessors use + * EnumValueDescriptors but the native accessors use the generated enum + * type. + */ + @Override + @SuppressWarnings("unchecked") + protected Object toReflectionType(final Object value) { + FieldDescriptor descriptor = getDescriptor(); + if (descriptor.isRepeated()) { + if (descriptor.getJavaType() == FieldDescriptor.JavaType.ENUM) { + // Must convert the whole list. + final List result = new ArrayList(); + for (final Object element : (List) value) { + result.add(singularToReflectionType(element)); + } + return result; + } else { + return value; + } + } else { + return singularToReflectionType(value); + } + } + + /** + * Like {@link #toReflectionType(Object)}, but if the type is a repeated + * type, this converts a single element. + */ + @Override + protected Object singularToReflectionType(final Object value) { + FieldDescriptor descriptor = getDescriptor(); + switch (descriptor.getJavaType()) { + case ENUM: + return invokeOrDie(enumGetValueDescriptor, value); + default: + return value; + } + } + + @Override + public int getNumber() { + return getDescriptor().getNumber(); + } + + @Override + public WireFormat.FieldType getLiteType() { + return getDescriptor().getLiteType(); + } + + @Override + public boolean isRepeated() { + return getDescriptor().isRepeated(); + } + + @Override + @SuppressWarnings("unchecked") + public Type getDefaultValue() { + if (isRepeated()) { + return (Type) Collections.emptyList(); + } + if (getDescriptor().getJavaType() == FieldDescriptor.JavaType.MESSAGE) { + return (Type) messageDefaultInstance; + } + return (Type) singularFromReflectionType( + getDescriptor().getDefaultValue()); + } + } + // ================================================================= /** Calls Class.getMethod and throws a RuntimeException if it fails. */ @@ -2223,6 +2555,20 @@ public abstract class GeneratedMessageV3 extends AbstractMessage field.getNumber()); } + private Message coerceType(Message value) { + if (value == null) { + return null; + } + if (mapEntryMessageDefaultInstance.getClass().isInstance(value)) { + return value; + } + // The value is not the exact right message type. However, if it + // is an alternative implementation of the same type -- e.g. a + // DynamicMessage -- we should accept it. In this case we can make + // a copy of the message. + return mapEntryMessageDefaultInstance.toBuilder().mergeFrom(value).build(); + } + @Override public Object get(GeneratedMessageV3 message) { List result = new ArrayList(); @@ -2278,15 +2624,15 @@ public abstract class GeneratedMessageV3 extends AbstractMessage public Object getRepeatedRaw(Builder builder, int index) { return getRepeated(builder, index); } - + @Override public void setRepeated(Builder builder, int index, Object value) { - getMutableMapField(builder).getMutableList().set(index, (Message) value); + getMutableMapField(builder).getMutableList().set(index, coerceType((Message) value)); } @Override public void addRepeated(Builder builder, Object value) { - getMutableMapField(builder).getMutableList().add((Message) value); + getMutableMapField(builder).getMutableList().add(coerceType((Message) value)); } @Override @@ -2713,4 +3059,131 @@ public abstract class GeneratedMessageV3 extends AbstractMessage output.writeBytesNoTag((ByteString) value); } } + + protected static void serializeIntegerMapTo( + CodedOutputStream out, + MapField field, + MapEntry defaultEntry, + int fieldNumber) throws IOException { + Map m = field.getMap(); + if (!out.isSerializationDeterministic()) { + serializeMapTo(out, m, defaultEntry, fieldNumber); + return; + } + // Sorting the unboxed keys and then look up the values during serialziation is 2x faster + // than sorting map entries with a custom comparator directly. + int[] keys = new int[m.size()]; + int index = 0; + for (int k : m.keySet()) { + keys[index++] = k; + } + Arrays.sort(keys); + for (int key : keys) { + out.writeMessage(fieldNumber, + defaultEntry.newBuilderForType() + .setKey(key) + .setValue(m.get(key)) + .build()); + } + } + + protected static void serializeLongMapTo( + CodedOutputStream out, + MapField field, + MapEntry defaultEntry, + int fieldNumber) + throws IOException { + Map m = field.getMap(); + if (!out.isSerializationDeterministic()) { + serializeMapTo(out, m, defaultEntry, fieldNumber); + return; + } + + long[] keys = new long[m.size()]; + int index = 0; + for (long k : m.keySet()) { + keys[index++] = k; + } + Arrays.sort(keys); + for (long key : keys) { + out.writeMessage(fieldNumber, + defaultEntry.newBuilderForType() + .setKey(key) + .setValue(m.get(key)) + .build()); + } + } + + protected static void serializeStringMapTo( + CodedOutputStream out, + MapField field, + MapEntry defaultEntry, + int fieldNumber) + throws IOException { + Map m = field.getMap(); + if (!out.isSerializationDeterministic()) { + serializeMapTo(out, m, defaultEntry, fieldNumber); + return; + } + + // Sorting the String keys and then look up the values during serialziation is 25% faster than + // sorting map entries with a custom comparator directly. + String[] keys = new String[m.size()]; + keys = m.keySet().toArray(keys); + Arrays.sort(keys); + for (String key : keys) { + out.writeMessage(fieldNumber, + defaultEntry.newBuilderForType() + .setKey(key) + .setValue(m.get(key)) + .build()); + } + } + + protected static void serializeBooleanMapTo( + CodedOutputStream out, + MapField field, + MapEntry defaultEntry, + int fieldNumber) + throws IOException { + Map m = field.getMap(); + if (!out.isSerializationDeterministic()) { + serializeMapTo(out, m, defaultEntry, fieldNumber); + return; + } + maybeSerializeBooleanEntryTo(out, m, defaultEntry, fieldNumber, false); + maybeSerializeBooleanEntryTo(out, m, defaultEntry, fieldNumber, true); + } + + private static void maybeSerializeBooleanEntryTo( + CodedOutputStream out, + Map m, + MapEntry defaultEntry, + int fieldNumber, + boolean key) + throws IOException { + if (m.containsKey(key)) { + out.writeMessage(fieldNumber, + defaultEntry.newBuilderForType() + .setKey(key) + .setValue(m.get(key)) + .build()); + } + } + + /** Serialize the map using the iteration order. */ + private static void serializeMapTo( + CodedOutputStream out, + Map m, + MapEntry defaultEntry, + int fieldNumber) + throws IOException { + for (Map.Entry entry : m.entrySet()) { + out.writeMessage(fieldNumber, + defaultEntry.newBuilderForType() + .setKey(entry.getKey()) + .setValue(entry.getValue()) + .build()); + } + } } diff --git a/java/core/src/main/java/com/google/protobuf/Internal.java b/java/core/src/main/java/com/google/protobuf/Internal.java index 3b4a0412..c234559c 100644 --- a/java/core/src/main/java/com/google/protobuf/Internal.java +++ b/java/core/src/main/java/com/google/protobuf/Internal.java @@ -59,6 +59,16 @@ public final class Internal { static final Charset UTF_8 = Charset.forName("UTF-8"); static final Charset ISO_8859_1 = Charset.forName("ISO-8859-1"); + /** + * Throws an appropriate {@link NullPointerException} if the given objects is {@code null}. + */ + static T checkNotNull(T obj, String message) { + if (obj == null) { + throw new NullPointerException(message); + } + return obj; + } + /** * Helper called by generated code to construct default values for string * fields. diff --git a/java/core/src/main/java/com/google/protobuf/InvalidProtocolBufferException.java b/java/core/src/main/java/com/google/protobuf/InvalidProtocolBufferException.java index 85ce7b24..55e33d21 100644 --- a/java/core/src/main/java/com/google/protobuf/InvalidProtocolBufferException.java +++ b/java/core/src/main/java/com/google/protobuf/InvalidProtocolBufferException.java @@ -107,11 +107,23 @@ public class InvalidProtocolBufferException extends IOException { "Protocol message end-group tag did not match expected tag."); } - static InvalidProtocolBufferException invalidWireType() { - return new InvalidProtocolBufferException( + static InvalidWireTypeException invalidWireType() { + return new InvalidWireTypeException( "Protocol message tag had invalid wire type."); } + /** + * Exception indicating that and unexpected wire type was encountered for a field. + */ + @ExperimentalApi + public static class InvalidWireTypeException extends InvalidProtocolBufferException { + private static final long serialVersionUID = 3283890091615336259L; + + public InvalidWireTypeException(String description) { + super(description); + } + } + static InvalidProtocolBufferException recursionLimitExceeded() { return new InvalidProtocolBufferException( "Protocol message had too many levels of nesting. May be malicious. " + diff --git a/java/core/src/main/java/com/google/protobuf/MapEntry.java b/java/core/src/main/java/com/google/protobuf/MapEntry.java index 117cd911..4f0351f4 100644 --- a/java/core/src/main/java/com/google/protobuf/MapEntry.java +++ b/java/core/src/main/java/com/google/protobuf/MapEntry.java @@ -334,6 +334,15 @@ public final class MapEntry extends AbstractMessage { } else { if (field.getType() == FieldDescriptor.Type.ENUM) { value = ((EnumValueDescriptor) value).getNumber(); + } else if (field.getType() == FieldDescriptor.Type.MESSAGE) { + if (value != null && !metadata.defaultValue.getClass().isInstance(value)) { + // The value is not the exact right message type. However, if it + // is an alternative implementation of the same type -- e.g. a + // DynamicMessage -- we should accept it. In this case we can make + // a copy of the message. + value = + ((Message) metadata.defaultValue).toBuilder().mergeFrom((Message) value).build(); + } } setValue((V) value); } diff --git a/java/core/src/main/java/com/google/protobuf/NioByteString.java b/java/core/src/main/java/com/google/protobuf/NioByteString.java index 6163c7b1..76594809 100644 --- a/java/core/src/main/java/com/google/protobuf/NioByteString.java +++ b/java/core/src/main/java/com/google/protobuf/NioByteString.java @@ -30,6 +30,8 @@ package com.google.protobuf; +import static com.google.protobuf.Internal.checkNotNull; + import java.io.IOException; import java.io.InputStream; import java.io.InvalidObjectException; @@ -49,10 +51,9 @@ final class NioByteString extends ByteString.LeafByteString { private final ByteBuffer buffer; NioByteString(ByteBuffer buffer) { - if (buffer == null) { - throw new NullPointerException("buffer"); - } + checkNotNull(buffer, "buffer"); + // Use native byte order for fast fixed32/64 operations. this.buffer = buffer.slice().order(ByteOrder.nativeOrder()); } @@ -266,7 +267,7 @@ final class NioByteString extends ByteString.LeafByteString { @Override public CodedInputStream newCodedInput() { - return CodedInputStream.newInstance(buffer); + return CodedInputStream.newInstance(buffer, true); } /** diff --git a/java/core/src/main/java/com/google/protobuf/RopeByteString.java b/java/core/src/main/java/com/google/protobuf/RopeByteString.java index 3f3e9bd1..6fa555df 100644 --- a/java/core/src/main/java/com/google/protobuf/RopeByteString.java +++ b/java/core/src/main/java/com/google/protobuf/RopeByteString.java @@ -406,6 +406,7 @@ final class RopeByteString extends ByteString { right.writeTo(output); } + @Override protected String toStringInternal(Charset charset) { return new String(toByteArray(), charset); diff --git a/java/core/src/main/java/com/google/protobuf/TextFormat.java b/java/core/src/main/java/com/google/protobuf/TextFormat.java index ff13675d..49708242 100644 --- a/java/core/src/main/java/com/google/protobuf/TextFormat.java +++ b/java/core/src/main/java/com/google/protobuf/TextFormat.java @@ -1576,14 +1576,22 @@ public final class TextFormat { // Support specifying repeated field values as a comma-separated list. // Ex."foo: [1, 2, 3]" if (field.isRepeated() && tokenizer.tryConsume("[")) { - while (true) { - consumeFieldValue(tokenizer, extensionRegistry, target, field, extension, - parseTreeBuilder, unknownFields); - if (tokenizer.tryConsume("]")) { - // End of list. - break; + if (!tokenizer.tryConsume("]")) { // Allow "foo: []" to be treated as empty. + while (true) { + consumeFieldValue( + tokenizer, + extensionRegistry, + target, + field, + extension, + parseTreeBuilder, + unknownFields); + if (tokenizer.tryConsume("]")) { + // End of list. + break; + } + tokenizer.consume(","); } - tokenizer.consume(","); } } else { consumeFieldValue(tokenizer, extensionRegistry, target, field, diff --git a/java/core/src/main/java/com/google/protobuf/TextFormatParseInfoTree.java b/java/core/src/main/java/com/google/protobuf/TextFormatParseInfoTree.java index 5c43b2c3..0127ce92 100644 --- a/java/core/src/main/java/com/google/protobuf/TextFormatParseInfoTree.java +++ b/java/core/src/main/java/com/google/protobuf/TextFormatParseInfoTree.java @@ -45,7 +45,8 @@ import java.util.Map.Entry; * *

The locations of primary fields values are retrieved by {@code getLocation} or * {@code getLocations}. The locations of sub message values are within nested - * {@code TextFormatParseInfoTree}s and are retrieve by {@code getNestedTree} or {@code getNestedTrees}. + * {@code TextFormatParseInfoTree}s and are retrieve by {@code getNestedTree} or + * {@code getNestedTrees}. * *

The {@code TextFormatParseInfoTree} is created by a Builder. */ diff --git a/java/core/src/main/java/com/google/protobuf/UnsafeByteOperations.java b/java/core/src/main/java/com/google/protobuf/UnsafeByteOperations.java index e72a6b4d..878c7758 100644 --- a/java/core/src/main/java/com/google/protobuf/UnsafeByteOperations.java +++ b/java/core/src/main/java/com/google/protobuf/UnsafeByteOperations.java @@ -64,6 +64,29 @@ import java.nio.ByteBuffer; public final class UnsafeByteOperations { private UnsafeByteOperations() {} + /** + * An unsafe operation that returns a {@link ByteString} that is backed by the provided buffer. + * + * @param buffer the buffer to be wrapped + * @return a {@link ByteString} backed by the provided buffer + */ + public static ByteString unsafeWrap(byte[] buffer) { + return ByteString.wrap(buffer); + } + + /** + * An unsafe operation that returns a {@link ByteString} that is backed by a subregion of the + * provided buffer. + * + * @param buffer the buffer to be wrapped + * @param offset the offset of the wrapped region + * @param length the number of bytes of the wrapped region + * @return a {@link ByteString} backed by the provided buffer + */ + public static ByteString unsafeWrap(byte[] buffer, int offset, int length) { + return ByteString.wrap(buffer, offset, length); + } + /** * An unsafe operation that returns a {@link ByteString} that is backed by the provided buffer. * @@ -71,12 +94,7 @@ public final class UnsafeByteOperations { * @return a {@link ByteString} backed by the provided buffer */ public static ByteString unsafeWrap(ByteBuffer buffer) { - if (buffer.hasArray()) { - final int offset = buffer.arrayOffset(); - return ByteString.wrap(buffer.array(), offset + buffer.position(), buffer.remaining()); - } else { - return new NioByteString(buffer); - } + return ByteString.wrap(buffer); } /** @@ -98,4 +116,5 @@ public final class UnsafeByteOperations { public static void unsafeWriteTo(ByteString bytes, ByteOutput output) throws IOException { bytes.writeTo(output); } + } diff --git a/java/core/src/main/java/com/google/protobuf/UnsafeUtil.java b/java/core/src/main/java/com/google/protobuf/UnsafeUtil.java index 6a4787d1..5f7bafd6 100644 --- a/java/core/src/main/java/com/google/protobuf/UnsafeUtil.java +++ b/java/core/src/main/java/com/google/protobuf/UnsafeUtil.java @@ -30,17 +30,14 @@ package com.google.protobuf; -import sun.misc.Unsafe; - import java.lang.reflect.Field; import java.nio.Buffer; import java.nio.ByteBuffer; import java.security.AccessController; import java.security.PrivilegedExceptionAction; +import sun.misc.Unsafe; -/** - * Utility class for working with unsafe operations. - */ +/** Utility class for working with unsafe operations. */ // TODO(nathanmittler): Add support for Android Memory/MemoryBlock final class UnsafeUtil { private static final sun.misc.Unsafe UNSAFE = getUnsafe(); @@ -50,8 +47,7 @@ final class UnsafeUtil { private static final long ARRAY_BASE_OFFSET = byteArrayBaseOffset(); private static final long BUFFER_ADDRESS_OFFSET = fieldOffset(field(Buffer.class, "address")); - private UnsafeUtil() { - } + private UnsafeUtil() {} static boolean hasUnsafeArrayOperations() { return HAS_UNSAFE_ARRAY_OPERATIONS; @@ -61,27 +57,83 @@ final class UnsafeUtil { return HAS_UNSAFE_BYTEBUFFER_OPERATIONS; } + static Object allocateInstance(Class clazz) { + try { + return UNSAFE.allocateInstance(clazz); + } catch (InstantiationException e) { + throw new RuntimeException(e); + } + } + + static long objectFieldOffset(Field field) { + return UNSAFE.objectFieldOffset(field); + } + static long getArrayBaseOffset() { return ARRAY_BASE_OFFSET; } - static byte getByte(byte[] target, long offset) { + static byte getByte(Object target, long offset) { return UNSAFE.getByte(target, offset); } - static void putByte(byte[] target, long offset, byte value) { + static void putByte(Object target, long offset, byte value) { UNSAFE.putByte(target, offset, value); } - static void copyMemory( - byte[] src, long srcOffset, byte[] target, long targetOffset, long length) { - UNSAFE.copyMemory(src, srcOffset, target, targetOffset, length); + static int getInt(Object target, long offset) { + return UNSAFE.getInt(target, offset); } - static long getLong(byte[] target, long offset) { + static void putInt(Object target, long offset, int value) { + UNSAFE.putInt(target, offset, value); + } + + static long getLong(Object target, long offset) { return UNSAFE.getLong(target, offset); } + static void putLong(Object target, long offset, long value) { + UNSAFE.putLong(target, offset, value); + } + + static boolean getBoolean(Object target, long offset) { + return UNSAFE.getBoolean(target, offset); + } + + static void putBoolean(Object target, long offset, boolean value) { + UNSAFE.putBoolean(target, offset, value); + } + + static float getFloat(Object target, long offset) { + return UNSAFE.getFloat(target, offset); + } + + static void putFloat(Object target, long offset, float value) { + UNSAFE.putFloat(target, offset, value); + } + + static double getDouble(Object target, long offset) { + return UNSAFE.getDouble(target, offset); + } + + static void putDouble(Object target, long offset, double value) { + UNSAFE.putDouble(target, offset, value); + } + + static Object getObject(Object target, long offset) { + return UNSAFE.getObject(target, offset); + } + + static void putObject(Object target, long offset, Object value) { + UNSAFE.putObject(target, offset, value); + } + + static void copyMemory( + Object src, long srcOffset, Object target, long targetOffset, long length) { + UNSAFE.copyMemory(src, srcOffset, target, targetOffset, length); + } + static byte getByte(long address) { return UNSAFE.getByte(address); } @@ -90,14 +142,30 @@ final class UnsafeUtil { UNSAFE.putByte(address, value); } + static int getInt(long address) { + return UNSAFE.getInt(address); + } + + static void putInt(long address, int value) { + UNSAFE.putInt(address, value); + } + static long getLong(long address) { return UNSAFE.getLong(address); } + static void putLong(long address, long value) { + UNSAFE.putLong(address, value); + } + static void copyMemory(long srcAddress, long targetAddress, long length) { UNSAFE.copyMemory(srcAddress, targetAddress, length); } + static void setMemory(long address, long numBytes, byte value) { + UNSAFE.setMemory(address, numBytes, value); + } + /** * Gets the offset of the {@code address} field of the given direct {@link ByteBuffer}. */ @@ -136,18 +204,29 @@ final class UnsafeUtil { return unsafe; } - /** - * Indicates whether or not unsafe array operations are supported on this platform. - */ + /** Indicates whether or not unsafe array operations are supported on this platform. */ private static boolean supportsUnsafeArrayOperations() { boolean supported = false; if (UNSAFE != null) { try { Class clazz = UNSAFE.getClass(); + clazz.getMethod("objectFieldOffset", Field.class); + clazz.getMethod("allocateInstance", Class.class); clazz.getMethod("arrayBaseOffset", Class.class); clazz.getMethod("getByte", Object.class, long.class); clazz.getMethod("putByte", Object.class, long.class, byte.class); + clazz.getMethod("getBoolean", Object.class, long.class); + clazz.getMethod("putBoolean", Object.class, long.class, boolean.class); + clazz.getMethod("getInt", Object.class, long.class); + clazz.getMethod("putInt", Object.class, long.class, int.class); clazz.getMethod("getLong", Object.class, long.class); + clazz.getMethod("putLong", Object.class, long.class, long.class); + clazz.getMethod("getFloat", Object.class, long.class); + clazz.getMethod("putFloat", Object.class, long.class, float.class); + clazz.getMethod("getDouble", Object.class, long.class); + clazz.getMethod("putDouble", Object.class, long.class, double.class); + clazz.getMethod("getObject", Object.class, long.class); + clazz.getMethod("putObject", Object.class, long.class, Object.class); clazz.getMethod( "copyMemory", Object.class, long.class, Object.class, long.class, long.class); supported = true; @@ -163,11 +242,17 @@ final class UnsafeUtil { if (UNSAFE != null) { try { Class clazz = UNSAFE.getClass(); + // Methods for getting direct buffer address. clazz.getMethod("objectFieldOffset", Field.class); - clazz.getMethod("getByte", long.class); clazz.getMethod("getLong", Object.class, long.class); + + clazz.getMethod("getByte", long.class); clazz.getMethod("putByte", long.class, byte.class); + clazz.getMethod("getInt", long.class); + clazz.getMethod("putInt", long.class, int.class); clazz.getMethod("getLong", long.class); + clazz.getMethod("putLong", long.class, long.class); + clazz.getMethod("setMemory", long.class, long.class, byte.class); clazz.getMethod("copyMemory", long.class, long.class, long.class); supported = true; } catch (Throwable e) { diff --git a/java/core/src/main/java/com/google/protobuf/WireFormat.java b/java/core/src/main/java/com/google/protobuf/WireFormat.java index 6a58b081..0b0cdb7d 100644 --- a/java/core/src/main/java/com/google/protobuf/WireFormat.java +++ b/java/core/src/main/java/com/google/protobuf/WireFormat.java @@ -47,6 +47,10 @@ public final class WireFormat { // Do not allow instantiation. private WireFormat() {} + static final int FIXED_32_SIZE = 4; + static final int FIXED_64_SIZE = 8; + static final int MAX_VARINT_SIZE = 10; + public static final int WIRETYPE_VARINT = 0; public static final int WIRETYPE_FIXED64 = 1; public static final int WIRETYPE_LENGTH_DELIMITED = 2; -- cgit v1.2.3