aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDongjoon Hyun <dongjoon@apache.org>2016-06-09 22:46:51 -0700
committerReynold Xin <rxin@databricks.com>2016-06-09 22:46:51 -0700
commit5a3533e779d8e43ce0980203dfd3cbe343cc7d0a (patch)
tree77c967b1db3c2afe2cb140e76807ef4883854441
parent6c5fd977fbcb821a57cb4a13bc3d413a695fbc32 (diff)
downloadspark-5a3533e779d8e43ce0980203dfd3cbe343cc7d0a.tar.gz
spark-5a3533e779d8e43ce0980203dfd3cbe343cc7d0a.tar.bz2
spark-5a3533e779d8e43ce0980203dfd3cbe343cc7d0a.zip
[SPARK-15696][SQL] Improve `crosstab` to have a consistent column order
## What changes were proposed in this pull request? Currently, `crosstab` returns a Dataframe having **random-order** columns obtained by just `distinct`. Also, the documentation of `crosstab` shows the result in a sorted order which is different from the current implementation. This PR explicitly constructs the columns in a sorted order in order to improve user experience. Also, this implementation gives the same result with the documentation. **Before** ```scala scala> spark.createDataFrame(Seq((1, 1), (1, 2), (2, 1), (2, 1), (2, 3), (3, 2), (3, 3))).toDF("key", "value").stat.crosstab("key", "value").show() +---------+---+---+---+ |key_value| 3| 2| 1| +---------+---+---+---+ | 2| 1| 0| 2| | 1| 0| 1| 1| | 3| 1| 1| 0| +---------+---+---+---+ scala> spark.createDataFrame(Seq((1, "a"), (1, "b"), (2, "a"), (2, "a"), (2, "c"), (3, "b"), (3, "c"))).toDF("key", "value").stat.crosstab("key", "value").show() +---------+---+---+---+ |key_value| c| a| b| +---------+---+---+---+ | 2| 1| 2| 0| | 1| 0| 1| 1| | 3| 1| 0| 1| +---------+---+---+---+ ``` **After** ```scala scala> spark.createDataFrame(Seq((1, 1), (1, 2), (2, 1), (2, 1), (2, 3), (3, 2), (3, 3))).toDF("key", "value").stat.crosstab("key", "value").show() +---------+---+---+---+ |key_value| 1| 2| 3| +---------+---+---+---+ | 2| 2| 0| 1| | 1| 1| 1| 0| | 3| 0| 1| 1| +---------+---+---+---+ scala> spark.createDataFrame(Seq((1, "a"), (1, "b"), (2, "a"), (2, "a"), (2, "c"), (3, "b"), (3, "c"))).toDF("key", "value").stat.crosstab("key", "value").show() +---------+---+---+---+ |key_value| a| b| c| +---------+---+---+---+ | 2| 2| 0| 1| | 1| 1| 1| 0| | 3| 0| 1| 1| +---------+---+---+---+ ``` ## How was this patch tested? Pass the Jenkins tests with updated testcases. Author: Dongjoon Hyun <dongjoon@apache.org> Closes #13436 from dongjoon-hyun/SPARK-15696.
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala4
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java4
2 files changed, 4 insertions, 4 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
index 9c0406168e..ea58df70b3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
@@ -423,9 +423,9 @@ private[sql] object StatFunctions extends Logging {
def cleanElement(element: Any): String = {
if (element == null) "null" else element.toString
}
- // get the distinct values of column 2, so that we can make them the column names
+ // get the distinct sorted values of column 2, so that we can make them the column names
val distinctCol2: Map[Any, Int] =
- counts.map(e => cleanElement(e.get(1))).distinct.zipWithIndex.toMap
+ counts.map(e => cleanElement(e.get(1))).distinct.sorted.zipWithIndex.toMap
val columnSize = distinctCol2.size
require(columnSize < 1e4, s"The number of distinct values for $col2, can't " +
s"exceed 1e4. Currently $columnSize")
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
index 1e8f1062c5..0152f3f85a 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
@@ -246,8 +246,8 @@ public class JavaDataFrameSuite {
Dataset<Row> crosstab = df.stat().crosstab("a", "b");
String[] columnNames = crosstab.schema().fieldNames();
Assert.assertEquals("a_b", columnNames[0]);
- Assert.assertEquals("2", columnNames[1]);
- Assert.assertEquals("1", columnNames[2]);
+ Assert.assertEquals("1", columnNames[1]);
+ Assert.assertEquals("2", columnNames[2]);
List<Row> rows = crosstab.collectAsList();
Collections.sort(rows, crosstabRowComparator);
Integer count = 1;