aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2015-11-09 08:56:22 -0800
committerXiangrui Meng <meng@databricks.com>2015-11-09 08:56:22 -0800
commit8c0e1b50e960d3e8e51d0618c462eed2bb4936f0 (patch)
tree467a738e59a86e39c1f59f00b1c0bbfffba55e1c /mllib/src/main/scala/org
parentb541b31630b1b85b48d6096079d073ccf46a62e8 (diff)
downloadspark-8c0e1b50e960d3e8e51d0618c462eed2bb4936f0.tar.gz
spark-8c0e1b50e960d3e8e51d0618c462eed2bb4936f0.tar.bz2
spark-8c0e1b50e960d3e8e51d0618c462eed2bb4936f0.zip
[SPARK-11494][ML][R] Expose R-like summary statistics in SparkR::glm for linear regression
Expose R-like summary statistics in SparkR::glm for linear regression, the output of ```summary``` like ```Java $DevianceResiduals Min Max -0.9509607 0.7291832 $Coefficients Estimate Std. Error t value Pr(>|t|) (Intercept) 1.6765 0.2353597 7.123139 4.456124e-11 Sepal_Length 0.3498801 0.04630128 7.556598 4.187317e-12 Species_versicolor -0.9833885 0.07207471 -13.64402 0 Species_virginica -1.00751 0.09330565 -10.79796 0 ``` Author: Yanbo Liang <ybliang8@gmail.com> Closes #9561 from yanboliang/spark-11494.
Diffstat (limited to 'mllib/src/main/scala/org')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala50
1 files changed, 46 insertions, 4 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala
index 5be2f86936..4d82b90bfd 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala
@@ -53,10 +53,35 @@ private[r] object SparkRWrappers {
def getModelCoefficients(model: PipelineModel): Array[Double] = {
model.stages.last match {
+ case m: LinearRegressionModel => {
+ val coefficientStandardErrorsR = Array(m.summary.coefficientStandardErrors.last) ++
+ m.summary.coefficientStandardErrors.dropRight(1)
+ val tValuesR = Array(m.summary.tValues.last) ++ m.summary.tValues.dropRight(1)
+ val pValuesR = Array(m.summary.pValues.last) ++ m.summary.pValues.dropRight(1)
+ if (m.getFitIntercept) {
+ Array(m.intercept) ++ m.coefficients.toArray ++ coefficientStandardErrorsR ++
+ tValuesR ++ pValuesR
+ } else {
+ m.coefficients.toArray ++ coefficientStandardErrorsR ++ tValuesR ++ pValuesR
+ }
+ }
+ case m: LogisticRegressionModel => {
+ if (m.getFitIntercept) {
+ Array(m.intercept) ++ m.coefficients.toArray
+ } else {
+ m.coefficients.toArray
+ }
+ }
+ }
+ }
+
+ def getModelDevianceResiduals(model: PipelineModel): Array[Double] = {
+ model.stages.last match {
case m: LinearRegressionModel =>
- Array(m.intercept) ++ m.coefficients.toArray
+ m.summary.devianceResiduals
case m: LogisticRegressionModel =>
- Array(m.intercept) ++ m.coefficients.toArray
+ throw new UnsupportedOperationException(
+ "No deviance residuals available for LogisticRegressionModel")
}
}
@@ -65,11 +90,28 @@ private[r] object SparkRWrappers {
case m: LinearRegressionModel =>
val attrs = AttributeGroup.fromStructField(
m.summary.predictions.schema(m.summary.featuresCol))
- Array("(Intercept)") ++ attrs.attributes.get.map(_.name.get)
+ if (m.getFitIntercept) {
+ Array("(Intercept)") ++ attrs.attributes.get.map(_.name.get)
+ } else {
+ attrs.attributes.get.map(_.name.get)
+ }
case m: LogisticRegressionModel =>
val attrs = AttributeGroup.fromStructField(
m.summary.predictions.schema(m.summary.featuresCol))
- Array("(Intercept)") ++ attrs.attributes.get.map(_.name.get)
+ if (m.getFitIntercept) {
+ Array("(Intercept)") ++ attrs.attributes.get.map(_.name.get)
+ } else {
+ attrs.attributes.get.map(_.name.get)
+ }
+ }
+ }
+
+ def getModelName(model: PipelineModel): String = {
+ model.stages.last match {
+ case m: LinearRegressionModel =>
+ "LinearRegressionModel"
+ case m: LogisticRegressionModel =>
+ "LogisticRegressionModel"
}
}
}