aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
authorCheng Hao <hao.cheng@intel.com>2014-04-03 15:33:17 -0700
committerReynold Xin <rxin@apache.org>2014-04-03 15:33:17 -0700
commit5d1feda217d25616d190f9bb369664e57417cd45 (patch)
tree35c79acc30995b8683843dc1aa514f3657726371 /sql/catalyst
parentfbebaedf26286ee8a75065822a3af1148351f828 (diff)
downloadspark-5d1feda217d25616d190f9bb369664e57417cd45.tar.gz
spark-5d1feda217d25616d190f9bb369664e57417cd45.tar.bz2
spark-5d1feda217d25616d190f9bb369664e57417cd45.zip
[SPARK-1360] Add Timestamp Support for SQL
This PR includes: 1) Add new data type Timestamp 2) Add more data type casting base on Hive's Rule 3) Fix bug missing data type in both parsers (HiveQl & SQLParser). Author: Cheng Hao <hao.cheng@intel.com> Closes #275 from chenghao-intel/timestamp and squashes the following commits: df709e5 [Cheng Hao] Move orc_ends_with_nulls to blacklist 24b04b0 [Cheng Hao] Put 3 cases into the black lists(describe_pretty,describe_syntax,lateral_view_outer) fc512c2 [Cheng Hao] remove the unnecessary data type equality check in data casting d0d1919 [Cheng Hao] Add more data type for scala reflection 3259808 [Cheng Hao] Add the new Golden files 3823b97 [Cheng Hao] Update the UnitTest cases & add timestamp type for HiveQL 54a0489 [Cheng Hao] fix bug mapping to 0 (which is supposed to be null) when NumberFormatException occurs 9cb505c [Cheng Hao] Fix issues according to PR comments e529168 [Cheng Hao] Fix bug of converting from String 6fc8100 [Cheng Hao] Update Unit Test & CodeStyle 8a1d4d6 [Cheng Hao] Add DataType for SqlParser ce4385e [Cheng Hao] Add TimestampType Support
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala7
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala41
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala193
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala54
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala5
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala59
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala12
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala53
8 files changed, 329 insertions, 95 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 5aaa63bf3b..446d0e0bd7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.catalyst
+import java.sql.Timestamp
+
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
@@ -54,14 +56,15 @@ object ScalaReflection {
val TypeRef(_, _, Seq(keyType, valueType)) = t
MapType(schemaFor(keyType), schemaFor(valueType))
case t if t <:< typeOf[String] => StringType
+ case t if t <:< typeOf[Timestamp] => TimestampType
+ case t if t <:< typeOf[BigDecimal] => DecimalType
case t if t <:< definitions.IntTpe => IntegerType
case t if t <:< definitions.LongTpe => LongType
- case t if t <:< definitions.FloatTpe => FloatType
case t if t <:< definitions.DoubleTpe => DoubleType
+ case t if t <:< definitions.FloatTpe => FloatType
case t if t <:< definitions.ShortTpe => ShortType
case t if t <:< definitions.ByteTpe => ByteType
case t if t <:< definitions.BooleanTpe => BooleanType
- case t if t <:< typeOf[BigDecimal] => DecimalType
}
implicit class CaseClassRelation[A <: Product : TypeTag](data: Seq[A]) {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index 44abe671c0..2c4bf1715b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.catalyst
+import java.sql.Timestamp
+
import scala.language.implicitConversions
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
@@ -72,6 +74,7 @@ package object dsl {
def like(other: Expression) = Like(expr, other)
def rlike(other: Expression) = RLike(expr, other)
+ def cast(to: DataType) = Cast(expr, to)
def asc = SortOrder(expr, Ascending)
def desc = SortOrder(expr, Descending)
@@ -84,15 +87,22 @@ package object dsl {
def expr = e
}
+ implicit def booleanToLiteral(b: Boolean) = Literal(b)
+ implicit def byteToLiteral(b: Byte) = Literal(b)
+ implicit def shortToLiteral(s: Short) = Literal(s)
implicit def intToLiteral(i: Int) = Literal(i)
implicit def longToLiteral(l: Long) = Literal(l)
implicit def floatToLiteral(f: Float) = Literal(f)
implicit def doubleToLiteral(d: Double) = Literal(d)
implicit def stringToLiteral(s: String) = Literal(s)
+ implicit def decimalToLiteral(d: BigDecimal) = Literal(d)
+ implicit def timestampToLiteral(t: Timestamp) = Literal(t)
+ implicit def binaryToLiteral(a: Array[Byte]) = Literal(a)
implicit def symbolToUnresolvedAttribute(s: Symbol) = analysis.UnresolvedAttribute(s.name)
implicit class DslSymbol(sym: Symbol) extends ImplicitAttribute { def s = sym.name }
+ // TODO more implicit class for literal?
implicit class DslString(val s: String) extends ImplicitOperators {
def expr: Expression = Literal(s)
def attr = analysis.UnresolvedAttribute(s)
@@ -103,11 +113,38 @@ package object dsl {
def expr = attr
def attr = analysis.UnresolvedAttribute(s)
- /** Creates a new typed attributes of type int */
+ /** Creates a new AttributeReference of type boolean */
+ def boolean = AttributeReference(s, BooleanType, nullable = false)()
+
+ /** Creates a new AttributeReference of type byte */
+ def byte = AttributeReference(s, ByteType, nullable = false)()
+
+ /** Creates a new AttributeReference of type short */
+ def short = AttributeReference(s, ShortType, nullable = false)()
+
+ /** Creates a new AttributeReference of type int */
def int = AttributeReference(s, IntegerType, nullable = false)()
- /** Creates a new typed attributes of type string */
+ /** Creates a new AttributeReference of type long */
+ def long = AttributeReference(s, LongType, nullable = false)()
+
+ /** Creates a new AttributeReference of type float */
+ def float = AttributeReference(s, FloatType, nullable = false)()
+
+ /** Creates a new AttributeReference of type double */
+ def double = AttributeReference(s, DoubleType, nullable = false)()
+
+ /** Creates a new AttributeReference of type string */
def string = AttributeReference(s, StringType, nullable = false)()
+
+ /** Creates a new AttributeReference of type decimal */
+ def decimal = AttributeReference(s, DecimalType, nullable = false)()
+
+ /** Creates a new AttributeReference of type timestamp */
+ def timestamp = AttributeReference(s, TimestampType, nullable = false)()
+
+ /** Creates a new AttributeReference of type binary */
+ def binary = AttributeReference(s, BinaryType, nullable = false)()
}
implicit class DslAttribute(a: AttributeReference) {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index c26fc3d0f3..941b53fe70 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.catalyst.expressions
+import java.sql.Timestamp
+
import org.apache.spark.sql.catalyst.types._
/** Cast the child expression to the target data type. */
@@ -26,52 +28,169 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
override def toString = s"CAST($child, $dataType)"
type EvaluatedType = Any
+
+ def nullOrCast[T](a: Any, func: T => Any): Any = if(a == null) {
+ null
+ } else {
+ func(a.asInstanceOf[T])
+ }
- lazy val castingFunction: Any => Any = (child.dataType, dataType) match {
- case (BinaryType, StringType) => a: Any => new String(a.asInstanceOf[Array[Byte]])
- case (StringType, BinaryType) => a: Any => a.asInstanceOf[String].getBytes
- case (_, StringType) => a: Any => a.toString
- case (StringType, IntegerType) => a: Any => castOrNull(a, _.toInt)
- case (StringType, DoubleType) => a: Any => castOrNull(a, _.toDouble)
- case (StringType, FloatType) => a: Any => castOrNull(a, _.toFloat)
- case (StringType, LongType) => a: Any => castOrNull(a, _.toLong)
- case (StringType, ShortType) => a: Any => castOrNull(a, _.toShort)
- case (StringType, ByteType) => a: Any => castOrNull(a, _.toByte)
- case (StringType, DecimalType) => a: Any => castOrNull(a, BigDecimal(_))
- case (BooleanType, ByteType) => {
- case null => null
- case true => 1.toByte
- case false => 0.toByte
- }
- case (dt, IntegerType) =>
- a: Any => dt.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].toInt(a)
- case (dt, DoubleType) =>
- a: Any => dt.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].toDouble(a)
- case (dt, FloatType) =>
- a: Any => dt.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].toFloat(a)
- case (dt, LongType) =>
- a: Any => dt.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].toLong(a)
- case (dt, ShortType) =>
- a: Any => dt.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].toInt(a).toShort
- case (dt, ByteType) =>
- a: Any => dt.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].toInt(a).toByte
- case (dt, DecimalType) =>
- a: Any =>
- BigDecimal(dt.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].toDouble(a))
+ // UDFToString
+ def castToString: Any => Any = child.dataType match {
+ case BinaryType => nullOrCast[Array[Byte]](_, new String(_, "UTF-8"))
+ case _ => nullOrCast[Any](_, _.toString)
+ }
+
+ // BinaryConverter
+ def castToBinary: Any => Any = child.dataType match {
+ case StringType => nullOrCast[String](_, _.getBytes("UTF-8"))
}
- @inline
- protected def castOrNull[A](a: Any, f: String => A) =
- try f(a.asInstanceOf[String]) catch {
- case _: java.lang.NumberFormatException => null
- }
+ // UDFToBoolean
+ def castToBoolean: Any => Any = child.dataType match {
+ case StringType => nullOrCast[String](_, _.length() != 0)
+ case TimestampType => nullOrCast[Timestamp](_, b => {(b.getTime() != 0 || b.getNanos() != 0)})
+ case LongType => nullOrCast[Long](_, _ != 0)
+ case IntegerType => nullOrCast[Int](_, _ != 0)
+ case ShortType => nullOrCast[Short](_, _ != 0)
+ case ByteType => nullOrCast[Byte](_, _ != 0)
+ case DecimalType => nullOrCast[BigDecimal](_, _ != 0)
+ case DoubleType => nullOrCast[Double](_, _ != 0)
+ case FloatType => nullOrCast[Float](_, _ != 0)
+ }
+
+ // TimestampConverter
+ def castToTimestamp: Any => Any = child.dataType match {
+ case StringType => nullOrCast[String](_, s => {
+ // Throw away extra if more than 9 decimal places
+ val periodIdx = s.indexOf(".");
+ var n = s
+ if (periodIdx != -1) {
+ if (n.length() - periodIdx > 9) {
+ n = n.substring(0, periodIdx + 10)
+ }
+ }
+ try Timestamp.valueOf(n) catch { case _: java.lang.IllegalArgumentException => null}
+ })
+ case BooleanType => nullOrCast[Boolean](_, b => new Timestamp((if(b) 1 else 0) * 1000))
+ case LongType => nullOrCast[Long](_, l => new Timestamp(l * 1000))
+ case IntegerType => nullOrCast[Int](_, i => new Timestamp(i * 1000))
+ case ShortType => nullOrCast[Short](_, s => new Timestamp(s * 1000))
+ case ByteType => nullOrCast[Byte](_, b => new Timestamp(b * 1000))
+ // TimestampWritable.decimalToTimestamp
+ case DecimalType => nullOrCast[BigDecimal](_, d => decimalToTimestamp(d))
+ // TimestampWritable.doubleToTimestamp
+ case DoubleType => nullOrCast[Double](_, d => decimalToTimestamp(d))
+ // TimestampWritable.floatToTimestamp
+ case FloatType => nullOrCast[Float](_, f => decimalToTimestamp(f))
+ }
+
+ private def decimalToTimestamp(d: BigDecimal) = {
+ val seconds = d.longValue()
+ val bd = (d - seconds) * (1000000000)
+ val nanos = bd.intValue()
+
+ // Convert to millis
+ val millis = seconds * 1000
+ val t = new Timestamp(millis)
+
+ // remaining fractional portion as nanos
+ t.setNanos(nanos)
+
+ t
+ }
+
+ private def timestampToDouble(t: Timestamp) = (t.getSeconds() + t.getNanos().toDouble / 1000)
+
+ def castToLong: Any => Any = child.dataType match {
+ case StringType => nullOrCast[String](_, s => try s.toLong catch {
+ case _: NumberFormatException => null
+ })
+ case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0)
+ case TimestampType => nullOrCast[Timestamp](_, t => timestampToDouble(t).toLong)
+ case DecimalType => nullOrCast[BigDecimal](_, _.toLong)
+ case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toLong(b)
+ }
+
+ def castToInt: Any => Any = child.dataType match {
+ case StringType => nullOrCast[String](_, s => try s.toInt catch {
+ case _: NumberFormatException => null
+ })
+ case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0)
+ case TimestampType => nullOrCast[Timestamp](_, t => timestampToDouble(t).toInt)
+ case DecimalType => nullOrCast[BigDecimal](_, _.toInt)
+ case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b)
+ }
+
+ def castToShort: Any => Any = child.dataType match {
+ case StringType => nullOrCast[String](_, s => try s.toShort catch {
+ case _: NumberFormatException => null
+ })
+ case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0)
+ case TimestampType => nullOrCast[Timestamp](_, t => timestampToDouble(t).toShort)
+ case DecimalType => nullOrCast[BigDecimal](_, _.toShort)
+ case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toShort
+ }
+
+ def castToByte: Any => Any = child.dataType match {
+ case StringType => nullOrCast[String](_, s => try s.toByte catch {
+ case _: NumberFormatException => null
+ })
+ case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0)
+ case TimestampType => nullOrCast[Timestamp](_, t => timestampToDouble(t).toByte)
+ case DecimalType => nullOrCast[BigDecimal](_, _.toByte)
+ case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toByte
+ }
+
+ def castToDecimal: Any => Any = child.dataType match {
+ case StringType => nullOrCast[String](_, s => try BigDecimal(s.toDouble) catch {
+ case _: NumberFormatException => null
+ })
+ case BooleanType => nullOrCast[Boolean](_, b => if(b) BigDecimal(1) else BigDecimal(0))
+ case TimestampType => nullOrCast[Timestamp](_, t => BigDecimal(timestampToDouble(t)))
+ case x: NumericType => b => BigDecimal(x.numeric.asInstanceOf[Numeric[Any]].toDouble(b))
+ }
+
+ def castToDouble: Any => Any = child.dataType match {
+ case StringType => nullOrCast[String](_, s => try s.toDouble catch {
+ case _: NumberFormatException => null
+ })
+ case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0)
+ case TimestampType => nullOrCast[Timestamp](_, t => timestampToDouble(t))
+ case DecimalType => nullOrCast[BigDecimal](_, _.toDouble)
+ case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toDouble(b)
+ }
+
+ def castToFloat: Any => Any = child.dataType match {
+ case StringType => nullOrCast[String](_, s => try s.toFloat catch {
+ case _: NumberFormatException => null
+ })
+ case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0)
+ case TimestampType => nullOrCast[Timestamp](_, t => timestampToDouble(t).toFloat)
+ case DecimalType => nullOrCast[BigDecimal](_, _.toFloat)
+ case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toFloat(b)
+ }
+
+ def cast: Any => Any = dataType match {
+ case StringType => castToString
+ case BinaryType => castToBinary
+ case DecimalType => castToDecimal
+ case TimestampType => castToTimestamp
+ case BooleanType => castToBoolean
+ case ByteType => castToByte
+ case ShortType => castToShort
+ case IntegerType => castToInt
+ case FloatType => castToFloat
+ case LongType => castToLong
+ case DoubleType => castToDouble
+ }
override def apply(input: Row): Any = {
val evaluated = child.apply(input)
if (evaluated == null) {
null
} else {
- castingFunction(evaluated)
+ cast(evaluated)
}
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index 81fd160e00..a3d1952550 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.trees.TreeNode
-import org.apache.spark.sql.catalyst.types.{DataType, FractionalType, IntegralType, NumericType}
+import org.apache.spark.sql.catalyst.types.{DataType, FractionalType, IntegralType, NumericType, NativeType}
abstract class Expression extends TreeNode[Expression] {
self: Product =>
@@ -86,6 +86,11 @@ abstract class Expression extends TreeNode[Expression] {
}
}
+ /**
+ * Evaluation helper function for 2 Numeric children expressions. Those expressions are supposed
+ * to be in the same data type, and also the return type.
+ * Either one of the expressions result is null, the evaluation result should be null.
+ */
@inline
protected final def n2(
i: Row,
@@ -115,6 +120,11 @@ abstract class Expression extends TreeNode[Expression] {
}
}
+ /**
+ * Evaluation helper function for 2 Fractional children expressions. Those expressions are
+ * supposed to be in the same data type, and also the return type.
+ * Either one of the expressions result is null, the evaluation result should be null.
+ */
@inline
protected final def f2(
i: Row,
@@ -143,6 +153,11 @@ abstract class Expression extends TreeNode[Expression] {
}
}
+ /**
+ * Evaluation helper function for 2 Integral children expressions. Those expressions are
+ * supposed to be in the same data type, and also the return type.
+ * Either one of the expressions result is null, the evaluation result should be null.
+ */
@inline
protected final def i2(
i: Row,
@@ -170,6 +185,43 @@ abstract class Expression extends TreeNode[Expression] {
}
}
}
+
+ /**
+ * Evaluation helper function for 2 Comparable children expressions. Those expressions are
+ * supposed to be in the same data type, and the return type should be Integer:
+ * Negative value: 1st argument less than 2nd argument
+ * Zero: 1st argument equals 2nd argument
+ * Positive value: 1st argument greater than 2nd argument
+ *
+ * Either one of the expressions result is null, the evaluation result should be null.
+ */
+ @inline
+ protected final def c2(
+ i: Row,
+ e1: Expression,
+ e2: Expression,
+ f: ((Ordering[Any], Any, Any) => Any)): Any = {
+ if (e1.dataType != e2.dataType) {
+ throw new TreeNodeException(this, s"Types do not match ${e1.dataType} != ${e2.dataType}")
+ }
+
+ val evalE1 = e1.apply(i)
+ if(evalE1 == null) {
+ null
+ } else {
+ val evalE2 = e2.apply(i)
+ if (evalE2 == null) {
+ null
+ } else {
+ e1.dataType match {
+ case i: NativeType =>
+ f.asInstanceOf[(Ordering[i.JvmType], i.JvmType, i.JvmType) => Boolean](
+ i.ordering, evalE1.asInstanceOf[i.JvmType], evalE2.asInstanceOf[i.JvmType])
+ case other => sys.error(s"Type $other does not support ordered operations")
+ }
+ }
+ }
+ }
}
abstract class BinaryExpression extends Expression with trees.BinaryNode[Expression] {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
index b82a12e0f7..d879b2b5e8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.catalyst.expressions
+import java.sql.Timestamp
+
import org.apache.spark.sql.catalyst.types._
object Literal {
@@ -29,6 +31,9 @@ object Literal {
case s: Short => Literal(s, ShortType)
case s: String => Literal(s, StringType)
case b: Boolean => Literal(b, BooleanType)
+ case d: BigDecimal => Literal(d, DecimalType)
+ case t: Timestamp => Literal(t, TimestampType)
+ case a: Array[Byte] => Literal(a, BinaryType)
case null => Literal(null, NullType)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 02fedd16b8..b74809e5ca 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -18,8 +18,9 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.trees
+import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.analysis.UnresolvedException
-import org.apache.spark.sql.catalyst.types.{BooleanType, StringType}
+import org.apache.spark.sql.catalyst.types.{BooleanType, StringType, TimestampType}
object InterpretedPredicate {
def apply(expression: Expression): (Row => Boolean) = {
@@ -123,70 +124,22 @@ case class Equals(left: Expression, right: Expression) extends BinaryComparison
case class LessThan(left: Expression, right: Expression) extends BinaryComparison {
def symbol = "<"
- override def apply(input: Row): Any = {
- if (left.dataType == StringType && right.dataType == StringType) {
- val l = left.apply(input)
- val r = right.apply(input)
- if(l == null || r == null) {
- null
- } else {
- l.asInstanceOf[String] < r.asInstanceOf[String]
- }
- } else {
- n2(input, left, right, _.lt(_, _))
- }
- }
+ override def apply(input: Row): Any = c2(input, left, right, _.lt(_, _))
}
case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryComparison {
def symbol = "<="
- override def apply(input: Row): Any = {
- if (left.dataType == StringType && right.dataType == StringType) {
- val l = left.apply(input)
- val r = right.apply(input)
- if(l == null || r == null) {
- null
- } else {
- l.asInstanceOf[String] <= r.asInstanceOf[String]
- }
- } else {
- n2(input, left, right, _.lteq(_, _))
- }
- }
+ override def apply(input: Row): Any = c2(input, left, right, _.lteq(_, _))
}
case class GreaterThan(left: Expression, right: Expression) extends BinaryComparison {
def symbol = ">"
- override def apply(input: Row): Any = {
- if (left.dataType == StringType && right.dataType == StringType) {
- val l = left.apply(input)
- val r = right.apply(input)
- if(l == null || r == null) {
- null
- } else {
- l.asInstanceOf[String] > r.asInstanceOf[String]
- }
- } else {
- n2(input, left, right, _.gt(_, _))
- }
- }
+ override def apply(input: Row): Any = c2(input, left, right, _.gt(_, _))
}
case class GreaterThanOrEqual(left: Expression, right: Expression) extends BinaryComparison {
def symbol = ">="
- override def apply(input: Row): Any = {
- if (left.dataType == StringType && right.dataType == StringType) {
- val l = left.apply(input)
- val r = right.apply(input)
- if(l == null || r == null) {
- null
- } else {
- l.asInstanceOf[String] >= r.asInstanceOf[String]
- }
- } else {
- n2(input, left, right, _.gteq(_, _))
- }
- }
+ override def apply(input: Row): Any = c2(input, left, right, _.gteq(_, _))
}
case class If(predicate: Expression, trueValue: Expression, falseValue: Expression)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala
index 7a45d1a1b8..cdeb01a965 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.catalyst.types
+import java.sql.Timestamp
+
import scala.reflect.runtime.universe.{typeTag, TypeTag}
import org.apache.spark.sql.catalyst.expressions.Expression
@@ -51,6 +53,16 @@ case object BooleanType extends NativeType {
val ordering = implicitly[Ordering[JvmType]]
}
+case object TimestampType extends NativeType {
+ type JvmType = Timestamp
+
+ @transient lazy val tag = typeTag[JvmType]
+
+ val ordering = new Ordering[JvmType] {
+ def compare(x: Timestamp, y: Timestamp) = x.compareTo(y)
+ }
+}
+
abstract class NumericType extends NativeType {
// Unfortunately we can't get this implicitly as that breaks Spark Serialization. In order for
// implicitly[Numeric[JvmType]] to be valid, we have to change JvmType from a type variable to a
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
index 52a205be3e..43876033d3 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.catalyst.expressions
+import java.sql.Timestamp
+
import org.scalatest.FunSuite
import org.apache.spark.sql.catalyst.types._
@@ -191,5 +193,56 @@ class ExpressionEvaluationSuite extends FunSuite {
evaluate("abbbbc" rlike regEx, new GenericRow(Array[Any]("**")))
}
}
+
+ test("data type casting") {
+
+ val sts = "1970-01-01 00:00:01.0"
+ val ts = Timestamp.valueOf(sts)
+
+ checkEvaluation("abdef" cast StringType, "abdef")
+ checkEvaluation("abdef" cast DecimalType, null)
+ checkEvaluation("abdef" cast TimestampType, null)
+ checkEvaluation("12.65" cast DecimalType, BigDecimal(12.65))
+
+ checkEvaluation(Literal(1) cast LongType, 1)
+ checkEvaluation(Cast(Literal(1) cast TimestampType, LongType), 1)
+ checkEvaluation(Cast(Literal(BigDecimal(1)) cast TimestampType, DecimalType), 1)
+ checkEvaluation(Cast(Literal(1.toDouble) cast TimestampType, DoubleType), 1.toDouble)
+
+ checkEvaluation(Cast(Literal(sts) cast TimestampType, StringType), sts)
+ checkEvaluation(Cast(Literal(ts) cast StringType, TimestampType), ts)
+
+ checkEvaluation(Cast("abdef" cast BinaryType, StringType), "abdef")
+
+ checkEvaluation(Cast(Cast(Cast(Cast(
+ Cast("5" cast ByteType, ShortType), IntegerType), FloatType), DoubleType), LongType), 5)
+ checkEvaluation(Cast(Cast(Cast(Cast(
+ Cast("5" cast ByteType, TimestampType), DecimalType), LongType), StringType), ShortType), 5)
+ checkEvaluation(Cast(Cast(Cast(Cast(
+ Cast("5" cast TimestampType, ByteType), DecimalType), LongType), StringType), ShortType), null)
+ checkEvaluation(Cast(Cast(Cast(Cast(
+ Cast("5" cast DecimalType, ByteType), TimestampType), LongType), StringType), ShortType), 5)
+ checkEvaluation(Literal(true) cast IntegerType, 1)
+ checkEvaluation(Literal(false) cast IntegerType, 0)
+ checkEvaluation(Cast(Literal(1) cast BooleanType, IntegerType), 1)
+ checkEvaluation(Cast(Literal(0) cast BooleanType, IntegerType), 0)
+ checkEvaluation("23" cast DoubleType, 23)
+ checkEvaluation("23" cast IntegerType, 23)
+ checkEvaluation("23" cast FloatType, 23)
+ checkEvaluation("23" cast DecimalType, 23)
+ checkEvaluation("23" cast ByteType, 23)
+ checkEvaluation("23" cast ShortType, 23)
+ checkEvaluation("2012-12-11" cast DoubleType, null)
+ checkEvaluation(Literal(123) cast IntegerType, 123)
+
+ intercept[Exception] {evaluate(Literal(1) cast BinaryType, null)}
+ }
+
+ test("timestamp") {
+ val ts1 = new Timestamp(12)
+ val ts2 = new Timestamp(123)
+ checkEvaluation(Literal("ab") < Literal("abc"), true)
+ checkEvaluation(Literal(ts1) < Literal(ts2), true)
+ }
}