aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorNick Pritchard <nicholas.pritchard@falkonry.com>2015-09-14 13:27:45 -0700
committerXiangrui Meng <meng@databricks.com>2015-09-14 13:27:45 -0700
commit8a634e9bcc671167613fb575c6c0c054fb4b3479 (patch)
tree48ecf76230e052fce7676bc7f03fd308b20e7ee3 /mllib
parentce6f3f163bc667cb5da9ab4331c8bad10cc0d701 (diff)
downloadspark-8a634e9bcc671167613fb575c6c0c054fb4b3479.tar.gz
spark-8a634e9bcc671167613fb575c6c0c054fb4b3479.tar.bz2
spark-8a634e9bcc671167613fb575c6c0c054fb4b3479.zip
[SPARK-10573] [ML] IndexToString output schema should be StringType
Fixes bug where IndexToString output schema was DoubleType. Correct me if I'm wrong, but it doesn't seem like the output needs to have any "ML Attribute" metadata. Author: Nick Pritchard <nicholas.pritchard@falkonry.com> Closes #8751 from pnpritchard/SPARK-10573.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala5
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala8
2 files changed, 10 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
index 3a4ab9a857..2b1592930e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -27,7 +27,7 @@ import org.apache.spark.ml.Transformer
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
-import org.apache.spark.sql.types.{DoubleType, NumericType, StringType, StructType}
+import org.apache.spark.sql.types._
import org.apache.spark.util.collection.OpenHashMap
/**
@@ -229,8 +229,7 @@ class IndexToString private[ml] (
val outputColName = $(outputCol)
require(inputFields.forall(_.name != outputColName),
s"Output column $outputColName already exists.")
- val attr = NominalAttribute.defaultAttr.withName($(outputCol))
- val outputFields = inputFields :+ attr.toStructField()
+ val outputFields = inputFields :+ StructField($(outputCol), StringType)
StructType(outputFields)
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
index 05e05bdc64..ddcdb5f421 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark.ml.feature
+import org.apache.spark.sql.types.{StringType, StructType, StructField, DoubleType}
import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
import org.apache.spark.ml.param.ParamsSuite
@@ -165,4 +166,11 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(a === b)
}
}
+
+ test("IndexToString.transformSchema (SPARK-10573)") {
+ val idxToStr = new IndexToString().setInputCol("input").setOutputCol("output")
+ val inSchema = StructType(Seq(StructField("input", DoubleType)))
+ val outSchema = idxToStr.transformSchema(inSchema)
+ assert(outSchema("output").dataType === StringType)
+ }
}