aboutsummaryrefslogtreecommitdiff
path: root/examples/src/main/python/mllib/decision_tree_runner.py
diff options
context:
space:
mode:
Diffstat (limited to 'examples/src/main/python/mllib/decision_tree_runner.py')
-rwxr-xr-xexamples/src/main/python/mllib/decision_tree_runner.py17
1 files changed, 10 insertions, 7 deletions
diff --git a/examples/src/main/python/mllib/decision_tree_runner.py b/examples/src/main/python/mllib/decision_tree_runner.py
index 61ea4e06ec..fccabd841b 100755
--- a/examples/src/main/python/mllib/decision_tree_runner.py
+++ b/examples/src/main/python/mllib/decision_tree_runner.py
@@ -106,8 +106,7 @@ def reindexClassLabels(data):
def usage():
print >> sys.stderr, \
- "Usage: decision_tree_runner [libsvm format data filepath]\n" + \
- " Note: This only supports binary classification."
+ "Usage: decision_tree_runner [libsvm format data filepath]"
exit(1)
@@ -127,16 +126,20 @@ if __name__ == "__main__":
# Re-index class labels if needed.
(reindexedData, origToNewLabels) = reindexClassLabels(points)
+ numClasses = len(origToNewLabels)
# Train a classifier.
categoricalFeaturesInfo = {} # no categorical features
- model = DecisionTree.trainClassifier(reindexedData, numClasses=2,
+ model = DecisionTree.trainClassifier(reindexedData, numClasses=numClasses,
categoricalFeaturesInfo=categoricalFeaturesInfo)
# Print learned tree and stats.
print "Trained DecisionTree for classification:"
- print " Model numNodes: %d\n" % model.numNodes()
- print " Model depth: %d\n" % model.depth()
- print " Training accuracy: %g\n" % getAccuracy(model, reindexedData)
- print model
+ print " Model numNodes: %d" % model.numNodes()
+ print " Model depth: %d" % model.depth()
+ print " Training accuracy: %g" % getAccuracy(model, reindexedData)
+ if model.numNodes() < 20:
+ print model.toDebugString()
+ else:
+ print model
sc.stop()