diff options
author | Liang-Chi Hsieh <viirya@gmail.com> | 2015-02-06 11:22:11 -0800 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-02-06 11:22:11 -0800 |
commit | 80f3bcb58f836cfe1829c85bdd349c10525c8a5e (patch) | |
tree | 9dfd9261fbbb51f731f9384b41b1bd8719a88373 /mllib/src/test | |
parent | 0d74bd7fd7b2722d08eddc5c269b8b2b6cb47635 (diff) | |
download | spark-80f3bcb58f836cfe1829c85bdd349c10525c8a5e.tar.gz spark-80f3bcb58f836cfe1829c85bdd349c10525c8a5e.tar.bz2 spark-80f3bcb58f836cfe1829c85bdd349c10525c8a5e.zip |
[SPARK-5652][Mllib] Use broadcasted weights in LogisticRegressionModel
`LogisticRegressionModel`'s `predictPoint` should directly use broadcasted weights. This pr also fixes the compilation errors of two unit test suite: `JavaLogisticRegressionSuite ` and `JavaLinearRegressionSuite`.
Author: Liang-Chi Hsieh <viirya@gmail.com>
Closes #4429 from viirya/use_bcvalue and squashes the following commits:
5a797e5 [Liang-Chi Hsieh] Use broadcasted weights. Fix compilation error.
Diffstat (limited to 'mllib/src/test')
-rw-r--r-- | mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java | 4 | ||||
-rw-r--r-- | mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java | 4 |
2 files changed, 4 insertions, 4 deletions
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java index 26284023b0..d4b6644792 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java @@ -84,7 +84,7 @@ public class JavaLogisticRegressionSuite implements Serializable { .setThreshold(0.6) .setProbabilityCol("myProbability"); LogisticRegressionModel model = lr.fit(dataset); - assert(model.fittingParamMap().apply(lr.maxIter()) == 10); + assert(model.fittingParamMap().apply(lr.maxIter()).equals(10)); assert(model.fittingParamMap().apply(lr.regParam()).equals(1.0)); assert(model.fittingParamMap().apply(lr.threshold()).equals(0.6)); assert(model.getThreshold() == 0.6); @@ -109,7 +109,7 @@ public class JavaLogisticRegressionSuite implements Serializable { // Call fit() with new params, and check as many params as we can. LogisticRegressionModel model2 = lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1), lr.threshold().w(0.4), lr.probabilityCol().w("theProb")); - assert(model2.fittingParamMap().apply(lr.maxIter()) == 5); + assert(model2.fittingParamMap().apply(lr.maxIter()).equals(5)); assert(model2.fittingParamMap().apply(lr.regParam()).equals(0.1)); assert(model2.fittingParamMap().apply(lr.threshold()).equals(0.4)); assert(model2.getThreshold() == 0.4); diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java index 5bd616e74d..40d5a92bb3 100644 --- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java @@ -76,13 +76,13 @@ public class JavaLinearRegressionSuite implements Serializable { .setMaxIter(10) .setRegParam(1.0); LinearRegressionModel model = lr.fit(dataset); - assert(model.fittingParamMap().apply(lr.maxIter()) == 10); + assert(model.fittingParamMap().apply(lr.maxIter()).equals(10)); assert(model.fittingParamMap().apply(lr.regParam()).equals(1.0)); // Call fit() with new params, and check as many params as we can. LinearRegressionModel model2 = lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1), lr.predictionCol().w("thePred")); - assert(model2.fittingParamMap().apply(lr.maxIter()) == 5); + assert(model2.fittingParamMap().apply(lr.maxIter()).equals(5)); assert(model2.fittingParamMap().apply(lr.regParam()).equals(0.1)); assert(model2.getPredictionCol().equals("thePred")); } |