From 3e944aec9ebdf5043780fba751d604c0a55511f2 Mon Sep 17 00:00:00 2001 From: Anuraag Agrawal Date: Tue, 31 Oct 2017 15:12:11 +0900 Subject: Add a UTF-8 decoder that uses Unsafe to directly decode a byte buffer. --- .../java/com/google/protobuf/CodedInputStream.java | 104 +++-- .../main/java/com/google/protobuf/UnsafeUtil.java | 35 +- .../src/main/java/com/google/protobuf/Utf8.java | 484 ++++++++++++++++++++- .../java/com/google/protobuf/DecodeUtf8Test.java | 325 ++++++++++++++ .../com/google/protobuf/IsValidUtf8TestUtil.java | 9 + 5 files changed, 920 insertions(+), 37 deletions(-) create mode 100644 java/core/src/test/java/com/google/protobuf/DecodeUtf8Test.java (limited to 'java') diff --git a/java/core/src/main/java/com/google/protobuf/CodedInputStream.java b/java/core/src/main/java/com/google/protobuf/CodedInputStream.java index e08a993b..7a3f0eb9 100644 --- a/java/core/src/main/java/com/google/protobuf/CodedInputStream.java +++ b/java/core/src/main/java/com/google/protobuf/CodedInputStream.java @@ -64,6 +64,14 @@ public abstract class CodedInputStream { // Integer.MAX_VALUE == 0x7FFFFFF == INT_MAX from limits.h private static final int DEFAULT_SIZE_LIMIT = Integer.MAX_VALUE; + /** + * Whether to enable our custom UTF-8 decode codepath which does not use {@link StringCoding}. + * Enabled by default, disable by setting + * {@code -Dcom.google.protobuf.enableCustomutf8Decode=false} in JVM args. + */ + private static final boolean ENABLE_CUSTOM_UTF8_DECODE + = !"false".equals(System.getProperty("com.google.protobuf.enableCustomUtf8Decode")); + /** Visible for subclasses. See setRecursionLimit() */ int recursionDepth; @@ -825,13 +833,19 @@ public abstract class CodedInputStream { public String readStringRequireUtf8() throws IOException { final int size = readRawVarint32(); if (size > 0 && size <= (limit - pos)) { - // TODO(martinrb): We could save a pass by validating while decoding. - if (!Utf8.isValidUtf8(buffer, pos, pos + size)) { - throw InvalidProtocolBufferException.invalidUtf8(); + if (ENABLE_CUSTOM_UTF8_DECODE) { + String result = Utf8.decodeUtf8(buffer, pos, size); + pos += size; + return result; + } else { + // TODO(martinrb): We could save a pass by validating while decoding. + if (!Utf8.isValidUtf8(buffer, pos, pos + size)) { + throw InvalidProtocolBufferException.invalidUtf8(); + } + final int tempPos = pos; + pos += size; + return new String(buffer, tempPos, size, UTF_8); } - final int tempPos = pos; - pos += size; - return new String(buffer, tempPos, size, UTF_8); } if (size == 0) { @@ -1524,6 +1538,8 @@ public abstract class CodedInputStream { final int size = readRawVarint32(); if (size > 0 && size <= remaining()) { // TODO(nathanmittler): Is there a way to avoid this copy? + // TODO(anuraaga): It might be possible to share the optimized loop with + // readStringRequireUtf8 by implementing Java replacement logic there. // The same as readBytes' logic byte[] bytes = new byte[size]; UnsafeUtil.copyMemory(pos, bytes, 0, size); @@ -1544,19 +1560,26 @@ public abstract class CodedInputStream { @Override public String readStringRequireUtf8() throws IOException { final int size = readRawVarint32(); - if (size >= 0 && size <= remaining()) { - // TODO(nathanmittler): Is there a way to avoid this copy? - // The same as readBytes' logic - byte[] bytes = new byte[size]; - UnsafeUtil.copyMemory(pos, bytes, 0, size); - // TODO(martinrb): We could save a pass by validating while decoding. - if (!Utf8.isValidUtf8(bytes)) { - throw InvalidProtocolBufferException.invalidUtf8(); - } + if (size > 0 && size <= remaining()) { + if (ENABLE_CUSTOM_UTF8_DECODE) { + final int bufferPos = bufferPos(pos); + String result = Utf8.decodeUtf8(buffer, bufferPos, size); + pos += size; + return result; + } else { + // TODO(nathanmittler): Is there a way to avoid this copy? + // The same as readBytes' logic + byte[] bytes = new byte[size]; + UnsafeUtil.copyMemory(pos, bytes, 0, size); + // TODO(martinrb): We could save a pass by validating while decoding. + if (!Utf8.isValidUtf8(bytes)) { + throw InvalidProtocolBufferException.invalidUtf8(); + } - String result = new String(bytes, UTF_8); - pos += size; - return result; + String result = new String(bytes, UTF_8); + pos += size; + return result; + } } if (size == 0) { @@ -2324,11 +2347,15 @@ public abstract class CodedInputStream { bytes = readRawBytesSlowPath(size); tempPos = 0; } - // TODO(martinrb): We could save a pass by validating while decoding. - if (!Utf8.isValidUtf8(bytes, tempPos, tempPos + size)) { - throw InvalidProtocolBufferException.invalidUtf8(); + if (ENABLE_CUSTOM_UTF8_DECODE) { + return Utf8.decodeUtf8(bytes, tempPos, size); + } else { + // TODO(martinrb): We could save a pass by validating while decoding. + if (!Utf8.isValidUtf8(bytes, tempPos, tempPos + size)) { + throw InvalidProtocolBufferException.invalidUtf8(); + } + return new String(bytes, tempPos, size, UTF_8); } - return new String(bytes, tempPos, size, UTF_8); } @Override @@ -3348,23 +3375,34 @@ public abstract class CodedInputStream { public String readStringRequireUtf8() throws IOException { final int size = readRawVarint32(); if (size > 0 && size <= currentByteBufferLimit - currentByteBufferPos) { - byte[] bytes = new byte[size]; - UnsafeUtil.copyMemory(currentByteBufferPos, bytes, 0, size); - if (!Utf8.isValidUtf8(bytes)) { - throw InvalidProtocolBufferException.invalidUtf8(); + if (ENABLE_CUSTOM_UTF8_DECODE) { + final int bufferPos = (int) (currentByteBufferPos - currentByteBufferStartPos); + String result = Utf8.decodeUtf8(currentByteBuffer, bufferPos, size); + currentByteBufferPos += size; + return result; + } else { + byte[] bytes = new byte[size]; + UnsafeUtil.copyMemory(currentByteBufferPos, bytes, 0, size); + if (!Utf8.isValidUtf8(bytes)) { + throw InvalidProtocolBufferException.invalidUtf8(); + } + String result = new String(bytes, UTF_8); + currentByteBufferPos += size; + return result; } - String result = new String(bytes, UTF_8); - currentByteBufferPos += size; - return result; } if (size >= 0 && size <= remaining()) { byte[] bytes = new byte[size]; readRawBytesTo(bytes, 0, size); - if (!Utf8.isValidUtf8(bytes)) { - throw InvalidProtocolBufferException.invalidUtf8(); + if (ENABLE_CUSTOM_UTF8_DECODE) { + return Utf8.decodeUtf8(bytes, 0, size); + } else { + if (!Utf8.isValidUtf8(bytes)) { + throw InvalidProtocolBufferException.invalidUtf8(); + } + String result = new String(bytes, UTF_8); + return result; } - String result = new String(bytes, UTF_8); - return result; } if (size == 0) { diff --git a/java/core/src/main/java/com/google/protobuf/UnsafeUtil.java b/java/core/src/main/java/com/google/protobuf/UnsafeUtil.java index 88315cb6..76fc687e 100644 --- a/java/core/src/main/java/com/google/protobuf/UnsafeUtil.java +++ b/java/core/src/main/java/com/google/protobuf/UnsafeUtil.java @@ -33,7 +33,6 @@ package com.google.protobuf; import java.lang.reflect.Field; import java.nio.Buffer; import java.nio.ByteBuffer; -import java.nio.ByteOrder; import java.security.AccessController; import java.security.PrivilegedExceptionAction; import java.util.logging.Level; @@ -72,6 +71,8 @@ final class UnsafeUtil { private static final long BUFFER_ADDRESS_OFFSET = fieldOffset(bufferAddressField()); + private static final long STRING_VALUE_OFFSET = fieldOffset(stringValueField()); + private UnsafeUtil() {} static boolean hasUnsafeArrayOperations() { @@ -259,6 +260,26 @@ final class UnsafeUtil { return MEMORY_ACCESSOR.getLong(buffer, BUFFER_ADDRESS_OFFSET); } + /** + * Returns a new {@link String} backed by the given {@code chars}. The char array should not + * be mutated any more after calling this function. + */ + static String moveToString(char[] chars) { + if (STRING_VALUE_OFFSET == -1) { + // In the off-chance that this JDK does not implement String as we'd expect, just do a copy. + return new String(chars); + } + final String str; + try { + str = (String) UNSAFE.allocateInstance(String.class); + } catch (InstantiationException e) { + // This should never happen, but return a copy as a fallback just in case. + return new String(chars); + } + putObject(str, STRING_VALUE_OFFSET, chars); + return str; + } + static Object getStaticObject(Field field) { return MEMORY_ACCESSOR.getStaticObject(field); } @@ -375,7 +396,12 @@ final class UnsafeUtil { /** Finds the address field within a direct {@link Buffer}. */ private static Field bufferAddressField() { - return field(Buffer.class, "address"); + return field(Buffer.class, "address", long.class); + } + + /** Finds the value field within a {@link String}. */ + private static Field stringValueField() { + return field(String.class, "value", char[].class); } /** @@ -390,11 +416,14 @@ final class UnsafeUtil { * Gets the field with the given name within the class, or {@code null} if not found. If found, * the field is made accessible. */ - private static Field field(Class clazz, String fieldName) { + private static Field field(Class clazz, String fieldName, Class expectedType) { Field field; try { field = clazz.getDeclaredField(fieldName); field.setAccessible(true); + if (!field.getType().equals(expectedType)) { + return null; + } } catch (Throwable t) { // Failed to access the fields. field = null; diff --git a/java/core/src/main/java/com/google/protobuf/Utf8.java b/java/core/src/main/java/com/google/protobuf/Utf8.java index 1b136144..6968abb3 100644 --- a/java/core/src/main/java/com/google/protobuf/Utf8.java +++ b/java/core/src/main/java/com/google/protobuf/Utf8.java @@ -34,11 +34,15 @@ import static com.google.protobuf.UnsafeUtil.addressOffset; import static com.google.protobuf.UnsafeUtil.hasUnsafeArrayOperations; import static com.google.protobuf.UnsafeUtil.hasUnsafeByteBufferOperations; import static java.lang.Character.MAX_SURROGATE; +import static java.lang.Character.MIN_HIGH_SURROGATE; +import static java.lang.Character.MIN_LOW_SURROGATE; +import static java.lang.Character.MIN_SUPPLEMENTARY_CODE_POINT; import static java.lang.Character.MIN_SURROGATE; import static java.lang.Character.isSurrogatePair; import static java.lang.Character.toCodePoint; import java.nio.ByteBuffer; +import java.util.Arrays; /** * A set of low-level, high-performance static utility methods related @@ -289,7 +293,7 @@ final class Utf8 { if (Character.MIN_SURROGATE <= c && c <= Character.MAX_SURROGATE) { // Check that we have a well-formed surrogate pair. int cp = Character.codePointAt(sequence, i); - if (cp < Character.MIN_SUPPLEMENTARY_CODE_POINT) { + if (cp < MIN_SUPPLEMENTARY_CODE_POINT) { throw new UnpairedSurrogateException(i, utf16Length); } i++; @@ -330,6 +334,26 @@ final class Utf8 { return processor.partialIsValidUtf8(state, buffer, index, limit); } + /** + * Decodes the given UTF-8 portion of the {@link ByteBuffer} into a {@link String}. + * + * @throws InvalidProtocolBufferException if the input is not valid UTF-8. + */ + static String decodeUtf8(ByteBuffer buffer, int index, int size) + throws InvalidProtocolBufferException { + return processor.decodeUtf8(buffer, index, size); + } + + /** + * Decodes the given UTF-8 encoded byte array slice into a {@link String}. + * + * @throws InvalidProtocolBufferException if the input is not valid UTF-8. + */ + static String decodeUtf8(byte[] bytes, int index, int size) + throws InvalidProtocolBufferException { + return processor.decodeUtf8(bytes, index, size); + } + /** * Encodes the given characters to the target {@link ByteBuffer} using UTF-8 encoding. * @@ -609,6 +633,116 @@ final class Utf8 { } } + /** + * Decodes the given byte array slice into a {@link String}. + * + * @throws InvalidProtocolBufferException if the byte array slice is not valid UTF-8. + */ + abstract String decodeUtf8(byte[] bytes, int index, int size) + throws InvalidProtocolBufferException; + + /** + * Decodes the given portion of the {@link ByteBuffer} into a {@link String}. + * + * @throws InvalidProtocolBufferException if the portion of the buffer is not valid UTF-8. + */ + final String decodeUtf8(ByteBuffer buffer, int index, int size) + throws InvalidProtocolBufferException { + if (buffer.hasArray()) { + final int offset = buffer.arrayOffset(); + return decodeUtf8(buffer.array(), offset + index, size); + } else if (buffer.isDirect()) { + return decodeUtf8Direct(buffer, index, size); + } + return decodeUtf8Default(buffer, index, size); + } + + /** + * Decodes direct {@link ByteBuffer} instances into {@link String}. + */ + abstract String decodeUtf8Direct(ByteBuffer buffer, int index, int size) + throws InvalidProtocolBufferException; + + /** + * Decodes {@link ByteBuffer} instances using the {@link ByteBuffer} API rather than + * potentially faster approaches. + */ + final String decodeUtf8Default(ByteBuffer buffer, int index, int size) + throws InvalidProtocolBufferException { + // Bitwise OR combines the sign bits so any negative value fails the check. + if ((index | size | buffer.limit() - index - size) < 0) { + throw new ArrayIndexOutOfBoundsException( + String.format("buffer limit=%d, index=%d, limit=%d", buffer.limit(), index, size)); + } + + int offset = index; + final int limit = offset + size; + + // The longest possible resulting String is the same as the number of input bytes, when it is + // all ASCII. For other cases, this over-allocates and we will truncate in the end. + char[] resultArr = new char[size]; + int resultPos = 0; + + // Optimize for 100% ASCII (Hotspot loves small simple top-level loops like this). + // This simple loop stops when we encounter a byte >= 0x80 (i.e. non-ASCII). + while (offset < limit) { + byte b = buffer.get(offset); + if (!DecodeUtil.isOneByte(b)) { + break; + } + offset++; + DecodeUtil.handleOneByte(b, resultArr, resultPos++); + } + + while (offset < limit) { + byte byte1 = buffer.get(offset++); + if (DecodeUtil.isOneByte(byte1)) { + DecodeUtil.handleOneByte(byte1, resultArr, resultPos++); + // It's common for there to be multiple ASCII characters in a run mixed in, so add an + // extra optimized loop to take care of these runs. + while (offset < limit) { + byte b = buffer.get(offset); + if (!DecodeUtil.isOneByte(b)) { + break; + } + offset++; + DecodeUtil.handleOneByte(b, resultArr, resultPos++); + } + } else if (DecodeUtil.isTwoBytes(byte1)) { + if (offset >= limit) { + throw InvalidProtocolBufferException.invalidUtf8(); + } + DecodeUtil.handleTwoBytes( + byte1, /* byte2 */ buffer.get(offset++), resultArr, resultPos++); + } else if (DecodeUtil.isThreeBytes(byte1)) { + if (offset >= limit - 1) { + throw InvalidProtocolBufferException.invalidUtf8(); + } + DecodeUtil.handleThreeBytes( + byte1, + /* byte2 */ buffer.get(offset++), + /* byte3 */ buffer.get(offset++), + resultArr, + resultPos++); + } else { + if (offset >= limit - 2) { + throw InvalidProtocolBufferException.invalidUtf8(); + } + DecodeUtil.handleFourBytes( + byte1, + /* byte2 */ buffer.get(offset++), + /* byte3 */ buffer.get(offset++), + /* byte4 */ buffer.get(offset++), + resultArr, + resultPos++); + // 4-byte case requires two chars. + resultPos++; + } + } + + return new String(resultArr, 0, resultPos); + } + /** * Encodes an input character sequence ({@code in}) to UTF-8 in the target array ({@code out}). * For a string, this method is similar to @@ -850,6 +984,88 @@ final class Utf8 { return partialIsValidUtf8Default(state, buffer, index, limit); } + @Override + String decodeUtf8(byte[] bytes, int index, int size) throws InvalidProtocolBufferException { + // Bitwise OR combines the sign bits so any negative value fails the check. + if ((index | size | bytes.length - index - size) < 0) { + throw new ArrayIndexOutOfBoundsException( + String.format("buffer length=%d, index=%d, size=%d", bytes.length, index, size)); + } + + int offset = index; + final int limit = offset + size; + + // The longest possible resulting String is the same as the number of input bytes, when it is + // all ASCII. For other cases, this over-allocates and we will truncate in the end. + char[] resultArr = new char[size]; + int resultPos = 0; + + // Optimize for 100% ASCII (Hotspot loves small simple top-level loops like this). + // This simple loop stops when we encounter a byte >= 0x80 (i.e. non-ASCII). + while (offset < limit) { + byte b = bytes[offset]; + if (!DecodeUtil.isOneByte(b)) { + break; + } + offset++; + DecodeUtil.handleOneByte(b, resultArr, resultPos++); + } + + while (offset < limit) { + byte byte1 = bytes[offset++]; + if (DecodeUtil.isOneByte(byte1)) { + DecodeUtil.handleOneByte(byte1, resultArr, resultPos++); + // It's common for there to be multiple ASCII characters in a run mixed in, so add an + // extra optimized loop to take care of these runs. + while (offset < limit) { + byte b = bytes[offset]; + if (!DecodeUtil.isOneByte(b)) { + break; + } + offset++; + DecodeUtil.handleOneByte(b, resultArr, resultPos++); + } + } else if (DecodeUtil.isTwoBytes(byte1)) { + if (offset >= limit) { + throw InvalidProtocolBufferException.invalidUtf8(); + } + DecodeUtil.handleTwoBytes(byte1, /* byte2 */ bytes[offset++], resultArr, resultPos++); + } else if (DecodeUtil.isThreeBytes(byte1)) { + if (offset >= limit - 1) { + throw InvalidProtocolBufferException.invalidUtf8(); + } + DecodeUtil.handleThreeBytes( + byte1, + /* byte2 */ bytes[offset++], + /* byte3 */ bytes[offset++], + resultArr, + resultPos++); + } else { + if (offset >= limit - 2) { + throw InvalidProtocolBufferException.invalidUtf8(); + } + DecodeUtil.handleFourBytes( + byte1, + /* byte2 */ bytes[offset++], + /* byte3 */ bytes[offset++], + /* byte4 */ bytes[offset++], + resultArr, + resultPos++); + // 4-byte case requires two chars. + resultPos++; + } + } + + return new String(resultArr, 0, resultPos); + } + + @Override + String decodeUtf8Direct(ByteBuffer buffer, int index, int size) + throws InvalidProtocolBufferException { + // For safe processing, we have to use the ByteBufferAPI. + return decodeUtf8Default(buffer, index, size); + } + @Override int encodeUtf8(CharSequence in, byte[] out, int offset, int length) { int utf16Length = in.length(); @@ -996,6 +1212,7 @@ final class Utf8 { @Override int partialIsValidUtf8(int state, byte[] bytes, final int index, final int limit) { + // Bitwise OR combines the sign bits so any negative value fails the check. if ((index | limit | bytes.length - limit) < 0) { throw new ArrayIndexOutOfBoundsException( String.format("Array length=%d, index=%d, limit=%d", bytes.length, index, limit)); @@ -1091,6 +1308,7 @@ final class Utf8 { @Override int partialIsValidUtf8Direct( final int state, ByteBuffer buffer, final int index, final int limit) { + // Bitwise OR combines the sign bits so any negative value fails the check. if ((index | limit | buffer.limit() - limit) < 0) { throw new ArrayIndexOutOfBoundsException( String.format("buffer limit=%d, index=%d, limit=%d", buffer.limit(), index, limit)); @@ -1184,6 +1402,163 @@ final class Utf8 { return partialIsValidUtf8(address, (int) (addressLimit - address)); } + @Override + String decodeUtf8(byte[] bytes, int index, int size) throws InvalidProtocolBufferException { + if ((index | size | bytes.length - index - size) < 0) { + throw new ArrayIndexOutOfBoundsException( + String.format("buffer length=%d, index=%d, size=%d", bytes.length, index, size)); + } + + int offset = index; + final int limit = offset + size; + + // The longest possible resulting String is the same as the number of input bytes, when it is + // all ASCII. For other cases, this over-allocates and we will truncate in the end. + char[] resultArr = new char[size]; + int resultPos = 0; + + // Optimize for 100% ASCII (Hotspot loves small simple top-level loops like this). + // This simple loop stops when we encounter a byte >= 0x80 (i.e. non-ASCII). + while (offset < limit) { + byte b = UnsafeUtil.getByte(bytes, offset); + if (!DecodeUtil.isOneByte(b)) { + break; + } + offset++; + DecodeUtil.handleOneByte(b, resultArr, resultPos++); + } + + while (offset < limit) { + byte byte1 = UnsafeUtil.getByte(bytes, offset++); + if (DecodeUtil.isOneByte(byte1)) { + DecodeUtil.handleOneByte(byte1, resultArr, resultPos++); + // It's common for there to be multiple ASCII characters in a run mixed in, so add an + // extra optimized loop to take care of these runs. + while (offset < limit) { + byte b = UnsafeUtil.getByte(bytes, offset); + if (!DecodeUtil.isOneByte(b)) { + break; + } + offset++; + DecodeUtil.handleOneByte(b, resultArr, resultPos++); + } + } else if (DecodeUtil.isTwoBytes(byte1)) { + if (offset >= limit) { + throw InvalidProtocolBufferException.invalidUtf8(); + } + DecodeUtil.handleTwoBytes( + byte1, /* byte2 */ UnsafeUtil.getByte(bytes, offset++), resultArr, resultPos++); + } else if (DecodeUtil.isThreeBytes(byte1)) { + if (offset >= limit - 1) { + throw InvalidProtocolBufferException.invalidUtf8(); + } + DecodeUtil.handleThreeBytes( + byte1, + /* byte2 */ UnsafeUtil.getByte(bytes, offset++), + /* byte3 */ UnsafeUtil.getByte(bytes, offset++), + resultArr, + resultPos++); + } else { + if (offset >= limit - 2) { + throw InvalidProtocolBufferException.invalidUtf8(); + } + DecodeUtil.handleFourBytes( + byte1, + /* byte2 */ UnsafeUtil.getByte(bytes, offset++), + /* byte3 */ UnsafeUtil.getByte(bytes, offset++), + /* byte4 */ UnsafeUtil.getByte(bytes, offset++), + resultArr, + resultPos++); + // 4-byte case requires two chars. + resultPos++; + } + } + + if (resultPos < resultArr.length) { + resultArr = Arrays.copyOf(resultArr, resultPos); + } + return UnsafeUtil.moveToString(resultArr); + } + + @Override + String decodeUtf8Direct(ByteBuffer buffer, int index, int size) + throws InvalidProtocolBufferException { + // Bitwise OR combines the sign bits so any negative value fails the check. + if ((index | size | buffer.limit() - index - size) < 0) { + throw new ArrayIndexOutOfBoundsException( + String.format("buffer limit=%d, index=%d, limit=%d", buffer.limit(), index, size)); + } + long address = UnsafeUtil.addressOffset(buffer) + index; + final long addressLimit = address + size; + + // The longest possible resulting String is the same as the number of input bytes, when it is + // all ASCII. For other cases, this over-allocates and we will truncate in the end. + char[] resultArr = new char[size]; + int resultPos = 0; + + // Optimize for 100% ASCII (Hotspot loves small simple top-level loops like this). + // This simple loop stops when we encounter a byte >= 0x80 (i.e. non-ASCII). + while (address < addressLimit) { + byte b = UnsafeUtil.getByte(address); + if (!DecodeUtil.isOneByte(b)) { + break; + } + address++; + DecodeUtil.handleOneByte(b, resultArr, resultPos++); + } + + while (address < addressLimit) { + byte byte1 = UnsafeUtil.getByte(address++); + if (DecodeUtil.isOneByte(byte1)) { + DecodeUtil.handleOneByte(byte1, resultArr, resultPos++); + // It's common for there to be multiple ASCII characters in a run mixed in, so add an + // extra optimized loop to take care of these runs. + while (address < addressLimit) { + byte b = UnsafeUtil.getByte(address); + if (!DecodeUtil.isOneByte(b)) { + break; + } + address++; + DecodeUtil.handleOneByte(b, resultArr, resultPos++); + } + } else if (DecodeUtil.isTwoBytes(byte1)) { + if (address >= addressLimit) { + throw InvalidProtocolBufferException.invalidUtf8(); + } + DecodeUtil.handleTwoBytes( + byte1, /* byte2 */ UnsafeUtil.getByte(address++), resultArr, resultPos++); + } else if (DecodeUtil.isThreeBytes(byte1)) { + if (address >= addressLimit - 1) { + throw InvalidProtocolBufferException.invalidUtf8(); + } + DecodeUtil.handleThreeBytes( + byte1, + /* byte2 */ UnsafeUtil.getByte(address++), + /* byte3 */ UnsafeUtil.getByte(address++), + resultArr, + resultPos++); + } else { + if (address >= addressLimit - 2) { + throw InvalidProtocolBufferException.invalidUtf8(); + } + DecodeUtil.handleFourBytes( + byte1, + /* byte2 */ UnsafeUtil.getByte(address++), + /* byte3 */ UnsafeUtil.getByte(address++), + /* byte4 */ UnsafeUtil.getByte(address++), + resultArr, + resultPos++); + // 4-byte case requires two chars. + resultPos++; + } + } + + if (resultPos < resultArr.length) { + resultArr = Arrays.copyOf(resultArr, resultPos); + } + return UnsafeUtil.moveToString(resultArr); + } + @Override int encodeUtf8(final CharSequence in, final byte[] out, final int offset, final int length) { long outIx = offset; @@ -1554,5 +1929,112 @@ final class Utf8 { } } + /** + * Utility methods for decoding bytes into {@link String}. Callers are responsible for extracting + * bytes (possibly using Unsafe methods), and checking remaining bytes. All other UTF-8 validity + * checks and codepoint conversion happen in this class. + */ + private static class DecodeUtil { + + /** + * Returns whether this is a single-byte codepoint (i.e., ASCII) with the form '0XXXXXXX'. + */ + private static boolean isOneByte(byte b) { + return b >= 0; + } + + /** + * Returns whether this is a two-byte codepoint with the form '10XXXXXX'. + */ + private static boolean isTwoBytes(byte b) { + return b < (byte) 0xE0; + } + + /** + * Returns whether this is a three-byte codepoint with the form '110XXXXX'. + */ + private static boolean isThreeBytes(byte b) { + return b < (byte) 0xF0; + } + + private static void handleOneByte(byte byte1, char[] resultArr, int resultPos) { + resultArr[resultPos] = (char) byte1; + } + + private static void handleTwoBytes( + byte byte1, byte byte2, char[] resultArr, int resultPos) + throws InvalidProtocolBufferException { + // Simultaneously checks for illegal trailing-byte in leading position (<= '11000000') and + // overlong 2-byte, '11000001'. + if (byte1 < (byte) 0xC2 + || isNotTrailingByte(byte2)) { + throw InvalidProtocolBufferException.invalidUtf8(); + } + resultArr[resultPos] = (char) (((byte1 & 0x1F) << 6) | trailingByteValue(byte2)); + } + + private static void handleThreeBytes( + byte byte1, byte byte2, byte byte3, char[] resultArr, int resultPos) + throws InvalidProtocolBufferException { + if (isNotTrailingByte(byte2) + // overlong? 5 most significant bits must not all be zero + || (byte1 == (byte) 0xE0 && byte2 < (byte) 0xA0) + // check for illegal surrogate codepoints + || (byte1 == (byte) 0xED && byte2 >= (byte) 0xA0) + || isNotTrailingByte(byte3)) { + throw InvalidProtocolBufferException.invalidUtf8(); + } + resultArr[resultPos] = (char) + (((byte1 & 0x0F) << 12) | (trailingByteValue(byte2) << 6) | trailingByteValue(byte3)); + } + + private static void handleFourBytes( + byte byte1, byte byte2, byte byte3, byte byte4, char[] resultArr, int resultPos) + throws InvalidProtocolBufferException{ + if (isNotTrailingByte(byte2) + // Check that 1 <= plane <= 16. Tricky optimized form of: + // valid 4-byte leading byte? + // if (byte1 > (byte) 0xF4 || + // overlong? 4 most significant bits must not all be zero + // byte1 == (byte) 0xF0 && byte2 < (byte) 0x90 || + // codepoint larger than the highest code point (U+10FFFF)? + // byte1 == (byte) 0xF4 && byte2 > (byte) 0x8F) + || (((byte1 << 28) + (byte2 - (byte) 0x90)) >> 30) != 0 + || isNotTrailingByte(byte3) + || isNotTrailingByte(byte4)) { + throw InvalidProtocolBufferException.invalidUtf8(); + } + int codepoint = ((byte1 & 0x07) << 18) + | (trailingByteValue(byte2) << 12) + | (trailingByteValue(byte3) << 6) + | trailingByteValue(byte4); + resultArr[resultPos] = DecodeUtil.highSurrogate(codepoint); + resultArr[resultPos + 1] = DecodeUtil.lowSurrogate(codepoint); + } + + /** + * Returns whether the byte is not a valid continuation of the form '10XXXXXX'. + */ + private static boolean isNotTrailingByte(byte b) { + return b > (byte) 0xBF; + } + + /** + * Returns the actual value of the trailing byte (removes the prefix '10') for composition. + */ + private static int trailingByteValue(byte b) { + return b & 0x3F; + } + + private static char highSurrogate(int codePoint) { + return (char) ((MIN_HIGH_SURROGATE - (MIN_SUPPLEMENTARY_CODE_POINT >>> 10)) + + (codePoint >>> 10)); + } + + private static char lowSurrogate(int codePoint) { + return (char) (MIN_LOW_SURROGATE + (codePoint & 0x3ff)); + } + } + private Utf8() {} } diff --git a/java/core/src/test/java/com/google/protobuf/DecodeUtf8Test.java b/java/core/src/test/java/com/google/protobuf/DecodeUtf8Test.java new file mode 100644 index 00000000..359d4d74 --- /dev/null +++ b/java/core/src/test/java/com/google/protobuf/DecodeUtf8Test.java @@ -0,0 +1,325 @@ +package com.google.protobuf; + +import com.google.protobuf.Utf8.Processor; +import com.google.protobuf.Utf8.SafeProcessor; +import com.google.protobuf.Utf8.UnsafeProcessor; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; +import java.util.logging.Logger; +import junit.framework.TestCase; + +public class DecodeUtf8Test extends TestCase { + private static Logger logger = Logger.getLogger(DecodeUtf8Test.class.getName()); + + private static final Processor SAFE_PROCESSOR = new SafeProcessor(); + private static final Processor UNSAFE_PROCESSOR = new UnsafeProcessor(); + + public void testRoundTripAllValidChars() throws Exception { + for (int i = Character.MIN_CODE_POINT; i < Character.MAX_CODE_POINT; i++) { + if (i < Character.MIN_SURROGATE || i > Character.MAX_SURROGATE) { + String str = new String(Character.toChars(i)); + assertRoundTrips(str); + } + } + } + + // Test all 1, 2, 3 invalid byte combinations. Valid ones would have been covered above. + + public void testOneByte() throws Exception { + int valid = 0; + for (int i = Byte.MIN_VALUE; i <= Byte.MAX_VALUE; i++) { + ByteString bs = ByteString.copyFrom(new byte[] { (byte) i }); + if (!bs.isValidUtf8()) { + assertInvalid(bs.toByteArray()); + } else { + valid++; + } + } + assertEquals(IsValidUtf8TestUtil.EXPECTED_ONE_BYTE_ROUNDTRIPPABLE_COUNT, valid); + } + + public void testTwoBytes() throws Exception { + int valid = 0; + for (int i = Byte.MIN_VALUE; i <= Byte.MAX_VALUE; i++) { + for (int j = Byte.MIN_VALUE; j <= Byte.MAX_VALUE; j++) { + ByteString bs = ByteString.copyFrom(new byte[]{(byte) i, (byte) j}); + if (!bs.isValidUtf8()) { + assertInvalid(bs.toByteArray()); + } else { + valid++; + } + } + } + assertEquals(IsValidUtf8TestUtil.EXPECTED_TWO_BYTE_ROUNDTRIPPABLE_COUNT, valid); + } + + public void testThreeBytes() throws Exception { + // Travis' OOM killer doesn't like this test + if (System.getenv("TRAVIS") == null) { + int count = 0; + int valid = 0; + for (int i = Byte.MIN_VALUE; i <= Byte.MAX_VALUE; i++) { + for (int j = Byte.MIN_VALUE; j <= Byte.MAX_VALUE; j++) { + for (int k = Byte.MIN_VALUE; k <= Byte.MAX_VALUE; k++) { + byte[] bytes = new byte[]{(byte) i, (byte) j, (byte) k}; + ByteString bs = ByteString.copyFrom(bytes); + if (!bs.isValidUtf8()) { + assertInvalid(bytes); + } else { + valid++; + } + count++; + if (count % 1000000L == 0) { + logger.info("Processed " + (count / 1000000L) + " million characters"); + } + } + } + } + assertEquals(IsValidUtf8TestUtil.EXPECTED_THREE_BYTE_ROUNDTRIPPABLE_COUNT, valid); + } + } + + /** + * Tests that round tripping of a sample of four byte permutations work. + */ + public void testInvalid_4BytesSamples() throws Exception { + // Bad trailing bytes + assertInvalid(0xF0, 0xA4, 0xAD, 0x7F); + assertInvalid(0xF0, 0xA4, 0xAD, 0xC0); + + // Special cases for byte2 + assertInvalid(0xF0, 0x8F, 0xAD, 0xA2); + assertInvalid(0xF4, 0x90, 0xAD, 0xA2); + } + + public void testRealStrings() throws Exception { + // English + assertRoundTrips("The quick brown fox jumps over the lazy dog"); + // German + assertRoundTrips("Quizdeltagerne spiste jordb\u00e6r med fl\u00f8de, mens cirkusklovnen"); + // Japanese + assertRoundTrips( + "\u3044\u308d\u306f\u306b\u307b\u3078\u3068\u3061\u308a\u306c\u308b\u3092"); + // Hebrew + assertRoundTrips( + "\u05d3\u05d2 \u05e1\u05e7\u05e8\u05df \u05e9\u05d8 \u05d1\u05d9\u05dd " + + "\u05de\u05d0\u05d5\u05db\u05d6\u05d1 \u05d5\u05dc\u05e4\u05ea\u05e2" + + " \u05de\u05e6\u05d0 \u05dc\u05d5 \u05d7\u05d1\u05e8\u05d4 " + + "\u05d0\u05d9\u05da \u05d4\u05e7\u05dc\u05d9\u05d8\u05d4"); + // Thai + assertRoundTrips( + " \u0e08\u0e07\u0e1d\u0e48\u0e32\u0e1f\u0e31\u0e19\u0e1e\u0e31\u0e12" + + "\u0e19\u0e32\u0e27\u0e34\u0e0a\u0e32\u0e01\u0e32\u0e23"); + // Chinese + assertRoundTrips( + "\u8fd4\u56de\u94fe\u4e2d\u7684\u4e0b\u4e00\u4e2a\u4ee3\u7406\u9879\u9009\u62e9\u5668"); + // Chinese with 4-byte chars + assertRoundTrips("\uD841\uDF0E\uD841\uDF31\uD841\uDF79\uD843\uDC53\uD843\uDC78" + + "\uD843\uDC96\uD843\uDCCF\uD843\uDCD5\uD843\uDD15\uD843\uDD7C\uD843\uDD7F" + + "\uD843\uDE0E\uD843\uDE0F\uD843\uDE77\uD843\uDE9D\uD843\uDEA2"); + // Mixed + assertRoundTrips( + "The quick brown \u3044\u308d\u306f\u306b\u307b\u3078\u8fd4\u56de\u94fe" + + "\u4e2d\u7684\u4e0b\u4e00"); + } + + public void testOverlong() throws Exception { + assertInvalid(0xc0, 0xaf); + assertInvalid(0xe0, 0x80, 0xaf); + assertInvalid(0xf0, 0x80, 0x80, 0xaf); + + // Max overlong + assertInvalid(0xc1, 0xbf); + assertInvalid(0xe0, 0x9f, 0xbf); + assertInvalid(0xf0 ,0x8f, 0xbf, 0xbf); + + // null overlong + assertInvalid(0xc0, 0x80); + assertInvalid(0xe0, 0x80, 0x80); + assertInvalid(0xf0, 0x80, 0x80, 0x80); + } + + public void testIllegalCodepoints() throws Exception { + // Single surrogate + assertInvalid(0xed, 0xa0, 0x80); + assertInvalid(0xed, 0xad, 0xbf); + assertInvalid(0xed, 0xae, 0x80); + assertInvalid(0xed, 0xaf, 0xbf); + assertInvalid(0xed, 0xb0, 0x80); + assertInvalid(0xed, 0xbe, 0x80); + assertInvalid(0xed, 0xbf, 0xbf); + + // Paired surrogates + assertInvalid(0xed, 0xa0, 0x80, 0xed, 0xb0, 0x80); + assertInvalid(0xed, 0xa0, 0x80, 0xed, 0xbf, 0xbf); + assertInvalid(0xed, 0xad, 0xbf, 0xed, 0xb0, 0x80); + assertInvalid(0xed, 0xad, 0xbf, 0xed, 0xbf, 0xbf); + assertInvalid(0xed, 0xae, 0x80, 0xed, 0xb0, 0x80); + assertInvalid(0xed, 0xae, 0x80, 0xed, 0xbf, 0xbf); + assertInvalid(0xed, 0xaf, 0xbf, 0xed, 0xb0, 0x80); + assertInvalid(0xed, 0xaf, 0xbf, 0xed, 0xbf, 0xbf); + } + + public void testBufferSlice() throws Exception { + String str = "The quick brown fox jumps over the lazy dog"; + assertRoundTrips(str, 10, 4); + assertRoundTrips(str, str.length(), 0); + } + + public void testInvalidBufferSlice() throws Exception { + byte[] bytes = "The quick brown fox jumps over the lazy dog".getBytes(Internal.UTF_8); + assertInvalidSlice(bytes, bytes.length - 3, 4); + assertInvalidSlice(bytes, bytes.length, 1); + assertInvalidSlice(bytes, bytes.length + 1, 0); + assertInvalidSlice(bytes, 0, bytes.length + 1); + } + + private void assertInvalid(int... bytesAsInt) throws Exception { + byte[] bytes = new byte[bytesAsInt.length]; + for (int i = 0; i < bytesAsInt.length; i++) { + bytes[i] = (byte) bytesAsInt[i]; + } + assertInvalid(bytes); + } + + private void assertInvalid(byte[] bytes) throws Exception { + try { + UNSAFE_PROCESSOR.decodeUtf8(bytes, 0, bytes.length); + fail(); + } catch (InvalidProtocolBufferException e) { + // Expected. + } + try { + SAFE_PROCESSOR.decodeUtf8(bytes, 0, bytes.length); + fail(); + } catch (InvalidProtocolBufferException e) { + // Expected. + } + + ByteBuffer direct = ByteBuffer.allocateDirect(bytes.length); + direct.put(bytes); + direct.flip(); + try { + UNSAFE_PROCESSOR.decodeUtf8(direct, 0, bytes.length); + fail(); + } catch (InvalidProtocolBufferException e) { + // Expected. + } + try { + SAFE_PROCESSOR.decodeUtf8(direct, 0, bytes.length); + fail(); + } catch (InvalidProtocolBufferException e) { + // Expected. + } + + ByteBuffer heap = ByteBuffer.allocate(bytes.length); + heap.put(bytes); + heap.flip(); + try { + UNSAFE_PROCESSOR.decodeUtf8(heap, 0, bytes.length); + fail(); + } catch (InvalidProtocolBufferException e) { + // Expected. + } + try { + SAFE_PROCESSOR.decodeUtf8(heap, 0, bytes.length); + fail(); + } catch (InvalidProtocolBufferException e) { + // Expected. + } + } + + private void assertInvalidSlice(byte[] bytes, int index, int size) throws Exception { + try { + UNSAFE_PROCESSOR.decodeUtf8(bytes, index, size); + fail(); + } catch (ArrayIndexOutOfBoundsException e) { + // Expected. + } + try { + SAFE_PROCESSOR.decodeUtf8(bytes, index, size); + fail(); + } catch (ArrayIndexOutOfBoundsException e) { + // Expected. + } + + ByteBuffer direct = ByteBuffer.allocateDirect(bytes.length); + direct.put(bytes); + direct.flip(); + try { + UNSAFE_PROCESSOR.decodeUtf8(direct, index, size); + fail(); + } catch (ArrayIndexOutOfBoundsException e) { + // Expected. + } + try { + SAFE_PROCESSOR.decodeUtf8(direct, index, size); + fail(); + } catch (ArrayIndexOutOfBoundsException e) { + // Expected. + } + + ByteBuffer heap = ByteBuffer.allocate(bytes.length); + heap.put(bytes); + heap.flip(); + try { + UNSAFE_PROCESSOR.decodeUtf8(heap, index, size); + fail(); + } catch (ArrayIndexOutOfBoundsException e) { + // Expected. + } + try { + SAFE_PROCESSOR.decodeUtf8(heap, index, size); + fail(); + } catch (ArrayIndexOutOfBoundsException e) { + // Expected. + } + } + + private void assertRoundTrips(String str) throws Exception { + assertRoundTrips(str, 0, -1); + } + + private void assertRoundTrips(String str, int index, int size) throws Exception { + byte[] bytes = str.getBytes(Internal.UTF_8); + if (size == -1) { + size = bytes.length; + } + assertDecode(new String(bytes, index, size, Internal.UTF_8), + UNSAFE_PROCESSOR.decodeUtf8(bytes, index, size)); + assertDecode(new String(bytes, index, size, Internal.UTF_8), + SAFE_PROCESSOR.decodeUtf8(bytes, index, size)); + + ByteBuffer direct = ByteBuffer.allocateDirect(bytes.length); + direct.put(bytes); + direct.flip(); + assertDecode(new String(bytes, index, size, Internal.UTF_8), + UNSAFE_PROCESSOR.decodeUtf8(direct, index, size)); + assertDecode(new String(bytes, index, size, Internal.UTF_8), + SAFE_PROCESSOR.decodeUtf8(direct, index, size)); + + ByteBuffer heap = ByteBuffer.allocate(bytes.length); + heap.put(bytes); + heap.flip(); + assertDecode(new String(bytes, index, size, Internal.UTF_8), + UNSAFE_PROCESSOR.decodeUtf8(heap, index, size)); + assertDecode(new String(bytes, index, size, Internal.UTF_8), + SAFE_PROCESSOR.decodeUtf8(heap, index, size)); + } + + private void assertDecode(String expected, String actual) { + if (!expected.equals(actual)) { + fail("Failure: Expected (" + codepoints(expected) + ") Actual (" + codepoints(actual) + ")"); + } + } + + private List codepoints(String str) { + List codepoints = new ArrayList(); + for (int i = 0; i < str.length(); i++) { + codepoints.add(Long.toHexString(str.charAt(i))); + } + return codepoints; + } + +} diff --git a/java/core/src/test/java/com/google/protobuf/IsValidUtf8TestUtil.java b/java/core/src/test/java/com/google/protobuf/IsValidUtf8TestUtil.java index 16a808bf..1bcf63e7 100644 --- a/java/core/src/test/java/com/google/protobuf/IsValidUtf8TestUtil.java +++ b/java/core/src/test/java/com/google/protobuf/IsValidUtf8TestUtil.java @@ -273,6 +273,15 @@ final class IsValidUtf8TestUtil { assertEquals(isRoundTrippable, Utf8.isValidUtf8(bytes)); assertEquals(isRoundTrippable, Utf8.isValidUtf8(bytes, 0, numBytes)); + try { + assertEquals(s, Utf8.decodeUtf8(bytes, 0, numBytes)); + } catch (InvalidProtocolBufferException e) { + if (isRoundTrippable) { + System.out.println("Could not decode utf-8"); + outputFailure(byteChar, bytes, bytesReencoded); + } + } + // Test partial sequences. // Partition numBytes into three segments (not necessarily non-empty). int i = rnd.nextInt(numBytes); -- cgit v1.2.3