diff options
author | Feynman Liang <feynman.liang@gmail.com> | 2016-01-19 11:08:52 -0800 |
---|---|---|
committer | DB Tsai <dbt@netflix.com> | 2016-01-19 11:08:52 -0800 |
commit | 2388de51912efccaceeb663ac56fc500a79d2ceb (patch) | |
tree | bd5cb346672770e96fde624a343e7b5c12dcf9ac /mllib/src/test | |
parent | b122c861cd72b580334a7532f0a52c0439552bdf (diff) | |
download | spark-2388de51912efccaceeb663ac56fc500a79d2ceb.tar.gz spark-2388de51912efccaceeb663ac56fc500a79d2ceb.tar.bz2 spark-2388de51912efccaceeb663ac56fc500a79d2ceb.zip |
[SPARK-12804][ML] Fix LogisticRegression with FitIntercept on all same label training data
CC jkbradley mengxr dbtsai
Author: Feynman Liang <feynman.liang@gmail.com>
Closes #10743 from feynmanliang/SPARK-12804.
Diffstat (limited to 'mllib/src/test')
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala | 43 |
1 files changed, 43 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index ff0d0ff771..972c0868a4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -30,6 +30,7 @@ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.functions.lit class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -883,6 +884,48 @@ class LogisticRegressionSuite assert(model1a0.intercept ~== model1b.intercept absTol 1E-3) } + test("logistic regression with all labels the same") { + val sameLabels = dataset + .withColumn("zeroLabel", lit(0.0)) + .withColumn("oneLabel", lit(1.0)) + + // fitIntercept=true + val lrIntercept = new LogisticRegression() + .setFitIntercept(true) + .setMaxIter(3) + + val allZeroInterceptModel = lrIntercept + .setLabelCol("zeroLabel") + .fit(sameLabels) + assert(allZeroInterceptModel.coefficients ~== Vectors.dense(0.0) absTol 1E-3) + assert(allZeroInterceptModel.intercept === Double.NegativeInfinity) + assert(allZeroInterceptModel.summary.totalIterations === 0) + + val allOneInterceptModel = lrIntercept + .setLabelCol("oneLabel") + .fit(sameLabels) + assert(allOneInterceptModel.coefficients ~== Vectors.dense(0.0) absTol 1E-3) + assert(allOneInterceptModel.intercept === Double.PositiveInfinity) + assert(allOneInterceptModel.summary.totalIterations === 0) + + // fitIntercept=false + val lrNoIntercept = new LogisticRegression() + .setFitIntercept(false) + .setMaxIter(3) + + val allZeroNoInterceptModel = lrNoIntercept + .setLabelCol("zeroLabel") + .fit(sameLabels) + assert(allZeroNoInterceptModel.intercept === 0.0) + assert(allZeroNoInterceptModel.summary.totalIterations > 0) + + val allOneNoInterceptModel = lrNoIntercept + .setLabelCol("oneLabel") + .fit(sameLabels) + assert(allOneNoInterceptModel.intercept === 0.0) + assert(allOneNoInterceptModel.summary.totalIterations > 0) + } + test("read/write") { def checkModelData(model: LogisticRegressionModel, model2: LogisticRegressionModel): Unit = { assert(model.intercept === model2.intercept) |