diff options
author | Josh Rosen <joshrosen@eecs.berkeley.edu> | 2012-08-25 18:00:25 -0700 |
---|---|---|
committer | Josh Rosen <joshrosen@eecs.berkeley.edu> | 2012-08-27 00:24:47 -0700 |
commit | bff6a46359131a8f9bc38b93149b22baa7c711cd (patch) | |
tree | b40947ec8714fdf60a16b8d47e0858bdcaa23d96 /pyspark | |
parent | 200d248dcc5903295296bf897211cf543b37f8c1 (diff) | |
download | spark-bff6a46359131a8f9bc38b93149b22baa7c711cd.tar.gz spark-bff6a46359131a8f9bc38b93149b22baa7c711cd.tar.bz2 spark-bff6a46359131a8f9bc38b93149b22baa7c711cd.zip |
Add pipe(), saveAsTextFile(), sc.union() to Python API.
Diffstat (limited to 'pyspark')
-rw-r--r-- | pyspark/pyspark/context.py | 14 | ||||
-rw-r--r-- | pyspark/pyspark/rdd.py | 25 |
2 files changed, 31 insertions, 8 deletions
diff --git a/pyspark/pyspark/context.py b/pyspark/pyspark/context.py index b8490019e3..04932c93f2 100644 --- a/pyspark/pyspark/context.py +++ b/pyspark/pyspark/context.py @@ -7,6 +7,8 @@ from pyspark.java_gateway import launch_gateway from pyspark.serializers import dump_pickle, write_with_length from pyspark.rdd import RDD +from py4j.java_collections import ListConverter + class SparkContext(object): @@ -39,12 +41,6 @@ class SparkContext(object): self._jsc = None def parallelize(self, c, numSlices=None): - """ - >>> sc = SparkContext("local", "test") - >>> rdd = sc.parallelize([(1, 2), (3, 4)]) - >>> rdd.collect() - [(1, 2), (3, 4)] - """ numSlices = numSlices or self.defaultParallelism # Calling the Java parallelize() method with an ArrayList is too slow, # because it sends O(n) Py4J commands. As an alternative, serialized @@ -62,6 +58,12 @@ class SparkContext(object): jrdd = self._jsc.textFile(name, minSplits) return RDD(jrdd, self) + def union(self, rdds): + first = rdds[0]._jrdd + rest = [x._jrdd for x in rdds[1:]] + rest = ListConverter().convert(rest, self.gateway._gateway_client) + return RDD(self._jsc.union(first, rest), self) + def broadcast(self, value): jbroadcast = self._jsc.broadcast(bytearray(dump_pickle(value))) return Broadcast(jbroadcast.uuid().toString(), value, jbroadcast, diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py index 21e822ba9f..8477f6dd02 100644 --- a/pyspark/pyspark/rdd.py +++ b/pyspark/pyspark/rdd.py @@ -1,6 +1,9 @@ from base64 import standard_b64encode as b64enc from collections import Counter from itertools import chain, ifilter, imap +import shlex +from subprocess import Popen, PIPE +from threading import Thread from pyspark import cloudpickle from pyspark.serializers import dump_pickle, load_pickle @@ -118,7 +121,20 @@ class RDD(object): """ return self.map(lambda x: (f(x), x)).groupByKey(numSplits) - # TODO: pipe + def pipe(self, command, env={}): + """ + >>> sc.parallelize([1, 2, 3]).pipe('cat').collect() + ['1', '2', '3'] + """ + def func(iterator): + pipe = Popen(shlex.split(command), env=env, stdin=PIPE, stdout=PIPE) + def pipe_objs(out): + for obj in iterator: + out.write(str(obj).rstrip('\n') + '\n') + out.close() + Thread(target=pipe_objs, args=[pipe.stdin]).start() + return (x.rstrip('\n') for x in pipe.stdout) + return self.mapPartitions(func) def foreach(self, f): """ @@ -206,7 +222,12 @@ class RDD(object): """ return load_pickle(bytes(self.ctx.asPickle(self._jrdd.first()))) - # TODO: saveAsTextFile + def saveAsTextFile(self, path): + def func(iterator): + return (str(x).encode("utf-8") for x in iterator) + keyed = PipelinedRDD(self, func) + keyed._bypass_serializer = True + keyed._jrdd.map(self.ctx.jvm.BytesToString()).saveAsTextFile(path) # Pair functions |