diff options
author | Nick Pritchard <nicholas.pritchard@falkonry.com> | 2015-09-14 13:27:45 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-09-14 13:27:45 -0700 |
commit | 8a634e9bcc671167613fb575c6c0c054fb4b3479 (patch) | |
tree | 48ecf76230e052fce7676bc7f03fd308b20e7ee3 /mllib | |
parent | ce6f3f163bc667cb5da9ab4331c8bad10cc0d701 (diff) | |
download | spark-8a634e9bcc671167613fb575c6c0c054fb4b3479.tar.gz spark-8a634e9bcc671167613fb575c6c0c054fb4b3479.tar.bz2 spark-8a634e9bcc671167613fb575c6c0c054fb4b3479.zip |
[SPARK-10573] [ML] IndexToString output schema should be StringType
Fixes bug where IndexToString output schema was DoubleType. Correct me if I'm wrong, but it doesn't seem like the output needs to have any "ML Attribute" metadata.
Author: Nick Pritchard <nicholas.pritchard@falkonry.com>
Closes #8751 from pnpritchard/SPARK-10573.
Diffstat (limited to 'mllib')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala | 5 | ||||
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala | 8 |
2 files changed, 10 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 3a4ab9a857..2b1592930e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -27,7 +27,7 @@ import org.apache.spark.ml.Transformer import org.apache.spark.ml.util.Identifiable import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.{DoubleType, NumericType, StringType, StructType} +import org.apache.spark.sql.types._ import org.apache.spark.util.collection.OpenHashMap /** @@ -229,8 +229,7 @@ class IndexToString private[ml] ( val outputColName = $(outputCol) require(inputFields.forall(_.name != outputColName), s"Output column $outputColName already exists.") - val attr = NominalAttribute.defaultAttr.withName($(outputCol)) - val outputFields = inputFields :+ attr.toStructField() + val outputFields = inputFields :+ StructField($(outputCol), StringType) StructType(outputFields) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index 05e05bdc64..ddcdb5f421 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.ml.feature +import org.apache.spark.sql.types.{StringType, StructType, StructField, DoubleType} import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} import org.apache.spark.ml.param.ParamsSuite @@ -165,4 +166,11 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { assert(a === b) } } + + test("IndexToString.transformSchema (SPARK-10573)") { + val idxToStr = new IndexToString().setInputCol("input").setOutputCol("output") + val inSchema = StructType(Seq(StructField("input", DoubleType))) + val outSchema = idxToStr.transformSchema(inSchema) + assert(outSchema("output").dataType === StringType) + } } |