From 31f1aebbeb77b4eb1080f22c9bece7fafd8022f8 Mon Sep 17 00:00:00 2001 From: BenFradet Date: Fri, 13 May 2016 09:08:04 +0200 Subject: [SPARK-13961][ML] spark.ml ChiSqSelector and RFormula should support other numeric types for label ## What changes were proposed in this pull request? Made ChiSqSelector and RFormula accept all numeric types for label ## How was this patch tested? Unit tests Author: BenFradet Closes #12467 from BenFradet/SPARK-13961. --- .../apache/spark/ml/feature/ChiSqSelector.scala | 4 +-- .../org/apache/spark/ml/feature/RFormula.scala | 4 +-- .../DecisionTreeClassifierSuite.scala | 2 +- .../ml/classification/GBTClassifierSuite.scala | 2 +- .../classification/LogisticRegressionSuite.scala | 2 +- .../MultilayerPerceptronClassifierSuite.scala | 2 +- .../spark/ml/classification/NaiveBayesSuite.scala | 2 +- .../spark/ml/classification/OneVsRestSuite.scala | 2 +- .../RandomForestClassifierSuite.scala | 2 +- .../spark/ml/feature/ChiSqSelectorSuite.scala | 10 +++++++- .../apache/spark/ml/feature/RFormulaSuite.scala | 30 +++++++++++++++++----- .../ml/regression/AFTSurvivalRegressionSuite.scala | 2 +- .../ml/regression/DecisionTreeRegressorSuite.scala | 2 +- .../spark/ml/regression/GBTRegressorSuite.scala | 2 +- .../GeneralizedLinearRegressionSuite.scala | 2 +- .../ml/regression/IsotonicRegressionSuite.scala | 2 +- .../ml/regression/LinearRegressionSuite.scala | 2 +- .../ml/regression/RandomForestRegressorSuite.scala | 2 +- .../org/apache/spark/ml/util/MLTestingUtils.scala | 4 +-- 19 files changed, 53 insertions(+), 27 deletions(-) (limited to 'mllib') diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala index cfecae7e0b..29f55a7f71 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala @@ -80,7 +80,7 @@ final class ChiSqSelector(override val uid: String) @Since("2.0.0") override def fit(dataset: Dataset[_]): ChiSqSelectorModel = { transformSchema(dataset.schema, logging = true) - val input = dataset.select($(labelCol), $(featuresCol)).rdd.map { + val input = dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map { case Row(label: Double, features: Vector) => LabeledPoint(label, features) } @@ -90,7 +90,7 @@ final class ChiSqSelector(override val uid: String) override def transformSchema(schema: StructType): StructType = { SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) - SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType) + SchemaUtils.checkNumericType(schema, $(labelCol)) SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 5219680be2..a2f3d44132 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -256,8 +256,8 @@ class RFormulaModel private[feature]( val columnNames = schema.map(_.name) require(!columnNames.contains($(featuresCol)), "Features column already exists.") require( - !columnNames.contains($(labelCol)) || schema($(labelCol)).dataType == DoubleType, - "Label column already exists and is not of type DoubleType.") + !columnNames.contains($(labelCol)) || schema($(labelCol)).dataType.isInstanceOf[NumericType], + "Label column already exists and is not of type NumericType.") } @Since("2.0.0") diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index f94d336df5..91a947f44b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -337,7 +337,7 @@ class DecisionTreeClassifierSuite test("should support all NumericType labels and not support other types") { val dt = new DecisionTreeClassifier().setMaxDepth(1) MLTestingUtils.checkNumericTypes[DecisionTreeClassificationModel, DecisionTreeClassifier]( - dt, isClassification = true, spark) { (expected, actual) => + dt, spark) { (expected, actual) => TreeTests.checkEqual(expected, actual) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index c9453aaec2..5a5e5c15fc 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -106,7 +106,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext test("should support all NumericType labels and not support other types") { val gbt = new GBTClassifier().setMaxDepth(1) MLTestingUtils.checkNumericTypes[GBTClassificationModel, GBTClassifier]( - gbt, isClassification = true, spark) { (expected, actual) => + gbt, spark) { (expected, actual) => TreeTests.checkEqual(expected, actual) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index cb4d087ce5..f127aa217c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -938,7 +938,7 @@ class LogisticRegressionSuite test("should support all NumericType labels and not support other types") { val lr = new LogisticRegression().setMaxIter(1) MLTestingUtils.checkNumericTypes[LogisticRegressionModel, LogisticRegression]( - lr, isClassification = true, spark) { (expected, actual) => + lr, spark) { (expected, actual) => assert(expected.intercept === actual.intercept) assert(expected.coefficients.toArray === actual.coefficients.toArray) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala index 876e047db5..d5282e07d6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala @@ -169,7 +169,7 @@ class MultilayerPerceptronClassifierSuite val mpc = new MultilayerPerceptronClassifier().setLayers(layers).setMaxIter(1) MLTestingUtils.checkNumericTypes[ MultilayerPerceptronClassificationModel, MultilayerPerceptronClassifier]( - mpc, isClassification = true, spark) { (expected, actual) => + mpc, spark) { (expected, actual) => assert(expected.layers === actual.layers) assert(expected.weights === actual.weights) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala index 15d0358c3f..2a05c446e5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala @@ -188,7 +188,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa test("should support all NumericType labels and not support other types") { val nb = new NaiveBayes() MLTestingUtils.checkNumericTypes[NaiveBayesModel, NaiveBayes]( - nb, isClassification = true, spark) { (expected, actual) => + nb, spark) { (expected, actual) => assert(expected.pi === actual.pi) assert(expected.theta === actual.theta) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala index 005d609307..5044d40998 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala @@ -228,7 +228,7 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with Defau test("should support all NumericType labels and not support other types") { val ovr = new OneVsRest().setClassifier(new LogisticRegression().setMaxIter(1)) MLTestingUtils.checkNumericTypes[OneVsRestModel, OneVsRest]( - ovr, isClassification = true, spark) { (expected, actual) => + ovr, spark) { (expected, actual) => val expectedModels = expected.models.map(m => m.asInstanceOf[LogisticRegressionModel]) val actualModels = actual.models.map(m => m.asInstanceOf[LogisticRegressionModel]) assert(expectedModels.length === actualModels.length) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index 97f3feacca..8002a2f4f2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -189,7 +189,7 @@ class RandomForestClassifierSuite test("should support all NumericType labels and not support other types") { val rf = new RandomForestClassifier().setMaxDepth(1) MLTestingUtils.checkNumericTypes[RandomForestClassificationModel, RandomForestClassifier]( - rf, isClassification = true, spark) { (expected, actual) => + rf, spark) { (expected, actual) => TreeTests.checkEqual(expected, actual) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala index 4c6d9c5e26..4fcc9745b7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint @@ -81,4 +81,12 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext val newInstance = testDefaultReadWrite(instance) assert(newInstance.selectedFeatures === instance.selectedFeatures) } + + test("should support all NumericType labels and not support other types") { + val css = new ChiSqSelector() + MLTestingUtils.checkNumericTypes[ChiSqSelectorModel, ChiSqSelector]( + css, spark) { (expected, actual) => + assert(expected.selectedFeatures === actual.selectedFeatures) + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala index 46e7495297..c623a6210b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -20,10 +20,10 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute._ import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.Row +import org.apache.spark.sql.types.DoubleType class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("params") { @@ -68,9 +68,9 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul assert(resultSchema.toString == model.transform(original).schema.toString) } - test("label column already exists but is not double type") { + test("label column already exists but is not numeric type") { val formula = new RFormula().setFormula("y ~ x").setLabelCol("y") - val original = spark.createDataFrame(Seq((0, 1), (2, 2))).toDF("x", "y") + val original = spark.createDataFrame(Seq((0, true), (2, false))).toDF("x", "y") val model = formula.fit(original) intercept[IllegalArgumentException] { model.transformSchema(original.schema) @@ -134,7 +134,6 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul ).toDF("id", "a", "b") val model = formula.fit(original) val result = model.transform(original) - val resultSchema = model.transformSchema(original.schema) val expected = spark.createDataFrame( Seq( ("male", "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0), @@ -188,7 +187,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul "vec2", Array[Attribute]( NumericAttribute.defaultAttr, - NumericAttribute.defaultAttr)).toMetadata + NumericAttribute.defaultAttr)).toMetadata() val original = base.select(base.col("id"), base.col("vec").as("vec2", metadata)) val model = formula.fit(original) val result = model.transform(original) @@ -309,4 +308,23 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul val newModel = testDefaultReadWrite(model) checkModelData(model, newModel) } + + test("should support all NumericType labels") { + val formula = new RFormula().setFormula("label ~ features") + .setLabelCol("x") + .setFeaturesCol("y") + val dfs = MLTestingUtils.genRegressionDFWithNumericLabelCol(spark) + val expected = formula.fit(dfs(DoubleType)) + val actuals = dfs.keys.filter(_ != DoubleType).map(t => formula.fit(dfs(t))) + actuals.foreach { actual => + assert(expected.pipelineModel.stages.length === actual.pipelineModel.stages.length) + expected.pipelineModel.stages.zip(actual.pipelineModel.stages).foreach { + case (exTransformer, acTransformer) => + assert(exTransformer.params === acTransformer.params) + } + assert(expected.resolvedFormula.label === actual.resolvedFormula.label) + assert(expected.resolvedFormula.terms === actual.resolvedFormula.terms) + assert(expected.resolvedFormula.hasIntercept === actual.resolvedFormula.hasIntercept) + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala index f8fc775676..e4772df622 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala @@ -356,7 +356,7 @@ class AFTSurvivalRegressionSuite test("should support all NumericType labels") { val aft = new AFTSurvivalRegression().setMaxIter(1) MLTestingUtils.checkNumericTypes[AFTSurvivalRegressionModel, AFTSurvivalRegression]( - aft, isClassification = false, spark) { (expected, actual) => + aft, spark, isClassification = false) { (expected, actual) => assert(expected.intercept === actual.intercept) assert(expected.coefficients === actual.coefficients) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala index d9f26ad8dc..2d30cbf367 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala @@ -120,7 +120,7 @@ class DecisionTreeRegressorSuite test("should support all NumericType labels and not support other types") { val dt = new DecisionTreeRegressor().setMaxDepth(1) MLTestingUtils.checkNumericTypes[DecisionTreeRegressionModel, DecisionTreeRegressor]( - dt, isClassification = false, spark) { (expected, actual) => + dt, spark, isClassification = false) { (expected, actual) => TreeTests.checkEqual(expected, actual) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index f6ea5bb741..ac833b833d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -115,7 +115,7 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext test("should support all NumericType labels and not support other types") { val gbt = new GBTRegressor().setMaxDepth(1) MLTestingUtils.checkNumericTypes[GBTRegressionModel, GBTRegressor]( - gbt, isClassification = false, spark) { (expected, actual) => + gbt, spark, isClassification = false) { (expected, actual) => TreeTests.checkEqual(expected, actual) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala index 161f8c80f8..3d9aeb8c0a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala @@ -1021,7 +1021,7 @@ class GeneralizedLinearRegressionSuite val glr = new GeneralizedLinearRegression().setMaxIter(1) MLTestingUtils.checkNumericTypes[ GeneralizedLinearRegressionModel, GeneralizedLinearRegression]( - glr, isClassification = false, spark) { (expected, actual) => + glr, spark, isClassification = false) { (expected, actual) => assert(expected.intercept === actual.intercept) assert(expected.coefficients === actual.coefficients) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala index 9bf7542b12..bed4978b25 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala @@ -184,7 +184,7 @@ class IsotonicRegressionSuite test("should support all NumericType labels and not support other types") { val ir = new IsotonicRegression() MLTestingUtils.checkNumericTypes[IsotonicRegressionModel, IsotonicRegression]( - ir, isClassification = false, spark) { (expected, actual) => + ir, spark, isClassification = false) { (expected, actual) => assert(expected.boundaries === actual.boundaries) assert(expected.predictions === actual.predictions) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index 10f547b673..a98227d2c1 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -1010,7 +1010,7 @@ class LinearRegressionSuite test("should support all NumericType labels and not support other types") { val lr = new LinearRegression().setMaxIter(1) MLTestingUtils.checkNumericTypes[LinearRegressionModel, LinearRegression]( - lr, isClassification = false, spark) { (expected, actual) => + lr, spark, isClassification = false) { (expected, actual) => assert(expected.intercept === actual.intercept) assert(expected.coefficients === actual.coefficients) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala index 72f3c65eb8..7a3a3698f9 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala @@ -98,7 +98,7 @@ class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContex test("should support all NumericType labels and not support other types") { val rf = new RandomForestRegressor().setMaxDepth(1) MLTestingUtils.checkNumericTypes[RandomForestRegressionModel, RandomForestRegressor]( - rf, isClassification = false, spark) { (expected, actual) => + rf, spark, isClassification = false) { (expected, actual) => TreeTests.checkEqual(expected, actual) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala index 4fe473bbac..ad7d2c9b8d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala @@ -37,8 +37,8 @@ object MLTestingUtils extends SparkFunSuite { def checkNumericTypes[M <: Model[M], T <: Estimator[M]]( estimator: T, - isClassification: Boolean, - spark: SparkSession)(check: (M, M) => Unit): Unit = { + spark: SparkSession, + isClassification: Boolean = true)(check: (M, M) => Unit): Unit = { val dfs = if (isClassification) { genClassifDFWithNumericLabelCol(spark) } else { -- cgit v1.2.3