diff options
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala | 8 |
1 files changed, 5 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala index 07383d393d..b17207e99b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala @@ -21,7 +21,7 @@ import org.apache.spark.ml.{Pipeline, PipelineModel} import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute} import org.apache.spark.ml.classification.{NaiveBayes, NaiveBayesModel} import org.apache.spark.ml.feature.{IndexToString, RFormula} -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} private[r] class NaiveBayesWrapper private ( pipeline: PipelineModel, @@ -36,8 +36,10 @@ private[r] class NaiveBayesWrapper private ( lazy val tables: Array[Double] = naiveBayesModel.theta.toArray.map(math.exp) - def transform(dataset: DataFrame): DataFrame = { - pipeline.transform(dataset).drop(PREDICTED_LABEL_INDEX_COL) + def transform(dataset: Dataset[_]): DataFrame = { + pipeline.transform(dataset) + .drop(PREDICTED_LABEL_INDEX_COL) + .drop(naiveBayesModel.getFeaturesCol) } } |