diff options
-rw-r--r-- | streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala | 3 | ||||
-rw-r--r-- | streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala | 124 |
2 files changed, 80 insertions, 47 deletions
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala index 9727ed2ba1..6e6ed8d819 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala @@ -182,6 +182,9 @@ private[util] class BatchedWriteAheadLog(val wrappedLog: WriteAheadLog, conf: Sp buffer.clear() } } + + /** Method for querying the queue length. Should only be used in tests. */ + private def getQueueLength(): Int = walWriteQueue.size() } /** Static methods for aggregating and de-aggregating records. */ diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala index e96f4c2a29..9e13f25c2e 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala @@ -18,15 +18,14 @@ package org.apache.spark.streaming.util import java.io._ import java.nio.ByteBuffer -import java.util.concurrent.{ExecutionException, ThreadPoolExecutor} -import java.util.concurrent.atomic.AtomicInteger +import java.util.{Iterator => JIterator} +import java.util.concurrent.ThreadPoolExecutor import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import scala.concurrent._ import scala.concurrent.duration._ import scala.language.{implicitConversions, postfixOps} -import scala.util.{Failure, Success} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path @@ -37,12 +36,12 @@ import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer import org.scalatest.concurrent.Eventually import org.scalatest.concurrent.Eventually._ -import org.scalatest.{BeforeAndAfterEach, BeforeAndAfter} +import org.scalatest.{PrivateMethodTester, BeforeAndAfterEach, BeforeAndAfter} import org.scalatest.mock.MockitoSugar import org.apache.spark.streaming.scheduler._ import org.apache.spark.util.{ThreadUtils, ManualClock, Utils} -import org.apache.spark.{SparkException, SparkConf, SparkFunSuite} +import org.apache.spark.{SparkConf, SparkFunSuite} /** Common tests for WriteAheadLogs that we would like to test with different configurations. */ abstract class CommonWriteAheadLogTests( @@ -315,7 +314,11 @@ class FileBasedWriteAheadLogWithFileCloseAfterWriteSuite class BatchedWriteAheadLogSuite extends CommonWriteAheadLogTests( allowBatching = true, closeFileAfterWrite = false, - "BatchedWriteAheadLog") with MockitoSugar with BeforeAndAfterEach with Eventually { + "BatchedWriteAheadLog") + with MockitoSugar + with BeforeAndAfterEach + with Eventually + with PrivateMethodTester { import BatchedWriteAheadLog._ import WriteAheadLogSuite._ @@ -326,6 +329,8 @@ class BatchedWriteAheadLogSuite extends CommonWriteAheadLogTests( private var walBatchingExecutionContext: ExecutionContextExecutorService = _ private val sparkConf = new SparkConf() + private val queueLength = PrivateMethod[Int]('getQueueLength) + override def beforeEach(): Unit = { wal = mock[WriteAheadLog] walHandle = mock[WriteAheadLogRecordHandle] @@ -366,7 +371,7 @@ class BatchedWriteAheadLogSuite extends CommonWriteAheadLogTests( } // we make the write requests in separate threads so that we don't block the test thread - private def promiseWriteEvent(wal: WriteAheadLog, event: String, time: Long): Promise[Unit] = { + private def writeAsync(wal: WriteAheadLog, event: String, time: Long): Promise[Unit] = { val p = Promise[Unit]() p.completeWith(Future { val v = wal.write(event, time) @@ -375,28 +380,9 @@ class BatchedWriteAheadLogSuite extends CommonWriteAheadLogTests( p } - /** - * In order to block the writes on the writer thread, we mock the write method, and block it - * for some time with a promise. - */ - private def writeBlockingPromise(wal: WriteAheadLog): Promise[Any] = { - // we would like to block the write so that we can queue requests - val promise = Promise[Any]() - when(wal.write(any[ByteBuffer], any[Long])).thenAnswer( - new Answer[WriteAheadLogRecordHandle] { - override def answer(invocation: InvocationOnMock): WriteAheadLogRecordHandle = { - Await.ready(promise.future, 4.seconds) - walHandle - } - } - ) - promise - } - test("BatchedWriteAheadLog - name log with aggregated entries with the timestamp of last entry") { - val batchedWal = new BatchedWriteAheadLog(wal, sparkConf) - // block the write so that we can batch some records - val promise = writeBlockingPromise(wal) + val blockingWal = new BlockingWriteAheadLog(wal, walHandle) + val batchedWal = new BatchedWriteAheadLog(blockingWal, sparkConf) val event1 = "hello" val event2 = "world" @@ -406,21 +392,27 @@ class BatchedWriteAheadLogSuite extends CommonWriteAheadLogTests( // The queue.take() immediately takes the 3, and there is nothing left in the queue at that // moment. Then the promise blocks the writing of 3. The rest get queued. - promiseWriteEvent(batchedWal, event1, 3L) - // rest of the records will be batched while it takes 3 to get written - promiseWriteEvent(batchedWal, event2, 5L) - promiseWriteEvent(batchedWal, event3, 8L) - promiseWriteEvent(batchedWal, event4, 12L) - promiseWriteEvent(batchedWal, event5, 10L) + writeAsync(batchedWal, event1, 3L) + eventually(timeout(1 second)) { + assert(blockingWal.isBlocked) + assert(batchedWal.invokePrivate(queueLength()) === 0) + } + // rest of the records will be batched while it takes time for 3 to get written + writeAsync(batchedWal, event2, 5L) + writeAsync(batchedWal, event3, 8L) + writeAsync(batchedWal, event4, 12L) + writeAsync(batchedWal, event5, 10L) eventually(timeout(1 second)) { assert(walBatchingThreadPool.getActiveCount === 5) + assert(batchedWal.invokePrivate(queueLength()) === 4) } - promise.success(true) + blockingWal.allowWrite() val buffer1 = wrapArrayArrayByte(Array(event1)) val buffer2 = wrapArrayArrayByte(Array(event2, event3, event4, event5)) eventually(timeout(1 second)) { + assert(batchedWal.invokePrivate(queueLength()) === 0) verify(wal, times(1)).write(meq(buffer1), meq(3L)) // the file name should be the timestamp of the last record, as events should be naturally // in order of timestamp, and we need the last element. @@ -437,27 +429,32 @@ class BatchedWriteAheadLogSuite extends CommonWriteAheadLogTests( } test("BatchedWriteAheadLog - fail everything in queue during shutdown") { - val batchedWal = new BatchedWriteAheadLog(wal, sparkConf) + val blockingWal = new BlockingWriteAheadLog(wal, walHandle) + val batchedWal = new BatchedWriteAheadLog(blockingWal, sparkConf) - // block the write so that we can batch some records - writeBlockingPromise(wal) - - val event1 = ("hello", 3L) - val event2 = ("world", 5L) - val event3 = ("this", 8L) - val event4 = ("is", 9L) - val event5 = ("doge", 10L) + val event1 = "hello" + val event2 = "world" + val event3 = "this" // The queue.take() immediately takes the 3, and there is nothing left in the queue at that // moment. Then the promise blocks the writing of 3. The rest get queued. - val writePromises = Seq(event1, event2, event3, event4, event5).map { event => - promiseWriteEvent(batchedWal, event._1, event._2) + val promise1 = writeAsync(batchedWal, event1, 3L) + eventually(timeout(1 second)) { + assert(blockingWal.isBlocked) + assert(batchedWal.invokePrivate(queueLength()) === 0) } + // rest of the records will be batched while it takes time for 3 to get written + val promise2 = writeAsync(batchedWal, event2, 5L) + val promise3 = writeAsync(batchedWal, event3, 8L) eventually(timeout(1 second)) { - assert(walBatchingThreadPool.getActiveCount === 5) + assert(walBatchingThreadPool.getActiveCount === 3) + assert(blockingWal.isBlocked) + assert(batchedWal.invokePrivate(queueLength()) === 2) // event1 is being written } + val writePromises = Seq(promise1, promise2, promise3) + batchedWal.close() eventually(timeout(1 second)) { assert(writePromises.forall(_.isCompleted)) @@ -641,4 +638,37 @@ object WriteAheadLogSuite { def wrapArrayArrayByte[T](records: Array[T]): ByteBuffer = { ByteBuffer.wrap(Utils.serialize[Array[Array[Byte]]](records.map(Utils.serialize[T]))) } + + /** + * A wrapper WriteAheadLog that blocks the write function to allow batching with the + * BatchedWriteAheadLog. + */ + class BlockingWriteAheadLog( + wal: WriteAheadLog, + handle: WriteAheadLogRecordHandle) extends WriteAheadLog { + @volatile private var isWriteCalled: Boolean = false + @volatile private var blockWrite: Boolean = true + + override def write(record: ByteBuffer, time: Long): WriteAheadLogRecordHandle = { + isWriteCalled = true + eventually(Eventually.timeout(2 second)) { + assert(!blockWrite) + } + wal.write(record, time) + isWriteCalled = false + handle + } + override def read(segment: WriteAheadLogRecordHandle): ByteBuffer = wal.read(segment) + override def readAll(): JIterator[ByteBuffer] = wal.readAll() + override def clean(threshTime: Long, waitForCompletion: Boolean): Unit = { + wal.clean(threshTime, waitForCompletion) + } + override def close(): Unit = wal.close() + + def allowWrite(): Unit = { + blockWrite = false + } + + def isBlocked: Boolean = isWriteCalled + } } |