aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@eecs.berkeley.edu>2012-12-24 17:20:10 -0800
committerJosh Rosen <joshrosen@eecs.berkeley.edu>2012-12-24 17:20:10 -0800
commit4608902fb87af64a15b97ab21fe6382cd6e5a644 (patch)
treeef2e9e7a9d3f88dccd70e7b6e39753354b605207 /core
parentccd075cf960df6c6c449b709515cdd81499a52be (diff)
downloadspark-4608902fb87af64a15b97ab21fe6382cd6e5a644.tar.gz
spark-4608902fb87af64a15b97ab21fe6382cd6e5a644.tar.bz2
spark-4608902fb87af64a15b97ab21fe6382cd6e5a644.zip
Use filesystem to collect RDDs in PySpark.
Passing large volumes of data through Py4J seems to be slow. It appears to be faster to write the data to the local filesystem and read it back from Python.
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/spark/api/python/PythonRDD.scala66
1 files changed, 24 insertions, 42 deletions
diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala
index 50094d6b0f..4f870e837a 100644
--- a/core/src/main/scala/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/spark/api/python/PythonRDD.scala
@@ -1,6 +1,7 @@
package spark.api.python
import java.io._
+import java.util.{List => JList}
import scala.collection.Map
import scala.collection.JavaConversions._
@@ -59,36 +60,7 @@ trait PythonRDDBase {
}
out.flush()
for (elem <- parent.iterator(split)) {
- if (elem.isInstanceOf[Array[Byte]]) {
- val arr = elem.asInstanceOf[Array[Byte]]
- dOut.writeInt(arr.length)
- dOut.write(arr)
- } else if (elem.isInstanceOf[scala.Tuple2[_, _]]) {
- val t = elem.asInstanceOf[scala.Tuple2[_, _]]
- val t1 = t._1.asInstanceOf[Array[Byte]]
- val t2 = t._2.asInstanceOf[Array[Byte]]
- val length = t1.length + t2.length - 3 - 3 + 4 // stripPickle() removes 3 bytes
- dOut.writeInt(length)
- dOut.writeByte(Pickle.PROTO)
- dOut.writeByte(Pickle.TWO)
- dOut.write(PythonRDD.stripPickle(t1))
- dOut.write(PythonRDD.stripPickle(t2))
- dOut.writeByte(Pickle.TUPLE2)
- dOut.writeByte(Pickle.STOP)
- } else if (elem.isInstanceOf[String]) {
- // For uniformity, strings are wrapped into Pickles.
- val s = elem.asInstanceOf[String].getBytes("UTF-8")
- val length = 2 + 1 + 4 + s.length + 1
- dOut.writeInt(length)
- dOut.writeByte(Pickle.PROTO)
- dOut.writeByte(Pickle.TWO)
- dOut.writeByte(Pickle.BINUNICODE)
- dOut.writeInt(Integer.reverseBytes(s.length))
- dOut.write(s)
- dOut.writeByte(Pickle.STOP)
- } else {
- throw new Exception("Unexpected RDD type")
- }
+ PythonRDD.writeAsPickle(elem, dOut)
}
dOut.flush()
out.flush()
@@ -174,36 +146,45 @@ object PythonRDD {
arr.slice(2, arr.length - 1)
}
- def asPickle(elem: Any) : Array[Byte] = {
- val baos = new ByteArrayOutputStream();
- val dOut = new DataOutputStream(baos);
+ /**
+ * Write strings, pickled Python objects, or pairs of pickled objects to a data output stream.
+ * The data format is a 32-bit integer representing the pickled object's length (in bytes),
+ * followed by the pickled data.
+ * @param elem the object to write
+ * @param dOut a data output stream
+ */
+ def writeAsPickle(elem: Any, dOut: DataOutputStream) {
if (elem.isInstanceOf[Array[Byte]]) {
- elem.asInstanceOf[Array[Byte]]
+ val arr = elem.asInstanceOf[Array[Byte]]
+ dOut.writeInt(arr.length)
+ dOut.write(arr)
} else if (elem.isInstanceOf[scala.Tuple2[Array[Byte], Array[Byte]]]) {
val t = elem.asInstanceOf[scala.Tuple2[Array[Byte], Array[Byte]]]
+ val length = t._1.length + t._2.length - 3 - 3 + 4 // stripPickle() removes 3 bytes
+ dOut.writeInt(length)
dOut.writeByte(Pickle.PROTO)
dOut.writeByte(Pickle.TWO)
dOut.write(PythonRDD.stripPickle(t._1))
dOut.write(PythonRDD.stripPickle(t._2))
dOut.writeByte(Pickle.TUPLE2)
dOut.writeByte(Pickle.STOP)
- baos.toByteArray()
} else if (elem.isInstanceOf[String]) {
// For uniformity, strings are wrapped into Pickles.
val s = elem.asInstanceOf[String].getBytes("UTF-8")
+ val length = 2 + 1 + 4 + s.length + 1
+ dOut.writeInt(length)
dOut.writeByte(Pickle.PROTO)
dOut.writeByte(Pickle.TWO)
dOut.write(Pickle.BINUNICODE)
dOut.writeInt(Integer.reverseBytes(s.length))
dOut.write(s)
dOut.writeByte(Pickle.STOP)
- baos.toByteArray()
} else {
throw new Exception("Unexpected RDD type")
}
}
- def pickleFile(sc: JavaSparkContext, filename: String, parallelism: Int) :
+ def readRDDFromPickleFile(sc: JavaSparkContext, filename: String, parallelism: Int) :
JavaRDD[Array[Byte]] = {
val file = new DataInputStream(new FileInputStream(filename))
val objs = new collection.mutable.ArrayBuffer[Array[Byte]]
@@ -221,11 +202,12 @@ object PythonRDD {
JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism))
}
- def arrayAsPickle(arr : Any) : Array[Byte] = {
- val pickles : Array[Byte] = arr.asInstanceOf[Array[Any]].map(asPickle).map(stripPickle).flatten
-
- Array[Byte](Pickle.PROTO, Pickle.TWO, Pickle.EMPTY_LIST, Pickle.MARK) ++ pickles ++
- Array[Byte] (Pickle.APPENDS, Pickle.STOP)
+ def writeArrayToPickleFile[T](items: Array[T], filename: String) {
+ val file = new DataOutputStream(new FileOutputStream(filename))
+ for (item <- items) {
+ writeAsPickle(item, file)
+ }
+ file.close()
}
}