diff options
author | Yanbo Liang <ybliang8@gmail.com> | 2016-04-21 17:31:33 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2016-04-21 17:31:33 -0700 |
commit | 4e726227a3e68c776ea30b78b7db8d01d00b44d6 (patch) | |
tree | 8476a45adfd764147d89c4f0f4b904f0067e2a94 | |
parent | f25a3ea8d3ee6972efb925826981918549deacaa (diff) | |
download | spark-4e726227a3e68c776ea30b78b7db8d01d00b44d6.tar.gz spark-4e726227a3e68c776ea30b78b7db8d01d00b44d6.tar.bz2 spark-4e726227a3e68c776ea30b78b7db8d01d00b44d6.zip |
[SPARK-14479][ML] GLM supports output link prediction
## What changes were proposed in this pull request?
GLM supports output link prediction.
## How was this patch tested?
unit test.
Author: Yanbo Liang <ybliang8@gmail.com>
Closes #12287 from yanboliang/spark-14479.
2 files changed, 108 insertions, 34 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index e92a3e7fa1..dcf69afe0d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -78,6 +78,20 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam @Since("2.0.0") def getLink: String = $(link) + /** + * Param for link prediction (linear predictor) column name. + * Default is empty, which means we do not output link prediction. + * @group param + */ + @Since("2.0.0") + final val linkPredictionCol: Param[String] = new Param[String](this, "linkPredictionCol", + "link prediction (linear predictor) column name") + setDefault(linkPredictionCol, "") + + /** @group getParam */ + @Since("2.0.0") + def getLinkPredictionCol: String = $(linkPredictionCol) + import GeneralizedLinearRegression._ @Since("2.0.0") @@ -93,7 +107,12 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam Family.fromName($(family)) -> Link.fromName($(link))), "Generalized Linear Regression " + s"with ${$(family)} family does not support ${$(link)} link function.") } - super.validateAndTransformSchema(schema, fitting, featuresDataType) + val newSchema = super.validateAndTransformSchema(schema, fitting, featuresDataType) + if ($(linkPredictionCol).nonEmpty) { + SchemaUtils.appendColumn(newSchema, $(linkPredictionCol), DoubleType) + } else { + newSchema + } } } @@ -196,6 +215,13 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val def setSolver(value: String): this.type = set(solver, value) setDefault(solver -> "irls") + /** + * Sets the link prediction (linear predictor) column name. + * @group setParam + */ + @Since("2.0.0") + def setLinkPredictionCol(value: String): this.type = set(linkPredictionCol, value) + override protected def train(dataset: Dataset[_]): GeneralizedLinearRegressionModel = { val familyObj = Family.fromName($(family)) val linkObj = if (isDefined(link)) { @@ -666,6 +692,13 @@ class GeneralizedLinearRegressionModel private[ml] ( extends RegressionModel[Vector, GeneralizedLinearRegressionModel] with GeneralizedLinearRegressionBase with MLWritable { + /** + * Sets the link prediction (linear predictor) column name. + * @group setParam + */ + @Since("2.0.0") + def setLinkPredictionCol(value: String): this.type = set(linkPredictionCol, value) + import GeneralizedLinearRegression._ lazy val familyObj = Family.fromName($(family)) @@ -677,10 +710,35 @@ class GeneralizedLinearRegressionModel private[ml] ( lazy val familyAndLink = new FamilyAndLink(familyObj, linkObj) override protected def predict(features: Vector): Double = { - val eta = BLAS.dot(features, coefficients) + intercept + val eta = predictLink(features) familyAndLink.fitted(eta) } + /** + * Calculate the link prediction (linear predictor) of the given instance. + */ + private def predictLink(features: Vector): Double = { + BLAS.dot(features, coefficients) + intercept + } + + override def transform(dataset: Dataset[_]): DataFrame = { + transformSchema(dataset.schema) + transformImpl(dataset) + } + + override protected def transformImpl(dataset: Dataset[_]): DataFrame = { + val predictUDF = udf { (features: Vector) => predict(features) } + val predictLinkUDF = udf { (features: Vector) => predictLink(features) } + var output = dataset + if ($(predictionCol).nonEmpty) { + output = output.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) + } + if ($(linkPredictionCol).nonEmpty) { + output = output.withColumn($(linkPredictionCol), predictLinkUDF(col($(featuresCol)))) + } + output.toDF + } + private var trainingSummary: Option[GeneralizedLinearRegressionSummary] = None /** diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala index 3ecc210abd..0b5e77afc3 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala @@ -247,20 +247,24 @@ class GeneralizedLinearRegressionSuite ("inverse", datasetGaussianInverse))) { for (fitIntercept <- Seq(false, true)) { val trainer = new GeneralizedLinearRegression().setFamily("gaussian").setLink(link) - .setFitIntercept(fitIntercept) + .setFitIntercept(fitIntercept).setLinkPredictionCol("linkPrediction") val model = trainer.fit(dataset) val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1)) assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with gaussian family, " + s"$link link and fitIntercept = $fitIntercept.") val familyLink = new FamilyAndLink(Gaussian, Link.fromName(link)) - model.transform(dataset).select("features", "prediction").collect().foreach { - case Row(features: DenseVector, prediction1: Double) => - val eta = BLAS.dot(features, model.coefficients) + model.intercept - val prediction2 = familyLink.fitted(eta) - assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " + - s"gaussian family, $link link and fitIntercept = $fitIntercept.") - } + model.transform(dataset).select("features", "prediction", "linkPrediction").collect() + .foreach { + case Row(features: DenseVector, prediction1: Double, linkPrediction1: Double) => + val eta = BLAS.dot(features, model.coefficients) + model.intercept + val prediction2 = familyLink.fitted(eta) + val linkPrediction2 = eta + assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " + + s"gaussian family, $link link and fitIntercept = $fitIntercept.") + assert(linkPrediction1 ~= linkPrediction2 relTol 1E-5, "Link Prediction mismatch: " + + s"GLM with gaussian family, $link link and fitIntercept = $fitIntercept.") + } idx += 1 } @@ -358,7 +362,7 @@ class GeneralizedLinearRegressionSuite ("cloglog", datasetBinomial))) { for (fitIntercept <- Seq(false, true)) { val trainer = new GeneralizedLinearRegression().setFamily("binomial").setLink(link) - .setFitIntercept(fitIntercept) + .setFitIntercept(fitIntercept).setLinkPredictionCol("linkPrediction") val model = trainer.fit(dataset) val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1), model.coefficients(2), model.coefficients(3)) @@ -366,13 +370,17 @@ class GeneralizedLinearRegressionSuite s"$link link and fitIntercept = $fitIntercept.") val familyLink = new FamilyAndLink(Binomial, Link.fromName(link)) - model.transform(dataset).select("features", "prediction").collect().foreach { - case Row(features: DenseVector, prediction1: Double) => - val eta = BLAS.dot(features, model.coefficients) + model.intercept - val prediction2 = familyLink.fitted(eta) - assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " + - s"binomial family, $link link and fitIntercept = $fitIntercept.") - } + model.transform(dataset).select("features", "prediction", "linkPrediction").collect() + .foreach { + case Row(features: DenseVector, prediction1: Double, linkPrediction1: Double) => + val eta = BLAS.dot(features, model.coefficients) + model.intercept + val prediction2 = familyLink.fitted(eta) + val linkPrediction2 = eta + assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " + + s"binomial family, $link link and fitIntercept = $fitIntercept.") + assert(linkPrediction1 ~= linkPrediction2 relTol 1E-5, "Link Prediction mismatch: " + + s"GLM with binomial family, $link link and fitIntercept = $fitIntercept.") + } idx += 1 } @@ -427,20 +435,24 @@ class GeneralizedLinearRegressionSuite ("sqrt", datasetPoissonSqrt))) { for (fitIntercept <- Seq(false, true)) { val trainer = new GeneralizedLinearRegression().setFamily("poisson").setLink(link) - .setFitIntercept(fitIntercept) + .setFitIntercept(fitIntercept).setLinkPredictionCol("linkPrediction") val model = trainer.fit(dataset) val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1)) assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with poisson family, " + s"$link link and fitIntercept = $fitIntercept.") val familyLink = new FamilyAndLink(Poisson, Link.fromName(link)) - model.transform(dataset).select("features", "prediction").collect().foreach { - case Row(features: DenseVector, prediction1: Double) => - val eta = BLAS.dot(features, model.coefficients) + model.intercept - val prediction2 = familyLink.fitted(eta) - assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " + - s"poisson family, $link link and fitIntercept = $fitIntercept.") - } + model.transform(dataset).select("features", "prediction", "linkPrediction").collect() + .foreach { + case Row(features: DenseVector, prediction1: Double, linkPrediction1: Double) => + val eta = BLAS.dot(features, model.coefficients) + model.intercept + val prediction2 = familyLink.fitted(eta) + val linkPrediction2 = eta + assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " + + s"poisson family, $link link and fitIntercept = $fitIntercept.") + assert(linkPrediction1 ~= linkPrediction2 relTol 1E-5, "Link Prediction mismatch: " + + s"GLM with poisson family, $link link and fitIntercept = $fitIntercept.") + } idx += 1 } @@ -495,20 +507,24 @@ class GeneralizedLinearRegressionSuite ("identity", datasetGammaIdentity), ("log", datasetGammaLog))) { for (fitIntercept <- Seq(false, true)) { val trainer = new GeneralizedLinearRegression().setFamily("gamma").setLink(link) - .setFitIntercept(fitIntercept) + .setFitIntercept(fitIntercept).setLinkPredictionCol("linkPrediction") val model = trainer.fit(dataset) val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1)) assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with gamma family, " + s"$link link and fitIntercept = $fitIntercept.") val familyLink = new FamilyAndLink(Gamma, Link.fromName(link)) - model.transform(dataset).select("features", "prediction").collect().foreach { - case Row(features: DenseVector, prediction1: Double) => - val eta = BLAS.dot(features, model.coefficients) + model.intercept - val prediction2 = familyLink.fitted(eta) - assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " + - s"gamma family, $link link and fitIntercept = $fitIntercept.") - } + model.transform(dataset).select("features", "prediction", "linkPrediction").collect() + .foreach { + case Row(features: DenseVector, prediction1: Double, linkPrediction1: Double) => + val eta = BLAS.dot(features, model.coefficients) + model.intercept + val prediction2 = familyLink.fitted(eta) + val linkPrediction2 = eta + assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " + + s"gamma family, $link link and fitIntercept = $fitIntercept.") + assert(linkPrediction1 ~= linkPrediction2 relTol 1E-5, "Link Prediction mismatch: " + + s"GLM with gamma family, $link link and fitIntercept = $fitIntercept.") + } idx += 1 } |