aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDilip Biswal <dbiswal@us.ibm.com>2016-05-21 08:36:08 -0700
committerWenchen Fan <wenchen@databricks.com>2016-05-21 08:36:08 -0700
commit5e1ee28984b169eaab5d2f832921d32cf09de915 (patch)
treeda3bbbc7f541e3dbf561522561e4f080290e7c48
parentf39621c998a0fe91a5115f3f843c3ca8dd71c1ab (diff)
downloadspark-5e1ee28984b169eaab5d2f832921d32cf09de915.tar.gz
spark-5e1ee28984b169eaab5d2f832921d32cf09de915.tar.bz2
spark-5e1ee28984b169eaab5d2f832921d32cf09de915.zip
[SPARK-15114][SQL] Column name generated by typed aggregate is super verbose
## What changes were proposed in this pull request? Generate a shorter default alias for `AggregateExpression `, In this PR, aggregate function name along with a index is used for generating the alias name. ```SQL val ds = Seq(1, 3, 2, 5).toDS() ds.select(typed.sum((i: Int) => i), typed.avg((i: Int) => i)).show() ``` Output before change. ```SQL +-----------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------------+ |typedsumdouble(unresolveddeserializer(upcast(input[0, int], IntegerType, - root class: "scala.Int"), value#1), upcast(value))|typedaverage(unresolveddeserializer(upcast(input[0, int], IntegerType, - root class: "scala.Int"), value#1), newInstance(class scala.Tuple2))| +-----------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------------+ | 11.0| 2.75| +-----------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------------+ ``` Output after change: ```SQL +-----------------+---------------+ |typedsumdouble_c1|typedaverage_c2| +-----------------+---------------+ | 11.0| 2.75| +-----------------+---------------+ ``` Note: There is one test in ParquetSuites.scala which shows that that the system picked alias name is not usable and is rejected. [test](https://github.com/apache/spark/blob/master/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala#L672-#L687) ## How was this patch tested? A new test was added in DataSetAggregatorSuite. Author: Dilip Biswal <dbiswal@us.ibm.com> Closes #13045 from dilipbiswal/spark-15114.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala7
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Column.scala16
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala3
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala12
5 files changed, 39 insertions, 5 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 2c269478ee..9a92330f75 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -177,14 +177,16 @@ class Analyzer(
private def assignAliases(exprs: Seq[NamedExpression]) = {
exprs.zipWithIndex.map {
case (expr, i) =>
- expr.transformUp { case u @ UnresolvedAlias(child, optionalAliasName) =>
+ expr.transformUp { case u @ UnresolvedAlias(child, optGenAliasFunc) =>
child match {
case ne: NamedExpression => ne
case e if !e.resolved => u
case g: Generator => MultiAlias(g, Nil)
case c @ Cast(ne: NamedExpression, _) => Alias(c, ne.name)()
case e: ExtractValue => Alias(e, toPrettySQL(e))()
- case e => Alias(e, optionalAliasName.getOrElse(toPrettySQL(e)))()
+ case e if optGenAliasFunc.isDefined =>
+ Alias(child, optGenAliasFunc.get.apply(e))()
+ case e => Alias(e, toPrettySQL(e))()
}
}
}.asInstanceOf[Seq[NamedExpression]]
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index 1f1897dc36..e953eda784 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -325,10 +325,13 @@ case class UnresolvedExtractValue(child: Expression, extraction: Expression)
* Holds the expression that has yet to be aliased.
*
* @param child The computation that is needs to be resolved during analysis.
- * @param aliasName The name if specified to be associated with the result of computing [[child]]
+ * @param aliasFunc The function if specified to be called to generate an alias to associate
+ * with the result of computing [[child]]
*
*/
-case class UnresolvedAlias(child: Expression, aliasName: Option[String] = None)
+case class UnresolvedAlias(
+ child: Expression,
+ aliasFunc: Option[Expression => String] = None)
extends UnaryExpression with NamedExpression with Unevaluable {
override def toAttribute: Attribute = throw new UnresolvedException(this, "toAttribute")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index 9b8334d334..204af719b2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -24,6 +24,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder}
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.catalyst.util.usePrettyExpression
import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
@@ -37,6 +38,14 @@ private[sql] object Column {
def apply(expr: Expression): Column = new Column(expr)
def unapply(col: Column): Option[Expression] = Some(col.expr)
+
+ private[sql] def generateAlias(e: Expression): String = {
+ e match {
+ case a: AggregateExpression if a.aggregateFunction.isInstanceOf[TypedAggregateExpression] =>
+ a.aggregateFunction.toString
+ case expr => usePrettyExpression(expr).sql
+ }
+ }
}
/**
@@ -145,7 +154,7 @@ class Column(protected[sql] val expr: Expression) extends Logging {
case jt: JsonTuple => MultiAlias(jt, Nil)
- case func: UnresolvedFunction => UnresolvedAlias(func, Some(usePrettyExpression(func).sql))
+ case func: UnresolvedFunction => UnresolvedAlias(func, Some(Column.generateAlias))
// If we have a top level Cast, there is a chance to give it a better alias, if there is a
// NamedExpression under this Cast.
@@ -156,9 +165,14 @@ class Column(protected[sql] val expr: Expression) extends Logging {
case other => Alias(expr, usePrettyExpression(expr).sql)()
}
+ case a: AggregateExpression if a.aggregateFunction.isInstanceOf[TypedAggregateExpression] =>
+ UnresolvedAlias(a, Some(Column.generateAlias))
+
case expr: Expression => Alias(expr, usePrettyExpression(expr).sql)()
}
+
+
override def toString: String = usePrettyExpression(expr).sql
override def equals(that: Any): Boolean = that match {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
index 4f5bf633fa..b0e48a6553 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
@@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Pivot}
import org.apache.spark.sql.catalyst.util.usePrettyExpression
+import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.NumericType
@@ -73,6 +74,8 @@ class RelationalGroupedDataset protected[sql](
private[this] def alias(expr: Expression): NamedExpression = expr match {
case u: UnresolvedAttribute => UnresolvedAlias(u)
case expr: NamedExpression => expr
+ case a: AggregateExpression if (a.aggregateFunction.isInstanceOf[TypedAggregateExpression]) =>
+ UnresolvedAlias(a, Some(Column.generateAlias))
case expr: Expression => Alias(expr, usePrettyExpression(expr).sql)()
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
index f1585ca3ff..ead7bd9642 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
@@ -240,4 +240,16 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
val df2 = Seq(1 -> "a", 2 -> "b", 3 -> "b").toDF("i", "j")
checkAnswer(df2.agg(RowAgg.toColumn as "b").select("b"), Row(6) :: Nil)
}
+
+ test("spark-15114 shorter system generated alias names") {
+ val ds = Seq(1, 3, 2, 5).toDS()
+ assert(ds.select(typed.sum((i: Int) => i)).columns.head === "TypedSumDouble(int)")
+ val ds2 = ds.select(typed.sum((i: Int) => i), typed.avg((i: Int) => i))
+ assert(ds2.columns.head === "TypedSumDouble(int)")
+ assert(ds2.columns.last === "TypedAverage(int)")
+ val df = Seq(1 -> "a", 2 -> "b", 3 -> "b").toDF("i", "j")
+ assert(df.groupBy($"j").agg(RowAgg.toColumn).columns.last ==
+ "RowAgg(org.apache.spark.sql.Row)")
+ assert(df.groupBy($"j").agg(RowAgg.toColumn as "agg1").columns.last == "agg1")
+ }
}