aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2015-08-07 11:02:53 -0700
committerReynold Xin <rxin@databricks.com>2015-08-07 11:02:53 -0700
commit76eaa701833a2ff23b50147d70ced41e85719572 (patch)
treefbfecb1372ceaec7ef7dc1fdb1d59e9740e4887e /sql
parentebfd91c542aaead343cb154277fcf9114382fee7 (diff)
downloadspark-76eaa701833a2ff23b50147d70ced41e85719572.tar.gz
spark-76eaa701833a2ff23b50147d70ced41e85719572.tar.bz2
spark-76eaa701833a2ff23b50147d70ced41e85719572.zip
[SPARK-9674][SPARK-9667] Remove SparkSqlSerializer2
It is now subsumed by various Tungsten operators. Author: Reynold Xin <rxin@databricks.com> Closes #7981 from rxin/SPARK-9674 and squashes the following commits: 144f96e [Reynold Xin] Re-enable test 58b7332 [Reynold Xin] Disable failing list. fb797e3 [Reynold Xin] Match all UDTs. be9f243 [Reynold Xin] Updated if. 71fc99c [Reynold Xin] [SPARK-9674][SPARK-9667] Remove GeneratedAggregate & SparkSqlSerializer2.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala48
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala426
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala221
4 files changed, 24 insertions, 677 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
index ef35c133d9..45d3d8c863 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
@@ -416,10 +416,6 @@ private[spark] object SQLConf {
val USE_SQL_AGGREGATE2 = booleanConf("spark.sql.useAggregate2",
defaultValue = Some(true), doc = "<TODO>")
- val USE_SQL_SERIALIZER2 = booleanConf(
- "spark.sql.useSerializer2",
- defaultValue = Some(true), isPublic = false)
-
object Deprecated {
val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks"
}
@@ -488,8 +484,6 @@ private[sql] class SQLConf extends Serializable with CatalystConf {
private[spark] def useSqlAggregate2: Boolean = getConf(USE_SQL_AGGREGATE2)
- private[spark] def useSqlSerializer2: Boolean = getConf(USE_SQL_SERIALIZER2)
-
private[spark] def autoBroadcastJoinThreshold: Int = getConf(AUTO_BROADCASTJOIN_THRESHOLD)
private[spark] def defaultSizeInBytes: Long =
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
index 6ea5eeedf1..60087f2ca4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
@@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.errors.attachTree
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.types.UserDefinedType
import org.apache.spark.util.MutablePair
import org.apache.spark.{HashPartitioner, Partitioner, RangePartitioner, SparkEnv}
@@ -39,21 +40,34 @@ import org.apache.spark.{HashPartitioner, Partitioner, RangePartitioner, SparkEn
@DeveloperApi
case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends UnaryNode {
- override def outputPartitioning: Partitioning = newPartitioning
-
- override def output: Seq[Attribute] = child.output
-
- override def outputsUnsafeRows: Boolean = child.outputsUnsafeRows
+ override def nodeName: String = if (tungstenMode) "TungstenExchange" else "Exchange"
- override def canProcessSafeRows: Boolean = true
-
- override def canProcessUnsafeRows: Boolean = {
+ /**
+ * Returns true iff the children outputs aggregate UDTs that are not part of the SQL type.
+ * This only happens with the old aggregate implementation and should be removed in 1.6.
+ */
+ private lazy val tungstenMode: Boolean = {
+ val unserializableUDT = child.schema.exists(_.dataType match {
+ case _: UserDefinedType[_] => true
+ case _ => false
+ })
// Do not use the Unsafe path if we are using a RangePartitioning, since this may lead to
// an interpreted RowOrdering being applied to an UnsafeRow, which will lead to
// ClassCastExceptions at runtime. This check can be removed after SPARK-9054 is fixed.
- !newPartitioning.isInstanceOf[RangePartitioning]
+ !unserializableUDT && !newPartitioning.isInstanceOf[RangePartitioning]
}
+ override def outputPartitioning: Partitioning = newPartitioning
+
+ override def output: Seq[Attribute] = child.output
+
+ // This setting is somewhat counterintuitive:
+ // If the schema works with UnsafeRow, then we tell the planner that we don't support safe row,
+ // so the planner inserts a converter to convert data into UnsafeRow if needed.
+ override def outputsUnsafeRows: Boolean = tungstenMode
+ override def canProcessSafeRows: Boolean = !tungstenMode
+ override def canProcessUnsafeRows: Boolean = tungstenMode
+
/**
* Determines whether records must be defensively copied before being sent to the shuffle.
* Several of Spark's shuffle components will buffer deserialized Java objects in memory. The
@@ -124,23 +138,9 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
private val serializer: Serializer = {
val rowDataTypes = child.output.map(_.dataType).toArray
- // It is true when there is no field that needs to be write out.
- // For now, we will not use SparkSqlSerializer2 when noField is true.
- val noField = rowDataTypes == null || rowDataTypes.length == 0
-
- val useSqlSerializer2 =
- child.sqlContext.conf.useSqlSerializer2 && // SparkSqlSerializer2 is enabled.
- SparkSqlSerializer2.support(rowDataTypes) && // The schema of row is supported.
- !noField
-
- if (child.outputsUnsafeRows) {
- logInfo("Using UnsafeRowSerializer.")
+ if (tungstenMode) {
new UnsafeRowSerializer(child.output.size)
- } else if (useSqlSerializer2) {
- logInfo("Using SparkSqlSerializer2.")
- new SparkSqlSerializer2(rowDataTypes)
} else {
- logInfo("Using SparkSqlSerializer.")
new SparkSqlSerializer(sparkConf)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala
deleted file mode 100644
index e811f1de3e..0000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala
+++ /dev/null
@@ -1,426 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution
-
-import java.io._
-import java.math.{BigDecimal, BigInteger}
-import java.nio.ByteBuffer
-
-import scala.reflect.ClassTag
-
-import org.apache.spark.Logging
-import org.apache.spark.serializer._
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{MutableRow, SpecificMutableRow}
-import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.UTF8String
-
-/**
- * The serialization stream for [[SparkSqlSerializer2]]. It assumes that the object passed in
- * its `writeObject` are [[Product2]]. The serialization functions for the key and value of the
- * [[Product2]] are constructed based on their schemata.
- * The benefit of this serialization stream is that compared with general-purpose serializers like
- * Kryo and Java serializer, it can significantly reduce the size of serialized and has a lower
- * allocation cost, which can benefit the shuffle operation. Right now, its main limitations are:
- * 1. It does not support complex types, i.e. Map, Array, and Struct.
- * 2. It assumes that the objects passed in are [[Product2]]. So, it cannot be used when
- * [[org.apache.spark.util.collection.ExternalSorter]]'s merge sort operation is used because
- * the objects passed in the serializer are not in the type of [[Product2]]. Also also see
- * the comment of the `serializer` method in [[Exchange]] for more information on it.
- */
-private[sql] class Serializer2SerializationStream(
- rowSchema: Array[DataType],
- out: OutputStream)
- extends SerializationStream with Logging {
-
- private val rowOut = new DataOutputStream(new BufferedOutputStream(out))
- private val writeRowFunc = SparkSqlSerializer2.createSerializationFunction(rowSchema, rowOut)
-
- override def writeObject[T: ClassTag](t: T): SerializationStream = {
- val kv = t.asInstanceOf[Product2[InternalRow, InternalRow]]
- writeKey(kv._1)
- writeValue(kv._2)
-
- this
- }
-
- override def writeKey[T: ClassTag](t: T): SerializationStream = {
- // No-op.
- this
- }
-
- override def writeValue[T: ClassTag](t: T): SerializationStream = {
- writeRowFunc(t.asInstanceOf[InternalRow])
- this
- }
-
- def flush(): Unit = {
- rowOut.flush()
- }
-
- def close(): Unit = {
- rowOut.close()
- }
-}
-
-/**
- * The corresponding deserialization stream for [[Serializer2SerializationStream]].
- */
-private[sql] class Serializer2DeserializationStream(
- rowSchema: Array[DataType],
- in: InputStream)
- extends DeserializationStream with Logging {
-
- private val rowIn = new DataInputStream(new BufferedInputStream(in))
-
- private def rowGenerator(schema: Array[DataType]): () => (MutableRow) = {
- if (schema == null) {
- () => null
- } else {
- // It is safe to reuse the mutable row.
- val mutableRow = new SpecificMutableRow(schema)
- () => mutableRow
- }
- }
-
- // Functions used to return rows for key and value.
- private val getRow = rowGenerator(rowSchema)
- // Functions used to read a serialized row from the InputStream and deserialize it.
- private val readRowFunc = SparkSqlSerializer2.createDeserializationFunction(rowSchema, rowIn)
-
- override def readObject[T: ClassTag](): T = {
- readValue()
- }
-
- override def readKey[T: ClassTag](): T = {
- null.asInstanceOf[T] // intentionally left blank.
- }
-
- override def readValue[T: ClassTag](): T = {
- readRowFunc(getRow()).asInstanceOf[T]
- }
-
- override def close(): Unit = {
- rowIn.close()
- }
-}
-
-private[sql] class SparkSqlSerializer2Instance(
- rowSchema: Array[DataType])
- extends SerializerInstance {
-
- def serialize[T: ClassTag](t: T): ByteBuffer =
- throw new UnsupportedOperationException("Not supported.")
-
- def deserialize[T: ClassTag](bytes: ByteBuffer): T =
- throw new UnsupportedOperationException("Not supported.")
-
- def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T =
- throw new UnsupportedOperationException("Not supported.")
-
- def serializeStream(s: OutputStream): SerializationStream = {
- new Serializer2SerializationStream(rowSchema, s)
- }
-
- def deserializeStream(s: InputStream): DeserializationStream = {
- new Serializer2DeserializationStream(rowSchema, s)
- }
-}
-
-/**
- * SparkSqlSerializer2 is a special serializer that creates serialization function and
- * deserialization function based on the schema of data. It assumes that values passed in
- * are Rows.
- */
-private[sql] class SparkSqlSerializer2(rowSchema: Array[DataType])
- extends Serializer
- with Logging
- with Serializable{
-
- def newInstance(): SerializerInstance = new SparkSqlSerializer2Instance(rowSchema)
-
- override def supportsRelocationOfSerializedObjects: Boolean = {
- // SparkSqlSerializer2 is stateless and writes no stream headers
- true
- }
-}
-
-private[sql] object SparkSqlSerializer2 {
-
- final val NULL = 0
- final val NOT_NULL = 1
-
- /**
- * Check if rows with the given schema can be serialized with ShuffleSerializer.
- * Right now, we do not support a schema having complex types or UDTs, or all data types
- * of fields are NullTypes.
- */
- def support(schema: Array[DataType]): Boolean = {
- if (schema == null) return true
-
- var allNullTypes = true
- var i = 0
- while (i < schema.length) {
- schema(i) match {
- case NullType => // Do nothing
- case udt: UserDefinedType[_] =>
- allNullTypes = false
- return false
- case array: ArrayType =>
- allNullTypes = false
- return false
- case map: MapType =>
- allNullTypes = false
- return false
- case struct: StructType =>
- allNullTypes = false
- return false
- case _ =>
- allNullTypes = false
- }
- i += 1
- }
-
- // If types of fields are all NullTypes, we return false.
- // Otherwise, we return true.
- return !allNullTypes
- }
-
- /**
- * The util function to create the serialization function based on the given schema.
- */
- def createSerializationFunction(schema: Array[DataType], out: DataOutputStream)
- : InternalRow => Unit = {
- (row: InternalRow) =>
- // If the schema is null, the returned function does nothing when it get called.
- if (schema != null) {
- var i = 0
- while (i < schema.length) {
- schema(i) match {
- // When we write values to the underlying stream, we also first write the null byte
- // first. Then, if the value is not null, we write the contents out.
-
- case NullType => // Write nothing.
-
- case BooleanType =>
- if (row.isNullAt(i)) {
- out.writeByte(NULL)
- } else {
- out.writeByte(NOT_NULL)
- out.writeBoolean(row.getBoolean(i))
- }
-
- case ByteType =>
- if (row.isNullAt(i)) {
- out.writeByte(NULL)
- } else {
- out.writeByte(NOT_NULL)
- out.writeByte(row.getByte(i))
- }
-
- case ShortType =>
- if (row.isNullAt(i)) {
- out.writeByte(NULL)
- } else {
- out.writeByte(NOT_NULL)
- out.writeShort(row.getShort(i))
- }
-
- case IntegerType | DateType =>
- if (row.isNullAt(i)) {
- out.writeByte(NULL)
- } else {
- out.writeByte(NOT_NULL)
- out.writeInt(row.getInt(i))
- }
-
- case LongType | TimestampType =>
- if (row.isNullAt(i)) {
- out.writeByte(NULL)
- } else {
- out.writeByte(NOT_NULL)
- out.writeLong(row.getLong(i))
- }
-
- case FloatType =>
- if (row.isNullAt(i)) {
- out.writeByte(NULL)
- } else {
- out.writeByte(NOT_NULL)
- out.writeFloat(row.getFloat(i))
- }
-
- case DoubleType =>
- if (row.isNullAt(i)) {
- out.writeByte(NULL)
- } else {
- out.writeByte(NOT_NULL)
- out.writeDouble(row.getDouble(i))
- }
-
- case StringType =>
- if (row.isNullAt(i)) {
- out.writeByte(NULL)
- } else {
- out.writeByte(NOT_NULL)
- val bytes = row.getUTF8String(i).getBytes
- out.writeInt(bytes.length)
- out.write(bytes)
- }
-
- case BinaryType =>
- if (row.isNullAt(i)) {
- out.writeByte(NULL)
- } else {
- out.writeByte(NOT_NULL)
- val bytes = row.getBinary(i)
- out.writeInt(bytes.length)
- out.write(bytes)
- }
-
- case decimal: DecimalType =>
- if (row.isNullAt(i)) {
- out.writeByte(NULL)
- } else {
- out.writeByte(NOT_NULL)
- val value = row.getDecimal(i, decimal.precision, decimal.scale)
- val javaBigDecimal = value.toJavaBigDecimal
- // First, write out the unscaled value.
- val bytes: Array[Byte] = javaBigDecimal.unscaledValue().toByteArray
- out.writeInt(bytes.length)
- out.write(bytes)
- // Then, write out the scale.
- out.writeInt(javaBigDecimal.scale())
- }
- }
- i += 1
- }
- }
- }
-
- /**
- * The util function to create the deserialization function based on the given schema.
- */
- def createDeserializationFunction(
- schema: Array[DataType],
- in: DataInputStream): (MutableRow) => InternalRow = {
- if (schema == null) {
- (mutableRow: MutableRow) => null
- } else {
- (mutableRow: MutableRow) => {
- var i = 0
- while (i < schema.length) {
- schema(i) match {
- // When we read values from the underlying stream, we also first read the null byte
- // first. Then, if the value is not null, we update the field of the mutable row.
-
- case NullType => mutableRow.setNullAt(i) // Read nothing.
-
- case BooleanType =>
- if (in.readByte() == NULL) {
- mutableRow.setNullAt(i)
- } else {
- mutableRow.setBoolean(i, in.readBoolean())
- }
-
- case ByteType =>
- if (in.readByte() == NULL) {
- mutableRow.setNullAt(i)
- } else {
- mutableRow.setByte(i, in.readByte())
- }
-
- case ShortType =>
- if (in.readByte() == NULL) {
- mutableRow.setNullAt(i)
- } else {
- mutableRow.setShort(i, in.readShort())
- }
-
- case IntegerType | DateType =>
- if (in.readByte() == NULL) {
- mutableRow.setNullAt(i)
- } else {
- mutableRow.setInt(i, in.readInt())
- }
-
- case LongType | TimestampType =>
- if (in.readByte() == NULL) {
- mutableRow.setNullAt(i)
- } else {
- mutableRow.setLong(i, in.readLong())
- }
-
- case FloatType =>
- if (in.readByte() == NULL) {
- mutableRow.setNullAt(i)
- } else {
- mutableRow.setFloat(i, in.readFloat())
- }
-
- case DoubleType =>
- if (in.readByte() == NULL) {
- mutableRow.setNullAt(i)
- } else {
- mutableRow.setDouble(i, in.readDouble())
- }
-
- case StringType =>
- if (in.readByte() == NULL) {
- mutableRow.setNullAt(i)
- } else {
- val length = in.readInt()
- val bytes = new Array[Byte](length)
- in.readFully(bytes)
- mutableRow.update(i, UTF8String.fromBytes(bytes))
- }
-
- case BinaryType =>
- if (in.readByte() == NULL) {
- mutableRow.setNullAt(i)
- } else {
- val length = in.readInt()
- val bytes = new Array[Byte](length)
- in.readFully(bytes)
- mutableRow.update(i, bytes)
- }
-
- case decimal: DecimalType =>
- if (in.readByte() == NULL) {
- mutableRow.setNullAt(i)
- } else {
- // First, read in the unscaled value.
- val length = in.readInt()
- val bytes = new Array[Byte](length)
- in.readFully(bytes)
- val unscaledVal = new BigInteger(bytes)
- // Then, read the scale.
- val scale = in.readInt()
- // Finally, create the Decimal object and set it in the row.
- mutableRow.update(i,
- Decimal(new BigDecimal(unscaledVal, scale), decimal.precision, decimal.scale))
- }
- }
- i += 1
- }
-
- mutableRow
- }
- }
- }
-}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala
deleted file mode 100644
index 7978ed57a9..0000000000
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala
+++ /dev/null
@@ -1,221 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution
-
-import java.sql.{Timestamp, Date}
-
-import org.apache.spark.sql.test.TestSQLContext
-import org.scalatest.BeforeAndAfterAll
-
-import org.apache.spark.rdd.ShuffledRDD
-import org.apache.spark.serializer.Serializer
-import org.apache.spark.{ShuffleDependency, SparkFunSuite}
-import org.apache.spark.sql.types._
-import org.apache.spark.sql.Row
-import org.apache.spark.sql.{MyDenseVectorUDT, QueryTest}
-
-class SparkSqlSerializer2DataTypeSuite extends SparkFunSuite {
- // Make sure that we will not use serializer2 for unsupported data types.
- def checkSupported(dataType: DataType, isSupported: Boolean): Unit = {
- val testName =
- s"${if (dataType == null) null else dataType.toString} is " +
- s"${if (isSupported) "supported" else "unsupported"}"
-
- test(testName) {
- assert(SparkSqlSerializer2.support(Array(dataType)) === isSupported)
- }
- }
-
- checkSupported(null, isSupported = true)
- checkSupported(BooleanType, isSupported = true)
- checkSupported(ByteType, isSupported = true)
- checkSupported(ShortType, isSupported = true)
- checkSupported(IntegerType, isSupported = true)
- checkSupported(LongType, isSupported = true)
- checkSupported(FloatType, isSupported = true)
- checkSupported(DoubleType, isSupported = true)
- checkSupported(DateType, isSupported = true)
- checkSupported(TimestampType, isSupported = true)
- checkSupported(StringType, isSupported = true)
- checkSupported(BinaryType, isSupported = true)
- checkSupported(DecimalType(10, 5), isSupported = true)
- checkSupported(DecimalType.SYSTEM_DEFAULT, isSupported = true)
-
- // If NullType is the only data type in the schema, we do not support it.
- checkSupported(NullType, isSupported = false)
- // For now, ArrayType, MapType, and StructType are not supported.
- checkSupported(ArrayType(DoubleType, true), isSupported = false)
- checkSupported(ArrayType(StringType, false), isSupported = false)
- checkSupported(MapType(IntegerType, StringType, true), isSupported = false)
- checkSupported(MapType(IntegerType, ArrayType(DoubleType), false), isSupported = false)
- checkSupported(StructType(StructField("a", IntegerType, true) :: Nil), isSupported = false)
- // UDTs are not supported right now.
- checkSupported(new MyDenseVectorUDT, isSupported = false)
-}
-
-abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll {
- var allColumns: String = _
- val serializerClass: Class[Serializer] =
- classOf[SparkSqlSerializer2].asInstanceOf[Class[Serializer]]
- var numShufflePartitions: Int = _
- var useSerializer2: Boolean = _
-
- protected lazy val ctx = TestSQLContext
-
- override def beforeAll(): Unit = {
- numShufflePartitions = ctx.conf.numShufflePartitions
- useSerializer2 = ctx.conf.useSqlSerializer2
-
- ctx.sql("set spark.sql.useSerializer2=true")
-
- val supportedTypes =
- Seq(StringType, BinaryType, NullType, BooleanType,
- ByteType, ShortType, IntegerType, LongType,
- FloatType, DoubleType, DecimalType.SYSTEM_DEFAULT, DecimalType(6, 5),
- DateType, TimestampType)
-
- val fields = supportedTypes.zipWithIndex.map { case (dataType, index) =>
- StructField(s"col$index", dataType, true)
- }
- allColumns = fields.map(_.name).mkString(",")
- val schema = StructType(fields)
-
- // Create a RDD with all data types supported by SparkSqlSerializer2.
- val rdd =
- ctx.sparkContext.parallelize((1 to 1000), 10).map { i =>
- Row(
- s"str${i}: test serializer2.",
- s"binary${i}: test serializer2.".getBytes("UTF-8"),
- null,
- i % 2 == 0,
- i.toByte,
- i.toShort,
- i,
- Long.MaxValue - i.toLong,
- (i + 0.25).toFloat,
- (i + 0.75),
- BigDecimal(Long.MaxValue.toString + ".12345"),
- new java.math.BigDecimal(s"${i % 9 + 1}" + ".23456"),
- new Date(i),
- new Timestamp(i))
- }
-
- ctx.createDataFrame(rdd, schema).registerTempTable("shuffle")
-
- super.beforeAll()
- }
-
- override def afterAll(): Unit = {
- ctx.dropTempTable("shuffle")
- ctx.sql(s"set spark.sql.shuffle.partitions=$numShufflePartitions")
- ctx.sql(s"set spark.sql.useSerializer2=$useSerializer2")
- super.afterAll()
- }
-
- def checkSerializer[T <: Serializer](
- executedPlan: SparkPlan,
- expectedSerializerClass: Class[T]): Unit = {
- executedPlan.foreach {
- case exchange: Exchange =>
- val shuffledRDD = exchange.execute()
- val dependency = shuffledRDD.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]]
- val serializerNotSetMessage =
- s"Expected $expectedSerializerClass as the serializer of Exchange. " +
- s"However, the serializer was not set."
- val serializer = dependency.serializer.getOrElse(fail(serializerNotSetMessage))
- val isExpectedSerializer =
- serializer.getClass == expectedSerializerClass ||
- serializer.getClass == classOf[UnsafeRowSerializer]
- val wrongSerializerErrorMessage =
- s"Expected ${expectedSerializerClass.getCanonicalName} or " +
- s"${classOf[UnsafeRowSerializer].getCanonicalName}. But " +
- s"${serializer.getClass.getCanonicalName} is used."
- assert(isExpectedSerializer, wrongSerializerErrorMessage)
- case _ => // Ignore other nodes.
- }
- }
-
- test("key schema and value schema are not nulls") {
- val df = ctx.sql(s"SELECT DISTINCT ${allColumns} FROM shuffle")
- checkSerializer(df.queryExecution.executedPlan, serializerClass)
- checkAnswer(
- df,
- ctx.table("shuffle").collect())
- }
-
- test("key schema is null") {
- val aggregations = allColumns.split(",").map(c => s"COUNT($c)").mkString(",")
- val df = ctx.sql(s"SELECT $aggregations FROM shuffle")
- checkSerializer(df.queryExecution.executedPlan, serializerClass)
- checkAnswer(
- df,
- Row(1000, 1000, 0, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000))
- }
-
- test("value schema is null") {
- val df = ctx.sql(s"SELECT col0 FROM shuffle ORDER BY col0")
- checkSerializer(df.queryExecution.executedPlan, serializerClass)
- assert(df.map(r => r.getString(0)).collect().toSeq ===
- ctx.table("shuffle").select("col0").map(r => r.getString(0)).collect().sorted.toSeq)
- }
-
- test("no map output field") {
- val df = ctx.sql(s"SELECT 1 + 1 FROM shuffle")
- checkSerializer(df.queryExecution.executedPlan, classOf[SparkSqlSerializer])
- }
-
- test("types of fields are all NullTypes") {
- // Test range partitioning code path.
- val nulls = ctx.sql(s"SELECT null as a, null as b, null as c")
- val df = nulls.unionAll(nulls).sort("a")
- checkSerializer(df.queryExecution.executedPlan, classOf[SparkSqlSerializer])
- checkAnswer(
- df,
- Row(null, null, null) :: Row(null, null, null) :: Nil)
-
- // Test hash partitioning code path.
- val oneRow = ctx.sql(s"SELECT DISTINCT null, null, null FROM shuffle")
- checkSerializer(oneRow.queryExecution.executedPlan, classOf[SparkSqlSerializer])
- checkAnswer(
- oneRow,
- Row(null, null, null))
- }
-}
-
-/** Tests SparkSqlSerializer2 with sort based shuffle without sort merge. */
-class SparkSqlSerializer2SortShuffleSuite extends SparkSqlSerializer2Suite {
- override def beforeAll(): Unit = {
- super.beforeAll()
- // Sort merge will not be triggered.
- val bypassMergeThreshold =
- ctx.sparkContext.conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)
- ctx.sql(s"set spark.sql.shuffle.partitions=${bypassMergeThreshold-1}")
- }
-}
-
-/** For now, we will use SparkSqlSerializer for sort based shuffle with sort merge. */
-class SparkSqlSerializer2SortMergeShuffleSuite extends SparkSqlSerializer2Suite {
-
- override def beforeAll(): Unit = {
- super.beforeAll()
- // To trigger the sort merge.
- val bypassMergeThreshold =
- ctx.sparkContext.conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)
- ctx.sql(s"set spark.sql.shuffle.partitions=${bypassMergeThreshold + 1}")
- }
-}