aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main
diff options
context:
space:
mode:
authorMechCoder <manojkumarsivaraj334@gmail.com>2015-08-06 10:08:33 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-08-06 10:08:33 -0700
commitc5c6aded641048a3e66ac79d9e84d34e4b1abae7 (patch)
tree0ffac363ca1bfc8f8b9720ee06dc8794b89b17c1 /mllib/src/main
parent9f94c85ff35df6289371f80edde51c2aa6c4bcdc (diff)
downloadspark-c5c6aded641048a3e66ac79d9e84d34e4b1abae7.tar.gz
spark-c5c6aded641048a3e66ac79d9e84d34e4b1abae7.tar.bz2
spark-c5c6aded641048a3e66ac79d9e84d34e4b1abae7.zip
[SPARK-9112] [ML] Implement Stats for LogisticRegression
I have added support for stats in LogisticRegression. The API is similar to that of LinearRegression with LogisticRegressionTrainingSummary and LogisticRegressionSummary I have some queries and asked them inline. Author: MechCoder <manojkumarsivaraj334@gmail.com> Closes #7538 from MechCoder/log_reg_stats and squashes the following commits: 2e9f7c7 [MechCoder] Change defs into lazy vals d775371 [MechCoder] Clean up class inheritance 9586125 [MechCoder] Add abstraction to handle Multiclass Metrics 40ad8ef [MechCoder] minor 640376a [MechCoder] remove unnecessary dataframe stuff and add docs 80d9954 [MechCoder] Added tests fbed861 [MechCoder] DataFrame support for metrics 70a0fc4 [MechCoder] [SPARK-9112] [ML] Implement Stats for LogisticRegression
Diffstat (limited to 'mllib/src/main')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala166
1 files changed, 164 insertions, 2 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 0d07383925..f55134d258 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -30,10 +30,12 @@ import org.apache.spark.ml.util.Identifiable
import org.apache.spark.mllib.linalg._
import org.apache.spark.mllib.linalg.BLAS._
import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.storage.StorageLevel
/**
@@ -284,7 +286,13 @@ class LogisticRegression(override val uid: String)
if (handlePersistence) instances.unpersist()
- copyValues(new LogisticRegressionModel(uid, weights, intercept))
+ val model = copyValues(new LogisticRegressionModel(uid, weights, intercept))
+ val logRegSummary = new BinaryLogisticRegressionTrainingSummary(
+ model.transform(dataset),
+ $(probabilityCol),
+ $(labelCol),
+ objectiveHistory)
+ model.setSummary(logRegSummary)
}
override def copy(extra: ParamMap): LogisticRegression = defaultCopy(extra)
@@ -319,6 +327,38 @@ class LogisticRegressionModel private[ml] (
override val numClasses: Int = 2
+ private var trainingSummary: Option[LogisticRegressionTrainingSummary] = None
+
+ /**
+ * Gets summary of model on training set. An exception is
+ * thrown if `trainingSummary == None`.
+ */
+ def summary: LogisticRegressionTrainingSummary = trainingSummary match {
+ case Some(summ) => summ
+ case None =>
+ throw new SparkException(
+ "No training summary available for this LogisticRegressionModel",
+ new NullPointerException())
+ }
+
+ private[classification] def setSummary(
+ summary: LogisticRegressionTrainingSummary): this.type = {
+ this.trainingSummary = Some(summary)
+ this
+ }
+
+ /** Indicates whether a training summary exists for this model instance. */
+ def hasSummary: Boolean = trainingSummary.isDefined
+
+ /**
+ * Evaluates the model on a testset.
+ * @param dataset Test dataset to evaluate model on.
+ */
+ // TODO: decide on a good name before exposing to public API
+ private[classification] def evaluate(dataset: DataFrame): LogisticRegressionSummary = {
+ new BinaryLogisticRegressionSummary(this.transform(dataset), $(probabilityCol), $(labelCol))
+ }
+
/**
* Predict label for the given feature vector.
* The behavior of this can be adjusted using [[thresholds]].
@@ -441,6 +481,128 @@ private[classification] class MultiClassSummarizer extends Serializable {
}
/**
+ * Abstraction for multinomial Logistic Regression Training results.
+ */
+sealed trait LogisticRegressionTrainingSummary extends LogisticRegressionSummary {
+
+ /** objective function (scaled loss + regularization) at each iteration. */
+ def objectiveHistory: Array[Double]
+
+ /** Number of training iterations until termination */
+ def totalIterations: Int = objectiveHistory.length
+
+}
+
+/**
+ * Abstraction for Logistic Regression Results for a given model.
+ */
+sealed trait LogisticRegressionSummary extends Serializable {
+
+ /** Dataframe outputted by the model's `transform` method. */
+ def predictions: DataFrame
+
+ /** Field in "predictions" which gives the calibrated probability of each sample as a vector. */
+ def probabilityCol: String
+
+ /** Field in "predictions" which gives the the true label of each sample. */
+ def labelCol: String
+
+}
+
+/**
+ * :: Experimental ::
+ * Logistic regression training results.
+ * @param predictions dataframe outputted by the model's `transform` method.
+ * @param probabilityCol field in "predictions" which gives the calibrated probability of
+ * each sample as a vector.
+ * @param labelCol field in "predictions" which gives the true label of each sample.
+ * @param objectiveHistory objective function (scaled loss + regularization) at each iteration.
+ */
+@Experimental
+class BinaryLogisticRegressionTrainingSummary private[classification] (
+ predictions: DataFrame,
+ probabilityCol: String,
+ labelCol: String,
+ val objectiveHistory: Array[Double])
+ extends BinaryLogisticRegressionSummary(predictions, probabilityCol, labelCol)
+ with LogisticRegressionTrainingSummary {
+
+}
+
+/**
+ * :: Experimental ::
+ * Binary Logistic regression results for a given model.
+ * @param predictions dataframe outputted by the model's `transform` method.
+ * @param probabilityCol field in "predictions" which gives the calibrated probability of
+ * each sample.
+ * @param labelCol field in "predictions" which gives the true label of each sample.
+ */
+@Experimental
+class BinaryLogisticRegressionSummary private[classification] (
+ @transient override val predictions: DataFrame,
+ override val probabilityCol: String,
+ override val labelCol: String) extends LogisticRegressionSummary {
+
+ private val sqlContext = predictions.sqlContext
+ import sqlContext.implicits._
+
+ /**
+ * Returns a BinaryClassificationMetrics object.
+ */
+ // TODO: Allow the user to vary the number of bins using a setBins method in
+ // BinaryClassificationMetrics. For now the default is set to 100.
+ @transient private val binaryMetrics = new BinaryClassificationMetrics(
+ predictions.select(probabilityCol, labelCol).map {
+ case Row(score: Vector, label: Double) => (score(1), label)
+ }, 100
+ )
+
+ /**
+ * Returns the receiver operating characteristic (ROC) curve,
+ * which is an Dataframe having two fields (FPR, TPR)
+ * with (0.0, 0.0) prepended and (1.0, 1.0) appended to it.
+ * @see http://en.wikipedia.org/wiki/Receiver_operating_characteristic
+ */
+ @transient lazy val roc: DataFrame = binaryMetrics.roc().toDF("FPR", "TPR")
+
+ /**
+ * Computes the area under the receiver operating characteristic (ROC) curve.
+ */
+ lazy val areaUnderROC: Double = binaryMetrics.areaUnderROC()
+
+ /**
+ * Returns the precision-recall curve, which is an Dataframe containing
+ * two fields recall, precision with (0.0, 1.0) prepended to it.
+ */
+ @transient lazy val pr: DataFrame = binaryMetrics.pr().toDF("recall", "precision")
+
+ /**
+ * Returns a dataframe with two fields (threshold, F-Measure) curve with beta = 1.0.
+ */
+ @transient lazy val fMeasureByThreshold: DataFrame = {
+ binaryMetrics.fMeasureByThreshold().toDF("threshold", "F-Measure")
+ }
+
+ /**
+ * Returns a dataframe with two fields (threshold, precision) curve.
+ * Every possible probability obtained in transforming the dataset are used
+ * as thresholds used in calculating the precision.
+ */
+ @transient lazy val precisionByThreshold: DataFrame = {
+ binaryMetrics.precisionByThreshold().toDF("threshold", "precision")
+ }
+
+ /**
+ * Returns a dataframe with two fields (threshold, recall) curve.
+ * Every possible probability obtained in transforming the dataset are used
+ * as thresholds used in calculating the recall.
+ */
+ @transient lazy val recallByThreshold: DataFrame = {
+ binaryMetrics.recallByThreshold().toDF("threshold", "recall")
+ }
+}
+
+/**
* LogisticAggregator computes the gradient and loss for binary logistic loss function, as used
* in binary classification for samples in sparse or dense vector in a online fashion.
*