From 5e1ee28984b169eaab5d2f832921d32cf09de915 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Sat, 21 May 2016 08:36:08 -0700 Subject: [SPARK-15114][SQL] Column name generated by typed aggregate is super verbose ## What changes were proposed in this pull request? Generate a shorter default alias for `AggregateExpression `, In this PR, aggregate function name along with a index is used for generating the alias name. ```SQL val ds = Seq(1, 3, 2, 5).toDS() ds.select(typed.sum((i: Int) => i), typed.avg((i: Int) => i)).show() ``` Output before change. ```SQL +-----------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------------+ |typedsumdouble(unresolveddeserializer(upcast(input[0, int], IntegerType, - root class: "scala.Int"), value#1), upcast(value))|typedaverage(unresolveddeserializer(upcast(input[0, int], IntegerType, - root class: "scala.Int"), value#1), newInstance(class scala.Tuple2))| +-----------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------------+ | 11.0| 2.75| +-----------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------------+ ``` Output after change: ```SQL +-----------------+---------------+ |typedsumdouble_c1|typedaverage_c2| +-----------------+---------------+ | 11.0| 2.75| +-----------------+---------------+ ``` Note: There is one test in ParquetSuites.scala which shows that that the system picked alias name is not usable and is rejected. [test](https://github.com/apache/spark/blob/master/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala#L672-#L687) ## How was this patch tested? A new test was added in DataSetAggregatorSuite. Author: Dilip Biswal Closes #13045 from dilipbiswal/spark-15114. --- .../apache/spark/sql/catalyst/analysis/Analyzer.scala | 6 ++++-- .../apache/spark/sql/catalyst/analysis/unresolved.scala | 7 +++++-- .../src/main/scala/org/apache/spark/sql/Column.scala | 16 +++++++++++++++- .../org/apache/spark/sql/RelationalGroupedDataset.scala | 3 +++ .../org/apache/spark/sql/DatasetAggregatorSuite.scala | 12 ++++++++++++ 5 files changed, 39 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 2c269478ee..9a92330f75 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -177,14 +177,16 @@ class Analyzer( private def assignAliases(exprs: Seq[NamedExpression]) = { exprs.zipWithIndex.map { case (expr, i) => - expr.transformUp { case u @ UnresolvedAlias(child, optionalAliasName) => + expr.transformUp { case u @ UnresolvedAlias(child, optGenAliasFunc) => child match { case ne: NamedExpression => ne case e if !e.resolved => u case g: Generator => MultiAlias(g, Nil) case c @ Cast(ne: NamedExpression, _) => Alias(c, ne.name)() case e: ExtractValue => Alias(e, toPrettySQL(e))() - case e => Alias(e, optionalAliasName.getOrElse(toPrettySQL(e)))() + case e if optGenAliasFunc.isDefined => + Alias(child, optGenAliasFunc.get.apply(e))() + case e => Alias(e, toPrettySQL(e))() } } }.asInstanceOf[Seq[NamedExpression]] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 1f1897dc36..e953eda784 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -325,10 +325,13 @@ case class UnresolvedExtractValue(child: Expression, extraction: Expression) * Holds the expression that has yet to be aliased. * * @param child The computation that is needs to be resolved during analysis. - * @param aliasName The name if specified to be associated with the result of computing [[child]] + * @param aliasFunc The function if specified to be called to generate an alias to associate + * with the result of computing [[child]] * */ -case class UnresolvedAlias(child: Expression, aliasName: Option[String] = None) +case class UnresolvedAlias( + child: Expression, + aliasFunc: Option[Expression => String] = None) extends UnaryExpression with NamedExpression with Unevaluable { override def toAttribute: Attribute = throw new UnresolvedException(this, "toAttribute") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 9b8334d334..204af719b2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -24,6 +24,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.util.usePrettyExpression import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression @@ -37,6 +38,14 @@ private[sql] object Column { def apply(expr: Expression): Column = new Column(expr) def unapply(col: Column): Option[Expression] = Some(col.expr) + + private[sql] def generateAlias(e: Expression): String = { + e match { + case a: AggregateExpression if a.aggregateFunction.isInstanceOf[TypedAggregateExpression] => + a.aggregateFunction.toString + case expr => usePrettyExpression(expr).sql + } + } } /** @@ -145,7 +154,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { case jt: JsonTuple => MultiAlias(jt, Nil) - case func: UnresolvedFunction => UnresolvedAlias(func, Some(usePrettyExpression(func).sql)) + case func: UnresolvedFunction => UnresolvedAlias(func, Some(Column.generateAlias)) // If we have a top level Cast, there is a chance to give it a better alias, if there is a // NamedExpression under this Cast. @@ -156,9 +165,14 @@ class Column(protected[sql] val expr: Expression) extends Logging { case other => Alias(expr, usePrettyExpression(expr).sql)() } + case a: AggregateExpression if a.aggregateFunction.isInstanceOf[TypedAggregateExpression] => + UnresolvedAlias(a, Some(Column.generateAlias)) + case expr: Expression => Alias(expr, usePrettyExpression(expr).sql)() } + + override def toString: String = usePrettyExpression(expr).sql override def equals(that: Any): Boolean = that match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 4f5bf633fa..b0e48a6553 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Pivot} import org.apache.spark.sql.catalyst.util.usePrettyExpression +import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.NumericType @@ -73,6 +74,8 @@ class RelationalGroupedDataset protected[sql]( private[this] def alias(expr: Expression): NamedExpression = expr match { case u: UnresolvedAttribute => UnresolvedAlias(u) case expr: NamedExpression => expr + case a: AggregateExpression if (a.aggregateFunction.isInstanceOf[TypedAggregateExpression]) => + UnresolvedAlias(a, Some(Column.generateAlias)) case expr: Expression => Alias(expr, usePrettyExpression(expr).sql)() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala index f1585ca3ff..ead7bd9642 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -240,4 +240,16 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { val df2 = Seq(1 -> "a", 2 -> "b", 3 -> "b").toDF("i", "j") checkAnswer(df2.agg(RowAgg.toColumn as "b").select("b"), Row(6) :: Nil) } + + test("spark-15114 shorter system generated alias names") { + val ds = Seq(1, 3, 2, 5).toDS() + assert(ds.select(typed.sum((i: Int) => i)).columns.head === "TypedSumDouble(int)") + val ds2 = ds.select(typed.sum((i: Int) => i), typed.avg((i: Int) => i)) + assert(ds2.columns.head === "TypedSumDouble(int)") + assert(ds2.columns.last === "TypedAverage(int)") + val df = Seq(1 -> "a", 2 -> "b", 3 -> "b").toDF("i", "j") + assert(df.groupBy($"j").agg(RowAgg.toColumn).columns.last == + "RowAgg(org.apache.spark.sql.Row)") + assert(df.groupBy($"j").agg(RowAgg.toColumn as "agg1").columns.last == "agg1") + } } -- cgit v1.2.3