diff options
author | Joseph K. Bradley <joseph@databricks.com> | 2016-03-27 19:04:18 -0700 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2016-03-27 19:04:18 -0700 |
commit | 8ef493760f58687df766d03ccf64039635a2609f (patch) | |
tree | 159f15bff4caa365bf74cef0a492d2f7de2e89ac /mllib/src | |
parent | 0f02a5c6e63a95f910e6aba572729ca8085ac3ab (diff) | |
download | spark-8ef493760f58687df766d03ccf64039635a2609f.tar.gz spark-8ef493760f58687df766d03ccf64039635a2609f.tar.bz2 spark-8ef493760f58687df766d03ccf64039635a2609f.zip |
[SPARK-10691][ML] Make LogisticRegressionModel, LinearRegressionModel evaluate() public
## What changes were proposed in this pull request?
Made evaluate method public. Fixed LogisticRegressionModel evaluate to handle case when probabilityCol is not specified.
## How was this patch tested?
There were already unit tests for these methods.
Author: Joseph K. Bradley <joseph@databricks.com>
Closes #11928 from jkbradley/public-evaluate.
Diffstat (limited to 'mllib/src')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala | 12 | ||||
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala | 8 |
2 files changed, 11 insertions, 9 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 861b1d4b66..3d1d5b6892 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -539,13 +539,15 @@ class LogisticRegressionModel private[spark] ( def hasSummary: Boolean = trainingSummary.isDefined /** - * Evaluates the model on a testset. + * Evaluates the model on a test dataset. * @param dataset Test dataset to evaluate model on. */ - // TODO: decide on a good name before exposing to public API - private[classification] def evaluate(dataset: DataFrame): LogisticRegressionSummary = { - new BinaryLogisticRegressionSummary( - this.transform(dataset), $(probabilityCol), $(labelCol), $(featuresCol)) + @Since("2.0.0") + def evaluate(dataset: DataFrame): LogisticRegressionSummary = { + // Handle possible missing or invalid prediction columns + val (summaryModel, probabilityColName) = findSummaryModelAndProbabilityCol() + new BinaryLogisticRegressionSummary(summaryModel.transform(dataset), + probabilityColName, $(labelCol), $(featuresCol)) } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index b81c588e44..5ec02135cc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -412,15 +412,15 @@ class LinearRegressionModel private[ml] ( def hasSummary: Boolean = trainingSummary.isDefined /** - * Evaluates the model on a testset. + * Evaluates the model on a test dataset. * @param dataset Test dataset to evaluate model on. */ - // TODO: decide on a good name before exposing to public API - private[regression] def evaluate(dataset: DataFrame): LinearRegressionSummary = { + @Since("2.0.0") + def evaluate(dataset: DataFrame): LinearRegressionSummary = { // Handle possible missing or invalid prediction columns val (summaryModel, predictionColName) = findSummaryModelAndPredictionCol() new LinearRegressionSummary(summaryModel.transform(dataset), predictionColName, - $(labelCol), this, Array(0D)) + $(labelCol), summaryModel, Array(0D)) } /** |