diff options
author | Matei Zaharia <matei@databricks.com> | 2014-01-10 16:08:35 -0800 |
---|---|---|
committer | Matei Zaharia <matei@databricks.com> | 2014-01-11 22:30:48 -0800 |
commit | f00e949f84df949fbe32c254b592a580b4623811 (patch) | |
tree | 6d1562dedcd649573ab5603cdb70d92123d11808 /mllib/src/test/java | |
parent | 4c28a2bad8a6d64ee69213eede440837636fe58b (diff) | |
download | spark-f00e949f84df949fbe32c254b592a580b4623811.tar.gz spark-f00e949f84df949fbe32c254b592a580b4623811.tar.bz2 spark-f00e949f84df949fbe32c254b592a580b4623811.zip |
Added Java unit test, data, and main method for Naive Bayes
Also fixes mains of a few other algorithms to print the final model
Diffstat (limited to 'mllib/src/test/java')
-rw-r--r-- | mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java | 72 |
1 files changed, 72 insertions, 0 deletions
diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java new file mode 100644 index 0000000000..23ea3548b9 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java @@ -0,0 +1,72 @@ +package org.apache.spark.mllib.classification; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import java.io.Serializable; +import java.util.Arrays; +import java.util.List; + +public class JavaNaiveBayesSuite implements Serializable { + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaNaiveBayesSuite"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + System.clearProperty("spark.driver.port"); + } + + private static final List<LabeledPoint> POINTS = Arrays.asList( + new LabeledPoint(0, new double[] {1.0, 0.0, 0.0}), + new LabeledPoint(0, new double[] {2.0, 0.0, 0.0}), + new LabeledPoint(1, new double[] {0.0, 1.0, 0.0}), + new LabeledPoint(1, new double[] {0.0, 2.0, 0.0}), + new LabeledPoint(2, new double[] {0.0, 0.0, 1.0}), + new LabeledPoint(2, new double[] {0.0, 0.0, 2.0}) + ); + + private int validatePrediction(List<LabeledPoint> points, NaiveBayesModel model) { + int correct = 0; + for (LabeledPoint p: points) { + if (model.predict(p.features()) == p.label()) { + correct += 1; + } + } + return correct; + } + + @Test + public void runUsingConstructor() { + JavaRDD<LabeledPoint> testRDD = sc.parallelize(POINTS, 2).cache(); + + NaiveBayes nb = new NaiveBayes().setLambda(1.0); + NaiveBayesModel model = nb.run(testRDD.rdd()); + + int numAccurate = validatePrediction(POINTS, model); + Assert.assertEquals(POINTS.size(), numAccurate); + } + + @Test + public void runUsingStaticMethods() { + JavaRDD<LabeledPoint> testRDD = sc.parallelize(POINTS, 2).cache(); + + NaiveBayesModel model1 = NaiveBayes.train(testRDD.rdd()); + int numAccurate1 = validatePrediction(POINTS, model1); + Assert.assertEquals(POINTS.size(), numAccurate1); + + NaiveBayesModel model2 = NaiveBayes.train(testRDD.rdd(), 0.5); + int numAccurate2 = validatePrediction(POINTS, model2); + Assert.assertEquals(POINTS.size(), numAccurate2); + } +} |