diff options
Diffstat (limited to 'java/core/src/main/java/com/google/protobuf/CodedInputStream.java')
-rw-r--r-- | java/core/src/main/java/com/google/protobuf/CodedInputStream.java | 1260 |
1 files changed, 1153 insertions, 107 deletions
diff --git a/java/core/src/main/java/com/google/protobuf/CodedInputStream.java b/java/core/src/main/java/com/google/protobuf/CodedInputStream.java index 14169dc4..1297462e 100644 --- a/java/core/src/main/java/com/google/protobuf/CodedInputStream.java +++ b/java/core/src/main/java/com/google/protobuf/CodedInputStream.java @@ -34,8 +34,8 @@ import static com.google.protobuf.Internal.EMPTY_BYTE_ARRAY; import static com.google.protobuf.Internal.EMPTY_BYTE_BUFFER; import static com.google.protobuf.Internal.UTF_8; import static com.google.protobuf.Internal.checkNotNull; -import static com.google.protobuf.WireFormat.FIXED_32_SIZE; -import static com.google.protobuf.WireFormat.FIXED_64_SIZE; +import static com.google.protobuf.WireFormat.FIXED32_SIZE; +import static com.google.protobuf.WireFormat.FIXED64_SIZE; import static com.google.protobuf.WireFormat.MAX_VARINT_SIZE; import java.io.ByteArrayOutputStream; @@ -44,6 +44,7 @@ import java.io.InputStream; import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Arrays; +import java.util.Iterator; import java.util.List; /** @@ -63,6 +64,12 @@ public abstract class CodedInputStream { // Integer.MAX_VALUE == 0x7FFFFFF == INT_MAX from limits.h private static final int DEFAULT_SIZE_LIMIT = Integer.MAX_VALUE; + /** + * Whether to enable our custom UTF-8 decode codepath which does not use {@link StringCoding}. + * Currently disabled. + */ + private static final boolean ENABLE_CUSTOM_UTF8_DECODE = false; + /** Visible for subclasses. See setRecursionLimit() */ int recursionDepth; @@ -85,6 +92,43 @@ public abstract class CodedInputStream { return new StreamDecoder(input, bufferSize); } + /** Create a new CodedInputStream wrapping the given {@code Iterable <ByteBuffer>}. */ + public static CodedInputStream newInstance(final Iterable<ByteBuffer> input) { + if (!UnsafeDirectNioDecoder.isSupported()) { + return newInstance(new IterableByteBufferInputStream(input)); + } + return newInstance(input, false); + } + + /** Create a new CodedInputStream wrapping the given {@code Iterable <ByteBuffer>}. */ + static CodedInputStream newInstance( + final Iterable<ByteBuffer> bufs, final boolean bufferIsImmutable) { + // flag is to check the type of input's ByteBuffers. + // flag equals 1: all ByteBuffers have array. + // flag equals 2: all ByteBuffers are direct ByteBuffers. + // flag equals 3: some ByteBuffers are direct and some have array. + // flag greater than 3: other cases. + int flag = 0; + // Total size of the input + int totalSize = 0; + for (ByteBuffer buf : bufs) { + totalSize += buf.remaining(); + if (buf.hasArray()) { + flag |= 1; + } else if (buf.isDirect()) { + flag |= 2; + } else { + flag |= 4; + } + } + if (flag == 2) { + return new IterableDirectByteBufferDecoder(bufs, totalSize, bufferIsImmutable); + } else { + // TODO(yilunchong): add another decoders to deal case 1 and 3. + return newInstance(new IterableByteBufferInputStream(bufs)); + } + } + /** Create a new CodedInputStream wrapping the given byte array. */ public static CodedInputStream newInstance(final byte[] buf) { return newInstance(buf, 0, buf.length); @@ -354,9 +398,9 @@ public abstract class CodedInputStream { * * <p>Set the maximum message size. In order to prevent malicious messages from exhausting memory * or causing integer overflows, {@code CodedInputStream} limits how large a message may be. The - * default limit is 64MB. You should set this limit as small as you can without harming your app's - * functionality. Note that size limits only apply when reading from an {@code InputStream}, not - * when constructed around a raw byte array (nor with {@link ByteString#newCodedInput}). + * default limit is {@code Integer.MAX_INT}. You should set this limit as small as you can without + * harming your app's functionality. Note that size limits only apply when reading from an {@code + * InputStream}, not when constructed around a raw byte array. * * <p>If you want to read several messages from a single CodedInputStream, you could call {@link * #resetSizeCounter()} after each one to avoid hitting the size limit. @@ -372,6 +416,63 @@ public abstract class CodedInputStream { return oldLimit; } + + private boolean explicitDiscardUnknownFields = false; + + private static volatile boolean proto3DiscardUnknownFieldsDefault = false; + + static void setProto3DiscardUnknownsByDefaultForTest() { + proto3DiscardUnknownFieldsDefault = true; + } + + static void setProto3KeepUnknownsByDefaultForTest() { + proto3DiscardUnknownFieldsDefault = false; + } + + static boolean getProto3DiscardUnknownFieldsDefault() { + return proto3DiscardUnknownFieldsDefault; + } + + /** + * Sets this {@code CodedInputStream} to discard unknown fields. Only applies to full runtime + * messages; lite messages will always preserve unknowns. + * + * <p>Note calling this function alone will have NO immediate effect on the underlying input data. + * The unknown fields will be discarded during parsing. This affects both Proto2 and Proto3 full + * runtime. + */ + final void discardUnknownFields() { + explicitDiscardUnknownFields = true; + } + + /** + * Reverts the unknown fields preservation behavior for Proto2 and Proto3 full runtime to their + * default. + */ + final void unsetDiscardUnknownFields() { + explicitDiscardUnknownFields = false; + } + + /** + * Whether unknown fields in this input stream should be discarded during parsing into full + * runtime messages. + */ + final boolean shouldDiscardUnknownFields() { + return explicitDiscardUnknownFields; + } + + /** + * Whether unknown fields in this input stream should be discarded during parsing for proto3 full + * runtime messages. + * + * <p>This function was temporarily introduced before proto3 unknown fields behavior is changed. + * TODO(liujisi): remove this and related code in GeneratedMessage after proto3 unknown + * fields migration is done. + */ + final boolean shouldDiscardUnknownFieldsProto3() { + return explicitDiscardUnknownFields ? true : proto3DiscardUnknownFieldsDefault; + } + /** * Resets the current size counter to zero (see {@link #setSizeLimit(int)}). Only valid for {@link * InputStream}-backed streams. @@ -572,7 +673,7 @@ public abstract class CodedInputStream { skipRawVarint(); return true; case WireFormat.WIRETYPE_FIXED64: - skipRawBytes(FIXED_64_SIZE); + skipRawBytes(FIXED64_SIZE); return true; case WireFormat.WIRETYPE_LENGTH_DELIMITED: skipRawBytes(readRawVarint32()); @@ -585,7 +686,7 @@ public abstract class CodedInputStream { case WireFormat.WIRETYPE_END_GROUP: return false; case WireFormat.WIRETYPE_FIXED32: - skipRawBytes(FIXED_32_SIZE); + skipRawBytes(FIXED32_SIZE); return true; default: throw InvalidProtocolBufferException.invalidWireType(); @@ -730,13 +831,19 @@ public abstract class CodedInputStream { public String readStringRequireUtf8() throws IOException { final int size = readRawVarint32(); if (size > 0 && size <= (limit - pos)) { - // TODO(martinrb): We could save a pass by validating while decoding. - if (!Utf8.isValidUtf8(buffer, pos, pos + size)) { - throw InvalidProtocolBufferException.invalidUtf8(); + if (ENABLE_CUSTOM_UTF8_DECODE) { + String result = Utf8.decodeUtf8(buffer, pos, size); + pos += size; + return result; + } else { + // TODO(martinrb): We could save a pass by validating while decoding. + if (!Utf8.isValidUtf8(buffer, pos, pos + size)) { + throw InvalidProtocolBufferException.invalidUtf8(); + } + final int tempPos = pos; + pos += size; + return new String(buffer, tempPos, size, UTF_8); } - final int tempPos = pos; - pos += size; - return new String(buffer, tempPos, size, UTF_8); } if (size == 0) { @@ -1064,12 +1171,12 @@ public abstract class CodedInputStream { public int readRawLittleEndian32() throws IOException { int tempPos = pos; - if (limit - tempPos < FIXED_32_SIZE) { + if (limit - tempPos < FIXED32_SIZE) { throw InvalidProtocolBufferException.truncatedMessage(); } final byte[] buffer = this.buffer; - pos = tempPos + FIXED_32_SIZE; + pos = tempPos + FIXED32_SIZE; return (((buffer[tempPos] & 0xff)) | ((buffer[tempPos + 1] & 0xff) << 8) | ((buffer[tempPos + 2] & 0xff) << 16) @@ -1080,12 +1187,12 @@ public abstract class CodedInputStream { public long readRawLittleEndian64() throws IOException { int tempPos = pos; - if (limit - tempPos < FIXED_64_SIZE) { + if (limit - tempPos < FIXED64_SIZE) { throw InvalidProtocolBufferException.truncatedMessage(); } final byte[] buffer = this.buffer; - pos = tempPos + FIXED_64_SIZE; + pos = tempPos + FIXED64_SIZE; return (((buffer[tempPos] & 0xffL)) | ((buffer[tempPos + 1] & 0xffL) << 8) | ((buffer[tempPos + 2] & 0xffL) << 16) @@ -1290,7 +1397,7 @@ public abstract class CodedInputStream { skipRawVarint(); return true; case WireFormat.WIRETYPE_FIXED64: - skipRawBytes(FIXED_64_SIZE); + skipRawBytes(FIXED64_SIZE); return true; case WireFormat.WIRETYPE_LENGTH_DELIMITED: skipRawBytes(readRawVarint32()); @@ -1303,7 +1410,7 @@ public abstract class CodedInputStream { case WireFormat.WIRETYPE_END_GROUP: return false; case WireFormat.WIRETYPE_FIXED32: - skipRawBytes(FIXED_32_SIZE); + skipRawBytes(FIXED32_SIZE); return true; default: throw InvalidProtocolBufferException.invalidWireType(); @@ -1429,7 +1536,11 @@ public abstract class CodedInputStream { final int size = readRawVarint32(); if (size > 0 && size <= remaining()) { // TODO(nathanmittler): Is there a way to avoid this copy? - byte[] bytes = copyToArray(pos, pos + size); + // TODO(anuraaga): It might be possible to share the optimized loop with + // readStringRequireUtf8 by implementing Java replacement logic there. + // The same as readBytes' logic + byte[] bytes = new byte[size]; + UnsafeUtil.copyMemory(pos, bytes, 0, size); String result = new String(bytes, UTF_8); pos += size; return result; @@ -1447,17 +1558,26 @@ public abstract class CodedInputStream { @Override public String readStringRequireUtf8() throws IOException { final int size = readRawVarint32(); - if (size >= 0 && size <= remaining()) { - // TODO(nathanmittler): Is there a way to avoid this copy? - byte[] bytes = copyToArray(pos, pos + size); - // TODO(martinrb): We could save a pass by validating while decoding. - if (!Utf8.isValidUtf8(bytes)) { - throw InvalidProtocolBufferException.invalidUtf8(); - } + if (size > 0 && size <= remaining()) { + if (ENABLE_CUSTOM_UTF8_DECODE) { + final int bufferPos = bufferPos(pos); + String result = Utf8.decodeUtf8(buffer, bufferPos, size); + pos += size; + return result; + } else { + // TODO(nathanmittler): Is there a way to avoid this copy? + // The same as readBytes' logic + byte[] bytes = new byte[size]; + UnsafeUtil.copyMemory(pos, bytes, 0, size); + // TODO(martinrb): We could save a pass by validating while decoding. + if (!Utf8.isValidUtf8(bytes)) { + throw InvalidProtocolBufferException.invalidUtf8(); + } - String result = new String(bytes, UTF_8); - pos += size; - return result; + String result = new String(bytes, UTF_8); + pos += size; + return result; + } } if (size == 0) { @@ -1545,14 +1665,17 @@ public abstract class CodedInputStream { public ByteString readBytes() throws IOException { final int size = readRawVarint32(); if (size > 0 && size <= remaining()) { - ByteBuffer result; if (immutable && enableAliasing) { - result = slice(pos, pos + size); + final ByteBuffer result = slice(pos, pos + size); + pos += size; + return ByteString.wrap(result); } else { - result = copy(pos, pos + size); + // Use UnsafeUtil to copy the memory to bytes instead of using ByteBuffer ways. + byte[] bytes = new byte[size]; + UnsafeUtil.copyMemory(pos, bytes, 0, size); + pos += size; + return ByteString.wrap(bytes); } - pos += size; - return ByteString.wrap(result); } if (size == 0) { @@ -1573,18 +1696,21 @@ public abstract class CodedInputStream { public ByteBuffer readByteBuffer() throws IOException { final int size = readRawVarint32(); if (size > 0 && size <= remaining()) { - ByteBuffer result; // "Immutable" implies that buffer is backing a ByteString. // Disallow slicing in this case to prevent the caller from modifying the contents // of the ByteString. if (!immutable && enableAliasing) { - result = slice(pos, pos + size); + final ByteBuffer result = slice(pos, pos + size); + pos += size; + return result; } else { - result = copy(pos, pos + size); + // The same as readBytes' logic + byte[] bytes = new byte[size]; + UnsafeUtil.copyMemory(pos, bytes, 0, size); + pos += size; + return ByteBuffer.wrap(bytes); } - pos += size; // TODO(nathanmittler): Investigate making the ByteBuffer be made read-only - return result; } if (size == 0) { @@ -1785,11 +1911,11 @@ public abstract class CodedInputStream { public int readRawLittleEndian32() throws IOException { long tempPos = pos; - if (limit - tempPos < FIXED_32_SIZE) { + if (limit - tempPos < FIXED32_SIZE) { throw InvalidProtocolBufferException.truncatedMessage(); } - pos = tempPos + FIXED_32_SIZE; + pos = tempPos + FIXED32_SIZE; return (((UnsafeUtil.getByte(tempPos) & 0xff)) | ((UnsafeUtil.getByte(tempPos + 1) & 0xff) << 8) | ((UnsafeUtil.getByte(tempPos + 2) & 0xff) << 16) @@ -1800,11 +1926,11 @@ public abstract class CodedInputStream { public long readRawLittleEndian64() throws IOException { long tempPos = pos; - if (limit - tempPos < FIXED_64_SIZE) { + if (limit - tempPos < FIXED64_SIZE) { throw InvalidProtocolBufferException.truncatedMessage(); } - pos = tempPos + FIXED_64_SIZE; + pos = tempPos + FIXED64_SIZE; return (((UnsafeUtil.getByte(tempPos) & 0xffL)) | ((UnsafeUtil.getByte(tempPos + 1) & 0xffL) << 8) | ((UnsafeUtil.getByte(tempPos + 2) & 0xffL) << 16) @@ -1943,27 +2069,6 @@ public abstract class CodedInputStream { buffer.limit(prevLimit); } } - - private ByteBuffer copy(long begin, long end) throws IOException { - return ByteBuffer.wrap(copyToArray(begin, end)); - } - - private byte[] copyToArray(long begin, long end) throws IOException { - int prevPos = buffer.position(); - int prevLimit = buffer.limit(); - try { - buffer.position(bufferPos(begin)); - buffer.limit(bufferPos(end)); - byte[] bytes = new byte[(int) (end - begin)]; - buffer.get(bytes); - return bytes; - } catch (IllegalArgumentException e) { - throw InvalidProtocolBufferException.truncatedMessage(); - } finally { - buffer.position(prevPos); - buffer.limit(prevLimit); - } - } } /** @@ -2034,7 +2139,7 @@ public abstract class CodedInputStream { skipRawVarint(); return true; case WireFormat.WIRETYPE_FIXED64: - skipRawBytes(FIXED_64_SIZE); + skipRawBytes(FIXED64_SIZE); return true; case WireFormat.WIRETYPE_LENGTH_DELIMITED: skipRawBytes(readRawVarint32()); @@ -2047,7 +2152,7 @@ public abstract class CodedInputStream { case WireFormat.WIRETYPE_END_GROUP: return false; case WireFormat.WIRETYPE_FIXED32: - skipRawBytes(FIXED_32_SIZE); + skipRawBytes(FIXED32_SIZE); return true; default: throw InvalidProtocolBufferException.invalidWireType(); @@ -2240,11 +2345,15 @@ public abstract class CodedInputStream { bytes = readRawBytesSlowPath(size); tempPos = 0; } - // TODO(martinrb): We could save a pass by validating while decoding. - if (!Utf8.isValidUtf8(bytes, tempPos, tempPos + size)) { - throw InvalidProtocolBufferException.invalidUtf8(); + if (ENABLE_CUSTOM_UTF8_DECODE) { + return Utf8.decodeUtf8(bytes, tempPos, size); + } else { + // TODO(martinrb): We could save a pass by validating while decoding. + if (!Utf8.isValidUtf8(bytes, tempPos, tempPos + size)) { + throw InvalidProtocolBufferException.invalidUtf8(); + } + return new String(bytes, tempPos, size, UTF_8); } - return new String(bytes, tempPos, size, UTF_8); } @Override @@ -2332,8 +2441,7 @@ public abstract class CodedInputStream { if (size == 0) { return ByteString.EMPTY; } - // Slow path: Build a byte array first then copy it. - return ByteString.wrap(readRawBytesSlowPath(size)); + return readBytesSlowPath(size); } @Override @@ -2558,13 +2666,13 @@ public abstract class CodedInputStream { public int readRawLittleEndian32() throws IOException { int tempPos = pos; - if (bufferSize - tempPos < FIXED_32_SIZE) { - refillBuffer(FIXED_32_SIZE); + if (bufferSize - tempPos < FIXED32_SIZE) { + refillBuffer(FIXED32_SIZE); tempPos = pos; } final byte[] buffer = this.buffer; - pos = tempPos + FIXED_32_SIZE; + pos = tempPos + FIXED32_SIZE; return (((buffer[tempPos] & 0xff)) | ((buffer[tempPos + 1] & 0xff) << 8) | ((buffer[tempPos + 2] & 0xff) << 16) @@ -2575,13 +2683,13 @@ public abstract class CodedInputStream { public long readRawLittleEndian64() throws IOException { int tempPos = pos; - if (bufferSize - tempPos < FIXED_64_SIZE) { - refillBuffer(FIXED_64_SIZE); + if (bufferSize - tempPos < FIXED64_SIZE) { + refillBuffer(FIXED64_SIZE); tempPos = pos; } final byte[] buffer = this.buffer; - pos = tempPos + FIXED_64_SIZE; + pos = tempPos + FIXED64_SIZE; return (((buffer[tempPos] & 0xffL)) | ((buffer[tempPos + 1] & 0xffL) << 8) | ((buffer[tempPos + 2] & 0xffL) << 16) @@ -2675,7 +2783,13 @@ public abstract class CodedInputStream { */ private void refillBuffer(int n) throws IOException { if (!tryRefillBuffer(n)) { - throw InvalidProtocolBufferException.truncatedMessage(); + // We have to distinguish the exception between sizeLimitExceeded and truncatedMessage. So + // we just throw an sizeLimitExceeded exception here if it exceeds the sizeLimit + if (n > sizeLimit - totalBytesRetired - pos) { + throw InvalidProtocolBufferException.sizeLimitExceeded(); + } else { + throw InvalidProtocolBufferException.truncatedMessage(); + } } } @@ -2684,8 +2798,8 @@ public abstract class CodedInputStream { * buffer. Caller must ensure that the requested space is not yet available, and that the * requested space is less than BUFFER_SIZE. * - * @return {@code true} if the bytes could be made available; {@code false} if the end of the - * stream or the current limit was reached. + * @return {@code true} If the bytes could be made available; {@code false} 1. Current at the + * end of the stream 2. The current limit was reached 3. The total size limit was reached */ private boolean tryRefillBuffer(int n) throws IOException { if (pos + n <= bufferSize) { @@ -2693,6 +2807,14 @@ public abstract class CodedInputStream { "refillBuffer() called when " + n + " bytes were already available in buffer"); } + // Check whether the size of total message needs to read is bigger than the size limit. + // We shouldn't throw an exception here as isAtEnd() function needs to get this function's + // return as the result. + if (n > sizeLimit - totalBytesRetired - pos) { + return false; + } + + // Shouldn't throw the exception here either. if (totalBytesRetired + pos + n > currentLimit) { // Oops, we hit a limit. return false; @@ -2712,7 +2834,16 @@ public abstract class CodedInputStream { pos = 0; } - int bytesRead = input.read(buffer, bufferSize, buffer.length - bufferSize); + // Here we should refill the buffer as many bytes as possible. + int bytesRead = + input.read( + buffer, + bufferSize, + Math.min( + // the size of allocated but unused bytes in the buffer + buffer.length - bufferSize, + // do not exceed the total bytes limit + sizeLimit - totalBytesRetired - bufferSize)); if (bytesRead == 0 || bytesRead < -1 || bytesRead > buffer.length) { throw new IllegalStateException( "InputStream#read(byte[]) returned invalid result: " @@ -2721,10 +2852,6 @@ public abstract class CodedInputStream { } if (bytesRead > 0) { bufferSize += bytesRead; - // Integer-overflow-conscious check against sizeLimit - if (totalBytesRetired + n - sizeLimit > 0) { - throw InvalidProtocolBufferException.sizeLimitExceeded(); - } recomputeBufferSizeAfterLimit(); return (bufferSize >= n) ? true : tryRefillBuffer(n); } @@ -2756,6 +2883,49 @@ public abstract class CodedInputStream { * (bufferSize - pos) && size > 0) */ private byte[] readRawBytesSlowPath(final int size) throws IOException { + // Attempt to read the data in one byte array when it's safe to do. + byte[] result = readRawBytesSlowPathOneChunk(size); + if (result != null) { + return result; + } + + final int originalBufferPos = pos; + final int bufferedBytes = bufferSize - pos; + + // Mark the current buffer consumed. + totalBytesRetired += bufferSize; + pos = 0; + bufferSize = 0; + + // Determine the number of bytes we need to read from the input stream. + int sizeLeft = size - bufferedBytes; + + // The size is very large. For security reasons we read them in small + // chunks. + List<byte[]> chunks = readRawBytesSlowPathRemainingChunks(sizeLeft); + + // OK, got everything. Now concatenate it all into one buffer. + final byte[] bytes = new byte[size]; + + // Start by copying the leftover bytes from this.buffer. + System.arraycopy(buffer, originalBufferPos, bytes, 0, bufferedBytes); + + // And now all the chunks. + int tempPos = bufferedBytes; + for (final byte[] chunk : chunks) { + System.arraycopy(chunk, 0, bytes, tempPos, chunk.length); + tempPos += chunk.length; + } + + // Done. + return bytes; + } + + /** + * Attempts to read the data in one byte array when it's safe to do. Returns null if the size to + * read is too large and needs to be allocated in smaller chunks for security reasons. + */ + private byte[] readRawBytesSlowPathOneChunk(final int size) throws IOException { if (size == 0) { return Internal.EMPTY_BYTE_ARRAY; } @@ -2776,14 +2946,7 @@ public abstract class CodedInputStream { throw InvalidProtocolBufferException.truncatedMessage(); } - final int originalBufferPos = pos; final int bufferedBytes = bufferSize - pos; - - // Mark the current buffer consumed. - totalBytesRetired += bufferSize; - pos = 0; - bufferSize = 0; - // Determine the number of bytes we need to read from the input stream. int sizeLeft = size - bufferedBytes; // TODO(nathanmittler): Consider using a value larger than DEFAULT_BUFFER_SIZE. @@ -2793,7 +2956,10 @@ public abstract class CodedInputStream { final byte[] bytes = new byte[size]; // Copy all of the buffered bytes to the result buffer. - System.arraycopy(buffer, originalBufferPos, bytes, 0, bufferedBytes); + System.arraycopy(buffer, pos, bytes, 0, bufferedBytes); + totalBytesRetired += bufferSize; + pos = 0; + bufferSize = 0; // Fill the remaining bytes from the input stream. int tempPos = bufferedBytes; @@ -2809,6 +2975,11 @@ public abstract class CodedInputStream { return bytes; } + return null; + } + + /** Reads the remaining data in small chunks from the input stream. */ + private List<byte[]> readRawBytesSlowPathRemainingChunks(int sizeLeft) throws IOException { // The size is very large. For security reasons, we can't allocate the // entire byte array yet. The size comes directly from the input, so a // maliciously-crafted message could provide a bogus very large size in @@ -2834,21 +3005,41 @@ public abstract class CodedInputStream { chunks.add(chunk); } - // OK, got everything. Now concatenate it all into one buffer. - final byte[] bytes = new byte[size]; - - // Start by copying the leftover bytes from this.buffer. - System.arraycopy(buffer, originalBufferPos, bytes, 0, bufferedBytes); + return chunks; + } - // And now all the chunks. - int tempPos = bufferedBytes; - for (final byte[] chunk : chunks) { - System.arraycopy(chunk, 0, bytes, tempPos, chunk.length); - tempPos += chunk.length; + /** + * Like readBytes, but caller must have already checked the fast path: (size <= (bufferSize - + * pos) && size > 0 || size == 0) + */ + private ByteString readBytesSlowPath(final int size) throws IOException { + final byte[] result = readRawBytesSlowPathOneChunk(size); + if (result != null) { + return ByteString.wrap(result); } - // Done. - return bytes; + final int originalBufferPos = pos; + final int bufferedBytes = bufferSize - pos; + + // Mark the current buffer consumed. + totalBytesRetired += bufferSize; + pos = 0; + bufferSize = 0; + + // Determine the number of bytes we need to read from the input stream. + int sizeLeft = size - bufferedBytes; + + // The size is very large. For security reasons we read them in small + // chunks. + List<byte[]> chunks = readRawBytesSlowPathRemainingChunks(sizeLeft); + + // Wrap the byte arrays into a single ByteString. + List<ByteString> byteStrings = new ArrayList<ByteString>(1 + chunks.size()); + byteStrings.add(ByteString.copyFrom(buffer, originalBufferPos, bufferedBytes)); + for (byte[] chunk : chunks) { + byteStrings.add(ByteString.wrap(chunk)); + } + return ByteString.copyFrom(byteStrings); } @Override @@ -2893,4 +3084,859 @@ public abstract class CodedInputStream { pos = size - tempPos; } } + + /** + * Implementation of {@link CodedInputStream} that uses an {@link Iterable <ByteBuffer>} as the + * data source. Requires the use of {@code sun.misc.Unsafe} to perform fast reads on the buffer. + */ + private static final class IterableDirectByteBufferDecoder extends CodedInputStream { + /** The object that need to decode. */ + private Iterable<ByteBuffer> input; + /** The {@link Iterator} with type {@link ByteBuffer} of {@code input} */ + private Iterator<ByteBuffer> iterator; + /** The current ByteBuffer; */ + private ByteBuffer currentByteBuffer; + /** + * If {@code true}, indicates that all the buffer are backing a {@link ByteString} and are + * therefore considered to be an immutable input source. + */ + private boolean immutable; + /** + * If {@code true}, indicates that calls to read {@link ByteString} or {@code byte[]} + * <strong>may</strong> return slices of the underlying buffer, rather than copies. + */ + private boolean enableAliasing; + /** The global total message length limit */ + private int totalBufferSize; + /** The amount of available data in the input beyond {@link #currentLimit}. */ + private int bufferSizeAfterCurrentLimit; + /** The absolute position of the end of the current message. */ + private int currentLimit = Integer.MAX_VALUE; + /** The last tag that was read from this stream. */ + private int lastTag; + /** Total Bytes have been Read from the {@link Iterable} {@link ByteBuffer} */ + private int totalBytesRead; + /** The start position offset of the whole message, used as to reset the totalBytesRead */ + private int startOffset; + /** The current position for current ByteBuffer */ + private long currentByteBufferPos; + + private long currentByteBufferStartPos; + /** + * If the current ByteBuffer is unsafe-direct based, currentAddress is the start address of this + * ByteBuffer; otherwise should be zero. + */ + private long currentAddress; + /** The limit position for current ByteBuffer */ + private long currentByteBufferLimit; + + /** + * The constructor of {@code Iterable<ByteBuffer>} decoder. + * + * @param inputBufs The input data. + * @param size The total size of the input data. + * @param immutableFlag whether the input data is immutable. + */ + private IterableDirectByteBufferDecoder( + Iterable<ByteBuffer> inputBufs, int size, boolean immutableFlag) { + totalBufferSize = size; + input = inputBufs; + iterator = input.iterator(); + immutable = immutableFlag; + startOffset = totalBytesRead = 0; + if (size == 0) { + currentByteBuffer = EMPTY_BYTE_BUFFER; + currentByteBufferPos = 0; + currentByteBufferStartPos = 0; + currentByteBufferLimit = 0; + currentAddress = 0; + } else { + tryGetNextByteBuffer(); + } + } + + /** To get the next ByteBuffer from {@code input}, and then update the parameters */ + private void getNextByteBuffer() throws InvalidProtocolBufferException { + if (!iterator.hasNext()) { + throw InvalidProtocolBufferException.truncatedMessage(); + } + tryGetNextByteBuffer(); + } + + private void tryGetNextByteBuffer() { + currentByteBuffer = iterator.next(); + totalBytesRead += (int) (currentByteBufferPos - currentByteBufferStartPos); + currentByteBufferPos = currentByteBuffer.position(); + currentByteBufferStartPos = currentByteBufferPos; + currentByteBufferLimit = currentByteBuffer.limit(); + currentAddress = UnsafeUtil.addressOffset(currentByteBuffer); + currentByteBufferPos += currentAddress; + currentByteBufferStartPos += currentAddress; + currentByteBufferLimit += currentAddress; + } + + @Override + public int readTag() throws IOException { + if (isAtEnd()) { + lastTag = 0; + return 0; + } + + lastTag = readRawVarint32(); + if (WireFormat.getTagFieldNumber(lastTag) == 0) { + // If we actually read zero (or any tag number corresponding to field + // number zero), that's not a valid tag. + throw InvalidProtocolBufferException.invalidTag(); + } + return lastTag; + } + + @Override + public void checkLastTagWas(final int value) throws InvalidProtocolBufferException { + if (lastTag != value) { + throw InvalidProtocolBufferException.invalidEndTag(); + } + } + + @Override + public int getLastTag() { + return lastTag; + } + + @Override + public boolean skipField(final int tag) throws IOException { + switch (WireFormat.getTagWireType(tag)) { + case WireFormat.WIRETYPE_VARINT: + skipRawVarint(); + return true; + case WireFormat.WIRETYPE_FIXED64: + skipRawBytes(FIXED64_SIZE); + return true; + case WireFormat.WIRETYPE_LENGTH_DELIMITED: + skipRawBytes(readRawVarint32()); + return true; + case WireFormat.WIRETYPE_START_GROUP: + skipMessage(); + checkLastTagWas( + WireFormat.makeTag(WireFormat.getTagFieldNumber(tag), WireFormat.WIRETYPE_END_GROUP)); + return true; + case WireFormat.WIRETYPE_END_GROUP: + return false; + case WireFormat.WIRETYPE_FIXED32: + skipRawBytes(FIXED32_SIZE); + return true; + default: + throw InvalidProtocolBufferException.invalidWireType(); + } + } + + @Override + public boolean skipField(final int tag, final CodedOutputStream output) throws IOException { + switch (WireFormat.getTagWireType(tag)) { + case WireFormat.WIRETYPE_VARINT: + { + long value = readInt64(); + output.writeRawVarint32(tag); + output.writeUInt64NoTag(value); + return true; + } + case WireFormat.WIRETYPE_FIXED64: + { + long value = readRawLittleEndian64(); + output.writeRawVarint32(tag); + output.writeFixed64NoTag(value); + return true; + } + case WireFormat.WIRETYPE_LENGTH_DELIMITED: + { + ByteString value = readBytes(); + output.writeRawVarint32(tag); + output.writeBytesNoTag(value); + return true; + } + case WireFormat.WIRETYPE_START_GROUP: + { + output.writeRawVarint32(tag); + skipMessage(output); + int endtag = + WireFormat.makeTag( + WireFormat.getTagFieldNumber(tag), WireFormat.WIRETYPE_END_GROUP); + checkLastTagWas(endtag); + output.writeRawVarint32(endtag); + return true; + } + case WireFormat.WIRETYPE_END_GROUP: + { + return false; + } + case WireFormat.WIRETYPE_FIXED32: + { + int value = readRawLittleEndian32(); + output.writeRawVarint32(tag); + output.writeFixed32NoTag(value); + return true; + } + default: + throw InvalidProtocolBufferException.invalidWireType(); + } + } + + @Override + public void skipMessage() throws IOException { + while (true) { + final int tag = readTag(); + if (tag == 0 || !skipField(tag)) { + return; + } + } + } + + @Override + public void skipMessage(CodedOutputStream output) throws IOException { + while (true) { + final int tag = readTag(); + if (tag == 0 || !skipField(tag, output)) { + return; + } + } + } + + // ----------------------------------------------------------------- + + @Override + public double readDouble() throws IOException { + return Double.longBitsToDouble(readRawLittleEndian64()); + } + + @Override + public float readFloat() throws IOException { + return Float.intBitsToFloat(readRawLittleEndian32()); + } + + @Override + public long readUInt64() throws IOException { + return readRawVarint64(); + } + + @Override + public long readInt64() throws IOException { + return readRawVarint64(); + } + + @Override + public int readInt32() throws IOException { + return readRawVarint32(); + } + + @Override + public long readFixed64() throws IOException { + return readRawLittleEndian64(); + } + + @Override + public int readFixed32() throws IOException { + return readRawLittleEndian32(); + } + + @Override + public boolean readBool() throws IOException { + return readRawVarint64() != 0; + } + + @Override + public String readString() throws IOException { + final int size = readRawVarint32(); + if (size > 0 && size <= currentByteBufferLimit - currentByteBufferPos) { + byte[] bytes = new byte[size]; + UnsafeUtil.copyMemory(currentByteBufferPos, bytes, 0, size); + String result = new String(bytes, UTF_8); + currentByteBufferPos += size; + return result; + } else if (size > 0 && size <= remaining()) { + // TODO(yilunchong): To use an underlying bytes[] instead of allocating a new bytes[] + byte[] bytes = new byte[size]; + readRawBytesTo(bytes, 0, size); + String result = new String(bytes, UTF_8); + return result; + } + + if (size == 0) { + return ""; + } + if (size < 0) { + throw InvalidProtocolBufferException.negativeSize(); + } + throw InvalidProtocolBufferException.truncatedMessage(); + } + + @Override + public String readStringRequireUtf8() throws IOException { + final int size = readRawVarint32(); + if (size > 0 && size <= currentByteBufferLimit - currentByteBufferPos) { + if (ENABLE_CUSTOM_UTF8_DECODE) { + final int bufferPos = (int) (currentByteBufferPos - currentByteBufferStartPos); + String result = Utf8.decodeUtf8(currentByteBuffer, bufferPos, size); + currentByteBufferPos += size; + return result; + } else { + byte[] bytes = new byte[size]; + UnsafeUtil.copyMemory(currentByteBufferPos, bytes, 0, size); + if (!Utf8.isValidUtf8(bytes)) { + throw InvalidProtocolBufferException.invalidUtf8(); + } + String result = new String(bytes, UTF_8); + currentByteBufferPos += size; + return result; + } + } + if (size >= 0 && size <= remaining()) { + byte[] bytes = new byte[size]; + readRawBytesTo(bytes, 0, size); + if (ENABLE_CUSTOM_UTF8_DECODE) { + return Utf8.decodeUtf8(bytes, 0, size); + } else { + if (!Utf8.isValidUtf8(bytes)) { + throw InvalidProtocolBufferException.invalidUtf8(); + } + String result = new String(bytes, UTF_8); + return result; + } + } + + if (size == 0) { + return ""; + } + if (size <= 0) { + throw InvalidProtocolBufferException.negativeSize(); + } + throw InvalidProtocolBufferException.truncatedMessage(); + } + + @Override + public void readGroup( + final int fieldNumber, + final MessageLite.Builder builder, + final ExtensionRegistryLite extensionRegistry) + throws IOException { + if (recursionDepth >= recursionLimit) { + throw InvalidProtocolBufferException.recursionLimitExceeded(); + } + ++recursionDepth; + builder.mergeFrom(this, extensionRegistry); + checkLastTagWas(WireFormat.makeTag(fieldNumber, WireFormat.WIRETYPE_END_GROUP)); + --recursionDepth; + } + + + @Override + public <T extends MessageLite> T readGroup( + final int fieldNumber, + final Parser<T> parser, + final ExtensionRegistryLite extensionRegistry) + throws IOException { + if (recursionDepth >= recursionLimit) { + throw InvalidProtocolBufferException.recursionLimitExceeded(); + } + ++recursionDepth; + T result = parser.parsePartialFrom(this, extensionRegistry); + checkLastTagWas(WireFormat.makeTag(fieldNumber, WireFormat.WIRETYPE_END_GROUP)); + --recursionDepth; + return result; + } + + @Deprecated + @Override + public void readUnknownGroup(final int fieldNumber, final MessageLite.Builder builder) + throws IOException { + readGroup(fieldNumber, builder, ExtensionRegistryLite.getEmptyRegistry()); + } + + @Override + public void readMessage( + final MessageLite.Builder builder, final ExtensionRegistryLite extensionRegistry) + throws IOException { + final int length = readRawVarint32(); + if (recursionDepth >= recursionLimit) { + throw InvalidProtocolBufferException.recursionLimitExceeded(); + } + final int oldLimit = pushLimit(length); + ++recursionDepth; + builder.mergeFrom(this, extensionRegistry); + checkLastTagWas(0); + --recursionDepth; + popLimit(oldLimit); + } + + + @Override + public <T extends MessageLite> T readMessage( + final Parser<T> parser, final ExtensionRegistryLite extensionRegistry) throws IOException { + int length = readRawVarint32(); + if (recursionDepth >= recursionLimit) { + throw InvalidProtocolBufferException.recursionLimitExceeded(); + } + final int oldLimit = pushLimit(length); + ++recursionDepth; + T result = parser.parsePartialFrom(this, extensionRegistry); + checkLastTagWas(0); + --recursionDepth; + popLimit(oldLimit); + return result; + } + + @Override + public ByteString readBytes() throws IOException { + final int size = readRawVarint32(); + if (size > 0 && size <= currentByteBufferLimit - currentByteBufferPos) { + if (immutable && enableAliasing) { + final int idx = (int) (currentByteBufferPos - currentAddress); + final ByteString result = ByteString.wrap(slice(idx, idx + size)); + currentByteBufferPos += size; + return result; + } else { + byte[] bytes; + bytes = new byte[size]; + UnsafeUtil.copyMemory(currentByteBufferPos, bytes, 0, size); + currentByteBufferPos += size; + return ByteString.wrap(bytes); + } + } else if (size > 0 && size <= remaining()) { + byte[] temp = new byte[size]; + readRawBytesTo(temp, 0, size); + return ByteString.wrap(temp); + } + + if (size == 0) { + return ByteString.EMPTY; + } + if (size < 0) { + throw InvalidProtocolBufferException.negativeSize(); + } + throw InvalidProtocolBufferException.truncatedMessage(); + } + + @Override + public byte[] readByteArray() throws IOException { + return readRawBytes(readRawVarint32()); + } + + @Override + public ByteBuffer readByteBuffer() throws IOException { + final int size = readRawVarint32(); + if (size > 0 && size <= currentRemaining()) { + if (!immutable && enableAliasing) { + currentByteBufferPos += size; + return slice( + (int) (currentByteBufferPos - currentAddress - size), + (int) (currentByteBufferPos - currentAddress)); + } else { + byte[] bytes = new byte[size]; + UnsafeUtil.copyMemory(currentByteBufferPos, bytes, 0, size); + currentByteBufferPos += size; + return ByteBuffer.wrap(bytes); + } + } else if (size > 0 && size <= remaining()) { + byte[] temp = new byte[size]; + readRawBytesTo(temp, 0, size); + return ByteBuffer.wrap(temp); + } + + if (size == 0) { + return EMPTY_BYTE_BUFFER; + } + if (size < 0) { + throw InvalidProtocolBufferException.negativeSize(); + } + throw InvalidProtocolBufferException.truncatedMessage(); + } + + @Override + public int readUInt32() throws IOException { + return readRawVarint32(); + } + + @Override + public int readEnum() throws IOException { + return readRawVarint32(); + } + + @Override + public int readSFixed32() throws IOException { + return readRawLittleEndian32(); + } + + @Override + public long readSFixed64() throws IOException { + return readRawLittleEndian64(); + } + + @Override + public int readSInt32() throws IOException { + return decodeZigZag32(readRawVarint32()); + } + + @Override + public long readSInt64() throws IOException { + return decodeZigZag64(readRawVarint64()); + } + + @Override + public int readRawVarint32() throws IOException { + fastpath: + { + long tempPos = currentByteBufferPos; + + if (currentByteBufferLimit == currentByteBufferPos) { + break fastpath; + } + + int x; + if ((x = UnsafeUtil.getByte(tempPos++)) >= 0) { + currentByteBufferPos++; + return x; + } else if (currentByteBufferLimit - currentByteBufferPos < 10) { + break fastpath; + } else if ((x ^= (UnsafeUtil.getByte(tempPos++) << 7)) < 0) { + x ^= (~0 << 7); + } else if ((x ^= (UnsafeUtil.getByte(tempPos++) << 14)) >= 0) { + x ^= (~0 << 7) ^ (~0 << 14); + } else if ((x ^= (UnsafeUtil.getByte(tempPos++) << 21)) < 0) { + x ^= (~0 << 7) ^ (~0 << 14) ^ (~0 << 21); + } else { + int y = UnsafeUtil.getByte(tempPos++); + x ^= y << 28; + x ^= (~0 << 7) ^ (~0 << 14) ^ (~0 << 21) ^ (~0 << 28); + if (y < 0 + && UnsafeUtil.getByte(tempPos++) < 0 + && UnsafeUtil.getByte(tempPos++) < 0 + && UnsafeUtil.getByte(tempPos++) < 0 + && UnsafeUtil.getByte(tempPos++) < 0 + && UnsafeUtil.getByte(tempPos++) < 0) { + break fastpath; // Will throw malformedVarint() + } + } + currentByteBufferPos = tempPos; + return x; + } + return (int) readRawVarint64SlowPath(); + } + + @Override + public long readRawVarint64() throws IOException { + fastpath: + { + long tempPos = currentByteBufferPos; + + if (currentByteBufferLimit == currentByteBufferPos) { + break fastpath; + } + + long x; + int y; + if ((y = UnsafeUtil.getByte(tempPos++)) >= 0) { + currentByteBufferPos++; + return y; + } else if (currentByteBufferLimit - currentByteBufferPos < 10) { + break fastpath; + } else if ((y ^= (UnsafeUtil.getByte(tempPos++) << 7)) < 0) { + x = y ^ (~0 << 7); + } else if ((y ^= (UnsafeUtil.getByte(tempPos++) << 14)) >= 0) { + x = y ^ ((~0 << 7) ^ (~0 << 14)); + } else if ((y ^= (UnsafeUtil.getByte(tempPos++) << 21)) < 0) { + x = y ^ ((~0 << 7) ^ (~0 << 14) ^ (~0 << 21)); + } else if ((x = y ^ ((long) UnsafeUtil.getByte(tempPos++) << 28)) >= 0L) { + x ^= (~0L << 7) ^ (~0L << 14) ^ (~0L << 21) ^ (~0L << 28); + } else if ((x ^= ((long) UnsafeUtil.getByte(tempPos++) << 35)) < 0L) { + x ^= (~0L << 7) ^ (~0L << 14) ^ (~0L << 21) ^ (~0L << 28) ^ (~0L << 35); + } else if ((x ^= ((long) UnsafeUtil.getByte(tempPos++) << 42)) >= 0L) { + x ^= (~0L << 7) ^ (~0L << 14) ^ (~0L << 21) ^ (~0L << 28) ^ (~0L << 35) ^ (~0L << 42); + } else if ((x ^= ((long) UnsafeUtil.getByte(tempPos++) << 49)) < 0L) { + x ^= + (~0L << 7) + ^ (~0L << 14) + ^ (~0L << 21) + ^ (~0L << 28) + ^ (~0L << 35) + ^ (~0L << 42) + ^ (~0L << 49); + } else { + x ^= ((long) UnsafeUtil.getByte(tempPos++) << 56); + x ^= + (~0L << 7) + ^ (~0L << 14) + ^ (~0L << 21) + ^ (~0L << 28) + ^ (~0L << 35) + ^ (~0L << 42) + ^ (~0L << 49) + ^ (~0L << 56); + if (x < 0L) { + if (UnsafeUtil.getByte(tempPos++) < 0L) { + break fastpath; // Will throw malformedVarint() + } + } + } + currentByteBufferPos = tempPos; + return x; + } + return readRawVarint64SlowPath(); + } + + @Override + long readRawVarint64SlowPath() throws IOException { + long result = 0; + for (int shift = 0; shift < 64; shift += 7) { + final byte b = readRawByte(); + result |= (long) (b & 0x7F) << shift; + if ((b & 0x80) == 0) { + return result; + } + } + throw InvalidProtocolBufferException.malformedVarint(); + } + + @Override + public int readRawLittleEndian32() throws IOException { + if (currentRemaining() >= FIXED32_SIZE) { + long tempPos = currentByteBufferPos; + currentByteBufferPos += FIXED32_SIZE; + return (((UnsafeUtil.getByte(tempPos) & 0xff)) + | ((UnsafeUtil.getByte(tempPos + 1) & 0xff) << 8) + | ((UnsafeUtil.getByte(tempPos + 2) & 0xff) << 16) + | ((UnsafeUtil.getByte(tempPos + 3) & 0xff) << 24)); + } + return ((readRawByte() & 0xff) + | ((readRawByte() & 0xff) << 8) + | ((readRawByte() & 0xff) << 16) + | ((readRawByte() & 0xff) << 24)); + } + + @Override + public long readRawLittleEndian64() throws IOException { + if (currentRemaining() >= FIXED64_SIZE) { + long tempPos = currentByteBufferPos; + currentByteBufferPos += FIXED64_SIZE; + return (((UnsafeUtil.getByte(tempPos) & 0xffL)) + | ((UnsafeUtil.getByte(tempPos + 1) & 0xffL) << 8) + | ((UnsafeUtil.getByte(tempPos + 2) & 0xffL) << 16) + | ((UnsafeUtil.getByte(tempPos + 3) & 0xffL) << 24) + | ((UnsafeUtil.getByte(tempPos + 4) & 0xffL) << 32) + | ((UnsafeUtil.getByte(tempPos + 5) & 0xffL) << 40) + | ((UnsafeUtil.getByte(tempPos + 6) & 0xffL) << 48) + | ((UnsafeUtil.getByte(tempPos + 7) & 0xffL) << 56)); + } + return ((readRawByte() & 0xffL) + | ((readRawByte() & 0xffL) << 8) + | ((readRawByte() & 0xffL) << 16) + | ((readRawByte() & 0xffL) << 24) + | ((readRawByte() & 0xffL) << 32) + | ((readRawByte() & 0xffL) << 40) + | ((readRawByte() & 0xffL) << 48) + | ((readRawByte() & 0xffL) << 56)); + } + + @Override + public void enableAliasing(boolean enabled) { + this.enableAliasing = enabled; + } + + @Override + public void resetSizeCounter() { + startOffset = (int) (totalBytesRead + currentByteBufferPos - currentByteBufferStartPos); + } + + @Override + public int pushLimit(int byteLimit) throws InvalidProtocolBufferException { + if (byteLimit < 0) { + throw InvalidProtocolBufferException.negativeSize(); + } + byteLimit += getTotalBytesRead(); + final int oldLimit = currentLimit; + if (byteLimit > oldLimit) { + throw InvalidProtocolBufferException.truncatedMessage(); + } + currentLimit = byteLimit; + + recomputeBufferSizeAfterLimit(); + + return oldLimit; + } + + private void recomputeBufferSizeAfterLimit() { + totalBufferSize += bufferSizeAfterCurrentLimit; + final int bufferEnd = totalBufferSize - startOffset; + if (bufferEnd > currentLimit) { + // Limit is in current buffer. + bufferSizeAfterCurrentLimit = bufferEnd - currentLimit; + totalBufferSize -= bufferSizeAfterCurrentLimit; + } else { + bufferSizeAfterCurrentLimit = 0; + } + } + + @Override + public void popLimit(final int oldLimit) { + currentLimit = oldLimit; + recomputeBufferSizeAfterLimit(); + } + + @Override + public int getBytesUntilLimit() { + if (currentLimit == Integer.MAX_VALUE) { + return -1; + } + + return currentLimit - getTotalBytesRead(); + } + + @Override + public boolean isAtEnd() throws IOException { + return totalBytesRead + currentByteBufferPos - currentByteBufferStartPos == totalBufferSize; + } + + @Override + public int getTotalBytesRead() { + return (int) + (totalBytesRead - startOffset + currentByteBufferPos - currentByteBufferStartPos); + } + + @Override + public byte readRawByte() throws IOException { + if (currentRemaining() == 0) { + getNextByteBuffer(); + } + return UnsafeUtil.getByte(currentByteBufferPos++); + } + + @Override + public byte[] readRawBytes(final int length) throws IOException { + if (length >= 0 && length <= currentRemaining()) { + byte[] bytes = new byte[length]; + UnsafeUtil.copyMemory(currentByteBufferPos, bytes, 0, length); + currentByteBufferPos += length; + return bytes; + } + if (length >= 0 && length <= remaining()) { + byte[] bytes = new byte[length]; + readRawBytesTo(bytes, 0, length); + return bytes; + } + + if (length <= 0) { + if (length == 0) { + return EMPTY_BYTE_ARRAY; + } else { + throw InvalidProtocolBufferException.negativeSize(); + } + } + + throw InvalidProtocolBufferException.truncatedMessage(); + } + + /** + * Try to get raw bytes from {@code input} with the size of {@code length} and copy to {@code + * bytes} array. If the size is bigger than the number of remaining bytes in the input, then + * throw {@code truncatedMessage} exception. + * + * @param bytes + * @param offset + * @param length + * @throws IOException + */ + private void readRawBytesTo(byte[] bytes, int offset, final int length) throws IOException { + if (length >= 0 && length <= remaining()) { + int l = length; + while (l > 0) { + if (currentRemaining() == 0) { + getNextByteBuffer(); + } + int bytesToCopy = Math.min(l, (int) currentRemaining()); + UnsafeUtil.copyMemory(currentByteBufferPos, bytes, length - l + offset, bytesToCopy); + l -= bytesToCopy; + currentByteBufferPos += bytesToCopy; + } + return; + } + + if (length <= 0) { + if (length == 0) { + return; + } else { + throw InvalidProtocolBufferException.negativeSize(); + } + } + throw InvalidProtocolBufferException.truncatedMessage(); + } + + @Override + public void skipRawBytes(final int length) throws IOException { + if (length >= 0 + && length + <= (totalBufferSize + - totalBytesRead + - currentByteBufferPos + + currentByteBufferStartPos)) { + // We have all the bytes we need already. + int l = length; + while (l > 0) { + if (currentRemaining() == 0) { + getNextByteBuffer(); + } + int rl = Math.min(l, (int) currentRemaining()); + l -= rl; + currentByteBufferPos += rl; + } + return; + } + + if (length < 0) { + throw InvalidProtocolBufferException.negativeSize(); + } + throw InvalidProtocolBufferException.truncatedMessage(); + } + + // TODO: optimize to fastpath + private void skipRawVarint() throws IOException { + for (int i = 0; i < MAX_VARINT_SIZE; i++) { + if (readRawByte() >= 0) { + return; + } + } + throw InvalidProtocolBufferException.malformedVarint(); + } + + /** + * Try to get the number of remaining bytes in {@code input}. + * + * @return the number of remaining bytes in {@code input}. + */ + private int remaining() { + return (int) + (totalBufferSize - totalBytesRead - currentByteBufferPos + currentByteBufferStartPos); + } + + /** + * Try to get the number of remaining bytes in {@code currentByteBuffer}. + * + * @return the number of remaining bytes in {@code currentByteBuffer} + */ + private long currentRemaining() { + return (currentByteBufferLimit - currentByteBufferPos); + } + + private ByteBuffer slice(int begin, int end) throws IOException { + int prevPos = currentByteBuffer.position(); + int prevLimit = currentByteBuffer.limit(); + try { + currentByteBuffer.position(begin); + currentByteBuffer.limit(end); + return currentByteBuffer.slice(); + } catch (IllegalArgumentException e) { + throw InvalidProtocolBufferException.truncatedMessage(); + } finally { + currentByteBuffer.position(prevPos); + currentByteBuffer.limit(prevLimit); + } + } + } } |