diff options
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala | 7 |
1 files changed, 4 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index 957e8e7a59..4d3e46e488 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -27,7 +27,7 @@ import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT} -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -47,10 +47,11 @@ class VectorAssembler(override val uid: String) /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { // Schema transformation. val schema = dataset.schema - lazy val first = dataset.first() + lazy val first = dataset.toDF.first() val attrs = $(inputCols).flatMap { c => val field = schema(c) val index = schema.fieldIndex(c) |