aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala76
1 files changed, 59 insertions, 17 deletions
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index b1cec0f647..8d4a53b4ca 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -19,26 +19,27 @@ package org.apache.spark.api.python
import java.io._
import java.net._
-import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, UUID, Collections}
-
-import org.apache.spark.input.PortableDataStream
+import java.util.{Collections, ArrayList => JArrayList, List => JList, Map => JMap}
import scala.collection.JavaConversions._
import scala.collection.mutable
import scala.language.existentials
import com.google.common.base.Charsets.UTF_8
-
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.io.compress.CompressionCodec
-import org.apache.hadoop.mapred.{InputFormat, OutputFormat, JobConf}
+import org.apache.hadoop.mapred.{InputFormat, JobConf, OutputFormat}
import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat, OutputFormat => NewOutputFormat}
+
import org.apache.spark._
-import org.apache.spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD}
+import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext}
import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.input.PortableDataStream
import org.apache.spark.rdd.RDD
import org.apache.spark.util.Utils
+import scala.util.control.NonFatal
+
private[spark] class PythonRDD(
@transient parent: RDD[_],
command: Array[Byte],
@@ -341,21 +342,33 @@ private[spark] object PythonRDD extends Logging {
/**
* Adapter for calling SparkContext#runJob from Python.
*
- * This method will return an iterator of an array that contains all elements in the RDD
+ * This method will serve an iterator of an array that contains all elements in the RDD
* (effectively a collect()), but allows you to run on a certain subset of partitions,
* or to enable local execution.
+ *
+ * @return the port number of a local socket which serves the data collected from this job.
*/
def runJob(
sc: SparkContext,
rdd: JavaRDD[Array[Byte]],
partitions: JArrayList[Int],
- allowLocal: Boolean): Iterator[Array[Byte]] = {
+ allowLocal: Boolean): Int = {
type ByteArray = Array[Byte]
type UnrolledPartition = Array[ByteArray]
val allPartitions: Array[UnrolledPartition] =
sc.runJob(rdd, (x: Iterator[ByteArray]) => x.toArray, partitions, allowLocal)
val flattenedPartition: UnrolledPartition = Array.concat(allPartitions: _*)
- flattenedPartition.iterator
+ serveIterator(flattenedPartition.iterator,
+ s"serve RDD ${rdd.id} with partitions ${partitions.mkString(",")}")
+ }
+
+ /**
+ * A helper function to collect an RDD as an iterator, then serve it via socket.
+ *
+ * @return the port number of a local socket which serves the data collected from this job.
+ */
+ def collectAndServe[T](rdd: RDD[T]): Int = {
+ serveIterator(rdd.collect().iterator, s"serve RDD ${rdd.id}")
}
def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int):
@@ -575,15 +588,44 @@ private[spark] object PythonRDD extends Logging {
dataOut.write(bytes)
}
- def writeToFile[T](items: java.util.Iterator[T], filename: String) {
- import scala.collection.JavaConverters._
- writeToFile(items.asScala, filename)
- }
+ /**
+ * Create a socket server and a background thread to serve the data in `items`,
+ *
+ * The socket server can only accept one connection, or close if no connection
+ * in 3 seconds.
+ *
+ * Once a connection comes in, it tries to serialize all the data in `items`
+ * and send them into this connection.
+ *
+ * The thread will terminate after all the data are sent or any exceptions happen.
+ */
+ private def serveIterator[T](items: Iterator[T], threadName: String): Int = {
+ val serverSocket = new ServerSocket(0, 1)
+ serverSocket.setReuseAddress(true)
+ // Close the socket if no connection in 3 seconds
+ serverSocket.setSoTimeout(3000)
+
+ new Thread(threadName) {
+ setDaemon(true)
+ override def run() {
+ try {
+ val sock = serverSocket.accept()
+ val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream))
+ try {
+ writeIteratorToStream(items, out)
+ } finally {
+ out.close()
+ }
+ } catch {
+ case NonFatal(e) =>
+ logError(s"Error while sending iterator", e)
+ } finally {
+ serverSocket.close()
+ }
+ }
+ }.start()
- def writeToFile[T](items: Iterator[T], filename: String) {
- val file = new DataOutputStream(new FileOutputStream(filename))
- writeIteratorToStream(items, file)
- file.close()
+ serverSocket.getLocalPort
}
private def getMergedConf(confAsMap: java.util.HashMap[String, String],