From c5c6aded641048a3e66ac79d9e84d34e4b1abae7 Mon Sep 17 00:00:00 2001 From: MechCoder Date: Thu, 6 Aug 2015 10:08:33 -0700 Subject: [SPARK-9112] [ML] Implement Stats for LogisticRegression I have added support for stats in LogisticRegression. The API is similar to that of LinearRegression with LogisticRegressionTrainingSummary and LogisticRegressionSummary I have some queries and asked them inline. Author: MechCoder Closes #7538 from MechCoder/log_reg_stats and squashes the following commits: 2e9f7c7 [MechCoder] Change defs into lazy vals d775371 [MechCoder] Clean up class inheritance 9586125 [MechCoder] Add abstraction to handle Multiclass Metrics 40ad8ef [MechCoder] minor 640376a [MechCoder] remove unnecessary dataframe stuff and add docs 80d9954 [MechCoder] Added tests fbed861 [MechCoder] DataFrame support for metrics 70a0fc4 [MechCoder] [SPARK-9112] [ML] Implement Stats for LogisticRegression --- .../JavaLogisticRegressionSuite.java | 9 ++++++ .../classification/LogisticRegressionSuite.scala | 37 +++++++++++++++++++++- 2 files changed, 45 insertions(+), 1 deletion(-) (limited to 'mllib/src/test') diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java index fb1de51163..7e9aa38372 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java @@ -152,4 +152,13 @@ public class JavaLogisticRegressionSuite implements Serializable { } } } + + @Test + public void logisticRegressionTrainingSummary() { + LogisticRegression lr = new LogisticRegression(); + LogisticRegressionModel model = lr.fit(dataset); + + LogisticRegressionTrainingSummary summary = model.summary(); + assert(summary.totalIterations() == summary.objectiveHistory().length); + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index da13dcb42d..8c3d4590f5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -723,6 +723,41 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { val weightsR = Vectors.dense(0.0, 0.0, 0.0, 0.0) assert(model1.intercept ~== interceptR relTol 1E-5) - assert(model1.weights ~= weightsR absTol 1E-6) + assert(model1.weights ~== weightsR absTol 1E-6) + } + + test("evaluate on test set") { + // Evaluate on test set should be same as that of the transformed training data. + val lr = new LogisticRegression() + .setMaxIter(10) + .setRegParam(1.0) + .setThreshold(0.6) + val model = lr.fit(dataset) + val summary = model.summary.asInstanceOf[BinaryLogisticRegressionSummary] + + val sameSummary = model.evaluate(dataset).asInstanceOf[BinaryLogisticRegressionSummary] + assert(summary.areaUnderROC === sameSummary.areaUnderROC) + assert(summary.roc.collect() === sameSummary.roc.collect()) + assert(summary.pr.collect === sameSummary.pr.collect()) + assert( + summary.fMeasureByThreshold.collect() === sameSummary.fMeasureByThreshold.collect()) + assert(summary.recallByThreshold.collect() === sameSummary.recallByThreshold.collect()) + assert( + summary.precisionByThreshold.collect() === sameSummary.precisionByThreshold.collect()) + } + + test("statistics on training data") { + // Test that loss is monotonically decreasing. + val lr = new LogisticRegression() + .setMaxIter(10) + .setRegParam(1.0) + .setThreshold(0.6) + val model = lr.fit(dataset) + assert( + model.summary + .objectiveHistory + .sliding(2) + .forall(x => x(0) >= x(1))) + } } -- cgit v1.2.3