aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/classification.py
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2015-11-18 13:32:06 -0800
committerXiangrui Meng <meng@databricks.com>2015-11-18 13:32:06 -0800
commit603a721c21488e17c15c45ce1de893e6b3d02274 (patch)
tree4f9d763ea38fcc7419b0be4b5a396e47ee18b416 /python/pyspark/ml/classification.py
parente222d758499ad2609046cc1a2cc8afb45c5bccbb (diff)
downloadspark-603a721c21488e17c15c45ce1de893e6b3d02274.tar.gz
spark-603a721c21488e17c15c45ce1de893e6b3d02274.tar.bz2
spark-603a721c21488e17c15c45ce1de893e6b3d02274.zip
[SPARK-11820][ML][PYSPARK] PySpark LiR & LoR should support weightCol
[SPARK-7685](https://issues.apache.org/jira/browse/SPARK-7685) and [SPARK-9642](https://issues.apache.org/jira/browse/SPARK-9642) have already supported setting weight column for ```LogisticRegression``` and ```LinearRegression```. It's a very important feature, PySpark should also support. mengxr Author: Yanbo Liang <ybliang8@gmail.com> Closes #9811 from yanboliang/spark-11820.
Diffstat (limited to 'python/pyspark/ml/classification.py')
-rw-r--r--python/pyspark/ml/classification.py17
1 files changed, 9 insertions, 8 deletions
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index 603f2c7f79..4a2982e204 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -36,7 +36,8 @@ __all__ = ['LogisticRegression', 'LogisticRegressionModel', 'DecisionTreeClassif
@inherit_doc
class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,
HasRegParam, HasTol, HasProbabilityCol, HasRawPredictionCol,
- HasElasticNetParam, HasFitIntercept, HasStandardization, HasThresholds):
+ HasElasticNetParam, HasFitIntercept, HasStandardization, HasThresholds,
+ HasWeightCol):
"""
Logistic regression.
Currently, this class only supports binary classification.
@@ -44,9 +45,9 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
>>> from pyspark.sql import Row
>>> from pyspark.mllib.linalg import Vectors
>>> df = sc.parallelize([
- ... Row(label=1.0, features=Vectors.dense(1.0)),
- ... Row(label=0.0, features=Vectors.sparse(1, [], []))]).toDF()
- >>> lr = LogisticRegression(maxIter=5, regParam=0.01)
+ ... Row(label=1.0, weight=2.0, features=Vectors.dense(1.0)),
+ ... Row(label=0.0, weight=2.0, features=Vectors.sparse(1, [], []))]).toDF()
+ >>> lr = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight")
>>> model = lr.fit(df)
>>> model.weights
DenseVector([5.5...])
@@ -80,12 +81,12 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
threshold=0.5, thresholds=None, probabilityCol="probability",
- rawPredictionCol="rawPrediction", standardization=True):
+ rawPredictionCol="rawPrediction", standardization=True, weightCol=None):
"""
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
threshold=0.5, thresholds=None, probabilityCol="probability", \
- rawPredictionCol="rawPrediction", standardization=True)
+ rawPredictionCol="rawPrediction", standardization=True, weightCol=None)
If the threshold and thresholds Params are both set, they must be equivalent.
"""
super(LogisticRegression, self).__init__()
@@ -105,12 +106,12 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
threshold=0.5, thresholds=None, probabilityCol="probability",
- rawPredictionCol="rawPrediction", standardization=True):
+ rawPredictionCol="rawPrediction", standardization=True, weightCol=None):
"""
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
threshold=0.5, thresholds=None, probabilityCol="probability", \
- rawPredictionCol="rawPrediction", standardization=True)
+ rawPredictionCol="rawPrediction", standardization=True, weightCol=None)
Sets params for logistic regression.
If the threshold and thresholds Params are both set, they must be equivalent.
"""