aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorBurak Yavuz <brkyvz@gmail.com>2015-05-03 21:44:39 -0700
committerReynold Xin <rxin@databricks.com>2015-05-03 21:44:39 -0700
commit9646018bb4466433521b4e602b808f16e8d0ffdb (patch)
treefe4df54135dcc0806af5f801c4753d9b88dbcc16 /sql
parent1ffa8cb91f8badf12a8aa190dc25920715a00db7 (diff)
downloadspark-9646018bb4466433521b4e602b808f16e8d0ffdb.tar.gz
spark-9646018bb4466433521b4e602b808f16e8d0ffdb.tar.bz2
spark-9646018bb4466433521b4e602b808f16e8d0ffdb.zip
[SPARK-7241] Pearson correlation for DataFrames
submitting this PR from a phone, excuse the brevity. adds Pearson correlation to Dataframes, reusing the covariance calculation code cc mengxr rxin Author: Burak Yavuz <brkyvz@gmail.com> Closes #5858 from brkyvz/df-corr and squashes the following commits: 285b838 [Burak Yavuz] addressed comments v2.0 d10babb [Burak Yavuz] addressed comments v0.2 4b74b24 [Burak Yavuz] Merge branch 'master' of github.com:apache/spark into df-corr 4fe693b [Burak Yavuz] addressed comments v0.1 a682d06 [Burak Yavuz] ready for PR
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala26
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala58
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java7
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala33
4 files changed, 98 insertions, 26 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
index e8fa829477..9035321052 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
@@ -28,6 +28,32 @@ import org.apache.spark.sql.execution.stat._
final class DataFrameStatFunctions private[sql](df: DataFrame) {
/**
+ * Calculates the correlation of two columns of a DataFrame. Currently only supports the Pearson
+ * Correlation Coefficient. For Spearman Correlation, consider using RDD methods found in
+ * MLlib's Statistics.
+ *
+ * @param col1 the name of the column
+ * @param col2 the name of the column to calculate the correlation against
+ * @return The Pearson Correlation Coefficient as a Double.
+ */
+ def corr(col1: String, col2: String, method: String): Double = {
+ require(method == "pearson", "Currently only the calculation of the Pearson Correlation " +
+ "coefficient is supported.")
+ StatFunctions.pearsonCorrelation(df, Seq(col1, col2))
+ }
+
+ /**
+ * Calculates the Pearson Correlation Coefficient of two columns of a DataFrame.
+ *
+ * @param col1 the name of the column
+ * @param col2 the name of the column to calculate the correlation against
+ * @return The Pearson Correlation Coefficient as a Double.
+ */
+ def corr(col1: String, col2: String): Double = {
+ corr(col1, col2, "pearson")
+ }
+
+ /**
* Finding frequent items for columns, possibly with false positives. Using the
* frequent element count algorithm described in
* [[http://dx.doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou]].
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
index d4a94c24d9..67b48e58b1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
@@ -23,29 +23,43 @@ import org.apache.spark.sql.types.{DoubleType, NumericType}
private[sql] object StatFunctions {
+ /** Calculate the Pearson Correlation Coefficient for the given columns */
+ private[sql] def pearsonCorrelation(df: DataFrame, cols: Seq[String]): Double = {
+ val counts = collectStatisticalData(df, cols)
+ counts.Ck / math.sqrt(counts.MkX * counts.MkY)
+ }
+
/** Helper class to simplify tracking and merging counts. */
private class CovarianceCounter extends Serializable {
- var xAvg = 0.0
- var yAvg = 0.0
- var Ck = 0.0
- var count = 0L
+ var xAvg = 0.0 // the mean of all examples seen so far in col1
+ var yAvg = 0.0 // the mean of all examples seen so far in col2
+ var Ck = 0.0 // the co-moment after k examples
+ var MkX = 0.0 // sum of squares of differences from the (current) mean for col1
+ var MkY = 0.0 // sum of squares of differences from the (current) mean for col1
+ var count = 0L // count of observed examples
// add an example to the calculation
def add(x: Double, y: Double): this.type = {
- val oldX = xAvg
+ val deltaX = x - xAvg
+ val deltaY = y - yAvg
count += 1
- xAvg += (x - xAvg) / count
- yAvg += (y - yAvg) / count
- Ck += (y - yAvg) * (x - oldX)
+ xAvg += deltaX / count
+ yAvg += deltaY / count
+ Ck += deltaX * (y - yAvg)
+ MkX += deltaX * (x - xAvg)
+ MkY += deltaY * (y - yAvg)
this
}
// merge counters from other partitions. Formula can be found at:
- // http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Covariance
+ // http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
def merge(other: CovarianceCounter): this.type = {
val totalCount = count + other.count
- Ck += other.Ck +
- (xAvg - other.xAvg) * (yAvg - other.yAvg) * count / totalCount * other.count
+ val deltaX = xAvg - other.xAvg
+ val deltaY = yAvg - other.yAvg
+ Ck += other.Ck + deltaX * deltaY * count / totalCount * other.count
xAvg = (xAvg * count + other.xAvg * other.count) / totalCount
yAvg = (yAvg * count + other.yAvg * other.count) / totalCount
+ MkX += other.MkX + deltaX * deltaX * count / totalCount * other.count
+ MkY += other.MkY + deltaY * deltaY * count / totalCount * other.count
count = totalCount
this
}
@@ -53,13 +67,7 @@ private[sql] object StatFunctions {
def cov: Double = Ck / (count - 1)
}
- /**
- * Calculate the covariance of two numerical columns of a DataFrame.
- * @param df The DataFrame
- * @param cols the column names
- * @return the covariance of the two columns.
- */
- private[sql] def calculateCov(df: DataFrame, cols: Seq[String]): Double = {
+ private def collectStatisticalData(df: DataFrame, cols: Seq[String]): CovarianceCounter = {
require(cols.length == 2, "Currently cov supports calculating the covariance " +
"between two columns.")
cols.map(name => (name, df.schema.fields.find(_.name == name))).foreach { case (name, data) =>
@@ -68,13 +76,23 @@ private[sql] object StatFunctions {
s"with dataType ${data.get.dataType} not supported.")
}
val columns = cols.map(n => Column(Cast(Column(n).expr, DoubleType)))
- val counts = df.select(columns:_*).rdd.aggregate(new CovarianceCounter)(
+ df.select(columns: _*).rdd.aggregate(new CovarianceCounter)(
seqOp = (counter, row) => {
counter.add(row.getDouble(0), row.getDouble(1))
},
combOp = (baseCounter, other) => {
baseCounter.merge(other)
- })
+ })
+ }
+
+ /**
+ * Calculate the covariance of two numerical columns of a DataFrame.
+ * @param df The DataFrame
+ * @param cols the column names
+ * @return the covariance of the two columns.
+ */
+ private[sql] def calculateCov(df: DataFrame, cols: Seq[String]): Double = {
+ val counts = collectStatisticalData(df, cols)
counts.cov
}
}
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
index 96fe66d0b8..78e847239f 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
@@ -188,6 +188,13 @@ public class JavaDataFrameSuite {
}
@Test
+ public void testCorrelation() {
+ DataFrame df = context.table("testData2");
+ Double pearsonCorr = df.stat().corr("a", "b", "pearson");
+ Assert.assertTrue(Math.abs(pearsonCorr) < 1e-6);
+ }
+
+ @Test
public void testCovariance() {
DataFrame df = context.table("testData2");
Double result = df.stat().cov("a", "b");
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
index 4f5a2ff696..06764d2a12 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
@@ -30,10 +30,10 @@ class DataFrameStatSuite extends FunSuite {
def toLetter(i: Int): String = (i + 97).toChar.toString
test("Frequent Items") {
- val rows = Array.tabulate(1000) { i =>
+ val rows = Seq.tabulate(1000) { i =>
if (i % 3 == 0) (1, toLetter(1), -1.0) else (i, toLetter(i), i * -1.0)
}
- val df = sqlCtx.sparkContext.parallelize(rows).toDF("numbers", "letters", "negDoubles")
+ val df = rows.toDF("numbers", "letters", "negDoubles")
val results = df.stat.freqItems(Array("numbers", "letters"), 0.1)
val items = results.collect().head
@@ -43,19 +43,40 @@ class DataFrameStatSuite extends FunSuite {
val singleColResults = df.stat.freqItems(Array("negDoubles"), 0.1)
val items2 = singleColResults.collect().head
items2.getSeq[Double](0) should contain (-1.0)
+ }
+ test("pearson correlation") {
+ val df = Seq.tabulate(10)(i => (i, 2 * i, i * -1.0)).toDF("a", "b", "c")
+ val corr1 = df.stat.corr("a", "b", "pearson")
+ assert(math.abs(corr1 - 1.0) < 1e-12)
+ val corr2 = df.stat.corr("a", "c", "pearson")
+ assert(math.abs(corr2 + 1.0) < 1e-12)
+ // non-trivial example. To reproduce in python, use:
+ // >>> from scipy.stats import pearsonr
+ // >>> import numpy as np
+ // >>> a = np.array(range(20))
+ // >>> b = np.array([x * x - 2 * x + 3.5 for x in range(20)])
+ // >>> pearsonr(a, b)
+ // (0.95723391394758572, 3.8902121417802199e-11)
+ // In R, use:
+ // > a <- 0:19
+ // > b <- mapply(function(x) x * x - 2 * x + 3.5, a)
+ // > cor(a, b)
+ // [1] 0.957233913947585835
+ val df2 = Seq.tabulate(20)(x => (x, x * x - 2 * x + 3.5)).toDF("a", "b")
+ val corr3 = df2.stat.corr("a", "b", "pearson")
+ assert(math.abs(corr3 - 0.95723391394758572) < 1e-12)
}
test("covariance") {
- val rows = Array.tabulate(10)(i => (i, 2.0 * i, toLetter(i)))
- val df = sqlCtx.sparkContext.parallelize(rows).toDF("singles", "doubles", "letters")
+ val df = Seq.tabulate(10)(i => (i, 2.0 * i, toLetter(i))).toDF("singles", "doubles", "letters")
val results = df.stat.cov("singles", "doubles")
- assert(math.abs(results - 55.0 / 3) < 1e-6)
+ assert(math.abs(results - 55.0 / 3) < 1e-12)
intercept[IllegalArgumentException] {
df.stat.cov("singles", "letters") // doesn't accept non-numerical dataTypes
}
val decimalRes = decimalData.stat.cov("a", "b")
- assert(math.abs(decimalRes) < 1e-6)
+ assert(math.abs(decimalRes) < 1e-12)
}
}