aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
authorTejas Patil <tejasp@fb.com>2017-03-07 20:19:30 -0800
committerWenchen Fan <wenchen@databricks.com>2017-03-07 20:19:30 -0800
commitc96d14abae5962a7b15239319c2a151b95f7db94 (patch)
treeeefd1b2e220a4a7afc901e29418f3f5ee92f21d1 /sql/catalyst
parent47b2f68a885b7a2fc593ac7a55cd19742016364d (diff)
downloadspark-c96d14abae5962a7b15239319c2a151b95f7db94.tar.gz
spark-c96d14abae5962a7b15239319c2a151b95f7db94.tar.bz2
spark-c96d14abae5962a7b15239319c2a151b95f7db94.zip
[SPARK-19843][SQL] UTF8String => (int / long) conversion expensive for invalid inputs
## What changes were proposed in this pull request? Jira : https://issues.apache.org/jira/browse/SPARK-19843 Created wrapper classes (`IntWrapper`, `LongWrapper`) to wrap the result of parsing (which are primitive types). In case of problem in parsing, the method would return a boolean. ## How was this patch tested? - Added new unit tests - Ran a prod job which had conversion from string -> int and verified the outputs ## Performance Tiny regression when all strings are valid integers ``` conversion to int: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------- trunk 502 / 522 33.4 29.9 1.0X SPARK-19843 493 / 503 34.0 29.4 1.0X ``` Huge gain when all strings are invalid integers ``` conversion to int: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------- trunk 33913 / 34219 0.5 2021.4 1.0X SPARK-19843 154 / 162 108.8 9.2 220.0X ``` Author: Tejas Patil <tejasp@fb.com> Closes #17184 from tejasapatil/SPARK-19843_is_numeric_maybe.
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala81
1 files changed, 50 insertions, 31 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index a36d3507d9..7c60f7d57a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
-
+import org.apache.spark.unsafe.types.UTF8String.{IntWrapper, LongWrapper}
object Cast {
@@ -277,9 +277,8 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
// LongConverter
private[this] def castToLong(from: DataType): Any => Any = from match {
case StringType =>
- buildCast[UTF8String](_, s => try s.toLong catch {
- case _: NumberFormatException => null
- })
+ val result = new LongWrapper()
+ buildCast[UTF8String](_, s => if (s.toLong(result)) result.value else null)
case BooleanType =>
buildCast[Boolean](_, b => if (b) 1L else 0L)
case DateType =>
@@ -293,9 +292,8 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
// IntConverter
private[this] def castToInt(from: DataType): Any => Any = from match {
case StringType =>
- buildCast[UTF8String](_, s => try s.toInt catch {
- case _: NumberFormatException => null
- })
+ val result = new IntWrapper()
+ buildCast[UTF8String](_, s => if (s.toInt(result)) result.value else null)
case BooleanType =>
buildCast[Boolean](_, b => if (b) 1 else 0)
case DateType =>
@@ -309,8 +307,11 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
// ShortConverter
private[this] def castToShort(from: DataType): Any => Any = from match {
case StringType =>
- buildCast[UTF8String](_, s => try s.toShort catch {
- case _: NumberFormatException => null
+ val result = new IntWrapper()
+ buildCast[UTF8String](_, s => if (s.toShort(result)) {
+ result.value.toShort
+ } else {
+ null
})
case BooleanType =>
buildCast[Boolean](_, b => if (b) 1.toShort else 0.toShort)
@@ -325,8 +326,11 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
// ByteConverter
private[this] def castToByte(from: DataType): Any => Any = from match {
case StringType =>
- buildCast[UTF8String](_, s => try s.toByte catch {
- case _: NumberFormatException => null
+ val result = new IntWrapper()
+ buildCast[UTF8String](_, s => if (s.toByte(result)) {
+ result.value.toByte
+ } else {
+ null
})
case BooleanType =>
buildCast[Boolean](_, b => if (b) 1.toByte else 0.toByte)
@@ -503,11 +507,11 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
case TimestampType => castToTimestampCode(from, ctx)
case CalendarIntervalType => castToIntervalCode(from)
case BooleanType => castToBooleanCode(from)
- case ByteType => castToByteCode(from)
- case ShortType => castToShortCode(from)
- case IntegerType => castToIntCode(from)
+ case ByteType => castToByteCode(from, ctx)
+ case ShortType => castToShortCode(from, ctx)
+ case IntegerType => castToIntCode(from, ctx)
case FloatType => castToFloatCode(from)
- case LongType => castToLongCode(from)
+ case LongType => castToLongCode(from, ctx)
case DoubleType => castToDoubleCode(from)
case array: ArrayType =>
@@ -734,13 +738,16 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
(c, evPrim, evNull) => s"$evPrim = $c != 0;"
}
- private[this] def castToByteCode(from: DataType): CastFunction = from match {
+ private[this] def castToByteCode(from: DataType, ctx: CodegenContext): CastFunction = from match {
case StringType =>
+ val wrapper = ctx.freshName("wrapper")
+ ctx.addMutableState("UTF8String.IntWrapper", wrapper,
+ s"$wrapper = new UTF8String.IntWrapper();")
(c, evPrim, evNull) =>
s"""
- try {
- $evPrim = $c.toByte();
- } catch (java.lang.NumberFormatException e) {
+ if ($c.toByte($wrapper)) {
+ $evPrim = (byte) $wrapper.value;
+ } else {
$evNull = true;
}
"""
@@ -756,13 +763,18 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
(c, evPrim, evNull) => s"$evPrim = (byte) $c;"
}
- private[this] def castToShortCode(from: DataType): CastFunction = from match {
+ private[this] def castToShortCode(
+ from: DataType,
+ ctx: CodegenContext): CastFunction = from match {
case StringType =>
+ val wrapper = ctx.freshName("wrapper")
+ ctx.addMutableState("UTF8String.IntWrapper", wrapper,
+ s"$wrapper = new UTF8String.IntWrapper();")
(c, evPrim, evNull) =>
s"""
- try {
- $evPrim = $c.toShort();
- } catch (java.lang.NumberFormatException e) {
+ if ($c.toShort($wrapper)) {
+ $evPrim = (short) $wrapper.value;
+ } else {
$evNull = true;
}
"""
@@ -778,13 +790,16 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
(c, evPrim, evNull) => s"$evPrim = (short) $c;"
}
- private[this] def castToIntCode(from: DataType): CastFunction = from match {
+ private[this] def castToIntCode(from: DataType, ctx: CodegenContext): CastFunction = from match {
case StringType =>
+ val wrapper = ctx.freshName("wrapper")
+ ctx.addMutableState("UTF8String.IntWrapper", wrapper,
+ s"$wrapper = new UTF8String.IntWrapper();")
(c, evPrim, evNull) =>
s"""
- try {
- $evPrim = $c.toInt();
- } catch (java.lang.NumberFormatException e) {
+ if ($c.toInt($wrapper)) {
+ $evPrim = $wrapper.value;
+ } else {
$evNull = true;
}
"""
@@ -800,13 +815,17 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
(c, evPrim, evNull) => s"$evPrim = (int) $c;"
}
- private[this] def castToLongCode(from: DataType): CastFunction = from match {
+ private[this] def castToLongCode(from: DataType, ctx: CodegenContext): CastFunction = from match {
case StringType =>
+ val wrapper = ctx.freshName("wrapper")
+ ctx.addMutableState("UTF8String.LongWrapper", wrapper,
+ s"$wrapper = new UTF8String.LongWrapper();")
+
(c, evPrim, evNull) =>
s"""
- try {
- $evPrim = $c.toLong();
- } catch (java.lang.NumberFormatException e) {
+ if ($c.toLong($wrapper)) {
+ $evPrim = $wrapper.value;
+ } else {
$evNull = true;
}
"""