aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/mllib/tree.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/mllib/tree.py')
-rw-r--r--python/pyspark/mllib/tree.py50
1 files changed, 15 insertions, 35 deletions
diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py
index 2518001ea0..e1a4671709 100644
--- a/python/pyspark/mllib/tree.py
+++ b/python/pyspark/mllib/tree.py
@@ -131,7 +131,7 @@ class DecisionTree(object):
"""
@staticmethod
- def trainClassifier(data, numClasses, categoricalFeaturesInfo={},
+ def trainClassifier(data, numClasses, categoricalFeaturesInfo,
impurity="gini", maxDepth=4, maxBins=100):
"""
Train a DecisionTreeModel for classification.
@@ -150,12 +150,20 @@ class DecisionTree(object):
:param maxBins: Number of bins used for finding splits at each node.
:return: DecisionTreeModel
"""
- return DecisionTree.train(data, "classification", numClasses,
- categoricalFeaturesInfo,
- impurity, maxDepth, maxBins)
+ sc = data.context
+ dataBytes = _get_unmangled_labeled_point_rdd(data)
+ categoricalFeaturesInfoJMap = \
+ MapConverter().convert(categoricalFeaturesInfo,
+ sc._gateway._gateway_client)
+ model = sc._jvm.PythonMLLibAPI().trainDecisionTreeModel(
+ dataBytes._jrdd, "classification",
+ numClasses, categoricalFeaturesInfoJMap,
+ impurity, maxDepth, maxBins)
+ dataBytes.unpersist()
+ return DecisionTreeModel(sc, model)
@staticmethod
- def trainRegressor(data, categoricalFeaturesInfo={},
+ def trainRegressor(data, categoricalFeaturesInfo,
impurity="variance", maxDepth=4, maxBins=100):
"""
Train a DecisionTreeModel for regression.
@@ -173,42 +181,14 @@ class DecisionTree(object):
:param maxBins: Number of bins used for finding splits at each node.
:return: DecisionTreeModel
"""
- return DecisionTree.train(data, "regression", 0,
- categoricalFeaturesInfo,
- impurity, maxDepth, maxBins)
-
- @staticmethod
- def train(data, algo, numClasses, categoricalFeaturesInfo,
- impurity, maxDepth, maxBins=100):
- """
- Train a DecisionTreeModel for classification or regression.
-
- :param data: Training data: RDD of LabeledPoint.
- For classification, labels are integers
- {0,1,...,numClasses}.
- For regression, labels are real numbers.
- :param algo: "classification" or "regression"
- :param numClasses: Number of classes for classification.
- :param categoricalFeaturesInfo: Map from categorical feature index
- to number of categories.
- Any feature not in this map
- is treated as continuous.
- :param impurity: For classification: "entropy" or "gini".
- For regression: "variance".
- :param maxDepth: Max depth of tree.
- E.g., depth 0 means 1 leaf node.
- Depth 1 means 1 internal node + 2 leaf nodes.
- :param maxBins: Number of bins used for finding splits at each node.
- :return: DecisionTreeModel
- """
sc = data.context
dataBytes = _get_unmangled_labeled_point_rdd(data)
categoricalFeaturesInfoJMap = \
MapConverter().convert(categoricalFeaturesInfo,
sc._gateway._gateway_client)
model = sc._jvm.PythonMLLibAPI().trainDecisionTreeModel(
- dataBytes._jrdd, algo,
- numClasses, categoricalFeaturesInfoJMap,
+ dataBytes._jrdd, "regression",
+ 0, categoricalFeaturesInfoJMap,
impurity, maxDepth, maxBins)
dataBytes.unpersist()
return DecisionTreeModel(sc, model)