aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/mllib/tree.py16
1 files changed, 12 insertions, 4 deletions
diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py
index ccc000ac70..5b13ab682b 100644
--- a/python/pyspark/mllib/tree.py
+++ b/python/pyspark/mllib/tree.py
@@ -138,7 +138,8 @@ class DecisionTree(object):
@staticmethod
def trainClassifier(data, numClasses, categoricalFeaturesInfo,
- impurity="gini", maxDepth=5, maxBins=32):
+ impurity="gini", maxDepth=5, maxBins=32, minInstancesPerNode=1,
+ minInfoGain=0.0):
"""
Train a DecisionTreeModel for classification.
@@ -154,6 +155,9 @@ class DecisionTree(object):
E.g., depth 0 means 1 leaf node.
Depth 1 means 1 internal node + 2 leaf nodes.
:param maxBins: Number of bins used for finding splits at each node.
+ :param minInstancesPerNode: Min number of instances required at child nodes to create
+ the parent split
+ :param minInfoGain: Min info gain required to create a split
:return: DecisionTreeModel
"""
sc = data.context
@@ -164,13 +168,14 @@ class DecisionTree(object):
model = sc._jvm.PythonMLLibAPI().trainDecisionTreeModel(
dataBytes._jrdd, "classification",
numClasses, categoricalFeaturesInfoJMap,
- impurity, maxDepth, maxBins)
+ impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain)
dataBytes.unpersist()
return DecisionTreeModel(sc, model)
@staticmethod
def trainRegressor(data, categoricalFeaturesInfo,
- impurity="variance", maxDepth=5, maxBins=32):
+ impurity="variance", maxDepth=5, maxBins=32, minInstancesPerNode=1,
+ minInfoGain=0.0):
"""
Train a DecisionTreeModel for regression.
@@ -185,6 +190,9 @@ class DecisionTree(object):
E.g., depth 0 means 1 leaf node.
Depth 1 means 1 internal node + 2 leaf nodes.
:param maxBins: Number of bins used for finding splits at each node.
+ :param minInstancesPerNode: Min number of instances required at child nodes to create
+ the parent split
+ :param minInfoGain: Min info gain required to create a split
:return: DecisionTreeModel
"""
sc = data.context
@@ -195,7 +203,7 @@ class DecisionTree(object):
model = sc._jvm.PythonMLLibAPI().trainDecisionTreeModel(
dataBytes._jrdd, "regression",
0, categoricalFeaturesInfoJMap,
- impurity, maxDepth, maxBins)
+ impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain)
dataBytes.unpersist()
return DecisionTreeModel(sc, model)