aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/context.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/context.py')
-rw-r--r--python/pyspark/context.py13
1 files changed, 6 insertions, 7 deletions
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 6011caf9f1..78dccc4047 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -21,6 +21,8 @@ import sys
from threading import Lock
from tempfile import NamedTemporaryFile
+from py4j.java_collections import ListConverter
+
from pyspark import accumulators
from pyspark.accumulators import Accumulator
from pyspark.broadcast import Broadcast
@@ -30,13 +32,11 @@ from pyspark.java_gateway import launch_gateway
from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \
PairDeserializer, AutoBatchedSerializer, NoOpSerializer
from pyspark.storagelevel import StorageLevel
-from pyspark.rdd import RDD
+from pyspark.rdd import RDD, _load_from_socket
from pyspark.traceback_utils import CallSite, first_spark_call
from pyspark.status import StatusTracker
from pyspark.profiler import ProfilerCollector, BasicProfiler
-from py4j.java_collections import ListConverter
-
__all__ = ['SparkContext']
@@ -59,7 +59,6 @@ class SparkContext(object):
_gateway = None
_jvm = None
- _writeToFile = None
_next_accum_id = 0
_active_spark_context = None
_lock = Lock()
@@ -221,7 +220,6 @@ class SparkContext(object):
if not SparkContext._gateway:
SparkContext._gateway = gateway or launch_gateway()
SparkContext._jvm = SparkContext._gateway.jvm
- SparkContext._writeToFile = SparkContext._jvm.PythonRDD.writeToFile
if instance:
if (SparkContext._active_spark_context and
@@ -840,8 +838,9 @@ class SparkContext(object):
# by runJob() in order to avoid having to pass a Python lambda into
# SparkContext#runJob.
mappedRDD = rdd.mapPartitions(partitionFunc)
- it = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, javaPartitions, allowLocal)
- return list(mappedRDD._collect_iterator_through_file(it))
+ port = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, javaPartitions,
+ allowLocal)
+ return list(_load_from_socket(port, mappedRDD._jrdd_deserializer))
def show_profiles(self):
""" Print the profile stats to stdout """