aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2015-11-04 08:28:33 -0800
committerXiangrui Meng <meng@databricks.com>2015-11-04 08:28:33 -0800
commite328b69c31821e4b27673d7ef6182ab3b7a05ca8 (patch)
tree7bd0416235fa72fca7097accb6c0a7a6019f80e5 /mllib
parentc09e5139874fb3626e005c8240cca5308b902ef3 (diff)
downloadspark-e328b69c31821e4b27673d7ef6182ab3b7a05ca8.tar.gz
spark-e328b69c31821e4b27673d7ef6182ab3b7a05ca8.tar.bz2
spark-e328b69c31821e4b27673d7ef6182ab3b7a05ca8.zip
[SPARK-9492][ML][R] LogisticRegression in R should provide model statistics
Like ml ```LinearRegression```, ```LogisticRegression``` should provide a training summary including feature names and their coefficients. Author: Yanbo Liang <ybliang8@gmail.com> Closes #9303 from yanboliang/spark-9492.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala17
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala7
2 files changed, 17 insertions, 7 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index a1335e7a1b..f5fca686df 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -378,6 +378,7 @@ class LogisticRegression(override val uid: String)
model.transform(dataset),
$(probabilityCol),
$(labelCol),
+ $(featuresCol),
objectiveHistory)
model.setSummary(logRegSummary)
}
@@ -452,7 +453,8 @@ class LogisticRegressionModel private[ml] (
*/
// TODO: decide on a good name before exposing to public API
private[classification] def evaluate(dataset: DataFrame): LogisticRegressionSummary = {
- new BinaryLogisticRegressionSummary(this.transform(dataset), $(probabilityCol), $(labelCol))
+ new BinaryLogisticRegressionSummary(
+ this.transform(dataset), $(probabilityCol), $(labelCol), $(featuresCol))
}
/**
@@ -614,9 +616,12 @@ sealed trait LogisticRegressionSummary extends Serializable {
/** Field in "predictions" which gives the calibrated probability of each instance as a vector. */
def probabilityCol: String
- /** Field in "predictions" which gives the the true label of each instance. */
+ /** Field in "predictions" which gives the true label of each instance. */
def labelCol: String
+ /** Field in "predictions" which gives the features of each instance as a vector. */
+ def featuresCol: String
+
}
/**
@@ -626,6 +631,7 @@ sealed trait LogisticRegressionSummary extends Serializable {
* @param probabilityCol field in "predictions" which gives the calibrated probability of
* each instance as a vector.
* @param labelCol field in "predictions" which gives the true label of each instance.
+ * @param featuresCol field in "predictions" which gives the features of each instance as a vector.
* @param objectiveHistory objective function (scaled loss + regularization) at each iteration.
*/
@Experimental
@@ -633,8 +639,9 @@ class BinaryLogisticRegressionTrainingSummary private[classification] (
predictions: DataFrame,
probabilityCol: String,
labelCol: String,
+ featuresCol: String,
val objectiveHistory: Array[Double])
- extends BinaryLogisticRegressionSummary(predictions, probabilityCol, labelCol)
+ extends BinaryLogisticRegressionSummary(predictions, probabilityCol, labelCol, featuresCol)
with LogisticRegressionTrainingSummary {
}
@@ -646,12 +653,14 @@ class BinaryLogisticRegressionTrainingSummary private[classification] (
* @param probabilityCol field in "predictions" which gives the calibrated probability of
* each instance.
* @param labelCol field in "predictions" which gives the true label of each instance.
+ * @param featuresCol field in "predictions" which gives the features of each instance as a vector.
*/
@Experimental
class BinaryLogisticRegressionSummary private[classification] (
@transient override val predictions: DataFrame,
override val probabilityCol: String,
- override val labelCol: String) extends LogisticRegressionSummary {
+ override val labelCol: String,
+ override val featuresCol: String) extends LogisticRegressionSummary {
private val sqlContext = predictions.sqlContext
import sqlContext.implicits._
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala
index 24f76de806..5be2f86936 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala
@@ -66,9 +66,10 @@ private[r] object SparkRWrappers {
val attrs = AttributeGroup.fromStructField(
m.summary.predictions.schema(m.summary.featuresCol))
Array("(Intercept)") ++ attrs.attributes.get.map(_.name.get)
- case _: LogisticRegressionModel =>
- throw new UnsupportedOperationException(
- "No features names available for LogisticRegressionModel") // SPARK-9492
+ case m: LogisticRegressionModel =>
+ val attrs = AttributeGroup.fromStructField(
+ m.summary.predictions.schema(m.summary.featuresCol))
+ Array("(Intercept)") ++ attrs.attributes.get.map(_.name.get)
}
}
}