aboutsummaryrefslogtreecommitdiff
path: root/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
diff options
context:
space:
mode:
Diffstat (limited to 'core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala')
-rw-r--r--core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala80
1 files changed, 62 insertions, 18 deletions
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index f423b2ee56..ab5b6c8380 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -59,7 +59,7 @@ private[spark] class PythonRDD(
val asJavaRDD: JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this)
override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
- val runner = new PythonRunner(func, bufferSize, reuse_worker)
+ val runner = PythonRunner(func, bufferSize, reuse_worker)
runner.compute(firstParent.iterator(split, context), split.index, context)
}
}
@@ -78,17 +78,41 @@ private[spark] case class PythonFunction(
accumulator: Accumulator[JList[Array[Byte]]])
/**
- * A helper class to run Python UDFs in Spark.
+ * A wrapper for chained Python functions (from bottom to top).
+ * @param funcs
+ */
+private[spark] case class ChainedPythonFunctions(funcs: Seq[PythonFunction])
+
+private[spark] object PythonRunner {
+ def apply(func: PythonFunction, bufferSize: Int, reuse_worker: Boolean): PythonRunner = {
+ new PythonRunner(
+ Seq(ChainedPythonFunctions(Seq(func))), bufferSize, reuse_worker, false, Array(Array(0)))
+ }
+}
+
+/**
+ * A helper class to run Python mapPartition/UDFs in Spark.
+ *
+ * funcs is a list of independent Python functions, each one of them is a list of chained Python
+ * functions (from bottom to top).
*/
private[spark] class PythonRunner(
- func: PythonFunction,
+ funcs: Seq[ChainedPythonFunctions],
bufferSize: Int,
- reuse_worker: Boolean)
+ reuse_worker: Boolean,
+ isUDF: Boolean,
+ argOffsets: Array[Array[Int]])
extends Logging {
- private val envVars = func.envVars
- private val pythonExec = func.pythonExec
- private val accumulator = func.accumulator
+ require(funcs.length == argOffsets.length, "argOffsets should have the same length as funcs")
+
+ // All the Python functions should have the same exec, version and envvars.
+ private val envVars = funcs.head.funcs.head.envVars
+ private val pythonExec = funcs.head.funcs.head.pythonExec
+ private val pythonVer = funcs.head.funcs.head.pythonVer
+
+ // TODO: support accumulator in multiple UDF
+ private val accumulator = funcs.head.funcs.head.accumulator
def compute(
inputIterator: Iterator[_],
@@ -228,10 +252,8 @@ private[spark] class PythonRunner(
@volatile private var _exception: Exception = null
- private val pythonVer = func.pythonVer
- private val pythonIncludes = func.pythonIncludes
- private val broadcastVars = func.broadcastVars
- private val command = func.command
+ private val pythonIncludes = funcs.flatMap(_.funcs.flatMap(_.pythonIncludes.asScala)).toSet
+ private val broadcastVars = funcs.flatMap(_.funcs.flatMap(_.broadcastVars.asScala))
setDaemon(true)
@@ -256,13 +278,13 @@ private[spark] class PythonRunner(
// sparkFilesDir
PythonRDD.writeUTF(SparkFiles.getRootDirectory(), dataOut)
// Python includes (*.zip and *.egg files)
- dataOut.writeInt(pythonIncludes.size())
- for (include <- pythonIncludes.asScala) {
+ dataOut.writeInt(pythonIncludes.size)
+ for (include <- pythonIncludes) {
PythonRDD.writeUTF(include, dataOut)
}
// Broadcast variables
val oldBids = PythonRDD.getWorkerBroadcasts(worker)
- val newBids = broadcastVars.asScala.map(_.id).toSet
+ val newBids = broadcastVars.map(_.id).toSet
// number of different broadcasts
val toRemove = oldBids.diff(newBids)
val cnt = toRemove.size + newBids.diff(oldBids).size
@@ -272,7 +294,7 @@ private[spark] class PythonRunner(
dataOut.writeLong(- bid - 1) // bid >= 0
oldBids.remove(bid)
}
- for (broadcast <- broadcastVars.asScala) {
+ for (broadcast <- broadcastVars) {
if (!oldBids.contains(broadcast.id)) {
// send new broadcast
dataOut.writeLong(broadcast.id)
@@ -282,8 +304,26 @@ private[spark] class PythonRunner(
}
dataOut.flush()
// Serialized command:
- dataOut.writeInt(command.length)
- dataOut.write(command)
+ if (isUDF) {
+ dataOut.writeInt(1)
+ dataOut.writeInt(funcs.length)
+ funcs.zip(argOffsets).foreach { case (chained, offsets) =>
+ dataOut.writeInt(offsets.length)
+ offsets.foreach { offset =>
+ dataOut.writeInt(offset)
+ }
+ dataOut.writeInt(chained.funcs.length)
+ chained.funcs.foreach { f =>
+ dataOut.writeInt(f.command.length)
+ dataOut.write(f.command)
+ }
+ }
+ } else {
+ dataOut.writeInt(0)
+ val command = funcs.head.funcs.head.command
+ dataOut.writeInt(command.length)
+ dataOut.write(command)
+ }
// Data values
PythonRDD.writeIteratorToStream(inputIterator, dataOut)
dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION)
@@ -413,6 +453,10 @@ private[spark] object PythonRDD extends Logging {
serveIterator(rdd.collect().iterator, s"serve RDD ${rdd.id}")
}
+ def toLocalIteratorAndServe[T](rdd: RDD[T]): Int = {
+ serveIterator(rdd.toLocalIterator, s"serve toLocalIterator")
+ }
+
def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int):
JavaRDD[Array[Byte]] = {
val file = new DataInputStream(new FileInputStream(filename))
@@ -426,7 +470,7 @@ private[spark] object PythonRDD extends Logging {
objs.append(obj)
}
} catch {
- case eof: EOFException => {}
+ case eof: EOFException => // No-op
}
JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism))
} finally {