aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala50
1 files changed, 46 insertions, 4 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala
index 5be2f86936..4d82b90bfd 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala
@@ -53,10 +53,35 @@ private[r] object SparkRWrappers {
def getModelCoefficients(model: PipelineModel): Array[Double] = {
model.stages.last match {
+ case m: LinearRegressionModel => {
+ val coefficientStandardErrorsR = Array(m.summary.coefficientStandardErrors.last) ++
+ m.summary.coefficientStandardErrors.dropRight(1)
+ val tValuesR = Array(m.summary.tValues.last) ++ m.summary.tValues.dropRight(1)
+ val pValuesR = Array(m.summary.pValues.last) ++ m.summary.pValues.dropRight(1)
+ if (m.getFitIntercept) {
+ Array(m.intercept) ++ m.coefficients.toArray ++ coefficientStandardErrorsR ++
+ tValuesR ++ pValuesR
+ } else {
+ m.coefficients.toArray ++ coefficientStandardErrorsR ++ tValuesR ++ pValuesR
+ }
+ }
+ case m: LogisticRegressionModel => {
+ if (m.getFitIntercept) {
+ Array(m.intercept) ++ m.coefficients.toArray
+ } else {
+ m.coefficients.toArray
+ }
+ }
+ }
+ }
+
+ def getModelDevianceResiduals(model: PipelineModel): Array[Double] = {
+ model.stages.last match {
case m: LinearRegressionModel =>
- Array(m.intercept) ++ m.coefficients.toArray
+ m.summary.devianceResiduals
case m: LogisticRegressionModel =>
- Array(m.intercept) ++ m.coefficients.toArray
+ throw new UnsupportedOperationException(
+ "No deviance residuals available for LogisticRegressionModel")
}
}
@@ -65,11 +90,28 @@ private[r] object SparkRWrappers {
case m: LinearRegressionModel =>
val attrs = AttributeGroup.fromStructField(
m.summary.predictions.schema(m.summary.featuresCol))
- Array("(Intercept)") ++ attrs.attributes.get.map(_.name.get)
+ if (m.getFitIntercept) {
+ Array("(Intercept)") ++ attrs.attributes.get.map(_.name.get)
+ } else {
+ attrs.attributes.get.map(_.name.get)
+ }
case m: LogisticRegressionModel =>
val attrs = AttributeGroup.fromStructField(
m.summary.predictions.schema(m.summary.featuresCol))
- Array("(Intercept)") ++ attrs.attributes.get.map(_.name.get)
+ if (m.getFitIntercept) {
+ Array("(Intercept)") ++ attrs.attributes.get.map(_.name.get)
+ } else {
+ attrs.attributes.get.map(_.name.get)
+ }
+ }
+ }
+
+ def getModelName(model: PipelineModel): String = {
+ model.stages.last match {
+ case m: LinearRegressionModel =>
+ "LinearRegressionModel"
+ case m: LogisticRegressionModel =>
+ "LogisticRegressionModel"
}
}
}