aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/java
diff options
context:
space:
mode:
authorEvan Sparks <sparks@cs.berkeley.edu>2013-08-16 17:48:26 -0700
committerShivaram Venkataraman <shivaram@eecs.berkeley.edu>2013-08-18 15:03:13 -0700
commitb291db712e73fdff0c02946bac96e330b089409d (patch)
tree42673ff425dbe405c95453199a82ca454b570693 /mllib/src/test/java
parentb659af83d3f91f0f339d874b2742ddca20a9f610 (diff)
downloadspark-b291db712e73fdff0c02946bac96e330b089409d.tar.gz
spark-b291db712e73fdff0c02946bac96e330b089409d.tar.bz2
spark-b291db712e73fdff0c02946bac96e330b089409d.zip
Centralizing linear data generator and mllib regression tests to use it.
Diffstat (limited to 'mllib/src/test/java')
-rw-r--r--mllib/src/test/java/spark/mllib/regression/JavaLassoSuite.java11
-rw-r--r--mllib/src/test/java/spark/mllib/regression/JavaLinearRegressionSuite.java9
-rw-r--r--mllib/src/test/java/spark/mllib/regression/JavaRidgeRegressionSuite.java9
3 files changed, 16 insertions, 13 deletions
diff --git a/mllib/src/test/java/spark/mllib/regression/JavaLassoSuite.java b/mllib/src/test/java/spark/mllib/regression/JavaLassoSuite.java
index e26d7b385c..8d692c2d0d 100644
--- a/mllib/src/test/java/spark/mllib/regression/JavaLassoSuite.java
+++ b/mllib/src/test/java/spark/mllib/regression/JavaLassoSuite.java
@@ -27,6 +27,7 @@ import org.junit.Test;
import spark.api.java.JavaRDD;
import spark.api.java.JavaSparkContext;
+import spark.mllib.util.LinearDataGenerator;
public class JavaLassoSuite implements Serializable {
private transient JavaSparkContext sc;
@@ -61,10 +62,10 @@ public class JavaLassoSuite implements Serializable {
double A = 2.0;
double[] weights = {-1.5, 1.0e-2};
- JavaRDD<LabeledPoint> testRDD = sc.parallelize(LassoSuite.generateLassoInputAsList(A,
- weights, nPoints, 42), 2).cache();
+ JavaRDD<LabeledPoint> testRDD = sc.parallelize(LinearDataGenerator.generateLinearInputAsList(A,
+ weights, nPoints, 42), 2).cache();
List<LabeledPoint> validationData =
- LassoSuite.generateLassoInputAsList(A, weights, nPoints, 17);
+ LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17);
LassoWithSGD svmSGDImpl = new LassoWithSGD();
svmSGDImpl.optimizer().setStepSize(1.0)
@@ -82,10 +83,10 @@ public class JavaLassoSuite implements Serializable {
double A = 2.0;
double[] weights = {-1.5, 1.0e-2};
- JavaRDD<LabeledPoint> testRDD = sc.parallelize(LassoSuite.generateLassoInputAsList(A,
+ JavaRDD<LabeledPoint> testRDD = sc.parallelize(LinearDataGenerator.generateLinearInputAsList(A,
weights, nPoints, 42), 2).cache();
List<LabeledPoint> validationData =
- LassoSuite.generateLassoInputAsList(A, weights, nPoints, 17);
+ LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17);
LassoModel model = LassoWithSGD.train(testRDD.rdd(), 100, 1.0, 0.01, 1.0);
diff --git a/mllib/src/test/java/spark/mllib/regression/JavaLinearRegressionSuite.java b/mllib/src/test/java/spark/mllib/regression/JavaLinearRegressionSuite.java
index 14d3d4ef39..d2d8a62980 100644
--- a/mllib/src/test/java/spark/mllib/regression/JavaLinearRegressionSuite.java
+++ b/mllib/src/test/java/spark/mllib/regression/JavaLinearRegressionSuite.java
@@ -27,6 +27,7 @@ import org.junit.Test;
import spark.api.java.JavaRDD;
import spark.api.java.JavaSparkContext;
+import spark.mllib.util.LinearDataGenerator;
public class JavaLinearRegressionSuite implements Serializable {
private transient JavaSparkContext sc;
@@ -61,10 +62,10 @@ public class JavaLinearRegressionSuite implements Serializable {
double A = 2.0;
double[] weights = {-1.5, 1.0e-2};
- JavaRDD<LabeledPoint> testRDD = sc.parallelize(LinearRegressionSuite.generateLinearRegressionInputAsList(A,
+ JavaRDD<LabeledPoint> testRDD = sc.parallelize(LinearDataGenerator.generateLinearInputAsList(A,
weights, nPoints, 42), 2).cache();
List<LabeledPoint> validationData =
- LinearRegressionSuite.generateLinearRegressionInputAsList(A, weights, nPoints, 17);
+ LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17);
LinearRegressionWithSGD svmSGDImpl = new LinearRegressionWithSGD();
svmSGDImpl.optimizer().setStepSize(1.0)
@@ -82,10 +83,10 @@ public class JavaLinearRegressionSuite implements Serializable {
double A = 2.0;
double[] weights = {-1.5, 1.0e-2};
- JavaRDD<LabeledPoint> testRDD = sc.parallelize(LinearRegressionSuite.generateLinearRegressionInputAsList(A,
+ JavaRDD<LabeledPoint> testRDD = sc.parallelize(LinearDataGenerator.generateLinearInputAsList(A,
weights, nPoints, 42), 2).cache();
List<LabeledPoint> validationData =
- LinearRegressionSuite.generateLinearRegressionInputAsList(A, weights, nPoints, 17);
+ LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17);
LinearRegressionModel model = LinearRegressionWithSGD.train(testRDD.rdd(), 100, 1.0, 1.0);
diff --git a/mllib/src/test/java/spark/mllib/regression/JavaRidgeRegressionSuite.java b/mllib/src/test/java/spark/mllib/regression/JavaRidgeRegressionSuite.java
index 4f379b51d5..72ab875985 100644
--- a/mllib/src/test/java/spark/mllib/regression/JavaRidgeRegressionSuite.java
+++ b/mllib/src/test/java/spark/mllib/regression/JavaRidgeRegressionSuite.java
@@ -27,6 +27,7 @@ import org.junit.Test;
import spark.api.java.JavaRDD;
import spark.api.java.JavaSparkContext;
+import spark.mllib.util.LinearDataGenerator;
public class JavaRidgeRegressionSuite implements Serializable {
private transient JavaSparkContext sc;
@@ -61,10 +62,10 @@ public class JavaRidgeRegressionSuite implements Serializable {
double A = 2.0;
double[] weights = {-1.5, 1.0e-2};
- JavaRDD<LabeledPoint> testRDD = sc.parallelize(RidgeRegressionSuite.generateRidgeRegressionInputAsList(A,
+ JavaRDD<LabeledPoint> testRDD = sc.parallelize(LinearDataGenerator.generateLinearInputAsList(A,
weights, nPoints, 42), 2).cache();
List<LabeledPoint> validationData =
- RidgeRegressionSuite.generateRidgeRegressionInputAsList(A, weights, nPoints, 17);
+ LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17);
RidgeRegressionWithSGD svmSGDImpl = new RidgeRegressionWithSGD();
svmSGDImpl.optimizer().setStepSize(1.0)
@@ -82,10 +83,10 @@ public class JavaRidgeRegressionSuite implements Serializable {
double A = 2.0;
double[] weights = {-1.5, 1.0e-2};
- JavaRDD<LabeledPoint> testRDD = sc.parallelize(RidgeRegressionSuite.generateRidgeRegressionInputAsList(A,
+ JavaRDD<LabeledPoint> testRDD = sc.parallelize(LinearDataGenerator.generateLinearInputAsList(A,
weights, nPoints, 42), 2).cache();
List<LabeledPoint> validationData =
- RidgeRegressionSuite.generateRidgeRegressionInputAsList(A, weights, nPoints, 17);
+ LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17);
RidgeRegressionModel model = RidgeRegressionWithSGD.train(testRDD.rdd(), 100, 1.0, 0.01, 1.0);