aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala12
2 files changed, 13 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
index e5f23f44bc..7f57af19e9 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
@@ -538,7 +538,7 @@ private class AFTAggregator(
* @return This AFTAggregator object.
*/
def merge(other: AFTAggregator): this.type = {
- if (totalCnt != 0) {
+ if (other.count != 0) {
totalCnt += other.totalCnt
lossSum += other.lossSum
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
index 05aae80c66..1c70b702de 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
@@ -390,6 +390,18 @@ class AFTSurvivalRegressionSuite
testEstimatorAndModelReadWrite(aft, datasetMultivariate,
AFTSurvivalRegressionSuite.allParamSettings, checkModelData)
}
+
+ test("SPARK-15892: Incorrectly merged AFTAggregator with zero total count") {
+ // This `dataset` will contain an empty partition because it has two rows but
+ // the parallelism is bigger than that. Because the issue was about `AFTAggregator`s
+ // being merged incorrectly when it has an empty partition, running the codes below
+ // should not throw an exception.
+ val dataset = spark.createDataFrame(
+ sc.parallelize(generateAFTInput(
+ 1, Array(5.5), Array(0.8), 2, 42, 1.0, 2.0, 2.0), numSlices = 3))
+ val trainer = new AFTSurvivalRegression()
+ trainer.fit(dataset)
+ }
}
object AFTSurvivalRegressionSuite {