diff options
author | Joseph K. Bradley <joseph.kurata.bradley@gmail.com> | 2014-10-01 01:03:24 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2014-10-01 01:03:24 -0700 |
commit | 7bf6cc9701cbb0f77fb85a412e387fb92274fca5 (patch) | |
tree | 21d38a426534826700f9f94b8f8d81034f55ea9b /python/pyspark/mllib | |
parent | eb43043f411b87b7b412ee31e858246bd93fdd04 (diff) | |
download | spark-7bf6cc9701cbb0f77fb85a412e387fb92274fca5.tar.gz spark-7bf6cc9701cbb0f77fb85a412e387fb92274fca5.tar.bz2 spark-7bf6cc9701cbb0f77fb85a412e387fb92274fca5.zip |
[SPARK-3751] [mllib] DecisionTree: example update + print options
DecisionTreeRunner functionality additions:
* Allow user to pass in a test dataset
* Do not print full model if the model is too large.
As part of this, modify DecisionTreeModel and RandomForestModel to allow printing less info. Proposed updates:
* toString: prints model summary
* toDebugString: prints full model (named after RDD.toDebugString)
Similar update to Python API:
* __repr__() now prints a model summary
* toDebugString() now prints the full model
CC: mengxr chouqin manishamde codedeft Small update (whomever can take a look). Thanks!
Author: Joseph K. Bradley <joseph.kurata.bradley@gmail.com>
Closes #2604 from jkbradley/dtrunner-update and squashes the following commits:
b2b3c60 [Joseph K. Bradley] re-added python sql doc test, temporarily removed before
07b1fae [Joseph K. Bradley] repr() now prints a model summary toDebugString() now prints the full model
1d0d93d [Joseph K. Bradley] Updated DT and RF to print less when toString is called. Added toDebugString for verbose printing.
22eac8c [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dtrunner-update
e007a95 [Joseph K. Bradley] Updated DecisionTreeRunner to accept a test dataset.
Diffstat (limited to 'python/pyspark/mllib')
-rw-r--r-- | python/pyspark/mllib/tree.py | 10 |
1 files changed, 8 insertions, 2 deletions
diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index f59a818a6e..afdcdbdf3a 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -77,8 +77,13 @@ class DecisionTreeModel(object): return self._java_model.depth() def __repr__(self): + """ Print summary of model. """ return self._java_model.toString() + def toDebugString(self): + """ Print full model. """ + return self._java_model.toDebugString() + class DecisionTree(object): @@ -135,7 +140,6 @@ class DecisionTree(object): >>> from numpy import array >>> from pyspark.mllib.regression import LabeledPoint >>> from pyspark.mllib.tree import DecisionTree - >>> from pyspark.mllib.linalg import SparseVector >>> >>> data = [ ... LabeledPoint(0.0, [0.0]), @@ -145,7 +149,9 @@ class DecisionTree(object): ... ] >>> model = DecisionTree.trainClassifier(sc.parallelize(data), 2, {}) >>> print model, # it already has newline - DecisionTreeModel classifier + DecisionTreeModel classifier of depth 1 with 3 nodes + >>> print model.toDebugString(), # it already has newline + DecisionTreeModel classifier of depth 1 with 3 nodes If (feature 0 <= 0.5) Predict: 0.0 Else (feature 0 > 0.5) |