aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2016-04-12 10:51:07 -0700
committerXiangrui Meng <meng@databricks.com>2016-04-12 10:51:09 -0700
commit75e05a5a964c9585dd09a2ef6178881929bab1f1 (patch)
tree2519cff0d3117b50b459f48a0e60601daea8257a /mllib/src
parent6bf692147c21dd74e91e2bd95845f11ef0a303e6 (diff)
downloadspark-75e05a5a964c9585dd09a2ef6178881929bab1f1.tar.gz
spark-75e05a5a964c9585dd09a2ef6178881929bab1f1.tar.bz2
spark-75e05a5a964c9585dd09a2ef6178881929bab1f1.zip
[SPARK-12566][SPARK-14324][ML] GLM model family, link function support in SparkR:::glm
* SparkR glm supports families and link functions which match R's signature for family. * SparkR glm API refactor. The comparative standard of the new API is R glm, so I only expose the arguments that R glm supports: ```formula, family, data, epsilon and maxit```. * This PR is focus on glm() and predict(), summary statistics will be done in a separate PR after this get in. * This PR depends on #12287 which make GLMs support link prediction at Scala side. After that merged, I will add more tests for predict() to this PR. Unit tests. cc mengxr jkbradley hhbyyh Author: Yanbo Liang <ybliang8@gmail.com> Closes #12294 from yanboliang/spark-12566.
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala79
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala115
2 files changed, 79 insertions, 115 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala
new file mode 100644
index 0000000000..475a308385
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.r
+
+import org.apache.spark.ml.{Pipeline, PipelineModel}
+import org.apache.spark.ml.attribute.AttributeGroup
+import org.apache.spark.ml.feature.RFormula
+import org.apache.spark.ml.regression._
+import org.apache.spark.sql._
+
+private[r] class GeneralizedLinearRegressionWrapper private (
+ pipeline: PipelineModel,
+ val features: Array[String]) {
+
+ private val glm: GeneralizedLinearRegressionModel =
+ pipeline.stages(1).asInstanceOf[GeneralizedLinearRegressionModel]
+
+ lazy val rCoefficients: Array[Double] = if (glm.getFitIntercept) {
+ Array(glm.intercept) ++ glm.coefficients.toArray
+ } else {
+ glm.coefficients.toArray
+ }
+
+ lazy val rFeatures: Array[String] = if (glm.getFitIntercept) {
+ Array("(Intercept)") ++ features
+ } else {
+ features
+ }
+
+ def transform(dataset: DataFrame): DataFrame = {
+ pipeline.transform(dataset).drop(glm.getFeaturesCol)
+ }
+}
+
+private[r] object GeneralizedLinearRegressionWrapper {
+
+ def fit(
+ formula: String,
+ data: DataFrame,
+ family: String,
+ link: String,
+ epsilon: Double,
+ maxit: Int): GeneralizedLinearRegressionWrapper = {
+ val rFormula = new RFormula()
+ .setFormula(formula)
+ val rFormulaModel = rFormula.fit(data)
+ // get labels and feature names from output schema
+ val schema = rFormulaModel.transform(data).schema
+ val featureAttrs = AttributeGroup.fromStructField(schema(rFormula.getFeaturesCol))
+ .attributes.get
+ val features = featureAttrs.map(_.name.get)
+ // assemble and fit the pipeline
+ val glm = new GeneralizedLinearRegression()
+ .setFamily(family)
+ .setLink(link)
+ .setFitIntercept(rFormula.hasIntercept)
+ .setTol(epsilon)
+ .setMaxIter(maxit)
+ val pipeline = new Pipeline()
+ .setStages(Array(rFormulaModel, glm))
+ .fit(data)
+ new GeneralizedLinearRegressionWrapper(pipeline, features)
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala
deleted file mode 100644
index fa143715be..0000000000
--- a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala
+++ /dev/null
@@ -1,115 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.ml.api.r
-
-import org.apache.spark.ml.{Pipeline, PipelineModel}
-import org.apache.spark.ml.attribute._
-import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel}
-import org.apache.spark.ml.feature.RFormula
-import org.apache.spark.ml.regression.{LinearRegression, LinearRegressionModel}
-import org.apache.spark.sql.DataFrame
-
-private[r] object SparkRWrappers {
- def fitRModelFormula(
- value: String,
- df: DataFrame,
- family: String,
- lambda: Double,
- alpha: Double,
- standardize: Boolean,
- solver: String): PipelineModel = {
- val formula = new RFormula().setFormula(value)
- val estimator = family match {
- case "gaussian" => new LinearRegression()
- .setRegParam(lambda)
- .setElasticNetParam(alpha)
- .setFitIntercept(formula.hasIntercept)
- .setStandardization(standardize)
- .setSolver(solver)
- case "binomial" => new LogisticRegression()
- .setRegParam(lambda)
- .setElasticNetParam(alpha)
- .setFitIntercept(formula.hasIntercept)
- .setStandardization(standardize)
- }
- val pipeline = new Pipeline().setStages(Array(formula, estimator))
- pipeline.fit(df)
- }
-
- def getModelCoefficients(model: PipelineModel): Array[Double] = {
- model.stages.last match {
- case m: LinearRegressionModel =>
- val coefficientStandardErrorsR = Array(m.summary.coefficientStandardErrors.last) ++
- m.summary.coefficientStandardErrors.dropRight(1)
- val tValuesR = Array(m.summary.tValues.last) ++ m.summary.tValues.dropRight(1)
- val pValuesR = Array(m.summary.pValues.last) ++ m.summary.pValues.dropRight(1)
- if (m.getFitIntercept) {
- Array(m.intercept) ++ m.coefficients.toArray ++ coefficientStandardErrorsR ++
- tValuesR ++ pValuesR
- } else {
- m.coefficients.toArray ++ coefficientStandardErrorsR ++ tValuesR ++ pValuesR
- }
- case m: LogisticRegressionModel =>
- if (m.getFitIntercept) {
- Array(m.intercept) ++ m.coefficients.toArray
- } else {
- m.coefficients.toArray
- }
- }
- }
-
- def getModelDevianceResiduals(model: PipelineModel): Array[Double] = {
- model.stages.last match {
- case m: LinearRegressionModel =>
- m.summary.devianceResiduals
- case m: LogisticRegressionModel =>
- throw new UnsupportedOperationException(
- "No deviance residuals available for LogisticRegressionModel")
- }
- }
-
- def getModelFeatures(model: PipelineModel): Array[String] = {
- model.stages.last match {
- case m: LinearRegressionModel =>
- val attrs = AttributeGroup.fromStructField(
- m.summary.predictions.schema(m.summary.featuresCol))
- if (m.getFitIntercept) {
- Array("(Intercept)") ++ attrs.attributes.get.map(_.name.get)
- } else {
- attrs.attributes.get.map(_.name.get)
- }
- case m: LogisticRegressionModel =>
- val attrs = AttributeGroup.fromStructField(
- m.summary.predictions.schema(m.summary.featuresCol))
- if (m.getFitIntercept) {
- Array("(Intercept)") ++ attrs.attributes.get.map(_.name.get)
- } else {
- attrs.attributes.get.map(_.name.get)
- }
- }
- }
-
- def getModelName(model: PipelineModel): String = {
- model.stages.last match {
- case m: LinearRegressionModel =>
- "LinearRegressionModel"
- case m: LogisticRegressionModel =>
- "LogisticRegressionModel"
- }
- }
-}