From 0e2e5d7d0b42226c61c3200fd63d2831c558519d Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Thu, 4 Aug 2016 21:44:54 +0100 Subject: [SPARK-16863][ML] ProbabilisticClassifier.fit check threshoulds' length ## What changes were proposed in this pull request? Add threshoulds' length checking for Classifiers which extends ProbabilisticClassifier ## How was this patch tested? unit tests and manual tests Author: Zheng RuiFeng Closes #14470 from zhengruifeng/classifier_check_setThreshoulds_length. --- .../apache/spark/ml/classification/DecisionTreeClassifier.scala | 7 +++++++ .../org/apache/spark/ml/classification/LogisticRegression.scala | 6 ++++++ .../scala/org/apache/spark/ml/classification/NaiveBayes.scala | 8 ++++++++ .../apache/spark/ml/classification/RandomForestClassifier.scala | 7 +++++++ 4 files changed, 28 insertions(+) (limited to 'mllib/src') diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 71293017e0..bb192ab5f2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -84,6 +84,13 @@ class DecisionTreeClassifier @Since("1.4.0") ( val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) val numClasses: Int = getNumClasses(dataset) + + if (isDefined(thresholds)) { + require($(thresholds).length == numClasses, this.getClass.getSimpleName + + ".train() called with non-matching numClasses and thresholds.length." + + s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}") + } + val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses) val strategy = getOldStrategy(categoricalFeatures, numClasses) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 7694773c81..90baa41918 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -292,6 +292,12 @@ class LogisticRegression @Since("1.2.0") ( val numClasses = histogram.length val numFeatures = summarizer.mean.size + if (isDefined(thresholds)) { + require($(thresholds).length == numClasses, this.getClass.getSimpleName + + ".train() called with non-matching numClasses and thresholds.length." + + s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}") + } + instr.logNumClasses(numClasses) instr.logNumFeatures(numFeatures) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index ab977c8802..f939a1c680 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -101,6 +101,14 @@ class NaiveBayes @Since("1.5.0") ( setDefault(modelType -> OldNaiveBayes.Multinomial) override protected def train(dataset: Dataset[_]): NaiveBayesModel = { + val numClasses = getNumClasses(dataset) + + if (isDefined(thresholds)) { + require($(thresholds).length == numClasses, this.getClass.getSimpleName + + ".train() called with non-matching numClasses and thresholds.length." + + s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}") + } + val oldDataset: RDD[OldLabeledPoint] = extractLabeledPoints(dataset).map(OldLabeledPoint.fromML) val oldModel = OldNaiveBayes.train(oldDataset, $(smoothing), $(modelType)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index 4ab132e5f2..52345b0626 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -100,6 +100,13 @@ class RandomForestClassifier @Since("1.4.0") ( val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) val numClasses: Int = getNumClasses(dataset) + + if (isDefined(thresholds)) { + require($(thresholds).length == numClasses, this.getClass.getSimpleName + + ".train() called with non-matching numClasses and thresholds.length." + + s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}") + } + val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses) val strategy = super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity) -- cgit v1.2.3