aboutsummaryrefslogtreecommitdiff
path: root/sql/core
diff options
context:
space:
mode:
Diffstat (limited to 'sql/core')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala154
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java16
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala21
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala1
4 files changed, 116 insertions, 76 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
index 63dd7fbcbe..ee7150cbbf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
@@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, UnresolvedAli
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical.{Pivot, Rollup, Cube, Aggregate}
-import org.apache.spark.sql.types.{StringType, NumericType}
+import org.apache.spark.sql.types.NumericType
/**
@@ -282,74 +282,96 @@ class GroupedData protected[sql](
}
/**
- * (Scala-specific) Pivots a column of the current [[DataFrame]] and preform the specified
- * aggregation.
- * {{{
- * // Compute the sum of earnings for each year by course with each course as a separate column
- * df.groupBy($"year").pivot($"course", "dotNET", "Java").agg(sum($"earnings"))
- * // Or without specifying column values
- * df.groupBy($"year").pivot($"course").agg(sum($"earnings"))
- * }}}
- * @param pivotColumn Column to pivot
- * @param values Optional list of values of pivotColumn that will be translated to columns in the
- * output data frame. If values are not provided the method with do an immediate
- * call to .distinct() on the pivot column.
- * @since 1.6.0
- */
- @scala.annotation.varargs
- def pivot(pivotColumn: Column, values: Column*): GroupedData = groupType match {
- case _: GroupedData.PivotType =>
- throw new UnsupportedOperationException("repeated pivots are not supported")
- case GroupedData.GroupByType =>
- val pivotValues = if (values.nonEmpty) {
- values.map {
- case Column(literal: Literal) => literal
- case other =>
- throw new UnsupportedOperationException(
- s"The values of a pivot must be literals, found $other")
- }
- } else {
- // This is to prevent unintended OOM errors when the number of distinct values is large
- val maxValues = df.sqlContext.conf.getConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES)
- // Get the distinct values of the column and sort them so its consistent
- val values = df.select(pivotColumn)
- .distinct()
- .sort(pivotColumn)
- .map(_.get(0))
- .take(maxValues + 1)
- .map(Literal(_)).toSeq
- if (values.length > maxValues) {
- throw new RuntimeException(
- s"The pivot column $pivotColumn has more than $maxValues distinct values, " +
- "this could indicate an error. " +
- "If this was intended, set \"" + SQLConf.DATAFRAME_PIVOT_MAX_VALUES.key + "\" " +
- s"to at least the number of distinct values of the pivot column.")
- }
- values
- }
- new GroupedData(df, groupingExprs, GroupedData.PivotType(pivotColumn.expr, pivotValues))
- case _ =>
- throw new UnsupportedOperationException("pivot is only supported after a groupBy")
+ * Pivots a column of the current [[DataFrame]] and preform the specified aggregation.
+ * There are two versions of pivot function: one that requires the caller to specify the list
+ * of distinct values to pivot on, and one that does not. The latter is more concise but less
+ * efficient, because Spark needs to first compute the list of distinct values internally.
+ *
+ * {{{
+ * // Compute the sum of earnings for each year by course with each course as a separate column
+ * df.groupBy("year").pivot("course", Seq("dotNET", "Java")).sum("earnings")
+ *
+ * // Or without specifying column values (less efficient)
+ * df.groupBy("year").pivot("course").sum("earnings")
+ * }}}
+ *
+ * @param pivotColumn Name of the column to pivot.
+ * @since 1.6.0
+ */
+ def pivot(pivotColumn: String): GroupedData = {
+ // This is to prevent unintended OOM errors when the number of distinct values is large
+ val maxValues = df.sqlContext.conf.getConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES)
+ // Get the distinct values of the column and sort them so its consistent
+ val values = df.select(pivotColumn)
+ .distinct()
+ .sort(pivotColumn)
+ .map(_.get(0))
+ .take(maxValues + 1)
+ .toSeq
+
+ if (values.length > maxValues) {
+ throw new AnalysisException(
+ s"The pivot column $pivotColumn has more than $maxValues distinct values, " +
+ "this could indicate an error. " +
+ s"If this was intended, set ${SQLConf.DATAFRAME_PIVOT_MAX_VALUES.key} " +
+ "to at least the number of distinct values of the pivot column.")
+ }
+
+ pivot(pivotColumn, values)
}
/**
- * Pivots a column of the current [[DataFrame]] and preform the specified aggregation.
- * {{{
- * // Compute the sum of earnings for each year by course with each course as a separate column
- * df.groupBy("year").pivot("course", "dotNET", "Java").sum("earnings")
- * // Or without specifying column values
- * df.groupBy("year").pivot("course").sum("earnings")
- * }}}
- * @param pivotColumn Column to pivot
- * @param values Optional list of values of pivotColumn that will be translated to columns in the
- * output data frame. If values are not provided the method with do an immediate
- * call to .distinct() on the pivot column.
- * @since 1.6.0
- */
- @scala.annotation.varargs
- def pivot(pivotColumn: String, values: Any*): GroupedData = {
- val resolvedPivotColumn = Column(df.resolve(pivotColumn))
- pivot(resolvedPivotColumn, values.map(functions.lit): _*)
+ * Pivots a column of the current [[DataFrame]] and preform the specified aggregation.
+ * There are two versions of pivot function: one that requires the caller to specify the list
+ * of distinct values to pivot on, and one that does not. The latter is more concise but less
+ * efficient, because Spark needs to first compute the list of distinct values internally.
+ *
+ * {{{
+ * // Compute the sum of earnings for each year by course with each course as a separate column
+ * df.groupBy("year").pivot("course", Seq("dotNET", "Java")).sum("earnings")
+ *
+ * // Or without specifying column values (less efficient)
+ * df.groupBy("year").pivot("course").sum("earnings")
+ * }}}
+ *
+ * @param pivotColumn Name of the column to pivot.
+ * @param values List of values that will be translated to columns in the output DataFrame.
+ * @since 1.6.0
+ */
+ def pivot(pivotColumn: String, values: Seq[Any]): GroupedData = {
+ groupType match {
+ case GroupedData.GroupByType =>
+ new GroupedData(
+ df,
+ groupingExprs,
+ GroupedData.PivotType(df.resolve(pivotColumn), values.map(Literal.apply)))
+ case _: GroupedData.PivotType =>
+ throw new UnsupportedOperationException("repeated pivots are not supported")
+ case _ =>
+ throw new UnsupportedOperationException("pivot is only supported after a groupBy")
+ }
+ }
+
+ /**
+ * Pivots a column of the current [[DataFrame]] and preform the specified aggregation.
+ * There are two versions of pivot function: one that requires the caller to specify the list
+ * of distinct values to pivot on, and one that does not. The latter is more concise but less
+ * efficient, because Spark needs to first compute the list of distinct values internally.
+ *
+ * {{{
+ * // Compute the sum of earnings for each year by course with each course as a separate column
+ * df.groupBy("year").pivot("course", Arrays.<Object>asList("dotNET", "Java")).sum("earnings");
+ *
+ * // Or without specifying column values (less efficient)
+ * df.groupBy("year").pivot("course").sum("earnings");
+ * }}}
+ *
+ * @param pivotColumn Name of the column to pivot.
+ * @param values List of values that will be translated to columns in the output DataFrame.
+ * @since 1.6.0
+ */
+ def pivot(pivotColumn: String, values: java.util.List[Any]): GroupedData = {
+ pivot(pivotColumn, values.asScala)
}
}
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
index 567bdddece..a12fed3c0c 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
@@ -282,4 +282,20 @@ public class JavaDataFrameSuite {
Assert.assertEquals(1, actual[1].getLong(0));
Assert.assertTrue(2 <= actual[1].getLong(1) && actual[1].getLong(1) <= 13);
}
+
+ @Test
+ public void pivot() {
+ DataFrame df = context.table("courseSales");
+ Row[] actual = df.groupBy("year")
+ .pivot("course", Arrays.<Object>asList("dotNET", "Java"))
+ .agg(sum("earnings")).orderBy("year").collect();
+
+ Assert.assertEquals(2012, actual[0].getInt(0));
+ Assert.assertEquals(15000.0, actual[0].getDouble(1), 0.01);
+ Assert.assertEquals(20000.0, actual[0].getDouble(2), 0.01);
+
+ Assert.assertEquals(2013, actual[1].getInt(0));
+ Assert.assertEquals(48000.0, actual[1].getDouble(1), 0.01);
+ Assert.assertEquals(30000.0, actual[1].getDouble(2), 0.01);
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala
index 0c23d14267..fc53aba68e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala
@@ -25,7 +25,7 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext{
test("pivot courses with literals") {
checkAnswer(
- courseSales.groupBy($"year").pivot($"course", lit("dotNET"), lit("Java"))
+ courseSales.groupBy("year").pivot("course", Seq("dotNET", "Java"))
.agg(sum($"earnings")),
Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil
)
@@ -33,14 +33,15 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext{
test("pivot year with literals") {
checkAnswer(
- courseSales.groupBy($"course").pivot($"year", lit(2012), lit(2013)).agg(sum($"earnings")),
+ courseSales.groupBy("course").pivot("year", Seq(2012, 2013)).agg(sum($"earnings")),
Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
)
}
test("pivot courses with literals and multiple aggregations") {
checkAnswer(
- courseSales.groupBy($"year").pivot($"course", lit("dotNET"), lit("Java"))
+ courseSales.groupBy($"year")
+ .pivot("course", Seq("dotNET", "Java"))
.agg(sum($"earnings"), avg($"earnings")),
Row(2012, 15000.0, 7500.0, 20000.0, 20000.0) ::
Row(2013, 48000.0, 48000.0, 30000.0, 30000.0) :: Nil
@@ -49,14 +50,14 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext{
test("pivot year with string values (cast)") {
checkAnswer(
- courseSales.groupBy("course").pivot("year", "2012", "2013").sum("earnings"),
+ courseSales.groupBy("course").pivot("year", Seq("2012", "2013")).sum("earnings"),
Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
)
}
test("pivot year with int values") {
checkAnswer(
- courseSales.groupBy("course").pivot("year", 2012, 2013).sum("earnings"),
+ courseSales.groupBy("course").pivot("year", Seq(2012, 2013)).sum("earnings"),
Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
)
}
@@ -64,22 +65,22 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext{
test("pivot courses with no values") {
// Note Java comes before dotNet in sorted order
checkAnswer(
- courseSales.groupBy($"year").pivot($"course").agg(sum($"earnings")),
+ courseSales.groupBy("year").pivot("course").agg(sum($"earnings")),
Row(2012, 20000.0, 15000.0) :: Row(2013, 30000.0, 48000.0) :: Nil
)
}
test("pivot year with no values") {
checkAnswer(
- courseSales.groupBy($"course").pivot($"year").agg(sum($"earnings")),
+ courseSales.groupBy("course").pivot("year").agg(sum($"earnings")),
Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
)
}
- test("pivot max values inforced") {
+ test("pivot max values enforced") {
sqlContext.conf.setConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES, 1)
- intercept[RuntimeException](
- courseSales.groupBy($"year").pivot($"course")
+ intercept[AnalysisException](
+ courseSales.groupBy("year").pivot("course")
)
sqlContext.conf.setConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES,
SQLConf.DATAFRAME_PIVOT_MAX_VALUES.defaultValue.get)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala
index abad0d7eaa..83c63e04f3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala
@@ -281,6 +281,7 @@ private[sql] trait SQLTestData { self =>
person
salary
complexData
+ courseSales
}
}