aboutsummaryrefslogtreecommitdiff
path: root/external/flume-sink/src/main/scala/org/apache
diff options
context:
space:
mode:
Diffstat (limited to 'external/flume-sink/src/main/scala/org/apache')
-rw-r--r--external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala14
-rw-r--r--external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSink.scala10
-rw-r--r--external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala12
3 files changed, 35 insertions, 1 deletions
diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala
index e77cf7bfa5..3c656a381b 100644
--- a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala
+++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala
@@ -16,7 +16,7 @@
*/
package org.apache.spark.streaming.flume.sink
-import java.util.concurrent.{ConcurrentHashMap, Executors}
+import java.util.concurrent.{CountDownLatch, ConcurrentHashMap, Executors}
import java.util.concurrent.atomic.AtomicLong
import scala.collection.JavaConversions._
@@ -58,8 +58,12 @@ private[flume] class SparkAvroCallbackHandler(val threads: Int, val channel: Cha
private val seqBase = RandomStringUtils.randomAlphanumeric(8)
private val seqCounter = new AtomicLong(0)
+
@volatile private var stopped = false
+ @volatile private var isTest = false
+ private var testLatch: CountDownLatch = null
+
/**
* Returns a bunch of events to Spark over Avro RPC.
* @param n Maximum number of events to return in a batch
@@ -90,6 +94,9 @@ private[flume] class SparkAvroCallbackHandler(val threads: Int, val channel: Cha
val processor = new TransactionProcessor(
channel, seq, n, transactionTimeout, backOffInterval, this)
sequenceNumberToProcessor.put(seq, processor)
+ if (isTest) {
+ processor.countDownWhenBatchAcked(testLatch)
+ }
Some(processor)
} else {
None
@@ -141,6 +148,11 @@ private[flume] class SparkAvroCallbackHandler(val threads: Int, val channel: Cha
}
}
+ private[sink] def countDownWhenBatchAcked(latch: CountDownLatch) {
+ testLatch = latch
+ isTest = true
+ }
+
/**
* Shuts down the executor used to process transactions.
*/
diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSink.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSink.scala
index 98ae7d783a..14dffb15fe 100644
--- a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSink.scala
+++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSink.scala
@@ -138,6 +138,16 @@ class SparkSink extends AbstractSink with Logging with Configurable {
throw new RuntimeException("Server was not started!")
)
}
+
+ /**
+ * Pass in a [[CountDownLatch]] for testing purposes. This batch is counted down when each
+ * batch is received. The test can simply call await on this latch till the expected number of
+ * batches are received.
+ * @param latch
+ */
+ private[flume] def countdownWhenBatchReceived(latch: CountDownLatch) {
+ handler.foreach(_.countDownWhenBatchAcked(latch))
+ }
}
/**
diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala
index 13f3aa94be..ea45b14294 100644
--- a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala
+++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala
@@ -62,6 +62,10 @@ private class TransactionProcessor(val channel: Channel, val seqNum: String,
@volatile private var stopped = false
+ @volatile private var isTest = false
+
+ private var testLatch: CountDownLatch = null
+
// The transaction that this processor would handle
var txOpt: Option[Transaction] = None
@@ -182,6 +186,9 @@ private class TransactionProcessor(val channel: Channel, val seqNum: String,
rollbackAndClose(tx, close = false) // tx will be closed later anyway
} finally {
tx.close()
+ if (isTest) {
+ testLatch.countDown()
+ }
}
} else {
logWarning("Spark could not commit transaction, NACK received. Rolling back transaction.")
@@ -237,4 +244,9 @@ private class TransactionProcessor(val channel: Channel, val seqNum: String,
processAckOrNack()
null
}
+
+ private[sink] def countDownWhenBatchAcked(latch: CountDownLatch) {
+ testLatch = latch
+ isTest = true
+ }
}