diff options
author | Sean Owen <sowen@cloudera.com> | 2015-09-12 10:40:10 +0100 |
---|---|---|
committer | Sean Owen <sowen@cloudera.com> | 2015-09-12 10:40:10 +0100 |
commit | 22730ad54d681ad30e63fe910e8d89360853177d (patch) | |
tree | 81194034499a6d391a0949e865fc0aa6dd5fc4ec /sql/core | |
parent | 8285e3b0d3dc0eff669eba993742dfe0401116f9 (diff) | |
download | spark-22730ad54d681ad30e63fe910e8d89360853177d.tar.gz spark-22730ad54d681ad30e63fe910e8d89360853177d.tar.bz2 spark-22730ad54d681ad30e63fe910e8d89360853177d.zip |
[SPARK-10547] [TEST] Streamline / improve style of Java API tests
Fix a few Java API test style issues: unused generic types, exceptions, wrong assert argument order
Author: Sean Owen <sowen@cloudera.com>
Closes #8706 from srowen/SPARK-10547.
Diffstat (limited to 'sql/core')
5 files changed, 54 insertions, 48 deletions
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java index bf693c7c39..7b50aad4ad 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java @@ -18,6 +18,7 @@ package test.org.apache.spark.sql; import java.io.Serializable; +import java.math.BigDecimal; import java.util.ArrayList; import java.util.Arrays; import java.util.List; @@ -83,7 +84,7 @@ public class JavaApplySchemaSuite implements Serializable { @Test public void applySchema() { - List<Person> personList = new ArrayList<Person>(2); + List<Person> personList = new ArrayList<>(2); Person person1 = new Person(); person1.setName("Michael"); person1.setAge(29); @@ -95,12 +96,13 @@ public class JavaApplySchemaSuite implements Serializable { JavaRDD<Row> rowRDD = javaCtx.parallelize(personList).map( new Function<Person, Row>() { + @Override public Row call(Person person) throws Exception { return RowFactory.create(person.getName(), person.getAge()); } }); - List<StructField> fields = new ArrayList<StructField>(2); + List<StructField> fields = new ArrayList<>(2); fields.add(DataTypes.createStructField("name", DataTypes.StringType, false)); fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false)); StructType schema = DataTypes.createStructType(fields); @@ -118,7 +120,7 @@ public class JavaApplySchemaSuite implements Serializable { @Test public void dataFrameRDDOperations() { - List<Person> personList = new ArrayList<Person>(2); + List<Person> personList = new ArrayList<>(2); Person person1 = new Person(); person1.setName("Michael"); person1.setAge(29); @@ -129,27 +131,28 @@ public class JavaApplySchemaSuite implements Serializable { personList.add(person2); JavaRDD<Row> rowRDD = javaCtx.parallelize(personList).map( - new Function<Person, Row>() { - public Row call(Person person) throws Exception { - return RowFactory.create(person.getName(), person.getAge()); - } - }); - - List<StructField> fields = new ArrayList<StructField>(2); - fields.add(DataTypes.createStructField("name", DataTypes.StringType, false)); + new Function<Person, Row>() { + @Override + public Row call(Person person) { + return RowFactory.create(person.getName(), person.getAge()); + } + }); + + List<StructField> fields = new ArrayList<>(2); + fields.add(DataTypes.createStructField("", DataTypes.StringType, false)); fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false)); StructType schema = DataTypes.createStructType(fields); DataFrame df = sqlContext.applySchema(rowRDD, schema); df.registerTempTable("people"); List<String> actual = sqlContext.sql("SELECT * FROM people").toJavaRDD().map(new Function<Row, String>() { - + @Override public String call(Row row) { - return row.getString(0) + "_" + row.get(1).toString(); + return row.getString(0) + "_" + row.get(1); } }).collect(); - List<String> expected = new ArrayList<String>(2); + List<String> expected = new ArrayList<>(2); expected.add("Michael_29"); expected.add("Yin_28"); @@ -165,7 +168,7 @@ public class JavaApplySchemaSuite implements Serializable { "{\"string\":\"this is another simple string.\", \"integer\":11, \"long\":21474836469, " + "\"bigInteger\":92233720368547758069, \"double\":1.7976931348623157E305, " + "\"boolean\":false, \"null\":null}")); - List<StructField> fields = new ArrayList<StructField>(7); + List<StructField> fields = new ArrayList<>(7); fields.add(DataTypes.createStructField("bigInteger", DataTypes.createDecimalType(20, 0), true)); fields.add(DataTypes.createStructField("boolean", DataTypes.BooleanType, true)); @@ -175,10 +178,10 @@ public class JavaApplySchemaSuite implements Serializable { fields.add(DataTypes.createStructField("null", DataTypes.StringType, true)); fields.add(DataTypes.createStructField("string", DataTypes.StringType, true)); StructType expectedSchema = DataTypes.createStructType(fields); - List<Row> expectedResult = new ArrayList<Row>(2); + List<Row> expectedResult = new ArrayList<>(2); expectedResult.add( RowFactory.create( - new java.math.BigDecimal("92233720368547758070"), + new BigDecimal("92233720368547758070"), true, 1.7976931348623157E308, 10, @@ -187,7 +190,7 @@ public class JavaApplySchemaSuite implements Serializable { "this is a simple string.")); expectedResult.add( RowFactory.create( - new java.math.BigDecimal("92233720368547758069"), + new BigDecimal("92233720368547758069"), false, 1.7976931348623157E305, 11, diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 4867cebf53..d981ce947f 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -61,7 +61,7 @@ public class JavaDataFrameSuite { @Test public void testExecution() { DataFrame df = context.table("testData").filter("key = 1"); - Assert.assertEquals(df.select("key").collect()[0].get(0), 1); + Assert.assertEquals(1, df.select("key").collect()[0].get(0)); } /** @@ -119,7 +119,7 @@ public class JavaDataFrameSuite { public static class Bean implements Serializable { private double a = 0.0; - private Integer[] b = new Integer[]{0, 1}; + private Integer[] b = { 0, 1 }; private Map<String, int[]> c = ImmutableMap.of("hello", new int[] { 1, 2 }); private List<String> d = Arrays.asList("floppy", "disk"); @@ -161,7 +161,7 @@ public class JavaDataFrameSuite { schema.apply("d")); Row first = df.select("a", "b", "c", "d").first(); Assert.assertEquals(bean.getA(), first.getDouble(0), 0.0); - // Now Java lists and maps are converetd to Scala Seq's and Map's. Once we get a Seq below, + // Now Java lists and maps are converted to Scala Seq's and Map's. Once we get a Seq below, // verify that it has the expected length, and contains expected elements. Seq<Integer> result = first.getAs(1); Assert.assertEquals(bean.getB().length, result.length()); @@ -180,7 +180,8 @@ public class JavaDataFrameSuite { } } - private static Comparator<Row> CrosstabRowComparator = new Comparator<Row>() { + private static final Comparator<Row> crosstabRowComparator = new Comparator<Row>() { + @Override public int compare(Row row1, Row row2) { String item1 = row1.getString(0); String item2 = row2.getString(0); @@ -193,16 +194,16 @@ public class JavaDataFrameSuite { DataFrame df = context.table("testData2"); DataFrame crosstab = df.stat().crosstab("a", "b"); String[] columnNames = crosstab.schema().fieldNames(); - Assert.assertEquals(columnNames[0], "a_b"); - Assert.assertEquals(columnNames[1], "1"); - Assert.assertEquals(columnNames[2], "2"); + Assert.assertEquals("a_b", columnNames[0]); + Assert.assertEquals("1", columnNames[1]); + Assert.assertEquals("2", columnNames[2]); Row[] rows = crosstab.collect(); - Arrays.sort(rows, CrosstabRowComparator); + Arrays.sort(rows, crosstabRowComparator); Integer count = 1; for (Row row : rows) { Assert.assertEquals(row.get(0).toString(), count.toString()); - Assert.assertEquals(row.getLong(1), 1L); - Assert.assertEquals(row.getLong(2), 1L); + Assert.assertEquals(1L, row.getLong(1)); + Assert.assertEquals(1L, row.getLong(2)); count++; } } @@ -210,7 +211,7 @@ public class JavaDataFrameSuite { @Test public void testFrequentItems() { DataFrame df = context.table("testData2"); - String[] cols = new String[]{"a"}; + String[] cols = {"a"}; DataFrame results = df.stat().freqItems(cols, 0.2); Assert.assertTrue(results.collect()[0].getSeq(0).contains(1)); } @@ -219,14 +220,14 @@ public class JavaDataFrameSuite { public void testCorrelation() { DataFrame df = context.table("testData2"); Double pearsonCorr = df.stat().corr("a", "b", "pearson"); - Assert.assertTrue(Math.abs(pearsonCorr) < 1e-6); + Assert.assertTrue(Math.abs(pearsonCorr) < 1.0e-6); } @Test public void testCovariance() { DataFrame df = context.table("testData2"); Double result = df.stat().cov("a", "b"); - Assert.assertTrue(Math.abs(result) < 1e-6); + Assert.assertTrue(Math.abs(result) < 1.0e-6); } @Test @@ -234,7 +235,7 @@ public class JavaDataFrameSuite { DataFrame df = context.range(0, 100, 1, 2).select(col("id").mod(3).as("key")); DataFrame sampled = df.stat().<Integer>sampleBy("key", ImmutableMap.of(0, 0.1, 1, 0.2), 0L); Row[] actual = sampled.groupBy("key").count().orderBy("key").collect(); - Row[] expected = new Row[] {RowFactory.create(0, 5), RowFactory.create(1, 8)}; + Row[] expected = {RowFactory.create(0, 5), RowFactory.create(1, 8)}; Assert.assertArrayEquals(expected, actual); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaRowSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaRowSuite.java index 4ce1d1dddb..3ab4db2a03 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaRowSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaRowSuite.java @@ -18,6 +18,7 @@ package test.org.apache.spark.sql; import java.math.BigDecimal; +import java.nio.charset.StandardCharsets; import java.sql.Date; import java.sql.Timestamp; import java.util.Arrays; @@ -52,12 +53,12 @@ public class JavaRowSuite { shortValue = (short)32767; intValue = 2147483647; longValue = 9223372036854775807L; - floatValue = (float)3.4028235E38; + floatValue = 3.4028235E38f; doubleValue = 1.7976931348623157E308; decimalValue = new BigDecimal("1.7976931348623157E328"); booleanValue = true; stringValue = "this is a string"; - binaryValue = stringValue.getBytes(); + binaryValue = stringValue.getBytes(StandardCharsets.UTF_8); dateValue = Date.valueOf("2014-06-30"); timestampValue = Timestamp.valueOf("2014-06-30 09:20:00.0"); } @@ -123,8 +124,8 @@ public class JavaRowSuite { Assert.assertEquals(binaryValue, simpleRow.get(16)); Assert.assertEquals(dateValue, simpleRow.get(17)); Assert.assertEquals(timestampValue, simpleRow.get(18)); - Assert.assertEquals(true, simpleRow.isNullAt(19)); - Assert.assertEquals(null, simpleRow.get(19)); + Assert.assertTrue(simpleRow.isNullAt(19)); + Assert.assertNull(simpleRow.get(19)); } @Test @@ -134,7 +135,7 @@ public class JavaRowSuite { stringValue + " (1)", stringValue + " (2)", stringValue + "(3)"); // Simple map - Map<String, Long> simpleMap = new HashMap<String, Long>(); + Map<String, Long> simpleMap = new HashMap<>(); simpleMap.put(stringValue + " (1)", longValue); simpleMap.put(stringValue + " (2)", longValue - 1); simpleMap.put(stringValue + " (3)", longValue - 2); @@ -149,7 +150,7 @@ public class JavaRowSuite { List<Row> arrayOfRows = Arrays.asList(simpleStruct); // Complex map - Map<List<Row>, Row> complexMap = new HashMap<List<Row>, Row>(); + Map<List<Row>, Row> complexMap = new HashMap<>(); complexMap.put(arrayOfRows, simpleStruct); // Complex struct @@ -167,7 +168,7 @@ public class JavaRowSuite { Assert.assertEquals(arrayOfMaps, complexStruct.get(3)); Assert.assertEquals(arrayOfRows, complexStruct.get(4)); Assert.assertEquals(complexMap, complexStruct.get(5)); - Assert.assertEquals(null, complexStruct.get(6)); + Assert.assertNull(complexStruct.get(6)); // A very complex row Row complexRow = RowFactory.create(arrayOfMaps, arrayOfRows, complexMap, complexStruct); diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java index bb02b58cca..4a78dca7fe 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java @@ -20,6 +20,7 @@ package test.org.apache.spark.sql; import java.io.Serializable; import org.junit.After; +import org.junit.Assert; import org.junit.Before; import org.junit.Test; @@ -61,13 +62,13 @@ public class JavaUDFSuite implements Serializable { sqlContext.udf().register("stringLengthTest", new UDF1<String, Integer>() { @Override - public Integer call(String str) throws Exception { + public Integer call(String str) { return str.length(); } }, DataTypes.IntegerType); Row result = sqlContext.sql("SELECT stringLengthTest('test')").head(); - assert(result.getInt(0) == 4); + Assert.assertEquals(4, result.getInt(0)); } @SuppressWarnings("unchecked") @@ -81,12 +82,12 @@ public class JavaUDFSuite implements Serializable { sqlContext.udf().register("stringLengthTest", new UDF2<String, String, Integer>() { @Override - public Integer call(String str1, String str2) throws Exception { + public Integer call(String str1, String str2) { return str1.length() + str2.length(); } }, DataTypes.IntegerType); Row result = sqlContext.sql("SELECT stringLengthTest('test', 'test2')").head(); - assert(result.getInt(0) == 9); + Assert.assertEquals(9, result.getInt(0)); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java index 6f9e7f68dc..9e241f2098 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java @@ -44,7 +44,7 @@ public class JavaSaveLoadSuite { File path; DataFrame df; - private void checkAnswer(DataFrame actual, List<Row> expected) { + private static void checkAnswer(DataFrame actual, List<Row> expected) { String errorMessage = QueryTest$.MODULE$.checkAnswer(actual, expected); if (errorMessage != null) { Assert.fail(errorMessage); @@ -64,7 +64,7 @@ public class JavaSaveLoadSuite { path.delete(); } - List<String> jsonObjects = new ArrayList<String>(10); + List<String> jsonObjects = new ArrayList<>(10); for (int i = 0; i < 10; i++) { jsonObjects.add("{\"a\":" + i + ", \"b\":\"str" + i + "\"}"); } @@ -82,7 +82,7 @@ public class JavaSaveLoadSuite { @Test public void saveAndLoad() { - Map<String, String> options = new HashMap<String, String>(); + Map<String, String> options = new HashMap<>(); options.put("path", path.toString()); df.write().mode(SaveMode.ErrorIfExists).format("json").options(options).save(); DataFrame loadedDF = sqlContext.read().format("json").options(options).load(); @@ -91,11 +91,11 @@ public class JavaSaveLoadSuite { @Test public void saveAndLoadWithSchema() { - Map<String, String> options = new HashMap<String, String>(); + Map<String, String> options = new HashMap<>(); options.put("path", path.toString()); df.write().format("json").mode(SaveMode.ErrorIfExists).options(options).save(); - List<StructField> fields = new ArrayList<StructField>(); + List<StructField> fields = new ArrayList<>(); fields.add(DataTypes.createStructField("b", DataTypes.StringType, true)); StructType schema = DataTypes.createStructType(fields); DataFrame loadedDF = sqlContext.read().format("json").schema(schema).options(options).load(); |