diff options
-rw-r--r-- | core/src/main/scala/spark/api/python/PythonPartitioner.scala | 41 | ||||
-rw-r--r-- | core/src/main/scala/spark/api/python/PythonRDD.scala | 10 | ||||
-rw-r--r-- | pyspark/pyspark/rdd.py | 12 |
3 files changed, 54 insertions, 9 deletions
diff --git a/core/src/main/scala/spark/api/python/PythonPartitioner.scala b/core/src/main/scala/spark/api/python/PythonPartitioner.scala new file mode 100644 index 0000000000..ef9f808fb2 --- /dev/null +++ b/core/src/main/scala/spark/api/python/PythonPartitioner.scala @@ -0,0 +1,41 @@ +package spark.api.python + +import spark.Partitioner + +import java.util.Arrays + +/** + * A [[spark.Partitioner]] that performs handling of byte arrays, for use by the Python API. + */ +class PythonPartitioner(override val numPartitions: Int) extends Partitioner { + + override def getPartition(key: Any): Int = { + if (key == null) { + return 0 + } + else { + val hashCode = { + if (key.isInstanceOf[Array[Byte]]) { + System.err.println("Dumping a byte array!" + Arrays.hashCode(key.asInstanceOf[Array[Byte]]) + ) + Arrays.hashCode(key.asInstanceOf[Array[Byte]]) + } + else + key.hashCode() + } + val mod = hashCode % numPartitions + if (mod < 0) { + mod + numPartitions + } else { + mod // Guard against negative hash codes + } + } + } + + override def equals(other: Any): Boolean = other match { + case h: PythonPartitioner => + h.numPartitions == numPartitions + case _ => + false + } +} diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index a593e53efd..50094d6b0f 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -179,14 +179,12 @@ object PythonRDD { val dOut = new DataOutputStream(baos); if (elem.isInstanceOf[Array[Byte]]) { elem.asInstanceOf[Array[Byte]] - } else if (elem.isInstanceOf[scala.Tuple2[_, _]]) { - val t = elem.asInstanceOf[scala.Tuple2[_, _]] - val t1 = t._1.asInstanceOf[Array[Byte]] - val t2 = t._2.asInstanceOf[Array[Byte]] + } else if (elem.isInstanceOf[scala.Tuple2[Array[Byte], Array[Byte]]]) { + val t = elem.asInstanceOf[scala.Tuple2[Array[Byte], Array[Byte]]] dOut.writeByte(Pickle.PROTO) dOut.writeByte(Pickle.TWO) - dOut.write(PythonRDD.stripPickle(t1)) - dOut.write(PythonRDD.stripPickle(t2)) + dOut.write(PythonRDD.stripPickle(t._1)) + dOut.write(PythonRDD.stripPickle(t._2)) dOut.writeByte(Pickle.TUPLE2) dOut.writeByte(Pickle.STOP) baos.toByteArray() diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py index e4878c08ba..85a24c6854 100644 --- a/pyspark/pyspark/rdd.py +++ b/pyspark/pyspark/rdd.py @@ -310,6 +310,12 @@ class RDD(object): return python_right_outer_join(self, other, numSplits) def partitionBy(self, numSplits, hashFunc=hash): + """ + >>> pairs = sc.parallelize([1, 2, 3, 4, 2, 4, 1]).map(lambda x: (x, x)) + >>> sets = pairs.partitionBy(2).glom().collect() + >>> set(sets[0]).intersection(set(sets[1])) + set([]) + """ if numSplits is None: numSplits = self.ctx.defaultParallelism def add_shuffle_key(iterator): @@ -319,7 +325,7 @@ class RDD(object): keyed = PipelinedRDD(self, add_shuffle_key) keyed._bypass_serializer = True pairRDD = self.ctx.jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD() - partitioner = self.ctx.jvm.spark.HashPartitioner(numSplits) + partitioner = self.ctx.jvm.spark.api.python.PythonPartitioner(numSplits) jrdd = pairRDD.partitionBy(partitioner) jrdd = jrdd.map(self.ctx.jvm.ExtractValue()) return RDD(jrdd, self.ctx) @@ -391,7 +397,7 @@ class RDD(object): """ >>> x = sc.parallelize([("a", 1), ("b", 4)]) >>> y = sc.parallelize([("a", 2)]) - >>> x.cogroup(y).collect() + >>> sorted(x.cogroup(y).collect()) [('a', ([1], [2])), ('b', ([4], []))] """ return python_cogroup(self, other, numSplits) @@ -462,7 +468,7 @@ def _test(): import doctest from pyspark.context import SparkContext globs = globals().copy() - globs['sc'] = SparkContext('local', 'PythonTest') + globs['sc'] = SparkContext('local[4]', 'PythonTest') doctest.testmod(globs=globs) globs['sc'].stop() |