aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/java/org
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2014-05-05 18:32:54 -0700
committerMatei Zaharia <matei@databricks.com>2014-05-05 18:32:54 -0700
commit98750a74daf7e2b873da85d2d5067f47e3bbdc4e (patch)
tree7751cfc30345957b4ee65bde5a0a91fe57a984e3 /mllib/src/test/java/org
parentea10b3126167af3f50f7c2a70e1d942e839fcb66 (diff)
downloadspark-98750a74daf7e2b873da85d2d5067f47e3bbdc4e.tar.gz
spark-98750a74daf7e2b873da85d2d5067f47e3bbdc4e.tar.bz2
spark-98750a74daf7e2b873da85d2d5067f47e3bbdc4e.zip
[SPARK-1594][MLLIB] Cleaning up MLlib APIs and guide
Final pass before the v1.0 release. * Remove `VectorRDDs` * Move `BinaryClassificationMetrics` from `evaluation.binary` to `evaluation` * Change default value of `addIntercept` to false and allow to add intercept in Ridge and Lasso. * Clean `DecisionTree` package doc and test suite. * Mark model constructors `private[spark]` * Rename `loadLibSVMData` to `loadLibSVMFile` and hide `LabelParser` from users. * Add `saveAsLibSVMFile`. * Add `appendBias` to `MLUtils`. Author: Xiangrui Meng <meng@databricks.com> Closes #524 from mengxr/mllib-cleaning and squashes the following commits: 295dc8b [Xiangrui Meng] update loadLibSVMFile doc 1977ac1 [Xiangrui Meng] fix doc of appendBias 649fcf0 [Xiangrui Meng] rename loadLibSVMData to loadLibSVMFile; hide LabelParser from user APIs 54b812c [Xiangrui Meng] add appendBias a71e7d0 [Xiangrui Meng] add saveAsLibSVMFile d976295 [Xiangrui Meng] Merge branch 'master' into mllib-cleaning b7e5cec [Xiangrui Meng] remove some experimental annotations and make model constructors private[mllib] 9b02b93 [Xiangrui Meng] minor code style update a593ddc [Xiangrui Meng] fix python tests fc28c18 [Xiangrui Meng] mark more classes experimental f6cbbff [Xiangrui Meng] fix Java tests 0af70b0 [Xiangrui Meng] minor 6e139ef [Xiangrui Meng] Merge branch 'master' into mllib-cleaning 94e6dce [Xiangrui Meng] move BinaryLabelCounter and BinaryConfusionMatrixImpl to evaluation.binary df34907 [Xiangrui Meng] clean DecisionTreeSuite to use LocalSparkContext c81807f [Xiangrui Meng] set the default value of AddIntercept to false 03389c0 [Xiangrui Meng] allow to add intercept in Ridge and Lasso c66c56f [Xiangrui Meng] move tree md to package object doc a2695df [Xiangrui Meng] update guide for BinaryClassificationMetrics 9194f4c [Xiangrui Meng] move BinaryClassificationMetrics one level up 1c1a0e3 [Xiangrui Meng] remove VectorRDDs because it only contains one function that is not necessary for us to maintain
Diffstat (limited to 'mllib/src/test/java/org')
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java6
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java3
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java3
3 files changed, 8 insertions, 4 deletions
diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java
index e18e3bc6a8..d75d3a6b26 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java
@@ -68,6 +68,7 @@ public class JavaLogisticRegressionSuite implements Serializable {
LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 17);
LogisticRegressionWithSGD lrImpl = new LogisticRegressionWithSGD();
+ lrImpl.setIntercept(true);
lrImpl.optimizer().setStepSize(1.0)
.setRegParam(1.0)
.setNumIterations(100);
@@ -80,8 +81,8 @@ public class JavaLogisticRegressionSuite implements Serializable {
@Test
public void runLRUsingStaticMethods() {
int nPoints = 10000;
- double A = 2.0;
- double B = -1.5;
+ double A = 0.0;
+ double B = -2.5;
JavaRDD<LabeledPoint> testRDD = sc.parallelize(
LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
@@ -92,6 +93,7 @@ public class JavaLogisticRegressionSuite implements Serializable {
testRDD.rdd(), 100, 1.0, 1.0);
int numAccurate = validatePrediction(validationData, model);
+ System.out.println(numAccurate);
Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0);
}
diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java
index 4701a5e545..667f76a1bd 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java
@@ -67,6 +67,7 @@ public class JavaSVMSuite implements Serializable {
SVMSuite.generateSVMInputAsList(A, weights, nPoints, 17);
SVMWithSGD svmSGDImpl = new SVMWithSGD();
+ svmSGDImpl.setIntercept(true);
svmSGDImpl.optimizer().setStepSize(1.0)
.setRegParam(1.0)
.setNumIterations(100);
@@ -79,7 +80,7 @@ public class JavaSVMSuite implements Serializable {
@Test
public void runSVMUsingStaticMethods() {
int nPoints = 10000;
- double A = 2.0;
+ double A = 0.0;
double[] weights = {-1.5, 1.0};
JavaRDD<LabeledPoint> testRDD = sc.parallelize(SVMSuite.generateSVMInputAsList(A,
diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java
index 5a4410a632..7151e55351 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java
@@ -68,6 +68,7 @@ public class JavaLinearRegressionSuite implements Serializable {
LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1);
LinearRegressionWithSGD linSGDImpl = new LinearRegressionWithSGD();
+ linSGDImpl.setIntercept(true);
LinearRegressionModel model = linSGDImpl.run(testRDD.rdd());
int numAccurate = validatePrediction(validationData, model);
@@ -77,7 +78,7 @@ public class JavaLinearRegressionSuite implements Serializable {
@Test
public void runLinearRegressionUsingStaticMethods() {
int nPoints = 100;
- double A = 3.0;
+ double A = 0.0;
double[] weights = {10, 10};
JavaRDD<LabeledPoint> testRDD = sc.parallelize(