From 7bf6cc9701cbb0f77fb85a412e387fb92274fca5 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 1 Oct 2014 01:03:24 -0700 Subject: [SPARK-3751] [mllib] DecisionTree: example update + print options DecisionTreeRunner functionality additions: * Allow user to pass in a test dataset * Do not print full model if the model is too large. As part of this, modify DecisionTreeModel and RandomForestModel to allow printing less info. Proposed updates: * toString: prints model summary * toDebugString: prints full model (named after RDD.toDebugString) Similar update to Python API: * __repr__() now prints a model summary * toDebugString() now prints the full model CC: mengxr chouqin manishamde codedeft Small update (whomever can take a look). Thanks! Author: Joseph K. Bradley Closes #2604 from jkbradley/dtrunner-update and squashes the following commits: b2b3c60 [Joseph K. Bradley] re-added python sql doc test, temporarily removed before 07b1fae [Joseph K. Bradley] repr() now prints a model summary toDebugString() now prints the full model 1d0d93d [Joseph K. Bradley] Updated DT and RF to print less when toString is called. Added toDebugString for verbose printing. 22eac8c [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dtrunner-update e007a95 [Joseph K. Bradley] Updated DecisionTreeRunner to accept a test dataset. --- python/pyspark/mllib/tree.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) (limited to 'python') diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index f59a818a6e..afdcdbdf3a 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -77,8 +77,13 @@ class DecisionTreeModel(object): return self._java_model.depth() def __repr__(self): + """ Print summary of model. """ return self._java_model.toString() + def toDebugString(self): + """ Print full model. """ + return self._java_model.toDebugString() + class DecisionTree(object): @@ -135,7 +140,6 @@ class DecisionTree(object): >>> from numpy import array >>> from pyspark.mllib.regression import LabeledPoint >>> from pyspark.mllib.tree import DecisionTree - >>> from pyspark.mllib.linalg import SparseVector >>> >>> data = [ ... LabeledPoint(0.0, [0.0]), @@ -145,7 +149,9 @@ class DecisionTree(object): ... ] >>> model = DecisionTree.trainClassifier(sc.parallelize(data), 2, {}) >>> print model, # it already has newline - DecisionTreeModel classifier + DecisionTreeModel classifier of depth 1 with 3 nodes + >>> print model.toDebugString(), # it already has newline + DecisionTreeModel classifier of depth 1 with 3 nodes If (feature 0 <= 0.5) Predict: 0.0 Else (feature 0 > 0.5) -- cgit v1.2.3