diff options
author | sethah <seth.hendrickson16@gmail.com> | 2016-05-13 09:01:20 +0200 |
---|---|---|
committer | Nick Pentreath <nick.pentreath@gmail.com> | 2016-05-13 09:01:20 +0200 |
commit | 5b849766ab080c91864ed06ebbfd82ad978d5e4c (patch) | |
tree | 7ce287278bbeb2e0771300784aba26cb98d13aa4 /python/pyspark/ml/regression.py | |
parent | 87d69a01f027aa18718827f94f921b4a1eaa78a5 (diff) | |
download | spark-5b849766ab080c91864ed06ebbfd82ad978d5e4c.tar.gz spark-5b849766ab080c91864ed06ebbfd82ad978d5e4c.tar.bz2 spark-5b849766ab080c91864ed06ebbfd82ad978d5e4c.zip |
[SPARK-15181][ML][PYSPARK] Python API for GLR summaries.
## What changes were proposed in this pull request?
This patch adds a python API for generalized linear regression summaries (training and test). This helps provide feature parity for Python GLMs.
## How was this patch tested?
Added a unit test to `pyspark.ml.tests`
Author: sethah <seth.hendrickson16@gmail.com>
Closes #12961 from sethah/GLR_summary.
Diffstat (limited to 'python/pyspark/ml/regression.py')
-rw-r--r-- | python/pyspark/ml/regression.py | 201 |
1 files changed, 200 insertions, 1 deletions
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 0d0eb8ae46..fcdc29e69b 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -29,6 +29,7 @@ __all__ = ['AFTSurvivalRegression', 'AFTSurvivalRegressionModel', 'DecisionTreeRegressor', 'DecisionTreeRegressionModel', 'GBTRegressor', 'GBTRegressionModel', 'GeneralizedLinearRegression', 'GeneralizedLinearRegressionModel', + 'GeneralizedLinearRegressionSummary', 'GeneralizedLinearRegressionTrainingSummary', 'IsotonicRegression', 'IsotonicRegressionModel', 'LinearRegression', 'LinearRegressionModel', 'LinearRegressionSummary', 'LinearRegressionTrainingSummary', @@ -1283,7 +1284,7 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha family = Param(Params._dummy(), "family", "The name of family which is a description of " + "the error distribution to be used in the model. Supported options: " + - "gaussian(default), binomial, poisson and gamma.", + "gaussian (default), binomial, poisson and gamma.", typeConverter=TypeConverters.toString) link = Param(Params._dummy(), "link", "The name of link function which provides the " + "relationship between the linear predictor and the mean of the distribution " + @@ -1377,6 +1378,204 @@ class GeneralizedLinearRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable """ return self._call_java("intercept") + @property + @since("2.0.0") + def summary(self): + """ + Gets summary (e.g. residuals, deviance, pValues) of model on + training set. An exception is thrown if + `trainingSummary is None`. + """ + java_glrt_summary = self._call_java("summary") + return GeneralizedLinearRegressionTrainingSummary(java_glrt_summary) + + @property + @since("2.0.0") + def hasSummary(self): + """ + Indicates whether a training summary exists for this model + instance. + """ + return self._call_java("hasSummary") + + @since("2.0.0") + def evaluate(self, dataset): + """ + Evaluates the model on a test dataset. + + :param dataset: + Test dataset to evaluate model on, where dataset is an + instance of :py:class:`pyspark.sql.DataFrame` + """ + if not isinstance(dataset, DataFrame): + raise ValueError("dataset must be a DataFrame but got %s." % type(dataset)) + java_glr_summary = self._call_java("evaluate", dataset) + return GeneralizedLinearRegressionSummary(java_glr_summary) + + +class GeneralizedLinearRegressionSummary(JavaWrapper): + """ + .. note:: Experimental + + Generalized linear regression results evaluated on a dataset. + + .. versionadded:: 2.0.0 + """ + + @property + @since("2.0.0") + def predictions(self): + """ + Predictions output by the model's `transform` method. + """ + return self._call_java("predictions") + + @property + @since("2.0.0") + def predictionCol(self): + """ + Field in :py:attr:`predictions` which gives the predicted value of each instance. + This is set to a new column name if the original model's `predictionCol` is not set. + """ + return self._call_java("predictionCol") + + @property + @since("2.0.0") + def rank(self): + """ + The numeric rank of the fitted linear model. + """ + return self._call_java("rank") + + @property + @since("2.0.0") + def degreesOfFreedom(self): + """ + Degrees of freedom. + """ + return self._call_java("degreesOfFreedom") + + @property + @since("2.0.0") + def residualDegreeOfFreedom(self): + """ + The residual degrees of freedom. + """ + return self._call_java("residualDegreeOfFreedom") + + @property + @since("2.0.0") + def residualDegreeOfFreedomNull(self): + """ + The residual degrees of freedom for the null model. + """ + return self._call_java("residualDegreeOfFreedomNull") + + @since("2.0.0") + def residuals(self, residualsType="deviance"): + """ + Get the residuals of the fitted model by type. + + :param residualsType: The type of residuals which should be returned. + Supported options: deviance (default), pearson, working, and response. + """ + return self._call_java("residuals", residualsType) + + @property + @since("2.0.0") + def nullDeviance(self): + """ + The deviance for the null model. + """ + return self._call_java("nullDeviance") + + @property + @since("2.0.0") + def deviance(self): + """ + The deviance for the fitted model. + """ + return self._call_java("deviance") + + @property + @since("2.0.0") + def dispersion(self): + """ + The dispersion of the fitted model. + It is taken as 1.0 for the "binomial" and "poisson" families, and otherwise + estimated by the residual Pearson's Chi-Squared statistic (which is defined as + sum of the squares of the Pearson residuals) divided by the residual degrees of freedom. + """ + return self._call_java("dispersion") + + @property + @since("2.0.0") + def aic(self): + """ + Akaike's "An Information Criterion"(AIC) for the fitted model. + """ + return self._call_java("aic") + + +@inherit_doc +class GeneralizedLinearRegressionTrainingSummary(GeneralizedLinearRegressionSummary): + """ + .. note:: Experimental + + Generalized linear regression training results. + + .. versionadded:: 2.0.0 + """ + + @property + @since("2.0.0") + def numIterations(self): + """ + Number of training iterations. + """ + return self._call_java("numIterations") + + @property + @since("2.0.0") + def solver(self): + """ + The numeric solver used for training. + """ + return self._call_java("solver") + + @property + @since("2.0.0") + def coefficientStandardErrors(self): + """ + Standard error of estimated coefficients and intercept. + + If :py:attr:`GeneralizedLinearRegression.fitIntercept` is set to True, + then the last element returned corresponds to the intercept. + """ + return self._call_java("coefficientStandardErrors") + + @property + @since("2.0.0") + def tValues(self): + """ + T-statistic of estimated coefficients and intercept. + + If :py:attr:`GeneralizedLinearRegression.fitIntercept` is set to True, + then the last element returned corresponds to the intercept. + """ + return self._call_java("tValues") + + @property + @since("2.0.0") + def pValues(self): + """ + Two-sided p-value of estimated coefficients and intercept. + + If :py:attr:`GeneralizedLinearRegression.fitIntercept` is set to True, + then the last element returned corresponds to the intercept. + """ + return self._call_java("pValues") + if __name__ == "__main__": import doctest |