aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2015-03-25 17:05:56 +0000
committerSean Owen <sowen@cloudera.com>2015-03-25 17:05:56 +0000
commit10c78607b2724f5a64b0cdb966e9c5805f23919b (patch)
treebcfc179866c1927d46b2f76f8f915ffcf510adbd /mllib/src
parent64262ed99912e780b51f240a14dc98fc3cdf916d (diff)
downloadspark-10c78607b2724f5a64b0cdb966e9c5805f23919b.tar.gz
spark-10c78607b2724f5a64b0cdb966e9c5805f23919b.tar.bz2
spark-10c78607b2724f5a64b0cdb966e9c5805f23919b.zip
[SPARK-6496] [MLLIB] GeneralizedLinearAlgorithm.run(input, initialWeights) should initialize numFeatures
In GeneralizedLinearAlgorithm ```numFeatures``` is default to -1, we need to update it to correct value when we call run() to train a model. ```LogisticRegressionWithLBFGS.run(input)``` works well, but when we call ```LogisticRegressionWithLBFGS.run(input, initialWeights)``` to train multiclass classification model, it will throw exception due to the numFeatures is not updated. In this PR, we just update numFeatures at the beginning of GeneralizedLinearAlgorithm.run(input, initialWeights) and add test case. Author: Yanbo Liang <ybliang8@gmail.com> Closes #5167 from yanboliang/spark-6496 and squashes the following commits: 8131c48 [Yanbo Liang] LogisticRegressionWithLBFGS.run(input, initialWeights) should initialize numFeatures
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala6
2 files changed, 10 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
index 45b9ebb4cc..9fd60ff7a0 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
@@ -211,6 +211,10 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
*/
def run(input: RDD[LabeledPoint], initialWeights: Vector): M = {
+ if (numFeatures < 0) {
+ numFeatures = input.map(_.features.size).first()
+ }
+
if (input.getStorageLevel == StorageLevel.NONE) {
logWarning("The input data is not directly cached, which may hurt performance if its"
+ " parent RDDs are also uncached.")
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
index aaa81da9e2..a26c52852c 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
@@ -425,6 +425,12 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with M
val model = lr.run(testRDD)
+ val numFeatures = testRDD.map(_.features.size).first()
+ val initialWeights = Vectors.dense(new Array[Double]((numFeatures + 1) * 2))
+ val model2 = lr.run(testRDD, initialWeights)
+
+ LogisticRegressionSuite.checkModelsEqual(model, model2)
+
/**
* The following is the instruction to reproduce the model using R's glmnet package.
*