aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2016-10-12 19:56:40 -0700
committerYanbo Liang <ybliang8@gmail.com>2016-10-12 19:56:40 -0700
commit21cb59f1cd137d96b2596f1abe691b544581cf59 (patch)
tree308c713889af9e4333df8f5b3e0be134bda20978 /mllib/src/main
parent0d4a695279c514c76aa0e9288c70ac7aaef91b03 (diff)
downloadspark-21cb59f1cd137d96b2596f1abe691b544581cf59.tar.gz
spark-21cb59f1cd137d96b2596f1abe691b544581cf59.tar.bz2
spark-21cb59f1cd137d96b2596f1abe691b544581cf59.zip
[SPARK-17835][ML][MLLIB] Optimize NaiveBayes mllib wrapper to eliminate extra pass on data
## What changes were proposed in this pull request? [SPARK-14077](https://issues.apache.org/jira/browse/SPARK-14077) copied the ```NaiveBayes``` implementation from mllib to ml and left mllib as a wrapper. However, there are some difference between mllib and ml to handle labels: * mllib allow input labels as {-1, +1}, however, ml assumes the input labels in range [0, numClasses). * mllib ```NaiveBayesModel``` expose ```labels``` but ml did not due to the assumption mention above. During the copy in [SPARK-14077](https://issues.apache.org/jira/browse/SPARK-14077), we use ```val labels = data.map(_.label).distinct().collect().sorted``` to get the distinct labels firstly, and then encode the labels for training. It involves extra Spark job compared with the original implementation. Since ```NaiveBayes``` only do one pass aggregation during training, adding another one seems less efficient. We can get the labels in a single pass along with ```NaiveBayes``` training and send them to MLlib side. ## How was this patch tested? Existing tests. Author: Yanbo Liang <ybliang8@gmail.com> Closes #15402 from yanboliang/spark-17835.
Diffstat (limited to 'mllib/src/main')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala46
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala15
2 files changed, 43 insertions, 18 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
index e565a6fd3e..994ed993c9 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
@@ -110,16 +110,28 @@ class NaiveBayes @Since("1.5.0") (
@Since("2.1.0")
def setWeightCol(value: String): this.type = set(weightCol, value)
- override protected def train(dataset: Dataset[_]): NaiveBayesModel = {
- val numClasses = getNumClasses(dataset)
+ /**
+ * ml assumes input labels in range [0, numClasses). But this implementation
+ * is also called by mllib NaiveBayes which allows other kinds of input labels
+ * such as {-1, +1}. Here we use this parameter to switch between different processing logic.
+ * It should be removed when we remove mllib NaiveBayes.
+ */
+ private[spark] var isML: Boolean = true
- if (isDefined(thresholds)) {
- require($(thresholds).length == numClasses, this.getClass.getSimpleName +
- ".train() called with non-matching numClasses and thresholds.length." +
- s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
- }
+ private[spark] def setIsML(isML: Boolean): this.type = {
+ this.isML = isML
+ this
+ }
- val numFeatures = dataset.select(col($(featuresCol))).head().getAs[Vector](0).size
+ override protected def train(dataset: Dataset[_]): NaiveBayesModel = {
+ if (isML) {
+ val numClasses = getNumClasses(dataset)
+ if (isDefined(thresholds)) {
+ require($(thresholds).length == numClasses, this.getClass.getSimpleName +
+ ".train() called with non-matching numClasses and thresholds.length." +
+ s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
+ }
+ }
val requireNonnegativeValues: Vector => Unit = (v: Vector) => {
val values = v match {
@@ -153,6 +165,7 @@ class NaiveBayes @Since("1.5.0") (
}
}
+ val numFeatures = dataset.select(col($(featuresCol))).head().getAs[Vector](0).size
val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
// Aggregates term frequencies per label.
@@ -176,6 +189,7 @@ class NaiveBayes @Since("1.5.0") (
val numLabels = aggregated.length
val numDocuments = aggregated.map(_._2._1).sum
+ val labelArray = new Array[Double](numLabels)
val piArray = new Array[Double](numLabels)
val thetaArray = new Array[Double](numLabels * numFeatures)
@@ -183,6 +197,7 @@ class NaiveBayes @Since("1.5.0") (
val piLogDenom = math.log(numDocuments + numLabels * lambda)
var i = 0
aggregated.foreach { case (label, (n, sumTermFreqs)) =>
+ labelArray(i) = label
piArray(i) = math.log(n + lambda) - piLogDenom
val thetaLogDenom = $(modelType) match {
case Multinomial => math.log(sumTermFreqs.values.sum + numFeatures * lambda)
@@ -201,7 +216,7 @@ class NaiveBayes @Since("1.5.0") (
val pi = Vectors.dense(piArray)
val theta = new DenseMatrix(numLabels, numFeatures, thetaArray, true)
- new NaiveBayesModel(uid, pi, theta)
+ new NaiveBayesModel(uid, pi, theta).setOldLabels(labelArray)
}
@Since("1.5.0")
@@ -240,6 +255,19 @@ class NaiveBayesModel private[ml] (
import NaiveBayes.{Bernoulli, Multinomial}
/**
+ * mllib NaiveBayes is a wrapper of ml implementation currently.
+ * Input labels of mllib could be {-1, +1} and mllib NaiveBayesModel exposes labels,
+ * both of which are different from ml, so we should store the labels sequentially
+ * to be called by mllib. This should be removed when we remove mllib NaiveBayes.
+ */
+ private[spark] var oldLabels: Array[Double] = null
+
+ private[spark] def setOldLabels(labels: Array[Double]): this.type = {
+ this.oldLabels = labels
+ this
+ }
+
+ /**
* Bernoulli scoring requires log(condprob) if 1, log(1-condprob) if 0.
* This precomputes log(1.0 - exp(theta)) and its sum which are used for the linear algebra
* application of this condition (in predict function).
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
index 32d6968a4e..33561be4b5 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
@@ -364,15 +364,10 @@ class NaiveBayes private (
val nb = new NewNaiveBayes()
.setModelType(modelType)
.setSmoothing(lambda)
+ .setIsML(false)
- val labels = data.map(_.label).distinct().collect().sorted
-
- // Input labels for [[org.apache.spark.ml.classification.NaiveBayes]] must be
- // in range [0, numClasses).
- val dataset = data.map {
- case LabeledPoint(label, features) =>
- (labels.indexOf(label).toDouble, features.asML)
- }.toDF("label", "features")
+ val dataset = data.map { case LabeledPoint(label, features) => (label, features.asML) }
+ .toDF("label", "features")
val newModel = nb.fit(dataset)
@@ -383,7 +378,9 @@ class NaiveBayes private (
theta(i)(j) = v
}
- new NaiveBayesModel(labels, pi, theta, modelType)
+ require(newModel.oldLabels != null,
+ "The underlying ML NaiveBayes training does not produce labels.")
+ new NaiveBayesModel(newModel.oldLabels, pi, theta, modelType)
}
}