aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala4
1 files changed, 2 insertions, 2 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
index 483ef0d88c..267d63b51e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
@@ -29,7 +29,7 @@ import org.apache.spark.mllib.classification.{NaiveBayesModel => OldNaiveBayesMo
import org.apache.spark.mllib.linalg._
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
/**
* Params for Naive Bayes Classifiers.
@@ -101,7 +101,7 @@ class NaiveBayes @Since("1.5.0") (
def setModelType(value: String): this.type = set(modelType, value)
setDefault(modelType -> OldNaiveBayes.Multinomial)
- override protected def train(dataset: DataFrame): NaiveBayesModel = {
+ override protected def train(dataset: Dataset[_]): NaiveBayesModel = {
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
val oldModel = OldNaiveBayes.train(oldDataset, $(smoothing), $(modelType))
NaiveBayesModel.fromOld(oldModel, this)