diff options
author | Yanbo Liang <ybliang8@gmail.com> | 2015-11-04 08:28:33 -0800 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-11-04 08:28:33 -0800 |
commit | e328b69c31821e4b27673d7ef6182ab3b7a05ca8 (patch) | |
tree | 7bd0416235fa72fca7097accb6c0a7a6019f80e5 /mllib | |
parent | c09e5139874fb3626e005c8240cca5308b902ef3 (diff) | |
download | spark-e328b69c31821e4b27673d7ef6182ab3b7a05ca8.tar.gz spark-e328b69c31821e4b27673d7ef6182ab3b7a05ca8.tar.bz2 spark-e328b69c31821e4b27673d7ef6182ab3b7a05ca8.zip |
[SPARK-9492][ML][R] LogisticRegression in R should provide model statistics
Like ml ```LinearRegression```, ```LogisticRegression``` should provide a training summary including feature names and their coefficients.
Author: Yanbo Liang <ybliang8@gmail.com>
Closes #9303 from yanboliang/spark-9492.
Diffstat (limited to 'mllib')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala | 17 | ||||
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala | 7 |
2 files changed, 17 insertions, 7 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index a1335e7a1b..f5fca686df 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -378,6 +378,7 @@ class LogisticRegression(override val uid: String) model.transform(dataset), $(probabilityCol), $(labelCol), + $(featuresCol), objectiveHistory) model.setSummary(logRegSummary) } @@ -452,7 +453,8 @@ class LogisticRegressionModel private[ml] ( */ // TODO: decide on a good name before exposing to public API private[classification] def evaluate(dataset: DataFrame): LogisticRegressionSummary = { - new BinaryLogisticRegressionSummary(this.transform(dataset), $(probabilityCol), $(labelCol)) + new BinaryLogisticRegressionSummary( + this.transform(dataset), $(probabilityCol), $(labelCol), $(featuresCol)) } /** @@ -614,9 +616,12 @@ sealed trait LogisticRegressionSummary extends Serializable { /** Field in "predictions" which gives the calibrated probability of each instance as a vector. */ def probabilityCol: String - /** Field in "predictions" which gives the the true label of each instance. */ + /** Field in "predictions" which gives the true label of each instance. */ def labelCol: String + /** Field in "predictions" which gives the features of each instance as a vector. */ + def featuresCol: String + } /** @@ -626,6 +631,7 @@ sealed trait LogisticRegressionSummary extends Serializable { * @param probabilityCol field in "predictions" which gives the calibrated probability of * each instance as a vector. * @param labelCol field in "predictions" which gives the true label of each instance. + * @param featuresCol field in "predictions" which gives the features of each instance as a vector. * @param objectiveHistory objective function (scaled loss + regularization) at each iteration. */ @Experimental @@ -633,8 +639,9 @@ class BinaryLogisticRegressionTrainingSummary private[classification] ( predictions: DataFrame, probabilityCol: String, labelCol: String, + featuresCol: String, val objectiveHistory: Array[Double]) - extends BinaryLogisticRegressionSummary(predictions, probabilityCol, labelCol) + extends BinaryLogisticRegressionSummary(predictions, probabilityCol, labelCol, featuresCol) with LogisticRegressionTrainingSummary { } @@ -646,12 +653,14 @@ class BinaryLogisticRegressionTrainingSummary private[classification] ( * @param probabilityCol field in "predictions" which gives the calibrated probability of * each instance. * @param labelCol field in "predictions" which gives the true label of each instance. + * @param featuresCol field in "predictions" which gives the features of each instance as a vector. */ @Experimental class BinaryLogisticRegressionSummary private[classification] ( @transient override val predictions: DataFrame, override val probabilityCol: String, - override val labelCol: String) extends LogisticRegressionSummary { + override val labelCol: String, + override val featuresCol: String) extends LogisticRegressionSummary { private val sqlContext = predictions.sqlContext import sqlContext.implicits._ diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala index 24f76de806..5be2f86936 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala @@ -66,9 +66,10 @@ private[r] object SparkRWrappers { val attrs = AttributeGroup.fromStructField( m.summary.predictions.schema(m.summary.featuresCol)) Array("(Intercept)") ++ attrs.attributes.get.map(_.name.get) - case _: LogisticRegressionModel => - throw new UnsupportedOperationException( - "No features names available for LogisticRegressionModel") // SPARK-9492 + case m: LogisticRegressionModel => + val attrs = AttributeGroup.fromStructField( + m.summary.predictions.schema(m.summary.featuresCol)) + Array("(Intercept)") ++ attrs.attributes.get.map(_.name.get) } } } |