aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorZheng RuiFeng <ruifengz@foxmail.com>2016-08-04 21:44:54 +0100
committerSean Owen <sowen@cloudera.com>2016-08-04 21:44:54 +0100
commit0e2e5d7d0b42226c61c3200fd63d2831c558519d (patch)
treeda78e677d02b8968f38c5c7332ddf2caa3288dc3 /mllib
parent1d781572e832058e2ef54bccd76ef71bc1fd548c (diff)
downloadspark-0e2e5d7d0b42226c61c3200fd63d2831c558519d.tar.gz
spark-0e2e5d7d0b42226c61c3200fd63d2831c558519d.tar.bz2
spark-0e2e5d7d0b42226c61c3200fd63d2831c558519d.zip
[SPARK-16863][ML] ProbabilisticClassifier.fit check threshoulds' length
## What changes were proposed in this pull request? Add threshoulds' length checking for Classifiers which extends ProbabilisticClassifier ## How was this patch tested? unit tests and manual tests Author: Zheng RuiFeng <ruifengz@foxmail.com> Closes #14470 from zhengruifeng/classifier_check_setThreshoulds_length.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala7
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala8
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala7
4 files changed, 28 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
index 71293017e0..bb192ab5f2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
@@ -84,6 +84,13 @@ class DecisionTreeClassifier @Since("1.4.0") (
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
val numClasses: Int = getNumClasses(dataset)
+
+ if (isDefined(thresholds)) {
+ require($(thresholds).length == numClasses, this.getClass.getSimpleName +
+ ".train() called with non-matching numClasses and thresholds.length." +
+ s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
+ }
+
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses)
val strategy = getOldStrategy(categoricalFeatures, numClasses)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 7694773c81..90baa41918 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -292,6 +292,12 @@ class LogisticRegression @Since("1.2.0") (
val numClasses = histogram.length
val numFeatures = summarizer.mean.size
+ if (isDefined(thresholds)) {
+ require($(thresholds).length == numClasses, this.getClass.getSimpleName +
+ ".train() called with non-matching numClasses and thresholds.length." +
+ s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
+ }
+
instr.logNumClasses(numClasses)
instr.logNumFeatures(numFeatures)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
index ab977c8802..f939a1c680 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
@@ -101,6 +101,14 @@ class NaiveBayes @Since("1.5.0") (
setDefault(modelType -> OldNaiveBayes.Multinomial)
override protected def train(dataset: Dataset[_]): NaiveBayesModel = {
+ val numClasses = getNumClasses(dataset)
+
+ if (isDefined(thresholds)) {
+ require($(thresholds).length == numClasses, this.getClass.getSimpleName +
+ ".train() called with non-matching numClasses and thresholds.length." +
+ s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
+ }
+
val oldDataset: RDD[OldLabeledPoint] =
extractLabeledPoints(dataset).map(OldLabeledPoint.fromML)
val oldModel = OldNaiveBayes.train(oldDataset, $(smoothing), $(modelType))
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index 4ab132e5f2..52345b0626 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -100,6 +100,13 @@ class RandomForestClassifier @Since("1.4.0") (
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
val numClasses: Int = getNumClasses(dataset)
+
+ if (isDefined(thresholds)) {
+ require($(thresholds).length == numClasses, this.getClass.getSimpleName +
+ ".train() called with non-matching numClasses and thresholds.length." +
+ s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
+ }
+
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses)
val strategy =
super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity)