diff options
Diffstat (limited to 'examples/src/main/python/ml/random_forest_classifier_example.py')
-rw-r--r-- | examples/src/main/python/ml/random_forest_classifier_example.py | 13 |
1 files changed, 9 insertions, 4 deletions
diff --git a/examples/src/main/python/ml/random_forest_classifier_example.py b/examples/src/main/python/ml/random_forest_classifier_example.py index eb9ded9af5..4eaa94dd7f 100644 --- a/examples/src/main/python/ml/random_forest_classifier_example.py +++ b/examples/src/main/python/ml/random_forest_classifier_example.py @@ -23,7 +23,7 @@ from __future__ import print_function # $example on$ from pyspark.ml import Pipeline from pyspark.ml.classification import RandomForestClassifier -from pyspark.ml.feature import StringIndexer, VectorIndexer +from pyspark.ml.feature import IndexToString, StringIndexer, VectorIndexer from pyspark.ml.evaluation import MulticlassClassificationEvaluator # $example off$ from pyspark.sql import SparkSession @@ -31,7 +31,7 @@ from pyspark.sql import SparkSession if __name__ == "__main__": spark = SparkSession\ .builder\ - .appName("random_forest_classifier_example")\ + .appName("RandomForestClassifierExample")\ .getOrCreate() # $example on$ @@ -41,6 +41,7 @@ if __name__ == "__main__": # Index labels, adding metadata to the label column. # Fit on whole dataset to include all labels in index. labelIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel").fit(data) + # Automatically identify categorical features, and index them. # Set maxCategories so features with > 4 distinct values are treated as continuous. featureIndexer =\ @@ -52,8 +53,12 @@ if __name__ == "__main__": # Train a RandomForest model. rf = RandomForestClassifier(labelCol="indexedLabel", featuresCol="indexedFeatures", numTrees=10) + # Convert indexed labels back to original labels. + labelConverter = IndexToString(inputCol="prediction", outputCol="predictedLabel", + labels=labelIndexer.labels) + # Chain indexers and forest in a Pipeline - pipeline = Pipeline(stages=[labelIndexer, featureIndexer, rf]) + pipeline = Pipeline(stages=[labelIndexer, featureIndexer, rf, labelConverter]) # Train model. This also runs the indexers. model = pipeline.fit(trainingData) @@ -62,7 +67,7 @@ if __name__ == "__main__": predictions = model.transform(testData) # Select example rows to display. - predictions.select("prediction", "indexedLabel", "features").show(5) + predictions.select("predictedLabel", "label", "features").show(5) # Select (prediction, true label) and compute test error evaluator = MulticlassClassificationEvaluator( |