diff options
author | Yanbo Liang <ybliang8@gmail.com> | 2016-07-29 04:40:20 -0700 |
---|---|---|
committer | Sean Owen <sowen@cloudera.com> | 2016-07-29 04:40:20 -0700 |
commit | 0557a45452f6e73877e5ec972110825ce8f3fbc5 (patch) | |
tree | 28b18541ba9bfc1217041a08a2210c3d5835c757 /mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala | |
parent | d1d5069aa3744d46abd3889abab5f15e9067382a (diff) | |
download | spark-0557a45452f6e73877e5ec972110825ce8f3fbc5.tar.gz spark-0557a45452f6e73877e5ec972110825ce8f3fbc5.tar.bz2 spark-0557a45452f6e73877e5ec972110825ce8f3fbc5.zip |
[SPARK-16750][ML] Fix GaussianMixture training failed due to feature column type mistake
## What changes were proposed in this pull request?
ML ```GaussianMixture``` training failed due to feature column type mistake. The feature column type should be ```ml.linalg.VectorUDT``` but got ```mllib.linalg.VectorUDT``` by mistake.
See [SPARK-16750](https://issues.apache.org/jira/browse/SPARK-16750) for how to reproduce this bug.
Why the unit tests did not complain this errors? Because some estimators/transformers missed calling ```transformSchema(dataset.schema)``` firstly during ```fit``` or ```transform```. I will also add this function to all estimators/transformers who missed in this PR.
## How was this patch tested?
No new tests, should pass existing ones.
Author: Yanbo Liang <ybliang8@gmail.com>
Closes #14378 from yanboliang/spark-16750.
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index 2b9912657f..d4ae59deff 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -196,7 +196,7 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S @Since("2.0.0") override def fit(dataset: Dataset[_]): AFTSurvivalRegressionModel = { - validateAndTransformSchema(dataset.schema, fitting = true) + transformSchema(dataset.schema, logging = true) val instances = extractAFTPoints(dataset) val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) @@ -326,7 +326,7 @@ class AFTSurvivalRegressionModel private[ml] ( @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { - transformSchema(dataset.schema) + transformSchema(dataset.schema, logging = true) val predictUDF = udf { features: Vector => predict(features) } val predictQuantilesUDF = udf { features: Vector => predictQuantiles(features)} if (hasQuantilesCol) { |