aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/mllib/regression.py
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2015-03-31 11:32:14 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-03-31 11:32:14 -0700
commitb5bd75d90a761199c3f9cb583c1fe48c8fda7780 (patch)
tree8defa75fba18d3fbb223bc2d780d21d33d00424b /python/pyspark/mllib/regression.py
parent46de6c05e0619250346f0988e296849f8f93d2b1 (diff)
downloadspark-b5bd75d90a761199c3f9cb583c1fe48c8fda7780.tar.gz
spark-b5bd75d90a761199c3f9cb583c1fe48c8fda7780.tar.bz2
spark-b5bd75d90a761199c3f9cb583c1fe48c8fda7780.zip
[SPARK-6255] [MLLIB] Support multiclass classification in Python API
Python API parity check for classification and multiclass classification support, major disparities need to be added for Python: ```scala LogisticRegressionWithLBFGS setNumClasses setValidateData LogisticRegressionModel getThreshold numClasses numFeatures SVMWithSGD setValidateData SVMModel getThreshold ``` For users the greatest benefit in this PR is multiclass classification was supported by Python API. Users can train multiclass classification model and use it to predict in pyspark. Author: Yanbo Liang <ybliang8@gmail.com> Closes #5137 from yanboliang/spark-6255 and squashes the following commits: 0bd531e [Yanbo Liang] address comments 444d5e2 [Yanbo Liang] LogisticRegressionModel.predict() optimization fc7990b [Yanbo Liang] address comments b0d9c63 [Yanbo Liang] Support Mulinomial LR model predict in Python API ded847c [Yanbo Liang] Python API parity check for classification (support multiclass classification)
Diffstat (limited to 'python/pyspark/mllib/regression.py')
-rw-r--r--python/pyspark/mllib/regression.py10
1 files changed, 8 insertions, 2 deletions
diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py
index 209f1ee473..cd7310a64f 100644
--- a/python/pyspark/mllib/regression.py
+++ b/python/pyspark/mllib/regression.py
@@ -167,13 +167,19 @@ class LinearRegressionModel(LinearRegressionModelBase):
# return the result of a call to the appropriate JVM stub.
# _regression_train_wrapper is responsible for setup and error checking.
def _regression_train_wrapper(train_func, modelClass, data, initial_weights):
+ from pyspark.mllib.classification import LogisticRegressionModel
first = data.first()
if not isinstance(first, LabeledPoint):
raise ValueError("data should be an RDD of LabeledPoint, but got %s" % first)
if initial_weights is None:
initial_weights = [0.0] * len(data.first().features)
- weights, intercept = train_func(data, _convert_to_vector(initial_weights))
- return modelClass(weights, intercept)
+ if (modelClass == LogisticRegressionModel):
+ weights, intercept, numFeatures, numClasses = train_func(
+ data, _convert_to_vector(initial_weights))
+ return modelClass(weights, intercept, numFeatures, numClasses)
+ else:
+ weights, intercept = train_func(data, _convert_to_vector(initial_weights))
+ return modelClass(weights, intercept)
class LinearRegressionWithSGD(object):