diff options
author | Reynold Xin <rxin@databricks.com> | 2015-08-07 11:02:53 -0700 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2015-08-07 11:02:53 -0700 |
commit | 76eaa701833a2ff23b50147d70ced41e85719572 (patch) | |
tree | fbfecb1372ceaec7ef7dc1fdb1d59e9740e4887e /sql | |
parent | ebfd91c542aaead343cb154277fcf9114382fee7 (diff) | |
download | spark-76eaa701833a2ff23b50147d70ced41e85719572.tar.gz spark-76eaa701833a2ff23b50147d70ced41e85719572.tar.bz2 spark-76eaa701833a2ff23b50147d70ced41e85719572.zip |
[SPARK-9674][SPARK-9667] Remove SparkSqlSerializer2
It is now subsumed by various Tungsten operators.
Author: Reynold Xin <rxin@databricks.com>
Closes #7981 from rxin/SPARK-9674 and squashes the following commits:
144f96e [Reynold Xin] Re-enable test
58b7332 [Reynold Xin] Disable failing list.
fb797e3 [Reynold Xin] Match all UDTs.
be9f243 [Reynold Xin] Updated if.
71fc99c [Reynold Xin] [SPARK-9674][SPARK-9667] Remove GeneratedAggregate & SparkSqlSerializer2.
Diffstat (limited to 'sql')
4 files changed, 24 insertions, 677 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index ef35c133d9..45d3d8c863 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -416,10 +416,6 @@ private[spark] object SQLConf { val USE_SQL_AGGREGATE2 = booleanConf("spark.sql.useAggregate2", defaultValue = Some(true), doc = "<TODO>") - val USE_SQL_SERIALIZER2 = booleanConf( - "spark.sql.useSerializer2", - defaultValue = Some(true), isPublic = false) - object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" } @@ -488,8 +484,6 @@ private[sql] class SQLConf extends Serializable with CatalystConf { private[spark] def useSqlAggregate2: Boolean = getConf(USE_SQL_AGGREGATE2) - private[spark] def useSqlSerializer2: Boolean = getConf(USE_SQL_SERIALIZER2) - private[spark] def autoBroadcastJoinThreshold: Int = getConf(AUTO_BROADCASTJOIN_THRESHOLD) private[spark] def defaultSizeInBytes: Long = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 6ea5eeedf1..60087f2ca4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.errors.attachTree import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.types.UserDefinedType import org.apache.spark.util.MutablePair import org.apache.spark.{HashPartitioner, Partitioner, RangePartitioner, SparkEnv} @@ -39,21 +40,34 @@ import org.apache.spark.{HashPartitioner, Partitioner, RangePartitioner, SparkEn @DeveloperApi case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends UnaryNode { - override def outputPartitioning: Partitioning = newPartitioning - - override def output: Seq[Attribute] = child.output - - override def outputsUnsafeRows: Boolean = child.outputsUnsafeRows + override def nodeName: String = if (tungstenMode) "TungstenExchange" else "Exchange" - override def canProcessSafeRows: Boolean = true - - override def canProcessUnsafeRows: Boolean = { + /** + * Returns true iff the children outputs aggregate UDTs that are not part of the SQL type. + * This only happens with the old aggregate implementation and should be removed in 1.6. + */ + private lazy val tungstenMode: Boolean = { + val unserializableUDT = child.schema.exists(_.dataType match { + case _: UserDefinedType[_] => true + case _ => false + }) // Do not use the Unsafe path if we are using a RangePartitioning, since this may lead to // an interpreted RowOrdering being applied to an UnsafeRow, which will lead to // ClassCastExceptions at runtime. This check can be removed after SPARK-9054 is fixed. - !newPartitioning.isInstanceOf[RangePartitioning] + !unserializableUDT && !newPartitioning.isInstanceOf[RangePartitioning] } + override def outputPartitioning: Partitioning = newPartitioning + + override def output: Seq[Attribute] = child.output + + // This setting is somewhat counterintuitive: + // If the schema works with UnsafeRow, then we tell the planner that we don't support safe row, + // so the planner inserts a converter to convert data into UnsafeRow if needed. + override def outputsUnsafeRows: Boolean = tungstenMode + override def canProcessSafeRows: Boolean = !tungstenMode + override def canProcessUnsafeRows: Boolean = tungstenMode + /** * Determines whether records must be defensively copied before being sent to the shuffle. * Several of Spark's shuffle components will buffer deserialized Java objects in memory. The @@ -124,23 +138,9 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una private val serializer: Serializer = { val rowDataTypes = child.output.map(_.dataType).toArray - // It is true when there is no field that needs to be write out. - // For now, we will not use SparkSqlSerializer2 when noField is true. - val noField = rowDataTypes == null || rowDataTypes.length == 0 - - val useSqlSerializer2 = - child.sqlContext.conf.useSqlSerializer2 && // SparkSqlSerializer2 is enabled. - SparkSqlSerializer2.support(rowDataTypes) && // The schema of row is supported. - !noField - - if (child.outputsUnsafeRows) { - logInfo("Using UnsafeRowSerializer.") + if (tungstenMode) { new UnsafeRowSerializer(child.output.size) - } else if (useSqlSerializer2) { - logInfo("Using SparkSqlSerializer2.") - new SparkSqlSerializer2(rowDataTypes) } else { - logInfo("Using SparkSqlSerializer.") new SparkSqlSerializer(sparkConf) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala deleted file mode 100644 index e811f1de3e..0000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala +++ /dev/null @@ -1,426 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import java.io._ -import java.math.{BigDecimal, BigInteger} -import java.nio.ByteBuffer - -import scala.reflect.ClassTag - -import org.apache.spark.Logging -import org.apache.spark.serializer._ -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{MutableRow, SpecificMutableRow} -import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String - -/** - * The serialization stream for [[SparkSqlSerializer2]]. It assumes that the object passed in - * its `writeObject` are [[Product2]]. The serialization functions for the key and value of the - * [[Product2]] are constructed based on their schemata. - * The benefit of this serialization stream is that compared with general-purpose serializers like - * Kryo and Java serializer, it can significantly reduce the size of serialized and has a lower - * allocation cost, which can benefit the shuffle operation. Right now, its main limitations are: - * 1. It does not support complex types, i.e. Map, Array, and Struct. - * 2. It assumes that the objects passed in are [[Product2]]. So, it cannot be used when - * [[org.apache.spark.util.collection.ExternalSorter]]'s merge sort operation is used because - * the objects passed in the serializer are not in the type of [[Product2]]. Also also see - * the comment of the `serializer` method in [[Exchange]] for more information on it. - */ -private[sql] class Serializer2SerializationStream( - rowSchema: Array[DataType], - out: OutputStream) - extends SerializationStream with Logging { - - private val rowOut = new DataOutputStream(new BufferedOutputStream(out)) - private val writeRowFunc = SparkSqlSerializer2.createSerializationFunction(rowSchema, rowOut) - - override def writeObject[T: ClassTag](t: T): SerializationStream = { - val kv = t.asInstanceOf[Product2[InternalRow, InternalRow]] - writeKey(kv._1) - writeValue(kv._2) - - this - } - - override def writeKey[T: ClassTag](t: T): SerializationStream = { - // No-op. - this - } - - override def writeValue[T: ClassTag](t: T): SerializationStream = { - writeRowFunc(t.asInstanceOf[InternalRow]) - this - } - - def flush(): Unit = { - rowOut.flush() - } - - def close(): Unit = { - rowOut.close() - } -} - -/** - * The corresponding deserialization stream for [[Serializer2SerializationStream]]. - */ -private[sql] class Serializer2DeserializationStream( - rowSchema: Array[DataType], - in: InputStream) - extends DeserializationStream with Logging { - - private val rowIn = new DataInputStream(new BufferedInputStream(in)) - - private def rowGenerator(schema: Array[DataType]): () => (MutableRow) = { - if (schema == null) { - () => null - } else { - // It is safe to reuse the mutable row. - val mutableRow = new SpecificMutableRow(schema) - () => mutableRow - } - } - - // Functions used to return rows for key and value. - private val getRow = rowGenerator(rowSchema) - // Functions used to read a serialized row from the InputStream and deserialize it. - private val readRowFunc = SparkSqlSerializer2.createDeserializationFunction(rowSchema, rowIn) - - override def readObject[T: ClassTag](): T = { - readValue() - } - - override def readKey[T: ClassTag](): T = { - null.asInstanceOf[T] // intentionally left blank. - } - - override def readValue[T: ClassTag](): T = { - readRowFunc(getRow()).asInstanceOf[T] - } - - override def close(): Unit = { - rowIn.close() - } -} - -private[sql] class SparkSqlSerializer2Instance( - rowSchema: Array[DataType]) - extends SerializerInstance { - - def serialize[T: ClassTag](t: T): ByteBuffer = - throw new UnsupportedOperationException("Not supported.") - - def deserialize[T: ClassTag](bytes: ByteBuffer): T = - throw new UnsupportedOperationException("Not supported.") - - def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = - throw new UnsupportedOperationException("Not supported.") - - def serializeStream(s: OutputStream): SerializationStream = { - new Serializer2SerializationStream(rowSchema, s) - } - - def deserializeStream(s: InputStream): DeserializationStream = { - new Serializer2DeserializationStream(rowSchema, s) - } -} - -/** - * SparkSqlSerializer2 is a special serializer that creates serialization function and - * deserialization function based on the schema of data. It assumes that values passed in - * are Rows. - */ -private[sql] class SparkSqlSerializer2(rowSchema: Array[DataType]) - extends Serializer - with Logging - with Serializable{ - - def newInstance(): SerializerInstance = new SparkSqlSerializer2Instance(rowSchema) - - override def supportsRelocationOfSerializedObjects: Boolean = { - // SparkSqlSerializer2 is stateless and writes no stream headers - true - } -} - -private[sql] object SparkSqlSerializer2 { - - final val NULL = 0 - final val NOT_NULL = 1 - - /** - * Check if rows with the given schema can be serialized with ShuffleSerializer. - * Right now, we do not support a schema having complex types or UDTs, or all data types - * of fields are NullTypes. - */ - def support(schema: Array[DataType]): Boolean = { - if (schema == null) return true - - var allNullTypes = true - var i = 0 - while (i < schema.length) { - schema(i) match { - case NullType => // Do nothing - case udt: UserDefinedType[_] => - allNullTypes = false - return false - case array: ArrayType => - allNullTypes = false - return false - case map: MapType => - allNullTypes = false - return false - case struct: StructType => - allNullTypes = false - return false - case _ => - allNullTypes = false - } - i += 1 - } - - // If types of fields are all NullTypes, we return false. - // Otherwise, we return true. - return !allNullTypes - } - - /** - * The util function to create the serialization function based on the given schema. - */ - def createSerializationFunction(schema: Array[DataType], out: DataOutputStream) - : InternalRow => Unit = { - (row: InternalRow) => - // If the schema is null, the returned function does nothing when it get called. - if (schema != null) { - var i = 0 - while (i < schema.length) { - schema(i) match { - // When we write values to the underlying stream, we also first write the null byte - // first. Then, if the value is not null, we write the contents out. - - case NullType => // Write nothing. - - case BooleanType => - if (row.isNullAt(i)) { - out.writeByte(NULL) - } else { - out.writeByte(NOT_NULL) - out.writeBoolean(row.getBoolean(i)) - } - - case ByteType => - if (row.isNullAt(i)) { - out.writeByte(NULL) - } else { - out.writeByte(NOT_NULL) - out.writeByte(row.getByte(i)) - } - - case ShortType => - if (row.isNullAt(i)) { - out.writeByte(NULL) - } else { - out.writeByte(NOT_NULL) - out.writeShort(row.getShort(i)) - } - - case IntegerType | DateType => - if (row.isNullAt(i)) { - out.writeByte(NULL) - } else { - out.writeByte(NOT_NULL) - out.writeInt(row.getInt(i)) - } - - case LongType | TimestampType => - if (row.isNullAt(i)) { - out.writeByte(NULL) - } else { - out.writeByte(NOT_NULL) - out.writeLong(row.getLong(i)) - } - - case FloatType => - if (row.isNullAt(i)) { - out.writeByte(NULL) - } else { - out.writeByte(NOT_NULL) - out.writeFloat(row.getFloat(i)) - } - - case DoubleType => - if (row.isNullAt(i)) { - out.writeByte(NULL) - } else { - out.writeByte(NOT_NULL) - out.writeDouble(row.getDouble(i)) - } - - case StringType => - if (row.isNullAt(i)) { - out.writeByte(NULL) - } else { - out.writeByte(NOT_NULL) - val bytes = row.getUTF8String(i).getBytes - out.writeInt(bytes.length) - out.write(bytes) - } - - case BinaryType => - if (row.isNullAt(i)) { - out.writeByte(NULL) - } else { - out.writeByte(NOT_NULL) - val bytes = row.getBinary(i) - out.writeInt(bytes.length) - out.write(bytes) - } - - case decimal: DecimalType => - if (row.isNullAt(i)) { - out.writeByte(NULL) - } else { - out.writeByte(NOT_NULL) - val value = row.getDecimal(i, decimal.precision, decimal.scale) - val javaBigDecimal = value.toJavaBigDecimal - // First, write out the unscaled value. - val bytes: Array[Byte] = javaBigDecimal.unscaledValue().toByteArray - out.writeInt(bytes.length) - out.write(bytes) - // Then, write out the scale. - out.writeInt(javaBigDecimal.scale()) - } - } - i += 1 - } - } - } - - /** - * The util function to create the deserialization function based on the given schema. - */ - def createDeserializationFunction( - schema: Array[DataType], - in: DataInputStream): (MutableRow) => InternalRow = { - if (schema == null) { - (mutableRow: MutableRow) => null - } else { - (mutableRow: MutableRow) => { - var i = 0 - while (i < schema.length) { - schema(i) match { - // When we read values from the underlying stream, we also first read the null byte - // first. Then, if the value is not null, we update the field of the mutable row. - - case NullType => mutableRow.setNullAt(i) // Read nothing. - - case BooleanType => - if (in.readByte() == NULL) { - mutableRow.setNullAt(i) - } else { - mutableRow.setBoolean(i, in.readBoolean()) - } - - case ByteType => - if (in.readByte() == NULL) { - mutableRow.setNullAt(i) - } else { - mutableRow.setByte(i, in.readByte()) - } - - case ShortType => - if (in.readByte() == NULL) { - mutableRow.setNullAt(i) - } else { - mutableRow.setShort(i, in.readShort()) - } - - case IntegerType | DateType => - if (in.readByte() == NULL) { - mutableRow.setNullAt(i) - } else { - mutableRow.setInt(i, in.readInt()) - } - - case LongType | TimestampType => - if (in.readByte() == NULL) { - mutableRow.setNullAt(i) - } else { - mutableRow.setLong(i, in.readLong()) - } - - case FloatType => - if (in.readByte() == NULL) { - mutableRow.setNullAt(i) - } else { - mutableRow.setFloat(i, in.readFloat()) - } - - case DoubleType => - if (in.readByte() == NULL) { - mutableRow.setNullAt(i) - } else { - mutableRow.setDouble(i, in.readDouble()) - } - - case StringType => - if (in.readByte() == NULL) { - mutableRow.setNullAt(i) - } else { - val length = in.readInt() - val bytes = new Array[Byte](length) - in.readFully(bytes) - mutableRow.update(i, UTF8String.fromBytes(bytes)) - } - - case BinaryType => - if (in.readByte() == NULL) { - mutableRow.setNullAt(i) - } else { - val length = in.readInt() - val bytes = new Array[Byte](length) - in.readFully(bytes) - mutableRow.update(i, bytes) - } - - case decimal: DecimalType => - if (in.readByte() == NULL) { - mutableRow.setNullAt(i) - } else { - // First, read in the unscaled value. - val length = in.readInt() - val bytes = new Array[Byte](length) - in.readFully(bytes) - val unscaledVal = new BigInteger(bytes) - // Then, read the scale. - val scale = in.readInt() - // Finally, create the Decimal object and set it in the row. - mutableRow.update(i, - Decimal(new BigDecimal(unscaledVal, scale), decimal.precision, decimal.scale)) - } - } - i += 1 - } - - mutableRow - } - } - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala deleted file mode 100644 index 7978ed57a9..0000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala +++ /dev/null @@ -1,221 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import java.sql.{Timestamp, Date} - -import org.apache.spark.sql.test.TestSQLContext -import org.scalatest.BeforeAndAfterAll - -import org.apache.spark.rdd.ShuffledRDD -import org.apache.spark.serializer.Serializer -import org.apache.spark.{ShuffleDependency, SparkFunSuite} -import org.apache.spark.sql.types._ -import org.apache.spark.sql.Row -import org.apache.spark.sql.{MyDenseVectorUDT, QueryTest} - -class SparkSqlSerializer2DataTypeSuite extends SparkFunSuite { - // Make sure that we will not use serializer2 for unsupported data types. - def checkSupported(dataType: DataType, isSupported: Boolean): Unit = { - val testName = - s"${if (dataType == null) null else dataType.toString} is " + - s"${if (isSupported) "supported" else "unsupported"}" - - test(testName) { - assert(SparkSqlSerializer2.support(Array(dataType)) === isSupported) - } - } - - checkSupported(null, isSupported = true) - checkSupported(BooleanType, isSupported = true) - checkSupported(ByteType, isSupported = true) - checkSupported(ShortType, isSupported = true) - checkSupported(IntegerType, isSupported = true) - checkSupported(LongType, isSupported = true) - checkSupported(FloatType, isSupported = true) - checkSupported(DoubleType, isSupported = true) - checkSupported(DateType, isSupported = true) - checkSupported(TimestampType, isSupported = true) - checkSupported(StringType, isSupported = true) - checkSupported(BinaryType, isSupported = true) - checkSupported(DecimalType(10, 5), isSupported = true) - checkSupported(DecimalType.SYSTEM_DEFAULT, isSupported = true) - - // If NullType is the only data type in the schema, we do not support it. - checkSupported(NullType, isSupported = false) - // For now, ArrayType, MapType, and StructType are not supported. - checkSupported(ArrayType(DoubleType, true), isSupported = false) - checkSupported(ArrayType(StringType, false), isSupported = false) - checkSupported(MapType(IntegerType, StringType, true), isSupported = false) - checkSupported(MapType(IntegerType, ArrayType(DoubleType), false), isSupported = false) - checkSupported(StructType(StructField("a", IntegerType, true) :: Nil), isSupported = false) - // UDTs are not supported right now. - checkSupported(new MyDenseVectorUDT, isSupported = false) -} - -abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll { - var allColumns: String = _ - val serializerClass: Class[Serializer] = - classOf[SparkSqlSerializer2].asInstanceOf[Class[Serializer]] - var numShufflePartitions: Int = _ - var useSerializer2: Boolean = _ - - protected lazy val ctx = TestSQLContext - - override def beforeAll(): Unit = { - numShufflePartitions = ctx.conf.numShufflePartitions - useSerializer2 = ctx.conf.useSqlSerializer2 - - ctx.sql("set spark.sql.useSerializer2=true") - - val supportedTypes = - Seq(StringType, BinaryType, NullType, BooleanType, - ByteType, ShortType, IntegerType, LongType, - FloatType, DoubleType, DecimalType.SYSTEM_DEFAULT, DecimalType(6, 5), - DateType, TimestampType) - - val fields = supportedTypes.zipWithIndex.map { case (dataType, index) => - StructField(s"col$index", dataType, true) - } - allColumns = fields.map(_.name).mkString(",") - val schema = StructType(fields) - - // Create a RDD with all data types supported by SparkSqlSerializer2. - val rdd = - ctx.sparkContext.parallelize((1 to 1000), 10).map { i => - Row( - s"str${i}: test serializer2.", - s"binary${i}: test serializer2.".getBytes("UTF-8"), - null, - i % 2 == 0, - i.toByte, - i.toShort, - i, - Long.MaxValue - i.toLong, - (i + 0.25).toFloat, - (i + 0.75), - BigDecimal(Long.MaxValue.toString + ".12345"), - new java.math.BigDecimal(s"${i % 9 + 1}" + ".23456"), - new Date(i), - new Timestamp(i)) - } - - ctx.createDataFrame(rdd, schema).registerTempTable("shuffle") - - super.beforeAll() - } - - override def afterAll(): Unit = { - ctx.dropTempTable("shuffle") - ctx.sql(s"set spark.sql.shuffle.partitions=$numShufflePartitions") - ctx.sql(s"set spark.sql.useSerializer2=$useSerializer2") - super.afterAll() - } - - def checkSerializer[T <: Serializer]( - executedPlan: SparkPlan, - expectedSerializerClass: Class[T]): Unit = { - executedPlan.foreach { - case exchange: Exchange => - val shuffledRDD = exchange.execute() - val dependency = shuffledRDD.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]] - val serializerNotSetMessage = - s"Expected $expectedSerializerClass as the serializer of Exchange. " + - s"However, the serializer was not set." - val serializer = dependency.serializer.getOrElse(fail(serializerNotSetMessage)) - val isExpectedSerializer = - serializer.getClass == expectedSerializerClass || - serializer.getClass == classOf[UnsafeRowSerializer] - val wrongSerializerErrorMessage = - s"Expected ${expectedSerializerClass.getCanonicalName} or " + - s"${classOf[UnsafeRowSerializer].getCanonicalName}. But " + - s"${serializer.getClass.getCanonicalName} is used." - assert(isExpectedSerializer, wrongSerializerErrorMessage) - case _ => // Ignore other nodes. - } - } - - test("key schema and value schema are not nulls") { - val df = ctx.sql(s"SELECT DISTINCT ${allColumns} FROM shuffle") - checkSerializer(df.queryExecution.executedPlan, serializerClass) - checkAnswer( - df, - ctx.table("shuffle").collect()) - } - - test("key schema is null") { - val aggregations = allColumns.split(",").map(c => s"COUNT($c)").mkString(",") - val df = ctx.sql(s"SELECT $aggregations FROM shuffle") - checkSerializer(df.queryExecution.executedPlan, serializerClass) - checkAnswer( - df, - Row(1000, 1000, 0, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000)) - } - - test("value schema is null") { - val df = ctx.sql(s"SELECT col0 FROM shuffle ORDER BY col0") - checkSerializer(df.queryExecution.executedPlan, serializerClass) - assert(df.map(r => r.getString(0)).collect().toSeq === - ctx.table("shuffle").select("col0").map(r => r.getString(0)).collect().sorted.toSeq) - } - - test("no map output field") { - val df = ctx.sql(s"SELECT 1 + 1 FROM shuffle") - checkSerializer(df.queryExecution.executedPlan, classOf[SparkSqlSerializer]) - } - - test("types of fields are all NullTypes") { - // Test range partitioning code path. - val nulls = ctx.sql(s"SELECT null as a, null as b, null as c") - val df = nulls.unionAll(nulls).sort("a") - checkSerializer(df.queryExecution.executedPlan, classOf[SparkSqlSerializer]) - checkAnswer( - df, - Row(null, null, null) :: Row(null, null, null) :: Nil) - - // Test hash partitioning code path. - val oneRow = ctx.sql(s"SELECT DISTINCT null, null, null FROM shuffle") - checkSerializer(oneRow.queryExecution.executedPlan, classOf[SparkSqlSerializer]) - checkAnswer( - oneRow, - Row(null, null, null)) - } -} - -/** Tests SparkSqlSerializer2 with sort based shuffle without sort merge. */ -class SparkSqlSerializer2SortShuffleSuite extends SparkSqlSerializer2Suite { - override def beforeAll(): Unit = { - super.beforeAll() - // Sort merge will not be triggered. - val bypassMergeThreshold = - ctx.sparkContext.conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) - ctx.sql(s"set spark.sql.shuffle.partitions=${bypassMergeThreshold-1}") - } -} - -/** For now, we will use SparkSqlSerializer for sort based shuffle with sort merge. */ -class SparkSqlSerializer2SortMergeShuffleSuite extends SparkSqlSerializer2Suite { - - override def beforeAll(): Unit = { - super.beforeAll() - // To trigger the sort merge. - val bypassMergeThreshold = - ctx.sparkContext.conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) - ctx.sql(s"set spark.sql.shuffle.partitions=${bypassMergeThreshold + 1}") - } -} |