aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-04-28 17:41:09 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-04-28 17:41:09 -0700
commitd36e67350c516a96d58abd50a0d5d22b3b22f291 (patch)
tree06f22988f9d0e6909cfc3132d95882959c0b9090 /mllib
parent555213ebbf2be2ee523be8665bd5b9a47ae4bec8 (diff)
downloadspark-d36e67350c516a96d58abd50a0d5d22b3b22f291.tar.gz
spark-d36e67350c516a96d58abd50a0d5d22b3b22f291.tar.bz2
spark-d36e67350c516a96d58abd50a0d5d22b3b22f291.zip
[SPARK-6965] [MLLIB] StringIndexer handles numeric input.
Cast numeric types to String for indexing. Boolean type is not handled in this PR. jkbradley Author: Xiangrui Meng <meng@databricks.com> Closes #5753 from mengxr/SPARK-6965 and squashes the following commits: 2e34f3c [Xiangrui Meng] add actual type in the error message ad938bf [Xiangrui Meng] StringIndexer handles numeric input.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala17
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala19
2 files changed, 31 insertions, 5 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
index 23956c512c..9db3b29e10 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -23,10 +23,9 @@ import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.attribute.NominalAttribute
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
-import org.apache.spark.ml.util.SchemaUtils
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
-import org.apache.spark.sql.types.{StringType, StructType}
+import org.apache.spark.sql.types.{NumericType, StringType, StructType}
import org.apache.spark.util.collection.OpenHashMap
/**
@@ -37,7 +36,11 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha
/** Validates and transforms the input schema. */
protected def validateAndTransformSchema(schema: StructType, paramMap: ParamMap): StructType = {
val map = extractParamMap(paramMap)
- SchemaUtils.checkColumnType(schema, map(inputCol), StringType)
+ val inputColName = map(inputCol)
+ val inputDataType = schema(inputColName).dataType
+ require(inputDataType == StringType || inputDataType.isInstanceOf[NumericType],
+ s"The input column $inputColName must be either string type or numeric type, " +
+ s"but got $inputDataType.")
val inputFields = schema.fields
val outputColName = map(outputCol)
require(inputFields.forall(_.name != outputColName),
@@ -51,6 +54,7 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha
/**
* :: AlphaComponent ::
* A label indexer that maps a string column of labels to an ML column of label indices.
+ * If the input column is numeric, we cast it to string and index the string values.
* The indices are in [0, numLabels), ordered by label frequencies.
* So the most frequent label gets index 0.
*/
@@ -67,7 +71,9 @@ class StringIndexer extends Estimator[StringIndexerModel] with StringIndexerBase
override def fit(dataset: DataFrame, paramMap: ParamMap): StringIndexerModel = {
val map = extractParamMap(paramMap)
- val counts = dataset.select(map(inputCol)).map(_.getString(0)).countByValue()
+ val counts = dataset.select(col(map(inputCol)).cast(StringType))
+ .map(_.getString(0))
+ .countByValue()
val labels = counts.toSeq.sortBy(-_._2).map(_._1).toArray
val model = new StringIndexerModel(this, map, labels)
Params.inheritValues(map, this, model)
@@ -119,7 +125,8 @@ class StringIndexerModel private[ml] (
val outputColName = map(outputCol)
val metadata = NominalAttribute.defaultAttr
.withName(outputColName).withValues(labels).toMetadata()
- dataset.select(col("*"), indexer(dataset(map(inputCol))).as(outputColName, metadata))
+ dataset.select(col("*"),
+ indexer(dataset(map(inputCol)).cast(StringType)).as(outputColName, metadata))
}
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
index 00b5d094d8..b6939e5870 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
@@ -49,4 +49,23 @@ class StringIndexerSuite extends FunSuite with MLlibTestSparkContext {
val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0))
assert(output === expected)
}
+
+ test("StringIndexer with a numeric input column") {
+ val data = sc.parallelize(Seq((0, 100), (1, 200), (2, 300), (3, 100), (4, 100), (5, 300)), 2)
+ val df = sqlContext.createDataFrame(data).toDF("id", "label")
+ val indexer = new StringIndexer()
+ .setInputCol("label")
+ .setOutputCol("labelIndex")
+ .fit(df)
+ val transformed = indexer.transform(df)
+ val attr = Attribute.fromStructField(transformed.schema("labelIndex"))
+ .asInstanceOf[NominalAttribute]
+ assert(attr.values.get === Array("100", "300", "200"))
+ val output = transformed.select("id", "labelIndex").map { r =>
+ (r.getInt(0), r.getDouble(1))
+ }.collect().toSet
+ // 100 -> 0, 200 -> 2, 300 -> 1
+ val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0))
+ assert(output === expected)
+ }
}