diff options
Diffstat (limited to 'mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala')
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala | 30 |
1 files changed, 24 insertions, 6 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala index 46e7495297..c623a6210b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -20,10 +20,10 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute._ import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.Row +import org.apache.spark.sql.types.DoubleType class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("params") { @@ -68,9 +68,9 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul assert(resultSchema.toString == model.transform(original).schema.toString) } - test("label column already exists but is not double type") { + test("label column already exists but is not numeric type") { val formula = new RFormula().setFormula("y ~ x").setLabelCol("y") - val original = spark.createDataFrame(Seq((0, 1), (2, 2))).toDF("x", "y") + val original = spark.createDataFrame(Seq((0, true), (2, false))).toDF("x", "y") val model = formula.fit(original) intercept[IllegalArgumentException] { model.transformSchema(original.schema) @@ -134,7 +134,6 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul ).toDF("id", "a", "b") val model = formula.fit(original) val result = model.transform(original) - val resultSchema = model.transformSchema(original.schema) val expected = spark.createDataFrame( Seq( ("male", "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0), @@ -188,7 +187,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul "vec2", Array[Attribute]( NumericAttribute.defaultAttr, - NumericAttribute.defaultAttr)).toMetadata + NumericAttribute.defaultAttr)).toMetadata() val original = base.select(base.col("id"), base.col("vec").as("vec2", metadata)) val model = formula.fit(original) val result = model.transform(original) @@ -309,4 +308,23 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul val newModel = testDefaultReadWrite(model) checkModelData(model, newModel) } + + test("should support all NumericType labels") { + val formula = new RFormula().setFormula("label ~ features") + .setLabelCol("x") + .setFeaturesCol("y") + val dfs = MLTestingUtils.genRegressionDFWithNumericLabelCol(spark) + val expected = formula.fit(dfs(DoubleType)) + val actuals = dfs.keys.filter(_ != DoubleType).map(t => formula.fit(dfs(t))) + actuals.foreach { actual => + assert(expected.pipelineModel.stages.length === actual.pipelineModel.stages.length) + expected.pipelineModel.stages.zip(actual.pipelineModel.stages).foreach { + case (exTransformer, acTransformer) => + assert(exTransformer.params === acTransformer.params) + } + assert(expected.resolvedFormula.label === actual.resolvedFormula.label) + assert(expected.resolvedFormula.terms === actual.resolvedFormula.terms) + assert(expected.resolvedFormula.hasIntercept === actual.resolvedFormula.hasIntercept) + } + } } |