From e328b69c31821e4b27673d7ef6182ab3b7a05ca8 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 4 Nov 2015 08:28:33 -0800 Subject: [SPARK-9492][ML][R] LogisticRegression in R should provide model statistics Like ml ```LinearRegression```, ```LogisticRegression``` should provide a training summary including feature names and their coefficients. Author: Yanbo Liang Closes #9303 from yanboliang/spark-9492. --- R/pkg/inst/tests/test_mllib.R | 17 +++++++++++++++++ .../spark/ml/classification/LogisticRegression.scala | 17 +++++++++++++---- .../scala/org/apache/spark/ml/r/SparkRWrappers.scala | 7 ++++--- project/MimaExcludes.scala | 4 +++- 4 files changed, 37 insertions(+), 8 deletions(-) diff --git a/R/pkg/inst/tests/test_mllib.R b/R/pkg/inst/tests/test_mllib.R index 3331ce7383..032cfef061 100644 --- a/R/pkg/inst/tests/test_mllib.R +++ b/R/pkg/inst/tests/test_mllib.R @@ -67,3 +67,20 @@ test_that("summary coefficients match with native glm", { as.character(stats$features) == c("(Intercept)", "Sepal_Length", "Species_versicolor", "Species_virginica"))) }) + +test_that("summary coefficients match with native glm of family 'binomial'", { + df <- createDataFrame(sqlContext, iris) + training <- filter(df, df$Species != "setosa") + stats <- summary(glm(Species ~ Sepal_Length + Sepal_Width, data = training, + family = "binomial")) + coefs <- as.vector(stats$coefficients) + + rTraining <- iris[iris$Species %in% c("versicolor","virginica"),] + rCoefs <- as.vector(coef(glm(Species ~ Sepal.Length + Sepal.Width, data = rTraining, + family = binomial(link = "logit")))) + + expect_true(all(abs(rCoefs - coefs) < 1e-4)) + expect_true(all( + as.character(stats$features) == + c("(Intercept)", "Sepal_Length", "Sepal_Width"))) +}) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index a1335e7a1b..f5fca686df 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -378,6 +378,7 @@ class LogisticRegression(override val uid: String) model.transform(dataset), $(probabilityCol), $(labelCol), + $(featuresCol), objectiveHistory) model.setSummary(logRegSummary) } @@ -452,7 +453,8 @@ class LogisticRegressionModel private[ml] ( */ // TODO: decide on a good name before exposing to public API private[classification] def evaluate(dataset: DataFrame): LogisticRegressionSummary = { - new BinaryLogisticRegressionSummary(this.transform(dataset), $(probabilityCol), $(labelCol)) + new BinaryLogisticRegressionSummary( + this.transform(dataset), $(probabilityCol), $(labelCol), $(featuresCol)) } /** @@ -614,9 +616,12 @@ sealed trait LogisticRegressionSummary extends Serializable { /** Field in "predictions" which gives the calibrated probability of each instance as a vector. */ def probabilityCol: String - /** Field in "predictions" which gives the the true label of each instance. */ + /** Field in "predictions" which gives the true label of each instance. */ def labelCol: String + /** Field in "predictions" which gives the features of each instance as a vector. */ + def featuresCol: String + } /** @@ -626,6 +631,7 @@ sealed trait LogisticRegressionSummary extends Serializable { * @param probabilityCol field in "predictions" which gives the calibrated probability of * each instance as a vector. * @param labelCol field in "predictions" which gives the true label of each instance. + * @param featuresCol field in "predictions" which gives the features of each instance as a vector. * @param objectiveHistory objective function (scaled loss + regularization) at each iteration. */ @Experimental @@ -633,8 +639,9 @@ class BinaryLogisticRegressionTrainingSummary private[classification] ( predictions: DataFrame, probabilityCol: String, labelCol: String, + featuresCol: String, val objectiveHistory: Array[Double]) - extends BinaryLogisticRegressionSummary(predictions, probabilityCol, labelCol) + extends BinaryLogisticRegressionSummary(predictions, probabilityCol, labelCol, featuresCol) with LogisticRegressionTrainingSummary { } @@ -646,12 +653,14 @@ class BinaryLogisticRegressionTrainingSummary private[classification] ( * @param probabilityCol field in "predictions" which gives the calibrated probability of * each instance. * @param labelCol field in "predictions" which gives the true label of each instance. + * @param featuresCol field in "predictions" which gives the features of each instance as a vector. */ @Experimental class BinaryLogisticRegressionSummary private[classification] ( @transient override val predictions: DataFrame, override val probabilityCol: String, - override val labelCol: String) extends LogisticRegressionSummary { + override val labelCol: String, + override val featuresCol: String) extends LogisticRegressionSummary { private val sqlContext = predictions.sqlContext import sqlContext.implicits._ diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala index 24f76de806..5be2f86936 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala @@ -66,9 +66,10 @@ private[r] object SparkRWrappers { val attrs = AttributeGroup.fromStructField( m.summary.predictions.schema(m.summary.featuresCol)) Array("(Intercept)") ++ attrs.attributes.get.map(_.name.get) - case _: LogisticRegressionModel => - throw new UnsupportedOperationException( - "No features names available for LogisticRegressionModel") // SPARK-9492 + case m: LogisticRegressionModel => + val attrs = AttributeGroup.fromStructField( + m.summary.predictions.schema(m.summary.featuresCol)) + Array("(Intercept)") ++ attrs.attributes.get.map(_.name.get) } } } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index ec0e44b7f2..eeef96c378 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -59,7 +59,9 @@ object MimaExcludes { ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.ml.classification.LogisticAggregator.add"), ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.ml.classification.LogisticAggregator.count") + "org.apache.spark.ml.classification.LogisticAggregator.count"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.ml.classification.LogisticRegressionSummary.featuresCol") ) ++ Seq( // SPARK-10381 Fix types / units in private AskPermissionToCommitOutput RPC message. // This class is marked as `private` but MiMa still seems to be confused by the change. -- cgit v1.2.3