aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2016-04-27 21:23:40 -0700
committerDavies Liu <davies.liu@gmail.com>2016-04-27 21:23:40 -0700
commitae4e3def5eacb8e383a3535e6c685897fd1aaf4c (patch)
tree47b2a6bf17de619e8166fce6359e183d0f216579 /sql
parentf5da592fc63b8d3bc09d49c196d6c5d98cd2a013 (diff)
downloadspark-ae4e3def5eacb8e383a3535e6c685897fd1aaf4c.tar.gz
spark-ae4e3def5eacb8e383a3535e6c685897fd1aaf4c.tar.bz2
spark-ae4e3def5eacb8e383a3535e6c685897fd1aaf4c.zip
[SPARK-14961] Build HashedRelation larger than 1G
## What changes were proposed in this pull request? Currently, LongToUnsafeRowMap use byte array as the underlying page, which can't be larger 1G. This PR improves LongToUnsafeRowMap to scale up to 8G bytes by using array of Long instead of array of byte. ## How was this patch tested? Manually ran a test to confirm that both UnsafeHashedRelation and LongHashedRelation could build a map that larger than 2G. Author: Davies Liu <davies@databricks.com> Closes #12740 from davies/larger_broadcast.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala134
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala30
2 files changed, 107 insertions, 57 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
index 0427db4e3b..b280c76c70 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
@@ -173,8 +173,8 @@ private[joins] class UnsafeHashedRelation(
override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
out.writeInt(numFields)
// TODO: move these into BytesToBytesMap
- out.writeInt(binaryMap.numKeys())
- out.writeInt(binaryMap.numValues())
+ out.writeLong(binaryMap.numKeys())
+ out.writeLong(binaryMap.numValues())
var buffer = new Array[Byte](64)
def write(base: Object, offset: Long, length: Int): Unit = {
@@ -199,8 +199,8 @@ private[joins] class UnsafeHashedRelation(
override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
numFields = in.readInt()
resultRow = new UnsafeRow(numFields)
- val nKeys = in.readInt()
- val nValues = in.readInt()
+ val nKeys = in.readLong()
+ val nValues = in.readLong()
// This is used in Broadcast, shared by multiple tasks, so we use on-heap memory
// TODO(josh): This needs to be revisited before we merge this patch; making this change now
// so that tests compile:
@@ -345,16 +345,20 @@ private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, cap
// The page to store all bytes of UnsafeRow and the pointer to next rows.
// [row1][pointer1] [row2][pointer2]
- private var page: Array[Byte] = null
+ private var page: Array[Long] = null
// Current write cursor in the page.
- private var cursor = Platform.BYTE_ARRAY_OFFSET
+ private var cursor: Long = Platform.LONG_ARRAY_OFFSET
+
+ // The number of bits for size in address
+ private val SIZE_BITS = 28
+ private val SIZE_MASK = 0xfffffff
// The total number of values of all keys.
- private var numValues = 0
+ private var numValues = 0L
// The number of unique keys.
- private var numKeys = 0
+ private var numKeys = 0L
// needed by serializer
def this() = {
@@ -390,7 +394,7 @@ private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, cap
acquireMemory(n * 2 * 8 + (1 << 20))
array = new Array[Long](n * 2)
mask = n * 2 - 2
- page = new Array[Byte](1 << 20) // 1M bytes
+ page = new Array[Long](1 << 17) // 1M bytes
}
}
@@ -406,7 +410,7 @@ private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, cap
/**
* Returns total memory consumption.
*/
- def getTotalMemoryConsumption: Long = array.length * 8 + page.length
+ def getTotalMemoryConsumption: Long = array.length * 8L + page.length * 8L
/**
* Returns the first slot of array that store the keys (sparse mode).
@@ -422,8 +426,8 @@ private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, cap
private def nextSlot(pos: Int): Int = (pos + 2) & mask
private def getRow(address: Long, resultRow: UnsafeRow): UnsafeRow = {
- val offset = address >>> 32
- val size = address & 0xffffffffL
+ val offset = address >>> SIZE_BITS
+ val size = address & SIZE_MASK
resultRow.pointTo(page, offset, size.toInt)
resultRow
}
@@ -450,15 +454,15 @@ private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, cap
}
/**
- * Returns an interator of UnsafeRow for multiple linked values.
+ * Returns an iterator of UnsafeRow for multiple linked values.
*/
private def valueIter(address: Long, resultRow: UnsafeRow): Iterator[UnsafeRow] = {
new Iterator[UnsafeRow] {
var addr = address
override def hasNext: Boolean = addr != 0
override def next(): UnsafeRow = {
- val offset = addr >>> 32
- val size = addr & 0xffffffffL
+ val offset = addr >>> SIZE_BITS
+ val size = addr & SIZE_MASK
resultRow.pointTo(page, offset, size.toInt)
addr = Platform.getLong(page, offset + size)
resultRow
@@ -491,6 +495,11 @@ private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, cap
* Appends the key and row into this map.
*/
def append(key: Long, row: UnsafeRow): Unit = {
+ val sizeInBytes = row.getSizeInBytes
+ if (sizeInBytes >= (1 << SIZE_BITS)) {
+ sys.error("Does not support row that is larger than 256M")
+ }
+
if (key < minKey) {
minKey = key
}
@@ -499,16 +508,17 @@ private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, cap
}
// There is 8 bytes for the pointer to next value
- if (cursor + 8 + row.getSizeInBytes > page.length + Platform.BYTE_ARRAY_OFFSET) {
+ if (cursor + 8 + row.getSizeInBytes > page.length * 8L + Platform.LONG_ARRAY_OFFSET) {
val used = page.length
- if (used * 2L > (1L << 31)) {
- sys.error("Can't allocate a page that is larger than 2G")
+ if (used >= (1 << 30)) {
+ sys.error("Can not build a HashedRelation that is larger than 8G")
}
- acquireMemory(used * 2)
- val newPage = new Array[Byte](used * 2)
- System.arraycopy(page, 0, newPage, 0, cursor - Platform.BYTE_ARRAY_OFFSET)
+ acquireMemory(used * 8L * 2)
+ val newPage = new Array[Long](used * 2)
+ Platform.copyMemory(page, Platform.LONG_ARRAY_OFFSET, newPage, Platform.LONG_ARRAY_OFFSET,
+ cursor - Platform.LONG_ARRAY_OFFSET)
page = newPage
- freeMemory(used)
+ freeMemory(used * 8)
}
// copy the bytes of UnsafeRow
@@ -518,7 +528,7 @@ private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, cap
Platform.putLong(page, cursor, 0)
cursor += 8
numValues += 1
- updateIndex(key, (offset.toLong << 32) | row.getSizeInBytes)
+ updateIndex(key, (offset.toLong << SIZE_BITS) | row.getSizeInBytes)
}
/**
@@ -536,11 +546,17 @@ private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, cap
numKeys += 1
if (numKeys * 4 > array.length) {
// reach half of the capacity
- growArray()
+ if (array.length < (1 << 30)) {
+ // Cannot allocate an array with 2G elements
+ growArray()
+ } else if (numKeys > array.length / 2 * 0.75) {
+ // The fill ratio should be less than 0.75
+ sys.error("Cannot build HashedRelation with more than 1/3 billions unique keys")
+ }
}
} else {
// there are some values for this key, put the address in the front of them.
- val pointer = (address >>> 32) + (address & 0xffffffffL)
+ val pointer = (address >>> SIZE_BITS) + (address & SIZE_MASK)
Platform.putLong(page, pointer, array(pos + 1))
array(pos + 1) = address
}
@@ -550,7 +566,7 @@ private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, cap
var old_array = array
val n = array.length
numKeys = 0
- acquireMemory(n * 2 * 8)
+ acquireMemory(n * 2 * 8L)
array = new Array[Long](n * 2)
mask = n * 2 - 2
var i = 0
@@ -599,7 +615,7 @@ private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, cap
*/
def free(): Unit = {
if (page != null) {
- freeMemory(page.length)
+ freeMemory(page.length * 8)
page = null
}
if (array != null) {
@@ -608,52 +624,58 @@ private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, cap
}
}
+ private def writeLongArray(out: ObjectOutput, arr: Array[Long], len: Int): Unit = {
+ val buffer = new Array[Byte](4 << 10)
+ var offset: Long = Platform.LONG_ARRAY_OFFSET
+ val end = len * 8L + Platform.LONG_ARRAY_OFFSET
+ while (offset < end) {
+ val size = Math.min(buffer.length, (end - offset).toInt)
+ Platform.copyMemory(arr, offset, buffer, Platform.BYTE_ARRAY_OFFSET, size)
+ out.write(buffer, 0, size)
+ offset += size
+ }
+ }
+
override def writeExternal(out: ObjectOutput): Unit = {
out.writeBoolean(isDense)
out.writeLong(minKey)
out.writeLong(maxKey)
- out.writeInt(numKeys)
- out.writeInt(numValues)
+ out.writeLong(numKeys)
+ out.writeLong(numValues)
+
+ out.writeLong(array.length)
+ writeLongArray(out, array, array.length)
+ val used = ((cursor - Platform.LONG_ARRAY_OFFSET) / 8).toInt
+ out.writeLong(used)
+ writeLongArray(out, page, used)
+ }
- out.writeInt(array.length)
+ private def readLongArray(in: ObjectInput, length: Int): Array[Long] = {
+ val array = new Array[Long](length)
val buffer = new Array[Byte](4 << 10)
- var offset = Platform.LONG_ARRAY_OFFSET
- val end = array.length * 8 + Platform.LONG_ARRAY_OFFSET
+ var offset: Long = Platform.LONG_ARRAY_OFFSET
+ val end = length * 8L + Platform.LONG_ARRAY_OFFSET
while (offset < end) {
- val size = Math.min(buffer.length, end - offset)
- Platform.copyMemory(array, offset, buffer, Platform.BYTE_ARRAY_OFFSET, size)
- out.write(buffer, 0, size)
+ val size = Math.min(buffer.length, (end - offset).toInt)
+ in.readFully(buffer, 0, size)
+ Platform.copyMemory(buffer, Platform.BYTE_ARRAY_OFFSET, array, offset, size)
offset += size
}
-
- val used = cursor - Platform.BYTE_ARRAY_OFFSET
- out.writeInt(used)
- out.write(page, 0, used)
+ array
}
override def readExternal(in: ObjectInput): Unit = {
isDense = in.readBoolean()
minKey = in.readLong()
maxKey = in.readLong()
- numKeys = in.readInt()
- numValues = in.readInt()
+ numKeys = in.readLong
+ numValues = in.readLong()
- val length = in.readInt()
- array = new Array[Long](length)
+ val length = in.readLong().toInt
mask = length - 2
- val buffer = new Array[Byte](4 << 10)
- var offset = Platform.LONG_ARRAY_OFFSET
- val end = length * 8 + Platform.LONG_ARRAY_OFFSET
- while (offset < end) {
- val size = Math.min(buffer.length, end - offset)
- in.readFully(buffer, 0, size)
- Platform.copyMemory(buffer, Platform.BYTE_ARRAY_OFFSET, array, offset, size)
- offset += size
- }
-
- val numBytes = in.readInt()
- page = new Array[Byte](numBytes)
- in.readFully(page)
+ array = readLongArray(in, length)
+ val pageLength = in.readLong().toInt
+ page = readLongArray(in, pageLength)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
index 371a9ed617..3ee25c0996 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
@@ -24,8 +24,9 @@ import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.test.SharedSQLContext
-import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
+import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.apache.spark.unsafe.map.BytesToBytesMap
+import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.collection.CompactBuffer
class HashedRelationSuite extends SparkFunSuite with SharedSQLContext {
@@ -149,4 +150,31 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext {
assert(rows(1).getInt(1) === i + 1)
}
}
+
+ // This test require 4G heap to run, should run it manually
+ ignore("build HashedRelation that is larger than 1G") {
+ val unsafeProj = UnsafeProjection.create(
+ Seq(BoundReference(0, IntegerType, false),
+ BoundReference(1, StringType, true)))
+ val unsafeRow = unsafeProj(InternalRow(0, UTF8String.fromString(" " * 100)))
+ val key = Seq(BoundReference(0, IntegerType, false))
+ val rows = (0 until (1 << 24)).iterator.map { i =>
+ unsafeRow.setInt(0, i % 1000000)
+ unsafeRow.setInt(1, i)
+ unsafeRow
+ }
+
+ val unsafeRelation = UnsafeHashedRelation(rows, key, 1000, mm)
+ assert(unsafeRelation.estimatedSize > (2L << 30))
+ unsafeRelation.close()
+
+ val rows2 = (0 until (1 << 24)).iterator.map { i =>
+ unsafeRow.setInt(0, i % 1000000)
+ unsafeRow.setInt(1, i)
+ unsafeRow
+ }
+ val longRelation = LongHashedRelation(rows2, key, 1000, mm)
+ assert(longRelation.estimatedSize > (2L << 30))
+ longRelation.close()
+ }
}