aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2017-01-21 21:15:57 -0800
committerYanbo Liang <ybliang8@gmail.com>2017-01-21 21:15:57 -0800
commit3dcad9fab17297f9966026f29fefb5c726965a13 (patch)
tree0547749fcff5c94d48b0320f4e626698eb0b9273 /mllib/src
parentaa014eb74bec332ca4d734f2501a4a01a806fa37 (diff)
downloadspark-3dcad9fab17297f9966026f29fefb5c726965a13.tar.gz
spark-3dcad9fab17297f9966026f29fefb5c726965a13.tar.bz2
spark-3dcad9fab17297f9966026f29fefb5c726965a13.zip
[SPARK-19155][ML] MLlib GeneralizedLinearRegression family and link should case insensitive
## What changes were proposed in this pull request? MLlib ```GeneralizedLinearRegression``` ```family``` and ```link``` should be case insensitive. This is consistent with some other MLlib params such as [```featureSubsetStrategy```](https://github.com/apache/spark/blob/master/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala#L415). ## How was this patch tested? Update corresponding tests. Author: Yanbo Liang <ybliang8@gmail.com> Closes #16516 from yanboliang/spark-19133.
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala8
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala4
2 files changed, 6 insertions, 6 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
index a32302bf5d..116f0f6507 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
@@ -57,7 +57,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
final val family: Param[String] = new Param(this, "family",
"The name of family which is a description of the error distribution to be used in the " +
s"model. Supported options: ${supportedFamilyNames.mkString(", ")}.",
- ParamValidators.inArray[String](supportedFamilyNames.toArray))
+ (value: String) => supportedFamilyNames.contains(value.toLowerCase))
/** @group getParam */
@Since("2.0.0")
@@ -74,7 +74,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
final val link: Param[String] = new Param(this, "link", "The name of link function " +
"which provides the relationship between the linear predictor and the mean of the " +
s"distribution function. Supported options: ${supportedLinkNames.mkString(", ")}",
- ParamValidators.inArray[String](supportedLinkNames.toArray))
+ (value: String) => supportedLinkNames.contains(value.toLowerCase))
/** @group getParam */
@Since("2.0.0")
@@ -414,7 +414,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
* @param name family name: "gaussian", "binomial", "poisson" or "gamma".
*/
def fromName(name: String): Family = {
- name match {
+ name.toLowerCase match {
case Gaussian.name => Gaussian
case Binomial.name => Binomial
case Poisson.name => Poisson
@@ -626,7 +626,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
* "inverse", "probit", "cloglog" or "sqrt".
*/
def fromName(name: String): Link = {
- name match {
+ name.toLowerCase match {
case Identity.name => Identity
case Logit.name => Logit
case Log.name => Log
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
index ed24c1e16a..9f3d643c2b 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
@@ -553,7 +553,7 @@ class GeneralizedLinearRegressionSuite
for ((link, dataset) <- Seq(("inverse", datasetGammaInverse),
("identity", datasetGammaIdentity), ("log", datasetGammaLog))) {
for (fitIntercept <- Seq(false, true)) {
- val trainer = new GeneralizedLinearRegression().setFamily("gamma").setLink(link)
+ val trainer = new GeneralizedLinearRegression().setFamily("Gamma").setLink(link)
.setFitIntercept(fitIntercept).setLinkPredictionCol("linkPrediction")
val model = trainer.fit(dataset)
val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1))
@@ -990,7 +990,7 @@ class GeneralizedLinearRegressionSuite
-0.6344390 0.3172195 0.2114797 -0.1586097
*/
val trainer = new GeneralizedLinearRegression()
- .setFamily("gamma")
+ .setFamily("Gamma")
.setWeightCol("weight")
val model = trainer.fit(datasetWithWeight)