diff options
Diffstat (limited to 'core')
-rw-r--r-- | core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala | 37 |
1 files changed, 22 insertions, 15 deletions
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index f12e2dfafa..05d1c31a08 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -42,14 +42,8 @@ import org.apache.spark.util.{SerializableConfiguration, Utils} private[spark] class PythonRDD( parent: RDD[_], - command: Array[Byte], - envVars: JMap[String, String], - pythonIncludes: JList[String], - preservePartitoning: Boolean, - pythonExec: String, - pythonVer: String, - broadcastVars: JList[Broadcast[PythonBroadcast]], - accumulator: Accumulator[JList[Array[Byte]]]) + func: PythonFunction, + preservePartitoning: Boolean) extends RDD[Array[Byte]](parent) { val bufferSize = conf.getInt("spark.buffer.size", 65536) @@ -64,29 +58,37 @@ private[spark] class PythonRDD( val asJavaRDD: JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = { - val runner = new PythonRunner( - command, envVars, pythonIncludes, pythonExec, pythonVer, broadcastVars, accumulator, - bufferSize, reuse_worker) + val runner = new PythonRunner(func, bufferSize, reuse_worker) runner.compute(firstParent.iterator(split, context), split.index, context) } } - /** - * A helper class to run Python UDFs in Spark. + * A wrapper for a Python function, contains all necessary context to run the function in Python + * runner. */ -private[spark] class PythonRunner( +private[spark] case class PythonFunction( command: Array[Byte], envVars: JMap[String, String], pythonIncludes: JList[String], pythonExec: String, pythonVer: String, broadcastVars: JList[Broadcast[PythonBroadcast]], - accumulator: Accumulator[JList[Array[Byte]]], + accumulator: Accumulator[JList[Array[Byte]]]) + +/** + * A helper class to run Python UDFs in Spark. + */ +private[spark] class PythonRunner( + func: PythonFunction, bufferSize: Int, reuse_worker: Boolean) extends Logging { + private val envVars = func.envVars + private val pythonExec = func.pythonExec + private val accumulator = func.accumulator + def compute( inputIterator: Iterator[_], partitionIndex: Int, @@ -225,6 +227,11 @@ private[spark] class PythonRunner( @volatile private var _exception: Exception = null + private val pythonVer = func.pythonVer + private val pythonIncludes = func.pythonIncludes + private val broadcastVars = func.broadcastVars + private val command = func.command + setDaemon(true) /** Contains the exception thrown while writing the parent iterator to the Python process. */ |