diff options
author | Evan Sparks <sparks@cs.berkeley.edu> | 2013-08-14 16:24:23 -0700 |
---|---|---|
committer | Shivaram Venkataraman <shivaram@eecs.berkeley.edu> | 2013-08-18 15:03:13 -0700 |
commit | b659af83d3f91f0f339d874b2742ddca20a9f610 (patch) | |
tree | 77b2b98a2bdf1433fd49144632f2d29f1b53f803 /mllib/src/test | |
parent | 044a088c0db68220aae2dad425886b618bb0023f (diff) | |
download | spark-b659af83d3f91f0f339d874b2742ddca20a9f610.tar.gz spark-b659af83d3f91f0f339d874b2742ddca20a9f610.tar.bz2 spark-b659af83d3f91f0f339d874b2742ddca20a9f610.zip |
Adding Linear Regression, and refactoring Ridge Regression.
Diffstat (limited to 'mllib/src/test')
4 files changed, 329 insertions, 18 deletions
diff --git a/mllib/src/test/java/spark/mllib/regression/JavaLinearRegressionSuite.java b/mllib/src/test/java/spark/mllib/regression/JavaLinearRegressionSuite.java new file mode 100644 index 0000000000..14d3d4ef39 --- /dev/null +++ b/mllib/src/test/java/spark/mllib/regression/JavaLinearRegressionSuite.java @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package spark.mllib.regression; + +import java.io.Serializable; +import java.util.List; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import spark.api.java.JavaRDD; +import spark.api.java.JavaSparkContext; + +public class JavaLinearRegressionSuite implements Serializable { + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaLinearRegressionSuite"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + System.clearProperty("spark.driver.port"); + } + + int validatePrediction(List<LabeledPoint> validationData, LinearRegressionModel model) { + int numAccurate = 0; + for (LabeledPoint point: validationData) { + Double prediction = model.predict(point.features()); + // A prediction is off if the prediction is more than 0.5 away from expected value. + if (Math.abs(prediction - point.label()) <= 0.5) { + numAccurate++; + } + } + return numAccurate; + } + + @Test + public void runLinearRegressionUsingConstructor() { + int nPoints = 10000; + double A = 2.0; + double[] weights = {-1.5, 1.0e-2}; + + JavaRDD<LabeledPoint> testRDD = sc.parallelize(LinearRegressionSuite.generateLinearRegressionInputAsList(A, + weights, nPoints, 42), 2).cache(); + List<LabeledPoint> validationData = + LinearRegressionSuite.generateLinearRegressionInputAsList(A, weights, nPoints, 17); + + LinearRegressionWithSGD svmSGDImpl = new LinearRegressionWithSGD(); + svmSGDImpl.optimizer().setStepSize(1.0) + .setRegParam(0.01) + .setNumIterations(20); + LinearRegressionModel model = svmSGDImpl.run(testRDD.rdd()); + + int numAccurate = validatePrediction(validationData, model); + Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0); + } + + @Test + public void runLinearRegressionUsingStaticMethods() { + int nPoints = 10000; + double A = 2.0; + double[] weights = {-1.5, 1.0e-2}; + + JavaRDD<LabeledPoint> testRDD = sc.parallelize(LinearRegressionSuite.generateLinearRegressionInputAsList(A, + weights, nPoints, 42), 2).cache(); + List<LabeledPoint> validationData = + LinearRegressionSuite.generateLinearRegressionInputAsList(A, weights, nPoints, 17); + + LinearRegressionModel model = LinearRegressionWithSGD.train(testRDD.rdd(), 100, 1.0, 1.0); + + int numAccurate = validatePrediction(validationData, model); + Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0); + } + +} diff --git a/mllib/src/test/java/spark/mllib/regression/JavaRidgeRegressionSuite.java b/mllib/src/test/java/spark/mllib/regression/JavaRidgeRegressionSuite.java new file mode 100644 index 0000000000..4f379b51d5 --- /dev/null +++ b/mllib/src/test/java/spark/mllib/regression/JavaRidgeRegressionSuite.java @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package spark.mllib.regression; + +import java.io.Serializable; +import java.util.List; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import spark.api.java.JavaRDD; +import spark.api.java.JavaSparkContext; + +public class JavaRidgeRegressionSuite implements Serializable { + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaRidgeRegressionSuite"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + System.clearProperty("spark.driver.port"); + } + + int validatePrediction(List<LabeledPoint> validationData, RidgeRegressionModel model) { + int numAccurate = 0; + for (LabeledPoint point: validationData) { + Double prediction = model.predict(point.features()); + // A prediction is off if the prediction is more than 0.5 away from expected value. + if (Math.abs(prediction - point.label()) <= 0.5) { + numAccurate++; + } + } + return numAccurate; + } + + @Test + public void runRidgeRegressionUsingConstructor() { + int nPoints = 10000; + double A = 2.0; + double[] weights = {-1.5, 1.0e-2}; + + JavaRDD<LabeledPoint> testRDD = sc.parallelize(RidgeRegressionSuite.generateRidgeRegressionInputAsList(A, + weights, nPoints, 42), 2).cache(); + List<LabeledPoint> validationData = + RidgeRegressionSuite.generateRidgeRegressionInputAsList(A, weights, nPoints, 17); + + RidgeRegressionWithSGD svmSGDImpl = new RidgeRegressionWithSGD(); + svmSGDImpl.optimizer().setStepSize(1.0) + .setRegParam(0.01) + .setNumIterations(20); + RidgeRegressionModel model = svmSGDImpl.run(testRDD.rdd()); + + int numAccurate = validatePrediction(validationData, model); + Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0); + } + + @Test + public void runRidgeRegressionUsingStaticMethods() { + int nPoints = 10000; + double A = 2.0; + double[] weights = {-1.5, 1.0e-2}; + + JavaRDD<LabeledPoint> testRDD = sc.parallelize(RidgeRegressionSuite.generateRidgeRegressionInputAsList(A, + weights, nPoints, 42), 2).cache(); + List<LabeledPoint> validationData = + RidgeRegressionSuite.generateRidgeRegressionInputAsList(A, weights, nPoints, 17); + + RidgeRegressionModel model = RidgeRegressionWithSGD.train(testRDD.rdd(), 100, 1.0, 0.01, 1.0); + + int numAccurate = validatePrediction(validationData, model); + Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0); + } + +} diff --git a/mllib/src/test/scala/spark/mllib/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/spark/mllib/regression/LinearRegressionSuite.scala new file mode 100644 index 0000000000..c794c1cac5 --- /dev/null +++ b/mllib/src/test/scala/spark/mllib/regression/LinearRegressionSuite.scala @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package spark.mllib.regression + +import scala.collection.JavaConversions._ +import scala.util.Random + +import org.scalatest.BeforeAndAfterAll +import org.scalatest.FunSuite + +import spark.SparkContext +import spark.SparkContext._ +import spark.mllib.util.LinearRegressionDataGenerator +import spark.mllib.regression.LabeledPoint +import org.jblas.DoubleMatrix + +object LinearRegressionSuite { + + def generateLinearRegressionInputAsList( + intercept: Double, + weights: Array[Double], + nPoints: Int, + seed: Int): java.util.List[LabeledPoint] = { + seqAsJavaList(generateLinearRegressionInput(intercept, weights, nPoints, seed)) + } + + + // Generate noisy input of the form Y = x.dot(weights) + intercept + noise + def generateLinearRegressionInput( + intercept: Double, + weights: Array[Double], + nPoints: Int, + seed: Int): Seq[LabeledPoint] = { + val rnd = new Random(seed) + val weightsMat = new DoubleMatrix(1, weights.length, weights:_*) + val x = Array.fill[Array[Double]](nPoints)( + Array.fill[Double](weights.length)(rnd.nextGaussian())) + val y = x.map(xi => + (new DoubleMatrix(1, xi.length, xi:_*)).dot(weightsMat) + intercept + 0.1 * rnd.nextGaussian() + ) + y.zip(x).map(p => LabeledPoint(p._1, p._2)) + } + +} + +class LinearRegressionSuite extends FunSuite with BeforeAndAfterAll { + @transient private var sc: SparkContext = _ + + override def beforeAll() { + sc = new SparkContext("local", "test") + } + + override def afterAll() { + sc.stop() + System.clearProperty("spark.driver.port") + } + + // Test if we can correctly learn Y = 3 + 10*X1 + 10*X2 when + // X1 and X2 are collinear. + test("multi-collinear variables") { + val testRDD = LinearRegressionDataGenerator.generateLinearRDD(sc, 100, 2, 0.0, intercept=3.0).cache() + val linReg = new LinearRegressionWithSGD() + linReg.optimizer.setNumIterations(1000).setStepSize(1.0) + + val model = linReg.run(testRDD) + + assert(model.intercept >= 2.5 && model.intercept <= 3.5) + assert(model.weights.length === 2) + assert(model.weights(0) >= 9.0 && model.weights(0) <= 11.0) + assert(model.weights(1) >= 9.0 && model.weights(1) <= 11.0) + } +} diff --git a/mllib/src/test/scala/spark/mllib/regression/RidgeRegressionSuite.scala b/mllib/src/test/scala/spark/mllib/regression/RidgeRegressionSuite.scala index e2b244894d..aaac083ad9 100644 --- a/mllib/src/test/scala/spark/mllib/regression/RidgeRegressionSuite.scala +++ b/mllib/src/test/scala/spark/mllib/regression/RidgeRegressionSuite.scala @@ -17,6 +17,7 @@ package spark.mllib.regression +import scala.collection.JavaConversions._ import scala.util.Random import org.scalatest.BeforeAndAfterAll @@ -24,6 +25,37 @@ import org.scalatest.FunSuite import spark.SparkContext import spark.SparkContext._ +import spark.mllib.util.RidgeRegressionDataGenerator +import org.jblas.DoubleMatrix + +object RidgeRegressionSuite { + + def generateRidgeRegressionInputAsList( + intercept: Double, + weights: Array[Double], + nPoints: Int, + seed: Int): java.util.List[LabeledPoint] = { + seqAsJavaList(generateRidgeRegressionInput(intercept, weights, nPoints, seed)) + } + + + // Generate noisy input of the form Y = x.dot(weights) + intercept + noise + def generateRidgeRegressionInput( + intercept: Double, + weights: Array[Double], + nPoints: Int, + seed: Int): Seq[LabeledPoint] = { + val rnd = new Random(seed) + val weightsMat = new DoubleMatrix(1, weights.length, weights:_*) + val x = Array.fill[Array[Double]](nPoints)( + Array.fill[Double](weights.length)(rnd.nextGaussian())) + val y = x.map(xi => + (new DoubleMatrix(1, xi.length, xi:_*)).dot(weightsMat) + intercept + 0.1 * rnd.nextGaussian() + ) + y.zip(x).map(p => LabeledPoint(p._1, p._2)) + } + +} class RidgeRegressionSuite extends FunSuite with BeforeAndAfterAll { @@ -38,31 +70,31 @@ class RidgeRegressionSuite extends FunSuite with BeforeAndAfterAll { System.clearProperty("spark.driver.port") } - // Test if we can correctly learn Y = 3 + X1 + X2 when + // Test if we can correctly learn Y = 3 + 10*X1 + 10*X2 when // X1 and X2 are collinear. test("multi-collinear variables") { - val rnd = new Random(43) - val x1 = Array.fill[Double](20)(rnd.nextGaussian()) - - // Pick a mean close to mean of x1 - val rnd1 = new Random(42) //new NormalDistribution(0.1, 0.01) - val x2 = Array.fill[Double](20)(0.1 + rnd1.nextGaussian() * 0.01) + val testRDD = RidgeRegressionDataGenerator.generateRidgeRDD(sc, 100, 2, 0.0, intercept=3.0).cache() + val ridgeReg = new RidgeRegressionWithSGD() + ridgeReg.optimizer.setNumIterations(1000).setRegParam(0.0).setStepSize(1.0) - val xMat = (0 until 20).map(i => Array(x1(i), x2(i))).toArray + val model = ridgeReg.run(testRDD) - val y = xMat.map(i => 3 + i(0) + i(1)) - val testData = (0 until 20).map(i => LabeledPoint(y(i), xMat(i))).toArray + assert(model.intercept >= 2.5 && model.intercept <= 3.5) + assert(model.weights.length === 2) + assert(model.weights(0) >= 9.0 && model.weights(0) <= 11.0) + assert(model.weights(1) >= 9.0 && model.weights(1) <= 11.0) + } - val testRDD = sc.parallelize(testData, 2) - testRDD.cache() - val ridgeReg = new RidgeRegression().setLowLambda(0) - .setHighLambda(10) + test("multi-collinear variables with regularization") { + val testRDD = RidgeRegressionDataGenerator.generateRidgeRDD(sc, 100, 2, 0.0, intercept=3.0).cache() + val ridgeReg = new RidgeRegressionWithSGD() + ridgeReg.optimizer.setNumIterations(1000).setRegParam(1.0).setStepSize(1.0) - val model = ridgeReg.train(testRDD) + val model = ridgeReg.run(testRDD) - assert(model.intercept >= 2.9 && model.intercept <= 3.1) + assert(model.intercept <= 5.0) assert(model.weights.length === 2) - assert(model.weights.get(0) >= 0.9 && model.weights.get(0) <= 1.1) - assert(model.weights.get(1) >= 0.9 && model.weights.get(1) <= 1.1) + assert(model.weights(0) <= 3.0) + assert(model.weights(1) <= 3.0) } } |