aboutsummaryrefslogtreecommitdiff
path: root/sql/core
diff options
context:
space:
mode:
Diffstat (limited to 'sql/core')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Column.scala17
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala7
2 files changed, 17 insertions, 7 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index 97bf7a0cc4..2ab091e40a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -133,6 +133,15 @@ class Column(protected[sql] val expr: Expression) extends Logging {
case func: UnresolvedFunction => UnresolvedAlias(func, Some(func.prettyString))
+ // If we have a top level Cast, there is a chance to give it a better alias, if there is a
+ // NamedExpression under this Cast.
+ case c: Cast => c.transformUp {
+ case Cast(ne: NamedExpression, to) => UnresolvedAlias(Cast(ne, to))
+ } match {
+ case ne: NamedExpression => ne
+ case other => Alias(expr, expr.prettyString)()
+ }
+
case expr: Expression => Alias(expr, expr.prettyString)()
}
@@ -921,13 +930,7 @@ class Column(protected[sql] val expr: Expression) extends Logging {
* @group expr_ops
* @since 1.3.0
*/
- def cast(to: DataType): Column = withExpr {
- expr match {
- // keeps the name of expression if possible when do cast.
- case ne: NamedExpression => UnresolvedAlias(Cast(expr, to))
- case _ => Cast(expr, to)
- }
- }
+ def cast(to: DataType): Column = withExpr { Cast(expr, to) }
/**
* Casts the column to a different data type, using the canonical string representation
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index d6c140dfea..afc8df07fd 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -1007,6 +1007,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
test("SPARK-10743: keep the name of expression if possible when do cast") {
val df = (1 to 10).map(Tuple1.apply).toDF("i").as("src")
assert(df.select($"src.i".cast(StringType)).columns.head === "i")
+ assert(df.select($"src.i".cast(StringType).cast(IntegerType)).columns.head === "i")
}
test("SPARK-11301: fix case sensitivity for filter on partitioned columns") {
@@ -1228,4 +1229,10 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
checkAnswer(df.withColumn("col.a", lit("c")), Row("c", "b"))
checkAnswer(df.withColumn("col.c", lit("c")), Row("a", "b", "c"))
}
+
+ test("SPARK-12841: cast in filter") {
+ checkAnswer(
+ Seq(1 -> "a").toDF("i", "j").filter($"i".cast(StringType) === "1"),
+ Row(1, "a"))
+ }
}