aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2014-03-26 19:30:20 -0700
committerTathagata Das <tathagata.das1565@gmail.com>2014-03-26 19:30:20 -0700
commitd679843a39bb4918a08a5aebdf113ac8886a5275 (patch)
tree6438556c9b1fc76dd5bc386113dccdc998348d87
parent1fa48d9422d543827011eec0cdf12d060b78a7c7 (diff)
downloadspark-d679843a39bb4918a08a5aebdf113ac8886a5275.tar.gz
spark-d679843a39bb4918a08a5aebdf113ac8886a5275.tar.bz2
spark-d679843a39bb4918a08a5aebdf113ac8886a5275.zip
[SPARK-1327] GLM needs to check addIntercept for intercept and weights
GLM needs to check addIntercept for intercept and weights. The current implementation always uses the first weight as intercept. Added a test for training without adding intercept. JIRA: https://spark-project.atlassian.net/browse/SPARK-1327 Author: Xiangrui Meng <meng@databricks.com> Closes #236 from mengxr/glm and squashes the following commits: bcac1ac [Xiangrui Meng] add two tests to ensure {Lasso, Ridge}.setIntercept will throw an exceptions a104072 [Xiangrui Meng] remove protected to be compatible with 0.9 0e57aa4 [Xiangrui Meng] update Lasso and RidgeRegression to parse the weights correctly from GLM mark createModel protected mark predictPoint protected d7f629f [Xiangrui Meng] fix a bug in GLM when intercept is not used
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala21
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala20
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala20
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala18
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala9
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala26
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala9
7 files changed, 86 insertions, 37 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
index b9621530ef..3e1ed91bf6 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
@@ -136,25 +136,28 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
// Prepend an extra variable consisting of all 1.0's for the intercept.
val data = if (addIntercept) {
- input.map(labeledPoint => (labeledPoint.label, labeledPoint.features.+:(1.0)))
+ input.map(labeledPoint => (labeledPoint.label, 1.0 +: labeledPoint.features))
} else {
input.map(labeledPoint => (labeledPoint.label, labeledPoint.features))
}
val initialWeightsWithIntercept = if (addIntercept) {
- initialWeights.+:(1.0)
+ 0.0 +: initialWeights
} else {
initialWeights
}
- val weights = optimizer.optimize(data, initialWeightsWithIntercept)
- val intercept = weights(0)
- val weightsScaled = weights.tail
+ val weightsWithIntercept = optimizer.optimize(data, initialWeightsWithIntercept)
- val model = createModel(weightsScaled, intercept)
+ val (intercept, weights) = if (addIntercept) {
+ (weightsWithIntercept(0), weightsWithIntercept.tail)
+ } else {
+ (0.0, weightsWithIntercept)
+ }
+
+ logInfo("Final weights " + weights.mkString(","))
+ logInfo("Final intercept " + intercept)
- logInfo("Final model weights " + model.weights.mkString(","))
- logInfo("Final model intercept " + model.intercept)
- model
+ createModel(weights, intercept)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala
index fb2bc9b92a..be63ce8538 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala
@@ -36,8 +36,10 @@ class LassoModel(
extends GeneralizedLinearModel(weights, intercept)
with RegressionModel with Serializable {
- override def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix,
- intercept: Double) = {
+ override def predictPoint(
+ dataMatrix: DoubleMatrix,
+ weightMatrix: DoubleMatrix,
+ intercept: Double): Double = {
dataMatrix.dot(weightMatrix) + intercept
}
}
@@ -66,7 +68,7 @@ class LassoWithSGD private (
.setMiniBatchFraction(miniBatchFraction)
// We don't want to penalize the intercept, so set this to false.
- setIntercept(false)
+ super.setIntercept(false)
var yMean = 0.0
var xColMean: DoubleMatrix = _
@@ -77,10 +79,16 @@ class LassoWithSGD private (
*/
def this() = this(1.0, 100, 1.0, 1.0)
- def createModel(weights: Array[Double], intercept: Double) = {
- val weightsMat = new DoubleMatrix(weights.length + 1, 1, (Array(intercept) ++ weights):_*)
+ override def setIntercept(addIntercept: Boolean): this.type = {
+ // TODO: Support adding intercept.
+ if (addIntercept) throw new UnsupportedOperationException("Adding intercept is not supported.")
+ this
+ }
+
+ override def createModel(weights: Array[Double], intercept: Double) = {
+ val weightsMat = new DoubleMatrix(weights.length, 1, weights: _*)
val weightsScaled = weightsMat.div(xColSd)
- val interceptScaled = yMean - (weightsMat.transpose().mmul(xColMean.div(xColSd)).get(0))
+ val interceptScaled = yMean - weightsMat.transpose().mmul(xColMean.div(xColSd)).get(0)
new LassoModel(weightsScaled.data, interceptScaled)
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala
index 8ee40addb2..f5f15d1a33 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala
@@ -31,13 +31,14 @@ import org.jblas.DoubleMatrix
* @param intercept Intercept computed for this model.
*/
class LinearRegressionModel(
- override val weights: Array[Double],
- override val intercept: Double)
- extends GeneralizedLinearModel(weights, intercept)
- with RegressionModel with Serializable {
-
- override def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix,
- intercept: Double) = {
+ override val weights: Array[Double],
+ override val intercept: Double)
+ extends GeneralizedLinearModel(weights, intercept) with RegressionModel with Serializable {
+
+ override def predictPoint(
+ dataMatrix: DoubleMatrix,
+ weightMatrix: DoubleMatrix,
+ intercept: Double): Double = {
dataMatrix.dot(weightMatrix) + intercept
}
}
@@ -55,8 +56,7 @@ class LinearRegressionWithSGD private (
var stepSize: Double,
var numIterations: Int,
var miniBatchFraction: Double)
- extends GeneralizedLinearAlgorithm[LinearRegressionModel]
- with Serializable {
+ extends GeneralizedLinearAlgorithm[LinearRegressionModel] with Serializable {
val gradient = new LeastSquaresGradient()
val updater = new SimpleUpdater()
@@ -69,7 +69,7 @@ class LinearRegressionWithSGD private (
*/
def this() = this(1.0, 100, 1.0)
- def createModel(weights: Array[Double], intercept: Double) = {
+ override def createModel(weights: Array[Double], intercept: Double) = {
new LinearRegressionModel(weights, intercept)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala
index c504d3d40c..feb100f218 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala
@@ -36,8 +36,10 @@ class RidgeRegressionModel(
extends GeneralizedLinearModel(weights, intercept)
with RegressionModel with Serializable {
- override def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix,
- intercept: Double) = {
+ override def predictPoint(
+ dataMatrix: DoubleMatrix,
+ weightMatrix: DoubleMatrix,
+ intercept: Double): Double = {
dataMatrix.dot(weightMatrix) + intercept
}
}
@@ -67,7 +69,7 @@ class RidgeRegressionWithSGD private (
.setMiniBatchFraction(miniBatchFraction)
// We don't want to penalize the intercept in RidgeRegression, so set this to false.
- setIntercept(false)
+ super.setIntercept(false)
var yMean = 0.0
var xColMean: DoubleMatrix = _
@@ -78,8 +80,14 @@ class RidgeRegressionWithSGD private (
*/
def this() = this(1.0, 100, 1.0, 1.0)
- def createModel(weights: Array[Double], intercept: Double) = {
- val weightsMat = new DoubleMatrix(weights.length + 1, 1, (Array(intercept) ++ weights):_*)
+ override def setIntercept(addIntercept: Boolean): this.type = {
+ // TODO: Support adding intercept.
+ if (addIntercept) throw new UnsupportedOperationException("Adding intercept is not supported.")
+ this
+ }
+
+ override def createModel(weights: Array[Double], intercept: Double) = {
+ val weightsMat = new DoubleMatrix(weights.length, 1, weights: _*)
val weightsScaled = weightsMat.div(xColSd)
val interceptScaled = yMean - weightsMat.transpose().mmul(xColMean.div(xColSd)).get(0)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala
index 64e4cbb860..2cebac943e 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala
@@ -17,11 +17,8 @@
package org.apache.spark.mllib.regression
-
-import org.scalatest.BeforeAndAfterAll
import org.scalatest.FunSuite
-import org.apache.spark.SparkContext
import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext}
class LassoSuite extends FunSuite with LocalSparkContext {
@@ -104,4 +101,10 @@ class LassoSuite extends FunSuite with LocalSparkContext {
// Test prediction on Array.
validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
}
+
+ test("do not support intercept") {
+ intercept[UnsupportedOperationException] {
+ new LassoWithSGD().setIntercept(true)
+ }
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala
index 281f9df36d..5d251bcbf3 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala
@@ -17,7 +17,6 @@
package org.apache.spark.mllib.regression
-import org.scalatest.BeforeAndAfterAll
import org.scalatest.FunSuite
import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext}
@@ -57,4 +56,29 @@ class LinearRegressionSuite extends FunSuite with LocalSparkContext {
// Test prediction on Array.
validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
}
+
+ // Test if we can correctly learn Y = 10*X1 + 10*X2
+ test("linear regression without intercept") {
+ val testRDD = sc.parallelize(LinearDataGenerator.generateLinearInput(
+ 0.0, Array(10.0, 10.0), 100, 42), 2).cache()
+ val linReg = new LinearRegressionWithSGD().setIntercept(false)
+ linReg.optimizer.setNumIterations(1000).setStepSize(1.0)
+
+ val model = linReg.run(testRDD)
+
+ assert(model.intercept === 0.0)
+ assert(model.weights.length === 2)
+ assert(model.weights(0) >= 9.0 && model.weights(0) <= 11.0)
+ assert(model.weights(1) >= 9.0 && model.weights(1) <= 11.0)
+
+ val validationData = LinearDataGenerator.generateLinearInput(
+ 0.0, Array(10.0, 10.0), 100, 17)
+ val validationRDD = sc.parallelize(validationData, 2).cache()
+
+ // Test prediction on RDD.
+ validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData)
+
+ // Test prediction on Array.
+ validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala
index 67dd06cc0f..b2044ed0d8 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala
@@ -17,14 +17,11 @@
package org.apache.spark.mllib.regression
-
import org.jblas.DoubleMatrix
-import org.scalatest.BeforeAndAfterAll
import org.scalatest.FunSuite
import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext}
-
class RidgeRegressionSuite extends FunSuite with LocalSparkContext {
def predictionError(predictions: Seq[Double], input: Seq[LabeledPoint]) = {
@@ -74,4 +71,10 @@ class RidgeRegressionSuite extends FunSuite with LocalSparkContext {
assert(ridgeErr < linearErr,
"ridgeError (" + ridgeErr + ") was not less than linearError(" + linearErr + ")")
}
+
+ test("do not support intercept") {
+ intercept[UnsupportedOperationException] {
+ new RidgeRegressionWithSGD().setIntercept(true)
+ }
+ }
}