aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/mllib/stat.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/mllib/stat.py')
-rw-r--r--python/pyspark/mllib/stat.py66
1 files changed, 65 insertions, 1 deletions
diff --git a/python/pyspark/mllib/stat.py b/python/pyspark/mllib/stat.py
index 982906b9d0..a73abc5ff9 100644
--- a/python/pyspark/mllib/stat.py
+++ b/python/pyspark/mllib/stat.py
@@ -22,12 +22,76 @@ Python package for statistical functions in MLlib.
from pyspark.mllib._common import \
_get_unmangled_double_vector_rdd, _get_unmangled_rdd, \
_serialize_double, _serialize_double_vector, \
- _deserialize_double, _deserialize_double_matrix
+ _deserialize_double, _deserialize_double_matrix, _deserialize_double_vector
+
+
+class MultivariateStatisticalSummary(object):
+
+ """
+ Trait for multivariate statistical summary of a data matrix.
+ """
+
+ def __init__(self, sc, java_summary):
+ """
+ :param sc: Spark context
+ :param java_summary: Handle to Java summary object
+ """
+ self._sc = sc
+ self._java_summary = java_summary
+
+ def __del__(self):
+ self._sc._gateway.detach(self._java_summary)
+
+ def mean(self):
+ return _deserialize_double_vector(self._java_summary.mean())
+
+ def variance(self):
+ return _deserialize_double_vector(self._java_summary.variance())
+
+ def count(self):
+ return self._java_summary.count()
+
+ def numNonzeros(self):
+ return _deserialize_double_vector(self._java_summary.numNonzeros())
+
+ def max(self):
+ return _deserialize_double_vector(self._java_summary.max())
+
+ def min(self):
+ return _deserialize_double_vector(self._java_summary.min())
class Statistics(object):
@staticmethod
+ def colStats(X):
+ """
+ Computes column-wise summary statistics for the input RDD[Vector].
+
+ >>> from linalg import Vectors
+ >>> rdd = sc.parallelize([Vectors.dense([2, 0, 0, -2]),
+ ... Vectors.dense([4, 5, 0, 3]),
+ ... Vectors.dense([6, 7, 0, 8])])
+ >>> cStats = Statistics.colStats(rdd)
+ >>> cStats.mean()
+ array([ 4., 4., 0., 3.])
+ >>> cStats.variance()
+ array([ 4., 13., 0., 25.])
+ >>> cStats.count()
+ 3L
+ >>> cStats.numNonzeros()
+ array([ 3., 2., 0., 3.])
+ >>> cStats.max()
+ array([ 6., 7., 0., 8.])
+ >>> cStats.min()
+ array([ 2., 0., 0., -2.])
+ """
+ sc = X.ctx
+ Xser = _get_unmangled_double_vector_rdd(X)
+ cStats = sc._jvm.PythonMLLibAPI().colStats(Xser._jrdd)
+ return MultivariateStatisticalSummary(sc, cStats)
+
+ @staticmethod
def corr(x, y=None, method=None):
"""
Compute the correlation (matrix) for the input RDD(s) using the