aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorzsxwing <zsxwing@gmail.com>2015-05-16 00:44:29 -0700
committerReynold Xin <rxin@databricks.com>2015-05-16 00:44:29 -0700
commit47e7ffe36b8a8a246fe9af522aff480d19c0c8a6 (patch)
tree8cfa6ec81fbd0f72e5c9bf74e8b1be0b5217b150 /core
parent0ac8b01a07840f199bbc79fb845762284aead6de (diff)
downloadspark-47e7ffe36b8a8a246fe9af522aff480d19c0c8a6.tar.gz
spark-47e7ffe36b8a8a246fe9af522aff480d19c0c8a6.tar.bz2
spark-47e7ffe36b8a8a246fe9af522aff480d19c0c8a6.zip
[SPARK-7655][Core][SQL] Remove 'scala.concurrent.ExecutionContext.Implicits.global' in 'ask' and 'BroadcastHashJoin'
Because both `AkkaRpcEndpointRef.ask` and `BroadcastHashJoin` uses `scala.concurrent.ExecutionContext.Implicits.global`. However, because the tasks in `BroadcastHashJoin` are usually long-running tasks, which will occupy all threads in `global`. Then `ask` cannot get a chance to process the replies. For `ask`, actually the tasks are very simple, so we can use `MoreExecutors.sameThreadExecutor()`. For `BroadcastHashJoin`, it's better to use `ThreadUtils.newDaemonCachedThreadPool`. Author: zsxwing <zsxwing@gmail.com> Closes #6200 from zsxwing/SPARK-7655-2 and squashes the following commits: cfdc605 [zsxwing] Remove redundant imort and minor doc fix cf83153 [zsxwing] Add "sameThread" and "newDaemonCachedThreadPool with maxThreadNumber" to ThreadUtils 08ad0ee [zsxwing] Remove 'scala.concurrent.ExecutionContext.Implicits.global' in 'ask' and 'BroadcastHashJoin'
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/util/ThreadUtils.scala24
-rw-r--r--core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala12
3 files changed, 40 insertions, 4 deletions
diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala
index ba0d468f11..0161962cde 100644
--- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala
@@ -29,9 +29,11 @@ import akka.actor.{ActorSystem, ExtendedActorSystem, Actor, ActorRef, Props, Add
import akka.event.Logging.Error
import akka.pattern.{ask => akkaAsk}
import akka.remote.{AssociationEvent, AssociatedEvent, DisassociatedEvent, AssociationErrorEvent}
+import com.google.common.util.concurrent.MoreExecutors
+
import org.apache.spark.{SparkException, Logging, SparkConf}
import org.apache.spark.rpc._
-import org.apache.spark.util.{ActorLogReceive, AkkaUtils}
+import org.apache.spark.util.{ActorLogReceive, AkkaUtils, ThreadUtils}
/**
* A RpcEnv implementation based on Akka.
@@ -294,8 +296,8 @@ private[akka] class AkkaRpcEndpointRef(
}
override def ask[T: ClassTag](message: Any, timeout: FiniteDuration): Future[T] = {
- import scala.concurrent.ExecutionContext.Implicits.global
actorRef.ask(AkkaMessage(message, true))(timeout).flatMap {
+ // The function will run in the calling thread, so it should be short and never block.
case msg @ AkkaMessage(message, reply) =>
if (reply) {
logError(s"Receive $msg but the sender cannot reply")
@@ -305,7 +307,7 @@ private[akka] class AkkaRpcEndpointRef(
}
case AkkaFailure(e) =>
Future.failed(e)
- }.mapTo[T]
+ }(ThreadUtils.sameThread).mapTo[T]
}
override def toString: String = s"${getClass.getSimpleName}($actorRef)"
diff --git a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala
index 098a4b7949..ca5624a3d8 100644
--- a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala
+++ b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala
@@ -20,10 +20,22 @@ package org.apache.spark.util
import java.util.concurrent._
-import com.google.common.util.concurrent.ThreadFactoryBuilder
+import scala.concurrent.{ExecutionContext, ExecutionContextExecutor}
+
+import com.google.common.util.concurrent.{MoreExecutors, ThreadFactoryBuilder}
private[spark] object ThreadUtils {
+ private val sameThreadExecutionContext =
+ ExecutionContext.fromExecutorService(MoreExecutors.sameThreadExecutor())
+
+ /**
+ * An `ExecutionContextExecutor` that runs each task in the thread that invokes `execute/submit`.
+ * The caller should make sure the tasks running in this `ExecutionContextExecutor` are short and
+ * never block.
+ */
+ def sameThread: ExecutionContextExecutor = sameThreadExecutionContext
+
/**
* Create a thread factory that names threads with a prefix and also sets the threads to daemon.
*/
@@ -41,6 +53,16 @@ private[spark] object ThreadUtils {
}
/**
+ * Create a cached thread pool whose max number of threads is `maxThreadNumber`. Thread names
+ * are formatted as prefix-ID, where ID is a unique, sequentially assigned integer.
+ */
+ def newDaemonCachedThreadPool(prefix: String, maxThreadNumber: Int): ThreadPoolExecutor = {
+ val threadFactory = namedThreadFactory(prefix)
+ new ThreadPoolExecutor(
+ 0, maxThreadNumber, 60L, TimeUnit.SECONDS, new SynchronousQueue[Runnable], threadFactory)
+ }
+
+ /**
* Wrapper over newFixedThreadPool. Thread names are formatted as prefix-ID, where ID is a
* unique, sequentially assigned integer.
*/
diff --git a/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala
index a3aa3e953f..751d3df9cc 100644
--- a/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala
@@ -20,6 +20,9 @@ package org.apache.spark.util
import java.util.concurrent.{CountDownLatch, TimeUnit}
+import scala.concurrent.{Await, Future}
+import scala.concurrent.duration._
+
import org.scalatest.FunSuite
class ThreadUtilsSuite extends FunSuite {
@@ -54,4 +57,13 @@ class ThreadUtilsSuite extends FunSuite {
executor.shutdownNow()
}
}
+
+ test("sameThread") {
+ val callerThreadName = Thread.currentThread().getName()
+ val f = Future {
+ Thread.currentThread().getName()
+ }(ThreadUtils.sameThread)
+ val futureThreadName = Await.result(f, 10.seconds)
+ assert(futureThreadName === callerThreadName)
+ }
}