aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/streaming
diff options
context:
space:
mode:
authorjerryshao <sshao@hortonworks.com>2015-11-17 16:57:52 -0800
committerTathagata Das <tathagata.das1565@gmail.com>2015-11-17 16:57:52 -0800
commit75a292291062783129d02607302f91c85655975e (patch)
tree7ba09e47b8aa3b3810bc3d115817a862a729aac0 /python/pyspark/streaming
parentb362d50fca30693f97bd859984157bb8a76d48a1 (diff)
downloadspark-75a292291062783129d02607302f91c85655975e.tar.gz
spark-75a292291062783129d02607302f91c85655975e.tar.bz2
spark-75a292291062783129d02607302f91c85655975e.zip
[SPARK-9065][STREAMING][PYSPARK] Add MessageHandler for Kafka Python API
Fixed the merge conflicts in #7410 Closes #7410 Author: Shixiong Zhu <shixiong@databricks.com> Author: jerryshao <saisai.shao@intel.com> Author: jerryshao <sshao@hortonworks.com> Closes #9742 from zsxwing/pr7410.
Diffstat (limited to 'python/pyspark/streaming')
-rw-r--r--python/pyspark/streaming/kafka.py111
-rw-r--r--python/pyspark/streaming/tests.py35
2 files changed, 134 insertions, 12 deletions
diff --git a/python/pyspark/streaming/kafka.py b/python/pyspark/streaming/kafka.py
index 06e159172a..cdf97ec73a 100644
--- a/python/pyspark/streaming/kafka.py
+++ b/python/pyspark/streaming/kafka.py
@@ -19,12 +19,14 @@ from py4j.protocol import Py4JJavaError
from pyspark.rdd import RDD
from pyspark.storagelevel import StorageLevel
-from pyspark.serializers import PairDeserializer, NoOpSerializer
+from pyspark.serializers import AutoBatchedSerializer, PickleSerializer, PairDeserializer, \
+ NoOpSerializer
from pyspark.streaming import DStream
from pyspark.streaming.dstream import TransformedDStream
from pyspark.streaming.util import TransformFunction
-__all__ = ['Broker', 'KafkaUtils', 'OffsetRange', 'TopicAndPartition', 'utf8_decoder']
+__all__ = ['Broker', 'KafkaMessageAndMetadata', 'KafkaUtils', 'OffsetRange',
+ 'TopicAndPartition', 'utf8_decoder']
def utf8_decoder(s):
@@ -82,7 +84,8 @@ class KafkaUtils(object):
@staticmethod
def createDirectStream(ssc, topics, kafkaParams, fromOffsets=None,
- keyDecoder=utf8_decoder, valueDecoder=utf8_decoder):
+ keyDecoder=utf8_decoder, valueDecoder=utf8_decoder,
+ messageHandler=None):
"""
.. note:: Experimental
@@ -107,6 +110,8 @@ class KafkaUtils(object):
point of the stream.
:param keyDecoder: A function used to decode key (default is utf8_decoder).
:param valueDecoder: A function used to decode value (default is utf8_decoder).
+ :param messageHandler: A function used to convert KafkaMessageAndMetadata. You can assess
+ meta using messageHandler (default is None).
:return: A DStream object
"""
if fromOffsets is None:
@@ -116,6 +121,14 @@ class KafkaUtils(object):
if not isinstance(kafkaParams, dict):
raise TypeError("kafkaParams should be dict")
+ def funcWithoutMessageHandler(k_v):
+ return (keyDecoder(k_v[0]), valueDecoder(k_v[1]))
+
+ def funcWithMessageHandler(m):
+ m._set_key_decoder(keyDecoder)
+ m._set_value_decoder(valueDecoder)
+ return messageHandler(m)
+
try:
helperClass = ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \
.loadClass("org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper")
@@ -123,20 +136,28 @@ class KafkaUtils(object):
jfromOffsets = dict([(k._jTopicAndPartition(helper),
v) for (k, v) in fromOffsets.items()])
- jstream = helper.createDirectStream(ssc._jssc, kafkaParams, set(topics), jfromOffsets)
+ if messageHandler is None:
+ ser = PairDeserializer(NoOpSerializer(), NoOpSerializer())
+ func = funcWithoutMessageHandler
+ jstream = helper.createDirectStreamWithoutMessageHandler(
+ ssc._jssc, kafkaParams, set(topics), jfromOffsets)
+ else:
+ ser = AutoBatchedSerializer(PickleSerializer())
+ func = funcWithMessageHandler
+ jstream = helper.createDirectStreamWithMessageHandler(
+ ssc._jssc, kafkaParams, set(topics), jfromOffsets)
except Py4JJavaError as e:
if 'ClassNotFoundException' in str(e.java_exception):
KafkaUtils._printErrorMsg(ssc.sparkContext)
raise e
- ser = PairDeserializer(NoOpSerializer(), NoOpSerializer())
- stream = DStream(jstream, ssc, ser) \
- .map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1])))
+ stream = DStream(jstream, ssc, ser).map(func)
return KafkaDStream(stream._jdstream, ssc, stream._jrdd_deserializer)
@staticmethod
def createRDD(sc, kafkaParams, offsetRanges, leaders=None,
- keyDecoder=utf8_decoder, valueDecoder=utf8_decoder):
+ keyDecoder=utf8_decoder, valueDecoder=utf8_decoder,
+ messageHandler=None):
"""
.. note:: Experimental
@@ -149,6 +170,8 @@ class KafkaUtils(object):
map, in which case leaders will be looked up on the driver.
:param keyDecoder: A function used to decode key (default is utf8_decoder)
:param valueDecoder: A function used to decode value (default is utf8_decoder)
+ :param messageHandler: A function used to convert KafkaMessageAndMetadata. You can assess
+ meta using messageHandler (default is None).
:return: A RDD object
"""
if leaders is None:
@@ -158,6 +181,14 @@ class KafkaUtils(object):
if not isinstance(offsetRanges, list):
raise TypeError("offsetRanges should be list")
+ def funcWithoutMessageHandler(k_v):
+ return (keyDecoder(k_v[0]), valueDecoder(k_v[1]))
+
+ def funcWithMessageHandler(m):
+ m._set_key_decoder(keyDecoder)
+ m._set_value_decoder(valueDecoder)
+ return messageHandler(m)
+
try:
helperClass = sc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \
.loadClass("org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper")
@@ -165,15 +196,21 @@ class KafkaUtils(object):
joffsetRanges = [o._jOffsetRange(helper) for o in offsetRanges]
jleaders = dict([(k._jTopicAndPartition(helper),
v._jBroker(helper)) for (k, v) in leaders.items()])
- jrdd = helper.createRDD(sc._jsc, kafkaParams, joffsetRanges, jleaders)
+ if messageHandler is None:
+ jrdd = helper.createRDDWithoutMessageHandler(
+ sc._jsc, kafkaParams, joffsetRanges, jleaders)
+ ser = PairDeserializer(NoOpSerializer(), NoOpSerializer())
+ rdd = RDD(jrdd, sc, ser).map(funcWithoutMessageHandler)
+ else:
+ jrdd = helper.createRDDWithMessageHandler(
+ sc._jsc, kafkaParams, joffsetRanges, jleaders)
+ rdd = RDD(jrdd, sc).map(funcWithMessageHandler)
except Py4JJavaError as e:
if 'ClassNotFoundException' in str(e.java_exception):
KafkaUtils._printErrorMsg(sc)
raise e
- ser = PairDeserializer(NoOpSerializer(), NoOpSerializer())
- rdd = RDD(jrdd, sc, ser).map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1])))
- return KafkaRDD(rdd._jrdd, rdd.ctx, rdd._jrdd_deserializer)
+ return KafkaRDD(rdd._jrdd, sc, rdd._jrdd_deserializer)
@staticmethod
def _printErrorMsg(sc):
@@ -365,3 +402,53 @@ class KafkaTransformedDStream(TransformedDStream):
dstream = self._sc._jvm.PythonTransformedDStream(self.prev._jdstream.dstream(), jfunc)
self._jdstream_val = dstream.asJavaDStream()
return self._jdstream_val
+
+
+class KafkaMessageAndMetadata(object):
+ """
+ Kafka message and metadata information. Including topic, partition, offset and message
+ """
+
+ def __init__(self, topic, partition, offset, key, message):
+ """
+ Python wrapper of Kafka MessageAndMetadata
+ :param topic: topic name of this Kafka message
+ :param partition: partition id of this Kafka message
+ :param offset: Offset of this Kafka message in the specific partition
+ :param key: key payload of this Kafka message, can be null if this Kafka message has no key
+ specified, the return data is undecoded bytearry.
+ :param message: actual message payload of this Kafka message, the return data is
+ undecoded bytearray.
+ """
+ self.topic = topic
+ self.partition = partition
+ self.offset = offset
+ self._rawKey = key
+ self._rawMessage = message
+ self._keyDecoder = utf8_decoder
+ self._valueDecoder = utf8_decoder
+
+ def __str__(self):
+ return "KafkaMessageAndMetadata(topic: %s, partition: %d, offset: %d, key and message...)" \
+ % (self.topic, self.partition, self.offset)
+
+ def __repr__(self):
+ return self.__str__()
+
+ def __reduce__(self):
+ return (KafkaMessageAndMetadata,
+ (self.topic, self.partition, self.offset, self._rawKey, self._rawMessage))
+
+ def _set_key_decoder(self, decoder):
+ self._keyDecoder = decoder
+
+ def _set_value_decoder(self, decoder):
+ self._valueDecoder = decoder
+
+ @property
+ def key(self):
+ return self._keyDecoder(self._rawKey)
+
+ @property
+ def message(self):
+ return self._valueDecoder(self._rawMessage)
diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py
index ff95639146..0bcd1f1553 100644
--- a/python/pyspark/streaming/tests.py
+++ b/python/pyspark/streaming/tests.py
@@ -1042,6 +1042,41 @@ class KafkaStreamTests(PySparkStreamingTestCase):
self.assertNotEqual(topic_and_partition_a, topic_and_partition_c)
self.assertNotEqual(topic_and_partition_a, topic_and_partition_d)
+ @unittest.skipIf(sys.version >= "3", "long type not support")
+ def test_kafka_rdd_message_handler(self):
+ """Test Python direct Kafka RDD MessageHandler."""
+ topic = self._randomTopic()
+ sendData = {"a": 1, "b": 1, "c": 2}
+ offsetRanges = [OffsetRange(topic, 0, long(0), long(sum(sendData.values())))]
+ kafkaParams = {"metadata.broker.list": self._kafkaTestUtils.brokerAddress()}
+
+ def getKeyAndDoubleMessage(m):
+ return m and (m.key, m.message * 2)
+
+ self._kafkaTestUtils.createTopic(topic)
+ self._kafkaTestUtils.sendMessages(topic, sendData)
+ rdd = KafkaUtils.createRDD(self.sc, kafkaParams, offsetRanges,
+ messageHandler=getKeyAndDoubleMessage)
+ self._validateRddResult({"aa": 1, "bb": 1, "cc": 2}, rdd)
+
+ @unittest.skipIf(sys.version >= "3", "long type not support")
+ def test_kafka_direct_stream_message_handler(self):
+ """Test the Python direct Kafka stream MessageHandler."""
+ topic = self._randomTopic()
+ sendData = {"a": 1, "b": 2, "c": 3}
+ kafkaParams = {"metadata.broker.list": self._kafkaTestUtils.brokerAddress(),
+ "auto.offset.reset": "smallest"}
+
+ self._kafkaTestUtils.createTopic(topic)
+ self._kafkaTestUtils.sendMessages(topic, sendData)
+
+ def getKeyAndDoubleMessage(m):
+ return m and (m.key, m.message * 2)
+
+ stream = KafkaUtils.createDirectStream(self.ssc, [topic], kafkaParams,
+ messageHandler=getKeyAndDoubleMessage)
+ self._validateStreamResult({"aa": 1, "bb": 2, "cc": 3}, stream)
+
class FlumeStreamTests(PySparkStreamingTestCase):
timeout = 20 # seconds