diff options
author | Doris Xin <doris.s.xin@gmail.com> | 2014-08-12 23:47:42 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2014-08-12 23:47:42 -0700 |
commit | fe4735958e62b1b32a01960503876000f3d2e520 (patch) | |
tree | 4d17db757f96e2d70017cb2990b2b020d5cdc4b1 /python/pyspark/mllib | |
parent | 2bd812639c3d8c62a725fb7577365ef0816f2898 (diff) | |
download | spark-fe4735958e62b1b32a01960503876000f3d2e520.tar.gz spark-fe4735958e62b1b32a01960503876000f3d2e520.tar.bz2 spark-fe4735958e62b1b32a01960503876000f3d2e520.zip |
[SPARK-2993] [MLLib] colStats (wrapper around MultivariateStatisticalSummary) in Statistics
For both Scala and Python.
The ser/de util functions were moved out of `PythonMLLibAPI` and into their own object to avoid creating the `PythonMLLibAPI` object inside of `MultivariateStatisticalSummarySerialized`, which is then referenced inside of a method in `PythonMLLibAPI`.
`MultivariateStatisticalSummarySerialized` was created to serialize the `Vector` fields in `MultivariateStatisticalSummary`.
Author: Doris Xin <doris.s.xin@gmail.com>
Closes #1911 from dorx/colStats and squashes the following commits:
77b9924 [Doris Xin] developerAPI tag
de9cbbe [Doris Xin] reviewer comments and moved more ser/de
459faba [Doris Xin] colStats in Statistics for both Scala and Python
Diffstat (limited to 'python/pyspark/mllib')
-rw-r--r-- | python/pyspark/mllib/stat.py | 66 |
1 files changed, 65 insertions, 1 deletions
diff --git a/python/pyspark/mllib/stat.py b/python/pyspark/mllib/stat.py index 982906b9d0..a73abc5ff9 100644 --- a/python/pyspark/mllib/stat.py +++ b/python/pyspark/mllib/stat.py @@ -22,12 +22,76 @@ Python package for statistical functions in MLlib. from pyspark.mllib._common import \ _get_unmangled_double_vector_rdd, _get_unmangled_rdd, \ _serialize_double, _serialize_double_vector, \ - _deserialize_double, _deserialize_double_matrix + _deserialize_double, _deserialize_double_matrix, _deserialize_double_vector + + +class MultivariateStatisticalSummary(object): + + """ + Trait for multivariate statistical summary of a data matrix. + """ + + def __init__(self, sc, java_summary): + """ + :param sc: Spark context + :param java_summary: Handle to Java summary object + """ + self._sc = sc + self._java_summary = java_summary + + def __del__(self): + self._sc._gateway.detach(self._java_summary) + + def mean(self): + return _deserialize_double_vector(self._java_summary.mean()) + + def variance(self): + return _deserialize_double_vector(self._java_summary.variance()) + + def count(self): + return self._java_summary.count() + + def numNonzeros(self): + return _deserialize_double_vector(self._java_summary.numNonzeros()) + + def max(self): + return _deserialize_double_vector(self._java_summary.max()) + + def min(self): + return _deserialize_double_vector(self._java_summary.min()) class Statistics(object): @staticmethod + def colStats(X): + """ + Computes column-wise summary statistics for the input RDD[Vector]. + + >>> from linalg import Vectors + >>> rdd = sc.parallelize([Vectors.dense([2, 0, 0, -2]), + ... Vectors.dense([4, 5, 0, 3]), + ... Vectors.dense([6, 7, 0, 8])]) + >>> cStats = Statistics.colStats(rdd) + >>> cStats.mean() + array([ 4., 4., 0., 3.]) + >>> cStats.variance() + array([ 4., 13., 0., 25.]) + >>> cStats.count() + 3L + >>> cStats.numNonzeros() + array([ 3., 2., 0., 3.]) + >>> cStats.max() + array([ 6., 7., 0., 8.]) + >>> cStats.min() + array([ 2., 0., 0., -2.]) + """ + sc = X.ctx + Xser = _get_unmangled_double_vector_rdd(X) + cStats = sc._jvm.PythonMLLibAPI().colStats(Xser._jrdd) + return MultivariateStatisticalSummary(sc, cStats) + + @staticmethod def corr(x, y=None, method=None): """ Compute the correlation (matrix) for the input RDD(s) using the |