aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2015-06-01 21:11:19 -0700
committerReynold Xin <rxin@databricks.com>2015-06-01 21:11:19 -0700
commit6396cc0303ceabea53c4df436ffa50b82b7e233f (patch)
tree4b7242d2f1ce9af82f6d36433c7841d8a7d15cb0 /sql
parent6b44278ef7cd2a278dfa67e8393ef30775c72726 (diff)
downloadspark-6396cc0303ceabea53c4df436ffa50b82b7e233f.tar.gz
spark-6396cc0303ceabea53c4df436ffa50b82b7e233f.tar.bz2
spark-6396cc0303ceabea53c4df436ffa50b82b7e233f.zip
[SPARK-7982][SQL] DataFrame.stat.crosstab should use 0 instead of null for pairs that don't appear
Author: Reynold Xin <rxin@databricks.com> Closes #6566 from rxin/crosstab and squashes the following commits: e0ace1c [Reynold Xin] [SPARK-7982][SQL] DataFrame.stat.crosstab should use 0 instead of null for pairs that don't appear
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala9
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala4
2 files changed, 8 insertions, 5 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
index b1a8204dd5..93383e5a62 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.execution.stat
import org.apache.spark.Logging
-import org.apache.spark.sql.{Column, DataFrame}
+import org.apache.spark.sql.{Row, Column, DataFrame}
import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, Cast}
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.functions._
@@ -116,7 +116,10 @@ private[sql] object StatFunctions extends Logging {
s"exceed 1e4. Currently $columnSize")
val table = counts.groupBy(_.get(0)).map { case (col1Item, rows) =>
val countsRow = new GenericMutableRow(columnSize + 1)
- rows.foreach { row =>
+ rows.foreach { (row: Row) =>
+ // row.get(0) is column 1
+ // row.get(1) is column 2
+ // row.get(3) is the frequency
countsRow.setLong(distinctCol2.get(row.get(1)).get + 1, row.getLong(2))
}
// the value of col1 is the first value, the rest are the counts
@@ -126,6 +129,6 @@ private[sql] object StatFunctions extends Logging {
val headerNames = distinctCol2.map(r => StructField(r._1.toString, LongType)).toSeq
val schema = StructType(StructField(tableName, StringType) +: headerNames)
- new DataFrame(df.sqlContext, LocalRelation(schema.toAttributes, table))
+ new DataFrame(df.sqlContext, LocalRelation(schema.toAttributes, table)).na.fill(0.0)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
index 78de89f0b9..438f479459 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
@@ -74,10 +74,10 @@ class DataFrameStatSuite extends SparkFunSuite {
val rows: Array[Row] = crosstab.collect().sortBy(_.getString(0))
assert(rows(0).get(0).toString === "0")
assert(rows(0).getLong(1) === 2L)
- assert(rows(0).get(2) === null)
+ assert(rows(0).get(2) === 0L)
assert(rows(1).get(0).toString === "1")
assert(rows(1).getLong(1) === 1L)
- assert(rows(1).get(2) === null)
+ assert(rows(1).get(2) === 0L)
assert(rows(2).get(0).toString === "2")
assert(rows(2).getLong(1) === 2L)
assert(rows(2).getLong(2) === 1L)