aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2017-03-08 02:09:36 -0800
committerYanbo Liang <ybliang8@gmail.com>2017-03-08 02:09:36 -0800
commit81303f7ca7808d51229411dce8feeed8c23dbe15 (patch)
treee7f6ebae38cfbd2877c933bafa63b21d766df531 /python/pyspark
parent1fa58868bc6635ff2119264665bd3d00b4b1253a (diff)
downloadspark-81303f7ca7808d51229411dce8feeed8c23dbe15.tar.gz
spark-81303f7ca7808d51229411dce8feeed8c23dbe15.tar.bz2
spark-81303f7ca7808d51229411dce8feeed8c23dbe15.zip
[SPARK-19806][ML][PYSPARK] PySpark GeneralizedLinearRegression supports tweedie distribution.
## What changes were proposed in this pull request? PySpark ```GeneralizedLinearRegression``` supports tweedie distribution. ## How was this patch tested? Add unit tests. Author: Yanbo Liang <ybliang8@gmail.com> Closes #17146 from yanboliang/spark-19806.
Diffstat (limited to 'python/pyspark')
-rw-r--r--python/pyspark/ml/regression.py61
-rwxr-xr-xpython/pyspark/ml/tests.py20
2 files changed, 73 insertions, 8 deletions
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index b199bf282e..3c3fcc8d9b 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -1294,8 +1294,8 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha
Fit a Generalized Linear Model specified by giving a symbolic description of the linear
predictor (link function) and a description of the error distribution (family). It supports
- "gaussian", "binomial", "poisson" and "gamma" as family. Valid link functions for each family
- is listed below. The first link function of each family is the default one.
+ "gaussian", "binomial", "poisson", "gamma" and "tweedie" as family. Valid link functions for
+ each family is listed below. The first link function of each family is the default one.
* "gaussian" -> "identity", "log", "inverse"
@@ -1305,6 +1305,9 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha
* "gamma" -> "inverse", "identity", "log"
+ * "tweedie" -> power link function specified through "linkPower". \
+ The default link power in the tweedie family is 1 - variancePower.
+
.. seealso:: `GLM <https://en.wikipedia.org/wiki/Generalized_linear_model>`_
>>> from pyspark.ml.linalg import Vectors
@@ -1344,7 +1347,7 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha
family = Param(Params._dummy(), "family", "The name of family which is a description of " +
"the error distribution to be used in the model. Supported options: " +
- "gaussian (default), binomial, poisson and gamma.",
+ "gaussian (default), binomial, poisson, gamma and tweedie.",
typeConverter=TypeConverters.toString)
link = Param(Params._dummy(), "link", "The name of link function which provides the " +
"relationship between the linear predictor and the mean of the distribution " +
@@ -1352,32 +1355,46 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha
"and sqrt.", typeConverter=TypeConverters.toString)
linkPredictionCol = Param(Params._dummy(), "linkPredictionCol", "link prediction (linear " +
"predictor) column name", typeConverter=TypeConverters.toString)
+ variancePower = Param(Params._dummy(), "variancePower", "The power in the variance function " +
+ "of the Tweedie distribution which characterizes the relationship " +
+ "between the variance and mean of the distribution. Only applicable " +
+ "for the Tweedie family. Supported values: 0 and [1, Inf).",
+ typeConverter=TypeConverters.toFloat)
+ linkPower = Param(Params._dummy(), "linkPower", "The index in the power link function. " +
+ "Only applicable to the Tweedie family.",
+ typeConverter=TypeConverters.toFloat)
@keyword_only
def __init__(self, labelCol="label", featuresCol="features", predictionCol="prediction",
family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6,
- regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None):
+ regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None,
+ variancePower=0.0, linkPower=None):
"""
__init__(self, labelCol="label", featuresCol="features", predictionCol="prediction", \
family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, \
- regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None)
+ regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None, \
+ variancePower=0.0, linkPower=None)
"""
super(GeneralizedLinearRegression, self).__init__()
self._java_obj = self._new_java_obj(
"org.apache.spark.ml.regression.GeneralizedLinearRegression", self.uid)
- self._setDefault(family="gaussian", maxIter=25, tol=1e-6, regParam=0.0, solver="irls")
+ self._setDefault(family="gaussian", maxIter=25, tol=1e-6, regParam=0.0, solver="irls",
+ variancePower=0.0)
kwargs = self._input_kwargs
+
self.setParams(**kwargs)
@keyword_only
@since("2.0.0")
def setParams(self, labelCol="label", featuresCol="features", predictionCol="prediction",
family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6,
- regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None):
+ regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None,
+ variancePower=0.0, linkPower=None):
"""
setParams(self, labelCol="label", featuresCol="features", predictionCol="prediction", \
family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, \
- regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None)
+ regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None, \
+ variancePower=0.0, linkPower=None)
Sets params for generalized linear regression.
"""
kwargs = self._input_kwargs
@@ -1428,6 +1445,34 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha
"""
return self.getOrDefault(self.link)
+ @since("2.2.0")
+ def setVariancePower(self, value):
+ """
+ Sets the value of :py:attr:`variancePower`.
+ """
+ return self._set(variancePower=value)
+
+ @since("2.2.0")
+ def getVariancePower(self):
+ """
+ Gets the value of variancePower or its default value.
+ """
+ return self.getOrDefault(self.variancePower)
+
+ @since("2.2.0")
+ def setLinkPower(self, value):
+ """
+ Sets the value of :py:attr:`linkPower`.
+ """
+ return self._set(linkPower=value)
+
+ @since("2.2.0")
+ def getLinkPower(self):
+ """
+ Gets the value of linkPower or its default value.
+ """
+ return self.getOrDefault(self.linkPower)
+
class GeneralizedLinearRegressionModel(JavaModel, JavaPredictionModel, JavaMLWritable,
JavaMLReadable):
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 3524160557..f052f5bb77 100755
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -1223,6 +1223,26 @@ class HashingTFTest(SparkSessionTestCase):
": expected " + str(expected[i]) + ", got " + str(features[i]))
+class GeneralizedLinearRegressionTest(SparkSessionTestCase):
+
+ def test_tweedie_distribution(self):
+
+ df = self.spark.createDataFrame(
+ [(1.0, Vectors.dense(0.0, 0.0)),
+ (1.0, Vectors.dense(1.0, 2.0)),
+ (2.0, Vectors.dense(0.0, 0.0)),
+ (2.0, Vectors.dense(1.0, 1.0)), ], ["label", "features"])
+
+ glr = GeneralizedLinearRegression(family="tweedie", variancePower=1.6)
+ model = glr.fit(df)
+ self.assertTrue(np.allclose(model.coefficients.toArray(), [-0.4645, 0.3402], atol=1E-4))
+ self.assertTrue(np.isclose(model.intercept, 0.7841, atol=1E-4))
+
+ model2 = glr.setLinkPower(-1.0).fit(df)
+ self.assertTrue(np.allclose(model2.coefficients.toArray(), [-0.6667, 0.5], atol=1E-4))
+ self.assertTrue(np.isclose(model2.intercept, 0.6667, atol=1E-4))
+
+
class ALSTest(SparkSessionTestCase):
def test_storage_levels(self):