diff options
author | Josh Rosen <joshrosen@databricks.com> | 2016-12-20 01:19:38 +0100 |
---|---|---|
committer | Herman van Hovell <hvanhovell@databricks.com> | 2016-12-20 01:19:38 +0100 |
commit | 5857b9ac2d9808d9b89a5b29620b5052e2beebf5 (patch) | |
tree | 84e43c79b6e4590613116a93649809620682b603 /sql/core/src/main | |
parent | 4cb49412d1d7d10ffcc738475928c7de2bc59fd4 (diff) | |
download | spark-5857b9ac2d9808d9b89a5b29620b5052e2beebf5.tar.gz spark-5857b9ac2d9808d9b89a5b29620b5052e2beebf5.tar.bz2 spark-5857b9ac2d9808d9b89a5b29620b5052e2beebf5.zip |
[SPARK-18928] Check TaskContext.isInterrupted() in FileScanRDD, JDBCRDD & UnsafeSorter
## What changes were proposed in this pull request?
In order to respond to task cancellation, Spark tasks must periodically check `TaskContext.isInterrupted()`, but this check is missing on a few critical read paths used in Spark SQL, including `FileScanRDD`, `JDBCRDD`, and UnsafeSorter-based sorts. This can cause interrupted / cancelled tasks to continue running and become zombies (as also described in #16189).
This patch aims to fix this problem by adding `TaskContext.isInterrupted()` checks to these paths. Note that I could have used `InterruptibleIterator` to simply wrap a bunch of iterators but in some cases this would have an adverse performance penalty or might not be effective due to certain special uses of Iterators in Spark SQL. Instead, I inlined `InterruptibleIterator`-style logic into existing iterator subclasses.
## How was this patch tested?
Tested manually in `spark-shell` with two different reproductions of non-cancellable tasks, one involving scans of huge files and another involving sort-merge joins that spill to disk. Both causes of zombie tasks are fixed by the changes added here.
Author: Josh Rosen <joshrosen@databricks.com>
Closes #16340 from JoshRosen/sql-task-interruption.
Diffstat (limited to 'sql/core/src/main')
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala | 12 | ||||
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala | 5 |
2 files changed, 13 insertions, 4 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala index e753cd962a..dced536136 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala @@ -21,7 +21,7 @@ import java.io.IOException import scala.collection.mutable -import org.apache.spark.{Partition => RDDPartition, TaskContext} +import org.apache.spark.{Partition => RDDPartition, TaskContext, TaskKilledException} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.{InputFileBlockHolder, RDD} import org.apache.spark.sql.SparkSession @@ -99,7 +99,15 @@ class FileScanRDD( private[this] var currentFile: PartitionedFile = null private[this] var currentIterator: Iterator[Object] = null - def hasNext: Boolean = (currentIterator != null && currentIterator.hasNext) || nextIterator() + def hasNext: Boolean = { + // Kill the task in case it has been marked as killed. This logic is from + // InterruptibleIterator, but we inline it here instead of wrapping the iterator in order + // to avoid performance overhead. + if (context.isInterrupted()) { + throw new TaskKilledException + } + (currentIterator != null && currentIterator.hasNext) || nextIterator() + } def next(): Object = { val nextElement = currentIterator.next() // TODO: we should have a better separation of row based and batch based scan, so that we diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index d5b11e7bec..2bdc432541 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -23,7 +23,7 @@ import scala.util.control.NonFatal import org.apache.commons.lang3.StringUtils -import org.apache.spark.{Partition, SparkContext, TaskContext} +import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -301,6 +301,7 @@ private[jdbc] class JDBCRDD( rs = stmt.executeQuery() val rowsIterator = JdbcUtils.resultSetToSparkInternalRows(rs, schema, inputMetrics) - CompletionIterator[InternalRow, Iterator[InternalRow]](rowsIterator, close()) + CompletionIterator[InternalRow, Iterator[InternalRow]]( + new InterruptibleIterator(context, rowsIterator), close()) } } |