aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala7
1 files changed, 6 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
index 88642abf63..19df8f7edd 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
@@ -83,7 +83,12 @@ abstract class ProbabilisticClassificationModel[
def setProbabilityCol(value: String): M = set(probabilityCol, value).asInstanceOf[M]
/** @group setParam */
- def setThresholds(value: Array[Double]): M = set(thresholds, value).asInstanceOf[M]
+ def setThresholds(value: Array[Double]): M = {
+ require(value.length == numClasses, this.getClass.getSimpleName +
+ ".setThresholds() called with non-matching numClasses and thresholds.length." +
+ s" numClasses=$numClasses, but thresholds has length ${value.length}")
+ set(thresholds, value).asInstanceOf[M]
+ }
/**
* Transforms dataset by reading from [[featuresCol]], and appending new columns as specified by