aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorqiping.lqp <qiping.lqp@alibaba-inc.com>2014-09-15 17:43:26 -0700
committerXiangrui Meng <meng@databricks.com>2014-09-15 17:43:26 -0700
commitfdb302f49c021227026909bdcdade7496059013f (patch)
treeeb9136a917317cf4777f01454436784d89496448 /python
parent983d6a9c48b69c5f0542922aa8b133f69eb1034d (diff)
downloadspark-fdb302f49c021227026909bdcdade7496059013f.tar.gz
spark-fdb302f49c021227026909bdcdade7496059013f.tar.bz2
spark-fdb302f49c021227026909bdcdade7496059013f.zip
[SPARK-3516] [mllib] DecisionTree: Add minInstancesPerNode, minInfoGain params to example and Python API
Added minInstancesPerNode, minInfoGain params to: * DecisionTreeRunner.scala example * Python API (tree.py) Also: * Fixed typo in tree suite test "do not choose split that does not satisfy min instance per node requirements" * small style fixes CC: mengxr Author: qiping.lqp <qiping.lqp@alibaba-inc.com> Author: Joseph K. Bradley <joseph.kurata.bradley@gmail.com> Author: chouqin <liqiping1991@gmail.com> Closes #2349 from jkbradley/chouqin-dt-preprune and squashes the following commits: 61b2e72 [Joseph K. Bradley] Added max of 10GB for maxMemoryInMB in Strategy. a95e7c8 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into chouqin-dt-preprune 95c479d [Joseph K. Bradley] * Fixed typo in tree suite test "do not choose split that does not satisfy min instance per node requirements" * small style fixes e2628b6 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into chouqin-dt-preprune 19b01af [Joseph K. Bradley] Merge remote-tracking branch 'chouqin/dt-preprune' into chouqin-dt-preprune f1d11d1 [chouqin] fix typo c7ebaf1 [chouqin] fix typo 39f9b60 [chouqin] change edge `minInstancesPerNode` to 2 and add one more test c6e2dfc [Joseph K. Bradley] Added minInstancesPerNode and minInfoGain parameters to DecisionTreeRunner.scala and to Python API in tree.py 0278a11 [chouqin] remove `noSplit` and set `Predict` private to tree d593ec7 [chouqin] fix docs and change minInstancesPerNode to 1 efcc736 [qiping.lqp] fix bug 10b8012 [qiping.lqp] fix style 6728fad [qiping.lqp] minor fix: remove empty lines bb465ca [qiping.lqp] Merge branch 'master' of https://github.com/apache/spark into dt-preprune cadd569 [qiping.lqp] add api docs 46b891f [qiping.lqp] fix bug e72c7e4 [qiping.lqp] add comments 845c6fa [qiping.lqp] fix style f195e83 [qiping.lqp] fix style 987cbf4 [qiping.lqp] fix bug ff34845 [qiping.lqp] separate calculation of predict of node from calculation of info gain ac42378 [qiping.lqp] add min info gain and min instances per node parameters in decision tree
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/mllib/tree.py16
1 files changed, 12 insertions, 4 deletions
diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py
index ccc000ac70..5b13ab682b 100644
--- a/python/pyspark/mllib/tree.py
+++ b/python/pyspark/mllib/tree.py
@@ -138,7 +138,8 @@ class DecisionTree(object):
@staticmethod
def trainClassifier(data, numClasses, categoricalFeaturesInfo,
- impurity="gini", maxDepth=5, maxBins=32):
+ impurity="gini", maxDepth=5, maxBins=32, minInstancesPerNode=1,
+ minInfoGain=0.0):
"""
Train a DecisionTreeModel for classification.
@@ -154,6 +155,9 @@ class DecisionTree(object):
E.g., depth 0 means 1 leaf node.
Depth 1 means 1 internal node + 2 leaf nodes.
:param maxBins: Number of bins used for finding splits at each node.
+ :param minInstancesPerNode: Min number of instances required at child nodes to create
+ the parent split
+ :param minInfoGain: Min info gain required to create a split
:return: DecisionTreeModel
"""
sc = data.context
@@ -164,13 +168,14 @@ class DecisionTree(object):
model = sc._jvm.PythonMLLibAPI().trainDecisionTreeModel(
dataBytes._jrdd, "classification",
numClasses, categoricalFeaturesInfoJMap,
- impurity, maxDepth, maxBins)
+ impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain)
dataBytes.unpersist()
return DecisionTreeModel(sc, model)
@staticmethod
def trainRegressor(data, categoricalFeaturesInfo,
- impurity="variance", maxDepth=5, maxBins=32):
+ impurity="variance", maxDepth=5, maxBins=32, minInstancesPerNode=1,
+ minInfoGain=0.0):
"""
Train a DecisionTreeModel for regression.
@@ -185,6 +190,9 @@ class DecisionTree(object):
E.g., depth 0 means 1 leaf node.
Depth 1 means 1 internal node + 2 leaf nodes.
:param maxBins: Number of bins used for finding splits at each node.
+ :param minInstancesPerNode: Min number of instances required at child nodes to create
+ the parent split
+ :param minInfoGain: Min info gain required to create a split
:return: DecisionTreeModel
"""
sc = data.context
@@ -195,7 +203,7 @@ class DecisionTree(object):
model = sc._jvm.PythonMLLibAPI().trainDecisionTreeModel(
dataBytes._jrdd, "regression",
0, categoricalFeaturesInfoJMap,
- impurity, maxDepth, maxBins)
+ impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain)
dataBytes.unpersist()
return DecisionTreeModel(sc, model)