aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala13
1 files changed, 8 insertions, 5 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
index faa0f6f407..7e0d374f02 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -26,7 +26,7 @@ import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.util.collection.OpenHashMap
@@ -80,7 +80,8 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod
def setOutputCol(value: String): this.type = set(outputCol, value)
- override def fit(dataset: DataFrame): StringIndexerModel = {
+ @Since("2.0.0")
+ override def fit(dataset: Dataset[_]): StringIndexerModel = {
val counts = dataset.select(col($(inputCol)).cast(StringType))
.rdd
.map(_.getString(0))
@@ -144,11 +145,12 @@ class StringIndexerModel (
/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)
- override def transform(dataset: DataFrame): DataFrame = {
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
if (!dataset.schema.fieldNames.contains($(inputCol))) {
logInfo(s"Input column ${$(inputCol)} does not exist during transformation. " +
"Skip StringIndexerModel.")
- return dataset
+ return dataset.toDF
}
validateAndTransformSchema(dataset.schema)
@@ -286,7 +288,8 @@ class IndexToString private[ml] (override val uid: String)
StructType(outputFields)
}
- override def transform(dataset: DataFrame): DataFrame = {
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
val inputColSchema = dataset.schema($(inputCol))
// If the labels array is empty use column metadata
val values = if ($(labels).isEmpty) {