aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala76
-rw-r--r--python/pyspark/context.py13
-rw-r--r--python/pyspark/rdd.py30
-rw-r--r--python/pyspark/sql/dataframe.py14
4 files changed, 82 insertions, 51 deletions
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index b1cec0f647..8d4a53b4ca 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -19,26 +19,27 @@ package org.apache.spark.api.python
import java.io._
import java.net._
-import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, UUID, Collections}
-
-import org.apache.spark.input.PortableDataStream
+import java.util.{Collections, ArrayList => JArrayList, List => JList, Map => JMap}
import scala.collection.JavaConversions._
import scala.collection.mutable
import scala.language.existentials
import com.google.common.base.Charsets.UTF_8
-
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.io.compress.CompressionCodec
-import org.apache.hadoop.mapred.{InputFormat, OutputFormat, JobConf}
+import org.apache.hadoop.mapred.{InputFormat, JobConf, OutputFormat}
import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat, OutputFormat => NewOutputFormat}
+
import org.apache.spark._
-import org.apache.spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD}
+import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext}
import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.input.PortableDataStream
import org.apache.spark.rdd.RDD
import org.apache.spark.util.Utils
+import scala.util.control.NonFatal
+
private[spark] class PythonRDD(
@transient parent: RDD[_],
command: Array[Byte],
@@ -341,21 +342,33 @@ private[spark] object PythonRDD extends Logging {
/**
* Adapter for calling SparkContext#runJob from Python.
*
- * This method will return an iterator of an array that contains all elements in the RDD
+ * This method will serve an iterator of an array that contains all elements in the RDD
* (effectively a collect()), but allows you to run on a certain subset of partitions,
* or to enable local execution.
+ *
+ * @return the port number of a local socket which serves the data collected from this job.
*/
def runJob(
sc: SparkContext,
rdd: JavaRDD[Array[Byte]],
partitions: JArrayList[Int],
- allowLocal: Boolean): Iterator[Array[Byte]] = {
+ allowLocal: Boolean): Int = {
type ByteArray = Array[Byte]
type UnrolledPartition = Array[ByteArray]
val allPartitions: Array[UnrolledPartition] =
sc.runJob(rdd, (x: Iterator[ByteArray]) => x.toArray, partitions, allowLocal)
val flattenedPartition: UnrolledPartition = Array.concat(allPartitions: _*)
- flattenedPartition.iterator
+ serveIterator(flattenedPartition.iterator,
+ s"serve RDD ${rdd.id} with partitions ${partitions.mkString(",")}")
+ }
+
+ /**
+ * A helper function to collect an RDD as an iterator, then serve it via socket.
+ *
+ * @return the port number of a local socket which serves the data collected from this job.
+ */
+ def collectAndServe[T](rdd: RDD[T]): Int = {
+ serveIterator(rdd.collect().iterator, s"serve RDD ${rdd.id}")
}
def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int):
@@ -575,15 +588,44 @@ private[spark] object PythonRDD extends Logging {
dataOut.write(bytes)
}
- def writeToFile[T](items: java.util.Iterator[T], filename: String) {
- import scala.collection.JavaConverters._
- writeToFile(items.asScala, filename)
- }
+ /**
+ * Create a socket server and a background thread to serve the data in `items`,
+ *
+ * The socket server can only accept one connection, or close if no connection
+ * in 3 seconds.
+ *
+ * Once a connection comes in, it tries to serialize all the data in `items`
+ * and send them into this connection.
+ *
+ * The thread will terminate after all the data are sent or any exceptions happen.
+ */
+ private def serveIterator[T](items: Iterator[T], threadName: String): Int = {
+ val serverSocket = new ServerSocket(0, 1)
+ serverSocket.setReuseAddress(true)
+ // Close the socket if no connection in 3 seconds
+ serverSocket.setSoTimeout(3000)
+
+ new Thread(threadName) {
+ setDaemon(true)
+ override def run() {
+ try {
+ val sock = serverSocket.accept()
+ val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream))
+ try {
+ writeIteratorToStream(items, out)
+ } finally {
+ out.close()
+ }
+ } catch {
+ case NonFatal(e) =>
+ logError(s"Error while sending iterator", e)
+ } finally {
+ serverSocket.close()
+ }
+ }
+ }.start()
- def writeToFile[T](items: Iterator[T], filename: String) {
- val file = new DataOutputStream(new FileOutputStream(filename))
- writeIteratorToStream(items, file)
- file.close()
+ serverSocket.getLocalPort
}
private def getMergedConf(confAsMap: java.util.HashMap[String, String],
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 6011caf9f1..78dccc4047 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -21,6 +21,8 @@ import sys
from threading import Lock
from tempfile import NamedTemporaryFile
+from py4j.java_collections import ListConverter
+
from pyspark import accumulators
from pyspark.accumulators import Accumulator
from pyspark.broadcast import Broadcast
@@ -30,13 +32,11 @@ from pyspark.java_gateway import launch_gateway
from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \
PairDeserializer, AutoBatchedSerializer, NoOpSerializer
from pyspark.storagelevel import StorageLevel
-from pyspark.rdd import RDD
+from pyspark.rdd import RDD, _load_from_socket
from pyspark.traceback_utils import CallSite, first_spark_call
from pyspark.status import StatusTracker
from pyspark.profiler import ProfilerCollector, BasicProfiler
-from py4j.java_collections import ListConverter
-
__all__ = ['SparkContext']
@@ -59,7 +59,6 @@ class SparkContext(object):
_gateway = None
_jvm = None
- _writeToFile = None
_next_accum_id = 0
_active_spark_context = None
_lock = Lock()
@@ -221,7 +220,6 @@ class SparkContext(object):
if not SparkContext._gateway:
SparkContext._gateway = gateway or launch_gateway()
SparkContext._jvm = SparkContext._gateway.jvm
- SparkContext._writeToFile = SparkContext._jvm.PythonRDD.writeToFile
if instance:
if (SparkContext._active_spark_context and
@@ -840,8 +838,9 @@ class SparkContext(object):
# by runJob() in order to avoid having to pass a Python lambda into
# SparkContext#runJob.
mappedRDD = rdd.mapPartitions(partitionFunc)
- it = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, javaPartitions, allowLocal)
- return list(mappedRDD._collect_iterator_through_file(it))
+ port = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, javaPartitions,
+ allowLocal)
+ return list(_load_from_socket(port, mappedRDD._jrdd_deserializer))
def show_profiles(self):
""" Print the profile stats to stdout """
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index cb12fed98c..bf17f513c0 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -19,7 +19,6 @@ import copy
from collections import defaultdict
from itertools import chain, ifilter, imap
import operator
-import os
import sys
import shlex
from subprocess import Popen, PIPE
@@ -29,6 +28,7 @@ import warnings
import heapq
import bisect
import random
+import socket
from math import sqrt, log, isinf, isnan, pow, ceil
from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
@@ -111,6 +111,17 @@ def _parse_memory(s):
return int(float(s[:-1]) * units[s[-1].lower()])
+def _load_from_socket(port, serializer):
+ sock = socket.socket()
+ try:
+ sock.connect(("localhost", port))
+ rf = sock.makefile("rb", 65536)
+ for item in serializer.load_stream(rf):
+ yield item
+ finally:
+ sock.close()
+
+
class Partitioner(object):
def __init__(self, numPartitions, partitionFunc):
self.numPartitions = numPartitions
@@ -698,21 +709,8 @@ class RDD(object):
Return a list that contains all of the elements in this RDD.
"""
with SCCallSiteSync(self.context) as css:
- bytesInJava = self._jrdd.collect().iterator()
- return list(self._collect_iterator_through_file(bytesInJava))
-
- def _collect_iterator_through_file(self, iterator):
- # Transferring lots of data through Py4J can be slow because
- # socket.readline() is inefficient. Instead, we'll dump the data to a
- # file and read it back.
- tempFile = NamedTemporaryFile(delete=False, dir=self.ctx._temp_dir)
- tempFile.close()
- self.ctx._writeToFile(iterator, tempFile.name)
- # Read the data into Python and deserialize it:
- with open(tempFile.name, 'rb') as tempFile:
- for item in self._jrdd_deserializer.load_stream(tempFile):
- yield item
- os.unlink(tempFile.name)
+ port = self.ctx._jvm.PythonRDD.collectAndServe(self._jrdd.rdd())
+ return list(_load_from_socket(port, self._jrdd_deserializer))
def reduce(self, f):
"""
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 5c3b7377c3..e8ce454745 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -19,13 +19,11 @@ import sys
import itertools
import warnings
import random
-import os
-from tempfile import NamedTemporaryFile
from py4j.java_collections import ListConverter, MapConverter
from pyspark.context import SparkContext
-from pyspark.rdd import RDD
+from pyspark.rdd import RDD, _load_from_socket
from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer
from pyspark.storagelevel import StorageLevel
from pyspark.traceback_utils import SCCallSiteSync
@@ -310,14 +308,8 @@ class DataFrame(object):
[Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
"""
with SCCallSiteSync(self._sc) as css:
- bytesInJava = self._jdf.javaToPython().collect().iterator()
- tempFile = NamedTemporaryFile(delete=False, dir=self._sc._temp_dir)
- tempFile.close()
- self._sc._writeToFile(bytesInJava, tempFile.name)
- # Read the data into Python and deserialize it:
- with open(tempFile.name, 'rb') as tempFile:
- rs = list(BatchedSerializer(PickleSerializer()).load_stream(tempFile))
- os.unlink(tempFile.name)
+ port = self._sc._jvm.PythonRDD.collectAndServe(self._jdf.javaToPython().rdd())
+ rs = list(_load_from_socket(port, BatchedSerializer(PickleSerializer())))
cls = _create_cls(self.schema)
return [cls(r) for r in rs]