aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2015-11-24 12:54:37 -0800
committerReynold Xin <rxin@databricks.com>2015-11-24 12:54:37 -0800
commitf3152722791b163fa66597b3684009058195ba33 (patch)
treec3b23ead8eaea14229f5b21642d2a4e8c5c191b2
parent81012546ee5a80d2576740af0dad067b0f5962c5 (diff)
downloadspark-f3152722791b163fa66597b3684009058195ba33.tar.gz
spark-f3152722791b163fa66597b3684009058195ba33.tar.bz2
spark-f3152722791b163fa66597b3684009058195ba33.zip
[SPARK-11946][SQL] Audit pivot API for 1.6.
Currently pivot's signature looks like ```scala scala.annotation.varargs def pivot(pivotColumn: Column, values: Column*): GroupedData scala.annotation.varargs def pivot(pivotColumn: String, values: Any*): GroupedData ``` I think we can remove the one that takes "Column" types, since callers should always be passing in literals. It'd also be more clear if the values are not varargs, but rather Seq or java.util.List. I also made similar changes for Python. Author: Reynold Xin <rxin@databricks.com> Closes #9929 from rxin/SPARK-11946.
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala1
-rw-r--r--python/pyspark/sql/group.py12
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala1
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala154
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java16
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala21
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala1
7 files changed, 125 insertions, 81 deletions
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index ae725b467d..77a184dfe4 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -1574,7 +1574,6 @@ class DAGScheduler(
}
def stop() {
- logInfo("Stopping DAGScheduler")
messageScheduler.shutdownNow()
eventProcessLoop.stop()
taskScheduler.stop()
diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py
index 227f40bc3c..d8ed7eb2dd 100644
--- a/python/pyspark/sql/group.py
+++ b/python/pyspark/sql/group.py
@@ -168,20 +168,24 @@ class GroupedData(object):
"""
@since(1.6)
- def pivot(self, pivot_col, *values):
+ def pivot(self, pivot_col, values=None):
"""Pivots a column of the current DataFrame and preform the specified aggregation.
:param pivot_col: Column to pivot
:param values: Optional list of values of pivotColumn that will be translated to columns in
the output data frame. If values are not provided the method with do an immediate call
to .distinct() on the pivot column.
- >>> df4.groupBy("year").pivot("course", "dotNET", "Java").sum("earnings").collect()
+
+ >>> df4.groupBy("year").pivot("course", ["dotNET", "Java"]).sum("earnings").collect()
[Row(year=2012, dotNET=15000, Java=20000), Row(year=2013, dotNET=48000, Java=30000)]
+
>>> df4.groupBy("year").pivot("course").sum("earnings").collect()
[Row(year=2012, Java=20000, dotNET=15000), Row(year=2013, Java=30000, dotNET=48000)]
"""
- jgd = self._jdf.pivot(_to_java_column(pivot_col),
- _to_seq(self.sql_ctx._sc, values, _create_column_from_literal))
+ if values is None:
+ jgd = self._jdf.pivot(pivot_col)
+ else:
+ jgd = self._jdf.pivot(pivot_col, values)
return GroupedData(jgd, self.sql_ctx)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
index e34fd49be8..68ec688c99 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
@@ -44,6 +44,7 @@ object Literal {
case a: Array[Byte] => Literal(a, BinaryType)
case i: CalendarInterval => Literal(i, CalendarIntervalType)
case null => Literal(null, NullType)
+ case v: Literal => v
case _ =>
throw new RuntimeException("Unsupported literal type " + v.getClass + " " + v)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
index 63dd7fbcbe..ee7150cbbf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
@@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, UnresolvedAli
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical.{Pivot, Rollup, Cube, Aggregate}
-import org.apache.spark.sql.types.{StringType, NumericType}
+import org.apache.spark.sql.types.NumericType
/**
@@ -282,74 +282,96 @@ class GroupedData protected[sql](
}
/**
- * (Scala-specific) Pivots a column of the current [[DataFrame]] and preform the specified
- * aggregation.
- * {{{
- * // Compute the sum of earnings for each year by course with each course as a separate column
- * df.groupBy($"year").pivot($"course", "dotNET", "Java").agg(sum($"earnings"))
- * // Or without specifying column values
- * df.groupBy($"year").pivot($"course").agg(sum($"earnings"))
- * }}}
- * @param pivotColumn Column to pivot
- * @param values Optional list of values of pivotColumn that will be translated to columns in the
- * output data frame. If values are not provided the method with do an immediate
- * call to .distinct() on the pivot column.
- * @since 1.6.0
- */
- @scala.annotation.varargs
- def pivot(pivotColumn: Column, values: Column*): GroupedData = groupType match {
- case _: GroupedData.PivotType =>
- throw new UnsupportedOperationException("repeated pivots are not supported")
- case GroupedData.GroupByType =>
- val pivotValues = if (values.nonEmpty) {
- values.map {
- case Column(literal: Literal) => literal
- case other =>
- throw new UnsupportedOperationException(
- s"The values of a pivot must be literals, found $other")
- }
- } else {
- // This is to prevent unintended OOM errors when the number of distinct values is large
- val maxValues = df.sqlContext.conf.getConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES)
- // Get the distinct values of the column and sort them so its consistent
- val values = df.select(pivotColumn)
- .distinct()
- .sort(pivotColumn)
- .map(_.get(0))
- .take(maxValues + 1)
- .map(Literal(_)).toSeq
- if (values.length > maxValues) {
- throw new RuntimeException(
- s"The pivot column $pivotColumn has more than $maxValues distinct values, " +
- "this could indicate an error. " +
- "If this was intended, set \"" + SQLConf.DATAFRAME_PIVOT_MAX_VALUES.key + "\" " +
- s"to at least the number of distinct values of the pivot column.")
- }
- values
- }
- new GroupedData(df, groupingExprs, GroupedData.PivotType(pivotColumn.expr, pivotValues))
- case _ =>
- throw new UnsupportedOperationException("pivot is only supported after a groupBy")
+ * Pivots a column of the current [[DataFrame]] and preform the specified aggregation.
+ * There are two versions of pivot function: one that requires the caller to specify the list
+ * of distinct values to pivot on, and one that does not. The latter is more concise but less
+ * efficient, because Spark needs to first compute the list of distinct values internally.
+ *
+ * {{{
+ * // Compute the sum of earnings for each year by course with each course as a separate column
+ * df.groupBy("year").pivot("course", Seq("dotNET", "Java")).sum("earnings")
+ *
+ * // Or without specifying column values (less efficient)
+ * df.groupBy("year").pivot("course").sum("earnings")
+ * }}}
+ *
+ * @param pivotColumn Name of the column to pivot.
+ * @since 1.6.0
+ */
+ def pivot(pivotColumn: String): GroupedData = {
+ // This is to prevent unintended OOM errors when the number of distinct values is large
+ val maxValues = df.sqlContext.conf.getConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES)
+ // Get the distinct values of the column and sort them so its consistent
+ val values = df.select(pivotColumn)
+ .distinct()
+ .sort(pivotColumn)
+ .map(_.get(0))
+ .take(maxValues + 1)
+ .toSeq
+
+ if (values.length > maxValues) {
+ throw new AnalysisException(
+ s"The pivot column $pivotColumn has more than $maxValues distinct values, " +
+ "this could indicate an error. " +
+ s"If this was intended, set ${SQLConf.DATAFRAME_PIVOT_MAX_VALUES.key} " +
+ "to at least the number of distinct values of the pivot column.")
+ }
+
+ pivot(pivotColumn, values)
}
/**
- * Pivots a column of the current [[DataFrame]] and preform the specified aggregation.
- * {{{
- * // Compute the sum of earnings for each year by course with each course as a separate column
- * df.groupBy("year").pivot("course", "dotNET", "Java").sum("earnings")
- * // Or without specifying column values
- * df.groupBy("year").pivot("course").sum("earnings")
- * }}}
- * @param pivotColumn Column to pivot
- * @param values Optional list of values of pivotColumn that will be translated to columns in the
- * output data frame. If values are not provided the method with do an immediate
- * call to .distinct() on the pivot column.
- * @since 1.6.0
- */
- @scala.annotation.varargs
- def pivot(pivotColumn: String, values: Any*): GroupedData = {
- val resolvedPivotColumn = Column(df.resolve(pivotColumn))
- pivot(resolvedPivotColumn, values.map(functions.lit): _*)
+ * Pivots a column of the current [[DataFrame]] and preform the specified aggregation.
+ * There are two versions of pivot function: one that requires the caller to specify the list
+ * of distinct values to pivot on, and one that does not. The latter is more concise but less
+ * efficient, because Spark needs to first compute the list of distinct values internally.
+ *
+ * {{{
+ * // Compute the sum of earnings for each year by course with each course as a separate column
+ * df.groupBy("year").pivot("course", Seq("dotNET", "Java")).sum("earnings")
+ *
+ * // Or without specifying column values (less efficient)
+ * df.groupBy("year").pivot("course").sum("earnings")
+ * }}}
+ *
+ * @param pivotColumn Name of the column to pivot.
+ * @param values List of values that will be translated to columns in the output DataFrame.
+ * @since 1.6.0
+ */
+ def pivot(pivotColumn: String, values: Seq[Any]): GroupedData = {
+ groupType match {
+ case GroupedData.GroupByType =>
+ new GroupedData(
+ df,
+ groupingExprs,
+ GroupedData.PivotType(df.resolve(pivotColumn), values.map(Literal.apply)))
+ case _: GroupedData.PivotType =>
+ throw new UnsupportedOperationException("repeated pivots are not supported")
+ case _ =>
+ throw new UnsupportedOperationException("pivot is only supported after a groupBy")
+ }
+ }
+
+ /**
+ * Pivots a column of the current [[DataFrame]] and preform the specified aggregation.
+ * There are two versions of pivot function: one that requires the caller to specify the list
+ * of distinct values to pivot on, and one that does not. The latter is more concise but less
+ * efficient, because Spark needs to first compute the list of distinct values internally.
+ *
+ * {{{
+ * // Compute the sum of earnings for each year by course with each course as a separate column
+ * df.groupBy("year").pivot("course", Arrays.<Object>asList("dotNET", "Java")).sum("earnings");
+ *
+ * // Or without specifying column values (less efficient)
+ * df.groupBy("year").pivot("course").sum("earnings");
+ * }}}
+ *
+ * @param pivotColumn Name of the column to pivot.
+ * @param values List of values that will be translated to columns in the output DataFrame.
+ * @since 1.6.0
+ */
+ def pivot(pivotColumn: String, values: java.util.List[Any]): GroupedData = {
+ pivot(pivotColumn, values.asScala)
}
}
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
index 567bdddece..a12fed3c0c 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
@@ -282,4 +282,20 @@ public class JavaDataFrameSuite {
Assert.assertEquals(1, actual[1].getLong(0));
Assert.assertTrue(2 <= actual[1].getLong(1) && actual[1].getLong(1) <= 13);
}
+
+ @Test
+ public void pivot() {
+ DataFrame df = context.table("courseSales");
+ Row[] actual = df.groupBy("year")
+ .pivot("course", Arrays.<Object>asList("dotNET", "Java"))
+ .agg(sum("earnings")).orderBy("year").collect();
+
+ Assert.assertEquals(2012, actual[0].getInt(0));
+ Assert.assertEquals(15000.0, actual[0].getDouble(1), 0.01);
+ Assert.assertEquals(20000.0, actual[0].getDouble(2), 0.01);
+
+ Assert.assertEquals(2013, actual[1].getInt(0));
+ Assert.assertEquals(48000.0, actual[1].getDouble(1), 0.01);
+ Assert.assertEquals(30000.0, actual[1].getDouble(2), 0.01);
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala
index 0c23d14267..fc53aba68e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala
@@ -25,7 +25,7 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext{
test("pivot courses with literals") {
checkAnswer(
- courseSales.groupBy($"year").pivot($"course", lit("dotNET"), lit("Java"))
+ courseSales.groupBy("year").pivot("course", Seq("dotNET", "Java"))
.agg(sum($"earnings")),
Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil
)
@@ -33,14 +33,15 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext{
test("pivot year with literals") {
checkAnswer(
- courseSales.groupBy($"course").pivot($"year", lit(2012), lit(2013)).agg(sum($"earnings")),
+ courseSales.groupBy("course").pivot("year", Seq(2012, 2013)).agg(sum($"earnings")),
Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
)
}
test("pivot courses with literals and multiple aggregations") {
checkAnswer(
- courseSales.groupBy($"year").pivot($"course", lit("dotNET"), lit("Java"))
+ courseSales.groupBy($"year")
+ .pivot("course", Seq("dotNET", "Java"))
.agg(sum($"earnings"), avg($"earnings")),
Row(2012, 15000.0, 7500.0, 20000.0, 20000.0) ::
Row(2013, 48000.0, 48000.0, 30000.0, 30000.0) :: Nil
@@ -49,14 +50,14 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext{
test("pivot year with string values (cast)") {
checkAnswer(
- courseSales.groupBy("course").pivot("year", "2012", "2013").sum("earnings"),
+ courseSales.groupBy("course").pivot("year", Seq("2012", "2013")).sum("earnings"),
Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
)
}
test("pivot year with int values") {
checkAnswer(
- courseSales.groupBy("course").pivot("year", 2012, 2013).sum("earnings"),
+ courseSales.groupBy("course").pivot("year", Seq(2012, 2013)).sum("earnings"),
Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
)
}
@@ -64,22 +65,22 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext{
test("pivot courses with no values") {
// Note Java comes before dotNet in sorted order
checkAnswer(
- courseSales.groupBy($"year").pivot($"course").agg(sum($"earnings")),
+ courseSales.groupBy("year").pivot("course").agg(sum($"earnings")),
Row(2012, 20000.0, 15000.0) :: Row(2013, 30000.0, 48000.0) :: Nil
)
}
test("pivot year with no values") {
checkAnswer(
- courseSales.groupBy($"course").pivot($"year").agg(sum($"earnings")),
+ courseSales.groupBy("course").pivot("year").agg(sum($"earnings")),
Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
)
}
- test("pivot max values inforced") {
+ test("pivot max values enforced") {
sqlContext.conf.setConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES, 1)
- intercept[RuntimeException](
- courseSales.groupBy($"year").pivot($"course")
+ intercept[AnalysisException](
+ courseSales.groupBy("year").pivot("course")
)
sqlContext.conf.setConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES,
SQLConf.DATAFRAME_PIVOT_MAX_VALUES.defaultValue.get)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala
index abad0d7eaa..83c63e04f3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala
@@ -281,6 +281,7 @@ private[sql] trait SQLTestData { self =>
person
salary
complexData
+ courseSales
}
}