diff options
Diffstat (limited to 'python/pyspark/mllib/stat.py')
-rw-r--r-- | python/pyspark/mllib/stat.py | 16 |
1 files changed, 15 insertions, 1 deletions
diff --git a/python/pyspark/mllib/stat.py b/python/pyspark/mllib/stat.py index 0700f8a8e5..1980f5b03f 100644 --- a/python/pyspark/mllib/stat.py +++ b/python/pyspark/mllib/stat.py @@ -22,6 +22,7 @@ Python package for statistical functions in MLlib. from pyspark import RDD from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper from pyspark.mllib.linalg import Matrix, _convert_to_vector +from pyspark.mllib.regression import LabeledPoint __all__ = ['MultivariateStatisticalSummary', 'ChiSqTestResult', 'Statistics'] @@ -107,6 +108,11 @@ class Statistics(object): """ Computes column-wise summary statistics for the input RDD[Vector]. + :param rdd: an RDD[Vector] for which column-wise summary statistics + are to be computed. + :return: :class:`MultivariateStatisticalSummary` object containing + column-wise summary statistics. + >>> from pyspark.mllib.linalg import Vectors >>> rdd = sc.parallelize([Vectors.dense([2, 0, 0, -2]), ... Vectors.dense([4, 5, 0, 3]), @@ -140,6 +146,13 @@ class Statistics(object): to specify the method to be used for single RDD inout. If two RDDs of floats are passed in, a single float is returned. + :param x: an RDD of vector for which the correlation matrix is to be computed, + or an RDD of float of the same cardinality as y when y is specified. + :param y: an RDD of float of the same cardinality as x. + :param method: String specifying the method to use for computing correlation. + Supported: `pearson` (default), `spearman` + :return: Correlation matrix comparing columns in x. + >>> x = sc.parallelize([1.0, 0.0, -2.0], 2) >>> y = sc.parallelize([4.0, 5.0, 3.0], 2) >>> zeros = sc.parallelize([0.0, 0.0, 0.0], 2) @@ -242,7 +255,6 @@ class Statistics(object): >>> print round(chi.statistic, 4) 21.9958 - >>> from pyspark.mllib.regression import LabeledPoint >>> data = [LabeledPoint(0.0, Vectors.dense([0.5, 10.0])), ... LabeledPoint(0.0, Vectors.dense([1.5, 20.0])), ... LabeledPoint(1.0, Vectors.dense([1.5, 30.0])), @@ -257,6 +269,8 @@ class Statistics(object): 1.5 """ if isinstance(observed, RDD): + if not isinstance(observed.first(), LabeledPoint): + raise ValueError("observed should be an RDD of LabeledPoint") jmodels = callMLlibFunc("chiSqTest", observed) return [ChiSqTestResult(m) for m in jmodels] |