aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorZheng RuiFeng <ruifengz@foxmail.com>2016-08-02 07:22:41 -0700
committerYanbo Liang <ybliang8@gmail.com>2016-08-02 07:22:41 -0700
commitd9e0919d30e9f79a0eb1ceb8d1b5e9fc58cf085e (patch)
tree0a273828eeacf49d1c9fc0d350af080fcfd7c98c /mllib
parenta1ff72e1cce6f22249ccc4905e8cef30075beb2f (diff)
downloadspark-d9e0919d30e9f79a0eb1ceb8d1b5e9fc58cf085e.tar.gz
spark-d9e0919d30e9f79a0eb1ceb8d1b5e9fc58cf085e.tar.bz2
spark-d9e0919d30e9f79a0eb1ceb8d1b5e9fc58cf085e.zip
[SPARK-16851][ML] Incorrect threshould length in 'setThresholds()' evoke Exception
## What changes were proposed in this pull request? Add a length checking for threshoulds' length in method `setThreshoulds()` of classification models. ## How was this patch tested? unit tests Author: Zheng RuiFeng <ruifengz@foxmail.com> Closes #14457 from zhengruifeng/check_setThresholds.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala7
1 files changed, 6 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
index 88642abf63..19df8f7edd 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
@@ -83,7 +83,12 @@ abstract class ProbabilisticClassificationModel[
def setProbabilityCol(value: String): M = set(probabilityCol, value).asInstanceOf[M]
/** @group setParam */
- def setThresholds(value: Array[Double]): M = set(thresholds, value).asInstanceOf[M]
+ def setThresholds(value: Array[Double]): M = {
+ require(value.length == numClasses, this.getClass.getSimpleName +
+ ".setThresholds() called with non-matching numClasses and thresholds.length." +
+ s" numClasses=$numClasses, but thresholds has length ${value.length}")
+ set(thresholds, value).asInstanceOf[M]
+ }
/**
* Transforms dataset by reading from [[featuresCol]], and appending new columns as specified by