aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/data/sample_naive_bayes_data.txt6
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala21
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala4
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java72
8 files changed, 111 insertions, 4 deletions
diff --git a/mllib/data/sample_naive_bayes_data.txt b/mllib/data/sample_naive_bayes_data.txt
new file mode 100644
index 0000000000..f874adbaf4
--- /dev/null
+++ b/mllib/data/sample_naive_bayes_data.txt
@@ -0,0 +1,6 @@
+0, 1 0 0
+0, 2 0 0
+1, 0 1 0
+1, 0 2 0
+2, 0 0 1
+2, 0 0 2
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
index 50aede9c07..a481f52276 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
@@ -97,7 +97,7 @@ object LogisticRegressionWithSGD {
* @param numIterations Number of iterations of gradient descent to run.
* @param stepSize Step size to be used for each iteration of gradient descent.
* @param miniBatchFraction Fraction of data to be used per iteration.
- * @param initialWeights Initial set of weights to be used. Array should be equal in size to
+ * @param initialWeights Initial set of weights to be used. Array should be equal in size to
* the number of features in the data.
*/
def train(
@@ -183,6 +183,8 @@ object LogisticRegressionWithSGD {
val sc = new SparkContext(args(0), "LogisticRegression")
val data = MLUtils.loadLabeledData(sc, args(1))
val model = LogisticRegressionWithSGD.train(data, args(3).toInt, args(2).toDouble)
+ println("Weights: " + model.weights.mkString("[", ", ", "]"))
+ println("Intercept: " + model.intercept)
sc.stop()
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
index f45802cd0b..6539b2f339 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
@@ -21,9 +21,10 @@ import scala.collection.mutable
import org.jblas.DoubleMatrix
-import org.apache.spark.Logging
+import org.apache.spark.{SparkContext, Logging}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.RDD
+import org.apache.spark.mllib.util.MLUtils
/**
* Model for Naive Bayes Classifiers.
@@ -144,4 +145,22 @@ object NaiveBayes {
def train(input: RDD[LabeledPoint], lambda: Double): NaiveBayesModel = {
new NaiveBayes(lambda).run(input)
}
+
+ def main(args: Array[String]) {
+ if (args.length != 2 && args.length != 3) {
+ println("Usage: NaiveBayes <master> <input_dir> [<lambda>]")
+ System.exit(1)
+ }
+ val sc = new SparkContext(args(0), "NaiveBayes")
+ val data = MLUtils.loadLabeledData(sc, args(1))
+ val model = if (args.length == 2) {
+ NaiveBayes.train(data)
+ } else {
+ NaiveBayes.train(data, args(2).toDouble)
+ }
+ println("Pi: " + model.pi.mkString("[", ", ", "]"))
+ println("Theta:\n" + model.theta.map(_.mkString("[", ", ", "]")).mkString("[", "\n ", "]"))
+
+ sc.stop()
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
index 3b8f8550d0..f2964ea446 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
@@ -183,6 +183,8 @@ object SVMWithSGD {
val sc = new SparkContext(args(0), "SVM")
val data = MLUtils.loadLabeledData(sc, args(1))
val model = SVMWithSGD.train(data, args(4).toInt, args(2).toDouble, args(3).toDouble)
+ println("Weights: " + model.weights.mkString("[", ", ", "]"))
+ println("Intercept: " + model.intercept)
sc.stop()
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala
index d959695325..7c41793722 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala
@@ -121,7 +121,7 @@ object LassoWithSGD {
* @param stepSize Step size to be used for each iteration of gradient descent.
* @param regParam Regularization parameter.
* @param miniBatchFraction Fraction of data to be used per iteration.
- * @param initialWeights Initial set of weights to be used. Array should be equal in size to
+ * @param initialWeights Initial set of weights to be used. Array should be equal in size to
* the number of features in the data.
*/
def train(
@@ -205,6 +205,8 @@ object LassoWithSGD {
val sc = new SparkContext(args(0), "Lasso")
val data = MLUtils.loadLabeledData(sc, args(1))
val model = LassoWithSGD.train(data, args(4).toInt, args(2).toDouble, args(3).toDouble)
+ println("Weights: " + model.weights.mkString("[", ", ", "]"))
+ println("Intercept: " + model.intercept)
sc.stop()
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala
index 597d55e0bb..fe5cce064b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala
@@ -162,6 +162,8 @@ object LinearRegressionWithSGD {
val sc = new SparkContext(args(0), "LinearRegression")
val data = MLUtils.loadLabeledData(sc, args(1))
val model = LinearRegressionWithSGD.train(data, args(3).toInt, args(2).toDouble)
+ println("Weights: " + model.weights.mkString("[", ", ", "]"))
+ println("Intercept: " + model.intercept)
sc.stop()
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala
index b29508d2b9..c125c6797a 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala
@@ -122,7 +122,7 @@ object RidgeRegressionWithSGD {
* @param stepSize Step size to be used for each iteration of gradient descent.
* @param regParam Regularization parameter.
* @param miniBatchFraction Fraction of data to be used per iteration.
- * @param initialWeights Initial set of weights to be used. Array should be equal in size to
+ * @param initialWeights Initial set of weights to be used. Array should be equal in size to
* the number of features in the data.
*/
def train(
@@ -208,6 +208,8 @@ object RidgeRegressionWithSGD {
val data = MLUtils.loadLabeledData(sc, args(1))
val model = RidgeRegressionWithSGD.train(data, args(4).toInt, args(2).toDouble,
args(3).toDouble)
+ println("Weights: " + model.weights.mkString("[", ", ", "]"))
+ println("Intercept: " + model.intercept)
sc.stop()
}
diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java
new file mode 100644
index 0000000000..23ea3548b9
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java
@@ -0,0 +1,72 @@
+package org.apache.spark.mllib.classification;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.List;
+
+public class JavaNaiveBayesSuite implements Serializable {
+ private transient JavaSparkContext sc;
+
+ @Before
+ public void setUp() {
+ sc = new JavaSparkContext("local", "JavaNaiveBayesSuite");
+ }
+
+ @After
+ public void tearDown() {
+ sc.stop();
+ sc = null;
+ System.clearProperty("spark.driver.port");
+ }
+
+ private static final List<LabeledPoint> POINTS = Arrays.asList(
+ new LabeledPoint(0, new double[] {1.0, 0.0, 0.0}),
+ new LabeledPoint(0, new double[] {2.0, 0.0, 0.0}),
+ new LabeledPoint(1, new double[] {0.0, 1.0, 0.0}),
+ new LabeledPoint(1, new double[] {0.0, 2.0, 0.0}),
+ new LabeledPoint(2, new double[] {0.0, 0.0, 1.0}),
+ new LabeledPoint(2, new double[] {0.0, 0.0, 2.0})
+ );
+
+ private int validatePrediction(List<LabeledPoint> points, NaiveBayesModel model) {
+ int correct = 0;
+ for (LabeledPoint p: points) {
+ if (model.predict(p.features()) == p.label()) {
+ correct += 1;
+ }
+ }
+ return correct;
+ }
+
+ @Test
+ public void runUsingConstructor() {
+ JavaRDD<LabeledPoint> testRDD = sc.parallelize(POINTS, 2).cache();
+
+ NaiveBayes nb = new NaiveBayes().setLambda(1.0);
+ NaiveBayesModel model = nb.run(testRDD.rdd());
+
+ int numAccurate = validatePrediction(POINTS, model);
+ Assert.assertEquals(POINTS.size(), numAccurate);
+ }
+
+ @Test
+ public void runUsingStaticMethods() {
+ JavaRDD<LabeledPoint> testRDD = sc.parallelize(POINTS, 2).cache();
+
+ NaiveBayesModel model1 = NaiveBayes.train(testRDD.rdd());
+ int numAccurate1 = validatePrediction(POINTS, model1);
+ Assert.assertEquals(POINTS.size(), numAccurate1);
+
+ NaiveBayesModel model2 = NaiveBayes.train(testRDD.rdd(), 0.5);
+ int numAccurate2 = validatePrediction(POINTS, model2);
+ Assert.assertEquals(POINTS.size(), numAccurate2);
+ }
+}