diff options
Diffstat (limited to 'sql')
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/Column.scala | 19 | ||||
-rw-r--r-- | sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala | 8 |
2 files changed, 21 insertions, 6 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index c58addaf90..9b8334d334 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -68,6 +68,18 @@ class TypedColumn[-T, U]( } new TypedColumn[T, U](newExpr, encoder) } + + /** + * Gives the TypedColumn a name (alias). + * If the current TypedColumn has metadata associated with it, this metadata will be propagated + * to the new column. + * + * @group expr_ops + * @since 2.0.0 + */ + override def name(alias: String): TypedColumn[T, U] = + new TypedColumn[T, U](super.name(alias).expr, encoder) + } /** @@ -910,12 +922,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def as(alias: Symbol): Column = withExpr { - expr match { - case ne: NamedExpression => Alias(expr, alias.name)(explicitMetadata = Some(ne.metadata)) - case other => Alias(other, alias.name)() - } - } + def as(alias: Symbol): Column = name(alias.name) /** * Gives the column an alias with metadata. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala index 6eae3ed7ad..b2a0f3d67e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -232,4 +232,12 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { "a" -> Seq(1, 2) ) } + + test("spark-15051 alias of aggregator in DataFrame/Dataset[Row]") { + val df1 = Seq(1 -> "a", 2 -> "b", 3 -> "b").toDF("i", "j") + checkAnswer(df1.agg(RowAgg.toColumn as "b"), Row(6) :: Nil) + + val df2 = Seq(1 -> "a", 2 -> "b", 3 -> "b").toDF("i", "j") + checkAnswer(df2.agg(RowAgg.toColumn as "b").select("b"), Row(6) :: Nil) + } } |