From 36e8fb8005eccea67a9dea8cf68ec3105aa43351 Mon Sep 17 00:00:00 2001 From: BenFradet Date: Fri, 1 Apr 2016 18:25:43 -0700 Subject: [SPARK-7425][ML] spark.ml Predictor should support other numeric types for label Currently, the Predictor abstraction expects the input labelCol type to be DoubleType, but we should support other numeric types. This will involve updating the PredictorParams.validateAndTransformSchema method. Author: BenFradet Closes #10355 from BenFradet/SPARK-7425. --- .../DecisionTreeClassifierSuite.scala | 15 +++- .../ml/classification/GBTClassifierSuite.scala | 9 ++- .../classification/LogisticRegressionSuite.scala | 11 ++- .../MultilayerPerceptronClassifierSuite.scala | 12 +++ .../spark/ml/classification/NaiveBayesSuite.scala | 14 +++- .../spark/ml/classification/OneVsRestSuite.scala | 16 +++- .../RandomForestClassifierSuite.scala | 8 ++ .../ml/regression/AFTSurvivalRegressionSuite.scala | 9 +++ .../ml/regression/DecisionTreeRegressorSuite.scala | 8 ++ .../spark/ml/regression/GBTRegressorSuite.scala | 8 +- .../GeneralizedLinearRegressionSuite.scala | 12 ++- .../ml/regression/IsotonicRegressionSuite.scala | 9 +++ .../ml/regression/LinearRegressionSuite.scala | 17 ++++- .../ml/regression/RandomForestRegressorSuite.scala | 8 ++ .../org/apache/spark/ml/tree/impl/TreeTests.scala | 18 +++++ .../org/apache/spark/ml/util/MLTestingUtils.scala | 86 +++++++++++++++++++++- 16 files changed, 242 insertions(+), 18 deletions(-) (limited to 'mllib/src/test') diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index 2b07524815..fe839e15e9 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -27,8 +27,7 @@ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, DecisionTreeSuite => OldDecisionTreeSuite} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.Row +import org.apache.spark.sql.{DataFrame, Row} class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -176,7 +175,7 @@ class DecisionTreeClassifierSuite } test("Multiclass classification tree with 10-ary (ordered) categorical features," + - " with just enough bins") { + " with just enough bins") { val rdd = categoricalDataPointsForMulticlassForOrderedFeaturesRDD val dt = new DecisionTreeClassifier() .setImpurity("Gini") @@ -273,7 +272,7 @@ class DecisionTreeClassifierSuite )) val df = TreeTests.setMetadata(data, Map(0 -> 1), 2) val dt = new DecisionTreeClassifier().setMaxDepth(3) - val model = dt.fit(df) + dt.fit(df) } test("Use soft prediction for binary classification with ordered categorical features") { @@ -335,6 +334,14 @@ class DecisionTreeClassifierSuite assert(importances.toArray.forall(_ >= 0.0)) } + test("should support all NumericType labels and not support other types") { + val dt = new DecisionTreeClassifier().setMaxDepth(1) + MLTestingUtils.checkNumericTypes[DecisionTreeClassificationModel, DecisionTreeClassifier]( + dt, isClassification = true, sqlContext) { (expected, actual) => + TreeTests.checkEqual(expected, actual) + } + } + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index bf7481e8a3..76d8c9372e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -31,7 +31,6 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame import org.apache.spark.util.Utils - /** * Test suite for [[GBTClassifier]]. */ @@ -102,6 +101,14 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { Utils.deleteRecursively(tempDir) } + test("should support all NumericType labels and not support other types") { + val gbt = new GBTClassifier().setMaxDepth(1) + MLTestingUtils.checkNumericTypes[GBTClassificationModel, GBTClassifier]( + gbt, isClassification = true, sqlContext) { (expected, actual) => + TreeTests.checkEqual(expected, actual) + } + } + // TODO: Reinstate test once runWithValidation is implemented SPARK-7132 /* test("runWithValidation stops early and performs better on a validation dataset") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index afeeaf7fb5..7eefaf2346 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -103,7 +103,7 @@ class LogisticRegressionSuite assert(model.hasSummary) // Validate that we re-insert a probability column for evaluation val fieldNames = model.summary.predictions.schema.fieldNames - assert((dataset.schema.fieldNames.toSet).subsetOf( + assert(dataset.schema.fieldNames.toSet.subsetOf( fieldNames.toSet)) assert(fieldNames.exists(s => s.startsWith("probability_"))) } @@ -934,6 +934,15 @@ class LogisticRegressionSuite testEstimatorAndModelReadWrite(lr, dataset, LogisticRegressionSuite.allParamSettings, checkModelData) } + + test("should support all NumericType labels and not support other types") { + val lr = new LogisticRegression().setMaxIter(1) + MLTestingUtils.checkNumericTypes[LogisticRegressionModel, LogisticRegression]( + lr, isClassification = true, sqlContext) { (expected, actual) => + assert(expected.intercept === actual.intercept) + assert(expected.coefficients.toArray === actual.coefficients.toArray) + } + } } object LogisticRegressionSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala index 43781385db..06ff049b48 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml.classification import org.apache.spark.SparkFunSuite import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.classification.LogisticRegressionSuite._ import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS import org.apache.spark.mllib.evaluation.MulticlassMetrics @@ -162,4 +163,15 @@ class MultilayerPerceptronClassifierSuite assert(newMlpModel.layers === mlpModel.layers) assert(newMlpModel.weights === mlpModel.weights) } + + test("should support all NumericType labels and not support other types") { + val layers = Array(3, 2) + val mpc = new MultilayerPerceptronClassifier().setLayers(layers).setMaxIter(1) + MLTestingUtils.checkNumericTypes[ + MultilayerPerceptronClassificationModel, MultilayerPerceptronClassifier]( + mpc, isClassification = true, sqlContext) { (expected, actual) => + assert(expected.layers === actual.layers) + assert(expected.weights === actual.weights) + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala index 082a6bcd21..4727cd436f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala @@ -21,7 +21,7 @@ import breeze.linalg.{Vector => BV} import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.classification.NaiveBayes.{Bernoulli, Multinomial} import org.apache.spark.mllib.classification.NaiveBayesSuite._ import org.apache.spark.mllib.linalg._ @@ -86,7 +86,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa model: NaiveBayesModel, modelType: String): Unit = { featureAndProbabilities.collect().foreach { - case Row(features: Vector, probability: Vector) => { + case Row(features: Vector, probability: Vector) => assert(probability.toArray.sum ~== 1.0 relTol 1.0e-10) val expected = modelType match { case Multinomial => @@ -97,7 +97,6 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa throw new UnknownError(s"Invalid modelType: $modelType.") } assert(probability ~== expected relTol 1.0e-10) - } } } @@ -185,6 +184,15 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa val nb = new NaiveBayes() testEstimatorAndModelReadWrite(nb, dataset, NaiveBayesSuite.allParamSettings, checkModelData) } + + test("should support all NumericType labels and not support other types") { + val nb = new NaiveBayes() + MLTestingUtils.checkNumericTypes[NaiveBayesModel, NaiveBayes]( + nb, isClassification = true, sqlContext) { (expected, actual) => + assert(expected.pi === actual.pi) + assert(expected.theta === actual.theta) + } + } } object NaiveBayesSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala index 51c1baf682..4131396726 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala @@ -74,7 +74,7 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with Defau // copied model must have the same parent. MLTestingUtils.checkCopy(ovaModel) - assert(ovaModel.models.size === numClasses) + assert(ovaModel.models.length === numClasses) val transformedDataset = ovaModel.transform(dataset) // check for label metadata in prediction col @@ -224,6 +224,20 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with Defau val newOvaModel = testDefaultReadWrite(ovaModel, testParams = false) checkModelData(ovaModel, newOvaModel) } + + test("should support all NumericType labels and not support other types") { + val ovr = new OneVsRest().setClassifier(new LogisticRegression().setMaxIter(1)) + MLTestingUtils.checkNumericTypes[OneVsRestModel, OneVsRest]( + ovr, isClassification = true, sqlContext) { (expected, actual) => + val expectedModels = expected.models.map(m => m.asInstanceOf[LogisticRegressionModel]) + val actualModels = actual.models.map(m => m.asInstanceOf[LogisticRegressionModel]) + assert(expectedModels.length === actualModels.length) + expectedModels.zip(actualModels).foreach { case (e, a) => + assert(e.intercept === a.intercept) + assert(e.coefficients.toArray === a.coefficients.toArray) + } + } + } } private class MockLogisticRegression(uid: String) extends LogisticRegression(uid) { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index b896099e31..052bc83c38 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -178,6 +178,14 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte assert(importances.toArray.forall(_ >= 0.0)) } + test("should support all NumericType labels and not support other types") { + val rf = new RandomForestClassifier().setMaxDepth(1) + MLTestingUtils.checkNumericTypes[RandomForestClassificationModel, RandomForestClassifier]( + rf, isClassification = true, sqlContext) { (expected, actual) => + TreeTests.checkEqual(expected, actual) + } + } + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala index dbd752d2aa..f4844cc671 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala @@ -347,6 +347,15 @@ class AFTSurvivalRegressionSuite } } + test("should support all NumericType labels") { + val aft = new AFTSurvivalRegression().setMaxIter(1) + MLTestingUtils.checkNumericTypes[AFTSurvivalRegressionModel, AFTSurvivalRegression]( + aft, isClassification = false, sqlContext) { (expected, actual) => + assert(expected.intercept === actual.intercept) + assert(expected.coefficients === actual.coefficients) + } + } + test("read/write") { def checkModelData( model: AFTSurvivalRegressionModel, diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala index 662e3fc679..e9fb2677b2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala @@ -117,6 +117,14 @@ class DecisionTreeRegressorSuite assert(importances.toArray.forall(_ >= 0.0)) } + test("should support all NumericType labels and not support other types") { + val dt = new DecisionTreeRegressor().setMaxDepth(1) + MLTestingUtils.checkNumericTypes[DecisionTreeRegressionModel, DecisionTreeRegressor]( + dt, isClassification = false, sqlContext) { (expected, actual) => + TreeTests.checkEqual(expected, actual) + } + } + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index dfb8418086..914818f41f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -29,7 +29,6 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame import org.apache.spark.util.Utils - /** * Test suite for [[GBTRegressor]]. */ @@ -110,7 +109,14 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext { sc.checkpointDir = None Utils.deleteRecursively(tempDir) + } + test("should support all NumericType labels and not support other types") { + val gbt = new GBTRegressor().setMaxDepth(1) + MLTestingUtils.checkNumericTypes[GBTRegressionModel, GBTRegressor]( + gbt, isClassification = false, sqlContext) { (expected, actual) => + TreeTests.checkEqual(expected, actual) + } } // TODO: Reinstate test once runWithValidation is implemented SPARK-7132 diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala index 4ebdbf2213..2265464b51 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala @@ -982,6 +982,16 @@ class GeneralizedLinearRegressionSuite testEstimatorAndModelReadWrite(glr, datasetPoissonLog, GeneralizedLinearRegressionSuite.allParamSettings, checkModelData) } + + test("should support all NumericType labels and not support other types") { + val glr = new GeneralizedLinearRegression().setMaxIter(1) + MLTestingUtils.checkNumericTypes[ + GeneralizedLinearRegressionModel, GeneralizedLinearRegression]( + glr, isClassification = false, sqlContext) { (expected, actual) => + assert(expected.intercept === actual.intercept) + assert(expected.coefficients === actual.coefficients) + } + } } object GeneralizedLinearRegressionSuite { @@ -1023,7 +1033,7 @@ object GeneralizedLinearRegressionSuite { generator.setSeed(seed) (0 until nPoints).map { _ => - val features = Vectors.dense(coefficients.indices.map { rndElement(_) }.toArray) + val features = Vectors.dense(coefficients.indices.map(rndElement).toArray) val eta = BLAS.dot(Vectors.dense(coefficients), features) + intercept val mu = link match { case "identity" => eta diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala index b8874b4cd3..3a10ad7ed0 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala @@ -180,6 +180,15 @@ class IsotonicRegressionSuite testEstimatorAndModelReadWrite(ir, dataset, IsotonicRegressionSuite.allParamSettings, checkModelData) } + + test("should support all NumericType labels and not support other types") { + val ir = new IsotonicRegression() + MLTestingUtils.checkNumericTypes[IsotonicRegressionModel, IsotonicRegression]( + ir, isClassification = false, sqlContext) { (expected, actual) => + assert(expected.boundaries === actual.boundaries) + assert(expected.predictions === actual.predictions) + } + } } object IsotonicRegressionSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index bd45d21e8d..cccb7f8d1b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -61,9 +61,9 @@ class LinearRegressionSuite val featureSize = 4100 datasetWithSparseFeature = sqlContext.createDataFrame( sc.parallelize(LinearDataGenerator.generateLinearInput( - intercept = 0.0, weights = Seq.fill(featureSize)(r.nextDouble).toArray, - xMean = Seq.fill(featureSize)(r.nextDouble).toArray, - xVariance = Seq.fill(featureSize)(r.nextDouble).toArray, nPoints = 200, + intercept = 0.0, weights = Seq.fill(featureSize)(r.nextDouble()).toArray, + xMean = Seq.fill(featureSize)(r.nextDouble()).toArray, + xVariance = Seq.fill(featureSize)(r.nextDouble()).toArray, nPoints = 200, seed, eps = 0.1, sparsity = 0.7), 2)) /* @@ -687,7 +687,7 @@ class LinearRegressionSuite // Validate that we re-insert a prediction column for evaluation val modelNoPredictionColFieldNames = modelNoPredictionCol.summary.predictions.schema.fieldNames - assert((datasetWithDenseFeature.schema.fieldNames.toSet).subsetOf( + assert(datasetWithDenseFeature.schema.fieldNames.toSet.subsetOf( modelNoPredictionColFieldNames.toSet)) assert(modelNoPredictionColFieldNames.exists(s => s.startsWith("prediction_"))) @@ -1006,6 +1006,15 @@ class LinearRegressionSuite testEstimatorAndModelReadWrite(lr, datasetWithWeight, LinearRegressionSuite.allParamSettings, checkModelData) } + + test("should support all NumericType labels and not support other types") { + val lr = new LinearRegression().setMaxIter(1) + MLTestingUtils.checkNumericTypes[LinearRegressionModel, LinearRegression]( + lr, isClassification = false, sqlContext) { (expected, actual) => + assert(expected.intercept === actual.intercept) + assert(expected.coefficients === actual.coefficients) + } + } } object LinearRegressionSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala index 6be0c8bca0..2ab4f1b146 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala @@ -94,6 +94,14 @@ class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContex assert(importances.toArray.forall(_ >= 0.0)) } + test("should support all NumericType labels and not support other types") { + val rf = new RandomForestRegressor().setMaxDepth(1) + MLTestingUtils.checkNumericTypes[RandomForestRegressionModel, RandomForestRegressor]( + rf, isClassification = false, sqlContext) { (expected, actual) => + TreeTests.checkEqual(expected, actual) + } + } + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala index 12808b0305..bd5bd17147 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala @@ -73,6 +73,24 @@ private[ml] object TreeTests extends SparkFunSuite { numClasses) } + /** + * Set label metadata (particularly the number of classes) on a DataFrame. + * @param data Dataset. Categorical features and labels must already have 0-based indices. + * This must be non-empty. + * @param numClasses Number of classes label can take. If 0, mark as continuous. + * @param labelColName Name of the label column on which to set the metadata. + * @return DataFrame with metadata + */ + def setMetadata(data: DataFrame, numClasses: Int, labelColName: String): DataFrame = { + val labelAttribute = if (numClasses == 0) { + NumericAttribute.defaultAttr.withName(labelColName) + } else { + NominalAttribute.defaultAttr.withName(labelColName).withNumValues(numClasses) + } + val labelMetadata = labelAttribute.toMetadata() + data.select(data("features"), data(labelColName).as(labelColName, labelMetadata)) + } + /** * Check if the two trees are exactly the same. * Note: I hesitate to override Node.equals since it could cause problems if users diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala index d290cc9b06..8108460518 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala @@ -17,14 +17,96 @@ package org.apache.spark.ml.util -import org.apache.spark.ml.Model +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.tree.impl.TreeTests +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ -object MLTestingUtils { +object MLTestingUtils extends SparkFunSuite { def checkCopy(model: Model[_]): Unit = { val copied = model.copy(ParamMap.empty) .asInstanceOf[Model[_]] assert(copied.parent.uid == model.parent.uid) assert(copied.parent == model.parent) } + + def checkNumericTypes[M <: Model[M], T <: Estimator[M]]( + estimator: T, + isClassification: Boolean, + sqlContext: SQLContext)(check: (M, M) => Unit): Unit = { + val dfs = if (isClassification) { + genClassifDFWithNumericLabelCol(sqlContext) + } else { + genRegressionDFWithNumericLabelCol(sqlContext) + } + val expected = estimator.fit(dfs(DoubleType)) + val actuals = dfs.keys.filter(_ != DoubleType).map(t => estimator.fit(dfs(t))) + actuals.foreach(actual => check(expected, actual)) + + val dfWithStringLabels = generateDFWithStringLabelCol(sqlContext) + val thrown = intercept[IllegalArgumentException] { + estimator.fit(dfWithStringLabels) + } + assert(thrown.getMessage contains + "Column label must be of type NumericType but was actually of type StringType") + } + + def genClassifDFWithNumericLabelCol( + sqlContext: SQLContext, + labelColName: String = "label", + featuresColName: String = "features"): Map[NumericType, DataFrame] = { + val df = sqlContext.createDataFrame(Seq( + (0, Vectors.dense(0, 2, 3)), + (1, Vectors.dense(0, 3, 1)), + (0, Vectors.dense(0, 2, 2)), + (1, Vectors.dense(0, 3, 9)), + (0, Vectors.dense(0, 2, 6)) + )).toDF(labelColName, featuresColName) + + val types = + Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DoubleType, DecimalType(10, 0)) + types.map(t => t -> df.select(col(labelColName).cast(t), col(featuresColName))) + .map { case (t, d) => t -> TreeTests.setMetadata(d, 2, labelColName) } + .toMap + } + + def genRegressionDFWithNumericLabelCol( + sqlContext: SQLContext, + labelColName: String = "label", + featuresColName: String = "features", + censorColName: String = "censor"): Map[NumericType, DataFrame] = { + val df = sqlContext.createDataFrame(Seq( + (0, Vectors.dense(0)), + (1, Vectors.dense(1)), + (2, Vectors.dense(2)), + (3, Vectors.dense(3)), + (4, Vectors.dense(4)) + )).toDF(labelColName, featuresColName) + + val types = + Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DoubleType, DecimalType(10, 0)) + types + .map(t => t -> df.select(col(labelColName).cast(t), col(featuresColName))) + .map { case (t, d) => + t -> TreeTests.setMetadata(d, 0, labelColName).withColumn(censorColName, lit(0.0)) + } + .toMap + } + + def generateDFWithStringLabelCol( + sqlContext: SQLContext, + labelColName: String = "label", + featuresColName: String = "features", + censorColName: String = "censor"): DataFrame = + sqlContext.createDataFrame(Seq( + ("0", Vectors.dense(0, 2, 3), 0.0), + ("1", Vectors.dense(0, 3, 1), 1.0), + ("0", Vectors.dense(0, 2, 2), 0.0), + ("1", Vectors.dense(0, 3, 9), 1.0), + ("0", Vectors.dense(0, 2, 6), 0.0) + )).toDF(labelColName, featuresColName, censorColName) } -- cgit v1.2.3