aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala6
1 files changed, 3 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
index 337ffbe90f..bde8c275fd 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
@@ -23,7 +23,7 @@ import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils}
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{Dataset, Row}
import org.apache.spark.sql.types.DoubleType
/**
@@ -69,8 +69,8 @@ class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override va
setDefault(metricName -> "areaUnderROC")
- @Since("1.2.0")
- override def evaluate(dataset: DataFrame): Double = {
+ @Since("2.0.0")
+ override def evaluate(dataset: Dataset[_]): Double = {
val schema = dataset.schema
SchemaUtils.checkColumnTypes(schema, $(rawPredictionCol), Seq(DoubleType, new VectorUDT))
SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType)