aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/mllib/stat.py
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2014-10-30 22:25:18 -0700
committerXiangrui Meng <meng@databricks.com>2014-10-30 22:25:18 -0700
commit872fc669b497fb255db3212568f2a14c2ba0d5db (patch)
tree6dcaa7e0b251fa5f233171e2878a4dc428db2348 /python/pyspark/mllib/stat.py
parent0734d09320fe37edd3a02718511cda0bda852478 (diff)
downloadspark-872fc669b497fb255db3212568f2a14c2ba0d5db.tar.gz
spark-872fc669b497fb255db3212568f2a14c2ba0d5db.tar.bz2
spark-872fc669b497fb255db3212568f2a14c2ba0d5db.zip
[SPARK-4124] [MLlib] [PySpark] simplify serialization in MLlib Python API
Create several helper functions to call MLlib Java API, convert the arguments to Java type and convert return value to Python object automatically, this simplify serialization in MLlib Python API very much. After this, the MLlib Python API does not need to deal with serialization details anymore, it's easier to add new API. cc mengxr Author: Davies Liu <davies@databricks.com> Closes #2995 from davies/cleanup and squashes the following commits: 8fa6ec6 [Davies Liu] address comments 16b85a0 [Davies Liu] Merge branch 'master' of github.com:apache/spark into cleanup 43743e5 [Davies Liu] bugfix 731331f [Davies Liu] simplify serialization in MLlib Python API
Diffstat (limited to 'python/pyspark/mllib/stat.py')
-rw-r--r--python/pyspark/mllib/stat.py65
1 files changed, 13 insertions, 52 deletions
diff --git a/python/pyspark/mllib/stat.py b/python/pyspark/mllib/stat.py
index 84baf12b90..15f0652f83 100644
--- a/python/pyspark/mllib/stat.py
+++ b/python/pyspark/mllib/stat.py
@@ -19,66 +19,36 @@
Python package for statistical functions in MLlib.
"""
-from functools import wraps
-
-from pyspark import PickleSerializer
-from pyspark.mllib.linalg import _convert_to_vector, _to_java_object_rdd
+from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper
+from pyspark.mllib.linalg import _convert_to_vector
__all__ = ['MultivariateStatisticalSummary', 'Statistics']
-def serialize(f):
- ser = PickleSerializer()
-
- @wraps(f)
- def func(self):
- jvec = f(self)
- bytes = self._sc._jvm.SerDe.dumps(jvec)
- return ser.loads(str(bytes)).toArray()
-
- return func
-
-
-class MultivariateStatisticalSummary(object):
+class MultivariateStatisticalSummary(JavaModelWrapper):
"""
Trait for multivariate statistical summary of a data matrix.
"""
- def __init__(self, sc, java_summary):
- """
- :param sc: Spark context
- :param java_summary: Handle to Java summary object
- """
- self._sc = sc
- self._java_summary = java_summary
-
- def __del__(self):
- self._sc._gateway.detach(self._java_summary)
-
- @serialize
def mean(self):
- return self._java_summary.mean()
+ return self.call("mean").toArray()
- @serialize
def variance(self):
- return self._java_summary.variance()
+ return self.call("variance").toArray()
def count(self):
- return self._java_summary.count()
+ return self.call("count")
- @serialize
def numNonzeros(self):
- return self._java_summary.numNonzeros()
+ return self.call("numNonzeros").toArray()
- @serialize
def max(self):
- return self._java_summary.max()
+ return self.call("max").toArray()
- @serialize
def min(self):
- return self._java_summary.min()
+ return self.call("min").toArray()
class Statistics(object):
@@ -106,10 +76,8 @@ class Statistics(object):
>>> cStats.min()
array([ 2., 0., 0., -2.])
"""
- sc = rdd.ctx
- jrdd = _to_java_object_rdd(rdd.map(_convert_to_vector))
- cStats = sc._jvm.PythonMLLibAPI().colStats(jrdd)
- return MultivariateStatisticalSummary(sc, cStats)
+ cStats = callMLlibFunc("colStats", rdd.map(_convert_to_vector))
+ return MultivariateStatisticalSummary(cStats)
@staticmethod
def corr(x, y=None, method=None):
@@ -156,7 +124,6 @@ class Statistics(object):
... except TypeError:
... pass
"""
- sc = x.ctx
# Check inputs to determine whether a single value or a matrix is needed for output.
# Since it's legal for users to use the method name as the second argument, we need to
# check if y is used to specify the method name instead.
@@ -164,15 +131,9 @@ class Statistics(object):
raise TypeError("Use 'method=' to specify method name.")
if not y:
- jx = _to_java_object_rdd(x.map(_convert_to_vector))
- resultMat = sc._jvm.PythonMLLibAPI().corr(jx, method)
- bytes = sc._jvm.SerDe.dumps(resultMat)
- ser = PickleSerializer()
- return ser.loads(str(bytes)).toArray()
+ return callMLlibFunc("corr", x.map(_convert_to_vector), method).toArray()
else:
- jx = _to_java_object_rdd(x.map(float))
- jy = _to_java_object_rdd(y.map(float))
- return sc._jvm.PythonMLLibAPI().corr(jx, jy, method)
+ return callMLlibFunc("corr", x.map(float), y.map(float), method)
def _test():