aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-02-10 19:40:12 -0800
committerMichael Armbrust <michael@databricks.com>2015-02-10 19:40:12 -0800
commitea60284095cad43aa7ac98256576375d0e91a52a (patch)
tree35ac6e3935e1e7c731f7b9a850f2daa9640387d1 /sql
parenta60aea86b4d4b716b5ec3bff776b509fe0831342 (diff)
downloadspark-ea60284095cad43aa7ac98256576375d0e91a52a.tar.gz
spark-ea60284095cad43aa7ac98256576375d0e91a52a.tar.bz2
spark-ea60284095cad43aa7ac98256576375d0e91a52a.zip
[SPARK-5704] [SQL] [PySpark] createDataFrame from RDD with columns
Deprecate inferSchema() and applySchema(), use createDataFrame() instead, which could take an optional `schema` to create an DataFrame from an RDD. The `schema` could be StructType or list of names of columns. Author: Davies Liu <davies@databricks.com> Closes #4498 from davies/create and squashes the following commits: 08469c1 [Davies Liu] remove Scala/Java API for now c80a7a9 [Davies Liu] fix hive test d1bd8f2 [Davies Liu] cleanup applySchema 9526e97 [Davies Liu] createDataFrame from RDD with columns
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala95
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala9
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala18
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala4
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala8
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala4
8 files changed, 111 insertions, 33 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 801505bceb..523911d108 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -243,7 +243,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
* val people =
* sc.textFile("examples/src/main/resources/people.txt").map(
* _.split(",")).map(p => Row(p(0), p(1).trim.toInt))
- * val dataFrame = sqlContext. applySchema(people, schema)
+ * val dataFrame = sqlContext.createDataFrame(people, schema)
* dataFrame.printSchema
* // root
* // |-- name: string (nullable = false)
@@ -252,11 +252,9 @@ class SQLContext(@transient val sparkContext: SparkContext)
* dataFrame.registerTempTable("people")
* sqlContext.sql("select name from people").collect.foreach(println)
* }}}
- *
- * @group userf
*/
@DeveloperApi
- def applySchema(rowRDD: RDD[Row], schema: StructType): DataFrame = {
+ def createDataFrame(rowRDD: RDD[Row], schema: StructType): DataFrame = {
// TODO: use MutableProjection when rowRDD is another DataFrame and the applied
// schema differs from the existing schema on any field data type.
val logicalPlan = LogicalRDD(schema.toAttributes, rowRDD)(self)
@@ -264,8 +262,21 @@ class SQLContext(@transient val sparkContext: SparkContext)
}
@DeveloperApi
- def applySchema(rowRDD: JavaRDD[Row], schema: StructType): DataFrame = {
- applySchema(rowRDD.rdd, schema);
+ def createDataFrame(rowRDD: JavaRDD[Row], schema: StructType): DataFrame = {
+ createDataFrame(rowRDD.rdd, schema)
+ }
+
+ /**
+ * Creates a [[DataFrame]] from an [[JavaRDD]] containing [[Row]]s by applying
+ * a seq of names of columns to this RDD, the data type for each column will
+ * be inferred by the first row.
+ *
+ * @param rowRDD an JavaRDD of Row
+ * @param columns names for each column
+ * @return DataFrame
+ */
+ def createDataFrame(rowRDD: JavaRDD[Row], columns: java.util.List[String]): DataFrame = {
+ createDataFrame(rowRDD.rdd, columns.toSeq)
}
/**
@@ -274,7 +285,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
* WARNING: Since there is no guaranteed ordering for fields in a Java Bean,
* SELECT * queries will return the columns in an undefined order.
*/
- def applySchema(rdd: RDD[_], beanClass: Class[_]): DataFrame = {
+ def createDataFrame(rdd: RDD[_], beanClass: Class[_]): DataFrame = {
val attributeSeq = getSchema(beanClass)
val className = beanClass.getName
val rowRdd = rdd.mapPartitions { iter =>
@@ -301,8 +312,72 @@ class SQLContext(@transient val sparkContext: SparkContext)
* WARNING: Since there is no guaranteed ordering for fields in a Java Bean,
* SELECT * queries will return the columns in an undefined order.
*/
+ def createDataFrame(rdd: JavaRDD[_], beanClass: Class[_]): DataFrame = {
+ createDataFrame(rdd.rdd, beanClass)
+ }
+
+ /**
+ * :: DeveloperApi ::
+ * Creates a [[DataFrame]] from an [[RDD]] containing [[Row]]s by applying a schema to this RDD.
+ * It is important to make sure that the structure of every [[Row]] of the provided RDD matches
+ * the provided schema. Otherwise, there will be runtime exception.
+ * Example:
+ * {{{
+ * import org.apache.spark.sql._
+ * val sqlContext = new org.apache.spark.sql.SQLContext(sc)
+ *
+ * val schema =
+ * StructType(
+ * StructField("name", StringType, false) ::
+ * StructField("age", IntegerType, true) :: Nil)
+ *
+ * val people =
+ * sc.textFile("examples/src/main/resources/people.txt").map(
+ * _.split(",")).map(p => Row(p(0), p(1).trim.toInt))
+ * val dataFrame = sqlContext. applySchema(people, schema)
+ * dataFrame.printSchema
+ * // root
+ * // |-- name: string (nullable = false)
+ * // |-- age: integer (nullable = true)
+ *
+ * dataFrame.registerTempTable("people")
+ * sqlContext.sql("select name from people").collect.foreach(println)
+ * }}}
+ *
+ * @group userf
+ */
+ @DeveloperApi
+ @deprecated("use createDataFrame", "1.3.0")
+ def applySchema(rowRDD: RDD[Row], schema: StructType): DataFrame = {
+ createDataFrame(rowRDD, schema)
+ }
+
+ @DeveloperApi
+ @deprecated("use createDataFrame", "1.3.0")
+ def applySchema(rowRDD: JavaRDD[Row], schema: StructType): DataFrame = {
+ createDataFrame(rowRDD, schema)
+ }
+
+ /**
+ * Applies a schema to an RDD of Java Beans.
+ *
+ * WARNING: Since there is no guaranteed ordering for fields in a Java Bean,
+ * SELECT * queries will return the columns in an undefined order.
+ */
+ @deprecated("use createDataFrame", "1.3.0")
+ def applySchema(rdd: RDD[_], beanClass: Class[_]): DataFrame = {
+ createDataFrame(rdd, beanClass)
+ }
+
+ /**
+ * Applies a schema to an RDD of Java Beans.
+ *
+ * WARNING: Since there is no guaranteed ordering for fields in a Java Bean,
+ * SELECT * queries will return the columns in an undefined order.
+ */
+ @deprecated("use createDataFrame", "1.3.0")
def applySchema(rdd: JavaRDD[_], beanClass: Class[_]): DataFrame = {
- applySchema(rdd.rdd, beanClass)
+ createDataFrame(rdd, beanClass)
}
/**
@@ -375,7 +450,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
JsonRDD.nullTypeToStringType(
JsonRDD.inferSchema(json, 1.0, columnNameOfCorruptJsonRecord)))
val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema, columnNameOfCorruptJsonRecord)
- applySchema(rowRDD, appliedSchema)
+ createDataFrame(rowRDD, appliedSchema)
}
@Experimental
@@ -393,7 +468,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
JsonRDD.nullTypeToStringType(
JsonRDD.inferSchema(json, samplingRatio, columnNameOfCorruptJsonRecord))
val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema, columnNameOfCorruptJsonRecord)
- applySchema(rowRDD, appliedSchema)
+ createDataFrame(rowRDD, appliedSchema)
}
@Experimental
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
index fa4cdecbcb..1d71039872 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
@@ -180,7 +180,7 @@ class ColumnExpressionSuite extends QueryTest {
}
test("!==") {
- val nullData = TestSQLContext.applySchema(TestSQLContext.sparkContext.parallelize(
+ val nullData = TestSQLContext.createDataFrame(TestSQLContext.sparkContext.parallelize(
Row(1, 1) ::
Row(1, 2) ::
Row(1, null) ::
@@ -240,7 +240,7 @@ class ColumnExpressionSuite extends QueryTest {
testData2.collect().toSeq.filter(r => r.getInt(0) <= r.getInt(1)))
}
- val booleanData = TestSQLContext.applySchema(TestSQLContext.sparkContext.parallelize(
+ val booleanData = TestSQLContext.createDataFrame(TestSQLContext.sparkContext.parallelize(
Row(false, false) ::
Row(false, true) ::
Row(true, false) ::
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 55fd0b0892..bba8899651 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -34,6 +34,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
TestData
import org.apache.spark.sql.test.TestSQLContext.implicits._
+ val sqlCtx = TestSQLContext
test("SPARK-4625 support SORT BY in SimpleSQLParser & DSL") {
checkAnswer(
@@ -669,7 +670,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
Row(values(0).toInt, values(1), values(2).toBoolean, v4)
}
- val df1 = applySchema(rowRDD1, schema1)
+ val df1 = sqlCtx.createDataFrame(rowRDD1, schema1)
df1.registerTempTable("applySchema1")
checkAnswer(
sql("SELECT * FROM applySchema1"),
@@ -699,7 +700,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
Row(Row(values(0).toInt, values(2).toBoolean), Map(values(1) -> v4))
}
- val df2 = applySchema(rowRDD2, schema2)
+ val df2 = sqlCtx.createDataFrame(rowRDD2, schema2)
df2.registerTempTable("applySchema2")
checkAnswer(
sql("SELECT * FROM applySchema2"),
@@ -724,7 +725,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
Row(Row(values(0).toInt, values(2).toBoolean), scala.collection.mutable.Map(values(1) -> v4))
}
- val df3 = applySchema(rowRDD3, schema2)
+ val df3 = sqlCtx.createDataFrame(rowRDD3, schema2)
df3.registerTempTable("applySchema3")
checkAnswer(
@@ -769,7 +770,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
.build()
val schemaWithMeta = new StructType(Array(
schema("id"), schema("name").copy(metadata = metadata), schema("age")))
- val personWithMeta = applySchema(person.rdd, schemaWithMeta)
+ val personWithMeta = sqlCtx.createDataFrame(person.rdd, schemaWithMeta)
def validateMetadata(rdd: DataFrame): Unit = {
assert(rdd.schema("name").metadata.getString(docKey) == docValue)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index df108a9d26..c3210733f1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -71,7 +71,7 @@ class PlannerSuite extends FunSuite {
val schema = StructType(fields)
val row = Row.fromSeq(Seq.fill(fields.size)(null))
val rowRDD = org.apache.spark.sql.test.TestSQLContext.sparkContext.parallelize(row :: Nil)
- applySchema(rowRDD, schema).registerTempTable("testLimit")
+ createDataFrame(rowRDD, schema).registerTempTable("testLimit")
val planned = sql(
"""
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
index e581ac9b50..21e7093610 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
@@ -54,7 +54,7 @@ class JDBCWriteSuite extends FunSuite with BeforeAndAfter {
StructField("seq", IntegerType) :: Nil)
test("Basic CREATE") {
- val srdd = TestSQLContext.applySchema(sc.parallelize(arr2x2), schema2)
+ val srdd = TestSQLContext.createDataFrame(sc.parallelize(arr2x2), schema2)
srdd.createJDBCTable(url, "TEST.BASICCREATETEST", false)
assert(2 == TestSQLContext.jdbcRDD(url, "TEST.BASICCREATETEST").count)
@@ -62,8 +62,8 @@ class JDBCWriteSuite extends FunSuite with BeforeAndAfter {
}
test("CREATE with overwrite") {
- val srdd = TestSQLContext.applySchema(sc.parallelize(arr2x3), schema3)
- val srdd2 = TestSQLContext.applySchema(sc.parallelize(arr1x2), schema2)
+ val srdd = TestSQLContext.createDataFrame(sc.parallelize(arr2x3), schema3)
+ val srdd2 = TestSQLContext.createDataFrame(sc.parallelize(arr1x2), schema2)
srdd.createJDBCTable(url, "TEST.DROPTEST", false)
assert(2 == TestSQLContext.jdbcRDD(url, "TEST.DROPTEST").count)
@@ -75,8 +75,8 @@ class JDBCWriteSuite extends FunSuite with BeforeAndAfter {
}
test("CREATE then INSERT to append") {
- val srdd = TestSQLContext.applySchema(sc.parallelize(arr2x2), schema2)
- val srdd2 = TestSQLContext.applySchema(sc.parallelize(arr1x2), schema2)
+ val srdd = TestSQLContext.createDataFrame(sc.parallelize(arr2x2), schema2)
+ val srdd2 = TestSQLContext.createDataFrame(sc.parallelize(arr1x2), schema2)
srdd.createJDBCTable(url, "TEST.APPENDTEST", false)
srdd2.insertIntoJDBC(url, "TEST.APPENDTEST", false)
@@ -85,8 +85,8 @@ class JDBCWriteSuite extends FunSuite with BeforeAndAfter {
}
test("CREATE then INSERT to truncate") {
- val srdd = TestSQLContext.applySchema(sc.parallelize(arr2x2), schema2)
- val srdd2 = TestSQLContext.applySchema(sc.parallelize(arr1x2), schema2)
+ val srdd = TestSQLContext.createDataFrame(sc.parallelize(arr2x2), schema2)
+ val srdd2 = TestSQLContext.createDataFrame(sc.parallelize(arr1x2), schema2)
srdd.createJDBCTable(url, "TEST.TRUNCATETEST", false)
srdd2.insertIntoJDBC(url, "TEST.TRUNCATETEST", true)
@@ -95,8 +95,8 @@ class JDBCWriteSuite extends FunSuite with BeforeAndAfter {
}
test("Incompatible INSERT to append") {
- val srdd = TestSQLContext.applySchema(sc.parallelize(arr2x2), schema2)
- val srdd2 = TestSQLContext.applySchema(sc.parallelize(arr2x3), schema3)
+ val srdd = TestSQLContext.createDataFrame(sc.parallelize(arr2x2), schema2)
+ val srdd2 = TestSQLContext.createDataFrame(sc.parallelize(arr2x3), schema3)
srdd.createJDBCTable(url, "TEST.INCOMPATIBLETEST", false)
intercept[org.apache.spark.SparkException] {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
index 4fc92e3e3b..fde4b47438 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
@@ -820,7 +820,7 @@ class JsonSuite extends QueryTest {
Row(values(0).toInt, values(1), values(2).toBoolean, r.split(",").toList, v5)
}
- val df1 = applySchema(rowRDD1, schema1)
+ val df1 = createDataFrame(rowRDD1, schema1)
df1.registerTempTable("applySchema1")
val df2 = df1.toDataFrame
val result = df2.toJSON.collect()
@@ -841,7 +841,7 @@ class JsonSuite extends QueryTest {
Row(Row(values(0).toInt, values(2).toBoolean), Map(values(1) -> v4))
}
- val df3 = applySchema(rowRDD2, schema2)
+ val df3 = createDataFrame(rowRDD2, schema2)
df3.registerTempTable("applySchema2")
val df4 = df3.toDataFrame
val result2 = df4.toJSON.collect()
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala
index 43da7519ac..89b18c3439 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala
@@ -97,7 +97,7 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter {
val schema = StructType(StructField("m", MapType(StringType, StringType), true) :: Nil)
val rowRDD = TestHive.sparkContext.parallelize(
(1 to 100).map(i => Row(scala.collection.mutable.HashMap(s"key$i" -> s"value$i"))))
- val df = applySchema(rowRDD, schema)
+ val df = TestHive.createDataFrame(rowRDD, schema)
df.registerTempTable("tableWithMapValue")
sql("CREATE TABLE hiveTableWithMapValue(m MAP <STRING, STRING>)")
sql("INSERT OVERWRITE TABLE hiveTableWithMapValue SELECT m FROM tableWithMapValue")
@@ -142,7 +142,7 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter {
val schema = StructType(Seq(
StructField("a", ArrayType(StringType, containsNull = false))))
val rowRDD = TestHive.sparkContext.parallelize((1 to 100).map(i => Row(Seq(s"value$i"))))
- val df = applySchema(rowRDD, schema)
+ val df = TestHive.createDataFrame(rowRDD, schema)
df.registerTempTable("tableWithArrayValue")
sql("CREATE TABLE hiveTableWithArrayValue(a Array <STRING>)")
sql("INSERT OVERWRITE TABLE hiveTableWithArrayValue SELECT a FROM tableWithArrayValue")
@@ -159,7 +159,7 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter {
StructField("m", MapType(StringType, StringType, valueContainsNull = false))))
val rowRDD = TestHive.sparkContext.parallelize(
(1 to 100).map(i => Row(Map(s"key$i" -> s"value$i"))))
- val df = applySchema(rowRDD, schema)
+ val df = TestHive.createDataFrame(rowRDD, schema)
df.registerTempTable("tableWithMapValue")
sql("CREATE TABLE hiveTableWithMapValue(m Map <STRING, STRING>)")
sql("INSERT OVERWRITE TABLE hiveTableWithMapValue SELECT m FROM tableWithMapValue")
@@ -176,7 +176,7 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter {
StructField("s", StructType(Seq(StructField("f", StringType, nullable = false))))))
val rowRDD = TestHive.sparkContext.parallelize(
(1 to 100).map(i => Row(Row(s"value$i"))))
- val df = applySchema(rowRDD, schema)
+ val df = TestHive.createDataFrame(rowRDD, schema)
df.registerTempTable("tableWithStructValue")
sql("CREATE TABLE hiveTableWithStructValue(s Struct <f: STRING>)")
sql("INSERT OVERWRITE TABLE hiveTableWithStructValue SELECT s FROM tableWithStructValue")
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
index 49fe79d989..9a6e8650a0 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.hive.execution
import org.apache.spark.sql.hive.HiveShim
+import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.hive.test.TestHive._
import org.apache.spark.sql.types._
import org.apache.spark.sql.{QueryTest, Row, SQLConf}
@@ -34,6 +35,7 @@ case class Nested3(f3: Int)
class SQLQuerySuite extends QueryTest {
import org.apache.spark.sql.hive.test.TestHive.implicits._
+ val sqlCtx = TestHive
test("SPARK-4512 Fix attribute reference resolution error when using SORT BY") {
checkAnswer(
@@ -277,7 +279,7 @@ class SQLQuerySuite extends QueryTest {
val rowRdd = sparkContext.parallelize(row :: Nil)
- applySchema(rowRdd, schema).registerTempTable("testTable")
+ sqlCtx.createDataFrame(rowRdd, schema).registerTempTable("testTable")
sql(
"""CREATE TABLE nullValuesInInnerComplexTypes