aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/spark/api/python/PythonPartitioner.scala41
-rw-r--r--core/src/main/scala/spark/api/python/PythonRDD.scala10
-rw-r--r--pyspark/pyspark/rdd.py12
3 files changed, 54 insertions, 9 deletions
diff --git a/core/src/main/scala/spark/api/python/PythonPartitioner.scala b/core/src/main/scala/spark/api/python/PythonPartitioner.scala
new file mode 100644
index 0000000000..ef9f808fb2
--- /dev/null
+++ b/core/src/main/scala/spark/api/python/PythonPartitioner.scala
@@ -0,0 +1,41 @@
+package spark.api.python
+
+import spark.Partitioner
+
+import java.util.Arrays
+
+/**
+ * A [[spark.Partitioner]] that performs handling of byte arrays, for use by the Python API.
+ */
+class PythonPartitioner(override val numPartitions: Int) extends Partitioner {
+
+ override def getPartition(key: Any): Int = {
+ if (key == null) {
+ return 0
+ }
+ else {
+ val hashCode = {
+ if (key.isInstanceOf[Array[Byte]]) {
+ System.err.println("Dumping a byte array!" + Arrays.hashCode(key.asInstanceOf[Array[Byte]])
+ )
+ Arrays.hashCode(key.asInstanceOf[Array[Byte]])
+ }
+ else
+ key.hashCode()
+ }
+ val mod = hashCode % numPartitions
+ if (mod < 0) {
+ mod + numPartitions
+ } else {
+ mod // Guard against negative hash codes
+ }
+ }
+ }
+
+ override def equals(other: Any): Boolean = other match {
+ case h: PythonPartitioner =>
+ h.numPartitions == numPartitions
+ case _ =>
+ false
+ }
+}
diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala
index a593e53efd..50094d6b0f 100644
--- a/core/src/main/scala/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/spark/api/python/PythonRDD.scala
@@ -179,14 +179,12 @@ object PythonRDD {
val dOut = new DataOutputStream(baos);
if (elem.isInstanceOf[Array[Byte]]) {
elem.asInstanceOf[Array[Byte]]
- } else if (elem.isInstanceOf[scala.Tuple2[_, _]]) {
- val t = elem.asInstanceOf[scala.Tuple2[_, _]]
- val t1 = t._1.asInstanceOf[Array[Byte]]
- val t2 = t._2.asInstanceOf[Array[Byte]]
+ } else if (elem.isInstanceOf[scala.Tuple2[Array[Byte], Array[Byte]]]) {
+ val t = elem.asInstanceOf[scala.Tuple2[Array[Byte], Array[Byte]]]
dOut.writeByte(Pickle.PROTO)
dOut.writeByte(Pickle.TWO)
- dOut.write(PythonRDD.stripPickle(t1))
- dOut.write(PythonRDD.stripPickle(t2))
+ dOut.write(PythonRDD.stripPickle(t._1))
+ dOut.write(PythonRDD.stripPickle(t._2))
dOut.writeByte(Pickle.TUPLE2)
dOut.writeByte(Pickle.STOP)
baos.toByteArray()
diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py
index e4878c08ba..85a24c6854 100644
--- a/pyspark/pyspark/rdd.py
+++ b/pyspark/pyspark/rdd.py
@@ -310,6 +310,12 @@ class RDD(object):
return python_right_outer_join(self, other, numSplits)
def partitionBy(self, numSplits, hashFunc=hash):
+ """
+ >>> pairs = sc.parallelize([1, 2, 3, 4, 2, 4, 1]).map(lambda x: (x, x))
+ >>> sets = pairs.partitionBy(2).glom().collect()
+ >>> set(sets[0]).intersection(set(sets[1]))
+ set([])
+ """
if numSplits is None:
numSplits = self.ctx.defaultParallelism
def add_shuffle_key(iterator):
@@ -319,7 +325,7 @@ class RDD(object):
keyed = PipelinedRDD(self, add_shuffle_key)
keyed._bypass_serializer = True
pairRDD = self.ctx.jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD()
- partitioner = self.ctx.jvm.spark.HashPartitioner(numSplits)
+ partitioner = self.ctx.jvm.spark.api.python.PythonPartitioner(numSplits)
jrdd = pairRDD.partitionBy(partitioner)
jrdd = jrdd.map(self.ctx.jvm.ExtractValue())
return RDD(jrdd, self.ctx)
@@ -391,7 +397,7 @@ class RDD(object):
"""
>>> x = sc.parallelize([("a", 1), ("b", 4)])
>>> y = sc.parallelize([("a", 2)])
- >>> x.cogroup(y).collect()
+ >>> sorted(x.cogroup(y).collect())
[('a', ([1], [2])), ('b', ([4], []))]
"""
return python_cogroup(self, other, numSplits)
@@ -462,7 +468,7 @@ def _test():
import doctest
from pyspark.context import SparkContext
globs = globals().copy()
- globs['sc'] = SparkContext('local', 'PythonTest')
+ globs['sc'] = SparkContext('local[4]', 'PythonTest')
doctest.testmod(globs=globs)
globs['sc'].stop()