diff options
Diffstat (limited to 'python/pyspark/ml/regression.py')
-rw-r--r-- | python/pyspark/ml/regression.py | 245 |
1 files changed, 243 insertions, 2 deletions
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index de8a5e4bed..6cd1b4bf3a 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -20,8 +20,9 @@ import warnings from pyspark import since from pyspark.ml.param.shared import * from pyspark.ml.util import * -from pyspark.ml.wrapper import JavaEstimator, JavaModel +from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaCallable from pyspark.mllib.common import inherit_doc +from pyspark.sql import DataFrame __all__ = ['AFTSurvivalRegression', 'AFTSurvivalRegressionModel', @@ -29,6 +30,7 @@ __all__ = ['AFTSurvivalRegression', 'AFTSurvivalRegressionModel', 'GBTRegressor', 'GBTRegressionModel', 'IsotonicRegression', 'IsotonicRegressionModel', 'LinearRegression', 'LinearRegressionModel', + 'LinearRegressionSummary', 'LinearRegressionTrainingSummary', 'RandomForestRegressor', 'RandomForestRegressionModel'] @@ -131,7 +133,6 @@ class LinearRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable): """ Model weights. """ - warnings.warn("weights is deprecated. Use coefficients instead.") return self._call_java("weights") @@ -151,6 +152,246 @@ class LinearRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable): """ return self._call_java("intercept") + @property + @since("2.0.0") + def summary(self): + """ + Gets summary (e.g. residuals, mse, r-squared ) of model on + training set. An exception is thrown if + `trainingSummary is None`. + """ + java_lrt_summary = self._call_java("summary") + return LinearRegressionTrainingSummary(java_lrt_summary) + + @property + @since("2.0.0") + def hasSummary(self): + """ + Indicates whether a training summary exists for this model + instance. + """ + return self._call_java("hasSummary") + + @since("2.0.0") + def evaluate(self, dataset): + """ + Evaluates the model on a test dataset. + + :param dataset: + Test dataset to evaluate model on, where dataset is an + instance of :py:class:`pyspark.sql.DataFrame` + """ + if not isinstance(dataset, DataFrame): + raise ValueError("dataset must be a DataFrame but got %s." % type(dataset)) + java_lr_summary = self._call_java("evaluate", dataset) + return LinearRegressionSummary(java_lr_summary) + + +class LinearRegressionSummary(JavaCallable): + """ + .. note:: Experimental + + Linear regression results evaluated on a dataset. + + .. versionadded:: 2.0.0 + """ + + @property + @since("2.0.0") + def predictions(self): + """ + Dataframe outputted by the model's `transform` method. + """ + return self._call_java("predictions") + + @property + @since("2.0.0") + def predictionCol(self): + """ + Field in "predictions" which gives the predicted value of + the label at each instance. + """ + return self._call_java("predictionCol") + + @property + @since("2.0.0") + def labelCol(self): + """ + Field in "predictions" which gives the true label of each + instance. + """ + return self._call_java("labelCol") + + @property + @since("2.0.0") + def featuresCol(self): + """ + Field in "predictions" which gives the features of each instance + as a vector. + """ + return self._call_java("featuresCol") + + @property + @since("2.0.0") + def explainedVariance(self): + """ + Returns the explained variance regression score. + explainedVariance = 1 - variance(y - \hat{y}) / variance(y) + Reference: http://en.wikipedia.org/wiki/Explained_variation + + Note: This ignores instance weights (setting all to 1.0) from + `LinearRegression.weightCol`. This will change in later Spark + versions. + """ + return self._call_java("explainedVariance") + + @property + @since("2.0.0") + def meanAbsoluteError(self): + """ + Returns the mean absolute error, which is a risk function + corresponding to the expected value of the absolute error + loss or l1-norm loss. + + Note: This ignores instance weights (setting all to 1.0) from + `LinearRegression.weightCol`. This will change in later Spark + versions. + """ + return self._call_java("meanAbsoluteError") + + @property + @since("2.0.0") + def meanSquaredError(self): + """ + Returns the mean squared error, which is a risk function + corresponding to the expected value of the squared error + loss or quadratic loss. + + Note: This ignores instance weights (setting all to 1.0) from + `LinearRegression.weightCol`. This will change in later Spark + versions. + """ + return self._call_java("meanSquaredError") + + @property + @since("2.0.0") + def rootMeanSquaredError(self): + """ + Returns the root mean squared error, which is defined as the + square root of the mean squared error. + + Note: This ignores instance weights (setting all to 1.0) from + `LinearRegression.weightCol`. This will change in later Spark + versions. + """ + return self._call_java("rootMeanSquaredError") + + @property + @since("2.0.0") + def r2(self): + """ + Returns R^2^, the coefficient of determination. + Reference: http://en.wikipedia.org/wiki/Coefficient_of_determination + + Note: This ignores instance weights (setting all to 1.0) from + `LinearRegression.weightCol`. This will change in later Spark + versions. + """ + return self._call_java("r2") + + @property + @since("2.0.0") + def residuals(self): + """ + Residuals (label - predicted value) + """ + return self._call_java("residuals") + + @property + @since("2.0.0") + def numInstances(self): + """ + Number of instances in DataFrame predictions + """ + return self._call_java("numInstances") + + @property + @since("2.0.0") + def devianceResiduals(self): + """ + The weighted residuals, the usual residuals rescaled by the + square root of the instance weights. + """ + return self._call_java("devianceResiduals") + + @property + @since("2.0.0") + def coefficientStandardErrors(self): + """ + Standard error of estimated coefficients and intercept. + This value is only available when using the "normal" solver. + + .. seealso:: :py:attr:`LinearRegression.solver` + """ + return self._call_java("coefficientStandardErrors") + + @property + @since("2.0.0") + def tValues(self): + """ + T-statistic of estimated coefficients and intercept. + This value is only available when using the "normal" solver. + + .. seealso:: :py:attr:`LinearRegression.solver` + """ + return self._call_java("tValues") + + @property + @since("2.0.0") + def pValues(self): + """ + Two-sided p-value of estimated coefficients and intercept. + This value is only available when using the "normal" solver. + + .. seealso:: :py:attr:`LinearRegression.solver` + """ + return self._call_java("pValues") + + +@inherit_doc +class LinearRegressionTrainingSummary(LinearRegressionSummary): + """ + .. note:: Experimental + + Linear regression training results. Currently, the training summary ignores the + training weights except for the objective trace. + + .. versionadded:: 2.0.0 + """ + + @property + @since("2.0.0") + def objectiveHistory(self): + """ + Objective function (scaled loss + regularization) at each + iteration. + This value is only available when using the "l-bfgs" solver. + + .. seealso:: :py:attr:`LinearRegression.solver` + """ + return self._call_java("objectiveHistory") + + @property + @since("2.0.0") + def totalIterations(self): + """ + Number of training iterations until termination. + This value is only available when using the "l-bfgs" solver. + + .. seealso:: :py:attr:`LinearRegression.solver` + """ + return self._call_java("totalIterations") + @inherit_doc class IsotonicRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, |