aboutsummaryrefslogtreecommitdiff
path: root/sql/core
diff options
context:
space:
mode:
authorBurak Yavuz <brkyvz@gmail.com>2015-06-29 13:15:04 -0700
committerReynold Xin <rxin@databricks.com>2015-06-29 13:15:04 -0700
commitbe7ef067620408859144e0244b0f1b8eb56faa86 (patch)
treea2afa180ddb8744c4d52e1571e617ce24598ba10 /sql/core
parentc6ba2ea341ad23de265d870669b25e6a41f461e5 (diff)
downloadspark-be7ef067620408859144e0244b0f1b8eb56faa86.tar.gz
spark-be7ef067620408859144e0244b0f1b8eb56faa86.tar.bz2
spark-be7ef067620408859144e0244b0f1b8eb56faa86.zip
[SPARK-8681] fixed wrong ordering of columns in crosstab
I specifically randomized the test. What crosstab does is equivalent to a countByKey, therefore if this test fails again for any reason, we will know that we hit a corner case or something. cc rxin marmbrus Author: Burak Yavuz <brkyvz@gmail.com> Closes #7060 from brkyvz/crosstab-fixes and squashes the following commits: 0a65234 [Burak Yavuz] addressed comments v1 d96da7e [Burak Yavuz] fixed wrong ordering of columns in crosstab
Diffstat (limited to 'sql/core')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala28
2 files changed, 20 insertions, 16 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
index 042e2c9cbb..b624ef7e8f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
@@ -111,7 +111,7 @@ private[sql] object StatFunctions extends Logging {
"the pairs. Please try reducing the amount of distinct items in your columns.")
}
// get the distinct values of column 2, so that we can make them the column names
- val distinctCol2 = counts.map(_.get(1)).distinct.zipWithIndex.toMap
+ val distinctCol2: Map[Any, Int] = counts.map(_.get(1)).distinct.zipWithIndex.toMap
val columnSize = distinctCol2.size
require(columnSize < 1e4, s"The number of distinct values for $col2, can't " +
s"exceed 1e4. Currently $columnSize")
@@ -120,14 +120,16 @@ private[sql] object StatFunctions extends Logging {
rows.foreach { (row: Row) =>
// row.get(0) is column 1
// row.get(1) is column 2
- // row.get(3) is the frequency
+ // row.get(2) is the frequency
countsRow.setLong(distinctCol2.get(row.get(1)).get + 1, row.getLong(2))
}
// the value of col1 is the first value, the rest are the counts
countsRow.update(0, UTF8String.fromString(col1Item.toString))
countsRow
}.toSeq
- val headerNames = distinctCol2.map(r => StructField(r._1.toString, LongType)).toSeq
+ // In the map, the column names (._1) are not ordered by the index (._2). This was the bug in
+ // SPARK-8681. We need to explicitly sort by the column index and assign the column names.
+ val headerNames = distinctCol2.toSeq.sortBy(_._2).map(r => StructField(r._1.toString, LongType))
val schema = StructType(StructField(tableName, StringType) +: headerNames)
new DataFrame(df.sqlContext, LocalRelation(schema.toAttributes, table)).na.fill(0.0)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
index 0d3ff899da..64ec1a70c4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql
+import java.util.Random
+
import org.scalatest.Matchers._
import org.apache.spark.SparkFunSuite
@@ -65,22 +67,22 @@ class DataFrameStatSuite extends SparkFunSuite {
}
test("crosstab") {
- val df = Seq((0, 0), (2, 1), (1, 0), (2, 0), (0, 0), (2, 0)).toDF("a", "b")
+ val rng = new Random()
+ val data = Seq.tabulate(25)(i => (rng.nextInt(5), rng.nextInt(10)))
+ val df = data.toDF("a", "b")
val crosstab = df.stat.crosstab("a", "b")
val columnNames = crosstab.schema.fieldNames
assert(columnNames(0) === "a_b")
- assert(columnNames(1) === "0")
- assert(columnNames(2) === "1")
- val rows: Array[Row] = crosstab.collect().sortBy(_.getString(0))
- assert(rows(0).get(0).toString === "0")
- assert(rows(0).getLong(1) === 2L)
- assert(rows(0).get(2) === 0L)
- assert(rows(1).get(0).toString === "1")
- assert(rows(1).getLong(1) === 1L)
- assert(rows(1).get(2) === 0L)
- assert(rows(2).get(0).toString === "2")
- assert(rows(2).getLong(1) === 2L)
- assert(rows(2).getLong(2) === 1L)
+ // reduce by key
+ val expected = data.map(t => (t, 1)).groupBy(_._1).mapValues(_.length)
+ val rows = crosstab.collect()
+ rows.foreach { row =>
+ val i = row.getString(0).toInt
+ for (col <- 1 to 9) {
+ val j = columnNames(col).toInt
+ assert(row.getLong(col) === expected.getOrElse((i, j), 0).toLong)
+ }
+ }
}
test("Frequent Items") {