aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/java
diff options
context:
space:
mode:
authorShivaram Venkataraman <shivaram@eecs.berkeley.edu>2013-08-25 22:24:27 -0700
committerShivaram Venkataraman <shivaram@eecs.berkeley.edu>2013-08-25 22:24:27 -0700
commitb8c50a0642cf74c25fd70cc1e7d1be95ddafc5d8 (patch)
tree8f2d5abb7cd565d81e56d3ceca8134eaae41abe4 /mllib/src/test/java
parent07fe910669b2ec15b6b5c1e5186df5036d05b9b1 (diff)
downloadspark-b8c50a0642cf74c25fd70cc1e7d1be95ddafc5d8.tar.gz
spark-b8c50a0642cf74c25fd70cc1e7d1be95ddafc5d8.tar.bz2
spark-b8c50a0642cf74c25fd70cc1e7d1be95ddafc5d8.zip
Center & scale variables in Ridge, Lasso.
Also add a unit test that checks if ridge regression lowers cross-validation error.
Diffstat (limited to 'mllib/src/test/java')
-rw-r--r--mllib/src/test/java/spark/mllib/regression/JavaLassoSuite.java8
-rw-r--r--mllib/src/test/java/spark/mllib/regression/JavaLinearRegressionSuite.java119
-rw-r--r--mllib/src/test/java/spark/mllib/regression/JavaRidgeRegressionSuite.java139
3 files changed, 138 insertions, 128 deletions
diff --git a/mllib/src/test/java/spark/mllib/regression/JavaLassoSuite.java b/mllib/src/test/java/spark/mllib/regression/JavaLassoSuite.java
index 428902e85c..5863140baf 100644
--- a/mllib/src/test/java/spark/mllib/regression/JavaLassoSuite.java
+++ b/mllib/src/test/java/spark/mllib/regression/JavaLassoSuite.java
@@ -63,9 +63,9 @@ public class JavaLassoSuite implements Serializable {
double[] weights = {-1.5, 1.0e-2};
JavaRDD<LabeledPoint> testRDD = sc.parallelize(LinearDataGenerator.generateLinearInputAsList(A,
- weights, nPoints, 42), 2).cache();
+ weights, nPoints, 42, 0.1), 2).cache();
List<LabeledPoint> validationData =
- LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17);
+ LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1);
LassoWithSGD lassoSGDImpl = new LassoWithSGD();
lassoSGDImpl.optimizer().setStepSize(1.0)
@@ -84,9 +84,9 @@ public class JavaLassoSuite implements Serializable {
double[] weights = {-1.5, 1.0e-2};
JavaRDD<LabeledPoint> testRDD = sc.parallelize(LinearDataGenerator.generateLinearInputAsList(A,
- weights, nPoints, 42), 2).cache();
+ weights, nPoints, 42, 0.1), 2).cache();
List<LabeledPoint> validationData =
- LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17);
+ LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1);
LassoModel model = LassoWithSGD.train(testRDD.rdd(), 100, 1.0, 0.01, 1.0);
diff --git a/mllib/src/test/java/spark/mllib/regression/JavaLinearRegressionSuite.java b/mllib/src/test/java/spark/mllib/regression/JavaLinearRegressionSuite.java
index 9642e89844..50716c7861 100644
--- a/mllib/src/test/java/spark/mllib/regression/JavaLinearRegressionSuite.java
+++ b/mllib/src/test/java/spark/mllib/regression/JavaLinearRegressionSuite.java
@@ -30,68 +30,65 @@ import spark.api.java.JavaSparkContext;
import spark.mllib.util.LinearDataGenerator;
public class JavaLinearRegressionSuite implements Serializable {
- private transient JavaSparkContext sc;
-
- @Before
- public void setUp() {
- sc = new JavaSparkContext("local", "JavaLinearRegressionSuite");
- }
-
- @After
- public void tearDown() {
- sc.stop();
- sc = null;
- System.clearProperty("spark.driver.port");
- }
-
- int validatePrediction(List<LabeledPoint> validationData, LinearRegressionModel model) {
- int numAccurate = 0;
- for (LabeledPoint point: validationData) {
- Double prediction = model.predict(point.features());
- // A prediction is off if the prediction is more than 0.5 away from expected value.
- if (Math.abs(prediction - point.label()) <= 0.5) {
- numAccurate++;
- }
+ private transient JavaSparkContext sc;
+
+ @Before
+ public void setUp() {
+ sc = new JavaSparkContext("local", "JavaLinearRegressionSuite");
+ }
+
+ @After
+ public void tearDown() {
+ sc.stop();
+ sc = null;
+ System.clearProperty("spark.driver.port");
+ }
+
+ int validatePrediction(List<LabeledPoint> validationData, LinearRegressionModel model) {
+ int numAccurate = 0;
+ for (LabeledPoint point: validationData) {
+ Double prediction = model.predict(point.features());
+ // A prediction is off if the prediction is more than 0.5 away from expected value.
+ if (Math.abs(prediction - point.label()) <= 0.5) {
+ numAccurate++;
}
- return numAccurate;
- }
-
- @Test
- public void runLinearRegressionUsingConstructor() {
- int nPoints = 10000;
- double A = 2.0;
- double[] weights = {-1.5, 1.0e-2};
-
- JavaRDD<LabeledPoint> testRDD = sc.parallelize(LinearDataGenerator.generateLinearInputAsList(A,
- weights, nPoints, 42), 2).cache();
- List<LabeledPoint> validationData =
- LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17);
-
- LinearRegressionWithSGD linSGDImpl = new LinearRegressionWithSGD();
- linSGDImpl.optimizer().setStepSize(1.0)
- .setRegParam(0.01)
- .setNumIterations(20);
- LinearRegressionModel model = linSGDImpl.run(testRDD.rdd());
-
- int numAccurate = validatePrediction(validationData, model);
- Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0);
- }
-
- @Test
- public void runLinearRegressionUsingStaticMethods() {
- int nPoints = 10000;
- double A = 2.0;
- double[] weights = {-1.5, 1.0e-2};
-
- JavaRDD<LabeledPoint> testRDD = sc.parallelize(LinearDataGenerator.generateLinearInputAsList(A,
- weights, nPoints, 42), 2).cache();
- List<LabeledPoint> validationData =
- LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17);
-
- LinearRegressionModel model = LinearRegressionWithSGD.train(testRDD.rdd(), 100, 1.0, 1.0);
-
- int numAccurate = validatePrediction(validationData, model);
- Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0);
}
+ return numAccurate;
+ }
+
+ @Test
+ public void runLinearRegressionUsingConstructor() {
+ int nPoints = 100;
+ double A = 3.0;
+ double[] weights = {10, 10};
+
+ JavaRDD<LabeledPoint> testRDD = sc.parallelize(
+ LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 42, 0.1), 2).cache();
+ List<LabeledPoint> validationData =
+ LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1);
+
+ LinearRegressionWithSGD linSGDImpl = new LinearRegressionWithSGD();
+ LinearRegressionModel model = linSGDImpl.run(testRDD.rdd());
+
+ int numAccurate = validatePrediction(validationData, model);
+ Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0);
+ }
+
+ @Test
+ public void runLinearRegressionUsingStaticMethods() {
+ int nPoints = 100;
+ double A = 3.0;
+ double[] weights = {10, 10};
+
+ JavaRDD<LabeledPoint> testRDD = sc.parallelize(
+ LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 42, 0.1), 2).cache();
+ List<LabeledPoint> validationData =
+ LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1);
+
+ LinearRegressionModel model = LinearRegressionWithSGD.train(testRDD.rdd(), 100);
+
+ int numAccurate = validatePrediction(validationData, model);
+ Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0);
+ }
}
diff --git a/mllib/src/test/java/spark/mllib/regression/JavaRidgeRegressionSuite.java b/mllib/src/test/java/spark/mllib/regression/JavaRidgeRegressionSuite.java
index 5df6d8076d..2c0aabad30 100644
--- a/mllib/src/test/java/spark/mllib/regression/JavaRidgeRegressionSuite.java
+++ b/mllib/src/test/java/spark/mllib/regression/JavaRidgeRegressionSuite.java
@@ -25,73 +25,86 @@ import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
+import org.jblas.DoubleMatrix;
+
import spark.api.java.JavaRDD;
import spark.api.java.JavaSparkContext;
import spark.mllib.util.LinearDataGenerator;
public class JavaRidgeRegressionSuite implements Serializable {
- private transient JavaSparkContext sc;
-
- @Before
- public void setUp() {
- sc = new JavaSparkContext("local", "JavaRidgeRegressionSuite");
- }
-
- @After
- public void tearDown() {
- sc.stop();
- sc = null;
- System.clearProperty("spark.driver.port");
- }
-
- int validatePrediction(List<LabeledPoint> validationData, RidgeRegressionModel model) {
- int numAccurate = 0;
- for (LabeledPoint point: validationData) {
- Double prediction = model.predict(point.features());
- // A prediction is off if the prediction is more than 0.5 away from expected value.
- if (Math.abs(prediction - point.label()) <= 0.5) {
- numAccurate++;
- }
- }
- return numAccurate;
- }
-
- @Test
- public void runRidgeRegressionUsingConstructor() {
- int nPoints = 10000;
- double A = 2.0;
- double[] weights = {-1.5, 1.0e-2};
-
- JavaRDD<LabeledPoint> testRDD = sc.parallelize(LinearDataGenerator.generateLinearInputAsList(A,
- weights, nPoints, 42), 2).cache();
- List<LabeledPoint> validationData =
- LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17);
-
- RidgeRegressionWithSGD ridgeSGDImpl = new RidgeRegressionWithSGD();
- ridgeSGDImpl.optimizer().setStepSize(1.0)
- .setRegParam(0.01)
- .setNumIterations(20);
- RidgeRegressionModel model = ridgeSGDImpl.run(testRDD.rdd());
-
- int numAccurate = validatePrediction(validationData, model);
- Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0);
+ private transient JavaSparkContext sc;
+
+ @Before
+ public void setUp() {
+ sc = new JavaSparkContext("local", "JavaRidgeRegressionSuite");
+ }
+
+ @After
+ public void tearDown() {
+ sc.stop();
+ sc = null;
+ System.clearProperty("spark.driver.port");
+ }
+
+ double predictionError(List<LabeledPoint> validationData, RidgeRegressionModel model) {
+ double errorSum = 0;
+ for (LabeledPoint point: validationData) {
+ Double prediction = model.predict(point.features());
+ errorSum += (prediction - point.label()) * (prediction - point.label());
}
-
- @Test
- public void runRidgeRegressionUsingStaticMethods() {
- int nPoints = 10000;
- double A = 2.0;
- double[] weights = {-1.5, 1.0e-2};
-
- JavaRDD<LabeledPoint> testRDD = sc.parallelize(LinearDataGenerator.generateLinearInputAsList(A,
- weights, nPoints, 42), 2).cache();
- List<LabeledPoint> validationData =
- LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17);
-
- RidgeRegressionModel model = RidgeRegressionWithSGD.train(testRDD.rdd(), 100, 1.0, 0.01, 1.0);
-
- int numAccurate = validatePrediction(validationData, model);
- Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0);
- }
-
+ return errorSum / validationData.size();
+ }
+
+ List<LabeledPoint> generateRidgeData(int numPoints, int nfeatures, double eps) {
+ org.jblas.util.Random.seed(42);
+ // Pick weights as random values distributed uniformly in [-0.5, 0.5]
+ DoubleMatrix w = DoubleMatrix.rand(nfeatures, 1).subi(0.5);
+ // Set first two weights to eps
+ w.put(0, 0, eps);
+ w.put(1, 0, eps);
+ return LinearDataGenerator.generateLinearInputAsList(0.0, w.data, numPoints, 42, eps);
+ }
+
+ @Test
+ public void runRidgeRegressionUsingConstructor() {
+ int nexamples = 200;
+ int nfeatures = 20;
+ double eps = 10.0;
+ List<LabeledPoint> data = generateRidgeData(2*nexamples, nfeatures, eps);
+
+ JavaRDD<LabeledPoint> testRDD = sc.parallelize(data.subList(0, nexamples));
+ List<LabeledPoint> validationData = data.subList(nexamples, 2*nexamples);
+
+ RidgeRegressionWithSGD ridgeSGDImpl = new RidgeRegressionWithSGD();
+ ridgeSGDImpl.optimizer().setStepSize(1.0)
+ .setRegParam(0.0)
+ .setNumIterations(200);
+ RidgeRegressionModel model = ridgeSGDImpl.run(testRDD.rdd());
+ double unRegularizedErr = predictionError(validationData, model);
+
+ ridgeSGDImpl.optimizer().setRegParam(0.1);
+ model = ridgeSGDImpl.run(testRDD.rdd());
+ double regularizedErr = predictionError(validationData, model);
+
+ Assert.assertTrue(regularizedErr < unRegularizedErr);
+ }
+
+ @Test
+ public void runRidgeRegressionUsingStaticMethods() {
+ int nexamples = 200;
+ int nfeatures = 20;
+ double eps = 10.0;
+ List<LabeledPoint> data = generateRidgeData(2*nexamples, nfeatures, eps);
+
+ JavaRDD<LabeledPoint> testRDD = sc.parallelize(data.subList(0, nexamples));
+ List<LabeledPoint> validationData = data.subList(nexamples, 2*nexamples);
+
+ RidgeRegressionModel model = RidgeRegressionWithSGD.train(testRDD.rdd(), 200, 1.0, 0.0);
+ double unRegularizedErr = predictionError(validationData, model);
+
+ model = RidgeRegressionWithSGD.train(testRDD.rdd(), 200, 1.0, 0.1);
+ double regularizedErr = predictionError(validationData, model);
+
+ Assert.assertTrue(regularizedErr < unRegularizedErr);
+ }
}