aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala144
1 files changed, 78 insertions, 66 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index d69d490ad6..2d99d1a3fe 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -27,23 +27,65 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
-/** Cast the child expression to the target data type. */
-case class Cast(child: Expression, dataType: DataType) extends UnaryExpression with Logging {
- override def checkInputDataTypes(): TypeCheckResult = {
- if (resolve(child.dataType, dataType)) {
- TypeCheckResult.TypeCheckSuccess
- } else {
- TypeCheckResult.TypeCheckFailure(
- s"cannot cast ${child.dataType} to $dataType")
- }
- }
+object Cast {
- override def foldable: Boolean = child.foldable
+ /**
+ * Returns true iff we can cast `from` type to `to` type.
+ */
+ def canCast(from: DataType, to: DataType): Boolean = (from, to) match {
+ case (fromType, toType) if fromType == toType => true
+
+ case (NullType, _) => true
+
+ case (_, StringType) => true
- override def nullable: Boolean = forceNullable(child.dataType, dataType) || child.nullable
+ case (StringType, BinaryType) => true
- private[this] def forceNullable(from: DataType, to: DataType) = (from, to) match {
+ case (StringType, BooleanType) => true
+ case (DateType, BooleanType) => true
+ case (TimestampType, BooleanType) => true
+ case (_: NumericType, BooleanType) => true
+
+ case (StringType, TimestampType) => true
+ case (BooleanType, TimestampType) => true
+ case (DateType, TimestampType) => true
+ case (_: NumericType, TimestampType) => true
+
+ case (_, DateType) => true
+
+ case (StringType, _: NumericType) => true
+ case (BooleanType, _: NumericType) => true
+ case (DateType, _: NumericType) => true
+ case (TimestampType, _: NumericType) => true
+ case (_: NumericType, _: NumericType) => true
+
+ case (ArrayType(fromType, fn), ArrayType(toType, tn)) =>
+ canCast(fromType, toType) &&
+ resolvableNullability(fn || forceNullable(fromType, toType), tn)
+
+ case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) =>
+ canCast(fromKey, toKey) &&
+ (!forceNullable(fromKey, toKey)) &&
+ canCast(fromValue, toValue) &&
+ resolvableNullability(fn || forceNullable(fromValue, toValue), tn)
+
+ case (StructType(fromFields), StructType(toFields)) =>
+ fromFields.length == toFields.length &&
+ fromFields.zip(toFields).forall {
+ case (fromField, toField) =>
+ canCast(fromField.dataType, toField.dataType) &&
+ resolvableNullability(
+ fromField.nullable || forceNullable(fromField.dataType, toField.dataType),
+ toField.nullable)
+ }
+
+ case _ => false
+ }
+
+ private def resolvableNullability(from: Boolean, to: Boolean) = !from || to
+
+ private def forceNullable(from: DataType, to: DataType) = (from, to) match {
case (StringType, _: NumericType) => true
case (StringType, TimestampType) => true
case (DoubleType, TimestampType) => true
@@ -58,61 +100,24 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
case (_, DecimalType.Fixed(_, _)) => true // TODO: not all upcasts here can really give null
case _ => false
}
+}
- private[this] def resolvableNullability(from: Boolean, to: Boolean) = !from || to
-
- private[this] def resolve(from: DataType, to: DataType): Boolean = {
- (from, to) match {
- case (from, to) if from == to => true
-
- case (NullType, _) => true
-
- case (_, StringType) => true
-
- case (StringType, BinaryType) => true
-
- case (StringType, BooleanType) => true
- case (DateType, BooleanType) => true
- case (TimestampType, BooleanType) => true
- case (_: NumericType, BooleanType) => true
-
- case (StringType, TimestampType) => true
- case (BooleanType, TimestampType) => true
- case (DateType, TimestampType) => true
- case (_: NumericType, TimestampType) => true
-
- case (_, DateType) => true
-
- case (StringType, _: NumericType) => true
- case (BooleanType, _: NumericType) => true
- case (DateType, _: NumericType) => true
- case (TimestampType, _: NumericType) => true
- case (_: NumericType, _: NumericType) => true
-
- case (ArrayType(from, fn), ArrayType(to, tn)) =>
- resolve(from, to) &&
- resolvableNullability(fn || forceNullable(from, to), tn)
-
- case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) =>
- resolve(fromKey, toKey) &&
- (!forceNullable(fromKey, toKey)) &&
- resolve(fromValue, toValue) &&
- resolvableNullability(fn || forceNullable(fromValue, toValue), tn)
-
- case (StructType(fromFields), StructType(toFields)) =>
- fromFields.size == toFields.size &&
- fromFields.zip(toFields).forall {
- case (fromField, toField) =>
- resolve(fromField.dataType, toField.dataType) &&
- resolvableNullability(
- fromField.nullable || forceNullable(fromField.dataType, toField.dataType),
- toField.nullable)
- }
+/** Cast the child expression to the target data type. */
+case class Cast(child: Expression, dataType: DataType) extends UnaryExpression with Logging {
- case _ => false
+ override def checkInputDataTypes(): TypeCheckResult = {
+ if (Cast.canCast(child.dataType, dataType)) {
+ TypeCheckResult.TypeCheckSuccess
+ } else {
+ TypeCheckResult.TypeCheckFailure(
+ s"cannot cast ${child.dataType} to $dataType")
}
}
+ override def foldable: Boolean = child.foldable
+
+ override def nullable: Boolean = Cast.forceNullable(child.dataType, dataType) || child.nullable
+
override def toString: String = s"CAST($child, $dataType)"
// [[func]] assumes the input is no longer null because eval already does the null check.
@@ -172,7 +177,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
catch { case _: java.lang.IllegalArgumentException => null }
})
case BooleanType =>
- buildCast[Boolean](_, b => (if (b) 1L else 0))
+ buildCast[Boolean](_, b => if (b) 1L else 0)
case LongType =>
buildCast[Long](_, l => longToTimestamp(l))
case IntegerType =>
@@ -388,7 +393,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
case (fromField, toField) => cast(fromField.dataType, toField.dataType)
}
// TODO: Could be faster?
- val newRow = new GenericMutableRow(from.fields.size)
+ val newRow = new GenericMutableRow(from.fields.length)
buildCast[InternalRow](_, row => {
var i = 0
while (i < row.length) {
@@ -427,20 +432,23 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
- // TODO(cg): Add support for more data types.
+ // TODO: Add support for more data types.
(child.dataType, dataType) match {
case (BinaryType, StringType) =>
defineCodeGen (ctx, ev, c =>
s"${ctx.stringType}.fromBytes($c)")
+
case (DateType, StringType) =>
defineCodeGen(ctx, ev, c =>
s"""${ctx.stringType}.fromString(
org.apache.spark.sql.catalyst.util.DateTimeUtils.dateToString($c))""")
+
case (TimestampType, StringType) =>
defineCodeGen(ctx, ev, c =>
s"""${ctx.stringType}.fromString(
org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampToString($c))""")
+
case (_, StringType) =>
defineCodeGen(ctx, ev, c => s"${ctx.stringType}.fromString(String.valueOf($c))")
@@ -450,12 +458,16 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
case (BooleanType, dt: NumericType) =>
defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})($c ? 1 : 0)")
+
case (dt: DecimalType, BooleanType) =>
defineCodeGen(ctx, ev, c => s"!$c.isZero()")
+
case (dt: NumericType, BooleanType) =>
defineCodeGen(ctx, ev, c => s"$c != 0")
+
case (_: DecimalType, dt: NumericType) =>
defineCodeGen(ctx, ev, c => s"($c).to${ctx.primitiveTypeName(dt)}()")
+
case (_: NumericType, dt: NumericType) =>
defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})($c)")