aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala6
1 files changed, 3 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala
index 55ff44323a..3acfc221c9 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala
@@ -22,7 +22,7 @@ import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol}
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils}
import org.apache.spark.mllib.evaluation.MulticlassMetrics
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{Dataset, Row}
import org.apache.spark.sql.types.DoubleType
/**
@@ -68,8 +68,8 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid
setDefault(metricName -> "f1")
- @Since("1.5.0")
- override def evaluate(dataset: DataFrame): Double = {
+ @Since("2.0.0")
+ override def evaluate(dataset: Dataset[_]): Double = {
val schema = dataset.schema
SchemaUtils.checkColumnType(schema, $(predictionCol), DoubleType)
SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType)