aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala9
1 files changed, 5 insertions, 4 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
index 267d63b51e..a98bdeca6b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
@@ -22,14 +22,14 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.SparkException
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.PredictorParams
+import org.apache.spark.ml.linalg._
import org.apache.spark.ml.param.{DoubleParam, Param, ParamMap, ParamValidators}
import org.apache.spark.ml.util._
import org.apache.spark.mllib.classification.{NaiveBayes => OldNaiveBayes}
import org.apache.spark.mllib.classification.{NaiveBayesModel => OldNaiveBayesModel}
-import org.apache.spark.mllib.linalg._
-import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Dataset}
+import org.apache.spark.sql.Dataset
/**
* Params for Naive Bayes Classifiers.
@@ -102,7 +102,8 @@ class NaiveBayes @Since("1.5.0") (
setDefault(modelType -> OldNaiveBayes.Multinomial)
override protected def train(dataset: Dataset[_]): NaiveBayesModel = {
- val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
+ val oldDataset: RDD[OldLabeledPoint] =
+ extractLabeledPoints(dataset).map(OldLabeledPoint.fromML)
val oldModel = OldNaiveBayes.train(oldDataset, $(smoothing), $(modelType))
NaiveBayesModel.fromOld(oldModel, this)
}