aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/mllib/regression.py
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2014-10-30 22:25:18 -0700
committerXiangrui Meng <meng@databricks.com>2014-10-30 22:25:18 -0700
commit872fc669b497fb255db3212568f2a14c2ba0d5db (patch)
tree6dcaa7e0b251fa5f233171e2878a4dc428db2348 /python/pyspark/mllib/regression.py
parent0734d09320fe37edd3a02718511cda0bda852478 (diff)
downloadspark-872fc669b497fb255db3212568f2a14c2ba0d5db.tar.gz
spark-872fc669b497fb255db3212568f2a14c2ba0d5db.tar.bz2
spark-872fc669b497fb255db3212568f2a14c2ba0d5db.zip
[SPARK-4124] [MLlib] [PySpark] simplify serialization in MLlib Python API
Create several helper functions to call MLlib Java API, convert the arguments to Java type and convert return value to Python object automatically, this simplify serialization in MLlib Python API very much. After this, the MLlib Python API does not need to deal with serialization details anymore, it's easier to add new API. cc mengxr Author: Davies Liu <davies@databricks.com> Closes #2995 from davies/cleanup and squashes the following commits: 8fa6ec6 [Davies Liu] address comments 16b85a0 [Davies Liu] Merge branch 'master' of github.com:apache/spark into cleanup 43743e5 [Davies Liu] bugfix 731331f [Davies Liu] simplify serialization in MLlib Python API
Diffstat (limited to 'python/pyspark/mllib/regression.py')
-rw-r--r--python/pyspark/mllib/regression.py52
1 files changed, 21 insertions, 31 deletions
diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py
index 93e17faf5c..43c1a2fc10 100644
--- a/python/pyspark/mllib/regression.py
+++ b/python/pyspark/mllib/regression.py
@@ -18,9 +18,8 @@
import numpy as np
from numpy import array
-from pyspark import SparkContext
-from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
-from pyspark.mllib.linalg import SparseVector, _convert_to_vector, _to_java_object_rdd
+from pyspark.mllib.common import callMLlibFunc, _to_java_object_rdd
+from pyspark.mllib.linalg import SparseVector, _convert_to_vector
__all__ = ['LabeledPoint', 'LinearModel', 'LinearRegressionModel', 'RidgeRegressionModel',
'LinearRegressionWithSGD', 'LassoWithSGD', 'RidgeRegressionWithSGD']
@@ -124,17 +123,11 @@ class LinearRegressionModel(LinearRegressionModelBase):
# train_func should take two parameters, namely data and initial_weights, and
# return the result of a call to the appropriate JVM stub.
# _regression_train_wrapper is responsible for setup and error checking.
-def _regression_train_wrapper(sc, train_func, modelClass, data, initial_weights):
+def _regression_train_wrapper(train_func, modelClass, data, initial_weights):
initial_weights = initial_weights or [0.0] * len(data.first().features)
- ser = PickleSerializer()
- initial_bytes = bytearray(ser.dumps(_convert_to_vector(initial_weights)))
- # use AutoBatchedSerializer before cache to reduce the memory
- # overhead in JVM
- cached = data._reserialize(AutoBatchedSerializer(ser)).cache()
- ans = train_func(_to_java_object_rdd(cached), initial_bytes)
- assert len(ans) == 2, "JVM call result had unexpected length"
- weights = ser.loads(str(ans[0]))
- return modelClass(weights, ans[1])
+ weights, intercept = train_func(_to_java_object_rdd(data, cache=True),
+ _convert_to_vector(initial_weights))
+ return modelClass(weights, intercept)
class LinearRegressionWithSGD(object):
@@ -168,13 +161,12 @@ class LinearRegressionWithSGD(object):
training data (i.e. whether bias features
are activated or not).
"""
- sc = data.context
+ def train(rdd, i):
+ return callMLlibFunc("trainLinearRegressionModelWithSGD", rdd, iterations, step,
+ miniBatchFraction, i, regParam, regType, intercept)
- def train(jrdd, i):
- return sc._jvm.PythonMLLibAPI().trainLinearRegressionModelWithSGD(
- jrdd, iterations, step, miniBatchFraction, i, regParam, regType, intercept)
-
- return _regression_train_wrapper(sc, train, LinearRegressionModel, data, initialWeights)
+ return _regression_train_wrapper(train, LinearRegressionModel,
+ data, initialWeights)
class LassoModel(LinearRegressionModelBase):
@@ -216,12 +208,10 @@ class LassoWithSGD(object):
def train(cls, data, iterations=100, step=1.0, regParam=1.0,
miniBatchFraction=1.0, initialWeights=None):
"""Train a Lasso regression model on the given data."""
- sc = data.context
-
- def train(jrdd, i):
- return sc._jvm.PythonMLLibAPI().trainLassoModelWithSGD(
- jrdd, iterations, step, regParam, miniBatchFraction, i)
- return _regression_train_wrapper(sc, train, LassoModel, data, initialWeights)
+ def train(rdd, i):
+ return callMLlibFunc("trainLassoModelWithSGD", rdd, iterations, step, regParam,
+ miniBatchFraction, i)
+ return _regression_train_wrapper(train, LassoModel, data, initialWeights)
class RidgeRegressionModel(LinearRegressionModelBase):
@@ -263,17 +253,17 @@ class RidgeRegressionWithSGD(object):
def train(cls, data, iterations=100, step=1.0, regParam=1.0,
miniBatchFraction=1.0, initialWeights=None):
"""Train a ridge regression model on the given data."""
- sc = data.context
-
- def train(jrdd, i):
- return sc._jvm.PythonMLLibAPI().trainRidgeModelWithSGD(
- jrdd, iterations, step, regParam, miniBatchFraction, i)
+ def train(rdd, i):
+ return callMLlibFunc("trainRidgeModelWithSGD", rdd, iterations, step, regParam,
+ miniBatchFraction, i)
- return _regression_train_wrapper(sc, train, RidgeRegressionModel, data, initialWeights)
+ return _regression_train_wrapper(train, RidgeRegressionModel,
+ data, initialWeights)
def _test():
import doctest
+ from pyspark import SparkContext
globs = globals().copy()
globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)