aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala12
2 files changed, 13 insertions, 1 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 6e20096901..ad218cf88d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -342,7 +342,7 @@ object ScalaReflection extends ScalaReflection {
StaticInvoke(
ArrayBasedMapData.getClass,
- ObjectType(classOf[Map[_, _]]),
+ ObjectType(classOf[scala.collection.immutable.Map[_, _]]),
"toScalaMap",
keyData :: valueData :: Nil)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 37d5667ed8..3742115134 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -1120,8 +1120,20 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
// sizeInBytes is 2404280404, before the fix, it overflows to a negative number
assert(sizeInBytes > 0)
}
+
+ test("SPARK-18717: code generation works for both scala.collection.Map" +
+ " and scala.collection.imutable.Map") {
+ val ds = Seq(WithImmutableMap("hi", Map(42L -> "foo"))).toDS
+ checkDataset(ds.map(t => t), WithImmutableMap("hi", Map(42L -> "foo")))
+
+ val ds2 = Seq(WithMap("hi", Map(42L -> "foo"))).toDS
+ checkDataset(ds2.map(t => t), WithMap("hi", Map(42L -> "foo")))
+ }
}
+case class WithImmutableMap(id: String, map_test: scala.collection.immutable.Map[Long, String])
+case class WithMap(id: String, map_test: scala.collection.Map[Long, String])
+
case class Generic[T](id: T, value: Double)
case class OtherTuple(_1: String, _2: Int)