diff options
author | Yanbo Liang <ybliang8@gmail.com> | 2016-10-12 19:56:40 -0700 |
---|---|---|
committer | Yanbo Liang <ybliang8@gmail.com> | 2016-10-12 19:56:40 -0700 |
commit | 21cb59f1cd137d96b2596f1abe691b544581cf59 (patch) | |
tree | 308c713889af9e4333df8f5b3e0be134bda20978 /mllib/src/main/scala/org | |
parent | 0d4a695279c514c76aa0e9288c70ac7aaef91b03 (diff) | |
download | spark-21cb59f1cd137d96b2596f1abe691b544581cf59.tar.gz spark-21cb59f1cd137d96b2596f1abe691b544581cf59.tar.bz2 spark-21cb59f1cd137d96b2596f1abe691b544581cf59.zip |
[SPARK-17835][ML][MLLIB] Optimize NaiveBayes mllib wrapper to eliminate extra pass on data
## What changes were proposed in this pull request?
[SPARK-14077](https://issues.apache.org/jira/browse/SPARK-14077) copied the ```NaiveBayes``` implementation from mllib to ml and left mllib as a wrapper. However, there are some difference between mllib and ml to handle labels:
* mllib allow input labels as {-1, +1}, however, ml assumes the input labels in range [0, numClasses).
* mllib ```NaiveBayesModel``` expose ```labels``` but ml did not due to the assumption mention above.
During the copy in [SPARK-14077](https://issues.apache.org/jira/browse/SPARK-14077), we use
```val labels = data.map(_.label).distinct().collect().sorted```
to get the distinct labels firstly, and then encode the labels for training. It involves extra Spark job compared with the original implementation. Since ```NaiveBayes``` only do one pass aggregation during training, adding another one seems less efficient. We can get the labels in a single pass along with ```NaiveBayes``` training and send them to MLlib side.
## How was this patch tested?
Existing tests.
Author: Yanbo Liang <ybliang8@gmail.com>
Closes #15402 from yanboliang/spark-17835.
Diffstat (limited to 'mllib/src/main/scala/org')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala | 46 | ||||
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala | 15 |
2 files changed, 43 insertions, 18 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index e565a6fd3e..994ed993c9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -110,16 +110,28 @@ class NaiveBayes @Since("1.5.0") ( @Since("2.1.0") def setWeightCol(value: String): this.type = set(weightCol, value) - override protected def train(dataset: Dataset[_]): NaiveBayesModel = { - val numClasses = getNumClasses(dataset) + /** + * ml assumes input labels in range [0, numClasses). But this implementation + * is also called by mllib NaiveBayes which allows other kinds of input labels + * such as {-1, +1}. Here we use this parameter to switch between different processing logic. + * It should be removed when we remove mllib NaiveBayes. + */ + private[spark] var isML: Boolean = true - if (isDefined(thresholds)) { - require($(thresholds).length == numClasses, this.getClass.getSimpleName + - ".train() called with non-matching numClasses and thresholds.length." + - s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}") - } + private[spark] def setIsML(isML: Boolean): this.type = { + this.isML = isML + this + } - val numFeatures = dataset.select(col($(featuresCol))).head().getAs[Vector](0).size + override protected def train(dataset: Dataset[_]): NaiveBayesModel = { + if (isML) { + val numClasses = getNumClasses(dataset) + if (isDefined(thresholds)) { + require($(thresholds).length == numClasses, this.getClass.getSimpleName + + ".train() called with non-matching numClasses and thresholds.length." + + s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}") + } + } val requireNonnegativeValues: Vector => Unit = (v: Vector) => { val values = v match { @@ -153,6 +165,7 @@ class NaiveBayes @Since("1.5.0") ( } } + val numFeatures = dataset.select(col($(featuresCol))).head().getAs[Vector](0).size val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)) // Aggregates term frequencies per label. @@ -176,6 +189,7 @@ class NaiveBayes @Since("1.5.0") ( val numLabels = aggregated.length val numDocuments = aggregated.map(_._2._1).sum + val labelArray = new Array[Double](numLabels) val piArray = new Array[Double](numLabels) val thetaArray = new Array[Double](numLabels * numFeatures) @@ -183,6 +197,7 @@ class NaiveBayes @Since("1.5.0") ( val piLogDenom = math.log(numDocuments + numLabels * lambda) var i = 0 aggregated.foreach { case (label, (n, sumTermFreqs)) => + labelArray(i) = label piArray(i) = math.log(n + lambda) - piLogDenom val thetaLogDenom = $(modelType) match { case Multinomial => math.log(sumTermFreqs.values.sum + numFeatures * lambda) @@ -201,7 +216,7 @@ class NaiveBayes @Since("1.5.0") ( val pi = Vectors.dense(piArray) val theta = new DenseMatrix(numLabels, numFeatures, thetaArray, true) - new NaiveBayesModel(uid, pi, theta) + new NaiveBayesModel(uid, pi, theta).setOldLabels(labelArray) } @Since("1.5.0") @@ -240,6 +255,19 @@ class NaiveBayesModel private[ml] ( import NaiveBayes.{Bernoulli, Multinomial} /** + * mllib NaiveBayes is a wrapper of ml implementation currently. + * Input labels of mllib could be {-1, +1} and mllib NaiveBayesModel exposes labels, + * both of which are different from ml, so we should store the labels sequentially + * to be called by mllib. This should be removed when we remove mllib NaiveBayes. + */ + private[spark] var oldLabels: Array[Double] = null + + private[spark] def setOldLabels(labels: Array[Double]): this.type = { + this.oldLabels = labels + this + } + + /** * Bernoulli scoring requires log(condprob) if 1, log(1-condprob) if 0. * This precomputes log(1.0 - exp(theta)) and its sum which are used for the linear algebra * application of this condition (in predict function). diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index 32d6968a4e..33561be4b5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -364,15 +364,10 @@ class NaiveBayes private ( val nb = new NewNaiveBayes() .setModelType(modelType) .setSmoothing(lambda) + .setIsML(false) - val labels = data.map(_.label).distinct().collect().sorted - - // Input labels for [[org.apache.spark.ml.classification.NaiveBayes]] must be - // in range [0, numClasses). - val dataset = data.map { - case LabeledPoint(label, features) => - (labels.indexOf(label).toDouble, features.asML) - }.toDF("label", "features") + val dataset = data.map { case LabeledPoint(label, features) => (label, features.asML) } + .toDF("label", "features") val newModel = nb.fit(dataset) @@ -383,7 +378,9 @@ class NaiveBayes private ( theta(i)(j) = v } - new NaiveBayesModel(labels, pi, theta, modelType) + require(newModel.oldLabels != null, + "The underlying ML NaiveBayes training does not produce labels.") + new NaiveBayesModel(newModel.oldLabels, pi, theta, modelType) } } |