aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2016-04-21 17:31:33 -0700
committerXiangrui Meng <meng@databricks.com>2016-04-21 17:31:33 -0700
commit4e726227a3e68c776ea30b78b7db8d01d00b44d6 (patch)
tree8476a45adfd764147d89c4f0f4b904f0067e2a94
parentf25a3ea8d3ee6972efb925826981918549deacaa (diff)
downloadspark-4e726227a3e68c776ea30b78b7db8d01d00b44d6.tar.gz
spark-4e726227a3e68c776ea30b78b7db8d01d00b44d6.tar.bz2
spark-4e726227a3e68c776ea30b78b7db8d01d00b44d6.zip
[SPARK-14479][ML] GLM supports output link prediction
## What changes were proposed in this pull request? GLM supports output link prediction. ## How was this patch tested? unit test. Author: Yanbo Liang <ybliang8@gmail.com> Closes #12287 from yanboliang/spark-14479.
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala62
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala80
2 files changed, 108 insertions, 34 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
index e92a3e7fa1..dcf69afe0d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
@@ -78,6 +78,20 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
@Since("2.0.0")
def getLink: String = $(link)
+ /**
+ * Param for link prediction (linear predictor) column name.
+ * Default is empty, which means we do not output link prediction.
+ * @group param
+ */
+ @Since("2.0.0")
+ final val linkPredictionCol: Param[String] = new Param[String](this, "linkPredictionCol",
+ "link prediction (linear predictor) column name")
+ setDefault(linkPredictionCol, "")
+
+ /** @group getParam */
+ @Since("2.0.0")
+ def getLinkPredictionCol: String = $(linkPredictionCol)
+
import GeneralizedLinearRegression._
@Since("2.0.0")
@@ -93,7 +107,12 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
Family.fromName($(family)) -> Link.fromName($(link))), "Generalized Linear Regression " +
s"with ${$(family)} family does not support ${$(link)} link function.")
}
- super.validateAndTransformSchema(schema, fitting, featuresDataType)
+ val newSchema = super.validateAndTransformSchema(schema, fitting, featuresDataType)
+ if ($(linkPredictionCol).nonEmpty) {
+ SchemaUtils.appendColumn(newSchema, $(linkPredictionCol), DoubleType)
+ } else {
+ newSchema
+ }
}
}
@@ -196,6 +215,13 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
def setSolver(value: String): this.type = set(solver, value)
setDefault(solver -> "irls")
+ /**
+ * Sets the link prediction (linear predictor) column name.
+ * @group setParam
+ */
+ @Since("2.0.0")
+ def setLinkPredictionCol(value: String): this.type = set(linkPredictionCol, value)
+
override protected def train(dataset: Dataset[_]): GeneralizedLinearRegressionModel = {
val familyObj = Family.fromName($(family))
val linkObj = if (isDefined(link)) {
@@ -666,6 +692,13 @@ class GeneralizedLinearRegressionModel private[ml] (
extends RegressionModel[Vector, GeneralizedLinearRegressionModel]
with GeneralizedLinearRegressionBase with MLWritable {
+ /**
+ * Sets the link prediction (linear predictor) column name.
+ * @group setParam
+ */
+ @Since("2.0.0")
+ def setLinkPredictionCol(value: String): this.type = set(linkPredictionCol, value)
+
import GeneralizedLinearRegression._
lazy val familyObj = Family.fromName($(family))
@@ -677,10 +710,35 @@ class GeneralizedLinearRegressionModel private[ml] (
lazy val familyAndLink = new FamilyAndLink(familyObj, linkObj)
override protected def predict(features: Vector): Double = {
- val eta = BLAS.dot(features, coefficients) + intercept
+ val eta = predictLink(features)
familyAndLink.fitted(eta)
}
+ /**
+ * Calculate the link prediction (linear predictor) of the given instance.
+ */
+ private def predictLink(features: Vector): Double = {
+ BLAS.dot(features, coefficients) + intercept
+ }
+
+ override def transform(dataset: Dataset[_]): DataFrame = {
+ transformSchema(dataset.schema)
+ transformImpl(dataset)
+ }
+
+ override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
+ val predictUDF = udf { (features: Vector) => predict(features) }
+ val predictLinkUDF = udf { (features: Vector) => predictLink(features) }
+ var output = dataset
+ if ($(predictionCol).nonEmpty) {
+ output = output.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
+ }
+ if ($(linkPredictionCol).nonEmpty) {
+ output = output.withColumn($(linkPredictionCol), predictLinkUDF(col($(featuresCol))))
+ }
+ output.toDF
+ }
+
private var trainingSummary: Option[GeneralizedLinearRegressionSummary] = None
/**
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
index 3ecc210abd..0b5e77afc3 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
@@ -247,20 +247,24 @@ class GeneralizedLinearRegressionSuite
("inverse", datasetGaussianInverse))) {
for (fitIntercept <- Seq(false, true)) {
val trainer = new GeneralizedLinearRegression().setFamily("gaussian").setLink(link)
- .setFitIntercept(fitIntercept)
+ .setFitIntercept(fitIntercept).setLinkPredictionCol("linkPrediction")
val model = trainer.fit(dataset)
val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1))
assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with gaussian family, " +
s"$link link and fitIntercept = $fitIntercept.")
val familyLink = new FamilyAndLink(Gaussian, Link.fromName(link))
- model.transform(dataset).select("features", "prediction").collect().foreach {
- case Row(features: DenseVector, prediction1: Double) =>
- val eta = BLAS.dot(features, model.coefficients) + model.intercept
- val prediction2 = familyLink.fitted(eta)
- assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " +
- s"gaussian family, $link link and fitIntercept = $fitIntercept.")
- }
+ model.transform(dataset).select("features", "prediction", "linkPrediction").collect()
+ .foreach {
+ case Row(features: DenseVector, prediction1: Double, linkPrediction1: Double) =>
+ val eta = BLAS.dot(features, model.coefficients) + model.intercept
+ val prediction2 = familyLink.fitted(eta)
+ val linkPrediction2 = eta
+ assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " +
+ s"gaussian family, $link link and fitIntercept = $fitIntercept.")
+ assert(linkPrediction1 ~= linkPrediction2 relTol 1E-5, "Link Prediction mismatch: " +
+ s"GLM with gaussian family, $link link and fitIntercept = $fitIntercept.")
+ }
idx += 1
}
@@ -358,7 +362,7 @@ class GeneralizedLinearRegressionSuite
("cloglog", datasetBinomial))) {
for (fitIntercept <- Seq(false, true)) {
val trainer = new GeneralizedLinearRegression().setFamily("binomial").setLink(link)
- .setFitIntercept(fitIntercept)
+ .setFitIntercept(fitIntercept).setLinkPredictionCol("linkPrediction")
val model = trainer.fit(dataset)
val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1),
model.coefficients(2), model.coefficients(3))
@@ -366,13 +370,17 @@ class GeneralizedLinearRegressionSuite
s"$link link and fitIntercept = $fitIntercept.")
val familyLink = new FamilyAndLink(Binomial, Link.fromName(link))
- model.transform(dataset).select("features", "prediction").collect().foreach {
- case Row(features: DenseVector, prediction1: Double) =>
- val eta = BLAS.dot(features, model.coefficients) + model.intercept
- val prediction2 = familyLink.fitted(eta)
- assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " +
- s"binomial family, $link link and fitIntercept = $fitIntercept.")
- }
+ model.transform(dataset).select("features", "prediction", "linkPrediction").collect()
+ .foreach {
+ case Row(features: DenseVector, prediction1: Double, linkPrediction1: Double) =>
+ val eta = BLAS.dot(features, model.coefficients) + model.intercept
+ val prediction2 = familyLink.fitted(eta)
+ val linkPrediction2 = eta
+ assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " +
+ s"binomial family, $link link and fitIntercept = $fitIntercept.")
+ assert(linkPrediction1 ~= linkPrediction2 relTol 1E-5, "Link Prediction mismatch: " +
+ s"GLM with binomial family, $link link and fitIntercept = $fitIntercept.")
+ }
idx += 1
}
@@ -427,20 +435,24 @@ class GeneralizedLinearRegressionSuite
("sqrt", datasetPoissonSqrt))) {
for (fitIntercept <- Seq(false, true)) {
val trainer = new GeneralizedLinearRegression().setFamily("poisson").setLink(link)
- .setFitIntercept(fitIntercept)
+ .setFitIntercept(fitIntercept).setLinkPredictionCol("linkPrediction")
val model = trainer.fit(dataset)
val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1))
assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with poisson family, " +
s"$link link and fitIntercept = $fitIntercept.")
val familyLink = new FamilyAndLink(Poisson, Link.fromName(link))
- model.transform(dataset).select("features", "prediction").collect().foreach {
- case Row(features: DenseVector, prediction1: Double) =>
- val eta = BLAS.dot(features, model.coefficients) + model.intercept
- val prediction2 = familyLink.fitted(eta)
- assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " +
- s"poisson family, $link link and fitIntercept = $fitIntercept.")
- }
+ model.transform(dataset).select("features", "prediction", "linkPrediction").collect()
+ .foreach {
+ case Row(features: DenseVector, prediction1: Double, linkPrediction1: Double) =>
+ val eta = BLAS.dot(features, model.coefficients) + model.intercept
+ val prediction2 = familyLink.fitted(eta)
+ val linkPrediction2 = eta
+ assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " +
+ s"poisson family, $link link and fitIntercept = $fitIntercept.")
+ assert(linkPrediction1 ~= linkPrediction2 relTol 1E-5, "Link Prediction mismatch: " +
+ s"GLM with poisson family, $link link and fitIntercept = $fitIntercept.")
+ }
idx += 1
}
@@ -495,20 +507,24 @@ class GeneralizedLinearRegressionSuite
("identity", datasetGammaIdentity), ("log", datasetGammaLog))) {
for (fitIntercept <- Seq(false, true)) {
val trainer = new GeneralizedLinearRegression().setFamily("gamma").setLink(link)
- .setFitIntercept(fitIntercept)
+ .setFitIntercept(fitIntercept).setLinkPredictionCol("linkPrediction")
val model = trainer.fit(dataset)
val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1))
assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with gamma family, " +
s"$link link and fitIntercept = $fitIntercept.")
val familyLink = new FamilyAndLink(Gamma, Link.fromName(link))
- model.transform(dataset).select("features", "prediction").collect().foreach {
- case Row(features: DenseVector, prediction1: Double) =>
- val eta = BLAS.dot(features, model.coefficients) + model.intercept
- val prediction2 = familyLink.fitted(eta)
- assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " +
- s"gamma family, $link link and fitIntercept = $fitIntercept.")
- }
+ model.transform(dataset).select("features", "prediction", "linkPrediction").collect()
+ .foreach {
+ case Row(features: DenseVector, prediction1: Double, linkPrediction1: Double) =>
+ val eta = BLAS.dot(features, model.coefficients) + model.intercept
+ val prediction2 = familyLink.fitted(eta)
+ val linkPrediction2 = eta
+ assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " +
+ s"gamma family, $link link and fitIntercept = $fitIntercept.")
+ assert(linkPrediction1 ~= linkPrediction2 relTol 1E-5, "Link Prediction mismatch: " +
+ s"GLM with gamma family, $link link and fitIntercept = $fitIntercept.")
+ }
idx += 1
}