aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorTakuya UESHIN <ueshin@happy-camper.st>2014-08-26 15:04:08 -0700
committerMichael Armbrust <michael@databricks.com>2014-08-26 15:04:08 -0700
commit6b5584ef1c605cd30f25dbe7099ab32aea1746fb (patch)
tree2e5abc226595b23bdb1b43de2800942afe252ea5 /sql
parent98c2bb0bbde6fb2b6f64af3efffefcb0dae94c12 (diff)
downloadspark-6b5584ef1c605cd30f25dbe7099ab32aea1746fb.tar.gz
spark-6b5584ef1c605cd30f25dbe7099ab32aea1746fb.tar.bz2
spark-6b5584ef1c605cd30f25dbe7099ab32aea1746fb.zip
[SPARK-3063][SQL] ExistingRdd should convert Map to catalyst Map.
Currently `ExistingRdd.convertToCatalyst` doesn't convert `Map` value. Author: Takuya UESHIN <ueshin@happy-camper.st> Closes #1963 from ueshin/issues/SPARK-3063 and squashes the following commits: 3ba41f2 [Takuya UESHIN] Merge branch 'master' into issues/SPARK-3063 4d7bae2 [Takuya UESHIN] Merge branch 'master' into issues/SPARK-3063 9321379 [Takuya UESHIN] Merge branch 'master' into issues/SPARK-3063 d8a900a [Takuya UESHIN] Make ExistingRdd.convertToCatalyst be able to convert Map value.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala3
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala46
2 files changed, 48 insertions, 1 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index f9dfa3c92f..374af48b82 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -206,7 +206,8 @@ case class Sort(
object ExistingRdd {
def convertToCatalyst(a: Any): Any = a match {
case o: Option[_] => o.orNull
- case s: Seq[Any] => s.map(convertToCatalyst)
+ case s: Seq[_] => s.map(convertToCatalyst)
+ case m: Map[_, _] => m.map { case (k, v) => convertToCatalyst(k) -> convertToCatalyst(v) }
case p: Product => new GenericRow(p.productIterator.map(convertToCatalyst).toArray)
case other => other
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala
index 5b84c658db..e24c521d24 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala
@@ -21,6 +21,7 @@ import java.sql.Timestamp
import org.scalatest.FunSuite
+import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.test.TestSQLContext._
case class ReflectData(
@@ -56,6 +57,22 @@ case class OptionalReflectData(
case class ReflectBinary(data: Array[Byte])
+case class Nested(i: Option[Int], s: String)
+
+case class Data(
+ array: Seq[Int],
+ arrayContainsNull: Seq[Option[Int]],
+ map: Map[Int, Long],
+ mapContainsNul: Map[Int, Option[Long]],
+ nested: Nested)
+
+case class ComplexReflectData(
+ arrayField: Seq[Int],
+ arrayFieldContainsNull: Seq[Option[Int]],
+ mapField: Map[Int, Long],
+ mapFieldContainsNull: Map[Int, Option[Long]],
+ dataField: Data)
+
class ScalaReflectionRelationSuite extends FunSuite {
test("query case class RDD") {
val data = ReflectData("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true,
@@ -90,4 +107,33 @@ class ScalaReflectionRelationSuite extends FunSuite {
val result = sql("SELECT data FROM reflectBinary").collect().head(0).asInstanceOf[Array[Byte]]
assert(result.toSeq === Seq[Byte](1))
}
+
+ test("query complex data") {
+ val data = ComplexReflectData(
+ Seq(1, 2, 3),
+ Seq(Some(1), Some(2), None),
+ Map(1 -> 10L, 2 -> 20L),
+ Map(1 -> Some(10L), 2 -> Some(20L), 3 -> None),
+ Data(
+ Seq(10, 20, 30),
+ Seq(Some(10), Some(20), None),
+ Map(10 -> 100L, 20 -> 200L),
+ Map(10 -> Some(100L), 20 -> Some(200L), 30 -> None),
+ Nested(None, "abc")))
+ val rdd = sparkContext.parallelize(data :: Nil)
+ rdd.registerTempTable("reflectComplexData")
+
+ assert(sql("SELECT * FROM reflectComplexData").collect().head ===
+ new GenericRow(Array[Any](
+ Seq(1, 2, 3),
+ Seq(1, 2, null),
+ Map(1 -> 10L, 2 -> 20L),
+ Map(1 -> 10L, 2 -> 20L, 3 -> null),
+ new GenericRow(Array[Any](
+ Seq(10, 20, 30),
+ Seq(10, 20, null),
+ Map(10 -> 100L, 20 -> 200L),
+ Map(10 -> 100L, 20 -> 200L, 30 -> null),
+ new GenericRow(Array[Any](null, "abc")))))))
+ }
}