diff options
author | BenFradet <benjamin.fradet@gmail.com> | 2016-01-19 14:59:20 -0800 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2016-01-19 14:59:20 -0800 |
commit | f6f7ca9d2ef65da15f42085993e58e618637fad5 (patch) | |
tree | 5c2d266d6b48111ff853ea4bb468c08038e3c0fa /mllib/src/test/scala/org/apache | |
parent | 43f1d59e17d89d19b322d639c5069a3fc0c8e2ed (diff) | |
download | spark-f6f7ca9d2ef65da15f42085993e58e618637fad5.tar.gz spark-f6f7ca9d2ef65da15f42085993e58e618637fad5.tar.bz2 spark-f6f7ca9d2ef65da15f42085993e58e618637fad5.zip |
[SPARK-9716][ML] BinaryClassificationEvaluator should accept Double prediction column
This PR aims to allow the prediction column of `BinaryClassificationEvaluator` to be of double type.
Author: BenFradet <benjamin.fradet@gmail.com>
Closes #10472 from BenFradet/SPARK-9716.
Diffstat (limited to 'mllib/src/test/scala/org/apache')
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala | 32 |
1 files changed, 32 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala index a535c1218e..27349950dc 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.ml.evaluation import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext class BinaryClassificationEvaluatorSuite @@ -36,4 +37,35 @@ class BinaryClassificationEvaluatorSuite .setMetricName("areaUnderPR") testDefaultReadWrite(evaluator) } + + test("should accept both vector and double raw prediction col") { + val evaluator = new BinaryClassificationEvaluator() + .setMetricName("areaUnderPR") + + val vectorDF = sqlContext.createDataFrame(Seq( + (0d, Vectors.dense(12, 2.5)), + (1d, Vectors.dense(1, 3)), + (0d, Vectors.dense(10, 2)) + )).toDF("label", "rawPrediction") + assert(evaluator.evaluate(vectorDF) === 1.0) + + val doubleDF = sqlContext.createDataFrame(Seq( + (0d, 0d), + (1d, 1d), + (0d, 0d) + )).toDF("label", "rawPrediction") + assert(evaluator.evaluate(doubleDF) === 1.0) + + val stringDF = sqlContext.createDataFrame(Seq( + (0d, "0d"), + (1d, "1d"), + (0d, "0d") + )).toDF("label", "rawPrediction") + val thrown = intercept[IllegalArgumentException] { + evaluator.evaluate(stringDF) + } + assert(thrown.getMessage.replace("\n", "") contains "Column rawPrediction must be of type " + + "equal to one of the following types: [DoubleType, ") + assert(thrown.getMessage.replace("\n", "") contains "but was actually of type StringType.") + } } |