aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/classification.py
diff options
context:
space:
mode:
authorvectorijk <jiangkai@gmail.com>2015-10-27 13:55:03 -0700
committerXiangrui Meng <meng@databricks.com>2015-10-27 13:55:03 -0700
commit9dba5fb2b59174cefde5b62a5c892fe5925bea38 (patch)
treec94ad0820a14a5b120354d76d71ed4cb0cf0fa49 /python/pyspark/ml/classification.py
parent5a5f65905a202e59bc85170b01c57a883718ddf6 (diff)
downloadspark-9dba5fb2b59174cefde5b62a5c892fe5925bea38.tar.gz
spark-9dba5fb2b59174cefde5b62a5c892fe5925bea38.tar.bz2
spark-9dba5fb2b59174cefde5b62a5c892fe5925bea38.zip
[SPARK-10024][PYSPARK] Python API RF and GBT related params clear up
implement {RandomForest, GBT, TreeEnsemble, TreeClassifier, TreeRegressor}Params for Python API in pyspark/ml/{classification, regression}.py Author: vectorijk <jiangkai@gmail.com> Closes #9233 from vectorijk/spark-10024.
Diffstat (limited to 'python/pyspark/ml/classification.py')
-rw-r--r--python/pyspark/ml/classification.py182
1 files changed, 31 insertions, 151 deletions
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index 88815e561f..4cbe7fbd48 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -19,7 +19,7 @@ from pyspark.ml.util import keyword_only
from pyspark.ml.wrapper import JavaEstimator, JavaModel
from pyspark.ml.param.shared import *
from pyspark.ml.regression import (
- RandomForestParams, DecisionTreeModel, TreeEnsembleModels)
+ RandomForestParams, TreeEnsembleParams, DecisionTreeModel, TreeEnsembleModels)
from pyspark.mllib.common import inherit_doc
@@ -205,8 +205,34 @@ class TreeClassifierParams(object):
"""
supportedImpurities = ["entropy", "gini"]
+ # a placeholder to make it appear in the generated doc
+ impurity = Param(Params._dummy(), "impurity",
+ "Criterion used for information gain calculation (case-insensitive). " +
+ "Supported options: " +
+ ", ".join(supportedImpurities))
+
+ def __init__(self):
+ super(TreeClassifierParams, self).__init__()
+ #: param for Criterion used for information gain calculation (case-insensitive).
+ self.impurity = Param(self, "impurity", "Criterion used for information " +
+ "gain calculation (case-insensitive). Supported options: " +
+ ", ".join(self.supportedImpurities))
+
+ def setImpurity(self, value):
+ """
+ Sets the value of :py:attr:`impurity`.
+ """
+ self._paramMap[self.impurity] = value
+ return self
-class GBTParams(object):
+ def getImpurity(self):
+ """
+ Gets the value of impurity or its default value.
+ """
+ return self.getOrDefault(self.impurity)
+
+
+class GBTParams(TreeEnsembleParams):
"""
Private class to track supported GBT params.
"""
@@ -216,7 +242,7 @@ class GBTParams(object):
@inherit_doc
class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
HasProbabilityCol, HasRawPredictionCol, DecisionTreeParams,
- HasCheckpointInterval):
+ TreeClassifierParams, HasCheckpointInterval):
"""
`http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree`
learning algorithm for classification.
@@ -250,11 +276,6 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
1.0
"""
- # a placeholder to make it appear in the generated doc
- impurity = Param(Params._dummy(), "impurity",
- "Criterion used for information gain calculation (case-insensitive). " +
- "Supported options: " + ", ".join(TreeClassifierParams.supportedImpurities))
-
@keyword_only
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
probabilityCol="probability", rawPredictionCol="rawPrediction",
@@ -269,11 +290,6 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
super(DecisionTreeClassifier, self).__init__()
self._java_obj = self._new_java_obj(
"org.apache.spark.ml.classification.DecisionTreeClassifier", self.uid)
- #: param for Criterion used for information gain calculation (case-insensitive).
- self.impurity = \
- Param(self, "impurity",
- "Criterion used for information gain calculation (case-insensitive). " +
- "Supported options: " + ", ".join(TreeClassifierParams.supportedImpurities))
self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
impurity="gini")
@@ -299,19 +315,6 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
def _create_model(self, java_model):
return DecisionTreeClassificationModel(java_model)
- def setImpurity(self, value):
- """
- Sets the value of :py:attr:`impurity`.
- """
- self._paramMap[self.impurity] = value
- return self
-
- def getImpurity(self):
- """
- Gets the value of impurity or its default value.
- """
- return self.getOrDefault(self.impurity)
-
@inherit_doc
class DecisionTreeClassificationModel(DecisionTreeModel):
@@ -323,7 +326,7 @@ class DecisionTreeClassificationModel(DecisionTreeModel):
@inherit_doc
class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasSeed,
HasRawPredictionCol, HasProbabilityCol,
- DecisionTreeParams, HasCheckpointInterval):
+ RandomForestParams, TreeClassifierParams, HasCheckpointInterval):
"""
`http://en.wikipedia.org/wiki/Random_forest Random Forest`
learning algorithm for classification.
@@ -357,19 +360,6 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
1.0
"""
- # a placeholder to make it appear in the generated doc
- impurity = Param(Params._dummy(), "impurity",
- "Criterion used for information gain calculation (case-insensitive). " +
- "Supported options: " + ", ".join(TreeClassifierParams.supportedImpurities))
- subsamplingRate = Param(Params._dummy(), "subsamplingRate",
- "Fraction of the training data used for learning each decision tree, " +
- "in range (0, 1].")
- numTrees = Param(Params._dummy(), "numTrees", "Number of trees to train (>= 1)")
- featureSubsetStrategy = \
- Param(Params._dummy(), "featureSubsetStrategy",
- "The number of features to consider for splits at each tree node. Supported " +
- "options: " + ", ".join(RandomForestParams.supportedFeatureSubsetStrategies))
-
@keyword_only
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
probabilityCol="probability", rawPredictionCol="rawPrediction",
@@ -386,23 +376,6 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
super(RandomForestClassifier, self).__init__()
self._java_obj = self._new_java_obj(
"org.apache.spark.ml.classification.RandomForestClassifier", self.uid)
- #: param for Criterion used for information gain calculation (case-insensitive).
- self.impurity = \
- Param(self, "impurity",
- "Criterion used for information gain calculation (case-insensitive). " +
- "Supported options: " + ", ".join(TreeClassifierParams.supportedImpurities))
- #: param for Fraction of the training data used for learning each decision tree,
- # in range (0, 1]
- self.subsamplingRate = Param(self, "subsamplingRate",
- "Fraction of the training data used for learning each " +
- "decision tree, in range (0, 1].")
- #: param for Number of trees to train (>= 1)
- self.numTrees = Param(self, "numTrees", "Number of trees to train (>= 1)")
- #: param for The number of features to consider for splits at each tree node
- self.featureSubsetStrategy = \
- Param(self, "featureSubsetStrategy",
- "The number of features to consider for splits at each tree node. Supported " +
- "options: " + ", ".join(RandomForestParams.supportedFeatureSubsetStrategies))
self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None,
impurity="gini", numTrees=20, featureSubsetStrategy="auto")
@@ -429,58 +402,6 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
def _create_model(self, java_model):
return RandomForestClassificationModel(java_model)
- def setImpurity(self, value):
- """
- Sets the value of :py:attr:`impurity`.
- """
- self._paramMap[self.impurity] = value
- return self
-
- def getImpurity(self):
- """
- Gets the value of impurity or its default value.
- """
- return self.getOrDefault(self.impurity)
-
- def setSubsamplingRate(self, value):
- """
- Sets the value of :py:attr:`subsamplingRate`.
- """
- self._paramMap[self.subsamplingRate] = value
- return self
-
- def getSubsamplingRate(self):
- """
- Gets the value of subsamplingRate or its default value.
- """
- return self.getOrDefault(self.subsamplingRate)
-
- def setNumTrees(self, value):
- """
- Sets the value of :py:attr:`numTrees`.
- """
- self._paramMap[self.numTrees] = value
- return self
-
- def getNumTrees(self):
- """
- Gets the value of numTrees or its default value.
- """
- return self.getOrDefault(self.numTrees)
-
- def setFeatureSubsetStrategy(self, value):
- """
- Sets the value of :py:attr:`featureSubsetStrategy`.
- """
- self._paramMap[self.featureSubsetStrategy] = value
- return self
-
- def getFeatureSubsetStrategy(self):
- """
- Gets the value of featureSubsetStrategy or its default value.
- """
- return self.getOrDefault(self.featureSubsetStrategy)
-
class RandomForestClassificationModel(TreeEnsembleModels):
"""
@@ -490,7 +411,7 @@ class RandomForestClassificationModel(TreeEnsembleModels):
@inherit_doc
class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,
- DecisionTreeParams, HasCheckpointInterval):
+ GBTParams, HasCheckpointInterval, HasStepSize, HasSeed):
"""
`http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)`
learning algorithm for classification.
@@ -522,12 +443,6 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol
lossType = Param(Params._dummy(), "lossType",
"Loss function which GBT tries to minimize (case-insensitive). " +
"Supported options: " + ", ".join(GBTParams.supportedLossTypes))
- subsamplingRate = Param(Params._dummy(), "subsamplingRate",
- "Fraction of the training data used for learning each decision tree, " +
- "in range (0, 1].")
- stepSize = Param(Params._dummy(), "stepSize",
- "Step size (a.k.a. learning rate) in interval (0, 1] for shrinking the " +
- "contribution of each estimator")
@keyword_only
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
@@ -547,15 +462,6 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol
self.lossType = Param(self, "lossType",
"Loss function which GBT tries to minimize (case-insensitive). " +
"Supported options: " + ", ".join(GBTParams.supportedLossTypes))
- #: Fraction of the training data used for learning each decision tree, in range (0, 1].
- self.subsamplingRate = Param(self, "subsamplingRate",
- "Fraction of the training data used for learning each " +
- "decision tree, in range (0, 1].")
- #: Step size (a.k.a. learning rate) in interval (0, 1] for shrinking the contribution of
- # each estimator
- self.stepSize = Param(self, "stepSize",
- "Step size (a.k.a. learning rate) in interval (0, 1] for shrinking " +
- "the contribution of each estimator")
self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
lossType="logistic", maxIter=20, stepSize=0.1)
@@ -593,32 +499,6 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol
"""
return self.getOrDefault(self.lossType)
- def setSubsamplingRate(self, value):
- """
- Sets the value of :py:attr:`subsamplingRate`.
- """
- self._paramMap[self.subsamplingRate] = value
- return self
-
- def getSubsamplingRate(self):
- """
- Gets the value of subsamplingRate or its default value.
- """
- return self.getOrDefault(self.subsamplingRate)
-
- def setStepSize(self, value):
- """
- Sets the value of :py:attr:`stepSize`.
- """
- self._paramMap[self.stepSize] = value
- return self
-
- def getStepSize(self):
- """
- Gets the value of stepSize or its default value.
- """
- return self.getOrDefault(self.stepSize)
-
class GBTClassificationModel(TreeEnsembleModels):
"""