aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2015-09-14 12:08:52 -0700
committerXiangrui Meng <meng@databricks.com>2015-09-14 12:08:52 -0700
commitce6f3f163bc667cb5da9ab4331c8bad10cc0d701 (patch)
treeb45c7d15811bef5d745ebe2b33b496a91ae9ab21 /mllib
parentcf2821ef5fd9965eb6256e8e8b3f1e00c0788098 (diff)
downloadspark-ce6f3f163bc667cb5da9ab4331c8bad10cc0d701.tar.gz
spark-ce6f3f163bc667cb5da9ab4331c8bad10cc0d701.tar.bz2
spark-ce6f3f163bc667cb5da9ab4331c8bad10cc0d701.zip
[SPARK-10194] [MLLIB] [PYSPARK] SGD algorithms need convergenceTol parameter in Python
[SPARK-3382](https://issues.apache.org/jira/browse/SPARK-3382) added a ```convergenceTol``` parameter for GradientDescent-based methods in Scala. We need that parameter in Python; otherwise, Python users will not be able to adjust that behavior (or even reproduce behavior from previous releases since the default changed). Author: Yanbo Liang <ybliang8@gmail.com> Closes #8457 from yanboliang/spark-10194.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala20
1 files changed, 15 insertions, 5 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index f585aacd45..69ce7f5070 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -132,7 +132,8 @@ private[python] class PythonMLLibAPI extends Serializable {
regParam: Double,
regType: String,
intercept: Boolean,
- validateData: Boolean): JList[Object] = {
+ validateData: Boolean,
+ convergenceTol: Double): JList[Object] = {
val lrAlg = new LinearRegressionWithSGD()
lrAlg.setIntercept(intercept)
.setValidateData(validateData)
@@ -141,6 +142,7 @@ private[python] class PythonMLLibAPI extends Serializable {
.setRegParam(regParam)
.setStepSize(stepSize)
.setMiniBatchFraction(miniBatchFraction)
+ .setConvergenceTol(convergenceTol)
lrAlg.optimizer.setUpdater(getUpdaterFromString(regType))
trainRegressionModel(
lrAlg,
@@ -159,7 +161,8 @@ private[python] class PythonMLLibAPI extends Serializable {
miniBatchFraction: Double,
initialWeights: Vector,
intercept: Boolean,
- validateData: Boolean): JList[Object] = {
+ validateData: Boolean,
+ convergenceTol: Double): JList[Object] = {
val lassoAlg = new LassoWithSGD()
lassoAlg.setIntercept(intercept)
.setValidateData(validateData)
@@ -168,6 +171,7 @@ private[python] class PythonMLLibAPI extends Serializable {
.setRegParam(regParam)
.setStepSize(stepSize)
.setMiniBatchFraction(miniBatchFraction)
+ .setConvergenceTol(convergenceTol)
trainRegressionModel(
lassoAlg,
data,
@@ -185,7 +189,8 @@ private[python] class PythonMLLibAPI extends Serializable {
miniBatchFraction: Double,
initialWeights: Vector,
intercept: Boolean,
- validateData: Boolean): JList[Object] = {
+ validateData: Boolean,
+ convergenceTol: Double): JList[Object] = {
val ridgeAlg = new RidgeRegressionWithSGD()
ridgeAlg.setIntercept(intercept)
.setValidateData(validateData)
@@ -194,6 +199,7 @@ private[python] class PythonMLLibAPI extends Serializable {
.setRegParam(regParam)
.setStepSize(stepSize)
.setMiniBatchFraction(miniBatchFraction)
+ .setConvergenceTol(convergenceTol)
trainRegressionModel(
ridgeAlg,
data,
@@ -212,7 +218,8 @@ private[python] class PythonMLLibAPI extends Serializable {
initialWeights: Vector,
regType: String,
intercept: Boolean,
- validateData: Boolean): JList[Object] = {
+ validateData: Boolean,
+ convergenceTol: Double): JList[Object] = {
val SVMAlg = new SVMWithSGD()
SVMAlg.setIntercept(intercept)
.setValidateData(validateData)
@@ -221,6 +228,7 @@ private[python] class PythonMLLibAPI extends Serializable {
.setRegParam(regParam)
.setStepSize(stepSize)
.setMiniBatchFraction(miniBatchFraction)
+ .setConvergenceTol(convergenceTol)
SVMAlg.optimizer.setUpdater(getUpdaterFromString(regType))
trainRegressionModel(
SVMAlg,
@@ -240,7 +248,8 @@ private[python] class PythonMLLibAPI extends Serializable {
regParam: Double,
regType: String,
intercept: Boolean,
- validateData: Boolean): JList[Object] = {
+ validateData: Boolean,
+ convergenceTol: Double): JList[Object] = {
val LogRegAlg = new LogisticRegressionWithSGD()
LogRegAlg.setIntercept(intercept)
.setValidateData(validateData)
@@ -249,6 +258,7 @@ private[python] class PythonMLLibAPI extends Serializable {
.setRegParam(regParam)
.setStepSize(stepSize)
.setMiniBatchFraction(miniBatchFraction)
+ .setConvergenceTol(convergenceTol)
LogRegAlg.optimizer.setUpdater(getUpdaterFromString(regType))
trainRegressionModel(
LogRegAlg,