aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala54
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala80
2 files changed, 89 insertions, 45 deletions
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 69da180593..3788d18297 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -24,6 +24,7 @@ import java.util.{Collections, ArrayList => JArrayList, List => JList, Map => JM
import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.language.existentials
+import scala.util.control.NonFatal
import com.google.common.base.Charsets.UTF_8
import org.apache.hadoop.conf.Configuration
@@ -38,7 +39,6 @@ import org.apache.spark.input.PortableDataStream
import org.apache.spark.rdd.RDD
import org.apache.spark.util.{SerializableConfiguration, Utils}
-import scala.util.control.NonFatal
private[spark] class PythonRDD(
parent: RDD[_],
@@ -61,11 +61,39 @@ private[spark] class PythonRDD(
if (preservePartitoning) firstParent.partitioner else None
}
+ val asJavaRDD: JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this)
+
override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
+ val runner = new PythonRunner(
+ command, envVars, pythonIncludes, pythonExec, pythonVer, broadcastVars, accumulator,
+ bufferSize, reuse_worker)
+ runner.compute(firstParent.iterator(split, context), split.index, context)
+ }
+}
+
+
+/**
+ * A helper class to run Python UDFs in Spark.
+ */
+private[spark] class PythonRunner(
+ command: Array[Byte],
+ envVars: JMap[String, String],
+ pythonIncludes: JList[String],
+ pythonExec: String,
+ pythonVer: String,
+ broadcastVars: JList[Broadcast[PythonBroadcast]],
+ accumulator: Accumulator[JList[Array[Byte]]],
+ bufferSize: Int,
+ reuse_worker: Boolean)
+ extends Logging {
+
+ def compute(
+ inputIterator: Iterator[_],
+ partitionIndex: Int,
+ context: TaskContext): Iterator[Array[Byte]] = {
val startTime = System.currentTimeMillis
val env = SparkEnv.get
- val localdir = env.blockManager.diskBlockManager.localDirs.map(
- f => f.getPath()).mkString(",")
+ val localdir = env.blockManager.diskBlockManager.localDirs.map(f => f.getPath()).mkString(",")
envVars.put("SPARK_LOCAL_DIRS", localdir) // it's also used in monitor thread
if (reuse_worker) {
envVars.put("SPARK_REUSE_WORKER", "1")
@@ -75,7 +103,7 @@ private[spark] class PythonRDD(
@volatile var released = false
// Start a thread to feed the process input from our parent's iterator
- val writerThread = new WriterThread(env, worker, split, context)
+ val writerThread = new WriterThread(env, worker, inputIterator, partitionIndex, context)
context.addTaskCompletionListener { context =>
writerThread.shutdownOnTaskCompletion()
@@ -183,13 +211,16 @@ private[spark] class PythonRDD(
new InterruptibleIterator(context, stdoutIterator)
}
- val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this)
-
/**
* The thread responsible for writing the data from the PythonRDD's parent iterator to the
* Python process.
*/
- class WriterThread(env: SparkEnv, worker: Socket, split: Partition, context: TaskContext)
+ class WriterThread(
+ env: SparkEnv,
+ worker: Socket,
+ inputIterator: Iterator[_],
+ partitionIndex: Int,
+ context: TaskContext)
extends Thread(s"stdout writer for $pythonExec") {
@volatile private var _exception: Exception = null
@@ -211,11 +242,11 @@ private[spark] class PythonRDD(
val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize)
val dataOut = new DataOutputStream(stream)
// Partition index
- dataOut.writeInt(split.index)
+ dataOut.writeInt(partitionIndex)
// Python version of driver
PythonRDD.writeUTF(pythonVer, dataOut)
// sparkFilesDir
- PythonRDD.writeUTF(SparkFiles.getRootDirectory, dataOut)
+ PythonRDD.writeUTF(SparkFiles.getRootDirectory(), dataOut)
// Python includes (*.zip and *.egg files)
dataOut.writeInt(pythonIncludes.size())
for (include <- pythonIncludes.asScala) {
@@ -246,7 +277,7 @@ private[spark] class PythonRDD(
dataOut.writeInt(command.length)
dataOut.write(command)
// Data values
- PythonRDD.writeIteratorToStream(firstParent.iterator(split, context), dataOut)
+ PythonRDD.writeIteratorToStream(inputIterator, dataOut)
dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION)
dataOut.writeInt(SpecialLengths.END_OF_STREAM)
dataOut.flush()
@@ -327,7 +358,8 @@ private[spark] object PythonRDD extends Logging {
// remember the broadcasts sent to each worker
private val workerBroadcasts = new mutable.WeakHashMap[Socket, mutable.Set[Long]]()
- private def getWorkerBroadcasts(worker: Socket) = {
+
+ def getWorkerBroadcasts(worker: Socket): mutable.Set[Long] = {
synchronized {
workerBroadcasts.getOrElseUpdate(worker, new mutable.HashSet[Long]())
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala
index d0411da6fd..c35c726bfc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala
@@ -25,7 +25,7 @@ import scala.collection.JavaConverters._
import net.razorvine.pickle._
import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.api.python.{PythonBroadcast, PythonRDD, SerDeUtil}
+import org.apache.spark.api.python.{PythonRunner, PythonBroadcast, PythonRDD, SerDeUtil}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
@@ -35,7 +35,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
-import org.apache.spark.{Accumulator, Logging => SparkLogging}
+import org.apache.spark.{Logging => SparkLogging, TaskContext, Accumulator}
/**
* A serialized version of a Python lambda function. Suitable for use in a [[PythonRDD]].
@@ -329,7 +329,13 @@ case class EvaluatePython(
/**
* :: DeveloperApi ::
* Uses PythonRDD to evaluate a [[PythonUDF]], one partition of tuples at a time.
- * The input data is zipped with the result of the udf evaluation.
+ *
+ * Python evaluation works by sending the necessary (projected) input data via a socket to an
+ * external Python process, and combine the result from the Python process with the original row.
+ *
+ * For each row we send to Python, we also put it in a queue. For each output row from Python,
+ * we drain the queue to find the original input row. Note that if the Python process is way too
+ * slow, this could lead to the queue growing unbounded and eventually run out of memory.
*/
@DeveloperApi
case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: SparkPlan)
@@ -342,51 +348,57 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child:
override def canProcessSafeRows: Boolean = true
protected override def doExecute(): RDD[InternalRow] = {
- val childResults = child.execute().map(_.copy())
+ val inputRDD = child.execute().map(_.copy())
+ val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536)
+ val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true)
- val parent = childResults.mapPartitions { iter =>
+ inputRDD.mapPartitions { iter =>
EvaluatePython.registerPicklers() // register pickler for Row
+
+ // The queue used to buffer input rows so we can drain it to
+ // combine input with output from Python.
+ val queue = new java.util.concurrent.ConcurrentLinkedQueue[InternalRow]()
+
val pickle = new Pickler
val currentRow = newMutableProjection(udf.children, child.output)()
val fields = udf.children.map(_.dataType)
val schema = new StructType(fields.map(t => new StructField("", t, true)).toArray)
- iter.grouped(100).map { inputRows =>
+
+ // Input iterator to Python: input rows are grouped so we send them in batches to Python.
+ // For each row, add it to the queue.
+ val inputIterator = iter.grouped(100).map { inputRows =>
val toBePickled = inputRows.map { row =>
+ queue.add(row)
EvaluatePython.toJava(currentRow(row), schema)
}.toArray
pickle.dumps(toBePickled)
}
- }
- val pyRDD = new PythonRDD(
- parent,
- udf.command,
- udf.envVars,
- udf.pythonIncludes,
- false,
- udf.pythonExec,
- udf.pythonVer,
- udf.broadcastVars,
- udf.accumulator
- ).mapPartitions { iter =>
- val pickle = new Unpickler
- iter.flatMap { pickedResult =>
- val unpickledBatch = pickle.loads(pickedResult)
- unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala
- }
- }.mapPartitions { iter =>
+ val context = TaskContext.get()
+
+ // Output iterator for results from Python.
+ val outputIterator = new PythonRunner(
+ udf.command,
+ udf.envVars,
+ udf.pythonIncludes,
+ udf.pythonExec,
+ udf.pythonVer,
+ udf.broadcastVars,
+ udf.accumulator,
+ bufferSize,
+ reuseWorker
+ ).compute(inputIterator, context.partitionId(), context)
+
+ val unpickle = new Unpickler
val row = new GenericMutableRow(1)
- iter.map { result =>
- row(0) = EvaluatePython.fromJava(result, udf.dataType)
- row: InternalRow
- }
- }
+ val joined = new JoinedRow
- childResults.zip(pyRDD).mapPartitions { iter =>
- val joinedRow = new JoinedRow()
- iter.map {
- case (row, udfResult) =>
- joinedRow(row, udfResult)
+ outputIterator.flatMap { pickedResult =>
+ val unpickledBatch = unpickle.loads(pickedResult)
+ unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala
+ }.map { result =>
+ row(0) = EvaluatePython.fromJava(result, udf.dataType)
+ joined(queue.poll(), row)
}
}
}