aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorFeynman Liang <feynman.liang@gmail.com>2016-01-19 11:08:52 -0800
committerDB Tsai <dbt@netflix.com>2016-01-19 11:08:52 -0800
commit2388de51912efccaceeb663ac56fc500a79d2ceb (patch)
treebd5cb346672770e96fde624a343e7b5c12dcf9ac /mllib/src/test
parentb122c861cd72b580334a7532f0a52c0439552bdf (diff)
downloadspark-2388de51912efccaceeb663ac56fc500a79d2ceb.tar.gz
spark-2388de51912efccaceeb663ac56fc500a79d2ceb.tar.bz2
spark-2388de51912efccaceeb663ac56fc500a79d2ceb.zip
[SPARK-12804][ML] Fix LogisticRegression with FitIntercept on all same label training data
CC jkbradley mengxr dbtsai Author: Feynman Liang <feynman.liang@gmail.com> Closes #10743 from feynmanliang/SPARK-12804.
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala43
1 files changed, 43 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index ff0d0ff771..972c0868a4 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -30,6 +30,7 @@ import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.functions.lit
class LogisticRegressionSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
@@ -883,6 +884,48 @@ class LogisticRegressionSuite
assert(model1a0.intercept ~== model1b.intercept absTol 1E-3)
}
+ test("logistic regression with all labels the same") {
+ val sameLabels = dataset
+ .withColumn("zeroLabel", lit(0.0))
+ .withColumn("oneLabel", lit(1.0))
+
+ // fitIntercept=true
+ val lrIntercept = new LogisticRegression()
+ .setFitIntercept(true)
+ .setMaxIter(3)
+
+ val allZeroInterceptModel = lrIntercept
+ .setLabelCol("zeroLabel")
+ .fit(sameLabels)
+ assert(allZeroInterceptModel.coefficients ~== Vectors.dense(0.0) absTol 1E-3)
+ assert(allZeroInterceptModel.intercept === Double.NegativeInfinity)
+ assert(allZeroInterceptModel.summary.totalIterations === 0)
+
+ val allOneInterceptModel = lrIntercept
+ .setLabelCol("oneLabel")
+ .fit(sameLabels)
+ assert(allOneInterceptModel.coefficients ~== Vectors.dense(0.0) absTol 1E-3)
+ assert(allOneInterceptModel.intercept === Double.PositiveInfinity)
+ assert(allOneInterceptModel.summary.totalIterations === 0)
+
+ // fitIntercept=false
+ val lrNoIntercept = new LogisticRegression()
+ .setFitIntercept(false)
+ .setMaxIter(3)
+
+ val allZeroNoInterceptModel = lrNoIntercept
+ .setLabelCol("zeroLabel")
+ .fit(sameLabels)
+ assert(allZeroNoInterceptModel.intercept === 0.0)
+ assert(allZeroNoInterceptModel.summary.totalIterations > 0)
+
+ val allOneNoInterceptModel = lrNoIntercept
+ .setLabelCol("oneLabel")
+ .fit(sameLabels)
+ assert(allOneNoInterceptModel.intercept === 0.0)
+ assert(allOneNoInterceptModel.summary.totalIterations > 0)
+ }
+
test("read/write") {
def checkModelData(model: LogisticRegressionModel, model2: LogisticRegressionModel): Unit = {
assert(model.intercept === model2.intercept)