From 8a634e9bcc671167613fb575c6c0c054fb4b3479 Mon Sep 17 00:00:00 2001 From: Nick Pritchard Date: Mon, 14 Sep 2015 13:27:45 -0700 Subject: [SPARK-10573] [ML] IndexToString output schema should be StringType Fixes bug where IndexToString output schema was DoubleType. Correct me if I'm wrong, but it doesn't seem like the output needs to have any "ML Attribute" metadata. Author: Nick Pritchard Closes #8751 from pnpritchard/SPARK-10573. --- .../main/scala/org/apache/spark/ml/feature/StringIndexer.scala | 5 ++--- .../scala/org/apache/spark/ml/feature/StringIndexerSuite.scala | 8 ++++++++ 2 files changed, 10 insertions(+), 3 deletions(-) (limited to 'mllib') diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 3a4ab9a857..2b1592930e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -27,7 +27,7 @@ import org.apache.spark.ml.Transformer import org.apache.spark.ml.util.Identifiable import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.{DoubleType, NumericType, StringType, StructType} +import org.apache.spark.sql.types._ import org.apache.spark.util.collection.OpenHashMap /** @@ -229,8 +229,7 @@ class IndexToString private[ml] ( val outputColName = $(outputCol) require(inputFields.forall(_.name != outputColName), s"Output column $outputColName already exists.") - val attr = NominalAttribute.defaultAttr.withName($(outputCol)) - val outputFields = inputFields :+ attr.toStructField() + val outputFields = inputFields :+ StructField($(outputCol), StringType) StructType(outputFields) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index 05e05bdc64..ddcdb5f421 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.ml.feature +import org.apache.spark.sql.types.{StringType, StructType, StructField, DoubleType} import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} import org.apache.spark.ml.param.ParamsSuite @@ -165,4 +166,11 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { assert(a === b) } } + + test("IndexToString.transformSchema (SPARK-10573)") { + val idxToStr = new IndexToString().setInputCol("input").setOutputCol("output") + val inSchema = StructType(Seq(StructField("input", DoubleType))) + val outSchema = idxToStr.transformSchema(inSchema) + assert(outSchema("output").dataType === StringType) + } } -- cgit v1.2.3