diff options
author | Yanbo Liang <ybliang8@gmail.com> | 2017-03-08 02:09:36 -0800 |
---|---|---|
committer | Yanbo Liang <ybliang8@gmail.com> | 2017-03-08 02:09:36 -0800 |
commit | 81303f7ca7808d51229411dce8feeed8c23dbe15 (patch) | |
tree | e7f6ebae38cfbd2877c933bafa63b21d766df531 /python/pyspark | |
parent | 1fa58868bc6635ff2119264665bd3d00b4b1253a (diff) | |
download | spark-81303f7ca7808d51229411dce8feeed8c23dbe15.tar.gz spark-81303f7ca7808d51229411dce8feeed8c23dbe15.tar.bz2 spark-81303f7ca7808d51229411dce8feeed8c23dbe15.zip |
[SPARK-19806][ML][PYSPARK] PySpark GeneralizedLinearRegression supports tweedie distribution.
## What changes were proposed in this pull request?
PySpark ```GeneralizedLinearRegression``` supports tweedie distribution.
## How was this patch tested?
Add unit tests.
Author: Yanbo Liang <ybliang8@gmail.com>
Closes #17146 from yanboliang/spark-19806.
Diffstat (limited to 'python/pyspark')
-rw-r--r-- | python/pyspark/ml/regression.py | 61 | ||||
-rwxr-xr-x | python/pyspark/ml/tests.py | 20 |
2 files changed, 73 insertions, 8 deletions
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index b199bf282e..3c3fcc8d9b 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -1294,8 +1294,8 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha Fit a Generalized Linear Model specified by giving a symbolic description of the linear predictor (link function) and a description of the error distribution (family). It supports - "gaussian", "binomial", "poisson" and "gamma" as family. Valid link functions for each family - is listed below. The first link function of each family is the default one. + "gaussian", "binomial", "poisson", "gamma" and "tweedie" as family. Valid link functions for + each family is listed below. The first link function of each family is the default one. * "gaussian" -> "identity", "log", "inverse" @@ -1305,6 +1305,9 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha * "gamma" -> "inverse", "identity", "log" + * "tweedie" -> power link function specified through "linkPower". \ + The default link power in the tweedie family is 1 - variancePower. + .. seealso:: `GLM <https://en.wikipedia.org/wiki/Generalized_linear_model>`_ >>> from pyspark.ml.linalg import Vectors @@ -1344,7 +1347,7 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha family = Param(Params._dummy(), "family", "The name of family which is a description of " + "the error distribution to be used in the model. Supported options: " + - "gaussian (default), binomial, poisson and gamma.", + "gaussian (default), binomial, poisson, gamma and tweedie.", typeConverter=TypeConverters.toString) link = Param(Params._dummy(), "link", "The name of link function which provides the " + "relationship between the linear predictor and the mean of the distribution " + @@ -1352,32 +1355,46 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha "and sqrt.", typeConverter=TypeConverters.toString) linkPredictionCol = Param(Params._dummy(), "linkPredictionCol", "link prediction (linear " + "predictor) column name", typeConverter=TypeConverters.toString) + variancePower = Param(Params._dummy(), "variancePower", "The power in the variance function " + + "of the Tweedie distribution which characterizes the relationship " + + "between the variance and mean of the distribution. Only applicable " + + "for the Tweedie family. Supported values: 0 and [1, Inf).", + typeConverter=TypeConverters.toFloat) + linkPower = Param(Params._dummy(), "linkPower", "The index in the power link function. " + + "Only applicable to the Tweedie family.", + typeConverter=TypeConverters.toFloat) @keyword_only def __init__(self, labelCol="label", featuresCol="features", predictionCol="prediction", family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, - regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None): + regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None, + variancePower=0.0, linkPower=None): """ __init__(self, labelCol="label", featuresCol="features", predictionCol="prediction", \ family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, \ - regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None) + regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None, \ + variancePower=0.0, linkPower=None) """ super(GeneralizedLinearRegression, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.regression.GeneralizedLinearRegression", self.uid) - self._setDefault(family="gaussian", maxIter=25, tol=1e-6, regParam=0.0, solver="irls") + self._setDefault(family="gaussian", maxIter=25, tol=1e-6, regParam=0.0, solver="irls", + variancePower=0.0) kwargs = self._input_kwargs + self.setParams(**kwargs) @keyword_only @since("2.0.0") def setParams(self, labelCol="label", featuresCol="features", predictionCol="prediction", family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, - regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None): + regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None, + variancePower=0.0, linkPower=None): """ setParams(self, labelCol="label", featuresCol="features", predictionCol="prediction", \ family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, \ - regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None) + regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None, \ + variancePower=0.0, linkPower=None) Sets params for generalized linear regression. """ kwargs = self._input_kwargs @@ -1428,6 +1445,34 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha """ return self.getOrDefault(self.link) + @since("2.2.0") + def setVariancePower(self, value): + """ + Sets the value of :py:attr:`variancePower`. + """ + return self._set(variancePower=value) + + @since("2.2.0") + def getVariancePower(self): + """ + Gets the value of variancePower or its default value. + """ + return self.getOrDefault(self.variancePower) + + @since("2.2.0") + def setLinkPower(self, value): + """ + Sets the value of :py:attr:`linkPower`. + """ + return self._set(linkPower=value) + + @since("2.2.0") + def getLinkPower(self): + """ + Gets the value of linkPower or its default value. + """ + return self.getOrDefault(self.linkPower) + class GeneralizedLinearRegressionModel(JavaModel, JavaPredictionModel, JavaMLWritable, JavaMLReadable): diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 3524160557..f052f5bb77 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -1223,6 +1223,26 @@ class HashingTFTest(SparkSessionTestCase): ": expected " + str(expected[i]) + ", got " + str(features[i])) +class GeneralizedLinearRegressionTest(SparkSessionTestCase): + + def test_tweedie_distribution(self): + + df = self.spark.createDataFrame( + [(1.0, Vectors.dense(0.0, 0.0)), + (1.0, Vectors.dense(1.0, 2.0)), + (2.0, Vectors.dense(0.0, 0.0)), + (2.0, Vectors.dense(1.0, 1.0)), ], ["label", "features"]) + + glr = GeneralizedLinearRegression(family="tweedie", variancePower=1.6) + model = glr.fit(df) + self.assertTrue(np.allclose(model.coefficients.toArray(), [-0.4645, 0.3402], atol=1E-4)) + self.assertTrue(np.isclose(model.intercept, 0.7841, atol=1E-4)) + + model2 = glr.setLinkPower(-1.0).fit(df) + self.assertTrue(np.allclose(model2.coefficients.toArray(), [-0.6667, 0.5], atol=1E-4)) + self.assertTrue(np.isclose(model2.intercept, 0.6667, atol=1E-4)) + + class ALSTest(SparkSessionTestCase): def test_storage_levels(self): |