aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYin Huai <yhuai@databricks.com>2015-05-07 20:59:42 -0700
committerYin Huai <yhuai@databricks.com>2015-05-07 20:59:42 -0700
commit3af423c92f117b5dd4dc6832dc50911cedb29abc (patch)
tree97f6bc9e5642fd2ee2b533c55a2aa74f8026abe1
parent92f8f803a68e0c16771e9793098c6d76dfdf99af (diff)
downloadspark-3af423c92f117b5dd4dc6832dc50911cedb29abc.tar.gz
spark-3af423c92f117b5dd4dc6832dc50911cedb29abc.tar.bz2
spark-3af423c92f117b5dd4dc6832dc50911cedb29abc.zip
[SPARK-6986] [SQL] Use Serializer2 in more cases.
With https://github.com/apache/spark/commit/0a2b15ce43cf6096e1a7ae060b7c8a4010ce3b92, the serialization stream and deserialization stream has enough information to determine it is handling a key-value pari, a key, or a value. It is safe to use `SparkSqlSerializer2` in more cases. Author: Yin Huai <yhuai@databricks.com> Closes #5849 from yhuai/serializer2MoreCases and squashes the following commits: 53a5eaa [Yin Huai] Josh's comments. 487f540 [Yin Huai] Use BufferedOutputStream. 8385f95 [Yin Huai] Always create a new row at the deserialization side to work with sort merge join. c7e2129 [Yin Huai] Update tests. 4513d13 [Yin Huai] Use Serializer2 in more places.
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala23
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala74
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala30
3 files changed, 69 insertions, 58 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
index 5b2e46962c..f0d54cd6cd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
@@ -84,18 +84,8 @@ case class Exchange(
def serializer(
keySchema: Array[DataType],
valueSchema: Array[DataType],
+ hasKeyOrdering: Boolean,
numPartitions: Int): Serializer = {
- // In ExternalSorter's spillToMergeableFile function, key-value pairs are written out
- // through write(key) and then write(value) instead of write((key, value)). Because
- // SparkSqlSerializer2 assumes that objects passed in are Product2, we cannot safely use
- // it when spillToMergeableFile in ExternalSorter will be used.
- // So, we will not use SparkSqlSerializer2 when
- // - Sort-based shuffle is enabled and the number of reducers (numPartitions) is greater
- // then the bypassMergeThreshold; or
- // - newOrdering is defined.
- val cannotUseSqlSerializer2 =
- (sortBasedShuffleOn && numPartitions > bypassMergeThreshold) || newOrdering.nonEmpty
-
// It is true when there is no field that needs to be write out.
// For now, we will not use SparkSqlSerializer2 when noField is true.
val noField =
@@ -104,14 +94,13 @@ case class Exchange(
val useSqlSerializer2 =
child.sqlContext.conf.useSqlSerializer2 && // SparkSqlSerializer2 is enabled.
- !cannotUseSqlSerializer2 && // Safe to use Serializer2.
SparkSqlSerializer2.support(keySchema) && // The schema of key is supported.
SparkSqlSerializer2.support(valueSchema) && // The schema of value is supported.
!noField
val serializer = if (useSqlSerializer2) {
logInfo("Using SparkSqlSerializer2.")
- new SparkSqlSerializer2(keySchema, valueSchema)
+ new SparkSqlSerializer2(keySchema, valueSchema, hasKeyOrdering)
} else {
logInfo("Using SparkSqlSerializer.")
new SparkSqlSerializer(sparkConf)
@@ -154,7 +143,8 @@ case class Exchange(
}
val keySchema = expressions.map(_.dataType).toArray
val valueSchema = child.output.map(_.dataType).toArray
- shuffled.setSerializer(serializer(keySchema, valueSchema, numPartitions))
+ shuffled.setSerializer(
+ serializer(keySchema, valueSchema, newOrdering.nonEmpty, numPartitions))
shuffled.map(_._2)
@@ -179,7 +169,8 @@ case class Exchange(
new ShuffledRDD[Row, Null, Null](rdd, part)
}
val keySchema = child.output.map(_.dataType).toArray
- shuffled.setSerializer(serializer(keySchema, null, numPartitions))
+ shuffled.setSerializer(
+ serializer(keySchema, null, newOrdering.nonEmpty, numPartitions))
shuffled.map(_._1)
@@ -199,7 +190,7 @@ case class Exchange(
val partitioner = new HashPartitioner(1)
val shuffled = new ShuffledRDD[Null, Row, Row](rdd, partitioner)
val valueSchema = child.output.map(_.dataType).toArray
- shuffled.setSerializer(serializer(null, valueSchema, 1))
+ shuffled.setSerializer(serializer(null, valueSchema, false, 1))
shuffled.map(_._2)
case _ => sys.error(s"Exchange not implemented for $newPartitioning")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala
index 35ad987eb1..256d527d7b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala
@@ -27,7 +27,7 @@ import scala.reflect.ClassTag
import org.apache.spark.serializer._
import org.apache.spark.Logging
import org.apache.spark.sql.Row
-import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow
+import org.apache.spark.sql.catalyst.expressions.{SpecificMutableRow, MutableRow, GenericMutableRow}
import org.apache.spark.sql.types._
/**
@@ -49,9 +49,9 @@ private[sql] class Serializer2SerializationStream(
out: OutputStream)
extends SerializationStream with Logging {
- val rowOut = new DataOutputStream(out)
- val writeKeyFunc = SparkSqlSerializer2.createSerializationFunction(keySchema, rowOut)
- val writeValueFunc = SparkSqlSerializer2.createSerializationFunction(valueSchema, rowOut)
+ private val rowOut = new DataOutputStream(new BufferedOutputStream(out))
+ private val writeKeyFunc = SparkSqlSerializer2.createSerializationFunction(keySchema, rowOut)
+ private val writeValueFunc = SparkSqlSerializer2.createSerializationFunction(valueSchema, rowOut)
override def writeObject[T: ClassTag](t: T): SerializationStream = {
val kv = t.asInstanceOf[Product2[Row, Row]]
@@ -86,31 +86,44 @@ private[sql] class Serializer2SerializationStream(
private[sql] class Serializer2DeserializationStream(
keySchema: Array[DataType],
valueSchema: Array[DataType],
+ hasKeyOrdering: Boolean,
in: InputStream)
extends DeserializationStream with Logging {
- val rowIn = new DataInputStream(new BufferedInputStream(in))
+ private val rowIn = new DataInputStream(new BufferedInputStream(in))
+
+ private def rowGenerator(schema: Array[DataType]): () => (MutableRow) = {
+ if (schema == null) {
+ () => null
+ } else {
+ if (hasKeyOrdering) {
+ // We have key ordering specified in a ShuffledRDD, it is not safe to reuse a mutable row.
+ () => new GenericMutableRow(schema.length)
+ } else {
+ // It is safe to reuse the mutable row.
+ val mutableRow = new SpecificMutableRow(schema)
+ () => mutableRow
+ }
+ }
+ }
- val key = if (keySchema != null) new SpecificMutableRow(keySchema) else null
- val value = if (valueSchema != null) new SpecificMutableRow(valueSchema) else null
- val readKeyFunc = SparkSqlSerializer2.createDeserializationFunction(keySchema, rowIn, key)
- val readValueFunc = SparkSqlSerializer2.createDeserializationFunction(valueSchema, rowIn, value)
+ // Functions used to return rows for key and value.
+ private val getKey = rowGenerator(keySchema)
+ private val getValue = rowGenerator(valueSchema)
+ // Functions used to read a serialized row from the InputStream and deserialize it.
+ private val readKeyFunc = SparkSqlSerializer2.createDeserializationFunction(keySchema, rowIn)
+ private val readValueFunc = SparkSqlSerializer2.createDeserializationFunction(valueSchema, rowIn)
override def readObject[T: ClassTag](): T = {
- readKeyFunc()
- readValueFunc()
-
- (key, value).asInstanceOf[T]
+ (readKeyFunc(getKey()), readValueFunc(getValue())).asInstanceOf[T]
}
override def readKey[T: ClassTag](): T = {
- readKeyFunc()
- key.asInstanceOf[T]
+ readKeyFunc(getKey()).asInstanceOf[T]
}
override def readValue[T: ClassTag](): T = {
- readValueFunc()
- value.asInstanceOf[T]
+ readValueFunc(getValue()).asInstanceOf[T]
}
override def close(): Unit = {
@@ -118,9 +131,10 @@ private[sql] class Serializer2DeserializationStream(
}
}
-private[sql] class ShuffleSerializerInstance(
+private[sql] class SparkSqlSerializer2Instance(
keySchema: Array[DataType],
- valueSchema: Array[DataType])
+ valueSchema: Array[DataType],
+ hasKeyOrdering: Boolean)
extends SerializerInstance {
def serialize[T: ClassTag](t: T): ByteBuffer =
@@ -137,7 +151,7 @@ private[sql] class ShuffleSerializerInstance(
}
def deserializeStream(s: InputStream): DeserializationStream = {
- new Serializer2DeserializationStream(keySchema, valueSchema, s)
+ new Serializer2DeserializationStream(keySchema, valueSchema, hasKeyOrdering, s)
}
}
@@ -148,12 +162,16 @@ private[sql] class ShuffleSerializerInstance(
* The schema of keys is represented by `keySchema` and that of values is represented by
* `valueSchema`.
*/
-private[sql] class SparkSqlSerializer2(keySchema: Array[DataType], valueSchema: Array[DataType])
+private[sql] class SparkSqlSerializer2(
+ keySchema: Array[DataType],
+ valueSchema: Array[DataType],
+ hasKeyOrdering: Boolean)
extends Serializer
with Logging
with Serializable{
- def newInstance(): SerializerInstance = new ShuffleSerializerInstance(keySchema, valueSchema)
+ def newInstance(): SerializerInstance =
+ new SparkSqlSerializer2Instance(keySchema, valueSchema, hasKeyOrdering)
override def supportsRelocationOfSerializedObjects: Boolean = {
// SparkSqlSerializer2 is stateless and writes no stream headers
@@ -323,11 +341,11 @@ private[sql] object SparkSqlSerializer2 {
*/
def createDeserializationFunction(
schema: Array[DataType],
- in: DataInputStream,
- mutableRow: SpecificMutableRow): () => Unit = {
- () => {
- // If the schema is null, the returned function does nothing when it get called.
- if (schema != null) {
+ in: DataInputStream): (MutableRow) => Row = {
+ if (schema == null) {
+ (mutableRow: MutableRow) => null
+ } else {
+ (mutableRow: MutableRow) => {
var i = 0
while (i < schema.length) {
schema(i) match {
@@ -440,6 +458,8 @@ private[sql] object SparkSqlSerializer2 {
}
i += 1
}
+
+ mutableRow
}
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala
index 27f063d73a..15337c4045 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala
@@ -148,6 +148,15 @@ abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll
table("shuffle").collect())
}
+ test("key schema is null") {
+ val aggregations = allColumns.split(",").map(c => s"COUNT($c)").mkString(",")
+ val df = sql(s"SELECT $aggregations FROM shuffle")
+ checkSerializer(df.queryExecution.executedPlan, serializerClass)
+ checkAnswer(
+ df,
+ Row(1000, 1000, 0, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000))
+ }
+
test("value schema is null") {
val df = sql(s"SELECT col0 FROM shuffle ORDER BY col0")
checkSerializer(df.queryExecution.executedPlan, serializerClass)
@@ -167,29 +176,20 @@ class SparkSqlSerializer2SortShuffleSuite extends SparkSqlSerializer2Suite {
override def beforeAll(): Unit = {
super.beforeAll()
// Sort merge will not be triggered.
- sql("set spark.sql.shuffle.partitions = 200")
- }
-
- test("key schema is null") {
- val aggregations = allColumns.split(",").map(c => s"COUNT($c)").mkString(",")
- val df = sql(s"SELECT $aggregations FROM shuffle")
- checkSerializer(df.queryExecution.executedPlan, serializerClass)
- checkAnswer(
- df,
- Row(1000, 1000, 0, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000))
+ val bypassMergeThreshold =
+ sparkContext.conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)
+ sql(s"set spark.sql.shuffle.partitions=${bypassMergeThreshold-1}")
}
}
/** For now, we will use SparkSqlSerializer for sort based shuffle with sort merge. */
class SparkSqlSerializer2SortMergeShuffleSuite extends SparkSqlSerializer2Suite {
- // We are expecting SparkSqlSerializer.
- override val serializerClass: Class[Serializer] =
- classOf[SparkSqlSerializer].asInstanceOf[Class[Serializer]]
-
override def beforeAll(): Unit = {
super.beforeAll()
// To trigger the sort merge.
- sql("set spark.sql.shuffle.partitions = 201")
+ val bypassMergeThreshold =
+ sparkContext.conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)
+ sql(s"set spark.sql.shuffle.partitions=${bypassMergeThreshold + 1}")
}
}