aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorMechCoder <manojkumarsivaraj334@gmail.com>2015-08-06 10:08:33 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-08-06 10:08:33 -0700
commitc5c6aded641048a3e66ac79d9e84d34e4b1abae7 (patch)
tree0ffac363ca1bfc8f8b9720ee06dc8794b89b17c1 /mllib/src/test
parent9f94c85ff35df6289371f80edde51c2aa6c4bcdc (diff)
downloadspark-c5c6aded641048a3e66ac79d9e84d34e4b1abae7.tar.gz
spark-c5c6aded641048a3e66ac79d9e84d34e4b1abae7.tar.bz2
spark-c5c6aded641048a3e66ac79d9e84d34e4b1abae7.zip
[SPARK-9112] [ML] Implement Stats for LogisticRegression
I have added support for stats in LogisticRegression. The API is similar to that of LinearRegression with LogisticRegressionTrainingSummary and LogisticRegressionSummary I have some queries and asked them inline. Author: MechCoder <manojkumarsivaraj334@gmail.com> Closes #7538 from MechCoder/log_reg_stats and squashes the following commits: 2e9f7c7 [MechCoder] Change defs into lazy vals d775371 [MechCoder] Clean up class inheritance 9586125 [MechCoder] Add abstraction to handle Multiclass Metrics 40ad8ef [MechCoder] minor 640376a [MechCoder] remove unnecessary dataframe stuff and add docs 80d9954 [MechCoder] Added tests fbed861 [MechCoder] DataFrame support for metrics 70a0fc4 [MechCoder] [SPARK-9112] [ML] Implement Stats for LogisticRegression
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java9
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala37
2 files changed, 45 insertions, 1 deletions
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
index fb1de51163..7e9aa38372 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
@@ -152,4 +152,13 @@ public class JavaLogisticRegressionSuite implements Serializable {
}
}
}
+
+ @Test
+ public void logisticRegressionTrainingSummary() {
+ LogisticRegression lr = new LogisticRegression();
+ LogisticRegressionModel model = lr.fit(dataset);
+
+ LogisticRegressionTrainingSummary summary = model.summary();
+ assert(summary.totalIterations() == summary.objectiveHistory().length);
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index da13dcb42d..8c3d4590f5 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -723,6 +723,41 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
val weightsR = Vectors.dense(0.0, 0.0, 0.0, 0.0)
assert(model1.intercept ~== interceptR relTol 1E-5)
- assert(model1.weights ~= weightsR absTol 1E-6)
+ assert(model1.weights ~== weightsR absTol 1E-6)
+ }
+
+ test("evaluate on test set") {
+ // Evaluate on test set should be same as that of the transformed training data.
+ val lr = new LogisticRegression()
+ .setMaxIter(10)
+ .setRegParam(1.0)
+ .setThreshold(0.6)
+ val model = lr.fit(dataset)
+ val summary = model.summary.asInstanceOf[BinaryLogisticRegressionSummary]
+
+ val sameSummary = model.evaluate(dataset).asInstanceOf[BinaryLogisticRegressionSummary]
+ assert(summary.areaUnderROC === sameSummary.areaUnderROC)
+ assert(summary.roc.collect() === sameSummary.roc.collect())
+ assert(summary.pr.collect === sameSummary.pr.collect())
+ assert(
+ summary.fMeasureByThreshold.collect() === sameSummary.fMeasureByThreshold.collect())
+ assert(summary.recallByThreshold.collect() === sameSummary.recallByThreshold.collect())
+ assert(
+ summary.precisionByThreshold.collect() === sameSummary.precisionByThreshold.collect())
+ }
+
+ test("statistics on training data") {
+ // Test that loss is monotonically decreasing.
+ val lr = new LogisticRegression()
+ .setMaxIter(10)
+ .setRegParam(1.0)
+ .setThreshold(0.6)
+ val model = lr.fit(dataset)
+ assert(
+ model.summary
+ .objectiveHistory
+ .sliding(2)
+ .forall(x => x(0) >= x(1)))
+
}
}