aboutsummaryrefslogtreecommitdiff
path: root/streaming
diff options
context:
space:
mode:
Diffstat (limited to 'streaming')
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala23
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala9
2 files changed, 18 insertions, 14 deletions
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala
index 314263f26e..a3b7e783ac 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala
@@ -18,11 +18,11 @@ package org.apache.spark.streaming.util
import java.nio.ByteBuffer
import java.util.{Iterator => JIterator}
-import java.util.concurrent.{RejectedExecutionException, ThreadPoolExecutor}
+import java.util.concurrent.RejectedExecutionException
import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
-import scala.collection.parallel.ThreadPoolTaskSupport
+import scala.collection.parallel.ExecutionContextTaskSupport
import scala.concurrent.{Await, ExecutionContext, Future}
import scala.language.postfixOps
@@ -62,8 +62,8 @@ private[streaming] class FileBasedWriteAheadLog(
private val threadpoolName = {
"WriteAheadLogManager" + callerName.map(c => s" for $c").getOrElse("")
}
- private val threadpool = ThreadUtils.newDaemonCachedThreadPool(threadpoolName, 20)
- private val executionContext = ExecutionContext.fromExecutorService(threadpool)
+ private val forkJoinPool = ThreadUtils.newForkJoinPool(threadpoolName, 20)
+ private val executionContext = ExecutionContext.fromExecutorService(forkJoinPool)
override protected def logName = {
getClass.getName.stripSuffix("$") +
@@ -144,7 +144,7 @@ private[streaming] class FileBasedWriteAheadLog(
} else {
// For performance gains, it makes sense to parallelize the recovery if
// closeFileAfterWrite = true
- seqToParIterator(threadpool, logFilesToRead, readFile).asJava
+ seqToParIterator(executionContext, logFilesToRead, readFile).asJava
}
}
@@ -283,16 +283,17 @@ private[streaming] object FileBasedWriteAheadLog {
/**
* This creates an iterator from a parallel collection, by keeping at most `n` objects in memory
- * at any given time, where `n` is the size of the thread pool. This is crucial for use cases
- * where we create `FileBasedWriteAheadLogReader`s during parallel recovery. We don't want to
- * open up `k` streams altogether where `k` is the size of the Seq that we want to parallelize.
+ * at any given time, where `n` is at most the max of the size of the thread pool or 8. This is
+ * crucial for use cases where we create `FileBasedWriteAheadLogReader`s during parallel recovery.
+ * We don't want to open up `k` streams altogether where `k` is the size of the Seq that we want
+ * to parallelize.
*/
def seqToParIterator[I, O](
- tpool: ThreadPoolExecutor,
+ executionContext: ExecutionContext,
source: Seq[I],
handler: I => Iterator[O]): Iterator[O] = {
- val taskSupport = new ThreadPoolTaskSupport(tpool)
- val groupSize = tpool.getMaximumPoolSize.max(8)
+ val taskSupport = new ExecutionContextTaskSupport(executionContext)
+ val groupSize = taskSupport.parallelismLevel.max(8)
source.grouped(groupSize).flatMap { group =>
val parallelCollection = group.par
parallelCollection.tasksupport = taskSupport
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala
index 7460e8629b..8c980dee2c 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala
@@ -228,7 +228,9 @@ class FileBasedWriteAheadLogSuite
the list of files.
*/
val numThreads = 8
- val tpool = ThreadUtils.newDaemonFixedThreadPool(numThreads, "wal-test-thread-pool")
+ val fpool = ThreadUtils.newForkJoinPool("wal-test-thread-pool", numThreads)
+ val executionContext = ExecutionContext.fromExecutorService(fpool)
+
class GetMaxCounter {
private val value = new AtomicInteger()
@volatile private var max: Int = 0
@@ -258,7 +260,8 @@ class FileBasedWriteAheadLogSuite
val t = new Thread() {
override def run() {
// run the calculation on a separate thread so that we can release the latch
- val iterator = FileBasedWriteAheadLog.seqToParIterator[Int, Int](tpool, testSeq, handle)
+ val iterator = FileBasedWriteAheadLog.seqToParIterator[Int, Int](executionContext,
+ testSeq, handle)
collected = iterator.toSeq
}
}
@@ -273,7 +276,7 @@ class FileBasedWriteAheadLogSuite
// make sure we didn't open too many Iterators
assert(counter.getMax() <= numThreads)
} finally {
- tpool.shutdownNow()
+ fpool.shutdownNow()
}
}