aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2014-08-19 21:01:23 -0700
committerXiangrui Meng <meng@databricks.com>2014-08-19 21:01:23 -0700
commit068b6fe6a10eb1c6b2102d88832203267f030e85 (patch)
treeeb12c866970102b636d0edb80351bee0b6cb7b28 /mllib/src/main
parent0e3ab94d413fd70fff748fded42ab5e2ebd66fcc (diff)
downloadspark-068b6fe6a10eb1c6b2102d88832203267f030e85.tar.gz
spark-068b6fe6a10eb1c6b2102d88832203267f030e85.tar.bz2
spark-068b6fe6a10eb1c6b2102d88832203267f030e85.zip
[SPARK-3130][MLLIB] detect negative values in naive Bayes
because NB treats feature values as term frequencies. jkbradley Author: Xiangrui Meng <meng@databricks.com> Closes #2038 from mengxr/nb-neg and squashes the following commits: 52c37c3 [Xiangrui Meng] address comments 65f892d [Xiangrui Meng] detect negative values in nb
Diffstat (limited to 'mllib/src/main')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala28
1 files changed, 23 insertions, 5 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
index 6c7be0a4f1..8c8e4a161a 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
@@ -19,9 +19,9 @@ package org.apache.spark.mllib.classification
import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum}
-import org.apache.spark.Logging
+import org.apache.spark.{SparkException, Logging}
import org.apache.spark.SparkContext._
-import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.RDD
@@ -73,7 +73,7 @@ class NaiveBayesModel private[mllib] (
* This is the Multinomial NB ([[http://tinyurl.com/lsdw6p]]) which can handle all kinds of
* discrete data. For example, by converting documents into TF-IDF vectors, it can be used for
* document classification. By making every vector a 0-1 vector, it can also be used as
- * Bernoulli NB ([[http://tinyurl.com/p7c96j6]]).
+ * Bernoulli NB ([[http://tinyurl.com/p7c96j6]]). The input feature values must be nonnegative.
*/
class NaiveBayes private (private var lambda: Double) extends Serializable with Logging {
@@ -91,12 +91,30 @@ class NaiveBayes private (private var lambda: Double) extends Serializable with
* @param data RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
*/
def run(data: RDD[LabeledPoint]) = {
+ val requireNonnegativeValues: Vector => Unit = (v: Vector) => {
+ val values = v match {
+ case sv: SparseVector =>
+ sv.values
+ case dv: DenseVector =>
+ dv.values
+ }
+ if (!values.forall(_ >= 0.0)) {
+ throw new SparkException(s"Naive Bayes requires nonnegative feature values but found $v.")
+ }
+ }
+
// Aggregates term frequencies per label.
// TODO: Calling combineByKey and collect creates two stages, we can implement something
// TODO: similar to reduceByKeyLocally to save one stage.
val aggregated = data.map(p => (p.label, p.features)).combineByKey[(Long, BDV[Double])](
- createCombiner = (v: Vector) => (1L, v.toBreeze.toDenseVector),
- mergeValue = (c: (Long, BDV[Double]), v: Vector) => (c._1 + 1L, c._2 += v.toBreeze),
+ createCombiner = (v: Vector) => {
+ requireNonnegativeValues(v)
+ (1L, v.toBreeze.toDenseVector)
+ },
+ mergeValue = (c: (Long, BDV[Double]), v: Vector) => {
+ requireNonnegativeValues(v)
+ (c._1 + 1L, c._2 += v.toBreeze)
+ },
mergeCombiners = (c1: (Long, BDV[Double]), c2: (Long, BDV[Double])) =>
(c1._1 + c2._1, c1._2 += c2._2)
).collect()