diff options
-rw-r--r-- | core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala | 73 | ||||
-rw-r--r-- | python/pyspark/broadcast.py | 95 | ||||
-rw-r--r-- | python/pyspark/context.py | 12 | ||||
-rw-r--r-- | python/pyspark/serializers.py | 178 | ||||
-rw-r--r-- | python/pyspark/tests.py | 18 | ||||
-rw-r--r-- | python/pyspark/worker.py | 10 | ||||
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala | 3 | ||||
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala | 4 |
8 files changed, 135 insertions, 258 deletions
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index b80c771d58..e0bc00e1eb 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -19,7 +19,7 @@ package org.apache.spark.api.python import java.io._ import java.net._ -import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections} +import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, UUID, Collections} import org.apache.spark.input.PortableDataStream @@ -47,7 +47,7 @@ private[spark] class PythonRDD( pythonIncludes: JList[String], preservePartitoning: Boolean, pythonExec: String, - broadcastVars: JList[Broadcast[Array[Array[Byte]]]], + broadcastVars: JList[Broadcast[PythonBroadcast]], accumulator: Accumulator[JList[Array[Byte]]]) extends RDD[Array[Byte]](parent) { @@ -230,8 +230,7 @@ private[spark] class PythonRDD( if (!oldBids.contains(broadcast.id)) { // send new broadcast dataOut.writeLong(broadcast.id) - dataOut.writeLong(broadcast.value.map(_.length.toLong).sum) - broadcast.value.foreach(dataOut.write) + PythonRDD.writeUTF(broadcast.value.path, dataOut) oldBids.add(broadcast.id) } } @@ -368,24 +367,8 @@ private[spark] object PythonRDD extends Logging { } } - def readBroadcastFromFile( - sc: JavaSparkContext, - filename: String): Broadcast[Array[Array[Byte]]] = { - val size = new File(filename).length() - val file = new DataInputStream(new FileInputStream(filename)) - val blockSize = 1 << 20 - val n = ((size + blockSize - 1) / blockSize).toInt - val obj = new Array[Array[Byte]](n) - try { - for (i <- 0 until n) { - val length = if (i < (n - 1)) blockSize else (size % blockSize).toInt - obj(i) = new Array[Byte](length) - file.readFully(obj(i)) - } - } finally { - file.close() - } - sc.broadcast(obj) + def readBroadcastFromFile(sc: JavaSparkContext, path: String): Broadcast[PythonBroadcast] = { + sc.broadcast(new PythonBroadcast(path)) } def writeIteratorToStream[T](iter: Iterator[T], dataOut: DataOutputStream) { @@ -824,3 +807,49 @@ private class PythonAccumulatorParam(@transient serverHost: String, serverPort: } } } + +/** + * An Wrapper for Python Broadcast, which is written into disk by Python. It also will + * write the data into disk after deserialization, then Python can read it from disks. + */ +private[spark] class PythonBroadcast(@transient var path: String) extends Serializable { + + /** + * Read data from disks, then copy it to `out` + */ + private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException { + val in = new FileInputStream(new File(path)) + try { + Utils.copyStream(in, out) + } finally { + in.close() + } + } + + /** + * Write data into disk, using randomly generated name. + */ + private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException { + val dir = new File(Utils.getLocalDir(SparkEnv.get.conf)) + val file = File.createTempFile("broadcast", "", dir) + path = file.getAbsolutePath + val out = new FileOutputStream(file) + try { + Utils.copyStream(in, out) + } finally { + out.close() + } + } + + /** + * Delete the file once the object is GCed. + */ + override def finalize() { + if (!path.isEmpty) { + val file = new File(path) + if (file.exists()) { + file.delete() + } + } + } +} diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py index 01cac3c72c..6b8a8b256a 100644 --- a/python/pyspark/broadcast.py +++ b/python/pyspark/broadcast.py @@ -15,21 +15,10 @@ # limitations under the License. # -""" ->>> from pyspark.context import SparkContext ->>> sc = SparkContext('local', 'test') ->>> b = sc.broadcast([1, 2, 3, 4, 5]) ->>> b.value -[1, 2, 3, 4, 5] ->>> sc.parallelize([0, 0]).flatMap(lambda x: b.value).collect() -[1, 2, 3, 4, 5, 1, 2, 3, 4, 5] ->>> b.unpersist() - ->>> large_broadcast = sc.broadcast(list(range(10000))) -""" import os - -from pyspark.serializers import LargeObjectSerializer +import cPickle +import gc +from tempfile import NamedTemporaryFile __all__ = ['Broadcast'] @@ -49,44 +38,88 @@ def _from_id(bid): class Broadcast(object): """ - A broadcast variable created with - L{SparkContext.broadcast()<pyspark.context.SparkContext.broadcast>}. + A broadcast variable created with L{SparkContext.broadcast()}. Access its value through C{.value}. + + Examples: + + >>> from pyspark.context import SparkContext + >>> sc = SparkContext('local', 'test') + >>> b = sc.broadcast([1, 2, 3, 4, 5]) + >>> b.value + [1, 2, 3, 4, 5] + >>> sc.parallelize([0, 0]).flatMap(lambda x: b.value).collect() + [1, 2, 3, 4, 5, 1, 2, 3, 4, 5] + >>> b.unpersist() + + >>> large_broadcast = sc.broadcast(range(10000)) """ - def __init__(self, bid, value, java_broadcast=None, - pickle_registry=None, path=None): + def __init__(self, sc=None, value=None, pickle_registry=None, path=None): """ - Should not be called directly by users -- use - L{SparkContext.broadcast()<pyspark.context.SparkContext.broadcast>} + Should not be called directly by users -- use L{SparkContext.broadcast()} instead. """ - self.bid = bid - if path is None: - self._value = value - self._jbroadcast = java_broadcast - self._pickle_registry = pickle_registry - self.path = path + if sc is not None: + f = NamedTemporaryFile(delete=False, dir=sc._temp_dir) + self._path = self.dump(value, f) + self._jbroadcast = sc._jvm.PythonRDD.readBroadcastFromFile(sc._jsc, self._path) + self._pickle_registry = pickle_registry + else: + self._jbroadcast = None + self._path = path + + def dump(self, value, f): + if isinstance(value, basestring): + if isinstance(value, unicode): + f.write('U') + value = value.encode('utf8') + else: + f.write('S') + f.write(value) + else: + f.write('P') + cPickle.dump(value, f, 2) + f.close() + return f.name + + def load(self, path): + with open(path, 'rb', 1 << 20) as f: + flag = f.read(1) + data = f.read() + if flag == 'P': + # cPickle.loads() may create lots of objects, disable GC + # temporary for better performance + gc.disable() + try: + return cPickle.loads(data) + finally: + gc.enable() + else: + return data.decode('utf8') if flag == 'U' else data @property def value(self): """ Return the broadcasted value """ - if not hasattr(self, "_value") and self.path is not None: - ser = LargeObjectSerializer() - self._value = ser.load_stream(open(self.path)).next() + if not hasattr(self, "_value") and self._path is not None: + self._value = self.load(self._path) return self._value def unpersist(self, blocking=False): """ Delete cached copies of this broadcast on the executors. """ + if self._jbroadcast is None: + raise Exception("Broadcast can only be unpersisted in driver") self._jbroadcast.unpersist(blocking) - os.unlink(self.path) + os.unlink(self._path) def __reduce__(self): + if self._jbroadcast is None: + raise Exception("Broadcast can only be serialized in driver") self._pickle_registry.add(self) - return (_from_id, (self.bid, )) + return _from_id, (self._jbroadcast.id(),) if __name__ == "__main__": diff --git a/python/pyspark/context.py b/python/pyspark/context.py index ec67ec8d0f..ed7351d60c 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -29,7 +29,7 @@ from pyspark.conf import SparkConf from pyspark.files import SparkFiles from pyspark.java_gateway import launch_gateway from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \ - PairDeserializer, AutoBatchedSerializer, NoOpSerializer, LargeObjectSerializer + PairDeserializer, AutoBatchedSerializer, NoOpSerializer from pyspark.storagelevel import StorageLevel from pyspark.rdd import RDD from pyspark.traceback_utils import CallSite, first_spark_call @@ -624,15 +624,7 @@ class SparkContext(object): object for reading it in distributed functions. The variable will be sent to each cluster only once. """ - ser = LargeObjectSerializer() - - # pass large object by py4j is very slow and need much memory - tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir) - ser.dump_stream([value], tempFile) - tempFile.close() - jbroadcast = self._jvm.PythonRDD.readBroadcastFromFile(self._jsc, tempFile.name) - return Broadcast(jbroadcast.id(), None, jbroadcast, - self._pickled_broadcast_vars, tempFile.name) + return Broadcast(self, value, self._pickled_broadcast_vars) def accumulator(self, value, accum_param=None): """ diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 760a509f0e..33aa55f7f1 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -448,184 +448,20 @@ class AutoSerializer(FramedSerializer): raise ValueError("invalid sevialization type: %s" % _type) -class SizeLimitedStream(object): - """ - Read at most `limit` bytes from underlying stream - - >>> from StringIO import StringIO - >>> io = StringIO() - >>> io.write("Hello world") - >>> io.seek(0) - >>> lio = SizeLimitedStream(io, 5) - >>> lio.read() - 'Hello' - """ - def __init__(self, stream, limit): - self.stream = stream - self.limit = limit - - def read(self, n=0): - if n > self.limit or n == 0: - n = self.limit - buf = self.stream.read(n) - self.limit -= len(buf) - return buf - - -class CompressedStream(object): - """ - Compress the data using zlib - - >>> from StringIO import StringIO - >>> io = StringIO() - >>> wio = CompressedStream(io, 'w') - >>> wio.write("Hello world") - >>> wio.flush() - >>> io.seek(0) - >>> rio = CompressedStream(io, 'r') - >>> rio.read() - 'Hello world' - >>> rio.read() - '' - """ - MAX_BATCH = 1 << 20 # 1MB - - def __init__(self, stream, mode='w', level=1): - self.stream = stream - self.mode = mode - if mode == 'w': - self.compresser = zlib.compressobj(level) - elif mode == 'r': - self.decompresser = zlib.decompressobj() - self.buf = '' - else: - raise ValueError("can only support mode 'w' or 'r' ") - - def write(self, buf): - assert self.mode == 'w', "It's not opened for write" - if len(buf) > self.MAX_BATCH: - # zlib can not compress string larger than 2G - batches = len(buf) / self.MAX_BATCH + 1 # last one may be empty - for i in xrange(batches): - self.write(buf[i * self.MAX_BATCH:(i + 1) * self.MAX_BATCH]) - else: - compressed = self.compresser.compress(buf) - self.stream.write(compressed) - - def flush(self, mode=zlib.Z_FULL_FLUSH): - if self.mode == 'w': - d = self.compresser.flush(mode) - self.stream.write(d) - self.stream.flush() - - def close(self): - if self.mode == 'w': - self.flush(zlib.Z_FINISH) - self.stream.close() - - def read(self, size=0): - assert self.mode == 'r', "It's not opened for read" - if not size: - data = self.stream.read() - result = self.decompresser.decompress(data) - last = self.decompresser.flush() - return self.buf + result + last - - # fast path for small read() - if size <= len(self.buf): - result = self.buf[:size] - self.buf = self.buf[size:] - return result - - result = [self.buf] - size -= len(self.buf) - self.buf = '' - while size: - need = min(size, self.MAX_BATCH) - input = self.stream.read(need) - if input: - buf = self.decompresser.decompress(input) - else: - buf = self.decompresser.flush() - - if len(buf) >= size: - self.buf = buf[size:] - result.append(buf[:size]) - return ''.join(result) - - size -= len(buf) - result.append(buf) - if not input: - return ''.join(result) - - def readline(self): - """ - This is needed for pickle, but not used in protocol 2 - """ - line = [] - b = self.read(1) - while b and b != '\n': - line.append(b) - b = self.read(1) - line.append(b) - return ''.join(line) - - -class LargeObjectSerializer(Serializer): - """ - Serialize large object which could be larger than 2G - - It uses cPickle to serialize the objects - """ - def dump_stream(self, iterator, stream): - stream = CompressedStream(stream, 'w') - for value in iterator: - if isinstance(value, basestring): - if isinstance(value, unicode): - stream.write('U') - value = value.encode("utf-8") - else: - stream.write('S') - write_long(len(value), stream) - stream.write(value) - else: - stream.write('P') - cPickle.dump(value, stream, 2) - stream.flush() - - def load_stream(self, stream): - stream = CompressedStream(stream, 'r') - while True: - type = stream.read(1) - if not type: - return - if type in ('S', 'U'): - length = read_long(stream) - value = stream.read(length) - if type == 'U': - value = value.decode('utf-8') - yield value - elif type == 'P': - yield cPickle.load(stream) - else: - raise ValueError("unknown type: %s" % type) - - -class CompressedSerializer(Serializer): +class CompressedSerializer(FramedSerializer): """ Compress the serialized data """ def __init__(self, serializer): + FramedSerializer.__init__(self) + assert isinstance(serializer, FramedSerializer), "serializer must be a FramedSerializer" self.serializer = serializer - def load_stream(self, stream): - stream = CompressedStream(stream, "r") - return self.serializer.load_stream(stream) + def dumps(self, obj): + return zlib.compress(self.serializer.dumps(obj), 1) - def dump_stream(self, iterator, stream): - stream = CompressedStream(stream, "w") - self.serializer.dump_stream(iterator, stream) - stream.flush() + def loads(self, obj): + return self.serializer.loads(zlib.decompress(obj)) class UTF8Deserializer(Serializer): diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 29bcd38908..32645778c2 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -48,7 +48,7 @@ from pyspark.conf import SparkConf from pyspark.context import SparkContext from pyspark.files import SparkFiles from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer, \ - CloudPickleSerializer, SizeLimitedStream, CompressedSerializer, LargeObjectSerializer + CloudPickleSerializer, CompressedSerializer from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter from pyspark.sql import SQLContext, IntegerType, Row, ArrayType, StructType, StructField, \ UserDefinedType, DoubleType @@ -237,26 +237,16 @@ class SerializationTestCase(unittest.TestCase): self.assertTrue("exit" in foo.func_code.co_names) ser.dumps(foo) - def _test_serializer(self, ser): + def test_compressed_serializer(self): + ser = CompressedSerializer(PickleSerializer()) from StringIO import StringIO io = StringIO() ser.dump_stream(["abc", u"123", range(5)], io) io.seek(0) self.assertEqual(["abc", u"123", range(5)], list(ser.load_stream(io))) - size = io.tell() ser.dump_stream(range(1000), io) io.seek(0) - first = SizeLimitedStream(io, size) - self.assertEqual(["abc", u"123", range(5)], list(ser.load_stream(first))) - self.assertEqual(range(1000), list(ser.load_stream(io))) - - def test_compressed_serializer(self): - ser = CompressedSerializer(PickleSerializer()) - self._test_serializer(ser) - - def test_large_object_serializer(self): - ser = LargeObjectSerializer() - self._test_serializer(ser) + self.assertEqual(["abc", u"123", range(5)] + range(1000), list(ser.load_stream(io))) class PySparkTestCase(unittest.TestCase): diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index e1552a0b0b..7e5343c973 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -30,8 +30,7 @@ from pyspark.accumulators import _accumulatorRegistry from pyspark.broadcast import Broadcast, _broadcastRegistry from pyspark.files import SparkFiles from pyspark.serializers import write_with_length, write_int, read_long, \ - write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \ - SizeLimitedStream, LargeObjectSerializer + write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer from pyspark import shuffle pickleSer = PickleSerializer() @@ -78,14 +77,11 @@ def main(infile, outfile): # fetch names and values of broadcast variables num_broadcast_variables = read_int(infile) - bser = LargeObjectSerializer() for _ in range(num_broadcast_variables): bid = read_long(infile) if bid >= 0: - size = read_long(infile) - s = SizeLimitedStream(infile, size) - value = list((bser.load_stream(s)))[0] # read out all the bytes - _broadcastRegistry[bid] = Broadcast(bid, value) + path = utf8_deserializer.loads(infile) + _broadcastRegistry[bid] = Broadcast(path=path) else: bid = - bid - 1 _broadcastRegistry.pop(bid) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala index ddcb5db6c3..00d6b43a57 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import java.util.{List => JList, Map => JMap} import org.apache.spark.Accumulator +import org.apache.spark.api.python.PythonBroadcast import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUdf} @@ -39,7 +40,7 @@ private[sql] trait UDFRegistration { envVars: JMap[String, String], pythonIncludes: JList[String], pythonExec: String, - broadcastVars: JList[Broadcast[Array[Array[Byte]]]], + broadcastVars: JList[Broadcast[PythonBroadcast]], accumulator: Accumulator[JList[Array[Byte]]], stringDataType: String): Unit = { log.debug( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala index f98cae3f17..2b4a88d5e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala @@ -26,7 +26,7 @@ import scala.collection.JavaConverters._ import net.razorvine.pickle.{Pickler, Unpickler} import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.api.python.PythonRDD +import org.apache.spark.api.python.{PythonBroadcast, PythonRDD} import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.catalyst.expressions._ @@ -45,7 +45,7 @@ private[spark] case class PythonUDF( envVars: JMap[String, String], pythonIncludes: JList[String], pythonExec: String, - broadcastVars: JList[Broadcast[Array[Array[Byte]]]], + broadcastVars: JList[Broadcast[PythonBroadcast]], accumulator: Accumulator[JList[Array[Byte]]], dataType: DataType, children: Seq[Expression]) extends Expression with SparkLogging { |