diff options
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala | 13 |
1 files changed, 8 insertions, 5 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index faa0f6f407..7e0d374f02 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -26,7 +26,7 @@ import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.apache.spark.util.collection.OpenHashMap @@ -80,7 +80,8 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod def setOutputCol(value: String): this.type = set(outputCol, value) - override def fit(dataset: DataFrame): StringIndexerModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): StringIndexerModel = { val counts = dataset.select(col($(inputCol)).cast(StringType)) .rdd .map(_.getString(0)) @@ -144,11 +145,12 @@ class StringIndexerModel ( /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { if (!dataset.schema.fieldNames.contains($(inputCol))) { logInfo(s"Input column ${$(inputCol)} does not exist during transformation. " + "Skip StringIndexerModel.") - return dataset + return dataset.toDF } validateAndTransformSchema(dataset.schema) @@ -286,7 +288,8 @@ class IndexToString private[ml] (override val uid: String) StructType(outputFields) } - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { val inputColSchema = dataset.schema($(inputCol)) // If the labels array is empty use column metadata val values = if ($(labels).isEmpty) { |