aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/mllib/classification.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/mllib/classification.py')
-rw-r--r--python/pyspark/mllib/classification.py13
1 files changed, 8 insertions, 5 deletions
diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py
index 297a2bf37d..5d90dddb5d 100644
--- a/python/pyspark/mllib/classification.py
+++ b/python/pyspark/mllib/classification.py
@@ -62,6 +62,7 @@ class LogisticRegressionModel(LinearModel):
"""
def predict(self, x):
+ x = _convert_to_vector(x)
margin = self.weights.dot(x) + self._intercept
if margin > 0:
prob = 1 / (1 + exp(-margin))
@@ -79,7 +80,7 @@ class LogisticRegressionWithSGD(object):
"""
Train a logistic regression model on the given data.
- :param data: The training data.
+ :param data: The training data, an RDD of LabeledPoint.
:param iterations: The number of iterations (default: 100).
:param step: The step parameter used in SGD
(default: 1.0).
@@ -136,6 +137,7 @@ class SVMModel(LinearModel):
"""
def predict(self, x):
+ x = _convert_to_vector(x)
margin = self.weights.dot(x) + self.intercept
return 1 if margin >= 0 else 0
@@ -148,7 +150,7 @@ class SVMWithSGD(object):
"""
Train a support vector machine on the given data.
- :param data: The training data.
+ :param data: The training data, an RDD of LabeledPoint.
:param iterations: The number of iterations (default: 100).
:param step: The step parameter used in SGD
(default: 1.0).
@@ -233,11 +235,12 @@ class NaiveBayes(object):
classification. By making every vector a 0-1 vector, it can also be
used as Bernoulli NB (U{http://tinyurl.com/p7c96j6}).
- :param data: RDD of NumPy vectors, one per element, where the first
- coordinate is the label and the rest is the feature vector
- (e.g. a count vector).
+ :param data: RDD of LabeledPoint.
:param lambda_: The smoothing parameter
"""
+ first = data.first()
+ if not isinstance(first, LabeledPoint):
+ raise ValueError("`data` should be an RDD of LabeledPoint")
labels, pi, theta = callMLlibFunc("trainNaiveBayes", data, lambda_)
return NaiveBayesModel(labels.toArray(), pi.toArray(), numpy.array(theta))