diff options
author | Josh Rosen <joshrosen@databricks.com> | 2015-09-23 11:31:01 -0700 |
---|---|---|
committer | Michael Armbrust <michael@databricks.com> | 2015-09-23 11:31:01 -0700 |
commit | a18208047f06a4244703c17023bb20cbe1f59d73 (patch) | |
tree | 2321398bcd999ab2bab28f5f2bd5dad3c706bee0 | |
parent | 27bfa9ab3a610e072c011fd88ee4684cea6ceb76 (diff) | |
download | spark-a18208047f06a4244703c17023bb20cbe1f59d73.tar.gz spark-a18208047f06a4244703c17023bb20cbe1f59d73.tar.bz2 spark-a18208047f06a4244703c17023bb20cbe1f59d73.zip |
[SPARK-10403] Allow UnsafeRowSerializer to work with tungsten-sort ShuffleManager
This patch attempts to fix an issue where Spark SQL's UnsafeRowSerializer was incompatible with the `tungsten-sort` ShuffleManager.
Author: Josh Rosen <joshrosen@databricks.com>
Closes #8873 from JoshRosen/SPARK-10403.
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala | 22 | ||||
-rw-r--r-- | sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala | 23 |
2 files changed, 27 insertions, 18 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala index e060c06d9e..7e981268de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala @@ -45,16 +45,9 @@ private[sql] class UnsafeRowSerializer(numFields: Int) extends Serializer with S } private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInstance { - - /** - * Marks the end of a stream written with [[serializeStream()]]. - */ - private[this] val EOF: Int = -1 - /** * Serializes a stream of UnsafeRows. Within the stream, each record consists of a record * length (stored as a 4-byte integer, written high byte first), followed by the record's bytes. - * The end of the stream is denoted by a record with the special length `EOF` (-1). */ override def serializeStream(out: OutputStream): SerializationStream = new SerializationStream { private[this] var writeBuffer: Array[Byte] = new Array[Byte](4096) @@ -92,7 +85,6 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst override def close(): Unit = { writeBuffer = null - dOut.writeInt(EOF) dOut.close() } } @@ -104,12 +96,20 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst private[this] var rowBuffer: Array[Byte] = new Array[Byte](1024) private[this] var row: UnsafeRow = new UnsafeRow() private[this] var rowTuple: (Int, UnsafeRow) = (0, row) + private[this] val EOF: Int = -1 override def asKeyValueIterator: Iterator[(Int, UnsafeRow)] = { new Iterator[(Int, UnsafeRow)] { - private[this] var rowSize: Int = dIn.readInt() - if (rowSize == EOF) dIn.close() + private[this] def readSize(): Int = try { + dIn.readInt() + } catch { + case e: EOFException => + dIn.close() + EOF + } + + private[this] var rowSize: Int = readSize() override def hasNext: Boolean = rowSize != EOF override def next(): (Int, UnsafeRow) = { @@ -118,7 +118,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst } ByteStreams.readFully(dIn, rowBuffer, 0, rowSize) row.pointTo(rowBuffer, Platform.BYTE_ARRAY_OFFSET, numFields, rowSize) - rowSize = dIn.readInt() // read the next row's size + rowSize = readSize() if (rowSize == EOF) { // We are returning the last row in this stream dIn.close() val _rowTuple = rowTuple diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala index 0113d052e3..f7d48bc53e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala @@ -17,9 +17,10 @@ package org.apache.spark.sql.execution -import java.io.{File, DataOutputStream, ByteArrayInputStream, ByteArrayOutputStream} +import java.io.{File, ByteArrayInputStream, ByteArrayOutputStream} import org.apache.spark.executor.ShuffleWriteMetrics +import org.apache.spark.rdd.RDD import org.apache.spark.storage.ShuffleBlockId import org.apache.spark.util.collection.ExternalSorter import org.apache.spark.util.Utils @@ -41,7 +42,7 @@ class ClosableByteArrayInputStream(buf: Array[Byte]) extends ByteArrayInputStrea } } -class UnsafeRowSerializerSuite extends SparkFunSuite { +class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext { private def toUnsafeRow(row: Row, schema: Array[DataType]): UnsafeRow = { val converter = unsafeRowConverter(schema) @@ -87,11 +88,7 @@ class UnsafeRowSerializerSuite extends SparkFunSuite { } test("close empty input stream") { - val baos = new ByteArrayOutputStream() - val dout = new DataOutputStream(baos) - dout.writeInt(-1) // EOF - dout.flush() - val input = new ClosableByteArrayInputStream(baos.toByteArray) + val input = new ClosableByteArrayInputStream(Array.empty) val serializer = new UnsafeRowSerializer(numFields = 2).newInstance() val deserializerIter = serializer.deserializeStream(input).asKeyValueIterator assert(!deserializerIter.hasNext) @@ -143,4 +140,16 @@ class UnsafeRowSerializerSuite extends SparkFunSuite { } } } + + test("SPARK-10403: unsafe row serializer with UnsafeShuffleManager") { + val conf = new SparkConf() + .set("spark.shuffle.manager", "tungsten-sort") + sc = new SparkContext("local", "test", conf) + val row = Row("Hello", 123) + val unsafeRow = toUnsafeRow(row, Array(StringType, IntegerType)) + val rowsRDD = sc.parallelize(Seq((0, unsafeRow), (1, unsafeRow), (0, unsafeRow))) + .asInstanceOf[RDD[Product2[Int, InternalRow]]] + val shuffled = new ShuffledRowRDD(rowsRDD, new UnsafeRowSerializer(2), 2) + shuffled.count() + } } |