aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala5
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala8
2 files changed, 10 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
index 3a4ab9a857..2b1592930e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -27,7 +27,7 @@ import org.apache.spark.ml.Transformer
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
-import org.apache.spark.sql.types.{DoubleType, NumericType, StringType, StructType}
+import org.apache.spark.sql.types._
import org.apache.spark.util.collection.OpenHashMap
/**
@@ -229,8 +229,7 @@ class IndexToString private[ml] (
val outputColName = $(outputCol)
require(inputFields.forall(_.name != outputColName),
s"Output column $outputColName already exists.")
- val attr = NominalAttribute.defaultAttr.withName($(outputCol))
- val outputFields = inputFields :+ attr.toStructField()
+ val outputFields = inputFields :+ StructField($(outputCol), StringType)
StructType(outputFields)
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
index 05e05bdc64..ddcdb5f421 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark.ml.feature
+import org.apache.spark.sql.types.{StringType, StructType, StructField, DoubleType}
import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
import org.apache.spark.ml.param.ParamsSuite
@@ -165,4 +166,11 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(a === b)
}
}
+
+ test("IndexToString.transformSchema (SPARK-10573)") {
+ val idxToStr = new IndexToString().setInputCol("input").setOutputCol("output")
+ val inSchema = StructType(Seq(StructField("input", DoubleType)))
+ val outSchema = idxToStr.transformSchema(inSchema)
+ assert(outSchema("output").dataType === StringType)
+ }
}