diff options
author | Xiangrui Meng <meng@databricks.com> | 2015-05-13 21:27:17 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-05-13 21:27:39 -0700 |
commit | 82f387fe23d3f5477df5d1be9a47d6df63fcbcf6 (patch) | |
tree | eb78a4e556ceb1dbafe3f69ad4ab5d7f731ae9e1 /mllib | |
parent | 2d4a961f82ccea3d3fc6d21fae1fc3a52e338634 (diff) | |
download | spark-82f387fe23d3f5477df5d1be9a47d6df63fcbcf6.tar.gz spark-82f387fe23d3f5477df5d1be9a47d6df63fcbcf6.tar.bz2 spark-82f387fe23d3f5477df5d1be9a47d6df63fcbcf6.zip |
[SPARK-7612] [MLLIB] update NB training to use mllib's BLAS
This is similar to the changes to k-means, which gives us better control on the performance. dbtsai
Author: Xiangrui Meng <meng@databricks.com>
Closes #6128 from mengxr/SPARK-7612 and squashes the following commits:
b5c24c5 [Xiangrui Meng] merge master
a90e3ec [Xiangrui Meng] update NB training to use mllib's BLAS
(cherry picked from commit d5f18de1657bfabf5493011e0b2c7ec29c02c64c)
Signed-off-by: Xiangrui Meng <meng@databricks.com>
Diffstat (limited to 'mllib')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala | 43 |
1 files changed, 20 insertions, 23 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index b381dc2cb0..af24ab6166 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -21,15 +21,13 @@ import java.lang.{Iterable => JIterable} import scala.collection.JavaConverters._ -import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum, Axis} +import breeze.linalg.{Axis, DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum} import breeze.numerics.{exp => brzExp, log => brzLog} - import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ -import org.json4s.{DefaultFormats, JValue} import org.apache.spark.{Logging, SparkContext, SparkException} -import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector} +import org.apache.spark.mllib.linalg.{BLAS, DenseVector, SparseVector, Vector} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD @@ -90,13 +88,13 @@ class NaiveBayesModel private[mllib] ( val brzData = testData.toBreeze modelType match { case "Multinomial" => - labels (brzArgmax (brzPi + brzTheta * brzData) ) + labels(brzArgmax(brzPi + brzTheta * brzData)) case "Bernoulli" => if (!brzData.forall(v => v == 0.0 || v == 1.0)) { throw new SparkException( s"Bernoulli Naive Bayes requires 0 or 1 feature values but found $testData.") } - labels (brzArgmax (brzPi + + labels(brzArgmax(brzPi + (brzTheta - brzNegTheta.get) * brzData + brzNegThetaSum.get)) case _ => // This should never happen. @@ -152,7 +150,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { // Check schema explicitly since erasure makes it hard to use match-case for checking. checkSchema[Data](dataRDD.schema) val dataArray = dataRDD.select("labels", "pi", "theta", "modelType").take(1) - assert(dataArray.size == 1, s"Unable to load NaiveBayesModel data from: ${dataPath(path)}") + assert(dataArray.length == 1, s"Unable to load NaiveBayesModel data from: ${dataPath(path)}") val data = dataArray(0) val labels = data.getAs[Seq[Double]](0).toArray val pi = data.getAs[Seq[Double]](1).toArray @@ -198,7 +196,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { // Check schema explicitly since erasure makes it hard to use match-case for checking. checkSchema[Data](dataRDD.schema) val dataArray = dataRDD.select("labels", "pi", "theta").take(1) - assert(dataArray.size == 1, s"Unable to load NaiveBayesModel data from: ${dataPath(path)}") + assert(dataArray.length == 1, s"Unable to load NaiveBayesModel data from: ${dataPath(path)}") val data = dataArray(0) val labels = data.getAs[Seq[Double]](0).toArray val pi = data.getAs[Seq[Double]](1).toArray @@ -288,10 +286,8 @@ class NaiveBayes private ( def run(data: RDD[LabeledPoint]): NaiveBayesModel = { val requireNonnegativeValues: Vector => Unit = (v: Vector) => { val values = v match { - case SparseVector(size, indices, values) => - values - case DenseVector(values) => - values + case sv: SparseVector => sv.values + case dv: DenseVector => dv.values } if (!values.forall(_ >= 0.0)) { throw new SparkException(s"Naive Bayes requires nonnegative feature values but found $v.") @@ -300,10 +296,8 @@ class NaiveBayes private ( val requireZeroOneBernoulliValues: Vector => Unit = (v: Vector) => { val values = v match { - case SparseVector(size, indices, values) => - values - case DenseVector(values) => - values + case sv: SparseVector => sv.values + case dv: DenseVector => dv.values } if (!values.forall(v => v == 0.0 || v == 1.0)) { throw new SparkException( @@ -314,21 +308,24 @@ class NaiveBayes private ( // Aggregates term frequencies per label. // TODO: Calling combineByKey and collect creates two stages, we can implement something // TODO: similar to reduceByKeyLocally to save one stage. - val aggregated = data.map(p => (p.label, p.features)).combineByKey[(Long, BDV[Double])]( + val aggregated = data.map(p => (p.label, p.features)).combineByKey[(Long, DenseVector)]( createCombiner = (v: Vector) => { if (modelType == "Bernoulli") { requireZeroOneBernoulliValues(v) } else { requireNonnegativeValues(v) } - (1L, v.toBreeze.toDenseVector) + (1L, v.copy.toDense) }, - mergeValue = (c: (Long, BDV[Double]), v: Vector) => { + mergeValue = (c: (Long, DenseVector), v: Vector) => { requireNonnegativeValues(v) - (c._1 + 1L, c._2 += v.toBreeze) + BLAS.axpy(1.0, v, c._2) + (c._1 + 1L, c._2) }, - mergeCombiners = (c1: (Long, BDV[Double]), c2: (Long, BDV[Double])) => - (c1._1 + c2._1, c1._2 += c2._2) + mergeCombiners = (c1: (Long, DenseVector), c2: (Long, DenseVector)) => { + BLAS.axpy(1.0, c2._2, c1._2) + (c1._1 + c2._1, c1._2) + } ).collect() val numLabels = aggregated.length @@ -348,7 +345,7 @@ class NaiveBayes private ( labels(i) = label pi(i) = math.log(n + lambda) - piLogDenom val thetaLogDenom = modelType match { - case "Multinomial" => math.log(brzSum(sumTermFreqs) + numFeatures * lambda) + case "Multinomial" => math.log(sumTermFreqs.values.sum + numFeatures * lambda) case "Bernoulli" => math.log(n + 2.0 * lambda) case _ => // This should never happen. |