aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-04-12 22:41:05 -0700
committerXiangrui Meng <meng@databricks.com>2015-04-12 22:41:05 -0700
commit685ddcf5253c0ecb39853802431e22b0c7b61dee (patch)
treebdca5d046db951a373eda0dd0c27a7e34531978f /mllib/src/test
parentd3792f54974e16cbe8f10b3091d248e0bdd48986 (diff)
downloadspark-685ddcf5253c0ecb39853802431e22b0c7b61dee.tar.gz
spark-685ddcf5253c0ecb39853802431e22b0c7b61dee.tar.bz2
spark-685ddcf5253c0ecb39853802431e22b0c7b61dee.zip
[SPARK-5886][ML] Add StringIndexer as a feature transformer
This PR adds string indexer, which takes a column of string labels and outputs a double column with labels indexed by their frequency. TODOs: - [x] store feature to index map in output metadata Author: Xiangrui Meng <meng@databricks.com> Closes #4735 from mengxr/SPARK-5886 and squashes the following commits: d82575f [Xiangrui Meng] fix test 700e70f [Xiangrui Meng] rename LabelIndexer to StringIndexer 16a6f8c [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-5886 457166e [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-5886 f8b30f4 [Xiangrui Meng] update label indexer to output metadata e81ec28 [Xiangrui Meng] Merge branch 'openhashmap-contains' into SPARK-5886-2 d6e6f1f [Xiangrui Meng] add contains to primitivekeyopenhashmap 748a69b [Xiangrui Meng] add contains to OpenHashMap def3c5c [Xiangrui Meng] add LabelIndexer
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala52
1 files changed, 52 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
new file mode 100644
index 0000000000..00b5d094d8
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.SQLContext
+
+class StringIndexerSuite extends FunSuite with MLlibTestSparkContext {
+ private var sqlContext: SQLContext = _
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ sqlContext = new SQLContext(sc)
+ }
+
+ test("StringIndexer") {
+ val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2)
+ val df = sqlContext.createDataFrame(data).toDF("id", "label")
+ val indexer = new StringIndexer()
+ .setInputCol("label")
+ .setOutputCol("labelIndex")
+ .fit(df)
+ val transformed = indexer.transform(df)
+ val attr = Attribute.fromStructField(transformed.schema("labelIndex"))
+ .asInstanceOf[NominalAttribute]
+ assert(attr.values.get === Array("a", "c", "b"))
+ val output = transformed.select("id", "labelIndex").map { r =>
+ (r.getInt(0), r.getDouble(1))
+ }.collect().toSet
+ // a -> 0, b -> 2, c -> 1
+ val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0))
+ assert(output === expected)
+ }
+}