aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-04-02 17:57:01 +0800
committerMichael Armbrust <michael@databricks.com>2015-04-02 16:56:21 -0700
commitc2694bba61268d61cb6cdf1a72e493e49c824564 (patch)
tree42f07b7cd43a401575769a60f144a6635659d70d
parente6ee95cbda7a52a40426b9461f34bb611bfcf077 (diff)
downloadspark-c2694bba61268d61cb6cdf1a72e493e49c824564.tar.gz
spark-c2694bba61268d61cb6cdf1a72e493e49c824564.tar.bz2
spark-c2694bba61268d61cb6cdf1a72e493e49c824564.zip
[SPARK-6672][SQL] convert row to catalyst in createDataFrame(RDD[Row], ...)
We assume that `RDD[Row]` contains Scala types. So we need to convert them into catalyst types in createDataFrame. liancheng Author: Xiangrui Meng <meng@databricks.com> Closes #5329 from mengxr/SPARK-6672 and squashes the following commits: 2d52644 [Xiangrui Meng] set needsConversion = false in jsonRDD 06896e4 [Xiangrui Meng] add createDataFrame without conversion 4a3767b [Xiangrui Meng] convert Row to catalyst (cherry picked from commit 424e987dfebbbaa37f4496d44090d469a931ce76) Signed-off-by: Michael Armbrust <michael@databricks.com>
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala5
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala20
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala9
7 files changed, 37 insertions, 8 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 2220970085..8bfd0471d9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -72,6 +72,11 @@ trait ScalaReflection {
case (d: BigDecimal, _) => Decimal(d)
case (d: java.math.BigDecimal, _) => Decimal(d)
case (d: java.sql.Date, _) => DateUtils.fromJavaDate(d)
+ case (r: Row, structType: StructType) =>
+ new GenericRow(
+ r.toSeq.zip(structType.fields).map { case (elem, field) =>
+ convertToCatalyst(elem, field.dataType)
+ }.toArray)
case (other, _) => other
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index de0a192c31..2b6b0e7997 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -904,7 +904,8 @@ class DataFrame private[sql](
*/
override def repartition(numPartitions: Int): DataFrame = {
sqlContext.createDataFrame(
- queryExecution.toRdd.map(_.copy()).repartition(numPartitions), schema)
+ queryExecution.toRdd.map(_.copy()).repartition(numPartitions),
+ schema, needsConversion = false)
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 5dc3e34664..b7f72e716c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -379,9 +379,23 @@ class SQLContext(@transient val sparkContext: SparkContext)
*/
@DeveloperApi
def createDataFrame(rowRDD: RDD[Row], schema: StructType): DataFrame = {
+ createDataFrame(rowRDD, schema, needsConversion = true)
+ }
+
+ /**
+ * Creates a DataFrame from an RDD[Row]. User can specify whether the input rows should be
+ * converted to Catalyst rows.
+ */
+ private[sql]
+ def createDataFrame(rowRDD: RDD[Row], schema: StructType, needsConversion: Boolean) = {
// TODO: use MutableProjection when rowRDD is another DataFrame and the applied
// schema differs from the existing schema on any field data type.
- val logicalPlan = LogicalRDD(schema.toAttributes, rowRDD)(self)
+ val catalystRows = if (needsConversion) {
+ rowRDD.map(ScalaReflection.convertToCatalyst(_, schema).asInstanceOf[Row])
+ } else {
+ rowRDD
+ }
+ val logicalPlan = LogicalRDD(schema.toAttributes, catalystRows)(self)
DataFrame(this, logicalPlan)
}
@@ -591,7 +605,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
JsonRDD.nullTypeToStringType(
JsonRDD.inferSchema(json, 1.0, columnNameOfCorruptJsonRecord)))
val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema, columnNameOfCorruptJsonRecord)
- createDataFrame(rowRDD, appliedSchema)
+ createDataFrame(rowRDD, appliedSchema, needsConversion = false)
}
/**
@@ -620,7 +634,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
JsonRDD.nullTypeToStringType(
JsonRDD.inferSchema(json, samplingRatio, columnNameOfCorruptJsonRecord))
val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema, columnNameOfCorruptJsonRecord)
- createDataFrame(rowRDD, appliedSchema)
+ createDataFrame(rowRDD, appliedSchema, needsConversion = false)
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala
index b297f1935e..ccc28bea79 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala
@@ -122,7 +122,8 @@ private[sql] class DefaultSource
val df =
sqlContext.createDataFrame(
data.queryExecution.toRdd,
- data.schema.asNullable)
+ data.schema.asNullable,
+ needsConversion = false)
val createdRelation =
createRelation(sqlContext, parameters, df.schema).asInstanceOf[ParquetRelation2]
createdRelation.insert(df, overwrite = mode == SaveMode.Overwrite)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala
index 9bbe06e59b..dbdb0d39c2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala
@@ -31,7 +31,8 @@ private[sql] case class InsertIntoDataSource(
val relation = logicalRelation.relation.asInstanceOf[InsertableRelation]
val data = DataFrame(sqlContext, query)
// Apply the schema of the existing table to the new data.
- val df = sqlContext.createDataFrame(data.queryExecution.toRdd, logicalRelation.schema)
+ val df = sqlContext.createDataFrame(
+ data.queryExecution.toRdd, logicalRelation.schema, needsConversion = false)
relation.insert(df, overwrite)
// Invalidate the cache.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala
index c11d0ae5bf..2fdd798b44 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala
@@ -28,7 +28,7 @@ import org.apache.spark.sql.types._
* @param y y coordinate
*/
@SQLUserDefinedType(udt = classOf[ExamplePointUDT])
-private[sql] class ExamplePoint(val x: Double, val y: Double)
+private[sql] class ExamplePoint(val x: Double, val y: Double) extends Serializable
/**
* User-defined type for [[ExamplePoint]].
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 6761d996fd..5297cc01ed 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -21,7 +21,7 @@ import scala.language.postfixOps
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
-import org.apache.spark.sql.test.TestSQLContext
+import org.apache.spark.sql.test.{ExamplePointUDT, ExamplePoint, TestSQLContext}
import org.apache.spark.sql.test.TestSQLContext.logicalPlanToSparkQuery
import org.apache.spark.sql.test.TestSQLContext.implicits._
import org.apache.spark.sql.test.TestSQLContext.sql
@@ -506,4 +506,11 @@ class DataFrameSuite extends QueryTest {
testData.select($"*").show()
testData.select($"*").show(1000)
}
+
+ test("createDataFrame(RDD[Row], StructType) should convert UDTs (SPARK-6672)") {
+ val rowRDD = TestSQLContext.sparkContext.parallelize(Seq(Row(new ExamplePoint(1.0, 2.0))))
+ val schema = StructType(Array(StructField("point", new ExamplePointUDT(), false)))
+ val df = TestSQLContext.createDataFrame(rowRDD, schema)
+ df.rdd.collect()
+ }
}