aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorDoris Xin <doris.s.xin@gmail.com>2014-08-12 23:47:42 -0700
committerXiangrui Meng <meng@databricks.com>2014-08-12 23:47:42 -0700
commitfe4735958e62b1b32a01960503876000f3d2e520 (patch)
tree4d17db757f96e2d70017cb2990b2b020d5cdc4b1 /python
parent2bd812639c3d8c62a725fb7577365ef0816f2898 (diff)
downloadspark-fe4735958e62b1b32a01960503876000f3d2e520.tar.gz
spark-fe4735958e62b1b32a01960503876000f3d2e520.tar.bz2
spark-fe4735958e62b1b32a01960503876000f3d2e520.zip
[SPARK-2993] [MLLib] colStats (wrapper around MultivariateStatisticalSummary) in Statistics
For both Scala and Python. The ser/de util functions were moved out of `PythonMLLibAPI` and into their own object to avoid creating the `PythonMLLibAPI` object inside of `MultivariateStatisticalSummarySerialized`, which is then referenced inside of a method in `PythonMLLibAPI`. `MultivariateStatisticalSummarySerialized` was created to serialize the `Vector` fields in `MultivariateStatisticalSummary`. Author: Doris Xin <doris.s.xin@gmail.com> Closes #1911 from dorx/colStats and squashes the following commits: 77b9924 [Doris Xin] developerAPI tag de9cbbe [Doris Xin] reviewer comments and moved more ser/de 459faba [Doris Xin] colStats in Statistics for both Scala and Python
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/mllib/stat.py66
1 files changed, 65 insertions, 1 deletions
diff --git a/python/pyspark/mllib/stat.py b/python/pyspark/mllib/stat.py
index 982906b9d0..a73abc5ff9 100644
--- a/python/pyspark/mllib/stat.py
+++ b/python/pyspark/mllib/stat.py
@@ -22,12 +22,76 @@ Python package for statistical functions in MLlib.
from pyspark.mllib._common import \
_get_unmangled_double_vector_rdd, _get_unmangled_rdd, \
_serialize_double, _serialize_double_vector, \
- _deserialize_double, _deserialize_double_matrix
+ _deserialize_double, _deserialize_double_matrix, _deserialize_double_vector
+
+
+class MultivariateStatisticalSummary(object):
+
+ """
+ Trait for multivariate statistical summary of a data matrix.
+ """
+
+ def __init__(self, sc, java_summary):
+ """
+ :param sc: Spark context
+ :param java_summary: Handle to Java summary object
+ """
+ self._sc = sc
+ self._java_summary = java_summary
+
+ def __del__(self):
+ self._sc._gateway.detach(self._java_summary)
+
+ def mean(self):
+ return _deserialize_double_vector(self._java_summary.mean())
+
+ def variance(self):
+ return _deserialize_double_vector(self._java_summary.variance())
+
+ def count(self):
+ return self._java_summary.count()
+
+ def numNonzeros(self):
+ return _deserialize_double_vector(self._java_summary.numNonzeros())
+
+ def max(self):
+ return _deserialize_double_vector(self._java_summary.max())
+
+ def min(self):
+ return _deserialize_double_vector(self._java_summary.min())
class Statistics(object):
@staticmethod
+ def colStats(X):
+ """
+ Computes column-wise summary statistics for the input RDD[Vector].
+
+ >>> from linalg import Vectors
+ >>> rdd = sc.parallelize([Vectors.dense([2, 0, 0, -2]),
+ ... Vectors.dense([4, 5, 0, 3]),
+ ... Vectors.dense([6, 7, 0, 8])])
+ >>> cStats = Statistics.colStats(rdd)
+ >>> cStats.mean()
+ array([ 4., 4., 0., 3.])
+ >>> cStats.variance()
+ array([ 4., 13., 0., 25.])
+ >>> cStats.count()
+ 3L
+ >>> cStats.numNonzeros()
+ array([ 3., 2., 0., 3.])
+ >>> cStats.max()
+ array([ 6., 7., 0., 8.])
+ >>> cStats.min()
+ array([ 2., 0., 0., -2.])
+ """
+ sc = X.ctx
+ Xser = _get_unmangled_double_vector_rdd(X)
+ cStats = sc._jvm.PythonMLLibAPI().colStats(Xser._jrdd)
+ return MultivariateStatisticalSummary(sc, cStats)
+
+ @staticmethod
def corr(x, y=None, method=None):
"""
Compute the correlation (matrix) for the input RDD(s) using the