aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/regression.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/ml/regression.py')
-rw-r--r--python/pyspark/ml/regression.py245
1 files changed, 243 insertions, 2 deletions
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index de8a5e4bed..6cd1b4bf3a 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -20,8 +20,9 @@ import warnings
from pyspark import since
from pyspark.ml.param.shared import *
from pyspark.ml.util import *
-from pyspark.ml.wrapper import JavaEstimator, JavaModel
+from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaCallable
from pyspark.mllib.common import inherit_doc
+from pyspark.sql import DataFrame
__all__ = ['AFTSurvivalRegression', 'AFTSurvivalRegressionModel',
@@ -29,6 +30,7 @@ __all__ = ['AFTSurvivalRegression', 'AFTSurvivalRegressionModel',
'GBTRegressor', 'GBTRegressionModel',
'IsotonicRegression', 'IsotonicRegressionModel',
'LinearRegression', 'LinearRegressionModel',
+ 'LinearRegressionSummary', 'LinearRegressionTrainingSummary',
'RandomForestRegressor', 'RandomForestRegressionModel']
@@ -131,7 +133,6 @@ class LinearRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable):
"""
Model weights.
"""
-
warnings.warn("weights is deprecated. Use coefficients instead.")
return self._call_java("weights")
@@ -151,6 +152,246 @@ class LinearRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable):
"""
return self._call_java("intercept")
+ @property
+ @since("2.0.0")
+ def summary(self):
+ """
+ Gets summary (e.g. residuals, mse, r-squared ) of model on
+ training set. An exception is thrown if
+ `trainingSummary is None`.
+ """
+ java_lrt_summary = self._call_java("summary")
+ return LinearRegressionTrainingSummary(java_lrt_summary)
+
+ @property
+ @since("2.0.0")
+ def hasSummary(self):
+ """
+ Indicates whether a training summary exists for this model
+ instance.
+ """
+ return self._call_java("hasSummary")
+
+ @since("2.0.0")
+ def evaluate(self, dataset):
+ """
+ Evaluates the model on a test dataset.
+
+ :param dataset:
+ Test dataset to evaluate model on, where dataset is an
+ instance of :py:class:`pyspark.sql.DataFrame`
+ """
+ if not isinstance(dataset, DataFrame):
+ raise ValueError("dataset must be a DataFrame but got %s." % type(dataset))
+ java_lr_summary = self._call_java("evaluate", dataset)
+ return LinearRegressionSummary(java_lr_summary)
+
+
+class LinearRegressionSummary(JavaCallable):
+ """
+ .. note:: Experimental
+
+ Linear regression results evaluated on a dataset.
+
+ .. versionadded:: 2.0.0
+ """
+
+ @property
+ @since("2.0.0")
+ def predictions(self):
+ """
+ Dataframe outputted by the model's `transform` method.
+ """
+ return self._call_java("predictions")
+
+ @property
+ @since("2.0.0")
+ def predictionCol(self):
+ """
+ Field in "predictions" which gives the predicted value of
+ the label at each instance.
+ """
+ return self._call_java("predictionCol")
+
+ @property
+ @since("2.0.0")
+ def labelCol(self):
+ """
+ Field in "predictions" which gives the true label of each
+ instance.
+ """
+ return self._call_java("labelCol")
+
+ @property
+ @since("2.0.0")
+ def featuresCol(self):
+ """
+ Field in "predictions" which gives the features of each instance
+ as a vector.
+ """
+ return self._call_java("featuresCol")
+
+ @property
+ @since("2.0.0")
+ def explainedVariance(self):
+ """
+ Returns the explained variance regression score.
+ explainedVariance = 1 - variance(y - \hat{y}) / variance(y)
+ Reference: http://en.wikipedia.org/wiki/Explained_variation
+
+ Note: This ignores instance weights (setting all to 1.0) from
+ `LinearRegression.weightCol`. This will change in later Spark
+ versions.
+ """
+ return self._call_java("explainedVariance")
+
+ @property
+ @since("2.0.0")
+ def meanAbsoluteError(self):
+ """
+ Returns the mean absolute error, which is a risk function
+ corresponding to the expected value of the absolute error
+ loss or l1-norm loss.
+
+ Note: This ignores instance weights (setting all to 1.0) from
+ `LinearRegression.weightCol`. This will change in later Spark
+ versions.
+ """
+ return self._call_java("meanAbsoluteError")
+
+ @property
+ @since("2.0.0")
+ def meanSquaredError(self):
+ """
+ Returns the mean squared error, which is a risk function
+ corresponding to the expected value of the squared error
+ loss or quadratic loss.
+
+ Note: This ignores instance weights (setting all to 1.0) from
+ `LinearRegression.weightCol`. This will change in later Spark
+ versions.
+ """
+ return self._call_java("meanSquaredError")
+
+ @property
+ @since("2.0.0")
+ def rootMeanSquaredError(self):
+ """
+ Returns the root mean squared error, which is defined as the
+ square root of the mean squared error.
+
+ Note: This ignores instance weights (setting all to 1.0) from
+ `LinearRegression.weightCol`. This will change in later Spark
+ versions.
+ """
+ return self._call_java("rootMeanSquaredError")
+
+ @property
+ @since("2.0.0")
+ def r2(self):
+ """
+ Returns R^2^, the coefficient of determination.
+ Reference: http://en.wikipedia.org/wiki/Coefficient_of_determination
+
+ Note: This ignores instance weights (setting all to 1.0) from
+ `LinearRegression.weightCol`. This will change in later Spark
+ versions.
+ """
+ return self._call_java("r2")
+
+ @property
+ @since("2.0.0")
+ def residuals(self):
+ """
+ Residuals (label - predicted value)
+ """
+ return self._call_java("residuals")
+
+ @property
+ @since("2.0.0")
+ def numInstances(self):
+ """
+ Number of instances in DataFrame predictions
+ """
+ return self._call_java("numInstances")
+
+ @property
+ @since("2.0.0")
+ def devianceResiduals(self):
+ """
+ The weighted residuals, the usual residuals rescaled by the
+ square root of the instance weights.
+ """
+ return self._call_java("devianceResiduals")
+
+ @property
+ @since("2.0.0")
+ def coefficientStandardErrors(self):
+ """
+ Standard error of estimated coefficients and intercept.
+ This value is only available when using the "normal" solver.
+
+ .. seealso:: :py:attr:`LinearRegression.solver`
+ """
+ return self._call_java("coefficientStandardErrors")
+
+ @property
+ @since("2.0.0")
+ def tValues(self):
+ """
+ T-statistic of estimated coefficients and intercept.
+ This value is only available when using the "normal" solver.
+
+ .. seealso:: :py:attr:`LinearRegression.solver`
+ """
+ return self._call_java("tValues")
+
+ @property
+ @since("2.0.0")
+ def pValues(self):
+ """
+ Two-sided p-value of estimated coefficients and intercept.
+ This value is only available when using the "normal" solver.
+
+ .. seealso:: :py:attr:`LinearRegression.solver`
+ """
+ return self._call_java("pValues")
+
+
+@inherit_doc
+class LinearRegressionTrainingSummary(LinearRegressionSummary):
+ """
+ .. note:: Experimental
+
+ Linear regression training results. Currently, the training summary ignores the
+ training weights except for the objective trace.
+
+ .. versionadded:: 2.0.0
+ """
+
+ @property
+ @since("2.0.0")
+ def objectiveHistory(self):
+ """
+ Objective function (scaled loss + regularization) at each
+ iteration.
+ This value is only available when using the "l-bfgs" solver.
+
+ .. seealso:: :py:attr:`LinearRegression.solver`
+ """
+ return self._call_java("objectiveHistory")
+
+ @property
+ @since("2.0.0")
+ def totalIterations(self):
+ """
+ Number of training iterations until termination.
+ This value is only available when using the "l-bfgs" solver.
+
+ .. seealso:: :py:attr:`LinearRegression.solver`
+ """
+ return self._call_java("totalIterations")
+
@inherit_doc
class IsotonicRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,