From f6f7ca9d2ef65da15f42085993e58e618637fad5 Mon Sep 17 00:00:00 2001 From: BenFradet Date: Tue, 19 Jan 2016 14:59:20 -0800 Subject: [SPARK-9716][ML] BinaryClassificationEvaluator should accept Double prediction column This PR aims to allow the prediction column of `BinaryClassificationEvaluator` to be of double type. Author: BenFradet Closes #10472 from BenFradet/SPARK-9716. --- .../evaluation/BinaryClassificationEvaluator.scala | 9 ++++-- .../org/apache/spark/ml/util/SchemaUtils.scala | 17 ++++++++++++ .../BinaryClassificationEvaluatorSuite.scala | 32 ++++++++++++++++++++++ python/pyspark/ml/evaluation.py | 5 ++-- 4 files changed, 58 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala index f71726f110..a1d36c4bec 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala @@ -29,6 +29,8 @@ import org.apache.spark.sql.types.DoubleType /** * :: Experimental :: * Evaluator for binary classification, which expects two input columns: rawPrediction and label. + * The rawPrediction column can be of type double (binary 0/1 prediction, or probability of label 1) + * or of type vector (length-2 vector of raw predictions, scores, or label probabilities). */ @Since("1.2.0") @Experimental @@ -78,13 +80,14 @@ class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override va @Since("1.2.0") override def evaluate(dataset: DataFrame): Double = { val schema = dataset.schema - SchemaUtils.checkColumnType(schema, $(rawPredictionCol), new VectorUDT) + SchemaUtils.checkColumnTypes(schema, $(rawPredictionCol), Seq(DoubleType, new VectorUDT)) SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType) // TODO: When dataset metadata has been implemented, check rawPredictionCol vector length = 2. val scoreAndLabels = dataset.select($(rawPredictionCol), $(labelCol)) - .map { case Row(rawPrediction: Vector, label: Double) => - (rawPrediction(1), label) + .map { + case Row(rawPrediction: Vector, label: Double) => (rawPrediction(1), label) + case Row(rawPrediction: Double, label: Double) => (rawPrediction, label) } val metrics = new BinaryClassificationMetrics(scoreAndLabels) val metric = $(metricName) match { diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala index 76f651488a..e71dd9eee0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala @@ -43,6 +43,23 @@ private[spark] object SchemaUtils { s"Column $colName must be of type $dataType but was actually $actualDataType.$message") } + /** + * Check whether the given schema contains a column of one of the require data types. + * @param colName column name + * @param dataTypes required column data types + */ + def checkColumnTypes( + schema: StructType, + colName: String, + dataTypes: Seq[DataType], + msg: String = ""): Unit = { + val actualDataType = schema(colName).dataType + val message = if (msg != null && msg.trim.length > 0) " " + msg else "" + require(dataTypes.exists(actualDataType.equals), + s"Column $colName must be of type equal to one of the following types: " + + s"${dataTypes.mkString("[", ", ", "]")} but was actually of type $actualDataType.$message") + } + /** * Appends a new column to the input schema. This fails if the given output column already exists. * @param schema input schema diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala index a535c1218e..27349950dc 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.ml.evaluation import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext class BinaryClassificationEvaluatorSuite @@ -36,4 +37,35 @@ class BinaryClassificationEvaluatorSuite .setMetricName("areaUnderPR") testDefaultReadWrite(evaluator) } + + test("should accept both vector and double raw prediction col") { + val evaluator = new BinaryClassificationEvaluator() + .setMetricName("areaUnderPR") + + val vectorDF = sqlContext.createDataFrame(Seq( + (0d, Vectors.dense(12, 2.5)), + (1d, Vectors.dense(1, 3)), + (0d, Vectors.dense(10, 2)) + )).toDF("label", "rawPrediction") + assert(evaluator.evaluate(vectorDF) === 1.0) + + val doubleDF = sqlContext.createDataFrame(Seq( + (0d, 0d), + (1d, 1d), + (0d, 0d) + )).toDF("label", "rawPrediction") + assert(evaluator.evaluate(doubleDF) === 1.0) + + val stringDF = sqlContext.createDataFrame(Seq( + (0d, "0d"), + (1d, "1d"), + (0d, "0d") + )).toDF("label", "rawPrediction") + val thrown = intercept[IllegalArgumentException] { + evaluator.evaluate(stringDF) + } + assert(thrown.getMessage.replace("\n", "") contains "Column rawPrediction must be of type " + + "equal to one of the following types: [DoubleType, ") + assert(thrown.getMessage.replace("\n", "") contains "but was actually of type StringType.") + } } diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py index dcc1738ec5..6ff68abd8f 100644 --- a/python/pyspark/ml/evaluation.py +++ b/python/pyspark/ml/evaluation.py @@ -106,8 +106,9 @@ class JavaEvaluator(Evaluator, JavaWrapper): @inherit_doc class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPredictionCol): """ - Evaluator for binary classification, which expects two input - columns: rawPrediction and label. + Evaluator for binary classification, which expects two input columns: rawPrediction and label. + The rawPrediction column can be of type double (binary 0/1 prediction, or probability of label + 1) or of type vector (length-2 vector of raw predictions, scores, or label probabilities). >>> from pyspark.mllib.linalg import Vectors >>> scoreAndLabels = map(lambda x: (Vectors.dense([1.0 - x[0], x[0]]), x[1]), -- cgit v1.2.3