aboutsummaryrefslogtreecommitdiff
path: root/examples/src/main/python/ml/random_forest_classifier_example.py
diff options
context:
space:
mode:
Diffstat (limited to 'examples/src/main/python/ml/random_forest_classifier_example.py')
-rw-r--r--examples/src/main/python/ml/random_forest_classifier_example.py13
1 files changed, 9 insertions, 4 deletions
diff --git a/examples/src/main/python/ml/random_forest_classifier_example.py b/examples/src/main/python/ml/random_forest_classifier_example.py
index eb9ded9af5..4eaa94dd7f 100644
--- a/examples/src/main/python/ml/random_forest_classifier_example.py
+++ b/examples/src/main/python/ml/random_forest_classifier_example.py
@@ -23,7 +23,7 @@ from __future__ import print_function
# $example on$
from pyspark.ml import Pipeline
from pyspark.ml.classification import RandomForestClassifier
-from pyspark.ml.feature import StringIndexer, VectorIndexer
+from pyspark.ml.feature import IndexToString, StringIndexer, VectorIndexer
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
# $example off$
from pyspark.sql import SparkSession
@@ -31,7 +31,7 @@ from pyspark.sql import SparkSession
if __name__ == "__main__":
spark = SparkSession\
.builder\
- .appName("random_forest_classifier_example")\
+ .appName("RandomForestClassifierExample")\
.getOrCreate()
# $example on$
@@ -41,6 +41,7 @@ if __name__ == "__main__":
# Index labels, adding metadata to the label column.
# Fit on whole dataset to include all labels in index.
labelIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel").fit(data)
+
# Automatically identify categorical features, and index them.
# Set maxCategories so features with > 4 distinct values are treated as continuous.
featureIndexer =\
@@ -52,8 +53,12 @@ if __name__ == "__main__":
# Train a RandomForest model.
rf = RandomForestClassifier(labelCol="indexedLabel", featuresCol="indexedFeatures", numTrees=10)
+ # Convert indexed labels back to original labels.
+ labelConverter = IndexToString(inputCol="prediction", outputCol="predictedLabel",
+ labels=labelIndexer.labels)
+
# Chain indexers and forest in a Pipeline
- pipeline = Pipeline(stages=[labelIndexer, featureIndexer, rf])
+ pipeline = Pipeline(stages=[labelIndexer, featureIndexer, rf, labelConverter])
# Train model. This also runs the indexers.
model = pipeline.fit(trainingData)
@@ -62,7 +67,7 @@ if __name__ == "__main__":
predictions = model.transform(testData)
# Select example rows to display.
- predictions.select("prediction", "indexedLabel", "features").show(5)
+ predictions.select("predictedLabel", "label", "features").show(5)
# Select (prediction, true label) and compute test error
evaluator = MulticlassClassificationEvaluator(