aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@eecs.berkeley.edu>2013-01-14 15:30:42 -0800
committerJosh Rosen <joshrosen@eecs.berkeley.edu>2013-01-20 15:41:42 -0800
commit9f211dd3f0132daf72fb39883fa4b28e4fd547ca (patch)
tree270d3bf88a053e858921277d329b5ace6843bac1 /core
parentfe85a075117a79675971aff0cd020bba446c0233 (diff)
downloadspark-9f211dd3f0132daf72fb39883fa4b28e4fd547ca.tar.gz
spark-9f211dd3f0132daf72fb39883fa4b28e4fd547ca.tar.bz2
spark-9f211dd3f0132daf72fb39883fa4b28e4fd547ca.zip
Fix PythonPartitioner equality; see SPARK-654.
PythonPartitioner did not take the Python-side partitioning function into account when checking for equality, which might cause problems in the future.
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/spark/api/python/PythonPartitioner.scala13
-rw-r--r--core/src/main/scala/spark/api/python/PythonRDD.scala5
2 files changed, 11 insertions, 7 deletions
diff --git a/core/src/main/scala/spark/api/python/PythonPartitioner.scala b/core/src/main/scala/spark/api/python/PythonPartitioner.scala
index 648d9402b0..519e310323 100644
--- a/core/src/main/scala/spark/api/python/PythonPartitioner.scala
+++ b/core/src/main/scala/spark/api/python/PythonPartitioner.scala
@@ -6,8 +6,17 @@ import java.util.Arrays
/**
* A [[spark.Partitioner]] that performs handling of byte arrays, for use by the Python API.
+ *
+ * Stores the unique id() of the Python-side partitioning function so that it is incorporated into
+ * equality comparisons. Correctness requires that the id is a unique identifier for the
+ * lifetime of the job (i.e. that it is not re-used as the id of a different partitioning
+ * function). This can be ensured by using the Python id() function and maintaining a reference
+ * to the Python partitioning function so that its id() is not reused.
*/
-private[spark] class PythonPartitioner(override val numPartitions: Int) extends Partitioner {
+private[spark] class PythonPartitioner(
+ override val numPartitions: Int,
+ val pyPartitionFunctionId: Long)
+ extends Partitioner {
override def getPartition(key: Any): Int = {
if (key == null) {
@@ -32,7 +41,7 @@ private[spark] class PythonPartitioner(override val numPartitions: Int) extends
override def equals(other: Any): Boolean = other match {
case h: PythonPartitioner =>
- h.numPartitions == numPartitions
+ h.numPartitions == numPartitions && h.pyPartitionFunctionId == pyPartitionFunctionId
case _ =>
false
}
diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala
index 89f7c316dc..e4c0530241 100644
--- a/core/src/main/scala/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/spark/api/python/PythonRDD.scala
@@ -252,11 +252,6 @@ private object Pickle {
val APPENDS: Byte = 'e'
}
-private class ExtractValue extends spark.api.java.function.Function[(Array[Byte],
- Array[Byte]), Array[Byte]] {
- override def call(pair: (Array[Byte], Array[Byte])) : Array[Byte] = pair._2
-}
-
private class BytesToString extends spark.api.java.function.Function[Array[Byte], String] {
override def call(arr: Array[Byte]) : String = new String(arr, "UTF-8")
}