diff options
Diffstat (limited to 'mllib/src/test/java/org')
3 files changed, 8 insertions, 4 deletions
diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java index e18e3bc6a8..d75d3a6b26 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java @@ -68,6 +68,7 @@ public class JavaLogisticRegressionSuite implements Serializable { LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 17); LogisticRegressionWithSGD lrImpl = new LogisticRegressionWithSGD(); + lrImpl.setIntercept(true); lrImpl.optimizer().setStepSize(1.0) .setRegParam(1.0) .setNumIterations(100); @@ -80,8 +81,8 @@ public class JavaLogisticRegressionSuite implements Serializable { @Test public void runLRUsingStaticMethods() { int nPoints = 10000; - double A = 2.0; - double B = -1.5; + double A = 0.0; + double B = -2.5; JavaRDD<LabeledPoint> testRDD = sc.parallelize( LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); @@ -92,6 +93,7 @@ public class JavaLogisticRegressionSuite implements Serializable { testRDD.rdd(), 100, 1.0, 1.0); int numAccurate = validatePrediction(validationData, model); + System.out.println(numAccurate); Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0); } diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java index 4701a5e545..667f76a1bd 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java @@ -67,6 +67,7 @@ public class JavaSVMSuite implements Serializable { SVMSuite.generateSVMInputAsList(A, weights, nPoints, 17); SVMWithSGD svmSGDImpl = new SVMWithSGD(); + svmSGDImpl.setIntercept(true); svmSGDImpl.optimizer().setStepSize(1.0) .setRegParam(1.0) .setNumIterations(100); @@ -79,7 +80,7 @@ public class JavaSVMSuite implements Serializable { @Test public void runSVMUsingStaticMethods() { int nPoints = 10000; - double A = 2.0; + double A = 0.0; double[] weights = {-1.5, 1.0}; JavaRDD<LabeledPoint> testRDD = sc.parallelize(SVMSuite.generateSVMInputAsList(A, diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java index 5a4410a632..7151e55351 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java @@ -68,6 +68,7 @@ public class JavaLinearRegressionSuite implements Serializable { LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1); LinearRegressionWithSGD linSGDImpl = new LinearRegressionWithSGD(); + linSGDImpl.setIntercept(true); LinearRegressionModel model = linSGDImpl.run(testRDD.rdd()); int numAccurate = validatePrediction(validationData, model); @@ -77,7 +78,7 @@ public class JavaLinearRegressionSuite implements Serializable { @Test public void runLinearRegressionUsingStaticMethods() { int nPoints = 100; - double A = 3.0; + double A = 0.0; double[] weights = {10, 10}; JavaRDD<LabeledPoint> testRDD = sc.parallelize( |