aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Column.scala19
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala8
2 files changed, 21 insertions, 6 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index c58addaf90..9b8334d334 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -68,6 +68,18 @@ class TypedColumn[-T, U](
}
new TypedColumn[T, U](newExpr, encoder)
}
+
+ /**
+ * Gives the TypedColumn a name (alias).
+ * If the current TypedColumn has metadata associated with it, this metadata will be propagated
+ * to the new column.
+ *
+ * @group expr_ops
+ * @since 2.0.0
+ */
+ override def name(alias: String): TypedColumn[T, U] =
+ new TypedColumn[T, U](super.name(alias).expr, encoder)
+
}
/**
@@ -910,12 +922,7 @@ class Column(protected[sql] val expr: Expression) extends Logging {
* @group expr_ops
* @since 1.3.0
*/
- def as(alias: Symbol): Column = withExpr {
- expr match {
- case ne: NamedExpression => Alias(expr, alias.name)(explicitMetadata = Some(ne.metadata))
- case other => Alias(other, alias.name)()
- }
- }
+ def as(alias: Symbol): Column = name(alias.name)
/**
* Gives the column an alias with metadata.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
index 6eae3ed7ad..b2a0f3d67e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
@@ -232,4 +232,12 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
"a" -> Seq(1, 2)
)
}
+
+ test("spark-15051 alias of aggregator in DataFrame/Dataset[Row]") {
+ val df1 = Seq(1 -> "a", 2 -> "b", 3 -> "b").toDF("i", "j")
+ checkAnswer(df1.agg(RowAgg.toColumn as "b"), Row(6) :: Nil)
+
+ val df2 = Seq(1 -> "a", 2 -> "b", 3 -> "b").toDF("i", "j")
+ checkAnswer(df2.agg(RowAgg.toColumn as "b").select("b"), Row(6) :: Nil)
+ }
}