diff options
Diffstat (limited to 'pyspark/pyspark/context.py')
-rw-r--r-- | pyspark/pyspark/context.py | 49 |
1 files changed, 19 insertions, 30 deletions
diff --git a/pyspark/pyspark/context.py b/pyspark/pyspark/context.py index 587ab12b5f..ac7e4057e9 100644 --- a/pyspark/pyspark/context.py +++ b/pyspark/pyspark/context.py @@ -3,22 +3,24 @@ import atexit from tempfile import NamedTemporaryFile from pyspark.java_gateway import launch_gateway -from pyspark.serializers import JSONSerializer, NopSerializer -from pyspark.rdd import RDD, PairRDD +from pyspark.serializers import PickleSerializer, dumps +from pyspark.rdd import RDD class SparkContext(object): gateway = launch_gateway() jvm = gateway.jvm - python_dump = jvm.spark.api.python.PythonRDD.pythonDump + pickleFile = jvm.spark.api.python.PythonRDD.pickleFile + asPickle = jvm.spark.api.python.PythonRDD.asPickle + arrayAsPickle = jvm.spark.api.python.PythonRDD.arrayAsPickle - def __init__(self, master, name, defaultSerializer=JSONSerializer, - defaultParallelism=None, pythonExec='python'): + + def __init__(self, master, name, defaultParallelism=None, + pythonExec='python'): self.master = master self.name = name self._jsc = self.jvm.JavaSparkContext(master, name) - self.defaultSerializer = defaultSerializer self.defaultParallelism = \ defaultParallelism or self._jsc.sc().defaultParallelism() self.pythonExec = pythonExec @@ -31,39 +33,26 @@ class SparkContext(object): self._jsc.stop() self._jsc = None - def parallelize(self, c, numSlices=None, serializer=None): - serializer = serializer or self.defaultSerializer - numSlices = numSlices or self.defaultParallelism - # Calling the Java parallelize() method with an ArrayList is too slow, - # because it sends O(n) Py4J commands. As an alternative, serialized - # objects are written to a file and loaded through textFile(). - tempFile = NamedTemporaryFile(delete=False) - tempFile.writelines(serializer.dumps(x) + '\n' for x in c) - tempFile.close() - atexit.register(lambda: os.unlink(tempFile.name)) - return self.textFile(tempFile.name, numSlices, serializer) - - def parallelizePairs(self, c, numSlices=None, keySerializer=None, - valSerializer=None): + def parallelize(self, c, numSlices=None): """ >>> sc = SparkContext("local", "test") - >>> rdd = sc.parallelizePairs([(1, 2), (3, 4)]) + >>> rdd = sc.parallelize([(1, 2), (3, 4)]) >>> rdd.collect() [(1, 2), (3, 4)] """ - keySerializer = keySerializer or self.defaultSerializer - valSerializer = valSerializer or self.defaultSerializer numSlices = numSlices or self.defaultParallelism + # Calling the Java parallelize() method with an ArrayList is too slow, + # because it sends O(n) Py4J commands. As an alternative, serialized + # objects are written to a file and loaded through textFile(). tempFile = NamedTemporaryFile(delete=False) - for (k, v) in c: - tempFile.write(keySerializer.dumps(k).rstrip('\r\n') + '\n') - tempFile.write(valSerializer.dumps(v).rstrip('\r\n') + '\n') + for x in c: + dumps(PickleSerializer.dumps(x), tempFile) tempFile.close() atexit.register(lambda: os.unlink(tempFile.name)) - jrdd = self.textFile(tempFile.name, numSlices)._pipePairs([], "echo") - return PairRDD(jrdd, self, keySerializer, valSerializer) + jrdd = self.pickleFile(self._jsc, tempFile.name, numSlices) + return RDD(jrdd, self) - def textFile(self, name, numSlices=None, serializer=NopSerializer): + def textFile(self, name, numSlices=None): numSlices = numSlices or self.defaultParallelism jrdd = self._jsc.textFile(name, numSlices) - return RDD(jrdd, self, serializer) + return RDD(jrdd, self) |