From f00e949f84df949fbe32c254b592a580b4623811 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Fri, 10 Jan 2014 16:08:35 -0800 Subject: Added Java unit test, data, and main method for Naive Bayes Also fixes mains of a few other algorithms to print the final model --- mllib/data/sample_naive_bayes_data.txt | 6 ++ .../mllib/classification/LogisticRegression.scala | 4 +- .../spark/mllib/classification/NaiveBayes.scala | 21 ++++++- .../apache/spark/mllib/classification/SVM.scala | 2 + .../org/apache/spark/mllib/regression/Lasso.scala | 4 +- .../spark/mllib/regression/LinearRegression.scala | 2 + .../spark/mllib/regression/RidgeRegression.scala | 4 +- .../mllib/classification/JavaNaiveBayesSuite.java | 72 ++++++++++++++++++++++ 8 files changed, 111 insertions(+), 4 deletions(-) create mode 100644 mllib/data/sample_naive_bayes_data.txt create mode 100644 mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java diff --git a/mllib/data/sample_naive_bayes_data.txt b/mllib/data/sample_naive_bayes_data.txt new file mode 100644 index 0000000000..f874adbaf4 --- /dev/null +++ b/mllib/data/sample_naive_bayes_data.txt @@ -0,0 +1,6 @@ +0, 1 0 0 +0, 2 0 0 +1, 0 1 0 +1, 0 2 0 +2, 0 0 1 +2, 0 0 2 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index 50aede9c07..a481f52276 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -97,7 +97,7 @@ object LogisticRegressionWithSGD { * @param numIterations Number of iterations of gradient descent to run. * @param stepSize Step size to be used for each iteration of gradient descent. * @param miniBatchFraction Fraction of data to be used per iteration. - * @param initialWeights Initial set of weights to be used. Array should be equal in size to + * @param initialWeights Initial set of weights to be used. Array should be equal in size to * the number of features in the data. */ def train( @@ -183,6 +183,8 @@ object LogisticRegressionWithSGD { val sc = new SparkContext(args(0), "LogisticRegression") val data = MLUtils.loadLabeledData(sc, args(1)) val model = LogisticRegressionWithSGD.train(data, args(3).toInt, args(2).toDouble) + println("Weights: " + model.weights.mkString("[", ", ", "]")) + println("Intercept: " + model.intercept) sc.stop() } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index f45802cd0b..6539b2f339 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -21,9 +21,10 @@ import scala.collection.mutable import org.jblas.DoubleMatrix -import org.apache.spark.Logging +import org.apache.spark.{SparkContext, Logging} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.RDD +import org.apache.spark.mllib.util.MLUtils /** * Model for Naive Bayes Classifiers. @@ -144,4 +145,22 @@ object NaiveBayes { def train(input: RDD[LabeledPoint], lambda: Double): NaiveBayesModel = { new NaiveBayes(lambda).run(input) } + + def main(args: Array[String]) { + if (args.length != 2 && args.length != 3) { + println("Usage: NaiveBayes []") + System.exit(1) + } + val sc = new SparkContext(args(0), "NaiveBayes") + val data = MLUtils.loadLabeledData(sc, args(1)) + val model = if (args.length == 2) { + NaiveBayes.train(data) + } else { + NaiveBayes.train(data, args(2).toDouble) + } + println("Pi: " + model.pi.mkString("[", ", ", "]")) + println("Theta:\n" + model.theta.map(_.mkString("[", ", ", "]")).mkString("[", "\n ", "]")) + + sc.stop() + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala index 3b8f8550d0..f2964ea446 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala @@ -183,6 +183,8 @@ object SVMWithSGD { val sc = new SparkContext(args(0), "SVM") val data = MLUtils.loadLabeledData(sc, args(1)) val model = SVMWithSGD.train(data, args(4).toInt, args(2).toDouble, args(3).toDouble) + println("Weights: " + model.weights.mkString("[", ", ", "]")) + println("Intercept: " + model.intercept) sc.stop() } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala index d959695325..7c41793722 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala @@ -121,7 +121,7 @@ object LassoWithSGD { * @param stepSize Step size to be used for each iteration of gradient descent. * @param regParam Regularization parameter. * @param miniBatchFraction Fraction of data to be used per iteration. - * @param initialWeights Initial set of weights to be used. Array should be equal in size to + * @param initialWeights Initial set of weights to be used. Array should be equal in size to * the number of features in the data. */ def train( @@ -205,6 +205,8 @@ object LassoWithSGD { val sc = new SparkContext(args(0), "Lasso") val data = MLUtils.loadLabeledData(sc, args(1)) val model = LassoWithSGD.train(data, args(4).toInt, args(2).toDouble, args(3).toDouble) + println("Weights: " + model.weights.mkString("[", ", ", "]")) + println("Intercept: " + model.intercept) sc.stop() } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala index 597d55e0bb..fe5cce064b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala @@ -162,6 +162,8 @@ object LinearRegressionWithSGD { val sc = new SparkContext(args(0), "LinearRegression") val data = MLUtils.loadLabeledData(sc, args(1)) val model = LinearRegressionWithSGD.train(data, args(3).toInt, args(2).toDouble) + println("Weights: " + model.weights.mkString("[", ", ", "]")) + println("Intercept: " + model.intercept) sc.stop() } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala index b29508d2b9..c125c6797a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala @@ -122,7 +122,7 @@ object RidgeRegressionWithSGD { * @param stepSize Step size to be used for each iteration of gradient descent. * @param regParam Regularization parameter. * @param miniBatchFraction Fraction of data to be used per iteration. - * @param initialWeights Initial set of weights to be used. Array should be equal in size to + * @param initialWeights Initial set of weights to be used. Array should be equal in size to * the number of features in the data. */ def train( @@ -208,6 +208,8 @@ object RidgeRegressionWithSGD { val data = MLUtils.loadLabeledData(sc, args(1)) val model = RidgeRegressionWithSGD.train(data, args(4).toInt, args(2).toDouble, args(3).toDouble) + println("Weights: " + model.weights.mkString("[", ", ", "]")) + println("Intercept: " + model.intercept) sc.stop() } diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java new file mode 100644 index 0000000000..23ea3548b9 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java @@ -0,0 +1,72 @@ +package org.apache.spark.mllib.classification; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import java.io.Serializable; +import java.util.Arrays; +import java.util.List; + +public class JavaNaiveBayesSuite implements Serializable { + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaNaiveBayesSuite"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + System.clearProperty("spark.driver.port"); + } + + private static final List POINTS = Arrays.asList( + new LabeledPoint(0, new double[] {1.0, 0.0, 0.0}), + new LabeledPoint(0, new double[] {2.0, 0.0, 0.0}), + new LabeledPoint(1, new double[] {0.0, 1.0, 0.0}), + new LabeledPoint(1, new double[] {0.0, 2.0, 0.0}), + new LabeledPoint(2, new double[] {0.0, 0.0, 1.0}), + new LabeledPoint(2, new double[] {0.0, 0.0, 2.0}) + ); + + private int validatePrediction(List points, NaiveBayesModel model) { + int correct = 0; + for (LabeledPoint p: points) { + if (model.predict(p.features()) == p.label()) { + correct += 1; + } + } + return correct; + } + + @Test + public void runUsingConstructor() { + JavaRDD testRDD = sc.parallelize(POINTS, 2).cache(); + + NaiveBayes nb = new NaiveBayes().setLambda(1.0); + NaiveBayesModel model = nb.run(testRDD.rdd()); + + int numAccurate = validatePrediction(POINTS, model); + Assert.assertEquals(POINTS.size(), numAccurate); + } + + @Test + public void runUsingStaticMethods() { + JavaRDD testRDD = sc.parallelize(POINTS, 2).cache(); + + NaiveBayesModel model1 = NaiveBayes.train(testRDD.rdd()); + int numAccurate1 = validatePrediction(POINTS, model1); + Assert.assertEquals(POINTS.size(), numAccurate1); + + NaiveBayesModel model2 = NaiveBayes.train(testRDD.rdd(), 0.5); + int numAccurate2 = validatePrediction(POINTS, model2); + Assert.assertEquals(POINTS.size(), numAccurate2); + } +} -- cgit v1.2.3