aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2015-06-30 23:04:54 -0700
committerReynold Xin <rxin@databricks.com>2015-06-30 23:04:54 -0700
commit365c14055e90db5ea4b25afec03022be81c8a704 (patch)
tree2524baa166530f9cc73729ef06e529d9e5bf8e87 /sql
parent64c14618d3f4ede042bd3f6a542bc17a730afb0e (diff)
downloadspark-365c14055e90db5ea4b25afec03022be81c8a704.tar.gz
spark-365c14055e90db5ea4b25afec03022be81c8a704.tar.bz2
spark-365c14055e90db5ea4b25afec03022be81c8a704.zip
[SPARK-8748][SQL] Move castability test out from Cast case class into Cast object.
This patch moved resolve function in Cast case class into the companion object, and renamed it canCast. We can then use this in the analyzer without a Cast expr. Author: Reynold Xin <rxin@databricks.com> Closes #7145 from rxin/cast and squashes the following commits: cd086a9 [Reynold Xin] Whitespace changes. 4d2d989 [Reynold Xin] [SPARK-8748][SQL] Move castability test out from Cast case class into Cast object.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala144
1 files changed, 78 insertions, 66 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index d69d490ad6..2d99d1a3fe 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -27,23 +27,65 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
-/** Cast the child expression to the target data type. */
-case class Cast(child: Expression, dataType: DataType) extends UnaryExpression with Logging {
- override def checkInputDataTypes(): TypeCheckResult = {
- if (resolve(child.dataType, dataType)) {
- TypeCheckResult.TypeCheckSuccess
- } else {
- TypeCheckResult.TypeCheckFailure(
- s"cannot cast ${child.dataType} to $dataType")
- }
- }
+object Cast {
- override def foldable: Boolean = child.foldable
+ /**
+ * Returns true iff we can cast `from` type to `to` type.
+ */
+ def canCast(from: DataType, to: DataType): Boolean = (from, to) match {
+ case (fromType, toType) if fromType == toType => true
+
+ case (NullType, _) => true
+
+ case (_, StringType) => true
- override def nullable: Boolean = forceNullable(child.dataType, dataType) || child.nullable
+ case (StringType, BinaryType) => true
- private[this] def forceNullable(from: DataType, to: DataType) = (from, to) match {
+ case (StringType, BooleanType) => true
+ case (DateType, BooleanType) => true
+ case (TimestampType, BooleanType) => true
+ case (_: NumericType, BooleanType) => true
+
+ case (StringType, TimestampType) => true
+ case (BooleanType, TimestampType) => true
+ case (DateType, TimestampType) => true
+ case (_: NumericType, TimestampType) => true
+
+ case (_, DateType) => true
+
+ case (StringType, _: NumericType) => true
+ case (BooleanType, _: NumericType) => true
+ case (DateType, _: NumericType) => true
+ case (TimestampType, _: NumericType) => true
+ case (_: NumericType, _: NumericType) => true
+
+ case (ArrayType(fromType, fn), ArrayType(toType, tn)) =>
+ canCast(fromType, toType) &&
+ resolvableNullability(fn || forceNullable(fromType, toType), tn)
+
+ case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) =>
+ canCast(fromKey, toKey) &&
+ (!forceNullable(fromKey, toKey)) &&
+ canCast(fromValue, toValue) &&
+ resolvableNullability(fn || forceNullable(fromValue, toValue), tn)
+
+ case (StructType(fromFields), StructType(toFields)) =>
+ fromFields.length == toFields.length &&
+ fromFields.zip(toFields).forall {
+ case (fromField, toField) =>
+ canCast(fromField.dataType, toField.dataType) &&
+ resolvableNullability(
+ fromField.nullable || forceNullable(fromField.dataType, toField.dataType),
+ toField.nullable)
+ }
+
+ case _ => false
+ }
+
+ private def resolvableNullability(from: Boolean, to: Boolean) = !from || to
+
+ private def forceNullable(from: DataType, to: DataType) = (from, to) match {
case (StringType, _: NumericType) => true
case (StringType, TimestampType) => true
case (DoubleType, TimestampType) => true
@@ -58,61 +100,24 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
case (_, DecimalType.Fixed(_, _)) => true // TODO: not all upcasts here can really give null
case _ => false
}
+}
- private[this] def resolvableNullability(from: Boolean, to: Boolean) = !from || to
-
- private[this] def resolve(from: DataType, to: DataType): Boolean = {
- (from, to) match {
- case (from, to) if from == to => true
-
- case (NullType, _) => true
-
- case (_, StringType) => true
-
- case (StringType, BinaryType) => true
-
- case (StringType, BooleanType) => true
- case (DateType, BooleanType) => true
- case (TimestampType, BooleanType) => true
- case (_: NumericType, BooleanType) => true
-
- case (StringType, TimestampType) => true
- case (BooleanType, TimestampType) => true
- case (DateType, TimestampType) => true
- case (_: NumericType, TimestampType) => true
-
- case (_, DateType) => true
-
- case (StringType, _: NumericType) => true
- case (BooleanType, _: NumericType) => true
- case (DateType, _: NumericType) => true
- case (TimestampType, _: NumericType) => true
- case (_: NumericType, _: NumericType) => true
-
- case (ArrayType(from, fn), ArrayType(to, tn)) =>
- resolve(from, to) &&
- resolvableNullability(fn || forceNullable(from, to), tn)
-
- case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) =>
- resolve(fromKey, toKey) &&
- (!forceNullable(fromKey, toKey)) &&
- resolve(fromValue, toValue) &&
- resolvableNullability(fn || forceNullable(fromValue, toValue), tn)
-
- case (StructType(fromFields), StructType(toFields)) =>
- fromFields.size == toFields.size &&
- fromFields.zip(toFields).forall {
- case (fromField, toField) =>
- resolve(fromField.dataType, toField.dataType) &&
- resolvableNullability(
- fromField.nullable || forceNullable(fromField.dataType, toField.dataType),
- toField.nullable)
- }
+/** Cast the child expression to the target data type. */
+case class Cast(child: Expression, dataType: DataType) extends UnaryExpression with Logging {
- case _ => false
+ override def checkInputDataTypes(): TypeCheckResult = {
+ if (Cast.canCast(child.dataType, dataType)) {
+ TypeCheckResult.TypeCheckSuccess
+ } else {
+ TypeCheckResult.TypeCheckFailure(
+ s"cannot cast ${child.dataType} to $dataType")
}
}
+ override def foldable: Boolean = child.foldable
+
+ override def nullable: Boolean = Cast.forceNullable(child.dataType, dataType) || child.nullable
+
override def toString: String = s"CAST($child, $dataType)"
// [[func]] assumes the input is no longer null because eval already does the null check.
@@ -172,7 +177,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
catch { case _: java.lang.IllegalArgumentException => null }
})
case BooleanType =>
- buildCast[Boolean](_, b => (if (b) 1L else 0))
+ buildCast[Boolean](_, b => if (b) 1L else 0)
case LongType =>
buildCast[Long](_, l => longToTimestamp(l))
case IntegerType =>
@@ -388,7 +393,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
case (fromField, toField) => cast(fromField.dataType, toField.dataType)
}
// TODO: Could be faster?
- val newRow = new GenericMutableRow(from.fields.size)
+ val newRow = new GenericMutableRow(from.fields.length)
buildCast[InternalRow](_, row => {
var i = 0
while (i < row.length) {
@@ -427,20 +432,23 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
- // TODO(cg): Add support for more data types.
+ // TODO: Add support for more data types.
(child.dataType, dataType) match {
case (BinaryType, StringType) =>
defineCodeGen (ctx, ev, c =>
s"${ctx.stringType}.fromBytes($c)")
+
case (DateType, StringType) =>
defineCodeGen(ctx, ev, c =>
s"""${ctx.stringType}.fromString(
org.apache.spark.sql.catalyst.util.DateTimeUtils.dateToString($c))""")
+
case (TimestampType, StringType) =>
defineCodeGen(ctx, ev, c =>
s"""${ctx.stringType}.fromString(
org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampToString($c))""")
+
case (_, StringType) =>
defineCodeGen(ctx, ev, c => s"${ctx.stringType}.fromString(String.valueOf($c))")
@@ -450,12 +458,16 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
case (BooleanType, dt: NumericType) =>
defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})($c ? 1 : 0)")
+
case (dt: DecimalType, BooleanType) =>
defineCodeGen(ctx, ev, c => s"!$c.isZero()")
+
case (dt: NumericType, BooleanType) =>
defineCodeGen(ctx, ev, c => s"$c != 0")
+
case (_: DecimalType, dt: NumericType) =>
defineCodeGen(ctx, ev, c => s"($c).to${ctx.primitiveTypeName(dt)}()")
+
case (_: NumericType, dt: NumericType) =>
defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})($c)")