diff options
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index 483ef0d88c..267d63b51e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -29,7 +29,7 @@ import org.apache.spark.mllib.classification.{NaiveBayesModel => OldNaiveBayesMo import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} /** * Params for Naive Bayes Classifiers. @@ -101,7 +101,7 @@ class NaiveBayes @Since("1.5.0") ( def setModelType(value: String): this.type = set(modelType, value) setDefault(modelType -> OldNaiveBayes.Multinomial) - override protected def train(dataset: DataFrame): NaiveBayesModel = { + override protected def train(dataset: Dataset[_]): NaiveBayesModel = { val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) val oldModel = OldNaiveBayes.train(oldDataset, $(smoothing), $(modelType)) NaiveBayesModel.fromOld(oldModel, this) |