aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph.kurata.bradley@gmail.com>2014-10-01 01:03:24 -0700
committerXiangrui Meng <meng@databricks.com>2014-10-01 01:03:24 -0700
commit7bf6cc9701cbb0f77fb85a412e387fb92274fca5 (patch)
tree21d38a426534826700f9f94b8f8d81034f55ea9b /mllib/src
parenteb43043f411b87b7b412ee31e858246bd93fdd04 (diff)
downloadspark-7bf6cc9701cbb0f77fb85a412e387fb92274fca5.tar.gz
spark-7bf6cc9701cbb0f77fb85a412e387fb92274fca5.tar.bz2
spark-7bf6cc9701cbb0f77fb85a412e387fb92274fca5.zip
[SPARK-3751] [mllib] DecisionTree: example update + print options
DecisionTreeRunner functionality additions: * Allow user to pass in a test dataset * Do not print full model if the model is too large. As part of this, modify DecisionTreeModel and RandomForestModel to allow printing less info. Proposed updates: * toString: prints model summary * toDebugString: prints full model (named after RDD.toDebugString) Similar update to Python API: * __repr__() now prints a model summary * toDebugString() now prints the full model CC: mengxr chouqin manishamde codedeft Small update (whomever can take a look). Thanks! Author: Joseph K. Bradley <joseph.kurata.bradley@gmail.com> Closes #2604 from jkbradley/dtrunner-update and squashes the following commits: b2b3c60 [Joseph K. Bradley] re-added python sql doc test, temporarily removed before 07b1fae [Joseph K. Bradley] repr() now prints a model summary toDebugString() now prints the full model 1d0d93d [Joseph K. Bradley] Updated DT and RF to print less when toString is called. Added toDebugString for verbose printing. 22eac8c [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dtrunner-update e007a95 [Joseph K. Bradley] Updated DecisionTreeRunner to accept a test dataset.
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala14
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala30
2 files changed, 31 insertions, 13 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
index 271b2c4ad8..ec1d99ab26 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
@@ -68,15 +68,23 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable
}
/**
- * Print full model.
+ * Print a summary of the model.
*/
override def toString: String = algo match {
case Classification =>
- s"DecisionTreeModel classifier\n" + topNode.subtreeToString(2)
+ s"DecisionTreeModel classifier of depth $depth with $numNodes nodes"
case Regression =>
- s"DecisionTreeModel regressor\n" + topNode.subtreeToString(2)
+ s"DecisionTreeModel regressor of depth $depth with $numNodes nodes"
case _ => throw new IllegalArgumentException(
s"DecisionTreeModel given unknown algo parameter: $algo.")
}
+ /**
+ * Print the full model to a string.
+ */
+ def toDebugString: String = {
+ val header = toString + "\n"
+ header + topNode.subtreeToString(2)
+ }
+
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala
index 538c0e2332..4d66d6d81c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala
@@ -73,17 +73,27 @@ class RandomForestModel(val trees: Array[DecisionTreeModel], val algo: Algo) ext
def numTrees: Int = trees.size
/**
- * Print full model.
+ * Get total number of nodes, summed over all trees in the forest.
*/
- override def toString: String = {
- val header = algo match {
- case Classification =>
- s"RandomForestModel classifier with $numTrees trees\n"
- case Regression =>
- s"RandomForestModel regressor with $numTrees trees\n"
- case _ => throw new IllegalArgumentException(
- s"RandomForestModel given unknown algo parameter: $algo.")
- }
+ def totalNumNodes: Int = trees.map(tree => tree.numNodes).sum
+
+ /**
+ * Print a summary of the model.
+ */
+ override def toString: String = algo match {
+ case Classification =>
+ s"RandomForestModel classifier with $numTrees trees"
+ case Regression =>
+ s"RandomForestModel regressor with $numTrees trees"
+ case _ => throw new IllegalArgumentException(
+ s"RandomForestModel given unknown algo parameter: $algo.")
+ }
+
+ /**
+ * Print the full model to a string.
+ */
+ def toDebugString: String = {
+ val header = toString + "\n"
header + trees.zipWithIndex.map { case (tree, treeIndex) =>
s" Tree $treeIndex:\n" + tree.topNode.subtreeToString(4)
}.fold("")(_ + _)