aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/mllib/regression.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/mllib/regression.py')
-rw-r--r--python/pyspark/mllib/regression.py52
1 files changed, 21 insertions, 31 deletions
diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py
index 93e17faf5c..43c1a2fc10 100644
--- a/python/pyspark/mllib/regression.py
+++ b/python/pyspark/mllib/regression.py
@@ -18,9 +18,8 @@
import numpy as np
from numpy import array
-from pyspark import SparkContext
-from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
-from pyspark.mllib.linalg import SparseVector, _convert_to_vector, _to_java_object_rdd
+from pyspark.mllib.common import callMLlibFunc, _to_java_object_rdd
+from pyspark.mllib.linalg import SparseVector, _convert_to_vector
__all__ = ['LabeledPoint', 'LinearModel', 'LinearRegressionModel', 'RidgeRegressionModel',
'LinearRegressionWithSGD', 'LassoWithSGD', 'RidgeRegressionWithSGD']
@@ -124,17 +123,11 @@ class LinearRegressionModel(LinearRegressionModelBase):
# train_func should take two parameters, namely data and initial_weights, and
# return the result of a call to the appropriate JVM stub.
# _regression_train_wrapper is responsible for setup and error checking.
-def _regression_train_wrapper(sc, train_func, modelClass, data, initial_weights):
+def _regression_train_wrapper(train_func, modelClass, data, initial_weights):
initial_weights = initial_weights or [0.0] * len(data.first().features)
- ser = PickleSerializer()
- initial_bytes = bytearray(ser.dumps(_convert_to_vector(initial_weights)))
- # use AutoBatchedSerializer before cache to reduce the memory
- # overhead in JVM
- cached = data._reserialize(AutoBatchedSerializer(ser)).cache()
- ans = train_func(_to_java_object_rdd(cached), initial_bytes)
- assert len(ans) == 2, "JVM call result had unexpected length"
- weights = ser.loads(str(ans[0]))
- return modelClass(weights, ans[1])
+ weights, intercept = train_func(_to_java_object_rdd(data, cache=True),
+ _convert_to_vector(initial_weights))
+ return modelClass(weights, intercept)
class LinearRegressionWithSGD(object):
@@ -168,13 +161,12 @@ class LinearRegressionWithSGD(object):
training data (i.e. whether bias features
are activated or not).
"""
- sc = data.context
+ def train(rdd, i):
+ return callMLlibFunc("trainLinearRegressionModelWithSGD", rdd, iterations, step,
+ miniBatchFraction, i, regParam, regType, intercept)
- def train(jrdd, i):
- return sc._jvm.PythonMLLibAPI().trainLinearRegressionModelWithSGD(
- jrdd, iterations, step, miniBatchFraction, i, regParam, regType, intercept)
-
- return _regression_train_wrapper(sc, train, LinearRegressionModel, data, initialWeights)
+ return _regression_train_wrapper(train, LinearRegressionModel,
+ data, initialWeights)
class LassoModel(LinearRegressionModelBase):
@@ -216,12 +208,10 @@ class LassoWithSGD(object):
def train(cls, data, iterations=100, step=1.0, regParam=1.0,
miniBatchFraction=1.0, initialWeights=None):
"""Train a Lasso regression model on the given data."""
- sc = data.context
-
- def train(jrdd, i):
- return sc._jvm.PythonMLLibAPI().trainLassoModelWithSGD(
- jrdd, iterations, step, regParam, miniBatchFraction, i)
- return _regression_train_wrapper(sc, train, LassoModel, data, initialWeights)
+ def train(rdd, i):
+ return callMLlibFunc("trainLassoModelWithSGD", rdd, iterations, step, regParam,
+ miniBatchFraction, i)
+ return _regression_train_wrapper(train, LassoModel, data, initialWeights)
class RidgeRegressionModel(LinearRegressionModelBase):
@@ -263,17 +253,17 @@ class RidgeRegressionWithSGD(object):
def train(cls, data, iterations=100, step=1.0, regParam=1.0,
miniBatchFraction=1.0, initialWeights=None):
"""Train a ridge regression model on the given data."""
- sc = data.context
-
- def train(jrdd, i):
- return sc._jvm.PythonMLLibAPI().trainRidgeModelWithSGD(
- jrdd, iterations, step, regParam, miniBatchFraction, i)
+ def train(rdd, i):
+ return callMLlibFunc("trainRidgeModelWithSGD", rdd, iterations, step, regParam,
+ miniBatchFraction, i)
- return _regression_train_wrapper(sc, train, RidgeRegressionModel, data, initialWeights)
+ return _regression_train_wrapper(train, RidgeRegressionModel,
+ data, initialWeights)
def _test():
import doctest
+ from pyspark import SparkContext
globs = globals().copy()
globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)