diff options
author | Cheng Lian <lian@databricks.com> | 2016-11-29 09:01:03 +0800 |
---|---|---|
committer | Wenchen Fan <wenchen@databricks.com> | 2016-11-29 09:01:03 +0800 |
commit | 2e809903d459b5b5aa6fd882b5c4a0c915af4d43 (patch) | |
tree | 62a9386361887183961ac523020516279e23eab1 /sql/hive/src | |
parent | 05f7c6ffab2a6be548375cd624dc27092677232f (diff) | |
download | spark-2e809903d459b5b5aa6fd882b5c4a0c915af4d43.tar.gz spark-2e809903d459b5b5aa6fd882b5c4a0c915af4d43.tar.bz2 spark-2e809903d459b5b5aa6fd882b5c4a0c915af4d43.zip |
[SPARK-18403][SQL] Fix unsafe data false sharing issue in ObjectHashAggregateExec
## What changes were proposed in this pull request?
This PR fixes a random OOM issue occurred while running `ObjectHashAggregateSuite`.
This issue can be steadily reproduced under the following conditions:
1. The aggregation must be evaluated using `ObjectHashAggregateExec`;
2. There must be an input column whose data type involves `ArrayType` (an input column of `MapType` may even cause SIGSEGV);
3. Sort-based aggregation fallback must be triggered during evaluation.
The root cause is that while falling back to sort-based aggregation, we must sort and feed already evaluated partial aggregation buffers living in the hash map to the sort-based aggregator using an external sorter. However, the underlying mutable byte buffer of `UnsafeRow`s produced by the iterator of the external sorter is reused and may get overwritten when the iterator steps forward. After the last entry is consumed, the byte buffer points to a block of uninitialized memory filled by `5a`. Therefore, while reading an `UnsafeArrayData` out of the `UnsafeRow`, `5a5a5a5a` is treated as array size and triggers a memory allocation for a ridiculously large array and immediately blows up the JVM with an OOM.
To fix this issue, we only need to add `.copy()` accordingly.
## How was this patch tested?
New regression test case added in `ObjectHashAggregateSuite`.
Author: Cheng Lian <lian@databricks.com>
Closes #15976 from liancheng/investigate-oom.
Diffstat (limited to 'sql/hive/src')
-rw-r--r-- | sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala | 164 |
1 files changed, 92 insertions, 72 deletions
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala index b7f91d8c3a..9a8d4498bb 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala @@ -205,23 +205,19 @@ class ObjectHashAggregateSuite // A TypedImperativeAggregate function val typed = percentile_approx($"c0", 0.5) - // A Hive UDAF without partial aggregation support - val withoutPartial = function("hive_max", $"c1") - // A Spark SQL native aggregate function with partial aggregation support that can be executed // by the Tungsten `HashAggregateExec` - val withPartialUnsafe = max($"c2") + val withPartialUnsafe = max($"c1") // A Spark SQL native aggregate function with partial aggregation support that can only be // executed by the Tungsten `HashAggregateExec` - val withPartialSafe = max($"c3") + val withPartialSafe = max($"c2") // A Spark SQL native distinct aggregate function - val withDistinct = countDistinct($"c4") + val withDistinct = countDistinct($"c3") val allAggs = Seq( "typed" -> typed, - "without partial" -> withoutPartial, "with partial + unsafe" -> withPartialUnsafe, "with partial + safe" -> withPartialSafe, "with distinct" -> withDistinct @@ -276,10 +272,9 @@ class ObjectHashAggregateSuite // Generates a random schema for the randomized data generator val schema = new StructType() .add("c0", numericTypes(random.nextInt(numericTypes.length)), nullable = true) - .add("c1", orderedTypes(random.nextInt(orderedTypes.length)), nullable = true) - .add("c2", fixedLengthTypes(random.nextInt(fixedLengthTypes.length)), nullable = true) - .add("c3", varLenOrderedTypes(random.nextInt(varLenOrderedTypes.length)), nullable = true) - .add("c4", allTypes(random.nextInt(allTypes.length)), nullable = true) + .add("c1", fixedLengthTypes(random.nextInt(fixedLengthTypes.length)), nullable = true) + .add("c2", varLenOrderedTypes(random.nextInt(varLenOrderedTypes.length)), nullable = true) + .add("c3", allTypes(random.nextInt(allTypes.length)), nullable = true) logInfo( s"""Using the following random schema to generate all the randomized aggregation tests: @@ -325,70 +320,67 @@ class ObjectHashAggregateSuite // Currently Spark SQL doesn't support evaluating distinct aggregate function together // with aggregate functions without partial aggregation support. - if (!(aggs.contains(withoutPartial) && aggs.contains(withDistinct))) { - // TODO Re-enables them after fixing SPARK-18403 - ignore( - s"randomized aggregation test - " + - s"${names.mkString("[", ", ", "]")} - " + - s"${if (withGroupingKeys) "with" else "without"} grouping keys - " + - s"with ${if (emptyInput) "empty" else "non-empty"} input" - ) { - var expected: Seq[Row] = null - var actual1: Seq[Row] = null - var actual2: Seq[Row] = null - - // Disables `ObjectHashAggregateExec` to obtain a standard answer - withSQLConf(SQLConf.USE_OBJECT_HASH_AGG.key -> "false") { - val aggDf = doAggregation(df) - - if (aggs.intersect(Seq(withoutPartial, withPartialSafe, typed)).nonEmpty) { - assert(containsSortAggregateExec(aggDf)) - assert(!containsObjectHashAggregateExec(aggDf)) - assert(!containsHashAggregateExec(aggDf)) - } else { - assert(!containsSortAggregateExec(aggDf)) - assert(!containsObjectHashAggregateExec(aggDf)) - assert(containsHashAggregateExec(aggDf)) - } - - expected = aggDf.collect().toSeq + test( + s"randomized aggregation test - " + + s"${names.mkString("[", ", ", "]")} - " + + s"${if (withGroupingKeys) "with" else "without"} grouping keys - " + + s"with ${if (emptyInput) "empty" else "non-empty"} input" + ) { + var expected: Seq[Row] = null + var actual1: Seq[Row] = null + var actual2: Seq[Row] = null + + // Disables `ObjectHashAggregateExec` to obtain a standard answer + withSQLConf(SQLConf.USE_OBJECT_HASH_AGG.key -> "false") { + val aggDf = doAggregation(df) + + if (aggs.intersect(Seq(withPartialSafe, typed)).nonEmpty) { + assert(containsSortAggregateExec(aggDf)) + assert(!containsObjectHashAggregateExec(aggDf)) + assert(!containsHashAggregateExec(aggDf)) + } else { + assert(!containsSortAggregateExec(aggDf)) + assert(!containsObjectHashAggregateExec(aggDf)) + assert(containsHashAggregateExec(aggDf)) } - // Enables `ObjectHashAggregateExec` - withSQLConf(SQLConf.USE_OBJECT_HASH_AGG.key -> "true") { - val aggDf = doAggregation(df) - - if (aggs.contains(typed) && !aggs.contains(withoutPartial)) { - assert(!containsSortAggregateExec(aggDf)) - assert(containsObjectHashAggregateExec(aggDf)) - assert(!containsHashAggregateExec(aggDf)) - } else if (aggs.intersect(Seq(withoutPartial, withPartialSafe)).nonEmpty) { - assert(containsSortAggregateExec(aggDf)) - assert(!containsObjectHashAggregateExec(aggDf)) - assert(!containsHashAggregateExec(aggDf)) - } else { - assert(!containsSortAggregateExec(aggDf)) - assert(!containsObjectHashAggregateExec(aggDf)) - assert(containsHashAggregateExec(aggDf)) - } - - // Disables sort-based aggregation fallback (we only generate 50 rows, so 100 is - // big enough) to obtain a result to be checked. - withSQLConf(SQLConf.OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key -> "100") { - actual1 = aggDf.collect().toSeq - } - - // Enables sort-based aggregation fallback to obtain another result to be checked. - withSQLConf(SQLConf.OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key -> "3") { - // Here we are not reusing `aggDf` because the physical plan in `aggDf` is - // cached and won't be re-planned using the new fallback threshold. - actual2 = doAggregation(df).collect().toSeq - } + expected = aggDf.collect().toSeq + } + + // Enables `ObjectHashAggregateExec` + withSQLConf(SQLConf.USE_OBJECT_HASH_AGG.key -> "true") { + val aggDf = doAggregation(df) + + if (aggs.contains(typed)) { + assert(!containsSortAggregateExec(aggDf)) + assert(containsObjectHashAggregateExec(aggDf)) + assert(!containsHashAggregateExec(aggDf)) + } else if (aggs.contains(withPartialSafe)) { + assert(containsSortAggregateExec(aggDf)) + assert(!containsObjectHashAggregateExec(aggDf)) + assert(!containsHashAggregateExec(aggDf)) + } else { + assert(!containsSortAggregateExec(aggDf)) + assert(!containsObjectHashAggregateExec(aggDf)) + assert(containsHashAggregateExec(aggDf)) } - doubleSafeCheckRows(actual1, expected, 1e-4) - doubleSafeCheckRows(actual2, expected, 1e-4) + // Disables sort-based aggregation fallback (we only generate 50 rows, so 100 is + // big enough) to obtain a result to be checked. + withSQLConf(SQLConf.OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key -> "100") { + actual1 = aggDf.collect().toSeq + } + + // Enables sort-based aggregation fallback to obtain another result to be checked. + withSQLConf(SQLConf.OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key -> "3") { + // Here we are not reusing `aggDf` because the physical plan in `aggDf` is + // cached and won't be re-planned using the new fallback threshold. + actual2 = doAggregation(df).collect().toSeq + } } + + doubleSafeCheckRows(actual1, expected, 1e-4) + doubleSafeCheckRows(actual2, expected, 1e-4) } } } @@ -425,7 +417,35 @@ class ObjectHashAggregateSuite } } - private def function(name: String, args: Column*): Column = { - Column(UnresolvedFunction(FunctionIdentifier(name), args.map(_.expr), isDistinct = false)) + test("SPARK-18403 Fix unsafe data false sharing issue in ObjectHashAggregateExec") { + // SPARK-18403: An unsafe data false sharing issue may trigger OOM / SIGSEGV when evaluating + // certain aggregate functions. To reproduce this issue, the following conditions must be + // met: + // + // 1. The aggregation must be evaluated using `ObjectHashAggregateExec`; + // 2. There must be an input column whose data type involves `ArrayType` or `MapType`; + // 3. Sort-based aggregation fallback must be triggered during evaluation. + withSQLConf( + SQLConf.USE_OBJECT_HASH_AGG.key -> "true", + SQLConf.OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key -> "1" + ) { + checkAnswer( + Seq + .fill(2)(Tuple1(Array.empty[Int])) + .toDF("c0") + .groupBy(lit(1)) + .agg(typed_count($"c0"), max($"c0")), + Row(1, 2, Array.empty[Int]) + ) + + checkAnswer( + Seq + .fill(2)(Tuple1(Map.empty[Int, Int])) + .toDF("c0") + .groupBy(lit(1)) + .agg(typed_count($"c0"), first($"c0")), + Row(1, 2, Map.empty[Int, Int]) + ) + } } } |