aboutsummaryrefslogtreecommitdiff
path: root/pyspark/pyspark/context.py
diff options
context:
space:
mode:
Diffstat (limited to 'pyspark/pyspark/context.py')
-rw-r--r--pyspark/pyspark/context.py49
1 files changed, 19 insertions, 30 deletions
diff --git a/pyspark/pyspark/context.py b/pyspark/pyspark/context.py
index 587ab12b5f..ac7e4057e9 100644
--- a/pyspark/pyspark/context.py
+++ b/pyspark/pyspark/context.py
@@ -3,22 +3,24 @@ import atexit
from tempfile import NamedTemporaryFile
from pyspark.java_gateway import launch_gateway
-from pyspark.serializers import JSONSerializer, NopSerializer
-from pyspark.rdd import RDD, PairRDD
+from pyspark.serializers import PickleSerializer, dumps
+from pyspark.rdd import RDD
class SparkContext(object):
gateway = launch_gateway()
jvm = gateway.jvm
- python_dump = jvm.spark.api.python.PythonRDD.pythonDump
+ pickleFile = jvm.spark.api.python.PythonRDD.pickleFile
+ asPickle = jvm.spark.api.python.PythonRDD.asPickle
+ arrayAsPickle = jvm.spark.api.python.PythonRDD.arrayAsPickle
- def __init__(self, master, name, defaultSerializer=JSONSerializer,
- defaultParallelism=None, pythonExec='python'):
+
+ def __init__(self, master, name, defaultParallelism=None,
+ pythonExec='python'):
self.master = master
self.name = name
self._jsc = self.jvm.JavaSparkContext(master, name)
- self.defaultSerializer = defaultSerializer
self.defaultParallelism = \
defaultParallelism or self._jsc.sc().defaultParallelism()
self.pythonExec = pythonExec
@@ -31,39 +33,26 @@ class SparkContext(object):
self._jsc.stop()
self._jsc = None
- def parallelize(self, c, numSlices=None, serializer=None):
- serializer = serializer or self.defaultSerializer
- numSlices = numSlices or self.defaultParallelism
- # Calling the Java parallelize() method with an ArrayList is too slow,
- # because it sends O(n) Py4J commands. As an alternative, serialized
- # objects are written to a file and loaded through textFile().
- tempFile = NamedTemporaryFile(delete=False)
- tempFile.writelines(serializer.dumps(x) + '\n' for x in c)
- tempFile.close()
- atexit.register(lambda: os.unlink(tempFile.name))
- return self.textFile(tempFile.name, numSlices, serializer)
-
- def parallelizePairs(self, c, numSlices=None, keySerializer=None,
- valSerializer=None):
+ def parallelize(self, c, numSlices=None):
"""
>>> sc = SparkContext("local", "test")
- >>> rdd = sc.parallelizePairs([(1, 2), (3, 4)])
+ >>> rdd = sc.parallelize([(1, 2), (3, 4)])
>>> rdd.collect()
[(1, 2), (3, 4)]
"""
- keySerializer = keySerializer or self.defaultSerializer
- valSerializer = valSerializer or self.defaultSerializer
numSlices = numSlices or self.defaultParallelism
+ # Calling the Java parallelize() method with an ArrayList is too slow,
+ # because it sends O(n) Py4J commands. As an alternative, serialized
+ # objects are written to a file and loaded through textFile().
tempFile = NamedTemporaryFile(delete=False)
- for (k, v) in c:
- tempFile.write(keySerializer.dumps(k).rstrip('\r\n') + '\n')
- tempFile.write(valSerializer.dumps(v).rstrip('\r\n') + '\n')
+ for x in c:
+ dumps(PickleSerializer.dumps(x), tempFile)
tempFile.close()
atexit.register(lambda: os.unlink(tempFile.name))
- jrdd = self.textFile(tempFile.name, numSlices)._pipePairs([], "echo")
- return PairRDD(jrdd, self, keySerializer, valSerializer)
+ jrdd = self.pickleFile(self._jsc, tempFile.name, numSlices)
+ return RDD(jrdd, self)
- def textFile(self, name, numSlices=None, serializer=NopSerializer):
+ def textFile(self, name, numSlices=None):
numSlices = numSlices or self.defaultParallelism
jrdd = self._jsc.textFile(name, numSlices)
- return RDD(jrdd, self, serializer)
+ return RDD(jrdd, self)