diff options
Diffstat (limited to 'python/pyspark/mllib/classification.py')
-rw-r--r-- | python/pyspark/mllib/classification.py | 13 |
1 files changed, 8 insertions, 5 deletions
diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index 297a2bf37d..5d90dddb5d 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -62,6 +62,7 @@ class LogisticRegressionModel(LinearModel): """ def predict(self, x): + x = _convert_to_vector(x) margin = self.weights.dot(x) + self._intercept if margin > 0: prob = 1 / (1 + exp(-margin)) @@ -79,7 +80,7 @@ class LogisticRegressionWithSGD(object): """ Train a logistic regression model on the given data. - :param data: The training data. + :param data: The training data, an RDD of LabeledPoint. :param iterations: The number of iterations (default: 100). :param step: The step parameter used in SGD (default: 1.0). @@ -136,6 +137,7 @@ class SVMModel(LinearModel): """ def predict(self, x): + x = _convert_to_vector(x) margin = self.weights.dot(x) + self.intercept return 1 if margin >= 0 else 0 @@ -148,7 +150,7 @@ class SVMWithSGD(object): """ Train a support vector machine on the given data. - :param data: The training data. + :param data: The training data, an RDD of LabeledPoint. :param iterations: The number of iterations (default: 100). :param step: The step parameter used in SGD (default: 1.0). @@ -233,11 +235,12 @@ class NaiveBayes(object): classification. By making every vector a 0-1 vector, it can also be used as Bernoulli NB (U{http://tinyurl.com/p7c96j6}). - :param data: RDD of NumPy vectors, one per element, where the first - coordinate is the label and the rest is the feature vector - (e.g. a count vector). + :param data: RDD of LabeledPoint. :param lambda_: The smoothing parameter """ + first = data.first() + if not isinstance(first, LabeledPoint): + raise ValueError("`data` should be an RDD of LabeledPoint") labels, pi, theta = callMLlibFunc("trainNaiveBayes", data, lambda_) return NaiveBayesModel(labels.toArray(), pi.toArray(), numpy.array(theta)) |