aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/mllib/random.py
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2014-10-30 22:25:18 -0700
committerXiangrui Meng <meng@databricks.com>2014-10-30 22:25:18 -0700
commit872fc669b497fb255db3212568f2a14c2ba0d5db (patch)
tree6dcaa7e0b251fa5f233171e2878a4dc428db2348 /python/pyspark/mllib/random.py
parent0734d09320fe37edd3a02718511cda0bda852478 (diff)
downloadspark-872fc669b497fb255db3212568f2a14c2ba0d5db.tar.gz
spark-872fc669b497fb255db3212568f2a14c2ba0d5db.tar.bz2
spark-872fc669b497fb255db3212568f2a14c2ba0d5db.zip
[SPARK-4124] [MLlib] [PySpark] simplify serialization in MLlib Python API
Create several helper functions to call MLlib Java API, convert the arguments to Java type and convert return value to Python object automatically, this simplify serialization in MLlib Python API very much. After this, the MLlib Python API does not need to deal with serialization details anymore, it's easier to add new API. cc mengxr Author: Davies Liu <davies@databricks.com> Closes #2995 from davies/cleanup and squashes the following commits: 8fa6ec6 [Davies Liu] address comments 16b85a0 [Davies Liu] Merge branch 'master' of github.com:apache/spark into cleanup 43743e5 [Davies Liu] bugfix 731331f [Davies Liu] simplify serialization in MLlib Python API
Diffstat (limited to 'python/pyspark/mllib/random.py')
-rw-r--r--python/pyspark/mllib/random.py34
1 files changed, 8 insertions, 26 deletions
diff --git a/python/pyspark/mllib/random.py b/python/pyspark/mllib/random.py
index 2202c51ab9..7eebfc6bcd 100644
--- a/python/pyspark/mllib/random.py
+++ b/python/pyspark/mllib/random.py
@@ -21,22 +21,12 @@ Python package for random data generation.
from functools import wraps
-from pyspark.rdd import RDD
-from pyspark.serializers import BatchedSerializer, PickleSerializer
+from pyspark.mllib.common import callMLlibFunc
__all__ = ['RandomRDDs', ]
-def serialize(f):
- @wraps(f)
- def func(sc, *a, **kw):
- jrdd = f(sc, *a, **kw)
- return RDD(sc._jvm.SerDe.javaToPython(jrdd), sc,
- BatchedSerializer(PickleSerializer(), 1024))
- return func
-
-
def toArray(f):
@wraps(f)
def func(sc, *a, **kw):
@@ -52,7 +42,6 @@ class RandomRDDs(object):
"""
@staticmethod
- @serialize
def uniformRDD(sc, size, numPartitions=None, seed=None):
"""
Generates an RDD comprised of i.i.d. samples from the
@@ -74,10 +63,9 @@ class RandomRDDs(object):
>>> parts == sc.defaultParallelism
True
"""
- return sc._jvm.PythonMLLibAPI().uniformRDD(sc._jsc, size, numPartitions, seed)
+ return callMLlibFunc("uniformRDD", sc._jsc, size, numPartitions, seed)
@staticmethod
- @serialize
def normalRDD(sc, size, numPartitions=None, seed=None):
"""
Generates an RDD comprised of i.i.d. samples from the standard normal
@@ -97,10 +85,9 @@ class RandomRDDs(object):
>>> abs(stats.stdev() - 1.0) < 0.1
True
"""
- return sc._jvm.PythonMLLibAPI().normalRDD(sc._jsc, size, numPartitions, seed)
+ return callMLlibFunc("normalRDD", sc._jsc, size, numPartitions, seed)
@staticmethod
- @serialize
def poissonRDD(sc, mean, size, numPartitions=None, seed=None):
"""
Generates an RDD comprised of i.i.d. samples from the Poisson
@@ -117,11 +104,10 @@ class RandomRDDs(object):
>>> abs(stats.stdev() - sqrt(mean)) < 0.5
True
"""
- return sc._jvm.PythonMLLibAPI().poissonRDD(sc._jsc, mean, size, numPartitions, seed)
+ return callMLlibFunc("poissonRDD", sc._jsc, mean, size, numPartitions, seed)
@staticmethod
@toArray
- @serialize
def uniformVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None):
"""
Generates an RDD comprised of vectors containing i.i.d. samples drawn
@@ -136,12 +122,10 @@ class RandomRDDs(object):
>>> RandomRDDs.uniformVectorRDD(sc, 10, 10, 4).getNumPartitions()
4
"""
- return sc._jvm.PythonMLLibAPI() \
- .uniformVectorRDD(sc._jsc, numRows, numCols, numPartitions, seed)
+ return callMLlibFunc("uniformVectorRDD", sc._jsc, numRows, numCols, numPartitions, seed)
@staticmethod
@toArray
- @serialize
def normalVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None):
"""
Generates an RDD comprised of vectors containing i.i.d. samples drawn
@@ -156,12 +140,10 @@ class RandomRDDs(object):
>>> abs(mat.std() - 1.0) < 0.1
True
"""
- return sc._jvm.PythonMLLibAPI() \
- .normalVectorRDD(sc._jsc, numRows, numCols, numPartitions, seed)
+ return callMLlibFunc("normalVectorRDD", sc._jsc, numRows, numCols, numPartitions, seed)
@staticmethod
@toArray
- @serialize
def poissonVectorRDD(sc, mean, numRows, numCols, numPartitions=None, seed=None):
"""
Generates an RDD comprised of vectors containing i.i.d. samples drawn
@@ -179,8 +161,8 @@ class RandomRDDs(object):
>>> abs(mat.std() - sqrt(mean)) < 0.5
True
"""
- return sc._jvm.PythonMLLibAPI() \
- .poissonVectorRDD(sc._jsc, mean, numRows, numCols, numPartitions, seed)
+ return callMLlibFunc("poissonVectorRDD", sc._jsc, mean, numRows, numCols,
+ numPartitions, seed)
def _test():