aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorWenchen Fan <cloud0fan@outlook.com>2015-07-13 00:49:39 -0700
committerReynold Xin <rxin@databricks.com>2015-07-13 00:49:39 -0700
commit6b89943834a8d9d5d0ecfd97efcc10056d08532a (patch)
tree7383eb5ef241c044e01393cccedc8fdf5fb94e48 /sql
parent92540d22e45f9300f413f520a1770e9f36cfa833 (diff)
downloadspark-6b89943834a8d9d5d0ecfd97efcc10056d08532a.tar.gz
spark-6b89943834a8d9d5d0ecfd97efcc10056d08532a.tar.bz2
spark-6b89943834a8d9d5d0ecfd97efcc10056d08532a.zip
[SPARK-8944][SQL] Support casting between IntervalType and StringType
Author: Wenchen Fan <cloud0fan@outlook.com> Closes #7355 from cloud-fan/fromString and squashes the following commits: 3bbb9d6 [Wenchen Fan] fix code gen 7dab957 [Wenchen Fan] naming fix 0fbbe19 [Wenchen Fan] address comments ac1f3d1 [Wenchen Fan] Support casting between IntervalType and StringType
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala17
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala10
2 files changed, 26 insertions, 1 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 7f2383dedc..ab02addfb4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.unsafe.types.{Interval, UTF8String}
object Cast {
@@ -55,6 +55,9 @@ object Cast {
case (_, DateType) => true
+ case (StringType, IntervalType) => true
+ case (IntervalType, StringType) => true
+
case (StringType, _: NumericType) => true
case (BooleanType, _: NumericType) => true
case (DateType, _: NumericType) => true
@@ -232,6 +235,13 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
case _ => _ => null
}
+ // IntervalConverter
+ private[this] def castToInterval(from: DataType): Any => Any = from match {
+ case StringType =>
+ buildCast[UTF8String](_, s => Interval.fromString(s.toString))
+ case _ => _ => null
+ }
+
// LongConverter
private[this] def castToLong(from: DataType): Any => Any = from match {
case StringType =>
@@ -405,6 +415,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
case DateType => castToDate(from)
case decimal: DecimalType => castToDecimal(from, decimal)
case TimestampType => castToTimestamp(from)
+ case IntervalType => castToInterval(from)
case BooleanType => castToBoolean(from)
case ByteType => castToByte(from)
case ShortType => castToShort(from)
@@ -442,6 +453,10 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
case (_, StringType) =>
defineCodeGen(ctx, ev, c => s"${ctx.stringType}.fromString(String.valueOf($c))")
+ case (StringType, IntervalType) =>
+ defineCodeGen(ctx, ev, c =>
+ s"org.apache.spark.unsafe.types.Interval.fromString($c.toString())")
+
// fallback for DecimalType, this must be before other numeric types
case (_, dt: DecimalType) =>
super.genCode(ctx, ev)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
index 919fdd470b..1de161c367 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
@@ -563,4 +563,14 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
InternalRow(0L)))
}
+ test("case between string and interval") {
+ import org.apache.spark.unsafe.types.Interval
+
+ checkEvaluation(Cast(Literal("interval -3 month 7 hours"), IntervalType),
+ new Interval(-3, 7 * Interval.MICROS_PER_HOUR))
+ checkEvaluation(Cast(Literal.create(
+ new Interval(15, -3 * Interval.MICROS_PER_DAY), IntervalType), StringType),
+ "interval 1 years 3 months -3 days")
+ }
+
}