aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst/src
diff options
context:
space:
mode:
authorbomeng <bmeng@us.ibm.com>2016-05-02 18:20:29 -0700
committerMichael Armbrust <michael@databricks.com>2016-05-02 18:20:29 -0700
commit0fd95be3cd815154a11ce7d6998311e7c86bc6b9 (patch)
tree37a9c1a0d3b7e4cd92e6b2df77e2e63beeec876a /sql/catalyst/src
parent1c19c2769edecaefabc2cd67b3b32f901feb3f59 (diff)
downloadspark-0fd95be3cd815154a11ce7d6998311e7c86bc6b9.tar.gz
spark-0fd95be3cd815154a11ce7d6998311e7c86bc6b9.tar.bz2
spark-0fd95be3cd815154a11ce7d6998311e7c86bc6b9.zip
[SPARK-15062][SQL] fix list type infer serializer issue
## What changes were proposed in this pull request? Make serializer correctly inferred if the input type is `List[_]`, since `List[_]` is type of `Seq[_]`, before it was matched to different case (`case t if definedByConstructorParams(t)`). ## How was this patch tested? New test case was added. Author: bomeng <bmeng@us.ibm.com> Closes #12849 from bomeng/SPARK-15062.
Diffstat (limited to 'sql/catalyst/src')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala11
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala15
2 files changed, 21 insertions, 5 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index be0d75a830..d158a64a85 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -509,6 +509,13 @@ object ScalaReflection extends ScalaReflection {
serializerFor(unwrapped, optType, newPath))
}
+ // Since List[_] also belongs to localTypeOf[Product], we put this case before
+ // "case t if definedByConstructorParams(t)" to make sure it will match to the
+ // case "localTypeOf[Seq[_]]"
+ case t if t <:< localTypeOf[Seq[_]] =>
+ val TypeRef(_, _, Seq(elementType)) = t
+ toCatalystArray(inputObject, elementType)
+
case t if definedByConstructorParams(t) =>
val params = getConstructorParameters(t)
val nonNullOutput = CreateNamedStruct(params.flatMap { case (fieldName, fieldType) =>
@@ -524,10 +531,6 @@ object ScalaReflection extends ScalaReflection {
val TypeRef(_, _, Seq(elementType)) = t
toCatalystArray(inputObject, elementType)
- case t if t <:< localTypeOf[Seq[_]] =>
- val TypeRef(_, _, Seq(elementType)) = t
- toCatalystArray(inputObject, elementType)
-
case t if t <:< localTypeOf[Map[_, _]] =>
val TypeRef(_, _, Seq(keyType, valueType)) = t
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
index 0672551b29..c00e9c7e39 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
@@ -23,8 +23,9 @@ import java.sql.{Date, Timestamp}
import scala.reflect.runtime.universe.typeOf
import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.catalyst.expressions.{BoundReference, SpecificMutableRow}
+import org.apache.spark.sql.catalyst.expressions.{BoundReference, Literal, NewInstance, SpecificMutableRow}
import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils
case class PrimitiveData(
@@ -277,6 +278,18 @@ class ScalaReflectionSuite extends SparkFunSuite {
assert(anyTypes === Seq(classOf[java.lang.Object], classOf[java.lang.Object]))
}
+ test("SPARK-15062: Get correct serializer for List[_]") {
+ val list = List(1, 2, 3)
+ val serializer = serializerFor[List[Int]](BoundReference(
+ 0, ObjectType(list.getClass), nullable = false))
+ assert(serializer.children.size == 2)
+ assert(serializer.children.head.isInstanceOf[Literal])
+ assert(serializer.children.head.asInstanceOf[Literal].value === UTF8String.fromString("value"))
+ assert(serializer.children.last.isInstanceOf[NewInstance])
+ assert(serializer.children.last.asInstanceOf[NewInstance]
+ .cls.isInstanceOf[Class[org.apache.spark.sql.catalyst.util.GenericArrayData]])
+ }
+
private val dataTypeForComplexData = dataTypeFor[ComplexData]
private val typeOfComplexData = typeOf[ComplexData]