aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala8
1 files changed, 5 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala
index 07383d393d..b17207e99b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala
@@ -21,7 +21,7 @@ import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute}
import org.apache.spark.ml.classification.{NaiveBayes, NaiveBayesModel}
import org.apache.spark.ml.feature.{IndexToString, RFormula}
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
private[r] class NaiveBayesWrapper private (
pipeline: PipelineModel,
@@ -36,8 +36,10 @@ private[r] class NaiveBayesWrapper private (
lazy val tables: Array[Double] = naiveBayesModel.theta.toArray.map(math.exp)
- def transform(dataset: DataFrame): DataFrame = {
- pipeline.transform(dataset).drop(PREDICTED_LABEL_INDEX_COL)
+ def transform(dataset: Dataset[_]): DataFrame = {
+ pipeline.transform(dataset)
+ .drop(PREDICTED_LABEL_INDEX_COL)
+ .drop(naiveBayesModel.getFeaturesCol)
}
}