From 8ef493760f58687df766d03ccf64039635a2609f Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Sun, 27 Mar 2016 19:04:18 -0700 Subject: [SPARK-10691][ML] Make LogisticRegressionModel, LinearRegressionModel evaluate() public ## What changes were proposed in this pull request? Made evaluate method public. Fixed LogisticRegressionModel evaluate to handle case when probabilityCol is not specified. ## How was this patch tested? There were already unit tests for these methods. Author: Joseph K. Bradley Closes #11928 from jkbradley/public-evaluate. --- .../apache/spark/ml/classification/LogisticRegression.scala | 12 +++++++----- .../org/apache/spark/ml/regression/LinearRegression.scala | 8 ++++---- 2 files changed, 11 insertions(+), 9 deletions(-) (limited to 'mllib') diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 861b1d4b66..3d1d5b6892 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -539,13 +539,15 @@ class LogisticRegressionModel private[spark] ( def hasSummary: Boolean = trainingSummary.isDefined /** - * Evaluates the model on a testset. + * Evaluates the model on a test dataset. * @param dataset Test dataset to evaluate model on. */ - // TODO: decide on a good name before exposing to public API - private[classification] def evaluate(dataset: DataFrame): LogisticRegressionSummary = { - new BinaryLogisticRegressionSummary( - this.transform(dataset), $(probabilityCol), $(labelCol), $(featuresCol)) + @Since("2.0.0") + def evaluate(dataset: DataFrame): LogisticRegressionSummary = { + // Handle possible missing or invalid prediction columns + val (summaryModel, probabilityColName) = findSummaryModelAndProbabilityCol() + new BinaryLogisticRegressionSummary(summaryModel.transform(dataset), + probabilityColName, $(labelCol), $(featuresCol)) } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index b81c588e44..5ec02135cc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -412,15 +412,15 @@ class LinearRegressionModel private[ml] ( def hasSummary: Boolean = trainingSummary.isDefined /** - * Evaluates the model on a testset. + * Evaluates the model on a test dataset. * @param dataset Test dataset to evaluate model on. */ - // TODO: decide on a good name before exposing to public API - private[regression] def evaluate(dataset: DataFrame): LinearRegressionSummary = { + @Since("2.0.0") + def evaluate(dataset: DataFrame): LinearRegressionSummary = { // Handle possible missing or invalid prediction columns val (summaryModel, predictionColName) = findSummaryModelAndPredictionCol() new LinearRegressionSummary(summaryModel.transform(dataset), predictionColName, - $(labelCol), this, Array(0D)) + $(labelCol), summaryModel, Array(0D)) } /** -- cgit v1.2.3