diff options
author | Takuya UESHIN <ueshin@happy-camper.st> | 2014-08-26 15:04:08 -0700 |
---|---|---|
committer | Michael Armbrust <michael@databricks.com> | 2014-08-26 15:04:08 -0700 |
commit | 6b5584ef1c605cd30f25dbe7099ab32aea1746fb (patch) | |
tree | 2e5abc226595b23bdb1b43de2800942afe252ea5 /sql/core | |
parent | 98c2bb0bbde6fb2b6f64af3efffefcb0dae94c12 (diff) | |
download | spark-6b5584ef1c605cd30f25dbe7099ab32aea1746fb.tar.gz spark-6b5584ef1c605cd30f25dbe7099ab32aea1746fb.tar.bz2 spark-6b5584ef1c605cd30f25dbe7099ab32aea1746fb.zip |
[SPARK-3063][SQL] ExistingRdd should convert Map to catalyst Map.
Currently `ExistingRdd.convertToCatalyst` doesn't convert `Map` value.
Author: Takuya UESHIN <ueshin@happy-camper.st>
Closes #1963 from ueshin/issues/SPARK-3063 and squashes the following commits:
3ba41f2 [Takuya UESHIN] Merge branch 'master' into issues/SPARK-3063
4d7bae2 [Takuya UESHIN] Merge branch 'master' into issues/SPARK-3063
9321379 [Takuya UESHIN] Merge branch 'master' into issues/SPARK-3063
d8a900a [Takuya UESHIN] Make ExistingRdd.convertToCatalyst be able to convert Map value.
Diffstat (limited to 'sql/core')
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala | 3 | ||||
-rw-r--r-- | sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala | 46 |
2 files changed, 48 insertions, 1 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index f9dfa3c92f..374af48b82 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -206,7 +206,8 @@ case class Sort( object ExistingRdd { def convertToCatalyst(a: Any): Any = a match { case o: Option[_] => o.orNull - case s: Seq[Any] => s.map(convertToCatalyst) + case s: Seq[_] => s.map(convertToCatalyst) + case m: Map[_, _] => m.map { case (k, v) => convertToCatalyst(k) -> convertToCatalyst(v) } case p: Product => new GenericRow(p.productIterator.map(convertToCatalyst).toArray) case other => other } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala index 5b84c658db..e24c521d24 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala @@ -21,6 +21,7 @@ import java.sql.Timestamp import org.scalatest.FunSuite +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.test.TestSQLContext._ case class ReflectData( @@ -56,6 +57,22 @@ case class OptionalReflectData( case class ReflectBinary(data: Array[Byte]) +case class Nested(i: Option[Int], s: String) + +case class Data( + array: Seq[Int], + arrayContainsNull: Seq[Option[Int]], + map: Map[Int, Long], + mapContainsNul: Map[Int, Option[Long]], + nested: Nested) + +case class ComplexReflectData( + arrayField: Seq[Int], + arrayFieldContainsNull: Seq[Option[Int]], + mapField: Map[Int, Long], + mapFieldContainsNull: Map[Int, Option[Long]], + dataField: Data) + class ScalaReflectionRelationSuite extends FunSuite { test("query case class RDD") { val data = ReflectData("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true, @@ -90,4 +107,33 @@ class ScalaReflectionRelationSuite extends FunSuite { val result = sql("SELECT data FROM reflectBinary").collect().head(0).asInstanceOf[Array[Byte]] assert(result.toSeq === Seq[Byte](1)) } + + test("query complex data") { + val data = ComplexReflectData( + Seq(1, 2, 3), + Seq(Some(1), Some(2), None), + Map(1 -> 10L, 2 -> 20L), + Map(1 -> Some(10L), 2 -> Some(20L), 3 -> None), + Data( + Seq(10, 20, 30), + Seq(Some(10), Some(20), None), + Map(10 -> 100L, 20 -> 200L), + Map(10 -> Some(100L), 20 -> Some(200L), 30 -> None), + Nested(None, "abc"))) + val rdd = sparkContext.parallelize(data :: Nil) + rdd.registerTempTable("reflectComplexData") + + assert(sql("SELECT * FROM reflectComplexData").collect().head === + new GenericRow(Array[Any]( + Seq(1, 2, 3), + Seq(1, 2, null), + Map(1 -> 10L, 2 -> 20L), + Map(1 -> 10L, 2 -> 20L, 3 -> null), + new GenericRow(Array[Any]( + Seq(10, 20, 30), + Seq(10, 20, null), + Map(10 -> 100L, 20 -> 200L), + Map(10 -> 100L, 20 -> 200L, 30 -> null), + new GenericRow(Array[Any](null, "abc"))))))) + } } |