aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2014-11-13 13:54:16 -0800
committerXiangrui Meng <meng@databricks.com>2014-11-13 13:54:37 -0800
commit5de97fc4384a8671f859cf8e2808324d0337216f (patch)
treef2b419e1a42f160ed9825708ccea7e20d964ff38 /mllib
parentd993a44de2bf91e93c5ad3f84d35ff4e55f4b2fb (diff)
downloadspark-5de97fc4384a8671f859cf8e2808324d0337216f.tar.gz
spark-5de97fc4384a8671f859cf8e2808324d0337216f.tar.bz2
spark-5de97fc4384a8671f859cf8e2808324d0337216f.zip
[SPARK-4372][MLLIB] Make LR and SVM's default parameters consistent in Scala and Python
The current default regParam is 1.0 and regType is claimed to be none in Python (but actually it is l2), while regParam = 0.0 and regType is L2 in Scala. We should make the default values consistent. This PR sets the default regType to L2 and regParam to 0.01. Note that the default regParam value in LIBLINEAR (and hence scikit-learn) is 1.0. However, we use average loss instead of total loss in our formulation. Hence regParam=1.0 is definitely too heavy. In LinearRegression, we set regParam=0.0 and regType=None, because we have separate classes for Lasso and Ridge, both of which use regParam=0.01 as the default. davies atalwalkar Author: Xiangrui Meng <meng@databricks.com> Closes #3232 from mengxr/SPARK-4372 and squashes the following commits: 9979837 [Xiangrui Meng] update Ridge/Lasso to use default regParam 0.01 cast input arguments d3ba096 [Xiangrui Meng] change 'none' back to None 1909a6e [Xiangrui Meng] change default regParam to 0.01 and regType to L2 in LR and SVM (cherry picked from commit 32218307edc6de2b08d5f7a0db6d566081d27197) Signed-off-by: Xiangrui Meng <meng@databricks.com>
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala34
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala12
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala10
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala8
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala28
6 files changed, 56 insertions, 42 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index 70d7138e30..c8476a5370 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -28,22 +28,22 @@ import net.razorvine.pickle._
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
-import org.apache.spark.api.python.{PythonRDD, SerDeUtil}
+import org.apache.spark.api.python.SerDeUtil
import org.apache.spark.mllib.classification._
import org.apache.spark.mllib.clustering._
import org.apache.spark.mllib.feature._
-import org.apache.spark.mllib.optimization._
import org.apache.spark.mllib.linalg._
+import org.apache.spark.mllib.optimization._
import org.apache.spark.mllib.random.{RandomRDDs => RG}
import org.apache.spark.mllib.recommendation._
import org.apache.spark.mllib.regression._
-import org.apache.spark.mllib.tree.configuration.{Algo, Strategy}
-import org.apache.spark.mllib.tree.DecisionTree
-import org.apache.spark.mllib.tree.impurity._
-import org.apache.spark.mllib.tree.model.DecisionTreeModel
import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, Statistics}
import org.apache.spark.mllib.stat.correlation.CorrelationNames
import org.apache.spark.mllib.stat.test.ChiSqTestResult
+import org.apache.spark.mllib.tree.DecisionTree
+import org.apache.spark.mllib.tree.configuration.{Algo, Strategy}
+import org.apache.spark.mllib.tree.impurity._
+import org.apache.spark.mllib.tree.model.DecisionTreeModel
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
@@ -103,9 +103,11 @@ class PythonMLLibAPI extends Serializable {
lrAlg.optimizer.setUpdater(new SquaredL2Updater)
} else if (regType == "l1") {
lrAlg.optimizer.setUpdater(new L1Updater)
- } else if (regType != "none") {
- throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter."
- + " Can only be initialized using the following string values: [l1, l2, none].")
+ } else if (regType == null) {
+ lrAlg.optimizer.setUpdater(new SimpleUpdater)
+ } else {
+ throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter."
+ + " Can only be initialized using the following string values: ['l1', 'l2', None].")
}
trainRegressionModel(
lrAlg,
@@ -180,9 +182,11 @@ class PythonMLLibAPI extends Serializable {
SVMAlg.optimizer.setUpdater(new SquaredL2Updater)
} else if (regType == "l1") {
SVMAlg.optimizer.setUpdater(new L1Updater)
- } else if (regType != "none") {
+ } else if (regType == null) {
+ SVMAlg.optimizer.setUpdater(new SimpleUpdater)
+ } else {
throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter."
- + " Can only be initialized using the following string values: [l1, l2, none].")
+ + " Can only be initialized using the following string values: ['l1', 'l2', None].")
}
trainRegressionModel(
SVMAlg,
@@ -213,9 +217,11 @@ class PythonMLLibAPI extends Serializable {
LogRegAlg.optimizer.setUpdater(new SquaredL2Updater)
} else if (regType == "l1") {
LogRegAlg.optimizer.setUpdater(new L1Updater)
- } else if (regType != "none") {
+ } else if (regType == null) {
+ LogRegAlg.optimizer.setUpdater(new SimpleUpdater)
+ } else {
throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter."
- + " Can only be initialized using the following string values: [l1, l2, none].")
+ + " Can only be initialized using the following string values: ['l1', 'l2', None].")
}
trainRegressionModel(
LogRegAlg,
@@ -250,7 +256,7 @@ class PythonMLLibAPI extends Serializable {
.setInitializationMode(initializationMode)
// Disable the uncached input warning because 'data' is a deliberately uncached MappedRDD.
.disableUncachedWarning()
- return kMeansAlg.run(data.rdd)
+ kMeansAlg.run(data.rdd)
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
index 84d3c7cebd..18b95f1edc 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
@@ -71,9 +71,10 @@ class LogisticRegressionModel (
}
/**
- * Train a classification model for Logistic Regression using Stochastic Gradient Descent.
- * NOTE: Labels used in Logistic Regression should be {0, 1}
- *
+ * Train a classification model for Logistic Regression using Stochastic Gradient Descent. By
+ * default L2 regularization is used, which can be changed via
+ * [[LogisticRegressionWithSGD.optimizer]].
+ * NOTE: Labels used in Logistic Regression should be {0, 1}.
* Using [[LogisticRegressionWithLBFGS]] is recommended over this.
*/
class LogisticRegressionWithSGD private (
@@ -93,9 +94,10 @@ class LogisticRegressionWithSGD private (
override protected val validators = List(DataValidators.binaryLabelValidator)
/**
- * Construct a LogisticRegression object with default parameters
+ * Construct a LogisticRegression object with default parameters: {stepSize: 1.0,
+ * numIterations: 100, regParm: 0.01, miniBatchFraction: 1.0}.
*/
- def this() = this(1.0, 100, 0.0, 1.0)
+ def this() = this(1.0, 100, 0.01, 1.0)
override protected def createModel(weights: Vector, intercept: Double) = {
new LogisticRegressionModel(weights, intercept)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
index 80f8a1b2f1..ab9515b2a6 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
@@ -72,7 +72,8 @@ class SVMModel (
}
/**
- * Train a Support Vector Machine (SVM) using Stochastic Gradient Descent.
+ * Train a Support Vector Machine (SVM) using Stochastic Gradient Descent. By default L2
+ * regularization is used, which can be changed via [[SVMWithSGD.optimizer]].
* NOTE: Labels used in SVM should be {0, 1}.
*/
class SVMWithSGD private (
@@ -92,9 +93,10 @@ class SVMWithSGD private (
override protected val validators = List(DataValidators.binaryLabelValidator)
/**
- * Construct a SVM object with default parameters
+ * Construct a SVM object with default parameters: {stepSize: 1.0, numIterations: 100,
+ * regParm: 0.01, miniBatchFraction: 1.0}.
*/
- def this() = this(1.0, 100, 1.0, 1.0)
+ def this() = this(1.0, 100, 0.01, 1.0)
override protected def createModel(weights: Vector, intercept: Double) = {
new SVMModel(weights, intercept)
@@ -185,6 +187,6 @@ object SVMWithSGD {
* @return a SVMModel which has the weights and offset from training.
*/
def train(input: RDD[LabeledPoint], numIterations: Int): SVMModel = {
- train(input, numIterations, 1.0, 1.0, 1.0)
+ train(input, numIterations, 1.0, 0.01, 1.0)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala
index cb0d39e759..f9791c6571 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala
@@ -67,9 +67,9 @@ class LassoWithSGD private (
/**
* Construct a Lasso object with default parameters: {stepSize: 1.0, numIterations: 100,
- * regParam: 1.0, miniBatchFraction: 1.0}.
+ * regParam: 0.01, miniBatchFraction: 1.0}.
*/
- def this() = this(1.0, 100, 1.0, 1.0)
+ def this() = this(1.0, 100, 0.01, 1.0)
override protected def createModel(weights: Vector, intercept: Double) = {
new LassoModel(weights, intercept)
@@ -161,6 +161,6 @@ object LassoWithSGD {
def train(
input: RDD[LabeledPoint],
numIterations: Int): LassoModel = {
- train(input, numIterations, 1.0, 1.0, 1.0)
+ train(input, numIterations, 1.0, 0.01, 1.0)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala
index a826deb695..c8cad773f5 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala
@@ -68,9 +68,9 @@ class RidgeRegressionWithSGD private (
/**
* Construct a RidgeRegression object with default parameters: {stepSize: 1.0, numIterations: 100,
- * regParam: 1.0, miniBatchFraction: 1.0}.
+ * regParam: 0.01, miniBatchFraction: 1.0}.
*/
- def this() = this(1.0, 100, 1.0, 1.0)
+ def this() = this(1.0, 100, 0.01, 1.0)
override protected def createModel(weights: Vector, intercept: Double) = {
new RidgeRegressionModel(weights, intercept)
@@ -143,7 +143,7 @@ object RidgeRegressionWithSGD {
numIterations: Int,
stepSize: Double,
regParam: Double): RidgeRegressionModel = {
- train(input, numIterations, stepSize, regParam, 1.0)
+ train(input, numIterations, stepSize, regParam, 0.01)
}
/**
@@ -158,6 +158,6 @@ object RidgeRegressionWithSGD {
def train(
input: RDD[LabeledPoint],
numIterations: Int): RidgeRegressionModel = {
- train(input, numIterations, 1.0, 1.0, 1.0)
+ train(input, numIterations, 1.0, 0.01, 1.0)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
index 6c1c784a19..4e81299440 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
@@ -80,13 +80,16 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with M
val testRDD = sc.parallelize(testData, 2)
testRDD.cache()
val lr = new LogisticRegressionWithSGD().setIntercept(true)
- lr.optimizer.setStepSize(10.0).setNumIterations(20)
+ lr.optimizer
+ .setStepSize(10.0)
+ .setRegParam(0.0)
+ .setNumIterations(20)
val model = lr.run(testRDD)
// Test the weights
- assert(model.weights(0) ~== -1.52 relTol 0.01)
- assert(model.intercept ~== 2.00 relTol 0.01)
+ assert(model.weights(0) ~== B relTol 0.02)
+ assert(model.intercept ~== A relTol 0.02)
val validationData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 17)
val validationRDD = sc.parallelize(validationData, 2)
@@ -112,10 +115,8 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with M
val model = lr.run(testRDD)
// Test the weights
- assert(model.weights(0) ~== -1.52 relTol 0.01)
- assert(model.intercept ~== 2.00 relTol 0.01)
- assert(model.weights(0) ~== model.weights(0) relTol 0.01)
- assert(model.intercept ~== model.intercept relTol 0.01)
+ assert(model.weights(0) ~== B relTol 0.02)
+ assert(model.intercept ~== A relTol 0.02)
val validationData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 17)
val validationRDD = sc.parallelize(validationData, 2)
@@ -141,13 +142,16 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with M
// Use half as many iterations as the previous test.
val lr = new LogisticRegressionWithSGD().setIntercept(true)
- lr.optimizer.setStepSize(10.0).setNumIterations(10)
+ lr.optimizer
+ .setStepSize(10.0)
+ .setRegParam(0.0)
+ .setNumIterations(10)
val model = lr.run(testRDD, initialWeights)
// Test the weights
- assert(model.weights(0) ~== -1.50 relTol 0.01)
- assert(model.intercept ~== 1.97 relTol 0.01)
+ assert(model.weights(0) ~== B relTol 0.02)
+ assert(model.intercept ~== A relTol 0.02)
val validationData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 17)
val validationRDD = sc.parallelize(validationData, 2)
@@ -212,8 +216,8 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with M
val model = lr.run(testRDD, initialWeights)
// Test the weights
- assert(model.weights(0) ~== -1.50 relTol 0.02)
- assert(model.intercept ~== 1.97 relTol 0.02)
+ assert(model.weights(0) ~== B relTol 0.02)
+ assert(model.intercept ~== A relTol 0.02)
val validationData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 17)
val validationRDD = sc.parallelize(validationData, 2)