aboutsummaryrefslogtreecommitdiff
path: root/streaming
diff options
context:
space:
mode:
Diffstat (limited to 'streaming')
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala62
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala25
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala223
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogUtils.scala21
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala506
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogUtilsSuite.scala122
6 files changed, 767 insertions, 192 deletions
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala
index f2711d1355..500dc70c98 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala
@@ -22,12 +22,13 @@ import java.nio.ByteBuffer
import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.language.implicitConversions
+import scala.util.control.NonFatal
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.spark.streaming.Time
-import org.apache.spark.streaming.util.{WriteAheadLog, WriteAheadLogUtils}
+import org.apache.spark.streaming.util.{BatchedWriteAheadLog, WriteAheadLog, WriteAheadLogUtils}
import org.apache.spark.util.{Clock, Utils}
import org.apache.spark.{Logging, SparkConf}
@@ -41,7 +42,6 @@ private[streaming] case class BatchAllocationEvent(time: Time, allocatedBlocks:
private[streaming] case class BatchCleanupEvent(times: Seq[Time])
extends ReceivedBlockTrackerLogEvent
-
/** Class representing the blocks of all the streams allocated to a batch */
private[streaming]
case class AllocatedBlocks(streamIdToAllocatedBlocks: Map[Int, Seq[ReceivedBlockInfo]]) {
@@ -82,15 +82,22 @@ private[streaming] class ReceivedBlockTracker(
}
/** Add received block. This event will get written to the write ahead log (if enabled). */
- def addBlock(receivedBlockInfo: ReceivedBlockInfo): Boolean = synchronized {
+ def addBlock(receivedBlockInfo: ReceivedBlockInfo): Boolean = {
try {
- writeToLog(BlockAdditionEvent(receivedBlockInfo))
- getReceivedBlockQueue(receivedBlockInfo.streamId) += receivedBlockInfo
- logDebug(s"Stream ${receivedBlockInfo.streamId} received " +
- s"block ${receivedBlockInfo.blockStoreResult.blockId}")
- true
+ val writeResult = writeToLog(BlockAdditionEvent(receivedBlockInfo))
+ if (writeResult) {
+ synchronized {
+ getReceivedBlockQueue(receivedBlockInfo.streamId) += receivedBlockInfo
+ }
+ logDebug(s"Stream ${receivedBlockInfo.streamId} received " +
+ s"block ${receivedBlockInfo.blockStoreResult.blockId}")
+ } else {
+ logDebug(s"Failed to acknowledge stream ${receivedBlockInfo.streamId} receiving " +
+ s"block ${receivedBlockInfo.blockStoreResult.blockId} in the Write Ahead Log.")
+ }
+ writeResult
} catch {
- case e: Exception =>
+ case NonFatal(e) =>
logError(s"Error adding block $receivedBlockInfo", e)
false
}
@@ -106,10 +113,12 @@ private[streaming] class ReceivedBlockTracker(
(streamId, getReceivedBlockQueue(streamId).dequeueAll(x => true))
}.toMap
val allocatedBlocks = AllocatedBlocks(streamIdToBlocks)
- writeToLog(BatchAllocationEvent(batchTime, allocatedBlocks))
- timeToAllocatedBlocks(batchTime) = allocatedBlocks
- lastAllocatedBatchTime = batchTime
- allocatedBlocks
+ if (writeToLog(BatchAllocationEvent(batchTime, allocatedBlocks))) {
+ timeToAllocatedBlocks.put(batchTime, allocatedBlocks)
+ lastAllocatedBatchTime = batchTime
+ } else {
+ logInfo(s"Possibly processed batch $batchTime need to be processed again in WAL recovery")
+ }
} else {
// This situation occurs when:
// 1. WAL is ended with BatchAllocationEvent, but without BatchCleanupEvent,
@@ -157,9 +166,12 @@ private[streaming] class ReceivedBlockTracker(
require(cleanupThreshTime.milliseconds < clock.getTimeMillis())
val timesToCleanup = timeToAllocatedBlocks.keys.filter { _ < cleanupThreshTime }.toSeq
logInfo("Deleting batches " + timesToCleanup)
- writeToLog(BatchCleanupEvent(timesToCleanup))
- timeToAllocatedBlocks --= timesToCleanup
- writeAheadLogOption.foreach(_.clean(cleanupThreshTime.milliseconds, waitForCompletion))
+ if (writeToLog(BatchCleanupEvent(timesToCleanup))) {
+ timeToAllocatedBlocks --= timesToCleanup
+ writeAheadLogOption.foreach(_.clean(cleanupThreshTime.milliseconds, waitForCompletion))
+ } else {
+ logWarning("Failed to acknowledge batch clean up in the Write Ahead Log.")
+ }
}
/** Stop the block tracker. */
@@ -185,8 +197,8 @@ private[streaming] class ReceivedBlockTracker(
logTrace(s"Recovery: Inserting allocated batch for time $batchTime to " +
s"${allocatedBlocks.streamIdToAllocatedBlocks}")
streamIdToUnallocatedBlockQueues.values.foreach { _.clear() }
- lastAllocatedBatchTime = batchTime
timeToAllocatedBlocks.put(batchTime, allocatedBlocks)
+ lastAllocatedBatchTime = batchTime
}
// Cleanup the batch allocations
@@ -213,12 +225,20 @@ private[streaming] class ReceivedBlockTracker(
}
/** Write an update to the tracker to the write ahead log */
- private def writeToLog(record: ReceivedBlockTrackerLogEvent) {
+ private def writeToLog(record: ReceivedBlockTrackerLogEvent): Boolean = {
if (isWriteAheadLogEnabled) {
- logDebug(s"Writing to log $record")
- writeAheadLogOption.foreach { logManager =>
- logManager.write(ByteBuffer.wrap(Utils.serialize(record)), clock.getTimeMillis())
+ logTrace(s"Writing record: $record")
+ try {
+ writeAheadLogOption.get.write(ByteBuffer.wrap(Utils.serialize(record)),
+ clock.getTimeMillis())
+ true
+ } catch {
+ case NonFatal(e) =>
+ logWarning(s"Exception thrown while writing record: $record to the WriteAheadLog.", e)
+ false
}
+ } else {
+ true
}
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala
index b183d856f5..ea5d12b50f 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala
@@ -20,7 +20,7 @@ package org.apache.spark.streaming.scheduler
import java.util.concurrent.{CountDownLatch, TimeUnit}
import scala.collection.mutable.HashMap
-import scala.concurrent.ExecutionContext
+import scala.concurrent.{Future, ExecutionContext}
import scala.language.existentials
import scala.util.{Failure, Success}
@@ -437,7 +437,12 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
// TODO Remove this thread pool after https://github.com/apache/spark/issues/7385 is merged
private val submitJobThreadPool = ExecutionContext.fromExecutorService(
- ThreadUtils.newDaemonCachedThreadPool("submit-job-thead-pool"))
+ ThreadUtils.newDaemonCachedThreadPool("submit-job-thread-pool"))
+
+ private val walBatchingThreadPool = ExecutionContext.fromExecutorService(
+ ThreadUtils.newDaemonCachedThreadPool("wal-batching-thread-pool"))
+
+ @volatile private var active: Boolean = true
override def receive: PartialFunction[Any, Unit] = {
// Local messages
@@ -488,7 +493,19 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
registerReceiver(streamId, typ, host, executorId, receiverEndpoint, context.senderAddress)
context.reply(successful)
case AddBlock(receivedBlockInfo) =>
- context.reply(addBlock(receivedBlockInfo))
+ if (WriteAheadLogUtils.isBatchingEnabled(ssc.conf, isDriver = true)) {
+ walBatchingThreadPool.execute(new Runnable {
+ override def run(): Unit = Utils.tryLogNonFatalError {
+ if (active) {
+ context.reply(addBlock(receivedBlockInfo))
+ } else {
+ throw new IllegalStateException("ReceiverTracker RpcEndpoint shut down.")
+ }
+ }
+ })
+ } else {
+ context.reply(addBlock(receivedBlockInfo))
+ }
case DeregisterReceiver(streamId, message, error) =>
deregisterReceiver(streamId, message, error)
context.reply(true)
@@ -599,6 +616,8 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
override def onStop(): Unit = {
submitJobThreadPool.shutdownNow()
+ active = false
+ walBatchingThreadPool.shutdown()
}
/**
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala
new file mode 100644
index 0000000000..9727ed2ba1
--- /dev/null
+++ b/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala
@@ -0,0 +1,223 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.util
+
+import java.nio.ByteBuffer
+import java.util.concurrent.LinkedBlockingQueue
+import java.util.{Iterator => JIterator}
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable.ArrayBuffer
+import scala.concurrent.{Await, Promise}
+import scala.concurrent.duration._
+import scala.util.control.NonFatal
+
+import org.apache.spark.{Logging, SparkConf}
+import org.apache.spark.util.Utils
+
+/**
+ * A wrapper for a WriteAheadLog that batches records before writing data. Handles aggregation
+ * during writes, and de-aggregation in the `readAll` method. The end consumer has to handle
+ * de-aggregation after the `read` method. In addition, the `WriteAheadLogRecordHandle` returned
+ * after the write will contain the batch of records rather than individual records.
+ *
+ * When writing a batch of records, the `time` passed to the `wrappedLog` will be the timestamp
+ * of the latest record in the batch. This is very important in achieving correctness. Consider the
+ * following example:
+ * We receive records with timestamps 1, 3, 5, 7. We use "log-1" as the filename. Once we receive
+ * a clean up request for timestamp 3, we would clean up the file "log-1", and lose data regarding
+ * 5 and 7.
+ *
+ * This means the caller can assume the same write semantics as any other WriteAheadLog
+ * implementation despite the batching in the background - when the write() returns, the data is
+ * written to the WAL and is durable. To take advantage of the batching, the caller can write from
+ * multiple threads, each of which will stay blocked until the corresponding data has been written.
+ *
+ * All other methods of the WriteAheadLog interface will be passed on to the wrapped WriteAheadLog.
+ */
+private[util] class BatchedWriteAheadLog(val wrappedLog: WriteAheadLog, conf: SparkConf)
+ extends WriteAheadLog with Logging {
+
+ import BatchedWriteAheadLog._
+
+ private val walWriteQueue = new LinkedBlockingQueue[Record]()
+
+ // Whether the writer thread is active
+ @volatile private var active: Boolean = true
+ private val buffer = new ArrayBuffer[Record]()
+
+ private val batchedWriterThread = startBatchedWriterThread()
+
+ /**
+ * Write a byte buffer to the log file. This method adds the byteBuffer to a queue and blocks
+ * until the record is properly written by the parent.
+ */
+ override def write(byteBuffer: ByteBuffer, time: Long): WriteAheadLogRecordHandle = {
+ val promise = Promise[WriteAheadLogRecordHandle]()
+ val putSuccessfully = synchronized {
+ if (active) {
+ walWriteQueue.offer(Record(byteBuffer, time, promise))
+ true
+ } else {
+ false
+ }
+ }
+ if (putSuccessfully) {
+ Await.result(promise.future, WriteAheadLogUtils.getBatchingTimeout(conf).milliseconds)
+ } else {
+ throw new IllegalStateException("close() was called on BatchedWriteAheadLog before " +
+ s"write request with time $time could be fulfilled.")
+ }
+ }
+
+ /**
+ * This method is not supported as the resulting ByteBuffer would actually require de-aggregation.
+ * This method is primarily used in testing, and to ensure that it is not used in production,
+ * we throw an UnsupportedOperationException.
+ */
+ override def read(segment: WriteAheadLogRecordHandle): ByteBuffer = {
+ throw new UnsupportedOperationException("read() is not supported for BatchedWriteAheadLog " +
+ "as the data may require de-aggregation.")
+ }
+
+ /**
+ * Read all the existing logs from the log directory. The output of the wrapped WriteAheadLog
+ * will be de-aggregated.
+ */
+ override def readAll(): JIterator[ByteBuffer] = {
+ wrappedLog.readAll().asScala.flatMap(deaggregate).asJava
+ }
+
+ /**
+ * Delete the log files that are older than the threshold time.
+ *
+ * This method is handled by the parent WriteAheadLog.
+ */
+ override def clean(threshTime: Long, waitForCompletion: Boolean): Unit = {
+ wrappedLog.clean(threshTime, waitForCompletion)
+ }
+
+
+ /**
+ * Stop the batched writer thread, fulfill promises with failures and close the wrapped WAL.
+ */
+ override def close(): Unit = {
+ logInfo(s"BatchedWriteAheadLog shutting down at time: ${System.currentTimeMillis()}.")
+ synchronized {
+ active = false
+ }
+ batchedWriterThread.interrupt()
+ batchedWriterThread.join()
+ while (!walWriteQueue.isEmpty) {
+ val Record(_, time, promise) = walWriteQueue.poll()
+ promise.failure(new IllegalStateException("close() was called on BatchedWriteAheadLog " +
+ s"before write request with time $time could be fulfilled."))
+ }
+ wrappedLog.close()
+ }
+
+ /** Start the actual log writer on a separate thread. */
+ private def startBatchedWriterThread(): Thread = {
+ val thread = new Thread(new Runnable {
+ override def run(): Unit = {
+ while (active) {
+ try {
+ flushRecords()
+ } catch {
+ case NonFatal(e) =>
+ logWarning("Encountered exception in Batched Writer Thread.", e)
+ }
+ }
+ logInfo("BatchedWriteAheadLog Writer thread exiting.")
+ }
+ }, "BatchedWriteAheadLog Writer")
+ thread.setDaemon(true)
+ thread.start()
+ thread
+ }
+
+ /** Write all the records in the buffer to the write ahead log. */
+ private def flushRecords(): Unit = {
+ try {
+ buffer.append(walWriteQueue.take())
+ val numBatched = walWriteQueue.drainTo(buffer.asJava) + 1
+ logDebug(s"Received $numBatched records from queue")
+ } catch {
+ case _: InterruptedException =>
+ logWarning("BatchedWriteAheadLog Writer queue interrupted.")
+ }
+ try {
+ var segment: WriteAheadLogRecordHandle = null
+ if (buffer.length > 0) {
+ logDebug(s"Batched ${buffer.length} records for Write Ahead Log write")
+ // We take the latest record for the timestamp. Please refer to the class Javadoc for
+ // detailed explanation
+ val time = buffer.last.time
+ segment = wrappedLog.write(aggregate(buffer), time)
+ }
+ buffer.foreach(_.promise.success(segment))
+ } catch {
+ case e: InterruptedException =>
+ logWarning("BatchedWriteAheadLog Writer queue interrupted.", e)
+ buffer.foreach(_.promise.failure(e))
+ case NonFatal(e) =>
+ logWarning(s"BatchedWriteAheadLog Writer failed to write $buffer", e)
+ buffer.foreach(_.promise.failure(e))
+ } finally {
+ buffer.clear()
+ }
+ }
+}
+
+/** Static methods for aggregating and de-aggregating records. */
+private[util] object BatchedWriteAheadLog {
+
+ /**
+ * Wrapper class for representing the records that we will write to the WriteAheadLog. Coupled
+ * with the timestamp for the write request of the record, and the promise that will block the
+ * write request, while a separate thread is actually performing the write.
+ */
+ case class Record(data: ByteBuffer, time: Long, promise: Promise[WriteAheadLogRecordHandle])
+
+ /** Copies the byte array of a ByteBuffer. */
+ private def getByteArray(buffer: ByteBuffer): Array[Byte] = {
+ val byteArray = new Array[Byte](buffer.remaining())
+ buffer.get(byteArray)
+ byteArray
+ }
+
+ /** Aggregate multiple serialized ReceivedBlockTrackerLogEvents in a single ByteBuffer. */
+ def aggregate(records: Seq[Record]): ByteBuffer = {
+ ByteBuffer.wrap(Utils.serialize[Array[Array[Byte]]](
+ records.map(record => getByteArray(record.data)).toArray))
+ }
+
+ /**
+ * De-aggregate serialized ReceivedBlockTrackerLogEvents in a single ByteBuffer.
+ * A stream may not have used batching initially, but started using it after a restart. This
+ * method therefore needs to be backwards compatible.
+ */
+ def deaggregate(buffer: ByteBuffer): Array[ByteBuffer] = {
+ try {
+ Utils.deserialize[Array[Array[Byte]]](getByteArray(buffer)).map(ByteBuffer.wrap)
+ } catch {
+ case _: ClassCastException => // users may restart a stream with batching enabled
+ Array(buffer)
+ }
+ }
+}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogUtils.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogUtils.scala
index 0ea970e61b..731a369fc9 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogUtils.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogUtils.scala
@@ -38,6 +38,8 @@ private[streaming] object WriteAheadLogUtils extends Logging {
val DRIVER_WAL_ROLLING_INTERVAL_CONF_KEY =
"spark.streaming.driver.writeAheadLog.rollingIntervalSecs"
val DRIVER_WAL_MAX_FAILURES_CONF_KEY = "spark.streaming.driver.writeAheadLog.maxFailures"
+ val DRIVER_WAL_BATCHING_CONF_KEY = "spark.streaming.driver.writeAheadLog.allowBatching"
+ val DRIVER_WAL_BATCHING_TIMEOUT_CONF_KEY = "spark.streaming.driver.writeAheadLog.batchingTimeout"
val DRIVER_WAL_CLOSE_AFTER_WRITE_CONF_KEY =
"spark.streaming.driver.writeAheadLog.closeFileAfterWrite"
@@ -64,6 +66,18 @@ private[streaming] object WriteAheadLogUtils extends Logging {
}
}
+ def isBatchingEnabled(conf: SparkConf, isDriver: Boolean): Boolean = {
+ isDriver && conf.getBoolean(DRIVER_WAL_BATCHING_CONF_KEY, defaultValue = false)
+ }
+
+ /**
+ * How long we will wait for the wrappedLog in the BatchedWriteAheadLog to write the records
+ * before we fail the write attempt to unblock receivers.
+ */
+ def getBatchingTimeout(conf: SparkConf): Long = {
+ conf.getLong(DRIVER_WAL_BATCHING_TIMEOUT_CONF_KEY, defaultValue = 5000)
+ }
+
def shouldCloseFileAfterWrite(conf: SparkConf, isDriver: Boolean): Boolean = {
if (isDriver) {
conf.getBoolean(DRIVER_WAL_CLOSE_AFTER_WRITE_CONF_KEY, defaultValue = false)
@@ -115,7 +129,7 @@ private[streaming] object WriteAheadLogUtils extends Logging {
} else {
sparkConf.getOption(RECEIVER_WAL_CLASS_CONF_KEY)
}
- classNameOption.map { className =>
+ val wal = classNameOption.map { className =>
try {
instantiateClass(
Utils.classForName(className).asInstanceOf[Class[_ <: WriteAheadLog]], sparkConf)
@@ -128,6 +142,11 @@ private[streaming] object WriteAheadLogUtils extends Logging {
getRollingIntervalSecs(sparkConf, isDriver), getMaxFailures(sparkConf, isDriver),
shouldCloseFileAfterWrite(sparkConf, isDriver))
}
+ if (isBatchingEnabled(sparkConf, isDriver)) {
+ new BatchedWriteAheadLog(wal, sparkConf)
+ } else {
+ wal
+ }
}
/** Instantiate the class, either using single arg constructor or zero arg constructor */
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala
index 93ae41a3d2..e96f4c2a29 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala
@@ -18,31 +18,47 @@ package org.apache.spark.streaming.util
import java.io._
import java.nio.ByteBuffer
-import java.util
+import java.util.concurrent.{ExecutionException, ThreadPoolExecutor}
+import java.util.concurrent.atomic.AtomicInteger
import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
+import scala.concurrent._
import scala.concurrent.duration._
import scala.language.{implicitConversions, postfixOps}
-import scala.reflect.ClassTag
+import scala.util.{Failure, Success}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
+import org.mockito.Matchers.{eq => meq}
+import org.mockito.Matchers._
+import org.mockito.Mockito._
+import org.mockito.invocation.InvocationOnMock
+import org.mockito.stubbing.Answer
+import org.scalatest.concurrent.Eventually
import org.scalatest.concurrent.Eventually._
-import org.scalatest.BeforeAndAfter
+import org.scalatest.{BeforeAndAfterEach, BeforeAndAfter}
+import org.scalatest.mock.MockitoSugar
-import org.apache.spark.util.{ManualClock, Utils}
-import org.apache.spark.{SparkConf, SparkException, SparkFunSuite}
+import org.apache.spark.streaming.scheduler._
+import org.apache.spark.util.{ThreadUtils, ManualClock, Utils}
+import org.apache.spark.{SparkException, SparkConf, SparkFunSuite}
-class WriteAheadLogSuite extends SparkFunSuite with BeforeAndAfter {
+/** Common tests for WriteAheadLogs that we would like to test with different configurations. */
+abstract class CommonWriteAheadLogTests(
+ allowBatching: Boolean,
+ closeFileAfterWrite: Boolean,
+ testTag: String = "")
+ extends SparkFunSuite with BeforeAndAfter {
import WriteAheadLogSuite._
- val hadoopConf = new Configuration()
- var tempDir: File = null
- var testDir: String = null
- var testFile: String = null
- var writeAheadLog: FileBasedWriteAheadLog = null
+ protected val hadoopConf = new Configuration()
+ protected var tempDir: File = null
+ protected var testDir: String = null
+ protected var testFile: String = null
+ protected var writeAheadLog: WriteAheadLog = null
+ protected def testPrefix = if (testTag != "") testTag + " - " else testTag
before {
tempDir = Utils.createTempDir()
@@ -58,49 +74,130 @@ class WriteAheadLogSuite extends SparkFunSuite with BeforeAndAfter {
Utils.deleteRecursively(tempDir)
}
- test("WriteAheadLogUtils - log selection and creation") {
- val logDir = Utils.createTempDir().getAbsolutePath()
+ test(testPrefix + "read all logs") {
+ // Write data manually for testing reading through WriteAheadLog
+ val writtenData = (1 to 10).map { i =>
+ val data = generateRandomData()
+ val file = testDir + s"/log-$i-$i"
+ writeDataManually(data, file, allowBatching)
+ data
+ }.flatten
- def assertDriverLogClass[T <: WriteAheadLog: ClassTag](conf: SparkConf): WriteAheadLog = {
- val log = WriteAheadLogUtils.createLogForDriver(conf, logDir, hadoopConf)
- assert(log.getClass === implicitly[ClassTag[T]].runtimeClass)
- log
+ val logDirectoryPath = new Path(testDir)
+ val fileSystem = HdfsUtils.getFileSystemForPath(logDirectoryPath, hadoopConf)
+ assert(fileSystem.exists(logDirectoryPath) === true)
+
+ // Read data using manager and verify
+ val readData = readDataUsingWriteAheadLog(testDir, closeFileAfterWrite, allowBatching)
+ assert(readData === writtenData)
+ }
+
+ test(testPrefix + "write logs") {
+ // Write data with rotation using WriteAheadLog class
+ val dataToWrite = generateRandomData()
+ writeDataUsingWriteAheadLog(testDir, dataToWrite, closeFileAfterWrite = closeFileAfterWrite,
+ allowBatching = allowBatching)
+
+ // Read data manually to verify the written data
+ val logFiles = getLogFilesInDirectory(testDir)
+ assert(logFiles.size > 1)
+ val writtenData = readAndDeserializeDataManually(logFiles, allowBatching)
+ assert(writtenData === dataToWrite)
+ }
+
+ test(testPrefix + "read all logs after write") {
+ // Write data with manager, recover with new manager and verify
+ val dataToWrite = generateRandomData()
+ writeDataUsingWriteAheadLog(testDir, dataToWrite, closeFileAfterWrite, allowBatching)
+ val logFiles = getLogFilesInDirectory(testDir)
+ assert(logFiles.size > 1)
+ val readData = readDataUsingWriteAheadLog(testDir, closeFileAfterWrite, allowBatching)
+ assert(dataToWrite === readData)
+ }
+
+ test(testPrefix + "clean old logs") {
+ logCleanUpTest(waitForCompletion = false)
+ }
+
+ test(testPrefix + "clean old logs synchronously") {
+ logCleanUpTest(waitForCompletion = true)
+ }
+
+ private def logCleanUpTest(waitForCompletion: Boolean): Unit = {
+ // Write data with manager, recover with new manager and verify
+ val manualClock = new ManualClock
+ val dataToWrite = generateRandomData()
+ writeAheadLog = writeDataUsingWriteAheadLog(testDir, dataToWrite, closeFileAfterWrite,
+ allowBatching, manualClock, closeLog = false)
+ val logFiles = getLogFilesInDirectory(testDir)
+ assert(logFiles.size > 1)
+
+ writeAheadLog.clean(manualClock.getTimeMillis() / 2, waitForCompletion)
+
+ if (waitForCompletion) {
+ assert(getLogFilesInDirectory(testDir).size < logFiles.size)
+ } else {
+ eventually(Eventually.timeout(1 second), interval(10 milliseconds)) {
+ assert(getLogFilesInDirectory(testDir).size < logFiles.size)
+ }
}
+ }
- def assertReceiverLogClass[T: ClassTag](conf: SparkConf): WriteAheadLog = {
- val log = WriteAheadLogUtils.createLogForReceiver(conf, logDir, hadoopConf)
- assert(log.getClass === implicitly[ClassTag[T]].runtimeClass)
- log
+ test(testPrefix + "handling file errors while reading rotating logs") {
+ // Generate a set of log files
+ val manualClock = new ManualClock
+ val dataToWrite1 = generateRandomData()
+ writeDataUsingWriteAheadLog(testDir, dataToWrite1, closeFileAfterWrite, allowBatching,
+ manualClock)
+ val logFiles1 = getLogFilesInDirectory(testDir)
+ assert(logFiles1.size > 1)
+
+
+ // Recover old files and generate a second set of log files
+ val dataToWrite2 = generateRandomData()
+ manualClock.advance(100000)
+ writeDataUsingWriteAheadLog(testDir, dataToWrite2, closeFileAfterWrite, allowBatching ,
+ manualClock)
+ val logFiles2 = getLogFilesInDirectory(testDir)
+ assert(logFiles2.size > logFiles1.size)
+
+ // Read the files and verify that all the written data can be read
+ val readData1 = readDataUsingWriteAheadLog(testDir, closeFileAfterWrite, allowBatching)
+ assert(readData1 === (dataToWrite1 ++ dataToWrite2))
+
+ // Corrupt the first set of files so that they are basically unreadable
+ logFiles1.foreach { f =>
+ val raf = new FileOutputStream(f, true).getChannel()
+ raf.truncate(1)
+ raf.close()
}
- val emptyConf = new SparkConf() // no log configuration
- assertDriverLogClass[FileBasedWriteAheadLog](emptyConf)
- assertReceiverLogClass[FileBasedWriteAheadLog](emptyConf)
-
- // Verify setting driver WAL class
- val conf1 = new SparkConf().set("spark.streaming.driver.writeAheadLog.class",
- classOf[MockWriteAheadLog0].getName())
- assertDriverLogClass[MockWriteAheadLog0](conf1)
- assertReceiverLogClass[FileBasedWriteAheadLog](conf1)
-
- // Verify setting receiver WAL class
- val receiverWALConf = new SparkConf().set("spark.streaming.receiver.writeAheadLog.class",
- classOf[MockWriteAheadLog0].getName())
- assertDriverLogClass[FileBasedWriteAheadLog](receiverWALConf)
- assertReceiverLogClass[MockWriteAheadLog0](receiverWALConf)
-
- // Verify setting receiver WAL class with 1-arg constructor
- val receiverWALConf2 = new SparkConf().set("spark.streaming.receiver.writeAheadLog.class",
- classOf[MockWriteAheadLog1].getName())
- assertReceiverLogClass[MockWriteAheadLog1](receiverWALConf2)
-
- // Verify failure setting receiver WAL class with 2-arg constructor
- intercept[SparkException] {
- val receiverWALConf3 = new SparkConf().set("spark.streaming.receiver.writeAheadLog.class",
- classOf[MockWriteAheadLog2].getName())
- assertReceiverLogClass[MockWriteAheadLog1](receiverWALConf3)
+ // Verify that the corrupted files do not prevent reading of the second set of data
+ val readData = readDataUsingWriteAheadLog(testDir, closeFileAfterWrite, allowBatching)
+ assert(readData === dataToWrite2)
+ }
+
+ test(testPrefix + "do not create directories or files unless write") {
+ val nonexistentTempPath = File.createTempFile("test", "")
+ nonexistentTempPath.delete()
+ assert(!nonexistentTempPath.exists())
+
+ val writtenSegment = writeDataManually(generateRandomData(), testFile, allowBatching)
+ val wal = createWriteAheadLog(testDir, closeFileAfterWrite, allowBatching)
+ assert(!nonexistentTempPath.exists(), "Directory created just by creating log object")
+ if (allowBatching) {
+ intercept[UnsupportedOperationException](wal.read(writtenSegment.head))
+ } else {
+ wal.read(writtenSegment.head)
}
+ assert(!nonexistentTempPath.exists(), "Directory created just by attempting to read segment")
}
+}
+
+class FileBasedWriteAheadLogSuite
+ extends CommonWriteAheadLogTests(false, false, "FileBasedWriteAheadLog") {
+
+ import WriteAheadLogSuite._
test("FileBasedWriteAheadLogWriter - writing data") {
val dataToWrite = generateRandomData()
@@ -122,7 +219,7 @@ class WriteAheadLogSuite extends SparkFunSuite with BeforeAndAfter {
test("FileBasedWriteAheadLogReader - sequentially reading data") {
val writtenData = generateRandomData()
- writeDataManually(writtenData, testFile)
+ writeDataManually(writtenData, testFile, allowBatching = false)
val reader = new FileBasedWriteAheadLogReader(testFile, hadoopConf)
val readData = reader.toSeq.map(byteBufferToString)
assert(readData === writtenData)
@@ -166,7 +263,7 @@ class WriteAheadLogSuite extends SparkFunSuite with BeforeAndAfter {
test("FileBasedWriteAheadLogRandomReader - reading data using random reader") {
// Write data manually for testing the random reader
val writtenData = generateRandomData()
- val segments = writeDataManually(writtenData, testFile)
+ val segments = writeDataManually(writtenData, testFile, allowBatching = false)
// Get a random order of these segments and read them back
val writtenDataAndSegments = writtenData.zip(segments).toSeq.permutations.take(10).flatten
@@ -190,163 +287,212 @@ class WriteAheadLogSuite extends SparkFunSuite with BeforeAndAfter {
}
reader.close()
}
+}
- test("FileBasedWriteAheadLog - write rotating logs") {
- // Write data with rotation using WriteAheadLog class
- val dataToWrite = generateRandomData()
- writeDataUsingWriteAheadLog(testDir, dataToWrite)
-
- // Read data manually to verify the written data
- val logFiles = getLogFilesInDirectory(testDir)
- assert(logFiles.size > 1)
- val writtenData = logFiles.flatMap { file => readDataManually(file)}
- assert(writtenData === dataToWrite)
- }
+abstract class CloseFileAfterWriteTests(allowBatching: Boolean, testTag: String)
+ extends CommonWriteAheadLogTests(allowBatching, closeFileAfterWrite = true, testTag) {
- test("FileBasedWriteAheadLog - close after write flag") {
+ import WriteAheadLogSuite._
+ test(testPrefix + "close after write flag") {
// Write data with rotation using WriteAheadLog class
val numFiles = 3
val dataToWrite = Seq.tabulate(numFiles)(_.toString)
// total advance time is less than 1000, therefore log shouldn't be rolled, but manually closed
writeDataUsingWriteAheadLog(testDir, dataToWrite, closeLog = false, clockAdvanceTime = 100,
- closeFileAfterWrite = true)
+ closeFileAfterWrite = true, allowBatching = allowBatching)
// Read data manually to verify the written data
val logFiles = getLogFilesInDirectory(testDir)
assert(logFiles.size === numFiles)
- val writtenData = logFiles.flatMap { file => readDataManually(file)}
+ val writtenData: Seq[String] = readAndDeserializeDataManually(logFiles, allowBatching)
assert(writtenData === dataToWrite)
}
+}
- test("FileBasedWriteAheadLog - read rotating logs") {
- // Write data manually for testing reading through WriteAheadLog
- val writtenData = (1 to 10).map { i =>
- val data = generateRandomData()
- val file = testDir + s"/log-$i-$i"
- writeDataManually(data, file)
- data
- }.flatten
+class FileBasedWriteAheadLogWithFileCloseAfterWriteSuite
+ extends CloseFileAfterWriteTests(allowBatching = false, "FileBasedWriteAheadLog")
- val logDirectoryPath = new Path(testDir)
- val fileSystem = HdfsUtils.getFileSystemForPath(logDirectoryPath, hadoopConf)
- assert(fileSystem.exists(logDirectoryPath) === true)
+class BatchedWriteAheadLogSuite extends CommonWriteAheadLogTests(
+ allowBatching = true,
+ closeFileAfterWrite = false,
+ "BatchedWriteAheadLog") with MockitoSugar with BeforeAndAfterEach with Eventually {
- // Read data using manager and verify
- val readData = readDataUsingWriteAheadLog(testDir)
- assert(readData === writtenData)
- }
+ import BatchedWriteAheadLog._
+ import WriteAheadLogSuite._
- test("FileBasedWriteAheadLog - recover past logs when creating new manager") {
- // Write data with manager, recover with new manager and verify
- val dataToWrite = generateRandomData()
- writeDataUsingWriteAheadLog(testDir, dataToWrite)
- val logFiles = getLogFilesInDirectory(testDir)
- assert(logFiles.size > 1)
- val readData = readDataUsingWriteAheadLog(testDir)
- assert(dataToWrite === readData)
+ private var wal: WriteAheadLog = _
+ private var walHandle: WriteAheadLogRecordHandle = _
+ private var walBatchingThreadPool: ThreadPoolExecutor = _
+ private var walBatchingExecutionContext: ExecutionContextExecutorService = _
+ private val sparkConf = new SparkConf()
+
+ override def beforeEach(): Unit = {
+ wal = mock[WriteAheadLog]
+ walHandle = mock[WriteAheadLogRecordHandle]
+ walBatchingThreadPool = ThreadUtils.newDaemonFixedThreadPool(8, "wal-test-thread-pool")
+ walBatchingExecutionContext = ExecutionContext.fromExecutorService(walBatchingThreadPool)
}
- test("FileBasedWriteAheadLog - clean old logs") {
- logCleanUpTest(waitForCompletion = false)
+ override def afterEach(): Unit = {
+ if (walBatchingExecutionContext != null) {
+ walBatchingExecutionContext.shutdownNow()
+ }
}
- test("FileBasedWriteAheadLog - clean old logs synchronously") {
- logCleanUpTest(waitForCompletion = true)
- }
+ test("BatchedWriteAheadLog - serializing and deserializing batched records") {
+ val events = Seq(
+ BlockAdditionEvent(ReceivedBlockInfo(0, None, None, null)),
+ BatchAllocationEvent(null, null),
+ BatchCleanupEvent(Nil)
+ )
- private def logCleanUpTest(waitForCompletion: Boolean): Unit = {
- // Write data with manager, recover with new manager and verify
- val manualClock = new ManualClock
- val dataToWrite = generateRandomData()
- writeAheadLog = writeDataUsingWriteAheadLog(testDir, dataToWrite, manualClock, closeLog = false)
- val logFiles = getLogFilesInDirectory(testDir)
- assert(logFiles.size > 1)
+ val buffers = events.map(e => Record(ByteBuffer.wrap(Utils.serialize(e)), 0L, null))
+ val batched = BatchedWriteAheadLog.aggregate(buffers)
+ val deaggregate = BatchedWriteAheadLog.deaggregate(batched).map(buffer =>
+ Utils.deserialize[ReceivedBlockTrackerLogEvent](buffer.array()))
- writeAheadLog.clean(manualClock.getTimeMillis() / 2, waitForCompletion)
+ assert(deaggregate.toSeq === events)
+ }
- if (waitForCompletion) {
- assert(getLogFilesInDirectory(testDir).size < logFiles.size)
- } else {
- eventually(timeout(1 second), interval(10 milliseconds)) {
- assert(getLogFilesInDirectory(testDir).size < logFiles.size)
- }
+ test("BatchedWriteAheadLog - failures in wrappedLog get bubbled up") {
+ when(wal.write(any[ByteBuffer], anyLong)).thenThrow(new RuntimeException("Hello!"))
+ // the BatchedWriteAheadLog should bubble up any exceptions that may have happened during writes
+ val batchedWal = new BatchedWriteAheadLog(wal, sparkConf)
+
+ intercept[RuntimeException] {
+ val buffer = mock[ByteBuffer]
+ batchedWal.write(buffer, 2L)
}
}
- test("FileBasedWriteAheadLog - handling file errors while reading rotating logs") {
- // Generate a set of log files
- val manualClock = new ManualClock
- val dataToWrite1 = generateRandomData()
- writeDataUsingWriteAheadLog(testDir, dataToWrite1, manualClock)
- val logFiles1 = getLogFilesInDirectory(testDir)
- assert(logFiles1.size > 1)
+ // we make the write requests in separate threads so that we don't block the test thread
+ private def promiseWriteEvent(wal: WriteAheadLog, event: String, time: Long): Promise[Unit] = {
+ val p = Promise[Unit]()
+ p.completeWith(Future {
+ val v = wal.write(event, time)
+ assert(v === walHandle)
+ }(walBatchingExecutionContext))
+ p
+ }
+ /**
+ * In order to block the writes on the writer thread, we mock the write method, and block it
+ * for some time with a promise.
+ */
+ private def writeBlockingPromise(wal: WriteAheadLog): Promise[Any] = {
+ // we would like to block the write so that we can queue requests
+ val promise = Promise[Any]()
+ when(wal.write(any[ByteBuffer], any[Long])).thenAnswer(
+ new Answer[WriteAheadLogRecordHandle] {
+ override def answer(invocation: InvocationOnMock): WriteAheadLogRecordHandle = {
+ Await.ready(promise.future, 4.seconds)
+ walHandle
+ }
+ }
+ )
+ promise
+ }
- // Recover old files and generate a second set of log files
- val dataToWrite2 = generateRandomData()
- manualClock.advance(100000)
- writeDataUsingWriteAheadLog(testDir, dataToWrite2, manualClock)
- val logFiles2 = getLogFilesInDirectory(testDir)
- assert(logFiles2.size > logFiles1.size)
+ test("BatchedWriteAheadLog - name log with aggregated entries with the timestamp of last entry") {
+ val batchedWal = new BatchedWriteAheadLog(wal, sparkConf)
+ // block the write so that we can batch some records
+ val promise = writeBlockingPromise(wal)
+
+ val event1 = "hello"
+ val event2 = "world"
+ val event3 = "this"
+ val event4 = "is"
+ val event5 = "doge"
+
+ // The queue.take() immediately takes the 3, and there is nothing left in the queue at that
+ // moment. Then the promise blocks the writing of 3. The rest get queued.
+ promiseWriteEvent(batchedWal, event1, 3L)
+ // rest of the records will be batched while it takes 3 to get written
+ promiseWriteEvent(batchedWal, event2, 5L)
+ promiseWriteEvent(batchedWal, event3, 8L)
+ promiseWriteEvent(batchedWal, event4, 12L)
+ promiseWriteEvent(batchedWal, event5, 10L)
+ eventually(timeout(1 second)) {
+ assert(walBatchingThreadPool.getActiveCount === 5)
+ }
+ promise.success(true)
- // Read the files and verify that all the written data can be read
- val readData1 = readDataUsingWriteAheadLog(testDir)
- assert(readData1 === (dataToWrite1 ++ dataToWrite2))
+ val buffer1 = wrapArrayArrayByte(Array(event1))
+ val buffer2 = wrapArrayArrayByte(Array(event2, event3, event4, event5))
- // Corrupt the first set of files so that they are basically unreadable
- logFiles1.foreach { f =>
- val raf = new FileOutputStream(f, true).getChannel()
- raf.truncate(1)
- raf.close()
+ eventually(timeout(1 second)) {
+ verify(wal, times(1)).write(meq(buffer1), meq(3L))
+ // the file name should be the timestamp of the last record, as events should be naturally
+ // in order of timestamp, and we need the last element.
+ verify(wal, times(1)).write(meq(buffer2), meq(10L))
}
-
- // Verify that the corrupted files do not prevent reading of the second set of data
- val readData = readDataUsingWriteAheadLog(testDir)
- assert(readData === dataToWrite2)
}
- test("FileBasedWriteAheadLog - do not create directories or files unless write") {
- val nonexistentTempPath = File.createTempFile("test", "")
- nonexistentTempPath.delete()
- assert(!nonexistentTempPath.exists())
+ test("BatchedWriteAheadLog - shutdown properly") {
+ val batchedWal = new BatchedWriteAheadLog(wal, sparkConf)
+ batchedWal.close()
+ verify(wal, times(1)).close()
- val writtenSegment = writeDataManually(generateRandomData(), testFile)
- val wal = new FileBasedWriteAheadLog(new SparkConf(), tempDir.getAbsolutePath,
- new Configuration(), 1, 1, closeFileAfterWrite = false)
- assert(!nonexistentTempPath.exists(), "Directory created just by creating log object")
- wal.read(writtenSegment.head)
- assert(!nonexistentTempPath.exists(), "Directory created just by attempting to read segment")
+ intercept[IllegalStateException](batchedWal.write(mock[ByteBuffer], 12L))
}
-}
-object WriteAheadLogSuite {
+ test("BatchedWriteAheadLog - fail everything in queue during shutdown") {
+ val batchedWal = new BatchedWriteAheadLog(wal, sparkConf)
- class MockWriteAheadLog0() extends WriteAheadLog {
- override def write(record: ByteBuffer, time: Long): WriteAheadLogRecordHandle = { null }
- override def read(handle: WriteAheadLogRecordHandle): ByteBuffer = { null }
- override def readAll(): util.Iterator[ByteBuffer] = { null }
- override def clean(threshTime: Long, waitForCompletion: Boolean): Unit = { }
- override def close(): Unit = { }
- }
+ // block the write so that we can batch some records
+ writeBlockingPromise(wal)
+
+ val event1 = ("hello", 3L)
+ val event2 = ("world", 5L)
+ val event3 = ("this", 8L)
+ val event4 = ("is", 9L)
+ val event5 = ("doge", 10L)
+
+ // The queue.take() immediately takes the 3, and there is nothing left in the queue at that
+ // moment. Then the promise blocks the writing of 3. The rest get queued.
+ val writePromises = Seq(event1, event2, event3, event4, event5).map { event =>
+ promiseWriteEvent(batchedWal, event._1, event._2)
+ }
- class MockWriteAheadLog1(val conf: SparkConf) extends MockWriteAheadLog0()
+ eventually(timeout(1 second)) {
+ assert(walBatchingThreadPool.getActiveCount === 5)
+ }
+
+ batchedWal.close()
+ eventually(timeout(1 second)) {
+ assert(writePromises.forall(_.isCompleted))
+ assert(writePromises.forall(_.future.value.get.isFailure)) // all should have failed
+ }
+ }
+}
- class MockWriteAheadLog2(val conf: SparkConf, x: Int) extends MockWriteAheadLog0()
+class BatchedWriteAheadLogWithCloseFileAfterWriteSuite
+ extends CloseFileAfterWriteTests(allowBatching = true, "BatchedWriteAheadLog")
+object WriteAheadLogSuite {
private val hadoopConf = new Configuration()
/** Write data to a file directly and return an array of the file segments written. */
- def writeDataManually(data: Seq[String], file: String): Seq[FileBasedWriteAheadLogSegment] = {
+ def writeDataManually(
+ data: Seq[String],
+ file: String,
+ allowBatching: Boolean): Seq[FileBasedWriteAheadLogSegment] = {
val segments = new ArrayBuffer[FileBasedWriteAheadLogSegment]()
val writer = HdfsUtils.getOutputStream(file, hadoopConf)
- data.foreach { item =>
+ def writeToStream(bytes: Array[Byte]): Unit = {
val offset = writer.getPos
- val bytes = Utils.serialize(item)
writer.writeInt(bytes.size)
writer.write(bytes)
segments += FileBasedWriteAheadLogSegment(file, offset, bytes.size)
}
+ if (allowBatching) {
+ writeToStream(wrapArrayArrayByte(data.toArray[String]).array())
+ } else {
+ data.foreach { item =>
+ writeToStream(Utils.serialize(item))
+ }
+ }
writer.close()
segments
}
@@ -356,8 +502,7 @@ object WriteAheadLogSuite {
*/
def writeDataUsingWriter(
filePath: String,
- data: Seq[String]
- ): Seq[FileBasedWriteAheadLogSegment] = {
+ data: Seq[String]): Seq[FileBasedWriteAheadLogSegment] = {
val writer = new FileBasedWriteAheadLogWriter(filePath, hadoopConf)
val segments = data.map {
item => writer.write(item)
@@ -370,13 +515,13 @@ object WriteAheadLogSuite {
def writeDataUsingWriteAheadLog(
logDirectory: String,
data: Seq[String],
+ closeFileAfterWrite: Boolean,
+ allowBatching: Boolean,
manualClock: ManualClock = new ManualClock,
closeLog: Boolean = true,
- clockAdvanceTime: Int = 500,
- closeFileAfterWrite: Boolean = false): FileBasedWriteAheadLog = {
+ clockAdvanceTime: Int = 500): WriteAheadLog = {
if (manualClock.getTimeMillis() < 100000) manualClock.setTime(10000)
- val wal = new FileBasedWriteAheadLog(new SparkConf(), logDirectory, hadoopConf, 1, 1,
- closeFileAfterWrite)
+ val wal = createWriteAheadLog(logDirectory, closeFileAfterWrite, allowBatching)
// Ensure that 500 does not get sorted after 2000, so put a high base value.
data.foreach { item =>
@@ -406,16 +551,16 @@ object WriteAheadLogSuite {
}
/** Read all the data from a log file directly and return the list of byte buffers. */
- def readDataManually(file: String): Seq[String] = {
+ def readDataManually[T](file: String): Seq[T] = {
val reader = HdfsUtils.getInputStream(file, hadoopConf)
- val buffer = new ArrayBuffer[String]
+ val buffer = new ArrayBuffer[T]
try {
while (true) {
// Read till EOF is thrown
val length = reader.readInt()
val bytes = new Array[Byte](length)
reader.read(bytes)
- buffer += Utils.deserialize[String](bytes)
+ buffer += Utils.deserialize[T](bytes)
}
} catch {
case ex: EOFException =>
@@ -434,15 +579,17 @@ object WriteAheadLogSuite {
}
/** Read all the data in the log file in a directory using the WriteAheadLog class. */
- def readDataUsingWriteAheadLog(logDirectory: String): Seq[String] = {
- val wal = new FileBasedWriteAheadLog(new SparkConf(), logDirectory, hadoopConf, 1, 1,
- closeFileAfterWrite = false)
+ def readDataUsingWriteAheadLog(
+ logDirectory: String,
+ closeFileAfterWrite: Boolean,
+ allowBatching: Boolean): Seq[String] = {
+ val wal = createWriteAheadLog(logDirectory, closeFileAfterWrite, allowBatching)
val data = wal.readAll().asScala.map(byteBufferToString).toSeq
wal.close()
data
}
- /** Get the log files in a direction */
+ /** Get the log files in a directory. */
def getLogFilesInDirectory(directory: String): Seq[String] = {
val logDirectoryPath = new Path(directory)
val fileSystem = HdfsUtils.getFileSystemForPath(logDirectoryPath, hadoopConf)
@@ -458,10 +605,31 @@ object WriteAheadLogSuite {
}
}
+ def createWriteAheadLog(
+ logDirectory: String,
+ closeFileAfterWrite: Boolean,
+ allowBatching: Boolean): WriteAheadLog = {
+ val sparkConf = new SparkConf
+ val wal = new FileBasedWriteAheadLog(sparkConf, logDirectory, hadoopConf, 1, 1,
+ closeFileAfterWrite)
+ if (allowBatching) new BatchedWriteAheadLog(wal, sparkConf) else wal
+ }
+
def generateRandomData(): Seq[String] = {
(1 to 100).map { _.toString }
}
+ def readAndDeserializeDataManually(logFiles: Seq[String], allowBatching: Boolean): Seq[String] = {
+ if (allowBatching) {
+ logFiles.flatMap { file =>
+ val data = readDataManually[Array[Array[Byte]]](file)
+ data.flatMap(byteArray => byteArray.map(Utils.deserialize[String]))
+ }
+ } else {
+ logFiles.flatMap { file => readDataManually[String](file)}
+ }
+ }
+
implicit def stringToByteBuffer(str: String): ByteBuffer = {
ByteBuffer.wrap(Utils.serialize(str))
}
@@ -469,4 +637,8 @@ object WriteAheadLogSuite {
implicit def byteBufferToString(byteBuffer: ByteBuffer): String = {
Utils.deserialize[String](byteBuffer.array)
}
+
+ def wrapArrayArrayByte[T](records: Array[T]): ByteBuffer = {
+ ByteBuffer.wrap(Utils.serialize[Array[Array[Byte]]](records.map(Utils.serialize[T])))
+ }
}
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogUtilsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogUtilsSuite.scala
new file mode 100644
index 0000000000..9152728191
--- /dev/null
+++ b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogUtilsSuite.scala
@@ -0,0 +1,122 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.util
+
+import java.nio.ByteBuffer
+import java.util
+
+import scala.reflect.ClassTag
+
+import org.apache.hadoop.conf.Configuration
+
+import org.apache.spark.{SparkException, SparkConf, SparkFunSuite}
+import org.apache.spark.util.Utils
+
+class WriteAheadLogUtilsSuite extends SparkFunSuite {
+ import WriteAheadLogUtilsSuite._
+
+ private val logDir = Utils.createTempDir().getAbsolutePath()
+ private val hadoopConf = new Configuration()
+
+ def assertDriverLogClass[T <: WriteAheadLog: ClassTag](
+ conf: SparkConf,
+ isBatched: Boolean = false): WriteAheadLog = {
+ val log = WriteAheadLogUtils.createLogForDriver(conf, logDir, hadoopConf)
+ if (isBatched) {
+ assert(log.isInstanceOf[BatchedWriteAheadLog])
+ val parentLog = log.asInstanceOf[BatchedWriteAheadLog].wrappedLog
+ assert(parentLog.getClass === implicitly[ClassTag[T]].runtimeClass)
+ } else {
+ assert(log.getClass === implicitly[ClassTag[T]].runtimeClass)
+ }
+ log
+ }
+
+ def assertReceiverLogClass[T <: WriteAheadLog: ClassTag](conf: SparkConf): WriteAheadLog = {
+ val log = WriteAheadLogUtils.createLogForReceiver(conf, logDir, hadoopConf)
+ assert(log.getClass === implicitly[ClassTag[T]].runtimeClass)
+ log
+ }
+
+ test("log selection and creation") {
+
+ val emptyConf = new SparkConf() // no log configuration
+ assertDriverLogClass[FileBasedWriteAheadLog](emptyConf)
+ assertReceiverLogClass[FileBasedWriteAheadLog](emptyConf)
+
+ // Verify setting driver WAL class
+ val driverWALConf = new SparkConf().set("spark.streaming.driver.writeAheadLog.class",
+ classOf[MockWriteAheadLog0].getName())
+ assertDriverLogClass[MockWriteAheadLog0](driverWALConf)
+ assertReceiverLogClass[FileBasedWriteAheadLog](driverWALConf)
+
+ // Verify setting receiver WAL class
+ val receiverWALConf = new SparkConf().set("spark.streaming.receiver.writeAheadLog.class",
+ classOf[MockWriteAheadLog0].getName())
+ assertDriverLogClass[FileBasedWriteAheadLog](receiverWALConf)
+ assertReceiverLogClass[MockWriteAheadLog0](receiverWALConf)
+
+ // Verify setting receiver WAL class with 1-arg constructor
+ val receiverWALConf2 = new SparkConf().set("spark.streaming.receiver.writeAheadLog.class",
+ classOf[MockWriteAheadLog1].getName())
+ assertReceiverLogClass[MockWriteAheadLog1](receiverWALConf2)
+
+ // Verify failure setting receiver WAL class with 2-arg constructor
+ intercept[SparkException] {
+ val receiverWALConf3 = new SparkConf().set("spark.streaming.receiver.writeAheadLog.class",
+ classOf[MockWriteAheadLog2].getName())
+ assertReceiverLogClass[MockWriteAheadLog1](receiverWALConf3)
+ }
+ }
+
+ test("wrap WriteAheadLog in BatchedWriteAheadLog when batching is enabled") {
+ def getBatchedSparkConf: SparkConf =
+ new SparkConf().set("spark.streaming.driver.writeAheadLog.allowBatching", "true")
+
+ val justBatchingConf = getBatchedSparkConf
+ assertDriverLogClass[FileBasedWriteAheadLog](justBatchingConf, isBatched = true)
+ assertReceiverLogClass[FileBasedWriteAheadLog](justBatchingConf)
+
+ // Verify setting driver WAL class
+ val driverWALConf = getBatchedSparkConf.set("spark.streaming.driver.writeAheadLog.class",
+ classOf[MockWriteAheadLog0].getName())
+ assertDriverLogClass[MockWriteAheadLog0](driverWALConf, isBatched = true)
+ assertReceiverLogClass[FileBasedWriteAheadLog](driverWALConf)
+
+ // Verify receivers are not wrapped
+ val receiverWALConf = getBatchedSparkConf.set("spark.streaming.receiver.writeAheadLog.class",
+ classOf[MockWriteAheadLog0].getName())
+ assertDriverLogClass[FileBasedWriteAheadLog](receiverWALConf, isBatched = true)
+ assertReceiverLogClass[MockWriteAheadLog0](receiverWALConf)
+ }
+}
+
+object WriteAheadLogUtilsSuite {
+
+ class MockWriteAheadLog0() extends WriteAheadLog {
+ override def write(record: ByteBuffer, time: Long): WriteAheadLogRecordHandle = { null }
+ override def read(handle: WriteAheadLogRecordHandle): ByteBuffer = { null }
+ override def readAll(): util.Iterator[ByteBuffer] = { null }
+ override def clean(threshTime: Long, waitForCompletion: Boolean): Unit = { }
+ override def close(): Unit = { }
+ }
+
+ class MockWriteAheadLog1(val conf: SparkConf) extends MockWriteAheadLog0()
+
+ class MockWriteAheadLog2(val conf: SparkConf, x: Int) extends MockWriteAheadLog0()
+}