aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@databricks.com>2015-09-23 11:31:01 -0700
committerMichael Armbrust <michael@databricks.com>2015-09-23 11:31:01 -0700
commita18208047f06a4244703c17023bb20cbe1f59d73 (patch)
tree2321398bcd999ab2bab28f5f2bd5dad3c706bee0 /sql
parent27bfa9ab3a610e072c011fd88ee4684cea6ceb76 (diff)
downloadspark-a18208047f06a4244703c17023bb20cbe1f59d73.tar.gz
spark-a18208047f06a4244703c17023bb20cbe1f59d73.tar.bz2
spark-a18208047f06a4244703c17023bb20cbe1f59d73.zip
[SPARK-10403] Allow UnsafeRowSerializer to work with tungsten-sort ShuffleManager
This patch attempts to fix an issue where Spark SQL's UnsafeRowSerializer was incompatible with the `tungsten-sort` ShuffleManager. Author: Josh Rosen <joshrosen@databricks.com> Closes #8873 from JoshRosen/SPARK-10403.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala22
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala23
2 files changed, 27 insertions, 18 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala
index e060c06d9e..7e981268de 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala
@@ -45,16 +45,9 @@ private[sql] class UnsafeRowSerializer(numFields: Int) extends Serializer with S
}
private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInstance {
-
- /**
- * Marks the end of a stream written with [[serializeStream()]].
- */
- private[this] val EOF: Int = -1
-
/**
* Serializes a stream of UnsafeRows. Within the stream, each record consists of a record
* length (stored as a 4-byte integer, written high byte first), followed by the record's bytes.
- * The end of the stream is denoted by a record with the special length `EOF` (-1).
*/
override def serializeStream(out: OutputStream): SerializationStream = new SerializationStream {
private[this] var writeBuffer: Array[Byte] = new Array[Byte](4096)
@@ -92,7 +85,6 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst
override def close(): Unit = {
writeBuffer = null
- dOut.writeInt(EOF)
dOut.close()
}
}
@@ -104,12 +96,20 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst
private[this] var rowBuffer: Array[Byte] = new Array[Byte](1024)
private[this] var row: UnsafeRow = new UnsafeRow()
private[this] var rowTuple: (Int, UnsafeRow) = (0, row)
+ private[this] val EOF: Int = -1
override def asKeyValueIterator: Iterator[(Int, UnsafeRow)] = {
new Iterator[(Int, UnsafeRow)] {
- private[this] var rowSize: Int = dIn.readInt()
- if (rowSize == EOF) dIn.close()
+ private[this] def readSize(): Int = try {
+ dIn.readInt()
+ } catch {
+ case e: EOFException =>
+ dIn.close()
+ EOF
+ }
+
+ private[this] var rowSize: Int = readSize()
override def hasNext: Boolean = rowSize != EOF
override def next(): (Int, UnsafeRow) = {
@@ -118,7 +118,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst
}
ByteStreams.readFully(dIn, rowBuffer, 0, rowSize)
row.pointTo(rowBuffer, Platform.BYTE_ARRAY_OFFSET, numFields, rowSize)
- rowSize = dIn.readInt() // read the next row's size
+ rowSize = readSize()
if (rowSize == EOF) { // We are returning the last row in this stream
dIn.close()
val _rowTuple = rowTuple
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala
index 0113d052e3..f7d48bc53e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala
@@ -17,9 +17,10 @@
package org.apache.spark.sql.execution
-import java.io.{File, DataOutputStream, ByteArrayInputStream, ByteArrayOutputStream}
+import java.io.{File, ByteArrayInputStream, ByteArrayOutputStream}
import org.apache.spark.executor.ShuffleWriteMetrics
+import org.apache.spark.rdd.RDD
import org.apache.spark.storage.ShuffleBlockId
import org.apache.spark.util.collection.ExternalSorter
import org.apache.spark.util.Utils
@@ -41,7 +42,7 @@ class ClosableByteArrayInputStream(buf: Array[Byte]) extends ByteArrayInputStrea
}
}
-class UnsafeRowSerializerSuite extends SparkFunSuite {
+class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext {
private def toUnsafeRow(row: Row, schema: Array[DataType]): UnsafeRow = {
val converter = unsafeRowConverter(schema)
@@ -87,11 +88,7 @@ class UnsafeRowSerializerSuite extends SparkFunSuite {
}
test("close empty input stream") {
- val baos = new ByteArrayOutputStream()
- val dout = new DataOutputStream(baos)
- dout.writeInt(-1) // EOF
- dout.flush()
- val input = new ClosableByteArrayInputStream(baos.toByteArray)
+ val input = new ClosableByteArrayInputStream(Array.empty)
val serializer = new UnsafeRowSerializer(numFields = 2).newInstance()
val deserializerIter = serializer.deserializeStream(input).asKeyValueIterator
assert(!deserializerIter.hasNext)
@@ -143,4 +140,16 @@ class UnsafeRowSerializerSuite extends SparkFunSuite {
}
}
}
+
+ test("SPARK-10403: unsafe row serializer with UnsafeShuffleManager") {
+ val conf = new SparkConf()
+ .set("spark.shuffle.manager", "tungsten-sort")
+ sc = new SparkContext("local", "test", conf)
+ val row = Row("Hello", 123)
+ val unsafeRow = toUnsafeRow(row, Array(StringType, IntegerType))
+ val rowsRDD = sc.parallelize(Seq((0, unsafeRow), (1, unsafeRow), (0, unsafeRow)))
+ .asInstanceOf[RDD[Product2[Int, InternalRow]]]
+ val shuffled = new ShuffledRowRDD(rowsRDD, new UnsafeRowSerializer(2), 2)
+ shuffled.count()
+ }
}