aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorHolden Karau <holden@pigscanfly.ca>2015-12-11 02:35:53 -0500
committerDB Tsai <dbt@netflix.com>2015-12-11 02:35:53 -0500
commit518ab5101073ee35d62e33c8f7281a1e6342101e (patch)
tree3c48c08b77424d50242aa33ede7c0ce340402f79 /mllib
parentb1b4ee7f3541d92c8bc2b0b4fdadf46cfdb09504 (diff)
downloadspark-518ab5101073ee35d62e33c8f7281a1e6342101e.tar.gz
spark-518ab5101073ee35d62e33c8f7281a1e6342101e.tar.bz2
spark-518ab5101073ee35d62e33c8f7281a1e6342101e.zip
[SPARK-10991][ML] logistic regression training summary handle empty prediction col
LogisticRegression training summary should still function if the predictionCol is set to an empty string or otherwise unset (related too https://issues.apache.org/jira/browse/SPARK-9718 ) Author: Holden Karau <holden@pigscanfly.ca> Author: Holden Karau <holden@us.ibm.com> Closes #9037 from holdenk/SPARK-10991-LogisticRegressionTrainingSummary-handle-empty-prediction-col.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala20
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala11
2 files changed, 29 insertions, 2 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 19cc323d50..486043e8d9 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -389,9 +389,10 @@ class LogisticRegression @Since("1.2.0") (
if (handlePersistence) instances.unpersist()
val model = copyValues(new LogisticRegressionModel(uid, coefficients, intercept))
+ val (summaryModel, probabilityColName) = model.findSummaryModelAndProbabilityCol()
val logRegSummary = new BinaryLogisticRegressionTrainingSummary(
- model.transform(dataset),
- $(probabilityCol),
+ summaryModel.transform(dataset),
+ probabilityColName,
$(labelCol),
$(featuresCol),
objectiveHistory)
@@ -469,6 +470,21 @@ class LogisticRegressionModel private[ml] (
new NullPointerException())
}
+ /**
+ * If the probability column is set returns the current model and probability column,
+ * otherwise generates a new column and sets it as the probability column on a new copy
+ * of the current model.
+ */
+ private[classification] def findSummaryModelAndProbabilityCol():
+ (LogisticRegressionModel, String) = {
+ $(probabilityCol) match {
+ case "" =>
+ val probabilityColName = "probability_" + java.util.UUID.randomUUID.toString()
+ (copy(ParamMap.empty).setProbabilityCol(probabilityColName), probabilityColName)
+ case p => (this, p)
+ }
+ }
+
private[classification] def setSummary(
summary: LogisticRegressionTrainingSummary): this.type = {
this.trainingSummary = Some(summary)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index a9a6ff8a78..1087afb0cd 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -99,6 +99,17 @@ class LogisticRegressionSuite
assert(model.hasParent)
}
+ test("empty probabilityCol") {
+ val lr = new LogisticRegression().setProbabilityCol("")
+ val model = lr.fit(dataset)
+ assert(model.hasSummary)
+ // Validate that we re-insert a probability column for evaluation
+ val fieldNames = model.summary.predictions.schema.fieldNames
+ assert((dataset.schema.fieldNames.toSet).subsetOf(
+ fieldNames.toSet))
+ assert(fieldNames.exists(s => s.startsWith("probability_")))
+ }
+
test("setThreshold, getThreshold") {
val lr = new LogisticRegression
// default