aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/stat.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/ml/stat.py')
-rw-r--r--python/pyspark/ml/stat.py61
1 files changed, 61 insertions, 0 deletions
diff --git a/python/pyspark/ml/stat.py b/python/pyspark/ml/stat.py
index db043ff68f..079b0833e1 100644
--- a/python/pyspark/ml/stat.py
+++ b/python/pyspark/ml/stat.py
@@ -71,6 +71,67 @@ class ChiSquareTest(object):
return _java2py(sc, javaTestObj.test(*args))
+class Correlation(object):
+ """
+ .. note:: Experimental
+
+ Compute the correlation matrix for the input dataset of Vectors using the specified method.
+ Methods currently supported: `pearson` (default), `spearman`.
+
+ .. note:: For Spearman, a rank correlation, we need to create an RDD[Double] for each column
+ and sort it in order to retrieve the ranks and then join the columns back into an RDD[Vector],
+ which is fairly costly. Cache the input Dataset before calling corr with `method = 'spearman'`
+ to avoid recomputing the common lineage.
+
+ :param dataset:
+ A dataset or a dataframe.
+ :param column:
+ The name of the column of vectors for which the correlation coefficient needs
+ to be computed. This must be a column of the dataset, and it must contain
+ Vector objects.
+ :param method:
+ String specifying the method to use for computing correlation.
+ Supported: `pearson` (default), `spearman`.
+ :return:
+ A dataframe that contains the correlation matrix of the column of vectors. This
+ dataframe contains a single row and a single column of name
+ '$METHODNAME($COLUMN)'.
+
+ >>> from pyspark.ml.linalg import Vectors
+ >>> from pyspark.ml.stat import Correlation
+ >>> dataset = [[Vectors.dense([1, 0, 0, -2])],
+ ... [Vectors.dense([4, 5, 0, 3])],
+ ... [Vectors.dense([6, 7, 0, 8])],
+ ... [Vectors.dense([9, 0, 0, 1])]]
+ >>> dataset = spark.createDataFrame(dataset, ['features'])
+ >>> pearsonCorr = Correlation.corr(dataset, 'features', 'pearson').collect()[0][0]
+ >>> print(str(pearsonCorr).replace('nan', 'NaN'))
+ DenseMatrix([[ 1. , 0.0556..., NaN, 0.4004...],
+ [ 0.0556..., 1. , NaN, 0.9135...],
+ [ NaN, NaN, 1. , NaN],
+ [ 0.4004..., 0.9135..., NaN, 1. ]])
+ >>> spearmanCorr = Correlation.corr(dataset, 'features', method='spearman').collect()[0][0]
+ >>> print(str(spearmanCorr).replace('nan', 'NaN'))
+ DenseMatrix([[ 1. , 0.1054..., NaN, 0.4 ],
+ [ 0.1054..., 1. , NaN, 0.9486... ],
+ [ NaN, NaN, 1. , NaN],
+ [ 0.4 , 0.9486... , NaN, 1. ]])
+
+ .. versionadded:: 2.2.0
+
+ """
+ @staticmethod
+ @since("2.2.0")
+ def corr(dataset, column, method="pearson"):
+ """
+ Compute the correlation matrix with specified method using dataset.
+ """
+ sc = SparkContext._active_spark_context
+ javaCorrObj = _jvm().org.apache.spark.ml.stat.Correlation
+ args = [_py2java(sc, arg) for arg in (dataset, column, method)]
+ return _java2py(sc, javaCorrObj.corr(*args))
+
+
if __name__ == "__main__":
import doctest
import pyspark.ml.stat