diff options
author | Evan Sparks <sparks@cs.berkeley.edu> | 2013-08-16 17:48:26 -0700 |
---|---|---|
committer | Shivaram Venkataraman <shivaram@eecs.berkeley.edu> | 2013-08-18 15:03:13 -0700 |
commit | b291db712e73fdff0c02946bac96e330b089409d (patch) | |
tree | 42673ff425dbe405c95453199a82ca454b570693 /mllib/src/test/java | |
parent | b659af83d3f91f0f339d874b2742ddca20a9f610 (diff) | |
download | spark-b291db712e73fdff0c02946bac96e330b089409d.tar.gz spark-b291db712e73fdff0c02946bac96e330b089409d.tar.bz2 spark-b291db712e73fdff0c02946bac96e330b089409d.zip |
Centralizing linear data generator and mllib regression tests to use it.
Diffstat (limited to 'mllib/src/test/java')
3 files changed, 16 insertions, 13 deletions
diff --git a/mllib/src/test/java/spark/mllib/regression/JavaLassoSuite.java b/mllib/src/test/java/spark/mllib/regression/JavaLassoSuite.java index e26d7b385c..8d692c2d0d 100644 --- a/mllib/src/test/java/spark/mllib/regression/JavaLassoSuite.java +++ b/mllib/src/test/java/spark/mllib/regression/JavaLassoSuite.java @@ -27,6 +27,7 @@ import org.junit.Test; import spark.api.java.JavaRDD; import spark.api.java.JavaSparkContext; +import spark.mllib.util.LinearDataGenerator; public class JavaLassoSuite implements Serializable { private transient JavaSparkContext sc; @@ -61,10 +62,10 @@ public class JavaLassoSuite implements Serializable { double A = 2.0; double[] weights = {-1.5, 1.0e-2}; - JavaRDD<LabeledPoint> testRDD = sc.parallelize(LassoSuite.generateLassoInputAsList(A, - weights, nPoints, 42), 2).cache(); + JavaRDD<LabeledPoint> testRDD = sc.parallelize(LinearDataGenerator.generateLinearInputAsList(A, + weights, nPoints, 42), 2).cache(); List<LabeledPoint> validationData = - LassoSuite.generateLassoInputAsList(A, weights, nPoints, 17); + LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17); LassoWithSGD svmSGDImpl = new LassoWithSGD(); svmSGDImpl.optimizer().setStepSize(1.0) @@ -82,10 +83,10 @@ public class JavaLassoSuite implements Serializable { double A = 2.0; double[] weights = {-1.5, 1.0e-2}; - JavaRDD<LabeledPoint> testRDD = sc.parallelize(LassoSuite.generateLassoInputAsList(A, + JavaRDD<LabeledPoint> testRDD = sc.parallelize(LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 42), 2).cache(); List<LabeledPoint> validationData = - LassoSuite.generateLassoInputAsList(A, weights, nPoints, 17); + LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17); LassoModel model = LassoWithSGD.train(testRDD.rdd(), 100, 1.0, 0.01, 1.0); diff --git a/mllib/src/test/java/spark/mllib/regression/JavaLinearRegressionSuite.java b/mllib/src/test/java/spark/mllib/regression/JavaLinearRegressionSuite.java index 14d3d4ef39..d2d8a62980 100644 --- a/mllib/src/test/java/spark/mllib/regression/JavaLinearRegressionSuite.java +++ b/mllib/src/test/java/spark/mllib/regression/JavaLinearRegressionSuite.java @@ -27,6 +27,7 @@ import org.junit.Test; import spark.api.java.JavaRDD; import spark.api.java.JavaSparkContext; +import spark.mllib.util.LinearDataGenerator; public class JavaLinearRegressionSuite implements Serializable { private transient JavaSparkContext sc; @@ -61,10 +62,10 @@ public class JavaLinearRegressionSuite implements Serializable { double A = 2.0; double[] weights = {-1.5, 1.0e-2}; - JavaRDD<LabeledPoint> testRDD = sc.parallelize(LinearRegressionSuite.generateLinearRegressionInputAsList(A, + JavaRDD<LabeledPoint> testRDD = sc.parallelize(LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 42), 2).cache(); List<LabeledPoint> validationData = - LinearRegressionSuite.generateLinearRegressionInputAsList(A, weights, nPoints, 17); + LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17); LinearRegressionWithSGD svmSGDImpl = new LinearRegressionWithSGD(); svmSGDImpl.optimizer().setStepSize(1.0) @@ -82,10 +83,10 @@ public class JavaLinearRegressionSuite implements Serializable { double A = 2.0; double[] weights = {-1.5, 1.0e-2}; - JavaRDD<LabeledPoint> testRDD = sc.parallelize(LinearRegressionSuite.generateLinearRegressionInputAsList(A, + JavaRDD<LabeledPoint> testRDD = sc.parallelize(LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 42), 2).cache(); List<LabeledPoint> validationData = - LinearRegressionSuite.generateLinearRegressionInputAsList(A, weights, nPoints, 17); + LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17); LinearRegressionModel model = LinearRegressionWithSGD.train(testRDD.rdd(), 100, 1.0, 1.0); diff --git a/mllib/src/test/java/spark/mllib/regression/JavaRidgeRegressionSuite.java b/mllib/src/test/java/spark/mllib/regression/JavaRidgeRegressionSuite.java index 4f379b51d5..72ab875985 100644 --- a/mllib/src/test/java/spark/mllib/regression/JavaRidgeRegressionSuite.java +++ b/mllib/src/test/java/spark/mllib/regression/JavaRidgeRegressionSuite.java @@ -27,6 +27,7 @@ import org.junit.Test; import spark.api.java.JavaRDD; import spark.api.java.JavaSparkContext; +import spark.mllib.util.LinearDataGenerator; public class JavaRidgeRegressionSuite implements Serializable { private transient JavaSparkContext sc; @@ -61,10 +62,10 @@ public class JavaRidgeRegressionSuite implements Serializable { double A = 2.0; double[] weights = {-1.5, 1.0e-2}; - JavaRDD<LabeledPoint> testRDD = sc.parallelize(RidgeRegressionSuite.generateRidgeRegressionInputAsList(A, + JavaRDD<LabeledPoint> testRDD = sc.parallelize(LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 42), 2).cache(); List<LabeledPoint> validationData = - RidgeRegressionSuite.generateRidgeRegressionInputAsList(A, weights, nPoints, 17); + LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17); RidgeRegressionWithSGD svmSGDImpl = new RidgeRegressionWithSGD(); svmSGDImpl.optimizer().setStepSize(1.0) @@ -82,10 +83,10 @@ public class JavaRidgeRegressionSuite implements Serializable { double A = 2.0; double[] weights = {-1.5, 1.0e-2}; - JavaRDD<LabeledPoint> testRDD = sc.parallelize(RidgeRegressionSuite.generateRidgeRegressionInputAsList(A, + JavaRDD<LabeledPoint> testRDD = sc.parallelize(LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 42), 2).cache(); List<LabeledPoint> validationData = - RidgeRegressionSuite.generateRidgeRegressionInputAsList(A, weights, nPoints, 17); + LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17); RidgeRegressionModel model = RidgeRegressionWithSGD.train(testRDD.rdd(), 100, 1.0, 0.01, 1.0); |