aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2015-11-04 08:28:33 -0800
committerXiangrui Meng <meng@databricks.com>2015-11-04 08:28:33 -0800
commite328b69c31821e4b27673d7ef6182ab3b7a05ca8 (patch)
tree7bd0416235fa72fca7097accb6c0a7a6019f80e5
parentc09e5139874fb3626e005c8240cca5308b902ef3 (diff)
downloadspark-e328b69c31821e4b27673d7ef6182ab3b7a05ca8.tar.gz
spark-e328b69c31821e4b27673d7ef6182ab3b7a05ca8.tar.bz2
spark-e328b69c31821e4b27673d7ef6182ab3b7a05ca8.zip
[SPARK-9492][ML][R] LogisticRegression in R should provide model statistics
Like ml ```LinearRegression```, ```LogisticRegression``` should provide a training summary including feature names and their coefficients. Author: Yanbo Liang <ybliang8@gmail.com> Closes #9303 from yanboliang/spark-9492.
-rw-r--r--R/pkg/inst/tests/test_mllib.R17
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala17
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala7
-rw-r--r--project/MimaExcludes.scala4
4 files changed, 37 insertions, 8 deletions
diff --git a/R/pkg/inst/tests/test_mllib.R b/R/pkg/inst/tests/test_mllib.R
index 3331ce7383..032cfef061 100644
--- a/R/pkg/inst/tests/test_mllib.R
+++ b/R/pkg/inst/tests/test_mllib.R
@@ -67,3 +67,20 @@ test_that("summary coefficients match with native glm", {
as.character(stats$features) ==
c("(Intercept)", "Sepal_Length", "Species_versicolor", "Species_virginica")))
})
+
+test_that("summary coefficients match with native glm of family 'binomial'", {
+ df <- createDataFrame(sqlContext, iris)
+ training <- filter(df, df$Species != "setosa")
+ stats <- summary(glm(Species ~ Sepal_Length + Sepal_Width, data = training,
+ family = "binomial"))
+ coefs <- as.vector(stats$coefficients)
+
+ rTraining <- iris[iris$Species %in% c("versicolor","virginica"),]
+ rCoefs <- as.vector(coef(glm(Species ~ Sepal.Length + Sepal.Width, data = rTraining,
+ family = binomial(link = "logit"))))
+
+ expect_true(all(abs(rCoefs - coefs) < 1e-4))
+ expect_true(all(
+ as.character(stats$features) ==
+ c("(Intercept)", "Sepal_Length", "Sepal_Width")))
+})
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index a1335e7a1b..f5fca686df 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -378,6 +378,7 @@ class LogisticRegression(override val uid: String)
model.transform(dataset),
$(probabilityCol),
$(labelCol),
+ $(featuresCol),
objectiveHistory)
model.setSummary(logRegSummary)
}
@@ -452,7 +453,8 @@ class LogisticRegressionModel private[ml] (
*/
// TODO: decide on a good name before exposing to public API
private[classification] def evaluate(dataset: DataFrame): LogisticRegressionSummary = {
- new BinaryLogisticRegressionSummary(this.transform(dataset), $(probabilityCol), $(labelCol))
+ new BinaryLogisticRegressionSummary(
+ this.transform(dataset), $(probabilityCol), $(labelCol), $(featuresCol))
}
/**
@@ -614,9 +616,12 @@ sealed trait LogisticRegressionSummary extends Serializable {
/** Field in "predictions" which gives the calibrated probability of each instance as a vector. */
def probabilityCol: String
- /** Field in "predictions" which gives the the true label of each instance. */
+ /** Field in "predictions" which gives the true label of each instance. */
def labelCol: String
+ /** Field in "predictions" which gives the features of each instance as a vector. */
+ def featuresCol: String
+
}
/**
@@ -626,6 +631,7 @@ sealed trait LogisticRegressionSummary extends Serializable {
* @param probabilityCol field in "predictions" which gives the calibrated probability of
* each instance as a vector.
* @param labelCol field in "predictions" which gives the true label of each instance.
+ * @param featuresCol field in "predictions" which gives the features of each instance as a vector.
* @param objectiveHistory objective function (scaled loss + regularization) at each iteration.
*/
@Experimental
@@ -633,8 +639,9 @@ class BinaryLogisticRegressionTrainingSummary private[classification] (
predictions: DataFrame,
probabilityCol: String,
labelCol: String,
+ featuresCol: String,
val objectiveHistory: Array[Double])
- extends BinaryLogisticRegressionSummary(predictions, probabilityCol, labelCol)
+ extends BinaryLogisticRegressionSummary(predictions, probabilityCol, labelCol, featuresCol)
with LogisticRegressionTrainingSummary {
}
@@ -646,12 +653,14 @@ class BinaryLogisticRegressionTrainingSummary private[classification] (
* @param probabilityCol field in "predictions" which gives the calibrated probability of
* each instance.
* @param labelCol field in "predictions" which gives the true label of each instance.
+ * @param featuresCol field in "predictions" which gives the features of each instance as a vector.
*/
@Experimental
class BinaryLogisticRegressionSummary private[classification] (
@transient override val predictions: DataFrame,
override val probabilityCol: String,
- override val labelCol: String) extends LogisticRegressionSummary {
+ override val labelCol: String,
+ override val featuresCol: String) extends LogisticRegressionSummary {
private val sqlContext = predictions.sqlContext
import sqlContext.implicits._
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala
index 24f76de806..5be2f86936 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala
@@ -66,9 +66,10 @@ private[r] object SparkRWrappers {
val attrs = AttributeGroup.fromStructField(
m.summary.predictions.schema(m.summary.featuresCol))
Array("(Intercept)") ++ attrs.attributes.get.map(_.name.get)
- case _: LogisticRegressionModel =>
- throw new UnsupportedOperationException(
- "No features names available for LogisticRegressionModel") // SPARK-9492
+ case m: LogisticRegressionModel =>
+ val attrs = AttributeGroup.fromStructField(
+ m.summary.predictions.schema(m.summary.featuresCol))
+ Array("(Intercept)") ++ attrs.attributes.get.map(_.name.get)
}
}
}
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index ec0e44b7f2..eeef96c378 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -59,7 +59,9 @@ object MimaExcludes {
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.ml.classification.LogisticAggregator.add"),
ProblemFilters.exclude[MissingMethodProblem](
- "org.apache.spark.ml.classification.LogisticAggregator.count")
+ "org.apache.spark.ml.classification.LogisticAggregator.count"),
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.ml.classification.LogisticRegressionSummary.featuresCol")
) ++ Seq(
// SPARK-10381 Fix types / units in private AskPermissionToCommitOutput RPC message.
// This class is marked as `private` but MiMa still seems to be confused by the change.