diff options
author | hyukjinkwon <gurwls223@gmail.com> | 2016-06-12 14:26:53 -0700 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2016-06-12 14:26:53 -0700 |
commit | e3554605b36bdce63ac180cc66dbdee5c1528ec7 (patch) | |
tree | 813ed4fafabf35fde2a0a59e8f2ce2ff15f39cfd /mllib/src/test | |
parent | 0ff8a68b9faefecf90c9ceef49580b2909beb19e (diff) | |
download | spark-e3554605b36bdce63ac180cc66dbdee5c1528ec7.tar.gz spark-e3554605b36bdce63ac180cc66dbdee5c1528ec7.tar.bz2 spark-e3554605b36bdce63ac180cc66dbdee5c1528ec7.zip |
[SPARK-15892][ML] Incorrectly merged AFTAggregator with zero total count
## What changes were proposed in this pull request?
Currently, `AFTAggregator` is not being merged correctly. For example, if there is any single empty partition in the data, this creates an `AFTAggregator` with zero total count which causes the exception below:
```
IllegalArgumentException: u'requirement failed: The number of instances should be greater than 0.0, but got 0.'
```
Please see [AFTSurvivalRegression.scala#L573-L575](https://github.com/apache/spark/blob/6ecedf39b44c9acd58cdddf1a31cf11e8e24428c/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala#L573-L575) as well.
Just to be clear, the python example `aft_survival_regression.py` seems using 5 rows. So, if there exist partitions more than 5, it throws the exception above since it contains empty partitions which results in an incorrectly merged `AFTAggregator`.
Executing `bin/spark-submit examples/src/main/python/ml/aft_survival_regression.py` on a machine with CPUs more than 5 is being failed because it creates tasks with some empty partitions with defualt configurations (AFAIK, it sets the parallelism level to the number of CPU cores).
## How was this patch tested?
An unit test in `AFTSurvivalRegressionSuite.scala` and manually tested by `bin/spark-submit examples/src/main/python/ml/aft_survival_regression.py`.
Author: hyukjinkwon <gurwls223@gmail.com>
Author: Hyukjin Kwon <gurwls223@gmail.com>
Closes #13619 from HyukjinKwon/SPARK-15892.
Diffstat (limited to 'mllib/src/test')
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala | 12 |
1 files changed, 12 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala index 05aae80c66..1c70b702de 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala @@ -390,6 +390,18 @@ class AFTSurvivalRegressionSuite testEstimatorAndModelReadWrite(aft, datasetMultivariate, AFTSurvivalRegressionSuite.allParamSettings, checkModelData) } + + test("SPARK-15892: Incorrectly merged AFTAggregator with zero total count") { + // This `dataset` will contain an empty partition because it has two rows but + // the parallelism is bigger than that. Because the issue was about `AFTAggregator`s + // being merged incorrectly when it has an empty partition, running the codes below + // should not throw an exception. + val dataset = spark.createDataFrame( + sc.parallelize(generateAFTInput( + 1, Array(5.5), Array(0.8), 2, 42, 1.0, 2.0, 2.0), numSlices = 3)) + val trainer = new AFTSurvivalRegression() + trainer.fit(dataset) + } } object AFTSurvivalRegressionSuite { |