aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2015-03-31 11:32:14 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-03-31 11:32:14 -0700
commitb5bd75d90a761199c3f9cb583c1fe48c8fda7780 (patch)
tree8defa75fba18d3fbb223bc2d780d21d33d00424b /mllib/src
parent46de6c05e0619250346f0988e296849f8f93d2b1 (diff)
downloadspark-b5bd75d90a761199c3f9cb583c1fe48c8fda7780.tar.gz
spark-b5bd75d90a761199c3f9cb583c1fe48c8fda7780.tar.bz2
spark-b5bd75d90a761199c3f9cb583c1fe48c8fda7780.zip
[SPARK-6255] [MLLIB] Support multiclass classification in Python API
Python API parity check for classification and multiclass classification support, major disparities need to be added for Python: ```scala LogisticRegressionWithLBFGS setNumClasses setValidateData LogisticRegressionModel getThreshold numClasses numFeatures SVMWithSGD setValidateData SVMModel getThreshold ``` For users the greatest benefit in this PR is multiclass classification was supported by Python API. Users can train multiclass classification model and use it to predict in pyspark. Author: Yanbo Liang <ybliang8@gmail.com> Closes #5137 from yanboliang/spark-6255 and squashes the following commits: 0bd531e [Yanbo Liang] address comments 444d5e2 [Yanbo Liang] LogisticRegressionModel.predict() optimization fc7990b [Yanbo Liang] address comments b0d9c63 [Yanbo Liang] Support Mulinomial LR model predict in Python API ded847c [Yanbo Liang] Python API parity check for classification (support multiclass classification)
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala22
1 files changed, 18 insertions, 4 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index 22fa684fd2..662ec5fbed 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -77,7 +77,13 @@ private[python] class PythonMLLibAPI extends Serializable {
initialWeights: Vector): JList[Object] = {
try {
val model = learner.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK), initialWeights)
- List(model.weights, model.intercept).map(_.asInstanceOf[Object]).asJava
+ if (model.isInstanceOf[LogisticRegressionModel]) {
+ val lrModel = model.asInstanceOf[LogisticRegressionModel]
+ List(lrModel.weights, lrModel.intercept, lrModel.numFeatures, lrModel.numClasses)
+ .map(_.asInstanceOf[Object]).asJava
+ } else {
+ List(model.weights, model.intercept).map(_.asInstanceOf[Object]).asJava
+ }
} finally {
data.rdd.unpersist(blocking = false)
}
@@ -190,9 +196,11 @@ private[python] class PythonMLLibAPI extends Serializable {
miniBatchFraction: Double,
initialWeights: Vector,
regType: String,
- intercept: Boolean): JList[Object] = {
+ intercept: Boolean,
+ validateData: Boolean): JList[Object] = {
val SVMAlg = new SVMWithSGD()
SVMAlg.setIntercept(intercept)
+ .setValidateData(validateData)
SVMAlg.optimizer
.setNumIterations(numIterations)
.setRegParam(regParam)
@@ -216,9 +224,11 @@ private[python] class PythonMLLibAPI extends Serializable {
initialWeights: Vector,
regParam: Double,
regType: String,
- intercept: Boolean): JList[Object] = {
+ intercept: Boolean,
+ validateData: Boolean): JList[Object] = {
val LogRegAlg = new LogisticRegressionWithSGD()
LogRegAlg.setIntercept(intercept)
+ .setValidateData(validateData)
LogRegAlg.optimizer
.setNumIterations(numIterations)
.setRegParam(regParam)
@@ -242,9 +252,13 @@ private[python] class PythonMLLibAPI extends Serializable {
regType: String,
intercept: Boolean,
corrections: Int,
- tolerance: Double): JList[Object] = {
+ tolerance: Double,
+ validateData: Boolean,
+ numClasses: Int): JList[Object] = {
val LogRegAlg = new LogisticRegressionWithLBFGS()
LogRegAlg.setIntercept(intercept)
+ .setValidateData(validateData)
+ .setNumClasses(numClasses)
LogRegAlg.optimizer
.setNumIterations(numIterations)
.setRegParam(regParam)