aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala44
1 files changed, 32 insertions, 12 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
index 0e71e8d8e1..e92a3e7fa1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
@@ -31,9 +31,9 @@ import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.{BLAS, Vector}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
-import org.apache.spark.sql.types.{DataType, StructType}
+import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
/**
* Params for Generalized Linear Regression.
@@ -47,6 +47,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
* to be used in the model.
* Supported options: "gaussian", "binomial", "poisson" and "gamma".
* Default is "gaussian".
+ *
* @group param
*/
@Since("2.0.0")
@@ -63,6 +64,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
* Param for the name of link function which provides the relationship
* between the linear predictor and the mean of the distribution function.
* Supported options: "identity", "log", "inverse", "logit", "probit", "cloglog" and "sqrt".
+ *
* @group param
*/
@Since("2.0.0")
@@ -163,7 +165,11 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
setDefault(tol -> 1E-6)
/**
- * Sets the regularization parameter.
+ * Sets the regularization parameter for L2 regularization.
+ * The regularization term is
+ * {{{
+ * 0.5 * regParam * L2norm(coefficients)^2
+ * }}}
* Default is 0.0.
* @group setParam
*/
@@ -190,7 +196,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
def setSolver(value: String): this.type = set(solver, value)
setDefault(solver -> "irls")
- override protected def train(dataset: DataFrame): GeneralizedLinearRegressionModel = {
+ override protected def train(dataset: Dataset[_]): GeneralizedLinearRegressionModel = {
val familyObj = Family.fromName($(family))
val linkObj = if (isDefined(link)) {
Link.fromName($(link))
@@ -210,9 +216,10 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
}
val w = if ($(weightCol).isEmpty) lit(1.0) else col($(weightCol))
- val instances: RDD[Instance] = dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd
- .map { case Row(label: Double, weight: Double, features: Vector) =>
- Instance(label, weight, features)
+ val instances: RDD[Instance] =
+ dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map {
+ case Row(label: Double, weight: Double, features: Vector) =>
+ Instance(label, weight, features)
}
if (familyObj == Gaussian && linkObj == Identity) {
@@ -230,7 +237,8 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
predictionColName,
model,
wlsModel.diagInvAtWA.toArray,
- 1)
+ 1,
+ getSolver)
return model.setSummary(trainingSummary)
}
@@ -250,7 +258,8 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
predictionColName,
model,
irlsModel.diagInvAtWA.toArray,
- irlsModel.numIterations)
+ irlsModel.numIterations,
+ getSolver)
model.setSummary(trainingSummary)
}
@@ -698,7 +707,7 @@ class GeneralizedLinearRegressionModel private[ml] (
: (GeneralizedLinearRegressionModel, String) = {
$(predictionCol) match {
case "" =>
- val predictionColName = "prediction_" + java.util.UUID.randomUUID.toString()
+ val predictionColName = "prediction_" + java.util.UUID.randomUUID.toString
(copy(ParamMap.empty).setPredictionCol(predictionColName), predictionColName)
case p => (this, p)
}
@@ -769,11 +778,12 @@ object GeneralizedLinearRegressionModel extends MLReadable[GeneralizedLinearRegr
* :: Experimental ::
* Summarizing Generalized Linear regression Fits.
*
- * @param predictions predictions outputted by the model's `transform` method
+ * @param predictions predictions output by the model's `transform` method
* @param predictionCol field in "predictions" which gives the prediction value of each instance
* @param model the model that should be summarized
* @param diagInvAtWA diagonal of matrix (A^T * W * A)^-1 in the last iteration
* @param numIterations number of iterations
+ * @param solver the solver algorithm used for model training
*/
@Since("2.0.0")
@Experimental
@@ -782,7 +792,8 @@ class GeneralizedLinearRegressionSummary private[regression] (
@Since("2.0.0") val predictionCol: String,
@Since("2.0.0") val model: GeneralizedLinearRegressionModel,
private val diagInvAtWA: Array[Double],
- @Since("2.0.0") val numIterations: Int) extends Serializable {
+ @Since("2.0.0") val numIterations: Int,
+ @Since("2.0.0") val solver: String) extends Serializable {
import GeneralizedLinearRegression._
@@ -930,6 +941,9 @@ class GeneralizedLinearRegressionSummary private[regression] (
/**
* Standard error of estimated coefficients and intercept.
+ *
+ * If [[GeneralizedLinearRegression.fitIntercept]] is set to true,
+ * then the last element returned corresponds to the intercept.
*/
@Since("2.0.0")
lazy val coefficientStandardErrors: Array[Double] = {
@@ -938,6 +952,9 @@ class GeneralizedLinearRegressionSummary private[regression] (
/**
* T-statistic of estimated coefficients and intercept.
+ *
+ * If [[GeneralizedLinearRegression.fitIntercept]] is set to true,
+ * then the last element returned corresponds to the intercept.
*/
@Since("2.0.0")
lazy val tValues: Array[Double] = {
@@ -951,6 +968,9 @@ class GeneralizedLinearRegressionSummary private[regression] (
/**
* Two-sided p-value of estimated coefficients and intercept.
+ *
+ * If [[GeneralizedLinearRegression.fitIntercept]] is set to true,
+ * then the last element returned corresponds to the intercept.
*/
@Since("2.0.0")
lazy val pValues: Array[Double] = {