aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/mllib/regression.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/mllib/regression.py')
-rw-r--r--python/pyspark/mllib/regression.py15
1 files changed, 10 insertions, 5 deletions
diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py
index 43c1a2fc10..66e25a48df 100644
--- a/python/pyspark/mllib/regression.py
+++ b/python/pyspark/mllib/regression.py
@@ -36,7 +36,7 @@ class LabeledPoint(object):
"""
def __init__(self, label, features):
- self.label = label
+ self.label = float(label)
self.features = _convert_to_vector(features)
def __reduce__(self):
@@ -46,7 +46,7 @@ class LabeledPoint(object):
return "(" + ",".join((str(self.label), str(self.features))) + ")"
def __repr__(self):
- return "LabeledPoint(" + ",".join((repr(self.label), repr(self.features))) + ")"
+ return "LabeledPoint(%s, %s)" % (self.label, self.features)
class LinearModel(object):
@@ -55,7 +55,7 @@ class LinearModel(object):
def __init__(self, weights, intercept):
self._coeff = _convert_to_vector(weights)
- self._intercept = intercept
+ self._intercept = float(intercept)
@property
def weights(self):
@@ -66,7 +66,7 @@ class LinearModel(object):
return self._intercept
def __repr__(self):
- return "(weights=%s, intercept=%s)" % (self._coeff, self._intercept)
+ return "(weights=%s, intercept=%r)" % (self._coeff, self._intercept)
class LinearRegressionModelBase(LinearModel):
@@ -85,6 +85,7 @@ class LinearRegressionModelBase(LinearModel):
Predict the value of the dependent variable given a vector x
containing values for the independent variables.
"""
+ x = _convert_to_vector(x)
return self.weights.dot(x) + self.intercept
@@ -124,6 +125,9 @@ class LinearRegressionModel(LinearRegressionModelBase):
# return the result of a call to the appropriate JVM stub.
# _regression_train_wrapper is responsible for setup and error checking.
def _regression_train_wrapper(train_func, modelClass, data, initial_weights):
+ first = data.first()
+ if not isinstance(first, LabeledPoint):
+ raise ValueError("data should be an RDD of LabeledPoint, but got %s" % first)
initial_weights = initial_weights or [0.0] * len(data.first().features)
weights, intercept = train_func(_to_java_object_rdd(data, cache=True),
_convert_to_vector(initial_weights))
@@ -264,7 +268,8 @@ class RidgeRegressionWithSGD(object):
def _test():
import doctest
from pyspark import SparkContext
- globs = globals().copy()
+ import pyspark.mllib.regression
+ globs = pyspark.mllib.regression.__dict__.copy()
globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
globs['sc'].stop()