aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala6
1 files changed, 3 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index cb42532271..9d80b8eb68 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -31,7 +31,7 @@ import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions._
@@ -98,7 +98,7 @@ final class RandomForestClassifier @Since("1.4.0") (
override def setFeatureSubsetStrategy(value: String): this.type =
super.setFeatureSubsetStrategy(value)
- override protected def train(dataset: DataFrame): RandomForestClassificationModel = {
+ override protected def train(dataset: Dataset[_]): RandomForestClassificationModel = {
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match {
@@ -180,7 +180,7 @@ final class RandomForestClassificationModel private[ml] (
@Since("1.4.0")
override def treeWeights: Array[Double] = _treeWeights
- override protected def transformImpl(dataset: DataFrame): DataFrame = {
+ override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)
val predictUDF = udf { (features: Any) =>
bcastModel.value.predict(features.asInstanceOf[Vector])