aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorMatei Zaharia <matei@databricks.com>2014-01-10 16:08:35 -0800
committerMatei Zaharia <matei@databricks.com>2014-01-11 22:30:48 -0800
commitf00e949f84df949fbe32c254b592a580b4623811 (patch)
tree6d1562dedcd649573ab5603cdb70d92123d11808 /mllib/src/test
parent4c28a2bad8a6d64ee69213eede440837636fe58b (diff)
downloadspark-f00e949f84df949fbe32c254b592a580b4623811.tar.gz
spark-f00e949f84df949fbe32c254b592a580b4623811.tar.bz2
spark-f00e949f84df949fbe32c254b592a580b4623811.zip
Added Java unit test, data, and main method for Naive Bayes
Also fixes mains of a few other algorithms to print the final model
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java72
1 files changed, 72 insertions, 0 deletions
diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java
new file mode 100644
index 0000000000..23ea3548b9
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java
@@ -0,0 +1,72 @@
+package org.apache.spark.mllib.classification;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.List;
+
+public class JavaNaiveBayesSuite implements Serializable {
+ private transient JavaSparkContext sc;
+
+ @Before
+ public void setUp() {
+ sc = new JavaSparkContext("local", "JavaNaiveBayesSuite");
+ }
+
+ @After
+ public void tearDown() {
+ sc.stop();
+ sc = null;
+ System.clearProperty("spark.driver.port");
+ }
+
+ private static final List<LabeledPoint> POINTS = Arrays.asList(
+ new LabeledPoint(0, new double[] {1.0, 0.0, 0.0}),
+ new LabeledPoint(0, new double[] {2.0, 0.0, 0.0}),
+ new LabeledPoint(1, new double[] {0.0, 1.0, 0.0}),
+ new LabeledPoint(1, new double[] {0.0, 2.0, 0.0}),
+ new LabeledPoint(2, new double[] {0.0, 0.0, 1.0}),
+ new LabeledPoint(2, new double[] {0.0, 0.0, 2.0})
+ );
+
+ private int validatePrediction(List<LabeledPoint> points, NaiveBayesModel model) {
+ int correct = 0;
+ for (LabeledPoint p: points) {
+ if (model.predict(p.features()) == p.label()) {
+ correct += 1;
+ }
+ }
+ return correct;
+ }
+
+ @Test
+ public void runUsingConstructor() {
+ JavaRDD<LabeledPoint> testRDD = sc.parallelize(POINTS, 2).cache();
+
+ NaiveBayes nb = new NaiveBayes().setLambda(1.0);
+ NaiveBayesModel model = nb.run(testRDD.rdd());
+
+ int numAccurate = validatePrediction(POINTS, model);
+ Assert.assertEquals(POINTS.size(), numAccurate);
+ }
+
+ @Test
+ public void runUsingStaticMethods() {
+ JavaRDD<LabeledPoint> testRDD = sc.parallelize(POINTS, 2).cache();
+
+ NaiveBayesModel model1 = NaiveBayes.train(testRDD.rdd());
+ int numAccurate1 = validatePrediction(POINTS, model1);
+ Assert.assertEquals(POINTS.size(), numAccurate1);
+
+ NaiveBayesModel model2 = NaiveBayes.train(testRDD.rdd(), 0.5);
+ int numAccurate2 = validatePrediction(POINTS, model2);
+ Assert.assertEquals(POINTS.size(), numAccurate2);
+ }
+}