From 073bf9d4d91e0242a813f3d227e52e76c26a2200 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 11 Mar 2016 11:18:51 -0800 Subject: [SPARK-13807] De-duplicate `Python*Helper` instantiation code in PySpark streaming This patch de-duplicates code in PySpark streaming which loads the `Python*Helper` classes. I also changed a few `raise e` statements to simply `raise` in order to preserve the full exception stacktrace when re-throwing. Here's a link to the whitespace-change-free diff: https://github.com/apache/spark/compare/master...JoshRosen:pyspark-reflection-deduplication?w=0 Author: Josh Rosen Closes #11641 from JoshRosen/pyspark-reflection-deduplication. --- python/pyspark/streaming/flume.py | 40 ++++++--------- python/pyspark/streaming/kafka.py | 100 +++++++++++++++--------------------- python/pyspark/streaming/kinesis.py | 2 +- python/pyspark/streaming/mqtt.py | 2 +- 4 files changed, 60 insertions(+), 84 deletions(-) (limited to 'python') diff --git a/python/pyspark/streaming/flume.py b/python/pyspark/streaming/flume.py index b1fff0a5c7..edd5886a85 100644 --- a/python/pyspark/streaming/flume.py +++ b/python/pyspark/streaming/flume.py @@ -55,17 +55,8 @@ class FlumeUtils(object): :return: A DStream object """ jlevel = ssc._sc._getJavaStorageLevel(storageLevel) - - try: - helperClass = ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader()\ - .loadClass("org.apache.spark.streaming.flume.FlumeUtilsPythonHelper") - helper = helperClass.newInstance() - jstream = helper.createStream(ssc._jssc, hostname, port, jlevel, enableDecompression) - except Py4JJavaError as e: - if 'ClassNotFoundException' in str(e.java_exception): - FlumeUtils._printErrorMsg(ssc.sparkContext) - raise e - + helper = FlumeUtils._get_helper(ssc._sc) + jstream = helper.createStream(ssc._jssc, hostname, port, jlevel, enableDecompression) return FlumeUtils._toPythonDStream(ssc, jstream, bodyDecoder) @staticmethod @@ -95,18 +86,9 @@ class FlumeUtils(object): for (host, port) in addresses: hosts.append(host) ports.append(port) - - try: - helperClass = ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \ - .loadClass("org.apache.spark.streaming.flume.FlumeUtilsPythonHelper") - helper = helperClass.newInstance() - jstream = helper.createPollingStream( - ssc._jssc, hosts, ports, jlevel, maxBatchSize, parallelism) - except Py4JJavaError as e: - if 'ClassNotFoundException' in str(e.java_exception): - FlumeUtils._printErrorMsg(ssc.sparkContext) - raise e - + helper = FlumeUtils._get_helper(ssc._sc) + jstream = helper.createPollingStream( + ssc._jssc, hosts, ports, jlevel, maxBatchSize, parallelism) return FlumeUtils._toPythonDStream(ssc, jstream, bodyDecoder) @staticmethod @@ -126,6 +108,18 @@ class FlumeUtils(object): return (headers, body) return stream.map(func) + @staticmethod + def _get_helper(sc): + try: + helperClass = sc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \ + .loadClass("org.apache.spark.streaming.flume.FlumeUtilsPythonHelper") + return helperClass.newInstance() + except Py4JJavaError as e: + # TODO: use --jar once it also work on driver + if 'ClassNotFoundException' in str(e.java_exception): + FlumeUtils._printErrorMsg(sc) + raise + @staticmethod def _printErrorMsg(sc): print(""" diff --git a/python/pyspark/streaming/kafka.py b/python/pyspark/streaming/kafka.py index 13f8f9578e..a70b99249d 100644 --- a/python/pyspark/streaming/kafka.py +++ b/python/pyspark/streaming/kafka.py @@ -66,18 +66,8 @@ class KafkaUtils(object): if not isinstance(topics, dict): raise TypeError("topics should be dict") jlevel = ssc._sc._getJavaStorageLevel(storageLevel) - - try: - # Use KafkaUtilsPythonHelper to access Scala's KafkaUtils (see SPARK-6027) - helperClass = ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader()\ - .loadClass("org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper") - helper = helperClass.newInstance() - jstream = helper.createStream(ssc._jssc, kafkaParams, topics, jlevel) - except Py4JJavaError as e: - # TODO: use --jar once it also work on driver - if 'ClassNotFoundException' in str(e.java_exception): - KafkaUtils._printErrorMsg(ssc.sparkContext) - raise e + helper = KafkaUtils._get_helper(ssc._sc) + jstream = helper.createStream(ssc._jssc, kafkaParams, topics, jlevel) ser = PairDeserializer(NoOpSerializer(), NoOpSerializer()) stream = DStream(jstream, ssc, ser) return stream.map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1]))) @@ -129,27 +119,20 @@ class KafkaUtils(object): m._set_value_decoder(valueDecoder) return messageHandler(m) - try: - helperClass = ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \ - .loadClass("org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper") - helper = helperClass.newInstance() - - jfromOffsets = dict([(k._jTopicAndPartition(helper), - v) for (k, v) in fromOffsets.items()]) - if messageHandler is None: - ser = PairDeserializer(NoOpSerializer(), NoOpSerializer()) - func = funcWithoutMessageHandler - jstream = helper.createDirectStreamWithoutMessageHandler( - ssc._jssc, kafkaParams, set(topics), jfromOffsets) - else: - ser = AutoBatchedSerializer(PickleSerializer()) - func = funcWithMessageHandler - jstream = helper.createDirectStreamWithMessageHandler( - ssc._jssc, kafkaParams, set(topics), jfromOffsets) - except Py4JJavaError as e: - if 'ClassNotFoundException' in str(e.java_exception): - KafkaUtils._printErrorMsg(ssc.sparkContext) - raise e + helper = KafkaUtils._get_helper(ssc._sc) + + jfromOffsets = dict([(k._jTopicAndPartition(helper), + v) for (k, v) in fromOffsets.items()]) + if messageHandler is None: + ser = PairDeserializer(NoOpSerializer(), NoOpSerializer()) + func = funcWithoutMessageHandler + jstream = helper.createDirectStreamWithoutMessageHandler( + ssc._jssc, kafkaParams, set(topics), jfromOffsets) + else: + ser = AutoBatchedSerializer(PickleSerializer()) + func = funcWithMessageHandler + jstream = helper.createDirectStreamWithMessageHandler( + ssc._jssc, kafkaParams, set(topics), jfromOffsets) stream = DStream(jstream, ssc, ser).map(func) return KafkaDStream(stream._jdstream, ssc, stream._jrdd_deserializer) @@ -189,28 +172,35 @@ class KafkaUtils(object): m._set_value_decoder(valueDecoder) return messageHandler(m) + helper = KafkaUtils._get_helper(sc) + + joffsetRanges = [o._jOffsetRange(helper) for o in offsetRanges] + jleaders = dict([(k._jTopicAndPartition(helper), + v._jBroker(helper)) for (k, v) in leaders.items()]) + if messageHandler is None: + jrdd = helper.createRDDWithoutMessageHandler( + sc._jsc, kafkaParams, joffsetRanges, jleaders) + ser = PairDeserializer(NoOpSerializer(), NoOpSerializer()) + rdd = RDD(jrdd, sc, ser).map(funcWithoutMessageHandler) + else: + jrdd = helper.createRDDWithMessageHandler( + sc._jsc, kafkaParams, joffsetRanges, jleaders) + rdd = RDD(jrdd, sc).map(funcWithMessageHandler) + + return KafkaRDD(rdd._jrdd, sc, rdd._jrdd_deserializer) + + @staticmethod + def _get_helper(sc): try: + # Use KafkaUtilsPythonHelper to access Scala's KafkaUtils (see SPARK-6027) helperClass = sc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \ .loadClass("org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper") - helper = helperClass.newInstance() - joffsetRanges = [o._jOffsetRange(helper) for o in offsetRanges] - jleaders = dict([(k._jTopicAndPartition(helper), - v._jBroker(helper)) for (k, v) in leaders.items()]) - if messageHandler is None: - jrdd = helper.createRDDWithoutMessageHandler( - sc._jsc, kafkaParams, joffsetRanges, jleaders) - ser = PairDeserializer(NoOpSerializer(), NoOpSerializer()) - rdd = RDD(jrdd, sc, ser).map(funcWithoutMessageHandler) - else: - jrdd = helper.createRDDWithMessageHandler( - sc._jsc, kafkaParams, joffsetRanges, jleaders) - rdd = RDD(jrdd, sc).map(funcWithMessageHandler) + return helperClass.newInstance() except Py4JJavaError as e: + # TODO: use --jar once it also work on driver if 'ClassNotFoundException' in str(e.java_exception): KafkaUtils._printErrorMsg(sc) - raise e - - return KafkaRDD(rdd._jrdd, sc, rdd._jrdd_deserializer) + raise @staticmethod def _printErrorMsg(sc): @@ -333,16 +323,8 @@ class KafkaRDD(RDD): Get the OffsetRange of specific KafkaRDD. :return: A list of OffsetRange """ - try: - helperClass = self.ctx._jvm.java.lang.Thread.currentThread().getContextClassLoader() \ - .loadClass("org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper") - helper = helperClass.newInstance() - joffsetRanges = helper.offsetRangesOfKafkaRDD(self._jrdd.rdd()) - except Py4JJavaError as e: - if 'ClassNotFoundException' in str(e.java_exception): - KafkaUtils._printErrorMsg(self.ctx) - raise e - + helper = KafkaUtils._get_helper(self.ctx) + joffsetRanges = helper.offsetRangesOfKafkaRDD(self._jrdd.rdd()) ranges = [OffsetRange(o.topic(), o.partition(), o.fromOffset(), o.untilOffset()) for o in joffsetRanges] return ranges diff --git a/python/pyspark/streaming/kinesis.py b/python/pyspark/streaming/kinesis.py index af72c3d690..e681301681 100644 --- a/python/pyspark/streaming/kinesis.py +++ b/python/pyspark/streaming/kinesis.py @@ -83,7 +83,7 @@ class KinesisUtils(object): except Py4JJavaError as e: if 'ClassNotFoundException' in str(e.java_exception): KinesisUtils._printErrorMsg(ssc.sparkContext) - raise e + raise stream = DStream(jstream, ssc, NoOpSerializer()) return stream.map(lambda v: decoder(v)) diff --git a/python/pyspark/streaming/mqtt.py b/python/pyspark/streaming/mqtt.py index 3a515ea499..388e9526ba 100644 --- a/python/pyspark/streaming/mqtt.py +++ b/python/pyspark/streaming/mqtt.py @@ -48,7 +48,7 @@ class MQTTUtils(object): except Py4JJavaError as e: if 'ClassNotFoundException' in str(e.java_exception): MQTTUtils._printErrorMsg(ssc.sparkContext) - raise e + raise return DStream(jstream, ssc, UTF8Deserializer()) -- cgit v1.2.3