aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorDB Tsai <dbt@netflix.com>2015-06-02 19:12:08 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-06-02 19:12:08 -0700
commita86b3e9b9b75f5af4fdbba22e87769058f023204 (patch)
treec44cf27beb8bfc2e93efede89c9e2755b1789629 /mllib
parentc3f4c3257194ba34ccd298d13ea1edcfc75f7552 (diff)
downloadspark-a86b3e9b9b75f5af4fdbba22e87769058f023204.tar.gz
spark-a86b3e9b9b75f5af4fdbba22e87769058f023204.tar.bz2
spark-a86b3e9b9b75f5af4fdbba22e87769058f023204.zip
[SPARK-7547] [ML] Scala Example code for ElasticNet
This is scala example code for both linear and logistic regression. Python and Java versions are to be added. Author: DB Tsai <dbt@netflix.com> Closes #6576 from dbtsai/elasticNetExample and squashes the following commits: e7ca406 [DB Tsai] fix test 6bb6d77 [DB Tsai] fix suite and remove duplicated setMaxIter 136e0dd [DB Tsai] address feedback 1ec29d4 [DB Tsai] fix style 9462f5f [DB Tsai] add example
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala8
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala6
5 files changed, 13 insertions, 9 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index d13109d9da..f136bcee9c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -74,7 +74,7 @@ class LogisticRegression(override val uid: String)
setDefault(elasticNetParam -> 0.0)
/**
- * Set the maximal number of iterations.
+ * Set the maximum number of iterations.
* Default is 100.
* @group setParam
*/
@@ -90,7 +90,11 @@ class LogisticRegression(override val uid: String)
def setTol(value: Double): this.type = set(tol, value)
setDefault(tol -> 1E-6)
- /** @group setParam */
+ /**
+ * Whether to fit an intercept term.
+ * Default is true.
+ * @group setParam
+ * */
def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value)
setDefault(fitIntercept -> true)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
index 1ffb5eddc3..8ffbcf0d8b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
@@ -33,7 +33,7 @@ private[shared] object SharedParamsCodeGen {
val params = Seq(
ParamDesc[Double]("regParam", "regularization parameter (>= 0)",
isValid = "ParamValidators.gtEq(0)"),
- ParamDesc[Int]("maxIter", "max number of iterations (>= 0)",
+ ParamDesc[Int]("maxIter", "maximum number of iterations (>= 0)",
isValid = "ParamValidators.gtEq(0)"),
ParamDesc[String]("featuresCol", "features column name", Some("\"features\"")),
ParamDesc[String]("labelCol", "label column name", Some("\"label\"")),
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
index ed08417bd4..a0c8ccdac9 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
@@ -45,10 +45,10 @@ private[ml] trait HasRegParam extends Params {
private[ml] trait HasMaxIter extends Params {
/**
- * Param for max number of iterations (>= 0).
+ * Param for maximum number of iterations (>= 0).
* @group param
*/
- final val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations (>= 0)", ParamValidators.gtEq(0))
+ final val maxIter: IntParam = new IntParam(this, "maxIter", "maximum number of iterations (>= 0)", ParamValidators.gtEq(0))
/** @group getParam */
final def getMaxIter: Int = $(maxIter)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index fe2a71a331..70cd8e9e87 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -83,7 +83,7 @@ class LinearRegression(override val uid: String)
setDefault(elasticNetParam -> 0.0)
/**
- * Set the maximal number of iterations.
+ * Set the maximum number of iterations.
* Default is 100.
* @group setParam
*/
diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
index f80e774909..96094d7a09 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
@@ -27,7 +27,7 @@ class ParamsSuite extends SparkFunSuite {
import solver.{maxIter, inputCol}
assert(maxIter.name === "maxIter")
- assert(maxIter.doc === "max number of iterations (>= 0)")
+ assert(maxIter.doc === "maximum number of iterations (>= 0)")
assert(maxIter.parent === uid)
assert(maxIter.toString === s"${uid}__maxIter")
assert(!maxIter.isValid(-1))
@@ -36,7 +36,7 @@ class ParamsSuite extends SparkFunSuite {
solver.setMaxIter(5)
assert(solver.explainParam(maxIter) ===
- "maxIter: max number of iterations (>= 0) (default: 10, current: 5)")
+ "maxIter: maximum number of iterations (>= 0) (default: 10, current: 5)")
assert(inputCol.toString === s"${uid}__inputCol")
@@ -120,7 +120,7 @@ class ParamsSuite extends SparkFunSuite {
intercept[NoSuchElementException](solver.getInputCol)
assert(solver.explainParam(maxIter) ===
- "maxIter: max number of iterations (>= 0) (default: 10, current: 100)")
+ "maxIter: maximum number of iterations (>= 0) (default: 10, current: 100)")
assert(solver.explainParams() ===
Seq(inputCol, maxIter).map(solver.explainParam).mkString("\n"))