aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMechCoder <manojkumarsivaraj334@gmail.com>2015-08-27 15:33:43 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-08-27 15:33:51 -0700
commit501e10a912d540d02fd3a611911e65b781692109 (patch)
treee329b60209d0ba9a8180c30c22b6197196872042
parent66db9cdc6ad3367ddf8d49d4d48c7506a4459675 (diff)
downloadspark-501e10a912d540d02fd3a611911e65b781692109.tar.gz
spark-501e10a912d540d02fd3a611911e65b781692109.tar.bz2
spark-501e10a912d540d02fd3a611911e65b781692109.zip
[SPARK-9906] [ML] User guide for LogisticRegressionSummary
User guide for LogisticRegression summaries Author: MechCoder <manojkumarsivaraj334@gmail.com> Author: Manoj Kumar <mks542@nyu.edu> Author: Feynman Liang <fliang@databricks.com> Closes #8197 from MechCoder/log_summary_user_guide. (cherry picked from commit c94ecdfc5b3c0fe6c38a170dc2af9259354dc9e3) Signed-off-by: Joseph K. Bradley <joseph@databricks.com>
-rw-r--r--docs/ml-linear-methods.md149
1 files changed, 133 insertions, 16 deletions
diff --git a/docs/ml-linear-methods.md b/docs/ml-linear-methods.md
index 1ac83d94c9..2761aeb789 100644
--- a/docs/ml-linear-methods.md
+++ b/docs/ml-linear-methods.md
@@ -23,20 +23,41 @@ displayTitle: <a href="ml-guide.html">ML</a> - Linear Methods
\]`
-In MLlib, we implement popular linear methods such as logistic regression and linear least squares with L1 or L2 regularization. Refer to [the linear methods in mllib](mllib-linear-methods.html) for details. In `spark.ml`, we also include Pipelines API for [Elastic net](http://en.wikipedia.org/wiki/Elastic_net_regularization), a hybrid of L1 and L2 regularization proposed in [this paper](http://users.stat.umn.edu/~zouxx019/Papers/elasticnet.pdf). Mathematically it is defined as a linear combination of the L1-norm and the L2-norm:
+In MLlib, we implement popular linear methods such as logistic
+regression and linear least squares with $L_1$ or $L_2$ regularization.
+Refer to [the linear methods in mllib](mllib-linear-methods.html) for
+details. In `spark.ml`, we also include Pipelines API for [Elastic
+net](http://en.wikipedia.org/wiki/Elastic_net_regularization), a hybrid
+of $L_1$ and $L_2$ regularization proposed in [Zou et al, Regularization
+and variable selection via the elastic
+net](http://users.stat.umn.edu/~zouxx019/Papers/elasticnet.pdf).
+Mathematically, it is defined as a convex combination of the $L_1$ and
+the $L_2$ regularization terms:
`\[
-\alpha \|\wv\|_1 + (1-\alpha) \frac{1}{2}\|\wv\|_2^2, \alpha \in [0, 1].
+\alpha~\lambda \|\wv\|_1 + (1-\alpha) \frac{\lambda}{2}\|\wv\|_2^2, \alpha \in [0, 1], \lambda \geq 0.
\]`
-By setting $\alpha$ properly, it contains both L1 and L2 regularization as special cases. For example, if a [linear regression](https://en.wikipedia.org/wiki/Linear_regression) model is trained with the elastic net parameter $\alpha$ set to $1$, it is equivalent to a [Lasso](http://en.wikipedia.org/wiki/Least_squares#Lasso_method) model. On the other hand, if $\alpha$ is set to $0$, the trained model reduces to a [ridge regression](http://en.wikipedia.org/wiki/Tikhonov_regularization) model. We implement Pipelines API for both linear regression and logistic regression with elastic net regularization.
-
-**Examples**
+By setting $\alpha$ properly, elastic net contains both $L_1$ and $L_2$
+regularization as special cases. For example, if a [linear
+regression](https://en.wikipedia.org/wiki/Linear_regression) model is
+trained with the elastic net parameter $\alpha$ set to $1$, it is
+equivalent to a
+[Lasso](http://en.wikipedia.org/wiki/Least_squares#Lasso_method) model.
+On the other hand, if $\alpha$ is set to $0$, the trained model reduces
+to a [ridge
+regression](http://en.wikipedia.org/wiki/Tikhonov_regularization) model.
+We implement Pipelines API for both linear regression and logistic
+regression with elastic net regularization.
+
+## Example: Logistic Regression
+
+The following example shows how to train a logistic regression model
+with elastic net regularization. `elasticNetParam` corresponds to
+$\alpha$ and `regParam` corresponds to $\lambda$.
<div class="codetabs">
<div data-lang="scala" markdown="1">
-
{% highlight scala %}
-
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.mllib.util.MLUtils
@@ -53,15 +74,11 @@ val lrModel = lr.fit(training)
// Print the weights and intercept for logistic regression
println(s"Weights: ${lrModel.weights} Intercept: ${lrModel.intercept}")
-
{% endhighlight %}
-
</div>
<div data-lang="java" markdown="1">
-
{% highlight java %}
-
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.classification.LogisticRegressionModel;
import org.apache.spark.mllib.regression.LabeledPoint;
@@ -99,9 +116,7 @@ public class LogisticRegressionWithElasticNetExample {
</div>
<div data-lang="python" markdown="1">
-
{% highlight python %}
-
from pyspark.ml.classification import LogisticRegression
from pyspark.mllib.regression import LabeledPoint
from pyspark.mllib.util import MLUtils
@@ -118,12 +133,114 @@ lrModel = lr.fit(training)
print("Weights: " + str(lrModel.weights))
print("Intercept: " + str(lrModel.intercept))
{% endhighlight %}
+</div>
</div>
+The `spark.ml` implementation of logistic regression also supports
+extracting a summary of the model over the training set. Note that the
+predictions and metrics which are stored as `Dataframe` in
+`BinaryLogisticRegressionSummary` are annotated `@transient` and hence
+only available on the driver.
+
+<div class="codetabs">
+
+<div data-lang="scala" markdown="1">
+
+[`LogisticRegressionTrainingSummary`](api/scala/index.html#org.apache.spark.ml.classification.LogisticRegressionTrainingSummary)
+provides a summary for a
+[`LogisticRegressionModel`](api/scala/index.html#org.apache.spark.ml.classification.LogisticRegressionModel).
+Currently, only binary classification is supported and the
+summary must be explicitly cast to
+[`BinaryLogisticRegressionTrainingSummary`](api/scala/index.html#org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary).
+This will likely change when multiclass classification is supported.
+
+Continuing the earlier example:
+
+{% highlight scala %}
+// Extract the summary from the returned LogisticRegressionModel instance trained in the earlier example
+val trainingSummary = lrModel.summary
+
+// Obtain the loss per iteration.
+val objectiveHistory = trainingSummary.objectiveHistory
+objectiveHistory.foreach(loss => println(loss))
+
+// Obtain the metrics useful to judge performance on test data.
+// We cast the summary to a BinaryLogisticRegressionSummary since the problem is a
+// binary classification problem.
+val binarySummary = trainingSummary.asInstanceOf[BinaryLogisticRegressionSummary]
+
+// Obtain the receiver-operating characteristic as a dataframe and areaUnderROC.
+val roc = binarySummary.roc
+roc.show()
+roc.select("FPR").show()
+println(binarySummary.areaUnderROC)
+
+// Get the threshold corresponding to the maximum F-Measure and rerun LogisticRegression with
+// this selected threshold.
+val fMeasure = binarySummary.fMeasureByThreshold
+val maxFMeasure = fMeasure.select(max("F-Measure")).head().getDouble(0)
+val bestThreshold = fMeasure.where($"F-Measure" === maxFMeasure).
+ select("threshold").head().getDouble(0)
+logReg.setThreshold(bestThreshold)
+logReg.fit(logRegDataFrame)
+{% endhighlight %}
</div>
-### Optimization
+<div data-lang="java" markdown="1">
+[`LogisticRegressionTrainingSummary`](api/java/org/apache/spark/ml/classification/LogisticRegressionTrainingSummary.html)
+provides a summary for a
+[`LogisticRegressionModel`](api/java/org/apache/spark/ml/classification/LogisticRegressionModel.html).
+Currently, only binary classification is supported and the
+summary must be explicitly cast to
+[`BinaryLogisticRegressionTrainingSummary`](api/java/org/apache/spark/ml/classification/BinaryLogisticRegressionTrainingSummary.html).
+This will likely change when multiclass classification is supported.
+
+Continuing the earlier example:
+
+{% highlight java %}
+// Extract the summary from the returned LogisticRegressionModel instance trained in the earlier example
+LogisticRegressionTrainingSummary trainingSummary = logRegModel.summary();
+
+// Obtain the loss per iteration.
+double[] objectiveHistory = trainingSummary.objectiveHistory();
+for (double lossPerIteration : objectiveHistory) {
+ System.out.println(lossPerIteration);
+}
+
+// Obtain the metrics useful to judge performance on test data.
+// We cast the summary to a BinaryLogisticRegressionSummary since the problem is a
+// binary classification problem.
+BinaryLogisticRegressionSummary binarySummary = (BinaryLogisticRegressionSummary) trainingSummary;
+
+// Obtain the receiver-operating characteristic as a dataframe and areaUnderROC.
+DataFrame roc = binarySummary.roc();
+roc.show();
+roc.select("FPR").show();
+System.out.println(binarySummary.areaUnderROC());
+
+// Get the threshold corresponding to the maximum F-Measure and rerun LogisticRegression with
+// this selected threshold.
+DataFrame fMeasure = binarySummary.fMeasureByThreshold();
+double maxFMeasure = fMeasure.select(max("F-Measure")).head().getDouble(0);
+double bestThreshold = fMeasure.where(fMeasure.col("F-Measure").equalTo(maxFMeasure)).
+ select("threshold").head().getDouble(0);
+logReg.setThreshold(bestThreshold);
+logReg.fit(logRegDataFrame);
+{% endhighlight %}
+</div>
+
+<div data-lang="python" markdown="1">
+Logistic regression model summary is not yet supported in Python.
+</div>
+
+</div>
+
+# Optimization
+
+The optimization algorithm underlying the implementation is called
+[Orthant-Wise Limited-memory
+QuasiNewton](http://research-srv.microsoft.com/en-us/um/people/jfgao/paper/icml07scalable.pdf)
+(OWL-QN). It is an extension of L-BFGS that can effectively handle L1
+regularization and elastic net.
-The optimization algorithm underlies the implementation is called [Orthant-Wise Limited-memory QuasiNewton](http://research-srv.microsoft.com/en-us/um/people/jfgao/paper/icml07scalable.pdf)
-(OWL-QN). It is an extension of L-BFGS that can effectively handle L1 regularization and elastic net.