diff options
Diffstat (limited to 'examples/src/main/python/ml/onehot_encoder_example.py')
-rw-r--r-- | examples/src/main/python/ml/onehot_encoder_example.py | 48 |
1 files changed, 48 insertions, 0 deletions
diff --git a/examples/src/main/python/ml/onehot_encoder_example.py b/examples/src/main/python/ml/onehot_encoder_example.py new file mode 100644 index 0000000000..0f94c26638 --- /dev/null +++ b/examples/src/main/python/ml/onehot_encoder_example.py @@ -0,0 +1,48 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import OneHotEncoder, StringIndexer +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="OneHotEncoderExample") + sqlContext = SQLContext(sc) + + # $example on$ + df = sqlContext.createDataFrame([ + (0, "a"), + (1, "b"), + (2, "c"), + (3, "a"), + (4, "a"), + (5, "c") + ], ["id", "category"]) + + stringIndexer = StringIndexer(inputCol="category", outputCol="categoryIndex") + model = stringIndexer.fit(df) + indexed = model.transform(df) + encoder = OneHotEncoder(dropLast=False, inputCol="categoryIndex", outputCol="categoryVec") + encoded = encoder.transform(indexed) + encoded.select("id", "categoryVec").show() + # $example off$ + + sc.stop() |