aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-05-21 10:30:08 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-05-21 10:30:27 -0700
commitb97a8053a02636b8f62a900d974cffa0e057441c (patch)
tree26fdf40a90c4a026718fdb15a665cc88b10f717d /mllib/src/test
parent3aa618510167ef72b4107d964a490be9d90da70d (diff)
downloadspark-b97a8053a02636b8f62a900d974cffa0e057441c.tar.gz
spark-b97a8053a02636b8f62a900d974cffa0e057441c.tar.bz2
spark-b97a8053a02636b8f62a900d974cffa0e057441c.zip
[SPARK-7752] [MLLIB] Use lowercase letters for NaiveBayes.modelType
to be consistent with other string names in MLlib. This PR also updates the implementation to use vals instead of hardcoded strings. jkbradley leahmcguire Author: Xiangrui Meng <meng@databricks.com> Closes #6277 from mengxr/SPARK-7752 and squashes the following commits: f38b662 [Xiangrui Meng] add another case _ back in test ae5c66a [Xiangrui Meng] model type -> modelType 711d1c6 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-7752 40ae53e [Xiangrui Meng] fix Java test suite 264a814 [Xiangrui Meng] add case _ back 3c456a8 [Xiangrui Meng] update NB user guide 17bba53 [Xiangrui Meng] update naive Bayes to use lowercase model type strings (cherry picked from commit 13348e21b6b1c0df42c18b82b86c613291228863) Signed-off-by: Joseph K. Bradley <joseph@databricks.com>
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java4
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala46
2 files changed, 27 insertions, 23 deletions
diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java
index 71fb7f13c3..3771c0ea7a 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java
@@ -108,7 +108,7 @@ public class JavaNaiveBayesSuite implements Serializable {
@Test
public void testModelTypeSetters() {
NaiveBayes nb = new NaiveBayes()
- .setModelType("Bernoulli")
- .setModelType("Multinomial");
+ .setModelType("bernoulli")
+ .setModelType("multinomial");
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
index 40a79a1f19..c111a78a55 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
@@ -19,9 +19,8 @@ package org.apache.spark.mllib.classification
import scala.util.Random
-import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum, Axis}
+import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum}
import breeze.stats.distributions.{Multinomial => BrzMultinomial}
-
import org.scalatest.FunSuite
import org.apache.spark.SparkException
@@ -30,9 +29,10 @@ import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}
import org.apache.spark.util.Utils
-
object NaiveBayesSuite {
+ import NaiveBayes.{Multinomial, Bernoulli}
+
private def calcLabel(p: Double, pi: Array[Double]): Int = {
var sum = 0.0
for (j <- 0 until pi.length) {
@@ -48,7 +48,7 @@ object NaiveBayesSuite {
theta: Array[Array[Double]], // CXD
nPoints: Int,
seed: Int,
- modelType: String = "Multinomial",
+ modelType: String = Multinomial,
sample: Int = 10): Seq[LabeledPoint] = {
val D = theta(0).length
val rnd = new Random(seed)
@@ -58,10 +58,10 @@ object NaiveBayesSuite {
for (i <- 0 until nPoints) yield {
val y = calcLabel(rnd.nextDouble(), _pi)
val xi = modelType match {
- case "Bernoulli" => Array.tabulate[Double] (D) { j =>
+ case Bernoulli => Array.tabulate[Double] (D) { j =>
if (rnd.nextDouble () < _theta(y)(j) ) 1 else 0
}
- case "Multinomial" =>
+ case Multinomial =>
val mult = BrzMultinomial(BDV(_theta(y)))
val emptyMap = (0 until D).map(x => (x, 0.0)).toMap
val counts = emptyMap ++ mult.sample(sample).groupBy(x => x).map {
@@ -70,7 +70,7 @@ object NaiveBayesSuite {
counts.toArray.sortBy(_._1).map(_._2)
case _ =>
// This should never happen.
- throw new UnknownError(s"NaiveBayesSuite found unknown ModelType: $modelType")
+ throw new UnknownError(s"Invalid modelType: $modelType.")
}
LabeledPoint(y, Vectors.dense(xi))
@@ -79,17 +79,17 @@ object NaiveBayesSuite {
/** Bernoulli NaiveBayes with binary labels, 3 features */
private val binaryBernoulliModel = new NaiveBayesModel(labels = Array(0.0, 1.0),
- pi = Array(0.2, 0.8), theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)),
- "Bernoulli")
+ pi = Array(0.2, 0.8), theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)), Bernoulli)
/** Multinomial NaiveBayes with binary labels, 3 features */
private val binaryMultinomialModel = new NaiveBayesModel(labels = Array(0.0, 1.0),
- pi = Array(0.2, 0.8), theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)),
- "Multinomial")
+ pi = Array(0.2, 0.8), theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)), Multinomial)
}
class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
+ import NaiveBayes.{Multinomial, Bernoulli}
+
def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) {
val numOfPredictions = predictions.zip(input).count {
case (prediction, expected) =>
@@ -117,6 +117,11 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
}
}
+ test("model types") {
+ assert(Multinomial === "multinomial")
+ assert(Bernoulli === "bernoulli")
+ }
+
test("get, set params") {
val nb = new NaiveBayes()
nb.setLambda(2.0)
@@ -134,16 +139,15 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
Array(0.10, 0.10, 0.70, 0.10) // label 2
).map(_.map(math.log))
- val testData = NaiveBayesSuite.generateNaiveBayesInput(
- pi, theta, nPoints, 42, "Multinomial")
+ val testData = NaiveBayesSuite.generateNaiveBayesInput(pi, theta, nPoints, 42, Multinomial)
val testRDD = sc.parallelize(testData, 2)
testRDD.cache()
- val model = NaiveBayes.train(testRDD, 1.0, "Multinomial")
+ val model = NaiveBayes.train(testRDD, 1.0, Multinomial)
validateModelFit(pi, theta, model)
val validationData = NaiveBayesSuite.generateNaiveBayesInput(
- pi, theta, nPoints, 17, "Multinomial")
+ pi, theta, nPoints, 17, Multinomial)
val validationRDD = sc.parallelize(validationData, 2)
// Test prediction on RDD.
@@ -163,15 +167,15 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
).map(_.map(math.log))
val testData = NaiveBayesSuite.generateNaiveBayesInput(
- pi, theta, nPoints, 45, "Bernoulli")
+ pi, theta, nPoints, 45, Bernoulli)
val testRDD = sc.parallelize(testData, 2)
testRDD.cache()
- val model = NaiveBayes.train(testRDD, 1.0, "Bernoulli")
+ val model = NaiveBayes.train(testRDD, 1.0, Bernoulli)
validateModelFit(pi, theta, model)
val validationData = NaiveBayesSuite.generateNaiveBayesInput(
- pi, theta, nPoints, 20, "Bernoulli")
+ pi, theta, nPoints, 20, Bernoulli)
val validationRDD = sc.parallelize(validationData, 2)
// Test prediction on RDD.
@@ -216,7 +220,7 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
LabeledPoint(1.0, Vectors.dense(0.0)))
intercept[SparkException] {
- NaiveBayes.train(sc.makeRDD(badTrain, 2), 1.0, "Bernoulli")
+ NaiveBayes.train(sc.makeRDD(badTrain, 2), 1.0, Bernoulli)
}
val okTrain = Seq(
@@ -235,7 +239,7 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
Vectors.dense(1.0),
Vectors.dense(0.0))
- val model = NaiveBayes.train(sc.makeRDD(okTrain, 2), 1.0, "Bernoulli")
+ val model = NaiveBayes.train(sc.makeRDD(okTrain, 2), 1.0, Bernoulli)
intercept[SparkException] {
model.predict(sc.makeRDD(badPredict, 2)).collect()
}
@@ -275,7 +279,7 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
assert(model.labels === sameModel.labels)
assert(model.pi === sameModel.pi)
assert(model.theta === sameModel.theta)
- assert(model.modelType === "Multinomial")
+ assert(model.modelType === Multinomial)
} finally {
Utils.deleteRecursively(tempDir)
}