aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala34
3 files changed, 37 insertions, 5 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
index 2f78dd30b3..4b3608330c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
@@ -106,7 +106,7 @@ private[regression] trait AFTSurvivalRegressionParams extends Params
fitting: Boolean): StructType = {
SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
if (fitting) {
- SchemaUtils.checkColumnType(schema, $(censorCol), DoubleType)
+ SchemaUtils.checkNumericType(schema, $(censorCol))
SchemaUtils.checkNumericType(schema, $(labelCol))
}
if (hasQuantilesCol) {
@@ -200,8 +200,8 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S
* and put it in an RDD with strong types.
*/
protected[ml] def extractAFTPoints(dataset: Dataset[_]): RDD[AFTPoint] = {
- dataset.select(col($(featuresCol)), col($(labelCol)).cast(DoubleType), col($(censorCol)))
- .rdd.map {
+ dataset.select(col($(featuresCol)), col($(labelCol)).cast(DoubleType),
+ col($(censorCol)).cast(DoubleType)).rdd.map {
case Row(features: Vector, label: Double, censor: Double) =>
AFTPoint(features, label, censor)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
index a6c29433d7..529f66eadb 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
@@ -49,7 +49,7 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures
*/
final val isotonic: BooleanParam =
new BooleanParam(this, "isotonic",
- "whether the output sequence should be isotonic/increasing (true) or" +
+ "whether the output sequence should be isotonic/increasing (true) or " +
"antitonic/decreasing (false)")
/** @group getParam */
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
index 0fdfdf37cf..3cd4b0ac30 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
@@ -27,6 +27,8 @@ import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.random.{ExponentialGenerator, WeibullGenerator}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.functions.{col, lit}
+import org.apache.spark.sql.types._
class AFTSurvivalRegressionSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
@@ -352,7 +354,7 @@ class AFTSurvivalRegressionSuite
}
}
- test("should support all NumericType labels") {
+ test("should support all NumericType labels, and not support other types") {
val aft = new AFTSurvivalRegression().setMaxIter(1)
MLTestingUtils.checkNumericTypes[AFTSurvivalRegressionModel, AFTSurvivalRegression](
aft, spark, isClassification = false) { (expected, actual) =>
@@ -361,6 +363,36 @@ class AFTSurvivalRegressionSuite
}
}
+ test("should support all NumericType censors, and not support other types") {
+ val df = spark.createDataFrame(Seq(
+ (0, Vectors.dense(0)),
+ (1, Vectors.dense(1)),
+ (2, Vectors.dense(2)),
+ (3, Vectors.dense(3)),
+ (4, Vectors.dense(4))
+ )).toDF("label", "features")
+ .withColumn("censor", lit(0.0))
+ val aft = new AFTSurvivalRegression().setMaxIter(1)
+ val expected = aft.fit(df)
+
+ val types = Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DecimalType(10, 0))
+ types.foreach { t =>
+ val actual = aft.fit(df.select(col("label"), col("features"),
+ col("censor").cast(t)))
+ assert(expected.intercept === actual.intercept)
+ assert(expected.coefficients === actual.coefficients)
+ }
+
+ val dfWithStringCensors = spark.createDataFrame(Seq(
+ (0, Vectors.dense(0, 2, 3), "0")
+ )).toDF("label", "features", "censor")
+ val thrown = intercept[IllegalArgumentException] {
+ aft.fit(dfWithStringCensors)
+ }
+ assert(thrown.getMessage.contains(
+ "Column censor must be of type NumericType but was actually of type StringType"))
+ }
+
test("numerical stability of standardization") {
val trainer = new AFTSurvivalRegression()
val model1 = trainer.fit(datasetUnivariate)