aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala22
1 files changed, 18 insertions, 4 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index 22fa684fd2..662ec5fbed 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -77,7 +77,13 @@ private[python] class PythonMLLibAPI extends Serializable {
initialWeights: Vector): JList[Object] = {
try {
val model = learner.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK), initialWeights)
- List(model.weights, model.intercept).map(_.asInstanceOf[Object]).asJava
+ if (model.isInstanceOf[LogisticRegressionModel]) {
+ val lrModel = model.asInstanceOf[LogisticRegressionModel]
+ List(lrModel.weights, lrModel.intercept, lrModel.numFeatures, lrModel.numClasses)
+ .map(_.asInstanceOf[Object]).asJava
+ } else {
+ List(model.weights, model.intercept).map(_.asInstanceOf[Object]).asJava
+ }
} finally {
data.rdd.unpersist(blocking = false)
}
@@ -190,9 +196,11 @@ private[python] class PythonMLLibAPI extends Serializable {
miniBatchFraction: Double,
initialWeights: Vector,
regType: String,
- intercept: Boolean): JList[Object] = {
+ intercept: Boolean,
+ validateData: Boolean): JList[Object] = {
val SVMAlg = new SVMWithSGD()
SVMAlg.setIntercept(intercept)
+ .setValidateData(validateData)
SVMAlg.optimizer
.setNumIterations(numIterations)
.setRegParam(regParam)
@@ -216,9 +224,11 @@ private[python] class PythonMLLibAPI extends Serializable {
initialWeights: Vector,
regParam: Double,
regType: String,
- intercept: Boolean): JList[Object] = {
+ intercept: Boolean,
+ validateData: Boolean): JList[Object] = {
val LogRegAlg = new LogisticRegressionWithSGD()
LogRegAlg.setIntercept(intercept)
+ .setValidateData(validateData)
LogRegAlg.optimizer
.setNumIterations(numIterations)
.setRegParam(regParam)
@@ -242,9 +252,13 @@ private[python] class PythonMLLibAPI extends Serializable {
regType: String,
intercept: Boolean,
corrections: Int,
- tolerance: Double): JList[Object] = {
+ tolerance: Double,
+ validateData: Boolean,
+ numClasses: Int): JList[Object] = {
val LogRegAlg = new LogisticRegressionWithLBFGS()
LogRegAlg.setIntercept(intercept)
+ .setValidateData(validateData)
+ .setNumClasses(numClasses)
LogRegAlg.optimizer
.setNumIterations(numIterations)
.setRegParam(regParam)