aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorGary King <gary@idibon.com>2016-02-07 09:13:28 +0000
committerSean Owen <sowen@cloudera.com>2016-02-07 09:13:28 +0000
commitbc8890b357811612ba6c10d96374902b9e08134f (patch)
treede2ad39d76c48718a7faf5caa995ae6d259f51e2 /mllib
parent81da3bee669aaeb79ec68baaf7c99bff6e5d14fe (diff)
downloadspark-bc8890b357811612ba6c10d96374902b9e08134f.tar.gz
spark-bc8890b357811612ba6c10d96374902b9e08134f.tar.bz2
spark-bc8890b357811612ba6c10d96374902b9e08134f.zip
[SPARK-13132][MLLIB] cache standardization param value in LogisticRegression
cache the value of the standardization Param in LogisticRegression, rather than re-fetching it from the ParamMap for every index and every optimization step in the quasi-newton optimizer also, fix Param#toString to cache the stringified representation, rather than re-interpolating it on every call, so any other implementations that have similar repeated access patterns will see a benefit. this change improves training times for one of my test sets from ~7m30s to ~4m30s Author: Gary King <gary@idibon.com> Closes #11027 from idigary/spark-13132-optimize-logistic-regression.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala3
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/params.scala4
2 files changed, 5 insertions, 2 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 9b2340a1f1..ac0124513f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -332,12 +332,13 @@ class LogisticRegression @Since("1.2.0") (
val optimizer = if ($(elasticNetParam) == 0.0 || $(regParam) == 0.0) {
new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol))
} else {
+ val standardizationParam = $(standardization)
def regParamL1Fun = (index: Int) => {
// Remove the L1 penalization on the intercept
if (index == numFeatures) {
0.0
} else {
- if ($(standardization)) {
+ if (standardizationParam) {
regParamL1
} else {
// If `standardization` is false, we still standardize the data
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
index f48923d699..d7d6c0f5fa 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
@@ -117,7 +117,9 @@ class Param[T](val parent: String, val name: String, val doc: String, val isVali
}
}
- override final def toString: String = s"${parent}__$name"
+ private[this] val stringRepresentation = s"${parent}__$name"
+
+ override final def toString: String = stringRepresentation
override final def hashCode: Int = toString.##