aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark
diff options
context:
space:
mode:
authorjerryshao <saisai.shao@intel.com>2015-07-09 13:54:44 -0700
committerTathagata Das <tathagata.das1565@gmail.com>2015-07-09 13:54:44 -0700
commit3ccebf36c5abe04702d4cf223552a94034d980fb (patch)
tree40990a75e75399422a2d34926505d5521db8edbe /python/pyspark
parent1f6b0b1234cc03aa2e07aea7fec2de7563885238 (diff)
downloadspark-3ccebf36c5abe04702d4cf223552a94034d980fb.tar.gz
spark-3ccebf36c5abe04702d4cf223552a94034d980fb.tar.bz2
spark-3ccebf36c5abe04702d4cf223552a94034d980fb.zip
[SPARK-8389] [STREAMING] [PYSPARK] Expose KafkaRDDs offsetRange in Python
This PR propose a simple way to expose OffsetRange in Python code, also the usage of offsetRanges is similar to Scala/Java way, here in Python we could get OffsetRange like: ``` dstream.foreachRDD(lambda r: KafkaUtils.offsetRanges(r)) ``` Reason I didn't follow the way what SPARK-8389 suggested is that: Python Kafka API has one more step to decode the message compared to Scala/Java, Which makes Python API return a transformed RDD/DStream, not directly wrapped so-called JavaKafkaRDD, so it is hard to backtrack to the original RDD to get the offsetRange. Author: jerryshao <saisai.shao@intel.com> Closes #7185 from jerryshao/SPARK-8389 and squashes the following commits: 4c6d320 [jerryshao] Another way to fix subclass deserialization issue e6a8011 [jerryshao] Address the comments fd13937 [jerryshao] Fix serialization bug 7debf1c [jerryshao] bug fix cff3893 [jerryshao] refactor the code according to the comments 2aabf9e [jerryshao] Style fix 848c708 [jerryshao] Add HasOffsetRanges for Python
Diffstat (limited to 'python/pyspark')
-rw-r--r--python/pyspark/streaming/kafka.py123
-rw-r--r--python/pyspark/streaming/tests.py64
-rw-r--r--python/pyspark/streaming/util.py7
3 files changed, 183 insertions, 11 deletions
diff --git a/python/pyspark/streaming/kafka.py b/python/pyspark/streaming/kafka.py
index 10a859a532..33dd596335 100644
--- a/python/pyspark/streaming/kafka.py
+++ b/python/pyspark/streaming/kafka.py
@@ -21,6 +21,8 @@ from pyspark.rdd import RDD
from pyspark.storagelevel import StorageLevel
from pyspark.serializers import PairDeserializer, NoOpSerializer
from pyspark.streaming import DStream
+from pyspark.streaming.dstream import TransformedDStream
+from pyspark.streaming.util import TransformFunction
__all__ = ['Broker', 'KafkaUtils', 'OffsetRange', 'TopicAndPartition', 'utf8_decoder']
@@ -122,8 +124,9 @@ class KafkaUtils(object):
raise e
ser = PairDeserializer(NoOpSerializer(), NoOpSerializer())
- stream = DStream(jstream, ssc, ser)
- return stream.map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1])))
+ stream = DStream(jstream, ssc, ser) \
+ .map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1])))
+ return KafkaDStream(stream._jdstream, ssc, stream._jrdd_deserializer)
@staticmethod
def createRDD(sc, kafkaParams, offsetRanges, leaders={},
@@ -161,8 +164,8 @@ class KafkaUtils(object):
raise e
ser = PairDeserializer(NoOpSerializer(), NoOpSerializer())
- rdd = RDD(jrdd, sc, ser)
- return rdd.map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1])))
+ rdd = RDD(jrdd, sc, ser).map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1])))
+ return KafkaRDD(rdd._jrdd, rdd.ctx, rdd._jrdd_deserializer)
@staticmethod
def _printErrorMsg(sc):
@@ -200,14 +203,30 @@ class OffsetRange(object):
:param fromOffset: Inclusive starting offset.
:param untilOffset: Exclusive ending offset.
"""
- self._topic = topic
- self._partition = partition
- self._fromOffset = fromOffset
- self._untilOffset = untilOffset
+ self.topic = topic
+ self.partition = partition
+ self.fromOffset = fromOffset
+ self.untilOffset = untilOffset
+
+ def __eq__(self, other):
+ if isinstance(other, self.__class__):
+ return (self.topic == other.topic
+ and self.partition == other.partition
+ and self.fromOffset == other.fromOffset
+ and self.untilOffset == other.untilOffset)
+ else:
+ return False
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+ def __str__(self):
+ return "OffsetRange(topic: %s, partition: %d, range: [%d -> %d]" \
+ % (self.topic, self.partition, self.fromOffset, self.untilOffset)
def _jOffsetRange(self, helper):
- return helper.createOffsetRange(self._topic, self._partition, self._fromOffset,
- self._untilOffset)
+ return helper.createOffsetRange(self.topic, self.partition, self.fromOffset,
+ self.untilOffset)
class TopicAndPartition(object):
@@ -244,3 +263,87 @@ class Broker(object):
def _jBroker(self, helper):
return helper.createBroker(self._host, self._port)
+
+
+class KafkaRDD(RDD):
+ """
+ A Python wrapper of KafkaRDD, to provide additional information on normal RDD.
+ """
+
+ def __init__(self, jrdd, ctx, jrdd_deserializer):
+ RDD.__init__(self, jrdd, ctx, jrdd_deserializer)
+
+ def offsetRanges(self):
+ """
+ Get the OffsetRange of specific KafkaRDD.
+ :return: A list of OffsetRange
+ """
+ try:
+ helperClass = self.ctx._jvm.java.lang.Thread.currentThread().getContextClassLoader() \
+ .loadClass("org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper")
+ helper = helperClass.newInstance()
+ joffsetRanges = helper.offsetRangesOfKafkaRDD(self._jrdd.rdd())
+ except Py4JJavaError as e:
+ if 'ClassNotFoundException' in str(e.java_exception):
+ KafkaUtils._printErrorMsg(self.ctx)
+ raise e
+
+ ranges = [OffsetRange(o.topic(), o.partition(), o.fromOffset(), o.untilOffset())
+ for o in joffsetRanges]
+ return ranges
+
+
+class KafkaDStream(DStream):
+ """
+ A Python wrapper of KafkaDStream
+ """
+
+ def __init__(self, jdstream, ssc, jrdd_deserializer):
+ DStream.__init__(self, jdstream, ssc, jrdd_deserializer)
+
+ def foreachRDD(self, func):
+ """
+ Apply a function to each RDD in this DStream.
+ """
+ if func.__code__.co_argcount == 1:
+ old_func = func
+ func = lambda r, rdd: old_func(rdd)
+ jfunc = TransformFunction(self._sc, func, self._jrdd_deserializer) \
+ .rdd_wrapper(lambda jrdd, ctx, ser: KafkaRDD(jrdd, ctx, ser))
+ api = self._ssc._jvm.PythonDStream
+ api.callForeachRDD(self._jdstream, jfunc)
+
+ def transform(self, func):
+ """
+ Return a new DStream in which each RDD is generated by applying a function
+ on each RDD of this DStream.
+
+ `func` can have one argument of `rdd`, or have two arguments of
+ (`time`, `rdd`)
+ """
+ if func.__code__.co_argcount == 1:
+ oldfunc = func
+ func = lambda t, rdd: oldfunc(rdd)
+ assert func.__code__.co_argcount == 2, "func should take one or two arguments"
+
+ return KafkaTransformedDStream(self, func)
+
+
+class KafkaTransformedDStream(TransformedDStream):
+ """
+ Kafka specific wrapper of TransformedDStream to transform on Kafka RDD.
+ """
+
+ def __init__(self, prev, func):
+ TransformedDStream.__init__(self, prev, func)
+
+ @property
+ def _jdstream(self):
+ if self._jdstream_val is not None:
+ return self._jdstream_val
+
+ jfunc = TransformFunction(self._sc, self.func, self.prev._jrdd_deserializer) \
+ .rdd_wrapper(lambda jrdd, ctx, ser: KafkaRDD(jrdd, ctx, ser))
+ dstream = self._sc._jvm.PythonTransformedDStream(self.prev._jdstream.dstream(), jfunc)
+ self._jdstream_val = dstream.asJavaDStream()
+ return self._jdstream_val
diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py
index 188c8ff120..4ecae1e4bf 100644
--- a/python/pyspark/streaming/tests.py
+++ b/python/pyspark/streaming/tests.py
@@ -678,6 +678,70 @@ class KafkaStreamTests(PySparkStreamingTestCase):
rdd = KafkaUtils.createRDD(self.sc, kafkaParams, offsetRanges, leaders)
self._validateRddResult(sendData, rdd)
+ @unittest.skipIf(sys.version >= "3", "long type not support")
+ def test_kafka_rdd_get_offsetRanges(self):
+ """Test Python direct Kafka RDD get OffsetRanges."""
+ topic = self._randomTopic()
+ sendData = {"a": 3, "b": 4, "c": 5}
+ offsetRanges = [OffsetRange(topic, 0, long(0), long(sum(sendData.values())))]
+ kafkaParams = {"metadata.broker.list": self._kafkaTestUtils.brokerAddress()}
+
+ self._kafkaTestUtils.createTopic(topic)
+ self._kafkaTestUtils.sendMessages(topic, sendData)
+ rdd = KafkaUtils.createRDD(self.sc, kafkaParams, offsetRanges)
+ self.assertEqual(offsetRanges, rdd.offsetRanges())
+
+ @unittest.skipIf(sys.version >= "3", "long type not support")
+ def test_kafka_direct_stream_foreach_get_offsetRanges(self):
+ """Test the Python direct Kafka stream foreachRDD get offsetRanges."""
+ topic = self._randomTopic()
+ sendData = {"a": 1, "b": 2, "c": 3}
+ kafkaParams = {"metadata.broker.list": self._kafkaTestUtils.brokerAddress(),
+ "auto.offset.reset": "smallest"}
+
+ self._kafkaTestUtils.createTopic(topic)
+ self._kafkaTestUtils.sendMessages(topic, sendData)
+
+ stream = KafkaUtils.createDirectStream(self.ssc, [topic], kafkaParams)
+
+ offsetRanges = []
+
+ def getOffsetRanges(_, rdd):
+ for o in rdd.offsetRanges():
+ offsetRanges.append(o)
+
+ stream.foreachRDD(getOffsetRanges)
+ self.ssc.start()
+ self.wait_for(offsetRanges, 1)
+
+ self.assertEqual(offsetRanges, [OffsetRange(topic, 0, long(0), long(6))])
+
+ @unittest.skipIf(sys.version >= "3", "long type not support")
+ def test_kafka_direct_stream_transform_get_offsetRanges(self):
+ """Test the Python direct Kafka stream transform get offsetRanges."""
+ topic = self._randomTopic()
+ sendData = {"a": 1, "b": 2, "c": 3}
+ kafkaParams = {"metadata.broker.list": self._kafkaTestUtils.brokerAddress(),
+ "auto.offset.reset": "smallest"}
+
+ self._kafkaTestUtils.createTopic(topic)
+ self._kafkaTestUtils.sendMessages(topic, sendData)
+
+ stream = KafkaUtils.createDirectStream(self.ssc, [topic], kafkaParams)
+
+ offsetRanges = []
+
+ def transformWithOffsetRanges(rdd):
+ for o in rdd.offsetRanges():
+ offsetRanges.append(o)
+ return rdd
+
+ stream.transform(transformWithOffsetRanges).foreachRDD(lambda rdd: rdd.count())
+ self.ssc.start()
+ self.wait_for(offsetRanges, 1)
+
+ self.assertEqual(offsetRanges, [OffsetRange(topic, 0, long(0), long(6))])
+
class FlumeStreamTests(PySparkStreamingTestCase):
timeout = 20 # seconds
diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py
index a9bfec2aab..b20613b128 100644
--- a/python/pyspark/streaming/util.py
+++ b/python/pyspark/streaming/util.py
@@ -37,6 +37,11 @@ class TransformFunction(object):
self.ctx = ctx
self.func = func
self.deserializers = deserializers
+ self._rdd_wrapper = lambda jrdd, ctx, ser: RDD(jrdd, ctx, ser)
+
+ def rdd_wrapper(self, func):
+ self._rdd_wrapper = func
+ return self
def call(self, milliseconds, jrdds):
try:
@@ -51,7 +56,7 @@ class TransformFunction(object):
if len(sers) < len(jrdds):
sers += (sers[0],) * (len(jrdds) - len(sers))
- rdds = [RDD(jrdd, self.ctx, ser) if jrdd else None
+ rdds = [self._rdd_wrapper(jrdd, self.ctx, ser) if jrdd else None
for jrdd, ser in zip(jrdds, sers)]
t = datetime.fromtimestamp(milliseconds / 1000.0)
r = self.func(t, *rdds)