aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-08-13 17:35:11 -0700
committerDavies Liu <davies.liu@gmail.com>2015-08-13 17:35:11 -0700
commitc50f97dafd2d5bf5a8351efcc1c8d3e2b87efc72 (patch)
treefb01f5f832a949eda36d392d016dabb244286dc8
parent693949ba4096c01a0b41da2542ff316823464a16 (diff)
downloadspark-c50f97dafd2d5bf5a8351efcc1c8d3e2b87efc72.tar.gz
spark-c50f97dafd2d5bf5a8351efcc1c8d3e2b87efc72.tar.bz2
spark-c50f97dafd2d5bf5a8351efcc1c8d3e2b87efc72.zip
[SPARK-9943] [SQL] deserialized UnsafeHashedRelation should be serializable
When the free memory in executor goes low, the cached broadcast objects need to serialized into disk, but currently the deserialized UnsafeHashedRelation can't be serialized , fail with NPE. This PR fixes that. cc rxin Author: Davies Liu <davies@databricks.com> Closes #8174 from davies/serialize_hashed.
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala93
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala14
2 files changed, 74 insertions, 33 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
index ea02076b41..6c0196c21a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
@@ -28,7 +28,7 @@ import org.apache.spark.sql.execution.SparkSqlSerializer
import org.apache.spark.sql.execution.metric.LongSQLMetric
import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.map.BytesToBytesMap
-import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager}
+import org.apache.spark.unsafe.memory.{MemoryLocation, ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager}
import org.apache.spark.util.Utils
import org.apache.spark.util.collection.CompactBuffer
import org.apache.spark.{SparkConf, SparkEnv}
@@ -247,40 +247,67 @@ private[joins] final class UnsafeHashedRelation(
}
override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
- out.writeInt(hashTable.size())
-
- val iter = hashTable.entrySet().iterator()
- while (iter.hasNext) {
- val entry = iter.next()
- val key = entry.getKey
- val values = entry.getValue
-
- // write all the values as single byte array
- var totalSize = 0L
- var i = 0
- while (i < values.length) {
- totalSize += values(i).getSizeInBytes + 4 + 4
- i += 1
+ if (binaryMap != null) {
+ // This could happen when a cached broadcast object need to be dumped into disk to free memory
+ out.writeInt(binaryMap.numElements())
+
+ var buffer = new Array[Byte](64)
+ def write(addr: MemoryLocation, length: Int): Unit = {
+ if (buffer.length < length) {
+ buffer = new Array[Byte](length)
+ }
+ Platform.copyMemory(addr.getBaseObject, addr.getBaseOffset,
+ buffer, Platform.BYTE_ARRAY_OFFSET, length)
+ out.write(buffer, 0, length)
}
- assert(totalSize < Integer.MAX_VALUE, "values are too big")
-
- // [key size] [values size] [key bytes] [values bytes]
- out.writeInt(key.getSizeInBytes)
- out.writeInt(totalSize.toInt)
- out.write(key.getBytes)
- i = 0
- while (i < values.length) {
- // [num of fields] [num of bytes] [row bytes]
- // write the integer in native order, so they can be read by UNSAFE.getInt()
- if (ByteOrder.nativeOrder() == ByteOrder.BIG_ENDIAN) {
- out.writeInt(values(i).numFields())
- out.writeInt(values(i).getSizeInBytes)
- } else {
- out.writeInt(Integer.reverseBytes(values(i).numFields()))
- out.writeInt(Integer.reverseBytes(values(i).getSizeInBytes))
+
+ val iter = binaryMap.iterator()
+ while (iter.hasNext) {
+ val loc = iter.next()
+ // [key size] [values size] [key bytes] [values bytes]
+ out.writeInt(loc.getKeyLength)
+ out.writeInt(loc.getValueLength)
+ write(loc.getKeyAddress, loc.getKeyLength)
+ write(loc.getValueAddress, loc.getValueLength)
+ }
+
+ } else {
+ assert(hashTable != null)
+ out.writeInt(hashTable.size())
+
+ val iter = hashTable.entrySet().iterator()
+ while (iter.hasNext) {
+ val entry = iter.next()
+ val key = entry.getKey
+ val values = entry.getValue
+
+ // write all the values as single byte array
+ var totalSize = 0L
+ var i = 0
+ while (i < values.length) {
+ totalSize += values(i).getSizeInBytes + 4 + 4
+ i += 1
+ }
+ assert(totalSize < Integer.MAX_VALUE, "values are too big")
+
+ // [key size] [values size] [key bytes] [values bytes]
+ out.writeInt(key.getSizeInBytes)
+ out.writeInt(totalSize.toInt)
+ out.write(key.getBytes)
+ i = 0
+ while (i < values.length) {
+ // [num of fields] [num of bytes] [row bytes]
+ // write the integer in native order, so they can be read by UNSAFE.getInt()
+ if (ByteOrder.nativeOrder() == ByteOrder.BIG_ENDIAN) {
+ out.writeInt(values(i).numFields())
+ out.writeInt(values(i).getSizeInBytes)
+ } else {
+ out.writeInt(Integer.reverseBytes(values(i).numFields()))
+ out.writeInt(Integer.reverseBytes(values(i).getSizeInBytes))
+ }
+ out.write(values(i).getBytes)
+ i += 1
}
- out.write(values(i).getBytes)
- i += 1
}
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
index c635b2d51f..d33a967093 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
@@ -102,6 +102,14 @@ class HashedRelationSuite extends SparkFunSuite {
assert(hashed2.get(toUnsafe(InternalRow(10))) === null)
assert(hashed2.get(unsafeData(2)) === data2)
assert(numDataRows.value.value === data.length)
+
+ val os2 = new ByteArrayOutputStream()
+ val out2 = new ObjectOutputStream(os2)
+ hashed2.asInstanceOf[UnsafeHashedRelation].writeExternal(out2)
+ out2.flush()
+ // This depends on that the order of items in BytesToBytesMap.iterator() is exactly the same
+ // as they are inserted
+ assert(java.util.Arrays.equals(os2.toByteArray, os.toByteArray))
}
test("test serialization empty hash map") {
@@ -119,5 +127,11 @@ class HashedRelationSuite extends SparkFunSuite {
val toUnsafe = UnsafeProjection.create(schema)
val row = toUnsafe(InternalRow(0))
assert(hashed2.get(row) === null)
+
+ val os2 = new ByteArrayOutputStream()
+ val out2 = new ObjectOutputStream(os2)
+ hashed2.writeExternal(out2)
+ out2.flush()
+ assert(java.util.Arrays.equals(os2.toByteArray, os.toByteArray))
}
}