aboutsummaryrefslogtreecommitdiff
path: root/sql/core
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-02-10 19:40:12 -0800
committerMichael Armbrust <michael@databricks.com>2015-02-10 19:40:12 -0800
commitea60284095cad43aa7ac98256576375d0e91a52a (patch)
tree35ac6e3935e1e7c731f7b9a850f2daa9640387d1 /sql/core
parenta60aea86b4d4b716b5ec3bff776b509fe0831342 (diff)
downloadspark-ea60284095cad43aa7ac98256576375d0e91a52a.tar.gz
spark-ea60284095cad43aa7ac98256576375d0e91a52a.tar.bz2
spark-ea60284095cad43aa7ac98256576375d0e91a52a.zip
[SPARK-5704] [SQL] [PySpark] createDataFrame from RDD with columns
Deprecate inferSchema() and applySchema(), use createDataFrame() instead, which could take an optional `schema` to create an DataFrame from an RDD. The `schema` could be StructType or list of names of columns. Author: Davies Liu <davies@databricks.com> Closes #4498 from davies/create and squashes the following commits: 08469c1 [Davies Liu] remove Scala/Java API for now c80a7a9 [Davies Liu] fix hive test d1bd8f2 [Davies Liu] cleanup applySchema 9526e97 [Davies Liu] createDataFrame from RDD with columns
Diffstat (limited to 'sql/core')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala95
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala9
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala18
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala4
6 files changed, 104 insertions, 28 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 801505bceb..523911d108 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -243,7 +243,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
* val people =
* sc.textFile("examples/src/main/resources/people.txt").map(
* _.split(",")).map(p => Row(p(0), p(1).trim.toInt))
- * val dataFrame = sqlContext. applySchema(people, schema)
+ * val dataFrame = sqlContext.createDataFrame(people, schema)
* dataFrame.printSchema
* // root
* // |-- name: string (nullable = false)
@@ -252,11 +252,9 @@ class SQLContext(@transient val sparkContext: SparkContext)
* dataFrame.registerTempTable("people")
* sqlContext.sql("select name from people").collect.foreach(println)
* }}}
- *
- * @group userf
*/
@DeveloperApi
- def applySchema(rowRDD: RDD[Row], schema: StructType): DataFrame = {
+ def createDataFrame(rowRDD: RDD[Row], schema: StructType): DataFrame = {
// TODO: use MutableProjection when rowRDD is another DataFrame and the applied
// schema differs from the existing schema on any field data type.
val logicalPlan = LogicalRDD(schema.toAttributes, rowRDD)(self)
@@ -264,8 +262,21 @@ class SQLContext(@transient val sparkContext: SparkContext)
}
@DeveloperApi
- def applySchema(rowRDD: JavaRDD[Row], schema: StructType): DataFrame = {
- applySchema(rowRDD.rdd, schema);
+ def createDataFrame(rowRDD: JavaRDD[Row], schema: StructType): DataFrame = {
+ createDataFrame(rowRDD.rdd, schema)
+ }
+
+ /**
+ * Creates a [[DataFrame]] from an [[JavaRDD]] containing [[Row]]s by applying
+ * a seq of names of columns to this RDD, the data type for each column will
+ * be inferred by the first row.
+ *
+ * @param rowRDD an JavaRDD of Row
+ * @param columns names for each column
+ * @return DataFrame
+ */
+ def createDataFrame(rowRDD: JavaRDD[Row], columns: java.util.List[String]): DataFrame = {
+ createDataFrame(rowRDD.rdd, columns.toSeq)
}
/**
@@ -274,7 +285,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
* WARNING: Since there is no guaranteed ordering for fields in a Java Bean,
* SELECT * queries will return the columns in an undefined order.
*/
- def applySchema(rdd: RDD[_], beanClass: Class[_]): DataFrame = {
+ def createDataFrame(rdd: RDD[_], beanClass: Class[_]): DataFrame = {
val attributeSeq = getSchema(beanClass)
val className = beanClass.getName
val rowRdd = rdd.mapPartitions { iter =>
@@ -301,8 +312,72 @@ class SQLContext(@transient val sparkContext: SparkContext)
* WARNING: Since there is no guaranteed ordering for fields in a Java Bean,
* SELECT * queries will return the columns in an undefined order.
*/
+ def createDataFrame(rdd: JavaRDD[_], beanClass: Class[_]): DataFrame = {
+ createDataFrame(rdd.rdd, beanClass)
+ }
+
+ /**
+ * :: DeveloperApi ::
+ * Creates a [[DataFrame]] from an [[RDD]] containing [[Row]]s by applying a schema to this RDD.
+ * It is important to make sure that the structure of every [[Row]] of the provided RDD matches
+ * the provided schema. Otherwise, there will be runtime exception.
+ * Example:
+ * {{{
+ * import org.apache.spark.sql._
+ * val sqlContext = new org.apache.spark.sql.SQLContext(sc)
+ *
+ * val schema =
+ * StructType(
+ * StructField("name", StringType, false) ::
+ * StructField("age", IntegerType, true) :: Nil)
+ *
+ * val people =
+ * sc.textFile("examples/src/main/resources/people.txt").map(
+ * _.split(",")).map(p => Row(p(0), p(1).trim.toInt))
+ * val dataFrame = sqlContext. applySchema(people, schema)
+ * dataFrame.printSchema
+ * // root
+ * // |-- name: string (nullable = false)
+ * // |-- age: integer (nullable = true)
+ *
+ * dataFrame.registerTempTable("people")
+ * sqlContext.sql("select name from people").collect.foreach(println)
+ * }}}
+ *
+ * @group userf
+ */
+ @DeveloperApi
+ @deprecated("use createDataFrame", "1.3.0")
+ def applySchema(rowRDD: RDD[Row], schema: StructType): DataFrame = {
+ createDataFrame(rowRDD, schema)
+ }
+
+ @DeveloperApi
+ @deprecated("use createDataFrame", "1.3.0")
+ def applySchema(rowRDD: JavaRDD[Row], schema: StructType): DataFrame = {
+ createDataFrame(rowRDD, schema)
+ }
+
+ /**
+ * Applies a schema to an RDD of Java Beans.
+ *
+ * WARNING: Since there is no guaranteed ordering for fields in a Java Bean,
+ * SELECT * queries will return the columns in an undefined order.
+ */
+ @deprecated("use createDataFrame", "1.3.0")
+ def applySchema(rdd: RDD[_], beanClass: Class[_]): DataFrame = {
+ createDataFrame(rdd, beanClass)
+ }
+
+ /**
+ * Applies a schema to an RDD of Java Beans.
+ *
+ * WARNING: Since there is no guaranteed ordering for fields in a Java Bean,
+ * SELECT * queries will return the columns in an undefined order.
+ */
+ @deprecated("use createDataFrame", "1.3.0")
def applySchema(rdd: JavaRDD[_], beanClass: Class[_]): DataFrame = {
- applySchema(rdd.rdd, beanClass)
+ createDataFrame(rdd, beanClass)
}
/**
@@ -375,7 +450,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
JsonRDD.nullTypeToStringType(
JsonRDD.inferSchema(json, 1.0, columnNameOfCorruptJsonRecord)))
val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema, columnNameOfCorruptJsonRecord)
- applySchema(rowRDD, appliedSchema)
+ createDataFrame(rowRDD, appliedSchema)
}
@Experimental
@@ -393,7 +468,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
JsonRDD.nullTypeToStringType(
JsonRDD.inferSchema(json, samplingRatio, columnNameOfCorruptJsonRecord))
val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema, columnNameOfCorruptJsonRecord)
- applySchema(rowRDD, appliedSchema)
+ createDataFrame(rowRDD, appliedSchema)
}
@Experimental
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
index fa4cdecbcb..1d71039872 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
@@ -180,7 +180,7 @@ class ColumnExpressionSuite extends QueryTest {
}
test("!==") {
- val nullData = TestSQLContext.applySchema(TestSQLContext.sparkContext.parallelize(
+ val nullData = TestSQLContext.createDataFrame(TestSQLContext.sparkContext.parallelize(
Row(1, 1) ::
Row(1, 2) ::
Row(1, null) ::
@@ -240,7 +240,7 @@ class ColumnExpressionSuite extends QueryTest {
testData2.collect().toSeq.filter(r => r.getInt(0) <= r.getInt(1)))
}
- val booleanData = TestSQLContext.applySchema(TestSQLContext.sparkContext.parallelize(
+ val booleanData = TestSQLContext.createDataFrame(TestSQLContext.sparkContext.parallelize(
Row(false, false) ::
Row(false, true) ::
Row(true, false) ::
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 55fd0b0892..bba8899651 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -34,6 +34,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
TestData
import org.apache.spark.sql.test.TestSQLContext.implicits._
+ val sqlCtx = TestSQLContext
test("SPARK-4625 support SORT BY in SimpleSQLParser & DSL") {
checkAnswer(
@@ -669,7 +670,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
Row(values(0).toInt, values(1), values(2).toBoolean, v4)
}
- val df1 = applySchema(rowRDD1, schema1)
+ val df1 = sqlCtx.createDataFrame(rowRDD1, schema1)
df1.registerTempTable("applySchema1")
checkAnswer(
sql("SELECT * FROM applySchema1"),
@@ -699,7 +700,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
Row(Row(values(0).toInt, values(2).toBoolean), Map(values(1) -> v4))
}
- val df2 = applySchema(rowRDD2, schema2)
+ val df2 = sqlCtx.createDataFrame(rowRDD2, schema2)
df2.registerTempTable("applySchema2")
checkAnswer(
sql("SELECT * FROM applySchema2"),
@@ -724,7 +725,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
Row(Row(values(0).toInt, values(2).toBoolean), scala.collection.mutable.Map(values(1) -> v4))
}
- val df3 = applySchema(rowRDD3, schema2)
+ val df3 = sqlCtx.createDataFrame(rowRDD3, schema2)
df3.registerTempTable("applySchema3")
checkAnswer(
@@ -769,7 +770,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
.build()
val schemaWithMeta = new StructType(Array(
schema("id"), schema("name").copy(metadata = metadata), schema("age")))
- val personWithMeta = applySchema(person.rdd, schemaWithMeta)
+ val personWithMeta = sqlCtx.createDataFrame(person.rdd, schemaWithMeta)
def validateMetadata(rdd: DataFrame): Unit = {
assert(rdd.schema("name").metadata.getString(docKey) == docValue)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index df108a9d26..c3210733f1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -71,7 +71,7 @@ class PlannerSuite extends FunSuite {
val schema = StructType(fields)
val row = Row.fromSeq(Seq.fill(fields.size)(null))
val rowRDD = org.apache.spark.sql.test.TestSQLContext.sparkContext.parallelize(row :: Nil)
- applySchema(rowRDD, schema).registerTempTable("testLimit")
+ createDataFrame(rowRDD, schema).registerTempTable("testLimit")
val planned = sql(
"""
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
index e581ac9b50..21e7093610 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
@@ -54,7 +54,7 @@ class JDBCWriteSuite extends FunSuite with BeforeAndAfter {
StructField("seq", IntegerType) :: Nil)
test("Basic CREATE") {
- val srdd = TestSQLContext.applySchema(sc.parallelize(arr2x2), schema2)
+ val srdd = TestSQLContext.createDataFrame(sc.parallelize(arr2x2), schema2)
srdd.createJDBCTable(url, "TEST.BASICCREATETEST", false)
assert(2 == TestSQLContext.jdbcRDD(url, "TEST.BASICCREATETEST").count)
@@ -62,8 +62,8 @@ class JDBCWriteSuite extends FunSuite with BeforeAndAfter {
}
test("CREATE with overwrite") {
- val srdd = TestSQLContext.applySchema(sc.parallelize(arr2x3), schema3)
- val srdd2 = TestSQLContext.applySchema(sc.parallelize(arr1x2), schema2)
+ val srdd = TestSQLContext.createDataFrame(sc.parallelize(arr2x3), schema3)
+ val srdd2 = TestSQLContext.createDataFrame(sc.parallelize(arr1x2), schema2)
srdd.createJDBCTable(url, "TEST.DROPTEST", false)
assert(2 == TestSQLContext.jdbcRDD(url, "TEST.DROPTEST").count)
@@ -75,8 +75,8 @@ class JDBCWriteSuite extends FunSuite with BeforeAndAfter {
}
test("CREATE then INSERT to append") {
- val srdd = TestSQLContext.applySchema(sc.parallelize(arr2x2), schema2)
- val srdd2 = TestSQLContext.applySchema(sc.parallelize(arr1x2), schema2)
+ val srdd = TestSQLContext.createDataFrame(sc.parallelize(arr2x2), schema2)
+ val srdd2 = TestSQLContext.createDataFrame(sc.parallelize(arr1x2), schema2)
srdd.createJDBCTable(url, "TEST.APPENDTEST", false)
srdd2.insertIntoJDBC(url, "TEST.APPENDTEST", false)
@@ -85,8 +85,8 @@ class JDBCWriteSuite extends FunSuite with BeforeAndAfter {
}
test("CREATE then INSERT to truncate") {
- val srdd = TestSQLContext.applySchema(sc.parallelize(arr2x2), schema2)
- val srdd2 = TestSQLContext.applySchema(sc.parallelize(arr1x2), schema2)
+ val srdd = TestSQLContext.createDataFrame(sc.parallelize(arr2x2), schema2)
+ val srdd2 = TestSQLContext.createDataFrame(sc.parallelize(arr1x2), schema2)
srdd.createJDBCTable(url, "TEST.TRUNCATETEST", false)
srdd2.insertIntoJDBC(url, "TEST.TRUNCATETEST", true)
@@ -95,8 +95,8 @@ class JDBCWriteSuite extends FunSuite with BeforeAndAfter {
}
test("Incompatible INSERT to append") {
- val srdd = TestSQLContext.applySchema(sc.parallelize(arr2x2), schema2)
- val srdd2 = TestSQLContext.applySchema(sc.parallelize(arr2x3), schema3)
+ val srdd = TestSQLContext.createDataFrame(sc.parallelize(arr2x2), schema2)
+ val srdd2 = TestSQLContext.createDataFrame(sc.parallelize(arr2x3), schema3)
srdd.createJDBCTable(url, "TEST.INCOMPATIBLETEST", false)
intercept[org.apache.spark.SparkException] {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
index 4fc92e3e3b..fde4b47438 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
@@ -820,7 +820,7 @@ class JsonSuite extends QueryTest {
Row(values(0).toInt, values(1), values(2).toBoolean, r.split(",").toList, v5)
}
- val df1 = applySchema(rowRDD1, schema1)
+ val df1 = createDataFrame(rowRDD1, schema1)
df1.registerTempTable("applySchema1")
val df2 = df1.toDataFrame
val result = df2.toJSON.collect()
@@ -841,7 +841,7 @@ class JsonSuite extends QueryTest {
Row(Row(values(0).toInt, values(2).toBoolean), Map(values(1) -> v4))
}
- val df3 = applySchema(rowRDD2, schema2)
+ val df3 = createDataFrame(rowRDD2, schema2)
df3.registerTempTable("applySchema2")
val df4 = df3.toDataFrame
val result2 = df4.toJSON.collect()