aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala24
-rw-r--r--python/pyspark/broadcast.py4
-rw-r--r--python/pyspark/context.py5
-rw-r--r--python/pyspark/serializers.py185
-rw-r--r--python/pyspark/tests.py52
-rw-r--r--python/pyspark/worker.py8
-rwxr-xr-xpython/run-tests2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala2
9 files changed, 257 insertions, 27 deletions
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 45beb8fc8c..b80c771d58 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -47,7 +47,7 @@ private[spark] class PythonRDD(
pythonIncludes: JList[String],
preservePartitoning: Boolean,
pythonExec: String,
- broadcastVars: JList[Broadcast[Array[Byte]]],
+ broadcastVars: JList[Broadcast[Array[Array[Byte]]]],
accumulator: Accumulator[JList[Array[Byte]]])
extends RDD[Array[Byte]](parent) {
@@ -230,8 +230,8 @@ private[spark] class PythonRDD(
if (!oldBids.contains(broadcast.id)) {
// send new broadcast
dataOut.writeLong(broadcast.id)
- dataOut.writeInt(broadcast.value.length)
- dataOut.write(broadcast.value)
+ dataOut.writeLong(broadcast.value.map(_.length.toLong).sum)
+ broadcast.value.foreach(dataOut.write)
oldBids.add(broadcast.id)
}
}
@@ -368,16 +368,24 @@ private[spark] object PythonRDD extends Logging {
}
}
- def readBroadcastFromFile(sc: JavaSparkContext, filename: String): Broadcast[Array[Byte]] = {
+ def readBroadcastFromFile(
+ sc: JavaSparkContext,
+ filename: String): Broadcast[Array[Array[Byte]]] = {
+ val size = new File(filename).length()
val file = new DataInputStream(new FileInputStream(filename))
+ val blockSize = 1 << 20
+ val n = ((size + blockSize - 1) / blockSize).toInt
+ val obj = new Array[Array[Byte]](n)
try {
- val length = file.readInt()
- val obj = new Array[Byte](length)
- file.readFully(obj)
- sc.broadcast(obj)
+ for (i <- 0 until n) {
+ val length = if (i < (n - 1)) blockSize else (size % blockSize).toInt
+ obj(i) = new Array[Byte](length)
+ file.readFully(obj(i))
+ }
} finally {
file.close()
}
+ sc.broadcast(obj)
}
def writeIteratorToStream[T](iter: Iterator[T], dataOut: DataOutputStream) {
diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py
index f124dc6c07..01cac3c72c 100644
--- a/python/pyspark/broadcast.py
+++ b/python/pyspark/broadcast.py
@@ -29,7 +29,7 @@
"""
import os
-from pyspark.serializers import CompressedSerializer, PickleSerializer
+from pyspark.serializers import LargeObjectSerializer
__all__ = ['Broadcast']
@@ -73,7 +73,7 @@ class Broadcast(object):
""" Return the broadcasted value
"""
if not hasattr(self, "_value") and self.path is not None:
- ser = CompressedSerializer(PickleSerializer())
+ ser = LargeObjectSerializer()
self._value = ser.load_stream(open(self.path)).next()
return self._value
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index b6c991453d..ec67ec8d0f 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -29,7 +29,7 @@ from pyspark.conf import SparkConf
from pyspark.files import SparkFiles
from pyspark.java_gateway import launch_gateway
from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \
- PairDeserializer, CompressedSerializer, AutoBatchedSerializer, NoOpSerializer
+ PairDeserializer, AutoBatchedSerializer, NoOpSerializer, LargeObjectSerializer
from pyspark.storagelevel import StorageLevel
from pyspark.rdd import RDD
from pyspark.traceback_utils import CallSite, first_spark_call
@@ -624,7 +624,8 @@ class SparkContext(object):
object for reading it in distributed functions. The variable will
be sent to each cluster only once.
"""
- ser = CompressedSerializer(PickleSerializer())
+ ser = LargeObjectSerializer()
+
# pass large object by py4j is very slow and need much memory
tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir)
ser.dump_stream([value], tempFile)
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index d597cbf94e..760a509f0e 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -133,6 +133,8 @@ class FramedSerializer(Serializer):
def _write_with_length(self, obj, stream):
serialized = self.dumps(obj)
+ if len(serialized) > (1 << 31):
+ raise ValueError("can not serialize object larger than 2G")
write_int(len(serialized), stream)
if self._only_write_strings:
stream.write(str(serialized))
@@ -446,20 +448,184 @@ class AutoSerializer(FramedSerializer):
raise ValueError("invalid sevialization type: %s" % _type)
-class CompressedSerializer(FramedSerializer):
+class SizeLimitedStream(object):
"""
- Compress the serialized data
+ Read at most `limit` bytes from underlying stream
+
+ >>> from StringIO import StringIO
+ >>> io = StringIO()
+ >>> io.write("Hello world")
+ >>> io.seek(0)
+ >>> lio = SizeLimitedStream(io, 5)
+ >>> lio.read()
+ 'Hello'
+ """
+ def __init__(self, stream, limit):
+ self.stream = stream
+ self.limit = limit
+
+ def read(self, n=0):
+ if n > self.limit or n == 0:
+ n = self.limit
+ buf = self.stream.read(n)
+ self.limit -= len(buf)
+ return buf
+
+
+class CompressedStream(object):
+ """
+ Compress the data using zlib
+
+ >>> from StringIO import StringIO
+ >>> io = StringIO()
+ >>> wio = CompressedStream(io, 'w')
+ >>> wio.write("Hello world")
+ >>> wio.flush()
+ >>> io.seek(0)
+ >>> rio = CompressedStream(io, 'r')
+ >>> rio.read()
+ 'Hello world'
+ >>> rio.read()
+ ''
+ """
+ MAX_BATCH = 1 << 20 # 1MB
+
+ def __init__(self, stream, mode='w', level=1):
+ self.stream = stream
+ self.mode = mode
+ if mode == 'w':
+ self.compresser = zlib.compressobj(level)
+ elif mode == 'r':
+ self.decompresser = zlib.decompressobj()
+ self.buf = ''
+ else:
+ raise ValueError("can only support mode 'w' or 'r' ")
+
+ def write(self, buf):
+ assert self.mode == 'w', "It's not opened for write"
+ if len(buf) > self.MAX_BATCH:
+ # zlib can not compress string larger than 2G
+ batches = len(buf) / self.MAX_BATCH + 1 # last one may be empty
+ for i in xrange(batches):
+ self.write(buf[i * self.MAX_BATCH:(i + 1) * self.MAX_BATCH])
+ else:
+ compressed = self.compresser.compress(buf)
+ self.stream.write(compressed)
+
+ def flush(self, mode=zlib.Z_FULL_FLUSH):
+ if self.mode == 'w':
+ d = self.compresser.flush(mode)
+ self.stream.write(d)
+ self.stream.flush()
+
+ def close(self):
+ if self.mode == 'w':
+ self.flush(zlib.Z_FINISH)
+ self.stream.close()
+
+ def read(self, size=0):
+ assert self.mode == 'r', "It's not opened for read"
+ if not size:
+ data = self.stream.read()
+ result = self.decompresser.decompress(data)
+ last = self.decompresser.flush()
+ return self.buf + result + last
+
+ # fast path for small read()
+ if size <= len(self.buf):
+ result = self.buf[:size]
+ self.buf = self.buf[size:]
+ return result
+
+ result = [self.buf]
+ size -= len(self.buf)
+ self.buf = ''
+ while size:
+ need = min(size, self.MAX_BATCH)
+ input = self.stream.read(need)
+ if input:
+ buf = self.decompresser.decompress(input)
+ else:
+ buf = self.decompresser.flush()
+
+ if len(buf) >= size:
+ self.buf = buf[size:]
+ result.append(buf[:size])
+ return ''.join(result)
+
+ size -= len(buf)
+ result.append(buf)
+ if not input:
+ return ''.join(result)
+
+ def readline(self):
+ """
+ This is needed for pickle, but not used in protocol 2
+ """
+ line = []
+ b = self.read(1)
+ while b and b != '\n':
+ line.append(b)
+ b = self.read(1)
+ line.append(b)
+ return ''.join(line)
+
+
+class LargeObjectSerializer(Serializer):
+ """
+ Serialize large object which could be larger than 2G
+
+ It uses cPickle to serialize the objects
"""
+ def dump_stream(self, iterator, stream):
+ stream = CompressedStream(stream, 'w')
+ for value in iterator:
+ if isinstance(value, basestring):
+ if isinstance(value, unicode):
+ stream.write('U')
+ value = value.encode("utf-8")
+ else:
+ stream.write('S')
+ write_long(len(value), stream)
+ stream.write(value)
+ else:
+ stream.write('P')
+ cPickle.dump(value, stream, 2)
+ stream.flush()
+ def load_stream(self, stream):
+ stream = CompressedStream(stream, 'r')
+ while True:
+ type = stream.read(1)
+ if not type:
+ return
+ if type in ('S', 'U'):
+ length = read_long(stream)
+ value = stream.read(length)
+ if type == 'U':
+ value = value.decode('utf-8')
+ yield value
+ elif type == 'P':
+ yield cPickle.load(stream)
+ else:
+ raise ValueError("unknown type: %s" % type)
+
+
+class CompressedSerializer(Serializer):
+ """
+ Compress the serialized data
+ """
def __init__(self, serializer):
- FramedSerializer.__init__(self)
self.serializer = serializer
- def dumps(self, obj):
- return zlib.compress(self.serializer.dumps(obj), 1)
+ def load_stream(self, stream):
+ stream = CompressedStream(stream, "r")
+ return self.serializer.load_stream(stream)
- def loads(self, obj):
- return self.serializer.loads(zlib.decompress(obj))
+ def dump_stream(self, iterator, stream):
+ stream = CompressedStream(stream, "w")
+ self.serializer.dump_stream(iterator, stream)
+ stream.flush()
class UTF8Deserializer(Serializer):
@@ -517,3 +683,8 @@ def write_int(value, stream):
def write_with_length(obj, stream):
write_int(len(obj), stream)
stream.write(obj)
+
+
+if __name__ == '__main__':
+ import doctest
+ doctest.testmod()
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 491e445a21..a01bd8d415 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -32,6 +32,7 @@ import time
import zipfile
import random
import threading
+import hashlib
if sys.version_info[:2] <= (2, 6):
try:
@@ -47,7 +48,7 @@ from pyspark.conf import SparkConf
from pyspark.context import SparkContext
from pyspark.files import SparkFiles
from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer, \
- CloudPickleSerializer
+ CloudPickleSerializer, SizeLimitedStream, CompressedSerializer, LargeObjectSerializer
from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter
from pyspark.sql import SQLContext, IntegerType, Row, ArrayType, StructType, StructField, \
UserDefinedType, DoubleType
@@ -236,6 +237,27 @@ class SerializationTestCase(unittest.TestCase):
self.assertTrue("exit" in foo.func_code.co_names)
ser.dumps(foo)
+ def _test_serializer(self, ser):
+ from StringIO import StringIO
+ io = StringIO()
+ ser.dump_stream(["abc", u"123", range(5)], io)
+ io.seek(0)
+ self.assertEqual(["abc", u"123", range(5)], list(ser.load_stream(io)))
+ size = io.tell()
+ ser.dump_stream(range(1000), io)
+ io.seek(0)
+ first = SizeLimitedStream(io, size)
+ self.assertEqual(["abc", u"123", range(5)], list(ser.load_stream(first)))
+ self.assertEqual(range(1000), list(ser.load_stream(io)))
+
+ def test_compressed_serializer(self):
+ ser = CompressedSerializer(PickleSerializer())
+ self._test_serializer(ser)
+
+ def test_large_object_serializer(self):
+ ser = LargeObjectSerializer()
+ self._test_serializer(ser)
+
class PySparkTestCase(unittest.TestCase):
@@ -440,7 +462,7 @@ class RDDTests(ReusedPySparkTestCase):
subset = data.takeSample(False, 10)
self.assertEqual(len(subset), 10)
- def testAggregateByKey(self):
+ def test_aggregate_by_key(self):
data = self.sc.parallelize([(1, 1), (1, 1), (3, 2), (5, 1), (5, 3)], 2)
def seqOp(x, y):
@@ -478,6 +500,32 @@ class RDDTests(ReusedPySparkTestCase):
m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum()
self.assertEquals(N, m)
+ def test_multiple_broadcasts(self):
+ N = 1 << 21
+ b1 = self.sc.broadcast(set(range(N))) # multiple blocks in JVM
+ r = range(1 << 15)
+ random.shuffle(r)
+ s = str(r)
+ checksum = hashlib.md5(s).hexdigest()
+ b2 = self.sc.broadcast(s)
+ r = list(set(self.sc.parallelize(range(10), 10).map(
+ lambda x: (len(b1.value), hashlib.md5(b2.value).hexdigest())).collect()))
+ self.assertEqual(1, len(r))
+ size, csum = r[0]
+ self.assertEqual(N, size)
+ self.assertEqual(checksum, csum)
+
+ random.shuffle(r)
+ s = str(r)
+ checksum = hashlib.md5(s).hexdigest()
+ b2 = self.sc.broadcast(s)
+ r = list(set(self.sc.parallelize(range(10), 10).map(
+ lambda x: (len(b1.value), hashlib.md5(b2.value).hexdigest())).collect()))
+ self.assertEqual(1, len(r))
+ size, csum = r[0]
+ self.assertEqual(N, size)
+ self.assertEqual(checksum, csum)
+
def test_large_closure(self):
N = 1000000
data = [float(i) for i in xrange(N)]
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 2bdccb5e93..e1552a0b0b 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -31,7 +31,7 @@ from pyspark.broadcast import Broadcast, _broadcastRegistry
from pyspark.files import SparkFiles
from pyspark.serializers import write_with_length, write_int, read_long, \
write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \
- CompressedSerializer
+ SizeLimitedStream, LargeObjectSerializer
from pyspark import shuffle
pickleSer = PickleSerializer()
@@ -78,11 +78,13 @@ def main(infile, outfile):
# fetch names and values of broadcast variables
num_broadcast_variables = read_int(infile)
- ser = CompressedSerializer(pickleSer)
+ bser = LargeObjectSerializer()
for _ in range(num_broadcast_variables):
bid = read_long(infile)
if bid >= 0:
- value = ser._read_with_length(infile)
+ size = read_long(infile)
+ s = SizeLimitedStream(infile, size)
+ value = list((bser.load_stream(s)))[0] # read out all the bytes
_broadcastRegistry[bid] = Broadcast(bid, value)
else:
bid = - bid - 1
diff --git a/python/run-tests b/python/run-tests
index e66854b44d..9ee19ed6e6 100755
--- a/python/run-tests
+++ b/python/run-tests
@@ -56,7 +56,7 @@ function run_core_tests() {
run_test "pyspark/conf.py"
PYSPARK_DOC_TEST=1 run_test "pyspark/broadcast.py"
PYSPARK_DOC_TEST=1 run_test "pyspark/accumulators.py"
- PYSPARK_DOC_TEST=1 run_test "pyspark/serializers.py"
+ run_test "pyspark/serializers.py"
run_test "pyspark/shuffle.py"
run_test "pyspark/tests.py"
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala
index 6d4c0d82ac..ddcb5db6c3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala
@@ -39,7 +39,7 @@ private[sql] trait UDFRegistration {
envVars: JMap[String, String],
pythonIncludes: JList[String],
pythonExec: String,
- broadcastVars: JList[Broadcast[Array[Byte]]],
+ broadcastVars: JList[Broadcast[Array[Array[Byte]]]],
accumulator: Accumulator[JList[Array[Byte]]],
stringDataType: String): Unit = {
log.debug(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
index a83cf5d441..f98cae3f17 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
@@ -45,7 +45,7 @@ private[spark] case class PythonUDF(
envVars: JMap[String, String],
pythonIncludes: JList[String],
pythonExec: String,
- broadcastVars: JList[Broadcast[Array[Byte]]],
+ broadcastVars: JList[Broadcast[Array[Array[Byte]]]],
accumulator: Accumulator[JList[Array[Byte]]],
dataType: DataType,
children: Seq[Expression]) extends Expression with SparkLogging {