aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/classification.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/ml/classification.py')
-rw-r--r--python/pyspark/ml/classification.py75
1 files changed, 11 insertions, 64 deletions
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index 83f808efc3..22bdd1b322 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -31,7 +31,8 @@ __all__ = ['LogisticRegression', 'LogisticRegressionModel', 'DecisionTreeClassif
@inherit_doc
class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,
- HasRegParam, HasTol, HasProbabilityCol, HasRawPredictionCol):
+ HasRegParam, HasTol, HasProbabilityCol, HasRawPredictionCol,
+ HasElasticNetParam, HasFitIntercept, HasStandardization, HasThresholds):
"""
Logistic regression.
Currently, this class only supports binary classification.
@@ -65,17 +66,6 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
"""
# a placeholder to make it appear in the generated doc
- elasticNetParam = \
- Param(Params._dummy(), "elasticNetParam",
- "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, " +
- "the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.")
- fitIntercept = Param(Params._dummy(), "fitIntercept", "whether to fit an intercept term.")
- thresholds = Param(Params._dummy(), "thresholds",
- "Thresholds in multi-class classification" +
- " to adjust the probability of predicting each class." +
- " Array must have length equal to the number of classes, with values >= 0." +
- " The class with largest value p/t is predicted, where p is the original" +
- " probability of that class and t is the class' threshold.")
threshold = Param(Params._dummy(), "threshold",
"Threshold in binary classification prediction, in range [0, 1]." +
" If threshold and thresholds are both set, they must match.")
@@ -83,40 +73,23 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
@keyword_only
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
- threshold=0.5, thresholds=None,
- probabilityCol="probability", rawPredictionCol="rawPrediction"):
+ threshold=0.5, thresholds=None, probabilityCol="probability",
+ rawPredictionCol="rawPrediction", standardization=True):
"""
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
- threshold=0.5, thresholds=None, \
- probabilityCol="probability", rawPredictionCol="rawPrediction")
+ threshold=0.5, thresholds=None, probabilityCol="probability", \
+ rawPredictionCol="rawPrediction", standardization=True)
If the threshold and thresholds Params are both set, they must be equivalent.
"""
super(LogisticRegression, self).__init__()
self._java_obj = self._new_java_obj(
"org.apache.spark.ml.classification.LogisticRegression", self.uid)
- #: param for the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty
- # is an L2 penalty. For alpha = 1, it is an L1 penalty.
- self.elasticNetParam = \
- Param(self, "elasticNetParam",
- "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, " +
- "the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.")
- #: param for whether to fit an intercept term.
- self.fitIntercept = Param(self, "fitIntercept", "whether to fit an intercept term.")
#: param for threshold in binary classification, in range [0, 1].
self.threshold = Param(self, "threshold",
"Threshold in binary classification prediction, in range [0, 1]." +
" If threshold and thresholds are both set, they must match.")
- #: param for thresholds or cutoffs in binary or multiclass classification
- self.thresholds = \
- Param(self, "thresholds",
- "Thresholds in multi-class classification" +
- " to adjust the probability of predicting each class." +
- " Array must have length equal to the number of classes, with values >= 0." +
- " The class with largest value p/t is predicted, where p is the original" +
- " probability of that class and t is the class' threshold.")
- self._setDefault(maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1E-6,
- fitIntercept=True, threshold=0.5)
+ self._setDefault(maxIter=100, regParam=0.1, tol=1E-6, threshold=0.5)
kwargs = self.__init__._input_kwargs
self.setParams(**kwargs)
self._checkThresholdConsistency()
@@ -124,13 +97,13 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
@keyword_only
def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
- threshold=0.5, thresholds=None,
- probabilityCol="probability", rawPredictionCol="rawPrediction"):
+ threshold=0.5, thresholds=None, probabilityCol="probability",
+ rawPredictionCol="rawPrediction", standardization=True):
"""
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
- threshold=0.5, thresholds=None, \
- probabilityCol="probability", rawPredictionCol="rawPrediction")
+ threshold=0.5, thresholds=None, probabilityCol="probability", \
+ rawPredictionCol="rawPrediction", standardization=True)
Sets params for logistic regression.
If the threshold and thresholds Params are both set, they must be equivalent.
"""
@@ -142,32 +115,6 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
def _create_model(self, java_model):
return LogisticRegressionModel(java_model)
- def setElasticNetParam(self, value):
- """
- Sets the value of :py:attr:`elasticNetParam`.
- """
- self._paramMap[self.elasticNetParam] = value
- return self
-
- def getElasticNetParam(self):
- """
- Gets the value of elasticNetParam or its default value.
- """
- return self.getOrDefault(self.elasticNetParam)
-
- def setFitIntercept(self, value):
- """
- Sets the value of :py:attr:`fitIntercept`.
- """
- self._paramMap[self.fitIntercept] = value
- return self
-
- def getFitIntercept(self):
- """
- Gets the value of fitIntercept or its default value.
- """
- return self.getOrDefault(self.fitIntercept)
-
def setThreshold(self, value):
"""
Sets the value of :py:attr:`threshold`.