From 2388de51912efccaceeb663ac56fc500a79d2ceb Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Tue, 19 Jan 2016 11:08:52 -0800 Subject: [SPARK-12804][ML] Fix LogisticRegression with FitIntercept on all same label training data CC jkbradley mengxr dbtsai Author: Feynman Liang Closes #10743 from feynmanliang/SPARK-12804. --- .../classification/LogisticRegressionSuite.scala | 43 ++++++++++++++++++++++ 1 file changed, 43 insertions(+) (limited to 'mllib/src/test') diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index ff0d0ff771..972c0868a4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -30,6 +30,7 @@ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.functions.lit class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -883,6 +884,48 @@ class LogisticRegressionSuite assert(model1a0.intercept ~== model1b.intercept absTol 1E-3) } + test("logistic regression with all labels the same") { + val sameLabels = dataset + .withColumn("zeroLabel", lit(0.0)) + .withColumn("oneLabel", lit(1.0)) + + // fitIntercept=true + val lrIntercept = new LogisticRegression() + .setFitIntercept(true) + .setMaxIter(3) + + val allZeroInterceptModel = lrIntercept + .setLabelCol("zeroLabel") + .fit(sameLabels) + assert(allZeroInterceptModel.coefficients ~== Vectors.dense(0.0) absTol 1E-3) + assert(allZeroInterceptModel.intercept === Double.NegativeInfinity) + assert(allZeroInterceptModel.summary.totalIterations === 0) + + val allOneInterceptModel = lrIntercept + .setLabelCol("oneLabel") + .fit(sameLabels) + assert(allOneInterceptModel.coefficients ~== Vectors.dense(0.0) absTol 1E-3) + assert(allOneInterceptModel.intercept === Double.PositiveInfinity) + assert(allOneInterceptModel.summary.totalIterations === 0) + + // fitIntercept=false + val lrNoIntercept = new LogisticRegression() + .setFitIntercept(false) + .setMaxIter(3) + + val allZeroNoInterceptModel = lrNoIntercept + .setLabelCol("zeroLabel") + .fit(sameLabels) + assert(allZeroNoInterceptModel.intercept === 0.0) + assert(allZeroNoInterceptModel.summary.totalIterations > 0) + + val allOneNoInterceptModel = lrNoIntercept + .setLabelCol("oneLabel") + .fit(sameLabels) + assert(allOneNoInterceptModel.intercept === 0.0) + assert(allOneNoInterceptModel.summary.totalIterations > 0) + } + test("read/write") { def checkModelData(model: LogisticRegressionModel, model2: LogisticRegressionModel): Unit = { assert(model.intercept === model2.intercept) -- cgit v1.2.3