aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-05-21 10:30:08 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-05-21 10:30:08 -0700
commit13348e21b6b1c0df42c18b82b86c613291228863 (patch)
tree83915ccd4d2ff42e0da660f1f9e26c8c893eaff4 /mllib
parenta25c1ab8f04a4e19d82ff4c18a0b1689d8b3ddac (diff)
downloadspark-13348e21b6b1c0df42c18b82b86c613291228863.tar.gz
spark-13348e21b6b1c0df42c18b82b86c613291228863.tar.bz2
spark-13348e21b6b1c0df42c18b82b86c613291228863.zip
[SPARK-7752] [MLLIB] Use lowercase letters for NaiveBayes.modelType
to be consistent with other string names in MLlib. This PR also updates the implementation to use vals instead of hardcoded strings. jkbradley leahmcguire Author: Xiangrui Meng <meng@databricks.com> Closes #6277 from mengxr/SPARK-7752 and squashes the following commits: f38b662 [Xiangrui Meng] add another case _ back in test ae5c66a [Xiangrui Meng] model type -> modelType 711d1c6 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-7752 40ae53e [Xiangrui Meng] fix Java test suite 264a814 [Xiangrui Meng] add case _ back 3c456a8 [Xiangrui Meng] update NB user guide 17bba53 [Xiangrui Meng] update naive Bayes to use lowercase model type strings
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala75
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java4
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala46
3 files changed, 70 insertions, 55 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
index cffe9ef1e0..f51ee36d0d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
@@ -25,13 +25,12 @@ import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._
import org.apache.spark.{Logging, SparkContext, SparkException}
-import org.apache.spark.mllib.linalg.{BLAS, DenseMatrix, DenseVector, SparseVector, Vector, Vectors}
+import org.apache.spark.mllib.linalg.{BLAS, DenseMatrix, DenseVector, SparseVector, Vector}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, SQLContext}
-
/**
* Model for Naive Bayes Classifiers.
*
@@ -39,7 +38,7 @@ import org.apache.spark.sql.{DataFrame, SQLContext}
* @param pi log of class priors, whose dimension is C, number of labels
* @param theta log of class conditional probabilities, whose dimension is C-by-D,
* where D is number of features
- * @param modelType The type of NB model to fit can be "Multinomial" or "Bernoulli"
+ * @param modelType The type of NB model to fit can be "multinomial" or "bernoulli"
*/
class NaiveBayesModel private[mllib] (
val labels: Array[Double],
@@ -48,11 +47,13 @@ class NaiveBayesModel private[mllib] (
val modelType: String)
extends ClassificationModel with Serializable with Saveable {
+ import NaiveBayes.{Bernoulli, Multinomial, supportedModelTypes}
+
private val piVector = new DenseVector(pi)
- private val thetaMatrix = new DenseMatrix(labels.size, theta(0).size, theta.flatten, true)
+ private val thetaMatrix = new DenseMatrix(labels.length, theta(0).length, theta.flatten, true)
private[mllib] def this(labels: Array[Double], pi: Array[Double], theta: Array[Array[Double]]) =
- this(labels, pi, theta, "Multinomial")
+ this(labels, pi, theta, NaiveBayes.Multinomial)
/** A Java-friendly constructor that takes three Iterable parameters. */
private[mllib] def this(
@@ -61,12 +62,15 @@ class NaiveBayesModel private[mllib] (
theta: JIterable[JIterable[Double]]) =
this(labels.asScala.toArray, pi.asScala.toArray, theta.asScala.toArray.map(_.asScala.toArray))
+ require(supportedModelTypes.contains(modelType),
+ s"Invalid modelType $modelType. Supported modelTypes are $supportedModelTypes.")
+
// Bernoulli scoring requires log(condprob) if 1, log(1-condprob) if 0.
// This precomputes log(1.0 - exp(theta)) and its sum which are used for the linear algebra
// application of this condition (in predict function).
private val (thetaMinusNegTheta, negThetaSum) = modelType match {
- case "Multinomial" => (None, None)
- case "Bernoulli" =>
+ case Multinomial => (None, None)
+ case Bernoulli =>
val negTheta = thetaMatrix.map(value => math.log(1.0 - math.exp(value)))
val ones = new DenseVector(Array.fill(thetaMatrix.numCols){1.0})
val thetaMinusNegTheta = thetaMatrix.map { value =>
@@ -75,7 +79,7 @@ class NaiveBayesModel private[mllib] (
(Option(thetaMinusNegTheta), Option(negTheta.multiply(ones)))
case _ =>
// This should never happen.
- throw new UnknownError(s"NaiveBayesModel was created with an unknown ModelType: $modelType")
+ throw new UnknownError(s"Invalid modelType: $modelType.")
}
override def predict(testData: RDD[Vector]): RDD[Double] = {
@@ -88,15 +92,15 @@ class NaiveBayesModel private[mllib] (
override def predict(testData: Vector): Double = {
modelType match {
- case "Multinomial" =>
+ case Multinomial =>
val prob = thetaMatrix.multiply(testData)
BLAS.axpy(1.0, piVector, prob)
labels(prob.argmax)
- case "Bernoulli" =>
+ case Bernoulli =>
testData.foreachActive { (index, value) =>
if (value != 0.0 && value != 1.0) {
throw new SparkException(
- s"Bernoulli Naive Bayes requires 0 or 1 feature values but found $testData.")
+ s"Bernoulli naive Bayes requires 0 or 1 feature values but found $testData.")
}
}
val prob = thetaMinusNegTheta.get.multiply(testData)
@@ -105,7 +109,7 @@ class NaiveBayesModel private[mllib] (
labels(prob.argmax)
case _ =>
// This should never happen.
- throw new UnknownError(s"NaiveBayesModel was created with an unknown ModelType: $modelType")
+ throw new UnknownError(s"Invalid modelType: $modelType.")
}
}
@@ -230,16 +234,16 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
s"($loadedClassName, $version). Supported:\n" +
s" ($classNameV1_0, 1.0)")
}
- assert(model.pi.size == numClasses,
+ assert(model.pi.length == numClasses,
s"NaiveBayesModel.load expected $numClasses classes," +
- s" but class priors vector pi had ${model.pi.size} elements")
- assert(model.theta.size == numClasses,
+ s" but class priors vector pi had ${model.pi.length} elements")
+ assert(model.theta.length == numClasses,
s"NaiveBayesModel.load expected $numClasses classes," +
- s" but class conditionals array theta had ${model.theta.size} elements")
- assert(model.theta.forall(_.size == numFeatures),
+ s" but class conditionals array theta had ${model.theta.length} elements")
+ assert(model.theta.forall(_.length == numFeatures),
s"NaiveBayesModel.load expected $numFeatures features," +
s" but class conditionals array theta had elements of size:" +
- s" ${model.theta.map(_.size).mkString(",")}")
+ s" ${model.theta.map(_.length).mkString(",")}")
model
}
}
@@ -257,9 +261,11 @@ class NaiveBayes private (
private var lambda: Double,
private var modelType: String) extends Serializable with Logging {
- def this(lambda: Double) = this(lambda, "Multinomial")
+ import NaiveBayes.{Bernoulli, Multinomial}
- def this() = this(1.0, "Multinomial")
+ def this(lambda: Double) = this(lambda, NaiveBayes.Multinomial)
+
+ def this() = this(1.0, NaiveBayes.Multinomial)
/** Set the smoothing parameter. Default: 1.0. */
def setLambda(lambda: Double): NaiveBayes = {
@@ -272,12 +278,11 @@ class NaiveBayes private (
/**
* Set the model type using a string (case-sensitive).
- * Supported options: "Multinomial" and "Bernoulli".
- * (default: Multinomial)
+ * Supported options: "multinomial" (default) and "bernoulli".
*/
- def setModelType(modelType:String): NaiveBayes = {
+ def setModelType(modelType: String): NaiveBayes = {
require(NaiveBayes.supportedModelTypes.contains(modelType),
- s"NaiveBayes was created with an unknown ModelType: $modelType")
+ s"NaiveBayes was created with an unknown modelType: $modelType.")
this.modelType = modelType
this
}
@@ -308,7 +313,7 @@ class NaiveBayes private (
}
if (!values.forall(v => v == 0.0 || v == 1.0)) {
throw new SparkException(
- s"Bernoulli Naive Bayes requires 0 or 1 feature values but found $v.")
+ s"Bernoulli naive Bayes requires 0 or 1 feature values but found $v.")
}
}
@@ -317,7 +322,7 @@ class NaiveBayes private (
// TODO: similar to reduceByKeyLocally to save one stage.
val aggregated = data.map(p => (p.label, p.features)).combineByKey[(Long, DenseVector)](
createCombiner = (v: Vector) => {
- if (modelType == "Bernoulli") {
+ if (modelType == Bernoulli) {
requireZeroOneBernoulliValues(v)
} else {
requireNonnegativeValues(v)
@@ -352,11 +357,11 @@ class NaiveBayes private (
labels(i) = label
pi(i) = math.log(n + lambda) - piLogDenom
val thetaLogDenom = modelType match {
- case "Multinomial" => math.log(sumTermFreqs.values.sum + numFeatures * lambda)
- case "Bernoulli" => math.log(n + 2.0 * lambda)
+ case Multinomial => math.log(sumTermFreqs.values.sum + numFeatures * lambda)
+ case Bernoulli => math.log(n + 2.0 * lambda)
case _ =>
// This should never happen.
- throw new UnknownError(s"NaiveBayes was created with an unknown ModelType: $modelType")
+ throw new UnknownError(s"Invalid modelType: $modelType.")
}
var j = 0
while (j < numFeatures) {
@@ -375,8 +380,14 @@ class NaiveBayes private (
*/
object NaiveBayes {
+ /** String name for multinomial model type. */
+ private[classification] val Multinomial: String = "multinomial"
+
+ /** String name for Bernoulli model type. */
+ private[classification] val Bernoulli: String = "bernoulli"
+
/* Set of modelTypes that NaiveBayes supports */
- private[mllib] val supportedModelTypes = Set("Multinomial", "Bernoulli")
+ private[classification] val supportedModelTypes = Set(Multinomial, Bernoulli)
/**
* Trains a Naive Bayes model given an RDD of `(label, features)` pairs.
@@ -406,7 +417,7 @@ object NaiveBayes {
* @param lambda The smoothing parameter
*/
def train(input: RDD[LabeledPoint], lambda: Double): NaiveBayesModel = {
- new NaiveBayes(lambda, "Multinomial").run(input)
+ new NaiveBayes(lambda, Multinomial).run(input)
}
/**
@@ -429,7 +440,7 @@ object NaiveBayes {
*/
def train(input: RDD[LabeledPoint], lambda: Double, modelType: String): NaiveBayesModel = {
require(supportedModelTypes.contains(modelType),
- s"NaiveBayes was created with an unknown ModelType: $modelType")
+ s"NaiveBayes was created with an unknown modelType: $modelType.")
new NaiveBayes(lambda, modelType).run(input)
}
diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java
index 71fb7f13c3..3771c0ea7a 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java
@@ -108,7 +108,7 @@ public class JavaNaiveBayesSuite implements Serializable {
@Test
public void testModelTypeSetters() {
NaiveBayes nb = new NaiveBayes()
- .setModelType("Bernoulli")
- .setModelType("Multinomial");
+ .setModelType("bernoulli")
+ .setModelType("multinomial");
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
index 40a79a1f19..c111a78a55 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
@@ -19,9 +19,8 @@ package org.apache.spark.mllib.classification
import scala.util.Random
-import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum, Axis}
+import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum}
import breeze.stats.distributions.{Multinomial => BrzMultinomial}
-
import org.scalatest.FunSuite
import org.apache.spark.SparkException
@@ -30,9 +29,10 @@ import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}
import org.apache.spark.util.Utils
-
object NaiveBayesSuite {
+ import NaiveBayes.{Multinomial, Bernoulli}
+
private def calcLabel(p: Double, pi: Array[Double]): Int = {
var sum = 0.0
for (j <- 0 until pi.length) {
@@ -48,7 +48,7 @@ object NaiveBayesSuite {
theta: Array[Array[Double]], // CXD
nPoints: Int,
seed: Int,
- modelType: String = "Multinomial",
+ modelType: String = Multinomial,
sample: Int = 10): Seq[LabeledPoint] = {
val D = theta(0).length
val rnd = new Random(seed)
@@ -58,10 +58,10 @@ object NaiveBayesSuite {
for (i <- 0 until nPoints) yield {
val y = calcLabel(rnd.nextDouble(), _pi)
val xi = modelType match {
- case "Bernoulli" => Array.tabulate[Double] (D) { j =>
+ case Bernoulli => Array.tabulate[Double] (D) { j =>
if (rnd.nextDouble () < _theta(y)(j) ) 1 else 0
}
- case "Multinomial" =>
+ case Multinomial =>
val mult = BrzMultinomial(BDV(_theta(y)))
val emptyMap = (0 until D).map(x => (x, 0.0)).toMap
val counts = emptyMap ++ mult.sample(sample).groupBy(x => x).map {
@@ -70,7 +70,7 @@ object NaiveBayesSuite {
counts.toArray.sortBy(_._1).map(_._2)
case _ =>
// This should never happen.
- throw new UnknownError(s"NaiveBayesSuite found unknown ModelType: $modelType")
+ throw new UnknownError(s"Invalid modelType: $modelType.")
}
LabeledPoint(y, Vectors.dense(xi))
@@ -79,17 +79,17 @@ object NaiveBayesSuite {
/** Bernoulli NaiveBayes with binary labels, 3 features */
private val binaryBernoulliModel = new NaiveBayesModel(labels = Array(0.0, 1.0),
- pi = Array(0.2, 0.8), theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)),
- "Bernoulli")
+ pi = Array(0.2, 0.8), theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)), Bernoulli)
/** Multinomial NaiveBayes with binary labels, 3 features */
private val binaryMultinomialModel = new NaiveBayesModel(labels = Array(0.0, 1.0),
- pi = Array(0.2, 0.8), theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)),
- "Multinomial")
+ pi = Array(0.2, 0.8), theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)), Multinomial)
}
class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
+ import NaiveBayes.{Multinomial, Bernoulli}
+
def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) {
val numOfPredictions = predictions.zip(input).count {
case (prediction, expected) =>
@@ -117,6 +117,11 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
}
}
+ test("model types") {
+ assert(Multinomial === "multinomial")
+ assert(Bernoulli === "bernoulli")
+ }
+
test("get, set params") {
val nb = new NaiveBayes()
nb.setLambda(2.0)
@@ -134,16 +139,15 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
Array(0.10, 0.10, 0.70, 0.10) // label 2
).map(_.map(math.log))
- val testData = NaiveBayesSuite.generateNaiveBayesInput(
- pi, theta, nPoints, 42, "Multinomial")
+ val testData = NaiveBayesSuite.generateNaiveBayesInput(pi, theta, nPoints, 42, Multinomial)
val testRDD = sc.parallelize(testData, 2)
testRDD.cache()
- val model = NaiveBayes.train(testRDD, 1.0, "Multinomial")
+ val model = NaiveBayes.train(testRDD, 1.0, Multinomial)
validateModelFit(pi, theta, model)
val validationData = NaiveBayesSuite.generateNaiveBayesInput(
- pi, theta, nPoints, 17, "Multinomial")
+ pi, theta, nPoints, 17, Multinomial)
val validationRDD = sc.parallelize(validationData, 2)
// Test prediction on RDD.
@@ -163,15 +167,15 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
).map(_.map(math.log))
val testData = NaiveBayesSuite.generateNaiveBayesInput(
- pi, theta, nPoints, 45, "Bernoulli")
+ pi, theta, nPoints, 45, Bernoulli)
val testRDD = sc.parallelize(testData, 2)
testRDD.cache()
- val model = NaiveBayes.train(testRDD, 1.0, "Bernoulli")
+ val model = NaiveBayes.train(testRDD, 1.0, Bernoulli)
validateModelFit(pi, theta, model)
val validationData = NaiveBayesSuite.generateNaiveBayesInput(
- pi, theta, nPoints, 20, "Bernoulli")
+ pi, theta, nPoints, 20, Bernoulli)
val validationRDD = sc.parallelize(validationData, 2)
// Test prediction on RDD.
@@ -216,7 +220,7 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
LabeledPoint(1.0, Vectors.dense(0.0)))
intercept[SparkException] {
- NaiveBayes.train(sc.makeRDD(badTrain, 2), 1.0, "Bernoulli")
+ NaiveBayes.train(sc.makeRDD(badTrain, 2), 1.0, Bernoulli)
}
val okTrain = Seq(
@@ -235,7 +239,7 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
Vectors.dense(1.0),
Vectors.dense(0.0))
- val model = NaiveBayes.train(sc.makeRDD(okTrain, 2), 1.0, "Bernoulli")
+ val model = NaiveBayes.train(sc.makeRDD(okTrain, 2), 1.0, Bernoulli)
intercept[SparkException] {
model.predict(sc.makeRDD(badPredict, 2)).collect()
}
@@ -275,7 +279,7 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
assert(model.labels === sameModel.labels)
assert(model.pi === sameModel.pi)
assert(model.theta === sameModel.theta)
- assert(model.modelType === "Multinomial")
+ assert(model.modelType === Multinomial)
} finally {
Utils.deleteRecursively(tempDir)
}