aboutsummaryrefslogtreecommitdiff
path: root/pyspark
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@eecs.berkeley.edu>2012-08-25 18:00:25 -0700
committerJosh Rosen <joshrosen@eecs.berkeley.edu>2012-08-27 00:24:47 -0700
commitbff6a46359131a8f9bc38b93149b22baa7c711cd (patch)
treeb40947ec8714fdf60a16b8d47e0858bdcaa23d96 /pyspark
parent200d248dcc5903295296bf897211cf543b37f8c1 (diff)
downloadspark-bff6a46359131a8f9bc38b93149b22baa7c711cd.tar.gz
spark-bff6a46359131a8f9bc38b93149b22baa7c711cd.tar.bz2
spark-bff6a46359131a8f9bc38b93149b22baa7c711cd.zip
Add pipe(), saveAsTextFile(), sc.union() to Python API.
Diffstat (limited to 'pyspark')
-rw-r--r--pyspark/pyspark/context.py14
-rw-r--r--pyspark/pyspark/rdd.py25
2 files changed, 31 insertions, 8 deletions
diff --git a/pyspark/pyspark/context.py b/pyspark/pyspark/context.py
index b8490019e3..04932c93f2 100644
--- a/pyspark/pyspark/context.py
+++ b/pyspark/pyspark/context.py
@@ -7,6 +7,8 @@ from pyspark.java_gateway import launch_gateway
from pyspark.serializers import dump_pickle, write_with_length
from pyspark.rdd import RDD
+from py4j.java_collections import ListConverter
+
class SparkContext(object):
@@ -39,12 +41,6 @@ class SparkContext(object):
self._jsc = None
def parallelize(self, c, numSlices=None):
- """
- >>> sc = SparkContext("local", "test")
- >>> rdd = sc.parallelize([(1, 2), (3, 4)])
- >>> rdd.collect()
- [(1, 2), (3, 4)]
- """
numSlices = numSlices or self.defaultParallelism
# Calling the Java parallelize() method with an ArrayList is too slow,
# because it sends O(n) Py4J commands. As an alternative, serialized
@@ -62,6 +58,12 @@ class SparkContext(object):
jrdd = self._jsc.textFile(name, minSplits)
return RDD(jrdd, self)
+ def union(self, rdds):
+ first = rdds[0]._jrdd
+ rest = [x._jrdd for x in rdds[1:]]
+ rest = ListConverter().convert(rest, self.gateway._gateway_client)
+ return RDD(self._jsc.union(first, rest), self)
+
def broadcast(self, value):
jbroadcast = self._jsc.broadcast(bytearray(dump_pickle(value)))
return Broadcast(jbroadcast.uuid().toString(), value, jbroadcast,
diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py
index 21e822ba9f..8477f6dd02 100644
--- a/pyspark/pyspark/rdd.py
+++ b/pyspark/pyspark/rdd.py
@@ -1,6 +1,9 @@
from base64 import standard_b64encode as b64enc
from collections import Counter
from itertools import chain, ifilter, imap
+import shlex
+from subprocess import Popen, PIPE
+from threading import Thread
from pyspark import cloudpickle
from pyspark.serializers import dump_pickle, load_pickle
@@ -118,7 +121,20 @@ class RDD(object):
"""
return self.map(lambda x: (f(x), x)).groupByKey(numSplits)
- # TODO: pipe
+ def pipe(self, command, env={}):
+ """
+ >>> sc.parallelize([1, 2, 3]).pipe('cat').collect()
+ ['1', '2', '3']
+ """
+ def func(iterator):
+ pipe = Popen(shlex.split(command), env=env, stdin=PIPE, stdout=PIPE)
+ def pipe_objs(out):
+ for obj in iterator:
+ out.write(str(obj).rstrip('\n') + '\n')
+ out.close()
+ Thread(target=pipe_objs, args=[pipe.stdin]).start()
+ return (x.rstrip('\n') for x in pipe.stdout)
+ return self.mapPartitions(func)
def foreach(self, f):
"""
@@ -206,7 +222,12 @@ class RDD(object):
"""
return load_pickle(bytes(self.ctx.asPickle(self._jrdd.first())))
- # TODO: saveAsTextFile
+ def saveAsTextFile(self, path):
+ def func(iterator):
+ return (str(x).encode("utf-8") for x in iterator)
+ keyed = PipelinedRDD(self, func)
+ keyed._bypass_serializer = True
+ keyed._jrdd.map(self.ctx.jvm.BytesToString()).saveAsTextFile(path)
# Pair functions