diff options
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala | 14 |
1 files changed, 7 insertions, 7 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index 3a99979a88..82066726a0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -31,7 +31,7 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} import org.apache.spark.internal.Logging import org.apache.spark.ml.param.{Param, ParamMap, Params} import org.apache.spark.ml.util._ -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.types.StructType /** @@ -123,8 +123,8 @@ class Pipeline @Since("1.4.0") ( * @param dataset input dataset * @return fitted pipeline */ - @Since("1.2.0") - override def fit(dataset: DataFrame): PipelineModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): PipelineModel = { transformSchema(dataset.schema, logging = true) val theStages = $(stages) // Search for the last estimator. @@ -147,7 +147,7 @@ class Pipeline @Since("1.4.0") ( t case _ => throw new IllegalArgumentException( - s"Do not support stage $stage of type ${stage.getClass}") + s"Does not support stage $stage of type ${stage.getClass}") } if (index < indexOfLastEstimator) { curDataset = transformer.transform(curDataset) @@ -291,10 +291,10 @@ class PipelineModel private[ml] ( this(uid, stages.asScala.toArray) } - @Since("1.2.0") - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) - stages.foldLeft(dataset)((cur, transformer) => transformer.transform(cur)) + stages.foldLeft(dataset.toDF)((cur, transformer) => transformer.transform(cur)) } @Since("1.2.0") |