From c50f97dafd2d5bf5a8351efcc1c8d3e2b87efc72 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 13 Aug 2015 17:35:11 -0700 Subject: [SPARK-9943] [SQL] deserialized UnsafeHashedRelation should be serializable When the free memory in executor goes low, the cached broadcast objects need to serialized into disk, but currently the deserialized UnsafeHashedRelation can't be serialized , fail with NPE. This PR fixes that. cc rxin Author: Davies Liu Closes #8174 from davies/serialize_hashed. --- .../spark/sql/execution/joins/HashedRelation.scala | 93 ++++++++++++++-------- .../sql/execution/joins/HashedRelationSuite.scala | 14 ++++ 2 files changed, 74 insertions(+), 33 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index ea02076b41..6c0196c21a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.execution.SparkSqlSerializer import org.apache.spark.sql.execution.metric.LongSQLMetric import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.map.BytesToBytesMap -import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager} +import org.apache.spark.unsafe.memory.{MemoryLocation, ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager} import org.apache.spark.util.Utils import org.apache.spark.util.collection.CompactBuffer import org.apache.spark.{SparkConf, SparkEnv} @@ -247,40 +247,67 @@ private[joins] final class UnsafeHashedRelation( } override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { - out.writeInt(hashTable.size()) - - val iter = hashTable.entrySet().iterator() - while (iter.hasNext) { - val entry = iter.next() - val key = entry.getKey - val values = entry.getValue - - // write all the values as single byte array - var totalSize = 0L - var i = 0 - while (i < values.length) { - totalSize += values(i).getSizeInBytes + 4 + 4 - i += 1 + if (binaryMap != null) { + // This could happen when a cached broadcast object need to be dumped into disk to free memory + out.writeInt(binaryMap.numElements()) + + var buffer = new Array[Byte](64) + def write(addr: MemoryLocation, length: Int): Unit = { + if (buffer.length < length) { + buffer = new Array[Byte](length) + } + Platform.copyMemory(addr.getBaseObject, addr.getBaseOffset, + buffer, Platform.BYTE_ARRAY_OFFSET, length) + out.write(buffer, 0, length) } - assert(totalSize < Integer.MAX_VALUE, "values are too big") - - // [key size] [values size] [key bytes] [values bytes] - out.writeInt(key.getSizeInBytes) - out.writeInt(totalSize.toInt) - out.write(key.getBytes) - i = 0 - while (i < values.length) { - // [num of fields] [num of bytes] [row bytes] - // write the integer in native order, so they can be read by UNSAFE.getInt() - if (ByteOrder.nativeOrder() == ByteOrder.BIG_ENDIAN) { - out.writeInt(values(i).numFields()) - out.writeInt(values(i).getSizeInBytes) - } else { - out.writeInt(Integer.reverseBytes(values(i).numFields())) - out.writeInt(Integer.reverseBytes(values(i).getSizeInBytes)) + + val iter = binaryMap.iterator() + while (iter.hasNext) { + val loc = iter.next() + // [key size] [values size] [key bytes] [values bytes] + out.writeInt(loc.getKeyLength) + out.writeInt(loc.getValueLength) + write(loc.getKeyAddress, loc.getKeyLength) + write(loc.getValueAddress, loc.getValueLength) + } + + } else { + assert(hashTable != null) + out.writeInt(hashTable.size()) + + val iter = hashTable.entrySet().iterator() + while (iter.hasNext) { + val entry = iter.next() + val key = entry.getKey + val values = entry.getValue + + // write all the values as single byte array + var totalSize = 0L + var i = 0 + while (i < values.length) { + totalSize += values(i).getSizeInBytes + 4 + 4 + i += 1 + } + assert(totalSize < Integer.MAX_VALUE, "values are too big") + + // [key size] [values size] [key bytes] [values bytes] + out.writeInt(key.getSizeInBytes) + out.writeInt(totalSize.toInt) + out.write(key.getBytes) + i = 0 + while (i < values.length) { + // [num of fields] [num of bytes] [row bytes] + // write the integer in native order, so they can be read by UNSAFE.getInt() + if (ByteOrder.nativeOrder() == ByteOrder.BIG_ENDIAN) { + out.writeInt(values(i).numFields()) + out.writeInt(values(i).getSizeInBytes) + } else { + out.writeInt(Integer.reverseBytes(values(i).numFields())) + out.writeInt(Integer.reverseBytes(values(i).getSizeInBytes)) + } + out.write(values(i).getBytes) + i += 1 } - out.write(values(i).getBytes) - i += 1 } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index c635b2d51f..d33a967093 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -102,6 +102,14 @@ class HashedRelationSuite extends SparkFunSuite { assert(hashed2.get(toUnsafe(InternalRow(10))) === null) assert(hashed2.get(unsafeData(2)) === data2) assert(numDataRows.value.value === data.length) + + val os2 = new ByteArrayOutputStream() + val out2 = new ObjectOutputStream(os2) + hashed2.asInstanceOf[UnsafeHashedRelation].writeExternal(out2) + out2.flush() + // This depends on that the order of items in BytesToBytesMap.iterator() is exactly the same + // as they are inserted + assert(java.util.Arrays.equals(os2.toByteArray, os.toByteArray)) } test("test serialization empty hash map") { @@ -119,5 +127,11 @@ class HashedRelationSuite extends SparkFunSuite { val toUnsafe = UnsafeProjection.create(schema) val row = toUnsafe(InternalRow(0)) assert(hashed2.get(row) === null) + + val os2 = new ByteArrayOutputStream() + val out2 = new ObjectOutputStream(os2) + hashed2.writeExternal(out2) + out2.flush() + assert(java.util.Arrays.equals(os2.toByteArray, os.toByteArray)) } } -- cgit v1.2.3