From d679843a39bb4918a08a5aebdf113ac8886a5275 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 26 Mar 2014 19:30:20 -0700 Subject: [SPARK-1327] GLM needs to check addIntercept for intercept and weights GLM needs to check addIntercept for intercept and weights. The current implementation always uses the first weight as intercept. Added a test for training without adding intercept. JIRA: https://spark-project.atlassian.net/browse/SPARK-1327 Author: Xiangrui Meng Closes #236 from mengxr/glm and squashes the following commits: bcac1ac [Xiangrui Meng] add two tests to ensure {Lasso, Ridge}.setIntercept will throw an exceptions a104072 [Xiangrui Meng] remove protected to be compatible with 0.9 0e57aa4 [Xiangrui Meng] update Lasso and RidgeRegression to parse the weights correctly from GLM mark createModel protected mark predictPoint protected d7f629f [Xiangrui Meng] fix a bug in GLM when intercept is not used --- .../regression/GeneralizedLinearAlgorithm.scala | 21 +++++++++-------- .../org/apache/spark/mllib/regression/Lasso.scala | 20 ++++++++++++----- .../spark/mllib/regression/LinearRegression.scala | 20 ++++++++--------- .../spark/mllib/regression/RidgeRegression.scala | 18 ++++++++++----- .../apache/spark/mllib/regression/LassoSuite.scala | 9 +++++--- .../mllib/regression/LinearRegressionSuite.scala | 26 +++++++++++++++++++++- .../mllib/regression/RidgeRegressionSuite.scala | 9 +++++--- 7 files changed, 86 insertions(+), 37 deletions(-) (limited to 'mllib') diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala index b9621530ef..3e1ed91bf6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala @@ -136,25 +136,28 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] // Prepend an extra variable consisting of all 1.0's for the intercept. val data = if (addIntercept) { - input.map(labeledPoint => (labeledPoint.label, labeledPoint.features.+:(1.0))) + input.map(labeledPoint => (labeledPoint.label, 1.0 +: labeledPoint.features)) } else { input.map(labeledPoint => (labeledPoint.label, labeledPoint.features)) } val initialWeightsWithIntercept = if (addIntercept) { - initialWeights.+:(1.0) + 0.0 +: initialWeights } else { initialWeights } - val weights = optimizer.optimize(data, initialWeightsWithIntercept) - val intercept = weights(0) - val weightsScaled = weights.tail + val weightsWithIntercept = optimizer.optimize(data, initialWeightsWithIntercept) - val model = createModel(weightsScaled, intercept) + val (intercept, weights) = if (addIntercept) { + (weightsWithIntercept(0), weightsWithIntercept.tail) + } else { + (0.0, weightsWithIntercept) + } + + logInfo("Final weights " + weights.mkString(",")) + logInfo("Final intercept " + intercept) - logInfo("Final model weights " + model.weights.mkString(",")) - logInfo("Final model intercept " + model.intercept) - model + createModel(weights, intercept) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala index fb2bc9b92a..be63ce8538 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala @@ -36,8 +36,10 @@ class LassoModel( extends GeneralizedLinearModel(weights, intercept) with RegressionModel with Serializable { - override def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix, - intercept: Double) = { + override def predictPoint( + dataMatrix: DoubleMatrix, + weightMatrix: DoubleMatrix, + intercept: Double): Double = { dataMatrix.dot(weightMatrix) + intercept } } @@ -66,7 +68,7 @@ class LassoWithSGD private ( .setMiniBatchFraction(miniBatchFraction) // We don't want to penalize the intercept, so set this to false. - setIntercept(false) + super.setIntercept(false) var yMean = 0.0 var xColMean: DoubleMatrix = _ @@ -77,10 +79,16 @@ class LassoWithSGD private ( */ def this() = this(1.0, 100, 1.0, 1.0) - def createModel(weights: Array[Double], intercept: Double) = { - val weightsMat = new DoubleMatrix(weights.length + 1, 1, (Array(intercept) ++ weights):_*) + override def setIntercept(addIntercept: Boolean): this.type = { + // TODO: Support adding intercept. + if (addIntercept) throw new UnsupportedOperationException("Adding intercept is not supported.") + this + } + + override def createModel(weights: Array[Double], intercept: Double) = { + val weightsMat = new DoubleMatrix(weights.length, 1, weights: _*) val weightsScaled = weightsMat.div(xColSd) - val interceptScaled = yMean - (weightsMat.transpose().mmul(xColMean.div(xColSd)).get(0)) + val interceptScaled = yMean - weightsMat.transpose().mmul(xColMean.div(xColSd)).get(0) new LassoModel(weightsScaled.data, interceptScaled) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala index 8ee40addb2..f5f15d1a33 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala @@ -31,13 +31,14 @@ import org.jblas.DoubleMatrix * @param intercept Intercept computed for this model. */ class LinearRegressionModel( - override val weights: Array[Double], - override val intercept: Double) - extends GeneralizedLinearModel(weights, intercept) - with RegressionModel with Serializable { - - override def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix, - intercept: Double) = { + override val weights: Array[Double], + override val intercept: Double) + extends GeneralizedLinearModel(weights, intercept) with RegressionModel with Serializable { + + override def predictPoint( + dataMatrix: DoubleMatrix, + weightMatrix: DoubleMatrix, + intercept: Double): Double = { dataMatrix.dot(weightMatrix) + intercept } } @@ -55,8 +56,7 @@ class LinearRegressionWithSGD private ( var stepSize: Double, var numIterations: Int, var miniBatchFraction: Double) - extends GeneralizedLinearAlgorithm[LinearRegressionModel] - with Serializable { + extends GeneralizedLinearAlgorithm[LinearRegressionModel] with Serializable { val gradient = new LeastSquaresGradient() val updater = new SimpleUpdater() @@ -69,7 +69,7 @@ class LinearRegressionWithSGD private ( */ def this() = this(1.0, 100, 1.0) - def createModel(weights: Array[Double], intercept: Double) = { + override def createModel(weights: Array[Double], intercept: Double) = { new LinearRegressionModel(weights, intercept) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala index c504d3d40c..feb100f218 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala @@ -36,8 +36,10 @@ class RidgeRegressionModel( extends GeneralizedLinearModel(weights, intercept) with RegressionModel with Serializable { - override def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix, - intercept: Double) = { + override def predictPoint( + dataMatrix: DoubleMatrix, + weightMatrix: DoubleMatrix, + intercept: Double): Double = { dataMatrix.dot(weightMatrix) + intercept } } @@ -67,7 +69,7 @@ class RidgeRegressionWithSGD private ( .setMiniBatchFraction(miniBatchFraction) // We don't want to penalize the intercept in RidgeRegression, so set this to false. - setIntercept(false) + super.setIntercept(false) var yMean = 0.0 var xColMean: DoubleMatrix = _ @@ -78,8 +80,14 @@ class RidgeRegressionWithSGD private ( */ def this() = this(1.0, 100, 1.0, 1.0) - def createModel(weights: Array[Double], intercept: Double) = { - val weightsMat = new DoubleMatrix(weights.length + 1, 1, (Array(intercept) ++ weights):_*) + override def setIntercept(addIntercept: Boolean): this.type = { + // TODO: Support adding intercept. + if (addIntercept) throw new UnsupportedOperationException("Adding intercept is not supported.") + this + } + + override def createModel(weights: Array[Double], intercept: Double) = { + val weightsMat = new DoubleMatrix(weights.length, 1, weights: _*) val weightsScaled = weightsMat.div(xColSd) val interceptScaled = yMean - weightsMat.transpose().mmul(xColMean.div(xColSd)).get(0) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala index 64e4cbb860..2cebac943e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala @@ -17,11 +17,8 @@ package org.apache.spark.mllib.regression - -import org.scalatest.BeforeAndAfterAll import org.scalatest.FunSuite -import org.apache.spark.SparkContext import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext} class LassoSuite extends FunSuite with LocalSparkContext { @@ -104,4 +101,10 @@ class LassoSuite extends FunSuite with LocalSparkContext { // Test prediction on Array. validatePrediction(validationData.map(row => model.predict(row.features)), validationData) } + + test("do not support intercept") { + intercept[UnsupportedOperationException] { + new LassoWithSGD().setIntercept(true) + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala index 281f9df36d..5d251bcbf3 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.mllib.regression -import org.scalatest.BeforeAndAfterAll import org.scalatest.FunSuite import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext} @@ -57,4 +56,29 @@ class LinearRegressionSuite extends FunSuite with LocalSparkContext { // Test prediction on Array. validatePrediction(validationData.map(row => model.predict(row.features)), validationData) } + + // Test if we can correctly learn Y = 10*X1 + 10*X2 + test("linear regression without intercept") { + val testRDD = sc.parallelize(LinearDataGenerator.generateLinearInput( + 0.0, Array(10.0, 10.0), 100, 42), 2).cache() + val linReg = new LinearRegressionWithSGD().setIntercept(false) + linReg.optimizer.setNumIterations(1000).setStepSize(1.0) + + val model = linReg.run(testRDD) + + assert(model.intercept === 0.0) + assert(model.weights.length === 2) + assert(model.weights(0) >= 9.0 && model.weights(0) <= 11.0) + assert(model.weights(1) >= 9.0 && model.weights(1) <= 11.0) + + val validationData = LinearDataGenerator.generateLinearInput( + 0.0, Array(10.0, 10.0), 100, 17) + val validationRDD = sc.parallelize(validationData, 2).cache() + + // Test prediction on RDD. + validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) + + // Test prediction on Array. + validatePrediction(validationData.map(row => model.predict(row.features)), validationData) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala index 67dd06cc0f..b2044ed0d8 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala @@ -17,14 +17,11 @@ package org.apache.spark.mllib.regression - import org.jblas.DoubleMatrix -import org.scalatest.BeforeAndAfterAll import org.scalatest.FunSuite import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext} - class RidgeRegressionSuite extends FunSuite with LocalSparkContext { def predictionError(predictions: Seq[Double], input: Seq[LabeledPoint]) = { @@ -74,4 +71,10 @@ class RidgeRegressionSuite extends FunSuite with LocalSparkContext { assert(ridgeErr < linearErr, "ridgeError (" + ridgeErr + ") was not less than linearError(" + linearErr + ")") } + + test("do not support intercept") { + intercept[UnsupportedOperationException] { + new RidgeRegressionWithSGD().setIntercept(true) + } + } } -- cgit v1.2.3