diff options
Diffstat (limited to 'python/pyspark/ml/classification.py')
-rw-r--r-- | python/pyspark/ml/classification.py | 75 |
1 files changed, 11 insertions, 64 deletions
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 83f808efc3..22bdd1b322 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -31,7 +31,8 @@ __all__ = ['LogisticRegression', 'LogisticRegressionModel', 'DecisionTreeClassif @inherit_doc class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, - HasRegParam, HasTol, HasProbabilityCol, HasRawPredictionCol): + HasRegParam, HasTol, HasProbabilityCol, HasRawPredictionCol, + HasElasticNetParam, HasFitIntercept, HasStandardization, HasThresholds): """ Logistic regression. Currently, this class only supports binary classification. @@ -65,17 +66,6 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti """ # a placeholder to make it appear in the generated doc - elasticNetParam = \ - Param(Params._dummy(), "elasticNetParam", - "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, " + - "the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.") - fitIntercept = Param(Params._dummy(), "fitIntercept", "whether to fit an intercept term.") - thresholds = Param(Params._dummy(), "thresholds", - "Thresholds in multi-class classification" + - " to adjust the probability of predicting each class." + - " Array must have length equal to the number of classes, with values >= 0." + - " The class with largest value p/t is predicted, where p is the original" + - " probability of that class and t is the class' threshold.") threshold = Param(Params._dummy(), "threshold", "Threshold in binary classification prediction, in range [0, 1]." + " If threshold and thresholds are both set, they must match.") @@ -83,40 +73,23 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, - threshold=0.5, thresholds=None, - probabilityCol="probability", rawPredictionCol="rawPrediction"): + threshold=0.5, thresholds=None, probabilityCol="probability", + rawPredictionCol="rawPrediction", standardization=True): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ - threshold=0.5, thresholds=None, \ - probabilityCol="probability", rawPredictionCol="rawPrediction") + threshold=0.5, thresholds=None, probabilityCol="probability", \ + rawPredictionCol="rawPrediction", standardization=True) If the threshold and thresholds Params are both set, they must be equivalent. """ super(LogisticRegression, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.classification.LogisticRegression", self.uid) - #: param for the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty - # is an L2 penalty. For alpha = 1, it is an L1 penalty. - self.elasticNetParam = \ - Param(self, "elasticNetParam", - "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, " + - "the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.") - #: param for whether to fit an intercept term. - self.fitIntercept = Param(self, "fitIntercept", "whether to fit an intercept term.") #: param for threshold in binary classification, in range [0, 1]. self.threshold = Param(self, "threshold", "Threshold in binary classification prediction, in range [0, 1]." + " If threshold and thresholds are both set, they must match.") - #: param for thresholds or cutoffs in binary or multiclass classification - self.thresholds = \ - Param(self, "thresholds", - "Thresholds in multi-class classification" + - " to adjust the probability of predicting each class." + - " Array must have length equal to the number of classes, with values >= 0." + - " The class with largest value p/t is predicted, where p is the original" + - " probability of that class and t is the class' threshold.") - self._setDefault(maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1E-6, - fitIntercept=True, threshold=0.5) + self._setDefault(maxIter=100, regParam=0.1, tol=1E-6, threshold=0.5) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) self._checkThresholdConsistency() @@ -124,13 +97,13 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti @keyword_only def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, - threshold=0.5, thresholds=None, - probabilityCol="probability", rawPredictionCol="rawPrediction"): + threshold=0.5, thresholds=None, probabilityCol="probability", + rawPredictionCol="rawPrediction", standardization=True): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ - threshold=0.5, thresholds=None, \ - probabilityCol="probability", rawPredictionCol="rawPrediction") + threshold=0.5, thresholds=None, probabilityCol="probability", \ + rawPredictionCol="rawPrediction", standardization=True) Sets params for logistic regression. If the threshold and thresholds Params are both set, they must be equivalent. """ @@ -142,32 +115,6 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti def _create_model(self, java_model): return LogisticRegressionModel(java_model) - def setElasticNetParam(self, value): - """ - Sets the value of :py:attr:`elasticNetParam`. - """ - self._paramMap[self.elasticNetParam] = value - return self - - def getElasticNetParam(self): - """ - Gets the value of elasticNetParam or its default value. - """ - return self.getOrDefault(self.elasticNetParam) - - def setFitIntercept(self, value): - """ - Sets the value of :py:attr:`fitIntercept`. - """ - self._paramMap[self.fitIntercept] = value - return self - - def getFitIntercept(self): - """ - Gets the value of fitIntercept or its default value. - """ - return self.getOrDefault(self.fitIntercept) - def setThreshold(self, value): """ Sets the value of :py:attr:`threshold`. |