aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorAndrew Bullen <andrew.bullen@workday.com>2014-11-12 22:14:44 -0800
committerXiangrui Meng <meng@databricks.com>2014-11-12 22:14:44 -0800
commit484fecbf1402c25f310be0b0a5ec15c11cbd65c3 (patch)
treec902a79d4cf8efdbcd1058a50196ec80b1e51223 /mllib
parentb9e1c2eb9b6f7fb609718ef20048a8da452d881b (diff)
downloadspark-484fecbf1402c25f310be0b0a5ec15c11cbd65c3.tar.gz
spark-484fecbf1402c25f310be0b0a5ec15c11cbd65c3.tar.bz2
spark-484fecbf1402c25f310be0b0a5ec15c11cbd65c3.zip
[SPARK-4256] Make Binary Evaluation Metrics functions defined in cases where there ar...
...e 0 positive or 0 negative examples. Author: Andrew Bullen <andrew.bullen@workday.com> Closes #3118 from abull/master and squashes the following commits: c2bf2b1 [Andrew Bullen] [SPARK-4256] Update Code formatting for BinaryClassificationMetricsSpec 36b0533 [Andrew Bullen] [SYMAN-4256] Extract BinaryClassificationMetricsSuite assertions into private method 4d2f79a [Andrew Bullen] [SPARK-4256] Refactor classification metrics tests - extract comparison functions in test f411e70 [Andrew Bullen] [SPARK-4256] Define precision as 1.0 when there are no positive examples; update code formatting per pull request comments d9a09ef [Andrew Bullen] Make Binary Evaluation Metrics functions defined in cases where there are 0 positive or 0 negative examples.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetricComputers.scala43
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala97
2 files changed, 113 insertions, 27 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetricComputers.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetricComputers.scala
index 562663ad36..be3319d60c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetricComputers.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetricComputers.scala
@@ -24,26 +24,43 @@ private[evaluation] trait BinaryClassificationMetricComputer extends Serializabl
def apply(c: BinaryConfusionMatrix): Double
}
-/** Precision. */
+/** Precision. Defined as 1.0 when there are no positive examples. */
private[evaluation] object Precision extends BinaryClassificationMetricComputer {
- override def apply(c: BinaryConfusionMatrix): Double =
- c.numTruePositives.toDouble / (c.numTruePositives + c.numFalsePositives)
+ override def apply(c: BinaryConfusionMatrix): Double = {
+ val totalPositives = c.numTruePositives + c.numFalsePositives
+ if (totalPositives == 0) {
+ 1.0
+ } else {
+ c.numTruePositives.toDouble / totalPositives
+ }
+ }
}
-/** False positive rate. */
+/** False positive rate. Defined as 0.0 when there are no negative examples. */
private[evaluation] object FalsePositiveRate extends BinaryClassificationMetricComputer {
- override def apply(c: BinaryConfusionMatrix): Double =
- c.numFalsePositives.toDouble / c.numNegatives
+ override def apply(c: BinaryConfusionMatrix): Double = {
+ if (c.numNegatives == 0) {
+ 0.0
+ } else {
+ c.numFalsePositives.toDouble / c.numNegatives
+ }
+ }
}
-/** Recall. */
+/** Recall. Defined as 0.0 when there are no positive examples. */
private[evaluation] object Recall extends BinaryClassificationMetricComputer {
- override def apply(c: BinaryConfusionMatrix): Double =
- c.numTruePositives.toDouble / c.numPositives
+ override def apply(c: BinaryConfusionMatrix): Double = {
+ if (c.numPositives == 0) {
+ 0.0
+ } else {
+ c.numTruePositives.toDouble / c.numPositives
+ }
+ }
}
/**
- * F-Measure.
+ * F-Measure. Defined as 0 if both precision and recall are 0. EG in the case that all examples
+ * are false positives.
* @param beta the beta constant in F-Measure
* @see http://en.wikipedia.org/wiki/F1_score
*/
@@ -52,6 +69,10 @@ private[evaluation] case class FMeasure(beta: Double) extends BinaryClassificati
override def apply(c: BinaryConfusionMatrix): Double = {
val precision = Precision(c)
val recall = Recall(c)
- (1.0 + beta2) * (precision * recall) / (beta2 * precision + recall)
+ if (precision + recall == 0) {
+ 0.0
+ } else {
+ (1.0 + beta2) * (precision * recall) / (beta2 * precision + recall)
+ }
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala
index 3a29ccb519..8a18e2971c 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala
@@ -24,39 +24,104 @@ import org.apache.spark.mllib.util.TestingUtils._
class BinaryClassificationMetricsSuite extends FunSuite with MLlibTestSparkContext {
- def cond1(x: (Double, Double)): Boolean = x._1 ~= (x._2) absTol 1E-5
+ private def areWithinEpsilon(x: (Double, Double)): Boolean = x._1 ~= (x._2) absTol 1E-5
- def cond2(x: ((Double, Double), (Double, Double))): Boolean =
+ private def pairsWithinEpsilon(x: ((Double, Double), (Double, Double))): Boolean =
(x._1._1 ~= x._2._1 absTol 1E-5) && (x._1._2 ~= x._2._2 absTol 1E-5)
+ private def assertSequencesMatch(left: Seq[Double], right: Seq[Double]): Unit = {
+ assert(left.zip(right).forall(areWithinEpsilon))
+ }
+
+ private def assertTupleSequencesMatch(left: Seq[(Double, Double)],
+ right: Seq[(Double, Double)]): Unit = {
+ assert(left.zip(right).forall(pairsWithinEpsilon))
+ }
+
+ private def validateMetrics(metrics: BinaryClassificationMetrics,
+ expectedThresholds: Seq[Double],
+ expectedROCCurve: Seq[(Double, Double)],
+ expectedPRCurve: Seq[(Double, Double)],
+ expectedFMeasures1: Seq[Double],
+ expectedFmeasures2: Seq[Double],
+ expectedPrecisions: Seq[Double],
+ expectedRecalls: Seq[Double]) = {
+
+ assertSequencesMatch(metrics.thresholds().collect(), expectedThresholds)
+ assertTupleSequencesMatch(metrics.roc().collect(), expectedROCCurve)
+ assert(metrics.areaUnderROC() ~== AreaUnderCurve.of(expectedROCCurve) absTol 1E-5)
+ assertTupleSequencesMatch(metrics.pr().collect(), expectedPRCurve)
+ assert(metrics.areaUnderPR() ~== AreaUnderCurve.of(expectedPRCurve) absTol 1E-5)
+ assertTupleSequencesMatch(metrics.fMeasureByThreshold().collect(),
+ expectedThresholds.zip(expectedFMeasures1))
+ assertTupleSequencesMatch(metrics.fMeasureByThreshold(2.0).collect(),
+ expectedThresholds.zip(expectedFmeasures2))
+ assertTupleSequencesMatch(metrics.precisionByThreshold().collect(),
+ expectedThresholds.zip(expectedPrecisions))
+ assertTupleSequencesMatch(metrics.recallByThreshold().collect(),
+ expectedThresholds.zip(expectedRecalls))
+ }
+
test("binary evaluation metrics") {
val scoreAndLabels = sc.parallelize(
Seq((0.1, 0.0), (0.1, 1.0), (0.4, 0.0), (0.6, 0.0), (0.6, 1.0), (0.6, 1.0), (0.8, 1.0)), 2)
val metrics = new BinaryClassificationMetrics(scoreAndLabels)
- val threshold = Seq(0.8, 0.6, 0.4, 0.1)
+ val thresholds = Seq(0.8, 0.6, 0.4, 0.1)
val numTruePositives = Seq(1, 3, 3, 4)
val numFalsePositives = Seq(0, 1, 2, 3)
val numPositives = 4
val numNegatives = 3
- val precision = numTruePositives.zip(numFalsePositives).map { case (t, f) =>
+ val precisions = numTruePositives.zip(numFalsePositives).map { case (t, f) =>
t.toDouble / (t + f)
}
- val recall = numTruePositives.map(t => t.toDouble / numPositives)
+ val recalls = numTruePositives.map(t => t.toDouble / numPositives)
val fpr = numFalsePositives.map(f => f.toDouble / numNegatives)
- val rocCurve = Seq((0.0, 0.0)) ++ fpr.zip(recall) ++ Seq((1.0, 1.0))
- val pr = recall.zip(precision)
+ val rocCurve = Seq((0.0, 0.0)) ++ fpr.zip(recalls) ++ Seq((1.0, 1.0))
+ val pr = recalls.zip(precisions)
val prCurve = Seq((0.0, 1.0)) ++ pr
val f1 = pr.map { case (r, p) => 2.0 * (p * r) / (p + r)}
val f2 = pr.map { case (r, p) => 5.0 * (p * r) / (4.0 * p + r)}
- assert(metrics.thresholds().collect().zip(threshold).forall(cond1))
- assert(metrics.roc().collect().zip(rocCurve).forall(cond2))
- assert(metrics.areaUnderROC() ~== AreaUnderCurve.of(rocCurve) absTol 1E-5)
- assert(metrics.pr().collect().zip(prCurve).forall(cond2))
- assert(metrics.areaUnderPR() ~== AreaUnderCurve.of(prCurve) absTol 1E-5)
- assert(metrics.fMeasureByThreshold().collect().zip(threshold.zip(f1)).forall(cond2))
- assert(metrics.fMeasureByThreshold(2.0).collect().zip(threshold.zip(f2)).forall(cond2))
- assert(metrics.precisionByThreshold().collect().zip(threshold.zip(precision)).forall(cond2))
- assert(metrics.recallByThreshold().collect().zip(threshold.zip(recall)).forall(cond2))
+ validateMetrics(metrics, thresholds, rocCurve, prCurve, f1, f2, precisions, recalls)
+ }
+
+ test("binary evaluation metrics for RDD where all examples have positive label") {
+ val scoreAndLabels = sc.parallelize(Seq((0.5, 1.0), (0.5, 1.0)), 2)
+ val metrics = new BinaryClassificationMetrics(scoreAndLabels)
+
+ val thresholds = Seq(0.5)
+ val precisions = Seq(1.0)
+ val recalls = Seq(1.0)
+ val fpr = Seq(0.0)
+ val rocCurve = Seq((0.0, 0.0)) ++ fpr.zip(recalls) ++ Seq((1.0, 1.0))
+ val pr = recalls.zip(precisions)
+ val prCurve = Seq((0.0, 1.0)) ++ pr
+ val f1 = pr.map { case (r, p) => 2.0 * (p * r) / (p + r)}
+ val f2 = pr.map { case (r, p) => 5.0 * (p * r) / (4.0 * p + r)}
+
+ validateMetrics(metrics, thresholds, rocCurve, prCurve, f1, f2, precisions, recalls)
+ }
+
+ test("binary evaluation metrics for RDD where all examples have negative label") {
+ val scoreAndLabels = sc.parallelize(Seq((0.5, 0.0), (0.5, 0.0)), 2)
+ val metrics = new BinaryClassificationMetrics(scoreAndLabels)
+
+ val thresholds = Seq(0.5)
+ val precisions = Seq(0.0)
+ val recalls = Seq(0.0)
+ val fpr = Seq(1.0)
+ val rocCurve = Seq((0.0, 0.0)) ++ fpr.zip(recalls) ++ Seq((1.0, 1.0))
+ val pr = recalls.zip(precisions)
+ val prCurve = Seq((0.0, 1.0)) ++ pr
+ val f1 = pr.map {
+ case (0, 0) => 0.0
+ case (r, p) => 2.0 * (p * r) / (p + r)
+ }
+ val f2 = pr.map {
+ case (0, 0) => 0.0
+ case (r, p) => 5.0 * (p * r) / (4.0 * p + r)
+ }
+
+ validateMetrics(metrics, thresholds, rocCurve, prCurve, f1, f2, precisions, recalls)
}
}