aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2016-07-29 04:40:20 -0700
committerSean Owen <sowen@cloudera.com>2016-07-29 04:40:20 -0700
commit0557a45452f6e73877e5ec972110825ce8f3fbc5 (patch)
tree28b18541ba9bfc1217041a08a2210c3d5835c757 /mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
parentd1d5069aa3744d46abd3889abab5f15e9067382a (diff)
downloadspark-0557a45452f6e73877e5ec972110825ce8f3fbc5.tar.gz
spark-0557a45452f6e73877e5ec972110825ce8f3fbc5.tar.bz2
spark-0557a45452f6e73877e5ec972110825ce8f3fbc5.zip
[SPARK-16750][ML] Fix GaussianMixture training failed due to feature column type mistake
## What changes were proposed in this pull request? ML ```GaussianMixture``` training failed due to feature column type mistake. The feature column type should be ```ml.linalg.VectorUDT``` but got ```mllib.linalg.VectorUDT``` by mistake. See [SPARK-16750](https://issues.apache.org/jira/browse/SPARK-16750) for how to reproduce this bug. Why the unit tests did not complain this errors? Because some estimators/transformers missed calling ```transformSchema(dataset.schema)``` firstly during ```fit``` or ```transform```. I will also add this function to all estimators/transformers who missed in this PR. ## How was this patch tested? No new tests, should pass existing ones. Author: Yanbo Liang <ybliang8@gmail.com> Closes #14378 from yanboliang/spark-16750.
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala4
1 files changed, 2 insertions, 2 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
index 2b9912657f..d4ae59deff 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
@@ -196,7 +196,7 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S
@Since("2.0.0")
override def fit(dataset: Dataset[_]): AFTSurvivalRegressionModel = {
- validateAndTransformSchema(dataset.schema, fitting = true)
+ transformSchema(dataset.schema, logging = true)
val instances = extractAFTPoints(dataset)
val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)
@@ -326,7 +326,7 @@ class AFTSurvivalRegressionModel private[ml] (
@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
- transformSchema(dataset.schema)
+ transformSchema(dataset.schema, logging = true)
val predictUDF = udf { features: Vector => predict(features) }
val predictQuantilesUDF = udf { features: Vector => predictQuantiles(features)}
if (hasQuantilesCol) {