aboutsummaryrefslogtreecommitdiff
path: root/sql/core
diff options
context:
space:
mode:
authorTakuya UESHIN <ueshin@happy-camper.st>2014-05-15 11:20:21 -0700
committerReynold Xin <rxin@apache.org>2014-05-15 11:20:21 -0700
commitdb8cc6f28abe4326cea6f53feb604920e4867a27 (patch)
treeab5dfcb3b8a458129d728ba69c128ecf223696a7 /sql/core
parent3abe2b734a5578966f671c34f1de34b4446b90f1 (diff)
downloadspark-db8cc6f28abe4326cea6f53feb604920e4867a27.tar.gz
spark-db8cc6f28abe4326cea6f53feb604920e4867a27.tar.bz2
spark-db8cc6f28abe4326cea6f53feb604920e4867a27.zip
[SPARK-1845] [SQL] Use AllScalaRegistrar for SparkSqlSerializer to register serializers of ...
...Scala collections. When I execute `orderBy` or `limit` for `SchemaRDD` including `ArrayType` or `MapType`, `SparkSqlSerializer` throws the following exception: ``` com.esotericsoftware.kryo.KryoException: Class cannot be created (missing no-arg constructor): scala.collection.immutable.$colon$colon ``` or ``` com.esotericsoftware.kryo.KryoException: Class cannot be created (missing no-arg constructor): scala.collection.immutable.Vector ``` or ``` com.esotericsoftware.kryo.KryoException: Class cannot be created (missing no-arg constructor): scala.collection.immutable.HashMap$HashTrieMap ``` and so on. This is because registrations of serializers for each concrete collections are missing in `SparkSqlSerializer`. I believe it should use `AllScalaRegistrar`. `AllScalaRegistrar` covers a lot of serializers for concrete classes of `Seq`, `Map` for `ArrayType`, `MapType`. Author: Takuya UESHIN <ueshin@happy-camper.st> Closes #790 from ueshin/issues/SPARK-1845 and squashes the following commits: d1ed992 [Takuya UESHIN] Use AllScalaRegistrar for SparkSqlSerializer to register serializers of Scala collections.
Diffstat (limited to 'sql/core')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala28
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala24
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala30
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/TestData.scala10
4 files changed, 66 insertions, 26 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala
index 94c2a249ef..34b355e906 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala
@@ -24,6 +24,7 @@ import scala.reflect.ClassTag
import com.clearspring.analytics.stream.cardinality.HyperLogLog
import com.esotericsoftware.kryo.io.{Input, Output}
import com.esotericsoftware.kryo.{Serializer, Kryo}
+import com.twitter.chill.AllScalaRegistrar
import org.apache.spark.{SparkEnv, SparkConf}
import org.apache.spark.serializer.KryoSerializer
@@ -35,22 +36,14 @@ private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(co
val kryo = new Kryo()
kryo.setRegistrationRequired(false)
kryo.register(classOf[MutablePair[_, _]])
- kryo.register(classOf[Array[Any]])
- // This is kinda hacky...
- kryo.register(classOf[scala.collection.immutable.Map$Map1], new MapSerializer)
- kryo.register(classOf[scala.collection.immutable.Map$Map2], new MapSerializer)
- kryo.register(classOf[scala.collection.immutable.Map$Map3], new MapSerializer)
- kryo.register(classOf[scala.collection.immutable.Map$Map4], new MapSerializer)
- kryo.register(classOf[scala.collection.immutable.Map[_,_]], new MapSerializer)
- kryo.register(classOf[scala.collection.Map[_,_]], new MapSerializer)
kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericRow])
kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericMutableRow])
kryo.register(classOf[com.clearspring.analytics.stream.cardinality.HyperLogLog],
new HyperLogLogSerializer)
- kryo.register(classOf[scala.collection.mutable.ArrayBuffer[_]])
kryo.register(classOf[scala.math.BigDecimal], new BigDecimalSerializer)
kryo.setReferences(false)
kryo.setClassLoader(Utils.getSparkClassLoader)
+ new AllScalaRegistrar().apply(kryo)
kryo
}
}
@@ -97,20 +90,3 @@ private[sql] class HyperLogLogSerializer extends Serializer[HyperLogLog] {
HyperLogLog.Builder.build(bytes)
}
}
-
-/**
- * Maps do not have a no arg constructor and so cannot be serialized by default. So, we serialize
- * them as `Array[(k,v)]`.
- */
-private[sql] class MapSerializer extends Serializer[Map[_,_]] {
- def write(kryo: Kryo, output: Output, map: Map[_,_]) {
- kryo.writeObject(output, map.flatMap(e => Seq(e._1, e._2)).toArray)
- }
-
- def read(kryo: Kryo, input: Input, tpe: Class[Map[_,_]]): Map[_,_] = {
- kryo.readObject(input, classOf[Array[Any]])
- .sliding(2,2)
- .map { case Array(k,v) => (k,v) }
- .toMap
- }
-}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
index 92a707ea57..f43e98d614 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
@@ -69,12 +69,36 @@ class DslQuerySuite extends QueryTest {
checkAnswer(
testData2.orderBy('a.desc, 'b.asc),
Seq((3,1), (3,2), (2,1), (2,2), (1,1), (1,2)))
+
+ checkAnswer(
+ arrayData.orderBy(GetItem('data, 0).asc),
+ arrayData.collect().sortBy(_.data(0)).toSeq)
+
+ checkAnswer(
+ arrayData.orderBy(GetItem('data, 0).desc),
+ arrayData.collect().sortBy(_.data(0)).reverse.toSeq)
+
+ checkAnswer(
+ mapData.orderBy(GetItem('data, 1).asc),
+ mapData.collect().sortBy(_.data(1)).toSeq)
+
+ checkAnswer(
+ mapData.orderBy(GetItem('data, 1).desc),
+ mapData.collect().sortBy(_.data(1)).reverse.toSeq)
}
test("limit") {
checkAnswer(
testData.limit(10),
testData.take(10).toSeq)
+
+ checkAnswer(
+ arrayData.limit(1),
+ arrayData.take(1).toSeq)
+
+ checkAnswer(
+ mapData.limit(1),
+ mapData.take(1).toSeq)
}
test("average") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 524549eb54..189dccd525 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -85,6 +85,36 @@ class SQLQuerySuite extends QueryTest {
checkAnswer(
sql("SELECT * FROM testData2 ORDER BY a DESC, b ASC"),
Seq((3,1), (3,2), (2,1), (2,2), (1,1), (1,2)))
+
+ checkAnswer(
+ sql("SELECT * FROM arrayData ORDER BY data[0] ASC"),
+ arrayData.collect().sortBy(_.data(0)).toSeq)
+
+ checkAnswer(
+ sql("SELECT * FROM arrayData ORDER BY data[0] DESC"),
+ arrayData.collect().sortBy(_.data(0)).reverse.toSeq)
+
+ checkAnswer(
+ sql("SELECT * FROM mapData ORDER BY data[1] ASC"),
+ mapData.collect().sortBy(_.data(1)).toSeq)
+
+ checkAnswer(
+ sql("SELECT * FROM mapData ORDER BY data[1] DESC"),
+ mapData.collect().sortBy(_.data(1)).reverse.toSeq)
+ }
+
+ test("limit") {
+ checkAnswer(
+ sql("SELECT * FROM testData LIMIT 10"),
+ testData.take(10).toSeq)
+
+ checkAnswer(
+ sql("SELECT * FROM arrayData LIMIT 1"),
+ arrayData.collect().take(1).toSeq)
+
+ checkAnswer(
+ sql("SELECT * FROM mapData LIMIT 1"),
+ mapData.collect().take(1).toSeq)
}
test("average") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
index aa71e274f7..1aca387252 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
@@ -74,6 +74,16 @@ object TestData {
ArrayData(Seq(2,3,4), Seq(Seq(2,3,4))) :: Nil)
arrayData.registerAsTable("arrayData")
+ case class MapData(data: Map[Int, String])
+ val mapData =
+ TestSQLContext.sparkContext.parallelize(
+ MapData(Map(1 -> "a1", 2 -> "b1", 3 -> "c1", 4 -> "d1", 5 -> "e1")) ::
+ MapData(Map(1 -> "a2", 2 -> "b2", 3 -> "c2", 4 -> "d2")) ::
+ MapData(Map(1 -> "a3", 2 -> "b3", 3 -> "c3")) ::
+ MapData(Map(1 -> "a4", 2 -> "b4")) ::
+ MapData(Map(1 -> "a5")) :: Nil)
+ mapData.registerAsTable("mapData")
+
case class StringData(s: String)
val repeatedData =
TestSQLContext.sparkContext.parallelize(List.fill(2)(StringData("test")))