aboutsummaryrefslogtreecommitdiff
path: root/streaming
diff options
context:
space:
mode:
Diffstat (limited to 'streaming')
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala6
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogUtils.scala15
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala32
3 files changed, 44 insertions, 9 deletions
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala
index 9f4a4d6806..bc3f2486c2 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala
@@ -47,7 +47,8 @@ private[streaming] class FileBasedWriteAheadLog(
logDirectory: String,
hadoopConf: Configuration,
rollingIntervalSecs: Int,
- maxFailures: Int
+ maxFailures: Int,
+ closeFileAfterWrite: Boolean
) extends WriteAheadLog with Logging {
import FileBasedWriteAheadLog._
@@ -80,6 +81,9 @@ private[streaming] class FileBasedWriteAheadLog(
while (!succeeded && failures < maxFailures) {
try {
fileSegment = getLogWriter(time).write(byteBuffer)
+ if (closeFileAfterWrite) {
+ resetWriter()
+ }
succeeded = true
} catch {
case ex: Exception =>
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogUtils.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogUtils.scala
index 7f6ff12c58..0ea970e61b 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogUtils.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogUtils.scala
@@ -31,11 +31,15 @@ private[streaming] object WriteAheadLogUtils extends Logging {
val RECEIVER_WAL_ROLLING_INTERVAL_CONF_KEY =
"spark.streaming.receiver.writeAheadLog.rollingIntervalSecs"
val RECEIVER_WAL_MAX_FAILURES_CONF_KEY = "spark.streaming.receiver.writeAheadLog.maxFailures"
+ val RECEIVER_WAL_CLOSE_AFTER_WRITE_CONF_KEY =
+ "spark.streaming.receiver.writeAheadLog.closeFileAfterWrite"
val DRIVER_WAL_CLASS_CONF_KEY = "spark.streaming.driver.writeAheadLog.class"
val DRIVER_WAL_ROLLING_INTERVAL_CONF_KEY =
"spark.streaming.driver.writeAheadLog.rollingIntervalSecs"
val DRIVER_WAL_MAX_FAILURES_CONF_KEY = "spark.streaming.driver.writeAheadLog.maxFailures"
+ val DRIVER_WAL_CLOSE_AFTER_WRITE_CONF_KEY =
+ "spark.streaming.driver.writeAheadLog.closeFileAfterWrite"
val DEFAULT_ROLLING_INTERVAL_SECS = 60
val DEFAULT_MAX_FAILURES = 3
@@ -60,6 +64,14 @@ private[streaming] object WriteAheadLogUtils extends Logging {
}
}
+ def shouldCloseFileAfterWrite(conf: SparkConf, isDriver: Boolean): Boolean = {
+ if (isDriver) {
+ conf.getBoolean(DRIVER_WAL_CLOSE_AFTER_WRITE_CONF_KEY, defaultValue = false)
+ } else {
+ conf.getBoolean(RECEIVER_WAL_CLOSE_AFTER_WRITE_CONF_KEY, defaultValue = false)
+ }
+ }
+
/**
* Create a WriteAheadLog for the driver. If configured with custom WAL class, it will try
* to create instance of that class, otherwise it will create the default FileBasedWriteAheadLog.
@@ -113,7 +125,8 @@ private[streaming] object WriteAheadLogUtils extends Logging {
}
}.getOrElse {
new FileBasedWriteAheadLog(sparkConf, fileWalLogDirectory, fileWalHadoopConf,
- getRollingIntervalSecs(sparkConf, isDriver), getMaxFailures(sparkConf, isDriver))
+ getRollingIntervalSecs(sparkConf, isDriver), getMaxFailures(sparkConf, isDriver),
+ shouldCloseFileAfterWrite(sparkConf, isDriver))
}
}
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala
index 5e49fd0076..93ae41a3d2 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala
@@ -203,6 +203,21 @@ class WriteAheadLogSuite extends SparkFunSuite with BeforeAndAfter {
assert(writtenData === dataToWrite)
}
+ test("FileBasedWriteAheadLog - close after write flag") {
+ // Write data with rotation using WriteAheadLog class
+ val numFiles = 3
+ val dataToWrite = Seq.tabulate(numFiles)(_.toString)
+ // total advance time is less than 1000, therefore log shouldn't be rolled, but manually closed
+ writeDataUsingWriteAheadLog(testDir, dataToWrite, closeLog = false, clockAdvanceTime = 100,
+ closeFileAfterWrite = true)
+
+ // Read data manually to verify the written data
+ val logFiles = getLogFilesInDirectory(testDir)
+ assert(logFiles.size === numFiles)
+ val writtenData = logFiles.flatMap { file => readDataManually(file)}
+ assert(writtenData === dataToWrite)
+ }
+
test("FileBasedWriteAheadLog - read rotating logs") {
// Write data manually for testing reading through WriteAheadLog
val writtenData = (1 to 10).map { i =>
@@ -296,8 +311,8 @@ class WriteAheadLogSuite extends SparkFunSuite with BeforeAndAfter {
assert(!nonexistentTempPath.exists())
val writtenSegment = writeDataManually(generateRandomData(), testFile)
- val wal = new FileBasedWriteAheadLog(
- new SparkConf(), tempDir.getAbsolutePath, new Configuration(), 1, 1)
+ val wal = new FileBasedWriteAheadLog(new SparkConf(), tempDir.getAbsolutePath,
+ new Configuration(), 1, 1, closeFileAfterWrite = false)
assert(!nonexistentTempPath.exists(), "Directory created just by creating log object")
wal.read(writtenSegment.head)
assert(!nonexistentTempPath.exists(), "Directory created just by attempting to read segment")
@@ -356,14 +371,16 @@ object WriteAheadLogSuite {
logDirectory: String,
data: Seq[String],
manualClock: ManualClock = new ManualClock,
- closeLog: Boolean = true
- ): FileBasedWriteAheadLog = {
+ closeLog: Boolean = true,
+ clockAdvanceTime: Int = 500,
+ closeFileAfterWrite: Boolean = false): FileBasedWriteAheadLog = {
if (manualClock.getTimeMillis() < 100000) manualClock.setTime(10000)
- val wal = new FileBasedWriteAheadLog(new SparkConf(), logDirectory, hadoopConf, 1, 1)
+ val wal = new FileBasedWriteAheadLog(new SparkConf(), logDirectory, hadoopConf, 1, 1,
+ closeFileAfterWrite)
// Ensure that 500 does not get sorted after 2000, so put a high base value.
data.foreach { item =>
- manualClock.advance(500)
+ manualClock.advance(clockAdvanceTime)
wal.write(item, manualClock.getTimeMillis())
}
if (closeLog) wal.close()
@@ -418,7 +435,8 @@ object WriteAheadLogSuite {
/** Read all the data in the log file in a directory using the WriteAheadLog class. */
def readDataUsingWriteAheadLog(logDirectory: String): Seq[String] = {
- val wal = new FileBasedWriteAheadLog(new SparkConf(), logDirectory, hadoopConf, 1, 1)
+ val wal = new FileBasedWriteAheadLog(new SparkConf(), logDirectory, hadoopConf, 1, 1,
+ closeFileAfterWrite = false)
val data = wal.readAll().asScala.map(byteBufferToString).toSeq
wal.close()
data