diff options
author | qiping.lqp <qiping.lqp@alibaba-inc.com> | 2014-09-15 17:43:26 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2014-09-15 17:43:26 -0700 |
commit | fdb302f49c021227026909bdcdade7496059013f (patch) | |
tree | eb9136a917317cf4777f01454436784d89496448 /python/pyspark | |
parent | 983d6a9c48b69c5f0542922aa8b133f69eb1034d (diff) | |
download | spark-fdb302f49c021227026909bdcdade7496059013f.tar.gz spark-fdb302f49c021227026909bdcdade7496059013f.tar.bz2 spark-fdb302f49c021227026909bdcdade7496059013f.zip |
[SPARK-3516] [mllib] DecisionTree: Add minInstancesPerNode, minInfoGain params to example and Python API
Added minInstancesPerNode, minInfoGain params to:
* DecisionTreeRunner.scala example
* Python API (tree.py)
Also:
* Fixed typo in tree suite test "do not choose split that does not satisfy min instance per node requirements"
* small style fixes
CC: mengxr
Author: qiping.lqp <qiping.lqp@alibaba-inc.com>
Author: Joseph K. Bradley <joseph.kurata.bradley@gmail.com>
Author: chouqin <liqiping1991@gmail.com>
Closes #2349 from jkbradley/chouqin-dt-preprune and squashes the following commits:
61b2e72 [Joseph K. Bradley] Added max of 10GB for maxMemoryInMB in Strategy.
a95e7c8 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into chouqin-dt-preprune
95c479d [Joseph K. Bradley] * Fixed typo in tree suite test "do not choose split that does not satisfy min instance per node requirements" * small style fixes
e2628b6 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into chouqin-dt-preprune
19b01af [Joseph K. Bradley] Merge remote-tracking branch 'chouqin/dt-preprune' into chouqin-dt-preprune
f1d11d1 [chouqin] fix typo
c7ebaf1 [chouqin] fix typo
39f9b60 [chouqin] change edge `minInstancesPerNode` to 2 and add one more test
c6e2dfc [Joseph K. Bradley] Added minInstancesPerNode and minInfoGain parameters to DecisionTreeRunner.scala and to Python API in tree.py
0278a11 [chouqin] remove `noSplit` and set `Predict` private to tree
d593ec7 [chouqin] fix docs and change minInstancesPerNode to 1
efcc736 [qiping.lqp] fix bug
10b8012 [qiping.lqp] fix style
6728fad [qiping.lqp] minor fix: remove empty lines
bb465ca [qiping.lqp] Merge branch 'master' of https://github.com/apache/spark into dt-preprune
cadd569 [qiping.lqp] add api docs
46b891f [qiping.lqp] fix bug
e72c7e4 [qiping.lqp] add comments
845c6fa [qiping.lqp] fix style
f195e83 [qiping.lqp] fix style
987cbf4 [qiping.lqp] fix bug
ff34845 [qiping.lqp] separate calculation of predict of node from calculation of info gain
ac42378 [qiping.lqp] add min info gain and min instances per node parameters in decision tree
Diffstat (limited to 'python/pyspark')
-rw-r--r-- | python/pyspark/mllib/tree.py | 16 |
1 files changed, 12 insertions, 4 deletions
diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index ccc000ac70..5b13ab682b 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -138,7 +138,8 @@ class DecisionTree(object): @staticmethod def trainClassifier(data, numClasses, categoricalFeaturesInfo, - impurity="gini", maxDepth=5, maxBins=32): + impurity="gini", maxDepth=5, maxBins=32, minInstancesPerNode=1, + minInfoGain=0.0): """ Train a DecisionTreeModel for classification. @@ -154,6 +155,9 @@ class DecisionTree(object): E.g., depth 0 means 1 leaf node. Depth 1 means 1 internal node + 2 leaf nodes. :param maxBins: Number of bins used for finding splits at each node. + :param minInstancesPerNode: Min number of instances required at child nodes to create + the parent split + :param minInfoGain: Min info gain required to create a split :return: DecisionTreeModel """ sc = data.context @@ -164,13 +168,14 @@ class DecisionTree(object): model = sc._jvm.PythonMLLibAPI().trainDecisionTreeModel( dataBytes._jrdd, "classification", numClasses, categoricalFeaturesInfoJMap, - impurity, maxDepth, maxBins) + impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain) dataBytes.unpersist() return DecisionTreeModel(sc, model) @staticmethod def trainRegressor(data, categoricalFeaturesInfo, - impurity="variance", maxDepth=5, maxBins=32): + impurity="variance", maxDepth=5, maxBins=32, minInstancesPerNode=1, + minInfoGain=0.0): """ Train a DecisionTreeModel for regression. @@ -185,6 +190,9 @@ class DecisionTree(object): E.g., depth 0 means 1 leaf node. Depth 1 means 1 internal node + 2 leaf nodes. :param maxBins: Number of bins used for finding splits at each node. + :param minInstancesPerNode: Min number of instances required at child nodes to create + the parent split + :param minInfoGain: Min info gain required to create a split :return: DecisionTreeModel """ sc = data.context @@ -195,7 +203,7 @@ class DecisionTree(object): model = sc._jvm.PythonMLLibAPI().trainDecisionTreeModel( dataBytes._jrdd, "regression", 0, categoricalFeaturesInfoJMap, - impurity, maxDepth, maxBins) + impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain) dataBytes.unpersist() return DecisionTreeModel(sc, model) |