aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/join.py5
-rw-r--r--python/pyspark/rdd.py10
-rw-r--r--python/pyspark/resultiterable.py33
3 files changed, 41 insertions, 7 deletions
diff --git a/python/pyspark/join.py b/python/pyspark/join.py
index 5f4294fb1b..6f94d26ef8 100644
--- a/python/pyspark/join.py
+++ b/python/pyspark/join.py
@@ -31,11 +31,12 @@ THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""
+from pyspark.resultiterable import ResultIterable
def _do_python_join(rdd, other, numPartitions, dispatch):
vs = rdd.map(lambda (k, v): (k, (1, v)))
ws = other.map(lambda (k, v): (k, (2, v)))
- return vs.union(ws).groupByKey(numPartitions).flatMapValues(dispatch)
+ return vs.union(ws).groupByKey(numPartitions).flatMapValues(lambda x : dispatch(x.__iter__()))
def python_join(rdd, other, numPartitions):
@@ -88,5 +89,5 @@ def python_cogroup(rdd, other, numPartitions):
vbuf.append(v)
elif n == 2:
wbuf.append(v)
- return (vbuf, wbuf)
+ return (ResultIterable(vbuf), ResultIterable(wbuf))
return vs.union(ws).groupByKey(numPartitions).mapValues(dispatch)
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index fb27863e07..91fc7e637e 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -38,6 +38,7 @@ from pyspark.join import python_join, python_left_outer_join, \
from pyspark.statcounter import StatCounter
from pyspark.rddsampler import RDDSampler
from pyspark.storagelevel import StorageLevel
+from pyspark.resultiterable import ResultIterable
from py4j.java_collections import ListConverter, MapConverter
@@ -1118,7 +1119,7 @@ class RDD(object):
Hash-partitions the resulting RDD with into numPartitions partitions.
>>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)])
- >>> sorted(x.groupByKey().collect())
+ >>> map((lambda (x,y): (x, list(y))), sorted(x.groupByKey().collect()))
[('a', [1, 1]), ('b', [1])]
"""
@@ -1133,7 +1134,7 @@ class RDD(object):
return a + b
return self.combineByKey(createCombiner, mergeValue, mergeCombiners,
- numPartitions)
+ numPartitions).mapValues(lambda x: ResultIterable(x))
# TODO: add tests
def flatMapValues(self, f):
@@ -1180,7 +1181,7 @@ class RDD(object):
>>> x = sc.parallelize([("a", 1), ("b", 4)])
>>> y = sc.parallelize([("a", 2)])
- >>> sorted(x.cogroup(y).collect())
+ >>> map((lambda (x,y): (x, (list(y[0]), list(y[1])))), sorted(list(x.cogroup(y).collect())))
[('a', ([1], [2])), ('b', ([4], []))]
"""
return python_cogroup(self, other, numPartitions)
@@ -1217,7 +1218,7 @@ class RDD(object):
>>> x = sc.parallelize(range(0,3)).keyBy(lambda x: x*x)
>>> y = sc.parallelize(zip(range(0,5), range(0,5)))
- >>> sorted(x.cogroup(y).collect())
+ >>> map((lambda (x,y): (x, (list(y[0]), (list(y[1]))))), sorted(x.cogroup(y).collect()))
[(0, ([0], [0])), (1, ([1], [1])), (2, ([], [2])), (3, ([], [3])), (4, ([2], [4]))]
"""
return self.map(lambda x: (f(x), x))
@@ -1317,7 +1318,6 @@ class RDD(object):
# keys in the pairs. This could be an expensive operation, since those
# hashes aren't retained.
-
class PipelinedRDD(RDD):
"""
Pipelined maps:
diff --git a/python/pyspark/resultiterable.py b/python/pyspark/resultiterable.py
new file mode 100644
index 0000000000..7f418f8d2e
--- /dev/null
+++ b/python/pyspark/resultiterable.py
@@ -0,0 +1,33 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+__all__ = ["ResultIterable"]
+
+import collections
+
+class ResultIterable(collections.Iterable):
+ """
+ A special result iterable. This is used because the standard iterator can not be pickled
+ """
+ def __init__(self, data):
+ self.data = data
+ self.index = 0
+ self.maxindex = len(data)
+ def __iter__(self):
+ return iter(self.data)
+ def __len__(self):
+ return len(self.data)