From 53a1ebf33b5c349ae3a40d7eebf357b839b363af Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 24 Feb 2015 18:51:41 -0800 Subject: [SPARK-5904][SQL] DataFrame Java API test suites. Added a new test suite to make sure Java DF programs can use varargs properly. Also moved all suites into test.org.apache.spark package to make sure the suites also test for method visibility. Author: Reynold Xin Closes #4751 from rxin/df-tests and squashes the following commits: 1e8b8e4 [Reynold Xin] Fixed imports and renamed JavaAPISuite. a6ca53b [Reynold Xin] [SPARK-5904][SQL] DataFrame Java API test suites. --- .../apache/spark/sql/api/java/JavaAPISuite.java | 86 --------- .../spark/sql/api/java/JavaApplySchemaSuite.java | 206 --------------------- .../org/apache/spark/sql/api/java/JavaDsl.java | 120 ------------ .../apache/spark/sql/api/java/JavaRowSuite.java | 179 ------------------ .../spark/sql/sources/JavaSaveLoadSuite.java | 97 ---------- .../org/apache/spark/sql/JavaApplySchemaSuite.java | 204 ++++++++++++++++++++ .../org/apache/spark/sql/JavaDataFrameSuite.java | 84 +++++++++ .../test/org/apache/spark/sql/JavaRowSuite.java | 179 ++++++++++++++++++ .../test/org/apache/spark/sql/JavaUDFSuite.java | 88 +++++++++ .../spark/sql/sources/JavaSaveLoadSuite.java | 98 ++++++++++ .../org/apache/spark/sql/DataFrameSuite.scala | 4 +- 11 files changed, 655 insertions(+), 690 deletions(-) delete mode 100644 sql/core/src/test/java/org/apache/spark/sql/api/java/JavaAPISuite.java delete mode 100644 sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java delete mode 100644 sql/core/src/test/java/org/apache/spark/sql/api/java/JavaDsl.java delete mode 100644 sql/core/src/test/java/org/apache/spark/sql/api/java/JavaRowSuite.java delete mode 100644 sql/core/src/test/java/org/apache/spark/sql/sources/JavaSaveLoadSuite.java create mode 100644 sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java create mode 100644 sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java create mode 100644 sql/core/src/test/java/test/org/apache/spark/sql/JavaRowSuite.java create mode 100644 sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java create mode 100644 sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java (limited to 'sql') diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaAPISuite.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaAPISuite.java deleted file mode 100644 index a21a154090..0000000000 --- a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaAPISuite.java +++ /dev/null @@ -1,86 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.api.java; - -import java.io.Serializable; - -import org.apache.spark.sql.test.TestSQLContext$; -import org.junit.After; -import org.junit.Before; -import org.junit.Test; - -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.SQLContext; -import org.apache.spark.sql.types.DataTypes; - -// The test suite itself is Serializable so that anonymous Function implementations can be -// serialized, as an alternative to converting these anonymous classes to static inner classes; -// see http://stackoverflow.com/questions/758570/. -public class JavaAPISuite implements Serializable { - private transient JavaSparkContext sc; - private transient SQLContext sqlContext; - - @Before - public void setUp() { - sqlContext = TestSQLContext$.MODULE$; - sc = new JavaSparkContext(sqlContext.sparkContext()); - } - - @After - public void tearDown() { - } - - @SuppressWarnings("unchecked") - @Test - public void udf1Test() { - // With Java 8 lambdas: - // sqlContext.registerFunction( - // "stringLengthTest", (String str) -> str.length(), DataType.IntegerType); - - sqlContext.udf().register("stringLengthTest", new UDF1() { - @Override - public Integer call(String str) throws Exception { - return str.length(); - } - }, DataTypes.IntegerType); - - Row result = sqlContext.sql("SELECT stringLengthTest('test')").head(); - assert(result.getInt(0) == 4); - } - - @SuppressWarnings("unchecked") - @Test - public void udf2Test() { - // With Java 8 lambdas: - // sqlContext.registerFunction( - // "stringLengthTest", - // (String str1, String str2) -> str1.length() + str2.length, - // DataType.IntegerType); - - sqlContext.udf().register("stringLengthTest", new UDF2() { - @Override - public Integer call(String str1, String str2) throws Exception { - return str1.length() + str2.length(); - } - }, DataTypes.IntegerType); - - Row result = sqlContext.sql("SELECT stringLengthTest('test', 'test2')").head(); - assert(result.getInt(0) == 9); - } -} diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java deleted file mode 100644 index 643b891ab1..0000000000 --- a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java +++ /dev/null @@ -1,206 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.api.java; - -import java.io.Serializable; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; - -import org.apache.spark.sql.test.TestSQLContext$; -import org.junit.After; -import org.junit.Assert; -import org.junit.Before; -import org.junit.Test; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.sql.*; -import org.apache.spark.sql.types.*; - -// The test suite itself is Serializable so that anonymous Function implementations can be -// serialized, as an alternative to converting these anonymous classes to static inner classes; -// see http://stackoverflow.com/questions/758570/. -public class JavaApplySchemaSuite implements Serializable { - private transient JavaSparkContext javaCtx; - private transient SQLContext javaSqlCtx; - - @Before - public void setUp() { - javaSqlCtx = TestSQLContext$.MODULE$; - javaCtx = new JavaSparkContext(javaSqlCtx.sparkContext()); - } - - @After - public void tearDown() { - javaCtx = null; - javaSqlCtx = null; - } - - public static class Person implements Serializable { - private String name; - private int age; - - public String getName() { - return name; - } - - public void setName(String name) { - this.name = name; - } - - public int getAge() { - return age; - } - - public void setAge(int age) { - this.age = age; - } - } - - @Test - public void applySchema() { - List personList = new ArrayList(2); - Person person1 = new Person(); - person1.setName("Michael"); - person1.setAge(29); - personList.add(person1); - Person person2 = new Person(); - person2.setName("Yin"); - person2.setAge(28); - personList.add(person2); - - JavaRDD rowRDD = javaCtx.parallelize(personList).map( - new Function() { - public Row call(Person person) throws Exception { - return RowFactory.create(person.getName(), person.getAge()); - } - }); - - List fields = new ArrayList(2); - fields.add(DataTypes.createStructField("name", DataTypes.StringType, false)); - fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false)); - StructType schema = DataTypes.createStructType(fields); - - DataFrame df = javaSqlCtx.applySchema(rowRDD, schema); - df.registerTempTable("people"); - Row[] actual = javaSqlCtx.sql("SELECT * FROM people").collect(); - - List expected = new ArrayList(2); - expected.add(RowFactory.create("Michael", 29)); - expected.add(RowFactory.create("Yin", 28)); - - Assert.assertEquals(expected, Arrays.asList(actual)); - } - - - - @Test - public void dataFrameRDDOperations() { - List personList = new ArrayList(2); - Person person1 = new Person(); - person1.setName("Michael"); - person1.setAge(29); - personList.add(person1); - Person person2 = new Person(); - person2.setName("Yin"); - person2.setAge(28); - personList.add(person2); - - JavaRDD rowRDD = javaCtx.parallelize(personList).map( - new Function() { - public Row call(Person person) throws Exception { - return RowFactory.create(person.getName(), person.getAge()); - } - }); - - List fields = new ArrayList(2); - fields.add(DataTypes.createStructField("name", DataTypes.StringType, false)); - fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false)); - StructType schema = DataTypes.createStructType(fields); - - DataFrame df = javaSqlCtx.applySchema(rowRDD, schema); - df.registerTempTable("people"); - List actual = javaSqlCtx.sql("SELECT * FROM people").toJavaRDD().map(new Function() { - - public String call(Row row) { - return row.getString(0) + "_" + row.get(1).toString(); - } - }).collect(); - - List expected = new ArrayList(2); - expected.add("Michael_29"); - expected.add("Yin_28"); - - Assert.assertEquals(expected, actual); - } - - @Test - public void applySchemaToJSON() { - JavaRDD jsonRDD = javaCtx.parallelize(Arrays.asList( - "{\"string\":\"this is a simple string.\", \"integer\":10, \"long\":21474836470, " + - "\"bigInteger\":92233720368547758070, \"double\":1.7976931348623157E308, " + - "\"boolean\":true, \"null\":null}", - "{\"string\":\"this is another simple string.\", \"integer\":11, \"long\":21474836469, " + - "\"bigInteger\":92233720368547758069, \"double\":1.7976931348623157E305, " + - "\"boolean\":false, \"null\":null}")); - List fields = new ArrayList(7); - fields.add(DataTypes.createStructField("bigInteger", DataTypes.createDecimalType(), true)); - fields.add(DataTypes.createStructField("boolean", DataTypes.BooleanType, true)); - fields.add(DataTypes.createStructField("double", DataTypes.DoubleType, true)); - fields.add(DataTypes.createStructField("integer", DataTypes.LongType, true)); - fields.add(DataTypes.createStructField("long", DataTypes.LongType, true)); - fields.add(DataTypes.createStructField("null", DataTypes.StringType, true)); - fields.add(DataTypes.createStructField("string", DataTypes.StringType, true)); - StructType expectedSchema = DataTypes.createStructType(fields); - List expectedResult = new ArrayList(2); - expectedResult.add( - RowFactory.create( - new java.math.BigDecimal("92233720368547758070"), - true, - 1.7976931348623157E308, - 10, - 21474836470L, - null, - "this is a simple string.")); - expectedResult.add( - RowFactory.create( - new java.math.BigDecimal("92233720368547758069"), - false, - 1.7976931348623157E305, - 11, - 21474836469L, - null, - "this is another simple string.")); - - DataFrame df1 = javaSqlCtx.jsonRDD(jsonRDD); - StructType actualSchema1 = df1.schema(); - Assert.assertEquals(expectedSchema, actualSchema1); - df1.registerTempTable("jsonTable1"); - List actual1 = javaSqlCtx.sql("select * from jsonTable1").collectAsList(); - Assert.assertEquals(expectedResult, actual1); - - DataFrame df2 = javaSqlCtx.jsonRDD(jsonRDD, expectedSchema); - StructType actualSchema2 = df2.schema(); - Assert.assertEquals(expectedSchema, actualSchema2); - df2.registerTempTable("jsonTable2"); - List actual2 = javaSqlCtx.sql("select * from jsonTable2").collectAsList(); - Assert.assertEquals(expectedResult, actual2); - } -} diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaDsl.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaDsl.java deleted file mode 100644 index 05233dc5ff..0000000000 --- a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaDsl.java +++ /dev/null @@ -1,120 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.api.java; - -import com.google.common.collect.ImmutableMap; - -import org.apache.spark.sql.Column; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.types.DataTypes; - -import static org.apache.spark.sql.functions.*; - -/** - * This test doesn't actually run anything. It is here to check the API compatibility for Java. - */ -public class JavaDsl { - - public static void testDataFrame(final DataFrame df) { - DataFrame df1 = df.select("colA"); - df1 = df.select("colA", "colB"); - - df1 = df.select(col("colA"), col("colB"), lit("literal value").$plus(1)); - - df1 = df.filter(col("colA")); - - java.util.Map aggExprs = ImmutableMap.builder() - .put("colA", "sum") - .put("colB", "avg") - .build(); - - df1 = df.agg(aggExprs); - - df1 = df.groupBy("groupCol").agg(aggExprs); - - df1 = df.join(df1, col("key1").$eq$eq$eq(col("key2")), "outer"); - - df.orderBy("colA"); - df.orderBy("colA", "colB", "colC"); - df.orderBy(col("colA").desc()); - df.orderBy(col("colA").desc(), col("colB").asc()); - - df.sort("colA"); - df.sort("colA", "colB", "colC"); - df.sort(col("colA").desc()); - df.sort(col("colA").desc(), col("colB").asc()); - - df.as("b"); - - df.limit(5); - - df.unionAll(df1); - df.intersect(df1); - df.except(df1); - - df.sample(true, 0.1, 234); - - df.head(); - df.head(5); - df.first(); - df.count(); - } - - public static void testColumn(final Column c) { - c.asc(); - c.desc(); - - c.endsWith("abcd"); - c.startsWith("afgasdf"); - - c.like("asdf%"); - c.rlike("wef%asdf"); - - c.as("newcol"); - - c.cast("int"); - c.cast(DataTypes.IntegerType); - } - - public static void testDsl() { - // Creating a column. - Column c = col("abcd"); - Column c1 = column("abcd"); - - // Literals - Column l1 = lit(1); - Column l2 = lit(1.0); - Column l3 = lit("abcd"); - - // Functions - Column a = upper(c); - a = lower(c); - a = sqrt(c); - a = abs(c); - - // Aggregates - a = min(c); - a = max(c); - a = sum(c); - a = sumDistinct(c); - a = countDistinct(c, a); - a = avg(c); - a = first(c); - a = last(c); - } -} diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaRowSuite.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaRowSuite.java deleted file mode 100644 index fbfcd3f59d..0000000000 --- a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaRowSuite.java +++ /dev/null @@ -1,179 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.api.java; - -import java.math.BigDecimal; -import java.sql.Date; -import java.sql.Timestamp; -import java.util.Arrays; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -import org.junit.Assert; -import org.junit.Before; -import org.junit.Test; - -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; - -public class JavaRowSuite { - private byte byteValue; - private short shortValue; - private int intValue; - private long longValue; - private float floatValue; - private double doubleValue; - private BigDecimal decimalValue; - private boolean booleanValue; - private String stringValue; - private byte[] binaryValue; - private Date dateValue; - private Timestamp timestampValue; - - @Before - public void setUp() { - byteValue = (byte)127; - shortValue = (short)32767; - intValue = 2147483647; - longValue = 9223372036854775807L; - floatValue = (float)3.4028235E38; - doubleValue = 1.7976931348623157E308; - decimalValue = new BigDecimal("1.7976931348623157E328"); - booleanValue = true; - stringValue = "this is a string"; - binaryValue = stringValue.getBytes(); - dateValue = Date.valueOf("2014-06-30"); - timestampValue = Timestamp.valueOf("2014-06-30 09:20:00.0"); - } - - @Test - public void constructSimpleRow() { - Row simpleRow = RowFactory.create( - byteValue, // ByteType - new Byte(byteValue), - shortValue, // ShortType - new Short(shortValue), - intValue, // IntegerType - new Integer(intValue), - longValue, // LongType - new Long(longValue), - floatValue, // FloatType - new Float(floatValue), - doubleValue, // DoubleType - new Double(doubleValue), - decimalValue, // DecimalType - booleanValue, // BooleanType - new Boolean(booleanValue), - stringValue, // StringType - binaryValue, // BinaryType - dateValue, // DateType - timestampValue, // TimestampType - null // null - ); - - Assert.assertEquals(byteValue, simpleRow.getByte(0)); - Assert.assertEquals(byteValue, simpleRow.get(0)); - Assert.assertEquals(byteValue, simpleRow.getByte(1)); - Assert.assertEquals(byteValue, simpleRow.get(1)); - Assert.assertEquals(shortValue, simpleRow.getShort(2)); - Assert.assertEquals(shortValue, simpleRow.get(2)); - Assert.assertEquals(shortValue, simpleRow.getShort(3)); - Assert.assertEquals(shortValue, simpleRow.get(3)); - Assert.assertEquals(intValue, simpleRow.getInt(4)); - Assert.assertEquals(intValue, simpleRow.get(4)); - Assert.assertEquals(intValue, simpleRow.getInt(5)); - Assert.assertEquals(intValue, simpleRow.get(5)); - Assert.assertEquals(longValue, simpleRow.getLong(6)); - Assert.assertEquals(longValue, simpleRow.get(6)); - Assert.assertEquals(longValue, simpleRow.getLong(7)); - Assert.assertEquals(longValue, simpleRow.get(7)); - // When we create the row, we do not do any conversion - // for a float/double value, so we just set the delta to 0. - Assert.assertEquals(floatValue, simpleRow.getFloat(8), 0); - Assert.assertEquals(floatValue, simpleRow.get(8)); - Assert.assertEquals(floatValue, simpleRow.getFloat(9), 0); - Assert.assertEquals(floatValue, simpleRow.get(9)); - Assert.assertEquals(doubleValue, simpleRow.getDouble(10), 0); - Assert.assertEquals(doubleValue, simpleRow.get(10)); - Assert.assertEquals(doubleValue, simpleRow.getDouble(11), 0); - Assert.assertEquals(doubleValue, simpleRow.get(11)); - Assert.assertEquals(decimalValue, simpleRow.get(12)); - Assert.assertEquals(booleanValue, simpleRow.getBoolean(13)); - Assert.assertEquals(booleanValue, simpleRow.get(13)); - Assert.assertEquals(booleanValue, simpleRow.getBoolean(14)); - Assert.assertEquals(booleanValue, simpleRow.get(14)); - Assert.assertEquals(stringValue, simpleRow.getString(15)); - Assert.assertEquals(stringValue, simpleRow.get(15)); - Assert.assertEquals(binaryValue, simpleRow.get(16)); - Assert.assertEquals(dateValue, simpleRow.get(17)); - Assert.assertEquals(timestampValue, simpleRow.get(18)); - Assert.assertEquals(true, simpleRow.isNullAt(19)); - Assert.assertEquals(null, simpleRow.get(19)); - } - - @Test - public void constructComplexRow() { - // Simple array - List simpleStringArray = Arrays.asList( - stringValue + " (1)", stringValue + " (2)", stringValue + "(3)"); - - // Simple map - Map simpleMap = new HashMap(); - simpleMap.put(stringValue + " (1)", longValue); - simpleMap.put(stringValue + " (2)", longValue - 1); - simpleMap.put(stringValue + " (3)", longValue - 2); - - // Simple struct - Row simpleStruct = RowFactory.create( - doubleValue, stringValue, timestampValue, null); - - // Complex array - @SuppressWarnings("unchecked") - List> arrayOfMaps = Arrays.asList(simpleMap); - List arrayOfRows = Arrays.asList(simpleStruct); - - // Complex map - Map, Row> complexMap = new HashMap, Row>(); - complexMap.put(arrayOfRows, simpleStruct); - - // Complex struct - Row complexStruct = RowFactory.create( - simpleStringArray, - simpleMap, - simpleStruct, - arrayOfMaps, - arrayOfRows, - complexMap, - null); - Assert.assertEquals(simpleStringArray, complexStruct.get(0)); - Assert.assertEquals(simpleMap, complexStruct.get(1)); - Assert.assertEquals(simpleStruct, complexStruct.get(2)); - Assert.assertEquals(arrayOfMaps, complexStruct.get(3)); - Assert.assertEquals(arrayOfRows, complexStruct.get(4)); - Assert.assertEquals(complexMap, complexStruct.get(5)); - Assert.assertEquals(null, complexStruct.get(6)); - - // A very complex row - Row complexRow = RowFactory.create(arrayOfMaps, arrayOfRows, complexMap, complexStruct); - Assert.assertEquals(arrayOfMaps, complexRow.get(0)); - Assert.assertEquals(arrayOfRows, complexRow.get(1)); - Assert.assertEquals(complexMap, complexRow.get(2)); - Assert.assertEquals(complexStruct, complexRow.get(3)); - } -} diff --git a/sql/core/src/test/java/org/apache/spark/sql/sources/JavaSaveLoadSuite.java b/sql/core/src/test/java/org/apache/spark/sql/sources/JavaSaveLoadSuite.java deleted file mode 100644 index 311f1bdd07..0000000000 --- a/sql/core/src/test/java/org/apache/spark/sql/sources/JavaSaveLoadSuite.java +++ /dev/null @@ -1,97 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.sources; - -import java.io.File; -import java.io.IOException; -import java.util.*; - -import org.junit.Assert; -import org.junit.Before; -import org.junit.Test; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.test.TestSQLContext$; -import org.apache.spark.sql.*; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; -import org.apache.spark.util.Utils; - -public class JavaSaveLoadSuite { - - private transient JavaSparkContext sc; - private transient SQLContext sqlContext; - - String originalDefaultSource; - File path; - DataFrame df; - - private void checkAnswer(DataFrame actual, List expected) { - String errorMessage = QueryTest$.MODULE$.checkAnswer(actual, expected); - if (errorMessage != null) { - Assert.fail(errorMessage); - } - } - - @Before - public void setUp() throws IOException { - sqlContext = TestSQLContext$.MODULE$; - sc = new JavaSparkContext(sqlContext.sparkContext()); - - originalDefaultSource = sqlContext.conf().defaultDataSourceName(); - path = - Utils.createTempDir(System.getProperty("java.io.tmpdir"), "datasource").getCanonicalFile(); - if (path.exists()) { - path.delete(); - } - - List jsonObjects = new ArrayList(10); - for (int i = 0; i < 10; i++) { - jsonObjects.add("{\"a\":" + i + ", \"b\":\"str" + i + "\"}"); - } - JavaRDD rdd = sc.parallelize(jsonObjects); - df = sqlContext.jsonRDD(rdd); - df.registerTempTable("jsonTable"); - } - - @Test - public void saveAndLoad() { - Map options = new HashMap(); - options.put("path", path.toString()); - df.save("org.apache.spark.sql.json", SaveMode.ErrorIfExists, options); - - DataFrame loadedDF = sqlContext.load("org.apache.spark.sql.json", options); - - checkAnswer(loadedDF, df.collectAsList()); - } - - @Test - public void saveAndLoadWithSchema() { - Map options = new HashMap(); - options.put("path", path.toString()); - df.save("org.apache.spark.sql.json", SaveMode.ErrorIfExists, options); - - List fields = new ArrayList(); - fields.add(DataTypes.createStructField("b", DataTypes.StringType, true)); - StructType schema = DataTypes.createStructType(fields); - DataFrame loadedDF = sqlContext.load("org.apache.spark.sql.json", schema, options); - - checkAnswer(loadedDF, sqlContext.sql("SELECT b FROM jsonTable").collectAsList()); - } -} diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java new file mode 100644 index 0000000000..c344a9b095 --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java @@ -0,0 +1,204 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import org.apache.spark.sql.test.TestSQLContext$; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.sql.*; +import org.apache.spark.sql.types.*; + +// The test suite itself is Serializable so that anonymous Function implementations can be +// serialized, as an alternative to converting these anonymous classes to static inner classes; +// see http://stackoverflow.com/questions/758570/. +public class JavaApplySchemaSuite implements Serializable { + private transient JavaSparkContext javaCtx; + private transient SQLContext sqlContext; + + @Before + public void setUp() { + sqlContext = TestSQLContext$.MODULE$; + javaCtx = new JavaSparkContext(sqlContext.sparkContext()); + } + + @After + public void tearDown() { + javaCtx = null; + sqlContext = null; + } + + public static class Person implements Serializable { + private String name; + private int age; + + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } + + public int getAge() { + return age; + } + + public void setAge(int age) { + this.age = age; + } + } + + @Test + public void applySchema() { + List personList = new ArrayList(2); + Person person1 = new Person(); + person1.setName("Michael"); + person1.setAge(29); + personList.add(person1); + Person person2 = new Person(); + person2.setName("Yin"); + person2.setAge(28); + personList.add(person2); + + JavaRDD rowRDD = javaCtx.parallelize(personList).map( + new Function() { + public Row call(Person person) throws Exception { + return RowFactory.create(person.getName(), person.getAge()); + } + }); + + List fields = new ArrayList(2); + fields.add(DataTypes.createStructField("name", DataTypes.StringType, false)); + fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false)); + StructType schema = DataTypes.createStructType(fields); + + DataFrame df = sqlContext.applySchema(rowRDD, schema); + df.registerTempTable("people"); + Row[] actual = sqlContext.sql("SELECT * FROM people").collect(); + + List expected = new ArrayList(2); + expected.add(RowFactory.create("Michael", 29)); + expected.add(RowFactory.create("Yin", 28)); + + Assert.assertEquals(expected, Arrays.asList(actual)); + } + + @Test + public void dataFrameRDDOperations() { + List personList = new ArrayList(2); + Person person1 = new Person(); + person1.setName("Michael"); + person1.setAge(29); + personList.add(person1); + Person person2 = new Person(); + person2.setName("Yin"); + person2.setAge(28); + personList.add(person2); + + JavaRDD rowRDD = javaCtx.parallelize(personList).map( + new Function() { + public Row call(Person person) throws Exception { + return RowFactory.create(person.getName(), person.getAge()); + } + }); + + List fields = new ArrayList(2); + fields.add(DataTypes.createStructField("name", DataTypes.StringType, false)); + fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false)); + StructType schema = DataTypes.createStructType(fields); + + DataFrame df = sqlContext.applySchema(rowRDD, schema); + df.registerTempTable("people"); + List actual = sqlContext.sql("SELECT * FROM people").toJavaRDD().map(new Function() { + + public String call(Row row) { + return row.getString(0) + "_" + row.get(1).toString(); + } + }).collect(); + + List expected = new ArrayList(2); + expected.add("Michael_29"); + expected.add("Yin_28"); + + Assert.assertEquals(expected, actual); + } + + @Test + public void applySchemaToJSON() { + JavaRDD jsonRDD = javaCtx.parallelize(Arrays.asList( + "{\"string\":\"this is a simple string.\", \"integer\":10, \"long\":21474836470, " + + "\"bigInteger\":92233720368547758070, \"double\":1.7976931348623157E308, " + + "\"boolean\":true, \"null\":null}", + "{\"string\":\"this is another simple string.\", \"integer\":11, \"long\":21474836469, " + + "\"bigInteger\":92233720368547758069, \"double\":1.7976931348623157E305, " + + "\"boolean\":false, \"null\":null}")); + List fields = new ArrayList(7); + fields.add(DataTypes.createStructField("bigInteger", DataTypes.createDecimalType(), true)); + fields.add(DataTypes.createStructField("boolean", DataTypes.BooleanType, true)); + fields.add(DataTypes.createStructField("double", DataTypes.DoubleType, true)); + fields.add(DataTypes.createStructField("integer", DataTypes.LongType, true)); + fields.add(DataTypes.createStructField("long", DataTypes.LongType, true)); + fields.add(DataTypes.createStructField("null", DataTypes.StringType, true)); + fields.add(DataTypes.createStructField("string", DataTypes.StringType, true)); + StructType expectedSchema = DataTypes.createStructType(fields); + List expectedResult = new ArrayList(2); + expectedResult.add( + RowFactory.create( + new java.math.BigDecimal("92233720368547758070"), + true, + 1.7976931348623157E308, + 10, + 21474836470L, + null, + "this is a simple string.")); + expectedResult.add( + RowFactory.create( + new java.math.BigDecimal("92233720368547758069"), + false, + 1.7976931348623157E305, + 11, + 21474836469L, + null, + "this is another simple string.")); + + DataFrame df1 = sqlContext.jsonRDD(jsonRDD); + StructType actualSchema1 = df1.schema(); + Assert.assertEquals(expectedSchema, actualSchema1); + df1.registerTempTable("jsonTable1"); + List actual1 = sqlContext.sql("select * from jsonTable1").collectAsList(); + Assert.assertEquals(expectedResult, actual1); + + DataFrame df2 = sqlContext.jsonRDD(jsonRDD, expectedSchema); + StructType actualSchema2 = df2.schema(); + Assert.assertEquals(expectedSchema, actualSchema2); + df2.registerTempTable("jsonTable2"); + List actual2 = sqlContext.sql("select * from jsonTable2").collectAsList(); + Assert.assertEquals(expectedResult, actual2); + } +} diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java new file mode 100644 index 0000000000..c1c51f80d6 --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.sql.*; +import org.apache.spark.sql.test.TestSQLContext$; +import static org.apache.spark.sql.functions.*; + + +public class JavaDataFrameSuite { + private transient SQLContext context; + + @Before + public void setUp() { + // Trigger static initializer of TestData + TestData$.MODULE$.testData(); + context = TestSQLContext$.MODULE$; + } + + @After + public void tearDown() { + context = null; + } + + @Test + public void testExecution() { + DataFrame df = context.table("testData").filter("key = 1"); + Assert.assertEquals(df.select("key").collect()[0].get(0), 1); + } + + /** + * See SPARK-5904. Abstract vararg methods defined in Scala do not work in Java. + */ + @Test + public void testVarargMethods() { + DataFrame df = context.table("testData"); + + df.toDF("key1", "value1"); + + df.select("key", "value"); + df.select(col("key"), col("value")); + df.selectExpr("key", "value + 1"); + + df.sort("key", "value"); + df.sort(col("key"), col("value")); + df.orderBy("key", "value"); + df.orderBy(col("key"), col("value")); + + df.groupBy("key", "value").agg(col("key"), col("value"), sum("value")); + df.groupBy(col("key"), col("value")).agg(col("key"), col("value"), sum("value")); + df.agg(first("key"), sum("value")); + + df.groupBy().avg("key"); + df.groupBy().mean("key"); + df.groupBy().max("key"); + df.groupBy().min("key"); + df.groupBy().sum("key"); + + // Varargs in column expressions + df.groupBy().agg(countDistinct("key", "value")); + df.groupBy().agg(countDistinct(col("key"), col("value"))); + df.select(coalesce(col("key"))); + } +} diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaRowSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaRowSuite.java new file mode 100644 index 0000000000..4ce1d1dddb --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaRowSuite.java @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql; + +import java.math.BigDecimal; +import java.sql.Date; +import java.sql.Timestamp; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; + +public class JavaRowSuite { + private byte byteValue; + private short shortValue; + private int intValue; + private long longValue; + private float floatValue; + private double doubleValue; + private BigDecimal decimalValue; + private boolean booleanValue; + private String stringValue; + private byte[] binaryValue; + private Date dateValue; + private Timestamp timestampValue; + + @Before + public void setUp() { + byteValue = (byte)127; + shortValue = (short)32767; + intValue = 2147483647; + longValue = 9223372036854775807L; + floatValue = (float)3.4028235E38; + doubleValue = 1.7976931348623157E308; + decimalValue = new BigDecimal("1.7976931348623157E328"); + booleanValue = true; + stringValue = "this is a string"; + binaryValue = stringValue.getBytes(); + dateValue = Date.valueOf("2014-06-30"); + timestampValue = Timestamp.valueOf("2014-06-30 09:20:00.0"); + } + + @Test + public void constructSimpleRow() { + Row simpleRow = RowFactory.create( + byteValue, // ByteType + new Byte(byteValue), + shortValue, // ShortType + new Short(shortValue), + intValue, // IntegerType + new Integer(intValue), + longValue, // LongType + new Long(longValue), + floatValue, // FloatType + new Float(floatValue), + doubleValue, // DoubleType + new Double(doubleValue), + decimalValue, // DecimalType + booleanValue, // BooleanType + new Boolean(booleanValue), + stringValue, // StringType + binaryValue, // BinaryType + dateValue, // DateType + timestampValue, // TimestampType + null // null + ); + + Assert.assertEquals(byteValue, simpleRow.getByte(0)); + Assert.assertEquals(byteValue, simpleRow.get(0)); + Assert.assertEquals(byteValue, simpleRow.getByte(1)); + Assert.assertEquals(byteValue, simpleRow.get(1)); + Assert.assertEquals(shortValue, simpleRow.getShort(2)); + Assert.assertEquals(shortValue, simpleRow.get(2)); + Assert.assertEquals(shortValue, simpleRow.getShort(3)); + Assert.assertEquals(shortValue, simpleRow.get(3)); + Assert.assertEquals(intValue, simpleRow.getInt(4)); + Assert.assertEquals(intValue, simpleRow.get(4)); + Assert.assertEquals(intValue, simpleRow.getInt(5)); + Assert.assertEquals(intValue, simpleRow.get(5)); + Assert.assertEquals(longValue, simpleRow.getLong(6)); + Assert.assertEquals(longValue, simpleRow.get(6)); + Assert.assertEquals(longValue, simpleRow.getLong(7)); + Assert.assertEquals(longValue, simpleRow.get(7)); + // When we create the row, we do not do any conversion + // for a float/double value, so we just set the delta to 0. + Assert.assertEquals(floatValue, simpleRow.getFloat(8), 0); + Assert.assertEquals(floatValue, simpleRow.get(8)); + Assert.assertEquals(floatValue, simpleRow.getFloat(9), 0); + Assert.assertEquals(floatValue, simpleRow.get(9)); + Assert.assertEquals(doubleValue, simpleRow.getDouble(10), 0); + Assert.assertEquals(doubleValue, simpleRow.get(10)); + Assert.assertEquals(doubleValue, simpleRow.getDouble(11), 0); + Assert.assertEquals(doubleValue, simpleRow.get(11)); + Assert.assertEquals(decimalValue, simpleRow.get(12)); + Assert.assertEquals(booleanValue, simpleRow.getBoolean(13)); + Assert.assertEquals(booleanValue, simpleRow.get(13)); + Assert.assertEquals(booleanValue, simpleRow.getBoolean(14)); + Assert.assertEquals(booleanValue, simpleRow.get(14)); + Assert.assertEquals(stringValue, simpleRow.getString(15)); + Assert.assertEquals(stringValue, simpleRow.get(15)); + Assert.assertEquals(binaryValue, simpleRow.get(16)); + Assert.assertEquals(dateValue, simpleRow.get(17)); + Assert.assertEquals(timestampValue, simpleRow.get(18)); + Assert.assertEquals(true, simpleRow.isNullAt(19)); + Assert.assertEquals(null, simpleRow.get(19)); + } + + @Test + public void constructComplexRow() { + // Simple array + List simpleStringArray = Arrays.asList( + stringValue + " (1)", stringValue + " (2)", stringValue + "(3)"); + + // Simple map + Map simpleMap = new HashMap(); + simpleMap.put(stringValue + " (1)", longValue); + simpleMap.put(stringValue + " (2)", longValue - 1); + simpleMap.put(stringValue + " (3)", longValue - 2); + + // Simple struct + Row simpleStruct = RowFactory.create( + doubleValue, stringValue, timestampValue, null); + + // Complex array + @SuppressWarnings("unchecked") + List> arrayOfMaps = Arrays.asList(simpleMap); + List arrayOfRows = Arrays.asList(simpleStruct); + + // Complex map + Map, Row> complexMap = new HashMap, Row>(); + complexMap.put(arrayOfRows, simpleStruct); + + // Complex struct + Row complexStruct = RowFactory.create( + simpleStringArray, + simpleMap, + simpleStruct, + arrayOfMaps, + arrayOfRows, + complexMap, + null); + Assert.assertEquals(simpleStringArray, complexStruct.get(0)); + Assert.assertEquals(simpleMap, complexStruct.get(1)); + Assert.assertEquals(simpleStruct, complexStruct.get(2)); + Assert.assertEquals(arrayOfMaps, complexStruct.get(3)); + Assert.assertEquals(arrayOfRows, complexStruct.get(4)); + Assert.assertEquals(complexMap, complexStruct.get(5)); + Assert.assertEquals(null, complexStruct.get(6)); + + // A very complex row + Row complexRow = RowFactory.create(arrayOfMaps, arrayOfRows, complexMap, complexStruct); + Assert.assertEquals(arrayOfMaps, complexRow.get(0)); + Assert.assertEquals(arrayOfRows, complexRow.get(1)); + Assert.assertEquals(complexMap, complexRow.get(2)); + Assert.assertEquals(complexStruct, complexRow.get(3)); + } +} diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java new file mode 100644 index 0000000000..79d92734ff --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql; + +import java.io.Serializable; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.api.java.UDF1; +import org.apache.spark.sql.api.java.UDF2; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.test.TestSQLContext$; +import org.apache.spark.sql.types.DataTypes; + +// The test suite itself is Serializable so that anonymous Function implementations can be +// serialized, as an alternative to converting these anonymous classes to static inner classes; +// see http://stackoverflow.com/questions/758570/. +public class JavaUDFSuite implements Serializable { + private transient JavaSparkContext sc; + private transient SQLContext sqlContext; + + @Before + public void setUp() { + sqlContext = TestSQLContext$.MODULE$; + sc = new JavaSparkContext(sqlContext.sparkContext()); + } + + @After + public void tearDown() { + } + + @SuppressWarnings("unchecked") + @Test + public void udf1Test() { + // With Java 8 lambdas: + // sqlContext.registerFunction( + // "stringLengthTest", (String str) -> str.length(), DataType.IntegerType); + + sqlContext.udf().register("stringLengthTest", new UDF1() { + @Override + public Integer call(String str) throws Exception { + return str.length(); + } + }, DataTypes.IntegerType); + + Row result = sqlContext.sql("SELECT stringLengthTest('test')").head(); + assert(result.getInt(0) == 4); + } + + @SuppressWarnings("unchecked") + @Test + public void udf2Test() { + // With Java 8 lambdas: + // sqlContext.registerFunction( + // "stringLengthTest", + // (String str1, String str2) -> str1.length() + str2.length, + // DataType.IntegerType); + + sqlContext.udf().register("stringLengthTest", new UDF2() { + @Override + public Integer call(String str1, String str2) throws Exception { + return str1.length() + str2.length(); + } + }, DataTypes.IntegerType); + + Row result = sqlContext.sql("SELECT stringLengthTest('test', 'test2')").head(); + assert(result.getInt(0) == 9); + } +} diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java new file mode 100644 index 0000000000..b76f7d421f --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql.sources; + +import java.io.File; +import java.io.IOException; +import java.util.*; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.test.TestSQLContext$; +import org.apache.spark.sql.*; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.util.Utils; + +public class JavaSaveLoadSuite { + + private transient JavaSparkContext sc; + private transient SQLContext sqlContext; + + String originalDefaultSource; + File path; + DataFrame df; + + private void checkAnswer(DataFrame actual, List expected) { + String errorMessage = QueryTest$.MODULE$.checkAnswer(actual, expected); + if (errorMessage != null) { + Assert.fail(errorMessage); + } + } + + @Before + public void setUp() throws IOException { + sqlContext = TestSQLContext$.MODULE$; + sc = new JavaSparkContext(sqlContext.sparkContext()); + + originalDefaultSource = sqlContext.conf().defaultDataSourceName(); + path = + Utils.createTempDir(System.getProperty("java.io.tmpdir"), "datasource").getCanonicalFile(); + if (path.exists()) { + path.delete(); + } + + List jsonObjects = new ArrayList(10); + for (int i = 0; i < 10; i++) { + jsonObjects.add("{\"a\":" + i + ", \"b\":\"str" + i + "\"}"); + } + JavaRDD rdd = sc.parallelize(jsonObjects); + df = sqlContext.jsonRDD(rdd); + df.registerTempTable("jsonTable"); + } + + @Test + public void saveAndLoad() { + Map options = new HashMap(); + options.put("path", path.toString()); + df.save("org.apache.spark.sql.json", SaveMode.ErrorIfExists, options); + + DataFrame loadedDF = sqlContext.load("org.apache.spark.sql.json", options); + + checkAnswer(loadedDF, df.collectAsList()); + } + + @Test + public void saveAndLoadWithSchema() { + Map options = new HashMap(); + options.put("path", path.toString()); + df.save("org.apache.spark.sql.json", SaveMode.ErrorIfExists, options); + + List fields = new ArrayList(); + fields.add(DataTypes.createStructField("b", DataTypes.StringType, true)); + StructType schema = DataTypes.createStructType(fields); + DataFrame loadedDF = sqlContext.load("org.apache.spark.sql.json", schema, options); + + checkAnswer(loadedDF, sqlContext.sql("SELECT b FROM jsonTable").collectAsList()); + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index e71e9bee3a..30e77e4ef3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -411,7 +411,7 @@ class DataFrameSuite extends QueryTest { ) } - test("addColumn") { + test("withColumn") { val df = testData.toDF().withColumn("newCol", col("key") + 1) checkAnswer( df, @@ -421,7 +421,7 @@ class DataFrameSuite extends QueryTest { assert(df.schema.map(_.name).toSeq === Seq("key", "value", "newCol")) } - test("renameColumn") { + test("withColumnRenamed") { val df = testData.toDF().withColumn("newCol", col("key") + 1) .withColumnRenamed("value", "valueRenamed") checkAnswer( -- cgit v1.2.3