aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorAla Luszczak <ala@databricks.com>2017-02-09 19:07:06 +0100
committerReynold Xin <rxin@databricks.com>2017-02-09 19:07:06 +0100
commit4064574d031215fcfdf899a1ee9f3b6fecb1bfc9 (patch)
tree73e2eb131378c3bb24a3f0a32f86fcfe491e44c9 /sql
parent3fc8e8caf81d0049daf9b776ad4059b0df81630f (diff)
downloadspark-4064574d031215fcfdf899a1ee9f3b6fecb1bfc9.tar.gz
spark-4064574d031215fcfdf899a1ee9f3b6fecb1bfc9.tar.bz2
spark-4064574d031215fcfdf899a1ee9f3b6fecb1bfc9.zip
[SPARK-19514] Making range interruptible.
## What changes were proposed in this pull request? Previously range operator could not be interrupted. For example, using DAGScheduler.cancelStage(...) on a query with range might have been ineffective. This change adds periodic checks of TaskContext.isInterrupted to codegen version, and InterruptibleOperator to non-codegen version. I benchmarked the performance of codegen version on a sample query `spark.range(1000L * 1000 * 1000 * 10).count()` and there is no measurable difference. ## How was this patch tested? Adds a unit test. Author: Ala Luszczak <ala@databricks.com> Closes #16872 from ala/SPARK-19514b.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala12
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala38
3 files changed, 52 insertions, 6 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 04b812e79e..374d714ad5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -23,14 +23,14 @@ import java.util.{Map => JavaMap}
import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
+import scala.language.existentials
import scala.util.control.NonFatal
import com.google.common.cache.{CacheBuilder, CacheLoader}
import org.codehaus.janino.{ByteArrayClassLoader, ClassBodyEvaluator, SimpleCompiler}
import org.codehaus.janino.util.ClassFile
-import scala.language.existentials
-import org.apache.spark.SparkEnv
+import org.apache.spark.{SparkEnv, TaskContext, TaskKilledException}
import org.apache.spark.internal.Logging
import org.apache.spark.metrics.source.CodegenMetrics
import org.apache.spark.sql.catalyst.InternalRow
@@ -933,7 +933,9 @@ object CodeGenerator extends Logging {
classOf[UnsafeArrayData].getName,
classOf[MapData].getName,
classOf[UnsafeMapData].getName,
- classOf[Expression].getName
+ classOf[Expression].getName,
+ classOf[TaskContext].getName,
+ classOf[TaskKilledException].getName
))
evaluator.setExtendedClass(classOf[GeneratedClass])
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
index 792fb3e795..649c21b294 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution
import scala.concurrent.{ExecutionContext, Future}
import scala.concurrent.duration.Duration
-import org.apache.spark.SparkException
+import org.apache.spark.{InterruptibleIterator, SparkException, TaskContext}
import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
@@ -363,6 +363,9 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
val ev = ExprCode("", "false", value)
val BigInt = classOf[java.math.BigInteger].getName
+ val taskContext = ctx.freshName("taskContext")
+ ctx.addMutableState("TaskContext", taskContext, s"$taskContext = TaskContext.get();")
+
// In order to periodically update the metrics without inflicting performance penalty, this
// operator produces elements in batches. After a batch is complete, the metrics are updated
// and a new batch is started.
@@ -443,6 +446,10 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
| if (shouldStop()) return;
| }
|
+ | if ($taskContext.isInterrupted()) {
+ | throw new TaskKilledException();
+ | }
+ |
| long $nextBatchTodo;
| if ($numElementsTodo > ${batchSize}L) {
| $nextBatchTodo = ${batchSize}L;
@@ -482,7 +489,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
val rowSize = UnsafeRow.calculateBitSetWidthInBytes(1) + LongType.defaultSize
val unsafeRow = UnsafeRow.createFromByteArray(rowSize, 1)
- new Iterator[InternalRow] {
+ val iter = new Iterator[InternalRow] {
private[this] var number: Long = safePartitionStart
private[this] var overflow: Boolean = false
@@ -511,6 +518,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
unsafeRow
}
}
+ new InterruptibleIterator(TaskContext.get(), iter)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala
index 6d2d776c92..3ebfd9ac3d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala
@@ -17,14 +17,20 @@
package org.apache.spark.sql
+import scala.concurrent.duration._
import scala.math.abs
import scala.util.Random
+import org.scalatest.concurrent.Eventually
+
+import org.apache.spark.SparkException
+import org.apache.spark.scheduler.{SparkListener, SparkListenerExecutorMetricsUpdate, SparkListenerTaskStart}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
-class DataFrameRangeSuite extends QueryTest with SharedSQLContext {
+
+class DataFrameRangeSuite extends QueryTest with SharedSQLContext with Eventually {
test("SPARK-7150 range api") {
// numSlice is greater than length
@@ -127,4 +133,34 @@ class DataFrameRangeSuite extends QueryTest with SharedSQLContext {
}
}
}
+
+ test("Cancelling stage in a query with Range.") {
+ val listener = new SparkListener {
+ override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = {
+ Thread.sleep(100)
+ sparkContext.cancelStage(taskStart.stageId)
+ }
+ }
+
+ sparkContext.addSparkListener(listener)
+ for (codegen <- Seq(true, false)) {
+ withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> codegen.toString()) {
+ val ex = intercept[SparkException] {
+ spark.range(100000L).crossJoin(spark.range(100000L))
+ .toDF("a", "b").agg(sum("a"), sum("b")).collect()
+ }
+ ex.getCause() match {
+ case null =>
+ assert(ex.getMessage().contains("cancelled"))
+ case cause: SparkException =>
+ assert(cause.getMessage().contains("cancelled"))
+ case cause: Throwable =>
+ fail("Expected the casue to be SparkException, got " + cause.toString() + " instead.")
+ }
+ }
+ eventually(timeout(20.seconds)) {
+ assert(sparkContext.statusTracker.getExecutorInfos.map(_.numRunningTasks()).sum == 0)
+ }
+ }
+ }
}