diff options
Diffstat (limited to 'sql')
-rw-r--r-- | sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala | 144 |
1 files changed, 78 insertions, 66 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index d69d490ad6..2d99d1a3fe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -27,23 +27,65 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -/** Cast the child expression to the target data type. */ -case class Cast(child: Expression, dataType: DataType) extends UnaryExpression with Logging { - override def checkInputDataTypes(): TypeCheckResult = { - if (resolve(child.dataType, dataType)) { - TypeCheckResult.TypeCheckSuccess - } else { - TypeCheckResult.TypeCheckFailure( - s"cannot cast ${child.dataType} to $dataType") - } - } +object Cast { - override def foldable: Boolean = child.foldable + /** + * Returns true iff we can cast `from` type to `to` type. + */ + def canCast(from: DataType, to: DataType): Boolean = (from, to) match { + case (fromType, toType) if fromType == toType => true + + case (NullType, _) => true + + case (_, StringType) => true - override def nullable: Boolean = forceNullable(child.dataType, dataType) || child.nullable + case (StringType, BinaryType) => true - private[this] def forceNullable(from: DataType, to: DataType) = (from, to) match { + case (StringType, BooleanType) => true + case (DateType, BooleanType) => true + case (TimestampType, BooleanType) => true + case (_: NumericType, BooleanType) => true + + case (StringType, TimestampType) => true + case (BooleanType, TimestampType) => true + case (DateType, TimestampType) => true + case (_: NumericType, TimestampType) => true + + case (_, DateType) => true + + case (StringType, _: NumericType) => true + case (BooleanType, _: NumericType) => true + case (DateType, _: NumericType) => true + case (TimestampType, _: NumericType) => true + case (_: NumericType, _: NumericType) => true + + case (ArrayType(fromType, fn), ArrayType(toType, tn)) => + canCast(fromType, toType) && + resolvableNullability(fn || forceNullable(fromType, toType), tn) + + case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) => + canCast(fromKey, toKey) && + (!forceNullable(fromKey, toKey)) && + canCast(fromValue, toValue) && + resolvableNullability(fn || forceNullable(fromValue, toValue), tn) + + case (StructType(fromFields), StructType(toFields)) => + fromFields.length == toFields.length && + fromFields.zip(toFields).forall { + case (fromField, toField) => + canCast(fromField.dataType, toField.dataType) && + resolvableNullability( + fromField.nullable || forceNullable(fromField.dataType, toField.dataType), + toField.nullable) + } + + case _ => false + } + + private def resolvableNullability(from: Boolean, to: Boolean) = !from || to + + private def forceNullable(from: DataType, to: DataType) = (from, to) match { case (StringType, _: NumericType) => true case (StringType, TimestampType) => true case (DoubleType, TimestampType) => true @@ -58,61 +100,24 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case (_, DecimalType.Fixed(_, _)) => true // TODO: not all upcasts here can really give null case _ => false } +} - private[this] def resolvableNullability(from: Boolean, to: Boolean) = !from || to - - private[this] def resolve(from: DataType, to: DataType): Boolean = { - (from, to) match { - case (from, to) if from == to => true - - case (NullType, _) => true - - case (_, StringType) => true - - case (StringType, BinaryType) => true - - case (StringType, BooleanType) => true - case (DateType, BooleanType) => true - case (TimestampType, BooleanType) => true - case (_: NumericType, BooleanType) => true - - case (StringType, TimestampType) => true - case (BooleanType, TimestampType) => true - case (DateType, TimestampType) => true - case (_: NumericType, TimestampType) => true - - case (_, DateType) => true - - case (StringType, _: NumericType) => true - case (BooleanType, _: NumericType) => true - case (DateType, _: NumericType) => true - case (TimestampType, _: NumericType) => true - case (_: NumericType, _: NumericType) => true - - case (ArrayType(from, fn), ArrayType(to, tn)) => - resolve(from, to) && - resolvableNullability(fn || forceNullable(from, to), tn) - - case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) => - resolve(fromKey, toKey) && - (!forceNullable(fromKey, toKey)) && - resolve(fromValue, toValue) && - resolvableNullability(fn || forceNullable(fromValue, toValue), tn) - - case (StructType(fromFields), StructType(toFields)) => - fromFields.size == toFields.size && - fromFields.zip(toFields).forall { - case (fromField, toField) => - resolve(fromField.dataType, toField.dataType) && - resolvableNullability( - fromField.nullable || forceNullable(fromField.dataType, toField.dataType), - toField.nullable) - } +/** Cast the child expression to the target data type. */ +case class Cast(child: Expression, dataType: DataType) extends UnaryExpression with Logging { - case _ => false + override def checkInputDataTypes(): TypeCheckResult = { + if (Cast.canCast(child.dataType, dataType)) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure( + s"cannot cast ${child.dataType} to $dataType") } } + override def foldable: Boolean = child.foldable + + override def nullable: Boolean = Cast.forceNullable(child.dataType, dataType) || child.nullable + override def toString: String = s"CAST($child, $dataType)" // [[func]] assumes the input is no longer null because eval already does the null check. @@ -172,7 +177,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w catch { case _: java.lang.IllegalArgumentException => null } }) case BooleanType => - buildCast[Boolean](_, b => (if (b) 1L else 0)) + buildCast[Boolean](_, b => if (b) 1L else 0) case LongType => buildCast[Long](_, l => longToTimestamp(l)) case IntegerType => @@ -388,7 +393,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case (fromField, toField) => cast(fromField.dataType, toField.dataType) } // TODO: Could be faster? - val newRow = new GenericMutableRow(from.fields.size) + val newRow = new GenericMutableRow(from.fields.length) buildCast[InternalRow](_, row => { var i = 0 while (i < row.length) { @@ -427,20 +432,23 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - // TODO(cg): Add support for more data types. + // TODO: Add support for more data types. (child.dataType, dataType) match { case (BinaryType, StringType) => defineCodeGen (ctx, ev, c => s"${ctx.stringType}.fromBytes($c)") + case (DateType, StringType) => defineCodeGen(ctx, ev, c => s"""${ctx.stringType}.fromString( org.apache.spark.sql.catalyst.util.DateTimeUtils.dateToString($c))""") + case (TimestampType, StringType) => defineCodeGen(ctx, ev, c => s"""${ctx.stringType}.fromString( org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampToString($c))""") + case (_, StringType) => defineCodeGen(ctx, ev, c => s"${ctx.stringType}.fromString(String.valueOf($c))") @@ -450,12 +458,16 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case (BooleanType, dt: NumericType) => defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})($c ? 1 : 0)") + case (dt: DecimalType, BooleanType) => defineCodeGen(ctx, ev, c => s"!$c.isZero()") + case (dt: NumericType, BooleanType) => defineCodeGen(ctx, ev, c => s"$c != 0") + case (_: DecimalType, dt: NumericType) => defineCodeGen(ctx, ev, c => s"($c).to${ctx.primitiveTypeName(dt)}()") + case (_: NumericType, dt: NumericType) => defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})($c)") |