aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNick Pritchard <nicholas.pritchard@falkonry.com>2015-09-14 13:27:45 -0700
committerXiangrui Meng <meng@databricks.com>2015-09-14 13:27:45 -0700
commit8a634e9bcc671167613fb575c6c0c054fb4b3479 (patch)
tree48ecf76230e052fce7676bc7f03fd308b20e7ee3
parentce6f3f163bc667cb5da9ab4331c8bad10cc0d701 (diff)
downloadspark-8a634e9bcc671167613fb575c6c0c054fb4b3479.tar.gz
spark-8a634e9bcc671167613fb575c6c0c054fb4b3479.tar.bz2
spark-8a634e9bcc671167613fb575c6c0c054fb4b3479.zip
[SPARK-10573] [ML] IndexToString output schema should be StringType
Fixes bug where IndexToString output schema was DoubleType. Correct me if I'm wrong, but it doesn't seem like the output needs to have any "ML Attribute" metadata. Author: Nick Pritchard <nicholas.pritchard@falkonry.com> Closes #8751 from pnpritchard/SPARK-10573.
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala5
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala8
2 files changed, 10 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
index 3a4ab9a857..2b1592930e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -27,7 +27,7 @@ import org.apache.spark.ml.Transformer
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
-import org.apache.spark.sql.types.{DoubleType, NumericType, StringType, StructType}
+import org.apache.spark.sql.types._
import org.apache.spark.util.collection.OpenHashMap
/**
@@ -229,8 +229,7 @@ class IndexToString private[ml] (
val outputColName = $(outputCol)
require(inputFields.forall(_.name != outputColName),
s"Output column $outputColName already exists.")
- val attr = NominalAttribute.defaultAttr.withName($(outputCol))
- val outputFields = inputFields :+ attr.toStructField()
+ val outputFields = inputFields :+ StructField($(outputCol), StringType)
StructType(outputFields)
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
index 05e05bdc64..ddcdb5f421 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark.ml.feature
+import org.apache.spark.sql.types.{StringType, StructType, StructField, DoubleType}
import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
import org.apache.spark.ml.param.ParamsSuite
@@ -165,4 +166,11 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(a === b)
}
}
+
+ test("IndexToString.transformSchema (SPARK-10573)") {
+ val idxToStr = new IndexToString().setInputCol("input").setOutputCol("output")
+ val inSchema = StructType(Seq(StructField("input", DoubleType)))
+ val outSchema = idxToStr.transformSchema(inSchema)
+ assert(outSchema("output").dataType === StringType)
+ }
}