aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2016-03-27 19:04:18 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-03-27 19:04:18 -0700
commit8ef493760f58687df766d03ccf64039635a2609f (patch)
tree159f15bff4caa365bf74cef0a492d2f7de2e89ac /mllib
parent0f02a5c6e63a95f910e6aba572729ca8085ac3ab (diff)
downloadspark-8ef493760f58687df766d03ccf64039635a2609f.tar.gz
spark-8ef493760f58687df766d03ccf64039635a2609f.tar.bz2
spark-8ef493760f58687df766d03ccf64039635a2609f.zip
[SPARK-10691][ML] Make LogisticRegressionModel, LinearRegressionModel evaluate() public
## What changes were proposed in this pull request? Made evaluate method public. Fixed LogisticRegressionModel evaluate to handle case when probabilityCol is not specified. ## How was this patch tested? There were already unit tests for these methods. Author: Joseph K. Bradley <joseph@databricks.com> Closes #11928 from jkbradley/public-evaluate.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala12
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala8
2 files changed, 11 insertions, 9 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 861b1d4b66..3d1d5b6892 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -539,13 +539,15 @@ class LogisticRegressionModel private[spark] (
def hasSummary: Boolean = trainingSummary.isDefined
/**
- * Evaluates the model on a testset.
+ * Evaluates the model on a test dataset.
* @param dataset Test dataset to evaluate model on.
*/
- // TODO: decide on a good name before exposing to public API
- private[classification] def evaluate(dataset: DataFrame): LogisticRegressionSummary = {
- new BinaryLogisticRegressionSummary(
- this.transform(dataset), $(probabilityCol), $(labelCol), $(featuresCol))
+ @Since("2.0.0")
+ def evaluate(dataset: DataFrame): LogisticRegressionSummary = {
+ // Handle possible missing or invalid prediction columns
+ val (summaryModel, probabilityColName) = findSummaryModelAndProbabilityCol()
+ new BinaryLogisticRegressionSummary(summaryModel.transform(dataset),
+ probabilityColName, $(labelCol), $(featuresCol))
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index b81c588e44..5ec02135cc 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -412,15 +412,15 @@ class LinearRegressionModel private[ml] (
def hasSummary: Boolean = trainingSummary.isDefined
/**
- * Evaluates the model on a testset.
+ * Evaluates the model on a test dataset.
* @param dataset Test dataset to evaluate model on.
*/
- // TODO: decide on a good name before exposing to public API
- private[regression] def evaluate(dataset: DataFrame): LinearRegressionSummary = {
+ @Since("2.0.0")
+ def evaluate(dataset: DataFrame): LinearRegressionSummary = {
// Handle possible missing or invalid prediction columns
val (summaryModel, predictionColName) = findSummaryModelAndPredictionCol()
new LinearRegressionSummary(summaryModel.transform(dataset), predictionColName,
- $(labelCol), this, Array(0D))
+ $(labelCol), summaryModel, Array(0D))
}
/**