aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorBurak Yavuz <brkyvz@gmail.com>2015-05-04 17:02:49 -0700
committerReynold Xin <rxin@databricks.com>2015-05-04 17:02:49 -0700
commit80554111703c08e2bedbe303e04ecd162ec119e1 (patch)
treefc2ea97df7c1111f33020329d93a9338ecd5fecb /sql
parentfc8b58195afa67fbb75b4c8303e022f703cbf007 (diff)
downloadspark-80554111703c08e2bedbe303e04ecd162ec119e1.tar.gz
spark-80554111703c08e2bedbe303e04ecd162ec119e1.tar.bz2
spark-80554111703c08e2bedbe303e04ecd162ec119e1.zip
[SPARK-7243][SQL] Contingency Tables for DataFrames
Computes a pair-wise frequency table of the given columns. Also known as cross-tabulation. cc mengxr rxin Author: Burak Yavuz <brkyvz@gmail.com> Closes #5842 from brkyvz/df-cont and squashes the following commits: a07c01e [Burak Yavuz] addressed comments v4.1 ae9e01d [Burak Yavuz] fix test 9106585 [Burak Yavuz] addressed comments v4.0 bced829 [Burak Yavuz] fix merge conflicts a63ad00 [Burak Yavuz] addressed comments v3.0 a0cad97 [Burak Yavuz] addressed comments v3.0 6805df8 [Burak Yavuz] addressed comments and fixed test 939b7c4 [Burak Yavuz] lint python 7f098bc [Burak Yavuz] add crosstab pyTest fd53b00 [Burak Yavuz] added python support for crosstab 27a5a81 [Burak Yavuz] implemented crosstab
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala37
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala37
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java28
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala55
4 files changed, 126 insertions, 31 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
index 9035321052..fcf21ca741 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
@@ -28,6 +28,16 @@ import org.apache.spark.sql.execution.stat._
final class DataFrameStatFunctions private[sql](df: DataFrame) {
/**
+ * Calculate the sample covariance of two numerical columns of a DataFrame.
+ * @param col1 the name of the first column
+ * @param col2 the name of the second column
+ * @return the covariance of the two columns.
+ */
+ def cov(col1: String, col2: String): Double = {
+ StatFunctions.calculateCov(df, Seq(col1, col2))
+ }
+
+ /*
* Calculates the correlation of two columns of a DataFrame. Currently only supports the Pearson
* Correlation Coefficient. For Spearman Correlation, consider using RDD methods found in
* MLlib's Statistics.
@@ -54,6 +64,23 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
}
/**
+ * Computes a pair-wise frequency table of the given columns. Also known as a contingency table.
+ * The number of distinct values for each column should be less than 1e4. The first
+ * column of each row will be the distinct values of `col1` and the column names will be the
+ * distinct values of `col2`. The name of the first column will be `$col1_$col2`. Counts will be
+ * returned as `Long`s. Pairs that have no occurrences will have `null` as their counts.
+ *
+ * @param col1 The name of the first column. Distinct items will make the first item of
+ * each row.
+ * @param col2 The name of the second column. Distinct items will make the column names
+ * of the DataFrame.
+ * @return A Local DataFrame containing the table
+ */
+ def crosstab(col1: String, col2: String): DataFrame = {
+ StatFunctions.crossTabulate(df, col1, col2)
+ }
+
+ /**
* Finding frequent items for columns, possibly with false positives. Using the
* frequent element count algorithm described in
* [[http://dx.doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou]].
@@ -94,14 +121,4 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
def freqItems(cols: Seq[String]): DataFrame = {
FrequentItems.singlePassFreqItems(df, cols, 0.01)
}
-
- /**
- * Calculate the sample covariance of two numerical columns of a DataFrame.
- * @param col1 the name of the first column
- * @param col2 the name of the second column
- * @return the covariance of the two columns.
- */
- def cov(col1: String, col2: String): Double = {
- StatFunctions.calculateCov(df, Seq(col1, col2))
- }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
index 67b48e58b1..b50f606d9c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
@@ -17,11 +17,14 @@
package org.apache.spark.sql.execution.stat
-import org.apache.spark.sql.catalyst.expressions.Cast
+import org.apache.spark.Logging
import org.apache.spark.sql.{Column, DataFrame}
-import org.apache.spark.sql.types.{DoubleType, NumericType}
+import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, Cast}
+import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types._
-private[sql] object StatFunctions {
+private[sql] object StatFunctions extends Logging {
/** Calculate the Pearson Correlation Coefficient for the given columns */
private[sql] def pearsonCorrelation(df: DataFrame, cols: Seq[String]): Double = {
@@ -95,4 +98,32 @@ private[sql] object StatFunctions {
val counts = collectStatisticalData(df, cols)
counts.cov
}
+
+ /** Generate a table of frequencies for the elements of two columns. */
+ private[sql] def crossTabulate(df: DataFrame, col1: String, col2: String): DataFrame = {
+ val tableName = s"${col1}_$col2"
+ val counts = df.groupBy(col1, col2).agg(col(col1), col(col2), count("*")).take(1e8.toInt)
+ if (counts.length == 1e8.toInt) {
+ logWarning("The maximum limit of 1e8 pairs have been collected, which may not be all of " +
+ "the pairs. Please try reducing the amount of distinct items in your columns.")
+ }
+ // get the distinct values of column 2, so that we can make them the column names
+ val distinctCol2 = counts.map(_.get(1)).distinct.zipWithIndex.toMap
+ val columnSize = distinctCol2.size
+ require(columnSize < 1e4, s"The number of distinct values for $col2, can't " +
+ s"exceed 1e4. Currently $columnSize")
+ val table = counts.groupBy(_.get(0)).map { case (col1Item, rows) =>
+ val countsRow = new GenericMutableRow(columnSize + 1)
+ rows.foreach { row =>
+ countsRow.setLong(distinctCol2.get(row.get(1)).get + 1, row.getLong(2))
+ }
+ // the value of col1 is the first value, the rest are the counts
+ countsRow.setString(0, col1Item.toString)
+ countsRow
+ }.toSeq
+ val headerNames = distinctCol2.map(r => StructField(r._1.toString, LongType)).toSeq
+ val schema = StructType(StructField(tableName, StringType) +: headerNames)
+
+ new DataFrame(df.sqlContext, LocalRelation(schema.toAttributes, table))
+ }
}
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
index 78e847239f..58cc8e5be6 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
@@ -34,6 +34,7 @@ import scala.collection.mutable.Buffer;
import java.io.Serializable;
import java.util.Arrays;
+import java.util.Comparator;
import java.util.List;
import java.util.Map;
@@ -178,6 +179,33 @@ public class JavaDataFrameSuite {
Assert.assertEquals(bean.getD().get(i), d.apply(i));
}
}
+
+ private static Comparator<Row> CrosstabRowComparator = new Comparator<Row>() {
+ public int compare(Row row1, Row row2) {
+ String item1 = row1.getString(0);
+ String item2 = row2.getString(0);
+ return item1.compareTo(item2);
+ }
+ };
+
+ @Test
+ public void testCrosstab() {
+ DataFrame df = context.table("testData2");
+ DataFrame crosstab = df.stat().crosstab("a", "b");
+ String[] columnNames = crosstab.schema().fieldNames();
+ Assert.assertEquals(columnNames[0], "a_b");
+ Assert.assertEquals(columnNames[1], "1");
+ Assert.assertEquals(columnNames[2], "2");
+ Row[] rows = crosstab.collect();
+ Arrays.sort(rows, CrosstabRowComparator);
+ Integer count = 1;
+ for (Row row : rows) {
+ Assert.assertEquals(row.get(0).toString(), count.toString());
+ Assert.assertEquals(row.getLong(1), 1L);
+ Assert.assertEquals(row.getLong(2), 1L);
+ count++;
+ }
+ }
@Test
public void testFrequentItems() {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
index 06764d2a12..46b1845a91 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
@@ -24,26 +24,9 @@ import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext.implicits._
class DataFrameStatSuite extends FunSuite {
-
- import TestData._
+
val sqlCtx = TestSQLContext
def toLetter(i: Int): String = (i + 97).toChar.toString
-
- test("Frequent Items") {
- val rows = Seq.tabulate(1000) { i =>
- if (i % 3 == 0) (1, toLetter(1), -1.0) else (i, toLetter(i), i * -1.0)
- }
- val df = rows.toDF("numbers", "letters", "negDoubles")
-
- val results = df.stat.freqItems(Array("numbers", "letters"), 0.1)
- val items = results.collect().head
- items.getSeq[Int](0) should contain (1)
- items.getSeq[String](1) should contain (toLetter(1))
-
- val singleColResults = df.stat.freqItems(Array("negDoubles"), 0.1)
- val items2 = singleColResults.collect().head
- items2.getSeq[Double](0) should contain (-1.0)
- }
test("pearson correlation") {
val df = Seq.tabulate(10)(i => (i, 2 * i, i * -1.0)).toDF("a", "b", "c")
@@ -76,7 +59,43 @@ class DataFrameStatSuite extends FunSuite {
intercept[IllegalArgumentException] {
df.stat.cov("singles", "letters") // doesn't accept non-numerical dataTypes
}
+ val decimalData = Seq.tabulate(6)(i => (BigDecimal(i % 3), BigDecimal(i % 2))).toDF("a", "b")
val decimalRes = decimalData.stat.cov("a", "b")
assert(math.abs(decimalRes) < 1e-12)
}
+
+ test("crosstab") {
+ val df = Seq((0, 0), (2, 1), (1, 0), (2, 0), (0, 0), (2, 0)).toDF("a", "b")
+ val crosstab = df.stat.crosstab("a", "b")
+ val columnNames = crosstab.schema.fieldNames
+ assert(columnNames(0) === "a_b")
+ assert(columnNames(1) === "0")
+ assert(columnNames(2) === "1")
+ val rows: Array[Row] = crosstab.collect().sortBy(_.getString(0))
+ assert(rows(0).get(0).toString === "0")
+ assert(rows(0).getLong(1) === 2L)
+ assert(rows(0).get(2) === null)
+ assert(rows(1).get(0).toString === "1")
+ assert(rows(1).getLong(1) === 1L)
+ assert(rows(1).get(2) === null)
+ assert(rows(2).get(0).toString === "2")
+ assert(rows(2).getLong(1) === 2L)
+ assert(rows(2).getLong(2) === 1L)
+ }
+
+ test("Frequent Items") {
+ val rows = Seq.tabulate(1000) { i =>
+ if (i % 3 == 0) (1, toLetter(1), -1.0) else (i, toLetter(i), i * -1.0)
+ }
+ val df = rows.toDF("numbers", "letters", "negDoubles")
+
+ val results = df.stat.freqItems(Array("numbers", "letters"), 0.1)
+ val items = results.collect().head
+ items.getSeq[Int](0) should contain (1)
+ items.getSeq[String](1) should contain (toLetter(1))
+
+ val singleColResults = df.stat.freqItems(Array("negDoubles"), 0.1)
+ val items2 = singleColResults.collect().head
+ items2.getSeq[Double](0) should contain (-1.0)
+ }
}