From d36e67350c516a96d58abd50a0d5d22b3b22f291 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 28 Apr 2015 17:41:09 -0700 Subject: [SPARK-6965] [MLLIB] StringIndexer handles numeric input. Cast numeric types to String for indexing. Boolean type is not handled in this PR. jkbradley Author: Xiangrui Meng Closes #5753 from mengxr/SPARK-6965 and squashes the following commits: 2e34f3c [Xiangrui Meng] add actual type in the error message ad938bf [Xiangrui Meng] StringIndexer handles numeric input. --- .../org/apache/spark/ml/feature/StringIndexer.scala | 17 ++++++++++++----- .../apache/spark/ml/feature/StringIndexerSuite.scala | 19 +++++++++++++++++++ 2 files changed, 31 insertions(+), 5 deletions(-) (limited to 'mllib') diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 23956c512c..9db3b29e10 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -23,10 +23,9 @@ import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.attribute.NominalAttribute import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.{StringType, StructType} +import org.apache.spark.sql.types.{NumericType, StringType, StructType} import org.apache.spark.util.collection.OpenHashMap /** @@ -37,7 +36,11 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha /** Validates and transforms the input schema. */ protected def validateAndTransformSchema(schema: StructType, paramMap: ParamMap): StructType = { val map = extractParamMap(paramMap) - SchemaUtils.checkColumnType(schema, map(inputCol), StringType) + val inputColName = map(inputCol) + val inputDataType = schema(inputColName).dataType + require(inputDataType == StringType || inputDataType.isInstanceOf[NumericType], + s"The input column $inputColName must be either string type or numeric type, " + + s"but got $inputDataType.") val inputFields = schema.fields val outputColName = map(outputCol) require(inputFields.forall(_.name != outputColName), @@ -51,6 +54,7 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha /** * :: AlphaComponent :: * A label indexer that maps a string column of labels to an ML column of label indices. + * If the input column is numeric, we cast it to string and index the string values. * The indices are in [0, numLabels), ordered by label frequencies. * So the most frequent label gets index 0. */ @@ -67,7 +71,9 @@ class StringIndexer extends Estimator[StringIndexerModel] with StringIndexerBase override def fit(dataset: DataFrame, paramMap: ParamMap): StringIndexerModel = { val map = extractParamMap(paramMap) - val counts = dataset.select(map(inputCol)).map(_.getString(0)).countByValue() + val counts = dataset.select(col(map(inputCol)).cast(StringType)) + .map(_.getString(0)) + .countByValue() val labels = counts.toSeq.sortBy(-_._2).map(_._1).toArray val model = new StringIndexerModel(this, map, labels) Params.inheritValues(map, this, model) @@ -119,7 +125,8 @@ class StringIndexerModel private[ml] ( val outputColName = map(outputCol) val metadata = NominalAttribute.defaultAttr .withName(outputColName).withValues(labels).toMetadata() - dataset.select(col("*"), indexer(dataset(map(inputCol))).as(outputColName, metadata)) + dataset.select(col("*"), + indexer(dataset(map(inputCol)).cast(StringType)).as(outputColName, metadata)) } override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index 00b5d094d8..b6939e5870 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -49,4 +49,23 @@ class StringIndexerSuite extends FunSuite with MLlibTestSparkContext { val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0)) assert(output === expected) } + + test("StringIndexer with a numeric input column") { + val data = sc.parallelize(Seq((0, 100), (1, 200), (2, 300), (3, 100), (4, 100), (5, 300)), 2) + val df = sqlContext.createDataFrame(data).toDF("id", "label") + val indexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("labelIndex") + .fit(df) + val transformed = indexer.transform(df) + val attr = Attribute.fromStructField(transformed.schema("labelIndex")) + .asInstanceOf[NominalAttribute] + assert(attr.values.get === Array("100", "300", "200")) + val output = transformed.select("id", "labelIndex").map { r => + (r.getInt(0), r.getDouble(1)) + }.collect().toSet + // 100 -> 0, 200 -> 2, 300 -> 1 + val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0)) + assert(output === expected) + } } -- cgit v1.2.3