From b5bd75d90a761199c3f9cb583c1fe48c8fda7780 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Tue, 31 Mar 2015 11:32:14 -0700 Subject: [SPARK-6255] [MLLIB] Support multiclass classification in Python API Python API parity check for classification and multiclass classification support, major disparities need to be added for Python: ```scala LogisticRegressionWithLBFGS setNumClasses setValidateData LogisticRegressionModel getThreshold numClasses numFeatures SVMWithSGD setValidateData SVMModel getThreshold ``` For users the greatest benefit in this PR is multiclass classification was supported by Python API. Users can train multiclass classification model and use it to predict in pyspark. Author: Yanbo Liang Closes #5137 from yanboliang/spark-6255 and squashes the following commits: 0bd531e [Yanbo Liang] address comments 444d5e2 [Yanbo Liang] LogisticRegressionModel.predict() optimization fc7990b [Yanbo Liang] address comments b0d9c63 [Yanbo Liang] Support Mulinomial LR model predict in Python API ded847c [Yanbo Liang] Python API parity check for classification (support multiclass classification) --- python/pyspark/mllib/regression.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) (limited to 'python/pyspark/mllib/regression.py') diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index 209f1ee473..cd7310a64f 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -167,13 +167,19 @@ class LinearRegressionModel(LinearRegressionModelBase): # return the result of a call to the appropriate JVM stub. # _regression_train_wrapper is responsible for setup and error checking. def _regression_train_wrapper(train_func, modelClass, data, initial_weights): + from pyspark.mllib.classification import LogisticRegressionModel first = data.first() if not isinstance(first, LabeledPoint): raise ValueError("data should be an RDD of LabeledPoint, but got %s" % first) if initial_weights is None: initial_weights = [0.0] * len(data.first().features) - weights, intercept = train_func(data, _convert_to_vector(initial_weights)) - return modelClass(weights, intercept) + if (modelClass == LogisticRegressionModel): + weights, intercept, numFeatures, numClasses = train_func( + data, _convert_to_vector(initial_weights)) + return modelClass(weights, intercept, numFeatures, numClasses) + else: + weights, intercept = train_func(data, _convert_to_vector(initial_weights)) + return modelClass(weights, intercept) class LinearRegressionWithSGD(object): -- cgit v1.2.3