aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala6
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala44
2 files changed, 47 insertions, 3 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 792ef6cee6..196695a0a1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -41,6 +41,9 @@ object ScalaReflection {
/** Returns a catalyst DataType for the given Scala Type using reflection. */
def schemaFor(tpe: `Type`): DataType = tpe match {
+ case t if t <:< typeOf[Option[_]] =>
+ val TypeRef(_, _, Seq(optType)) = t
+ schemaFor(optType)
case t if t <:< typeOf[Product] =>
val params = t.member("<init>": TermName).asMethod.paramss
StructType(
@@ -59,9 +62,6 @@ object ScalaReflection {
case t if t <:< typeOf[String] => StringType
case t if t <:< typeOf[Timestamp] => TimestampType
case t if t <:< typeOf[BigDecimal] => DecimalType
- case t if t <:< typeOf[Option[_]] =>
- val TypeRef(_, _, Seq(optType)) = t
- schemaFor(optType)
case t if t <:< typeOf[java.lang.Integer] => IntegerType
case t if t <:< typeOf[java.lang.Long] => LongType
case t if t <:< typeOf[java.lang.Double] => DoubleType
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
index d9c9b9a076..ff1677eb8a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
@@ -42,6 +42,20 @@ import org.apache.spark.sql.test.TestSQLContext._
case class TestRDDEntry(key: Int, value: String)
+case class NullReflectData(
+ intField: java.lang.Integer,
+ longField: java.lang.Long,
+ floatField: java.lang.Float,
+ doubleField: java.lang.Double,
+ booleanField: java.lang.Boolean)
+
+case class OptionalReflectData(
+ intField: Option[Int],
+ longField: Option[Long],
+ floatField: Option[Float],
+ doubleField: Option[Double],
+ booleanField: Option[Boolean])
+
class ParquetQuerySuite extends QueryTest with FunSuite with BeforeAndAfterAll {
import TestData._
TestData // Load test data tables.
@@ -195,5 +209,35 @@ class ParquetQuerySuite extends QueryTest with FunSuite with BeforeAndAfterAll {
Utils.deleteRecursively(ParquetTestData.testDir)
ParquetTestData.writeFile()
}
+
+ test("save and load case class RDD with nulls as parquet") {
+ val data = NullReflectData(null, null, null, null, null)
+ val rdd = sparkContext.parallelize(data :: Nil)
+
+ val file = getTempFilePath("parquet")
+ val path = file.toString
+ rdd.saveAsParquetFile(path)
+ val readFile = parquetFile(path)
+
+ val rdd_saved = readFile.collect()
+ assert(rdd_saved(0) === Seq.fill(5)(null))
+ Utils.deleteRecursively(file)
+ assert(true)
+ }
+
+ test("save and load case class RDD with Nones as parquet") {
+ val data = OptionalReflectData(null, null, null, null, null)
+ val rdd = sparkContext.parallelize(data :: Nil)
+
+ val file = getTempFilePath("parquet")
+ val path = file.toString
+ rdd.saveAsParquetFile(path)
+ val readFile = parquetFile(path)
+
+ val rdd_saved = readFile.collect()
+ assert(rdd_saved(0) === Seq.fill(5)(null))
+ Utils.deleteRecursively(file)
+ assert(true)
+ }
}