aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/sql/dataframe.py26
-rw-r--r--python/pyspark/sql/tests.py6
2 files changed, 32 insertions, 0 deletions
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 8ddcff8fcd..aac5b8c4c5 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -875,6 +875,27 @@ class DataFrame(object):
return DataFrame(self._jdf.na().fill(value, self._jseq(subset)), self.sql_ctx)
+ def corr(self, col1, col2, method=None):
+ """
+ Calculates the correlation of two columns of a DataFrame as a double value. Currently only
+ supports the Pearson Correlation Coefficient.
+ :func:`DataFrame.corr` and :func:`DataFrameStatFunctions.corr` are aliases.
+
+ :param col1: The name of the first column
+ :param col2: The name of the second column
+ :param method: The correlation method. Currently only supports "pearson"
+ """
+ if not isinstance(col1, str):
+ raise ValueError("col1 should be a string.")
+ if not isinstance(col2, str):
+ raise ValueError("col2 should be a string.")
+ if not method:
+ method = "pearson"
+ if not method == "pearson":
+ raise ValueError("Currently only the calculation of the Pearson Correlation " +
+ "coefficient is supported.")
+ return self._jdf.stat().corr(col1, col2, method)
+
def cov(self, col1, col2):
"""
Calculate the sample covariance for the given columns, specified by their names, as a
@@ -1359,6 +1380,11 @@ class DataFrameStatFunctions(object):
def __init__(self, df):
self.df = df
+ def corr(self, col1, col2, method=None):
+ return self.df.corr(col1, col2, method)
+
+ corr.__doc__ = DataFrame.corr.__doc__
+
def cov(self, col1, col2):
return self.df.cov(col1, col2)
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 613efc0ac0..d652c302a5 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -394,6 +394,12 @@ class SQLTests(ReusedPySparkTestCase):
self.assertTrue(95 < g.agg(functions.approxCountDistinct(df.key)).first()[0])
self.assertEqual(100, g.agg(functions.countDistinct(df.value)).first()[0])
+ def test_corr(self):
+ import math
+ df = self.sc.parallelize([Row(a=i, b=math.sqrt(i)) for i in range(10)]).toDF()
+ corr = df.stat.corr("a", "b")
+ self.assertTrue(abs(corr - 0.95734012) < 1e-6)
+
def test_cov(self):
df = self.sc.parallelize([Row(a=i, b=2 * i) for i in range(10)]).toDF()
cov = df.stat.cov("a", "b")