aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYin Huai <yhuai@databricks.com>2015-04-20 18:42:50 -0700
committerMichael Armbrust <michael@databricks.com>2015-04-20 18:42:50 -0700
commitce7ddabbcd330b19f6d0c17082304dfa6e1621b2 (patch)
tree8f7b0feeffb49c44b8735429651dd568304186dc
parent517bdf36aecdc94ef569b68f0a96892e707b5c7b (diff)
downloadspark-ce7ddabbcd330b19f6d0c17082304dfa6e1621b2.tar.gz
spark-ce7ddabbcd330b19f6d0c17082304dfa6e1621b2.tar.bz2
spark-ce7ddabbcd330b19f6d0c17082304dfa6e1621b2.zip
[SPARK-6368][SQL] Build a specialized serializer for Exchange operator.
JIRA: https://issues.apache.org/jira/browse/SPARK-6368 Author: Yin Huai <yhuai@databricks.com> Closes #5497 from yhuai/serializer2 and squashes the following commits: da562c5 [Yin Huai] Merge remote-tracking branch 'upstream/master' into serializer2 50e0c3d [Yin Huai] When no filed is emitted to shuffle, use SparkSqlSerializer for now. 9f1ed92 [Yin Huai] Merge remote-tracking branch 'upstream/master' into serializer2 6d07678 [Yin Huai] Address comments. 4273b8c [Yin Huai] Enabled SparkSqlSerializer2. 09e587a [Yin Huai] Remove TODO. 791b96a [Yin Huai] Use UTF8String. 60a1487 [Yin Huai] Merge remote-tracking branch 'upstream/master' into serializer2 3e09655 [Yin Huai] Use getAs for Date column. 43b9fb4 [Yin Huai] Test. 8297732 [Yin Huai] Fix test. c9373c8 [Yin Huai] Support DecimalType. 2379eeb [Yin Huai] ASF header. 39704ab [Yin Huai] Specialized serializer for Exchange.
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala59
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala421
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala195
4 files changed, 673 insertions, 6 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
index 5c65f04ee8..4fc5de7e82 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
@@ -64,6 +64,8 @@ private[spark] object SQLConf {
// Set to false when debugging requires the ability to look at invalid query plans.
val DATAFRAME_EAGER_ANALYSIS = "spark.sql.eagerAnalysis"
+ val USE_SQL_SERIALIZER2 = "spark.sql.useSerializer2"
+
object Deprecated {
val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks"
}
@@ -147,6 +149,8 @@ private[sql] class SQLConf extends Serializable {
*/
private[spark] def codegenEnabled: Boolean = getConf(CODEGEN_ENABLED, "false").toBoolean
+ private[spark] def useSqlSerializer2: Boolean = getConf(USE_SQL_SERIALIZER2, "true").toBoolean
+
/**
* Upper bound on the sizes (in bytes) of the tables qualified for the auto conversion to
* a broadcast value during the physical executions of join operations. Setting this to -1
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
index 69a620e1ec..5b2e46962c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
@@ -19,13 +19,15 @@ package org.apache.spark.sql.execution
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.shuffle.sort.SortShuffleManager
-import org.apache.spark.{SparkEnv, HashPartitioner, RangePartitioner, SparkConf}
+import org.apache.spark.{SparkEnv, HashPartitioner, RangePartitioner}
import org.apache.spark.rdd.{RDD, ShuffledRDD}
+import org.apache.spark.serializer.Serializer
import org.apache.spark.sql.{SQLContext, Row}
import org.apache.spark.sql.catalyst.errors.attachTree
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.types.DataType
import org.apache.spark.util.MutablePair
object Exchange {
@@ -77,9 +79,48 @@ case class Exchange(
}
}
- override def execute(): RDD[Row] = attachTree(this , "execute") {
- lazy val sparkConf = child.sqlContext.sparkContext.getConf
+ @transient private lazy val sparkConf = child.sqlContext.sparkContext.getConf
+
+ def serializer(
+ keySchema: Array[DataType],
+ valueSchema: Array[DataType],
+ numPartitions: Int): Serializer = {
+ // In ExternalSorter's spillToMergeableFile function, key-value pairs are written out
+ // through write(key) and then write(value) instead of write((key, value)). Because
+ // SparkSqlSerializer2 assumes that objects passed in are Product2, we cannot safely use
+ // it when spillToMergeableFile in ExternalSorter will be used.
+ // So, we will not use SparkSqlSerializer2 when
+ // - Sort-based shuffle is enabled and the number of reducers (numPartitions) is greater
+ // then the bypassMergeThreshold; or
+ // - newOrdering is defined.
+ val cannotUseSqlSerializer2 =
+ (sortBasedShuffleOn && numPartitions > bypassMergeThreshold) || newOrdering.nonEmpty
+
+ // It is true when there is no field that needs to be write out.
+ // For now, we will not use SparkSqlSerializer2 when noField is true.
+ val noField =
+ (keySchema == null || keySchema.length == 0) &&
+ (valueSchema == null || valueSchema.length == 0)
+
+ val useSqlSerializer2 =
+ child.sqlContext.conf.useSqlSerializer2 && // SparkSqlSerializer2 is enabled.
+ !cannotUseSqlSerializer2 && // Safe to use Serializer2.
+ SparkSqlSerializer2.support(keySchema) && // The schema of key is supported.
+ SparkSqlSerializer2.support(valueSchema) && // The schema of value is supported.
+ !noField
+
+ val serializer = if (useSqlSerializer2) {
+ logInfo("Using SparkSqlSerializer2.")
+ new SparkSqlSerializer2(keySchema, valueSchema)
+ } else {
+ logInfo("Using SparkSqlSerializer.")
+ new SparkSqlSerializer(sparkConf)
+ }
+
+ serializer
+ }
+ override def execute(): RDD[Row] = attachTree(this , "execute") {
newPartitioning match {
case HashPartitioning(expressions, numPartitions) =>
// TODO: Eliminate redundant expressions in grouping key and value.
@@ -111,7 +152,10 @@ case class Exchange(
} else {
new ShuffledRDD[Row, Row, Row](rdd, part)
}
- shuffled.setSerializer(new SparkSqlSerializer(sparkConf))
+ val keySchema = expressions.map(_.dataType).toArray
+ val valueSchema = child.output.map(_.dataType).toArray
+ shuffled.setSerializer(serializer(keySchema, valueSchema, numPartitions))
+
shuffled.map(_._2)
case RangePartitioning(sortingExpressions, numPartitions) =>
@@ -134,7 +178,9 @@ case class Exchange(
} else {
new ShuffledRDD[Row, Null, Null](rdd, part)
}
- shuffled.setSerializer(new SparkSqlSerializer(sparkConf))
+ val keySchema = child.output.map(_.dataType).toArray
+ shuffled.setSerializer(serializer(keySchema, null, numPartitions))
+
shuffled.map(_._1)
case SinglePartition =>
@@ -152,7 +198,8 @@ case class Exchange(
}
val partitioner = new HashPartitioner(1)
val shuffled = new ShuffledRDD[Null, Row, Row](rdd, partitioner)
- shuffled.setSerializer(new SparkSqlSerializer(sparkConf))
+ val valueSchema = child.output.map(_.dataType).toArray
+ shuffled.setSerializer(serializer(null, valueSchema, 1))
shuffled.map(_._2)
case _ => sys.error(s"Exchange not implemented for $newPartitioning")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala
new file mode 100644
index 0000000000..cec97de2cd
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala
@@ -0,0 +1,421 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import java.io._
+import java.math.{BigDecimal, BigInteger}
+import java.nio.ByteBuffer
+import java.sql.Timestamp
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.serializer._
+import org.apache.spark.Logging
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow
+import org.apache.spark.sql.types._
+
+/**
+ * The serialization stream for [[SparkSqlSerializer2]]. It assumes that the object passed in
+ * its `writeObject` are [[Product2]]. The serialization functions for the key and value of the
+ * [[Product2]] are constructed based on their schemata.
+ * The benefit of this serialization stream is that compared with general-purpose serializers like
+ * Kryo and Java serializer, it can significantly reduce the size of serialized and has a lower
+ * allocation cost, which can benefit the shuffle operation. Right now, its main limitations are:
+ * 1. It does not support complex types, i.e. Map, Array, and Struct.
+ * 2. It assumes that the objects passed in are [[Product2]]. So, it cannot be used when
+ * [[org.apache.spark.util.collection.ExternalSorter]]'s merge sort operation is used because
+ * the objects passed in the serializer are not in the type of [[Product2]]. Also also see
+ * the comment of the `serializer` method in [[Exchange]] for more information on it.
+ */
+private[sql] class Serializer2SerializationStream(
+ keySchema: Array[DataType],
+ valueSchema: Array[DataType],
+ out: OutputStream)
+ extends SerializationStream with Logging {
+
+ val rowOut = new DataOutputStream(out)
+ val writeKey = SparkSqlSerializer2.createSerializationFunction(keySchema, rowOut)
+ val writeValue = SparkSqlSerializer2.createSerializationFunction(valueSchema, rowOut)
+
+ def writeObject[T: ClassTag](t: T): SerializationStream = {
+ val kv = t.asInstanceOf[Product2[Row, Row]]
+ writeKey(kv._1)
+ writeValue(kv._2)
+
+ this
+ }
+
+ def flush(): Unit = {
+ rowOut.flush()
+ }
+
+ def close(): Unit = {
+ rowOut.close()
+ }
+}
+
+/**
+ * The corresponding deserialization stream for [[Serializer2SerializationStream]].
+ */
+private[sql] class Serializer2DeserializationStream(
+ keySchema: Array[DataType],
+ valueSchema: Array[DataType],
+ in: InputStream)
+ extends DeserializationStream with Logging {
+
+ val rowIn = new DataInputStream(new BufferedInputStream(in))
+
+ val key = if (keySchema != null) new SpecificMutableRow(keySchema) else null
+ val value = if (valueSchema != null) new SpecificMutableRow(valueSchema) else null
+ val readKey = SparkSqlSerializer2.createDeserializationFunction(keySchema, rowIn, key)
+ val readValue = SparkSqlSerializer2.createDeserializationFunction(valueSchema, rowIn, value)
+
+ def readObject[T: ClassTag](): T = {
+ readKey()
+ readValue()
+
+ (key, value).asInstanceOf[T]
+ }
+
+ def close(): Unit = {
+ rowIn.close()
+ }
+}
+
+private[sql] class ShuffleSerializerInstance(
+ keySchema: Array[DataType],
+ valueSchema: Array[DataType])
+ extends SerializerInstance {
+
+ def serialize[T: ClassTag](t: T): ByteBuffer =
+ throw new UnsupportedOperationException("Not supported.")
+
+ def deserialize[T: ClassTag](bytes: ByteBuffer): T =
+ throw new UnsupportedOperationException("Not supported.")
+
+ def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T =
+ throw new UnsupportedOperationException("Not supported.")
+
+ def serializeStream(s: OutputStream): SerializationStream = {
+ new Serializer2SerializationStream(keySchema, valueSchema, s)
+ }
+
+ def deserializeStream(s: InputStream): DeserializationStream = {
+ new Serializer2DeserializationStream(keySchema, valueSchema, s)
+ }
+}
+
+/**
+ * SparkSqlSerializer2 is a special serializer that creates serialization function and
+ * deserialization function based on the schema of data. It assumes that values passed in
+ * are key/value pairs and values returned from it are also key/value pairs.
+ * The schema of keys is represented by `keySchema` and that of values is represented by
+ * `valueSchema`.
+ */
+private[sql] class SparkSqlSerializer2(keySchema: Array[DataType], valueSchema: Array[DataType])
+ extends Serializer
+ with Logging
+ with Serializable{
+
+ def newInstance(): SerializerInstance = new ShuffleSerializerInstance(keySchema, valueSchema)
+}
+
+private[sql] object SparkSqlSerializer2 {
+
+ final val NULL = 0
+ final val NOT_NULL = 1
+
+ /**
+ * Check if rows with the given schema can be serialized with ShuffleSerializer.
+ */
+ def support(schema: Array[DataType]): Boolean = {
+ if (schema == null) return true
+
+ var i = 0
+ while (i < schema.length) {
+ schema(i) match {
+ case udt: UserDefinedType[_] => return false
+ case array: ArrayType => return false
+ case map: MapType => return false
+ case struct: StructType => return false
+ case _ =>
+ }
+ i += 1
+ }
+
+ return true
+ }
+
+ /**
+ * The util function to create the serialization function based on the given schema.
+ */
+ def createSerializationFunction(schema: Array[DataType], out: DataOutputStream): Row => Unit = {
+ (row: Row) =>
+ // If the schema is null, the returned function does nothing when it get called.
+ if (schema != null) {
+ var i = 0
+ while (i < schema.length) {
+ schema(i) match {
+ // When we write values to the underlying stream, we also first write the null byte
+ // first. Then, if the value is not null, we write the contents out.
+
+ case NullType => // Write nothing.
+
+ case BooleanType =>
+ if (row.isNullAt(i)) {
+ out.writeByte(NULL)
+ } else {
+ out.writeByte(NOT_NULL)
+ out.writeBoolean(row.getBoolean(i))
+ }
+
+ case ByteType =>
+ if (row.isNullAt(i)) {
+ out.writeByte(NULL)
+ } else {
+ out.writeByte(NOT_NULL)
+ out.writeByte(row.getByte(i))
+ }
+
+ case ShortType =>
+ if (row.isNullAt(i)) {
+ out.writeByte(NULL)
+ } else {
+ out.writeByte(NOT_NULL)
+ out.writeShort(row.getShort(i))
+ }
+
+ case IntegerType =>
+ if (row.isNullAt(i)) {
+ out.writeByte(NULL)
+ } else {
+ out.writeByte(NOT_NULL)
+ out.writeInt(row.getInt(i))
+ }
+
+ case LongType =>
+ if (row.isNullAt(i)) {
+ out.writeByte(NULL)
+ } else {
+ out.writeByte(NOT_NULL)
+ out.writeLong(row.getLong(i))
+ }
+
+ case FloatType =>
+ if (row.isNullAt(i)) {
+ out.writeByte(NULL)
+ } else {
+ out.writeByte(NOT_NULL)
+ out.writeFloat(row.getFloat(i))
+ }
+
+ case DoubleType =>
+ if (row.isNullAt(i)) {
+ out.writeByte(NULL)
+ } else {
+ out.writeByte(NOT_NULL)
+ out.writeDouble(row.getDouble(i))
+ }
+
+ case decimal: DecimalType =>
+ if (row.isNullAt(i)) {
+ out.writeByte(NULL)
+ } else {
+ out.writeByte(NOT_NULL)
+ val value = row.apply(i).asInstanceOf[Decimal]
+ val javaBigDecimal = value.toJavaBigDecimal
+ // First, write out the unscaled value.
+ val bytes: Array[Byte] = javaBigDecimal.unscaledValue().toByteArray
+ out.writeInt(bytes.length)
+ out.write(bytes)
+ // Then, write out the scale.
+ out.writeInt(javaBigDecimal.scale())
+ }
+
+ case DateType =>
+ if (row.isNullAt(i)) {
+ out.writeByte(NULL)
+ } else {
+ out.writeByte(NOT_NULL)
+ out.writeInt(row.getAs[Int](i))
+ }
+
+ case TimestampType =>
+ if (row.isNullAt(i)) {
+ out.writeByte(NULL)
+ } else {
+ out.writeByte(NOT_NULL)
+ val timestamp = row.getAs[java.sql.Timestamp](i)
+ val time = timestamp.getTime
+ val nanos = timestamp.getNanos
+ out.writeLong(time - (nanos / 1000000)) // Write the milliseconds value.
+ out.writeInt(nanos) // Write the nanoseconds part.
+ }
+
+ case StringType =>
+ if (row.isNullAt(i)) {
+ out.writeByte(NULL)
+ } else {
+ out.writeByte(NOT_NULL)
+ val bytes = row.getAs[UTF8String](i).getBytes
+ out.writeInt(bytes.length)
+ out.write(bytes)
+ }
+
+ case BinaryType =>
+ if (row.isNullAt(i)) {
+ out.writeByte(NULL)
+ } else {
+ out.writeByte(NOT_NULL)
+ val bytes = row.getAs[Array[Byte]](i)
+ out.writeInt(bytes.length)
+ out.write(bytes)
+ }
+ }
+ i += 1
+ }
+ }
+ }
+
+ /**
+ * The util function to create the deserialization function based on the given schema.
+ */
+ def createDeserializationFunction(
+ schema: Array[DataType],
+ in: DataInputStream,
+ mutableRow: SpecificMutableRow): () => Unit = {
+ () => {
+ // If the schema is null, the returned function does nothing when it get called.
+ if (schema != null) {
+ var i = 0
+ while (i < schema.length) {
+ schema(i) match {
+ // When we read values from the underlying stream, we also first read the null byte
+ // first. Then, if the value is not null, we update the field of the mutable row.
+
+ case NullType => mutableRow.setNullAt(i) // Read nothing.
+
+ case BooleanType =>
+ if (in.readByte() == NULL) {
+ mutableRow.setNullAt(i)
+ } else {
+ mutableRow.setBoolean(i, in.readBoolean())
+ }
+
+ case ByteType =>
+ if (in.readByte() == NULL) {
+ mutableRow.setNullAt(i)
+ } else {
+ mutableRow.setByte(i, in.readByte())
+ }
+
+ case ShortType =>
+ if (in.readByte() == NULL) {
+ mutableRow.setNullAt(i)
+ } else {
+ mutableRow.setShort(i, in.readShort())
+ }
+
+ case IntegerType =>
+ if (in.readByte() == NULL) {
+ mutableRow.setNullAt(i)
+ } else {
+ mutableRow.setInt(i, in.readInt())
+ }
+
+ case LongType =>
+ if (in.readByte() == NULL) {
+ mutableRow.setNullAt(i)
+ } else {
+ mutableRow.setLong(i, in.readLong())
+ }
+
+ case FloatType =>
+ if (in.readByte() == NULL) {
+ mutableRow.setNullAt(i)
+ } else {
+ mutableRow.setFloat(i, in.readFloat())
+ }
+
+ case DoubleType =>
+ if (in.readByte() == NULL) {
+ mutableRow.setNullAt(i)
+ } else {
+ mutableRow.setDouble(i, in.readDouble())
+ }
+
+ case decimal: DecimalType =>
+ if (in.readByte() == NULL) {
+ mutableRow.setNullAt(i)
+ } else {
+ // First, read in the unscaled value.
+ val length = in.readInt()
+ val bytes = new Array[Byte](length)
+ in.readFully(bytes)
+ val unscaledVal = new BigInteger(bytes)
+ // Then, read the scale.
+ val scale = in.readInt()
+ // Finally, create the Decimal object and set it in the row.
+ mutableRow.update(i, Decimal(new BigDecimal(unscaledVal, scale)))
+ }
+
+ case DateType =>
+ if (in.readByte() == NULL) {
+ mutableRow.setNullAt(i)
+ } else {
+ mutableRow.update(i, in.readInt())
+ }
+
+ case TimestampType =>
+ if (in.readByte() == NULL) {
+ mutableRow.setNullAt(i)
+ } else {
+ val time = in.readLong() // Read the milliseconds value.
+ val nanos = in.readInt() // Read the nanoseconds part.
+ val timestamp = new Timestamp(time)
+ timestamp.setNanos(nanos)
+ mutableRow.update(i, timestamp)
+ }
+
+ case StringType =>
+ if (in.readByte() == NULL) {
+ mutableRow.setNullAt(i)
+ } else {
+ val length = in.readInt()
+ val bytes = new Array[Byte](length)
+ in.readFully(bytes)
+ mutableRow.update(i, UTF8String(bytes))
+ }
+
+ case BinaryType =>
+ if (in.readByte() == NULL) {
+ mutableRow.setNullAt(i)
+ } else {
+ val length = in.readInt()
+ val bytes = new Array[Byte](length)
+ in.readFully(bytes)
+ mutableRow.update(i, bytes)
+ }
+ }
+ i += 1
+ }
+ }
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala
new file mode 100644
index 0000000000..27f063d73a
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala
@@ -0,0 +1,195 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import java.sql.{Timestamp, Date}
+
+import org.scalatest.{FunSuite, BeforeAndAfterAll}
+
+import org.apache.spark.rdd.ShuffledRDD
+import org.apache.spark.serializer.Serializer
+import org.apache.spark.ShuffleDependency
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.test.TestSQLContext._
+import org.apache.spark.sql.{MyDenseVectorUDT, QueryTest}
+
+class SparkSqlSerializer2DataTypeSuite extends FunSuite {
+ // Make sure that we will not use serializer2 for unsupported data types.
+ def checkSupported(dataType: DataType, isSupported: Boolean): Unit = {
+ val testName =
+ s"${if (dataType == null) null else dataType.toString} is " +
+ s"${if (isSupported) "supported" else "unsupported"}"
+
+ test(testName) {
+ assert(SparkSqlSerializer2.support(Array(dataType)) === isSupported)
+ }
+ }
+
+ checkSupported(null, isSupported = true)
+ checkSupported(NullType, isSupported = true)
+ checkSupported(BooleanType, isSupported = true)
+ checkSupported(ByteType, isSupported = true)
+ checkSupported(ShortType, isSupported = true)
+ checkSupported(IntegerType, isSupported = true)
+ checkSupported(LongType, isSupported = true)
+ checkSupported(FloatType, isSupported = true)
+ checkSupported(DoubleType, isSupported = true)
+ checkSupported(DateType, isSupported = true)
+ checkSupported(TimestampType, isSupported = true)
+ checkSupported(StringType, isSupported = true)
+ checkSupported(BinaryType, isSupported = true)
+ checkSupported(DecimalType(10, 5), isSupported = true)
+ checkSupported(DecimalType.Unlimited, isSupported = true)
+
+ // For now, ArrayType, MapType, and StructType are not supported.
+ checkSupported(ArrayType(DoubleType, true), isSupported = false)
+ checkSupported(ArrayType(StringType, false), isSupported = false)
+ checkSupported(MapType(IntegerType, StringType, true), isSupported = false)
+ checkSupported(MapType(IntegerType, ArrayType(DoubleType), false), isSupported = false)
+ checkSupported(StructType(StructField("a", IntegerType, true) :: Nil), isSupported = false)
+ // UDTs are not supported right now.
+ checkSupported(new MyDenseVectorUDT, isSupported = false)
+}
+
+abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll {
+ var allColumns: String = _
+ val serializerClass: Class[Serializer] =
+ classOf[SparkSqlSerializer2].asInstanceOf[Class[Serializer]]
+ var numShufflePartitions: Int = _
+ var useSerializer2: Boolean = _
+
+ override def beforeAll(): Unit = {
+ numShufflePartitions = conf.numShufflePartitions
+ useSerializer2 = conf.useSqlSerializer2
+
+ sql("set spark.sql.useSerializer2=true")
+
+ val supportedTypes =
+ Seq(StringType, BinaryType, NullType, BooleanType,
+ ByteType, ShortType, IntegerType, LongType,
+ FloatType, DoubleType, DecimalType.Unlimited, DecimalType(6, 5),
+ DateType, TimestampType)
+
+ val fields = supportedTypes.zipWithIndex.map { case (dataType, index) =>
+ StructField(s"col$index", dataType, true)
+ }
+ allColumns = fields.map(_.name).mkString(",")
+ val schema = StructType(fields)
+
+ // Create a RDD with all data types supported by SparkSqlSerializer2.
+ val rdd =
+ sparkContext.parallelize((1 to 1000), 10).map { i =>
+ Row(
+ s"str${i}: test serializer2.",
+ s"binary${i}: test serializer2.".getBytes("UTF-8"),
+ null,
+ i % 2 == 0,
+ i.toByte,
+ i.toShort,
+ i,
+ Long.MaxValue - i.toLong,
+ (i + 0.25).toFloat,
+ (i + 0.75),
+ BigDecimal(Long.MaxValue.toString + ".12345"),
+ new java.math.BigDecimal(s"${i % 9 + 1}" + ".23456"),
+ new Date(i),
+ new Timestamp(i))
+ }
+
+ createDataFrame(rdd, schema).registerTempTable("shuffle")
+
+ super.beforeAll()
+ }
+
+ override def afterAll(): Unit = {
+ dropTempTable("shuffle")
+ sql(s"set spark.sql.shuffle.partitions=$numShufflePartitions")
+ sql(s"set spark.sql.useSerializer2=$useSerializer2")
+ super.afterAll()
+ }
+
+ def checkSerializer[T <: Serializer](
+ executedPlan: SparkPlan,
+ expectedSerializerClass: Class[T]): Unit = {
+ executedPlan.foreach {
+ case exchange: Exchange =>
+ val shuffledRDD = exchange.execute().firstParent.asInstanceOf[ShuffledRDD[_, _, _]]
+ val dependency = shuffledRDD.getDependencies.head.asInstanceOf[ShuffleDependency[_, _, _]]
+ val serializerNotSetMessage =
+ s"Expected $expectedSerializerClass as the serializer of Exchange. " +
+ s"However, the serializer was not set."
+ val serializer = dependency.serializer.getOrElse(fail(serializerNotSetMessage))
+ assert(serializer.getClass === expectedSerializerClass)
+ case _ => // Ignore other nodes.
+ }
+ }
+
+ test("key schema and value schema are not nulls") {
+ val df = sql(s"SELECT DISTINCT ${allColumns} FROM shuffle")
+ checkSerializer(df.queryExecution.executedPlan, serializerClass)
+ checkAnswer(
+ df,
+ table("shuffle").collect())
+ }
+
+ test("value schema is null") {
+ val df = sql(s"SELECT col0 FROM shuffle ORDER BY col0")
+ checkSerializer(df.queryExecution.executedPlan, serializerClass)
+ assert(
+ df.map(r => r.getString(0)).collect().toSeq ===
+ table("shuffle").select("col0").map(r => r.getString(0)).collect().sorted.toSeq)
+ }
+
+ test("no map output field") {
+ val df = sql(s"SELECT 1 + 1 FROM shuffle")
+ checkSerializer(df.queryExecution.executedPlan, classOf[SparkSqlSerializer])
+ }
+}
+
+/** Tests SparkSqlSerializer2 with sort based shuffle without sort merge. */
+class SparkSqlSerializer2SortShuffleSuite extends SparkSqlSerializer2Suite {
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ // Sort merge will not be triggered.
+ sql("set spark.sql.shuffle.partitions = 200")
+ }
+
+ test("key schema is null") {
+ val aggregations = allColumns.split(",").map(c => s"COUNT($c)").mkString(",")
+ val df = sql(s"SELECT $aggregations FROM shuffle")
+ checkSerializer(df.queryExecution.executedPlan, serializerClass)
+ checkAnswer(
+ df,
+ Row(1000, 1000, 0, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000))
+ }
+}
+
+/** For now, we will use SparkSqlSerializer for sort based shuffle with sort merge. */
+class SparkSqlSerializer2SortMergeShuffleSuite extends SparkSqlSerializer2Suite {
+
+ // We are expecting SparkSqlSerializer.
+ override val serializerClass: Class[Serializer] =
+ classOf[SparkSqlSerializer].asInstanceOf[Class[Serializer]]
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ // To trigger the sort merge.
+ sql("set spark.sql.shuffle.partitions = 201")
+ }
+}