From 871764c6ce531af5b1ac7ccccb32e7a903b59a2a Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Sat, 5 Sep 2015 00:04:00 -1000 Subject: [SPARK-10013] [ML] [JAVA] [TEST] remove java assert from java unit tests From Jira: We should use assertTrue, etc. instead to make sure the asserts are not ignored in tests. Author: Holden Karau Closes #8607 from holdenk/SPARK-10013-remove-java-assert-from-java-unit-tests. --- .../JavaLogisticRegressionSuite.java | 51 +++++++++++----------- .../ml/classification/JavaNaiveBayesSuite.java | 13 +++--- .../ml/regression/JavaLinearRegressionSuite.java | 2 +- .../spark/mllib/linalg/JavaMatricesSuite.java | 40 ++++++++--------- 4 files changed, 54 insertions(+), 52 deletions(-) (limited to 'mllib') diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java index 618b95b9bd..fd22eb6dca 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java @@ -22,6 +22,7 @@ import java.lang.Math; import java.util.List; import org.junit.After; +import org.junit.Assert; import org.junit.Before; import org.junit.Test; @@ -63,16 +64,16 @@ public class JavaLogisticRegressionSuite implements Serializable { @Test public void logisticRegressionDefaultParams() { LogisticRegression lr = new LogisticRegression(); - assert(lr.getLabelCol().equals("label")); + Assert.assertEquals(lr.getLabelCol(), "label"); LogisticRegressionModel model = lr.fit(dataset); model.transform(dataset).registerTempTable("prediction"); DataFrame predictions = jsql.sql("SELECT label, probability, prediction FROM prediction"); predictions.collectAsList(); // Check defaults - assert(model.getThreshold() == 0.5); - assert(model.getFeaturesCol().equals("features")); - assert(model.getPredictionCol().equals("prediction")); - assert(model.getProbabilityCol().equals("probability")); + Assert.assertEquals(0.5, model.getThreshold(), eps); + Assert.assertEquals("features", model.getFeaturesCol()); + Assert.assertEquals("prediction", model.getPredictionCol()); + Assert.assertEquals("probability", model.getProbabilityCol()); } @Test @@ -85,19 +86,19 @@ public class JavaLogisticRegressionSuite implements Serializable { .setProbabilityCol("myProbability"); LogisticRegressionModel model = lr.fit(dataset); LogisticRegression parent = (LogisticRegression) model.parent(); - assert(parent.getMaxIter() == 10); - assert(parent.getRegParam() == 1.0); - assert(parent.getThresholds()[0] == 0.4); - assert(parent.getThresholds()[1] == 0.6); - assert(parent.getThreshold() == 0.6); - assert(model.getThreshold() == 0.6); + Assert.assertEquals(10, parent.getMaxIter()); + Assert.assertEquals(1.0, parent.getRegParam(), eps); + Assert.assertEquals(0.4, parent.getThresholds()[0], eps); + Assert.assertEquals(0.6, parent.getThresholds()[1], eps); + Assert.assertEquals(0.6, parent.getThreshold(), eps); + Assert.assertEquals(0.6, model.getThreshold(), eps); // Modify model params, and check that the params worked. model.setThreshold(1.0); model.transform(dataset).registerTempTable("predAllZero"); DataFrame predAllZero = jsql.sql("SELECT prediction, myProbability FROM predAllZero"); for (Row r: predAllZero.collectAsList()) { - assert(r.getDouble(0) == 0.0); + Assert.assertEquals(0.0, r.getDouble(0), eps); } // Call transform with params, and check that the params worked. model.transform(dataset, model.threshold().w(0.0), model.probabilityCol().w("myProb")) @@ -107,17 +108,17 @@ public class JavaLogisticRegressionSuite implements Serializable { for (Row r: predNotAllZero.collectAsList()) { if (r.getDouble(0) != 0.0) foundNonZero = true; } - assert(foundNonZero); + Assert.assertTrue(foundNonZero); // Call fit() with new params, and check as many params as we can. LogisticRegressionModel model2 = lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1), lr.threshold().w(0.4), lr.probabilityCol().w("theProb")); LogisticRegression parent2 = (LogisticRegression) model2.parent(); - assert(parent2.getMaxIter() == 5); - assert(parent2.getRegParam() == 0.1); - assert(parent2.getThreshold() == 0.4); - assert(model2.getThreshold() == 0.4); - assert(model2.getProbabilityCol().equals("theProb")); + Assert.assertEquals(5, parent2.getMaxIter()); + Assert.assertEquals(0.1, parent2.getRegParam(), eps); + Assert.assertEquals(0.4, parent2.getThreshold(), eps); + Assert.assertEquals(0.4, model2.getThreshold(), eps); + Assert.assertEquals("theProb", model2.getProbabilityCol()); } @SuppressWarnings("unchecked") @@ -125,18 +126,18 @@ public class JavaLogisticRegressionSuite implements Serializable { public void logisticRegressionPredictorClassifierMethods() { LogisticRegression lr = new LogisticRegression(); LogisticRegressionModel model = lr.fit(dataset); - assert(model.numClasses() == 2); + Assert.assertEquals(2, model.numClasses()); model.transform(dataset).registerTempTable("transformed"); DataFrame trans1 = jsql.sql("SELECT rawPrediction, probability FROM transformed"); for (Row row: trans1.collect()) { Vector raw = (Vector)row.get(0); Vector prob = (Vector)row.get(1); - assert(raw.size() == 2); - assert(prob.size() == 2); + Assert.assertEquals(raw.size(), 2); + Assert.assertEquals(prob.size(), 2); double probFromRaw1 = 1.0 / (1.0 + Math.exp(-raw.apply(1))); - assert(Math.abs(prob.apply(1) - probFromRaw1) < eps); - assert(Math.abs(prob.apply(0) - (1.0 - probFromRaw1)) < eps); + Assert.assertEquals(0, Math.abs(prob.apply(1) - probFromRaw1), eps); + Assert.assertEquals(0, Math.abs(prob.apply(0) - (1.0 - probFromRaw1)), eps); } DataFrame trans2 = jsql.sql("SELECT prediction, probability FROM transformed"); @@ -145,7 +146,7 @@ public class JavaLogisticRegressionSuite implements Serializable { Vector prob = (Vector)row.get(1); double probOfPred = prob.apply((int)pred); for (int i = 0; i < prob.size(); ++i) { - assert(probOfPred >= prob.apply(i)); + Assert.assertTrue(probOfPred >= prob.apply(i)); } } } @@ -156,6 +157,6 @@ public class JavaLogisticRegressionSuite implements Serializable { LogisticRegressionModel model = lr.fit(dataset); LogisticRegressionTrainingSummary summary = model.summary(); - assert(summary.totalIterations() == summary.objectiveHistory().length); + Assert.assertEquals(summary.totalIterations(), summary.objectiveHistory().length); } } diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java index 8fd7bf55a2..075a62c493 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java @@ -23,6 +23,7 @@ import java.util.Arrays; import org.junit.After; import org.junit.Before; import org.junit.Test; +import static org.junit.Assert.assertEquals; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; @@ -58,18 +59,18 @@ public class JavaNaiveBayesSuite implements Serializable { for (Row r : predictionAndLabels.collect()) { double prediction = r.getAs(0); double label = r.getAs(1); - assert(prediction == label); + assertEquals(label, prediction, 1E-5); } } @Test public void naiveBayesDefaultParams() { NaiveBayes nb = new NaiveBayes(); - assert(nb.getLabelCol() == "label"); - assert(nb.getFeaturesCol() == "features"); - assert(nb.getPredictionCol() == "prediction"); - assert(nb.getSmoothing() == 1.0); - assert(nb.getModelType() == "multinomial"); + assertEquals("label", nb.getLabelCol()); + assertEquals("features", nb.getFeaturesCol()); + assertEquals("prediction", nb.getPredictionCol()); + assertEquals(1.0, nb.getSmoothing(), 1E-5); + assertEquals("multinomial", nb.getModelType()); } @Test diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java index d591a45686..91c589d00a 100644 --- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java @@ -60,7 +60,7 @@ public class JavaLinearRegressionSuite implements Serializable { @Test public void linearRegressionDefaultParams() { LinearRegression lr = new LinearRegression(); - assert(lr.getLabelCol().equals("label")); + assertEquals("label", lr.getLabelCol()); LinearRegressionModel model = lr.fit(dataset); model.transform(dataset).registerTempTable("prediction"); DataFrame predictions = jsql.sql("SELECT label, prediction FROM prediction"); diff --git a/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java index 3349c50224..8beea102ef 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java @@ -80,10 +80,10 @@ public class JavaMatricesSuite implements Serializable { assertArrayEquals(sd.toArray(), s.toArray(), 0.0); assertArrayEquals(s.toArray(), ss.toArray(), 0.0); assertArrayEquals(s.values(), ss.values(), 0.0); - assert(s.values().length == 2); - assert(ss.values().length == 2); - assert(s.colPtrs().length == 4); - assert(ss.colPtrs().length == 4); + assertEquals(2, s.values().length); + assertEquals(2, ss.values().length); + assertEquals(4, s.colPtrs().length); + assertEquals(4, ss.colPtrs().length); } @Test @@ -137,27 +137,27 @@ public class JavaMatricesSuite implements Serializable { Matrix deHorz2 = Matrices.horzcat(new Matrix[]{spMat1, deMat2}); Matrix deHorz3 = Matrices.horzcat(new Matrix[]{deMat1, spMat2}); - assert(deHorz1.numRows() == 3); - assert(deHorz2.numRows() == 3); - assert(deHorz3.numRows() == 3); - assert(spHorz.numRows() == 3); - assert(deHorz1.numCols() == 5); - assert(deHorz2.numCols() == 5); - assert(deHorz3.numCols() == 5); - assert(spHorz.numCols() == 5); + assertEquals(3, deHorz1.numRows()); + assertEquals(3, deHorz2.numRows()); + assertEquals(3, deHorz3.numRows()); + assertEquals(3, spHorz.numRows()); + assertEquals(5, deHorz1.numCols()); + assertEquals(5, deHorz2.numCols()); + assertEquals(5, deHorz3.numCols()); + assertEquals(5, spHorz.numCols()); Matrix spVert = Matrices.vertcat(new Matrix[]{spMat1, spMat3}); Matrix deVert1 = Matrices.vertcat(new Matrix[]{deMat1, deMat3}); Matrix deVert2 = Matrices.vertcat(new Matrix[]{spMat1, deMat3}); Matrix deVert3 = Matrices.vertcat(new Matrix[]{deMat1, spMat3}); - assert(deVert1.numRows() == 5); - assert(deVert2.numRows() == 5); - assert(deVert3.numRows() == 5); - assert(spVert.numRows() == 5); - assert(deVert1.numCols() == 2); - assert(deVert2.numCols() == 2); - assert(deVert3.numCols() == 2); - assert(spVert.numCols() == 2); + assertEquals(5, deVert1.numRows()); + assertEquals(5, deVert2.numRows()); + assertEquals(5, deVert3.numRows()); + assertEquals(5, spVert.numRows()); + assertEquals(2, deVert1.numCols()); + assertEquals(2, deVert2.numCols()); + assertEquals(2, deVert3.numCols()); + assertEquals(2, spVert.numCols()); } } -- cgit v1.2.3