aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@databricks.com>2016-03-11 11:18:51 -0800
committerShixiong Zhu <shixiong@databricks.com>2016-03-11 11:18:51 -0800
commit073bf9d4d91e0242a813f3d227e52e76c26a2200 (patch)
tree469a2cdc7e9e67c64005aa2938dfb0cdfcebac22 /python
parentff776b2fc1cd4c571fd542dbf807e6fa3373cb34 (diff)
downloadspark-073bf9d4d91e0242a813f3d227e52e76c26a2200.tar.gz
spark-073bf9d4d91e0242a813f3d227e52e76c26a2200.tar.bz2
spark-073bf9d4d91e0242a813f3d227e52e76c26a2200.zip
[SPARK-13807] De-duplicate `Python*Helper` instantiation code in PySpark streaming
This patch de-duplicates code in PySpark streaming which loads the `Python*Helper` classes. I also changed a few `raise e` statements to simply `raise` in order to preserve the full exception stacktrace when re-throwing. Here's a link to the whitespace-change-free diff: https://github.com/apache/spark/compare/master...JoshRosen:pyspark-reflection-deduplication?w=0 Author: Josh Rosen <joshrosen@databricks.com> Closes #11641 from JoshRosen/pyspark-reflection-deduplication.
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/streaming/flume.py40
-rw-r--r--python/pyspark/streaming/kafka.py100
-rw-r--r--python/pyspark/streaming/kinesis.py2
-rw-r--r--python/pyspark/streaming/mqtt.py2
4 files changed, 60 insertions, 84 deletions
diff --git a/python/pyspark/streaming/flume.py b/python/pyspark/streaming/flume.py
index b1fff0a5c7..edd5886a85 100644
--- a/python/pyspark/streaming/flume.py
+++ b/python/pyspark/streaming/flume.py
@@ -55,17 +55,8 @@ class FlumeUtils(object):
:return: A DStream object
"""
jlevel = ssc._sc._getJavaStorageLevel(storageLevel)
-
- try:
- helperClass = ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader()\
- .loadClass("org.apache.spark.streaming.flume.FlumeUtilsPythonHelper")
- helper = helperClass.newInstance()
- jstream = helper.createStream(ssc._jssc, hostname, port, jlevel, enableDecompression)
- except Py4JJavaError as e:
- if 'ClassNotFoundException' in str(e.java_exception):
- FlumeUtils._printErrorMsg(ssc.sparkContext)
- raise e
-
+ helper = FlumeUtils._get_helper(ssc._sc)
+ jstream = helper.createStream(ssc._jssc, hostname, port, jlevel, enableDecompression)
return FlumeUtils._toPythonDStream(ssc, jstream, bodyDecoder)
@staticmethod
@@ -95,18 +86,9 @@ class FlumeUtils(object):
for (host, port) in addresses:
hosts.append(host)
ports.append(port)
-
- try:
- helperClass = ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \
- .loadClass("org.apache.spark.streaming.flume.FlumeUtilsPythonHelper")
- helper = helperClass.newInstance()
- jstream = helper.createPollingStream(
- ssc._jssc, hosts, ports, jlevel, maxBatchSize, parallelism)
- except Py4JJavaError as e:
- if 'ClassNotFoundException' in str(e.java_exception):
- FlumeUtils._printErrorMsg(ssc.sparkContext)
- raise e
-
+ helper = FlumeUtils._get_helper(ssc._sc)
+ jstream = helper.createPollingStream(
+ ssc._jssc, hosts, ports, jlevel, maxBatchSize, parallelism)
return FlumeUtils._toPythonDStream(ssc, jstream, bodyDecoder)
@staticmethod
@@ -127,6 +109,18 @@ class FlumeUtils(object):
return stream.map(func)
@staticmethod
+ def _get_helper(sc):
+ try:
+ helperClass = sc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \
+ .loadClass("org.apache.spark.streaming.flume.FlumeUtilsPythonHelper")
+ return helperClass.newInstance()
+ except Py4JJavaError as e:
+ # TODO: use --jar once it also work on driver
+ if 'ClassNotFoundException' in str(e.java_exception):
+ FlumeUtils._printErrorMsg(sc)
+ raise
+
+ @staticmethod
def _printErrorMsg(sc):
print("""
________________________________________________________________________________________________
diff --git a/python/pyspark/streaming/kafka.py b/python/pyspark/streaming/kafka.py
index 13f8f9578e..a70b99249d 100644
--- a/python/pyspark/streaming/kafka.py
+++ b/python/pyspark/streaming/kafka.py
@@ -66,18 +66,8 @@ class KafkaUtils(object):
if not isinstance(topics, dict):
raise TypeError("topics should be dict")
jlevel = ssc._sc._getJavaStorageLevel(storageLevel)
-
- try:
- # Use KafkaUtilsPythonHelper to access Scala's KafkaUtils (see SPARK-6027)
- helperClass = ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader()\
- .loadClass("org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper")
- helper = helperClass.newInstance()
- jstream = helper.createStream(ssc._jssc, kafkaParams, topics, jlevel)
- except Py4JJavaError as e:
- # TODO: use --jar once it also work on driver
- if 'ClassNotFoundException' in str(e.java_exception):
- KafkaUtils._printErrorMsg(ssc.sparkContext)
- raise e
+ helper = KafkaUtils._get_helper(ssc._sc)
+ jstream = helper.createStream(ssc._jssc, kafkaParams, topics, jlevel)
ser = PairDeserializer(NoOpSerializer(), NoOpSerializer())
stream = DStream(jstream, ssc, ser)
return stream.map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1])))
@@ -129,27 +119,20 @@ class KafkaUtils(object):
m._set_value_decoder(valueDecoder)
return messageHandler(m)
- try:
- helperClass = ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \
- .loadClass("org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper")
- helper = helperClass.newInstance()
-
- jfromOffsets = dict([(k._jTopicAndPartition(helper),
- v) for (k, v) in fromOffsets.items()])
- if messageHandler is None:
- ser = PairDeserializer(NoOpSerializer(), NoOpSerializer())
- func = funcWithoutMessageHandler
- jstream = helper.createDirectStreamWithoutMessageHandler(
- ssc._jssc, kafkaParams, set(topics), jfromOffsets)
- else:
- ser = AutoBatchedSerializer(PickleSerializer())
- func = funcWithMessageHandler
- jstream = helper.createDirectStreamWithMessageHandler(
- ssc._jssc, kafkaParams, set(topics), jfromOffsets)
- except Py4JJavaError as e:
- if 'ClassNotFoundException' in str(e.java_exception):
- KafkaUtils._printErrorMsg(ssc.sparkContext)
- raise e
+ helper = KafkaUtils._get_helper(ssc._sc)
+
+ jfromOffsets = dict([(k._jTopicAndPartition(helper),
+ v) for (k, v) in fromOffsets.items()])
+ if messageHandler is None:
+ ser = PairDeserializer(NoOpSerializer(), NoOpSerializer())
+ func = funcWithoutMessageHandler
+ jstream = helper.createDirectStreamWithoutMessageHandler(
+ ssc._jssc, kafkaParams, set(topics), jfromOffsets)
+ else:
+ ser = AutoBatchedSerializer(PickleSerializer())
+ func = funcWithMessageHandler
+ jstream = helper.createDirectStreamWithMessageHandler(
+ ssc._jssc, kafkaParams, set(topics), jfromOffsets)
stream = DStream(jstream, ssc, ser).map(func)
return KafkaDStream(stream._jdstream, ssc, stream._jrdd_deserializer)
@@ -189,28 +172,35 @@ class KafkaUtils(object):
m._set_value_decoder(valueDecoder)
return messageHandler(m)
+ helper = KafkaUtils._get_helper(sc)
+
+ joffsetRanges = [o._jOffsetRange(helper) for o in offsetRanges]
+ jleaders = dict([(k._jTopicAndPartition(helper),
+ v._jBroker(helper)) for (k, v) in leaders.items()])
+ if messageHandler is None:
+ jrdd = helper.createRDDWithoutMessageHandler(
+ sc._jsc, kafkaParams, joffsetRanges, jleaders)
+ ser = PairDeserializer(NoOpSerializer(), NoOpSerializer())
+ rdd = RDD(jrdd, sc, ser).map(funcWithoutMessageHandler)
+ else:
+ jrdd = helper.createRDDWithMessageHandler(
+ sc._jsc, kafkaParams, joffsetRanges, jleaders)
+ rdd = RDD(jrdd, sc).map(funcWithMessageHandler)
+
+ return KafkaRDD(rdd._jrdd, sc, rdd._jrdd_deserializer)
+
+ @staticmethod
+ def _get_helper(sc):
try:
+ # Use KafkaUtilsPythonHelper to access Scala's KafkaUtils (see SPARK-6027)
helperClass = sc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \
.loadClass("org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper")
- helper = helperClass.newInstance()
- joffsetRanges = [o._jOffsetRange(helper) for o in offsetRanges]
- jleaders = dict([(k._jTopicAndPartition(helper),
- v._jBroker(helper)) for (k, v) in leaders.items()])
- if messageHandler is None:
- jrdd = helper.createRDDWithoutMessageHandler(
- sc._jsc, kafkaParams, joffsetRanges, jleaders)
- ser = PairDeserializer(NoOpSerializer(), NoOpSerializer())
- rdd = RDD(jrdd, sc, ser).map(funcWithoutMessageHandler)
- else:
- jrdd = helper.createRDDWithMessageHandler(
- sc._jsc, kafkaParams, joffsetRanges, jleaders)
- rdd = RDD(jrdd, sc).map(funcWithMessageHandler)
+ return helperClass.newInstance()
except Py4JJavaError as e:
+ # TODO: use --jar once it also work on driver
if 'ClassNotFoundException' in str(e.java_exception):
KafkaUtils._printErrorMsg(sc)
- raise e
-
- return KafkaRDD(rdd._jrdd, sc, rdd._jrdd_deserializer)
+ raise
@staticmethod
def _printErrorMsg(sc):
@@ -333,16 +323,8 @@ class KafkaRDD(RDD):
Get the OffsetRange of specific KafkaRDD.
:return: A list of OffsetRange
"""
- try:
- helperClass = self.ctx._jvm.java.lang.Thread.currentThread().getContextClassLoader() \
- .loadClass("org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper")
- helper = helperClass.newInstance()
- joffsetRanges = helper.offsetRangesOfKafkaRDD(self._jrdd.rdd())
- except Py4JJavaError as e:
- if 'ClassNotFoundException' in str(e.java_exception):
- KafkaUtils._printErrorMsg(self.ctx)
- raise e
-
+ helper = KafkaUtils._get_helper(self.ctx)
+ joffsetRanges = helper.offsetRangesOfKafkaRDD(self._jrdd.rdd())
ranges = [OffsetRange(o.topic(), o.partition(), o.fromOffset(), o.untilOffset())
for o in joffsetRanges]
return ranges
diff --git a/python/pyspark/streaming/kinesis.py b/python/pyspark/streaming/kinesis.py
index af72c3d690..e681301681 100644
--- a/python/pyspark/streaming/kinesis.py
+++ b/python/pyspark/streaming/kinesis.py
@@ -83,7 +83,7 @@ class KinesisUtils(object):
except Py4JJavaError as e:
if 'ClassNotFoundException' in str(e.java_exception):
KinesisUtils._printErrorMsg(ssc.sparkContext)
- raise e
+ raise
stream = DStream(jstream, ssc, NoOpSerializer())
return stream.map(lambda v: decoder(v))
diff --git a/python/pyspark/streaming/mqtt.py b/python/pyspark/streaming/mqtt.py
index 3a515ea499..388e9526ba 100644
--- a/python/pyspark/streaming/mqtt.py
+++ b/python/pyspark/streaming/mqtt.py
@@ -48,7 +48,7 @@ class MQTTUtils(object):
except Py4JJavaError as e:
if 'ClassNotFoundException' in str(e.java_exception):
MQTTUtils._printErrorMsg(ssc.sparkContext)
- raise e
+ raise
return DStream(jstream, ssc, UTF8Deserializer())