aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/mllib/stat.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/mllib/stat.py')
-rw-r--r--python/pyspark/mllib/stat.py16
1 files changed, 15 insertions, 1 deletions
diff --git a/python/pyspark/mllib/stat.py b/python/pyspark/mllib/stat.py
index 0700f8a8e5..1980f5b03f 100644
--- a/python/pyspark/mllib/stat.py
+++ b/python/pyspark/mllib/stat.py
@@ -22,6 +22,7 @@ Python package for statistical functions in MLlib.
from pyspark import RDD
from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper
from pyspark.mllib.linalg import Matrix, _convert_to_vector
+from pyspark.mllib.regression import LabeledPoint
__all__ = ['MultivariateStatisticalSummary', 'ChiSqTestResult', 'Statistics']
@@ -107,6 +108,11 @@ class Statistics(object):
"""
Computes column-wise summary statistics for the input RDD[Vector].
+ :param rdd: an RDD[Vector] for which column-wise summary statistics
+ are to be computed.
+ :return: :class:`MultivariateStatisticalSummary` object containing
+ column-wise summary statistics.
+
>>> from pyspark.mllib.linalg import Vectors
>>> rdd = sc.parallelize([Vectors.dense([2, 0, 0, -2]),
... Vectors.dense([4, 5, 0, 3]),
@@ -140,6 +146,13 @@ class Statistics(object):
to specify the method to be used for single RDD inout.
If two RDDs of floats are passed in, a single float is returned.
+ :param x: an RDD of vector for which the correlation matrix is to be computed,
+ or an RDD of float of the same cardinality as y when y is specified.
+ :param y: an RDD of float of the same cardinality as x.
+ :param method: String specifying the method to use for computing correlation.
+ Supported: `pearson` (default), `spearman`
+ :return: Correlation matrix comparing columns in x.
+
>>> x = sc.parallelize([1.0, 0.0, -2.0], 2)
>>> y = sc.parallelize([4.0, 5.0, 3.0], 2)
>>> zeros = sc.parallelize([0.0, 0.0, 0.0], 2)
@@ -242,7 +255,6 @@ class Statistics(object):
>>> print round(chi.statistic, 4)
21.9958
- >>> from pyspark.mllib.regression import LabeledPoint
>>> data = [LabeledPoint(0.0, Vectors.dense([0.5, 10.0])),
... LabeledPoint(0.0, Vectors.dense([1.5, 20.0])),
... LabeledPoint(1.0, Vectors.dense([1.5, 30.0])),
@@ -257,6 +269,8 @@ class Statistics(object):
1.5
"""
if isinstance(observed, RDD):
+ if not isinstance(observed.first(), LabeledPoint):
+ raise ValueError("observed should be an RDD of LabeledPoint")
jmodels = callMLlibFunc("chiSqTest", observed)
return [ChiSqTestResult(m) for m in jmodels]