aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/classification.py
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2015-09-11 08:50:35 -0700
committerXiangrui Meng <meng@databricks.com>2015-09-11 08:50:35 -0700
commitb656e6134fc5cd27e1fe6b6ab30fd7633cab0b14 (patch)
tree10d2d556a148adab585979cc387109588c6fda43 /python/pyspark/ml/classification.py
parentc268ca4ddde2f5213b2e3985dcaaac5900aea71c (diff)
downloadspark-b656e6134fc5cd27e1fe6b6ab30fd7633cab0b14.tar.gz
spark-b656e6134fc5cd27e1fe6b6ab30fd7633cab0b14.tar.bz2
spark-b656e6134fc5cd27e1fe6b6ab30fd7633cab0b14.zip
[SPARK-10026] [ML] [PySpark] Implement some common Params for regression in PySpark
LinearRegression and LogisticRegression lack of some Params for Python, and some Params are not shared classes which lead we need to write them for each class. These kinds of Params are list here: ```scala HasElasticNetParam HasFitIntercept HasStandardization HasThresholds ``` Here we implement them in shared params at Python side and make LinearRegression/LogisticRegression parameters peer with Scala one. Author: Yanbo Liang <ybliang8@gmail.com> Closes #8508 from yanboliang/spark-10026.
Diffstat (limited to 'python/pyspark/ml/classification.py')
-rw-r--r--python/pyspark/ml/classification.py75
1 files changed, 11 insertions, 64 deletions
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index 83f808efc3..22bdd1b322 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -31,7 +31,8 @@ __all__ = ['LogisticRegression', 'LogisticRegressionModel', 'DecisionTreeClassif
@inherit_doc
class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,
- HasRegParam, HasTol, HasProbabilityCol, HasRawPredictionCol):
+ HasRegParam, HasTol, HasProbabilityCol, HasRawPredictionCol,
+ HasElasticNetParam, HasFitIntercept, HasStandardization, HasThresholds):
"""
Logistic regression.
Currently, this class only supports binary classification.
@@ -65,17 +66,6 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
"""
# a placeholder to make it appear in the generated doc
- elasticNetParam = \
- Param(Params._dummy(), "elasticNetParam",
- "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, " +
- "the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.")
- fitIntercept = Param(Params._dummy(), "fitIntercept", "whether to fit an intercept term.")
- thresholds = Param(Params._dummy(), "thresholds",
- "Thresholds in multi-class classification" +
- " to adjust the probability of predicting each class." +
- " Array must have length equal to the number of classes, with values >= 0." +
- " The class with largest value p/t is predicted, where p is the original" +
- " probability of that class and t is the class' threshold.")
threshold = Param(Params._dummy(), "threshold",
"Threshold in binary classification prediction, in range [0, 1]." +
" If threshold and thresholds are both set, they must match.")
@@ -83,40 +73,23 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
@keyword_only
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
- threshold=0.5, thresholds=None,
- probabilityCol="probability", rawPredictionCol="rawPrediction"):
+ threshold=0.5, thresholds=None, probabilityCol="probability",
+ rawPredictionCol="rawPrediction", standardization=True):
"""
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
- threshold=0.5, thresholds=None, \
- probabilityCol="probability", rawPredictionCol="rawPrediction")
+ threshold=0.5, thresholds=None, probabilityCol="probability", \
+ rawPredictionCol="rawPrediction", standardization=True)
If the threshold and thresholds Params are both set, they must be equivalent.
"""
super(LogisticRegression, self).__init__()
self._java_obj = self._new_java_obj(
"org.apache.spark.ml.classification.LogisticRegression", self.uid)
- #: param for the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty
- # is an L2 penalty. For alpha = 1, it is an L1 penalty.
- self.elasticNetParam = \
- Param(self, "elasticNetParam",
- "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, " +
- "the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.")
- #: param for whether to fit an intercept term.
- self.fitIntercept = Param(self, "fitIntercept", "whether to fit an intercept term.")
#: param for threshold in binary classification, in range [0, 1].
self.threshold = Param(self, "threshold",
"Threshold in binary classification prediction, in range [0, 1]." +
" If threshold and thresholds are both set, they must match.")
- #: param for thresholds or cutoffs in binary or multiclass classification
- self.thresholds = \
- Param(self, "thresholds",
- "Thresholds in multi-class classification" +
- " to adjust the probability of predicting each class." +
- " Array must have length equal to the number of classes, with values >= 0." +
- " The class with largest value p/t is predicted, where p is the original" +
- " probability of that class and t is the class' threshold.")
- self._setDefault(maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1E-6,
- fitIntercept=True, threshold=0.5)
+ self._setDefault(maxIter=100, regParam=0.1, tol=1E-6, threshold=0.5)
kwargs = self.__init__._input_kwargs
self.setParams(**kwargs)
self._checkThresholdConsistency()
@@ -124,13 +97,13 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
@keyword_only
def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
- threshold=0.5, thresholds=None,
- probabilityCol="probability", rawPredictionCol="rawPrediction"):
+ threshold=0.5, thresholds=None, probabilityCol="probability",
+ rawPredictionCol="rawPrediction", standardization=True):
"""
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
- threshold=0.5, thresholds=None, \
- probabilityCol="probability", rawPredictionCol="rawPrediction")
+ threshold=0.5, thresholds=None, probabilityCol="probability", \
+ rawPredictionCol="rawPrediction", standardization=True)
Sets params for logistic regression.
If the threshold and thresholds Params are both set, they must be equivalent.
"""
@@ -142,32 +115,6 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
def _create_model(self, java_model):
return LogisticRegressionModel(java_model)
- def setElasticNetParam(self, value):
- """
- Sets the value of :py:attr:`elasticNetParam`.
- """
- self._paramMap[self.elasticNetParam] = value
- return self
-
- def getElasticNetParam(self):
- """
- Gets the value of elasticNetParam or its default value.
- """
- return self.getOrDefault(self.elasticNetParam)
-
- def setFitIntercept(self, value):
- """
- Sets the value of :py:attr:`fitIntercept`.
- """
- self._paramMap[self.fitIntercept] = value
- return self
-
- def getFitIntercept(self):
- """
- Gets the value of fitIntercept or its default value.
- """
- return self.getOrDefault(self.fitIntercept)
-
def setThreshold(self, value):
"""
Sets the value of :py:attr:`threshold`.