aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-02-17 10:17:45 -0800
committerXiangrui Meng <meng@databricks.com>2015-02-17 10:17:45 -0800
commitc76da36c2163276b5c34e59fbb139eeb34ed0faa (patch)
treeded3c1ff10a27cd88bc8ba81b4a71b2ae84a8aa7 /mllib
parent3ce46e94fe77d15f18e916b76b37fa96356ace93 (diff)
downloadspark-c76da36c2163276b5c34e59fbb139eeb34ed0faa.tar.gz
spark-c76da36c2163276b5c34e59fbb139eeb34ed0faa.tar.bz2
spark-c76da36c2163276b5c34e59fbb139eeb34ed0faa.zip
[SPARK-5858][MLLIB] Remove unnecessary first() call in GLM
`numFeatures` is only used by multinomial logistic regression. Calling `.first()` for every GLM causes performance regression, especially in Python. Author: Xiangrui Meng <meng@databricks.com> Closes #4647 from mengxr/SPARK-5858 and squashes the following commits: 036dc7f [Xiangrui Meng] remove unnecessary first() call 12c5548 [Xiangrui Meng] check numFeatures only once
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala7
2 files changed, 9 insertions, 4 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
index 420d6e2861..b787667b01 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
@@ -355,6 +355,10 @@ class LogisticRegressionWithLBFGS
}
override protected def createModel(weights: Vector, intercept: Double) = {
- new LogisticRegressionModel(weights, intercept, numFeatures, numOfLinearPredictor + 1)
+ if (numOfLinearPredictor == 1) {
+ new LogisticRegressionModel(weights, intercept)
+ } else {
+ new LogisticRegressionModel(weights, intercept, numFeatures, numOfLinearPredictor + 1)
+ }
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
index 2b7145362a..7c66e8cdeb 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
@@ -126,7 +126,7 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
/**
* The dimension of training features.
*/
- protected var numFeatures: Int = 0
+ protected var numFeatures: Int = -1
/**
* Set if the algorithm should use feature scaling to improve the convergence during optimization.
@@ -163,7 +163,9 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
* RDD of LabeledPoint entries.
*/
def run(input: RDD[LabeledPoint]): M = {
- numFeatures = input.first().features.size
+ if (numFeatures < 0) {
+ numFeatures = input.map(_.features.size).first()
+ }
/**
* When `numOfLinearPredictor > 1`, the intercepts are encapsulated into weights,
@@ -193,7 +195,6 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
* of LabeledPoint entries starting from the initial weights provided.
*/
def run(input: RDD[LabeledPoint], initialWeights: Vector): M = {
- numFeatures = input.first().features.size
if (input.getStorageLevel == StorageLevel.NONE) {
logWarning("The input data is not directly cached, which may hurt performance if its"