aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala44
2 files changed, 42 insertions, 4 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
index 486bdbfa9c..84d3c7cebd 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
@@ -84,7 +84,7 @@ class LogisticRegressionWithSGD private (
extends GeneralizedLinearAlgorithm[LogisticRegressionModel] with Serializable {
private val gradient = new LogisticGradient()
- private val updater = new SimpleUpdater()
+ private val updater = new SquaredL2Updater()
override val optimizer = new GradientDescent(gradient, updater)
.setStepSize(stepSize)
.setNumIterations(numIterations)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
index 862178694a..e954baaf7d 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
@@ -43,7 +43,7 @@ object LogisticRegressionSuite {
offset: Double,
scale: Double,
nPoints: Int,
- seed: Int): Seq[LabeledPoint] = {
+ seed: Int): Seq[LabeledPoint] = {
val rnd = new Random(seed)
val x1 = Array.fill[Double](nPoints)(rnd.nextGaussian())
@@ -58,12 +58,15 @@ object LogisticRegressionSuite {
}
class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Matchers {
- def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) {
+ def validatePrediction(
+ predictions: Seq[Double],
+ input: Seq[LabeledPoint],
+ expectedAcc: Double = 0.83) {
val numOffPredictions = predictions.zip(input).count { case (prediction, expected) =>
prediction != expected.label
}
// At least 83% of the predictions should be on.
- ((input.length - numOffPredictions).toDouble / input.length) should be > 0.83
+ ((input.length - numOffPredictions).toDouble / input.length) should be > expectedAcc
}
// Test if we can correctly learn A, B where Y = logistic(A + B*X)
@@ -155,6 +158,41 @@ class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Match
validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
}
+ test("logistic regression with initial weights and non-default regularization parameter") {
+ val nPoints = 10000
+ val A = 2.0
+ val B = -1.5
+
+ val testData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 42)
+
+ val initialB = -1.0
+ val initialWeights = Vectors.dense(initialB)
+
+ val testRDD = sc.parallelize(testData, 2)
+ testRDD.cache()
+
+ // Use half as many iterations as the previous test.
+ val lr = new LogisticRegressionWithSGD().setIntercept(true)
+ lr.optimizer.
+ setStepSize(10.0).
+ setNumIterations(10).
+ setRegParam(1.0)
+
+ val model = lr.run(testRDD, initialWeights)
+
+ // Test the weights
+ assert(model.weights(0) ~== -430000.0 relTol 20000.0)
+ assert(model.intercept ~== 370000.0 relTol 20000.0)
+
+ val validationData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 17)
+ val validationRDD = sc.parallelize(validationData, 2)
+ // Test prediction on RDD.
+ validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData, 0.8)
+
+ // Test prediction on Array.
+ validatePrediction(validationData.map(row => model.predict(row.features)), validationData, 0.8)
+ }
+
test("logistic regression with initial weights with LBFGS") {
val nPoints = 10000
val A = 2.0