From f292018f8e57779debc04998456ec875f628133b Mon Sep 17 00:00:00 2001 From: jerryshao Date: Tue, 1 Dec 2015 15:26:10 -0800 Subject: [SPARK-12002][STREAMING][PYSPARK] Fix python direct stream checkpoint recovery issue Fixed a minor race condition in #10017 Closes #10017 Author: jerryshao Author: Shixiong Zhu Closes #10074 from zsxwing/review-pr10017. --- python/pyspark/streaming/tests.py | 49 +++++++++++++++++++++++++++++++++++++++ python/pyspark/streaming/util.py | 13 ++++++----- 2 files changed, 56 insertions(+), 6 deletions(-) (limited to 'python/pyspark') diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index a647e6bf39..d50c6b8d4a 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -1149,6 +1149,55 @@ class KafkaStreamTests(PySparkStreamingTestCase): self.assertNotEqual(topic_and_partition_a, topic_and_partition_c) self.assertNotEqual(topic_and_partition_a, topic_and_partition_d) + @unittest.skipIf(sys.version >= "3", "long type not support") + def test_kafka_direct_stream_transform_with_checkpoint(self): + """Test the Python direct Kafka stream transform with checkpoint correctly recovered.""" + topic = self._randomTopic() + sendData = {"a": 1, "b": 2, "c": 3} + kafkaParams = {"metadata.broker.list": self._kafkaTestUtils.brokerAddress(), + "auto.offset.reset": "smallest"} + + self._kafkaTestUtils.createTopic(topic) + self._kafkaTestUtils.sendMessages(topic, sendData) + + offsetRanges = [] + + def transformWithOffsetRanges(rdd): + for o in rdd.offsetRanges(): + offsetRanges.append(o) + return rdd + + self.ssc.stop(False) + self.ssc = None + tmpdir = "checkpoint-test-%d" % random.randint(0, 10000) + + def setup(): + ssc = StreamingContext(self.sc, 0.5) + ssc.checkpoint(tmpdir) + stream = KafkaUtils.createDirectStream(ssc, [topic], kafkaParams) + stream.transform(transformWithOffsetRanges).count().pprint() + return ssc + + try: + ssc1 = StreamingContext.getOrCreate(tmpdir, setup) + ssc1.start() + self.wait_for(offsetRanges, 1) + self.assertEqual(offsetRanges, [OffsetRange(topic, 0, long(0), long(6))]) + + # To make sure some checkpoint is written + time.sleep(3) + ssc1.stop(False) + ssc1 = None + + # Restart again to make sure the checkpoint is recovered correctly + ssc2 = StreamingContext.getOrCreate(tmpdir, setup) + ssc2.start() + ssc2.awaitTermination(3) + ssc2.stop(stopSparkContext=False, stopGraceFully=True) + ssc2 = None + finally: + shutil.rmtree(tmpdir) + @unittest.skipIf(sys.version >= "3", "long type not support") def test_kafka_rdd_message_handler(self): """Test Python direct Kafka RDD MessageHandler.""" diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py index c7f02bca2a..abbbf6eb93 100644 --- a/python/pyspark/streaming/util.py +++ b/python/pyspark/streaming/util.py @@ -37,11 +37,11 @@ class TransformFunction(object): self.ctx = ctx self.func = func self.deserializers = deserializers - self._rdd_wrapper = lambda jrdd, ctx, ser: RDD(jrdd, ctx, ser) + self.rdd_wrap_func = lambda jrdd, ctx, ser: RDD(jrdd, ctx, ser) self.failure = None def rdd_wrapper(self, func): - self._rdd_wrapper = func + self.rdd_wrap_func = func return self def call(self, milliseconds, jrdds): @@ -59,7 +59,7 @@ class TransformFunction(object): if len(sers) < len(jrdds): sers += (sers[0],) * (len(jrdds) - len(sers)) - rdds = [self._rdd_wrapper(jrdd, self.ctx, ser) if jrdd else None + rdds = [self.rdd_wrap_func(jrdd, self.ctx, ser) if jrdd else None for jrdd, ser in zip(jrdds, sers)] t = datetime.fromtimestamp(milliseconds / 1000.0) r = self.func(t, *rdds) @@ -101,7 +101,8 @@ class TransformFunctionSerializer(object): self.failure = None try: func = self.gateway.gateway_property.pool[id] - return bytearray(self.serializer.dumps((func.func, func.deserializers))) + return bytearray(self.serializer.dumps(( + func.func, func.rdd_wrap_func, func.deserializers))) except: self.failure = traceback.format_exc() @@ -109,8 +110,8 @@ class TransformFunctionSerializer(object): # Clear the failure self.failure = None try: - f, deserializers = self.serializer.loads(bytes(data)) - return TransformFunction(self.ctx, f, *deserializers) + f, wrap_func, deserializers = self.serializer.loads(bytes(data)) + return TransformFunction(self.ctx, f, *deserializers).rdd_wrapper(wrap_func) except: self.failure = traceback.format_exc() -- cgit v1.2.3