diff options
author | Yanbo Liang <ybliang8@gmail.com> | 2015-03-31 11:32:14 -0700 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2015-03-31 11:32:14 -0700 |
commit | b5bd75d90a761199c3f9cb583c1fe48c8fda7780 (patch) | |
tree | 8defa75fba18d3fbb223bc2d780d21d33d00424b /python/pyspark/mllib/regression.py | |
parent | 46de6c05e0619250346f0988e296849f8f93d2b1 (diff) | |
download | spark-b5bd75d90a761199c3f9cb583c1fe48c8fda7780.tar.gz spark-b5bd75d90a761199c3f9cb583c1fe48c8fda7780.tar.bz2 spark-b5bd75d90a761199c3f9cb583c1fe48c8fda7780.zip |
[SPARK-6255] [MLLIB] Support multiclass classification in Python API
Python API parity check for classification and multiclass classification support, major disparities need to be added for Python:
```scala
LogisticRegressionWithLBFGS
setNumClasses
setValidateData
LogisticRegressionModel
getThreshold
numClasses
numFeatures
SVMWithSGD
setValidateData
SVMModel
getThreshold
```
For users the greatest benefit in this PR is multiclass classification was supported by Python API.
Users can train multiclass classification model and use it to predict in pyspark.
Author: Yanbo Liang <ybliang8@gmail.com>
Closes #5137 from yanboliang/spark-6255 and squashes the following commits:
0bd531e [Yanbo Liang] address comments
444d5e2 [Yanbo Liang] LogisticRegressionModel.predict() optimization
fc7990b [Yanbo Liang] address comments
b0d9c63 [Yanbo Liang] Support Mulinomial LR model predict in Python API
ded847c [Yanbo Liang] Python API parity check for classification (support multiclass classification)
Diffstat (limited to 'python/pyspark/mllib/regression.py')
-rw-r--r-- | python/pyspark/mllib/regression.py | 10 |
1 files changed, 8 insertions, 2 deletions
diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index 209f1ee473..cd7310a64f 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -167,13 +167,19 @@ class LinearRegressionModel(LinearRegressionModelBase): # return the result of a call to the appropriate JVM stub. # _regression_train_wrapper is responsible for setup and error checking. def _regression_train_wrapper(train_func, modelClass, data, initial_weights): + from pyspark.mllib.classification import LogisticRegressionModel first = data.first() if not isinstance(first, LabeledPoint): raise ValueError("data should be an RDD of LabeledPoint, but got %s" % first) if initial_weights is None: initial_weights = [0.0] * len(data.first().features) - weights, intercept = train_func(data, _convert_to_vector(initial_weights)) - return modelClass(weights, intercept) + if (modelClass == LogisticRegressionModel): + weights, intercept, numFeatures, numClasses = train_func( + data, _convert_to_vector(initial_weights)) + return modelClass(weights, intercept, numFeatures, numClasses) + else: + weights, intercept = train_func(data, _convert_to_vector(initial_weights)) + return modelClass(weights, intercept) class LinearRegressionWithSGD(object): |