aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
authorqiping.lqp <qiping.lqp@alibaba-inc.com>2014-09-15 17:43:26 -0700
committerXiangrui Meng <meng@databricks.com>2014-09-15 17:43:26 -0700
commitfdb302f49c021227026909bdcdade7496059013f (patch)
treeeb9136a917317cf4777f01454436784d89496448 /mllib/src
parent983d6a9c48b69c5f0542922aa8b133f69eb1034d (diff)
downloadspark-fdb302f49c021227026909bdcdade7496059013f.tar.gz
spark-fdb302f49c021227026909bdcdade7496059013f.tar.bz2
spark-fdb302f49c021227026909bdcdade7496059013f.zip
[SPARK-3516] [mllib] DecisionTree: Add minInstancesPerNode, minInfoGain params to example and Python API
Added minInstancesPerNode, minInfoGain params to: * DecisionTreeRunner.scala example * Python API (tree.py) Also: * Fixed typo in tree suite test "do not choose split that does not satisfy min instance per node requirements" * small style fixes CC: mengxr Author: qiping.lqp <qiping.lqp@alibaba-inc.com> Author: Joseph K. Bradley <joseph.kurata.bradley@gmail.com> Author: chouqin <liqiping1991@gmail.com> Closes #2349 from jkbradley/chouqin-dt-preprune and squashes the following commits: 61b2e72 [Joseph K. Bradley] Added max of 10GB for maxMemoryInMB in Strategy. a95e7c8 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into chouqin-dt-preprune 95c479d [Joseph K. Bradley] * Fixed typo in tree suite test "do not choose split that does not satisfy min instance per node requirements" * small style fixes e2628b6 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into chouqin-dt-preprune 19b01af [Joseph K. Bradley] Merge remote-tracking branch 'chouqin/dt-preprune' into chouqin-dt-preprune f1d11d1 [chouqin] fix typo c7ebaf1 [chouqin] fix typo 39f9b60 [chouqin] change edge `minInstancesPerNode` to 2 and add one more test c6e2dfc [Joseph K. Bradley] Added minInstancesPerNode and minInfoGain parameters to DecisionTreeRunner.scala and to Python API in tree.py 0278a11 [chouqin] remove `noSplit` and set `Predict` private to tree d593ec7 [chouqin] fix docs and change minInstancesPerNode to 1 efcc736 [qiping.lqp] fix bug 10b8012 [qiping.lqp] fix style 6728fad [qiping.lqp] minor fix: remove empty lines bb465ca [qiping.lqp] Merge branch 'master' of https://github.com/apache/spark into dt-preprune cadd569 [qiping.lqp] add api docs 46b891f [qiping.lqp] fix bug e72c7e4 [qiping.lqp] add comments 845c6fa [qiping.lqp] fix style f195e83 [qiping.lqp] fix style 987cbf4 [qiping.lqp] fix bug ff34845 [qiping.lqp] separate calculation of predict of node from calculation of info gain ac42378 [qiping.lqp] add min info gain and min instances per node parameters in decision tree
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala8
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala6
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala4
5 files changed, 13 insertions, 11 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index 4343124f10..fa0fa69f38 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -303,7 +303,9 @@ class PythonMLLibAPI extends Serializable {
categoricalFeaturesInfoJMap: java.util.Map[Int, Int],
impurityStr: String,
maxDepth: Int,
- maxBins: Int): DecisionTreeModel = {
+ maxBins: Int,
+ minInstancesPerNode: Int,
+ minInfoGain: Double): DecisionTreeModel = {
val data = dataBytesJRDD.rdd.map(SerDe.deserializeLabeledPoint)
@@ -316,7 +318,9 @@ class PythonMLLibAPI extends Serializable {
maxDepth = maxDepth,
numClassesForClassification = numClasses,
maxBins = maxBins,
- categoricalFeaturesInfo = categoricalFeaturesInfoJMap.asScala.toMap)
+ categoricalFeaturesInfo = categoricalFeaturesInfoJMap.asScala.toMap,
+ minInstancesPerNode = minInstancesPerNode,
+ minInfoGain = minInfoGain)
DecisionTree.train(data, strategy)
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index 56bb881210..c7f2576c82 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -389,7 +389,7 @@ object DecisionTree extends Serializable with Logging {
var groupIndex = 0
var doneTraining = true
while (groupIndex < numGroups) {
- val (tmpRoot, doneTrainingGroup) = findBestSplitsPerGroup(input, metadata, level,
+ val (_, doneTrainingGroup) = findBestSplitsPerGroup(input, metadata, level,
topNode, splits, bins, timer, numGroups, groupIndex)
doneTraining = doneTraining && doneTrainingGroup
groupIndex += 1
@@ -898,7 +898,7 @@ object DecisionTree extends Serializable with Logging {
}
}.maxBy(_._2.gain)
- require(predict.isDefined, "must calculate predict for each node")
+ assert(predict.isDefined, "must calculate predict for each node")
(bestSplit, bestSplitStats, predict.get)
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
index 31d1e8ac30..caaccbfb8a 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
@@ -77,6 +77,8 @@ class Strategy (
}
require(minInstancesPerNode >= 1,
s"DecisionTree Strategy requires minInstancesPerNode >= 1 but was given $minInstancesPerNode")
+ require(maxMemoryInMB <= 10240,
+ s"DecisionTree Strategy requires maxMemoryInMB <= 10240, but was given $maxMemoryInMB")
val isMulticlassClassification =
algo == Classification && numClassesForClassification > 2
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala
index 6fac2be279..d8476b5cd7 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala
@@ -17,18 +17,14 @@
package org.apache.spark.mllib.tree.model
-import org.apache.spark.annotation.DeveloperApi
-
/**
- * :: DeveloperApi ::
* Predicted value for a node
* @param predict predicted value
* @param prob probability of the label (classification only)
*/
-@DeveloperApi
private[tree] class Predict(
val predict: Double,
- val prob: Double = 0.0) extends Serializable{
+ val prob: Double = 0.0) extends Serializable {
override def toString = {
"predict = %f, prob = %f".format(predict, prob)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index 1bd7ea05c4..2b2e579b99 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -714,8 +714,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(gain == InformationGainStats.invalidInformationGainStats)
}
- test("don't choose split that doesn't satisfy min instance per node requirements") {
- // if a split doesn't satisfy min instances per node requirements,
+ test("do not choose split that does not satisfy min instance per node requirements") {
+ // if a split does not satisfy min instances per node requirements,
// this split is invalid, even though the information gain of split is large.
val arr = new Array[LabeledPoint](4)
arr(0) = new LabeledPoint(0.0, Vectors.dense(0.0, 1.0))