aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorLiang-Chi Hsieh <viirya@gmail.com>2015-05-19 13:53:08 -0700
committerXiangrui Meng <meng@databricks.com>2015-05-19 13:53:08 -0700
commitc12dff9b82e4869f866a9b96ce0bf05503dd7dda (patch)
tree18ed22cdb7287c076fd57e06866474de432941f9 /mllib
parent68fb2a46edc95f867d4b28597d20da2597f008c1 (diff)
downloadspark-c12dff9b82e4869f866a9b96ce0bf05503dd7dda.tar.gz
spark-c12dff9b82e4869f866a9b96ce0bf05503dd7dda.tar.bz2
spark-c12dff9b82e4869f866a9b96ce0bf05503dd7dda.zip
[SPARK-7652] [MLLIB] Update the implementation of naive Bayes prediction with BLAS
JIRA: https://issues.apache.org/jira/browse/SPARK-7652 Author: Liang-Chi Hsieh <viirya@gmail.com> Closes #6189 from viirya/naive_bayes_blas_prediction and squashes the following commits: ab611fd [Liang-Chi Hsieh] Remove unnecessary space. ddc48b9 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into naive_bayes_blas_prediction b5772b4 [Liang-Chi Hsieh] Fix binary compatibility. 2f65186 [Liang-Chi Hsieh] Remove toDense. 1b6cdfe [Liang-Chi Hsieh] Update the implementation of naive Bayes prediction with BLAS.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala41
1 files changed, 24 insertions, 17 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
index ac0ebeceaa..53fb2cba03 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
@@ -21,13 +21,11 @@ import java.lang.{Iterable => JIterable}
import scala.collection.JavaConverters._
-import breeze.linalg.{Axis, DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum}
-import breeze.numerics.{exp => brzExp, log => brzLog}
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._
import org.apache.spark.{Logging, SparkContext, SparkException}
-import org.apache.spark.mllib.linalg.{BLAS, DenseVector, SparseVector, Vector}
+import org.apache.spark.mllib.linalg.{BLAS, DenseMatrix, DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD
@@ -50,6 +48,9 @@ class NaiveBayesModel private[mllib] (
val modelType: String)
extends ClassificationModel with Serializable with Saveable {
+ private val piVector = new DenseVector(pi)
+ private val thetaMatrix = new DenseMatrix(labels.size, theta(0).size, theta.flatten, true)
+
private[mllib] def this(labels: Array[Double], pi: Array[Double], theta: Array[Array[Double]]) =
this(labels, pi, theta, "Multinomial")
@@ -60,17 +61,18 @@ class NaiveBayesModel private[mllib] (
theta: JIterable[JIterable[Double]]) =
this(labels.asScala.toArray, pi.asScala.toArray, theta.asScala.toArray.map(_.asScala.toArray))
- private val brzPi = new BDV[Double](pi)
- private val brzTheta = new BDM(theta(0).length, theta.length, theta.flatten).t
-
// Bernoulli scoring requires log(condprob) if 1, log(1-condprob) if 0.
- // This precomputes log(1.0 - exp(theta)) and its sum which are used for the linear algebra
+ // This precomputes log(1.0 - exp(theta)) and its sum which are used for the linear algebra
// application of this condition (in predict function).
- private val (brzNegTheta, brzNegThetaSum) = modelType match {
+ private val (thetaMinusNegTheta, negThetaSum) = modelType match {
case "Multinomial" => (None, None)
case "Bernoulli" =>
- val negTheta = brzLog((brzExp(brzTheta.copy) :*= (-1.0)) :+= 1.0) // log(1.0 - exp(x))
- (Option(negTheta), Option(brzSum(negTheta, Axis._1)))
+ val negTheta = thetaMatrix.map(value => math.log(1.0 - math.exp(value)))
+ val ones = new DenseVector(Array.fill(thetaMatrix.numCols){1.0})
+ val thetaMinusNegTheta = thetaMatrix.map { value =>
+ value - math.log(1.0 - math.exp(value))
+ }
+ (Option(thetaMinusNegTheta), Option(negTheta.multiply(ones)))
case _ =>
// This should never happen.
throw new UnknownError(s"NaiveBayesModel was created with an unknown ModelType: $modelType")
@@ -85,17 +87,22 @@ class NaiveBayesModel private[mllib] (
}
override def predict(testData: Vector): Double = {
- val brzData = testData.toBreeze
modelType match {
case "Multinomial" =>
- labels(brzArgmax(brzPi + brzTheta * brzData))
+ val prob = thetaMatrix.multiply(testData)
+ BLAS.axpy(1.0, piVector, prob)
+ labels(prob.argmax)
case "Bernoulli" =>
- if (!brzData.forall(v => v == 0.0 || v == 1.0)) {
- throw new SparkException(
- s"Bernoulli Naive Bayes requires 0 or 1 feature values but found $testData.")
+ testData.foreachActive { (index, value) =>
+ if (value != 0.0 && value != 1.0) {
+ throw new SparkException(
+ s"Bernoulli Naive Bayes requires 0 or 1 feature values but found $testData.")
+ }
}
- labels(brzArgmax(brzPi +
- (brzTheta - brzNegTheta.get) * brzData + brzNegThetaSum.get))
+ val prob = thetaMinusNegTheta.get.multiply(testData)
+ BLAS.axpy(1.0, piVector, prob)
+ BLAS.axpy(1.0, negThetaSum.get, prob)
+ labels(prob.argmax)
case _ =>
// This should never happen.
throw new UnknownError(s"NaiveBayesModel was created with an unknown ModelType: $modelType")