aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/util/ThreadUtils.scala18
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala23
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala9
3 files changed, 36 insertions, 14 deletions
diff --git a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala
index f9fbe2ff85..9abbf4a7a3 100644
--- a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala
+++ b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala
@@ -20,6 +20,7 @@ package org.apache.spark.util
import java.util.concurrent._
import scala.concurrent.{ExecutionContext, ExecutionContextExecutor}
+import scala.concurrent.forkjoin.{ForkJoinPool => SForkJoinPool, ForkJoinWorkerThread => SForkJoinWorkerThread}
import scala.util.control.NonFatal
import com.google.common.util.concurrent.{MoreExecutors, ThreadFactoryBuilder}
@@ -156,4 +157,21 @@ private[spark] object ThreadUtils {
result
}
}
+
+ /**
+ * Construct a new Scala ForkJoinPool with a specified max parallelism and name prefix.
+ */
+ def newForkJoinPool(prefix: String, maxThreadNumber: Int): SForkJoinPool = {
+ // Custom factory to set thread names
+ val factory = new SForkJoinPool.ForkJoinWorkerThreadFactory {
+ override def newThread(pool: SForkJoinPool) =
+ new SForkJoinWorkerThread(pool) {
+ setName(prefix + "-" + super.getName)
+ }
+ }
+ new SForkJoinPool(maxThreadNumber, factory,
+ null, // handler
+ false // asyncMode
+ )
+ }
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala
index 314263f26e..a3b7e783ac 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala
@@ -18,11 +18,11 @@ package org.apache.spark.streaming.util
import java.nio.ByteBuffer
import java.util.{Iterator => JIterator}
-import java.util.concurrent.{RejectedExecutionException, ThreadPoolExecutor}
+import java.util.concurrent.RejectedExecutionException
import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
-import scala.collection.parallel.ThreadPoolTaskSupport
+import scala.collection.parallel.ExecutionContextTaskSupport
import scala.concurrent.{Await, ExecutionContext, Future}
import scala.language.postfixOps
@@ -62,8 +62,8 @@ private[streaming] class FileBasedWriteAheadLog(
private val threadpoolName = {
"WriteAheadLogManager" + callerName.map(c => s" for $c").getOrElse("")
}
- private val threadpool = ThreadUtils.newDaemonCachedThreadPool(threadpoolName, 20)
- private val executionContext = ExecutionContext.fromExecutorService(threadpool)
+ private val forkJoinPool = ThreadUtils.newForkJoinPool(threadpoolName, 20)
+ private val executionContext = ExecutionContext.fromExecutorService(forkJoinPool)
override protected def logName = {
getClass.getName.stripSuffix("$") +
@@ -144,7 +144,7 @@ private[streaming] class FileBasedWriteAheadLog(
} else {
// For performance gains, it makes sense to parallelize the recovery if
// closeFileAfterWrite = true
- seqToParIterator(threadpool, logFilesToRead, readFile).asJava
+ seqToParIterator(executionContext, logFilesToRead, readFile).asJava
}
}
@@ -283,16 +283,17 @@ private[streaming] object FileBasedWriteAheadLog {
/**
* This creates an iterator from a parallel collection, by keeping at most `n` objects in memory
- * at any given time, where `n` is the size of the thread pool. This is crucial for use cases
- * where we create `FileBasedWriteAheadLogReader`s during parallel recovery. We don't want to
- * open up `k` streams altogether where `k` is the size of the Seq that we want to parallelize.
+ * at any given time, where `n` is at most the max of the size of the thread pool or 8. This is
+ * crucial for use cases where we create `FileBasedWriteAheadLogReader`s during parallel recovery.
+ * We don't want to open up `k` streams altogether where `k` is the size of the Seq that we want
+ * to parallelize.
*/
def seqToParIterator[I, O](
- tpool: ThreadPoolExecutor,
+ executionContext: ExecutionContext,
source: Seq[I],
handler: I => Iterator[O]): Iterator[O] = {
- val taskSupport = new ThreadPoolTaskSupport(tpool)
- val groupSize = tpool.getMaximumPoolSize.max(8)
+ val taskSupport = new ExecutionContextTaskSupport(executionContext)
+ val groupSize = taskSupport.parallelismLevel.max(8)
source.grouped(groupSize).flatMap { group =>
val parallelCollection = group.par
parallelCollection.tasksupport = taskSupport
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala
index 7460e8629b..8c980dee2c 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala
@@ -228,7 +228,9 @@ class FileBasedWriteAheadLogSuite
the list of files.
*/
val numThreads = 8
- val tpool = ThreadUtils.newDaemonFixedThreadPool(numThreads, "wal-test-thread-pool")
+ val fpool = ThreadUtils.newForkJoinPool("wal-test-thread-pool", numThreads)
+ val executionContext = ExecutionContext.fromExecutorService(fpool)
+
class GetMaxCounter {
private val value = new AtomicInteger()
@volatile private var max: Int = 0
@@ -258,7 +260,8 @@ class FileBasedWriteAheadLogSuite
val t = new Thread() {
override def run() {
// run the calculation on a separate thread so that we can release the latch
- val iterator = FileBasedWriteAheadLog.seqToParIterator[Int, Int](tpool, testSeq, handle)
+ val iterator = FileBasedWriteAheadLog.seqToParIterator[Int, Int](executionContext,
+ testSeq, handle)
collected = iterator.toSeq
}
}
@@ -273,7 +276,7 @@ class FileBasedWriteAheadLogSuite
// make sure we didn't open too many Iterators
assert(counter.getMax() <= numThreads)
} finally {
- tpool.shutdownNow()
+ fpool.shutdownNow()
}
}