aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorHolden Karau <holden@pigscanfly.ca>2015-09-05 00:04:00 -1000
committerReynold Xin <rxin@databricks.com>2015-09-05 00:04:00 -1000
commit871764c6ce531af5b1ac7ccccb32e7a903b59a2a (patch)
tree7699e590f0962d61291c047360b61a90d9d76c0e /mllib
parentbca8c072bd710beda6cfac1533a67f32f579b134 (diff)
downloadspark-871764c6ce531af5b1ac7ccccb32e7a903b59a2a.tar.gz
spark-871764c6ce531af5b1ac7ccccb32e7a903b59a2a.tar.bz2
spark-871764c6ce531af5b1ac7ccccb32e7a903b59a2a.zip
[SPARK-10013] [ML] [JAVA] [TEST] remove java assert from java unit tests
From Jira: We should use assertTrue, etc. instead to make sure the asserts are not ignored in tests. Author: Holden Karau <holden@pigscanfly.ca> Closes #8607 from holdenk/SPARK-10013-remove-java-assert-from-java-unit-tests.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java51
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java13
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java2
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java40
4 files changed, 54 insertions, 52 deletions
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
index 618b95b9bd..fd22eb6dca 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
@@ -22,6 +22,7 @@ import java.lang.Math;
import java.util.List;
import org.junit.After;
+import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
@@ -63,16 +64,16 @@ public class JavaLogisticRegressionSuite implements Serializable {
@Test
public void logisticRegressionDefaultParams() {
LogisticRegression lr = new LogisticRegression();
- assert(lr.getLabelCol().equals("label"));
+ Assert.assertEquals(lr.getLabelCol(), "label");
LogisticRegressionModel model = lr.fit(dataset);
model.transform(dataset).registerTempTable("prediction");
DataFrame predictions = jsql.sql("SELECT label, probability, prediction FROM prediction");
predictions.collectAsList();
// Check defaults
- assert(model.getThreshold() == 0.5);
- assert(model.getFeaturesCol().equals("features"));
- assert(model.getPredictionCol().equals("prediction"));
- assert(model.getProbabilityCol().equals("probability"));
+ Assert.assertEquals(0.5, model.getThreshold(), eps);
+ Assert.assertEquals("features", model.getFeaturesCol());
+ Assert.assertEquals("prediction", model.getPredictionCol());
+ Assert.assertEquals("probability", model.getProbabilityCol());
}
@Test
@@ -85,19 +86,19 @@ public class JavaLogisticRegressionSuite implements Serializable {
.setProbabilityCol("myProbability");
LogisticRegressionModel model = lr.fit(dataset);
LogisticRegression parent = (LogisticRegression) model.parent();
- assert(parent.getMaxIter() == 10);
- assert(parent.getRegParam() == 1.0);
- assert(parent.getThresholds()[0] == 0.4);
- assert(parent.getThresholds()[1] == 0.6);
- assert(parent.getThreshold() == 0.6);
- assert(model.getThreshold() == 0.6);
+ Assert.assertEquals(10, parent.getMaxIter());
+ Assert.assertEquals(1.0, parent.getRegParam(), eps);
+ Assert.assertEquals(0.4, parent.getThresholds()[0], eps);
+ Assert.assertEquals(0.6, parent.getThresholds()[1], eps);
+ Assert.assertEquals(0.6, parent.getThreshold(), eps);
+ Assert.assertEquals(0.6, model.getThreshold(), eps);
// Modify model params, and check that the params worked.
model.setThreshold(1.0);
model.transform(dataset).registerTempTable("predAllZero");
DataFrame predAllZero = jsql.sql("SELECT prediction, myProbability FROM predAllZero");
for (Row r: predAllZero.collectAsList()) {
- assert(r.getDouble(0) == 0.0);
+ Assert.assertEquals(0.0, r.getDouble(0), eps);
}
// Call transform with params, and check that the params worked.
model.transform(dataset, model.threshold().w(0.0), model.probabilityCol().w("myProb"))
@@ -107,17 +108,17 @@ public class JavaLogisticRegressionSuite implements Serializable {
for (Row r: predNotAllZero.collectAsList()) {
if (r.getDouble(0) != 0.0) foundNonZero = true;
}
- assert(foundNonZero);
+ Assert.assertTrue(foundNonZero);
// Call fit() with new params, and check as many params as we can.
LogisticRegressionModel model2 = lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1),
lr.threshold().w(0.4), lr.probabilityCol().w("theProb"));
LogisticRegression parent2 = (LogisticRegression) model2.parent();
- assert(parent2.getMaxIter() == 5);
- assert(parent2.getRegParam() == 0.1);
- assert(parent2.getThreshold() == 0.4);
- assert(model2.getThreshold() == 0.4);
- assert(model2.getProbabilityCol().equals("theProb"));
+ Assert.assertEquals(5, parent2.getMaxIter());
+ Assert.assertEquals(0.1, parent2.getRegParam(), eps);
+ Assert.assertEquals(0.4, parent2.getThreshold(), eps);
+ Assert.assertEquals(0.4, model2.getThreshold(), eps);
+ Assert.assertEquals("theProb", model2.getProbabilityCol());
}
@SuppressWarnings("unchecked")
@@ -125,18 +126,18 @@ public class JavaLogisticRegressionSuite implements Serializable {
public void logisticRegressionPredictorClassifierMethods() {
LogisticRegression lr = new LogisticRegression();
LogisticRegressionModel model = lr.fit(dataset);
- assert(model.numClasses() == 2);
+ Assert.assertEquals(2, model.numClasses());
model.transform(dataset).registerTempTable("transformed");
DataFrame trans1 = jsql.sql("SELECT rawPrediction, probability FROM transformed");
for (Row row: trans1.collect()) {
Vector raw = (Vector)row.get(0);
Vector prob = (Vector)row.get(1);
- assert(raw.size() == 2);
- assert(prob.size() == 2);
+ Assert.assertEquals(raw.size(), 2);
+ Assert.assertEquals(prob.size(), 2);
double probFromRaw1 = 1.0 / (1.0 + Math.exp(-raw.apply(1)));
- assert(Math.abs(prob.apply(1) - probFromRaw1) < eps);
- assert(Math.abs(prob.apply(0) - (1.0 - probFromRaw1)) < eps);
+ Assert.assertEquals(0, Math.abs(prob.apply(1) - probFromRaw1), eps);
+ Assert.assertEquals(0, Math.abs(prob.apply(0) - (1.0 - probFromRaw1)), eps);
}
DataFrame trans2 = jsql.sql("SELECT prediction, probability FROM transformed");
@@ -145,7 +146,7 @@ public class JavaLogisticRegressionSuite implements Serializable {
Vector prob = (Vector)row.get(1);
double probOfPred = prob.apply((int)pred);
for (int i = 0; i < prob.size(); ++i) {
- assert(probOfPred >= prob.apply(i));
+ Assert.assertTrue(probOfPred >= prob.apply(i));
}
}
}
@@ -156,6 +157,6 @@ public class JavaLogisticRegressionSuite implements Serializable {
LogisticRegressionModel model = lr.fit(dataset);
LogisticRegressionTrainingSummary summary = model.summary();
- assert(summary.totalIterations() == summary.objectiveHistory().length);
+ Assert.assertEquals(summary.totalIterations(), summary.objectiveHistory().length);
}
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java
index 8fd7bf55a2..075a62c493 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java
@@ -23,6 +23,7 @@ import java.util.Arrays;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
+import static org.junit.Assert.assertEquals;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
@@ -58,18 +59,18 @@ public class JavaNaiveBayesSuite implements Serializable {
for (Row r : predictionAndLabels.collect()) {
double prediction = r.getAs(0);
double label = r.getAs(1);
- assert(prediction == label);
+ assertEquals(label, prediction, 1E-5);
}
}
@Test
public void naiveBayesDefaultParams() {
NaiveBayes nb = new NaiveBayes();
- assert(nb.getLabelCol() == "label");
- assert(nb.getFeaturesCol() == "features");
- assert(nb.getPredictionCol() == "prediction");
- assert(nb.getSmoothing() == 1.0);
- assert(nb.getModelType() == "multinomial");
+ assertEquals("label", nb.getLabelCol());
+ assertEquals("features", nb.getFeaturesCol());
+ assertEquals("prediction", nb.getPredictionCol());
+ assertEquals(1.0, nb.getSmoothing(), 1E-5);
+ assertEquals("multinomial", nb.getModelType());
}
@Test
diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
index d591a45686..91c589d00a 100644
--- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
@@ -60,7 +60,7 @@ public class JavaLinearRegressionSuite implements Serializable {
@Test
public void linearRegressionDefaultParams() {
LinearRegression lr = new LinearRegression();
- assert(lr.getLabelCol().equals("label"));
+ assertEquals("label", lr.getLabelCol());
LinearRegressionModel model = lr.fit(dataset);
model.transform(dataset).registerTempTable("prediction");
DataFrame predictions = jsql.sql("SELECT label, prediction FROM prediction");
diff --git a/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java
index 3349c50224..8beea102ef 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java
@@ -80,10 +80,10 @@ public class JavaMatricesSuite implements Serializable {
assertArrayEquals(sd.toArray(), s.toArray(), 0.0);
assertArrayEquals(s.toArray(), ss.toArray(), 0.0);
assertArrayEquals(s.values(), ss.values(), 0.0);
- assert(s.values().length == 2);
- assert(ss.values().length == 2);
- assert(s.colPtrs().length == 4);
- assert(ss.colPtrs().length == 4);
+ assertEquals(2, s.values().length);
+ assertEquals(2, ss.values().length);
+ assertEquals(4, s.colPtrs().length);
+ assertEquals(4, ss.colPtrs().length);
}
@Test
@@ -137,27 +137,27 @@ public class JavaMatricesSuite implements Serializable {
Matrix deHorz2 = Matrices.horzcat(new Matrix[]{spMat1, deMat2});
Matrix deHorz3 = Matrices.horzcat(new Matrix[]{deMat1, spMat2});
- assert(deHorz1.numRows() == 3);
- assert(deHorz2.numRows() == 3);
- assert(deHorz3.numRows() == 3);
- assert(spHorz.numRows() == 3);
- assert(deHorz1.numCols() == 5);
- assert(deHorz2.numCols() == 5);
- assert(deHorz3.numCols() == 5);
- assert(spHorz.numCols() == 5);
+ assertEquals(3, deHorz1.numRows());
+ assertEquals(3, deHorz2.numRows());
+ assertEquals(3, deHorz3.numRows());
+ assertEquals(3, spHorz.numRows());
+ assertEquals(5, deHorz1.numCols());
+ assertEquals(5, deHorz2.numCols());
+ assertEquals(5, deHorz3.numCols());
+ assertEquals(5, spHorz.numCols());
Matrix spVert = Matrices.vertcat(new Matrix[]{spMat1, spMat3});
Matrix deVert1 = Matrices.vertcat(new Matrix[]{deMat1, deMat3});
Matrix deVert2 = Matrices.vertcat(new Matrix[]{spMat1, deMat3});
Matrix deVert3 = Matrices.vertcat(new Matrix[]{deMat1, spMat3});
- assert(deVert1.numRows() == 5);
- assert(deVert2.numRows() == 5);
- assert(deVert3.numRows() == 5);
- assert(spVert.numRows() == 5);
- assert(deVert1.numCols() == 2);
- assert(deVert2.numCols() == 2);
- assert(deVert3.numCols() == 2);
- assert(spVert.numCols() == 2);
+ assertEquals(5, deVert1.numRows());
+ assertEquals(5, deVert2.numRows());
+ assertEquals(5, deVert3.numRows());
+ assertEquals(5, spVert.numRows());
+ assertEquals(2, deVert1.numCols());
+ assertEquals(2, deVert2.numCols());
+ assertEquals(2, deVert3.numCols());
+ assertEquals(2, spVert.numCols());
}
}