aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2016-05-19 11:45:18 -0700
committerDavies Liu <davies.liu@gmail.com>2016-05-19 11:45:18 -0700
commit9308bf119204015c8733fab0c2aef70ff2e41d74 (patch)
treeb9ea76cfa10517918ef991f6371a53086bf19950 /sql
parent31f63ac25da43746fdef2a9477f6a79ac046112f (diff)
downloadspark-9308bf119204015c8733fab0c2aef70ff2e41d74.tar.gz
spark-9308bf119204015c8733fab0c2aef70ff2e41d74.tar.bz2
spark-9308bf119204015c8733fab0c2aef70ff2e41d74.zip
[SPARK-15390] fix broadcast with 100 millions rows
## What changes were proposed in this pull request? When broadcast a table with more than 100 millions rows (should not ideally), the size of needed memory will overflow. This PR fix the overflow by converting it to Long when calculating the size of memory. Also add more checking in broadcast to show reasonable messages. ## How was this patch tested? Add test. Author: Davies Liu <davies@databricks.com> Closes #13182 from davies/fix_broadcast.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala13
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala5
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala15
3 files changed, 29 insertions, 4 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala
index b6ecd3cb06..d3081ba7ac 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.exchange
import scala.concurrent.{ExecutionContext, Future}
import scala.concurrent.duration._
-import org.apache.spark.broadcast
+import org.apache.spark.{broadcast, SparkException}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
@@ -72,9 +72,18 @@ case class BroadcastExchangeExec(
val beforeCollect = System.nanoTime()
// Note that we use .executeCollect() because we don't want to convert data to Scala types
val input: Array[InternalRow] = child.executeCollect()
+ if (input.length >= 512000000) {
+ throw new SparkException(
+ s"Cannot broadcast the table with more than 512 millions rows: ${input.length} rows")
+ }
val beforeBuild = System.nanoTime()
longMetric("collectTime") += (beforeBuild - beforeCollect) / 1000000
- longMetric("dataSize") += input.map(_.asInstanceOf[UnsafeRow].getSizeInBytes.toLong).sum
+ val dataSize = input.map(_.asInstanceOf[UnsafeRow].getSizeInBytes.toLong).sum
+ longMetric("dataSize") += dataSize
+ if (dataSize >= (8L << 30)) {
+ throw new SparkException(
+ s"Cannot broadcast the table that is larger than 8GB: ${dataSize >> 30} GB")
+ }
// Construct and broadcast the relation.
val relation = mode.transform(input)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
index cb41457b66..cd6b97a855 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
@@ -410,9 +410,10 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
private def init(): Unit = {
if (mm != null) {
+ require(capacity < 512000000, "Cannot broadcast more than 512 millions rows")
var n = 1
while (n < capacity) n *= 2
- ensureAcquireMemory(n * 2 * 8 + (1 << 20))
+ ensureAcquireMemory(n * 2L * 8 + (1 << 20))
array = new Array[Long](n * 2)
mask = n * 2 - 2
page = new Array[Long](1 << 17) // 1M bytes
@@ -788,7 +789,7 @@ private[joins] object LongHashedRelation {
sizeEstimate: Int,
taskMemoryManager: TaskMemoryManager): LongHashedRelation = {
- val map: LongToUnsafeRowMap = new LongToUnsafeRowMap(taskMemoryManager, sizeEstimate)
+ val map = new LongToUnsafeRowMap(taskMemoryManager, sizeEstimate)
val keyGenerator = UnsafeProjection.create(key)
// Create a mapping of key -> rows
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
index b7b08dc4b1..a5b56541c9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
@@ -212,4 +212,19 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext {
assert(longRelation.estimatedSize > (2L << 30))
longRelation.close()
}
+
+ test("build HashedRelation with more than 100 millions rows") {
+ val unsafeProj = UnsafeProjection.create(
+ Seq(BoundReference(0, IntegerType, false),
+ BoundReference(1, StringType, true)))
+ val unsafeRow = unsafeProj(InternalRow(0, UTF8String.fromString(" " * 100)))
+ val key = Seq(BoundReference(0, IntegerType, false))
+ val rows = (0 until (1 << 10)).iterator.map { i =>
+ unsafeRow.setInt(0, i % 1000000)
+ unsafeRow.setInt(1, i)
+ unsafeRow
+ }
+ val m = LongHashedRelation(rows, key, 100 << 20, mm)
+ m.close()
+ }
}