diff options
Diffstat (limited to 'python/pyspark/mllib/regression.py')
-rw-r--r-- | python/pyspark/mllib/regression.py | 15 |
1 files changed, 10 insertions, 5 deletions
diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index 43c1a2fc10..66e25a48df 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -36,7 +36,7 @@ class LabeledPoint(object): """ def __init__(self, label, features): - self.label = label + self.label = float(label) self.features = _convert_to_vector(features) def __reduce__(self): @@ -46,7 +46,7 @@ class LabeledPoint(object): return "(" + ",".join((str(self.label), str(self.features))) + ")" def __repr__(self): - return "LabeledPoint(" + ",".join((repr(self.label), repr(self.features))) + ")" + return "LabeledPoint(%s, %s)" % (self.label, self.features) class LinearModel(object): @@ -55,7 +55,7 @@ class LinearModel(object): def __init__(self, weights, intercept): self._coeff = _convert_to_vector(weights) - self._intercept = intercept + self._intercept = float(intercept) @property def weights(self): @@ -66,7 +66,7 @@ class LinearModel(object): return self._intercept def __repr__(self): - return "(weights=%s, intercept=%s)" % (self._coeff, self._intercept) + return "(weights=%s, intercept=%r)" % (self._coeff, self._intercept) class LinearRegressionModelBase(LinearModel): @@ -85,6 +85,7 @@ class LinearRegressionModelBase(LinearModel): Predict the value of the dependent variable given a vector x containing values for the independent variables. """ + x = _convert_to_vector(x) return self.weights.dot(x) + self.intercept @@ -124,6 +125,9 @@ class LinearRegressionModel(LinearRegressionModelBase): # return the result of a call to the appropriate JVM stub. # _regression_train_wrapper is responsible for setup and error checking. def _regression_train_wrapper(train_func, modelClass, data, initial_weights): + first = data.first() + if not isinstance(first, LabeledPoint): + raise ValueError("data should be an RDD of LabeledPoint, but got %s" % first) initial_weights = initial_weights or [0.0] * len(data.first().features) weights, intercept = train_func(_to_java_object_rdd(data, cache=True), _convert_to_vector(initial_weights)) @@ -264,7 +268,8 @@ class RidgeRegressionWithSGD(object): def _test(): import doctest from pyspark import SparkContext - globs = globals().copy() + import pyspark.mllib.regression + globs = pyspark.mllib.regression.__dict__.copy() globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) globs['sc'].stop() |