aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@eecs.berkeley.edu>2013-01-14 15:30:42 -0800
committerJosh Rosen <joshrosen@eecs.berkeley.edu>2013-01-20 15:41:42 -0800
commit9f211dd3f0132daf72fb39883fa4b28e4fd547ca (patch)
tree270d3bf88a053e858921277d329b5ace6843bac1
parentfe85a075117a79675971aff0cd020bba446c0233 (diff)
downloadspark-9f211dd3f0132daf72fb39883fa4b28e4fd547ca.tar.gz
spark-9f211dd3f0132daf72fb39883fa4b28e4fd547ca.tar.bz2
spark-9f211dd3f0132daf72fb39883fa4b28e4fd547ca.zip
Fix PythonPartitioner equality; see SPARK-654.
PythonPartitioner did not take the Python-side partitioning function into account when checking for equality, which might cause problems in the future.
-rw-r--r--core/src/main/scala/spark/api/python/PythonPartitioner.scala13
-rw-r--r--core/src/main/scala/spark/api/python/PythonRDD.scala5
-rw-r--r--python/pyspark/rdd.py17
3 files changed, 22 insertions, 13 deletions
diff --git a/core/src/main/scala/spark/api/python/PythonPartitioner.scala b/core/src/main/scala/spark/api/python/PythonPartitioner.scala
index 648d9402b0..519e310323 100644
--- a/core/src/main/scala/spark/api/python/PythonPartitioner.scala
+++ b/core/src/main/scala/spark/api/python/PythonPartitioner.scala
@@ -6,8 +6,17 @@ import java.util.Arrays
/**
* A [[spark.Partitioner]] that performs handling of byte arrays, for use by the Python API.
+ *
+ * Stores the unique id() of the Python-side partitioning function so that it is incorporated into
+ * equality comparisons. Correctness requires that the id is a unique identifier for the
+ * lifetime of the job (i.e. that it is not re-used as the id of a different partitioning
+ * function). This can be ensured by using the Python id() function and maintaining a reference
+ * to the Python partitioning function so that its id() is not reused.
*/
-private[spark] class PythonPartitioner(override val numPartitions: Int) extends Partitioner {
+private[spark] class PythonPartitioner(
+ override val numPartitions: Int,
+ val pyPartitionFunctionId: Long)
+ extends Partitioner {
override def getPartition(key: Any): Int = {
if (key == null) {
@@ -32,7 +41,7 @@ private[spark] class PythonPartitioner(override val numPartitions: Int) extends
override def equals(other: Any): Boolean = other match {
case h: PythonPartitioner =>
- h.numPartitions == numPartitions
+ h.numPartitions == numPartitions && h.pyPartitionFunctionId == pyPartitionFunctionId
case _ =>
false
}
diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala
index 89f7c316dc..e4c0530241 100644
--- a/core/src/main/scala/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/spark/api/python/PythonRDD.scala
@@ -252,11 +252,6 @@ private object Pickle {
val APPENDS: Byte = 'e'
}
-private class ExtractValue extends spark.api.java.function.Function[(Array[Byte],
- Array[Byte]), Array[Byte]] {
- override def call(pair: (Array[Byte], Array[Byte])) : Array[Byte] = pair._2
-}
-
private class BytesToString extends spark.api.java.function.Function[Array[Byte], String] {
override def call(arr: Array[Byte]) : String = new String(arr, "UTF-8")
}
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index d705f0f9e1..b58bf24e3e 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -33,6 +33,7 @@ class RDD(object):
self._jrdd = jrdd
self.is_cached = False
self.ctx = ctx
+ self._partitionFunc = None
@property
def context(self):
@@ -497,7 +498,7 @@ class RDD(object):
return python_right_outer_join(self, other, numSplits)
# TODO: add option to control map-side combining
- def partitionBy(self, numSplits, hashFunc=hash):
+ def partitionBy(self, numSplits, partitionFunc=hash):
"""
Return a copy of the RDD partitioned using the specified partitioner.
@@ -514,17 +515,21 @@ class RDD(object):
def add_shuffle_key(split, iterator):
buckets = defaultdict(list)
for (k, v) in iterator:
- buckets[hashFunc(k) % numSplits].append((k, v))
+ buckets[partitionFunc(k) % numSplits].append((k, v))
for (split, items) in buckets.iteritems():
yield str(split)
yield dump_pickle(Batch(items))
keyed = PipelinedRDD(self, add_shuffle_key)
keyed._bypass_serializer = True
pairRDD = self.ctx.jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD()
- partitioner = self.ctx.jvm.spark.api.python.PythonPartitioner(numSplits)
- jrdd = pairRDD.partitionBy(partitioner)
- jrdd = jrdd.map(self.ctx.jvm.ExtractValue())
- return RDD(jrdd, self.ctx)
+ partitioner = self.ctx.jvm.PythonPartitioner(numSplits,
+ id(partitionFunc))
+ jrdd = pairRDD.partitionBy(partitioner).values()
+ rdd = RDD(jrdd, self.ctx)
+ # This is required so that id(partitionFunc) remains unique, even if
+ # partitionFunc is a lambda:
+ rdd._partitionFunc = partitionFunc
+ return rdd
# TODO: add control over map-side aggregation
def combineByKey(self, createCombiner, mergeValue, mergeCombiners,