aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTejas Patil <tejasp@fb.com>2017-03-07 20:19:30 -0800
committerWenchen Fan <wenchen@databricks.com>2017-03-07 20:19:30 -0800
commitc96d14abae5962a7b15239319c2a151b95f7db94 (patch)
treeeefd1b2e220a4a7afc901e29418f3f5ee92f21d1
parent47b2f68a885b7a2fc593ac7a55cd19742016364d (diff)
downloadspark-c96d14abae5962a7b15239319c2a151b95f7db94.tar.gz
spark-c96d14abae5962a7b15239319c2a151b95f7db94.tar.bz2
spark-c96d14abae5962a7b15239319c2a151b95f7db94.zip
[SPARK-19843][SQL] UTF8String => (int / long) conversion expensive for invalid inputs
## What changes were proposed in this pull request? Jira : https://issues.apache.org/jira/browse/SPARK-19843 Created wrapper classes (`IntWrapper`, `LongWrapper`) to wrap the result of parsing (which are primitive types). In case of problem in parsing, the method would return a boolean. ## How was this patch tested? - Added new unit tests - Ran a prod job which had conversion from string -> int and verified the outputs ## Performance Tiny regression when all strings are valid integers ``` conversion to int: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------- trunk 502 / 522 33.4 29.9 1.0X SPARK-19843 493 / 503 34.0 29.4 1.0X ``` Huge gain when all strings are invalid integers ``` conversion to int: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------- trunk 33913 / 34219 0.5 2021.4 1.0X SPARK-19843 154 / 162 108.8 9.2 220.0X ``` Author: Tejas Patil <tejasp@fb.com> Closes #17184 from tejasapatil/SPARK-19843_is_numeric_maybe.
-rw-r--r--common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java120
-rw-r--r--common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java128
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala81
3 files changed, 247 insertions, 82 deletions
diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
index 10a7cb1d06..7abe0fa80a 100644
--- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
+++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
@@ -850,11 +850,8 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable,
return fromString(sb.toString());
}
- private int getDigit(byte b) {
- if (b >= '0' && b <= '9') {
- return b - '0';
- }
- throw new NumberFormatException(toString());
+ public static class LongWrapper {
+ public long value = 0;
}
/**
@@ -862,14 +859,18 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable,
*
* Note that, in this method we accumulate the result in negative format, and convert it to
* positive format at the end, if this string is not started with '-'. This is because min value
- * is bigger than max value in digits, e.g. Integer.MAX_VALUE is '2147483647' and
- * Integer.MIN_VALUE is '-2147483648'.
+ * is bigger than max value in digits, e.g. Long.MAX_VALUE is '9223372036854775807' and
+ * Long.MIN_VALUE is '-9223372036854775808'.
*
* This code is mostly copied from LazyLong.parseLong in Hive.
+ *
+ * @param toLongResult If a valid `long` was parsed from this UTF8String, then its value would
+ * be set in `toLongResult`
+ * @return true if the parsing was successful else false
*/
- public long toLong() {
+ public boolean toLong(LongWrapper toLongResult) {
if (numBytes == 0) {
- throw new NumberFormatException("Empty string");
+ return false;
}
byte b = getByte(0);
@@ -878,7 +879,7 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable,
if (negative || b == '+') {
offset++;
if (numBytes == 1) {
- throw new NumberFormatException(toString());
+ return false;
}
}
@@ -897,20 +898,25 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable,
break;
}
- int digit = getDigit(b);
+ int digit;
+ if (b >= '0' && b <= '9') {
+ digit = b - '0';
+ } else {
+ return false;
+ }
+
// We are going to process the new digit and accumulate the result. However, before doing
// this, if the result is already smaller than the stopValue(Long.MIN_VALUE / radix), then
- // result * 10 will definitely be smaller than minValue, and we can stop and throw exception.
+ // result * 10 will definitely be smaller than minValue, and we can stop.
if (result < stopValue) {
- throw new NumberFormatException(toString());
+ return false;
}
result = result * radix - digit;
// Since the previous result is less than or equal to stopValue(Long.MIN_VALUE / radix), we
- // can just use `result > 0` to check overflow. If result overflows, we should stop and throw
- // exception.
+ // can just use `result > 0` to check overflow. If result overflows, we should stop.
if (result > 0) {
- throw new NumberFormatException(toString());
+ return false;
}
}
@@ -918,8 +924,9 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable,
// part will not change the number, but we will verify that the fractional part
// is well formed.
while (offset < numBytes) {
- if (getDigit(getByte(offset)) == -1) {
- throw new NumberFormatException(toString());
+ byte currentByte = getByte(offset);
+ if (currentByte < '0' || currentByte > '9') {
+ return false;
}
offset++;
}
@@ -927,11 +934,16 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable,
if (!negative) {
result = -result;
if (result < 0) {
- throw new NumberFormatException(toString());
+ return false;
}
}
- return result;
+ toLongResult.value = result;
+ return true;
+ }
+
+ public static class IntWrapper {
+ public int value = 0;
}
/**
@@ -946,10 +958,14 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable,
*
* Note that, this method is almost same as `toLong`, but we leave it duplicated for performance
* reasons, like Hive does.
+ *
+ * @param intWrapper If a valid `int` was parsed from this UTF8String, then its value would
+ * be set in `intWrapper`
+ * @return true if the parsing was successful else false
*/
- public int toInt() {
+ public boolean toInt(IntWrapper intWrapper) {
if (numBytes == 0) {
- throw new NumberFormatException("Empty string");
+ return false;
}
byte b = getByte(0);
@@ -958,7 +974,7 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable,
if (negative || b == '+') {
offset++;
if (numBytes == 1) {
- throw new NumberFormatException(toString());
+ return false;
}
}
@@ -977,20 +993,25 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable,
break;
}
- int digit = getDigit(b);
+ int digit;
+ if (b >= '0' && b <= '9') {
+ digit = b - '0';
+ } else {
+ return false;
+ }
+
// We are going to process the new digit and accumulate the result. However, before doing
// this, if the result is already smaller than the stopValue(Integer.MIN_VALUE / radix), then
- // result * 10 will definitely be smaller than minValue, and we can stop and throw exception.
+ // result * 10 will definitely be smaller than minValue, and we can stop
if (result < stopValue) {
- throw new NumberFormatException(toString());
+ return false;
}
result = result * radix - digit;
// Since the previous result is less than or equal to stopValue(Integer.MIN_VALUE / radix),
- // we can just use `result > 0` to check overflow. If result overflows, we should stop and
- // throw exception.
+ // we can just use `result > 0` to check overflow. If result overflows, we should stop
if (result > 0) {
- throw new NumberFormatException(toString());
+ return false;
}
}
@@ -998,8 +1019,9 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable,
// part will not change the number, but we will verify that the fractional part
// is well formed.
while (offset < numBytes) {
- if (getDigit(getByte(offset)) == -1) {
- throw new NumberFormatException(toString());
+ byte currentByte = getByte(offset);
+ if (currentByte < '0' || currentByte > '9') {
+ return false;
}
offset++;
}
@@ -1007,31 +1029,33 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable,
if (!negative) {
result = -result;
if (result < 0) {
- throw new NumberFormatException(toString());
+ return false;
}
}
-
- return result;
+ intWrapper.value = result;
+ return true;
}
- public short toShort() {
- int intValue = toInt();
- short result = (short) intValue;
- if (result != intValue) {
- throw new NumberFormatException(toString());
+ public boolean toShort(IntWrapper intWrapper) {
+ if (toInt(intWrapper)) {
+ int intValue = intWrapper.value;
+ short result = (short) intValue;
+ if (result == intValue) {
+ return true;
+ }
}
-
- return result;
+ return false;
}
- public byte toByte() {
- int intValue = toInt();
- byte result = (byte) intValue;
- if (result != intValue) {
- throw new NumberFormatException(toString());
+ public boolean toByte(IntWrapper intWrapper) {
+ if (toInt(intWrapper)) {
+ int intValue = intWrapper.value;
+ byte result = (byte) intValue;
+ if (result == intValue) {
+ return true;
+ }
}
-
- return result;
+ return false;
}
@Override
diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
index 6f6e0ef0e4..c376371abd 100644
--- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
+++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
@@ -22,9 +22,7 @@ import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.charset.StandardCharsets;
-import java.util.Arrays;
-import java.util.HashMap;
-import java.util.HashSet;
+import java.util.*;
import com.google.common.collect.ImmutableMap;
import org.apache.spark.unsafe.Platform;
@@ -608,4 +606,128 @@ public class UTF8StringSuite {
.writeTo(outputStream);
assertEquals("大千世界", outputStream.toString("UTF-8"));
}
+
+ @Test
+ public void testToShort() throws IOException {
+ Map<String, Short> inputToExpectedOutput = new HashMap<>();
+ inputToExpectedOutput.put("1", (short) 1);
+ inputToExpectedOutput.put("+1", (short) 1);
+ inputToExpectedOutput.put("-1", (short) -1);
+ inputToExpectedOutput.put("0", (short) 0);
+ inputToExpectedOutput.put("1111.12345678901234567890", (short) 1111);
+ inputToExpectedOutput.put(String.valueOf(Short.MAX_VALUE), Short.MAX_VALUE);
+ inputToExpectedOutput.put(String.valueOf(Short.MIN_VALUE), Short.MIN_VALUE);
+
+ Random rand = new Random();
+ for (int i = 0; i < 10; i++) {
+ short value = (short) rand.nextInt();
+ inputToExpectedOutput.put(String.valueOf(value), value);
+ }
+
+ IntWrapper wrapper = new IntWrapper();
+ for (Map.Entry<String, Short> entry : inputToExpectedOutput.entrySet()) {
+ assertTrue(entry.getKey(), UTF8String.fromString(entry.getKey()).toShort(wrapper));
+ assertEquals((short) entry.getValue(), wrapper.value);
+ }
+
+ List<String> negativeInputs =
+ Arrays.asList("", " ", "null", "NULL", "\n", "~1212121", "3276700");
+
+ for (String negativeInput : negativeInputs) {
+ assertFalse(negativeInput, UTF8String.fromString(negativeInput).toShort(wrapper));
+ }
+ }
+
+ @Test
+ public void testToByte() throws IOException {
+ Map<String, Byte> inputToExpectedOutput = new HashMap<>();
+ inputToExpectedOutput.put("1", (byte) 1);
+ inputToExpectedOutput.put("+1",(byte) 1);
+ inputToExpectedOutput.put("-1", (byte) -1);
+ inputToExpectedOutput.put("0", (byte) 0);
+ inputToExpectedOutput.put("111.12345678901234567890", (byte) 111);
+ inputToExpectedOutput.put(String.valueOf(Byte.MAX_VALUE), Byte.MAX_VALUE);
+ inputToExpectedOutput.put(String.valueOf(Byte.MIN_VALUE), Byte.MIN_VALUE);
+
+ Random rand = new Random();
+ for (int i = 0; i < 10; i++) {
+ byte value = (byte) rand.nextInt();
+ inputToExpectedOutput.put(String.valueOf(value), value);
+ }
+
+ IntWrapper intWrapper = new IntWrapper();
+ for (Map.Entry<String, Byte> entry : inputToExpectedOutput.entrySet()) {
+ assertTrue(entry.getKey(), UTF8String.fromString(entry.getKey()).toByte(intWrapper));
+ assertEquals((byte) entry.getValue(), intWrapper.value);
+ }
+
+ List<String> negativeInputs =
+ Arrays.asList("", " ", "null", "NULL", "\n", "~1212121", "12345678901234567890");
+
+ for (String negativeInput : negativeInputs) {
+ assertFalse(negativeInput, UTF8String.fromString(negativeInput).toByte(intWrapper));
+ }
+ }
+
+ @Test
+ public void testToInt() throws IOException {
+ Map<String, Integer> inputToExpectedOutput = new HashMap<>();
+ inputToExpectedOutput.put("1", 1);
+ inputToExpectedOutput.put("+1", 1);
+ inputToExpectedOutput.put("-1", -1);
+ inputToExpectedOutput.put("0", 0);
+ inputToExpectedOutput.put("11111.1234567", 11111);
+ inputToExpectedOutput.put(String.valueOf(Integer.MAX_VALUE), Integer.MAX_VALUE);
+ inputToExpectedOutput.put(String.valueOf(Integer.MIN_VALUE), Integer.MIN_VALUE);
+
+ Random rand = new Random();
+ for (int i = 0; i < 10; i++) {
+ int value = rand.nextInt();
+ inputToExpectedOutput.put(String.valueOf(value), value);
+ }
+
+ IntWrapper intWrapper = new IntWrapper();
+ for (Map.Entry<String, Integer> entry : inputToExpectedOutput.entrySet()) {
+ assertTrue(entry.getKey(), UTF8String.fromString(entry.getKey()).toInt(intWrapper));
+ assertEquals((int) entry.getValue(), intWrapper.value);
+ }
+
+ List<String> negativeInputs =
+ Arrays.asList("", " ", "null", "NULL", "\n", "~1212121", "12345678901234567890");
+
+ for (String negativeInput : negativeInputs) {
+ assertFalse(negativeInput, UTF8String.fromString(negativeInput).toInt(intWrapper));
+ }
+ }
+
+ @Test
+ public void testToLong() throws IOException {
+ Map<String, Long> inputToExpectedOutput = new HashMap<>();
+ inputToExpectedOutput.put("1", 1L);
+ inputToExpectedOutput.put("+1", 1L);
+ inputToExpectedOutput.put("-1", -1L);
+ inputToExpectedOutput.put("0", 0L);
+ inputToExpectedOutput.put("1076753423.12345678901234567890", 1076753423L);
+ inputToExpectedOutput.put(String.valueOf(Long.MAX_VALUE), Long.MAX_VALUE);
+ inputToExpectedOutput.put(String.valueOf(Long.MIN_VALUE), Long.MIN_VALUE);
+
+ Random rand = new Random();
+ for (int i = 0; i < 10; i++) {
+ long value = rand.nextLong();
+ inputToExpectedOutput.put(String.valueOf(value), value);
+ }
+
+ LongWrapper wrapper = new LongWrapper();
+ for (Map.Entry<String, Long> entry : inputToExpectedOutput.entrySet()) {
+ assertTrue(entry.getKey(), UTF8String.fromString(entry.getKey()).toLong(wrapper));
+ assertEquals((long) entry.getValue(), wrapper.value);
+ }
+
+ List<String> negativeInputs = Arrays.asList("", " ", "null", "NULL", "\n", "~1212121",
+ "1234567890123456789012345678901234");
+
+ for (String negativeInput : negativeInputs) {
+ assertFalse(negativeInput, UTF8String.fromString(negativeInput).toLong(wrapper));
+ }
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index a36d3507d9..7c60f7d57a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
-
+import org.apache.spark.unsafe.types.UTF8String.{IntWrapper, LongWrapper}
object Cast {
@@ -277,9 +277,8 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
// LongConverter
private[this] def castToLong(from: DataType): Any => Any = from match {
case StringType =>
- buildCast[UTF8String](_, s => try s.toLong catch {
- case _: NumberFormatException => null
- })
+ val result = new LongWrapper()
+ buildCast[UTF8String](_, s => if (s.toLong(result)) result.value else null)
case BooleanType =>
buildCast[Boolean](_, b => if (b) 1L else 0L)
case DateType =>
@@ -293,9 +292,8 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
// IntConverter
private[this] def castToInt(from: DataType): Any => Any = from match {
case StringType =>
- buildCast[UTF8String](_, s => try s.toInt catch {
- case _: NumberFormatException => null
- })
+ val result = new IntWrapper()
+ buildCast[UTF8String](_, s => if (s.toInt(result)) result.value else null)
case BooleanType =>
buildCast[Boolean](_, b => if (b) 1 else 0)
case DateType =>
@@ -309,8 +307,11 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
// ShortConverter
private[this] def castToShort(from: DataType): Any => Any = from match {
case StringType =>
- buildCast[UTF8String](_, s => try s.toShort catch {
- case _: NumberFormatException => null
+ val result = new IntWrapper()
+ buildCast[UTF8String](_, s => if (s.toShort(result)) {
+ result.value.toShort
+ } else {
+ null
})
case BooleanType =>
buildCast[Boolean](_, b => if (b) 1.toShort else 0.toShort)
@@ -325,8 +326,11 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
// ByteConverter
private[this] def castToByte(from: DataType): Any => Any = from match {
case StringType =>
- buildCast[UTF8String](_, s => try s.toByte catch {
- case _: NumberFormatException => null
+ val result = new IntWrapper()
+ buildCast[UTF8String](_, s => if (s.toByte(result)) {
+ result.value.toByte
+ } else {
+ null
})
case BooleanType =>
buildCast[Boolean](_, b => if (b) 1.toByte else 0.toByte)
@@ -503,11 +507,11 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
case TimestampType => castToTimestampCode(from, ctx)
case CalendarIntervalType => castToIntervalCode(from)
case BooleanType => castToBooleanCode(from)
- case ByteType => castToByteCode(from)
- case ShortType => castToShortCode(from)
- case IntegerType => castToIntCode(from)
+ case ByteType => castToByteCode(from, ctx)
+ case ShortType => castToShortCode(from, ctx)
+ case IntegerType => castToIntCode(from, ctx)
case FloatType => castToFloatCode(from)
- case LongType => castToLongCode(from)
+ case LongType => castToLongCode(from, ctx)
case DoubleType => castToDoubleCode(from)
case array: ArrayType =>
@@ -734,13 +738,16 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
(c, evPrim, evNull) => s"$evPrim = $c != 0;"
}
- private[this] def castToByteCode(from: DataType): CastFunction = from match {
+ private[this] def castToByteCode(from: DataType, ctx: CodegenContext): CastFunction = from match {
case StringType =>
+ val wrapper = ctx.freshName("wrapper")
+ ctx.addMutableState("UTF8String.IntWrapper", wrapper,
+ s"$wrapper = new UTF8String.IntWrapper();")
(c, evPrim, evNull) =>
s"""
- try {
- $evPrim = $c.toByte();
- } catch (java.lang.NumberFormatException e) {
+ if ($c.toByte($wrapper)) {
+ $evPrim = (byte) $wrapper.value;
+ } else {
$evNull = true;
}
"""
@@ -756,13 +763,18 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
(c, evPrim, evNull) => s"$evPrim = (byte) $c;"
}
- private[this] def castToShortCode(from: DataType): CastFunction = from match {
+ private[this] def castToShortCode(
+ from: DataType,
+ ctx: CodegenContext): CastFunction = from match {
case StringType =>
+ val wrapper = ctx.freshName("wrapper")
+ ctx.addMutableState("UTF8String.IntWrapper", wrapper,
+ s"$wrapper = new UTF8String.IntWrapper();")
(c, evPrim, evNull) =>
s"""
- try {
- $evPrim = $c.toShort();
- } catch (java.lang.NumberFormatException e) {
+ if ($c.toShort($wrapper)) {
+ $evPrim = (short) $wrapper.value;
+ } else {
$evNull = true;
}
"""
@@ -778,13 +790,16 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
(c, evPrim, evNull) => s"$evPrim = (short) $c;"
}
- private[this] def castToIntCode(from: DataType): CastFunction = from match {
+ private[this] def castToIntCode(from: DataType, ctx: CodegenContext): CastFunction = from match {
case StringType =>
+ val wrapper = ctx.freshName("wrapper")
+ ctx.addMutableState("UTF8String.IntWrapper", wrapper,
+ s"$wrapper = new UTF8String.IntWrapper();")
(c, evPrim, evNull) =>
s"""
- try {
- $evPrim = $c.toInt();
- } catch (java.lang.NumberFormatException e) {
+ if ($c.toInt($wrapper)) {
+ $evPrim = $wrapper.value;
+ } else {
$evNull = true;
}
"""
@@ -800,13 +815,17 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
(c, evPrim, evNull) => s"$evPrim = (int) $c;"
}
- private[this] def castToLongCode(from: DataType): CastFunction = from match {
+ private[this] def castToLongCode(from: DataType, ctx: CodegenContext): CastFunction = from match {
case StringType =>
+ val wrapper = ctx.freshName("wrapper")
+ ctx.addMutableState("UTF8String.LongWrapper", wrapper,
+ s"$wrapper = new UTF8String.LongWrapper();")
+
(c, evPrim, evNull) =>
s"""
- try {
- $evPrim = $c.toLong();
- } catch (java.lang.NumberFormatException e) {
+ if ($c.toLong($wrapper)) {
+ $evPrim = $wrapper.value;
+ } else {
$evNull = true;
}
"""