aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph.kurata.bradley@gmail.com>2014-08-06 22:58:59 -0700
committerXiangrui Meng <meng@databricks.com>2014-08-06 22:58:59 -0700
commit47ccd5e71be49b723476f3ff8d5768f0f45c2ea6 (patch)
tree18b61526f97c93c4112e5a75dddefb42e0a9fafc /python
parentffd1f59a62a9dd9a4d5a7b09490b9d01ff1cd42d (diff)
downloadspark-47ccd5e71be49b723476f3ff8d5768f0f45c2ea6.tar.gz
spark-47ccd5e71be49b723476f3ff8d5768f0f45c2ea6.tar.bz2
spark-47ccd5e71be49b723476f3ff8d5768f0f45c2ea6.zip
[SPARK-2851] [mllib] DecisionTree Python consistency update
Added 6 static train methods to match Python API, but without default arguments (but with Python default args noted in docs). Added factory classes for Algo and Impurity, but made private[mllib]. CC: mengxr dorx Please let me know if there are other changes which would help with API consistency---thanks! Author: Joseph K. Bradley <joseph.kurata.bradley@gmail.com> Closes #1798 from jkbradley/dt-python-consistency and squashes the following commits: 6f7edf8 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-python-consistency a0d7dbe [Joseph K. Bradley] DecisionTree: In Java-friendly train* methods, changed to use JavaRDD instead of RDD. ee1d236 [Joseph K. Bradley] DecisionTree API updates: * Removed train() function in Python API (tree.py) ** Removed corresponding function in Scala/Java API (the ones taking basic types) 00f820e [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-python-consistency fe6dbfa [Joseph K. Bradley] removed unnecessary imports e358661 [Joseph K. Bradley] DecisionTree API change: * Added 6 static train methods to match Python API, but without default arguments (but with Python default args noted in docs). c699850 [Joseph K. Bradley] a few doc comments eaf84c0 [Joseph K. Bradley] Added DecisionTree static train() methods API to match Python, but without default parameters
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/mllib/tree.py50
1 files changed, 15 insertions, 35 deletions
diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py
index 2518001ea0..e1a4671709 100644
--- a/python/pyspark/mllib/tree.py
+++ b/python/pyspark/mllib/tree.py
@@ -131,7 +131,7 @@ class DecisionTree(object):
"""
@staticmethod
- def trainClassifier(data, numClasses, categoricalFeaturesInfo={},
+ def trainClassifier(data, numClasses, categoricalFeaturesInfo,
impurity="gini", maxDepth=4, maxBins=100):
"""
Train a DecisionTreeModel for classification.
@@ -150,12 +150,20 @@ class DecisionTree(object):
:param maxBins: Number of bins used for finding splits at each node.
:return: DecisionTreeModel
"""
- return DecisionTree.train(data, "classification", numClasses,
- categoricalFeaturesInfo,
- impurity, maxDepth, maxBins)
+ sc = data.context
+ dataBytes = _get_unmangled_labeled_point_rdd(data)
+ categoricalFeaturesInfoJMap = \
+ MapConverter().convert(categoricalFeaturesInfo,
+ sc._gateway._gateway_client)
+ model = sc._jvm.PythonMLLibAPI().trainDecisionTreeModel(
+ dataBytes._jrdd, "classification",
+ numClasses, categoricalFeaturesInfoJMap,
+ impurity, maxDepth, maxBins)
+ dataBytes.unpersist()
+ return DecisionTreeModel(sc, model)
@staticmethod
- def trainRegressor(data, categoricalFeaturesInfo={},
+ def trainRegressor(data, categoricalFeaturesInfo,
impurity="variance", maxDepth=4, maxBins=100):
"""
Train a DecisionTreeModel for regression.
@@ -173,42 +181,14 @@ class DecisionTree(object):
:param maxBins: Number of bins used for finding splits at each node.
:return: DecisionTreeModel
"""
- return DecisionTree.train(data, "regression", 0,
- categoricalFeaturesInfo,
- impurity, maxDepth, maxBins)
-
- @staticmethod
- def train(data, algo, numClasses, categoricalFeaturesInfo,
- impurity, maxDepth, maxBins=100):
- """
- Train a DecisionTreeModel for classification or regression.
-
- :param data: Training data: RDD of LabeledPoint.
- For classification, labels are integers
- {0,1,...,numClasses}.
- For regression, labels are real numbers.
- :param algo: "classification" or "regression"
- :param numClasses: Number of classes for classification.
- :param categoricalFeaturesInfo: Map from categorical feature index
- to number of categories.
- Any feature not in this map
- is treated as continuous.
- :param impurity: For classification: "entropy" or "gini".
- For regression: "variance".
- :param maxDepth: Max depth of tree.
- E.g., depth 0 means 1 leaf node.
- Depth 1 means 1 internal node + 2 leaf nodes.
- :param maxBins: Number of bins used for finding splits at each node.
- :return: DecisionTreeModel
- """
sc = data.context
dataBytes = _get_unmangled_labeled_point_rdd(data)
categoricalFeaturesInfoJMap = \
MapConverter().convert(categoricalFeaturesInfo,
sc._gateway._gateway_client)
model = sc._jvm.PythonMLLibAPI().trainDecisionTreeModel(
- dataBytes._jrdd, algo,
- numClasses, categoricalFeaturesInfoJMap,
+ dataBytes._jrdd, "regression",
+ 0, categoricalFeaturesInfoJMap,
impurity, maxDepth, maxBins)
dataBytes.unpersist()
return DecisionTreeModel(sc, model)