diff options
Diffstat (limited to 'java/core/src/main/java')
50 files changed, 6830 insertions, 1737 deletions
diff --git a/java/core/src/main/java/com/google/protobuf/AbstractMessage.java b/java/core/src/main/java/com/google/protobuf/AbstractMessage.java index 7639efcf..fc3c2a5d 100644 --- a/java/core/src/main/java/com/google/protobuf/AbstractMessage.java +++ b/java/core/src/main/java/com/google/protobuf/AbstractMessage.java @@ -32,9 +32,9 @@ package com.google.protobuf; import com.google.protobuf.Descriptors.EnumValueDescriptor; import com.google.protobuf.Descriptors.FieldDescriptor; +import com.google.protobuf.Descriptors.FileDescriptor.Syntax; import com.google.protobuf.Descriptors.OneofDescriptor; import com.google.protobuf.Internal.EnumLite; - import java.io.IOException; import java.io.InputStream; import java.util.Arrays; @@ -125,6 +125,16 @@ public abstract class AbstractMessage protected int memoizedSize = -1; @Override + int getMemoizedSerializedSize() { + return memoizedSize; + } + + @Override + void setMemoizedSerializedSize(int size) { + memoizedSize = size; + } + + @Override public int getSerializedSize() { int size = memoizedSize; if (size != -1) { @@ -163,7 +173,7 @@ public abstract class AbstractMessage } return hash; } - + private static ByteString toByteString(Object value) { if (value instanceof byte[]) { return ByteString.copyFrom((byte[]) value); @@ -171,7 +181,7 @@ public abstract class AbstractMessage return (ByteString) value; } } - + /** * Compares two bytes fields. The parameters must be either a byte array or a * ByteString object. They can be of different type though. @@ -182,7 +192,7 @@ public abstract class AbstractMessage } return toByteString(a).equals(toByteString(b)); } - + /** * Converts a list of MapEntry messages into a Map used for equals() and * hashCode(). @@ -213,7 +223,7 @@ public abstract class AbstractMessage } return result; } - + /** * Compares two map fields. The parameters must be a list of MapEntry * messages. @@ -224,13 +234,13 @@ public abstract class AbstractMessage Map mb = convertMapEntryListToMap((List) b); return MapFieldLite.equals(ma, mb); } - + /** * Compares two set of fields. * This method is used to implement {@link AbstractMessage#equals(Object)} * and {@link AbstractMutableMessage#equals(Object)}. It takes special care * of bytes fields because immutable messages and mutable messages use - * different Java type to reprensent a bytes field and this method should be + * different Java type to represent a bytes field and this method should be * able to compare immutable messages, mutable messages and also an immutable * message to a mutable message. */ @@ -276,7 +286,7 @@ public abstract class AbstractMessage } return true; } - + /** * Calculates the hash code of a map field. {@code value} must be a list of * MapEntry messages. @@ -328,8 +338,12 @@ public abstract class AbstractMessage extends AbstractMessageLite.Builder implements Message.Builder { // The compiler produces an error if this is not declared explicitly. + // Method isn't abstract to bypass Java 1.6 compiler issue: + // http://bugs.java.com/view_bug.do?bug_id=6908259 @Override - public abstract BuilderType clone(); + public BuilderType clone() { + throw new UnsupportedOperationException("clone() should be implemented in subclasses."); + } /** TODO(jieluo): Clear it when all subclasses have implemented this method. */ @Override @@ -368,7 +382,7 @@ public abstract class AbstractMessage public String getInitializationErrorString() { return MessageReflection.delimitWithCommas(findInitializationErrors()); } - + @Override protected BuilderType internalMergeFrom(AbstractMessageLite other) { return mergeFrom((Message) other); @@ -376,6 +390,10 @@ public abstract class AbstractMessage @Override public BuilderType mergeFrom(final Message other) { + return mergeFrom(other, other.getAllFields()); + } + + BuilderType mergeFrom(final Message other, Map<FieldDescriptor, Object> allFields) { if (other.getDescriptorForType() != getDescriptorForType()) { throw new IllegalArgumentException( "mergeFrom(Message) can only merge messages of the same type."); @@ -390,8 +408,7 @@ public abstract class AbstractMessage // TODO(kenton): Provide a function somewhere called makeDeepCopy() // which allows people to make secure deep copies of messages. - for (final Map.Entry<FieldDescriptor, Object> entry : - other.getAllFields().entrySet()) { + for (final Map.Entry<FieldDescriptor, Object> entry : allFields.entrySet()) { final FieldDescriptor field = entry.getKey(); if (field.isRepeated()) { for (final Object element : (List)entry.getValue()) { @@ -429,8 +446,12 @@ public abstract class AbstractMessage final CodedInputStream input, final ExtensionRegistryLite extensionRegistry) throws IOException { + boolean discardUnknown = + getDescriptorForType().getFile().getSyntax() == Syntax.PROTO3 + ? input.shouldDiscardUnknownFieldsProto3() + : input.shouldDiscardUnknownFields(); final UnknownFieldSet.Builder unknownFields = - UnknownFieldSet.newBuilder(getUnknownFields()); + discardUnknown ? null : UnknownFieldSet.newBuilder(getUnknownFields()); while (true) { final int tag = input.readTag(); if (tag == 0) { @@ -448,7 +469,9 @@ public abstract class AbstractMessage break; } } - setUnknownFields(unknownFields.build()); + if (unknownFields != null) { + setUnknownFields(unknownFields.build()); + } return (BuilderType) this; } diff --git a/java/core/src/main/java/com/google/protobuf/AbstractMessageLite.java b/java/core/src/main/java/com/google/protobuf/AbstractMessageLite.java index 046030f3..b22bbaab 100644 --- a/java/core/src/main/java/com/google/protobuf/AbstractMessageLite.java +++ b/java/core/src/main/java/com/google/protobuf/AbstractMessageLite.java @@ -30,11 +30,15 @@ package com.google.protobuf; +import static com.google.protobuf.Internal.checkNotNull; + import java.io.FilterInputStream; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; +import java.util.ArrayList; import java.util.Collection; +import java.util.List; /** * A partial implementation of the {@link MessageLite} interface which @@ -48,7 +52,6 @@ public abstract class AbstractMessageLite< BuilderType extends AbstractMessageLite.Builder<MessageType, BuilderType>> implements MessageLite { protected int memoizedHashCode = 0; - @Override public ByteString toByteString() { try { @@ -96,6 +99,16 @@ public abstract class AbstractMessageLite< codedOutput.flush(); } + // We'd like these to be abstract but some folks are extending this class directly. They shouldn't + // be doing that and they should feel bad. + int getMemoizedSerializedSize() { + throw new UnsupportedOperationException(); + } + + void setMemoizedSerializedSize(int size) { + throw new UnsupportedOperationException(); + } + /** * Package private helper method for AbstractParser to create @@ -117,8 +130,13 @@ public abstract class AbstractMessageLite< } } - protected static <T> void addAll(final Iterable<T> values, - final Collection<? super T> list) { + // For binary compatibility + @Deprecated + protected static <T> void addAll(final Iterable<T> values, final Collection<? super T> list) { + Builder.addAll(values, (List) list); + } + + protected static <T> void addAll(final Iterable<T> values, final List<? super T> list) { Builder.addAll(values, list); } @@ -333,6 +351,25 @@ public abstract class AbstractMessageLite< + " threw an IOException (should never happen)."; } + // We check nulls as we iterate to avoid iterating over values twice. + private static <T> void addAllCheckingNulls(Iterable<T> values, List<? super T> list) { + if (list instanceof ArrayList && values instanceof Collection) { + ((ArrayList<T>) list).ensureCapacity(list.size() + ((Collection<T>) values).size()); + } + int begin = list.size(); + for (T value : values) { + if (value == null) { + // encountered a null value so we must undo our modifications prior to throwing + String message = "Element at index " + (list.size() - begin) + " is null."; + for (int i = list.size() - 1; i >= begin; i--) { + list.remove(i); + } + throw new NullPointerException(message); + } + list.add(value); + } + } + /** * Construct an UninitializedMessageException reporting missing fields in * the given message. @@ -342,41 +379,50 @@ public abstract class AbstractMessageLite< return new UninitializedMessageException(message); } + // For binary compatibility. + @Deprecated + protected static <T> void addAll(final Iterable<T> values, final Collection<? super T> list) { + addAll(values, (List<T>) list); + } + /** - * Adds the {@code values} to the {@code list}. This is a helper method - * used by generated code. Users should ignore it. + * Adds the {@code values} to the {@code list}. This is a helper method used by generated code. + * Users should ignore it. * - * @throws NullPointerException if {@code values} or any of the elements of - * {@code values} is null. When that happens, some elements of - * {@code values} may have already been added to the result {@code list}. + * @throws NullPointerException if {@code values} or any of the elements of {@code values} is + * null. */ - protected static <T> void addAll(final Iterable<T> values, - final Collection<? super T> list) { - if (values == null) { - throw new NullPointerException(); - } + protected static <T> void addAll(final Iterable<T> values, final List<? super T> list) { + checkNotNull(values); if (values instanceof LazyStringList) { // For StringOrByteStringLists, check the underlying elements to avoid // forcing conversions of ByteStrings to Strings. - checkForNullValues(((LazyStringList) values).getUnderlyingElements()); - list.addAll((Collection<T>) values); - } else if (values instanceof Collection) { - checkForNullValues(values); - list.addAll((Collection<T>) values); - } else { - for (final T value : values) { + // TODO(dweis): Could we just prohibit nulls in all protobuf lists and get rid of this? Is + // if even possible to hit this condition as all protobuf methods check for null first, + // right? + List<?> lazyValues = ((LazyStringList) values).getUnderlyingElements(); + LazyStringList lazyList = (LazyStringList) list; + int begin = list.size(); + for (Object value : lazyValues) { if (value == null) { - throw new NullPointerException(); + // encountered a null value so we must undo our modifications prior to throwing + String message = "Element at index " + (lazyList.size() - begin) + " is null."; + for (int i = lazyList.size() - 1; i >= begin; i--) { + lazyList.remove(i); + } + throw new NullPointerException(message); + } + if (value instanceof ByteString) { + lazyList.add((ByteString) value); + } else { + lazyList.add((String) value); } - list.add(value); } - } - } - - private static void checkForNullValues(final Iterable<?> values) { - for (final Object value : values) { - if (value == null) { - throw new NullPointerException(); + } else { + if (values instanceof PrimitiveNonBoxingCollection) { + list.addAll((Collection<T>) values); + } else { + addAllCheckingNulls(values, list); } } } diff --git a/java/core/src/main/java/com/google/protobuf/AbstractParser.java b/java/core/src/main/java/com/google/protobuf/AbstractParser.java index 66b0ee3b..ba570e3d 100644 --- a/java/core/src/main/java/com/google/protobuf/AbstractParser.java +++ b/java/core/src/main/java/com/google/protobuf/AbstractParser.java @@ -31,9 +31,9 @@ package com.google.protobuf; import com.google.protobuf.AbstractMessageLite.Builder.LimitedInputStream; - import java.io.IOException; import java.io.InputStream; +import java.nio.ByteBuffer; /** * A partial implementation of the {@link Parser} interface which implements @@ -131,6 +131,30 @@ public abstract class AbstractParser<MessageType extends MessageLite> } @Override + public MessageType parseFrom(ByteBuffer data, ExtensionRegistryLite extensionRegistry) + throws InvalidProtocolBufferException { + MessageType message; + try { + CodedInputStream input = CodedInputStream.newInstance(data); + message = parsePartialFrom(input, extensionRegistry); + try { + input.checkLastTagWas(0); + } catch (InvalidProtocolBufferException e) { + throw e.setUnfinishedMessage(message); + } + } catch (InvalidProtocolBufferException e) { + throw e; + } + + return checkMessageInitialized(message); + } + + @Override + public MessageType parseFrom(ByteBuffer data) throws InvalidProtocolBufferException { + return parseFrom(data, EMPTY_REGISTRY); + } + + @Override public MessageType parsePartialFrom( byte[] data, int off, int len, ExtensionRegistryLite extensionRegistry) throws InvalidProtocolBufferException { @@ -232,7 +256,7 @@ public abstract class AbstractParser<MessageType extends MessageLite> } size = CodedInputStream.readRawVarint32(firstByte, input); } catch (IOException e) { - throw new InvalidProtocolBufferException(e.getMessage()); + throw new InvalidProtocolBufferException(e); } InputStream limitedInput = new LimitedInputStream(input, size); return parsePartialFrom(limitedInput, extensionRegistry); diff --git a/java/core/src/main/java/com/google/protobuf/Android.java b/java/core/src/main/java/com/google/protobuf/Android.java new file mode 100644 index 00000000..cad54783 --- /dev/null +++ b/java/core/src/main/java/com/google/protobuf/Android.java @@ -0,0 +1,57 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package com.google.protobuf; + +final class Android { + + private static final Class<?> MEMORY_CLASS = getClassForName("libcore.io.Memory"); + private static final boolean IS_ROBOLECTRIC = + getClassForName("org.robolectric.Robolectric") != null; + + /** Returns {@code true} if running on an Android device. */ + static boolean isOnAndroidDevice() { + return MEMORY_CLASS != null && !IS_ROBOLECTRIC; + } + + /** Returns the memory class or {@code null} if not on Android device. */ + static Class<?> getMemoryClass() { + return MEMORY_CLASS; + } + + @SuppressWarnings("unchecked") + private static <T> Class<T> getClassForName(String name) { + try { + return (Class<T>) Class.forName(name); + } catch (Throwable e) { + return null; + } + } +} diff --git a/java/core/src/main/java/com/google/protobuf/BooleanArrayList.java b/java/core/src/main/java/com/google/protobuf/BooleanArrayList.java index 0d9f87ba..4d7a9727 100644 --- a/java/core/src/main/java/com/google/protobuf/BooleanArrayList.java +++ b/java/core/src/main/java/com/google/protobuf/BooleanArrayList.java @@ -30,8 +30,9 @@ package com.google.protobuf; -import com.google.protobuf.Internal.BooleanList; +import static com.google.protobuf.Internal.checkNotNull; +import com.google.protobuf.Internal.BooleanList; import java.util.Arrays; import java.util.Collection; import java.util.RandomAccess; @@ -41,9 +42,8 @@ import java.util.RandomAccess; * * @author dweis@google.com (Daniel Weis) */ -final class BooleanArrayList - extends AbstractProtobufList<Boolean> - implements BooleanList, RandomAccess { +final class BooleanArrayList extends AbstractProtobufList<Boolean> + implements BooleanList, RandomAccess, PrimitiveNonBoxingCollection { private static final BooleanArrayList EMPTY_LIST = new BooleanArrayList(); static { @@ -82,6 +82,18 @@ final class BooleanArrayList } @Override + protected void removeRange(int fromIndex, int toIndex) { + ensureIsMutable(); + if (toIndex < fromIndex) { + throw new IndexOutOfBoundsException("toIndex < fromIndex"); + } + + System.arraycopy(array, toIndex, array, fromIndex, size - toIndex); + size -= (toIndex - fromIndex); + modCount++; + } + + @Override public boolean equals(Object o) { if (this == o) { return true; @@ -198,9 +210,7 @@ final class BooleanArrayList public boolean addAll(Collection<? extends Boolean> collection) { ensureIsMutable(); - if (collection == null) { - throw new NullPointerException(); - } + checkNotNull(collection); // We specialize when adding another BooleanArrayList to avoid boxing elements. if (!(collection instanceof BooleanArrayList)) { @@ -248,7 +258,9 @@ final class BooleanArrayList ensureIsMutable(); ensureIndexInRange(index); boolean value = array[index]; - System.arraycopy(array, index + 1, array, index, size - index); + if (index < size - 1) { + System.arraycopy(array, index + 1, array, index, size - index); + } size--; modCount++; return value; diff --git a/java/core/src/main/java/com/google/protobuf/ByteBufferWriter.java b/java/core/src/main/java/com/google/protobuf/ByteBufferWriter.java index 0cc38175..6157a52f 100644 --- a/java/core/src/main/java/com/google/protobuf/ByteBufferWriter.java +++ b/java/core/src/main/java/com/google/protobuf/ByteBufferWriter.java @@ -33,11 +33,12 @@ package com.google.protobuf; import static java.lang.Math.max; import static java.lang.Math.min; -import java.io.FileOutputStream; import java.io.IOException; import java.io.OutputStream; import java.lang.ref.SoftReference; +import java.lang.reflect.Field; import java.nio.ByteBuffer; +import java.nio.channels.WritableByteChannel; /** * Utility class to provide efficient writing of {@link ByteBuffer}s to {@link OutputStream}s. @@ -75,6 +76,12 @@ final class ByteBufferWriter { new ThreadLocal<SoftReference<byte[]>>(); /** + * This is a hack for GAE, where {@code FileOutputStream} is unavailable. + */ + private static final Class<?> FILE_OUTPUT_STREAM_CLASS = safeGetClass("java.io.FileOutputStream"); + private static final long CHANNEL_FIELD_OFFSET = getChannelFieldOffset(FILE_OUTPUT_STREAM_CLASS); + + /** * For testing purposes only. Clears the cached buffer to force a new allocation on the next * invocation. */ @@ -93,10 +100,7 @@ final class ByteBufferWriter { // Optimized write for array-backed buffers. // Note that we're taking the risk that a malicious OutputStream could modify the array. output.write(buffer.array(), buffer.arrayOffset() + buffer.position(), buffer.remaining()); - } else if (output instanceof FileOutputStream) { - // Use a channel to write out the ByteBuffer. This will automatically empty the buffer. - ((FileOutputStream) output).getChannel().write(buffer); - } else { + } else if (!writeToChannel(buffer, output)){ // Read all of the data from the buffer to an array. // TODO(nathanmittler): Consider performance improvements for other "known" stream types. final byte[] array = getOrCreateBuffer(buffer.remaining()); @@ -142,4 +146,40 @@ final class ByteBufferWriter { private static void setBuffer(byte[] value) { BUFFER.set(new SoftReference<byte[]>(value)); } + + private static boolean writeToChannel(ByteBuffer buffer, OutputStream output) throws IOException { + if (CHANNEL_FIELD_OFFSET >= 0 && FILE_OUTPUT_STREAM_CLASS.isInstance(output)) { + // Use a channel to write out the ByteBuffer. This will automatically empty the buffer. + WritableByteChannel channel = null; + try { + channel = (WritableByteChannel) UnsafeUtil.getObject(output, CHANNEL_FIELD_OFFSET); + } catch (ClassCastException e) { + // Absorb. + } + if (channel != null) { + channel.write(buffer); + return true; + } + } + return false; + } + + private static Class<?> safeGetClass(String className) { + try { + return Class.forName(className); + } catch (ClassNotFoundException e) { + return null; + } + } + private static long getChannelFieldOffset(Class<?> clazz) { + try { + if (clazz != null && UnsafeUtil.hasUnsafeArrayOperations()) { + Field field = clazz.getDeclaredField("channel"); + return UnsafeUtil.objectFieldOffset(field); + } + } catch (Throwable e) { + // Absorb + } + return -1; + } } diff --git a/java/core/src/main/java/com/google/protobuf/ByteString.java b/java/core/src/main/java/com/google/protobuf/ByteString.java index 3f3f9f3c..d67bb54a 100644 --- a/java/core/src/main/java/com/google/protobuf/ByteString.java +++ b/java/core/src/main/java/com/google/protobuf/ByteString.java @@ -51,14 +51,12 @@ import java.util.List; import java.util.NoSuchElementException; /** - * Immutable sequence of bytes. Substring is supported by sharing the reference - * to the immutable underlying bytes. Concatenation is likewise supported - * without copying (long strings) by building a tree of pieces in - * {@link RopeByteString}. - * <p> - * Like {@link String}, the contents of a {@link ByteString} can never be - * observed to change, not even in the presence of a data race or incorrect - * API usage in the client code. + * Immutable sequence of bytes. Substring is supported by sharing the reference to the immutable + * underlying bytes. Concatenation is likewise supported without copying (long strings) by building + * a tree of pieces in {@link RopeByteString}. + * + * <p>Like {@link String}, the contents of a {@link ByteString} can never be observed to change, not + * even in the presence of a data race or incorrect API usage in the client code. * * @author crazybob@google.com Bob Lee * @author kenton@google.com Kenton Varda @@ -126,14 +124,8 @@ public abstract class ByteString implements Iterable<Byte>, Serializable { private static final ByteArrayCopier byteArrayCopier; static { - boolean isAndroid = true; - try { - Class.forName("android.content.Context"); - } catch (ClassNotFoundException e) { - isAndroid = false; - } - - byteArrayCopier = isAndroid ? new SystemByteArrayCopier() : new ArraysByteArrayCopier(); + byteArrayCopier = + Android.isOnAndroidDevice() ? new SystemByteArrayCopier() : new ArraysByteArrayCopier(); } /** @@ -311,6 +303,18 @@ public abstract class ByteString implements Iterable<Byte>, Serializable { } /** + * Wraps the given bytes into a {@code ByteString}. Intended for internal only usage. + */ + static ByteString wrap(ByteBuffer buffer) { + if (buffer.hasArray()) { + final int offset = buffer.arrayOffset(); + return ByteString.wrap(buffer.array(), offset + buffer.position(), buffer.remaining()); + } else { + return new NioByteString(buffer); + } + } + + /** * Wraps the given bytes into a {@code ByteString}. Intended for internal only * usage to force a classload of ByteString before LiteralByteString. */ @@ -553,7 +557,9 @@ public abstract class ByteString implements Iterable<Byte>, Serializable { // Create a balanced concatenation of the next "length" elements from the // iterable. private static ByteString balancedConcat(Iterator<ByteString> iterator, int length) { - assert length >= 1; + if (length < 1) { + throw new IllegalArgumentException(String.format("length (%s) must be >= 1", length)); + } ByteString result; if (length == 1) { result = iterator.next(); @@ -679,6 +685,7 @@ public abstract class ByteString implements Iterable<Byte>, Serializable { */ abstract void writeTo(ByteOutput byteOutput) throws IOException; + /** * Constructs a read-only {@code java.nio.ByteBuffer} whose content * is equal to the contents of this byte string. @@ -820,6 +827,7 @@ public abstract class ByteString implements Iterable<Byte>, Serializable { return true; } + /** * Check equality of the substring of given length of this object starting at * zero with another {@code ByteString} substring starting at offset. diff --git a/java/core/src/main/java/com/google/protobuf/CodedInputStream.java b/java/core/src/main/java/com/google/protobuf/CodedInputStream.java index e8860651..1297462e 100644 --- a/java/core/src/main/java/com/google/protobuf/CodedInputStream.java +++ b/java/core/src/main/java/com/google/protobuf/CodedInputStream.java @@ -30,62 +30,119 @@ package com.google.protobuf; +import static com.google.protobuf.Internal.EMPTY_BYTE_ARRAY; +import static com.google.protobuf.Internal.EMPTY_BYTE_BUFFER; +import static com.google.protobuf.Internal.UTF_8; +import static com.google.protobuf.Internal.checkNotNull; +import static com.google.protobuf.WireFormat.FIXED32_SIZE; +import static com.google.protobuf.WireFormat.FIXED64_SIZE; +import static com.google.protobuf.WireFormat.MAX_VARINT_SIZE; + import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.InputStream; import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Arrays; +import java.util.Iterator; import java.util.List; /** * Reads and decodes protocol message fields. * - * This class contains two kinds of methods: methods that read specific - * protocol message constructs and field types (e.g. {@link #readTag()} and - * {@link #readInt32()}) and methods that read low-level values (e.g. - * {@link #readRawVarint32()} and {@link #readRawBytes}). If you are reading - * encoded protocol messages, you should use the former methods, but if you are - * reading some other format of your own design, use the latter. + * <p>This class contains two kinds of methods: methods that read specific protocol message + * constructs and field types (e.g. {@link #readTag()} and {@link #readInt32()}) and methods that + * read low-level values (e.g. {@link #readRawVarint32()} and {@link #readRawBytes}). If you are + * reading encoded protocol messages, you should use the former methods, but if you are reading some + * other format of your own design, use the latter. * * @author kenton@google.com Kenton Varda */ -public final class CodedInputStream { +public abstract class CodedInputStream { + private static final int DEFAULT_BUFFER_SIZE = 4096; + private static final int DEFAULT_RECURSION_LIMIT = 100; + // Integer.MAX_VALUE == 0x7FFFFFF == INT_MAX from limits.h + private static final int DEFAULT_SIZE_LIMIT = Integer.MAX_VALUE; + /** - * Create a new CodedInputStream wrapping the given InputStream. + * Whether to enable our custom UTF-8 decode codepath which does not use {@link StringCoding}. + * Currently disabled. */ + private static final boolean ENABLE_CUSTOM_UTF8_DECODE = false; + + /** Visible for subclasses. See setRecursionLimit() */ + int recursionDepth; + + int recursionLimit = DEFAULT_RECURSION_LIMIT; + + /** Visible for subclasses. See setSizeLimit() */ + int sizeLimit = DEFAULT_SIZE_LIMIT; + + /** Create a new CodedInputStream wrapping the given InputStream. */ public static CodedInputStream newInstance(final InputStream input) { - return new CodedInputStream(input, BUFFER_SIZE); + return newInstance(input, DEFAULT_BUFFER_SIZE); } - - /** - * Create a new CodedInputStream wrapping the given InputStream. - */ + + /** Create a new CodedInputStream wrapping the given InputStream. */ static CodedInputStream newInstance(final InputStream input, int bufferSize) { - return new CodedInputStream(input, bufferSize); + if (input == null) { + // TODO(nathanmittler): Ideally we should throw here. This is done for backward compatibility. + return newInstance(EMPTY_BYTE_ARRAY); + } + return new StreamDecoder(input, bufferSize); } - /** - * Create a new CodedInputStream wrapping the given byte array. - */ + /** Create a new CodedInputStream wrapping the given {@code Iterable <ByteBuffer>}. */ + public static CodedInputStream newInstance(final Iterable<ByteBuffer> input) { + if (!UnsafeDirectNioDecoder.isSupported()) { + return newInstance(new IterableByteBufferInputStream(input)); + } + return newInstance(input, false); + } + + /** Create a new CodedInputStream wrapping the given {@code Iterable <ByteBuffer>}. */ + static CodedInputStream newInstance( + final Iterable<ByteBuffer> bufs, final boolean bufferIsImmutable) { + // flag is to check the type of input's ByteBuffers. + // flag equals 1: all ByteBuffers have array. + // flag equals 2: all ByteBuffers are direct ByteBuffers. + // flag equals 3: some ByteBuffers are direct and some have array. + // flag greater than 3: other cases. + int flag = 0; + // Total size of the input + int totalSize = 0; + for (ByteBuffer buf : bufs) { + totalSize += buf.remaining(); + if (buf.hasArray()) { + flag |= 1; + } else if (buf.isDirect()) { + flag |= 2; + } else { + flag |= 4; + } + } + if (flag == 2) { + return new IterableDirectByteBufferDecoder(bufs, totalSize, bufferIsImmutable); + } else { + // TODO(yilunchong): add another decoders to deal case 1 and 3. + return newInstance(new IterableByteBufferInputStream(bufs)); + } + } + + /** Create a new CodedInputStream wrapping the given byte array. */ public static CodedInputStream newInstance(final byte[] buf) { return newInstance(buf, 0, buf.length); } - /** - * Create a new CodedInputStream wrapping the given byte array slice. - */ - public static CodedInputStream newInstance(final byte[] buf, final int off, - final int len) { + /** Create a new CodedInputStream wrapping the given byte array slice. */ + public static CodedInputStream newInstance(final byte[] buf, final int off, final int len) { return newInstance(buf, off, len, false /* bufferIsImmutable */); } - - /** - * Create a new CodedInputStream wrapping the given byte array slice. - */ + + /** Create a new CodedInputStream wrapping the given byte array slice. */ static CodedInputStream newInstance( final byte[] buf, final int off, final int len, final boolean bufferIsImmutable) { - CodedInputStream result = new CodedInputStream(buf, off, len, bufferIsImmutable); + ArrayDecoder result = new ArrayDecoder(buf, off, len, bufferIsImmutable); try { // Some uses of CodedInputStream can be more efficient if they know // exactly how many bytes are available. By pushing the end point of the @@ -107,583 +164,415 @@ public final class CodedInputStream { } /** - * Create a new CodedInputStream wrapping the given ByteBuffer. The data - * starting from the ByteBuffer's current position to its limit will be read. - * The returned CodedInputStream may or may not share the underlying data - * in the ByteBuffer, therefore the ByteBuffer cannot be changed while the - * CodedInputStream is in use. - * Note that the ByteBuffer's position won't be changed by this function. - * Concurrent calls with the same ByteBuffer object are safe if no other - * thread is trying to alter the ByteBuffer's status. + * Create a new CodedInputStream wrapping the given ByteBuffer. The data starting from the + * ByteBuffer's current position to its limit will be read. The returned CodedInputStream may or + * may not share the underlying data in the ByteBuffer, therefore the ByteBuffer cannot be changed + * while the CodedInputStream is in use. Note that the ByteBuffer's position won't be changed by + * this function. Concurrent calls with the same ByteBuffer object are safe if no other thread is + * trying to alter the ByteBuffer's status. */ public static CodedInputStream newInstance(ByteBuffer buf) { + return newInstance(buf, false /* bufferIsImmutable */); + } + + /** Create a new CodedInputStream wrapping the given buffer. */ + static CodedInputStream newInstance(ByteBuffer buf, boolean bufferIsImmutable) { if (buf.hasArray()) { - return newInstance(buf.array(), buf.arrayOffset() + buf.position(), - buf.remaining()); - } else { - ByteBuffer temp = buf.duplicate(); - byte[] buffer = new byte[temp.remaining()]; - temp.get(buffer); - return newInstance(buffer); + return newInstance( + buf.array(), buf.arrayOffset() + buf.position(), buf.remaining(), bufferIsImmutable); } + + if (buf.isDirect() && UnsafeDirectNioDecoder.isSupported()) { + return new UnsafeDirectNioDecoder(buf, bufferIsImmutable); + } + + // The buffer is non-direct and does not expose the underlying array. Using the ByteBuffer API + // to access individual bytes is very slow, so just copy the buffer to an array. + // TODO(nathanmittler): Re-evaluate with Java 9 + byte[] buffer = new byte[buf.remaining()]; + buf.duplicate().get(buffer); + return newInstance(buffer, 0, buffer.length, true); } + /** Disable construction/inheritance outside of this class. */ + private CodedInputStream() {} + // ----------------------------------------------------------------- /** - * Attempt to read a field tag, returning zero if we have reached EOF. - * Protocol message parsers use this to read tags, since a protocol message - * may legally end wherever a tag occurs, and zero is not a valid tag number. + * Attempt to read a field tag, returning zero if we have reached EOF. Protocol message parsers + * use this to read tags, since a protocol message may legally end wherever a tag occurs, and zero + * is not a valid tag number. */ - public int readTag() throws IOException { - if (isAtEnd()) { - lastTag = 0; - return 0; - } - - lastTag = readRawVarint32(); - if (WireFormat.getTagFieldNumber(lastTag) == 0) { - // If we actually read zero (or any tag number corresponding to field - // number zero), that's not a valid tag. - throw InvalidProtocolBufferException.invalidTag(); - } - return lastTag; - } + public abstract int readTag() throws IOException; /** - * Verifies that the last call to readTag() returned the given tag value. - * This is used to verify that a nested group ended with the correct - * end tag. + * Verifies that the last call to readTag() returned the given tag value. This is used to verify + * that a nested group ended with the correct end tag. * - * @throws InvalidProtocolBufferException {@code value} does not match the - * last tag. + * @throws InvalidProtocolBufferException {@code value} does not match the last tag. */ - public void checkLastTagWas(final int value) - throws InvalidProtocolBufferException { - if (lastTag != value) { - throw InvalidProtocolBufferException.invalidEndTag(); - } - } + public abstract void checkLastTagWas(final int value) throws InvalidProtocolBufferException; - public int getLastTag() { - return lastTag; - } + public abstract int getLastTag(); /** * Reads and discards a single field, given its tag value. * - * @return {@code false} if the tag is an endgroup tag, in which case - * nothing is skipped. Otherwise, returns {@code true}. + * @return {@code false} if the tag is an endgroup tag, in which case nothing is skipped. + * Otherwise, returns {@code true}. */ - public boolean skipField(final int tag) throws IOException { - switch (WireFormat.getTagWireType(tag)) { - case WireFormat.WIRETYPE_VARINT: - skipRawVarint(); - return true; - case WireFormat.WIRETYPE_FIXED64: - skipRawBytes(8); - return true; - case WireFormat.WIRETYPE_LENGTH_DELIMITED: - skipRawBytes(readRawVarint32()); - return true; - case WireFormat.WIRETYPE_START_GROUP: - skipMessage(); - checkLastTagWas( - WireFormat.makeTag(WireFormat.getTagFieldNumber(tag), - WireFormat.WIRETYPE_END_GROUP)); - return true; - case WireFormat.WIRETYPE_END_GROUP: - return false; - case WireFormat.WIRETYPE_FIXED32: - skipRawBytes(4); - return true; - default: - throw InvalidProtocolBufferException.invalidWireType(); - } - } + public abstract boolean skipField(final int tag) throws IOException; /** - * Reads a single field and writes it to output in wire format, - * given its tag value. + * Reads a single field and writes it to output in wire format, given its tag value. * - * @return {@code false} if the tag is an endgroup tag, in which case - * nothing is skipped. Otherwise, returns {@code true}. + * @return {@code false} if the tag is an endgroup tag, in which case nothing is skipped. + * Otherwise, returns {@code true}. + * @deprecated use {@code UnknownFieldSet} or {@code UnknownFieldSetLite} to skip to an output + * stream. */ - public boolean skipField(final int tag, final CodedOutputStream output) - throws IOException { - switch (WireFormat.getTagWireType(tag)) { - case WireFormat.WIRETYPE_VARINT: { - long value = readInt64(); - output.writeRawVarint32(tag); - output.writeUInt64NoTag(value); - return true; - } - case WireFormat.WIRETYPE_FIXED64: { - long value = readRawLittleEndian64(); - output.writeRawVarint32(tag); - output.writeFixed64NoTag(value); - return true; - } - case WireFormat.WIRETYPE_LENGTH_DELIMITED: { - ByteString value = readBytes(); - output.writeRawVarint32(tag); - output.writeBytesNoTag(value); - return true; - } - case WireFormat.WIRETYPE_START_GROUP: { - output.writeRawVarint32(tag); - skipMessage(output); - int endtag = WireFormat.makeTag(WireFormat.getTagFieldNumber(tag), - WireFormat.WIRETYPE_END_GROUP); - checkLastTagWas(endtag); - output.writeRawVarint32(endtag); - return true; - } - case WireFormat.WIRETYPE_END_GROUP: { - return false; - } - case WireFormat.WIRETYPE_FIXED32: { - int value = readRawLittleEndian32(); - output.writeRawVarint32(tag); - output.writeFixed32NoTag(value); - return true; - } - default: - throw InvalidProtocolBufferException.invalidWireType(); - } - } - - /** - * Reads and discards an entire message. This will read either until EOF - * or until an endgroup tag, whichever comes first. - */ - public void skipMessage() throws IOException { - while (true) { - final int tag = readTag(); - if (tag == 0 || !skipField(tag)) { - return; - } - } - } + @Deprecated + public abstract boolean skipField(final int tag, final CodedOutputStream output) + throws IOException; /** - * Reads an entire message and writes it to output in wire format. - * This will read either until EOF or until an endgroup tag, + * Reads and discards an entire message. This will read either until EOF or until an endgroup tag, * whichever comes first. */ - public void skipMessage(CodedOutputStream output) throws IOException { - while (true) { - final int tag = readTag(); - if (tag == 0 || !skipField(tag, output)) { - return; - } - } - } + public abstract void skipMessage() throws IOException; /** - * Collects the bytes skipped and returns the data in a ByteBuffer. + * Reads an entire message and writes it to output in wire format. This will read either until EOF + * or until an endgroup tag, whichever comes first. */ - private class SkippedDataSink implements RefillCallback { - private int lastPos = bufferPos; - private ByteArrayOutputStream byteArrayStream; - - @Override - public void onRefill() { - if (byteArrayStream == null) { - byteArrayStream = new ByteArrayOutputStream(); - } - byteArrayStream.write(buffer, lastPos, bufferPos - lastPos); - lastPos = 0; - } - - /** - * Gets skipped data in a ByteBuffer. This method should only be - * called once. - */ - ByteBuffer getSkippedData() { - if (byteArrayStream == null) { - return ByteBuffer.wrap(buffer, lastPos, bufferPos - lastPos); - } else { - byteArrayStream.write(buffer, lastPos, bufferPos); - return ByteBuffer.wrap(byteArrayStream.toByteArray()); - } - } - } + public abstract void skipMessage(CodedOutputStream output) throws IOException; // ----------------------------------------------------------------- /** Read a {@code double} field value from the stream. */ - public double readDouble() throws IOException { - return Double.longBitsToDouble(readRawLittleEndian64()); - } + public abstract double readDouble() throws IOException; /** Read a {@code float} field value from the stream. */ - public float readFloat() throws IOException { - return Float.intBitsToFloat(readRawLittleEndian32()); - } + public abstract float readFloat() throws IOException; /** Read a {@code uint64} field value from the stream. */ - public long readUInt64() throws IOException { - return readRawVarint64(); - } + public abstract long readUInt64() throws IOException; /** Read an {@code int64} field value from the stream. */ - public long readInt64() throws IOException { - return readRawVarint64(); - } + public abstract long readInt64() throws IOException; /** Read an {@code int32} field value from the stream. */ - public int readInt32() throws IOException { - return readRawVarint32(); - } + public abstract int readInt32() throws IOException; /** Read a {@code fixed64} field value from the stream. */ - public long readFixed64() throws IOException { - return readRawLittleEndian64(); - } + public abstract long readFixed64() throws IOException; /** Read a {@code fixed32} field value from the stream. */ - public int readFixed32() throws IOException { - return readRawLittleEndian32(); - } + public abstract int readFixed32() throws IOException; /** Read a {@code bool} field value from the stream. */ - public boolean readBool() throws IOException { - return readRawVarint64() != 0; - } + public abstract boolean readBool() throws IOException; /** - * Read a {@code string} field value from the stream. - * If the stream contains malformed UTF-8, + * Read a {@code string} field value from the stream. If the stream contains malformed UTF-8, * replace the offending bytes with the standard UTF-8 replacement character. */ - public String readString() throws IOException { - final int size = readRawVarint32(); - if (size <= (bufferSize - bufferPos) && size > 0) { - // Fast path: We already have the bytes in a contiguous buffer, so - // just copy directly from it. - final String result = new String(buffer, bufferPos, size, Internal.UTF_8); - bufferPos += size; - return result; - } else if (size == 0) { - return ""; - } else if (size <= bufferSize) { - refillBuffer(size); - String result = new String(buffer, bufferPos, size, Internal.UTF_8); - bufferPos += size; - return result; - } else { - // Slow path: Build a byte array first then copy it. - return new String(readRawBytesSlowPath(size), Internal.UTF_8); - } - } + public abstract String readString() throws IOException; /** - * Read a {@code string} field value from the stream. - * If the stream contains malformed UTF-8, + * Read a {@code string} field value from the stream. If the stream contains malformed UTF-8, * throw exception {@link InvalidProtocolBufferException}. */ - public String readStringRequireUtf8() throws IOException { - final int size = readRawVarint32(); - final byte[] bytes; - final int oldPos = bufferPos; - final int pos; - if (size <= (bufferSize - oldPos) && size > 0) { - // Fast path: We already have the bytes in a contiguous buffer, so - // just copy directly from it. - bytes = buffer; - bufferPos = oldPos + size; - pos = oldPos; - } else if (size == 0) { - return ""; - } else if (size <= bufferSize) { - refillBuffer(size); - bytes = buffer; - pos = 0; - bufferPos = pos + size; - } else { - // Slow path: Build a byte array first then copy it. - bytes = readRawBytesSlowPath(size); - pos = 0; - } - // TODO(martinrb): We could save a pass by validating while decoding. - if (!Utf8.isValidUtf8(bytes, pos, pos + size)) { - throw InvalidProtocolBufferException.invalidUtf8(); - } - return new String(bytes, pos, size, Internal.UTF_8); - } + public abstract String readStringRequireUtf8() throws IOException; /** Read a {@code group} field value from the stream. */ - public void readGroup(final int fieldNumber, - final MessageLite.Builder builder, - final ExtensionRegistryLite extensionRegistry) - throws IOException { - if (recursionDepth >= recursionLimit) { - throw InvalidProtocolBufferException.recursionLimitExceeded(); - } - ++recursionDepth; - builder.mergeFrom(this, extensionRegistry); - checkLastTagWas( - WireFormat.makeTag(fieldNumber, WireFormat.WIRETYPE_END_GROUP)); - --recursionDepth; - } + public abstract void readGroup( + final int fieldNumber, + final MessageLite.Builder builder, + final ExtensionRegistryLite extensionRegistry) + throws IOException; /** Read a {@code group} field value from the stream. */ - public <T extends MessageLite> T readGroup( - final int fieldNumber, - final Parser<T> parser, - final ExtensionRegistryLite extensionRegistry) - throws IOException { - if (recursionDepth >= recursionLimit) { - throw InvalidProtocolBufferException.recursionLimitExceeded(); - } - ++recursionDepth; - T result = parser.parsePartialFrom(this, extensionRegistry); - checkLastTagWas( - WireFormat.makeTag(fieldNumber, WireFormat.WIRETYPE_END_GROUP)); - --recursionDepth; - return result; - } + public abstract <T extends MessageLite> T readGroup( + final int fieldNumber, final Parser<T> parser, final ExtensionRegistryLite extensionRegistry) + throws IOException; /** - * Reads a {@code group} field value from the stream and merges it into the - * given {@link UnknownFieldSet}. + * Reads a {@code group} field value from the stream and merges it into the given {@link + * UnknownFieldSet}. * - * @deprecated UnknownFieldSet.Builder now implements MessageLite.Builder, so - * you can just call {@link #readGroup}. + * @deprecated UnknownFieldSet.Builder now implements MessageLite.Builder, so you can just call + * {@link #readGroup}. */ @Deprecated - public void readUnknownGroup(final int fieldNumber, - final MessageLite.Builder builder) - throws IOException { - // We know that UnknownFieldSet will ignore any ExtensionRegistry so it - // is safe to pass null here. (We can't call - // ExtensionRegistry.getEmptyRegistry() because that would make this - // class depend on ExtensionRegistry, which is not part of the lite - // library.) - readGroup(fieldNumber, builder, null); - } + public abstract void readUnknownGroup(final int fieldNumber, final MessageLite.Builder builder) + throws IOException; /** Read an embedded message field value from the stream. */ - public void readMessage(final MessageLite.Builder builder, - final ExtensionRegistryLite extensionRegistry) - throws IOException { - final int length = readRawVarint32(); - if (recursionDepth >= recursionLimit) { - throw InvalidProtocolBufferException.recursionLimitExceeded(); - } - final int oldLimit = pushLimit(length); - ++recursionDepth; - builder.mergeFrom(this, extensionRegistry); - checkLastTagWas(0); - --recursionDepth; - popLimit(oldLimit); - } + public abstract void readMessage( + final MessageLite.Builder builder, final ExtensionRegistryLite extensionRegistry) + throws IOException; /** Read an embedded message field value from the stream. */ - public <T extends MessageLite> T readMessage( - final Parser<T> parser, - final ExtensionRegistryLite extensionRegistry) - throws IOException { - int length = readRawVarint32(); - if (recursionDepth >= recursionLimit) { - throw InvalidProtocolBufferException.recursionLimitExceeded(); - } - final int oldLimit = pushLimit(length); - ++recursionDepth; - T result = parser.parsePartialFrom(this, extensionRegistry); - checkLastTagWas(0); - --recursionDepth; - popLimit(oldLimit); - return result; - } + public abstract <T extends MessageLite> T readMessage( + final Parser<T> parser, final ExtensionRegistryLite extensionRegistry) throws IOException; /** Read a {@code bytes} field value from the stream. */ - public ByteString readBytes() throws IOException { - final int size = readRawVarint32(); - if (size <= (bufferSize - bufferPos) && size > 0) { - // Fast path: We already have the bytes in a contiguous buffer, so - // just copy directly from it. - final ByteString result = bufferIsImmutable && enableAliasing - ? ByteString.wrap(buffer, bufferPos, size) - : ByteString.copyFrom(buffer, bufferPos, size); - bufferPos += size; - return result; - } else if (size == 0) { - return ByteString.EMPTY; - } else { - // Slow path: Build a byte array first then copy it. - return ByteString.wrap(readRawBytesSlowPath(size)); - } - } + public abstract ByteString readBytes() throws IOException; /** Read a {@code bytes} field value from the stream. */ - public byte[] readByteArray() throws IOException { - final int size = readRawVarint32(); - if (size <= (bufferSize - bufferPos) && size > 0) { - // Fast path: We already have the bytes in a contiguous buffer, so - // just copy directly from it. - final byte[] result = - Arrays.copyOfRange(buffer, bufferPos, bufferPos + size); - bufferPos += size; - return result; - } else { - // Slow path: Build a byte array first then copy it. - return readRawBytesSlowPath(size); - } - } + public abstract byte[] readByteArray() throws IOException; /** Read a {@code bytes} field value from the stream. */ - public ByteBuffer readByteBuffer() throws IOException { - final int size = readRawVarint32(); - if (size <= (bufferSize - bufferPos) && size > 0) { - // Fast path: We already have the bytes in a contiguous buffer. - // When aliasing is enabled, we can return a ByteBuffer pointing directly - // into the underlying byte array without copy if the CodedInputStream is - // constructed from a byte array. If aliasing is disabled or the input is - // from an InputStream or ByteString, we have to make a copy of the bytes. - ByteBuffer result = input == null && !bufferIsImmutable && enableAliasing - ? ByteBuffer.wrap(buffer, bufferPos, size).slice() - : ByteBuffer.wrap(Arrays.copyOfRange( - buffer, bufferPos, bufferPos + size)); - bufferPos += size; - return result; - } else if (size == 0) { - return Internal.EMPTY_BYTE_BUFFER; - } else { - // Slow path: Build a byte array first then copy it. - return ByteBuffer.wrap(readRawBytesSlowPath(size)); - } - } + public abstract ByteBuffer readByteBuffer() throws IOException; /** Read a {@code uint32} field value from the stream. */ - public int readUInt32() throws IOException { - return readRawVarint32(); - } + public abstract int readUInt32() throws IOException; /** - * Read an enum field value from the stream. Caller is responsible - * for converting the numeric value to an actual enum. + * Read an enum field value from the stream. Caller is responsible for converting the numeric + * value to an actual enum. */ - public int readEnum() throws IOException { - return readRawVarint32(); - } + public abstract int readEnum() throws IOException; /** Read an {@code sfixed32} field value from the stream. */ - public int readSFixed32() throws IOException { - return readRawLittleEndian32(); - } + public abstract int readSFixed32() throws IOException; /** Read an {@code sfixed64} field value from the stream. */ - public long readSFixed64() throws IOException { - return readRawLittleEndian64(); - } + public abstract long readSFixed64() throws IOException; /** Read an {@code sint32} field value from the stream. */ - public int readSInt32() throws IOException { - return decodeZigZag32(readRawVarint32()); - } + public abstract int readSInt32() throws IOException; /** Read an {@code sint64} field value from the stream. */ - public long readSInt64() throws IOException { - return decodeZigZag64(readRawVarint64()); - } + public abstract long readSInt64() throws IOException; // ================================================================= + /** Read a raw Varint from the stream. If larger than 32 bits, discard the upper bits. */ + public abstract int readRawVarint32() throws IOException; + + /** Read a raw Varint from the stream. */ + public abstract long readRawVarint64() throws IOException; + + /** Variant of readRawVarint64 for when uncomfortably close to the limit. */ + /* Visible for testing */ + abstract long readRawVarint64SlowPath() throws IOException; + + /** Read a 32-bit little-endian integer from the stream. */ + public abstract int readRawLittleEndian32() throws IOException; + + /** Read a 64-bit little-endian integer from the stream. */ + public abstract long readRawLittleEndian64() throws IOException; + + // ----------------------------------------------------------------- + /** - * Read a raw Varint from the stream. If larger than 32 bits, discard the - * upper bits. + * Enables {@link ByteString} aliasing of the underlying buffer, trading off on buffer pinning for + * data copies. Only valid for buffer-backed streams. */ - public int readRawVarint32() throws IOException { - // See implementation notes for readRawVarint64 - fastpath: { - int pos = bufferPos; + public abstract void enableAliasing(boolean enabled); - if (bufferSize == pos) { - break fastpath; - } - - final byte[] buffer = this.buffer; - int x; - if ((x = buffer[pos++]) >= 0) { - bufferPos = pos; - return x; - } else if (bufferSize - pos < 9) { - break fastpath; - } else if ((x ^= (buffer[pos++] << 7)) < 0) { - x ^= (~0 << 7); - } else if ((x ^= (buffer[pos++] << 14)) >= 0) { - x ^= (~0 << 7) ^ (~0 << 14); - } else if ((x ^= (buffer[pos++] << 21)) < 0) { - x ^= (~0 << 7) ^ (~0 << 14) ^ (~0 << 21); - } else { - int y = buffer[pos++]; - x ^= y << 28; - x ^= (~0 << 7) ^ (~0 << 14) ^ (~0 << 21) ^ (~0 << 28); - if (y < 0 && - buffer[pos++] < 0 && - buffer[pos++] < 0 && - buffer[pos++] < 0 && - buffer[pos++] < 0 && - buffer[pos++] < 0) { - break fastpath; // Will throw malformedVarint() - } - } - bufferPos = pos; - return x; + /** + * Set the maximum message recursion depth. In order to prevent malicious messages from causing + * stack overflows, {@code CodedInputStream} limits how deeply messages may be nested. The default + * limit is 64. + * + * @return the old limit. + */ + public final int setRecursionLimit(final int limit) { + if (limit < 0) { + throw new IllegalArgumentException("Recursion limit cannot be negative: " + limit); } - return (int) readRawVarint64SlowPath(); + final int oldLimit = recursionLimit; + recursionLimit = limit; + return oldLimit; } - private void skipRawVarint() throws IOException { - if (bufferSize - bufferPos >= 10) { - final byte[] buffer = this.buffer; - int pos = bufferPos; - for (int i = 0; i < 10; i++) { - if (buffer[pos++] >= 0) { - bufferPos = pos; - return; - } - } + /** + * Only valid for {@link InputStream}-backed streams. + * + * <p>Set the maximum message size. In order to prevent malicious messages from exhausting memory + * or causing integer overflows, {@code CodedInputStream} limits how large a message may be. The + * default limit is {@code Integer.MAX_INT}. You should set this limit as small as you can without + * harming your app's functionality. Note that size limits only apply when reading from an {@code + * InputStream}, not when constructed around a raw byte array. + * + * <p>If you want to read several messages from a single CodedInputStream, you could call {@link + * #resetSizeCounter()} after each one to avoid hitting the size limit. + * + * @return the old limit. + */ + public final int setSizeLimit(final int limit) { + if (limit < 0) { + throw new IllegalArgumentException("Size limit cannot be negative: " + limit); } - skipRawVarintSlowPath(); + final int oldLimit = sizeLimit; + sizeLimit = limit; + return oldLimit; } - private void skipRawVarintSlowPath() throws IOException { - for (int i = 0; i < 10; i++) { - if (readRawByte() >= 0) { - return; - } - } - throw InvalidProtocolBufferException.malformedVarint(); + + private boolean explicitDiscardUnknownFields = false; + + private static volatile boolean proto3DiscardUnknownFieldsDefault = false; + + static void setProto3DiscardUnknownsByDefaultForTest() { + proto3DiscardUnknownFieldsDefault = true; + } + + static void setProto3KeepUnknownsByDefaultForTest() { + proto3DiscardUnknownFieldsDefault = false; + } + + static boolean getProto3DiscardUnknownFieldsDefault() { + return proto3DiscardUnknownFieldsDefault; } /** - * Reads a varint from the input one byte at a time, so that it does not - * read any bytes after the end of the varint. If you simply wrapped the - * stream in a CodedInputStream and used {@link #readRawVarint32(InputStream)} - * then you would probably end up reading past the end of the varint since - * CodedInputStream buffers its input. + * Sets this {@code CodedInputStream} to discard unknown fields. Only applies to full runtime + * messages; lite messages will always preserve unknowns. + * + * <p>Note calling this function alone will have NO immediate effect on the underlying input data. + * The unknown fields will be discarded during parsing. This affects both Proto2 and Proto3 full + * runtime. */ - static int readRawVarint32(final InputStream input) throws IOException { - final int firstByte = input.read(); - if (firstByte == -1) { - throw InvalidProtocolBufferException.truncatedMessage(); - } - return readRawVarint32(firstByte, input); + final void discardUnknownFields() { + explicitDiscardUnknownFields = true; + } + + /** + * Reverts the unknown fields preservation behavior for Proto2 and Proto3 full runtime to their + * default. + */ + final void unsetDiscardUnknownFields() { + explicitDiscardUnknownFields = false; + } + + /** + * Whether unknown fields in this input stream should be discarded during parsing into full + * runtime messages. + */ + final boolean shouldDiscardUnknownFields() { + return explicitDiscardUnknownFields; + } + + /** + * Whether unknown fields in this input stream should be discarded during parsing for proto3 full + * runtime messages. + * + * <p>This function was temporarily introduced before proto3 unknown fields behavior is changed. + * TODO(liujisi): remove this and related code in GeneratedMessage after proto3 unknown + * fields migration is done. + */ + final boolean shouldDiscardUnknownFieldsProto3() { + return explicitDiscardUnknownFields ? true : proto3DiscardUnknownFieldsDefault; + } + + /** + * Resets the current size counter to zero (see {@link #setSizeLimit(int)}). Only valid for {@link + * InputStream}-backed streams. + */ + public abstract void resetSizeCounter(); + + /** + * Sets {@code currentLimit} to (current position) + {@code byteLimit}. This is called when + * descending into a length-delimited embedded message. + * + * <p>Note that {@code pushLimit()} does NOT affect how many bytes the {@code CodedInputStream} + * reads from an underlying {@code InputStream} when refreshing its buffer. If you need to prevent + * reading past a certain point in the underlying {@code InputStream} (e.g. because you expect it + * to contain more data after the end of the message which you need to handle differently) then + * you must place a wrapper around your {@code InputStream} which limits the amount of data that + * can be read from it. + * + * @return the old limit. + */ + public abstract int pushLimit(int byteLimit) throws InvalidProtocolBufferException; + + /** + * Discards the current limit, returning to the previous limit. + * + * @param oldLimit The old limit, as returned by {@code pushLimit}. + */ + public abstract void popLimit(final int oldLimit); + + /** + * Returns the number of bytes to be read before the current limit. If no limit is set, returns + * -1. + */ + public abstract int getBytesUntilLimit(); + + /** + * Returns true if the stream has reached the end of the input. This is the case if either the end + * of the underlying input source has been reached or if the stream has reached a limit created + * using {@link #pushLimit(int)}. + */ + public abstract boolean isAtEnd() throws IOException; + + /** + * The total bytes read up to the current position. Calling {@link #resetSizeCounter()} resets + * this value to zero. + */ + public abstract int getTotalBytesRead(); + + /** + * Read one byte from the input. + * + * @throws InvalidProtocolBufferException The end of the stream or the current limit was reached. + */ + public abstract byte readRawByte() throws IOException; + + /** + * Read a fixed size of bytes from the input. + * + * @throws InvalidProtocolBufferException The end of the stream or the current limit was reached. + */ + public abstract byte[] readRawBytes(final int size) throws IOException; + + /** + * Reads and discards {@code size} bytes. + * + * @throws InvalidProtocolBufferException The end of the stream or the current limit was reached. + */ + public abstract void skipRawBytes(final int size) throws IOException; + + /** + * Decode a ZigZag-encoded 32-bit value. ZigZag encodes signed integers into values that can be + * efficiently encoded with varint. (Otherwise, negative values must be sign-extended to 64 bits + * to be varint encoded, thus always taking 10 bytes on the wire.) + * + * @param n An unsigned 32-bit integer, stored in a signed int because Java has no explicit + * unsigned support. + * @return A signed 32-bit integer. + */ + public static int decodeZigZag32(final int n) { + return (n >>> 1) ^ -(n & 1); } /** - * Like {@link #readRawVarint32(InputStream)}, but expects that the caller - * has already read one byte. This allows the caller to determine if EOF - * has been reached before attempting to read. + * Decode a ZigZag-encoded 64-bit value. ZigZag encodes signed integers into values that can be + * efficiently encoded with varint. (Otherwise, negative values must be sign-extended to 64 bits + * to be varint encoded, thus always taking 10 bytes on the wire.) + * + * @param n An unsigned 64-bit integer, stored in a signed int because Java has no explicit + * unsigned support. + * @return A signed 64-bit integer. + */ + public static long decodeZigZag64(final long n) { + return (n >>> 1) ^ -(n & 1); + } + + /** + * Like {@link #readRawVarint32(InputStream)}, but expects that the caller has already read one + * byte. This allows the caller to determine if EOF has been reached before attempting to read. */ - public static int readRawVarint32( - final int firstByte, final InputStream input) throws IOException { + public static int readRawVarint32(final int firstByte, final InputStream input) + throws IOException { if ((firstByte & 0x80) == 0) { return firstByte; } @@ -713,590 +602,3341 @@ public final class CodedInputStream { throw InvalidProtocolBufferException.malformedVarint(); } - /** Read a raw Varint from the stream. */ - public long readRawVarint64() throws IOException { - // Implementation notes: - // - // Optimized for one-byte values, expected to be common. - // The particular code below was selected from various candidates - // empirically, by winning VarintBenchmark. - // - // Sign extension of (signed) Java bytes is usually a nuisance, but - // we exploit it here to more easily obtain the sign of bytes read. - // Instead of cleaning up the sign extension bits by masking eagerly, - // we delay until we find the final (positive) byte, when we clear all - // accumulated bits with one xor. We depend on javac to constant fold. - fastpath: { - int pos = bufferPos; - - if (bufferSize == pos) { - break fastpath; + /** + * Reads a varint from the input one byte at a time, so that it does not read any bytes after the + * end of the varint. If you simply wrapped the stream in a CodedInputStream and used {@link + * #readRawVarint32(InputStream)} then you would probably end up reading past the end of the + * varint since CodedInputStream buffers its input. + */ + static int readRawVarint32(final InputStream input) throws IOException { + final int firstByte = input.read(); + if (firstByte == -1) { + throw InvalidProtocolBufferException.truncatedMessage(); + } + return readRawVarint32(firstByte, input); + } + + /** A {@link CodedInputStream} implementation that uses a backing array as the input. */ + private static final class ArrayDecoder extends CodedInputStream { + private final byte[] buffer; + private final boolean immutable; + private int limit; + private int bufferSizeAfterLimit; + private int pos; + private int startPos; + private int lastTag; + private boolean enableAliasing; + + /** The absolute position of the end of the current message. */ + private int currentLimit = Integer.MAX_VALUE; + + private ArrayDecoder(final byte[] buffer, final int offset, final int len, boolean immutable) { + this.buffer = buffer; + limit = offset + len; + pos = offset; + startPos = pos; + this.immutable = immutable; + } + + @Override + public int readTag() throws IOException { + if (isAtEnd()) { + lastTag = 0; + return 0; } - final byte[] buffer = this.buffer; - long x; - int y; - if ((y = buffer[pos++]) >= 0) { - bufferPos = pos; - return y; - } else if (bufferSize - pos < 9) { - break fastpath; - } else if ((y ^= (buffer[pos++] << 7)) < 0) { - x = y ^ (~0 << 7); - } else if ((y ^= (buffer[pos++] << 14)) >= 0) { - x = y ^ ((~0 << 7) ^ (~0 << 14)); - } else if ((y ^= (buffer[pos++] << 21)) < 0) { - x = y ^ ((~0 << 7) ^ (~0 << 14) ^ (~0 << 21)); - } else if ((x = ((long) y) ^ ((long) buffer[pos++] << 28)) >= 0L) { - x ^= (~0L << 7) ^ (~0L << 14) ^ (~0L << 21) ^ (~0L << 28); - } else if ((x ^= ((long) buffer[pos++] << 35)) < 0L) { - x ^= (~0L << 7) ^ (~0L << 14) ^ (~0L << 21) ^ (~0L << 28) ^ (~0L << 35); - } else if ((x ^= ((long) buffer[pos++] << 42)) >= 0L) { - x ^= (~0L << 7) ^ (~0L << 14) ^ (~0L << 21) ^ (~0L << 28) ^ (~0L << 35) ^ (~0L << 42); - } else if ((x ^= ((long) buffer[pos++] << 49)) < 0L) { - x ^= (~0L << 7) ^ (~0L << 14) ^ (~0L << 21) ^ (~0L << 28) ^ (~0L << 35) ^ (~0L << 42) - ^ (~0L << 49); - } else { - x ^= ((long) buffer[pos++] << 56); - x ^= (~0L << 7) ^ (~0L << 14) ^ (~0L << 21) ^ (~0L << 28) ^ (~0L << 35) ^ (~0L << 42) - ^ (~0L << 49) ^ (~0L << 56); - if (x < 0L) { - if (buffer[pos++] < 0L) { - break fastpath; // Will throw malformedVarint() + lastTag = readRawVarint32(); + if (WireFormat.getTagFieldNumber(lastTag) == 0) { + // If we actually read zero (or any tag number corresponding to field + // number zero), that's not a valid tag. + throw InvalidProtocolBufferException.invalidTag(); + } + return lastTag; + } + + @Override + public void checkLastTagWas(final int value) throws InvalidProtocolBufferException { + if (lastTag != value) { + throw InvalidProtocolBufferException.invalidEndTag(); + } + } + + @Override + public int getLastTag() { + return lastTag; + } + + @Override + public boolean skipField(final int tag) throws IOException { + switch (WireFormat.getTagWireType(tag)) { + case WireFormat.WIRETYPE_VARINT: + skipRawVarint(); + return true; + case WireFormat.WIRETYPE_FIXED64: + skipRawBytes(FIXED64_SIZE); + return true; + case WireFormat.WIRETYPE_LENGTH_DELIMITED: + skipRawBytes(readRawVarint32()); + return true; + case WireFormat.WIRETYPE_START_GROUP: + skipMessage(); + checkLastTagWas( + WireFormat.makeTag(WireFormat.getTagFieldNumber(tag), WireFormat.WIRETYPE_END_GROUP)); + return true; + case WireFormat.WIRETYPE_END_GROUP: + return false; + case WireFormat.WIRETYPE_FIXED32: + skipRawBytes(FIXED32_SIZE); + return true; + default: + throw InvalidProtocolBufferException.invalidWireType(); + } + } + + @Override + public boolean skipField(final int tag, final CodedOutputStream output) throws IOException { + switch (WireFormat.getTagWireType(tag)) { + case WireFormat.WIRETYPE_VARINT: + { + long value = readInt64(); + output.writeRawVarint32(tag); + output.writeUInt64NoTag(value); + return true; + } + case WireFormat.WIRETYPE_FIXED64: + { + long value = readRawLittleEndian64(); + output.writeRawVarint32(tag); + output.writeFixed64NoTag(value); + return true; } + case WireFormat.WIRETYPE_LENGTH_DELIMITED: + { + ByteString value = readBytes(); + output.writeRawVarint32(tag); + output.writeBytesNoTag(value); + return true; + } + case WireFormat.WIRETYPE_START_GROUP: + { + output.writeRawVarint32(tag); + skipMessage(output); + int endtag = + WireFormat.makeTag( + WireFormat.getTagFieldNumber(tag), WireFormat.WIRETYPE_END_GROUP); + checkLastTagWas(endtag); + output.writeRawVarint32(endtag); + return true; + } + case WireFormat.WIRETYPE_END_GROUP: + { + return false; + } + case WireFormat.WIRETYPE_FIXED32: + { + int value = readRawLittleEndian32(); + output.writeRawVarint32(tag); + output.writeFixed32NoTag(value); + return true; + } + default: + throw InvalidProtocolBufferException.invalidWireType(); + } + } + + @Override + public void skipMessage() throws IOException { + while (true) { + final int tag = readTag(); + if (tag == 0 || !skipField(tag)) { + return; } } - bufferPos = pos; - return x; } - return readRawVarint64SlowPath(); - } - /** Variant of readRawVarint64 for when uncomfortably close to the limit. */ - /* Visible for testing */ - long readRawVarint64SlowPath() throws IOException { - long result = 0; - for (int shift = 0; shift < 64; shift += 7) { - final byte b = readRawByte(); - result |= (long) (b & 0x7F) << shift; - if ((b & 0x80) == 0) { + @Override + public void skipMessage(CodedOutputStream output) throws IOException { + while (true) { + final int tag = readTag(); + if (tag == 0 || !skipField(tag, output)) { + return; + } + } + } + + + // ----------------------------------------------------------------- + + @Override + public double readDouble() throws IOException { + return Double.longBitsToDouble(readRawLittleEndian64()); + } + + @Override + public float readFloat() throws IOException { + return Float.intBitsToFloat(readRawLittleEndian32()); + } + + @Override + public long readUInt64() throws IOException { + return readRawVarint64(); + } + + @Override + public long readInt64() throws IOException { + return readRawVarint64(); + } + + @Override + public int readInt32() throws IOException { + return readRawVarint32(); + } + + @Override + public long readFixed64() throws IOException { + return readRawLittleEndian64(); + } + + @Override + public int readFixed32() throws IOException { + return readRawLittleEndian32(); + } + + @Override + public boolean readBool() throws IOException { + return readRawVarint64() != 0; + } + + @Override + public String readString() throws IOException { + final int size = readRawVarint32(); + if (size > 0 && size <= (limit - pos)) { + // Fast path: We already have the bytes in a contiguous buffer, so + // just copy directly from it. + final String result = new String(buffer, pos, size, UTF_8); + pos += size; return result; } + + if (size == 0) { + return ""; + } + if (size < 0) { + throw InvalidProtocolBufferException.negativeSize(); + } + throw InvalidProtocolBufferException.truncatedMessage(); } - throw InvalidProtocolBufferException.malformedVarint(); - } - /** Read a 32-bit little-endian integer from the stream. */ - public int readRawLittleEndian32() throws IOException { - int pos = bufferPos; + @Override + public String readStringRequireUtf8() throws IOException { + final int size = readRawVarint32(); + if (size > 0 && size <= (limit - pos)) { + if (ENABLE_CUSTOM_UTF8_DECODE) { + String result = Utf8.decodeUtf8(buffer, pos, size); + pos += size; + return result; + } else { + // TODO(martinrb): We could save a pass by validating while decoding. + if (!Utf8.isValidUtf8(buffer, pos, pos + size)) { + throw InvalidProtocolBufferException.invalidUtf8(); + } + final int tempPos = pos; + pos += size; + return new String(buffer, tempPos, size, UTF_8); + } + } - // hand-inlined ensureAvailable(4); - if (bufferSize - pos < 4) { - refillBuffer(4); - pos = bufferPos; + if (size == 0) { + return ""; + } + if (size <= 0) { + throw InvalidProtocolBufferException.negativeSize(); + } + throw InvalidProtocolBufferException.truncatedMessage(); } - final byte[] buffer = this.buffer; - bufferPos = pos + 4; - return (((buffer[pos] & 0xff)) | - ((buffer[pos + 1] & 0xff) << 8) | - ((buffer[pos + 2] & 0xff) << 16) | - ((buffer[pos + 3] & 0xff) << 24)); - } + @Override + public void readGroup( + final int fieldNumber, + final MessageLite.Builder builder, + final ExtensionRegistryLite extensionRegistry) + throws IOException { + if (recursionDepth >= recursionLimit) { + throw InvalidProtocolBufferException.recursionLimitExceeded(); + } + ++recursionDepth; + builder.mergeFrom(this, extensionRegistry); + checkLastTagWas(WireFormat.makeTag(fieldNumber, WireFormat.WIRETYPE_END_GROUP)); + --recursionDepth; + } - /** Read a 64-bit little-endian integer from the stream. */ - public long readRawLittleEndian64() throws IOException { - int pos = bufferPos; - - // hand-inlined ensureAvailable(8); - if (bufferSize - pos < 8) { - refillBuffer(8); - pos = bufferPos; - } - - final byte[] buffer = this.buffer; - bufferPos = pos + 8; - return ((((long) buffer[pos] & 0xffL)) | - (((long) buffer[pos + 1] & 0xffL) << 8) | - (((long) buffer[pos + 2] & 0xffL) << 16) | - (((long) buffer[pos + 3] & 0xffL) << 24) | - (((long) buffer[pos + 4] & 0xffL) << 32) | - (((long) buffer[pos + 5] & 0xffL) << 40) | - (((long) buffer[pos + 6] & 0xffL) << 48) | - (((long) buffer[pos + 7] & 0xffL) << 56)); - } - /** - * Decode a ZigZag-encoded 32-bit value. ZigZag encodes signed integers - * into values that can be efficiently encoded with varint. (Otherwise, - * negative values must be sign-extended to 64 bits to be varint encoded, - * thus always taking 10 bytes on the wire.) - * - * @param n An unsigned 32-bit integer, stored in a signed int because - * Java has no explicit unsigned support. - * @return A signed 32-bit integer. - */ - public static int decodeZigZag32(final int n) { - return (n >>> 1) ^ -(n & 1); - } + @Override + public <T extends MessageLite> T readGroup( + final int fieldNumber, + final Parser<T> parser, + final ExtensionRegistryLite extensionRegistry) + throws IOException { + if (recursionDepth >= recursionLimit) { + throw InvalidProtocolBufferException.recursionLimitExceeded(); + } + ++recursionDepth; + T result = parser.parsePartialFrom(this, extensionRegistry); + checkLastTagWas(WireFormat.makeTag(fieldNumber, WireFormat.WIRETYPE_END_GROUP)); + --recursionDepth; + return result; + } - /** - * Decode a ZigZag-encoded 64-bit value. ZigZag encodes signed integers - * into values that can be efficiently encoded with varint. (Otherwise, - * negative values must be sign-extended to 64 bits to be varint encoded, - * thus always taking 10 bytes on the wire.) - * - * @param n An unsigned 64-bit integer, stored in a signed int because - * Java has no explicit unsigned support. - * @return A signed 64-bit integer. - */ - public static long decodeZigZag64(final long n) { - return (n >>> 1) ^ -(n & 1); - } + @Deprecated + @Override + public void readUnknownGroup(final int fieldNumber, final MessageLite.Builder builder) + throws IOException { + readGroup(fieldNumber, builder, ExtensionRegistryLite.getEmptyRegistry()); + } - // ----------------------------------------------------------------- + @Override + public void readMessage( + final MessageLite.Builder builder, final ExtensionRegistryLite extensionRegistry) + throws IOException { + final int length = readRawVarint32(); + if (recursionDepth >= recursionLimit) { + throw InvalidProtocolBufferException.recursionLimitExceeded(); + } + final int oldLimit = pushLimit(length); + ++recursionDepth; + builder.mergeFrom(this, extensionRegistry); + checkLastTagWas(0); + --recursionDepth; + popLimit(oldLimit); + } - private final byte[] buffer; - private final boolean bufferIsImmutable; - private int bufferSize; - private int bufferSizeAfterLimit; - private int bufferPos; - private final InputStream input; - private int lastTag; - private boolean enableAliasing = false; - /** - * The total number of bytes read before the current buffer. The total - * bytes read up to the current position can be computed as - * {@code totalBytesRetired + bufferPos}. This value may be negative if - * reading started in the middle of the current buffer (e.g. if the - * constructor that takes a byte array and an offset was used). - */ - private int totalBytesRetired; + @Override + public <T extends MessageLite> T readMessage( + final Parser<T> parser, final ExtensionRegistryLite extensionRegistry) throws IOException { + int length = readRawVarint32(); + if (recursionDepth >= recursionLimit) { + throw InvalidProtocolBufferException.recursionLimitExceeded(); + } + final int oldLimit = pushLimit(length); + ++recursionDepth; + T result = parser.parsePartialFrom(this, extensionRegistry); + checkLastTagWas(0); + --recursionDepth; + popLimit(oldLimit); + return result; + } - /** The absolute position of the end of the current message. */ - private int currentLimit = Integer.MAX_VALUE; + @Override + public ByteString readBytes() throws IOException { + final int size = readRawVarint32(); + if (size > 0 && size <= (limit - pos)) { + // Fast path: We already have the bytes in a contiguous buffer, so + // just copy directly from it. + final ByteString result = + immutable && enableAliasing + ? ByteString.wrap(buffer, pos, size) + : ByteString.copyFrom(buffer, pos, size); + pos += size; + return result; + } + if (size == 0) { + return ByteString.EMPTY; + } + // Slow path: Build a byte array first then copy it. + return ByteString.wrap(readRawBytes(size)); + } - /** See setRecursionLimit() */ - private int recursionDepth; - private int recursionLimit = DEFAULT_RECURSION_LIMIT; + @Override + public byte[] readByteArray() throws IOException { + final int size = readRawVarint32(); + return readRawBytes(size); + } - /** See setSizeLimit() */ - private int sizeLimit = DEFAULT_SIZE_LIMIT; + @Override + public ByteBuffer readByteBuffer() throws IOException { + final int size = readRawVarint32(); + if (size > 0 && size <= (limit - pos)) { + // Fast path: We already have the bytes in a contiguous buffer. + // When aliasing is enabled, we can return a ByteBuffer pointing directly + // into the underlying byte array without copy if the CodedInputStream is + // constructed from a byte array. If aliasing is disabled or the input is + // from an InputStream or ByteString, we have to make a copy of the bytes. + ByteBuffer result = + !immutable && enableAliasing + ? ByteBuffer.wrap(buffer, pos, size).slice() + : ByteBuffer.wrap(Arrays.copyOfRange(buffer, pos, pos + size)); + pos += size; + // TODO(nathanmittler): Investigate making the ByteBuffer be made read-only + return result; + } - private static final int DEFAULT_RECURSION_LIMIT = 100; - private static final int DEFAULT_SIZE_LIMIT = 64 << 20; // 64MB - private static final int BUFFER_SIZE = 4096; - - private CodedInputStream( - final byte[] buffer, final int off, final int len, boolean bufferIsImmutable) { - this.buffer = buffer; - bufferSize = off + len; - bufferPos = off; - totalBytesRetired = -off; - input = null; - this.bufferIsImmutable = bufferIsImmutable; - } + if (size == 0) { + return EMPTY_BYTE_BUFFER; + } + if (size < 0) { + throw InvalidProtocolBufferException.negativeSize(); + } + throw InvalidProtocolBufferException.truncatedMessage(); + } - private CodedInputStream(final InputStream input, int bufferSize) { - buffer = new byte[bufferSize]; - bufferSize = 0; - bufferPos = 0; - totalBytesRetired = 0; - this.input = input; - bufferIsImmutable = false; - } + @Override + public int readUInt32() throws IOException { + return readRawVarint32(); + } - public void enableAliasing(boolean enabled) { - this.enableAliasing = enabled; - } + @Override + public int readEnum() throws IOException { + return readRawVarint32(); + } - /** - * Set the maximum message recursion depth. In order to prevent malicious - * messages from causing stack overflows, {@code CodedInputStream} limits - * how deeply messages may be nested. The default limit is 64. - * - * @return the old limit. - */ - public int setRecursionLimit(final int limit) { - if (limit < 0) { - throw new IllegalArgumentException( - "Recursion limit cannot be negative: " + limit); + @Override + public int readSFixed32() throws IOException { + return readRawLittleEndian32(); } - final int oldLimit = recursionLimit; - recursionLimit = limit; - return oldLimit; - } - /** - * Set the maximum message size. In order to prevent malicious - * messages from exhausting memory or causing integer overflows, - * {@code CodedInputStream} limits how large a message may be. - * The default limit is 64MB. You should set this limit as small - * as you can without harming your app's functionality. Note that - * size limits only apply when reading from an {@code InputStream}, not - * when constructed around a raw byte array (nor with - * {@link ByteString#newCodedInput}). - * <p> - * If you want to read several messages from a single CodedInputStream, you - * could call {@link #resetSizeCounter()} after each one to avoid hitting the - * size limit. - * - * @return the old limit. - */ - public int setSizeLimit(final int limit) { - if (limit < 0) { - throw new IllegalArgumentException( - "Size limit cannot be negative: " + limit); + @Override + public long readSFixed64() throws IOException { + return readRawLittleEndian64(); } - final int oldLimit = sizeLimit; - sizeLimit = limit; - return oldLimit; - } - /** - * Resets the current size counter to zero (see {@link #setSizeLimit(int)}). - */ - public void resetSizeCounter() { - totalBytesRetired = -bufferPos; + @Override + public int readSInt32() throws IOException { + return decodeZigZag32(readRawVarint32()); + } + + @Override + public long readSInt64() throws IOException { + return decodeZigZag64(readRawVarint64()); + } + + // ================================================================= + + @Override + public int readRawVarint32() throws IOException { + // See implementation notes for readRawVarint64 + fastpath: + { + int tempPos = pos; + + if (limit == tempPos) { + break fastpath; + } + + final byte[] buffer = this.buffer; + int x; + if ((x = buffer[tempPos++]) >= 0) { + pos = tempPos; + return x; + } else if (limit - tempPos < 9) { + break fastpath; + } else if ((x ^= (buffer[tempPos++] << 7)) < 0) { + x ^= (~0 << 7); + } else if ((x ^= (buffer[tempPos++] << 14)) >= 0) { + x ^= (~0 << 7) ^ (~0 << 14); + } else if ((x ^= (buffer[tempPos++] << 21)) < 0) { + x ^= (~0 << 7) ^ (~0 << 14) ^ (~0 << 21); + } else { + int y = buffer[tempPos++]; + x ^= y << 28; + x ^= (~0 << 7) ^ (~0 << 14) ^ (~0 << 21) ^ (~0 << 28); + if (y < 0 + && buffer[tempPos++] < 0 + && buffer[tempPos++] < 0 + && buffer[tempPos++] < 0 + && buffer[tempPos++] < 0 + && buffer[tempPos++] < 0) { + break fastpath; // Will throw malformedVarint() + } + } + pos = tempPos; + return x; + } + return (int) readRawVarint64SlowPath(); + } + + private void skipRawVarint() throws IOException { + if (limit - pos >= MAX_VARINT_SIZE) { + skipRawVarintFastPath(); + } else { + skipRawVarintSlowPath(); + } + } + + private void skipRawVarintFastPath() throws IOException { + for (int i = 0; i < MAX_VARINT_SIZE; i++) { + if (buffer[pos++] >= 0) { + return; + } + } + throw InvalidProtocolBufferException.malformedVarint(); + } + + private void skipRawVarintSlowPath() throws IOException { + for (int i = 0; i < MAX_VARINT_SIZE; i++) { + if (readRawByte() >= 0) { + return; + } + } + throw InvalidProtocolBufferException.malformedVarint(); + } + + @Override + public long readRawVarint64() throws IOException { + // Implementation notes: + // + // Optimized for one-byte values, expected to be common. + // The particular code below was selected from various candidates + // empirically, by winning VarintBenchmark. + // + // Sign extension of (signed) Java bytes is usually a nuisance, but + // we exploit it here to more easily obtain the sign of bytes read. + // Instead of cleaning up the sign extension bits by masking eagerly, + // we delay until we find the final (positive) byte, when we clear all + // accumulated bits with one xor. We depend on javac to constant fold. + fastpath: + { + int tempPos = pos; + + if (limit == tempPos) { + break fastpath; + } + + final byte[] buffer = this.buffer; + long x; + int y; + if ((y = buffer[tempPos++]) >= 0) { + pos = tempPos; + return y; + } else if (limit - tempPos < 9) { + break fastpath; + } else if ((y ^= (buffer[tempPos++] << 7)) < 0) { + x = y ^ (~0 << 7); + } else if ((y ^= (buffer[tempPos++] << 14)) >= 0) { + x = y ^ ((~0 << 7) ^ (~0 << 14)); + } else if ((y ^= (buffer[tempPos++] << 21)) < 0) { + x = y ^ ((~0 << 7) ^ (~0 << 14) ^ (~0 << 21)); + } else if ((x = y ^ ((long) buffer[tempPos++] << 28)) >= 0L) { + x ^= (~0L << 7) ^ (~0L << 14) ^ (~0L << 21) ^ (~0L << 28); + } else if ((x ^= ((long) buffer[tempPos++] << 35)) < 0L) { + x ^= (~0L << 7) ^ (~0L << 14) ^ (~0L << 21) ^ (~0L << 28) ^ (~0L << 35); + } else if ((x ^= ((long) buffer[tempPos++] << 42)) >= 0L) { + x ^= (~0L << 7) ^ (~0L << 14) ^ (~0L << 21) ^ (~0L << 28) ^ (~0L << 35) ^ (~0L << 42); + } else if ((x ^= ((long) buffer[tempPos++] << 49)) < 0L) { + x ^= + (~0L << 7) + ^ (~0L << 14) + ^ (~0L << 21) + ^ (~0L << 28) + ^ (~0L << 35) + ^ (~0L << 42) + ^ (~0L << 49); + } else { + x ^= ((long) buffer[tempPos++] << 56); + x ^= + (~0L << 7) + ^ (~0L << 14) + ^ (~0L << 21) + ^ (~0L << 28) + ^ (~0L << 35) + ^ (~0L << 42) + ^ (~0L << 49) + ^ (~0L << 56); + if (x < 0L) { + if (buffer[tempPos++] < 0L) { + break fastpath; // Will throw malformedVarint() + } + } + } + pos = tempPos; + return x; + } + return readRawVarint64SlowPath(); + } + + @Override + long readRawVarint64SlowPath() throws IOException { + long result = 0; + for (int shift = 0; shift < 64; shift += 7) { + final byte b = readRawByte(); + result |= (long) (b & 0x7F) << shift; + if ((b & 0x80) == 0) { + return result; + } + } + throw InvalidProtocolBufferException.malformedVarint(); + } + + @Override + public int readRawLittleEndian32() throws IOException { + int tempPos = pos; + + if (limit - tempPos < FIXED32_SIZE) { + throw InvalidProtocolBufferException.truncatedMessage(); + } + + final byte[] buffer = this.buffer; + pos = tempPos + FIXED32_SIZE; + return (((buffer[tempPos] & 0xff)) + | ((buffer[tempPos + 1] & 0xff) << 8) + | ((buffer[tempPos + 2] & 0xff) << 16) + | ((buffer[tempPos + 3] & 0xff) << 24)); + } + + @Override + public long readRawLittleEndian64() throws IOException { + int tempPos = pos; + + if (limit - tempPos < FIXED64_SIZE) { + throw InvalidProtocolBufferException.truncatedMessage(); + } + + final byte[] buffer = this.buffer; + pos = tempPos + FIXED64_SIZE; + return (((buffer[tempPos] & 0xffL)) + | ((buffer[tempPos + 1] & 0xffL) << 8) + | ((buffer[tempPos + 2] & 0xffL) << 16) + | ((buffer[tempPos + 3] & 0xffL) << 24) + | ((buffer[tempPos + 4] & 0xffL) << 32) + | ((buffer[tempPos + 5] & 0xffL) << 40) + | ((buffer[tempPos + 6] & 0xffL) << 48) + | ((buffer[tempPos + 7] & 0xffL) << 56)); + } + + @Override + public void enableAliasing(boolean enabled) { + this.enableAliasing = enabled; + } + + @Override + public void resetSizeCounter() { + startPos = pos; + } + + @Override + public int pushLimit(int byteLimit) throws InvalidProtocolBufferException { + if (byteLimit < 0) { + throw InvalidProtocolBufferException.negativeSize(); + } + byteLimit += getTotalBytesRead(); + final int oldLimit = currentLimit; + if (byteLimit > oldLimit) { + throw InvalidProtocolBufferException.truncatedMessage(); + } + currentLimit = byteLimit; + + recomputeBufferSizeAfterLimit(); + + return oldLimit; + } + + private void recomputeBufferSizeAfterLimit() { + limit += bufferSizeAfterLimit; + final int bufferEnd = limit - startPos; + if (bufferEnd > currentLimit) { + // Limit is in current buffer. + bufferSizeAfterLimit = bufferEnd - currentLimit; + limit -= bufferSizeAfterLimit; + } else { + bufferSizeAfterLimit = 0; + } + } + + @Override + public void popLimit(final int oldLimit) { + currentLimit = oldLimit; + recomputeBufferSizeAfterLimit(); + } + + @Override + public int getBytesUntilLimit() { + if (currentLimit == Integer.MAX_VALUE) { + return -1; + } + + return currentLimit - getTotalBytesRead(); + } + + @Override + public boolean isAtEnd() throws IOException { + return pos == limit; + } + + @Override + public int getTotalBytesRead() { + return pos - startPos; + } + + @Override + public byte readRawByte() throws IOException { + if (pos == limit) { + throw InvalidProtocolBufferException.truncatedMessage(); + } + return buffer[pos++]; + } + + @Override + public byte[] readRawBytes(final int length) throws IOException { + if (length > 0 && length <= (limit - pos)) { + final int tempPos = pos; + pos += length; + return Arrays.copyOfRange(buffer, tempPos, pos); + } + + if (length <= 0) { + if (length == 0) { + return Internal.EMPTY_BYTE_ARRAY; + } else { + throw InvalidProtocolBufferException.negativeSize(); + } + } + throw InvalidProtocolBufferException.truncatedMessage(); + } + + @Override + public void skipRawBytes(final int length) throws IOException { + if (length >= 0 && length <= (limit - pos)) { + // We have all the bytes we need already. + pos += length; + return; + } + + if (length < 0) { + throw InvalidProtocolBufferException.negativeSize(); + } + throw InvalidProtocolBufferException.truncatedMessage(); + } } /** - * Sets {@code currentLimit} to (current position) + {@code byteLimit}. This - * is called when descending into a length-delimited embedded message. - * - * <p>Note that {@code pushLimit()} does NOT affect how many bytes the - * {@code CodedInputStream} reads from an underlying {@code InputStream} when - * refreshing its buffer. If you need to prevent reading past a certain - * point in the underlying {@code InputStream} (e.g. because you expect it to - * contain more data after the end of the message which you need to handle - * differently) then you must place a wrapper around your {@code InputStream} - * which limits the amount of data that can be read from it. - * - * @return the old limit. + * A {@link CodedInputStream} implementation that uses a backing direct ByteBuffer as the input. + * Requires the use of {@code sun.misc.Unsafe} to perform fast reads on the buffer. */ - public int pushLimit(int byteLimit) throws InvalidProtocolBufferException { - if (byteLimit < 0) { - throw InvalidProtocolBufferException.negativeSize(); + private static final class UnsafeDirectNioDecoder extends CodedInputStream { + /** The direct buffer that is backing this stream. */ + private final ByteBuffer buffer; + + /** + * If {@code true}, indicates that the buffer is backing a {@link ByteString} and is therefore + * considered to be an immutable input source. + */ + private final boolean immutable; + + /** The unsafe address of the content of {@link #buffer}. */ + private final long address; + + /** The unsafe address of the current read limit of the buffer. */ + private long limit; + + /** The unsafe address of the current read position of the buffer. */ + private long pos; + + /** The unsafe address of the starting read position. */ + private long startPos; + + /** The amount of available data in the buffer beyond {@link #limit}. */ + private int bufferSizeAfterLimit; + + /** The last tag that was read from this stream. */ + private int lastTag; + + /** + * If {@code true}, indicates that calls to read {@link ByteString} or {@code byte[]} + * <strong>may</strong> return slices of the underlying buffer, rather than copies. + */ + private boolean enableAliasing; + + /** The absolute position of the end of the current message. */ + private int currentLimit = Integer.MAX_VALUE; + + static boolean isSupported() { + return UnsafeUtil.hasUnsafeByteBufferOperations(); + } + + private UnsafeDirectNioDecoder(ByteBuffer buffer, boolean immutable) { + this.buffer = buffer; + address = UnsafeUtil.addressOffset(buffer); + limit = address + buffer.limit(); + pos = address + buffer.position(); + startPos = pos; + this.immutable = immutable; + } + + @Override + public int readTag() throws IOException { + if (isAtEnd()) { + lastTag = 0; + return 0; + } + + lastTag = readRawVarint32(); + if (WireFormat.getTagFieldNumber(lastTag) == 0) { + // If we actually read zero (or any tag number corresponding to field + // number zero), that's not a valid tag. + throw InvalidProtocolBufferException.invalidTag(); + } + return lastTag; + } + + @Override + public void checkLastTagWas(final int value) throws InvalidProtocolBufferException { + if (lastTag != value) { + throw InvalidProtocolBufferException.invalidEndTag(); + } + } + + @Override + public int getLastTag() { + return lastTag; + } + + @Override + public boolean skipField(final int tag) throws IOException { + switch (WireFormat.getTagWireType(tag)) { + case WireFormat.WIRETYPE_VARINT: + skipRawVarint(); + return true; + case WireFormat.WIRETYPE_FIXED64: + skipRawBytes(FIXED64_SIZE); + return true; + case WireFormat.WIRETYPE_LENGTH_DELIMITED: + skipRawBytes(readRawVarint32()); + return true; + case WireFormat.WIRETYPE_START_GROUP: + skipMessage(); + checkLastTagWas( + WireFormat.makeTag(WireFormat.getTagFieldNumber(tag), WireFormat.WIRETYPE_END_GROUP)); + return true; + case WireFormat.WIRETYPE_END_GROUP: + return false; + case WireFormat.WIRETYPE_FIXED32: + skipRawBytes(FIXED32_SIZE); + return true; + default: + throw InvalidProtocolBufferException.invalidWireType(); + } + } + + @Override + public boolean skipField(final int tag, final CodedOutputStream output) throws IOException { + switch (WireFormat.getTagWireType(tag)) { + case WireFormat.WIRETYPE_VARINT: + { + long value = readInt64(); + output.writeRawVarint32(tag); + output.writeUInt64NoTag(value); + return true; + } + case WireFormat.WIRETYPE_FIXED64: + { + long value = readRawLittleEndian64(); + output.writeRawVarint32(tag); + output.writeFixed64NoTag(value); + return true; + } + case WireFormat.WIRETYPE_LENGTH_DELIMITED: + { + ByteString value = readBytes(); + output.writeRawVarint32(tag); + output.writeBytesNoTag(value); + return true; + } + case WireFormat.WIRETYPE_START_GROUP: + { + output.writeRawVarint32(tag); + skipMessage(output); + int endtag = + WireFormat.makeTag( + WireFormat.getTagFieldNumber(tag), WireFormat.WIRETYPE_END_GROUP); + checkLastTagWas(endtag); + output.writeRawVarint32(endtag); + return true; + } + case WireFormat.WIRETYPE_END_GROUP: + { + return false; + } + case WireFormat.WIRETYPE_FIXED32: + { + int value = readRawLittleEndian32(); + output.writeRawVarint32(tag); + output.writeFixed32NoTag(value); + return true; + } + default: + throw InvalidProtocolBufferException.invalidWireType(); + } } - byteLimit += totalBytesRetired + bufferPos; - final int oldLimit = currentLimit; - if (byteLimit > oldLimit) { + + @Override + public void skipMessage() throws IOException { + while (true) { + final int tag = readTag(); + if (tag == 0 || !skipField(tag)) { + return; + } + } + } + + @Override + public void skipMessage(CodedOutputStream output) throws IOException { + while (true) { + final int tag = readTag(); + if (tag == 0 || !skipField(tag, output)) { + return; + } + } + } + + + // ----------------------------------------------------------------- + + @Override + public double readDouble() throws IOException { + return Double.longBitsToDouble(readRawLittleEndian64()); + } + + @Override + public float readFloat() throws IOException { + return Float.intBitsToFloat(readRawLittleEndian32()); + } + + @Override + public long readUInt64() throws IOException { + return readRawVarint64(); + } + + @Override + public long readInt64() throws IOException { + return readRawVarint64(); + } + + @Override + public int readInt32() throws IOException { + return readRawVarint32(); + } + + @Override + public long readFixed64() throws IOException { + return readRawLittleEndian64(); + } + + @Override + public int readFixed32() throws IOException { + return readRawLittleEndian32(); + } + + @Override + public boolean readBool() throws IOException { + return readRawVarint64() != 0; + } + + @Override + public String readString() throws IOException { + final int size = readRawVarint32(); + if (size > 0 && size <= remaining()) { + // TODO(nathanmittler): Is there a way to avoid this copy? + // TODO(anuraaga): It might be possible to share the optimized loop with + // readStringRequireUtf8 by implementing Java replacement logic there. + // The same as readBytes' logic + byte[] bytes = new byte[size]; + UnsafeUtil.copyMemory(pos, bytes, 0, size); + String result = new String(bytes, UTF_8); + pos += size; + return result; + } + + if (size == 0) { + return ""; + } + if (size < 0) { + throw InvalidProtocolBufferException.negativeSize(); + } throw InvalidProtocolBufferException.truncatedMessage(); } - currentLimit = byteLimit; - recomputeBufferSizeAfterLimit(); + @Override + public String readStringRequireUtf8() throws IOException { + final int size = readRawVarint32(); + if (size > 0 && size <= remaining()) { + if (ENABLE_CUSTOM_UTF8_DECODE) { + final int bufferPos = bufferPos(pos); + String result = Utf8.decodeUtf8(buffer, bufferPos, size); + pos += size; + return result; + } else { + // TODO(nathanmittler): Is there a way to avoid this copy? + // The same as readBytes' logic + byte[] bytes = new byte[size]; + UnsafeUtil.copyMemory(pos, bytes, 0, size); + // TODO(martinrb): We could save a pass by validating while decoding. + if (!Utf8.isValidUtf8(bytes)) { + throw InvalidProtocolBufferException.invalidUtf8(); + } - return oldLimit; - } + String result = new String(bytes, UTF_8); + pos += size; + return result; + } + } - private void recomputeBufferSizeAfterLimit() { - bufferSize += bufferSizeAfterLimit; - final int bufferEnd = totalBytesRetired + bufferSize; - if (bufferEnd > currentLimit) { - // Limit is in current buffer. - bufferSizeAfterLimit = bufferEnd - currentLimit; - bufferSize -= bufferSizeAfterLimit; - } else { - bufferSizeAfterLimit = 0; + if (size == 0) { + return ""; + } + if (size <= 0) { + throw InvalidProtocolBufferException.negativeSize(); + } + throw InvalidProtocolBufferException.truncatedMessage(); } - } - /** - * Discards the current limit, returning to the previous limit. - * - * @param oldLimit The old limit, as returned by {@code pushLimit}. - */ - public void popLimit(final int oldLimit) { - currentLimit = oldLimit; - recomputeBufferSizeAfterLimit(); - } + @Override + public void readGroup( + final int fieldNumber, + final MessageLite.Builder builder, + final ExtensionRegistryLite extensionRegistry) + throws IOException { + if (recursionDepth >= recursionLimit) { + throw InvalidProtocolBufferException.recursionLimitExceeded(); + } + ++recursionDepth; + builder.mergeFrom(this, extensionRegistry); + checkLastTagWas(WireFormat.makeTag(fieldNumber, WireFormat.WIRETYPE_END_GROUP)); + --recursionDepth; + } - /** - * Returns the number of bytes to be read before the current limit. - * If no limit is set, returns -1. - */ - public int getBytesUntilLimit() { - if (currentLimit == Integer.MAX_VALUE) { - return -1; + + @Override + public <T extends MessageLite> T readGroup( + final int fieldNumber, + final Parser<T> parser, + final ExtensionRegistryLite extensionRegistry) + throws IOException { + if (recursionDepth >= recursionLimit) { + throw InvalidProtocolBufferException.recursionLimitExceeded(); + } + ++recursionDepth; + T result = parser.parsePartialFrom(this, extensionRegistry); + checkLastTagWas(WireFormat.makeTag(fieldNumber, WireFormat.WIRETYPE_END_GROUP)); + --recursionDepth; + return result; } - final int currentAbsolutePosition = totalBytesRetired + bufferPos; - return currentLimit - currentAbsolutePosition; - } + @Deprecated + @Override + public void readUnknownGroup(final int fieldNumber, final MessageLite.Builder builder) + throws IOException { + readGroup(fieldNumber, builder, ExtensionRegistryLite.getEmptyRegistry()); + } - /** - * Returns true if the stream has reached the end of the input. This is the - * case if either the end of the underlying input source has been reached or - * if the stream has reached a limit created using {@link #pushLimit(int)}. - */ - public boolean isAtEnd() throws IOException { - return bufferPos == bufferSize && !tryRefillBuffer(1); - } + @Override + public void readMessage( + final MessageLite.Builder builder, final ExtensionRegistryLite extensionRegistry) + throws IOException { + final int length = readRawVarint32(); + if (recursionDepth >= recursionLimit) { + throw InvalidProtocolBufferException.recursionLimitExceeded(); + } + final int oldLimit = pushLimit(length); + ++recursionDepth; + builder.mergeFrom(this, extensionRegistry); + checkLastTagWas(0); + --recursionDepth; + popLimit(oldLimit); + } - /** - * The total bytes read up to the current position. Calling - * {@link #resetSizeCounter()} resets this value to zero. - */ - public int getTotalBytesRead() { - return totalBytesRetired + bufferPos; - } - private interface RefillCallback { - void onRefill(); - } + @Override + public <T extends MessageLite> T readMessage( + final Parser<T> parser, final ExtensionRegistryLite extensionRegistry) throws IOException { + int length = readRawVarint32(); + if (recursionDepth >= recursionLimit) { + throw InvalidProtocolBufferException.recursionLimitExceeded(); + } + final int oldLimit = pushLimit(length); + ++recursionDepth; + T result = parser.parsePartialFrom(this, extensionRegistry); + checkLastTagWas(0); + --recursionDepth; + popLimit(oldLimit); + return result; + } - private RefillCallback refillCallback = null; + @Override + public ByteString readBytes() throws IOException { + final int size = readRawVarint32(); + if (size > 0 && size <= remaining()) { + if (immutable && enableAliasing) { + final ByteBuffer result = slice(pos, pos + size); + pos += size; + return ByteString.wrap(result); + } else { + // Use UnsafeUtil to copy the memory to bytes instead of using ByteBuffer ways. + byte[] bytes = new byte[size]; + UnsafeUtil.copyMemory(pos, bytes, 0, size); + pos += size; + return ByteString.wrap(bytes); + } + } + + if (size == 0) { + return ByteString.EMPTY; + } + if (size < 0) { + throw InvalidProtocolBufferException.negativeSize(); + } + throw InvalidProtocolBufferException.truncatedMessage(); + } + + @Override + public byte[] readByteArray() throws IOException { + return readRawBytes(readRawVarint32()); + } + + @Override + public ByteBuffer readByteBuffer() throws IOException { + final int size = readRawVarint32(); + if (size > 0 && size <= remaining()) { + // "Immutable" implies that buffer is backing a ByteString. + // Disallow slicing in this case to prevent the caller from modifying the contents + // of the ByteString. + if (!immutable && enableAliasing) { + final ByteBuffer result = slice(pos, pos + size); + pos += size; + return result; + } else { + // The same as readBytes' logic + byte[] bytes = new byte[size]; + UnsafeUtil.copyMemory(pos, bytes, 0, size); + pos += size; + return ByteBuffer.wrap(bytes); + } + // TODO(nathanmittler): Investigate making the ByteBuffer be made read-only + } + + if (size == 0) { + return EMPTY_BYTE_BUFFER; + } + if (size < 0) { + throw InvalidProtocolBufferException.negativeSize(); + } + throw InvalidProtocolBufferException.truncatedMessage(); + } + + @Override + public int readUInt32() throws IOException { + return readRawVarint32(); + } + + @Override + public int readEnum() throws IOException { + return readRawVarint32(); + } + + @Override + public int readSFixed32() throws IOException { + return readRawLittleEndian32(); + } + + @Override + public long readSFixed64() throws IOException { + return readRawLittleEndian64(); + } + + @Override + public int readSInt32() throws IOException { + return decodeZigZag32(readRawVarint32()); + } + + @Override + public long readSInt64() throws IOException { + return decodeZigZag64(readRawVarint64()); + } + + // ================================================================= + + @Override + public int readRawVarint32() throws IOException { + // See implementation notes for readRawVarint64 + fastpath: + { + long tempPos = pos; + + if (limit == tempPos) { + break fastpath; + } + + int x; + if ((x = UnsafeUtil.getByte(tempPos++)) >= 0) { + pos = tempPos; + return x; + } else if (limit - tempPos < 9) { + break fastpath; + } else if ((x ^= (UnsafeUtil.getByte(tempPos++) << 7)) < 0) { + x ^= (~0 << 7); + } else if ((x ^= (UnsafeUtil.getByte(tempPos++) << 14)) >= 0) { + x ^= (~0 << 7) ^ (~0 << 14); + } else if ((x ^= (UnsafeUtil.getByte(tempPos++) << 21)) < 0) { + x ^= (~0 << 7) ^ (~0 << 14) ^ (~0 << 21); + } else { + int y = UnsafeUtil.getByte(tempPos++); + x ^= y << 28; + x ^= (~0 << 7) ^ (~0 << 14) ^ (~0 << 21) ^ (~0 << 28); + if (y < 0 + && UnsafeUtil.getByte(tempPos++) < 0 + && UnsafeUtil.getByte(tempPos++) < 0 + && UnsafeUtil.getByte(tempPos++) < 0 + && UnsafeUtil.getByte(tempPos++) < 0 + && UnsafeUtil.getByte(tempPos++) < 0) { + break fastpath; // Will throw malformedVarint() + } + } + pos = tempPos; + return x; + } + return (int) readRawVarint64SlowPath(); + } + + private void skipRawVarint() throws IOException { + if (remaining() >= MAX_VARINT_SIZE) { + skipRawVarintFastPath(); + } else { + skipRawVarintSlowPath(); + } + } + + private void skipRawVarintFastPath() throws IOException { + for (int i = 0; i < MAX_VARINT_SIZE; i++) { + if (UnsafeUtil.getByte(pos++) >= 0) { + return; + } + } + throw InvalidProtocolBufferException.malformedVarint(); + } + + private void skipRawVarintSlowPath() throws IOException { + for (int i = 0; i < MAX_VARINT_SIZE; i++) { + if (readRawByte() >= 0) { + return; + } + } + throw InvalidProtocolBufferException.malformedVarint(); + } + + @Override + public long readRawVarint64() throws IOException { + // Implementation notes: + // + // Optimized for one-byte values, expected to be common. + // The particular code below was selected from various candidates + // empirically, by winning VarintBenchmark. + // + // Sign extension of (signed) Java bytes is usually a nuisance, but + // we exploit it here to more easily obtain the sign of bytes read. + // Instead of cleaning up the sign extension bits by masking eagerly, + // we delay until we find the final (positive) byte, when we clear all + // accumulated bits with one xor. We depend on javac to constant fold. + fastpath: + { + long tempPos = pos; + + if (limit == tempPos) { + break fastpath; + } + + long x; + int y; + if ((y = UnsafeUtil.getByte(tempPos++)) >= 0) { + pos = tempPos; + return y; + } else if (limit - tempPos < 9) { + break fastpath; + } else if ((y ^= (UnsafeUtil.getByte(tempPos++) << 7)) < 0) { + x = y ^ (~0 << 7); + } else if ((y ^= (UnsafeUtil.getByte(tempPos++) << 14)) >= 0) { + x = y ^ ((~0 << 7) ^ (~0 << 14)); + } else if ((y ^= (UnsafeUtil.getByte(tempPos++) << 21)) < 0) { + x = y ^ ((~0 << 7) ^ (~0 << 14) ^ (~0 << 21)); + } else if ((x = y ^ ((long) UnsafeUtil.getByte(tempPos++) << 28)) >= 0L) { + x ^= (~0L << 7) ^ (~0L << 14) ^ (~0L << 21) ^ (~0L << 28); + } else if ((x ^= ((long) UnsafeUtil.getByte(tempPos++) << 35)) < 0L) { + x ^= (~0L << 7) ^ (~0L << 14) ^ (~0L << 21) ^ (~0L << 28) ^ (~0L << 35); + } else if ((x ^= ((long) UnsafeUtil.getByte(tempPos++) << 42)) >= 0L) { + x ^= (~0L << 7) ^ (~0L << 14) ^ (~0L << 21) ^ (~0L << 28) ^ (~0L << 35) ^ (~0L << 42); + } else if ((x ^= ((long) UnsafeUtil.getByte(tempPos++) << 49)) < 0L) { + x ^= + (~0L << 7) + ^ (~0L << 14) + ^ (~0L << 21) + ^ (~0L << 28) + ^ (~0L << 35) + ^ (~0L << 42) + ^ (~0L << 49); + } else { + x ^= ((long) UnsafeUtil.getByte(tempPos++) << 56); + x ^= + (~0L << 7) + ^ (~0L << 14) + ^ (~0L << 21) + ^ (~0L << 28) + ^ (~0L << 35) + ^ (~0L << 42) + ^ (~0L << 49) + ^ (~0L << 56); + if (x < 0L) { + if (UnsafeUtil.getByte(tempPos++) < 0L) { + break fastpath; // Will throw malformedVarint() + } + } + } + pos = tempPos; + return x; + } + return readRawVarint64SlowPath(); + } + + @Override + long readRawVarint64SlowPath() throws IOException { + long result = 0; + for (int shift = 0; shift < 64; shift += 7) { + final byte b = readRawByte(); + result |= (long) (b & 0x7F) << shift; + if ((b & 0x80) == 0) { + return result; + } + } + throw InvalidProtocolBufferException.malformedVarint(); + } + + @Override + public int readRawLittleEndian32() throws IOException { + long tempPos = pos; + + if (limit - tempPos < FIXED32_SIZE) { + throw InvalidProtocolBufferException.truncatedMessage(); + } + + pos = tempPos + FIXED32_SIZE; + return (((UnsafeUtil.getByte(tempPos) & 0xff)) + | ((UnsafeUtil.getByte(tempPos + 1) & 0xff) << 8) + | ((UnsafeUtil.getByte(tempPos + 2) & 0xff) << 16) + | ((UnsafeUtil.getByte(tempPos + 3) & 0xff) << 24)); + } + + @Override + public long readRawLittleEndian64() throws IOException { + long tempPos = pos; + + if (limit - tempPos < FIXED64_SIZE) { + throw InvalidProtocolBufferException.truncatedMessage(); + } + + pos = tempPos + FIXED64_SIZE; + return (((UnsafeUtil.getByte(tempPos) & 0xffL)) + | ((UnsafeUtil.getByte(tempPos + 1) & 0xffL) << 8) + | ((UnsafeUtil.getByte(tempPos + 2) & 0xffL) << 16) + | ((UnsafeUtil.getByte(tempPos + 3) & 0xffL) << 24) + | ((UnsafeUtil.getByte(tempPos + 4) & 0xffL) << 32) + | ((UnsafeUtil.getByte(tempPos + 5) & 0xffL) << 40) + | ((UnsafeUtil.getByte(tempPos + 6) & 0xffL) << 48) + | ((UnsafeUtil.getByte(tempPos + 7) & 0xffL) << 56)); + } + + @Override + public void enableAliasing(boolean enabled) { + this.enableAliasing = enabled; + } + + @Override + public void resetSizeCounter() { + startPos = pos; + } + + @Override + public int pushLimit(int byteLimit) throws InvalidProtocolBufferException { + if (byteLimit < 0) { + throw InvalidProtocolBufferException.negativeSize(); + } + byteLimit += getTotalBytesRead(); + final int oldLimit = currentLimit; + if (byteLimit > oldLimit) { + throw InvalidProtocolBufferException.truncatedMessage(); + } + currentLimit = byteLimit; + + recomputeBufferSizeAfterLimit(); + + return oldLimit; + } + + @Override + public void popLimit(final int oldLimit) { + currentLimit = oldLimit; + recomputeBufferSizeAfterLimit(); + } + + @Override + public int getBytesUntilLimit() { + if (currentLimit == Integer.MAX_VALUE) { + return -1; + } + + return currentLimit - getTotalBytesRead(); + } + + @Override + public boolean isAtEnd() throws IOException { + return pos == limit; + } + + @Override + public int getTotalBytesRead() { + return (int) (pos - startPos); + } + + @Override + public byte readRawByte() throws IOException { + if (pos == limit) { + throw InvalidProtocolBufferException.truncatedMessage(); + } + return UnsafeUtil.getByte(pos++); + } + + @Override + public byte[] readRawBytes(final int length) throws IOException { + if (length >= 0 && length <= remaining()) { + byte[] bytes = new byte[length]; + slice(pos, pos + length).get(bytes); + pos += length; + return bytes; + } + + if (length <= 0) { + if (length == 0) { + return EMPTY_BYTE_ARRAY; + } else { + throw InvalidProtocolBufferException.negativeSize(); + } + } - /** - * Reads more bytes from the input, making at least {@code n} bytes available - * in the buffer. Caller must ensure that the requested space is not yet - * available, and that the requested space is less than BUFFER_SIZE. - * - * @throws InvalidProtocolBufferException The end of the stream or the current - * limit was reached. - */ - private void refillBuffer(int n) throws IOException { - if (!tryRefillBuffer(n)) { throw InvalidProtocolBufferException.truncatedMessage(); } + + @Override + public void skipRawBytes(final int length) throws IOException { + if (length >= 0 && length <= remaining()) { + // We have all the bytes we need already. + pos += length; + return; + } + + if (length < 0) { + throw InvalidProtocolBufferException.negativeSize(); + } + throw InvalidProtocolBufferException.truncatedMessage(); + } + + private void recomputeBufferSizeAfterLimit() { + limit += bufferSizeAfterLimit; + final int bufferEnd = (int) (limit - startPos); + if (bufferEnd > currentLimit) { + // Limit is in current buffer. + bufferSizeAfterLimit = bufferEnd - currentLimit; + limit -= bufferSizeAfterLimit; + } else { + bufferSizeAfterLimit = 0; + } + } + + private int remaining() { + return (int) (limit - pos); + } + + private int bufferPos(long pos) { + return (int) (pos - address); + } + + private ByteBuffer slice(long begin, long end) throws IOException { + int prevPos = buffer.position(); + int prevLimit = buffer.limit(); + try { + buffer.position(bufferPos(begin)); + buffer.limit(bufferPos(end)); + return buffer.slice(); + } catch (IllegalArgumentException e) { + throw InvalidProtocolBufferException.truncatedMessage(); + } finally { + buffer.position(prevPos); + buffer.limit(prevLimit); + } + } } /** - * Tries to read more bytes from the input, making at least {@code n} bytes - * available in the buffer. Caller must ensure that the requested space is - * not yet available, and that the requested space is less than BUFFER_SIZE. - * - * @return {@code true} if the bytes could be made available; {@code false} - * if the end of the stream or the current limit was reached. + * Implementation of {@link CodedInputStream} that uses an {@link InputStream} as the data source. */ - private boolean tryRefillBuffer(int n) throws IOException { - if (bufferPos + n <= bufferSize) { - throw new IllegalStateException( - "refillBuffer() called when " + n + - " bytes were already available in buffer"); + private static final class StreamDecoder extends CodedInputStream { + private final InputStream input; + private final byte[] buffer; + /** bufferSize represents how many bytes are currently filled in the buffer */ + private int bufferSize; + + private int bufferSizeAfterLimit; + private int pos; + private int lastTag; + + /** + * The total number of bytes read before the current buffer. The total bytes read up to the + * current position can be computed as {@code totalBytesRetired + pos}. This value may be + * negative if reading started in the middle of the current buffer (e.g. if the constructor that + * takes a byte array and an offset was used). + */ + private int totalBytesRetired; + + /** The absolute position of the end of the current message. */ + private int currentLimit = Integer.MAX_VALUE; + + private StreamDecoder(final InputStream input, int bufferSize) { + checkNotNull(input, "input"); + this.input = input; + this.buffer = new byte[bufferSize]; + this.bufferSize = 0; + pos = 0; + totalBytesRetired = 0; } - if (totalBytesRetired + bufferPos + n > currentLimit) { - // Oops, we hit a limit. - return false; + @Override + public int readTag() throws IOException { + if (isAtEnd()) { + lastTag = 0; + return 0; + } + + lastTag = readRawVarint32(); + if (WireFormat.getTagFieldNumber(lastTag) == 0) { + // If we actually read zero (or any tag number corresponding to field + // number zero), that's not a valid tag. + throw InvalidProtocolBufferException.invalidTag(); + } + return lastTag; + } + + @Override + public void checkLastTagWas(final int value) throws InvalidProtocolBufferException { + if (lastTag != value) { + throw InvalidProtocolBufferException.invalidEndTag(); + } + } + + @Override + public int getLastTag() { + return lastTag; + } + + @Override + public boolean skipField(final int tag) throws IOException { + switch (WireFormat.getTagWireType(tag)) { + case WireFormat.WIRETYPE_VARINT: + skipRawVarint(); + return true; + case WireFormat.WIRETYPE_FIXED64: + skipRawBytes(FIXED64_SIZE); + return true; + case WireFormat.WIRETYPE_LENGTH_DELIMITED: + skipRawBytes(readRawVarint32()); + return true; + case WireFormat.WIRETYPE_START_GROUP: + skipMessage(); + checkLastTagWas( + WireFormat.makeTag(WireFormat.getTagFieldNumber(tag), WireFormat.WIRETYPE_END_GROUP)); + return true; + case WireFormat.WIRETYPE_END_GROUP: + return false; + case WireFormat.WIRETYPE_FIXED32: + skipRawBytes(FIXED32_SIZE); + return true; + default: + throw InvalidProtocolBufferException.invalidWireType(); + } + } + + @Override + public boolean skipField(final int tag, final CodedOutputStream output) throws IOException { + switch (WireFormat.getTagWireType(tag)) { + case WireFormat.WIRETYPE_VARINT: + { + long value = readInt64(); + output.writeRawVarint32(tag); + output.writeUInt64NoTag(value); + return true; + } + case WireFormat.WIRETYPE_FIXED64: + { + long value = readRawLittleEndian64(); + output.writeRawVarint32(tag); + output.writeFixed64NoTag(value); + return true; + } + case WireFormat.WIRETYPE_LENGTH_DELIMITED: + { + ByteString value = readBytes(); + output.writeRawVarint32(tag); + output.writeBytesNoTag(value); + return true; + } + case WireFormat.WIRETYPE_START_GROUP: + { + output.writeRawVarint32(tag); + skipMessage(output); + int endtag = + WireFormat.makeTag( + WireFormat.getTagFieldNumber(tag), WireFormat.WIRETYPE_END_GROUP); + checkLastTagWas(endtag); + output.writeRawVarint32(endtag); + return true; + } + case WireFormat.WIRETYPE_END_GROUP: + { + return false; + } + case WireFormat.WIRETYPE_FIXED32: + { + int value = readRawLittleEndian32(); + output.writeRawVarint32(tag); + output.writeFixed32NoTag(value); + return true; + } + default: + throw InvalidProtocolBufferException.invalidWireType(); + } + } + + @Override + public void skipMessage() throws IOException { + while (true) { + final int tag = readTag(); + if (tag == 0 || !skipField(tag)) { + return; + } + } + } + + @Override + public void skipMessage(CodedOutputStream output) throws IOException { + while (true) { + final int tag = readTag(); + if (tag == 0 || !skipField(tag, output)) { + return; + } + } + } + + /** Collects the bytes skipped and returns the data in a ByteBuffer. */ + private class SkippedDataSink implements RefillCallback { + private int lastPos = pos; + private ByteArrayOutputStream byteArrayStream; + + @Override + public void onRefill() { + if (byteArrayStream == null) { + byteArrayStream = new ByteArrayOutputStream(); + } + byteArrayStream.write(buffer, lastPos, pos - lastPos); + lastPos = 0; + } + + /** Gets skipped data in a ByteBuffer. This method should only be called once. */ + ByteBuffer getSkippedData() { + if (byteArrayStream == null) { + return ByteBuffer.wrap(buffer, lastPos, pos - lastPos); + } else { + byteArrayStream.write(buffer, lastPos, pos); + return ByteBuffer.wrap(byteArrayStream.toByteArray()); + } + } + } + + + // ----------------------------------------------------------------- + + @Override + public double readDouble() throws IOException { + return Double.longBitsToDouble(readRawLittleEndian64()); + } + + @Override + public float readFloat() throws IOException { + return Float.intBitsToFloat(readRawLittleEndian32()); + } + + @Override + public long readUInt64() throws IOException { + return readRawVarint64(); + } + + @Override + public long readInt64() throws IOException { + return readRawVarint64(); + } + + @Override + public int readInt32() throws IOException { + return readRawVarint32(); + } + + @Override + public long readFixed64() throws IOException { + return readRawLittleEndian64(); + } + + @Override + public int readFixed32() throws IOException { + return readRawLittleEndian32(); } - if (refillCallback != null) { - refillCallback.onRefill(); + @Override + public boolean readBool() throws IOException { + return readRawVarint64() != 0; + } + + @Override + public String readString() throws IOException { + final int size = readRawVarint32(); + if (size > 0 && size <= (bufferSize - pos)) { + // Fast path: We already have the bytes in a contiguous buffer, so + // just copy directly from it. + final String result = new String(buffer, pos, size, UTF_8); + pos += size; + return result; + } + if (size == 0) { + return ""; + } + if (size <= bufferSize) { + refillBuffer(size); + String result = new String(buffer, pos, size, UTF_8); + pos += size; + return result; + } + // Slow path: Build a byte array first then copy it. + return new String(readRawBytesSlowPath(size), UTF_8); + } + + @Override + public String readStringRequireUtf8() throws IOException { + final int size = readRawVarint32(); + final byte[] bytes; + final int oldPos = pos; + final int tempPos; + if (size <= (bufferSize - oldPos) && size > 0) { + // Fast path: We already have the bytes in a contiguous buffer, so + // just copy directly from it. + bytes = buffer; + pos = oldPos + size; + tempPos = oldPos; + } else if (size == 0) { + return ""; + } else if (size <= bufferSize) { + refillBuffer(size); + bytes = buffer; + tempPos = 0; + pos = tempPos + size; + } else { + // Slow path: Build a byte array first then copy it. + bytes = readRawBytesSlowPath(size); + tempPos = 0; + } + if (ENABLE_CUSTOM_UTF8_DECODE) { + return Utf8.decodeUtf8(bytes, tempPos, size); + } else { + // TODO(martinrb): We could save a pass by validating while decoding. + if (!Utf8.isValidUtf8(bytes, tempPos, tempPos + size)) { + throw InvalidProtocolBufferException.invalidUtf8(); + } + return new String(bytes, tempPos, size, UTF_8); + } + } + + @Override + public void readGroup( + final int fieldNumber, + final MessageLite.Builder builder, + final ExtensionRegistryLite extensionRegistry) + throws IOException { + if (recursionDepth >= recursionLimit) { + throw InvalidProtocolBufferException.recursionLimitExceeded(); + } + ++recursionDepth; + builder.mergeFrom(this, extensionRegistry); + checkLastTagWas(WireFormat.makeTag(fieldNumber, WireFormat.WIRETYPE_END_GROUP)); + --recursionDepth; + } + + + @Override + public <T extends MessageLite> T readGroup( + final int fieldNumber, + final Parser<T> parser, + final ExtensionRegistryLite extensionRegistry) + throws IOException { + if (recursionDepth >= recursionLimit) { + throw InvalidProtocolBufferException.recursionLimitExceeded(); + } + ++recursionDepth; + T result = parser.parsePartialFrom(this, extensionRegistry); + checkLastTagWas(WireFormat.makeTag(fieldNumber, WireFormat.WIRETYPE_END_GROUP)); + --recursionDepth; + return result; + } + + @Deprecated + @Override + public void readUnknownGroup(final int fieldNumber, final MessageLite.Builder builder) + throws IOException { + readGroup(fieldNumber, builder, ExtensionRegistryLite.getEmptyRegistry()); + } + + @Override + public void readMessage( + final MessageLite.Builder builder, final ExtensionRegistryLite extensionRegistry) + throws IOException { + final int length = readRawVarint32(); + if (recursionDepth >= recursionLimit) { + throw InvalidProtocolBufferException.recursionLimitExceeded(); + } + final int oldLimit = pushLimit(length); + ++recursionDepth; + builder.mergeFrom(this, extensionRegistry); + checkLastTagWas(0); + --recursionDepth; + popLimit(oldLimit); + } + + + @Override + public <T extends MessageLite> T readMessage( + final Parser<T> parser, final ExtensionRegistryLite extensionRegistry) throws IOException { + int length = readRawVarint32(); + if (recursionDepth >= recursionLimit) { + throw InvalidProtocolBufferException.recursionLimitExceeded(); + } + final int oldLimit = pushLimit(length); + ++recursionDepth; + T result = parser.parsePartialFrom(this, extensionRegistry); + checkLastTagWas(0); + --recursionDepth; + popLimit(oldLimit); + return result; + } + + @Override + public ByteString readBytes() throws IOException { + final int size = readRawVarint32(); + if (size <= (bufferSize - pos) && size > 0) { + // Fast path: We already have the bytes in a contiguous buffer, so + // just copy directly from it. + final ByteString result = ByteString.copyFrom(buffer, pos, size); + pos += size; + return result; + } + if (size == 0) { + return ByteString.EMPTY; + } + return readBytesSlowPath(size); + } + + @Override + public byte[] readByteArray() throws IOException { + final int size = readRawVarint32(); + if (size <= (bufferSize - pos) && size > 0) { + // Fast path: We already have the bytes in a contiguous buffer, so + // just copy directly from it. + final byte[] result = Arrays.copyOfRange(buffer, pos, pos + size); + pos += size; + return result; + } else { + // Slow path: Build a byte array first then copy it. + return readRawBytesSlowPath(size); + } + } + + @Override + public ByteBuffer readByteBuffer() throws IOException { + final int size = readRawVarint32(); + if (size <= (bufferSize - pos) && size > 0) { + // Fast path: We already have the bytes in a contiguous buffer. + ByteBuffer result = ByteBuffer.wrap(Arrays.copyOfRange(buffer, pos, pos + size)); + pos += size; + return result; + } + if (size == 0) { + return Internal.EMPTY_BYTE_BUFFER; + } + // Slow path: Build a byte array first then copy it. + return ByteBuffer.wrap(readRawBytesSlowPath(size)); + } + + @Override + public int readUInt32() throws IOException { + return readRawVarint32(); + } + + @Override + public int readEnum() throws IOException { + return readRawVarint32(); } - if (input != null) { - int pos = bufferPos; - if (pos > 0) { - if (bufferSize > pos) { - System.arraycopy(buffer, pos, buffer, 0, bufferSize - pos); + @Override + public int readSFixed32() throws IOException { + return readRawLittleEndian32(); + } + + @Override + public long readSFixed64() throws IOException { + return readRawLittleEndian64(); + } + + @Override + public int readSInt32() throws IOException { + return decodeZigZag32(readRawVarint32()); + } + + @Override + public long readSInt64() throws IOException { + return decodeZigZag64(readRawVarint64()); + } + + // ================================================================= + + @Override + public int readRawVarint32() throws IOException { + // See implementation notes for readRawVarint64 + fastpath: + { + int tempPos = pos; + + if (bufferSize == tempPos) { + break fastpath; + } + + final byte[] buffer = this.buffer; + int x; + if ((x = buffer[tempPos++]) >= 0) { + pos = tempPos; + return x; + } else if (bufferSize - tempPos < 9) { + break fastpath; + } else if ((x ^= (buffer[tempPos++] << 7)) < 0) { + x ^= (~0 << 7); + } else if ((x ^= (buffer[tempPos++] << 14)) >= 0) { + x ^= (~0 << 7) ^ (~0 << 14); + } else if ((x ^= (buffer[tempPos++] << 21)) < 0) { + x ^= (~0 << 7) ^ (~0 << 14) ^ (~0 << 21); + } else { + int y = buffer[tempPos++]; + x ^= y << 28; + x ^= (~0 << 7) ^ (~0 << 14) ^ (~0 << 21) ^ (~0 << 28); + if (y < 0 + && buffer[tempPos++] < 0 + && buffer[tempPos++] < 0 + && buffer[tempPos++] < 0 + && buffer[tempPos++] < 0 + && buffer[tempPos++] < 0) { + break fastpath; // Will throw malformedVarint() + } + } + pos = tempPos; + return x; + } + return (int) readRawVarint64SlowPath(); + } + + private void skipRawVarint() throws IOException { + if (bufferSize - pos >= MAX_VARINT_SIZE) { + skipRawVarintFastPath(); + } else { + skipRawVarintSlowPath(); + } + } + + private void skipRawVarintFastPath() throws IOException { + for (int i = 0; i < MAX_VARINT_SIZE; i++) { + if (buffer[pos++] >= 0) { + return; + } + } + throw InvalidProtocolBufferException.malformedVarint(); + } + + private void skipRawVarintSlowPath() throws IOException { + for (int i = 0; i < MAX_VARINT_SIZE; i++) { + if (readRawByte() >= 0) { + return; + } + } + throw InvalidProtocolBufferException.malformedVarint(); + } + + @Override + public long readRawVarint64() throws IOException { + // Implementation notes: + // + // Optimized for one-byte values, expected to be common. + // The particular code below was selected from various candidates + // empirically, by winning VarintBenchmark. + // + // Sign extension of (signed) Java bytes is usually a nuisance, but + // we exploit it here to more easily obtain the sign of bytes read. + // Instead of cleaning up the sign extension bits by masking eagerly, + // we delay until we find the final (positive) byte, when we clear all + // accumulated bits with one xor. We depend on javac to constant fold. + fastpath: + { + int tempPos = pos; + + if (bufferSize == tempPos) { + break fastpath; + } + + final byte[] buffer = this.buffer; + long x; + int y; + if ((y = buffer[tempPos++]) >= 0) { + pos = tempPos; + return y; + } else if (bufferSize - tempPos < 9) { + break fastpath; + } else if ((y ^= (buffer[tempPos++] << 7)) < 0) { + x = y ^ (~0 << 7); + } else if ((y ^= (buffer[tempPos++] << 14)) >= 0) { + x = y ^ ((~0 << 7) ^ (~0 << 14)); + } else if ((y ^= (buffer[tempPos++] << 21)) < 0) { + x = y ^ ((~0 << 7) ^ (~0 << 14) ^ (~0 << 21)); + } else if ((x = y ^ ((long) buffer[tempPos++] << 28)) >= 0L) { + x ^= (~0L << 7) ^ (~0L << 14) ^ (~0L << 21) ^ (~0L << 28); + } else if ((x ^= ((long) buffer[tempPos++] << 35)) < 0L) { + x ^= (~0L << 7) ^ (~0L << 14) ^ (~0L << 21) ^ (~0L << 28) ^ (~0L << 35); + } else if ((x ^= ((long) buffer[tempPos++] << 42)) >= 0L) { + x ^= (~0L << 7) ^ (~0L << 14) ^ (~0L << 21) ^ (~0L << 28) ^ (~0L << 35) ^ (~0L << 42); + } else if ((x ^= ((long) buffer[tempPos++] << 49)) < 0L) { + x ^= + (~0L << 7) + ^ (~0L << 14) + ^ (~0L << 21) + ^ (~0L << 28) + ^ (~0L << 35) + ^ (~0L << 42) + ^ (~0L << 49); + } else { + x ^= ((long) buffer[tempPos++] << 56); + x ^= + (~0L << 7) + ^ (~0L << 14) + ^ (~0L << 21) + ^ (~0L << 28) + ^ (~0L << 35) + ^ (~0L << 42) + ^ (~0L << 49) + ^ (~0L << 56); + if (x < 0L) { + if (buffer[tempPos++] < 0L) { + break fastpath; // Will throw malformedVarint() + } + } + } + pos = tempPos; + return x; + } + return readRawVarint64SlowPath(); + } + + @Override + long readRawVarint64SlowPath() throws IOException { + long result = 0; + for (int shift = 0; shift < 64; shift += 7) { + final byte b = readRawByte(); + result |= (long) (b & 0x7F) << shift; + if ((b & 0x80) == 0) { + return result; + } + } + throw InvalidProtocolBufferException.malformedVarint(); + } + + @Override + public int readRawLittleEndian32() throws IOException { + int tempPos = pos; + + if (bufferSize - tempPos < FIXED32_SIZE) { + refillBuffer(FIXED32_SIZE); + tempPos = pos; + } + + final byte[] buffer = this.buffer; + pos = tempPos + FIXED32_SIZE; + return (((buffer[tempPos] & 0xff)) + | ((buffer[tempPos + 1] & 0xff) << 8) + | ((buffer[tempPos + 2] & 0xff) << 16) + | ((buffer[tempPos + 3] & 0xff) << 24)); + } + + @Override + public long readRawLittleEndian64() throws IOException { + int tempPos = pos; + + if (bufferSize - tempPos < FIXED64_SIZE) { + refillBuffer(FIXED64_SIZE); + tempPos = pos; + } + + final byte[] buffer = this.buffer; + pos = tempPos + FIXED64_SIZE; + return (((buffer[tempPos] & 0xffL)) + | ((buffer[tempPos + 1] & 0xffL) << 8) + | ((buffer[tempPos + 2] & 0xffL) << 16) + | ((buffer[tempPos + 3] & 0xffL) << 24) + | ((buffer[tempPos + 4] & 0xffL) << 32) + | ((buffer[tempPos + 5] & 0xffL) << 40) + | ((buffer[tempPos + 6] & 0xffL) << 48) + | ((buffer[tempPos + 7] & 0xffL) << 56)); + } + + // ----------------------------------------------------------------- + + @Override + public void enableAliasing(boolean enabled) { + // TODO(nathanmittler): Ideally we should throw here. Do nothing for backward compatibility. + } + + @Override + public void resetSizeCounter() { + totalBytesRetired = -pos; + } + + @Override + public int pushLimit(int byteLimit) throws InvalidProtocolBufferException { + if (byteLimit < 0) { + throw InvalidProtocolBufferException.negativeSize(); + } + byteLimit += totalBytesRetired + pos; + final int oldLimit = currentLimit; + if (byteLimit > oldLimit) { + throw InvalidProtocolBufferException.truncatedMessage(); + } + currentLimit = byteLimit; + + recomputeBufferSizeAfterLimit(); + + return oldLimit; + } + + private void recomputeBufferSizeAfterLimit() { + bufferSize += bufferSizeAfterLimit; + final int bufferEnd = totalBytesRetired + bufferSize; + if (bufferEnd > currentLimit) { + // Limit is in current buffer. + bufferSizeAfterLimit = bufferEnd - currentLimit; + bufferSize -= bufferSizeAfterLimit; + } else { + bufferSizeAfterLimit = 0; + } + } + + @Override + public void popLimit(final int oldLimit) { + currentLimit = oldLimit; + recomputeBufferSizeAfterLimit(); + } + + @Override + public int getBytesUntilLimit() { + if (currentLimit == Integer.MAX_VALUE) { + return -1; + } + + final int currentAbsolutePosition = totalBytesRetired + pos; + return currentLimit - currentAbsolutePosition; + } + + @Override + public boolean isAtEnd() throws IOException { + return pos == bufferSize && !tryRefillBuffer(1); + } + + @Override + public int getTotalBytesRead() { + return totalBytesRetired + pos; + } + + private interface RefillCallback { + void onRefill(); + } + + private RefillCallback refillCallback = null; + + /** + * Reads more bytes from the input, making at least {@code n} bytes available in the buffer. + * Caller must ensure that the requested space is not yet available, and that the requested + * space is less than BUFFER_SIZE. + * + * @throws InvalidProtocolBufferException The end of the stream or the current limit was + * reached. + */ + private void refillBuffer(int n) throws IOException { + if (!tryRefillBuffer(n)) { + // We have to distinguish the exception between sizeLimitExceeded and truncatedMessage. So + // we just throw an sizeLimitExceeded exception here if it exceeds the sizeLimit + if (n > sizeLimit - totalBytesRetired - pos) { + throw InvalidProtocolBufferException.sizeLimitExceeded(); + } else { + throw InvalidProtocolBufferException.truncatedMessage(); + } + } + } + + /** + * Tries to read more bytes from the input, making at least {@code n} bytes available in the + * buffer. Caller must ensure that the requested space is not yet available, and that the + * requested space is less than BUFFER_SIZE. + * + * @return {@code true} If the bytes could be made available; {@code false} 1. Current at the + * end of the stream 2. The current limit was reached 3. The total size limit was reached + */ + private boolean tryRefillBuffer(int n) throws IOException { + if (pos + n <= bufferSize) { + throw new IllegalStateException( + "refillBuffer() called when " + n + " bytes were already available in buffer"); + } + + // Check whether the size of total message needs to read is bigger than the size limit. + // We shouldn't throw an exception here as isAtEnd() function needs to get this function's + // return as the result. + if (n > sizeLimit - totalBytesRetired - pos) { + return false; + } + + // Shouldn't throw the exception here either. + if (totalBytesRetired + pos + n > currentLimit) { + // Oops, we hit a limit. + return false; + } + + if (refillCallback != null) { + refillCallback.onRefill(); + } + + int tempPos = pos; + if (tempPos > 0) { + if (bufferSize > tempPos) { + System.arraycopy(buffer, tempPos, buffer, 0, bufferSize - tempPos); } - totalBytesRetired += pos; - bufferSize -= pos; - bufferPos = 0; + totalBytesRetired += tempPos; + bufferSize -= tempPos; + pos = 0; } - int bytesRead = input.read(buffer, bufferSize, buffer.length - bufferSize); + // Here we should refill the buffer as many bytes as possible. + int bytesRead = + input.read( + buffer, + bufferSize, + Math.min( + // the size of allocated but unused bytes in the buffer + buffer.length - bufferSize, + // do not exceed the total bytes limit + sizeLimit - totalBytesRetired - bufferSize)); if (bytesRead == 0 || bytesRead < -1 || bytesRead > buffer.length) { throw new IllegalStateException( - "InputStream#read(byte[]) returned invalid result: " + bytesRead + - "\nThe InputStream implementation is buggy."); + "InputStream#read(byte[]) returned invalid result: " + + bytesRead + + "\nThe InputStream implementation is buggy."); } if (bytesRead > 0) { bufferSize += bytesRead; - // Integer-overflow-conscious check against sizeLimit - if (totalBytesRetired + n - sizeLimit > 0) { - throw InvalidProtocolBufferException.sizeLimitExceeded(); - } recomputeBufferSizeAfterLimit(); return (bufferSize >= n) ? true : tryRefillBuffer(n); } + + return false; } - return false; - } + @Override + public byte readRawByte() throws IOException { + if (pos == bufferSize) { + refillBuffer(1); + } + return buffer[pos++]; + } - /** - * Read one byte from the input. - * - * @throws InvalidProtocolBufferException The end of the stream or the current - * limit was reached. - */ - public byte readRawByte() throws IOException { - if (bufferPos == bufferSize) { - refillBuffer(1); + @Override + public byte[] readRawBytes(final int size) throws IOException { + final int tempPos = pos; + if (size <= (bufferSize - tempPos) && size > 0) { + pos = tempPos + size; + return Arrays.copyOfRange(buffer, tempPos, tempPos + size); + } else { + return readRawBytesSlowPath(size); + } } - return buffer[bufferPos++]; - } - /** - * Read a fixed size of bytes from the input. - * - * @throws InvalidProtocolBufferException The end of the stream or the current - * limit was reached. - */ - public byte[] readRawBytes(final int size) throws IOException { - final int pos = bufferPos; - if (size <= (bufferSize - pos) && size > 0) { - bufferPos = pos + size; - return Arrays.copyOfRange(buffer, pos, pos + size); - } else { - return readRawBytesSlowPath(size); + /** + * Exactly like readRawBytes, but caller must have already checked the fast path: (size <= + * (bufferSize - pos) && size > 0) + */ + private byte[] readRawBytesSlowPath(final int size) throws IOException { + // Attempt to read the data in one byte array when it's safe to do. + byte[] result = readRawBytesSlowPathOneChunk(size); + if (result != null) { + return result; + } + + final int originalBufferPos = pos; + final int bufferedBytes = bufferSize - pos; + + // Mark the current buffer consumed. + totalBytesRetired += bufferSize; + pos = 0; + bufferSize = 0; + + // Determine the number of bytes we need to read from the input stream. + int sizeLeft = size - bufferedBytes; + + // The size is very large. For security reasons we read them in small + // chunks. + List<byte[]> chunks = readRawBytesSlowPathRemainingChunks(sizeLeft); + + // OK, got everything. Now concatenate it all into one buffer. + final byte[] bytes = new byte[size]; + + // Start by copying the leftover bytes from this.buffer. + System.arraycopy(buffer, originalBufferPos, bytes, 0, bufferedBytes); + + // And now all the chunks. + int tempPos = bufferedBytes; + for (final byte[] chunk : chunks) { + System.arraycopy(chunk, 0, bytes, tempPos, chunk.length); + tempPos += chunk.length; + } + + // Done. + return bytes; + } + + /** + * Attempts to read the data in one byte array when it's safe to do. Returns null if the size to + * read is too large and needs to be allocated in smaller chunks for security reasons. + */ + private byte[] readRawBytesSlowPathOneChunk(final int size) throws IOException { + if (size == 0) { + return Internal.EMPTY_BYTE_ARRAY; + } + if (size < 0) { + throw InvalidProtocolBufferException.negativeSize(); + } + + // Integer-overflow-conscious check that the message size so far has not exceeded sizeLimit. + int currentMessageSize = totalBytesRetired + pos + size; + if (currentMessageSize - sizeLimit > 0) { + throw InvalidProtocolBufferException.sizeLimitExceeded(); + } + + // Verify that the message size so far has not exceeded currentLimit. + if (currentMessageSize > currentLimit) { + // Read to the end of the stream anyway. + skipRawBytes(currentLimit - totalBytesRetired - pos); + throw InvalidProtocolBufferException.truncatedMessage(); + } + + final int bufferedBytes = bufferSize - pos; + // Determine the number of bytes we need to read from the input stream. + int sizeLeft = size - bufferedBytes; + // TODO(nathanmittler): Consider using a value larger than DEFAULT_BUFFER_SIZE. + if (sizeLeft < DEFAULT_BUFFER_SIZE || sizeLeft <= input.available()) { + // Either the bytes we need are known to be available, or the required buffer is + // within an allowed threshold - go ahead and allocate the buffer now. + final byte[] bytes = new byte[size]; + + // Copy all of the buffered bytes to the result buffer. + System.arraycopy(buffer, pos, bytes, 0, bufferedBytes); + totalBytesRetired += bufferSize; + pos = 0; + bufferSize = 0; + + // Fill the remaining bytes from the input stream. + int tempPos = bufferedBytes; + while (tempPos < bytes.length) { + int n = input.read(bytes, tempPos, size - tempPos); + if (n == -1) { + throw InvalidProtocolBufferException.truncatedMessage(); + } + totalBytesRetired += n; + tempPos += n; + } + + return bytes; + } + + return null; + } + + /** Reads the remaining data in small chunks from the input stream. */ + private List<byte[]> readRawBytesSlowPathRemainingChunks(int sizeLeft) throws IOException { + // The size is very large. For security reasons, we can't allocate the + // entire byte array yet. The size comes directly from the input, so a + // maliciously-crafted message could provide a bogus very large size in + // order to trick the app into allocating a lot of memory. We avoid this + // by allocating and reading only a small chunk at a time, so that the + // malicious message must actually *be* extremely large to cause + // problems. Meanwhile, we limit the allowed size of a message elsewhere. + final List<byte[]> chunks = new ArrayList<byte[]>(); + + while (sizeLeft > 0) { + // TODO(nathanmittler): Consider using a value larger than DEFAULT_BUFFER_SIZE. + final byte[] chunk = new byte[Math.min(sizeLeft, DEFAULT_BUFFER_SIZE)]; + int tempPos = 0; + while (tempPos < chunk.length) { + final int n = input.read(chunk, tempPos, chunk.length - tempPos); + if (n == -1) { + throw InvalidProtocolBufferException.truncatedMessage(); + } + totalBytesRetired += n; + tempPos += n; + } + sizeLeft -= chunk.length; + chunks.add(chunk); + } + + return chunks; + } + + /** + * Like readBytes, but caller must have already checked the fast path: (size <= (bufferSize - + * pos) && size > 0 || size == 0) + */ + private ByteString readBytesSlowPath(final int size) throws IOException { + final byte[] result = readRawBytesSlowPathOneChunk(size); + if (result != null) { + return ByteString.wrap(result); + } + + final int originalBufferPos = pos; + final int bufferedBytes = bufferSize - pos; + + // Mark the current buffer consumed. + totalBytesRetired += bufferSize; + pos = 0; + bufferSize = 0; + + // Determine the number of bytes we need to read from the input stream. + int sizeLeft = size - bufferedBytes; + + // The size is very large. For security reasons we read them in small + // chunks. + List<byte[]> chunks = readRawBytesSlowPathRemainingChunks(sizeLeft); + + // Wrap the byte arrays into a single ByteString. + List<ByteString> byteStrings = new ArrayList<ByteString>(1 + chunks.size()); + byteStrings.add(ByteString.copyFrom(buffer, originalBufferPos, bufferedBytes)); + for (byte[] chunk : chunks) { + byteStrings.add(ByteString.wrap(chunk)); + } + return ByteString.copyFrom(byteStrings); + } + + @Override + public void skipRawBytes(final int size) throws IOException { + if (size <= (bufferSize - pos) && size >= 0) { + // We have all the bytes we need already. + pos += size; + } else { + skipRawBytesSlowPath(size); + } + } + + /** + * Exactly like skipRawBytes, but caller must have already checked the fast path: (size <= + * (bufferSize - pos) && size >= 0) + */ + private void skipRawBytesSlowPath(final int size) throws IOException { + if (size < 0) { + throw InvalidProtocolBufferException.negativeSize(); + } + + if (totalBytesRetired + pos + size > currentLimit) { + // Read to the end of the stream anyway. + skipRawBytes(currentLimit - totalBytesRetired - pos); + // Then fail. + throw InvalidProtocolBufferException.truncatedMessage(); + } + + // Skipping more bytes than are in the buffer. First skip what we have. + int tempPos = bufferSize - pos; + pos = bufferSize; + + // Keep refilling the buffer until we get to the point we wanted to skip to. + // This has the side effect of ensuring the limits are updated correctly. + refillBuffer(1); + while (size - tempPos > bufferSize) { + tempPos += bufferSize; + pos = bufferSize; + refillBuffer(1); + } + + pos = size - tempPos; } } /** - * Exactly like readRawBytes, but caller must have already checked the fast - * path: (size <= (bufferSize - pos) && size > 0) + * Implementation of {@link CodedInputStream} that uses an {@link Iterable <ByteBuffer>} as the + * data source. Requires the use of {@code sun.misc.Unsafe} to perform fast reads on the buffer. */ - private byte[] readRawBytesSlowPath(final int size) throws IOException { - if (size <= 0) { + private static final class IterableDirectByteBufferDecoder extends CodedInputStream { + /** The object that need to decode. */ + private Iterable<ByteBuffer> input; + /** The {@link Iterator} with type {@link ByteBuffer} of {@code input} */ + private Iterator<ByteBuffer> iterator; + /** The current ByteBuffer; */ + private ByteBuffer currentByteBuffer; + /** + * If {@code true}, indicates that all the buffer are backing a {@link ByteString} and are + * therefore considered to be an immutable input source. + */ + private boolean immutable; + /** + * If {@code true}, indicates that calls to read {@link ByteString} or {@code byte[]} + * <strong>may</strong> return slices of the underlying buffer, rather than copies. + */ + private boolean enableAliasing; + /** The global total message length limit */ + private int totalBufferSize; + /** The amount of available data in the input beyond {@link #currentLimit}. */ + private int bufferSizeAfterCurrentLimit; + /** The absolute position of the end of the current message. */ + private int currentLimit = Integer.MAX_VALUE; + /** The last tag that was read from this stream. */ + private int lastTag; + /** Total Bytes have been Read from the {@link Iterable} {@link ByteBuffer} */ + private int totalBytesRead; + /** The start position offset of the whole message, used as to reset the totalBytesRead */ + private int startOffset; + /** The current position for current ByteBuffer */ + private long currentByteBufferPos; + + private long currentByteBufferStartPos; + /** + * If the current ByteBuffer is unsafe-direct based, currentAddress is the start address of this + * ByteBuffer; otherwise should be zero. + */ + private long currentAddress; + /** The limit position for current ByteBuffer */ + private long currentByteBufferLimit; + + /** + * The constructor of {@code Iterable<ByteBuffer>} decoder. + * + * @param inputBufs The input data. + * @param size The total size of the input data. + * @param immutableFlag whether the input data is immutable. + */ + private IterableDirectByteBufferDecoder( + Iterable<ByteBuffer> inputBufs, int size, boolean immutableFlag) { + totalBufferSize = size; + input = inputBufs; + iterator = input.iterator(); + immutable = immutableFlag; + startOffset = totalBytesRead = 0; if (size == 0) { - return Internal.EMPTY_BYTE_ARRAY; + currentByteBuffer = EMPTY_BYTE_BUFFER; + currentByteBufferPos = 0; + currentByteBufferStartPos = 0; + currentByteBufferLimit = 0; + currentAddress = 0; } else { + tryGetNextByteBuffer(); + } + } + + /** To get the next ByteBuffer from {@code input}, and then update the parameters */ + private void getNextByteBuffer() throws InvalidProtocolBufferException { + if (!iterator.hasNext()) { + throw InvalidProtocolBufferException.truncatedMessage(); + } + tryGetNextByteBuffer(); + } + + private void tryGetNextByteBuffer() { + currentByteBuffer = iterator.next(); + totalBytesRead += (int) (currentByteBufferPos - currentByteBufferStartPos); + currentByteBufferPos = currentByteBuffer.position(); + currentByteBufferStartPos = currentByteBufferPos; + currentByteBufferLimit = currentByteBuffer.limit(); + currentAddress = UnsafeUtil.addressOffset(currentByteBuffer); + currentByteBufferPos += currentAddress; + currentByteBufferStartPos += currentAddress; + currentByteBufferLimit += currentAddress; + } + + @Override + public int readTag() throws IOException { + if (isAtEnd()) { + lastTag = 0; + return 0; + } + + lastTag = readRawVarint32(); + if (WireFormat.getTagFieldNumber(lastTag) == 0) { + // If we actually read zero (or any tag number corresponding to field + // number zero), that's not a valid tag. + throw InvalidProtocolBufferException.invalidTag(); + } + return lastTag; + } + + @Override + public void checkLastTagWas(final int value) throws InvalidProtocolBufferException { + if (lastTag != value) { + throw InvalidProtocolBufferException.invalidEndTag(); + } + } + + @Override + public int getLastTag() { + return lastTag; + } + + @Override + public boolean skipField(final int tag) throws IOException { + switch (WireFormat.getTagWireType(tag)) { + case WireFormat.WIRETYPE_VARINT: + skipRawVarint(); + return true; + case WireFormat.WIRETYPE_FIXED64: + skipRawBytes(FIXED64_SIZE); + return true; + case WireFormat.WIRETYPE_LENGTH_DELIMITED: + skipRawBytes(readRawVarint32()); + return true; + case WireFormat.WIRETYPE_START_GROUP: + skipMessage(); + checkLastTagWas( + WireFormat.makeTag(WireFormat.getTagFieldNumber(tag), WireFormat.WIRETYPE_END_GROUP)); + return true; + case WireFormat.WIRETYPE_END_GROUP: + return false; + case WireFormat.WIRETYPE_FIXED32: + skipRawBytes(FIXED32_SIZE); + return true; + default: + throw InvalidProtocolBufferException.invalidWireType(); + } + } + + @Override + public boolean skipField(final int tag, final CodedOutputStream output) throws IOException { + switch (WireFormat.getTagWireType(tag)) { + case WireFormat.WIRETYPE_VARINT: + { + long value = readInt64(); + output.writeRawVarint32(tag); + output.writeUInt64NoTag(value); + return true; + } + case WireFormat.WIRETYPE_FIXED64: + { + long value = readRawLittleEndian64(); + output.writeRawVarint32(tag); + output.writeFixed64NoTag(value); + return true; + } + case WireFormat.WIRETYPE_LENGTH_DELIMITED: + { + ByteString value = readBytes(); + output.writeRawVarint32(tag); + output.writeBytesNoTag(value); + return true; + } + case WireFormat.WIRETYPE_START_GROUP: + { + output.writeRawVarint32(tag); + skipMessage(output); + int endtag = + WireFormat.makeTag( + WireFormat.getTagFieldNumber(tag), WireFormat.WIRETYPE_END_GROUP); + checkLastTagWas(endtag); + output.writeRawVarint32(endtag); + return true; + } + case WireFormat.WIRETYPE_END_GROUP: + { + return false; + } + case WireFormat.WIRETYPE_FIXED32: + { + int value = readRawLittleEndian32(); + output.writeRawVarint32(tag); + output.writeFixed32NoTag(value); + return true; + } + default: + throw InvalidProtocolBufferException.invalidWireType(); + } + } + + @Override + public void skipMessage() throws IOException { + while (true) { + final int tag = readTag(); + if (tag == 0 || !skipField(tag)) { + return; + } + } + } + + @Override + public void skipMessage(CodedOutputStream output) throws IOException { + while (true) { + final int tag = readTag(); + if (tag == 0 || !skipField(tag, output)) { + return; + } + } + } + + // ----------------------------------------------------------------- + + @Override + public double readDouble() throws IOException { + return Double.longBitsToDouble(readRawLittleEndian64()); + } + + @Override + public float readFloat() throws IOException { + return Float.intBitsToFloat(readRawLittleEndian32()); + } + + @Override + public long readUInt64() throws IOException { + return readRawVarint64(); + } + + @Override + public long readInt64() throws IOException { + return readRawVarint64(); + } + + @Override + public int readInt32() throws IOException { + return readRawVarint32(); + } + + @Override + public long readFixed64() throws IOException { + return readRawLittleEndian64(); + } + + @Override + public int readFixed32() throws IOException { + return readRawLittleEndian32(); + } + + @Override + public boolean readBool() throws IOException { + return readRawVarint64() != 0; + } + + @Override + public String readString() throws IOException { + final int size = readRawVarint32(); + if (size > 0 && size <= currentByteBufferLimit - currentByteBufferPos) { + byte[] bytes = new byte[size]; + UnsafeUtil.copyMemory(currentByteBufferPos, bytes, 0, size); + String result = new String(bytes, UTF_8); + currentByteBufferPos += size; + return result; + } else if (size > 0 && size <= remaining()) { + // TODO(yilunchong): To use an underlying bytes[] instead of allocating a new bytes[] + byte[] bytes = new byte[size]; + readRawBytesTo(bytes, 0, size); + String result = new String(bytes, UTF_8); + return result; + } + + if (size == 0) { + return ""; + } + if (size < 0) { + throw InvalidProtocolBufferException.negativeSize(); + } + throw InvalidProtocolBufferException.truncatedMessage(); + } + + @Override + public String readStringRequireUtf8() throws IOException { + final int size = readRawVarint32(); + if (size > 0 && size <= currentByteBufferLimit - currentByteBufferPos) { + if (ENABLE_CUSTOM_UTF8_DECODE) { + final int bufferPos = (int) (currentByteBufferPos - currentByteBufferStartPos); + String result = Utf8.decodeUtf8(currentByteBuffer, bufferPos, size); + currentByteBufferPos += size; + return result; + } else { + byte[] bytes = new byte[size]; + UnsafeUtil.copyMemory(currentByteBufferPos, bytes, 0, size); + if (!Utf8.isValidUtf8(bytes)) { + throw InvalidProtocolBufferException.invalidUtf8(); + } + String result = new String(bytes, UTF_8); + currentByteBufferPos += size; + return result; + } + } + if (size >= 0 && size <= remaining()) { + byte[] bytes = new byte[size]; + readRawBytesTo(bytes, 0, size); + if (ENABLE_CUSTOM_UTF8_DECODE) { + return Utf8.decodeUtf8(bytes, 0, size); + } else { + if (!Utf8.isValidUtf8(bytes)) { + throw InvalidProtocolBufferException.invalidUtf8(); + } + String result = new String(bytes, UTF_8); + return result; + } + } + + if (size == 0) { + return ""; + } + if (size <= 0) { throw InvalidProtocolBufferException.negativeSize(); } + throw InvalidProtocolBufferException.truncatedMessage(); } - // Verify that the message size so far has not exceeded sizeLimit. - int currentMessageSize = totalBytesRetired + bufferPos + size; - if (currentMessageSize > sizeLimit) { - throw InvalidProtocolBufferException.sizeLimitExceeded(); + @Override + public void readGroup( + final int fieldNumber, + final MessageLite.Builder builder, + final ExtensionRegistryLite extensionRegistry) + throws IOException { + if (recursionDepth >= recursionLimit) { + throw InvalidProtocolBufferException.recursionLimitExceeded(); + } + ++recursionDepth; + builder.mergeFrom(this, extensionRegistry); + checkLastTagWas(WireFormat.makeTag(fieldNumber, WireFormat.WIRETYPE_END_GROUP)); + --recursionDepth; } - // Verify that the message size so far has not exceeded currentLimit. - if (currentMessageSize > currentLimit) { - // Read to the end of the stream anyway. - skipRawBytes(currentLimit - totalBytesRetired - bufferPos); + + @Override + public <T extends MessageLite> T readGroup( + final int fieldNumber, + final Parser<T> parser, + final ExtensionRegistryLite extensionRegistry) + throws IOException { + if (recursionDepth >= recursionLimit) { + throw InvalidProtocolBufferException.recursionLimitExceeded(); + } + ++recursionDepth; + T result = parser.parsePartialFrom(this, extensionRegistry); + checkLastTagWas(WireFormat.makeTag(fieldNumber, WireFormat.WIRETYPE_END_GROUP)); + --recursionDepth; + return result; + } + + @Deprecated + @Override + public void readUnknownGroup(final int fieldNumber, final MessageLite.Builder builder) + throws IOException { + readGroup(fieldNumber, builder, ExtensionRegistryLite.getEmptyRegistry()); + } + + @Override + public void readMessage( + final MessageLite.Builder builder, final ExtensionRegistryLite extensionRegistry) + throws IOException { + final int length = readRawVarint32(); + if (recursionDepth >= recursionLimit) { + throw InvalidProtocolBufferException.recursionLimitExceeded(); + } + final int oldLimit = pushLimit(length); + ++recursionDepth; + builder.mergeFrom(this, extensionRegistry); + checkLastTagWas(0); + --recursionDepth; + popLimit(oldLimit); + } + + + @Override + public <T extends MessageLite> T readMessage( + final Parser<T> parser, final ExtensionRegistryLite extensionRegistry) throws IOException { + int length = readRawVarint32(); + if (recursionDepth >= recursionLimit) { + throw InvalidProtocolBufferException.recursionLimitExceeded(); + } + final int oldLimit = pushLimit(length); + ++recursionDepth; + T result = parser.parsePartialFrom(this, extensionRegistry); + checkLastTagWas(0); + --recursionDepth; + popLimit(oldLimit); + return result; + } + + @Override + public ByteString readBytes() throws IOException { + final int size = readRawVarint32(); + if (size > 0 && size <= currentByteBufferLimit - currentByteBufferPos) { + if (immutable && enableAliasing) { + final int idx = (int) (currentByteBufferPos - currentAddress); + final ByteString result = ByteString.wrap(slice(idx, idx + size)); + currentByteBufferPos += size; + return result; + } else { + byte[] bytes; + bytes = new byte[size]; + UnsafeUtil.copyMemory(currentByteBufferPos, bytes, 0, size); + currentByteBufferPos += size; + return ByteString.wrap(bytes); + } + } else if (size > 0 && size <= remaining()) { + byte[] temp = new byte[size]; + readRawBytesTo(temp, 0, size); + return ByteString.wrap(temp); + } + + if (size == 0) { + return ByteString.EMPTY; + } + if (size < 0) { + throw InvalidProtocolBufferException.negativeSize(); + } throw InvalidProtocolBufferException.truncatedMessage(); } - // We need the input stream to proceed. - if (input == null) { + @Override + public byte[] readByteArray() throws IOException { + return readRawBytes(readRawVarint32()); + } + + @Override + public ByteBuffer readByteBuffer() throws IOException { + final int size = readRawVarint32(); + if (size > 0 && size <= currentRemaining()) { + if (!immutable && enableAliasing) { + currentByteBufferPos += size; + return slice( + (int) (currentByteBufferPos - currentAddress - size), + (int) (currentByteBufferPos - currentAddress)); + } else { + byte[] bytes = new byte[size]; + UnsafeUtil.copyMemory(currentByteBufferPos, bytes, 0, size); + currentByteBufferPos += size; + return ByteBuffer.wrap(bytes); + } + } else if (size > 0 && size <= remaining()) { + byte[] temp = new byte[size]; + readRawBytesTo(temp, 0, size); + return ByteBuffer.wrap(temp); + } + + if (size == 0) { + return EMPTY_BYTE_BUFFER; + } + if (size < 0) { + throw InvalidProtocolBufferException.negativeSize(); + } throw InvalidProtocolBufferException.truncatedMessage(); } - final int originalBufferPos = bufferPos; - final int bufferedBytes = bufferSize - bufferPos; + @Override + public int readUInt32() throws IOException { + return readRawVarint32(); + } - // Mark the current buffer consumed. - totalBytesRetired += bufferSize; - bufferPos = 0; - bufferSize = 0; + @Override + public int readEnum() throws IOException { + return readRawVarint32(); + } - // Determine the number of bytes we need to read from the input stream. - int sizeLeft = size - bufferedBytes; - // TODO(nathanmittler): Consider using a value larger than BUFFER_SIZE. - if (sizeLeft < BUFFER_SIZE || sizeLeft <= input.available()) { - // Either the bytes we need are known to be available, or the required buffer is - // within an allowed threshold - go ahead and allocate the buffer now. - final byte[] bytes = new byte[size]; + @Override + public int readSFixed32() throws IOException { + return readRawLittleEndian32(); + } - // Copy all of the buffered bytes to the result buffer. - System.arraycopy(buffer, originalBufferPos, bytes, 0, bufferedBytes); + @Override + public long readSFixed64() throws IOException { + return readRawLittleEndian64(); + } - // Fill the remaining bytes from the input stream. - int pos = bufferedBytes; - while (pos < bytes.length) { - int n = input.read(bytes, pos, size - pos); - if (n == -1) { - throw InvalidProtocolBufferException.truncatedMessage(); + @Override + public int readSInt32() throws IOException { + return decodeZigZag32(readRawVarint32()); + } + + @Override + public long readSInt64() throws IOException { + return decodeZigZag64(readRawVarint64()); + } + + @Override + public int readRawVarint32() throws IOException { + fastpath: + { + long tempPos = currentByteBufferPos; + + if (currentByteBufferLimit == currentByteBufferPos) { + break fastpath; } - totalBytesRetired += n; - pos += n; + + int x; + if ((x = UnsafeUtil.getByte(tempPos++)) >= 0) { + currentByteBufferPos++; + return x; + } else if (currentByteBufferLimit - currentByteBufferPos < 10) { + break fastpath; + } else if ((x ^= (UnsafeUtil.getByte(tempPos++) << 7)) < 0) { + x ^= (~0 << 7); + } else if ((x ^= (UnsafeUtil.getByte(tempPos++) << 14)) >= 0) { + x ^= (~0 << 7) ^ (~0 << 14); + } else if ((x ^= (UnsafeUtil.getByte(tempPos++) << 21)) < 0) { + x ^= (~0 << 7) ^ (~0 << 14) ^ (~0 << 21); + } else { + int y = UnsafeUtil.getByte(tempPos++); + x ^= y << 28; + x ^= (~0 << 7) ^ (~0 << 14) ^ (~0 << 21) ^ (~0 << 28); + if (y < 0 + && UnsafeUtil.getByte(tempPos++) < 0 + && UnsafeUtil.getByte(tempPos++) < 0 + && UnsafeUtil.getByte(tempPos++) < 0 + && UnsafeUtil.getByte(tempPos++) < 0 + && UnsafeUtil.getByte(tempPos++) < 0) { + break fastpath; // Will throw malformedVarint() + } + } + currentByteBufferPos = tempPos; + return x; } + return (int) readRawVarint64SlowPath(); + } - return bytes; + @Override + public long readRawVarint64() throws IOException { + fastpath: + { + long tempPos = currentByteBufferPos; + + if (currentByteBufferLimit == currentByteBufferPos) { + break fastpath; + } + + long x; + int y; + if ((y = UnsafeUtil.getByte(tempPos++)) >= 0) { + currentByteBufferPos++; + return y; + } else if (currentByteBufferLimit - currentByteBufferPos < 10) { + break fastpath; + } else if ((y ^= (UnsafeUtil.getByte(tempPos++) << 7)) < 0) { + x = y ^ (~0 << 7); + } else if ((y ^= (UnsafeUtil.getByte(tempPos++) << 14)) >= 0) { + x = y ^ ((~0 << 7) ^ (~0 << 14)); + } else if ((y ^= (UnsafeUtil.getByte(tempPos++) << 21)) < 0) { + x = y ^ ((~0 << 7) ^ (~0 << 14) ^ (~0 << 21)); + } else if ((x = y ^ ((long) UnsafeUtil.getByte(tempPos++) << 28)) >= 0L) { + x ^= (~0L << 7) ^ (~0L << 14) ^ (~0L << 21) ^ (~0L << 28); + } else if ((x ^= ((long) UnsafeUtil.getByte(tempPos++) << 35)) < 0L) { + x ^= (~0L << 7) ^ (~0L << 14) ^ (~0L << 21) ^ (~0L << 28) ^ (~0L << 35); + } else if ((x ^= ((long) UnsafeUtil.getByte(tempPos++) << 42)) >= 0L) { + x ^= (~0L << 7) ^ (~0L << 14) ^ (~0L << 21) ^ (~0L << 28) ^ (~0L << 35) ^ (~0L << 42); + } else if ((x ^= ((long) UnsafeUtil.getByte(tempPos++) << 49)) < 0L) { + x ^= + (~0L << 7) + ^ (~0L << 14) + ^ (~0L << 21) + ^ (~0L << 28) + ^ (~0L << 35) + ^ (~0L << 42) + ^ (~0L << 49); + } else { + x ^= ((long) UnsafeUtil.getByte(tempPos++) << 56); + x ^= + (~0L << 7) + ^ (~0L << 14) + ^ (~0L << 21) + ^ (~0L << 28) + ^ (~0L << 35) + ^ (~0L << 42) + ^ (~0L << 49) + ^ (~0L << 56); + if (x < 0L) { + if (UnsafeUtil.getByte(tempPos++) < 0L) { + break fastpath; // Will throw malformedVarint() + } + } + } + currentByteBufferPos = tempPos; + return x; + } + return readRawVarint64SlowPath(); } - // The size is very large. For security reasons, we can't allocate the - // entire byte array yet. The size comes directly from the input, so a - // maliciously-crafted message could provide a bogus very large size in - // order to trick the app into allocating a lot of memory. We avoid this - // by allocating and reading only a small chunk at a time, so that the - // malicious message must actually *be* extremely large to cause - // problems. Meanwhile, we limit the allowed size of a message elsewhere. - final List<byte[]> chunks = new ArrayList<byte[]>(); - - while (sizeLeft > 0) { - // TODO(nathanmittler): Consider using a value larger than BUFFER_SIZE. - final byte[] chunk = new byte[Math.min(sizeLeft, BUFFER_SIZE)]; - int pos = 0; - while (pos < chunk.length) { - final int n = input.read(chunk, pos, chunk.length - pos); - if (n == -1) { - throw InvalidProtocolBufferException.truncatedMessage(); + @Override + long readRawVarint64SlowPath() throws IOException { + long result = 0; + for (int shift = 0; shift < 64; shift += 7) { + final byte b = readRawByte(); + result |= (long) (b & 0x7F) << shift; + if ((b & 0x80) == 0) { + return result; } - totalBytesRetired += n; - pos += n; } - sizeLeft -= chunk.length; - chunks.add(chunk); + throw InvalidProtocolBufferException.malformedVarint(); } - // OK, got everything. Now concatenate it all into one buffer. - final byte[] bytes = new byte[size]; + @Override + public int readRawLittleEndian32() throws IOException { + if (currentRemaining() >= FIXED32_SIZE) { + long tempPos = currentByteBufferPos; + currentByteBufferPos += FIXED32_SIZE; + return (((UnsafeUtil.getByte(tempPos) & 0xff)) + | ((UnsafeUtil.getByte(tempPos + 1) & 0xff) << 8) + | ((UnsafeUtil.getByte(tempPos + 2) & 0xff) << 16) + | ((UnsafeUtil.getByte(tempPos + 3) & 0xff) << 24)); + } + return ((readRawByte() & 0xff) + | ((readRawByte() & 0xff) << 8) + | ((readRawByte() & 0xff) << 16) + | ((readRawByte() & 0xff) << 24)); + } - // Start by copying the leftover bytes from this.buffer. - System.arraycopy(buffer, originalBufferPos, bytes, 0, bufferedBytes); + @Override + public long readRawLittleEndian64() throws IOException { + if (currentRemaining() >= FIXED64_SIZE) { + long tempPos = currentByteBufferPos; + currentByteBufferPos += FIXED64_SIZE; + return (((UnsafeUtil.getByte(tempPos) & 0xffL)) + | ((UnsafeUtil.getByte(tempPos + 1) & 0xffL) << 8) + | ((UnsafeUtil.getByte(tempPos + 2) & 0xffL) << 16) + | ((UnsafeUtil.getByte(tempPos + 3) & 0xffL) << 24) + | ((UnsafeUtil.getByte(tempPos + 4) & 0xffL) << 32) + | ((UnsafeUtil.getByte(tempPos + 5) & 0xffL) << 40) + | ((UnsafeUtil.getByte(tempPos + 6) & 0xffL) << 48) + | ((UnsafeUtil.getByte(tempPos + 7) & 0xffL) << 56)); + } + return ((readRawByte() & 0xffL) + | ((readRawByte() & 0xffL) << 8) + | ((readRawByte() & 0xffL) << 16) + | ((readRawByte() & 0xffL) << 24) + | ((readRawByte() & 0xffL) << 32) + | ((readRawByte() & 0xffL) << 40) + | ((readRawByte() & 0xffL) << 48) + | ((readRawByte() & 0xffL) << 56)); + } - // And now all the chunks. - int pos = bufferedBytes; - for (final byte[] chunk : chunks) { - System.arraycopy(chunk, 0, bytes, pos, chunk.length); - pos += chunk.length; + @Override + public void enableAliasing(boolean enabled) { + this.enableAliasing = enabled; } - // Done. - return bytes; - } + @Override + public void resetSizeCounter() { + startOffset = (int) (totalBytesRead + currentByteBufferPos - currentByteBufferStartPos); + } - /** - * Reads and discards {@code size} bytes. - * - * @throws InvalidProtocolBufferException The end of the stream or the current - * limit was reached. - */ - public void skipRawBytes(final int size) throws IOException { - if (size <= (bufferSize - bufferPos) && size >= 0) { - // We have all the bytes we need already. - bufferPos += size; - } else { - skipRawBytesSlowPath(size); + @Override + public int pushLimit(int byteLimit) throws InvalidProtocolBufferException { + if (byteLimit < 0) { + throw InvalidProtocolBufferException.negativeSize(); + } + byteLimit += getTotalBytesRead(); + final int oldLimit = currentLimit; + if (byteLimit > oldLimit) { + throw InvalidProtocolBufferException.truncatedMessage(); + } + currentLimit = byteLimit; + + recomputeBufferSizeAfterLimit(); + + return oldLimit; } - } - /** - * Exactly like skipRawBytes, but caller must have already checked the fast - * path: (size <= (bufferSize - pos) && size >= 0) - */ - private void skipRawBytesSlowPath(final int size) throws IOException { - if (size < 0) { - throw InvalidProtocolBufferException.negativeSize(); + private void recomputeBufferSizeAfterLimit() { + totalBufferSize += bufferSizeAfterCurrentLimit; + final int bufferEnd = totalBufferSize - startOffset; + if (bufferEnd > currentLimit) { + // Limit is in current buffer. + bufferSizeAfterCurrentLimit = bufferEnd - currentLimit; + totalBufferSize -= bufferSizeAfterCurrentLimit; + } else { + bufferSizeAfterCurrentLimit = 0; + } } - if (totalBytesRetired + bufferPos + size > currentLimit) { - // Read to the end of the stream anyway. - skipRawBytes(currentLimit - totalBytesRetired - bufferPos); - // Then fail. + @Override + public void popLimit(final int oldLimit) { + currentLimit = oldLimit; + recomputeBufferSizeAfterLimit(); + } + + @Override + public int getBytesUntilLimit() { + if (currentLimit == Integer.MAX_VALUE) { + return -1; + } + + return currentLimit - getTotalBytesRead(); + } + + @Override + public boolean isAtEnd() throws IOException { + return totalBytesRead + currentByteBufferPos - currentByteBufferStartPos == totalBufferSize; + } + + @Override + public int getTotalBytesRead() { + return (int) + (totalBytesRead - startOffset + currentByteBufferPos - currentByteBufferStartPos); + } + + @Override + public byte readRawByte() throws IOException { + if (currentRemaining() == 0) { + getNextByteBuffer(); + } + return UnsafeUtil.getByte(currentByteBufferPos++); + } + + @Override + public byte[] readRawBytes(final int length) throws IOException { + if (length >= 0 && length <= currentRemaining()) { + byte[] bytes = new byte[length]; + UnsafeUtil.copyMemory(currentByteBufferPos, bytes, 0, length); + currentByteBufferPos += length; + return bytes; + } + if (length >= 0 && length <= remaining()) { + byte[] bytes = new byte[length]; + readRawBytesTo(bytes, 0, length); + return bytes; + } + + if (length <= 0) { + if (length == 0) { + return EMPTY_BYTE_ARRAY; + } else { + throw InvalidProtocolBufferException.negativeSize(); + } + } + throw InvalidProtocolBufferException.truncatedMessage(); } - // Skipping more bytes than are in the buffer. First skip what we have. - int pos = bufferSize - bufferPos; - bufferPos = bufferSize; + /** + * Try to get raw bytes from {@code input} with the size of {@code length} and copy to {@code + * bytes} array. If the size is bigger than the number of remaining bytes in the input, then + * throw {@code truncatedMessage} exception. + * + * @param bytes + * @param offset + * @param length + * @throws IOException + */ + private void readRawBytesTo(byte[] bytes, int offset, final int length) throws IOException { + if (length >= 0 && length <= remaining()) { + int l = length; + while (l > 0) { + if (currentRemaining() == 0) { + getNextByteBuffer(); + } + int bytesToCopy = Math.min(l, (int) currentRemaining()); + UnsafeUtil.copyMemory(currentByteBufferPos, bytes, length - l + offset, bytesToCopy); + l -= bytesToCopy; + currentByteBufferPos += bytesToCopy; + } + return; + } - // Keep refilling the buffer until we get to the point we wanted to skip to. - // This has the side effect of ensuring the limits are updated correctly. - refillBuffer(1); - while (size - pos > bufferSize) { - pos += bufferSize; - bufferPos = bufferSize; - refillBuffer(1); + if (length <= 0) { + if (length == 0) { + return; + } else { + throw InvalidProtocolBufferException.negativeSize(); + } + } + throw InvalidProtocolBufferException.truncatedMessage(); + } + + @Override + public void skipRawBytes(final int length) throws IOException { + if (length >= 0 + && length + <= (totalBufferSize + - totalBytesRead + - currentByteBufferPos + + currentByteBufferStartPos)) { + // We have all the bytes we need already. + int l = length; + while (l > 0) { + if (currentRemaining() == 0) { + getNextByteBuffer(); + } + int rl = Math.min(l, (int) currentRemaining()); + l -= rl; + currentByteBufferPos += rl; + } + return; + } + + if (length < 0) { + throw InvalidProtocolBufferException.negativeSize(); + } + throw InvalidProtocolBufferException.truncatedMessage(); + } + + // TODO: optimize to fastpath + private void skipRawVarint() throws IOException { + for (int i = 0; i < MAX_VARINT_SIZE; i++) { + if (readRawByte() >= 0) { + return; + } + } + throw InvalidProtocolBufferException.malformedVarint(); } - bufferPos = size - pos; + /** + * Try to get the number of remaining bytes in {@code input}. + * + * @return the number of remaining bytes in {@code input}. + */ + private int remaining() { + return (int) + (totalBufferSize - totalBytesRead - currentByteBufferPos + currentByteBufferStartPos); + } + + /** + * Try to get the number of remaining bytes in {@code currentByteBuffer}. + * + * @return the number of remaining bytes in {@code currentByteBuffer} + */ + private long currentRemaining() { + return (currentByteBufferLimit - currentByteBufferPos); + } + + private ByteBuffer slice(int begin, int end) throws IOException { + int prevPos = currentByteBuffer.position(); + int prevLimit = currentByteBuffer.limit(); + try { + currentByteBuffer.position(begin); + currentByteBuffer.limit(end); + return currentByteBuffer.slice(); + } catch (IllegalArgumentException e) { + throw InvalidProtocolBufferException.truncatedMessage(); + } finally { + currentByteBuffer.position(prevPos); + currentByteBuffer.limit(prevLimit); + } + } } } diff --git a/java/core/src/main/java/com/google/protobuf/CodedOutputStream.java b/java/core/src/main/java/com/google/protobuf/CodedOutputStream.java index e5515285..7b1ac651 100644 --- a/java/core/src/main/java/com/google/protobuf/CodedOutputStream.java +++ b/java/core/src/main/java/com/google/protobuf/CodedOutputStream.java @@ -30,10 +30,12 @@ package com.google.protobuf; +import static com.google.protobuf.WireFormat.FIXED32_SIZE; +import static com.google.protobuf.WireFormat.FIXED64_SIZE; +import static com.google.protobuf.WireFormat.MAX_VARINT_SIZE; import static java.lang.Math.max; import com.google.protobuf.Utf8.UnpairedSurrogateException; - import java.io.IOException; import java.io.OutputStream; import java.nio.BufferOverflowException; @@ -57,17 +59,12 @@ import java.util.logging.Logger; public abstract class CodedOutputStream extends ByteOutput { private static final Logger logger = Logger.getLogger(CodedOutputStream.class.getName()); private static final boolean HAS_UNSAFE_ARRAY_OPERATIONS = UnsafeUtil.hasUnsafeArrayOperations(); - private static final long ARRAY_BASE_OFFSET = UnsafeUtil.getArrayBaseOffset(); - - private static final int FIXED_32_SIZE = 4; - private static final int FIXED_64_SIZE = 8; - private static final int MAX_VARINT_SIZE = 10; /** * @deprecated Use {@link #computeFixed32SizeNoTag(int)} instead. */ @Deprecated - public static final int LITTLE_ENDIAN_32_SIZE = FIXED_32_SIZE; + public static final int LITTLE_ENDIAN_32_SIZE = FIXED32_SIZE; /** * The buffer size used in {@link #newInstance(OutputStream)}. @@ -134,14 +131,27 @@ public abstract class CodedOutputStream extends ByteOutput { return new ArrayEncoder(flatArray, offset, length); } - /** - * Create a new {@code CodedOutputStream} that writes to the given {@link ByteBuffer}. - */ - public static CodedOutputStream newInstance(ByteBuffer byteBuffer) { - if (byteBuffer.hasArray()) { - return new NioHeapEncoder(byteBuffer); + /** Create a new {@code CodedOutputStream} that writes to the given {@link ByteBuffer}. */ + public static CodedOutputStream newInstance(ByteBuffer buffer) { + if (buffer.hasArray()) { + return new HeapNioEncoder(buffer); } - return new NioEncoder(byteBuffer); + if (buffer.isDirect() && !buffer.isReadOnly()) { + return UnsafeDirectNioEncoder.isSupported() + ? newUnsafeInstance(buffer) + : newSafeInstance(buffer); + } + throw new IllegalArgumentException("ByteBuffer is read-only"); + } + + /** For testing purposes only. */ + static CodedOutputStream newUnsafeInstance(ByteBuffer buffer) { + return new UnsafeDirectNioEncoder(buffer); + } + + /** For testing purposes only. */ + static CodedOutputStream newSafeInstance(ByteBuffer buffer) { + return new SafeDirectNioEncoder(buffer); } /** @@ -173,7 +183,7 @@ public abstract class CodedOutputStream extends ByteOutput { * maps are sorted on the lexicographical order of the UTF8 encoded keys. * </ul> */ - void useDeterministicSerialization() { + public void useDeterministicSerialization() { serializationDeterministic = true; } @@ -367,6 +377,7 @@ public abstract class CodedOutputStream extends ByteOutput { public abstract void writeMessage(final int fieldNumber, final MessageLite value) throws IOException; + /** * Write a MessageSet extension field to the stream. For historical reasons, * the wire format differs from normal fields. @@ -471,6 +482,7 @@ public abstract class CodedOutputStream extends ByteOutput { // Abstract to avoid overhead of additional virtual method calls. public abstract void writeMessageNoTag(final MessageLite value) throws IOException; + //================================================================= @ExperimentalApi @@ -656,6 +668,7 @@ public abstract class CodedOutputStream extends ByteOutput { return computeTagSize(fieldNumber) + computeMessageSizeNoTag(value); } + /** * Compute the number of bytes that would be needed to encode a * MessageSet extension to the stream. For historical reasons, @@ -744,7 +757,7 @@ public abstract class CodedOutputStream extends ByteOutput { * {@code fixed32} field. */ public static int computeFixed32SizeNoTag(@SuppressWarnings("unused") final int unused) { - return FIXED_32_SIZE; + return FIXED32_SIZE; } /** @@ -752,7 +765,7 @@ public abstract class CodedOutputStream extends ByteOutput { * {@code sfixed32} field. */ public static int computeSFixed32SizeNoTag(@SuppressWarnings("unused") final int unused) { - return FIXED_32_SIZE; + return FIXED32_SIZE; } /** @@ -802,7 +815,7 @@ public abstract class CodedOutputStream extends ByteOutput { * {@code fixed64} field. */ public static int computeFixed64SizeNoTag(@SuppressWarnings("unused") final long unused) { - return FIXED_64_SIZE; + return FIXED64_SIZE; } /** @@ -810,7 +823,7 @@ public abstract class CodedOutputStream extends ByteOutput { * {@code sfixed64} field. */ public static int computeSFixed64SizeNoTag(@SuppressWarnings("unused") final long unused) { - return FIXED_64_SIZE; + return FIXED64_SIZE; } /** @@ -818,7 +831,7 @@ public abstract class CodedOutputStream extends ByteOutput { * {@code float} field, including tag. */ public static int computeFloatSizeNoTag(@SuppressWarnings("unused") final float unused) { - return FIXED_32_SIZE; + return FIXED32_SIZE; } /** @@ -826,7 +839,7 @@ public abstract class CodedOutputStream extends ByteOutput { * {@code double} field, including tag. */ public static int computeDoubleSizeNoTag(@SuppressWarnings("unused") final double unused) { - return FIXED_64_SIZE; + return FIXED64_SIZE; } /** @@ -903,6 +916,7 @@ public abstract class CodedOutputStream extends ByteOutput { return computeLengthDelimitedFieldSize(value.getSerializedSize()); } + static int computeLengthDelimitedFieldSize(int fieldLength) { return computeUInt32SizeNoTag(fieldLength) + fieldLength; } @@ -979,6 +993,10 @@ public abstract class CodedOutputStream extends ByteOutput { super(MESSAGE); } + OutOfSpaceException(String explanationMessage) { + super(MESSAGE + ": " + explanationMessage); + } + OutOfSpaceException(Throwable cause) { super(MESSAGE, cause); } @@ -1035,6 +1053,7 @@ public abstract class CodedOutputStream extends ByteOutput { writeTag(fieldNumber, WireFormat.WIRETYPE_END_GROUP); } + /** * Write a {@code group} field to the stream. * @@ -1045,6 +1064,7 @@ public abstract class CodedOutputStream extends ByteOutput { value.writeTo(this); } + /** * Compute the number of bytes that would be needed to encode a * {@code group} field, including tag. @@ -1056,6 +1076,7 @@ public abstract class CodedOutputStream extends ByteOutput { return computeTagSize(fieldNumber) * 2 + computeGroupSizeNoTag(value); } + /** * Compute the number of bytes that would be needed to encode a * {@code group} field. @@ -1065,6 +1086,7 @@ public abstract class CodedOutputStream extends ByteOutput { return value.getSerializedSize(); } + /** * Encode and write a varint. {@code value} is treated as * unsigned, so it won't be sign-extended if negative. @@ -1259,6 +1281,7 @@ public abstract class CodedOutputStream extends ByteOutput { writeMessageNoTag(value); } + @Override public final void writeMessageSetExtension(final int fieldNumber, final MessageLite value) throws IOException { @@ -1283,6 +1306,7 @@ public abstract class CodedOutputStream extends ByteOutput { value.writeTo(this); } + @Override public final void write(byte value) throws IOException { try { @@ -1306,15 +1330,12 @@ public abstract class CodedOutputStream extends ByteOutput { @Override public final void writeUInt32NoTag(int value) throws IOException { if (HAS_UNSAFE_ARRAY_OPERATIONS && spaceLeft() >= MAX_VARINT_SIZE) { - long pos = ARRAY_BASE_OFFSET + position; while (true) { if ((value & ~0x7F) == 0) { - UnsafeUtil.putByte(buffer, pos++, (byte) value); - position++; + UnsafeUtil.putByte(buffer, position++, (byte) value); return; } else { - UnsafeUtil.putByte(buffer, pos++, (byte) ((value & 0x7F) | 0x80)); - position++; + UnsafeUtil.putByte(buffer, position++, (byte) ((value & 0x7F) | 0x80)); value >>>= 7; } } @@ -1352,15 +1373,12 @@ public abstract class CodedOutputStream extends ByteOutput { @Override public final void writeUInt64NoTag(long value) throws IOException { if (HAS_UNSAFE_ARRAY_OPERATIONS && spaceLeft() >= MAX_VARINT_SIZE) { - long pos = ARRAY_BASE_OFFSET + position; while (true) { if ((value & ~0x7FL) == 0) { - UnsafeUtil.putByte(buffer, pos++, (byte) value); - position++; + UnsafeUtil.putByte(buffer, position++, (byte) value); return; } else { - UnsafeUtil.putByte(buffer, pos++, (byte) (((int) value & 0x7F) | 0x80)); - position++; + UnsafeUtil.putByte(buffer, position++, (byte) (((int) value & 0x7F) | 0x80)); value >>>= 7; } } @@ -1486,11 +1504,11 @@ public abstract class CodedOutputStream extends ByteOutput { * A {@link CodedOutputStream} that writes directly to a heap {@link ByteBuffer}. Writes are * done directly to the underlying array. The buffer position is only updated after a flush. */ - private static final class NioHeapEncoder extends ArrayEncoder { + private static final class HeapNioEncoder extends ArrayEncoder { private final ByteBuffer byteBuffer; private int initialPosition; - NioHeapEncoder(ByteBuffer byteBuffer) { + HeapNioEncoder(ByteBuffer byteBuffer) { super(byteBuffer.array(), byteBuffer.arrayOffset() + byteBuffer.position(), byteBuffer.remaining()); this.byteBuffer = byteBuffer; @@ -1505,14 +1523,15 @@ public abstract class CodedOutputStream extends ByteOutput { } /** - * A {@link CodedOutputStream} that writes directly to a {@link ByteBuffer}. + * A {@link CodedOutputStream} that writes directly to a direct {@link ByteBuffer}, using only + * safe operations.. */ - private static final class NioEncoder extends CodedOutputStream { + private static final class SafeDirectNioEncoder extends CodedOutputStream { private final ByteBuffer originalBuffer; private final ByteBuffer buffer; private final int initialPosition; - NioEncoder(ByteBuffer buffer) { + SafeDirectNioEncoder(ByteBuffer buffer) { this.originalBuffer = buffer; this.buffer = buffer.duplicate().order(ByteOrder.LITTLE_ENDIAN); initialPosition = buffer.position(); @@ -1599,6 +1618,7 @@ public abstract class CodedOutputStream extends ByteOutput { writeMessageNoTag(value); } + @Override public void writeMessageSetExtension(final int fieldNumber, final MessageLite value) throws IOException { @@ -1623,6 +1643,7 @@ public abstract class CodedOutputStream extends ByteOutput { value.writeTo(this); } + @Override public void write(byte value) throws IOException { try { @@ -1815,6 +1836,357 @@ public abstract class CodedOutputStream extends ByteOutput { } /** + * A {@link CodedOutputStream} that writes directly to a direct {@link ByteBuffer} using {@code + * sun.misc.Unsafe}. + */ + private static final class UnsafeDirectNioEncoder extends CodedOutputStream { + private final ByteBuffer originalBuffer; + private final ByteBuffer buffer; + private final long address; + private final long initialPosition; + private final long limit; + private final long oneVarintLimit; + private long position; + + UnsafeDirectNioEncoder(ByteBuffer buffer) { + this.originalBuffer = buffer; + this.buffer = buffer.duplicate().order(ByteOrder.LITTLE_ENDIAN); + address = UnsafeUtil.addressOffset(buffer); + initialPosition = address + buffer.position(); + limit = address + buffer.limit(); + oneVarintLimit = limit - MAX_VARINT_SIZE; + position = initialPosition; + } + + static boolean isSupported() { + return UnsafeUtil.hasUnsafeByteBufferOperations(); + } + + @Override + public void writeTag(int fieldNumber, int wireType) throws IOException { + writeUInt32NoTag(WireFormat.makeTag(fieldNumber, wireType)); + } + + @Override + public void writeInt32(int fieldNumber, int value) throws IOException { + writeTag(fieldNumber, WireFormat.WIRETYPE_VARINT); + writeInt32NoTag(value); + } + + @Override + public void writeUInt32(int fieldNumber, int value) throws IOException { + writeTag(fieldNumber, WireFormat.WIRETYPE_VARINT); + writeUInt32NoTag(value); + } + + @Override + public void writeFixed32(int fieldNumber, int value) throws IOException { + writeTag(fieldNumber, WireFormat.WIRETYPE_FIXED32); + writeFixed32NoTag(value); + } + + @Override + public void writeUInt64(int fieldNumber, long value) throws IOException { + writeTag(fieldNumber, WireFormat.WIRETYPE_VARINT); + writeUInt64NoTag(value); + } + + @Override + public void writeFixed64(int fieldNumber, long value) throws IOException { + writeTag(fieldNumber, WireFormat.WIRETYPE_FIXED64); + writeFixed64NoTag(value); + } + + @Override + public void writeBool(int fieldNumber, boolean value) throws IOException { + writeTag(fieldNumber, WireFormat.WIRETYPE_VARINT); + write((byte) (value ? 1 : 0)); + } + + @Override + public void writeString(int fieldNumber, String value) throws IOException { + writeTag(fieldNumber, WireFormat.WIRETYPE_LENGTH_DELIMITED); + writeStringNoTag(value); + } + + @Override + public void writeBytes(int fieldNumber, ByteString value) throws IOException { + writeTag(fieldNumber, WireFormat.WIRETYPE_LENGTH_DELIMITED); + writeBytesNoTag(value); + } + + @Override + public void writeByteArray(int fieldNumber, byte[] value) throws IOException { + writeByteArray(fieldNumber, value, 0, value.length); + } + + @Override + public void writeByteArray(int fieldNumber, byte[] value, int offset, int length) + throws IOException { + writeTag(fieldNumber, WireFormat.WIRETYPE_LENGTH_DELIMITED); + writeByteArrayNoTag(value, offset, length); + } + + @Override + public void writeByteBuffer(int fieldNumber, ByteBuffer value) throws IOException { + writeTag(fieldNumber, WireFormat.WIRETYPE_LENGTH_DELIMITED); + writeUInt32NoTag(value.capacity()); + writeRawBytes(value); + } + + @Override + public void writeMessage(int fieldNumber, MessageLite value) throws IOException { + writeTag(fieldNumber, WireFormat.WIRETYPE_LENGTH_DELIMITED); + writeMessageNoTag(value); + } + + + @Override + public void writeMessageSetExtension(int fieldNumber, MessageLite value) throws IOException { + writeTag(WireFormat.MESSAGE_SET_ITEM, WireFormat.WIRETYPE_START_GROUP); + writeUInt32(WireFormat.MESSAGE_SET_TYPE_ID, fieldNumber); + writeMessage(WireFormat.MESSAGE_SET_MESSAGE, value); + writeTag(WireFormat.MESSAGE_SET_ITEM, WireFormat.WIRETYPE_END_GROUP); + } + + @Override + public void writeRawMessageSetExtension(int fieldNumber, ByteString value) throws IOException { + writeTag(WireFormat.MESSAGE_SET_ITEM, WireFormat.WIRETYPE_START_GROUP); + writeUInt32(WireFormat.MESSAGE_SET_TYPE_ID, fieldNumber); + writeBytes(WireFormat.MESSAGE_SET_MESSAGE, value); + writeTag(WireFormat.MESSAGE_SET_ITEM, WireFormat.WIRETYPE_END_GROUP); + } + + @Override + public void writeMessageNoTag(MessageLite value) throws IOException { + writeUInt32NoTag(value.getSerializedSize()); + value.writeTo(this); + } + + + @Override + public void write(byte value) throws IOException { + if (position >= limit) { + throw new OutOfSpaceException( + String.format("Pos: %d, limit: %d, len: %d", position, limit, 1)); + } + UnsafeUtil.putByte(position++, value); + } + + @Override + public void writeBytesNoTag(ByteString value) throws IOException { + writeUInt32NoTag(value.size()); + value.writeTo(this); + } + + @Override + public void writeByteArrayNoTag(byte[] value, int offset, int length) throws IOException { + writeUInt32NoTag(length); + write(value, offset, length); + } + + @Override + public void writeRawBytes(ByteBuffer value) throws IOException { + if (value.hasArray()) { + write(value.array(), value.arrayOffset(), value.capacity()); + } else { + ByteBuffer duplicated = value.duplicate(); + duplicated.clear(); + write(duplicated); + } + } + + @Override + public void writeInt32NoTag(int value) throws IOException { + if (value >= 0) { + writeUInt32NoTag(value); + } else { + // Must sign-extend. + writeUInt64NoTag(value); + } + } + + @Override + public void writeUInt32NoTag(int value) throws IOException { + if (position <= oneVarintLimit) { + // Optimization to avoid bounds checks on each iteration. + while (true) { + if ((value & ~0x7F) == 0) { + UnsafeUtil.putByte(position++, (byte) value); + return; + } else { + UnsafeUtil.putByte(position++, (byte) ((value & 0x7F) | 0x80)); + value >>>= 7; + } + } + } else { + while (position < limit) { + if ((value & ~0x7F) == 0) { + UnsafeUtil.putByte(position++, (byte) value); + return; + } else { + UnsafeUtil.putByte(position++, (byte) ((value & 0x7F) | 0x80)); + value >>>= 7; + } + } + throw new OutOfSpaceException( + String.format("Pos: %d, limit: %d, len: %d", position, limit, 1)); + } + } + + @Override + public void writeFixed32NoTag(int value) throws IOException { + buffer.putInt(bufferPos(position), value); + position += FIXED32_SIZE; + } + + @Override + public void writeUInt64NoTag(long value) throws IOException { + if (position <= oneVarintLimit) { + // Optimization to avoid bounds checks on each iteration. + while (true) { + if ((value & ~0x7FL) == 0) { + UnsafeUtil.putByte(position++, (byte) value); + return; + } else { + UnsafeUtil.putByte(position++, (byte) (((int) value & 0x7F) | 0x80)); + value >>>= 7; + } + } + } else { + while (position < limit) { + if ((value & ~0x7FL) == 0) { + UnsafeUtil.putByte(position++, (byte) value); + return; + } else { + UnsafeUtil.putByte(position++, (byte) (((int) value & 0x7F) | 0x80)); + value >>>= 7; + } + } + throw new OutOfSpaceException( + String.format("Pos: %d, limit: %d, len: %d", position, limit, 1)); + } + } + + @Override + public void writeFixed64NoTag(long value) throws IOException { + buffer.putLong(bufferPos(position), value); + position += FIXED64_SIZE; + } + + @Override + public void write(byte[] value, int offset, int length) throws IOException { + if (value == null + || offset < 0 + || length < 0 + || (value.length - length) < offset + || (limit - length) < position) { + if (value == null) { + throw new NullPointerException("value"); + } + throw new OutOfSpaceException( + String.format("Pos: %d, limit: %d, len: %d", position, limit, length)); + } + + UnsafeUtil.copyMemory(value, offset, position, length); + position += length; + } + + @Override + public void writeLazy(byte[] value, int offset, int length) throws IOException { + write(value, offset, length); + } + + @Override + public void write(ByteBuffer value) throws IOException { + try { + int length = value.remaining(); + repositionBuffer(position); + buffer.put(value); + position += length; + } catch (BufferOverflowException e) { + throw new OutOfSpaceException(e); + } + } + + @Override + public void writeLazy(ByteBuffer value) throws IOException { + write(value); + } + + @Override + public void writeStringNoTag(String value) throws IOException { + long prevPos = position; + try { + // UTF-8 byte length of the string is at least its UTF-16 code unit length (value.length()), + // and at most 3 times of it. We take advantage of this in both branches below. + int maxEncodedSize = value.length() * Utf8.MAX_BYTES_PER_CHAR; + int maxLengthVarIntSize = computeUInt32SizeNoTag(maxEncodedSize); + int minLengthVarIntSize = computeUInt32SizeNoTag(value.length()); + if (minLengthVarIntSize == maxLengthVarIntSize) { + // Save the current position and increment past the length field. We'll come back + // and write the length field after the encoding is complete. + int stringStart = bufferPos(position) + minLengthVarIntSize; + buffer.position(stringStart); + + // Encode the string. + Utf8.encodeUtf8(value, buffer); + + // Write the length and advance the position. + int length = buffer.position() - stringStart; + writeUInt32NoTag(length); + position += length; + } else { + // Calculate and write the encoded length. + int length = Utf8.encodedLength(value); + writeUInt32NoTag(length); + + // Write the string and advance the position. + repositionBuffer(position); + Utf8.encodeUtf8(value, buffer); + position += length; + } + } catch (UnpairedSurrogateException e) { + // Roll back the change and convert to an IOException. + position = prevPos; + repositionBuffer(position); + + // TODO(nathanmittler): We should throw an IOException here instead. + inefficientWriteStringNoTag(value, e); + } catch (IllegalArgumentException e) { + // Thrown by buffer.position() if out of range. + throw new OutOfSpaceException(e); + } catch (IndexOutOfBoundsException e) { + throw new OutOfSpaceException(e); + } + } + + @Override + public void flush() { + // Update the position of the original buffer. + originalBuffer.position(bufferPos(position)); + } + + @Override + public int spaceLeft() { + return (int) (limit - position); + } + + @Override + public int getTotalBytesWritten() { + return (int) (position - initialPosition); + } + + private void repositionBuffer(long pos) { + buffer.position(bufferPos(pos)); + } + + private int bufferPos(long pos) { + return (int) (pos - address); + } + } + + /** * Abstract base class for buffered encoders. */ private abstract static class AbstractBufferedEncoder extends CodedOutputStream { @@ -1883,19 +2255,17 @@ public abstract class CodedOutputStream extends ByteOutput { */ final void bufferUInt32NoTag(int value) { if (HAS_UNSAFE_ARRAY_OPERATIONS) { - final long originalPos = ARRAY_BASE_OFFSET + position; - long pos = originalPos; + final long originalPos = position; while (true) { if ((value & ~0x7F) == 0) { - UnsafeUtil.putByte(buffer, pos++, (byte) value); + UnsafeUtil.putByte(buffer, position++, (byte) value); break; } else { - UnsafeUtil.putByte(buffer, pos++, (byte) ((value & 0x7F) | 0x80)); + UnsafeUtil.putByte(buffer, position++, (byte) ((value & 0x7F) | 0x80)); value >>>= 7; } } - int delta = (int) (pos - originalPos); - position += delta; + int delta = (int) (position - originalPos); totalBytesWritten += delta; } else { while (true) { @@ -1918,19 +2288,17 @@ public abstract class CodedOutputStream extends ByteOutput { */ final void bufferUInt64NoTag(long value) { if (HAS_UNSAFE_ARRAY_OPERATIONS) { - final long originalPos = ARRAY_BASE_OFFSET + position; - long pos = originalPos; + final long originalPos = position; while (true) { if ((value & ~0x7FL) == 0) { - UnsafeUtil.putByte(buffer, pos++, (byte) value); + UnsafeUtil.putByte(buffer, position++, (byte) value); break; } else { - UnsafeUtil.putByte(buffer, pos++, (byte) (((int) value & 0x7F) | 0x80)); + UnsafeUtil.putByte(buffer, position++, (byte) (((int) value & 0x7F) | 0x80)); value >>>= 7; } } - int delta = (int) (pos - originalPos); - position += delta; + int delta = (int) (position - originalPos); totalBytesWritten += delta; } else { while (true) { @@ -1956,7 +2324,7 @@ public abstract class CodedOutputStream extends ByteOutput { buffer[position++] = (byte) ((value >> 8) & 0xFF); buffer[position++] = (byte) ((value >> 16) & 0xFF); buffer[position++] = (byte) ((value >> 24) & 0xFF); - totalBytesWritten += FIXED_32_SIZE; + totalBytesWritten += FIXED32_SIZE; } /** @@ -1972,7 +2340,7 @@ public abstract class CodedOutputStream extends ByteOutput { buffer[position++] = (byte) ((int) (value >> 40) & 0xFF); buffer[position++] = (byte) ((int) (value >> 48) & 0xFF); buffer[position++] = (byte) ((int) (value >> 56) & 0xFF); - totalBytesWritten += FIXED_64_SIZE; + totalBytesWritten += FIXED64_SIZE; } } @@ -2013,7 +2381,7 @@ public abstract class CodedOutputStream extends ByteOutput { @Override public void writeFixed32(final int fieldNumber, final int value) throws IOException { - flushIfNotAvailable(MAX_VARINT_SIZE + FIXED_32_SIZE); + flushIfNotAvailable(MAX_VARINT_SIZE + FIXED32_SIZE); bufferTag(fieldNumber, WireFormat.WIRETYPE_FIXED32); bufferFixed32NoTag(value); } @@ -2027,7 +2395,7 @@ public abstract class CodedOutputStream extends ByteOutput { @Override public void writeFixed64(final int fieldNumber, final long value) throws IOException { - flushIfNotAvailable(MAX_VARINT_SIZE + FIXED_64_SIZE); + flushIfNotAvailable(MAX_VARINT_SIZE + FIXED64_SIZE); bufferTag(fieldNumber, WireFormat.WIRETYPE_FIXED64); bufferFixed64NoTag(value); } @@ -2102,6 +2470,7 @@ public abstract class CodedOutputStream extends ByteOutput { writeMessageNoTag(value); } + @Override public void writeMessageSetExtension(final int fieldNumber, final MessageLite value) throws IOException { @@ -2126,6 +2495,7 @@ public abstract class CodedOutputStream extends ByteOutput { value.writeTo(this); } + @Override public void write(byte value) throws IOException { if (position == limit) { @@ -2153,7 +2523,7 @@ public abstract class CodedOutputStream extends ByteOutput { @Override public void writeFixed32NoTag(final int value) throws IOException { - flushIfNotAvailable(FIXED_32_SIZE); + flushIfNotAvailable(FIXED32_SIZE); bufferFixed32NoTag(value); } @@ -2165,7 +2535,7 @@ public abstract class CodedOutputStream extends ByteOutput { @Override public void writeFixed64NoTag(final long value) throws IOException { - flushIfNotAvailable(FIXED_64_SIZE); + flushIfNotAvailable(FIXED64_SIZE); bufferFixed64NoTag(value); } @@ -2316,7 +2686,7 @@ public abstract class CodedOutputStream extends ByteOutput { @Override public void writeFixed32(final int fieldNumber, final int value) throws IOException { - flushIfNotAvailable(MAX_VARINT_SIZE + FIXED_32_SIZE); + flushIfNotAvailable(MAX_VARINT_SIZE + FIXED32_SIZE); bufferTag(fieldNumber, WireFormat.WIRETYPE_FIXED32); bufferFixed32NoTag(value); } @@ -2330,7 +2700,7 @@ public abstract class CodedOutputStream extends ByteOutput { @Override public void writeFixed64(final int fieldNumber, final long value) throws IOException { - flushIfNotAvailable(MAX_VARINT_SIZE + FIXED_64_SIZE); + flushIfNotAvailable(MAX_VARINT_SIZE + FIXED64_SIZE); bufferTag(fieldNumber, WireFormat.WIRETYPE_FIXED64); bufferFixed64NoTag(value); } @@ -2405,6 +2775,7 @@ public abstract class CodedOutputStream extends ByteOutput { writeMessageNoTag(value); } + @Override public void writeMessageSetExtension(final int fieldNumber, final MessageLite value) throws IOException { @@ -2429,6 +2800,7 @@ public abstract class CodedOutputStream extends ByteOutput { value.writeTo(this); } + @Override public void write(byte value) throws IOException { if (position == limit) { @@ -2456,7 +2828,7 @@ public abstract class CodedOutputStream extends ByteOutput { @Override public void writeFixed32NoTag(final int value) throws IOException { - flushIfNotAvailable(FIXED_32_SIZE); + flushIfNotAvailable(FIXED32_SIZE); bufferFixed32NoTag(value); } @@ -2468,7 +2840,7 @@ public abstract class CodedOutputStream extends ByteOutput { @Override public void writeFixed64NoTag(final long value) throws IOException { - flushIfNotAvailable(FIXED_64_SIZE); + flushIfNotAvailable(FIXED64_SIZE); bufferFixed64NoTag(value); } diff --git a/java/core/src/main/java/com/google/protobuf/Descriptors.java b/java/core/src/main/java/com/google/protobuf/Descriptors.java index 1c34c24f..75b16fe3 100644 --- a/java/core/src/main/java/com/google/protobuf/Descriptors.java +++ b/java/core/src/main/java/com/google/protobuf/Descriptors.java @@ -30,9 +30,10 @@ package com.google.protobuf; +import static com.google.protobuf.Internal.checkNotNull; + import com.google.protobuf.DescriptorProtos.*; import com.google.protobuf.Descriptors.FileDescriptor.Syntax; - import java.lang.ref.WeakReference; import java.util.ArrayList; import java.util.Arrays; @@ -682,9 +683,7 @@ public final class Descriptors { /** Determines if the given field name is reserved. */ public boolean isReservedName(final String name) { - if (name == null) { - throw new NullPointerException(); - } + checkNotNull(name); for (final String reservedName : proto.getReservedNameList()) { if (reservedName.equals(name)) { return true; @@ -1212,33 +1211,20 @@ public final class Descriptors { private final Object defaultDefault; } - // TODO(xiaofeng): Implement it consistently across different languages. See b/24751348. - private static String fieldNameToLowerCamelCase(String name) { + // This method should match exactly with the ToJsonName() function in C++ + // descriptor.cc. + private static String fieldNameToJsonName(String name) { StringBuilder result = new StringBuilder(name.length()); boolean isNextUpperCase = false; for (int i = 0; i < name.length(); i++) { Character ch = name.charAt(i); - if (Character.isLowerCase(ch)) { - if (isNextUpperCase) { - result.append(Character.toUpperCase(ch)); - } else { - result.append(ch); - } - isNextUpperCase = false; - } else if (Character.isUpperCase(ch)) { - if (i == 0) { - // Force first letter to lower-case. - result.append(Character.toLowerCase(ch)); - } else { - // Capital letters after the first are left as-is. - result.append(ch); - } - isNextUpperCase = false; - } else if (Character.isDigit(ch)) { - result.append(ch); + if (ch == '_') { + isNextUpperCase = true; + } else if (isNextUpperCase) { + result.append(Character.toUpperCase(ch)); isNextUpperCase = false; } else { - isNextUpperCase = true; + result.append(ch); } } return result.toString(); @@ -1257,7 +1243,7 @@ public final class Descriptors { if (proto.hasJsonName()) { jsonName = proto.getJsonName(); } else { - jsonName = fieldNameToLowerCamelCase(proto.getName()); + jsonName = fieldNameToJsonName(proto.getName()); } if (proto.hasType()) { @@ -2136,7 +2122,7 @@ public final class Descriptors { // Can't happen, because addPackage() only fails when the name // conflicts with a non-package, but we have not yet added any // non-packages at this point. - assert false; + throw new AssertionError(e); } } } diff --git a/java/core/src/main/java/com/google/protobuf/DiscardUnknownFieldsParser.java b/java/core/src/main/java/com/google/protobuf/DiscardUnknownFieldsParser.java new file mode 100644 index 00000000..7ae94349 --- /dev/null +++ b/java/core/src/main/java/com/google/protobuf/DiscardUnknownFieldsParser.java @@ -0,0 +1,71 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package com.google.protobuf; + +/** + * Parsers to discard unknown fields during parsing. + */ +public final class DiscardUnknownFieldsParser { + + /** + * Warps a given {@link Parser} into a new {@link Parser} that discards unknown fields during + * parsing. + * + * <p>Usage example: + * <pre>{@code + * private final static Parser<Foo> FOO_PARSER = DiscardUnknownFieldsParser.wrap(Foo.parser()); + * Foo parseFooDiscardUnknown(ByteBuffer input) throws IOException { + * return FOO_PARSER.parseFrom(input); + * } + * }</pre> + * + * <p>Like all other implementations of {@code Parser}, this parser is stateless and thread-safe. + * + * @param parser The delegated parser that parses messages. + * @return a {@link Parser} that will discard unknown fields during parsing. + */ + public static final <T extends Message> Parser<T> wrap(final Parser<T> parser) { + return new AbstractParser<T>() { + @Override + public T parsePartialFrom(CodedInputStream input, ExtensionRegistryLite extensionRegistry) + throws InvalidProtocolBufferException { + try { + input.discardUnknownFields(); + return parser.parsePartialFrom(input, extensionRegistry); + } finally { + input.unsetDiscardUnknownFields(); + } + } + }; + } + + private DiscardUnknownFieldsParser() {} +} diff --git a/java/core/src/main/java/com/google/protobuf/DoubleArrayList.java b/java/core/src/main/java/com/google/protobuf/DoubleArrayList.java index 6177f3ca..5b28b4a8 100644 --- a/java/core/src/main/java/com/google/protobuf/DoubleArrayList.java +++ b/java/core/src/main/java/com/google/protobuf/DoubleArrayList.java @@ -30,8 +30,9 @@ package com.google.protobuf; -import com.google.protobuf.Internal.DoubleList; +import static com.google.protobuf.Internal.checkNotNull; +import com.google.protobuf.Internal.DoubleList; import java.util.Arrays; import java.util.Collection; import java.util.RandomAccess; @@ -41,9 +42,8 @@ import java.util.RandomAccess; * * @author dweis@google.com (Daniel Weis) */ -final class DoubleArrayList - extends AbstractProtobufList<Double> - implements DoubleList, RandomAccess { +final class DoubleArrayList extends AbstractProtobufList<Double> + implements DoubleList, RandomAccess, PrimitiveNonBoxingCollection { private static final DoubleArrayList EMPTY_LIST = new DoubleArrayList(); static { @@ -82,6 +82,18 @@ final class DoubleArrayList } @Override + protected void removeRange(int fromIndex, int toIndex) { + ensureIsMutable(); + if (toIndex < fromIndex) { + throw new IndexOutOfBoundsException("toIndex < fromIndex"); + } + + System.arraycopy(array, toIndex, array, fromIndex, size - toIndex); + size -= (toIndex - fromIndex); + modCount++; + } + + @Override public boolean equals(Object o) { if (this == o) { return true; @@ -199,9 +211,7 @@ final class DoubleArrayList public boolean addAll(Collection<? extends Double> collection) { ensureIsMutable(); - if (collection == null) { - throw new NullPointerException(); - } + checkNotNull(collection); // We specialize when adding another DoubleArrayList to avoid boxing elements. if (!(collection instanceof DoubleArrayList)) { @@ -249,7 +259,9 @@ final class DoubleArrayList ensureIsMutable(); ensureIndexInRange(index); double value = array[index]; - System.arraycopy(array, index + 1, array, index, size - index); + if (index < size - 1) { + System.arraycopy(array, index + 1, array, index, size - index); + } size--; modCount++; return value; diff --git a/java/core/src/main/java/com/google/protobuf/DynamicMessage.java b/java/core/src/main/java/com/google/protobuf/DynamicMessage.java index c54da67f..a6a774b7 100644 --- a/java/core/src/main/java/com/google/protobuf/DynamicMessage.java +++ b/java/core/src/main/java/com/google/protobuf/DynamicMessage.java @@ -30,11 +30,12 @@ package com.google.protobuf; +import static com.google.protobuf.Internal.checkNotNull; + import com.google.protobuf.Descriptors.Descriptor; import com.google.protobuf.Descriptors.EnumValueDescriptor; import com.google.protobuf.Descriptors.FieldDescriptor; import com.google.protobuf.Descriptors.OneofDescriptor; - import java.io.IOException; import java.io.InputStream; import java.util.Collections; @@ -297,7 +298,7 @@ public final class DynamicMessage extends AbstractMessage { } catch (InvalidProtocolBufferException e) { throw e.setUnfinishedMessage(builder.buildPartial()); } catch (IOException e) { - throw new InvalidProtocolBufferException(e.getMessage()) + throw new InvalidProtocolBufferException(e) .setUnfinishedMessage(builder.buildPartial()); } return builder.buildPartial(); @@ -338,6 +339,20 @@ public final class DynamicMessage extends AbstractMessage { this.fields = FieldSet.newFieldSet(); this.unknownFields = UnknownFieldSet.getDefaultInstance(); this.oneofCases = new FieldDescriptor[type.toProto().getOneofDeclCount()]; + // A MapEntry has all of its fields present at all times. + if (type.getOptions().getMapEntry()) { + populateMapEntry(); + } + } + + private void populateMapEntry() { + for (FieldDescriptor field : type.getFields()) { + if (field.getJavaType() == FieldDescriptor.JavaType.MESSAGE) { + fields.setField(field, getDefaultInstance(field.getMessageType())); + } else { + fields.setField(field, field.getDefaultValue()); + } + } } // --------------------------------------------------------------- @@ -350,6 +365,10 @@ public final class DynamicMessage extends AbstractMessage { } else { fields.clear(); } + // A MapEntry has all of its fields present at all times. + if (type.getOptions().getMapEntry()) { + populateMapEntry(); + } unknownFields = UnknownFieldSet.getDefaultInstance(); return this; } @@ -589,9 +608,8 @@ public final class DynamicMessage extends AbstractMessage { @Override public Builder setUnknownFields(UnknownFieldSet unknownFields) { - if (getDescriptorForType().getFile().getSyntax() - == Descriptors.FileDescriptor.Syntax.PROTO3) { - // Proto3 discards unknown fields. + if (getDescriptorForType().getFile().getSyntax() == Descriptors.FileDescriptor.Syntax.PROTO3 + && CodedInputStream.getProto3DiscardUnknownFieldsDefault()) { return this; } this.unknownFields = unknownFields; @@ -600,9 +618,8 @@ public final class DynamicMessage extends AbstractMessage { @Override public Builder mergeUnknownFields(UnknownFieldSet unknownFields) { - if (getDescriptorForType().getFile().getSyntax() - == Descriptors.FileDescriptor.Syntax.PROTO3) { - // Proto3 discards unknown fields. + if (getDescriptorForType().getFile().getSyntax() == Descriptors.FileDescriptor.Syntax.PROTO3 + && CodedInputStream.getProto3DiscardUnknownFieldsDefault()) { return this; } this.unknownFields = @@ -631,9 +648,7 @@ public final class DynamicMessage extends AbstractMessage { /** Verifies that the value is EnumValueDescriptor and matches Enum Type. */ private void ensureSingularEnumValueDescriptor( FieldDescriptor field, Object value) { - if (value == null) { - throw new NullPointerException(); - } + checkNotNull(value); if (!(value instanceof EnumValueDescriptor)) { throw new IllegalArgumentException( "DynamicMessage should use EnumValueDescriptor to set Enum Value."); diff --git a/java/core/src/main/java/com/google/protobuf/Extension.java b/java/core/src/main/java/com/google/protobuf/Extension.java index 08ec5b45..5df12e64 100644 --- a/java/core/src/main/java/com/google/protobuf/Extension.java +++ b/java/core/src/main/java/com/google/protobuf/Extension.java @@ -58,10 +58,7 @@ public abstract class Extension<ContainingType extends MessageLite, Type> PROTO1, } - protected ExtensionType getExtensionType() { - // TODO(liujisi): make this abstract after we fix proto1. - return ExtensionType.IMMUTABLE; - } + protected abstract ExtensionType getExtensionType(); /** * Type of a message extension. @@ -70,7 +67,7 @@ public abstract class Extension<ContainingType extends MessageLite, Type> PROTO1, PROTO2, } - + /** * If the extension is a message extension (i.e., getLiteType() == MESSAGE), * returns the type of the message, otherwise undefined. diff --git a/java/core/src/main/java/com/google/protobuf/ExtensionRegistry.java b/java/core/src/main/java/com/google/protobuf/ExtensionRegistry.java index 1c2e7e6f..a22a74a0 100644 --- a/java/core/src/main/java/com/google/protobuf/ExtensionRegistry.java +++ b/java/core/src/main/java/com/google/protobuf/ExtensionRegistry.java @@ -32,7 +32,6 @@ package com.google.protobuf; import com.google.protobuf.Descriptors.Descriptor; import com.google.protobuf.Descriptors.FieldDescriptor; - import java.util.Collection; import java.util.Collections; import java.util.HashMap; diff --git a/java/core/src/main/java/com/google/protobuf/ExtensionRegistryFactory.java b/java/core/src/main/java/com/google/protobuf/ExtensionRegistryFactory.java index 23174e24..89f7ab9b 100644 --- a/java/core/src/main/java/com/google/protobuf/ExtensionRegistryFactory.java +++ b/java/core/src/main/java/com/google/protobuf/ExtensionRegistryFactory.java @@ -34,7 +34,7 @@ import static com.google.protobuf.ExtensionRegistryLite.EMPTY_REGISTRY_LITE; /** * A factory object to create instances of {@link ExtensionRegistryLite}. - * + * * <p> * This factory detects (via reflection) if the full (non-Lite) protocol buffer libraries * are available, and if so, the instances returned are actually {@link ExtensionRegistry}. @@ -82,6 +82,7 @@ final class ExtensionRegistryFactory { return EMPTY_REGISTRY_LITE; } + static boolean isFullRegistry(ExtensionRegistryLite registry) { return EXTENSION_REGISTRY_CLASS != null && EXTENSION_REGISTRY_CLASS.isAssignableFrom(registry.getClass()); @@ -90,6 +91,6 @@ final class ExtensionRegistryFactory { private static final ExtensionRegistryLite invokeSubclassFactory(String methodName) throws Exception { return (ExtensionRegistryLite) EXTENSION_REGISTRY_CLASS - .getMethod(methodName).invoke(null); + .getDeclaredMethod(methodName).invoke(null); } } diff --git a/java/core/src/main/java/com/google/protobuf/ExtensionRegistryLite.java b/java/core/src/main/java/com/google/protobuf/ExtensionRegistryLite.java index 5e4d7739..f3d48d3a 100644 --- a/java/core/src/main/java/com/google/protobuf/ExtensionRegistryLite.java +++ b/java/core/src/main/java/com/google/protobuf/ExtensionRegistryLite.java @@ -105,9 +105,9 @@ public class ExtensionRegistryLite { /** * Construct a new, empty instance. - * - * <p> - * This may be an {@code ExtensionRegistry} if the full (non-Lite) proto libraries are available. + * + * <p>This may be an {@code ExtensionRegistry} if the full (non-Lite) proto libraries are + * available. */ public static ExtensionRegistryLite newInstance() { return ExtensionRegistryFactory.create(); @@ -121,6 +121,7 @@ public class ExtensionRegistryLite { return ExtensionRegistryFactory.createEmpty(); } + /** Returns an unmodifiable view of the registry. */ public ExtensionRegistryLite getUnmodifiable() { return new ExtensionRegistryLite(this); diff --git a/java/core/src/main/java/com/google/protobuf/FieldSet.java b/java/core/src/main/java/com/google/protobuf/FieldSet.java index 5b251743..c09daa32 100644 --- a/java/core/src/main/java/com/google/protobuf/FieldSet.java +++ b/java/core/src/main/java/com/google/protobuf/FieldSet.java @@ -30,8 +30,9 @@ package com.google.protobuf; -import com.google.protobuf.LazyField.LazyIterator; +import static com.google.protobuf.Internal.checkNotNull; +import com.google.protobuf.LazyField.LazyIterator; import java.io.IOException; import java.util.ArrayList; import java.util.Collections; @@ -101,6 +102,11 @@ final class FieldSet<FieldDescriptorType extends @SuppressWarnings("rawtypes") private static final FieldSet DEFAULT_INSTANCE = new FieldSet(true); + /** Returns {@code true} if empty, {@code false} otherwise. */ + boolean isEmpty() { + return fields.isEmpty(); + } + /** Make this FieldSet immutable from this point forward. */ @SuppressWarnings("unchecked") public void makeImmutable() { @@ -220,6 +226,7 @@ final class FieldSet<FieldDescriptorType extends return fields.entrySet().iterator(); } + /** * Useful for implementing * {@link Message#hasField(Descriptors.FieldDescriptor)}. @@ -384,9 +391,7 @@ final class FieldSet<FieldDescriptorType extends */ private static void verifyType(final WireFormat.FieldType type, final Object value) { - if (value == null) { - throw new NullPointerException(); - } + checkNotNull(value); boolean isValid = false; switch (type.getJavaType()) { diff --git a/java/core/src/main/java/com/google/protobuf/FloatArrayList.java b/java/core/src/main/java/com/google/protobuf/FloatArrayList.java index 90d6154b..7c080af3 100644 --- a/java/core/src/main/java/com/google/protobuf/FloatArrayList.java +++ b/java/core/src/main/java/com/google/protobuf/FloatArrayList.java @@ -30,8 +30,9 @@ package com.google.protobuf; -import com.google.protobuf.Internal.FloatList; +import static com.google.protobuf.Internal.checkNotNull; +import com.google.protobuf.Internal.FloatList; import java.util.Arrays; import java.util.Collection; import java.util.RandomAccess; @@ -41,9 +42,8 @@ import java.util.RandomAccess; * * @author dweis@google.com (Daniel Weis) */ -final class FloatArrayList - extends AbstractProtobufList<Float> - implements FloatList, RandomAccess { +final class FloatArrayList extends AbstractProtobufList<Float> + implements FloatList, RandomAccess, PrimitiveNonBoxingCollection { private static final FloatArrayList EMPTY_LIST = new FloatArrayList(); static { @@ -82,6 +82,18 @@ final class FloatArrayList } @Override + protected void removeRange(int fromIndex, int toIndex) { + ensureIsMutable(); + if (toIndex < fromIndex) { + throw new IndexOutOfBoundsException("toIndex < fromIndex"); + } + + System.arraycopy(array, toIndex, array, fromIndex, size - toIndex); + size -= (toIndex - fromIndex); + modCount++; + } + + @Override public boolean equals(Object o) { if (this == o) { return true; @@ -198,9 +210,7 @@ final class FloatArrayList public boolean addAll(Collection<? extends Float> collection) { ensureIsMutable(); - if (collection == null) { - throw new NullPointerException(); - } + checkNotNull(collection); // We specialize when adding another FloatArrayList to avoid boxing elements. if (!(collection instanceof FloatArrayList)) { @@ -248,7 +258,9 @@ final class FloatArrayList ensureIsMutable(); ensureIndexInRange(index); float value = array[index]; - System.arraycopy(array, index + 1, array, index, size - index); + if (index < size - 1) { + System.arraycopy(array, index + 1, array, index, size - index); + } size--; modCount++; return value; diff --git a/java/core/src/main/java/com/google/protobuf/GeneratedMessage.java b/java/core/src/main/java/com/google/protobuf/GeneratedMessage.java index cea05794..60179e37 100644 --- a/java/core/src/main/java/com/google/protobuf/GeneratedMessage.java +++ b/java/core/src/main/java/com/google/protobuf/GeneratedMessage.java @@ -854,6 +854,7 @@ public abstract class GeneratedMessage extends AbstractMessage /** Check if a singular extension is present. */ @Override + @SuppressWarnings("unchecked") public final <Type> boolean hasExtension(final ExtensionLite<MessageType, Type> extensionLite) { Extension<MessageType, Type> extension = checkNotLite(extensionLite); @@ -863,6 +864,7 @@ public abstract class GeneratedMessage extends AbstractMessage /** Get the number of elements in a repeated extension. */ @Override + @SuppressWarnings("unchecked") public final <Type> int getExtensionCount( final ExtensionLite<MessageType, List<Type>> extensionLite) { Extension<MessageType, List<Type>> extension = checkNotLite(extensionLite); @@ -2555,6 +2557,7 @@ public abstract class GeneratedMessage extends AbstractMessage } @Override + @SuppressWarnings("unchecked") public Object get(GeneratedMessage message) { List result = new ArrayList(); for (int i = 0; i < getRepeatedCount(message); i++) { @@ -2564,6 +2567,7 @@ public abstract class GeneratedMessage extends AbstractMessage } @Override + @SuppressWarnings("unchecked") public Object get(Builder builder) { List result = new ArrayList(); for (int i = 0; i < getRepeatedCount(builder); i++) { diff --git a/java/core/src/main/java/com/google/protobuf/GeneratedMessageLite.java b/java/core/src/main/java/com/google/protobuf/GeneratedMessageLite.java index 214971b1..df01547e 100644 --- a/java/core/src/main/java/com/google/protobuf/GeneratedMessageLite.java +++ b/java/core/src/main/java/com/google/protobuf/GeneratedMessageLite.java @@ -31,26 +31,27 @@ package com.google.protobuf; import com.google.protobuf.AbstractMessageLite.Builder.LimitedInputStream; -import com.google.protobuf.GeneratedMessageLite.EqualsVisitor.NotEqualsException; import com.google.protobuf.Internal.BooleanList; import com.google.protobuf.Internal.DoubleList; +import com.google.protobuf.Internal.EnumLiteMap; import com.google.protobuf.Internal.FloatList; import com.google.protobuf.Internal.IntList; import com.google.protobuf.Internal.LongList; import com.google.protobuf.Internal.ProtobufList; import com.google.protobuf.WireFormat.FieldType; - import java.io.IOException; import java.io.InputStream; import java.io.ObjectStreamException; import java.io.Serializable; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; +import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Collections; import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; /** * Lite version of {@link GeneratedMessage}. @@ -61,6 +62,12 @@ public abstract class GeneratedMessageLite< MessageType extends GeneratedMessageLite<MessageType, BuilderType>, BuilderType extends GeneratedMessageLite.Builder<MessageType, BuilderType>> extends AbstractMessageLite<MessageType, BuilderType> { + // BEGIN REGULAR + static final boolean ENABLE_EXPERIMENTAL_RUNTIME_AT_BUILD_TIME = false; + // END REGULAR + // BEGIN EXPERIMENTAL + // static final boolean ENABLE_EXPERIMENTAL_RUNTIME_AT_BUILD_TIME = true; + // END EXPERIMENTAL /** For use by generated code only. Lazily initialized to reduce allocations. */ protected UnknownFieldSetLite unknownFields = UnknownFieldSetLite.getDefaultInstance(); @@ -106,14 +113,22 @@ public abstract class GeneratedMessageLite< @SuppressWarnings("unchecked") // Guaranteed by runtime @Override public int hashCode() { - if (memoizedHashCode == 0) { - HashCodeVisitor visitor = new HashCodeVisitor(); - visit(visitor, (MessageType) this); - memoizedHashCode = visitor.hashCode; - } + if (memoizedHashCode != 0) { + return memoizedHashCode; + } + // BEGIN EXPERIMENTAL + // memoizedHashCode = Protobuf.getInstance().schemaFor(this).hashCode(this); + // return memoizedHashCode; + // END EXPERIMENTAL + // BEGIN REGULAR + HashCodeVisitor visitor = new HashCodeVisitor(); + visit(visitor, (MessageType) this); + memoizedHashCode = visitor.hashCode; return memoizedHashCode; + // END REGULAR } + // BEGIN REGULAR @SuppressWarnings("unchecked") // Guaranteed by runtime int hashCode(HashCodeVisitor visitor) { if (memoizedHashCode == 0) { @@ -125,6 +140,7 @@ public abstract class GeneratedMessageLite< } return memoizedHashCode; } + // END REGULAR @SuppressWarnings("unchecked") // Guaranteed by isInstance + runtime @Override @@ -137,17 +153,22 @@ public abstract class GeneratedMessageLite< return false; } + // BEGIN EXPERIMENTAL + // return Protobuf.getInstance().schemaFor(this).equals(this, (MessageType) other); + // END EXPERIMENTAL + // BEGIN REGULAR + try { visit(EqualsVisitor.INSTANCE, (MessageType) other); - } catch (NotEqualsException e) { + } catch (EqualsVisitor.NotEqualsException e) { return false; } return true; + // END REGULAR } - /** - * Same as {@link #equals(Object)} but throws {@code NotEqualsException}. - */ + // BEGIN REGULAR + /** Same as {@link #equals(Object)} but throws {@code NotEqualsException}. */ @SuppressWarnings("unchecked") // Guaranteed by isInstance + runtime boolean equals(EqualsVisitor visitor, MessageLite other) { if (this == other) { @@ -161,14 +182,13 @@ public abstract class GeneratedMessageLite< visit(visitor, (MessageType) other); return true; } + // END REGULAR // The general strategy for unknown fields is to use an UnknownFieldSetLite that is treated as // mutable during the parsing constructor and immutable after. This allows us to avoid // any unnecessary intermediary allocations while reducing the generated code size. - /** - * Lazily initializes unknown fields. - */ + /** Lazily initializes unknown fields. */ private final void ensureUnknownFieldsInitialized() { if (unknownFields == UnknownFieldSetLite.getDefaultInstance()) { unknownFields = UnknownFieldSetLite.newInstance(); @@ -210,17 +230,36 @@ public abstract class GeneratedMessageLite< * Called by subclasses to complete parsing. For use by generated code only. */ protected void makeImmutable() { + // BEGIN REGULAR dynamicMethod(MethodToInvoke.MAKE_IMMUTABLE); - unknownFields.makeImmutable(); + // END REGULAR + // BEGIN EXPERIMENTAL + // Protobuf.getInstance().schemaFor(this).makeImmutable(this); + // END EXPERIMENTAL + } + + protected final < + MessageType extends GeneratedMessageLite<MessageType, BuilderType>, + BuilderType extends GeneratedMessageLite.Builder<MessageType, BuilderType>> + BuilderType createBuilder() { + return (BuilderType) dynamicMethod(MethodToInvoke.NEW_BUILDER); + } + + protected final < + MessageType extends GeneratedMessageLite<MessageType, BuilderType>, + BuilderType extends GeneratedMessageLite.Builder<MessageType, BuilderType>> + BuilderType createBuilder(MessageType prototype) { + return ((BuilderType) createBuilder()).mergeFrom(prototype); } @Override public final boolean isInitialized() { - return dynamicMethod(MethodToInvoke.IS_INITIALIZED, Boolean.TRUE) != null; + return isInitialized((MessageType) this, Boolean.TRUE); } @Override + @SuppressWarnings("unchecked") public final BuilderType toBuilder() { BuilderType builder = (BuilderType) dynamicMethod(MethodToInvoke.NEW_BUILDER); builder.mergeFrom((MessageType) this); @@ -234,11 +273,15 @@ public abstract class GeneratedMessageLite< * For use by generated code only. */ public static enum MethodToInvoke { - // Rely on/modify instance state + // BEGIN REGULAR IS_INITIALIZED, VISIT, MERGE_FROM_STREAM, MAKE_IMMUTABLE, + // END REGULAR + // Rely on/modify instance state + GET_MEMOIZED_IS_INITIALIZED, + SET_MEMOIZED_IS_INITIALIZED, // Rely on static state NEW_MUTABLE_INSTANCE, @@ -252,25 +295,30 @@ public abstract class GeneratedMessageLite< * Theses different kinds of operations are required to implement message-level operations for * builders in the runtime. This method bundles those operations to reduce the generated methods * count. + * * <ul> - * <li>{@code MERGE_FROM_STREAM} is parameterized with an {@link CodedInputStream} and - * {@link ExtensionRegistryLite}. It consumes the input stream, parsing the contents into the - * returned protocol buffer. If parsing throws an {@link InvalidProtocolBufferException}, the - * implementation wraps it in a RuntimeException. - * <li>{@code NEW_INSTANCE} returns a new instance of the protocol buffer that has not yet been - * made immutable. See {@code MAKE_IMMUTABLE}. - * <li>{@code IS_INITIALIZED} is parameterized with a {@code Boolean} detailing whether to - * memoize. It returns {@code null} for false and the default instance for true. We optionally - * memoize to support the Builder case, where memoization is not desired. - * <li>{@code NEW_BUILDER} returns a {@code BuilderType} instance. - * <li>{@code VISIT} is parameterized with a {@code Visitor} and a {@code MessageType} and - * recursively iterates through the fields side by side between this and the instance. - * <li>{@code MAKE_IMMUTABLE} sets all internal fields to an immutable state. + * <li>{@code MERGE_FROM_STREAM} is parameterized with an {@link CodedInputStream} and {@link + * ExtensionRegistryLite}. It consumes the input stream, parsing the contents into the + * returned protocol buffer. If parsing throws an {@link InvalidProtocolBufferException}, + * the implementation wraps it in a RuntimeException. + * <li>{@code NEW_INSTANCE} returns a new instance of the protocol buffer that has not yet been + * made immutable. See {@code MAKE_IMMUTABLE}. + * <li>{@code IS_INITIALIZED} returns {@code null} for false and the default instance for true. + * It doesn't use or modify any memoized value. + * <li>{@code GET_MEMOIZED_IS_INITIALIZED} returns the memoized {@code isInitialized} byte + * value. + * <li>{@code SET_MEMOIZED_IS_INITIALIZED} sets the memoized {@code isInitilaized} byte value to + * 1 if the first parameter is not null, or to 0 if the first parameter is null. + * <li>{@code NEW_BUILDER} returns a {@code BuilderType} instance. + * <li>{@code VISIT} is parameterized with a {@code Visitor} and a {@code MessageType} and + * recursively iterates through the fields side by side between this and the instance. + * <li>{@code MAKE_IMMUTABLE} sets all internal fields to an immutable state. * </ul> + * * This method, plus the implementation of the Builder, enables the Builder class to be proguarded * away entirely on Android. - * <p> - * For use by generated code only. + * + * <p>For use by generated code only. */ protected abstract Object dynamicMethod(MethodToInvoke method, Object arg0, Object arg1); @@ -288,14 +336,27 @@ public abstract class GeneratedMessageLite< return dynamicMethod(method, null, null); } + // BEGIN REGULAR void visit(Visitor visitor, MessageType other) { dynamicMethod(MethodToInvoke.VISIT, visitor, other); unknownFields = visitor.visitUnknownFields(unknownFields, other.unknownFields); } + // END REGULAR + + @Override + int getMemoizedSerializedSize() { + return memoizedSerializedSize; + } + + @Override + void setMemoizedSerializedSize(int size) { + memoizedSerializedSize = size; + } + + /** - * Merge some unknown fields into the {@link UnknownFieldSetLite} for this - * message. + * Merge some unknown fields into the {@link UnknownFieldSetLite} for this message. * * <p>For use by generated code only. */ @@ -328,7 +389,7 @@ public abstract class GeneratedMessageLite< if (isBuilt) { MessageType newInstance = (MessageType) instance.dynamicMethod(MethodToInvoke.NEW_MUTABLE_INSTANCE); - newInstance.visit(MergeFromVisitor.INSTANCE, instance); + mergeFromInstance(newInstance, instance); instance = newInstance; isBuilt = false; } @@ -383,23 +444,60 @@ public abstract class GeneratedMessageLite< /** All subclasses implement this. */ public BuilderType mergeFrom(MessageType message) { copyOnWrite(); - instance.visit(MergeFromVisitor.INSTANCE, message); + mergeFromInstance(instance, message); return (BuilderType) this; } + private void mergeFromInstance(MessageType dest, MessageType src) { + // BEGIN EXPERIMENTAL + // Protobuf.getInstance().schemaFor(dest).mergeFrom(dest, src); + // END EXPERIMENTAL + // BEGIN REGULAR + dest.visit(MergeFromVisitor.INSTANCE, src); + // END REGULAR + } + @Override public MessageType getDefaultInstanceForType() { return defaultInstance; } @Override + public BuilderType mergeFrom(byte[] input, int offset, int length) + throws InvalidProtocolBufferException { + // BEGIN REGULAR + return super.mergeFrom(input, offset, length); + // END REGULAR + // BEGIN EXPERIMENTAL + // copyOnWrite(); + // try { + // Protobuf.getInstance().schemaFor(instance).mergeFrom( + // instance, input, offset, offset + length, new ArrayDecoders.Registers()); + // } catch (InvalidProtocolBufferException e) { + // throw e; + // } catch (IndexOutOfBoundsException e) { + // throw InvalidProtocolBufferException.truncatedMessage(); + // } catch (IOException e) { + // throw new RuntimeException("Reading from byte array should not throw IOException.", e); + // } + // return (BuilderType) this; + // END EXPERIMENTAL + } + + @Override public BuilderType mergeFrom( com.google.protobuf.CodedInputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws IOException { copyOnWrite(); try { + // BEGIN REGULAR instance.dynamicMethod(MethodToInvoke.MERGE_FROM_STREAM, input, extensionRegistry); + // END REGULAR + // BEGIN EXPERIMENTAL + // Protobuf.getInstance().schemaFor(instance).mergeFrom( + // instance, CodedInputStreamReader.forCodedInput(input), extensionRegistry); + // END EXPERIMENTAL } catch (RuntimeException e) { if (e.getCause() instanceof IOException) { throw (IOException) e.getCause(); @@ -448,12 +546,10 @@ public abstract class GeneratedMessageLite< extends GeneratedMessageLite<MessageType, BuilderType> implements ExtendableMessageOrBuilder<MessageType, BuilderType> { - /** - * Represents the set of extensions on this message. For use by generated - * code only. - */ - protected FieldSet<ExtensionDescriptor> extensions = FieldSet.newFieldSet(); + /** Represents the set of extensions on this message. For use by generated code only. */ + protected FieldSet<ExtensionDescriptor> extensions = FieldSet.emptySet(); + @SuppressWarnings("unchecked") protected final void mergeExtensionFields(final MessageType other) { if (extensions.isImmutable()) { extensions = extensions.clone(); @@ -461,11 +557,13 @@ public abstract class GeneratedMessageLite< extensions.mergeFrom(((ExtendableMessage) other).extensions); } + // BEGIN REGULAR @Override final void visit(Visitor visitor, MessageType other) { super.visit(visitor, other); extensions = visitor.visitExtensions(extensions, other.extensions); } + // END REGULAR /** * Parse an unknown field or an extension. For use by generated code only. @@ -478,8 +576,8 @@ public abstract class GeneratedMessageLite< MessageType defaultInstance, CodedInputStream input, ExtensionRegistryLite extensionRegistry, - int tag) throws IOException { - int wireType = WireFormat.getTagWireType(tag); + int tag) + throws IOException { int fieldNumber = WireFormat.getTagFieldNumber(tag); // TODO(dweis): How much bytecode would be saved by not requiring the generated code to @@ -487,6 +585,17 @@ public abstract class GeneratedMessageLite< GeneratedExtension<MessageType, ?> extension = extensionRegistry.findLiteExtensionByNumber( defaultInstance, fieldNumber); + return parseExtension(input, extensionRegistry, extension, tag, fieldNumber); + } + + private boolean parseExtension( + CodedInputStream input, + ExtensionRegistryLite extensionRegistry, + GeneratedExtension<?, ?> extension, + int tag, + int fieldNumber) + throws IOException { + int wireType = WireFormat.getTagWireType(tag); boolean unknown = false; boolean packed = false; if (extension == null) { @@ -509,6 +618,8 @@ public abstract class GeneratedMessageLite< return parseUnknownField(tag, input); } + ensureExtensionsAreMutable(); + if (packed) { int length = input.readRawVarint32(); int limit = input.pushLimit(length); @@ -587,9 +698,155 @@ public abstract class GeneratedMessageLite< extension.singularToFieldSetType(value)); } } - return true; } + + /** + * Parse an unknown field or an extension. For use by generated code only. + * + * <p>For use by generated code only. + * + * @return {@code true} unless the tag is an end-group tag. + */ + protected <MessageType extends MessageLite> boolean parseUnknownFieldAsMessageSet( + MessageType defaultInstance, + CodedInputStream input, + ExtensionRegistryLite extensionRegistry, + int tag) + throws IOException { + + if (tag == WireFormat.MESSAGE_SET_ITEM_TAG) { + mergeMessageSetExtensionFromCodedStream(defaultInstance, input, extensionRegistry); + return true; + } + + // TODO(dweis): Do we really want to support non message set wire format in message sets? + // Full runtime does... So we do for now. + int wireType = WireFormat.getTagWireType(tag); + if (wireType == WireFormat.WIRETYPE_LENGTH_DELIMITED) { + return parseUnknownField(defaultInstance, input, extensionRegistry, tag); + } else { + // TODO(dweis): Should we throw on invalid input? Full runtime does not... + return input.skipField(tag); + } + } + + /** + * Merges the message set from the input stream; requires message set wire format. + * + * @param defaultInstance the default instance of the containing message we are parsing in + * @param input the stream to parse from + * @param extensionRegistry the registry to use when parsing + */ + private <MessageType extends MessageLite> void mergeMessageSetExtensionFromCodedStream( + MessageType defaultInstance, + CodedInputStream input, + ExtensionRegistryLite extensionRegistry) + throws IOException { + // The wire format for MessageSet is: + // message MessageSet { + // repeated group Item = 1 { + // required int32 typeId = 2; + // required bytes message = 3; + // } + // } + // "typeId" is the extension's field number. The extension can only be + // a message type, where "message" contains the encoded bytes of that + // message. + // + // In practice, we will probably never see a MessageSet item in which + // the message appears before the type ID, or where either field does not + // appear exactly once. However, in theory such cases are valid, so we + // should be prepared to accept them. + + int typeId = 0; + ByteString rawBytes = null; // If we encounter "message" before "typeId" + GeneratedExtension<?, ?> extension = null; + + // Read bytes from input, if we get it's type first then parse it eagerly, + // otherwise we store the raw bytes in a local variable. + while (true) { + final int tag = input.readTag(); + if (tag == 0) { + break; + } + + if (tag == WireFormat.MESSAGE_SET_TYPE_ID_TAG) { + typeId = input.readUInt32(); + if (typeId != 0) { + extension = extensionRegistry.findLiteExtensionByNumber(defaultInstance, typeId); + } + + } else if (tag == WireFormat.MESSAGE_SET_MESSAGE_TAG) { + if (typeId != 0) { + if (extension != null) { + // We already know the type, so we can parse directly from the + // input with no copying. Hooray! + eagerlyMergeMessageSetExtension(input, extension, extensionRegistry, typeId); + rawBytes = null; + continue; + } + } + // We haven't seen a type ID yet or we want parse message lazily. + rawBytes = input.readBytes(); + + } else { // Unknown tag. Skip it. + if (!input.skipField(tag)) { + break; // End of group + } + } + } + input.checkLastTagWas(WireFormat.MESSAGE_SET_ITEM_END_TAG); + + // Process the raw bytes. + if (rawBytes != null && typeId != 0) { // Zero is not a valid type ID. + if (extension != null) { // We known the type + mergeMessageSetExtensionFromBytes(rawBytes, extensionRegistry, extension); + } else { // We don't know how to parse this. Ignore it. + if (rawBytes != null) { + mergeLengthDelimitedField(typeId, rawBytes); + } + } + } + } + + private void eagerlyMergeMessageSetExtension( + CodedInputStream input, + GeneratedExtension<?, ?> extension, + ExtensionRegistryLite extensionRegistry, + int typeId) + throws IOException { + int fieldNumber = typeId; + int tag = WireFormat.makeTag(typeId, WireFormat.WIRETYPE_LENGTH_DELIMITED); + parseExtension(input, extensionRegistry, extension, tag, fieldNumber); + } + + private void mergeMessageSetExtensionFromBytes( + ByteString rawBytes, + ExtensionRegistryLite extensionRegistry, + GeneratedExtension<?, ?> extension) + throws IOException { + MessageLite.Builder subBuilder = null; + MessageLite existingValue = (MessageLite) extensions.getField(extension.descriptor); + if (existingValue != null) { + subBuilder = existingValue.toBuilder(); + } + if (subBuilder == null) { + subBuilder = extension.getMessageDefaultInstance().newBuilderForType(); + } + subBuilder.mergeFrom(rawBytes, extensionRegistry); + MessageLite value = subBuilder.build(); + + ensureExtensionsAreMutable().setField( + extension.descriptor, extension.singularToFieldSetType(value)); + } + + private FieldSet<ExtensionDescriptor> ensureExtensionsAreMutable() { + if (extensions.isImmutable()) { + extensions = extensions.clone(); + } + return extensions; + } private void verifyExtensionContainingType( final GeneratedExtension<MessageType, ?> extension) { @@ -660,10 +917,12 @@ public abstract class GeneratedMessageLite< @Override protected final void makeImmutable() { super.makeImmutable(); - + // BEGIN REGULAR extensions.makeImmutable(); + // END REGULAR } + /** * Used by subclasses to serialize extensions. Extension ranges may be * interleaved with field numbers, but we must write them in canonical @@ -734,12 +993,6 @@ public abstract class GeneratedMessageLite< implements ExtendableMessageOrBuilder<MessageType, BuilderType> { protected ExtendableBuilder(MessageType defaultInstance) { super(defaultInstance); - - // TODO(dweis): This is kind of an unnecessary clone since we construct a - // new instance in the parent constructor which makes the extensions - // immutable. This extra allocation shouldn't matter in practice - // though. - instance.extensions = instance.extensions.clone(); } // For immutable message conversion. @@ -758,6 +1011,15 @@ public abstract class GeneratedMessageLite< instance.extensions = instance.extensions.clone(); } + private FieldSet<ExtensionDescriptor> ensureExtensionsAreMutable() { + FieldSet<ExtensionDescriptor> extensions = instance.extensions; + if (extensions.isImmutable()) { + extensions = extensions.clone(); + instance.extensions = extensions; + } + return extensions; + } + @Override public final MessageType buildPartial() { if (isBuilt) { @@ -807,14 +1069,6 @@ public abstract class GeneratedMessageLite< return instance.getExtension(extension, index); } - // This is implemented here only to work around an apparent bug in the - // Java compiler and/or build system. See bug #1898463. The mere presence - // of this dummy clone() implementation makes it go away. - @Override - public BuilderType clone() { - return super.clone(); - } - /** Set the value of an extension. */ public final <Type> BuilderType setExtension( final ExtensionLite<MessageType, Type> extension, @@ -824,7 +1078,8 @@ public abstract class GeneratedMessageLite< verifyExtensionContainingType(extensionLite); copyOnWrite(); - instance.extensions.setField(extensionLite.descriptor, extensionLite.toFieldSetType(value)); + ensureExtensionsAreMutable() + .setField(extensionLite.descriptor, extensionLite.toFieldSetType(value)); return (BuilderType) this; } @@ -837,8 +1092,9 @@ public abstract class GeneratedMessageLite< verifyExtensionContainingType(extensionLite); copyOnWrite(); - instance.extensions.setRepeatedField( - extensionLite.descriptor, index, extensionLite.singularToFieldSetType(value)); + ensureExtensionsAreMutable() + .setRepeatedField( + extensionLite.descriptor, index, extensionLite.singularToFieldSetType(value)); return (BuilderType) this; } @@ -851,8 +1107,8 @@ public abstract class GeneratedMessageLite< verifyExtensionContainingType(extensionLite); copyOnWrite(); - instance.extensions.addRepeatedField( - extensionLite.descriptor, extensionLite.singularToFieldSetType(value)); + ensureExtensionsAreMutable() + .addRepeatedField(extensionLite.descriptor, extensionLite.singularToFieldSetType(value)); return (BuilderType) this; } @@ -863,7 +1119,7 @@ public abstract class GeneratedMessageLite< verifyExtensionContainingType(extensionLite); copyOnWrite(); - instance.extensions.clearField(extensionLite.descriptor); + ensureExtensionsAreMutable().clearField(extensionLite.descriptor); return (BuilderType) this; } } @@ -1013,6 +1269,7 @@ public abstract class GeneratedMessageLite< } } + /** * Lite equivalent to {@link GeneratedMessage.GeneratedExtension}. * @@ -1253,11 +1510,26 @@ public abstract class GeneratedMessageLite< */ protected static final <T extends GeneratedMessageLite<T, ?>> boolean isInitialized( T message, boolean shouldMemoize) { - return message.dynamicMethod(MethodToInvoke.IS_INITIALIZED, shouldMemoize) != null; - } - - protected static final <T extends GeneratedMessageLite<T, ?>> void makeImmutable(T message) { - message.dynamicMethod(MethodToInvoke.MAKE_IMMUTABLE); + byte memoizedIsInitialized = + (Byte) message.dynamicMethod(MethodToInvoke.GET_MEMOIZED_IS_INITIALIZED); + if (memoizedIsInitialized == 1) { + return true; + } + if (memoizedIsInitialized == 0) { + return false; + } + // BEGIN EXPERIMENTAL + // boolean isInitialized = Protobuf.getInstance().schemaFor(message).isInitialized(message); + // END EXPERIMENTAL + // BEGIN REGULAR + boolean isInitialized = + message.dynamicMethod(MethodToInvoke.IS_INITIALIZED, Boolean.FALSE) != null; + // END REGULAR + if (shouldMemoize) { + message.dynamicMethod( + MethodToInvoke.SET_MEMOIZED_IS_INITIALIZED, isInitialized ? message : null); + } + return isInitialized; } protected static IntList emptyIntList() { @@ -1339,6 +1611,11 @@ public abstract class GeneratedMessageLite< throws InvalidProtocolBufferException { return GeneratedMessageLite.parsePartialFrom(defaultInstance, input, extensionRegistry); } + + @Override + public T parsePartialFrom(byte[] input) throws InvalidProtocolBufferException { + return GeneratedMessageLite.parsePartialFrom(defaultInstance, input); + } } /** @@ -1352,8 +1629,21 @@ public abstract class GeneratedMessageLite< @SuppressWarnings("unchecked") // Guaranteed by protoc T result = (T) instance.dynamicMethod(MethodToInvoke.NEW_MUTABLE_INSTANCE); try { + // BEGIN REGULAR result.dynamicMethod(MethodToInvoke.MERGE_FROM_STREAM, input, extensionRegistry); + // END REGULAR + // BEGIN EXPERIMENTAL + // Protobuf.getInstance().schemaFor(result).mergeFrom( + // result, CodedInputStreamReader.forCodedInput(input), extensionRegistry); + // END EXPERIMENTAL result.makeImmutable(); + // BEGIN EXPERIMENTAL + // } catch (IOException e) { + // if (e.getCause() instanceof InvalidProtocolBufferException) { + // throw (InvalidProtocolBufferException) e.getCause(); + // } + // throw new InvalidProtocolBufferException(e.getMessage()).setUnfinishedMessage(result); + // END EXPERIMENTAL } catch (RuntimeException e) { if (e.getCause() instanceof InvalidProtocolBufferException) { throw (InvalidProtocolBufferException) e.getCause(); @@ -1363,6 +1653,34 @@ public abstract class GeneratedMessageLite< return result; } + /** A static helper method for parsing a partial from byte array. */ + static <T extends GeneratedMessageLite<T, ?>> T parsePartialFrom(T instance, byte[] input) + throws InvalidProtocolBufferException { + // BEGIN REGULAR + return parsePartialFrom(instance, input, ExtensionRegistryLite.getEmptyRegistry()); + // END REGULAR + // BEGIN EXPERIMENTAL + // @SuppressWarnings("unchecked") // Guaranteed by protoc + // T result = (T) instance.dynamicMethod(MethodToInvoke.NEW_MUTABLE_INSTANCE); + // try { + // Protobuf.getInstance().schemaFor(result).mergeFrom( + // result, input, 0, input.length, new ArrayDecoders.Registers()); + // result.makeImmutable(); + // if (result.memoizedHashCode != 0) { + // throw new RuntimeException(); + // } + // } catch (IOException e) { + // if (e.getCause() instanceof InvalidProtocolBufferException) { + // throw (InvalidProtocolBufferException) e.getCause(); + // } + // throw new InvalidProtocolBufferException(e.getMessage()).setUnfinishedMessage(result); + // } catch (IndexOutOfBoundsException e) { + // throw InvalidProtocolBufferException.truncatedMessage().setUnfinishedMessage(result); + // } + // return result; + // END EXPERIMENTAL + } + protected static <T extends GeneratedMessageLite<T, ?>> T parsePartialFrom( T defaultInstance, CodedInputStream input) @@ -1388,6 +1706,20 @@ public abstract class GeneratedMessageLite< // Validates last tag. protected static <T extends GeneratedMessageLite<T, ?>> T parseFrom( + T defaultInstance, ByteBuffer data, ExtensionRegistryLite extensionRegistry) + throws InvalidProtocolBufferException { + return checkMessageInitialized( + parseFrom(defaultInstance, CodedInputStream.newInstance(data), extensionRegistry)); + } + + // Validates last tag. + protected static <T extends GeneratedMessageLite<T, ?>> T parseFrom( + T defaultInstance, ByteBuffer data) throws InvalidProtocolBufferException { + return parseFrom(defaultInstance, data, ExtensionRegistryLite.getEmptyRegistry()); + } + + // Validates last tag. + protected static <T extends GeneratedMessageLite<T, ?>> T parseFrom( T defaultInstance, ByteString data) throws InvalidProtocolBufferException { return checkMessageInitialized( @@ -1445,8 +1777,7 @@ public abstract class GeneratedMessageLite< protected static <T extends GeneratedMessageLite<T, ?>> T parseFrom( T defaultInstance, byte[] data) throws InvalidProtocolBufferException { - return checkMessageInitialized( - parsePartialFrom(defaultInstance, data, ExtensionRegistryLite.getEmptyRegistry())); + return checkMessageInitialized(parsePartialFrom(defaultInstance, data)); } // Validates last tag. @@ -1531,6 +1862,7 @@ public abstract class GeneratedMessageLite< return message; } + // BEGIN REGULAR /** * An abstract visitor that the generated code calls into that we use to implement various * features. Fields that are not members of oneofs are always visited. Members of a oneof are only @@ -1554,7 +1886,6 @@ public abstract class GeneratedMessageLite< Object visitOneofLong(boolean minePresent, Object mine, Object other); Object visitOneofString(boolean minePresent, Object mine, Object other); Object visitOneofByteString(boolean minePresent, Object mine, Object other); - Object visitOneofLazyMessage(boolean minePresent, Object mine, Object other); Object visitOneofMessage(boolean minePresent, Object mine, Object other); void visitOneofNotSet(boolean minePresent); @@ -1562,7 +1893,6 @@ public abstract class GeneratedMessageLite< * Message fields use null sentinals. */ <T extends MessageLite> T visitMessage(T mine, T other); - LazyFieldLite visitLazyMessage(LazyFieldLite mine, LazyFieldLite other); <T> ProtobufList<T> visitList(ProtobufList<T> mine, ProtobufList<T> other); BooleanList visitBooleanList(BooleanList mine, BooleanList other); @@ -1706,14 +2036,6 @@ public abstract class GeneratedMessageLite< } @Override - public Object visitOneofLazyMessage(boolean minePresent, Object mine, Object other) { - if (minePresent && mine.equals(other)) { - return mine; - } - throw NOT_EQUALS; - } - - @Override public Object visitOneofMessage(boolean minePresent, Object mine, Object other) { if (minePresent && ((GeneratedMessageLite<?, ?>) mine).equals(this, (MessageLite) other)) { return mine; @@ -1744,21 +2066,6 @@ public abstract class GeneratedMessageLite< } @Override - public LazyFieldLite visitLazyMessage( - LazyFieldLite mine, LazyFieldLite other) { - if (mine == null && other == null) { - return null; - } - if (mine == null || other == null) { - throw NOT_EQUALS; - } - if (mine.equals(other)) { - return mine; - } - throw NOT_EQUALS; - } - - @Override public <T> ProtobufList<T> visitList(ProtobufList<T> mine, ProtobufList<T> other) { if (!mine.equals(other)) { throw NOT_EQUALS; @@ -1838,13 +2145,13 @@ public abstract class GeneratedMessageLite< /** * Implements hashCode by accumulating state. */ - private static class HashCodeVisitor implements Visitor { + static class HashCodeVisitor implements Visitor { // The caller must ensure that the visitor is invoked parameterized with this and this such that // other is this. This is required due to how oneof cases are handled. See the class comment // on Visitor for more information. - private int hashCode = 0; + int hashCode = 0; @Override public boolean visitBoolean( @@ -1935,12 +2242,6 @@ public abstract class GeneratedMessageLite< } @Override - public Object visitOneofLazyMessage(boolean minePresent, Object mine, Object other) { - hashCode = (53 * hashCode) + mine.hashCode(); - return mine; - } - - @Override public Object visitOneofMessage(boolean minePresent, Object mine, Object other) { return visitMessage((MessageLite) mine, (MessageLite) other); } @@ -1969,18 +2270,6 @@ public abstract class GeneratedMessageLite< } @Override - public LazyFieldLite visitLazyMessage(LazyFieldLite mine, LazyFieldLite other) { - final int protoHash; - if (mine != null) { - protoHash = mine.hashCode(); - } else { - protoHash = 37; - } - hashCode = (53 * hashCode) + protoHash; - return mine; - } - - @Override public <T> ProtobufList<T> visitList(ProtobufList<T> mine, ProtobufList<T> other) { hashCode = (53 * hashCode) + mine.hashCode(); return mine; @@ -2123,13 +2412,6 @@ public abstract class GeneratedMessageLite< } @Override - public Object visitOneofLazyMessage(boolean minePresent, Object mine, Object other) { - LazyFieldLite lazy = minePresent ? (LazyFieldLite) mine : new LazyFieldLite(); - lazy.merge((LazyFieldLite) other); - return lazy; - } - - @Override public Object visitOneofMessage(boolean minePresent, Object mine, Object other) { if (minePresent) { return visitMessage((MessageLite) mine, (MessageLite) other); @@ -2153,17 +2435,6 @@ public abstract class GeneratedMessageLite< } @Override - public LazyFieldLite visitLazyMessage(LazyFieldLite mine, LazyFieldLite other) { - if (other != null) { - if (mine == null) { - mine = new LazyFieldLite(); - } - mine.merge(other); - } - return mine; - } - - @Override public <T> ProtobufList<T> visitList(ProtobufList<T> mine, ProtobufList<T> other) { int size = mine.size(); int otherSize = other.size(); @@ -2277,4 +2548,5 @@ public abstract class GeneratedMessageLite< return mine; } } + // END REGULAR } diff --git a/java/core/src/main/java/com/google/protobuf/GeneratedMessageV3.java b/java/core/src/main/java/com/google/protobuf/GeneratedMessageV3.java index 5dfe7ff7..4acd8f2f 100644 --- a/java/core/src/main/java/com/google/protobuf/GeneratedMessageV3.java +++ b/java/core/src/main/java/com/google/protobuf/GeneratedMessageV3.java @@ -30,14 +30,25 @@ package com.google.protobuf; +import static com.google.protobuf.Internal.checkNotNull; + import com.google.protobuf.Descriptors.Descriptor; import com.google.protobuf.Descriptors.EnumDescriptor; import com.google.protobuf.Descriptors.EnumValueDescriptor; import com.google.protobuf.Descriptors.FieldDescriptor; import com.google.protobuf.Descriptors.FileDescriptor; import com.google.protobuf.Descriptors.OneofDescriptor; +// In opensource protobuf, we have versioned this GeneratedMessageV3 class to GeneratedMessageV3V3 and +// in the future may have GeneratedMessageV3V4 etc. This allows us to change some aspects of this +// class without breaking binary compatibility with old generated code that still subclasses +// the old GeneratedMessageV3 class. To allow these different GeneratedMessageV3V? classes to +// interoperate (e.g., a GeneratedMessageV3V3 object has a message extension field whose class +// type is GeneratedMessageV3V4), these classes still share a common parent class AbstractMessage +// and are using the same GeneratedMessage.GeneratedExtension class for extension definitions. +// Since this class becomes GeneratedMessageV3V? in opensource, we have to add an import here +// to be able to use GeneratedMessage.GeneratedExtension. The GeneratedExtension definition in +// this file is also excluded from opensource to avoid conflict. import com.google.protobuf.GeneratedMessage.GeneratedExtension; - import java.io.IOException; import java.io.InputStream; import java.io.ObjectStreamException; @@ -45,6 +56,7 @@ import java.io.Serializable; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.Iterator; import java.util.List; @@ -266,13 +278,30 @@ public abstract class GeneratedMessageV3 extends AbstractMessage /** * Called by subclasses to parse an unknown field. + * * @return {@code true} unless the tag is an end-group tag. */ protected boolean parseUnknownField( CodedInputStream input, UnknownFieldSet.Builder unknownFields, ExtensionRegistryLite extensionRegistry, - int tag) throws IOException { + int tag) + throws IOException { + if (input.shouldDiscardUnknownFields()) { + return input.skipField(tag); + } + return unknownFields.mergeFieldFrom(tag, input); + } + + protected boolean parseUnknownFieldProto3( + CodedInputStream input, + UnknownFieldSet.Builder unknownFields, + ExtensionRegistryLite extensionRegistry, + int tag) + throws IOException { + if (input.shouldDiscardUnknownFieldsProto3()) { + return input.skipField(tag); + } return unknownFields.mergeFieldFrom(tag, input); } @@ -329,6 +358,10 @@ public abstract class GeneratedMessageV3 extends AbstractMessage throw e.unwrapIOException(); } } + + protected static boolean canUseUnsafe() { + return UnsafeUtil.hasUnsafeArrayOperations() && UnsafeUtil.hasUnsafeByteBufferOperations(); + } @Override public void writeTo(final CodedOutputStream output) throws IOException { @@ -608,17 +641,25 @@ public abstract class GeneratedMessageV3 extends AbstractMessage return (BuilderType) this; } + protected BuilderType setUnknownFieldsProto3(final UnknownFieldSet unknownFields) { + if (CodedInputStream.getProto3DiscardUnknownFieldsDefault()) { + return (BuilderType) this; + } + this.unknownFields = unknownFields; + onChanged(); + return (BuilderType) this; + } + @Override public BuilderType mergeUnknownFields( final UnknownFieldSet unknownFields) { - this.unknownFields = + return setUnknownFields( UnknownFieldSet.newBuilder(this.unknownFields) .mergeFrom(unknownFields) - .build(); - onChanged(); - return (BuilderType) this; + .build()); } + @Override public boolean isInitialized() { for (final FieldDescriptor field : getDescriptorForType().getFields()) { @@ -655,18 +696,6 @@ public abstract class GeneratedMessageV3 extends AbstractMessage } /** - * Called by subclasses to parse an unknown field. - * @return {@code true} unless the tag is an end-group tag. - */ - protected boolean parseUnknownField( - final CodedInputStream input, - final UnknownFieldSet.Builder unknownFields, - final ExtensionRegistryLite extensionRegistry, - final int tag) throws IOException { - return unknownFields.mergeFieldFrom(tag, input); - } - - /** * Implementation of {@link BuilderParent} for giving to our children. This * small inner class makes it so we don't publicly expose the BuilderParent * methods. @@ -855,6 +884,7 @@ public abstract class GeneratedMessageV3 extends AbstractMessage /** Check if a singular extension is present. */ @Override + @SuppressWarnings("unchecked") public final <Type> boolean hasExtension(final ExtensionLite<MessageType, Type> extensionLite) { Extension<MessageType, Type> extension = checkNotLite(extensionLite); @@ -864,6 +894,7 @@ public abstract class GeneratedMessageV3 extends AbstractMessage /** Get the number of elements in a repeated extension. */ @Override + @SuppressWarnings("unchecked") public final <Type> int getExtensionCount( final ExtensionLite<MessageType, List<Type>> extensionLite) { Extension<MessageType, List<Type>> extension = checkNotLite(extensionLite); @@ -974,8 +1005,23 @@ public abstract class GeneratedMessageV3 extends AbstractMessage ExtensionRegistryLite extensionRegistry, int tag) throws IOException { return MessageReflection.mergeFieldFrom( - input, unknownFields, extensionRegistry, getDescriptorForType(), - new MessageReflection.ExtensionAdapter(extensions), tag); + input, input.shouldDiscardUnknownFields() ? null : unknownFields, extensionRegistry, + getDescriptorForType(), new MessageReflection.ExtensionAdapter(extensions), tag); + } + + @Override + protected boolean parseUnknownFieldProto3( + CodedInputStream input, + UnknownFieldSet.Builder unknownFields, + ExtensionRegistryLite extensionRegistry, + int tag) throws IOException { + return MessageReflection.mergeFieldFrom( + input, + input.shouldDiscardUnknownFieldsProto3() ? null : unknownFields, + extensionRegistry, + getDescriptorForType(), + new MessageReflection.ExtensionAdapter(extensions), + tag); } @@ -1206,14 +1252,6 @@ public abstract class GeneratedMessageV3 extends AbstractMessage return super.clear(); } - // This is implemented here only to work around an apparent bug in the - // Java compiler and/or build system. See bug #1898463. The mere presence - // of this clone() implementation makes it go away. - @Override - public BuilderType clone() { - return super.clone(); - } - private void ensureExtensionsIsMutable() { if (extensions.isImmutable()) { extensions = extensions.clone(); @@ -1453,21 +1491,6 @@ public abstract class GeneratedMessageV3 extends AbstractMessage return super.isInitialized() && extensionsAreInitialized(); } - /** - * Called by subclasses to parse an unknown field or an extension. - * @return {@code true} unless the tag is an end-group tag. - */ - @Override - protected boolean parseUnknownField( - final CodedInputStream input, - final UnknownFieldSet.Builder unknownFields, - final ExtensionRegistryLite extensionRegistry, - final int tag) throws IOException { - return MessageReflection.mergeFieldFrom( - input, unknownFields, extensionRegistry, getDescriptorForType(), - new MessageReflection.BuilderAdapter(this), tag); - } - // --------------------------------------------------------------- // Reflection @@ -1609,23 +1632,6 @@ public abstract class GeneratedMessageV3 extends AbstractMessage FieldDescriptor getDescriptor(); } - private abstract static class CachedDescriptorRetriever - implements ExtensionDescriptorRetriever { - private volatile FieldDescriptor descriptor; - protected abstract FieldDescriptor loadDescriptor(); - - @Override - public FieldDescriptor getDescriptor() { - if (descriptor == null) { - synchronized (this) { - if (descriptor == null) { - descriptor = loadDescriptor(); - } - } - } - return descriptor; - } - } // ================================================================= @@ -1722,11 +1728,6 @@ public abstract class GeneratedMessageV3 extends AbstractMessage initialized = false; } - private boolean isMapFieldEnabled(FieldDescriptor field) { - boolean result = true; - return result; - } - /** * Ensures the field accessors are initialized. This method is thread-safe. * @@ -1750,7 +1751,7 @@ public abstract class GeneratedMessageV3 extends AbstractMessage } if (field.isRepeated()) { if (field.getJavaType() == FieldDescriptor.JavaType.MESSAGE) { - if (field.isMapField() && isMapFieldEnabled(field)) { + if (field.isMapField()) { fields[i] = new MapFieldAccessor( field, camelCaseNames[i], messageClass, builderClass); } else { @@ -2223,7 +2224,22 @@ public abstract class GeneratedMessageV3 extends AbstractMessage field.getNumber()); } + private Message coerceType(Message value) { + if (value == null) { + return null; + } + if (mapEntryMessageDefaultInstance.getClass().isInstance(value)) { + return value; + } + // The value is not the exact right message type. However, if it + // is an alternative implementation of the same type -- e.g. a + // DynamicMessage -- we should accept it. In this case we can make + // a copy of the message. + return mapEntryMessageDefaultInstance.toBuilder().mergeFrom(value).build(); + } + @Override + @SuppressWarnings("unchecked") public Object get(GeneratedMessageV3 message) { List result = new ArrayList(); for (int i = 0; i < getRepeatedCount(message); i++) { @@ -2233,6 +2249,7 @@ public abstract class GeneratedMessageV3 extends AbstractMessage } @Override + @SuppressWarnings("unchecked") public Object get(Builder builder) { List result = new ArrayList(); for (int i = 0; i < getRepeatedCount(builder); i++) { @@ -2281,12 +2298,12 @@ public abstract class GeneratedMessageV3 extends AbstractMessage @Override public void setRepeated(Builder builder, int index, Object value) { - getMutableMapField(builder).getMutableList().set(index, (Message) value); + getMutableMapField(builder).getMutableList().set(index, coerceType((Message) value)); } @Override public void addRepeated(Builder builder, Object value) { - getMutableMapField(builder).getMutableList().add((Message) value); + getMutableMapField(builder).getMutableList().add(coerceType((Message) value)); } @Override @@ -2679,7 +2696,7 @@ public abstract class GeneratedMessageV3 extends AbstractMessage return (Extension<MessageType, T>) extension; } - + protected static int computeStringSize(final int fieldNumber, final Object value) { if (value instanceof String) { return CodedOutputStream.computeStringSize(fieldNumber, (String) value); @@ -2687,7 +2704,7 @@ public abstract class GeneratedMessageV3 extends AbstractMessage return CodedOutputStream.computeBytesSize(fieldNumber, (ByteString) value); } } - + protected static int computeStringSizeNoTag(final Object value) { if (value instanceof String) { return CodedOutputStream.computeStringSizeNoTag((String) value); @@ -2695,7 +2712,7 @@ public abstract class GeneratedMessageV3 extends AbstractMessage return CodedOutputStream.computeBytesSizeNoTag((ByteString) value); } } - + protected static void writeString( CodedOutputStream output, final int fieldNumber, final Object value) throws IOException { if (value instanceof String) { @@ -2704,7 +2721,7 @@ public abstract class GeneratedMessageV3 extends AbstractMessage output.writeBytes(fieldNumber, (ByteString) value); } } - + protected static void writeStringNoTag( CodedOutputStream output, final Object value) throws IOException { if (value instanceof String) { @@ -2713,4 +2730,132 @@ public abstract class GeneratedMessageV3 extends AbstractMessage output.writeBytesNoTag((ByteString) value); } } + + protected static <V> void serializeIntegerMapTo( + CodedOutputStream out, + MapField<Integer, V> field, + MapEntry<Integer, V> defaultEntry, + int fieldNumber) throws IOException { + Map<Integer, V> m = field.getMap(); + if (!out.isSerializationDeterministic()) { + serializeMapTo(out, m, defaultEntry, fieldNumber); + return; + } + // Sorting the unboxed keys and then look up the values during serialziation is 2x faster + // than sorting map entries with a custom comparator directly. + int[] keys = new int[m.size()]; + int index = 0; + for (int k : m.keySet()) { + keys[index++] = k; + } + Arrays.sort(keys); + for (int key : keys) { + out.writeMessage(fieldNumber, + defaultEntry.newBuilderForType() + .setKey(key) + .setValue(m.get(key)) + .build()); + } + } + + protected static <V> void serializeLongMapTo( + CodedOutputStream out, + MapField<Long, V> field, + MapEntry<Long, V> defaultEntry, + int fieldNumber) + throws IOException { + Map<Long, V> m = field.getMap(); + if (!out.isSerializationDeterministic()) { + serializeMapTo(out, m, defaultEntry, fieldNumber); + return; + } + + long[] keys = new long[m.size()]; + int index = 0; + for (long k : m.keySet()) { + keys[index++] = k; + } + Arrays.sort(keys); + for (long key : keys) { + out.writeMessage(fieldNumber, + defaultEntry.newBuilderForType() + .setKey(key) + .setValue(m.get(key)) + .build()); + } + } + + protected static <V> void serializeStringMapTo( + CodedOutputStream out, + MapField<String, V> field, + MapEntry<String, V> defaultEntry, + int fieldNumber) + throws IOException { + Map<String, V> m = field.getMap(); + if (!out.isSerializationDeterministic()) { + serializeMapTo(out, m, defaultEntry, fieldNumber); + return; + } + + // Sorting the String keys and then look up the values during serialziation is 25% faster than + // sorting map entries with a custom comparator directly. + String[] keys = new String[m.size()]; + keys = m.keySet().toArray(keys); + Arrays.sort(keys); + for (String key : keys) { + out.writeMessage(fieldNumber, + defaultEntry.newBuilderForType() + .setKey(key) + .setValue(m.get(key)) + .build()); + } + } + + protected static <V> void serializeBooleanMapTo( + CodedOutputStream out, + MapField<Boolean, V> field, + MapEntry<Boolean, V> defaultEntry, + int fieldNumber) + throws IOException { + Map<Boolean, V> m = field.getMap(); + if (!out.isSerializationDeterministic()) { + serializeMapTo(out, m, defaultEntry, fieldNumber); + return; + } + maybeSerializeBooleanEntryTo(out, m, defaultEntry, fieldNumber, false); + maybeSerializeBooleanEntryTo(out, m, defaultEntry, fieldNumber, true); + } + + private static <V> void maybeSerializeBooleanEntryTo( + CodedOutputStream out, + Map<Boolean, V> m, + MapEntry<Boolean, V> defaultEntry, + int fieldNumber, + boolean key) + throws IOException { + if (m.containsKey(key)) { + out.writeMessage(fieldNumber, + defaultEntry.newBuilderForType() + .setKey(key) + .setValue(m.get(key)) + .build()); + } + } + + /** Serialize the map using the iteration order. */ + private static <K, V> void serializeMapTo( + CodedOutputStream out, + Map<K, V> m, + MapEntry<K, V> defaultEntry, + int fieldNumber) + throws IOException { + for (Map.Entry<K, V> entry : m.entrySet()) { + out.writeMessage(fieldNumber, + defaultEntry.newBuilderForType() + .setKey(entry.getKey()) + .setValue(entry.getValue()) + .build()); + } + } } + diff --git a/java/core/src/main/java/com/google/protobuf/IntArrayList.java b/java/core/src/main/java/com/google/protobuf/IntArrayList.java index 2f526e3f..aacd71e1 100644 --- a/java/core/src/main/java/com/google/protobuf/IntArrayList.java +++ b/java/core/src/main/java/com/google/protobuf/IntArrayList.java @@ -30,8 +30,9 @@ package com.google.protobuf; -import com.google.protobuf.Internal.IntList; +import static com.google.protobuf.Internal.checkNotNull; +import com.google.protobuf.Internal.IntList; import java.util.Arrays; import java.util.Collection; import java.util.RandomAccess; @@ -41,9 +42,8 @@ import java.util.RandomAccess; * * @author dweis@google.com (Daniel Weis) */ -final class IntArrayList - extends AbstractProtobufList<Integer> - implements IntList, RandomAccess { +final class IntArrayList extends AbstractProtobufList<Integer> + implements IntList, RandomAccess, PrimitiveNonBoxingCollection { private static final IntArrayList EMPTY_LIST = new IntArrayList(); static { @@ -82,6 +82,18 @@ final class IntArrayList } @Override + protected void removeRange(int fromIndex, int toIndex) { + ensureIsMutable(); + if (toIndex < fromIndex) { + throw new IndexOutOfBoundsException("toIndex < fromIndex"); + } + + System.arraycopy(array, toIndex, array, fromIndex, size - toIndex); + size -= (toIndex - fromIndex); + modCount++; + } + + @Override public boolean equals(Object o) { if (this == o) { return true; @@ -198,9 +210,7 @@ final class IntArrayList public boolean addAll(Collection<? extends Integer> collection) { ensureIsMutable(); - if (collection == null) { - throw new NullPointerException(); - } + checkNotNull(collection); // We specialize when adding another IntArrayList to avoid boxing elements. if (!(collection instanceof IntArrayList)) { @@ -248,7 +258,9 @@ final class IntArrayList ensureIsMutable(); ensureIndexInRange(index); int value = array[index]; - System.arraycopy(array, index + 1, array, index, size - index); + if (index < size - 1) { + System.arraycopy(array, index + 1, array, index, size - index); + } size--; modCount++; return value; diff --git a/java/core/src/main/java/com/google/protobuf/Internal.java b/java/core/src/main/java/com/google/protobuf/Internal.java index 3b4a0412..848cad03 100644 --- a/java/core/src/main/java/com/google/protobuf/Internal.java +++ b/java/core/src/main/java/com/google/protobuf/Internal.java @@ -60,6 +60,26 @@ public final class Internal { static final Charset ISO_8859_1 = Charset.forName("ISO-8859-1"); /** + * Throws an appropriate {@link NullPointerException} if the given objects is {@code null}. + */ + static <T> T checkNotNull(T obj) { + if (obj == null) { + throw new NullPointerException(); + } + return obj; + } + + /** + * Throws an appropriate {@link NullPointerException} if the given objects is {@code null}. + */ + static <T> T checkNotNull(T obj, String message) { + if (obj == null) { + throw new NullPointerException(message); + } + return obj; + } + + /** * Helper called by generated code to construct default values for string * fields. * <p> @@ -394,9 +414,8 @@ public final class Internal { } } - /** - * An empty byte array constant used in generated code. - */ + + /** An empty byte array constant used in generated code. */ public static final byte[] EMPTY_BYTE_ARRAY = new byte[0]; /** @@ -410,6 +429,11 @@ public final class Internal { CodedInputStream.newInstance(EMPTY_BYTE_ARRAY); + /** Helper method to merge two MessageLite instances. */ + static Object mergeMessage(Object destination, Object source) { + return ((MessageLite) destination).toBuilder().mergeFrom((MessageLite) source).buildPartial(); + } + /** * Provides an immutable view of {@code List<T>} around a {@code List<F>}. * diff --git a/java/core/src/main/java/com/google/protobuf/InvalidProtocolBufferException.java b/java/core/src/main/java/com/google/protobuf/InvalidProtocolBufferException.java index 85ce7b24..510c6aac 100644 --- a/java/core/src/main/java/com/google/protobuf/InvalidProtocolBufferException.java +++ b/java/core/src/main/java/com/google/protobuf/InvalidProtocolBufferException.java @@ -50,6 +50,10 @@ public class InvalidProtocolBufferException extends IOException { super(e.getMessage(), e); } + public InvalidProtocolBufferException(final String description, IOException e) { + super(description, e); + } + /** * Attaches an unfinished message to the exception to support best-effort * parsing in {@code Parser} interface. @@ -107,11 +111,23 @@ public class InvalidProtocolBufferException extends IOException { "Protocol message end-group tag did not match expected tag."); } - static InvalidProtocolBufferException invalidWireType() { - return new InvalidProtocolBufferException( + static InvalidWireTypeException invalidWireType() { + return new InvalidWireTypeException( "Protocol message tag had invalid wire type."); } + /** + * Exception indicating that and unexpected wire type was encountered for a field. + */ + @ExperimentalApi + public static class InvalidWireTypeException extends InvalidProtocolBufferException { + private static final long serialVersionUID = 3283890091615336259L; + + public InvalidWireTypeException(String description) { + super(description); + } + } + static InvalidProtocolBufferException recursionLimitExceeded() { return new InvalidProtocolBufferException( "Protocol message had too many levels of nesting. May be malicious. " + diff --git a/java/core/src/main/java/com/google/protobuf/IterableByteBufferInputStream.java b/java/core/src/main/java/com/google/protobuf/IterableByteBufferInputStream.java new file mode 100644 index 00000000..713e8064 --- /dev/null +++ b/java/core/src/main/java/com/google/protobuf/IterableByteBufferInputStream.java @@ -0,0 +1,150 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package com.google.protobuf; + +import static com.google.protobuf.Internal.EMPTY_BYTE_BUFFER; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.util.Iterator; + +class IterableByteBufferInputStream extends InputStream { + /** The {@link Iterator} with type {@link ByteBuffer} of {@code input} */ + private Iterator<ByteBuffer> iterator; + /** The current ByteBuffer; */ + private ByteBuffer currentByteBuffer; + /** The number of ByteBuffers in the input data. */ + private int dataSize; + /** + * Current {@code ByteBuffer}'s index + * + * <p>If index equals dataSize, then all the data in the InputStream has been consumed + */ + private int currentIndex; + /** The current position for current ByteBuffer */ + private int currentByteBufferPos; + /** Whether current ByteBuffer has an array */ + private boolean hasArray; + /** + * If the current ByteBuffer is unsafe-direct based, currentArray is null; otherwise should be the + * array inside ByteBuffer. + */ + private byte[] currentArray; + /** Current ByteBuffer's array offset */ + private int currentArrayOffset; + /** + * If the current ByteBuffer is unsafe-direct based, currentAddress is the start address of this + * ByteBuffer; otherwise should be zero. + */ + private long currentAddress; + + IterableByteBufferInputStream(Iterable<ByteBuffer> data) { + iterator = data.iterator(); + dataSize = 0; + for (ByteBuffer unused : data) { + dataSize++; + } + currentIndex = -1; + + if (!getNextByteBuffer()) { + currentByteBuffer = EMPTY_BYTE_BUFFER; + currentIndex = 0; + currentByteBufferPos = 0; + currentAddress = 0; + } + } + + private boolean getNextByteBuffer() { + currentIndex++; + if (!iterator.hasNext()) { + return false; + } + currentByteBuffer = iterator.next(); + currentByteBufferPos = currentByteBuffer.position(); + if (currentByteBuffer.hasArray()) { + hasArray = true; + currentArray = currentByteBuffer.array(); + currentArrayOffset = currentByteBuffer.arrayOffset(); + } else { + hasArray = false; + currentAddress = UnsafeUtil.addressOffset(currentByteBuffer); + currentArray = null; + } + return true; + } + + private void updateCurrentByteBufferPos(int numberOfBytesRead) { + currentByteBufferPos += numberOfBytesRead; + if (currentByteBufferPos == currentByteBuffer.limit()) { + getNextByteBuffer(); + } + } + + @Override + public int read() throws IOException { + if (currentIndex == dataSize) { + return -1; + } + if (hasArray) { + int result = currentArray[currentByteBufferPos + currentArrayOffset] & 0xFF; + updateCurrentByteBufferPos(1); + return result; + } else { + int result = UnsafeUtil.getByte(currentByteBufferPos + currentAddress) & 0xFF; + updateCurrentByteBufferPos(1); + return result; + } + } + + @Override + public int read(byte[] output, int offset, int length) throws IOException { + if (currentIndex == dataSize) { + return -1; + } + int remaining = currentByteBuffer.limit() - currentByteBufferPos; + if (length > remaining) { + length = remaining; + } + if (hasArray) { + System.arraycopy( + currentArray, currentByteBufferPos + currentArrayOffset, output, offset, length); + updateCurrentByteBufferPos(length); + } else { + int prevPos = currentByteBuffer.position(); + currentByteBuffer.position(currentByteBufferPos); + currentByteBuffer.get(output, offset, length); + currentByteBuffer.position(prevPos); + updateCurrentByteBufferPos(length); + } + return length; + } +} diff --git a/java/core/src/main/java/com/google/protobuf/LazyFieldLite.java b/java/core/src/main/java/com/google/protobuf/LazyFieldLite.java index 2febaace..49ecfc0b 100644 --- a/java/core/src/main/java/com/google/protobuf/LazyFieldLite.java +++ b/java/core/src/main/java/com/google/protobuf/LazyFieldLite.java @@ -284,29 +284,8 @@ public class LazyFieldLite { return; } - // At this point we have two fully parsed messages. We can't merge directly from one to the - // other because only generated builder code contains methods to mergeFrom another parsed - // message. We have to serialize one instance and then merge the bytes into the other. This may - // drop extensions from one of the messages if one of the values had an extension set on it - // directly. - // - // To mitigate this we prefer serializing a message that has an extension registry, and - // therefore a chance that all extensions set on it are in that registry. - // - // NOTE: The check for other.extensionRegistry not being null must come first because at this - // point in time if other.extensionRegistry is not null then this.extensionRegistry will not be - // null either. - if (other.extensionRegistry != null) { - setValue(mergeValueAndBytes(this.value, other.toByteString(), other.extensionRegistry)); - return; - } else if (this.extensionRegistry != null) { - setValue(mergeValueAndBytes(other.value, this.toByteString(), this.extensionRegistry)); - return; - } else { - // All extensions from the other message will be dropped because we have no registry. - setValue(mergeValueAndBytes(this.value, other.toByteString(), EMPTY_REGISTRY)); - return; - } + // At this point we have two fully parsed messages. + setValue(this.value.toBuilder().mergeFrom(other.value).build()); } /** @@ -415,6 +394,7 @@ public class LazyFieldLite { } } + /** * Might lazily parse the bytes that were previously passed in. Is thread-safe. */ diff --git a/java/core/src/main/java/com/google/protobuf/LongArrayList.java b/java/core/src/main/java/com/google/protobuf/LongArrayList.java index 5a772e3a..95945cb7 100644 --- a/java/core/src/main/java/com/google/protobuf/LongArrayList.java +++ b/java/core/src/main/java/com/google/protobuf/LongArrayList.java @@ -30,8 +30,9 @@ package com.google.protobuf; -import com.google.protobuf.Internal.LongList; +import static com.google.protobuf.Internal.checkNotNull; +import com.google.protobuf.Internal.LongList; import java.util.Arrays; import java.util.Collection; import java.util.RandomAccess; @@ -41,9 +42,8 @@ import java.util.RandomAccess; * * @author dweis@google.com (Daniel Weis) */ -final class LongArrayList - extends AbstractProtobufList<Long> - implements LongList, RandomAccess { +final class LongArrayList extends AbstractProtobufList<Long> + implements LongList, RandomAccess, PrimitiveNonBoxingCollection { private static final LongArrayList EMPTY_LIST = new LongArrayList(); static { @@ -82,6 +82,18 @@ final class LongArrayList } @Override + protected void removeRange(int fromIndex, int toIndex) { + ensureIsMutable(); + if (toIndex < fromIndex) { + throw new IndexOutOfBoundsException("toIndex < fromIndex"); + } + + System.arraycopy(array, toIndex, array, fromIndex, size - toIndex); + size -= (toIndex - fromIndex); + modCount++; + } + + @Override public boolean equals(Object o) { if (this == o) { return true; @@ -198,9 +210,7 @@ final class LongArrayList public boolean addAll(Collection<? extends Long> collection) { ensureIsMutable(); - if (collection == null) { - throw new NullPointerException(); - } + checkNotNull(collection); // We specialize when adding another LongArrayList to avoid boxing elements. if (!(collection instanceof LongArrayList)) { @@ -248,7 +258,9 @@ final class LongArrayList ensureIsMutable(); ensureIndexInRange(index); long value = array[index]; - System.arraycopy(array, index + 1, array, index, size - index); + if (index < size - 1) { + System.arraycopy(array, index + 1, array, index, size - index); + } size--; modCount++; return value; diff --git a/java/core/src/main/java/com/google/protobuf/MapEntry.java b/java/core/src/main/java/com/google/protobuf/MapEntry.java index 117cd911..0849b821 100644 --- a/java/core/src/main/java/com/google/protobuf/MapEntry.java +++ b/java/core/src/main/java/com/google/protobuf/MapEntry.java @@ -33,7 +33,6 @@ package com.google.protobuf; import com.google.protobuf.Descriptors.Descriptor; import com.google.protobuf.Descriptors.EnumValueDescriptor; import com.google.protobuf.Descriptors.FieldDescriptor; - import java.io.IOException; import java.util.Collections; import java.util.Map; @@ -89,6 +88,7 @@ public final class MapEntry<K, V> extends AbstractMessage { } /** Create a MapEntry with the provided key and value. */ + @SuppressWarnings("unchecked") private MapEntry(Metadata metadata, K key, V value) { this.key = key; this.value = value; @@ -109,7 +109,7 @@ public final class MapEntry<K, V> extends AbstractMessage { } catch (InvalidProtocolBufferException e) { throw e.setUnfinishedMessage(this); } catch (IOException e) { - throw new InvalidProtocolBufferException(e.getMessage()).setUnfinishedMessage(this); + throw new InvalidProtocolBufferException(e).setUnfinishedMessage(this); } } @@ -170,7 +170,7 @@ public final class MapEntry<K, V> extends AbstractMessage { @Override public Builder<K, V> toBuilder() { - return new Builder<K, V>(metadata, key, value); + return new Builder<K, V>(metadata, key, value, true, true); } @Override @@ -246,15 +246,19 @@ public final class MapEntry<K, V> extends AbstractMessage { private final Metadata<K, V> metadata; private K key; private V value; + private boolean hasKey; + private boolean hasValue; private Builder(Metadata<K, V> metadata) { - this(metadata, metadata.defaultKey, metadata.defaultValue); + this(metadata, metadata.defaultKey, metadata.defaultValue, false, false); } - private Builder(Metadata<K, V> metadata, K key, V value) { + private Builder(Metadata<K, V> metadata, K key, V value, boolean hasKey, boolean hasValue) { this.metadata = metadata; this.key = key; this.value = value; + this.hasKey = hasKey; + this.hasValue = hasValue; } public K getKey() { @@ -267,21 +271,25 @@ public final class MapEntry<K, V> extends AbstractMessage { public Builder<K, V> setKey(K key) { this.key = key; + this.hasKey = true; return this; } public Builder<K, V> clearKey() { this.key = metadata.defaultKey; + this.hasKey = false; return this; } public Builder<K, V> setValue(V value) { this.value = value; + this.hasValue = true; return this; } public Builder<K, V> clearValue() { this.value = metadata.defaultValue; + this.hasValue = false; return this; } @@ -334,6 +342,15 @@ public final class MapEntry<K, V> extends AbstractMessage { } else { if (field.getType() == FieldDescriptor.Type.ENUM) { value = ((EnumValueDescriptor) value).getNumber(); + } else if (field.getType() == FieldDescriptor.Type.MESSAGE) { + if (value != null && !metadata.defaultValue.getClass().isInstance(value)) { + // The value is not the exact right message type. However, if it + // is an alternative implementation of the same type -- e.g. a + // DynamicMessage -- we should accept it. In this case we can make + // a copy of the message. + value = + ((Message) metadata.defaultValue).toBuilder().mergeFrom((Message) value).build(); + } } setValue((V) value); } @@ -394,7 +411,7 @@ public final class MapEntry<K, V> extends AbstractMessage { @Override public boolean hasField(FieldDescriptor field) { checkFieldDescriptor(field); - return true; + return field.getNumber() == 1 ? hasKey : hasValue; } @Override @@ -426,8 +443,9 @@ public final class MapEntry<K, V> extends AbstractMessage { } @Override + @SuppressWarnings("unchecked") public Builder<K, V> clone() { - return new Builder(metadata, key, value); + return new Builder(metadata, key, value, hasKey, hasValue); } } @@ -437,4 +455,9 @@ public final class MapEntry<K, V> extends AbstractMessage { } return true; } + + /** Returns the metadata only for experimental runtime. */ + final Metadata<K, V> getMetadata() { + return metadata; + } } diff --git a/java/core/src/main/java/com/google/protobuf/MapEntryLite.java b/java/core/src/main/java/com/google/protobuf/MapEntryLite.java index 22aef8f9..dcb5dfad 100644 --- a/java/core/src/main/java/com/google/protobuf/MapEntryLite.java +++ b/java/core/src/main/java/com/google/protobuf/MapEntryLite.java @@ -223,4 +223,9 @@ public class MapEntryLite<K, V> { input.popLimit(oldLimit); map.put(key, value); } + + /** For experimental runtime internal use only. */ + Metadata<K, V> getMetadata() { + return metadata; + } } diff --git a/java/core/src/main/java/com/google/protobuf/MapField.java b/java/core/src/main/java/com/google/protobuf/MapField.java index a6109f98..ad8ceb02 100644 --- a/java/core/src/main/java/com/google/protobuf/MapField.java +++ b/java/core/src/main/java/com/google/protobuf/MapField.java @@ -30,6 +30,8 @@ package com.google.protobuf; +import static com.google.protobuf.Internal.checkNotNull; + import java.util.ArrayList; import java.util.Collection; import java.util.Collections; @@ -100,6 +102,7 @@ public class MapField<K, V> implements MutabilityOracle { } @Override + @SuppressWarnings("unchecked") public void convertMessageToKeyAndValue(Message message, Map<K, V> map) { MapEntry<K, V> entry = (MapEntry<K, V>) message; map.put(entry.getKey(), entry.getValue()); @@ -328,6 +331,8 @@ public class MapField<K, V> implements MutabilityOracle { @Override public V put(K key, V value) { mutabilityOracle.ensureMutable(); + checkNotNull(key); + checkNotNull(value); return delegate.put(key, value); } @@ -340,6 +345,10 @@ public class MapField<K, V> implements MutabilityOracle { @Override public void putAll(Map<? extends K, ? extends V> m) { mutabilityOracle.ensureMutable(); + for (K key : m.keySet()) { + checkNotNull(key); + checkNotNull(m.get(key)); + } delegate.putAll(m); } diff --git a/java/core/src/main/java/com/google/protobuf/MapFieldLite.java b/java/core/src/main/java/com/google/protobuf/MapFieldLite.java index 3c0ad89a..a8b3dd88 100644 --- a/java/core/src/main/java/com/google/protobuf/MapFieldLite.java +++ b/java/core/src/main/java/com/google/protobuf/MapFieldLite.java @@ -30,8 +30,9 @@ package com.google.protobuf; -import com.google.protobuf.Internal.EnumLite; +import static com.google.protobuf.Internal.checkNotNull; +import com.google.protobuf.Internal.EnumLite; import java.util.Arrays; import java.util.Collections; import java.util.LinkedHashMap; @@ -58,7 +59,7 @@ public final class MapFieldLite<K, V> extends LinkedHashMap<K, V> { } @SuppressWarnings({"rawtypes", "unchecked"}) - private static final MapFieldLite EMPTY_MAP_FIELD = new MapFieldLite(Collections.emptyMap()); + private static final MapFieldLite EMPTY_MAP_FIELD = new MapFieldLite(); static { EMPTY_MAP_FIELD.makeImmutable(); } @@ -83,11 +84,14 @@ public final class MapFieldLite<K, V> extends LinkedHashMap<K, V> { @Override public void clear() { ensureMutable(); - clear(); + super.clear(); } @Override public V put(K key, V value) { ensureMutable(); + checkNotNull(key); + + checkNotNull(value); return super.put(key, value); } @@ -97,6 +101,7 @@ public final class MapFieldLite<K, V> extends LinkedHashMap<K, V> { @Override public void putAll(Map<? extends K, ? extends V> m) { ensureMutable(); + checkForNullKeysAndValues(m); super.putAll(m); } @@ -105,6 +110,13 @@ public final class MapFieldLite<K, V> extends LinkedHashMap<K, V> { return super.remove(key); } + private static void checkForNullKeysAndValues(Map<?, ?> m) { + for (Object key : m.keySet()) { + checkNotNull(key); + checkNotNull(m.get(key)); + } + } + private static boolean equals(Object a, Object b) { if (a instanceof byte[] && b instanceof byte[]) { return Arrays.equals((byte[]) a, (byte[]) b); diff --git a/java/core/src/main/java/com/google/protobuf/Message.java b/java/core/src/main/java/com/google/protobuf/Message.java index 94590fb9..0770d417 100644 --- a/java/core/src/main/java/com/google/protobuf/Message.java +++ b/java/core/src/main/java/com/google/protobuf/Message.java @@ -125,7 +125,7 @@ public interface Message extends MessageLite, MessageOrBuilder { * it is merged into the corresponding sub-message of this message * using the same merging rules.<br> * * For repeated fields, the elements in {@code other} are concatenated - * with the elements in this message. + * with the elements in this message.<br> * * For oneof groups, if the other message has one of the fields set, * the group of this message is cleared and replaced by the field * of the other message, so that the oneof constraint is preserved. diff --git a/java/core/src/main/java/com/google/protobuf/MessageLiteToString.java b/java/core/src/main/java/com/google/protobuf/MessageLiteToString.java index 43847651..8e265935 100644 --- a/java/core/src/main/java/com/google/protobuf/MessageLiteToString.java +++ b/java/core/src/main/java/com/google/protobuf/MessageLiteToString.java @@ -31,6 +31,7 @@ package com.google.protobuf; import java.lang.reflect.Method; +import java.lang.reflect.Modifier; import java.util.HashMap; import java.util.Iterator; import java.util.List; @@ -38,20 +39,18 @@ import java.util.Map; import java.util.Set; import java.util.TreeSet; -/** - * Helps generate {@link String} representations of {@link MessageLite} protos. - */ -// TODO(dweis): Fix map fields. +/** Helps generate {@link String} representations of {@link MessageLite} protos. */ final class MessageLiteToString { private static final String LIST_SUFFIX = "List"; private static final String BUILDER_LIST_SUFFIX = "OrBuilderList"; + private static final String MAP_SUFFIX = "Map"; private static final String BYTES_SUFFIX = "Bytes"; - + /** - * Returns a {@link String} representation of the {@link MessageLite} object. The first line of + * Returns a {@link String} representation of the {@link MessageLite} object. The first line of * the {@code String} representation representation includes a comment string to uniquely identify - * the objcet instance. This acts as an indicator that this should not be relied on for + * the object instance. This acts as an indicator that this should not be relied on for * comparisons. * * <p>For use by generated code only. @@ -71,8 +70,9 @@ final class MessageLiteToString { */ private static void reflectivePrintWithIndent( MessageLite messageLite, StringBuilder buffer, int indent) { - // Build a map of method name to method. We're looking for methods like getFoo(), hasFoo(), and - // getFooList() which might be useful for building an object's string representation. + // Build a map of method name to method. We're looking for methods like getFoo(), hasFoo(), + // getFooList() and getFooMap() which might be useful for building an object's string + // representation. Map<String, Method> nameToNoArgMethod = new HashMap<String, Method>(); Map<String, Method> nameToMethod = new HashMap<String, Method>(); Set<String> getters = new TreeSet<String>(); @@ -89,13 +89,17 @@ final class MessageLiteToString { for (String getter : getters) { String suffix = getter.replaceFirst("get", ""); - if (suffix.endsWith(LIST_SUFFIX) && !suffix.endsWith(BUILDER_LIST_SUFFIX)) { - String camelCase = suffix.substring(0, 1).toLowerCase() - + suffix.substring(1, suffix.length() - LIST_SUFFIX.length()); + if (suffix.endsWith(LIST_SUFFIX) + && !suffix.endsWith(BUILDER_LIST_SUFFIX) + // Sometimes people have fields named 'list' that aren't repeated. + && !suffix.equals(LIST_SUFFIX)) { + String camelCase = + suffix.substring(0, 1).toLowerCase() + + suffix.substring(1, suffix.length() - LIST_SUFFIX.length()); // Try to reflectively get the value and toString() the field as if it were repeated. This - // only works if the method names have not be proguarded out or renamed. - Method listMethod = nameToNoArgMethod.get("get" + suffix); - if (listMethod != null) { + // only works if the method names have not been proguarded out or renamed. + Method listMethod = nameToNoArgMethod.get(getter); + if (listMethod != null && listMethod.getReturnType().equals(List.class)) { printField( buffer, indent, @@ -104,6 +108,30 @@ final class MessageLiteToString { continue; } } + if (suffix.endsWith(MAP_SUFFIX) + // Sometimes people have fields named 'map' that aren't maps. + && !suffix.equals(MAP_SUFFIX)) { + String camelCase = + suffix.substring(0, 1).toLowerCase() + + suffix.substring(1, suffix.length() - MAP_SUFFIX.length()); + // Try to reflectively get the value and toString() the field as if it were a map. This only + // works if the method names have not been proguarded out or renamed. + Method mapMethod = nameToNoArgMethod.get(getter); + if (mapMethod != null + && mapMethod.getReturnType().equals(Map.class) + // Skip the deprecated getter method with no prefix "Map" when the field name ends with + // "map". + && !mapMethod.isAnnotationPresent(Deprecated.class) + // Skip the internal mutable getter method. + && Modifier.isPublic(mapMethod.getModifiers())) { + printField( + buffer, + indent, + camelCaseToSnakeCase(camelCase), + GeneratedMessageLite.invokeOrDie(mapMethod, messageLite)); + continue; + } + } Method setter = nameToMethod.get("set" + suffix); if (setter == null) { @@ -115,26 +143,23 @@ final class MessageLiteToString { // Heuristic to skip bytes based accessors for string fields. continue; } - + String camelCase = suffix.substring(0, 1).toLowerCase() + suffix.substring(1); // Try to reflectively get the value and toString() the field as if it were optional. This - // only works if the method names have not be proguarded out or renamed. + // only works if the method names have not been proguarded out or renamed. Method getMethod = nameToNoArgMethod.get("get" + suffix); Method hasMethod = nameToNoArgMethod.get("has" + suffix); // TODO(dweis): Fix proto3 semantics. if (getMethod != null) { Object value = GeneratedMessageLite.invokeOrDie(getMethod, messageLite); - final boolean hasValue = hasMethod == null - ? !isDefaultValue(value) - : (Boolean) GeneratedMessageLite.invokeOrDie(hasMethod, messageLite); - // TODO(dweis): This doesn't stop printing oneof case twice: value and enum style. + final boolean hasValue = + hasMethod == null + ? !isDefaultValue(value) + : (Boolean) GeneratedMessageLite.invokeOrDie(hasMethod, messageLite); + // TODO(dweis): This doesn't stop printing oneof case twice: value and enum style. if (hasValue) { - printField( - buffer, - indent, - camelCaseToSnakeCase(camelCase), - value); + printField(buffer, indent, camelCaseToSnakeCase(camelCase), value); } continue; } @@ -153,7 +178,7 @@ final class MessageLiteToString { ((GeneratedMessageLite<?, ?>) messageLite).unknownFields.printWithIndent(buffer, indent); } } - + private static boolean isDefaultValue(Object o) { if (o instanceof Boolean) { return !((Boolean) o); @@ -179,7 +204,7 @@ final class MessageLiteToString { if (o instanceof java.lang.Enum<?>) { // Catches oneof enums. return ((java.lang.Enum<?>) o).ordinal() == 0; } - + return false; } @@ -201,6 +226,13 @@ final class MessageLiteToString { } return; } + if (object instanceof Map<?, ?>) { + Map<?, ?> map = (Map<?, ?>) object; + for (Map.Entry<?, ?> entry : map.entrySet()) { + printField(buffer, indent, name, entry); + } + return; + } buffer.append('\n'); for (int i = 0; i < indent; i++) { @@ -220,11 +252,21 @@ final class MessageLiteToString { buffer.append(' '); } buffer.append("}"); + } else if (object instanceof Map.Entry<?, ?>) { + buffer.append(" {"); + Map.Entry<?, ?> entry = (Map.Entry<?, ?>) object; + printField(buffer, indent + 2, "key", entry.getKey()); + printField(buffer, indent + 2, "value", entry.getValue()); + buffer.append("\n"); + for (int i = 0; i < indent; i++) { + buffer.append(' '); + } + buffer.append("}"); } else { buffer.append(": ").append(object.toString()); } } - + private static final String camelCaseToSnakeCase(String camelCase) { StringBuilder builder = new StringBuilder(); for (int i = 0; i < camelCase.length(); i++) { diff --git a/java/core/src/main/java/com/google/protobuf/MessageReflection.java b/java/core/src/main/java/com/google/protobuf/MessageReflection.java index 3d73efb3..69ad7ddf 100644 --- a/java/core/src/main/java/com/google/protobuf/MessageReflection.java +++ b/java/core/src/main/java/com/google/protobuf/MessageReflection.java @@ -31,7 +31,6 @@ package com.google.protobuf; import com.google.protobuf.Descriptors.FieldDescriptor; - import java.io.IOException; import java.util.ArrayList; import java.util.List; @@ -714,12 +713,14 @@ class MessageReflection { } /** - * Parses a single field into MergeTarget. The target can be Message.Builder, - * FieldSet or MutableMessage. + * Parses a single field into MergeTarget. The target can be Message.Builder, FieldSet or + * MutableMessage. * - * Package-private because it is used by GeneratedMessage.ExtendableMessage. + * <p>Package-private because it is used by GeneratedMessage.ExtendableMessage. * * @param tag The tag, which should have already been read. + * @param unknownFields If not null, unknown fields will be merged to this {@link + * UnknownFieldSet}, otherwise unknown fields will be discarded. * @return {@code true} unless the tag is an end-group tag. */ static boolean mergeFieldFrom( @@ -728,7 +729,8 @@ class MessageReflection { ExtensionRegistryLite extensionRegistry, Descriptors.Descriptor type, MergeTarget target, - int tag) throws IOException { + int tag) + throws IOException { if (type.getOptions().getMessageSetWireFormat() && tag == WireFormat.MESSAGE_SET_ITEM_TAG) { mergeMessageSetExtensionFromCodedStream( @@ -792,7 +794,11 @@ class MessageReflection { } if (unknown) { // Unknown field or wrong wire type. Skip. - return unknownFields.mergeFieldFrom(tag, input); + if (unknownFields != null) { + return unknownFields.mergeFieldFrom(tag, input); + } else { + return input.skipField(tag); + } } if (packed) { @@ -844,7 +850,9 @@ class MessageReflection { // If the number isn't recognized as a valid value for this enum, // drop it. if (value == null) { - unknownFields.mergeVarintField(fieldNumber, rawValue); + if (unknownFields != null) { + unknownFields.mergeVarintField(fieldNumber, rawValue); + } return true; } } @@ -947,7 +955,7 @@ class MessageReflection { mergeMessageSetExtensionFromBytes( rawBytes, extension, extensionRegistry, target); } else { // We don't know how to parse this. Ignore it. - if (rawBytes != null) { + if (rawBytes != null && unknownFields != null) { unknownFields.mergeField(typeId, UnknownFieldSet.Field.newBuilder() .addLengthDelimited(rawBytes).build()); } diff --git a/java/core/src/main/java/com/google/protobuf/NioByteString.java b/java/core/src/main/java/com/google/protobuf/NioByteString.java index 6163c7b1..76594809 100644 --- a/java/core/src/main/java/com/google/protobuf/NioByteString.java +++ b/java/core/src/main/java/com/google/protobuf/NioByteString.java @@ -30,6 +30,8 @@ package com.google.protobuf; +import static com.google.protobuf.Internal.checkNotNull; + import java.io.IOException; import java.io.InputStream; import java.io.InvalidObjectException; @@ -49,10 +51,9 @@ final class NioByteString extends ByteString.LeafByteString { private final ByteBuffer buffer; NioByteString(ByteBuffer buffer) { - if (buffer == null) { - throw new NullPointerException("buffer"); - } + checkNotNull(buffer, "buffer"); + // Use native byte order for fast fixed32/64 operations. this.buffer = buffer.slice().order(ByteOrder.nativeOrder()); } @@ -266,7 +267,7 @@ final class NioByteString extends ByteString.LeafByteString { @Override public CodedInputStream newCodedInput() { - return CodedInputStream.newInstance(buffer); + return CodedInputStream.newInstance(buffer, true); } /** diff --git a/java/core/src/main/java/com/google/protobuf/Parser.java b/java/core/src/main/java/com/google/protobuf/Parser.java index cfbcb442..e07c6895 100644 --- a/java/core/src/main/java/com/google/protobuf/Parser.java +++ b/java/core/src/main/java/com/google/protobuf/Parser.java @@ -31,6 +31,7 @@ package com.google.protobuf; import java.io.InputStream; +import java.nio.ByteBuffer; /** * Abstract interface for parsing Protocol Messages. @@ -93,6 +94,18 @@ public interface Parser<MessageType> { // Convenience methods. /** + * Parses {@code data} as a message of {@code MessageType}. This is just a small wrapper around + * {@link #parseFrom(CodedInputStream)}. + */ + public MessageType parseFrom(ByteBuffer data) throws InvalidProtocolBufferException; + + /** + * Parses {@code data} as a message of {@code MessageType}. This is just a small wrapper around + * {@link #parseFrom(CodedInputStream, ExtensionRegistryLite)}. + */ + public MessageType parseFrom(ByteBuffer data, ExtensionRegistryLite extensionRegistry) + throws InvalidProtocolBufferException; + /** * Parses {@code data} as a message of {@code MessageType}. * This is just a small wrapper around {@link #parseFrom(CodedInputStream)}. */ diff --git a/java/core/src/main/java/com/google/protobuf/PrimitiveNonBoxingCollection.java b/java/core/src/main/java/com/google/protobuf/PrimitiveNonBoxingCollection.java new file mode 100644 index 00000000..79b5769d --- /dev/null +++ b/java/core/src/main/java/com/google/protobuf/PrimitiveNonBoxingCollection.java @@ -0,0 +1,34 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package com.google.protobuf; + +/** A marker interface indicating that the collection supports primitives and is non-boxing. */ +interface PrimitiveNonBoxingCollection {} diff --git a/java/core/src/main/java/com/google/protobuf/RepeatedFieldBuilderV3.java b/java/core/src/main/java/com/google/protobuf/RepeatedFieldBuilderV3.java index 77b61b5f..30c991d4 100644 --- a/java/core/src/main/java/com/google/protobuf/RepeatedFieldBuilderV3.java +++ b/java/core/src/main/java/com/google/protobuf/RepeatedFieldBuilderV3.java @@ -30,6 +30,8 @@ package com.google.protobuf; +import static com.google.protobuf.Internal.checkNotNull; + import java.util.AbstractList; import java.util.ArrayList; import java.util.Collection; @@ -290,9 +292,7 @@ public class RepeatedFieldBuilderV3 */ public RepeatedFieldBuilderV3<MType, BType, IType> setMessage( int index, MType message) { - if (message == null) { - throw new NullPointerException(); - } + checkNotNull(message); ensureMutableMessageList(); messages.set(index, message); if (builders != null) { @@ -315,9 +315,7 @@ public class RepeatedFieldBuilderV3 */ public RepeatedFieldBuilderV3<MType, BType, IType> addMessage( MType message) { - if (message == null) { - throw new NullPointerException(); - } + checkNotNull(message); ensureMutableMessageList(); messages.add(message); if (builders != null) { @@ -339,9 +337,7 @@ public class RepeatedFieldBuilderV3 */ public RepeatedFieldBuilderV3<MType, BType, IType> addMessage( int index, MType message) { - if (message == null) { - throw new NullPointerException(); - } + checkNotNull(message); ensureMutableMessageList(); messages.add(index, message); if (builders != null) { @@ -363,9 +359,7 @@ public class RepeatedFieldBuilderV3 public RepeatedFieldBuilderV3<MType, BType, IType> addAllMessages( Iterable<? extends MType> values) { for (final MType value : values) { - if (value == null) { - throw new NullPointerException(); - } + checkNotNull(value); } // If we can inspect the size, we can more efficiently add messages. diff --git a/java/core/src/main/java/com/google/protobuf/RopeByteString.java b/java/core/src/main/java/com/google/protobuf/RopeByteString.java index 3f3e9bd1..6fa555df 100644 --- a/java/core/src/main/java/com/google/protobuf/RopeByteString.java +++ b/java/core/src/main/java/com/google/protobuf/RopeByteString.java @@ -406,6 +406,7 @@ final class RopeByteString extends ByteString { right.writeTo(output); } + @Override protected String toStringInternal(Charset charset) { return new String(toByteArray(), charset); diff --git a/java/core/src/main/java/com/google/protobuf/SingleFieldBuilderV3.java b/java/core/src/main/java/com/google/protobuf/SingleFieldBuilderV3.java index fb1f76a7..8ab0f26d 100644 --- a/java/core/src/main/java/com/google/protobuf/SingleFieldBuilderV3.java +++ b/java/core/src/main/java/com/google/protobuf/SingleFieldBuilderV3.java @@ -30,6 +30,8 @@ package com.google.protobuf; +import static com.google.protobuf.Internal.checkNotNull; + /** * {@code SingleFieldBuilderV3} implements a structure that a protocol * message uses to hold a single field of another protocol message. It supports @@ -84,10 +86,7 @@ public class SingleFieldBuilderV3 MType message, AbstractMessage.BuilderParent parent, boolean isClean) { - if (message == null) { - throw new NullPointerException(); - } - this.message = message; + this.message = checkNotNull(message); this.parent = parent; this.isClean = isClean; } @@ -169,10 +168,7 @@ public class SingleFieldBuilderV3 */ public SingleFieldBuilderV3<MType, BType, IType> setMessage( MType message) { - if (message == null) { - throw new NullPointerException(); - } - this.message = message; + this.message = checkNotNull(message); if (builder != null) { builder.dispose(); builder = null; diff --git a/java/core/src/main/java/com/google/protobuf/SmallSortedMap.java b/java/core/src/main/java/com/google/protobuf/SmallSortedMap.java index 409fec10..279edc4d 100644 --- a/java/core/src/main/java/com/google/protobuf/SmallSortedMap.java +++ b/java/core/src/main/java/com/google/protobuf/SmallSortedMap.java @@ -197,6 +197,7 @@ class SmallSortedMap<K extends Comparable<K>, V> extends AbstractMap<K, V> { overflowEntries.entrySet(); } + @Override public int size() { return entryList.size() + overflowEntries.size(); @@ -356,6 +357,7 @@ class SmallSortedMap<K extends Comparable<K>, V> extends AbstractMap<K, V> { return lazyEntrySet; } + /** * @throws UnsupportedOperationException if {@link #makeImmutable()} has * has been called. @@ -525,6 +527,7 @@ class SmallSortedMap<K extends Comparable<K>, V> extends AbstractMap<K, V> { } } + /** * Iterator implementation that switches from the entry array to the overflow * entries appropriately. @@ -537,8 +540,8 @@ class SmallSortedMap<K extends Comparable<K>, V> extends AbstractMap<K, V> { @Override public boolean hasNext() { - return (pos + 1) < entryList.size() || - getOverflowIterator().hasNext(); + return (pos + 1) < entryList.size() + || (!overflowEntries.isEmpty() && getOverflowIterator().hasNext()); } @Override @@ -617,43 +620,43 @@ class SmallSortedMap<K extends Comparable<K>, V> extends AbstractMap<K, V> { return (Iterable<T>) ITERABLE; } } - + @Override public boolean equals(Object o) { if (this == o) { return true; } - + if (!(o instanceof SmallSortedMap)) { return super.equals(o); } - + SmallSortedMap<?, ?> other = (SmallSortedMap<?, ?>) o; final int size = size(); if (size != other.size()) { return false; } - + // Best effort try to avoid allocating an entry set. final int numArrayEntries = getNumArrayEntries(); if (numArrayEntries != other.getNumArrayEntries()) { return entrySet().equals(other.entrySet()); } - + for (int i = 0; i < numArrayEntries; i++) { if (!getArrayEntryAt(i).equals(other.getArrayEntryAt(i))) { return false; } } - + if (numArrayEntries != size) { return overflowEntries.equals(other.overflowEntries); } - - + + return true; } - + @Override public int hashCode() { int h = 0; diff --git a/java/core/src/main/java/com/google/protobuf/TextFormat.java b/java/core/src/main/java/com/google/protobuf/TextFormat.java index ff13675d..25c3474f 100644 --- a/java/core/src/main/java/com/google/protobuf/TextFormat.java +++ b/java/core/src/main/java/com/google/protobuf/TextFormat.java @@ -34,7 +34,6 @@ import com.google.protobuf.Descriptors.Descriptor; import com.google.protobuf.Descriptors.EnumDescriptor; import com.google.protobuf.Descriptors.EnumValueDescriptor; import com.google.protobuf.Descriptors.FieldDescriptor; - import java.io.IOException; import java.math.BigInteger; import java.nio.CharBuffer; @@ -56,14 +55,7 @@ import java.util.regex.Pattern; public final class TextFormat { private TextFormat() {} - private static final Logger logger = - Logger.getLogger(TextFormat.class.getName()); - - private static final Printer DEFAULT_PRINTER = new Printer(); - private static final Printer SINGLE_LINE_PRINTER = - (new Printer()).setSingleLineMode(true); - private static final Printer UNICODE_PRINTER = - (new Printer()).setEscapeNonAscii(false); + private static final Logger logger = Logger.getLogger(TextFormat.class.getName()); /** * Outputs a textual representation of the Protocol Message supplied into @@ -73,14 +65,14 @@ public final class TextFormat { public static void print( final MessageOrBuilder message, final Appendable output) throws IOException { - DEFAULT_PRINTER.print(message, new TextGenerator(output)); + Printer.DEFAULT.print(message, multiLineOutput(output)); } /** Outputs a textual representation of {@code fields} to {@code output}. */ public static void print(final UnknownFieldSet fields, final Appendable output) throws IOException { - DEFAULT_PRINTER.printUnknownFields(fields, new TextGenerator(output)); + Printer.DEFAULT.printUnknownFields(fields, multiLineOutput(output)); } /** @@ -90,7 +82,7 @@ public final class TextFormat { public static void printUnicode( final MessageOrBuilder message, final Appendable output) throws IOException { - UNICODE_PRINTER.print(message, new TextGenerator(output)); + Printer.UNICODE.print(message, multiLineOutput(output)); } /** @@ -100,7 +92,7 @@ public final class TextFormat { public static void printUnicode(final UnknownFieldSet fields, final Appendable output) throws IOException { - UNICODE_PRINTER.printUnknownFields(fields, new TextGenerator(output)); + Printer.UNICODE.printUnknownFields(fields, multiLineOutput(output)); } /** @@ -109,10 +101,9 @@ public final class TextFormat { */ public static String shortDebugString(final MessageOrBuilder message) { try { - final StringBuilder sb = new StringBuilder(); - SINGLE_LINE_PRINTER.print(message, new TextGenerator(sb)); - // Single line mode currently might have an extra space at the end. - return sb.toString().trim(); + final StringBuilder text = new StringBuilder(); + Printer.DEFAULT.print(message, singleLineOutput(text)); + return text.toString(); } catch (IOException e) { throw new IllegalStateException(e); } @@ -125,11 +116,11 @@ public final class TextFormat { public static String shortDebugString(final FieldDescriptor field, final Object value) { try { - final StringBuilder sb = new StringBuilder(); - SINGLE_LINE_PRINTER.printField(field, value, new TextGenerator(sb)); - return sb.toString().trim(); + final StringBuilder text = new StringBuilder(); + Printer.DEFAULT.printField(field, value, singleLineOutput(text)); + return text.toString(); } catch (IOException e) { - throw new IllegalStateException(e); + throw new IllegalStateException(e); } } @@ -139,10 +130,9 @@ public final class TextFormat { */ public static String shortDebugString(final UnknownFieldSet fields) { try { - final StringBuilder sb = new StringBuilder(); - SINGLE_LINE_PRINTER.printUnknownFields(fields, new TextGenerator(sb)); - // Single line mode currently might have an extra space at the end. - return sb.toString().trim(); + final StringBuilder text = new StringBuilder(); + Printer.DEFAULT.printUnknownFields(fields, singleLineOutput(text)); + return text.toString(); } catch (IOException e) { throw new IllegalStateException(e); } @@ -183,7 +173,7 @@ public final class TextFormat { public static String printToUnicodeString(final MessageOrBuilder message) { try { final StringBuilder text = new StringBuilder(); - UNICODE_PRINTER.print(message, new TextGenerator(text)); + Printer.UNICODE.print(message, multiLineOutput(text)); return text.toString(); } catch (IOException e) { throw new IllegalStateException(e); @@ -197,7 +187,7 @@ public final class TextFormat { public static String printToUnicodeString(final UnknownFieldSet fields) { try { final StringBuilder text = new StringBuilder(); - UNICODE_PRINTER.printUnknownFields(fields, new TextGenerator(text)); + Printer.UNICODE.printUnknownFields(fields, multiLineOutput(text)); return text.toString(); } catch (IOException e) { throw new IllegalStateException(e); @@ -208,7 +198,7 @@ public final class TextFormat { final Object value, final Appendable output) throws IOException { - DEFAULT_PRINTER.printField(field, value, new TextGenerator(output)); + Printer.DEFAULT.printField(field, value, multiLineOutput(output)); } public static String printFieldToString(final FieldDescriptor field, @@ -223,6 +213,23 @@ public final class TextFormat { } /** + * Outputs a unicode textual representation of the value of given field value. + * + * <p>Same as {@code printFieldValue()}, except that non-ASCII characters in string type fields + * are not escaped in backslash+octals. + * + * @param field the descriptor of the field + * @param value the value of the field + * @param output the output to which to append the formatted value + * @throws ClassCastException if the value is not appropriate for the given field descriptor + * @throws IOException if there is an exception writing to the output + */ + public static void printUnicodeFieldValue( + final FieldDescriptor field, final Object value, final Appendable output) throws IOException { + Printer.UNICODE.printFieldValue(field, value, multiLineOutput(output)); + } + + /** * Outputs a textual representation of the value of given field value. * * @param field the descriptor of the field @@ -236,7 +243,7 @@ public final class TextFormat { final Object value, final Appendable output) throws IOException { - DEFAULT_PRINTER.printFieldValue(field, value, new TextGenerator(output)); + Printer.DEFAULT.printFieldValue(field, value, multiLineOutput(output)); } /** @@ -253,7 +260,7 @@ public final class TextFormat { final Object value, final Appendable output) throws IOException { - printUnknownFieldValue(tag, value, new TextGenerator(output)); + printUnknownFieldValue(tag, value, multiLineOutput(output)); } private static void printUnknownFieldValue(final int tag, @@ -272,12 +279,24 @@ public final class TextFormat { generator.print(String.format((Locale) null, "0x%016x", (Long) value)); break; case WireFormat.WIRETYPE_LENGTH_DELIMITED: - generator.print("\""); - generator.print(escapeBytes((ByteString) value)); - generator.print("\""); + try { + // Try to parse and print the field as an embedded message + UnknownFieldSet message = UnknownFieldSet.parseFrom((ByteString) value); + generator.print("{"); + generator.eol(); + generator.indent(); + Printer.DEFAULT.printUnknownFields(message, generator); + generator.outdent(); + generator.print("}"); + } catch (InvalidProtocolBufferException e) { + // If not parseable as a message, print as a String + generator.print("\""); + generator.print(escapeBytes((ByteString) value)); + generator.print("\""); + } break; case WireFormat.WIRETYPE_START_GROUP: - DEFAULT_PRINTER.printUnknownFields((UnknownFieldSet) value, generator); + Printer.DEFAULT.printUnknownFields((UnknownFieldSet) value, generator); break; default: throw new IllegalArgumentException("Bad tag: " + tag); @@ -286,24 +305,16 @@ public final class TextFormat { /** Helper class for converting protobufs to text. */ private static final class Printer { - /** Whether to omit newlines from the output. */ - boolean singleLineMode = false; + // Printer instance which escapes non-ASCII characters. + static final Printer DEFAULT = new Printer(true); + // Printer instance which emits Unicode (it still escapes newlines and quotes in strings). + static final Printer UNICODE = new Printer(false); /** Whether to escape non ASCII characters with backslash and octal. */ - boolean escapeNonAscii = true; + private final boolean escapeNonAscii; - private Printer() {} - - /** Setter of singleLineMode */ - private Printer setSingleLineMode(boolean singleLineMode) { - this.singleLineMode = singleLineMode; - return this; - } - - /** Setter of escapeNonAscii */ - private Printer setEscapeNonAscii(boolean escapeNonAscii) { + private Printer(boolean escapeNonAscii) { this.escapeNonAscii = escapeNonAscii; - return this; } private void print( @@ -355,12 +366,9 @@ public final class TextFormat { } if (field.getJavaType() == FieldDescriptor.JavaType.MESSAGE) { - if (singleLineMode) { - generator.print(" { "); - } else { - generator.print(" {\n"); - generator.indent(); - } + generator.print(" {"); + generator.eol(); + generator.indent(); } else { generator.print(": "); } @@ -368,19 +376,10 @@ public final class TextFormat { printFieldValue(field, value, generator); if (field.getJavaType() == FieldDescriptor.JavaType.MESSAGE) { - if (singleLineMode) { - generator.print("} "); - } else { - generator.outdent(); - generator.print("}\n"); - } - } else { - if (singleLineMode) { - generator.print(" "); - } else { - generator.print("\n"); - } + generator.outdent(); + generator.print("}"); } + generator.eol(); } private void printFieldValue(final FieldDescriptor field, @@ -469,19 +468,13 @@ public final class TextFormat { field.getLengthDelimitedList(), generator); for (final UnknownFieldSet value : field.getGroupList()) { generator.print(entry.getKey().toString()); - if (singleLineMode) { - generator.print(" { "); - } else { - generator.print(" {\n"); - generator.indent(); - } + generator.print(" {"); + generator.eol(); + generator.indent(); printUnknownFields(value, generator); - if (singleLineMode) { - generator.print("} "); - } else { - generator.outdent(); - generator.print("}\n"); - } + generator.outdent(); + generator.print("}"); + generator.eol(); } } } @@ -495,7 +488,7 @@ public final class TextFormat { generator.print(String.valueOf(number)); generator.print(": "); printUnknownFieldValue(wireType, value, generator); - generator.print(singleLineMode ? " " : "\n"); + generator.eol(); } } } @@ -521,16 +514,29 @@ public final class TextFormat { } } - /** + private static TextGenerator multiLineOutput(Appendable output) { + return new TextGenerator(output, false); + } + + private static TextGenerator singleLineOutput(Appendable output) { + return new TextGenerator(output, true); + } + + /** * An inner class for writing text to the output stream. */ private static final class TextGenerator { private final Appendable output; private final StringBuilder indent = new StringBuilder(); - private boolean atStartOfLine = true; + private final boolean singleLineMode; + // While technically we are "at the start of a line" at the very beginning of the output, all + // we would do in response to this is emit the (zero length) indentation, so it has no effect. + // Setting it false here does however suppress an unwanted leading space in single-line mode. + private boolean atStartOfLine = false; - private TextGenerator(final Appendable output) { + private TextGenerator(final Appendable output, boolean singleLineMode) { this.output = output; + this.singleLineMode = singleLineMode; } /** @@ -552,35 +558,31 @@ public final class TextFormat { throw new IllegalArgumentException( " Outdent() without matching Indent()."); } - indent.delete(length - 2, length); + indent.setLength(length - 2); } /** - * Print text to the output stream. + * Print text to the output stream. Bare newlines are never expected to be passed to this + * method; to indicate the end of a line, call "eol()". */ public void print(final CharSequence text) throws IOException { - final int size = text.length(); - int pos = 0; - - for (int i = 0; i < size; i++) { - if (text.charAt(i) == '\n') { - write(text.subSequence(pos, i + 1)); - pos = i + 1; - atStartOfLine = true; - } + if (atStartOfLine) { + atStartOfLine = false; + output.append(singleLineMode ? " " : indent); } - write(text.subSequence(pos, size)); + output.append(text); } - private void write(final CharSequence data) throws IOException { - if (data.length() == 0) { - return; - } - if (atStartOfLine) { - atStartOfLine = false; - output.append(indent); + /** + * Signifies reaching the "end of the current line" in the output. In single-line mode, this + * does not result in a newline being emitted, but ensures that a separating space is written + * before the next output. + */ + public void eol() throws IOException { + if (!singleLineMode) { + output.append("\n"); } - output.append(data); + atStartOfLine = true; } } @@ -985,7 +987,7 @@ public final class TextFormat { nextToken(); return false; } else { - throw parseException("Expected \"true\" or \"false\"."); + throw parseException("Expected \"true\" or \"false\". Found \"" + currentToken + "\"."); } } @@ -1222,6 +1224,22 @@ public final class TextFormat { } /** + * Parse a text-format message from {@code input}. + * + * @return the parsed message, guaranteed initialized + */ + public static <T extends Message> T parse(final CharSequence input, + final Class<T> protoClass) + throws ParseException { + Message.Builder builder = + Internal.getDefaultInstance(protoClass).newBuilderForType(); + merge(input, builder); + @SuppressWarnings("unchecked") + T output = (T) builder.build(); + return output; + } + + /** * Parse a text-format message from {@code input} and merge the contents * into {@code builder}. Extensions will be recognized if they are * registered in {@code extensionRegistry}. @@ -1246,6 +1264,25 @@ public final class TextFormat { PARSER.merge(input, extensionRegistry, builder); } + /** + * Parse a text-format message from {@code input}. Extensions will be + * recognized if they are registered in {@code extensionRegistry}. + * + * @return the parsed message, guaranteed initialized + */ + public static <T extends Message> T parse( + final CharSequence input, + final ExtensionRegistry extensionRegistry, + final Class<T> protoClass) + throws ParseException { + Message.Builder builder = + Internal.getDefaultInstance(protoClass).newBuilderForType(); + merge(input, extensionRegistry, builder); + @SuppressWarnings("unchecked") + T output = (T) builder.build(); + return output; + } + /** * Parser for text-format proto2 instances. This class is thread-safe. @@ -1274,13 +1311,17 @@ public final class TextFormat { } private final boolean allowUnknownFields; + private final boolean allowUnknownEnumValues; private final SingularOverwritePolicy singularOverwritePolicy; private TextFormatParseInfoTree.Builder parseInfoTreeBuilder; private Parser( - boolean allowUnknownFields, SingularOverwritePolicy singularOverwritePolicy, + boolean allowUnknownFields, + boolean allowUnknownEnumValues, + SingularOverwritePolicy singularOverwritePolicy, TextFormatParseInfoTree.Builder parseInfoTreeBuilder) { this.allowUnknownFields = allowUnknownFields; + this.allowUnknownEnumValues = allowUnknownEnumValues; this.singularOverwritePolicy = singularOverwritePolicy; this.parseInfoTreeBuilder = parseInfoTreeBuilder; } @@ -1297,6 +1338,7 @@ public final class TextFormat { */ public static class Builder { private boolean allowUnknownFields = false; + private boolean allowUnknownEnumValues = false; private SingularOverwritePolicy singularOverwritePolicy = SingularOverwritePolicy.ALLOW_SINGULAR_OVERWRITES; private TextFormatParseInfoTree.Builder parseInfoTreeBuilder = null; @@ -1318,7 +1360,10 @@ public final class TextFormat { public Parser build() { return new Parser( - allowUnknownFields, singularOverwritePolicy, parseInfoTreeBuilder); + allowUnknownFields, + allowUnknownEnumValues, + singularOverwritePolicy, + parseInfoTreeBuilder); } } @@ -1382,7 +1427,7 @@ public final class TextFormat { return text; } - // Check both unknown fields and unknown extensions and log warming messages + // Check both unknown fields and unknown extensions and log warning messages // or throw exceptions according to the flag. private void checkUnknownFields(final List<String> unknownFields) throws ParseException { @@ -1442,7 +1487,7 @@ public final class TextFormat { /** * Parse a single field from {@code tokenizer} and merge it into - * {@code builder}. + * {@code target}. */ private void mergeField(final Tokenizer tokenizer, final ExtensionRegistry extensionRegistry, @@ -1469,9 +1514,15 @@ public final class TextFormat { extensionRegistry, name.toString()); if (extension == null) { - unknownFields.add((tokenizer.getPreviousLine() + 1) + ":" + - (tokenizer.getPreviousColumn() + 1) + ":\t" + - type.getFullName() + ".[" + name + "]"); + unknownFields.add( + (tokenizer.getPreviousLine() + 1) + + ":" + + (tokenizer.getPreviousColumn() + 1) + + ":\t" + + type.getFullName() + + ".[" + + name + + "]"); } else { if (extension.descriptor.getContainingType() != type) { throw tokenizer.parseExceptionPreviousToken( @@ -1506,9 +1557,14 @@ public final class TextFormat { } if (field == null) { - unknownFields.add((tokenizer.getPreviousLine() + 1) + ":" + - (tokenizer.getPreviousColumn() + 1) + ":\t" + - type.getFullName() + "." + name); + unknownFields.add( + (tokenizer.getPreviousLine() + 1) + + ":" + + (tokenizer.getPreviousColumn() + 1) + + ":\t" + + type.getFullName() + + "." + + name); } } @@ -1576,14 +1632,22 @@ public final class TextFormat { // Support specifying repeated field values as a comma-separated list. // Ex."foo: [1, 2, 3]" if (field.isRepeated() && tokenizer.tryConsume("[")) { - while (true) { - consumeFieldValue(tokenizer, extensionRegistry, target, field, extension, - parseTreeBuilder, unknownFields); - if (tokenizer.tryConsume("]")) { - // End of list. - break; + if (!tokenizer.tryConsume("]")) { // Allow "foo: []" to be treated as empty. + while (true) { + consumeFieldValue( + tokenizer, + extensionRegistry, + target, + field, + extension, + parseTreeBuilder, + unknownFields); + if (tokenizer.tryConsume("]")) { + // End of list. + break; + } + tokenizer.consume(","); } - tokenizer.consume(","); } } else { consumeFieldValue(tokenizer, extensionRegistry, target, field, @@ -1681,17 +1745,40 @@ public final class TextFormat { final int number = tokenizer.consumeInt32(); value = enumType.findValueByNumber(number); if (value == null) { - throw tokenizer.parseExceptionPreviousToken( - "Enum type \"" + enumType.getFullName() - + "\" has no value with number " + number + '.'); + String unknownValueMsg = + "Enum type \"" + + enumType.getFullName() + + "\" has no value with number " + + number + + '.'; + if (allowUnknownEnumValues) { + logger.warning(unknownValueMsg); + return; + } else { + throw tokenizer.parseExceptionPreviousToken( + "Enum type \"" + + enumType.getFullName() + + "\" has no value with number " + + number + + '.'); + } } } else { final String id = tokenizer.consumeIdentifier(); value = enumType.findValueByName(id); if (value == null) { - throw tokenizer.parseExceptionPreviousToken( - "Enum type \"" + enumType.getFullName() - + "\" has no value named \"" + id + "\"."); + String unknownValueMsg = + "Enum type \"" + + enumType.getFullName() + + "\" has no value named \"" + + id + + "\"."; + if (allowUnknownEnumValues) { + logger.warning(unknownValueMsg); + return; + } else { + throw tokenizer.parseExceptionPreviousToken(unknownValueMsg); + } } } @@ -1704,6 +1791,8 @@ public final class TextFormat { } if (field.isRepeated()) { + // TODO(b/29122459): If field.isMapField() and FORBID_SINGULAR_OVERWRITES mode, + // check for duplicate map keys here. target.addRepeatedField(field, value); } else if ((singularOverwritePolicy == SingularOverwritePolicy.FORBID_SINGULAR_OVERWRITES) diff --git a/java/core/src/main/java/com/google/protobuf/TextFormatParseInfoTree.java b/java/core/src/main/java/com/google/protobuf/TextFormatParseInfoTree.java index 5c43b2c3..0127ce92 100644 --- a/java/core/src/main/java/com/google/protobuf/TextFormatParseInfoTree.java +++ b/java/core/src/main/java/com/google/protobuf/TextFormatParseInfoTree.java @@ -45,7 +45,8 @@ import java.util.Map.Entry; * * <p>The locations of primary fields values are retrieved by {@code getLocation} or * {@code getLocations}. The locations of sub message values are within nested - * {@code TextFormatParseInfoTree}s and are retrieve by {@code getNestedTree} or {@code getNestedTrees}. + * {@code TextFormatParseInfoTree}s and are retrieve by {@code getNestedTree} or + * {@code getNestedTrees}. * * <p>The {@code TextFormatParseInfoTree} is created by a Builder. */ diff --git a/java/core/src/main/java/com/google/protobuf/UnknownFieldSet.java b/java/core/src/main/java/com/google/protobuf/UnknownFieldSet.java index 6d33d3a8..37d64633 100644 --- a/java/core/src/main/java/com/google/protobuf/UnknownFieldSet.java +++ b/java/core/src/main/java/com/google/protobuf/UnknownFieldSet.java @@ -31,7 +31,6 @@ package com.google.protobuf; import com.google.protobuf.AbstractMessageLite.Builder.LimitedInputStream; - import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; @@ -39,6 +38,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.ListIterator; import java.util.Map; import java.util.TreeMap; @@ -58,7 +58,9 @@ import java.util.TreeMap; */ public final class UnknownFieldSet implements MessageLite { - private UnknownFieldSet() {} + private UnknownFieldSet() { + fields = null; + } /** Create a new {@link Builder}. */ public static Builder newBuilder() { @@ -82,16 +84,18 @@ public final class UnknownFieldSet implements MessageLite { return defaultInstance; } private static final UnknownFieldSet defaultInstance = - new UnknownFieldSet(Collections.<Integer, Field>emptyMap()); + new UnknownFieldSet(Collections.<Integer, Field>emptyMap(), + Collections.<Integer, Field>emptyMap()); /** * Construct an {@code UnknownFieldSet} around the given map. The map is * expected to be immutable. */ - private UnknownFieldSet(final Map<Integer, Field> fields) { + UnknownFieldSet(final Map<Integer, Field> fields, + final Map<Integer, Field> fieldsDescending) { this.fields = fields; } - private Map<Integer, Field> fields; + private final Map<Integer, Field> fields; @Override @@ -224,10 +228,8 @@ public final class UnknownFieldSet implements MessageLite { } } - /** - * Get the number of bytes required to encode this set using - * {@code MessageSet} wire format. - */ + + /** Get the number of bytes required to encode this set using {@code MessageSet} wire format. */ public int getSerializedSizeAsMessageSet() { int result = 0; for (final Map.Entry<Integer, Field> entry : fields.entrySet()) { @@ -343,12 +345,13 @@ public final class UnknownFieldSet implements MessageLite { */ @Override public UnknownFieldSet build() { - getFieldBuilder(0); // Force lastField to be built. + getFieldBuilder(0); // Force lastField to be built. final UnknownFieldSet result; if (fields.isEmpty()) { result = getDefaultInstance(); } else { - result = new UnknownFieldSet(Collections.unmodifiableMap(fields)); + Map<Integer, Field> descendingFields = null; + result = new UnknownFieldSet(Collections.unmodifiableMap(fields), descendingFields); } fields = null; return result; @@ -363,8 +366,9 @@ public final class UnknownFieldSet implements MessageLite { @Override public Builder clone() { getFieldBuilder(0); // Force lastField to be built. + Map<Integer, Field> descendingFields = null; return UnknownFieldSet.newBuilder().mergeFrom( - new UnknownFieldSet(fields)); + new UnknownFieldSet(fields, descendingFields)); } @Override @@ -841,9 +845,10 @@ public final class UnknownFieldSet implements MessageLite { } } + /** - * Get the number of bytes required to encode this field, including field - * number, using {@code MessageSet} wire format. + * Get the number of bytes required to encode this field, including field number, using {@code + * MessageSet} wire format. */ public int getSerializedSizeAsMessageSetExtension(final int fieldNumber) { int result = 0; @@ -1022,7 +1027,7 @@ public final class UnknownFieldSet implements MessageLite { } catch (InvalidProtocolBufferException e) { throw e.setUnfinishedMessage(builder.buildPartial()); } catch (IOException e) { - throw new InvalidProtocolBufferException(e.getMessage()) + throw new InvalidProtocolBufferException(e) .setUnfinishedMessage(builder.buildPartial()); } return builder.buildPartial(); diff --git a/java/core/src/main/java/com/google/protobuf/UnknownFieldSetLite.java b/java/core/src/main/java/com/google/protobuf/UnknownFieldSetLite.java index 9500f905..f0b919ad 100644 --- a/java/core/src/main/java/com/google/protobuf/UnknownFieldSetLite.java +++ b/java/core/src/main/java/com/google/protobuf/UnknownFieldSetLite.java @@ -81,7 +81,7 @@ public final class UnknownFieldSetLite { System.arraycopy(second.objects, 0, objects, first.count, second.count); return new UnknownFieldSetLite(count, tags, objects, true /* isMutable */); } - + /** * The number of elements in the set. */ @@ -176,6 +176,42 @@ public final class UnknownFieldSetLite { } /** + * Serializes the set and writes it to {@code output} using {@code MessageSet} wire format. + * + * <p>For use by generated code only. + */ + public void writeAsMessageSetTo(CodedOutputStream output) throws IOException { + for (int i = 0; i < count; i++) { + int fieldNumber = WireFormat.getTagFieldNumber(tags[i]); + output.writeRawMessageSetExtension(fieldNumber, (ByteString) objects[i]); + } + } + + + /** + * Get the number of bytes required to encode this field, including field number, using {@code + * MessageSet} wire format. + */ + public int getSerializedSizeAsMessageSet() { + int size = memoizedSerializedSize; + if (size != -1) { + return size; + } + + size = 0; + for (int i = 0; i < count; i++) { + int tag = tags[i]; + int fieldNumber = WireFormat.getTagFieldNumber(tag); + size += CodedOutputStream.computeRawMessageSetExtensionSize( + fieldNumber, (ByteString) objects[i]); + } + + memoizedSerializedSize = size; + + return size; + } + + /** * Get the number of bytes required to encode this set. * * <p>For use by generated code only. @@ -216,6 +252,24 @@ public final class UnknownFieldSetLite { return size; } + + private static boolean equals(int[] tags1, int[] tags2, int count) { + for (int i = 0; i < count; ++i) { + if (tags1[i] != tags2[i]) { + return false; + } + } + return true; + } + + private static boolean equals(Object[] objects1, Object[] objects2, int count) { + for (int i = 0; i < count; ++i) { + if (!objects1[i].equals(objects2[i])) { + return false; + } + } + return true; + } @Override public boolean equals(Object obj) { @@ -233,23 +287,38 @@ public final class UnknownFieldSetLite { UnknownFieldSetLite other = (UnknownFieldSetLite) obj; if (count != other.count - // TODO(dweis): Only have to compare up to count but at worst 2x worse than we need to do. - || !Arrays.equals(tags, other.tags) - || !Arrays.deepEquals(objects, other.objects)) { + || !equals(tags, other.tags, count) + || !equals(objects, other.objects, count)) { return false; } return true; } + private static int hashCode(int[] tags, int count) { + int hashCode = 17; + for (int i = 0; i < count; ++i) { + hashCode = 31 * hashCode + tags[i]; + } + return hashCode; + } + + private static int hashCode(Object[] objects, int count) { + int hashCode = 17; + for (int i = 0; i < count; ++i) { + hashCode = 31 * hashCode + objects[i].hashCode(); + } + return hashCode; + } + @Override public int hashCode() { int hashCode = 17; - + hashCode = 31 * hashCode + count; - hashCode = 31 * hashCode + Arrays.hashCode(tags); - hashCode = 31 * hashCode + Arrays.deepHashCode(objects); - + hashCode = 31 * hashCode + hashCode(tags, count); + hashCode = 31 * hashCode + hashCode(objects, count); + return hashCode; } @@ -268,7 +337,9 @@ public final class UnknownFieldSetLite { } } - private void storeField(int tag, Object value) { + // Package private for unsafe experimental runtime. + void storeField(int tag, Object value) { + checkMutable(); ensureCapacity(); tags[count] = tag; diff --git a/java/core/src/main/java/com/google/protobuf/UnsafeByteOperations.java b/java/core/src/main/java/com/google/protobuf/UnsafeByteOperations.java index e72a6b4d..878c7758 100644 --- a/java/core/src/main/java/com/google/protobuf/UnsafeByteOperations.java +++ b/java/core/src/main/java/com/google/protobuf/UnsafeByteOperations.java @@ -67,16 +67,34 @@ public final class UnsafeByteOperations { /** * An unsafe operation that returns a {@link ByteString} that is backed by the provided buffer. * + * @param buffer the buffer to be wrapped + * @return a {@link ByteString} backed by the provided buffer + */ + public static ByteString unsafeWrap(byte[] buffer) { + return ByteString.wrap(buffer); + } + + /** + * An unsafe operation that returns a {@link ByteString} that is backed by a subregion of the + * provided buffer. + * + * @param buffer the buffer to be wrapped + * @param offset the offset of the wrapped region + * @param length the number of bytes of the wrapped region + * @return a {@link ByteString} backed by the provided buffer + */ + public static ByteString unsafeWrap(byte[] buffer, int offset, int length) { + return ByteString.wrap(buffer, offset, length); + } + + /** + * An unsafe operation that returns a {@link ByteString} that is backed by the provided buffer. + * * @param buffer the Java NIO buffer to be wrapped * @return a {@link ByteString} backed by the provided buffer */ public static ByteString unsafeWrap(ByteBuffer buffer) { - if (buffer.hasArray()) { - final int offset = buffer.arrayOffset(); - return ByteString.wrap(buffer.array(), offset + buffer.position(), buffer.remaining()); - } else { - return new NioByteString(buffer); - } + return ByteString.wrap(buffer); } /** @@ -98,4 +116,5 @@ public final class UnsafeByteOperations { public static void unsafeWriteTo(ByteString bytes, ByteOutput output) throws IOException { bytes.writeTo(output); } + } diff --git a/java/core/src/main/java/com/google/protobuf/UnsafeUtil.java b/java/core/src/main/java/com/google/protobuf/UnsafeUtil.java index 6a4787d1..d84ef3c5 100644 --- a/java/core/src/main/java/com/google/protobuf/UnsafeUtil.java +++ b/java/core/src/main/java/com/google/protobuf/UnsafeUtil.java @@ -30,28 +30,48 @@ package com.google.protobuf; -import sun.misc.Unsafe; - import java.lang.reflect.Field; import java.nio.Buffer; import java.nio.ByteBuffer; import java.security.AccessController; import java.security.PrivilegedExceptionAction; +import java.util.logging.Level; +import java.util.logging.Logger; -/** - * Utility class for working with unsafe operations. - */ -// TODO(nathanmittler): Add support for Android Memory/MemoryBlock +/** Utility class for working with unsafe operations. */ final class UnsafeUtil { + private static final Logger logger = Logger.getLogger(UnsafeUtil.class.getName()); private static final sun.misc.Unsafe UNSAFE = getUnsafe(); + private static final MemoryAccessor MEMORY_ACCESSOR = getMemoryAccessor(); private static final boolean HAS_UNSAFE_BYTEBUFFER_OPERATIONS = supportsUnsafeByteBufferOperations(); private static final boolean HAS_UNSAFE_ARRAY_OPERATIONS = supportsUnsafeArrayOperations(); - private static final long ARRAY_BASE_OFFSET = byteArrayBaseOffset(); - private static final long BUFFER_ADDRESS_OFFSET = fieldOffset(field(Buffer.class, "address")); - private UnsafeUtil() { - } + private static final long BYTE_ARRAY_BASE_OFFSET = arrayBaseOffset(byte[].class); + // Micro-optimization: we can assume a scale of 1 and skip the multiply + // private static final long BYTE_ARRAY_INDEX_SCALE = 1; + + private static final long BOOLEAN_ARRAY_BASE_OFFSET = arrayBaseOffset(boolean[].class); + private static final long BOOLEAN_ARRAY_INDEX_SCALE = arrayIndexScale(boolean[].class); + + private static final long INT_ARRAY_BASE_OFFSET = arrayBaseOffset(int[].class); + private static final long INT_ARRAY_INDEX_SCALE = arrayIndexScale(int[].class); + + private static final long LONG_ARRAY_BASE_OFFSET = arrayBaseOffset(long[].class); + private static final long LONG_ARRAY_INDEX_SCALE = arrayIndexScale(long[].class); + + private static final long FLOAT_ARRAY_BASE_OFFSET = arrayBaseOffset(float[].class); + private static final long FLOAT_ARRAY_INDEX_SCALE = arrayIndexScale(float[].class); + + private static final long DOUBLE_ARRAY_BASE_OFFSET = arrayBaseOffset(double[].class); + private static final long DOUBLE_ARRAY_INDEX_SCALE = arrayIndexScale(double[].class); + + private static final long OBJECT_ARRAY_BASE_OFFSET = arrayBaseOffset(Object[].class); + private static final long OBJECT_ARRAY_INDEX_SCALE = arrayIndexScale(Object[].class); + + private static final long BUFFER_ADDRESS_OFFSET = fieldOffset(bufferAddressField()); + + private UnsafeUtil() {} static boolean hasUnsafeArrayOperations() { return HAS_UNSAFE_ARRAY_OPERATIONS; @@ -61,59 +81,193 @@ final class UnsafeUtil { return HAS_UNSAFE_BYTEBUFFER_OPERATIONS; } - static long getArrayBaseOffset() { - return ARRAY_BASE_OFFSET; + + static long objectFieldOffset(Field field) { + return MEMORY_ACCESSOR.objectFieldOffset(field); + } + + private static int arrayBaseOffset(Class<?> clazz) { + return HAS_UNSAFE_ARRAY_OPERATIONS ? MEMORY_ACCESSOR.arrayBaseOffset(clazz) : -1; + } + + private static int arrayIndexScale(Class<?> clazz) { + return HAS_UNSAFE_ARRAY_OPERATIONS ? MEMORY_ACCESSOR.arrayIndexScale(clazz) : -1; + } + + static byte getByte(Object target, long offset) { + return MEMORY_ACCESSOR.getByte(target, offset); + } + + static void putByte(Object target, long offset, byte value) { + MEMORY_ACCESSOR.putByte(target, offset, value); + } + + static int getInt(Object target, long offset) { + return MEMORY_ACCESSOR.getInt(target, offset); + } + + static void putInt(Object target, long offset, int value) { + MEMORY_ACCESSOR.putInt(target, offset, value); + } + + static long getLong(Object target, long offset) { + return MEMORY_ACCESSOR.getLong(target, offset); + } + + static void putLong(Object target, long offset, long value) { + MEMORY_ACCESSOR.putLong(target, offset, value); + } + + static boolean getBoolean(Object target, long offset) { + return MEMORY_ACCESSOR.getBoolean(target, offset); + } + + static void putBoolean(Object target, long offset, boolean value) { + MEMORY_ACCESSOR.putBoolean(target, offset, value); + } + + static float getFloat(Object target, long offset) { + return MEMORY_ACCESSOR.getFloat(target, offset); + } + + static void putFloat(Object target, long offset, float value) { + MEMORY_ACCESSOR.putFloat(target, offset, value); + } + + static double getDouble(Object target, long offset) { + return MEMORY_ACCESSOR.getDouble(target, offset); + } + + static void putDouble(Object target, long offset, double value) { + MEMORY_ACCESSOR.putDouble(target, offset, value); + } + + static Object getObject(Object target, long offset) { + return MEMORY_ACCESSOR.getObject(target, offset); + } + + static byte getByte(byte[] target, long index) { + return MEMORY_ACCESSOR.getByte(target, BYTE_ARRAY_BASE_OFFSET + index); + } + + static void putByte(byte[] target, long index, byte value) { + MEMORY_ACCESSOR.putByte(target, BYTE_ARRAY_BASE_OFFSET + index, value); + } + + static int getInt(int[] target, long index) { + return MEMORY_ACCESSOR.getInt(target, INT_ARRAY_BASE_OFFSET + (index * INT_ARRAY_INDEX_SCALE)); + } + + static void putInt(int[] target, long index, int value) { + MEMORY_ACCESSOR.putInt(target, INT_ARRAY_BASE_OFFSET + (index * INT_ARRAY_INDEX_SCALE), value); + } + + static long getLong(long[] target, long index) { + return MEMORY_ACCESSOR.getLong( + target, LONG_ARRAY_BASE_OFFSET + (index * LONG_ARRAY_INDEX_SCALE)); + } + + static void putLong(long[] target, long index, long value) { + MEMORY_ACCESSOR.putLong( + target, LONG_ARRAY_BASE_OFFSET + (index * LONG_ARRAY_INDEX_SCALE), value); + } + + static boolean getBoolean(boolean[] target, long index) { + return MEMORY_ACCESSOR.getBoolean( + target, BOOLEAN_ARRAY_BASE_OFFSET + (index * BOOLEAN_ARRAY_INDEX_SCALE)); + } + + static void putBoolean(boolean[] target, long index, boolean value) { + MEMORY_ACCESSOR.putBoolean( + target, BOOLEAN_ARRAY_BASE_OFFSET + (index * BOOLEAN_ARRAY_INDEX_SCALE), value); + } + + static float getFloat(float[] target, long index) { + return MEMORY_ACCESSOR.getFloat( + target, FLOAT_ARRAY_BASE_OFFSET + (index * FLOAT_ARRAY_INDEX_SCALE)); + } + + static void putFloat(float[] target, long index, float value) { + MEMORY_ACCESSOR.putFloat( + target, FLOAT_ARRAY_BASE_OFFSET + (index * FLOAT_ARRAY_INDEX_SCALE), value); + } + + static double getDouble(double[] target, long index) { + return MEMORY_ACCESSOR.getDouble( + target, DOUBLE_ARRAY_BASE_OFFSET + (index * DOUBLE_ARRAY_INDEX_SCALE)); } - static byte getByte(byte[] target, long offset) { - return UNSAFE.getByte(target, offset); + static void putDouble(double[] target, long index, double value) { + MEMORY_ACCESSOR.putDouble( + target, DOUBLE_ARRAY_BASE_OFFSET + (index * DOUBLE_ARRAY_INDEX_SCALE), value); } - static void putByte(byte[] target, long offset, byte value) { - UNSAFE.putByte(target, offset, value); + static Object getObject(Object[] target, long index) { + return MEMORY_ACCESSOR.getObject( + target, OBJECT_ARRAY_BASE_OFFSET + (index * OBJECT_ARRAY_INDEX_SCALE)); } - static void copyMemory( - byte[] src, long srcOffset, byte[] target, long targetOffset, long length) { - UNSAFE.copyMemory(src, srcOffset, target, targetOffset, length); + static void putObject(Object[] target, long index, Object value) { + MEMORY_ACCESSOR.putObject( + target, OBJECT_ARRAY_BASE_OFFSET + (index * OBJECT_ARRAY_INDEX_SCALE), value); } - static long getLong(byte[] target, long offset) { - return UNSAFE.getLong(target, offset); + static void copyMemory(byte[] src, long srcIndex, long targetOffset, long length) { + MEMORY_ACCESSOR.copyMemory(src, srcIndex, targetOffset, length); + } + + static void copyMemory(long srcOffset, byte[] target, long targetIndex, long length) { + MEMORY_ACCESSOR.copyMemory(srcOffset, target, targetIndex, length); + } + + static void copyMemory(byte[] src, long srcIndex, byte[] target, long targetIndex, long length) { + System.arraycopy(src, (int) srcIndex, target, (int) targetIndex, (int) length); } static byte getByte(long address) { - return UNSAFE.getByte(address); + return MEMORY_ACCESSOR.getByte(address); } static void putByte(long address, byte value) { - UNSAFE.putByte(address, value); + MEMORY_ACCESSOR.putByte(address, value); + } + + static int getInt(long address) { + return MEMORY_ACCESSOR.getInt(address); + } + + static void putInt(long address, int value) { + MEMORY_ACCESSOR.putInt(address, value); } static long getLong(long address) { - return UNSAFE.getLong(address); + return MEMORY_ACCESSOR.getLong(address); } - static void copyMemory(long srcAddress, long targetAddress, long length) { - UNSAFE.copyMemory(srcAddress, targetAddress, length); + static void putLong(long address, long value) { + MEMORY_ACCESSOR.putLong(address, value); } /** * Gets the offset of the {@code address} field of the given direct {@link ByteBuffer}. */ static long addressOffset(ByteBuffer buffer) { - return UNSAFE.getLong(buffer, BUFFER_ADDRESS_OFFSET); + return MEMORY_ACCESSOR.getLong(buffer, BUFFER_ADDRESS_OFFSET); + } + + static Object getStaticObject(Field field) { + return MEMORY_ACCESSOR.getStaticObject(field); } /** * Gets the {@code sun.misc.Unsafe} instance, or {@code null} if not available on this platform. */ - private static sun.misc.Unsafe getUnsafe() { + static sun.misc.Unsafe getUnsafe() { sun.misc.Unsafe unsafe = null; try { unsafe = AccessController.doPrivileged( - new PrivilegedExceptionAction<Unsafe>() { + new PrivilegedExceptionAction<sun.misc.Unsafe>() { @Override public sun.misc.Unsafe run() throws Exception { Class<sun.misc.Unsafe> k = sun.misc.Unsafe.class; @@ -136,52 +290,90 @@ final class UnsafeUtil { return unsafe; } - /** - * Indicates whether or not unsafe array operations are supported on this platform. - */ + /** Get a {@link MemoryAccessor} appropriate for the platform, or null if not supported. */ + private static MemoryAccessor getMemoryAccessor() { + if (UNSAFE == null) { + return null; + } + return new JvmMemoryAccessor(UNSAFE); + } + + /** Indicates whether or not unsafe array operations are supported on this platform. */ private static boolean supportsUnsafeArrayOperations() { - boolean supported = false; - if (UNSAFE != null) { - try { - Class<?> clazz = UNSAFE.getClass(); - clazz.getMethod("arrayBaseOffset", Class.class); - clazz.getMethod("getByte", Object.class, long.class); - clazz.getMethod("putByte", Object.class, long.class, byte.class); - clazz.getMethod("getLong", Object.class, long.class); - clazz.getMethod( - "copyMemory", Object.class, long.class, Object.class, long.class, long.class); - supported = true; - } catch (Throwable e) { - // Do nothing. - } + if (UNSAFE == null) { + return false; } - return supported; + try { + Class<?> clazz = UNSAFE.getClass(); + clazz.getMethod("objectFieldOffset", Field.class); + clazz.getMethod("arrayBaseOffset", Class.class); + clazz.getMethod("arrayIndexScale", Class.class); + clazz.getMethod("getInt", Object.class, long.class); + clazz.getMethod("putInt", Object.class, long.class, int.class); + clazz.getMethod("getLong", Object.class, long.class); + clazz.getMethod("putLong", Object.class, long.class, long.class); + clazz.getMethod("getObject", Object.class, long.class); + clazz.getMethod("putObject", Object.class, long.class, Object.class); + clazz.getMethod("getByte", Object.class, long.class); + clazz.getMethod("putByte", Object.class, long.class, byte.class); + clazz.getMethod("getBoolean", Object.class, long.class); + clazz.getMethod("putBoolean", Object.class, long.class, boolean.class); + clazz.getMethod("getFloat", Object.class, long.class); + clazz.getMethod("putFloat", Object.class, long.class, float.class); + clazz.getMethod("getDouble", Object.class, long.class); + clazz.getMethod("putDouble", Object.class, long.class, double.class); + + return true; + } catch (Throwable e) { + logger.log( + Level.WARNING, + "platform method missing - proto runtime falling back to safer methods: " + e); + } + return false; } private static boolean supportsUnsafeByteBufferOperations() { - boolean supported = false; - if (UNSAFE != null) { - try { - Class<?> clazz = UNSAFE.getClass(); - clazz.getMethod("objectFieldOffset", Field.class); - clazz.getMethod("getByte", long.class); - clazz.getMethod("getLong", Object.class, long.class); - clazz.getMethod("putByte", long.class, byte.class); - clazz.getMethod("getLong", long.class); - clazz.getMethod("copyMemory", long.class, long.class, long.class); - supported = true; - } catch (Throwable e) { - // Do nothing. + if (UNSAFE == null) { + return false; + } + try { + Class<?> clazz = UNSAFE.getClass(); + // Methods for getting direct buffer address. + clazz.getMethod("objectFieldOffset", Field.class); + clazz.getMethod("getLong", Object.class, long.class); + + if (bufferAddressField() == null) { + return false; } + + clazz.getMethod("getByte", long.class); + clazz.getMethod("putByte", long.class, byte.class); + clazz.getMethod("getInt", long.class); + clazz.getMethod("putInt", long.class, int.class); + clazz.getMethod("getLong", long.class); + clazz.getMethod("putLong", long.class, long.class); + clazz.getMethod("copyMemory", long.class, long.class, long.class); + clazz.getMethod("copyMemory", Object.class, long.class, Object.class, long.class, long.class); + return true; + } catch (Throwable e) { + logger.log( + Level.WARNING, + "platform method missing - proto runtime falling back to safer methods: " + e); } - return supported; + return false; } - /** - * Get the base offset for byte arrays, or {@code -1} if {@code sun.misc.Unsafe} is not available. - */ - private static int byteArrayBaseOffset() { - return HAS_UNSAFE_ARRAY_OPERATIONS ? UNSAFE.arrayBaseOffset(byte[].class) : -1; + + /** Finds the address field within a direct {@link Buffer}. */ + private static Field bufferAddressField() { + Field field = field(Buffer.class, "address"); + return field != null && field.getType() == long.class ? field : null; + } + + /** Finds the value field within a {@link String}. */ + private static Field stringValueField() { + Field field = field(String.class, "value"); + return field != null && field.getType() == char[].class ? field : null; } /** @@ -189,7 +381,7 @@ final class UnsafeUtil { * available. */ private static long fieldOffset(Field field) { - return field == null || UNSAFE == null ? -1 : UNSAFE.objectFieldOffset(field); + return field == null || MEMORY_ACCESSOR == null ? -1 : MEMORY_ACCESSOR.objectFieldOffset(field); } /** @@ -207,4 +399,176 @@ final class UnsafeUtil { } return field; } + + private abstract static class MemoryAccessor { + + sun.misc.Unsafe unsafe; + + MemoryAccessor(sun.misc.Unsafe unsafe) { + this.unsafe = unsafe; + } + + public final long objectFieldOffset(Field field) { + return unsafe.objectFieldOffset(field); + } + + public abstract byte getByte(Object target, long offset); + + public abstract void putByte(Object target, long offset, byte value); + + public final int getInt(Object target, long offset) { + return unsafe.getInt(target, offset); + } + + public final void putInt(Object target, long offset, int value) { + unsafe.putInt(target, offset, value); + } + + public final long getLong(Object target, long offset) { + return unsafe.getLong(target, offset); + } + + public final void putLong(Object target, long offset, long value) { + unsafe.putLong(target, offset, value); + } + + public abstract boolean getBoolean(Object target, long offset); + + public abstract void putBoolean(Object target, long offset, boolean value); + + public abstract float getFloat(Object target, long offset); + + public abstract void putFloat(Object target, long offset, float value); + + public abstract double getDouble(Object target, long offset); + + public abstract void putDouble(Object target, long offset, double value); + + public final Object getObject(Object target, long offset) { + return unsafe.getObject(target, offset); + } + + public final void putObject(Object target, long offset, Object value) { + unsafe.putObject(target, offset, value); + } + + public final int arrayBaseOffset(Class<?> clazz) { + return unsafe.arrayBaseOffset(clazz); + } + + public final int arrayIndexScale(Class<?> clazz) { + return unsafe.arrayIndexScale(clazz); + } + + public abstract byte getByte(long address); + + public abstract void putByte(long address, byte value); + + public abstract int getInt(long address); + + public abstract void putInt(long address, int value); + + public abstract long getLong(long address); + + public abstract void putLong(long address, long value); + + public abstract Object getStaticObject(Field field); + + public abstract void copyMemory(long srcOffset, byte[] target, long targetIndex, long length); + + public abstract void copyMemory(byte[] src, long srcIndex, long targetOffset, long length); + } + + private static final class JvmMemoryAccessor extends MemoryAccessor { + + JvmMemoryAccessor(sun.misc.Unsafe unsafe) { + super(unsafe); + } + + @Override + public byte getByte(long address) { + return unsafe.getByte(address); + } + + @Override + public void putByte(long address, byte value) { + unsafe.putByte(address, value); + } + + @Override + public int getInt(long address) { + return unsafe.getInt(address); + } + + @Override + public void putInt(long address, int value) { + unsafe.putInt(address, value); + } + + @Override + public long getLong(long address) { + return unsafe.getLong(address); + } + + @Override + public void putLong(long address, long value) { + unsafe.putLong(address, value); + } + + @Override + public byte getByte(Object target, long offset) { + return unsafe.getByte(target, offset); + } + + @Override + public void putByte(Object target, long offset, byte value) { + unsafe.putByte(target, offset, value); + } + + @Override + public boolean getBoolean(Object target, long offset) { + return unsafe.getBoolean(target, offset); + } + + @Override + public void putBoolean(Object target, long offset, boolean value) { + unsafe.putBoolean(target, offset, value); + } + + @Override + public float getFloat(Object target, long offset) { + return unsafe.getFloat(target, offset); + } + + @Override + public void putFloat(Object target, long offset, float value) { + unsafe.putFloat(target, offset, value); + } + + @Override + public double getDouble(Object target, long offset) { + return unsafe.getDouble(target, offset); + } + + @Override + public void putDouble(Object target, long offset, double value) { + unsafe.putDouble(target, offset, value); + } + + @Override + public void copyMemory(long srcOffset, byte[] target, long targetIndex, long length) { + unsafe.copyMemory(null, srcOffset, target, BYTE_ARRAY_BASE_OFFSET + targetIndex, length); + } + + @Override + public void copyMemory(byte[] src, long srcIndex, long targetOffset, long length) { + unsafe.copyMemory(src, BYTE_ARRAY_BASE_OFFSET + srcIndex, null, targetOffset, length); + } + + @Override + public Object getStaticObject(Field field) { + return getObject(unsafe.staticFieldBase(field), unsafe.staticFieldOffset(field)); + } + } + } diff --git a/java/core/src/main/java/com/google/protobuf/Utf8.java b/java/core/src/main/java/com/google/protobuf/Utf8.java index 5b80d405..de75fe6b 100644 --- a/java/core/src/main/java/com/google/protobuf/Utf8.java +++ b/java/core/src/main/java/com/google/protobuf/Utf8.java @@ -31,15 +31,18 @@ package com.google.protobuf; import static com.google.protobuf.UnsafeUtil.addressOffset; -import static com.google.protobuf.UnsafeUtil.getArrayBaseOffset; import static com.google.protobuf.UnsafeUtil.hasUnsafeArrayOperations; import static com.google.protobuf.UnsafeUtil.hasUnsafeByteBufferOperations; import static java.lang.Character.MAX_SURROGATE; +import static java.lang.Character.MIN_HIGH_SURROGATE; +import static java.lang.Character.MIN_LOW_SURROGATE; +import static java.lang.Character.MIN_SUPPLEMENTARY_CODE_POINT; import static java.lang.Character.MIN_SURROGATE; import static java.lang.Character.isSurrogatePair; import static java.lang.Character.toCodePoint; import java.nio.ByteBuffer; +import java.util.Arrays; /** * A set of low-level, high-performance static utility methods related @@ -290,7 +293,7 @@ final class Utf8 { if (Character.MIN_SURROGATE <= c && c <= Character.MAX_SURROGATE) { // Check that we have a well-formed surrogate pair. int cp = Character.codePointAt(sequence, i); - if (cp < Character.MIN_SUPPLEMENTARY_CODE_POINT) { + if (cp < MIN_SUPPLEMENTARY_CODE_POINT) { throw new UnpairedSurrogateException(i, utf16Length); } i++; @@ -332,6 +335,26 @@ final class Utf8 { } /** + * Decodes the given UTF-8 portion of the {@link ByteBuffer} into a {@link String}. + * + * @throws InvalidProtocolBufferException if the input is not valid UTF-8. + */ + static String decodeUtf8(ByteBuffer buffer, int index, int size) + throws InvalidProtocolBufferException { + return processor.decodeUtf8(buffer, index, size); + } + + /** + * Decodes the given UTF-8 encoded byte array slice into a {@link String}. + * + * @throws InvalidProtocolBufferException if the input is not valid UTF-8. + */ + static String decodeUtf8(byte[] bytes, int index, int size) + throws InvalidProtocolBufferException { + return processor.decodeUtf8(bytes, index, size); + } + + /** * Encodes the given characters to the target {@link ByteBuffer} using UTF-8 encoding. * * <p>Selects an optimal algorithm based on the type of {@link ByteBuffer} (i.e. heap or direct) @@ -611,6 +634,116 @@ final class Utf8 { } /** + * Decodes the given byte array slice into a {@link String}. + * + * @throws InvalidProtocolBufferException if the byte array slice is not valid UTF-8. + */ + abstract String decodeUtf8(byte[] bytes, int index, int size) + throws InvalidProtocolBufferException; + + /** + * Decodes the given portion of the {@link ByteBuffer} into a {@link String}. + * + * @throws InvalidProtocolBufferException if the portion of the buffer is not valid UTF-8. + */ + final String decodeUtf8(ByteBuffer buffer, int index, int size) + throws InvalidProtocolBufferException { + if (buffer.hasArray()) { + final int offset = buffer.arrayOffset(); + return decodeUtf8(buffer.array(), offset + index, size); + } else if (buffer.isDirect()) { + return decodeUtf8Direct(buffer, index, size); + } + return decodeUtf8Default(buffer, index, size); + } + + /** + * Decodes direct {@link ByteBuffer} instances into {@link String}. + */ + abstract String decodeUtf8Direct(ByteBuffer buffer, int index, int size) + throws InvalidProtocolBufferException; + + /** + * Decodes {@link ByteBuffer} instances using the {@link ByteBuffer} API rather than + * potentially faster approaches. + */ + final String decodeUtf8Default(ByteBuffer buffer, int index, int size) + throws InvalidProtocolBufferException { + // Bitwise OR combines the sign bits so any negative value fails the check. + if ((index | size | buffer.limit() - index - size) < 0) { + throw new ArrayIndexOutOfBoundsException( + String.format("buffer limit=%d, index=%d, limit=%d", buffer.limit(), index, size)); + } + + int offset = index; + final int limit = offset + size; + + // The longest possible resulting String is the same as the number of input bytes, when it is + // all ASCII. For other cases, this over-allocates and we will truncate in the end. + char[] resultArr = new char[size]; + int resultPos = 0; + + // Optimize for 100% ASCII (Hotspot loves small simple top-level loops like this). + // This simple loop stops when we encounter a byte >= 0x80 (i.e. non-ASCII). + while (offset < limit) { + byte b = buffer.get(offset); + if (!DecodeUtil.isOneByte(b)) { + break; + } + offset++; + DecodeUtil.handleOneByte(b, resultArr, resultPos++); + } + + while (offset < limit) { + byte byte1 = buffer.get(offset++); + if (DecodeUtil.isOneByte(byte1)) { + DecodeUtil.handleOneByte(byte1, resultArr, resultPos++); + // It's common for there to be multiple ASCII characters in a run mixed in, so add an + // extra optimized loop to take care of these runs. + while (offset < limit) { + byte b = buffer.get(offset); + if (!DecodeUtil.isOneByte(b)) { + break; + } + offset++; + DecodeUtil.handleOneByte(b, resultArr, resultPos++); + } + } else if (DecodeUtil.isTwoBytes(byte1)) { + if (offset >= limit) { + throw InvalidProtocolBufferException.invalidUtf8(); + } + DecodeUtil.handleTwoBytes( + byte1, /* byte2 */ buffer.get(offset++), resultArr, resultPos++); + } else if (DecodeUtil.isThreeBytes(byte1)) { + if (offset >= limit - 1) { + throw InvalidProtocolBufferException.invalidUtf8(); + } + DecodeUtil.handleThreeBytes( + byte1, + /* byte2 */ buffer.get(offset++), + /* byte3 */ buffer.get(offset++), + resultArr, + resultPos++); + } else { + if (offset >= limit - 2) { + throw InvalidProtocolBufferException.invalidUtf8(); + } + DecodeUtil.handleFourBytes( + byte1, + /* byte2 */ buffer.get(offset++), + /* byte3 */ buffer.get(offset++), + /* byte4 */ buffer.get(offset++), + resultArr, + resultPos++); + // 4-byte case requires two chars. + resultPos++; + } + } + + return new String(resultArr, 0, resultPos); + } + + /** * Encodes an input character sequence ({@code in}) to UTF-8 in the target array ({@code out}). * For a string, this method is similar to * <pre>{@code @@ -852,6 +985,88 @@ final class Utf8 { } @Override + String decodeUtf8(byte[] bytes, int index, int size) throws InvalidProtocolBufferException { + // Bitwise OR combines the sign bits so any negative value fails the check. + if ((index | size | bytes.length - index - size) < 0) { + throw new ArrayIndexOutOfBoundsException( + String.format("buffer length=%d, index=%d, size=%d", bytes.length, index, size)); + } + + int offset = index; + final int limit = offset + size; + + // The longest possible resulting String is the same as the number of input bytes, when it is + // all ASCII. For other cases, this over-allocates and we will truncate in the end. + char[] resultArr = new char[size]; + int resultPos = 0; + + // Optimize for 100% ASCII (Hotspot loves small simple top-level loops like this). + // This simple loop stops when we encounter a byte >= 0x80 (i.e. non-ASCII). + while (offset < limit) { + byte b = bytes[offset]; + if (!DecodeUtil.isOneByte(b)) { + break; + } + offset++; + DecodeUtil.handleOneByte(b, resultArr, resultPos++); + } + + while (offset < limit) { + byte byte1 = bytes[offset++]; + if (DecodeUtil.isOneByte(byte1)) { + DecodeUtil.handleOneByte(byte1, resultArr, resultPos++); + // It's common for there to be multiple ASCII characters in a run mixed in, so add an + // extra optimized loop to take care of these runs. + while (offset < limit) { + byte b = bytes[offset]; + if (!DecodeUtil.isOneByte(b)) { + break; + } + offset++; + DecodeUtil.handleOneByte(b, resultArr, resultPos++); + } + } else if (DecodeUtil.isTwoBytes(byte1)) { + if (offset >= limit) { + throw InvalidProtocolBufferException.invalidUtf8(); + } + DecodeUtil.handleTwoBytes(byte1, /* byte2 */ bytes[offset++], resultArr, resultPos++); + } else if (DecodeUtil.isThreeBytes(byte1)) { + if (offset >= limit - 1) { + throw InvalidProtocolBufferException.invalidUtf8(); + } + DecodeUtil.handleThreeBytes( + byte1, + /* byte2 */ bytes[offset++], + /* byte3 */ bytes[offset++], + resultArr, + resultPos++); + } else { + if (offset >= limit - 2) { + throw InvalidProtocolBufferException.invalidUtf8(); + } + DecodeUtil.handleFourBytes( + byte1, + /* byte2 */ bytes[offset++], + /* byte3 */ bytes[offset++], + /* byte4 */ bytes[offset++], + resultArr, + resultPos++); + // 4-byte case requires two chars. + resultPos++; + } + } + + return new String(resultArr, 0, resultPos); + } + + @Override + String decodeUtf8Direct(ByteBuffer buffer, int index, int size) + throws InvalidProtocolBufferException { + // For safe processing, we have to use the ByteBufferAPI. + return decodeUtf8Default(buffer, index, size); + } + + @Override int encodeUtf8(CharSequence in, byte[] out, int offset, int length) { int utf16Length = in.length(); int j = offset; @@ -997,12 +1212,13 @@ final class Utf8 { @Override int partialIsValidUtf8(int state, byte[] bytes, final int index, final int limit) { + // Bitwise OR combines the sign bits so any negative value fails the check. if ((index | limit | bytes.length - limit) < 0) { throw new ArrayIndexOutOfBoundsException( String.format("Array length=%d, index=%d, limit=%d", bytes.length, index, limit)); } - long offset = getArrayBaseOffset() + index; - final long offsetLimit = getArrayBaseOffset() + limit; + long offset = index; + final long offsetLimit = limit; if (state != COMPLETE) { // The previous decoding operation was incomplete (or malformed). // We look for a well-formed sequence consisting of bytes from @@ -1092,6 +1308,7 @@ final class Utf8 { @Override int partialIsValidUtf8Direct( final int state, ByteBuffer buffer, final int index, final int limit) { + // Bitwise OR combines the sign bits so any negative value fails the check. if ((index | limit | buffer.limit() - limit) < 0) { throw new ArrayIndexOutOfBoundsException( String.format("buffer limit=%d, index=%d, limit=%d", buffer.limit(), index, limit)); @@ -1186,8 +1403,159 @@ final class Utf8 { } @Override + String decodeUtf8(byte[] bytes, int index, int size) throws InvalidProtocolBufferException { + if ((index | size | bytes.length - index - size) < 0) { + throw new ArrayIndexOutOfBoundsException( + String.format("buffer length=%d, index=%d, size=%d", bytes.length, index, size)); + } + + int offset = index; + final int limit = offset + size; + + // The longest possible resulting String is the same as the number of input bytes, when it is + // all ASCII. For other cases, this over-allocates and we will truncate in the end. + char[] resultArr = new char[size]; + int resultPos = 0; + + // Optimize for 100% ASCII (Hotspot loves small simple top-level loops like this). + // This simple loop stops when we encounter a byte >= 0x80 (i.e. non-ASCII). + while (offset < limit) { + byte b = UnsafeUtil.getByte(bytes, offset); + if (!DecodeUtil.isOneByte(b)) { + break; + } + offset++; + DecodeUtil.handleOneByte(b, resultArr, resultPos++); + } + + while (offset < limit) { + byte byte1 = UnsafeUtil.getByte(bytes, offset++); + if (DecodeUtil.isOneByte(byte1)) { + DecodeUtil.handleOneByte(byte1, resultArr, resultPos++); + // It's common for there to be multiple ASCII characters in a run mixed in, so add an + // extra optimized loop to take care of these runs. + while (offset < limit) { + byte b = UnsafeUtil.getByte(bytes, offset); + if (!DecodeUtil.isOneByte(b)) { + break; + } + offset++; + DecodeUtil.handleOneByte(b, resultArr, resultPos++); + } + } else if (DecodeUtil.isTwoBytes(byte1)) { + if (offset >= limit) { + throw InvalidProtocolBufferException.invalidUtf8(); + } + DecodeUtil.handleTwoBytes( + byte1, /* byte2 */ UnsafeUtil.getByte(bytes, offset++), resultArr, resultPos++); + } else if (DecodeUtil.isThreeBytes(byte1)) { + if (offset >= limit - 1) { + throw InvalidProtocolBufferException.invalidUtf8(); + } + DecodeUtil.handleThreeBytes( + byte1, + /* byte2 */ UnsafeUtil.getByte(bytes, offset++), + /* byte3 */ UnsafeUtil.getByte(bytes, offset++), + resultArr, + resultPos++); + } else { + if (offset >= limit - 2) { + throw InvalidProtocolBufferException.invalidUtf8(); + } + DecodeUtil.handleFourBytes( + byte1, + /* byte2 */ UnsafeUtil.getByte(bytes, offset++), + /* byte3 */ UnsafeUtil.getByte(bytes, offset++), + /* byte4 */ UnsafeUtil.getByte(bytes, offset++), + resultArr, + resultPos++); + // 4-byte case requires two chars. + resultPos++; + } + } + + return new String(resultArr, 0, resultPos); + } + + @Override + String decodeUtf8Direct(ByteBuffer buffer, int index, int size) + throws InvalidProtocolBufferException { + // Bitwise OR combines the sign bits so any negative value fails the check. + if ((index | size | buffer.limit() - index - size) < 0) { + throw new ArrayIndexOutOfBoundsException( + String.format("buffer limit=%d, index=%d, limit=%d", buffer.limit(), index, size)); + } + long address = UnsafeUtil.addressOffset(buffer) + index; + final long addressLimit = address + size; + + // The longest possible resulting String is the same as the number of input bytes, when it is + // all ASCII. For other cases, this over-allocates and we will truncate in the end. + char[] resultArr = new char[size]; + int resultPos = 0; + + // Optimize for 100% ASCII (Hotspot loves small simple top-level loops like this). + // This simple loop stops when we encounter a byte >= 0x80 (i.e. non-ASCII). + while (address < addressLimit) { + byte b = UnsafeUtil.getByte(address); + if (!DecodeUtil.isOneByte(b)) { + break; + } + address++; + DecodeUtil.handleOneByte(b, resultArr, resultPos++); + } + + while (address < addressLimit) { + byte byte1 = UnsafeUtil.getByte(address++); + if (DecodeUtil.isOneByte(byte1)) { + DecodeUtil.handleOneByte(byte1, resultArr, resultPos++); + // It's common for there to be multiple ASCII characters in a run mixed in, so add an + // extra optimized loop to take care of these runs. + while (address < addressLimit) { + byte b = UnsafeUtil.getByte(address); + if (!DecodeUtil.isOneByte(b)) { + break; + } + address++; + DecodeUtil.handleOneByte(b, resultArr, resultPos++); + } + } else if (DecodeUtil.isTwoBytes(byte1)) { + if (address >= addressLimit) { + throw InvalidProtocolBufferException.invalidUtf8(); + } + DecodeUtil.handleTwoBytes( + byte1, /* byte2 */ UnsafeUtil.getByte(address++), resultArr, resultPos++); + } else if (DecodeUtil.isThreeBytes(byte1)) { + if (address >= addressLimit - 1) { + throw InvalidProtocolBufferException.invalidUtf8(); + } + DecodeUtil.handleThreeBytes( + byte1, + /* byte2 */ UnsafeUtil.getByte(address++), + /* byte3 */ UnsafeUtil.getByte(address++), + resultArr, + resultPos++); + } else { + if (address >= addressLimit - 2) { + throw InvalidProtocolBufferException.invalidUtf8(); + } + DecodeUtil.handleFourBytes( + byte1, + /* byte2 */ UnsafeUtil.getByte(address++), + /* byte3 */ UnsafeUtil.getByte(address++), + /* byte4 */ UnsafeUtil.getByte(address++), + resultArr, + resultPos++); + // 4-byte case requires two chars. + resultPos++; + } + } + + return new String(resultArr, 0, resultPos); + } + + @Override int encodeUtf8(final CharSequence in, final byte[] out, final int offset, final int length) { - long outIx = getArrayBaseOffset() + offset; + long outIx = offset; final long outLimit = outIx + length; final int inLimit = in.length(); if (inLimit > length || out.length - length < offset) { @@ -1204,7 +1572,7 @@ final class Utf8 { } if (inIx == inLimit) { // We're done, it was ASCII encoded. - return (int) (outIx - getArrayBaseOffset()); + return (int) outIx; } for (char c; inIx < inLimit; ++inIx) { @@ -1243,7 +1611,7 @@ final class Utf8 { } // All bytes have been encoded. - return (int) (outIx - getArrayBaseOffset()); + return (int) outIx; } @Override @@ -1321,31 +1689,17 @@ final class Utf8 { */ private static int unsafeEstimateConsecutiveAscii( byte[] bytes, long offset, final int maxChars) { - int remaining = maxChars; - if (remaining < UNSAFE_COUNT_ASCII_THRESHOLD) { + if (maxChars < UNSAFE_COUNT_ASCII_THRESHOLD) { // Don't bother with small strings. return 0; } - // Read bytes until 8-byte aligned so that we can read longs in the loop below. - // Byte arrays are already either 8 or 16-byte aligned, so we just need to make sure that - // the index (relative to the start of the array) is also 8-byte aligned. We do this by - // ANDing the index with 7 to determine the number of bytes that need to be read before - // we're 8-byte aligned. - final int unaligned = (int) offset & 7; - for (int j = unaligned; j > 0; j--) { + for (int i = 0; i < maxChars; i++) { if (UnsafeUtil.getByte(bytes, offset++) < 0) { - return unaligned - j; + return i; } } - - // This simple loop stops when we encounter a byte >= 0x80 (i.e. non-ASCII). - // To speed things up further, we're reading longs instead of bytes so we use a mask to - // determine if any byte in the current long is non-ASCII. - remaining -= unaligned; - for (; remaining >= 8 && (UnsafeUtil.getLong(bytes, offset) & ASCII_MASK_LONG) == 0; - offset += 8, remaining -= 8) {} - return maxChars - remaining; + return maxChars; } /** @@ -1362,7 +1716,7 @@ final class Utf8 { // Read bytes until 8-byte aligned so that we can read longs in the loop below. // We do this by ANDing the address with 7 to determine the number of bytes that need to // be read before we're 8-byte aligned. - final int unaligned = (int) address & 7; + final int unaligned = 8 - ((int) address & 7); for (int j = unaligned; j > 0; j--) { if (UnsafeUtil.getByte(address++) < 0) { return unaligned - j; @@ -1569,5 +1923,112 @@ final class Utf8 { } } + /** + * Utility methods for decoding bytes into {@link String}. Callers are responsible for extracting + * bytes (possibly using Unsafe methods), and checking remaining bytes. All other UTF-8 validity + * checks and codepoint conversion happen in this class. + */ + private static class DecodeUtil { + + /** + * Returns whether this is a single-byte codepoint (i.e., ASCII) with the form '0XXXXXXX'. + */ + private static boolean isOneByte(byte b) { + return b >= 0; + } + + /** + * Returns whether this is a two-byte codepoint with the form '10XXXXXX'. + */ + private static boolean isTwoBytes(byte b) { + return b < (byte) 0xE0; + } + + /** + * Returns whether this is a three-byte codepoint with the form '110XXXXX'. + */ + private static boolean isThreeBytes(byte b) { + return b < (byte) 0xF0; + } + + private static void handleOneByte(byte byte1, char[] resultArr, int resultPos) { + resultArr[resultPos] = (char) byte1; + } + + private static void handleTwoBytes( + byte byte1, byte byte2, char[] resultArr, int resultPos) + throws InvalidProtocolBufferException { + // Simultaneously checks for illegal trailing-byte in leading position (<= '11000000') and + // overlong 2-byte, '11000001'. + if (byte1 < (byte) 0xC2 + || isNotTrailingByte(byte2)) { + throw InvalidProtocolBufferException.invalidUtf8(); + } + resultArr[resultPos] = (char) (((byte1 & 0x1F) << 6) | trailingByteValue(byte2)); + } + + private static void handleThreeBytes( + byte byte1, byte byte2, byte byte3, char[] resultArr, int resultPos) + throws InvalidProtocolBufferException { + if (isNotTrailingByte(byte2) + // overlong? 5 most significant bits must not all be zero + || (byte1 == (byte) 0xE0 && byte2 < (byte) 0xA0) + // check for illegal surrogate codepoints + || (byte1 == (byte) 0xED && byte2 >= (byte) 0xA0) + || isNotTrailingByte(byte3)) { + throw InvalidProtocolBufferException.invalidUtf8(); + } + resultArr[resultPos] = (char) + (((byte1 & 0x0F) << 12) | (trailingByteValue(byte2) << 6) | trailingByteValue(byte3)); + } + + private static void handleFourBytes( + byte byte1, byte byte2, byte byte3, byte byte4, char[] resultArr, int resultPos) + throws InvalidProtocolBufferException{ + if (isNotTrailingByte(byte2) + // Check that 1 <= plane <= 16. Tricky optimized form of: + // valid 4-byte leading byte? + // if (byte1 > (byte) 0xF4 || + // overlong? 4 most significant bits must not all be zero + // byte1 == (byte) 0xF0 && byte2 < (byte) 0x90 || + // codepoint larger than the highest code point (U+10FFFF)? + // byte1 == (byte) 0xF4 && byte2 > (byte) 0x8F) + || (((byte1 << 28) + (byte2 - (byte) 0x90)) >> 30) != 0 + || isNotTrailingByte(byte3) + || isNotTrailingByte(byte4)) { + throw InvalidProtocolBufferException.invalidUtf8(); + } + int codepoint = ((byte1 & 0x07) << 18) + | (trailingByteValue(byte2) << 12) + | (trailingByteValue(byte3) << 6) + | trailingByteValue(byte4); + resultArr[resultPos] = DecodeUtil.highSurrogate(codepoint); + resultArr[resultPos + 1] = DecodeUtil.lowSurrogate(codepoint); + } + + /** + * Returns whether the byte is not a valid continuation of the form '10XXXXXX'. + */ + private static boolean isNotTrailingByte(byte b) { + return b > (byte) 0xBF; + } + + /** + * Returns the actual value of the trailing byte (removes the prefix '10') for composition. + */ + private static int trailingByteValue(byte b) { + return b & 0x3F; + } + + private static char highSurrogate(int codePoint) { + return (char) ((MIN_HIGH_SURROGATE - (MIN_SUPPLEMENTARY_CODE_POINT >>> 10)) + + (codePoint >>> 10)); + } + + private static char lowSurrogate(int codePoint) { + return (char) (MIN_LOW_SURROGATE + (codePoint & 0x3ff)); + } + } + private Utf8() {} } diff --git a/java/core/src/main/java/com/google/protobuf/WireFormat.java b/java/core/src/main/java/com/google/protobuf/WireFormat.java index 6a58b081..8b837ee5 100644 --- a/java/core/src/main/java/com/google/protobuf/WireFormat.java +++ b/java/core/src/main/java/com/google/protobuf/WireFormat.java @@ -47,6 +47,12 @@ public final class WireFormat { // Do not allow instantiation. private WireFormat() {} + static final int FIXED32_SIZE = 4; + static final int FIXED64_SIZE = 8; + static final int MAX_VARINT32_SIZE = 5; + static final int MAX_VARINT64_SIZE = 10; + static final int MAX_VARINT_SIZE = 10; + public static final int WIRETYPE_VARINT = 0; public static final int WIRETYPE_FIXED64 = 1; public static final int WIRETYPE_LENGTH_DELIMITED = 2; |