diff options
author | Yanbo Liang <ybliang8@gmail.com> | 2016-04-12 11:34:40 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2016-04-12 11:34:40 -0700 |
commit | 111a62474a2fb7f4e7f19fcfb8efaae37aa40400 (patch) | |
tree | 12cf81dcb4606a3d0bac50708444fccb37d6facd | |
parent | 1995c2e6482bf4af5a4be087bfc156311c1bec19 (diff) | |
download | spark-111a62474a2fb7f4e7f19fcfb8efaae37aa40400.tar.gz spark-111a62474a2fb7f4e7f19fcfb8efaae37aa40400.tar.bz2 spark-111a62474a2fb7f4e7f19fcfb8efaae37aa40400.zip |
[SPARK-14147][ML][SPARKR] SparkR predict should not output feature column
## What changes were proposed in this pull request?
SparkR does not support type of vector which is the default type of feature column in ML. R predict also does not output intermediate feature column. So SparkR ```predict``` should not output feature column. In this PR, I only fix this issue for ```naiveBayes``` and ```survreg```. ```kmeans``` has the right code route already and ```glm``` will be fixed at SparkRWrapper refactor(#12294).
## How was this patch tested?
No new tests.
cc mengxr shivaram
Author: Yanbo Liang <ybliang8@gmail.com>
Closes #11958 from yanboliang/spark-14147.
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala | 2 | ||||
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala | 4 |
2 files changed, 4 insertions, 2 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala index 2ae411555f..7835468626 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala @@ -44,7 +44,7 @@ private[r] class AFTSurvivalRegressionWrapper private ( } def transform(dataset: Dataset[_]): DataFrame = { - pipeline.transform(dataset) + pipeline.transform(dataset).drop(aftModel.getFeaturesCol) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala index 2cd709d2ee..b17207e99b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala @@ -37,7 +37,9 @@ private[r] class NaiveBayesWrapper private ( lazy val tables: Array[Double] = naiveBayesModel.theta.toArray.map(math.exp) def transform(dataset: Dataset[_]): DataFrame = { - pipeline.transform(dataset).drop(PREDICTED_LABEL_INDEX_COL) + pipeline.transform(dataset) + .drop(PREDICTED_LABEL_INDEX_COL) + .drop(naiveBayesModel.getFeaturesCol) } } |