aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
authorleahmcguire <lmcguire@salesforce.com>2015-03-31 11:16:55 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-03-31 11:16:55 -0700
commitd01a6d8c33fc5c8325b0cc4b51395dba5eb3462c (patch)
tree14a3faac411b44804fc141d32b4b2001952a0125 /mllib/src
parenta05835b89fe2086e460f0b80f7c22e284c0c32d0 (diff)
downloadspark-d01a6d8c33fc5c8325b0cc4b51395dba5eb3462c.tar.gz
spark-d01a6d8c33fc5c8325b0cc4b51395dba5eb3462c.tar.bz2
spark-d01a6d8c33fc5c8325b0cc4b51395dba5eb3462c.zip
[SPARK-4894][mllib] Added Bernoulli option to NaiveBayes model in mllib
Added optional model type parameter for NaiveBayes training. Can be either Multinomial or Bernoulli. When Bernoulli is given the Bernoulli smoothing is used for fitting and for prediction as per: http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html. Default for model is original Multinomial fit and predict. Added additional testing for Bernoulli and Multinomial models. Author: leahmcguire <lmcguire@salesforce.com> Author: Joseph K. Bradley <joseph@databricks.com> Author: Leah McGuire <lmcguire@salesforce.com> Closes #4087 from leahmcguire/master and squashes the following commits: f3c8994 [leahmcguire] changed checks on model type to requires acb69af [leahmcguire] removed enum type and replaces all modelType parameters with strings 2224b15 [Leah McGuire] Merge pull request #2 from jkbradley/leahmcguire-master 9ad89ca [Joseph K. Bradley] removed old code 6a8f383 [Joseph K. Bradley] Added new model save/load format 2.0 for NaiveBayesModel after modelType parameter was added. Updated tests. Also updated ModelType enum-like type. 852a727 [leahmcguire] merged with upstream master a22d670 [leahmcguire] changed NaiveBayesModel modelType parameter back to NaiveBayes.ModelType, made NaiveBayes.ModelType serializable, fixed getter method in NavieBayes 18f3219 [leahmcguire] removed private from naive bayes constructor for lambda only bea62af [leahmcguire] put back in constructor for NaiveBayes 01baad7 [leahmcguire] made fixes from code review fb0a5c7 [leahmcguire] removed typo e2d925e [leahmcguire] fixed nonserializable error that was causing naivebayes test failures 2d0c1ba [leahmcguire] fixed typo in NaiveBayes c298e78 [leahmcguire] fixed scala style errors b85b0c9 [leahmcguire] Merge remote-tracking branch 'upstream/master' 900b586 [leahmcguire] fixed model call so that uses type argument ea09b28 [leahmcguire] Merge remote-tracking branch 'upstream/master' e016569 [leahmcguire] updated test suite with model type fix 85f298f [leahmcguire] Merge remote-tracking branch 'upstream/master' dc65374 [leahmcguire] integrated model type fix 7622b0c [leahmcguire] added comments and fixed style as per rb b93aaf6 [Leah McGuire] Merge pull request #1 from jkbradley/nb-model-type 3730572 [Joseph K. Bradley] modified NB model type to be more Java-friendly b61b5e2 [leahmcguire] added back compatable constructor to NaiveBayesModel to fix MIMA test failure 5a4a534 [leahmcguire] fixed scala style error in NaiveBayes 3891bf2 [leahmcguire] synced with apache spark and resolved merge conflict d9477ed [leahmcguire] removed old inaccurate comment from test suite for mllib naive bayes 76e5b0f [leahmcguire] removed unnecessary sort from test 0313c0c [leahmcguire] fixed style error in NaiveBayes.scala 4a3676d [leahmcguire] Updated changes re-comments. Got rid of verbose populateMatrix method. Public api now has string instead of enumeration. Docs are updated." ce73c63 [leahmcguire] added Bernoulli option to niave bayes model in mllib, added optional model type parameter for training. When Bernoulli is given the Bernoulli smoothing is used for fitting and for prediction http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala225
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java23
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala148
3 files changed, 312 insertions, 84 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
index d60e82c410..c9b3ff0172 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
@@ -21,9 +21,12 @@ import java.lang.{Iterable => JIterable}
import scala.collection.JavaConverters._
-import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum}
+import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum, Axis}
+import breeze.numerics.{exp => brzExp, log => brzLog}
+
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._
+import org.json4s.{DefaultFormats, JValue}
import org.apache.spark.{Logging, SparkContext, SparkException}
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector}
@@ -32,6 +35,7 @@ import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, SQLContext}
+
/**
* Model for Naive Bayes Classifiers.
*
@@ -39,11 +43,17 @@ import org.apache.spark.sql.{DataFrame, SQLContext}
* @param pi log of class priors, whose dimension is C, number of labels
* @param theta log of class conditional probabilities, whose dimension is C-by-D,
* where D is number of features
+ * @param modelType The type of NB model to fit can be "Multinomial" or "Bernoulli"
*/
class NaiveBayesModel private[mllib] (
val labels: Array[Double],
val pi: Array[Double],
- val theta: Array[Array[Double]]) extends ClassificationModel with Serializable with Saveable {
+ val theta: Array[Array[Double]],
+ val modelType: String)
+ extends ClassificationModel with Serializable with Saveable {
+
+ private[mllib] def this(labels: Array[Double], pi: Array[Double], theta: Array[Array[Double]]) =
+ this(labels, pi, theta, "Multinomial")
/** A Java-friendly constructor that takes three Iterable parameters. */
private[mllib] def this(
@@ -53,19 +63,19 @@ class NaiveBayesModel private[mllib] (
this(labels.asScala.toArray, pi.asScala.toArray, theta.asScala.toArray.map(_.asScala.toArray))
private val brzPi = new BDV[Double](pi)
- private val brzTheta = new BDM[Double](theta.length, theta(0).length)
-
- {
- // Need to put an extra pair of braces to prevent Scala treating `i` as a member.
- var i = 0
- while (i < theta.length) {
- var j = 0
- while (j < theta(i).length) {
- brzTheta(i, j) = theta(i)(j)
- j += 1
- }
- i += 1
- }
+ private val brzTheta = new BDM(theta(0).length, theta.length, theta.flatten).t
+
+ // Bernoulli scoring requires log(condprob) if 1, log(1-condprob) if 0.
+ // This precomputes log(1.0 - exp(theta)) and its sum which are used for the linear algebra
+ // application of this condition (in predict function).
+ private val (brzNegTheta, brzNegThetaSum) = modelType match {
+ case "Multinomial" => (None, None)
+ case "Bernoulli" =>
+ val negTheta = brzLog((brzExp(brzTheta.copy) :*= (-1.0)) :+= 1.0) // log(1.0 - exp(x))
+ (Option(negTheta), Option(brzSum(negTheta, Axis._1)))
+ case _ =>
+ // This should never happen.
+ throw new UnknownError(s"NaiveBayesModel was created with an unknown ModelType: $modelType")
}
override def predict(testData: RDD[Vector]): RDD[Double] = {
@@ -77,22 +87,78 @@ class NaiveBayesModel private[mllib] (
}
override def predict(testData: Vector): Double = {
- labels(brzArgmax(brzPi + brzTheta * testData.toBreeze))
+ modelType match {
+ case "Multinomial" =>
+ labels (brzArgmax (brzPi + brzTheta * testData.toBreeze) )
+ case "Bernoulli" =>
+ labels (brzArgmax (brzPi +
+ (brzTheta - brzNegTheta.get) * testData.toBreeze + brzNegThetaSum.get))
+ case _ =>
+ // This should never happen.
+ throw new UnknownError(s"NaiveBayesModel was created with an unknown ModelType: $modelType")
+ }
}
override def save(sc: SparkContext, path: String): Unit = {
- val data = NaiveBayesModel.SaveLoadV1_0.Data(labels, pi, theta)
- NaiveBayesModel.SaveLoadV1_0.save(sc, path, data)
+ val data = NaiveBayesModel.SaveLoadV2_0.Data(labels, pi, theta, modelType)
+ NaiveBayesModel.SaveLoadV2_0.save(sc, path, data)
}
- override protected def formatVersion: String = "1.0"
+ override protected def formatVersion: String = "2.0"
}
object NaiveBayesModel extends Loader[NaiveBayesModel] {
import org.apache.spark.mllib.util.Loader._
- private object SaveLoadV1_0 {
+ private[mllib] object SaveLoadV2_0 {
+
+ def thisFormatVersion: String = "2.0"
+
+ /** Hard-code class name string in case it changes in the future */
+ def thisClassName: String = "org.apache.spark.mllib.classification.NaiveBayesModel"
+
+ /** Model data for model import/export */
+ case class Data(
+ labels: Array[Double],
+ pi: Array[Double],
+ theta: Array[Array[Double]],
+ modelType: String)
+
+ def save(sc: SparkContext, path: String, data: Data): Unit = {
+ val sqlContext = new SQLContext(sc)
+ import sqlContext.implicits._
+
+ // Create JSON metadata.
+ val metadata = compact(render(
+ ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~
+ ("numFeatures" -> data.theta(0).length) ~ ("numClasses" -> data.pi.length)))
+ sc.parallelize(Seq(metadata), 1).saveAsTextFile(metadataPath(path))
+
+ // Create Parquet data.
+ val dataRDD: DataFrame = sc.parallelize(Seq(data), 1).toDF()
+ dataRDD.saveAsParquetFile(dataPath(path))
+ }
+
+ def load(sc: SparkContext, path: String): NaiveBayesModel = {
+ val sqlContext = new SQLContext(sc)
+ // Load Parquet data.
+ val dataRDD = sqlContext.parquetFile(dataPath(path))
+ // Check schema explicitly since erasure makes it hard to use match-case for checking.
+ checkSchema[Data](dataRDD.schema)
+ val dataArray = dataRDD.select("labels", "pi", "theta", "modelType").take(1)
+ assert(dataArray.size == 1, s"Unable to load NaiveBayesModel data from: ${dataPath(path)}")
+ val data = dataArray(0)
+ val labels = data.getAs[Seq[Double]](0).toArray
+ val pi = data.getAs[Seq[Double]](1).toArray
+ val theta = data.getAs[Seq[Seq[Double]]](2).map(_.toArray).toArray
+ val modelType = data.getString(3)
+ new NaiveBayesModel(labels, pi, theta, modelType)
+ }
+
+ }
+
+ private[mllib] object SaveLoadV1_0 {
def thisFormatVersion: String = "1.0"
@@ -100,7 +166,10 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
def thisClassName: String = "org.apache.spark.mllib.classification.NaiveBayesModel"
/** Model data for model import/export */
- case class Data(labels: Array[Double], pi: Array[Double], theta: Array[Array[Double]])
+ case class Data(
+ labels: Array[Double],
+ pi: Array[Double],
+ theta: Array[Array[Double]])
def save(sc: SparkContext, path: String, data: Data): Unit = {
val sqlContext = new SQLContext(sc)
@@ -136,26 +205,32 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
override def load(sc: SparkContext, path: String): NaiveBayesModel = {
val (loadedClassName, version, metadata) = loadMetadata(sc, path)
val classNameV1_0 = SaveLoadV1_0.thisClassName
- (loadedClassName, version) match {
+ val classNameV2_0 = SaveLoadV2_0.thisClassName
+ val (model, numFeatures, numClasses) = (loadedClassName, version) match {
case (className, "1.0") if className == classNameV1_0 =>
val (numFeatures, numClasses) = ClassificationModel.getNumFeaturesClasses(metadata)
val model = SaveLoadV1_0.load(sc, path)
- assert(model.pi.size == numClasses,
- s"NaiveBayesModel.load expected $numClasses classes," +
- s" but class priors vector pi had ${model.pi.size} elements")
- assert(model.theta.size == numClasses,
- s"NaiveBayesModel.load expected $numClasses classes," +
- s" but class conditionals array theta had ${model.theta.size} elements")
- assert(model.theta.forall(_.size == numFeatures),
- s"NaiveBayesModel.load expected $numFeatures features," +
- s" but class conditionals array theta had elements of size:" +
- s" ${model.theta.map(_.size).mkString(",")}")
- model
+ (model, numFeatures, numClasses)
+ case (className, "2.0") if className == classNameV2_0 =>
+ val (numFeatures, numClasses) = ClassificationModel.getNumFeaturesClasses(metadata)
+ val model = SaveLoadV2_0.load(sc, path)
+ (model, numFeatures, numClasses)
case _ => throw new Exception(
s"NaiveBayesModel.load did not recognize model with (className, format version):" +
s"($loadedClassName, $version). Supported:\n" +
s" ($classNameV1_0, 1.0)")
}
+ assert(model.pi.size == numClasses,
+ s"NaiveBayesModel.load expected $numClasses classes," +
+ s" but class priors vector pi had ${model.pi.size} elements")
+ assert(model.theta.size == numClasses,
+ s"NaiveBayesModel.load expected $numClasses classes," +
+ s" but class conditionals array theta had ${model.theta.size} elements")
+ assert(model.theta.forall(_.size == numFeatures),
+ s"NaiveBayesModel.load expected $numFeatures features," +
+ s" but class conditionals array theta had elements of size:" +
+ s" ${model.theta.map(_.size).mkString(",")}")
+ model
}
}
@@ -167,9 +242,14 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
* document classification. By making every vector a 0-1 vector, it can also be used as
* Bernoulli NB ([[http://tinyurl.com/p7c96j6]]). The input feature values must be nonnegative.
*/
-class NaiveBayes private (private var lambda: Double) extends Serializable with Logging {
- def this() = this(1.0)
+class NaiveBayes private (
+ private var lambda: Double,
+ private var modelType: String) extends Serializable with Logging {
+
+ def this(lambda: Double) = this(lambda, "Multinomial")
+
+ def this() = this(1.0, "Multinomial")
/** Set the smoothing parameter. Default: 1.0. */
def setLambda(lambda: Double): NaiveBayes = {
@@ -177,10 +257,25 @@ class NaiveBayes private (private var lambda: Double) extends Serializable with
this
}
- /** Get the smoothing parameter. Default: 1.0. */
+ /** Get the smoothing parameter. */
def getLambda: Double = lambda
/**
+ * Set the model type using a string (case-sensitive).
+ * Supported options: "Multinomial" and "Bernoulli".
+ * (default: Multinomial)
+ */
+ def setModelType(modelType:String): NaiveBayes = {
+ require(NaiveBayes.supportedModelTypes.contains(modelType),
+ s"NaiveBayes was created with an unknown ModelType: $modelType")
+ this.modelType = modelType
+ this
+ }
+
+ /** Get the model type. */
+ def getModelType: String = this.modelType
+
+ /**
* Run the algorithm with the configured parameters on an input RDD of LabeledPoint entries.
*
* @param data RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
@@ -213,21 +308,30 @@ class NaiveBayes private (private var lambda: Double) extends Serializable with
mergeCombiners = (c1: (Long, BDV[Double]), c2: (Long, BDV[Double])) =>
(c1._1 + c2._1, c1._2 += c2._2)
).collect()
+
val numLabels = aggregated.length
var numDocuments = 0L
aggregated.foreach { case (_, (n, _)) =>
numDocuments += n
}
val numFeatures = aggregated.head match { case (_, (_, v)) => v.size }
+
val labels = new Array[Double](numLabels)
val pi = new Array[Double](numLabels)
val theta = Array.fill(numLabels)(new Array[Double](numFeatures))
+
val piLogDenom = math.log(numDocuments + numLabels * lambda)
var i = 0
aggregated.foreach { case (label, (n, sumTermFreqs)) =>
labels(i) = label
- val thetaLogDenom = math.log(brzSum(sumTermFreqs) + numFeatures * lambda)
pi(i) = math.log(n + lambda) - piLogDenom
+ val thetaLogDenom = modelType match {
+ case "Multinomial" => math.log(brzSum(sumTermFreqs) + numFeatures * lambda)
+ case "Bernoulli" => math.log(n + 2.0 * lambda)
+ case _ =>
+ // This should never happen.
+ throw new UnknownError(s"NaiveBayes was created with an unknown ModelType: $modelType")
+ }
var j = 0
while (j < numFeatures) {
theta(i)(j) = math.log(sumTermFreqs(j) + lambda) - thetaLogDenom
@@ -236,7 +340,7 @@ class NaiveBayes private (private var lambda: Double) extends Serializable with
i += 1
}
- new NaiveBayesModel(labels, pi, theta)
+ new NaiveBayesModel(labels, pi, theta, modelType)
}
}
@@ -244,13 +348,16 @@ class NaiveBayes private (private var lambda: Double) extends Serializable with
* Top-level methods for calling naive Bayes.
*/
object NaiveBayes {
+
+ /* Set of modelTypes that NaiveBayes supports */
+ private[mllib] val supportedModelTypes = Set("Multinomial", "Bernoulli")
+
/**
* Trains a Naive Bayes model given an RDD of `(label, features)` pairs.
*
- * This is the Multinomial NB ([[http://tinyurl.com/lsdw6p]]) which can handle all kinds of
- * discrete data. For example, by converting documents into TF-IDF vectors, it can be used for
- * document classification. By making every vector a 0-1 vector, it can also be used as
- * Bernoulli NB ([[http://tinyurl.com/p7c96j6]]).
+ * This is the default Multinomial NB ([[http://tinyurl.com/lsdw6p]]) which can handle all
+ * kinds of discrete data. For example, by converting documents into TF-IDF vectors, it
+ * can be used for document classification.
*
* This version of the method uses a default smoothing parameter of 1.0.
*
@@ -264,16 +371,40 @@ object NaiveBayes {
/**
* Trains a Naive Bayes model given an RDD of `(label, features)` pairs.
*
- * This is the Multinomial NB ([[http://tinyurl.com/lsdw6p]]) which can handle all kinds of
- * discrete data. For example, by converting documents into TF-IDF vectors, it can be used for
- * document classification. By making every vector a 0-1 vector, it can also be used as
- * Bernoulli NB ([[http://tinyurl.com/p7c96j6]]).
+ * This is the default Multinomial NB ([[http://tinyurl.com/lsdw6p]]) which can handle all
+ * kinds of discrete data. For example, by converting documents into TF-IDF vectors, it
+ * can be used for document classification.
*
* @param input RDD of `(label, array of features)` pairs. Every vector should be a frequency
* vector or a count vector.
* @param lambda The smoothing parameter
*/
def train(input: RDD[LabeledPoint], lambda: Double): NaiveBayesModel = {
- new NaiveBayes(lambda).run(input)
+ new NaiveBayes(lambda, "Multinomial").run(input)
+ }
+
+ /**
+ * Trains a Naive Bayes model given an RDD of `(label, features)` pairs.
+ *
+ * The model type can be set to either Multinomial NB ([[http://tinyurl.com/lsdw6p]])
+ * or Bernoulli NB ([[http://tinyurl.com/p7c96j6]]). The Multinomial NB can handle
+ * discrete count data and can be called by setting the model type to "multinomial".
+ * For example, it can be used with word counts or TF_IDF vectors of documents.
+ * The Bernoulli model fits presence or absence (0-1) counts. By making every vector a
+ * 0-1 vector and setting the model type to "bernoulli", the fits and predicts as
+ * Bernoulli NB.
+ *
+ * @param input RDD of `(label, array of features)` pairs. Every vector should be a frequency
+ * vector or a count vector.
+ * @param lambda The smoothing parameter
+ *
+ * @param modelType The type of NB model to fit from the enumeration NaiveBayesModels, can be
+ * multinomial or bernoulli
+ */
+ def train(input: RDD[LabeledPoint], lambda: Double, modelType: String): NaiveBayesModel = {
+ require(supportedModelTypes.contains(modelType),
+ s"NaiveBayes was created with an unknown ModelType: $modelType")
+ new NaiveBayes(lambda, modelType).run(input)
}
+
}
diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java
index 1c90522a07..71fb7f13c3 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java
@@ -17,20 +17,22 @@
package org.apache.spark.mllib.classification;
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.List;
+
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
-import org.junit.After;
-import org.junit.Assert;
-import org.junit.Before;
-import org.junit.Test;
-import java.io.Serializable;
-import java.util.Arrays;
-import java.util.List;
public class JavaNaiveBayesSuite implements Serializable {
private transient JavaSparkContext sc;
@@ -102,4 +104,11 @@ public class JavaNaiveBayesSuite implements Serializable {
// Should be able to get the first prediction.
predictions.first();
}
+
+ @Test
+ public void testModelTypeSetters() {
+ NaiveBayes nb = new NaiveBayes()
+ .setModelType("Bernoulli")
+ .setModelType("Multinomial");
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
index 5a27c7d230..f9fe3e006c 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
@@ -19,6 +19,9 @@ package org.apache.spark.mllib.classification
import scala.util.Random
+import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum, Axis}
+import breeze.stats.distributions.{Multinomial => BrzMultinomial}
+
import org.scalatest.FunSuite
import org.apache.spark.SparkException
@@ -41,37 +44,48 @@ object NaiveBayesSuite {
// Generate input of the form Y = (theta * x).argmax()
def generateNaiveBayesInput(
- pi: Array[Double], // 1XC
- theta: Array[Array[Double]], // CXD
- nPoints: Int,
- seed: Int): Seq[LabeledPoint] = {
+ pi: Array[Double], // 1XC
+ theta: Array[Array[Double]], // CXD
+ nPoints: Int,
+ seed: Int,
+ modelType: String = "Multinomial",
+ sample: Int = 10): Seq[LabeledPoint] = {
val D = theta(0).length
val rnd = new Random(seed)
-
val _pi = pi.map(math.pow(math.E, _))
val _theta = theta.map(row => row.map(math.pow(math.E, _)))
for (i <- 0 until nPoints) yield {
val y = calcLabel(rnd.nextDouble(), _pi)
- val xi = Array.tabulate[Double](D) { j =>
- if (rnd.nextDouble() < _theta(y)(j)) 1 else 0
+ val xi = modelType match {
+ case "Bernoulli" => Array.tabulate[Double] (D) { j =>
+ if (rnd.nextDouble () < _theta(y)(j) ) 1 else 0
+ }
+ case "Multinomial" =>
+ val mult = BrzMultinomial(BDV(_theta(y)))
+ val emptyMap = (0 until D).map(x => (x, 0.0)).toMap
+ val counts = emptyMap ++ mult.sample(sample).groupBy(x => x).map {
+ case (index, reps) => (index, reps.size.toDouble)
+ }
+ counts.toArray.sortBy(_._1).map(_._2)
+ case _ =>
+ // This should never happen.
+ throw new UnknownError(s"NaiveBayesSuite found unknown ModelType: $modelType")
}
LabeledPoint(y, Vectors.dense(xi))
}
}
- private val smallPi = Array(0.5, 0.3, 0.2).map(math.log)
+ /** Bernoulli NaiveBayes with binary labels, 3 features */
+ private val binaryBernoulliModel = new NaiveBayesModel(labels = Array(0.0, 1.0),
+ pi = Array(0.2, 0.8), theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)),
+ "Bernoulli")
- private val smallTheta = Array(
- Array(0.91, 0.03, 0.03, 0.03), // label 0
- Array(0.03, 0.91, 0.03, 0.03), // label 1
- Array(0.03, 0.03, 0.91, 0.03) // label 2
- ).map(_.map(math.log))
-
- /** Binary labels, 3 features */
- private val binaryModel = new NaiveBayesModel(labels = Array(0.0, 1.0), pi = Array(0.2, 0.8),
- theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)))
+ /** Multinomial NaiveBayes with binary labels, 3 features */
+ private val binaryMultinomialModel = new NaiveBayesModel(labels = Array(0.0, 1.0),
+ pi = Array(0.2, 0.8), theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)),
+ "Multinomial")
}
class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
@@ -85,6 +99,24 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
assert(numOfPredictions < input.length / 5)
}
+ def validateModelFit(
+ piData: Array[Double],
+ thetaData: Array[Array[Double]],
+ model: NaiveBayesModel) = {
+ def closeFit(d1: Double, d2: Double, precision: Double): Boolean = {
+ (d1 - d2).abs <= precision
+ }
+ val modelIndex = (0 until piData.length).zip(model.labels.map(_.toInt))
+ for (i <- modelIndex) {
+ assert(closeFit(math.exp(piData(i._2)), math.exp(model.pi(i._1)), 0.05))
+ }
+ for (i <- modelIndex) {
+ for (j <- 0 until thetaData(i._2).length) {
+ assert(closeFit(math.exp(thetaData(i._2)(j)), math.exp(model.theta(i._1)(j)), 0.05))
+ }
+ }
+ }
+
test("get, set params") {
val nb = new NaiveBayes()
nb.setLambda(2.0)
@@ -93,19 +125,53 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
assert(nb.getLambda === 3.0)
}
- test("Naive Bayes") {
- val nPoints = 10000
+ test("Naive Bayes Multinomial") {
+ val nPoints = 1000
+ val pi = Array(0.5, 0.1, 0.4).map(math.log)
+ val theta = Array(
+ Array(0.70, 0.10, 0.10, 0.10), // label 0
+ Array(0.10, 0.70, 0.10, 0.10), // label 1
+ Array(0.10, 0.10, 0.70, 0.10) // label 2
+ ).map(_.map(math.log))
+
+ val testData = NaiveBayesSuite.generateNaiveBayesInput(
+ pi, theta, nPoints, 42, "Multinomial")
+ val testRDD = sc.parallelize(testData, 2)
+ testRDD.cache()
+
+ val model = NaiveBayes.train(testRDD, 1.0, "Multinomial")
+ validateModelFit(pi, theta, model)
+
+ val validationData = NaiveBayesSuite.generateNaiveBayesInput(
+ pi, theta, nPoints, 17, "Multinomial")
+ val validationRDD = sc.parallelize(validationData, 2)
+
+ // Test prediction on RDD.
+ validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData)
- val pi = NaiveBayesSuite.smallPi
- val theta = NaiveBayesSuite.smallTheta
+ // Test prediction on Array.
+ validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
+ }
- val testData = NaiveBayesSuite.generateNaiveBayesInput(pi, theta, nPoints, 42)
+ test("Naive Bayes Bernoulli") {
+ val nPoints = 10000
+ val pi = Array(0.5, 0.3, 0.2).map(math.log)
+ val theta = Array(
+ Array(0.50, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.40), // label 0
+ Array(0.02, 0.70, 0.10, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02), // label 1
+ Array(0.02, 0.02, 0.60, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.30) // label 2
+ ).map(_.map(math.log))
+
+ val testData = NaiveBayesSuite.generateNaiveBayesInput(
+ pi, theta, nPoints, 45, "Bernoulli")
val testRDD = sc.parallelize(testData, 2)
testRDD.cache()
- val model = NaiveBayes.train(testRDD)
+ val model = NaiveBayes.train(testRDD, 1.0, "Bernoulli")
+ validateModelFit(pi, theta, model)
- val validationData = NaiveBayesSuite.generateNaiveBayesInput(pi, theta, nPoints, 17)
+ val validationData = NaiveBayesSuite.generateNaiveBayesInput(
+ pi, theta, nPoints, 20, "Bernoulli")
val validationRDD = sc.parallelize(validationData, 2)
// Test prediction on RDD.
@@ -142,19 +208,41 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
}
}
- test("model save/load") {
- val model = NaiveBayesSuite.binaryModel
+ test("model save/load: 2.0 to 2.0") {
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+
+ Seq(NaiveBayesSuite.binaryBernoulliModel, NaiveBayesSuite.binaryMultinomialModel).map {
+ model =>
+ // Save model, load it back, and compare.
+ try {
+ model.save(sc, path)
+ val sameModel = NaiveBayesModel.load(sc, path)
+ assert(model.labels === sameModel.labels)
+ assert(model.pi === sameModel.pi)
+ assert(model.theta === sameModel.theta)
+ assert(model.modelType === sameModel.modelType)
+ } finally {
+ Utils.deleteRecursively(tempDir)
+ }
+ }
+ }
+
+ test("model save/load: 1.0 to 2.0") {
+ val model = NaiveBayesSuite.binaryMultinomialModel
val tempDir = Utils.createTempDir()
val path = tempDir.toURI.toString
- // Save model, load it back, and compare.
+ // Save model as version 1.0, load it back, and compare.
try {
- model.save(sc, path)
+ val data = NaiveBayesModel.SaveLoadV1_0.Data(model.labels, model.pi, model.theta)
+ NaiveBayesModel.SaveLoadV1_0.save(sc, path, data)
val sameModel = NaiveBayesModel.load(sc, path)
assert(model.labels === sameModel.labels)
assert(model.pi === sameModel.pi)
assert(model.theta === sameModel.theta)
+ assert(model.modelType === "Multinomial")
} finally {
Utils.deleteRecursively(tempDir)
}
@@ -172,8 +260,8 @@ class NaiveBayesClusterSuite extends FunSuite with LocalClusterSparkContext {
LabeledPoint(random.nextInt(2), Vectors.dense(Array.fill(n)(random.nextDouble())))
}
}
- // If we serialize data directly in the task closure, the size of the serialized task would be
- // greater than 1MB and hence Spark would throw an error.
+ // If we serialize data directly in the task closure, the size of the serialized task
+ // would be greater than 1MB and hence Spark would throw an error.
val model = NaiveBayes.train(examples)
val predictions = model.predict(examples.map(_.features))
}