aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
authorZheng RuiFeng <ruifengz@foxmail.com>2017-01-04 11:54:13 +0000
committerSean Owen <sowen@cloudera.com>2017-01-04 11:54:13 +0000
commit7a82505817d479007adff6424473063d2003fcc1 (patch)
tree4e7e9014d862f741deb53c8e3e101c95246e3bc1 /mllib/src
parent101556d0fa704deca0f4a2e5070906d4af2c861b (diff)
downloadspark-7a82505817d479007adff6424473063d2003fcc1.tar.gz
spark-7a82505817d479007adff6424473063d2003fcc1.tar.bz2
spark-7a82505817d479007adff6424473063d2003fcc1.zip
[SPARK-19054][ML] Eliminate extra pass in NB
## What changes were proposed in this pull request? eliminate unnecessary extra pass in NB's train ## How was this patch tested? existing tests Author: Zheng RuiFeng <ruifengz@foxmail.com> Closes #16453 from zhengruifeng/nb_getNC.
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala10
1 files changed, 4 insertions, 6 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
index 94ee2a2e7d..e90040dbf1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
@@ -127,13 +127,11 @@ class NaiveBayes @Since("1.5.0") (
private[spark] def trainWithLabelCheck(
dataset: Dataset[_],
positiveLabel: Boolean): NaiveBayesModel = {
- if (positiveLabel) {
+ if (positiveLabel && isDefined(thresholds)) {
val numClasses = getNumClasses(dataset)
- if (isDefined(thresholds)) {
- require($(thresholds).length == numClasses, this.getClass.getSimpleName +
- ".train() called with non-matching numClasses and thresholds.length." +
- s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
- }
+ require($(thresholds).length == numClasses, this.getClass.getSimpleName +
+ ".train() called with non-matching numClasses and thresholds.length." +
+ s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
}
val modelTypeValue = $(modelType)