aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorsethah <seth.hendrickson16@gmail.com>2016-05-13 09:01:20 +0200
committerNick Pentreath <nick.pentreath@gmail.com>2016-05-13 09:01:20 +0200
commit5b849766ab080c91864ed06ebbfd82ad978d5e4c (patch)
tree7ce287278bbeb2e0771300784aba26cb98d13aa4 /python
parent87d69a01f027aa18718827f94f921b4a1eaa78a5 (diff)
downloadspark-5b849766ab080c91864ed06ebbfd82ad978d5e4c.tar.gz
spark-5b849766ab080c91864ed06ebbfd82ad978d5e4c.tar.bz2
spark-5b849766ab080c91864ed06ebbfd82ad978d5e4c.zip
[SPARK-15181][ML][PYSPARK] Python API for GLR summaries.
## What changes were proposed in this pull request? This patch adds a python API for generalized linear regression summaries (training and test). This helps provide feature parity for Python GLMs. ## How was this patch tested? Added a unit test to `pyspark.ml.tests` Author: sethah <seth.hendrickson16@gmail.com> Closes #12961 from sethah/GLR_summary.
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/ml/regression.py201
-rwxr-xr-xpython/pyspark/ml/tests.py39
2 files changed, 238 insertions, 2 deletions
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index 0d0eb8ae46..fcdc29e69b 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -29,6 +29,7 @@ __all__ = ['AFTSurvivalRegression', 'AFTSurvivalRegressionModel',
'DecisionTreeRegressor', 'DecisionTreeRegressionModel',
'GBTRegressor', 'GBTRegressionModel',
'GeneralizedLinearRegression', 'GeneralizedLinearRegressionModel',
+ 'GeneralizedLinearRegressionSummary', 'GeneralizedLinearRegressionTrainingSummary',
'IsotonicRegression', 'IsotonicRegressionModel',
'LinearRegression', 'LinearRegressionModel',
'LinearRegressionSummary', 'LinearRegressionTrainingSummary',
@@ -1283,7 +1284,7 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha
family = Param(Params._dummy(), "family", "The name of family which is a description of " +
"the error distribution to be used in the model. Supported options: " +
- "gaussian(default), binomial, poisson and gamma.",
+ "gaussian (default), binomial, poisson and gamma.",
typeConverter=TypeConverters.toString)
link = Param(Params._dummy(), "link", "The name of link function which provides the " +
"relationship between the linear predictor and the mean of the distribution " +
@@ -1377,6 +1378,204 @@ class GeneralizedLinearRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable
"""
return self._call_java("intercept")
+ @property
+ @since("2.0.0")
+ def summary(self):
+ """
+ Gets summary (e.g. residuals, deviance, pValues) of model on
+ training set. An exception is thrown if
+ `trainingSummary is None`.
+ """
+ java_glrt_summary = self._call_java("summary")
+ return GeneralizedLinearRegressionTrainingSummary(java_glrt_summary)
+
+ @property
+ @since("2.0.0")
+ def hasSummary(self):
+ """
+ Indicates whether a training summary exists for this model
+ instance.
+ """
+ return self._call_java("hasSummary")
+
+ @since("2.0.0")
+ def evaluate(self, dataset):
+ """
+ Evaluates the model on a test dataset.
+
+ :param dataset:
+ Test dataset to evaluate model on, where dataset is an
+ instance of :py:class:`pyspark.sql.DataFrame`
+ """
+ if not isinstance(dataset, DataFrame):
+ raise ValueError("dataset must be a DataFrame but got %s." % type(dataset))
+ java_glr_summary = self._call_java("evaluate", dataset)
+ return GeneralizedLinearRegressionSummary(java_glr_summary)
+
+
+class GeneralizedLinearRegressionSummary(JavaWrapper):
+ """
+ .. note:: Experimental
+
+ Generalized linear regression results evaluated on a dataset.
+
+ .. versionadded:: 2.0.0
+ """
+
+ @property
+ @since("2.0.0")
+ def predictions(self):
+ """
+ Predictions output by the model's `transform` method.
+ """
+ return self._call_java("predictions")
+
+ @property
+ @since("2.0.0")
+ def predictionCol(self):
+ """
+ Field in :py:attr:`predictions` which gives the predicted value of each instance.
+ This is set to a new column name if the original model's `predictionCol` is not set.
+ """
+ return self._call_java("predictionCol")
+
+ @property
+ @since("2.0.0")
+ def rank(self):
+ """
+ The numeric rank of the fitted linear model.
+ """
+ return self._call_java("rank")
+
+ @property
+ @since("2.0.0")
+ def degreesOfFreedom(self):
+ """
+ Degrees of freedom.
+ """
+ return self._call_java("degreesOfFreedom")
+
+ @property
+ @since("2.0.0")
+ def residualDegreeOfFreedom(self):
+ """
+ The residual degrees of freedom.
+ """
+ return self._call_java("residualDegreeOfFreedom")
+
+ @property
+ @since("2.0.0")
+ def residualDegreeOfFreedomNull(self):
+ """
+ The residual degrees of freedom for the null model.
+ """
+ return self._call_java("residualDegreeOfFreedomNull")
+
+ @since("2.0.0")
+ def residuals(self, residualsType="deviance"):
+ """
+ Get the residuals of the fitted model by type.
+
+ :param residualsType: The type of residuals which should be returned.
+ Supported options: deviance (default), pearson, working, and response.
+ """
+ return self._call_java("residuals", residualsType)
+
+ @property
+ @since("2.0.0")
+ def nullDeviance(self):
+ """
+ The deviance for the null model.
+ """
+ return self._call_java("nullDeviance")
+
+ @property
+ @since("2.0.0")
+ def deviance(self):
+ """
+ The deviance for the fitted model.
+ """
+ return self._call_java("deviance")
+
+ @property
+ @since("2.0.0")
+ def dispersion(self):
+ """
+ The dispersion of the fitted model.
+ It is taken as 1.0 for the "binomial" and "poisson" families, and otherwise
+ estimated by the residual Pearson's Chi-Squared statistic (which is defined as
+ sum of the squares of the Pearson residuals) divided by the residual degrees of freedom.
+ """
+ return self._call_java("dispersion")
+
+ @property
+ @since("2.0.0")
+ def aic(self):
+ """
+ Akaike's "An Information Criterion"(AIC) for the fitted model.
+ """
+ return self._call_java("aic")
+
+
+@inherit_doc
+class GeneralizedLinearRegressionTrainingSummary(GeneralizedLinearRegressionSummary):
+ """
+ .. note:: Experimental
+
+ Generalized linear regression training results.
+
+ .. versionadded:: 2.0.0
+ """
+
+ @property
+ @since("2.0.0")
+ def numIterations(self):
+ """
+ Number of training iterations.
+ """
+ return self._call_java("numIterations")
+
+ @property
+ @since("2.0.0")
+ def solver(self):
+ """
+ The numeric solver used for training.
+ """
+ return self._call_java("solver")
+
+ @property
+ @since("2.0.0")
+ def coefficientStandardErrors(self):
+ """
+ Standard error of estimated coefficients and intercept.
+
+ If :py:attr:`GeneralizedLinearRegression.fitIntercept` is set to True,
+ then the last element returned corresponds to the intercept.
+ """
+ return self._call_java("coefficientStandardErrors")
+
+ @property
+ @since("2.0.0")
+ def tValues(self):
+ """
+ T-statistic of estimated coefficients and intercept.
+
+ If :py:attr:`GeneralizedLinearRegression.fitIntercept` is set to True,
+ then the last element returned corresponds to the intercept.
+ """
+ return self._call_java("tValues")
+
+ @property
+ @since("2.0.0")
+ def pValues(self):
+ """
+ Two-sided p-value of estimated coefficients and intercept.
+
+ If :py:attr:`GeneralizedLinearRegression.fitIntercept` is set to True,
+ then the last element returned corresponds to the intercept.
+ """
+ return self._call_java("pValues")
+
if __name__ == "__main__":
import doctest
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 49d3a4a332..8e56b0d6ff 100755
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -52,7 +52,8 @@ from pyspark.ml.feature import *
from pyspark.ml.param import Param, Params, TypeConverters
from pyspark.ml.param.shared import HasMaxIter, HasInputCol, HasSeed
from pyspark.ml.recommendation import ALS
-from pyspark.ml.regression import LinearRegression, DecisionTreeRegressor
+from pyspark.ml.regression import LinearRegression, DecisionTreeRegressor, \
+ GeneralizedLinearRegression
from pyspark.ml.tuning import *
from pyspark.ml.wrapper import JavaParams
from pyspark.mllib.common import _java2py
@@ -909,6 +910,42 @@ class TrainingSummaryTest(SparkSessionTestCase):
sameSummary = model.evaluate(df)
self.assertAlmostEqual(sameSummary.explainedVariance, s.explainedVariance)
+ def test_glr_summary(self):
+ from pyspark.mllib.linalg import Vectors
+ df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)),
+ (0.0, 2.0, Vectors.sparse(1, [], []))],
+ ["label", "weight", "features"])
+ glr = GeneralizedLinearRegression(family="gaussian", link="identity", weightCol="weight",
+ fitIntercept=False)
+ model = glr.fit(df)
+ self.assertTrue(model.hasSummary)
+ s = model.summary
+ # test that api is callable and returns expected types
+ self.assertEqual(s.numIterations, 1) # this should default to a single iteration of WLS
+ self.assertTrue(isinstance(s.predictions, DataFrame))
+ self.assertEqual(s.predictionCol, "prediction")
+ self.assertTrue(isinstance(s.residuals(), DataFrame))
+ self.assertTrue(isinstance(s.residuals("pearson"), DataFrame))
+ coefStdErr = s.coefficientStandardErrors
+ self.assertTrue(isinstance(coefStdErr, list) and isinstance(coefStdErr[0], float))
+ tValues = s.tValues
+ self.assertTrue(isinstance(tValues, list) and isinstance(tValues[0], float))
+ pValues = s.pValues
+ self.assertTrue(isinstance(pValues, list) and isinstance(pValues[0], float))
+ self.assertEqual(s.degreesOfFreedom, 1)
+ self.assertEqual(s.residualDegreeOfFreedom, 1)
+ self.assertEqual(s.residualDegreeOfFreedomNull, 2)
+ self.assertEqual(s.rank, 1)
+ self.assertTrue(isinstance(s.solver, basestring))
+ self.assertTrue(isinstance(s.aic, float))
+ self.assertTrue(isinstance(s.deviance, float))
+ self.assertTrue(isinstance(s.nullDeviance, float))
+ self.assertTrue(isinstance(s.dispersion, float))
+ # test evaluation (with training dataset) produces a summary with same values
+ # one check is enough to verify a summary is returned, Scala version runs full test
+ sameSummary = model.evaluate(df)
+ self.assertAlmostEqual(sameSummary.deviance, s.deviance)
+
def test_logistic_regression_summary(self):
from pyspark.mllib.linalg import Vectors
df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)),