diff options
author | Davies Liu <davies@databricks.com> | 2016-03-31 16:40:20 -0700 |
---|---|---|
committer | Davies Liu <davies.liu@gmail.com> | 2016-03-31 16:40:20 -0700 |
commit | f0afafdc5dfee80d7e5cd2fc1fa8187def7f262d (patch) | |
tree | b374ad4a7c98e11f8f85fbd44618422bd4fe6a1b /core/src | |
parent | 8de201baedc8e839e06098c536ba31b3dafd54b5 (diff) | |
download | spark-f0afafdc5dfee80d7e5cd2fc1fa8187def7f262d.tar.gz spark-f0afafdc5dfee80d7e5cd2fc1fa8187def7f262d.tar.bz2 spark-f0afafdc5dfee80d7e5cd2fc1fa8187def7f262d.zip |
[SPARK-14267] [SQL] [PYSPARK] execute multiple Python UDFs within single batch
## What changes were proposed in this pull request?
This PR support multiple Python UDFs within single batch, also improve the performance.
```python
>>> from pyspark.sql.types import IntegerType
>>> sqlContext.registerFunction("double", lambda x: x * 2, IntegerType())
>>> sqlContext.registerFunction("add", lambda x, y: x + y, IntegerType())
>>> sqlContext.sql("SELECT double(add(1, 2)), add(double(2), 1)").explain(True)
== Parsed Logical Plan ==
'Project [unresolvedalias('double('add(1, 2)), None),unresolvedalias('add('double(2), 1), None)]
+- OneRowRelation$
== Analyzed Logical Plan ==
double(add(1, 2)): int, add(double(2), 1): int
Project [double(add(1, 2))#14,add(double(2), 1)#15]
+- Project [double(add(1, 2))#14,add(double(2), 1)#15]
+- Project [pythonUDF0#16 AS double(add(1, 2))#14,pythonUDF0#18 AS add(double(2), 1)#15]
+- EvaluatePython [add(pythonUDF1#17, 1)], [pythonUDF0#18]
+- EvaluatePython [double(add(1, 2)),double(2)], [pythonUDF0#16,pythonUDF1#17]
+- OneRowRelation$
== Optimized Logical Plan ==
Project [pythonUDF0#16 AS double(add(1, 2))#14,pythonUDF0#18 AS add(double(2), 1)#15]
+- EvaluatePython [add(pythonUDF1#17, 1)], [pythonUDF0#18]
+- EvaluatePython [double(add(1, 2)),double(2)], [pythonUDF0#16,pythonUDF1#17]
+- OneRowRelation$
== Physical Plan ==
WholeStageCodegen
: +- Project [pythonUDF0#16 AS double(add(1, 2))#14,pythonUDF0#18 AS add(double(2), 1)#15]
: +- INPUT
+- !BatchPythonEvaluation [add(pythonUDF1#17, 1)], [pythonUDF0#16,pythonUDF1#17,pythonUDF0#18]
+- !BatchPythonEvaluation [double(add(1, 2)),double(2)], [pythonUDF0#16,pythonUDF1#17]
+- Scan OneRowRelation[]
```
## How was this patch tested?
Added new tests.
Using the following script to benchmark 1, 2 and 3 udfs,
```
df = sqlContext.range(1, 1 << 23, 1, 4)
double = F.udf(lambda x: x * 2, LongType())
print df.select(double(df.id)).count()
print df.select(double(df.id), double(df.id + 1)).count()
print df.select(double(df.id), double(df.id + 1), double(df.id + 2)).count()
```
Here is the results:
N | Before | After | speed up
---- |------------ | -------------|------
1 | 22 s | 7 s | 3.1X
2 | 38 s | 13 s | 2.9X
3 | 58 s | 16 s | 3.6X
This benchmark ran locally with 4 CPUs. For 3 UDFs, it launched 12 Python before before this patch, 4 process after this patch. After this patch, it will use less memory for multiple UDFs than before (less buffering).
Author: Davies Liu <davies@databricks.com>
Closes #12057 from davies/multi_udfs.
Diffstat (limited to 'core/src')
-rw-r--r-- | core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala | 64 |
1 files changed, 49 insertions, 15 deletions
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 0f579b4ef5..6faa03c12b 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -59,7 +59,7 @@ private[spark] class PythonRDD( val asJavaRDD: JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = { - val runner = new PythonRunner(Seq(func), bufferSize, reuse_worker, false) + val runner = PythonRunner(func, bufferSize, reuse_worker) runner.compute(firstParent.iterator(split, context), split.index, context) } } @@ -78,21 +78,41 @@ private[spark] case class PythonFunction( accumulator: Accumulator[JList[Array[Byte]]]) /** - * A helper class to run Python UDFs in Spark. + * A wrapper for chained Python functions (from bottom to top). + * @param funcs + */ +private[spark] case class ChainedPythonFunctions(funcs: Seq[PythonFunction]) + +private[spark] object PythonRunner { + def apply(func: PythonFunction, bufferSize: Int, reuse_worker: Boolean): PythonRunner = { + new PythonRunner( + Seq(ChainedPythonFunctions(Seq(func))), bufferSize, reuse_worker, false, Array(Array(0))) + } +} + +/** + * A helper class to run Python mapPartition/UDFs in Spark. + * + * funcs is a list of independent Python functions, each one of them is a list of chained Python + * functions (from bottom to top). */ private[spark] class PythonRunner( - funcs: Seq[PythonFunction], + funcs: Seq[ChainedPythonFunctions], bufferSize: Int, reuse_worker: Boolean, - rowBased: Boolean) + isUDF: Boolean, + argOffsets: Array[Array[Int]]) extends Logging { + require(funcs.length == argOffsets.length, "argOffsets should have the same length as funcs") + // All the Python functions should have the same exec, version and envvars. - private val envVars = funcs.head.envVars - private val pythonExec = funcs.head.pythonExec - private val pythonVer = funcs.head.pythonVer + private val envVars = funcs.head.funcs.head.envVars + private val pythonExec = funcs.head.funcs.head.pythonExec + private val pythonVer = funcs.head.funcs.head.pythonVer - private val accumulator = funcs.head.accumulator // TODO: support accumulator in multiple UDF + // TODO: support accumulator in multiple UDF + private val accumulator = funcs.head.funcs.head.accumulator def compute( inputIterator: Iterator[_], @@ -232,8 +252,8 @@ private[spark] class PythonRunner( @volatile private var _exception: Exception = null - private val pythonIncludes = funcs.flatMap(_.pythonIncludes.asScala).toSet - private val broadcastVars = funcs.flatMap(_.broadcastVars.asScala) + private val pythonIncludes = funcs.flatMap(_.funcs.flatMap(_.pythonIncludes.asScala)).toSet + private val broadcastVars = funcs.flatMap(_.funcs.flatMap(_.broadcastVars.asScala)) setDaemon(true) @@ -284,11 +304,25 @@ private[spark] class PythonRunner( } dataOut.flush() // Serialized command: - dataOut.writeInt(if (rowBased) 1 else 0) - dataOut.writeInt(funcs.length) - funcs.foreach { f => - dataOut.writeInt(f.command.length) - dataOut.write(f.command) + if (isUDF) { + dataOut.writeInt(1) + dataOut.writeInt(funcs.length) + funcs.zip(argOffsets).foreach { case (chained, offsets) => + dataOut.writeInt(offsets.length) + offsets.foreach { offset => + dataOut.writeInt(offset) + } + dataOut.writeInt(chained.funcs.length) + chained.funcs.foreach { f => + dataOut.writeInt(f.command.length) + dataOut.write(f.command) + } + } + } else { + dataOut.writeInt(0) + val command = funcs.head.funcs.head.command + dataOut.writeInt(command.length) + dataOut.write(command) } // Data values PythonRDD.writeIteratorToStream(inputIterator, dataOut) |