aboutsummaryrefslogtreecommitdiff
path: root/examples/src/main/python/mllib/decision_tree_runner.py
diff options
context:
space:
mode:
Diffstat (limited to 'examples/src/main/python/mllib/decision_tree_runner.py')
-rwxr-xr-xexamples/src/main/python/mllib/decision_tree_runner.py29
1 files changed, 14 insertions, 15 deletions
diff --git a/examples/src/main/python/mllib/decision_tree_runner.py b/examples/src/main/python/mllib/decision_tree_runner.py
index fccabd841b..513ed8fd51 100755
--- a/examples/src/main/python/mllib/decision_tree_runner.py
+++ b/examples/src/main/python/mllib/decision_tree_runner.py
@@ -20,6 +20,7 @@ Decision tree classification and regression using MLlib.
This example requires NumPy (http://www.numpy.org/).
"""
+from __future__ import print_function
import numpy
import os
@@ -83,18 +84,17 @@ def reindexClassLabels(data):
numClasses = len(classCounts)
# origToNewLabels: class --> index in 0,...,numClasses-1
if (numClasses < 2):
- print >> sys.stderr, \
- "Dataset for classification should have at least 2 classes." + \
- " The given dataset had only %d classes." % numClasses
+ print("Dataset for classification should have at least 2 classes."
+ " The given dataset had only %d classes." % numClasses, file=sys.stderr)
exit(1)
origToNewLabels = dict([(sortedClasses[i], i) for i in range(0, numClasses)])
- print "numClasses = %d" % numClasses
- print "Per-class example fractions, counts:"
- print "Class\tFrac\tCount"
+ print("numClasses = %d" % numClasses)
+ print("Per-class example fractions, counts:")
+ print("Class\tFrac\tCount")
for c in sortedClasses:
frac = classCounts[c] / (numExamples + 0.0)
- print "%g\t%g\t%d" % (c, frac, classCounts[c])
+ print("%g\t%g\t%d" % (c, frac, classCounts[c]))
if (sortedClasses[0] == 0 and sortedClasses[-1] == numClasses - 1):
return (data, origToNewLabels)
@@ -105,8 +105,7 @@ def reindexClassLabels(data):
def usage():
- print >> sys.stderr, \
- "Usage: decision_tree_runner [libsvm format data filepath]"
+ print("Usage: decision_tree_runner [libsvm format data filepath]", file=sys.stderr)
exit(1)
@@ -133,13 +132,13 @@ if __name__ == "__main__":
model = DecisionTree.trainClassifier(reindexedData, numClasses=numClasses,
categoricalFeaturesInfo=categoricalFeaturesInfo)
# Print learned tree and stats.
- print "Trained DecisionTree for classification:"
- print " Model numNodes: %d" % model.numNodes()
- print " Model depth: %d" % model.depth()
- print " Training accuracy: %g" % getAccuracy(model, reindexedData)
+ print("Trained DecisionTree for classification:")
+ print(" Model numNodes: %d" % model.numNodes())
+ print(" Model depth: %d" % model.depth())
+ print(" Training accuracy: %g" % getAccuracy(model, reindexedData))
if model.numNodes() < 20:
- print model.toDebugString()
+ print(model.toDebugString())
else:
- print model
+ print(model)
sc.stop()