From 75e05a5a964c9585dd09a2ef6178881929bab1f1 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Tue, 12 Apr 2016 10:51:07 -0700 Subject: [SPARK-12566][SPARK-14324][ML] GLM model family, link function support in SparkR:::glm * SparkR glm supports families and link functions which match R's signature for family. * SparkR glm API refactor. The comparative standard of the new API is R glm, so I only expose the arguments that R glm supports: ```formula, family, data, epsilon and maxit```. * This PR is focus on glm() and predict(), summary statistics will be done in a separate PR after this get in. * This PR depends on #12287 which make GLMs support link prediction at Scala side. After that merged, I will add more tests for predict() to this PR. Unit tests. cc mengxr jkbradley hhbyyh Author: Yanbo Liang Closes #12294 from yanboliang/spark-12566. --- .../ml/r/GeneralizedLinearRegressionWrapper.scala | 79 ++++++++++++++ .../org/apache/spark/ml/r/SparkRWrappers.scala | 115 --------------------- 2 files changed, 79 insertions(+), 115 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala delete mode 100644 mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala (limited to 'mllib/src') diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala new file mode 100644 index 0000000000..475a308385 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.r + +import org.apache.spark.ml.{Pipeline, PipelineModel} +import org.apache.spark.ml.attribute.AttributeGroup +import org.apache.spark.ml.feature.RFormula +import org.apache.spark.ml.regression._ +import org.apache.spark.sql._ + +private[r] class GeneralizedLinearRegressionWrapper private ( + pipeline: PipelineModel, + val features: Array[String]) { + + private val glm: GeneralizedLinearRegressionModel = + pipeline.stages(1).asInstanceOf[GeneralizedLinearRegressionModel] + + lazy val rCoefficients: Array[Double] = if (glm.getFitIntercept) { + Array(glm.intercept) ++ glm.coefficients.toArray + } else { + glm.coefficients.toArray + } + + lazy val rFeatures: Array[String] = if (glm.getFitIntercept) { + Array("(Intercept)") ++ features + } else { + features + } + + def transform(dataset: DataFrame): DataFrame = { + pipeline.transform(dataset).drop(glm.getFeaturesCol) + } +} + +private[r] object GeneralizedLinearRegressionWrapper { + + def fit( + formula: String, + data: DataFrame, + family: String, + link: String, + epsilon: Double, + maxit: Int): GeneralizedLinearRegressionWrapper = { + val rFormula = new RFormula() + .setFormula(formula) + val rFormulaModel = rFormula.fit(data) + // get labels and feature names from output schema + val schema = rFormulaModel.transform(data).schema + val featureAttrs = AttributeGroup.fromStructField(schema(rFormula.getFeaturesCol)) + .attributes.get + val features = featureAttrs.map(_.name.get) + // assemble and fit the pipeline + val glm = new GeneralizedLinearRegression() + .setFamily(family) + .setLink(link) + .setFitIntercept(rFormula.hasIntercept) + .setTol(epsilon) + .setMaxIter(maxit) + val pipeline = new Pipeline() + .setStages(Array(rFormulaModel, glm)) + .fit(data) + new GeneralizedLinearRegressionWrapper(pipeline, features) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala deleted file mode 100644 index fa143715be..0000000000 --- a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala +++ /dev/null @@ -1,115 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.ml.api.r - -import org.apache.spark.ml.{Pipeline, PipelineModel} -import org.apache.spark.ml.attribute._ -import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel} -import org.apache.spark.ml.feature.RFormula -import org.apache.spark.ml.regression.{LinearRegression, LinearRegressionModel} -import org.apache.spark.sql.DataFrame - -private[r] object SparkRWrappers { - def fitRModelFormula( - value: String, - df: DataFrame, - family: String, - lambda: Double, - alpha: Double, - standardize: Boolean, - solver: String): PipelineModel = { - val formula = new RFormula().setFormula(value) - val estimator = family match { - case "gaussian" => new LinearRegression() - .setRegParam(lambda) - .setElasticNetParam(alpha) - .setFitIntercept(formula.hasIntercept) - .setStandardization(standardize) - .setSolver(solver) - case "binomial" => new LogisticRegression() - .setRegParam(lambda) - .setElasticNetParam(alpha) - .setFitIntercept(formula.hasIntercept) - .setStandardization(standardize) - } - val pipeline = new Pipeline().setStages(Array(formula, estimator)) - pipeline.fit(df) - } - - def getModelCoefficients(model: PipelineModel): Array[Double] = { - model.stages.last match { - case m: LinearRegressionModel => - val coefficientStandardErrorsR = Array(m.summary.coefficientStandardErrors.last) ++ - m.summary.coefficientStandardErrors.dropRight(1) - val tValuesR = Array(m.summary.tValues.last) ++ m.summary.tValues.dropRight(1) - val pValuesR = Array(m.summary.pValues.last) ++ m.summary.pValues.dropRight(1) - if (m.getFitIntercept) { - Array(m.intercept) ++ m.coefficients.toArray ++ coefficientStandardErrorsR ++ - tValuesR ++ pValuesR - } else { - m.coefficients.toArray ++ coefficientStandardErrorsR ++ tValuesR ++ pValuesR - } - case m: LogisticRegressionModel => - if (m.getFitIntercept) { - Array(m.intercept) ++ m.coefficients.toArray - } else { - m.coefficients.toArray - } - } - } - - def getModelDevianceResiduals(model: PipelineModel): Array[Double] = { - model.stages.last match { - case m: LinearRegressionModel => - m.summary.devianceResiduals - case m: LogisticRegressionModel => - throw new UnsupportedOperationException( - "No deviance residuals available for LogisticRegressionModel") - } - } - - def getModelFeatures(model: PipelineModel): Array[String] = { - model.stages.last match { - case m: LinearRegressionModel => - val attrs = AttributeGroup.fromStructField( - m.summary.predictions.schema(m.summary.featuresCol)) - if (m.getFitIntercept) { - Array("(Intercept)") ++ attrs.attributes.get.map(_.name.get) - } else { - attrs.attributes.get.map(_.name.get) - } - case m: LogisticRegressionModel => - val attrs = AttributeGroup.fromStructField( - m.summary.predictions.schema(m.summary.featuresCol)) - if (m.getFitIntercept) { - Array("(Intercept)") ++ attrs.attributes.get.map(_.name.get) - } else { - attrs.attributes.get.map(_.name.get) - } - } - } - - def getModelName(model: PipelineModel): String = { - model.stages.last match { - case m: LinearRegressionModel => - "LinearRegressionModel" - case m: LogisticRegressionModel => - "LogisticRegressionModel" - } - } -} -- cgit v1.2.3