aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/classification.py
diff options
context:
space:
mode:
authorBryan Cutler <cutlerb@gmail.com>2016-04-06 12:07:47 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-04-06 12:07:47 -0700
commit9c6556c5f8ab013b36312db4bf02c4c6d965a535 (patch)
treee4200c088c376f26f27de4f3a96c99006dd99b20 /python/pyspark/ml/classification.py
parentbb1fa5b2182f384cb711fc2be45b0f1a8c466ed6 (diff)
downloadspark-9c6556c5f8ab013b36312db4bf02c4c6d965a535.tar.gz
spark-9c6556c5f8ab013b36312db4bf02c4c6d965a535.tar.bz2
spark-9c6556c5f8ab013b36312db4bf02c4c6d965a535.zip
[SPARK-13430][PYSPARK][ML] Python API for training summaries of linear and logistic regression
## What changes were proposed in this pull request? Adding Python API for training summaries of LogisticRegression and LinearRegression in PySpark ML. ## How was this patch tested? Added unit tests to exercise the api calls for the summary classes. Also, manually verified values are expected and match those from Scala directly. Author: Bryan Cutler <cutlerb@gmail.com> Closes #11621 from BryanCutler/pyspark-ml-summary-SPARK-13430.
Diffstat (limited to 'python/pyspark/ml/classification.py')
-rw-r--r--python/pyspark/ml/classification.py218
1 files changed, 217 insertions, 1 deletions
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index 067009559b..be7f9ea9ef 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -19,15 +19,18 @@ import warnings
from pyspark import since
from pyspark.ml.util import *
-from pyspark.ml.wrapper import JavaEstimator, JavaModel
+from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaCallable
from pyspark.ml.param import TypeConverters
from pyspark.ml.param.shared import *
from pyspark.ml.regression import (
RandomForestParams, TreeEnsembleParams, DecisionTreeModel, TreeEnsembleModels)
from pyspark.mllib.common import inherit_doc
+from pyspark.sql import DataFrame
__all__ = ['LogisticRegression', 'LogisticRegressionModel',
+ 'LogisticRegressionSummary', 'LogisticRegressionTrainingSummary',
+ 'BinaryLogisticRegressionSummary', 'BinaryLogisticRegressionTrainingSummary',
'DecisionTreeClassifier', 'DecisionTreeClassificationModel',
'GBTClassifier', 'GBTClassificationModel',
'RandomForestClassifier', 'RandomForestClassificationModel',
@@ -233,6 +236,219 @@ class LogisticRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable):
"""
return self._call_java("intercept")
+ @property
+ @since("2.0.0")
+ def summary(self):
+ """
+ Gets summary (e.g. residuals, mse, r-squared ) of model on
+ training set. An exception is thrown if
+ `trainingSummary is None`.
+ """
+ java_blrt_summary = self._call_java("summary")
+ # Note: Once multiclass is added, update this to return correct summary
+ return BinaryLogisticRegressionTrainingSummary(java_blrt_summary)
+
+ @property
+ @since("2.0.0")
+ def hasSummary(self):
+ """
+ Indicates whether a training summary exists for this model
+ instance.
+ """
+ return self._call_java("hasSummary")
+
+ @since("2.0.0")
+ def evaluate(self, dataset):
+ """
+ Evaluates the model on a test dataset.
+
+ :param dataset:
+ Test dataset to evaluate model on, where dataset is an
+ instance of :py:class:`pyspark.sql.DataFrame`
+ """
+ if not isinstance(dataset, DataFrame):
+ raise ValueError("dataset must be a DataFrame but got %s." % type(dataset))
+ java_blr_summary = self._call_java("evaluate", dataset)
+ return BinaryLogisticRegressionSummary(java_blr_summary)
+
+
+class LogisticRegressionSummary(JavaCallable):
+ """
+ Abstraction for Logistic Regression Results for a given model.
+
+ .. versionadded:: 2.0.0
+ """
+
+ @property
+ @since("2.0.0")
+ def predictions(self):
+ """
+ Dataframe outputted by the model's `transform` method.
+ """
+ return self._call_java("predictions")
+
+ @property
+ @since("2.0.0")
+ def probabilityCol(self):
+ """
+ Field in "predictions" which gives the calibrated probability
+ of each class as a vector.
+ """
+ return self._call_java("probabilityCol")
+
+ @property
+ @since("2.0.0")
+ def labelCol(self):
+ """
+ Field in "predictions" which gives the true label of each
+ instance.
+ """
+ return self._call_java("labelCol")
+
+ @property
+ @since("2.0.0")
+ def featuresCol(self):
+ """
+ Field in "predictions" which gives the features of each instance
+ as a vector.
+ """
+ return self._call_java("featuresCol")
+
+
+@inherit_doc
+class LogisticRegressionTrainingSummary(LogisticRegressionSummary):
+ """
+ Abstraction for multinomial Logistic Regression Training results.
+ Currently, the training summary ignores the training weights except
+ for the objective trace.
+
+ .. versionadded:: 2.0.0
+ """
+
+ @property
+ @since("2.0.0")
+ def objectiveHistory(self):
+ """
+ Objective function (scaled loss + regularization) at each
+ iteration.
+ """
+ return self._call_java("objectiveHistory")
+
+ @property
+ @since("2.0.0")
+ def totalIterations(self):
+ """
+ Number of training iterations until termination.
+ """
+ return self._call_java("totalIterations")
+
+
+@inherit_doc
+class BinaryLogisticRegressionSummary(LogisticRegressionSummary):
+ """
+ .. note:: Experimental
+
+ Binary Logistic regression results for a given model.
+
+ .. versionadded:: 2.0.0
+ """
+
+ @property
+ @since("2.0.0")
+ def roc(self):
+ """
+ Returns the receiver operating characteristic (ROC) curve,
+ which is an Dataframe having two fields (FPR, TPR) with
+ (0.0, 0.0) prepended and (1.0, 1.0) appended to it.
+ Reference: http://en.wikipedia.org/wiki/Receiver_operating_characteristic
+
+ Note: This ignores instance weights (setting all to 1.0) from
+ `LogisticRegression.weightCol`. This will change in later Spark
+ versions.
+ """
+ return self._call_java("roc")
+
+ @property
+ @since("2.0.0")
+ def areaUnderROC(self):
+ """
+ Computes the area under the receiver operating characteristic
+ (ROC) curve.
+
+ Note: This ignores instance weights (setting all to 1.0) from
+ `LogisticRegression.weightCol`. This will change in later Spark
+ versions.
+ """
+ return self._call_java("areaUnderROC")
+
+ @property
+ @since("2.0.0")
+ def pr(self):
+ """
+ Returns the precision-recall curve, which is an Dataframe
+ containing two fields recall, precision with (0.0, 1.0) prepended
+ to it.
+
+ Note: This ignores instance weights (setting all to 1.0) from
+ `LogisticRegression.weightCol`. This will change in later Spark
+ versions.
+ """
+ return self._call_java("pr")
+
+ @property
+ @since("2.0.0")
+ def fMeasureByThreshold(self):
+ """
+ Returns a dataframe with two fields (threshold, F-Measure) curve
+ with beta = 1.0.
+
+ Note: This ignores instance weights (setting all to 1.0) from
+ `LogisticRegression.weightCol`. This will change in later Spark
+ versions.
+ """
+ return self._call_java("fMeasureByThreshold")
+
+ @property
+ @since("2.0.0")
+ def precisionByThreshold(self):
+ """
+ Returns a dataframe with two fields (threshold, precision) curve.
+ Every possible probability obtained in transforming the dataset
+ are used as thresholds used in calculating the precision.
+
+ Note: This ignores instance weights (setting all to 1.0) from
+ `LogisticRegression.weightCol`. This will change in later Spark
+ versions.
+ """
+ return self._call_java("precisionByThreshold")
+
+ @property
+ @since("2.0.0")
+ def recallByThreshold(self):
+ """
+ Returns a dataframe with two fields (threshold, recall) curve.
+ Every possible probability obtained in transforming the dataset
+ are used as thresholds used in calculating the recall.
+
+ Note: This ignores instance weights (setting all to 1.0) from
+ `LogisticRegression.weightCol`. This will change in later Spark
+ versions.
+ """
+ return self._call_java("recallByThreshold")
+
+
+@inherit_doc
+class BinaryLogisticRegressionTrainingSummary(BinaryLogisticRegressionSummary,
+ LogisticRegressionTrainingSummary):
+ """
+ .. note:: Experimental
+
+ Binary Logistic regression training results for a given model.
+
+ .. versionadded:: 2.0.0
+ """
+ pass
+
class TreeClassifierParams(object):
"""