diff options
Diffstat (limited to 'mllib')
2 files changed, 6 insertions, 8 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala index 0b84e0a3fa..794b1e7d9d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala @@ -39,16 +39,16 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid def this() = this(Identifiable.randomUID("mcEval")) /** - * param for metric name in evaluation (supports `"f1"` (default), `"precision"`, `"recall"`, - * `"weightedPrecision"`, `"weightedRecall"`, `"accuracy"`) + * param for metric name in evaluation (supports `"f1"` (default), `"weightedPrecision"`, + * `"weightedRecall"`, `"accuracy"`) * @group param */ @Since("1.5.0") val metricName: Param[String] = { - val allowedParams = ParamValidators.inArray(Array("f1", "precision", - "recall", "weightedPrecision", "weightedRecall", "accuracy")) + val allowedParams = ParamValidators.inArray(Array("f1", "weightedPrecision", + "weightedRecall", "accuracy")) new Param(this, "metricName", "metric name in evaluation " + - "(f1|precision|recall|weightedPrecision|weightedRecall|accuracy)", allowedParams) + "(f1|weightedPrecision|weightedRecall|accuracy)", allowedParams) } /** @group getParam */ @@ -82,8 +82,6 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid val metrics = new MulticlassMetrics(predictionAndLabels) val metric = $(metricName) match { case "f1" => metrics.weightedFMeasure - case "precision" => metrics.accuracy - case "recall" => metrics.accuracy case "weightedPrecision" => metrics.weightedPrecision case "weightedRecall" => metrics.weightedRecall case "accuracy" => metrics.accuracy diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala index 522f6675d7..1a3a8a13a2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala @@ -33,7 +33,7 @@ class MulticlassClassificationEvaluatorSuite val evaluator = new MulticlassClassificationEvaluator() .setPredictionCol("myPrediction") .setLabelCol("myLabel") - .setMetricName("recall") + .setMetricName("accuracy") testDefaultReadWrite(evaluator) } |