diff options
author | Xiangrui Meng <meng@databricks.com> | 2015-05-21 10:30:08 -0700 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2015-05-21 10:30:27 -0700 |
commit | b97a8053a02636b8f62a900d974cffa0e057441c (patch) | |
tree | 26fdf40a90c4a026718fdb15a665cc88b10f717d /mllib/src/test | |
parent | 3aa618510167ef72b4107d964a490be9d90da70d (diff) | |
download | spark-b97a8053a02636b8f62a900d974cffa0e057441c.tar.gz spark-b97a8053a02636b8f62a900d974cffa0e057441c.tar.bz2 spark-b97a8053a02636b8f62a900d974cffa0e057441c.zip |
[SPARK-7752] [MLLIB] Use lowercase letters for NaiveBayes.modelType
to be consistent with other string names in MLlib. This PR also updates the implementation to use vals instead of hardcoded strings. jkbradley leahmcguire
Author: Xiangrui Meng <meng@databricks.com>
Closes #6277 from mengxr/SPARK-7752 and squashes the following commits:
f38b662 [Xiangrui Meng] add another case _ back in test
ae5c66a [Xiangrui Meng] model type -> modelType
711d1c6 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-7752
40ae53e [Xiangrui Meng] fix Java test suite
264a814 [Xiangrui Meng] add case _ back
3c456a8 [Xiangrui Meng] update NB user guide
17bba53 [Xiangrui Meng] update naive Bayes to use lowercase model type strings
(cherry picked from commit 13348e21b6b1c0df42c18b82b86c613291228863)
Signed-off-by: Joseph K. Bradley <joseph@databricks.com>
Diffstat (limited to 'mllib/src/test')
-rw-r--r-- | mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java | 4 | ||||
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala | 46 |
2 files changed, 27 insertions, 23 deletions
diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java index 71fb7f13c3..3771c0ea7a 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java @@ -108,7 +108,7 @@ public class JavaNaiveBayesSuite implements Serializable { @Test public void testModelTypeSetters() { NaiveBayes nb = new NaiveBayes() - .setModelType("Bernoulli") - .setModelType("Multinomial"); + .setModelType("bernoulli") + .setModelType("multinomial"); } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala index 40a79a1f19..c111a78a55 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala @@ -19,9 +19,8 @@ package org.apache.spark.mllib.classification import scala.util.Random -import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum, Axis} +import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum} import breeze.stats.distributions.{Multinomial => BrzMultinomial} - import org.scalatest.FunSuite import org.apache.spark.SparkException @@ -30,9 +29,10 @@ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} import org.apache.spark.util.Utils - object NaiveBayesSuite { + import NaiveBayes.{Multinomial, Bernoulli} + private def calcLabel(p: Double, pi: Array[Double]): Int = { var sum = 0.0 for (j <- 0 until pi.length) { @@ -48,7 +48,7 @@ object NaiveBayesSuite { theta: Array[Array[Double]], // CXD nPoints: Int, seed: Int, - modelType: String = "Multinomial", + modelType: String = Multinomial, sample: Int = 10): Seq[LabeledPoint] = { val D = theta(0).length val rnd = new Random(seed) @@ -58,10 +58,10 @@ object NaiveBayesSuite { for (i <- 0 until nPoints) yield { val y = calcLabel(rnd.nextDouble(), _pi) val xi = modelType match { - case "Bernoulli" => Array.tabulate[Double] (D) { j => + case Bernoulli => Array.tabulate[Double] (D) { j => if (rnd.nextDouble () < _theta(y)(j) ) 1 else 0 } - case "Multinomial" => + case Multinomial => val mult = BrzMultinomial(BDV(_theta(y))) val emptyMap = (0 until D).map(x => (x, 0.0)).toMap val counts = emptyMap ++ mult.sample(sample).groupBy(x => x).map { @@ -70,7 +70,7 @@ object NaiveBayesSuite { counts.toArray.sortBy(_._1).map(_._2) case _ => // This should never happen. - throw new UnknownError(s"NaiveBayesSuite found unknown ModelType: $modelType") + throw new UnknownError(s"Invalid modelType: $modelType.") } LabeledPoint(y, Vectors.dense(xi)) @@ -79,17 +79,17 @@ object NaiveBayesSuite { /** Bernoulli NaiveBayes with binary labels, 3 features */ private val binaryBernoulliModel = new NaiveBayesModel(labels = Array(0.0, 1.0), - pi = Array(0.2, 0.8), theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)), - "Bernoulli") + pi = Array(0.2, 0.8), theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)), Bernoulli) /** Multinomial NaiveBayes with binary labels, 3 features */ private val binaryMultinomialModel = new NaiveBayesModel(labels = Array(0.0, 1.0), - pi = Array(0.2, 0.8), theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)), - "Multinomial") + pi = Array(0.2, 0.8), theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)), Multinomial) } class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { + import NaiveBayes.{Multinomial, Bernoulli} + def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) { val numOfPredictions = predictions.zip(input).count { case (prediction, expected) => @@ -117,6 +117,11 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { } } + test("model types") { + assert(Multinomial === "multinomial") + assert(Bernoulli === "bernoulli") + } + test("get, set params") { val nb = new NaiveBayes() nb.setLambda(2.0) @@ -134,16 +139,15 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { Array(0.10, 0.10, 0.70, 0.10) // label 2 ).map(_.map(math.log)) - val testData = NaiveBayesSuite.generateNaiveBayesInput( - pi, theta, nPoints, 42, "Multinomial") + val testData = NaiveBayesSuite.generateNaiveBayesInput(pi, theta, nPoints, 42, Multinomial) val testRDD = sc.parallelize(testData, 2) testRDD.cache() - val model = NaiveBayes.train(testRDD, 1.0, "Multinomial") + val model = NaiveBayes.train(testRDD, 1.0, Multinomial) validateModelFit(pi, theta, model) val validationData = NaiveBayesSuite.generateNaiveBayesInput( - pi, theta, nPoints, 17, "Multinomial") + pi, theta, nPoints, 17, Multinomial) val validationRDD = sc.parallelize(validationData, 2) // Test prediction on RDD. @@ -163,15 +167,15 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { ).map(_.map(math.log)) val testData = NaiveBayesSuite.generateNaiveBayesInput( - pi, theta, nPoints, 45, "Bernoulli") + pi, theta, nPoints, 45, Bernoulli) val testRDD = sc.parallelize(testData, 2) testRDD.cache() - val model = NaiveBayes.train(testRDD, 1.0, "Bernoulli") + val model = NaiveBayes.train(testRDD, 1.0, Bernoulli) validateModelFit(pi, theta, model) val validationData = NaiveBayesSuite.generateNaiveBayesInput( - pi, theta, nPoints, 20, "Bernoulli") + pi, theta, nPoints, 20, Bernoulli) val validationRDD = sc.parallelize(validationData, 2) // Test prediction on RDD. @@ -216,7 +220,7 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { LabeledPoint(1.0, Vectors.dense(0.0))) intercept[SparkException] { - NaiveBayes.train(sc.makeRDD(badTrain, 2), 1.0, "Bernoulli") + NaiveBayes.train(sc.makeRDD(badTrain, 2), 1.0, Bernoulli) } val okTrain = Seq( @@ -235,7 +239,7 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { Vectors.dense(1.0), Vectors.dense(0.0)) - val model = NaiveBayes.train(sc.makeRDD(okTrain, 2), 1.0, "Bernoulli") + val model = NaiveBayes.train(sc.makeRDD(okTrain, 2), 1.0, Bernoulli) intercept[SparkException] { model.predict(sc.makeRDD(badPredict, 2)).collect() } @@ -275,7 +279,7 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { assert(model.labels === sameModel.labels) assert(model.pi === sameModel.pi) assert(model.theta === sameModel.theta) - assert(model.modelType === "Multinomial") + assert(model.modelType === Multinomial) } finally { Utils.deleteRecursively(tempDir) } |