diff options
author | Josh Rosen <joshrosen@eecs.berkeley.edu> | 2012-10-28 22:30:28 -0700 |
---|---|---|
committer | Josh Rosen <joshrosen@eecs.berkeley.edu> | 2012-10-28 22:30:28 -0700 |
commit | 2ccf3b665280bf5b0919e3801d028126cb070dbd (patch) | |
tree | c486342f7d31dd1f529444ca40339e2a879ac219 /core | |
parent | 7859879aaa1860ff6b383e32a18fd9a410a97416 (diff) | |
download | spark-2ccf3b665280bf5b0919e3801d028126cb070dbd.tar.gz spark-2ccf3b665280bf5b0919e3801d028126cb070dbd.tar.bz2 spark-2ccf3b665280bf5b0919e3801d028126cb070dbd.zip |
Fix PySpark hash partitioning bug.
A Java array's hashCode is based on its object
identify, not its elements, so this was causing
serialized keys to be hashed incorrectly.
This commit adds a PySpark-specific workaround
and adds more tests.
Diffstat (limited to 'core')
-rw-r--r-- | core/src/main/scala/spark/api/python/PythonPartitioner.scala | 41 | ||||
-rw-r--r-- | core/src/main/scala/spark/api/python/PythonRDD.scala | 10 |
2 files changed, 45 insertions, 6 deletions
diff --git a/core/src/main/scala/spark/api/python/PythonPartitioner.scala b/core/src/main/scala/spark/api/python/PythonPartitioner.scala new file mode 100644 index 0000000000..ef9f808fb2 --- /dev/null +++ b/core/src/main/scala/spark/api/python/PythonPartitioner.scala @@ -0,0 +1,41 @@ +package spark.api.python + +import spark.Partitioner + +import java.util.Arrays + +/** + * A [[spark.Partitioner]] that performs handling of byte arrays, for use by the Python API. + */ +class PythonPartitioner(override val numPartitions: Int) extends Partitioner { + + override def getPartition(key: Any): Int = { + if (key == null) { + return 0 + } + else { + val hashCode = { + if (key.isInstanceOf[Array[Byte]]) { + System.err.println("Dumping a byte array!" + Arrays.hashCode(key.asInstanceOf[Array[Byte]]) + ) + Arrays.hashCode(key.asInstanceOf[Array[Byte]]) + } + else + key.hashCode() + } + val mod = hashCode % numPartitions + if (mod < 0) { + mod + numPartitions + } else { + mod // Guard against negative hash codes + } + } + } + + override def equals(other: Any): Boolean = other match { + case h: PythonPartitioner => + h.numPartitions == numPartitions + case _ => + false + } +} diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index a593e53efd..50094d6b0f 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -179,14 +179,12 @@ object PythonRDD { val dOut = new DataOutputStream(baos); if (elem.isInstanceOf[Array[Byte]]) { elem.asInstanceOf[Array[Byte]] - } else if (elem.isInstanceOf[scala.Tuple2[_, _]]) { - val t = elem.asInstanceOf[scala.Tuple2[_, _]] - val t1 = t._1.asInstanceOf[Array[Byte]] - val t2 = t._2.asInstanceOf[Array[Byte]] + } else if (elem.isInstanceOf[scala.Tuple2[Array[Byte], Array[Byte]]]) { + val t = elem.asInstanceOf[scala.Tuple2[Array[Byte], Array[Byte]]] dOut.writeByte(Pickle.PROTO) dOut.writeByte(Pickle.TWO) - dOut.write(PythonRDD.stripPickle(t1)) - dOut.write(PythonRDD.stripPickle(t2)) + dOut.write(PythonRDD.stripPickle(t._1)) + dOut.write(PythonRDD.stripPickle(t._2)) dOut.writeByte(Pickle.TUPLE2) dOut.writeByte(Pickle.STOP) baos.toByteArray() |