From 13fd045dbb2b4dacea32be162a41d5a4b0d1802f Mon Sep 17 00:00:00 2001 From: Adam Cozzette Date: Tue, 12 Sep 2017 10:32:01 -0700 Subject: Integrated internal changes from Google --- .../java/com/google/protobuf/CodedInputStream.java | 882 +++++++++++++++++++++ .../com/google/protobuf/GeneratedMessageLite.java | 59 +- .../protobuf/IterableByteBufferInputStream.java | 150 ++++ .../java/com/google/protobuf/SmallSortedMap.java | 4 +- .../main/java/com/google/protobuf/TextFormat.java | 18 +- .../com/google/protobuf/UnknownFieldSetLite.java | 3 +- .../main/java/com/google/protobuf/UnsafeUtil.java | 11 - 7 files changed, 1091 insertions(+), 36 deletions(-) create mode 100644 java/core/src/main/java/com/google/protobuf/IterableByteBufferInputStream.java (limited to 'java/core/src/main') diff --git a/java/core/src/main/java/com/google/protobuf/CodedInputStream.java b/java/core/src/main/java/com/google/protobuf/CodedInputStream.java index d6a941b1..511501d4 100644 --- a/java/core/src/main/java/com/google/protobuf/CodedInputStream.java +++ b/java/core/src/main/java/com/google/protobuf/CodedInputStream.java @@ -44,6 +44,7 @@ import java.io.InputStream; import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Arrays; +import java.util.Iterator; import java.util.List; /** @@ -85,6 +86,43 @@ public abstract class CodedInputStream { return new StreamDecoder(input, bufferSize); } + /** Create a new CodedInputStream wrapping the given {@code Iterable }. */ + public static CodedInputStream newInstance(final Iterable input) { + if (!UnsafeDirectNioDecoder.isSupported()) { + return newInstance(new IterableByteBufferInputStream(input)); + } + return newInstance(input, false); + } + + /** Create a new CodedInputStream wrapping the given {@code Iterable }. */ + static CodedInputStream newInstance( + final Iterable bufs, final boolean bufferIsImmutable) { + // flag is to check the type of input's ByteBuffers. + // flag equals 1: all ByteBuffers have array. + // flag equals 2: all ByteBuffers are direct ByteBuffers. + // flag equals 3: some ByteBuffers are direct and some have array. + // flag greater than 3: other cases. + int flag = 0; + // Total size of the input + int totalSize = 0; + for (ByteBuffer buf : bufs) { + totalSize += buf.remaining(); + if (buf.hasArray()) { + flag |= 1; + } else if (buf.isDirect()) { + flag |= 2; + } else { + flag |= 4; + } + } + if (flag == 2) { + return new IterableDirectByteBufferDecoder(bufs, totalSize, bufferIsImmutable); + } else { + // TODO(yilunchong): add another decoders to deal case 1 and 3. + return newInstance(new IterableByteBufferInputStream(bufs)); + } + } + /** Create a new CodedInputStream wrapping the given byte array. */ public static CodedInputStream newInstance(final byte[] buf) { return newInstance(buf, 0, buf.length); @@ -3022,4 +3060,848 @@ public abstract class CodedInputStream { pos = size - tempPos; } } + + /** + * Implementation of {@link CodedInputStream} that uses an {@link Iterable } as the + * data source. Requires the use of {@code sun.misc.Unsafe} to perform fast reads on the buffer. + */ + private static final class IterableDirectByteBufferDecoder extends CodedInputStream { + /** The object that need to decode. */ + private Iterable input; + /** The {@link Iterator} with type {@link ByteBuffer} of {@code input} */ + private Iterator iterator; + /** The current ByteBuffer; */ + private ByteBuffer currentByteBuffer; + /** + * If {@code true}, indicates that all the buffer are backing a {@link ByteString} and are + * therefore considered to be an immutable input source. + */ + private boolean immutable; + /** + * If {@code true}, indicates that calls to read {@link ByteString} or {@code byte[]} + * may return slices of the underlying buffer, rather than copies. + */ + private boolean enableAliasing; + /** The global total message length limit */ + private int totalBufferSize; + /** The amount of available data in the input beyond {@link #currentLimit}. */ + private int bufferSizeAfterCurrentLimit; + /** The absolute position of the end of the current message. */ + private int currentLimit = Integer.MAX_VALUE; + /** The last tag that was read from this stream. */ + private int lastTag; + /** Total Bytes have been Read from the {@link Iterable} {@link ByteBuffer} */ + private int totalBytesRead; + /** The start position offset of the whole message, used as to reset the totalBytesRead */ + private int startOffset; + /** The current position for current ByteBuffer */ + private long currentByteBufferPos; + + private long currentByteBufferStartPos; + /** + * If the current ByteBuffer is unsafe-direct based, currentAddress is the start address of this + * ByteBuffer; otherwise should be zero. + */ + private long currentAddress; + /** The limit position for current ByteBuffer */ + private long currentByteBufferLimit; + + /** + * The constructor of {@code Iterable} decoder. + * + * @param inputBufs The input data. + * @param size The total size of the input data. + * @param immutableFlag whether the input data is immutable. + */ + private IterableDirectByteBufferDecoder( + Iterable inputBufs, int size, boolean immutableFlag) { + totalBufferSize = size; + input = inputBufs; + iterator = input.iterator(); + immutable = immutableFlag; + startOffset = totalBytesRead = 0; + if (size == 0) { + currentByteBuffer = EMPTY_BYTE_BUFFER; + currentByteBufferPos = 0; + currentByteBufferStartPos = 0; + currentByteBufferLimit = 0; + currentAddress = 0; + } else { + tryGetNextByteBuffer(); + } + } + + /** To get the next ByteBuffer from {@code input}, and then update the parameters */ + private void getNextByteBuffer() throws InvalidProtocolBufferException { + if (!iterator.hasNext()) { + throw InvalidProtocolBufferException.truncatedMessage(); + } + tryGetNextByteBuffer(); + } + + private void tryGetNextByteBuffer() { + currentByteBuffer = iterator.next(); + totalBytesRead += (int) (currentByteBufferPos - currentByteBufferStartPos); + currentByteBufferPos = currentByteBuffer.position(); + currentByteBufferStartPos = currentByteBufferPos; + currentByteBufferLimit = currentByteBuffer.limit(); + currentAddress = UnsafeUtil.addressOffset(currentByteBuffer); + currentByteBufferPos += currentAddress; + currentByteBufferStartPos += currentAddress; + currentByteBufferLimit += currentAddress; + } + + @Override + public int readTag() throws IOException { + if (isAtEnd()) { + lastTag = 0; + return 0; + } + + lastTag = readRawVarint32(); + if (WireFormat.getTagFieldNumber(lastTag) == 0) { + // If we actually read zero (or any tag number corresponding to field + // number zero), that's not a valid tag. + throw InvalidProtocolBufferException.invalidTag(); + } + return lastTag; + } + + @Override + public void checkLastTagWas(final int value) throws InvalidProtocolBufferException { + if (lastTag != value) { + throw InvalidProtocolBufferException.invalidEndTag(); + } + } + + @Override + public int getLastTag() { + return lastTag; + } + + @Override + public boolean skipField(final int tag) throws IOException { + switch (WireFormat.getTagWireType(tag)) { + case WireFormat.WIRETYPE_VARINT: + skipRawVarint(); + return true; + case WireFormat.WIRETYPE_FIXED64: + skipRawBytes(FIXED64_SIZE); + return true; + case WireFormat.WIRETYPE_LENGTH_DELIMITED: + skipRawBytes(readRawVarint32()); + return true; + case WireFormat.WIRETYPE_START_GROUP: + skipMessage(); + checkLastTagWas( + WireFormat.makeTag(WireFormat.getTagFieldNumber(tag), WireFormat.WIRETYPE_END_GROUP)); + return true; + case WireFormat.WIRETYPE_END_GROUP: + return false; + case WireFormat.WIRETYPE_FIXED32: + skipRawBytes(FIXED32_SIZE); + return true; + default: + throw InvalidProtocolBufferException.invalidWireType(); + } + } + + @Override + public boolean skipField(final int tag, final CodedOutputStream output) throws IOException { + switch (WireFormat.getTagWireType(tag)) { + case WireFormat.WIRETYPE_VARINT: + { + long value = readInt64(); + output.writeRawVarint32(tag); + output.writeUInt64NoTag(value); + return true; + } + case WireFormat.WIRETYPE_FIXED64: + { + long value = readRawLittleEndian64(); + output.writeRawVarint32(tag); + output.writeFixed64NoTag(value); + return true; + } + case WireFormat.WIRETYPE_LENGTH_DELIMITED: + { + ByteString value = readBytes(); + output.writeRawVarint32(tag); + output.writeBytesNoTag(value); + return true; + } + case WireFormat.WIRETYPE_START_GROUP: + { + output.writeRawVarint32(tag); + skipMessage(output); + int endtag = + WireFormat.makeTag( + WireFormat.getTagFieldNumber(tag), WireFormat.WIRETYPE_END_GROUP); + checkLastTagWas(endtag); + output.writeRawVarint32(endtag); + return true; + } + case WireFormat.WIRETYPE_END_GROUP: + { + return false; + } + case WireFormat.WIRETYPE_FIXED32: + { + int value = readRawLittleEndian32(); + output.writeRawVarint32(tag); + output.writeFixed32NoTag(value); + return true; + } + default: + throw InvalidProtocolBufferException.invalidWireType(); + } + } + + @Override + public void skipMessage() throws IOException { + while (true) { + final int tag = readTag(); + if (tag == 0 || !skipField(tag)) { + return; + } + } + } + + @Override + public void skipMessage(CodedOutputStream output) throws IOException { + while (true) { + final int tag = readTag(); + if (tag == 0 || !skipField(tag, output)) { + return; + } + } + } + + // ----------------------------------------------------------------- + + @Override + public double readDouble() throws IOException { + return Double.longBitsToDouble(readRawLittleEndian64()); + } + + @Override + public float readFloat() throws IOException { + return Float.intBitsToFloat(readRawLittleEndian32()); + } + + @Override + public long readUInt64() throws IOException { + return readRawVarint64(); + } + + @Override + public long readInt64() throws IOException { + return readRawVarint64(); + } + + @Override + public int readInt32() throws IOException { + return readRawVarint32(); + } + + @Override + public long readFixed64() throws IOException { + return readRawLittleEndian64(); + } + + @Override + public int readFixed32() throws IOException { + return readRawLittleEndian32(); + } + + @Override + public boolean readBool() throws IOException { + return readRawVarint64() != 0; + } + + @Override + public String readString() throws IOException { + final int size = readRawVarint32(); + if (size > 0 && size <= currentByteBufferLimit - currentByteBufferPos) { + byte[] bytes = new byte[size]; + UnsafeUtil.copyMemory(currentByteBufferPos, bytes, 0, size); + String result = new String(bytes, UTF_8); + currentByteBufferPos += size; + return result; + } else if (size > 0 && size <= remaining()) { + // TODO(yilunchong): To use an underlying bytes[] instead of allocating a new bytes[] + byte[] bytes = new byte[size]; + readRawBytesTo(bytes, 0, size); + String result = new String(bytes, UTF_8); + return result; + } + + if (size == 0) { + return ""; + } + if (size < 0) { + throw InvalidProtocolBufferException.negativeSize(); + } + throw InvalidProtocolBufferException.truncatedMessage(); + } + + @Override + public String readStringRequireUtf8() throws IOException { + final int size = readRawVarint32(); + if (size > 0 && size <= currentByteBufferLimit - currentByteBufferPos) { + byte[] bytes = new byte[size]; + UnsafeUtil.copyMemory(currentByteBufferPos, bytes, 0, size); + if (!Utf8.isValidUtf8(bytes)) { + throw InvalidProtocolBufferException.invalidUtf8(); + } + String result = new String(bytes, UTF_8); + currentByteBufferPos += size; + return result; + } + if (size >= 0 && size <= remaining()) { + byte[] bytes = new byte[size]; + readRawBytesTo(bytes, 0, size); + if (!Utf8.isValidUtf8(bytes)) { + throw InvalidProtocolBufferException.invalidUtf8(); + } + String result = new String(bytes, UTF_8); + return result; + } + + if (size == 0) { + return ""; + } + if (size <= 0) { + throw InvalidProtocolBufferException.negativeSize(); + } + throw InvalidProtocolBufferException.truncatedMessage(); + } + + @Override + public void readGroup( + final int fieldNumber, + final MessageLite.Builder builder, + final ExtensionRegistryLite extensionRegistry) + throws IOException { + if (recursionDepth >= recursionLimit) { + throw InvalidProtocolBufferException.recursionLimitExceeded(); + } + ++recursionDepth; + builder.mergeFrom(this, extensionRegistry); + checkLastTagWas(WireFormat.makeTag(fieldNumber, WireFormat.WIRETYPE_END_GROUP)); + --recursionDepth; + } + + + @Override + public T readGroup( + final int fieldNumber, + final Parser parser, + final ExtensionRegistryLite extensionRegistry) + throws IOException { + if (recursionDepth >= recursionLimit) { + throw InvalidProtocolBufferException.recursionLimitExceeded(); + } + ++recursionDepth; + T result = parser.parsePartialFrom(this, extensionRegistry); + checkLastTagWas(WireFormat.makeTag(fieldNumber, WireFormat.WIRETYPE_END_GROUP)); + --recursionDepth; + return result; + } + + @Deprecated + @Override + public void readUnknownGroup(final int fieldNumber, final MessageLite.Builder builder) + throws IOException { + readGroup(fieldNumber, builder, ExtensionRegistryLite.getEmptyRegistry()); + } + + @Override + public void readMessage( + final MessageLite.Builder builder, final ExtensionRegistryLite extensionRegistry) + throws IOException { + final int length = readRawVarint32(); + if (recursionDepth >= recursionLimit) { + throw InvalidProtocolBufferException.recursionLimitExceeded(); + } + final int oldLimit = pushLimit(length); + ++recursionDepth; + builder.mergeFrom(this, extensionRegistry); + checkLastTagWas(0); + --recursionDepth; + popLimit(oldLimit); + } + + + @Override + public T readMessage( + final Parser parser, final ExtensionRegistryLite extensionRegistry) throws IOException { + int length = readRawVarint32(); + if (recursionDepth >= recursionLimit) { + throw InvalidProtocolBufferException.recursionLimitExceeded(); + } + final int oldLimit = pushLimit(length); + ++recursionDepth; + T result = parser.parsePartialFrom(this, extensionRegistry); + checkLastTagWas(0); + --recursionDepth; + popLimit(oldLimit); + return result; + } + + @Override + public ByteString readBytes() throws IOException { + final int size = readRawVarint32(); + if (size > 0 && size <= currentByteBufferLimit - currentByteBufferPos) { + if (immutable && enableAliasing) { + final int idx = (int) (currentByteBufferPos - currentAddress); + final ByteString result = ByteString.wrap(slice(idx, idx + size)); + currentByteBufferPos += size; + return result; + } else { + byte[] bytes; + bytes = new byte[size]; + UnsafeUtil.copyMemory(currentByteBufferPos, bytes, 0, size); + currentByteBufferPos += size; + return ByteString.wrap(bytes); + } + } else if (size > 0 && size <= remaining()) { + byte[] temp = new byte[size]; + readRawBytesTo(temp, 0, size); + return ByteString.wrap(temp); + } + + if (size == 0) { + return ByteString.EMPTY; + } + if (size < 0) { + throw InvalidProtocolBufferException.negativeSize(); + } + throw InvalidProtocolBufferException.truncatedMessage(); + } + + @Override + public byte[] readByteArray() throws IOException { + return readRawBytes(readRawVarint32()); + } + + @Override + public ByteBuffer readByteBuffer() throws IOException { + final int size = readRawVarint32(); + if (size > 0 && size <= currentRemaining()) { + if (!immutable && enableAliasing) { + currentByteBufferPos += size; + return slice( + (int) (currentByteBufferPos - currentAddress - size), + (int) (currentByteBufferPos - currentAddress)); + } else { + byte[] bytes = new byte[size]; + UnsafeUtil.copyMemory(currentByteBufferPos, bytes, 0, size); + currentByteBufferPos += size; + return ByteBuffer.wrap(bytes); + } + } else if (size > 0 && size <= remaining()) { + byte[] temp = new byte[size]; + readRawBytesTo(temp, 0, size); + return ByteBuffer.wrap(temp); + } + + if (size == 0) { + return EMPTY_BYTE_BUFFER; + } + if (size < 0) { + throw InvalidProtocolBufferException.negativeSize(); + } + throw InvalidProtocolBufferException.truncatedMessage(); + } + + @Override + public int readUInt32() throws IOException { + return readRawVarint32(); + } + + @Override + public int readEnum() throws IOException { + return readRawVarint32(); + } + + @Override + public int readSFixed32() throws IOException { + return readRawLittleEndian32(); + } + + @Override + public long readSFixed64() throws IOException { + return readRawLittleEndian64(); + } + + @Override + public int readSInt32() throws IOException { + return decodeZigZag32(readRawVarint32()); + } + + @Override + public long readSInt64() throws IOException { + return decodeZigZag64(readRawVarint64()); + } + + @Override + public int readRawVarint32() throws IOException { + fastpath: + { + long tempPos = currentByteBufferPos; + + if (currentByteBufferLimit == currentByteBufferPos) { + break fastpath; + } + + int x; + if ((x = UnsafeUtil.getByte(tempPos++)) >= 0) { + currentByteBufferPos++; + return x; + } else if (currentByteBufferLimit - currentByteBufferPos < 10) { + break fastpath; + } else if ((x ^= (UnsafeUtil.getByte(tempPos++) << 7)) < 0) { + x ^= (~0 << 7); + } else if ((x ^= (UnsafeUtil.getByte(tempPos++) << 14)) >= 0) { + x ^= (~0 << 7) ^ (~0 << 14); + } else if ((x ^= (UnsafeUtil.getByte(tempPos++) << 21)) < 0) { + x ^= (~0 << 7) ^ (~0 << 14) ^ (~0 << 21); + } else { + int y = UnsafeUtil.getByte(tempPos++); + x ^= y << 28; + x ^= (~0 << 7) ^ (~0 << 14) ^ (~0 << 21) ^ (~0 << 28); + if (y < 0 + && UnsafeUtil.getByte(tempPos++) < 0 + && UnsafeUtil.getByte(tempPos++) < 0 + && UnsafeUtil.getByte(tempPos++) < 0 + && UnsafeUtil.getByte(tempPos++) < 0 + && UnsafeUtil.getByte(tempPos++) < 0) { + break fastpath; // Will throw malformedVarint() + } + } + currentByteBufferPos = tempPos; + return x; + } + return (int) readRawVarint64SlowPath(); + } + + @Override + public long readRawVarint64() throws IOException { + fastpath: + { + long tempPos = currentByteBufferPos; + + if (currentByteBufferLimit == currentByteBufferPos) { + break fastpath; + } + + long x; + int y; + if ((y = UnsafeUtil.getByte(tempPos++)) >= 0) { + currentByteBufferPos++; + return y; + } else if (currentByteBufferLimit - currentByteBufferPos < 10) { + break fastpath; + } else if ((y ^= (UnsafeUtil.getByte(tempPos++) << 7)) < 0) { + x = y ^ (~0 << 7); + } else if ((y ^= (UnsafeUtil.getByte(tempPos++) << 14)) >= 0) { + x = y ^ ((~0 << 7) ^ (~0 << 14)); + } else if ((y ^= (UnsafeUtil.getByte(tempPos++) << 21)) < 0) { + x = y ^ ((~0 << 7) ^ (~0 << 14) ^ (~0 << 21)); + } else if ((x = y ^ ((long) UnsafeUtil.getByte(tempPos++) << 28)) >= 0L) { + x ^= (~0L << 7) ^ (~0L << 14) ^ (~0L << 21) ^ (~0L << 28); + } else if ((x ^= ((long) UnsafeUtil.getByte(tempPos++) << 35)) < 0L) { + x ^= (~0L << 7) ^ (~0L << 14) ^ (~0L << 21) ^ (~0L << 28) ^ (~0L << 35); + } else if ((x ^= ((long) UnsafeUtil.getByte(tempPos++) << 42)) >= 0L) { + x ^= (~0L << 7) ^ (~0L << 14) ^ (~0L << 21) ^ (~0L << 28) ^ (~0L << 35) ^ (~0L << 42); + } else if ((x ^= ((long) UnsafeUtil.getByte(tempPos++) << 49)) < 0L) { + x ^= + (~0L << 7) + ^ (~0L << 14) + ^ (~0L << 21) + ^ (~0L << 28) + ^ (~0L << 35) + ^ (~0L << 42) + ^ (~0L << 49); + } else { + x ^= ((long) UnsafeUtil.getByte(tempPos++) << 56); + x ^= + (~0L << 7) + ^ (~0L << 14) + ^ (~0L << 21) + ^ (~0L << 28) + ^ (~0L << 35) + ^ (~0L << 42) + ^ (~0L << 49) + ^ (~0L << 56); + if (x < 0L) { + if (UnsafeUtil.getByte(tempPos++) < 0L) { + break fastpath; // Will throw malformedVarint() + } + } + } + currentByteBufferPos = tempPos; + return x; + } + return readRawVarint64SlowPath(); + } + + @Override + long readRawVarint64SlowPath() throws IOException { + long result = 0; + for (int shift = 0; shift < 64; shift += 7) { + final byte b = readRawByte(); + result |= (long) (b & 0x7F) << shift; + if ((b & 0x80) == 0) { + return result; + } + } + throw InvalidProtocolBufferException.malformedVarint(); + } + + @Override + public int readRawLittleEndian32() throws IOException { + if (currentRemaining() >= FIXED32_SIZE) { + long tempPos = currentByteBufferPos; + currentByteBufferPos += FIXED32_SIZE; + return (((UnsafeUtil.getByte(tempPos) & 0xff)) + | ((UnsafeUtil.getByte(tempPos + 1) & 0xff) << 8) + | ((UnsafeUtil.getByte(tempPos + 2) & 0xff) << 16) + | ((UnsafeUtil.getByte(tempPos + 3) & 0xff) << 24)); + } + return ((readRawByte() & 0xff) + | ((readRawByte() & 0xff) << 8) + | ((readRawByte() & 0xff) << 16) + | ((readRawByte() & 0xff) << 24)); + } + + @Override + public long readRawLittleEndian64() throws IOException { + if (currentRemaining() >= FIXED64_SIZE) { + long tempPos = currentByteBufferPos; + currentByteBufferPos += FIXED64_SIZE; + return (((UnsafeUtil.getByte(tempPos) & 0xffL)) + | ((UnsafeUtil.getByte(tempPos + 1) & 0xffL) << 8) + | ((UnsafeUtil.getByte(tempPos + 2) & 0xffL) << 16) + | ((UnsafeUtil.getByte(tempPos + 3) & 0xffL) << 24) + | ((UnsafeUtil.getByte(tempPos + 4) & 0xffL) << 32) + | ((UnsafeUtil.getByte(tempPos + 5) & 0xffL) << 40) + | ((UnsafeUtil.getByte(tempPos + 6) & 0xffL) << 48) + | ((UnsafeUtil.getByte(tempPos + 7) & 0xffL) << 56)); + } + return ((readRawByte() & 0xffL) + | ((readRawByte() & 0xffL) << 8) + | ((readRawByte() & 0xffL) << 16) + | ((readRawByte() & 0xffL) << 24) + | ((readRawByte() & 0xffL) << 32) + | ((readRawByte() & 0xffL) << 40) + | ((readRawByte() & 0xffL) << 48) + | ((readRawByte() & 0xffL) << 56)); + } + + @Override + public void enableAliasing(boolean enabled) { + this.enableAliasing = enabled; + } + + @Override + public void resetSizeCounter() { + startOffset = (int) (totalBytesRead + currentByteBufferPos - currentByteBufferStartPos); + } + + @Override + public int pushLimit(int byteLimit) throws InvalidProtocolBufferException { + if (byteLimit < 0) { + throw InvalidProtocolBufferException.negativeSize(); + } + byteLimit += getTotalBytesRead(); + final int oldLimit = currentLimit; + if (byteLimit > oldLimit) { + throw InvalidProtocolBufferException.truncatedMessage(); + } + currentLimit = byteLimit; + + recomputeBufferSizeAfterLimit(); + + return oldLimit; + } + + private void recomputeBufferSizeAfterLimit() { + totalBufferSize += bufferSizeAfterCurrentLimit; + final int bufferEnd = totalBufferSize - startOffset; + if (bufferEnd > currentLimit) { + // Limit is in current buffer. + bufferSizeAfterCurrentLimit = bufferEnd - currentLimit; + totalBufferSize -= bufferSizeAfterCurrentLimit; + } else { + bufferSizeAfterCurrentLimit = 0; + } + } + + @Override + public void popLimit(final int oldLimit) { + currentLimit = oldLimit; + recomputeBufferSizeAfterLimit(); + } + + @Override + public int getBytesUntilLimit() { + if (currentLimit == Integer.MAX_VALUE) { + return -1; + } + + return currentLimit - getTotalBytesRead(); + } + + @Override + public boolean isAtEnd() throws IOException { + return totalBytesRead + currentByteBufferPos - currentByteBufferStartPos == totalBufferSize; + } + + @Override + public int getTotalBytesRead() { + return (int) + (totalBytesRead - startOffset + currentByteBufferPos - currentByteBufferStartPos); + } + + @Override + public byte readRawByte() throws IOException { + if (currentRemaining() == 0) { + getNextByteBuffer(); + } + return UnsafeUtil.getByte(currentByteBufferPos++); + } + + @Override + public byte[] readRawBytes(final int length) throws IOException { + if (length >= 0 && length <= currentRemaining()) { + byte[] bytes = new byte[length]; + UnsafeUtil.copyMemory(currentByteBufferPos, bytes, 0, length); + currentByteBufferPos += length; + return bytes; + } + if (length >= 0 && length <= remaining()) { + byte[] bytes = new byte[length]; + readRawBytesTo(bytes, 0, length); + return bytes; + } + + if (length <= 0) { + if (length == 0) { + return EMPTY_BYTE_ARRAY; + } else { + throw InvalidProtocolBufferException.negativeSize(); + } + } + + throw InvalidProtocolBufferException.truncatedMessage(); + } + + /** + * Try to get raw bytes from {@code input} with the size of {@code length} and copy to {@code + * bytes} array. If the size is bigger than the number of remaining bytes in the input, then + * throw {@code truncatedMessage} exception. + * + * @param bytes + * @param offset + * @param length + * @throws IOException + */ + private void readRawBytesTo(byte[] bytes, int offset, final int length) throws IOException { + if (length >= 0 && length <= remaining()) { + int l = length; + while (l > 0) { + if (currentRemaining() == 0) { + getNextByteBuffer(); + } + int bytesToCopy = Math.min(l, (int) currentRemaining()); + UnsafeUtil.copyMemory(currentByteBufferPos, bytes, length - l + offset, bytesToCopy); + l -= bytesToCopy; + currentByteBufferPos += bytesToCopy; + } + return; + } + + if (length <= 0) { + if (length == 0) { + return; + } else { + throw InvalidProtocolBufferException.negativeSize(); + } + } + throw InvalidProtocolBufferException.truncatedMessage(); + } + + @Override + public void skipRawBytes(final int length) throws IOException { + if (length >= 0 + && length + <= (totalBufferSize + - totalBytesRead + - currentByteBufferPos + + currentByteBufferStartPos)) { + // We have all the bytes we need already. + int l = length; + while (l > 0) { + if (currentRemaining() == 0) { + getNextByteBuffer(); + } + int rl = Math.min(l, (int) currentRemaining()); + l -= rl; + currentByteBufferPos += rl; + } + return; + } + + if (length < 0) { + throw InvalidProtocolBufferException.negativeSize(); + } + throw InvalidProtocolBufferException.truncatedMessage(); + } + + // TODO: optimize to fastpath + private void skipRawVarint() throws IOException { + for (int i = 0; i < MAX_VARINT_SIZE; i++) { + if (readRawByte() >= 0) { + return; + } + } + throw InvalidProtocolBufferException.malformedVarint(); + } + + /** + * Try to get the number of remaining bytes in {@code input}. + * + * @return the number of remaining bytes in {@code input}. + */ + private int remaining() { + return (int) + (totalBufferSize - totalBytesRead - currentByteBufferPos + currentByteBufferStartPos); + } + + /** + * Try to get the number of remaining bytes in {@code currentByteBuffer}. + * + * @return the number of remaining bytes in {@code currentByteBuffer} + */ + private long currentRemaining() { + return (currentByteBufferLimit - currentByteBufferPos); + } + + private ByteBuffer slice(int begin, int end) throws IOException { + int prevPos = currentByteBuffer.position(); + int prevLimit = currentByteBuffer.limit(); + try { + currentByteBuffer.position(begin); + currentByteBuffer.limit(end); + return currentByteBuffer.slice(); + } catch (IllegalArgumentException e) { + throw InvalidProtocolBufferException.truncatedMessage(); + } finally { + currentByteBuffer.position(prevPos); + currentByteBuffer.limit(prevLimit); + } + } + } } diff --git a/java/core/src/main/java/com/google/protobuf/GeneratedMessageLite.java b/java/core/src/main/java/com/google/protobuf/GeneratedMessageLite.java index 7116ae1c..99864964 100644 --- a/java/core/src/main/java/com/google/protobuf/GeneratedMessageLite.java +++ b/java/core/src/main/java/com/google/protobuf/GeneratedMessageLite.java @@ -220,7 +220,7 @@ public abstract class GeneratedMessageLite< @Override public final boolean isInitialized() { - return dynamicMethod(MethodToInvoke.IS_INITIALIZED, Boolean.TRUE) != null; + return isInitialized((MessageType) this, Boolean.TRUE); } @Override @@ -240,6 +240,8 @@ public abstract class GeneratedMessageLite< public static enum MethodToInvoke { // Rely on/modify instance state IS_INITIALIZED, + GET_MEMOIZED_IS_INITIALIZED, + SET_MEMOIZED_IS_INITIALIZED, VISIT, MERGE_FROM_STREAM, MAKE_IMMUTABLE, @@ -256,25 +258,30 @@ public abstract class GeneratedMessageLite< * Theses different kinds of operations are required to implement message-level operations for * builders in the runtime. This method bundles those operations to reduce the generated methods * count. + * *
    - *
  • {@code MERGE_FROM_STREAM} is parameterized with an {@link CodedInputStream} and - * {@link ExtensionRegistryLite}. It consumes the input stream, parsing the contents into the - * returned protocol buffer. If parsing throws an {@link InvalidProtocolBufferException}, the - * implementation wraps it in a RuntimeException. - *
  • {@code NEW_INSTANCE} returns a new instance of the protocol buffer that has not yet been - * made immutable. See {@code MAKE_IMMUTABLE}. - *
  • {@code IS_INITIALIZED} is parameterized with a {@code Boolean} detailing whether to - * memoize. It returns {@code null} for false and the default instance for true. We optionally - * memoize to support the Builder case, where memoization is not desired. - *
  • {@code NEW_BUILDER} returns a {@code BuilderType} instance. - *
  • {@code VISIT} is parameterized with a {@code Visitor} and a {@code MessageType} and - * recursively iterates through the fields side by side between this and the instance. - *
  • {@code MAKE_IMMUTABLE} sets all internal fields to an immutable state. + *
  • {@code MERGE_FROM_STREAM} is parameterized with an {@link CodedInputStream} and {@link + * ExtensionRegistryLite}. It consumes the input stream, parsing the contents into the + * returned protocol buffer. If parsing throws an {@link InvalidProtocolBufferException}, + * the implementation wraps it in a RuntimeException. + *
  • {@code NEW_INSTANCE} returns a new instance of the protocol buffer that has not yet been + * made immutable. See {@code MAKE_IMMUTABLE}. + *
  • {@code IS_INITIALIZED} returns {@code null} for false and the default instance for true. + * It doesn't use or modify any memoized value. + *
  • {@code GET_MEMOIZED_IS_INITIALIZED} returns the memoized {@code isInitialized} byte + * value. + *
  • {@code SET_MEMOIZED_IS_INITIALIZED} sets the memoized {@code isInitilaized} byte value to + * 1 if the first parameter is not null, or to 0 if the first parameter is null. + *
  • {@code NEW_BUILDER} returns a {@code BuilderType} instance. + *
  • {@code VISIT} is parameterized with a {@code Visitor} and a {@code MessageType} and + * recursively iterates through the fields side by side between this and the instance. + *
  • {@code MAKE_IMMUTABLE} sets all internal fields to an immutable state. *
+ * * This method, plus the implementation of the Builder, enables the Builder class to be proguarded * away entirely on Android. - *

- * For use by generated code only. + * + *

For use by generated code only. */ protected abstract Object dynamicMethod(MethodToInvoke method, Object arg0, Object arg1); @@ -297,9 +304,9 @@ public abstract class GeneratedMessageLite< unknownFields = visitor.visitUnknownFields(unknownFields, other.unknownFields); } + /** - * Merge some unknown fields into the {@link UnknownFieldSetLite} for this - * message. + * Merge some unknown fields into the {@link UnknownFieldSetLite} for this message. * *

For use by generated code only. */ @@ -1403,7 +1410,21 @@ public abstract class GeneratedMessageLite< */ protected static final > boolean isInitialized( T message, boolean shouldMemoize) { - return message.dynamicMethod(MethodToInvoke.IS_INITIALIZED, shouldMemoize) != null; + byte memoizedIsInitialized = + (Byte) message.dynamicMethod(MethodToInvoke.GET_MEMOIZED_IS_INITIALIZED); + if (memoizedIsInitialized == 1) { + return true; + } + if (memoizedIsInitialized == 0) { + return false; + } + boolean isInitialized = + message.dynamicMethod(MethodToInvoke.IS_INITIALIZED, Boolean.FALSE) != null; + if (shouldMemoize) { + message.dynamicMethod( + MethodToInvoke.SET_MEMOIZED_IS_INITIALIZED, isInitialized ? message : null); + } + return isInitialized; } protected static final > void makeImmutable(T message) { diff --git a/java/core/src/main/java/com/google/protobuf/IterableByteBufferInputStream.java b/java/core/src/main/java/com/google/protobuf/IterableByteBufferInputStream.java new file mode 100644 index 00000000..713e8064 --- /dev/null +++ b/java/core/src/main/java/com/google/protobuf/IterableByteBufferInputStream.java @@ -0,0 +1,150 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package com.google.protobuf; + +import static com.google.protobuf.Internal.EMPTY_BYTE_BUFFER; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.util.Iterator; + +class IterableByteBufferInputStream extends InputStream { + /** The {@link Iterator} with type {@link ByteBuffer} of {@code input} */ + private Iterator iterator; + /** The current ByteBuffer; */ + private ByteBuffer currentByteBuffer; + /** The number of ByteBuffers in the input data. */ + private int dataSize; + /** + * Current {@code ByteBuffer}'s index + * + *

If index equals dataSize, then all the data in the InputStream has been consumed + */ + private int currentIndex; + /** The current position for current ByteBuffer */ + private int currentByteBufferPos; + /** Whether current ByteBuffer has an array */ + private boolean hasArray; + /** + * If the current ByteBuffer is unsafe-direct based, currentArray is null; otherwise should be the + * array inside ByteBuffer. + */ + private byte[] currentArray; + /** Current ByteBuffer's array offset */ + private int currentArrayOffset; + /** + * If the current ByteBuffer is unsafe-direct based, currentAddress is the start address of this + * ByteBuffer; otherwise should be zero. + */ + private long currentAddress; + + IterableByteBufferInputStream(Iterable data) { + iterator = data.iterator(); + dataSize = 0; + for (ByteBuffer unused : data) { + dataSize++; + } + currentIndex = -1; + + if (!getNextByteBuffer()) { + currentByteBuffer = EMPTY_BYTE_BUFFER; + currentIndex = 0; + currentByteBufferPos = 0; + currentAddress = 0; + } + } + + private boolean getNextByteBuffer() { + currentIndex++; + if (!iterator.hasNext()) { + return false; + } + currentByteBuffer = iterator.next(); + currentByteBufferPos = currentByteBuffer.position(); + if (currentByteBuffer.hasArray()) { + hasArray = true; + currentArray = currentByteBuffer.array(); + currentArrayOffset = currentByteBuffer.arrayOffset(); + } else { + hasArray = false; + currentAddress = UnsafeUtil.addressOffset(currentByteBuffer); + currentArray = null; + } + return true; + } + + private void updateCurrentByteBufferPos(int numberOfBytesRead) { + currentByteBufferPos += numberOfBytesRead; + if (currentByteBufferPos == currentByteBuffer.limit()) { + getNextByteBuffer(); + } + } + + @Override + public int read() throws IOException { + if (currentIndex == dataSize) { + return -1; + } + if (hasArray) { + int result = currentArray[currentByteBufferPos + currentArrayOffset] & 0xFF; + updateCurrentByteBufferPos(1); + return result; + } else { + int result = UnsafeUtil.getByte(currentByteBufferPos + currentAddress) & 0xFF; + updateCurrentByteBufferPos(1); + return result; + } + } + + @Override + public int read(byte[] output, int offset, int length) throws IOException { + if (currentIndex == dataSize) { + return -1; + } + int remaining = currentByteBuffer.limit() - currentByteBufferPos; + if (length > remaining) { + length = remaining; + } + if (hasArray) { + System.arraycopy( + currentArray, currentByteBufferPos + currentArrayOffset, output, offset, length); + updateCurrentByteBufferPos(length); + } else { + int prevPos = currentByteBuffer.position(); + currentByteBuffer.position(currentByteBufferPos); + currentByteBuffer.get(output, offset, length); + currentByteBuffer.position(prevPos); + updateCurrentByteBufferPos(length); + } + return length; + } +} diff --git a/java/core/src/main/java/com/google/protobuf/SmallSortedMap.java b/java/core/src/main/java/com/google/protobuf/SmallSortedMap.java index 66033f58..279edc4d 100644 --- a/java/core/src/main/java/com/google/protobuf/SmallSortedMap.java +++ b/java/core/src/main/java/com/google/protobuf/SmallSortedMap.java @@ -540,8 +540,8 @@ class SmallSortedMap, V> extends AbstractMap { @Override public boolean hasNext() { - return (pos + 1) < entryList.size() || - getOverflowIterator().hasNext(); + return (pos + 1) < entryList.size() + || (!overflowEntries.isEmpty() && getOverflowIterator().hasNext()); } @Override diff --git a/java/core/src/main/java/com/google/protobuf/TextFormat.java b/java/core/src/main/java/com/google/protobuf/TextFormat.java index 64094d09..ab9acf2f 100644 --- a/java/core/src/main/java/com/google/protobuf/TextFormat.java +++ b/java/core/src/main/java/com/google/protobuf/TextFormat.java @@ -279,9 +279,21 @@ public final class TextFormat { generator.print(String.format((Locale) null, "0x%016x", (Long) value)); break; case WireFormat.WIRETYPE_LENGTH_DELIMITED: - generator.print("\""); - generator.print(escapeBytes((ByteString) value)); - generator.print("\""); + try { + // Try to parse and print the field as an embedded message + UnknownFieldSet message = UnknownFieldSet.parseFrom((ByteString) value); + generator.print("{"); + generator.eol(); + generator.indent(); + Printer.DEFAULT.printUnknownFields(message, generator); + generator.outdent(); + generator.print("}"); + } catch (InvalidProtocolBufferException e) { + // If not parseable as a message, print as a String + generator.print("\""); + generator.print(escapeBytes((ByteString) value)); + generator.print("\""); + } break; case WireFormat.WIRETYPE_START_GROUP: Printer.DEFAULT.printUnknownFields((UnknownFieldSet) value, generator); diff --git a/java/core/src/main/java/com/google/protobuf/UnknownFieldSetLite.java b/java/core/src/main/java/com/google/protobuf/UnknownFieldSetLite.java index d6226fc7..2a614c84 100644 --- a/java/core/src/main/java/com/google/protobuf/UnknownFieldSetLite.java +++ b/java/core/src/main/java/com/google/protobuf/UnknownFieldSetLite.java @@ -81,7 +81,7 @@ public final class UnknownFieldSetLite { System.arraycopy(second.objects, 0, objects, first.count, second.count); return new UnknownFieldSetLite(count, tags, objects, true /* isMutable */); } - + /** * The number of elements in the set. */ @@ -323,6 +323,7 @@ public final class UnknownFieldSetLite { // Package private for unsafe experimental runtime. void storeField(int tag, Object value) { + checkMutable(); ensureCapacity(); tags[count] = tag; diff --git a/java/core/src/main/java/com/google/protobuf/UnsafeUtil.java b/java/core/src/main/java/com/google/protobuf/UnsafeUtil.java index acc03a7c..88315cb6 100644 --- a/java/core/src/main/java/com/google/protobuf/UnsafeUtil.java +++ b/java/core/src/main/java/com/google/protobuf/UnsafeUtil.java @@ -252,10 +252,6 @@ final class UnsafeUtil { MEMORY_ACCESSOR.putLong(address, value); } - static void copyMemory(long srcAddress, long targetAddress, long length) { - MEMORY_ACCESSOR.copyMemory(srcAddress, targetAddress, length); - } - /** * Gets the offset of the {@code address} field of the given direct {@link ByteBuffer}. */ @@ -478,8 +474,6 @@ final class UnsafeUtil { public abstract void putLong(long address, long value); - public abstract void copyMemory(long srcAddress, long targetAddress, long length); - public abstract Object getStaticObject(Field field); public abstract void copyMemory(long srcOffset, byte[] target, long targetIndex, long length); @@ -562,11 +556,6 @@ final class UnsafeUtil { public void putDouble(Object target, long offset, double value) { unsafe.putDouble(target, offset, value); } - - @Override - public void copyMemory(long srcAddress, long targetAddress, long length) { - unsafe.copyMemory(srcAddress, targetAddress, length); - } @Override public void copyMemory(long srcOffset, byte[] target, long targetIndex, long length) { -- cgit v1.2.3