aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2016-05-19 23:35:20 -0700
committerXiangrui Meng <meng@databricks.com>2016-05-19 23:35:20 -0700
commitc94b34ebbf4c6ce353c899c571beb34e8db98917 (patch)
treea742c44515259359153599ee62f6aa0e6bd58e91 /mllib/src/main/scala
parent5e203505f1a092e5849ebd01d9ff9e4fc6cdc34a (diff)
downloadspark-c94b34ebbf4c6ce353c899c571beb34e8db98917.tar.gz
spark-c94b34ebbf4c6ce353c899c571beb34e8db98917.tar.bz2
spark-c94b34ebbf4c6ce353c899c571beb34e8db98917.zip
[SPARK-15339][ML] ML 2.0 QA: Scala APIs and code audit for regression
## What changes were proposed in this pull request? * ```GeneralizedLinearRegression``` API docs enhancement. * The default value of ```GeneralizedLinearRegression``` ```linkPredictionCol``` is not set rather than empty. This will consistent with other similar params such as ```weightCol``` * Make some methods more private. * Fix a minor bug of LinearRegression. * Fix some other issues. ## How was this patch tested? Existing tests. Author: Yanbo Liang <ybliang8@gmail.com> Closes #13129 from yanboliang/spark-15339.
Diffstat (limited to 'mllib/src/main/scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala74
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala8
4 files changed, 45 insertions, 45 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
index cc16c2f038..e63eb71080 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
@@ -89,8 +89,8 @@ private[regression] trait AFTSurvivalRegressionParams extends Params
def getQuantilesCol: String = $(quantilesCol)
/** Checks whether the input has quantiles column name. */
- protected[regression] def hasQuantilesCol: Boolean = {
- isDefined(quantilesCol) && $(quantilesCol) != ""
+ private[regression] def hasQuantilesCol: Boolean = {
+ isDefined(quantilesCol) && $(quantilesCol).nonEmpty
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
index e8474d035e..adbdd345e9 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
@@ -43,6 +43,8 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
with HasFitIntercept with HasMaxIter with HasTol with HasRegParam with HasWeightCol
with HasSolver with Logging {
+ import GeneralizedLinearRegression._
+
/**
* Param for the name of family which is a description of the error distribution
* to be used in the model.
@@ -54,8 +56,8 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
@Since("2.0.0")
final val family: Param[String] = new Param(this, "family",
"The name of family which is a description of the error distribution to be used in the " +
- "model. Supported options: gaussian(default), binomial, poisson and gamma.",
- ParamValidators.inArray[String](GeneralizedLinearRegression.supportedFamilyNames.toArray))
+ s"model. Supported options: ${supportedFamilyNames.mkString(", ")}.",
+ ParamValidators.inArray[String](supportedFamilyNames.toArray))
/** @group getParam */
@Since("2.0.0")
@@ -71,9 +73,8 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
@Since("2.0.0")
final val link: Param[String] = new Param(this, "link", "The name of link function " +
"which provides the relationship between the linear predictor and the mean of the " +
- "distribution function. Supported options: identity, log, inverse, logit, probit, " +
- "cloglog and sqrt.",
- ParamValidators.inArray[String](GeneralizedLinearRegression.supportedLinkNames.toArray))
+ s"distribution function. Supported options: ${supportedLinkNames.mkString(", ")}",
+ ParamValidators.inArray[String](supportedLinkNames.toArray))
/** @group getParam */
@Since("2.0.0")
@@ -81,19 +82,23 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
/**
* Param for link prediction (linear predictor) column name.
- * Default is empty, which means we do not output link prediction.
+ * Default is not set, which means we do not output link prediction.
*
* @group param
*/
@Since("2.0.0")
final val linkPredictionCol: Param[String] = new Param[String](this, "linkPredictionCol",
"link prediction (linear predictor) column name")
- setDefault(linkPredictionCol, "")
/** @group getParam */
@Since("2.0.0")
def getLinkPredictionCol: String = $(linkPredictionCol)
+ /** Checks whether we should output link prediction. */
+ private[regression] def hasLinkPredictionCol: Boolean = {
+ isDefined(linkPredictionCol) && $(linkPredictionCol).nonEmpty
+ }
+
import GeneralizedLinearRegression._
@Since("2.0.0")
@@ -107,7 +112,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
s"with ${$(family)} family does not support ${$(link)} link function.")
}
val newSchema = super.validateAndTransformSchema(schema, fitting, featuresDataType)
- if ($(linkPredictionCol).nonEmpty) {
+ if (hasLinkPredictionCol) {
SchemaUtils.appendColumn(newSchema, $(linkPredictionCol), DoubleType)
} else {
newSchema
@@ -205,7 +210,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
/**
* Sets the value of param [[weightCol]].
* If this is not set or empty, we treat all instance weights as 1.0.
- * Default is empty, so all instances have weight one.
+ * Default is not set, so all instances have weight one.
*
* @group setParam
*/
@@ -214,7 +219,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
/**
* Sets the solver algorithm used for optimization.
- * Currently only support "irls" which is also the default solver.
+ * Currently only supports "irls" which is also the default solver.
*
* @group setParam
*/
@@ -239,10 +244,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
}
val familyAndLink = new FamilyAndLink(familyObj, linkObj)
- val numFeatures = dataset.select(col($(featuresCol))).limit(1).rdd
- .map { case Row(features: Vector) =>
- features.size
- }.first()
+ val numFeatures = dataset.select(col($(featuresCol))).first().getAs[Vector](0).size
if (numFeatures > WeightedLeastSquares.MAX_NUM_FEATURES) {
val msg = "Currently, GeneralizedLinearRegression only supports number of features" +
s" <= ${WeightedLeastSquares.MAX_NUM_FEATURES}. Found $numFeatures in the input dataset."
@@ -294,7 +296,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
override def load(path: String): GeneralizedLinearRegression = super.load(path)
/** Set of family and link pairs that GeneralizedLinearRegression supports. */
- private[ml] lazy val supportedFamilyAndLinkPairs = Set(
+ private[regression] lazy val supportedFamilyAndLinkPairs = Set(
Gaussian -> Identity, Gaussian -> Log, Gaussian -> Inverse,
Binomial -> Logit, Binomial -> Probit, Binomial -> CLogLog,
Poisson -> Log, Poisson -> Identity, Poisson -> Sqrt,
@@ -302,17 +304,17 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
)
/** Set of family names that GeneralizedLinearRegression supports. */
- private[ml] lazy val supportedFamilyNames = supportedFamilyAndLinkPairs.map(_._1.name)
+ private[regression] lazy val supportedFamilyNames = supportedFamilyAndLinkPairs.map(_._1.name)
/** Set of link names that GeneralizedLinearRegression supports. */
- private[ml] lazy val supportedLinkNames = supportedFamilyAndLinkPairs.map(_._2.name)
+ private[regression] lazy val supportedLinkNames = supportedFamilyAndLinkPairs.map(_._2.name)
- private[ml] val epsilon: Double = 1E-16
+ private[regression] val epsilon: Double = 1E-16
/**
* Wrapper of family and link combination used in the model.
*/
- private[ml] class FamilyAndLink(val family: Family, val link: Link) extends Serializable {
+ private[regression] class FamilyAndLink(val family: Family, val link: Link) extends Serializable {
/** Linear predictor based on given mu. */
def predict(mu: Double): Double = link.link(family.project(mu))
@@ -359,7 +361,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
*
* @param name the name of the family.
*/
- private[ml] abstract class Family(val name: String) extends Serializable {
+ private[regression] abstract class Family(val name: String) extends Serializable {
/** The default link instance of this family. */
val defaultLink: Link
@@ -391,7 +393,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
def project(mu: Double): Double = mu
}
- private[ml] object Family {
+ private[regression] object Family {
/**
* Gets the [[Family]] object from its name.
@@ -412,7 +414,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
* Gaussian exponential family distribution.
* The default link for the Gaussian family is the identity link.
*/
- private[ml] object Gaussian extends Family("gaussian") {
+ private[regression] object Gaussian extends Family("gaussian") {
val defaultLink: Link = Identity
@@ -448,7 +450,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
* Binomial exponential family distribution.
* The default link for the Binomial family is the logit link.
*/
- private[ml] object Binomial extends Family("binomial") {
+ private[regression] object Binomial extends Family("binomial") {
val defaultLink: Link = Logit
@@ -492,7 +494,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
* Poisson exponential family distribution.
* The default link for the Poisson family is the log link.
*/
- private[ml] object Poisson extends Family("poisson") {
+ private[regression] object Poisson extends Family("poisson") {
val defaultLink: Link = Log
@@ -533,7 +535,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
* Gamma exponential family distribution.
* The default link for the Gamma family is the inverse link.
*/
- private[ml] object Gamma extends Family("gamma") {
+ private[regression] object Gamma extends Family("gamma") {
val defaultLink: Link = Inverse
@@ -578,7 +580,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
*
* @param name the name of link function.
*/
- private[ml] abstract class Link(val name: String) extends Serializable {
+ private[regression] abstract class Link(val name: String) extends Serializable {
/** The link function. */
def link(mu: Double): Double
@@ -590,7 +592,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
def unlink(eta: Double): Double
}
- private[ml] object Link {
+ private[regression] object Link {
/**
* Gets the [[Link]] object from its name.
@@ -611,7 +613,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
}
}
- private[ml] object Identity extends Link("identity") {
+ private[regression] object Identity extends Link("identity") {
override def link(mu: Double): Double = mu
@@ -620,7 +622,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
override def unlink(eta: Double): Double = eta
}
- private[ml] object Logit extends Link("logit") {
+ private[regression] object Logit extends Link("logit") {
override def link(mu: Double): Double = math.log(mu / (1.0 - mu))
@@ -629,7 +631,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
override def unlink(eta: Double): Double = 1.0 / (1.0 + math.exp(-1.0 * eta))
}
- private[ml] object Log extends Link("log") {
+ private[regression] object Log extends Link("log") {
override def link(mu: Double): Double = math.log(mu)
@@ -638,7 +640,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
override def unlink(eta: Double): Double = math.exp(eta)
}
- private[ml] object Inverse extends Link("inverse") {
+ private[regression] object Inverse extends Link("inverse") {
override def link(mu: Double): Double = 1.0 / mu
@@ -647,7 +649,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
override def unlink(eta: Double): Double = 1.0 / eta
}
- private[ml] object Probit extends Link("probit") {
+ private[regression] object Probit extends Link("probit") {
override def link(mu: Double): Double = dist.Gaussian(0.0, 1.0).icdf(mu)
@@ -658,7 +660,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
override def unlink(eta: Double): Double = dist.Gaussian(0.0, 1.0).cdf(eta)
}
- private[ml] object CLogLog extends Link("cloglog") {
+ private[regression] object CLogLog extends Link("cloglog") {
override def link(mu: Double): Double = math.log(-1.0 * math.log(1 - mu))
@@ -667,7 +669,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
override def unlink(eta: Double): Double = 1.0 - math.exp(-1.0 * math.exp(eta))
}
- private[ml] object Sqrt extends Link("sqrt") {
+ private[regression] object Sqrt extends Link("sqrt") {
override def link(mu: Double): Double = math.sqrt(mu)
@@ -732,7 +734,7 @@ class GeneralizedLinearRegressionModel private[ml] (
if ($(predictionCol).nonEmpty) {
output = output.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
}
- if ($(linkPredictionCol).nonEmpty) {
+ if (hasLinkPredictionCol) {
output = output.withColumn($(linkPredictionCol), predictLinkUDF(col($(featuresCol))))
}
output.toDF()
@@ -860,7 +862,7 @@ class GeneralizedLinearRegressionSummary private[regression] (
*/
@Since("2.0.0")
val predictionCol: String = {
- if (origModel.isDefined(origModel.predictionCol) && origModel.getPredictionCol != "") {
+ if (origModel.isDefined(origModel.predictionCol) && origModel.getPredictionCol.nonEmpty) {
origModel.getPredictionCol
} else {
"prediction_" + java.util.UUID.randomUUID.toString
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
index ba0f59e89b..d16e8e3f6b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
@@ -69,8 +69,8 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures
setDefault(isotonic -> true, featureIndex -> 0)
/** Checks whether the input has weight column. */
- protected[ml] def hasWeightCol: Boolean = {
- isDefined(weightCol) && $(weightCol) != ""
+ private[regression] def hasWeightCol: Boolean = {
+ isDefined(weightCol) && $(weightCol).nonEmpty
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index a702f02c91..ff1038cbf1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -161,9 +161,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
override protected def train(dataset: Dataset[_]): LinearRegressionModel = {
// Extract the number of features before deciding optimization solver.
- val numFeatures = dataset.select(col($(featuresCol))).limit(1).rdd.map {
- case Row(features: Vector) => features.size
- }.first()
+ val numFeatures = dataset.select(col($(featuresCol))).first().getAs[Vector](0).size
val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
if (($(solver) == "auto" && $(elasticNetParam) == 0.0 &&
@@ -242,7 +240,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
val coefficients = Vectors.sparse(numFeatures, Seq())
val intercept = yMean
- val model = new LinearRegressionModel(uid, coefficients, intercept)
+ val model = copyValues(new LinearRegressionModel(uid, coefficients, intercept))
// Handle possible missing or invalid prediction columns
val (summaryModel, predictionColName) = model.findSummaryModelAndPredictionCol()
@@ -254,7 +252,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
model,
Array(0D),
Array(0D))
- return copyValues(model.setSummary(trainingSummary))
+ return model.setSummary(trainingSummary)
} else {
require($(regParam) == 0.0, "The standard deviation of the label is zero. " +
"Model cannot be regularized.")