diff options
author | Reynold Xin <rxin@apache.org> | 2014-06-20 00:01:19 -0700 |
---|---|---|
committer | Reynold Xin <rxin@apache.org> | 2014-06-20 00:01:19 -0700 |
commit | c55bbb49f7ec653f0ff635015d3bc789ca26c4eb (patch) | |
tree | a47a29f9075c03e75ceb0ccd8d26775b14126084 | |
parent | 61756409736a64bd42577782cb7468557fa0b642 (diff) | |
download | spark-c55bbb49f7ec653f0ff635015d3bc789ca26c4eb.tar.gz spark-c55bbb49f7ec653f0ff635015d3bc789ca26c4eb.tar.bz2 spark-c55bbb49f7ec653f0ff635015d3bc789ca26c4eb.zip |
[SPARK-2209][SQL] Cast shouldn't do null check twice.
Also took the chance to clean up cast a little bit. Too many arrows on each line before!
Author: Reynold Xin <rxin@apache.org>
Closes #1143 from rxin/cast and squashes the following commits:
dd006cb [Reynold Xin] Code review feedback.
c2b88ae [Reynold Xin] [SPARK-2209][SQL] Cast shouldn't do null check twice.
-rw-r--r-- | sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala | 274 |
1 files changed, 159 insertions, 115 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 0b3a4e728e..1f9716e385 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -24,72 +24,87 @@ import org.apache.spark.sql.catalyst.types._ /** Cast the child expression to the target data type. */ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { override def foldable = child.foldable - def nullable = (child.dataType, dataType) match { + + override def nullable = (child.dataType, dataType) match { case (StringType, _: NumericType) => true case (StringType, TimestampType) => true case _ => child.nullable } + override def toString = s"CAST($child, $dataType)" type EvaluatedType = Any - def nullOrCast[T](a: Any, func: T => Any): Any = if(a == null) { - null - } else { - func(a.asInstanceOf[T]) - } + // [[func]] assumes the input is no longer null because eval already does the null check. + @inline private[this] def buildCast[T](a: Any, func: T => Any): Any = func(a.asInstanceOf[T]) // UDFToString - def castToString: Any => Any = child.dataType match { - case BinaryType => nullOrCast[Array[Byte]](_, new String(_, "UTF-8")) - case _ => nullOrCast[Any](_, _.toString) + private[this] def castToString: Any => Any = child.dataType match { + case BinaryType => buildCast[Array[Byte]](_, new String(_, "UTF-8")) + case _ => buildCast[Any](_, _.toString) } // BinaryConverter - def castToBinary: Any => Any = child.dataType match { - case StringType => nullOrCast[String](_, _.getBytes("UTF-8")) + private[this] def castToBinary: Any => Any = child.dataType match { + case StringType => buildCast[String](_, _.getBytes("UTF-8")) } // UDFToBoolean - def castToBoolean: Any => Any = child.dataType match { - case StringType => nullOrCast[String](_, _.length() != 0) - case TimestampType => nullOrCast[Timestamp](_, b => {(b.getTime() != 0 || b.getNanos() != 0)}) - case LongType => nullOrCast[Long](_, _ != 0) - case IntegerType => nullOrCast[Int](_, _ != 0) - case ShortType => nullOrCast[Short](_, _ != 0) - case ByteType => nullOrCast[Byte](_, _ != 0) - case DecimalType => nullOrCast[BigDecimal](_, _ != 0) - case DoubleType => nullOrCast[Double](_, _ != 0) - case FloatType => nullOrCast[Float](_, _ != 0) + private[this] def castToBoolean: Any => Any = child.dataType match { + case StringType => + buildCast[String](_, _.length() != 0) + case TimestampType => + buildCast[Timestamp](_, b => b.getTime() != 0 || b.getNanos() != 0) + case LongType => + buildCast[Long](_, _ != 0) + case IntegerType => + buildCast[Int](_, _ != 0) + case ShortType => + buildCast[Short](_, _ != 0) + case ByteType => + buildCast[Byte](_, _ != 0) + case DecimalType => + buildCast[BigDecimal](_, _ != 0) + case DoubleType => + buildCast[Double](_, _ != 0) + case FloatType => + buildCast[Float](_, _ != 0) } // TimestampConverter - def castToTimestamp: Any => Any = child.dataType match { - case StringType => nullOrCast[String](_, s => { - // Throw away extra if more than 9 decimal places - val periodIdx = s.indexOf("."); - var n = s - if (periodIdx != -1) { - if (n.length() - periodIdx > 9) { + private[this] def castToTimestamp: Any => Any = child.dataType match { + case StringType => + buildCast[String](_, s => { + // Throw away extra if more than 9 decimal places + val periodIdx = s.indexOf(".") + var n = s + if (periodIdx != -1 && n.length() - periodIdx > 9) { n = n.substring(0, periodIdx + 10) } - } - try Timestamp.valueOf(n) catch { case _: java.lang.IllegalArgumentException => null} - }) - case BooleanType => nullOrCast[Boolean](_, b => new Timestamp((if(b) 1 else 0) * 1000)) - case LongType => nullOrCast[Long](_, l => new Timestamp(l * 1000)) - case IntegerType => nullOrCast[Int](_, i => new Timestamp(i * 1000)) - case ShortType => nullOrCast[Short](_, s => new Timestamp(s * 1000)) - case ByteType => nullOrCast[Byte](_, b => new Timestamp(b * 1000)) + try Timestamp.valueOf(n) catch { case _: java.lang.IllegalArgumentException => null } + }) + case BooleanType => + buildCast[Boolean](_, b => new Timestamp((if (b) 1 else 0) * 1000)) + case LongType => + buildCast[Long](_, l => new Timestamp(l * 1000)) + case IntegerType => + buildCast[Int](_, i => new Timestamp(i * 1000)) + case ShortType => + buildCast[Short](_, s => new Timestamp(s * 1000)) + case ByteType => + buildCast[Byte](_, b => new Timestamp(b * 1000)) // TimestampWritable.decimalToTimestamp - case DecimalType => nullOrCast[BigDecimal](_, d => decimalToTimestamp(d)) + case DecimalType => + buildCast[BigDecimal](_, d => decimalToTimestamp(d)) // TimestampWritable.doubleToTimestamp - case DoubleType => nullOrCast[Double](_, d => decimalToTimestamp(d)) + case DoubleType => + buildCast[Double](_, d => decimalToTimestamp(d)) // TimestampWritable.floatToTimestamp - case FloatType => nullOrCast[Float](_, f => decimalToTimestamp(f)) + case FloatType => + buildCast[Float](_, f => decimalToTimestamp(f)) } - private def decimalToTimestamp(d: BigDecimal) = { + private[this] def decimalToTimestamp(d: BigDecimal) = { val seconds = d.longValue() val bd = (d - seconds) * 1000000000 val nanos = bd.intValue() @@ -104,85 +119,118 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { } // Timestamp to long, converting milliseconds to seconds - private def timestampToLong(ts: Timestamp) = ts.getTime / 1000 + private[this] def timestampToLong(ts: Timestamp) = ts.getTime / 1000 - private def timestampToDouble(ts: Timestamp) = { + private[this] def timestampToDouble(ts: Timestamp) = { // First part is the seconds since the beginning of time, followed by nanosecs. ts.getTime / 1000 + ts.getNanos.toDouble / 1000000000 } - def castToLong: Any => Any = child.dataType match { - case StringType => nullOrCast[String](_, s => try s.toLong catch { - case _: NumberFormatException => null - }) - case BooleanType => nullOrCast[Boolean](_, b => if(b) 1L else 0L) - case TimestampType => nullOrCast[Timestamp](_, t => timestampToLong(t)) - case DecimalType => nullOrCast[BigDecimal](_, _.toLong) - case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toLong(b) - } - - def castToInt: Any => Any = child.dataType match { - case StringType => nullOrCast[String](_, s => try s.toInt catch { - case _: NumberFormatException => null - }) - case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0) - case TimestampType => nullOrCast[Timestamp](_, t => timestampToLong(t).toInt) - case DecimalType => nullOrCast[BigDecimal](_, _.toInt) - case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b) - } - - def castToShort: Any => Any = child.dataType match { - case StringType => nullOrCast[String](_, s => try s.toShort catch { - case _: NumberFormatException => null - }) - case BooleanType => nullOrCast[Boolean](_, b => if(b) 1.toShort else 0.toShort) - case TimestampType => nullOrCast[Timestamp](_, t => timestampToLong(t).toShort) - case DecimalType => nullOrCast[BigDecimal](_, _.toShort) - case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toShort - } - - def castToByte: Any => Any = child.dataType match { - case StringType => nullOrCast[String](_, s => try s.toByte catch { - case _: NumberFormatException => null - }) - case BooleanType => nullOrCast[Boolean](_, b => if(b) 1.toByte else 0.toByte) - case TimestampType => nullOrCast[Timestamp](_, t => timestampToLong(t).toByte) - case DecimalType => nullOrCast[BigDecimal](_, _.toByte) - case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toByte - } - - def castToDecimal: Any => Any = child.dataType match { - case StringType => nullOrCast[String](_, s => try BigDecimal(s.toDouble) catch { - case _: NumberFormatException => null - }) - case BooleanType => nullOrCast[Boolean](_, b => if(b) BigDecimal(1) else BigDecimal(0)) + private[this] def castToLong: Any => Any = child.dataType match { + case StringType => + buildCast[String](_, s => try s.toLong catch { + case _: NumberFormatException => null + }) + case BooleanType => + buildCast[Boolean](_, b => if (b) 1L else 0L) + case TimestampType => + buildCast[Timestamp](_, t => timestampToLong(t)) + case DecimalType => + buildCast[BigDecimal](_, _.toLong) + case x: NumericType => + b => x.numeric.asInstanceOf[Numeric[Any]].toLong(b) + } + + private[this] def castToInt: Any => Any = child.dataType match { + case StringType => + buildCast[String](_, s => try s.toInt catch { + case _: NumberFormatException => null + }) + case BooleanType => + buildCast[Boolean](_, b => if (b) 1 else 0) + case TimestampType => + buildCast[Timestamp](_, t => timestampToLong(t).toInt) + case DecimalType => + buildCast[BigDecimal](_, _.toInt) + case x: NumericType => + b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b) + } + + private[this] def castToShort: Any => Any = child.dataType match { + case StringType => + buildCast[String](_, s => try s.toShort catch { + case _: NumberFormatException => null + }) + case BooleanType => + buildCast[Boolean](_, b => if (b) 1.toShort else 0.toShort) + case TimestampType => + buildCast[Timestamp](_, t => timestampToLong(t).toShort) + case DecimalType => + buildCast[BigDecimal](_, _.toShort) + case x: NumericType => + b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toShort + } + + private[this] def castToByte: Any => Any = child.dataType match { + case StringType => + buildCast[String](_, s => try s.toByte catch { + case _: NumberFormatException => null + }) + case BooleanType => + buildCast[Boolean](_, b => if (b) 1.toByte else 0.toByte) + case TimestampType => + buildCast[Timestamp](_, t => timestampToLong(t).toByte) + case DecimalType => + buildCast[BigDecimal](_, _.toByte) + case x: NumericType => + b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toByte + } + + private[this] def castToDecimal: Any => Any = child.dataType match { + case StringType => + buildCast[String](_, s => try BigDecimal(s.toDouble) catch { + case _: NumberFormatException => null + }) + case BooleanType => + buildCast[Boolean](_, b => if (b) BigDecimal(1) else BigDecimal(0)) case TimestampType => // Note that we lose precision here. - nullOrCast[Timestamp](_, t => BigDecimal(timestampToDouble(t))) - case x: NumericType => b => BigDecimal(x.numeric.asInstanceOf[Numeric[Any]].toDouble(b)) - } - - def castToDouble: Any => Any = child.dataType match { - case StringType => nullOrCast[String](_, s => try s.toDouble catch { - case _: NumberFormatException => null - }) - case BooleanType => nullOrCast[Boolean](_, b => if(b) 1d else 0d) - case TimestampType => nullOrCast[Timestamp](_, t => timestampToDouble(t)) - case DecimalType => nullOrCast[BigDecimal](_, _.toDouble) - case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toDouble(b) - } - - def castToFloat: Any => Any = child.dataType match { - case StringType => nullOrCast[String](_, s => try s.toFloat catch { - case _: NumberFormatException => null - }) - case BooleanType => nullOrCast[Boolean](_, b => if(b) 1f else 0f) - case TimestampType => nullOrCast[Timestamp](_, t => timestampToDouble(t).toFloat) - case DecimalType => nullOrCast[BigDecimal](_, _.toFloat) - case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toFloat(b) + buildCast[Timestamp](_, t => BigDecimal(timestampToDouble(t))) + case x: NumericType => + b => BigDecimal(x.numeric.asInstanceOf[Numeric[Any]].toDouble(b)) + } + + private[this] def castToDouble: Any => Any = child.dataType match { + case StringType => + buildCast[String](_, s => try s.toDouble catch { + case _: NumberFormatException => null + }) + case BooleanType => + buildCast[Boolean](_, b => if (b) 1d else 0d) + case TimestampType => + buildCast[Timestamp](_, t => timestampToDouble(t)) + case DecimalType => + buildCast[BigDecimal](_, _.toDouble) + case x: NumericType => + b => x.numeric.asInstanceOf[Numeric[Any]].toDouble(b) + } + + private[this] def castToFloat: Any => Any = child.dataType match { + case StringType => + buildCast[String](_, s => try s.toFloat catch { + case _: NumberFormatException => null + }) + case BooleanType => + buildCast[Boolean](_, b => if (b) 1f else 0f) + case TimestampType => + buildCast[Timestamp](_, t => timestampToDouble(t).toFloat) + case DecimalType => + buildCast[BigDecimal](_, _.toFloat) + case x: NumericType => + b => x.numeric.asInstanceOf[Numeric[Any]].toFloat(b) } - private lazy val cast: Any => Any = dataType match { + private[this] lazy val cast: Any => Any = dataType match { case StringType => castToString case BinaryType => castToBinary case DecimalType => castToDecimal @@ -198,10 +246,6 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { override def eval(input: Row): Any = { val evaluated = child.eval(input) - if (evaluated == null) { - null - } else { - cast(evaluated) - } + if (evaluated == null) null else cast(evaluated) } } |