diff options
author | Gary King <gary@idibon.com> | 2016-02-07 09:13:28 +0000 |
---|---|---|
committer | Sean Owen <sowen@cloudera.com> | 2016-02-07 09:13:28 +0000 |
commit | bc8890b357811612ba6c10d96374902b9e08134f (patch) | |
tree | de2ad39d76c48718a7faf5caa995ae6d259f51e2 | |
parent | 81da3bee669aaeb79ec68baaf7c99bff6e5d14fe (diff) | |
download | spark-bc8890b357811612ba6c10d96374902b9e08134f.tar.gz spark-bc8890b357811612ba6c10d96374902b9e08134f.tar.bz2 spark-bc8890b357811612ba6c10d96374902b9e08134f.zip |
[SPARK-13132][MLLIB] cache standardization param value in LogisticRegression
cache the value of the standardization Param in LogisticRegression, rather than re-fetching it from the ParamMap for every index and every optimization step in the quasi-newton optimizer
also, fix Param#toString to cache the stringified representation, rather than re-interpolating it on every call, so any other implementations that have similar repeated access patterns will see a benefit.
this change improves training times for one of my test sets from ~7m30s to ~4m30s
Author: Gary King <gary@idibon.com>
Closes #11027 from idigary/spark-13132-optimize-logistic-regression.
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala | 3 | ||||
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/param/params.scala | 4 |
2 files changed, 5 insertions, 2 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 9b2340a1f1..ac0124513f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -332,12 +332,13 @@ class LogisticRegression @Since("1.2.0") ( val optimizer = if ($(elasticNetParam) == 0.0 || $(regParam) == 0.0) { new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol)) } else { + val standardizationParam = $(standardization) def regParamL1Fun = (index: Int) => { // Remove the L1 penalization on the intercept if (index == numFeatures) { 0.0 } else { - if ($(standardization)) { + if (standardizationParam) { regParamL1 } else { // If `standardization` is false, we still standardize the data diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index f48923d699..d7d6c0f5fa 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -117,7 +117,9 @@ class Param[T](val parent: String, val name: String, val doc: String, val isVali } } - override final def toString: String = s"${parent}__$name" + private[this] val stringRepresentation = s"${parent}__$name" + + override final def toString: String = stringRepresentation override final def hashCode: Int = toString.## |