aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorBenFradet <benjamin.fradet@gmail.com>2016-01-19 14:59:20 -0800
committerJoseph K. Bradley <joseph@databricks.com>2016-01-19 14:59:20 -0800
commitf6f7ca9d2ef65da15f42085993e58e618637fad5 (patch)
tree5c2d266d6b48111ff853ea4bb468c08038e3c0fa /mllib
parent43f1d59e17d89d19b322d639c5069a3fc0c8e2ed (diff)
downloadspark-f6f7ca9d2ef65da15f42085993e58e618637fad5.tar.gz
spark-f6f7ca9d2ef65da15f42085993e58e618637fad5.tar.bz2
spark-f6f7ca9d2ef65da15f42085993e58e618637fad5.zip
[SPARK-9716][ML] BinaryClassificationEvaluator should accept Double prediction column
This PR aims to allow the prediction column of `BinaryClassificationEvaluator` to be of double type. Author: BenFradet <benjamin.fradet@gmail.com> Closes #10472 from BenFradet/SPARK-9716.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala9
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala17
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala32
3 files changed, 55 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
index f71726f110..a1d36c4bec 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
@@ -29,6 +29,8 @@ import org.apache.spark.sql.types.DoubleType
/**
* :: Experimental ::
* Evaluator for binary classification, which expects two input columns: rawPrediction and label.
+ * The rawPrediction column can be of type double (binary 0/1 prediction, or probability of label 1)
+ * or of type vector (length-2 vector of raw predictions, scores, or label probabilities).
*/
@Since("1.2.0")
@Experimental
@@ -78,13 +80,14 @@ class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override va
@Since("1.2.0")
override def evaluate(dataset: DataFrame): Double = {
val schema = dataset.schema
- SchemaUtils.checkColumnType(schema, $(rawPredictionCol), new VectorUDT)
+ SchemaUtils.checkColumnTypes(schema, $(rawPredictionCol), Seq(DoubleType, new VectorUDT))
SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType)
// TODO: When dataset metadata has been implemented, check rawPredictionCol vector length = 2.
val scoreAndLabels = dataset.select($(rawPredictionCol), $(labelCol))
- .map { case Row(rawPrediction: Vector, label: Double) =>
- (rawPrediction(1), label)
+ .map {
+ case Row(rawPrediction: Vector, label: Double) => (rawPrediction(1), label)
+ case Row(rawPrediction: Double, label: Double) => (rawPrediction, label)
}
val metrics = new BinaryClassificationMetrics(scoreAndLabels)
val metric = $(metricName) match {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala
index 76f651488a..e71dd9eee0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala
@@ -44,6 +44,23 @@ private[spark] object SchemaUtils {
}
/**
+ * Check whether the given schema contains a column of one of the require data types.
+ * @param colName column name
+ * @param dataTypes required column data types
+ */
+ def checkColumnTypes(
+ schema: StructType,
+ colName: String,
+ dataTypes: Seq[DataType],
+ msg: String = ""): Unit = {
+ val actualDataType = schema(colName).dataType
+ val message = if (msg != null && msg.trim.length > 0) " " + msg else ""
+ require(dataTypes.exists(actualDataType.equals),
+ s"Column $colName must be of type equal to one of the following types: " +
+ s"${dataTypes.mkString("[", ", ", "]")} but was actually of type $actualDataType.$message")
+ }
+
+ /**
* Appends a new column to the input schema. This fails if the given output column already exists.
* @param schema input schema
* @param colName new column name. If this column name is an empty string "", this method returns
diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala
index a535c1218e..27349950dc 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.ml.evaluation
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.DefaultReadWriteTest
+import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLlibTestSparkContext
class BinaryClassificationEvaluatorSuite
@@ -36,4 +37,35 @@ class BinaryClassificationEvaluatorSuite
.setMetricName("areaUnderPR")
testDefaultReadWrite(evaluator)
}
+
+ test("should accept both vector and double raw prediction col") {
+ val evaluator = new BinaryClassificationEvaluator()
+ .setMetricName("areaUnderPR")
+
+ val vectorDF = sqlContext.createDataFrame(Seq(
+ (0d, Vectors.dense(12, 2.5)),
+ (1d, Vectors.dense(1, 3)),
+ (0d, Vectors.dense(10, 2))
+ )).toDF("label", "rawPrediction")
+ assert(evaluator.evaluate(vectorDF) === 1.0)
+
+ val doubleDF = sqlContext.createDataFrame(Seq(
+ (0d, 0d),
+ (1d, 1d),
+ (0d, 0d)
+ )).toDF("label", "rawPrediction")
+ assert(evaluator.evaluate(doubleDF) === 1.0)
+
+ val stringDF = sqlContext.createDataFrame(Seq(
+ (0d, "0d"),
+ (1d, "1d"),
+ (0d, "0d")
+ )).toDF("label", "rawPrediction")
+ val thrown = intercept[IllegalArgumentException] {
+ evaluator.evaluate(stringDF)
+ }
+ assert(thrown.getMessage.replace("\n", "") contains "Column rawPrediction must be of type " +
+ "equal to one of the following types: [DoubleType, ")
+ assert(thrown.getMessage.replace("\n", "") contains "but was actually of type StringType.")
+ }
}