diff options
author | Holden Karau <holden@pigscanfly.ca> | 2015-12-11 02:35:53 -0500 |
---|---|---|
committer | DB Tsai <dbt@netflix.com> | 2015-12-11 02:35:53 -0500 |
commit | 518ab5101073ee35d62e33c8f7281a1e6342101e (patch) | |
tree | 3c48c08b77424d50242aa33ede7c0ce340402f79 | |
parent | b1b4ee7f3541d92c8bc2b0b4fdadf46cfdb09504 (diff) | |
download | spark-518ab5101073ee35d62e33c8f7281a1e6342101e.tar.gz spark-518ab5101073ee35d62e33c8f7281a1e6342101e.tar.bz2 spark-518ab5101073ee35d62e33c8f7281a1e6342101e.zip |
[SPARK-10991][ML] logistic regression training summary handle empty prediction col
LogisticRegression training summary should still function if the predictionCol is set to an empty string or otherwise unset (related too https://issues.apache.org/jira/browse/SPARK-9718 )
Author: Holden Karau <holden@pigscanfly.ca>
Author: Holden Karau <holden@us.ibm.com>
Closes #9037 from holdenk/SPARK-10991-LogisticRegressionTrainingSummary-handle-empty-prediction-col.
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala | 20 | ||||
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala | 11 |
2 files changed, 29 insertions, 2 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 19cc323d50..486043e8d9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -389,9 +389,10 @@ class LogisticRegression @Since("1.2.0") ( if (handlePersistence) instances.unpersist() val model = copyValues(new LogisticRegressionModel(uid, coefficients, intercept)) + val (summaryModel, probabilityColName) = model.findSummaryModelAndProbabilityCol() val logRegSummary = new BinaryLogisticRegressionTrainingSummary( - model.transform(dataset), - $(probabilityCol), + summaryModel.transform(dataset), + probabilityColName, $(labelCol), $(featuresCol), objectiveHistory) @@ -469,6 +470,21 @@ class LogisticRegressionModel private[ml] ( new NullPointerException()) } + /** + * If the probability column is set returns the current model and probability column, + * otherwise generates a new column and sets it as the probability column on a new copy + * of the current model. + */ + private[classification] def findSummaryModelAndProbabilityCol(): + (LogisticRegressionModel, String) = { + $(probabilityCol) match { + case "" => + val probabilityColName = "probability_" + java.util.UUID.randomUUID.toString() + (copy(ParamMap.empty).setProbabilityCol(probabilityColName), probabilityColName) + case p => (this, p) + } + } + private[classification] def setSummary( summary: LogisticRegressionTrainingSummary): this.type = { this.trainingSummary = Some(summary) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index a9a6ff8a78..1087afb0cd 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -99,6 +99,17 @@ class LogisticRegressionSuite assert(model.hasParent) } + test("empty probabilityCol") { + val lr = new LogisticRegression().setProbabilityCol("") + val model = lr.fit(dataset) + assert(model.hasSummary) + // Validate that we re-insert a probability column for evaluation + val fieldNames = model.summary.predictions.schema.fieldNames + assert((dataset.schema.fieldNames.toSet).subsetOf( + fieldNames.toSet)) + assert(fieldNames.exists(s => s.startsWith("probability_"))) + } + test("setThreshold, getThreshold") { val lr = new LogisticRegression // default |