aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorReza Zadeh <rizlar@gmail.com>2014-01-13 23:52:34 -0800
committerReza Zadeh <rizlar@gmail.com>2014-01-13 23:52:34 -0800
commit845e568fada0550e632e7381748c5a9ebbe53e16 (patch)
tree3a4fa34894df649b5ef429cd794b73cf4b3e99b1 /mllib
parentf324d5355514b1c7ae85019b476046bb64b5593e (diff)
parentfdaabdc67387524ffb84354f87985f48bd31cf60 (diff)
downloadspark-845e568fada0550e632e7381748c5a9ebbe53e16.tar.gz
spark-845e568fada0550e632e7381748c5a9ebbe53e16.tar.bz2
spark-845e568fada0550e632e7381748c5a9ebbe53e16.zip
Merge remote-tracking branch 'upstream/master' into sparsesvd
Diffstat (limited to 'mllib')
-rw-r--r--mllib/data/sample_naive_bayes_data.txt6
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala46
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala65
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala15
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala4
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java72
11 files changed, 191 insertions, 35 deletions
diff --git a/mllib/data/sample_naive_bayes_data.txt b/mllib/data/sample_naive_bayes_data.txt
new file mode 100644
index 0000000000..f874adbaf4
--- /dev/null
+++ b/mllib/data/sample_naive_bayes_data.txt
@@ -0,0 +1,6 @@
+0, 1 0 0
+0, 2 0 0
+1, 0 1 0
+1, 0 2 0
+2, 0 0 1
+2, 0 0 2
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index 2d8623392e..3fec1a909d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -48,7 +48,7 @@ class PythonMLLibAPI extends Serializable {
val db = bb.asDoubleBuffer()
val ans = new Array[Double](length.toInt)
db.get(ans)
- return ans
+ ans
}
private def serializeDoubleVector(doubles: Array[Double]): Array[Byte] = {
@@ -60,7 +60,7 @@ class PythonMLLibAPI extends Serializable {
bb.putLong(len)
val db = bb.asDoubleBuffer()
db.put(doubles)
- return bytes
+ bytes
}
private def deserializeDoubleMatrix(bytes: Array[Byte]): Array[Array[Double]] = {
@@ -86,7 +86,7 @@ class PythonMLLibAPI extends Serializable {
ans(i) = new Array[Double](cols.toInt)
db.get(ans(i))
}
- return ans
+ ans
}
private def serializeDoubleMatrix(doubles: Array[Array[Double]]): Array[Byte] = {
@@ -102,11 +102,10 @@ class PythonMLLibAPI extends Serializable {
bb.putLong(rows)
bb.putLong(cols)
val db = bb.asDoubleBuffer()
- var i = 0
for (i <- 0 until rows) {
db.put(doubles(i))
}
- return bytes
+ bytes
}
private def trainRegressionModel(trainFunc: (RDD[LabeledPoint], Array[Double]) => GeneralizedLinearModel,
@@ -121,7 +120,7 @@ class PythonMLLibAPI extends Serializable {
val ret = new java.util.LinkedList[java.lang.Object]()
ret.add(serializeDoubleVector(model.weights))
ret.add(model.intercept: java.lang.Double)
- return ret
+ ret
}
/**
@@ -130,7 +129,7 @@ class PythonMLLibAPI extends Serializable {
def trainLinearRegressionModelWithSGD(dataBytesJRDD: JavaRDD[Array[Byte]],
numIterations: Int, stepSize: Double, miniBatchFraction: Double,
initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = {
- return trainRegressionModel((data, initialWeights) =>
+ trainRegressionModel((data, initialWeights) =>
LinearRegressionWithSGD.train(data, numIterations, stepSize,
miniBatchFraction, initialWeights),
dataBytesJRDD, initialWeightsBA)
@@ -142,7 +141,7 @@ class PythonMLLibAPI extends Serializable {
def trainLassoModelWithSGD(dataBytesJRDD: JavaRDD[Array[Byte]], numIterations: Int,
stepSize: Double, regParam: Double, miniBatchFraction: Double,
initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = {
- return trainRegressionModel((data, initialWeights) =>
+ trainRegressionModel((data, initialWeights) =>
LassoWithSGD.train(data, numIterations, stepSize, regParam,
miniBatchFraction, initialWeights),
dataBytesJRDD, initialWeightsBA)
@@ -154,7 +153,7 @@ class PythonMLLibAPI extends Serializable {
def trainRidgeModelWithSGD(dataBytesJRDD: JavaRDD[Array[Byte]], numIterations: Int,
stepSize: Double, regParam: Double, miniBatchFraction: Double,
initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = {
- return trainRegressionModel((data, initialWeights) =>
+ trainRegressionModel((data, initialWeights) =>
RidgeRegressionWithSGD.train(data, numIterations, stepSize, regParam,
miniBatchFraction, initialWeights),
dataBytesJRDD, initialWeightsBA)
@@ -166,7 +165,7 @@ class PythonMLLibAPI extends Serializable {
def trainSVMModelWithSGD(dataBytesJRDD: JavaRDD[Array[Byte]], numIterations: Int,
stepSize: Double, regParam: Double, miniBatchFraction: Double,
initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = {
- return trainRegressionModel((data, initialWeights) =>
+ trainRegressionModel((data, initialWeights) =>
SVMWithSGD.train(data, numIterations, stepSize, regParam,
miniBatchFraction, initialWeights),
dataBytesJRDD, initialWeightsBA)
@@ -178,13 +177,30 @@ class PythonMLLibAPI extends Serializable {
def trainLogisticRegressionModelWithSGD(dataBytesJRDD: JavaRDD[Array[Byte]],
numIterations: Int, stepSize: Double, miniBatchFraction: Double,
initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = {
- return trainRegressionModel((data, initialWeights) =>
+ trainRegressionModel((data, initialWeights) =>
LogisticRegressionWithSGD.train(data, numIterations, stepSize,
miniBatchFraction, initialWeights),
dataBytesJRDD, initialWeightsBA)
}
/**
+ * Java stub for NaiveBayes.train()
+ */
+ def trainNaiveBayes(dataBytesJRDD: JavaRDD[Array[Byte]], lambda: Double)
+ : java.util.List[java.lang.Object] =
+ {
+ val data = dataBytesJRDD.rdd.map(xBytes => {
+ val x = deserializeDoubleVector(xBytes)
+ LabeledPoint(x(0), x.slice(1, x.length))
+ })
+ val model = NaiveBayes.train(data, lambda)
+ val ret = new java.util.LinkedList[java.lang.Object]()
+ ret.add(serializeDoubleVector(model.pi))
+ ret.add(serializeDoubleMatrix(model.theta))
+ ret
+ }
+
+ /**
* Java stub for Python mllib KMeans.train()
*/
def trainKMeansModel(dataBytesJRDD: JavaRDD[Array[Byte]], k: Int,
@@ -194,7 +210,7 @@ class PythonMLLibAPI extends Serializable {
val model = KMeans.train(data, k, maxIterations, runs, initializationMode)
val ret = new java.util.LinkedList[java.lang.Object]()
ret.add(serializeDoubleMatrix(model.clusterCenters))
- return ret
+ ret
}
/** Unpack a Rating object from an array of bytes */
@@ -204,7 +220,7 @@ class PythonMLLibAPI extends Serializable {
val user = bb.getInt()
val product = bb.getInt()
val rating = bb.getDouble()
- return new Rating(user, product, rating)
+ new Rating(user, product, rating)
}
/** Unpack a tuple of Ints from an array of bytes */
@@ -245,7 +261,7 @@ class PythonMLLibAPI extends Serializable {
def trainALSModel(ratingsBytesJRDD: JavaRDD[Array[Byte]], rank: Int,
iterations: Int, lambda: Double, blocks: Int): MatrixFactorizationModel = {
val ratings = ratingsBytesJRDD.rdd.map(unpackRating)
- return ALS.train(ratings, rank, iterations, lambda, blocks)
+ ALS.train(ratings, rank, iterations, lambda, blocks)
}
/**
@@ -257,6 +273,6 @@ class PythonMLLibAPI extends Serializable {
def trainImplicitALSModel(ratingsBytesJRDD: JavaRDD[Array[Byte]], rank: Int,
iterations: Int, lambda: Double, blocks: Int, alpha: Double): MatrixFactorizationModel = {
val ratings = ratingsBytesJRDD.rdd.map(unpackRating)
- return ALS.trainImplicit(ratings, rank, iterations, lambda, blocks, alpha)
+ ALS.trainImplicit(ratings, rank, iterations, lambda, blocks, alpha)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
index 50aede9c07..a481f52276 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
@@ -97,7 +97,7 @@ object LogisticRegressionWithSGD {
* @param numIterations Number of iterations of gradient descent to run.
* @param stepSize Step size to be used for each iteration of gradient descent.
* @param miniBatchFraction Fraction of data to be used per iteration.
- * @param initialWeights Initial set of weights to be used. Array should be equal in size to
+ * @param initialWeights Initial set of weights to be used. Array should be equal in size to
* the number of features in the data.
*/
def train(
@@ -183,6 +183,8 @@ object LogisticRegressionWithSGD {
val sc = new SparkContext(args(0), "LogisticRegression")
val data = MLUtils.loadLabeledData(sc, args(1))
val model = LogisticRegressionWithSGD.train(data, args(3).toInt, args(2).toDouble)
+ println("Weights: " + model.weights.mkString("[", ", ", "]"))
+ println("Intercept: " + model.intercept)
sc.stop()
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
index 524300d6ae..6539b2f339 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
@@ -21,17 +21,18 @@ import scala.collection.mutable
import org.jblas.DoubleMatrix
-import org.apache.spark.Logging
+import org.apache.spark.{SparkContext, Logging}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.RDD
+import org.apache.spark.mllib.util.MLUtils
/**
* Model for Naive Bayes Classifiers.
*
* @param pi Log of class priors, whose dimension is C.
- * @param theta Log of class conditional probabilities, whose dimension is CXD.
+ * @param theta Log of class conditional probabilities, whose dimension is CxD.
*/
-class NaiveBayesModel(pi: Array[Double], theta: Array[Array[Double]])
+class NaiveBayesModel(val pi: Array[Double], val theta: Array[Array[Double]])
extends ClassificationModel with Serializable {
// Create a column vector that can be used for predictions
@@ -50,10 +51,21 @@ class NaiveBayesModel(pi: Array[Double], theta: Array[Array[Double]])
/**
* Trains a Naive Bayes model given an RDD of `(label, features)` pairs.
*
- * @param lambda The smooth parameter
+ * This is the Multinomial NB ([[http://tinyurl.com/lsdw6p]]) which can handle all kinds of
+ * discrete data. For example, by converting documents into TF-IDF vectors, it can be used for
+ * document classification. By making every vector a 0-1 vector, it can also be used as
+ * Bernoulli NB ([[http://tinyurl.com/p7c96j6]]).
*/
-class NaiveBayes private (val lambda: Double = 1.0)
- extends Serializable with Logging {
+class NaiveBayes private (var lambda: Double)
+ extends Serializable with Logging
+{
+ def this() = this(1.0)
+
+ /** Set the smoothing parameter. Default: 1.0. */
+ def setLambda(lambda: Double): NaiveBayes = {
+ this.lambda = lambda
+ this
+ }
/**
* Run the algorithm with the configured parameters on an input RDD of LabeledPoint entries.
@@ -106,14 +118,49 @@ object NaiveBayes {
*
* This is the Multinomial NB ([[http://tinyurl.com/lsdw6p]]) which can handle all kinds of
* discrete data. For example, by converting documents into TF-IDF vectors, it can be used for
- * document classification. By making every vector a 0-1 vector. it can also be used as
+ * document classification. By making every vector a 0-1 vector, it can also be used as
+ * Bernoulli NB ([[http://tinyurl.com/p7c96j6]]).
+ *
+ * This version of the method uses a default smoothing parameter of 1.0.
+ *
+ * @param input RDD of `(label, array of features)` pairs. Every vector should be a frequency
+ * vector or a count vector.
+ */
+ def train(input: RDD[LabeledPoint]): NaiveBayesModel = {
+ new NaiveBayes().run(input)
+ }
+
+ /**
+ * Trains a Naive Bayes model given an RDD of `(label, features)` pairs.
+ *
+ * This is the Multinomial NB ([[http://tinyurl.com/lsdw6p]]) which can handle all kinds of
+ * discrete data. For example, by converting documents into TF-IDF vectors, it can be used for
+ * document classification. By making every vector a 0-1 vector, it can also be used as
* Bernoulli NB ([[http://tinyurl.com/p7c96j6]]).
*
* @param input RDD of `(label, array of features)` pairs. Every vector should be a frequency
* vector or a count vector.
- * @param lambda The smooth parameter
+ * @param lambda The smoothing parameter
*/
- def train(input: RDD[LabeledPoint], lambda: Double = 1.0): NaiveBayesModel = {
+ def train(input: RDD[LabeledPoint], lambda: Double): NaiveBayesModel = {
new NaiveBayes(lambda).run(input)
}
+
+ def main(args: Array[String]) {
+ if (args.length != 2 && args.length != 3) {
+ println("Usage: NaiveBayes <master> <input_dir> [<lambda>]")
+ System.exit(1)
+ }
+ val sc = new SparkContext(args(0), "NaiveBayes")
+ val data = MLUtils.loadLabeledData(sc, args(1))
+ val model = if (args.length == 2) {
+ NaiveBayes.train(data)
+ } else {
+ NaiveBayes.train(data, args(2).toDouble)
+ }
+ println("Pi: " + model.pi.mkString("[", ", ", "]"))
+ println("Theta:\n" + model.theta.map(_.mkString("[", ", ", "]")).mkString("[", "\n ", "]"))
+
+ sc.stop()
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
index 3b8f8550d0..f2964ea446 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
@@ -183,6 +183,8 @@ object SVMWithSGD {
val sc = new SparkContext(args(0), "SVM")
val data = MLUtils.loadLabeledData(sc, args(1))
val model = SVMWithSGD.train(data, args(4).toInt, args(2).toDouble, args(3).toDouble)
+ println("Weights: " + model.weights.mkString("[", ", ", "]"))
+ println("Intercept: " + model.intercept)
sc.stop()
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
index 8b27ecf82c..89ee07063d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
@@ -22,7 +22,7 @@ import scala.util.Random
import scala.util.Sorting
import org.apache.spark.broadcast.Broadcast
-import org.apache.spark.{Logging, HashPartitioner, Partitioner, SparkContext}
+import org.apache.spark.{Logging, HashPartitioner, Partitioner, SparkContext, SparkConf}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.rdd.RDD
import org.apache.spark.serializer.KryoRegistrator
@@ -578,12 +578,13 @@ object ALS {
val implicitPrefs = if (args.length >= 7) args(6).toBoolean else false
val alpha = if (args.length >= 8) args(7).toDouble else 1
val blocks = if (args.length == 9) args(8).toInt else -1
- val sc = new SparkContext(master, "ALS")
- sc.conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
- sc.conf.set("spark.kryo.registrator", classOf[ALSRegistrator].getName)
- sc.conf.set("spark.kryo.referenceTracking", "false")
- sc.conf.set("spark.kryoserializer.buffer.mb", "8")
- sc.conf.set("spark.locality.wait", "10000")
+ val conf = new SparkConf()
+ .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
+ .set("spark.kryo.registrator", classOf[ALSRegistrator].getName)
+ .set("spark.kryo.referenceTracking", "false")
+ .set("spark.kryoserializer.buffer.mb", "8")
+ .set("spark.locality.wait", "10000")
+ val sc = new SparkContext(master, "ALS", conf)
val ratings = sc.textFile(ratingsFile).map { line =>
val fields = line.split(',')
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala
index 63240e24dc..1a18292fe3 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala
@@ -23,4 +23,8 @@ package org.apache.spark.mllib.regression
* @param label Label for this data point.
* @param features List of features for this data point.
*/
-case class LabeledPoint(val label: Double, val features: Array[Double])
+case class LabeledPoint(label: Double, features: Array[Double]) {
+ override def toString: String = {
+ "LabeledPoint(%s, %s)".format(label, features.mkString("[", ", ", "]"))
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala
index d959695325..7c41793722 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala
@@ -121,7 +121,7 @@ object LassoWithSGD {
* @param stepSize Step size to be used for each iteration of gradient descent.
* @param regParam Regularization parameter.
* @param miniBatchFraction Fraction of data to be used per iteration.
- * @param initialWeights Initial set of weights to be used. Array should be equal in size to
+ * @param initialWeights Initial set of weights to be used. Array should be equal in size to
* the number of features in the data.
*/
def train(
@@ -205,6 +205,8 @@ object LassoWithSGD {
val sc = new SparkContext(args(0), "Lasso")
val data = MLUtils.loadLabeledData(sc, args(1))
val model = LassoWithSGD.train(data, args(4).toInt, args(2).toDouble, args(3).toDouble)
+ println("Weights: " + model.weights.mkString("[", ", ", "]"))
+ println("Intercept: " + model.intercept)
sc.stop()
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala
index 597d55e0bb..fe5cce064b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala
@@ -162,6 +162,8 @@ object LinearRegressionWithSGD {
val sc = new SparkContext(args(0), "LinearRegression")
val data = MLUtils.loadLabeledData(sc, args(1))
val model = LinearRegressionWithSGD.train(data, args(3).toInt, args(2).toDouble)
+ println("Weights: " + model.weights.mkString("[", ", ", "]"))
+ println("Intercept: " + model.intercept)
sc.stop()
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala
index b29508d2b9..c125c6797a 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala
@@ -122,7 +122,7 @@ object RidgeRegressionWithSGD {
* @param stepSize Step size to be used for each iteration of gradient descent.
* @param regParam Regularization parameter.
* @param miniBatchFraction Fraction of data to be used per iteration.
- * @param initialWeights Initial set of weights to be used. Array should be equal in size to
+ * @param initialWeights Initial set of weights to be used. Array should be equal in size to
* the number of features in the data.
*/
def train(
@@ -208,6 +208,8 @@ object RidgeRegressionWithSGD {
val data = MLUtils.loadLabeledData(sc, args(1))
val model = RidgeRegressionWithSGD.train(data, args(4).toInt, args(2).toDouble,
args(3).toDouble)
+ println("Weights: " + model.weights.mkString("[", ", ", "]"))
+ println("Intercept: " + model.intercept)
sc.stop()
}
diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java
new file mode 100644
index 0000000000..23ea3548b9
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java
@@ -0,0 +1,72 @@
+package org.apache.spark.mllib.classification;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.List;
+
+public class JavaNaiveBayesSuite implements Serializable {
+ private transient JavaSparkContext sc;
+
+ @Before
+ public void setUp() {
+ sc = new JavaSparkContext("local", "JavaNaiveBayesSuite");
+ }
+
+ @After
+ public void tearDown() {
+ sc.stop();
+ sc = null;
+ System.clearProperty("spark.driver.port");
+ }
+
+ private static final List<LabeledPoint> POINTS = Arrays.asList(
+ new LabeledPoint(0, new double[] {1.0, 0.0, 0.0}),
+ new LabeledPoint(0, new double[] {2.0, 0.0, 0.0}),
+ new LabeledPoint(1, new double[] {0.0, 1.0, 0.0}),
+ new LabeledPoint(1, new double[] {0.0, 2.0, 0.0}),
+ new LabeledPoint(2, new double[] {0.0, 0.0, 1.0}),
+ new LabeledPoint(2, new double[] {0.0, 0.0, 2.0})
+ );
+
+ private int validatePrediction(List<LabeledPoint> points, NaiveBayesModel model) {
+ int correct = 0;
+ for (LabeledPoint p: points) {
+ if (model.predict(p.features()) == p.label()) {
+ correct += 1;
+ }
+ }
+ return correct;
+ }
+
+ @Test
+ public void runUsingConstructor() {
+ JavaRDD<LabeledPoint> testRDD = sc.parallelize(POINTS, 2).cache();
+
+ NaiveBayes nb = new NaiveBayes().setLambda(1.0);
+ NaiveBayesModel model = nb.run(testRDD.rdd());
+
+ int numAccurate = validatePrediction(POINTS, model);
+ Assert.assertEquals(POINTS.size(), numAccurate);
+ }
+
+ @Test
+ public void runUsingStaticMethods() {
+ JavaRDD<LabeledPoint> testRDD = sc.parallelize(POINTS, 2).cache();
+
+ NaiveBayesModel model1 = NaiveBayes.train(testRDD.rdd());
+ int numAccurate1 = validatePrediction(POINTS, model1);
+ Assert.assertEquals(POINTS.size(), numAccurate1);
+
+ NaiveBayesModel model2 = NaiveBayes.train(testRDD.rdd(), 0.5);
+ int numAccurate2 = validatePrediction(POINTS, model2);
+ Assert.assertEquals(POINTS.size(), numAccurate2);
+ }
+}