diff options
author | Joseph K. Bradley <joseph.kurata.bradley@gmail.com> | 2014-08-06 22:58:59 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2014-08-06 22:58:59 -0700 |
commit | 47ccd5e71be49b723476f3ff8d5768f0f45c2ea6 (patch) | |
tree | 18b61526f97c93c4112e5a75dddefb42e0a9fafc /python | |
parent | ffd1f59a62a9dd9a4d5a7b09490b9d01ff1cd42d (diff) | |
download | spark-47ccd5e71be49b723476f3ff8d5768f0f45c2ea6.tar.gz spark-47ccd5e71be49b723476f3ff8d5768f0f45c2ea6.tar.bz2 spark-47ccd5e71be49b723476f3ff8d5768f0f45c2ea6.zip |
[SPARK-2851] [mllib] DecisionTree Python consistency update
Added 6 static train methods to match Python API, but without default arguments (but with Python default args noted in docs).
Added factory classes for Algo and Impurity, but made private[mllib].
CC: mengxr dorx Please let me know if there are other changes which would help with API consistency---thanks!
Author: Joseph K. Bradley <joseph.kurata.bradley@gmail.com>
Closes #1798 from jkbradley/dt-python-consistency and squashes the following commits:
6f7edf8 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-python-consistency
a0d7dbe [Joseph K. Bradley] DecisionTree: In Java-friendly train* methods, changed to use JavaRDD instead of RDD.
ee1d236 [Joseph K. Bradley] DecisionTree API updates: * Removed train() function in Python API (tree.py) ** Removed corresponding function in Scala/Java API (the ones taking basic types)
00f820e [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-python-consistency
fe6dbfa [Joseph K. Bradley] removed unnecessary imports
e358661 [Joseph K. Bradley] DecisionTree API change: * Added 6 static train methods to match Python API, but without default arguments (but with Python default args noted in docs).
c699850 [Joseph K. Bradley] a few doc comments
eaf84c0 [Joseph K. Bradley] Added DecisionTree static train() methods API to match Python, but without default parameters
Diffstat (limited to 'python')
-rw-r--r-- | python/pyspark/mllib/tree.py | 50 |
1 files changed, 15 insertions, 35 deletions
diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index 2518001ea0..e1a4671709 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -131,7 +131,7 @@ class DecisionTree(object): """ @staticmethod - def trainClassifier(data, numClasses, categoricalFeaturesInfo={}, + def trainClassifier(data, numClasses, categoricalFeaturesInfo, impurity="gini", maxDepth=4, maxBins=100): """ Train a DecisionTreeModel for classification. @@ -150,12 +150,20 @@ class DecisionTree(object): :param maxBins: Number of bins used for finding splits at each node. :return: DecisionTreeModel """ - return DecisionTree.train(data, "classification", numClasses, - categoricalFeaturesInfo, - impurity, maxDepth, maxBins) + sc = data.context + dataBytes = _get_unmangled_labeled_point_rdd(data) + categoricalFeaturesInfoJMap = \ + MapConverter().convert(categoricalFeaturesInfo, + sc._gateway._gateway_client) + model = sc._jvm.PythonMLLibAPI().trainDecisionTreeModel( + dataBytes._jrdd, "classification", + numClasses, categoricalFeaturesInfoJMap, + impurity, maxDepth, maxBins) + dataBytes.unpersist() + return DecisionTreeModel(sc, model) @staticmethod - def trainRegressor(data, categoricalFeaturesInfo={}, + def trainRegressor(data, categoricalFeaturesInfo, impurity="variance", maxDepth=4, maxBins=100): """ Train a DecisionTreeModel for regression. @@ -173,42 +181,14 @@ class DecisionTree(object): :param maxBins: Number of bins used for finding splits at each node. :return: DecisionTreeModel """ - return DecisionTree.train(data, "regression", 0, - categoricalFeaturesInfo, - impurity, maxDepth, maxBins) - - @staticmethod - def train(data, algo, numClasses, categoricalFeaturesInfo, - impurity, maxDepth, maxBins=100): - """ - Train a DecisionTreeModel for classification or regression. - - :param data: Training data: RDD of LabeledPoint. - For classification, labels are integers - {0,1,...,numClasses}. - For regression, labels are real numbers. - :param algo: "classification" or "regression" - :param numClasses: Number of classes for classification. - :param categoricalFeaturesInfo: Map from categorical feature index - to number of categories. - Any feature not in this map - is treated as continuous. - :param impurity: For classification: "entropy" or "gini". - For regression: "variance". - :param maxDepth: Max depth of tree. - E.g., depth 0 means 1 leaf node. - Depth 1 means 1 internal node + 2 leaf nodes. - :param maxBins: Number of bins used for finding splits at each node. - :return: DecisionTreeModel - """ sc = data.context dataBytes = _get_unmangled_labeled_point_rdd(data) categoricalFeaturesInfoJMap = \ MapConverter().convert(categoricalFeaturesInfo, sc._gateway._gateway_client) model = sc._jvm.PythonMLLibAPI().trainDecisionTreeModel( - dataBytes._jrdd, algo, - numClasses, categoricalFeaturesInfoJMap, + dataBytes._jrdd, "regression", + 0, categoricalFeaturesInfoJMap, impurity, maxDepth, maxBins) dataBytes.unpersist() return DecisionTreeModel(sc, model) |