aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--docs/ml-features.md104
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaIndexToStringExample.java75
-rw-r--r--examples/src/main/python/ml/index_to_string_example.py45
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/ml/IndexToStringExample.scala60
4 files changed, 268 insertions, 16 deletions
diff --git a/docs/ml-features.md b/docs/ml-features.md
index 01d6abeb5b..e15c26836a 100644
--- a/docs/ml-features.md
+++ b/docs/ml-features.md
@@ -835,10 +835,10 @@ dctDf.select("featuresDCT").show(3);
`StringIndexer` encodes a string column of labels to a column of label indices.
The indices are in `[0, numLabels)`, ordered by label frequencies.
So the most frequent label gets index `0`.
-If the input column is numeric, we cast it to string and index the string
-values. When downstream pipeline components such as `Estimator` or
-`Transformer` make use of this string-indexed label, you must set the input
-column of the component to this string-indexed column name. In many cases,
+If the input column is numeric, we cast it to string and index the string
+values. When downstream pipeline components such as `Estimator` or
+`Transformer` make use of this string-indexed label, you must set the input
+column of the component to this string-indexed column name. In many cases,
you can set the input column with `setInputCol`.
**Examples**
@@ -951,9 +951,78 @@ indexed.show()
</div>
</div>
+
+## IndexToString
+
+Symmetrically to `StringIndexer`, `IndexToString` maps a column of label indices
+back to a column containing the original labels as strings. The common use case
+is to produce indices from labels with `StringIndexer`, train a model with those
+indices and retrieve the original labels from the column of predicted indices
+with `IndexToString`. However, you are free to supply your own labels.
+
+**Examples**
+
+Building on the `StringIndexer` example, let's assume we have the following
+DataFrame with columns `id` and `categoryIndex`:
+
+~~~~
+ id | categoryIndex
+----|---------------
+ 0 | 0.0
+ 1 | 2.0
+ 2 | 1.0
+ 3 | 0.0
+ 4 | 0.0
+ 5 | 1.0
+~~~~
+
+Applying `IndexToString` with `categoryIndex` as the input column,
+`originalCategory` as the output column, we are able to retrieve our original
+labels (they will be inferred from the columns' metadata):
+
+~~~~
+ id | categoryIndex | originalCategory
+----|---------------|-----------------
+ 0 | 0.0 | a
+ 1 | 2.0 | b
+ 2 | 1.0 | c
+ 3 | 0.0 | a
+ 4 | 0.0 | a
+ 5 | 1.0 | c
+~~~~
+
+<div class="codetabs">
+<div data-lang="scala" markdown="1">
+
+Refer to the [IndexToString Scala docs](api/scala/index.html#org.apache.spark.ml.feature.IndexToString)
+for more details on the API.
+
+{% include_example scala/org/apache/spark/examples/ml/IndexToStringExample.scala %}
+
+</div>
+
+<div data-lang="java" markdown="1">
+
+Refer to the [IndexToString Java docs](api/java/org/apache/spark/ml/feature/IndexToString.html)
+for more details on the API.
+
+{% include_example java/org/apache/spark/examples/ml/JavaIndexToStringExample.java %}
+
+</div>
+
+<div data-lang="python" markdown="1">
+
+Refer to the [IndexToString Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.IndexToString)
+for more details on the API.
+
+{% include_example python/ml/index_to_string_example.py %}
+
+</div>
+</div>
+
## OneHotEncoder
-[One-hot encoding](http://en.wikipedia.org/wiki/One-hot) maps a column of label indices to a column of binary vectors, with at most a single one-value. This encoding allows algorithms which expect continuous features, such as Logistic Regression, to use categorical features
+[One-hot encoding](http://en.wikipedia.org/wiki/One-hot) maps a column of label indices to a column of binary vectors, with at most a single one-value. This encoding allows algorithms which expect continuous features, such as Logistic Regression, to use categorical features
<div class="codetabs">
<div data-lang="scala" markdown="1">
@@ -979,10 +1048,11 @@ val indexer = new StringIndexer()
.fit(df)
val indexed = indexer.transform(df)
-val encoder = new OneHotEncoder().setInputCol("categoryIndex").
- setOutputCol("categoryVec")
+val encoder = new OneHotEncoder()
+ .setInputCol("categoryIndex")
+ .setOutputCol("categoryVec")
val encoded = encoder.transform(indexed)
-encoded.select("id", "categoryVec").foreach(println)
+encoded.select("id", "categoryVec").show()
{% endhighlight %}
</div>
@@ -1015,7 +1085,7 @@ JavaRDD<Row> jrdd = jsc.parallelize(Arrays.asList(
RowFactory.create(5, "c")
));
StructType schema = new StructType(new StructField[]{
- new StructField("id", DataTypes.DoubleType, false, Metadata.empty()),
+ new StructField("id", DataTypes.IntegerType, false, Metadata.empty()),
new StructField("category", DataTypes.StringType, false, Metadata.empty())
});
DataFrame df = sqlContext.createDataFrame(jrdd, schema);
@@ -1029,6 +1099,7 @@ OneHotEncoder encoder = new OneHotEncoder()
.setInputCol("categoryIndex")
.setOutputCol("categoryVec");
DataFrame encoded = encoder.transform(indexed);
+encoded.select("id", "categoryVec").show();
{% endhighlight %}
</div>
@@ -1054,6 +1125,7 @@ model = stringIndexer.fit(df)
indexed = model.transform(df)
encoder = OneHotEncoder(includeFirst=False, inputCol="categoryIndex", outputCol="categoryVec")
encoded = encoder.transform(indexed)
+encoded.select("id", "categoryVec").show()
{% endhighlight %}
</div>
</div>
@@ -1582,7 +1654,7 @@ from pyspark.mllib.linalg import Vectors
data = [(Vectors.dense([1.0, 2.0, 3.0]),), (Vectors.dense([4.0, 5.0, 6.0]),)]
df = sqlContext.createDataFrame(data, ["vector"])
-transformer = ElementwiseProduct(scalingVec=Vectors.dense([0.0, 1.0, 2.0]),
+transformer = ElementwiseProduct(scalingVec=Vectors.dense([0.0, 1.0, 2.0]),
inputCol="vector", outputCol="transformedVector")
transformer.transform(df).show()
@@ -1837,15 +1909,15 @@ for more details on the API.
sub-array of the original features. It is useful for extracting features from a vector column.
`VectorSlicer` accepts a vector column with a specified indices, then outputs a new vector column
-whose values are selected via those indices. There are two types of indices,
+whose values are selected via those indices. There are two types of indices,
1. Integer indices that represents the indices into the vector, `setIndices()`;
- 2. String indices that represents the names of features into the vector, `setNames()`.
+ 2. String indices that represents the names of features into the vector, `setNames()`.
*This requires the vector column to have an `AttributeGroup` since the implementation matches on
the name field of an `Attribute`.*
-Specification by integer and string are both acceptable. Moreover, you can use integer index and
+Specification by integer and string are both acceptable. Moreover, you can use integer index and
string name simultaneously. At least one feature must be selected. Duplicate features are not
allowed, so there can be no overlap between selected indices and names. Note that if names of
features are selected, an exception will be threw out when encountering with empty input attributes.
@@ -1858,9 +1930,9 @@ followed by the selected names (in the order given).
Suppose that we have a DataFrame with the column `userFeatures`:
~~~
- userFeatures
+ userFeatures
------------------
- [0.0, 10.0, 0.5]
+ [0.0, 10.0, 0.5]
~~~
`userFeatures` is a vector column that contains three user features. Assuming that the first column
@@ -1874,7 +1946,7 @@ column named `features`:
[0.0, 10.0, 0.5] | [10.0, 0.5]
~~~
-Suppose also that we have a potential input attributes for the `userFeatures`, i.e.
+Suppose also that we have a potential input attributes for the `userFeatures`, i.e.
`["f1", "f2", "f3"]`, then we can use `setNames("f2", "f3")` to select them.
~~~
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaIndexToStringExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaIndexToStringExample.java
new file mode 100644
index 0000000000..3ccd699326
--- /dev/null
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaIndexToStringExample.java
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.ml;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.sql.SQLContext;
+
+// $example on$
+import java.util.Arrays;
+
+import org.apache.spark.ml.feature.IndexToString;
+import org.apache.spark.ml.feature.StringIndexer;
+import org.apache.spark.ml.feature.StringIndexerModel;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.RowFactory;
+import org.apache.spark.sql.types.DataTypes;
+import org.apache.spark.sql.types.Metadata;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
+// $example off$
+
+public class JavaIndexToStringExample {
+ public static void main(String[] args) {
+ SparkConf conf = new SparkConf().setAppName("JavaIndexToStringExample");
+ JavaSparkContext jsc = new JavaSparkContext(conf);
+ SQLContext sqlContext = new SQLContext(jsc);
+
+ // $example on$
+ JavaRDD<Row> jrdd = jsc.parallelize(Arrays.asList(
+ RowFactory.create(0, "a"),
+ RowFactory.create(1, "b"),
+ RowFactory.create(2, "c"),
+ RowFactory.create(3, "a"),
+ RowFactory.create(4, "a"),
+ RowFactory.create(5, "c")
+ ));
+ StructType schema = new StructType(new StructField[]{
+ new StructField("id", DataTypes.IntegerType, false, Metadata.empty()),
+ new StructField("category", DataTypes.StringType, false, Metadata.empty())
+ });
+ DataFrame df = sqlContext.createDataFrame(jrdd, schema);
+
+ StringIndexerModel indexer = new StringIndexer()
+ .setInputCol("category")
+ .setOutputCol("categoryIndex")
+ .fit(df);
+ DataFrame indexed = indexer.transform(df);
+
+ IndexToString converter = new IndexToString()
+ .setInputCol("categoryIndex")
+ .setOutputCol("originalCategory");
+ DataFrame converted = converter.transform(indexed);
+ converted.select("id", "originalCategory").show();
+ // $example off$
+ jsc.stop();
+ }
+}
diff --git a/examples/src/main/python/ml/index_to_string_example.py b/examples/src/main/python/ml/index_to_string_example.py
new file mode 100644
index 0000000000..fb0ba2950b
--- /dev/null
+++ b/examples/src/main/python/ml/index_to_string_example.py
@@ -0,0 +1,45 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from __future__ import print_function
+
+from pyspark import SparkContext
+# $example on$
+from pyspark.ml.feature import IndexToString, StringIndexer
+# $example off$
+from pyspark.sql import SQLContext
+
+if __name__ == "__main__":
+ sc = SparkContext(appName="IndexToStringExample")
+ sqlContext = SQLContext(sc)
+
+ # $example on$
+ df = sqlContext.createDataFrame(
+ [(0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")],
+ ["id", "category"])
+
+ stringIndexer = StringIndexer(inputCol="category", outputCol="categoryIndex")
+ model = stringIndexer.fit(df)
+ indexed = model.transform(df)
+
+ converter = IndexToString(inputCol="categoryIndex", outputCol="originalCategory")
+ converted = converter.transform(indexed)
+
+ converted.select("id", "originalCategory").show()
+ # $example off$
+
+ sc.stop()
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/IndexToStringExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/IndexToStringExample.scala
new file mode 100644
index 0000000000..52537e5bb5
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/IndexToStringExample.scala
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// scalastyle:off println
+package org.apache.spark.examples.ml
+
+import org.apache.spark.sql.SQLContext
+import org.apache.spark.{SparkConf, SparkContext}
+// $example on$
+import org.apache.spark.ml.feature.{StringIndexer, IndexToString}
+// $example off$
+
+object IndexToStringExample {
+ def main(args: Array[String]) {
+ val conf = new SparkConf().setAppName("IndexToStringExample")
+ val sc = new SparkContext(conf)
+
+ val sqlContext = SQLContext.getOrCreate(sc)
+
+ // $example on$
+ val df = sqlContext.createDataFrame(Seq(
+ (0, "a"),
+ (1, "b"),
+ (2, "c"),
+ (3, "a"),
+ (4, "a"),
+ (5, "c")
+ )).toDF("id", "category")
+
+ val indexer = new StringIndexer()
+ .setInputCol("category")
+ .setOutputCol("categoryIndex")
+ .fit(df)
+ val indexed = indexer.transform(df)
+
+ val converter = new IndexToString()
+ .setInputCol("categoryIndex")
+ .setOutputCol("originalCategory")
+
+ val converted = converter.transform(indexed)
+ converted.select("id", "originalCategory").show()
+ // $example off$
+ sc.stop()
+ }
+}
+// scalastyle:on println