aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala14
1 files changed, 7 insertions, 7 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
index 3a99979a88..82066726a0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
@@ -31,7 +31,7 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental, Since}
import org.apache.spark.internal.Logging
import org.apache.spark.ml.param.{Param, ParamMap, Params}
import org.apache.spark.ml.util._
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.types.StructType
/**
@@ -123,8 +123,8 @@ class Pipeline @Since("1.4.0") (
* @param dataset input dataset
* @return fitted pipeline
*/
- @Since("1.2.0")
- override def fit(dataset: DataFrame): PipelineModel = {
+ @Since("2.0.0")
+ override def fit(dataset: Dataset[_]): PipelineModel = {
transformSchema(dataset.schema, logging = true)
val theStages = $(stages)
// Search for the last estimator.
@@ -147,7 +147,7 @@ class Pipeline @Since("1.4.0") (
t
case _ =>
throw new IllegalArgumentException(
- s"Do not support stage $stage of type ${stage.getClass}")
+ s"Does not support stage $stage of type ${stage.getClass}")
}
if (index < indexOfLastEstimator) {
curDataset = transformer.transform(curDataset)
@@ -291,10 +291,10 @@ class PipelineModel private[ml] (
this(uid, stages.asScala.toArray)
}
- @Since("1.2.0")
- override def transform(dataset: DataFrame): DataFrame = {
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
- stages.foldLeft(dataset)((cur, transformer) => transformer.transform(cur))
+ stages.foldLeft(dataset.toDF)((cur, transformer) => transformer.transform(cur))
}
@Since("1.2.0")