aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph.kurata.bradley@gmail.com>2014-10-01 01:03:24 -0700
committerXiangrui Meng <meng@databricks.com>2014-10-01 01:03:24 -0700
commit7bf6cc9701cbb0f77fb85a412e387fb92274fca5 (patch)
tree21d38a426534826700f9f94b8f8d81034f55ea9b /python
parenteb43043f411b87b7b412ee31e858246bd93fdd04 (diff)
downloadspark-7bf6cc9701cbb0f77fb85a412e387fb92274fca5.tar.gz
spark-7bf6cc9701cbb0f77fb85a412e387fb92274fca5.tar.bz2
spark-7bf6cc9701cbb0f77fb85a412e387fb92274fca5.zip
[SPARK-3751] [mllib] DecisionTree: example update + print options
DecisionTreeRunner functionality additions: * Allow user to pass in a test dataset * Do not print full model if the model is too large. As part of this, modify DecisionTreeModel and RandomForestModel to allow printing less info. Proposed updates: * toString: prints model summary * toDebugString: prints full model (named after RDD.toDebugString) Similar update to Python API: * __repr__() now prints a model summary * toDebugString() now prints the full model CC: mengxr chouqin manishamde codedeft Small update (whomever can take a look). Thanks! Author: Joseph K. Bradley <joseph.kurata.bradley@gmail.com> Closes #2604 from jkbradley/dtrunner-update and squashes the following commits: b2b3c60 [Joseph K. Bradley] re-added python sql doc test, temporarily removed before 07b1fae [Joseph K. Bradley] repr() now prints a model summary toDebugString() now prints the full model 1d0d93d [Joseph K. Bradley] Updated DT and RF to print less when toString is called. Added toDebugString for verbose printing. 22eac8c [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dtrunner-update e007a95 [Joseph K. Bradley] Updated DecisionTreeRunner to accept a test dataset.
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/mllib/tree.py10
1 files changed, 8 insertions, 2 deletions
diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py
index f59a818a6e..afdcdbdf3a 100644
--- a/python/pyspark/mllib/tree.py
+++ b/python/pyspark/mllib/tree.py
@@ -77,8 +77,13 @@ class DecisionTreeModel(object):
return self._java_model.depth()
def __repr__(self):
+ """ Print summary of model. """
return self._java_model.toString()
+ def toDebugString(self):
+ """ Print full model. """
+ return self._java_model.toDebugString()
+
class DecisionTree(object):
@@ -135,7 +140,6 @@ class DecisionTree(object):
>>> from numpy import array
>>> from pyspark.mllib.regression import LabeledPoint
>>> from pyspark.mllib.tree import DecisionTree
- >>> from pyspark.mllib.linalg import SparseVector
>>>
>>> data = [
... LabeledPoint(0.0, [0.0]),
@@ -145,7 +149,9 @@ class DecisionTree(object):
... ]
>>> model = DecisionTree.trainClassifier(sc.parallelize(data), 2, {})
>>> print model, # it already has newline
- DecisionTreeModel classifier
+ DecisionTreeModel classifier of depth 1 with 3 nodes
+ >>> print model.toDebugString(), # it already has newline
+ DecisionTreeModel classifier of depth 1 with 3 nodes
If (feature 0 <= 0.5)
Predict: 0.0
Else (feature 0 > 0.5)