aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala30
1 files changed, 24 insertions, 6 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
index 46e7495297..c623a6210b 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
@@ -20,10 +20,10 @@ package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.attribute._
import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.DefaultReadWriteTest
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.Row
+import org.apache.spark.sql.types.DoubleType
class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
test("params") {
@@ -68,9 +68,9 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
assert(resultSchema.toString == model.transform(original).schema.toString)
}
- test("label column already exists but is not double type") {
+ test("label column already exists but is not numeric type") {
val formula = new RFormula().setFormula("y ~ x").setLabelCol("y")
- val original = spark.createDataFrame(Seq((0, 1), (2, 2))).toDF("x", "y")
+ val original = spark.createDataFrame(Seq((0, true), (2, false))).toDF("x", "y")
val model = formula.fit(original)
intercept[IllegalArgumentException] {
model.transformSchema(original.schema)
@@ -134,7 +134,6 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
).toDF("id", "a", "b")
val model = formula.fit(original)
val result = model.transform(original)
- val resultSchema = model.transformSchema(original.schema)
val expected = spark.createDataFrame(
Seq(
("male", "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0),
@@ -188,7 +187,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
"vec2",
Array[Attribute](
NumericAttribute.defaultAttr,
- NumericAttribute.defaultAttr)).toMetadata
+ NumericAttribute.defaultAttr)).toMetadata()
val original = base.select(base.col("id"), base.col("vec").as("vec2", metadata))
val model = formula.fit(original)
val result = model.transform(original)
@@ -309,4 +308,23 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
val newModel = testDefaultReadWrite(model)
checkModelData(model, newModel)
}
+
+ test("should support all NumericType labels") {
+ val formula = new RFormula().setFormula("label ~ features")
+ .setLabelCol("x")
+ .setFeaturesCol("y")
+ val dfs = MLTestingUtils.genRegressionDFWithNumericLabelCol(spark)
+ val expected = formula.fit(dfs(DoubleType))
+ val actuals = dfs.keys.filter(_ != DoubleType).map(t => formula.fit(dfs(t)))
+ actuals.foreach { actual =>
+ assert(expected.pipelineModel.stages.length === actual.pipelineModel.stages.length)
+ expected.pipelineModel.stages.zip(actual.pipelineModel.stages).foreach {
+ case (exTransformer, acTransformer) =>
+ assert(exTransformer.params === acTransformer.params)
+ }
+ assert(expected.resolvedFormula.label === actual.resolvedFormula.label)
+ assert(expected.resolvedFormula.terms === actual.resolvedFormula.terms)
+ assert(expected.resolvedFormula.hasIntercept === actual.resolvedFormula.hasIntercept)
+ }
+ }
}