From fb3081d3b38a50aa5e023c603e1b191e57f7c876 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Tue, 13 Dec 2016 09:53:22 -0800 Subject: [SPARK-13747][CORE] Fix potential ThreadLocal leaks in RPC when using ForkJoinPool ## What changes were proposed in this pull request? Some places in SQL may call `RpcEndpointRef.askWithRetry` (e.g., ParquetFileFormat.buildReader -> SparkContext.broadcast -> ... -> BlockManagerMaster.updateBlockInfo -> RpcEndpointRef.askWithRetry), which will finally call `Await.result`. It may cause `java.lang.IllegalArgumentException: spark.sql.execution.id is already set` when running in Scala ForkJoinPool. This PR includes the following changes to fix this issue: - Remove `ThreadUtils.awaitResult` - Rename `ThreadUtils. awaitResultInForkJoinSafely` to `ThreadUtils.awaitResult` - Replace `Await.result` in RpcTimeout with `ThreadUtils.awaitResult`. ## How was this patch tested? Jenkins Author: Shixiong Zhu Closes #16230 from zsxwing/fix-SPARK-13747. --- .../scala/org/apache/spark/rpc/RpcTimeout.scala | 12 ++----- .../scala/org/apache/spark/util/ThreadUtils.scala | 41 +++++++++------------- .../apache/spark/rdd/AsyncRDDActionsSuite.scala | 3 +- .../scheduler/OutputCommitCoordinatorSuite.scala | 3 +- 4 files changed, 21 insertions(+), 38 deletions(-) (limited to 'core/src') diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala b/core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala index 2761d39e37..efd26486ab 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala @@ -24,7 +24,7 @@ import scala.concurrent.duration._ import scala.util.control.NonFatal import org.apache.spark.{SparkConf, SparkException} -import org.apache.spark.util.Utils +import org.apache.spark.util.{ThreadUtils, Utils} /** * An exception thrown if RpcTimeout modifies a [[TimeoutException]]. @@ -72,15 +72,9 @@ private[spark] class RpcTimeout(val duration: FiniteDuration, val timeoutProp: S * is still not ready */ def awaitResult[T](future: Future[T]): T = { - val wrapAndRethrow: PartialFunction[Throwable, T] = { - case NonFatal(t) => - throw new SparkException("Exception thrown in awaitResult", t) - } try { - // scalastyle:off awaitresult - Await.result(future, duration) - // scalastyle:on awaitresult - } catch addMessageIfTimeout.orElse(wrapAndRethrow) + ThreadUtils.awaitResult(future, duration) + } catch addMessageIfTimeout } } diff --git a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala index 60a6e82c6f..1aa4456ed0 100644 --- a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala @@ -19,7 +19,7 @@ package org.apache.spark.util import java.util.concurrent._ -import scala.concurrent.{Await, Awaitable, ExecutionContext, ExecutionContextExecutor} +import scala.concurrent.{Awaitable, ExecutionContext, ExecutionContextExecutor} import scala.concurrent.duration.Duration import scala.concurrent.forkjoin.{ForkJoinPool => SForkJoinPool, ForkJoinWorkerThread => SForkJoinWorkerThread} import scala.util.control.NonFatal @@ -180,39 +180,30 @@ private[spark] object ThreadUtils { // scalastyle:off awaitresult /** - * Preferred alternative to `Await.result()`. This method wraps and re-throws any exceptions - * thrown by the underlying `Await` call, ensuring that this thread's stack trace appears in - * logs. - */ - @throws(classOf[SparkException]) - def awaitResult[T](awaitable: Awaitable[T], atMost: Duration): T = { - try { - Await.result(awaitable, atMost) - // scalastyle:on awaitresult - } catch { - case NonFatal(t) => - throw new SparkException("Exception thrown in awaitResult: ", t) - } - } - - /** - * Calls `Awaitable.result` directly to avoid using `ForkJoinPool`'s `BlockingContext`, wraps - * and re-throws any exceptions with nice stack track. + * Preferred alternative to `Await.result()`. + * + * This method wraps and re-throws any exceptions thrown by the underlying `Await` call, ensuring + * that this thread's stack trace appears in logs. * - * Codes running in the user's thread may be in a thread of Scala ForkJoinPool. As concurrent - * executions in ForkJoinPool may see some [[ThreadLocal]] value unexpectedly, this method - * basically prevents ForkJoinPool from running other tasks in the current waiting thread. + * In addition, it calls `Awaitable.result` directly to avoid using `ForkJoinPool`'s + * `BlockingContext`. Codes running in the user's thread may be in a thread of Scala ForkJoinPool. + * As concurrent executions in ForkJoinPool may see some [[ThreadLocal]] value unexpectedly, this + * method basically prevents ForkJoinPool from running other tasks in the current waiting thread. + * In general, we should use this method because many places in Spark use [[ThreadLocal]] and it's + * hard to debug when [[ThreadLocal]]s leak to other tasks. */ @throws(classOf[SparkException]) - def awaitResultInForkJoinSafely[T](awaitable: Awaitable[T], atMost: Duration): T = { + def awaitResult[T](awaitable: Awaitable[T], atMost: Duration): T = { try { // `awaitPermission` is not actually used anywhere so it's safe to pass in null here. // See SPARK-13747. val awaitPermission = null.asInstanceOf[scala.concurrent.CanAwait] - awaitable.result(Duration.Inf)(awaitPermission) + awaitable.result(atMost)(awaitPermission) } catch { - case NonFatal(t) => + // TimeoutException is thrown in the current thread, so not need to warp the exception. + case NonFatal(t) if !t.isInstanceOf[TimeoutException] => throw new SparkException("Exception thrown in awaitResult: ", t) } } + // scalastyle:on awaitresult } diff --git a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala index 58664e77d2..b29a53cffe 100644 --- a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala @@ -199,10 +199,9 @@ class AsyncRDDActionsSuite extends SparkFunSuite with BeforeAndAfterAll with Tim val f = sc.parallelize(1 to 100, 4) .mapPartitions(itr => { Thread.sleep(20); itr }) .countAsync() - val e = intercept[SparkException] { + intercept[TimeoutException] { ThreadUtils.awaitResult(f, Duration(20, "milliseconds")) } - assert(e.getCause.isInstanceOf[TimeoutException]) } private def testAsyncAction[R](action: RDD[Int] => FutureAction[R]): Unit = { diff --git a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala index 83288db92b..8c4e389e86 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala @@ -158,10 +158,9 @@ class OutputCommitCoordinatorSuite extends SparkFunSuite with BeforeAndAfter { 0 until rdd.partitions.size, resultHandler, () => Unit) // It's an error if the job completes successfully even though no committer was authorized, // so throw an exception if the job was allowed to complete. - val e = intercept[SparkException] { + intercept[TimeoutException] { ThreadUtils.awaitResult(futureAction, 5 seconds) } - assert(e.getCause.isInstanceOf[TimeoutException]) assert(tempDir.list().size === 0) } -- cgit v1.2.3