diff options
author | Zheng RuiFeng <ruifengz@foxmail.com> | 2016-08-02 07:22:41 -0700 |
---|---|---|
committer | Yanbo Liang <ybliang8@gmail.com> | 2016-08-02 07:22:41 -0700 |
commit | d9e0919d30e9f79a0eb1ceb8d1b5e9fc58cf085e (patch) | |
tree | 0a273828eeacf49d1c9fc0d350af080fcfd7c98c /mllib/src | |
parent | a1ff72e1cce6f22249ccc4905e8cef30075beb2f (diff) | |
download | spark-d9e0919d30e9f79a0eb1ceb8d1b5e9fc58cf085e.tar.gz spark-d9e0919d30e9f79a0eb1ceb8d1b5e9fc58cf085e.tar.bz2 spark-d9e0919d30e9f79a0eb1ceb8d1b5e9fc58cf085e.zip |
[SPARK-16851][ML] Incorrect threshould length in 'setThresholds()' evoke Exception
## What changes were proposed in this pull request?
Add a length checking for threshoulds' length in method `setThreshoulds()` of classification models.
## How was this patch tested?
unit tests
Author: Zheng RuiFeng <ruifengz@foxmail.com>
Closes #14457 from zhengruifeng/check_setThresholds.
Diffstat (limited to 'mllib/src')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala | 7 |
1 files changed, 6 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala index 88642abf63..19df8f7edd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala @@ -83,7 +83,12 @@ abstract class ProbabilisticClassificationModel[ def setProbabilityCol(value: String): M = set(probabilityCol, value).asInstanceOf[M] /** @group setParam */ - def setThresholds(value: Array[Double]): M = set(thresholds, value).asInstanceOf[M] + def setThresholds(value: Array[Double]): M = { + require(value.length == numClasses, this.getClass.getSimpleName + + ".setThresholds() called with non-matching numClasses and thresholds.length." + + s" numClasses=$numClasses, but thresholds has length ${value.length}") + set(thresholds, value).asInstanceOf[M] + } /** * Transforms dataset by reading from [[featuresCol]], and appending new columns as specified by |