diff options
author | Cheng Hao <hao.cheng@intel.com> | 2014-04-03 15:33:17 -0700 |
---|---|---|
committer | Reynold Xin <rxin@apache.org> | 2014-04-03 15:33:17 -0700 |
commit | 5d1feda217d25616d190f9bb369664e57417cd45 (patch) | |
tree | 35c79acc30995b8683843dc1aa514f3657726371 /sql/catalyst | |
parent | fbebaedf26286ee8a75065822a3af1148351f828 (diff) | |
download | spark-5d1feda217d25616d190f9bb369664e57417cd45.tar.gz spark-5d1feda217d25616d190f9bb369664e57417cd45.tar.bz2 spark-5d1feda217d25616d190f9bb369664e57417cd45.zip |
[SPARK-1360] Add Timestamp Support for SQL
This PR includes:
1) Add new data type Timestamp
2) Add more data type casting base on Hive's Rule
3) Fix bug missing data type in both parsers (HiveQl & SQLParser).
Author: Cheng Hao <hao.cheng@intel.com>
Closes #275 from chenghao-intel/timestamp and squashes the following commits:
df709e5 [Cheng Hao] Move orc_ends_with_nulls to blacklist
24b04b0 [Cheng Hao] Put 3 cases into the black lists(describe_pretty,describe_syntax,lateral_view_outer)
fc512c2 [Cheng Hao] remove the unnecessary data type equality check in data casting
d0d1919 [Cheng Hao] Add more data type for scala reflection
3259808 [Cheng Hao] Add the new Golden files
3823b97 [Cheng Hao] Update the UnitTest cases & add timestamp type for HiveQL
54a0489 [Cheng Hao] fix bug mapping to 0 (which is supposed to be null) when NumberFormatException occurs
9cb505c [Cheng Hao] Fix issues according to PR comments
e529168 [Cheng Hao] Fix bug of converting from String
6fc8100 [Cheng Hao] Update Unit Test & CodeStyle
8a1d4d6 [Cheng Hao] Add DataType for SqlParser
ce4385e [Cheng Hao] Add TimestampType Support
Diffstat (limited to 'sql/catalyst')
8 files changed, 329 insertions, 95 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 5aaa63bf3b..446d0e0bd7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst +import java.sql.Timestamp + import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.plans.logical.LocalRelation @@ -54,14 +56,15 @@ object ScalaReflection { val TypeRef(_, _, Seq(keyType, valueType)) = t MapType(schemaFor(keyType), schemaFor(valueType)) case t if t <:< typeOf[String] => StringType + case t if t <:< typeOf[Timestamp] => TimestampType + case t if t <:< typeOf[BigDecimal] => DecimalType case t if t <:< definitions.IntTpe => IntegerType case t if t <:< definitions.LongTpe => LongType - case t if t <:< definitions.FloatTpe => FloatType case t if t <:< definitions.DoubleTpe => DoubleType + case t if t <:< definitions.FloatTpe => FloatType case t if t <:< definitions.ShortTpe => ShortType case t if t <:< definitions.ByteTpe => ByteType case t if t <:< definitions.BooleanTpe => BooleanType - case t if t <:< typeOf[BigDecimal] => DecimalType } implicit class CaseClassRelation[A <: Product : TypeTag](data: Seq[A]) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 44abe671c0..2c4bf1715b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst +import java.sql.Timestamp + import scala.language.implicitConversions import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute @@ -72,6 +74,7 @@ package object dsl { def like(other: Expression) = Like(expr, other) def rlike(other: Expression) = RLike(expr, other) + def cast(to: DataType) = Cast(expr, to) def asc = SortOrder(expr, Ascending) def desc = SortOrder(expr, Descending) @@ -84,15 +87,22 @@ package object dsl { def expr = e } + implicit def booleanToLiteral(b: Boolean) = Literal(b) + implicit def byteToLiteral(b: Byte) = Literal(b) + implicit def shortToLiteral(s: Short) = Literal(s) implicit def intToLiteral(i: Int) = Literal(i) implicit def longToLiteral(l: Long) = Literal(l) implicit def floatToLiteral(f: Float) = Literal(f) implicit def doubleToLiteral(d: Double) = Literal(d) implicit def stringToLiteral(s: String) = Literal(s) + implicit def decimalToLiteral(d: BigDecimal) = Literal(d) + implicit def timestampToLiteral(t: Timestamp) = Literal(t) + implicit def binaryToLiteral(a: Array[Byte]) = Literal(a) implicit def symbolToUnresolvedAttribute(s: Symbol) = analysis.UnresolvedAttribute(s.name) implicit class DslSymbol(sym: Symbol) extends ImplicitAttribute { def s = sym.name } + // TODO more implicit class for literal? implicit class DslString(val s: String) extends ImplicitOperators { def expr: Expression = Literal(s) def attr = analysis.UnresolvedAttribute(s) @@ -103,11 +113,38 @@ package object dsl { def expr = attr def attr = analysis.UnresolvedAttribute(s) - /** Creates a new typed attributes of type int */ + /** Creates a new AttributeReference of type boolean */ + def boolean = AttributeReference(s, BooleanType, nullable = false)() + + /** Creates a new AttributeReference of type byte */ + def byte = AttributeReference(s, ByteType, nullable = false)() + + /** Creates a new AttributeReference of type short */ + def short = AttributeReference(s, ShortType, nullable = false)() + + /** Creates a new AttributeReference of type int */ def int = AttributeReference(s, IntegerType, nullable = false)() - /** Creates a new typed attributes of type string */ + /** Creates a new AttributeReference of type long */ + def long = AttributeReference(s, LongType, nullable = false)() + + /** Creates a new AttributeReference of type float */ + def float = AttributeReference(s, FloatType, nullable = false)() + + /** Creates a new AttributeReference of type double */ + def double = AttributeReference(s, DoubleType, nullable = false)() + + /** Creates a new AttributeReference of type string */ def string = AttributeReference(s, StringType, nullable = false)() + + /** Creates a new AttributeReference of type decimal */ + def decimal = AttributeReference(s, DecimalType, nullable = false)() + + /** Creates a new AttributeReference of type timestamp */ + def timestamp = AttributeReference(s, TimestampType, nullable = false)() + + /** Creates a new AttributeReference of type binary */ + def binary = AttributeReference(s, BinaryType, nullable = false)() } implicit class DslAttribute(a: AttributeReference) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index c26fc3d0f3..941b53fe70 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import java.sql.Timestamp + import org.apache.spark.sql.catalyst.types._ /** Cast the child expression to the target data type. */ @@ -26,52 +28,169 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { override def toString = s"CAST($child, $dataType)" type EvaluatedType = Any + + def nullOrCast[T](a: Any, func: T => Any): Any = if(a == null) { + null + } else { + func(a.asInstanceOf[T]) + } - lazy val castingFunction: Any => Any = (child.dataType, dataType) match { - case (BinaryType, StringType) => a: Any => new String(a.asInstanceOf[Array[Byte]]) - case (StringType, BinaryType) => a: Any => a.asInstanceOf[String].getBytes - case (_, StringType) => a: Any => a.toString - case (StringType, IntegerType) => a: Any => castOrNull(a, _.toInt) - case (StringType, DoubleType) => a: Any => castOrNull(a, _.toDouble) - case (StringType, FloatType) => a: Any => castOrNull(a, _.toFloat) - case (StringType, LongType) => a: Any => castOrNull(a, _.toLong) - case (StringType, ShortType) => a: Any => castOrNull(a, _.toShort) - case (StringType, ByteType) => a: Any => castOrNull(a, _.toByte) - case (StringType, DecimalType) => a: Any => castOrNull(a, BigDecimal(_)) - case (BooleanType, ByteType) => { - case null => null - case true => 1.toByte - case false => 0.toByte - } - case (dt, IntegerType) => - a: Any => dt.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].toInt(a) - case (dt, DoubleType) => - a: Any => dt.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].toDouble(a) - case (dt, FloatType) => - a: Any => dt.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].toFloat(a) - case (dt, LongType) => - a: Any => dt.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].toLong(a) - case (dt, ShortType) => - a: Any => dt.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].toInt(a).toShort - case (dt, ByteType) => - a: Any => dt.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].toInt(a).toByte - case (dt, DecimalType) => - a: Any => - BigDecimal(dt.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].toDouble(a)) + // UDFToString + def castToString: Any => Any = child.dataType match { + case BinaryType => nullOrCast[Array[Byte]](_, new String(_, "UTF-8")) + case _ => nullOrCast[Any](_, _.toString) + } + + // BinaryConverter + def castToBinary: Any => Any = child.dataType match { + case StringType => nullOrCast[String](_, _.getBytes("UTF-8")) } - @inline - protected def castOrNull[A](a: Any, f: String => A) = - try f(a.asInstanceOf[String]) catch { - case _: java.lang.NumberFormatException => null - } + // UDFToBoolean + def castToBoolean: Any => Any = child.dataType match { + case StringType => nullOrCast[String](_, _.length() != 0) + case TimestampType => nullOrCast[Timestamp](_, b => {(b.getTime() != 0 || b.getNanos() != 0)}) + case LongType => nullOrCast[Long](_, _ != 0) + case IntegerType => nullOrCast[Int](_, _ != 0) + case ShortType => nullOrCast[Short](_, _ != 0) + case ByteType => nullOrCast[Byte](_, _ != 0) + case DecimalType => nullOrCast[BigDecimal](_, _ != 0) + case DoubleType => nullOrCast[Double](_, _ != 0) + case FloatType => nullOrCast[Float](_, _ != 0) + } + + // TimestampConverter + def castToTimestamp: Any => Any = child.dataType match { + case StringType => nullOrCast[String](_, s => { + // Throw away extra if more than 9 decimal places + val periodIdx = s.indexOf("."); + var n = s + if (periodIdx != -1) { + if (n.length() - periodIdx > 9) { + n = n.substring(0, periodIdx + 10) + } + } + try Timestamp.valueOf(n) catch { case _: java.lang.IllegalArgumentException => null} + }) + case BooleanType => nullOrCast[Boolean](_, b => new Timestamp((if(b) 1 else 0) * 1000)) + case LongType => nullOrCast[Long](_, l => new Timestamp(l * 1000)) + case IntegerType => nullOrCast[Int](_, i => new Timestamp(i * 1000)) + case ShortType => nullOrCast[Short](_, s => new Timestamp(s * 1000)) + case ByteType => nullOrCast[Byte](_, b => new Timestamp(b * 1000)) + // TimestampWritable.decimalToTimestamp + case DecimalType => nullOrCast[BigDecimal](_, d => decimalToTimestamp(d)) + // TimestampWritable.doubleToTimestamp + case DoubleType => nullOrCast[Double](_, d => decimalToTimestamp(d)) + // TimestampWritable.floatToTimestamp + case FloatType => nullOrCast[Float](_, f => decimalToTimestamp(f)) + } + + private def decimalToTimestamp(d: BigDecimal) = { + val seconds = d.longValue() + val bd = (d - seconds) * (1000000000) + val nanos = bd.intValue() + + // Convert to millis + val millis = seconds * 1000 + val t = new Timestamp(millis) + + // remaining fractional portion as nanos + t.setNanos(nanos) + + t + } + + private def timestampToDouble(t: Timestamp) = (t.getSeconds() + t.getNanos().toDouble / 1000) + + def castToLong: Any => Any = child.dataType match { + case StringType => nullOrCast[String](_, s => try s.toLong catch { + case _: NumberFormatException => null + }) + case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0) + case TimestampType => nullOrCast[Timestamp](_, t => timestampToDouble(t).toLong) + case DecimalType => nullOrCast[BigDecimal](_, _.toLong) + case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toLong(b) + } + + def castToInt: Any => Any = child.dataType match { + case StringType => nullOrCast[String](_, s => try s.toInt catch { + case _: NumberFormatException => null + }) + case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0) + case TimestampType => nullOrCast[Timestamp](_, t => timestampToDouble(t).toInt) + case DecimalType => nullOrCast[BigDecimal](_, _.toInt) + case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b) + } + + def castToShort: Any => Any = child.dataType match { + case StringType => nullOrCast[String](_, s => try s.toShort catch { + case _: NumberFormatException => null + }) + case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0) + case TimestampType => nullOrCast[Timestamp](_, t => timestampToDouble(t).toShort) + case DecimalType => nullOrCast[BigDecimal](_, _.toShort) + case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toShort + } + + def castToByte: Any => Any = child.dataType match { + case StringType => nullOrCast[String](_, s => try s.toByte catch { + case _: NumberFormatException => null + }) + case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0) + case TimestampType => nullOrCast[Timestamp](_, t => timestampToDouble(t).toByte) + case DecimalType => nullOrCast[BigDecimal](_, _.toByte) + case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toByte + } + + def castToDecimal: Any => Any = child.dataType match { + case StringType => nullOrCast[String](_, s => try BigDecimal(s.toDouble) catch { + case _: NumberFormatException => null + }) + case BooleanType => nullOrCast[Boolean](_, b => if(b) BigDecimal(1) else BigDecimal(0)) + case TimestampType => nullOrCast[Timestamp](_, t => BigDecimal(timestampToDouble(t))) + case x: NumericType => b => BigDecimal(x.numeric.asInstanceOf[Numeric[Any]].toDouble(b)) + } + + def castToDouble: Any => Any = child.dataType match { + case StringType => nullOrCast[String](_, s => try s.toDouble catch { + case _: NumberFormatException => null + }) + case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0) + case TimestampType => nullOrCast[Timestamp](_, t => timestampToDouble(t)) + case DecimalType => nullOrCast[BigDecimal](_, _.toDouble) + case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toDouble(b) + } + + def castToFloat: Any => Any = child.dataType match { + case StringType => nullOrCast[String](_, s => try s.toFloat catch { + case _: NumberFormatException => null + }) + case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0) + case TimestampType => nullOrCast[Timestamp](_, t => timestampToDouble(t).toFloat) + case DecimalType => nullOrCast[BigDecimal](_, _.toFloat) + case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toFloat(b) + } + + def cast: Any => Any = dataType match { + case StringType => castToString + case BinaryType => castToBinary + case DecimalType => castToDecimal + case TimestampType => castToTimestamp + case BooleanType => castToBoolean + case ByteType => castToByte + case ShortType => castToShort + case IntegerType => castToInt + case FloatType => castToFloat + case LongType => castToLong + case DoubleType => castToDouble + } override def apply(input: Row): Any = { val evaluated = child.apply(input) if (evaluated == null) { null } else { - castingFunction(evaluated) + cast(evaluated) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 81fd160e00..a3d1952550 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.trees.TreeNode -import org.apache.spark.sql.catalyst.types.{DataType, FractionalType, IntegralType, NumericType} +import org.apache.spark.sql.catalyst.types.{DataType, FractionalType, IntegralType, NumericType, NativeType} abstract class Expression extends TreeNode[Expression] { self: Product => @@ -86,6 +86,11 @@ abstract class Expression extends TreeNode[Expression] { } } + /** + * Evaluation helper function for 2 Numeric children expressions. Those expressions are supposed + * to be in the same data type, and also the return type. + * Either one of the expressions result is null, the evaluation result should be null. + */ @inline protected final def n2( i: Row, @@ -115,6 +120,11 @@ abstract class Expression extends TreeNode[Expression] { } } + /** + * Evaluation helper function for 2 Fractional children expressions. Those expressions are + * supposed to be in the same data type, and also the return type. + * Either one of the expressions result is null, the evaluation result should be null. + */ @inline protected final def f2( i: Row, @@ -143,6 +153,11 @@ abstract class Expression extends TreeNode[Expression] { } } + /** + * Evaluation helper function for 2 Integral children expressions. Those expressions are + * supposed to be in the same data type, and also the return type. + * Either one of the expressions result is null, the evaluation result should be null. + */ @inline protected final def i2( i: Row, @@ -170,6 +185,43 @@ abstract class Expression extends TreeNode[Expression] { } } } + + /** + * Evaluation helper function for 2 Comparable children expressions. Those expressions are + * supposed to be in the same data type, and the return type should be Integer: + * Negative value: 1st argument less than 2nd argument + * Zero: 1st argument equals 2nd argument + * Positive value: 1st argument greater than 2nd argument + * + * Either one of the expressions result is null, the evaluation result should be null. + */ + @inline + protected final def c2( + i: Row, + e1: Expression, + e2: Expression, + f: ((Ordering[Any], Any, Any) => Any)): Any = { + if (e1.dataType != e2.dataType) { + throw new TreeNodeException(this, s"Types do not match ${e1.dataType} != ${e2.dataType}") + } + + val evalE1 = e1.apply(i) + if(evalE1 == null) { + null + } else { + val evalE2 = e2.apply(i) + if (evalE2 == null) { + null + } else { + e1.dataType match { + case i: NativeType => + f.asInstanceOf[(Ordering[i.JvmType], i.JvmType, i.JvmType) => Boolean]( + i.ordering, evalE1.asInstanceOf[i.JvmType], evalE2.asInstanceOf[i.JvmType]) + case other => sys.error(s"Type $other does not support ordered operations") + } + } + } + } } abstract class BinaryExpression extends Expression with trees.BinaryNode[Expression] { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index b82a12e0f7..d879b2b5e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import java.sql.Timestamp + import org.apache.spark.sql.catalyst.types._ object Literal { @@ -29,6 +31,9 @@ object Literal { case s: Short => Literal(s, ShortType) case s: String => Literal(s, StringType) case b: Boolean => Literal(b, BooleanType) + case d: BigDecimal => Literal(d, DecimalType) + case t: Timestamp => Literal(t, TimestampType) + case a: Array[Byte] => Literal(a, BinaryType) case null => Literal(null, NullType) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 02fedd16b8..b74809e5ca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -18,8 +18,9 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.trees +import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.analysis.UnresolvedException -import org.apache.spark.sql.catalyst.types.{BooleanType, StringType} +import org.apache.spark.sql.catalyst.types.{BooleanType, StringType, TimestampType} object InterpretedPredicate { def apply(expression: Expression): (Row => Boolean) = { @@ -123,70 +124,22 @@ case class Equals(left: Expression, right: Expression) extends BinaryComparison case class LessThan(left: Expression, right: Expression) extends BinaryComparison { def symbol = "<" - override def apply(input: Row): Any = { - if (left.dataType == StringType && right.dataType == StringType) { - val l = left.apply(input) - val r = right.apply(input) - if(l == null || r == null) { - null - } else { - l.asInstanceOf[String] < r.asInstanceOf[String] - } - } else { - n2(input, left, right, _.lt(_, _)) - } - } + override def apply(input: Row): Any = c2(input, left, right, _.lt(_, _)) } case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryComparison { def symbol = "<=" - override def apply(input: Row): Any = { - if (left.dataType == StringType && right.dataType == StringType) { - val l = left.apply(input) - val r = right.apply(input) - if(l == null || r == null) { - null - } else { - l.asInstanceOf[String] <= r.asInstanceOf[String] - } - } else { - n2(input, left, right, _.lteq(_, _)) - } - } + override def apply(input: Row): Any = c2(input, left, right, _.lteq(_, _)) } case class GreaterThan(left: Expression, right: Expression) extends BinaryComparison { def symbol = ">" - override def apply(input: Row): Any = { - if (left.dataType == StringType && right.dataType == StringType) { - val l = left.apply(input) - val r = right.apply(input) - if(l == null || r == null) { - null - } else { - l.asInstanceOf[String] > r.asInstanceOf[String] - } - } else { - n2(input, left, right, _.gt(_, _)) - } - } + override def apply(input: Row): Any = c2(input, left, right, _.gt(_, _)) } case class GreaterThanOrEqual(left: Expression, right: Expression) extends BinaryComparison { def symbol = ">=" - override def apply(input: Row): Any = { - if (left.dataType == StringType && right.dataType == StringType) { - val l = left.apply(input) - val r = right.apply(input) - if(l == null || r == null) { - null - } else { - l.asInstanceOf[String] >= r.asInstanceOf[String] - } - } else { - n2(input, left, right, _.gteq(_, _)) - } - } + override def apply(input: Row): Any = c2(input, left, right, _.gteq(_, _)) } case class If(predicate: Expression, trueValue: Expression, falseValue: Expression) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index 7a45d1a1b8..cdeb01a965 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.types +import java.sql.Timestamp + import scala.reflect.runtime.universe.{typeTag, TypeTag} import org.apache.spark.sql.catalyst.expressions.Expression @@ -51,6 +53,16 @@ case object BooleanType extends NativeType { val ordering = implicitly[Ordering[JvmType]] } +case object TimestampType extends NativeType { + type JvmType = Timestamp + + @transient lazy val tag = typeTag[JvmType] + + val ordering = new Ordering[JvmType] { + def compare(x: Timestamp, y: Timestamp) = x.compareTo(y) + } +} + abstract class NumericType extends NativeType { // Unfortunately we can't get this implicitly as that breaks Spark Serialization. In order for // implicitly[Numeric[JvmType]] to be valid, we have to change JvmType from a type variable to a diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 52a205be3e..43876033d3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import java.sql.Timestamp + import org.scalatest.FunSuite import org.apache.spark.sql.catalyst.types._ @@ -191,5 +193,56 @@ class ExpressionEvaluationSuite extends FunSuite { evaluate("abbbbc" rlike regEx, new GenericRow(Array[Any]("**"))) } } + + test("data type casting") { + + val sts = "1970-01-01 00:00:01.0" + val ts = Timestamp.valueOf(sts) + + checkEvaluation("abdef" cast StringType, "abdef") + checkEvaluation("abdef" cast DecimalType, null) + checkEvaluation("abdef" cast TimestampType, null) + checkEvaluation("12.65" cast DecimalType, BigDecimal(12.65)) + + checkEvaluation(Literal(1) cast LongType, 1) + checkEvaluation(Cast(Literal(1) cast TimestampType, LongType), 1) + checkEvaluation(Cast(Literal(BigDecimal(1)) cast TimestampType, DecimalType), 1) + checkEvaluation(Cast(Literal(1.toDouble) cast TimestampType, DoubleType), 1.toDouble) + + checkEvaluation(Cast(Literal(sts) cast TimestampType, StringType), sts) + checkEvaluation(Cast(Literal(ts) cast StringType, TimestampType), ts) + + checkEvaluation(Cast("abdef" cast BinaryType, StringType), "abdef") + + checkEvaluation(Cast(Cast(Cast(Cast( + Cast("5" cast ByteType, ShortType), IntegerType), FloatType), DoubleType), LongType), 5) + checkEvaluation(Cast(Cast(Cast(Cast( + Cast("5" cast ByteType, TimestampType), DecimalType), LongType), StringType), ShortType), 5) + checkEvaluation(Cast(Cast(Cast(Cast( + Cast("5" cast TimestampType, ByteType), DecimalType), LongType), StringType), ShortType), null) + checkEvaluation(Cast(Cast(Cast(Cast( + Cast("5" cast DecimalType, ByteType), TimestampType), LongType), StringType), ShortType), 5) + checkEvaluation(Literal(true) cast IntegerType, 1) + checkEvaluation(Literal(false) cast IntegerType, 0) + checkEvaluation(Cast(Literal(1) cast BooleanType, IntegerType), 1) + checkEvaluation(Cast(Literal(0) cast BooleanType, IntegerType), 0) + checkEvaluation("23" cast DoubleType, 23) + checkEvaluation("23" cast IntegerType, 23) + checkEvaluation("23" cast FloatType, 23) + checkEvaluation("23" cast DecimalType, 23) + checkEvaluation("23" cast ByteType, 23) + checkEvaluation("23" cast ShortType, 23) + checkEvaluation("2012-12-11" cast DoubleType, null) + checkEvaluation(Literal(123) cast IntegerType, 123) + + intercept[Exception] {evaluate(Literal(1) cast BinaryType, null)} + } + + test("timestamp") { + val ts1 = new Timestamp(12) + val ts2 = new Timestamp(123) + checkEvaluation(Literal("ab") < Literal("abc"), true) + checkEvaluation(Literal(ts1) < Literal(ts2), true) + } } |