aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@eecs.berkeley.edu>2012-10-28 22:30:28 -0700
committerJosh Rosen <joshrosen@eecs.berkeley.edu>2012-10-28 22:30:28 -0700
commit2ccf3b665280bf5b0919e3801d028126cb070dbd (patch)
treec486342f7d31dd1f529444ca40339e2a879ac219 /core
parent7859879aaa1860ff6b383e32a18fd9a410a97416 (diff)
downloadspark-2ccf3b665280bf5b0919e3801d028126cb070dbd.tar.gz
spark-2ccf3b665280bf5b0919e3801d028126cb070dbd.tar.bz2
spark-2ccf3b665280bf5b0919e3801d028126cb070dbd.zip
Fix PySpark hash partitioning bug.
A Java array's hashCode is based on its object identify, not its elements, so this was causing serialized keys to be hashed incorrectly. This commit adds a PySpark-specific workaround and adds more tests.
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/spark/api/python/PythonPartitioner.scala41
-rw-r--r--core/src/main/scala/spark/api/python/PythonRDD.scala10
2 files changed, 45 insertions, 6 deletions
diff --git a/core/src/main/scala/spark/api/python/PythonPartitioner.scala b/core/src/main/scala/spark/api/python/PythonPartitioner.scala
new file mode 100644
index 0000000000..ef9f808fb2
--- /dev/null
+++ b/core/src/main/scala/spark/api/python/PythonPartitioner.scala
@@ -0,0 +1,41 @@
+package spark.api.python
+
+import spark.Partitioner
+
+import java.util.Arrays
+
+/**
+ * A [[spark.Partitioner]] that performs handling of byte arrays, for use by the Python API.
+ */
+class PythonPartitioner(override val numPartitions: Int) extends Partitioner {
+
+ override def getPartition(key: Any): Int = {
+ if (key == null) {
+ return 0
+ }
+ else {
+ val hashCode = {
+ if (key.isInstanceOf[Array[Byte]]) {
+ System.err.println("Dumping a byte array!" + Arrays.hashCode(key.asInstanceOf[Array[Byte]])
+ )
+ Arrays.hashCode(key.asInstanceOf[Array[Byte]])
+ }
+ else
+ key.hashCode()
+ }
+ val mod = hashCode % numPartitions
+ if (mod < 0) {
+ mod + numPartitions
+ } else {
+ mod // Guard against negative hash codes
+ }
+ }
+ }
+
+ override def equals(other: Any): Boolean = other match {
+ case h: PythonPartitioner =>
+ h.numPartitions == numPartitions
+ case _ =>
+ false
+ }
+}
diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala
index a593e53efd..50094d6b0f 100644
--- a/core/src/main/scala/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/spark/api/python/PythonRDD.scala
@@ -179,14 +179,12 @@ object PythonRDD {
val dOut = new DataOutputStream(baos);
if (elem.isInstanceOf[Array[Byte]]) {
elem.asInstanceOf[Array[Byte]]
- } else if (elem.isInstanceOf[scala.Tuple2[_, _]]) {
- val t = elem.asInstanceOf[scala.Tuple2[_, _]]
- val t1 = t._1.asInstanceOf[Array[Byte]]
- val t2 = t._2.asInstanceOf[Array[Byte]]
+ } else if (elem.isInstanceOf[scala.Tuple2[Array[Byte], Array[Byte]]]) {
+ val t = elem.asInstanceOf[scala.Tuple2[Array[Byte], Array[Byte]]]
dOut.writeByte(Pickle.PROTO)
dOut.writeByte(Pickle.TWO)
- dOut.write(PythonRDD.stripPickle(t1))
- dOut.write(PythonRDD.stripPickle(t2))
+ dOut.write(PythonRDD.stripPickle(t._1))
+ dOut.write(PythonRDD.stripPickle(t._2))
dOut.writeByte(Pickle.TUPLE2)
dOut.writeByte(Pickle.STOP)
baos.toByteArray()