aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorCheng Lian <lian@databricks.com>2016-11-29 09:01:03 +0800
committerWenchen Fan <wenchen@databricks.com>2016-11-29 09:01:03 +0800
commit2e809903d459b5b5aa6fd882b5c4a0c915af4d43 (patch)
tree62a9386361887183961ac523020516279e23eab1 /sql
parent05f7c6ffab2a6be548375cd624dc27092677232f (diff)
downloadspark-2e809903d459b5b5aa6fd882b5c4a0c915af4d43.tar.gz
spark-2e809903d459b5b5aa6fd882b5c4a0c915af4d43.tar.bz2
spark-2e809903d459b5b5aa6fd882b5c4a0c915af4d43.zip
[SPARK-18403][SQL] Fix unsafe data false sharing issue in ObjectHashAggregateExec
## What changes were proposed in this pull request? This PR fixes a random OOM issue occurred while running `ObjectHashAggregateSuite`. This issue can be steadily reproduced under the following conditions: 1. The aggregation must be evaluated using `ObjectHashAggregateExec`; 2. There must be an input column whose data type involves `ArrayType` (an input column of `MapType` may even cause SIGSEGV); 3. Sort-based aggregation fallback must be triggered during evaluation. The root cause is that while falling back to sort-based aggregation, we must sort and feed already evaluated partial aggregation buffers living in the hash map to the sort-based aggregator using an external sorter. However, the underlying mutable byte buffer of `UnsafeRow`s produced by the iterator of the external sorter is reused and may get overwritten when the iterator steps forward. After the last entry is consumed, the byte buffer points to a block of uninitialized memory filled by `5a`. Therefore, while reading an `UnsafeArrayData` out of the `UnsafeRow`, `5a5a5a5a` is treated as array size and triggers a memory allocation for a ridiculously large array and immediately blows up the JVM with an OOM. To fix this issue, we only need to add `.copy()` accordingly. ## How was this patch tested? New regression test case added in `ObjectHashAggregateSuite`. Author: Cheng Lian <lian@databricks.com> Closes #15976 from liancheng/investigate-oom.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala11
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala164
2 files changed, 101 insertions, 74 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala
index 3c7b9ee317..3a7fcf1fa9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala
@@ -262,7 +262,9 @@ class SortBasedAggregator(
// Firstly, update the aggregation buffer with input rows.
while (hasNextInput &&
groupingKeyOrdering.compare(inputIterator.getKey, groupingKey) == 0) {
- processRow(result.aggregationBuffer, inputIterator.getValue)
+ // Since `inputIterator.getValue` is an `UnsafeRow` whose underlying buffer will be
+ // overwritten when `inputIterator` steps forward, we need to do a deep copy here.
+ processRow(result.aggregationBuffer, inputIterator.getValue.copy())
hasNextInput = inputIterator.next()
}
@@ -271,7 +273,12 @@ class SortBasedAggregator(
// be called after calling processRow.
while (hasNextAggBuffer &&
groupingKeyOrdering.compare(initialAggBufferIterator.getKey, groupingKey) == 0) {
- mergeAggregationBuffers(result.aggregationBuffer, initialAggBufferIterator.getValue)
+ mergeAggregationBuffers(
+ result.aggregationBuffer,
+ // Since `inputIterator.getValue` is an `UnsafeRow` whose underlying buffer will be
+ // overwritten when `inputIterator` steps forward, we need to do a deep copy here.
+ initialAggBufferIterator.getValue.copy()
+ )
hasNextAggBuffer = initialAggBufferIterator.next()
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala
index b7f91d8c3a..9a8d4498bb 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala
@@ -205,23 +205,19 @@ class ObjectHashAggregateSuite
// A TypedImperativeAggregate function
val typed = percentile_approx($"c0", 0.5)
- // A Hive UDAF without partial aggregation support
- val withoutPartial = function("hive_max", $"c1")
-
// A Spark SQL native aggregate function with partial aggregation support that can be executed
// by the Tungsten `HashAggregateExec`
- val withPartialUnsafe = max($"c2")
+ val withPartialUnsafe = max($"c1")
// A Spark SQL native aggregate function with partial aggregation support that can only be
// executed by the Tungsten `HashAggregateExec`
- val withPartialSafe = max($"c3")
+ val withPartialSafe = max($"c2")
// A Spark SQL native distinct aggregate function
- val withDistinct = countDistinct($"c4")
+ val withDistinct = countDistinct($"c3")
val allAggs = Seq(
"typed" -> typed,
- "without partial" -> withoutPartial,
"with partial + unsafe" -> withPartialUnsafe,
"with partial + safe" -> withPartialSafe,
"with distinct" -> withDistinct
@@ -276,10 +272,9 @@ class ObjectHashAggregateSuite
// Generates a random schema for the randomized data generator
val schema = new StructType()
.add("c0", numericTypes(random.nextInt(numericTypes.length)), nullable = true)
- .add("c1", orderedTypes(random.nextInt(orderedTypes.length)), nullable = true)
- .add("c2", fixedLengthTypes(random.nextInt(fixedLengthTypes.length)), nullable = true)
- .add("c3", varLenOrderedTypes(random.nextInt(varLenOrderedTypes.length)), nullable = true)
- .add("c4", allTypes(random.nextInt(allTypes.length)), nullable = true)
+ .add("c1", fixedLengthTypes(random.nextInt(fixedLengthTypes.length)), nullable = true)
+ .add("c2", varLenOrderedTypes(random.nextInt(varLenOrderedTypes.length)), nullable = true)
+ .add("c3", allTypes(random.nextInt(allTypes.length)), nullable = true)
logInfo(
s"""Using the following random schema to generate all the randomized aggregation tests:
@@ -325,70 +320,67 @@ class ObjectHashAggregateSuite
// Currently Spark SQL doesn't support evaluating distinct aggregate function together
// with aggregate functions without partial aggregation support.
- if (!(aggs.contains(withoutPartial) && aggs.contains(withDistinct))) {
- // TODO Re-enables them after fixing SPARK-18403
- ignore(
- s"randomized aggregation test - " +
- s"${names.mkString("[", ", ", "]")} - " +
- s"${if (withGroupingKeys) "with" else "without"} grouping keys - " +
- s"with ${if (emptyInput) "empty" else "non-empty"} input"
- ) {
- var expected: Seq[Row] = null
- var actual1: Seq[Row] = null
- var actual2: Seq[Row] = null
-
- // Disables `ObjectHashAggregateExec` to obtain a standard answer
- withSQLConf(SQLConf.USE_OBJECT_HASH_AGG.key -> "false") {
- val aggDf = doAggregation(df)
-
- if (aggs.intersect(Seq(withoutPartial, withPartialSafe, typed)).nonEmpty) {
- assert(containsSortAggregateExec(aggDf))
- assert(!containsObjectHashAggregateExec(aggDf))
- assert(!containsHashAggregateExec(aggDf))
- } else {
- assert(!containsSortAggregateExec(aggDf))
- assert(!containsObjectHashAggregateExec(aggDf))
- assert(containsHashAggregateExec(aggDf))
- }
-
- expected = aggDf.collect().toSeq
+ test(
+ s"randomized aggregation test - " +
+ s"${names.mkString("[", ", ", "]")} - " +
+ s"${if (withGroupingKeys) "with" else "without"} grouping keys - " +
+ s"with ${if (emptyInput) "empty" else "non-empty"} input"
+ ) {
+ var expected: Seq[Row] = null
+ var actual1: Seq[Row] = null
+ var actual2: Seq[Row] = null
+
+ // Disables `ObjectHashAggregateExec` to obtain a standard answer
+ withSQLConf(SQLConf.USE_OBJECT_HASH_AGG.key -> "false") {
+ val aggDf = doAggregation(df)
+
+ if (aggs.intersect(Seq(withPartialSafe, typed)).nonEmpty) {
+ assert(containsSortAggregateExec(aggDf))
+ assert(!containsObjectHashAggregateExec(aggDf))
+ assert(!containsHashAggregateExec(aggDf))
+ } else {
+ assert(!containsSortAggregateExec(aggDf))
+ assert(!containsObjectHashAggregateExec(aggDf))
+ assert(containsHashAggregateExec(aggDf))
}
- // Enables `ObjectHashAggregateExec`
- withSQLConf(SQLConf.USE_OBJECT_HASH_AGG.key -> "true") {
- val aggDf = doAggregation(df)
-
- if (aggs.contains(typed) && !aggs.contains(withoutPartial)) {
- assert(!containsSortAggregateExec(aggDf))
- assert(containsObjectHashAggregateExec(aggDf))
- assert(!containsHashAggregateExec(aggDf))
- } else if (aggs.intersect(Seq(withoutPartial, withPartialSafe)).nonEmpty) {
- assert(containsSortAggregateExec(aggDf))
- assert(!containsObjectHashAggregateExec(aggDf))
- assert(!containsHashAggregateExec(aggDf))
- } else {
- assert(!containsSortAggregateExec(aggDf))
- assert(!containsObjectHashAggregateExec(aggDf))
- assert(containsHashAggregateExec(aggDf))
- }
-
- // Disables sort-based aggregation fallback (we only generate 50 rows, so 100 is
- // big enough) to obtain a result to be checked.
- withSQLConf(SQLConf.OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key -> "100") {
- actual1 = aggDf.collect().toSeq
- }
-
- // Enables sort-based aggregation fallback to obtain another result to be checked.
- withSQLConf(SQLConf.OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key -> "3") {
- // Here we are not reusing `aggDf` because the physical plan in `aggDf` is
- // cached and won't be re-planned using the new fallback threshold.
- actual2 = doAggregation(df).collect().toSeq
- }
+ expected = aggDf.collect().toSeq
+ }
+
+ // Enables `ObjectHashAggregateExec`
+ withSQLConf(SQLConf.USE_OBJECT_HASH_AGG.key -> "true") {
+ val aggDf = doAggregation(df)
+
+ if (aggs.contains(typed)) {
+ assert(!containsSortAggregateExec(aggDf))
+ assert(containsObjectHashAggregateExec(aggDf))
+ assert(!containsHashAggregateExec(aggDf))
+ } else if (aggs.contains(withPartialSafe)) {
+ assert(containsSortAggregateExec(aggDf))
+ assert(!containsObjectHashAggregateExec(aggDf))
+ assert(!containsHashAggregateExec(aggDf))
+ } else {
+ assert(!containsSortAggregateExec(aggDf))
+ assert(!containsObjectHashAggregateExec(aggDf))
+ assert(containsHashAggregateExec(aggDf))
}
- doubleSafeCheckRows(actual1, expected, 1e-4)
- doubleSafeCheckRows(actual2, expected, 1e-4)
+ // Disables sort-based aggregation fallback (we only generate 50 rows, so 100 is
+ // big enough) to obtain a result to be checked.
+ withSQLConf(SQLConf.OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key -> "100") {
+ actual1 = aggDf.collect().toSeq
+ }
+
+ // Enables sort-based aggregation fallback to obtain another result to be checked.
+ withSQLConf(SQLConf.OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key -> "3") {
+ // Here we are not reusing `aggDf` because the physical plan in `aggDf` is
+ // cached and won't be re-planned using the new fallback threshold.
+ actual2 = doAggregation(df).collect().toSeq
+ }
}
+
+ doubleSafeCheckRows(actual1, expected, 1e-4)
+ doubleSafeCheckRows(actual2, expected, 1e-4)
}
}
}
@@ -425,7 +417,35 @@ class ObjectHashAggregateSuite
}
}
- private def function(name: String, args: Column*): Column = {
- Column(UnresolvedFunction(FunctionIdentifier(name), args.map(_.expr), isDistinct = false))
+ test("SPARK-18403 Fix unsafe data false sharing issue in ObjectHashAggregateExec") {
+ // SPARK-18403: An unsafe data false sharing issue may trigger OOM / SIGSEGV when evaluating
+ // certain aggregate functions. To reproduce this issue, the following conditions must be
+ // met:
+ //
+ // 1. The aggregation must be evaluated using `ObjectHashAggregateExec`;
+ // 2. There must be an input column whose data type involves `ArrayType` or `MapType`;
+ // 3. Sort-based aggregation fallback must be triggered during evaluation.
+ withSQLConf(
+ SQLConf.USE_OBJECT_HASH_AGG.key -> "true",
+ SQLConf.OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key -> "1"
+ ) {
+ checkAnswer(
+ Seq
+ .fill(2)(Tuple1(Array.empty[Int]))
+ .toDF("c0")
+ .groupBy(lit(1))
+ .agg(typed_count($"c0"), max($"c0")),
+ Row(1, 2, Array.empty[Int])
+ )
+
+ checkAnswer(
+ Seq
+ .fill(2)(Tuple1(Map.empty[Int, Int]))
+ .toDF("c0")
+ .groupBy(lit(1))
+ .agg(typed_count($"c0"), first($"c0")),
+ Row(1, 2, Map.empty[Int, Int])
+ )
+ }
}
}