aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2016-03-01 08:47:56 -0800
committerXiangrui Meng <meng@databricks.com>2016-03-01 08:47:56 -0800
commit5ed48dd84d38dfe621428e164a02e74ddbdbc622 (patch)
tree1d342ac5a27e37fe11a9aaf6b5863322f56faffb /mllib
parentc43899a04e4de18e238a1761bf4fe9f54e182320 (diff)
downloadspark-5ed48dd84d38dfe621428e164a02e74ddbdbc622.tar.gz
spark-5ed48dd84d38dfe621428e164a02e74ddbdbc622.tar.bz2
spark-5ed48dd84d38dfe621428e164a02e74ddbdbc622.zip
[SPARK-12811][ML] Estimator for Generalized Linear Models(GLMs)
Estimator for Generalized Linear Models(GLMs) which will be solved by IRLS. cc mengxr Author: Yanbo Liang <ybliang8@gmail.com> Closes #11136 from yanboliang/spark-12811.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala10
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala577
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala507
4 files changed, 1094 insertions, 4 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala
index 61b3642131..55b7510656 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala
@@ -157,6 +157,12 @@ private[ml] class WeightedLeastSquares(
private[ml] object WeightedLeastSquares {
/**
+ * In order to take the normal equation approach efficiently, [[WeightedLeastSquares]]
+ * only supports the number of features is no more than 4096.
+ */
+ val MAX_NUM_FEATURES: Int = 4096
+
+ /**
* Aggregator to provide necessary summary statistics for solving [[WeightedLeastSquares]].
*/
// TODO: consolidate aggregates for summary statistics
@@ -174,8 +180,8 @@ private[ml] object WeightedLeastSquares {
private var aaSum: DenseVector = _
private def init(k: Int): Unit = {
- require(k <= 4096, "In order to take the normal equation approach efficiently, " +
- s"we set the max number of features to 4096 but got $k.")
+ require(k <= MAX_NUM_FEATURES, "In order to take the normal equation approach efficiently, " +
+ s"we set the max number of features to $MAX_NUM_FEATURES but got $k.")
this.k = k
triK = k * (k + 1) / 2
count = 0L
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
new file mode 100644
index 0000000000..a850dfee0a
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
@@ -0,0 +1,577 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.regression
+
+import breeze.stats.distributions.{Gaussian => GD}
+
+import org.apache.spark.{Logging, SparkException}
+import org.apache.spark.annotation.{Experimental, Since}
+import org.apache.spark.ml.PredictorParams
+import org.apache.spark.ml.feature.Instance
+import org.apache.spark.ml.optim._
+import org.apache.spark.ml.param._
+import org.apache.spark.ml.param.shared._
+import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.mllib.linalg.{BLAS, Vector}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.functions._
+
+/**
+ * Params for Generalized Linear Regression.
+ */
+private[regression] trait GeneralizedLinearRegressionBase extends PredictorParams
+ with HasFitIntercept with HasMaxIter with HasTol with HasRegParam with HasWeightCol
+ with HasSolver with Logging {
+
+ /**
+ * Param for the name of family which is a description of the error distribution
+ * to be used in the model.
+ * Supported options: "gaussian", "binomial", "poisson" and "gamma".
+ * Default is "gaussian".
+ * @group param
+ */
+ @Since("2.0.0")
+ final val family: Param[String] = new Param(this, "family",
+ "The name of family which is a description of the error distribution to be used in the " +
+ "model. Supported options: gaussian(default), binomial, poisson and gamma.",
+ ParamValidators.inArray[String](GeneralizedLinearRegression.supportedFamilyNames.toArray))
+
+ /** @group getParam */
+ @Since("2.0.0")
+ def getFamily: String = $(family)
+
+ /**
+ * Param for the name of link function which provides the relationship
+ * between the linear predictor and the mean of the distribution function.
+ * Supported options: "identity", "log", "inverse", "logit", "probit", "cloglog" and "sqrt".
+ * @group param
+ */
+ @Since("2.0.0")
+ final val link: Param[String] = new Param(this, "link", "The name of link function " +
+ "which provides the relationship between the linear predictor and the mean of the " +
+ "distribution function. Supported options: identity, log, inverse, logit, probit, " +
+ "cloglog and sqrt.",
+ ParamValidators.inArray[String](GeneralizedLinearRegression.supportedLinkNames.toArray))
+
+ /** @group getParam */
+ @Since("2.0.0")
+ def getLink: String = $(link)
+
+ import GeneralizedLinearRegression._
+
+ @Since("2.0.0")
+ override def validateParams(): Unit = {
+ if ($(solver) == "irls") {
+ setDefault(maxIter -> 25)
+ }
+ if (isDefined(link)) {
+ require(supportedFamilyAndLinkPairs.contains(
+ Family.fromName($(family)) -> Link.fromName($(link))), "Generalized Linear Regression " +
+ s"with ${$(family)} family does not support ${$(link)} link function.")
+ }
+ }
+}
+
+/**
+ * :: Experimental ::
+ *
+ * Fit a Generalized Linear Model ([[https://en.wikipedia.org/wiki/Generalized_linear_model]])
+ * specified by giving a symbolic description of the linear predictor (link function) and
+ * a description of the error distribution (family).
+ * It supports "gaussian", "binomial", "poisson" and "gamma" as family.
+ * Valid link functions for each family is listed below. The first link function of each family
+ * is the default one.
+ * - "gaussian" -> "identity", "log", "inverse"
+ * - "binomial" -> "logit", "probit", "cloglog"
+ * - "poisson" -> "log", "identity", "sqrt"
+ * - "gamma" -> "inverse", "identity", "log"
+ */
+@Experimental
+@Since("2.0.0")
+class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val uid: String)
+ extends Regressor[Vector, GeneralizedLinearRegression, GeneralizedLinearRegressionModel]
+ with GeneralizedLinearRegressionBase with Logging {
+
+ import GeneralizedLinearRegression._
+
+ @Since("2.0.0")
+ def this() = this(Identifiable.randomUID("glm"))
+
+ /**
+ * Sets the value of param [[family]].
+ * Default is "gaussian".
+ * @group setParam
+ */
+ @Since("2.0.0")
+ def setFamily(value: String): this.type = set(family, value)
+ setDefault(family -> Gaussian.name)
+
+ /**
+ * Sets the value of param [[link]].
+ * @group setParam
+ */
+ @Since("2.0.0")
+ def setLink(value: String): this.type = set(link, value)
+
+ /**
+ * Sets if we should fit the intercept.
+ * Default is true.
+ * @group setParam
+ */
+ @Since("2.0.0")
+ def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value)
+
+ /**
+ * Sets the maximum number of iterations.
+ * Default is 25 if the solver algorithm is "irls".
+ * @group setParam
+ */
+ @Since("2.0.0")
+ def setMaxIter(value: Int): this.type = set(maxIter, value)
+
+ /**
+ * Sets the convergence tolerance of iterations.
+ * Smaller value will lead to higher accuracy with the cost of more iterations.
+ * Default is 1E-6.
+ * @group setParam
+ */
+ @Since("2.0.0")
+ def setTol(value: Double): this.type = set(tol, value)
+ setDefault(tol -> 1E-6)
+
+ /**
+ * Sets the regularization parameter.
+ * Default is 0.0.
+ * @group setParam
+ */
+ @Since("2.0.0")
+ def setRegParam(value: Double): this.type = set(regParam, value)
+ setDefault(regParam -> 0.0)
+
+ /**
+ * Sets the value of param [[weightCol]].
+ * If this is not set or empty, we treat all instance weights as 1.0.
+ * Default is empty, so all instances have weight one.
+ * @group setParam
+ */
+ @Since("2.0.0")
+ def setWeightCol(value: String): this.type = set(weightCol, value)
+ setDefault(weightCol -> "")
+
+ /**
+ * Sets the solver algorithm used for optimization.
+ * Currently only support "irls" which is also the default solver.
+ * @group setParam
+ */
+ @Since("2.0.0")
+ def setSolver(value: String): this.type = set(solver, value)
+ setDefault(solver -> "irls")
+
+ override protected def train(dataset: DataFrame): GeneralizedLinearRegressionModel = {
+ val familyObj = Family.fromName($(family))
+ val linkObj = if (isDefined(link)) {
+ Link.fromName($(link))
+ } else {
+ familyObj.defaultLink
+ }
+ val familyAndLink = new FamilyAndLink(familyObj, linkObj)
+
+ val numFeatures = dataset.select(col($(featuresCol))).limit(1).rdd
+ .map { case Row(features: Vector) =>
+ features.size
+ }.first()
+ if (numFeatures > WeightedLeastSquares.MAX_NUM_FEATURES) {
+ val msg = "Currently, GeneralizedLinearRegression only supports number of features" +
+ s" <= ${WeightedLeastSquares.MAX_NUM_FEATURES}. Found $numFeatures in the input dataset."
+ throw new SparkException(msg)
+ }
+
+ val w = if ($(weightCol).isEmpty) lit(1.0) else col($(weightCol))
+ val instances: RDD[Instance] = dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd
+ .map { case Row(label: Double, weight: Double, features: Vector) =>
+ Instance(label, weight, features)
+ }
+
+ if (familyObj == Gaussian && linkObj == Identity) {
+ // TODO: Make standardizeFeatures and standardizeLabel configurable.
+ val optimizer = new WeightedLeastSquares($(fitIntercept), $(regParam),
+ standardizeFeatures = true, standardizeLabel = true)
+ val wlsModel = optimizer.fit(instances)
+ val model = copyValues(
+ new GeneralizedLinearRegressionModel(uid, wlsModel.coefficients, wlsModel.intercept)
+ .setParent(this))
+ return model
+ }
+
+ // Fit Generalized Linear Model by iteratively reweighted least squares (IRLS).
+ val initialModel = familyAndLink.initialize(instances, $(fitIntercept), $(regParam))
+ val optimizer = new IterativelyReweightedLeastSquares(initialModel, familyAndLink.reweightFunc,
+ $(fitIntercept), $(regParam), $(maxIter), $(tol))
+ val irlsModel = optimizer.fit(instances)
+
+ val model = copyValues(
+ new GeneralizedLinearRegressionModel(uid, irlsModel.coefficients, irlsModel.intercept)
+ .setParent(this))
+ model
+ }
+
+ @Since("2.0.0")
+ override def copy(extra: ParamMap): GeneralizedLinearRegression = defaultCopy(extra)
+}
+
+@Since("2.0.0")
+private[ml] object GeneralizedLinearRegression {
+
+ /** Set of family and link pairs that GeneralizedLinearRegression supports. */
+ lazy val supportedFamilyAndLinkPairs = Set(
+ Gaussian -> Identity, Gaussian -> Log, Gaussian -> Inverse,
+ Binomial -> Logit, Binomial -> Probit, Binomial -> CLogLog,
+ Poisson -> Log, Poisson -> Identity, Poisson -> Sqrt,
+ Gamma -> Inverse, Gamma -> Identity, Gamma -> Log
+ )
+
+ /** Set of family names that GeneralizedLinearRegression supports. */
+ lazy val supportedFamilyNames = supportedFamilyAndLinkPairs.map(_._1.name)
+
+ /** Set of link names that GeneralizedLinearRegression supports. */
+ lazy val supportedLinkNames = supportedFamilyAndLinkPairs.map(_._2.name)
+
+ val epsilon: Double = 1E-16
+
+ /**
+ * Wrapper of family and link combination used in the model.
+ */
+ private[ml] class FamilyAndLink(val family: Family, val link: Link) extends Serializable {
+
+ /** Linear predictor based on given mu. */
+ def predict(mu: Double): Double = link.link(family.project(mu))
+
+ /** Fitted value based on linear predictor eta. */
+ def fitted(eta: Double): Double = family.project(link.unlink(eta))
+
+ /**
+ * Get the initial guess model for [[IterativelyReweightedLeastSquares]].
+ */
+ def initialize(
+ instances: RDD[Instance],
+ fitIntercept: Boolean,
+ regParam: Double): WeightedLeastSquaresModel = {
+ val newInstances = instances.map { instance =>
+ val mu = family.initialize(instance.label, instance.weight)
+ val eta = predict(mu)
+ Instance(eta, instance.weight, instance.features)
+ }
+ // TODO: Make standardizeFeatures and standardizeLabel configurable.
+ val initialModel = new WeightedLeastSquares(fitIntercept, regParam,
+ standardizeFeatures = true, standardizeLabel = true)
+ .fit(newInstances)
+ initialModel
+ }
+
+ /**
+ * The reweight function used to update offsets and weights
+ * at each iteration of [[IterativelyReweightedLeastSquares]].
+ */
+ val reweightFunc: (Instance, WeightedLeastSquaresModel) => (Double, Double) = {
+ (instance: Instance, model: WeightedLeastSquaresModel) => {
+ val eta = model.predict(instance.features)
+ val mu = fitted(eta)
+ val offset = eta + (instance.label - mu) * link.deriv(mu)
+ val weight = instance.weight / (math.pow(this.link.deriv(mu), 2.0) * family.variance(mu))
+ (offset, weight)
+ }
+ }
+ }
+
+ /**
+ * A description of the error distribution to be used in the model.
+ * @param name the name of the family.
+ */
+ private[ml] abstract class Family(val name: String) extends Serializable {
+
+ /** The default link instance of this family. */
+ val defaultLink: Link
+
+ /** Initialize the starting value for mu. */
+ def initialize(y: Double, weight: Double): Double
+
+ /** The variance of the endogenous variable's mean, given the value mu. */
+ def variance(mu: Double): Double
+
+ /** Trim the fitted value so that it will be in valid range. */
+ def project(mu: Double): Double = mu
+ }
+
+ private[ml] object Family {
+
+ /**
+ * Gets the [[Family]] object from its name.
+ * @param name family name: "gaussian", "binomial", "poisson" or "gamma".
+ */
+ def fromName(name: String): Family = {
+ name match {
+ case Gaussian.name => Gaussian
+ case Binomial.name => Binomial
+ case Poisson.name => Poisson
+ case Gamma.name => Gamma
+ }
+ }
+ }
+
+ /**
+ * Gaussian exponential family distribution.
+ * The default link for the Gaussian family is the identity link.
+ */
+ private[ml] object Gaussian extends Family("gaussian") {
+
+ val defaultLink: Link = Identity
+
+ override def initialize(y: Double, weight: Double): Double = y
+
+ def variance(mu: Double): Double = 1.0
+
+ override def project(mu: Double): Double = {
+ if (mu.isNegInfinity) {
+ Double.MinValue
+ } else if (mu.isPosInfinity) {
+ Double.MaxValue
+ } else {
+ mu
+ }
+ }
+ }
+
+ /**
+ * Binomial exponential family distribution.
+ * The default link for the Binomial family is the logit link.
+ */
+ private[ml] object Binomial extends Family("binomial") {
+
+ val defaultLink: Link = Logit
+
+ override def initialize(y: Double, weight: Double): Double = {
+ val mu = (weight * y + 0.5) / (weight + 1.0)
+ require(mu > 0.0 && mu < 1.0, "The response variable of Binomial family" +
+ s"should be in range (0, 1), but got $mu")
+ mu
+ }
+
+ override def variance(mu: Double): Double = mu * (1.0 - mu)
+
+ override def project(mu: Double): Double = {
+ if (mu < epsilon) {
+ epsilon
+ } else if (mu > 1.0 - epsilon) {
+ 1.0 - epsilon
+ } else {
+ mu
+ }
+ }
+ }
+
+ /**
+ * Poisson exponential family distribution.
+ * The default link for the Poisson family is the log link.
+ */
+ private[ml] object Poisson extends Family("poisson") {
+
+ val defaultLink: Link = Log
+
+ override def initialize(y: Double, weight: Double): Double = {
+ require(y > 0.0, "The response variable of Poisson family " +
+ s"should be positive, but got $y")
+ y
+ }
+
+ override def variance(mu: Double): Double = mu
+
+ override def project(mu: Double): Double = {
+ if (mu < epsilon) {
+ epsilon
+ } else if (mu.isInfinity) {
+ Double.MaxValue
+ } else {
+ mu
+ }
+ }
+ }
+
+ /**
+ * Gamma exponential family distribution.
+ * The default link for the Gamma family is the inverse link.
+ */
+ private[ml] object Gamma extends Family("gamma") {
+
+ val defaultLink: Link = Inverse
+
+ override def initialize(y: Double, weight: Double): Double = {
+ require(y > 0.0, "The response variable of Gamma family " +
+ s"should be positive, but got $y")
+ y
+ }
+
+ override def variance(mu: Double): Double = math.pow(mu, 2.0)
+
+ override def project(mu: Double): Double = {
+ if (mu < epsilon) {
+ epsilon
+ } else if (mu.isInfinity) {
+ Double.MaxValue
+ } else {
+ mu
+ }
+ }
+ }
+
+ /**
+ * A description of the link function to be used in the model.
+ * The link function provides the relationship between the linear predictor
+ * and the mean of the distribution function.
+ * @param name the name of link function.
+ */
+ private[ml] abstract class Link(val name: String) extends Serializable {
+
+ /** The link function. */
+ def link(mu: Double): Double
+
+ /** Derivative of the link function. */
+ def deriv(mu: Double): Double
+
+ /** The inverse link function. */
+ def unlink(eta: Double): Double
+ }
+
+ private[ml] object Link {
+
+ /**
+ * Gets the [[Link]] object from its name.
+ * @param name link name: "identity", "logit", "log",
+ * "inverse", "probit", "cloglog" or "sqrt".
+ */
+ def fromName(name: String): Link = {
+ name match {
+ case Identity.name => Identity
+ case Logit.name => Logit
+ case Log.name => Log
+ case Inverse.name => Inverse
+ case Probit.name => Probit
+ case CLogLog.name => CLogLog
+ case Sqrt.name => Sqrt
+ }
+ }
+ }
+
+ private[ml] object Identity extends Link("identity") {
+
+ override def link(mu: Double): Double = mu
+
+ override def deriv(mu: Double): Double = 1.0
+
+ override def unlink(eta: Double): Double = eta
+ }
+
+ private[ml] object Logit extends Link("logit") {
+
+ override def link(mu: Double): Double = math.log(mu / (1.0 - mu))
+
+ override def deriv(mu: Double): Double = 1.0 / (mu * (1.0 - mu))
+
+ override def unlink(eta: Double): Double = 1.0 / (1.0 + math.exp(-1.0 * eta))
+ }
+
+ private[ml] object Log extends Link("log") {
+
+ override def link(mu: Double): Double = math.log(mu)
+
+ override def deriv(mu: Double): Double = 1.0 / mu
+
+ override def unlink(eta: Double): Double = math.exp(eta)
+ }
+
+ private[ml] object Inverse extends Link("inverse") {
+
+ override def link(mu: Double): Double = 1.0 / mu
+
+ override def deriv(mu: Double): Double = -1.0 * math.pow(mu, -2.0)
+
+ override def unlink(eta: Double): Double = 1.0 / eta
+ }
+
+ private[ml] object Probit extends Link("probit") {
+
+ override def link(mu: Double): Double = GD(0.0, 1.0).icdf(mu)
+
+ override def deriv(mu: Double): Double = 1.0 / GD(0.0, 1.0).pdf(GD(0.0, 1.0).icdf(mu))
+
+ override def unlink(eta: Double): Double = GD(0.0, 1.0).cdf(eta)
+ }
+
+ private[ml] object CLogLog extends Link("cloglog") {
+
+ override def link(mu: Double): Double = math.log(-1.0 * math.log(1 - mu))
+
+ override def deriv(mu: Double): Double = 1.0 / ((mu - 1.0) * math.log(1.0 - mu))
+
+ override def unlink(eta: Double): Double = 1.0 - math.exp(-1.0 * math.exp(eta))
+ }
+
+ private[ml] object Sqrt extends Link("sqrt") {
+
+ override def link(mu: Double): Double = math.sqrt(mu)
+
+ override def deriv(mu: Double): Double = 1.0 / (2.0 * math.sqrt(mu))
+
+ override def unlink(eta: Double): Double = math.pow(eta, 2.0)
+ }
+}
+
+/**
+ * :: Experimental ::
+ * Model produced by [[GeneralizedLinearRegression]].
+ */
+@Experimental
+@Since("2.0.0")
+class GeneralizedLinearRegressionModel private[ml] (
+ @Since("2.0.0") override val uid: String,
+ @Since("2.0.0") val coefficients: Vector,
+ @Since("2.0.0") val intercept: Double)
+ extends RegressionModel[Vector, GeneralizedLinearRegressionModel]
+ with GeneralizedLinearRegressionBase {
+
+ import GeneralizedLinearRegression._
+
+ lazy val familyObj = Family.fromName($(family))
+ lazy val linkObj = if (isDefined(link)) {
+ Link.fromName($(link))
+ } else {
+ familyObj.defaultLink
+ }
+ lazy val familyAndLink = new FamilyAndLink(familyObj, linkObj)
+
+ override protected def predict(features: Vector): Double = {
+ val eta = BLAS.dot(features, coefficients) + intercept
+ familyAndLink.fitted(eta)
+ }
+
+ @Since("2.0.0")
+ override def copy(extra: ParamMap): GeneralizedLinearRegressionModel = {
+ copyValues(new GeneralizedLinearRegressionModel(uid, coefficients, intercept), extra)
+ .setParent(parent)
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index 8f78fd122f..b4f17b8e28 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -163,8 +163,8 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
}.first()
val w = if ($(weightCol).isEmpty) lit(1.0) else col($(weightCol))
- if (($(solver) == "auto" && $(elasticNetParam) == 0.0 && numFeatures <= 4096) ||
- $(solver) == "normal") {
+ if (($(solver) == "auto" && $(elasticNetParam) == 0.0 &&
+ numFeatures <= WeightedLeastSquares.MAX_NUM_FEATURES) || $(solver) == "normal") {
require($(elasticNetParam) == 0.0, "Only L2 regularization can be used when normal " +
"solver is used.'")
// For low dimensional data, WeightedLeastSquares is more efficiently since the
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
new file mode 100644
index 0000000000..8bfa9855ce
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
@@ -0,0 +1,507 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.regression
+
+import scala.util.Random
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.ml.util.MLTestingUtils
+import org.apache.spark.mllib.classification.LogisticRegressionSuite._
+import org.apache.spark.mllib.linalg.{BLAS, DenseVector, Vectors}
+import org.apache.spark.mllib.random._
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.mllib.util.TestingUtils._
+import org.apache.spark.sql.{DataFrame, Row}
+
+class GeneralizedLinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
+
+ private val seed: Int = 42
+ @transient var datasetGaussianIdentity: DataFrame = _
+ @transient var datasetGaussianLog: DataFrame = _
+ @transient var datasetGaussianInverse: DataFrame = _
+ @transient var datasetBinomial: DataFrame = _
+ @transient var datasetPoissonLog: DataFrame = _
+ @transient var datasetPoissonIdentity: DataFrame = _
+ @transient var datasetPoissonSqrt: DataFrame = _
+ @transient var datasetGammaInverse: DataFrame = _
+ @transient var datasetGammaIdentity: DataFrame = _
+ @transient var datasetGammaLog: DataFrame = _
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+
+ import GeneralizedLinearRegressionSuite._
+
+ datasetGaussianIdentity = sqlContext.createDataFrame(
+ sc.parallelize(generateGeneralizedLinearRegressionInput(
+ intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5),
+ xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01,
+ family = "gaussian", link = "identity"), 2))
+
+ datasetGaussianLog = sqlContext.createDataFrame(
+ sc.parallelize(generateGeneralizedLinearRegressionInput(
+ intercept = 0.25, coefficients = Array(0.22, 0.06), xMean = Array(2.9, 10.5),
+ xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01,
+ family = "gaussian", link = "log"), 2))
+
+ datasetGaussianInverse = sqlContext.createDataFrame(
+ sc.parallelize(generateGeneralizedLinearRegressionInput(
+ intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5),
+ xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01,
+ family = "gaussian", link = "inverse"), 2))
+
+ datasetBinomial = {
+ val nPoints = 10000
+ val coefficients = Array(-0.57997, 0.912083, -0.371077, -0.819866, 2.688191)
+ val xMean = Array(5.843, 3.057, 3.758, 1.199)
+ val xVariance = Array(0.6856, 0.1899, 3.116, 0.581)
+
+ val testData =
+ generateMultinomialLogisticInput(coefficients, xMean, xVariance,
+ addIntercept = true, nPoints, seed)
+
+ sqlContext.createDataFrame(sc.parallelize(testData, 2))
+ }
+
+ datasetPoissonLog = sqlContext.createDataFrame(
+ sc.parallelize(generateGeneralizedLinearRegressionInput(
+ intercept = 0.25, coefficients = Array(0.22, 0.06), xMean = Array(2.9, 10.5),
+ xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01,
+ family = "poisson", link = "log"), 2))
+
+ datasetPoissonIdentity = sqlContext.createDataFrame(
+ sc.parallelize(generateGeneralizedLinearRegressionInput(
+ intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5),
+ xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01,
+ family = "poisson", link = "identity"), 2))
+
+ datasetPoissonSqrt = sqlContext.createDataFrame(
+ sc.parallelize(generateGeneralizedLinearRegressionInput(
+ intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5),
+ xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01,
+ family = "poisson", link = "sqrt"), 2))
+
+ datasetGammaInverse = sqlContext.createDataFrame(
+ sc.parallelize(generateGeneralizedLinearRegressionInput(
+ intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5),
+ xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01,
+ family = "gamma", link = "inverse"), 2))
+
+ datasetGammaIdentity = sqlContext.createDataFrame(
+ sc.parallelize(generateGeneralizedLinearRegressionInput(
+ intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5),
+ xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01,
+ family = "gamma", link = "identity"), 2))
+
+ datasetGammaLog = sqlContext.createDataFrame(
+ sc.parallelize(generateGeneralizedLinearRegressionInput(
+ intercept = 0.25, coefficients = Array(0.22, 0.06), xMean = Array(2.9, 10.5),
+ xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01,
+ family = "gamma", link = "log"), 2))
+ }
+
+ test("params") {
+ ParamsSuite.checkParams(new GeneralizedLinearRegression)
+ val model = new GeneralizedLinearRegressionModel("genLinReg", Vectors.dense(0.0), 0.0)
+ ParamsSuite.checkParams(model)
+ }
+
+ test("generalized linear regression: default params") {
+ val glr = new GeneralizedLinearRegression
+ assert(glr.getLabelCol === "label")
+ assert(glr.getFeaturesCol === "features")
+ assert(glr.getPredictionCol === "prediction")
+ assert(glr.getFitIntercept)
+ assert(glr.getTol === 1E-6)
+ assert(glr.getWeightCol === "")
+ assert(glr.getRegParam === 0.0)
+ assert(glr.getSolver == "irls")
+ // TODO: Construct model directly instead of via fitting.
+ val model = glr.setFamily("gaussian").setLink("identity")
+ .fit(datasetGaussianIdentity)
+
+ // copied model must have the same parent.
+ MLTestingUtils.checkCopy(model)
+
+ assert(model.getFeaturesCol === "features")
+ assert(model.getPredictionCol === "prediction")
+ assert(model.intercept !== 0.0)
+ assert(model.hasParent)
+ assert(model.getFamily === "gaussian")
+ assert(model.getLink === "identity")
+ }
+
+ test("generalized linear regression: gaussian family against glm") {
+ /*
+ R code:
+ f1 <- data$V1 ~ data$V2 + data$V3 - 1
+ f2 <- data$V1 ~ data$V2 + data$V3
+
+ data <- read.csv("path", header=FALSE)
+ for (formula in c(f1, f2)) {
+ model <- glm(formula, family="gaussian", data=data)
+ print(as.vector(coef(model)))
+ }
+
+ [1] 2.2960999 0.8087933
+ [1] 2.5002642 2.2000403 0.5999485
+
+ data <- read.csv("path", header=FALSE)
+ model1 <- glm(f1, family=gaussian(link=log), data=data, start=c(0,0))
+ model2 <- glm(f2, family=gaussian(link=log), data=data, start=c(0,0,0))
+ print(as.vector(coef(model1)))
+ print(as.vector(coef(model2)))
+
+ [1] 0.23069326 0.07993778
+ [1] 0.25001858 0.22002452 0.05998789
+
+ data <- read.csv("path", header=FALSE)
+ for (formula in c(f1, f2)) {
+ model <- glm(formula, family=gaussian(link=inverse), data=data)
+ print(as.vector(coef(model)))
+ }
+
+ [1] 2.3010179 0.8198976
+ [1] 2.4108902 2.2130248 0.6086152
+ */
+
+ val expected = Seq(
+ Vectors.dense(0.0, 2.2960999, 0.8087933),
+ Vectors.dense(2.5002642, 2.2000403, 0.5999485),
+ Vectors.dense(0.0, 0.23069326, 0.07993778),
+ Vectors.dense(0.25001858, 0.22002452, 0.05998789),
+ Vectors.dense(0.0, 2.3010179, 0.8198976),
+ Vectors.dense(2.4108902, 2.2130248, 0.6086152))
+
+ import GeneralizedLinearRegression._
+
+ var idx = 0
+ for ((link, dataset) <- Seq(("identity", datasetGaussianIdentity), ("log", datasetGaussianLog),
+ ("inverse", datasetGaussianInverse))) {
+ for (fitIntercept <- Seq(false, true)) {
+ val trainer = new GeneralizedLinearRegression().setFamily("gaussian").setLink(link)
+ .setFitIntercept(fitIntercept)
+ val model = trainer.fit(dataset)
+ val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1))
+ assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with gaussian family, " +
+ s"$link link and fitIntercept = $fitIntercept.")
+
+ val familyLink = new FamilyAndLink(Gaussian, Link.fromName(link))
+ model.transform(dataset).select("features", "prediction").collect().foreach {
+ case Row(features: DenseVector, prediction1: Double) =>
+ val eta = BLAS.dot(features, model.coefficients) + model.intercept
+ val prediction2 = familyLink.fitted(eta)
+ assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " +
+ s"gaussian family, $link link and fitIntercept = $fitIntercept.")
+ }
+
+ idx += 1
+ }
+ }
+ }
+
+ test("generalized linear regression: gaussian family against glmnet") {
+ /*
+ R code:
+ library(glmnet)
+ data <- read.csv("path", header=FALSE)
+ label = data$V1
+ features = as.matrix(data.frame(data$V2, data$V3))
+ for (intercept in c(FALSE, TRUE)) {
+ for (lambda in c(0.0, 0.1, 1.0)) {
+ model <- glmnet(features, label, family="gaussian", intercept=intercept,
+ lambda=lambda, alpha=0, thresh=1E-14)
+ print(as.vector(coef(model)))
+ }
+ }
+
+ [1] 0.0000000 2.2961005 0.8087932
+ [1] 0.0000000 2.2130368 0.8309556
+ [1] 0.0000000 1.7176137 0.9610657
+ [1] 2.5002642 2.2000403 0.5999485
+ [1] 3.1106389 2.0935142 0.5712711
+ [1] 6.7597127 1.4581054 0.3994266
+ */
+
+ val expected = Seq(
+ Vectors.dense(0.0, 2.2961005, 0.8087932),
+ Vectors.dense(0.0, 2.2130368, 0.8309556),
+ Vectors.dense(0.0, 1.7176137, 0.9610657),
+ Vectors.dense(2.5002642, 2.2000403, 0.5999485),
+ Vectors.dense(3.1106389, 2.0935142, 0.5712711),
+ Vectors.dense(6.7597127, 1.4581054, 0.3994266))
+
+ var idx = 0
+ for (fitIntercept <- Seq(false, true);
+ regParam <- Seq(0.0, 0.1, 1.0)) {
+ val trainer = new GeneralizedLinearRegression().setFamily("gaussian")
+ .setFitIntercept(fitIntercept).setRegParam(regParam)
+ val model = trainer.fit(datasetGaussianIdentity)
+ val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1))
+ assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with gaussian family, " +
+ s"fitIntercept = $fitIntercept and regParam = $regParam.")
+
+ idx += 1
+ }
+ }
+
+ test("generalized linear regression: binomial family against glm") {
+ /*
+ R code:
+ f1 <- data$V1 ~ data$V2 + data$V3 + data$V4 + data$V5 - 1
+ f2 <- data$V1 ~ data$V2 + data$V3 + data$V4 + data$V5
+ data <- read.csv("path", header=FALSE)
+
+ for (formula in c(f1, f2)) {
+ model <- glm(formula, family="binomial", data=data)
+ print(as.vector(coef(model)))
+ }
+
+ [1] -0.3560284 1.3010002 -0.3570805 -0.7406762
+ [1] 2.8367406 -0.5896187 0.8931655 -0.3925169 -0.7996989
+
+ for (formula in c(f1, f2)) {
+ model <- glm(formula, family=binomial(link=probit), data=data)
+ print(as.vector(coef(model)))
+ }
+
+ [1] -0.2134390 0.7800646 -0.2144267 -0.4438358
+ [1] 1.6995366 -0.3524694 0.5332651 -0.2352985 -0.4780850
+
+ for (formula in c(f1, f2)) {
+ model <- glm(formula, family=binomial(link=cloglog), data=data)
+ print(as.vector(coef(model)))
+ }
+
+ [1] -0.2832198 0.8434144 -0.2524727 -0.5293452
+ [1] 1.5063590 -0.4038015 0.6133664 -0.2687882 -0.5541758
+ */
+ val expected = Seq(
+ Vectors.dense(0.0, -0.3560284, 1.3010002, -0.3570805, -0.7406762),
+ Vectors.dense(2.8367406, -0.5896187, 0.8931655, -0.3925169, -0.7996989),
+ Vectors.dense(0.0, -0.2134390, 0.7800646, -0.2144267, -0.4438358),
+ Vectors.dense(1.6995366, -0.3524694, 0.5332651, -0.2352985, -0.4780850),
+ Vectors.dense(0.0, -0.2832198, 0.8434144, -0.2524727, -0.5293452),
+ Vectors.dense(1.5063590, -0.4038015, 0.6133664, -0.2687882, -0.5541758))
+
+ import GeneralizedLinearRegression._
+
+ var idx = 0
+ for ((link, dataset) <- Seq(("logit", datasetBinomial), ("probit", datasetBinomial),
+ ("cloglog", datasetBinomial))) {
+ for (fitIntercept <- Seq(false, true)) {
+ val trainer = new GeneralizedLinearRegression().setFamily("binomial").setLink(link)
+ .setFitIntercept(fitIntercept)
+ val model = trainer.fit(dataset)
+ val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1),
+ model.coefficients(2), model.coefficients(3))
+ assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with binomial family, " +
+ s"$link link and fitIntercept = $fitIntercept.")
+
+ val familyLink = new FamilyAndLink(Binomial, Link.fromName(link))
+ model.transform(dataset).select("features", "prediction").collect().foreach {
+ case Row(features: DenseVector, prediction1: Double) =>
+ val eta = BLAS.dot(features, model.coefficients) + model.intercept
+ val prediction2 = familyLink.fitted(eta)
+ assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " +
+ s"binomial family, $link link and fitIntercept = $fitIntercept.")
+ }
+
+ idx += 1
+ }
+ }
+ }
+
+ test("generalized linear regression: poisson family against glm") {
+ /*
+ R code:
+ f1 <- data$V1 ~ data$V2 + data$V3 - 1
+ f2 <- data$V1 ~ data$V2 + data$V3
+
+ data <- read.csv("path", header=FALSE)
+ for (formula in c(f1, f2)) {
+ model <- glm(formula, family="poisson", data=data)
+ print(as.vector(coef(model)))
+ }
+
+ [1] 0.22999393 0.08047088
+ [1] 0.25022353 0.21998599 0.05998621
+
+ data <- read.csv("path", header=FALSE)
+ for (formula in c(f1, f2)) {
+ model <- glm(formula, family=poisson(link=identity), data=data)
+ print(as.vector(coef(model)))
+ }
+
+ [1] 2.2929501 0.8119415
+ [1] 2.5012730 2.1999407 0.5999107
+
+ data <- read.csv("path", header=FALSE)
+ for (formula in c(f1, f2)) {
+ model <- glm(formula, family=poisson(link=sqrt), data=data)
+ print(as.vector(coef(model)))
+ }
+
+ [1] 2.2958947 0.8090515
+ [1] 2.5000480 2.1999972 0.5999968
+ */
+ val expected = Seq(
+ Vectors.dense(0.0, 0.22999393, 0.08047088),
+ Vectors.dense(0.25022353, 0.21998599, 0.05998621),
+ Vectors.dense(0.0, 2.2929501, 0.8119415),
+ Vectors.dense(2.5012730, 2.1999407, 0.5999107),
+ Vectors.dense(0.0, 2.2958947, 0.8090515),
+ Vectors.dense(2.5000480, 2.1999972, 0.5999968))
+
+ import GeneralizedLinearRegression._
+
+ var idx = 0
+ for ((link, dataset) <- Seq(("log", datasetPoissonLog), ("identity", datasetPoissonIdentity),
+ ("sqrt", datasetPoissonSqrt))) {
+ for (fitIntercept <- Seq(false, true)) {
+ val trainer = new GeneralizedLinearRegression().setFamily("poisson").setLink(link)
+ .setFitIntercept(fitIntercept)
+ val model = trainer.fit(dataset)
+ val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1))
+ assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with poisson family, " +
+ s"$link link and fitIntercept = $fitIntercept.")
+
+ val familyLink = new FamilyAndLink(Poisson, Link.fromName(link))
+ model.transform(dataset).select("features", "prediction").collect().foreach {
+ case Row(features: DenseVector, prediction1: Double) =>
+ val eta = BLAS.dot(features, model.coefficients) + model.intercept
+ val prediction2 = familyLink.fitted(eta)
+ assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " +
+ s"poisson family, $link link and fitIntercept = $fitIntercept.")
+ }
+
+ idx += 1
+ }
+ }
+ }
+
+ test("generalized linear regression: gamma family against glm") {
+ /*
+ R code:
+ f1 <- data$V1 ~ data$V2 + data$V3 - 1
+ f2 <- data$V1 ~ data$V2 + data$V3
+
+ data <- read.csv("path", header=FALSE)
+ for (formula in c(f1, f2)) {
+ model <- glm(formula, family="Gamma", data=data)
+ print(as.vector(coef(model)))
+ }
+
+ [1] 2.3392419 0.8058058
+ [1] 2.3507700 2.2533574 0.6042991
+
+ data <- read.csv("path", header=FALSE)
+ for (formula in c(f1, f2)) {
+ model <- glm(formula, family=Gamma(link=identity), data=data)
+ print(as.vector(coef(model)))
+ }
+
+ [1] 2.2908883 0.8147796
+ [1] 2.5002406 2.1998346 0.6000059
+
+ data <- read.csv("path", header=FALSE)
+ for (formula in c(f1, f2)) {
+ model <- glm(formula, family=Gamma(link=log), data=data)
+ print(as.vector(coef(model)))
+ }
+
+ [1] 0.22958970 0.08091066
+ [1] 0.25003210 0.21996957 0.06000215
+ */
+ val expected = Seq(
+ Vectors.dense(0.0, 2.3392419, 0.8058058),
+ Vectors.dense(2.3507700, 2.2533574, 0.6042991),
+ Vectors.dense(0.0, 2.2908883, 0.8147796),
+ Vectors.dense(2.5002406, 2.1998346, 0.6000059),
+ Vectors.dense(0.0, 0.22958970, 0.08091066),
+ Vectors.dense(0.25003210, 0.21996957, 0.06000215))
+
+ import GeneralizedLinearRegression._
+
+ var idx = 0
+ for ((link, dataset) <- Seq(("inverse", datasetGammaInverse),
+ ("identity", datasetGammaIdentity), ("log", datasetGammaLog))) {
+ for (fitIntercept <- Seq(false, true)) {
+ val trainer = new GeneralizedLinearRegression().setFamily("gamma").setLink(link)
+ .setFitIntercept(fitIntercept)
+ val model = trainer.fit(dataset)
+ val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1))
+ assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with gamma family, " +
+ s"$link link and fitIntercept = $fitIntercept.")
+
+ val familyLink = new FamilyAndLink(Gamma, Link.fromName(link))
+ model.transform(dataset).select("features", "prediction").collect().foreach {
+ case Row(features: DenseVector, prediction1: Double) =>
+ val eta = BLAS.dot(features, model.coefficients) + model.intercept
+ val prediction2 = familyLink.fitted(eta)
+ assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " +
+ s"gamma family, $link link and fitIntercept = $fitIntercept.")
+ }
+
+ idx += 1
+ }
+ }
+ }
+}
+
+object GeneralizedLinearRegressionSuite {
+
+ def generateGeneralizedLinearRegressionInput(
+ intercept: Double,
+ coefficients: Array[Double],
+ xMean: Array[Double],
+ xVariance: Array[Double],
+ nPoints: Int,
+ seed: Int,
+ noiseLevel: Double,
+ family: String,
+ link: String): Seq[LabeledPoint] = {
+
+ val rnd = new Random(seed)
+ def rndElement(i: Int) = {
+ (rnd.nextDouble() - 0.5) * math.sqrt(12.0 * xVariance(i)) + xMean(i)
+ }
+ val (generator, mean) = family match {
+ case "gaussian" => (new StandardNormalGenerator, 0.0)
+ case "poisson" => (new PoissonGenerator(1.0), 1.0)
+ case "gamma" => (new GammaGenerator(1.0, 1.0), 1.0)
+ }
+ generator.setSeed(seed)
+
+ (0 until nPoints).map { _ =>
+ val features = Vectors.dense(coefficients.indices.map { rndElement(_) }.toArray)
+ val eta = BLAS.dot(Vectors.dense(coefficients), features) + intercept
+ val mu = link match {
+ case "identity" => eta
+ case "log" => math.exp(eta)
+ case "sqrt" => math.pow(eta, 2.0)
+ case "inverse" => 1.0 / eta
+ }
+ val label = mu + noiseLevel * (generator.nextValue() - mean)
+ // Return LabeledPoints with DenseVector
+ LabeledPoint(label, features)
+ }
+ }
+}