aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/Estimator.scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Estimator.scala16
1 files changed, 10 insertions, 6 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
index 57e416591d..1247882d6c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
@@ -19,9 +19,9 @@ package org.apache.spark.ml
import scala.annotation.varargs
-import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.annotation.{DeveloperApi, Since}
import org.apache.spark.ml.param.{ParamMap, ParamPair}
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.Dataset
/**
* :: DeveloperApi ::
@@ -39,8 +39,9 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage {
* Estimator's embedded ParamMap.
* @return fitted model
*/
+ @Since("2.0.0")
@varargs
- def fit(dataset: DataFrame, firstParamPair: ParamPair[_], otherParamPairs: ParamPair[_]*): M = {
+ def fit(dataset: Dataset[_], firstParamPair: ParamPair[_], otherParamPairs: ParamPair[_]*): M = {
val map = new ParamMap()
.put(firstParamPair)
.put(otherParamPairs: _*)
@@ -55,14 +56,16 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage {
* These values override any specified in this Estimator's embedded ParamMap.
* @return fitted model
*/
- def fit(dataset: DataFrame, paramMap: ParamMap): M = {
+ @Since("2.0.0")
+ def fit(dataset: Dataset[_], paramMap: ParamMap): M = {
copy(paramMap).fit(dataset)
}
/**
* Fits a model to the input data.
*/
- def fit(dataset: DataFrame): M
+ @Since("2.0.0")
+ def fit(dataset: Dataset[_]): M
/**
* Fits multiple models to the input data with multiple sets of parameters.
@@ -74,7 +77,8 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage {
* These values override any specified in this Estimator's embedded ParamMap.
* @return fitted models, matching the input parameter maps
*/
- def fit(dataset: DataFrame, paramMaps: Array[ParamMap]): Seq[M] = {
+ @Since("2.0.0")
+ def fit(dataset: Dataset[_], paramMaps: Array[ParamMap]): Seq[M] = {
paramMaps.map(fit(dataset, _))
}