diff options
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/Estimator.scala')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/Estimator.scala | 16 |
1 files changed, 10 insertions, 6 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala index 57e416591d..1247882d6c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala @@ -19,9 +19,9 @@ package org.apache.spark.ml import scala.annotation.varargs -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.ml.param.{ParamMap, ParamPair} -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.Dataset /** * :: DeveloperApi :: @@ -39,8 +39,9 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage { * Estimator's embedded ParamMap. * @return fitted model */ + @Since("2.0.0") @varargs - def fit(dataset: DataFrame, firstParamPair: ParamPair[_], otherParamPairs: ParamPair[_]*): M = { + def fit(dataset: Dataset[_], firstParamPair: ParamPair[_], otherParamPairs: ParamPair[_]*): M = { val map = new ParamMap() .put(firstParamPair) .put(otherParamPairs: _*) @@ -55,14 +56,16 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage { * These values override any specified in this Estimator's embedded ParamMap. * @return fitted model */ - def fit(dataset: DataFrame, paramMap: ParamMap): M = { + @Since("2.0.0") + def fit(dataset: Dataset[_], paramMap: ParamMap): M = { copy(paramMap).fit(dataset) } /** * Fits a model to the input data. */ - def fit(dataset: DataFrame): M + @Since("2.0.0") + def fit(dataset: Dataset[_]): M /** * Fits multiple models to the input data with multiple sets of parameters. @@ -74,7 +77,8 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage { * These values override any specified in this Estimator's embedded ParamMap. * @return fitted models, matching the input parameter maps */ - def fit(dataset: DataFrame, paramMaps: Array[ParamMap]): Seq[M] = { + @Since("2.0.0") + def fit(dataset: Dataset[_], paramMaps: Array[ParamMap]): Seq[M] = { paramMaps.map(fit(dataset, _)) } |