aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala34
1 files changed, 33 insertions, 1 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
index 0fdfdf37cf..3cd4b0ac30 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
@@ -27,6 +27,8 @@ import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.random.{ExponentialGenerator, WeibullGenerator}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.functions.{col, lit}
+import org.apache.spark.sql.types._
class AFTSurvivalRegressionSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
@@ -352,7 +354,7 @@ class AFTSurvivalRegressionSuite
}
}
- test("should support all NumericType labels") {
+ test("should support all NumericType labels, and not support other types") {
val aft = new AFTSurvivalRegression().setMaxIter(1)
MLTestingUtils.checkNumericTypes[AFTSurvivalRegressionModel, AFTSurvivalRegression](
aft, spark, isClassification = false) { (expected, actual) =>
@@ -361,6 +363,36 @@ class AFTSurvivalRegressionSuite
}
}
+ test("should support all NumericType censors, and not support other types") {
+ val df = spark.createDataFrame(Seq(
+ (0, Vectors.dense(0)),
+ (1, Vectors.dense(1)),
+ (2, Vectors.dense(2)),
+ (3, Vectors.dense(3)),
+ (4, Vectors.dense(4))
+ )).toDF("label", "features")
+ .withColumn("censor", lit(0.0))
+ val aft = new AFTSurvivalRegression().setMaxIter(1)
+ val expected = aft.fit(df)
+
+ val types = Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DecimalType(10, 0))
+ types.foreach { t =>
+ val actual = aft.fit(df.select(col("label"), col("features"),
+ col("censor").cast(t)))
+ assert(expected.intercept === actual.intercept)
+ assert(expected.coefficients === actual.coefficients)
+ }
+
+ val dfWithStringCensors = spark.createDataFrame(Seq(
+ (0, Vectors.dense(0, 2, 3), "0")
+ )).toDF("label", "features", "censor")
+ val thrown = intercept[IllegalArgumentException] {
+ aft.fit(dfWithStringCensors)
+ }
+ assert(thrown.getMessage.contains(
+ "Column censor must be of type NumericType but was actually of type StringType"))
+ }
+
test("numerical stability of standardization") {
val trainer = new AFTSurvivalRegression()
val model1 = trainer.fit(datasetUnivariate)