aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorBenFradet <benjamin.fradet@gmail.com>2016-05-13 09:08:04 +0200
committerNick Pentreath <nick.pentreath@gmail.com>2016-05-13 09:08:04 +0200
commit31f1aebbeb77b4eb1080f22c9bece7fafd8022f8 (patch)
tree95723ad90476594c4eeb1458fe61cc6e0007eaa3 /mllib
parent5b849766ab080c91864ed06ebbfd82ad978d5e4c (diff)
downloadspark-31f1aebbeb77b4eb1080f22c9bece7fafd8022f8.tar.gz
spark-31f1aebbeb77b4eb1080f22c9bece7fafd8022f8.tar.bz2
spark-31f1aebbeb77b4eb1080f22c9bece7fafd8022f8.zip
[SPARK-13961][ML] spark.ml ChiSqSelector and RFormula should support other numeric types for label
## What changes were proposed in this pull request? Made ChiSqSelector and RFormula accept all numeric types for label ## How was this patch tested? Unit tests Author: BenFradet <benjamin.fradet@gmail.com> Closes #12467 from BenFradet/SPARK-13961.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala10
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala30
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala4
19 files changed, 53 insertions, 27 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
index cfecae7e0b..29f55a7f71 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
@@ -80,7 +80,7 @@ final class ChiSqSelector(override val uid: String)
@Since("2.0.0")
override def fit(dataset: Dataset[_]): ChiSqSelectorModel = {
transformSchema(dataset.schema, logging = true)
- val input = dataset.select($(labelCol), $(featuresCol)).rdd.map {
+ val input = dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map {
case Row(label: Double, features: Vector) =>
LabeledPoint(label, features)
}
@@ -90,7 +90,7 @@ final class ChiSqSelector(override val uid: String)
override def transformSchema(schema: StructType): StructType = {
SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
- SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType)
+ SchemaUtils.checkNumericType(schema, $(labelCol))
SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
index 5219680be2..a2f3d44132 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
@@ -256,8 +256,8 @@ class RFormulaModel private[feature](
val columnNames = schema.map(_.name)
require(!columnNames.contains($(featuresCol)), "Features column already exists.")
require(
- !columnNames.contains($(labelCol)) || schema($(labelCol)).dataType == DoubleType,
- "Label column already exists and is not of type DoubleType.")
+ !columnNames.contains($(labelCol)) || schema($(labelCol)).dataType.isInstanceOf[NumericType],
+ "Label column already exists and is not of type NumericType.")
}
@Since("2.0.0")
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
index f94d336df5..91a947f44b 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
@@ -337,7 +337,7 @@ class DecisionTreeClassifierSuite
test("should support all NumericType labels and not support other types") {
val dt = new DecisionTreeClassifier().setMaxDepth(1)
MLTestingUtils.checkNumericTypes[DecisionTreeClassificationModel, DecisionTreeClassifier](
- dt, isClassification = true, spark) { (expected, actual) =>
+ dt, spark) { (expected, actual) =>
TreeTests.checkEqual(expected, actual)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
index c9453aaec2..5a5e5c15fc 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
@@ -106,7 +106,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
test("should support all NumericType labels and not support other types") {
val gbt = new GBTClassifier().setMaxDepth(1)
MLTestingUtils.checkNumericTypes[GBTClassificationModel, GBTClassifier](
- gbt, isClassification = true, spark) { (expected, actual) =>
+ gbt, spark) { (expected, actual) =>
TreeTests.checkEqual(expected, actual)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index cb4d087ce5..f127aa217c 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -938,7 +938,7 @@ class LogisticRegressionSuite
test("should support all NumericType labels and not support other types") {
val lr = new LogisticRegression().setMaxIter(1)
MLTestingUtils.checkNumericTypes[LogisticRegressionModel, LogisticRegression](
- lr, isClassification = true, spark) { (expected, actual) =>
+ lr, spark) { (expected, actual) =>
assert(expected.intercept === actual.intercept)
assert(expected.coefficients.toArray === actual.coefficients.toArray)
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
index 876e047db5..d5282e07d6 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
@@ -169,7 +169,7 @@ class MultilayerPerceptronClassifierSuite
val mpc = new MultilayerPerceptronClassifier().setLayers(layers).setMaxIter(1)
MLTestingUtils.checkNumericTypes[
MultilayerPerceptronClassificationModel, MultilayerPerceptronClassifier](
- mpc, isClassification = true, spark) { (expected, actual) =>
+ mpc, spark) { (expected, actual) =>
assert(expected.layers === actual.layers)
assert(expected.weights === actual.weights)
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
index 15d0358c3f..2a05c446e5 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
@@ -188,7 +188,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
test("should support all NumericType labels and not support other types") {
val nb = new NaiveBayes()
MLTestingUtils.checkNumericTypes[NaiveBayesModel, NaiveBayes](
- nb, isClassification = true, spark) { (expected, actual) =>
+ nb, spark) { (expected, actual) =>
assert(expected.pi === actual.pi)
assert(expected.theta === actual.theta)
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
index 005d609307..5044d40998 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
@@ -228,7 +228,7 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
test("should support all NumericType labels and not support other types") {
val ovr = new OneVsRest().setClassifier(new LogisticRegression().setMaxIter(1))
MLTestingUtils.checkNumericTypes[OneVsRestModel, OneVsRest](
- ovr, isClassification = true, spark) { (expected, actual) =>
+ ovr, spark) { (expected, actual) =>
val expectedModels = expected.models.map(m => m.asInstanceOf[LogisticRegressionModel])
val actualModels = actual.models.map(m => m.asInstanceOf[LogisticRegressionModel])
assert(expectedModels.length === actualModels.length)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
index 97f3feacca..8002a2f4f2 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
@@ -189,7 +189,7 @@ class RandomForestClassifierSuite
test("should support all NumericType labels and not support other types") {
val rf = new RandomForestClassifier().setMaxDepth(1)
MLTestingUtils.checkNumericTypes[RandomForestClassificationModel, RandomForestClassifier](
- rf, isClassification = true, spark) { (expected, actual) =>
+ rf, spark) { (expected, actual) =>
TreeTests.checkEqual(expected, actual)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala
index 4c6d9c5e26..4fcc9745b7 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala
@@ -18,7 +18,7 @@
package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite
-import org.apache.spark.ml.util.DefaultReadWriteTest
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.feature
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
@@ -81,4 +81,12 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext
val newInstance = testDefaultReadWrite(instance)
assert(newInstance.selectedFeatures === instance.selectedFeatures)
}
+
+ test("should support all NumericType labels and not support other types") {
+ val css = new ChiSqSelector()
+ MLTestingUtils.checkNumericTypes[ChiSqSelectorModel, ChiSqSelector](
+ css, spark) { (expected, actual) =>
+ assert(expected.selectedFeatures === actual.selectedFeatures)
+ }
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
index 46e7495297..c623a6210b 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
@@ -20,10 +20,10 @@ package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.attribute._
import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.DefaultReadWriteTest
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.Row
+import org.apache.spark.sql.types.DoubleType
class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
test("params") {
@@ -68,9 +68,9 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
assert(resultSchema.toString == model.transform(original).schema.toString)
}
- test("label column already exists but is not double type") {
+ test("label column already exists but is not numeric type") {
val formula = new RFormula().setFormula("y ~ x").setLabelCol("y")
- val original = spark.createDataFrame(Seq((0, 1), (2, 2))).toDF("x", "y")
+ val original = spark.createDataFrame(Seq((0, true), (2, false))).toDF("x", "y")
val model = formula.fit(original)
intercept[IllegalArgumentException] {
model.transformSchema(original.schema)
@@ -134,7 +134,6 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
).toDF("id", "a", "b")
val model = formula.fit(original)
val result = model.transform(original)
- val resultSchema = model.transformSchema(original.schema)
val expected = spark.createDataFrame(
Seq(
("male", "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0),
@@ -188,7 +187,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
"vec2",
Array[Attribute](
NumericAttribute.defaultAttr,
- NumericAttribute.defaultAttr)).toMetadata
+ NumericAttribute.defaultAttr)).toMetadata()
val original = base.select(base.col("id"), base.col("vec").as("vec2", metadata))
val model = formula.fit(original)
val result = model.transform(original)
@@ -309,4 +308,23 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
val newModel = testDefaultReadWrite(model)
checkModelData(model, newModel)
}
+
+ test("should support all NumericType labels") {
+ val formula = new RFormula().setFormula("label ~ features")
+ .setLabelCol("x")
+ .setFeaturesCol("y")
+ val dfs = MLTestingUtils.genRegressionDFWithNumericLabelCol(spark)
+ val expected = formula.fit(dfs(DoubleType))
+ val actuals = dfs.keys.filter(_ != DoubleType).map(t => formula.fit(dfs(t)))
+ actuals.foreach { actual =>
+ assert(expected.pipelineModel.stages.length === actual.pipelineModel.stages.length)
+ expected.pipelineModel.stages.zip(actual.pipelineModel.stages).foreach {
+ case (exTransformer, acTransformer) =>
+ assert(exTransformer.params === acTransformer.params)
+ }
+ assert(expected.resolvedFormula.label === actual.resolvedFormula.label)
+ assert(expected.resolvedFormula.terms === actual.resolvedFormula.terms)
+ assert(expected.resolvedFormula.hasIntercept === actual.resolvedFormula.hasIntercept)
+ }
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
index f8fc775676..e4772df622 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
@@ -356,7 +356,7 @@ class AFTSurvivalRegressionSuite
test("should support all NumericType labels") {
val aft = new AFTSurvivalRegression().setMaxIter(1)
MLTestingUtils.checkNumericTypes[AFTSurvivalRegressionModel, AFTSurvivalRegression](
- aft, isClassification = false, spark) { (expected, actual) =>
+ aft, spark, isClassification = false) { (expected, actual) =>
assert(expected.intercept === actual.intercept)
assert(expected.coefficients === actual.coefficients)
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
index d9f26ad8dc..2d30cbf367 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
@@ -120,7 +120,7 @@ class DecisionTreeRegressorSuite
test("should support all NumericType labels and not support other types") {
val dt = new DecisionTreeRegressor().setMaxDepth(1)
MLTestingUtils.checkNumericTypes[DecisionTreeRegressionModel, DecisionTreeRegressor](
- dt, isClassification = false, spark) { (expected, actual) =>
+ dt, spark, isClassification = false) { (expected, actual) =>
TreeTests.checkEqual(expected, actual)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
index f6ea5bb741..ac833b833d 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
@@ -115,7 +115,7 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext
test("should support all NumericType labels and not support other types") {
val gbt = new GBTRegressor().setMaxDepth(1)
MLTestingUtils.checkNumericTypes[GBTRegressionModel, GBTRegressor](
- gbt, isClassification = false, spark) { (expected, actual) =>
+ gbt, spark, isClassification = false) { (expected, actual) =>
TreeTests.checkEqual(expected, actual)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
index 161f8c80f8..3d9aeb8c0a 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
@@ -1021,7 +1021,7 @@ class GeneralizedLinearRegressionSuite
val glr = new GeneralizedLinearRegression().setMaxIter(1)
MLTestingUtils.checkNumericTypes[
GeneralizedLinearRegressionModel, GeneralizedLinearRegression](
- glr, isClassification = false, spark) { (expected, actual) =>
+ glr, spark, isClassification = false) { (expected, actual) =>
assert(expected.intercept === actual.intercept)
assert(expected.coefficients === actual.coefficients)
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala
index 9bf7542b12..bed4978b25 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala
@@ -184,7 +184,7 @@ class IsotonicRegressionSuite
test("should support all NumericType labels and not support other types") {
val ir = new IsotonicRegression()
MLTestingUtils.checkNumericTypes[IsotonicRegressionModel, IsotonicRegression](
- ir, isClassification = false, spark) { (expected, actual) =>
+ ir, spark, isClassification = false) { (expected, actual) =>
assert(expected.boundaries === actual.boundaries)
assert(expected.predictions === actual.predictions)
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
index 10f547b673..a98227d2c1 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
@@ -1010,7 +1010,7 @@ class LinearRegressionSuite
test("should support all NumericType labels and not support other types") {
val lr = new LinearRegression().setMaxIter(1)
MLTestingUtils.checkNumericTypes[LinearRegressionModel, LinearRegression](
- lr, isClassification = false, spark) { (expected, actual) =>
+ lr, spark, isClassification = false) { (expected, actual) =>
assert(expected.intercept === actual.intercept)
assert(expected.coefficients === actual.coefficients)
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
index 72f3c65eb8..7a3a3698f9 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
@@ -98,7 +98,7 @@ class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContex
test("should support all NumericType labels and not support other types") {
val rf = new RandomForestRegressor().setMaxDepth(1)
MLTestingUtils.checkNumericTypes[RandomForestRegressionModel, RandomForestRegressor](
- rf, isClassification = false, spark) { (expected, actual) =>
+ rf, spark, isClassification = false) { (expected, actual) =>
TreeTests.checkEqual(expected, actual)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala
index 4fe473bbac..ad7d2c9b8d 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala
@@ -37,8 +37,8 @@ object MLTestingUtils extends SparkFunSuite {
def checkNumericTypes[M <: Model[M], T <: Estimator[M]](
estimator: T,
- isClassification: Boolean,
- spark: SparkSession)(check: (M, M) => Unit): Unit = {
+ spark: SparkSession,
+ isClassification: Boolean = true)(check: (M, M) => Unit): Unit = {
val dfs = if (isClassification) {
genClassifDFWithNumericLabelCol(spark)
} else {