From 6396cc0303ceabea53c4df436ffa50b82b7e233f Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 1 Jun 2015 21:11:19 -0700 Subject: [SPARK-7982][SQL] DataFrame.stat.crosstab should use 0 instead of null for pairs that don't appear Author: Reynold Xin Closes #6566 from rxin/crosstab and squashes the following commits: e0ace1c [Reynold Xin] [SPARK-7982][SQL] DataFrame.stat.crosstab should use 0 instead of null for pairs that don't appear --- .../org/apache/spark/sql/execution/stat/StatFunctions.scala | 9 ++++++--- .../src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala | 4 ++-- 2 files changed, 8 insertions(+), 5 deletions(-) (limited to 'sql/core/src') diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index b1a8204dd5..93383e5a62 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.stat import org.apache.spark.Logging -import org.apache.spark.sql.{Column, DataFrame} +import org.apache.spark.sql.{Row, Column, DataFrame} import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, Cast} import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.functions._ @@ -116,7 +116,10 @@ private[sql] object StatFunctions extends Logging { s"exceed 1e4. Currently $columnSize") val table = counts.groupBy(_.get(0)).map { case (col1Item, rows) => val countsRow = new GenericMutableRow(columnSize + 1) - rows.foreach { row => + rows.foreach { (row: Row) => + // row.get(0) is column 1 + // row.get(1) is column 2 + // row.get(3) is the frequency countsRow.setLong(distinctCol2.get(row.get(1)).get + 1, row.getLong(2)) } // the value of col1 is the first value, the rest are the counts @@ -126,6 +129,6 @@ private[sql] object StatFunctions extends Logging { val headerNames = distinctCol2.map(r => StructField(r._1.toString, LongType)).toSeq val schema = StructType(StructField(tableName, StringType) +: headerNames) - new DataFrame(df.sqlContext, LocalRelation(schema.toAttributes, table)) + new DataFrame(df.sqlContext, LocalRelation(schema.toAttributes, table)).na.fill(0.0) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 78de89f0b9..438f479459 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -74,10 +74,10 @@ class DataFrameStatSuite extends SparkFunSuite { val rows: Array[Row] = crosstab.collect().sortBy(_.getString(0)) assert(rows(0).get(0).toString === "0") assert(rows(0).getLong(1) === 2L) - assert(rows(0).get(2) === null) + assert(rows(0).get(2) === 0L) assert(rows(1).get(0).toString === "1") assert(rows(1).getLong(1) === 1L) - assert(rows(1).get(2) === null) + assert(rows(1).get(2) === 0L) assert(rows(2).get(0).toString === "2") assert(rows(2).getLong(1) === 2L) assert(rows(2).getLong(2) === 1L) -- cgit v1.2.3