aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorBenFradet <benjamin.fradet@gmail.com>2016-04-26 08:55:50 +0200
committerNick Pentreath <nickp@za.ibm.com>2016-04-26 08:55:50 +0200
commit2a5c930790b4b92674e74f093380d89a9a625552 (patch)
tree5738e2d5d6d9db957911e85ec273c147dc313e28 /mllib
parentf8709218115f6c7aa4fb321865cdef8ceb443bd1 (diff)
downloadspark-2a5c930790b4b92674e74f093380d89a9a625552.tar.gz
spark-2a5c930790b4b92674e74f093380d89a9a625552.tar.bz2
spark-2a5c930790b4b92674e74f093380d89a9a625552.zip
[SPARK-13962][ML] spark.ml Evaluators should support other numeric types for label
## What changes were proposed in this pull request? Made BinaryClassificationEvaluator, MulticlassClassificationEvaluator and RegressionEvaluator accept all numeric types for label ## How was this patch tested? Unit tests Author: BenFradet <benjamin.fradet@gmail.com> Closes #12500 from BenFradet/SPARK-13962.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala12
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala11
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala19
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala7
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala6
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala6
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala9
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala69
8 files changed, 88 insertions, 51 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
index bde8c275fd..0cbc391d96 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
@@ -24,6 +24,7 @@ import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, I
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.sql.{Dataset, Row}
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.DoubleType
/**
@@ -73,13 +74,14 @@ class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override va
override def evaluate(dataset: Dataset[_]): Double = {
val schema = dataset.schema
SchemaUtils.checkColumnTypes(schema, $(rawPredictionCol), Seq(DoubleType, new VectorUDT))
- SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType)
+ SchemaUtils.checkNumericType(schema, $(labelCol))
// TODO: When dataset metadata has been implemented, check rawPredictionCol vector length = 2.
- val scoreAndLabels = dataset.select($(rawPredictionCol), $(labelCol)).rdd.map {
- case Row(rawPrediction: Vector, label: Double) => (rawPrediction(1), label)
- case Row(rawPrediction: Double, label: Double) => (rawPrediction, label)
- }
+ val scoreAndLabels =
+ dataset.select(col($(rawPredictionCol)), col($(labelCol)).cast(DoubleType)).rdd.map {
+ case Row(rawPrediction: Vector, label: Double) => (rawPrediction(1), label)
+ case Row(rawPrediction: Double, label: Double) => (rawPrediction, label)
+ }
val metrics = new BinaryClassificationMetrics(scoreAndLabels)
val metric = $(metricName) match {
case "areaUnderROC" => metrics.areaUnderROC()
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala
index 3acfc221c9..3d89843a0b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala
@@ -23,6 +23,7 @@ import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol}
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils}
import org.apache.spark.mllib.evaluation.MulticlassMetrics
import org.apache.spark.sql.{Dataset, Row}
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.DoubleType
/**
@@ -72,12 +73,12 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid
override def evaluate(dataset: Dataset[_]): Double = {
val schema = dataset.schema
SchemaUtils.checkColumnType(schema, $(predictionCol), DoubleType)
- SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType)
+ SchemaUtils.checkNumericType(schema, $(labelCol))
- val predictionAndLabels = dataset.select($(predictionCol), $(labelCol)).rdd.map {
- case Row(prediction: Double, label: Double) =>
- (prediction, label)
- }
+ val predictionAndLabels =
+ dataset.select(col($(predictionCol)), col($(labelCol)).cast(DoubleType)).rdd.map {
+ case Row(prediction: Double, label: Double) => (prediction, label)
+ }
val metrics = new MulticlassMetrics(predictionAndLabels)
val metric = $(metricName) match {
case "f1" => metrics.weightedFMeasure
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
index 988f6e918f..031cd0d635 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
@@ -20,7 +20,7 @@ package org.apache.spark.ml.evaluation
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol}
-import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable}
+import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils}
import org.apache.spark.mllib.evaluation.RegressionMetrics
import org.apache.spark.sql.{Dataset, Row}
import org.apache.spark.sql.functions._
@@ -74,22 +74,13 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui
@Since("2.0.0")
override def evaluate(dataset: Dataset[_]): Double = {
val schema = dataset.schema
- val predictionColName = $(predictionCol)
- val predictionType = schema($(predictionCol)).dataType
- require(predictionType == FloatType || predictionType == DoubleType,
- s"Prediction column $predictionColName must be of type float or double, " +
- s" but not $predictionType")
- val labelColName = $(labelCol)
- val labelType = schema($(labelCol)).dataType
- require(labelType == FloatType || labelType == DoubleType,
- s"Label column $labelColName must be of type float or double, but not $labelType")
+ SchemaUtils.checkColumnTypes(schema, $(predictionCol), Seq(DoubleType, FloatType))
+ SchemaUtils.checkNumericType(schema, $(labelCol))
val predictionAndLabels = dataset
.select(col($(predictionCol)).cast(DoubleType), col($(labelCol)).cast(DoubleType))
- .rdd.
- map { case Row(prediction: Double, label: Double) =>
- (prediction, label)
- }
+ .rdd
+ .map { case Row(prediction: Double, label: Double) => (prediction, label) }
val metrics = new RegressionMetrics(predictionAndLabels)
val metric = $(metricName) match {
case "rmse" => metrics.rootMeanSquaredError
diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala
index 27349950dc..ff34522178 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.ml.evaluation
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.DefaultReadWriteTest
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLlibTestSparkContext
@@ -68,4 +68,9 @@ class BinaryClassificationEvaluatorSuite
"equal to one of the following types: [DoubleType, ")
assert(thrown.getMessage.replace("\n", "") contains "but was actually of type StringType.")
}
+
+ test("should support all NumericType labels and not support other types") {
+ val evaluator = new BinaryClassificationEvaluator().setRawPredictionCol("prediction")
+ MLTestingUtils.checkNumericTypes(evaluator, sqlContext)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala
index 7ee65975d2..87e511a368 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.ml.evaluation
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.DefaultReadWriteTest
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.util.MLlibTestSparkContext
class MulticlassClassificationEvaluatorSuite
@@ -36,4 +36,8 @@ class MulticlassClassificationEvaluatorSuite
.setMetricName("recall")
testDefaultReadWrite(evaluator)
}
+
+ test("should support all NumericType labels and not support other types") {
+ MLTestingUtils.checkNumericTypes(new MulticlassClassificationEvaluator, sqlContext)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala
index 954d3bedc1..c7b9483069 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.ml.evaluation
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.regression.LinearRegression
-import org.apache.spark.ml.util.DefaultReadWriteTest
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
import org.apache.spark.mllib.util.TestingUtils._
@@ -83,4 +83,8 @@ class RegressionEvaluatorSuite
.setMetricName("r2")
testDefaultReadWrite(evaluator)
}
+
+ test("should support all NumericType labels and not support other types") {
+ MLTestingUtils.checkNumericTypes(new RegressionEvaluator, sqlContext)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala
index b650a9f092..e3f09899d7 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala
@@ -79,16 +79,21 @@ private[ml] object TreeTests extends SparkFunSuite {
* This must be non-empty.
* @param numClasses Number of classes label can take. If 0, mark as continuous.
* @param labelColName Name of the label column on which to set the metadata.
+ * @param featuresColName Name of the features column
* @return DataFrame with metadata
*/
- def setMetadata(data: DataFrame, numClasses: Int, labelColName: String): DataFrame = {
+ def setMetadata(
+ data: DataFrame,
+ numClasses: Int,
+ labelColName: String,
+ featuresColName: String): DataFrame = {
val labelAttribute = if (numClasses == 0) {
NumericAttribute.defaultAttr.withName(labelColName)
} else {
NominalAttribute.defaultAttr.withName(labelColName).withNumValues(numClasses)
}
val labelMetadata = labelAttribute.toMetadata()
- data.select(data("features"), data(labelColName).as(labelColName, labelMetadata))
+ data.select(data(featuresColName), data(labelColName).as(labelColName, labelMetadata))
}
/**
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala
index 8108460518..d9e6fd5aae 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala
@@ -19,6 +19,7 @@ package org.apache.spark.ml.util
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.{Estimator, Model}
+import org.apache.spark.ml.evaluation.Evaluator
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tree.impl.TreeTests
import org.apache.spark.mllib.linalg.Vectors
@@ -47,12 +48,30 @@ object MLTestingUtils extends SparkFunSuite {
val actuals = dfs.keys.filter(_ != DoubleType).map(t => estimator.fit(dfs(t)))
actuals.foreach(actual => check(expected, actual))
- val dfWithStringLabels = generateDFWithStringLabelCol(sqlContext)
+ val dfWithStringLabels = sqlContext.createDataFrame(Seq(
+ ("0", Vectors.dense(0, 2, 3), 0.0)
+ )).toDF("label", "features", "censor")
val thrown = intercept[IllegalArgumentException] {
estimator.fit(dfWithStringLabels)
}
- assert(thrown.getMessage contains
- "Column label must be of type NumericType but was actually of type StringType")
+ assert(thrown.getMessage.contains(
+ "Column label must be of type NumericType but was actually of type StringType"))
+ }
+
+ def checkNumericTypes[T <: Evaluator](evaluator: T, sqlContext: SQLContext): Unit = {
+ val dfs = genEvaluatorDFWithNumericLabelCol(sqlContext, "label", "prediction")
+ val expected = evaluator.evaluate(dfs(DoubleType))
+ val actuals = dfs.keys.filter(_ != DoubleType).map(t => evaluator.evaluate(dfs(t)))
+ actuals.foreach(actual => assert(expected === actual))
+
+ val dfWithStringLabels = sqlContext.createDataFrame(Seq(
+ ("0", 0d)
+ )).toDF("label", "prediction")
+ val thrown = intercept[IllegalArgumentException] {
+ evaluator.evaluate(dfWithStringLabels)
+ }
+ assert(thrown.getMessage.contains(
+ "Column label must be of type NumericType but was actually of type StringType"))
}
def genClassifDFWithNumericLabelCol(
@@ -69,9 +88,10 @@ object MLTestingUtils extends SparkFunSuite {
val types =
Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DoubleType, DecimalType(10, 0))
- types.map(t => t -> df.select(col(labelColName).cast(t), col(featuresColName)))
- .map { case (t, d) => t -> TreeTests.setMetadata(d, 2, labelColName) }
- .toMap
+ types.map { t =>
+ val castDF = df.select(col(labelColName).cast(t), col(featuresColName))
+ t -> TreeTests.setMetadata(castDF, 2, labelColName, featuresColName)
+ }.toMap
}
def genRegressionDFWithNumericLabelCol(
@@ -89,24 +109,29 @@ object MLTestingUtils extends SparkFunSuite {
val types =
Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DoubleType, DecimalType(10, 0))
- types
- .map(t => t -> df.select(col(labelColName).cast(t), col(featuresColName)))
- .map { case (t, d) =>
- t -> TreeTests.setMetadata(d, 0, labelColName).withColumn(censorColName, lit(0.0))
- }
- .toMap
+ types.map { t =>
+ val castDF = df.select(col(labelColName).cast(t), col(featuresColName))
+ t -> TreeTests.setMetadata(castDF, 0, labelColName, featuresColName)
+ .withColumn(censorColName, lit(0.0))
+ }.toMap
}
- def generateDFWithStringLabelCol(
+ def genEvaluatorDFWithNumericLabelCol(
sqlContext: SQLContext,
labelColName: String = "label",
- featuresColName: String = "features",
- censorColName: String = "censor"): DataFrame =
- sqlContext.createDataFrame(Seq(
- ("0", Vectors.dense(0, 2, 3), 0.0),
- ("1", Vectors.dense(0, 3, 1), 1.0),
- ("0", Vectors.dense(0, 2, 2), 0.0),
- ("1", Vectors.dense(0, 3, 9), 1.0),
- ("0", Vectors.dense(0, 2, 6), 0.0)
- )).toDF(labelColName, featuresColName, censorColName)
+ predictionColName: String = "prediction"): Map[NumericType, DataFrame] = {
+ val df = sqlContext.createDataFrame(Seq(
+ (0, 0d),
+ (1, 1d),
+ (2, 2d),
+ (3, 3d),
+ (4, 4d)
+ )).toDF(labelColName, predictionColName)
+
+ val types =
+ Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DoubleType, DecimalType(10, 0))
+ types
+ .map(t => t -> df.select(col(labelColName).cast(t), col(predictionColName)))
+ .toMap
+ }
}