aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-05-13 21:27:17 -0700
committerXiangrui Meng <meng@databricks.com>2015-05-13 21:27:17 -0700
commitd5f18de1657bfabf5493011e0b2c7ec29c02c64c (patch)
treea7686906e3598971003c674371ef230e197f2732
parent3113da9c7067bbf90639866ae9d946f02cc484ff (diff)
downloadspark-d5f18de1657bfabf5493011e0b2c7ec29c02c64c.tar.gz
spark-d5f18de1657bfabf5493011e0b2c7ec29c02c64c.tar.bz2
spark-d5f18de1657bfabf5493011e0b2c7ec29c02c64c.zip
[SPARK-7612] [MLLIB] update NB training to use mllib's BLAS
This is similar to the changes to k-means, which gives us better control on the performance. dbtsai Author: Xiangrui Meng <meng@databricks.com> Closes #6128 from mengxr/SPARK-7612 and squashes the following commits: b5c24c5 [Xiangrui Meng] merge master a90e3ec [Xiangrui Meng] update NB training to use mllib's BLAS
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala43
1 files changed, 20 insertions, 23 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
index b381dc2cb0..af24ab6166 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
@@ -21,15 +21,13 @@ import java.lang.{Iterable => JIterable}
import scala.collection.JavaConverters._
-import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum, Axis}
+import breeze.linalg.{Axis, DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum}
import breeze.numerics.{exp => brzExp, log => brzLog}
-
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._
-import org.json4s.{DefaultFormats, JValue}
import org.apache.spark.{Logging, SparkContext, SparkException}
-import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector}
+import org.apache.spark.mllib.linalg.{BLAS, DenseVector, SparseVector, Vector}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD
@@ -90,13 +88,13 @@ class NaiveBayesModel private[mllib] (
val brzData = testData.toBreeze
modelType match {
case "Multinomial" =>
- labels (brzArgmax (brzPi + brzTheta * brzData) )
+ labels(brzArgmax(brzPi + brzTheta * brzData))
case "Bernoulli" =>
if (!brzData.forall(v => v == 0.0 || v == 1.0)) {
throw new SparkException(
s"Bernoulli Naive Bayes requires 0 or 1 feature values but found $testData.")
}
- labels (brzArgmax (brzPi +
+ labels(brzArgmax(brzPi +
(brzTheta - brzNegTheta.get) * brzData + brzNegThetaSum.get))
case _ =>
// This should never happen.
@@ -152,7 +150,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
// Check schema explicitly since erasure makes it hard to use match-case for checking.
checkSchema[Data](dataRDD.schema)
val dataArray = dataRDD.select("labels", "pi", "theta", "modelType").take(1)
- assert(dataArray.size == 1, s"Unable to load NaiveBayesModel data from: ${dataPath(path)}")
+ assert(dataArray.length == 1, s"Unable to load NaiveBayesModel data from: ${dataPath(path)}")
val data = dataArray(0)
val labels = data.getAs[Seq[Double]](0).toArray
val pi = data.getAs[Seq[Double]](1).toArray
@@ -198,7 +196,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
// Check schema explicitly since erasure makes it hard to use match-case for checking.
checkSchema[Data](dataRDD.schema)
val dataArray = dataRDD.select("labels", "pi", "theta").take(1)
- assert(dataArray.size == 1, s"Unable to load NaiveBayesModel data from: ${dataPath(path)}")
+ assert(dataArray.length == 1, s"Unable to load NaiveBayesModel data from: ${dataPath(path)}")
val data = dataArray(0)
val labels = data.getAs[Seq[Double]](0).toArray
val pi = data.getAs[Seq[Double]](1).toArray
@@ -288,10 +286,8 @@ class NaiveBayes private (
def run(data: RDD[LabeledPoint]): NaiveBayesModel = {
val requireNonnegativeValues: Vector => Unit = (v: Vector) => {
val values = v match {
- case SparseVector(size, indices, values) =>
- values
- case DenseVector(values) =>
- values
+ case sv: SparseVector => sv.values
+ case dv: DenseVector => dv.values
}
if (!values.forall(_ >= 0.0)) {
throw new SparkException(s"Naive Bayes requires nonnegative feature values but found $v.")
@@ -300,10 +296,8 @@ class NaiveBayes private (
val requireZeroOneBernoulliValues: Vector => Unit = (v: Vector) => {
val values = v match {
- case SparseVector(size, indices, values) =>
- values
- case DenseVector(values) =>
- values
+ case sv: SparseVector => sv.values
+ case dv: DenseVector => dv.values
}
if (!values.forall(v => v == 0.0 || v == 1.0)) {
throw new SparkException(
@@ -314,21 +308,24 @@ class NaiveBayes private (
// Aggregates term frequencies per label.
// TODO: Calling combineByKey and collect creates two stages, we can implement something
// TODO: similar to reduceByKeyLocally to save one stage.
- val aggregated = data.map(p => (p.label, p.features)).combineByKey[(Long, BDV[Double])](
+ val aggregated = data.map(p => (p.label, p.features)).combineByKey[(Long, DenseVector)](
createCombiner = (v: Vector) => {
if (modelType == "Bernoulli") {
requireZeroOneBernoulliValues(v)
} else {
requireNonnegativeValues(v)
}
- (1L, v.toBreeze.toDenseVector)
+ (1L, v.copy.toDense)
},
- mergeValue = (c: (Long, BDV[Double]), v: Vector) => {
+ mergeValue = (c: (Long, DenseVector), v: Vector) => {
requireNonnegativeValues(v)
- (c._1 + 1L, c._2 += v.toBreeze)
+ BLAS.axpy(1.0, v, c._2)
+ (c._1 + 1L, c._2)
},
- mergeCombiners = (c1: (Long, BDV[Double]), c2: (Long, BDV[Double])) =>
- (c1._1 + c2._1, c1._2 += c2._2)
+ mergeCombiners = (c1: (Long, DenseVector), c2: (Long, DenseVector)) => {
+ BLAS.axpy(1.0, c2._2, c1._2)
+ (c1._1 + c2._1, c1._2)
+ }
).collect()
val numLabels = aggregated.length
@@ -348,7 +345,7 @@ class NaiveBayes private (
labels(i) = label
pi(i) = math.log(n + lambda) - piLogDenom
val thetaLogDenom = modelType match {
- case "Multinomial" => math.log(brzSum(sumTermFreqs) + numFeatures * lambda)
+ case "Multinomial" => math.log(sumTermFreqs.values.sum + numFeatures * lambda)
case "Bernoulli" => math.log(n + 2.0 * lambda)
case _ =>
// This should never happen.