aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorMechCoder <manojkumarsivaraj334@gmail.com>2015-08-06 10:08:33 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-08-06 10:08:33 -0700
commitc5c6aded641048a3e66ac79d9e84d34e4b1abae7 (patch)
tree0ffac363ca1bfc8f8b9720ee06dc8794b89b17c1 /mllib
parent9f94c85ff35df6289371f80edde51c2aa6c4bcdc (diff)
downloadspark-c5c6aded641048a3e66ac79d9e84d34e4b1abae7.tar.gz
spark-c5c6aded641048a3e66ac79d9e84d34e4b1abae7.tar.bz2
spark-c5c6aded641048a3e66ac79d9e84d34e4b1abae7.zip
[SPARK-9112] [ML] Implement Stats for LogisticRegression
I have added support for stats in LogisticRegression. The API is similar to that of LinearRegression with LogisticRegressionTrainingSummary and LogisticRegressionSummary I have some queries and asked them inline. Author: MechCoder <manojkumarsivaraj334@gmail.com> Closes #7538 from MechCoder/log_reg_stats and squashes the following commits: 2e9f7c7 [MechCoder] Change defs into lazy vals d775371 [MechCoder] Clean up class inheritance 9586125 [MechCoder] Add abstraction to handle Multiclass Metrics 40ad8ef [MechCoder] minor 640376a [MechCoder] remove unnecessary dataframe stuff and add docs 80d9954 [MechCoder] Added tests fbed861 [MechCoder] DataFrame support for metrics 70a0fc4 [MechCoder] [SPARK-9112] [ML] Implement Stats for LogisticRegression
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala166
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java9
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala37
3 files changed, 209 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 0d07383925..f55134d258 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -30,10 +30,12 @@ import org.apache.spark.ml.util.Identifiable
import org.apache.spark.mllib.linalg._
import org.apache.spark.mllib.linalg.BLAS._
import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.storage.StorageLevel
/**
@@ -284,7 +286,13 @@ class LogisticRegression(override val uid: String)
if (handlePersistence) instances.unpersist()
- copyValues(new LogisticRegressionModel(uid, weights, intercept))
+ val model = copyValues(new LogisticRegressionModel(uid, weights, intercept))
+ val logRegSummary = new BinaryLogisticRegressionTrainingSummary(
+ model.transform(dataset),
+ $(probabilityCol),
+ $(labelCol),
+ objectiveHistory)
+ model.setSummary(logRegSummary)
}
override def copy(extra: ParamMap): LogisticRegression = defaultCopy(extra)
@@ -319,6 +327,38 @@ class LogisticRegressionModel private[ml] (
override val numClasses: Int = 2
+ private var trainingSummary: Option[LogisticRegressionTrainingSummary] = None
+
+ /**
+ * Gets summary of model on training set. An exception is
+ * thrown if `trainingSummary == None`.
+ */
+ def summary: LogisticRegressionTrainingSummary = trainingSummary match {
+ case Some(summ) => summ
+ case None =>
+ throw new SparkException(
+ "No training summary available for this LogisticRegressionModel",
+ new NullPointerException())
+ }
+
+ private[classification] def setSummary(
+ summary: LogisticRegressionTrainingSummary): this.type = {
+ this.trainingSummary = Some(summary)
+ this
+ }
+
+ /** Indicates whether a training summary exists for this model instance. */
+ def hasSummary: Boolean = trainingSummary.isDefined
+
+ /**
+ * Evaluates the model on a testset.
+ * @param dataset Test dataset to evaluate model on.
+ */
+ // TODO: decide on a good name before exposing to public API
+ private[classification] def evaluate(dataset: DataFrame): LogisticRegressionSummary = {
+ new BinaryLogisticRegressionSummary(this.transform(dataset), $(probabilityCol), $(labelCol))
+ }
+
/**
* Predict label for the given feature vector.
* The behavior of this can be adjusted using [[thresholds]].
@@ -441,6 +481,128 @@ private[classification] class MultiClassSummarizer extends Serializable {
}
/**
+ * Abstraction for multinomial Logistic Regression Training results.
+ */
+sealed trait LogisticRegressionTrainingSummary extends LogisticRegressionSummary {
+
+ /** objective function (scaled loss + regularization) at each iteration. */
+ def objectiveHistory: Array[Double]
+
+ /** Number of training iterations until termination */
+ def totalIterations: Int = objectiveHistory.length
+
+}
+
+/**
+ * Abstraction for Logistic Regression Results for a given model.
+ */
+sealed trait LogisticRegressionSummary extends Serializable {
+
+ /** Dataframe outputted by the model's `transform` method. */
+ def predictions: DataFrame
+
+ /** Field in "predictions" which gives the calibrated probability of each sample as a vector. */
+ def probabilityCol: String
+
+ /** Field in "predictions" which gives the the true label of each sample. */
+ def labelCol: String
+
+}
+
+/**
+ * :: Experimental ::
+ * Logistic regression training results.
+ * @param predictions dataframe outputted by the model's `transform` method.
+ * @param probabilityCol field in "predictions" which gives the calibrated probability of
+ * each sample as a vector.
+ * @param labelCol field in "predictions" which gives the true label of each sample.
+ * @param objectiveHistory objective function (scaled loss + regularization) at each iteration.
+ */
+@Experimental
+class BinaryLogisticRegressionTrainingSummary private[classification] (
+ predictions: DataFrame,
+ probabilityCol: String,
+ labelCol: String,
+ val objectiveHistory: Array[Double])
+ extends BinaryLogisticRegressionSummary(predictions, probabilityCol, labelCol)
+ with LogisticRegressionTrainingSummary {
+
+}
+
+/**
+ * :: Experimental ::
+ * Binary Logistic regression results for a given model.
+ * @param predictions dataframe outputted by the model's `transform` method.
+ * @param probabilityCol field in "predictions" which gives the calibrated probability of
+ * each sample.
+ * @param labelCol field in "predictions" which gives the true label of each sample.
+ */
+@Experimental
+class BinaryLogisticRegressionSummary private[classification] (
+ @transient override val predictions: DataFrame,
+ override val probabilityCol: String,
+ override val labelCol: String) extends LogisticRegressionSummary {
+
+ private val sqlContext = predictions.sqlContext
+ import sqlContext.implicits._
+
+ /**
+ * Returns a BinaryClassificationMetrics object.
+ */
+ // TODO: Allow the user to vary the number of bins using a setBins method in
+ // BinaryClassificationMetrics. For now the default is set to 100.
+ @transient private val binaryMetrics = new BinaryClassificationMetrics(
+ predictions.select(probabilityCol, labelCol).map {
+ case Row(score: Vector, label: Double) => (score(1), label)
+ }, 100
+ )
+
+ /**
+ * Returns the receiver operating characteristic (ROC) curve,
+ * which is an Dataframe having two fields (FPR, TPR)
+ * with (0.0, 0.0) prepended and (1.0, 1.0) appended to it.
+ * @see http://en.wikipedia.org/wiki/Receiver_operating_characteristic
+ */
+ @transient lazy val roc: DataFrame = binaryMetrics.roc().toDF("FPR", "TPR")
+
+ /**
+ * Computes the area under the receiver operating characteristic (ROC) curve.
+ */
+ lazy val areaUnderROC: Double = binaryMetrics.areaUnderROC()
+
+ /**
+ * Returns the precision-recall curve, which is an Dataframe containing
+ * two fields recall, precision with (0.0, 1.0) prepended to it.
+ */
+ @transient lazy val pr: DataFrame = binaryMetrics.pr().toDF("recall", "precision")
+
+ /**
+ * Returns a dataframe with two fields (threshold, F-Measure) curve with beta = 1.0.
+ */
+ @transient lazy val fMeasureByThreshold: DataFrame = {
+ binaryMetrics.fMeasureByThreshold().toDF("threshold", "F-Measure")
+ }
+
+ /**
+ * Returns a dataframe with two fields (threshold, precision) curve.
+ * Every possible probability obtained in transforming the dataset are used
+ * as thresholds used in calculating the precision.
+ */
+ @transient lazy val precisionByThreshold: DataFrame = {
+ binaryMetrics.precisionByThreshold().toDF("threshold", "precision")
+ }
+
+ /**
+ * Returns a dataframe with two fields (threshold, recall) curve.
+ * Every possible probability obtained in transforming the dataset are used
+ * as thresholds used in calculating the recall.
+ */
+ @transient lazy val recallByThreshold: DataFrame = {
+ binaryMetrics.recallByThreshold().toDF("threshold", "recall")
+ }
+}
+
+/**
* LogisticAggregator computes the gradient and loss for binary logistic loss function, as used
* in binary classification for samples in sparse or dense vector in a online fashion.
*
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
index fb1de51163..7e9aa38372 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
@@ -152,4 +152,13 @@ public class JavaLogisticRegressionSuite implements Serializable {
}
}
}
+
+ @Test
+ public void logisticRegressionTrainingSummary() {
+ LogisticRegression lr = new LogisticRegression();
+ LogisticRegressionModel model = lr.fit(dataset);
+
+ LogisticRegressionTrainingSummary summary = model.summary();
+ assert(summary.totalIterations() == summary.objectiveHistory().length);
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index da13dcb42d..8c3d4590f5 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -723,6 +723,41 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
val weightsR = Vectors.dense(0.0, 0.0, 0.0, 0.0)
assert(model1.intercept ~== interceptR relTol 1E-5)
- assert(model1.weights ~= weightsR absTol 1E-6)
+ assert(model1.weights ~== weightsR absTol 1E-6)
+ }
+
+ test("evaluate on test set") {
+ // Evaluate on test set should be same as that of the transformed training data.
+ val lr = new LogisticRegression()
+ .setMaxIter(10)
+ .setRegParam(1.0)
+ .setThreshold(0.6)
+ val model = lr.fit(dataset)
+ val summary = model.summary.asInstanceOf[BinaryLogisticRegressionSummary]
+
+ val sameSummary = model.evaluate(dataset).asInstanceOf[BinaryLogisticRegressionSummary]
+ assert(summary.areaUnderROC === sameSummary.areaUnderROC)
+ assert(summary.roc.collect() === sameSummary.roc.collect())
+ assert(summary.pr.collect === sameSummary.pr.collect())
+ assert(
+ summary.fMeasureByThreshold.collect() === sameSummary.fMeasureByThreshold.collect())
+ assert(summary.recallByThreshold.collect() === sameSummary.recallByThreshold.collect())
+ assert(
+ summary.precisionByThreshold.collect() === sameSummary.precisionByThreshold.collect())
+ }
+
+ test("statistics on training data") {
+ // Test that loss is monotonically decreasing.
+ val lr = new LogisticRegression()
+ .setMaxIter(10)
+ .setRegParam(1.0)
+ .setThreshold(0.6)
+ val model = lr.fit(dataset)
+ assert(
+ model.summary
+ .objectiveHistory
+ .sliding(2)
+ .forall(x => x(0) >= x(1)))
+
}
}