aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorDB Tsai <dbtsai@alpinenow.com>2014-08-11 19:49:29 -0700
committerXiangrui Meng <meng@databricks.com>2014-08-11 19:49:29 -0700
commit6fab941b65f0cb6c9b32e0f8290d76889cda6a87 (patch)
tree21a68ffda086cac8cc263a9493539bde5dc2fa61 /mllib
parent32638b5e74e02410831b391f555223f90c830498 (diff)
downloadspark-6fab941b65f0cb6c9b32e0f8290d76889cda6a87.tar.gz
spark-6fab941b65f0cb6c9b32e0f8290d76889cda6a87.tar.bz2
spark-6fab941b65f0cb6c9b32e0f8290d76889cda6a87.zip
[SPARK-2934][MLlib] Adding LogisticRegressionWithLBFGS Interface
for training with LBFGS Optimizer which will converge faster than SGD. Author: DB Tsai <dbtsai@alpinenow.com> Closes #1862 from dbtsai/dbtsai-lbfgs-lor and squashes the following commits: aa84b81 [DB Tsai] small change f852bcd [DB Tsai] Remove duplicate method f119fdc [DB Tsai] Formatting 97776aa [DB Tsai] address more feedback 85b4a91 [DB Tsai] address feedback 3cf50c2 [DB Tsai] LogisticRegressionWithLBFGS interface
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala51
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala89
2 files changed, 136 insertions, 4 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
index 2242329b79..31d474a20f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
@@ -101,7 +101,7 @@ class LogisticRegressionWithSGD private (
}
/**
- * Top-level methods for calling Logistic Regression.
+ * Top-level methods for calling Logistic Regression using Stochastic Gradient Descent.
* NOTE: Labels used in Logistic Regression should be {0, 1}
*/
object LogisticRegressionWithSGD {
@@ -188,3 +188,52 @@ object LogisticRegressionWithSGD {
train(input, numIterations, 1.0, 1.0)
}
}
+
+/**
+ * Train a classification model for Logistic Regression using Limited-memory BFGS.
+ * NOTE: Labels used in Logistic Regression should be {0, 1}
+ */
+class LogisticRegressionWithLBFGS private (
+ private var convergenceTol: Double,
+ private var maxNumIterations: Int,
+ private var regParam: Double)
+ extends GeneralizedLinearAlgorithm[LogisticRegressionModel] with Serializable {
+
+ /**
+ * Construct a LogisticRegression object with default parameters
+ */
+ def this() = this(1E-4, 100, 0.0)
+
+ private val gradient = new LogisticGradient()
+ private val updater = new SimpleUpdater()
+ // Have to return new LBFGS object every time since users can reset the parameters anytime.
+ override def optimizer = new LBFGS(gradient, updater)
+ .setNumCorrections(10)
+ .setConvergenceTol(convergenceTol)
+ .setMaxNumIterations(maxNumIterations)
+ .setRegParam(regParam)
+
+ override protected val validators = List(DataValidators.binaryLabelValidator)
+
+ /**
+ * Set the convergence tolerance of iterations for L-BFGS. Default 1E-4.
+ * Smaller value will lead to higher accuracy with the cost of more iterations.
+ */
+ def setConvergenceTol(convergenceTol: Double): this.type = {
+ this.convergenceTol = convergenceTol
+ this
+ }
+
+ /**
+ * Set the maximal number of iterations for L-BFGS. Default 100.
+ */
+ def setNumIterations(numIterations: Int): this.type = {
+ this.maxNumIterations = numIterations
+ this
+ }
+
+ override protected def createModel(weights: Vector, intercept: Double) = {
+ new LogisticRegressionModel(weights, intercept)
+ }
+
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
index da7c633bbd..2289c6cdc1 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
@@ -67,7 +67,7 @@ class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Match
}
// Test if we can correctly learn A, B where Y = logistic(A + B*X)
- test("logistic regression") {
+ test("logistic regression with SGD") {
val nPoints = 10000
val A = 2.0
val B = -1.5
@@ -94,7 +94,36 @@ class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Match
validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
}
- test("logistic regression with initial weights") {
+ // Test if we can correctly learn A, B where Y = logistic(A + B*X)
+ test("logistic regression with LBFGS") {
+ val nPoints = 10000
+ val A = 2.0
+ val B = -1.5
+
+ val testData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 42)
+
+ val testRDD = sc.parallelize(testData, 2)
+ testRDD.cache()
+ val lr = new LogisticRegressionWithLBFGS().setIntercept(true)
+
+ val model = lr.run(testRDD)
+
+ // Test the weights
+ assert(model.weights(0) ~== -1.52 relTol 0.01)
+ assert(model.intercept ~== 2.00 relTol 0.01)
+ assert(model.weights(0) ~== model.weights(0) relTol 0.01)
+ assert(model.intercept ~== model.intercept relTol 0.01)
+
+ val validationData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 17)
+ val validationRDD = sc.parallelize(validationData, 2)
+ // Test prediction on RDD.
+ validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData)
+
+ // Test prediction on Array.
+ validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
+ }
+
+ test("logistic regression with initial weights with SGD") {
val nPoints = 10000
val A = 2.0
val B = -1.5
@@ -125,11 +154,42 @@ class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Match
// Test prediction on Array.
validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
}
+
+ test("logistic regression with initial weights with LBFGS") {
+ val nPoints = 10000
+ val A = 2.0
+ val B = -1.5
+
+ val testData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 42)
+
+ val initialB = -1.0
+ val initialWeights = Vectors.dense(initialB)
+
+ val testRDD = sc.parallelize(testData, 2)
+ testRDD.cache()
+
+ // Use half as many iterations as the previous test.
+ val lr = new LogisticRegressionWithLBFGS().setIntercept(true)
+
+ val model = lr.run(testRDD, initialWeights)
+
+ // Test the weights
+ assert(model.weights(0) ~== -1.50 relTol 0.02)
+ assert(model.intercept ~== 1.97 relTol 0.02)
+
+ val validationData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 17)
+ val validationRDD = sc.parallelize(validationData, 2)
+ // Test prediction on RDD.
+ validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData)
+
+ // Test prediction on Array.
+ validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
+ }
}
class LogisticRegressionClusterSuite extends FunSuite with LocalClusterSparkContext {
- test("task size should be small in both training and prediction") {
+ test("task size should be small in both training and prediction using SGD optimizer") {
val m = 4
val n = 200000
val points = sc.parallelize(0 until m, 2).mapPartitionsWithIndex { (idx, iter) =>
@@ -139,6 +199,29 @@ class LogisticRegressionClusterSuite extends FunSuite with LocalClusterSparkCont
// If we serialize data directly in the task closure, the size of the serialized task would be
// greater than 1MB and hence Spark would throw an error.
val model = LogisticRegressionWithSGD.train(points, 2)
+
val predictions = model.predict(points.map(_.features))
+
+ // Materialize the RDDs
+ predictions.count()
}
+
+ test("task size should be small in both training and prediction using LBFGS optimizer") {
+ val m = 4
+ val n = 200000
+ val points = sc.parallelize(0 until m, 2).mapPartitionsWithIndex { (idx, iter) =>
+ val random = new Random(idx)
+ iter.map(i => LabeledPoint(1.0, Vectors.dense(Array.fill(n)(random.nextDouble()))))
+ }.cache()
+ // If we serialize data directly in the task closure, the size of the serialized task would be
+ // greater than 1MB and hence Spark would throw an error.
+ val model =
+ (new LogisticRegressionWithLBFGS().setIntercept(true).setNumIterations(2)).run(points)
+
+ val predictions = model.predict(points.map(_.features))
+
+ // Materialize the RDDs
+ predictions.count()
+ }
+
}