aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorShivaram Venkataraman <shivaram@eecs.berkeley.edu>2013-08-25 22:24:27 -0700
committerShivaram Venkataraman <shivaram@eecs.berkeley.edu>2013-08-25 22:24:27 -0700
commitb8c50a0642cf74c25fd70cc1e7d1be95ddafc5d8 (patch)
tree8f2d5abb7cd565d81e56d3ceca8134eaae41abe4 /mllib
parent07fe910669b2ec15b6b5c1e5186df5036d05b9b1 (diff)
downloadspark-b8c50a0642cf74c25fd70cc1e7d1be95ddafc5d8.tar.gz
spark-b8c50a0642cf74c25fd70cc1e7d1be95ddafc5d8.tar.bz2
spark-b8c50a0642cf74c25fd70cc1e7d1be95ddafc5d8.zip
Center & scale variables in Ridge, Lasso.
Also add a unit test that checks if ridge regression lowers cross-validation error.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/spark/mllib/regression/Lasso.scala45
-rw-r--r--mllib/src/main/scala/spark/mllib/regression/LinearRegression.scala44
-rw-r--r--mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala95
-rw-r--r--mllib/src/main/scala/spark/mllib/util/LinearDataGenerator.scala24
-rw-r--r--mllib/src/main/scala/spark/mllib/util/MLUtils.scala10
-rw-r--r--mllib/src/test/java/spark/mllib/regression/JavaLassoSuite.java8
-rw-r--r--mllib/src/test/java/spark/mllib/regression/JavaLinearRegressionSuite.java119
-rw-r--r--mllib/src/test/java/spark/mllib/regression/JavaRidgeRegressionSuite.java139
-rw-r--r--mllib/src/test/scala/spark/mllib/regression/LinearRegressionSuite.scala27
-rw-r--r--mllib/src/test/scala/spark/mllib/regression/RidgeRegressionSuite.scala64
10 files changed, 347 insertions, 228 deletions
diff --git a/mllib/src/main/scala/spark/mllib/regression/Lasso.scala b/mllib/src/main/scala/spark/mllib/regression/Lasso.scala
index 6bbc990a5a..929c36bd76 100644
--- a/mllib/src/main/scala/spark/mllib/regression/Lasso.scala
+++ b/mllib/src/main/scala/spark/mllib/regression/Lasso.scala
@@ -55,10 +55,17 @@ class LassoWithSGD private (
val gradient = new SquaredGradient()
val updater = new L1Updater()
- val optimizer = new GradientDescent(gradient, updater).setStepSize(stepSize)
- .setNumIterations(numIterations)
- .setRegParam(regParam)
- .setMiniBatchFraction(miniBatchFraction)
+ @transient val optimizer = new GradientDescent(gradient, updater).setStepSize(stepSize)
+ .setNumIterations(numIterations)
+ .setRegParam(regParam)
+ .setMiniBatchFraction(miniBatchFraction)
+
+ // We don't want to penalize the intercept, so set this to false.
+ setIntercept(false)
+
+ var yMean = 0.0
+ var xColMean: DoubleMatrix = _
+ var xColSd: DoubleMatrix = _
/**
* Construct a Lasso object with default parameters
@@ -66,7 +73,35 @@ class LassoWithSGD private (
def this() = this(1.0, 100, 1.0, 1.0, true)
def createModel(weights: Array[Double], intercept: Double) = {
- new LassoModel(weights, intercept)
+ val weightsMat = new DoubleMatrix(weights.length + 1, 1, (Array(intercept) ++ weights):_*)
+ val weightsScaled = weightsMat.div(xColSd)
+ val interceptScaled = yMean - (weightsMat.transpose().mmul(xColMean.div(xColSd)).get(0))
+
+ new LassoModel(weightsScaled.data, interceptScaled)
+ }
+
+ override def run(
+ input: RDD[LabeledPoint],
+ initialWeights: Array[Double])
+ : LassoModel =
+ {
+ val nfeatures: Int = input.first.features.length
+ val nexamples: Long = input.count()
+
+ // To avoid penalizing the intercept, we center and scale the data.
+ val stats = MLUtils.computeStats(input, nfeatures, nexamples)
+ yMean = stats._1
+ xColMean = stats._2
+ xColSd = stats._3
+
+ val normalizedData = input.map { point =>
+ val yNormalized = point.label - yMean
+ val featuresMat = new DoubleMatrix(nfeatures, 1, point.features:_*)
+ val featuresNormalized = featuresMat.sub(xColMean).divi(xColSd)
+ LabeledPoint(yNormalized, featuresNormalized.toArray)
+ }
+
+ super.run(normalizedData, initialWeights)
}
}
diff --git a/mllib/src/main/scala/spark/mllib/regression/LinearRegression.scala b/mllib/src/main/scala/spark/mllib/regression/LinearRegression.scala
index 0ea5348a1f..5b3743f2fa 100644
--- a/mllib/src/main/scala/spark/mllib/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/spark/mllib/regression/LinearRegression.scala
@@ -45,10 +45,10 @@ class LinearRegressionModel(
* Train a regression model with no regularization using Stochastic Gradient Descent.
*/
class LinearRegressionWithSGD private (
- var stepSize: Double,
- var numIterations: Int,
- var miniBatchFraction: Double,
- var addIntercept: Boolean)
+ var stepSize: Double,
+ var numIterations: Int,
+ var miniBatchFraction: Double,
+ var addIntercept: Boolean)
extends GeneralizedLinearAlgorithm[LinearRegressionModel]
with Serializable {
@@ -87,12 +87,12 @@ object LinearRegressionWithSGD {
* the number of features in the data.
*/
def train(
- input: RDD[LabeledPoint],
- numIterations: Int,
- stepSize: Double,
- miniBatchFraction: Double,
- initialWeights: Array[Double])
- : LinearRegressionModel =
+ input: RDD[LabeledPoint],
+ numIterations: Int,
+ stepSize: Double,
+ miniBatchFraction: Double,
+ initialWeights: Array[Double])
+ : LinearRegressionModel =
{
new LinearRegressionWithSGD(stepSize, numIterations, miniBatchFraction, true).run(input,
initialWeights)
@@ -109,11 +109,11 @@ object LinearRegressionWithSGD {
* @param miniBatchFraction Fraction of data to be used per iteration.
*/
def train(
- input: RDD[LabeledPoint],
- numIterations: Int,
- stepSize: Double,
- miniBatchFraction: Double)
- : LinearRegressionModel =
+ input: RDD[LabeledPoint],
+ numIterations: Int,
+ stepSize: Double,
+ miniBatchFraction: Double)
+ : LinearRegressionModel =
{
new LinearRegressionWithSGD(stepSize, numIterations, miniBatchFraction, true).run(input)
}
@@ -129,10 +129,10 @@ object LinearRegressionWithSGD {
* @return a LinearRegressionModel which has the weights and offset from training.
*/
def train(
- input: RDD[LabeledPoint],
- numIterations: Int,
- stepSize: Double)
- : LinearRegressionModel =
+ input: RDD[LabeledPoint],
+ numIterations: Int,
+ stepSize: Double)
+ : LinearRegressionModel =
{
train(input, numIterations, stepSize, 1.0)
}
@@ -147,9 +147,9 @@ object LinearRegressionWithSGD {
* @return a LinearRegressionModel which has the weights and offset from training.
*/
def train(
- input: RDD[LabeledPoint],
- numIterations: Int)
- : LinearRegressionModel =
+ input: RDD[LabeledPoint],
+ numIterations: Int)
+ : LinearRegressionModel =
{
train(input, numIterations, 1.0, 1.0)
}
diff --git a/mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala
index addf8cd59e..ccf7364806 100644
--- a/mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala
+++ b/mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala
@@ -55,18 +55,54 @@ class RidgeRegressionWithSGD private (
val gradient = new SquaredGradient()
val updater = new SquaredL2Updater()
- val optimizer = new GradientDescent(gradient, updater).setStepSize(stepSize)
+
+ @transient val optimizer = new GradientDescent(gradient, updater).setStepSize(stepSize)
.setNumIterations(numIterations)
.setRegParam(regParam)
.setMiniBatchFraction(miniBatchFraction)
+ // We don't want to penalize the intercept in RidgeRegression, so set this to false.
+ setIntercept(false)
+
+ var yMean = 0.0
+ var xColMean: DoubleMatrix = _
+ var xColSd: DoubleMatrix = _
+
/**
* Construct a RidgeRegression object with default parameters
*/
def this() = this(1.0, 100, 1.0, 1.0, true)
def createModel(weights: Array[Double], intercept: Double) = {
- new RidgeRegressionModel(weights, intercept)
+ val weightsMat = new DoubleMatrix(weights.length + 1, 1, (Array(intercept) ++ weights):_*)
+ val weightsScaled = weightsMat.div(xColSd)
+ val interceptScaled = yMean - (weightsMat.transpose().mmul(xColMean.div(xColSd)).get(0))
+
+ new RidgeRegressionModel(weightsScaled.data, interceptScaled)
+ }
+
+ override def run(
+ input: RDD[LabeledPoint],
+ initialWeights: Array[Double])
+ : RidgeRegressionModel =
+ {
+ val nfeatures: Int = input.first.features.length
+ val nexamples: Long = input.count()
+
+ // To avoid penalizing the intercept, we center and scale the data.
+ val stats = MLUtils.computeStats(input, nfeatures, nexamples)
+ yMean = stats._1
+ xColMean = stats._2
+ xColSd = stats._3
+
+ val normalizedData = input.map { point =>
+ val yNormalized = point.label - yMean
+ val featuresMat = new DoubleMatrix(nfeatures, 1, point.features:_*)
+ val featuresNormalized = featuresMat.sub(xColMean).divi(xColSd)
+ LabeledPoint(yNormalized, featuresNormalized.toArray)
+ }
+
+ super.run(normalizedData, initialWeights)
}
}
@@ -90,16 +126,16 @@ object RidgeRegressionWithSGD {
* the number of features in the data.
*/
def train(
- input: RDD[LabeledPoint],
- numIterations: Int,
- stepSize: Double,
- regParam: Double,
- miniBatchFraction: Double,
- initialWeights: Array[Double])
- : RidgeRegressionModel =
+ input: RDD[LabeledPoint],
+ numIterations: Int,
+ stepSize: Double,
+ regParam: Double,
+ miniBatchFraction: Double,
+ initialWeights: Array[Double])
+ : RidgeRegressionModel =
{
- new RidgeRegressionWithSGD(stepSize, numIterations, regParam, miniBatchFraction, true).run(input,
- initialWeights)
+ new RidgeRegressionWithSGD(stepSize, numIterations, regParam, miniBatchFraction, true).run(
+ input, initialWeights)
}
/**
@@ -114,14 +150,15 @@ object RidgeRegressionWithSGD {
* @param miniBatchFraction Fraction of data to be used per iteration.
*/
def train(
- input: RDD[LabeledPoint],
- numIterations: Int,
- stepSize: Double,
- regParam: Double,
- miniBatchFraction: Double)
- : RidgeRegressionModel =
+ input: RDD[LabeledPoint],
+ numIterations: Int,
+ stepSize: Double,
+ regParam: Double,
+ miniBatchFraction: Double)
+ : RidgeRegressionModel =
{
- new RidgeRegressionWithSGD(stepSize, numIterations, regParam, miniBatchFraction, true).run(input)
+ new RidgeRegressionWithSGD(stepSize, numIterations, regParam, miniBatchFraction, true).run(
+ input)
}
/**
@@ -136,11 +173,11 @@ object RidgeRegressionWithSGD {
* @return a RidgeRegressionModel which has the weights and offset from training.
*/
def train(
- input: RDD[LabeledPoint],
- numIterations: Int,
- stepSize: Double,
- regParam: Double)
- : RidgeRegressionModel =
+ input: RDD[LabeledPoint],
+ numIterations: Int,
+ stepSize: Double,
+ regParam: Double)
+ : RidgeRegressionModel =
{
train(input, numIterations, stepSize, regParam, 1.0)
}
@@ -155,21 +192,23 @@ object RidgeRegressionWithSGD {
* @return a RidgeRegressionModel which has the weights and offset from training.
*/
def train(
- input: RDD[LabeledPoint],
- numIterations: Int)
- : RidgeRegressionModel =
+ input: RDD[LabeledPoint],
+ numIterations: Int)
+ : RidgeRegressionModel =
{
train(input, numIterations, 1.0, 1.0, 1.0)
}
def main(args: Array[String]) {
if (args.length != 5) {
- println("Usage: RidgeRegression <master> <input_dir> <step_size> <regularization_parameter> <niters>")
+ println("Usage: RidgeRegression <master> <input_dir> <step_size> <regularization_parameter>" +
+ " <niters>")
System.exit(1)
}
val sc = new SparkContext(args(0), "RidgeRegression")
val data = MLUtils.loadLabeledData(sc, args(1))
- val model = RidgeRegressionWithSGD.train(data, args(4).toInt, args(2).toDouble, args(3).toDouble)
+ val model = RidgeRegressionWithSGD.train(data, args(4).toInt, args(2).toDouble,
+ args(3).toDouble)
sc.stop()
}
diff --git a/mllib/src/main/scala/spark/mllib/util/LinearDataGenerator.scala b/mllib/src/main/scala/spark/mllib/util/LinearDataGenerator.scala
index 20e1656beb..9f48477f84 100644
--- a/mllib/src/main/scala/spark/mllib/util/LinearDataGenerator.scala
+++ b/mllib/src/main/scala/spark/mllib/util/LinearDataGenerator.scala
@@ -17,20 +17,19 @@
package spark.mllib.util
+import scala.collection.JavaConversions._
import scala.util.Random
import org.jblas.DoubleMatrix
import spark.{RDD, SparkContext}
import spark.mllib.regression.LabeledPoint
-import scala.collection.JavaConversions._
import spark.mllib.regression.LabeledPoint
/**
* Generate sample data used for Linear Data. This class generates
* uniformly random values for every feature and adds Gaussian noise with mean `eps` to the
* response variable `Y`.
- *
*/
object LinearDataGenerator {
@@ -47,8 +46,9 @@ object LinearDataGenerator {
intercept: Double,
weights: Array[Double],
nPoints: Int,
- seed: Int): java.util.List[LabeledPoint] = {
- seqAsJavaList(generateLinearInput(intercept, weights, nPoints, seed))
+ seed: Int,
+ eps: Double): java.util.List[LabeledPoint] = {
+ seqAsJavaList(generateLinearInput(intercept, weights, nPoints, seed, eps))
}
/**
@@ -70,10 +70,10 @@ object LinearDataGenerator {
val rnd = new Random(seed)
val weightsMat = new DoubleMatrix(1, weights.length, weights:_*)
val x = Array.fill[Array[Double]](nPoints)(
- Array.fill[Double](weights.length)(rnd.nextGaussian()))
- val y = x.map(xi =>
+ Array.fill[Double](weights.length)(2 * rnd.nextDouble - 1.0))
+ val y = x.map { xi =>
(new DoubleMatrix(1, xi.length, xi:_*)).dot(weightsMat) + intercept + eps * rnd.nextGaussian()
- )
+ }
y.zip(x).map(p => LabeledPoint(p._1, p._2))
}
@@ -95,19 +95,15 @@ object LinearDataGenerator {
nexamples: Int,
nfeatures: Int,
eps: Double,
- weights: Array[Double] = Array[Double](),
nparts: Int = 2,
intercept: Double = 0.0) : RDD[LabeledPoint] = {
org.jblas.util.Random.seed(42)
// Random values distributed uniformly in [-0.5, 0.5]
val w = DoubleMatrix.rand(nfeatures, 1).subi(0.5)
- (0 until weights.length.max(nfeatures)).map(i => w.put(i, 0, weights(i)))
-
val data: RDD[LabeledPoint] = sc.parallelize(0 until nparts, nparts).flatMap { p =>
- val seed = 42+p
+ val seed = 42 + p
val examplesInPartition = nexamples / nparts
-
generateLinearInput(intercept, w.toArray, examplesInPartition, seed, eps)
}
data
@@ -115,7 +111,7 @@ object LinearDataGenerator {
def main(args: Array[String]) {
if (args.length < 2) {
- println("Usage: RidgeRegressionGenerator " +
+ println("Usage: LinearDataGenerator " +
"<master> <output_dir> [num_examples] [num_features] [num_partitions]")
System.exit(1)
}
@@ -127,7 +123,7 @@ object LinearDataGenerator {
val parts: Int = if (args.length > 4) args(4).toInt else 2
val eps = 10
- val sc = new SparkContext(sparkMaster, "RidgeRegressionDataGenerator")
+ val sc = new SparkContext(sparkMaster, "LinearDataGenerator")
val data = generateLinearRDD(sc, nexamples, nfeatures, eps, nparts = parts)
MLUtils.saveLabeledData(data, outputPath)
diff --git a/mllib/src/main/scala/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/spark/mllib/util/MLUtils.scala
index 4e030a81b4..a8e6ae9953 100644
--- a/mllib/src/main/scala/spark/mllib/util/MLUtils.scala
+++ b/mllib/src/main/scala/spark/mllib/util/MLUtils.scala
@@ -72,16 +72,16 @@ object MLUtils {
* xColMean - Row vector with mean for every column (or feature) of the input data
* xColSd - Row vector standard deviation for every column (or feature) of the input data.
*/
- def computeStats(data: RDD[(Double, Array[Double])], nfeatures: Int, nexamples: Long):
+ def computeStats(data: RDD[LabeledPoint], nfeatures: Int, nexamples: Long):
(Double, DoubleMatrix, DoubleMatrix) = {
- val yMean: Double = data.map { case (y, features) => y }.reduce(_ + _) / nexamples
+ val yMean: Double = data.map { labeledPoint => labeledPoint.label }.reduce(_ + _) / nexamples
// NOTE: We shuffle X by column here to compute column sum and sum of squares.
- val xColSumSq: RDD[(Int, (Double, Double))] = data.flatMap { case(y, features) =>
- val nCols = features.length
+ val xColSumSq: RDD[(Int, (Double, Double))] = data.flatMap { labeledPoint =>
+ val nCols = labeledPoint.features.length
// Traverse over every column and emit (col, value, value^2)
Iterator.tabulate(nCols) { i =>
- (i, (features(i), features(i)*features(i)))
+ (i, (labeledPoint.features(i), labeledPoint.features(i)*labeledPoint.features(i)))
}
}.reduceByKey { case(x1, x2) =>
(x1._1 + x2._1, x1._2 + x2._2)
diff --git a/mllib/src/test/java/spark/mllib/regression/JavaLassoSuite.java b/mllib/src/test/java/spark/mllib/regression/JavaLassoSuite.java
index 428902e85c..5863140baf 100644
--- a/mllib/src/test/java/spark/mllib/regression/JavaLassoSuite.java
+++ b/mllib/src/test/java/spark/mllib/regression/JavaLassoSuite.java
@@ -63,9 +63,9 @@ public class JavaLassoSuite implements Serializable {
double[] weights = {-1.5, 1.0e-2};
JavaRDD<LabeledPoint> testRDD = sc.parallelize(LinearDataGenerator.generateLinearInputAsList(A,
- weights, nPoints, 42), 2).cache();
+ weights, nPoints, 42, 0.1), 2).cache();
List<LabeledPoint> validationData =
- LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17);
+ LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1);
LassoWithSGD lassoSGDImpl = new LassoWithSGD();
lassoSGDImpl.optimizer().setStepSize(1.0)
@@ -84,9 +84,9 @@ public class JavaLassoSuite implements Serializable {
double[] weights = {-1.5, 1.0e-2};
JavaRDD<LabeledPoint> testRDD = sc.parallelize(LinearDataGenerator.generateLinearInputAsList(A,
- weights, nPoints, 42), 2).cache();
+ weights, nPoints, 42, 0.1), 2).cache();
List<LabeledPoint> validationData =
- LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17);
+ LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1);
LassoModel model = LassoWithSGD.train(testRDD.rdd(), 100, 1.0, 0.01, 1.0);
diff --git a/mllib/src/test/java/spark/mllib/regression/JavaLinearRegressionSuite.java b/mllib/src/test/java/spark/mllib/regression/JavaLinearRegressionSuite.java
index 9642e89844..50716c7861 100644
--- a/mllib/src/test/java/spark/mllib/regression/JavaLinearRegressionSuite.java
+++ b/mllib/src/test/java/spark/mllib/regression/JavaLinearRegressionSuite.java
@@ -30,68 +30,65 @@ import spark.api.java.JavaSparkContext;
import spark.mllib.util.LinearDataGenerator;
public class JavaLinearRegressionSuite implements Serializable {
- private transient JavaSparkContext sc;
-
- @Before
- public void setUp() {
- sc = new JavaSparkContext("local", "JavaLinearRegressionSuite");
- }
-
- @After
- public void tearDown() {
- sc.stop();
- sc = null;
- System.clearProperty("spark.driver.port");
- }
-
- int validatePrediction(List<LabeledPoint> validationData, LinearRegressionModel model) {
- int numAccurate = 0;
- for (LabeledPoint point: validationData) {
- Double prediction = model.predict(point.features());
- // A prediction is off if the prediction is more than 0.5 away from expected value.
- if (Math.abs(prediction - point.label()) <= 0.5) {
- numAccurate++;
- }
+ private transient JavaSparkContext sc;
+
+ @Before
+ public void setUp() {
+ sc = new JavaSparkContext("local", "JavaLinearRegressionSuite");
+ }
+
+ @After
+ public void tearDown() {
+ sc.stop();
+ sc = null;
+ System.clearProperty("spark.driver.port");
+ }
+
+ int validatePrediction(List<LabeledPoint> validationData, LinearRegressionModel model) {
+ int numAccurate = 0;
+ for (LabeledPoint point: validationData) {
+ Double prediction = model.predict(point.features());
+ // A prediction is off if the prediction is more than 0.5 away from expected value.
+ if (Math.abs(prediction - point.label()) <= 0.5) {
+ numAccurate++;
}
- return numAccurate;
- }
-
- @Test
- public void runLinearRegressionUsingConstructor() {
- int nPoints = 10000;
- double A = 2.0;
- double[] weights = {-1.5, 1.0e-2};
-
- JavaRDD<LabeledPoint> testRDD = sc.parallelize(LinearDataGenerator.generateLinearInputAsList(A,
- weights, nPoints, 42), 2).cache();
- List<LabeledPoint> validationData =
- LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17);
-
- LinearRegressionWithSGD linSGDImpl = new LinearRegressionWithSGD();
- linSGDImpl.optimizer().setStepSize(1.0)
- .setRegParam(0.01)
- .setNumIterations(20);
- LinearRegressionModel model = linSGDImpl.run(testRDD.rdd());
-
- int numAccurate = validatePrediction(validationData, model);
- Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0);
- }
-
- @Test
- public void runLinearRegressionUsingStaticMethods() {
- int nPoints = 10000;
- double A = 2.0;
- double[] weights = {-1.5, 1.0e-2};
-
- JavaRDD<LabeledPoint> testRDD = sc.parallelize(LinearDataGenerator.generateLinearInputAsList(A,
- weights, nPoints, 42), 2).cache();
- List<LabeledPoint> validationData =
- LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17);
-
- LinearRegressionModel model = LinearRegressionWithSGD.train(testRDD.rdd(), 100, 1.0, 1.0);
-
- int numAccurate = validatePrediction(validationData, model);
- Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0);
}
+ return numAccurate;
+ }
+
+ @Test
+ public void runLinearRegressionUsingConstructor() {
+ int nPoints = 100;
+ double A = 3.0;
+ double[] weights = {10, 10};
+
+ JavaRDD<LabeledPoint> testRDD = sc.parallelize(
+ LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 42, 0.1), 2).cache();
+ List<LabeledPoint> validationData =
+ LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1);
+
+ LinearRegressionWithSGD linSGDImpl = new LinearRegressionWithSGD();
+ LinearRegressionModel model = linSGDImpl.run(testRDD.rdd());
+
+ int numAccurate = validatePrediction(validationData, model);
+ Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0);
+ }
+
+ @Test
+ public void runLinearRegressionUsingStaticMethods() {
+ int nPoints = 100;
+ double A = 3.0;
+ double[] weights = {10, 10};
+
+ JavaRDD<LabeledPoint> testRDD = sc.parallelize(
+ LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 42, 0.1), 2).cache();
+ List<LabeledPoint> validationData =
+ LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1);
+
+ LinearRegressionModel model = LinearRegressionWithSGD.train(testRDD.rdd(), 100);
+
+ int numAccurate = validatePrediction(validationData, model);
+ Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0);
+ }
}
diff --git a/mllib/src/test/java/spark/mllib/regression/JavaRidgeRegressionSuite.java b/mllib/src/test/java/spark/mllib/regression/JavaRidgeRegressionSuite.java
index 5df6d8076d..2c0aabad30 100644
--- a/mllib/src/test/java/spark/mllib/regression/JavaRidgeRegressionSuite.java
+++ b/mllib/src/test/java/spark/mllib/regression/JavaRidgeRegressionSuite.java
@@ -25,73 +25,86 @@ import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
+import org.jblas.DoubleMatrix;
+
import spark.api.java.JavaRDD;
import spark.api.java.JavaSparkContext;
import spark.mllib.util.LinearDataGenerator;
public class JavaRidgeRegressionSuite implements Serializable {
- private transient JavaSparkContext sc;
-
- @Before
- public void setUp() {
- sc = new JavaSparkContext("local", "JavaRidgeRegressionSuite");
- }
-
- @After
- public void tearDown() {
- sc.stop();
- sc = null;
- System.clearProperty("spark.driver.port");
- }
-
- int validatePrediction(List<LabeledPoint> validationData, RidgeRegressionModel model) {
- int numAccurate = 0;
- for (LabeledPoint point: validationData) {
- Double prediction = model.predict(point.features());
- // A prediction is off if the prediction is more than 0.5 away from expected value.
- if (Math.abs(prediction - point.label()) <= 0.5) {
- numAccurate++;
- }
- }
- return numAccurate;
- }
-
- @Test
- public void runRidgeRegressionUsingConstructor() {
- int nPoints = 10000;
- double A = 2.0;
- double[] weights = {-1.5, 1.0e-2};
-
- JavaRDD<LabeledPoint> testRDD = sc.parallelize(LinearDataGenerator.generateLinearInputAsList(A,
- weights, nPoints, 42), 2).cache();
- List<LabeledPoint> validationData =
- LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17);
-
- RidgeRegressionWithSGD ridgeSGDImpl = new RidgeRegressionWithSGD();
- ridgeSGDImpl.optimizer().setStepSize(1.0)
- .setRegParam(0.01)
- .setNumIterations(20);
- RidgeRegressionModel model = ridgeSGDImpl.run(testRDD.rdd());
-
- int numAccurate = validatePrediction(validationData, model);
- Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0);
+ private transient JavaSparkContext sc;
+
+ @Before
+ public void setUp() {
+ sc = new JavaSparkContext("local", "JavaRidgeRegressionSuite");
+ }
+
+ @After
+ public void tearDown() {
+ sc.stop();
+ sc = null;
+ System.clearProperty("spark.driver.port");
+ }
+
+ double predictionError(List<LabeledPoint> validationData, RidgeRegressionModel model) {
+ double errorSum = 0;
+ for (LabeledPoint point: validationData) {
+ Double prediction = model.predict(point.features());
+ errorSum += (prediction - point.label()) * (prediction - point.label());
}
-
- @Test
- public void runRidgeRegressionUsingStaticMethods() {
- int nPoints = 10000;
- double A = 2.0;
- double[] weights = {-1.5, 1.0e-2};
-
- JavaRDD<LabeledPoint> testRDD = sc.parallelize(LinearDataGenerator.generateLinearInputAsList(A,
- weights, nPoints, 42), 2).cache();
- List<LabeledPoint> validationData =
- LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17);
-
- RidgeRegressionModel model = RidgeRegressionWithSGD.train(testRDD.rdd(), 100, 1.0, 0.01, 1.0);
-
- int numAccurate = validatePrediction(validationData, model);
- Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0);
- }
-
+ return errorSum / validationData.size();
+ }
+
+ List<LabeledPoint> generateRidgeData(int numPoints, int nfeatures, double eps) {
+ org.jblas.util.Random.seed(42);
+ // Pick weights as random values distributed uniformly in [-0.5, 0.5]
+ DoubleMatrix w = DoubleMatrix.rand(nfeatures, 1).subi(0.5);
+ // Set first two weights to eps
+ w.put(0, 0, eps);
+ w.put(1, 0, eps);
+ return LinearDataGenerator.generateLinearInputAsList(0.0, w.data, numPoints, 42, eps);
+ }
+
+ @Test
+ public void runRidgeRegressionUsingConstructor() {
+ int nexamples = 200;
+ int nfeatures = 20;
+ double eps = 10.0;
+ List<LabeledPoint> data = generateRidgeData(2*nexamples, nfeatures, eps);
+
+ JavaRDD<LabeledPoint> testRDD = sc.parallelize(data.subList(0, nexamples));
+ List<LabeledPoint> validationData = data.subList(nexamples, 2*nexamples);
+
+ RidgeRegressionWithSGD ridgeSGDImpl = new RidgeRegressionWithSGD();
+ ridgeSGDImpl.optimizer().setStepSize(1.0)
+ .setRegParam(0.0)
+ .setNumIterations(200);
+ RidgeRegressionModel model = ridgeSGDImpl.run(testRDD.rdd());
+ double unRegularizedErr = predictionError(validationData, model);
+
+ ridgeSGDImpl.optimizer().setRegParam(0.1);
+ model = ridgeSGDImpl.run(testRDD.rdd());
+ double regularizedErr = predictionError(validationData, model);
+
+ Assert.assertTrue(regularizedErr < unRegularizedErr);
+ }
+
+ @Test
+ public void runRidgeRegressionUsingStaticMethods() {
+ int nexamples = 200;
+ int nfeatures = 20;
+ double eps = 10.0;
+ List<LabeledPoint> data = generateRidgeData(2*nexamples, nfeatures, eps);
+
+ JavaRDD<LabeledPoint> testRDD = sc.parallelize(data.subList(0, nexamples));
+ List<LabeledPoint> validationData = data.subList(nexamples, 2*nexamples);
+
+ RidgeRegressionModel model = RidgeRegressionWithSGD.train(testRDD.rdd(), 200, 1.0, 0.0);
+ double unRegularizedErr = predictionError(validationData, model);
+
+ model = RidgeRegressionWithSGD.train(testRDD.rdd(), 200, 1.0, 0.1);
+ double regularizedErr = predictionError(validationData, model);
+
+ Assert.assertTrue(regularizedErr < unRegularizedErr);
+ }
}
diff --git a/mllib/src/test/scala/spark/mllib/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/spark/mllib/regression/LinearRegressionSuite.scala
index 3d22b7d385..acc48a3283 100644
--- a/mllib/src/test/scala/spark/mllib/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/spark/mllib/regression/LinearRegressionSuite.scala
@@ -36,10 +36,19 @@ class LinearRegressionSuite extends FunSuite with BeforeAndAfterAll {
System.clearProperty("spark.driver.port")
}
- // Test if we can correctly learn Y = 3 + 10*X1 + 10*X2 when
- // X1 and X2 are collinear.
- test("multi-collinear variables") {
- val testRDD = LinearDataGenerator.generateLinearRDD(sc, 100, 2, 0.0, Array(10.0, 10.0), intercept=3.0).cache()
+ def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) {
+ val numOffPredictions = predictions.zip(input).filter { case (prediction, expected) =>
+ // A prediction is off if the prediction is more than 0.5 away from expected value.
+ math.abs(prediction - expected.label) > 0.5
+ }.size
+ // At least 80% of the predictions should be on.
+ assert(numOffPredictions < input.length / 5)
+ }
+
+ // Test if we can correctly learn Y = 3 + 10*X1 + 10*X2
+ test("linear regression") {
+ val testRDD = sc.parallelize(LinearDataGenerator.generateLinearInput(
+ 3.0, Array(10.0, 10.0), 100, 42), 2).cache()
val linReg = new LinearRegressionWithSGD()
linReg.optimizer.setNumIterations(1000).setStepSize(1.0)
@@ -49,5 +58,15 @@ class LinearRegressionSuite extends FunSuite with BeforeAndAfterAll {
assert(model.weights.length === 2)
assert(model.weights(0) >= 9.0 && model.weights(0) <= 11.0)
assert(model.weights(1) >= 9.0 && model.weights(1) <= 11.0)
+
+ val validationData = LinearDataGenerator.generateLinearInput(
+ 3.0, Array(10.0, 10.0), 100, 17)
+ val validationRDD = sc.parallelize(validationData, 2).cache()
+
+ // Test prediction on RDD.
+ validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData)
+
+ // Test prediction on Array.
+ validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
}
}
diff --git a/mllib/src/test/scala/spark/mllib/regression/RidgeRegressionSuite.scala b/mllib/src/test/scala/spark/mllib/regression/RidgeRegressionSuite.scala
index 0237ccdf87..c482035706 100644
--- a/mllib/src/test/scala/spark/mllib/regression/RidgeRegressionSuite.scala
+++ b/mllib/src/test/scala/spark/mllib/regression/RidgeRegressionSuite.scala
@@ -20,6 +20,7 @@ package spark.mllib.regression
import scala.collection.JavaConversions._
import scala.util.Random
+import org.jblas.DoubleMatrix
import org.scalatest.BeforeAndAfterAll
import org.scalatest.FunSuite
@@ -27,7 +28,6 @@ import spark.SparkContext
import spark.SparkContext._
import spark.mllib.util.LinearDataGenerator
-
class RidgeRegressionSuite extends FunSuite with BeforeAndAfterAll {
@transient private var sc: SparkContext = _
@@ -40,31 +40,51 @@ class RidgeRegressionSuite extends FunSuite with BeforeAndAfterAll {
System.clearProperty("spark.driver.port")
}
- // Test if we can correctly learn Y = 3 + 10*X1 + 10*X2 when
- // X1 and X2 are collinear.
- test("multi-collinear variables") {
- val testRDD = LinearDataGenerator.generateLinearRDD(sc, 100, 2, 0.0, Array(10.0, 10.0), intercept=3.0).cache()
- val ridgeReg = new RidgeRegressionWithSGD()
- ridgeReg.optimizer.setNumIterations(1000).setRegParam(0.0).setStepSize(1.0)
+ def predictionError(predictions: Seq[Double], input: Seq[LabeledPoint]) = {
+ predictions.zip(input).map { case (prediction, expected) =>
+ (prediction - expected.label) * (prediction - expected.label)
+ }.reduceLeft(_ + _) / predictions.size
+ }
- val model = ridgeReg.run(testRDD)
+ test("regularization with skewed weights") {
+ val nexamples = 200
+ val nfeatures = 20
+ val eps = 10
- assert(model.intercept >= 2.5 && model.intercept <= 3.5)
- assert(model.weights.length === 2)
- assert(model.weights(0) >= 9.0 && model.weights(0) <= 11.0)
- assert(model.weights(1) >= 9.0 && model.weights(1) <= 11.0)
- }
+ org.jblas.util.Random.seed(42)
+ // Pick weights as random values distributed uniformly in [-0.5, 0.5]
+ val w = DoubleMatrix.rand(nfeatures, 1).subi(0.5)
+ // Set first two weights to eps
+ w.put(0, 0, eps)
+ w.put(1, 0, eps)
- test("multi-collinear variables with regularization") {
- val testRDD = LinearDataGenerator.generateLinearRDD(sc, 100, 2, 0.0, Array(10.0, 10.0), intercept=3.0).cache()
- val ridgeReg = new RidgeRegressionWithSGD()
- ridgeReg.optimizer.setNumIterations(1000).setRegParam(1.0).setStepSize(1.0)
+ // Use half of data for training and other half for validation
+ val data = LinearDataGenerator.generateLinearInput(3.0, w.toArray, 2*nexamples, 42, eps)
+ val testData = data.take(nexamples)
+ val validationData = data.takeRight(nexamples)
- val model = ridgeReg.run(testRDD)
+ val testRDD = sc.parallelize(testData, 2).cache()
+ val validationRDD = sc.parallelize(validationData, 2).cache()
+
+ // First run without regularization.
+ val linearReg = new LinearRegressionWithSGD()
+ linearReg.optimizer.setNumIterations(200)
+ .setStepSize(1.0)
+
+ val linearModel = linearReg.run(testRDD)
+ val linearErr = predictionError(
+ linearModel.predict(validationRDD.map(_.features)).collect(), validationData)
+
+ val ridgeReg = new RidgeRegressionWithSGD()
+ ridgeReg.optimizer.setNumIterations(200)
+ .setRegParam(0.1)
+ .setStepSize(1.0)
+ val ridgeModel = ridgeReg.run(testRDD)
+ val ridgeErr = predictionError(
+ ridgeModel.predict(validationRDD.map(_.features)).collect(), validationData)
- assert(model.intercept <= 5.0)
- assert(model.weights.length === 2)
- assert(model.weights(0) <= 4.0)
- assert(model.weights(1) <= 3.0)
+ // Ridge CV-error should be lower than linear regression
+ assert(ridgeErr < linearErr,
+ "ridgeError (" + ridgeErr + ") was not less than linearError(" + linearErr + ")")
}
}