aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala99
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala14
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala30
-rw-r--r--python/pyspark/mllib/tree.py10
4 files changed, 111 insertions, 42 deletions
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
index 96fb068e9e..4adc91d2fb 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
@@ -52,6 +52,7 @@ object DecisionTreeRunner {
case class Params(
input: String = null,
+ testInput: String = "",
dataFormat: String = "libsvm",
algo: Algo = Classification,
maxDepth: Int = 5,
@@ -98,13 +99,18 @@ object DecisionTreeRunner {
s"default: ${defaultParams.featureSubsetStrategy}")
.action((x, c) => c.copy(featureSubsetStrategy = x))
opt[Double]("fracTest")
- .text(s"fraction of data to hold out for testing, default: ${defaultParams.fracTest}")
+ .text(s"fraction of data to hold out for testing. If given option testInput, " +
+ s"this option is ignored. default: ${defaultParams.fracTest}")
.action((x, c) => c.copy(fracTest = x))
+ opt[String]("testInput")
+ .text(s"input path to test dataset. If given, option fracTest is ignored." +
+ s" default: ${defaultParams.testInput}")
+ .action((x, c) => c.copy(testInput = x))
opt[String]("<dataFormat>")
.text("data format: libsvm (default), dense (deprecated in Spark v1.1)")
.action((x, c) => c.copy(dataFormat = x))
arg[String]("<input>")
- .text("input paths to labeled examples in dense format (label,f0 f1 f2 ...)")
+ .text("input path to labeled examples")
.required()
.action((x, c) => c.copy(input = x))
checkConfig { params =>
@@ -141,7 +147,7 @@ object DecisionTreeRunner {
case "libsvm" => MLUtils.loadLibSVMFile(sc, params.input).cache()
}
// For classification, re-index classes if needed.
- val (examples, numClasses) = params.algo match {
+ val (examples, classIndexMap, numClasses) = params.algo match {
case Classification => {
// classCounts: class --> # examples in class
val classCounts = origExamples.map(_.label).countByValue()
@@ -170,16 +176,40 @@ object DecisionTreeRunner {
val frac = classCounts(c) / numExamples.toDouble
println(s"$c\t$frac\t${classCounts(c)}")
}
- (examples, numClasses)
+ (examples, classIndexMap, numClasses)
}
case Regression =>
- (origExamples, 0)
+ (origExamples, null, 0)
case _ =>
throw new IllegalArgumentException("Algo ${params.algo} not supported.")
}
- // Split into training, test.
- val splits = examples.randomSplit(Array(1.0 - params.fracTest, params.fracTest))
+ // Create training, test sets.
+ val splits = if (params.testInput != "") {
+ // Load testInput.
+ val origTestExamples = params.dataFormat match {
+ case "dense" => MLUtils.loadLabeledPoints(sc, params.testInput)
+ case "libsvm" => MLUtils.loadLibSVMFile(sc, params.testInput)
+ }
+ params.algo match {
+ case Classification => {
+ // classCounts: class --> # examples in class
+ val testExamples = {
+ if (classIndexMap.isEmpty) {
+ origTestExamples
+ } else {
+ origTestExamples.map(lp => LabeledPoint(classIndexMap(lp.label), lp.features))
+ }
+ }
+ Array(examples, testExamples)
+ }
+ case Regression =>
+ Array(examples, origTestExamples)
+ }
+ } else {
+ // Split input into training, test.
+ examples.randomSplit(Array(1.0 - params.fracTest, params.fracTest))
+ }
val training = splits(0).cache()
val test = splits(1).cache()
val numTraining = training.count()
@@ -206,32 +236,56 @@ object DecisionTreeRunner {
minInfoGain = params.minInfoGain)
if (params.numTrees == 1) {
val model = DecisionTree.train(training, strategy)
- println(model)
+ if (model.numNodes < 20) {
+ println(model.toDebugString) // Print full model.
+ } else {
+ println(model) // Print model summary.
+ }
if (params.algo == Classification) {
- val accuracy =
+ val trainAccuracy =
+ new MulticlassMetrics(training.map(lp => (model.predict(lp.features), lp.label)))
+ .precision
+ println(s"Train accuracy = $trainAccuracy")
+ val testAccuracy =
new MulticlassMetrics(test.map(lp => (model.predict(lp.features), lp.label))).precision
- println(s"Test accuracy = $accuracy")
+ println(s"Test accuracy = $testAccuracy")
}
if (params.algo == Regression) {
- val mse = meanSquaredError(model, test)
- println(s"Test mean squared error = $mse")
+ val trainMSE = meanSquaredError(model, training)
+ println(s"Train mean squared error = $trainMSE")
+ val testMSE = meanSquaredError(model, test)
+ println(s"Test mean squared error = $testMSE")
}
} else {
val randomSeed = Utils.random.nextInt()
if (params.algo == Classification) {
val model = RandomForest.trainClassifier(training, strategy, params.numTrees,
params.featureSubsetStrategy, randomSeed)
- println(model)
- val accuracy =
+ if (model.totalNumNodes < 30) {
+ println(model.toDebugString) // Print full model.
+ } else {
+ println(model) // Print model summary.
+ }
+ val trainAccuracy =
+ new MulticlassMetrics(training.map(lp => (model.predict(lp.features), lp.label)))
+ .precision
+ println(s"Train accuracy = $trainAccuracy")
+ val testAccuracy =
new MulticlassMetrics(test.map(lp => (model.predict(lp.features), lp.label))).precision
- println(s"Test accuracy = $accuracy")
+ println(s"Test accuracy = $testAccuracy")
}
if (params.algo == Regression) {
val model = RandomForest.trainRegressor(training, strategy, params.numTrees,
params.featureSubsetStrategy, randomSeed)
- println(model)
- val mse = meanSquaredError(model, test)
- println(s"Test mean squared error = $mse")
+ if (model.totalNumNodes < 30) {
+ println(model.toDebugString) // Print full model.
+ } else {
+ println(model) // Print model summary.
+ }
+ val trainMSE = meanSquaredError(model, training)
+ println(s"Train mean squared error = $trainMSE")
+ val testMSE = meanSquaredError(model, test)
+ println(s"Test mean squared error = $testMSE")
}
}
@@ -239,15 +293,6 @@ object DecisionTreeRunner {
}
/**
- * Calculates the classifier accuracy.
- */
- private def accuracyScore(model: DecisionTreeModel, data: RDD[LabeledPoint]): Double = {
- val correctCount = data.filter(y => model.predict(y.features) == y.label).count()
- val count = data.count()
- correctCount.toDouble / count
- }
-
- /**
* Calculates the mean squared error for regression.
*/
private def meanSquaredError(tree: DecisionTreeModel, data: RDD[LabeledPoint]): Double = {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
index 271b2c4ad8..ec1d99ab26 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
@@ -68,15 +68,23 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable
}
/**
- * Print full model.
+ * Print a summary of the model.
*/
override def toString: String = algo match {
case Classification =>
- s"DecisionTreeModel classifier\n" + topNode.subtreeToString(2)
+ s"DecisionTreeModel classifier of depth $depth with $numNodes nodes"
case Regression =>
- s"DecisionTreeModel regressor\n" + topNode.subtreeToString(2)
+ s"DecisionTreeModel regressor of depth $depth with $numNodes nodes"
case _ => throw new IllegalArgumentException(
s"DecisionTreeModel given unknown algo parameter: $algo.")
}
+ /**
+ * Print the full model to a string.
+ */
+ def toDebugString: String = {
+ val header = toString + "\n"
+ header + topNode.subtreeToString(2)
+ }
+
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala
index 538c0e2332..4d66d6d81c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala
@@ -73,17 +73,27 @@ class RandomForestModel(val trees: Array[DecisionTreeModel], val algo: Algo) ext
def numTrees: Int = trees.size
/**
- * Print full model.
+ * Get total number of nodes, summed over all trees in the forest.
*/
- override def toString: String = {
- val header = algo match {
- case Classification =>
- s"RandomForestModel classifier with $numTrees trees\n"
- case Regression =>
- s"RandomForestModel regressor with $numTrees trees\n"
- case _ => throw new IllegalArgumentException(
- s"RandomForestModel given unknown algo parameter: $algo.")
- }
+ def totalNumNodes: Int = trees.map(tree => tree.numNodes).sum
+
+ /**
+ * Print a summary of the model.
+ */
+ override def toString: String = algo match {
+ case Classification =>
+ s"RandomForestModel classifier with $numTrees trees"
+ case Regression =>
+ s"RandomForestModel regressor with $numTrees trees"
+ case _ => throw new IllegalArgumentException(
+ s"RandomForestModel given unknown algo parameter: $algo.")
+ }
+
+ /**
+ * Print the full model to a string.
+ */
+ def toDebugString: String = {
+ val header = toString + "\n"
header + trees.zipWithIndex.map { case (tree, treeIndex) =>
s" Tree $treeIndex:\n" + tree.topNode.subtreeToString(4)
}.fold("")(_ + _)
diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py
index f59a818a6e..afdcdbdf3a 100644
--- a/python/pyspark/mllib/tree.py
+++ b/python/pyspark/mllib/tree.py
@@ -77,8 +77,13 @@ class DecisionTreeModel(object):
return self._java_model.depth()
def __repr__(self):
+ """ Print summary of model. """
return self._java_model.toString()
+ def toDebugString(self):
+ """ Print full model. """
+ return self._java_model.toDebugString()
+
class DecisionTree(object):
@@ -135,7 +140,6 @@ class DecisionTree(object):
>>> from numpy import array
>>> from pyspark.mllib.regression import LabeledPoint
>>> from pyspark.mllib.tree import DecisionTree
- >>> from pyspark.mllib.linalg import SparseVector
>>>
>>> data = [
... LabeledPoint(0.0, [0.0]),
@@ -145,7 +149,9 @@ class DecisionTree(object):
... ]
>>> model = DecisionTree.trainClassifier(sc.parallelize(data), 2, {})
>>> print model, # it already has newline
- DecisionTreeModel classifier
+ DecisionTreeModel classifier of depth 1 with 3 nodes
+ >>> print model.toDebugString(), # it already has newline
+ DecisionTreeModel classifier of depth 1 with 3 nodes
If (feature 0 <= 0.5)
Predict: 0.0
Else (feature 0 > 0.5)