aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala7
1 files changed, 4 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
index 957e8e7a59..4d3e46e488 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
@@ -27,7 +27,7 @@ import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT}
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
@@ -47,10 +47,11 @@ class VectorAssembler(override val uid: String)
/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)
- override def transform(dataset: DataFrame): DataFrame = {
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
// Schema transformation.
val schema = dataset.schema
- lazy val first = dataset.first()
+ lazy val first = dataset.toDF.first()
val attrs = $(inputCols).flatMap { c =>
val field = schema(c)
val index = schema.fieldIndex(c)