diff options
2 files changed, 508 insertions, 51 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 3346d3c9f9..e66cd82848 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -26,6 +26,8 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{Interval, UTF8String} +import scala.collection.mutable + object Cast { @@ -418,51 +420,506 @@ case class Cast(child: Expression, dataType: DataType) protected override def nullSafeEval(input: Any): Any = cast(input) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - // TODO: Add support for more data types. - (child.dataType, dataType) match { + val eval = child.gen(ctx) + val nullSafeCast = nullSafeCastFunction(child.dataType, dataType, ctx) + eval.code + + castCode(ctx, eval.primitive, eval.isNull, ev.primitive, ev.isNull, dataType, nullSafeCast) + } + + // three function arguments are: child.primitive, result.primitive and result.isNull + // it returns the code snippets to be put in null safe evaluation region + private[this] type CastFunction = (String, String, String) => String + + private[this] def nullSafeCastFunction( + from: DataType, + to: DataType, + ctx: CodeGenContext): CastFunction = to match { + + case _ if from == NullType => (c, evPrim, evNull) => s"$evNull = true;" + case _ if to == from => (c, evPrim, evNull) => s"$evPrim = $c;" + case StringType => castToStringCode(from, ctx) + case BinaryType => castToBinaryCode(from) + case DateType => castToDateCode(from, ctx) + case decimal: DecimalType => castToDecimalCode(from, decimal) + case TimestampType => castToTimestampCode(from, ctx) + case IntervalType => castToIntervalCode(from) + case BooleanType => castToBooleanCode(from) + case ByteType => castToByteCode(from) + case ShortType => castToShortCode(from) + case IntegerType => castToIntCode(from) + case FloatType => castToFloatCode(from) + case LongType => castToLongCode(from) + case DoubleType => castToDoubleCode(from) + + case array: ArrayType => castArrayCode(from.asInstanceOf[ArrayType], array, ctx) + case map: MapType => castMapCode(from.asInstanceOf[MapType], map, ctx) + case struct: StructType => castStructCode(from.asInstanceOf[StructType], struct, ctx) + } + + // Since we need to cast child expressions recursively inside ComplexTypes, such as Map's + // Key and Value, Struct's field, we need to name out all the variable names involved in a cast. + private[this] def castCode(ctx: CodeGenContext, childPrim: String, childNull: String, + resultPrim: String, resultNull: String, resultType: DataType, cast: CastFunction): String = { + s""" + boolean $resultNull = $childNull; + ${ctx.javaType(resultType)} $resultPrim = ${ctx.defaultValue(resultType)}; + if (!${childNull}) { + ${cast(childPrim, resultPrim, resultNull)} + } + """ + } + + private[this] def castToStringCode(from: DataType, ctx: CodeGenContext): CastFunction = { + from match { + case BinaryType => + (c, evPrim, evNull) => s"$evPrim = UTF8String.fromBytes($c);" + case DateType => + (c, evPrim, evNull) => s"""$evPrim = UTF8String.fromString( + org.apache.spark.sql.catalyst.util.DateTimeUtils.dateToString($c));""" + case TimestampType => + (c, evPrim, evNull) => s"""$evPrim = UTF8String.fromString( + org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampToString($c));""" + case _ => + (c, evPrim, evNull) => s"$evPrim = UTF8String.fromString(String.valueOf($c));" + } + } + + private[this] def castToBinaryCode(from: DataType): CastFunction = from match { + case StringType => + (c, evPrim, evNull) => s"$evPrim = $c.getBytes();" + } + + private[this] def castToDateCode( + from: DataType, + ctx: CodeGenContext): CastFunction = from match { + case StringType => + val intOpt = ctx.freshName("intOpt") + (c, evPrim, evNull) => s""" + scala.Option<Integer> $intOpt = + org.apache.spark.sql.catalyst.util.DateTimeUtils.stringToDate($c); + if ($intOpt.isDefined()) { + $evPrim = ((Integer) $intOpt.get()).intValue(); + } else { + $evNull = true; + } + """ + case TimestampType => + (c, evPrim, evNull) => + s"$evPrim = org.apache.spark.sql.catalyst.util.DateTimeUtils.millisToDays($c / 1000L);"; + case _ => + (c, evPrim, evNull) => s"$evNull = true;" + } + + private[this] def changePrecision(d: String, decimalType: DecimalType, + evPrim: String, evNull: String): String = { + decimalType match { + case DecimalType.Unlimited => + s"$evPrim = $d;" + case DecimalType.Fixed(precision, scale) => + s""" + if ($d.changePrecision($precision, $scale)) { + $evPrim = $d; + } else { + $evNull = true; + } + """ + } + } - case (BinaryType, StringType) => - defineCodeGen (ctx, ev, c => - s"UTF8String.fromBytes($c)") + private[this] def castToDecimalCode(from: DataType, target: DecimalType): CastFunction = { + from match { + case StringType => + (c, evPrim, evNull) => + s""" + try { + org.apache.spark.sql.types.Decimal tmpDecimal = + new org.apache.spark.sql.types.Decimal().set( + new scala.math.BigDecimal( + new java.math.BigDecimal($c.toString()))); + ${changePrecision("tmpDecimal", target, evPrim, evNull)} + } catch (java.lang.NumberFormatException e) { + $evNull = true; + } + """ + case BooleanType => + (c, evPrim, evNull) => + s""" + org.apache.spark.sql.types.Decimal tmpDecimal = null; + if ($c) { + tmpDecimal = new org.apache.spark.sql.types.Decimal().set(1); + } else { + tmpDecimal = new org.apache.spark.sql.types.Decimal().set(0); + } + ${changePrecision("tmpDecimal", target, evPrim, evNull)} + """ + case DateType => + // date can't cast to decimal in Hive + (c, evPrim, evNull) => s"$evNull = true;" + case TimestampType => + // Note that we lose precision here. + (c, evPrim, evNull) => + s""" + org.apache.spark.sql.types.Decimal tmpDecimal = + new org.apache.spark.sql.types.Decimal().set( + scala.math.BigDecimal.valueOf(${timestampToDoubleCode(c)})); + ${changePrecision("tmpDecimal", target, evPrim, evNull)} + """ + case DecimalType() => + (c, evPrim, evNull) => + s""" + org.apache.spark.sql.types.Decimal tmpDecimal = $c.clone(); + ${changePrecision("tmpDecimal", target, evPrim, evNull)} + """ + case LongType => + (c, evPrim, evNull) => + s""" + org.apache.spark.sql.types.Decimal tmpDecimal = + new org.apache.spark.sql.types.Decimal().set($c); + ${changePrecision("tmpDecimal", target, evPrim, evNull)} + """ + case x: NumericType => + // All other numeric types can be represented precisely as Doubles + (c, evPrim, evNull) => + s""" + try { + org.apache.spark.sql.types.Decimal tmpDecimal = + new org.apache.spark.sql.types.Decimal().set( + scala.math.BigDecimal.valueOf((double) $c)); + ${changePrecision("tmpDecimal", target, evPrim, evNull)} + } catch (java.lang.NumberFormatException e) { + $evNull = true; + } + """ + } + } - case (DateType, StringType) => - defineCodeGen(ctx, ev, c => - s"""UTF8String.fromString( - org.apache.spark.sql.catalyst.util.DateTimeUtils.dateToString($c))""") + private[this] def castToTimestampCode( + from: DataType, + ctx: CodeGenContext): CastFunction = from match { + case StringType => + val longOpt = ctx.freshName("longOpt") + (c, evPrim, evNull) => + s""" + scala.Option<Long> $longOpt = + org.apache.spark.sql.catalyst.util.DateTimeUtils.stringToTimestamp($c); + if ($longOpt.isDefined()) { + $evPrim = ((Long) $longOpt.get()).longValue(); + } else { + $evNull = true; + } + """ + case BooleanType => + (c, evPrim, evNull) => s"$evPrim = $c ? 1L : 0;" + case _: IntegralType => + (c, evPrim, evNull) => s"$evPrim = ${longToTimeStampCode(c)};" + case DateType => + (c, evPrim, evNull) => + s"$evPrim = org.apache.spark.sql.catalyst.util.DateTimeUtils.daysToMillis($c) * 1000;" + case DecimalType() => + (c, evPrim, evNull) => s"$evPrim = ${decimalToTimestampCode(c)};" + case DoubleType => + (c, evPrim, evNull) => + s""" + if (Double.isNaN($c) || Double.isInfinite($c)) { + $evNull = true; + } else { + $evPrim = (long)($c * 1000000L); + } + """ + case FloatType => + (c, evPrim, evNull) => + s""" + if (Float.isNaN($c) || Float.isInfinite($c)) { + $evNull = true; + } else { + $evPrim = (long)($c * 1000000L); + } + """ + } - case (TimestampType, StringType) => - defineCodeGen(ctx, ev, c => - s"""UTF8String.fromString( - org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampToString($c))""") + private[this] def castToIntervalCode(from: DataType): CastFunction = from match { + case StringType => + (c, evPrim, evNull) => + s"$evPrim = org.apache.spark.unsafe.types.Interval.fromString($c.toString());" + } + + private[this] def decimalToTimestampCode(d: String): String = + s"($d.toBigDecimal().bigDecimal().multiply(new java.math.BigDecimal(1000000L))).longValue()" + private[this] def longToTimeStampCode(l: String): String = s"$l * 1000L" + private[this] def timestampToIntegerCode(ts: String): String = + s"java.lang.Math.floor((double) $ts / 1000000L)" + private[this] def timestampToDoubleCode(ts: String): String = s"$ts / 1000000.0" + + private[this] def castToBooleanCode(from: DataType): CastFunction = from match { + case StringType => + (c, evPrim, evNull) => s"$evPrim = $c.numBytes() != 0;" + case TimestampType => + (c, evPrim, evNull) => s"$evPrim = $c != 0;" + case DateType => + // Hive would return null when cast from date to boolean + (c, evPrim, evNull) => s"$evNull = true;" + case DecimalType() => + (c, evPrim, evNull) => s"$evPrim = !$c.isZero();" + case n: NumericType => + (c, evPrim, evNull) => s"$evPrim = $c != 0;" + } + + private[this] def castToByteCode(from: DataType): CastFunction = from match { + case StringType => + (c, evPrim, evNull) => + s""" + try { + $evPrim = Byte.valueOf($c.toString()); + } catch (java.lang.NumberFormatException e) { + $evNull = true; + } + """ + case BooleanType => + (c, evPrim, evNull) => s"$evPrim = $c ? 1 : 0;" + case DateType => + (c, evPrim, evNull) => s"$evNull = true;" + case TimestampType => + (c, evPrim, evNull) => s"$evPrim = (byte) ${timestampToIntegerCode(c)};" + case DecimalType() => + (c, evPrim, evNull) => s"$evPrim = $c.toByte();" + case x: NumericType => + (c, evPrim, evNull) => s"$evPrim = (byte) $c;" + } - case (_, StringType) => - defineCodeGen(ctx, ev, c => s"UTF8String.fromString(String.valueOf($c))") + private[this] def castToShortCode(from: DataType): CastFunction = from match { + case StringType => + (c, evPrim, evNull) => + s""" + try { + $evPrim = Short.valueOf($c.toString()); + } catch (java.lang.NumberFormatException e) { + $evNull = true; + } + """ + case BooleanType => + (c, evPrim, evNull) => s"$evPrim = $c ? 1 : 0;" + case DateType => + (c, evPrim, evNull) => s"$evNull = true;" + case TimestampType => + (c, evPrim, evNull) => s"$evPrim = (short) ${timestampToIntegerCode(c)};" + case DecimalType() => + (c, evPrim, evNull) => s"$evPrim = $c.toShort();" + case x: NumericType => + (c, evPrim, evNull) => s"$evPrim = (short) $c;" + } - case (StringType, IntervalType) => - defineCodeGen(ctx, ev, c => - s"org.apache.spark.unsafe.types.Interval.fromString($c.toString())") + private[this] def castToIntCode(from: DataType): CastFunction = from match { + case StringType => + (c, evPrim, evNull) => + s""" + try { + $evPrim = Integer.valueOf($c.toString()); + } catch (java.lang.NumberFormatException e) { + $evNull = true; + } + """ + case BooleanType => + (c, evPrim, evNull) => s"$evPrim = $c ? 1 : 0;" + case DateType => + (c, evPrim, evNull) => s"$evNull = true;" + case TimestampType => + (c, evPrim, evNull) => s"$evPrim = (int) ${timestampToIntegerCode(c)};" + case DecimalType() => + (c, evPrim, evNull) => s"$evPrim = $c.toInt();" + case x: NumericType => + (c, evPrim, evNull) => s"$evPrim = (int) $c;" + } - // fallback for DecimalType, this must be before other numeric types - case (_, dt: DecimalType) => - super.genCode(ctx, ev) + private[this] def castToLongCode(from: DataType): CastFunction = from match { + case StringType => + (c, evPrim, evNull) => + s""" + try { + $evPrim = Long.valueOf($c.toString()); + } catch (java.lang.NumberFormatException e) { + $evNull = true; + } + """ + case BooleanType => + (c, evPrim, evNull) => s"$evPrim = $c ? 1 : 0;" + case DateType => + (c, evPrim, evNull) => s"$evNull = true;" + case TimestampType => + (c, evPrim, evNull) => s"$evPrim = (long) ${timestampToIntegerCode(c)};" + case DecimalType() => + (c, evPrim, evNull) => s"$evPrim = $c.toLong();" + case x: NumericType => + (c, evPrim, evNull) => s"$evPrim = (long) $c;" + } - case (BooleanType, dt: NumericType) => - defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})($c ? 1 : 0)") + private[this] def castToFloatCode(from: DataType): CastFunction = from match { + case StringType => + (c, evPrim, evNull) => + s""" + try { + $evPrim = Float.valueOf($c.toString()); + } catch (java.lang.NumberFormatException e) { + $evNull = true; + } + """ + case BooleanType => + (c, evPrim, evNull) => s"$evPrim = $c ? 1 : 0;" + case DateType => + (c, evPrim, evNull) => s"$evNull = true;" + case TimestampType => + (c, evPrim, evNull) => s"$evPrim = (float) (${timestampToDoubleCode(c)});" + case DecimalType() => + (c, evPrim, evNull) => s"$evPrim = $c.toFloat();" + case x: NumericType => + (c, evPrim, evNull) => s"$evPrim = (float) $c;" + } - case (dt: DecimalType, BooleanType) => - defineCodeGen(ctx, ev, c => s"!$c.isZero()") + private[this] def castToDoubleCode(from: DataType): CastFunction = from match { + case StringType => + (c, evPrim, evNull) => + s""" + try { + $evPrim = Double.valueOf($c.toString()); + } catch (java.lang.NumberFormatException e) { + $evNull = true; + } + """ + case BooleanType => + (c, evPrim, evNull) => s"$evPrim = $c ? 1 : 0;" + case DateType => + (c, evPrim, evNull) => s"$evNull = true;" + case TimestampType => + (c, evPrim, evNull) => s"$evPrim = ${timestampToDoubleCode(c)};" + case DecimalType() => + (c, evPrim, evNull) => s"$evPrim = $c.toDouble();" + case x: NumericType => + (c, evPrim, evNull) => s"$evPrim = (double) $c;" + } - case (dt: NumericType, BooleanType) => - defineCodeGen(ctx, ev, c => s"$c != 0") + private[this] def castArrayCode( + from: ArrayType, to: ArrayType, ctx: CodeGenContext): CastFunction = { + val elementCast = nullSafeCastFunction(from.elementType, to.elementType, ctx) + + val arraySeqClass = classOf[mutable.ArraySeq[Any]].getName + val fromElementNull = ctx.freshName("feNull") + val fromElementPrim = ctx.freshName("fePrim") + val toElementNull = ctx.freshName("teNull") + val toElementPrim = ctx.freshName("tePrim") + val size = ctx.freshName("n") + val j = ctx.freshName("j") + val result = ctx.freshName("result") + + (c, evPrim, evNull) => + s""" + final int $size = $c.size(); + final $arraySeqClass<Object> $result = new $arraySeqClass<Object>($size); + for (int $j = 0; $j < $size; $j ++) { + if ($c.apply($j) == null) { + $result.update($j, null); + } else { + boolean $fromElementNull = false; + ${ctx.javaType(from.elementType)} $fromElementPrim = + (${ctx.boxedType(from.elementType)}) $c.apply($j); + ${castCode(ctx, fromElementPrim, + fromElementNull, toElementPrim, toElementNull, to.elementType, elementCast)} + if ($toElementNull) { + $result.update($j, null); + } else { + $result.update($j, $toElementPrim); + } + } + } + $evPrim = $result; + """ + } - case (_: DecimalType, dt: NumericType) => - defineCodeGen(ctx, ev, c => s"($c).to${ctx.primitiveTypeName(dt)}()") + private[this] def castMapCode(from: MapType, to: MapType, ctx: CodeGenContext): CastFunction = { + val keyCast = nullSafeCastFunction(from.keyType, to.keyType, ctx) + val valueCast = nullSafeCastFunction(from.valueType, to.valueType, ctx) + + val hashMapClass = classOf[mutable.HashMap[Any, Any]].getName + val fromKeyPrim = ctx.freshName("fkp") + val fromKeyNull = ctx.freshName("fkn") + val fromValuePrim = ctx.freshName("fvp") + val fromValueNull = ctx.freshName("fvn") + val toKeyPrim = ctx.freshName("tkp") + val toKeyNull = ctx.freshName("tkn") + val toValuePrim = ctx.freshName("tvp") + val toValueNull = ctx.freshName("tvn") + val result = ctx.freshName("result") + + (c, evPrim, evNull) => + s""" + final $hashMapClass $result = new $hashMapClass(); + scala.collection.Iterator iter = $c.iterator(); + while (iter.hasNext()) { + scala.Tuple2 kv = (scala.Tuple2) iter.next(); + boolean $fromKeyNull = false; + ${ctx.javaType(from.keyType)} $fromKeyPrim = + (${ctx.boxedType(from.keyType)}) kv._1(); + ${castCode(ctx, fromKeyPrim, + fromKeyNull, toKeyPrim, toKeyNull, to.keyType, keyCast)} + + boolean $fromValueNull = kv._2() == null; + if ($fromValueNull) { + $result.put($toKeyPrim, null); + } else { + ${ctx.javaType(from.valueType)} $fromValuePrim = + (${ctx.boxedType(from.valueType)}) kv._2(); + ${castCode(ctx, fromValuePrim, + fromValueNull, toValuePrim, toValueNull, to.valueType, valueCast)} + if ($toValueNull) { + $result.put($toKeyPrim, null); + } else { + $result.put($toKeyPrim, $toValuePrim); + } + } + } + $evPrim = $result; + """ + } - case (_: NumericType, dt: NumericType) => - defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})($c)") + private[this] def castStructCode( + from: StructType, to: StructType, ctx: CodeGenContext): CastFunction = { - case other => - super.genCode(ctx, ev) + val fieldsCasts = from.fields.zip(to.fields).map { + case (fromField, toField) => nullSafeCastFunction(fromField.dataType, toField.dataType, ctx) } + val rowClass = classOf[GenericMutableRow].getName + val result = ctx.freshName("result") + val tmpRow = ctx.freshName("tmpRow") + + val fieldsEvalCode = fieldsCasts.zipWithIndex.map { case (cast, i) => { + val fromFieldPrim = ctx.freshName("ffp") + val fromFieldNull = ctx.freshName("ffn") + val toFieldPrim = ctx.freshName("tfp") + val toFieldNull = ctx.freshName("tfn") + val fromType = ctx.javaType(from.fields(i).dataType) + s""" + boolean $fromFieldNull = $tmpRow.isNullAt($i); + if ($fromFieldNull) { + $result.setNullAt($i); + } else { + $fromType $fromFieldPrim = + ${ctx.getColumn(tmpRow, from.fields(i).dataType, i)}; + ${castCode(ctx, fromFieldPrim, + fromFieldNull, toFieldPrim, toFieldNull, to.fields(i).dataType, cast)} + if ($toFieldNull) { + $result.setNullAt($i); + } else { + ${ctx.setColumn(result, to.fields(i).dataType, i, toFieldPrim)}; + } + } + """ + } + }.mkString("\n") + + (c, evPrim, evNull) => + s""" + final $rowClass $result = new $rowClass(${fieldsCasts.size}); + final InternalRow $tmpRow = $c; + $fieldsEvalCode + $evPrim = $result.copy(); + """ } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala index f724bab4d8..bdba6ce891 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala @@ -39,7 +39,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val c = Calendar.getInstance() c.set(y, m, 28, 0, 0, 0) c.add(Calendar.DATE, i) - checkEvaluation(DayOfYear(Cast(Literal(new Date(c.getTimeInMillis)), DateType)), + checkEvaluation(DayOfYear(Literal(new Date(c.getTimeInMillis))), sdfDay.format(c.getTime).toInt) } } @@ -51,7 +51,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val c = Calendar.getInstance() c.set(y, m, 28, 0, 0, 0) c.add(Calendar.DATE, i) - checkEvaluation(DayOfYear(Cast(Literal(new Date(c.getTimeInMillis)), DateType)), + checkEvaluation(DayOfYear(Literal(new Date(c.getTimeInMillis))), sdfDay.format(c.getTime).toInt) } } @@ -63,7 +63,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val c = Calendar.getInstance() c.set(y, m, 28, 0, 0, 0) c.add(Calendar.DATE, i) - checkEvaluation(DayOfYear(Cast(Literal(new Date(c.getTimeInMillis)), DateType)), + checkEvaluation(DayOfYear(Literal(new Date(c.getTimeInMillis))), sdfDay.format(c.getTime).toInt) } } @@ -75,7 +75,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val c = Calendar.getInstance() c.set(y, m, 28, 0, 0, 0) c.add(Calendar.DATE, i) - checkEvaluation(DayOfYear(Cast(Literal(new Date(c.getTimeInMillis)), DateType)), + checkEvaluation(DayOfYear(Literal(new Date(c.getTimeInMillis))), sdfDay.format(c.getTime).toInt) } } @@ -87,7 +87,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val c = Calendar.getInstance() c.set(y, m, 28, 0, 0, 0) c.add(Calendar.DATE, i) - checkEvaluation(DayOfYear(Cast(Literal(new Date(c.getTimeInMillis)), DateType)), + checkEvaluation(DayOfYear(Literal(new Date(c.getTimeInMillis))), sdfDay.format(c.getTime).toInt) } } @@ -96,7 +96,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("Year") { checkEvaluation(Year(Literal.create(null, DateType)), null) - checkEvaluation(Year(Cast(Literal(d), DateType)), 2015) + checkEvaluation(Year(Literal(d)), 2015) checkEvaluation(Year(Cast(Literal(sdfDate.format(d)), DateType)), 2015) checkEvaluation(Year(Cast(Literal(ts), DateType)), 2013) @@ -106,7 +106,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { c.set(y, m, 28) (0 to 5 * 24).foreach { i => c.add(Calendar.HOUR_OF_DAY, 1) - checkEvaluation(Year(Cast(Literal(new Date(c.getTimeInMillis)), DateType)), + checkEvaluation(Year(Literal(new Date(c.getTimeInMillis))), c.get(Calendar.YEAR)) } } @@ -115,7 +115,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("Quarter") { checkEvaluation(Quarter(Literal.create(null, DateType)), null) - checkEvaluation(Quarter(Cast(Literal(d), DateType)), 2) + checkEvaluation(Quarter(Literal(d)), 2) checkEvaluation(Quarter(Cast(Literal(sdfDate.format(d)), DateType)), 2) checkEvaluation(Quarter(Cast(Literal(ts), DateType)), 4) @@ -125,7 +125,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { c.set(y, m, 28, 0, 0, 0) (0 to 5 * 24).foreach { i => c.add(Calendar.HOUR_OF_DAY, 1) - checkEvaluation(Quarter(Cast(Literal(new Date(c.getTimeInMillis)), DateType)), + checkEvaluation(Quarter(Literal(new Date(c.getTimeInMillis))), c.get(Calendar.MONTH) / 3 + 1) } } @@ -134,7 +134,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("Month") { checkEvaluation(Month(Literal.create(null, DateType)), null) - checkEvaluation(Month(Cast(Literal(d), DateType)), 4) + checkEvaluation(Month(Literal(d)), 4) checkEvaluation(Month(Cast(Literal(sdfDate.format(d)), DateType)), 4) checkEvaluation(Month(Cast(Literal(ts), DateType)), 11) @@ -144,7 +144,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val c = Calendar.getInstance() c.set(y, m, 28, 0, 0, 0) c.add(Calendar.HOUR_OF_DAY, i) - checkEvaluation(Month(Cast(Literal(new Date(c.getTimeInMillis)), DateType)), + checkEvaluation(Month(Literal(new Date(c.getTimeInMillis))), c.get(Calendar.MONTH) + 1) } } @@ -156,7 +156,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val c = Calendar.getInstance() c.set(y, m, 28, 0, 0, 0) c.add(Calendar.HOUR_OF_DAY, i) - checkEvaluation(Month(Cast(Literal(new Date(c.getTimeInMillis)), DateType)), + checkEvaluation(Month(Literal(new Date(c.getTimeInMillis))), c.get(Calendar.MONTH) + 1) } } @@ -166,7 +166,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("Day / DayOfMonth") { checkEvaluation(DayOfMonth(Cast(Literal("2000-02-29"), DateType)), 29) checkEvaluation(DayOfMonth(Literal.create(null, DateType)), null) - checkEvaluation(DayOfMonth(Cast(Literal(d), DateType)), 8) + checkEvaluation(DayOfMonth(Literal(d)), 8) checkEvaluation(DayOfMonth(Cast(Literal(sdfDate.format(d)), DateType)), 8) checkEvaluation(DayOfMonth(Cast(Literal(ts), DateType)), 8) @@ -175,7 +175,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { c.set(y, 0, 1, 0, 0, 0) (0 to 365).foreach { d => c.add(Calendar.DATE, 1) - checkEvaluation(DayOfMonth(Cast(Literal(new Date(c.getTimeInMillis)), DateType)), + checkEvaluation(DayOfMonth(Literal(new Date(c.getTimeInMillis))), c.get(Calendar.DAY_OF_MONTH)) } } @@ -190,14 +190,14 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val c = Calendar.getInstance() (0 to 60 by 5).foreach { s => c.set(2015, 18, 3, 3, 5, s) - checkEvaluation(Second(Cast(Literal(new Timestamp(c.getTimeInMillis)), TimestampType)), + checkEvaluation(Second(Literal(new Timestamp(c.getTimeInMillis))), c.get(Calendar.SECOND)) } } test("WeekOfYear") { checkEvaluation(WeekOfYear(Literal.create(null, DateType)), null) - checkEvaluation(WeekOfYear(Cast(Literal(d), DateType)), 15) + checkEvaluation(WeekOfYear(Literal(d)), 15) checkEvaluation(WeekOfYear(Cast(Literal(sdfDate.format(d)), DateType)), 15) checkEvaluation(WeekOfYear(Cast(Literal(ts), DateType)), 45) checkEvaluation(WeekOfYear(Cast(Literal("2011-05-06"), DateType)), 18) @@ -223,7 +223,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { (0 to 60 by 15).foreach { m => (0 to 60 by 15).foreach { s => c.set(2015, 18, 3, h, m, s) - checkEvaluation(Hour(Cast(Literal(new Timestamp(c.getTimeInMillis)), TimestampType)), + checkEvaluation(Hour(Literal(new Timestamp(c.getTimeInMillis))), c.get(Calendar.HOUR_OF_DAY)) } } @@ -240,7 +240,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { (0 to 60 by 5).foreach { m => (0 to 60 by 15).foreach { s => c.set(2015, 18, 3, 3, m, s) - checkEvaluation(Minute(Cast(Literal(new Timestamp(c.getTimeInMillis)), TimestampType)), + checkEvaluation(Minute(Literal(new Timestamp(c.getTimeInMillis))), c.get(Calendar.MINUTE)) } } |