From be7ef067620408859144e0244b0f1b8eb56faa86 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Mon, 29 Jun 2015 13:15:04 -0700 Subject: [SPARK-8681] fixed wrong ordering of columns in crosstab I specifically randomized the test. What crosstab does is equivalent to a countByKey, therefore if this test fails again for any reason, we will know that we hit a corner case or something. cc rxin marmbrus Author: Burak Yavuz Closes #7060 from brkyvz/crosstab-fixes and squashes the following commits: 0a65234 [Burak Yavuz] addressed comments v1 d96da7e [Burak Yavuz] fixed wrong ordering of columns in crosstab --- .../spark/sql/execution/stat/StatFunctions.scala | 8 ++++--- .../org/apache/spark/sql/DataFrameStatSuite.scala | 28 ++++++++++++---------- 2 files changed, 20 insertions(+), 16 deletions(-) (limited to 'sql') diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index 042e2c9cbb..b624ef7e8f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -111,7 +111,7 @@ private[sql] object StatFunctions extends Logging { "the pairs. Please try reducing the amount of distinct items in your columns.") } // get the distinct values of column 2, so that we can make them the column names - val distinctCol2 = counts.map(_.get(1)).distinct.zipWithIndex.toMap + val distinctCol2: Map[Any, Int] = counts.map(_.get(1)).distinct.zipWithIndex.toMap val columnSize = distinctCol2.size require(columnSize < 1e4, s"The number of distinct values for $col2, can't " + s"exceed 1e4. Currently $columnSize") @@ -120,14 +120,16 @@ private[sql] object StatFunctions extends Logging { rows.foreach { (row: Row) => // row.get(0) is column 1 // row.get(1) is column 2 - // row.get(3) is the frequency + // row.get(2) is the frequency countsRow.setLong(distinctCol2.get(row.get(1)).get + 1, row.getLong(2)) } // the value of col1 is the first value, the rest are the counts countsRow.update(0, UTF8String.fromString(col1Item.toString)) countsRow }.toSeq - val headerNames = distinctCol2.map(r => StructField(r._1.toString, LongType)).toSeq + // In the map, the column names (._1) are not ordered by the index (._2). This was the bug in + // SPARK-8681. We need to explicitly sort by the column index and assign the column names. + val headerNames = distinctCol2.toSeq.sortBy(_._2).map(r => StructField(r._1.toString, LongType)) val schema = StructType(StructField(tableName, StringType) +: headerNames) new DataFrame(df.sqlContext, LocalRelation(schema.toAttributes, table)).na.fill(0.0) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 0d3ff899da..64ec1a70c4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql +import java.util.Random + import org.scalatest.Matchers._ import org.apache.spark.SparkFunSuite @@ -65,22 +67,22 @@ class DataFrameStatSuite extends SparkFunSuite { } test("crosstab") { - val df = Seq((0, 0), (2, 1), (1, 0), (2, 0), (0, 0), (2, 0)).toDF("a", "b") + val rng = new Random() + val data = Seq.tabulate(25)(i => (rng.nextInt(5), rng.nextInt(10))) + val df = data.toDF("a", "b") val crosstab = df.stat.crosstab("a", "b") val columnNames = crosstab.schema.fieldNames assert(columnNames(0) === "a_b") - assert(columnNames(1) === "0") - assert(columnNames(2) === "1") - val rows: Array[Row] = crosstab.collect().sortBy(_.getString(0)) - assert(rows(0).get(0).toString === "0") - assert(rows(0).getLong(1) === 2L) - assert(rows(0).get(2) === 0L) - assert(rows(1).get(0).toString === "1") - assert(rows(1).getLong(1) === 1L) - assert(rows(1).get(2) === 0L) - assert(rows(2).get(0).toString === "2") - assert(rows(2).getLong(1) === 2L) - assert(rows(2).getLong(2) === 1L) + // reduce by key + val expected = data.map(t => (t, 1)).groupBy(_._1).mapValues(_.length) + val rows = crosstab.collect() + rows.foreach { row => + val i = row.getString(0).toInt + for (col <- 1 to 9) { + val j = columnNames(col).toInt + assert(row.getLong(col) === expected.getOrElse((i, j), 0).toLong) + } + } } test("Frequent Items") { -- cgit v1.2.3