aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src/test/scala
diff options
context:
space:
mode:
authorTejas Patil <tejasp@fb.com>2017-03-15 20:18:39 +0100
committerHerman van Hovell <hvanhovell@databricks.com>2017-03-15 20:18:39 +0100
commit02c274eaba0a8e7611226e0d4e93d3c36253f4ce (patch)
tree52852e05f5a0b0729a6c92c1d360a6379a52a380 /sql/core/src/test/scala
parent7387126f83dc0489eb1df734bfeba705709b7861 (diff)
downloadspark-02c274eaba0a8e7611226e0d4e93d3c36253f4ce.tar.gz
spark-02c274eaba0a8e7611226e0d4e93d3c36253f4ce.tar.bz2
spark-02c274eaba0a8e7611226e0d4e93d3c36253f4ce.zip
[SPARK-13450] Introduce ExternalAppendOnlyUnsafeRowArray. Change CartesianProductExec, SortMergeJoin, WindowExec to use it
## What issue does this PR address ? Jira: https://issues.apache.org/jira/browse/SPARK-13450 In `SortMergeJoinExec`, rows of the right relation having the same value for a join key are buffered in-memory. In case of skew, this causes OOMs (see comments in SPARK-13450 for more details). Heap dump from a failed job confirms this : https://issues.apache.org/jira/secure/attachment/12846382/heap-dump-analysis.png . While its possible to increase the heap size to workaround, Spark should be resilient to such issues as skews can happen arbitrarily. ## Change proposed in this pull request - Introduces `ExternalAppendOnlyUnsafeRowArray` - It holds `UnsafeRow`s in-memory upto a certain threshold. - After the threshold is hit, it switches to `UnsafeExternalSorter` which enables spilling of the rows to disk. It does NOT sort the data. - Allows iterating the array multiple times. However, any alteration to the array (using `add` or `clear`) will invalidate the existing iterator(s) - `WindowExec` was already using `UnsafeExternalSorter` to support spilling. Changed it to use the new array - Changed `SortMergeJoinExec` to use the new array implementation - NOTE: I have not changed FULL OUTER JOIN to use this new array implementation. Changing that will need more surgery and I will rather put up a separate PR for that once this gets in. - Changed `CartesianProductExec` to use the new array implementation #### Note for reviewers The diff can be divided into 3 parts. My motive behind having all the changes in a single PR was to demonstrate that the API is sane and supports 2 use cases. If reviewing as 3 separate PRs would help, I am happy to make the split. ## How was this patch tested ? #### Unit testing - Added unit tests `ExternalAppendOnlyUnsafeRowArray` to validate all its APIs and access patterns - Added unit test for `SortMergeExec` - with and without spill for inner join, left outer join, right outer join to confirm that the spill threshold config behaves as expected and output is as expected. - This PR touches the scanning logic in `SortMergeExec` for _all_ joins (except FULL OUTER JOIN). However, I expect existing test cases to cover that there is no regression in correctness. - Added unit test for `WindowExec` to check behavior of spilling and correctness of results. #### Stress testing - Confirmed that OOM is gone by running against a production job which used to OOM - Since I cannot share details about prod workload externally, created synthetic data to mimic the issue. Ran before and after the fix to demonstrate the issue and query success with this PR Generating the synthetic data ``` ./bin/spark-shell --driver-memory=6G import org.apache.spark.sql._ val hc = SparkSession.builder.master("local").getOrCreate() hc.sql("DROP TABLE IF EXISTS spark_13450_large_table").collect hc.sql("DROP TABLE IF EXISTS spark_13450_one_row_table").collect val df1 = (0 until 1).map(i => ("10", "100", i.toString, (i * 2).toString)).toDF("i", "j", "str1", "str2") df1.write.format("org.apache.spark.sql.hive.orc.OrcFileFormat").bucketBy(100, "i", "j").sortBy("i", "j").saveAsTable("spark_13450_one_row_table") val df2 = (0 until 3000000).map(i => ("10", "100", i.toString, (i * 2).toString)).toDF("i", "j", "str1", "str2") df2.write.format("org.apache.spark.sql.hive.orc.OrcFileFormat").bucketBy(100, "i", "j").sortBy("i", "j").saveAsTable("spark_13450_large_table") ``` Ran this against trunk VS local build with this PR. OOM repros with trunk and with the fix this query runs fine. ``` ./bin/spark-shell --driver-java-options="-XX:+HeapDumpOnOutOfMemoryError -XX:HeapDumpPath=/tmp/spark.driver.heapdump.hprof" import org.apache.spark.sql._ val hc = SparkSession.builder.master("local").getOrCreate() hc.sql("SET spark.sql.autoBroadcastJoinThreshold=1") hc.sql("SET spark.sql.sortMergeJoinExec.buffer.spill.threshold=10000") hc.sql("DROP TABLE IF EXISTS spark_13450_result").collect hc.sql(""" CREATE TABLE spark_13450_result AS SELECT a.i AS a_i, a.j AS a_j, a.str1 AS a_str1, a.str2 AS a_str2, b.i AS b_i, b.j AS b_j, b.str1 AS b_str1, b.str2 AS b_str2 FROM spark_13450_one_row_table a JOIN spark_13450_large_table b ON a.i=b.i AND a.j=b.j """) ``` ## Performance comparison ### Macro-benchmark I ran a SMB join query over two real world tables (2 trillion rows (40 TB) and 6 million rows (120 GB)). Note that this dataset does not have skew so no spill happened. I saw improvement in CPU time by 2-4% over version without this PR. This did not add up as I was expected some regression. I think allocating array of capacity of 128 at the start (instead of starting with default size 16) is the sole reason for the perf. gain : https://github.com/tejasapatil/spark/blob/SPARK-13450_smb_buffer_oom/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala#L43 . I could remove that and rerun, but effectively the change will be deployed in this form and I wanted to see the effect of it over large workload. ### Micro-benchmark Two types of benchmarking can be found in `ExternalAppendOnlyUnsafeRowArrayBenchmark`: [A] Comparing `ExternalAppendOnlyUnsafeRowArray` against raw `ArrayBuffer` when all rows fit in-memory and there is no spill ``` Array with 1000 rows: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ ArrayBuffer 7821 / 7941 33.5 29.8 1.0X ExternalAppendOnlyUnsafeRowArray 8798 / 8819 29.8 33.6 0.9X Array with 30000 rows: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ ArrayBuffer 19200 / 19206 25.6 39.1 1.0X ExternalAppendOnlyUnsafeRowArray 19558 / 19562 25.1 39.8 1.0X Array with 100000 rows: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ ArrayBuffer 5949 / 6028 17.2 58.1 1.0X ExternalAppendOnlyUnsafeRowArray 6078 / 6138 16.8 59.4 1.0X ``` [B] Comparing `ExternalAppendOnlyUnsafeRowArray` against raw `UnsafeExternalSorter` when there is spilling of data ``` Spilling with 1000 rows: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ UnsafeExternalSorter 9239 / 9470 28.4 35.2 1.0X ExternalAppendOnlyUnsafeRowArray 8857 / 8909 29.6 33.8 1.0X Spilling with 10000 rows: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ UnsafeExternalSorter 4 / 5 39.3 25.5 1.0X ExternalAppendOnlyUnsafeRowArray 5 / 6 29.8 33.5 0.8X ``` Author: Tejas Patil <tejasp@fb.com> Closes #16909 from tejasapatil/SPARK-13450_smb_buffer_oom.
Diffstat (limited to 'sql/core/src/test/scala')
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala136
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala233
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala351
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala33
4 files changed, 752 insertions, 1 deletions
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index 2e006735d1..1a66aa85f5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql
+import scala.collection.mutable.ListBuffer
import scala.language.existentials
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
@@ -24,7 +25,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.execution.joins._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
-
+import org.apache.spark.TestUtils.{assertNotSpilled, assertSpilled}
class JoinSuite extends QueryTest with SharedSQLContext {
import testImplicits._
@@ -604,4 +605,137 @@ class JoinSuite extends QueryTest with SharedSQLContext {
cartesianQueries.foreach(checkCartesianDetection)
}
+
+ test("test SortMergeJoin (without spill)") {
+ withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1",
+ "spark.sql.sortMergeJoinExec.buffer.spill.threshold" -> Int.MaxValue.toString) {
+
+ assertNotSpilled(sparkContext, "inner join") {
+ checkAnswer(
+ sql("SELECT * FROM testData JOIN testData2 ON key = a where key = 2"),
+ Row(2, "2", 2, 1) :: Row(2, "2", 2, 2) :: Nil
+ )
+ }
+
+ val expected = new ListBuffer[Row]()
+ expected.append(
+ Row(1, "1", 1, 1), Row(1, "1", 1, 2),
+ Row(2, "2", 2, 1), Row(2, "2", 2, 2),
+ Row(3, "3", 3, 1), Row(3, "3", 3, 2)
+ )
+ for (i <- 4 to 100) {
+ expected.append(Row(i, i.toString, null, null))
+ }
+
+ assertNotSpilled(sparkContext, "left outer join") {
+ checkAnswer(
+ sql(
+ """
+ |SELECT
+ | big.key, big.value, small.a, small.b
+ |FROM
+ | testData big
+ |LEFT OUTER JOIN
+ | testData2 small
+ |ON
+ | big.key = small.a
+ """.stripMargin),
+ expected
+ )
+ }
+
+ assertNotSpilled(sparkContext, "right outer join") {
+ checkAnswer(
+ sql(
+ """
+ |SELECT
+ | big.key, big.value, small.a, small.b
+ |FROM
+ | testData2 small
+ |RIGHT OUTER JOIN
+ | testData big
+ |ON
+ | big.key = small.a
+ """.stripMargin),
+ expected
+ )
+ }
+ }
+ }
+
+ test("test SortMergeJoin (with spill)") {
+ withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1",
+ "spark.sql.sortMergeJoinExec.buffer.spill.threshold" -> "0") {
+
+ assertSpilled(sparkContext, "inner join") {
+ checkAnswer(
+ sql("SELECT * FROM testData JOIN testData2 ON key = a where key = 2"),
+ Row(2, "2", 2, 1) :: Row(2, "2", 2, 2) :: Nil
+ )
+ }
+
+ val expected = new ListBuffer[Row]()
+ expected.append(
+ Row(1, "1", 1, 1), Row(1, "1", 1, 2),
+ Row(2, "2", 2, 1), Row(2, "2", 2, 2),
+ Row(3, "3", 3, 1), Row(3, "3", 3, 2)
+ )
+ for (i <- 4 to 100) {
+ expected.append(Row(i, i.toString, null, null))
+ }
+
+ assertSpilled(sparkContext, "left outer join") {
+ checkAnswer(
+ sql(
+ """
+ |SELECT
+ | big.key, big.value, small.a, small.b
+ |FROM
+ | testData big
+ |LEFT OUTER JOIN
+ | testData2 small
+ |ON
+ | big.key = small.a
+ """.stripMargin),
+ expected
+ )
+ }
+
+ assertSpilled(sparkContext, "right outer join") {
+ checkAnswer(
+ sql(
+ """
+ |SELECT
+ | big.key, big.value, small.a, small.b
+ |FROM
+ | testData2 small
+ |RIGHT OUTER JOIN
+ | testData big
+ |ON
+ | big.key = small.a
+ """.stripMargin),
+ expected
+ )
+ }
+
+ // FULL OUTER JOIN still does not use [[ExternalAppendOnlyUnsafeRowArray]]
+ // so should not cause any spill
+ assertNotSpilled(sparkContext, "full outer join") {
+ checkAnswer(
+ sql(
+ """
+ |SELECT
+ | big.key, big.value, small.a, small.b
+ |FROM
+ | testData2 small
+ |FULL OUTER JOIN
+ | testData big
+ |ON
+ | big.key = small.a
+ """.stripMargin),
+ expected
+ )
+ }
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala
new file mode 100644
index 0000000000..00c5f2550c
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala
@@ -0,0 +1,233 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.{SparkConf, SparkContext, SparkEnv, TaskContext}
+import org.apache.spark.memory.MemoryTestingUtils
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.util.Benchmark
+import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter
+
+object ExternalAppendOnlyUnsafeRowArrayBenchmark {
+
+ def testAgainstRawArrayBuffer(numSpillThreshold: Int, numRows: Int, iterations: Int): Unit = {
+ val random = new java.util.Random()
+ val rows = (1 to numRows).map(_ => {
+ val row = new UnsafeRow(1)
+ row.pointTo(new Array[Byte](64), 16)
+ row.setLong(0, random.nextLong())
+ row
+ })
+
+ val benchmark = new Benchmark(s"Array with $numRows rows", iterations * numRows)
+
+ // Internally, `ExternalAppendOnlyUnsafeRowArray` will create an
+ // in-memory buffer of size `numSpillThreshold`. This will mimic that
+ val initialSize =
+ Math.min(
+ ExternalAppendOnlyUnsafeRowArray.DefaultInitialSizeOfInMemoryBuffer,
+ numSpillThreshold)
+
+ benchmark.addCase("ArrayBuffer") { _: Int =>
+ var sum = 0L
+ for (_ <- 0L until iterations) {
+ val array = new ArrayBuffer[UnsafeRow](initialSize)
+
+ // Internally, `ExternalAppendOnlyUnsafeRowArray` will create a
+ // copy of the row. This will mimic that
+ rows.foreach(x => array += x.copy())
+
+ var i = 0
+ val n = array.length
+ while (i < n) {
+ sum = sum + array(i).getLong(0)
+ i += 1
+ }
+ array.clear()
+ }
+ }
+
+ benchmark.addCase("ExternalAppendOnlyUnsafeRowArray") { _: Int =>
+ var sum = 0L
+ for (_ <- 0L until iterations) {
+ val array = new ExternalAppendOnlyUnsafeRowArray(numSpillThreshold)
+ rows.foreach(x => array.add(x))
+
+ val iterator = array.generateIterator()
+ while (iterator.hasNext) {
+ sum = sum + iterator.next().getLong(0)
+ }
+ array.clear()
+ }
+ }
+
+ val conf = new SparkConf(false)
+ // Make the Java serializer write a reset instruction (TC_RESET) after each object to test
+ // for a bug we had with bytes written past the last object in a batch (SPARK-2792)
+ conf.set("spark.serializer.objectStreamReset", "1")
+ conf.set("spark.serializer", "org.apache.spark.serializer.JavaSerializer")
+
+ val sc = new SparkContext("local", "test", conf)
+ val taskContext = MemoryTestingUtils.fakeTaskContext(SparkEnv.get)
+ TaskContext.setTaskContext(taskContext)
+ benchmark.run()
+ sc.stop()
+ }
+
+ def testAgainstRawUnsafeExternalSorter(
+ numSpillThreshold: Int,
+ numRows: Int,
+ iterations: Int): Unit = {
+
+ val random = new java.util.Random()
+ val rows = (1 to numRows).map(_ => {
+ val row = new UnsafeRow(1)
+ row.pointTo(new Array[Byte](64), 16)
+ row.setLong(0, random.nextLong())
+ row
+ })
+
+ val benchmark = new Benchmark(s"Spilling with $numRows rows", iterations * numRows)
+
+ benchmark.addCase("UnsafeExternalSorter") { _: Int =>
+ var sum = 0L
+ for (_ <- 0L until iterations) {
+ val array = UnsafeExternalSorter.create(
+ TaskContext.get().taskMemoryManager(),
+ SparkEnv.get.blockManager,
+ SparkEnv.get.serializerManager,
+ TaskContext.get(),
+ null,
+ null,
+ 1024,
+ SparkEnv.get.memoryManager.pageSizeBytes,
+ numSpillThreshold,
+ false)
+
+ rows.foreach(x =>
+ array.insertRecord(
+ x.getBaseObject,
+ x.getBaseOffset,
+ x.getSizeInBytes,
+ 0,
+ false))
+
+ val unsafeRow = new UnsafeRow(1)
+ val iter = array.getIterator
+ while (iter.hasNext) {
+ iter.loadNext()
+ unsafeRow.pointTo(iter.getBaseObject, iter.getBaseOffset, iter.getRecordLength)
+ sum = sum + unsafeRow.getLong(0)
+ }
+ array.cleanupResources()
+ }
+ }
+
+ benchmark.addCase("ExternalAppendOnlyUnsafeRowArray") { _: Int =>
+ var sum = 0L
+ for (_ <- 0L until iterations) {
+ val array = new ExternalAppendOnlyUnsafeRowArray(numSpillThreshold)
+ rows.foreach(x => array.add(x))
+
+ val iterator = array.generateIterator()
+ while (iterator.hasNext) {
+ sum = sum + iterator.next().getLong(0)
+ }
+ array.clear()
+ }
+ }
+
+ val conf = new SparkConf(false)
+ // Make the Java serializer write a reset instruction (TC_RESET) after each object to test
+ // for a bug we had with bytes written past the last object in a batch (SPARK-2792)
+ conf.set("spark.serializer.objectStreamReset", "1")
+ conf.set("spark.serializer", "org.apache.spark.serializer.JavaSerializer")
+
+ val sc = new SparkContext("local", "test", conf)
+ val taskContext = MemoryTestingUtils.fakeTaskContext(SparkEnv.get)
+ TaskContext.setTaskContext(taskContext)
+ benchmark.run()
+ sc.stop()
+ }
+
+ def main(args: Array[String]): Unit = {
+
+ // ========================================================================================= //
+ // WITHOUT SPILL
+ // ========================================================================================= //
+
+ val spillThreshold = 100 * 1000
+
+ /*
+ Intel(R) Core(TM) i7-6920HQ CPU @ 2.90GHz
+
+ Array with 1000 rows: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+ ------------------------------------------------------------------------------------------------
+ ArrayBuffer 7821 / 7941 33.5 29.8 1.0X
+ ExternalAppendOnlyUnsafeRowArray 8798 / 8819 29.8 33.6 0.9X
+ */
+ testAgainstRawArrayBuffer(spillThreshold, 1000, 1 << 18)
+
+ /*
+ Intel(R) Core(TM) i7-6920HQ CPU @ 2.90GHz
+
+ Array with 30000 rows: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+ ------------------------------------------------------------------------------------------------
+ ArrayBuffer 19200 / 19206 25.6 39.1 1.0X
+ ExternalAppendOnlyUnsafeRowArray 19558 / 19562 25.1 39.8 1.0X
+ */
+ testAgainstRawArrayBuffer(spillThreshold, 30 * 1000, 1 << 14)
+
+ /*
+ Intel(R) Core(TM) i7-6920HQ CPU @ 2.90GHz
+
+ Array with 100000 rows: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+ ------------------------------------------------------------------------------------------------
+ ArrayBuffer 5949 / 6028 17.2 58.1 1.0X
+ ExternalAppendOnlyUnsafeRowArray 6078 / 6138 16.8 59.4 1.0X
+ */
+ testAgainstRawArrayBuffer(spillThreshold, 100 * 1000, 1 << 10)
+
+ // ========================================================================================= //
+ // WITH SPILL
+ // ========================================================================================= //
+
+ /*
+ Intel(R) Core(TM) i7-6920HQ CPU @ 2.90GHz
+
+ Spilling with 1000 rows: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+ ------------------------------------------------------------------------------------------------
+ UnsafeExternalSorter 9239 / 9470 28.4 35.2 1.0X
+ ExternalAppendOnlyUnsafeRowArray 8857 / 8909 29.6 33.8 1.0X
+ */
+ testAgainstRawUnsafeExternalSorter(100 * 1000, 1000, 1 << 18)
+
+ /*
+ Intel(R) Core(TM) i7-6920HQ CPU @ 2.90GHz
+
+ Spilling with 10000 rows: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+ ------------------------------------------------------------------------------------------------
+ UnsafeExternalSorter 4 / 5 39.3 25.5 1.0X
+ ExternalAppendOnlyUnsafeRowArray 5 / 6 29.8 33.5 0.8X
+ */
+ testAgainstRawUnsafeExternalSorter(
+ UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD.toInt, 10 * 1000, 1 << 4)
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala
new file mode 100644
index 0000000000..53c4163994
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala
@@ -0,0 +1,351 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import java.util.ConcurrentModificationException
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark._
+import org.apache.spark.memory.MemoryTestingUtils
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+
+class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSparkContext {
+ private val random = new java.util.Random()
+ private var taskContext: TaskContext = _
+
+ override def afterAll(): Unit = TaskContext.unset()
+
+ private def withExternalArray(spillThreshold: Int)
+ (f: ExternalAppendOnlyUnsafeRowArray => Unit): Unit = {
+ sc = new SparkContext("local", "test", new SparkConf(false))
+
+ taskContext = MemoryTestingUtils.fakeTaskContext(SparkEnv.get)
+ TaskContext.setTaskContext(taskContext)
+
+ val array = new ExternalAppendOnlyUnsafeRowArray(
+ taskContext.taskMemoryManager(),
+ SparkEnv.get.blockManager,
+ SparkEnv.get.serializerManager,
+ taskContext,
+ 1024,
+ SparkEnv.get.memoryManager.pageSizeBytes,
+ spillThreshold)
+ try f(array) finally {
+ array.clear()
+ }
+ }
+
+ private def insertRow(array: ExternalAppendOnlyUnsafeRowArray): Long = {
+ val valueInserted = random.nextLong()
+
+ val row = new UnsafeRow(1)
+ row.pointTo(new Array[Byte](64), 16)
+ row.setLong(0, valueInserted)
+ array.add(row)
+ valueInserted
+ }
+
+ private def checkIfValueExists(iterator: Iterator[UnsafeRow], expectedValue: Long): Unit = {
+ assert(iterator.hasNext)
+ val actualRow = iterator.next()
+ assert(actualRow.getLong(0) == expectedValue)
+ assert(actualRow.getSizeInBytes == 16)
+ }
+
+ private def validateData(
+ array: ExternalAppendOnlyUnsafeRowArray,
+ expectedValues: ArrayBuffer[Long]): Iterator[UnsafeRow] = {
+ val iterator = array.generateIterator()
+ for (value <- expectedValues) {
+ checkIfValueExists(iterator, value)
+ }
+
+ assert(!iterator.hasNext)
+ iterator
+ }
+
+ private def populateRows(
+ array: ExternalAppendOnlyUnsafeRowArray,
+ numRowsToBePopulated: Int): ArrayBuffer[Long] = {
+ val populatedValues = new ArrayBuffer[Long]
+ populateRows(array, numRowsToBePopulated, populatedValues)
+ }
+
+ private def populateRows(
+ array: ExternalAppendOnlyUnsafeRowArray,
+ numRowsToBePopulated: Int,
+ populatedValues: ArrayBuffer[Long]): ArrayBuffer[Long] = {
+ for (_ <- 0 until numRowsToBePopulated) {
+ populatedValues.append(insertRow(array))
+ }
+ populatedValues
+ }
+
+ private def getNumBytesSpilled: Long = {
+ TaskContext.get().taskMetrics().memoryBytesSpilled
+ }
+
+ private def assertNoSpill(): Unit = {
+ assert(getNumBytesSpilled == 0)
+ }
+
+ private def assertSpill(): Unit = {
+ assert(getNumBytesSpilled > 0)
+ }
+
+ test("insert rows less than the spillThreshold") {
+ val spillThreshold = 100
+ withExternalArray(spillThreshold) { array =>
+ assert(array.isEmpty)
+
+ val expectedValues = populateRows(array, 1)
+ assert(!array.isEmpty)
+ assert(array.length == 1)
+
+ val iterator1 = validateData(array, expectedValues)
+
+ // Add more rows (but not too many to trigger switch to [[UnsafeExternalSorter]])
+ // Verify that NO spill has happened
+ populateRows(array, spillThreshold - 1, expectedValues)
+ assert(array.length == spillThreshold)
+ assertNoSpill()
+
+ val iterator2 = validateData(array, expectedValues)
+
+ assert(!iterator1.hasNext)
+ assert(!iterator2.hasNext)
+ }
+ }
+
+ test("insert rows more than the spillThreshold to force spill") {
+ val spillThreshold = 100
+ withExternalArray(spillThreshold) { array =>
+ val numValuesInserted = 20 * spillThreshold
+
+ assert(array.isEmpty)
+ val expectedValues = populateRows(array, 1)
+ assert(array.length == 1)
+
+ val iterator1 = validateData(array, expectedValues)
+
+ // Populate more rows to trigger spill. Verify that spill has happened
+ populateRows(array, numValuesInserted - 1, expectedValues)
+ assert(array.length == numValuesInserted)
+ assertSpill()
+
+ val iterator2 = validateData(array, expectedValues)
+ assert(!iterator2.hasNext)
+
+ assert(!iterator1.hasNext)
+ intercept[ConcurrentModificationException](iterator1.next())
+ }
+ }
+
+ test("iterator on an empty array should be empty") {
+ withExternalArray(spillThreshold = 10) { array =>
+ val iterator = array.generateIterator()
+ assert(array.isEmpty)
+ assert(array.length == 0)
+ assert(!iterator.hasNext)
+ }
+ }
+
+ test("generate iterator with negative start index") {
+ withExternalArray(spillThreshold = 2) { array =>
+ val exception =
+ intercept[ArrayIndexOutOfBoundsException](array.generateIterator(startIndex = -10))
+
+ assert(exception.getMessage.contains(
+ "Invalid `startIndex` provided for generating iterator over the array")
+ )
+ }
+ }
+
+ test("generate iterator with start index exceeding array's size (without spill)") {
+ val spillThreshold = 2
+ withExternalArray(spillThreshold) { array =>
+ populateRows(array, spillThreshold / 2)
+
+ val exception =
+ intercept[ArrayIndexOutOfBoundsException](
+ array.generateIterator(startIndex = spillThreshold * 10))
+ assert(exception.getMessage.contains(
+ "Invalid `startIndex` provided for generating iterator over the array"))
+ }
+ }
+
+ test("generate iterator with start index exceeding array's size (with spill)") {
+ val spillThreshold = 2
+ withExternalArray(spillThreshold) { array =>
+ populateRows(array, spillThreshold * 2)
+
+ val exception =
+ intercept[ArrayIndexOutOfBoundsException](
+ array.generateIterator(startIndex = spillThreshold * 10))
+
+ assert(exception.getMessage.contains(
+ "Invalid `startIndex` provided for generating iterator over the array"))
+ }
+ }
+
+ test("generate iterator with custom start index (without spill)") {
+ val spillThreshold = 10
+ withExternalArray(spillThreshold) { array =>
+ val expectedValues = populateRows(array, spillThreshold)
+ val startIndex = spillThreshold / 2
+ val iterator = array.generateIterator(startIndex = startIndex)
+ for (i <- startIndex until expectedValues.length) {
+ checkIfValueExists(iterator, expectedValues(i))
+ }
+ }
+ }
+
+ test("generate iterator with custom start index (with spill)") {
+ val spillThreshold = 10
+ withExternalArray(spillThreshold) { array =>
+ val expectedValues = populateRows(array, spillThreshold * 10)
+ val startIndex = spillThreshold * 2
+ val iterator = array.generateIterator(startIndex = startIndex)
+ for (i <- startIndex until expectedValues.length) {
+ checkIfValueExists(iterator, expectedValues(i))
+ }
+ }
+ }
+
+ test("test iterator invalidation (without spill)") {
+ withExternalArray(spillThreshold = 10) { array =>
+ // insert 2 rows, iterate until the first row
+ populateRows(array, 2)
+
+ var iterator = array.generateIterator()
+ assert(iterator.hasNext)
+ iterator.next()
+
+ // Adding more row(s) should invalidate any old iterators
+ populateRows(array, 1)
+ assert(!iterator.hasNext)
+ intercept[ConcurrentModificationException](iterator.next())
+
+ // Clearing the array should also invalidate any old iterators
+ iterator = array.generateIterator()
+ assert(iterator.hasNext)
+ iterator.next()
+
+ array.clear()
+ assert(!iterator.hasNext)
+ intercept[ConcurrentModificationException](iterator.next())
+ }
+ }
+
+ test("test iterator invalidation (with spill)") {
+ val spillThreshold = 10
+ withExternalArray(spillThreshold) { array =>
+ // Populate enough rows so that spill has happens
+ populateRows(array, spillThreshold * 2)
+ assertSpill()
+
+ var iterator = array.generateIterator()
+ assert(iterator.hasNext)
+ iterator.next()
+
+ // Adding more row(s) should invalidate any old iterators
+ populateRows(array, 1)
+ assert(!iterator.hasNext)
+ intercept[ConcurrentModificationException](iterator.next())
+
+ // Clearing the array should also invalidate any old iterators
+ iterator = array.generateIterator()
+ assert(iterator.hasNext)
+ iterator.next()
+
+ array.clear()
+ assert(!iterator.hasNext)
+ intercept[ConcurrentModificationException](iterator.next())
+ }
+ }
+
+ test("clear on an empty the array") {
+ withExternalArray(spillThreshold = 2) { array =>
+ val iterator = array.generateIterator()
+ assert(!iterator.hasNext)
+
+ // multiple clear'ing should not have an side-effect
+ array.clear()
+ array.clear()
+ array.clear()
+ assert(array.isEmpty)
+ assert(array.length == 0)
+
+ // Clearing an empty array should also invalidate any old iterators
+ assert(!iterator.hasNext)
+ intercept[ConcurrentModificationException](iterator.next())
+ }
+ }
+
+ test("clear array (without spill)") {
+ val spillThreshold = 10
+ withExternalArray(spillThreshold) { array =>
+ // Populate rows ... but not enough to trigger spill
+ populateRows(array, spillThreshold / 2)
+ assertNoSpill()
+
+ // Clear the array
+ array.clear()
+ assert(array.isEmpty)
+
+ // Re-populate few rows so that there is no spill
+ // Verify the data. Verify that there was no spill
+ val expectedValues = populateRows(array, spillThreshold / 3)
+ validateData(array, expectedValues)
+ assertNoSpill()
+
+ // Populate more rows .. enough to not trigger a spill.
+ // Verify the data. Verify that there was no spill
+ populateRows(array, spillThreshold / 3, expectedValues)
+ validateData(array, expectedValues)
+ assertNoSpill()
+ }
+ }
+
+ test("clear array (with spill)") {
+ val spillThreshold = 10
+ withExternalArray(spillThreshold) { array =>
+ // Populate enough rows to trigger spill
+ populateRows(array, spillThreshold * 2)
+ val bytesSpilled = getNumBytesSpilled
+ assert(bytesSpilled > 0)
+
+ // Clear the array
+ array.clear()
+ assert(array.isEmpty)
+
+ // Re-populate the array ... but NOT upto the point that there is spill.
+ // Verify data. Verify that there was NO "extra" spill
+ val expectedValues = populateRows(array, spillThreshold / 2)
+ validateData(array, expectedValues)
+ assert(getNumBytesSpilled == bytesSpilled)
+
+ // Populate more rows to trigger spill
+ // Verify the data. Verify that there was "extra" spill
+ populateRows(array, spillThreshold * 2, expectedValues)
+ validateData(array, expectedValues)
+ assert(getNumBytesSpilled > bytesSpilled)
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala
index afd47897ed..52e4f04722 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution
import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.TestUtils.assertSpilled
case class WindowData(month: Int, area: String, product: Int)
@@ -412,4 +413,36 @@ class SQLWindowFunctionSuite extends QueryTest with SharedSQLContext {
""".stripMargin),
Row(1, 3, null) :: Row(2, null, 4) :: Nil)
}
+
+ test("test with low buffer spill threshold") {
+ val nums = sparkContext.parallelize(1 to 10).map(x => (x, x % 2)).toDF("x", "y")
+ nums.createOrReplaceTempView("nums")
+
+ val expected =
+ Row(1, 1, 1) ::
+ Row(0, 2, 3) ::
+ Row(1, 3, 6) ::
+ Row(0, 4, 10) ::
+ Row(1, 5, 15) ::
+ Row(0, 6, 21) ::
+ Row(1, 7, 28) ::
+ Row(0, 8, 36) ::
+ Row(1, 9, 45) ::
+ Row(0, 10, 55) :: Nil
+
+ val actual = sql(
+ """
+ |SELECT y, x, sum(x) OVER w1 AS running_sum
+ |FROM nums
+ |WINDOW w1 AS (ORDER BY x ROWS BETWEEN UNBOUNDED PRECEDiNG AND CURRENT RoW)
+ """.stripMargin)
+
+ withSQLConf("spark.sql.windowExec.buffer.spill.threshold" -> "1") {
+ assertSpilled(sparkContext, "test with low buffer spill threshold") {
+ checkAnswer(actual, expected)
+ }
+ }
+
+ spark.catalog.dropTempView("nums")
+ }
}