aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2014-08-19 21:01:23 -0700
committerXiangrui Meng <meng@databricks.com>2014-08-19 21:01:23 -0700
commit068b6fe6a10eb1c6b2102d88832203267f030e85 (patch)
treeeb12c866970102b636d0edb80351bee0b6cb7b28 /mllib
parent0e3ab94d413fd70fff748fded42ab5e2ebd66fcc (diff)
downloadspark-068b6fe6a10eb1c6b2102d88832203267f030e85.tar.gz
spark-068b6fe6a10eb1c6b2102d88832203267f030e85.tar.bz2
spark-068b6fe6a10eb1c6b2102d88832203267f030e85.zip
[SPARK-3130][MLLIB] detect negative values in naive Bayes
because NB treats feature values as term frequencies. jkbradley Author: Xiangrui Meng <meng@databricks.com> Closes #2038 from mengxr/nb-neg and squashes the following commits: 52c37c3 [Xiangrui Meng] address comments 65f892d [Xiangrui Meng] detect negative values in nb
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala28
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala28
2 files changed, 51 insertions, 5 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
index 6c7be0a4f1..8c8e4a161a 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
@@ -19,9 +19,9 @@ package org.apache.spark.mllib.classification
import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum}
-import org.apache.spark.Logging
+import org.apache.spark.{SparkException, Logging}
import org.apache.spark.SparkContext._
-import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.RDD
@@ -73,7 +73,7 @@ class NaiveBayesModel private[mllib] (
* This is the Multinomial NB ([[http://tinyurl.com/lsdw6p]]) which can handle all kinds of
* discrete data. For example, by converting documents into TF-IDF vectors, it can be used for
* document classification. By making every vector a 0-1 vector, it can also be used as
- * Bernoulli NB ([[http://tinyurl.com/p7c96j6]]).
+ * Bernoulli NB ([[http://tinyurl.com/p7c96j6]]). The input feature values must be nonnegative.
*/
class NaiveBayes private (private var lambda: Double) extends Serializable with Logging {
@@ -91,12 +91,30 @@ class NaiveBayes private (private var lambda: Double) extends Serializable with
* @param data RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
*/
def run(data: RDD[LabeledPoint]) = {
+ val requireNonnegativeValues: Vector => Unit = (v: Vector) => {
+ val values = v match {
+ case sv: SparseVector =>
+ sv.values
+ case dv: DenseVector =>
+ dv.values
+ }
+ if (!values.forall(_ >= 0.0)) {
+ throw new SparkException(s"Naive Bayes requires nonnegative feature values but found $v.")
+ }
+ }
+
// Aggregates term frequencies per label.
// TODO: Calling combineByKey and collect creates two stages, we can implement something
// TODO: similar to reduceByKeyLocally to save one stage.
val aggregated = data.map(p => (p.label, p.features)).combineByKey[(Long, BDV[Double])](
- createCombiner = (v: Vector) => (1L, v.toBreeze.toDenseVector),
- mergeValue = (c: (Long, BDV[Double]), v: Vector) => (c._1 + 1L, c._2 += v.toBreeze),
+ createCombiner = (v: Vector) => {
+ requireNonnegativeValues(v)
+ (1L, v.toBreeze.toDenseVector)
+ },
+ mergeValue = (c: (Long, BDV[Double]), v: Vector) => {
+ requireNonnegativeValues(v)
+ (c._1 + 1L, c._2 += v.toBreeze)
+ },
mergeCombiners = (c1: (Long, BDV[Double]), c2: (Long, BDV[Double])) =>
(c1._1 + c2._1, c1._2 += c2._2)
).collect()
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
index 06cdd04f5f..80989bc074 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
@@ -21,6 +21,7 @@ import scala.util.Random
import org.scalatest.FunSuite
+import org.apache.spark.SparkException
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext}
@@ -95,6 +96,33 @@ class NaiveBayesSuite extends FunSuite with LocalSparkContext {
// Test prediction on Array.
validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
}
+
+ test("detect negative values") {
+ val dense = Seq(
+ LabeledPoint(1.0, Vectors.dense(1.0)),
+ LabeledPoint(0.0, Vectors.dense(-1.0)),
+ LabeledPoint(1.0, Vectors.dense(1.0)),
+ LabeledPoint(1.0, Vectors.dense(0.0)))
+ intercept[SparkException] {
+ NaiveBayes.train(sc.makeRDD(dense, 2))
+ }
+ val sparse = Seq(
+ LabeledPoint(1.0, Vectors.sparse(1, Array(0), Array(1.0))),
+ LabeledPoint(0.0, Vectors.sparse(1, Array(0), Array(-1.0))),
+ LabeledPoint(1.0, Vectors.sparse(1, Array(0), Array(1.0))),
+ LabeledPoint(1.0, Vectors.sparse(1, Array.empty, Array.empty)))
+ intercept[SparkException] {
+ NaiveBayes.train(sc.makeRDD(sparse, 2))
+ }
+ val nan = Seq(
+ LabeledPoint(1.0, Vectors.sparse(1, Array(0), Array(1.0))),
+ LabeledPoint(0.0, Vectors.sparse(1, Array(0), Array(Double.NaN))),
+ LabeledPoint(1.0, Vectors.sparse(1, Array(0), Array(1.0))),
+ LabeledPoint(1.0, Vectors.sparse(1, Array.empty, Array.empty)))
+ intercept[SparkException] {
+ NaiveBayes.train(sc.makeRDD(nan, 2))
+ }
+ }
}
class NaiveBayesClusterSuite extends FunSuite with LocalClusterSparkContext {