aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYijie Shen <henry.yijieshen@gmail.com>2015-07-22 23:44:08 -0700
committerDavies Liu <davies.liu@gmail.com>2015-07-22 23:44:08 -0700
commit6d0d8b406942edcf9fc97e76fb227ff1eb35ca3a (patch)
treeb836bc18a446c600a0cd7d1d295c62b4c4af900a
parent825ab1e4526059a77e3278769797c4d065f48bd3 (diff)
downloadspark-6d0d8b406942edcf9fc97e76fb227ff1eb35ca3a.tar.gz
spark-6d0d8b406942edcf9fc97e76fb227ff1eb35ca3a.tar.bz2
spark-6d0d8b406942edcf9fc97e76fb227ff1eb35ca3a.zip
[SPARK-8935] [SQL] Implement code generation for all casts
JIRA: https://issues.apache.org/jira/browse/SPARK-8935 Author: Yijie Shen <henry.yijieshen@gmail.com> Closes #7365 from yjshen/cast_codegen and squashes the following commits: ef6e8b5 [Yijie Shen] getColumn and setColumn in struct cast, autounboxing in array and map eaece18 [Yijie Shen] remove null case in cast code gen fd7eba4 [Yijie Shen] resolve comments 80378a5 [Yijie Shen] the missing self cast 611d66e [Yijie Shen] Bug fix: NullType & primitive object unboxing 6d5c0fe [Yijie Shen] rebase and add Interval codegen 9424b65 [Yijie Shen] tiny style fix 4a1c801 [Yijie Shen] remove CodeHolder class, use function instead. 3f5df88 [Yijie Shen] CodeHolder for complex dataTypes c286f13 [Yijie Shen] moved all the cast code into class body 4edfd76 [Yijie Shen] [WIP] finished primitive part
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala523
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala36
2 files changed, 508 insertions, 51 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 3346d3c9f9..e66cd82848 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -26,6 +26,8 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{Interval, UTF8String}
+import scala.collection.mutable
+
object Cast {
@@ -418,51 +420,506 @@ case class Cast(child: Expression, dataType: DataType)
protected override def nullSafeEval(input: Any): Any = cast(input)
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
- // TODO: Add support for more data types.
- (child.dataType, dataType) match {
+ val eval = child.gen(ctx)
+ val nullSafeCast = nullSafeCastFunction(child.dataType, dataType, ctx)
+ eval.code +
+ castCode(ctx, eval.primitive, eval.isNull, ev.primitive, ev.isNull, dataType, nullSafeCast)
+ }
+
+ // three function arguments are: child.primitive, result.primitive and result.isNull
+ // it returns the code snippets to be put in null safe evaluation region
+ private[this] type CastFunction = (String, String, String) => String
+
+ private[this] def nullSafeCastFunction(
+ from: DataType,
+ to: DataType,
+ ctx: CodeGenContext): CastFunction = to match {
+
+ case _ if from == NullType => (c, evPrim, evNull) => s"$evNull = true;"
+ case _ if to == from => (c, evPrim, evNull) => s"$evPrim = $c;"
+ case StringType => castToStringCode(from, ctx)
+ case BinaryType => castToBinaryCode(from)
+ case DateType => castToDateCode(from, ctx)
+ case decimal: DecimalType => castToDecimalCode(from, decimal)
+ case TimestampType => castToTimestampCode(from, ctx)
+ case IntervalType => castToIntervalCode(from)
+ case BooleanType => castToBooleanCode(from)
+ case ByteType => castToByteCode(from)
+ case ShortType => castToShortCode(from)
+ case IntegerType => castToIntCode(from)
+ case FloatType => castToFloatCode(from)
+ case LongType => castToLongCode(from)
+ case DoubleType => castToDoubleCode(from)
+
+ case array: ArrayType => castArrayCode(from.asInstanceOf[ArrayType], array, ctx)
+ case map: MapType => castMapCode(from.asInstanceOf[MapType], map, ctx)
+ case struct: StructType => castStructCode(from.asInstanceOf[StructType], struct, ctx)
+ }
+
+ // Since we need to cast child expressions recursively inside ComplexTypes, such as Map's
+ // Key and Value, Struct's field, we need to name out all the variable names involved in a cast.
+ private[this] def castCode(ctx: CodeGenContext, childPrim: String, childNull: String,
+ resultPrim: String, resultNull: String, resultType: DataType, cast: CastFunction): String = {
+ s"""
+ boolean $resultNull = $childNull;
+ ${ctx.javaType(resultType)} $resultPrim = ${ctx.defaultValue(resultType)};
+ if (!${childNull}) {
+ ${cast(childPrim, resultPrim, resultNull)}
+ }
+ """
+ }
+
+ private[this] def castToStringCode(from: DataType, ctx: CodeGenContext): CastFunction = {
+ from match {
+ case BinaryType =>
+ (c, evPrim, evNull) => s"$evPrim = UTF8String.fromBytes($c);"
+ case DateType =>
+ (c, evPrim, evNull) => s"""$evPrim = UTF8String.fromString(
+ org.apache.spark.sql.catalyst.util.DateTimeUtils.dateToString($c));"""
+ case TimestampType =>
+ (c, evPrim, evNull) => s"""$evPrim = UTF8String.fromString(
+ org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampToString($c));"""
+ case _ =>
+ (c, evPrim, evNull) => s"$evPrim = UTF8String.fromString(String.valueOf($c));"
+ }
+ }
+
+ private[this] def castToBinaryCode(from: DataType): CastFunction = from match {
+ case StringType =>
+ (c, evPrim, evNull) => s"$evPrim = $c.getBytes();"
+ }
+
+ private[this] def castToDateCode(
+ from: DataType,
+ ctx: CodeGenContext): CastFunction = from match {
+ case StringType =>
+ val intOpt = ctx.freshName("intOpt")
+ (c, evPrim, evNull) => s"""
+ scala.Option<Integer> $intOpt =
+ org.apache.spark.sql.catalyst.util.DateTimeUtils.stringToDate($c);
+ if ($intOpt.isDefined()) {
+ $evPrim = ((Integer) $intOpt.get()).intValue();
+ } else {
+ $evNull = true;
+ }
+ """
+ case TimestampType =>
+ (c, evPrim, evNull) =>
+ s"$evPrim = org.apache.spark.sql.catalyst.util.DateTimeUtils.millisToDays($c / 1000L);";
+ case _ =>
+ (c, evPrim, evNull) => s"$evNull = true;"
+ }
+
+ private[this] def changePrecision(d: String, decimalType: DecimalType,
+ evPrim: String, evNull: String): String = {
+ decimalType match {
+ case DecimalType.Unlimited =>
+ s"$evPrim = $d;"
+ case DecimalType.Fixed(precision, scale) =>
+ s"""
+ if ($d.changePrecision($precision, $scale)) {
+ $evPrim = $d;
+ } else {
+ $evNull = true;
+ }
+ """
+ }
+ }
- case (BinaryType, StringType) =>
- defineCodeGen (ctx, ev, c =>
- s"UTF8String.fromBytes($c)")
+ private[this] def castToDecimalCode(from: DataType, target: DecimalType): CastFunction = {
+ from match {
+ case StringType =>
+ (c, evPrim, evNull) =>
+ s"""
+ try {
+ org.apache.spark.sql.types.Decimal tmpDecimal =
+ new org.apache.spark.sql.types.Decimal().set(
+ new scala.math.BigDecimal(
+ new java.math.BigDecimal($c.toString())));
+ ${changePrecision("tmpDecimal", target, evPrim, evNull)}
+ } catch (java.lang.NumberFormatException e) {
+ $evNull = true;
+ }
+ """
+ case BooleanType =>
+ (c, evPrim, evNull) =>
+ s"""
+ org.apache.spark.sql.types.Decimal tmpDecimal = null;
+ if ($c) {
+ tmpDecimal = new org.apache.spark.sql.types.Decimal().set(1);
+ } else {
+ tmpDecimal = new org.apache.spark.sql.types.Decimal().set(0);
+ }
+ ${changePrecision("tmpDecimal", target, evPrim, evNull)}
+ """
+ case DateType =>
+ // date can't cast to decimal in Hive
+ (c, evPrim, evNull) => s"$evNull = true;"
+ case TimestampType =>
+ // Note that we lose precision here.
+ (c, evPrim, evNull) =>
+ s"""
+ org.apache.spark.sql.types.Decimal tmpDecimal =
+ new org.apache.spark.sql.types.Decimal().set(
+ scala.math.BigDecimal.valueOf(${timestampToDoubleCode(c)}));
+ ${changePrecision("tmpDecimal", target, evPrim, evNull)}
+ """
+ case DecimalType() =>
+ (c, evPrim, evNull) =>
+ s"""
+ org.apache.spark.sql.types.Decimal tmpDecimal = $c.clone();
+ ${changePrecision("tmpDecimal", target, evPrim, evNull)}
+ """
+ case LongType =>
+ (c, evPrim, evNull) =>
+ s"""
+ org.apache.spark.sql.types.Decimal tmpDecimal =
+ new org.apache.spark.sql.types.Decimal().set($c);
+ ${changePrecision("tmpDecimal", target, evPrim, evNull)}
+ """
+ case x: NumericType =>
+ // All other numeric types can be represented precisely as Doubles
+ (c, evPrim, evNull) =>
+ s"""
+ try {
+ org.apache.spark.sql.types.Decimal tmpDecimal =
+ new org.apache.spark.sql.types.Decimal().set(
+ scala.math.BigDecimal.valueOf((double) $c));
+ ${changePrecision("tmpDecimal", target, evPrim, evNull)}
+ } catch (java.lang.NumberFormatException e) {
+ $evNull = true;
+ }
+ """
+ }
+ }
- case (DateType, StringType) =>
- defineCodeGen(ctx, ev, c =>
- s"""UTF8String.fromString(
- org.apache.spark.sql.catalyst.util.DateTimeUtils.dateToString($c))""")
+ private[this] def castToTimestampCode(
+ from: DataType,
+ ctx: CodeGenContext): CastFunction = from match {
+ case StringType =>
+ val longOpt = ctx.freshName("longOpt")
+ (c, evPrim, evNull) =>
+ s"""
+ scala.Option<Long> $longOpt =
+ org.apache.spark.sql.catalyst.util.DateTimeUtils.stringToTimestamp($c);
+ if ($longOpt.isDefined()) {
+ $evPrim = ((Long) $longOpt.get()).longValue();
+ } else {
+ $evNull = true;
+ }
+ """
+ case BooleanType =>
+ (c, evPrim, evNull) => s"$evPrim = $c ? 1L : 0;"
+ case _: IntegralType =>
+ (c, evPrim, evNull) => s"$evPrim = ${longToTimeStampCode(c)};"
+ case DateType =>
+ (c, evPrim, evNull) =>
+ s"$evPrim = org.apache.spark.sql.catalyst.util.DateTimeUtils.daysToMillis($c) * 1000;"
+ case DecimalType() =>
+ (c, evPrim, evNull) => s"$evPrim = ${decimalToTimestampCode(c)};"
+ case DoubleType =>
+ (c, evPrim, evNull) =>
+ s"""
+ if (Double.isNaN($c) || Double.isInfinite($c)) {
+ $evNull = true;
+ } else {
+ $evPrim = (long)($c * 1000000L);
+ }
+ """
+ case FloatType =>
+ (c, evPrim, evNull) =>
+ s"""
+ if (Float.isNaN($c) || Float.isInfinite($c)) {
+ $evNull = true;
+ } else {
+ $evPrim = (long)($c * 1000000L);
+ }
+ """
+ }
- case (TimestampType, StringType) =>
- defineCodeGen(ctx, ev, c =>
- s"""UTF8String.fromString(
- org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampToString($c))""")
+ private[this] def castToIntervalCode(from: DataType): CastFunction = from match {
+ case StringType =>
+ (c, evPrim, evNull) =>
+ s"$evPrim = org.apache.spark.unsafe.types.Interval.fromString($c.toString());"
+ }
+
+ private[this] def decimalToTimestampCode(d: String): String =
+ s"($d.toBigDecimal().bigDecimal().multiply(new java.math.BigDecimal(1000000L))).longValue()"
+ private[this] def longToTimeStampCode(l: String): String = s"$l * 1000L"
+ private[this] def timestampToIntegerCode(ts: String): String =
+ s"java.lang.Math.floor((double) $ts / 1000000L)"
+ private[this] def timestampToDoubleCode(ts: String): String = s"$ts / 1000000.0"
+
+ private[this] def castToBooleanCode(from: DataType): CastFunction = from match {
+ case StringType =>
+ (c, evPrim, evNull) => s"$evPrim = $c.numBytes() != 0;"
+ case TimestampType =>
+ (c, evPrim, evNull) => s"$evPrim = $c != 0;"
+ case DateType =>
+ // Hive would return null when cast from date to boolean
+ (c, evPrim, evNull) => s"$evNull = true;"
+ case DecimalType() =>
+ (c, evPrim, evNull) => s"$evPrim = !$c.isZero();"
+ case n: NumericType =>
+ (c, evPrim, evNull) => s"$evPrim = $c != 0;"
+ }
+
+ private[this] def castToByteCode(from: DataType): CastFunction = from match {
+ case StringType =>
+ (c, evPrim, evNull) =>
+ s"""
+ try {
+ $evPrim = Byte.valueOf($c.toString());
+ } catch (java.lang.NumberFormatException e) {
+ $evNull = true;
+ }
+ """
+ case BooleanType =>
+ (c, evPrim, evNull) => s"$evPrim = $c ? 1 : 0;"
+ case DateType =>
+ (c, evPrim, evNull) => s"$evNull = true;"
+ case TimestampType =>
+ (c, evPrim, evNull) => s"$evPrim = (byte) ${timestampToIntegerCode(c)};"
+ case DecimalType() =>
+ (c, evPrim, evNull) => s"$evPrim = $c.toByte();"
+ case x: NumericType =>
+ (c, evPrim, evNull) => s"$evPrim = (byte) $c;"
+ }
- case (_, StringType) =>
- defineCodeGen(ctx, ev, c => s"UTF8String.fromString(String.valueOf($c))")
+ private[this] def castToShortCode(from: DataType): CastFunction = from match {
+ case StringType =>
+ (c, evPrim, evNull) =>
+ s"""
+ try {
+ $evPrim = Short.valueOf($c.toString());
+ } catch (java.lang.NumberFormatException e) {
+ $evNull = true;
+ }
+ """
+ case BooleanType =>
+ (c, evPrim, evNull) => s"$evPrim = $c ? 1 : 0;"
+ case DateType =>
+ (c, evPrim, evNull) => s"$evNull = true;"
+ case TimestampType =>
+ (c, evPrim, evNull) => s"$evPrim = (short) ${timestampToIntegerCode(c)};"
+ case DecimalType() =>
+ (c, evPrim, evNull) => s"$evPrim = $c.toShort();"
+ case x: NumericType =>
+ (c, evPrim, evNull) => s"$evPrim = (short) $c;"
+ }
- case (StringType, IntervalType) =>
- defineCodeGen(ctx, ev, c =>
- s"org.apache.spark.unsafe.types.Interval.fromString($c.toString())")
+ private[this] def castToIntCode(from: DataType): CastFunction = from match {
+ case StringType =>
+ (c, evPrim, evNull) =>
+ s"""
+ try {
+ $evPrim = Integer.valueOf($c.toString());
+ } catch (java.lang.NumberFormatException e) {
+ $evNull = true;
+ }
+ """
+ case BooleanType =>
+ (c, evPrim, evNull) => s"$evPrim = $c ? 1 : 0;"
+ case DateType =>
+ (c, evPrim, evNull) => s"$evNull = true;"
+ case TimestampType =>
+ (c, evPrim, evNull) => s"$evPrim = (int) ${timestampToIntegerCode(c)};"
+ case DecimalType() =>
+ (c, evPrim, evNull) => s"$evPrim = $c.toInt();"
+ case x: NumericType =>
+ (c, evPrim, evNull) => s"$evPrim = (int) $c;"
+ }
- // fallback for DecimalType, this must be before other numeric types
- case (_, dt: DecimalType) =>
- super.genCode(ctx, ev)
+ private[this] def castToLongCode(from: DataType): CastFunction = from match {
+ case StringType =>
+ (c, evPrim, evNull) =>
+ s"""
+ try {
+ $evPrim = Long.valueOf($c.toString());
+ } catch (java.lang.NumberFormatException e) {
+ $evNull = true;
+ }
+ """
+ case BooleanType =>
+ (c, evPrim, evNull) => s"$evPrim = $c ? 1 : 0;"
+ case DateType =>
+ (c, evPrim, evNull) => s"$evNull = true;"
+ case TimestampType =>
+ (c, evPrim, evNull) => s"$evPrim = (long) ${timestampToIntegerCode(c)};"
+ case DecimalType() =>
+ (c, evPrim, evNull) => s"$evPrim = $c.toLong();"
+ case x: NumericType =>
+ (c, evPrim, evNull) => s"$evPrim = (long) $c;"
+ }
- case (BooleanType, dt: NumericType) =>
- defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})($c ? 1 : 0)")
+ private[this] def castToFloatCode(from: DataType): CastFunction = from match {
+ case StringType =>
+ (c, evPrim, evNull) =>
+ s"""
+ try {
+ $evPrim = Float.valueOf($c.toString());
+ } catch (java.lang.NumberFormatException e) {
+ $evNull = true;
+ }
+ """
+ case BooleanType =>
+ (c, evPrim, evNull) => s"$evPrim = $c ? 1 : 0;"
+ case DateType =>
+ (c, evPrim, evNull) => s"$evNull = true;"
+ case TimestampType =>
+ (c, evPrim, evNull) => s"$evPrim = (float) (${timestampToDoubleCode(c)});"
+ case DecimalType() =>
+ (c, evPrim, evNull) => s"$evPrim = $c.toFloat();"
+ case x: NumericType =>
+ (c, evPrim, evNull) => s"$evPrim = (float) $c;"
+ }
- case (dt: DecimalType, BooleanType) =>
- defineCodeGen(ctx, ev, c => s"!$c.isZero()")
+ private[this] def castToDoubleCode(from: DataType): CastFunction = from match {
+ case StringType =>
+ (c, evPrim, evNull) =>
+ s"""
+ try {
+ $evPrim = Double.valueOf($c.toString());
+ } catch (java.lang.NumberFormatException e) {
+ $evNull = true;
+ }
+ """
+ case BooleanType =>
+ (c, evPrim, evNull) => s"$evPrim = $c ? 1 : 0;"
+ case DateType =>
+ (c, evPrim, evNull) => s"$evNull = true;"
+ case TimestampType =>
+ (c, evPrim, evNull) => s"$evPrim = ${timestampToDoubleCode(c)};"
+ case DecimalType() =>
+ (c, evPrim, evNull) => s"$evPrim = $c.toDouble();"
+ case x: NumericType =>
+ (c, evPrim, evNull) => s"$evPrim = (double) $c;"
+ }
- case (dt: NumericType, BooleanType) =>
- defineCodeGen(ctx, ev, c => s"$c != 0")
+ private[this] def castArrayCode(
+ from: ArrayType, to: ArrayType, ctx: CodeGenContext): CastFunction = {
+ val elementCast = nullSafeCastFunction(from.elementType, to.elementType, ctx)
+
+ val arraySeqClass = classOf[mutable.ArraySeq[Any]].getName
+ val fromElementNull = ctx.freshName("feNull")
+ val fromElementPrim = ctx.freshName("fePrim")
+ val toElementNull = ctx.freshName("teNull")
+ val toElementPrim = ctx.freshName("tePrim")
+ val size = ctx.freshName("n")
+ val j = ctx.freshName("j")
+ val result = ctx.freshName("result")
+
+ (c, evPrim, evNull) =>
+ s"""
+ final int $size = $c.size();
+ final $arraySeqClass<Object> $result = new $arraySeqClass<Object>($size);
+ for (int $j = 0; $j < $size; $j ++) {
+ if ($c.apply($j) == null) {
+ $result.update($j, null);
+ } else {
+ boolean $fromElementNull = false;
+ ${ctx.javaType(from.elementType)} $fromElementPrim =
+ (${ctx.boxedType(from.elementType)}) $c.apply($j);
+ ${castCode(ctx, fromElementPrim,
+ fromElementNull, toElementPrim, toElementNull, to.elementType, elementCast)}
+ if ($toElementNull) {
+ $result.update($j, null);
+ } else {
+ $result.update($j, $toElementPrim);
+ }
+ }
+ }
+ $evPrim = $result;
+ """
+ }
- case (_: DecimalType, dt: NumericType) =>
- defineCodeGen(ctx, ev, c => s"($c).to${ctx.primitiveTypeName(dt)}()")
+ private[this] def castMapCode(from: MapType, to: MapType, ctx: CodeGenContext): CastFunction = {
+ val keyCast = nullSafeCastFunction(from.keyType, to.keyType, ctx)
+ val valueCast = nullSafeCastFunction(from.valueType, to.valueType, ctx)
+
+ val hashMapClass = classOf[mutable.HashMap[Any, Any]].getName
+ val fromKeyPrim = ctx.freshName("fkp")
+ val fromKeyNull = ctx.freshName("fkn")
+ val fromValuePrim = ctx.freshName("fvp")
+ val fromValueNull = ctx.freshName("fvn")
+ val toKeyPrim = ctx.freshName("tkp")
+ val toKeyNull = ctx.freshName("tkn")
+ val toValuePrim = ctx.freshName("tvp")
+ val toValueNull = ctx.freshName("tvn")
+ val result = ctx.freshName("result")
+
+ (c, evPrim, evNull) =>
+ s"""
+ final $hashMapClass $result = new $hashMapClass();
+ scala.collection.Iterator iter = $c.iterator();
+ while (iter.hasNext()) {
+ scala.Tuple2 kv = (scala.Tuple2) iter.next();
+ boolean $fromKeyNull = false;
+ ${ctx.javaType(from.keyType)} $fromKeyPrim =
+ (${ctx.boxedType(from.keyType)}) kv._1();
+ ${castCode(ctx, fromKeyPrim,
+ fromKeyNull, toKeyPrim, toKeyNull, to.keyType, keyCast)}
+
+ boolean $fromValueNull = kv._2() == null;
+ if ($fromValueNull) {
+ $result.put($toKeyPrim, null);
+ } else {
+ ${ctx.javaType(from.valueType)} $fromValuePrim =
+ (${ctx.boxedType(from.valueType)}) kv._2();
+ ${castCode(ctx, fromValuePrim,
+ fromValueNull, toValuePrim, toValueNull, to.valueType, valueCast)}
+ if ($toValueNull) {
+ $result.put($toKeyPrim, null);
+ } else {
+ $result.put($toKeyPrim, $toValuePrim);
+ }
+ }
+ }
+ $evPrim = $result;
+ """
+ }
- case (_: NumericType, dt: NumericType) =>
- defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})($c)")
+ private[this] def castStructCode(
+ from: StructType, to: StructType, ctx: CodeGenContext): CastFunction = {
- case other =>
- super.genCode(ctx, ev)
+ val fieldsCasts = from.fields.zip(to.fields).map {
+ case (fromField, toField) => nullSafeCastFunction(fromField.dataType, toField.dataType, ctx)
}
+ val rowClass = classOf[GenericMutableRow].getName
+ val result = ctx.freshName("result")
+ val tmpRow = ctx.freshName("tmpRow")
+
+ val fieldsEvalCode = fieldsCasts.zipWithIndex.map { case (cast, i) => {
+ val fromFieldPrim = ctx.freshName("ffp")
+ val fromFieldNull = ctx.freshName("ffn")
+ val toFieldPrim = ctx.freshName("tfp")
+ val toFieldNull = ctx.freshName("tfn")
+ val fromType = ctx.javaType(from.fields(i).dataType)
+ s"""
+ boolean $fromFieldNull = $tmpRow.isNullAt($i);
+ if ($fromFieldNull) {
+ $result.setNullAt($i);
+ } else {
+ $fromType $fromFieldPrim =
+ ${ctx.getColumn(tmpRow, from.fields(i).dataType, i)};
+ ${castCode(ctx, fromFieldPrim,
+ fromFieldNull, toFieldPrim, toFieldNull, to.fields(i).dataType, cast)}
+ if ($toFieldNull) {
+ $result.setNullAt($i);
+ } else {
+ ${ctx.setColumn(result, to.fields(i).dataType, i, toFieldPrim)};
+ }
+ }
+ """
+ }
+ }.mkString("\n")
+
+ (c, evPrim, evNull) =>
+ s"""
+ final $rowClass $result = new $rowClass(${fieldsCasts.size});
+ final InternalRow $tmpRow = $c;
+ $fieldsEvalCode
+ $evPrim = $result.copy();
+ """
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala
index f724bab4d8..bdba6ce891 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala
@@ -39,7 +39,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
val c = Calendar.getInstance()
c.set(y, m, 28, 0, 0, 0)
c.add(Calendar.DATE, i)
- checkEvaluation(DayOfYear(Cast(Literal(new Date(c.getTimeInMillis)), DateType)),
+ checkEvaluation(DayOfYear(Literal(new Date(c.getTimeInMillis))),
sdfDay.format(c.getTime).toInt)
}
}
@@ -51,7 +51,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
val c = Calendar.getInstance()
c.set(y, m, 28, 0, 0, 0)
c.add(Calendar.DATE, i)
- checkEvaluation(DayOfYear(Cast(Literal(new Date(c.getTimeInMillis)), DateType)),
+ checkEvaluation(DayOfYear(Literal(new Date(c.getTimeInMillis))),
sdfDay.format(c.getTime).toInt)
}
}
@@ -63,7 +63,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
val c = Calendar.getInstance()
c.set(y, m, 28, 0, 0, 0)
c.add(Calendar.DATE, i)
- checkEvaluation(DayOfYear(Cast(Literal(new Date(c.getTimeInMillis)), DateType)),
+ checkEvaluation(DayOfYear(Literal(new Date(c.getTimeInMillis))),
sdfDay.format(c.getTime).toInt)
}
}
@@ -75,7 +75,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
val c = Calendar.getInstance()
c.set(y, m, 28, 0, 0, 0)
c.add(Calendar.DATE, i)
- checkEvaluation(DayOfYear(Cast(Literal(new Date(c.getTimeInMillis)), DateType)),
+ checkEvaluation(DayOfYear(Literal(new Date(c.getTimeInMillis))),
sdfDay.format(c.getTime).toInt)
}
}
@@ -87,7 +87,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
val c = Calendar.getInstance()
c.set(y, m, 28, 0, 0, 0)
c.add(Calendar.DATE, i)
- checkEvaluation(DayOfYear(Cast(Literal(new Date(c.getTimeInMillis)), DateType)),
+ checkEvaluation(DayOfYear(Literal(new Date(c.getTimeInMillis))),
sdfDay.format(c.getTime).toInt)
}
}
@@ -96,7 +96,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
test("Year") {
checkEvaluation(Year(Literal.create(null, DateType)), null)
- checkEvaluation(Year(Cast(Literal(d), DateType)), 2015)
+ checkEvaluation(Year(Literal(d)), 2015)
checkEvaluation(Year(Cast(Literal(sdfDate.format(d)), DateType)), 2015)
checkEvaluation(Year(Cast(Literal(ts), DateType)), 2013)
@@ -106,7 +106,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
c.set(y, m, 28)
(0 to 5 * 24).foreach { i =>
c.add(Calendar.HOUR_OF_DAY, 1)
- checkEvaluation(Year(Cast(Literal(new Date(c.getTimeInMillis)), DateType)),
+ checkEvaluation(Year(Literal(new Date(c.getTimeInMillis))),
c.get(Calendar.YEAR))
}
}
@@ -115,7 +115,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
test("Quarter") {
checkEvaluation(Quarter(Literal.create(null, DateType)), null)
- checkEvaluation(Quarter(Cast(Literal(d), DateType)), 2)
+ checkEvaluation(Quarter(Literal(d)), 2)
checkEvaluation(Quarter(Cast(Literal(sdfDate.format(d)), DateType)), 2)
checkEvaluation(Quarter(Cast(Literal(ts), DateType)), 4)
@@ -125,7 +125,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
c.set(y, m, 28, 0, 0, 0)
(0 to 5 * 24).foreach { i =>
c.add(Calendar.HOUR_OF_DAY, 1)
- checkEvaluation(Quarter(Cast(Literal(new Date(c.getTimeInMillis)), DateType)),
+ checkEvaluation(Quarter(Literal(new Date(c.getTimeInMillis))),
c.get(Calendar.MONTH) / 3 + 1)
}
}
@@ -134,7 +134,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
test("Month") {
checkEvaluation(Month(Literal.create(null, DateType)), null)
- checkEvaluation(Month(Cast(Literal(d), DateType)), 4)
+ checkEvaluation(Month(Literal(d)), 4)
checkEvaluation(Month(Cast(Literal(sdfDate.format(d)), DateType)), 4)
checkEvaluation(Month(Cast(Literal(ts), DateType)), 11)
@@ -144,7 +144,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
val c = Calendar.getInstance()
c.set(y, m, 28, 0, 0, 0)
c.add(Calendar.HOUR_OF_DAY, i)
- checkEvaluation(Month(Cast(Literal(new Date(c.getTimeInMillis)), DateType)),
+ checkEvaluation(Month(Literal(new Date(c.getTimeInMillis))),
c.get(Calendar.MONTH) + 1)
}
}
@@ -156,7 +156,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
val c = Calendar.getInstance()
c.set(y, m, 28, 0, 0, 0)
c.add(Calendar.HOUR_OF_DAY, i)
- checkEvaluation(Month(Cast(Literal(new Date(c.getTimeInMillis)), DateType)),
+ checkEvaluation(Month(Literal(new Date(c.getTimeInMillis))),
c.get(Calendar.MONTH) + 1)
}
}
@@ -166,7 +166,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
test("Day / DayOfMonth") {
checkEvaluation(DayOfMonth(Cast(Literal("2000-02-29"), DateType)), 29)
checkEvaluation(DayOfMonth(Literal.create(null, DateType)), null)
- checkEvaluation(DayOfMonth(Cast(Literal(d), DateType)), 8)
+ checkEvaluation(DayOfMonth(Literal(d)), 8)
checkEvaluation(DayOfMonth(Cast(Literal(sdfDate.format(d)), DateType)), 8)
checkEvaluation(DayOfMonth(Cast(Literal(ts), DateType)), 8)
@@ -175,7 +175,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
c.set(y, 0, 1, 0, 0, 0)
(0 to 365).foreach { d =>
c.add(Calendar.DATE, 1)
- checkEvaluation(DayOfMonth(Cast(Literal(new Date(c.getTimeInMillis)), DateType)),
+ checkEvaluation(DayOfMonth(Literal(new Date(c.getTimeInMillis))),
c.get(Calendar.DAY_OF_MONTH))
}
}
@@ -190,14 +190,14 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
val c = Calendar.getInstance()
(0 to 60 by 5).foreach { s =>
c.set(2015, 18, 3, 3, 5, s)
- checkEvaluation(Second(Cast(Literal(new Timestamp(c.getTimeInMillis)), TimestampType)),
+ checkEvaluation(Second(Literal(new Timestamp(c.getTimeInMillis))),
c.get(Calendar.SECOND))
}
}
test("WeekOfYear") {
checkEvaluation(WeekOfYear(Literal.create(null, DateType)), null)
- checkEvaluation(WeekOfYear(Cast(Literal(d), DateType)), 15)
+ checkEvaluation(WeekOfYear(Literal(d)), 15)
checkEvaluation(WeekOfYear(Cast(Literal(sdfDate.format(d)), DateType)), 15)
checkEvaluation(WeekOfYear(Cast(Literal(ts), DateType)), 45)
checkEvaluation(WeekOfYear(Cast(Literal("2011-05-06"), DateType)), 18)
@@ -223,7 +223,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
(0 to 60 by 15).foreach { m =>
(0 to 60 by 15).foreach { s =>
c.set(2015, 18, 3, h, m, s)
- checkEvaluation(Hour(Cast(Literal(new Timestamp(c.getTimeInMillis)), TimestampType)),
+ checkEvaluation(Hour(Literal(new Timestamp(c.getTimeInMillis))),
c.get(Calendar.HOUR_OF_DAY))
}
}
@@ -240,7 +240,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
(0 to 60 by 5).foreach { m =>
(0 to 60 by 15).foreach { s =>
c.set(2015, 18, 3, 3, m, s)
- checkEvaluation(Minute(Cast(Literal(new Timestamp(c.getTimeInMillis)), TimestampType)),
+ checkEvaluation(Minute(Literal(new Timestamp(c.getTimeInMillis))),
c.get(Calendar.MINUTE))
}
}