aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorqiping.lqp <qiping.lqp@alibaba-inc.com>2014-09-15 17:43:26 -0700
committerXiangrui Meng <meng@databricks.com>2014-09-15 17:43:26 -0700
commitfdb302f49c021227026909bdcdade7496059013f (patch)
treeeb9136a917317cf4777f01454436784d89496448
parent983d6a9c48b69c5f0542922aa8b133f69eb1034d (diff)
downloadspark-fdb302f49c021227026909bdcdade7496059013f.tar.gz
spark-fdb302f49c021227026909bdcdade7496059013f.tar.bz2
spark-fdb302f49c021227026909bdcdade7496059013f.zip
[SPARK-3516] [mllib] DecisionTree: Add minInstancesPerNode, minInfoGain params to example and Python API
Added minInstancesPerNode, minInfoGain params to: * DecisionTreeRunner.scala example * Python API (tree.py) Also: * Fixed typo in tree suite test "do not choose split that does not satisfy min instance per node requirements" * small style fixes CC: mengxr Author: qiping.lqp <qiping.lqp@alibaba-inc.com> Author: Joseph K. Bradley <joseph.kurata.bradley@gmail.com> Author: chouqin <liqiping1991@gmail.com> Closes #2349 from jkbradley/chouqin-dt-preprune and squashes the following commits: 61b2e72 [Joseph K. Bradley] Added max of 10GB for maxMemoryInMB in Strategy. a95e7c8 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into chouqin-dt-preprune 95c479d [Joseph K. Bradley] * Fixed typo in tree suite test "do not choose split that does not satisfy min instance per node requirements" * small style fixes e2628b6 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into chouqin-dt-preprune 19b01af [Joseph K. Bradley] Merge remote-tracking branch 'chouqin/dt-preprune' into chouqin-dt-preprune f1d11d1 [chouqin] fix typo c7ebaf1 [chouqin] fix typo 39f9b60 [chouqin] change edge `minInstancesPerNode` to 2 and add one more test c6e2dfc [Joseph K. Bradley] Added minInstancesPerNode and minInfoGain parameters to DecisionTreeRunner.scala and to Python API in tree.py 0278a11 [chouqin] remove `noSplit` and set `Predict` private to tree d593ec7 [chouqin] fix docs and change minInstancesPerNode to 1 efcc736 [qiping.lqp] fix bug 10b8012 [qiping.lqp] fix style 6728fad [qiping.lqp] minor fix: remove empty lines bb465ca [qiping.lqp] Merge branch 'master' of https://github.com/apache/spark into dt-preprune cadd569 [qiping.lqp] add api docs 46b891f [qiping.lqp] fix bug e72c7e4 [qiping.lqp] add comments 845c6fa [qiping.lqp] fix style f195e83 [qiping.lqp] fix style 987cbf4 [qiping.lqp] fix bug ff34845 [qiping.lqp] separate calculation of predict of node from calculation of info gain ac42378 [qiping.lqp] add min info gain and min instances per node parameters in decision tree
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala13
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala8
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala6
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala4
-rw-r--r--python/pyspark/mllib/tree.py16
7 files changed, 37 insertions, 16 deletions
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
index 72c3ab475b..4683e6eb96 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
@@ -55,6 +55,8 @@ object DecisionTreeRunner {
maxDepth: Int = 5,
impurity: ImpurityType = Gini,
maxBins: Int = 32,
+ minInstancesPerNode: Int = 1,
+ minInfoGain: Double = 0.0,
fracTest: Double = 0.2)
def main(args: Array[String]) {
@@ -75,6 +77,13 @@ object DecisionTreeRunner {
opt[Int]("maxBins")
.text(s"max number of bins, default: ${defaultParams.maxBins}")
.action((x, c) => c.copy(maxBins = x))
+ opt[Int]("minInstancesPerNode")
+ .text(s"min number of instances required at child nodes to create the parent split," +
+ s" default: ${defaultParams.minInstancesPerNode}")
+ .action((x, c) => c.copy(minInstancesPerNode = x))
+ opt[Double]("minInfoGain")
+ .text(s"min info gain required to create a split, default: ${defaultParams.minInfoGain}")
+ .action((x, c) => c.copy(minInfoGain = x))
opt[Double]("fracTest")
.text(s"fraction of data to hold out for testing, default: ${defaultParams.fracTest}")
.action((x, c) => c.copy(fracTest = x))
@@ -179,7 +188,9 @@ object DecisionTreeRunner {
impurity = impurityCalculator,
maxDepth = params.maxDepth,
maxBins = params.maxBins,
- numClassesForClassification = numClasses)
+ numClassesForClassification = numClasses,
+ minInstancesPerNode = params.minInstancesPerNode,
+ minInfoGain = params.minInfoGain)
val model = DecisionTree.train(training, strategy)
println(model)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index 4343124f10..fa0fa69f38 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -303,7 +303,9 @@ class PythonMLLibAPI extends Serializable {
categoricalFeaturesInfoJMap: java.util.Map[Int, Int],
impurityStr: String,
maxDepth: Int,
- maxBins: Int): DecisionTreeModel = {
+ maxBins: Int,
+ minInstancesPerNode: Int,
+ minInfoGain: Double): DecisionTreeModel = {
val data = dataBytesJRDD.rdd.map(SerDe.deserializeLabeledPoint)
@@ -316,7 +318,9 @@ class PythonMLLibAPI extends Serializable {
maxDepth = maxDepth,
numClassesForClassification = numClasses,
maxBins = maxBins,
- categoricalFeaturesInfo = categoricalFeaturesInfoJMap.asScala.toMap)
+ categoricalFeaturesInfo = categoricalFeaturesInfoJMap.asScala.toMap,
+ minInstancesPerNode = minInstancesPerNode,
+ minInfoGain = minInfoGain)
DecisionTree.train(data, strategy)
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index 56bb881210..c7f2576c82 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -389,7 +389,7 @@ object DecisionTree extends Serializable with Logging {
var groupIndex = 0
var doneTraining = true
while (groupIndex < numGroups) {
- val (tmpRoot, doneTrainingGroup) = findBestSplitsPerGroup(input, metadata, level,
+ val (_, doneTrainingGroup) = findBestSplitsPerGroup(input, metadata, level,
topNode, splits, bins, timer, numGroups, groupIndex)
doneTraining = doneTraining && doneTrainingGroup
groupIndex += 1
@@ -898,7 +898,7 @@ object DecisionTree extends Serializable with Logging {
}
}.maxBy(_._2.gain)
- require(predict.isDefined, "must calculate predict for each node")
+ assert(predict.isDefined, "must calculate predict for each node")
(bestSplit, bestSplitStats, predict.get)
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
index 31d1e8ac30..caaccbfb8a 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
@@ -77,6 +77,8 @@ class Strategy (
}
require(minInstancesPerNode >= 1,
s"DecisionTree Strategy requires minInstancesPerNode >= 1 but was given $minInstancesPerNode")
+ require(maxMemoryInMB <= 10240,
+ s"DecisionTree Strategy requires maxMemoryInMB <= 10240, but was given $maxMemoryInMB")
val isMulticlassClassification =
algo == Classification && numClassesForClassification > 2
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala
index 6fac2be279..d8476b5cd7 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala
@@ -17,18 +17,14 @@
package org.apache.spark.mllib.tree.model
-import org.apache.spark.annotation.DeveloperApi
-
/**
- * :: DeveloperApi ::
* Predicted value for a node
* @param predict predicted value
* @param prob probability of the label (classification only)
*/
-@DeveloperApi
private[tree] class Predict(
val predict: Double,
- val prob: Double = 0.0) extends Serializable{
+ val prob: Double = 0.0) extends Serializable {
override def toString = {
"predict = %f, prob = %f".format(predict, prob)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index 1bd7ea05c4..2b2e579b99 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -714,8 +714,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(gain == InformationGainStats.invalidInformationGainStats)
}
- test("don't choose split that doesn't satisfy min instance per node requirements") {
- // if a split doesn't satisfy min instances per node requirements,
+ test("do not choose split that does not satisfy min instance per node requirements") {
+ // if a split does not satisfy min instances per node requirements,
// this split is invalid, even though the information gain of split is large.
val arr = new Array[LabeledPoint](4)
arr(0) = new LabeledPoint(0.0, Vectors.dense(0.0, 1.0))
diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py
index ccc000ac70..5b13ab682b 100644
--- a/python/pyspark/mllib/tree.py
+++ b/python/pyspark/mllib/tree.py
@@ -138,7 +138,8 @@ class DecisionTree(object):
@staticmethod
def trainClassifier(data, numClasses, categoricalFeaturesInfo,
- impurity="gini", maxDepth=5, maxBins=32):
+ impurity="gini", maxDepth=5, maxBins=32, minInstancesPerNode=1,
+ minInfoGain=0.0):
"""
Train a DecisionTreeModel for classification.
@@ -154,6 +155,9 @@ class DecisionTree(object):
E.g., depth 0 means 1 leaf node.
Depth 1 means 1 internal node + 2 leaf nodes.
:param maxBins: Number of bins used for finding splits at each node.
+ :param minInstancesPerNode: Min number of instances required at child nodes to create
+ the parent split
+ :param minInfoGain: Min info gain required to create a split
:return: DecisionTreeModel
"""
sc = data.context
@@ -164,13 +168,14 @@ class DecisionTree(object):
model = sc._jvm.PythonMLLibAPI().trainDecisionTreeModel(
dataBytes._jrdd, "classification",
numClasses, categoricalFeaturesInfoJMap,
- impurity, maxDepth, maxBins)
+ impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain)
dataBytes.unpersist()
return DecisionTreeModel(sc, model)
@staticmethod
def trainRegressor(data, categoricalFeaturesInfo,
- impurity="variance", maxDepth=5, maxBins=32):
+ impurity="variance", maxDepth=5, maxBins=32, minInstancesPerNode=1,
+ minInfoGain=0.0):
"""
Train a DecisionTreeModel for regression.
@@ -185,6 +190,9 @@ class DecisionTree(object):
E.g., depth 0 means 1 leaf node.
Depth 1 means 1 internal node + 2 leaf nodes.
:param maxBins: Number of bins used for finding splits at each node.
+ :param minInstancesPerNode: Min number of instances required at child nodes to create
+ the parent split
+ :param minInfoGain: Min info gain required to create a split
:return: DecisionTreeModel
"""
sc = data.context
@@ -195,7 +203,7 @@ class DecisionTree(object):
model = sc._jvm.PythonMLLibAPI().trainDecisionTreeModel(
dataBytes._jrdd, "regression",
0, categoricalFeaturesInfoJMap,
- impurity, maxDepth, maxBins)
+ impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain)
dataBytes.unpersist()
return DecisionTreeModel(sc, model)