aboutsummaryrefslogtreecommitdiff
path: root/examples/src/main/python/mllib/random_forest_example.py
diff options
context:
space:
mode:
Diffstat (limited to 'examples/src/main/python/mllib/random_forest_example.py')
-rwxr-xr-xexamples/src/main/python/mllib/random_forest_example.py9
1 files changed, 5 insertions, 4 deletions
diff --git a/examples/src/main/python/mllib/random_forest_example.py b/examples/src/main/python/mllib/random_forest_example.py
index d3c24f7664..4cfdad868c 100755
--- a/examples/src/main/python/mllib/random_forest_example.py
+++ b/examples/src/main/python/mllib/random_forest_example.py
@@ -22,6 +22,7 @@ Note: This example illustrates binary classification.
For information on multiclass classification, please refer to the decision_tree_runner.py
example.
"""
+from __future__ import print_function
import sys
@@ -43,7 +44,7 @@ def testClassification(trainingData, testData):
# Evaluate model on test instances and compute test error
predictions = model.predict(testData.map(lambda x: x.features))
labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions)
- testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count()\
+ testErr = labelsAndPredictions.filter(lambda v_p: v_p[0] != v_p[1]).count()\
/ float(testData.count())
print('Test Error = ' + str(testErr))
print('Learned classification forest model:')
@@ -62,8 +63,8 @@ def testRegression(trainingData, testData):
# Evaluate model on test instances and compute test error
predictions = model.predict(testData.map(lambda x: x.features))
labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions)
- testMSE = labelsAndPredictions.map(lambda (v, p): (v - p) * (v - p)).sum()\
- / float(testData.count())
+ testMSE = labelsAndPredictions.map(lambda v_p1: (v_p1[0] - v_p1[1]) * (v_p1[0] - v_p1[1]))\
+ .sum() / float(testData.count())
print('Test Mean Squared Error = ' + str(testMSE))
print('Learned regression forest model:')
print(model.toDebugString())
@@ -71,7 +72,7 @@ def testRegression(trainingData, testData):
if __name__ == "__main__":
if len(sys.argv) > 1:
- print >> sys.stderr, "Usage: random_forest_example"
+ print("Usage: random_forest_example", file=sys.stderr)
exit(1)
sc = SparkContext(appName="PythonRandomForestExample")