aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/streaming/util.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/streaming/util.py')
-rw-r--r--python/pyspark/streaming/util.py13
1 files changed, 7 insertions, 6 deletions
diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py
index c7f02bca2a..abbbf6eb93 100644
--- a/python/pyspark/streaming/util.py
+++ b/python/pyspark/streaming/util.py
@@ -37,11 +37,11 @@ class TransformFunction(object):
self.ctx = ctx
self.func = func
self.deserializers = deserializers
- self._rdd_wrapper = lambda jrdd, ctx, ser: RDD(jrdd, ctx, ser)
+ self.rdd_wrap_func = lambda jrdd, ctx, ser: RDD(jrdd, ctx, ser)
self.failure = None
def rdd_wrapper(self, func):
- self._rdd_wrapper = func
+ self.rdd_wrap_func = func
return self
def call(self, milliseconds, jrdds):
@@ -59,7 +59,7 @@ class TransformFunction(object):
if len(sers) < len(jrdds):
sers += (sers[0],) * (len(jrdds) - len(sers))
- rdds = [self._rdd_wrapper(jrdd, self.ctx, ser) if jrdd else None
+ rdds = [self.rdd_wrap_func(jrdd, self.ctx, ser) if jrdd else None
for jrdd, ser in zip(jrdds, sers)]
t = datetime.fromtimestamp(milliseconds / 1000.0)
r = self.func(t, *rdds)
@@ -101,7 +101,8 @@ class TransformFunctionSerializer(object):
self.failure = None
try:
func = self.gateway.gateway_property.pool[id]
- return bytearray(self.serializer.dumps((func.func, func.deserializers)))
+ return bytearray(self.serializer.dumps((
+ func.func, func.rdd_wrap_func, func.deserializers)))
except:
self.failure = traceback.format_exc()
@@ -109,8 +110,8 @@ class TransformFunctionSerializer(object):
# Clear the failure
self.failure = None
try:
- f, deserializers = self.serializer.loads(bytes(data))
- return TransformFunction(self.ctx, f, *deserializers)
+ f, wrap_func, deserializers = self.serializer.loads(bytes(data))
+ return TransformFunction(self.ctx, f, *deserializers).rdd_wrapper(wrap_func)
except:
self.failure = traceback.format_exc()