aboutsummaryrefslogtreecommitdiff
path: root/core/src/main/scala/org/apache
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2016-03-31 16:40:20 -0700
committerDavies Liu <davies.liu@gmail.com>2016-03-31 16:40:20 -0700
commitf0afafdc5dfee80d7e5cd2fc1fa8187def7f262d (patch)
treeb374ad4a7c98e11f8f85fbd44618422bd4fe6a1b /core/src/main/scala/org/apache
parent8de201baedc8e839e06098c536ba31b3dafd54b5 (diff)
downloadspark-f0afafdc5dfee80d7e5cd2fc1fa8187def7f262d.tar.gz
spark-f0afafdc5dfee80d7e5cd2fc1fa8187def7f262d.tar.bz2
spark-f0afafdc5dfee80d7e5cd2fc1fa8187def7f262d.zip
[SPARK-14267] [SQL] [PYSPARK] execute multiple Python UDFs within single batch
## What changes were proposed in this pull request? This PR support multiple Python UDFs within single batch, also improve the performance. ```python >>> from pyspark.sql.types import IntegerType >>> sqlContext.registerFunction("double", lambda x: x * 2, IntegerType()) >>> sqlContext.registerFunction("add", lambda x, y: x + y, IntegerType()) >>> sqlContext.sql("SELECT double(add(1, 2)), add(double(2), 1)").explain(True) == Parsed Logical Plan == 'Project [unresolvedalias('double('add(1, 2)), None),unresolvedalias('add('double(2), 1), None)] +- OneRowRelation$ == Analyzed Logical Plan == double(add(1, 2)): int, add(double(2), 1): int Project [double(add(1, 2))#14,add(double(2), 1)#15] +- Project [double(add(1, 2))#14,add(double(2), 1)#15] +- Project [pythonUDF0#16 AS double(add(1, 2))#14,pythonUDF0#18 AS add(double(2), 1)#15] +- EvaluatePython [add(pythonUDF1#17, 1)], [pythonUDF0#18] +- EvaluatePython [double(add(1, 2)),double(2)], [pythonUDF0#16,pythonUDF1#17] +- OneRowRelation$ == Optimized Logical Plan == Project [pythonUDF0#16 AS double(add(1, 2))#14,pythonUDF0#18 AS add(double(2), 1)#15] +- EvaluatePython [add(pythonUDF1#17, 1)], [pythonUDF0#18] +- EvaluatePython [double(add(1, 2)),double(2)], [pythonUDF0#16,pythonUDF1#17] +- OneRowRelation$ == Physical Plan == WholeStageCodegen : +- Project [pythonUDF0#16 AS double(add(1, 2))#14,pythonUDF0#18 AS add(double(2), 1)#15] : +- INPUT +- !BatchPythonEvaluation [add(pythonUDF1#17, 1)], [pythonUDF0#16,pythonUDF1#17,pythonUDF0#18] +- !BatchPythonEvaluation [double(add(1, 2)),double(2)], [pythonUDF0#16,pythonUDF1#17] +- Scan OneRowRelation[] ``` ## How was this patch tested? Added new tests. Using the following script to benchmark 1, 2 and 3 udfs, ``` df = sqlContext.range(1, 1 << 23, 1, 4) double = F.udf(lambda x: x * 2, LongType()) print df.select(double(df.id)).count() print df.select(double(df.id), double(df.id + 1)).count() print df.select(double(df.id), double(df.id + 1), double(df.id + 2)).count() ``` Here is the results: N | Before | After | speed up ---- |------------ | -------------|------ 1 | 22 s | 7 s | 3.1X 2 | 38 s | 13 s | 2.9X 3 | 58 s | 16 s | 3.6X This benchmark ran locally with 4 CPUs. For 3 UDFs, it launched 12 Python before before this patch, 4 process after this patch. After this patch, it will use less memory for multiple UDFs than before (less buffering). Author: Davies Liu <davies@databricks.com> Closes #12057 from davies/multi_udfs.
Diffstat (limited to 'core/src/main/scala/org/apache')
-rw-r--r--core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala64
1 files changed, 49 insertions, 15 deletions
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 0f579b4ef5..6faa03c12b 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -59,7 +59,7 @@ private[spark] class PythonRDD(
val asJavaRDD: JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this)
override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
- val runner = new PythonRunner(Seq(func), bufferSize, reuse_worker, false)
+ val runner = PythonRunner(func, bufferSize, reuse_worker)
runner.compute(firstParent.iterator(split, context), split.index, context)
}
}
@@ -78,21 +78,41 @@ private[spark] case class PythonFunction(
accumulator: Accumulator[JList[Array[Byte]]])
/**
- * A helper class to run Python UDFs in Spark.
+ * A wrapper for chained Python functions (from bottom to top).
+ * @param funcs
+ */
+private[spark] case class ChainedPythonFunctions(funcs: Seq[PythonFunction])
+
+private[spark] object PythonRunner {
+ def apply(func: PythonFunction, bufferSize: Int, reuse_worker: Boolean): PythonRunner = {
+ new PythonRunner(
+ Seq(ChainedPythonFunctions(Seq(func))), bufferSize, reuse_worker, false, Array(Array(0)))
+ }
+}
+
+/**
+ * A helper class to run Python mapPartition/UDFs in Spark.
+ *
+ * funcs is a list of independent Python functions, each one of them is a list of chained Python
+ * functions (from bottom to top).
*/
private[spark] class PythonRunner(
- funcs: Seq[PythonFunction],
+ funcs: Seq[ChainedPythonFunctions],
bufferSize: Int,
reuse_worker: Boolean,
- rowBased: Boolean)
+ isUDF: Boolean,
+ argOffsets: Array[Array[Int]])
extends Logging {
+ require(funcs.length == argOffsets.length, "argOffsets should have the same length as funcs")
+
// All the Python functions should have the same exec, version and envvars.
- private val envVars = funcs.head.envVars
- private val pythonExec = funcs.head.pythonExec
- private val pythonVer = funcs.head.pythonVer
+ private val envVars = funcs.head.funcs.head.envVars
+ private val pythonExec = funcs.head.funcs.head.pythonExec
+ private val pythonVer = funcs.head.funcs.head.pythonVer
- private val accumulator = funcs.head.accumulator // TODO: support accumulator in multiple UDF
+ // TODO: support accumulator in multiple UDF
+ private val accumulator = funcs.head.funcs.head.accumulator
def compute(
inputIterator: Iterator[_],
@@ -232,8 +252,8 @@ private[spark] class PythonRunner(
@volatile private var _exception: Exception = null
- private val pythonIncludes = funcs.flatMap(_.pythonIncludes.asScala).toSet
- private val broadcastVars = funcs.flatMap(_.broadcastVars.asScala)
+ private val pythonIncludes = funcs.flatMap(_.funcs.flatMap(_.pythonIncludes.asScala)).toSet
+ private val broadcastVars = funcs.flatMap(_.funcs.flatMap(_.broadcastVars.asScala))
setDaemon(true)
@@ -284,11 +304,25 @@ private[spark] class PythonRunner(
}
dataOut.flush()
// Serialized command:
- dataOut.writeInt(if (rowBased) 1 else 0)
- dataOut.writeInt(funcs.length)
- funcs.foreach { f =>
- dataOut.writeInt(f.command.length)
- dataOut.write(f.command)
+ if (isUDF) {
+ dataOut.writeInt(1)
+ dataOut.writeInt(funcs.length)
+ funcs.zip(argOffsets).foreach { case (chained, offsets) =>
+ dataOut.writeInt(offsets.length)
+ offsets.foreach { offset =>
+ dataOut.writeInt(offset)
+ }
+ dataOut.writeInt(chained.funcs.length)
+ chained.funcs.foreach { f =>
+ dataOut.writeInt(f.command.length)
+ dataOut.write(f.command)
+ }
+ }
+ } else {
+ dataOut.writeInt(0)
+ val command = funcs.head.funcs.head.command
+ dataOut.writeInt(command.length)
+ dataOut.write(command)
}
// Data values
PythonRDD.writeIteratorToStream(inputIterator, dataOut)