aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/regression.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/ml/regression.py')
-rw-r--r--python/pyspark/ml/regression.py16
1 files changed, 8 insertions, 8 deletions
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index 7648bf1326..944e648ec8 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -35,7 +35,7 @@ __all__ = ['AFTSurvivalRegression', 'AFTSurvivalRegressionModel',
@inherit_doc
class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,
HasRegParam, HasTol, HasElasticNetParam, HasFitIntercept,
- HasStandardization, HasSolver):
+ HasStandardization, HasSolver, HasWeightCol):
"""
Linear regression.
@@ -50,9 +50,9 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction
>>> from pyspark.mllib.linalg import Vectors
>>> df = sqlContext.createDataFrame([
- ... (1.0, Vectors.dense(1.0)),
- ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
- >>> lr = LinearRegression(maxIter=5, regParam=0.0, solver="normal")
+ ... (1.0, 2.0, Vectors.dense(1.0)),
+ ... (0.0, 2.0, Vectors.sparse(1, [], []))], ["label", "weight", "features"])
+ >>> lr = LinearRegression(maxIter=5, regParam=0.0, solver="normal", weightCol="weight")
>>> model = lr.fit(df)
>>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
>>> abs(model.transform(test0).head().prediction - (-1.0)) < 0.001
@@ -75,11 +75,11 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction
@keyword_only
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
- standardization=True, solver="auto"):
+ standardization=True, solver="auto", weightCol=None):
"""
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
- standardization=True, solver="auto")
+ standardization=True, solver="auto", weightCol=None)
"""
super(LinearRegression, self).__init__()
self._java_obj = self._new_java_obj(
@@ -92,11 +92,11 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction
@since("1.4.0")
def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
- standardization=True, solver="auto"):
+ standardization=True, solver="auto", weightCol=None):
"""
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
- standardization=True, solver="auto")
+ standardization=True, solver="auto", weightCol=None)
Sets params for linear regression.
"""
kwargs = self.setParams._input_kwargs