diff options
Diffstat (limited to 'external/flume-sink/src/main')
3 files changed, 35 insertions, 1 deletions
diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala index e77cf7bfa5..3c656a381b 100644 --- a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala +++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala @@ -16,7 +16,7 @@ */ package org.apache.spark.streaming.flume.sink -import java.util.concurrent.{ConcurrentHashMap, Executors} +import java.util.concurrent.{CountDownLatch, ConcurrentHashMap, Executors} import java.util.concurrent.atomic.AtomicLong import scala.collection.JavaConversions._ @@ -58,8 +58,12 @@ private[flume] class SparkAvroCallbackHandler(val threads: Int, val channel: Cha private val seqBase = RandomStringUtils.randomAlphanumeric(8) private val seqCounter = new AtomicLong(0) + @volatile private var stopped = false + @volatile private var isTest = false + private var testLatch: CountDownLatch = null + /** * Returns a bunch of events to Spark over Avro RPC. * @param n Maximum number of events to return in a batch @@ -90,6 +94,9 @@ private[flume] class SparkAvroCallbackHandler(val threads: Int, val channel: Cha val processor = new TransactionProcessor( channel, seq, n, transactionTimeout, backOffInterval, this) sequenceNumberToProcessor.put(seq, processor) + if (isTest) { + processor.countDownWhenBatchAcked(testLatch) + } Some(processor) } else { None @@ -141,6 +148,11 @@ private[flume] class SparkAvroCallbackHandler(val threads: Int, val channel: Cha } } + private[sink] def countDownWhenBatchAcked(latch: CountDownLatch) { + testLatch = latch + isTest = true + } + /** * Shuts down the executor used to process transactions. */ diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSink.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSink.scala index 98ae7d783a..14dffb15fe 100644 --- a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSink.scala +++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSink.scala @@ -138,6 +138,16 @@ class SparkSink extends AbstractSink with Logging with Configurable { throw new RuntimeException("Server was not started!") ) } + + /** + * Pass in a [[CountDownLatch]] for testing purposes. This batch is counted down when each + * batch is received. The test can simply call await on this latch till the expected number of + * batches are received. + * @param latch + */ + private[flume] def countdownWhenBatchReceived(latch: CountDownLatch) { + handler.foreach(_.countDownWhenBatchAcked(latch)) + } } /** diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala index 13f3aa94be..ea45b14294 100644 --- a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala +++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala @@ -62,6 +62,10 @@ private class TransactionProcessor(val channel: Channel, val seqNum: String, @volatile private var stopped = false + @volatile private var isTest = false + + private var testLatch: CountDownLatch = null + // The transaction that this processor would handle var txOpt: Option[Transaction] = None @@ -182,6 +186,9 @@ private class TransactionProcessor(val channel: Channel, val seqNum: String, rollbackAndClose(tx, close = false) // tx will be closed later anyway } finally { tx.close() + if (isTest) { + testLatch.countDown() + } } } else { logWarning("Spark could not commit transaction, NACK received. Rolling back transaction.") @@ -237,4 +244,9 @@ private class TransactionProcessor(val channel: Channel, val seqNum: String, processAckOrNack() null } + + private[sink] def countDownWhenBatchAcked(latch: CountDownLatch) { + testLatch = latch + isTest = true + } } |