diff options
Diffstat (limited to 'mllib')
8 files changed, 88 insertions, 51 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala index bde8c275fd..0cbc391d96 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala @@ -24,6 +24,7 @@ import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, I import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql.{Dataset, Row} +import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.DoubleType /** @@ -73,13 +74,14 @@ class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override va override def evaluate(dataset: Dataset[_]): Double = { val schema = dataset.schema SchemaUtils.checkColumnTypes(schema, $(rawPredictionCol), Seq(DoubleType, new VectorUDT)) - SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType) + SchemaUtils.checkNumericType(schema, $(labelCol)) // TODO: When dataset metadata has been implemented, check rawPredictionCol vector length = 2. - val scoreAndLabels = dataset.select($(rawPredictionCol), $(labelCol)).rdd.map { - case Row(rawPrediction: Vector, label: Double) => (rawPrediction(1), label) - case Row(rawPrediction: Double, label: Double) => (rawPrediction, label) - } + val scoreAndLabels = + dataset.select(col($(rawPredictionCol)), col($(labelCol)).cast(DoubleType)).rdd.map { + case Row(rawPrediction: Vector, label: Double) => (rawPrediction(1), label) + case Row(rawPrediction: Double, label: Double) => (rawPrediction, label) + } val metrics = new BinaryClassificationMetrics(scoreAndLabels) val metric = $(metricName) match { case "areaUnderROC" => metrics.areaUnderROC() diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala index 3acfc221c9..3d89843a0b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala @@ -23,6 +23,7 @@ import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol} import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils} import org.apache.spark.mllib.evaluation.MulticlassMetrics import org.apache.spark.sql.{Dataset, Row} +import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.DoubleType /** @@ -72,12 +73,12 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid override def evaluate(dataset: Dataset[_]): Double = { val schema = dataset.schema SchemaUtils.checkColumnType(schema, $(predictionCol), DoubleType) - SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType) + SchemaUtils.checkNumericType(schema, $(labelCol)) - val predictionAndLabels = dataset.select($(predictionCol), $(labelCol)).rdd.map { - case Row(prediction: Double, label: Double) => - (prediction, label) - } + val predictionAndLabels = + dataset.select(col($(predictionCol)), col($(labelCol)).cast(DoubleType)).rdd.map { + case Row(prediction: Double, label: Double) => (prediction, label) + } val metrics = new MulticlassMetrics(predictionAndLabels) val metric = $(metricName) match { case "f1" => metrics.weightedFMeasure diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala index 988f6e918f..031cd0d635 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.evaluation import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol} -import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable} +import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils} import org.apache.spark.mllib.evaluation.RegressionMetrics import org.apache.spark.sql.{Dataset, Row} import org.apache.spark.sql.functions._ @@ -74,22 +74,13 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui @Since("2.0.0") override def evaluate(dataset: Dataset[_]): Double = { val schema = dataset.schema - val predictionColName = $(predictionCol) - val predictionType = schema($(predictionCol)).dataType - require(predictionType == FloatType || predictionType == DoubleType, - s"Prediction column $predictionColName must be of type float or double, " + - s" but not $predictionType") - val labelColName = $(labelCol) - val labelType = schema($(labelCol)).dataType - require(labelType == FloatType || labelType == DoubleType, - s"Label column $labelColName must be of type float or double, but not $labelType") + SchemaUtils.checkColumnTypes(schema, $(predictionCol), Seq(DoubleType, FloatType)) + SchemaUtils.checkNumericType(schema, $(labelCol)) val predictionAndLabels = dataset .select(col($(predictionCol)).cast(DoubleType), col($(labelCol)).cast(DoubleType)) - .rdd. - map { case Row(prediction: Double, label: Double) => - (prediction, label) - } + .rdd + .map { case Row(prediction: Double, label: Double) => (prediction, label) } val metrics = new RegressionMetrics(predictionAndLabels) val metric = $(metricName) match { case "rmse" => metrics.rootMeanSquaredError diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala index 27349950dc..ff34522178 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml.evaluation import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext @@ -68,4 +68,9 @@ class BinaryClassificationEvaluatorSuite "equal to one of the following types: [DoubleType, ") assert(thrown.getMessage.replace("\n", "") contains "but was actually of type StringType.") } + + test("should support all NumericType labels and not support other types") { + val evaluator = new BinaryClassificationEvaluator().setRawPredictionCol("prediction") + MLTestingUtils.checkNumericTypes(evaluator, sqlContext) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala index 7ee65975d2..87e511a368 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml.evaluation import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.util.MLlibTestSparkContext class MulticlassClassificationEvaluatorSuite @@ -36,4 +36,8 @@ class MulticlassClassificationEvaluatorSuite .setMetricName("recall") testDefaultReadWrite(evaluator) } + + test("should support all NumericType labels and not support other types") { + MLTestingUtils.checkNumericTypes(new MulticlassClassificationEvaluator, sqlContext) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala index 954d3bedc1..c7b9483069 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.evaluation import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.regression.LinearRegression -import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ @@ -83,4 +83,8 @@ class RegressionEvaluatorSuite .setMetricName("r2") testDefaultReadWrite(evaluator) } + + test("should support all NumericType labels and not support other types") { + MLTestingUtils.checkNumericTypes(new RegressionEvaluator, sqlContext) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala index b650a9f092..e3f09899d7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala @@ -79,16 +79,21 @@ private[ml] object TreeTests extends SparkFunSuite { * This must be non-empty. * @param numClasses Number of classes label can take. If 0, mark as continuous. * @param labelColName Name of the label column on which to set the metadata. + * @param featuresColName Name of the features column * @return DataFrame with metadata */ - def setMetadata(data: DataFrame, numClasses: Int, labelColName: String): DataFrame = { + def setMetadata( + data: DataFrame, + numClasses: Int, + labelColName: String, + featuresColName: String): DataFrame = { val labelAttribute = if (numClasses == 0) { NumericAttribute.defaultAttr.withName(labelColName) } else { NominalAttribute.defaultAttr.withName(labelColName).withNumValues(numClasses) } val labelMetadata = labelAttribute.toMetadata() - data.select(data("features"), data(labelColName).as(labelColName, labelMetadata)) + data.select(data(featuresColName), data(labelColName).as(labelColName, labelMetadata)) } /** diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala index 8108460518..d9e6fd5aae 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml.util import org.apache.spark.SparkFunSuite import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.evaluation.Evaluator import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.tree.impl.TreeTests import org.apache.spark.mllib.linalg.Vectors @@ -47,12 +48,30 @@ object MLTestingUtils extends SparkFunSuite { val actuals = dfs.keys.filter(_ != DoubleType).map(t => estimator.fit(dfs(t))) actuals.foreach(actual => check(expected, actual)) - val dfWithStringLabels = generateDFWithStringLabelCol(sqlContext) + val dfWithStringLabels = sqlContext.createDataFrame(Seq( + ("0", Vectors.dense(0, 2, 3), 0.0) + )).toDF("label", "features", "censor") val thrown = intercept[IllegalArgumentException] { estimator.fit(dfWithStringLabels) } - assert(thrown.getMessage contains - "Column label must be of type NumericType but was actually of type StringType") + assert(thrown.getMessage.contains( + "Column label must be of type NumericType but was actually of type StringType")) + } + + def checkNumericTypes[T <: Evaluator](evaluator: T, sqlContext: SQLContext): Unit = { + val dfs = genEvaluatorDFWithNumericLabelCol(sqlContext, "label", "prediction") + val expected = evaluator.evaluate(dfs(DoubleType)) + val actuals = dfs.keys.filter(_ != DoubleType).map(t => evaluator.evaluate(dfs(t))) + actuals.foreach(actual => assert(expected === actual)) + + val dfWithStringLabels = sqlContext.createDataFrame(Seq( + ("0", 0d) + )).toDF("label", "prediction") + val thrown = intercept[IllegalArgumentException] { + evaluator.evaluate(dfWithStringLabels) + } + assert(thrown.getMessage.contains( + "Column label must be of type NumericType but was actually of type StringType")) } def genClassifDFWithNumericLabelCol( @@ -69,9 +88,10 @@ object MLTestingUtils extends SparkFunSuite { val types = Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DoubleType, DecimalType(10, 0)) - types.map(t => t -> df.select(col(labelColName).cast(t), col(featuresColName))) - .map { case (t, d) => t -> TreeTests.setMetadata(d, 2, labelColName) } - .toMap + types.map { t => + val castDF = df.select(col(labelColName).cast(t), col(featuresColName)) + t -> TreeTests.setMetadata(castDF, 2, labelColName, featuresColName) + }.toMap } def genRegressionDFWithNumericLabelCol( @@ -89,24 +109,29 @@ object MLTestingUtils extends SparkFunSuite { val types = Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DoubleType, DecimalType(10, 0)) - types - .map(t => t -> df.select(col(labelColName).cast(t), col(featuresColName))) - .map { case (t, d) => - t -> TreeTests.setMetadata(d, 0, labelColName).withColumn(censorColName, lit(0.0)) - } - .toMap + types.map { t => + val castDF = df.select(col(labelColName).cast(t), col(featuresColName)) + t -> TreeTests.setMetadata(castDF, 0, labelColName, featuresColName) + .withColumn(censorColName, lit(0.0)) + }.toMap } - def generateDFWithStringLabelCol( + def genEvaluatorDFWithNumericLabelCol( sqlContext: SQLContext, labelColName: String = "label", - featuresColName: String = "features", - censorColName: String = "censor"): DataFrame = - sqlContext.createDataFrame(Seq( - ("0", Vectors.dense(0, 2, 3), 0.0), - ("1", Vectors.dense(0, 3, 1), 1.0), - ("0", Vectors.dense(0, 2, 2), 0.0), - ("1", Vectors.dense(0, 3, 9), 1.0), - ("0", Vectors.dense(0, 2, 6), 0.0) - )).toDF(labelColName, featuresColName, censorColName) + predictionColName: String = "prediction"): Map[NumericType, DataFrame] = { + val df = sqlContext.createDataFrame(Seq( + (0, 0d), + (1, 1d), + (2, 2d), + (3, 3d), + (4, 4d) + )).toDF(labelColName, predictionColName) + + val types = + Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DoubleType, DecimalType(10, 0)) + types + .map(t => t -> df.select(col(labelColName).cast(t), col(predictionColName))) + .toMap + } } |