diff options
Diffstat (limited to 'python/pyspark/ml/classification.py')
-rw-r--r-- | python/pyspark/ml/classification.py | 218 |
1 files changed, 217 insertions, 1 deletions
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 067009559b..be7f9ea9ef 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -19,15 +19,18 @@ import warnings from pyspark import since from pyspark.ml.util import * -from pyspark.ml.wrapper import JavaEstimator, JavaModel +from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaCallable from pyspark.ml.param import TypeConverters from pyspark.ml.param.shared import * from pyspark.ml.regression import ( RandomForestParams, TreeEnsembleParams, DecisionTreeModel, TreeEnsembleModels) from pyspark.mllib.common import inherit_doc +from pyspark.sql import DataFrame __all__ = ['LogisticRegression', 'LogisticRegressionModel', + 'LogisticRegressionSummary', 'LogisticRegressionTrainingSummary', + 'BinaryLogisticRegressionSummary', 'BinaryLogisticRegressionTrainingSummary', 'DecisionTreeClassifier', 'DecisionTreeClassificationModel', 'GBTClassifier', 'GBTClassificationModel', 'RandomForestClassifier', 'RandomForestClassificationModel', @@ -233,6 +236,219 @@ class LogisticRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable): """ return self._call_java("intercept") + @property + @since("2.0.0") + def summary(self): + """ + Gets summary (e.g. residuals, mse, r-squared ) of model on + training set. An exception is thrown if + `trainingSummary is None`. + """ + java_blrt_summary = self._call_java("summary") + # Note: Once multiclass is added, update this to return correct summary + return BinaryLogisticRegressionTrainingSummary(java_blrt_summary) + + @property + @since("2.0.0") + def hasSummary(self): + """ + Indicates whether a training summary exists for this model + instance. + """ + return self._call_java("hasSummary") + + @since("2.0.0") + def evaluate(self, dataset): + """ + Evaluates the model on a test dataset. + + :param dataset: + Test dataset to evaluate model on, where dataset is an + instance of :py:class:`pyspark.sql.DataFrame` + """ + if not isinstance(dataset, DataFrame): + raise ValueError("dataset must be a DataFrame but got %s." % type(dataset)) + java_blr_summary = self._call_java("evaluate", dataset) + return BinaryLogisticRegressionSummary(java_blr_summary) + + +class LogisticRegressionSummary(JavaCallable): + """ + Abstraction for Logistic Regression Results for a given model. + + .. versionadded:: 2.0.0 + """ + + @property + @since("2.0.0") + def predictions(self): + """ + Dataframe outputted by the model's `transform` method. + """ + return self._call_java("predictions") + + @property + @since("2.0.0") + def probabilityCol(self): + """ + Field in "predictions" which gives the calibrated probability + of each class as a vector. + """ + return self._call_java("probabilityCol") + + @property + @since("2.0.0") + def labelCol(self): + """ + Field in "predictions" which gives the true label of each + instance. + """ + return self._call_java("labelCol") + + @property + @since("2.0.0") + def featuresCol(self): + """ + Field in "predictions" which gives the features of each instance + as a vector. + """ + return self._call_java("featuresCol") + + +@inherit_doc +class LogisticRegressionTrainingSummary(LogisticRegressionSummary): + """ + Abstraction for multinomial Logistic Regression Training results. + Currently, the training summary ignores the training weights except + for the objective trace. + + .. versionadded:: 2.0.0 + """ + + @property + @since("2.0.0") + def objectiveHistory(self): + """ + Objective function (scaled loss + regularization) at each + iteration. + """ + return self._call_java("objectiveHistory") + + @property + @since("2.0.0") + def totalIterations(self): + """ + Number of training iterations until termination. + """ + return self._call_java("totalIterations") + + +@inherit_doc +class BinaryLogisticRegressionSummary(LogisticRegressionSummary): + """ + .. note:: Experimental + + Binary Logistic regression results for a given model. + + .. versionadded:: 2.0.0 + """ + + @property + @since("2.0.0") + def roc(self): + """ + Returns the receiver operating characteristic (ROC) curve, + which is an Dataframe having two fields (FPR, TPR) with + (0.0, 0.0) prepended and (1.0, 1.0) appended to it. + Reference: http://en.wikipedia.org/wiki/Receiver_operating_characteristic + + Note: This ignores instance weights (setting all to 1.0) from + `LogisticRegression.weightCol`. This will change in later Spark + versions. + """ + return self._call_java("roc") + + @property + @since("2.0.0") + def areaUnderROC(self): + """ + Computes the area under the receiver operating characteristic + (ROC) curve. + + Note: This ignores instance weights (setting all to 1.0) from + `LogisticRegression.weightCol`. This will change in later Spark + versions. + """ + return self._call_java("areaUnderROC") + + @property + @since("2.0.0") + def pr(self): + """ + Returns the precision-recall curve, which is an Dataframe + containing two fields recall, precision with (0.0, 1.0) prepended + to it. + + Note: This ignores instance weights (setting all to 1.0) from + `LogisticRegression.weightCol`. This will change in later Spark + versions. + """ + return self._call_java("pr") + + @property + @since("2.0.0") + def fMeasureByThreshold(self): + """ + Returns a dataframe with two fields (threshold, F-Measure) curve + with beta = 1.0. + + Note: This ignores instance weights (setting all to 1.0) from + `LogisticRegression.weightCol`. This will change in later Spark + versions. + """ + return self._call_java("fMeasureByThreshold") + + @property + @since("2.0.0") + def precisionByThreshold(self): + """ + Returns a dataframe with two fields (threshold, precision) curve. + Every possible probability obtained in transforming the dataset + are used as thresholds used in calculating the precision. + + Note: This ignores instance weights (setting all to 1.0) from + `LogisticRegression.weightCol`. This will change in later Spark + versions. + """ + return self._call_java("precisionByThreshold") + + @property + @since("2.0.0") + def recallByThreshold(self): + """ + Returns a dataframe with two fields (threshold, recall) curve. + Every possible probability obtained in transforming the dataset + are used as thresholds used in calculating the recall. + + Note: This ignores instance weights (setting all to 1.0) from + `LogisticRegression.weightCol`. This will change in later Spark + versions. + """ + return self._call_java("recallByThreshold") + + +@inherit_doc +class BinaryLogisticRegressionTrainingSummary(BinaryLogisticRegressionSummary, + LogisticRegressionTrainingSummary): + """ + .. note:: Experimental + + Binary Logistic regression training results for a given model. + + .. versionadded:: 2.0.0 + """ + pass + class TreeClassifierParams(object): """ |