aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorKai Jiang <jiangkai@gmail.com>2016-04-12 11:29:12 -0700
committerXiangrui Meng <meng@databricks.com>2016-04-12 11:29:12 -0700
commit7f024c47441a2f84fcc34a6021b976f036ea24c4 (patch)
tree72a90b101ba983517017726132244423867134e3 /python
parent101663f1ae222a919fc40510aa4f2bad22d1be6f (diff)
downloadspark-7f024c47441a2f84fcc34a6021b976f036ea24c4.tar.gz
spark-7f024c47441a2f84fcc34a6021b976f036ea24c4.tar.bz2
spark-7f024c47441a2f84fcc34a6021b976f036ea24c4.zip
[SPARK-13597][PYSPARK][ML] Python API for GeneralizedLinearRegression
## What changes were proposed in this pull request? Python API for GeneralizedLinearRegression JIRA: https://issues.apache.org/jira/browse/SPARK-13597 ## How was this patch tested? The patch is tested with Python doctest. Author: Kai Jiang <jiangkai@gmail.com> Closes #11468 from vectorijk/spark-13597.
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/ml/regression.py145
1 files changed, 145 insertions, 0 deletions
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index 1c18df3b27..bc88f88b7f 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -28,6 +28,7 @@ from pyspark.sql import DataFrame
__all__ = ['AFTSurvivalRegression', 'AFTSurvivalRegressionModel',
'DecisionTreeRegressor', 'DecisionTreeRegressionModel',
'GBTRegressor', 'GBTRegressionModel',
+ 'GeneralizedLinearRegression', 'GeneralizedLinearRegressionModel'
'IsotonicRegression', 'IsotonicRegressionModel',
'LinearRegression', 'LinearRegressionModel',
'LinearRegressionSummary', 'LinearRegressionTrainingSummary',
@@ -1197,6 +1198,150 @@ class AFTSurvivalRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable):
return self._call_java("predict", features)
+@inherit_doc
+class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, HasPredictionCol,
+ HasFitIntercept, HasMaxIter, HasTol, HasRegParam, HasWeightCol,
+ HasSolver, JavaMLWritable, JavaMLReadable):
+ """
+ Generalized Linear Regression.
+
+ Fit a Generalized Linear Model specified by giving a symbolic description of the linear
+ predictor (link function) and a description of the error distribution (family). It supports
+ "gaussian", "binomial", "poisson" and "gamma" as family. Valid link functions for each family
+ is listed below. The first link function of each family is the default one.
+ - "gaussian" -> "identity", "log", "inverse"
+ - "binomial" -> "logit", "probit", "cloglog"
+ - "poisson" -> "log", "identity", "sqrt"
+ - "gamma" -> "inverse", "identity", "log"
+
+ .. seealso:: `GLM <https://en.wikipedia.org/wiki/Generalized_linear_model>`_
+
+ >>> from pyspark.mllib.linalg import Vectors
+ >>> df = sqlContext.createDataFrame([
+ ... (1.0, Vectors.dense(0.0, 0.0)),
+ ... (1.0, Vectors.dense(1.0, 2.0)),
+ ... (2.0, Vectors.dense(0.0, 0.0)),
+ ... (2.0, Vectors.dense(1.0, 1.0)),], ["label", "features"])
+ >>> glr = GeneralizedLinearRegression(family="gaussian", link="identity")
+ >>> model = glr.fit(df)
+ >>> abs(model.transform(df).head().prediction - 1.5) < 0.001
+ True
+ >>> model.coefficients
+ DenseVector([1.5..., -1.0...])
+ >>> abs(model.intercept - 1.5) < 0.001
+ True
+ >>> glr_path = temp_path + "/glr"
+ >>> glr.save(glr_path)
+ >>> glr2 = GeneralizedLinearRegression.load(glr_path)
+ >>> glr.getFamily() == glr2.getFamily()
+ True
+ >>> model_path = temp_path + "/glr_model"
+ >>> model.save(model_path)
+ >>> model2 = GeneralizedLinearRegressionModel.load(model_path)
+ >>> model.intercept == model2.intercept
+ True
+ >>> model.coefficients[0] == model2.coefficients[0]
+ True
+
+ .. versionadded:: 2.0.0
+ """
+
+ family = Param(Params._dummy(), "family", "The name of family which is a description of " +
+ "the error distribution to be used in the model. Supported options: " +
+ "gaussian(default), binomial, poisson and gamma.")
+ link = Param(Params._dummy(), "link", "The name of link function which provides the " +
+ "relationship between the linear predictor and the mean of the distribution " +
+ "function. Supported options: identity, log, inverse, logit, probit, cloglog " +
+ "and sqrt.")
+
+ @keyword_only
+ def __init__(self, labelCol="label", featuresCol="features", predictionCol="prediction",
+ family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6,
+ regParam=0.0, weightCol=None, solver="irls"):
+ """
+ __init__(self, labelCol="label", featuresCol="features", predictionCol="prediction", \
+ family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, \
+ regParam=0.0, weightCol=None, solver="irls")
+ """
+ super(GeneralizedLinearRegression, self).__init__()
+ self._java_obj = self._new_java_obj(
+ "org.apache.spark.ml.regression.GeneralizedLinearRegression", self.uid)
+ self._setDefault(family="gaussian", maxIter=25, tol=1e-6, regParam=0.0, solver="irls")
+ kwargs = self.__init__._input_kwargs
+ self.setParams(**kwargs)
+
+ @keyword_only
+ @since("2.0.0")
+ def setParams(self, labelCol="label", featuresCol="features", predictionCol="prediction",
+ family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6,
+ regParam=0.0, weightCol=None, solver="irls"):
+ """
+ setParams(self, labelCol="label", featuresCol="features", predictionCol="prediction", \
+ family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, \
+ regParam=0.0, weightCol=None, solver="irls")
+ Sets params for generalized linear regression.
+ """
+ kwargs = self.setParams._input_kwargs
+ return self._set(**kwargs)
+
+ def _create_model(self, java_model):
+ return GeneralizedLinearRegressionModel(java_model)
+
+ @since("2.0.0")
+ def setFamily(self, value):
+ """
+ Sets the value of :py:attr:`family`.
+ """
+ self._paramMap[self.family] = value
+ return self
+
+ @since("2.0.0")
+ def getFamily(self):
+ """
+ Gets the value of family or its default value.
+ """
+ return self.getOrDefault(self.family)
+
+ @since("2.0.0")
+ def setLink(self, value):
+ """
+ Sets the value of :py:attr:`link`.
+ """
+ self._paramMap[self.link] = value
+ return self
+
+ @since("2.0.0")
+ def getLink(self):
+ """
+ Gets the value of link or its default value.
+ """
+ return self.getOrDefault(self.link)
+
+
+class GeneralizedLinearRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable):
+ """
+ Model fitted by GeneralizedLinearRegression.
+
+ .. versionadded:: 2.0.0
+ """
+
+ @property
+ @since("2.0.0")
+ def coefficients(self):
+ """
+ Model coefficients.
+ """
+ return self._call_java("coefficients")
+
+ @property
+ @since("2.0.0")
+ def intercept(self):
+ """
+ Model intercept.
+ """
+ return self._call_java("intercept")
+
+
if __name__ == "__main__":
import doctest
import pyspark.ml.regression