aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLiang-Chi Hsieh <viirya@gmail.com>2017-04-07 11:00:10 +0200
committerNick Pentreath <nickp@za.ibm.com>2017-04-07 11:00:10 +0200
commit1a52a62377a87cec493c8c6711bfd44e779c7973 (patch)
treeddaccd67e7298284cfac3f01d038449900a001b3
parentad3cc1312db3b5667cea134940a09896a4609b74 (diff)
downloadspark-1a52a62377a87cec493c8c6711bfd44e779c7973.tar.gz
spark-1a52a62377a87cec493c8c6711bfd44e779c7973.tar.bz2
spark-1a52a62377a87cec493c8c6711bfd44e779c7973.zip
[SPARK-20076][ML][PYSPARK] Add Python interface for ml.stats.Correlation
## What changes were proposed in this pull request? The Dataframes-based support for the correlation statistics is added in #17108. This patch adds the Python interface for it. ## How was this patch tested? Python unit test. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Liang-Chi Hsieh <viirya@gmail.com> Closes #17494 from viirya/correlation-python-api.
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/stat/Correlation.scala8
-rw-r--r--python/pyspark/ml/stat.py61
2 files changed, 65 insertions, 4 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/Correlation.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/Correlation.scala
index d3c84b77d2..e185bc8a6f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/stat/Correlation.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/stat/Correlation.scala
@@ -38,7 +38,7 @@ object Correlation {
/**
* :: Experimental ::
- * Compute the correlation matrix for the input RDD of Vectors using the specified method.
+ * Compute the correlation matrix for the input Dataset of Vectors using the specified method.
* Methods currently supported: `pearson` (default), `spearman`.
*
* @param dataset A dataset or a dataframe
@@ -56,14 +56,14 @@ object Correlation {
* Here is how to access the correlation coefficient:
* {{{
* val data: Dataset[Vector] = ...
- * val Row(coeff: Matrix) = Statistics.corr(data, "value").head
+ * val Row(coeff: Matrix) = Correlation.corr(data, "value").head
* // coeff now contains the Pearson correlation matrix.
* }}}
*
* @note For Spearman, a rank correlation, we need to create an RDD[Double] for each column
* and sort it in order to retrieve the ranks and then join the columns back into an RDD[Vector],
- * which is fairly costly. Cache the input RDD before calling corr with `method = "spearman"` to
- * avoid recomputing the common lineage.
+ * which is fairly costly. Cache the input Dataset before calling corr with `method = "spearman"`
+ * to avoid recomputing the common lineage.
*/
@Since("2.2.0")
def corr(dataset: Dataset[_], column: String, method: String): DataFrame = {
diff --git a/python/pyspark/ml/stat.py b/python/pyspark/ml/stat.py
index db043ff68f..079b0833e1 100644
--- a/python/pyspark/ml/stat.py
+++ b/python/pyspark/ml/stat.py
@@ -71,6 +71,67 @@ class ChiSquareTest(object):
return _java2py(sc, javaTestObj.test(*args))
+class Correlation(object):
+ """
+ .. note:: Experimental
+
+ Compute the correlation matrix for the input dataset of Vectors using the specified method.
+ Methods currently supported: `pearson` (default), `spearman`.
+
+ .. note:: For Spearman, a rank correlation, we need to create an RDD[Double] for each column
+ and sort it in order to retrieve the ranks and then join the columns back into an RDD[Vector],
+ which is fairly costly. Cache the input Dataset before calling corr with `method = 'spearman'`
+ to avoid recomputing the common lineage.
+
+ :param dataset:
+ A dataset or a dataframe.
+ :param column:
+ The name of the column of vectors for which the correlation coefficient needs
+ to be computed. This must be a column of the dataset, and it must contain
+ Vector objects.
+ :param method:
+ String specifying the method to use for computing correlation.
+ Supported: `pearson` (default), `spearman`.
+ :return:
+ A dataframe that contains the correlation matrix of the column of vectors. This
+ dataframe contains a single row and a single column of name
+ '$METHODNAME($COLUMN)'.
+
+ >>> from pyspark.ml.linalg import Vectors
+ >>> from pyspark.ml.stat import Correlation
+ >>> dataset = [[Vectors.dense([1, 0, 0, -2])],
+ ... [Vectors.dense([4, 5, 0, 3])],
+ ... [Vectors.dense([6, 7, 0, 8])],
+ ... [Vectors.dense([9, 0, 0, 1])]]
+ >>> dataset = spark.createDataFrame(dataset, ['features'])
+ >>> pearsonCorr = Correlation.corr(dataset, 'features', 'pearson').collect()[0][0]
+ >>> print(str(pearsonCorr).replace('nan', 'NaN'))
+ DenseMatrix([[ 1. , 0.0556..., NaN, 0.4004...],
+ [ 0.0556..., 1. , NaN, 0.9135...],
+ [ NaN, NaN, 1. , NaN],
+ [ 0.4004..., 0.9135..., NaN, 1. ]])
+ >>> spearmanCorr = Correlation.corr(dataset, 'features', method='spearman').collect()[0][0]
+ >>> print(str(spearmanCorr).replace('nan', 'NaN'))
+ DenseMatrix([[ 1. , 0.1054..., NaN, 0.4 ],
+ [ 0.1054..., 1. , NaN, 0.9486... ],
+ [ NaN, NaN, 1. , NaN],
+ [ 0.4 , 0.9486... , NaN, 1. ]])
+
+ .. versionadded:: 2.2.0
+
+ """
+ @staticmethod
+ @since("2.2.0")
+ def corr(dataset, column, method="pearson"):
+ """
+ Compute the correlation matrix with specified method using dataset.
+ """
+ sc = SparkContext._active_spark_context
+ javaCorrObj = _jvm().org.apache.spark.ml.stat.Correlation
+ args = [_py2java(sc, arg) for arg in (dataset, column, method)]
+ return _java2py(sc, javaCorrObj.corr(*args))
+
+
if __name__ == "__main__":
import doctest
import pyspark.ml.stat