aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src/main
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@databricks.com>2016-12-20 01:19:38 +0100
committerHerman van Hovell <hvanhovell@databricks.com>2016-12-20 01:19:38 +0100
commit5857b9ac2d9808d9b89a5b29620b5052e2beebf5 (patch)
tree84e43c79b6e4590613116a93649809620682b603 /sql/core/src/main
parent4cb49412d1d7d10ffcc738475928c7de2bc59fd4 (diff)
downloadspark-5857b9ac2d9808d9b89a5b29620b5052e2beebf5.tar.gz
spark-5857b9ac2d9808d9b89a5b29620b5052e2beebf5.tar.bz2
spark-5857b9ac2d9808d9b89a5b29620b5052e2beebf5.zip
[SPARK-18928] Check TaskContext.isInterrupted() in FileScanRDD, JDBCRDD & UnsafeSorter
## What changes were proposed in this pull request? In order to respond to task cancellation, Spark tasks must periodically check `TaskContext.isInterrupted()`, but this check is missing on a few critical read paths used in Spark SQL, including `FileScanRDD`, `JDBCRDD`, and UnsafeSorter-based sorts. This can cause interrupted / cancelled tasks to continue running and become zombies (as also described in #16189). This patch aims to fix this problem by adding `TaskContext.isInterrupted()` checks to these paths. Note that I could have used `InterruptibleIterator` to simply wrap a bunch of iterators but in some cases this would have an adverse performance penalty or might not be effective due to certain special uses of Iterators in Spark SQL. Instead, I inlined `InterruptibleIterator`-style logic into existing iterator subclasses. ## How was this patch tested? Tested manually in `spark-shell` with two different reproductions of non-cancellable tasks, one involving scans of huge files and another involving sort-merge joins that spill to disk. Both causes of zombie tasks are fixed by the changes added here. Author: Josh Rosen <joshrosen@databricks.com> Closes #16340 from JoshRosen/sql-task-interruption.
Diffstat (limited to 'sql/core/src/main')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala12
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala5
2 files changed, 13 insertions, 4 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala
index e753cd962a..dced536136 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala
@@ -21,7 +21,7 @@ import java.io.IOException
import scala.collection.mutable
-import org.apache.spark.{Partition => RDDPartition, TaskContext}
+import org.apache.spark.{Partition => RDDPartition, TaskContext, TaskKilledException}
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.rdd.{InputFileBlockHolder, RDD}
import org.apache.spark.sql.SparkSession
@@ -99,7 +99,15 @@ class FileScanRDD(
private[this] var currentFile: PartitionedFile = null
private[this] var currentIterator: Iterator[Object] = null
- def hasNext: Boolean = (currentIterator != null && currentIterator.hasNext) || nextIterator()
+ def hasNext: Boolean = {
+ // Kill the task in case it has been marked as killed. This logic is from
+ // InterruptibleIterator, but we inline it here instead of wrapping the iterator in order
+ // to avoid performance overhead.
+ if (context.isInterrupted()) {
+ throw new TaskKilledException
+ }
+ (currentIterator != null && currentIterator.hasNext) || nextIterator()
+ }
def next(): Object = {
val nextElement = currentIterator.next()
// TODO: we should have a better separation of row based and batch based scan, so that we
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
index d5b11e7bec..2bdc432541 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
@@ -23,7 +23,7 @@ import scala.util.control.NonFatal
import org.apache.commons.lang3.StringUtils
-import org.apache.spark.{Partition, SparkContext, TaskContext}
+import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext}
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
@@ -301,6 +301,7 @@ private[jdbc] class JDBCRDD(
rs = stmt.executeQuery()
val rowsIterator = JdbcUtils.resultSetToSparkInternalRows(rs, schema, inputMetrics)
- CompletionIterator[InternalRow, Iterator[InternalRow]](rowsIterator, close())
+ CompletionIterator[InternalRow, Iterator[InternalRow]](
+ new InterruptibleIterator(context, rowsIterator), close())
}
}