aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authoryzhou2001 <yzhou_1999@yahoo.com>2016-05-03 13:41:04 -0700
committerDavies Liu <davies.liu@gmail.com>2016-05-03 13:41:04 -0700
commita4aed71719b4fc728de93afc623aef05d27bc89a (patch)
tree5663ce28289fe28e1484a1eca6c17788170e6886 /sql
parent659f635d3bd0c0d025bf514dfb1747ed7386ba45 (diff)
downloadspark-a4aed71719b4fc728de93afc623aef05d27bc89a.tar.gz
spark-a4aed71719b4fc728de93afc623aef05d27bc89a.tar.bz2
spark-a4aed71719b4fc728de93afc623aef05d27bc89a.zip
[SPARK-14521] [SQL] StackOverflowError in Kryo when executing TPC-DS
## What changes were proposed in this pull request? Observed stackOverflowError in Kryo when executing TPC-DS Query27. Spark thrift server disables kryo reference tracking (if not specified in conf). When "spark.kryo.referenceTracking" is set to true explicitly in spark-defaults.conf, query executes successfully. The root cause is that the TaskMemoryManager inside MemoryConsumer and LongToUnsafeRowMap were not transient and thus were serialized and broadcast around from within LongHashedRelation, which could potentially cause circular reference inside Kryo. But the TaskMemoryManager is per task and should not be passed around at the first place. This fix makes it transient. ## How was this patch tested? core/test, hive/test, sql/test, catalyst/test, dev/lint-scala, org.apache.spark.sql.hive.execution.HiveCompatibilitySuite, dev/scalastyle, manual test of TBC-DS Query 27 with 1GB data but without the "limit 100" which would cause a NPE due to SPARK-14752. Author: yzhou2001 <yzhou_1999@yahoo.com> Closes #12598 from yzhou2001/master.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala136
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala35
2 files changed, 129 insertions, 42 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
index b280c76c70..315ef6a879 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
@@ -17,7 +17,10 @@
package org.apache.spark.sql.execution.joins
-import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput}
+import java.io._
+
+import com.esotericsoftware.kryo.{Kryo, KryoSerializable}
+import com.esotericsoftware.kryo.io.{Input, Output}
import org.apache.spark.{SparkConf, SparkEnv, SparkException}
import org.apache.spark.memory.{MemoryConsumer, MemoryMode, StaticMemoryManager, TaskMemoryManager}
@@ -116,7 +119,7 @@ private[execution] object HashedRelation {
private[joins] class UnsafeHashedRelation(
private var numFields: Int,
private var binaryMap: BytesToBytesMap)
- extends HashedRelation with Externalizable {
+ extends HashedRelation with Externalizable with KryoSerializable {
private[joins] def this() = this(0, null) // Needed for serialization
@@ -171,10 +174,21 @@ private[joins] class UnsafeHashedRelation(
}
override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
- out.writeInt(numFields)
+ write(out.writeInt, out.writeLong, out.write)
+ }
+
+ override def write(kryo: Kryo, out: Output): Unit = Utils.tryOrIOException {
+ write(out.writeInt, out.writeLong, out.write)
+ }
+
+ private def write(
+ writeInt: (Int) => Unit,
+ writeLong: (Long) => Unit,
+ writeBuffer: (Array[Byte], Int, Int) => Unit) : Unit = {
+ writeInt(numFields)
// TODO: move these into BytesToBytesMap
- out.writeLong(binaryMap.numKeys())
- out.writeLong(binaryMap.numValues())
+ writeLong(binaryMap.numKeys())
+ writeLong(binaryMap.numValues())
var buffer = new Array[Byte](64)
def write(base: Object, offset: Long, length: Int): Unit = {
@@ -182,25 +196,32 @@ private[joins] class UnsafeHashedRelation(
buffer = new Array[Byte](length)
}
Platform.copyMemory(base, offset, buffer, Platform.BYTE_ARRAY_OFFSET, length)
- out.write(buffer, 0, length)
+ writeBuffer(buffer, 0, length)
}
val iter = binaryMap.iterator()
while (iter.hasNext) {
val loc = iter.next()
// [key size] [values size] [key bytes] [value bytes]
- out.writeInt(loc.getKeyLength)
- out.writeInt(loc.getValueLength)
+ writeInt(loc.getKeyLength)
+ writeInt(loc.getValueLength)
write(loc.getKeyBase, loc.getKeyOffset, loc.getKeyLength)
write(loc.getValueBase, loc.getValueOffset, loc.getValueLength)
}
}
override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
- numFields = in.readInt()
+ read(in.readInt, in.readLong, in.readFully)
+ }
+
+ private def read(
+ readInt: () => Int,
+ readLong: () => Long,
+ readBuffer: (Array[Byte], Int, Int) => Unit): Unit = {
+ numFields = readInt()
resultRow = new UnsafeRow(numFields)
- val nKeys = in.readLong()
- val nValues = in.readLong()
+ val nKeys = readLong()
+ val nValues = readLong()
// This is used in Broadcast, shared by multiple tasks, so we use on-heap memory
// TODO(josh): This needs to be revisited before we merge this patch; making this change now
// so that tests compile:
@@ -227,16 +248,16 @@ private[joins] class UnsafeHashedRelation(
var keyBuffer = new Array[Byte](1024)
var valuesBuffer = new Array[Byte](1024)
while (i < nValues) {
- val keySize = in.readInt()
- val valuesSize = in.readInt()
+ val keySize = readInt()
+ val valuesSize = readInt()
if (keySize > keyBuffer.length) {
keyBuffer = new Array[Byte](keySize)
}
- in.readFully(keyBuffer, 0, keySize)
+ readBuffer(keyBuffer, 0, keySize)
if (valuesSize > valuesBuffer.length) {
valuesBuffer = new Array[Byte](valuesSize)
}
- in.readFully(valuesBuffer, 0, valuesSize)
+ readBuffer(valuesBuffer, 0, valuesSize)
val loc = binaryMap.lookup(keyBuffer, Platform.BYTE_ARRAY_OFFSET, keySize)
val putSuceeded = loc.append(keyBuffer, Platform.BYTE_ARRAY_OFFSET, keySize,
@@ -248,6 +269,10 @@ private[joins] class UnsafeHashedRelation(
i += 1
}
}
+
+ override def read(kryo: Kryo, in: Input): Unit = Utils.tryOrIOException {
+ read(in.readInt, in.readLong, in.readBytes)
+ }
}
private[joins] object UnsafeHashedRelation {
@@ -324,8 +349,8 @@ private[joins] object UnsafeHashedRelation {
*
* see http://java-performance.info/implementing-world-fastest-java-int-to-int-hash-map/
*/
-private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, capacity: Int)
- extends MemoryConsumer(mm) with Externalizable {
+private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, capacity: Int)
+ extends MemoryConsumer(mm) with Externalizable with KryoSerializable {
// Whether the keys are stored in dense mode or not.
private var isDense = false
@@ -624,58 +649,85 @@ private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, cap
}
}
- private def writeLongArray(out: ObjectOutput, arr: Array[Long], len: Int): Unit = {
+ private def writeLongArray(
+ writeBuffer: (Array[Byte], Int, Int) => Unit,
+ arr: Array[Long],
+ len: Int): Unit = {
val buffer = new Array[Byte](4 << 10)
var offset: Long = Platform.LONG_ARRAY_OFFSET
val end = len * 8L + Platform.LONG_ARRAY_OFFSET
while (offset < end) {
val size = Math.min(buffer.length, (end - offset).toInt)
Platform.copyMemory(arr, offset, buffer, Platform.BYTE_ARRAY_OFFSET, size)
- out.write(buffer, 0, size)
+ writeBuffer(buffer, 0, size)
offset += size
}
}
- override def writeExternal(out: ObjectOutput): Unit = {
- out.writeBoolean(isDense)
- out.writeLong(minKey)
- out.writeLong(maxKey)
- out.writeLong(numKeys)
- out.writeLong(numValues)
-
- out.writeLong(array.length)
- writeLongArray(out, array, array.length)
+ private def write(
+ writeBoolean: (Boolean) => Unit,
+ writeLong: (Long) => Unit,
+ writeBuffer: (Array[Byte], Int, Int) => Unit): Unit = {
+ writeBoolean(isDense)
+ writeLong(minKey)
+ writeLong(maxKey)
+ writeLong(numKeys)
+ writeLong(numValues)
+
+ writeLong(array.length)
+ writeLongArray(writeBuffer, array, array.length)
val used = ((cursor - Platform.LONG_ARRAY_OFFSET) / 8).toInt
- out.writeLong(used)
- writeLongArray(out, page, used)
+ writeLong(used)
+ writeLongArray(writeBuffer, page, used)
}
- private def readLongArray(in: ObjectInput, length: Int): Array[Long] = {
+ override def writeExternal(output: ObjectOutput): Unit = {
+ write(output.writeBoolean, output.writeLong, output.write)
+ }
+
+ override def write(kryo: Kryo, out: Output): Unit = {
+ write(out.writeBoolean, out.writeLong, out.write)
+ }
+
+ private def readLongArray(
+ readBuffer: (Array[Byte], Int, Int) => Unit,
+ length: Int): Array[Long] = {
val array = new Array[Long](length)
val buffer = new Array[Byte](4 << 10)
var offset: Long = Platform.LONG_ARRAY_OFFSET
val end = length * 8L + Platform.LONG_ARRAY_OFFSET
while (offset < end) {
val size = Math.min(buffer.length, (end - offset).toInt)
- in.readFully(buffer, 0, size)
+ readBuffer(buffer, 0, size)
Platform.copyMemory(buffer, Platform.BYTE_ARRAY_OFFSET, array, offset, size)
offset += size
}
array
}
- override def readExternal(in: ObjectInput): Unit = {
- isDense = in.readBoolean()
- minKey = in.readLong()
- maxKey = in.readLong()
- numKeys = in.readLong
- numValues = in.readLong()
+ private def read(
+ readBoolean: () => Boolean,
+ readLong: () => Long,
+ readBuffer: (Array[Byte], Int, Int) => Unit): Unit = {
+ isDense = readBoolean()
+ minKey = readLong()
+ maxKey = readLong()
+ numKeys = readLong()
+ numValues = readLong()
- val length = in.readLong().toInt
+ val length = readLong().toInt
mask = length - 2
- array = readLongArray(in, length)
- val pageLength = in.readLong().toInt
- page = readLongArray(in, pageLength)
+ array = readLongArray(readBuffer, length)
+ val pageLength = readLong().toInt
+ page = readLongArray(readBuffer, pageLength)
+ }
+
+ override def readExternal(in: ObjectInput): Unit = {
+ read(in.readBoolean, in.readLong, in.readFully)
+ }
+
+ override def read(kryo: Kryo, in: Input): Unit = {
+ read(in.readBoolean, in.readLong, in.readBytes)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
index 3ee25c0996..9826a64fe2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
@@ -21,6 +21,7 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream,
import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager}
+import org.apache.spark.serializer.KryoSerializer
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.test.SharedSQLContext
@@ -151,6 +152,40 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext {
}
}
+ test("Spark-14521") {
+ val ser = new KryoSerializer(
+ (new SparkConf).set("spark.kryo.referenceTracking", "false")).newInstance()
+ val key = Seq(BoundReference(0, IntegerType, false))
+
+ // Testing Kryo serialization of HashedRelation
+ val unsafeProj = UnsafeProjection.create(
+ Seq(BoundReference(0, IntegerType, false), BoundReference(1, IntegerType, true)))
+ val rows = (0 until 100).map(i => unsafeProj(InternalRow(i, i + 1)).copy())
+ val longRelation = LongHashedRelation(rows.iterator ++ rows.iterator, key, 100, mm)
+ val longRelation2 = ser.deserialize[LongHashedRelation](ser.serialize(longRelation))
+ (0 until 100).foreach { i =>
+ val rows = longRelation2.get(i).toArray
+ assert(rows.length === 2)
+ assert(rows(0).getInt(0) === i)
+ assert(rows(0).getInt(1) === i + 1)
+ assert(rows(1).getInt(0) === i)
+ assert(rows(1).getInt(1) === i + 1)
+ }
+
+ // Testing Kryo serialization of UnsafeHashedRelation
+ val unsafeHashed = UnsafeHashedRelation(rows.iterator, key, 1, mm)
+ val os = new ByteArrayOutputStream()
+ val out = new ObjectOutputStream(os)
+ unsafeHashed.asInstanceOf[UnsafeHashedRelation].writeExternal(out)
+ out.flush()
+ val unsafeHashed2 = ser.deserialize[UnsafeHashedRelation](ser.serialize(unsafeHashed))
+ val os2 = new ByteArrayOutputStream()
+ val out2 = new ObjectOutputStream(os2)
+ unsafeHashed2.writeExternal(out2)
+ out2.flush()
+ assert(java.util.Arrays.equals(os.toByteArray, os2.toByteArray))
+ }
+
// This test require 4G heap to run, should run it manually
ignore("build HashedRelation that is larger than 1G") {
val unsafeProj = UnsafeProjection.create(