aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorReynold Xin <rxin@apache.org>2014-06-20 00:01:19 -0700
committerReynold Xin <rxin@apache.org>2014-06-20 00:01:19 -0700
commitc55bbb49f7ec653f0ff635015d3bc789ca26c4eb (patch)
treea47a29f9075c03e75ceb0ccd8d26775b14126084 /sql
parent61756409736a64bd42577782cb7468557fa0b642 (diff)
downloadspark-c55bbb49f7ec653f0ff635015d3bc789ca26c4eb.tar.gz
spark-c55bbb49f7ec653f0ff635015d3bc789ca26c4eb.tar.bz2
spark-c55bbb49f7ec653f0ff635015d3bc789ca26c4eb.zip
[SPARK-2209][SQL] Cast shouldn't do null check twice.
Also took the chance to clean up cast a little bit. Too many arrows on each line before! Author: Reynold Xin <rxin@apache.org> Closes #1143 from rxin/cast and squashes the following commits: dd006cb [Reynold Xin] Code review feedback. c2b88ae [Reynold Xin] [SPARK-2209][SQL] Cast shouldn't do null check twice.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala274
1 files changed, 159 insertions, 115 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 0b3a4e728e..1f9716e385 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -24,72 +24,87 @@ import org.apache.spark.sql.catalyst.types._
/** Cast the child expression to the target data type. */
case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
override def foldable = child.foldable
- def nullable = (child.dataType, dataType) match {
+
+ override def nullable = (child.dataType, dataType) match {
case (StringType, _: NumericType) => true
case (StringType, TimestampType) => true
case _ => child.nullable
}
+
override def toString = s"CAST($child, $dataType)"
type EvaluatedType = Any
- def nullOrCast[T](a: Any, func: T => Any): Any = if(a == null) {
- null
- } else {
- func(a.asInstanceOf[T])
- }
+ // [[func]] assumes the input is no longer null because eval already does the null check.
+ @inline private[this] def buildCast[T](a: Any, func: T => Any): Any = func(a.asInstanceOf[T])
// UDFToString
- def castToString: Any => Any = child.dataType match {
- case BinaryType => nullOrCast[Array[Byte]](_, new String(_, "UTF-8"))
- case _ => nullOrCast[Any](_, _.toString)
+ private[this] def castToString: Any => Any = child.dataType match {
+ case BinaryType => buildCast[Array[Byte]](_, new String(_, "UTF-8"))
+ case _ => buildCast[Any](_, _.toString)
}
// BinaryConverter
- def castToBinary: Any => Any = child.dataType match {
- case StringType => nullOrCast[String](_, _.getBytes("UTF-8"))
+ private[this] def castToBinary: Any => Any = child.dataType match {
+ case StringType => buildCast[String](_, _.getBytes("UTF-8"))
}
// UDFToBoolean
- def castToBoolean: Any => Any = child.dataType match {
- case StringType => nullOrCast[String](_, _.length() != 0)
- case TimestampType => nullOrCast[Timestamp](_, b => {(b.getTime() != 0 || b.getNanos() != 0)})
- case LongType => nullOrCast[Long](_, _ != 0)
- case IntegerType => nullOrCast[Int](_, _ != 0)
- case ShortType => nullOrCast[Short](_, _ != 0)
- case ByteType => nullOrCast[Byte](_, _ != 0)
- case DecimalType => nullOrCast[BigDecimal](_, _ != 0)
- case DoubleType => nullOrCast[Double](_, _ != 0)
- case FloatType => nullOrCast[Float](_, _ != 0)
+ private[this] def castToBoolean: Any => Any = child.dataType match {
+ case StringType =>
+ buildCast[String](_, _.length() != 0)
+ case TimestampType =>
+ buildCast[Timestamp](_, b => b.getTime() != 0 || b.getNanos() != 0)
+ case LongType =>
+ buildCast[Long](_, _ != 0)
+ case IntegerType =>
+ buildCast[Int](_, _ != 0)
+ case ShortType =>
+ buildCast[Short](_, _ != 0)
+ case ByteType =>
+ buildCast[Byte](_, _ != 0)
+ case DecimalType =>
+ buildCast[BigDecimal](_, _ != 0)
+ case DoubleType =>
+ buildCast[Double](_, _ != 0)
+ case FloatType =>
+ buildCast[Float](_, _ != 0)
}
// TimestampConverter
- def castToTimestamp: Any => Any = child.dataType match {
- case StringType => nullOrCast[String](_, s => {
- // Throw away extra if more than 9 decimal places
- val periodIdx = s.indexOf(".");
- var n = s
- if (periodIdx != -1) {
- if (n.length() - periodIdx > 9) {
+ private[this] def castToTimestamp: Any => Any = child.dataType match {
+ case StringType =>
+ buildCast[String](_, s => {
+ // Throw away extra if more than 9 decimal places
+ val periodIdx = s.indexOf(".")
+ var n = s
+ if (periodIdx != -1 && n.length() - periodIdx > 9) {
n = n.substring(0, periodIdx + 10)
}
- }
- try Timestamp.valueOf(n) catch { case _: java.lang.IllegalArgumentException => null}
- })
- case BooleanType => nullOrCast[Boolean](_, b => new Timestamp((if(b) 1 else 0) * 1000))
- case LongType => nullOrCast[Long](_, l => new Timestamp(l * 1000))
- case IntegerType => nullOrCast[Int](_, i => new Timestamp(i * 1000))
- case ShortType => nullOrCast[Short](_, s => new Timestamp(s * 1000))
- case ByteType => nullOrCast[Byte](_, b => new Timestamp(b * 1000))
+ try Timestamp.valueOf(n) catch { case _: java.lang.IllegalArgumentException => null }
+ })
+ case BooleanType =>
+ buildCast[Boolean](_, b => new Timestamp((if (b) 1 else 0) * 1000))
+ case LongType =>
+ buildCast[Long](_, l => new Timestamp(l * 1000))
+ case IntegerType =>
+ buildCast[Int](_, i => new Timestamp(i * 1000))
+ case ShortType =>
+ buildCast[Short](_, s => new Timestamp(s * 1000))
+ case ByteType =>
+ buildCast[Byte](_, b => new Timestamp(b * 1000))
// TimestampWritable.decimalToTimestamp
- case DecimalType => nullOrCast[BigDecimal](_, d => decimalToTimestamp(d))
+ case DecimalType =>
+ buildCast[BigDecimal](_, d => decimalToTimestamp(d))
// TimestampWritable.doubleToTimestamp
- case DoubleType => nullOrCast[Double](_, d => decimalToTimestamp(d))
+ case DoubleType =>
+ buildCast[Double](_, d => decimalToTimestamp(d))
// TimestampWritable.floatToTimestamp
- case FloatType => nullOrCast[Float](_, f => decimalToTimestamp(f))
+ case FloatType =>
+ buildCast[Float](_, f => decimalToTimestamp(f))
}
- private def decimalToTimestamp(d: BigDecimal) = {
+ private[this] def decimalToTimestamp(d: BigDecimal) = {
val seconds = d.longValue()
val bd = (d - seconds) * 1000000000
val nanos = bd.intValue()
@@ -104,85 +119,118 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
}
// Timestamp to long, converting milliseconds to seconds
- private def timestampToLong(ts: Timestamp) = ts.getTime / 1000
+ private[this] def timestampToLong(ts: Timestamp) = ts.getTime / 1000
- private def timestampToDouble(ts: Timestamp) = {
+ private[this] def timestampToDouble(ts: Timestamp) = {
// First part is the seconds since the beginning of time, followed by nanosecs.
ts.getTime / 1000 + ts.getNanos.toDouble / 1000000000
}
- def castToLong: Any => Any = child.dataType match {
- case StringType => nullOrCast[String](_, s => try s.toLong catch {
- case _: NumberFormatException => null
- })
- case BooleanType => nullOrCast[Boolean](_, b => if(b) 1L else 0L)
- case TimestampType => nullOrCast[Timestamp](_, t => timestampToLong(t))
- case DecimalType => nullOrCast[BigDecimal](_, _.toLong)
- case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toLong(b)
- }
-
- def castToInt: Any => Any = child.dataType match {
- case StringType => nullOrCast[String](_, s => try s.toInt catch {
- case _: NumberFormatException => null
- })
- case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0)
- case TimestampType => nullOrCast[Timestamp](_, t => timestampToLong(t).toInt)
- case DecimalType => nullOrCast[BigDecimal](_, _.toInt)
- case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b)
- }
-
- def castToShort: Any => Any = child.dataType match {
- case StringType => nullOrCast[String](_, s => try s.toShort catch {
- case _: NumberFormatException => null
- })
- case BooleanType => nullOrCast[Boolean](_, b => if(b) 1.toShort else 0.toShort)
- case TimestampType => nullOrCast[Timestamp](_, t => timestampToLong(t).toShort)
- case DecimalType => nullOrCast[BigDecimal](_, _.toShort)
- case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toShort
- }
-
- def castToByte: Any => Any = child.dataType match {
- case StringType => nullOrCast[String](_, s => try s.toByte catch {
- case _: NumberFormatException => null
- })
- case BooleanType => nullOrCast[Boolean](_, b => if(b) 1.toByte else 0.toByte)
- case TimestampType => nullOrCast[Timestamp](_, t => timestampToLong(t).toByte)
- case DecimalType => nullOrCast[BigDecimal](_, _.toByte)
- case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toByte
- }
-
- def castToDecimal: Any => Any = child.dataType match {
- case StringType => nullOrCast[String](_, s => try BigDecimal(s.toDouble) catch {
- case _: NumberFormatException => null
- })
- case BooleanType => nullOrCast[Boolean](_, b => if(b) BigDecimal(1) else BigDecimal(0))
+ private[this] def castToLong: Any => Any = child.dataType match {
+ case StringType =>
+ buildCast[String](_, s => try s.toLong catch {
+ case _: NumberFormatException => null
+ })
+ case BooleanType =>
+ buildCast[Boolean](_, b => if (b) 1L else 0L)
+ case TimestampType =>
+ buildCast[Timestamp](_, t => timestampToLong(t))
+ case DecimalType =>
+ buildCast[BigDecimal](_, _.toLong)
+ case x: NumericType =>
+ b => x.numeric.asInstanceOf[Numeric[Any]].toLong(b)
+ }
+
+ private[this] def castToInt: Any => Any = child.dataType match {
+ case StringType =>
+ buildCast[String](_, s => try s.toInt catch {
+ case _: NumberFormatException => null
+ })
+ case BooleanType =>
+ buildCast[Boolean](_, b => if (b) 1 else 0)
+ case TimestampType =>
+ buildCast[Timestamp](_, t => timestampToLong(t).toInt)
+ case DecimalType =>
+ buildCast[BigDecimal](_, _.toInt)
+ case x: NumericType =>
+ b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b)
+ }
+
+ private[this] def castToShort: Any => Any = child.dataType match {
+ case StringType =>
+ buildCast[String](_, s => try s.toShort catch {
+ case _: NumberFormatException => null
+ })
+ case BooleanType =>
+ buildCast[Boolean](_, b => if (b) 1.toShort else 0.toShort)
+ case TimestampType =>
+ buildCast[Timestamp](_, t => timestampToLong(t).toShort)
+ case DecimalType =>
+ buildCast[BigDecimal](_, _.toShort)
+ case x: NumericType =>
+ b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toShort
+ }
+
+ private[this] def castToByte: Any => Any = child.dataType match {
+ case StringType =>
+ buildCast[String](_, s => try s.toByte catch {
+ case _: NumberFormatException => null
+ })
+ case BooleanType =>
+ buildCast[Boolean](_, b => if (b) 1.toByte else 0.toByte)
+ case TimestampType =>
+ buildCast[Timestamp](_, t => timestampToLong(t).toByte)
+ case DecimalType =>
+ buildCast[BigDecimal](_, _.toByte)
+ case x: NumericType =>
+ b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toByte
+ }
+
+ private[this] def castToDecimal: Any => Any = child.dataType match {
+ case StringType =>
+ buildCast[String](_, s => try BigDecimal(s.toDouble) catch {
+ case _: NumberFormatException => null
+ })
+ case BooleanType =>
+ buildCast[Boolean](_, b => if (b) BigDecimal(1) else BigDecimal(0))
case TimestampType =>
// Note that we lose precision here.
- nullOrCast[Timestamp](_, t => BigDecimal(timestampToDouble(t)))
- case x: NumericType => b => BigDecimal(x.numeric.asInstanceOf[Numeric[Any]].toDouble(b))
- }
-
- def castToDouble: Any => Any = child.dataType match {
- case StringType => nullOrCast[String](_, s => try s.toDouble catch {
- case _: NumberFormatException => null
- })
- case BooleanType => nullOrCast[Boolean](_, b => if(b) 1d else 0d)
- case TimestampType => nullOrCast[Timestamp](_, t => timestampToDouble(t))
- case DecimalType => nullOrCast[BigDecimal](_, _.toDouble)
- case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toDouble(b)
- }
-
- def castToFloat: Any => Any = child.dataType match {
- case StringType => nullOrCast[String](_, s => try s.toFloat catch {
- case _: NumberFormatException => null
- })
- case BooleanType => nullOrCast[Boolean](_, b => if(b) 1f else 0f)
- case TimestampType => nullOrCast[Timestamp](_, t => timestampToDouble(t).toFloat)
- case DecimalType => nullOrCast[BigDecimal](_, _.toFloat)
- case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toFloat(b)
+ buildCast[Timestamp](_, t => BigDecimal(timestampToDouble(t)))
+ case x: NumericType =>
+ b => BigDecimal(x.numeric.asInstanceOf[Numeric[Any]].toDouble(b))
+ }
+
+ private[this] def castToDouble: Any => Any = child.dataType match {
+ case StringType =>
+ buildCast[String](_, s => try s.toDouble catch {
+ case _: NumberFormatException => null
+ })
+ case BooleanType =>
+ buildCast[Boolean](_, b => if (b) 1d else 0d)
+ case TimestampType =>
+ buildCast[Timestamp](_, t => timestampToDouble(t))
+ case DecimalType =>
+ buildCast[BigDecimal](_, _.toDouble)
+ case x: NumericType =>
+ b => x.numeric.asInstanceOf[Numeric[Any]].toDouble(b)
+ }
+
+ private[this] def castToFloat: Any => Any = child.dataType match {
+ case StringType =>
+ buildCast[String](_, s => try s.toFloat catch {
+ case _: NumberFormatException => null
+ })
+ case BooleanType =>
+ buildCast[Boolean](_, b => if (b) 1f else 0f)
+ case TimestampType =>
+ buildCast[Timestamp](_, t => timestampToDouble(t).toFloat)
+ case DecimalType =>
+ buildCast[BigDecimal](_, _.toFloat)
+ case x: NumericType =>
+ b => x.numeric.asInstanceOf[Numeric[Any]].toFloat(b)
}
- private lazy val cast: Any => Any = dataType match {
+ private[this] lazy val cast: Any => Any = dataType match {
case StringType => castToString
case BinaryType => castToBinary
case DecimalType => castToDecimal
@@ -198,10 +246,6 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
override def eval(input: Row): Any = {
val evaluated = child.eval(input)
- if (evaluated == null) {
- null
- } else {
- cast(evaluated)
- }
+ if (evaluated == null) null else cast(evaluated)
}
}