aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala245
-rw-r--r--project/MimaExcludes.scala6
-rw-r--r--python/pyspark/streaming/kafka.py111
-rw-r--r--python/pyspark/streaming/tests.py35
4 files changed, 299 insertions, 98 deletions
diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala
index 3128222077..ad2fb8aa5f 100644
--- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala
+++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala
@@ -17,25 +17,29 @@
package org.apache.spark.streaming.kafka
+import java.io.OutputStream
import java.lang.{Integer => JInt, Long => JLong}
import java.util.{List => JList, Map => JMap, Set => JSet}
import scala.collection.JavaConverters._
import scala.reflect.ClassTag
+import com.google.common.base.Charsets.UTF_8
import kafka.common.TopicAndPartition
import kafka.message.MessageAndMetadata
-import kafka.serializer.{Decoder, DefaultDecoder, StringDecoder}
+import kafka.serializer.{DefaultDecoder, Decoder, StringDecoder}
+import net.razorvine.pickle.{Opcodes, Pickler, IObjectPickler}
import org.apache.spark.api.java.function.{Function => JFunction}
-import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext}
+import org.apache.spark.streaming.util.WriteAheadLogUtils
+import org.apache.spark.{SparkContext, SparkException}
+import org.apache.spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD}
+import org.apache.spark.api.python.SerDeUtil
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.StreamingContext
-import org.apache.spark.streaming.api.java.{JavaInputDStream, JavaPairInputDStream, JavaPairReceiverInputDStream, JavaStreamingContext}
-import org.apache.spark.streaming.dstream.{InputDStream, ReceiverInputDStream}
-import org.apache.spark.streaming.util.WriteAheadLogUtils
-import org.apache.spark.{SparkContext, SparkException}
+import org.apache.spark.streaming.api.java._
+import org.apache.spark.streaming.dstream.{DStream, InputDStream, ReceiverInputDStream}
object KafkaUtils {
/**
@@ -184,6 +188,27 @@ object KafkaUtils {
}
}
+ private[kafka] def getFromOffsets(
+ kc: KafkaCluster,
+ kafkaParams: Map[String, String],
+ topics: Set[String]
+ ): Map[TopicAndPartition, Long] = {
+ val reset = kafkaParams.get("auto.offset.reset").map(_.toLowerCase)
+ val result = for {
+ topicPartitions <- kc.getPartitions(topics).right
+ leaderOffsets <- (if (reset == Some("smallest")) {
+ kc.getEarliestLeaderOffsets(topicPartitions)
+ } else {
+ kc.getLatestLeaderOffsets(topicPartitions)
+ }).right
+ } yield {
+ leaderOffsets.map { case (tp, lo) =>
+ (tp, lo.offset)
+ }
+ }
+ KafkaCluster.checkErrors(result)
+ }
+
/**
* Create a RDD from Kafka using offset ranges for each topic and partition.
*
@@ -246,7 +271,7 @@ object KafkaUtils {
// This could be avoided by refactoring KafkaRDD.leaders and KafkaCluster to use Broker
leaders.map {
case (tp: TopicAndPartition, Broker(host, port)) => (tp, (host, port))
- }.toMap
+ }
}
val cleanedHandler = sc.clean(messageHandler)
checkOffsets(kc, offsetRanges)
@@ -406,23 +431,9 @@ object KafkaUtils {
): InputDStream[(K, V)] = {
val messageHandler = (mmd: MessageAndMetadata[K, V]) => (mmd.key, mmd.message)
val kc = new KafkaCluster(kafkaParams)
- val reset = kafkaParams.get("auto.offset.reset").map(_.toLowerCase)
-
- val result = for {
- topicPartitions <- kc.getPartitions(topics).right
- leaderOffsets <- (if (reset == Some("smallest")) {
- kc.getEarliestLeaderOffsets(topicPartitions)
- } else {
- kc.getLatestLeaderOffsets(topicPartitions)
- }).right
- } yield {
- val fromOffsets = leaderOffsets.map { case (tp, lo) =>
- (tp, lo.offset)
- }
- new DirectKafkaInputDStream[K, V, KD, VD, (K, V)](
- ssc, kafkaParams, fromOffsets, messageHandler)
- }
- KafkaCluster.checkErrors(result)
+ val fromOffsets = getFromOffsets(kc, kafkaParams, topics)
+ new DirectKafkaInputDStream[K, V, KD, VD, (K, V)](
+ ssc, kafkaParams, fromOffsets, messageHandler)
}
/**
@@ -550,6 +561,8 @@ object KafkaUtils {
* takes care of known parameters instead of passing them from Python
*/
private[kafka] class KafkaUtilsPythonHelper {
+ import KafkaUtilsPythonHelper._
+
def createStream(
jssc: JavaStreamingContext,
kafkaParams: JMap[String, String],
@@ -566,86 +579,92 @@ private[kafka] class KafkaUtilsPythonHelper {
storageLevel)
}
- def createRDD(
+ def createRDDWithoutMessageHandler(
jsc: JavaSparkContext,
kafkaParams: JMap[String, String],
offsetRanges: JList[OffsetRange],
- leaders: JMap[TopicAndPartition, Broker]): JavaPairRDD[Array[Byte], Array[Byte]] = {
- val messageHandler = new JFunction[MessageAndMetadata[Array[Byte], Array[Byte]],
- (Array[Byte], Array[Byte])] {
- def call(t1: MessageAndMetadata[Array[Byte], Array[Byte]]): (Array[Byte], Array[Byte]) =
- (t1.key(), t1.message())
- }
+ leaders: JMap[TopicAndPartition, Broker]): JavaRDD[(Array[Byte], Array[Byte])] = {
+ val messageHandler =
+ (mmd: MessageAndMetadata[Array[Byte], Array[Byte]]) => (mmd.key, mmd.message)
+ new JavaRDD(createRDD(jsc, kafkaParams, offsetRanges, leaders, messageHandler))
+ }
- val jrdd = KafkaUtils.createRDD[
- Array[Byte],
- Array[Byte],
- DefaultDecoder,
- DefaultDecoder,
- (Array[Byte], Array[Byte])](
- jsc,
- classOf[Array[Byte]],
- classOf[Array[Byte]],
- classOf[DefaultDecoder],
- classOf[DefaultDecoder],
- classOf[(Array[Byte], Array[Byte])],
- kafkaParams,
- offsetRanges.toArray(new Array[OffsetRange](offsetRanges.size())),
- leaders,
- messageHandler
- )
- new JavaPairRDD(jrdd.rdd)
+ def createRDDWithMessageHandler(
+ jsc: JavaSparkContext,
+ kafkaParams: JMap[String, String],
+ offsetRanges: JList[OffsetRange],
+ leaders: JMap[TopicAndPartition, Broker]): JavaRDD[Array[Byte]] = {
+ val messageHandler = (mmd: MessageAndMetadata[Array[Byte], Array[Byte]]) =>
+ new PythonMessageAndMetadata(
+ mmd.topic, mmd.partition, mmd.offset, mmd.key(), mmd.message())
+ val rdd = createRDD(jsc, kafkaParams, offsetRanges, leaders, messageHandler).
+ mapPartitions(picklerIterator)
+ new JavaRDD(rdd)
}
- def createDirectStream(
+ private def createRDD[V: ClassTag](
+ jsc: JavaSparkContext,
+ kafkaParams: JMap[String, String],
+ offsetRanges: JList[OffsetRange],
+ leaders: JMap[TopicAndPartition, Broker],
+ messageHandler: MessageAndMetadata[Array[Byte], Array[Byte]] => V): RDD[V] = {
+ KafkaUtils.createRDD[Array[Byte], Array[Byte], DefaultDecoder, DefaultDecoder, V](
+ jsc.sc,
+ kafkaParams.asScala.toMap,
+ offsetRanges.toArray(new Array[OffsetRange](offsetRanges.size())),
+ leaders.asScala.toMap,
+ messageHandler
+ )
+ }
+
+ def createDirectStreamWithoutMessageHandler(
+ jssc: JavaStreamingContext,
+ kafkaParams: JMap[String, String],
+ topics: JSet[String],
+ fromOffsets: JMap[TopicAndPartition, JLong]): JavaDStream[(Array[Byte], Array[Byte])] = {
+ val messageHandler =
+ (mmd: MessageAndMetadata[Array[Byte], Array[Byte]]) => (mmd.key, mmd.message)
+ new JavaDStream(createDirectStream(jssc, kafkaParams, topics, fromOffsets, messageHandler))
+ }
+
+ def createDirectStreamWithMessageHandler(
jssc: JavaStreamingContext,
kafkaParams: JMap[String, String],
topics: JSet[String],
- fromOffsets: JMap[TopicAndPartition, JLong]
- ): JavaPairInputDStream[Array[Byte], Array[Byte]] = {
+ fromOffsets: JMap[TopicAndPartition, JLong]): JavaDStream[Array[Byte]] = {
+ val messageHandler = (mmd: MessageAndMetadata[Array[Byte], Array[Byte]]) =>
+ new PythonMessageAndMetadata(mmd.topic, mmd.partition, mmd.offset, mmd.key(), mmd.message())
+ val stream = createDirectStream(jssc, kafkaParams, topics, fromOffsets, messageHandler).
+ mapPartitions(picklerIterator)
+ new JavaDStream(stream)
+ }
- if (!fromOffsets.isEmpty) {
+ private def createDirectStream[V: ClassTag](
+ jssc: JavaStreamingContext,
+ kafkaParams: JMap[String, String],
+ topics: JSet[String],
+ fromOffsets: JMap[TopicAndPartition, JLong],
+ messageHandler: MessageAndMetadata[Array[Byte], Array[Byte]] => V): DStream[V] = {
+
+ val currentFromOffsets = if (!fromOffsets.isEmpty) {
val topicsFromOffsets = fromOffsets.keySet().asScala.map(_.topic)
if (topicsFromOffsets != topics.asScala.toSet) {
throw new IllegalStateException(
s"The specified topics: ${topics.asScala.toSet.mkString(" ")} " +
s"do not equal to the topic from offsets: ${topicsFromOffsets.mkString(" ")}")
}
- }
-
- if (fromOffsets.isEmpty) {
- KafkaUtils.createDirectStream[Array[Byte], Array[Byte], DefaultDecoder, DefaultDecoder](
- jssc,
- classOf[Array[Byte]],
- classOf[Array[Byte]],
- classOf[DefaultDecoder],
- classOf[DefaultDecoder],
- kafkaParams,
- topics)
+ Map(fromOffsets.asScala.mapValues { _.longValue() }.toSeq: _*)
} else {
- val messageHandler = new JFunction[MessageAndMetadata[Array[Byte], Array[Byte]],
- (Array[Byte], Array[Byte])] {
- def call(t1: MessageAndMetadata[Array[Byte], Array[Byte]]): (Array[Byte], Array[Byte]) =
- (t1.key(), t1.message())
- }
-
- val jstream = KafkaUtils.createDirectStream[
- Array[Byte],
- Array[Byte],
- DefaultDecoder,
- DefaultDecoder,
- (Array[Byte], Array[Byte])](
- jssc,
- classOf[Array[Byte]],
- classOf[Array[Byte]],
- classOf[DefaultDecoder],
- classOf[DefaultDecoder],
- classOf[(Array[Byte], Array[Byte])],
- kafkaParams,
- fromOffsets,
- messageHandler)
- new JavaPairInputDStream(jstream.inputDStream)
+ val kc = new KafkaCluster(Map(kafkaParams.asScala.toSeq: _*))
+ KafkaUtils.getFromOffsets(
+ kc, Map(kafkaParams.asScala.toSeq: _*), Set(topics.asScala.toSeq: _*))
}
+
+ KafkaUtils.createDirectStream[Array[Byte], Array[Byte], DefaultDecoder, DefaultDecoder, V](
+ jssc.ssc,
+ Map(kafkaParams.asScala.toSeq: _*),
+ Map(currentFromOffsets.toSeq: _*),
+ messageHandler)
}
def createOffsetRange(topic: String, partition: JInt, fromOffset: JLong, untilOffset: JLong
@@ -669,3 +688,57 @@ private[kafka] class KafkaUtilsPythonHelper {
kafkaRDD.offsetRanges.toSeq.asJava
}
}
+
+private object KafkaUtilsPythonHelper {
+ private var initialized = false
+
+ def initialize(): Unit = {
+ SerDeUtil.initialize()
+ synchronized {
+ if (!initialized) {
+ new PythonMessageAndMetadataPickler().register()
+ initialized = true
+ }
+ }
+ }
+
+ initialize()
+
+ def picklerIterator(iter: Iterator[Any]): Iterator[Array[Byte]] = {
+ new SerDeUtil.AutoBatchedPickler(iter)
+ }
+
+ case class PythonMessageAndMetadata(
+ topic: String,
+ partition: JInt,
+ offset: JLong,
+ key: Array[Byte],
+ message: Array[Byte])
+
+ class PythonMessageAndMetadataPickler extends IObjectPickler {
+ private val module = "pyspark.streaming.kafka"
+
+ def register(): Unit = {
+ Pickler.registerCustomPickler(classOf[PythonMessageAndMetadata], this)
+ Pickler.registerCustomPickler(this.getClass, this)
+ }
+
+ def pickle(obj: Object, out: OutputStream, pickler: Pickler) {
+ if (obj == this) {
+ out.write(Opcodes.GLOBAL)
+ out.write(s"$module\nKafkaMessageAndMetadata\n".getBytes(UTF_8))
+ } else {
+ pickler.save(this)
+ val msgAndMetaData = obj.asInstanceOf[PythonMessageAndMetadata]
+ out.write(Opcodes.MARK)
+ pickler.save(msgAndMetaData.topic)
+ pickler.save(msgAndMetaData.partition)
+ pickler.save(msgAndMetaData.offset)
+ pickler.save(msgAndMetaData.key)
+ pickler.save(msgAndMetaData.message)
+ out.write(Opcodes.TUPLE)
+ out.write(Opcodes.REDUCE)
+ }
+ }
+ }
+}
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 8b3bc96801..eb70d27c34 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -136,6 +136,12 @@ object MimaExcludes {
// SPARK-11766 add toJson to Vector
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.mllib.linalg.Vector.toJson")
+ ) ++ Seq(
+ // SPARK-9065 Support message handler in Kafka Python API
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper.createDirectStream"),
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper.createRDD")
)
case v if v.startsWith("1.5") =>
Seq(
diff --git a/python/pyspark/streaming/kafka.py b/python/pyspark/streaming/kafka.py
index 06e159172a..cdf97ec73a 100644
--- a/python/pyspark/streaming/kafka.py
+++ b/python/pyspark/streaming/kafka.py
@@ -19,12 +19,14 @@ from py4j.protocol import Py4JJavaError
from pyspark.rdd import RDD
from pyspark.storagelevel import StorageLevel
-from pyspark.serializers import PairDeserializer, NoOpSerializer
+from pyspark.serializers import AutoBatchedSerializer, PickleSerializer, PairDeserializer, \
+ NoOpSerializer
from pyspark.streaming import DStream
from pyspark.streaming.dstream import TransformedDStream
from pyspark.streaming.util import TransformFunction
-__all__ = ['Broker', 'KafkaUtils', 'OffsetRange', 'TopicAndPartition', 'utf8_decoder']
+__all__ = ['Broker', 'KafkaMessageAndMetadata', 'KafkaUtils', 'OffsetRange',
+ 'TopicAndPartition', 'utf8_decoder']
def utf8_decoder(s):
@@ -82,7 +84,8 @@ class KafkaUtils(object):
@staticmethod
def createDirectStream(ssc, topics, kafkaParams, fromOffsets=None,
- keyDecoder=utf8_decoder, valueDecoder=utf8_decoder):
+ keyDecoder=utf8_decoder, valueDecoder=utf8_decoder,
+ messageHandler=None):
"""
.. note:: Experimental
@@ -107,6 +110,8 @@ class KafkaUtils(object):
point of the stream.
:param keyDecoder: A function used to decode key (default is utf8_decoder).
:param valueDecoder: A function used to decode value (default is utf8_decoder).
+ :param messageHandler: A function used to convert KafkaMessageAndMetadata. You can assess
+ meta using messageHandler (default is None).
:return: A DStream object
"""
if fromOffsets is None:
@@ -116,6 +121,14 @@ class KafkaUtils(object):
if not isinstance(kafkaParams, dict):
raise TypeError("kafkaParams should be dict")
+ def funcWithoutMessageHandler(k_v):
+ return (keyDecoder(k_v[0]), valueDecoder(k_v[1]))
+
+ def funcWithMessageHandler(m):
+ m._set_key_decoder(keyDecoder)
+ m._set_value_decoder(valueDecoder)
+ return messageHandler(m)
+
try:
helperClass = ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \
.loadClass("org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper")
@@ -123,20 +136,28 @@ class KafkaUtils(object):
jfromOffsets = dict([(k._jTopicAndPartition(helper),
v) for (k, v) in fromOffsets.items()])
- jstream = helper.createDirectStream(ssc._jssc, kafkaParams, set(topics), jfromOffsets)
+ if messageHandler is None:
+ ser = PairDeserializer(NoOpSerializer(), NoOpSerializer())
+ func = funcWithoutMessageHandler
+ jstream = helper.createDirectStreamWithoutMessageHandler(
+ ssc._jssc, kafkaParams, set(topics), jfromOffsets)
+ else:
+ ser = AutoBatchedSerializer(PickleSerializer())
+ func = funcWithMessageHandler
+ jstream = helper.createDirectStreamWithMessageHandler(
+ ssc._jssc, kafkaParams, set(topics), jfromOffsets)
except Py4JJavaError as e:
if 'ClassNotFoundException' in str(e.java_exception):
KafkaUtils._printErrorMsg(ssc.sparkContext)
raise e
- ser = PairDeserializer(NoOpSerializer(), NoOpSerializer())
- stream = DStream(jstream, ssc, ser) \
- .map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1])))
+ stream = DStream(jstream, ssc, ser).map(func)
return KafkaDStream(stream._jdstream, ssc, stream._jrdd_deserializer)
@staticmethod
def createRDD(sc, kafkaParams, offsetRanges, leaders=None,
- keyDecoder=utf8_decoder, valueDecoder=utf8_decoder):
+ keyDecoder=utf8_decoder, valueDecoder=utf8_decoder,
+ messageHandler=None):
"""
.. note:: Experimental
@@ -149,6 +170,8 @@ class KafkaUtils(object):
map, in which case leaders will be looked up on the driver.
:param keyDecoder: A function used to decode key (default is utf8_decoder)
:param valueDecoder: A function used to decode value (default is utf8_decoder)
+ :param messageHandler: A function used to convert KafkaMessageAndMetadata. You can assess
+ meta using messageHandler (default is None).
:return: A RDD object
"""
if leaders is None:
@@ -158,6 +181,14 @@ class KafkaUtils(object):
if not isinstance(offsetRanges, list):
raise TypeError("offsetRanges should be list")
+ def funcWithoutMessageHandler(k_v):
+ return (keyDecoder(k_v[0]), valueDecoder(k_v[1]))
+
+ def funcWithMessageHandler(m):
+ m._set_key_decoder(keyDecoder)
+ m._set_value_decoder(valueDecoder)
+ return messageHandler(m)
+
try:
helperClass = sc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \
.loadClass("org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper")
@@ -165,15 +196,21 @@ class KafkaUtils(object):
joffsetRanges = [o._jOffsetRange(helper) for o in offsetRanges]
jleaders = dict([(k._jTopicAndPartition(helper),
v._jBroker(helper)) for (k, v) in leaders.items()])
- jrdd = helper.createRDD(sc._jsc, kafkaParams, joffsetRanges, jleaders)
+ if messageHandler is None:
+ jrdd = helper.createRDDWithoutMessageHandler(
+ sc._jsc, kafkaParams, joffsetRanges, jleaders)
+ ser = PairDeserializer(NoOpSerializer(), NoOpSerializer())
+ rdd = RDD(jrdd, sc, ser).map(funcWithoutMessageHandler)
+ else:
+ jrdd = helper.createRDDWithMessageHandler(
+ sc._jsc, kafkaParams, joffsetRanges, jleaders)
+ rdd = RDD(jrdd, sc).map(funcWithMessageHandler)
except Py4JJavaError as e:
if 'ClassNotFoundException' in str(e.java_exception):
KafkaUtils._printErrorMsg(sc)
raise e
- ser = PairDeserializer(NoOpSerializer(), NoOpSerializer())
- rdd = RDD(jrdd, sc, ser).map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1])))
- return KafkaRDD(rdd._jrdd, rdd.ctx, rdd._jrdd_deserializer)
+ return KafkaRDD(rdd._jrdd, sc, rdd._jrdd_deserializer)
@staticmethod
def _printErrorMsg(sc):
@@ -365,3 +402,53 @@ class KafkaTransformedDStream(TransformedDStream):
dstream = self._sc._jvm.PythonTransformedDStream(self.prev._jdstream.dstream(), jfunc)
self._jdstream_val = dstream.asJavaDStream()
return self._jdstream_val
+
+
+class KafkaMessageAndMetadata(object):
+ """
+ Kafka message and metadata information. Including topic, partition, offset and message
+ """
+
+ def __init__(self, topic, partition, offset, key, message):
+ """
+ Python wrapper of Kafka MessageAndMetadata
+ :param topic: topic name of this Kafka message
+ :param partition: partition id of this Kafka message
+ :param offset: Offset of this Kafka message in the specific partition
+ :param key: key payload of this Kafka message, can be null if this Kafka message has no key
+ specified, the return data is undecoded bytearry.
+ :param message: actual message payload of this Kafka message, the return data is
+ undecoded bytearray.
+ """
+ self.topic = topic
+ self.partition = partition
+ self.offset = offset
+ self._rawKey = key
+ self._rawMessage = message
+ self._keyDecoder = utf8_decoder
+ self._valueDecoder = utf8_decoder
+
+ def __str__(self):
+ return "KafkaMessageAndMetadata(topic: %s, partition: %d, offset: %d, key and message...)" \
+ % (self.topic, self.partition, self.offset)
+
+ def __repr__(self):
+ return self.__str__()
+
+ def __reduce__(self):
+ return (KafkaMessageAndMetadata,
+ (self.topic, self.partition, self.offset, self._rawKey, self._rawMessage))
+
+ def _set_key_decoder(self, decoder):
+ self._keyDecoder = decoder
+
+ def _set_value_decoder(self, decoder):
+ self._valueDecoder = decoder
+
+ @property
+ def key(self):
+ return self._keyDecoder(self._rawKey)
+
+ @property
+ def message(self):
+ return self._valueDecoder(self._rawMessage)
diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py
index ff95639146..0bcd1f1553 100644
--- a/python/pyspark/streaming/tests.py
+++ b/python/pyspark/streaming/tests.py
@@ -1042,6 +1042,41 @@ class KafkaStreamTests(PySparkStreamingTestCase):
self.assertNotEqual(topic_and_partition_a, topic_and_partition_c)
self.assertNotEqual(topic_and_partition_a, topic_and_partition_d)
+ @unittest.skipIf(sys.version >= "3", "long type not support")
+ def test_kafka_rdd_message_handler(self):
+ """Test Python direct Kafka RDD MessageHandler."""
+ topic = self._randomTopic()
+ sendData = {"a": 1, "b": 1, "c": 2}
+ offsetRanges = [OffsetRange(topic, 0, long(0), long(sum(sendData.values())))]
+ kafkaParams = {"metadata.broker.list": self._kafkaTestUtils.brokerAddress()}
+
+ def getKeyAndDoubleMessage(m):
+ return m and (m.key, m.message * 2)
+
+ self._kafkaTestUtils.createTopic(topic)
+ self._kafkaTestUtils.sendMessages(topic, sendData)
+ rdd = KafkaUtils.createRDD(self.sc, kafkaParams, offsetRanges,
+ messageHandler=getKeyAndDoubleMessage)
+ self._validateRddResult({"aa": 1, "bb": 1, "cc": 2}, rdd)
+
+ @unittest.skipIf(sys.version >= "3", "long type not support")
+ def test_kafka_direct_stream_message_handler(self):
+ """Test the Python direct Kafka stream MessageHandler."""
+ topic = self._randomTopic()
+ sendData = {"a": 1, "b": 2, "c": 3}
+ kafkaParams = {"metadata.broker.list": self._kafkaTestUtils.brokerAddress(),
+ "auto.offset.reset": "smallest"}
+
+ self._kafkaTestUtils.createTopic(topic)
+ self._kafkaTestUtils.sendMessages(topic, sendData)
+
+ def getKeyAndDoubleMessage(m):
+ return m and (m.key, m.message * 2)
+
+ stream = KafkaUtils.createDirectStream(self.ssc, [topic], kafkaParams,
+ messageHandler=getKeyAndDoubleMessage)
+ self._validateStreamResult({"aa": 1, "bb": 2, "cc": 3}, stream)
+
class FlumeStreamTests(PySparkStreamingTestCase):
timeout = 20 # seconds