aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMichal Senkyr <mike.senkyr@gmail.com>2017-01-06 15:05:20 +0800
committerWenchen Fan <wenchen@databricks.com>2017-01-06 15:05:20 +0800
commit903bb8e8a2b84b9ea82acbb8ae9d58754862be3a (patch)
tree1df577fa49e4fd3400920234cc79865f40fbebdc
parentbcc510b021391035abe6d07c5b82bb0f0be31167 (diff)
downloadspark-903bb8e8a2b84b9ea82acbb8ae9d58754862be3a.tar.gz
spark-903bb8e8a2b84b9ea82acbb8ae9d58754862be3a.tar.bz2
spark-903bb8e8a2b84b9ea82acbb8ae9d58754862be3a.zip
[SPARK-16792][SQL] Dataset containing a Case Class with a List type causes a CompileException (converting sequence to list)
## What changes were proposed in this pull request? Added a `to` call at the end of the code generated by `ScalaReflection.deserializerFor` if the requested type is not a supertype of `WrappedArray[_]` that uses `CanBuildFrom[_, _, _]` to convert result into an arbitrary subtype of `Seq[_]`. Care was taken to preserve the original deserialization where it is possible to avoid the overhead of conversion in cases where it is not needed `ScalaReflection.serializerFor` could already be used to serialize any `Seq[_]` so it was not altered `SQLImplicits` had to be altered and new implicit encoders added to permit serialization of other sequence types Also fixes [SPARK-16815] Dataset[List[T]] leads to ArrayStoreException ## How was this patch tested? ```bash ./build/mvn -DskipTests clean package && ./dev/run-tests ``` Also manual execution of the following sets of commands in the Spark shell: ```scala case class TestCC(key: Int, letters: List[String]) val ds1 = sc.makeRDD(Seq( (List("D")), (List("S","H")), (List("F","H")), (List("D","L","L")) )).map(x=>(x.length,x)).toDF("key","letters").as[TestCC] val test1=ds1.map{_.key} test1.show ``` ```scala case class X(l: List[String]) spark.createDataset(Seq(List("A"))).map(X).show ``` ```scala spark.sqlContext.createDataset(sc.parallelize(List(1) :: Nil)).collect ``` After adding arbitrary sequence support also tested with the following commands: ```scala case class QueueClass(q: scala.collection.immutable.Queue[Int]) spark.createDataset(Seq(List(1,2,3))).map(x => QueueClass(scala.collection.immutable.Queue(x: _*))).map(_.q.dequeue).collect ``` Author: Michal Senkyr <mike.senkyr@gmail.com> Closes #16240 from michalsenkyr/sql-caseclass-list-fix.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala40
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala31
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala115
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala67
4 files changed, 231 insertions, 22 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index ad218cf88d..7f7dd51aa2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -312,12 +312,50 @@ object ScalaReflection extends ScalaReflection {
"array",
ObjectType(classOf[Array[Any]]))
- StaticInvoke(
+ val wrappedArray = StaticInvoke(
scala.collection.mutable.WrappedArray.getClass,
ObjectType(classOf[Seq[_]]),
"make",
array :: Nil)
+ if (localTypeOf[scala.collection.mutable.WrappedArray[_]] <:< t.erasure) {
+ wrappedArray
+ } else {
+ // Convert to another type using `to`
+ val cls = mirror.runtimeClass(t.typeSymbol.asClass)
+ import scala.collection.generic.CanBuildFrom
+ import scala.reflect.ClassTag
+
+ // Some canBuildFrom methods take an implicit ClassTag parameter
+ val cbfParams = try {
+ cls.getDeclaredMethod("canBuildFrom", classOf[ClassTag[_]])
+ StaticInvoke(
+ ClassTag.getClass,
+ ObjectType(classOf[ClassTag[_]]),
+ "apply",
+ StaticInvoke(
+ cls,
+ ObjectType(classOf[Class[_]]),
+ "getClass"
+ ) :: Nil
+ ) :: Nil
+ } catch {
+ case _: NoSuchMethodException => Nil
+ }
+
+ Invoke(
+ wrappedArray,
+ "to",
+ ObjectType(cls),
+ StaticInvoke(
+ cls,
+ ObjectType(classOf[CanBuildFrom[_, _, _]]),
+ "canBuildFrom",
+ cbfParams
+ ) :: Nil
+ )
+ }
+
case t if t <:< localTypeOf[Map[_, _]] =>
// TODO: add walked type path for map
val TypeRef(_, _, Seq(keyType, valueType)) = t
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
index 43b6afd9ad..650a35398f 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
@@ -291,6 +291,37 @@ class ScalaReflectionSuite extends SparkFunSuite {
.cls.isAssignableFrom(classOf[org.apache.spark.sql.catalyst.util.GenericArrayData]))
}
+ test("SPARK 16792: Get correct deserializer for List[_]") {
+ val listDeserializer = deserializerFor[List[Int]]
+ assert(listDeserializer.dataType == ObjectType(classOf[List[_]]))
+ }
+
+ test("serialize and deserialize arbitrary sequence types") {
+ import scala.collection.immutable.Queue
+ val queueSerializer = serializerFor[Queue[Int]](BoundReference(
+ 0, ObjectType(classOf[Queue[Int]]), nullable = false))
+ assert(queueSerializer.dataType.head.dataType ==
+ ArrayType(IntegerType, containsNull = false))
+ val queueDeserializer = deserializerFor[Queue[Int]]
+ assert(queueDeserializer.dataType == ObjectType(classOf[Queue[_]]))
+
+ import scala.collection.mutable.ArrayBuffer
+ val arrayBufferSerializer = serializerFor[ArrayBuffer[Int]](BoundReference(
+ 0, ObjectType(classOf[ArrayBuffer[Int]]), nullable = false))
+ assert(arrayBufferSerializer.dataType.head.dataType ==
+ ArrayType(IntegerType, containsNull = false))
+ val arrayBufferDeserializer = deserializerFor[ArrayBuffer[Int]]
+ assert(arrayBufferDeserializer.dataType == ObjectType(classOf[ArrayBuffer[_]]))
+
+ // Check whether conversion is skipped when using WrappedArray[_] supertype
+ // (would otherwise needlessly add overhead)
+ import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
+ val seqDeserializer = deserializerFor[Seq[Int]]
+ assert(seqDeserializer.asInstanceOf[StaticInvoke].staticObject ==
+ scala.collection.mutable.WrappedArray.getClass)
+ assert(seqDeserializer.asInstanceOf[StaticInvoke].functionName == "make")
+ }
+
private val dataTypeForComplexData = dataTypeFor[ComplexData]
private val typeOfComplexData = typeOf[ComplexData]
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
index 872a78b578..2caf723669 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
@@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
* @since 1.6.0
*/
@InterfaceStability.Evolving
-abstract class SQLImplicits {
+abstract class SQLImplicits extends LowPrioritySQLImplicits {
protected def _sqlContext: SQLContext
@@ -45,9 +45,6 @@ abstract class SQLImplicits {
}
}
- /** @since 1.6.0 */
- implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = Encoders.product[T]
-
// Primitives
/** @since 1.6.0 */
@@ -112,33 +109,96 @@ abstract class SQLImplicits {
// Seqs
- /** @since 1.6.1 */
- implicit def newIntSeqEncoder: Encoder[Seq[Int]] = ExpressionEncoder()
+ /**
+ * @since 1.6.1
+ * @deprecated use [[newIntSequenceEncoder]]
+ */
+ def newIntSeqEncoder: Encoder[Seq[Int]] = ExpressionEncoder()
- /** @since 1.6.1 */
- implicit def newLongSeqEncoder: Encoder[Seq[Long]] = ExpressionEncoder()
+ /**
+ * @since 1.6.1
+ * @deprecated use [[newLongSequenceEncoder]]
+ */
+ def newLongSeqEncoder: Encoder[Seq[Long]] = ExpressionEncoder()
- /** @since 1.6.1 */
- implicit def newDoubleSeqEncoder: Encoder[Seq[Double]] = ExpressionEncoder()
+ /**
+ * @since 1.6.1
+ * @deprecated use [[newDoubleSequenceEncoder]]
+ */
+ def newDoubleSeqEncoder: Encoder[Seq[Double]] = ExpressionEncoder()
- /** @since 1.6.1 */
- implicit def newFloatSeqEncoder: Encoder[Seq[Float]] = ExpressionEncoder()
+ /**
+ * @since 1.6.1
+ * @deprecated use [[newFloatSequenceEncoder]]
+ */
+ def newFloatSeqEncoder: Encoder[Seq[Float]] = ExpressionEncoder()
- /** @since 1.6.1 */
- implicit def newByteSeqEncoder: Encoder[Seq[Byte]] = ExpressionEncoder()
+ /**
+ * @since 1.6.1
+ * @deprecated use [[newByteSequenceEncoder]]
+ */
+ def newByteSeqEncoder: Encoder[Seq[Byte]] = ExpressionEncoder()
- /** @since 1.6.1 */
- implicit def newShortSeqEncoder: Encoder[Seq[Short]] = ExpressionEncoder()
+ /**
+ * @since 1.6.1
+ * @deprecated use [[newShortSequenceEncoder]]
+ */
+ def newShortSeqEncoder: Encoder[Seq[Short]] = ExpressionEncoder()
- /** @since 1.6.1 */
- implicit def newBooleanSeqEncoder: Encoder[Seq[Boolean]] = ExpressionEncoder()
+ /**
+ * @since 1.6.1
+ * @deprecated use [[newBooleanSequenceEncoder]]
+ */
+ def newBooleanSeqEncoder: Encoder[Seq[Boolean]] = ExpressionEncoder()
- /** @since 1.6.1 */
- implicit def newStringSeqEncoder: Encoder[Seq[String]] = ExpressionEncoder()
+ /**
+ * @since 1.6.1
+ * @deprecated use [[newStringSequenceEncoder]]
+ */
+ def newStringSeqEncoder: Encoder[Seq[String]] = ExpressionEncoder()
- /** @since 1.6.1 */
+ /**
+ * @since 1.6.1
+ * @deprecated use [[newProductSequenceEncoder]]
+ */
implicit def newProductSeqEncoder[A <: Product : TypeTag]: Encoder[Seq[A]] = ExpressionEncoder()
+ /** @since 2.2.0 */
+ implicit def newIntSequenceEncoder[T <: Seq[Int] : TypeTag]: Encoder[T] =
+ ExpressionEncoder()
+
+ /** @since 2.2.0 */
+ implicit def newLongSequenceEncoder[T <: Seq[Long] : TypeTag]: Encoder[T] =
+ ExpressionEncoder()
+
+ /** @since 2.2.0 */
+ implicit def newDoubleSequenceEncoder[T <: Seq[Double] : TypeTag]: Encoder[T] =
+ ExpressionEncoder()
+
+ /** @since 2.2.0 */
+ implicit def newFloatSequenceEncoder[T <: Seq[Float] : TypeTag]: Encoder[T] =
+ ExpressionEncoder()
+
+ /** @since 2.2.0 */
+ implicit def newByteSequenceEncoder[T <: Seq[Byte] : TypeTag]: Encoder[T] =
+ ExpressionEncoder()
+
+ /** @since 2.2.0 */
+ implicit def newShortSequenceEncoder[T <: Seq[Short] : TypeTag]: Encoder[T] =
+ ExpressionEncoder()
+
+ /** @since 2.2.0 */
+ implicit def newBooleanSequenceEncoder[T <: Seq[Boolean] : TypeTag]: Encoder[T] =
+ ExpressionEncoder()
+
+ /** @since 2.2.0 */
+ implicit def newStringSequenceEncoder[T <: Seq[String] : TypeTag]: Encoder[T] =
+ ExpressionEncoder()
+
+ /** @since 2.2.0 */
+ implicit def newProductSequenceEncoder[T <: Seq[Product] : TypeTag]: Encoder[T] =
+ ExpressionEncoder()
+
// Arrays
/** @since 1.6.1 */
@@ -193,3 +253,16 @@ abstract class SQLImplicits {
implicit def symbolToColumn(s: Symbol): ColumnName = new ColumnName(s.name)
}
+
+/**
+ * Lower priority implicit methods for converting Scala objects into [[Dataset]]s.
+ * Conflicting implicits are placed here to disambiguate resolution.
+ *
+ * Reasons for including specific implicits:
+ * newProductEncoder - to disambiguate for [[List]]s which are both [[Seq]] and [[Product]]
+ */
+trait LowPrioritySQLImplicits {
+ /** @since 1.6.0 */
+ implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = Encoders.product[T]
+
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
index f8d4c61967..6b50cb3e48 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
@@ -17,10 +17,21 @@
package org.apache.spark.sql
+import scala.collection.immutable.Queue
+import scala.collection.mutable.ArrayBuffer
+
import org.apache.spark.sql.test.SharedSQLContext
case class IntClass(value: Int)
+case class SeqClass(s: Seq[Int])
+
+case class ListClass(l: List[Int])
+
+case class QueueClass(q: Queue[Int])
+
+case class ComplexClass(seq: SeqClass, list: ListClass, queue: QueueClass)
+
package object packageobject {
case class PackageClass(value: Int)
}
@@ -130,6 +141,62 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext {
checkDataset(Seq(Array(Tuple1(1))).toDS(), Array(Tuple1(1)))
}
+ test("arbitrary sequences") {
+ checkDataset(Seq(Queue(1)).toDS(), Queue(1))
+ checkDataset(Seq(Queue(1.toLong)).toDS(), Queue(1.toLong))
+ checkDataset(Seq(Queue(1.toDouble)).toDS(), Queue(1.toDouble))
+ checkDataset(Seq(Queue(1.toFloat)).toDS(), Queue(1.toFloat))
+ checkDataset(Seq(Queue(1.toByte)).toDS(), Queue(1.toByte))
+ checkDataset(Seq(Queue(1.toShort)).toDS(), Queue(1.toShort))
+ checkDataset(Seq(Queue(true)).toDS(), Queue(true))
+ checkDataset(Seq(Queue("test")).toDS(), Queue("test"))
+ checkDataset(Seq(Queue(Tuple1(1))).toDS(), Queue(Tuple1(1)))
+
+ checkDataset(Seq(ArrayBuffer(1)).toDS(), ArrayBuffer(1))
+ checkDataset(Seq(ArrayBuffer(1.toLong)).toDS(), ArrayBuffer(1.toLong))
+ checkDataset(Seq(ArrayBuffer(1.toDouble)).toDS(), ArrayBuffer(1.toDouble))
+ checkDataset(Seq(ArrayBuffer(1.toFloat)).toDS(), ArrayBuffer(1.toFloat))
+ checkDataset(Seq(ArrayBuffer(1.toByte)).toDS(), ArrayBuffer(1.toByte))
+ checkDataset(Seq(ArrayBuffer(1.toShort)).toDS(), ArrayBuffer(1.toShort))
+ checkDataset(Seq(ArrayBuffer(true)).toDS(), ArrayBuffer(true))
+ checkDataset(Seq(ArrayBuffer("test")).toDS(), ArrayBuffer("test"))
+ checkDataset(Seq(ArrayBuffer(Tuple1(1))).toDS(), ArrayBuffer(Tuple1(1)))
+ }
+
+ test("sequence and product combinations") {
+ // Case classes
+ checkDataset(Seq(SeqClass(Seq(1))).toDS(), SeqClass(Seq(1)))
+ checkDataset(Seq(Seq(SeqClass(Seq(1)))).toDS(), Seq(SeqClass(Seq(1))))
+ checkDataset(Seq(List(SeqClass(Seq(1)))).toDS(), List(SeqClass(Seq(1))))
+ checkDataset(Seq(Queue(SeqClass(Seq(1)))).toDS(), Queue(SeqClass(Seq(1))))
+
+ checkDataset(Seq(ListClass(List(1))).toDS(), ListClass(List(1)))
+ checkDataset(Seq(Seq(ListClass(List(1)))).toDS(), Seq(ListClass(List(1))))
+ checkDataset(Seq(List(ListClass(List(1)))).toDS(), List(ListClass(List(1))))
+ checkDataset(Seq(Queue(ListClass(List(1)))).toDS(), Queue(ListClass(List(1))))
+
+ checkDataset(Seq(QueueClass(Queue(1))).toDS(), QueueClass(Queue(1)))
+ checkDataset(Seq(Seq(QueueClass(Queue(1)))).toDS(), Seq(QueueClass(Queue(1))))
+ checkDataset(Seq(List(QueueClass(Queue(1)))).toDS(), List(QueueClass(Queue(1))))
+ checkDataset(Seq(Queue(QueueClass(Queue(1)))).toDS(), Queue(QueueClass(Queue(1))))
+
+ val complex = ComplexClass(SeqClass(Seq(1)), ListClass(List(2)), QueueClass(Queue(3)))
+ checkDataset(Seq(complex).toDS(), complex)
+ checkDataset(Seq(Seq(complex)).toDS(), Seq(complex))
+ checkDataset(Seq(List(complex)).toDS(), List(complex))
+ checkDataset(Seq(Queue(complex)).toDS(), Queue(complex))
+
+ // Tuples
+ checkDataset(Seq(Seq(1) -> Seq(2)).toDS(), Seq(1) -> Seq(2))
+ checkDataset(Seq(List(1) -> Queue(2)).toDS(), List(1) -> Queue(2))
+ checkDataset(Seq(List(Seq("test1") -> List(Queue("test2")))).toDS(),
+ List(Seq("test1") -> List(Queue("test2"))))
+
+ // Complex
+ checkDataset(Seq(ListClass(List(1)) -> Queue("test" -> SeqClass(Seq(2)))).toDS(),
+ ListClass(List(1)) -> Queue("test" -> SeqClass(Seq(2))))
+ }
+
test("package objects") {
import packageobject._
checkDataset(Seq(PackageClass(1)).toDS(), PackageClass(1))