diff options
Diffstat (limited to 'core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala')
-rw-r--r-- | core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala | 80 |
1 files changed, 62 insertions, 18 deletions
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index f423b2ee56..ab5b6c8380 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -59,7 +59,7 @@ private[spark] class PythonRDD( val asJavaRDD: JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = { - val runner = new PythonRunner(func, bufferSize, reuse_worker) + val runner = PythonRunner(func, bufferSize, reuse_worker) runner.compute(firstParent.iterator(split, context), split.index, context) } } @@ -78,17 +78,41 @@ private[spark] case class PythonFunction( accumulator: Accumulator[JList[Array[Byte]]]) /** - * A helper class to run Python UDFs in Spark. + * A wrapper for chained Python functions (from bottom to top). + * @param funcs + */ +private[spark] case class ChainedPythonFunctions(funcs: Seq[PythonFunction]) + +private[spark] object PythonRunner { + def apply(func: PythonFunction, bufferSize: Int, reuse_worker: Boolean): PythonRunner = { + new PythonRunner( + Seq(ChainedPythonFunctions(Seq(func))), bufferSize, reuse_worker, false, Array(Array(0))) + } +} + +/** + * A helper class to run Python mapPartition/UDFs in Spark. + * + * funcs is a list of independent Python functions, each one of them is a list of chained Python + * functions (from bottom to top). */ private[spark] class PythonRunner( - func: PythonFunction, + funcs: Seq[ChainedPythonFunctions], bufferSize: Int, - reuse_worker: Boolean) + reuse_worker: Boolean, + isUDF: Boolean, + argOffsets: Array[Array[Int]]) extends Logging { - private val envVars = func.envVars - private val pythonExec = func.pythonExec - private val accumulator = func.accumulator + require(funcs.length == argOffsets.length, "argOffsets should have the same length as funcs") + + // All the Python functions should have the same exec, version and envvars. + private val envVars = funcs.head.funcs.head.envVars + private val pythonExec = funcs.head.funcs.head.pythonExec + private val pythonVer = funcs.head.funcs.head.pythonVer + + // TODO: support accumulator in multiple UDF + private val accumulator = funcs.head.funcs.head.accumulator def compute( inputIterator: Iterator[_], @@ -228,10 +252,8 @@ private[spark] class PythonRunner( @volatile private var _exception: Exception = null - private val pythonVer = func.pythonVer - private val pythonIncludes = func.pythonIncludes - private val broadcastVars = func.broadcastVars - private val command = func.command + private val pythonIncludes = funcs.flatMap(_.funcs.flatMap(_.pythonIncludes.asScala)).toSet + private val broadcastVars = funcs.flatMap(_.funcs.flatMap(_.broadcastVars.asScala)) setDaemon(true) @@ -256,13 +278,13 @@ private[spark] class PythonRunner( // sparkFilesDir PythonRDD.writeUTF(SparkFiles.getRootDirectory(), dataOut) // Python includes (*.zip and *.egg files) - dataOut.writeInt(pythonIncludes.size()) - for (include <- pythonIncludes.asScala) { + dataOut.writeInt(pythonIncludes.size) + for (include <- pythonIncludes) { PythonRDD.writeUTF(include, dataOut) } // Broadcast variables val oldBids = PythonRDD.getWorkerBroadcasts(worker) - val newBids = broadcastVars.asScala.map(_.id).toSet + val newBids = broadcastVars.map(_.id).toSet // number of different broadcasts val toRemove = oldBids.diff(newBids) val cnt = toRemove.size + newBids.diff(oldBids).size @@ -272,7 +294,7 @@ private[spark] class PythonRunner( dataOut.writeLong(- bid - 1) // bid >= 0 oldBids.remove(bid) } - for (broadcast <- broadcastVars.asScala) { + for (broadcast <- broadcastVars) { if (!oldBids.contains(broadcast.id)) { // send new broadcast dataOut.writeLong(broadcast.id) @@ -282,8 +304,26 @@ private[spark] class PythonRunner( } dataOut.flush() // Serialized command: - dataOut.writeInt(command.length) - dataOut.write(command) + if (isUDF) { + dataOut.writeInt(1) + dataOut.writeInt(funcs.length) + funcs.zip(argOffsets).foreach { case (chained, offsets) => + dataOut.writeInt(offsets.length) + offsets.foreach { offset => + dataOut.writeInt(offset) + } + dataOut.writeInt(chained.funcs.length) + chained.funcs.foreach { f => + dataOut.writeInt(f.command.length) + dataOut.write(f.command) + } + } + } else { + dataOut.writeInt(0) + val command = funcs.head.funcs.head.command + dataOut.writeInt(command.length) + dataOut.write(command) + } // Data values PythonRDD.writeIteratorToStream(inputIterator, dataOut) dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION) @@ -413,6 +453,10 @@ private[spark] object PythonRDD extends Logging { serveIterator(rdd.collect().iterator, s"serve RDD ${rdd.id}") } + def toLocalIteratorAndServe[T](rdd: RDD[T]): Int = { + serveIterator(rdd.toLocalIterator, s"serve toLocalIterator") + } + def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int): JavaRDD[Array[Byte]] = { val file = new DataInputStream(new FileInputStream(filename)) @@ -426,7 +470,7 @@ private[spark] object PythonRDD extends Logging { objs.append(obj) } } catch { - case eof: EOFException => {} + case eof: EOFException => // No-op } JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism)) } finally { |