diff options
author | Liang-Chi Hsieh <viirya@gmail.com> | 2015-02-06 11:22:11 -0800 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-02-06 11:22:19 -0800 |
commit | 6fda4c136de2a0036e460ef00f60416caabb3ed9 (patch) | |
tree | 9dfd9261fbbb51f731f9384b41b1bd8719a88373 /mllib/src/main | |
parent | 93fee7b903972daa17761bfcdebe1de2e549240d (diff) | |
download | spark-6fda4c136de2a0036e460ef00f60416caabb3ed9.tar.gz spark-6fda4c136de2a0036e460ef00f60416caabb3ed9.tar.bz2 spark-6fda4c136de2a0036e460ef00f60416caabb3ed9.zip |
[SPARK-5652][Mllib] Use broadcasted weights in LogisticRegressionModel
`LogisticRegressionModel`'s `predictPoint` should directly use broadcasted weights. This pr also fixes the compilation errors of two unit test suite: `JavaLogisticRegressionSuite ` and `JavaLinearRegressionSuite`.
Author: Liang-Chi Hsieh <viirya@gmail.com>
Closes #4429 from viirya/use_bcvalue and squashes the following commits:
5a797e5 [Liang-Chi Hsieh] Use broadcasted weights. Fix compilation error.
(cherry picked from commit 80f3bcb58f836cfe1829c85bdd349c10525c8a5e)
Signed-off-by: Xiangrui Meng <meng@databricks.com>
Diffstat (limited to 'mllib/src/main')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index a668e7a7a3..9a391bfff7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -33,7 +33,7 @@ import org.apache.spark.rdd.RDD * * @param weights Weights computed for every feature. * @param intercept Intercept computed for this model. (Only used in Binary Logistic Regression. - * In Multinomial Logistic Regression, the intercepts will not be a single values, + * In Multinomial Logistic Regression, the intercepts will not be a single value, * so the intercepts will be part of the weights.) * @param numFeatures the dimension of the features. * @param numClasses the number of possible outcomes for k classes classification problem in @@ -107,7 +107,7 @@ class LogisticRegressionModel ( // If dataMatrix and weightMatrix have the same dimension, it's binary logistic regression. if (numClasses == 2) { require(numFeatures == weightMatrix.size) - val margin = dot(weights, dataMatrix) + intercept + val margin = dot(weightMatrix, dataMatrix) + intercept val score = 1.0 / (1.0 + math.exp(-margin)) threshold match { case Some(t) => if (score > t) 1.0 else 0.0 @@ -116,11 +116,11 @@ class LogisticRegressionModel ( } else { val dataWithBiasSize = weightMatrix.size / (numClasses - 1) - val weightsArray = weights match { + val weightsArray = weightMatrix match { case dv: DenseVector => dv.values case _ => throw new IllegalArgumentException( - s"weights only supports dense vector but got type ${weights.getClass}.") + s"weights only supports dense vector but got type ${weightMatrix.getClass}.") } val margins = (0 until numClasses - 1).map { i => |