diff options
author | Shivaram Venkataraman <shivaram@eecs.berkeley.edu> | 2013-08-25 22:24:27 -0700 |
---|---|---|
committer | Shivaram Venkataraman <shivaram@eecs.berkeley.edu> | 2013-08-25 22:24:27 -0700 |
commit | b8c50a0642cf74c25fd70cc1e7d1be95ddafc5d8 (patch) | |
tree | 8f2d5abb7cd565d81e56d3ceca8134eaae41abe4 /mllib/src/test/java | |
parent | 07fe910669b2ec15b6b5c1e5186df5036d05b9b1 (diff) | |
download | spark-b8c50a0642cf74c25fd70cc1e7d1be95ddafc5d8.tar.gz spark-b8c50a0642cf74c25fd70cc1e7d1be95ddafc5d8.tar.bz2 spark-b8c50a0642cf74c25fd70cc1e7d1be95ddafc5d8.zip |
Center & scale variables in Ridge, Lasso.
Also add a unit test that checks if ridge regression lowers
cross-validation error.
Diffstat (limited to 'mllib/src/test/java')
3 files changed, 138 insertions, 128 deletions
diff --git a/mllib/src/test/java/spark/mllib/regression/JavaLassoSuite.java b/mllib/src/test/java/spark/mllib/regression/JavaLassoSuite.java index 428902e85c..5863140baf 100644 --- a/mllib/src/test/java/spark/mllib/regression/JavaLassoSuite.java +++ b/mllib/src/test/java/spark/mllib/regression/JavaLassoSuite.java @@ -63,9 +63,9 @@ public class JavaLassoSuite implements Serializable { double[] weights = {-1.5, 1.0e-2}; JavaRDD<LabeledPoint> testRDD = sc.parallelize(LinearDataGenerator.generateLinearInputAsList(A, - weights, nPoints, 42), 2).cache(); + weights, nPoints, 42, 0.1), 2).cache(); List<LabeledPoint> validationData = - LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17); + LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1); LassoWithSGD lassoSGDImpl = new LassoWithSGD(); lassoSGDImpl.optimizer().setStepSize(1.0) @@ -84,9 +84,9 @@ public class JavaLassoSuite implements Serializable { double[] weights = {-1.5, 1.0e-2}; JavaRDD<LabeledPoint> testRDD = sc.parallelize(LinearDataGenerator.generateLinearInputAsList(A, - weights, nPoints, 42), 2).cache(); + weights, nPoints, 42, 0.1), 2).cache(); List<LabeledPoint> validationData = - LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17); + LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1); LassoModel model = LassoWithSGD.train(testRDD.rdd(), 100, 1.0, 0.01, 1.0); diff --git a/mllib/src/test/java/spark/mllib/regression/JavaLinearRegressionSuite.java b/mllib/src/test/java/spark/mllib/regression/JavaLinearRegressionSuite.java index 9642e89844..50716c7861 100644 --- a/mllib/src/test/java/spark/mllib/regression/JavaLinearRegressionSuite.java +++ b/mllib/src/test/java/spark/mllib/regression/JavaLinearRegressionSuite.java @@ -30,68 +30,65 @@ import spark.api.java.JavaSparkContext; import spark.mllib.util.LinearDataGenerator; public class JavaLinearRegressionSuite implements Serializable { - private transient JavaSparkContext sc; - - @Before - public void setUp() { - sc = new JavaSparkContext("local", "JavaLinearRegressionSuite"); - } - - @After - public void tearDown() { - sc.stop(); - sc = null; - System.clearProperty("spark.driver.port"); - } - - int validatePrediction(List<LabeledPoint> validationData, LinearRegressionModel model) { - int numAccurate = 0; - for (LabeledPoint point: validationData) { - Double prediction = model.predict(point.features()); - // A prediction is off if the prediction is more than 0.5 away from expected value. - if (Math.abs(prediction - point.label()) <= 0.5) { - numAccurate++; - } + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaLinearRegressionSuite"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + System.clearProperty("spark.driver.port"); + } + + int validatePrediction(List<LabeledPoint> validationData, LinearRegressionModel model) { + int numAccurate = 0; + for (LabeledPoint point: validationData) { + Double prediction = model.predict(point.features()); + // A prediction is off if the prediction is more than 0.5 away from expected value. + if (Math.abs(prediction - point.label()) <= 0.5) { + numAccurate++; } - return numAccurate; - } - - @Test - public void runLinearRegressionUsingConstructor() { - int nPoints = 10000; - double A = 2.0; - double[] weights = {-1.5, 1.0e-2}; - - JavaRDD<LabeledPoint> testRDD = sc.parallelize(LinearDataGenerator.generateLinearInputAsList(A, - weights, nPoints, 42), 2).cache(); - List<LabeledPoint> validationData = - LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17); - - LinearRegressionWithSGD linSGDImpl = new LinearRegressionWithSGD(); - linSGDImpl.optimizer().setStepSize(1.0) - .setRegParam(0.01) - .setNumIterations(20); - LinearRegressionModel model = linSGDImpl.run(testRDD.rdd()); - - int numAccurate = validatePrediction(validationData, model); - Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0); - } - - @Test - public void runLinearRegressionUsingStaticMethods() { - int nPoints = 10000; - double A = 2.0; - double[] weights = {-1.5, 1.0e-2}; - - JavaRDD<LabeledPoint> testRDD = sc.parallelize(LinearDataGenerator.generateLinearInputAsList(A, - weights, nPoints, 42), 2).cache(); - List<LabeledPoint> validationData = - LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17); - - LinearRegressionModel model = LinearRegressionWithSGD.train(testRDD.rdd(), 100, 1.0, 1.0); - - int numAccurate = validatePrediction(validationData, model); - Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0); } + return numAccurate; + } + + @Test + public void runLinearRegressionUsingConstructor() { + int nPoints = 100; + double A = 3.0; + double[] weights = {10, 10}; + + JavaRDD<LabeledPoint> testRDD = sc.parallelize( + LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 42, 0.1), 2).cache(); + List<LabeledPoint> validationData = + LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1); + + LinearRegressionWithSGD linSGDImpl = new LinearRegressionWithSGD(); + LinearRegressionModel model = linSGDImpl.run(testRDD.rdd()); + + int numAccurate = validatePrediction(validationData, model); + Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0); + } + + @Test + public void runLinearRegressionUsingStaticMethods() { + int nPoints = 100; + double A = 3.0; + double[] weights = {10, 10}; + + JavaRDD<LabeledPoint> testRDD = sc.parallelize( + LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 42, 0.1), 2).cache(); + List<LabeledPoint> validationData = + LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1); + + LinearRegressionModel model = LinearRegressionWithSGD.train(testRDD.rdd(), 100); + + int numAccurate = validatePrediction(validationData, model); + Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0); + } } diff --git a/mllib/src/test/java/spark/mllib/regression/JavaRidgeRegressionSuite.java b/mllib/src/test/java/spark/mllib/regression/JavaRidgeRegressionSuite.java index 5df6d8076d..2c0aabad30 100644 --- a/mllib/src/test/java/spark/mllib/regression/JavaRidgeRegressionSuite.java +++ b/mllib/src/test/java/spark/mllib/regression/JavaRidgeRegressionSuite.java @@ -25,73 +25,86 @@ import org.junit.Assert; import org.junit.Before; import org.junit.Test; +import org.jblas.DoubleMatrix; + import spark.api.java.JavaRDD; import spark.api.java.JavaSparkContext; import spark.mllib.util.LinearDataGenerator; public class JavaRidgeRegressionSuite implements Serializable { - private transient JavaSparkContext sc; - - @Before - public void setUp() { - sc = new JavaSparkContext("local", "JavaRidgeRegressionSuite"); - } - - @After - public void tearDown() { - sc.stop(); - sc = null; - System.clearProperty("spark.driver.port"); - } - - int validatePrediction(List<LabeledPoint> validationData, RidgeRegressionModel model) { - int numAccurate = 0; - for (LabeledPoint point: validationData) { - Double prediction = model.predict(point.features()); - // A prediction is off if the prediction is more than 0.5 away from expected value. - if (Math.abs(prediction - point.label()) <= 0.5) { - numAccurate++; - } - } - return numAccurate; - } - - @Test - public void runRidgeRegressionUsingConstructor() { - int nPoints = 10000; - double A = 2.0; - double[] weights = {-1.5, 1.0e-2}; - - JavaRDD<LabeledPoint> testRDD = sc.parallelize(LinearDataGenerator.generateLinearInputAsList(A, - weights, nPoints, 42), 2).cache(); - List<LabeledPoint> validationData = - LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17); - - RidgeRegressionWithSGD ridgeSGDImpl = new RidgeRegressionWithSGD(); - ridgeSGDImpl.optimizer().setStepSize(1.0) - .setRegParam(0.01) - .setNumIterations(20); - RidgeRegressionModel model = ridgeSGDImpl.run(testRDD.rdd()); - - int numAccurate = validatePrediction(validationData, model); - Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0); + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaRidgeRegressionSuite"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + System.clearProperty("spark.driver.port"); + } + + double predictionError(List<LabeledPoint> validationData, RidgeRegressionModel model) { + double errorSum = 0; + for (LabeledPoint point: validationData) { + Double prediction = model.predict(point.features()); + errorSum += (prediction - point.label()) * (prediction - point.label()); } - - @Test - public void runRidgeRegressionUsingStaticMethods() { - int nPoints = 10000; - double A = 2.0; - double[] weights = {-1.5, 1.0e-2}; - - JavaRDD<LabeledPoint> testRDD = sc.parallelize(LinearDataGenerator.generateLinearInputAsList(A, - weights, nPoints, 42), 2).cache(); - List<LabeledPoint> validationData = - LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17); - - RidgeRegressionModel model = RidgeRegressionWithSGD.train(testRDD.rdd(), 100, 1.0, 0.01, 1.0); - - int numAccurate = validatePrediction(validationData, model); - Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0); - } - + return errorSum / validationData.size(); + } + + List<LabeledPoint> generateRidgeData(int numPoints, int nfeatures, double eps) { + org.jblas.util.Random.seed(42); + // Pick weights as random values distributed uniformly in [-0.5, 0.5] + DoubleMatrix w = DoubleMatrix.rand(nfeatures, 1).subi(0.5); + // Set first two weights to eps + w.put(0, 0, eps); + w.put(1, 0, eps); + return LinearDataGenerator.generateLinearInputAsList(0.0, w.data, numPoints, 42, eps); + } + + @Test + public void runRidgeRegressionUsingConstructor() { + int nexamples = 200; + int nfeatures = 20; + double eps = 10.0; + List<LabeledPoint> data = generateRidgeData(2*nexamples, nfeatures, eps); + + JavaRDD<LabeledPoint> testRDD = sc.parallelize(data.subList(0, nexamples)); + List<LabeledPoint> validationData = data.subList(nexamples, 2*nexamples); + + RidgeRegressionWithSGD ridgeSGDImpl = new RidgeRegressionWithSGD(); + ridgeSGDImpl.optimizer().setStepSize(1.0) + .setRegParam(0.0) + .setNumIterations(200); + RidgeRegressionModel model = ridgeSGDImpl.run(testRDD.rdd()); + double unRegularizedErr = predictionError(validationData, model); + + ridgeSGDImpl.optimizer().setRegParam(0.1); + model = ridgeSGDImpl.run(testRDD.rdd()); + double regularizedErr = predictionError(validationData, model); + + Assert.assertTrue(regularizedErr < unRegularizedErr); + } + + @Test + public void runRidgeRegressionUsingStaticMethods() { + int nexamples = 200; + int nfeatures = 20; + double eps = 10.0; + List<LabeledPoint> data = generateRidgeData(2*nexamples, nfeatures, eps); + + JavaRDD<LabeledPoint> testRDD = sc.parallelize(data.subList(0, nexamples)); + List<LabeledPoint> validationData = data.subList(nexamples, 2*nexamples); + + RidgeRegressionModel model = RidgeRegressionWithSGD.train(testRDD.rdd(), 200, 1.0, 0.0); + double unRegularizedErr = predictionError(validationData, model); + + model = RidgeRegressionWithSGD.train(testRDD.rdd(), 200, 1.0, 0.1); + double regularizedErr = predictionError(validationData, model); + + Assert.assertTrue(regularizedErr < unRegularizedErr); + } } |