aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--docs/mllib-naive-bayes.md17
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala225
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java23
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala148
4 files changed, 322 insertions, 91 deletions
diff --git a/docs/mllib-naive-bayes.md b/docs/mllib-naive-bayes.md
index a83472f5be..9780ea52c4 100644
--- a/docs/mllib-naive-bayes.md
+++ b/docs/mllib-naive-bayes.md
@@ -13,12 +13,15 @@ compute the conditional probability distribution of label given an observation
and use it for prediction.
MLlib supports [multinomial naive
-Bayes](http://en.wikipedia.org/wiki/Naive_Bayes_classifier#Multinomial_naive_Bayes),
-which is typically used for [document
-classification](http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html).
+Bayes](http://en.wikipedia.org/wiki/Naive_Bayes_classifier#Multinomial_naive_Bayes)
+and [Bernoulli naive Bayes] (http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html).
+These models are typically used for [document classification]
+(http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html).
Within that context, each observation is a document and each
-feature represents a term whose value is the frequency of the term.
-Feature values must be nonnegative to represent term frequencies.
+feature represents a term whose value is the frequency of the term (in multinomial naive Bayes) or
+a zero or one indicating whether the term was found in the document (in Bernoulli naive Bayes).
+Feature values must be nonnegative. The model type is selected with an optional parameter
+"Multinomial" or "Bernoulli" with "Multinomial" as the default.
[Additive smoothing](http://en.wikipedia.org/wiki/Lidstone_smoothing) can be used by
setting the parameter $\lambda$ (default to $1.0$). For document classification, the input feature
vectors are usually sparse, and sparse vectors should be supplied as input to take advantage of
@@ -32,7 +35,7 @@ sparsity. Since the training data is only used once, it is not necessary to cach
[NaiveBayes](api/scala/index.html#org.apache.spark.mllib.classification.NaiveBayes$) implements
multinomial naive Bayes. It takes an RDD of
[LabeledPoint](api/scala/index.html#org.apache.spark.mllib.regression.LabeledPoint) and an optional
-smoothing parameter `lambda` as input, and output a
+smoothing parameter `lambda` as input, an optional model type parameter (default is Multinomial), and outputs a
[NaiveBayesModel](api/scala/index.html#org.apache.spark.mllib.classification.NaiveBayesModel), which
can be used for evaluation and prediction.
@@ -51,7 +54,7 @@ val splits = parsedData.randomSplit(Array(0.6, 0.4), seed = 11L)
val training = splits(0)
val test = splits(1)
-val model = NaiveBayes.train(training, lambda = 1.0)
+val model = NaiveBayes.train(training, lambda = 1.0, model = "Multinomial")
val predictionAndLabel = test.map(p => (model.predict(p.features), p.label))
val accuracy = 1.0 * predictionAndLabel.filter(x => x._1 == x._2).count() / test.count()
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
index d60e82c410..c9b3ff0172 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
@@ -21,9 +21,12 @@ import java.lang.{Iterable => JIterable}
import scala.collection.JavaConverters._
-import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum}
+import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum, Axis}
+import breeze.numerics.{exp => brzExp, log => brzLog}
+
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._
+import org.json4s.{DefaultFormats, JValue}
import org.apache.spark.{Logging, SparkContext, SparkException}
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector}
@@ -32,6 +35,7 @@ import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, SQLContext}
+
/**
* Model for Naive Bayes Classifiers.
*
@@ -39,11 +43,17 @@ import org.apache.spark.sql.{DataFrame, SQLContext}
* @param pi log of class priors, whose dimension is C, number of labels
* @param theta log of class conditional probabilities, whose dimension is C-by-D,
* where D is number of features
+ * @param modelType The type of NB model to fit can be "Multinomial" or "Bernoulli"
*/
class NaiveBayesModel private[mllib] (
val labels: Array[Double],
val pi: Array[Double],
- val theta: Array[Array[Double]]) extends ClassificationModel with Serializable with Saveable {
+ val theta: Array[Array[Double]],
+ val modelType: String)
+ extends ClassificationModel with Serializable with Saveable {
+
+ private[mllib] def this(labels: Array[Double], pi: Array[Double], theta: Array[Array[Double]]) =
+ this(labels, pi, theta, "Multinomial")
/** A Java-friendly constructor that takes three Iterable parameters. */
private[mllib] def this(
@@ -53,19 +63,19 @@ class NaiveBayesModel private[mllib] (
this(labels.asScala.toArray, pi.asScala.toArray, theta.asScala.toArray.map(_.asScala.toArray))
private val brzPi = new BDV[Double](pi)
- private val brzTheta = new BDM[Double](theta.length, theta(0).length)
-
- {
- // Need to put an extra pair of braces to prevent Scala treating `i` as a member.
- var i = 0
- while (i < theta.length) {
- var j = 0
- while (j < theta(i).length) {
- brzTheta(i, j) = theta(i)(j)
- j += 1
- }
- i += 1
- }
+ private val brzTheta = new BDM(theta(0).length, theta.length, theta.flatten).t
+
+ // Bernoulli scoring requires log(condprob) if 1, log(1-condprob) if 0.
+ // This precomputes log(1.0 - exp(theta)) and its sum which are used for the linear algebra
+ // application of this condition (in predict function).
+ private val (brzNegTheta, brzNegThetaSum) = modelType match {
+ case "Multinomial" => (None, None)
+ case "Bernoulli" =>
+ val negTheta = brzLog((brzExp(brzTheta.copy) :*= (-1.0)) :+= 1.0) // log(1.0 - exp(x))
+ (Option(negTheta), Option(brzSum(negTheta, Axis._1)))
+ case _ =>
+ // This should never happen.
+ throw new UnknownError(s"NaiveBayesModel was created with an unknown ModelType: $modelType")
}
override def predict(testData: RDD[Vector]): RDD[Double] = {
@@ -77,22 +87,78 @@ class NaiveBayesModel private[mllib] (
}
override def predict(testData: Vector): Double = {
- labels(brzArgmax(brzPi + brzTheta * testData.toBreeze))
+ modelType match {
+ case "Multinomial" =>
+ labels (brzArgmax (brzPi + brzTheta * testData.toBreeze) )
+ case "Bernoulli" =>
+ labels (brzArgmax (brzPi +
+ (brzTheta - brzNegTheta.get) * testData.toBreeze + brzNegThetaSum.get))
+ case _ =>
+ // This should never happen.
+ throw new UnknownError(s"NaiveBayesModel was created with an unknown ModelType: $modelType")
+ }
}
override def save(sc: SparkContext, path: String): Unit = {
- val data = NaiveBayesModel.SaveLoadV1_0.Data(labels, pi, theta)
- NaiveBayesModel.SaveLoadV1_0.save(sc, path, data)
+ val data = NaiveBayesModel.SaveLoadV2_0.Data(labels, pi, theta, modelType)
+ NaiveBayesModel.SaveLoadV2_0.save(sc, path, data)
}
- override protected def formatVersion: String = "1.0"
+ override protected def formatVersion: String = "2.0"
}
object NaiveBayesModel extends Loader[NaiveBayesModel] {
import org.apache.spark.mllib.util.Loader._
- private object SaveLoadV1_0 {
+ private[mllib] object SaveLoadV2_0 {
+
+ def thisFormatVersion: String = "2.0"
+
+ /** Hard-code class name string in case it changes in the future */
+ def thisClassName: String = "org.apache.spark.mllib.classification.NaiveBayesModel"
+
+ /** Model data for model import/export */
+ case class Data(
+ labels: Array[Double],
+ pi: Array[Double],
+ theta: Array[Array[Double]],
+ modelType: String)
+
+ def save(sc: SparkContext, path: String, data: Data): Unit = {
+ val sqlContext = new SQLContext(sc)
+ import sqlContext.implicits._
+
+ // Create JSON metadata.
+ val metadata = compact(render(
+ ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~
+ ("numFeatures" -> data.theta(0).length) ~ ("numClasses" -> data.pi.length)))
+ sc.parallelize(Seq(metadata), 1).saveAsTextFile(metadataPath(path))
+
+ // Create Parquet data.
+ val dataRDD: DataFrame = sc.parallelize(Seq(data), 1).toDF()
+ dataRDD.saveAsParquetFile(dataPath(path))
+ }
+
+ def load(sc: SparkContext, path: String): NaiveBayesModel = {
+ val sqlContext = new SQLContext(sc)
+ // Load Parquet data.
+ val dataRDD = sqlContext.parquetFile(dataPath(path))
+ // Check schema explicitly since erasure makes it hard to use match-case for checking.
+ checkSchema[Data](dataRDD.schema)
+ val dataArray = dataRDD.select("labels", "pi", "theta", "modelType").take(1)
+ assert(dataArray.size == 1, s"Unable to load NaiveBayesModel data from: ${dataPath(path)}")
+ val data = dataArray(0)
+ val labels = data.getAs[Seq[Double]](0).toArray
+ val pi = data.getAs[Seq[Double]](1).toArray
+ val theta = data.getAs[Seq[Seq[Double]]](2).map(_.toArray).toArray
+ val modelType = data.getString(3)
+ new NaiveBayesModel(labels, pi, theta, modelType)
+ }
+
+ }
+
+ private[mllib] object SaveLoadV1_0 {
def thisFormatVersion: String = "1.0"
@@ -100,7 +166,10 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
def thisClassName: String = "org.apache.spark.mllib.classification.NaiveBayesModel"
/** Model data for model import/export */
- case class Data(labels: Array[Double], pi: Array[Double], theta: Array[Array[Double]])
+ case class Data(
+ labels: Array[Double],
+ pi: Array[Double],
+ theta: Array[Array[Double]])
def save(sc: SparkContext, path: String, data: Data): Unit = {
val sqlContext = new SQLContext(sc)
@@ -136,26 +205,32 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
override def load(sc: SparkContext, path: String): NaiveBayesModel = {
val (loadedClassName, version, metadata) = loadMetadata(sc, path)
val classNameV1_0 = SaveLoadV1_0.thisClassName
- (loadedClassName, version) match {
+ val classNameV2_0 = SaveLoadV2_0.thisClassName
+ val (model, numFeatures, numClasses) = (loadedClassName, version) match {
case (className, "1.0") if className == classNameV1_0 =>
val (numFeatures, numClasses) = ClassificationModel.getNumFeaturesClasses(metadata)
val model = SaveLoadV1_0.load(sc, path)
- assert(model.pi.size == numClasses,
- s"NaiveBayesModel.load expected $numClasses classes," +
- s" but class priors vector pi had ${model.pi.size} elements")
- assert(model.theta.size == numClasses,
- s"NaiveBayesModel.load expected $numClasses classes," +
- s" but class conditionals array theta had ${model.theta.size} elements")
- assert(model.theta.forall(_.size == numFeatures),
- s"NaiveBayesModel.load expected $numFeatures features," +
- s" but class conditionals array theta had elements of size:" +
- s" ${model.theta.map(_.size).mkString(",")}")
- model
+ (model, numFeatures, numClasses)
+ case (className, "2.0") if className == classNameV2_0 =>
+ val (numFeatures, numClasses) = ClassificationModel.getNumFeaturesClasses(metadata)
+ val model = SaveLoadV2_0.load(sc, path)
+ (model, numFeatures, numClasses)
case _ => throw new Exception(
s"NaiveBayesModel.load did not recognize model with (className, format version):" +
s"($loadedClassName, $version). Supported:\n" +
s" ($classNameV1_0, 1.0)")
}
+ assert(model.pi.size == numClasses,
+ s"NaiveBayesModel.load expected $numClasses classes," +
+ s" but class priors vector pi had ${model.pi.size} elements")
+ assert(model.theta.size == numClasses,
+ s"NaiveBayesModel.load expected $numClasses classes," +
+ s" but class conditionals array theta had ${model.theta.size} elements")
+ assert(model.theta.forall(_.size == numFeatures),
+ s"NaiveBayesModel.load expected $numFeatures features," +
+ s" but class conditionals array theta had elements of size:" +
+ s" ${model.theta.map(_.size).mkString(",")}")
+ model
}
}
@@ -167,9 +242,14 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
* document classification. By making every vector a 0-1 vector, it can also be used as
* Bernoulli NB ([[http://tinyurl.com/p7c96j6]]). The input feature values must be nonnegative.
*/
-class NaiveBayes private (private var lambda: Double) extends Serializable with Logging {
- def this() = this(1.0)
+class NaiveBayes private (
+ private var lambda: Double,
+ private var modelType: String) extends Serializable with Logging {
+
+ def this(lambda: Double) = this(lambda, "Multinomial")
+
+ def this() = this(1.0, "Multinomial")
/** Set the smoothing parameter. Default: 1.0. */
def setLambda(lambda: Double): NaiveBayes = {
@@ -177,10 +257,25 @@ class NaiveBayes private (private var lambda: Double) extends Serializable with
this
}
- /** Get the smoothing parameter. Default: 1.0. */
+ /** Get the smoothing parameter. */
def getLambda: Double = lambda
/**
+ * Set the model type using a string (case-sensitive).
+ * Supported options: "Multinomial" and "Bernoulli".
+ * (default: Multinomial)
+ */
+ def setModelType(modelType:String): NaiveBayes = {
+ require(NaiveBayes.supportedModelTypes.contains(modelType),
+ s"NaiveBayes was created with an unknown ModelType: $modelType")
+ this.modelType = modelType
+ this
+ }
+
+ /** Get the model type. */
+ def getModelType: String = this.modelType
+
+ /**
* Run the algorithm with the configured parameters on an input RDD of LabeledPoint entries.
*
* @param data RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
@@ -213,21 +308,30 @@ class NaiveBayes private (private var lambda: Double) extends Serializable with
mergeCombiners = (c1: (Long, BDV[Double]), c2: (Long, BDV[Double])) =>
(c1._1 + c2._1, c1._2 += c2._2)
).collect()
+
val numLabels = aggregated.length
var numDocuments = 0L
aggregated.foreach { case (_, (n, _)) =>
numDocuments += n
}
val numFeatures = aggregated.head match { case (_, (_, v)) => v.size }
+
val labels = new Array[Double](numLabels)
val pi = new Array[Double](numLabels)
val theta = Array.fill(numLabels)(new Array[Double](numFeatures))
+
val piLogDenom = math.log(numDocuments + numLabels * lambda)
var i = 0
aggregated.foreach { case (label, (n, sumTermFreqs)) =>
labels(i) = label
- val thetaLogDenom = math.log(brzSum(sumTermFreqs) + numFeatures * lambda)
pi(i) = math.log(n + lambda) - piLogDenom
+ val thetaLogDenom = modelType match {
+ case "Multinomial" => math.log(brzSum(sumTermFreqs) + numFeatures * lambda)
+ case "Bernoulli" => math.log(n + 2.0 * lambda)
+ case _ =>
+ // This should never happen.
+ throw new UnknownError(s"NaiveBayes was created with an unknown ModelType: $modelType")
+ }
var j = 0
while (j < numFeatures) {
theta(i)(j) = math.log(sumTermFreqs(j) + lambda) - thetaLogDenom
@@ -236,7 +340,7 @@ class NaiveBayes private (private var lambda: Double) extends Serializable with
i += 1
}
- new NaiveBayesModel(labels, pi, theta)
+ new NaiveBayesModel(labels, pi, theta, modelType)
}
}
@@ -244,13 +348,16 @@ class NaiveBayes private (private var lambda: Double) extends Serializable with
* Top-level methods for calling naive Bayes.
*/
object NaiveBayes {
+
+ /* Set of modelTypes that NaiveBayes supports */
+ private[mllib] val supportedModelTypes = Set("Multinomial", "Bernoulli")
+
/**
* Trains a Naive Bayes model given an RDD of `(label, features)` pairs.
*
- * This is the Multinomial NB ([[http://tinyurl.com/lsdw6p]]) which can handle all kinds of
- * discrete data. For example, by converting documents into TF-IDF vectors, it can be used for
- * document classification. By making every vector a 0-1 vector, it can also be used as
- * Bernoulli NB ([[http://tinyurl.com/p7c96j6]]).
+ * This is the default Multinomial NB ([[http://tinyurl.com/lsdw6p]]) which can handle all
+ * kinds of discrete data. For example, by converting documents into TF-IDF vectors, it
+ * can be used for document classification.
*
* This version of the method uses a default smoothing parameter of 1.0.
*
@@ -264,16 +371,40 @@ object NaiveBayes {
/**
* Trains a Naive Bayes model given an RDD of `(label, features)` pairs.
*
- * This is the Multinomial NB ([[http://tinyurl.com/lsdw6p]]) which can handle all kinds of
- * discrete data. For example, by converting documents into TF-IDF vectors, it can be used for
- * document classification. By making every vector a 0-1 vector, it can also be used as
- * Bernoulli NB ([[http://tinyurl.com/p7c96j6]]).
+ * This is the default Multinomial NB ([[http://tinyurl.com/lsdw6p]]) which can handle all
+ * kinds of discrete data. For example, by converting documents into TF-IDF vectors, it
+ * can be used for document classification.
*
* @param input RDD of `(label, array of features)` pairs. Every vector should be a frequency
* vector or a count vector.
* @param lambda The smoothing parameter
*/
def train(input: RDD[LabeledPoint], lambda: Double): NaiveBayesModel = {
- new NaiveBayes(lambda).run(input)
+ new NaiveBayes(lambda, "Multinomial").run(input)
+ }
+
+ /**
+ * Trains a Naive Bayes model given an RDD of `(label, features)` pairs.
+ *
+ * The model type can be set to either Multinomial NB ([[http://tinyurl.com/lsdw6p]])
+ * or Bernoulli NB ([[http://tinyurl.com/p7c96j6]]). The Multinomial NB can handle
+ * discrete count data and can be called by setting the model type to "multinomial".
+ * For example, it can be used with word counts or TF_IDF vectors of documents.
+ * The Bernoulli model fits presence or absence (0-1) counts. By making every vector a
+ * 0-1 vector and setting the model type to "bernoulli", the fits and predicts as
+ * Bernoulli NB.
+ *
+ * @param input RDD of `(label, array of features)` pairs. Every vector should be a frequency
+ * vector or a count vector.
+ * @param lambda The smoothing parameter
+ *
+ * @param modelType The type of NB model to fit from the enumeration NaiveBayesModels, can be
+ * multinomial or bernoulli
+ */
+ def train(input: RDD[LabeledPoint], lambda: Double, modelType: String): NaiveBayesModel = {
+ require(supportedModelTypes.contains(modelType),
+ s"NaiveBayes was created with an unknown ModelType: $modelType")
+ new NaiveBayes(lambda, modelType).run(input)
}
+
}
diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java
index 1c90522a07..71fb7f13c3 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java
@@ -17,20 +17,22 @@
package org.apache.spark.mllib.classification;
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.List;
+
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
-import org.junit.After;
-import org.junit.Assert;
-import org.junit.Before;
-import org.junit.Test;
-import java.io.Serializable;
-import java.util.Arrays;
-import java.util.List;
public class JavaNaiveBayesSuite implements Serializable {
private transient JavaSparkContext sc;
@@ -102,4 +104,11 @@ public class JavaNaiveBayesSuite implements Serializable {
// Should be able to get the first prediction.
predictions.first();
}
+
+ @Test
+ public void testModelTypeSetters() {
+ NaiveBayes nb = new NaiveBayes()
+ .setModelType("Bernoulli")
+ .setModelType("Multinomial");
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
index 5a27c7d230..f9fe3e006c 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
@@ -19,6 +19,9 @@ package org.apache.spark.mllib.classification
import scala.util.Random
+import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum, Axis}
+import breeze.stats.distributions.{Multinomial => BrzMultinomial}
+
import org.scalatest.FunSuite
import org.apache.spark.SparkException
@@ -41,37 +44,48 @@ object NaiveBayesSuite {
// Generate input of the form Y = (theta * x).argmax()
def generateNaiveBayesInput(
- pi: Array[Double], // 1XC
- theta: Array[Array[Double]], // CXD
- nPoints: Int,
- seed: Int): Seq[LabeledPoint] = {
+ pi: Array[Double], // 1XC
+ theta: Array[Array[Double]], // CXD
+ nPoints: Int,
+ seed: Int,
+ modelType: String = "Multinomial",
+ sample: Int = 10): Seq[LabeledPoint] = {
val D = theta(0).length
val rnd = new Random(seed)
-
val _pi = pi.map(math.pow(math.E, _))
val _theta = theta.map(row => row.map(math.pow(math.E, _)))
for (i <- 0 until nPoints) yield {
val y = calcLabel(rnd.nextDouble(), _pi)
- val xi = Array.tabulate[Double](D) { j =>
- if (rnd.nextDouble() < _theta(y)(j)) 1 else 0
+ val xi = modelType match {
+ case "Bernoulli" => Array.tabulate[Double] (D) { j =>
+ if (rnd.nextDouble () < _theta(y)(j) ) 1 else 0
+ }
+ case "Multinomial" =>
+ val mult = BrzMultinomial(BDV(_theta(y)))
+ val emptyMap = (0 until D).map(x => (x, 0.0)).toMap
+ val counts = emptyMap ++ mult.sample(sample).groupBy(x => x).map {
+ case (index, reps) => (index, reps.size.toDouble)
+ }
+ counts.toArray.sortBy(_._1).map(_._2)
+ case _ =>
+ // This should never happen.
+ throw new UnknownError(s"NaiveBayesSuite found unknown ModelType: $modelType")
}
LabeledPoint(y, Vectors.dense(xi))
}
}
- private val smallPi = Array(0.5, 0.3, 0.2).map(math.log)
+ /** Bernoulli NaiveBayes with binary labels, 3 features */
+ private val binaryBernoulliModel = new NaiveBayesModel(labels = Array(0.0, 1.0),
+ pi = Array(0.2, 0.8), theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)),
+ "Bernoulli")
- private val smallTheta = Array(
- Array(0.91, 0.03, 0.03, 0.03), // label 0
- Array(0.03, 0.91, 0.03, 0.03), // label 1
- Array(0.03, 0.03, 0.91, 0.03) // label 2
- ).map(_.map(math.log))
-
- /** Binary labels, 3 features */
- private val binaryModel = new NaiveBayesModel(labels = Array(0.0, 1.0), pi = Array(0.2, 0.8),
- theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)))
+ /** Multinomial NaiveBayes with binary labels, 3 features */
+ private val binaryMultinomialModel = new NaiveBayesModel(labels = Array(0.0, 1.0),
+ pi = Array(0.2, 0.8), theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)),
+ "Multinomial")
}
class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
@@ -85,6 +99,24 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
assert(numOfPredictions < input.length / 5)
}
+ def validateModelFit(
+ piData: Array[Double],
+ thetaData: Array[Array[Double]],
+ model: NaiveBayesModel) = {
+ def closeFit(d1: Double, d2: Double, precision: Double): Boolean = {
+ (d1 - d2).abs <= precision
+ }
+ val modelIndex = (0 until piData.length).zip(model.labels.map(_.toInt))
+ for (i <- modelIndex) {
+ assert(closeFit(math.exp(piData(i._2)), math.exp(model.pi(i._1)), 0.05))
+ }
+ for (i <- modelIndex) {
+ for (j <- 0 until thetaData(i._2).length) {
+ assert(closeFit(math.exp(thetaData(i._2)(j)), math.exp(model.theta(i._1)(j)), 0.05))
+ }
+ }
+ }
+
test("get, set params") {
val nb = new NaiveBayes()
nb.setLambda(2.0)
@@ -93,19 +125,53 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
assert(nb.getLambda === 3.0)
}
- test("Naive Bayes") {
- val nPoints = 10000
+ test("Naive Bayes Multinomial") {
+ val nPoints = 1000
+ val pi = Array(0.5, 0.1, 0.4).map(math.log)
+ val theta = Array(
+ Array(0.70, 0.10, 0.10, 0.10), // label 0
+ Array(0.10, 0.70, 0.10, 0.10), // label 1
+ Array(0.10, 0.10, 0.70, 0.10) // label 2
+ ).map(_.map(math.log))
+
+ val testData = NaiveBayesSuite.generateNaiveBayesInput(
+ pi, theta, nPoints, 42, "Multinomial")
+ val testRDD = sc.parallelize(testData, 2)
+ testRDD.cache()
+
+ val model = NaiveBayes.train(testRDD, 1.0, "Multinomial")
+ validateModelFit(pi, theta, model)
+
+ val validationData = NaiveBayesSuite.generateNaiveBayesInput(
+ pi, theta, nPoints, 17, "Multinomial")
+ val validationRDD = sc.parallelize(validationData, 2)
+
+ // Test prediction on RDD.
+ validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData)
- val pi = NaiveBayesSuite.smallPi
- val theta = NaiveBayesSuite.smallTheta
+ // Test prediction on Array.
+ validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
+ }
- val testData = NaiveBayesSuite.generateNaiveBayesInput(pi, theta, nPoints, 42)
+ test("Naive Bayes Bernoulli") {
+ val nPoints = 10000
+ val pi = Array(0.5, 0.3, 0.2).map(math.log)
+ val theta = Array(
+ Array(0.50, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.40), // label 0
+ Array(0.02, 0.70, 0.10, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02), // label 1
+ Array(0.02, 0.02, 0.60, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.30) // label 2
+ ).map(_.map(math.log))
+
+ val testData = NaiveBayesSuite.generateNaiveBayesInput(
+ pi, theta, nPoints, 45, "Bernoulli")
val testRDD = sc.parallelize(testData, 2)
testRDD.cache()
- val model = NaiveBayes.train(testRDD)
+ val model = NaiveBayes.train(testRDD, 1.0, "Bernoulli")
+ validateModelFit(pi, theta, model)
- val validationData = NaiveBayesSuite.generateNaiveBayesInput(pi, theta, nPoints, 17)
+ val validationData = NaiveBayesSuite.generateNaiveBayesInput(
+ pi, theta, nPoints, 20, "Bernoulli")
val validationRDD = sc.parallelize(validationData, 2)
// Test prediction on RDD.
@@ -142,19 +208,41 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
}
}
- test("model save/load") {
- val model = NaiveBayesSuite.binaryModel
+ test("model save/load: 2.0 to 2.0") {
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+
+ Seq(NaiveBayesSuite.binaryBernoulliModel, NaiveBayesSuite.binaryMultinomialModel).map {
+ model =>
+ // Save model, load it back, and compare.
+ try {
+ model.save(sc, path)
+ val sameModel = NaiveBayesModel.load(sc, path)
+ assert(model.labels === sameModel.labels)
+ assert(model.pi === sameModel.pi)
+ assert(model.theta === sameModel.theta)
+ assert(model.modelType === sameModel.modelType)
+ } finally {
+ Utils.deleteRecursively(tempDir)
+ }
+ }
+ }
+
+ test("model save/load: 1.0 to 2.0") {
+ val model = NaiveBayesSuite.binaryMultinomialModel
val tempDir = Utils.createTempDir()
val path = tempDir.toURI.toString
- // Save model, load it back, and compare.
+ // Save model as version 1.0, load it back, and compare.
try {
- model.save(sc, path)
+ val data = NaiveBayesModel.SaveLoadV1_0.Data(model.labels, model.pi, model.theta)
+ NaiveBayesModel.SaveLoadV1_0.save(sc, path, data)
val sameModel = NaiveBayesModel.load(sc, path)
assert(model.labels === sameModel.labels)
assert(model.pi === sameModel.pi)
assert(model.theta === sameModel.theta)
+ assert(model.modelType === "Multinomial")
} finally {
Utils.deleteRecursively(tempDir)
}
@@ -172,8 +260,8 @@ class NaiveBayesClusterSuite extends FunSuite with LocalClusterSparkContext {
LabeledPoint(random.nextInt(2), Vectors.dense(Array.fill(n)(random.nextDouble())))
}
}
- // If we serialize data directly in the task closure, the size of the serialized task would be
- // greater than 1MB and hence Spark would throw an error.
+ // If we serialize data directly in the task closure, the size of the serialized task
+ // would be greater than 1MB and hence Spark would throw an error.
val model = NaiveBayes.train(examples)
val predictions = model.predict(examples.map(_.features))
}