aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/java/org
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/test/java/org')
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java6
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java3
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java3
3 files changed, 8 insertions, 4 deletions
diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java
index e18e3bc6a8..d75d3a6b26 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java
@@ -68,6 +68,7 @@ public class JavaLogisticRegressionSuite implements Serializable {
LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 17);
LogisticRegressionWithSGD lrImpl = new LogisticRegressionWithSGD();
+ lrImpl.setIntercept(true);
lrImpl.optimizer().setStepSize(1.0)
.setRegParam(1.0)
.setNumIterations(100);
@@ -80,8 +81,8 @@ public class JavaLogisticRegressionSuite implements Serializable {
@Test
public void runLRUsingStaticMethods() {
int nPoints = 10000;
- double A = 2.0;
- double B = -1.5;
+ double A = 0.0;
+ double B = -2.5;
JavaRDD<LabeledPoint> testRDD = sc.parallelize(
LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
@@ -92,6 +93,7 @@ public class JavaLogisticRegressionSuite implements Serializable {
testRDD.rdd(), 100, 1.0, 1.0);
int numAccurate = validatePrediction(validationData, model);
+ System.out.println(numAccurate);
Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0);
}
diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java
index 4701a5e545..667f76a1bd 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java
@@ -67,6 +67,7 @@ public class JavaSVMSuite implements Serializable {
SVMSuite.generateSVMInputAsList(A, weights, nPoints, 17);
SVMWithSGD svmSGDImpl = new SVMWithSGD();
+ svmSGDImpl.setIntercept(true);
svmSGDImpl.optimizer().setStepSize(1.0)
.setRegParam(1.0)
.setNumIterations(100);
@@ -79,7 +80,7 @@ public class JavaSVMSuite implements Serializable {
@Test
public void runSVMUsingStaticMethods() {
int nPoints = 10000;
- double A = 2.0;
+ double A = 0.0;
double[] weights = {-1.5, 1.0};
JavaRDD<LabeledPoint> testRDD = sc.parallelize(SVMSuite.generateSVMInputAsList(A,
diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java
index 5a4410a632..7151e55351 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java
@@ -68,6 +68,7 @@ public class JavaLinearRegressionSuite implements Serializable {
LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1);
LinearRegressionWithSGD linSGDImpl = new LinearRegressionWithSGD();
+ linSGDImpl.setIntercept(true);
LinearRegressionModel model = linSGDImpl.run(testRDD.rdd());
int numAccurate = validatePrediction(validationData, model);
@@ -77,7 +78,7 @@ public class JavaLinearRegressionSuite implements Serializable {
@Test
public void runLinearRegressionUsingStaticMethods() {
int nPoints = 100;
- double A = 3.0;
+ double A = 0.0;
double[] weights = {10, 10};
JavaRDD<LabeledPoint> testRDD = sc.parallelize(